isa-model 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +30 -1
- isa_model/client.py +770 -0
- isa_model/core/config/__init__.py +16 -0
- isa_model/core/config/config_manager.py +514 -0
- isa_model/core/config.py +426 -0
- isa_model/core/models/model_billing_tracker.py +476 -0
- isa_model/core/models/model_manager.py +399 -0
- isa_model/core/models/model_repo.py +343 -0
- isa_model/core/pricing_manager.py +426 -0
- isa_model/core/services/__init__.py +19 -0
- isa_model/core/services/intelligent_model_selector.py +547 -0
- isa_model/core/types.py +291 -0
- isa_model/deployment/__init__.py +2 -0
- isa_model/deployment/cloud/__init__.py +9 -0
- isa_model/deployment/cloud/modal/__init__.py +10 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +766 -0
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +406 -0
- isa_model/deployment/cloud/modal/register_models.py +321 -0
- isa_model/deployment/runtime/deployed_service.py +338 -0
- isa_model/deployment/services/__init__.py +9 -0
- isa_model/deployment/services/auto_deploy_vision_service.py +537 -0
- isa_model/deployment/services/model_service.py +332 -0
- isa_model/deployment/services/service_monitor.py +356 -0
- isa_model/deployment/services/service_registry.py +527 -0
- isa_model/eval/__init__.py +80 -44
- isa_model/eval/config/__init__.py +10 -0
- isa_model/eval/config/evaluation_config.py +108 -0
- isa_model/eval/evaluators/__init__.py +18 -0
- isa_model/eval/evaluators/base_evaluator.py +503 -0
- isa_model/eval/evaluators/llm_evaluator.py +472 -0
- isa_model/eval/factory.py +417 -709
- isa_model/eval/infrastructure/__init__.py +24 -0
- isa_model/eval/infrastructure/experiment_tracker.py +466 -0
- isa_model/eval/metrics.py +191 -21
- isa_model/inference/ai_factory.py +187 -387
- isa_model/inference/providers/modal_provider.py +109 -0
- isa_model/inference/providers/yyds_provider.py +108 -0
- isa_model/inference/services/__init__.py +2 -1
- isa_model/inference/services/audio/base_stt_service.py +65 -1
- isa_model/inference/services/audio/base_tts_service.py +75 -1
- isa_model/inference/services/audio/openai_stt_service.py +189 -151
- isa_model/inference/services/audio/openai_tts_service.py +12 -10
- isa_model/inference/services/audio/replicate_tts_service.py +61 -56
- isa_model/inference/services/base_service.py +55 -55
- isa_model/inference/services/embedding/base_embed_service.py +65 -1
- isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
- isa_model/inference/services/embedding/openai_embed_service.py +8 -10
- isa_model/inference/services/helpers/stacked_config.py +148 -0
- isa_model/inference/services/img/__init__.py +18 -0
- isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -35
- isa_model/inference/services/img/flux_professional_service.py +603 -0
- isa_model/inference/services/img/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +210 -69
- isa_model/inference/services/llm/__init__.py +3 -3
- isa_model/inference/services/llm/base_llm_service.py +519 -35
- isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +40 -0
- isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
- isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
- isa_model/inference/services/llm/ollama_llm_service.py +150 -15
- isa_model/inference/services/llm/openai_llm_service.py +134 -31
- isa_model/inference/services/llm/yyds_llm_service.py +255 -0
- isa_model/inference/services/vision/__init__.py +38 -4
- isa_model/inference/services/vision/base_vision_service.py +241 -96
- isa_model/inference/services/vision/disabled/isA_vision_service.py +500 -0
- isa_model/inference/services/vision/doc_analysis_service.py +640 -0
- isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/vision/helpers/image_utils.py +272 -3
- isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
- isa_model/inference/services/vision/openai_vision_service.py +109 -170
- isa_model/inference/services/vision/replicate_vision_service.py +508 -0
- isa_model/inference/services/vision/ui_analysis_service.py +823 -0
- isa_model/scripts/register_models.py +370 -0
- isa_model/scripts/register_models_with_embeddings.py +510 -0
- isa_model/serving/__init__.py +19 -0
- isa_model/serving/api/__init__.py +10 -0
- isa_model/serving/api/fastapi_server.py +89 -0
- isa_model/serving/api/middleware/__init__.py +9 -0
- isa_model/serving/api/middleware/request_logger.py +88 -0
- isa_model/serving/api/routes/__init__.py +5 -0
- isa_model/serving/api/routes/health.py +82 -0
- isa_model/serving/api/routes/llm.py +19 -0
- isa_model/serving/api/routes/ui_analysis.py +223 -0
- isa_model/serving/api/routes/unified.py +202 -0
- isa_model/serving/api/routes/vision.py +19 -0
- isa_model/serving/api/schemas/__init__.py +17 -0
- isa_model/serving/api/schemas/common.py +33 -0
- isa_model/serving/api/schemas/ui_analysis.py +78 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/METADATA +4 -1
- isa_model-0.3.6.dist-info/RECORD +147 -0
- isa_model/core/model_manager.py +0 -208
- isa_model/core/model_registry.py +0 -342
- isa_model/inference/billing_tracker.py +0 -406
- isa_model/inference/services/llm/triton_llm_service.py +0 -481
- isa_model/inference/services/vision/ollama_vision_service.py +0 -194
- isa_model-0.3.4.dist-info/RECORD +0 -91
- /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
- /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/WHEEL +0 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,108 @@
|
|
1
|
+
"""
|
2
|
+
Configuration management for evaluation framework
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import json
|
7
|
+
import logging
|
8
|
+
from typing import Dict, Any, Optional, List
|
9
|
+
from dataclasses import dataclass, asdict
|
10
|
+
from pathlib import Path
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
@dataclass
|
16
|
+
class EvaluationConfig:
|
17
|
+
"""
|
18
|
+
Configuration class for evaluation settings.
|
19
|
+
"""
|
20
|
+
|
21
|
+
# General settings
|
22
|
+
output_dir: str = "evaluation_results"
|
23
|
+
max_concurrent_evaluations: int = 3
|
24
|
+
timeout_seconds: int = 600
|
25
|
+
|
26
|
+
# Model settings
|
27
|
+
default_provider: str = "openai"
|
28
|
+
default_max_tokens: int = 150
|
29
|
+
default_temperature: float = 0.1
|
30
|
+
batch_size: int = 8
|
31
|
+
|
32
|
+
# Metrics settings
|
33
|
+
compute_all_metrics: bool = False
|
34
|
+
custom_metrics: List[str] = None
|
35
|
+
|
36
|
+
# Benchmark settings
|
37
|
+
max_samples_per_benchmark: Optional[int] = None
|
38
|
+
enable_few_shot: bool = True
|
39
|
+
num_shots: int = 5
|
40
|
+
|
41
|
+
# Experiment tracking
|
42
|
+
use_wandb: bool = False
|
43
|
+
wandb_project: Optional[str] = None
|
44
|
+
wandb_entity: Optional[str] = None
|
45
|
+
use_mlflow: bool = False
|
46
|
+
mlflow_tracking_uri: Optional[str] = None
|
47
|
+
|
48
|
+
# Results settings
|
49
|
+
save_predictions: bool = True
|
50
|
+
save_detailed_results: bool = True
|
51
|
+
export_format: str = "json" # json, csv, html
|
52
|
+
|
53
|
+
def __post_init__(self):
|
54
|
+
"""Initialize default values after creation."""
|
55
|
+
if self.custom_metrics is None:
|
56
|
+
self.custom_metrics = []
|
57
|
+
|
58
|
+
# Ensure output directory exists
|
59
|
+
os.makedirs(self.output_dir, exist_ok=True)
|
60
|
+
|
61
|
+
@classmethod
|
62
|
+
def from_dict(cls, config_dict: Dict[str, Any]) -> 'EvaluationConfig':
|
63
|
+
"""
|
64
|
+
Create configuration from dictionary.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
config_dict: Configuration dictionary
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
EvaluationConfig instance
|
71
|
+
"""
|
72
|
+
# Filter out unknown keys
|
73
|
+
valid_keys = {field.name for field in cls.__dataclass_fields__.values()}
|
74
|
+
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
|
75
|
+
|
76
|
+
return cls(**filtered_dict)
|
77
|
+
|
78
|
+
def to_dict(self) -> Dict[str, Any]:
|
79
|
+
"""
|
80
|
+
Convert configuration to dictionary.
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
Configuration as dictionary
|
84
|
+
"""
|
85
|
+
return asdict(self)
|
86
|
+
|
87
|
+
|
88
|
+
class ConfigManager:
|
89
|
+
"""Manager for handling multiple evaluation configurations."""
|
90
|
+
|
91
|
+
def __init__(self, config_dir: str = "configs"):
|
92
|
+
"""Initialize configuration manager."""
|
93
|
+
self.config_dir = config_dir
|
94
|
+
self.configs: Dict[str, EvaluationConfig] = {}
|
95
|
+
self.default_config = EvaluationConfig()
|
96
|
+
|
97
|
+
# Ensure config directory exists
|
98
|
+
os.makedirs(config_dir, exist_ok=True)
|
99
|
+
|
100
|
+
def get_config(self, config_name: Optional[str] = None) -> EvaluationConfig:
|
101
|
+
"""Get configuration by name."""
|
102
|
+
if config_name is None:
|
103
|
+
return self.default_config
|
104
|
+
|
105
|
+
if config_name in self.configs:
|
106
|
+
return self.configs[config_name]
|
107
|
+
|
108
|
+
return self.default_config
|
@@ -0,0 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Evaluators module for ISA Model Framework
|
3
|
+
|
4
|
+
Provides specialized evaluators for different model types and evaluation tasks.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from .base_evaluator import BaseEvaluator, EvaluationResult
|
8
|
+
from .llm_evaluator import LLMEvaluator
|
9
|
+
from .vision_evaluator import VisionEvaluator
|
10
|
+
from .multimodal_evaluator import MultimodalEvaluator
|
11
|
+
|
12
|
+
__all__ = [
|
13
|
+
"BaseEvaluator",
|
14
|
+
"EvaluationResult",
|
15
|
+
"LLMEvaluator",
|
16
|
+
"VisionEvaluator",
|
17
|
+
"MultimodalEvaluator"
|
18
|
+
]
|
@@ -0,0 +1,503 @@
|
|
1
|
+
"""
|
2
|
+
Base evaluator class implementing industry best practices for AI model evaluation.
|
3
|
+
|
4
|
+
Features:
|
5
|
+
- Async/await support for concurrent evaluation
|
6
|
+
- Comprehensive error handling and retry logic
|
7
|
+
- Experiment tracking integration (W&B, MLflow)
|
8
|
+
- Distributed evaluation support
|
9
|
+
- Memory-efficient batch processing
|
10
|
+
- Comprehensive logging and metrics
|
11
|
+
"""
|
12
|
+
|
13
|
+
import asyncio
|
14
|
+
import logging
|
15
|
+
import time
|
16
|
+
import traceback
|
17
|
+
from abc import ABC, abstractmethod
|
18
|
+
from dataclasses import dataclass, field
|
19
|
+
from typing import Dict, List, Any, Optional, Union, Callable, AsyncGenerator
|
20
|
+
from datetime import datetime
|
21
|
+
from pathlib import Path
|
22
|
+
import json
|
23
|
+
|
24
|
+
try:
|
25
|
+
import wandb
|
26
|
+
WANDB_AVAILABLE = True
|
27
|
+
except ImportError:
|
28
|
+
WANDB_AVAILABLE = False
|
29
|
+
|
30
|
+
try:
|
31
|
+
import mlflow
|
32
|
+
MLFLOW_AVAILABLE = True
|
33
|
+
except ImportError:
|
34
|
+
MLFLOW_AVAILABLE = False
|
35
|
+
|
36
|
+
logger = logging.getLogger(__name__)
|
37
|
+
|
38
|
+
|
39
|
+
@dataclass
|
40
|
+
class EvaluationResult:
|
41
|
+
"""
|
42
|
+
Standardized evaluation result container.
|
43
|
+
|
44
|
+
Follows MLOps best practices for result tracking and reproducibility.
|
45
|
+
"""
|
46
|
+
|
47
|
+
# Core results
|
48
|
+
metrics: Dict[str, float] = field(default_factory=dict)
|
49
|
+
predictions: List[Any] = field(default_factory=list)
|
50
|
+
references: List[Any] = field(default_factory=list)
|
51
|
+
|
52
|
+
# Metadata
|
53
|
+
model_name: str = ""
|
54
|
+
dataset_name: str = ""
|
55
|
+
evaluation_type: str = ""
|
56
|
+
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
57
|
+
|
58
|
+
# Performance metrics
|
59
|
+
total_samples: int = 0
|
60
|
+
successful_samples: int = 0
|
61
|
+
failed_samples: int = 0
|
62
|
+
evaluation_time_seconds: float = 0.0
|
63
|
+
throughput_samples_per_second: float = 0.0
|
64
|
+
|
65
|
+
# Cost and resource tracking
|
66
|
+
total_tokens_used: int = 0
|
67
|
+
estimated_cost_usd: float = 0.0
|
68
|
+
memory_peak_mb: float = 0.0
|
69
|
+
|
70
|
+
# Configuration
|
71
|
+
config: Dict[str, Any] = field(default_factory=dict)
|
72
|
+
environment_info: Dict[str, Any] = field(default_factory=dict)
|
73
|
+
|
74
|
+
# Error tracking
|
75
|
+
errors: List[Dict[str, Any]] = field(default_factory=list)
|
76
|
+
warnings: List[str] = field(default_factory=list)
|
77
|
+
|
78
|
+
# Detailed results
|
79
|
+
sample_results: List[Dict[str, Any]] = field(default_factory=list)
|
80
|
+
|
81
|
+
def to_dict(self) -> Dict[str, Any]:
|
82
|
+
"""Convert to dictionary for serialization."""
|
83
|
+
return {
|
84
|
+
"metrics": self.metrics,
|
85
|
+
"predictions": self.predictions,
|
86
|
+
"references": self.references,
|
87
|
+
"model_name": self.model_name,
|
88
|
+
"dataset_name": self.dataset_name,
|
89
|
+
"evaluation_type": self.evaluation_type,
|
90
|
+
"timestamp": self.timestamp,
|
91
|
+
"total_samples": self.total_samples,
|
92
|
+
"successful_samples": self.successful_samples,
|
93
|
+
"failed_samples": self.failed_samples,
|
94
|
+
"evaluation_time_seconds": self.evaluation_time_seconds,
|
95
|
+
"throughput_samples_per_second": self.throughput_samples_per_second,
|
96
|
+
"total_tokens_used": self.total_tokens_used,
|
97
|
+
"estimated_cost_usd": self.estimated_cost_usd,
|
98
|
+
"memory_peak_mb": self.memory_peak_mb,
|
99
|
+
"config": self.config,
|
100
|
+
"environment_info": self.environment_info,
|
101
|
+
"errors": self.errors,
|
102
|
+
"warnings": self.warnings,
|
103
|
+
"sample_results": self.sample_results
|
104
|
+
}
|
105
|
+
|
106
|
+
def save_to_file(self, file_path: Union[str, Path]) -> None:
|
107
|
+
"""Save results to JSON file."""
|
108
|
+
with open(file_path, 'w', encoding='utf-8') as f:
|
109
|
+
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
110
|
+
|
111
|
+
@classmethod
|
112
|
+
def load_from_file(cls, file_path: Union[str, Path]) -> 'EvaluationResult':
|
113
|
+
"""Load results from JSON file."""
|
114
|
+
with open(file_path, 'r', encoding='utf-8') as f:
|
115
|
+
data = json.load(f)
|
116
|
+
|
117
|
+
result = cls()
|
118
|
+
for key, value in data.items():
|
119
|
+
if hasattr(result, key):
|
120
|
+
setattr(result, key, value)
|
121
|
+
|
122
|
+
return result
|
123
|
+
|
124
|
+
def get_summary(self) -> Dict[str, Any]:
|
125
|
+
"""Get evaluation summary."""
|
126
|
+
success_rate = self.successful_samples / self.total_samples if self.total_samples > 0 else 0.0
|
127
|
+
|
128
|
+
return {
|
129
|
+
"model_name": self.model_name,
|
130
|
+
"dataset_name": self.dataset_name,
|
131
|
+
"evaluation_type": self.evaluation_type,
|
132
|
+
"timestamp": self.timestamp,
|
133
|
+
"success_rate": success_rate,
|
134
|
+
"total_samples": self.total_samples,
|
135
|
+
"evaluation_time_seconds": self.evaluation_time_seconds,
|
136
|
+
"throughput_samples_per_second": self.throughput_samples_per_second,
|
137
|
+
"estimated_cost_usd": self.estimated_cost_usd,
|
138
|
+
"key_metrics": self.metrics,
|
139
|
+
"error_count": len(self.errors),
|
140
|
+
"warning_count": len(self.warnings)
|
141
|
+
}
|
142
|
+
|
143
|
+
|
144
|
+
class BaseEvaluator(ABC):
|
145
|
+
"""
|
146
|
+
Abstract base evaluator implementing industry best practices.
|
147
|
+
|
148
|
+
Features:
|
149
|
+
- Async evaluation with concurrency control
|
150
|
+
- Comprehensive error handling and retry logic
|
151
|
+
- Experiment tracking integration
|
152
|
+
- Memory-efficient batch processing
|
153
|
+
- Progress monitoring and cancellation support
|
154
|
+
"""
|
155
|
+
|
156
|
+
def __init__(self,
|
157
|
+
evaluator_name: str,
|
158
|
+
config: Optional[Dict[str, Any]] = None,
|
159
|
+
experiment_tracker: Optional[Any] = None):
|
160
|
+
"""
|
161
|
+
Initialize the base evaluator.
|
162
|
+
|
163
|
+
Args:
|
164
|
+
evaluator_name: Name identifier for this evaluator
|
165
|
+
config: Evaluation configuration
|
166
|
+
experiment_tracker: Optional experiment tracking instance
|
167
|
+
"""
|
168
|
+
self.evaluator_name = evaluator_name
|
169
|
+
self.config = config or {}
|
170
|
+
self.experiment_tracker = experiment_tracker
|
171
|
+
|
172
|
+
# State management
|
173
|
+
self._is_running = False
|
174
|
+
self._should_stop = False
|
175
|
+
self._current_result: Optional[EvaluationResult] = None
|
176
|
+
|
177
|
+
# Performance monitoring
|
178
|
+
self._start_time: Optional[float] = None
|
179
|
+
self._peak_memory_mb: float = 0.0
|
180
|
+
|
181
|
+
# Concurrency control
|
182
|
+
self.max_concurrent_requests = self.config.get("max_concurrent_requests", 10)
|
183
|
+
self.semaphore = asyncio.Semaphore(self.max_concurrent_requests)
|
184
|
+
|
185
|
+
# Retry configuration
|
186
|
+
self.max_retries = self.config.get("max_retries", 3)
|
187
|
+
self.retry_delay = self.config.get("retry_delay_seconds", 1.0)
|
188
|
+
|
189
|
+
logger.info(f"Initialized {evaluator_name} evaluator with config: {self.config}")
|
190
|
+
|
191
|
+
@abstractmethod
|
192
|
+
async def evaluate_sample(self,
|
193
|
+
sample: Dict[str, Any],
|
194
|
+
model_interface: Any) -> Dict[str, Any]:
|
195
|
+
"""
|
196
|
+
Evaluate a single sample.
|
197
|
+
|
198
|
+
Args:
|
199
|
+
sample: Data sample to evaluate
|
200
|
+
model_interface: Model interface for inference
|
201
|
+
|
202
|
+
Returns:
|
203
|
+
Evaluation result for the sample
|
204
|
+
"""
|
205
|
+
pass
|
206
|
+
|
207
|
+
@abstractmethod
|
208
|
+
def compute_metrics(self,
|
209
|
+
predictions: List[Any],
|
210
|
+
references: List[Any],
|
211
|
+
**kwargs) -> Dict[str, float]:
|
212
|
+
"""
|
213
|
+
Compute evaluation metrics.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
predictions: Model predictions
|
217
|
+
references: Ground truth references
|
218
|
+
**kwargs: Additional parameters
|
219
|
+
|
220
|
+
Returns:
|
221
|
+
Dictionary of computed metrics
|
222
|
+
"""
|
223
|
+
pass
|
224
|
+
|
225
|
+
async def evaluate(self,
|
226
|
+
model_interface: Any,
|
227
|
+
dataset: List[Dict[str, Any]],
|
228
|
+
dataset_name: str = "unknown",
|
229
|
+
model_name: str = "unknown",
|
230
|
+
batch_size: Optional[int] = None,
|
231
|
+
save_predictions: bool = True,
|
232
|
+
progress_callback: Optional[Callable] = None) -> EvaluationResult:
|
233
|
+
"""
|
234
|
+
Perform comprehensive evaluation with industry best practices.
|
235
|
+
|
236
|
+
Args:
|
237
|
+
model_interface: Model interface for inference
|
238
|
+
dataset: Dataset to evaluate on
|
239
|
+
dataset_name: Name of the dataset
|
240
|
+
model_name: Name of the model
|
241
|
+
batch_size: Batch size for processing
|
242
|
+
save_predictions: Whether to save individual predictions
|
243
|
+
progress_callback: Optional callback for progress updates
|
244
|
+
|
245
|
+
Returns:
|
246
|
+
Comprehensive evaluation results
|
247
|
+
"""
|
248
|
+
|
249
|
+
# Initialize evaluation
|
250
|
+
self._start_evaluation()
|
251
|
+
result = EvaluationResult(
|
252
|
+
model_name=model_name,
|
253
|
+
dataset_name=dataset_name,
|
254
|
+
evaluation_type=self.evaluator_name,
|
255
|
+
config=self.config.copy(),
|
256
|
+
environment_info=self._get_environment_info()
|
257
|
+
)
|
258
|
+
|
259
|
+
try:
|
260
|
+
# Start experiment tracking
|
261
|
+
await self._start_experiment_tracking(model_name, dataset_name)
|
262
|
+
|
263
|
+
# Process dataset in batches
|
264
|
+
batch_size = batch_size or self.config.get("batch_size", 32)
|
265
|
+
total_batches = (len(dataset) + batch_size - 1) // batch_size
|
266
|
+
|
267
|
+
all_predictions = []
|
268
|
+
all_references = []
|
269
|
+
all_sample_results = []
|
270
|
+
|
271
|
+
for batch_idx in range(total_batches):
|
272
|
+
if self._should_stop:
|
273
|
+
logger.info("Evaluation stopped by user request")
|
274
|
+
break
|
275
|
+
|
276
|
+
# Get batch
|
277
|
+
start_idx = batch_idx * batch_size
|
278
|
+
end_idx = min(start_idx + batch_size, len(dataset))
|
279
|
+
batch = dataset[start_idx:end_idx]
|
280
|
+
|
281
|
+
# Process batch
|
282
|
+
batch_results = await self._process_batch(batch, model_interface)
|
283
|
+
|
284
|
+
# Collect results
|
285
|
+
for sample_result in batch_results:
|
286
|
+
if sample_result.get("success", False):
|
287
|
+
all_predictions.append(sample_result.get("prediction"))
|
288
|
+
all_references.append(sample_result.get("reference"))
|
289
|
+
result.successful_samples += 1
|
290
|
+
else:
|
291
|
+
result.failed_samples += 1
|
292
|
+
result.errors.append({
|
293
|
+
"sample_id": sample_result.get("sample_id"),
|
294
|
+
"error": sample_result.get("error"),
|
295
|
+
"timestamp": datetime.now().isoformat()
|
296
|
+
})
|
297
|
+
|
298
|
+
if save_predictions:
|
299
|
+
all_sample_results.append(sample_result)
|
300
|
+
|
301
|
+
# Update progress
|
302
|
+
progress = (batch_idx + 1) / total_batches
|
303
|
+
if progress_callback:
|
304
|
+
await progress_callback(progress, batch_idx + 1, total_batches)
|
305
|
+
|
306
|
+
# Log progress
|
307
|
+
if (batch_idx + 1) % 10 == 0 or batch_idx == total_batches - 1:
|
308
|
+
logger.info(f"Processed {batch_idx + 1}/{total_batches} batches "
|
309
|
+
f"({result.successful_samples} successful, {result.failed_samples} failed)")
|
310
|
+
|
311
|
+
# Compute final metrics
|
312
|
+
if all_predictions and all_references:
|
313
|
+
result.metrics = self.compute_metrics(all_predictions, all_references)
|
314
|
+
logger.info(f"Computed metrics: {result.metrics}")
|
315
|
+
else:
|
316
|
+
logger.warning("No valid predictions available for metric computation")
|
317
|
+
result.warnings.append("No valid predictions available for metric computation")
|
318
|
+
|
319
|
+
# Finalize results
|
320
|
+
result.predictions = all_predictions
|
321
|
+
result.references = all_references
|
322
|
+
result.sample_results = all_sample_results
|
323
|
+
result.total_samples = len(dataset)
|
324
|
+
|
325
|
+
# Log experiment results
|
326
|
+
await self._log_experiment_results(result)
|
327
|
+
|
328
|
+
except Exception as e:
|
329
|
+
logger.error(f"Evaluation failed: {e}")
|
330
|
+
logger.error(traceback.format_exc())
|
331
|
+
result.errors.append({
|
332
|
+
"error": str(e),
|
333
|
+
"error_type": type(e).__name__,
|
334
|
+
"traceback": traceback.format_exc(),
|
335
|
+
"timestamp": datetime.now().isoformat()
|
336
|
+
})
|
337
|
+
|
338
|
+
finally:
|
339
|
+
# Finalize evaluation
|
340
|
+
self._end_evaluation(result)
|
341
|
+
await self._end_experiment_tracking()
|
342
|
+
self._current_result = result
|
343
|
+
|
344
|
+
return result
|
345
|
+
|
346
|
+
async def _process_batch(self,
|
347
|
+
batch: List[Dict[str, Any]],
|
348
|
+
model_interface: Any) -> List[Dict[str, Any]]:
|
349
|
+
"""Process a batch of samples with concurrency control."""
|
350
|
+
tasks = []
|
351
|
+
|
352
|
+
for sample in batch:
|
353
|
+
task = asyncio.create_task(
|
354
|
+
self._process_sample_with_retry(sample, model_interface)
|
355
|
+
)
|
356
|
+
tasks.append(task)
|
357
|
+
|
358
|
+
# Wait for all tasks in batch to complete
|
359
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
360
|
+
|
361
|
+
# Process results and handle exceptions
|
362
|
+
processed_results = []
|
363
|
+
for i, result in enumerate(results):
|
364
|
+
if isinstance(result, Exception):
|
365
|
+
processed_results.append({
|
366
|
+
"sample_id": batch[i].get("id", f"sample_{i}"),
|
367
|
+
"success": False,
|
368
|
+
"error": str(result),
|
369
|
+
"prediction": None,
|
370
|
+
"reference": batch[i].get("reference")
|
371
|
+
})
|
372
|
+
else:
|
373
|
+
processed_results.append(result)
|
374
|
+
|
375
|
+
return processed_results
|
376
|
+
|
377
|
+
async def _process_sample_with_retry(self,
|
378
|
+
sample: Dict[str, Any],
|
379
|
+
model_interface: Any) -> Dict[str, Any]:
|
380
|
+
"""Process a single sample with retry logic and concurrency control."""
|
381
|
+
async with self.semaphore: # Limit concurrent requests
|
382
|
+
for attempt in range(self.max_retries + 1):
|
383
|
+
try:
|
384
|
+
result = await self.evaluate_sample(sample, model_interface)
|
385
|
+
result["success"] = True
|
386
|
+
result["sample_id"] = sample.get("id", "unknown")
|
387
|
+
result["reference"] = sample.get("reference")
|
388
|
+
return result
|
389
|
+
|
390
|
+
except Exception as e:
|
391
|
+
if attempt == self.max_retries:
|
392
|
+
# Final attempt failed
|
393
|
+
logger.error(f"Sample evaluation failed after {self.max_retries + 1} attempts: {e}")
|
394
|
+
return {
|
395
|
+
"sample_id": sample.get("id", "unknown"),
|
396
|
+
"success": False,
|
397
|
+
"error": str(e),
|
398
|
+
"prediction": None,
|
399
|
+
"reference": sample.get("reference")
|
400
|
+
}
|
401
|
+
else:
|
402
|
+
# Retry with exponential backoff
|
403
|
+
delay = self.retry_delay * (2 ** attempt)
|
404
|
+
logger.warning(f"Sample evaluation failed (attempt {attempt + 1}), retrying in {delay}s: {e}")
|
405
|
+
await asyncio.sleep(delay)
|
406
|
+
|
407
|
+
def _start_evaluation(self) -> None:
|
408
|
+
"""Mark the start of evaluation."""
|
409
|
+
self._is_running = True
|
410
|
+
self._should_stop = False
|
411
|
+
self._start_time = time.time()
|
412
|
+
|
413
|
+
# Monitor memory usage
|
414
|
+
try:
|
415
|
+
import psutil
|
416
|
+
process = psutil.Process()
|
417
|
+
self._peak_memory_mb = process.memory_info().rss / 1024 / 1024
|
418
|
+
except ImportError:
|
419
|
+
pass
|
420
|
+
|
421
|
+
def _end_evaluation(self, result: EvaluationResult) -> None:
|
422
|
+
"""Finalize evaluation with performance metrics."""
|
423
|
+
self._is_running = False
|
424
|
+
end_time = time.time()
|
425
|
+
|
426
|
+
if self._start_time:
|
427
|
+
result.evaluation_time_seconds = end_time - self._start_time
|
428
|
+
if result.total_samples > 0:
|
429
|
+
result.throughput_samples_per_second = result.total_samples / result.evaluation_time_seconds
|
430
|
+
|
431
|
+
result.memory_peak_mb = self._peak_memory_mb
|
432
|
+
|
433
|
+
logger.info(f"Evaluation completed in {result.evaluation_time_seconds:.2f}s "
|
434
|
+
f"({result.throughput_samples_per_second:.2f} samples/sec)")
|
435
|
+
|
436
|
+
def _get_environment_info(self) -> Dict[str, Any]:
|
437
|
+
"""Get environment information for reproducibility."""
|
438
|
+
import platform
|
439
|
+
import sys
|
440
|
+
|
441
|
+
env_info = {
|
442
|
+
"python_version": sys.version,
|
443
|
+
"platform": platform.platform(),
|
444
|
+
"hostname": platform.node(),
|
445
|
+
"timestamp": datetime.now().isoformat()
|
446
|
+
}
|
447
|
+
|
448
|
+
try:
|
449
|
+
import torch
|
450
|
+
env_info["torch_version"] = torch.__version__
|
451
|
+
env_info["cuda_available"] = torch.cuda.is_available()
|
452
|
+
if torch.cuda.is_available():
|
453
|
+
env_info["cuda_device_count"] = torch.cuda.device_count()
|
454
|
+
env_info["cuda_device_name"] = torch.cuda.get_device_name()
|
455
|
+
except ImportError:
|
456
|
+
pass
|
457
|
+
|
458
|
+
return env_info
|
459
|
+
|
460
|
+
async def _start_experiment_tracking(self, model_name: str, dataset_name: str) -> None:
|
461
|
+
"""Start experiment tracking if available."""
|
462
|
+
if self.experiment_tracker:
|
463
|
+
try:
|
464
|
+
await self.experiment_tracker.start_run(
|
465
|
+
name=f"{self.evaluator_name}_{model_name}_{dataset_name}",
|
466
|
+
config=self.config
|
467
|
+
)
|
468
|
+
except Exception as e:
|
469
|
+
logger.warning(f"Failed to start experiment tracking: {e}")
|
470
|
+
|
471
|
+
async def _log_experiment_results(self, result: EvaluationResult) -> None:
|
472
|
+
"""Log results to experiment tracker."""
|
473
|
+
if self.experiment_tracker:
|
474
|
+
try:
|
475
|
+
await self.experiment_tracker.log_metrics(result.metrics)
|
476
|
+
await self.experiment_tracker.log_params(result.config)
|
477
|
+
except Exception as e:
|
478
|
+
logger.warning(f"Failed to log experiment results: {e}")
|
479
|
+
|
480
|
+
async def _end_experiment_tracking(self) -> None:
|
481
|
+
"""End experiment tracking."""
|
482
|
+
if self.experiment_tracker:
|
483
|
+
try:
|
484
|
+
await self.experiment_tracker.end_run()
|
485
|
+
except Exception as e:
|
486
|
+
logger.warning(f"Failed to end experiment tracking: {e}")
|
487
|
+
|
488
|
+
def stop_evaluation(self) -> None:
|
489
|
+
"""Request evaluation to stop gracefully."""
|
490
|
+
self._should_stop = True
|
491
|
+
logger.info("Evaluation stop requested")
|
492
|
+
|
493
|
+
def is_running(self) -> bool:
|
494
|
+
"""Check if evaluation is currently running."""
|
495
|
+
return self._is_running
|
496
|
+
|
497
|
+
def get_current_result(self) -> Optional[EvaluationResult]:
|
498
|
+
"""Get the current/latest evaluation result."""
|
499
|
+
return self._current_result
|
500
|
+
|
501
|
+
def get_supported_metrics(self) -> List[str]:
|
502
|
+
"""Get list of metrics supported by this evaluator."""
|
503
|
+
return [] # To be overridden by subclasses
|