megatron-core 0.14.0rc6__tar.gz → 0.15.0rc0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (327) hide show
  1. {megatron_core-0.14.0rc6/megatron_core.egg-info → megatron_core-0.15.0rc0}/PKG-INFO +3 -3
  2. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/dict_utils.py +13 -5
  3. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/mapping.py +11 -11
  4. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/optimizer.py +6 -0
  5. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/async_utils.py +52 -14
  6. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/base.py +1 -5
  7. megatron_core-0.15.0rc0/megatron/core/dist_checkpointing/strategies/checkpointable.py +196 -0
  8. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/torch.py +38 -15
  9. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/zarr.py +6 -1
  10. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/validation.py +13 -3
  11. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/distributed/__init__.py +1 -0
  12. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/distributed/distributed_data_parallel_config.py +20 -6
  13. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/distributed/finalize_model_grads.py +27 -14
  14. megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/__init__.py +3 -0
  15. megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py +317 -0
  16. megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/src/__init__.py +13 -0
  17. megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py +22 -0
  18. megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py +141 -0
  19. megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py +387 -0
  20. megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +1107 -0
  21. {megatron_core-0.14.0rc6/megatron/core/distributed/custom_fsdp → megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/src/megatron_fsdp}/param_and_grad_buffer.py +1649 -522
  22. megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py +458 -0
  23. megatron_core-0.15.0rc0/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +908 -0
  24. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/distributed/param_and_grad_buffer.py +5 -7
  25. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/distributed/torch_fully_sharded_data_parallel.py +8 -0
  26. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/extensions/kitchen.py +4 -0
  27. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/extensions/transformer_engine.py +72 -2
  28. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/extensions/transformer_engine_spec_provider.py +5 -0
  29. megatron_core-0.15.0rc0/megatron/core/inference/data_parallel_inference_coordinator.py +322 -0
  30. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/engines/dynamic_engine.py +323 -6
  31. megatron_core-0.15.0rc0/megatron/core/inference/headers.py +17 -0
  32. megatron_core-0.15.0rc0/megatron/core/inference/inference_client.py +190 -0
  33. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/inference_request.py +11 -1
  34. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/sampling_params.py +11 -0
  35. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/text_generation_controllers/text_generation_controller.py +19 -1
  36. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/model_parallel_config.py +2 -2
  37. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/backends.py +9 -0
  38. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +23 -21
  39. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/language_module/language_module.py +19 -2
  40. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/gpt/gpt_layer_specs.py +13 -1
  41. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/gpt/moe_module_specs.py +7 -1
  42. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/huggingface/clip_model.py +1 -1
  43. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/huggingface/qwen_model.py +1 -1
  44. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/multimodal/context_parallel.py +25 -13
  45. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/multimodal/llava_model.py +5 -0
  46. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/vision/multimodal_projector.py +35 -30
  47. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/vision/radio.py +26 -0
  48. megatron_core-0.15.0rc0/megatron/core/nccl_allocator.py +249 -0
  49. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer/__init__.py +3 -22
  50. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer/clip_grads.py +15 -0
  51. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer/distrib_optimizer.py +556 -248
  52. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer/optimizer.py +15 -10
  53. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer/optimizer_config.py +6 -0
  54. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/package_info.py +2 -2
  55. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +1 -0
  56. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/tensor_parallel/layers.py +8 -8
  57. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/tensor_parallel/random.py +5 -2
  58. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/attention.py +30 -2
  59. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/cuda_graphs.py +16 -5
  60. megatron_core-0.15.0rc0/megatron/core/transformer/fsdp_dtensor_checkpoint.py +195 -0
  61. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/mlp.py +20 -2
  62. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/experts.py +25 -31
  63. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/moe_layer.py +28 -1
  64. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/shared_experts.py +33 -2
  65. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/multi_latent_attention.py +28 -3
  66. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/transformer_config.py +55 -5
  67. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/transformer_layer.py +12 -2
  68. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/utils.py +0 -3
  69. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/utils.py +4 -41
  70. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0/megatron_core.egg-info}/PKG-INFO +3 -3
  71. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron_core.egg-info/SOURCES.txt +16 -3
  72. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron_core.egg-info/requires.txt +2 -2
  73. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/pyproject.toml +3 -3
  74. megatron_core-0.14.0rc6/megatron/core/distributed/custom_fsdp/__init__.py +0 -3
  75. megatron_core-0.14.0rc6/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py +0 -835
  76. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/LICENSE +0 -0
  77. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/MANIFEST.in +0 -0
  78. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/README.md +0 -0
  79. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/README.md +0 -0
  80. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/__init__.py +0 -0
  81. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/activations.py +0 -0
  82. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/config.py +0 -0
  83. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/config_logger.py +0 -0
  84. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/__init__.py +0 -0
  85. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/bert_dataset.py +0 -0
  86. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/blended_dataset.py +0 -0
  87. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/blended_megatron_dataset_builder.py +0 -0
  88. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/blended_megatron_dataset_config.py +0 -0
  89. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/gpt_dataset.py +0 -0
  90. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/helpers.cpp +0 -0
  91. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/helpers.py +0 -0
  92. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/indexed_dataset.py +0 -0
  93. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/masked_dataset.py +0 -0
  94. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/megatron_dataset.py +0 -0
  95. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/megatron_tokenizer.py +0 -0
  96. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/multimodal_dataset.py +0 -0
  97. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/object_storage_utils.py +0 -0
  98. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/__init__.py +0 -0
  99. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/config/__init__.py +0 -0
  100. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/config/bert_embedders.py +0 -0
  101. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/config/config.py +0 -0
  102. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/config/gpt_chunk_datasets.py +0 -0
  103. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/config/tokenizers.py +0 -0
  104. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/db/__init__.py +0 -0
  105. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/db/build.py +0 -0
  106. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/db/dataset.py +0 -0
  107. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/db/utils.py +0 -0
  108. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/external_libs.py +0 -0
  109. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/__init__.py +0 -0
  110. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/build.py +0 -0
  111. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/factory.py +0 -0
  112. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/index.py +0 -0
  113. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/indexes/__init__.py +0 -0
  114. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/indexes/faiss_base.py +0 -0
  115. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +0 -0
  116. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/utils.py +0 -0
  117. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/index/validate.py +0 -0
  118. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/query/__init__.py +0 -0
  119. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +0 -0
  120. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +0 -0
  121. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/query/query.py +0 -0
  122. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/query/retro_dataset.py +0 -0
  123. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/query/utils.py +0 -0
  124. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/retro/utils.py +0 -0
  125. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/t5_dataset.py +0 -0
  126. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/utils.py +0 -0
  127. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/utils_object_storage.py +0 -0
  128. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/datasets/utils_s3.py +0 -0
  129. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/__init__.py +0 -0
  130. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/core.py +0 -0
  131. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/exchange_utils.py +0 -0
  132. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/serialization.py +0 -0
  133. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/state_dict_utils.py +0 -0
  134. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/__init__.py +0 -0
  135. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +0 -0
  136. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/common.py +0 -0
  137. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/filesystem_async.py +0 -0
  138. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/fully_parallel.py +0 -0
  139. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/resharding.py +0 -0
  140. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +0 -0
  141. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/tensorstore.py +0 -0
  142. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/strategies/two_stage.py +0 -0
  143. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/tensor_aware_state_dict.py +0 -0
  144. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/dist_checkpointing/utils.py +0 -0
  145. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/distributed/data_parallel_base.py +0 -0
  146. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/distributed/distributed_data_parallel.py +0 -0
  147. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +0 -0
  148. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/energy_monitor.py +0 -0
  149. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/enums.py +0 -0
  150. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/__init__.py +0 -0
  151. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/data_type.py +0 -0
  152. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/export_config.py +0 -0
  153. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/model_type.py +0 -0
  154. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/__init__.py +0 -0
  155. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/engine_builder/__init__.py +0 -0
  156. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +0 -0
  157. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +0 -0
  158. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +0 -0
  159. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/trt_model_config.py +0 -0
  160. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/trt_model_type.py +0 -0
  161. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/trtllm_helper.py +0 -0
  162. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/trtllm_layers.py +0 -0
  163. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +0 -0
  164. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +0 -0
  165. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +0 -0
  166. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/export/trtllm/trtllm_weights_converter/utils.py +0 -0
  167. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/extensions/__init__.py +0 -0
  168. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fp8_utils.py +0 -0
  169. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/full_cuda_graph.py +0 -0
  170. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/__init__.py +0 -0
  171. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_bias_dropout.py +0 -0
  172. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_bias_geglu.py +0 -0
  173. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_bias_gelu.py +0 -0
  174. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_bias_swiglu.py +0 -0
  175. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_cross_entropy.py +0 -0
  176. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_indices_converter.py +0 -0
  177. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_layer_norm.py +0 -0
  178. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_mla_yarn_rope_apply.py +0 -0
  179. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_pad_routing_map.py +0 -0
  180. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_softmax.py +0 -0
  181. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/fusions/fused_weighted_squared_relu.py +0 -0
  182. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/hyper_comm_grid.py +0 -0
  183. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/__init__.py +0 -0
  184. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/async_stream.py +0 -0
  185. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/common_inference_params.py +0 -0
  186. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/communication_utils.py +0 -0
  187. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/contexts/__init__.py +0 -0
  188. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/contexts/base_context.py +0 -0
  189. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/contexts/dynamic_chunk_allocator.py +0 -0
  190. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/contexts/dynamic_context.py +0 -0
  191. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/contexts/static_context.py +0 -0
  192. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/engines/__init__.py +0 -0
  193. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/engines/abstract_engine.py +0 -0
  194. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/engines/mcore_engine.py +0 -0
  195. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/engines/static_engine.py +0 -0
  196. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/model_inference_wrappers/__init__.py +0 -0
  197. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +0 -0
  198. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/model_inference_wrappers/gpt/__init__.py +0 -0
  199. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +0 -0
  200. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +0 -0
  201. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +0 -0
  202. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/model_inference_wrappers/t5/__init__.py +0 -0
  203. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +0 -0
  204. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/scheduler.py +0 -0
  205. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/text_generation_controllers/__init__.py +0 -0
  206. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +0 -0
  207. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +0 -0
  208. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +0 -0
  209. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference/utils.py +0 -0
  210. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/inference_params.py +0 -0
  211. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/jit.py +0 -0
  212. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/T5/__init__.py +0 -0
  213. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/T5/t5_model.py +0 -0
  214. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/T5/t5_spec.py +0 -0
  215. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/__init__.py +0 -0
  216. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/bert/__init__.py +0 -0
  217. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/bert/bert_layer_specs.py +0 -0
  218. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/bert/bert_lm_head.py +0 -0
  219. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/bert/bert_model.py +0 -0
  220. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/bert/pooler.py +0 -0
  221. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/__init__.py +0 -0
  222. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/embeddings/__init__.py +0 -0
  223. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/embeddings/language_model_embedding.py +0 -0
  224. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/embeddings/relative_pos_embedding.py +0 -0
  225. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/embeddings/rope_utils.py +0 -0
  226. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/embeddings/rotary_pos_embedding.py +0 -0
  227. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/language_module/__init__.py +0 -0
  228. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/model_chunk_schedule_plan.py +0 -0
  229. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/vision_module/__init__.py +0 -0
  230. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/common/vision_module/vision_module.py +0 -0
  231. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/gpt/__init__.py +0 -0
  232. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/gpt/fine_grained_callables.py +0 -0
  233. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/gpt/gpt_model.py +0 -0
  234. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +0 -0
  235. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/huggingface/__init__.py +0 -0
  236. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/huggingface/module.py +0 -0
  237. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mamba/__init__.py +0 -0
  238. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mamba/mamba_layer_specs.py +0 -0
  239. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mamba/mamba_model.py +0 -0
  240. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mimo/__init__.py +0 -0
  241. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mimo/config/__init__.py +0 -0
  242. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mimo/config/base_configs.py +0 -0
  243. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mimo/model/__init__.py +0 -0
  244. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mimo/model/base.py +0 -0
  245. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mimo/submodules/audio.py +0 -0
  246. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mimo/submodules/base.py +0 -0
  247. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/mimo/submodules/vision.py +0 -0
  248. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/multimodal/__init__.py +0 -0
  249. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/multimodal/llava_spec.py +0 -0
  250. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/__init__.py +0 -0
  251. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/base_attention.py +0 -0
  252. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/config.py +0 -0
  253. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/decoder_attention.py +0 -0
  254. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/decoder_spec.py +0 -0
  255. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/encoder_attention.py +0 -0
  256. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/encoder_spec.py +0 -0
  257. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/model.py +0 -0
  258. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/retro/utils.py +0 -0
  259. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/vision/__init__.py +0 -0
  260. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/vision/clip_vit_model.py +0 -0
  261. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/models/vision/vit_layer_specs.py +0 -0
  262. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/msc_utils.py +0 -0
  263. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/num_microbatches_calculator.py +0 -0
  264. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer/cpu_offloading/__init__.py +0 -0
  265. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +0 -0
  266. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer/grad_scaler.py +0 -0
  267. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/optimizer_param_scheduler.py +0 -0
  268. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/packed_seq_params.py +0 -0
  269. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/parallel_state.py +0 -0
  270. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/pipeline_parallel/__init__.py +0 -0
  271. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/pipeline_parallel/combined_1f1b.py +0 -0
  272. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/pipeline_parallel/p2p_communication.py +0 -0
  273. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/pipeline_parallel/schedules.py +0 -0
  274. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/pipeline_parallel/utils.py +0 -0
  275. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/post_training/__init__.py +0 -0
  276. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/post_training/modelopt/__init__.py +0 -0
  277. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/post_training/modelopt/gpt/__init__.py +0 -0
  278. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/post_training/modelopt/gpt/model_specs.py +0 -0
  279. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/post_training/modelopt/layers.py +0 -0
  280. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/post_training/modelopt/mamba/__init__.py +0 -0
  281. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/post_training/modelopt/mamba/model_specs.py +0 -0
  282. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/process_groups_config.py +0 -0
  283. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/quantization/__init__.py +0 -0
  284. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/quantization/quant_config.py +0 -0
  285. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/quantization/utils.py +0 -0
  286. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/requirements.txt +0 -0
  287. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/rerun_state_machine.py +0 -0
  288. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/ssm/__init__.py +0 -0
  289. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/ssm/mamba_block.py +0 -0
  290. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/ssm/mamba_context_parallel.py +0 -0
  291. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/ssm/mamba_hybrid_layer_allocation.py +0 -0
  292. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/ssm/mamba_layer.py +0 -0
  293. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/ssm/mamba_mixer.py +0 -0
  294. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/ssm/mlp_layer.py +0 -0
  295. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/ssm/triton_cache_manager.py +0 -0
  296. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/tensor_parallel/__init__.py +0 -0
  297. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/tensor_parallel/cross_entropy.py +0 -0
  298. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/tensor_parallel/data.py +0 -0
  299. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/tensor_parallel/mappings.py +0 -0
  300. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/tensor_parallel/utils.py +0 -0
  301. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/timers.py +0 -0
  302. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/__init__.py +0 -0
  303. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/custom_layers/__init__.py +0 -0
  304. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/custom_layers/transformer_engine.py +0 -0
  305. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/dot_product_attention.py +0 -0
  306. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/enums.py +0 -0
  307. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/heterogeneous/heterogeneous_config.py +0 -0
  308. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/heterogeneous/linear_replacements.py +0 -0
  309. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/identity_op.py +0 -0
  310. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/module.py +0 -0
  311. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/__init__.py +0 -0
  312. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/fused_a2a.py +0 -0
  313. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/grouped_gemm_util.py +0 -0
  314. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/moe_utils.py +0 -0
  315. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/router.py +0 -0
  316. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/token_dispatcher.py +0 -0
  317. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/moe/upcycling_utils.py +0 -0
  318. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/multi_token_prediction.py +0 -0
  319. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/pipeline_parallel_layer_layout.py +0 -0
  320. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/spec_utils.py +0 -0
  321. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/torch_layer_norm.py +0 -0
  322. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/torch_norm.py +0 -0
  323. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron/core/transformer/transformer_block.py +0 -0
  324. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron_core.egg-info/dependency_links.txt +0 -0
  325. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/megatron_core.egg-info/top_level.txt +0 -0
  326. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/setup.cfg +0 -0
  327. {megatron_core-0.14.0rc6 → megatron_core-0.15.0rc0}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: megatron-core
3
- Version: 0.14.0rc6
3
+ Version: 0.15.0rc0
4
4
  Summary: Megatron Core - a library for efficient and scalable training of transformer based models
5
5
  Author-email: NVIDIA <nemo-toolkit@nvidia.com>
6
6
  Maintainer-email: NVIDIA <nemo-toolkit@nvidia.com>
@@ -31,7 +31,7 @@ Description-Content-Type: text/markdown
31
31
  License-File: LICENSE
32
32
  Requires-Dist: torch
33
33
  Requires-Dist: numpy<2.0.0
34
- Requires-Dist: packaging
34
+ Requires-Dist: packaging>=24.2
35
35
  Provides-Extra: mlm
36
36
  Requires-Dist: flask-restful; extra == "mlm"
37
37
  Requires-Dist: sentencepiece; extra == "mlm"
@@ -43,7 +43,7 @@ Requires-Dist: einops~=0.8; extra == "dev"
43
43
  Requires-Dist: tensorstore!=0.1.46,!=0.1.72,~=0.1; extra == "dev"
44
44
  Requires-Dist: nvtx~=0.2; extra == "dev"
45
45
  Requires-Dist: transformers~=4.53; extra == "dev"
46
- Requires-Dist: multi-storage-client~=0.20; extra == "dev"
46
+ Requires-Dist: multi-storage-client<0.26,~=0.25; extra == "dev"
47
47
  Requires-Dist: opentelemetry-api~=1.33.1; extra == "dev"
48
48
  Requires-Dist: setuptools<80.0.0; extra == "dev"
49
49
  Requires-Dist: mamba-ssm~=2.2; extra == "dev"
@@ -103,11 +103,19 @@ def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
103
103
  else:
104
104
  only_left = []
105
105
  only_right = []
106
+ mismatch_debug_data = [prefix, type(x1), type(x2)]
106
107
  if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
107
- if x1.device != x2.device:
108
- _is_mismatch = not torch.all(x1.cpu() == x2.cpu())
109
- else:
110
- _is_mismatch = not torch.all(x1 == x2)
108
+ try:
109
+ if x1.device != x2.device:
110
+ _is_mismatch = not torch.all(x1.cpu() == x2.cpu())
111
+ else:
112
+ _is_mismatch = not torch.all(x1 == x2)
113
+ mismatch_debug_data.extend(
114
+ [(x1 != x2).sum(), (x1 != x2).shape, (x1 != x2).nonzero().tolist()]
115
+ )
116
+ except (RuntimeError, TypeError, ValueError):
117
+ _is_mismatch = True
118
+ mismatch_debug_data.extend([x1.shape, x2.shape])
111
119
  # TODO: change with concrete type that has both replica_id and data attrs
112
120
  elif hasattr(x1, "replica_id") and hasattr(x2, "replica_id"):
113
121
  assert type(x1) == type(x2)
@@ -122,7 +130,7 @@ def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
122
130
  _is_mismatch = True
123
131
 
124
132
  if _is_mismatch:
125
- mismatch.append((prefix, type(x1), type(x2)))
133
+ mismatch.append(tuple(mismatch_debug_data))
126
134
 
127
135
  return only_left, only_right, mismatch
128
136
 
@@ -135,23 +135,23 @@ class ShardedTensor(ShardedBase):
135
135
  f"equal to global shape dimensions for {self}"
136
136
  )
137
137
 
138
- for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
139
- # NOTE: In custom FSDP, we have a case where a new parameter shard is created locally.
140
- # For example, consider parameters [p0, p1, p2] sharded across GPU0 and GPU1.
141
- # GPU0 receives p0 and a portion of p1, while GPU1 receives the
142
- # remaining portion of p1 and p2.
143
- # As a result, there is no parameter shard of p2 on GPU0, and
144
- # the shape of p2 on GPU0 is zero.
145
- if sh != 0 and off % sh != 0:
146
- raise CheckpointingException(
147
- f"Global offset ({off}) must be divisible by local shape ({sh}) for {self}."
148
- )
138
+ if self.axis_fragmentations is not None:
139
+ for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
140
+ if sh != 0 and off % sh != 0:
141
+ raise CheckpointingException(
142
+ f"Global offset ({off}) must be divisible by local shape ({sh}) for {self}."
143
+ )
149
144
 
150
145
  if has_flattened_range and self.flattened_range.step is not None:
151
146
  raise CheckpointingException(
152
147
  f"`step` argument in the flattened range of a ShardedTensor is not supported."
153
148
  )
154
149
 
150
+ @property
151
+ def has_regular_grid(self):
152
+ """Alias for having a regular sharding grid."""
153
+ return self.axis_fragmentations is not None
154
+
155
155
  def global_slice(self) -> Tuple[Union[int, slice], ...]:
156
156
  """
157
157
  Returns a tuple of int and slice objects representing a slice of the
@@ -25,6 +25,12 @@ from .mapping import (
25
25
  )
26
26
  from .utils import extract_sharded_tensors_and_factories
27
27
 
28
+ KEEP_VARS_HINT = (
29
+ " Make sure state dict contains original torch.nn.Parameters (not pure torch.Tensors)"
30
+ " by passing `keep_vars=True` to `.state_dict()`. If any transformation of the original"
31
+ " parameter is needed, use a ShardedTensorFactory."
32
+ )
33
+
28
34
 
29
35
  def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]:
30
36
  """Generate mapping from optimizer param to optimizer state id."""
@@ -79,9 +79,24 @@ class AsyncRequest(NamedTuple):
79
79
 
80
80
  This logic is equivalent to what should happen in case of the async call.
81
81
  """
82
+ # preload tensors.
83
+ async_fn_args = list(self.async_fn_args)
84
+ if self.preload_fn:
85
+ assert len(async_fn_args) == 3, "Expected 3 args to be passed to async function"
86
+ # The async_fn is passed as a partial functool with pre-determined args
87
+ # In the async_fn_args we pass the remaining positional args required by the async_fn
88
+ # async_fn_args[1] refers to the write_buckets
89
+ # To ensure we stage the write_buckets to CPU memory for sync CP,
90
+ # we replace it with preload_fn callable that returns the CPU staged tensors
91
+ async_fn_args[1] = self.preload_fn()
92
+ # persist the state
82
93
  if self.async_fn is not None:
83
- self.async_fn(*self.async_fn_args)
94
+ self.async_fn(*async_fn_args, **self.async_fn_kwargs)
95
+
96
+ # This utility implements a sync cp save. Hence the barrier.
84
97
  torch.distributed.barrier()
98
+
99
+ # Finalize the CP state
85
100
  for finalize_fn in self.finalize_fns:
86
101
  finalize_fn()
87
102
 
@@ -150,7 +165,7 @@ class AsyncCaller(ABC):
150
165
  return ten[0] == 0
151
166
 
152
167
  @abstractmethod
153
- def close(self):
168
+ def close(self, abort=False):
154
169
  """Terminate the async caller at exit of an application or some termination conditions"""
155
170
  logger.info(f"AsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller")
156
171
 
@@ -237,15 +252,23 @@ class TemporalAsyncCaller(AsyncCaller):
237
252
  is_done = True
238
253
  return is_done
239
254
 
240
- def close(self):
255
+ def close(self, abort=False):
241
256
  """For TemporalAsyncCaller, this method is called explictly in `is_current_async_calls_done`
242
257
 
243
258
  This method make sure the TemporalAsyncCaller terminated
244
259
  with all its assigned async request completed
260
+
261
+ Args:
262
+ abort (bool, optional): Default to False. Needs to be manually set to true when
263
+ the checkpoint async process needs to be aborted.
245
264
  """
246
265
  if self.process:
247
266
  logger.debug(f"rank: {torch.distributed.get_rank()}, joining self.process")
248
- self.process.join()
267
+ if abort:
268
+ logger.warning(f"Temporal worker aborted in rank {torch.distributed.get_rank()}")
269
+ self.process.kill()
270
+ else:
271
+ self.process.join()
249
272
  self.process = None
250
273
  logger.debug(
251
274
  "TemporalAsyncCaller: Async process join finished "
@@ -388,18 +411,25 @@ class PersistentAsyncCaller(AsyncCaller):
388
411
 
389
412
  return is_done
390
413
 
391
- def close(self):
414
+ def close(self, abort=False):
392
415
  """Wait on the left async requests and terminate the PersistentAsyncCaller
393
416
 
394
417
  Signals the PersistentAsyncCaller by sending a 'DONE' message to make it terminated
418
+ Args:
419
+ abort (bool, optional): Default to False. Needs to be manually set to true when
420
+ the checkpoint async process needs to be aborted.
395
421
  """
396
422
  logger.info(
397
423
  f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller"
398
424
  )
399
425
  if self.process:
400
- self.queue.put('DONE')
401
- self.queue.join()
402
- self.process.join()
426
+ if abort:
427
+ logger.warning(f"Persistent worker aborted in rank {torch.distributed.get_rank()}")
428
+ self.process.kill()
429
+ else:
430
+ self.queue.put('DONE')
431
+ self.queue.join()
432
+ self.process.join()
403
433
  self.process = None
404
434
 
405
435
  def __del__(self):
@@ -528,6 +558,9 @@ class AsyncCallsQueue:
528
558
  blocking (bool, optional): if True, will wait until all active requests
529
559
  are done. Otherwise, finalizes only the async request that already
530
560
  finished. Defaults to False.
561
+
562
+ no_dist (bool, Optional): if True, training ranks simply check its
563
+ asynchronous checkpoint writer without synchronization.
531
564
  Returns:
532
565
  List[int]: list of indices (as returned by `schedule_async_request`)
533
566
  of async calls that have been successfully finalized.
@@ -545,8 +578,8 @@ class AsyncCallsQueue:
545
578
  finalize_fn()
546
579
  ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device())
547
580
  torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX)
548
- assert ten.item() == call_idx, 'Unmatched async calls. '
549
- 'That probably means not all ranks are participating in async finalization'
581
+ assert ten.item() == call_idx, "Unmatched async calls. "
582
+ "That probably means not all ranks are participating in async finalization"
550
583
  call_idx_finalized.append(call_idx)
551
584
  return call_idx_finalized
552
585
 
@@ -554,8 +587,13 @@ class AsyncCallsQueue:
554
587
  """Get the number of active async calls."""
555
588
  return len(self.async_calls)
556
589
 
557
- def close(self):
558
- """Finalize all calls upon closing."""
559
- self.maybe_finalize_async_calls(blocking=True)
590
+ def close(self, abort=False):
591
+ """Finalize all calls upon closing.
592
+ Args:
593
+ abort (bool, optional): Default to False. Needs to be manually set to true when
594
+ the checkpoint async process needs to be aborted.
595
+ """
596
+ if not abort:
597
+ self.maybe_finalize_async_calls(blocking=True)
560
598
  if self.persistent and self.persistent_caller:
561
- self.persistent_caller.close()
599
+ self.persistent_caller.close(abort=abort)
@@ -221,8 +221,4 @@ class AsyncSaveShardedStrategy(SaveShardedStrategy):
221
221
  def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union[str, Path]):
222
222
  """Each async strategy can be trivially used as a sync strategy."""
223
223
  async_request = self.async_save(sharded_state_dict, checkpoint_dir)
224
- # multiprocessing routines may cause issue when called on parent process
225
- # We keep this verbose call for now
226
- global async_calls
227
- async_calls.schedule_async_request(async_request)
228
- async_calls.maybe_finalize_async_calls(blocking=True)
224
+ async_request.execute_sync()
@@ -0,0 +1,196 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ from itertools import chain
3
+
4
+ import torch
5
+ from torch.distributed.checkpoint.metadata import (
6
+ ChunkStorageMetadata,
7
+ MetadataIndex,
8
+ TensorProperties,
9
+ )
10
+ from torch.distributed.checkpoint.planner import TensorWriteData, WriteItem, WriteItemType
11
+
12
+ from ..mapping import ShardedTensor
13
+
14
+
15
+ class CheckpointableShardedTensor(torch.Tensor):
16
+ """ShardedTensor extension compatible with PyTorch DCP checkpointing library.
17
+
18
+ Implements the torch.distributed._checkpointable._Checkpointable protocol.
19
+ """
20
+
21
+ def __new__(cls, data: torch.Tensor, sh_ten: ShardedTensor):
22
+ return torch.Tensor._make_wrapper_subclass(cls, torch.Size(sh_ten.global_shape))
23
+
24
+ def __init__(self, data: torch.Tensor, sh_ten: ShardedTensor):
25
+ self._data = data
26
+ self._sh_ten = sh_ten
27
+
28
+ def __create_write_items__(
29
+ self, fqn: str, sh_ten: 'CheckpointableShardedTensor', index: int = None
30
+ ) -> list[WriteItem]:
31
+ """Simple translation from ShardedTensor offsets into DCP offsets.
32
+
33
+ Args:
34
+ fqn (str): tensor FQN.
35
+ sh_ten (CheckpointableShardedTensor): same as `self`
36
+ index (int): specifies index within the LocalShardsContainer.
37
+ This is an optimization hint used in DCP.
38
+
39
+ Returns:
40
+ List[WriteItem]: list of DCP WriteItem metadata objects.
41
+ """
42
+ offsets = torch.Size(sh_ten._sh_ten.global_offset)
43
+ global_shape = torch.Size(sh_ten._sh_ten.global_shape)
44
+ chunk_size = torch.Size(sh_ten._sh_ten.local_shape)
45
+ assert chunk_size == sh_ten._sh_ten.data.size()
46
+
47
+ return [
48
+ WriteItem(
49
+ index=MetadataIndex(fqn, offsets, index),
50
+ type=WriteItemType.SHARD,
51
+ tensor_data=TensorWriteData(
52
+ chunk=ChunkStorageMetadata(offsets=offsets, sizes=chunk_size),
53
+ properties=TensorProperties.create_from_tensor(sh_ten._sh_ten.data),
54
+ size=global_shape,
55
+ ),
56
+ )
57
+ ]
58
+
59
+ def __create_chunk_list__(self) -> list[ChunkStorageMetadata]:
60
+ """Simple translation from ShardedTensor offsets into DCP offsets.
61
+
62
+ Returns:
63
+ List[ChunkStorageMetadata]: list of DCP ChunkStorageMetadata metadata objects.
64
+ """
65
+ offsets = torch.Size(self._sh_ten.global_offset)
66
+ chunk_size = torch.Size(self._sh_ten.local_shape)
67
+ assert chunk_size == self._sh_ten.data.size()
68
+
69
+ return [ChunkStorageMetadata(offsets=offsets, sizes=chunk_size)]
70
+
71
+ def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor:
72
+ """Trivial implementation which simply yields the underlying tensor.
73
+
74
+ Args:
75
+ index (MetadataIndex): unused
76
+
77
+ Returns:
78
+ Tensor: the underlying data tensor
79
+ """
80
+ return self._sh_ten.data
81
+
82
+ @classmethod
83
+ def from_sh_ten(cls, sh_ten: ShardedTensor) -> 'CheckpointableShardedTensor':
84
+ """Constructor which turns a ShardedTensor into CheckpointableShardedTensor
85
+
86
+ Args:
87
+ sh_ten (ShardedTensor): a sharded tensor to wrap
88
+
89
+ Returns:
90
+ CheckpointableShardedTensor: wrapped ShardedTensor
91
+ """
92
+ assert isinstance(sh_ten, ShardedTensor)
93
+ return cls(sh_ten.data, sh_ten)
94
+
95
+ @classmethod
96
+ def __torch_dispatch__(cls, func, types, args, kwargs=None):
97
+ """Placeholder implementation."""
98
+ raise NotImplementedError(
99
+ f"{cls.__name__}.__torch_dispatch__ not implemented."
100
+ f" {cls.__name__} shouldn't be used with Tensor operations."
101
+ )
102
+
103
+ def __repr__(self):
104
+ return f'{self.__class__.__name__}({self._sh_ten.__repr__()})'
105
+
106
+
107
+ class LocalShardsContainer(torch.Tensor):
108
+ """DCP compatible container for local shards.
109
+
110
+ PyTorch DCP requires a single tensor per rank for a given global tensor FQN.
111
+ This class acts as a container allowing multiple checkpointable shards per rank.
112
+
113
+ Implements the torch.distributed._checkpointable._Checkpointable protocol.
114
+ """
115
+
116
+ @staticmethod
117
+ def __new__(cls, local_shards: list[torch.Tensor]) -> "LocalShardsContainer":
118
+ assert len(local_shards) > 0
119
+ # This assumes local shard already has correct size info
120
+ return torch.Tensor._make_wrapper_subclass(cls, local_shards[0].size())
121
+
122
+ def __init__(self, local_shards: list[torch.Tensor]):
123
+ for local_shard in local_shards:
124
+ # this is needed only for __get_tensor_shard__
125
+ assert isinstance(local_shard, CheckpointableShardedTensor)
126
+ self._local_shards = local_shards
127
+
128
+ @classmethod
129
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
130
+ """Placeholder implementation."""
131
+ raise NotImplementedError(
132
+ f"{cls.__name__}.__torch_dispatch__ not implemented."
133
+ f" {cls.__name__} shouldn't be used with Tensor operations."
134
+ )
135
+
136
+ def __create_write_items__(
137
+ self, fqn: str, local_shards_cont: 'LocalShardsContainer'
138
+ ) -> list[object]:
139
+ """Delegates creating write items to local shards.
140
+
141
+ Args:
142
+ fqn (str): tensor FQN.
143
+ local_shards_cont (LocalShardsContainer): same as `self`
144
+
145
+ Returns:
146
+ List[WriteItem]: list of DCP WriteItem metadata objects.
147
+ """
148
+ return list(
149
+ chain.from_iterable(
150
+ shard.__create_write_items__(fqn, shard, index=index)
151
+ for index, shard in enumerate(local_shards_cont._local_shards)
152
+ )
153
+ )
154
+
155
+ def __create_chunk_list__(self) -> list[ChunkStorageMetadata]:
156
+ """Delegates creating chunk items to local shards.
157
+
158
+ Returns:
159
+ List[ChunkStorageMetadata]: list of DCP ChunkStorageMetadata metadata objects.
160
+ """
161
+ return list(
162
+ chain.from_iterable(shard.__create_chunk_list__() for shard in self._local_shards)
163
+ )
164
+
165
+ def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor:
166
+ """Performs shard matching lookup based on index hint or offset.
167
+
168
+ Args:
169
+ index (MetadataIndex): metadata specifying the offset of the queried shard.
170
+ Optionally provides an index hint which speeds up the lookup.
171
+
172
+ Returns:
173
+ Tensor: the matching shard data tensor
174
+ """
175
+ if index.offset is None:
176
+ raise ValueError(
177
+ f"Cannot lookup {index.fqn} for a LocalShardsContainer without an offset"
178
+ )
179
+
180
+ shards = self._local_shards
181
+ # index hint direct lookup
182
+ if index.index is not None:
183
+ if (
184
+ len(shards) > index.index
185
+ and torch.Size(shards[index.index]._sh_ten.global_offset) == index.offset
186
+ ):
187
+ return shards[index.index].__get_tensor_shard__(index)
188
+
189
+ # slow linear search
190
+ for shard in shards:
191
+ if torch.Size(shard._sh_ten.global_offset) == index.offset:
192
+ return shard.__get_tensor_shard__(index)
193
+ raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'")
194
+
195
+ def __repr__(self):
196
+ return f'{self.__class__.__name__}({self._local_shards.__repr__()})'
@@ -57,6 +57,7 @@ from .base import (
57
57
  register_default_strategy,
58
58
  )
59
59
  from .cached_metadata_filesystem_reader import CachedMetadataFileSystemReader
60
+ from .checkpointable import CheckpointableShardedTensor, LocalShardsContainer
60
61
  from .filesystem_async import FileSystemWriterAsync
61
62
  from .resharding import (
62
63
  TensorReformulationMetadata,
@@ -240,14 +241,18 @@ def sharded_tensor_to_torch_sharded_tensor(
240
241
  placement = f"rank:{rank}/cuda"
241
242
  for sh_ten in local_global_offsets[offset]:
242
243
  if has_flattened_range:
243
- assert offset == sh_ten.local_chunk_offset_in_global()
244
+ assert offset == sh_ten.local_chunk_offset_in_global(), (
245
+ offset,
246
+ sh_ten.local_chunk_offset_in_global(),
247
+ )
244
248
  # This is not an actual offset, but an offset of the whole shard
245
249
  # This is needed for a PyT Dist internal integrity check
246
- offset = sh_ten.local_chunk_offset_in_global() + (0,)
250
+ _shard_offset = sh_ten.local_chunk_offset_in_global() + (0,)
247
251
  size = (1,) * len(offsets_shape) + global_shape[-1:]
248
252
  else:
249
253
  size = sh_ten.data.shape
250
- shard_metadata.append(ShardMetadata(offset, size, placement))
254
+ _shard_offset = offset
255
+ shard_metadata.append(ShardMetadata(_shard_offset, size, placement))
251
256
 
252
257
  else:
253
258
  # pylint: disable=line-too-long
@@ -312,7 +317,7 @@ def mcore_to_pyt_state_dict(
312
317
  rank = torch.distributed.get_rank()
313
318
  pyt_state_dict = {}
314
319
 
315
- def _mcore_to_torch_sharded_tensor(sh_tens: List[ShardedTensor]) -> TorchShardedTensor:
320
+ def _mcore_to_dcp_compatible_tensor(sh_tens: List[ShardedTensor]) -> TorchShardedTensor:
316
321
  """Build a PyT ShardedTensor from given shards.
317
322
 
318
323
  During loading:
@@ -335,11 +340,24 @@ def mcore_to_pyt_state_dict(
335
340
  if sh_ten.allow_shape_mismatch and is_loading:
336
341
  sh_ten.data.zero_()
337
342
 
338
- torch_sh_ten = sharded_tensor_to_torch_sharded_tensor(
339
- sh_tens, rank, load_legacy_1d_flatten_tensors
340
- )
341
- torch_sh_ten.key = sh_tens[0].key
342
- return torch_sh_ten
343
+ if not sh_tens[0].has_regular_grid:
344
+ if not is_torch_min_version("2.6a0"):
345
+ raise CheckpointingException(
346
+ f"Uneven sharding not supported for PyTorch version {get_torch_version()}"
347
+ )
348
+ assert sh_tens[0].flattened_range is None
349
+ if len(sh_tens) > 1:
350
+ return LocalShardsContainer(
351
+ [CheckpointableShardedTensor.from_sh_ten(sh_ten) for sh_ten in sh_tens]
352
+ )
353
+ else:
354
+ return CheckpointableShardedTensor.from_sh_ten(sh_tens[0])
355
+ else:
356
+ torch_sh_ten = sharded_tensor_to_torch_sharded_tensor(
357
+ sh_tens, rank, load_legacy_1d_flatten_tensors
358
+ )
359
+ torch_sh_ten.key = sh_tens[0].key
360
+ return torch_sh_ten
343
361
 
344
362
  def _mcore_to_torch_sharded_object(sh_objs: List[ShardedObject]) -> io.BytesIO:
345
363
  """Build io.BytesIO from given sharded objects data."""
@@ -351,7 +369,7 @@ def mcore_to_pyt_state_dict(
351
369
  for k, v in state_dict.items():
352
370
  if isinstance(v[0], ShardedTensor):
353
371
  v = cast(List[ShardedTensor], v)
354
- pyt_state_dict[k] = _mcore_to_torch_sharded_tensor(v)
372
+ pyt_state_dict[k] = _mcore_to_dcp_compatible_tensor(v)
355
373
  else:
356
374
  v = cast(List[ShardedObject], v)
357
375
  pyt_state_dict[k] = _mcore_to_torch_sharded_object(v)
@@ -359,12 +377,20 @@ def mcore_to_pyt_state_dict(
359
377
  return pyt_state_dict
360
378
 
361
379
 
362
- def _unwrap_pyt_sharded_tensor(sh_ten: TorchShardedTensor) -> List[torch.Tensor]:
380
+ def _unwrap_pyt_sharded_tensor(
381
+ sh_ten: Union[TorchShardedTensor, CheckpointableShardedTensor, LocalShardsContainer, Any]
382
+ ) -> Union[List[torch.Tensor], Any]:
363
383
  """Unwrap tensor from PyT ShardedTensor instance.
364
384
 
365
385
  If `prepend_axis_num` was non-zero (which is specific to MCore ShardedTensor)
366
386
  then the tensor has additional singleton dimensions which should be squeezed.
367
387
  """
388
+ if isinstance(sh_ten, CheckpointableShardedTensor):
389
+ return [sh_ten._sh_ten.data]
390
+ if isinstance(sh_ten, LocalShardsContainer):
391
+ return [local_shard._sh_ten.data for local_shard in sh_ten._local_shards]
392
+ if not isinstance(sh_ten, TorchShardedTensor):
393
+ return sh_ten
368
394
  mcore_sh_ten = sh_ten.mcore_sh_ten
369
395
  ret_tensors = []
370
396
  for sh in sh_ten.local_shards():
@@ -930,10 +956,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy):
930
956
  Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict
931
957
  )
932
958
  # Unwrap ShardedTensors and return to original state dict
933
- mcore_state_dict = {
934
- k: v if not isinstance(v, TorchShardedTensor) else _unwrap_pyt_sharded_tensor(v)
935
- for k, v in pyt_state_dict.items()
936
- }
959
+ mcore_state_dict = {k: _unwrap_pyt_sharded_tensor(v) for k, v in pyt_state_dict.items()}
937
960
  mcore_state_dict = _replace_sharded_keys_with_state_dict_keys(
938
961
  mcore_state_dict, flat_mapping, rename_mapping # type: ignore[arg-type]
939
962
  )
@@ -175,6 +175,11 @@ def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
175
175
  compressor=None,
176
176
  fill_value=None,
177
177
  write_empty_chunks=True,
178
+ synchronizer=(
179
+ zarr.ProcessSynchronizer(str(checkpoint_dir / f'{sharded_tensor.key}.sync'))
180
+ if sharded_tensor.flattened_range is not None
181
+ else None
182
+ ),
178
183
  )
179
184
  logger.debug(f"Created a new Zarr array at {checkpoint_dir / sharded_tensor.key}")
180
185
  except zarr.errors.ContainsArrayError as e:
@@ -328,7 +333,7 @@ def load_zarr_based_sharded_metadata(
328
333
 
329
334
  sharded_state_dict = {}
330
335
  for subdir in checkpoint_dir.iterdir():
331
- if not subdir.is_dir() or not (subdir / ".zarray").exists():
336
+ if not subdir.is_dir() or not (subdir / ".zarray").exists() or subdir.suffix == ".sync":
332
337
  continue
333
338
  key = subdir.name
334
339
  arr_shape, arr_dtype = get_shape_dtype_fn(str(subdir))
@@ -450,6 +450,7 @@ def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
450
450
  local_shape = some_rank_shard.local_shape
451
451
  dtype = some_rank_shard.dtype
452
452
  has_flattened_range = some_rank_shard.flattened_range is not None
453
+ has_regular_sharding_grid = some_rank_shard.has_regular_grid
453
454
  for rank, sharding in rank_sharding:
454
455
  assert sharding.dtype == dtype, (sharding.dtype, dtype, some_rank_shard)
455
456
  assert sharding.global_shape == global_shape, (
@@ -457,17 +458,26 @@ def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
457
458
  global_shape,
458
459
  some_rank_shard,
459
460
  )
460
- assert sharding.local_shape == local_shape, (
461
- sharding.local_shape,
462
- local_shape,
461
+ assert sharding.has_regular_grid == has_regular_sharding_grid, (
462
+ has_regular_sharding_grid,
463
463
  some_rank_shard,
464
464
  )
465
+ if has_regular_sharding_grid:
466
+ assert sharding.local_shape == local_shape, (
467
+ sharding.local_shape,
468
+ local_shape,
469
+ some_rank_shard,
470
+ )
465
471
  assert (sharding.flattened_range is not None) == has_flattened_range, (
466
472
  (sharding.flattened_range is not None),
467
473
  has_flattened_range,
468
474
  some_rank_shard,
469
475
  )
470
476
 
477
+ if not has_regular_sharding_grid:
478
+ # In case of uneven sharding we defer the validation to DCP
479
+ return
480
+
471
481
  shard_access_cnt = _compute_shards_access(rank_sharding)
472
482
  if has_flattened_range:
473
483
  map_reduce(
@@ -8,5 +8,6 @@ except ImportError:
8
8
  from .distributed_data_parallel import DistributedDataParallel
9
9
  from .distributed_data_parallel_config import DistributedDataParallelConfig
10
10
  from .finalize_model_grads import finalize_model_grads
11
+ from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel
11
12
  from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel
12
13
  from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig