isa-model 0.3.5__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. isa_model/__init__.py +30 -1
  2. isa_model/client.py +770 -0
  3. isa_model/core/config/__init__.py +16 -0
  4. isa_model/core/config/config_manager.py +514 -0
  5. isa_model/core/config.py +426 -0
  6. isa_model/core/models/model_billing_tracker.py +476 -0
  7. isa_model/core/models/model_manager.py +399 -0
  8. isa_model/core/{storage/supabase_storage.py → models/model_repo.py} +72 -73
  9. isa_model/core/pricing_manager.py +426 -0
  10. isa_model/core/services/__init__.py +19 -0
  11. isa_model/core/services/intelligent_model_selector.py +547 -0
  12. isa_model/core/types.py +291 -0
  13. isa_model/deployment/__init__.py +2 -0
  14. isa_model/deployment/cloud/modal/isa_vision_doc_service.py +157 -3
  15. isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
  16. isa_model/deployment/cloud/modal/isa_vision_ui_service.py +104 -3
  17. isa_model/deployment/cloud/modal/register_models.py +321 -0
  18. isa_model/deployment/runtime/deployed_service.py +338 -0
  19. isa_model/deployment/services/__init__.py +9 -0
  20. isa_model/deployment/services/auto_deploy_vision_service.py +537 -0
  21. isa_model/deployment/services/model_service.py +332 -0
  22. isa_model/deployment/services/service_monitor.py +356 -0
  23. isa_model/deployment/services/service_registry.py +527 -0
  24. isa_model/eval/__init__.py +80 -44
  25. isa_model/eval/config/__init__.py +10 -0
  26. isa_model/eval/config/evaluation_config.py +108 -0
  27. isa_model/eval/evaluators/__init__.py +18 -0
  28. isa_model/eval/evaluators/base_evaluator.py +503 -0
  29. isa_model/eval/evaluators/llm_evaluator.py +472 -0
  30. isa_model/eval/factory.py +417 -709
  31. isa_model/eval/infrastructure/__init__.py +24 -0
  32. isa_model/eval/infrastructure/experiment_tracker.py +466 -0
  33. isa_model/eval/metrics.py +191 -21
  34. isa_model/inference/ai_factory.py +181 -605
  35. isa_model/inference/services/audio/base_stt_service.py +65 -1
  36. isa_model/inference/services/audio/base_tts_service.py +75 -1
  37. isa_model/inference/services/audio/openai_stt_service.py +189 -151
  38. isa_model/inference/services/audio/openai_tts_service.py +12 -10
  39. isa_model/inference/services/audio/replicate_tts_service.py +61 -56
  40. isa_model/inference/services/base_service.py +55 -17
  41. isa_model/inference/services/embedding/base_embed_service.py +65 -1
  42. isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
  43. isa_model/inference/services/embedding/openai_embed_service.py +8 -10
  44. isa_model/inference/services/helpers/stacked_config.py +148 -0
  45. isa_model/inference/services/img/__init__.py +18 -0
  46. isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -1
  47. isa_model/inference/services/{stacked → img}/flux_professional_service.py +25 -1
  48. isa_model/inference/services/{stacked → img/helpers}/base_stacked_service.py +40 -35
  49. isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +44 -31
  50. isa_model/inference/services/llm/__init__.py +3 -3
  51. isa_model/inference/services/llm/base_llm_service.py +492 -40
  52. isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
  53. isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
  54. isa_model/inference/services/llm/ollama_llm_service.py +51 -17
  55. isa_model/inference/services/llm/openai_llm_service.py +70 -19
  56. isa_model/inference/services/llm/yyds_llm_service.py +24 -23
  57. isa_model/inference/services/vision/__init__.py +38 -4
  58. isa_model/inference/services/vision/base_vision_service.py +218 -117
  59. isa_model/inference/services/vision/{isA_vision_service.py → disabled/isA_vision_service.py} +98 -0
  60. isa_model/inference/services/{stacked → vision}/doc_analysis_service.py +1 -1
  61. isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
  62. isa_model/inference/services/vision/helpers/image_utils.py +272 -3
  63. isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
  64. isa_model/inference/services/vision/openai_vision_service.py +104 -307
  65. isa_model/inference/services/vision/replicate_vision_service.py +140 -325
  66. isa_model/inference/services/{stacked → vision}/ui_analysis_service.py +2 -498
  67. isa_model/scripts/register_models.py +370 -0
  68. isa_model/scripts/register_models_with_embeddings.py +510 -0
  69. isa_model/serving/api/fastapi_server.py +6 -1
  70. isa_model/serving/api/routes/unified.py +202 -0
  71. {isa_model-0.3.5.dist-info → isa_model-0.3.6.dist-info}/METADATA +4 -1
  72. {isa_model-0.3.5.dist-info → isa_model-0.3.6.dist-info}/RECORD +77 -53
  73. isa_model/config/__init__.py +0 -9
  74. isa_model/config/config_manager.py +0 -213
  75. isa_model/core/model_manager.py +0 -213
  76. isa_model/core/model_registry.py +0 -375
  77. isa_model/core/vision_models_init.py +0 -116
  78. isa_model/inference/billing_tracker.py +0 -406
  79. isa_model/inference/services/llm/triton_llm_service.py +0 -481
  80. isa_model/inference/services/stacked/__init__.py +0 -26
  81. isa_model/inference/services/stacked/config.py +0 -426
  82. isa_model/inference/services/vision/ollama_vision_service.py +0 -194
  83. /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
  84. /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
  85. /isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +0 -0
  86. {isa_model-0.3.5.dist-info → isa_model-0.3.6.dist-info}/WHEEL +0 -0
  87. {isa_model-0.3.5.dist-info → isa_model-0.3.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,472 @@
1
+ """
2
+ LLM Evaluator implementing industry best practices for large language model evaluation.
3
+
4
+ Features:
5
+ - Support for multiple LLM providers (OpenAI, Anthropic, local models)
6
+ - Comprehensive text generation metrics
7
+ - Benchmark evaluation (MMLU, HellaSwag, etc.)
8
+ - Token usage and cost tracking
9
+ - Safety and bias evaluation
10
+ """
11
+
12
+ import logging
13
+ import asyncio
14
+ from typing import Dict, List, Any, Optional, Union
15
+ import json
16
+ import re
17
+
18
+ from .base_evaluator import BaseEvaluator, EvaluationResult
19
+
20
+ try:
21
+ from ...inference.ai_factory import AIFactory
22
+ AI_FACTORY_AVAILABLE = True
23
+ except ImportError:
24
+ AI_FACTORY_AVAILABLE = False
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class LLMEvaluator(BaseEvaluator):
30
+ """
31
+ Comprehensive LLM evaluator with industry-standard metrics and practices.
32
+
33
+ Supports:
34
+ - Text generation evaluation
35
+ - Classification tasks
36
+ - Question answering
37
+ - Reasoning benchmarks
38
+ - Safety and bias assessment
39
+ """
40
+
41
+ def __init__(self,
42
+ config: Optional[Dict[str, Any]] = None,
43
+ experiment_tracker: Optional[Any] = None):
44
+ """
45
+ Initialize LLM evaluator.
46
+
47
+ Args:
48
+ config: Evaluation configuration
49
+ experiment_tracker: Optional experiment tracking instance
50
+ """
51
+ super().__init__("LLMEvaluator", config, experiment_tracker)
52
+
53
+ # Initialize AI factory for model inference
54
+ if AI_FACTORY_AVAILABLE:
55
+ try:
56
+ self.ai_factory = AIFactory()
57
+ logger.info("AI Factory initialized successfully")
58
+ except Exception as e:
59
+ logger.warning(f"Failed to initialize AI Factory: {e}")
60
+ self.ai_factory = None
61
+ else:
62
+ self.ai_factory = None
63
+ logger.warning("AI Factory not available")
64
+
65
+ # LLM-specific configuration
66
+ self.provider = self.config.get("provider", "openai")
67
+ self.model_name = self.config.get("model_name", "gpt-4.1-mini")
68
+ self.temperature = self.config.get("temperature", 0.1)
69
+ self.max_tokens = self.config.get("max_tokens", 512)
70
+ self.system_prompt = self.config.get("system_prompt", "")
71
+
72
+ # Token tracking
73
+ self.total_input_tokens = 0
74
+ self.total_output_tokens = 0
75
+ self.total_cost_usd = 0.0
76
+
77
+ async def evaluate_sample(self,
78
+ sample: Dict[str, Any],
79
+ model_interface: Any = None) -> Dict[str, Any]:
80
+ """
81
+ Evaluate a single sample with the LLM.
82
+
83
+ Args:
84
+ sample: Data sample containing prompt and expected response
85
+ model_interface: Model interface (uses AI factory if None)
86
+
87
+ Returns:
88
+ Evaluation result for the sample
89
+ """
90
+ # Use provided model interface or default to AI factory
91
+ if model_interface is None:
92
+ if not self.ai_factory:
93
+ raise ValueError("No model interface available for evaluation")
94
+ model_interface = self.ai_factory.get_llm(
95
+ model_name=self.model_name,
96
+ provider=self.provider
97
+ )
98
+
99
+ # Extract prompt and reference from sample
100
+ prompt = self._format_prompt(sample)
101
+ reference = sample.get("reference") or sample.get("expected_output") or sample.get("answer")
102
+
103
+ # Generate prediction
104
+ try:
105
+ response = await model_interface.ainvoke(prompt)
106
+
107
+ # Extract prediction text
108
+ if hasattr(response, 'content'):
109
+ prediction = response.content
110
+ elif isinstance(response, dict):
111
+ prediction = response.get('text') or response.get('content') or str(response)
112
+ elif isinstance(response, str):
113
+ prediction = response
114
+ else:
115
+ prediction = str(response)
116
+
117
+ # Track token usage if available
118
+ if hasattr(response, 'usage'):
119
+ usage = response.usage
120
+ input_tokens = getattr(usage, 'prompt_tokens', 0)
121
+ output_tokens = getattr(usage, 'completion_tokens', 0)
122
+ self.total_input_tokens += input_tokens
123
+ self.total_output_tokens += output_tokens
124
+
125
+ return {
126
+ "prediction": prediction.strip(),
127
+ "reference": reference,
128
+ "prompt": prompt,
129
+ "sample_id": sample.get("id", "unknown"),
130
+ "input_tokens": getattr(response, 'input_tokens', 0) if hasattr(response, 'input_tokens') else 0,
131
+ "output_tokens": getattr(response, 'output_tokens', 0) if hasattr(response, 'output_tokens') else 0
132
+ }
133
+
134
+ except Exception as e:
135
+ logger.error(f"Failed to evaluate sample {sample.get('id', 'unknown')}: {e}")
136
+ raise
137
+
138
+ def _format_prompt(self, sample: Dict[str, Any]) -> str:
139
+ """
140
+ Format prompt based on sample type and configuration.
141
+
142
+ Args:
143
+ sample: Data sample
144
+
145
+ Returns:
146
+ Formatted prompt string
147
+ """
148
+ # Handle different sample formats
149
+ if "prompt" in sample:
150
+ prompt = sample["prompt"]
151
+ elif "question" in sample:
152
+ prompt = sample["question"]
153
+ elif "input" in sample:
154
+ prompt = sample["input"]
155
+ elif "text" in sample:
156
+ prompt = sample["text"]
157
+ else:
158
+ prompt = str(sample)
159
+
160
+ # Add system prompt if configured
161
+ if self.system_prompt:
162
+ prompt = f"{self.system_prompt}\n\n{prompt}"
163
+
164
+ # Handle few-shot examples
165
+ if "examples" in sample and sample["examples"]:
166
+ examples_text = ""
167
+ for example in sample["examples"]:
168
+ if isinstance(example, dict):
169
+ ex_input = example.get("input", example.get("question", ""))
170
+ ex_output = example.get("output", example.get("answer", ""))
171
+ examples_text += f"Input: {ex_input}\nOutput: {ex_output}\n\n"
172
+
173
+ prompt = f"{examples_text}Input: {prompt}\nOutput:"
174
+
175
+ return prompt
176
+
177
+ def compute_metrics(self,
178
+ predictions: List[str],
179
+ references: List[str],
180
+ **kwargs) -> Dict[str, float]:
181
+ """
182
+ Compute comprehensive LLM evaluation metrics.
183
+
184
+ Args:
185
+ predictions: Model predictions
186
+ references: Ground truth references
187
+ **kwargs: Additional parameters
188
+
189
+ Returns:
190
+ Dictionary of computed metrics
191
+ """
192
+ metrics = {}
193
+
194
+ if not predictions or not references:
195
+ logger.warning("Empty predictions or references, returning empty metrics")
196
+ return metrics
197
+
198
+ # Exact match accuracy
199
+ exact_matches = sum(1 for pred, ref in zip(predictions, references)
200
+ if self._normalize_text(pred) == self._normalize_text(ref))
201
+ metrics["exact_match"] = exact_matches / len(predictions)
202
+
203
+ # Token-based F1 score
204
+ f1_scores = []
205
+ for pred, ref in zip(predictions, references):
206
+ f1_score = self._compute_f1_score(pred, ref)
207
+ f1_scores.append(f1_score)
208
+ metrics["f1_score"] = sum(f1_scores) / len(f1_scores)
209
+
210
+ # BLEU score (simplified)
211
+ bleu_scores = []
212
+ for pred, ref in zip(predictions, references):
213
+ bleu_score = self._compute_bleu_score(pred, ref)
214
+ bleu_scores.append(bleu_score)
215
+ metrics["bleu_score"] = sum(bleu_scores) / len(bleu_scores)
216
+
217
+ # ROUGE-L score (simplified)
218
+ rouge_scores = []
219
+ for pred, ref in zip(predictions, references):
220
+ rouge_score = self._compute_rouge_l_score(pred, ref)
221
+ rouge_scores.append(rouge_score)
222
+ metrics["rouge_l"] = sum(rouge_scores) / len(rouge_scores)
223
+
224
+ # Response length statistics
225
+ pred_lengths = [len(pred.split()) for pred in predictions]
226
+ ref_lengths = [len(ref.split()) for ref in references]
227
+
228
+ metrics["avg_prediction_length"] = sum(pred_lengths) / len(pred_lengths)
229
+ metrics["avg_reference_length"] = sum(ref_lengths) / len(ref_lengths)
230
+ metrics["length_ratio"] = metrics["avg_prediction_length"] / metrics["avg_reference_length"] if metrics["avg_reference_length"] > 0 else 0
231
+
232
+ # Diversity metrics
233
+ metrics.update(self._compute_diversity_metrics(predictions))
234
+
235
+ # Token and cost metrics
236
+ if self.total_input_tokens > 0 or self.total_output_tokens > 0:
237
+ metrics["total_input_tokens"] = float(self.total_input_tokens)
238
+ metrics["total_output_tokens"] = float(self.total_output_tokens)
239
+ metrics["total_tokens"] = float(self.total_input_tokens + self.total_output_tokens)
240
+ metrics["estimated_cost_usd"] = self.total_cost_usd
241
+
242
+ return metrics
243
+
244
+ def _normalize_text(self, text: str) -> str:
245
+ """Normalize text for comparison."""
246
+ # Remove extra whitespace
247
+ text = re.sub(r'\s+', ' ', text.strip())
248
+ # Convert to lowercase
249
+ text = text.lower()
250
+ # Remove punctuation for comparison
251
+ text = re.sub(r'[^\w\s]', '', text)
252
+ return text
253
+
254
+ def _compute_f1_score(self, prediction: str, reference: str) -> float:
255
+ """Compute token-level F1 score."""
256
+ pred_tokens = set(self._normalize_text(prediction).split())
257
+ ref_tokens = set(self._normalize_text(reference).split())
258
+
259
+ if not pred_tokens and not ref_tokens:
260
+ return 1.0
261
+
262
+ if not pred_tokens or not ref_tokens:
263
+ return 0.0
264
+
265
+ common_tokens = pred_tokens & ref_tokens
266
+
267
+ if len(common_tokens) == 0:
268
+ return 0.0
269
+
270
+ precision = len(common_tokens) / len(pred_tokens)
271
+ recall = len(common_tokens) / len(ref_tokens)
272
+
273
+ return 2 * precision * recall / (precision + recall)
274
+
275
+ def _compute_bleu_score(self, prediction: str, reference: str) -> float:
276
+ """Compute simplified BLEU score."""
277
+ pred_tokens = self._normalize_text(prediction).split()
278
+ ref_tokens = self._normalize_text(reference).split()
279
+
280
+ if not pred_tokens or not ref_tokens:
281
+ return 0.0
282
+
283
+ # Simplified unigram precision
284
+ pred_set = set(pred_tokens)
285
+ ref_set = set(ref_tokens)
286
+ overlap = len(pred_set & ref_set)
287
+
288
+ precision = overlap / len(pred_set) if pred_set else 0
289
+ recall = overlap / len(ref_set) if ref_set else 0
290
+
291
+ if precision + recall == 0:
292
+ return 0.0
293
+
294
+ return 2 * precision * recall / (precision + recall)
295
+
296
+ def _compute_rouge_l_score(self, prediction: str, reference: str) -> float:
297
+ """Compute simplified ROUGE-L score."""
298
+ pred_tokens = self._normalize_text(prediction).split()
299
+ ref_tokens = self._normalize_text(reference).split()
300
+
301
+ if not pred_tokens or not ref_tokens:
302
+ return 0.0
303
+
304
+ # Simplified LCS computation
305
+ lcs_length = self._longest_common_subsequence_length(pred_tokens, ref_tokens)
306
+
307
+ if len(pred_tokens) == 0 or len(ref_tokens) == 0:
308
+ return 0.0
309
+
310
+ precision = lcs_length / len(pred_tokens)
311
+ recall = lcs_length / len(ref_tokens)
312
+
313
+ if precision + recall == 0:
314
+ return 0.0
315
+
316
+ return 2 * precision * recall / (precision + recall)
317
+
318
+ def _longest_common_subsequence_length(self, seq1: List[str], seq2: List[str]) -> int:
319
+ """Compute length of longest common subsequence."""
320
+ m, n = len(seq1), len(seq2)
321
+ dp = [[0] * (n + 1) for _ in range(m + 1)]
322
+
323
+ for i in range(1, m + 1):
324
+ for j in range(1, n + 1):
325
+ if seq1[i-1] == seq2[j-1]:
326
+ dp[i][j] = dp[i-1][j-1] + 1
327
+ else:
328
+ dp[i][j] = max(dp[i-1][j], dp[i][j-1])
329
+
330
+ return dp[m][n]
331
+
332
+ def _compute_diversity_metrics(self, predictions: List[str]) -> Dict[str, float]:
333
+ """Compute diversity metrics for predictions."""
334
+ all_tokens = []
335
+ all_bigrams = []
336
+
337
+ for pred in predictions:
338
+ tokens = self._normalize_text(pred).split()
339
+ all_tokens.extend(tokens)
340
+
341
+ # Generate bigrams
342
+ for i in range(len(tokens) - 1):
343
+ all_bigrams.append((tokens[i], tokens[i + 1]))
344
+
345
+ # Distinct-n metrics
346
+ distinct_1 = len(set(all_tokens)) / len(all_tokens) if all_tokens else 0
347
+ distinct_2 = len(set(all_bigrams)) / len(all_bigrams) if all_bigrams else 0
348
+
349
+ return {
350
+ "distinct_1": distinct_1,
351
+ "distinct_2": distinct_2,
352
+ "vocab_size": float(len(set(all_tokens)))
353
+ }
354
+
355
+ def get_supported_metrics(self) -> List[str]:
356
+ """Get list of metrics supported by this evaluator."""
357
+ return [
358
+ "exact_match",
359
+ "f1_score",
360
+ "bleu_score",
361
+ "rouge_l",
362
+ "avg_prediction_length",
363
+ "avg_reference_length",
364
+ "length_ratio",
365
+ "distinct_1",
366
+ "distinct_2",
367
+ "vocab_size",
368
+ "total_input_tokens",
369
+ "total_output_tokens",
370
+ "total_tokens",
371
+ "estimated_cost_usd"
372
+ ]
373
+
374
+ async def evaluate_classification(self,
375
+ dataset: List[Dict[str, Any]],
376
+ class_labels: List[str],
377
+ model_name: str = "unknown") -> EvaluationResult:
378
+ """
379
+ Evaluate classification tasks with specialized metrics.
380
+
381
+ Args:
382
+ dataset: Classification dataset
383
+ class_labels: List of possible class labels
384
+ model_name: Name of the model being evaluated
385
+
386
+ Returns:
387
+ Classification evaluation results
388
+ """
389
+ # Update config for classification
390
+ self.config.update({
391
+ "task_type": "classification",
392
+ "class_labels": class_labels,
393
+ "max_tokens": 10 # Short responses for classification
394
+ })
395
+
396
+ result = await self.evaluate(
397
+ model_interface=None,
398
+ dataset=dataset,
399
+ dataset_name="classification_task",
400
+ model_name=model_name
401
+ )
402
+
403
+ # Add classification-specific metrics
404
+ if result.predictions and result.references:
405
+ classification_metrics = self._compute_classification_metrics(
406
+ result.predictions,
407
+ result.references,
408
+ class_labels
409
+ )
410
+ result.metrics.update(classification_metrics)
411
+
412
+ return result
413
+
414
+ def _compute_classification_metrics(self,
415
+ predictions: List[str],
416
+ references: List[str],
417
+ class_labels: List[str]) -> Dict[str, float]:
418
+ """Compute classification-specific metrics."""
419
+ # Map predictions to class labels
420
+ mapped_predictions = []
421
+ for pred in predictions:
422
+ mapped_pred = self._map_to_class_label(pred, class_labels)
423
+ mapped_predictions.append(mapped_pred)
424
+
425
+ # Compute accuracy
426
+ correct = sum(1 for pred, ref in zip(mapped_predictions, references) if pred == ref)
427
+ accuracy = correct / len(predictions) if predictions else 0
428
+
429
+ # Compute per-class precision and recall
430
+ class_metrics = {}
431
+ for label in class_labels:
432
+ tp = sum(1 for pred, ref in zip(mapped_predictions, references) if pred == label and ref == label)
433
+ fp = sum(1 for pred, ref in zip(mapped_predictions, references) if pred == label and ref != label)
434
+ fn = sum(1 for pred, ref in zip(mapped_predictions, references) if pred != label and ref == label)
435
+
436
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
437
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
438
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
439
+
440
+ class_metrics[f"{label}_precision"] = precision
441
+ class_metrics[f"{label}_recall"] = recall
442
+ class_metrics[f"{label}_f1"] = f1
443
+
444
+ # Compute macro averages
445
+ precisions = [class_metrics[f"{label}_precision"] for label in class_labels]
446
+ recalls = [class_metrics[f"{label}_recall"] for label in class_labels]
447
+ f1s = [class_metrics[f"{label}_f1"] for label in class_labels]
448
+
449
+ return {
450
+ "accuracy": accuracy,
451
+ "macro_precision": sum(precisions) / len(precisions) if precisions else 0,
452
+ "macro_recall": sum(recalls) / len(recalls) if recalls else 0,
453
+ "macro_f1": sum(f1s) / len(f1s) if f1s else 0,
454
+ **class_metrics
455
+ }
456
+
457
+ def _map_to_class_label(self, prediction: str, class_labels: List[str]) -> str:
458
+ """Map prediction text to the most likely class label."""
459
+ pred_normalized = self._normalize_text(prediction)
460
+
461
+ # Direct match
462
+ for label in class_labels:
463
+ if self._normalize_text(label) == pred_normalized:
464
+ return label
465
+
466
+ # Substring match
467
+ for label in class_labels:
468
+ if self._normalize_text(label) in pred_normalized:
469
+ return label
470
+
471
+ # Return first label if no match found
472
+ return class_labels[0] if class_labels else "unknown"