megatron-core 0.16.0rc0.dev127461__cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (356) hide show
  1. megatron/core/README.md +51 -0
  2. megatron/core/__init__.py +52 -0
  3. megatron/core/activations.py +23 -0
  4. megatron/core/config.py +14 -0
  5. megatron/core/config_logger.py +126 -0
  6. megatron/core/datasets/__init__.py +0 -0
  7. megatron/core/datasets/bert_dataset.py +190 -0
  8. megatron/core/datasets/blended_dataset.py +212 -0
  9. megatron/core/datasets/blended_megatron_dataset_builder.py +552 -0
  10. megatron/core/datasets/blended_megatron_dataset_config.py +197 -0
  11. megatron/core/datasets/gpt_dataset.py +809 -0
  12. megatron/core/datasets/helpers.cpp +848 -0
  13. megatron/core/datasets/helpers.py +66 -0
  14. megatron/core/datasets/helpers_cpp.cpython-311-aarch64-linux-gnu.so +0 -0
  15. megatron/core/datasets/indexed_dataset.py +953 -0
  16. megatron/core/datasets/masked_dataset.py +423 -0
  17. megatron/core/datasets/megatron_dataset.py +185 -0
  18. megatron/core/datasets/megatron_tokenizer.py +162 -0
  19. megatron/core/datasets/multimodal_dataset.py +62 -0
  20. megatron/core/datasets/object_storage_utils.py +281 -0
  21. megatron/core/datasets/retro/__init__.py +5 -0
  22. megatron/core/datasets/retro/config/__init__.py +16 -0
  23. megatron/core/datasets/retro/config/bert_embedders.py +49 -0
  24. megatron/core/datasets/retro/config/config.py +135 -0
  25. megatron/core/datasets/retro/config/gpt_chunk_datasets.py +15 -0
  26. megatron/core/datasets/retro/config/tokenizers.py +15 -0
  27. megatron/core/datasets/retro/db/__init__.py +9 -0
  28. megatron/core/datasets/retro/db/build.py +649 -0
  29. megatron/core/datasets/retro/db/dataset.py +114 -0
  30. megatron/core/datasets/retro/db/utils.py +398 -0
  31. megatron/core/datasets/retro/external_libs.py +13 -0
  32. megatron/core/datasets/retro/index/__init__.py +11 -0
  33. megatron/core/datasets/retro/index/build.py +339 -0
  34. megatron/core/datasets/retro/index/factory.py +40 -0
  35. megatron/core/datasets/retro/index/index.py +150 -0
  36. megatron/core/datasets/retro/index/indexes/__init__.py +10 -0
  37. megatron/core/datasets/retro/index/indexes/faiss_base.py +179 -0
  38. megatron/core/datasets/retro/index/indexes/faiss_par_add.py +253 -0
  39. megatron/core/datasets/retro/index/utils.py +126 -0
  40. megatron/core/datasets/retro/index/validate.py +194 -0
  41. megatron/core/datasets/retro/query/__init__.py +1 -0
  42. megatron/core/datasets/retro/query/gpt_chunk_dataset.py +109 -0
  43. megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +115 -0
  44. megatron/core/datasets/retro/query/query.py +449 -0
  45. megatron/core/datasets/retro/query/retro_dataset.py +251 -0
  46. megatron/core/datasets/retro/query/utils.py +35 -0
  47. megatron/core/datasets/retro/utils.py +386 -0
  48. megatron/core/datasets/t5_dataset.py +338 -0
  49. megatron/core/datasets/utils.py +92 -0
  50. megatron/core/datasets/utils_s3.py +5 -0
  51. megatron/core/dist_checkpointing/__init__.py +13 -0
  52. megatron/core/dist_checkpointing/core.py +93 -0
  53. megatron/core/dist_checkpointing/dict_utils.py +256 -0
  54. megatron/core/dist_checkpointing/exchange_utils.py +576 -0
  55. megatron/core/dist_checkpointing/mapping.py +738 -0
  56. megatron/core/dist_checkpointing/optimizer.py +148 -0
  57. megatron/core/dist_checkpointing/serialization.py +454 -0
  58. megatron/core/dist_checkpointing/state_dict_utils.py +112 -0
  59. megatron/core/dist_checkpointing/strategies/__init__.py +7 -0
  60. megatron/core/dist_checkpointing/strategies/async_utils.py +602 -0
  61. megatron/core/dist_checkpointing/strategies/base.py +224 -0
  62. megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +38 -0
  63. megatron/core/dist_checkpointing/strategies/checkpointable.py +196 -0
  64. megatron/core/dist_checkpointing/strategies/common.py +193 -0
  65. megatron/core/dist_checkpointing/strategies/filesystem_async.py +645 -0
  66. megatron/core/dist_checkpointing/strategies/fully_parallel.py +520 -0
  67. megatron/core/dist_checkpointing/strategies/resharding.py +320 -0
  68. megatron/core/dist_checkpointing/strategies/state_dict_saver.py +258 -0
  69. megatron/core/dist_checkpointing/strategies/tensorstore.py +149 -0
  70. megatron/core/dist_checkpointing/strategies/torch.py +1123 -0
  71. megatron/core/dist_checkpointing/strategies/two_stage.py +268 -0
  72. megatron/core/dist_checkpointing/strategies/zarr.py +357 -0
  73. megatron/core/dist_checkpointing/tensor_aware_state_dict.py +394 -0
  74. megatron/core/dist_checkpointing/utils.py +332 -0
  75. megatron/core/dist_checkpointing/validation.py +585 -0
  76. megatron/core/distributed/__init__.py +13 -0
  77. megatron/core/distributed/data_parallel_base.py +96 -0
  78. megatron/core/distributed/distributed_data_parallel.py +584 -0
  79. megatron/core/distributed/distributed_data_parallel_config.py +155 -0
  80. megatron/core/distributed/finalize_model_grads.py +488 -0
  81. megatron/core/distributed/fsdp/__init__.py +3 -0
  82. megatron/core/distributed/fsdp/mcore_fsdp_adapter.py +431 -0
  83. megatron/core/distributed/fsdp/src/__init__.py +13 -0
  84. megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py +51 -0
  85. megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py +146 -0
  86. megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py +540 -0
  87. megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +1223 -0
  88. megatron/core/distributed/fsdp/src/megatron_fsdp/package_info.py +27 -0
  89. megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +3812 -0
  90. megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py +460 -0
  91. megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +992 -0
  92. megatron/core/distributed/param_and_grad_buffer.py +1006 -0
  93. megatron/core/distributed/reduce_scatter_with_fp32_accumulation.py +92 -0
  94. megatron/core/distributed/torch_fully_sharded_data_parallel.py +154 -0
  95. megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +19 -0
  96. megatron/core/energy_monitor.py +91 -0
  97. megatron/core/enums.py +36 -0
  98. megatron/core/export/__init__.py +1 -0
  99. megatron/core/export/data_type.py +5 -0
  100. megatron/core/export/export_config.py +32 -0
  101. megatron/core/export/model_type.py +8 -0
  102. megatron/core/export/trtllm/__init__.py +1 -0
  103. megatron/core/export/trtllm/engine_builder/__init__.py +1 -0
  104. megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +172 -0
  105. megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +1 -0
  106. megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +50 -0
  107. megatron/core/export/trtllm/trt_model_config.py +25 -0
  108. megatron/core/export/trtllm/trt_model_type.py +14 -0
  109. megatron/core/export/trtllm/trtllm_helper.py +614 -0
  110. megatron/core/export/trtllm/trtllm_layers.py +169 -0
  111. megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +1 -0
  112. megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +293 -0
  113. megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +512 -0
  114. megatron/core/export/trtllm/trtllm_weights_converter/utils.py +8 -0
  115. megatron/core/extensions/__init__.py +0 -0
  116. megatron/core/extensions/kitchen.py +1092 -0
  117. megatron/core/extensions/transformer_engine.py +2118 -0
  118. megatron/core/extensions/transformer_engine_spec_provider.py +95 -0
  119. megatron/core/fp4_utils.py +139 -0
  120. megatron/core/fp8_utils.py +750 -0
  121. megatron/core/full_cuda_graph.py +198 -0
  122. megatron/core/fusions/__init__.py +0 -0
  123. megatron/core/fusions/fused_bias_dropout.py +92 -0
  124. megatron/core/fusions/fused_bias_geglu.py +442 -0
  125. megatron/core/fusions/fused_bias_gelu.py +55 -0
  126. megatron/core/fusions/fused_bias_swiglu.py +255 -0
  127. megatron/core/fusions/fused_cross_entropy.py +148 -0
  128. megatron/core/fusions/fused_indices_converter.py +288 -0
  129. megatron/core/fusions/fused_layer_norm.py +169 -0
  130. megatron/core/fusions/fused_mla_yarn_rope_apply.py +784 -0
  131. megatron/core/fusions/fused_pad_routing_map.py +98 -0
  132. megatron/core/fusions/fused_softmax.py +359 -0
  133. megatron/core/fusions/fused_weighted_squared_relu.py +110 -0
  134. megatron/core/hyper_comm_grid.py +239 -0
  135. megatron/core/inference/__init__.py +1 -0
  136. megatron/core/inference/async_stream.py +73 -0
  137. megatron/core/inference/common_inference_params.py +4 -0
  138. megatron/core/inference/communication_utils.py +211 -0
  139. megatron/core/inference/contexts/__init__.py +23 -0
  140. megatron/core/inference/contexts/attention_context/mamba_metadata.py +106 -0
  141. megatron/core/inference/contexts/attention_context/metadata_base.py +72 -0
  142. megatron/core/inference/contexts/attention_context/mha_metadata.py +220 -0
  143. megatron/core/inference/contexts/base_context.py +43 -0
  144. megatron/core/inference/contexts/dynamic_block_allocator.py +118 -0
  145. megatron/core/inference/contexts/dynamic_context.py +1804 -0
  146. megatron/core/inference/contexts/fused_kv_append_kernel.py +174 -0
  147. megatron/core/inference/contexts/static_context.py +130 -0
  148. megatron/core/inference/data_parallel_inference_coordinator.py +255 -0
  149. megatron/core/inference/engines/__init__.py +5 -0
  150. megatron/core/inference/engines/abstract_engine.py +17 -0
  151. megatron/core/inference/engines/dynamic_engine.py +1102 -0
  152. megatron/core/inference/engines/mcore_engine.py +5 -0
  153. megatron/core/inference/engines/static_engine.py +388 -0
  154. megatron/core/inference/headers.py +17 -0
  155. megatron/core/inference/inference_client.py +193 -0
  156. megatron/core/inference/inference_request.py +357 -0
  157. megatron/core/inference/model_inference_wrappers/__init__.py +1 -0
  158. megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +389 -0
  159. megatron/core/inference/model_inference_wrappers/gpt/__init__.py +1 -0
  160. megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +131 -0
  161. megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +66 -0
  162. megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +216 -0
  163. megatron/core/inference/model_inference_wrappers/t5/__init__.py +1 -0
  164. megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +230 -0
  165. megatron/core/inference/sampling_params.py +56 -0
  166. megatron/core/inference/scheduler.py +193 -0
  167. megatron/core/inference/text_generation_controllers/__init__.py +1 -0
  168. megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +51 -0
  169. megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +5 -0
  170. megatron/core/inference/text_generation_controllers/text_generation_controller.py +1474 -0
  171. megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +53 -0
  172. megatron/core/inference/text_generation_server/__init__.py +3 -0
  173. megatron/core/inference/text_generation_server/endpoints/common.py +14 -0
  174. megatron/core/inference/text_generation_server/endpoints/completions.py +212 -0
  175. megatron/core/inference/text_generation_server/run_mcore_engine.py +111 -0
  176. megatron/core/inference/text_generation_server/text_generation_server.py +211 -0
  177. megatron/core/inference/text_generation_server/tokenization.py +110 -0
  178. megatron/core/inference/unified_memory.py +127 -0
  179. megatron/core/inference/utils.py +163 -0
  180. megatron/core/inference_params.py +5 -0
  181. megatron/core/jit.py +18 -0
  182. megatron/core/model_parallel_config.py +404 -0
  183. megatron/core/models/T5/__init__.py +2 -0
  184. megatron/core/models/T5/t5_model.py +536 -0
  185. megatron/core/models/T5/t5_spec.py +251 -0
  186. megatron/core/models/__init__.py +1 -0
  187. megatron/core/models/backends.py +182 -0
  188. megatron/core/models/bert/__init__.py +0 -0
  189. megatron/core/models/bert/bert_layer_specs.py +118 -0
  190. megatron/core/models/bert/bert_lm_head.py +50 -0
  191. megatron/core/models/bert/bert_model.py +386 -0
  192. megatron/core/models/bert/pooler.py +52 -0
  193. megatron/core/models/common/__init__.py +0 -0
  194. megatron/core/models/common/embeddings/__init__.py +5 -0
  195. megatron/core/models/common/embeddings/language_model_embedding.py +150 -0
  196. megatron/core/models/common/embeddings/relative_pos_embedding.py +180 -0
  197. megatron/core/models/common/embeddings/rope_utils.py +345 -0
  198. megatron/core/models/common/embeddings/rotary_pos_embedding.py +325 -0
  199. megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +249 -0
  200. megatron/core/models/common/language_module/__init__.py +0 -0
  201. megatron/core/models/common/language_module/language_module.py +344 -0
  202. megatron/core/models/common/model_chunk_schedule_plan.py +508 -0
  203. megatron/core/models/common/vision_module/__init__.py +0 -0
  204. megatron/core/models/common/vision_module/vision_module.py +17 -0
  205. megatron/core/models/gpt/__init__.py +2 -0
  206. megatron/core/models/gpt/fine_grained_callables.py +585 -0
  207. megatron/core/models/gpt/gpt_layer_specs.py +673 -0
  208. megatron/core/models/gpt/gpt_model.py +765 -0
  209. megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +220 -0
  210. megatron/core/models/gpt/moe_module_specs.py +74 -0
  211. megatron/core/models/huggingface/__init__.py +2 -0
  212. megatron/core/models/huggingface/clip_model.py +42 -0
  213. megatron/core/models/huggingface/module.py +97 -0
  214. megatron/core/models/huggingface/qwen_model.py +59 -0
  215. megatron/core/models/mamba/__init__.py +2 -0
  216. megatron/core/models/mamba/mamba_layer_specs.py +68 -0
  217. megatron/core/models/mamba/mamba_model.py +289 -0
  218. megatron/core/models/mimo/__init__.py +16 -0
  219. megatron/core/models/mimo/config/__init__.py +5 -0
  220. megatron/core/models/mimo/config/base_configs.py +34 -0
  221. megatron/core/models/mimo/model/__init__.py +4 -0
  222. megatron/core/models/mimo/model/base.py +290 -0
  223. megatron/core/models/mimo/submodules/audio.py +155 -0
  224. megatron/core/models/mimo/submodules/base.py +193 -0
  225. megatron/core/models/mimo/submodules/vision.py +184 -0
  226. megatron/core/models/multimodal/__init__.py +1 -0
  227. megatron/core/models/multimodal/context_parallel.py +111 -0
  228. megatron/core/models/multimodal/llava_model.py +1028 -0
  229. megatron/core/models/multimodal/llava_spec.py +90 -0
  230. megatron/core/models/retro/__init__.py +13 -0
  231. megatron/core/models/retro/base_attention.py +47 -0
  232. megatron/core/models/retro/config.py +88 -0
  233. megatron/core/models/retro/decoder_attention.py +319 -0
  234. megatron/core/models/retro/decoder_spec.py +195 -0
  235. megatron/core/models/retro/encoder_attention.py +231 -0
  236. megatron/core/models/retro/encoder_spec.py +171 -0
  237. megatron/core/models/retro/model.py +107 -0
  238. megatron/core/models/retro/utils.py +24 -0
  239. megatron/core/models/vision/__init__.py +0 -0
  240. megatron/core/models/vision/clip_vit_model.py +261 -0
  241. megatron/core/models/vision/multimodal_projector.py +88 -0
  242. megatron/core/models/vision/radio.py +380 -0
  243. megatron/core/models/vision/vit_layer_specs.py +96 -0
  244. megatron/core/msc_utils.py +69 -0
  245. megatron/core/nccl_allocator.py +316 -0
  246. megatron/core/num_microbatches_calculator.py +508 -0
  247. megatron/core/optimizer/__init__.py +635 -0
  248. megatron/core/optimizer/clip_grads.py +247 -0
  249. megatron/core/optimizer/cpu_offloading/__init__.py +2 -0
  250. megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +472 -0
  251. megatron/core/optimizer/distrib_optimizer.py +2602 -0
  252. megatron/core/optimizer/grad_scaler.py +142 -0
  253. megatron/core/optimizer/optimizer.py +1418 -0
  254. megatron/core/optimizer/optimizer_config.py +308 -0
  255. megatron/core/optimizer_param_scheduler.py +311 -0
  256. megatron/core/package_info.py +27 -0
  257. megatron/core/packed_seq_params.py +20 -0
  258. megatron/core/parallel_state.py +2097 -0
  259. megatron/core/pipeline_parallel/__init__.py +2 -0
  260. megatron/core/pipeline_parallel/bridge_communicator.py +922 -0
  261. megatron/core/pipeline_parallel/combined_1f1b.py +444 -0
  262. megatron/core/pipeline_parallel/p2p_communication.py +645 -0
  263. megatron/core/pipeline_parallel/schedules.py +2303 -0
  264. megatron/core/pipeline_parallel/utils.py +307 -0
  265. megatron/core/post_training/__init__.py +1 -0
  266. megatron/core/post_training/modelopt/__init__.py +10 -0
  267. megatron/core/post_training/modelopt/gpt/__init__.py +1 -0
  268. megatron/core/post_training/modelopt/gpt/model_specs.py +206 -0
  269. megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +64 -0
  270. megatron/core/post_training/modelopt/layers.py +249 -0
  271. megatron/core/post_training/modelopt/mamba/__init__.py +1 -0
  272. megatron/core/post_training/modelopt/mamba/model_specs.py +91 -0
  273. megatron/core/process_groups_config.py +571 -0
  274. megatron/core/quantization/__init__.py +1 -0
  275. megatron/core/quantization/quant_config.py +219 -0
  276. megatron/core/quantization/utils.py +37 -0
  277. megatron/core/requirements.txt +2 -0
  278. megatron/core/rerun_state_machine.py +1345 -0
  279. megatron/core/safe_globals.py +39 -0
  280. megatron/core/ssm/__init__.py +1 -0
  281. megatron/core/ssm/mamba_block.py +414 -0
  282. megatron/core/ssm/mamba_context_parallel.py +389 -0
  283. megatron/core/ssm/mamba_hybrid_layer_allocation.py +218 -0
  284. megatron/core/ssm/mamba_layer.py +184 -0
  285. megatron/core/ssm/mamba_mixer.py +1171 -0
  286. megatron/core/ssm/mlp_layer.py +30 -0
  287. megatron/core/ssm/triton_cache_manager.py +81 -0
  288. megatron/core/tensor_parallel/__init__.py +74 -0
  289. megatron/core/tensor_parallel/cross_entropy.py +232 -0
  290. megatron/core/tensor_parallel/data.py +101 -0
  291. megatron/core/tensor_parallel/inference_layers.py +151 -0
  292. megatron/core/tensor_parallel/layers.py +1303 -0
  293. megatron/core/tensor_parallel/mappings.py +596 -0
  294. megatron/core/tensor_parallel/random.py +615 -0
  295. megatron/core/tensor_parallel/utils.py +121 -0
  296. megatron/core/timers.py +465 -0
  297. megatron/core/tokenizers/__init__.py +4 -0
  298. megatron/core/tokenizers/base_tokenizer.py +48 -0
  299. megatron/core/tokenizers/megatron_tokenizer.py +171 -0
  300. megatron/core/tokenizers/text/__init__.py +3 -0
  301. megatron/core/tokenizers/text/libraries/__init__.py +8 -0
  302. megatron/core/tokenizers/text/libraries/abstract_tokenizer.py +147 -0
  303. megatron/core/tokenizers/text/libraries/bytelevel_tokenizer.py +164 -0
  304. megatron/core/tokenizers/text/libraries/chat_template.py +71 -0
  305. megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py +335 -0
  306. megatron/core/tokenizers/text/libraries/megatron_hf_tokenizer.py +179 -0
  307. megatron/core/tokenizers/text/libraries/null_tokenizer.py +79 -0
  308. megatron/core/tokenizers/text/libraries/sentencepiece_tokenizer.py +411 -0
  309. megatron/core/tokenizers/text/libraries/tiktoken_tokenizer.py +303 -0
  310. megatron/core/tokenizers/text/models/__init__.py +8 -0
  311. megatron/core/tokenizers/text/models/bert_tokenizer.py +12 -0
  312. megatron/core/tokenizers/text/models/default_tokenizer.py +12 -0
  313. megatron/core/tokenizers/text/models/gpt_tokenizer.py +12 -0
  314. megatron/core/tokenizers/text/models/mamba_tokenizer.py +12 -0
  315. megatron/core/tokenizers/text/models/retro_tokenizer.py +12 -0
  316. megatron/core/tokenizers/text/models/t5_tokenizer.py +12 -0
  317. megatron/core/tokenizers/text/text_tokenizer.py +254 -0
  318. megatron/core/tokenizers/text/utils/build_tokenizer.py +58 -0
  319. megatron/core/transformer/__init__.py +6 -0
  320. megatron/core/transformer/attention.py +1238 -0
  321. megatron/core/transformer/cuda_graphs.py +1676 -0
  322. megatron/core/transformer/custom_layers/__init__.py +0 -0
  323. megatron/core/transformer/custom_layers/transformer_engine.py +12 -0
  324. megatron/core/transformer/dot_product_attention.py +258 -0
  325. megatron/core/transformer/enums.py +67 -0
  326. megatron/core/transformer/fsdp_dtensor_checkpoint.py +455 -0
  327. megatron/core/transformer/heterogeneous/heterogeneous_config.py +267 -0
  328. megatron/core/transformer/heterogeneous/linear_replacements.py +115 -0
  329. megatron/core/transformer/identity_op.py +28 -0
  330. megatron/core/transformer/mlp.py +403 -0
  331. megatron/core/transformer/module.py +453 -0
  332. megatron/core/transformer/moe/__init__.py +0 -0
  333. megatron/core/transformer/moe/experts.py +1166 -0
  334. megatron/core/transformer/moe/fused_a2a.py +264 -0
  335. megatron/core/transformer/moe/grouped_gemm_util.py +22 -0
  336. megatron/core/transformer/moe/moe_layer.py +309 -0
  337. megatron/core/transformer/moe/moe_utils.py +1030 -0
  338. megatron/core/transformer/moe/router.py +572 -0
  339. megatron/core/transformer/moe/shared_experts.py +286 -0
  340. megatron/core/transformer/moe/token_dispatcher.py +1327 -0
  341. megatron/core/transformer/moe/upcycling_utils.py +359 -0
  342. megatron/core/transformer/multi_latent_attention.py +919 -0
  343. megatron/core/transformer/multi_token_prediction.py +955 -0
  344. megatron/core/transformer/pipeline_parallel_layer_layout.py +308 -0
  345. megatron/core/transformer/spec_utils.py +106 -0
  346. megatron/core/transformer/torch_layer_norm.py +4 -0
  347. megatron/core/transformer/torch_norm.py +96 -0
  348. megatron/core/transformer/transformer_block.py +815 -0
  349. megatron/core/transformer/transformer_config.py +1647 -0
  350. megatron/core/transformer/transformer_layer.py +852 -0
  351. megatron/core/transformer/utils.py +419 -0
  352. megatron/core/utils.py +2154 -0
  353. megatron_core-0.16.0rc0.dev127461.dist-info/METADATA +579 -0
  354. megatron_core-0.16.0rc0.dev127461.dist-info/RECORD +356 -0
  355. megatron_core-0.16.0rc0.dev127461.dist-info/WHEEL +6 -0
  356. megatron_core-0.16.0rc0.dev127461.dist-info/top_level.txt +1 -0
@@ -0,0 +1,51 @@
1
+ <div align="center">
2
+
3
+ Megatron Core
4
+ =============
5
+ <h4>Production-ready library for building custom training frameworks</h4>
6
+
7
+ <div align="left">
8
+
9
+ ## ⚡ Quick Start
10
+
11
+ ```bash
12
+ # Install Megatron Core with required dependencies
13
+ pip install --no-build-isolation megatron-core[dev]
14
+
15
+ # Distributed training example (2 GPUs, mock data)
16
+ torchrun --nproc_per_node=2 examples/run_simple_mcore_train_loop.py
17
+ ```
18
+
19
+ # What is Megatron Core?
20
+
21
+ **Megatron Core** is an open-source PyTorch-based library that contains GPU-optimized techniques and cutting-edge system-level optimizations. It abstracts them into composable and modular APIs, allowing full flexibility for developers and model researchers to train custom transformers at-scale on NVIDIA accelerated computing infrastructure.
22
+
23
+ ## 🚀 Key Components
24
+
25
+ ### GPU-Optimized Building Blocks
26
+ - **Transformer Components**: Attention mechanisms, MLP layers, embeddings
27
+ - **Memory Management**: Activation recomputation
28
+ - **FP8 Precision**: Optimized for NVIDIA Hopper, Ada, and Blackwell GPUs
29
+
30
+ ### Parallelism Strategies
31
+ - **Tensor Parallelism (TP)**: Layer-wise parallelization (activation memory footprint can be further reduced using sequence parallelism)
32
+ - **Pipeline Parallelism (PP)**: Depth-wise model splitting and pipelining of microbatches to improve efficiency
33
+ - **Context Parallelism (CP)**: Long sequence handling ([documentation](https://docs.nvidia.com/megatron-core/developer-guide/latest/api-guide/context_parallel.html))
34
+ - **Expert Parallelism (EP)**: Split experts of an MoE model across multiple GPUs
35
+
36
+
37
+ ## 🔗 Examples & Documentation
38
+
39
+ **Examples:**
40
+ - **[Simple Training Loop](https://github.com/NVIDIA/Megatron-LM/blob/main/examples/run_simple_mcore_train_loop.py)** - Basic usage
41
+ - **[Multimodal Training](https://github.com/NVIDIA/Megatron-LM/blob/main/examples/multimodal/)** - Vision-language models
42
+ - **[Mixture-of-Experts](https://github.com/yanring/Megatron-MoE-ModelZoo)** - MoE examples
43
+ - **[Mamba Models](https://github.com/NVIDIA/Megatron-LM/blob/main/examples/mamba/)** - State-space models
44
+
45
+ **Documentation:**
46
+ - **[📚 API Guide](https://docs.nvidia.com/megatron-core/developer-guide/latest/api-guide/index.html)** - Complete API documentation
47
+ - **[💡 Developer Guide](https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html)** - Custom framework development
48
+
49
+ ---
50
+
51
+ *For complete installation instructions, performance benchmarks, and ecosystem information, see the [main README](../README.md).*
@@ -0,0 +1,52 @@
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ import megatron.core.tensor_parallel
4
+ import megatron.core.utils
5
+ from megatron.core import parallel_state
6
+ from megatron.core.distributed import DistributedDataParallel
7
+ from megatron.core.inference_params import InferenceParams
8
+ from megatron.core.model_parallel_config import ModelParallelConfig
9
+ from megatron.core.package_info import (
10
+ __contact_emails__,
11
+ __contact_names__,
12
+ __description__,
13
+ __download_url__,
14
+ __homepage__,
15
+ __keywords__,
16
+ __license__,
17
+ __package_name__,
18
+ __repository_url__,
19
+ __shortversion__,
20
+ __version__,
21
+ )
22
+ from megatron.core.timers import Timers
23
+ from megatron.core.utils import is_torch_min_version
24
+
25
+ # Alias parallel_state as mpu, its legacy name
26
+ mpu = parallel_state
27
+
28
+ __all__ = [
29
+ "parallel_state",
30
+ "tensor_parallel",
31
+ "utils",
32
+ "DistributedDataParallel",
33
+ "InferenceParams",
34
+ "ModelParallelConfig",
35
+ "Timers",
36
+ "__contact_emails__",
37
+ "__contact_names__",
38
+ "__description__",
39
+ "__download_url__",
40
+ "__homepage__",
41
+ "__keywords__",
42
+ "__license__",
43
+ "__package_name__",
44
+ "__repository_url__",
45
+ "__shortversion__",
46
+ "__version__",
47
+ ]
48
+
49
+ from .safe_globals import register_safe_globals
50
+
51
+ if is_torch_min_version("2.6a0"):
52
+ register_safe_globals()
@@ -0,0 +1,23 @@
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from megatron.core.jit import jit_fuser
6
+
7
+
8
+ @jit_fuser
9
+ def squared_relu(x: torch.Tensor) -> torch.Tensor:
10
+ """Squared ReLU activation"""
11
+ return torch.pow(F.relu(x), 2)
12
+
13
+
14
+ @jit_fuser
15
+ def quick_gelu(x: torch.Tensor) -> torch.Tensor:
16
+ """Quick GELU activation"""
17
+ return x * torch.sigmoid(1.702 * x)
18
+
19
+
20
+ @jit_fuser
21
+ def fast_gelu(x: torch.Tensor) -> torch.Tensor:
22
+ """Fast GELU activation"""
23
+ return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
@@ -0,0 +1,14 @@
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ ENABLE_EXPERIMENTAL = False
4
+
5
+
6
+ def set_experimental_flag(flag: bool):
7
+ """Set the experimental flag to the given value."""
8
+ global ENABLE_EXPERIMENTAL
9
+ ENABLE_EXPERIMENTAL = flag
10
+
11
+
12
+ def is_experimental_enabled():
13
+ """Return the experimental flag."""
14
+ return ENABLE_EXPERIMENTAL
@@ -0,0 +1,126 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+ import json
17
+ import os
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from megatron.core import parallel_state
23
+
24
+
25
+ def get_config_logger_path(config):
26
+ """Get the path to the config logger directory."""
27
+ return getattr(config, 'config_logger_dir', '')
28
+
29
+
30
+ def has_config_logger_enabled(config):
31
+ """Check if config logger is enabled."""
32
+ return get_config_logger_path(config) != ''
33
+
34
+
35
+ # For each prefix, holds a counter and increases it every time we dump with this
36
+ # prefix.
37
+ __config_logger_path_counts = {}
38
+
39
+
40
+ def get_path_count(path):
41
+ """
42
+ keeps tracks of number of times we've seen the input `path` and return count-1
43
+ """
44
+ global __config_logger_path_counts
45
+ if not path in __config_logger_path_counts:
46
+ __config_logger_path_counts[path] = 0
47
+ count = __config_logger_path_counts[path]
48
+ __config_logger_path_counts[path] += 1
49
+ return count
50
+
51
+
52
+ def get_path_with_count(path):
53
+ """
54
+ calls get_path_count and appends returned value to path
55
+ """
56
+ return f'{path}.iter{get_path_count(path)}'
57
+
58
+
59
+ class JSONEncoderWithMcoreTypes(json.JSONEncoder):
60
+ """
61
+ Custom JSON encoder that serializes according to types in mcore.
62
+ """
63
+
64
+ def default(self, o):
65
+ if type(o).__name__ in ['function', 'ProcessGroup']:
66
+ return str(o)
67
+ if type(o).__name__ in ['dict', 'OrderedDict']:
68
+ return {k: self.default(v) for k, v in o.items()}
69
+ if type(o).__name__ in ['list', 'ModuleList']:
70
+ return [self.default(val) for val in o]
71
+ if type(o).__name__ == 'UniqueDescriptor':
72
+ return {
73
+ attr: self.default(getattr(o, attr))
74
+ for attr in filter(lambda x: not x.startswith('__'), dir(o))
75
+ }
76
+ if type(o) is torch.dtype:
77
+ return str(o)
78
+ # if it's a Float16Module, add "Float16Module" to the output dict
79
+ if type(o).__name__ == 'Float16Module':
80
+ return {'Float16Module': {'module': self.default(o.module)}}
81
+ # If it's a nn.Module subchild, either print its children or itself if leaf.
82
+ if issubclass(type(o), nn.Module):
83
+ if len(getattr(o, '_modules', {})) > 0:
84
+ return {key: self.default(val) for key, val in o._modules.items()}
85
+ else:
86
+ return str(o)
87
+ if type(o).__name__ in ['ABCMeta', 'type', 'AttnMaskType']:
88
+ return str(o)
89
+ if dataclasses.is_dataclass(o) or type(o).__name__ in ['ModuleSpec', 'TransformerConfig']:
90
+ return dataclasses.asdict(o)
91
+ try:
92
+ return super().default(o)
93
+ except:
94
+ return str(o)
95
+
96
+
97
+ def log_config_to_disk(config, dict_data, prefix='', rank_str=''):
98
+ """
99
+ Encodes the input dict (dict_data) using the JSONEncoderWithMcoreTypes
100
+ and dumps to disk, as specified via path
101
+ """
102
+ path = get_config_logger_path(config)
103
+ assert path is not None, 'Expected config_logger_dir to be non-empty in config.'
104
+
105
+ if not os.path.exists(path):
106
+ os.makedirs(path, exist_ok=True)
107
+
108
+ if 'self' in dict_data:
109
+ if prefix == '':
110
+ prefix = type(dict_data['self']).__name__
111
+ del dict_data['self']
112
+
113
+ # the caller of the funcion can decide the most informative string
114
+ # rank_str defaults to '0_0_0_0_0' format (tp_dp_cp_pp_ep ranks)
115
+ if rank_str == '':
116
+ rank_str = parallel_state.get_all_ranks()
117
+
118
+ path = get_path_with_count(os.path.join(path, f'{prefix}.rank_{rank_str}'))
119
+ if type(dict_data).__name__ == 'OrderedDict':
120
+ torch.save(dict_data, f'{path}.pth')
121
+ else:
122
+ with open(f'{path}.json', 'w') as fp:
123
+ json.dump(dict_data, fp, cls=JSONEncoderWithMcoreTypes)
124
+
125
+
126
+ __all__ = ['has_config_logger_enabled', 'log_config_to_disk']
File without changes
@@ -0,0 +1,190 @@
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Optional, Union
5
+
6
+ import numpy
7
+
8
+ from megatron.core.datasets.indexed_dataset import IndexedDataset
9
+ from megatron.core.datasets.masked_dataset import (
10
+ MaskedWordPieceDataset,
11
+ MaskedWordPieceDatasetConfig,
12
+ )
13
+ from megatron.core.datasets.utils import Split
14
+
15
+
16
+ @dataclass
17
+ class BERTMaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig):
18
+ """Configuration object for Megatron Core BERT WordPiece datasets"""
19
+
20
+ classification_head: bool = None
21
+ """Option to perform the next sequence prediction during sampling"""
22
+
23
+ def __post_init__(self) -> None:
24
+ """Do asserts and set fields post init"""
25
+ super().__post_init__()
26
+
27
+ assert self.classification_head is not None
28
+
29
+
30
+ class BERTMaskedWordPieceDataset(MaskedWordPieceDataset):
31
+ """The BERT dataset that assumes WordPiece tokenization
32
+
33
+ Args:
34
+ indexed_dataset (IndexedDataset): The IndexedDataset around which
35
+ to build the MegatronDataset
36
+ dataset_path (str): The real path on disk to the dataset, for bookkeeping
37
+ indexed_indices (numpy.ndarray): The set of the documents indices to expose
38
+ num_samples (Optional[int]): The number of samples to draw from the indexed dataset.
39
+ When None, build as many samples as correspond to one epoch.
40
+ index_split (Split): The indexed_indices Split
41
+ config (BERTMaskedWordPieceDatasetConfig): The config
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ indexed_dataset: IndexedDataset,
47
+ dataset_path: str,
48
+ indexed_indices: numpy.ndarray,
49
+ num_samples: Optional[int],
50
+ index_split: Split,
51
+ config: BERTMaskedWordPieceDatasetConfig,
52
+ ) -> None:
53
+ super().__init__(
54
+ indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
55
+ )
56
+
57
+ self.token_lookup = list(self.config.tokenizer.inv_vocab.keys())
58
+ # Account for the single <cls> and two <sep> token ids
59
+ self.sample_index = self._build_sample_index(
60
+ self.config.sequence_length - 3, 2 if self.config.classification_head else 1
61
+ )
62
+
63
+ @staticmethod
64
+ def _key_config_attributes() -> List[str]:
65
+ """Inherited method implementation
66
+
67
+ Returns:
68
+ List[str]: The key config attributes
69
+ """
70
+ return super(
71
+ BERTMaskedWordPieceDataset, BERTMaskedWordPieceDataset
72
+ )._key_config_attributes() + ["classification_head"]
73
+
74
+ def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
75
+ """Abstract method implementation
76
+
77
+ Args:
78
+ idx (int): The index into the dataset
79
+
80
+ Returns:
81
+ Dict[str, Union[int, numpy.ndarray]]: The
82
+ """
83
+
84
+ idx_beg, idx_end, target_sequence_length = self.sample_index[idx]
85
+ sample = [self.dataset[i] for i in range(idx_beg, idx_end)]
86
+ numpy_random_state = numpy.random.RandomState(seed=(self.config.random_seed + idx) % 2**32)
87
+
88
+ assert target_sequence_length <= self.config.sequence_length
89
+
90
+ # Split the sample into contiguous subsegments A and B
91
+ pivot = len(sample)
92
+ is_next_random = False
93
+ if self.config.classification_head:
94
+ assert len(sample) > 1, "the sample must contain at least two sentences"
95
+ pivot = 1
96
+ if len(sample) >= 3:
97
+ pivot = numpy_random_state.randint(low=1, high=len(sample))
98
+ is_next_random = numpy_random_state.random() < 0.5
99
+ split_A = []
100
+ for sample_a in sample[:pivot]:
101
+ split_A.extend(sample_a)
102
+ split_B = []
103
+ for sample_b in sample[pivot:]:
104
+ split_B.extend(sample_b)
105
+ if is_next_random:
106
+ split_A, split_B = split_B, split_A
107
+
108
+ # Trim the subsegments from either end to a desired joint length
109
+ length_A = len(split_A)
110
+ length_B = len(split_B)
111
+ if length_A + length_B <= target_sequence_length:
112
+ truncated = False
113
+ else:
114
+ while length_A + length_B > target_sequence_length:
115
+ split = split_A if length_A > length_B else split_B
116
+ if numpy_random_state.random() < 0.5:
117
+ del split[0]
118
+ else:
119
+ del split[-1]
120
+ length_A = len(split_A)
121
+ length_B = len(split_B)
122
+ truncated = True
123
+
124
+ # Merge the subsegments and create the token assignment labels
125
+ tokens = [self.config.tokenizer.cls, *split_A, self.config.tokenizer.sep]
126
+ assignments = [0 for _ in range(1 + len(split_A) + 1)]
127
+ if split_B:
128
+ tokens += [*split_B, self.config.tokenizer.sep]
129
+ assignments += [1 for _ in range(len(split_B) + 1)]
130
+
131
+ # Masking
132
+ tokens, masked_positions, masked_labels, _, _ = self._create_masked_lm_predictions(
133
+ tokens, target_sequence_length, numpy_random_state
134
+ )
135
+
136
+ # Pad the sequences and convert to NumPy
137
+ length_toks = len(tokens)
138
+ length_pads = self.config.sequence_length - length_toks
139
+ assert length_pads >= 0
140
+
141
+ tokens = numpy.array(tokens, dtype=numpy.int64)
142
+ tokens = numpy.pad(tokens, (0, length_pads), constant_values=self._pad_token_id)
143
+
144
+ assignments = numpy.array(assignments, dtype=numpy.int64)
145
+ assignments = numpy.pad(assignments, (0, length_pads), constant_values=self._pad_token_id)
146
+
147
+ # Get the padding mask
148
+ mask_pads = numpy.ones(self.config.sequence_length, dtype=numpy.int64)
149
+ mask_pads[tokens == self._pad_token_id] = self._pad_token_id
150
+
151
+ # Mask the labels
152
+ labels = numpy.zeros(self.config.sequence_length, dtype=numpy.int64) - 1
153
+ labels[masked_positions] = masked_labels
154
+
155
+ # Get the loss mask
156
+ mask_loss = numpy.zeros(self.config.sequence_length, dtype=numpy.int64)
157
+ mask_loss[masked_positions] = 1
158
+
159
+ # For padded sequences, ensure the embedding layer can map the token ID
160
+ tokens[tokens == self._pad_token_id] = 0
161
+ labels[labels == self._pad_token_id] = 0
162
+
163
+ return {
164
+ "text": tokens,
165
+ "types": assignments,
166
+ "labels": labels,
167
+ "is_random": int(is_next_random),
168
+ "padding_mask": mask_pads,
169
+ "loss_mask": mask_loss,
170
+ "truncated": int(truncated),
171
+ }
172
+
173
+ def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> Optional[int]:
174
+ """Abstract method implementation
175
+
176
+ 80% of the time, replace the token id with mask token id. 10% of the time, replace token id
177
+ with a random token id from the vocabulary. 10% of the time, do nothing.
178
+
179
+ Args:
180
+ numpy_random_state (RandomState): The NumPy random state
181
+
182
+ Returns:
183
+ Optional[int]: The replacement token id or None
184
+ """
185
+ if numpy_random_state.random() < 0.8:
186
+ return self.config.tokenizer.mask
187
+ else:
188
+ if numpy_random_state.random() >= 0.5:
189
+ return self.token_lookup[numpy_random_state.randint(0, len(self.token_lookup))]
190
+ return None
@@ -0,0 +1,212 @@
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ import hashlib
4
+ import json
5
+ import logging
6
+ import os
7
+ import time
8
+ from collections import OrderedDict
9
+ from typing import Dict, List, Optional, Tuple, Union
10
+
11
+ import numpy
12
+ import torch
13
+
14
+ from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
15
+ from megatron.core.datasets.megatron_dataset import MegatronDataset
16
+ from megatron.core.datasets.utils import normalize
17
+ from megatron.core.utils import log_single_rank
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ _VERBOSE = False
22
+
23
+
24
+ class BlendedDataset(torch.utils.data.Dataset):
25
+ """Conjugating class for a set of MegatronDataset instances
26
+
27
+ Args:
28
+ datasets (List[MegatronDataset]): The MegatronDataset instances to blend
29
+
30
+ weights (List[Union[int, float]]): The weights that determine the dataset blend ratios
31
+
32
+ size (Optional[int]): The number of samples to draw from the blend. If None, for each
33
+ dataset index idx draw exactly weights[idx] samples from datasets[idx].
34
+
35
+ config (BlendedMegatronDatasetConfig): The config
36
+
37
+ Raises:
38
+ RuntimeError: When the dataset has fewer or more samples than 'size' post-initialization
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ datasets: List[MegatronDataset],
44
+ weights: List[Union[int, float]],
45
+ size: Optional[int],
46
+ config: BlendedMegatronDatasetConfig,
47
+ ) -> None:
48
+ assert len(datasets) == len(weights)
49
+ assert len(datasets) < 32767
50
+ assert all(map(lambda _: type(_) == type(datasets[0]), datasets))
51
+ assert all(map(lambda _: _.index_split == datasets[0].index_split, datasets))
52
+ assert all(map(lambda _: _ > 0, weights))
53
+ assert all(map(lambda _: type(_) == type(weights[0]), weights))
54
+ if size is None and isinstance(weights[0], float):
55
+ assert all(map(lambda _: _ == int(_), weights))
56
+
57
+ # Alert user to unnecessary blending
58
+ if len(datasets) == 1:
59
+ log_single_rank(
60
+ logger, logging.WARNING, f"Building a BlendedDataset for a single MegatronDataset"
61
+ )
62
+
63
+ if size is not None:
64
+ weights = normalize(weights)
65
+
66
+ self.datasets = datasets
67
+ self.split = self.datasets[0].index_split
68
+ self.weights = weights
69
+ self.size = size
70
+ self.config = config
71
+
72
+ unique_identifiers = OrderedDict()
73
+ unique_identifiers["class"] = type(self).__name__
74
+ unique_identifiers["datasets"] = [dataset.unique_identifiers for dataset in self.datasets]
75
+ unique_identifiers["split"] = self.split.name
76
+ unique_identifiers["weights"] = self.weights
77
+ unique_identifiers["size"] = self.size
78
+
79
+ self.unique_description = json.dumps(
80
+ unique_identifiers, indent=4, default=lambda obj: obj.unique_identifiers
81
+ )
82
+ self.unique_description_hash = hashlib.md5(
83
+ self.unique_description.encode("utf-8"), usedforsecurity=False
84
+ ).hexdigest()
85
+
86
+ self.dataset_index, self.dataset_sample_index = self._build_indices()
87
+
88
+ def __len__(self) -> int:
89
+ return self.dataset_index.shape[0]
90
+
91
+ def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
92
+ dataset_id = self.dataset_index[idx]
93
+ dataset_sample_id = self.dataset_sample_index[idx]
94
+ return {"dataset_id": dataset_id, **self.datasets[dataset_id][dataset_sample_id]}
95
+
96
+ def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]:
97
+ """Build and optionally cache the dataset index and the dataset sample index
98
+
99
+ The dataset index is a 1-D mapping which determines the dataset to query. The dataset
100
+ sample index is a 1-D mapping which determines the sample to request from the queried
101
+ dataset.
102
+
103
+ Returns:
104
+ Tuple[numpy.ndarray, numpy.ndarray]: The dataset index and the dataset sample index
105
+ """
106
+
107
+ path_to_cache = self.config.path_to_cache
108
+
109
+ if path_to_cache:
110
+ get_path_to = lambda suffix: os.path.join(
111
+ path_to_cache,
112
+ f"{self.unique_description_hash}-{type(self).__name__}-{self.split.name}-{suffix}",
113
+ )
114
+ path_to_description = get_path_to("description.txt")
115
+ path_to_dataset_index = get_path_to("dataset_index.npy")
116
+ path_to_dataset_sample_index = get_path_to("dataset_sample_index.npy")
117
+ cache_hit = all(
118
+ map(
119
+ os.path.isfile,
120
+ [path_to_description, path_to_dataset_index, path_to_dataset_sample_index],
121
+ )
122
+ )
123
+ else:
124
+ cache_hit = False
125
+
126
+ if not path_to_cache or (not cache_hit and torch.distributed.get_rank() == 0):
127
+ log_single_rank(
128
+ logger, logging.INFO, f"Build and save the {type(self).__name__} indices"
129
+ )
130
+
131
+ # Build the dataset and dataset sample indexes
132
+ log_single_rank(
133
+ logger, logging.INFO, f"\tBuild and save the dataset and dataset sample indexes"
134
+ )
135
+ t_beg = time.time()
136
+ from megatron.core.datasets import helpers
137
+
138
+ if self.size is not None:
139
+ dataset_index = numpy.zeros(self.size, dtype=numpy.int16)
140
+ dataset_sample_index = numpy.zeros(self.size, dtype=numpy.int64)
141
+ helpers.build_blending_indices(
142
+ dataset_index,
143
+ dataset_sample_index,
144
+ self.weights,
145
+ len(self.datasets),
146
+ self.size,
147
+ _VERBOSE,
148
+ )
149
+ else:
150
+ size = sum(self.weights)
151
+ dataset_index = numpy.zeros(size, dtype=numpy.int16)
152
+ dataset_sample_index = numpy.zeros(size, dtype=numpy.int64)
153
+ helpers.build_exhaustive_blending_indices(
154
+ dataset_index, dataset_sample_index, self.weights, len(self.datasets)
155
+ )
156
+
157
+ dataset_indices, dataset_sizes = numpy.unique(dataset_index, return_counts=True)
158
+ for i, (_index, _size) in enumerate(zip(dataset_indices, dataset_sizes)):
159
+ if len(self.datasets[_index]) < _size:
160
+ raise IndexError(
161
+ f"The {self.split.name} blend oversamples the contributing datasets and, "
162
+ f"for example, requests {_size} samples from "
163
+ f"{type(self.datasets[_index]).__name__} number {i} in excess of its size "
164
+ f"{len(self.datasets[_index])}. The current value of the config attribute "
165
+ f"mid_level_dataset_surplus may be increased, e.g. two- or ten-fold, from "
166
+ f"its current value ({self.config.mid_level_dataset_surplus}) to ensure a "
167
+ f"sufficient mid-level dataset sample margin from which to draw."
168
+ )
169
+
170
+ if path_to_cache:
171
+ os.makedirs(path_to_cache, exist_ok=True)
172
+ # Write the description
173
+ with open(path_to_description, "wt") as writer:
174
+ writer.write(self.unique_description)
175
+ # Save the indexes
176
+ numpy.save(path_to_dataset_index, dataset_index, allow_pickle=True)
177
+ numpy.save(path_to_dataset_sample_index, dataset_sample_index, allow_pickle=True)
178
+ else:
179
+ log_single_rank(
180
+ logger,
181
+ logging.WARNING,
182
+ f"Cannot save the {type(self).__name__} indexes because path_to_cache is None",
183
+ )
184
+
185
+ t_end = time.time()
186
+ log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
187
+
188
+ return dataset_index, dataset_sample_index
189
+
190
+ log_single_rank(logger, logging.INFO, f"Load the {type(self).__name__} indices")
191
+
192
+ log_single_rank(
193
+ logger, logging.INFO, f"\tLoad the dataset index from {path_to_dataset_index}"
194
+ )
195
+ t_beg = time.time()
196
+ dataset_index = numpy.load(path_to_dataset_index, allow_pickle=True, mmap_mode="r")
197
+ t_end = time.time()
198
+ log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
199
+
200
+ log_single_rank(
201
+ logger,
202
+ logging.INFO,
203
+ f"\tLoad the dataset sample index from {path_to_dataset_sample_index}",
204
+ )
205
+ t_beg = time.time()
206
+ dataset_sample_index = numpy.load(
207
+ path_to_dataset_sample_index, allow_pickle=True, mmap_mode="r"
208
+ )
209
+ t_end = time.time()
210
+ log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
211
+
212
+ return dataset_index, dataset_sample_index