megatron-core 0.5.0__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megatron-core might be problematic. Click here for more details.
- megatron/core/__init__.py +20 -0
- megatron/core/datasets/__init__.py +0 -0
- megatron/core/datasets/bert_dataset.py +207 -0
- megatron/core/datasets/blended_dataset.py +192 -0
- megatron/core/datasets/blended_megatron_dataset_builder.py +329 -0
- megatron/core/datasets/blended_megatron_dataset_config.py +171 -0
- megatron/core/datasets/gpt_dataset.py +642 -0
- megatron/core/datasets/helpers.cpp +765 -0
- megatron/core/datasets/helpers.cpython-310-x86_64-linux-gnu.so +0 -0
- megatron/core/datasets/indexed_dataset.py +639 -0
- megatron/core/datasets/masked_dataset.py +430 -0
- megatron/core/datasets/megatron_dataset.py +187 -0
- megatron/core/datasets/megatron_tokenizer.py +141 -0
- megatron/core/datasets/t5_dataset.py +239 -0
- megatron/core/datasets/utils.py +64 -0
- megatron/core/dist_checkpointing/__init__.py +11 -0
- megatron/core/dist_checkpointing/core.py +77 -0
- megatron/core/dist_checkpointing/dict_utils.py +232 -0
- megatron/core/dist_checkpointing/mapping.py +346 -0
- megatron/core/dist_checkpointing/optimizer.py +127 -0
- megatron/core/dist_checkpointing/serialization.py +453 -0
- megatron/core/dist_checkpointing/strategies/__init__.py +20 -0
- megatron/core/dist_checkpointing/strategies/base.py +105 -0
- megatron/core/dist_checkpointing/strategies/tensorstore.py +131 -0
- megatron/core/dist_checkpointing/strategies/two_stage.py +255 -0
- megatron/core/dist_checkpointing/strategies/zarr.py +298 -0
- megatron/core/dist_checkpointing/utils.py +139 -0
- megatron/core/distributed/__init__.py +2 -0
- megatron/core/distributed/distributed_data_parallel.py +250 -0
- megatron/core/distributed/finalize_model_grads.py +158 -0
- megatron/core/distributed/grad_buffer.py +426 -0
- megatron/core/enums.py +10 -0
- megatron/core/fusions/__init__.py +0 -0
- megatron/core/fusions/fused_bias_dropout.py +73 -0
- megatron/core/fusions/fused_bias_gelu.py +50 -0
- megatron/core/fusions/fused_bias_swiglu.py +81 -0
- megatron/core/fusions/fused_layer_norm.py +172 -0
- megatron/core/fusions/fused_softmax.py +220 -0
- megatron/core/inference_params.py +27 -0
- megatron/core/jit.py +11 -0
- megatron/core/model_parallel_config.py +247 -0
- megatron/core/models/T5/__init__.py +1 -0
- megatron/core/models/T5/t5_model.py +428 -0
- megatron/core/models/T5/t5_spec.py +220 -0
- megatron/core/models/__init__.py +0 -0
- megatron/core/models/bert/__init__.py +0 -0
- megatron/core/models/bert/bert_layer_specs.py +64 -0
- megatron/core/models/bert/bert_lm_head.py +75 -0
- megatron/core/models/bert/bert_model.py +282 -0
- megatron/core/models/bert/pooler.py +51 -0
- megatron/core/models/common/__init__.py +0 -0
- megatron/core/models/common/embeddings/__init__.py +0 -0
- megatron/core/models/common/embeddings/language_model_embedding.py +128 -0
- megatron/core/models/common/embeddings/rotary_pos_embedding.py +249 -0
- megatron/core/models/common/language_module/__init__.py +0 -0
- megatron/core/models/common/language_module/language_module.py +105 -0
- megatron/core/models/gpt/__init__.py +1 -0
- megatron/core/models/gpt/gpt_layer_specs.py +99 -0
- megatron/core/models/gpt/gpt_model.py +247 -0
- megatron/core/models/retro/__init__.py +5 -0
- megatron/core/models/retro/base_attention.py +45 -0
- megatron/core/models/retro/config.py +43 -0
- megatron/core/models/retro/decoder_attention.py +301 -0
- megatron/core/models/retro/decoder_spec.py +152 -0
- megatron/core/models/retro/encoder_attention.py +223 -0
- megatron/core/models/retro/encoder_spec.py +141 -0
- megatron/core/models/retro/model.py +89 -0
- megatron/core/package_info.py +29 -0
- megatron/core/packed_seq_params.py +13 -0
- megatron/core/parallel_state.py +1014 -0
- megatron/core/pipeline_parallel/__init__.py +1 -0
- megatron/core/pipeline_parallel/p2p_communication.py +571 -0
- megatron/core/pipeline_parallel/schedules.py +1341 -0
- megatron/core/requirements.txt +1 -0
- megatron/core/tensor_parallel/__init__.py +65 -0
- megatron/core/tensor_parallel/cross_entropy.py +142 -0
- megatron/core/tensor_parallel/data.py +104 -0
- megatron/core/tensor_parallel/layers.py +998 -0
- megatron/core/tensor_parallel/mappings.py +358 -0
- megatron/core/tensor_parallel/random.py +266 -0
- megatron/core/tensor_parallel/utils.py +113 -0
- megatron/core/timers.py +391 -0
- megatron/core/transformer/__init__.py +6 -0
- megatron/core/transformer/attention.py +487 -0
- megatron/core/transformer/custom_layers/__init__.py +0 -0
- megatron/core/transformer/custom_layers/transformer_engine.py +495 -0
- megatron/core/transformer/dot_product_attention.py +205 -0
- megatron/core/transformer/enums.py +26 -0
- megatron/core/transformer/identity_op.py +28 -0
- megatron/core/transformer/mlp.py +188 -0
- megatron/core/transformer/module.py +185 -0
- megatron/core/transformer/moe/__init__.py +0 -0
- megatron/core/transformer/moe/experts.py +235 -0
- megatron/core/transformer/moe/grouped_gemm_util.py +20 -0
- megatron/core/transformer/moe/moe_layer.py +80 -0
- megatron/core/transformer/moe/moe_utils.py +101 -0
- megatron/core/transformer/moe/router.py +242 -0
- megatron/core/transformer/moe/token_dispatcher.py +279 -0
- megatron/core/transformer/spec_utils.py +109 -0
- megatron/core/transformer/transformer_block.py +418 -0
- megatron/core/transformer/transformer_config.py +256 -0
- megatron/core/transformer/transformer_layer.py +234 -0
- megatron/core/transformer/utils.py +184 -0
- megatron/core/utils.py +236 -0
- megatron_core-0.5.0.dist-info/LICENSE +291 -0
- megatron_core-0.5.0.dist-info/METADATA +34 -0
- megatron_core-0.5.0.dist-info/RECORD +109 -0
- megatron_core-0.5.0.dist-info/WHEEL +6 -0
- megatron_core-0.5.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import megatron.core.tensor_parallel
|
|
2
|
+
import megatron.core.utils
|
|
3
|
+
from megatron.core import parallel_state
|
|
4
|
+
from megatron.core.distributed import DistributedDataParallel
|
|
5
|
+
from megatron.core.inference_params import InferenceParams
|
|
6
|
+
from megatron.core.model_parallel_config import ModelParallelConfig
|
|
7
|
+
from megatron.core.timers import Timers
|
|
8
|
+
|
|
9
|
+
# Alias parallel_state as mpu, its legacy name
|
|
10
|
+
mpu = parallel_state
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"parallel_state",
|
|
14
|
+
"tensor_parallel",
|
|
15
|
+
"utils",
|
|
16
|
+
"DistributedDataParallel",
|
|
17
|
+
"InferenceParams",
|
|
18
|
+
"ModelParallelConfig",
|
|
19
|
+
"Timers",
|
|
20
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, List, Optional, Union
|
|
5
|
+
|
|
6
|
+
import numpy
|
|
7
|
+
|
|
8
|
+
from megatron.core.datasets.indexed_dataset import MMapIndexedDataset
|
|
9
|
+
from megatron.core.datasets.masked_dataset import (
|
|
10
|
+
MaskedWordPieceDataset,
|
|
11
|
+
MaskedWordPieceDatasetConfig,
|
|
12
|
+
)
|
|
13
|
+
from megatron.core.datasets.utils import Split
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class BERTMaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig):
|
|
18
|
+
"""Configuration object for Megatron Core BERT WordPiece datasets
|
|
19
|
+
|
|
20
|
+
Attributes:
|
|
21
|
+
classification_head (bool): Option to perform the next sequence prediction during
|
|
22
|
+
sampling
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
classification_head: bool = None
|
|
26
|
+
|
|
27
|
+
def __post_init__(self) -> None:
|
|
28
|
+
"""Do asserts and set fields post init
|
|
29
|
+
"""
|
|
30
|
+
super().__post_init__()
|
|
31
|
+
|
|
32
|
+
assert self.classification_head is not None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class BERTMaskedWordPieceDataset(MaskedWordPieceDataset):
|
|
36
|
+
"""The BERT dataset that assumes WordPiece tokenization
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
indexed_dataset (MMapIndexedDataset): The MMapIndexedDataset around which to build the
|
|
40
|
+
MegatronDataset
|
|
41
|
+
|
|
42
|
+
dataset_path (str): The real path on disk to the dataset, for bookkeeping
|
|
43
|
+
|
|
44
|
+
indexed_indices (numpy.ndarray): The set of the documents indices to expose
|
|
45
|
+
|
|
46
|
+
num_samples (int): The number of samples to draw from the indexed dataset
|
|
47
|
+
|
|
48
|
+
index_split (Split): The indexed_indices Split
|
|
49
|
+
|
|
50
|
+
config (BERTMaskedWordPieceDatasetConfig): The config
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
indexed_dataset: MMapIndexedDataset,
|
|
56
|
+
dataset_path: str,
|
|
57
|
+
indexed_indices: numpy.ndarray,
|
|
58
|
+
num_samples: int,
|
|
59
|
+
index_split: Split,
|
|
60
|
+
config: BERTMaskedWordPieceDatasetConfig,
|
|
61
|
+
) -> None:
|
|
62
|
+
super().__init__(
|
|
63
|
+
indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def _finalize(self) -> None:
|
|
67
|
+
"""Abstract method implementation
|
|
68
|
+
"""
|
|
69
|
+
self.token_lookup = list(self.config.tokenizer.inv_vocab.keys())
|
|
70
|
+
# Account for the single <cls> and two <sep> token ids
|
|
71
|
+
self.sample_index = self._build_sample_index(
|
|
72
|
+
self.config.sequence_length - 3, 2 if self.config.classification_head else 1
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def _key_config_attributes() -> List[str]:
|
|
77
|
+
"""Inherited method implementation
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
List[str]: The key config attributes
|
|
81
|
+
"""
|
|
82
|
+
return super(
|
|
83
|
+
BERTMaskedWordPieceDataset, BERTMaskedWordPieceDataset
|
|
84
|
+
)._key_config_attributes() + ["classification_head",]
|
|
85
|
+
|
|
86
|
+
def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
|
|
87
|
+
"""Abstract method implementation
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
idx (int): The index into the dataset
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Dict[str, Union[int, numpy.ndarray]]: The
|
|
94
|
+
"""
|
|
95
|
+
idx_beg, idx_end, target_sequence_length = self.sample_index[idx]
|
|
96
|
+
sample = [self.dataset[i] for i in range(idx_beg, idx_end)]
|
|
97
|
+
numpy_random_state = numpy.random.RandomState(
|
|
98
|
+
seed=(self.config.random_seed + idx) % 2 ** 32
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
assert target_sequence_length <= self.config.sequence_length
|
|
102
|
+
|
|
103
|
+
# Split the sample into contiguous subsegments A and B
|
|
104
|
+
pivot = len(sample)
|
|
105
|
+
is_next_random = False
|
|
106
|
+
if self.config.classification_head:
|
|
107
|
+
assert len(sample) > 1, "the sample must contain at least two sentences"
|
|
108
|
+
pivot = 1
|
|
109
|
+
if len(sample) >= 3:
|
|
110
|
+
pivot = numpy_random_state.randint(low=1, high=len(sample))
|
|
111
|
+
is_next_random = numpy_random_state.random() < 0.5
|
|
112
|
+
split_A = []
|
|
113
|
+
for sample_a in sample[:pivot]:
|
|
114
|
+
split_A.extend(sample_a)
|
|
115
|
+
split_B = []
|
|
116
|
+
for sample_b in sample[pivot:]:
|
|
117
|
+
split_B.extend(sample_b)
|
|
118
|
+
if is_next_random:
|
|
119
|
+
split_A, split_B = split_B, split_A
|
|
120
|
+
|
|
121
|
+
# Trim the subsegments from either end to a desired joint length
|
|
122
|
+
length_A = len(split_A)
|
|
123
|
+
length_B = len(split_B)
|
|
124
|
+
if length_A + length_B <= target_sequence_length:
|
|
125
|
+
truncated = False
|
|
126
|
+
else:
|
|
127
|
+
while length_A + length_B > target_sequence_length:
|
|
128
|
+
split = split_A if length_A > length_B else split_B
|
|
129
|
+
if numpy_random_state.random() < 0.5:
|
|
130
|
+
del split[0]
|
|
131
|
+
else:
|
|
132
|
+
del split[-1]
|
|
133
|
+
length_A = len(split_A)
|
|
134
|
+
length_B = len(split_B)
|
|
135
|
+
truncated = True
|
|
136
|
+
|
|
137
|
+
# Merge the subsegments and create the token assignment labels
|
|
138
|
+
tokens = [
|
|
139
|
+
self.config.tokenizer.cls,
|
|
140
|
+
*split_A,
|
|
141
|
+
self.config.tokenizer.sep,
|
|
142
|
+
]
|
|
143
|
+
assignments = [0 for _ in range(1 + len(split_A) + 1)]
|
|
144
|
+
if split_B:
|
|
145
|
+
tokens += [*split_B, self.config.tokenizer.sep]
|
|
146
|
+
assignments += [1 for _ in range(len(split_B) + 1)]
|
|
147
|
+
|
|
148
|
+
# Masking
|
|
149
|
+
tokens, masked_positions, masked_labels, _, _ = self._create_masked_lm_predictions(
|
|
150
|
+
tokens, target_sequence_length, numpy_random_state
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Pad the sequences and convert to NumPy
|
|
154
|
+
length_toks = len(tokens)
|
|
155
|
+
length_pads = self.config.sequence_length - length_toks
|
|
156
|
+
assert length_pads >= 0
|
|
157
|
+
|
|
158
|
+
tokens = numpy.array(tokens, dtype=numpy.int64)
|
|
159
|
+
tokens = numpy.pad(tokens, (0, length_pads), constant_values=self.config.tokenizer.pad)
|
|
160
|
+
|
|
161
|
+
assignments = numpy.array(assignments, dtype=numpy.int64)
|
|
162
|
+
assignments = numpy.pad(
|
|
163
|
+
assignments, (0, length_pads), constant_values=self.config.tokenizer.pad
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Get the padding mask
|
|
167
|
+
mask_pads = numpy.ones(length_toks, dtype=numpy.int64)
|
|
168
|
+
mask_pads = numpy.pad(
|
|
169
|
+
mask_pads, (0, length_pads), constant_values=self.config.tokenizer.pad
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Mask the labels
|
|
173
|
+
labels = numpy.zeros(self.config.sequence_length, dtype=numpy.int64) - 1
|
|
174
|
+
labels[masked_positions] = masked_labels
|
|
175
|
+
|
|
176
|
+
# Get the loss mask
|
|
177
|
+
mask_loss = numpy.zeros(self.config.sequence_length, dtype=numpy.int64)
|
|
178
|
+
mask_loss[masked_positions] = 1
|
|
179
|
+
|
|
180
|
+
return {
|
|
181
|
+
"text": tokens,
|
|
182
|
+
"types": assignments,
|
|
183
|
+
"labels": labels,
|
|
184
|
+
"is_random": int(is_next_random),
|
|
185
|
+
"padding_mask": mask_pads,
|
|
186
|
+
"loss_mask": mask_loss,
|
|
187
|
+
"truncated": int(truncated),
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> Optional[int]:
|
|
191
|
+
"""Abstract method implementation
|
|
192
|
+
|
|
193
|
+
80% of the time, replace the token id with mask token id. 10% of the time, replace token id
|
|
194
|
+
with a random token id from the vocabulary. 10% of the time, do nothing.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
numpy_random_state (RandomState): The NumPy random state
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Optional[int]: The replacement token id or None
|
|
201
|
+
"""
|
|
202
|
+
if numpy_random_state.random() < 0.8:
|
|
203
|
+
return self.config.tokenizer.mask
|
|
204
|
+
else:
|
|
205
|
+
if numpy_random_state.random() >= 0.5:
|
|
206
|
+
return self.token_lookup[numpy_random_state.randint(0, len(self.token_lookup))]
|
|
207
|
+
return None
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import time
|
|
8
|
+
from collections import OrderedDict
|
|
9
|
+
from typing import Dict, List, Tuple, Union
|
|
10
|
+
|
|
11
|
+
import numpy
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
|
|
15
|
+
from megatron.core.datasets.megatron_dataset import MegatronDataset
|
|
16
|
+
from megatron.core.datasets.utils import log_single_rank, normalize
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
_VERBOSE = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BlendedDataset(torch.utils.data.Dataset):
|
|
24
|
+
"""Conjugating class for a set of MegatronDataset instances
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
datasets (List[MegatronDataset]): The MegatronDataset instances to blend
|
|
28
|
+
|
|
29
|
+
weights (List[float]): The weights which determines the dataset blend ratios
|
|
30
|
+
|
|
31
|
+
size (int): The number of samples to draw from the blend
|
|
32
|
+
|
|
33
|
+
config (BlendedMegatronDatasetConfig): The config
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
RuntimeError: When the dataset has fewer or more samples than 'size' post-initialization
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
datasets: List[MegatronDataset],
|
|
42
|
+
weights: List[float],
|
|
43
|
+
size: int,
|
|
44
|
+
config: BlendedMegatronDatasetConfig,
|
|
45
|
+
) -> None:
|
|
46
|
+
assert len(datasets) < 32767
|
|
47
|
+
assert len(datasets) == len(weights)
|
|
48
|
+
assert numpy.isclose(sum(weights), 1.0)
|
|
49
|
+
assert all(map(lambda _: type(_) == type(datasets[0]), datasets))
|
|
50
|
+
|
|
51
|
+
# Alert user to unnecessary blending
|
|
52
|
+
if len(datasets) == 1:
|
|
53
|
+
log_single_rank(
|
|
54
|
+
logger, logging.WARNING, f"Building a BlendedDataset for a single MegatronDataset"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Redundant normalization for bitwise identical comparison with Megatron-LM
|
|
58
|
+
weights = normalize(weights)
|
|
59
|
+
|
|
60
|
+
self.datasets = datasets
|
|
61
|
+
self.weights = weights
|
|
62
|
+
self.size = size
|
|
63
|
+
self.config = config
|
|
64
|
+
|
|
65
|
+
unique_identifiers = OrderedDict()
|
|
66
|
+
unique_identifiers["class"] = type(self).__name__
|
|
67
|
+
unique_identifiers["datasets"] = [dataset.unique_identifiers for dataset in self.datasets]
|
|
68
|
+
unique_identifiers["weights"] = self.weights
|
|
69
|
+
unique_identifiers["size"] = self.size
|
|
70
|
+
|
|
71
|
+
self.unique_description = json.dumps(
|
|
72
|
+
unique_identifiers, indent=4, default=lambda obj: obj.unique_identifiers
|
|
73
|
+
)
|
|
74
|
+
self.unique_description_hash = hashlib.md5(
|
|
75
|
+
self.unique_description.encode("utf-8")
|
|
76
|
+
).hexdigest()
|
|
77
|
+
|
|
78
|
+
self.dataset_index, self.dataset_sample_index = self._build_indices()
|
|
79
|
+
|
|
80
|
+
# Check size
|
|
81
|
+
_ = self[self.size - 1]
|
|
82
|
+
try:
|
|
83
|
+
_ = self[self.size]
|
|
84
|
+
raise RuntimeError(f"{type(self).__name__} size is improperly bounded")
|
|
85
|
+
except IndexError:
|
|
86
|
+
log_single_rank(logger, logging.INFO, f"> {type(self).__name__} length: {len(self)}")
|
|
87
|
+
|
|
88
|
+
def __len__(self) -> int:
|
|
89
|
+
return self.size
|
|
90
|
+
|
|
91
|
+
def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
|
|
92
|
+
dataset_id = self.dataset_index[idx]
|
|
93
|
+
dataset_sample_id = self.dataset_sample_index[idx]
|
|
94
|
+
return {
|
|
95
|
+
"dataset_id": dataset_id,
|
|
96
|
+
**self.datasets[dataset_id][dataset_sample_id],
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]:
|
|
100
|
+
"""Build and optionally cache the dataset index and the dataset sample index
|
|
101
|
+
|
|
102
|
+
The dataset index is a 1-D mapping which determines the dataset to query. The dataset
|
|
103
|
+
sample index is a 1-D mapping which determines the sample to request from the queried
|
|
104
|
+
dataset.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Tuple[numpy.ndarray, numpy.ndarray]: The dataset index and the dataset sample index
|
|
108
|
+
"""
|
|
109
|
+
path_to_cache = self.config.path_to_cache
|
|
110
|
+
|
|
111
|
+
if path_to_cache:
|
|
112
|
+
get_path_to = lambda suffix: os.path.join(
|
|
113
|
+
path_to_cache, f"{self.unique_description_hash}-{type(self).__name__}-{suffix}"
|
|
114
|
+
)
|
|
115
|
+
path_to_description = get_path_to("description.txt")
|
|
116
|
+
path_to_dataset_index = get_path_to("dataset_index.npy")
|
|
117
|
+
path_to_dataset_sample_index = get_path_to("dataset_sample_index.npy")
|
|
118
|
+
cache_hit = all(
|
|
119
|
+
map(
|
|
120
|
+
os.path.isfile,
|
|
121
|
+
[path_to_description, path_to_dataset_index, path_to_dataset_sample_index],
|
|
122
|
+
)
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
cache_hit = False
|
|
126
|
+
|
|
127
|
+
if not path_to_cache or (not cache_hit and torch.distributed.get_rank() == 0):
|
|
128
|
+
log_single_rank(
|
|
129
|
+
logger, logging.INFO, f"Build and save the {type(self).__name__} indices",
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Build the dataset and dataset sample indexes
|
|
133
|
+
log_single_rank(
|
|
134
|
+
logger, logging.INFO, f"\tBuild and save the dataset and dataset sample indexes"
|
|
135
|
+
)
|
|
136
|
+
t_beg = time.time()
|
|
137
|
+
from megatron.core.datasets import helpers
|
|
138
|
+
|
|
139
|
+
dataset_index = numpy.zeros(self.size, dtype=numpy.int16)
|
|
140
|
+
dataset_sample_index = numpy.zeros(self.size, dtype=numpy.int64)
|
|
141
|
+
helpers.build_blending_indices(
|
|
142
|
+
dataset_index,
|
|
143
|
+
dataset_sample_index,
|
|
144
|
+
self.weights,
|
|
145
|
+
len(self.datasets),
|
|
146
|
+
self.size,
|
|
147
|
+
_VERBOSE,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
if path_to_cache:
|
|
151
|
+
os.makedirs(path_to_cache, exist_ok=True)
|
|
152
|
+
# Write the description
|
|
153
|
+
with open(path_to_description, "wt") as writer:
|
|
154
|
+
writer.write(self.unique_description)
|
|
155
|
+
# Save the indexes
|
|
156
|
+
numpy.save(path_to_dataset_index, dataset_index, allow_pickle=True)
|
|
157
|
+
numpy.save(path_to_dataset_sample_index, dataset_sample_index, allow_pickle=True)
|
|
158
|
+
else:
|
|
159
|
+
log_single_rank(
|
|
160
|
+
logger,
|
|
161
|
+
logging.WARNING,
|
|
162
|
+
"Unable to save the indexes because path_to_cache is None",
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
t_end = time.time()
|
|
166
|
+
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
|
|
167
|
+
|
|
168
|
+
return dataset_index, dataset_sample_index
|
|
169
|
+
|
|
170
|
+
log_single_rank(logger, logging.INFO, f"Load the {type(self).__name__} indices")
|
|
171
|
+
|
|
172
|
+
log_single_rank(
|
|
173
|
+
logger, logging.INFO, f"\tLoad the dataset index from {path_to_dataset_index}"
|
|
174
|
+
)
|
|
175
|
+
t_beg = time.time()
|
|
176
|
+
dataset_index = numpy.load(path_to_dataset_index, allow_pickle=True, mmap_mode='r')
|
|
177
|
+
t_end = time.time()
|
|
178
|
+
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
|
|
179
|
+
|
|
180
|
+
log_single_rank(
|
|
181
|
+
logger,
|
|
182
|
+
logging.INFO,
|
|
183
|
+
f"\tLoad the dataset sample index from {path_to_dataset_sample_index}",
|
|
184
|
+
)
|
|
185
|
+
t_beg = time.time()
|
|
186
|
+
dataset_sample_index = numpy.load(
|
|
187
|
+
path_to_dataset_sample_index, allow_pickle=True, mmap_mode='r'
|
|
188
|
+
)
|
|
189
|
+
t_end = time.time()
|
|
190
|
+
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
|
|
191
|
+
|
|
192
|
+
return dataset_index, dataset_sample_index
|