langtune 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langtune/__init__.py +315 -0
- langtune/acceleration.py +132 -0
- langtune/api.py +320 -0
- langtune/auth.py +434 -0
- langtune/callbacks.py +268 -0
- langtune/cli.py +687 -0
- langtune/client.py +721 -0
- langtune/config.py +356 -0
- langtune/data.py +526 -0
- langtune/distributed.py +154 -0
- langtune/facade.py +174 -0
- langtune/finetune.py +491 -0
- langtune/generation.py +95 -0
- langtune/logging_utils.py +182 -0
- langtune/metrics.py +345 -0
- langtune/model/__init__.py +20 -0
- langtune/model/hub.py +109 -0
- langtune/model/loader.py +84 -0
- langtune/model/safetensors.py +104 -0
- langtune/model/weights.py +100 -0
- langtune/models.py +19 -0
- langtune/nn/fast_transformer.py +399 -0
- langtune/nn/layers.py +178 -0
- langtune/nn/transformer.py +254 -0
- langtune/optimizations.py +870 -0
- langtune/py.typed +2 -0
- langtune/schedulers.py +234 -0
- langtune/tokenizers.py +275 -0
- langtune/trainer.py +889 -0
- langtune/training/neftune.py +80 -0
- langtune/utils.py +337 -0
- langtune-0.1.19.dist-info/METADATA +257 -0
- langtune-0.1.19.dist-info/RECORD +37 -0
- langtune-0.1.19.dist-info/WHEEL +5 -0
- langtune-0.1.19.dist-info/entry_points.txt +2 -0
- langtune-0.1.19.dist-info/licenses/LICENSE +21 -0
- langtune-0.1.19.dist-info/top_level.txt +1 -0
langtune/data.py
ADDED
|
@@ -0,0 +1,526 @@
|
|
|
1
|
+
"""
|
|
2
|
+
data.py: Data loading and preprocessing utilities for Langtune
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import json
|
|
7
|
+
import torch
|
|
8
|
+
from torch.utils.data import Dataset, DataLoader
|
|
9
|
+
from typing import List, Dict, Any, Optional, Union, Iterator
|
|
10
|
+
import logging
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
import random
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
class TextDataset(Dataset):
|
|
18
|
+
"""
|
|
19
|
+
A PyTorch Dataset for text data with tokenization support.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
texts: List[str],
|
|
25
|
+
tokenizer=None,
|
|
26
|
+
max_length: int = 512,
|
|
27
|
+
padding: str = "max_length",
|
|
28
|
+
truncation: bool = True,
|
|
29
|
+
add_special_tokens: bool = True
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Initialize the dataset.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
texts: List of text strings
|
|
36
|
+
tokenizer: Tokenizer object (optional)
|
|
37
|
+
max_length: Maximum sequence length
|
|
38
|
+
padding: Padding strategy
|
|
39
|
+
truncation: Whether to truncate sequences
|
|
40
|
+
add_special_tokens: Whether to add special tokens
|
|
41
|
+
"""
|
|
42
|
+
self.texts = texts
|
|
43
|
+
self.tokenizer = tokenizer
|
|
44
|
+
self.max_length = max_length
|
|
45
|
+
self.padding = padding
|
|
46
|
+
self.truncation = truncation
|
|
47
|
+
self.add_special_tokens = add_special_tokens
|
|
48
|
+
|
|
49
|
+
logger.info(f"Initialized TextDataset with {len(texts)} samples")
|
|
50
|
+
|
|
51
|
+
def __len__(self) -> int:
|
|
52
|
+
return len(self.texts)
|
|
53
|
+
|
|
54
|
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
|
55
|
+
text = self.texts[idx]
|
|
56
|
+
|
|
57
|
+
if self.tokenizer:
|
|
58
|
+
# Use provided tokenizer
|
|
59
|
+
encoding = self.tokenizer(
|
|
60
|
+
text,
|
|
61
|
+
max_length=self.max_length,
|
|
62
|
+
padding=self.padding,
|
|
63
|
+
truncation=self.truncation,
|
|
64
|
+
add_special_tokens=self.add_special_tokens,
|
|
65
|
+
return_tensors="pt"
|
|
66
|
+
)
|
|
67
|
+
return {
|
|
68
|
+
"input_ids": encoding["input_ids"].squeeze(0),
|
|
69
|
+
"attention_mask": encoding.get("attention_mask", torch.ones_like(encoding["input_ids"])).squeeze(0)
|
|
70
|
+
}
|
|
71
|
+
else:
|
|
72
|
+
# Simple character-level tokenization as fallback
|
|
73
|
+
input_ids = torch.tensor([ord(c) for c in text[:self.max_length]], dtype=torch.long)
|
|
74
|
+
|
|
75
|
+
# Pad or truncate
|
|
76
|
+
if len(input_ids) < self.max_length:
|
|
77
|
+
padding_length = self.max_length - len(input_ids)
|
|
78
|
+
input_ids = torch.cat([input_ids, torch.zeros(padding_length, dtype=torch.long)])
|
|
79
|
+
attention_mask = torch.cat([torch.ones(len(input_ids) - padding_length), torch.zeros(padding_length)])
|
|
80
|
+
else:
|
|
81
|
+
attention_mask = torch.ones(self.max_length)
|
|
82
|
+
|
|
83
|
+
return {
|
|
84
|
+
"input_ids": input_ids,
|
|
85
|
+
"attention_mask": attention_mask
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
class LanguageModelingDataset(Dataset):
|
|
89
|
+
"""
|
|
90
|
+
Dataset for language modeling tasks with next-token prediction.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
def __init__(
|
|
94
|
+
self,
|
|
95
|
+
texts: List[str],
|
|
96
|
+
tokenizer=None,
|
|
97
|
+
max_length: int = 512,
|
|
98
|
+
stride: int = 128,
|
|
99
|
+
padding: str = "max_length",
|
|
100
|
+
truncation: bool = True
|
|
101
|
+
):
|
|
102
|
+
"""
|
|
103
|
+
Initialize the language modeling dataset.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
texts: List of text strings
|
|
107
|
+
tokenizer: Tokenizer object
|
|
108
|
+
max_length: Maximum sequence length
|
|
109
|
+
stride: Stride for sliding window
|
|
110
|
+
padding: Padding strategy
|
|
111
|
+
truncation: Whether to truncate sequences
|
|
112
|
+
"""
|
|
113
|
+
self.texts = texts
|
|
114
|
+
self.tokenizer = tokenizer
|
|
115
|
+
self.max_length = max_length
|
|
116
|
+
self.stride = stride
|
|
117
|
+
self.padding = padding
|
|
118
|
+
self.truncation = truncation
|
|
119
|
+
|
|
120
|
+
# Process texts into sequences
|
|
121
|
+
self.sequences = self._create_sequences()
|
|
122
|
+
|
|
123
|
+
logger.info(f"Initialized LanguageModelingDataset with {len(self.sequences)} sequences")
|
|
124
|
+
|
|
125
|
+
def _create_sequences(self) -> List[Dict[str, torch.Tensor]]:
|
|
126
|
+
"""Create sequences for language modeling."""
|
|
127
|
+
sequences = []
|
|
128
|
+
|
|
129
|
+
for text in self.texts:
|
|
130
|
+
if self.tokenizer:
|
|
131
|
+
# Tokenize the text
|
|
132
|
+
tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
|
133
|
+
|
|
134
|
+
# Create sliding window sequences
|
|
135
|
+
for i in range(0, len(tokens), self.stride):
|
|
136
|
+
sequence = tokens[i:i + self.max_length]
|
|
137
|
+
|
|
138
|
+
if len(sequence) < self.max_length:
|
|
139
|
+
# Pad sequence
|
|
140
|
+
sequence = sequence + [self.tokenizer.pad_token_id] * (self.max_length - len(sequence))
|
|
141
|
+
attention_mask = [1] * (len(sequence) - (self.max_length - len(sequence))) + [0] * (self.max_length - len(sequence))
|
|
142
|
+
else:
|
|
143
|
+
attention_mask = [1] * self.max_length
|
|
144
|
+
|
|
145
|
+
# Create labels (shifted by 1 for next token prediction)
|
|
146
|
+
labels = sequence[1:] + [-100] # -100 is ignored in loss computation
|
|
147
|
+
|
|
148
|
+
sequences.append({
|
|
149
|
+
"input_ids": torch.tensor(sequence, dtype=torch.long),
|
|
150
|
+
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
|
151
|
+
"labels": torch.tensor(labels, dtype=torch.long)
|
|
152
|
+
})
|
|
153
|
+
else:
|
|
154
|
+
# Simple character-level processing
|
|
155
|
+
chars = [ord(c) for c in text]
|
|
156
|
+
|
|
157
|
+
for i in range(0, len(chars), self.stride):
|
|
158
|
+
sequence = chars[i:i + self.max_length]
|
|
159
|
+
|
|
160
|
+
if len(sequence) < self.max_length:
|
|
161
|
+
sequence = sequence + [0] * (self.max_length - len(sequence))
|
|
162
|
+
attention_mask = [1] * (len(sequence) - (self.max_length - len(sequence))) + [0] * (self.max_length - len(sequence))
|
|
163
|
+
else:
|
|
164
|
+
attention_mask = [1] * self.max_length
|
|
165
|
+
|
|
166
|
+
labels = sequence[1:] + [-100]
|
|
167
|
+
|
|
168
|
+
sequences.append({
|
|
169
|
+
"input_ids": torch.tensor(sequence, dtype=torch.long),
|
|
170
|
+
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
|
171
|
+
"labels": torch.tensor(labels, dtype=torch.long)
|
|
172
|
+
})
|
|
173
|
+
|
|
174
|
+
return sequences
|
|
175
|
+
|
|
176
|
+
def __len__(self) -> int:
|
|
177
|
+
return len(self.sequences)
|
|
178
|
+
|
|
179
|
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
|
180
|
+
return self.sequences[idx]
|
|
181
|
+
|
|
182
|
+
class DataCollator:
|
|
183
|
+
"""
|
|
184
|
+
Data collator for batching sequences.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
def __init__(self, pad_token_id: int = 0):
|
|
188
|
+
self.pad_token_id = pad_token_id
|
|
189
|
+
|
|
190
|
+
def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
|
191
|
+
"""
|
|
192
|
+
Collate a batch of sequences.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
batch: List of dictionaries containing sequences
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Batched tensors
|
|
199
|
+
"""
|
|
200
|
+
# Get the maximum length in the batch
|
|
201
|
+
max_length = max(item["input_ids"].size(0) for item in batch)
|
|
202
|
+
|
|
203
|
+
# Pad sequences to the same length
|
|
204
|
+
input_ids = []
|
|
205
|
+
attention_masks = []
|
|
206
|
+
labels = []
|
|
207
|
+
|
|
208
|
+
for item in batch:
|
|
209
|
+
seq_len = item["input_ids"].size(0)
|
|
210
|
+
|
|
211
|
+
# Pad input_ids
|
|
212
|
+
if seq_len < max_length:
|
|
213
|
+
padding = torch.full((max_length - seq_len,), self.pad_token_id, dtype=torch.long)
|
|
214
|
+
input_ids.append(torch.cat([item["input_ids"], padding]))
|
|
215
|
+
else:
|
|
216
|
+
input_ids.append(item["input_ids"])
|
|
217
|
+
|
|
218
|
+
# Pad attention_mask
|
|
219
|
+
if seq_len < max_length:
|
|
220
|
+
padding = torch.zeros(max_length - seq_len, dtype=torch.long)
|
|
221
|
+
attention_masks.append(torch.cat([item["attention_mask"], padding]))
|
|
222
|
+
else:
|
|
223
|
+
attention_masks.append(item["attention_mask"])
|
|
224
|
+
|
|
225
|
+
# Pad labels if present
|
|
226
|
+
if "labels" in item:
|
|
227
|
+
if seq_len < max_length:
|
|
228
|
+
padding = torch.full((max_length - seq_len,), -100, dtype=torch.long)
|
|
229
|
+
labels.append(torch.cat([item["labels"], padding]))
|
|
230
|
+
else:
|
|
231
|
+
labels.append(item["labels"])
|
|
232
|
+
|
|
233
|
+
result = {
|
|
234
|
+
"input_ids": torch.stack(input_ids),
|
|
235
|
+
"attention_mask": torch.stack(attention_masks)
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
if labels:
|
|
239
|
+
result["labels"] = torch.stack(labels)
|
|
240
|
+
|
|
241
|
+
return result
|
|
242
|
+
|
|
243
|
+
def load_text_file(file_path: str, encoding: str = "utf-8") -> List[str]:
|
|
244
|
+
"""
|
|
245
|
+
Load text from a file.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
file_path: Path to the text file
|
|
249
|
+
encoding: File encoding
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
List of text lines
|
|
253
|
+
"""
|
|
254
|
+
if not os.path.exists(file_path):
|
|
255
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
256
|
+
|
|
257
|
+
with open(file_path, 'r', encoding=encoding) as f:
|
|
258
|
+
lines = f.readlines()
|
|
259
|
+
|
|
260
|
+
# Remove empty lines and strip whitespace
|
|
261
|
+
lines = [line.strip() for line in lines if line.strip()]
|
|
262
|
+
|
|
263
|
+
logger.info(f"Loaded {len(lines)} lines from {file_path}")
|
|
264
|
+
return lines
|
|
265
|
+
|
|
266
|
+
def load_json_file(file_path: str, text_key: str = "text") -> List[str]:
|
|
267
|
+
"""
|
|
268
|
+
Load text from a JSON file.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
file_path: Path to the JSON file
|
|
272
|
+
text_key: Key containing the text data
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
List of text strings
|
|
276
|
+
"""
|
|
277
|
+
if not os.path.exists(file_path):
|
|
278
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
279
|
+
|
|
280
|
+
with open(file_path, 'r', encoding='utf-8') as f:
|
|
281
|
+
data = json.load(f)
|
|
282
|
+
|
|
283
|
+
if isinstance(data, list):
|
|
284
|
+
texts = [item[text_key] for item in data if text_key in item]
|
|
285
|
+
elif isinstance(data, dict):
|
|
286
|
+
if text_key in data:
|
|
287
|
+
texts = [data[text_key]]
|
|
288
|
+
else:
|
|
289
|
+
raise ValueError(f"Key '{text_key}' not found in JSON data")
|
|
290
|
+
else:
|
|
291
|
+
raise ValueError("JSON data must be a list or dictionary")
|
|
292
|
+
|
|
293
|
+
logger.info(f"Loaded {len(texts)} texts from {file_path}")
|
|
294
|
+
return texts
|
|
295
|
+
|
|
296
|
+
def create_data_loader(
|
|
297
|
+
dataset: Dataset,
|
|
298
|
+
batch_size: int = 32,
|
|
299
|
+
shuffle: bool = True,
|
|
300
|
+
num_workers: int = 4,
|
|
301
|
+
pin_memory: bool = True,
|
|
302
|
+
collate_fn=None
|
|
303
|
+
) -> DataLoader:
|
|
304
|
+
"""
|
|
305
|
+
Create a DataLoader for the dataset.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
dataset: PyTorch Dataset
|
|
309
|
+
batch_size: Batch size
|
|
310
|
+
shuffle: Whether to shuffle the data
|
|
311
|
+
num_workers: Number of worker processes
|
|
312
|
+
pin_memory: Whether to pin memory
|
|
313
|
+
collate_fn: Custom collate function
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
DataLoader
|
|
317
|
+
"""
|
|
318
|
+
return DataLoader(
|
|
319
|
+
dataset,
|
|
320
|
+
batch_size=batch_size,
|
|
321
|
+
shuffle=shuffle,
|
|
322
|
+
num_workers=num_workers,
|
|
323
|
+
pin_memory=pin_memory,
|
|
324
|
+
collate_fn=collate_fn
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
def split_dataset(
|
|
328
|
+
texts: List[str],
|
|
329
|
+
train_ratio: float = 0.8,
|
|
330
|
+
val_ratio: float = 0.1,
|
|
331
|
+
test_ratio: float = 0.1,
|
|
332
|
+
seed: int = 42
|
|
333
|
+
) -> tuple[List[str], List[str], List[str]]:
|
|
334
|
+
"""
|
|
335
|
+
Split dataset into train, validation, and test sets.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
texts: List of text strings
|
|
339
|
+
train_ratio: Ratio for training set
|
|
340
|
+
val_ratio: Ratio for validation set
|
|
341
|
+
test_ratio: Ratio for test set
|
|
342
|
+
seed: Random seed
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
Tuple of (train_texts, val_texts, test_texts)
|
|
346
|
+
"""
|
|
347
|
+
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
|
|
348
|
+
raise ValueError("Ratios must sum to 1.0")
|
|
349
|
+
|
|
350
|
+
random.seed(seed)
|
|
351
|
+
np.random.seed(seed)
|
|
352
|
+
|
|
353
|
+
# Shuffle the data
|
|
354
|
+
indices = list(range(len(texts)))
|
|
355
|
+
random.shuffle(indices)
|
|
356
|
+
|
|
357
|
+
# Calculate split points
|
|
358
|
+
train_size = int(len(texts) * train_ratio)
|
|
359
|
+
val_size = int(len(texts) * val_ratio)
|
|
360
|
+
|
|
361
|
+
# Split the data
|
|
362
|
+
train_indices = indices[:train_size]
|
|
363
|
+
val_indices = indices[train_size:train_size + val_size]
|
|
364
|
+
test_indices = indices[train_size + val_size:]
|
|
365
|
+
|
|
366
|
+
train_texts = [texts[i] for i in train_indices]
|
|
367
|
+
val_texts = [texts[i] for i in val_indices]
|
|
368
|
+
test_texts = [texts[i] for i in test_indices]
|
|
369
|
+
|
|
370
|
+
logger.info(f"Split dataset: {len(train_texts)} train, {len(val_texts)} val, {len(test_texts)} test")
|
|
371
|
+
|
|
372
|
+
return train_texts, val_texts, test_texts
|
|
373
|
+
|
|
374
|
+
class SimpleTokenizer:
|
|
375
|
+
"""
|
|
376
|
+
A simple tokenizer for demonstration purposes.
|
|
377
|
+
"""
|
|
378
|
+
|
|
379
|
+
def __init__(self, vocab_size: int = 32000):
|
|
380
|
+
self.vocab_size = vocab_size
|
|
381
|
+
self.pad_token_id = 0
|
|
382
|
+
self.unk_token_id = 1
|
|
383
|
+
self.bos_token_id = 2
|
|
384
|
+
self.eos_token_id = 3
|
|
385
|
+
|
|
386
|
+
# Create a simple vocabulary
|
|
387
|
+
self.vocab = {
|
|
388
|
+
"<pad>": self.pad_token_id,
|
|
389
|
+
"<unk>": self.unk_token_id,
|
|
390
|
+
"<bos>": self.bos_token_id,
|
|
391
|
+
"<eos>": self.eos_token_id
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
# Add common characters
|
|
395
|
+
for i, char in enumerate("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?;:'\"()[]{}"):
|
|
396
|
+
if len(self.vocab) < vocab_size:
|
|
397
|
+
self.vocab[char] = len(self.vocab)
|
|
398
|
+
|
|
399
|
+
# Create reverse vocabulary
|
|
400
|
+
self.id_to_token = {v: k for k, v in self.vocab.items()}
|
|
401
|
+
|
|
402
|
+
def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
|
|
403
|
+
"""Encode text to token IDs."""
|
|
404
|
+
tokens = []
|
|
405
|
+
|
|
406
|
+
if add_special_tokens:
|
|
407
|
+
tokens.append(self.bos_token_id)
|
|
408
|
+
|
|
409
|
+
for char in text:
|
|
410
|
+
if char in self.vocab:
|
|
411
|
+
tokens.append(self.vocab[char])
|
|
412
|
+
else:
|
|
413
|
+
tokens.append(self.unk_token_id)
|
|
414
|
+
|
|
415
|
+
if add_special_tokens:
|
|
416
|
+
tokens.append(self.eos_token_id)
|
|
417
|
+
|
|
418
|
+
return tokens
|
|
419
|
+
|
|
420
|
+
def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
|
|
421
|
+
"""Decode token IDs to text."""
|
|
422
|
+
tokens = []
|
|
423
|
+
|
|
424
|
+
for token_id in token_ids:
|
|
425
|
+
if token_id in self.id_to_token:
|
|
426
|
+
token = self.id_to_token[token_id]
|
|
427
|
+
if skip_special_tokens and token.startswith("<"):
|
|
428
|
+
continue
|
|
429
|
+
tokens.append(token)
|
|
430
|
+
|
|
431
|
+
return "".join(tokens)
|
|
432
|
+
|
|
433
|
+
def __call__(self, text: str, **kwargs) -> Dict[str, List[int]]:
|
|
434
|
+
"""Callable interface for compatibility."""
|
|
435
|
+
token_ids = self.encode(text, **kwargs)
|
|
436
|
+
return {"input_ids": token_ids}
|
|
437
|
+
|
|
438
|
+
# Example usage and utility functions
|
|
439
|
+
def create_sample_dataset(num_samples: int = 1000, text_length: int = 100) -> List[str]:
|
|
440
|
+
"""
|
|
441
|
+
Create a sample dataset for testing.
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
num_samples: Number of samples to generate
|
|
445
|
+
text_length: Length of each text sample
|
|
446
|
+
|
|
447
|
+
Returns:
|
|
448
|
+
List of sample texts
|
|
449
|
+
"""
|
|
450
|
+
sample_texts = []
|
|
451
|
+
|
|
452
|
+
for i in range(num_samples):
|
|
453
|
+
# Generate random text
|
|
454
|
+
chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?;:'\"()[]{}"
|
|
455
|
+
text = "".join(random.choices(chars, k=text_length))
|
|
456
|
+
sample_texts.append(text)
|
|
457
|
+
|
|
458
|
+
return sample_texts
|
|
459
|
+
|
|
460
|
+
def load_dataset_from_config(config) -> tuple[Dataset, Dataset, Dataset]:
|
|
461
|
+
"""
|
|
462
|
+
Load datasets based on configuration.
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
config: Configuration object
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
Tuple of (train_dataset, val_dataset, test_dataset)
|
|
469
|
+
"""
|
|
470
|
+
# Load training data
|
|
471
|
+
if config.data.train_file:
|
|
472
|
+
if config.data.train_file.endswith('.json'):
|
|
473
|
+
train_texts = load_json_file(config.data.train_file)
|
|
474
|
+
else:
|
|
475
|
+
train_texts = load_text_file(config.data.train_file)
|
|
476
|
+
else:
|
|
477
|
+
logger.warning("No training file specified, creating sample dataset")
|
|
478
|
+
train_texts = create_sample_dataset(1000)
|
|
479
|
+
|
|
480
|
+
# Load validation data
|
|
481
|
+
if config.data.eval_file:
|
|
482
|
+
if config.data.eval_file.endswith('.json'):
|
|
483
|
+
val_texts = load_json_file(config.data.eval_file)
|
|
484
|
+
else:
|
|
485
|
+
val_texts = load_text_file(config.data.eval_file)
|
|
486
|
+
else:
|
|
487
|
+
# Split training data for validation
|
|
488
|
+
train_texts, val_texts, _ = split_dataset(train_texts, train_ratio=0.9, val_ratio=0.1, test_ratio=0.0)
|
|
489
|
+
|
|
490
|
+
# Load test data
|
|
491
|
+
if config.data.test_file:
|
|
492
|
+
if config.data.test_file.endswith('.json'):
|
|
493
|
+
test_texts = load_json_file(config.data.test_file)
|
|
494
|
+
else:
|
|
495
|
+
test_texts = load_text_file(config.data.test_file)
|
|
496
|
+
else:
|
|
497
|
+
# Split training data for test
|
|
498
|
+
train_texts, _, test_texts = split_dataset(train_texts, train_ratio=0.8, val_ratio=0.0, test_ratio=0.2)
|
|
499
|
+
|
|
500
|
+
# Create tokenizer
|
|
501
|
+
tokenizer = None
|
|
502
|
+
if config.data.tokenizer_name:
|
|
503
|
+
# In a real implementation, you would load a proper tokenizer here
|
|
504
|
+
# For now, we'll use the simple tokenizer
|
|
505
|
+
tokenizer = SimpleTokenizer(config.model.vocab_size)
|
|
506
|
+
|
|
507
|
+
# Create datasets
|
|
508
|
+
train_dataset = LanguageModelingDataset(
|
|
509
|
+
train_texts,
|
|
510
|
+
tokenizer=tokenizer,
|
|
511
|
+
max_length=config.data.max_length
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
val_dataset = LanguageModelingDataset(
|
|
515
|
+
val_texts,
|
|
516
|
+
tokenizer=tokenizer,
|
|
517
|
+
max_length=config.data.max_length
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
test_dataset = LanguageModelingDataset(
|
|
521
|
+
test_texts,
|
|
522
|
+
tokenizer=tokenizer,
|
|
523
|
+
max_length=config.data.max_length
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
return train_dataset, val_dataset, test_dataset
|
langtune/distributed.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""
|
|
2
|
+
distributed.py: Distributed training utilities for Langtune
|
|
3
|
+
|
|
4
|
+
Provides helpers for multi-GPU and distributed training.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import torch
|
|
9
|
+
import torch.distributed as dist
|
|
10
|
+
from typing import Optional
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def is_distributed() -> bool:
|
|
17
|
+
"""Check if running in distributed mode."""
|
|
18
|
+
return dist.is_initialized()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_rank() -> int:
|
|
22
|
+
"""Get current process rank."""
|
|
23
|
+
if is_distributed():
|
|
24
|
+
return dist.get_rank()
|
|
25
|
+
return 0
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_world_size() -> int:
|
|
29
|
+
"""Get total number of processes."""
|
|
30
|
+
if is_distributed():
|
|
31
|
+
return dist.get_world_size()
|
|
32
|
+
return 1
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def is_main_process() -> bool:
|
|
36
|
+
"""Check if this is the main process."""
|
|
37
|
+
return get_rank() == 0
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def setup_distributed(backend: str = "nccl", init_method: str = "env://"):
|
|
41
|
+
"""Initialize distributed training."""
|
|
42
|
+
if not dist.is_initialized():
|
|
43
|
+
rank = int(os.environ.get("RANK", 0))
|
|
44
|
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
45
|
+
|
|
46
|
+
if world_size > 1:
|
|
47
|
+
dist.init_process_group(
|
|
48
|
+
backend=backend,
|
|
49
|
+
init_method=init_method,
|
|
50
|
+
world_size=world_size,
|
|
51
|
+
rank=rank
|
|
52
|
+
)
|
|
53
|
+
logger.info(f"Initialized distributed: rank={rank}, world_size={world_size}")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def cleanup_distributed():
|
|
57
|
+
"""Clean up distributed training."""
|
|
58
|
+
if is_distributed():
|
|
59
|
+
dist.destroy_process_group()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def barrier():
|
|
63
|
+
"""Synchronize all processes."""
|
|
64
|
+
if is_distributed():
|
|
65
|
+
dist.barrier()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def all_reduce(tensor: torch.Tensor, op=dist.ReduceOp.SUM) -> torch.Tensor:
|
|
69
|
+
"""All-reduce across processes."""
|
|
70
|
+
if is_distributed():
|
|
71
|
+
dist.all_reduce(tensor, op=op)
|
|
72
|
+
return tensor
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
|
76
|
+
"""Broadcast tensor from source rank."""
|
|
77
|
+
if is_distributed():
|
|
78
|
+
dist.broadcast(tensor, src)
|
|
79
|
+
return tensor
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def all_gather(tensor: torch.Tensor) -> torch.Tensor:
|
|
83
|
+
"""Gather tensors from all processes."""
|
|
84
|
+
if not is_distributed():
|
|
85
|
+
return tensor
|
|
86
|
+
|
|
87
|
+
world_size = get_world_size()
|
|
88
|
+
tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
|
|
89
|
+
dist.all_gather(tensors, tensor)
|
|
90
|
+
return torch.cat(tensors, dim=0)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class DistributedDataParallelWrapper:
|
|
94
|
+
"""Simple DDP wrapper for models."""
|
|
95
|
+
|
|
96
|
+
def __init__(self, model: torch.nn.Module, device_id: Optional[int] = None):
|
|
97
|
+
self.model = model
|
|
98
|
+
self.device_id = device_id
|
|
99
|
+
|
|
100
|
+
if is_distributed():
|
|
101
|
+
self.model = torch.nn.parallel.DistributedDataParallel(
|
|
102
|
+
model,
|
|
103
|
+
device_ids=[device_id] if device_id is not None else None,
|
|
104
|
+
output_device=device_id
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def __getattr__(self, name):
|
|
108
|
+
return getattr(self.model, name)
|
|
109
|
+
|
|
110
|
+
def __call__(self, *args, **kwargs):
|
|
111
|
+
return self.model(*args, **kwargs)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def wrap_model_ddp(model: torch.nn.Module, device_id: Optional[int] = None):
|
|
115
|
+
"""Wrap model with DDP if in distributed mode."""
|
|
116
|
+
if is_distributed():
|
|
117
|
+
return torch.nn.parallel.DistributedDataParallel(
|
|
118
|
+
model,
|
|
119
|
+
device_ids=[device_id] if device_id is not None else None,
|
|
120
|
+
output_device=device_id
|
|
121
|
+
)
|
|
122
|
+
return model
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def get_distributed_sampler(dataset, shuffle: bool = True):
|
|
126
|
+
"""Get distributed sampler for dataset."""
|
|
127
|
+
if is_distributed():
|
|
128
|
+
return torch.utils.data.DistributedSampler(
|
|
129
|
+
dataset,
|
|
130
|
+
num_replicas=get_world_size(),
|
|
131
|
+
rank=get_rank(),
|
|
132
|
+
shuffle=shuffle
|
|
133
|
+
)
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def reduce_dict(input_dict: dict, average: bool = True) -> dict:
|
|
138
|
+
"""Reduce dictionary values across processes."""
|
|
139
|
+
if not is_distributed():
|
|
140
|
+
return input_dict
|
|
141
|
+
|
|
142
|
+
world_size = get_world_size()
|
|
143
|
+
keys = sorted(input_dict.keys())
|
|
144
|
+
values = torch.tensor([input_dict[k] for k in keys], dtype=torch.float32)
|
|
145
|
+
|
|
146
|
+
if values.is_cuda:
|
|
147
|
+
values = values.cuda()
|
|
148
|
+
|
|
149
|
+
dist.all_reduce(values, op=dist.ReduceOp.SUM)
|
|
150
|
+
|
|
151
|
+
if average:
|
|
152
|
+
values /= world_size
|
|
153
|
+
|
|
154
|
+
return {k: v.item() for k, v in zip(keys, values)}
|