langtune 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langtune/__init__.py +315 -0
- langtune/acceleration.py +132 -0
- langtune/api.py +320 -0
- langtune/auth.py +434 -0
- langtune/callbacks.py +268 -0
- langtune/cli.py +687 -0
- langtune/client.py +721 -0
- langtune/config.py +356 -0
- langtune/data.py +526 -0
- langtune/distributed.py +154 -0
- langtune/facade.py +174 -0
- langtune/finetune.py +491 -0
- langtune/generation.py +95 -0
- langtune/logging_utils.py +182 -0
- langtune/metrics.py +345 -0
- langtune/model/__init__.py +20 -0
- langtune/model/hub.py +109 -0
- langtune/model/loader.py +84 -0
- langtune/model/safetensors.py +104 -0
- langtune/model/weights.py +100 -0
- langtune/models.py +19 -0
- langtune/nn/fast_transformer.py +399 -0
- langtune/nn/layers.py +178 -0
- langtune/nn/transformer.py +254 -0
- langtune/optimizations.py +870 -0
- langtune/py.typed +2 -0
- langtune/schedulers.py +234 -0
- langtune/tokenizers.py +275 -0
- langtune/trainer.py +889 -0
- langtune/training/neftune.py +80 -0
- langtune/utils.py +337 -0
- langtune-0.1.19.dist-info/METADATA +257 -0
- langtune-0.1.19.dist-info/RECORD +37 -0
- langtune-0.1.19.dist-info/WHEEL +5 -0
- langtune-0.1.19.dist-info/entry_points.txt +2 -0
- langtune-0.1.19.dist-info/licenses/LICENSE +21 -0
- langtune-0.1.19.dist-info/top_level.txt +1 -0
langtune/generation.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""
|
|
2
|
+
generation.py: Text generation utilities for Langtune
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from typing import Optional, Union
|
|
9
|
+
import logging
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TextGenerator:
|
|
15
|
+
"""High-level text generation with sampling strategies."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, model: nn.Module, tokenizer=None, device: Optional[torch.device] = None):
|
|
18
|
+
self.model = model
|
|
19
|
+
self.tokenizer = tokenizer
|
|
20
|
+
self.device = device or next(model.parameters()).device
|
|
21
|
+
self.model.eval()
|
|
22
|
+
|
|
23
|
+
@torch.no_grad()
|
|
24
|
+
def generate(
|
|
25
|
+
self,
|
|
26
|
+
prompt: Union[str, torch.Tensor],
|
|
27
|
+
max_length: int = 100,
|
|
28
|
+
temperature: float = 1.0,
|
|
29
|
+
top_k: Optional[int] = None,
|
|
30
|
+
top_p: Optional[float] = None,
|
|
31
|
+
repetition_penalty: float = 1.0,
|
|
32
|
+
do_sample: bool = True,
|
|
33
|
+
eos_token_id: int = 1
|
|
34
|
+
) -> Union[str, torch.Tensor]:
|
|
35
|
+
"""Generate text with various sampling strategies."""
|
|
36
|
+
# Encode prompt
|
|
37
|
+
if isinstance(prompt, str):
|
|
38
|
+
if self.tokenizer:
|
|
39
|
+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
|
40
|
+
else:
|
|
41
|
+
input_ids = torch.tensor([[ord(c) for c in prompt]], device=self.device)
|
|
42
|
+
else:
|
|
43
|
+
input_ids = prompt.to(self.device)
|
|
44
|
+
|
|
45
|
+
for _ in range(max_length - input_ids.size(1)):
|
|
46
|
+
outputs = self.model(input_ids)
|
|
47
|
+
logits = outputs["logits"] if isinstance(outputs, dict) else outputs
|
|
48
|
+
next_logits = logits[:, -1, :]
|
|
49
|
+
|
|
50
|
+
# Repetition penalty
|
|
51
|
+
if repetition_penalty != 1.0:
|
|
52
|
+
for token in input_ids[0].unique():
|
|
53
|
+
next_logits[0, token] /= repetition_penalty
|
|
54
|
+
|
|
55
|
+
# Temperature
|
|
56
|
+
if do_sample and temperature != 1.0:
|
|
57
|
+
next_logits = next_logits / temperature
|
|
58
|
+
|
|
59
|
+
# Top-k filtering
|
|
60
|
+
if top_k:
|
|
61
|
+
top_k = min(top_k, next_logits.size(-1))
|
|
62
|
+
indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
|
|
63
|
+
next_logits[indices_to_remove] = float('-inf')
|
|
64
|
+
|
|
65
|
+
# Top-p filtering
|
|
66
|
+
if top_p and top_p < 1.0:
|
|
67
|
+
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
|
|
68
|
+
cumulative = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
69
|
+
remove = cumulative > top_p
|
|
70
|
+
remove[..., 1:] = remove[..., :-1].clone()
|
|
71
|
+
remove[..., 0] = 0
|
|
72
|
+
indices_to_remove = remove.scatter(1, sorted_indices, remove)
|
|
73
|
+
next_logits[indices_to_remove] = float('-inf')
|
|
74
|
+
|
|
75
|
+
# Sample or greedy
|
|
76
|
+
if do_sample:
|
|
77
|
+
next_token = torch.multinomial(F.softmax(next_logits, dim=-1), 1)
|
|
78
|
+
else:
|
|
79
|
+
next_token = next_logits.argmax(dim=-1, keepdim=True)
|
|
80
|
+
|
|
81
|
+
input_ids = torch.cat([input_ids, next_token], dim=1)
|
|
82
|
+
|
|
83
|
+
if next_token.item() == eos_token_id:
|
|
84
|
+
break
|
|
85
|
+
|
|
86
|
+
if self.tokenizer:
|
|
87
|
+
return self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
|
88
|
+
return input_ids
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def generate(model, prompt, max_length=100, temperature=0.8, top_k=50, top_p=0.95, **kwargs):
|
|
92
|
+
"""Convenience function for text generation."""
|
|
93
|
+
return TextGenerator(model).generate(
|
|
94
|
+
prompt, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, **kwargs
|
|
95
|
+
)
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""
|
|
2
|
+
logging_utils.py: Logging utilities for Langtune
|
|
3
|
+
|
|
4
|
+
Provides colorful logging and progress tracking.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import sys
|
|
9
|
+
from typing import Optional
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
|
|
12
|
+
# Rich console for pretty output
|
|
13
|
+
try:
|
|
14
|
+
from rich.console import Console
|
|
15
|
+
from rich.logging import RichHandler
|
|
16
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn
|
|
17
|
+
RICH_AVAILABLE = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
RICH_AVAILABLE = False
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def setup_logging(
|
|
23
|
+
level: str = "INFO",
|
|
24
|
+
log_file: Optional[str] = None,
|
|
25
|
+
use_rich: bool = True
|
|
26
|
+
):
|
|
27
|
+
"""Setup logging with optional rich formatting."""
|
|
28
|
+
log_level = getattr(logging, level.upper(), logging.INFO)
|
|
29
|
+
|
|
30
|
+
handlers = []
|
|
31
|
+
|
|
32
|
+
if use_rich and RICH_AVAILABLE:
|
|
33
|
+
handlers.append(RichHandler(
|
|
34
|
+
console=Console(stderr=True),
|
|
35
|
+
show_time=True,
|
|
36
|
+
show_path=False,
|
|
37
|
+
rich_tracebacks=True
|
|
38
|
+
))
|
|
39
|
+
else:
|
|
40
|
+
handler = logging.StreamHandler(sys.stderr)
|
|
41
|
+
handler.setFormatter(logging.Formatter(
|
|
42
|
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
43
|
+
))
|
|
44
|
+
handlers.append(handler)
|
|
45
|
+
|
|
46
|
+
if log_file:
|
|
47
|
+
file_handler = logging.FileHandler(log_file)
|
|
48
|
+
file_handler.setFormatter(logging.Formatter(
|
|
49
|
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
50
|
+
))
|
|
51
|
+
handlers.append(file_handler)
|
|
52
|
+
|
|
53
|
+
logging.basicConfig(level=log_level, handlers=handlers, force=True)
|
|
54
|
+
|
|
55
|
+
# Set langtune logger
|
|
56
|
+
logger = logging.getLogger("langtune")
|
|
57
|
+
logger.setLevel(log_level)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_logger(name: str = "langtune") -> logging.Logger:
|
|
61
|
+
"""Get a logger with the given name."""
|
|
62
|
+
return logging.getLogger(name)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class TrainingLogger:
|
|
66
|
+
"""Structured logger for training progress."""
|
|
67
|
+
|
|
68
|
+
def __init__(self, name: str = "training", log_file: Optional[str] = None):
|
|
69
|
+
self.logger = logging.getLogger(name)
|
|
70
|
+
self.log_file = log_file
|
|
71
|
+
self.start_time = None
|
|
72
|
+
self.metrics_history = []
|
|
73
|
+
|
|
74
|
+
def start(self, message: str = "Starting training"):
|
|
75
|
+
"""Log training start."""
|
|
76
|
+
self.start_time = datetime.now()
|
|
77
|
+
self.logger.info(f"🚀 {message}")
|
|
78
|
+
|
|
79
|
+
def log_epoch(self, epoch: int, total_epochs: int, metrics: dict):
|
|
80
|
+
"""Log epoch completion."""
|
|
81
|
+
metrics_str = ", ".join(f"{k}={v:.4f}" for k, v in metrics.items())
|
|
82
|
+
self.logger.info(f"📊 Epoch {epoch+1}/{total_epochs}: {metrics_str}")
|
|
83
|
+
self.metrics_history.append({"epoch": epoch + 1, **metrics})
|
|
84
|
+
|
|
85
|
+
def log_step(self, step: int, loss: float, lr: Optional[float] = None):
|
|
86
|
+
"""Log training step."""
|
|
87
|
+
msg = f"Step {step}: loss={loss:.4f}"
|
|
88
|
+
if lr is not None:
|
|
89
|
+
msg += f", lr={lr:.2e}"
|
|
90
|
+
self.logger.debug(msg)
|
|
91
|
+
|
|
92
|
+
def log_validation(self, metrics: dict):
|
|
93
|
+
"""Log validation results."""
|
|
94
|
+
metrics_str = ", ".join(f"{k}={v:.4f}" for k, v in metrics.items())
|
|
95
|
+
self.logger.info(f"✓ Validation: {metrics_str}")
|
|
96
|
+
|
|
97
|
+
def end(self, message: str = "Training complete"):
|
|
98
|
+
"""Log training end."""
|
|
99
|
+
if self.start_time:
|
|
100
|
+
elapsed = datetime.now() - self.start_time
|
|
101
|
+
self.logger.info(f"🎉 {message} (took {elapsed})")
|
|
102
|
+
else:
|
|
103
|
+
self.logger.info(f"🎉 {message}")
|
|
104
|
+
|
|
105
|
+
def error(self, message: str):
|
|
106
|
+
"""Log error."""
|
|
107
|
+
self.logger.error(f"❌ {message}")
|
|
108
|
+
|
|
109
|
+
def warning(self, message: str):
|
|
110
|
+
"""Log warning."""
|
|
111
|
+
self.logger.warning(f"⚠️ {message}")
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class ProgressTracker:
|
|
115
|
+
"""Progress bar for training."""
|
|
116
|
+
|
|
117
|
+
def __init__(self, total: int, description: str = "Training"):
|
|
118
|
+
self.total = total
|
|
119
|
+
self.description = description
|
|
120
|
+
self.current = 0
|
|
121
|
+
self.progress = None
|
|
122
|
+
self.task = None
|
|
123
|
+
|
|
124
|
+
def __enter__(self):
|
|
125
|
+
if RICH_AVAILABLE:
|
|
126
|
+
self.progress = Progress(
|
|
127
|
+
SpinnerColumn(),
|
|
128
|
+
TextColumn("[bold blue]{task.description}"),
|
|
129
|
+
BarColumn(),
|
|
130
|
+
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
|
131
|
+
TimeElapsedColumn()
|
|
132
|
+
)
|
|
133
|
+
self.progress.start()
|
|
134
|
+
self.task = self.progress.add_task(self.description, total=self.total)
|
|
135
|
+
return self
|
|
136
|
+
|
|
137
|
+
def __exit__(self, *args):
|
|
138
|
+
if self.progress:
|
|
139
|
+
self.progress.stop()
|
|
140
|
+
|
|
141
|
+
def update(self, n: int = 1, description: Optional[str] = None):
|
|
142
|
+
"""Update progress."""
|
|
143
|
+
self.current += n
|
|
144
|
+
if self.progress and self.task is not None:
|
|
145
|
+
self.progress.update(self.task, advance=n)
|
|
146
|
+
if description:
|
|
147
|
+
self.progress.update(self.task, description=description)
|
|
148
|
+
|
|
149
|
+
def set_description(self, description: str):
|
|
150
|
+
"""Update description."""
|
|
151
|
+
if self.progress and self.task is not None:
|
|
152
|
+
self.progress.update(self.task, description=description)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def print_banner(text: str, style: str = "cyan"):
|
|
156
|
+
"""Print a styled banner."""
|
|
157
|
+
if RICH_AVAILABLE:
|
|
158
|
+
console = Console()
|
|
159
|
+
console.print(f"\n[bold {style}]{'='*60}[/]")
|
|
160
|
+
console.print(f"[bold {style}] {text}[/]")
|
|
161
|
+
console.print(f"[bold {style}]{'='*60}[/]\n")
|
|
162
|
+
else:
|
|
163
|
+
print(f"\n{'='*60}")
|
|
164
|
+
print(f" {text}")
|
|
165
|
+
print(f"{'='*60}\n")
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def print_metrics(metrics: dict, title: str = "Metrics"):
|
|
169
|
+
"""Print metrics in a nice format."""
|
|
170
|
+
if RICH_AVAILABLE:
|
|
171
|
+
from rich.table import Table
|
|
172
|
+
console = Console()
|
|
173
|
+
table = Table(title=title)
|
|
174
|
+
table.add_column("Metric", style="cyan")
|
|
175
|
+
table.add_column("Value", style="green")
|
|
176
|
+
for k, v in metrics.items():
|
|
177
|
+
table.add_row(k, f"{v:.4f}" if isinstance(v, float) else str(v))
|
|
178
|
+
console.print(table)
|
|
179
|
+
else:
|
|
180
|
+
print(f"\n{title}:")
|
|
181
|
+
for k, v in metrics.items():
|
|
182
|
+
print(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}")
|
langtune/metrics.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
"""
|
|
2
|
+
metrics.py: Evaluation metrics for Langtune
|
|
3
|
+
|
|
4
|
+
Provides metrics for evaluating language models.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
import torch
|
|
9
|
+
import numpy as np
|
|
10
|
+
from typing import List, Dict, Any, Optional, Union
|
|
11
|
+
from collections import Counter
|
|
12
|
+
import logging
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def compute_perplexity(loss: float) -> float:
|
|
18
|
+
"""
|
|
19
|
+
Compute perplexity from cross-entropy loss.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
loss: Cross-entropy loss value
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Perplexity score
|
|
26
|
+
"""
|
|
27
|
+
return math.exp(min(loss, 20)) # Clip to avoid overflow
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def compute_bits_per_character(loss: float, log_base: float = 2) -> float:
|
|
31
|
+
"""
|
|
32
|
+
Compute bits per character (BPC).
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
loss: Cross-entropy loss (in nats)
|
|
36
|
+
log_base: Logarithm base (2 for bits)
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
BPC score
|
|
40
|
+
"""
|
|
41
|
+
return loss / math.log(log_base)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def compute_accuracy(predictions: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> float:
|
|
45
|
+
"""
|
|
46
|
+
Compute token prediction accuracy.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
predictions: Predicted token IDs (batch, seq_len)
|
|
50
|
+
labels: True token IDs (batch, seq_len)
|
|
51
|
+
ignore_index: Index to ignore in labels
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Accuracy score (0-1)
|
|
55
|
+
"""
|
|
56
|
+
mask = labels != ignore_index
|
|
57
|
+
if mask.sum() == 0:
|
|
58
|
+
return 0.0
|
|
59
|
+
|
|
60
|
+
correct = (predictions == labels) & mask
|
|
61
|
+
return correct.sum().item() / mask.sum().item()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def compute_top_k_accuracy(logits: torch.Tensor, labels: torch.Tensor, k: int = 5, ignore_index: int = -100) -> float:
|
|
65
|
+
"""
|
|
66
|
+
Compute top-k token prediction accuracy.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
logits: Model logits (batch, seq_len, vocab_size)
|
|
70
|
+
labels: True token IDs (batch, seq_len)
|
|
71
|
+
k: Number of top predictions to consider
|
|
72
|
+
ignore_index: Index to ignore in labels
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Top-k accuracy score (0-1)
|
|
76
|
+
"""
|
|
77
|
+
mask = labels != ignore_index
|
|
78
|
+
if mask.sum() == 0:
|
|
79
|
+
return 0.0
|
|
80
|
+
|
|
81
|
+
top_k_preds = logits.topk(k, dim=-1).indices
|
|
82
|
+
labels_expanded = labels.unsqueeze(-1).expand_as(top_k_preds)
|
|
83
|
+
|
|
84
|
+
correct = (top_k_preds == labels_expanded).any(dim=-1) & mask
|
|
85
|
+
return correct.sum().item() / mask.sum().item()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def compute_ngram_overlap(generated: List[str], references: List[str], n: int = 1) -> Dict[str, float]:
|
|
89
|
+
"""
|
|
90
|
+
Compute n-gram overlap metrics (precision, recall, F1).
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
generated: List of generated texts
|
|
94
|
+
references: List of reference texts
|
|
95
|
+
n: N-gram size
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Dict with precision, recall, and F1 scores
|
|
99
|
+
"""
|
|
100
|
+
def get_ngrams(text: str, n: int) -> Counter:
|
|
101
|
+
tokens = text.lower().split()
|
|
102
|
+
return Counter(tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1))
|
|
103
|
+
|
|
104
|
+
total_precision = 0.0
|
|
105
|
+
total_recall = 0.0
|
|
106
|
+
count = 0
|
|
107
|
+
|
|
108
|
+
for gen, ref in zip(generated, references):
|
|
109
|
+
gen_ngrams = get_ngrams(gen, n)
|
|
110
|
+
ref_ngrams = get_ngrams(ref, n)
|
|
111
|
+
|
|
112
|
+
if not gen_ngrams or not ref_ngrams:
|
|
113
|
+
continue
|
|
114
|
+
|
|
115
|
+
overlap = sum((gen_ngrams & ref_ngrams).values())
|
|
116
|
+
precision = overlap / sum(gen_ngrams.values()) if gen_ngrams else 0
|
|
117
|
+
recall = overlap / sum(ref_ngrams.values()) if ref_ngrams else 0
|
|
118
|
+
|
|
119
|
+
total_precision += precision
|
|
120
|
+
total_recall += recall
|
|
121
|
+
count += 1
|
|
122
|
+
|
|
123
|
+
if count == 0:
|
|
124
|
+
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
|
125
|
+
|
|
126
|
+
avg_precision = total_precision / count
|
|
127
|
+
avg_recall = total_recall / count
|
|
128
|
+
f1 = 2 * avg_precision * avg_recall / (avg_precision + avg_recall) if (avg_precision + avg_recall) > 0 else 0
|
|
129
|
+
|
|
130
|
+
return {
|
|
131
|
+
"precision": avg_precision,
|
|
132
|
+
"recall": avg_recall,
|
|
133
|
+
"f1": f1
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def compute_bleu(generated: List[str], references: List[str], max_n: int = 4) -> float:
|
|
138
|
+
"""
|
|
139
|
+
Compute BLEU score (simplified implementation).
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
generated: List of generated texts
|
|
143
|
+
references: List of reference texts
|
|
144
|
+
max_n: Maximum n-gram size
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
BLEU score (0-1)
|
|
148
|
+
"""
|
|
149
|
+
def get_ngrams(text: str, n: int) -> Counter:
|
|
150
|
+
tokens = text.lower().split()
|
|
151
|
+
return Counter(tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1))
|
|
152
|
+
|
|
153
|
+
def brevity_penalty(gen_len: int, ref_len: int) -> float:
|
|
154
|
+
if gen_len >= ref_len:
|
|
155
|
+
return 1.0
|
|
156
|
+
return math.exp(1 - ref_len / max(gen_len, 1))
|
|
157
|
+
|
|
158
|
+
precisions = []
|
|
159
|
+
total_gen_len = 0
|
|
160
|
+
total_ref_len = 0
|
|
161
|
+
|
|
162
|
+
for n in range(1, max_n + 1):
|
|
163
|
+
total_overlap = 0
|
|
164
|
+
total_count = 0
|
|
165
|
+
|
|
166
|
+
for gen, ref in zip(generated, references):
|
|
167
|
+
gen_ngrams = get_ngrams(gen, n)
|
|
168
|
+
ref_ngrams = get_ngrams(ref, n)
|
|
169
|
+
|
|
170
|
+
overlap = sum((gen_ngrams & ref_ngrams).values())
|
|
171
|
+
count = sum(gen_ngrams.values())
|
|
172
|
+
|
|
173
|
+
total_overlap += overlap
|
|
174
|
+
total_count += count
|
|
175
|
+
|
|
176
|
+
if n == 1:
|
|
177
|
+
total_gen_len += len(gen.split())
|
|
178
|
+
total_ref_len += len(ref.split())
|
|
179
|
+
|
|
180
|
+
precision = total_overlap / max(total_count, 1)
|
|
181
|
+
precisions.append(max(precision, 1e-10)) # Avoid log(0)
|
|
182
|
+
|
|
183
|
+
# Geometric mean of precisions
|
|
184
|
+
log_precision = sum(math.log(p) for p in precisions) / max_n
|
|
185
|
+
bleu = math.exp(log_precision)
|
|
186
|
+
|
|
187
|
+
# Apply brevity penalty
|
|
188
|
+
bp = brevity_penalty(total_gen_len, total_ref_len)
|
|
189
|
+
|
|
190
|
+
return bp * bleu
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def compute_rouge_l(generated: str, reference: str) -> Dict[str, float]:
|
|
194
|
+
"""
|
|
195
|
+
Compute ROUGE-L score based on longest common subsequence.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
generated: Generated text
|
|
199
|
+
reference: Reference text
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
Dict with precision, recall, and F1
|
|
203
|
+
"""
|
|
204
|
+
def lcs_length(a: List[str], b: List[str]) -> int:
|
|
205
|
+
m, n = len(a), len(b)
|
|
206
|
+
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
|
207
|
+
|
|
208
|
+
for i in range(1, m + 1):
|
|
209
|
+
for j in range(1, n + 1):
|
|
210
|
+
if a[i-1] == b[j-1]:
|
|
211
|
+
dp[i][j] = dp[i-1][j-1] + 1
|
|
212
|
+
else:
|
|
213
|
+
dp[i][j] = max(dp[i-1][j], dp[i][j-1])
|
|
214
|
+
|
|
215
|
+
return dp[m][n]
|
|
216
|
+
|
|
217
|
+
gen_tokens = generated.lower().split()
|
|
218
|
+
ref_tokens = reference.lower().split()
|
|
219
|
+
|
|
220
|
+
if not gen_tokens or not ref_tokens:
|
|
221
|
+
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
|
222
|
+
|
|
223
|
+
lcs = lcs_length(gen_tokens, ref_tokens)
|
|
224
|
+
|
|
225
|
+
precision = lcs / len(gen_tokens)
|
|
226
|
+
recall = lcs / len(ref_tokens)
|
|
227
|
+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
|
228
|
+
|
|
229
|
+
return {"precision": precision, "recall": recall, "f1": f1}
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def compute_diversity(texts: List[str], n: int = 2) -> float:
|
|
233
|
+
"""
|
|
234
|
+
Compute n-gram diversity (ratio of unique n-grams to total n-grams).
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
texts: List of texts
|
|
238
|
+
n: N-gram size
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Diversity score (0-1)
|
|
242
|
+
"""
|
|
243
|
+
all_ngrams = []
|
|
244
|
+
|
|
245
|
+
for text in texts:
|
|
246
|
+
tokens = text.lower().split()
|
|
247
|
+
ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
|
|
248
|
+
all_ngrams.extend(ngrams)
|
|
249
|
+
|
|
250
|
+
if not all_ngrams:
|
|
251
|
+
return 0.0
|
|
252
|
+
|
|
253
|
+
unique = len(set(all_ngrams))
|
|
254
|
+
total = len(all_ngrams)
|
|
255
|
+
|
|
256
|
+
return unique / total
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def compute_repetition_ratio(text: str, n: int = 3) -> float:
|
|
260
|
+
"""
|
|
261
|
+
Compute ratio of repeated n-grams (lower is better).
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
text: Input text
|
|
265
|
+
n: N-gram size
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
Repetition ratio (0-1)
|
|
269
|
+
"""
|
|
270
|
+
tokens = text.lower().split()
|
|
271
|
+
if len(tokens) < n:
|
|
272
|
+
return 0.0
|
|
273
|
+
|
|
274
|
+
ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
|
|
275
|
+
|
|
276
|
+
if not ngrams:
|
|
277
|
+
return 0.0
|
|
278
|
+
|
|
279
|
+
ngram_counts = Counter(ngrams)
|
|
280
|
+
repeated = sum(1 for count in ngram_counts.values() if count > 1)
|
|
281
|
+
|
|
282
|
+
return repeated / len(ngram_counts)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
class MetricsCalculator:
|
|
286
|
+
"""Convenience class for computing multiple metrics."""
|
|
287
|
+
|
|
288
|
+
def __init__(self):
|
|
289
|
+
self.reset()
|
|
290
|
+
|
|
291
|
+
def reset(self):
|
|
292
|
+
"""Reset all metrics."""
|
|
293
|
+
self.total_loss = 0.0
|
|
294
|
+
self.total_tokens = 0
|
|
295
|
+
self.total_correct = 0
|
|
296
|
+
self.generated_texts = []
|
|
297
|
+
self.reference_texts = []
|
|
298
|
+
|
|
299
|
+
def update(
|
|
300
|
+
self,
|
|
301
|
+
loss: Optional[float] = None,
|
|
302
|
+
logits: Optional[torch.Tensor] = None,
|
|
303
|
+
labels: Optional[torch.Tensor] = None,
|
|
304
|
+
generated: Optional[str] = None,
|
|
305
|
+
reference: Optional[str] = None,
|
|
306
|
+
ignore_index: int = -100
|
|
307
|
+
):
|
|
308
|
+
"""Update metrics with a batch."""
|
|
309
|
+
if loss is not None and labels is not None:
|
|
310
|
+
mask = labels != ignore_index
|
|
311
|
+
num_tokens = mask.sum().item()
|
|
312
|
+
self.total_loss += loss * num_tokens
|
|
313
|
+
self.total_tokens += num_tokens
|
|
314
|
+
|
|
315
|
+
if logits is not None and labels is not None:
|
|
316
|
+
predictions = logits.argmax(dim=-1)
|
|
317
|
+
mask = labels != ignore_index
|
|
318
|
+
correct = ((predictions == labels) & mask).sum().item()
|
|
319
|
+
self.total_correct += correct
|
|
320
|
+
|
|
321
|
+
if generated is not None:
|
|
322
|
+
self.generated_texts.append(generated)
|
|
323
|
+
if reference is not None:
|
|
324
|
+
self.reference_texts.append(reference)
|
|
325
|
+
|
|
326
|
+
def compute(self) -> Dict[str, float]:
|
|
327
|
+
"""Compute all metrics."""
|
|
328
|
+
results = {}
|
|
329
|
+
|
|
330
|
+
if self.total_tokens > 0:
|
|
331
|
+
avg_loss = self.total_loss / self.total_tokens
|
|
332
|
+
results["loss"] = avg_loss
|
|
333
|
+
results["perplexity"] = compute_perplexity(avg_loss)
|
|
334
|
+
results["accuracy"] = self.total_correct / self.total_tokens
|
|
335
|
+
|
|
336
|
+
if self.generated_texts and self.reference_texts:
|
|
337
|
+
results["bleu"] = compute_bleu(self.generated_texts, self.reference_texts)
|
|
338
|
+
|
|
339
|
+
# Average ROUGE-L
|
|
340
|
+
rouge_scores = [compute_rouge_l(g, r) for g, r in zip(self.generated_texts, self.reference_texts)]
|
|
341
|
+
results["rouge_l_f1"] = sum(s["f1"] for s in rouge_scores) / len(rouge_scores)
|
|
342
|
+
|
|
343
|
+
results["diversity"] = compute_diversity(self.generated_texts)
|
|
344
|
+
|
|
345
|
+
return results
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Langtune Model Loading Subsystem.
|
|
3
|
+
|
|
4
|
+
Provides high-performance loading primitives:
|
|
5
|
+
- HubResolver: Cached HF downloads
|
|
6
|
+
- TensorStreamer: Lazy safetensors loading
|
|
7
|
+
- ModelLoader: Orchestration
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from .hub import HubResolver
|
|
11
|
+
from .safetensors import TensorStreamer
|
|
12
|
+
from .weights import WeightLoader
|
|
13
|
+
from .loader import ModelLoader
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"HubResolver",
|
|
17
|
+
"TensorStreamer",
|
|
18
|
+
"WeightLoader",
|
|
19
|
+
"ModelLoader"
|
|
20
|
+
]
|