loom-agent 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of loom-agent might be problematic. Click here for more details.
- loom/__init__.py +77 -0
- loom/agent.py +217 -0
- loom/agents/__init__.py +10 -0
- loom/agents/refs.py +28 -0
- loom/agents/registry.py +50 -0
- loom/builtin/compression/__init__.py +4 -0
- loom/builtin/compression/structured.py +79 -0
- loom/builtin/embeddings/__init__.py +9 -0
- loom/builtin/embeddings/openai_embedding.py +135 -0
- loom/builtin/embeddings/sentence_transformers_embedding.py +145 -0
- loom/builtin/llms/__init__.py +8 -0
- loom/builtin/llms/mock.py +34 -0
- loom/builtin/llms/openai.py +168 -0
- loom/builtin/llms/rule.py +102 -0
- loom/builtin/memory/__init__.py +5 -0
- loom/builtin/memory/in_memory.py +21 -0
- loom/builtin/memory/persistent_memory.py +278 -0
- loom/builtin/retriever/__init__.py +9 -0
- loom/builtin/retriever/chroma_store.py +265 -0
- loom/builtin/retriever/in_memory.py +106 -0
- loom/builtin/retriever/milvus_store.py +307 -0
- loom/builtin/retriever/pinecone_store.py +237 -0
- loom/builtin/retriever/qdrant_store.py +274 -0
- loom/builtin/retriever/vector_store.py +128 -0
- loom/builtin/retriever/vector_store_config.py +217 -0
- loom/builtin/tools/__init__.py +32 -0
- loom/builtin/tools/calculator.py +49 -0
- loom/builtin/tools/document_search.py +111 -0
- loom/builtin/tools/glob.py +27 -0
- loom/builtin/tools/grep.py +56 -0
- loom/builtin/tools/http_request.py +86 -0
- loom/builtin/tools/python_repl.py +73 -0
- loom/builtin/tools/read_file.py +32 -0
- loom/builtin/tools/task.py +158 -0
- loom/builtin/tools/web_search.py +64 -0
- loom/builtin/tools/write_file.py +31 -0
- loom/callbacks/base.py +9 -0
- loom/callbacks/logging.py +12 -0
- loom/callbacks/metrics.py +27 -0
- loom/callbacks/observability.py +248 -0
- loom/components/agent.py +107 -0
- loom/core/agent_executor.py +450 -0
- loom/core/circuit_breaker.py +178 -0
- loom/core/compression_manager.py +329 -0
- loom/core/context_retriever.py +185 -0
- loom/core/error_classifier.py +193 -0
- loom/core/errors.py +66 -0
- loom/core/message_queue.py +167 -0
- loom/core/permission_store.py +62 -0
- loom/core/permissions.py +69 -0
- loom/core/scheduler.py +125 -0
- loom/core/steering_control.py +47 -0
- loom/core/structured_logger.py +279 -0
- loom/core/subagent_pool.py +232 -0
- loom/core/system_prompt.py +141 -0
- loom/core/system_reminders.py +283 -0
- loom/core/tool_pipeline.py +113 -0
- loom/core/types.py +269 -0
- loom/interfaces/compressor.py +59 -0
- loom/interfaces/embedding.py +51 -0
- loom/interfaces/llm.py +33 -0
- loom/interfaces/memory.py +29 -0
- loom/interfaces/retriever.py +179 -0
- loom/interfaces/tool.py +27 -0
- loom/interfaces/vector_store.py +80 -0
- loom/llm/__init__.py +14 -0
- loom/llm/config.py +228 -0
- loom/llm/factory.py +111 -0
- loom/llm/model_health.py +235 -0
- loom/llm/model_pool_advanced.py +305 -0
- loom/llm/pool.py +170 -0
- loom/llm/registry.py +201 -0
- loom/mcp/__init__.py +4 -0
- loom/mcp/client.py +86 -0
- loom/mcp/registry.py +58 -0
- loom/mcp/tool_adapter.py +48 -0
- loom/observability/__init__.py +5 -0
- loom/patterns/__init__.py +5 -0
- loom/patterns/multi_agent.py +123 -0
- loom/patterns/rag.py +262 -0
- loom/plugins/registry.py +55 -0
- loom/resilience/__init__.py +5 -0
- loom/tooling.py +72 -0
- loom/utils/agent_loader.py +218 -0
- loom/utils/token_counter.py +19 -0
- loom_agent-0.0.1.dist-info/METADATA +457 -0
- loom_agent-0.0.1.dist-info/RECORD +89 -0
- loom_agent-0.0.1.dist-info/WHEEL +4 -0
- loom_agent-0.0.1.dist-info/licenses/LICENSE +21 -0
loom/llm/factory.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""LLM 工厂模式
|
|
2
|
+
|
|
3
|
+
根据配置自动创建对应的 LLM 实例,支持:
|
|
4
|
+
- 自动注册与发现
|
|
5
|
+
- 统一的创建接口
|
|
6
|
+
- 类型安全
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import Dict, Type, Any, cast
|
|
12
|
+
from .config import LLMConfig, LLMProvider
|
|
13
|
+
from loom.interfaces.llm import BaseLLM
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LLMFactory:
|
|
17
|
+
"""LLM 工厂
|
|
18
|
+
|
|
19
|
+
负责根据配置创建对应的 LLM 实例。
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
_registry: Dict[LLMProvider, Type[BaseLLM]] = {}
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def register(cls, provider: LLMProvider, llm_class: Type[BaseLLM]):
|
|
26
|
+
cls._registry[provider] = llm_class
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def create(cls, config: LLMConfig) -> BaseLLM:
|
|
30
|
+
"""根据配置创建 LLM 实例。优先调用 from_config,其次按 provider 适配构造。"""
|
|
31
|
+
cls._ensure_registered()
|
|
32
|
+
|
|
33
|
+
if config.provider not in cls._registry:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Unsupported LLM provider: {config.provider}. Available: {list(cls._registry.keys())}"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
llm_class = cls._registry[config.provider]
|
|
39
|
+
|
|
40
|
+
# 优先使用 classmethod from_config(config)
|
|
41
|
+
from_config = getattr(llm_class, "from_config", None)
|
|
42
|
+
if callable(from_config):
|
|
43
|
+
return cast(BaseLLM, from_config(config))
|
|
44
|
+
|
|
45
|
+
# 无 from_config,按 provider 适配常用构造参数
|
|
46
|
+
if config.provider in (LLMProvider.OPENAI, LLMProvider.AZURE_OPENAI, LLMProvider.CUSTOM):
|
|
47
|
+
kwargs: dict[str, Any] = {
|
|
48
|
+
"model": config.model_name,
|
|
49
|
+
"temperature": config.temperature,
|
|
50
|
+
}
|
|
51
|
+
if config.max_tokens is not None:
|
|
52
|
+
kwargs["max_tokens"] = config.max_tokens
|
|
53
|
+
if config.base_url is not None:
|
|
54
|
+
kwargs["base_url"] = config.base_url
|
|
55
|
+
if config.api_key is not None:
|
|
56
|
+
kwargs["api_key"] = config.api_key
|
|
57
|
+
kwargs.update(config.extra_params)
|
|
58
|
+
return llm_class(**kwargs) # type: ignore[arg-type]
|
|
59
|
+
|
|
60
|
+
# 其他提供商暂不内置适配
|
|
61
|
+
raise ValueError(f"Provider {config.provider} has no default constructor mapping")
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def from_dict(cls, config_dict: Dict[str, Any]) -> BaseLLM:
|
|
65
|
+
config = LLMConfig.from_dict(config_dict)
|
|
66
|
+
return cls.create(config)
|
|
67
|
+
|
|
68
|
+
# 便捷创建方法
|
|
69
|
+
@classmethod
|
|
70
|
+
def create_openai(
|
|
71
|
+
cls,
|
|
72
|
+
api_key: str,
|
|
73
|
+
model: str = "gpt-4",
|
|
74
|
+
**kwargs
|
|
75
|
+
) -> BaseLLM:
|
|
76
|
+
config = LLMConfig.openai(api_key=api_key, model=model, **kwargs)
|
|
77
|
+
return cls.create(config)
|
|
78
|
+
|
|
79
|
+
# 私有:延迟注册可用实现
|
|
80
|
+
@classmethod
|
|
81
|
+
def _ensure_registered(cls):
|
|
82
|
+
if cls._registry:
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
# 注册内置实现(按实际存在的实现)
|
|
86
|
+
try:
|
|
87
|
+
from loom.builtin.llms.openai import OpenAILLM
|
|
88
|
+
cls.register(LLMProvider.OPENAI, OpenAILLM)
|
|
89
|
+
# 兼容自定义/代理/azure: 复用 OpenAI 客户端风格
|
|
90
|
+
cls.register(LLMProvider.AZURE_OPENAI, OpenAILLM)
|
|
91
|
+
cls.register(LLMProvider.CUSTOM, OpenAILLM)
|
|
92
|
+
except Exception:
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
from loom.builtin.llms.mock import MockLLM
|
|
97
|
+
# 仅用于测试时手动注册(无 provider 枚举映射)
|
|
98
|
+
except Exception:
|
|
99
|
+
pass
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
from loom.builtin.llms.rule import RuleLLM
|
|
103
|
+
# 规则引擎型 LLM(不绑定 provider)
|
|
104
|
+
except Exception:
|
|
105
|
+
pass
|
|
106
|
+
|
|
107
|
+
@classmethod
|
|
108
|
+
def list_available_providers(cls) -> list[str]:
|
|
109
|
+
cls._ensure_registered()
|
|
110
|
+
return [provider.value for provider in cls._registry.keys()]
|
|
111
|
+
|
loom/llm/model_health.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
"""US8: Model Health Checking
|
|
2
|
+
|
|
3
|
+
Tracks model health and enables intelligent fallback decisions.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import time
|
|
9
|
+
from typing import Dict, Optional
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from enum import Enum
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class HealthStatus(str, Enum):
|
|
15
|
+
"""Model health status."""
|
|
16
|
+
HEALTHY = "healthy"
|
|
17
|
+
DEGRADED = "degraded"
|
|
18
|
+
UNHEALTHY = "unhealthy"
|
|
19
|
+
UNKNOWN = "unknown"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class HealthMetrics:
|
|
24
|
+
"""Health metrics for a model."""
|
|
25
|
+
status: HealthStatus
|
|
26
|
+
success_count: int = 0
|
|
27
|
+
failure_count: int = 0
|
|
28
|
+
total_requests: int = 0
|
|
29
|
+
last_success_time: Optional[float] = None
|
|
30
|
+
last_failure_time: Optional[float] = None
|
|
31
|
+
avg_latency_ms: float = 0.0
|
|
32
|
+
consecutive_failures: int = 0
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def success_rate(self) -> float:
|
|
36
|
+
"""Calculate success rate."""
|
|
37
|
+
if self.total_requests == 0:
|
|
38
|
+
return 0.0
|
|
39
|
+
return self.success_count / self.total_requests
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def failure_rate(self) -> float:
|
|
43
|
+
"""Calculate failure rate."""
|
|
44
|
+
return 1.0 - self.success_rate
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ModelHealthChecker:
|
|
48
|
+
"""Tracks and evaluates model health.
|
|
49
|
+
|
|
50
|
+
Example:
|
|
51
|
+
checker = ModelHealthChecker()
|
|
52
|
+
|
|
53
|
+
# Record success
|
|
54
|
+
checker.record_success("gpt-4", latency_ms=234.5)
|
|
55
|
+
|
|
56
|
+
# Record failure
|
|
57
|
+
checker.record_failure("gpt-4", error="timeout")
|
|
58
|
+
|
|
59
|
+
# Check health
|
|
60
|
+
status = checker.get_status("gpt-4")
|
|
61
|
+
print(f"Model health: {status}") # HEALTHY / DEGRADED / UNHEALTHY
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
degraded_threshold: float = 0.8, # <80% success rate = degraded
|
|
67
|
+
unhealthy_threshold: float = 0.5, # <50% success rate = unhealthy
|
|
68
|
+
consecutive_failure_threshold: int = 5, # 5 consecutive failures = unhealthy
|
|
69
|
+
health_check_window: int = 100, # Last 100 requests
|
|
70
|
+
):
|
|
71
|
+
"""Initialize health checker.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
degraded_threshold: Success rate threshold for degraded status
|
|
75
|
+
unhealthy_threshold: Success rate threshold for unhealthy status
|
|
76
|
+
consecutive_failure_threshold: Consecutive failures to mark unhealthy
|
|
77
|
+
health_check_window: Number of recent requests to consider
|
|
78
|
+
"""
|
|
79
|
+
self.degraded_threshold = degraded_threshold
|
|
80
|
+
self.unhealthy_threshold = unhealthy_threshold
|
|
81
|
+
self.consecutive_failure_threshold = consecutive_failure_threshold
|
|
82
|
+
self.health_check_window = health_check_window
|
|
83
|
+
|
|
84
|
+
self._metrics: Dict[str, HealthMetrics] = {}
|
|
85
|
+
self._latency_samples: Dict[str, list[float]] = {} # Rolling window
|
|
86
|
+
|
|
87
|
+
def record_success(
|
|
88
|
+
self,
|
|
89
|
+
model_id: str,
|
|
90
|
+
latency_ms: float = 0.0
|
|
91
|
+
) -> None:
|
|
92
|
+
"""Record a successful request.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
model_id: Model identifier
|
|
96
|
+
latency_ms: Request latency in milliseconds
|
|
97
|
+
"""
|
|
98
|
+
if model_id not in self._metrics:
|
|
99
|
+
self._metrics[model_id] = HealthMetrics(status=HealthStatus.UNKNOWN)
|
|
100
|
+
|
|
101
|
+
metrics = self._metrics[model_id]
|
|
102
|
+
metrics.success_count += 1
|
|
103
|
+
metrics.total_requests += 1
|
|
104
|
+
metrics.last_success_time = time.time()
|
|
105
|
+
metrics.consecutive_failures = 0
|
|
106
|
+
|
|
107
|
+
# Update rolling latency
|
|
108
|
+
if model_id not in self._latency_samples:
|
|
109
|
+
self._latency_samples[model_id] = []
|
|
110
|
+
|
|
111
|
+
self._latency_samples[model_id].append(latency_ms)
|
|
112
|
+
|
|
113
|
+
# Keep only recent samples
|
|
114
|
+
if len(self._latency_samples[model_id]) > self.health_check_window:
|
|
115
|
+
self._latency_samples[model_id].pop(0)
|
|
116
|
+
|
|
117
|
+
# Update average latency
|
|
118
|
+
if self._latency_samples[model_id]:
|
|
119
|
+
metrics.avg_latency_ms = sum(self._latency_samples[model_id]) / len(self._latency_samples[model_id])
|
|
120
|
+
|
|
121
|
+
# Update status
|
|
122
|
+
self._update_status(model_id)
|
|
123
|
+
|
|
124
|
+
def record_failure(
|
|
125
|
+
self,
|
|
126
|
+
model_id: str,
|
|
127
|
+
error: Optional[str] = None
|
|
128
|
+
) -> None:
|
|
129
|
+
"""Record a failed request.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
model_id: Model identifier
|
|
133
|
+
error: Error message
|
|
134
|
+
"""
|
|
135
|
+
if model_id not in self._metrics:
|
|
136
|
+
self._metrics[model_id] = HealthMetrics(status=HealthStatus.UNKNOWN)
|
|
137
|
+
|
|
138
|
+
metrics = self._metrics[model_id]
|
|
139
|
+
metrics.failure_count += 1
|
|
140
|
+
metrics.total_requests += 1
|
|
141
|
+
metrics.last_failure_time = time.time()
|
|
142
|
+
metrics.consecutive_failures += 1
|
|
143
|
+
|
|
144
|
+
# Update status
|
|
145
|
+
self._update_status(model_id)
|
|
146
|
+
|
|
147
|
+
def _update_status(self, model_id: str) -> None:
|
|
148
|
+
"""Update health status based on metrics.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
model_id: Model identifier
|
|
152
|
+
"""
|
|
153
|
+
metrics = self._metrics[model_id]
|
|
154
|
+
|
|
155
|
+
# Check consecutive failures
|
|
156
|
+
if metrics.consecutive_failures >= self.consecutive_failure_threshold:
|
|
157
|
+
metrics.status = HealthStatus.UNHEALTHY
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
# Check success rate (only if we have enough data)
|
|
161
|
+
if metrics.total_requests > 0:
|
|
162
|
+
success_rate = metrics.success_rate
|
|
163
|
+
|
|
164
|
+
if success_rate >= self.degraded_threshold:
|
|
165
|
+
metrics.status = HealthStatus.HEALTHY
|
|
166
|
+
elif success_rate >= self.unhealthy_threshold:
|
|
167
|
+
metrics.status = HealthStatus.DEGRADED
|
|
168
|
+
else:
|
|
169
|
+
metrics.status = HealthStatus.UNHEALTHY
|
|
170
|
+
else:
|
|
171
|
+
metrics.status = HealthStatus.UNKNOWN
|
|
172
|
+
|
|
173
|
+
def get_status(self, model_id: str) -> HealthStatus:
|
|
174
|
+
"""Get current health status.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
model_id: Model identifier
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Current health status
|
|
181
|
+
"""
|
|
182
|
+
if model_id not in self._metrics:
|
|
183
|
+
return HealthStatus.UNKNOWN
|
|
184
|
+
|
|
185
|
+
return self._metrics[model_id].status
|
|
186
|
+
|
|
187
|
+
def get_metrics(self, model_id: str) -> Optional[HealthMetrics]:
|
|
188
|
+
"""Get detailed health metrics.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
model_id: Model identifier
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
HealthMetrics or None if not tracked
|
|
195
|
+
"""
|
|
196
|
+
return self._metrics.get(model_id)
|
|
197
|
+
|
|
198
|
+
def is_healthy(self, model_id: str) -> bool:
|
|
199
|
+
"""Check if model is healthy.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
model_id: Model identifier
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
True if healthy, False otherwise
|
|
206
|
+
"""
|
|
207
|
+
status = self.get_status(model_id)
|
|
208
|
+
return status == HealthStatus.HEALTHY
|
|
209
|
+
|
|
210
|
+
def get_all_healthy_models(self) -> list[str]:
|
|
211
|
+
"""Get list of all healthy models.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
List of model IDs with healthy status
|
|
215
|
+
"""
|
|
216
|
+
return [
|
|
217
|
+
model_id
|
|
218
|
+
for model_id, metrics in self._metrics.items()
|
|
219
|
+
if metrics.status == HealthStatus.HEALTHY
|
|
220
|
+
]
|
|
221
|
+
|
|
222
|
+
def reset(self, model_id: Optional[str] = None) -> None:
|
|
223
|
+
"""Reset health metrics.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
model_id: Model to reset (None = reset all)
|
|
227
|
+
"""
|
|
228
|
+
if model_id:
|
|
229
|
+
if model_id in self._metrics:
|
|
230
|
+
self._metrics[model_id] = HealthMetrics(status=HealthStatus.UNKNOWN)
|
|
231
|
+
if model_id in self._latency_samples:
|
|
232
|
+
self._latency_samples[model_id] = []
|
|
233
|
+
else:
|
|
234
|
+
self._metrics.clear()
|
|
235
|
+
self._latency_samples.clear()
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
"""US8: Advanced Model Pool with Fallback Chain
|
|
2
|
+
|
|
3
|
+
Provides intelligent model selection with health-aware fallback.
|
|
4
|
+
|
|
5
|
+
Features:
|
|
6
|
+
- Fallback chain: [primary, fallback1, fallback2]
|
|
7
|
+
- Automatic fallback on 5xx errors
|
|
8
|
+
- Health-based model selection
|
|
9
|
+
- Connection pooling for reduced latency
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import asyncio
|
|
15
|
+
from typing import List, Optional, Dict, Any
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
|
|
18
|
+
from loom.interfaces.llm import BaseLLM
|
|
19
|
+
from loom.llm.model_health import ModelHealthChecker, HealthStatus
|
|
20
|
+
from loom.core.error_classifier import ErrorClassifier, ErrorCategory
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class ModelConfig:
|
|
25
|
+
"""Configuration for a model in the pool."""
|
|
26
|
+
model_id: str
|
|
27
|
+
llm: BaseLLM
|
|
28
|
+
priority: int = 0 # Higher priority = preferred (for same health status)
|
|
29
|
+
max_concurrent: int = 10 # Max concurrent requests to this model
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class FallbackChain:
|
|
33
|
+
"""Manages fallback chain for model selection.
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
chain = FallbackChain([
|
|
37
|
+
ModelConfig("gpt-4", gpt4_llm, priority=100),
|
|
38
|
+
ModelConfig("gpt-3.5-turbo", gpt35_llm, priority=50),
|
|
39
|
+
ModelConfig("claude-2", claude_llm, priority=30),
|
|
40
|
+
])
|
|
41
|
+
|
|
42
|
+
# Automatically selects best available model
|
|
43
|
+
llm = await chain.get_next_model()
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
models: List[ModelConfig],
|
|
49
|
+
health_checker: Optional[ModelHealthChecker] = None,
|
|
50
|
+
):
|
|
51
|
+
"""Initialize fallback chain.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
models: List of model configurations (ordered by preference)
|
|
55
|
+
health_checker: Health checker instance
|
|
56
|
+
"""
|
|
57
|
+
self.models = models
|
|
58
|
+
self.health_checker = health_checker or ModelHealthChecker()
|
|
59
|
+
self._model_map = {m.model_id: m for m in models}
|
|
60
|
+
self._semaphores: Dict[str, asyncio.Semaphore] = {}
|
|
61
|
+
|
|
62
|
+
# Create semaphores for rate limiting
|
|
63
|
+
for model in models:
|
|
64
|
+
self._semaphores[model.model_id] = asyncio.Semaphore(model.max_concurrent)
|
|
65
|
+
|
|
66
|
+
def get_next_model(
|
|
67
|
+
self,
|
|
68
|
+
skip_unhealthy: bool = True,
|
|
69
|
+
prefer_healthy: bool = True,
|
|
70
|
+
) -> Optional[ModelConfig]:
|
|
71
|
+
"""Get next best model from fallback chain.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
skip_unhealthy: Skip models marked as unhealthy
|
|
75
|
+
prefer_healthy: Prefer healthy models over degraded
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
ModelConfig or None if no models available
|
|
79
|
+
"""
|
|
80
|
+
# Categorize models by health
|
|
81
|
+
healthy = []
|
|
82
|
+
degraded = []
|
|
83
|
+
unhealthy = []
|
|
84
|
+
unknown = []
|
|
85
|
+
|
|
86
|
+
for model in self.models:
|
|
87
|
+
status = self.health_checker.get_status(model.model_id)
|
|
88
|
+
|
|
89
|
+
if status == HealthStatus.HEALTHY:
|
|
90
|
+
healthy.append(model)
|
|
91
|
+
elif status == HealthStatus.DEGRADED:
|
|
92
|
+
degraded.append(model)
|
|
93
|
+
elif status == HealthStatus.UNHEALTHY:
|
|
94
|
+
unhealthy.append(model)
|
|
95
|
+
else:
|
|
96
|
+
unknown.append(model)
|
|
97
|
+
|
|
98
|
+
# Select based on preference
|
|
99
|
+
candidates = []
|
|
100
|
+
|
|
101
|
+
if prefer_healthy and healthy:
|
|
102
|
+
candidates = healthy
|
|
103
|
+
elif healthy or degraded:
|
|
104
|
+
candidates = healthy + degraded
|
|
105
|
+
elif unknown:
|
|
106
|
+
candidates = unknown
|
|
107
|
+
elif not skip_unhealthy:
|
|
108
|
+
candidates = unhealthy
|
|
109
|
+
|
|
110
|
+
if not candidates:
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
# Sort by priority (highest first)
|
|
114
|
+
candidates.sort(key=lambda m: m.priority, reverse=True)
|
|
115
|
+
|
|
116
|
+
return candidates[0]
|
|
117
|
+
|
|
118
|
+
async def call_with_fallback(
|
|
119
|
+
self,
|
|
120
|
+
operation: str, # "generate" or "generate_with_tools"
|
|
121
|
+
*args: Any,
|
|
122
|
+
max_fallback_attempts: int = 3,
|
|
123
|
+
**kwargs: Any,
|
|
124
|
+
) -> Any:
|
|
125
|
+
"""Call model with automatic fallback on failure.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
operation: LLM method name
|
|
129
|
+
*args: Positional arguments
|
|
130
|
+
max_fallback_attempts: Max models to try
|
|
131
|
+
**kwargs: Keyword arguments
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Result from successful model call
|
|
135
|
+
|
|
136
|
+
Raises:
|
|
137
|
+
Exception if all models fail
|
|
138
|
+
"""
|
|
139
|
+
attempts = 0
|
|
140
|
+
last_exception = None
|
|
141
|
+
|
|
142
|
+
while attempts < max_fallback_attempts:
|
|
143
|
+
model_config = self.get_next_model()
|
|
144
|
+
|
|
145
|
+
if not model_config:
|
|
146
|
+
break
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
# Acquire semaphore for rate limiting
|
|
150
|
+
async with self._semaphores[model_config.model_id]:
|
|
151
|
+
# Call the model
|
|
152
|
+
import time
|
|
153
|
+
start = time.time()
|
|
154
|
+
|
|
155
|
+
method = getattr(model_config.llm, operation)
|
|
156
|
+
result = await method(*args, **kwargs)
|
|
157
|
+
|
|
158
|
+
latency_ms = (time.time() - start) * 1000
|
|
159
|
+
|
|
160
|
+
# Record success
|
|
161
|
+
self.health_checker.record_success(
|
|
162
|
+
model_config.model_id,
|
|
163
|
+
latency_ms=latency_ms
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
return result
|
|
167
|
+
|
|
168
|
+
except Exception as e:
|
|
169
|
+
last_exception = e
|
|
170
|
+
|
|
171
|
+
# Record failure
|
|
172
|
+
self.health_checker.record_failure(
|
|
173
|
+
model_config.model_id,
|
|
174
|
+
error=str(e)
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Check if error is retryable
|
|
178
|
+
category = ErrorClassifier.classify(e)
|
|
179
|
+
|
|
180
|
+
if category == ErrorCategory.SERVICE_ERROR:
|
|
181
|
+
# 5xx error - try fallback
|
|
182
|
+
attempts += 1
|
|
183
|
+
continue
|
|
184
|
+
else:
|
|
185
|
+
# Non-retryable error - propagate
|
|
186
|
+
raise
|
|
187
|
+
|
|
188
|
+
attempts += 1
|
|
189
|
+
|
|
190
|
+
# All models failed
|
|
191
|
+
if last_exception:
|
|
192
|
+
raise last_exception
|
|
193
|
+
else:
|
|
194
|
+
raise RuntimeError("No healthy models available")
|
|
195
|
+
|
|
196
|
+
def get_health_summary(self) -> Dict[str, Any]:
|
|
197
|
+
"""Get health summary for all models.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Dict with health information
|
|
201
|
+
"""
|
|
202
|
+
summary = {}
|
|
203
|
+
|
|
204
|
+
for model in self.models:
|
|
205
|
+
metrics = self.health_checker.get_metrics(model.model_id)
|
|
206
|
+
if metrics:
|
|
207
|
+
summary[model.model_id] = {
|
|
208
|
+
"status": metrics.status.value,
|
|
209
|
+
"success_rate": metrics.success_rate,
|
|
210
|
+
"avg_latency_ms": metrics.avg_latency_ms,
|
|
211
|
+
"consecutive_failures": metrics.consecutive_failures,
|
|
212
|
+
}
|
|
213
|
+
else:
|
|
214
|
+
summary[model.model_id] = {"status": "unknown"}
|
|
215
|
+
|
|
216
|
+
return summary
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class ModelPoolLLM(BaseLLM):
|
|
220
|
+
"""LLM wrapper that uses fallback chain.
|
|
221
|
+
|
|
222
|
+
Drop-in replacement for any BaseLLM with automatic fallback.
|
|
223
|
+
|
|
224
|
+
Example:
|
|
225
|
+
# Create pool with fallback chain
|
|
226
|
+
pool_llm = ModelPoolLLM([
|
|
227
|
+
ModelConfig("gpt-4", gpt4_llm, priority=100),
|
|
228
|
+
ModelConfig("gpt-3.5-turbo", gpt35_llm, priority=50),
|
|
229
|
+
])
|
|
230
|
+
|
|
231
|
+
# Use like any LLM
|
|
232
|
+
agent = Agent(llm=pool_llm, tools=tools)
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
def __init__(
|
|
236
|
+
self,
|
|
237
|
+
models: List[ModelConfig],
|
|
238
|
+
max_fallback_attempts: int = 3,
|
|
239
|
+
):
|
|
240
|
+
"""Initialize model pool LLM.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
models: List of model configurations
|
|
244
|
+
max_fallback_attempts: Max models to try on failure
|
|
245
|
+
"""
|
|
246
|
+
self.fallback_chain = FallbackChain(models)
|
|
247
|
+
self.max_fallback_attempts = max_fallback_attempts
|
|
248
|
+
|
|
249
|
+
# Use first model's capabilities as default
|
|
250
|
+
if models:
|
|
251
|
+
self._supports_tools = models[0].llm.supports_tools
|
|
252
|
+
self._model_name = f"pool({','.join([m.model_id for m in models])})"
|
|
253
|
+
else:
|
|
254
|
+
self._supports_tools = False
|
|
255
|
+
self._model_name = "pool(empty)"
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def model_name(self) -> str:
|
|
259
|
+
"""Get pool model name."""
|
|
260
|
+
return self._model_name
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def supports_tools(self) -> bool:
|
|
264
|
+
"""Check if pool supports tools."""
|
|
265
|
+
return self._supports_tools
|
|
266
|
+
|
|
267
|
+
async def generate(self, messages: List[dict]) -> str:
|
|
268
|
+
"""Generate completion with automatic fallback."""
|
|
269
|
+
return await self.fallback_chain.call_with_fallback(
|
|
270
|
+
"generate",
|
|
271
|
+
messages,
|
|
272
|
+
max_fallback_attempts=self.max_fallback_attempts,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
async def generate_with_tools(
|
|
276
|
+
self,
|
|
277
|
+
messages: List[dict],
|
|
278
|
+
tools: List[dict],
|
|
279
|
+
) -> dict:
|
|
280
|
+
"""Generate with tools, with automatic fallback."""
|
|
281
|
+
return await self.fallback_chain.call_with_fallback(
|
|
282
|
+
"generate_with_tools",
|
|
283
|
+
messages,
|
|
284
|
+
tools,
|
|
285
|
+
max_fallback_attempts=self.max_fallback_attempts,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
async def stream(self, messages: List[dict]):
|
|
289
|
+
"""Stream responses (uses first healthy model)."""
|
|
290
|
+
model_config = self.fallback_chain.get_next_model()
|
|
291
|
+
if not model_config:
|
|
292
|
+
raise RuntimeError("No healthy models available")
|
|
293
|
+
|
|
294
|
+
# Delegate to model's stream
|
|
295
|
+
if hasattr(model_config.llm, 'stream'):
|
|
296
|
+
async for chunk in model_config.llm.stream(messages):
|
|
297
|
+
yield chunk
|
|
298
|
+
else:
|
|
299
|
+
# Fallback to non-streaming
|
|
300
|
+
result = await self.generate(messages)
|
|
301
|
+
yield result
|
|
302
|
+
|
|
303
|
+
def get_health_summary(self) -> Dict[str, Any]:
|
|
304
|
+
"""Get health summary for pool."""
|
|
305
|
+
return self.fallback_chain.get_health_summary()
|