isa-model 0.3.9__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/client.py +732 -565
- isa_model/core/cache/redis_cache.py +401 -0
- isa_model/core/config/config_manager.py +53 -10
- isa_model/core/config.py +1 -1
- isa_model/core/database/__init__.py +1 -0
- isa_model/core/database/migrations.py +277 -0
- isa_model/core/database/supabase_client.py +123 -0
- isa_model/core/models/__init__.py +37 -0
- isa_model/core/models/model_billing_tracker.py +60 -88
- isa_model/core/models/model_manager.py +36 -18
- isa_model/core/models/model_repo.py +44 -38
- isa_model/core/models/model_statistics_tracker.py +234 -0
- isa_model/core/models/model_storage.py +0 -1
- isa_model/core/models/model_version_manager.py +959 -0
- isa_model/core/pricing_manager.py +2 -249
- isa_model/core/resilience/circuit_breaker.py +366 -0
- isa_model/core/security/secrets.py +358 -0
- isa_model/core/services/__init__.py +2 -4
- isa_model/core/services/intelligent_model_selector.py +101 -370
- isa_model/core/storage/hf_storage.py +1 -1
- isa_model/core/types.py +7 -0
- isa_model/deployment/cloud/modal/isa_audio_chatTTS_service.py +520 -0
- isa_model/deployment/cloud/modal/isa_audio_fish_service.py +0 -0
- isa_model/deployment/cloud/modal/isa_audio_openvoice_service.py +758 -0
- isa_model/deployment/cloud/modal/isa_audio_service_v2.py +1044 -0
- isa_model/deployment/cloud/modal/isa_embed_rerank_service.py +296 -0
- isa_model/deployment/cloud/modal/isa_video_hunyuan_service.py +423 -0
- isa_model/deployment/cloud/modal/isa_vision_ocr_service.py +519 -0
- isa_model/deployment/cloud/modal/isa_vision_qwen25_service.py +709 -0
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +467 -323
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +607 -180
- isa_model/deployment/cloud/modal/isa_vision_ui_service_optimized.py +660 -0
- isa_model/deployment/core/deployment_manager.py +6 -4
- isa_model/deployment/services/auto_hf_modal_deployer.py +894 -0
- isa_model/eval/benchmarks/__init__.py +27 -0
- isa_model/eval/benchmarks/multimodal_datasets.py +460 -0
- isa_model/eval/benchmarks.py +244 -12
- isa_model/eval/evaluators/__init__.py +8 -2
- isa_model/eval/evaluators/audio_evaluator.py +727 -0
- isa_model/eval/evaluators/embedding_evaluator.py +742 -0
- isa_model/eval/evaluators/vision_evaluator.py +564 -0
- isa_model/eval/example_evaluation.py +395 -0
- isa_model/eval/factory.py +272 -5
- isa_model/eval/isa_benchmarks.py +700 -0
- isa_model/eval/isa_integration.py +582 -0
- isa_model/eval/metrics.py +159 -6
- isa_model/eval/tests/unit/test_basic.py +396 -0
- isa_model/inference/ai_factory.py +44 -8
- isa_model/inference/services/audio/__init__.py +21 -0
- isa_model/inference/services/audio/base_realtime_service.py +225 -0
- isa_model/inference/services/audio/isa_tts_service.py +0 -0
- isa_model/inference/services/audio/openai_realtime_service.py +320 -124
- isa_model/inference/services/audio/openai_stt_service.py +32 -6
- isa_model/inference/services/base_service.py +17 -1
- isa_model/inference/services/embedding/__init__.py +13 -0
- isa_model/inference/services/embedding/base_embed_service.py +111 -8
- isa_model/inference/services/embedding/isa_embed_service.py +305 -0
- isa_model/inference/services/embedding/openai_embed_service.py +2 -4
- isa_model/inference/services/embedding/tests/test_embedding.py +222 -0
- isa_model/inference/services/img/__init__.py +2 -2
- isa_model/inference/services/img/base_image_gen_service.py +24 -7
- isa_model/inference/services/img/replicate_image_gen_service.py +84 -422
- isa_model/inference/services/img/services/replicate_face_swap.py +193 -0
- isa_model/inference/services/img/services/replicate_flux.py +226 -0
- isa_model/inference/services/img/services/replicate_flux_kontext.py +219 -0
- isa_model/inference/services/img/services/replicate_sticker_maker.py +249 -0
- isa_model/inference/services/img/tests/test_img_client.py +297 -0
- isa_model/inference/services/llm/base_llm_service.py +30 -6
- isa_model/inference/services/llm/helpers/llm_adapter.py +63 -9
- isa_model/inference/services/llm/ollama_llm_service.py +2 -1
- isa_model/inference/services/llm/openai_llm_service.py +652 -55
- isa_model/inference/services/llm/yyds_llm_service.py +2 -1
- isa_model/inference/services/vision/__init__.py +5 -5
- isa_model/inference/services/vision/base_vision_service.py +118 -185
- isa_model/inference/services/vision/helpers/image_utils.py +11 -5
- isa_model/inference/services/vision/isa_vision_service.py +573 -0
- isa_model/inference/services/vision/tests/test_ocr_client.py +284 -0
- isa_model/serving/api/fastapi_server.py +88 -16
- isa_model/serving/api/middleware/auth.py +311 -0
- isa_model/serving/api/middleware/security.py +278 -0
- isa_model/serving/api/routes/analytics.py +486 -0
- isa_model/serving/api/routes/deployments.py +339 -0
- isa_model/serving/api/routes/evaluations.py +579 -0
- isa_model/serving/api/routes/logs.py +430 -0
- isa_model/serving/api/routes/settings.py +582 -0
- isa_model/serving/api/routes/unified.py +324 -165
- isa_model/serving/api/startup.py +304 -0
- isa_model/serving/modal_proxy_server.py +249 -0
- isa_model/training/__init__.py +100 -6
- isa_model/training/core/__init__.py +4 -1
- isa_model/training/examples/intelligent_training_example.py +281 -0
- isa_model/training/intelligent/__init__.py +25 -0
- isa_model/training/intelligent/decision_engine.py +643 -0
- isa_model/training/intelligent/intelligent_factory.py +888 -0
- isa_model/training/intelligent/knowledge_base.py +751 -0
- isa_model/training/intelligent/resource_optimizer.py +839 -0
- isa_model/training/intelligent/task_classifier.py +576 -0
- isa_model/training/storage/__init__.py +24 -0
- isa_model/training/storage/core_integration.py +439 -0
- isa_model/training/storage/training_repository.py +552 -0
- isa_model/training/storage/training_storage.py +628 -0
- {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/METADATA +13 -1
- isa_model-0.4.0.dist-info/RECORD +182 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +0 -766
- isa_model/deployment/cloud/modal/register_models.py +0 -321
- isa_model/inference/adapter/unified_api.py +0 -248
- isa_model/inference/services/helpers/stacked_config.py +0 -148
- isa_model/inference/services/img/flux_professional_service.py +0 -603
- isa_model/inference/services/img/helpers/base_stacked_service.py +0 -274
- isa_model/inference/services/others/table_transformer_service.py +0 -61
- isa_model/inference/services/vision/doc_analysis_service.py +0 -640
- isa_model/inference/services/vision/helpers/base_stacked_service.py +0 -274
- isa_model/inference/services/vision/ui_analysis_service.py +0 -823
- isa_model/scripts/inference_tracker.py +0 -283
- isa_model/scripts/mlflow_manager.py +0 -379
- isa_model/scripts/model_registry.py +0 -465
- isa_model/scripts/register_models.py +0 -370
- isa_model/scripts/register_models_with_embeddings.py +0 -510
- isa_model/scripts/start_mlflow.py +0 -95
- isa_model/scripts/training_tracker.py +0 -257
- isa_model-0.3.9.dist-info/RECORD +0 -138
- {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/WHEEL +0 -0
- {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/top_level.txt +0 -0
isa_model/eval/metrics.py
CHANGED
@@ -83,7 +83,7 @@ class LLMMetrics:
|
|
83
83
|
else:
|
84
84
|
self.ai_factory = None
|
85
85
|
|
86
|
-
def evaluate(
|
86
|
+
async def evaluate(
|
87
87
|
self,
|
88
88
|
model_path: str,
|
89
89
|
dataset: List[Dict[str, Any]],
|
@@ -113,7 +113,7 @@ class LLMMetrics:
|
|
113
113
|
}
|
114
114
|
|
115
115
|
# Generate predictions
|
116
|
-
predictions, references = self._generate_predictions(
|
116
|
+
predictions, references = await self._generate_predictions(
|
117
117
|
model_path, dataset, batch_size, provider, **kwargs
|
118
118
|
)
|
119
119
|
|
@@ -149,7 +149,7 @@ class LLMMetrics:
|
|
149
149
|
|
150
150
|
return results
|
151
151
|
|
152
|
-
def evaluate_generation(
|
152
|
+
async def evaluate_generation(
|
153
153
|
self,
|
154
154
|
model_path: str,
|
155
155
|
prompts: List[str],
|
@@ -208,7 +208,7 @@ class LLMMetrics:
|
|
208
208
|
|
209
209
|
return results
|
210
210
|
|
211
|
-
def _generate_predictions(
|
211
|
+
async def _generate_predictions(
|
212
212
|
self,
|
213
213
|
model_path: str,
|
214
214
|
dataset: List[Dict[str, Any]],
|
@@ -306,7 +306,7 @@ class LLMMetrics:
|
|
306
306
|
logger.info(f"Generated {len(predictions)} predictions")
|
307
307
|
return predictions, references
|
308
308
|
|
309
|
-
def _generate_texts(
|
309
|
+
async def _generate_texts(
|
310
310
|
self,
|
311
311
|
model_path: str,
|
312
312
|
prompts: List[str],
|
@@ -795,4 +795,157 @@ class BenchmarkRunner:
|
|
795
795
|
|
796
796
|
except Exception as e:
|
797
797
|
logger.error(f"Failed to generate prediction: {e}")
|
798
|
-
return "A" # Fallback answer
|
798
|
+
return "A" # Fallback answer
|
799
|
+
|
800
|
+
|
801
|
+
# Utility functions for evaluators
|
802
|
+
def compute_text_metrics(predictions: Union[str, List[str]],
|
803
|
+
references: Union[str, List[str]],
|
804
|
+
aggregate: bool = False) -> Dict[str, float]:
|
805
|
+
"""
|
806
|
+
Compute standard text evaluation metrics.
|
807
|
+
|
808
|
+
Args:
|
809
|
+
predictions: Single prediction or list of predictions
|
810
|
+
references: Single reference or list of references
|
811
|
+
aggregate: Whether to compute aggregate metrics for lists
|
812
|
+
|
813
|
+
Returns:
|
814
|
+
Dictionary of computed metrics
|
815
|
+
"""
|
816
|
+
try:
|
817
|
+
# Handle single string inputs
|
818
|
+
if isinstance(predictions, str) and isinstance(references, str):
|
819
|
+
pred_list = [predictions]
|
820
|
+
ref_list = [references]
|
821
|
+
else:
|
822
|
+
pred_list = predictions if isinstance(predictions, list) else [str(predictions)]
|
823
|
+
ref_list = references if isinstance(references, list) else [str(references)]
|
824
|
+
|
825
|
+
# Ensure equal lengths
|
826
|
+
min_len = min(len(pred_list), len(ref_list))
|
827
|
+
pred_list = pred_list[:min_len]
|
828
|
+
ref_list = ref_list[:min_len]
|
829
|
+
|
830
|
+
metrics = {}
|
831
|
+
|
832
|
+
# Exact match
|
833
|
+
exact_matches = sum(1 for p, r in zip(pred_list, ref_list) if p.strip().lower() == r.strip().lower())
|
834
|
+
metrics["exact_match"] = exact_matches / len(pred_list) if pred_list else 0.0
|
835
|
+
|
836
|
+
# F1 Score (token-level)
|
837
|
+
f1_scores = []
|
838
|
+
for pred, ref in zip(pred_list, ref_list):
|
839
|
+
pred_tokens = set(pred.lower().split())
|
840
|
+
ref_tokens = set(ref.lower().split())
|
841
|
+
|
842
|
+
if not ref_tokens and not pred_tokens:
|
843
|
+
f1_scores.append(1.0)
|
844
|
+
elif not ref_tokens or not pred_tokens:
|
845
|
+
f1_scores.append(0.0)
|
846
|
+
else:
|
847
|
+
intersection = len(pred_tokens & ref_tokens)
|
848
|
+
precision = intersection / len(pred_tokens)
|
849
|
+
recall = intersection / len(ref_tokens)
|
850
|
+
|
851
|
+
if precision + recall > 0:
|
852
|
+
f1 = 2 * (precision * recall) / (precision + recall)
|
853
|
+
f1_scores.append(f1)
|
854
|
+
else:
|
855
|
+
f1_scores.append(0.0)
|
856
|
+
|
857
|
+
metrics["f1_score"] = np.mean(f1_scores) if f1_scores else 0.0
|
858
|
+
|
859
|
+
# BLEU Score (simplified)
|
860
|
+
bleu_scores = []
|
861
|
+
for pred, ref in zip(pred_list, ref_list):
|
862
|
+
pred_words = pred.lower().split()
|
863
|
+
ref_words = ref.lower().split()
|
864
|
+
|
865
|
+
# Simple n-gram overlap
|
866
|
+
overlap = len(set(pred_words) & set(ref_words))
|
867
|
+
total = len(set(pred_words) | set(ref_words))
|
868
|
+
|
869
|
+
bleu_scores.append(overlap / total if total > 0 else 0.0)
|
870
|
+
|
871
|
+
metrics["bleu_score"] = np.mean(bleu_scores) if bleu_scores else 0.0
|
872
|
+
|
873
|
+
# ROUGE-L (simplified)
|
874
|
+
rouge_scores = []
|
875
|
+
for pred, ref in zip(pred_list, ref_list):
|
876
|
+
pred_words = set(pred.lower().split())
|
877
|
+
ref_words = set(ref.lower().split())
|
878
|
+
|
879
|
+
if len(ref_words) > 0:
|
880
|
+
rouge_l = len(pred_words & ref_words) / len(ref_words)
|
881
|
+
rouge_scores.append(rouge_l)
|
882
|
+
else:
|
883
|
+
rouge_scores.append(0.0)
|
884
|
+
|
885
|
+
metrics["rouge_l"] = np.mean(rouge_scores) if rouge_scores else 0.0
|
886
|
+
|
887
|
+
# Response length metrics
|
888
|
+
pred_lengths = [len(p.split()) for p in pred_list]
|
889
|
+
ref_lengths = [len(r.split()) for r in ref_list]
|
890
|
+
|
891
|
+
metrics["avg_prediction_length"] = np.mean(pred_lengths) if pred_lengths else 0.0
|
892
|
+
metrics["avg_reference_length"] = np.mean(ref_lengths) if ref_lengths else 0.0
|
893
|
+
metrics["length_ratio"] = (np.mean(pred_lengths) / np.mean(ref_lengths)) if np.mean(ref_lengths) > 0 else 0.0
|
894
|
+
|
895
|
+
# Diversity metrics for predictions
|
896
|
+
if len(pred_list) > 1:
|
897
|
+
all_words = []
|
898
|
+
for pred in pred_list:
|
899
|
+
all_words.extend(pred.lower().split())
|
900
|
+
|
901
|
+
unique_words = len(set(all_words))
|
902
|
+
total_words = len(all_words)
|
903
|
+
|
904
|
+
metrics["vocabulary_diversity"] = unique_words / total_words if total_words > 0 else 0.0
|
905
|
+
|
906
|
+
return metrics
|
907
|
+
|
908
|
+
except Exception as e:
|
909
|
+
logger.error(f"Error computing text metrics: {e}")
|
910
|
+
return {"text_metrics_error": 1.0}
|
911
|
+
|
912
|
+
|
913
|
+
def compute_vision_metrics(predictions: List[Any],
|
914
|
+
references: List[Any],
|
915
|
+
task_type: str = "general") -> Dict[str, float]:
|
916
|
+
"""
|
917
|
+
Compute vision-specific evaluation metrics.
|
918
|
+
|
919
|
+
Args:
|
920
|
+
predictions: List of vision model predictions
|
921
|
+
references: List of reference outputs
|
922
|
+
task_type: Type of vision task (ocr, detection, etc.)
|
923
|
+
|
924
|
+
Returns:
|
925
|
+
Dictionary of computed metrics
|
926
|
+
"""
|
927
|
+
try:
|
928
|
+
metrics = {}
|
929
|
+
|
930
|
+
# Basic success rate
|
931
|
+
successful_predictions = sum(1 for p in predictions if p is not None)
|
932
|
+
metrics["prediction_success_rate"] = successful_predictions / len(predictions) if predictions else 0.0
|
933
|
+
|
934
|
+
# Task-specific metrics would be computed by individual evaluators
|
935
|
+
# This is a placeholder for common vision metrics
|
936
|
+
|
937
|
+
if task_type == "ocr":
|
938
|
+
# OCR-specific metrics would be computed in VisionEvaluator
|
939
|
+
pass
|
940
|
+
elif task_type == "detection":
|
941
|
+
# Object detection metrics (IoU, mAP, etc.)
|
942
|
+
pass
|
943
|
+
elif task_type == "classification":
|
944
|
+
# Image classification metrics
|
945
|
+
pass
|
946
|
+
|
947
|
+
return metrics
|
948
|
+
|
949
|
+
except Exception as e:
|
950
|
+
logger.error(f"Error computing vision metrics: {e}")
|
951
|
+
return {"vision_metrics_error": 1.0}
|
@@ -0,0 +1,396 @@
|
|
1
|
+
"""
|
2
|
+
Unit tests for basic ISA Model evaluation framework functionality.
|
3
|
+
|
4
|
+
This test file focuses on core functionality without complex dependencies.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import pytest
|
8
|
+
import asyncio
|
9
|
+
from dataclasses import dataclass, field
|
10
|
+
from typing import Dict, List, Any
|
11
|
+
from abc import ABC, abstractmethod
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class MockEvaluationResult:
|
16
|
+
"""Mock evaluation result for testing."""
|
17
|
+
metrics: Dict[str, float] = field(default_factory=dict)
|
18
|
+
predictions: List[Any] = field(default_factory=list)
|
19
|
+
references: List[Any] = field(default_factory=list)
|
20
|
+
|
21
|
+
def to_dict(self):
|
22
|
+
"""Convert to dictionary."""
|
23
|
+
return {
|
24
|
+
"metrics": self.metrics,
|
25
|
+
"predictions": self.predictions,
|
26
|
+
"references": self.references
|
27
|
+
}
|
28
|
+
|
29
|
+
|
30
|
+
class MockBaseEvaluator(ABC):
|
31
|
+
"""Mock base evaluator for testing."""
|
32
|
+
|
33
|
+
def __init__(self, config: Dict = None):
|
34
|
+
self.config = config or {}
|
35
|
+
|
36
|
+
@abstractmethod
|
37
|
+
async def evaluate(self, model_interface, dataset, **kwargs):
|
38
|
+
pass
|
39
|
+
|
40
|
+
|
41
|
+
class TestEvaluationResult:
|
42
|
+
"""Test the EvaluationResult data structure."""
|
43
|
+
|
44
|
+
def test_evaluation_result_creation(self):
|
45
|
+
"""Test basic EvaluationResult creation and properties."""
|
46
|
+
result = MockEvaluationResult(
|
47
|
+
metrics={"accuracy": 0.85, "f1_score": 0.78},
|
48
|
+
predictions=["response1", "response2"],
|
49
|
+
references=["expected1", "expected2"]
|
50
|
+
)
|
51
|
+
|
52
|
+
assert result.metrics["accuracy"] == 0.85
|
53
|
+
assert result.metrics["f1_score"] == 0.78
|
54
|
+
assert len(result.predictions) == 2
|
55
|
+
assert len(result.references) == 2
|
56
|
+
|
57
|
+
def test_evaluation_result_default_values(self):
|
58
|
+
"""Test EvaluationResult with default values."""
|
59
|
+
result = MockEvaluationResult()
|
60
|
+
|
61
|
+
assert isinstance(result.metrics, dict)
|
62
|
+
assert isinstance(result.predictions, list)
|
63
|
+
assert isinstance(result.references, list)
|
64
|
+
assert len(result.metrics) == 0
|
65
|
+
assert len(result.predictions) == 0
|
66
|
+
assert len(result.references) == 0
|
67
|
+
|
68
|
+
def test_evaluation_result_to_dict(self):
|
69
|
+
"""Test EvaluationResult serialization."""
|
70
|
+
result = MockEvaluationResult(
|
71
|
+
metrics={"accuracy": 0.9},
|
72
|
+
predictions=["test"],
|
73
|
+
references=["expected"]
|
74
|
+
)
|
75
|
+
|
76
|
+
result_dict = result.to_dict()
|
77
|
+
assert isinstance(result_dict, dict)
|
78
|
+
assert "metrics" in result_dict
|
79
|
+
assert result_dict["metrics"]["accuracy"] == 0.9
|
80
|
+
|
81
|
+
|
82
|
+
class MockModelInterface:
|
83
|
+
"""Mock model interface for testing."""
|
84
|
+
|
85
|
+
def __init__(self, responses: List[str] = None):
|
86
|
+
self.responses = responses or ["mock response"]
|
87
|
+
self.call_count = 0
|
88
|
+
|
89
|
+
async def generate_response(self, prompt: str, **kwargs) -> str:
|
90
|
+
"""Mock response generation."""
|
91
|
+
response = self.responses[self.call_count % len(self.responses)]
|
92
|
+
self.call_count += 1
|
93
|
+
await asyncio.sleep(0.01) # Simulate async processing
|
94
|
+
return response
|
95
|
+
|
96
|
+
|
97
|
+
class TestBasicMetrics:
|
98
|
+
"""Test basic metric calculation functions."""
|
99
|
+
|
100
|
+
def test_exact_match_metric(self):
|
101
|
+
"""Test exact match calculation."""
|
102
|
+
predictions = ["Paris", "London", "Berlin"]
|
103
|
+
references = ["Paris", "Madrid", "Berlin"]
|
104
|
+
|
105
|
+
def calculate_exact_match(pred_list, ref_list):
|
106
|
+
"""Simple exact match implementation."""
|
107
|
+
matches = sum(1 for p, r in zip(pred_list, ref_list)
|
108
|
+
if p.strip().lower() == r.strip().lower())
|
109
|
+
return matches / len(pred_list)
|
110
|
+
|
111
|
+
accuracy = calculate_exact_match(predictions, references)
|
112
|
+
assert accuracy == 2/3 # 2 out of 3 matches
|
113
|
+
|
114
|
+
def test_f1_score_calculation(self):
|
115
|
+
"""Test F1 score calculation."""
|
116
|
+
predictions = ["The cat sits", "A dog runs"]
|
117
|
+
references = ["The cat sits on mat", "The dog runs fast"]
|
118
|
+
|
119
|
+
def calculate_f1_score(pred_list, ref_list):
|
120
|
+
"""Simple token-based F1 calculation."""
|
121
|
+
total_f1 = 0
|
122
|
+
for pred, ref in zip(pred_list, ref_list):
|
123
|
+
pred_tokens = set(pred.lower().split())
|
124
|
+
ref_tokens = set(ref.lower().split())
|
125
|
+
|
126
|
+
if len(pred_tokens) == 0 and len(ref_tokens) == 0:
|
127
|
+
f1 = 1.0
|
128
|
+
elif len(pred_tokens) == 0 or len(ref_tokens) == 0:
|
129
|
+
f1 = 0.0
|
130
|
+
else:
|
131
|
+
intersection = len(pred_tokens & ref_tokens)
|
132
|
+
precision = intersection / len(pred_tokens)
|
133
|
+
recall = intersection / len(ref_tokens)
|
134
|
+
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
135
|
+
|
136
|
+
total_f1 += f1
|
137
|
+
|
138
|
+
return total_f1 / len(pred_list)
|
139
|
+
|
140
|
+
f1 = calculate_f1_score(predictions, references)
|
141
|
+
assert isinstance(f1, float)
|
142
|
+
assert 0 <= f1 <= 1
|
143
|
+
|
144
|
+
|
145
|
+
class TestBasicEvaluator:
|
146
|
+
"""Test basic evaluator functionality."""
|
147
|
+
|
148
|
+
@pytest.fixture
|
149
|
+
def mock_config(self):
|
150
|
+
"""Create a mock evaluation config."""
|
151
|
+
return {
|
152
|
+
"batch_size": 2,
|
153
|
+
"max_concurrent_requests": 3,
|
154
|
+
"timeout_seconds": 30
|
155
|
+
}
|
156
|
+
|
157
|
+
@pytest.fixture
|
158
|
+
def sample_dataset(self):
|
159
|
+
"""Create a sample dataset for testing."""
|
160
|
+
return [
|
161
|
+
{
|
162
|
+
"id": "test_1",
|
163
|
+
"prompt": "What is 2+2?",
|
164
|
+
"expected_output": "4",
|
165
|
+
"metadata": {"category": "math"}
|
166
|
+
},
|
167
|
+
{
|
168
|
+
"id": "test_2",
|
169
|
+
"input": "Name the capital of France",
|
170
|
+
"expected_output": "Paris",
|
171
|
+
"metadata": {"category": "geography"}
|
172
|
+
},
|
173
|
+
{
|
174
|
+
"id": "test_3",
|
175
|
+
"prompt": "What color is the sky?",
|
176
|
+
"expected_output": "blue",
|
177
|
+
"metadata": {"category": "general"}
|
178
|
+
}
|
179
|
+
]
|
180
|
+
|
181
|
+
def test_evaluator_initialization(self, mock_config):
|
182
|
+
"""Test basic evaluator initialization."""
|
183
|
+
class TestEvaluator(MockBaseEvaluator):
|
184
|
+
async def evaluate(self, model_interface, dataset, **kwargs):
|
185
|
+
return MockEvaluationResult()
|
186
|
+
|
187
|
+
evaluator = TestEvaluator(config=mock_config)
|
188
|
+
assert evaluator.config["batch_size"] == 2
|
189
|
+
assert evaluator.config["max_concurrent_requests"] == 3
|
190
|
+
|
191
|
+
@pytest.mark.asyncio
|
192
|
+
async def test_mock_evaluation_workflow(self, sample_dataset, mock_config):
|
193
|
+
"""Test basic evaluation workflow with mocked components."""
|
194
|
+
|
195
|
+
class TestEvaluator(MockBaseEvaluator):
|
196
|
+
async def evaluate(self, model_interface, dataset, **kwargs):
|
197
|
+
"""Simple evaluation implementation for testing."""
|
198
|
+
predictions = []
|
199
|
+
references = []
|
200
|
+
|
201
|
+
for item in dataset:
|
202
|
+
# Mock model call
|
203
|
+
response = await model_interface.generate_response(
|
204
|
+
item.get("prompt", item.get("input", ""))
|
205
|
+
)
|
206
|
+
predictions.append(response)
|
207
|
+
references.append(item["expected_output"])
|
208
|
+
|
209
|
+
# Calculate simple accuracy
|
210
|
+
matches = sum(1 for p, r in zip(predictions, references)
|
211
|
+
if p.strip().lower() == r.strip().lower())
|
212
|
+
accuracy = matches / len(predictions) if predictions else 0
|
213
|
+
|
214
|
+
return MockEvaluationResult(
|
215
|
+
metrics={"accuracy": accuracy, "total_samples": len(dataset)},
|
216
|
+
predictions=predictions,
|
217
|
+
references=references
|
218
|
+
)
|
219
|
+
|
220
|
+
# Create evaluator and mock model
|
221
|
+
evaluator = TestEvaluator(config=mock_config)
|
222
|
+
model_interface = MockModelInterface(responses=["4", "Paris", "blue"])
|
223
|
+
|
224
|
+
# Run evaluation
|
225
|
+
result = await evaluator.evaluate(
|
226
|
+
model_interface=model_interface,
|
227
|
+
dataset=sample_dataset,
|
228
|
+
dataset_name="test_dataset"
|
229
|
+
)
|
230
|
+
|
231
|
+
# Verify results
|
232
|
+
assert isinstance(result, MockEvaluationResult)
|
233
|
+
assert "accuracy" in result.metrics
|
234
|
+
assert "total_samples" in result.metrics
|
235
|
+
assert result.metrics["total_samples"] == 3
|
236
|
+
assert result.metrics["accuracy"] == 1.0 # All mock responses match expected
|
237
|
+
assert len(result.predictions) == 3
|
238
|
+
assert len(result.references) == 3
|
239
|
+
|
240
|
+
|
241
|
+
class TestEvaluationConfig:
|
242
|
+
"""Test evaluation configuration functionality."""
|
243
|
+
|
244
|
+
def test_config_creation(self):
|
245
|
+
"""Test basic config creation."""
|
246
|
+
config_data = {
|
247
|
+
"batch_size": 16,
|
248
|
+
"max_concurrent_requests": 5,
|
249
|
+
"timeout_seconds": 60,
|
250
|
+
"output_dir": "test_results"
|
251
|
+
}
|
252
|
+
|
253
|
+
class MockConfig:
|
254
|
+
def __init__(self, **kwargs):
|
255
|
+
for k, v in kwargs.items():
|
256
|
+
setattr(self, k, v)
|
257
|
+
|
258
|
+
config = MockConfig(**config_data)
|
259
|
+
assert config.batch_size == 16
|
260
|
+
assert config.output_dir == "test_results"
|
261
|
+
|
262
|
+
def test_config_validation(self):
|
263
|
+
"""Test config validation logic."""
|
264
|
+
def validate_config(config_dict):
|
265
|
+
"""Validate configuration values."""
|
266
|
+
if config_dict.get("batch_size", 1) <= 0:
|
267
|
+
raise ValueError("batch_size must be positive")
|
268
|
+
if config_dict.get("max_concurrent_requests", 1) <= 0:
|
269
|
+
raise ValueError("max_concurrent_requests must be positive")
|
270
|
+
if config_dict.get("timeout_seconds", 1) <= 0:
|
271
|
+
raise ValueError("timeout_seconds must be positive")
|
272
|
+
return True
|
273
|
+
|
274
|
+
# Test valid config
|
275
|
+
valid_config = {"batch_size": 10, "max_concurrent_requests": 5, "timeout_seconds": 60}
|
276
|
+
assert validate_config(valid_config) is True
|
277
|
+
|
278
|
+
# Test invalid configs
|
279
|
+
invalid_configs = [
|
280
|
+
{"batch_size": -1},
|
281
|
+
{"max_concurrent_requests": 0},
|
282
|
+
{"timeout_seconds": -5}
|
283
|
+
]
|
284
|
+
|
285
|
+
for invalid_config in invalid_configs:
|
286
|
+
with pytest.raises(ValueError):
|
287
|
+
validate_config(invalid_config)
|
288
|
+
|
289
|
+
|
290
|
+
class TestAsyncEvaluation:
|
291
|
+
"""Test asynchronous evaluation capabilities."""
|
292
|
+
|
293
|
+
@pytest.mark.asyncio
|
294
|
+
async def test_concurrent_evaluation(self):
|
295
|
+
"""Test that evaluations can run concurrently."""
|
296
|
+
async def mock_evaluation_task(task_id: int, delay: float = 0.1):
|
297
|
+
"""Mock evaluation task with delay."""
|
298
|
+
await asyncio.sleep(delay)
|
299
|
+
return {"task_id": task_id, "result": f"completed_{task_id}"}
|
300
|
+
|
301
|
+
# Run multiple evaluations concurrently
|
302
|
+
start_time = asyncio.get_event_loop().time()
|
303
|
+
|
304
|
+
tasks = [mock_evaluation_task(i, 0.1) for i in range(3)]
|
305
|
+
results = await asyncio.gather(*tasks)
|
306
|
+
|
307
|
+
end_time = asyncio.get_event_loop().time()
|
308
|
+
|
309
|
+
# Should complete in roughly 0.1 seconds (concurrent) rather than 0.3 (sequential)
|
310
|
+
assert end_time - start_time < 0.2
|
311
|
+
assert len(results) == 3
|
312
|
+
assert all(r["result"].startswith("completed_") for r in results)
|
313
|
+
|
314
|
+
@pytest.mark.asyncio
|
315
|
+
async def test_batch_processing(self):
|
316
|
+
"""Test batch processing functionality."""
|
317
|
+
async def process_batch(batch: List[Dict], batch_size: int = 2):
|
318
|
+
"""Process items in batches."""
|
319
|
+
results = []
|
320
|
+
for i in range(0, len(batch), batch_size):
|
321
|
+
batch_items = batch[i:i + batch_size]
|
322
|
+
# Simulate processing time proportional to batch size
|
323
|
+
await asyncio.sleep(0.01 * len(batch_items))
|
324
|
+
batch_results = [{"processed": item["id"]} for item in batch_items]
|
325
|
+
results.extend(batch_results)
|
326
|
+
return results
|
327
|
+
|
328
|
+
# Test data
|
329
|
+
test_items = [{"id": f"item_{i}"} for i in range(5)]
|
330
|
+
|
331
|
+
# Process in batches
|
332
|
+
results = await process_batch(test_items, batch_size=2)
|
333
|
+
|
334
|
+
assert len(results) == 5
|
335
|
+
assert all("processed" in r for r in results)
|
336
|
+
|
337
|
+
|
338
|
+
class TestErrorHandling:
|
339
|
+
"""Test error handling and edge cases."""
|
340
|
+
|
341
|
+
@pytest.mark.asyncio
|
342
|
+
async def test_timeout_handling(self):
|
343
|
+
"""Test timeout handling in async operations."""
|
344
|
+
async def slow_operation():
|
345
|
+
"""Simulate a slow operation."""
|
346
|
+
await asyncio.sleep(1.0)
|
347
|
+
return "completed"
|
348
|
+
|
349
|
+
# Test timeout
|
350
|
+
with pytest.raises(asyncio.TimeoutError):
|
351
|
+
await asyncio.wait_for(slow_operation(), timeout=0.1)
|
352
|
+
|
353
|
+
def test_empty_dataset_handling(self):
|
354
|
+
"""Test handling of empty datasets."""
|
355
|
+
def calculate_metrics(predictions, references):
|
356
|
+
"""Calculate metrics with empty data handling."""
|
357
|
+
if not predictions or not references:
|
358
|
+
return {"accuracy": 0.0, "count": 0}
|
359
|
+
|
360
|
+
matches = sum(1 for p, r in zip(predictions, references) if p == r)
|
361
|
+
return {
|
362
|
+
"accuracy": matches / len(predictions),
|
363
|
+
"count": len(predictions)
|
364
|
+
}
|
365
|
+
|
366
|
+
# Test empty data
|
367
|
+
empty_metrics = calculate_metrics([], [])
|
368
|
+
assert empty_metrics["accuracy"] == 0.0
|
369
|
+
assert empty_metrics["count"] == 0
|
370
|
+
|
371
|
+
def test_mismatched_data_lengths(self):
|
372
|
+
"""Test handling of mismatched prediction and reference lengths."""
|
373
|
+
def safe_calculate_accuracy(predictions, references):
|
374
|
+
"""Safely calculate accuracy with length mismatch handling."""
|
375
|
+
if len(predictions) != len(references):
|
376
|
+
min_len = min(len(predictions), len(references))
|
377
|
+
predictions = predictions[:min_len]
|
378
|
+
references = references[:min_len]
|
379
|
+
|
380
|
+
if not predictions:
|
381
|
+
return 0.0
|
382
|
+
|
383
|
+
matches = sum(1 for p, r in zip(predictions, references) if p == r)
|
384
|
+
return matches / len(predictions)
|
385
|
+
|
386
|
+
# Test mismatched lengths
|
387
|
+
predictions = ["a", "b", "c"]
|
388
|
+
references = ["a", "b"] # Shorter
|
389
|
+
|
390
|
+
accuracy = safe_calculate_accuracy(predictions, references)
|
391
|
+
assert accuracy == 1.0 # Both "a" and "b" match
|
392
|
+
|
393
|
+
|
394
|
+
if __name__ == "__main__":
|
395
|
+
# Allow running tests directly
|
396
|
+
pytest.main([__file__, "-v"])
|