isa-model 0.3.5__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +30 -1
- isa_model/client.py +770 -0
- isa_model/core/config/__init__.py +16 -0
- isa_model/core/config/config_manager.py +514 -0
- isa_model/core/config.py +426 -0
- isa_model/core/models/model_billing_tracker.py +476 -0
- isa_model/core/models/model_manager.py +399 -0
- isa_model/core/{storage/supabase_storage.py → models/model_repo.py} +72 -73
- isa_model/core/pricing_manager.py +426 -0
- isa_model/core/services/__init__.py +19 -0
- isa_model/core/services/intelligent_model_selector.py +547 -0
- isa_model/core/types.py +291 -0
- isa_model/deployment/__init__.py +2 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +157 -3
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +104 -3
- isa_model/deployment/cloud/modal/register_models.py +321 -0
- isa_model/deployment/runtime/deployed_service.py +338 -0
- isa_model/deployment/services/__init__.py +9 -0
- isa_model/deployment/services/auto_deploy_vision_service.py +537 -0
- isa_model/deployment/services/model_service.py +332 -0
- isa_model/deployment/services/service_monitor.py +356 -0
- isa_model/deployment/services/service_registry.py +527 -0
- isa_model/eval/__init__.py +80 -44
- isa_model/eval/config/__init__.py +10 -0
- isa_model/eval/config/evaluation_config.py +108 -0
- isa_model/eval/evaluators/__init__.py +18 -0
- isa_model/eval/evaluators/base_evaluator.py +503 -0
- isa_model/eval/evaluators/llm_evaluator.py +472 -0
- isa_model/eval/factory.py +417 -709
- isa_model/eval/infrastructure/__init__.py +24 -0
- isa_model/eval/infrastructure/experiment_tracker.py +466 -0
- isa_model/eval/metrics.py +191 -21
- isa_model/inference/ai_factory.py +181 -605
- isa_model/inference/services/audio/base_stt_service.py +65 -1
- isa_model/inference/services/audio/base_tts_service.py +75 -1
- isa_model/inference/services/audio/openai_stt_service.py +189 -151
- isa_model/inference/services/audio/openai_tts_service.py +12 -10
- isa_model/inference/services/audio/replicate_tts_service.py +61 -56
- isa_model/inference/services/base_service.py +55 -17
- isa_model/inference/services/embedding/base_embed_service.py +65 -1
- isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
- isa_model/inference/services/embedding/openai_embed_service.py +8 -10
- isa_model/inference/services/helpers/stacked_config.py +148 -0
- isa_model/inference/services/img/__init__.py +18 -0
- isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -1
- isa_model/inference/services/{stacked → img}/flux_professional_service.py +25 -1
- isa_model/inference/services/{stacked → img/helpers}/base_stacked_service.py +40 -35
- isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +44 -31
- isa_model/inference/services/llm/__init__.py +3 -3
- isa_model/inference/services/llm/base_llm_service.py +492 -40
- isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
- isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
- isa_model/inference/services/llm/ollama_llm_service.py +51 -17
- isa_model/inference/services/llm/openai_llm_service.py +70 -19
- isa_model/inference/services/llm/yyds_llm_service.py +24 -23
- isa_model/inference/services/vision/__init__.py +38 -4
- isa_model/inference/services/vision/base_vision_service.py +218 -117
- isa_model/inference/services/vision/{isA_vision_service.py → disabled/isA_vision_service.py} +98 -0
- isa_model/inference/services/{stacked → vision}/doc_analysis_service.py +1 -1
- isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/vision/helpers/image_utils.py +272 -3
- isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
- isa_model/inference/services/vision/openai_vision_service.py +104 -307
- isa_model/inference/services/vision/replicate_vision_service.py +140 -325
- isa_model/inference/services/{stacked → vision}/ui_analysis_service.py +2 -498
- isa_model/scripts/register_models.py +370 -0
- isa_model/scripts/register_models_with_embeddings.py +510 -0
- isa_model/serving/api/fastapi_server.py +6 -1
- isa_model/serving/api/routes/unified.py +202 -0
- {isa_model-0.3.5.dist-info → isa_model-0.3.6.dist-info}/METADATA +4 -1
- {isa_model-0.3.5.dist-info → isa_model-0.3.6.dist-info}/RECORD +77 -53
- isa_model/config/__init__.py +0 -9
- isa_model/config/config_manager.py +0 -213
- isa_model/core/model_manager.py +0 -213
- isa_model/core/model_registry.py +0 -375
- isa_model/core/vision_models_init.py +0 -116
- isa_model/inference/billing_tracker.py +0 -406
- isa_model/inference/services/llm/triton_llm_service.py +0 -481
- isa_model/inference/services/stacked/__init__.py +0 -26
- isa_model/inference/services/stacked/config.py +0 -426
- isa_model/inference/services/vision/ollama_vision_service.py +0 -194
- /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
- /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
- /isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +0 -0
- {isa_model-0.3.5.dist-info → isa_model-0.3.6.dist-info}/WHEEL +0 -0
- {isa_model-0.3.5.dist-info → isa_model-0.3.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,472 @@
|
|
1
|
+
"""
|
2
|
+
LLM Evaluator implementing industry best practices for large language model evaluation.
|
3
|
+
|
4
|
+
Features:
|
5
|
+
- Support for multiple LLM providers (OpenAI, Anthropic, local models)
|
6
|
+
- Comprehensive text generation metrics
|
7
|
+
- Benchmark evaluation (MMLU, HellaSwag, etc.)
|
8
|
+
- Token usage and cost tracking
|
9
|
+
- Safety and bias evaluation
|
10
|
+
"""
|
11
|
+
|
12
|
+
import logging
|
13
|
+
import asyncio
|
14
|
+
from typing import Dict, List, Any, Optional, Union
|
15
|
+
import json
|
16
|
+
import re
|
17
|
+
|
18
|
+
from .base_evaluator import BaseEvaluator, EvaluationResult
|
19
|
+
|
20
|
+
try:
|
21
|
+
from ...inference.ai_factory import AIFactory
|
22
|
+
AI_FACTORY_AVAILABLE = True
|
23
|
+
except ImportError:
|
24
|
+
AI_FACTORY_AVAILABLE = False
|
25
|
+
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
|
29
|
+
class LLMEvaluator(BaseEvaluator):
|
30
|
+
"""
|
31
|
+
Comprehensive LLM evaluator with industry-standard metrics and practices.
|
32
|
+
|
33
|
+
Supports:
|
34
|
+
- Text generation evaluation
|
35
|
+
- Classification tasks
|
36
|
+
- Question answering
|
37
|
+
- Reasoning benchmarks
|
38
|
+
- Safety and bias assessment
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(self,
|
42
|
+
config: Optional[Dict[str, Any]] = None,
|
43
|
+
experiment_tracker: Optional[Any] = None):
|
44
|
+
"""
|
45
|
+
Initialize LLM evaluator.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
config: Evaluation configuration
|
49
|
+
experiment_tracker: Optional experiment tracking instance
|
50
|
+
"""
|
51
|
+
super().__init__("LLMEvaluator", config, experiment_tracker)
|
52
|
+
|
53
|
+
# Initialize AI factory for model inference
|
54
|
+
if AI_FACTORY_AVAILABLE:
|
55
|
+
try:
|
56
|
+
self.ai_factory = AIFactory()
|
57
|
+
logger.info("AI Factory initialized successfully")
|
58
|
+
except Exception as e:
|
59
|
+
logger.warning(f"Failed to initialize AI Factory: {e}")
|
60
|
+
self.ai_factory = None
|
61
|
+
else:
|
62
|
+
self.ai_factory = None
|
63
|
+
logger.warning("AI Factory not available")
|
64
|
+
|
65
|
+
# LLM-specific configuration
|
66
|
+
self.provider = self.config.get("provider", "openai")
|
67
|
+
self.model_name = self.config.get("model_name", "gpt-4.1-mini")
|
68
|
+
self.temperature = self.config.get("temperature", 0.1)
|
69
|
+
self.max_tokens = self.config.get("max_tokens", 512)
|
70
|
+
self.system_prompt = self.config.get("system_prompt", "")
|
71
|
+
|
72
|
+
# Token tracking
|
73
|
+
self.total_input_tokens = 0
|
74
|
+
self.total_output_tokens = 0
|
75
|
+
self.total_cost_usd = 0.0
|
76
|
+
|
77
|
+
async def evaluate_sample(self,
|
78
|
+
sample: Dict[str, Any],
|
79
|
+
model_interface: Any = None) -> Dict[str, Any]:
|
80
|
+
"""
|
81
|
+
Evaluate a single sample with the LLM.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
sample: Data sample containing prompt and expected response
|
85
|
+
model_interface: Model interface (uses AI factory if None)
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
Evaluation result for the sample
|
89
|
+
"""
|
90
|
+
# Use provided model interface or default to AI factory
|
91
|
+
if model_interface is None:
|
92
|
+
if not self.ai_factory:
|
93
|
+
raise ValueError("No model interface available for evaluation")
|
94
|
+
model_interface = self.ai_factory.get_llm(
|
95
|
+
model_name=self.model_name,
|
96
|
+
provider=self.provider
|
97
|
+
)
|
98
|
+
|
99
|
+
# Extract prompt and reference from sample
|
100
|
+
prompt = self._format_prompt(sample)
|
101
|
+
reference = sample.get("reference") or sample.get("expected_output") or sample.get("answer")
|
102
|
+
|
103
|
+
# Generate prediction
|
104
|
+
try:
|
105
|
+
response = await model_interface.ainvoke(prompt)
|
106
|
+
|
107
|
+
# Extract prediction text
|
108
|
+
if hasattr(response, 'content'):
|
109
|
+
prediction = response.content
|
110
|
+
elif isinstance(response, dict):
|
111
|
+
prediction = response.get('text') or response.get('content') or str(response)
|
112
|
+
elif isinstance(response, str):
|
113
|
+
prediction = response
|
114
|
+
else:
|
115
|
+
prediction = str(response)
|
116
|
+
|
117
|
+
# Track token usage if available
|
118
|
+
if hasattr(response, 'usage'):
|
119
|
+
usage = response.usage
|
120
|
+
input_tokens = getattr(usage, 'prompt_tokens', 0)
|
121
|
+
output_tokens = getattr(usage, 'completion_tokens', 0)
|
122
|
+
self.total_input_tokens += input_tokens
|
123
|
+
self.total_output_tokens += output_tokens
|
124
|
+
|
125
|
+
return {
|
126
|
+
"prediction": prediction.strip(),
|
127
|
+
"reference": reference,
|
128
|
+
"prompt": prompt,
|
129
|
+
"sample_id": sample.get("id", "unknown"),
|
130
|
+
"input_tokens": getattr(response, 'input_tokens', 0) if hasattr(response, 'input_tokens') else 0,
|
131
|
+
"output_tokens": getattr(response, 'output_tokens', 0) if hasattr(response, 'output_tokens') else 0
|
132
|
+
}
|
133
|
+
|
134
|
+
except Exception as e:
|
135
|
+
logger.error(f"Failed to evaluate sample {sample.get('id', 'unknown')}: {e}")
|
136
|
+
raise
|
137
|
+
|
138
|
+
def _format_prompt(self, sample: Dict[str, Any]) -> str:
|
139
|
+
"""
|
140
|
+
Format prompt based on sample type and configuration.
|
141
|
+
|
142
|
+
Args:
|
143
|
+
sample: Data sample
|
144
|
+
|
145
|
+
Returns:
|
146
|
+
Formatted prompt string
|
147
|
+
"""
|
148
|
+
# Handle different sample formats
|
149
|
+
if "prompt" in sample:
|
150
|
+
prompt = sample["prompt"]
|
151
|
+
elif "question" in sample:
|
152
|
+
prompt = sample["question"]
|
153
|
+
elif "input" in sample:
|
154
|
+
prompt = sample["input"]
|
155
|
+
elif "text" in sample:
|
156
|
+
prompt = sample["text"]
|
157
|
+
else:
|
158
|
+
prompt = str(sample)
|
159
|
+
|
160
|
+
# Add system prompt if configured
|
161
|
+
if self.system_prompt:
|
162
|
+
prompt = f"{self.system_prompt}\n\n{prompt}"
|
163
|
+
|
164
|
+
# Handle few-shot examples
|
165
|
+
if "examples" in sample and sample["examples"]:
|
166
|
+
examples_text = ""
|
167
|
+
for example in sample["examples"]:
|
168
|
+
if isinstance(example, dict):
|
169
|
+
ex_input = example.get("input", example.get("question", ""))
|
170
|
+
ex_output = example.get("output", example.get("answer", ""))
|
171
|
+
examples_text += f"Input: {ex_input}\nOutput: {ex_output}\n\n"
|
172
|
+
|
173
|
+
prompt = f"{examples_text}Input: {prompt}\nOutput:"
|
174
|
+
|
175
|
+
return prompt
|
176
|
+
|
177
|
+
def compute_metrics(self,
|
178
|
+
predictions: List[str],
|
179
|
+
references: List[str],
|
180
|
+
**kwargs) -> Dict[str, float]:
|
181
|
+
"""
|
182
|
+
Compute comprehensive LLM evaluation metrics.
|
183
|
+
|
184
|
+
Args:
|
185
|
+
predictions: Model predictions
|
186
|
+
references: Ground truth references
|
187
|
+
**kwargs: Additional parameters
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
Dictionary of computed metrics
|
191
|
+
"""
|
192
|
+
metrics = {}
|
193
|
+
|
194
|
+
if not predictions or not references:
|
195
|
+
logger.warning("Empty predictions or references, returning empty metrics")
|
196
|
+
return metrics
|
197
|
+
|
198
|
+
# Exact match accuracy
|
199
|
+
exact_matches = sum(1 for pred, ref in zip(predictions, references)
|
200
|
+
if self._normalize_text(pred) == self._normalize_text(ref))
|
201
|
+
metrics["exact_match"] = exact_matches / len(predictions)
|
202
|
+
|
203
|
+
# Token-based F1 score
|
204
|
+
f1_scores = []
|
205
|
+
for pred, ref in zip(predictions, references):
|
206
|
+
f1_score = self._compute_f1_score(pred, ref)
|
207
|
+
f1_scores.append(f1_score)
|
208
|
+
metrics["f1_score"] = sum(f1_scores) / len(f1_scores)
|
209
|
+
|
210
|
+
# BLEU score (simplified)
|
211
|
+
bleu_scores = []
|
212
|
+
for pred, ref in zip(predictions, references):
|
213
|
+
bleu_score = self._compute_bleu_score(pred, ref)
|
214
|
+
bleu_scores.append(bleu_score)
|
215
|
+
metrics["bleu_score"] = sum(bleu_scores) / len(bleu_scores)
|
216
|
+
|
217
|
+
# ROUGE-L score (simplified)
|
218
|
+
rouge_scores = []
|
219
|
+
for pred, ref in zip(predictions, references):
|
220
|
+
rouge_score = self._compute_rouge_l_score(pred, ref)
|
221
|
+
rouge_scores.append(rouge_score)
|
222
|
+
metrics["rouge_l"] = sum(rouge_scores) / len(rouge_scores)
|
223
|
+
|
224
|
+
# Response length statistics
|
225
|
+
pred_lengths = [len(pred.split()) for pred in predictions]
|
226
|
+
ref_lengths = [len(ref.split()) for ref in references]
|
227
|
+
|
228
|
+
metrics["avg_prediction_length"] = sum(pred_lengths) / len(pred_lengths)
|
229
|
+
metrics["avg_reference_length"] = sum(ref_lengths) / len(ref_lengths)
|
230
|
+
metrics["length_ratio"] = metrics["avg_prediction_length"] / metrics["avg_reference_length"] if metrics["avg_reference_length"] > 0 else 0
|
231
|
+
|
232
|
+
# Diversity metrics
|
233
|
+
metrics.update(self._compute_diversity_metrics(predictions))
|
234
|
+
|
235
|
+
# Token and cost metrics
|
236
|
+
if self.total_input_tokens > 0 or self.total_output_tokens > 0:
|
237
|
+
metrics["total_input_tokens"] = float(self.total_input_tokens)
|
238
|
+
metrics["total_output_tokens"] = float(self.total_output_tokens)
|
239
|
+
metrics["total_tokens"] = float(self.total_input_tokens + self.total_output_tokens)
|
240
|
+
metrics["estimated_cost_usd"] = self.total_cost_usd
|
241
|
+
|
242
|
+
return metrics
|
243
|
+
|
244
|
+
def _normalize_text(self, text: str) -> str:
|
245
|
+
"""Normalize text for comparison."""
|
246
|
+
# Remove extra whitespace
|
247
|
+
text = re.sub(r'\s+', ' ', text.strip())
|
248
|
+
# Convert to lowercase
|
249
|
+
text = text.lower()
|
250
|
+
# Remove punctuation for comparison
|
251
|
+
text = re.sub(r'[^\w\s]', '', text)
|
252
|
+
return text
|
253
|
+
|
254
|
+
def _compute_f1_score(self, prediction: str, reference: str) -> float:
|
255
|
+
"""Compute token-level F1 score."""
|
256
|
+
pred_tokens = set(self._normalize_text(prediction).split())
|
257
|
+
ref_tokens = set(self._normalize_text(reference).split())
|
258
|
+
|
259
|
+
if not pred_tokens and not ref_tokens:
|
260
|
+
return 1.0
|
261
|
+
|
262
|
+
if not pred_tokens or not ref_tokens:
|
263
|
+
return 0.0
|
264
|
+
|
265
|
+
common_tokens = pred_tokens & ref_tokens
|
266
|
+
|
267
|
+
if len(common_tokens) == 0:
|
268
|
+
return 0.0
|
269
|
+
|
270
|
+
precision = len(common_tokens) / len(pred_tokens)
|
271
|
+
recall = len(common_tokens) / len(ref_tokens)
|
272
|
+
|
273
|
+
return 2 * precision * recall / (precision + recall)
|
274
|
+
|
275
|
+
def _compute_bleu_score(self, prediction: str, reference: str) -> float:
|
276
|
+
"""Compute simplified BLEU score."""
|
277
|
+
pred_tokens = self._normalize_text(prediction).split()
|
278
|
+
ref_tokens = self._normalize_text(reference).split()
|
279
|
+
|
280
|
+
if not pred_tokens or not ref_tokens:
|
281
|
+
return 0.0
|
282
|
+
|
283
|
+
# Simplified unigram precision
|
284
|
+
pred_set = set(pred_tokens)
|
285
|
+
ref_set = set(ref_tokens)
|
286
|
+
overlap = len(pred_set & ref_set)
|
287
|
+
|
288
|
+
precision = overlap / len(pred_set) if pred_set else 0
|
289
|
+
recall = overlap / len(ref_set) if ref_set else 0
|
290
|
+
|
291
|
+
if precision + recall == 0:
|
292
|
+
return 0.0
|
293
|
+
|
294
|
+
return 2 * precision * recall / (precision + recall)
|
295
|
+
|
296
|
+
def _compute_rouge_l_score(self, prediction: str, reference: str) -> float:
|
297
|
+
"""Compute simplified ROUGE-L score."""
|
298
|
+
pred_tokens = self._normalize_text(prediction).split()
|
299
|
+
ref_tokens = self._normalize_text(reference).split()
|
300
|
+
|
301
|
+
if not pred_tokens or not ref_tokens:
|
302
|
+
return 0.0
|
303
|
+
|
304
|
+
# Simplified LCS computation
|
305
|
+
lcs_length = self._longest_common_subsequence_length(pred_tokens, ref_tokens)
|
306
|
+
|
307
|
+
if len(pred_tokens) == 0 or len(ref_tokens) == 0:
|
308
|
+
return 0.0
|
309
|
+
|
310
|
+
precision = lcs_length / len(pred_tokens)
|
311
|
+
recall = lcs_length / len(ref_tokens)
|
312
|
+
|
313
|
+
if precision + recall == 0:
|
314
|
+
return 0.0
|
315
|
+
|
316
|
+
return 2 * precision * recall / (precision + recall)
|
317
|
+
|
318
|
+
def _longest_common_subsequence_length(self, seq1: List[str], seq2: List[str]) -> int:
|
319
|
+
"""Compute length of longest common subsequence."""
|
320
|
+
m, n = len(seq1), len(seq2)
|
321
|
+
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
322
|
+
|
323
|
+
for i in range(1, m + 1):
|
324
|
+
for j in range(1, n + 1):
|
325
|
+
if seq1[i-1] == seq2[j-1]:
|
326
|
+
dp[i][j] = dp[i-1][j-1] + 1
|
327
|
+
else:
|
328
|
+
dp[i][j] = max(dp[i-1][j], dp[i][j-1])
|
329
|
+
|
330
|
+
return dp[m][n]
|
331
|
+
|
332
|
+
def _compute_diversity_metrics(self, predictions: List[str]) -> Dict[str, float]:
|
333
|
+
"""Compute diversity metrics for predictions."""
|
334
|
+
all_tokens = []
|
335
|
+
all_bigrams = []
|
336
|
+
|
337
|
+
for pred in predictions:
|
338
|
+
tokens = self._normalize_text(pred).split()
|
339
|
+
all_tokens.extend(tokens)
|
340
|
+
|
341
|
+
# Generate bigrams
|
342
|
+
for i in range(len(tokens) - 1):
|
343
|
+
all_bigrams.append((tokens[i], tokens[i + 1]))
|
344
|
+
|
345
|
+
# Distinct-n metrics
|
346
|
+
distinct_1 = len(set(all_tokens)) / len(all_tokens) if all_tokens else 0
|
347
|
+
distinct_2 = len(set(all_bigrams)) / len(all_bigrams) if all_bigrams else 0
|
348
|
+
|
349
|
+
return {
|
350
|
+
"distinct_1": distinct_1,
|
351
|
+
"distinct_2": distinct_2,
|
352
|
+
"vocab_size": float(len(set(all_tokens)))
|
353
|
+
}
|
354
|
+
|
355
|
+
def get_supported_metrics(self) -> List[str]:
|
356
|
+
"""Get list of metrics supported by this evaluator."""
|
357
|
+
return [
|
358
|
+
"exact_match",
|
359
|
+
"f1_score",
|
360
|
+
"bleu_score",
|
361
|
+
"rouge_l",
|
362
|
+
"avg_prediction_length",
|
363
|
+
"avg_reference_length",
|
364
|
+
"length_ratio",
|
365
|
+
"distinct_1",
|
366
|
+
"distinct_2",
|
367
|
+
"vocab_size",
|
368
|
+
"total_input_tokens",
|
369
|
+
"total_output_tokens",
|
370
|
+
"total_tokens",
|
371
|
+
"estimated_cost_usd"
|
372
|
+
]
|
373
|
+
|
374
|
+
async def evaluate_classification(self,
|
375
|
+
dataset: List[Dict[str, Any]],
|
376
|
+
class_labels: List[str],
|
377
|
+
model_name: str = "unknown") -> EvaluationResult:
|
378
|
+
"""
|
379
|
+
Evaluate classification tasks with specialized metrics.
|
380
|
+
|
381
|
+
Args:
|
382
|
+
dataset: Classification dataset
|
383
|
+
class_labels: List of possible class labels
|
384
|
+
model_name: Name of the model being evaluated
|
385
|
+
|
386
|
+
Returns:
|
387
|
+
Classification evaluation results
|
388
|
+
"""
|
389
|
+
# Update config for classification
|
390
|
+
self.config.update({
|
391
|
+
"task_type": "classification",
|
392
|
+
"class_labels": class_labels,
|
393
|
+
"max_tokens": 10 # Short responses for classification
|
394
|
+
})
|
395
|
+
|
396
|
+
result = await self.evaluate(
|
397
|
+
model_interface=None,
|
398
|
+
dataset=dataset,
|
399
|
+
dataset_name="classification_task",
|
400
|
+
model_name=model_name
|
401
|
+
)
|
402
|
+
|
403
|
+
# Add classification-specific metrics
|
404
|
+
if result.predictions and result.references:
|
405
|
+
classification_metrics = self._compute_classification_metrics(
|
406
|
+
result.predictions,
|
407
|
+
result.references,
|
408
|
+
class_labels
|
409
|
+
)
|
410
|
+
result.metrics.update(classification_metrics)
|
411
|
+
|
412
|
+
return result
|
413
|
+
|
414
|
+
def _compute_classification_metrics(self,
|
415
|
+
predictions: List[str],
|
416
|
+
references: List[str],
|
417
|
+
class_labels: List[str]) -> Dict[str, float]:
|
418
|
+
"""Compute classification-specific metrics."""
|
419
|
+
# Map predictions to class labels
|
420
|
+
mapped_predictions = []
|
421
|
+
for pred in predictions:
|
422
|
+
mapped_pred = self._map_to_class_label(pred, class_labels)
|
423
|
+
mapped_predictions.append(mapped_pred)
|
424
|
+
|
425
|
+
# Compute accuracy
|
426
|
+
correct = sum(1 for pred, ref in zip(mapped_predictions, references) if pred == ref)
|
427
|
+
accuracy = correct / len(predictions) if predictions else 0
|
428
|
+
|
429
|
+
# Compute per-class precision and recall
|
430
|
+
class_metrics = {}
|
431
|
+
for label in class_labels:
|
432
|
+
tp = sum(1 for pred, ref in zip(mapped_predictions, references) if pred == label and ref == label)
|
433
|
+
fp = sum(1 for pred, ref in zip(mapped_predictions, references) if pred == label and ref != label)
|
434
|
+
fn = sum(1 for pred, ref in zip(mapped_predictions, references) if pred != label and ref == label)
|
435
|
+
|
436
|
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
437
|
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
438
|
+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
439
|
+
|
440
|
+
class_metrics[f"{label}_precision"] = precision
|
441
|
+
class_metrics[f"{label}_recall"] = recall
|
442
|
+
class_metrics[f"{label}_f1"] = f1
|
443
|
+
|
444
|
+
# Compute macro averages
|
445
|
+
precisions = [class_metrics[f"{label}_precision"] for label in class_labels]
|
446
|
+
recalls = [class_metrics[f"{label}_recall"] for label in class_labels]
|
447
|
+
f1s = [class_metrics[f"{label}_f1"] for label in class_labels]
|
448
|
+
|
449
|
+
return {
|
450
|
+
"accuracy": accuracy,
|
451
|
+
"macro_precision": sum(precisions) / len(precisions) if precisions else 0,
|
452
|
+
"macro_recall": sum(recalls) / len(recalls) if recalls else 0,
|
453
|
+
"macro_f1": sum(f1s) / len(f1s) if f1s else 0,
|
454
|
+
**class_metrics
|
455
|
+
}
|
456
|
+
|
457
|
+
def _map_to_class_label(self, prediction: str, class_labels: List[str]) -> str:
|
458
|
+
"""Map prediction text to the most likely class label."""
|
459
|
+
pred_normalized = self._normalize_text(prediction)
|
460
|
+
|
461
|
+
# Direct match
|
462
|
+
for label in class_labels:
|
463
|
+
if self._normalize_text(label) == pred_normalized:
|
464
|
+
return label
|
465
|
+
|
466
|
+
# Substring match
|
467
|
+
for label in class_labels:
|
468
|
+
if self._normalize_text(label) in pred_normalized:
|
469
|
+
return label
|
470
|
+
|
471
|
+
# Return first label if no match found
|
472
|
+
return class_labels[0] if class_labels else "unknown"
|