megatron-core 0.16.0rc0.dev127461__cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- megatron/core/README.md +51 -0
- megatron/core/__init__.py +52 -0
- megatron/core/activations.py +23 -0
- megatron/core/config.py +14 -0
- megatron/core/config_logger.py +126 -0
- megatron/core/datasets/__init__.py +0 -0
- megatron/core/datasets/bert_dataset.py +190 -0
- megatron/core/datasets/blended_dataset.py +212 -0
- megatron/core/datasets/blended_megatron_dataset_builder.py +552 -0
- megatron/core/datasets/blended_megatron_dataset_config.py +197 -0
- megatron/core/datasets/gpt_dataset.py +809 -0
- megatron/core/datasets/helpers.cpp +848 -0
- megatron/core/datasets/helpers.py +66 -0
- megatron/core/datasets/helpers_cpp.cpython-311-aarch64-linux-gnu.so +0 -0
- megatron/core/datasets/indexed_dataset.py +953 -0
- megatron/core/datasets/masked_dataset.py +423 -0
- megatron/core/datasets/megatron_dataset.py +185 -0
- megatron/core/datasets/megatron_tokenizer.py +162 -0
- megatron/core/datasets/multimodal_dataset.py +62 -0
- megatron/core/datasets/object_storage_utils.py +281 -0
- megatron/core/datasets/retro/__init__.py +5 -0
- megatron/core/datasets/retro/config/__init__.py +16 -0
- megatron/core/datasets/retro/config/bert_embedders.py +49 -0
- megatron/core/datasets/retro/config/config.py +135 -0
- megatron/core/datasets/retro/config/gpt_chunk_datasets.py +15 -0
- megatron/core/datasets/retro/config/tokenizers.py +15 -0
- megatron/core/datasets/retro/db/__init__.py +9 -0
- megatron/core/datasets/retro/db/build.py +649 -0
- megatron/core/datasets/retro/db/dataset.py +114 -0
- megatron/core/datasets/retro/db/utils.py +398 -0
- megatron/core/datasets/retro/external_libs.py +13 -0
- megatron/core/datasets/retro/index/__init__.py +11 -0
- megatron/core/datasets/retro/index/build.py +339 -0
- megatron/core/datasets/retro/index/factory.py +40 -0
- megatron/core/datasets/retro/index/index.py +150 -0
- megatron/core/datasets/retro/index/indexes/__init__.py +10 -0
- megatron/core/datasets/retro/index/indexes/faiss_base.py +179 -0
- megatron/core/datasets/retro/index/indexes/faiss_par_add.py +253 -0
- megatron/core/datasets/retro/index/utils.py +126 -0
- megatron/core/datasets/retro/index/validate.py +194 -0
- megatron/core/datasets/retro/query/__init__.py +1 -0
- megatron/core/datasets/retro/query/gpt_chunk_dataset.py +109 -0
- megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +115 -0
- megatron/core/datasets/retro/query/query.py +449 -0
- megatron/core/datasets/retro/query/retro_dataset.py +251 -0
- megatron/core/datasets/retro/query/utils.py +35 -0
- megatron/core/datasets/retro/utils.py +386 -0
- megatron/core/datasets/t5_dataset.py +338 -0
- megatron/core/datasets/utils.py +92 -0
- megatron/core/datasets/utils_s3.py +5 -0
- megatron/core/dist_checkpointing/__init__.py +13 -0
- megatron/core/dist_checkpointing/core.py +93 -0
- megatron/core/dist_checkpointing/dict_utils.py +256 -0
- megatron/core/dist_checkpointing/exchange_utils.py +576 -0
- megatron/core/dist_checkpointing/mapping.py +738 -0
- megatron/core/dist_checkpointing/optimizer.py +148 -0
- megatron/core/dist_checkpointing/serialization.py +454 -0
- megatron/core/dist_checkpointing/state_dict_utils.py +112 -0
- megatron/core/dist_checkpointing/strategies/__init__.py +7 -0
- megatron/core/dist_checkpointing/strategies/async_utils.py +602 -0
- megatron/core/dist_checkpointing/strategies/base.py +224 -0
- megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +38 -0
- megatron/core/dist_checkpointing/strategies/checkpointable.py +196 -0
- megatron/core/dist_checkpointing/strategies/common.py +193 -0
- megatron/core/dist_checkpointing/strategies/filesystem_async.py +645 -0
- megatron/core/dist_checkpointing/strategies/fully_parallel.py +520 -0
- megatron/core/dist_checkpointing/strategies/resharding.py +320 -0
- megatron/core/dist_checkpointing/strategies/state_dict_saver.py +258 -0
- megatron/core/dist_checkpointing/strategies/tensorstore.py +149 -0
- megatron/core/dist_checkpointing/strategies/torch.py +1123 -0
- megatron/core/dist_checkpointing/strategies/two_stage.py +268 -0
- megatron/core/dist_checkpointing/strategies/zarr.py +357 -0
- megatron/core/dist_checkpointing/tensor_aware_state_dict.py +394 -0
- megatron/core/dist_checkpointing/utils.py +332 -0
- megatron/core/dist_checkpointing/validation.py +585 -0
- megatron/core/distributed/__init__.py +13 -0
- megatron/core/distributed/data_parallel_base.py +96 -0
- megatron/core/distributed/distributed_data_parallel.py +584 -0
- megatron/core/distributed/distributed_data_parallel_config.py +155 -0
- megatron/core/distributed/finalize_model_grads.py +488 -0
- megatron/core/distributed/fsdp/__init__.py +3 -0
- megatron/core/distributed/fsdp/mcore_fsdp_adapter.py +431 -0
- megatron/core/distributed/fsdp/src/__init__.py +13 -0
- megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py +51 -0
- megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py +146 -0
- megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py +540 -0
- megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +1223 -0
- megatron/core/distributed/fsdp/src/megatron_fsdp/package_info.py +27 -0
- megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +3812 -0
- megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py +460 -0
- megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +992 -0
- megatron/core/distributed/param_and_grad_buffer.py +1006 -0
- megatron/core/distributed/reduce_scatter_with_fp32_accumulation.py +92 -0
- megatron/core/distributed/torch_fully_sharded_data_parallel.py +154 -0
- megatron/core/distributed/torch_fully_sharded_data_parallel_config.py +19 -0
- megatron/core/energy_monitor.py +91 -0
- megatron/core/enums.py +36 -0
- megatron/core/export/__init__.py +1 -0
- megatron/core/export/data_type.py +5 -0
- megatron/core/export/export_config.py +32 -0
- megatron/core/export/model_type.py +8 -0
- megatron/core/export/trtllm/__init__.py +1 -0
- megatron/core/export/trtllm/engine_builder/__init__.py +1 -0
- megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +172 -0
- megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +1 -0
- megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +50 -0
- megatron/core/export/trtllm/trt_model_config.py +25 -0
- megatron/core/export/trtllm/trt_model_type.py +14 -0
- megatron/core/export/trtllm/trtllm_helper.py +614 -0
- megatron/core/export/trtllm/trtllm_layers.py +169 -0
- megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +1 -0
- megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +293 -0
- megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +512 -0
- megatron/core/export/trtllm/trtllm_weights_converter/utils.py +8 -0
- megatron/core/extensions/__init__.py +0 -0
- megatron/core/extensions/kitchen.py +1092 -0
- megatron/core/extensions/transformer_engine.py +2118 -0
- megatron/core/extensions/transformer_engine_spec_provider.py +95 -0
- megatron/core/fp4_utils.py +139 -0
- megatron/core/fp8_utils.py +750 -0
- megatron/core/full_cuda_graph.py +198 -0
- megatron/core/fusions/__init__.py +0 -0
- megatron/core/fusions/fused_bias_dropout.py +92 -0
- megatron/core/fusions/fused_bias_geglu.py +442 -0
- megatron/core/fusions/fused_bias_gelu.py +55 -0
- megatron/core/fusions/fused_bias_swiglu.py +255 -0
- megatron/core/fusions/fused_cross_entropy.py +148 -0
- megatron/core/fusions/fused_indices_converter.py +288 -0
- megatron/core/fusions/fused_layer_norm.py +169 -0
- megatron/core/fusions/fused_mla_yarn_rope_apply.py +784 -0
- megatron/core/fusions/fused_pad_routing_map.py +98 -0
- megatron/core/fusions/fused_softmax.py +359 -0
- megatron/core/fusions/fused_weighted_squared_relu.py +110 -0
- megatron/core/hyper_comm_grid.py +239 -0
- megatron/core/inference/__init__.py +1 -0
- megatron/core/inference/async_stream.py +73 -0
- megatron/core/inference/common_inference_params.py +4 -0
- megatron/core/inference/communication_utils.py +211 -0
- megatron/core/inference/contexts/__init__.py +23 -0
- megatron/core/inference/contexts/attention_context/mamba_metadata.py +106 -0
- megatron/core/inference/contexts/attention_context/metadata_base.py +72 -0
- megatron/core/inference/contexts/attention_context/mha_metadata.py +220 -0
- megatron/core/inference/contexts/base_context.py +43 -0
- megatron/core/inference/contexts/dynamic_block_allocator.py +118 -0
- megatron/core/inference/contexts/dynamic_context.py +1804 -0
- megatron/core/inference/contexts/fused_kv_append_kernel.py +174 -0
- megatron/core/inference/contexts/static_context.py +130 -0
- megatron/core/inference/data_parallel_inference_coordinator.py +255 -0
- megatron/core/inference/engines/__init__.py +5 -0
- megatron/core/inference/engines/abstract_engine.py +17 -0
- megatron/core/inference/engines/dynamic_engine.py +1102 -0
- megatron/core/inference/engines/mcore_engine.py +5 -0
- megatron/core/inference/engines/static_engine.py +388 -0
- megatron/core/inference/headers.py +17 -0
- megatron/core/inference/inference_client.py +193 -0
- megatron/core/inference/inference_request.py +357 -0
- megatron/core/inference/model_inference_wrappers/__init__.py +1 -0
- megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +389 -0
- megatron/core/inference/model_inference_wrappers/gpt/__init__.py +1 -0
- megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +131 -0
- megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +66 -0
- megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +216 -0
- megatron/core/inference/model_inference_wrappers/t5/__init__.py +1 -0
- megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +230 -0
- megatron/core/inference/sampling_params.py +56 -0
- megatron/core/inference/scheduler.py +193 -0
- megatron/core/inference/text_generation_controllers/__init__.py +1 -0
- megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +51 -0
- megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +5 -0
- megatron/core/inference/text_generation_controllers/text_generation_controller.py +1474 -0
- megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +53 -0
- megatron/core/inference/text_generation_server/__init__.py +3 -0
- megatron/core/inference/text_generation_server/endpoints/common.py +14 -0
- megatron/core/inference/text_generation_server/endpoints/completions.py +212 -0
- megatron/core/inference/text_generation_server/run_mcore_engine.py +111 -0
- megatron/core/inference/text_generation_server/text_generation_server.py +211 -0
- megatron/core/inference/text_generation_server/tokenization.py +110 -0
- megatron/core/inference/unified_memory.py +127 -0
- megatron/core/inference/utils.py +163 -0
- megatron/core/inference_params.py +5 -0
- megatron/core/jit.py +18 -0
- megatron/core/model_parallel_config.py +404 -0
- megatron/core/models/T5/__init__.py +2 -0
- megatron/core/models/T5/t5_model.py +536 -0
- megatron/core/models/T5/t5_spec.py +251 -0
- megatron/core/models/__init__.py +1 -0
- megatron/core/models/backends.py +182 -0
- megatron/core/models/bert/__init__.py +0 -0
- megatron/core/models/bert/bert_layer_specs.py +118 -0
- megatron/core/models/bert/bert_lm_head.py +50 -0
- megatron/core/models/bert/bert_model.py +386 -0
- megatron/core/models/bert/pooler.py +52 -0
- megatron/core/models/common/__init__.py +0 -0
- megatron/core/models/common/embeddings/__init__.py +5 -0
- megatron/core/models/common/embeddings/language_model_embedding.py +150 -0
- megatron/core/models/common/embeddings/relative_pos_embedding.py +180 -0
- megatron/core/models/common/embeddings/rope_utils.py +345 -0
- megatron/core/models/common/embeddings/rotary_pos_embedding.py +325 -0
- megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +249 -0
- megatron/core/models/common/language_module/__init__.py +0 -0
- megatron/core/models/common/language_module/language_module.py +344 -0
- megatron/core/models/common/model_chunk_schedule_plan.py +508 -0
- megatron/core/models/common/vision_module/__init__.py +0 -0
- megatron/core/models/common/vision_module/vision_module.py +17 -0
- megatron/core/models/gpt/__init__.py +2 -0
- megatron/core/models/gpt/fine_grained_callables.py +585 -0
- megatron/core/models/gpt/gpt_layer_specs.py +673 -0
- megatron/core/models/gpt/gpt_model.py +765 -0
- megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +220 -0
- megatron/core/models/gpt/moe_module_specs.py +74 -0
- megatron/core/models/huggingface/__init__.py +2 -0
- megatron/core/models/huggingface/clip_model.py +42 -0
- megatron/core/models/huggingface/module.py +97 -0
- megatron/core/models/huggingface/qwen_model.py +59 -0
- megatron/core/models/mamba/__init__.py +2 -0
- megatron/core/models/mamba/mamba_layer_specs.py +68 -0
- megatron/core/models/mamba/mamba_model.py +289 -0
- megatron/core/models/mimo/__init__.py +16 -0
- megatron/core/models/mimo/config/__init__.py +5 -0
- megatron/core/models/mimo/config/base_configs.py +34 -0
- megatron/core/models/mimo/model/__init__.py +4 -0
- megatron/core/models/mimo/model/base.py +290 -0
- megatron/core/models/mimo/submodules/audio.py +155 -0
- megatron/core/models/mimo/submodules/base.py +193 -0
- megatron/core/models/mimo/submodules/vision.py +184 -0
- megatron/core/models/multimodal/__init__.py +1 -0
- megatron/core/models/multimodal/context_parallel.py +111 -0
- megatron/core/models/multimodal/llava_model.py +1028 -0
- megatron/core/models/multimodal/llava_spec.py +90 -0
- megatron/core/models/retro/__init__.py +13 -0
- megatron/core/models/retro/base_attention.py +47 -0
- megatron/core/models/retro/config.py +88 -0
- megatron/core/models/retro/decoder_attention.py +319 -0
- megatron/core/models/retro/decoder_spec.py +195 -0
- megatron/core/models/retro/encoder_attention.py +231 -0
- megatron/core/models/retro/encoder_spec.py +171 -0
- megatron/core/models/retro/model.py +107 -0
- megatron/core/models/retro/utils.py +24 -0
- megatron/core/models/vision/__init__.py +0 -0
- megatron/core/models/vision/clip_vit_model.py +261 -0
- megatron/core/models/vision/multimodal_projector.py +88 -0
- megatron/core/models/vision/radio.py +380 -0
- megatron/core/models/vision/vit_layer_specs.py +96 -0
- megatron/core/msc_utils.py +69 -0
- megatron/core/nccl_allocator.py +316 -0
- megatron/core/num_microbatches_calculator.py +508 -0
- megatron/core/optimizer/__init__.py +635 -0
- megatron/core/optimizer/clip_grads.py +247 -0
- megatron/core/optimizer/cpu_offloading/__init__.py +2 -0
- megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +472 -0
- megatron/core/optimizer/distrib_optimizer.py +2602 -0
- megatron/core/optimizer/grad_scaler.py +142 -0
- megatron/core/optimizer/optimizer.py +1418 -0
- megatron/core/optimizer/optimizer_config.py +308 -0
- megatron/core/optimizer_param_scheduler.py +311 -0
- megatron/core/package_info.py +27 -0
- megatron/core/packed_seq_params.py +20 -0
- megatron/core/parallel_state.py +2097 -0
- megatron/core/pipeline_parallel/__init__.py +2 -0
- megatron/core/pipeline_parallel/bridge_communicator.py +922 -0
- megatron/core/pipeline_parallel/combined_1f1b.py +444 -0
- megatron/core/pipeline_parallel/p2p_communication.py +645 -0
- megatron/core/pipeline_parallel/schedules.py +2303 -0
- megatron/core/pipeline_parallel/utils.py +307 -0
- megatron/core/post_training/__init__.py +1 -0
- megatron/core/post_training/modelopt/__init__.py +10 -0
- megatron/core/post_training/modelopt/gpt/__init__.py +1 -0
- megatron/core/post_training/modelopt/gpt/model_specs.py +206 -0
- megatron/core/post_training/modelopt/gpt/state_dict_hooks.py +64 -0
- megatron/core/post_training/modelopt/layers.py +249 -0
- megatron/core/post_training/modelopt/mamba/__init__.py +1 -0
- megatron/core/post_training/modelopt/mamba/model_specs.py +91 -0
- megatron/core/process_groups_config.py +571 -0
- megatron/core/quantization/__init__.py +1 -0
- megatron/core/quantization/quant_config.py +219 -0
- megatron/core/quantization/utils.py +37 -0
- megatron/core/requirements.txt +2 -0
- megatron/core/rerun_state_machine.py +1345 -0
- megatron/core/safe_globals.py +39 -0
- megatron/core/ssm/__init__.py +1 -0
- megatron/core/ssm/mamba_block.py +414 -0
- megatron/core/ssm/mamba_context_parallel.py +389 -0
- megatron/core/ssm/mamba_hybrid_layer_allocation.py +218 -0
- megatron/core/ssm/mamba_layer.py +184 -0
- megatron/core/ssm/mamba_mixer.py +1171 -0
- megatron/core/ssm/mlp_layer.py +30 -0
- megatron/core/ssm/triton_cache_manager.py +81 -0
- megatron/core/tensor_parallel/__init__.py +74 -0
- megatron/core/tensor_parallel/cross_entropy.py +232 -0
- megatron/core/tensor_parallel/data.py +101 -0
- megatron/core/tensor_parallel/inference_layers.py +151 -0
- megatron/core/tensor_parallel/layers.py +1303 -0
- megatron/core/tensor_parallel/mappings.py +596 -0
- megatron/core/tensor_parallel/random.py +615 -0
- megatron/core/tensor_parallel/utils.py +121 -0
- megatron/core/timers.py +465 -0
- megatron/core/tokenizers/__init__.py +4 -0
- megatron/core/tokenizers/base_tokenizer.py +48 -0
- megatron/core/tokenizers/megatron_tokenizer.py +171 -0
- megatron/core/tokenizers/text/__init__.py +3 -0
- megatron/core/tokenizers/text/libraries/__init__.py +8 -0
- megatron/core/tokenizers/text/libraries/abstract_tokenizer.py +147 -0
- megatron/core/tokenizers/text/libraries/bytelevel_tokenizer.py +164 -0
- megatron/core/tokenizers/text/libraries/chat_template.py +71 -0
- megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py +335 -0
- megatron/core/tokenizers/text/libraries/megatron_hf_tokenizer.py +179 -0
- megatron/core/tokenizers/text/libraries/null_tokenizer.py +79 -0
- megatron/core/tokenizers/text/libraries/sentencepiece_tokenizer.py +411 -0
- megatron/core/tokenizers/text/libraries/tiktoken_tokenizer.py +303 -0
- megatron/core/tokenizers/text/models/__init__.py +8 -0
- megatron/core/tokenizers/text/models/bert_tokenizer.py +12 -0
- megatron/core/tokenizers/text/models/default_tokenizer.py +12 -0
- megatron/core/tokenizers/text/models/gpt_tokenizer.py +12 -0
- megatron/core/tokenizers/text/models/mamba_tokenizer.py +12 -0
- megatron/core/tokenizers/text/models/retro_tokenizer.py +12 -0
- megatron/core/tokenizers/text/models/t5_tokenizer.py +12 -0
- megatron/core/tokenizers/text/text_tokenizer.py +254 -0
- megatron/core/tokenizers/text/utils/build_tokenizer.py +58 -0
- megatron/core/transformer/__init__.py +6 -0
- megatron/core/transformer/attention.py +1238 -0
- megatron/core/transformer/cuda_graphs.py +1676 -0
- megatron/core/transformer/custom_layers/__init__.py +0 -0
- megatron/core/transformer/custom_layers/transformer_engine.py +12 -0
- megatron/core/transformer/dot_product_attention.py +258 -0
- megatron/core/transformer/enums.py +67 -0
- megatron/core/transformer/fsdp_dtensor_checkpoint.py +455 -0
- megatron/core/transformer/heterogeneous/heterogeneous_config.py +267 -0
- megatron/core/transformer/heterogeneous/linear_replacements.py +115 -0
- megatron/core/transformer/identity_op.py +28 -0
- megatron/core/transformer/mlp.py +403 -0
- megatron/core/transformer/module.py +453 -0
- megatron/core/transformer/moe/__init__.py +0 -0
- megatron/core/transformer/moe/experts.py +1166 -0
- megatron/core/transformer/moe/fused_a2a.py +264 -0
- megatron/core/transformer/moe/grouped_gemm_util.py +22 -0
- megatron/core/transformer/moe/moe_layer.py +309 -0
- megatron/core/transformer/moe/moe_utils.py +1030 -0
- megatron/core/transformer/moe/router.py +572 -0
- megatron/core/transformer/moe/shared_experts.py +286 -0
- megatron/core/transformer/moe/token_dispatcher.py +1327 -0
- megatron/core/transformer/moe/upcycling_utils.py +359 -0
- megatron/core/transformer/multi_latent_attention.py +919 -0
- megatron/core/transformer/multi_token_prediction.py +955 -0
- megatron/core/transformer/pipeline_parallel_layer_layout.py +308 -0
- megatron/core/transformer/spec_utils.py +106 -0
- megatron/core/transformer/torch_layer_norm.py +4 -0
- megatron/core/transformer/torch_norm.py +96 -0
- megatron/core/transformer/transformer_block.py +815 -0
- megatron/core/transformer/transformer_config.py +1647 -0
- megatron/core/transformer/transformer_layer.py +852 -0
- megatron/core/transformer/utils.py +419 -0
- megatron/core/utils.py +2154 -0
- megatron_core-0.16.0rc0.dev127461.dist-info/METADATA +579 -0
- megatron_core-0.16.0rc0.dev127461.dist-info/RECORD +356 -0
- megatron_core-0.16.0rc0.dev127461.dist-info/WHEEL +6 -0
- megatron_core-0.16.0rc0.dev127461.dist-info/top_level.txt +1 -0
megatron/core/README.md
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
<div align="center">
|
|
2
|
+
|
|
3
|
+
Megatron Core
|
|
4
|
+
=============
|
|
5
|
+
<h4>Production-ready library for building custom training frameworks</h4>
|
|
6
|
+
|
|
7
|
+
<div align="left">
|
|
8
|
+
|
|
9
|
+
## ⚡ Quick Start
|
|
10
|
+
|
|
11
|
+
```bash
|
|
12
|
+
# Install Megatron Core with required dependencies
|
|
13
|
+
pip install --no-build-isolation megatron-core[dev]
|
|
14
|
+
|
|
15
|
+
# Distributed training example (2 GPUs, mock data)
|
|
16
|
+
torchrun --nproc_per_node=2 examples/run_simple_mcore_train_loop.py
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
# What is Megatron Core?
|
|
20
|
+
|
|
21
|
+
**Megatron Core** is an open-source PyTorch-based library that contains GPU-optimized techniques and cutting-edge system-level optimizations. It abstracts them into composable and modular APIs, allowing full flexibility for developers and model researchers to train custom transformers at-scale on NVIDIA accelerated computing infrastructure.
|
|
22
|
+
|
|
23
|
+
## 🚀 Key Components
|
|
24
|
+
|
|
25
|
+
### GPU-Optimized Building Blocks
|
|
26
|
+
- **Transformer Components**: Attention mechanisms, MLP layers, embeddings
|
|
27
|
+
- **Memory Management**: Activation recomputation
|
|
28
|
+
- **FP8 Precision**: Optimized for NVIDIA Hopper, Ada, and Blackwell GPUs
|
|
29
|
+
|
|
30
|
+
### Parallelism Strategies
|
|
31
|
+
- **Tensor Parallelism (TP)**: Layer-wise parallelization (activation memory footprint can be further reduced using sequence parallelism)
|
|
32
|
+
- **Pipeline Parallelism (PP)**: Depth-wise model splitting and pipelining of microbatches to improve efficiency
|
|
33
|
+
- **Context Parallelism (CP)**: Long sequence handling ([documentation](https://docs.nvidia.com/megatron-core/developer-guide/latest/api-guide/context_parallel.html))
|
|
34
|
+
- **Expert Parallelism (EP)**: Split experts of an MoE model across multiple GPUs
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
## 🔗 Examples & Documentation
|
|
38
|
+
|
|
39
|
+
**Examples:**
|
|
40
|
+
- **[Simple Training Loop](https://github.com/NVIDIA/Megatron-LM/blob/main/examples/run_simple_mcore_train_loop.py)** - Basic usage
|
|
41
|
+
- **[Multimodal Training](https://github.com/NVIDIA/Megatron-LM/blob/main/examples/multimodal/)** - Vision-language models
|
|
42
|
+
- **[Mixture-of-Experts](https://github.com/yanring/Megatron-MoE-ModelZoo)** - MoE examples
|
|
43
|
+
- **[Mamba Models](https://github.com/NVIDIA/Megatron-LM/blob/main/examples/mamba/)** - State-space models
|
|
44
|
+
|
|
45
|
+
**Documentation:**
|
|
46
|
+
- **[📚 API Guide](https://docs.nvidia.com/megatron-core/developer-guide/latest/api-guide/index.html)** - Complete API documentation
|
|
47
|
+
- **[💡 Developer Guide](https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html)** - Custom framework development
|
|
48
|
+
|
|
49
|
+
---
|
|
50
|
+
|
|
51
|
+
*For complete installation instructions, performance benchmarks, and ecosystem information, see the [main README](../README.md).*
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
|
|
3
|
+
import megatron.core.tensor_parallel
|
|
4
|
+
import megatron.core.utils
|
|
5
|
+
from megatron.core import parallel_state
|
|
6
|
+
from megatron.core.distributed import DistributedDataParallel
|
|
7
|
+
from megatron.core.inference_params import InferenceParams
|
|
8
|
+
from megatron.core.model_parallel_config import ModelParallelConfig
|
|
9
|
+
from megatron.core.package_info import (
|
|
10
|
+
__contact_emails__,
|
|
11
|
+
__contact_names__,
|
|
12
|
+
__description__,
|
|
13
|
+
__download_url__,
|
|
14
|
+
__homepage__,
|
|
15
|
+
__keywords__,
|
|
16
|
+
__license__,
|
|
17
|
+
__package_name__,
|
|
18
|
+
__repository_url__,
|
|
19
|
+
__shortversion__,
|
|
20
|
+
__version__,
|
|
21
|
+
)
|
|
22
|
+
from megatron.core.timers import Timers
|
|
23
|
+
from megatron.core.utils import is_torch_min_version
|
|
24
|
+
|
|
25
|
+
# Alias parallel_state as mpu, its legacy name
|
|
26
|
+
mpu = parallel_state
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"parallel_state",
|
|
30
|
+
"tensor_parallel",
|
|
31
|
+
"utils",
|
|
32
|
+
"DistributedDataParallel",
|
|
33
|
+
"InferenceParams",
|
|
34
|
+
"ModelParallelConfig",
|
|
35
|
+
"Timers",
|
|
36
|
+
"__contact_emails__",
|
|
37
|
+
"__contact_names__",
|
|
38
|
+
"__description__",
|
|
39
|
+
"__download_url__",
|
|
40
|
+
"__homepage__",
|
|
41
|
+
"__keywords__",
|
|
42
|
+
"__license__",
|
|
43
|
+
"__package_name__",
|
|
44
|
+
"__repository_url__",
|
|
45
|
+
"__shortversion__",
|
|
46
|
+
"__version__",
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
from .safe_globals import register_safe_globals
|
|
50
|
+
|
|
51
|
+
if is_torch_min_version("2.6a0"):
|
|
52
|
+
register_safe_globals()
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
from megatron.core.jit import jit_fuser
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@jit_fuser
|
|
9
|
+
def squared_relu(x: torch.Tensor) -> torch.Tensor:
|
|
10
|
+
"""Squared ReLU activation"""
|
|
11
|
+
return torch.pow(F.relu(x), 2)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@jit_fuser
|
|
15
|
+
def quick_gelu(x: torch.Tensor) -> torch.Tensor:
|
|
16
|
+
"""Quick GELU activation"""
|
|
17
|
+
return x * torch.sigmoid(1.702 * x)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@jit_fuser
|
|
21
|
+
def fast_gelu(x: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
"""Fast GELU activation"""
|
|
23
|
+
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
|
megatron/core/config.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
|
|
3
|
+
ENABLE_EXPERIMENTAL = False
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def set_experimental_flag(flag: bool):
|
|
7
|
+
"""Set the experimental flag to the given value."""
|
|
8
|
+
global ENABLE_EXPERIMENTAL
|
|
9
|
+
ENABLE_EXPERIMENTAL = flag
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def is_experimental_enabled():
|
|
13
|
+
"""Return the experimental flag."""
|
|
14
|
+
return ENABLE_EXPERIMENTAL
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# Copyright (c) 2025, NVIDIA CORPORATION.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import dataclasses
|
|
16
|
+
import json
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn as nn
|
|
21
|
+
|
|
22
|
+
from megatron.core import parallel_state
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_config_logger_path(config):
|
|
26
|
+
"""Get the path to the config logger directory."""
|
|
27
|
+
return getattr(config, 'config_logger_dir', '')
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def has_config_logger_enabled(config):
|
|
31
|
+
"""Check if config logger is enabled."""
|
|
32
|
+
return get_config_logger_path(config) != ''
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# For each prefix, holds a counter and increases it every time we dump with this
|
|
36
|
+
# prefix.
|
|
37
|
+
__config_logger_path_counts = {}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_path_count(path):
|
|
41
|
+
"""
|
|
42
|
+
keeps tracks of number of times we've seen the input `path` and return count-1
|
|
43
|
+
"""
|
|
44
|
+
global __config_logger_path_counts
|
|
45
|
+
if not path in __config_logger_path_counts:
|
|
46
|
+
__config_logger_path_counts[path] = 0
|
|
47
|
+
count = __config_logger_path_counts[path]
|
|
48
|
+
__config_logger_path_counts[path] += 1
|
|
49
|
+
return count
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_path_with_count(path):
|
|
53
|
+
"""
|
|
54
|
+
calls get_path_count and appends returned value to path
|
|
55
|
+
"""
|
|
56
|
+
return f'{path}.iter{get_path_count(path)}'
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class JSONEncoderWithMcoreTypes(json.JSONEncoder):
|
|
60
|
+
"""
|
|
61
|
+
Custom JSON encoder that serializes according to types in mcore.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def default(self, o):
|
|
65
|
+
if type(o).__name__ in ['function', 'ProcessGroup']:
|
|
66
|
+
return str(o)
|
|
67
|
+
if type(o).__name__ in ['dict', 'OrderedDict']:
|
|
68
|
+
return {k: self.default(v) for k, v in o.items()}
|
|
69
|
+
if type(o).__name__ in ['list', 'ModuleList']:
|
|
70
|
+
return [self.default(val) for val in o]
|
|
71
|
+
if type(o).__name__ == 'UniqueDescriptor':
|
|
72
|
+
return {
|
|
73
|
+
attr: self.default(getattr(o, attr))
|
|
74
|
+
for attr in filter(lambda x: not x.startswith('__'), dir(o))
|
|
75
|
+
}
|
|
76
|
+
if type(o) is torch.dtype:
|
|
77
|
+
return str(o)
|
|
78
|
+
# if it's a Float16Module, add "Float16Module" to the output dict
|
|
79
|
+
if type(o).__name__ == 'Float16Module':
|
|
80
|
+
return {'Float16Module': {'module': self.default(o.module)}}
|
|
81
|
+
# If it's a nn.Module subchild, either print its children or itself if leaf.
|
|
82
|
+
if issubclass(type(o), nn.Module):
|
|
83
|
+
if len(getattr(o, '_modules', {})) > 0:
|
|
84
|
+
return {key: self.default(val) for key, val in o._modules.items()}
|
|
85
|
+
else:
|
|
86
|
+
return str(o)
|
|
87
|
+
if type(o).__name__ in ['ABCMeta', 'type', 'AttnMaskType']:
|
|
88
|
+
return str(o)
|
|
89
|
+
if dataclasses.is_dataclass(o) or type(o).__name__ in ['ModuleSpec', 'TransformerConfig']:
|
|
90
|
+
return dataclasses.asdict(o)
|
|
91
|
+
try:
|
|
92
|
+
return super().default(o)
|
|
93
|
+
except:
|
|
94
|
+
return str(o)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def log_config_to_disk(config, dict_data, prefix='', rank_str=''):
|
|
98
|
+
"""
|
|
99
|
+
Encodes the input dict (dict_data) using the JSONEncoderWithMcoreTypes
|
|
100
|
+
and dumps to disk, as specified via path
|
|
101
|
+
"""
|
|
102
|
+
path = get_config_logger_path(config)
|
|
103
|
+
assert path is not None, 'Expected config_logger_dir to be non-empty in config.'
|
|
104
|
+
|
|
105
|
+
if not os.path.exists(path):
|
|
106
|
+
os.makedirs(path, exist_ok=True)
|
|
107
|
+
|
|
108
|
+
if 'self' in dict_data:
|
|
109
|
+
if prefix == '':
|
|
110
|
+
prefix = type(dict_data['self']).__name__
|
|
111
|
+
del dict_data['self']
|
|
112
|
+
|
|
113
|
+
# the caller of the funcion can decide the most informative string
|
|
114
|
+
# rank_str defaults to '0_0_0_0_0' format (tp_dp_cp_pp_ep ranks)
|
|
115
|
+
if rank_str == '':
|
|
116
|
+
rank_str = parallel_state.get_all_ranks()
|
|
117
|
+
|
|
118
|
+
path = get_path_with_count(os.path.join(path, f'{prefix}.rank_{rank_str}'))
|
|
119
|
+
if type(dict_data).__name__ == 'OrderedDict':
|
|
120
|
+
torch.save(dict_data, f'{path}.pth')
|
|
121
|
+
else:
|
|
122
|
+
with open(f'{path}.json', 'w') as fp:
|
|
123
|
+
json.dump(dict_data, fp, cls=JSONEncoderWithMcoreTypes)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
__all__ = ['has_config_logger_enabled', 'log_config_to_disk']
|
|
File without changes
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, List, Optional, Union
|
|
5
|
+
|
|
6
|
+
import numpy
|
|
7
|
+
|
|
8
|
+
from megatron.core.datasets.indexed_dataset import IndexedDataset
|
|
9
|
+
from megatron.core.datasets.masked_dataset import (
|
|
10
|
+
MaskedWordPieceDataset,
|
|
11
|
+
MaskedWordPieceDatasetConfig,
|
|
12
|
+
)
|
|
13
|
+
from megatron.core.datasets.utils import Split
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class BERTMaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig):
|
|
18
|
+
"""Configuration object for Megatron Core BERT WordPiece datasets"""
|
|
19
|
+
|
|
20
|
+
classification_head: bool = None
|
|
21
|
+
"""Option to perform the next sequence prediction during sampling"""
|
|
22
|
+
|
|
23
|
+
def __post_init__(self) -> None:
|
|
24
|
+
"""Do asserts and set fields post init"""
|
|
25
|
+
super().__post_init__()
|
|
26
|
+
|
|
27
|
+
assert self.classification_head is not None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BERTMaskedWordPieceDataset(MaskedWordPieceDataset):
|
|
31
|
+
"""The BERT dataset that assumes WordPiece tokenization
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
indexed_dataset (IndexedDataset): The IndexedDataset around which
|
|
35
|
+
to build the MegatronDataset
|
|
36
|
+
dataset_path (str): The real path on disk to the dataset, for bookkeeping
|
|
37
|
+
indexed_indices (numpy.ndarray): The set of the documents indices to expose
|
|
38
|
+
num_samples (Optional[int]): The number of samples to draw from the indexed dataset.
|
|
39
|
+
When None, build as many samples as correspond to one epoch.
|
|
40
|
+
index_split (Split): The indexed_indices Split
|
|
41
|
+
config (BERTMaskedWordPieceDatasetConfig): The config
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
indexed_dataset: IndexedDataset,
|
|
47
|
+
dataset_path: str,
|
|
48
|
+
indexed_indices: numpy.ndarray,
|
|
49
|
+
num_samples: Optional[int],
|
|
50
|
+
index_split: Split,
|
|
51
|
+
config: BERTMaskedWordPieceDatasetConfig,
|
|
52
|
+
) -> None:
|
|
53
|
+
super().__init__(
|
|
54
|
+
indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
self.token_lookup = list(self.config.tokenizer.inv_vocab.keys())
|
|
58
|
+
# Account for the single <cls> and two <sep> token ids
|
|
59
|
+
self.sample_index = self._build_sample_index(
|
|
60
|
+
self.config.sequence_length - 3, 2 if self.config.classification_head else 1
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def _key_config_attributes() -> List[str]:
|
|
65
|
+
"""Inherited method implementation
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
List[str]: The key config attributes
|
|
69
|
+
"""
|
|
70
|
+
return super(
|
|
71
|
+
BERTMaskedWordPieceDataset, BERTMaskedWordPieceDataset
|
|
72
|
+
)._key_config_attributes() + ["classification_head"]
|
|
73
|
+
|
|
74
|
+
def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
|
|
75
|
+
"""Abstract method implementation
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
idx (int): The index into the dataset
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Dict[str, Union[int, numpy.ndarray]]: The
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
idx_beg, idx_end, target_sequence_length = self.sample_index[idx]
|
|
85
|
+
sample = [self.dataset[i] for i in range(idx_beg, idx_end)]
|
|
86
|
+
numpy_random_state = numpy.random.RandomState(seed=(self.config.random_seed + idx) % 2**32)
|
|
87
|
+
|
|
88
|
+
assert target_sequence_length <= self.config.sequence_length
|
|
89
|
+
|
|
90
|
+
# Split the sample into contiguous subsegments A and B
|
|
91
|
+
pivot = len(sample)
|
|
92
|
+
is_next_random = False
|
|
93
|
+
if self.config.classification_head:
|
|
94
|
+
assert len(sample) > 1, "the sample must contain at least two sentences"
|
|
95
|
+
pivot = 1
|
|
96
|
+
if len(sample) >= 3:
|
|
97
|
+
pivot = numpy_random_state.randint(low=1, high=len(sample))
|
|
98
|
+
is_next_random = numpy_random_state.random() < 0.5
|
|
99
|
+
split_A = []
|
|
100
|
+
for sample_a in sample[:pivot]:
|
|
101
|
+
split_A.extend(sample_a)
|
|
102
|
+
split_B = []
|
|
103
|
+
for sample_b in sample[pivot:]:
|
|
104
|
+
split_B.extend(sample_b)
|
|
105
|
+
if is_next_random:
|
|
106
|
+
split_A, split_B = split_B, split_A
|
|
107
|
+
|
|
108
|
+
# Trim the subsegments from either end to a desired joint length
|
|
109
|
+
length_A = len(split_A)
|
|
110
|
+
length_B = len(split_B)
|
|
111
|
+
if length_A + length_B <= target_sequence_length:
|
|
112
|
+
truncated = False
|
|
113
|
+
else:
|
|
114
|
+
while length_A + length_B > target_sequence_length:
|
|
115
|
+
split = split_A if length_A > length_B else split_B
|
|
116
|
+
if numpy_random_state.random() < 0.5:
|
|
117
|
+
del split[0]
|
|
118
|
+
else:
|
|
119
|
+
del split[-1]
|
|
120
|
+
length_A = len(split_A)
|
|
121
|
+
length_B = len(split_B)
|
|
122
|
+
truncated = True
|
|
123
|
+
|
|
124
|
+
# Merge the subsegments and create the token assignment labels
|
|
125
|
+
tokens = [self.config.tokenizer.cls, *split_A, self.config.tokenizer.sep]
|
|
126
|
+
assignments = [0 for _ in range(1 + len(split_A) + 1)]
|
|
127
|
+
if split_B:
|
|
128
|
+
tokens += [*split_B, self.config.tokenizer.sep]
|
|
129
|
+
assignments += [1 for _ in range(len(split_B) + 1)]
|
|
130
|
+
|
|
131
|
+
# Masking
|
|
132
|
+
tokens, masked_positions, masked_labels, _, _ = self._create_masked_lm_predictions(
|
|
133
|
+
tokens, target_sequence_length, numpy_random_state
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Pad the sequences and convert to NumPy
|
|
137
|
+
length_toks = len(tokens)
|
|
138
|
+
length_pads = self.config.sequence_length - length_toks
|
|
139
|
+
assert length_pads >= 0
|
|
140
|
+
|
|
141
|
+
tokens = numpy.array(tokens, dtype=numpy.int64)
|
|
142
|
+
tokens = numpy.pad(tokens, (0, length_pads), constant_values=self._pad_token_id)
|
|
143
|
+
|
|
144
|
+
assignments = numpy.array(assignments, dtype=numpy.int64)
|
|
145
|
+
assignments = numpy.pad(assignments, (0, length_pads), constant_values=self._pad_token_id)
|
|
146
|
+
|
|
147
|
+
# Get the padding mask
|
|
148
|
+
mask_pads = numpy.ones(self.config.sequence_length, dtype=numpy.int64)
|
|
149
|
+
mask_pads[tokens == self._pad_token_id] = self._pad_token_id
|
|
150
|
+
|
|
151
|
+
# Mask the labels
|
|
152
|
+
labels = numpy.zeros(self.config.sequence_length, dtype=numpy.int64) - 1
|
|
153
|
+
labels[masked_positions] = masked_labels
|
|
154
|
+
|
|
155
|
+
# Get the loss mask
|
|
156
|
+
mask_loss = numpy.zeros(self.config.sequence_length, dtype=numpy.int64)
|
|
157
|
+
mask_loss[masked_positions] = 1
|
|
158
|
+
|
|
159
|
+
# For padded sequences, ensure the embedding layer can map the token ID
|
|
160
|
+
tokens[tokens == self._pad_token_id] = 0
|
|
161
|
+
labels[labels == self._pad_token_id] = 0
|
|
162
|
+
|
|
163
|
+
return {
|
|
164
|
+
"text": tokens,
|
|
165
|
+
"types": assignments,
|
|
166
|
+
"labels": labels,
|
|
167
|
+
"is_random": int(is_next_random),
|
|
168
|
+
"padding_mask": mask_pads,
|
|
169
|
+
"loss_mask": mask_loss,
|
|
170
|
+
"truncated": int(truncated),
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> Optional[int]:
|
|
174
|
+
"""Abstract method implementation
|
|
175
|
+
|
|
176
|
+
80% of the time, replace the token id with mask token id. 10% of the time, replace token id
|
|
177
|
+
with a random token id from the vocabulary. 10% of the time, do nothing.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
numpy_random_state (RandomState): The NumPy random state
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Optional[int]: The replacement token id or None
|
|
184
|
+
"""
|
|
185
|
+
if numpy_random_state.random() < 0.8:
|
|
186
|
+
return self.config.tokenizer.mask
|
|
187
|
+
else:
|
|
188
|
+
if numpy_random_state.random() >= 0.5:
|
|
189
|
+
return self.token_lookup[numpy_random_state.randint(0, len(self.token_lookup))]
|
|
190
|
+
return None
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import time
|
|
8
|
+
from collections import OrderedDict
|
|
9
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
10
|
+
|
|
11
|
+
import numpy
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
|
|
15
|
+
from megatron.core.datasets.megatron_dataset import MegatronDataset
|
|
16
|
+
from megatron.core.datasets.utils import normalize
|
|
17
|
+
from megatron.core.utils import log_single_rank
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
_VERBOSE = False
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BlendedDataset(torch.utils.data.Dataset):
|
|
25
|
+
"""Conjugating class for a set of MegatronDataset instances
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
datasets (List[MegatronDataset]): The MegatronDataset instances to blend
|
|
29
|
+
|
|
30
|
+
weights (List[Union[int, float]]): The weights that determine the dataset blend ratios
|
|
31
|
+
|
|
32
|
+
size (Optional[int]): The number of samples to draw from the blend. If None, for each
|
|
33
|
+
dataset index idx draw exactly weights[idx] samples from datasets[idx].
|
|
34
|
+
|
|
35
|
+
config (BlendedMegatronDatasetConfig): The config
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
RuntimeError: When the dataset has fewer or more samples than 'size' post-initialization
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
datasets: List[MegatronDataset],
|
|
44
|
+
weights: List[Union[int, float]],
|
|
45
|
+
size: Optional[int],
|
|
46
|
+
config: BlendedMegatronDatasetConfig,
|
|
47
|
+
) -> None:
|
|
48
|
+
assert len(datasets) == len(weights)
|
|
49
|
+
assert len(datasets) < 32767
|
|
50
|
+
assert all(map(lambda _: type(_) == type(datasets[0]), datasets))
|
|
51
|
+
assert all(map(lambda _: _.index_split == datasets[0].index_split, datasets))
|
|
52
|
+
assert all(map(lambda _: _ > 0, weights))
|
|
53
|
+
assert all(map(lambda _: type(_) == type(weights[0]), weights))
|
|
54
|
+
if size is None and isinstance(weights[0], float):
|
|
55
|
+
assert all(map(lambda _: _ == int(_), weights))
|
|
56
|
+
|
|
57
|
+
# Alert user to unnecessary blending
|
|
58
|
+
if len(datasets) == 1:
|
|
59
|
+
log_single_rank(
|
|
60
|
+
logger, logging.WARNING, f"Building a BlendedDataset for a single MegatronDataset"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if size is not None:
|
|
64
|
+
weights = normalize(weights)
|
|
65
|
+
|
|
66
|
+
self.datasets = datasets
|
|
67
|
+
self.split = self.datasets[0].index_split
|
|
68
|
+
self.weights = weights
|
|
69
|
+
self.size = size
|
|
70
|
+
self.config = config
|
|
71
|
+
|
|
72
|
+
unique_identifiers = OrderedDict()
|
|
73
|
+
unique_identifiers["class"] = type(self).__name__
|
|
74
|
+
unique_identifiers["datasets"] = [dataset.unique_identifiers for dataset in self.datasets]
|
|
75
|
+
unique_identifiers["split"] = self.split.name
|
|
76
|
+
unique_identifiers["weights"] = self.weights
|
|
77
|
+
unique_identifiers["size"] = self.size
|
|
78
|
+
|
|
79
|
+
self.unique_description = json.dumps(
|
|
80
|
+
unique_identifiers, indent=4, default=lambda obj: obj.unique_identifiers
|
|
81
|
+
)
|
|
82
|
+
self.unique_description_hash = hashlib.md5(
|
|
83
|
+
self.unique_description.encode("utf-8"), usedforsecurity=False
|
|
84
|
+
).hexdigest()
|
|
85
|
+
|
|
86
|
+
self.dataset_index, self.dataset_sample_index = self._build_indices()
|
|
87
|
+
|
|
88
|
+
def __len__(self) -> int:
|
|
89
|
+
return self.dataset_index.shape[0]
|
|
90
|
+
|
|
91
|
+
def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
|
|
92
|
+
dataset_id = self.dataset_index[idx]
|
|
93
|
+
dataset_sample_id = self.dataset_sample_index[idx]
|
|
94
|
+
return {"dataset_id": dataset_id, **self.datasets[dataset_id][dataset_sample_id]}
|
|
95
|
+
|
|
96
|
+
def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]:
|
|
97
|
+
"""Build and optionally cache the dataset index and the dataset sample index
|
|
98
|
+
|
|
99
|
+
The dataset index is a 1-D mapping which determines the dataset to query. The dataset
|
|
100
|
+
sample index is a 1-D mapping which determines the sample to request from the queried
|
|
101
|
+
dataset.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Tuple[numpy.ndarray, numpy.ndarray]: The dataset index and the dataset sample index
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
path_to_cache = self.config.path_to_cache
|
|
108
|
+
|
|
109
|
+
if path_to_cache:
|
|
110
|
+
get_path_to = lambda suffix: os.path.join(
|
|
111
|
+
path_to_cache,
|
|
112
|
+
f"{self.unique_description_hash}-{type(self).__name__}-{self.split.name}-{suffix}",
|
|
113
|
+
)
|
|
114
|
+
path_to_description = get_path_to("description.txt")
|
|
115
|
+
path_to_dataset_index = get_path_to("dataset_index.npy")
|
|
116
|
+
path_to_dataset_sample_index = get_path_to("dataset_sample_index.npy")
|
|
117
|
+
cache_hit = all(
|
|
118
|
+
map(
|
|
119
|
+
os.path.isfile,
|
|
120
|
+
[path_to_description, path_to_dataset_index, path_to_dataset_sample_index],
|
|
121
|
+
)
|
|
122
|
+
)
|
|
123
|
+
else:
|
|
124
|
+
cache_hit = False
|
|
125
|
+
|
|
126
|
+
if not path_to_cache or (not cache_hit and torch.distributed.get_rank() == 0):
|
|
127
|
+
log_single_rank(
|
|
128
|
+
logger, logging.INFO, f"Build and save the {type(self).__name__} indices"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Build the dataset and dataset sample indexes
|
|
132
|
+
log_single_rank(
|
|
133
|
+
logger, logging.INFO, f"\tBuild and save the dataset and dataset sample indexes"
|
|
134
|
+
)
|
|
135
|
+
t_beg = time.time()
|
|
136
|
+
from megatron.core.datasets import helpers
|
|
137
|
+
|
|
138
|
+
if self.size is not None:
|
|
139
|
+
dataset_index = numpy.zeros(self.size, dtype=numpy.int16)
|
|
140
|
+
dataset_sample_index = numpy.zeros(self.size, dtype=numpy.int64)
|
|
141
|
+
helpers.build_blending_indices(
|
|
142
|
+
dataset_index,
|
|
143
|
+
dataset_sample_index,
|
|
144
|
+
self.weights,
|
|
145
|
+
len(self.datasets),
|
|
146
|
+
self.size,
|
|
147
|
+
_VERBOSE,
|
|
148
|
+
)
|
|
149
|
+
else:
|
|
150
|
+
size = sum(self.weights)
|
|
151
|
+
dataset_index = numpy.zeros(size, dtype=numpy.int16)
|
|
152
|
+
dataset_sample_index = numpy.zeros(size, dtype=numpy.int64)
|
|
153
|
+
helpers.build_exhaustive_blending_indices(
|
|
154
|
+
dataset_index, dataset_sample_index, self.weights, len(self.datasets)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
dataset_indices, dataset_sizes = numpy.unique(dataset_index, return_counts=True)
|
|
158
|
+
for i, (_index, _size) in enumerate(zip(dataset_indices, dataset_sizes)):
|
|
159
|
+
if len(self.datasets[_index]) < _size:
|
|
160
|
+
raise IndexError(
|
|
161
|
+
f"The {self.split.name} blend oversamples the contributing datasets and, "
|
|
162
|
+
f"for example, requests {_size} samples from "
|
|
163
|
+
f"{type(self.datasets[_index]).__name__} number {i} in excess of its size "
|
|
164
|
+
f"{len(self.datasets[_index])}. The current value of the config attribute "
|
|
165
|
+
f"mid_level_dataset_surplus may be increased, e.g. two- or ten-fold, from "
|
|
166
|
+
f"its current value ({self.config.mid_level_dataset_surplus}) to ensure a "
|
|
167
|
+
f"sufficient mid-level dataset sample margin from which to draw."
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
if path_to_cache:
|
|
171
|
+
os.makedirs(path_to_cache, exist_ok=True)
|
|
172
|
+
# Write the description
|
|
173
|
+
with open(path_to_description, "wt") as writer:
|
|
174
|
+
writer.write(self.unique_description)
|
|
175
|
+
# Save the indexes
|
|
176
|
+
numpy.save(path_to_dataset_index, dataset_index, allow_pickle=True)
|
|
177
|
+
numpy.save(path_to_dataset_sample_index, dataset_sample_index, allow_pickle=True)
|
|
178
|
+
else:
|
|
179
|
+
log_single_rank(
|
|
180
|
+
logger,
|
|
181
|
+
logging.WARNING,
|
|
182
|
+
f"Cannot save the {type(self).__name__} indexes because path_to_cache is None",
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
t_end = time.time()
|
|
186
|
+
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
|
|
187
|
+
|
|
188
|
+
return dataset_index, dataset_sample_index
|
|
189
|
+
|
|
190
|
+
log_single_rank(logger, logging.INFO, f"Load the {type(self).__name__} indices")
|
|
191
|
+
|
|
192
|
+
log_single_rank(
|
|
193
|
+
logger, logging.INFO, f"\tLoad the dataset index from {path_to_dataset_index}"
|
|
194
|
+
)
|
|
195
|
+
t_beg = time.time()
|
|
196
|
+
dataset_index = numpy.load(path_to_dataset_index, allow_pickle=True, mmap_mode="r")
|
|
197
|
+
t_end = time.time()
|
|
198
|
+
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
|
|
199
|
+
|
|
200
|
+
log_single_rank(
|
|
201
|
+
logger,
|
|
202
|
+
logging.INFO,
|
|
203
|
+
f"\tLoad the dataset sample index from {path_to_dataset_sample_index}",
|
|
204
|
+
)
|
|
205
|
+
t_beg = time.time()
|
|
206
|
+
dataset_sample_index = numpy.load(
|
|
207
|
+
path_to_dataset_sample_index, allow_pickle=True, mmap_mode="r"
|
|
208
|
+
)
|
|
209
|
+
t_end = time.time()
|
|
210
|
+
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
|
|
211
|
+
|
|
212
|
+
return dataset_index, dataset_sample_index
|