isa-model 0.3.9__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/client.py +732 -565
  3. isa_model/core/cache/redis_cache.py +401 -0
  4. isa_model/core/config/config_manager.py +53 -10
  5. isa_model/core/config.py +1 -1
  6. isa_model/core/database/__init__.py +1 -0
  7. isa_model/core/database/migrations.py +277 -0
  8. isa_model/core/database/supabase_client.py +123 -0
  9. isa_model/core/models/__init__.py +37 -0
  10. isa_model/core/models/model_billing_tracker.py +60 -88
  11. isa_model/core/models/model_manager.py +36 -18
  12. isa_model/core/models/model_repo.py +44 -38
  13. isa_model/core/models/model_statistics_tracker.py +234 -0
  14. isa_model/core/models/model_storage.py +0 -1
  15. isa_model/core/models/model_version_manager.py +959 -0
  16. isa_model/core/pricing_manager.py +2 -249
  17. isa_model/core/resilience/circuit_breaker.py +366 -0
  18. isa_model/core/security/secrets.py +358 -0
  19. isa_model/core/services/__init__.py +2 -4
  20. isa_model/core/services/intelligent_model_selector.py +101 -370
  21. isa_model/core/storage/hf_storage.py +1 -1
  22. isa_model/core/types.py +7 -0
  23. isa_model/deployment/cloud/modal/isa_audio_chatTTS_service.py +520 -0
  24. isa_model/deployment/cloud/modal/isa_audio_fish_service.py +0 -0
  25. isa_model/deployment/cloud/modal/isa_audio_openvoice_service.py +758 -0
  26. isa_model/deployment/cloud/modal/isa_audio_service_v2.py +1044 -0
  27. isa_model/deployment/cloud/modal/isa_embed_rerank_service.py +296 -0
  28. isa_model/deployment/cloud/modal/isa_video_hunyuan_service.py +423 -0
  29. isa_model/deployment/cloud/modal/isa_vision_ocr_service.py +519 -0
  30. isa_model/deployment/cloud/modal/isa_vision_qwen25_service.py +709 -0
  31. isa_model/deployment/cloud/modal/isa_vision_table_service.py +467 -323
  32. isa_model/deployment/cloud/modal/isa_vision_ui_service.py +607 -180
  33. isa_model/deployment/cloud/modal/isa_vision_ui_service_optimized.py +660 -0
  34. isa_model/deployment/core/deployment_manager.py +6 -4
  35. isa_model/deployment/services/auto_hf_modal_deployer.py +894 -0
  36. isa_model/eval/benchmarks/__init__.py +27 -0
  37. isa_model/eval/benchmarks/multimodal_datasets.py +460 -0
  38. isa_model/eval/benchmarks.py +244 -12
  39. isa_model/eval/evaluators/__init__.py +8 -2
  40. isa_model/eval/evaluators/audio_evaluator.py +727 -0
  41. isa_model/eval/evaluators/embedding_evaluator.py +742 -0
  42. isa_model/eval/evaluators/vision_evaluator.py +564 -0
  43. isa_model/eval/example_evaluation.py +395 -0
  44. isa_model/eval/factory.py +272 -5
  45. isa_model/eval/isa_benchmarks.py +700 -0
  46. isa_model/eval/isa_integration.py +582 -0
  47. isa_model/eval/metrics.py +159 -6
  48. isa_model/eval/tests/unit/test_basic.py +396 -0
  49. isa_model/inference/ai_factory.py +44 -8
  50. isa_model/inference/services/audio/__init__.py +21 -0
  51. isa_model/inference/services/audio/base_realtime_service.py +225 -0
  52. isa_model/inference/services/audio/isa_tts_service.py +0 -0
  53. isa_model/inference/services/audio/openai_realtime_service.py +320 -124
  54. isa_model/inference/services/audio/openai_stt_service.py +32 -6
  55. isa_model/inference/services/base_service.py +17 -1
  56. isa_model/inference/services/embedding/__init__.py +13 -0
  57. isa_model/inference/services/embedding/base_embed_service.py +111 -8
  58. isa_model/inference/services/embedding/isa_embed_service.py +305 -0
  59. isa_model/inference/services/embedding/openai_embed_service.py +2 -4
  60. isa_model/inference/services/embedding/tests/test_embedding.py +222 -0
  61. isa_model/inference/services/img/__init__.py +2 -2
  62. isa_model/inference/services/img/base_image_gen_service.py +24 -7
  63. isa_model/inference/services/img/replicate_image_gen_service.py +84 -422
  64. isa_model/inference/services/img/services/replicate_face_swap.py +193 -0
  65. isa_model/inference/services/img/services/replicate_flux.py +226 -0
  66. isa_model/inference/services/img/services/replicate_flux_kontext.py +219 -0
  67. isa_model/inference/services/img/services/replicate_sticker_maker.py +249 -0
  68. isa_model/inference/services/img/tests/test_img_client.py +297 -0
  69. isa_model/inference/services/llm/base_llm_service.py +30 -6
  70. isa_model/inference/services/llm/helpers/llm_adapter.py +63 -9
  71. isa_model/inference/services/llm/ollama_llm_service.py +2 -1
  72. isa_model/inference/services/llm/openai_llm_service.py +652 -55
  73. isa_model/inference/services/llm/yyds_llm_service.py +2 -1
  74. isa_model/inference/services/vision/__init__.py +5 -5
  75. isa_model/inference/services/vision/base_vision_service.py +118 -185
  76. isa_model/inference/services/vision/helpers/image_utils.py +11 -5
  77. isa_model/inference/services/vision/isa_vision_service.py +573 -0
  78. isa_model/inference/services/vision/tests/test_ocr_client.py +284 -0
  79. isa_model/serving/api/fastapi_server.py +88 -16
  80. isa_model/serving/api/middleware/auth.py +311 -0
  81. isa_model/serving/api/middleware/security.py +278 -0
  82. isa_model/serving/api/routes/analytics.py +486 -0
  83. isa_model/serving/api/routes/deployments.py +339 -0
  84. isa_model/serving/api/routes/evaluations.py +579 -0
  85. isa_model/serving/api/routes/logs.py +430 -0
  86. isa_model/serving/api/routes/settings.py +582 -0
  87. isa_model/serving/api/routes/unified.py +324 -165
  88. isa_model/serving/api/startup.py +304 -0
  89. isa_model/serving/modal_proxy_server.py +249 -0
  90. isa_model/training/__init__.py +100 -6
  91. isa_model/training/core/__init__.py +4 -1
  92. isa_model/training/examples/intelligent_training_example.py +281 -0
  93. isa_model/training/intelligent/__init__.py +25 -0
  94. isa_model/training/intelligent/decision_engine.py +643 -0
  95. isa_model/training/intelligent/intelligent_factory.py +888 -0
  96. isa_model/training/intelligent/knowledge_base.py +751 -0
  97. isa_model/training/intelligent/resource_optimizer.py +839 -0
  98. isa_model/training/intelligent/task_classifier.py +576 -0
  99. isa_model/training/storage/__init__.py +24 -0
  100. isa_model/training/storage/core_integration.py +439 -0
  101. isa_model/training/storage/training_repository.py +552 -0
  102. isa_model/training/storage/training_storage.py +628 -0
  103. {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/METADATA +13 -1
  104. isa_model-0.4.0.dist-info/RECORD +182 -0
  105. isa_model/deployment/cloud/modal/isa_vision_doc_service.py +0 -766
  106. isa_model/deployment/cloud/modal/register_models.py +0 -321
  107. isa_model/inference/adapter/unified_api.py +0 -248
  108. isa_model/inference/services/helpers/stacked_config.py +0 -148
  109. isa_model/inference/services/img/flux_professional_service.py +0 -603
  110. isa_model/inference/services/img/helpers/base_stacked_service.py +0 -274
  111. isa_model/inference/services/others/table_transformer_service.py +0 -61
  112. isa_model/inference/services/vision/doc_analysis_service.py +0 -640
  113. isa_model/inference/services/vision/helpers/base_stacked_service.py +0 -274
  114. isa_model/inference/services/vision/ui_analysis_service.py +0 -823
  115. isa_model/scripts/inference_tracker.py +0 -283
  116. isa_model/scripts/mlflow_manager.py +0 -379
  117. isa_model/scripts/model_registry.py +0 -465
  118. isa_model/scripts/register_models.py +0 -370
  119. isa_model/scripts/register_models_with_embeddings.py +0 -510
  120. isa_model/scripts/start_mlflow.py +0 -95
  121. isa_model/scripts/training_tracker.py +0 -257
  122. isa_model-0.3.9.dist-info/RECORD +0 -138
  123. {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/WHEEL +0 -0
  124. {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,727 @@
1
+ """
2
+ Audio Evaluator for ISA Model evaluation framework.
3
+
4
+ Provides comprehensive evaluation capabilities for audio tasks including:
5
+ - Speech-to-Text (STT) evaluation with WER/CER metrics
6
+ - Speaker diarization evaluation
7
+ - Emotion recognition evaluation
8
+ - Voice activity detection evaluation
9
+ - Speech enhancement evaluation
10
+ - Text-to-Speech (TTS) quality evaluation
11
+
12
+ Supports ISA custom audio services and standard audio models.
13
+ """
14
+
15
+ import asyncio
16
+ import logging
17
+ import librosa
18
+ import numpy as np
19
+ import re
20
+ from typing import Dict, List, Any, Optional, Union, Tuple
21
+ from pathlib import Path
22
+ import tempfile
23
+ import wave
24
+
25
+ from .base_evaluator import BaseEvaluator, EvaluationResult
26
+ from ..metrics import compute_text_metrics
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class AudioEvaluator(BaseEvaluator):
32
+ """
33
+ Comprehensive audio model evaluator.
34
+
35
+ Supports evaluation of:
36
+ - Speech-to-Text accuracy (WER, CER, BLEU)
37
+ - Speaker diarization accuracy (DER, Speaker F1)
38
+ - Emotion recognition accuracy
39
+ - Voice activity detection (Precision, Recall, F1)
40
+ - Speech enhancement quality (SNR, PESQ, STOI)
41
+ - Text-to-Speech naturalness and intelligibility
42
+ """
43
+
44
+ def __init__(self,
45
+ config: Optional[Dict[str, Any]] = None,
46
+ experiment_tracker: Optional[Any] = None):
47
+ """
48
+ Initialize the audio evaluator.
49
+
50
+ Args:
51
+ config: Evaluation configuration
52
+ experiment_tracker: Optional experiment tracking instance
53
+ """
54
+ super().__init__(
55
+ evaluator_name="audio_evaluator",
56
+ config=config,
57
+ experiment_tracker=experiment_tracker
58
+ )
59
+
60
+ # Audio-specific configuration
61
+ self.sample_rate = self.config.get("sample_rate", 16000)
62
+ self.supported_formats = self.config.get("supported_formats", ["wav", "mp3", "flac", "m4a"])
63
+ self.max_duration = self.config.get("max_duration_seconds", 300) # 5 minutes
64
+
65
+ # Evaluation task types
66
+ self.task_type = self.config.get("task_type", "stt") # stt, diarization, emotion, tts, enhancement
67
+
68
+ # STT evaluation settings
69
+ self.normalize_text = self.config.get("normalize_text", True)
70
+ self.case_sensitive = self.config.get("case_sensitive", False)
71
+ self.remove_punctuation = self.config.get("remove_punctuation", True)
72
+
73
+ # Speaker diarization settings
74
+ self.collar_tolerance = self.config.get("collar_tolerance", 0.25) # 250ms tolerance
75
+
76
+ logger.info(f"Initialized AudioEvaluator for task: {self.task_type}")
77
+
78
+ async def evaluate_sample(self,
79
+ sample: Dict[str, Any],
80
+ model_interface: Any) -> Dict[str, Any]:
81
+ """
82
+ Evaluate a single audio sample.
83
+
84
+ Args:
85
+ sample: Audio sample containing audio data and expected output
86
+ model_interface: Audio model interface
87
+
88
+ Returns:
89
+ Evaluation result for the sample
90
+ """
91
+ try:
92
+ # Extract sample data
93
+ audio_data = sample.get("audio")
94
+ expected_output = sample.get("expected_output", "")
95
+ task_type = sample.get("task_type", self.task_type)
96
+ metadata = sample.get("metadata", {})
97
+
98
+ # Process audio
99
+ processed_audio, audio_info = await self._process_audio(audio_data)
100
+
101
+ # Get model prediction based on task type
102
+ prediction = await self._get_model_prediction(
103
+ model_interface, processed_audio, task_type, metadata
104
+ )
105
+
106
+ # Compute sample-level metrics
107
+ sample_metrics = self._compute_sample_metrics(
108
+ prediction, expected_output, task_type, audio_info
109
+ )
110
+
111
+ return {
112
+ "prediction": prediction,
113
+ "expected_output": expected_output,
114
+ "task_type": task_type,
115
+ "sample_metrics": sample_metrics,
116
+ "audio_info": audio_info,
117
+ "metadata": metadata
118
+ }
119
+
120
+ except Exception as e:
121
+ logger.error(f"Error evaluating audio sample: {e}")
122
+ raise
123
+
124
+ async def _process_audio(self, audio_data: Union[str, bytes, np.ndarray, Path]) -> Tuple[np.ndarray, Dict[str, Any]]:
125
+ """
126
+ Process and validate audio data.
127
+
128
+ Args:
129
+ audio_data: Audio in various formats
130
+
131
+ Returns:
132
+ Tuple of (processed audio array, audio info dict)
133
+ """
134
+ try:
135
+ if isinstance(audio_data, str) and Path(audio_data).exists():
136
+ # File path
137
+ audio_array, sr = librosa.load(audio_data, sr=self.sample_rate)
138
+ original_sr = librosa.get_samplerate(audio_data)
139
+
140
+ elif isinstance(audio_data, Path):
141
+ # Path object
142
+ audio_array, sr = librosa.load(str(audio_data), sr=self.sample_rate)
143
+ original_sr = librosa.get_samplerate(str(audio_data))
144
+
145
+ elif isinstance(audio_data, bytes):
146
+ # Raw audio bytes - save to temp file first
147
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
148
+ tmp_file.write(audio_data)
149
+ tmp_path = tmp_file.name
150
+
151
+ audio_array, sr = librosa.load(tmp_path, sr=self.sample_rate)
152
+ original_sr = self.sample_rate # Assume target sample rate
153
+ Path(tmp_path).unlink() # Clean up temp file
154
+
155
+ elif isinstance(audio_data, np.ndarray):
156
+ # NumPy array
157
+ audio_array = audio_data
158
+ sr = self.sample_rate
159
+ original_sr = self.sample_rate
160
+
161
+ # Resample if needed
162
+ if len(audio_array.shape) > 1:
163
+ audio_array = librosa.to_mono(audio_array)
164
+
165
+ else:
166
+ raise ValueError(f"Unsupported audio data type: {type(audio_data)}")
167
+
168
+ # Validate duration
169
+ duration = len(audio_array) / sr
170
+ if duration > self.max_duration:
171
+ logger.warning(f"Audio duration {duration:.2f}s exceeds max duration {self.max_duration}s, truncating")
172
+ max_samples = int(self.max_duration * sr)
173
+ audio_array = audio_array[:max_samples]
174
+ duration = self.max_duration
175
+
176
+ # Compute audio features
177
+ audio_info = {
178
+ "duration_seconds": duration,
179
+ "sample_rate": sr,
180
+ "original_sample_rate": original_sr,
181
+ "num_samples": len(audio_array),
182
+ "rms_energy": float(np.sqrt(np.mean(audio_array**2))),
183
+ "zero_crossing_rate": float(np.mean(librosa.feature.zero_crossing_rate(audio_array))),
184
+ "spectral_centroid": float(np.mean(librosa.feature.spectral_centroid(y=audio_array, sr=sr)))
185
+ }
186
+
187
+ return audio_array, audio_info
188
+
189
+ except Exception as e:
190
+ logger.error(f"Error processing audio: {e}")
191
+ raise
192
+
193
+ async def _get_model_prediction(self,
194
+ model_interface: Any,
195
+ audio: np.ndarray,
196
+ task_type: str,
197
+ metadata: Dict[str, Any]) -> Union[str, Dict[str, Any]]:
198
+ """
199
+ Get model prediction for audio task.
200
+
201
+ Args:
202
+ model_interface: Audio model interface
203
+ audio: Processed audio array
204
+ task_type: Type of audio task
205
+ metadata: Additional metadata
206
+
207
+ Returns:
208
+ Model prediction (string for STT, dict for complex tasks)
209
+ """
210
+ try:
211
+ # Save audio to temporary file for model processing
212
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
213
+ # Convert to int16 for wav format
214
+ audio_int16 = (audio * 32767).astype(np.int16)
215
+
216
+ # Write WAV file
217
+ with wave.open(tmp_file.name, 'wb') as wav_file:
218
+ wav_file.setnchannels(1) # Mono
219
+ wav_file.setsampwidth(2) # 16-bit
220
+ wav_file.setframerate(self.sample_rate)
221
+ wav_file.writeframes(audio_int16.tobytes())
222
+
223
+ tmp_path = tmp_file.name
224
+
225
+ try:
226
+ if task_type == "stt":
227
+ # Speech-to-Text
228
+ if hasattr(model_interface, 'transcribe'):
229
+ result = await model_interface.transcribe(tmp_path)
230
+ prediction = result.get("text", "") if isinstance(result, dict) else str(result)
231
+ elif hasattr(model_interface, 'stt'):
232
+ prediction = await model_interface.stt(tmp_path)
233
+ else:
234
+ prediction = await model_interface.predict(tmp_path)
235
+
236
+ return str(prediction).strip()
237
+
238
+ elif task_type == "diarization":
239
+ # Speaker diarization
240
+ if hasattr(model_interface, 'diarize'):
241
+ result = await model_interface.diarize(tmp_path)
242
+ else:
243
+ result = await model_interface.predict(tmp_path, task="diarization")
244
+
245
+ return result # Should be dict with speaker segments
246
+
247
+ elif task_type == "emotion":
248
+ # Emotion recognition
249
+ if hasattr(model_interface, 'detect_emotion'):
250
+ result = await model_interface.detect_emotion(tmp_path)
251
+ else:
252
+ result = await model_interface.predict(tmp_path, task="emotion")
253
+
254
+ return result # Should be emotion label or dict
255
+
256
+ elif task_type == "tts":
257
+ # Text-to-Speech (reverse evaluation)
258
+ text_input = metadata.get("text_input", "")
259
+ if hasattr(model_interface, 'synthesize'):
260
+ result = await model_interface.synthesize(text_input)
261
+ else:
262
+ result = await model_interface.predict(text_input, task="tts")
263
+
264
+ return result # Should be synthesized audio
265
+
266
+ else:
267
+ # Generic prediction
268
+ prediction = await model_interface.predict(tmp_path)
269
+ return prediction
270
+
271
+ finally:
272
+ # Clean up temporary file
273
+ Path(tmp_path).unlink()
274
+
275
+ except Exception as e:
276
+ logger.error(f"Error getting model prediction: {e}")
277
+ raise
278
+
279
+ def _compute_sample_metrics(self,
280
+ prediction: Union[str, Dict[str, Any]],
281
+ expected_output: Union[str, Dict[str, Any]],
282
+ task_type: str,
283
+ audio_info: Dict[str, Any]) -> Dict[str, float]:
284
+ """
285
+ Compute metrics for a single sample.
286
+
287
+ Args:
288
+ prediction: Model prediction
289
+ expected_output: Expected/reference output
290
+ task_type: Type of audio task
291
+ audio_info: Audio metadata
292
+
293
+ Returns:
294
+ Dictionary of sample-level metrics
295
+ """
296
+ try:
297
+ metrics = {}
298
+
299
+ if task_type == "stt":
300
+ # Speech-to-Text metrics
301
+ if isinstance(prediction, str) and isinstance(expected_output, str):
302
+ metrics.update(self._compute_stt_metrics(prediction, expected_output))
303
+
304
+ elif task_type == "diarization":
305
+ # Speaker diarization metrics
306
+ metrics.update(self._compute_diarization_metrics(prediction, expected_output))
307
+
308
+ elif task_type == "emotion":
309
+ # Emotion recognition metrics
310
+ metrics.update(self._compute_emotion_metrics(prediction, expected_output))
311
+
312
+ elif task_type == "tts":
313
+ # TTS quality metrics
314
+ metrics.update(self._compute_tts_metrics(prediction, expected_output, audio_info))
315
+
316
+ # Add audio metadata
317
+ metrics.update({
318
+ "audio_duration": audio_info.get("duration_seconds", 0.0),
319
+ "audio_quality_score": self._compute_audio_quality_score(audio_info)
320
+ })
321
+
322
+ return metrics
323
+
324
+ except Exception as e:
325
+ logger.error(f"Error computing sample metrics: {e}")
326
+ return {"error": 1.0}
327
+
328
+ def _compute_stt_metrics(self, prediction: str, reference: str) -> Dict[str, float]:
329
+ """Compute Speech-to-Text specific metrics."""
330
+ try:
331
+ # Normalize text if configured
332
+ if self.normalize_text:
333
+ prediction = self._normalize_text(prediction)
334
+ reference = self._normalize_text(reference)
335
+
336
+ # Word Error Rate (WER)
337
+ wer = self._compute_wer(prediction, reference)
338
+
339
+ # Character Error Rate (CER)
340
+ cer = self._compute_cer(prediction, reference)
341
+
342
+ # Additional text metrics
343
+ text_metrics = compute_text_metrics(prediction, reference)
344
+
345
+ return {
346
+ "wer": wer,
347
+ "cer": cer,
348
+ "word_accuracy": 1.0 - wer,
349
+ "char_accuracy": 1.0 - cer,
350
+ **text_metrics
351
+ }
352
+
353
+ except Exception as e:
354
+ logger.error(f"Error computing STT metrics: {e}")
355
+ return {"stt_error": 1.0}
356
+
357
+ def _compute_wer(self, prediction: str, reference: str) -> float:
358
+ """Compute Word Error Rate."""
359
+ try:
360
+ pred_words = prediction.strip().split()
361
+ ref_words = reference.strip().split()
362
+
363
+ if not ref_words:
364
+ return 0.0 if not pred_words else 1.0
365
+
366
+ # Compute edit distance
367
+ distance = self._edit_distance(pred_words, ref_words)
368
+ wer = distance / len(ref_words)
369
+
370
+ return min(1.0, wer) # Cap at 100%
371
+
372
+ except Exception as e:
373
+ logger.error(f"Error computing WER: {e}")
374
+ return 1.0
375
+
376
+ def _compute_cer(self, prediction: str, reference: str) -> float:
377
+ """Compute Character Error Rate."""
378
+ try:
379
+ pred_chars = list(prediction.strip())
380
+ ref_chars = list(reference.strip())
381
+
382
+ if not ref_chars:
383
+ return 0.0 if not pred_chars else 1.0
384
+
385
+ # Compute edit distance
386
+ distance = self._edit_distance(pred_chars, ref_chars)
387
+ cer = distance / len(ref_chars)
388
+
389
+ return min(1.0, cer) # Cap at 100%
390
+
391
+ except Exception as e:
392
+ logger.error(f"Error computing CER: {e}")
393
+ return 1.0
394
+
395
+ def _edit_distance(self, seq1: List[str], seq2: List[str]) -> int:
396
+ """Compute Levenshtein edit distance."""
397
+ m, n = len(seq1), len(seq2)
398
+ dp = [[0] * (n + 1) for _ in range(m + 1)]
399
+
400
+ # Initialize base cases
401
+ for i in range(m + 1):
402
+ dp[i][0] = i
403
+ for j in range(n + 1):
404
+ dp[0][j] = j
405
+
406
+ # Fill the DP table
407
+ for i in range(1, m + 1):
408
+ for j in range(1, n + 1):
409
+ if seq1[i-1] == seq2[j-1]:
410
+ dp[i][j] = dp[i-1][j-1]
411
+ else:
412
+ dp[i][j] = 1 + min(
413
+ dp[i-1][j], # deletion
414
+ dp[i][j-1], # insertion
415
+ dp[i-1][j-1] # substitution
416
+ )
417
+
418
+ return dp[m][n]
419
+
420
+ def _normalize_text(self, text: str) -> str:
421
+ """Normalize text for evaluation."""
422
+ # Convert to lowercase if not case sensitive
423
+ if not self.case_sensitive:
424
+ text = text.lower()
425
+
426
+ # Remove punctuation if configured
427
+ if self.remove_punctuation:
428
+ text = re.sub(r'[^\w\s]', ' ', text)
429
+
430
+ # Normalize whitespace
431
+ text = re.sub(r'\s+', ' ', text).strip()
432
+
433
+ return text
434
+
435
+ def _compute_diarization_metrics(self,
436
+ prediction: Dict[str, Any],
437
+ reference: Dict[str, Any]) -> Dict[str, float]:
438
+ """Compute speaker diarization metrics."""
439
+ try:
440
+ # Extract speaker segments
441
+ pred_segments = prediction.get("segments", []) if isinstance(prediction, dict) else []
442
+ ref_segments = reference.get("segments", []) if isinstance(reference, dict) else []
443
+
444
+ if not ref_segments:
445
+ return {"diarization_error": 1.0}
446
+
447
+ # Compute Diarization Error Rate (DER)
448
+ der = self._compute_der(pred_segments, ref_segments)
449
+
450
+ # Compute Speaker F1 score
451
+ speaker_f1 = self._compute_speaker_f1(pred_segments, ref_segments)
452
+
453
+ return {
454
+ "diarization_error_rate": der,
455
+ "speaker_f1_score": speaker_f1,
456
+ "num_predicted_speakers": len(set(seg.get("speaker", "") for seg in pred_segments)),
457
+ "num_reference_speakers": len(set(seg.get("speaker", "") for seg in ref_segments))
458
+ }
459
+
460
+ except Exception as e:
461
+ logger.error(f"Error computing diarization metrics: {e}")
462
+ return {"diarization_error": 1.0}
463
+
464
+ def _compute_der(self, pred_segments: List[Dict], ref_segments: List[Dict]) -> float:
465
+ """Compute Diarization Error Rate."""
466
+ try:
467
+ # This is a simplified DER computation
468
+ # In practice, you'd use specialized libraries like pyannote.metrics
469
+
470
+ total_time = 0.0
471
+ error_time = 0.0
472
+
473
+ # Find overall time range
474
+ all_segments = pred_segments + ref_segments
475
+ if not all_segments:
476
+ return 0.0
477
+
478
+ start_time = min(seg.get("start", 0.0) for seg in all_segments)
479
+ end_time = max(seg.get("end", 0.0) for seg in all_segments)
480
+ total_time = end_time - start_time
481
+
482
+ if total_time <= 0:
483
+ return 0.0
484
+
485
+ # Sample time points and check for errors
486
+ time_step = 0.1 # 100ms resolution
487
+ num_steps = int(total_time / time_step)
488
+
489
+ for i in range(num_steps):
490
+ t = start_time + i * time_step
491
+
492
+ # Find speakers at time t
493
+ pred_speaker = self._get_speaker_at_time(t, pred_segments)
494
+ ref_speaker = self._get_speaker_at_time(t, ref_segments)
495
+
496
+ if pred_speaker != ref_speaker:
497
+ error_time += time_step
498
+
499
+ der = error_time / total_time if total_time > 0 else 0.0
500
+ return min(1.0, der)
501
+
502
+ except Exception as e:
503
+ logger.error(f"Error computing DER: {e}")
504
+ return 1.0
505
+
506
+ def _get_speaker_at_time(self, time: float, segments: List[Dict]) -> Optional[str]:
507
+ """Get the speaker at a specific time point."""
508
+ for segment in segments:
509
+ start = segment.get("start", 0.0)
510
+ end = segment.get("end", 0.0)
511
+ if start <= time < end:
512
+ return segment.get("speaker")
513
+ return None
514
+
515
+ def _compute_speaker_f1(self, pred_segments: List[Dict], ref_segments: List[Dict]) -> float:
516
+ """Compute Speaker F1 score."""
517
+ try:
518
+ # Extract unique speakers
519
+ pred_speakers = set(seg.get("speaker", "") for seg in pred_segments)
520
+ ref_speakers = set(seg.get("speaker", "") for seg in ref_segments)
521
+
522
+ pred_speakers.discard("") # Remove empty speakers
523
+ ref_speakers.discard("")
524
+
525
+ if not ref_speakers:
526
+ return 1.0 if not pred_speakers else 0.0
527
+
528
+ # Simple speaker overlap metric
529
+ intersection = len(pred_speakers.intersection(ref_speakers))
530
+ precision = intersection / len(pred_speakers) if pred_speakers else 0.0
531
+ recall = intersection / len(ref_speakers) if ref_speakers else 0.0
532
+
533
+ if precision + recall == 0:
534
+ return 0.0
535
+
536
+ f1 = 2 * precision * recall / (precision + recall)
537
+ return f1
538
+
539
+ except Exception as e:
540
+ logger.error(f"Error computing speaker F1: {e}")
541
+ return 0.0
542
+
543
+ def _compute_emotion_metrics(self,
544
+ prediction: Union[str, Dict[str, Any]],
545
+ reference: Union[str, Dict[str, Any]]) -> Dict[str, float]:
546
+ """Compute emotion recognition metrics."""
547
+ try:
548
+ # Extract emotion labels
549
+ if isinstance(prediction, dict):
550
+ pred_emotion = prediction.get("emotion", "")
551
+ pred_confidence = prediction.get("confidence", 0.0)
552
+ else:
553
+ pred_emotion = str(prediction)
554
+ pred_confidence = 1.0
555
+
556
+ if isinstance(reference, dict):
557
+ ref_emotion = reference.get("emotion", "")
558
+ else:
559
+ ref_emotion = str(reference)
560
+
561
+ # Compute accuracy
562
+ emotion_accuracy = 1.0 if pred_emotion.lower() == ref_emotion.lower() else 0.0
563
+
564
+ return {
565
+ "emotion_accuracy": emotion_accuracy,
566
+ "emotion_confidence": pred_confidence,
567
+ "predicted_emotion": pred_emotion,
568
+ "reference_emotion": ref_emotion
569
+ }
570
+
571
+ except Exception as e:
572
+ logger.error(f"Error computing emotion metrics: {e}")
573
+ return {"emotion_error": 1.0}
574
+
575
+ def _compute_tts_metrics(self,
576
+ prediction: Any,
577
+ reference: Any,
578
+ audio_info: Dict[str, Any]) -> Dict[str, float]:
579
+ """Compute Text-to-Speech quality metrics."""
580
+ try:
581
+ # This is a placeholder for TTS evaluation
582
+ # In practice, you'd use specialized metrics like MOS, PESQ, STOI
583
+
584
+ return {
585
+ "tts_quality_score": 0.8, # Placeholder
586
+ "synthesis_success": 1.0 if prediction is not None else 0.0
587
+ }
588
+
589
+ except Exception as e:
590
+ logger.error(f"Error computing TTS metrics: {e}")
591
+ return {"tts_error": 1.0}
592
+
593
+ def _compute_audio_quality_score(self, audio_info: Dict[str, Any]) -> float:
594
+ """Compute a simple audio quality score based on audio features."""
595
+ try:
596
+ # Simple heuristic based on RMS energy and other features
597
+ rms_energy = audio_info.get("rms_energy", 0.0)
598
+ duration = audio_info.get("duration_seconds", 0.0)
599
+
600
+ # Normalize RMS energy (assuming good audio is in range 0.01-0.1)
601
+ energy_score = min(1.0, max(0.0, (rms_energy - 0.001) / 0.1))
602
+
603
+ # Duration score (prefer reasonable durations)
604
+ duration_score = 1.0 if 1.0 <= duration <= 60.0 else 0.5
605
+
606
+ quality_score = (energy_score + duration_score) / 2
607
+ return quality_score
608
+
609
+ except Exception as e:
610
+ logger.error(f"Error computing audio quality score: {e}")
611
+ return 0.5
612
+
613
+ def compute_metrics(self,
614
+ predictions: List[Any],
615
+ references: List[Any],
616
+ **kwargs) -> Dict[str, float]:
617
+ """
618
+ Compute aggregate audio evaluation metrics.
619
+
620
+ Args:
621
+ predictions: List of model predictions
622
+ references: List of reference outputs
623
+ **kwargs: Additional parameters
624
+
625
+ Returns:
626
+ Dictionary of computed metrics
627
+ """
628
+ try:
629
+ if not predictions or not references:
630
+ logger.warning("Empty predictions or references provided")
631
+ return {}
632
+
633
+ # Ensure equal lengths
634
+ min_len = min(len(predictions), len(references))
635
+ predictions = predictions[:min_len]
636
+ references = references[:min_len]
637
+
638
+ task_type = self.task_type
639
+
640
+ if task_type == "stt":
641
+ return self._compute_aggregate_stt_metrics(predictions, references)
642
+ elif task_type == "diarization":
643
+ return self._compute_aggregate_diarization_metrics(predictions, references)
644
+ elif task_type == "emotion":
645
+ return self._compute_aggregate_emotion_metrics(predictions, references)
646
+ else:
647
+ # Generic metrics
648
+ return {
649
+ "total_samples": len(predictions),
650
+ "task_type": task_type,
651
+ "evaluation_success_rate": 1.0
652
+ }
653
+
654
+ except Exception as e:
655
+ logger.error(f"Error computing aggregate metrics: {e}")
656
+ return {"error_rate": 1.0}
657
+
658
+ def _compute_aggregate_stt_metrics(self,
659
+ predictions: List[str],
660
+ references: List[str]) -> Dict[str, float]:
661
+ """Compute aggregate STT metrics."""
662
+ wer_scores = []
663
+ cer_scores = []
664
+
665
+ for pred, ref in zip(predictions, references):
666
+ if isinstance(pred, str) and isinstance(ref, str):
667
+ sample_metrics = self._compute_stt_metrics(pred, ref)
668
+ wer_scores.append(sample_metrics.get("wer", 1.0))
669
+ cer_scores.append(sample_metrics.get("cer", 1.0))
670
+
671
+ return {
672
+ "avg_wer": np.mean(wer_scores) if wer_scores else 1.0,
673
+ "avg_cer": np.mean(cer_scores) if cer_scores else 1.0,
674
+ "avg_word_accuracy": 1.0 - np.mean(wer_scores) if wer_scores else 0.0,
675
+ "avg_char_accuracy": 1.0 - np.mean(cer_scores) if cer_scores else 0.0,
676
+ "total_samples": len(predictions)
677
+ }
678
+
679
+ def _compute_aggregate_diarization_metrics(self,
680
+ predictions: List[Dict],
681
+ references: List[Dict]) -> Dict[str, float]:
682
+ """Compute aggregate diarization metrics."""
683
+ der_scores = []
684
+ f1_scores = []
685
+
686
+ for pred, ref in zip(predictions, references):
687
+ if isinstance(pred, dict) and isinstance(ref, dict):
688
+ sample_metrics = self._compute_diarization_metrics(pred, ref)
689
+ der_scores.append(sample_metrics.get("diarization_error_rate", 1.0))
690
+ f1_scores.append(sample_metrics.get("speaker_f1_score", 0.0))
691
+
692
+ return {
693
+ "avg_diarization_error_rate": np.mean(der_scores) if der_scores else 1.0,
694
+ "avg_speaker_f1_score": np.mean(f1_scores) if f1_scores else 0.0,
695
+ "total_samples": len(predictions)
696
+ }
697
+
698
+ def _compute_aggregate_emotion_metrics(self,
699
+ predictions: List[Any],
700
+ references: List[Any]) -> Dict[str, float]:
701
+ """Compute aggregate emotion recognition metrics."""
702
+ accuracies = []
703
+ confidences = []
704
+
705
+ for pred, ref in zip(predictions, references):
706
+ sample_metrics = self._compute_emotion_metrics(pred, ref)
707
+ accuracies.append(sample_metrics.get("emotion_accuracy", 0.0))
708
+ confidences.append(sample_metrics.get("emotion_confidence", 0.0))
709
+
710
+ return {
711
+ "avg_emotion_accuracy": np.mean(accuracies) if accuracies else 0.0,
712
+ "avg_confidence": np.mean(confidences) if confidences else 0.0,
713
+ "total_samples": len(predictions)
714
+ }
715
+
716
+ def get_supported_metrics(self) -> List[str]:
717
+ """Get list of metrics supported by this evaluator."""
718
+ base_metrics = ["total_samples", "evaluation_success_rate"]
719
+
720
+ task_specific_metrics = {
721
+ "stt": ["wer", "cer", "word_accuracy", "char_accuracy", "bleu_score", "rouge_l"],
722
+ "diarization": ["diarization_error_rate", "speaker_f1_score"],
723
+ "emotion": ["emotion_accuracy", "emotion_confidence"],
724
+ "tts": ["tts_quality_score", "synthesis_success"]
725
+ }
726
+
727
+ return base_metrics + task_specific_metrics.get(self.task_type, [])