isa-model 0.4.0__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/client.py +466 -43
- isa_model/core/cache/redis_cache.py +12 -3
- isa_model/core/config/config_manager.py +230 -3
- isa_model/core/config.py +90 -0
- isa_model/core/database/direct_db_client.py +114 -0
- isa_model/core/database/migration_manager.py +563 -0
- isa_model/core/database/migrations.py +21 -1
- isa_model/core/database/supabase_client.py +154 -19
- isa_model/core/dependencies.py +316 -0
- isa_model/core/discovery/__init__.py +19 -0
- isa_model/core/discovery/consul_discovery.py +190 -0
- isa_model/core/logging/__init__.py +54 -0
- isa_model/core/logging/influx_logger.py +523 -0
- isa_model/core/logging/loki_logger.py +160 -0
- isa_model/core/models/__init__.py +27 -18
- isa_model/core/models/config_models.py +625 -0
- isa_model/core/models/deployment_billing_tracker.py +430 -0
- isa_model/core/models/model_manager.py +40 -17
- isa_model/core/models/model_metadata.py +690 -0
- isa_model/core/models/model_repo.py +174 -18
- isa_model/core/models/system_models.py +857 -0
- isa_model/core/repositories/__init__.py +9 -0
- isa_model/core/repositories/config_repository.py +912 -0
- isa_model/core/services/intelligent_model_selector.py +399 -21
- isa_model/core/storage/hf_storage.py +1 -1
- isa_model/core/types.py +1 -0
- isa_model/deployment/__init__.py +5 -48
- isa_model/deployment/core/__init__.py +2 -31
- isa_model/deployment/core/deployment_manager.py +1278 -370
- isa_model/deployment/local/__init__.py +31 -0
- isa_model/deployment/local/config.py +248 -0
- isa_model/deployment/local/gpu_gateway.py +607 -0
- isa_model/deployment/local/health_checker.py +428 -0
- isa_model/deployment/local/provider.py +586 -0
- isa_model/deployment/local/tensorrt_service.py +621 -0
- isa_model/deployment/local/transformers_service.py +644 -0
- isa_model/deployment/local/vllm_service.py +527 -0
- isa_model/deployment/modal/__init__.py +8 -0
- isa_model/deployment/modal/config.py +136 -0
- isa_model/deployment/{services/auto_hf_modal_deployer.py → modal/deployer.py} +1 -1
- isa_model/deployment/modal/services/__init__.py +3 -0
- isa_model/deployment/modal/services/audio/__init__.py +1 -0
- isa_model/deployment/modal/services/embedding/__init__.py +1 -0
- isa_model/deployment/modal/services/llm/__init__.py +1 -0
- isa_model/deployment/modal/services/llm/isa_llm_service.py +424 -0
- isa_model/deployment/modal/services/video/__init__.py +1 -0
- isa_model/deployment/modal/services/vision/__init__.py +1 -0
- isa_model/deployment/models/org-org-acme-corp-tenant-a-service-llm-20250825-225822/tenant-a-service_modal_service.py +48 -0
- isa_model/deployment/models/org-test-org-123-prefix-test-service-llm-20250825-225822/prefix-test-service_modal_service.py +48 -0
- isa_model/deployment/models/test-llm-service-llm-20250825-204442/test-llm-service_modal_service.py +48 -0
- isa_model/deployment/models/test-monitoring-gpt2-llm-20250825-212906/test-monitoring-gpt2_modal_service.py +48 -0
- isa_model/deployment/models/test-monitoring-gpt2-llm-20250825-213009/test-monitoring-gpt2_modal_service.py +48 -0
- isa_model/deployment/storage/__init__.py +5 -0
- isa_model/deployment/storage/deployment_repository.py +824 -0
- isa_model/deployment/triton/__init__.py +10 -0
- isa_model/deployment/triton/config.py +196 -0
- isa_model/deployment/triton/configs/__init__.py +1 -0
- isa_model/deployment/triton/provider.py +512 -0
- isa_model/deployment/triton/scripts/__init__.py +1 -0
- isa_model/deployment/triton/templates/__init__.py +1 -0
- isa_model/inference/__init__.py +47 -1
- isa_model/inference/ai_factory.py +137 -10
- isa_model/inference/legacy_services/__init__.py +21 -0
- isa_model/inference/legacy_services/model_evaluation.py +637 -0
- isa_model/inference/legacy_services/model_service.py +573 -0
- isa_model/inference/legacy_services/model_serving.py +717 -0
- isa_model/inference/legacy_services/model_training.py +561 -0
- isa_model/inference/models/__init__.py +21 -0
- isa_model/inference/models/inference_config.py +551 -0
- isa_model/inference/models/inference_record.py +675 -0
- isa_model/inference/models/performance_models.py +714 -0
- isa_model/inference/repositories/__init__.py +9 -0
- isa_model/inference/repositories/inference_repository.py +828 -0
- isa_model/inference/services/audio/base_stt_service.py +184 -11
- isa_model/inference/services/audio/openai_stt_service.py +22 -6
- isa_model/inference/services/custom_model_manager.py +277 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +15 -3
- isa_model/inference/services/embedding/resilient_embed_service.py +285 -0
- isa_model/inference/services/llm/__init__.py +10 -2
- isa_model/inference/services/llm/base_llm_service.py +335 -24
- isa_model/inference/services/llm/cerebras_llm_service.py +628 -0
- isa_model/inference/services/llm/helpers/llm_adapter.py +9 -4
- isa_model/inference/services/llm/helpers/llm_prompts.py +342 -0
- isa_model/inference/services/llm/helpers/llm_utils.py +321 -23
- isa_model/inference/services/llm/huggingface_llm_service.py +581 -0
- isa_model/inference/services/llm/local_llm_service.py +747 -0
- isa_model/inference/services/llm/ollama_llm_service.py +9 -2
- isa_model/inference/services/llm/openai_llm_service.py +33 -16
- isa_model/inference/services/llm/yyds_llm_service.py +8 -2
- isa_model/inference/services/vision/__init__.py +22 -1
- isa_model/inference/services/vision/blip_vision_service.py +359 -0
- isa_model/inference/services/vision/helpers/image_utils.py +8 -5
- isa_model/inference/services/vision/isa_vision_service.py +65 -4
- isa_model/inference/services/vision/openai_vision_service.py +19 -10
- isa_model/inference/services/vision/vgg16_vision_service.py +257 -0
- isa_model/serving/api/cache_manager.py +245 -0
- isa_model/serving/api/dependencies/__init__.py +1 -0
- isa_model/serving/api/dependencies/auth.py +194 -0
- isa_model/serving/api/dependencies/database.py +139 -0
- isa_model/serving/api/error_handlers.py +284 -0
- isa_model/serving/api/fastapi_server.py +172 -22
- isa_model/serving/api/middleware/auth.py +8 -2
- isa_model/serving/api/middleware/security.py +23 -33
- isa_model/serving/api/middleware/tenant_context.py +414 -0
- isa_model/serving/api/routes/analytics.py +4 -1
- isa_model/serving/api/routes/config.py +645 -0
- isa_model/serving/api/routes/deployment_billing.py +315 -0
- isa_model/serving/api/routes/deployments.py +138 -2
- isa_model/serving/api/routes/gpu_gateway.py +440 -0
- isa_model/serving/api/routes/health.py +32 -12
- isa_model/serving/api/routes/inference_monitoring.py +486 -0
- isa_model/serving/api/routes/local_deployments.py +448 -0
- isa_model/serving/api/routes/tenants.py +575 -0
- isa_model/serving/api/routes/unified.py +680 -18
- isa_model/serving/api/routes/webhooks.py +479 -0
- isa_model/serving/api/startup.py +68 -54
- isa_model/utils/gpu_utils.py +311 -0
- {isa_model-0.4.0.dist-info → isa_model-0.4.3.dist-info}/METADATA +66 -24
- isa_model-0.4.3.dist-info/RECORD +193 -0
- isa_model/core/storage/minio_storage.py +0 -0
- isa_model/deployment/cloud/__init__.py +0 -9
- isa_model/deployment/cloud/modal/__init__.py +0 -10
- isa_model/deployment/core/deployment_config.py +0 -356
- isa_model/deployment/core/isa_deployment_service.py +0 -401
- isa_model/deployment/gpu_int8_ds8/app/server.py +0 -66
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +0 -43
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +0 -35
- isa_model/deployment/runtime/deployed_service.py +0 -338
- isa_model/deployment/services/__init__.py +0 -9
- isa_model/deployment/services/auto_deploy_vision_service.py +0 -538
- isa_model/deployment/services/model_service.py +0 -332
- isa_model/deployment/services/service_monitor.py +0 -356
- isa_model/deployment/services/service_registry.py +0 -527
- isa_model/eval/__init__.py +0 -92
- isa_model/eval/benchmarks/__init__.py +0 -27
- isa_model/eval/benchmarks/multimodal_datasets.py +0 -460
- isa_model/eval/benchmarks.py +0 -701
- isa_model/eval/config/__init__.py +0 -10
- isa_model/eval/config/evaluation_config.py +0 -108
- isa_model/eval/evaluators/__init__.py +0 -24
- isa_model/eval/evaluators/audio_evaluator.py +0 -727
- isa_model/eval/evaluators/base_evaluator.py +0 -503
- isa_model/eval/evaluators/embedding_evaluator.py +0 -742
- isa_model/eval/evaluators/llm_evaluator.py +0 -472
- isa_model/eval/evaluators/vision_evaluator.py +0 -564
- isa_model/eval/example_evaluation.py +0 -395
- isa_model/eval/factory.py +0 -798
- isa_model/eval/infrastructure/__init__.py +0 -24
- isa_model/eval/infrastructure/experiment_tracker.py +0 -466
- isa_model/eval/isa_benchmarks.py +0 -700
- isa_model/eval/isa_integration.py +0 -582
- isa_model/eval/metrics.py +0 -951
- isa_model/eval/tests/unit/test_basic.py +0 -396
- isa_model/serving/api/routes/evaluations.py +0 -579
- isa_model/training/__init__.py +0 -168
- isa_model/training/annotation/annotation_schema.py +0 -47
- isa_model/training/annotation/processors/annotation_processor.py +0 -126
- isa_model/training/annotation/storage/dataset_manager.py +0 -131
- isa_model/training/annotation/storage/dataset_schema.py +0 -44
- isa_model/training/annotation/tests/test_annotation_flow.py +0 -109
- isa_model/training/annotation/tests/test_minio copy.py +0 -113
- isa_model/training/annotation/tests/test_minio_upload.py +0 -43
- isa_model/training/annotation/views/annotation_controller.py +0 -158
- isa_model/training/cloud/__init__.py +0 -22
- isa_model/training/cloud/job_orchestrator.py +0 -402
- isa_model/training/cloud/runpod_trainer.py +0 -454
- isa_model/training/cloud/storage_manager.py +0 -482
- isa_model/training/core/__init__.py +0 -26
- isa_model/training/core/config.py +0 -181
- isa_model/training/core/dataset.py +0 -222
- isa_model/training/core/trainer.py +0 -720
- isa_model/training/core/utils.py +0 -213
- isa_model/training/examples/intelligent_training_example.py +0 -281
- isa_model/training/factory.py +0 -424
- isa_model/training/intelligent/__init__.py +0 -25
- isa_model/training/intelligent/decision_engine.py +0 -643
- isa_model/training/intelligent/intelligent_factory.py +0 -888
- isa_model/training/intelligent/knowledge_base.py +0 -751
- isa_model/training/intelligent/resource_optimizer.py +0 -839
- isa_model/training/intelligent/task_classifier.py +0 -576
- isa_model/training/storage/__init__.py +0 -24
- isa_model/training/storage/core_integration.py +0 -439
- isa_model/training/storage/training_repository.py +0 -552
- isa_model/training/storage/training_storage.py +0 -628
- isa_model-0.4.0.dist-info/RECORD +0 -182
- /isa_model/deployment/{cloud/modal → modal/services/audio}/isa_audio_chatTTS_service.py +0 -0
- /isa_model/deployment/{cloud/modal → modal/services/audio}/isa_audio_fish_service.py +0 -0
- /isa_model/deployment/{cloud/modal → modal/services/audio}/isa_audio_openvoice_service.py +0 -0
- /isa_model/deployment/{cloud/modal → modal/services/audio}/isa_audio_service_v2.py +0 -0
- /isa_model/deployment/{cloud/modal → modal/services/embedding}/isa_embed_rerank_service.py +0 -0
- /isa_model/deployment/{cloud/modal → modal/services/video}/isa_video_hunyuan_service.py +0 -0
- /isa_model/deployment/{cloud/modal → modal/services/vision}/isa_vision_ocr_service.py +0 -0
- /isa_model/deployment/{cloud/modal → modal/services/vision}/isa_vision_qwen25_service.py +0 -0
- /isa_model/deployment/{cloud/modal → modal/services/vision}/isa_vision_table_service.py +0 -0
- /isa_model/deployment/{cloud/modal → modal/services/vision}/isa_vision_ui_service.py +0 -0
- /isa_model/deployment/{cloud/modal → modal/services/vision}/isa_vision_ui_service_optimized.py +0 -0
- /isa_model/deployment/{services → modal/services/vision}/simple_auto_deploy_vision_service.py +0 -0
- {isa_model-0.4.0.dist-info → isa_model-0.4.3.dist-info}/WHEEL +0 -0
- {isa_model-0.4.0.dist-info → isa_model-0.4.3.dist-info}/top_level.txt +0 -0
@@ -1,727 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Audio Evaluator for ISA Model evaluation framework.
|
3
|
-
|
4
|
-
Provides comprehensive evaluation capabilities for audio tasks including:
|
5
|
-
- Speech-to-Text (STT) evaluation with WER/CER metrics
|
6
|
-
- Speaker diarization evaluation
|
7
|
-
- Emotion recognition evaluation
|
8
|
-
- Voice activity detection evaluation
|
9
|
-
- Speech enhancement evaluation
|
10
|
-
- Text-to-Speech (TTS) quality evaluation
|
11
|
-
|
12
|
-
Supports ISA custom audio services and standard audio models.
|
13
|
-
"""
|
14
|
-
|
15
|
-
import asyncio
|
16
|
-
import logging
|
17
|
-
import librosa
|
18
|
-
import numpy as np
|
19
|
-
import re
|
20
|
-
from typing import Dict, List, Any, Optional, Union, Tuple
|
21
|
-
from pathlib import Path
|
22
|
-
import tempfile
|
23
|
-
import wave
|
24
|
-
|
25
|
-
from .base_evaluator import BaseEvaluator, EvaluationResult
|
26
|
-
from ..metrics import compute_text_metrics
|
27
|
-
|
28
|
-
logger = logging.getLogger(__name__)
|
29
|
-
|
30
|
-
|
31
|
-
class AudioEvaluator(BaseEvaluator):
|
32
|
-
"""
|
33
|
-
Comprehensive audio model evaluator.
|
34
|
-
|
35
|
-
Supports evaluation of:
|
36
|
-
- Speech-to-Text accuracy (WER, CER, BLEU)
|
37
|
-
- Speaker diarization accuracy (DER, Speaker F1)
|
38
|
-
- Emotion recognition accuracy
|
39
|
-
- Voice activity detection (Precision, Recall, F1)
|
40
|
-
- Speech enhancement quality (SNR, PESQ, STOI)
|
41
|
-
- Text-to-Speech naturalness and intelligibility
|
42
|
-
"""
|
43
|
-
|
44
|
-
def __init__(self,
|
45
|
-
config: Optional[Dict[str, Any]] = None,
|
46
|
-
experiment_tracker: Optional[Any] = None):
|
47
|
-
"""
|
48
|
-
Initialize the audio evaluator.
|
49
|
-
|
50
|
-
Args:
|
51
|
-
config: Evaluation configuration
|
52
|
-
experiment_tracker: Optional experiment tracking instance
|
53
|
-
"""
|
54
|
-
super().__init__(
|
55
|
-
evaluator_name="audio_evaluator",
|
56
|
-
config=config,
|
57
|
-
experiment_tracker=experiment_tracker
|
58
|
-
)
|
59
|
-
|
60
|
-
# Audio-specific configuration
|
61
|
-
self.sample_rate = self.config.get("sample_rate", 16000)
|
62
|
-
self.supported_formats = self.config.get("supported_formats", ["wav", "mp3", "flac", "m4a"])
|
63
|
-
self.max_duration = self.config.get("max_duration_seconds", 300) # 5 minutes
|
64
|
-
|
65
|
-
# Evaluation task types
|
66
|
-
self.task_type = self.config.get("task_type", "stt") # stt, diarization, emotion, tts, enhancement
|
67
|
-
|
68
|
-
# STT evaluation settings
|
69
|
-
self.normalize_text = self.config.get("normalize_text", True)
|
70
|
-
self.case_sensitive = self.config.get("case_sensitive", False)
|
71
|
-
self.remove_punctuation = self.config.get("remove_punctuation", True)
|
72
|
-
|
73
|
-
# Speaker diarization settings
|
74
|
-
self.collar_tolerance = self.config.get("collar_tolerance", 0.25) # 250ms tolerance
|
75
|
-
|
76
|
-
logger.info(f"Initialized AudioEvaluator for task: {self.task_type}")
|
77
|
-
|
78
|
-
async def evaluate_sample(self,
|
79
|
-
sample: Dict[str, Any],
|
80
|
-
model_interface: Any) -> Dict[str, Any]:
|
81
|
-
"""
|
82
|
-
Evaluate a single audio sample.
|
83
|
-
|
84
|
-
Args:
|
85
|
-
sample: Audio sample containing audio data and expected output
|
86
|
-
model_interface: Audio model interface
|
87
|
-
|
88
|
-
Returns:
|
89
|
-
Evaluation result for the sample
|
90
|
-
"""
|
91
|
-
try:
|
92
|
-
# Extract sample data
|
93
|
-
audio_data = sample.get("audio")
|
94
|
-
expected_output = sample.get("expected_output", "")
|
95
|
-
task_type = sample.get("task_type", self.task_type)
|
96
|
-
metadata = sample.get("metadata", {})
|
97
|
-
|
98
|
-
# Process audio
|
99
|
-
processed_audio, audio_info = await self._process_audio(audio_data)
|
100
|
-
|
101
|
-
# Get model prediction based on task type
|
102
|
-
prediction = await self._get_model_prediction(
|
103
|
-
model_interface, processed_audio, task_type, metadata
|
104
|
-
)
|
105
|
-
|
106
|
-
# Compute sample-level metrics
|
107
|
-
sample_metrics = self._compute_sample_metrics(
|
108
|
-
prediction, expected_output, task_type, audio_info
|
109
|
-
)
|
110
|
-
|
111
|
-
return {
|
112
|
-
"prediction": prediction,
|
113
|
-
"expected_output": expected_output,
|
114
|
-
"task_type": task_type,
|
115
|
-
"sample_metrics": sample_metrics,
|
116
|
-
"audio_info": audio_info,
|
117
|
-
"metadata": metadata
|
118
|
-
}
|
119
|
-
|
120
|
-
except Exception as e:
|
121
|
-
logger.error(f"Error evaluating audio sample: {e}")
|
122
|
-
raise
|
123
|
-
|
124
|
-
async def _process_audio(self, audio_data: Union[str, bytes, np.ndarray, Path]) -> Tuple[np.ndarray, Dict[str, Any]]:
|
125
|
-
"""
|
126
|
-
Process and validate audio data.
|
127
|
-
|
128
|
-
Args:
|
129
|
-
audio_data: Audio in various formats
|
130
|
-
|
131
|
-
Returns:
|
132
|
-
Tuple of (processed audio array, audio info dict)
|
133
|
-
"""
|
134
|
-
try:
|
135
|
-
if isinstance(audio_data, str) and Path(audio_data).exists():
|
136
|
-
# File path
|
137
|
-
audio_array, sr = librosa.load(audio_data, sr=self.sample_rate)
|
138
|
-
original_sr = librosa.get_samplerate(audio_data)
|
139
|
-
|
140
|
-
elif isinstance(audio_data, Path):
|
141
|
-
# Path object
|
142
|
-
audio_array, sr = librosa.load(str(audio_data), sr=self.sample_rate)
|
143
|
-
original_sr = librosa.get_samplerate(str(audio_data))
|
144
|
-
|
145
|
-
elif isinstance(audio_data, bytes):
|
146
|
-
# Raw audio bytes - save to temp file first
|
147
|
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
148
|
-
tmp_file.write(audio_data)
|
149
|
-
tmp_path = tmp_file.name
|
150
|
-
|
151
|
-
audio_array, sr = librosa.load(tmp_path, sr=self.sample_rate)
|
152
|
-
original_sr = self.sample_rate # Assume target sample rate
|
153
|
-
Path(tmp_path).unlink() # Clean up temp file
|
154
|
-
|
155
|
-
elif isinstance(audio_data, np.ndarray):
|
156
|
-
# NumPy array
|
157
|
-
audio_array = audio_data
|
158
|
-
sr = self.sample_rate
|
159
|
-
original_sr = self.sample_rate
|
160
|
-
|
161
|
-
# Resample if needed
|
162
|
-
if len(audio_array.shape) > 1:
|
163
|
-
audio_array = librosa.to_mono(audio_array)
|
164
|
-
|
165
|
-
else:
|
166
|
-
raise ValueError(f"Unsupported audio data type: {type(audio_data)}")
|
167
|
-
|
168
|
-
# Validate duration
|
169
|
-
duration = len(audio_array) / sr
|
170
|
-
if duration > self.max_duration:
|
171
|
-
logger.warning(f"Audio duration {duration:.2f}s exceeds max duration {self.max_duration}s, truncating")
|
172
|
-
max_samples = int(self.max_duration * sr)
|
173
|
-
audio_array = audio_array[:max_samples]
|
174
|
-
duration = self.max_duration
|
175
|
-
|
176
|
-
# Compute audio features
|
177
|
-
audio_info = {
|
178
|
-
"duration_seconds": duration,
|
179
|
-
"sample_rate": sr,
|
180
|
-
"original_sample_rate": original_sr,
|
181
|
-
"num_samples": len(audio_array),
|
182
|
-
"rms_energy": float(np.sqrt(np.mean(audio_array**2))),
|
183
|
-
"zero_crossing_rate": float(np.mean(librosa.feature.zero_crossing_rate(audio_array))),
|
184
|
-
"spectral_centroid": float(np.mean(librosa.feature.spectral_centroid(y=audio_array, sr=sr)))
|
185
|
-
}
|
186
|
-
|
187
|
-
return audio_array, audio_info
|
188
|
-
|
189
|
-
except Exception as e:
|
190
|
-
logger.error(f"Error processing audio: {e}")
|
191
|
-
raise
|
192
|
-
|
193
|
-
async def _get_model_prediction(self,
|
194
|
-
model_interface: Any,
|
195
|
-
audio: np.ndarray,
|
196
|
-
task_type: str,
|
197
|
-
metadata: Dict[str, Any]) -> Union[str, Dict[str, Any]]:
|
198
|
-
"""
|
199
|
-
Get model prediction for audio task.
|
200
|
-
|
201
|
-
Args:
|
202
|
-
model_interface: Audio model interface
|
203
|
-
audio: Processed audio array
|
204
|
-
task_type: Type of audio task
|
205
|
-
metadata: Additional metadata
|
206
|
-
|
207
|
-
Returns:
|
208
|
-
Model prediction (string for STT, dict for complex tasks)
|
209
|
-
"""
|
210
|
-
try:
|
211
|
-
# Save audio to temporary file for model processing
|
212
|
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
213
|
-
# Convert to int16 for wav format
|
214
|
-
audio_int16 = (audio * 32767).astype(np.int16)
|
215
|
-
|
216
|
-
# Write WAV file
|
217
|
-
with wave.open(tmp_file.name, 'wb') as wav_file:
|
218
|
-
wav_file.setnchannels(1) # Mono
|
219
|
-
wav_file.setsampwidth(2) # 16-bit
|
220
|
-
wav_file.setframerate(self.sample_rate)
|
221
|
-
wav_file.writeframes(audio_int16.tobytes())
|
222
|
-
|
223
|
-
tmp_path = tmp_file.name
|
224
|
-
|
225
|
-
try:
|
226
|
-
if task_type == "stt":
|
227
|
-
# Speech-to-Text
|
228
|
-
if hasattr(model_interface, 'transcribe'):
|
229
|
-
result = await model_interface.transcribe(tmp_path)
|
230
|
-
prediction = result.get("text", "") if isinstance(result, dict) else str(result)
|
231
|
-
elif hasattr(model_interface, 'stt'):
|
232
|
-
prediction = await model_interface.stt(tmp_path)
|
233
|
-
else:
|
234
|
-
prediction = await model_interface.predict(tmp_path)
|
235
|
-
|
236
|
-
return str(prediction).strip()
|
237
|
-
|
238
|
-
elif task_type == "diarization":
|
239
|
-
# Speaker diarization
|
240
|
-
if hasattr(model_interface, 'diarize'):
|
241
|
-
result = await model_interface.diarize(tmp_path)
|
242
|
-
else:
|
243
|
-
result = await model_interface.predict(tmp_path, task="diarization")
|
244
|
-
|
245
|
-
return result # Should be dict with speaker segments
|
246
|
-
|
247
|
-
elif task_type == "emotion":
|
248
|
-
# Emotion recognition
|
249
|
-
if hasattr(model_interface, 'detect_emotion'):
|
250
|
-
result = await model_interface.detect_emotion(tmp_path)
|
251
|
-
else:
|
252
|
-
result = await model_interface.predict(tmp_path, task="emotion")
|
253
|
-
|
254
|
-
return result # Should be emotion label or dict
|
255
|
-
|
256
|
-
elif task_type == "tts":
|
257
|
-
# Text-to-Speech (reverse evaluation)
|
258
|
-
text_input = metadata.get("text_input", "")
|
259
|
-
if hasattr(model_interface, 'synthesize'):
|
260
|
-
result = await model_interface.synthesize(text_input)
|
261
|
-
else:
|
262
|
-
result = await model_interface.predict(text_input, task="tts")
|
263
|
-
|
264
|
-
return result # Should be synthesized audio
|
265
|
-
|
266
|
-
else:
|
267
|
-
# Generic prediction
|
268
|
-
prediction = await model_interface.predict(tmp_path)
|
269
|
-
return prediction
|
270
|
-
|
271
|
-
finally:
|
272
|
-
# Clean up temporary file
|
273
|
-
Path(tmp_path).unlink()
|
274
|
-
|
275
|
-
except Exception as e:
|
276
|
-
logger.error(f"Error getting model prediction: {e}")
|
277
|
-
raise
|
278
|
-
|
279
|
-
def _compute_sample_metrics(self,
|
280
|
-
prediction: Union[str, Dict[str, Any]],
|
281
|
-
expected_output: Union[str, Dict[str, Any]],
|
282
|
-
task_type: str,
|
283
|
-
audio_info: Dict[str, Any]) -> Dict[str, float]:
|
284
|
-
"""
|
285
|
-
Compute metrics for a single sample.
|
286
|
-
|
287
|
-
Args:
|
288
|
-
prediction: Model prediction
|
289
|
-
expected_output: Expected/reference output
|
290
|
-
task_type: Type of audio task
|
291
|
-
audio_info: Audio metadata
|
292
|
-
|
293
|
-
Returns:
|
294
|
-
Dictionary of sample-level metrics
|
295
|
-
"""
|
296
|
-
try:
|
297
|
-
metrics = {}
|
298
|
-
|
299
|
-
if task_type == "stt":
|
300
|
-
# Speech-to-Text metrics
|
301
|
-
if isinstance(prediction, str) and isinstance(expected_output, str):
|
302
|
-
metrics.update(self._compute_stt_metrics(prediction, expected_output))
|
303
|
-
|
304
|
-
elif task_type == "diarization":
|
305
|
-
# Speaker diarization metrics
|
306
|
-
metrics.update(self._compute_diarization_metrics(prediction, expected_output))
|
307
|
-
|
308
|
-
elif task_type == "emotion":
|
309
|
-
# Emotion recognition metrics
|
310
|
-
metrics.update(self._compute_emotion_metrics(prediction, expected_output))
|
311
|
-
|
312
|
-
elif task_type == "tts":
|
313
|
-
# TTS quality metrics
|
314
|
-
metrics.update(self._compute_tts_metrics(prediction, expected_output, audio_info))
|
315
|
-
|
316
|
-
# Add audio metadata
|
317
|
-
metrics.update({
|
318
|
-
"audio_duration": audio_info.get("duration_seconds", 0.0),
|
319
|
-
"audio_quality_score": self._compute_audio_quality_score(audio_info)
|
320
|
-
})
|
321
|
-
|
322
|
-
return metrics
|
323
|
-
|
324
|
-
except Exception as e:
|
325
|
-
logger.error(f"Error computing sample metrics: {e}")
|
326
|
-
return {"error": 1.0}
|
327
|
-
|
328
|
-
def _compute_stt_metrics(self, prediction: str, reference: str) -> Dict[str, float]:
|
329
|
-
"""Compute Speech-to-Text specific metrics."""
|
330
|
-
try:
|
331
|
-
# Normalize text if configured
|
332
|
-
if self.normalize_text:
|
333
|
-
prediction = self._normalize_text(prediction)
|
334
|
-
reference = self._normalize_text(reference)
|
335
|
-
|
336
|
-
# Word Error Rate (WER)
|
337
|
-
wer = self._compute_wer(prediction, reference)
|
338
|
-
|
339
|
-
# Character Error Rate (CER)
|
340
|
-
cer = self._compute_cer(prediction, reference)
|
341
|
-
|
342
|
-
# Additional text metrics
|
343
|
-
text_metrics = compute_text_metrics(prediction, reference)
|
344
|
-
|
345
|
-
return {
|
346
|
-
"wer": wer,
|
347
|
-
"cer": cer,
|
348
|
-
"word_accuracy": 1.0 - wer,
|
349
|
-
"char_accuracy": 1.0 - cer,
|
350
|
-
**text_metrics
|
351
|
-
}
|
352
|
-
|
353
|
-
except Exception as e:
|
354
|
-
logger.error(f"Error computing STT metrics: {e}")
|
355
|
-
return {"stt_error": 1.0}
|
356
|
-
|
357
|
-
def _compute_wer(self, prediction: str, reference: str) -> float:
|
358
|
-
"""Compute Word Error Rate."""
|
359
|
-
try:
|
360
|
-
pred_words = prediction.strip().split()
|
361
|
-
ref_words = reference.strip().split()
|
362
|
-
|
363
|
-
if not ref_words:
|
364
|
-
return 0.0 if not pred_words else 1.0
|
365
|
-
|
366
|
-
# Compute edit distance
|
367
|
-
distance = self._edit_distance(pred_words, ref_words)
|
368
|
-
wer = distance / len(ref_words)
|
369
|
-
|
370
|
-
return min(1.0, wer) # Cap at 100%
|
371
|
-
|
372
|
-
except Exception as e:
|
373
|
-
logger.error(f"Error computing WER: {e}")
|
374
|
-
return 1.0
|
375
|
-
|
376
|
-
def _compute_cer(self, prediction: str, reference: str) -> float:
|
377
|
-
"""Compute Character Error Rate."""
|
378
|
-
try:
|
379
|
-
pred_chars = list(prediction.strip())
|
380
|
-
ref_chars = list(reference.strip())
|
381
|
-
|
382
|
-
if not ref_chars:
|
383
|
-
return 0.0 if not pred_chars else 1.0
|
384
|
-
|
385
|
-
# Compute edit distance
|
386
|
-
distance = self._edit_distance(pred_chars, ref_chars)
|
387
|
-
cer = distance / len(ref_chars)
|
388
|
-
|
389
|
-
return min(1.0, cer) # Cap at 100%
|
390
|
-
|
391
|
-
except Exception as e:
|
392
|
-
logger.error(f"Error computing CER: {e}")
|
393
|
-
return 1.0
|
394
|
-
|
395
|
-
def _edit_distance(self, seq1: List[str], seq2: List[str]) -> int:
|
396
|
-
"""Compute Levenshtein edit distance."""
|
397
|
-
m, n = len(seq1), len(seq2)
|
398
|
-
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
399
|
-
|
400
|
-
# Initialize base cases
|
401
|
-
for i in range(m + 1):
|
402
|
-
dp[i][0] = i
|
403
|
-
for j in range(n + 1):
|
404
|
-
dp[0][j] = j
|
405
|
-
|
406
|
-
# Fill the DP table
|
407
|
-
for i in range(1, m + 1):
|
408
|
-
for j in range(1, n + 1):
|
409
|
-
if seq1[i-1] == seq2[j-1]:
|
410
|
-
dp[i][j] = dp[i-1][j-1]
|
411
|
-
else:
|
412
|
-
dp[i][j] = 1 + min(
|
413
|
-
dp[i-1][j], # deletion
|
414
|
-
dp[i][j-1], # insertion
|
415
|
-
dp[i-1][j-1] # substitution
|
416
|
-
)
|
417
|
-
|
418
|
-
return dp[m][n]
|
419
|
-
|
420
|
-
def _normalize_text(self, text: str) -> str:
|
421
|
-
"""Normalize text for evaluation."""
|
422
|
-
# Convert to lowercase if not case sensitive
|
423
|
-
if not self.case_sensitive:
|
424
|
-
text = text.lower()
|
425
|
-
|
426
|
-
# Remove punctuation if configured
|
427
|
-
if self.remove_punctuation:
|
428
|
-
text = re.sub(r'[^\w\s]', ' ', text)
|
429
|
-
|
430
|
-
# Normalize whitespace
|
431
|
-
text = re.sub(r'\s+', ' ', text).strip()
|
432
|
-
|
433
|
-
return text
|
434
|
-
|
435
|
-
def _compute_diarization_metrics(self,
|
436
|
-
prediction: Dict[str, Any],
|
437
|
-
reference: Dict[str, Any]) -> Dict[str, float]:
|
438
|
-
"""Compute speaker diarization metrics."""
|
439
|
-
try:
|
440
|
-
# Extract speaker segments
|
441
|
-
pred_segments = prediction.get("segments", []) if isinstance(prediction, dict) else []
|
442
|
-
ref_segments = reference.get("segments", []) if isinstance(reference, dict) else []
|
443
|
-
|
444
|
-
if not ref_segments:
|
445
|
-
return {"diarization_error": 1.0}
|
446
|
-
|
447
|
-
# Compute Diarization Error Rate (DER)
|
448
|
-
der = self._compute_der(pred_segments, ref_segments)
|
449
|
-
|
450
|
-
# Compute Speaker F1 score
|
451
|
-
speaker_f1 = self._compute_speaker_f1(pred_segments, ref_segments)
|
452
|
-
|
453
|
-
return {
|
454
|
-
"diarization_error_rate": der,
|
455
|
-
"speaker_f1_score": speaker_f1,
|
456
|
-
"num_predicted_speakers": len(set(seg.get("speaker", "") for seg in pred_segments)),
|
457
|
-
"num_reference_speakers": len(set(seg.get("speaker", "") for seg in ref_segments))
|
458
|
-
}
|
459
|
-
|
460
|
-
except Exception as e:
|
461
|
-
logger.error(f"Error computing diarization metrics: {e}")
|
462
|
-
return {"diarization_error": 1.0}
|
463
|
-
|
464
|
-
def _compute_der(self, pred_segments: List[Dict], ref_segments: List[Dict]) -> float:
|
465
|
-
"""Compute Diarization Error Rate."""
|
466
|
-
try:
|
467
|
-
# This is a simplified DER computation
|
468
|
-
# In practice, you'd use specialized libraries like pyannote.metrics
|
469
|
-
|
470
|
-
total_time = 0.0
|
471
|
-
error_time = 0.0
|
472
|
-
|
473
|
-
# Find overall time range
|
474
|
-
all_segments = pred_segments + ref_segments
|
475
|
-
if not all_segments:
|
476
|
-
return 0.0
|
477
|
-
|
478
|
-
start_time = min(seg.get("start", 0.0) for seg in all_segments)
|
479
|
-
end_time = max(seg.get("end", 0.0) for seg in all_segments)
|
480
|
-
total_time = end_time - start_time
|
481
|
-
|
482
|
-
if total_time <= 0:
|
483
|
-
return 0.0
|
484
|
-
|
485
|
-
# Sample time points and check for errors
|
486
|
-
time_step = 0.1 # 100ms resolution
|
487
|
-
num_steps = int(total_time / time_step)
|
488
|
-
|
489
|
-
for i in range(num_steps):
|
490
|
-
t = start_time + i * time_step
|
491
|
-
|
492
|
-
# Find speakers at time t
|
493
|
-
pred_speaker = self._get_speaker_at_time(t, pred_segments)
|
494
|
-
ref_speaker = self._get_speaker_at_time(t, ref_segments)
|
495
|
-
|
496
|
-
if pred_speaker != ref_speaker:
|
497
|
-
error_time += time_step
|
498
|
-
|
499
|
-
der = error_time / total_time if total_time > 0 else 0.0
|
500
|
-
return min(1.0, der)
|
501
|
-
|
502
|
-
except Exception as e:
|
503
|
-
logger.error(f"Error computing DER: {e}")
|
504
|
-
return 1.0
|
505
|
-
|
506
|
-
def _get_speaker_at_time(self, time: float, segments: List[Dict]) -> Optional[str]:
|
507
|
-
"""Get the speaker at a specific time point."""
|
508
|
-
for segment in segments:
|
509
|
-
start = segment.get("start", 0.0)
|
510
|
-
end = segment.get("end", 0.0)
|
511
|
-
if start <= time < end:
|
512
|
-
return segment.get("speaker")
|
513
|
-
return None
|
514
|
-
|
515
|
-
def _compute_speaker_f1(self, pred_segments: List[Dict], ref_segments: List[Dict]) -> float:
|
516
|
-
"""Compute Speaker F1 score."""
|
517
|
-
try:
|
518
|
-
# Extract unique speakers
|
519
|
-
pred_speakers = set(seg.get("speaker", "") for seg in pred_segments)
|
520
|
-
ref_speakers = set(seg.get("speaker", "") for seg in ref_segments)
|
521
|
-
|
522
|
-
pred_speakers.discard("") # Remove empty speakers
|
523
|
-
ref_speakers.discard("")
|
524
|
-
|
525
|
-
if not ref_speakers:
|
526
|
-
return 1.0 if not pred_speakers else 0.0
|
527
|
-
|
528
|
-
# Simple speaker overlap metric
|
529
|
-
intersection = len(pred_speakers.intersection(ref_speakers))
|
530
|
-
precision = intersection / len(pred_speakers) if pred_speakers else 0.0
|
531
|
-
recall = intersection / len(ref_speakers) if ref_speakers else 0.0
|
532
|
-
|
533
|
-
if precision + recall == 0:
|
534
|
-
return 0.0
|
535
|
-
|
536
|
-
f1 = 2 * precision * recall / (precision + recall)
|
537
|
-
return f1
|
538
|
-
|
539
|
-
except Exception as e:
|
540
|
-
logger.error(f"Error computing speaker F1: {e}")
|
541
|
-
return 0.0
|
542
|
-
|
543
|
-
def _compute_emotion_metrics(self,
|
544
|
-
prediction: Union[str, Dict[str, Any]],
|
545
|
-
reference: Union[str, Dict[str, Any]]) -> Dict[str, float]:
|
546
|
-
"""Compute emotion recognition metrics."""
|
547
|
-
try:
|
548
|
-
# Extract emotion labels
|
549
|
-
if isinstance(prediction, dict):
|
550
|
-
pred_emotion = prediction.get("emotion", "")
|
551
|
-
pred_confidence = prediction.get("confidence", 0.0)
|
552
|
-
else:
|
553
|
-
pred_emotion = str(prediction)
|
554
|
-
pred_confidence = 1.0
|
555
|
-
|
556
|
-
if isinstance(reference, dict):
|
557
|
-
ref_emotion = reference.get("emotion", "")
|
558
|
-
else:
|
559
|
-
ref_emotion = str(reference)
|
560
|
-
|
561
|
-
# Compute accuracy
|
562
|
-
emotion_accuracy = 1.0 if pred_emotion.lower() == ref_emotion.lower() else 0.0
|
563
|
-
|
564
|
-
return {
|
565
|
-
"emotion_accuracy": emotion_accuracy,
|
566
|
-
"emotion_confidence": pred_confidence,
|
567
|
-
"predicted_emotion": pred_emotion,
|
568
|
-
"reference_emotion": ref_emotion
|
569
|
-
}
|
570
|
-
|
571
|
-
except Exception as e:
|
572
|
-
logger.error(f"Error computing emotion metrics: {e}")
|
573
|
-
return {"emotion_error": 1.0}
|
574
|
-
|
575
|
-
def _compute_tts_metrics(self,
|
576
|
-
prediction: Any,
|
577
|
-
reference: Any,
|
578
|
-
audio_info: Dict[str, Any]) -> Dict[str, float]:
|
579
|
-
"""Compute Text-to-Speech quality metrics."""
|
580
|
-
try:
|
581
|
-
# This is a placeholder for TTS evaluation
|
582
|
-
# In practice, you'd use specialized metrics like MOS, PESQ, STOI
|
583
|
-
|
584
|
-
return {
|
585
|
-
"tts_quality_score": 0.8, # Placeholder
|
586
|
-
"synthesis_success": 1.0 if prediction is not None else 0.0
|
587
|
-
}
|
588
|
-
|
589
|
-
except Exception as e:
|
590
|
-
logger.error(f"Error computing TTS metrics: {e}")
|
591
|
-
return {"tts_error": 1.0}
|
592
|
-
|
593
|
-
def _compute_audio_quality_score(self, audio_info: Dict[str, Any]) -> float:
|
594
|
-
"""Compute a simple audio quality score based on audio features."""
|
595
|
-
try:
|
596
|
-
# Simple heuristic based on RMS energy and other features
|
597
|
-
rms_energy = audio_info.get("rms_energy", 0.0)
|
598
|
-
duration = audio_info.get("duration_seconds", 0.0)
|
599
|
-
|
600
|
-
# Normalize RMS energy (assuming good audio is in range 0.01-0.1)
|
601
|
-
energy_score = min(1.0, max(0.0, (rms_energy - 0.001) / 0.1))
|
602
|
-
|
603
|
-
# Duration score (prefer reasonable durations)
|
604
|
-
duration_score = 1.0 if 1.0 <= duration <= 60.0 else 0.5
|
605
|
-
|
606
|
-
quality_score = (energy_score + duration_score) / 2
|
607
|
-
return quality_score
|
608
|
-
|
609
|
-
except Exception as e:
|
610
|
-
logger.error(f"Error computing audio quality score: {e}")
|
611
|
-
return 0.5
|
612
|
-
|
613
|
-
def compute_metrics(self,
|
614
|
-
predictions: List[Any],
|
615
|
-
references: List[Any],
|
616
|
-
**kwargs) -> Dict[str, float]:
|
617
|
-
"""
|
618
|
-
Compute aggregate audio evaluation metrics.
|
619
|
-
|
620
|
-
Args:
|
621
|
-
predictions: List of model predictions
|
622
|
-
references: List of reference outputs
|
623
|
-
**kwargs: Additional parameters
|
624
|
-
|
625
|
-
Returns:
|
626
|
-
Dictionary of computed metrics
|
627
|
-
"""
|
628
|
-
try:
|
629
|
-
if not predictions or not references:
|
630
|
-
logger.warning("Empty predictions or references provided")
|
631
|
-
return {}
|
632
|
-
|
633
|
-
# Ensure equal lengths
|
634
|
-
min_len = min(len(predictions), len(references))
|
635
|
-
predictions = predictions[:min_len]
|
636
|
-
references = references[:min_len]
|
637
|
-
|
638
|
-
task_type = self.task_type
|
639
|
-
|
640
|
-
if task_type == "stt":
|
641
|
-
return self._compute_aggregate_stt_metrics(predictions, references)
|
642
|
-
elif task_type == "diarization":
|
643
|
-
return self._compute_aggregate_diarization_metrics(predictions, references)
|
644
|
-
elif task_type == "emotion":
|
645
|
-
return self._compute_aggregate_emotion_metrics(predictions, references)
|
646
|
-
else:
|
647
|
-
# Generic metrics
|
648
|
-
return {
|
649
|
-
"total_samples": len(predictions),
|
650
|
-
"task_type": task_type,
|
651
|
-
"evaluation_success_rate": 1.0
|
652
|
-
}
|
653
|
-
|
654
|
-
except Exception as e:
|
655
|
-
logger.error(f"Error computing aggregate metrics: {e}")
|
656
|
-
return {"error_rate": 1.0}
|
657
|
-
|
658
|
-
def _compute_aggregate_stt_metrics(self,
|
659
|
-
predictions: List[str],
|
660
|
-
references: List[str]) -> Dict[str, float]:
|
661
|
-
"""Compute aggregate STT metrics."""
|
662
|
-
wer_scores = []
|
663
|
-
cer_scores = []
|
664
|
-
|
665
|
-
for pred, ref in zip(predictions, references):
|
666
|
-
if isinstance(pred, str) and isinstance(ref, str):
|
667
|
-
sample_metrics = self._compute_stt_metrics(pred, ref)
|
668
|
-
wer_scores.append(sample_metrics.get("wer", 1.0))
|
669
|
-
cer_scores.append(sample_metrics.get("cer", 1.0))
|
670
|
-
|
671
|
-
return {
|
672
|
-
"avg_wer": np.mean(wer_scores) if wer_scores else 1.0,
|
673
|
-
"avg_cer": np.mean(cer_scores) if cer_scores else 1.0,
|
674
|
-
"avg_word_accuracy": 1.0 - np.mean(wer_scores) if wer_scores else 0.0,
|
675
|
-
"avg_char_accuracy": 1.0 - np.mean(cer_scores) if cer_scores else 0.0,
|
676
|
-
"total_samples": len(predictions)
|
677
|
-
}
|
678
|
-
|
679
|
-
def _compute_aggregate_diarization_metrics(self,
|
680
|
-
predictions: List[Dict],
|
681
|
-
references: List[Dict]) -> Dict[str, float]:
|
682
|
-
"""Compute aggregate diarization metrics."""
|
683
|
-
der_scores = []
|
684
|
-
f1_scores = []
|
685
|
-
|
686
|
-
for pred, ref in zip(predictions, references):
|
687
|
-
if isinstance(pred, dict) and isinstance(ref, dict):
|
688
|
-
sample_metrics = self._compute_diarization_metrics(pred, ref)
|
689
|
-
der_scores.append(sample_metrics.get("diarization_error_rate", 1.0))
|
690
|
-
f1_scores.append(sample_metrics.get("speaker_f1_score", 0.0))
|
691
|
-
|
692
|
-
return {
|
693
|
-
"avg_diarization_error_rate": np.mean(der_scores) if der_scores else 1.0,
|
694
|
-
"avg_speaker_f1_score": np.mean(f1_scores) if f1_scores else 0.0,
|
695
|
-
"total_samples": len(predictions)
|
696
|
-
}
|
697
|
-
|
698
|
-
def _compute_aggregate_emotion_metrics(self,
|
699
|
-
predictions: List[Any],
|
700
|
-
references: List[Any]) -> Dict[str, float]:
|
701
|
-
"""Compute aggregate emotion recognition metrics."""
|
702
|
-
accuracies = []
|
703
|
-
confidences = []
|
704
|
-
|
705
|
-
for pred, ref in zip(predictions, references):
|
706
|
-
sample_metrics = self._compute_emotion_metrics(pred, ref)
|
707
|
-
accuracies.append(sample_metrics.get("emotion_accuracy", 0.0))
|
708
|
-
confidences.append(sample_metrics.get("emotion_confidence", 0.0))
|
709
|
-
|
710
|
-
return {
|
711
|
-
"avg_emotion_accuracy": np.mean(accuracies) if accuracies else 0.0,
|
712
|
-
"avg_confidence": np.mean(confidences) if confidences else 0.0,
|
713
|
-
"total_samples": len(predictions)
|
714
|
-
}
|
715
|
-
|
716
|
-
def get_supported_metrics(self) -> List[str]:
|
717
|
-
"""Get list of metrics supported by this evaluator."""
|
718
|
-
base_metrics = ["total_samples", "evaluation_success_rate"]
|
719
|
-
|
720
|
-
task_specific_metrics = {
|
721
|
-
"stt": ["wer", "cer", "word_accuracy", "char_accuracy", "bleu_score", "rouge_l"],
|
722
|
-
"diarization": ["diarization_error_rate", "speaker_f1_score"],
|
723
|
-
"emotion": ["emotion_accuracy", "emotion_confidence"],
|
724
|
-
"tts": ["tts_quality_score", "synthesis_success"]
|
725
|
-
}
|
726
|
-
|
727
|
-
return base_metrics + task_specific_metrics.get(self.task_type, [])
|