langtune 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langtune/data.py ADDED
@@ -0,0 +1,526 @@
1
+ """
2
+ data.py: Data loading and preprocessing utilities for Langtune
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import torch
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from typing import List, Dict, Any, Optional, Union, Iterator
10
+ import logging
11
+ from pathlib import Path
12
+ import random
13
+ import numpy as np
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class TextDataset(Dataset):
18
+ """
19
+ A PyTorch Dataset for text data with tokenization support.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ texts: List[str],
25
+ tokenizer=None,
26
+ max_length: int = 512,
27
+ padding: str = "max_length",
28
+ truncation: bool = True,
29
+ add_special_tokens: bool = True
30
+ ):
31
+ """
32
+ Initialize the dataset.
33
+
34
+ Args:
35
+ texts: List of text strings
36
+ tokenizer: Tokenizer object (optional)
37
+ max_length: Maximum sequence length
38
+ padding: Padding strategy
39
+ truncation: Whether to truncate sequences
40
+ add_special_tokens: Whether to add special tokens
41
+ """
42
+ self.texts = texts
43
+ self.tokenizer = tokenizer
44
+ self.max_length = max_length
45
+ self.padding = padding
46
+ self.truncation = truncation
47
+ self.add_special_tokens = add_special_tokens
48
+
49
+ logger.info(f"Initialized TextDataset with {len(texts)} samples")
50
+
51
+ def __len__(self) -> int:
52
+ return len(self.texts)
53
+
54
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
55
+ text = self.texts[idx]
56
+
57
+ if self.tokenizer:
58
+ # Use provided tokenizer
59
+ encoding = self.tokenizer(
60
+ text,
61
+ max_length=self.max_length,
62
+ padding=self.padding,
63
+ truncation=self.truncation,
64
+ add_special_tokens=self.add_special_tokens,
65
+ return_tensors="pt"
66
+ )
67
+ return {
68
+ "input_ids": encoding["input_ids"].squeeze(0),
69
+ "attention_mask": encoding.get("attention_mask", torch.ones_like(encoding["input_ids"])).squeeze(0)
70
+ }
71
+ else:
72
+ # Simple character-level tokenization as fallback
73
+ input_ids = torch.tensor([ord(c) for c in text[:self.max_length]], dtype=torch.long)
74
+
75
+ # Pad or truncate
76
+ if len(input_ids) < self.max_length:
77
+ padding_length = self.max_length - len(input_ids)
78
+ input_ids = torch.cat([input_ids, torch.zeros(padding_length, dtype=torch.long)])
79
+ attention_mask = torch.cat([torch.ones(len(input_ids) - padding_length), torch.zeros(padding_length)])
80
+ else:
81
+ attention_mask = torch.ones(self.max_length)
82
+
83
+ return {
84
+ "input_ids": input_ids,
85
+ "attention_mask": attention_mask
86
+ }
87
+
88
+ class LanguageModelingDataset(Dataset):
89
+ """
90
+ Dataset for language modeling tasks with next-token prediction.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ texts: List[str],
96
+ tokenizer=None,
97
+ max_length: int = 512,
98
+ stride: int = 128,
99
+ padding: str = "max_length",
100
+ truncation: bool = True
101
+ ):
102
+ """
103
+ Initialize the language modeling dataset.
104
+
105
+ Args:
106
+ texts: List of text strings
107
+ tokenizer: Tokenizer object
108
+ max_length: Maximum sequence length
109
+ stride: Stride for sliding window
110
+ padding: Padding strategy
111
+ truncation: Whether to truncate sequences
112
+ """
113
+ self.texts = texts
114
+ self.tokenizer = tokenizer
115
+ self.max_length = max_length
116
+ self.stride = stride
117
+ self.padding = padding
118
+ self.truncation = truncation
119
+
120
+ # Process texts into sequences
121
+ self.sequences = self._create_sequences()
122
+
123
+ logger.info(f"Initialized LanguageModelingDataset with {len(self.sequences)} sequences")
124
+
125
+ def _create_sequences(self) -> List[Dict[str, torch.Tensor]]:
126
+ """Create sequences for language modeling."""
127
+ sequences = []
128
+
129
+ for text in self.texts:
130
+ if self.tokenizer:
131
+ # Tokenize the text
132
+ tokens = self.tokenizer.encode(text, add_special_tokens=False)
133
+
134
+ # Create sliding window sequences
135
+ for i in range(0, len(tokens), self.stride):
136
+ sequence = tokens[i:i + self.max_length]
137
+
138
+ if len(sequence) < self.max_length:
139
+ # Pad sequence
140
+ sequence = sequence + [self.tokenizer.pad_token_id] * (self.max_length - len(sequence))
141
+ attention_mask = [1] * (len(sequence) - (self.max_length - len(sequence))) + [0] * (self.max_length - len(sequence))
142
+ else:
143
+ attention_mask = [1] * self.max_length
144
+
145
+ # Create labels (shifted by 1 for next token prediction)
146
+ labels = sequence[1:] + [-100] # -100 is ignored in loss computation
147
+
148
+ sequences.append({
149
+ "input_ids": torch.tensor(sequence, dtype=torch.long),
150
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
151
+ "labels": torch.tensor(labels, dtype=torch.long)
152
+ })
153
+ else:
154
+ # Simple character-level processing
155
+ chars = [ord(c) for c in text]
156
+
157
+ for i in range(0, len(chars), self.stride):
158
+ sequence = chars[i:i + self.max_length]
159
+
160
+ if len(sequence) < self.max_length:
161
+ sequence = sequence + [0] * (self.max_length - len(sequence))
162
+ attention_mask = [1] * (len(sequence) - (self.max_length - len(sequence))) + [0] * (self.max_length - len(sequence))
163
+ else:
164
+ attention_mask = [1] * self.max_length
165
+
166
+ labels = sequence[1:] + [-100]
167
+
168
+ sequences.append({
169
+ "input_ids": torch.tensor(sequence, dtype=torch.long),
170
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
171
+ "labels": torch.tensor(labels, dtype=torch.long)
172
+ })
173
+
174
+ return sequences
175
+
176
+ def __len__(self) -> int:
177
+ return len(self.sequences)
178
+
179
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
180
+ return self.sequences[idx]
181
+
182
+ class DataCollator:
183
+ """
184
+ Data collator for batching sequences.
185
+ """
186
+
187
+ def __init__(self, pad_token_id: int = 0):
188
+ self.pad_token_id = pad_token_id
189
+
190
+ def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
191
+ """
192
+ Collate a batch of sequences.
193
+
194
+ Args:
195
+ batch: List of dictionaries containing sequences
196
+
197
+ Returns:
198
+ Batched tensors
199
+ """
200
+ # Get the maximum length in the batch
201
+ max_length = max(item["input_ids"].size(0) for item in batch)
202
+
203
+ # Pad sequences to the same length
204
+ input_ids = []
205
+ attention_masks = []
206
+ labels = []
207
+
208
+ for item in batch:
209
+ seq_len = item["input_ids"].size(0)
210
+
211
+ # Pad input_ids
212
+ if seq_len < max_length:
213
+ padding = torch.full((max_length - seq_len,), self.pad_token_id, dtype=torch.long)
214
+ input_ids.append(torch.cat([item["input_ids"], padding]))
215
+ else:
216
+ input_ids.append(item["input_ids"])
217
+
218
+ # Pad attention_mask
219
+ if seq_len < max_length:
220
+ padding = torch.zeros(max_length - seq_len, dtype=torch.long)
221
+ attention_masks.append(torch.cat([item["attention_mask"], padding]))
222
+ else:
223
+ attention_masks.append(item["attention_mask"])
224
+
225
+ # Pad labels if present
226
+ if "labels" in item:
227
+ if seq_len < max_length:
228
+ padding = torch.full((max_length - seq_len,), -100, dtype=torch.long)
229
+ labels.append(torch.cat([item["labels"], padding]))
230
+ else:
231
+ labels.append(item["labels"])
232
+
233
+ result = {
234
+ "input_ids": torch.stack(input_ids),
235
+ "attention_mask": torch.stack(attention_masks)
236
+ }
237
+
238
+ if labels:
239
+ result["labels"] = torch.stack(labels)
240
+
241
+ return result
242
+
243
+ def load_text_file(file_path: str, encoding: str = "utf-8") -> List[str]:
244
+ """
245
+ Load text from a file.
246
+
247
+ Args:
248
+ file_path: Path to the text file
249
+ encoding: File encoding
250
+
251
+ Returns:
252
+ List of text lines
253
+ """
254
+ if not os.path.exists(file_path):
255
+ raise FileNotFoundError(f"File not found: {file_path}")
256
+
257
+ with open(file_path, 'r', encoding=encoding) as f:
258
+ lines = f.readlines()
259
+
260
+ # Remove empty lines and strip whitespace
261
+ lines = [line.strip() for line in lines if line.strip()]
262
+
263
+ logger.info(f"Loaded {len(lines)} lines from {file_path}")
264
+ return lines
265
+
266
+ def load_json_file(file_path: str, text_key: str = "text") -> List[str]:
267
+ """
268
+ Load text from a JSON file.
269
+
270
+ Args:
271
+ file_path: Path to the JSON file
272
+ text_key: Key containing the text data
273
+
274
+ Returns:
275
+ List of text strings
276
+ """
277
+ if not os.path.exists(file_path):
278
+ raise FileNotFoundError(f"File not found: {file_path}")
279
+
280
+ with open(file_path, 'r', encoding='utf-8') as f:
281
+ data = json.load(f)
282
+
283
+ if isinstance(data, list):
284
+ texts = [item[text_key] for item in data if text_key in item]
285
+ elif isinstance(data, dict):
286
+ if text_key in data:
287
+ texts = [data[text_key]]
288
+ else:
289
+ raise ValueError(f"Key '{text_key}' not found in JSON data")
290
+ else:
291
+ raise ValueError("JSON data must be a list or dictionary")
292
+
293
+ logger.info(f"Loaded {len(texts)} texts from {file_path}")
294
+ return texts
295
+
296
+ def create_data_loader(
297
+ dataset: Dataset,
298
+ batch_size: int = 32,
299
+ shuffle: bool = True,
300
+ num_workers: int = 4,
301
+ pin_memory: bool = True,
302
+ collate_fn=None
303
+ ) -> DataLoader:
304
+ """
305
+ Create a DataLoader for the dataset.
306
+
307
+ Args:
308
+ dataset: PyTorch Dataset
309
+ batch_size: Batch size
310
+ shuffle: Whether to shuffle the data
311
+ num_workers: Number of worker processes
312
+ pin_memory: Whether to pin memory
313
+ collate_fn: Custom collate function
314
+
315
+ Returns:
316
+ DataLoader
317
+ """
318
+ return DataLoader(
319
+ dataset,
320
+ batch_size=batch_size,
321
+ shuffle=shuffle,
322
+ num_workers=num_workers,
323
+ pin_memory=pin_memory,
324
+ collate_fn=collate_fn
325
+ )
326
+
327
+ def split_dataset(
328
+ texts: List[str],
329
+ train_ratio: float = 0.8,
330
+ val_ratio: float = 0.1,
331
+ test_ratio: float = 0.1,
332
+ seed: int = 42
333
+ ) -> tuple[List[str], List[str], List[str]]:
334
+ """
335
+ Split dataset into train, validation, and test sets.
336
+
337
+ Args:
338
+ texts: List of text strings
339
+ train_ratio: Ratio for training set
340
+ val_ratio: Ratio for validation set
341
+ test_ratio: Ratio for test set
342
+ seed: Random seed
343
+
344
+ Returns:
345
+ Tuple of (train_texts, val_texts, test_texts)
346
+ """
347
+ if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
348
+ raise ValueError("Ratios must sum to 1.0")
349
+
350
+ random.seed(seed)
351
+ np.random.seed(seed)
352
+
353
+ # Shuffle the data
354
+ indices = list(range(len(texts)))
355
+ random.shuffle(indices)
356
+
357
+ # Calculate split points
358
+ train_size = int(len(texts) * train_ratio)
359
+ val_size = int(len(texts) * val_ratio)
360
+
361
+ # Split the data
362
+ train_indices = indices[:train_size]
363
+ val_indices = indices[train_size:train_size + val_size]
364
+ test_indices = indices[train_size + val_size:]
365
+
366
+ train_texts = [texts[i] for i in train_indices]
367
+ val_texts = [texts[i] for i in val_indices]
368
+ test_texts = [texts[i] for i in test_indices]
369
+
370
+ logger.info(f"Split dataset: {len(train_texts)} train, {len(val_texts)} val, {len(test_texts)} test")
371
+
372
+ return train_texts, val_texts, test_texts
373
+
374
+ class SimpleTokenizer:
375
+ """
376
+ A simple tokenizer for demonstration purposes.
377
+ """
378
+
379
+ def __init__(self, vocab_size: int = 32000):
380
+ self.vocab_size = vocab_size
381
+ self.pad_token_id = 0
382
+ self.unk_token_id = 1
383
+ self.bos_token_id = 2
384
+ self.eos_token_id = 3
385
+
386
+ # Create a simple vocabulary
387
+ self.vocab = {
388
+ "<pad>": self.pad_token_id,
389
+ "<unk>": self.unk_token_id,
390
+ "<bos>": self.bos_token_id,
391
+ "<eos>": self.eos_token_id
392
+ }
393
+
394
+ # Add common characters
395
+ for i, char in enumerate("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?;:'\"()[]{}"):
396
+ if len(self.vocab) < vocab_size:
397
+ self.vocab[char] = len(self.vocab)
398
+
399
+ # Create reverse vocabulary
400
+ self.id_to_token = {v: k for k, v in self.vocab.items()}
401
+
402
+ def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
403
+ """Encode text to token IDs."""
404
+ tokens = []
405
+
406
+ if add_special_tokens:
407
+ tokens.append(self.bos_token_id)
408
+
409
+ for char in text:
410
+ if char in self.vocab:
411
+ tokens.append(self.vocab[char])
412
+ else:
413
+ tokens.append(self.unk_token_id)
414
+
415
+ if add_special_tokens:
416
+ tokens.append(self.eos_token_id)
417
+
418
+ return tokens
419
+
420
+ def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
421
+ """Decode token IDs to text."""
422
+ tokens = []
423
+
424
+ for token_id in token_ids:
425
+ if token_id in self.id_to_token:
426
+ token = self.id_to_token[token_id]
427
+ if skip_special_tokens and token.startswith("<"):
428
+ continue
429
+ tokens.append(token)
430
+
431
+ return "".join(tokens)
432
+
433
+ def __call__(self, text: str, **kwargs) -> Dict[str, List[int]]:
434
+ """Callable interface for compatibility."""
435
+ token_ids = self.encode(text, **kwargs)
436
+ return {"input_ids": token_ids}
437
+
438
+ # Example usage and utility functions
439
+ def create_sample_dataset(num_samples: int = 1000, text_length: int = 100) -> List[str]:
440
+ """
441
+ Create a sample dataset for testing.
442
+
443
+ Args:
444
+ num_samples: Number of samples to generate
445
+ text_length: Length of each text sample
446
+
447
+ Returns:
448
+ List of sample texts
449
+ """
450
+ sample_texts = []
451
+
452
+ for i in range(num_samples):
453
+ # Generate random text
454
+ chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?;:'\"()[]{}"
455
+ text = "".join(random.choices(chars, k=text_length))
456
+ sample_texts.append(text)
457
+
458
+ return sample_texts
459
+
460
+ def load_dataset_from_config(config) -> tuple[Dataset, Dataset, Dataset]:
461
+ """
462
+ Load datasets based on configuration.
463
+
464
+ Args:
465
+ config: Configuration object
466
+
467
+ Returns:
468
+ Tuple of (train_dataset, val_dataset, test_dataset)
469
+ """
470
+ # Load training data
471
+ if config.data.train_file:
472
+ if config.data.train_file.endswith('.json'):
473
+ train_texts = load_json_file(config.data.train_file)
474
+ else:
475
+ train_texts = load_text_file(config.data.train_file)
476
+ else:
477
+ logger.warning("No training file specified, creating sample dataset")
478
+ train_texts = create_sample_dataset(1000)
479
+
480
+ # Load validation data
481
+ if config.data.eval_file:
482
+ if config.data.eval_file.endswith('.json'):
483
+ val_texts = load_json_file(config.data.eval_file)
484
+ else:
485
+ val_texts = load_text_file(config.data.eval_file)
486
+ else:
487
+ # Split training data for validation
488
+ train_texts, val_texts, _ = split_dataset(train_texts, train_ratio=0.9, val_ratio=0.1, test_ratio=0.0)
489
+
490
+ # Load test data
491
+ if config.data.test_file:
492
+ if config.data.test_file.endswith('.json'):
493
+ test_texts = load_json_file(config.data.test_file)
494
+ else:
495
+ test_texts = load_text_file(config.data.test_file)
496
+ else:
497
+ # Split training data for test
498
+ train_texts, _, test_texts = split_dataset(train_texts, train_ratio=0.8, val_ratio=0.0, test_ratio=0.2)
499
+
500
+ # Create tokenizer
501
+ tokenizer = None
502
+ if config.data.tokenizer_name:
503
+ # In a real implementation, you would load a proper tokenizer here
504
+ # For now, we'll use the simple tokenizer
505
+ tokenizer = SimpleTokenizer(config.model.vocab_size)
506
+
507
+ # Create datasets
508
+ train_dataset = LanguageModelingDataset(
509
+ train_texts,
510
+ tokenizer=tokenizer,
511
+ max_length=config.data.max_length
512
+ )
513
+
514
+ val_dataset = LanguageModelingDataset(
515
+ val_texts,
516
+ tokenizer=tokenizer,
517
+ max_length=config.data.max_length
518
+ )
519
+
520
+ test_dataset = LanguageModelingDataset(
521
+ test_texts,
522
+ tokenizer=tokenizer,
523
+ max_length=config.data.max_length
524
+ )
525
+
526
+ return train_dataset, val_dataset, test_dataset
@@ -0,0 +1,154 @@
1
+ """
2
+ distributed.py: Distributed training utilities for Langtune
3
+
4
+ Provides helpers for multi-GPU and distributed training.
5
+ """
6
+
7
+ import os
8
+ import torch
9
+ import torch.distributed as dist
10
+ from typing import Optional
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def is_distributed() -> bool:
17
+ """Check if running in distributed mode."""
18
+ return dist.is_initialized()
19
+
20
+
21
+ def get_rank() -> int:
22
+ """Get current process rank."""
23
+ if is_distributed():
24
+ return dist.get_rank()
25
+ return 0
26
+
27
+
28
+ def get_world_size() -> int:
29
+ """Get total number of processes."""
30
+ if is_distributed():
31
+ return dist.get_world_size()
32
+ return 1
33
+
34
+
35
+ def is_main_process() -> bool:
36
+ """Check if this is the main process."""
37
+ return get_rank() == 0
38
+
39
+
40
+ def setup_distributed(backend: str = "nccl", init_method: str = "env://"):
41
+ """Initialize distributed training."""
42
+ if not dist.is_initialized():
43
+ rank = int(os.environ.get("RANK", 0))
44
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
45
+
46
+ if world_size > 1:
47
+ dist.init_process_group(
48
+ backend=backend,
49
+ init_method=init_method,
50
+ world_size=world_size,
51
+ rank=rank
52
+ )
53
+ logger.info(f"Initialized distributed: rank={rank}, world_size={world_size}")
54
+
55
+
56
+ def cleanup_distributed():
57
+ """Clean up distributed training."""
58
+ if is_distributed():
59
+ dist.destroy_process_group()
60
+
61
+
62
+ def barrier():
63
+ """Synchronize all processes."""
64
+ if is_distributed():
65
+ dist.barrier()
66
+
67
+
68
+ def all_reduce(tensor: torch.Tensor, op=dist.ReduceOp.SUM) -> torch.Tensor:
69
+ """All-reduce across processes."""
70
+ if is_distributed():
71
+ dist.all_reduce(tensor, op=op)
72
+ return tensor
73
+
74
+
75
+ def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
76
+ """Broadcast tensor from source rank."""
77
+ if is_distributed():
78
+ dist.broadcast(tensor, src)
79
+ return tensor
80
+
81
+
82
+ def all_gather(tensor: torch.Tensor) -> torch.Tensor:
83
+ """Gather tensors from all processes."""
84
+ if not is_distributed():
85
+ return tensor
86
+
87
+ world_size = get_world_size()
88
+ tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
89
+ dist.all_gather(tensors, tensor)
90
+ return torch.cat(tensors, dim=0)
91
+
92
+
93
+ class DistributedDataParallelWrapper:
94
+ """Simple DDP wrapper for models."""
95
+
96
+ def __init__(self, model: torch.nn.Module, device_id: Optional[int] = None):
97
+ self.model = model
98
+ self.device_id = device_id
99
+
100
+ if is_distributed():
101
+ self.model = torch.nn.parallel.DistributedDataParallel(
102
+ model,
103
+ device_ids=[device_id] if device_id is not None else None,
104
+ output_device=device_id
105
+ )
106
+
107
+ def __getattr__(self, name):
108
+ return getattr(self.model, name)
109
+
110
+ def __call__(self, *args, **kwargs):
111
+ return self.model(*args, **kwargs)
112
+
113
+
114
+ def wrap_model_ddp(model: torch.nn.Module, device_id: Optional[int] = None):
115
+ """Wrap model with DDP if in distributed mode."""
116
+ if is_distributed():
117
+ return torch.nn.parallel.DistributedDataParallel(
118
+ model,
119
+ device_ids=[device_id] if device_id is not None else None,
120
+ output_device=device_id
121
+ )
122
+ return model
123
+
124
+
125
+ def get_distributed_sampler(dataset, shuffle: bool = True):
126
+ """Get distributed sampler for dataset."""
127
+ if is_distributed():
128
+ return torch.utils.data.DistributedSampler(
129
+ dataset,
130
+ num_replicas=get_world_size(),
131
+ rank=get_rank(),
132
+ shuffle=shuffle
133
+ )
134
+ return None
135
+
136
+
137
+ def reduce_dict(input_dict: dict, average: bool = True) -> dict:
138
+ """Reduce dictionary values across processes."""
139
+ if not is_distributed():
140
+ return input_dict
141
+
142
+ world_size = get_world_size()
143
+ keys = sorted(input_dict.keys())
144
+ values = torch.tensor([input_dict[k] for k in keys], dtype=torch.float32)
145
+
146
+ if values.is_cuda:
147
+ values = values.cuda()
148
+
149
+ dist.all_reduce(values, op=dist.ReduceOp.SUM)
150
+
151
+ if average:
152
+ values /= world_size
153
+
154
+ return {k: v.item() for k, v in zip(keys, values)}