isa-model 0.3.91__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/client.py +732 -573
- isa_model/core/cache/redis_cache.py +401 -0
- isa_model/core/config/config_manager.py +53 -10
- isa_model/core/config.py +1 -1
- isa_model/core/database/__init__.py +1 -0
- isa_model/core/database/migrations.py +277 -0
- isa_model/core/database/supabase_client.py +123 -0
- isa_model/core/models/__init__.py +37 -0
- isa_model/core/models/model_billing_tracker.py +60 -88
- isa_model/core/models/model_manager.py +36 -18
- isa_model/core/models/model_repo.py +44 -38
- isa_model/core/models/model_statistics_tracker.py +234 -0
- isa_model/core/models/model_storage.py +0 -1
- isa_model/core/models/model_version_manager.py +959 -0
- isa_model/core/pricing_manager.py +2 -249
- isa_model/core/resilience/circuit_breaker.py +366 -0
- isa_model/core/security/secrets.py +358 -0
- isa_model/core/services/__init__.py +2 -4
- isa_model/core/services/intelligent_model_selector.py +101 -370
- isa_model/core/storage/hf_storage.py +1 -1
- isa_model/core/types.py +7 -0
- isa_model/deployment/cloud/modal/isa_audio_chatTTS_service.py +520 -0
- isa_model/deployment/cloud/modal/isa_audio_fish_service.py +0 -0
- isa_model/deployment/cloud/modal/isa_audio_openvoice_service.py +758 -0
- isa_model/deployment/cloud/modal/isa_audio_service_v2.py +1044 -0
- isa_model/deployment/cloud/modal/isa_embed_rerank_service.py +296 -0
- isa_model/deployment/cloud/modal/isa_video_hunyuan_service.py +423 -0
- isa_model/deployment/cloud/modal/isa_vision_ocr_service.py +519 -0
- isa_model/deployment/cloud/modal/isa_vision_qwen25_service.py +709 -0
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +467 -323
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +607 -180
- isa_model/deployment/cloud/modal/isa_vision_ui_service_optimized.py +660 -0
- isa_model/deployment/core/deployment_manager.py +6 -4
- isa_model/deployment/services/auto_hf_modal_deployer.py +894 -0
- isa_model/eval/benchmarks/__init__.py +27 -0
- isa_model/eval/benchmarks/multimodal_datasets.py +460 -0
- isa_model/eval/benchmarks.py +244 -12
- isa_model/eval/evaluators/__init__.py +8 -2
- isa_model/eval/evaluators/audio_evaluator.py +727 -0
- isa_model/eval/evaluators/embedding_evaluator.py +742 -0
- isa_model/eval/evaluators/vision_evaluator.py +564 -0
- isa_model/eval/example_evaluation.py +395 -0
- isa_model/eval/factory.py +272 -5
- isa_model/eval/isa_benchmarks.py +700 -0
- isa_model/eval/isa_integration.py +582 -0
- isa_model/eval/metrics.py +159 -6
- isa_model/eval/tests/unit/test_basic.py +396 -0
- isa_model/inference/ai_factory.py +44 -8
- isa_model/inference/services/audio/__init__.py +21 -0
- isa_model/inference/services/audio/base_realtime_service.py +225 -0
- isa_model/inference/services/audio/isa_tts_service.py +0 -0
- isa_model/inference/services/audio/openai_realtime_service.py +320 -124
- isa_model/inference/services/audio/openai_stt_service.py +32 -6
- isa_model/inference/services/base_service.py +17 -1
- isa_model/inference/services/embedding/__init__.py +13 -0
- isa_model/inference/services/embedding/base_embed_service.py +111 -8
- isa_model/inference/services/embedding/isa_embed_service.py +305 -0
- isa_model/inference/services/embedding/openai_embed_service.py +2 -4
- isa_model/inference/services/embedding/tests/test_embedding.py +222 -0
- isa_model/inference/services/img/__init__.py +2 -2
- isa_model/inference/services/img/base_image_gen_service.py +24 -7
- isa_model/inference/services/img/replicate_image_gen_service.py +84 -422
- isa_model/inference/services/img/services/replicate_face_swap.py +193 -0
- isa_model/inference/services/img/services/replicate_flux.py +226 -0
- isa_model/inference/services/img/services/replicate_flux_kontext.py +219 -0
- isa_model/inference/services/img/services/replicate_sticker_maker.py +249 -0
- isa_model/inference/services/img/tests/test_img_client.py +297 -0
- isa_model/inference/services/llm/base_llm_service.py +30 -6
- isa_model/inference/services/llm/helpers/llm_adapter.py +63 -9
- isa_model/inference/services/llm/ollama_llm_service.py +2 -1
- isa_model/inference/services/llm/openai_llm_service.py +652 -55
- isa_model/inference/services/llm/yyds_llm_service.py +2 -1
- isa_model/inference/services/vision/__init__.py +5 -5
- isa_model/inference/services/vision/base_vision_service.py +118 -185
- isa_model/inference/services/vision/helpers/image_utils.py +11 -5
- isa_model/inference/services/vision/isa_vision_service.py +573 -0
- isa_model/inference/services/vision/tests/test_ocr_client.py +284 -0
- isa_model/serving/api/fastapi_server.py +88 -16
- isa_model/serving/api/middleware/auth.py +311 -0
- isa_model/serving/api/middleware/security.py +278 -0
- isa_model/serving/api/routes/analytics.py +486 -0
- isa_model/serving/api/routes/deployments.py +339 -0
- isa_model/serving/api/routes/evaluations.py +579 -0
- isa_model/serving/api/routes/logs.py +430 -0
- isa_model/serving/api/routes/settings.py +582 -0
- isa_model/serving/api/routes/unified.py +324 -165
- isa_model/serving/api/startup.py +304 -0
- isa_model/serving/modal_proxy_server.py +249 -0
- isa_model/training/__init__.py +100 -6
- isa_model/training/core/__init__.py +4 -1
- isa_model/training/examples/intelligent_training_example.py +281 -0
- isa_model/training/intelligent/__init__.py +25 -0
- isa_model/training/intelligent/decision_engine.py +643 -0
- isa_model/training/intelligent/intelligent_factory.py +888 -0
- isa_model/training/intelligent/knowledge_base.py +751 -0
- isa_model/training/intelligent/resource_optimizer.py +839 -0
- isa_model/training/intelligent/task_classifier.py +576 -0
- isa_model/training/storage/__init__.py +24 -0
- isa_model/training/storage/core_integration.py +439 -0
- isa_model/training/storage/training_repository.py +552 -0
- isa_model/training/storage/training_storage.py +628 -0
- {isa_model-0.3.91.dist-info → isa_model-0.4.0.dist-info}/METADATA +13 -1
- isa_model-0.4.0.dist-info/RECORD +182 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +0 -766
- isa_model/deployment/cloud/modal/register_models.py +0 -321
- isa_model/inference/adapter/unified_api.py +0 -248
- isa_model/inference/services/helpers/stacked_config.py +0 -148
- isa_model/inference/services/img/flux_professional_service.py +0 -603
- isa_model/inference/services/img/helpers/base_stacked_service.py +0 -274
- isa_model/inference/services/others/table_transformer_service.py +0 -61
- isa_model/inference/services/vision/doc_analysis_service.py +0 -640
- isa_model/inference/services/vision/helpers/base_stacked_service.py +0 -274
- isa_model/inference/services/vision/ui_analysis_service.py +0 -823
- isa_model/scripts/inference_tracker.py +0 -283
- isa_model/scripts/mlflow_manager.py +0 -379
- isa_model/scripts/model_registry.py +0 -465
- isa_model/scripts/register_models.py +0 -370
- isa_model/scripts/register_models_with_embeddings.py +0 -510
- isa_model/scripts/start_mlflow.py +0 -95
- isa_model/scripts/training_tracker.py +0 -257
- isa_model-0.3.91.dist-info/RECORD +0 -138
- {isa_model-0.3.91.dist-info → isa_model-0.4.0.dist-info}/WHEEL +0 -0
- {isa_model-0.3.91.dist-info → isa_model-0.4.0.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,6 @@ from datetime import datetime, timedelta
|
|
15
15
|
from dataclasses import dataclass, field
|
16
16
|
|
17
17
|
from .types import Provider
|
18
|
-
from .config import config_manager
|
19
18
|
|
20
19
|
logger = logging.getLogger(__name__)
|
21
20
|
|
@@ -78,17 +77,7 @@ class PricingManager:
|
|
78
77
|
|
79
78
|
def _load_pricing_data(self):
|
80
79
|
"""Load pricing data from configuration files"""
|
81
|
-
# Try to load from
|
82
|
-
if self._load_from_supabase():
|
83
|
-
logger.info("Loaded pricing data from Supabase")
|
84
|
-
return
|
85
|
-
|
86
|
-
# Try to load from provider configurations
|
87
|
-
if self._load_from_provider_configs():
|
88
|
-
logger.info("Loaded pricing data from provider configurations")
|
89
|
-
return
|
90
|
-
|
91
|
-
# Try to load from specified config path
|
80
|
+
# Try to load from specified config path first
|
92
81
|
if self.config_path and self.config_path.exists():
|
93
82
|
self._load_from_file(self.config_path)
|
94
83
|
return
|
@@ -168,7 +157,7 @@ class PricingManager:
|
|
168
157
|
"""Load default pricing data as fallback"""
|
169
158
|
default_pricing = {
|
170
159
|
"openai": {
|
171
|
-
"gpt-4o-mini": {"input": 0.
|
160
|
+
"gpt-4o-mini": {"input": 0.00000015, "output": 0.0000006, "unit_type": "token"},
|
172
161
|
"gpt-4o": {"input": 0.000005, "output": 0.000015, "unit_type": "token"},
|
173
162
|
"gpt-4-turbo": {"input": 0.00001, "output": 0.00003, "unit_type": "token"},
|
174
163
|
"gpt-4": {"input": 0.00003, "output": 0.00006, "unit_type": "token"},
|
@@ -198,242 +187,6 @@ class PricingManager:
|
|
198
187
|
self._parse_pricing_data({"providers": default_pricing})
|
199
188
|
logger.info("Loaded default pricing data")
|
200
189
|
|
201
|
-
def _load_from_supabase(self) -> bool:
|
202
|
-
"""Try to load pricing data from Supabase models table"""
|
203
|
-
try:
|
204
|
-
global_config = config_manager.get_global_config()
|
205
|
-
if not global_config.use_supabase:
|
206
|
-
return False
|
207
|
-
|
208
|
-
# Import Supabase client
|
209
|
-
try:
|
210
|
-
from supabase import create_client, Client
|
211
|
-
except ImportError:
|
212
|
-
logger.debug("Supabase library not available")
|
213
|
-
return False
|
214
|
-
|
215
|
-
# Get Supabase credentials
|
216
|
-
supabase_url = global_config.supabase_url or os.getenv('SUPABASE_URL')
|
217
|
-
supabase_key = global_config.supabase_key or os.getenv('SUPABASE_ANON_KEY')
|
218
|
-
|
219
|
-
if not supabase_url or not supabase_key:
|
220
|
-
logger.debug("Supabase credentials not configured")
|
221
|
-
return False
|
222
|
-
|
223
|
-
# Create Supabase client
|
224
|
-
supabase: Client = create_client(supabase_url, supabase_key)
|
225
|
-
|
226
|
-
# Query models table for pricing information
|
227
|
-
result = supabase.table('models').select('model_id, provider, metadata').execute()
|
228
|
-
|
229
|
-
if not result.data:
|
230
|
-
logger.debug("No models found in Supabase")
|
231
|
-
return False
|
232
|
-
|
233
|
-
self.pricing_data = {}
|
234
|
-
loaded_count = 0
|
235
|
-
|
236
|
-
for model_record in result.data:
|
237
|
-
model_id = model_record.get('model_id')
|
238
|
-
provider = model_record.get('provider')
|
239
|
-
metadata = model_record.get('metadata', {})
|
240
|
-
|
241
|
-
if not model_id or not provider:
|
242
|
-
continue
|
243
|
-
|
244
|
-
# Extract pricing from metadata
|
245
|
-
pricing = self._extract_pricing_from_supabase_metadata(metadata, provider, model_id)
|
246
|
-
if pricing:
|
247
|
-
if provider not in self.pricing_data:
|
248
|
-
self.pricing_data[provider] = {}
|
249
|
-
self.pricing_data[provider][model_id] = pricing
|
250
|
-
loaded_count += 1
|
251
|
-
|
252
|
-
if loaded_count > 0:
|
253
|
-
logger.info(f"Loaded pricing for {loaded_count} models from Supabase")
|
254
|
-
return True
|
255
|
-
else:
|
256
|
-
logger.debug("No pricing data found in Supabase models")
|
257
|
-
return False
|
258
|
-
|
259
|
-
except Exception as e:
|
260
|
-
logger.debug(f"Failed to load pricing from Supabase: {e}")
|
261
|
-
return False
|
262
|
-
|
263
|
-
def _load_from_provider_configs(self) -> bool:
|
264
|
-
"""Load pricing data from provider configuration files"""
|
265
|
-
try:
|
266
|
-
providers_dir = self._find_project_root() / "isa_model" / "core" / "config" / "providers"
|
267
|
-
if not providers_dir.exists():
|
268
|
-
return False
|
269
|
-
|
270
|
-
self.pricing_data = {}
|
271
|
-
loaded_any = False
|
272
|
-
|
273
|
-
# Load all provider config files
|
274
|
-
for config_file in providers_dir.glob("*.yaml"):
|
275
|
-
if self._load_provider_config_file(config_file):
|
276
|
-
loaded_any = True
|
277
|
-
|
278
|
-
return loaded_any
|
279
|
-
|
280
|
-
except Exception as e:
|
281
|
-
logger.error(f"Failed to load pricing from provider configs: {e}")
|
282
|
-
return False
|
283
|
-
|
284
|
-
def _load_provider_config_file(self, config_file: Path) -> bool:
|
285
|
-
"""Load pricing data from a single provider config file"""
|
286
|
-
try:
|
287
|
-
with open(config_file, 'r') as f:
|
288
|
-
provider_data = yaml.safe_load(f)
|
289
|
-
|
290
|
-
provider_name = provider_data.get("provider")
|
291
|
-
if not provider_name:
|
292
|
-
return False
|
293
|
-
|
294
|
-
models = provider_data.get("models", [])
|
295
|
-
if not models:
|
296
|
-
return False
|
297
|
-
|
298
|
-
self.pricing_data[provider_name] = {}
|
299
|
-
|
300
|
-
for model in models:
|
301
|
-
model_id = model.get("model_id")
|
302
|
-
metadata = model.get("metadata", {})
|
303
|
-
|
304
|
-
if not model_id:
|
305
|
-
continue
|
306
|
-
|
307
|
-
# Extract pricing information from metadata
|
308
|
-
pricing = self._extract_pricing_from_metadata(metadata, provider_name, model_id)
|
309
|
-
if pricing:
|
310
|
-
self.pricing_data[provider_name][model_id] = pricing
|
311
|
-
|
312
|
-
logger.debug(f"Loaded pricing for {len(self.pricing_data[provider_name])} models from {provider_name}")
|
313
|
-
return True
|
314
|
-
|
315
|
-
except Exception as e:
|
316
|
-
logger.error(f"Failed to load provider config {config_file}: {e}")
|
317
|
-
return False
|
318
|
-
|
319
|
-
def _extract_pricing_from_metadata(self, metadata: Dict[str, Any], provider: str, model_name: str) -> Optional[ModelPricing]:
|
320
|
-
"""Extract pricing information from model metadata"""
|
321
|
-
try:
|
322
|
-
# Map different pricing field formats to our standard format
|
323
|
-
pricing_fields = {
|
324
|
-
"cost_per_1000_chars": ("character", 1000),
|
325
|
-
"cost_per_1000_tokens": ("token", 1000000), # Convert to cost per 1M tokens
|
326
|
-
"cost_per_minute": ("minute", 1),
|
327
|
-
"cost_per_image": ("image", 1),
|
328
|
-
"cost_per_request": ("request", 1),
|
329
|
-
}
|
330
|
-
|
331
|
-
input_cost = 0.0
|
332
|
-
output_cost = 0.0
|
333
|
-
unit_type = "token"
|
334
|
-
base_cost = 0.0
|
335
|
-
|
336
|
-
for field, (unit, multiplier) in pricing_fields.items():
|
337
|
-
if field in metadata:
|
338
|
-
cost = float(metadata[field])
|
339
|
-
if unit == "character":
|
340
|
-
# Convert cost per 1K chars to cost per 1K chars
|
341
|
-
input_cost = cost
|
342
|
-
unit_type = "character"
|
343
|
-
elif unit == "token":
|
344
|
-
# Cost per 1M tokens
|
345
|
-
input_cost = cost
|
346
|
-
unit_type = "token"
|
347
|
-
elif unit == "minute":
|
348
|
-
input_cost = cost
|
349
|
-
unit_type = "minute"
|
350
|
-
elif unit == "image":
|
351
|
-
input_cost = cost
|
352
|
-
unit_type = "image"
|
353
|
-
elif unit == "request":
|
354
|
-
base_cost = cost
|
355
|
-
break
|
356
|
-
|
357
|
-
# If no pricing found, skip this model
|
358
|
-
if input_cost == 0.0 and base_cost == 0.0:
|
359
|
-
return None
|
360
|
-
|
361
|
-
return ModelPricing(
|
362
|
-
provider=provider,
|
363
|
-
model_name=model_name,
|
364
|
-
input_cost=input_cost,
|
365
|
-
output_cost=output_cost,
|
366
|
-
unit_type=unit_type,
|
367
|
-
base_cost=base_cost,
|
368
|
-
last_updated=datetime.now()
|
369
|
-
)
|
370
|
-
|
371
|
-
except Exception as e:
|
372
|
-
logger.warning(f"Failed to extract pricing for {provider}/{model_name}: {e}")
|
373
|
-
return None
|
374
|
-
|
375
|
-
def _extract_pricing_from_supabase_metadata(self, metadata: Dict[str, Any], provider: str, model_name: str) -> Optional[ModelPricing]:
|
376
|
-
"""Extract pricing information from Supabase model metadata"""
|
377
|
-
try:
|
378
|
-
# Check for pricing information in metadata
|
379
|
-
pricing_info = metadata.get('pricing', {})
|
380
|
-
|
381
|
-
# If no pricing object, look for direct pricing fields
|
382
|
-
if not pricing_info:
|
383
|
-
# Look for various pricing field formats in metadata
|
384
|
-
pricing_fields = [
|
385
|
-
'cost_per_1000_chars', 'cost_per_1000_tokens', 'cost_per_minute',
|
386
|
-
'cost_per_image', 'cost_per_request', 'input_cost', 'output_cost',
|
387
|
-
'cost_per_1k_tokens', 'cost_per_1k_chars'
|
388
|
-
]
|
389
|
-
|
390
|
-
for field in pricing_fields:
|
391
|
-
if field in metadata:
|
392
|
-
# Create a pricing object from the field
|
393
|
-
if 'char' in field:
|
394
|
-
pricing_info = {'input': metadata[field], 'unit_type': 'character'}
|
395
|
-
elif 'token' in field:
|
396
|
-
pricing_info = {'input': metadata[field], 'unit_type': 'token'}
|
397
|
-
elif 'minute' in field:
|
398
|
-
pricing_info = {'input': metadata[field], 'unit_type': 'minute'}
|
399
|
-
elif 'image' in field:
|
400
|
-
pricing_info = {'input': metadata[field], 'unit_type': 'image'}
|
401
|
-
elif 'request' in field:
|
402
|
-
pricing_info = {'base_cost': metadata[field], 'unit_type': 'request'}
|
403
|
-
break
|
404
|
-
|
405
|
-
if not pricing_info:
|
406
|
-
return None
|
407
|
-
|
408
|
-
# Extract standard pricing fields
|
409
|
-
input_cost = float(pricing_info.get('input', pricing_info.get('input_cost', 0.0)))
|
410
|
-
output_cost = float(pricing_info.get('output', pricing_info.get('output_cost', 0.0)))
|
411
|
-
unit_type = pricing_info.get('unit_type', 'token')
|
412
|
-
base_cost = float(pricing_info.get('base_cost', 0.0))
|
413
|
-
infrastructure_cost_per_hour = float(pricing_info.get('infrastructure_cost_per_hour', 0.0))
|
414
|
-
currency = pricing_info.get('currency', 'USD')
|
415
|
-
|
416
|
-
# If no pricing found, skip this model
|
417
|
-
if input_cost == 0.0 and output_cost == 0.0 and base_cost == 0.0:
|
418
|
-
return None
|
419
|
-
|
420
|
-
return ModelPricing(
|
421
|
-
provider=provider,
|
422
|
-
model_name=model_name,
|
423
|
-
input_cost=input_cost,
|
424
|
-
output_cost=output_cost,
|
425
|
-
unit_type=unit_type,
|
426
|
-
base_cost=base_cost,
|
427
|
-
infrastructure_cost_per_hour=infrastructure_cost_per_hour,
|
428
|
-
currency=currency,
|
429
|
-
last_updated=datetime.now(),
|
430
|
-
metadata=pricing_info
|
431
|
-
)
|
432
|
-
|
433
|
-
except Exception as e:
|
434
|
-
logger.warning(f"Failed to extract pricing from Supabase metadata for {provider}/{model_name}: {e}")
|
435
|
-
return None
|
436
|
-
|
437
190
|
def get_model_pricing(self, provider: str, model_name: str) -> Optional[ModelPricing]:
|
438
191
|
"""Get pricing information for a specific model"""
|
439
192
|
self._refresh_if_needed()
|
@@ -0,0 +1,366 @@
|
|
1
|
+
"""
|
2
|
+
Circuit Breaker Implementation for External Service Calls
|
3
|
+
|
4
|
+
Provides resilience patterns including:
|
5
|
+
- Circuit breaker for external API calls
|
6
|
+
- Retry with exponential backoff
|
7
|
+
- Timeout handling
|
8
|
+
- Health monitoring
|
9
|
+
"""
|
10
|
+
|
11
|
+
import asyncio
|
12
|
+
import time
|
13
|
+
import logging
|
14
|
+
import functools
|
15
|
+
from typing import Dict, Any, Optional, Callable, Union, List
|
16
|
+
from enum import Enum
|
17
|
+
from dataclasses import dataclass, field
|
18
|
+
from circuitbreaker import circuit
|
19
|
+
import structlog
|
20
|
+
|
21
|
+
logger = structlog.get_logger(__name__)
|
22
|
+
|
23
|
+
class CircuitState(Enum):
|
24
|
+
CLOSED = "closed" # Normal operation
|
25
|
+
OPEN = "open" # Circuit breaker triggered, blocking calls
|
26
|
+
HALF_OPEN = "half_open" # Testing if service is recovered
|
27
|
+
|
28
|
+
@dataclass
|
29
|
+
class CircuitBreakerConfig:
|
30
|
+
"""Configuration for circuit breaker"""
|
31
|
+
failure_threshold: int = 5 # Number of failures before opening
|
32
|
+
recovery_timeout: int = 30 # Seconds before trying to recover
|
33
|
+
expected_exception: type = Exception # Exception type that triggers circuit breaker
|
34
|
+
success_threshold: int = 3 # Successful calls needed to close circuit
|
35
|
+
timeout: float = 30.0 # Request timeout in seconds
|
36
|
+
|
37
|
+
@dataclass
|
38
|
+
class CircuitBreakerStats:
|
39
|
+
"""Circuit breaker statistics"""
|
40
|
+
total_calls: int = 0
|
41
|
+
successful_calls: int = 0
|
42
|
+
failed_calls: int = 0
|
43
|
+
consecutive_failures: int = 0
|
44
|
+
consecutive_successes: int = 0
|
45
|
+
last_failure_time: Optional[float] = None
|
46
|
+
last_success_time: Optional[float] = None
|
47
|
+
state_changes: List[Dict[str, Any]] = field(default_factory=list)
|
48
|
+
|
49
|
+
class EnhancedCircuitBreaker:
|
50
|
+
"""Enhanced circuit breaker with monitoring and configuration"""
|
51
|
+
|
52
|
+
def __init__(self, name: str, config: CircuitBreakerConfig):
|
53
|
+
self.name = name
|
54
|
+
self.config = config
|
55
|
+
self.stats = CircuitBreakerStats()
|
56
|
+
self.state = CircuitState.CLOSED
|
57
|
+
self.last_state_change = time.time()
|
58
|
+
|
59
|
+
# Create underlying circuit breaker
|
60
|
+
self._circuit = circuit(
|
61
|
+
failure_threshold=config.failure_threshold,
|
62
|
+
recovery_timeout=config.recovery_timeout,
|
63
|
+
expected_exception=config.expected_exception
|
64
|
+
)
|
65
|
+
|
66
|
+
logger.info("Circuit breaker initialized", name=name, config=config)
|
67
|
+
|
68
|
+
def __call__(self, func: Callable):
|
69
|
+
"""Decorator to wrap functions with circuit breaker"""
|
70
|
+
|
71
|
+
@functools.wraps(func)
|
72
|
+
async def wrapper(*args, **kwargs):
|
73
|
+
return await self.call(func, *args, **kwargs)
|
74
|
+
|
75
|
+
# Apply the circuit breaker decorator
|
76
|
+
wrapped_func = self._circuit(wrapper)
|
77
|
+
return wrapped_func
|
78
|
+
|
79
|
+
async def call(self, func: Callable, *args, **kwargs):
|
80
|
+
"""Execute function with circuit breaker protection"""
|
81
|
+
start_time = time.time()
|
82
|
+
|
83
|
+
try:
|
84
|
+
# Record call attempt
|
85
|
+
self.stats.total_calls += 1
|
86
|
+
|
87
|
+
# Check if circuit is open
|
88
|
+
if self.state == CircuitState.OPEN:
|
89
|
+
if time.time() - self.last_state_change < self.config.recovery_timeout:
|
90
|
+
raise CircuitBreakerOpenException(
|
91
|
+
f"Circuit breaker '{self.name}' is open"
|
92
|
+
)
|
93
|
+
else:
|
94
|
+
# Try to move to half-open state
|
95
|
+
self._change_state(CircuitState.HALF_OPEN)
|
96
|
+
|
97
|
+
# Execute the function with timeout
|
98
|
+
try:
|
99
|
+
if asyncio.iscoroutinefunction(func):
|
100
|
+
result = await asyncio.wait_for(
|
101
|
+
func(*args, **kwargs),
|
102
|
+
timeout=self.config.timeout
|
103
|
+
)
|
104
|
+
else:
|
105
|
+
result = func(*args, **kwargs)
|
106
|
+
|
107
|
+
# Record success
|
108
|
+
self._record_success()
|
109
|
+
|
110
|
+
return result
|
111
|
+
|
112
|
+
except asyncio.TimeoutError:
|
113
|
+
self._record_failure()
|
114
|
+
raise CircuitBreakerTimeoutException(
|
115
|
+
f"Timeout after {self.config.timeout}s for '{self.name}'"
|
116
|
+
)
|
117
|
+
except self.config.expected_exception as e:
|
118
|
+
self._record_failure()
|
119
|
+
raise
|
120
|
+
|
121
|
+
except Exception as e:
|
122
|
+
duration = time.time() - start_time
|
123
|
+
logger.error(
|
124
|
+
"Circuit breaker call failed",
|
125
|
+
name=self.name,
|
126
|
+
error=str(e),
|
127
|
+
duration=duration,
|
128
|
+
state=self.state.value
|
129
|
+
)
|
130
|
+
raise
|
131
|
+
|
132
|
+
def _record_success(self):
|
133
|
+
"""Record successful call"""
|
134
|
+
self.stats.successful_calls += 1
|
135
|
+
self.stats.consecutive_successes += 1
|
136
|
+
self.stats.consecutive_failures = 0
|
137
|
+
self.stats.last_success_time = time.time()
|
138
|
+
|
139
|
+
# If we're in half-open state and have enough successes, close the circuit
|
140
|
+
if (self.state == CircuitState.HALF_OPEN and
|
141
|
+
self.stats.consecutive_successes >= self.config.success_threshold):
|
142
|
+
self._change_state(CircuitState.CLOSED)
|
143
|
+
|
144
|
+
logger.debug(
|
145
|
+
"Circuit breaker success",
|
146
|
+
name=self.name,
|
147
|
+
consecutive_successes=self.stats.consecutive_successes,
|
148
|
+
state=self.state.value
|
149
|
+
)
|
150
|
+
|
151
|
+
def _record_failure(self):
|
152
|
+
"""Record failed call"""
|
153
|
+
self.stats.failed_calls += 1
|
154
|
+
self.stats.consecutive_failures += 1
|
155
|
+
self.stats.consecutive_successes = 0
|
156
|
+
self.stats.last_failure_time = time.time()
|
157
|
+
|
158
|
+
# Check if we should open the circuit
|
159
|
+
if (self.state == CircuitState.CLOSED and
|
160
|
+
self.stats.consecutive_failures >= self.config.failure_threshold):
|
161
|
+
self._change_state(CircuitState.OPEN)
|
162
|
+
elif self.state == CircuitState.HALF_OPEN:
|
163
|
+
# Any failure in half-open state reopens the circuit
|
164
|
+
self._change_state(CircuitState.OPEN)
|
165
|
+
|
166
|
+
logger.warning(
|
167
|
+
"Circuit breaker failure",
|
168
|
+
name=self.name,
|
169
|
+
consecutive_failures=self.stats.consecutive_failures,
|
170
|
+
state=self.state.value
|
171
|
+
)
|
172
|
+
|
173
|
+
def _change_state(self, new_state: CircuitState):
|
174
|
+
"""Change circuit breaker state"""
|
175
|
+
old_state = self.state
|
176
|
+
self.state = new_state
|
177
|
+
self.last_state_change = time.time()
|
178
|
+
|
179
|
+
# Record state change
|
180
|
+
state_change = {
|
181
|
+
"from_state": old_state.value,
|
182
|
+
"to_state": new_state.value,
|
183
|
+
"timestamp": self.last_state_change,
|
184
|
+
"total_calls": self.stats.total_calls,
|
185
|
+
"consecutive_failures": self.stats.consecutive_failures
|
186
|
+
}
|
187
|
+
self.stats.state_changes.append(state_change)
|
188
|
+
|
189
|
+
logger.warning(
|
190
|
+
"Circuit breaker state changed",
|
191
|
+
name=self.name,
|
192
|
+
from_state=old_state.value,
|
193
|
+
to_state=new_state.value,
|
194
|
+
consecutive_failures=self.stats.consecutive_failures
|
195
|
+
)
|
196
|
+
|
197
|
+
def get_stats(self) -> Dict[str, Any]:
|
198
|
+
"""Get circuit breaker statistics"""
|
199
|
+
return {
|
200
|
+
"name": self.name,
|
201
|
+
"state": self.state.value,
|
202
|
+
"total_calls": self.stats.total_calls,
|
203
|
+
"successful_calls": self.stats.successful_calls,
|
204
|
+
"failed_calls": self.stats.failed_calls,
|
205
|
+
"success_rate": (
|
206
|
+
self.stats.successful_calls / self.stats.total_calls
|
207
|
+
if self.stats.total_calls > 0 else 0
|
208
|
+
),
|
209
|
+
"consecutive_failures": self.stats.consecutive_failures,
|
210
|
+
"consecutive_successes": self.stats.consecutive_successes,
|
211
|
+
"last_failure_time": self.stats.last_failure_time,
|
212
|
+
"last_success_time": self.stats.last_success_time,
|
213
|
+
"last_state_change": self.last_state_change,
|
214
|
+
"config": {
|
215
|
+
"failure_threshold": self.config.failure_threshold,
|
216
|
+
"recovery_timeout": self.config.recovery_timeout,
|
217
|
+
"timeout": self.config.timeout
|
218
|
+
}
|
219
|
+
}
|
220
|
+
|
221
|
+
def reset(self):
|
222
|
+
"""Reset circuit breaker to initial state"""
|
223
|
+
self.state = CircuitState.CLOSED
|
224
|
+
self.stats = CircuitBreakerStats()
|
225
|
+
self.last_state_change = time.time()
|
226
|
+
|
227
|
+
logger.info("Circuit breaker reset", name=self.name)
|
228
|
+
|
229
|
+
class CircuitBreakerOpenException(Exception):
|
230
|
+
"""Exception raised when circuit breaker is open"""
|
231
|
+
pass
|
232
|
+
|
233
|
+
class CircuitBreakerTimeoutException(Exception):
|
234
|
+
"""Exception raised when call times out"""
|
235
|
+
pass
|
236
|
+
|
237
|
+
# Global circuit breaker registry
|
238
|
+
_circuit_breakers: Dict[str, EnhancedCircuitBreaker] = {}
|
239
|
+
|
240
|
+
def get_circuit_breaker(
|
241
|
+
name: str,
|
242
|
+
config: Optional[CircuitBreakerConfig] = None
|
243
|
+
) -> EnhancedCircuitBreaker:
|
244
|
+
"""Get or create a circuit breaker"""
|
245
|
+
if name not in _circuit_breakers:
|
246
|
+
if config is None:
|
247
|
+
config = CircuitBreakerConfig()
|
248
|
+
_circuit_breakers[name] = EnhancedCircuitBreaker(name, config)
|
249
|
+
|
250
|
+
return _circuit_breakers[name]
|
251
|
+
|
252
|
+
def circuit_breaker(
|
253
|
+
name: str,
|
254
|
+
failure_threshold: int = 5,
|
255
|
+
recovery_timeout: int = 30,
|
256
|
+
timeout: float = 30.0,
|
257
|
+
expected_exception: type = Exception
|
258
|
+
):
|
259
|
+
"""Decorator for applying circuit breaker to functions"""
|
260
|
+
config = CircuitBreakerConfig(
|
261
|
+
failure_threshold=failure_threshold,
|
262
|
+
recovery_timeout=recovery_timeout,
|
263
|
+
timeout=timeout,
|
264
|
+
expected_exception=expected_exception
|
265
|
+
)
|
266
|
+
|
267
|
+
breaker = get_circuit_breaker(name, config)
|
268
|
+
return breaker
|
269
|
+
|
270
|
+
# Predefined circuit breakers for common services
|
271
|
+
def openai_circuit_breaker():
|
272
|
+
"""Circuit breaker for OpenAI API calls"""
|
273
|
+
return circuit_breaker(
|
274
|
+
name="openai",
|
275
|
+
failure_threshold=3,
|
276
|
+
recovery_timeout=60,
|
277
|
+
timeout=120.0 # OpenAI can be slow
|
278
|
+
)
|
279
|
+
|
280
|
+
def replicate_circuit_breaker():
|
281
|
+
"""Circuit breaker for Replicate API calls"""
|
282
|
+
return circuit_breaker(
|
283
|
+
name="replicate",
|
284
|
+
failure_threshold=3,
|
285
|
+
recovery_timeout=45,
|
286
|
+
timeout=300.0 # Replicate can be very slow for image generation
|
287
|
+
)
|
288
|
+
|
289
|
+
def database_circuit_breaker():
|
290
|
+
"""Circuit breaker for database calls"""
|
291
|
+
return circuit_breaker(
|
292
|
+
name="database",
|
293
|
+
failure_threshold=5,
|
294
|
+
recovery_timeout=20,
|
295
|
+
timeout=10.0
|
296
|
+
)
|
297
|
+
|
298
|
+
def redis_circuit_breaker():
|
299
|
+
"""Circuit breaker for Redis calls"""
|
300
|
+
return circuit_breaker(
|
301
|
+
name="redis",
|
302
|
+
failure_threshold=3,
|
303
|
+
recovery_timeout=15,
|
304
|
+
timeout=5.0
|
305
|
+
)
|
306
|
+
|
307
|
+
# Health check for all circuit breakers
|
308
|
+
async def check_circuit_breakers_health() -> Dict[str, Any]:
|
309
|
+
"""Check health of all circuit breakers"""
|
310
|
+
health_status = {
|
311
|
+
"circuit_breakers": {},
|
312
|
+
"total_breakers": len(_circuit_breakers),
|
313
|
+
"open_breakers": 0,
|
314
|
+
"status": "healthy"
|
315
|
+
}
|
316
|
+
|
317
|
+
for name, breaker in _circuit_breakers.items():
|
318
|
+
stats = breaker.get_stats()
|
319
|
+
health_status["circuit_breakers"][name] = stats
|
320
|
+
|
321
|
+
if stats["state"] == "open":
|
322
|
+
health_status["open_breakers"] += 1
|
323
|
+
|
324
|
+
# Overall health status
|
325
|
+
if health_status["open_breakers"] > 0:
|
326
|
+
health_status["status"] = "degraded"
|
327
|
+
|
328
|
+
return health_status
|
329
|
+
|
330
|
+
# Utility functions for retry with exponential backoff
|
331
|
+
async def retry_with_backoff(
|
332
|
+
func: Callable,
|
333
|
+
max_retries: int = 3,
|
334
|
+
base_delay: float = 1.0,
|
335
|
+
max_delay: float = 60.0,
|
336
|
+
exponential_base: float = 2.0,
|
337
|
+
exceptions: tuple = (Exception,)
|
338
|
+
):
|
339
|
+
"""Retry function with exponential backoff"""
|
340
|
+
|
341
|
+
for attempt in range(max_retries + 1):
|
342
|
+
try:
|
343
|
+
if asyncio.iscoroutinefunction(func):
|
344
|
+
return await func()
|
345
|
+
else:
|
346
|
+
return func()
|
347
|
+
except exceptions as e:
|
348
|
+
if attempt == max_retries:
|
349
|
+
logger.error(
|
350
|
+
"All retry attempts failed",
|
351
|
+
attempts=attempt + 1,
|
352
|
+
error=str(e)
|
353
|
+
)
|
354
|
+
raise
|
355
|
+
|
356
|
+
delay = min(base_delay * (exponential_base ** attempt), max_delay)
|
357
|
+
|
358
|
+
logger.warning(
|
359
|
+
"Retry attempt failed, backing off",
|
360
|
+
attempt=attempt + 1,
|
361
|
+
max_retries=max_retries,
|
362
|
+
delay=delay,
|
363
|
+
error=str(e)
|
364
|
+
)
|
365
|
+
|
366
|
+
await asyncio.sleep(delay)
|