langtune 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langtune/generation.py ADDED
@@ -0,0 +1,95 @@
1
+ """
2
+ generation.py: Text generation utilities for Langtune
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Optional, Union
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class TextGenerator:
15
+ """High-level text generation with sampling strategies."""
16
+
17
+ def __init__(self, model: nn.Module, tokenizer=None, device: Optional[torch.device] = None):
18
+ self.model = model
19
+ self.tokenizer = tokenizer
20
+ self.device = device or next(model.parameters()).device
21
+ self.model.eval()
22
+
23
+ @torch.no_grad()
24
+ def generate(
25
+ self,
26
+ prompt: Union[str, torch.Tensor],
27
+ max_length: int = 100,
28
+ temperature: float = 1.0,
29
+ top_k: Optional[int] = None,
30
+ top_p: Optional[float] = None,
31
+ repetition_penalty: float = 1.0,
32
+ do_sample: bool = True,
33
+ eos_token_id: int = 1
34
+ ) -> Union[str, torch.Tensor]:
35
+ """Generate text with various sampling strategies."""
36
+ # Encode prompt
37
+ if isinstance(prompt, str):
38
+ if self.tokenizer:
39
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
40
+ else:
41
+ input_ids = torch.tensor([[ord(c) for c in prompt]], device=self.device)
42
+ else:
43
+ input_ids = prompt.to(self.device)
44
+
45
+ for _ in range(max_length - input_ids.size(1)):
46
+ outputs = self.model(input_ids)
47
+ logits = outputs["logits"] if isinstance(outputs, dict) else outputs
48
+ next_logits = logits[:, -1, :]
49
+
50
+ # Repetition penalty
51
+ if repetition_penalty != 1.0:
52
+ for token in input_ids[0].unique():
53
+ next_logits[0, token] /= repetition_penalty
54
+
55
+ # Temperature
56
+ if do_sample and temperature != 1.0:
57
+ next_logits = next_logits / temperature
58
+
59
+ # Top-k filtering
60
+ if top_k:
61
+ top_k = min(top_k, next_logits.size(-1))
62
+ indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
63
+ next_logits[indices_to_remove] = float('-inf')
64
+
65
+ # Top-p filtering
66
+ if top_p and top_p < 1.0:
67
+ sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
68
+ cumulative = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
69
+ remove = cumulative > top_p
70
+ remove[..., 1:] = remove[..., :-1].clone()
71
+ remove[..., 0] = 0
72
+ indices_to_remove = remove.scatter(1, sorted_indices, remove)
73
+ next_logits[indices_to_remove] = float('-inf')
74
+
75
+ # Sample or greedy
76
+ if do_sample:
77
+ next_token = torch.multinomial(F.softmax(next_logits, dim=-1), 1)
78
+ else:
79
+ next_token = next_logits.argmax(dim=-1, keepdim=True)
80
+
81
+ input_ids = torch.cat([input_ids, next_token], dim=1)
82
+
83
+ if next_token.item() == eos_token_id:
84
+ break
85
+
86
+ if self.tokenizer:
87
+ return self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
88
+ return input_ids
89
+
90
+
91
+ def generate(model, prompt, max_length=100, temperature=0.8, top_k=50, top_p=0.95, **kwargs):
92
+ """Convenience function for text generation."""
93
+ return TextGenerator(model).generate(
94
+ prompt, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, **kwargs
95
+ )
@@ -0,0 +1,182 @@
1
+ """
2
+ logging_utils.py: Logging utilities for Langtune
3
+
4
+ Provides colorful logging and progress tracking.
5
+ """
6
+
7
+ import logging
8
+ import sys
9
+ from typing import Optional
10
+ from datetime import datetime
11
+
12
+ # Rich console for pretty output
13
+ try:
14
+ from rich.console import Console
15
+ from rich.logging import RichHandler
16
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn
17
+ RICH_AVAILABLE = True
18
+ except ImportError:
19
+ RICH_AVAILABLE = False
20
+
21
+
22
+ def setup_logging(
23
+ level: str = "INFO",
24
+ log_file: Optional[str] = None,
25
+ use_rich: bool = True
26
+ ):
27
+ """Setup logging with optional rich formatting."""
28
+ log_level = getattr(logging, level.upper(), logging.INFO)
29
+
30
+ handlers = []
31
+
32
+ if use_rich and RICH_AVAILABLE:
33
+ handlers.append(RichHandler(
34
+ console=Console(stderr=True),
35
+ show_time=True,
36
+ show_path=False,
37
+ rich_tracebacks=True
38
+ ))
39
+ else:
40
+ handler = logging.StreamHandler(sys.stderr)
41
+ handler.setFormatter(logging.Formatter(
42
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
43
+ ))
44
+ handlers.append(handler)
45
+
46
+ if log_file:
47
+ file_handler = logging.FileHandler(log_file)
48
+ file_handler.setFormatter(logging.Formatter(
49
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
50
+ ))
51
+ handlers.append(file_handler)
52
+
53
+ logging.basicConfig(level=log_level, handlers=handlers, force=True)
54
+
55
+ # Set langtune logger
56
+ logger = logging.getLogger("langtune")
57
+ logger.setLevel(log_level)
58
+
59
+
60
+ def get_logger(name: str = "langtune") -> logging.Logger:
61
+ """Get a logger with the given name."""
62
+ return logging.getLogger(name)
63
+
64
+
65
+ class TrainingLogger:
66
+ """Structured logger for training progress."""
67
+
68
+ def __init__(self, name: str = "training", log_file: Optional[str] = None):
69
+ self.logger = logging.getLogger(name)
70
+ self.log_file = log_file
71
+ self.start_time = None
72
+ self.metrics_history = []
73
+
74
+ def start(self, message: str = "Starting training"):
75
+ """Log training start."""
76
+ self.start_time = datetime.now()
77
+ self.logger.info(f"🚀 {message}")
78
+
79
+ def log_epoch(self, epoch: int, total_epochs: int, metrics: dict):
80
+ """Log epoch completion."""
81
+ metrics_str = ", ".join(f"{k}={v:.4f}" for k, v in metrics.items())
82
+ self.logger.info(f"📊 Epoch {epoch+1}/{total_epochs}: {metrics_str}")
83
+ self.metrics_history.append({"epoch": epoch + 1, **metrics})
84
+
85
+ def log_step(self, step: int, loss: float, lr: Optional[float] = None):
86
+ """Log training step."""
87
+ msg = f"Step {step}: loss={loss:.4f}"
88
+ if lr is not None:
89
+ msg += f", lr={lr:.2e}"
90
+ self.logger.debug(msg)
91
+
92
+ def log_validation(self, metrics: dict):
93
+ """Log validation results."""
94
+ metrics_str = ", ".join(f"{k}={v:.4f}" for k, v in metrics.items())
95
+ self.logger.info(f"✓ Validation: {metrics_str}")
96
+
97
+ def end(self, message: str = "Training complete"):
98
+ """Log training end."""
99
+ if self.start_time:
100
+ elapsed = datetime.now() - self.start_time
101
+ self.logger.info(f"🎉 {message} (took {elapsed})")
102
+ else:
103
+ self.logger.info(f"🎉 {message}")
104
+
105
+ def error(self, message: str):
106
+ """Log error."""
107
+ self.logger.error(f"❌ {message}")
108
+
109
+ def warning(self, message: str):
110
+ """Log warning."""
111
+ self.logger.warning(f"⚠️ {message}")
112
+
113
+
114
+ class ProgressTracker:
115
+ """Progress bar for training."""
116
+
117
+ def __init__(self, total: int, description: str = "Training"):
118
+ self.total = total
119
+ self.description = description
120
+ self.current = 0
121
+ self.progress = None
122
+ self.task = None
123
+
124
+ def __enter__(self):
125
+ if RICH_AVAILABLE:
126
+ self.progress = Progress(
127
+ SpinnerColumn(),
128
+ TextColumn("[bold blue]{task.description}"),
129
+ BarColumn(),
130
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
131
+ TimeElapsedColumn()
132
+ )
133
+ self.progress.start()
134
+ self.task = self.progress.add_task(self.description, total=self.total)
135
+ return self
136
+
137
+ def __exit__(self, *args):
138
+ if self.progress:
139
+ self.progress.stop()
140
+
141
+ def update(self, n: int = 1, description: Optional[str] = None):
142
+ """Update progress."""
143
+ self.current += n
144
+ if self.progress and self.task is not None:
145
+ self.progress.update(self.task, advance=n)
146
+ if description:
147
+ self.progress.update(self.task, description=description)
148
+
149
+ def set_description(self, description: str):
150
+ """Update description."""
151
+ if self.progress and self.task is not None:
152
+ self.progress.update(self.task, description=description)
153
+
154
+
155
+ def print_banner(text: str, style: str = "cyan"):
156
+ """Print a styled banner."""
157
+ if RICH_AVAILABLE:
158
+ console = Console()
159
+ console.print(f"\n[bold {style}]{'='*60}[/]")
160
+ console.print(f"[bold {style}] {text}[/]")
161
+ console.print(f"[bold {style}]{'='*60}[/]\n")
162
+ else:
163
+ print(f"\n{'='*60}")
164
+ print(f" {text}")
165
+ print(f"{'='*60}\n")
166
+
167
+
168
+ def print_metrics(metrics: dict, title: str = "Metrics"):
169
+ """Print metrics in a nice format."""
170
+ if RICH_AVAILABLE:
171
+ from rich.table import Table
172
+ console = Console()
173
+ table = Table(title=title)
174
+ table.add_column("Metric", style="cyan")
175
+ table.add_column("Value", style="green")
176
+ for k, v in metrics.items():
177
+ table.add_row(k, f"{v:.4f}" if isinstance(v, float) else str(v))
178
+ console.print(table)
179
+ else:
180
+ print(f"\n{title}:")
181
+ for k, v in metrics.items():
182
+ print(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}")
langtune/metrics.py ADDED
@@ -0,0 +1,345 @@
1
+ """
2
+ metrics.py: Evaluation metrics for Langtune
3
+
4
+ Provides metrics for evaluating language models.
5
+ """
6
+
7
+ import math
8
+ import torch
9
+ import numpy as np
10
+ from typing import List, Dict, Any, Optional, Union
11
+ from collections import Counter
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def compute_perplexity(loss: float) -> float:
18
+ """
19
+ Compute perplexity from cross-entropy loss.
20
+
21
+ Args:
22
+ loss: Cross-entropy loss value
23
+
24
+ Returns:
25
+ Perplexity score
26
+ """
27
+ return math.exp(min(loss, 20)) # Clip to avoid overflow
28
+
29
+
30
+ def compute_bits_per_character(loss: float, log_base: float = 2) -> float:
31
+ """
32
+ Compute bits per character (BPC).
33
+
34
+ Args:
35
+ loss: Cross-entropy loss (in nats)
36
+ log_base: Logarithm base (2 for bits)
37
+
38
+ Returns:
39
+ BPC score
40
+ """
41
+ return loss / math.log(log_base)
42
+
43
+
44
+ def compute_accuracy(predictions: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> float:
45
+ """
46
+ Compute token prediction accuracy.
47
+
48
+ Args:
49
+ predictions: Predicted token IDs (batch, seq_len)
50
+ labels: True token IDs (batch, seq_len)
51
+ ignore_index: Index to ignore in labels
52
+
53
+ Returns:
54
+ Accuracy score (0-1)
55
+ """
56
+ mask = labels != ignore_index
57
+ if mask.sum() == 0:
58
+ return 0.0
59
+
60
+ correct = (predictions == labels) & mask
61
+ return correct.sum().item() / mask.sum().item()
62
+
63
+
64
+ def compute_top_k_accuracy(logits: torch.Tensor, labels: torch.Tensor, k: int = 5, ignore_index: int = -100) -> float:
65
+ """
66
+ Compute top-k token prediction accuracy.
67
+
68
+ Args:
69
+ logits: Model logits (batch, seq_len, vocab_size)
70
+ labels: True token IDs (batch, seq_len)
71
+ k: Number of top predictions to consider
72
+ ignore_index: Index to ignore in labels
73
+
74
+ Returns:
75
+ Top-k accuracy score (0-1)
76
+ """
77
+ mask = labels != ignore_index
78
+ if mask.sum() == 0:
79
+ return 0.0
80
+
81
+ top_k_preds = logits.topk(k, dim=-1).indices
82
+ labels_expanded = labels.unsqueeze(-1).expand_as(top_k_preds)
83
+
84
+ correct = (top_k_preds == labels_expanded).any(dim=-1) & mask
85
+ return correct.sum().item() / mask.sum().item()
86
+
87
+
88
+ def compute_ngram_overlap(generated: List[str], references: List[str], n: int = 1) -> Dict[str, float]:
89
+ """
90
+ Compute n-gram overlap metrics (precision, recall, F1).
91
+
92
+ Args:
93
+ generated: List of generated texts
94
+ references: List of reference texts
95
+ n: N-gram size
96
+
97
+ Returns:
98
+ Dict with precision, recall, and F1 scores
99
+ """
100
+ def get_ngrams(text: str, n: int) -> Counter:
101
+ tokens = text.lower().split()
102
+ return Counter(tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1))
103
+
104
+ total_precision = 0.0
105
+ total_recall = 0.0
106
+ count = 0
107
+
108
+ for gen, ref in zip(generated, references):
109
+ gen_ngrams = get_ngrams(gen, n)
110
+ ref_ngrams = get_ngrams(ref, n)
111
+
112
+ if not gen_ngrams or not ref_ngrams:
113
+ continue
114
+
115
+ overlap = sum((gen_ngrams & ref_ngrams).values())
116
+ precision = overlap / sum(gen_ngrams.values()) if gen_ngrams else 0
117
+ recall = overlap / sum(ref_ngrams.values()) if ref_ngrams else 0
118
+
119
+ total_precision += precision
120
+ total_recall += recall
121
+ count += 1
122
+
123
+ if count == 0:
124
+ return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
125
+
126
+ avg_precision = total_precision / count
127
+ avg_recall = total_recall / count
128
+ f1 = 2 * avg_precision * avg_recall / (avg_precision + avg_recall) if (avg_precision + avg_recall) > 0 else 0
129
+
130
+ return {
131
+ "precision": avg_precision,
132
+ "recall": avg_recall,
133
+ "f1": f1
134
+ }
135
+
136
+
137
+ def compute_bleu(generated: List[str], references: List[str], max_n: int = 4) -> float:
138
+ """
139
+ Compute BLEU score (simplified implementation).
140
+
141
+ Args:
142
+ generated: List of generated texts
143
+ references: List of reference texts
144
+ max_n: Maximum n-gram size
145
+
146
+ Returns:
147
+ BLEU score (0-1)
148
+ """
149
+ def get_ngrams(text: str, n: int) -> Counter:
150
+ tokens = text.lower().split()
151
+ return Counter(tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1))
152
+
153
+ def brevity_penalty(gen_len: int, ref_len: int) -> float:
154
+ if gen_len >= ref_len:
155
+ return 1.0
156
+ return math.exp(1 - ref_len / max(gen_len, 1))
157
+
158
+ precisions = []
159
+ total_gen_len = 0
160
+ total_ref_len = 0
161
+
162
+ for n in range(1, max_n + 1):
163
+ total_overlap = 0
164
+ total_count = 0
165
+
166
+ for gen, ref in zip(generated, references):
167
+ gen_ngrams = get_ngrams(gen, n)
168
+ ref_ngrams = get_ngrams(ref, n)
169
+
170
+ overlap = sum((gen_ngrams & ref_ngrams).values())
171
+ count = sum(gen_ngrams.values())
172
+
173
+ total_overlap += overlap
174
+ total_count += count
175
+
176
+ if n == 1:
177
+ total_gen_len += len(gen.split())
178
+ total_ref_len += len(ref.split())
179
+
180
+ precision = total_overlap / max(total_count, 1)
181
+ precisions.append(max(precision, 1e-10)) # Avoid log(0)
182
+
183
+ # Geometric mean of precisions
184
+ log_precision = sum(math.log(p) for p in precisions) / max_n
185
+ bleu = math.exp(log_precision)
186
+
187
+ # Apply brevity penalty
188
+ bp = brevity_penalty(total_gen_len, total_ref_len)
189
+
190
+ return bp * bleu
191
+
192
+
193
+ def compute_rouge_l(generated: str, reference: str) -> Dict[str, float]:
194
+ """
195
+ Compute ROUGE-L score based on longest common subsequence.
196
+
197
+ Args:
198
+ generated: Generated text
199
+ reference: Reference text
200
+
201
+ Returns:
202
+ Dict with precision, recall, and F1
203
+ """
204
+ def lcs_length(a: List[str], b: List[str]) -> int:
205
+ m, n = len(a), len(b)
206
+ dp = [[0] * (n + 1) for _ in range(m + 1)]
207
+
208
+ for i in range(1, m + 1):
209
+ for j in range(1, n + 1):
210
+ if a[i-1] == b[j-1]:
211
+ dp[i][j] = dp[i-1][j-1] + 1
212
+ else:
213
+ dp[i][j] = max(dp[i-1][j], dp[i][j-1])
214
+
215
+ return dp[m][n]
216
+
217
+ gen_tokens = generated.lower().split()
218
+ ref_tokens = reference.lower().split()
219
+
220
+ if not gen_tokens or not ref_tokens:
221
+ return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
222
+
223
+ lcs = lcs_length(gen_tokens, ref_tokens)
224
+
225
+ precision = lcs / len(gen_tokens)
226
+ recall = lcs / len(ref_tokens)
227
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
228
+
229
+ return {"precision": precision, "recall": recall, "f1": f1}
230
+
231
+
232
+ def compute_diversity(texts: List[str], n: int = 2) -> float:
233
+ """
234
+ Compute n-gram diversity (ratio of unique n-grams to total n-grams).
235
+
236
+ Args:
237
+ texts: List of texts
238
+ n: N-gram size
239
+
240
+ Returns:
241
+ Diversity score (0-1)
242
+ """
243
+ all_ngrams = []
244
+
245
+ for text in texts:
246
+ tokens = text.lower().split()
247
+ ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
248
+ all_ngrams.extend(ngrams)
249
+
250
+ if not all_ngrams:
251
+ return 0.0
252
+
253
+ unique = len(set(all_ngrams))
254
+ total = len(all_ngrams)
255
+
256
+ return unique / total
257
+
258
+
259
+ def compute_repetition_ratio(text: str, n: int = 3) -> float:
260
+ """
261
+ Compute ratio of repeated n-grams (lower is better).
262
+
263
+ Args:
264
+ text: Input text
265
+ n: N-gram size
266
+
267
+ Returns:
268
+ Repetition ratio (0-1)
269
+ """
270
+ tokens = text.lower().split()
271
+ if len(tokens) < n:
272
+ return 0.0
273
+
274
+ ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
275
+
276
+ if not ngrams:
277
+ return 0.0
278
+
279
+ ngram_counts = Counter(ngrams)
280
+ repeated = sum(1 for count in ngram_counts.values() if count > 1)
281
+
282
+ return repeated / len(ngram_counts)
283
+
284
+
285
+ class MetricsCalculator:
286
+ """Convenience class for computing multiple metrics."""
287
+
288
+ def __init__(self):
289
+ self.reset()
290
+
291
+ def reset(self):
292
+ """Reset all metrics."""
293
+ self.total_loss = 0.0
294
+ self.total_tokens = 0
295
+ self.total_correct = 0
296
+ self.generated_texts = []
297
+ self.reference_texts = []
298
+
299
+ def update(
300
+ self,
301
+ loss: Optional[float] = None,
302
+ logits: Optional[torch.Tensor] = None,
303
+ labels: Optional[torch.Tensor] = None,
304
+ generated: Optional[str] = None,
305
+ reference: Optional[str] = None,
306
+ ignore_index: int = -100
307
+ ):
308
+ """Update metrics with a batch."""
309
+ if loss is not None and labels is not None:
310
+ mask = labels != ignore_index
311
+ num_tokens = mask.sum().item()
312
+ self.total_loss += loss * num_tokens
313
+ self.total_tokens += num_tokens
314
+
315
+ if logits is not None and labels is not None:
316
+ predictions = logits.argmax(dim=-1)
317
+ mask = labels != ignore_index
318
+ correct = ((predictions == labels) & mask).sum().item()
319
+ self.total_correct += correct
320
+
321
+ if generated is not None:
322
+ self.generated_texts.append(generated)
323
+ if reference is not None:
324
+ self.reference_texts.append(reference)
325
+
326
+ def compute(self) -> Dict[str, float]:
327
+ """Compute all metrics."""
328
+ results = {}
329
+
330
+ if self.total_tokens > 0:
331
+ avg_loss = self.total_loss / self.total_tokens
332
+ results["loss"] = avg_loss
333
+ results["perplexity"] = compute_perplexity(avg_loss)
334
+ results["accuracy"] = self.total_correct / self.total_tokens
335
+
336
+ if self.generated_texts and self.reference_texts:
337
+ results["bleu"] = compute_bleu(self.generated_texts, self.reference_texts)
338
+
339
+ # Average ROUGE-L
340
+ rouge_scores = [compute_rouge_l(g, r) for g, r in zip(self.generated_texts, self.reference_texts)]
341
+ results["rouge_l_f1"] = sum(s["f1"] for s in rouge_scores) / len(rouge_scores)
342
+
343
+ results["diversity"] = compute_diversity(self.generated_texts)
344
+
345
+ return results
@@ -0,0 +1,20 @@
1
+ """
2
+ Langtune Model Loading Subsystem.
3
+
4
+ Provides high-performance loading primitives:
5
+ - HubResolver: Cached HF downloads
6
+ - TensorStreamer: Lazy safetensors loading
7
+ - ModelLoader: Orchestration
8
+ """
9
+
10
+ from .hub import HubResolver
11
+ from .safetensors import TensorStreamer
12
+ from .weights import WeightLoader
13
+ from .loader import ModelLoader
14
+
15
+ __all__ = [
16
+ "HubResolver",
17
+ "TensorStreamer",
18
+ "WeightLoader",
19
+ "ModelLoader"
20
+ ]