megatron-core 0.5.0__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megatron-core might be problematic. Click here for more details.

Files changed (109) hide show
  1. megatron/core/__init__.py +20 -0
  2. megatron/core/datasets/__init__.py +0 -0
  3. megatron/core/datasets/bert_dataset.py +207 -0
  4. megatron/core/datasets/blended_dataset.py +192 -0
  5. megatron/core/datasets/blended_megatron_dataset_builder.py +329 -0
  6. megatron/core/datasets/blended_megatron_dataset_config.py +171 -0
  7. megatron/core/datasets/gpt_dataset.py +642 -0
  8. megatron/core/datasets/helpers.cpp +765 -0
  9. megatron/core/datasets/helpers.cpython-310-x86_64-linux-gnu.so +0 -0
  10. megatron/core/datasets/indexed_dataset.py +639 -0
  11. megatron/core/datasets/masked_dataset.py +430 -0
  12. megatron/core/datasets/megatron_dataset.py +187 -0
  13. megatron/core/datasets/megatron_tokenizer.py +141 -0
  14. megatron/core/datasets/t5_dataset.py +239 -0
  15. megatron/core/datasets/utils.py +64 -0
  16. megatron/core/dist_checkpointing/__init__.py +11 -0
  17. megatron/core/dist_checkpointing/core.py +77 -0
  18. megatron/core/dist_checkpointing/dict_utils.py +232 -0
  19. megatron/core/dist_checkpointing/mapping.py +346 -0
  20. megatron/core/dist_checkpointing/optimizer.py +127 -0
  21. megatron/core/dist_checkpointing/serialization.py +453 -0
  22. megatron/core/dist_checkpointing/strategies/__init__.py +20 -0
  23. megatron/core/dist_checkpointing/strategies/base.py +105 -0
  24. megatron/core/dist_checkpointing/strategies/tensorstore.py +131 -0
  25. megatron/core/dist_checkpointing/strategies/two_stage.py +255 -0
  26. megatron/core/dist_checkpointing/strategies/zarr.py +298 -0
  27. megatron/core/dist_checkpointing/utils.py +139 -0
  28. megatron/core/distributed/__init__.py +2 -0
  29. megatron/core/distributed/distributed_data_parallel.py +250 -0
  30. megatron/core/distributed/finalize_model_grads.py +158 -0
  31. megatron/core/distributed/grad_buffer.py +426 -0
  32. megatron/core/enums.py +10 -0
  33. megatron/core/fusions/__init__.py +0 -0
  34. megatron/core/fusions/fused_bias_dropout.py +73 -0
  35. megatron/core/fusions/fused_bias_gelu.py +50 -0
  36. megatron/core/fusions/fused_bias_swiglu.py +81 -0
  37. megatron/core/fusions/fused_layer_norm.py +172 -0
  38. megatron/core/fusions/fused_softmax.py +220 -0
  39. megatron/core/inference_params.py +27 -0
  40. megatron/core/jit.py +11 -0
  41. megatron/core/model_parallel_config.py +247 -0
  42. megatron/core/models/T5/__init__.py +1 -0
  43. megatron/core/models/T5/t5_model.py +428 -0
  44. megatron/core/models/T5/t5_spec.py +220 -0
  45. megatron/core/models/__init__.py +0 -0
  46. megatron/core/models/bert/__init__.py +0 -0
  47. megatron/core/models/bert/bert_layer_specs.py +64 -0
  48. megatron/core/models/bert/bert_lm_head.py +75 -0
  49. megatron/core/models/bert/bert_model.py +282 -0
  50. megatron/core/models/bert/pooler.py +51 -0
  51. megatron/core/models/common/__init__.py +0 -0
  52. megatron/core/models/common/embeddings/__init__.py +0 -0
  53. megatron/core/models/common/embeddings/language_model_embedding.py +128 -0
  54. megatron/core/models/common/embeddings/rotary_pos_embedding.py +249 -0
  55. megatron/core/models/common/language_module/__init__.py +0 -0
  56. megatron/core/models/common/language_module/language_module.py +105 -0
  57. megatron/core/models/gpt/__init__.py +1 -0
  58. megatron/core/models/gpt/gpt_layer_specs.py +99 -0
  59. megatron/core/models/gpt/gpt_model.py +247 -0
  60. megatron/core/models/retro/__init__.py +5 -0
  61. megatron/core/models/retro/base_attention.py +45 -0
  62. megatron/core/models/retro/config.py +43 -0
  63. megatron/core/models/retro/decoder_attention.py +301 -0
  64. megatron/core/models/retro/decoder_spec.py +152 -0
  65. megatron/core/models/retro/encoder_attention.py +223 -0
  66. megatron/core/models/retro/encoder_spec.py +141 -0
  67. megatron/core/models/retro/model.py +89 -0
  68. megatron/core/package_info.py +29 -0
  69. megatron/core/packed_seq_params.py +13 -0
  70. megatron/core/parallel_state.py +1014 -0
  71. megatron/core/pipeline_parallel/__init__.py +1 -0
  72. megatron/core/pipeline_parallel/p2p_communication.py +571 -0
  73. megatron/core/pipeline_parallel/schedules.py +1341 -0
  74. megatron/core/requirements.txt +1 -0
  75. megatron/core/tensor_parallel/__init__.py +65 -0
  76. megatron/core/tensor_parallel/cross_entropy.py +142 -0
  77. megatron/core/tensor_parallel/data.py +104 -0
  78. megatron/core/tensor_parallel/layers.py +998 -0
  79. megatron/core/tensor_parallel/mappings.py +358 -0
  80. megatron/core/tensor_parallel/random.py +266 -0
  81. megatron/core/tensor_parallel/utils.py +113 -0
  82. megatron/core/timers.py +391 -0
  83. megatron/core/transformer/__init__.py +6 -0
  84. megatron/core/transformer/attention.py +487 -0
  85. megatron/core/transformer/custom_layers/__init__.py +0 -0
  86. megatron/core/transformer/custom_layers/transformer_engine.py +495 -0
  87. megatron/core/transformer/dot_product_attention.py +205 -0
  88. megatron/core/transformer/enums.py +26 -0
  89. megatron/core/transformer/identity_op.py +28 -0
  90. megatron/core/transformer/mlp.py +188 -0
  91. megatron/core/transformer/module.py +185 -0
  92. megatron/core/transformer/moe/__init__.py +0 -0
  93. megatron/core/transformer/moe/experts.py +235 -0
  94. megatron/core/transformer/moe/grouped_gemm_util.py +20 -0
  95. megatron/core/transformer/moe/moe_layer.py +80 -0
  96. megatron/core/transformer/moe/moe_utils.py +101 -0
  97. megatron/core/transformer/moe/router.py +242 -0
  98. megatron/core/transformer/moe/token_dispatcher.py +279 -0
  99. megatron/core/transformer/spec_utils.py +109 -0
  100. megatron/core/transformer/transformer_block.py +418 -0
  101. megatron/core/transformer/transformer_config.py +256 -0
  102. megatron/core/transformer/transformer_layer.py +234 -0
  103. megatron/core/transformer/utils.py +184 -0
  104. megatron/core/utils.py +236 -0
  105. megatron_core-0.5.0.dist-info/LICENSE +291 -0
  106. megatron_core-0.5.0.dist-info/METADATA +34 -0
  107. megatron_core-0.5.0.dist-info/RECORD +109 -0
  108. megatron_core-0.5.0.dist-info/WHEEL +6 -0
  109. megatron_core-0.5.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,20 @@
1
+ import megatron.core.tensor_parallel
2
+ import megatron.core.utils
3
+ from megatron.core import parallel_state
4
+ from megatron.core.distributed import DistributedDataParallel
5
+ from megatron.core.inference_params import InferenceParams
6
+ from megatron.core.model_parallel_config import ModelParallelConfig
7
+ from megatron.core.timers import Timers
8
+
9
+ # Alias parallel_state as mpu, its legacy name
10
+ mpu = parallel_state
11
+
12
+ __all__ = [
13
+ "parallel_state",
14
+ "tensor_parallel",
15
+ "utils",
16
+ "DistributedDataParallel",
17
+ "InferenceParams",
18
+ "ModelParallelConfig",
19
+ "Timers",
20
+ ]
File without changes
@@ -0,0 +1,207 @@
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Optional, Union
5
+
6
+ import numpy
7
+
8
+ from megatron.core.datasets.indexed_dataset import MMapIndexedDataset
9
+ from megatron.core.datasets.masked_dataset import (
10
+ MaskedWordPieceDataset,
11
+ MaskedWordPieceDatasetConfig,
12
+ )
13
+ from megatron.core.datasets.utils import Split
14
+
15
+
16
+ @dataclass
17
+ class BERTMaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig):
18
+ """Configuration object for Megatron Core BERT WordPiece datasets
19
+
20
+ Attributes:
21
+ classification_head (bool): Option to perform the next sequence prediction during
22
+ sampling
23
+ """
24
+
25
+ classification_head: bool = None
26
+
27
+ def __post_init__(self) -> None:
28
+ """Do asserts and set fields post init
29
+ """
30
+ super().__post_init__()
31
+
32
+ assert self.classification_head is not None
33
+
34
+
35
+ class BERTMaskedWordPieceDataset(MaskedWordPieceDataset):
36
+ """The BERT dataset that assumes WordPiece tokenization
37
+
38
+ Args:
39
+ indexed_dataset (MMapIndexedDataset): The MMapIndexedDataset around which to build the
40
+ MegatronDataset
41
+
42
+ dataset_path (str): The real path on disk to the dataset, for bookkeeping
43
+
44
+ indexed_indices (numpy.ndarray): The set of the documents indices to expose
45
+
46
+ num_samples (int): The number of samples to draw from the indexed dataset
47
+
48
+ index_split (Split): The indexed_indices Split
49
+
50
+ config (BERTMaskedWordPieceDatasetConfig): The config
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ indexed_dataset: MMapIndexedDataset,
56
+ dataset_path: str,
57
+ indexed_indices: numpy.ndarray,
58
+ num_samples: int,
59
+ index_split: Split,
60
+ config: BERTMaskedWordPieceDatasetConfig,
61
+ ) -> None:
62
+ super().__init__(
63
+ indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
64
+ )
65
+
66
+ def _finalize(self) -> None:
67
+ """Abstract method implementation
68
+ """
69
+ self.token_lookup = list(self.config.tokenizer.inv_vocab.keys())
70
+ # Account for the single <cls> and two <sep> token ids
71
+ self.sample_index = self._build_sample_index(
72
+ self.config.sequence_length - 3, 2 if self.config.classification_head else 1
73
+ )
74
+
75
+ @staticmethod
76
+ def _key_config_attributes() -> List[str]:
77
+ """Inherited method implementation
78
+
79
+ Returns:
80
+ List[str]: The key config attributes
81
+ """
82
+ return super(
83
+ BERTMaskedWordPieceDataset, BERTMaskedWordPieceDataset
84
+ )._key_config_attributes() + ["classification_head",]
85
+
86
+ def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
87
+ """Abstract method implementation
88
+
89
+ Args:
90
+ idx (int): The index into the dataset
91
+
92
+ Returns:
93
+ Dict[str, Union[int, numpy.ndarray]]: The
94
+ """
95
+ idx_beg, idx_end, target_sequence_length = self.sample_index[idx]
96
+ sample = [self.dataset[i] for i in range(idx_beg, idx_end)]
97
+ numpy_random_state = numpy.random.RandomState(
98
+ seed=(self.config.random_seed + idx) % 2 ** 32
99
+ )
100
+
101
+ assert target_sequence_length <= self.config.sequence_length
102
+
103
+ # Split the sample into contiguous subsegments A and B
104
+ pivot = len(sample)
105
+ is_next_random = False
106
+ if self.config.classification_head:
107
+ assert len(sample) > 1, "the sample must contain at least two sentences"
108
+ pivot = 1
109
+ if len(sample) >= 3:
110
+ pivot = numpy_random_state.randint(low=1, high=len(sample))
111
+ is_next_random = numpy_random_state.random() < 0.5
112
+ split_A = []
113
+ for sample_a in sample[:pivot]:
114
+ split_A.extend(sample_a)
115
+ split_B = []
116
+ for sample_b in sample[pivot:]:
117
+ split_B.extend(sample_b)
118
+ if is_next_random:
119
+ split_A, split_B = split_B, split_A
120
+
121
+ # Trim the subsegments from either end to a desired joint length
122
+ length_A = len(split_A)
123
+ length_B = len(split_B)
124
+ if length_A + length_B <= target_sequence_length:
125
+ truncated = False
126
+ else:
127
+ while length_A + length_B > target_sequence_length:
128
+ split = split_A if length_A > length_B else split_B
129
+ if numpy_random_state.random() < 0.5:
130
+ del split[0]
131
+ else:
132
+ del split[-1]
133
+ length_A = len(split_A)
134
+ length_B = len(split_B)
135
+ truncated = True
136
+
137
+ # Merge the subsegments and create the token assignment labels
138
+ tokens = [
139
+ self.config.tokenizer.cls,
140
+ *split_A,
141
+ self.config.tokenizer.sep,
142
+ ]
143
+ assignments = [0 for _ in range(1 + len(split_A) + 1)]
144
+ if split_B:
145
+ tokens += [*split_B, self.config.tokenizer.sep]
146
+ assignments += [1 for _ in range(len(split_B) + 1)]
147
+
148
+ # Masking
149
+ tokens, masked_positions, masked_labels, _, _ = self._create_masked_lm_predictions(
150
+ tokens, target_sequence_length, numpy_random_state
151
+ )
152
+
153
+ # Pad the sequences and convert to NumPy
154
+ length_toks = len(tokens)
155
+ length_pads = self.config.sequence_length - length_toks
156
+ assert length_pads >= 0
157
+
158
+ tokens = numpy.array(tokens, dtype=numpy.int64)
159
+ tokens = numpy.pad(tokens, (0, length_pads), constant_values=self.config.tokenizer.pad)
160
+
161
+ assignments = numpy.array(assignments, dtype=numpy.int64)
162
+ assignments = numpy.pad(
163
+ assignments, (0, length_pads), constant_values=self.config.tokenizer.pad
164
+ )
165
+
166
+ # Get the padding mask
167
+ mask_pads = numpy.ones(length_toks, dtype=numpy.int64)
168
+ mask_pads = numpy.pad(
169
+ mask_pads, (0, length_pads), constant_values=self.config.tokenizer.pad
170
+ )
171
+
172
+ # Mask the labels
173
+ labels = numpy.zeros(self.config.sequence_length, dtype=numpy.int64) - 1
174
+ labels[masked_positions] = masked_labels
175
+
176
+ # Get the loss mask
177
+ mask_loss = numpy.zeros(self.config.sequence_length, dtype=numpy.int64)
178
+ mask_loss[masked_positions] = 1
179
+
180
+ return {
181
+ "text": tokens,
182
+ "types": assignments,
183
+ "labels": labels,
184
+ "is_random": int(is_next_random),
185
+ "padding_mask": mask_pads,
186
+ "loss_mask": mask_loss,
187
+ "truncated": int(truncated),
188
+ }
189
+
190
+ def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> Optional[int]:
191
+ """Abstract method implementation
192
+
193
+ 80% of the time, replace the token id with mask token id. 10% of the time, replace token id
194
+ with a random token id from the vocabulary. 10% of the time, do nothing.
195
+
196
+ Args:
197
+ numpy_random_state (RandomState): The NumPy random state
198
+
199
+ Returns:
200
+ Optional[int]: The replacement token id or None
201
+ """
202
+ if numpy_random_state.random() < 0.8:
203
+ return self.config.tokenizer.mask
204
+ else:
205
+ if numpy_random_state.random() >= 0.5:
206
+ return self.token_lookup[numpy_random_state.randint(0, len(self.token_lookup))]
207
+ return None
@@ -0,0 +1,192 @@
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ import hashlib
4
+ import json
5
+ import logging
6
+ import os
7
+ import time
8
+ from collections import OrderedDict
9
+ from typing import Dict, List, Tuple, Union
10
+
11
+ import numpy
12
+ import torch
13
+
14
+ from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
15
+ from megatron.core.datasets.megatron_dataset import MegatronDataset
16
+ from megatron.core.datasets.utils import log_single_rank, normalize
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ _VERBOSE = False
21
+
22
+
23
+ class BlendedDataset(torch.utils.data.Dataset):
24
+ """Conjugating class for a set of MegatronDataset instances
25
+
26
+ Args:
27
+ datasets (List[MegatronDataset]): The MegatronDataset instances to blend
28
+
29
+ weights (List[float]): The weights which determines the dataset blend ratios
30
+
31
+ size (int): The number of samples to draw from the blend
32
+
33
+ config (BlendedMegatronDatasetConfig): The config
34
+
35
+ Raises:
36
+ RuntimeError: When the dataset has fewer or more samples than 'size' post-initialization
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ datasets: List[MegatronDataset],
42
+ weights: List[float],
43
+ size: int,
44
+ config: BlendedMegatronDatasetConfig,
45
+ ) -> None:
46
+ assert len(datasets) < 32767
47
+ assert len(datasets) == len(weights)
48
+ assert numpy.isclose(sum(weights), 1.0)
49
+ assert all(map(lambda _: type(_) == type(datasets[0]), datasets))
50
+
51
+ # Alert user to unnecessary blending
52
+ if len(datasets) == 1:
53
+ log_single_rank(
54
+ logger, logging.WARNING, f"Building a BlendedDataset for a single MegatronDataset"
55
+ )
56
+
57
+ # Redundant normalization for bitwise identical comparison with Megatron-LM
58
+ weights = normalize(weights)
59
+
60
+ self.datasets = datasets
61
+ self.weights = weights
62
+ self.size = size
63
+ self.config = config
64
+
65
+ unique_identifiers = OrderedDict()
66
+ unique_identifiers["class"] = type(self).__name__
67
+ unique_identifiers["datasets"] = [dataset.unique_identifiers for dataset in self.datasets]
68
+ unique_identifiers["weights"] = self.weights
69
+ unique_identifiers["size"] = self.size
70
+
71
+ self.unique_description = json.dumps(
72
+ unique_identifiers, indent=4, default=lambda obj: obj.unique_identifiers
73
+ )
74
+ self.unique_description_hash = hashlib.md5(
75
+ self.unique_description.encode("utf-8")
76
+ ).hexdigest()
77
+
78
+ self.dataset_index, self.dataset_sample_index = self._build_indices()
79
+
80
+ # Check size
81
+ _ = self[self.size - 1]
82
+ try:
83
+ _ = self[self.size]
84
+ raise RuntimeError(f"{type(self).__name__} size is improperly bounded")
85
+ except IndexError:
86
+ log_single_rank(logger, logging.INFO, f"> {type(self).__name__} length: {len(self)}")
87
+
88
+ def __len__(self) -> int:
89
+ return self.size
90
+
91
+ def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
92
+ dataset_id = self.dataset_index[idx]
93
+ dataset_sample_id = self.dataset_sample_index[idx]
94
+ return {
95
+ "dataset_id": dataset_id,
96
+ **self.datasets[dataset_id][dataset_sample_id],
97
+ }
98
+
99
+ def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]:
100
+ """Build and optionally cache the dataset index and the dataset sample index
101
+
102
+ The dataset index is a 1-D mapping which determines the dataset to query. The dataset
103
+ sample index is a 1-D mapping which determines the sample to request from the queried
104
+ dataset.
105
+
106
+ Returns:
107
+ Tuple[numpy.ndarray, numpy.ndarray]: The dataset index and the dataset sample index
108
+ """
109
+ path_to_cache = self.config.path_to_cache
110
+
111
+ if path_to_cache:
112
+ get_path_to = lambda suffix: os.path.join(
113
+ path_to_cache, f"{self.unique_description_hash}-{type(self).__name__}-{suffix}"
114
+ )
115
+ path_to_description = get_path_to("description.txt")
116
+ path_to_dataset_index = get_path_to("dataset_index.npy")
117
+ path_to_dataset_sample_index = get_path_to("dataset_sample_index.npy")
118
+ cache_hit = all(
119
+ map(
120
+ os.path.isfile,
121
+ [path_to_description, path_to_dataset_index, path_to_dataset_sample_index],
122
+ )
123
+ )
124
+ else:
125
+ cache_hit = False
126
+
127
+ if not path_to_cache or (not cache_hit and torch.distributed.get_rank() == 0):
128
+ log_single_rank(
129
+ logger, logging.INFO, f"Build and save the {type(self).__name__} indices",
130
+ )
131
+
132
+ # Build the dataset and dataset sample indexes
133
+ log_single_rank(
134
+ logger, logging.INFO, f"\tBuild and save the dataset and dataset sample indexes"
135
+ )
136
+ t_beg = time.time()
137
+ from megatron.core.datasets import helpers
138
+
139
+ dataset_index = numpy.zeros(self.size, dtype=numpy.int16)
140
+ dataset_sample_index = numpy.zeros(self.size, dtype=numpy.int64)
141
+ helpers.build_blending_indices(
142
+ dataset_index,
143
+ dataset_sample_index,
144
+ self.weights,
145
+ len(self.datasets),
146
+ self.size,
147
+ _VERBOSE,
148
+ )
149
+
150
+ if path_to_cache:
151
+ os.makedirs(path_to_cache, exist_ok=True)
152
+ # Write the description
153
+ with open(path_to_description, "wt") as writer:
154
+ writer.write(self.unique_description)
155
+ # Save the indexes
156
+ numpy.save(path_to_dataset_index, dataset_index, allow_pickle=True)
157
+ numpy.save(path_to_dataset_sample_index, dataset_sample_index, allow_pickle=True)
158
+ else:
159
+ log_single_rank(
160
+ logger,
161
+ logging.WARNING,
162
+ "Unable to save the indexes because path_to_cache is None",
163
+ )
164
+
165
+ t_end = time.time()
166
+ log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
167
+
168
+ return dataset_index, dataset_sample_index
169
+
170
+ log_single_rank(logger, logging.INFO, f"Load the {type(self).__name__} indices")
171
+
172
+ log_single_rank(
173
+ logger, logging.INFO, f"\tLoad the dataset index from {path_to_dataset_index}"
174
+ )
175
+ t_beg = time.time()
176
+ dataset_index = numpy.load(path_to_dataset_index, allow_pickle=True, mmap_mode='r')
177
+ t_end = time.time()
178
+ log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
179
+
180
+ log_single_rank(
181
+ logger,
182
+ logging.INFO,
183
+ f"\tLoad the dataset sample index from {path_to_dataset_sample_index}",
184
+ )
185
+ t_beg = time.time()
186
+ dataset_sample_index = numpy.load(
187
+ path_to_dataset_sample_index, allow_pickle=True, mmap_mode='r'
188
+ )
189
+ t_end = time.time()
190
+ log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
191
+
192
+ return dataset_index, dataset_sample_index