megatron-core 0.11.0__cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (246) hide show
  1. megatron/core/README.md +14 -0
  2. megatron/core/__init__.py +34 -0
  3. megatron/core/config_logger.py +104 -0
  4. megatron/core/datasets/__init__.py +0 -0
  5. megatron/core/datasets/bert_dataset.py +192 -0
  6. megatron/core/datasets/blended_dataset.py +201 -0
  7. megatron/core/datasets/blended_megatron_dataset_builder.py +579 -0
  8. megatron/core/datasets/blended_megatron_dataset_config.py +172 -0
  9. megatron/core/datasets/gpt_dataset.py +810 -0
  10. megatron/core/datasets/helpers.cpp +846 -0
  11. megatron/core/datasets/helpers.py +64 -0
  12. megatron/core/datasets/helpers_cpp.cpython-310-aarch64-linux-gnu.so +0 -0
  13. megatron/core/datasets/indexed_dataset.py +857 -0
  14. megatron/core/datasets/masked_dataset.py +425 -0
  15. megatron/core/datasets/megatron_dataset.py +139 -0
  16. megatron/core/datasets/megatron_tokenizer.py +154 -0
  17. megatron/core/datasets/multimodal_dataset.py +62 -0
  18. megatron/core/datasets/retro/__init__.py +5 -0
  19. megatron/core/datasets/retro/config/__init__.py +16 -0
  20. megatron/core/datasets/retro/config/bert_embedders.py +48 -0
  21. megatron/core/datasets/retro/config/config.py +135 -0
  22. megatron/core/datasets/retro/config/gpt_chunk_datasets.py +15 -0
  23. megatron/core/datasets/retro/config/tokenizers.py +15 -0
  24. megatron/core/datasets/retro/db/__init__.py +9 -0
  25. megatron/core/datasets/retro/db/build.py +633 -0
  26. megatron/core/datasets/retro/db/dataset.py +105 -0
  27. megatron/core/datasets/retro/db/utils.py +367 -0
  28. megatron/core/datasets/retro/external_libs.py +15 -0
  29. megatron/core/datasets/retro/index/__init__.py +11 -0
  30. megatron/core/datasets/retro/index/build.py +313 -0
  31. megatron/core/datasets/retro/index/factory.py +40 -0
  32. megatron/core/datasets/retro/index/index.py +133 -0
  33. megatron/core/datasets/retro/index/indexes/__init__.py +10 -0
  34. megatron/core/datasets/retro/index/indexes/faiss_base.py +150 -0
  35. megatron/core/datasets/retro/index/indexes/faiss_par_add.py +208 -0
  36. megatron/core/datasets/retro/index/utils.py +126 -0
  37. megatron/core/datasets/retro/index/validate.py +191 -0
  38. megatron/core/datasets/retro/query/__init__.py +1 -0
  39. megatron/core/datasets/retro/query/gpt_chunk_dataset.py +109 -0
  40. megatron/core/datasets/retro/query/multi_split_gpt_dataset.py +107 -0
  41. megatron/core/datasets/retro/query/query.py +393 -0
  42. megatron/core/datasets/retro/query/retro_dataset.py +238 -0
  43. megatron/core/datasets/retro/query/utils.py +35 -0
  44. megatron/core/datasets/retro/utils.py +349 -0
  45. megatron/core/datasets/t5_dataset.py +331 -0
  46. megatron/core/datasets/utils.py +87 -0
  47. megatron/core/datasets/utils_s3.py +164 -0
  48. megatron/core/dist_checkpointing/__init__.py +12 -0
  49. megatron/core/dist_checkpointing/core.py +77 -0
  50. megatron/core/dist_checkpointing/dict_utils.py +248 -0
  51. megatron/core/dist_checkpointing/exchange_utils.py +544 -0
  52. megatron/core/dist_checkpointing/mapping.py +719 -0
  53. megatron/core/dist_checkpointing/optimizer.py +142 -0
  54. megatron/core/dist_checkpointing/serialization.py +424 -0
  55. megatron/core/dist_checkpointing/state_dict_utils.py +85 -0
  56. megatron/core/dist_checkpointing/strategies/__init__.py +7 -0
  57. megatron/core/dist_checkpointing/strategies/async_utils.py +228 -0
  58. megatron/core/dist_checkpointing/strategies/base.py +227 -0
  59. megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py +38 -0
  60. megatron/core/dist_checkpointing/strategies/common.py +157 -0
  61. megatron/core/dist_checkpointing/strategies/filesystem_async.py +477 -0
  62. megatron/core/dist_checkpointing/strategies/fully_parallel.py +515 -0
  63. megatron/core/dist_checkpointing/strategies/resharding.py +318 -0
  64. megatron/core/dist_checkpointing/strategies/state_dict_saver.py +247 -0
  65. megatron/core/dist_checkpointing/strategies/tensorstore.py +128 -0
  66. megatron/core/dist_checkpointing/strategies/torch.py +1010 -0
  67. megatron/core/dist_checkpointing/strategies/two_stage.py +268 -0
  68. megatron/core/dist_checkpointing/strategies/zarr.py +321 -0
  69. megatron/core/dist_checkpointing/tensor_aware_state_dict.py +347 -0
  70. megatron/core/dist_checkpointing/utils.py +319 -0
  71. megatron/core/dist_checkpointing/validation.py +560 -0
  72. megatron/core/distributed/__init__.py +8 -0
  73. megatron/core/distributed/data_parallel_base.py +96 -0
  74. megatron/core/distributed/distributed_data_parallel.py +483 -0
  75. megatron/core/distributed/distributed_data_parallel_config.py +49 -0
  76. megatron/core/distributed/finalize_model_grads.py +316 -0
  77. megatron/core/distributed/param_and_grad_buffer.py +845 -0
  78. megatron/core/distributed/torch_fully_sharded_data_parallel.py +123 -0
  79. megatron/core/enums.py +10 -0
  80. megatron/core/export/__init__.py +1 -0
  81. megatron/core/export/data_type.py +5 -0
  82. megatron/core/export/export_config.py +19 -0
  83. megatron/core/export/model_type.py +7 -0
  84. megatron/core/export/trtllm/__init__.py +1 -0
  85. megatron/core/export/trtllm/engine_builder/__init__.py +1 -0
  86. megatron/core/export/trtllm/engine_builder/trtllm_engine_builder.py +154 -0
  87. megatron/core/export/trtllm/model_to_trllm_mapping/__init__.py +1 -0
  88. megatron/core/export/trtllm/model_to_trllm_mapping/default_conversion_dict.py +40 -0
  89. megatron/core/export/trtllm/trt_model_config.py +15 -0
  90. megatron/core/export/trtllm/trt_model_type.py +13 -0
  91. megatron/core/export/trtllm/trtllm_helper.py +588 -0
  92. megatron/core/export/trtllm/trtllm_layers.py +157 -0
  93. megatron/core/export/trtllm/trtllm_weights_converter/__init__.py +1 -0
  94. megatron/core/export/trtllm/trtllm_weights_converter/distributed_trtllm_model_weights_converter.py +280 -0
  95. megatron/core/export/trtllm/trtllm_weights_converter/single_device_trtllm_model_weights_converter.py +471 -0
  96. megatron/core/extensions/__init__.py +0 -0
  97. megatron/core/extensions/transformer_engine.py +1353 -0
  98. megatron/core/fusions/__init__.py +0 -0
  99. megatron/core/fusions/fused_bias_dropout.py +73 -0
  100. megatron/core/fusions/fused_bias_geglu.py +85 -0
  101. megatron/core/fusions/fused_bias_gelu.py +55 -0
  102. megatron/core/fusions/fused_bias_swiglu.py +89 -0
  103. megatron/core/fusions/fused_cross_entropy.py +143 -0
  104. megatron/core/fusions/fused_layer_norm.py +169 -0
  105. megatron/core/fusions/fused_softmax.py +220 -0
  106. megatron/core/inference/__init__.py +1 -0
  107. megatron/core/inference/async_stream.py +67 -0
  108. megatron/core/inference/common_inference_params.py +4 -0
  109. megatron/core/inference/communication_utils.py +54 -0
  110. megatron/core/inference/engines/__init__.py +1 -0
  111. megatron/core/inference/engines/abstract_engine.py +17 -0
  112. megatron/core/inference/engines/mcore_engine.py +210 -0
  113. megatron/core/inference/inference_request.py +52 -0
  114. megatron/core/inference/model_inference_wrappers/__init__.py +1 -0
  115. megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +311 -0
  116. megatron/core/inference/model_inference_wrappers/gpt/__init__.py +1 -0
  117. megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py +102 -0
  118. megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +44 -0
  119. megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +212 -0
  120. megatron/core/inference/model_inference_wrappers/t5/__init__.py +1 -0
  121. megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py +225 -0
  122. megatron/core/inference/modelopt_support/__init__.py +10 -0
  123. megatron/core/inference/modelopt_support/gpt/__init__.py +1 -0
  124. megatron/core/inference/modelopt_support/gpt/model_specs.py +68 -0
  125. megatron/core/inference/modelopt_support/gpt/state_dict_hooks.py +133 -0
  126. megatron/core/inference/modelopt_support/mamba/__init__.py +1 -0
  127. megatron/core/inference/modelopt_support/mamba/model_specs.py +89 -0
  128. megatron/core/inference/sampling_params.py +36 -0
  129. megatron/core/inference/scheduler.py +175 -0
  130. megatron/core/inference/text_generation_controllers/__init__.py +1 -0
  131. megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +38 -0
  132. megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +5 -0
  133. megatron/core/inference/text_generation_controllers/text_generation_controller.py +665 -0
  134. megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py +40 -0
  135. megatron/core/inference/utils.py +17 -0
  136. megatron/core/inference_params.py +89 -0
  137. megatron/core/jit.py +10 -0
  138. megatron/core/model_parallel_config.py +387 -0
  139. megatron/core/models/T5/__init__.py +2 -0
  140. megatron/core/models/T5/t5_model.py +517 -0
  141. megatron/core/models/T5/t5_spec.py +248 -0
  142. megatron/core/models/__init__.py +0 -0
  143. megatron/core/models/bert/__init__.py +0 -0
  144. megatron/core/models/bert/bert_layer_specs.py +116 -0
  145. megatron/core/models/bert/bert_lm_head.py +50 -0
  146. megatron/core/models/bert/bert_model.py +373 -0
  147. megatron/core/models/bert/pooler.py +52 -0
  148. megatron/core/models/common/__init__.py +0 -0
  149. megatron/core/models/common/embeddings/__init__.py +5 -0
  150. megatron/core/models/common/embeddings/language_model_embedding.py +143 -0
  151. megatron/core/models/common/embeddings/relative_pos_embedding.py +173 -0
  152. megatron/core/models/common/embeddings/rope_utils.py +261 -0
  153. megatron/core/models/common/embeddings/rotary_pos_embedding.py +215 -0
  154. megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +179 -0
  155. megatron/core/models/common/language_module/__init__.py +0 -0
  156. megatron/core/models/common/language_module/language_module.py +244 -0
  157. megatron/core/models/common/vision_module/__init__.py +0 -0
  158. megatron/core/models/common/vision_module/vision_module.py +17 -0
  159. megatron/core/models/gpt/__init__.py +2 -0
  160. megatron/core/models/gpt/gpt_layer_specs.py +383 -0
  161. megatron/core/models/gpt/gpt_model.py +331 -0
  162. megatron/core/models/gpt/moe_module_specs.py +81 -0
  163. megatron/core/models/mamba/__init__.py +2 -0
  164. megatron/core/models/mamba/mamba_layer_specs.py +67 -0
  165. megatron/core/models/mamba/mamba_model.py +228 -0
  166. megatron/core/models/multimodal/__init__.py +1 -0
  167. megatron/core/models/multimodal/llava_model.py +915 -0
  168. megatron/core/models/multimodal/llava_spec.py +89 -0
  169. megatron/core/models/retro/__init__.py +13 -0
  170. megatron/core/models/retro/base_attention.py +43 -0
  171. megatron/core/models/retro/config.py +88 -0
  172. megatron/core/models/retro/decoder_attention.py +305 -0
  173. megatron/core/models/retro/decoder_spec.py +185 -0
  174. megatron/core/models/retro/encoder_attention.py +226 -0
  175. megatron/core/models/retro/encoder_spec.py +168 -0
  176. megatron/core/models/retro/model.py +99 -0
  177. megatron/core/models/retro/utils.py +24 -0
  178. megatron/core/models/vision/__init__.py +0 -0
  179. megatron/core/models/vision/clip_vit_model.py +221 -0
  180. megatron/core/models/vision/multimodal_projector.py +74 -0
  181. megatron/core/models/vision/radio.py +325 -0
  182. megatron/core/models/vision/vit_layer_specs.py +95 -0
  183. megatron/core/num_microbatches_calculator.py +508 -0
  184. megatron/core/optimizer/__init__.py +487 -0
  185. megatron/core/optimizer/clip_grads.py +232 -0
  186. megatron/core/optimizer/distrib_optimizer.py +1930 -0
  187. megatron/core/optimizer/grad_scaler.py +142 -0
  188. megatron/core/optimizer/optimizer.py +1118 -0
  189. megatron/core/optimizer/optimizer_config.py +181 -0
  190. megatron/core/optimizer_param_scheduler.py +297 -0
  191. megatron/core/package_info.py +29 -0
  192. megatron/core/packed_seq_params.py +20 -0
  193. megatron/core/parallel_state.py +2012 -0
  194. megatron/core/pipeline_parallel/__init__.py +2 -0
  195. megatron/core/pipeline_parallel/p2p_communication.py +632 -0
  196. megatron/core/pipeline_parallel/schedules.py +1887 -0
  197. megatron/core/requirements.txt +2 -0
  198. megatron/core/rerun_state_machine.py +1133 -0
  199. megatron/core/ssm/__init__.py +0 -0
  200. megatron/core/ssm/mamba_block.py +336 -0
  201. megatron/core/ssm/mamba_hybrid_layer_allocation.py +191 -0
  202. megatron/core/ssm/mamba_layer.py +145 -0
  203. megatron/core/ssm/mamba_mixer.py +718 -0
  204. megatron/core/ssm/triton_cache_manager.py +81 -0
  205. megatron/core/tensor_parallel/__init__.py +72 -0
  206. megatron/core/tensor_parallel/cross_entropy.py +232 -0
  207. megatron/core/tensor_parallel/data.py +105 -0
  208. megatron/core/tensor_parallel/layers.py +1224 -0
  209. megatron/core/tensor_parallel/mappings.py +576 -0
  210. megatron/core/tensor_parallel/random.py +431 -0
  211. megatron/core/tensor_parallel/utils.py +113 -0
  212. megatron/core/timers.py +449 -0
  213. megatron/core/transformer/__init__.py +6 -0
  214. megatron/core/transformer/attention.py +748 -0
  215. megatron/core/transformer/cuda_graphs.py +893 -0
  216. megatron/core/transformer/custom_layers/__init__.py +0 -0
  217. megatron/core/transformer/custom_layers/transformer_engine.py +12 -0
  218. megatron/core/transformer/dot_product_attention.py +206 -0
  219. megatron/core/transformer/enums.py +48 -0
  220. megatron/core/transformer/identity_op.py +28 -0
  221. megatron/core/transformer/mlp.py +261 -0
  222. megatron/core/transformer/module.py +195 -0
  223. megatron/core/transformer/moe/__init__.py +0 -0
  224. megatron/core/transformer/moe/experts.py +854 -0
  225. megatron/core/transformer/moe/grouped_gemm_util.py +22 -0
  226. megatron/core/transformer/moe/legacy_a2a_token_dispatcher.py +317 -0
  227. megatron/core/transformer/moe/moe_layer.py +147 -0
  228. megatron/core/transformer/moe/moe_utils.py +655 -0
  229. megatron/core/transformer/moe/router.py +364 -0
  230. megatron/core/transformer/moe/shared_experts.py +243 -0
  231. megatron/core/transformer/moe/token_dispatcher.py +643 -0
  232. megatron/core/transformer/moe/upcycling_utils.py +196 -0
  233. megatron/core/transformer/multi_latent_attention.py +414 -0
  234. megatron/core/transformer/spec_utils.py +106 -0
  235. megatron/core/transformer/torch_layer_norm.py +4 -0
  236. megatron/core/transformer/torch_norm.py +48 -0
  237. megatron/core/transformer/transformer_block.py +664 -0
  238. megatron/core/transformer/transformer_config.py +920 -0
  239. megatron/core/transformer/transformer_layer.py +502 -0
  240. megatron/core/transformer/utils.py +188 -0
  241. megatron/core/utils.py +1453 -0
  242. megatron_core-0.11.0.dist-info/LICENSE +272 -0
  243. megatron_core-0.11.0.dist-info/METADATA +998 -0
  244. megatron_core-0.11.0.dist-info/RECORD +246 -0
  245. megatron_core-0.11.0.dist-info/WHEEL +6 -0
  246. megatron_core-0.11.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,14 @@
1
+ # Megatron-Core
2
+
3
+ Megatron-Core is an open-source PyTorch-based library that contains GPU-optimized techniques and cutting-edge system-level optimizations. It abstracts them into composable and modular APIs, allowing full flexibility for developers and model researchers to train custom transformers at-scale on NVIDIA accelerated computing infrastructure. This library is compatible with all NVIDIA Tensor Core GPUs, including FP8 acceleration support for [NVIDIA Hopper architectures](https://www.nvidia.com/en-us/data-center/technologies/hopper-architecture/).
4
+
5
+ Megatron-Core offers core building blocks such as attention mechanisms, transformer blocks and layers, normalization layers, and embedding techniques. Additional functionality like activation re-computation, distributed checkpointing is also natively built-in to the library. The building blocks and functionality are all GPU optimized, and can be built with advanced parallelization strategies for optimal training speed and stability on NVIDIA Accelerated Computing Infrastructure. Another key component of the Megatron-Core library includes advanced model parallelism techniques (tensor, sequence, pipeline, context, and MoE expert parallelism).
6
+
7
+ Megatron-Core can be used with [NVIDIA NeMo](https://www.nvidia.com/en-us/ai-data-science/products/nemo/), an enterprise-grade AI platform. Alternatively, you can explore Megatron-Core with the native PyTorch training loop [here](https://github.com/NVIDIA/Megatron-LM/tree/main/examples). Visit [Megatron-Core documentation](https://docs.nvidia.com/megatron-core/developer-guide/latest/index.html) to learn more.
8
+
9
+ ## Quick links
10
+
11
+ - [Benchmark using NVIDIA NeMo](https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html#performance-benchmarks)
12
+ - [Multimodal example (LLaVA training pipeline)](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/multimodal)
13
+ - [Mixture-of-Experts](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/transformer/moe)
14
+ - [Training Mamba-based Language Models](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/mamba)
@@ -0,0 +1,34 @@
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ import megatron.core.tensor_parallel
3
+ import megatron.core.utils
4
+ from megatron.core import parallel_state
5
+ from megatron.core.distributed import DistributedDataParallel
6
+ from megatron.core.inference_params import InferenceParams
7
+ from megatron.core.model_parallel_config import ModelParallelConfig
8
+ from megatron.core.package_info import (
9
+ __contact_emails__,
10
+ __contact_names__,
11
+ __description__,
12
+ __download_url__,
13
+ __homepage__,
14
+ __keywords__,
15
+ __license__,
16
+ __package_name__,
17
+ __repository_url__,
18
+ __shortversion__,
19
+ __version__,
20
+ )
21
+ from megatron.core.timers import Timers
22
+
23
+ # Alias parallel_state as mpu, its legacy name
24
+ mpu = parallel_state
25
+
26
+ __all__ = [
27
+ "parallel_state",
28
+ "tensor_parallel",
29
+ "utils",
30
+ "DistributedDataParallel",
31
+ "InferenceParams",
32
+ "ModelParallelConfig",
33
+ "Timers",
34
+ ]
@@ -0,0 +1,104 @@
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ import dataclasses
4
+ import json
5
+ import os
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from megatron.core import parallel_state
11
+
12
+
13
+ def get_config_logger_path(config):
14
+ return getattr(config, 'config_logger_dir', '')
15
+
16
+
17
+ def has_config_logger_enabled(config):
18
+ return get_config_logger_path(config) != ''
19
+
20
+
21
+ # For each prefix, holds a counter and increases it every time we dump with this
22
+ # prefix.
23
+ __config_logger_path_counts = {}
24
+
25
+
26
+ def get_path_count(path):
27
+ """
28
+ keeps tracks of number of times we've seen the input `path` and return count-1
29
+ """
30
+ global __config_logger_path_counts
31
+ if not path in __config_logger_path_counts:
32
+ __config_logger_path_counts[path] = 0
33
+ count = __config_logger_path_counts[path]
34
+ __config_logger_path_counts[path] += 1
35
+ return count
36
+
37
+
38
+ def get_path_with_count(path):
39
+ """
40
+ calls get_path_count and appends returned value to path
41
+ """
42
+ return f'{path}.iter{get_path_count(path)}'
43
+
44
+
45
+ class JSONEncoderWithMcoreTypes(json.JSONEncoder):
46
+ def default(self, o):
47
+ if type(o).__name__ in ['function', 'ProcessGroup']:
48
+ return str(o)
49
+ if type(o).__name__ in ['dict', 'OrderedDict']:
50
+ return {k: self.default(v) for k, v in o.items()}
51
+ if type(o).__name__ in ['list', 'ModuleList']:
52
+ return [self.default(val) for val in o]
53
+ if type(o).__name__ == 'UniqueDescriptor':
54
+ return {
55
+ attr: self.default(getattr(o, attr))
56
+ for attr in filter(lambda x: not x.startswith('__'), dir(o))
57
+ }
58
+ if type(o) is torch.dtype:
59
+ return str(o)
60
+ # if it's a Float16Module, add "Float16Module" to the output dict
61
+ if type(o).__name__ == 'Float16Module':
62
+ return {'Float16Module': {'module': self.default(o.module)}}
63
+ # If it's a nn.Module subchild, either print its children or itself if leaf.
64
+ if issubclass(type(o), nn.Module):
65
+ if len(getattr(o, '_modules', {})) > 0:
66
+ return {key: self.default(val) for key, val in o._modules.items()}
67
+ else:
68
+ return str(o)
69
+ if type(o).__name__ in ['ABCMeta', 'type', 'AttnMaskType']:
70
+ return str(o)
71
+ if dataclasses.is_dataclass(o) or type(o).__name__ in ['ModuleSpec', 'TransformerConfig']:
72
+ return dataclasses.asdict(o)
73
+ try:
74
+ return super().default(o)
75
+ except:
76
+ return str(o)
77
+
78
+
79
+ def log_config_to_disk(config, dict_data, prefix=''):
80
+ """
81
+ Encodes the input dict (dict_data) using the JSONEncoderWithMcoreTypes
82
+ and dumps to disk, as specified via path
83
+ """
84
+ path = get_config_logger_path(config)
85
+ assert path is not None, 'Expected config_logger_dir to be non-empty in config.'
86
+
87
+ if 'self' in dict_data:
88
+ if prefix == '':
89
+ prefix = type(dict_data['self']).__name__
90
+ del dict_data['self']
91
+
92
+ if not os.path.exists(path):
93
+ os.makedirs(path, exist_ok=True)
94
+
95
+ rank = parallel_state.get_all_ranks()
96
+ path = get_path_with_count(os.path.join(path, f'{prefix}.rank_{rank}'))
97
+ if type(dict_data).__name__ == 'OrderedDict':
98
+ torch.save(dict_data, f'{path}.pth')
99
+ else:
100
+ with open(f'{path}.json', 'w') as fp:
101
+ json.dump(dict_data, fp, cls=JSONEncoderWithMcoreTypes)
102
+
103
+
104
+ __all__ = ['has_config_logger_enabled', 'log_config_to_disk']
File without changes
@@ -0,0 +1,192 @@
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Optional, Union
5
+
6
+ import numpy
7
+
8
+ from megatron.core.datasets.indexed_dataset import IndexedDataset
9
+ from megatron.core.datasets.masked_dataset import (
10
+ MaskedWordPieceDataset,
11
+ MaskedWordPieceDatasetConfig,
12
+ )
13
+ from megatron.core.datasets.utils import Split
14
+
15
+
16
+ @dataclass
17
+ class BERTMaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig):
18
+ """Configuration object for Megatron Core BERT WordPiece datasets"""
19
+
20
+ classification_head: bool = None
21
+ """Option to perform the next sequence prediction during sampling"""
22
+
23
+ def __post_init__(self) -> None:
24
+ """Do asserts and set fields post init"""
25
+ super().__post_init__()
26
+
27
+ assert self.classification_head is not None
28
+
29
+
30
+ class BERTMaskedWordPieceDataset(MaskedWordPieceDataset):
31
+ """The BERT dataset that assumes WordPiece tokenization
32
+
33
+ Args:
34
+ indexed_dataset (IndexedDataset): The IndexedDataset around which to build the MegatronDataset
35
+
36
+ dataset_path (str): The real path on disk to the dataset, for bookkeeping
37
+
38
+ indexed_indices (numpy.ndarray): The set of the documents indices to expose
39
+
40
+ num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.
41
+
42
+ index_split (Split): The indexed_indices Split
43
+
44
+ config (BERTMaskedWordPieceDatasetConfig): The config
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ indexed_dataset: IndexedDataset,
50
+ dataset_path: str,
51
+ indexed_indices: numpy.ndarray,
52
+ num_samples: Optional[int],
53
+ index_split: Split,
54
+ config: BERTMaskedWordPieceDatasetConfig,
55
+ ) -> None:
56
+ super().__init__(
57
+ indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
58
+ )
59
+
60
+ self.token_lookup = list(self.config.tokenizer.inv_vocab.keys())
61
+ # Account for the single <cls> and two <sep> token ids
62
+ self.sample_index = self._build_sample_index(
63
+ self.config.sequence_length - 3, 2 if self.config.classification_head else 1
64
+ )
65
+
66
+ @staticmethod
67
+ def _key_config_attributes() -> List[str]:
68
+ """Inherited method implementation
69
+
70
+ Returns:
71
+ List[str]: The key config attributes
72
+ """
73
+ return super(
74
+ BERTMaskedWordPieceDataset, BERTMaskedWordPieceDataset
75
+ )._key_config_attributes() + ["classification_head"]
76
+
77
+ def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
78
+ """Abstract method implementation
79
+
80
+ Args:
81
+ idx (int): The index into the dataset
82
+
83
+ Returns:
84
+ Dict[str, Union[int, numpy.ndarray]]: The
85
+ """
86
+ idx_beg, idx_end, target_sequence_length = self.sample_index[idx]
87
+ sample = [self.dataset[i] for i in range(idx_beg, idx_end)]
88
+ numpy_random_state = numpy.random.RandomState(seed=(self.config.random_seed + idx) % 2**32)
89
+
90
+ assert target_sequence_length <= self.config.sequence_length
91
+
92
+ # Split the sample into contiguous subsegments A and B
93
+ pivot = len(sample)
94
+ is_next_random = False
95
+ if self.config.classification_head:
96
+ assert len(sample) > 1, "the sample must contain at least two sentences"
97
+ pivot = 1
98
+ if len(sample) >= 3:
99
+ pivot = numpy_random_state.randint(low=1, high=len(sample))
100
+ is_next_random = numpy_random_state.random() < 0.5
101
+ split_A = []
102
+ for sample_a in sample[:pivot]:
103
+ split_A.extend(sample_a)
104
+ split_B = []
105
+ for sample_b in sample[pivot:]:
106
+ split_B.extend(sample_b)
107
+ if is_next_random:
108
+ split_A, split_B = split_B, split_A
109
+
110
+ # Trim the subsegments from either end to a desired joint length
111
+ length_A = len(split_A)
112
+ length_B = len(split_B)
113
+ if length_A + length_B <= target_sequence_length:
114
+ truncated = False
115
+ else:
116
+ while length_A + length_B > target_sequence_length:
117
+ split = split_A if length_A > length_B else split_B
118
+ if numpy_random_state.random() < 0.5:
119
+ del split[0]
120
+ else:
121
+ del split[-1]
122
+ length_A = len(split_A)
123
+ length_B = len(split_B)
124
+ truncated = True
125
+
126
+ # Merge the subsegments and create the token assignment labels
127
+ tokens = [self.config.tokenizer.cls, *split_A, self.config.tokenizer.sep]
128
+ assignments = [0 for _ in range(1 + len(split_A) + 1)]
129
+ if split_B:
130
+ tokens += [*split_B, self.config.tokenizer.sep]
131
+ assignments += [1 for _ in range(len(split_B) + 1)]
132
+
133
+ # Masking
134
+ tokens, masked_positions, masked_labels, _, _ = self._create_masked_lm_predictions(
135
+ tokens, target_sequence_length, numpy_random_state
136
+ )
137
+
138
+ # Pad the sequences and convert to NumPy
139
+ length_toks = len(tokens)
140
+ length_pads = self.config.sequence_length - length_toks
141
+ assert length_pads >= 0
142
+
143
+ tokens = numpy.array(tokens, dtype=numpy.int64)
144
+ tokens = numpy.pad(tokens, (0, length_pads), constant_values=self.config.tokenizer.pad)
145
+
146
+ assignments = numpy.array(assignments, dtype=numpy.int64)
147
+ assignments = numpy.pad(
148
+ assignments, (0, length_pads), constant_values=self.config.tokenizer.pad
149
+ )
150
+
151
+ # Get the padding mask
152
+ mask_pads = numpy.ones(length_toks, dtype=numpy.int64)
153
+ mask_pads = numpy.pad(
154
+ mask_pads, (0, length_pads), constant_values=self.config.tokenizer.pad
155
+ )
156
+
157
+ # Mask the labels
158
+ labels = numpy.zeros(self.config.sequence_length, dtype=numpy.int64) - 1
159
+ labels[masked_positions] = masked_labels
160
+
161
+ # Get the loss mask
162
+ mask_loss = numpy.zeros(self.config.sequence_length, dtype=numpy.int64)
163
+ mask_loss[masked_positions] = 1
164
+
165
+ return {
166
+ "text": tokens,
167
+ "types": assignments,
168
+ "labels": labels,
169
+ "is_random": int(is_next_random),
170
+ "padding_mask": mask_pads,
171
+ "loss_mask": mask_loss,
172
+ "truncated": int(truncated),
173
+ }
174
+
175
+ def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> Optional[int]:
176
+ """Abstract method implementation
177
+
178
+ 80% of the time, replace the token id with mask token id. 10% of the time, replace token id
179
+ with a random token id from the vocabulary. 10% of the time, do nothing.
180
+
181
+ Args:
182
+ numpy_random_state (RandomState): The NumPy random state
183
+
184
+ Returns:
185
+ Optional[int]: The replacement token id or None
186
+ """
187
+ if numpy_random_state.random() < 0.8:
188
+ return self.config.tokenizer.mask
189
+ else:
190
+ if numpy_random_state.random() >= 0.5:
191
+ return self.token_lookup[numpy_random_state.randint(0, len(self.token_lookup))]
192
+ return None
@@ -0,0 +1,201 @@
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ import hashlib
4
+ import json
5
+ import logging
6
+ import os
7
+ import time
8
+ from collections import OrderedDict
9
+ from typing import Dict, List, Optional, Tuple, Union
10
+
11
+ import numpy
12
+ import torch
13
+
14
+ from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
15
+ from megatron.core.datasets.megatron_dataset import MegatronDataset
16
+ from megatron.core.datasets.utils import normalize
17
+ from megatron.core.utils import log_single_rank
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ _VERBOSE = False
22
+
23
+
24
+ class BlendedDataset(torch.utils.data.Dataset):
25
+ """Conjugating class for a set of MegatronDataset instances
26
+
27
+ Args:
28
+ datasets (List[MegatronDataset]): The MegatronDataset instances to blend
29
+
30
+ weights (List[Union[int, float]]): The weights that determine the dataset blend ratios
31
+
32
+ size (Optional[int]): The number of samples to draw from the blend. If None, for each
33
+ dataset index idx draw exactly weights[idx] samples from datasets[idx].
34
+
35
+ config (BlendedMegatronDatasetConfig): The config
36
+
37
+ Raises:
38
+ RuntimeError: When the dataset has fewer or more samples than 'size' post-initialization
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ datasets: List[MegatronDataset],
44
+ weights: List[Union[int, float]],
45
+ size: Optional[int],
46
+ config: BlendedMegatronDatasetConfig,
47
+ ) -> None:
48
+ assert len(datasets) == len(weights)
49
+ assert len(datasets) < 32767
50
+ assert all(map(lambda _: type(_) == type(datasets[0]), datasets))
51
+ assert all(map(lambda _: _.index_split == datasets[0].index_split, datasets))
52
+ assert all(map(lambda _: _ > 0, weights))
53
+ assert all(map(lambda _: type(_) == type(weights[0]), weights))
54
+ if size is None and isinstance(weights[0], float):
55
+ assert all(map(lambda _: _ == int(_), weights))
56
+
57
+ # Alert user to unnecessary blending
58
+ if len(datasets) == 1:
59
+ log_single_rank(
60
+ logger, logging.WARNING, f"Building a BlendedDataset for a single MegatronDataset"
61
+ )
62
+
63
+ if size is not None:
64
+ weights = normalize(weights)
65
+
66
+ self.datasets = datasets
67
+ self.split = self.datasets[0].index_split
68
+ self.weights = weights
69
+ self.size = size
70
+ self.config = config
71
+
72
+ unique_identifiers = OrderedDict()
73
+ unique_identifiers["class"] = type(self).__name__
74
+ unique_identifiers["datasets"] = [dataset.unique_identifiers for dataset in self.datasets]
75
+ unique_identifiers["split"] = self.split.name
76
+ unique_identifiers["weights"] = self.weights
77
+ unique_identifiers["size"] = self.size
78
+
79
+ self.unique_description = json.dumps(
80
+ unique_identifiers, indent=4, default=lambda obj: obj.unique_identifiers
81
+ )
82
+ self.unique_description_hash = hashlib.md5(
83
+ self.unique_description.encode("utf-8")
84
+ ).hexdigest()
85
+
86
+ self.built_anew_on_cache_miss = False
87
+
88
+ self.dataset_index, self.dataset_sample_index = self._build_indices()
89
+
90
+ def __len__(self) -> int:
91
+ return self.dataset_index.shape[0]
92
+
93
+ def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
94
+ dataset_id = self.dataset_index[idx]
95
+ dataset_sample_id = self.dataset_sample_index[idx]
96
+ return {"dataset_id": dataset_id, **self.datasets[dataset_id][dataset_sample_id]}
97
+
98
+ def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]:
99
+ """Build and optionally cache the dataset index and the dataset sample index
100
+
101
+ The dataset index is a 1-D mapping which determines the dataset to query. The dataset
102
+ sample index is a 1-D mapping which determines the sample to request from the queried
103
+ dataset.
104
+
105
+ Returns:
106
+ Tuple[numpy.ndarray, numpy.ndarray]: The dataset index and the dataset sample index
107
+ """
108
+ path_to_cache = self.config.path_to_cache
109
+
110
+ if path_to_cache:
111
+ get_path_to = lambda suffix: os.path.join(
112
+ path_to_cache,
113
+ f"{self.unique_description_hash}-{type(self).__name__}-{self.split.name}-{suffix}",
114
+ )
115
+ path_to_description = get_path_to("description.txt")
116
+ path_to_dataset_index = get_path_to("dataset_index.npy")
117
+ path_to_dataset_sample_index = get_path_to("dataset_sample_index.npy")
118
+ cache_hit = all(
119
+ map(
120
+ os.path.isfile,
121
+ [path_to_description, path_to_dataset_index, path_to_dataset_sample_index],
122
+ )
123
+ )
124
+ else:
125
+ cache_hit = False
126
+
127
+ if not path_to_cache or (not cache_hit and torch.distributed.get_rank() == 0):
128
+ log_single_rank(
129
+ logger, logging.INFO, f"Build and save the {type(self).__name__} indices"
130
+ )
131
+ self.built_anew_on_cache_miss = True
132
+
133
+ # Build the dataset and dataset sample indexes
134
+ log_single_rank(
135
+ logger, logging.INFO, f"\tBuild and save the dataset and dataset sample indexes"
136
+ )
137
+ t_beg = time.time()
138
+ from megatron.core.datasets import helpers
139
+
140
+ if self.size is not None:
141
+ dataset_index = numpy.zeros(self.size, dtype=numpy.int16)
142
+ dataset_sample_index = numpy.zeros(self.size, dtype=numpy.int64)
143
+ helpers.build_blending_indices(
144
+ dataset_index,
145
+ dataset_sample_index,
146
+ self.weights,
147
+ len(self.datasets),
148
+ self.size,
149
+ _VERBOSE,
150
+ )
151
+ else:
152
+ size = sum(self.weights)
153
+ dataset_index = numpy.zeros(size, dtype=numpy.int16)
154
+ dataset_sample_index = numpy.zeros(size, dtype=numpy.int64)
155
+ helpers.build_exhaustive_blending_indices(
156
+ dataset_index, dataset_sample_index, self.weights, len(self.datasets)
157
+ )
158
+
159
+ if path_to_cache:
160
+ os.makedirs(path_to_cache, exist_ok=True)
161
+ # Write the description
162
+ with open(path_to_description, "wt") as writer:
163
+ writer.write(self.unique_description)
164
+ # Save the indexes
165
+ numpy.save(path_to_dataset_index, dataset_index, allow_pickle=True)
166
+ numpy.save(path_to_dataset_sample_index, dataset_sample_index, allow_pickle=True)
167
+ else:
168
+ log_single_rank(
169
+ logger,
170
+ logging.WARNING,
171
+ f"Cannot save the {type(self).__name__} indexes because path_to_cache is None",
172
+ )
173
+
174
+ t_end = time.time()
175
+ log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
176
+
177
+ return dataset_index, dataset_sample_index
178
+
179
+ log_single_rank(logger, logging.INFO, f"Load the {type(self).__name__} indices")
180
+
181
+ log_single_rank(
182
+ logger, logging.INFO, f"\tLoad the dataset index from {path_to_dataset_index}"
183
+ )
184
+ t_beg = time.time()
185
+ dataset_index = numpy.load(path_to_dataset_index, allow_pickle=True, mmap_mode='r')
186
+ t_end = time.time()
187
+ log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
188
+
189
+ log_single_rank(
190
+ logger,
191
+ logging.INFO,
192
+ f"\tLoad the dataset sample index from {path_to_dataset_sample_index}",
193
+ )
194
+ t_beg = time.time()
195
+ dataset_sample_index = numpy.load(
196
+ path_to_dataset_sample_index, allow_pickle=True, mmap_mode='r'
197
+ )
198
+ t_end = time.time()
199
+ log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
200
+
201
+ return dataset_index, dataset_sample_index