valtron-core 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- valtron_core/__init__.py +3 -0
- valtron_core/__main__.py +16 -0
- valtron_core/attachments.py +32 -0
- valtron_core/bert_evaluator.py +227 -0
- valtron_core/bert_trainer.py +332 -0
- valtron_core/client.py +276 -0
- valtron_core/config.py +109 -0
- valtron_core/cost_utils.py +95 -0
- valtron_core/decompose.py +906 -0
- valtron_core/evaluation/__init__.py +0 -0
- valtron_core/evaluation/comparison_functions.py +607 -0
- valtron_core/evaluation/json_eval.py +557 -0
- valtron_core/evaluator.py +573 -0
- valtron_core/few_shot_training_data_generator.py +976 -0
- valtron_core/loader.py +382 -0
- valtron_core/models.py +331 -0
- valtron_core/optimized_evaluator.py +261 -0
- valtron_core/prompt_optimizer.py +474 -0
- valtron_core/recipes/README.md +122 -0
- valtron_core/recipes/__init__.py +25 -0
- valtron_core/recipes/base.py +223 -0
- valtron_core/recipes/config.py +146 -0
- valtron_core/recipes/model_eval.py +1118 -0
- valtron_core/report.py +1218 -0
- valtron_core/runner.py +1001 -0
- valtron_core/templates/common.css +36 -0
- valtron_core/templates/config_wizard.html +1191 -0
- valtron_core/templates/detailed_analysis.jinja2.html +1189 -0
- valtron_core/templates/evaluation_report.jinja2.html +3143 -0
- valtron_core/templates/pdf_report.jinja2.html +555 -0
- valtron_core/transformer_classifier.py +297 -0
- valtron_core/transformer_wrapper.py +86 -0
- valtron_core/utilities/__init__.py +1 -0
- valtron_core/utilities/aggregate_reports.py +359 -0
- valtron_core/utilities/cli_introspect.py +233 -0
- valtron_core/utilities/code_introspection.py +1076 -0
- valtron_core/utilities/config_wizard.py +255 -0
- valtron_core/utilities/field_config_generator.py +92 -0
- valtron_core/utilities/train_transformer.py +231 -0
- valtron_core-0.1.0.dist-info/METADATA +248 -0
- valtron_core-0.1.0.dist-info/RECORD +43 -0
- valtron_core-0.1.0.dist-info/WHEEL +4 -0
- valtron_core-0.1.0.dist-info/licenses/LICENSE +201 -0
valtron_core/__init__.py
ADDED
valtron_core/__main__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Main entry point for the evaltron_core package."""
|
|
2
|
+
|
|
3
|
+
import typer
|
|
4
|
+
|
|
5
|
+
app = typer.Typer()
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@app.command()
|
|
9
|
+
def main() -> None:
|
|
10
|
+
"""Main entry point for evaltron_core."""
|
|
11
|
+
typer.echo("Valtron Core - LLM call optimization")
|
|
12
|
+
typer.echo("Project initialized successfully!")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
if __name__ == "__main__":
|
|
16
|
+
app()
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""Shared attachment MIME detection utilities."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
_EXT_MIME: dict[str, str] = {
|
|
6
|
+
".pdf": "application/pdf",
|
|
7
|
+
".png": "image/png",
|
|
8
|
+
".jpg": "image/jpeg",
|
|
9
|
+
".jpeg": "image/jpeg",
|
|
10
|
+
".gif": "image/gif",
|
|
11
|
+
".webp": "image/webp",
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
_MAGIC: list[tuple[bytes, str]] = [
|
|
15
|
+
(b"%PDF-", "application/pdf"),
|
|
16
|
+
(b"\x89PNG", "image/png"),
|
|
17
|
+
(b"\xff\xd8\xff", "image/jpeg"),
|
|
18
|
+
(b"GIF8", "image/gif"),
|
|
19
|
+
(b"RIFF", "image/webp"),
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def detect_mime_hint(s: str) -> str:
|
|
24
|
+
"""
|
|
25
|
+
Detect MIME type from a data URI header or file/URL extension without any I/O.
|
|
26
|
+
Returns an empty string if the type cannot be determined.
|
|
27
|
+
"""
|
|
28
|
+
if s.startswith("data:"):
|
|
29
|
+
header = s.split(",")[0]
|
|
30
|
+
return header.split(":")[1].split(";")[0]
|
|
31
|
+
suffix = Path(s.split("?")[0]).suffix.lower()
|
|
32
|
+
return _EXT_MIME.get(suffix, "")
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
"""BERT model evaluation integrated with the evaluation framework.
|
|
2
|
+
|
|
3
|
+
.. deprecated::
|
|
4
|
+
Use :class:`evaltron_core.transformer_classifier.TransformerClassifier` for
|
|
5
|
+
inference and the recipe layer for evaluation. ``BERTEvaluator`` will be
|
|
6
|
+
removed in a future release.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import warnings
|
|
10
|
+
|
|
11
|
+
warnings.warn(
|
|
12
|
+
"evaltron_core.bert_evaluator.BERTEvaluator is deprecated. "
|
|
13
|
+
"Use evaltron_core.transformer_classifier.TransformerClassifier with "
|
|
14
|
+
"ModelEval (type='transformer') instead.",
|
|
15
|
+
DeprecationWarning,
|
|
16
|
+
stacklevel=2,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
import time
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
import structlog
|
|
23
|
+
|
|
24
|
+
from valtron_core.bert_trainer import BERTTrainer
|
|
25
|
+
from valtron_core.models import Document, EvaluationInput, EvaluationResult, Label, PredictionResult
|
|
26
|
+
|
|
27
|
+
logger = structlog.get_logger()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BERTEvaluator:
|
|
31
|
+
"""Evaluate BERT models using the same framework as LLM evaluations."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, trainer: BERTTrainer) -> None:
|
|
34
|
+
"""
|
|
35
|
+
Initialize BERT evaluator.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
trainer: Trained BERTTrainer instance
|
|
39
|
+
"""
|
|
40
|
+
self.trainer = trainer
|
|
41
|
+
|
|
42
|
+
async def evaluate_single(
|
|
43
|
+
self,
|
|
44
|
+
document: Document,
|
|
45
|
+
label: Label,
|
|
46
|
+
) -> PredictionResult:
|
|
47
|
+
"""
|
|
48
|
+
Evaluate a single document.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
document: Document to evaluate
|
|
52
|
+
label: Expected label
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
PredictionResult
|
|
56
|
+
"""
|
|
57
|
+
start_time = time.time()
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
# Get prediction
|
|
61
|
+
predicted_value = self.trainer.predict_single(document.content)
|
|
62
|
+
|
|
63
|
+
end_time = time.time()
|
|
64
|
+
response_time = end_time - start_time
|
|
65
|
+
|
|
66
|
+
# Compare with expected
|
|
67
|
+
is_correct = predicted_value.strip().lower() == label.value.strip().lower()
|
|
68
|
+
|
|
69
|
+
logger.info(
|
|
70
|
+
"bert_evaluation_single",
|
|
71
|
+
document_id=document.id,
|
|
72
|
+
predicted=predicted_value,
|
|
73
|
+
expected=label.value,
|
|
74
|
+
correct=is_correct,
|
|
75
|
+
time=response_time,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return PredictionResult(
|
|
79
|
+
document_id=document.id,
|
|
80
|
+
predicted_value=predicted_value,
|
|
81
|
+
expected_value=label.value,
|
|
82
|
+
is_correct=is_correct,
|
|
83
|
+
response_time=response_time,
|
|
84
|
+
original_cost=0.0,
|
|
85
|
+
cost=0.0, # BERT inference is free
|
|
86
|
+
model="bert-local",
|
|
87
|
+
metadata={"content": document.content},
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
except Exception as e:
|
|
91
|
+
end_time = time.time()
|
|
92
|
+
response_time = end_time - start_time
|
|
93
|
+
|
|
94
|
+
logger.error(
|
|
95
|
+
"bert_evaluation_error",
|
|
96
|
+
document_id=document.id,
|
|
97
|
+
error=str(e),
|
|
98
|
+
time=response_time,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return PredictionResult(
|
|
102
|
+
document_id=document.id,
|
|
103
|
+
predicted_value=f"ERROR: {str(e)}",
|
|
104
|
+
expected_value=label.value,
|
|
105
|
+
is_correct=False,
|
|
106
|
+
response_time=response_time,
|
|
107
|
+
original_cost=0.0,
|
|
108
|
+
cost=0.0,
|
|
109
|
+
model="bert-local",
|
|
110
|
+
metadata={"content": document.content, "error": str(e)},
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
async def evaluate(
|
|
114
|
+
self,
|
|
115
|
+
eval_input: EvaluationInput,
|
|
116
|
+
) -> EvaluationResult:
|
|
117
|
+
"""
|
|
118
|
+
Evaluate all documents.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
eval_input: Evaluation input configuration
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
EvaluationResult with all predictions and metrics
|
|
125
|
+
"""
|
|
126
|
+
import uuid
|
|
127
|
+
from datetime import datetime
|
|
128
|
+
|
|
129
|
+
run_id = str(uuid.uuid4())
|
|
130
|
+
result = EvaluationResult(
|
|
131
|
+
run_id=run_id,
|
|
132
|
+
started_at=datetime.now(),
|
|
133
|
+
prompt_template="BERT local inference (no prompt)",
|
|
134
|
+
model="bert-local",
|
|
135
|
+
status="running",
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Create label lookup
|
|
139
|
+
label_map = {label.document_id: label for label in eval_input.labels}
|
|
140
|
+
|
|
141
|
+
logger.info(
|
|
142
|
+
"bert_evaluation_started",
|
|
143
|
+
run_id=run_id,
|
|
144
|
+
total_documents=len(eval_input.documents),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
# Evaluate all documents
|
|
149
|
+
for doc in eval_input.documents:
|
|
150
|
+
if doc.id not in label_map:
|
|
151
|
+
logger.warning("missing_label", document_id=doc.id)
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
prediction = await self.evaluate_single(
|
|
155
|
+
document=doc,
|
|
156
|
+
label=label_map[doc.id],
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
result.predictions.append(prediction)
|
|
160
|
+
|
|
161
|
+
# Compute metrics
|
|
162
|
+
result.compute_metrics()
|
|
163
|
+
result.completed_at = datetime.now()
|
|
164
|
+
result.status = "completed"
|
|
165
|
+
|
|
166
|
+
logger.info(
|
|
167
|
+
"bert_evaluation_completed",
|
|
168
|
+
run_id=run_id,
|
|
169
|
+
accuracy=result.metrics.accuracy if result.metrics else 0,
|
|
170
|
+
total_time=result.metrics.total_time if result.metrics else 0,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
except Exception as e:
|
|
174
|
+
result.status = "failed"
|
|
175
|
+
result.error = str(e)
|
|
176
|
+
result.completed_at = datetime.now()
|
|
177
|
+
|
|
178
|
+
logger.error("bert_evaluation_failed", run_id=run_id, error=str(e))
|
|
179
|
+
|
|
180
|
+
return result
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def create_bert_model_for_comparison(
|
|
184
|
+
documents: list[Document],
|
|
185
|
+
labels: list[Label],
|
|
186
|
+
model_name: str = "bert-base-uncased",
|
|
187
|
+
output_dir: str = "./bert_models",
|
|
188
|
+
num_epochs: int = 3,
|
|
189
|
+
batch_size: int = 8,
|
|
190
|
+
) -> BERTTrainer:
|
|
191
|
+
"""
|
|
192
|
+
Train a BERT model for comparison with LLMs.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
documents: Training documents
|
|
196
|
+
labels: Training labels
|
|
197
|
+
model_name: Pretrained BERT model to use
|
|
198
|
+
output_dir: Directory to save model
|
|
199
|
+
num_epochs: Number of training epochs
|
|
200
|
+
batch_size: Training batch size
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
Trained BERTTrainer instance
|
|
204
|
+
"""
|
|
205
|
+
logger.info("creating_bert_model", model_name=model_name, num_docs=len(documents))
|
|
206
|
+
|
|
207
|
+
# Initialize trainer
|
|
208
|
+
trainer = BERTTrainer(model_name=model_name, output_dir=output_dir)
|
|
209
|
+
|
|
210
|
+
# Prepare data
|
|
211
|
+
train_dataset, test_dataset = trainer.prepare_data(documents, labels)
|
|
212
|
+
|
|
213
|
+
# Train
|
|
214
|
+
results = trainer.train(
|
|
215
|
+
train_dataset=train_dataset,
|
|
216
|
+
test_dataset=test_dataset,
|
|
217
|
+
num_epochs=num_epochs,
|
|
218
|
+
batch_size=batch_size,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
logger.info(
|
|
222
|
+
"bert_model_created",
|
|
223
|
+
accuracy=results["eval_accuracy"],
|
|
224
|
+
model_dir=results["model_dir"],
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
return trainer
|
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
"""BERT model training for classification tasks.
|
|
2
|
+
|
|
3
|
+
.. deprecated::
|
|
4
|
+
Use :class:`evaltron_core.transformer_classifier.TransformerClassifier` instead.
|
|
5
|
+
``BERTTrainer`` will be removed in a future release.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import warnings
|
|
9
|
+
|
|
10
|
+
warnings.warn(
|
|
11
|
+
"evaltron_core.bert_trainer.BERTTrainer is deprecated. "
|
|
12
|
+
"Use evaltron_core.transformer_classifier.TransformerClassifier instead.",
|
|
13
|
+
DeprecationWarning,
|
|
14
|
+
stacklevel=2,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
import structlog
|
|
22
|
+
import torch
|
|
23
|
+
from datasets import Dataset
|
|
24
|
+
from sklearn.metrics import accuracy_score, classification_report
|
|
25
|
+
from sklearn.model_selection import train_test_split
|
|
26
|
+
from transformers import (
|
|
27
|
+
AutoModelForSequenceClassification,
|
|
28
|
+
AutoTokenizer,
|
|
29
|
+
Trainer,
|
|
30
|
+
TrainingArguments,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
from valtron_core.models import Document, Label
|
|
34
|
+
|
|
35
|
+
logger = structlog.get_logger()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class BERTTrainer:
|
|
39
|
+
"""Train and manage BERT-based classification models."""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
model_name: str = "bert-base-uncased",
|
|
44
|
+
output_dir: str | Path = "./bert_models",
|
|
45
|
+
) -> None:
|
|
46
|
+
"""
|
|
47
|
+
Initialize BERT trainer.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
model_name: Pretrained model name from HuggingFace
|
|
51
|
+
output_dir: Directory to save trained models
|
|
52
|
+
"""
|
|
53
|
+
self.model_name = model_name
|
|
54
|
+
self.output_dir = Path(output_dir)
|
|
55
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
56
|
+
|
|
57
|
+
self.tokenizer: Any = None
|
|
58
|
+
self.model: Any = None
|
|
59
|
+
self.label_to_id: dict[str, int] = {}
|
|
60
|
+
self.id_to_label: dict[int, str] = {}
|
|
61
|
+
|
|
62
|
+
def prepare_data(
|
|
63
|
+
self,
|
|
64
|
+
documents: list[Document],
|
|
65
|
+
labels: list[Label],
|
|
66
|
+
test_size: float = 0.2,
|
|
67
|
+
random_state: int = 42,
|
|
68
|
+
) -> tuple[Dataset, Dataset]:
|
|
69
|
+
"""
|
|
70
|
+
Prepare data for training.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
documents: List of documents
|
|
74
|
+
labels: List of labels
|
|
75
|
+
test_size: Fraction of data to use for testing
|
|
76
|
+
random_state: Random seed for reproducibility
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Tuple of (train_dataset, test_dataset)
|
|
80
|
+
"""
|
|
81
|
+
# Create label mapping
|
|
82
|
+
label_map = {label.document_id: label.value for label in labels}
|
|
83
|
+
|
|
84
|
+
# Extract texts and labels
|
|
85
|
+
texts = []
|
|
86
|
+
label_values = []
|
|
87
|
+
|
|
88
|
+
for doc in documents:
|
|
89
|
+
if doc.id in label_map:
|
|
90
|
+
texts.append(doc.content)
|
|
91
|
+
label_values.append(label_map[doc.id])
|
|
92
|
+
|
|
93
|
+
# Create unique label mappings
|
|
94
|
+
unique_labels = sorted(set(label_values))
|
|
95
|
+
self.label_to_id = {label: idx for idx, label in enumerate(unique_labels)}
|
|
96
|
+
self.id_to_label = {idx: label for label, idx in self.label_to_id.items()}
|
|
97
|
+
|
|
98
|
+
# Convert labels to integers
|
|
99
|
+
label_ids = [self.label_to_id[label] for label in label_values]
|
|
100
|
+
|
|
101
|
+
# Split data
|
|
102
|
+
train_texts, test_texts, train_labels, test_labels = train_test_split(
|
|
103
|
+
texts, label_ids, test_size=test_size, random_state=random_state, stratify=label_ids
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Create datasets
|
|
107
|
+
train_dataset = Dataset.from_dict({"text": train_texts, "label": train_labels})
|
|
108
|
+
test_dataset = Dataset.from_dict({"text": test_texts, "label": test_labels})
|
|
109
|
+
|
|
110
|
+
logger.info(
|
|
111
|
+
"data_prepared",
|
|
112
|
+
num_train=len(train_dataset),
|
|
113
|
+
num_test=len(test_dataset),
|
|
114
|
+
num_labels=len(unique_labels),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return train_dataset, test_dataset
|
|
118
|
+
|
|
119
|
+
def _tokenize_function(self, examples: dict[str, Any]) -> dict[str, Any]:
|
|
120
|
+
"""Tokenize examples."""
|
|
121
|
+
return self.tokenizer(examples["text"], padding="max_length", truncation=True)
|
|
122
|
+
|
|
123
|
+
def train(
|
|
124
|
+
self,
|
|
125
|
+
train_dataset: Dataset,
|
|
126
|
+
test_dataset: Dataset,
|
|
127
|
+
num_epochs: int = 3,
|
|
128
|
+
batch_size: int = 8,
|
|
129
|
+
learning_rate: float = 2e-5,
|
|
130
|
+
warmup_steps: int = 500,
|
|
131
|
+
weight_decay: float = 0.01,
|
|
132
|
+
save_steps: int = 500,
|
|
133
|
+
eval_steps: int = 500,
|
|
134
|
+
) -> dict[str, Any]:
|
|
135
|
+
"""
|
|
136
|
+
Train BERT model.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
train_dataset: Training dataset
|
|
140
|
+
test_dataset: Test dataset
|
|
141
|
+
num_epochs: Number of training epochs
|
|
142
|
+
batch_size: Batch size for training
|
|
143
|
+
learning_rate: Learning rate
|
|
144
|
+
warmup_steps: Number of warmup steps
|
|
145
|
+
weight_decay: Weight decay for optimizer
|
|
146
|
+
save_steps: Save checkpoint every N steps
|
|
147
|
+
eval_steps: Evaluate every N steps
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Training metrics and results
|
|
151
|
+
"""
|
|
152
|
+
# Initialize tokenizer and model
|
|
153
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
154
|
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
|
155
|
+
self.model_name,
|
|
156
|
+
num_labels=len(self.label_to_id),
|
|
157
|
+
id2label=self.id_to_label,
|
|
158
|
+
label2id=self.label_to_id,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Tokenize datasets
|
|
162
|
+
train_dataset = train_dataset.map(self._tokenize_function, batched=True)
|
|
163
|
+
test_dataset = test_dataset.map(self._tokenize_function, batched=True)
|
|
164
|
+
|
|
165
|
+
# Set format for PyTorch
|
|
166
|
+
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
|
167
|
+
test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
|
|
168
|
+
|
|
169
|
+
# Training arguments
|
|
170
|
+
training_args = TrainingArguments(
|
|
171
|
+
output_dir=str(self.output_dir),
|
|
172
|
+
num_train_epochs=num_epochs,
|
|
173
|
+
per_device_train_batch_size=batch_size,
|
|
174
|
+
per_device_eval_batch_size=batch_size,
|
|
175
|
+
warmup_steps=warmup_steps,
|
|
176
|
+
weight_decay=weight_decay,
|
|
177
|
+
learning_rate=learning_rate,
|
|
178
|
+
logging_dir=str(self.output_dir / "logs"),
|
|
179
|
+
logging_steps=100,
|
|
180
|
+
eval_strategy="steps",
|
|
181
|
+
eval_steps=eval_steps,
|
|
182
|
+
save_steps=save_steps,
|
|
183
|
+
load_best_model_at_end=True,
|
|
184
|
+
metric_for_best_model="accuracy",
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Compute metrics function
|
|
188
|
+
def compute_metrics(eval_pred: Any) -> dict[str, float]:
|
|
189
|
+
predictions, labels = eval_pred
|
|
190
|
+
predictions = predictions.argmax(-1)
|
|
191
|
+
accuracy = accuracy_score(labels, predictions)
|
|
192
|
+
return {"accuracy": accuracy}
|
|
193
|
+
|
|
194
|
+
# Create trainer
|
|
195
|
+
trainer = Trainer(
|
|
196
|
+
model=self.model,
|
|
197
|
+
args=training_args,
|
|
198
|
+
train_dataset=train_dataset,
|
|
199
|
+
eval_dataset=test_dataset,
|
|
200
|
+
compute_metrics=compute_metrics,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Train
|
|
204
|
+
logger.info("training_started", epochs=num_epochs, batch_size=batch_size)
|
|
205
|
+
train_result = trainer.train()
|
|
206
|
+
|
|
207
|
+
# Evaluate
|
|
208
|
+
eval_result = trainer.evaluate()
|
|
209
|
+
|
|
210
|
+
logger.info("training_completed", train_loss=train_result.training_loss, **eval_result)
|
|
211
|
+
|
|
212
|
+
# Save model
|
|
213
|
+
final_model_dir = self.output_dir / "final_model"
|
|
214
|
+
trainer.save_model(str(final_model_dir))
|
|
215
|
+
self.tokenizer.save_pretrained(str(final_model_dir))
|
|
216
|
+
|
|
217
|
+
# Save label mappings
|
|
218
|
+
label_map_path = final_model_dir / "label_mapping.json"
|
|
219
|
+
with open(label_map_path, "w") as f:
|
|
220
|
+
json.dump(
|
|
221
|
+
{"label_to_id": self.label_to_id, "id_to_label": self.id_to_label},
|
|
222
|
+
f,
|
|
223
|
+
indent=2,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
return {
|
|
227
|
+
"train_loss": train_result.training_loss,
|
|
228
|
+
"eval_accuracy": eval_result["eval_accuracy"],
|
|
229
|
+
"eval_loss": eval_result["eval_loss"],
|
|
230
|
+
"model_dir": str(final_model_dir),
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
def load_model(self, model_dir: str | Path) -> None:
|
|
234
|
+
"""
|
|
235
|
+
Load a trained model.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
model_dir: Directory containing saved model
|
|
239
|
+
"""
|
|
240
|
+
model_dir = Path(model_dir)
|
|
241
|
+
|
|
242
|
+
self.tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
|
|
243
|
+
self.model = AutoModelForSequenceClassification.from_pretrained(str(model_dir))
|
|
244
|
+
|
|
245
|
+
# Load label mappings
|
|
246
|
+
label_map_path = model_dir / "label_mapping.json"
|
|
247
|
+
if label_map_path.exists():
|
|
248
|
+
with open(label_map_path, "r") as f:
|
|
249
|
+
mappings = json.load(f)
|
|
250
|
+
self.label_to_id = mappings["label_to_id"]
|
|
251
|
+
# Convert string keys back to integers for id_to_label
|
|
252
|
+
self.id_to_label = {int(k): v for k, v in mappings["id_to_label"].items()}
|
|
253
|
+
|
|
254
|
+
logger.info("model_loaded", model_dir=str(model_dir))
|
|
255
|
+
|
|
256
|
+
def predict(self, texts: list[str]) -> list[str]:
|
|
257
|
+
"""
|
|
258
|
+
Make predictions on new texts.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
texts: List of texts to classify
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
List of predicted labels
|
|
265
|
+
"""
|
|
266
|
+
if self.model is None or self.tokenizer is None:
|
|
267
|
+
raise ValueError("Model not loaded. Train or load a model first.")
|
|
268
|
+
|
|
269
|
+
# Tokenize
|
|
270
|
+
inputs = self.tokenizer(
|
|
271
|
+
texts,
|
|
272
|
+
padding=True,
|
|
273
|
+
truncation=True,
|
|
274
|
+
return_tensors="pt",
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Predict
|
|
278
|
+
self.model.eval()
|
|
279
|
+
with torch.no_grad():
|
|
280
|
+
outputs = self.model(**inputs)
|
|
281
|
+
predictions = outputs.logits.argmax(-1).tolist()
|
|
282
|
+
|
|
283
|
+
# Convert to labels
|
|
284
|
+
predicted_labels = [self.id_to_label[pred] for pred in predictions]
|
|
285
|
+
|
|
286
|
+
return predicted_labels
|
|
287
|
+
|
|
288
|
+
def predict_single(self, text: str) -> str:
|
|
289
|
+
"""
|
|
290
|
+
Predict label for a single text.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
text: Text to classify
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
Predicted label
|
|
297
|
+
"""
|
|
298
|
+
return self.predict([text])[0]
|
|
299
|
+
|
|
300
|
+
def evaluate_on_documents(
|
|
301
|
+
self,
|
|
302
|
+
documents: list[Document],
|
|
303
|
+
labels: list[Label],
|
|
304
|
+
) -> dict[str, Any]:
|
|
305
|
+
"""
|
|
306
|
+
Evaluate model on a set of documents.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
documents: List of documents
|
|
310
|
+
labels: List of expected labels
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
Evaluation metrics
|
|
314
|
+
"""
|
|
315
|
+
# Create label mapping
|
|
316
|
+
label_map = {label.document_id: label.value for label in labels}
|
|
317
|
+
|
|
318
|
+
# Get predictions
|
|
319
|
+
texts = [doc.content for doc in documents if doc.id in label_map]
|
|
320
|
+
expected_labels = [label_map[doc.id] for doc in documents if doc.id in label_map]
|
|
321
|
+
|
|
322
|
+
predicted_labels = self.predict(texts)
|
|
323
|
+
|
|
324
|
+
# Calculate metrics
|
|
325
|
+
accuracy = accuracy_score(expected_labels, predicted_labels)
|
|
326
|
+
report = classification_report(expected_labels, predicted_labels, output_dict=True)
|
|
327
|
+
|
|
328
|
+
return {
|
|
329
|
+
"accuracy": accuracy,
|
|
330
|
+
"classification_report": report,
|
|
331
|
+
"num_samples": len(texts),
|
|
332
|
+
}
|