isa-model 0.3.9__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/client.py +732 -565
  3. isa_model/core/cache/redis_cache.py +401 -0
  4. isa_model/core/config/config_manager.py +53 -10
  5. isa_model/core/config.py +1 -1
  6. isa_model/core/database/__init__.py +1 -0
  7. isa_model/core/database/migrations.py +277 -0
  8. isa_model/core/database/supabase_client.py +123 -0
  9. isa_model/core/models/__init__.py +37 -0
  10. isa_model/core/models/model_billing_tracker.py +60 -88
  11. isa_model/core/models/model_manager.py +36 -18
  12. isa_model/core/models/model_repo.py +44 -38
  13. isa_model/core/models/model_statistics_tracker.py +234 -0
  14. isa_model/core/models/model_storage.py +0 -1
  15. isa_model/core/models/model_version_manager.py +959 -0
  16. isa_model/core/pricing_manager.py +2 -249
  17. isa_model/core/resilience/circuit_breaker.py +366 -0
  18. isa_model/core/security/secrets.py +358 -0
  19. isa_model/core/services/__init__.py +2 -4
  20. isa_model/core/services/intelligent_model_selector.py +101 -370
  21. isa_model/core/storage/hf_storage.py +1 -1
  22. isa_model/core/types.py +7 -0
  23. isa_model/deployment/cloud/modal/isa_audio_chatTTS_service.py +520 -0
  24. isa_model/deployment/cloud/modal/isa_audio_fish_service.py +0 -0
  25. isa_model/deployment/cloud/modal/isa_audio_openvoice_service.py +758 -0
  26. isa_model/deployment/cloud/modal/isa_audio_service_v2.py +1044 -0
  27. isa_model/deployment/cloud/modal/isa_embed_rerank_service.py +296 -0
  28. isa_model/deployment/cloud/modal/isa_video_hunyuan_service.py +423 -0
  29. isa_model/deployment/cloud/modal/isa_vision_ocr_service.py +519 -0
  30. isa_model/deployment/cloud/modal/isa_vision_qwen25_service.py +709 -0
  31. isa_model/deployment/cloud/modal/isa_vision_table_service.py +467 -323
  32. isa_model/deployment/cloud/modal/isa_vision_ui_service.py +607 -180
  33. isa_model/deployment/cloud/modal/isa_vision_ui_service_optimized.py +660 -0
  34. isa_model/deployment/core/deployment_manager.py +6 -4
  35. isa_model/deployment/services/auto_hf_modal_deployer.py +894 -0
  36. isa_model/eval/benchmarks/__init__.py +27 -0
  37. isa_model/eval/benchmarks/multimodal_datasets.py +460 -0
  38. isa_model/eval/benchmarks.py +244 -12
  39. isa_model/eval/evaluators/__init__.py +8 -2
  40. isa_model/eval/evaluators/audio_evaluator.py +727 -0
  41. isa_model/eval/evaluators/embedding_evaluator.py +742 -0
  42. isa_model/eval/evaluators/vision_evaluator.py +564 -0
  43. isa_model/eval/example_evaluation.py +395 -0
  44. isa_model/eval/factory.py +272 -5
  45. isa_model/eval/isa_benchmarks.py +700 -0
  46. isa_model/eval/isa_integration.py +582 -0
  47. isa_model/eval/metrics.py +159 -6
  48. isa_model/eval/tests/unit/test_basic.py +396 -0
  49. isa_model/inference/ai_factory.py +44 -8
  50. isa_model/inference/services/audio/__init__.py +21 -0
  51. isa_model/inference/services/audio/base_realtime_service.py +225 -0
  52. isa_model/inference/services/audio/isa_tts_service.py +0 -0
  53. isa_model/inference/services/audio/openai_realtime_service.py +320 -124
  54. isa_model/inference/services/audio/openai_stt_service.py +32 -6
  55. isa_model/inference/services/base_service.py +17 -1
  56. isa_model/inference/services/embedding/__init__.py +13 -0
  57. isa_model/inference/services/embedding/base_embed_service.py +111 -8
  58. isa_model/inference/services/embedding/isa_embed_service.py +305 -0
  59. isa_model/inference/services/embedding/openai_embed_service.py +2 -4
  60. isa_model/inference/services/embedding/tests/test_embedding.py +222 -0
  61. isa_model/inference/services/img/__init__.py +2 -2
  62. isa_model/inference/services/img/base_image_gen_service.py +24 -7
  63. isa_model/inference/services/img/replicate_image_gen_service.py +84 -422
  64. isa_model/inference/services/img/services/replicate_face_swap.py +193 -0
  65. isa_model/inference/services/img/services/replicate_flux.py +226 -0
  66. isa_model/inference/services/img/services/replicate_flux_kontext.py +219 -0
  67. isa_model/inference/services/img/services/replicate_sticker_maker.py +249 -0
  68. isa_model/inference/services/img/tests/test_img_client.py +297 -0
  69. isa_model/inference/services/llm/base_llm_service.py +30 -6
  70. isa_model/inference/services/llm/helpers/llm_adapter.py +63 -9
  71. isa_model/inference/services/llm/ollama_llm_service.py +2 -1
  72. isa_model/inference/services/llm/openai_llm_service.py +652 -55
  73. isa_model/inference/services/llm/yyds_llm_service.py +2 -1
  74. isa_model/inference/services/vision/__init__.py +5 -5
  75. isa_model/inference/services/vision/base_vision_service.py +118 -185
  76. isa_model/inference/services/vision/helpers/image_utils.py +11 -5
  77. isa_model/inference/services/vision/isa_vision_service.py +573 -0
  78. isa_model/inference/services/vision/tests/test_ocr_client.py +284 -0
  79. isa_model/serving/api/fastapi_server.py +88 -16
  80. isa_model/serving/api/middleware/auth.py +311 -0
  81. isa_model/serving/api/middleware/security.py +278 -0
  82. isa_model/serving/api/routes/analytics.py +486 -0
  83. isa_model/serving/api/routes/deployments.py +339 -0
  84. isa_model/serving/api/routes/evaluations.py +579 -0
  85. isa_model/serving/api/routes/logs.py +430 -0
  86. isa_model/serving/api/routes/settings.py +582 -0
  87. isa_model/serving/api/routes/unified.py +324 -165
  88. isa_model/serving/api/startup.py +304 -0
  89. isa_model/serving/modal_proxy_server.py +249 -0
  90. isa_model/training/__init__.py +100 -6
  91. isa_model/training/core/__init__.py +4 -1
  92. isa_model/training/examples/intelligent_training_example.py +281 -0
  93. isa_model/training/intelligent/__init__.py +25 -0
  94. isa_model/training/intelligent/decision_engine.py +643 -0
  95. isa_model/training/intelligent/intelligent_factory.py +888 -0
  96. isa_model/training/intelligent/knowledge_base.py +751 -0
  97. isa_model/training/intelligent/resource_optimizer.py +839 -0
  98. isa_model/training/intelligent/task_classifier.py +576 -0
  99. isa_model/training/storage/__init__.py +24 -0
  100. isa_model/training/storage/core_integration.py +439 -0
  101. isa_model/training/storage/training_repository.py +552 -0
  102. isa_model/training/storage/training_storage.py +628 -0
  103. {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/METADATA +13 -1
  104. isa_model-0.4.0.dist-info/RECORD +182 -0
  105. isa_model/deployment/cloud/modal/isa_vision_doc_service.py +0 -766
  106. isa_model/deployment/cloud/modal/register_models.py +0 -321
  107. isa_model/inference/adapter/unified_api.py +0 -248
  108. isa_model/inference/services/helpers/stacked_config.py +0 -148
  109. isa_model/inference/services/img/flux_professional_service.py +0 -603
  110. isa_model/inference/services/img/helpers/base_stacked_service.py +0 -274
  111. isa_model/inference/services/others/table_transformer_service.py +0 -61
  112. isa_model/inference/services/vision/doc_analysis_service.py +0 -640
  113. isa_model/inference/services/vision/helpers/base_stacked_service.py +0 -274
  114. isa_model/inference/services/vision/ui_analysis_service.py +0 -823
  115. isa_model/scripts/inference_tracker.py +0 -283
  116. isa_model/scripts/mlflow_manager.py +0 -379
  117. isa_model/scripts/model_registry.py +0 -465
  118. isa_model/scripts/register_models.py +0 -370
  119. isa_model/scripts/register_models_with_embeddings.py +0 -510
  120. isa_model/scripts/start_mlflow.py +0 -95
  121. isa_model/scripts/training_tracker.py +0 -257
  122. isa_model-0.3.9.dist-info/RECORD +0 -138
  123. {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/WHEEL +0 -0
  124. {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/top_level.txt +0 -0
isa_model/eval/metrics.py CHANGED
@@ -83,7 +83,7 @@ class LLMMetrics:
83
83
  else:
84
84
  self.ai_factory = None
85
85
 
86
- def evaluate(
86
+ async def evaluate(
87
87
  self,
88
88
  model_path: str,
89
89
  dataset: List[Dict[str, Any]],
@@ -113,7 +113,7 @@ class LLMMetrics:
113
113
  }
114
114
 
115
115
  # Generate predictions
116
- predictions, references = self._generate_predictions(
116
+ predictions, references = await self._generate_predictions(
117
117
  model_path, dataset, batch_size, provider, **kwargs
118
118
  )
119
119
 
@@ -149,7 +149,7 @@ class LLMMetrics:
149
149
 
150
150
  return results
151
151
 
152
- def evaluate_generation(
152
+ async def evaluate_generation(
153
153
  self,
154
154
  model_path: str,
155
155
  prompts: List[str],
@@ -208,7 +208,7 @@ class LLMMetrics:
208
208
 
209
209
  return results
210
210
 
211
- def _generate_predictions(
211
+ async def _generate_predictions(
212
212
  self,
213
213
  model_path: str,
214
214
  dataset: List[Dict[str, Any]],
@@ -306,7 +306,7 @@ class LLMMetrics:
306
306
  logger.info(f"Generated {len(predictions)} predictions")
307
307
  return predictions, references
308
308
 
309
- def _generate_texts(
309
+ async def _generate_texts(
310
310
  self,
311
311
  model_path: str,
312
312
  prompts: List[str],
@@ -795,4 +795,157 @@ class BenchmarkRunner:
795
795
 
796
796
  except Exception as e:
797
797
  logger.error(f"Failed to generate prediction: {e}")
798
- return "A" # Fallback answer
798
+ return "A" # Fallback answer
799
+
800
+
801
+ # Utility functions for evaluators
802
+ def compute_text_metrics(predictions: Union[str, List[str]],
803
+ references: Union[str, List[str]],
804
+ aggregate: bool = False) -> Dict[str, float]:
805
+ """
806
+ Compute standard text evaluation metrics.
807
+
808
+ Args:
809
+ predictions: Single prediction or list of predictions
810
+ references: Single reference or list of references
811
+ aggregate: Whether to compute aggregate metrics for lists
812
+
813
+ Returns:
814
+ Dictionary of computed metrics
815
+ """
816
+ try:
817
+ # Handle single string inputs
818
+ if isinstance(predictions, str) and isinstance(references, str):
819
+ pred_list = [predictions]
820
+ ref_list = [references]
821
+ else:
822
+ pred_list = predictions if isinstance(predictions, list) else [str(predictions)]
823
+ ref_list = references if isinstance(references, list) else [str(references)]
824
+
825
+ # Ensure equal lengths
826
+ min_len = min(len(pred_list), len(ref_list))
827
+ pred_list = pred_list[:min_len]
828
+ ref_list = ref_list[:min_len]
829
+
830
+ metrics = {}
831
+
832
+ # Exact match
833
+ exact_matches = sum(1 for p, r in zip(pred_list, ref_list) if p.strip().lower() == r.strip().lower())
834
+ metrics["exact_match"] = exact_matches / len(pred_list) if pred_list else 0.0
835
+
836
+ # F1 Score (token-level)
837
+ f1_scores = []
838
+ for pred, ref in zip(pred_list, ref_list):
839
+ pred_tokens = set(pred.lower().split())
840
+ ref_tokens = set(ref.lower().split())
841
+
842
+ if not ref_tokens and not pred_tokens:
843
+ f1_scores.append(1.0)
844
+ elif not ref_tokens or not pred_tokens:
845
+ f1_scores.append(0.0)
846
+ else:
847
+ intersection = len(pred_tokens & ref_tokens)
848
+ precision = intersection / len(pred_tokens)
849
+ recall = intersection / len(ref_tokens)
850
+
851
+ if precision + recall > 0:
852
+ f1 = 2 * (precision * recall) / (precision + recall)
853
+ f1_scores.append(f1)
854
+ else:
855
+ f1_scores.append(0.0)
856
+
857
+ metrics["f1_score"] = np.mean(f1_scores) if f1_scores else 0.0
858
+
859
+ # BLEU Score (simplified)
860
+ bleu_scores = []
861
+ for pred, ref in zip(pred_list, ref_list):
862
+ pred_words = pred.lower().split()
863
+ ref_words = ref.lower().split()
864
+
865
+ # Simple n-gram overlap
866
+ overlap = len(set(pred_words) & set(ref_words))
867
+ total = len(set(pred_words) | set(ref_words))
868
+
869
+ bleu_scores.append(overlap / total if total > 0 else 0.0)
870
+
871
+ metrics["bleu_score"] = np.mean(bleu_scores) if bleu_scores else 0.0
872
+
873
+ # ROUGE-L (simplified)
874
+ rouge_scores = []
875
+ for pred, ref in zip(pred_list, ref_list):
876
+ pred_words = set(pred.lower().split())
877
+ ref_words = set(ref.lower().split())
878
+
879
+ if len(ref_words) > 0:
880
+ rouge_l = len(pred_words & ref_words) / len(ref_words)
881
+ rouge_scores.append(rouge_l)
882
+ else:
883
+ rouge_scores.append(0.0)
884
+
885
+ metrics["rouge_l"] = np.mean(rouge_scores) if rouge_scores else 0.0
886
+
887
+ # Response length metrics
888
+ pred_lengths = [len(p.split()) for p in pred_list]
889
+ ref_lengths = [len(r.split()) for r in ref_list]
890
+
891
+ metrics["avg_prediction_length"] = np.mean(pred_lengths) if pred_lengths else 0.0
892
+ metrics["avg_reference_length"] = np.mean(ref_lengths) if ref_lengths else 0.0
893
+ metrics["length_ratio"] = (np.mean(pred_lengths) / np.mean(ref_lengths)) if np.mean(ref_lengths) > 0 else 0.0
894
+
895
+ # Diversity metrics for predictions
896
+ if len(pred_list) > 1:
897
+ all_words = []
898
+ for pred in pred_list:
899
+ all_words.extend(pred.lower().split())
900
+
901
+ unique_words = len(set(all_words))
902
+ total_words = len(all_words)
903
+
904
+ metrics["vocabulary_diversity"] = unique_words / total_words if total_words > 0 else 0.0
905
+
906
+ return metrics
907
+
908
+ except Exception as e:
909
+ logger.error(f"Error computing text metrics: {e}")
910
+ return {"text_metrics_error": 1.0}
911
+
912
+
913
+ def compute_vision_metrics(predictions: List[Any],
914
+ references: List[Any],
915
+ task_type: str = "general") -> Dict[str, float]:
916
+ """
917
+ Compute vision-specific evaluation metrics.
918
+
919
+ Args:
920
+ predictions: List of vision model predictions
921
+ references: List of reference outputs
922
+ task_type: Type of vision task (ocr, detection, etc.)
923
+
924
+ Returns:
925
+ Dictionary of computed metrics
926
+ """
927
+ try:
928
+ metrics = {}
929
+
930
+ # Basic success rate
931
+ successful_predictions = sum(1 for p in predictions if p is not None)
932
+ metrics["prediction_success_rate"] = successful_predictions / len(predictions) if predictions else 0.0
933
+
934
+ # Task-specific metrics would be computed by individual evaluators
935
+ # This is a placeholder for common vision metrics
936
+
937
+ if task_type == "ocr":
938
+ # OCR-specific metrics would be computed in VisionEvaluator
939
+ pass
940
+ elif task_type == "detection":
941
+ # Object detection metrics (IoU, mAP, etc.)
942
+ pass
943
+ elif task_type == "classification":
944
+ # Image classification metrics
945
+ pass
946
+
947
+ return metrics
948
+
949
+ except Exception as e:
950
+ logger.error(f"Error computing vision metrics: {e}")
951
+ return {"vision_metrics_error": 1.0}
@@ -0,0 +1,396 @@
1
+ """
2
+ Unit tests for basic ISA Model evaluation framework functionality.
3
+
4
+ This test file focuses on core functionality without complex dependencies.
5
+ """
6
+
7
+ import pytest
8
+ import asyncio
9
+ from dataclasses import dataclass, field
10
+ from typing import Dict, List, Any
11
+ from abc import ABC, abstractmethod
12
+
13
+
14
+ @dataclass
15
+ class MockEvaluationResult:
16
+ """Mock evaluation result for testing."""
17
+ metrics: Dict[str, float] = field(default_factory=dict)
18
+ predictions: List[Any] = field(default_factory=list)
19
+ references: List[Any] = field(default_factory=list)
20
+
21
+ def to_dict(self):
22
+ """Convert to dictionary."""
23
+ return {
24
+ "metrics": self.metrics,
25
+ "predictions": self.predictions,
26
+ "references": self.references
27
+ }
28
+
29
+
30
+ class MockBaseEvaluator(ABC):
31
+ """Mock base evaluator for testing."""
32
+
33
+ def __init__(self, config: Dict = None):
34
+ self.config = config or {}
35
+
36
+ @abstractmethod
37
+ async def evaluate(self, model_interface, dataset, **kwargs):
38
+ pass
39
+
40
+
41
+ class TestEvaluationResult:
42
+ """Test the EvaluationResult data structure."""
43
+
44
+ def test_evaluation_result_creation(self):
45
+ """Test basic EvaluationResult creation and properties."""
46
+ result = MockEvaluationResult(
47
+ metrics={"accuracy": 0.85, "f1_score": 0.78},
48
+ predictions=["response1", "response2"],
49
+ references=["expected1", "expected2"]
50
+ )
51
+
52
+ assert result.metrics["accuracy"] == 0.85
53
+ assert result.metrics["f1_score"] == 0.78
54
+ assert len(result.predictions) == 2
55
+ assert len(result.references) == 2
56
+
57
+ def test_evaluation_result_default_values(self):
58
+ """Test EvaluationResult with default values."""
59
+ result = MockEvaluationResult()
60
+
61
+ assert isinstance(result.metrics, dict)
62
+ assert isinstance(result.predictions, list)
63
+ assert isinstance(result.references, list)
64
+ assert len(result.metrics) == 0
65
+ assert len(result.predictions) == 0
66
+ assert len(result.references) == 0
67
+
68
+ def test_evaluation_result_to_dict(self):
69
+ """Test EvaluationResult serialization."""
70
+ result = MockEvaluationResult(
71
+ metrics={"accuracy": 0.9},
72
+ predictions=["test"],
73
+ references=["expected"]
74
+ )
75
+
76
+ result_dict = result.to_dict()
77
+ assert isinstance(result_dict, dict)
78
+ assert "metrics" in result_dict
79
+ assert result_dict["metrics"]["accuracy"] == 0.9
80
+
81
+
82
+ class MockModelInterface:
83
+ """Mock model interface for testing."""
84
+
85
+ def __init__(self, responses: List[str] = None):
86
+ self.responses = responses or ["mock response"]
87
+ self.call_count = 0
88
+
89
+ async def generate_response(self, prompt: str, **kwargs) -> str:
90
+ """Mock response generation."""
91
+ response = self.responses[self.call_count % len(self.responses)]
92
+ self.call_count += 1
93
+ await asyncio.sleep(0.01) # Simulate async processing
94
+ return response
95
+
96
+
97
+ class TestBasicMetrics:
98
+ """Test basic metric calculation functions."""
99
+
100
+ def test_exact_match_metric(self):
101
+ """Test exact match calculation."""
102
+ predictions = ["Paris", "London", "Berlin"]
103
+ references = ["Paris", "Madrid", "Berlin"]
104
+
105
+ def calculate_exact_match(pred_list, ref_list):
106
+ """Simple exact match implementation."""
107
+ matches = sum(1 for p, r in zip(pred_list, ref_list)
108
+ if p.strip().lower() == r.strip().lower())
109
+ return matches / len(pred_list)
110
+
111
+ accuracy = calculate_exact_match(predictions, references)
112
+ assert accuracy == 2/3 # 2 out of 3 matches
113
+
114
+ def test_f1_score_calculation(self):
115
+ """Test F1 score calculation."""
116
+ predictions = ["The cat sits", "A dog runs"]
117
+ references = ["The cat sits on mat", "The dog runs fast"]
118
+
119
+ def calculate_f1_score(pred_list, ref_list):
120
+ """Simple token-based F1 calculation."""
121
+ total_f1 = 0
122
+ for pred, ref in zip(pred_list, ref_list):
123
+ pred_tokens = set(pred.lower().split())
124
+ ref_tokens = set(ref.lower().split())
125
+
126
+ if len(pred_tokens) == 0 and len(ref_tokens) == 0:
127
+ f1 = 1.0
128
+ elif len(pred_tokens) == 0 or len(ref_tokens) == 0:
129
+ f1 = 0.0
130
+ else:
131
+ intersection = len(pred_tokens & ref_tokens)
132
+ precision = intersection / len(pred_tokens)
133
+ recall = intersection / len(ref_tokens)
134
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
135
+
136
+ total_f1 += f1
137
+
138
+ return total_f1 / len(pred_list)
139
+
140
+ f1 = calculate_f1_score(predictions, references)
141
+ assert isinstance(f1, float)
142
+ assert 0 <= f1 <= 1
143
+
144
+
145
+ class TestBasicEvaluator:
146
+ """Test basic evaluator functionality."""
147
+
148
+ @pytest.fixture
149
+ def mock_config(self):
150
+ """Create a mock evaluation config."""
151
+ return {
152
+ "batch_size": 2,
153
+ "max_concurrent_requests": 3,
154
+ "timeout_seconds": 30
155
+ }
156
+
157
+ @pytest.fixture
158
+ def sample_dataset(self):
159
+ """Create a sample dataset for testing."""
160
+ return [
161
+ {
162
+ "id": "test_1",
163
+ "prompt": "What is 2+2?",
164
+ "expected_output": "4",
165
+ "metadata": {"category": "math"}
166
+ },
167
+ {
168
+ "id": "test_2",
169
+ "input": "Name the capital of France",
170
+ "expected_output": "Paris",
171
+ "metadata": {"category": "geography"}
172
+ },
173
+ {
174
+ "id": "test_3",
175
+ "prompt": "What color is the sky?",
176
+ "expected_output": "blue",
177
+ "metadata": {"category": "general"}
178
+ }
179
+ ]
180
+
181
+ def test_evaluator_initialization(self, mock_config):
182
+ """Test basic evaluator initialization."""
183
+ class TestEvaluator(MockBaseEvaluator):
184
+ async def evaluate(self, model_interface, dataset, **kwargs):
185
+ return MockEvaluationResult()
186
+
187
+ evaluator = TestEvaluator(config=mock_config)
188
+ assert evaluator.config["batch_size"] == 2
189
+ assert evaluator.config["max_concurrent_requests"] == 3
190
+
191
+ @pytest.mark.asyncio
192
+ async def test_mock_evaluation_workflow(self, sample_dataset, mock_config):
193
+ """Test basic evaluation workflow with mocked components."""
194
+
195
+ class TestEvaluator(MockBaseEvaluator):
196
+ async def evaluate(self, model_interface, dataset, **kwargs):
197
+ """Simple evaluation implementation for testing."""
198
+ predictions = []
199
+ references = []
200
+
201
+ for item in dataset:
202
+ # Mock model call
203
+ response = await model_interface.generate_response(
204
+ item.get("prompt", item.get("input", ""))
205
+ )
206
+ predictions.append(response)
207
+ references.append(item["expected_output"])
208
+
209
+ # Calculate simple accuracy
210
+ matches = sum(1 for p, r in zip(predictions, references)
211
+ if p.strip().lower() == r.strip().lower())
212
+ accuracy = matches / len(predictions) if predictions else 0
213
+
214
+ return MockEvaluationResult(
215
+ metrics={"accuracy": accuracy, "total_samples": len(dataset)},
216
+ predictions=predictions,
217
+ references=references
218
+ )
219
+
220
+ # Create evaluator and mock model
221
+ evaluator = TestEvaluator(config=mock_config)
222
+ model_interface = MockModelInterface(responses=["4", "Paris", "blue"])
223
+
224
+ # Run evaluation
225
+ result = await evaluator.evaluate(
226
+ model_interface=model_interface,
227
+ dataset=sample_dataset,
228
+ dataset_name="test_dataset"
229
+ )
230
+
231
+ # Verify results
232
+ assert isinstance(result, MockEvaluationResult)
233
+ assert "accuracy" in result.metrics
234
+ assert "total_samples" in result.metrics
235
+ assert result.metrics["total_samples"] == 3
236
+ assert result.metrics["accuracy"] == 1.0 # All mock responses match expected
237
+ assert len(result.predictions) == 3
238
+ assert len(result.references) == 3
239
+
240
+
241
+ class TestEvaluationConfig:
242
+ """Test evaluation configuration functionality."""
243
+
244
+ def test_config_creation(self):
245
+ """Test basic config creation."""
246
+ config_data = {
247
+ "batch_size": 16,
248
+ "max_concurrent_requests": 5,
249
+ "timeout_seconds": 60,
250
+ "output_dir": "test_results"
251
+ }
252
+
253
+ class MockConfig:
254
+ def __init__(self, **kwargs):
255
+ for k, v in kwargs.items():
256
+ setattr(self, k, v)
257
+
258
+ config = MockConfig(**config_data)
259
+ assert config.batch_size == 16
260
+ assert config.output_dir == "test_results"
261
+
262
+ def test_config_validation(self):
263
+ """Test config validation logic."""
264
+ def validate_config(config_dict):
265
+ """Validate configuration values."""
266
+ if config_dict.get("batch_size", 1) <= 0:
267
+ raise ValueError("batch_size must be positive")
268
+ if config_dict.get("max_concurrent_requests", 1) <= 0:
269
+ raise ValueError("max_concurrent_requests must be positive")
270
+ if config_dict.get("timeout_seconds", 1) <= 0:
271
+ raise ValueError("timeout_seconds must be positive")
272
+ return True
273
+
274
+ # Test valid config
275
+ valid_config = {"batch_size": 10, "max_concurrent_requests": 5, "timeout_seconds": 60}
276
+ assert validate_config(valid_config) is True
277
+
278
+ # Test invalid configs
279
+ invalid_configs = [
280
+ {"batch_size": -1},
281
+ {"max_concurrent_requests": 0},
282
+ {"timeout_seconds": -5}
283
+ ]
284
+
285
+ for invalid_config in invalid_configs:
286
+ with pytest.raises(ValueError):
287
+ validate_config(invalid_config)
288
+
289
+
290
+ class TestAsyncEvaluation:
291
+ """Test asynchronous evaluation capabilities."""
292
+
293
+ @pytest.mark.asyncio
294
+ async def test_concurrent_evaluation(self):
295
+ """Test that evaluations can run concurrently."""
296
+ async def mock_evaluation_task(task_id: int, delay: float = 0.1):
297
+ """Mock evaluation task with delay."""
298
+ await asyncio.sleep(delay)
299
+ return {"task_id": task_id, "result": f"completed_{task_id}"}
300
+
301
+ # Run multiple evaluations concurrently
302
+ start_time = asyncio.get_event_loop().time()
303
+
304
+ tasks = [mock_evaluation_task(i, 0.1) for i in range(3)]
305
+ results = await asyncio.gather(*tasks)
306
+
307
+ end_time = asyncio.get_event_loop().time()
308
+
309
+ # Should complete in roughly 0.1 seconds (concurrent) rather than 0.3 (sequential)
310
+ assert end_time - start_time < 0.2
311
+ assert len(results) == 3
312
+ assert all(r["result"].startswith("completed_") for r in results)
313
+
314
+ @pytest.mark.asyncio
315
+ async def test_batch_processing(self):
316
+ """Test batch processing functionality."""
317
+ async def process_batch(batch: List[Dict], batch_size: int = 2):
318
+ """Process items in batches."""
319
+ results = []
320
+ for i in range(0, len(batch), batch_size):
321
+ batch_items = batch[i:i + batch_size]
322
+ # Simulate processing time proportional to batch size
323
+ await asyncio.sleep(0.01 * len(batch_items))
324
+ batch_results = [{"processed": item["id"]} for item in batch_items]
325
+ results.extend(batch_results)
326
+ return results
327
+
328
+ # Test data
329
+ test_items = [{"id": f"item_{i}"} for i in range(5)]
330
+
331
+ # Process in batches
332
+ results = await process_batch(test_items, batch_size=2)
333
+
334
+ assert len(results) == 5
335
+ assert all("processed" in r for r in results)
336
+
337
+
338
+ class TestErrorHandling:
339
+ """Test error handling and edge cases."""
340
+
341
+ @pytest.mark.asyncio
342
+ async def test_timeout_handling(self):
343
+ """Test timeout handling in async operations."""
344
+ async def slow_operation():
345
+ """Simulate a slow operation."""
346
+ await asyncio.sleep(1.0)
347
+ return "completed"
348
+
349
+ # Test timeout
350
+ with pytest.raises(asyncio.TimeoutError):
351
+ await asyncio.wait_for(slow_operation(), timeout=0.1)
352
+
353
+ def test_empty_dataset_handling(self):
354
+ """Test handling of empty datasets."""
355
+ def calculate_metrics(predictions, references):
356
+ """Calculate metrics with empty data handling."""
357
+ if not predictions or not references:
358
+ return {"accuracy": 0.0, "count": 0}
359
+
360
+ matches = sum(1 for p, r in zip(predictions, references) if p == r)
361
+ return {
362
+ "accuracy": matches / len(predictions),
363
+ "count": len(predictions)
364
+ }
365
+
366
+ # Test empty data
367
+ empty_metrics = calculate_metrics([], [])
368
+ assert empty_metrics["accuracy"] == 0.0
369
+ assert empty_metrics["count"] == 0
370
+
371
+ def test_mismatched_data_lengths(self):
372
+ """Test handling of mismatched prediction and reference lengths."""
373
+ def safe_calculate_accuracy(predictions, references):
374
+ """Safely calculate accuracy with length mismatch handling."""
375
+ if len(predictions) != len(references):
376
+ min_len = min(len(predictions), len(references))
377
+ predictions = predictions[:min_len]
378
+ references = references[:min_len]
379
+
380
+ if not predictions:
381
+ return 0.0
382
+
383
+ matches = sum(1 for p, r in zip(predictions, references) if p == r)
384
+ return matches / len(predictions)
385
+
386
+ # Test mismatched lengths
387
+ predictions = ["a", "b", "c"]
388
+ references = ["a", "b"] # Shorter
389
+
390
+ accuracy = safe_calculate_accuracy(predictions, references)
391
+ assert accuracy == 1.0 # Both "a" and "b" match
392
+
393
+
394
+ if __name__ == "__main__":
395
+ # Allow running tests directly
396
+ pytest.main([__file__, "-v"])