loom-agent 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of loom-agent might be problematic. Click here for more details.

Files changed (89) hide show
  1. loom/__init__.py +77 -0
  2. loom/agent.py +217 -0
  3. loom/agents/__init__.py +10 -0
  4. loom/agents/refs.py +28 -0
  5. loom/agents/registry.py +50 -0
  6. loom/builtin/compression/__init__.py +4 -0
  7. loom/builtin/compression/structured.py +79 -0
  8. loom/builtin/embeddings/__init__.py +9 -0
  9. loom/builtin/embeddings/openai_embedding.py +135 -0
  10. loom/builtin/embeddings/sentence_transformers_embedding.py +145 -0
  11. loom/builtin/llms/__init__.py +8 -0
  12. loom/builtin/llms/mock.py +34 -0
  13. loom/builtin/llms/openai.py +168 -0
  14. loom/builtin/llms/rule.py +102 -0
  15. loom/builtin/memory/__init__.py +5 -0
  16. loom/builtin/memory/in_memory.py +21 -0
  17. loom/builtin/memory/persistent_memory.py +278 -0
  18. loom/builtin/retriever/__init__.py +9 -0
  19. loom/builtin/retriever/chroma_store.py +265 -0
  20. loom/builtin/retriever/in_memory.py +106 -0
  21. loom/builtin/retriever/milvus_store.py +307 -0
  22. loom/builtin/retriever/pinecone_store.py +237 -0
  23. loom/builtin/retriever/qdrant_store.py +274 -0
  24. loom/builtin/retriever/vector_store.py +128 -0
  25. loom/builtin/retriever/vector_store_config.py +217 -0
  26. loom/builtin/tools/__init__.py +32 -0
  27. loom/builtin/tools/calculator.py +49 -0
  28. loom/builtin/tools/document_search.py +111 -0
  29. loom/builtin/tools/glob.py +27 -0
  30. loom/builtin/tools/grep.py +56 -0
  31. loom/builtin/tools/http_request.py +86 -0
  32. loom/builtin/tools/python_repl.py +73 -0
  33. loom/builtin/tools/read_file.py +32 -0
  34. loom/builtin/tools/task.py +158 -0
  35. loom/builtin/tools/web_search.py +64 -0
  36. loom/builtin/tools/write_file.py +31 -0
  37. loom/callbacks/base.py +9 -0
  38. loom/callbacks/logging.py +12 -0
  39. loom/callbacks/metrics.py +27 -0
  40. loom/callbacks/observability.py +248 -0
  41. loom/components/agent.py +107 -0
  42. loom/core/agent_executor.py +450 -0
  43. loom/core/circuit_breaker.py +178 -0
  44. loom/core/compression_manager.py +329 -0
  45. loom/core/context_retriever.py +185 -0
  46. loom/core/error_classifier.py +193 -0
  47. loom/core/errors.py +66 -0
  48. loom/core/message_queue.py +167 -0
  49. loom/core/permission_store.py +62 -0
  50. loom/core/permissions.py +69 -0
  51. loom/core/scheduler.py +125 -0
  52. loom/core/steering_control.py +47 -0
  53. loom/core/structured_logger.py +279 -0
  54. loom/core/subagent_pool.py +232 -0
  55. loom/core/system_prompt.py +141 -0
  56. loom/core/system_reminders.py +283 -0
  57. loom/core/tool_pipeline.py +113 -0
  58. loom/core/types.py +269 -0
  59. loom/interfaces/compressor.py +59 -0
  60. loom/interfaces/embedding.py +51 -0
  61. loom/interfaces/llm.py +33 -0
  62. loom/interfaces/memory.py +29 -0
  63. loom/interfaces/retriever.py +179 -0
  64. loom/interfaces/tool.py +27 -0
  65. loom/interfaces/vector_store.py +80 -0
  66. loom/llm/__init__.py +14 -0
  67. loom/llm/config.py +228 -0
  68. loom/llm/factory.py +111 -0
  69. loom/llm/model_health.py +235 -0
  70. loom/llm/model_pool_advanced.py +305 -0
  71. loom/llm/pool.py +170 -0
  72. loom/llm/registry.py +201 -0
  73. loom/mcp/__init__.py +4 -0
  74. loom/mcp/client.py +86 -0
  75. loom/mcp/registry.py +58 -0
  76. loom/mcp/tool_adapter.py +48 -0
  77. loom/observability/__init__.py +5 -0
  78. loom/patterns/__init__.py +5 -0
  79. loom/patterns/multi_agent.py +123 -0
  80. loom/patterns/rag.py +262 -0
  81. loom/plugins/registry.py +55 -0
  82. loom/resilience/__init__.py +5 -0
  83. loom/tooling.py +72 -0
  84. loom/utils/agent_loader.py +218 -0
  85. loom/utils/token_counter.py +19 -0
  86. loom_agent-0.0.1.dist-info/METADATA +457 -0
  87. loom_agent-0.0.1.dist-info/RECORD +89 -0
  88. loom_agent-0.0.1.dist-info/WHEEL +4 -0
  89. loom_agent-0.0.1.dist-info/licenses/LICENSE +21 -0
loom/llm/factory.py ADDED
@@ -0,0 +1,111 @@
1
+ """LLM 工厂模式
2
+
3
+ 根据配置自动创建对应的 LLM 实例,支持:
4
+ - 自动注册与发现
5
+ - 统一的创建接口
6
+ - 类型安全
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Dict, Type, Any, cast
12
+ from .config import LLMConfig, LLMProvider
13
+ from loom.interfaces.llm import BaseLLM
14
+
15
+
16
+ class LLMFactory:
17
+ """LLM 工厂
18
+
19
+ 负责根据配置创建对应的 LLM 实例。
20
+ """
21
+
22
+ _registry: Dict[LLMProvider, Type[BaseLLM]] = {}
23
+
24
+ @classmethod
25
+ def register(cls, provider: LLMProvider, llm_class: Type[BaseLLM]):
26
+ cls._registry[provider] = llm_class
27
+
28
+ @classmethod
29
+ def create(cls, config: LLMConfig) -> BaseLLM:
30
+ """根据配置创建 LLM 实例。优先调用 from_config,其次按 provider 适配构造。"""
31
+ cls._ensure_registered()
32
+
33
+ if config.provider not in cls._registry:
34
+ raise ValueError(
35
+ f"Unsupported LLM provider: {config.provider}. Available: {list(cls._registry.keys())}"
36
+ )
37
+
38
+ llm_class = cls._registry[config.provider]
39
+
40
+ # 优先使用 classmethod from_config(config)
41
+ from_config = getattr(llm_class, "from_config", None)
42
+ if callable(from_config):
43
+ return cast(BaseLLM, from_config(config))
44
+
45
+ # 无 from_config,按 provider 适配常用构造参数
46
+ if config.provider in (LLMProvider.OPENAI, LLMProvider.AZURE_OPENAI, LLMProvider.CUSTOM):
47
+ kwargs: dict[str, Any] = {
48
+ "model": config.model_name,
49
+ "temperature": config.temperature,
50
+ }
51
+ if config.max_tokens is not None:
52
+ kwargs["max_tokens"] = config.max_tokens
53
+ if config.base_url is not None:
54
+ kwargs["base_url"] = config.base_url
55
+ if config.api_key is not None:
56
+ kwargs["api_key"] = config.api_key
57
+ kwargs.update(config.extra_params)
58
+ return llm_class(**kwargs) # type: ignore[arg-type]
59
+
60
+ # 其他提供商暂不内置适配
61
+ raise ValueError(f"Provider {config.provider} has no default constructor mapping")
62
+
63
+ @classmethod
64
+ def from_dict(cls, config_dict: Dict[str, Any]) -> BaseLLM:
65
+ config = LLMConfig.from_dict(config_dict)
66
+ return cls.create(config)
67
+
68
+ # 便捷创建方法
69
+ @classmethod
70
+ def create_openai(
71
+ cls,
72
+ api_key: str,
73
+ model: str = "gpt-4",
74
+ **kwargs
75
+ ) -> BaseLLM:
76
+ config = LLMConfig.openai(api_key=api_key, model=model, **kwargs)
77
+ return cls.create(config)
78
+
79
+ # 私有:延迟注册可用实现
80
+ @classmethod
81
+ def _ensure_registered(cls):
82
+ if cls._registry:
83
+ return
84
+
85
+ # 注册内置实现(按实际存在的实现)
86
+ try:
87
+ from loom.builtin.llms.openai import OpenAILLM
88
+ cls.register(LLMProvider.OPENAI, OpenAILLM)
89
+ # 兼容自定义/代理/azure: 复用 OpenAI 客户端风格
90
+ cls.register(LLMProvider.AZURE_OPENAI, OpenAILLM)
91
+ cls.register(LLMProvider.CUSTOM, OpenAILLM)
92
+ except Exception:
93
+ pass
94
+
95
+ try:
96
+ from loom.builtin.llms.mock import MockLLM
97
+ # 仅用于测试时手动注册(无 provider 枚举映射)
98
+ except Exception:
99
+ pass
100
+
101
+ try:
102
+ from loom.builtin.llms.rule import RuleLLM
103
+ # 规则引擎型 LLM(不绑定 provider)
104
+ except Exception:
105
+ pass
106
+
107
+ @classmethod
108
+ def list_available_providers(cls) -> list[str]:
109
+ cls._ensure_registered()
110
+ return [provider.value for provider in cls._registry.keys()]
111
+
@@ -0,0 +1,235 @@
1
+ """US8: Model Health Checking
2
+
3
+ Tracks model health and enables intelligent fallback decisions.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import time
9
+ from typing import Dict, Optional
10
+ from dataclasses import dataclass
11
+ from enum import Enum
12
+
13
+
14
+ class HealthStatus(str, Enum):
15
+ """Model health status."""
16
+ HEALTHY = "healthy"
17
+ DEGRADED = "degraded"
18
+ UNHEALTHY = "unhealthy"
19
+ UNKNOWN = "unknown"
20
+
21
+
22
+ @dataclass
23
+ class HealthMetrics:
24
+ """Health metrics for a model."""
25
+ status: HealthStatus
26
+ success_count: int = 0
27
+ failure_count: int = 0
28
+ total_requests: int = 0
29
+ last_success_time: Optional[float] = None
30
+ last_failure_time: Optional[float] = None
31
+ avg_latency_ms: float = 0.0
32
+ consecutive_failures: int = 0
33
+
34
+ @property
35
+ def success_rate(self) -> float:
36
+ """Calculate success rate."""
37
+ if self.total_requests == 0:
38
+ return 0.0
39
+ return self.success_count / self.total_requests
40
+
41
+ @property
42
+ def failure_rate(self) -> float:
43
+ """Calculate failure rate."""
44
+ return 1.0 - self.success_rate
45
+
46
+
47
+ class ModelHealthChecker:
48
+ """Tracks and evaluates model health.
49
+
50
+ Example:
51
+ checker = ModelHealthChecker()
52
+
53
+ # Record success
54
+ checker.record_success("gpt-4", latency_ms=234.5)
55
+
56
+ # Record failure
57
+ checker.record_failure("gpt-4", error="timeout")
58
+
59
+ # Check health
60
+ status = checker.get_status("gpt-4")
61
+ print(f"Model health: {status}") # HEALTHY / DEGRADED / UNHEALTHY
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ degraded_threshold: float = 0.8, # <80% success rate = degraded
67
+ unhealthy_threshold: float = 0.5, # <50% success rate = unhealthy
68
+ consecutive_failure_threshold: int = 5, # 5 consecutive failures = unhealthy
69
+ health_check_window: int = 100, # Last 100 requests
70
+ ):
71
+ """Initialize health checker.
72
+
73
+ Args:
74
+ degraded_threshold: Success rate threshold for degraded status
75
+ unhealthy_threshold: Success rate threshold for unhealthy status
76
+ consecutive_failure_threshold: Consecutive failures to mark unhealthy
77
+ health_check_window: Number of recent requests to consider
78
+ """
79
+ self.degraded_threshold = degraded_threshold
80
+ self.unhealthy_threshold = unhealthy_threshold
81
+ self.consecutive_failure_threshold = consecutive_failure_threshold
82
+ self.health_check_window = health_check_window
83
+
84
+ self._metrics: Dict[str, HealthMetrics] = {}
85
+ self._latency_samples: Dict[str, list[float]] = {} # Rolling window
86
+
87
+ def record_success(
88
+ self,
89
+ model_id: str,
90
+ latency_ms: float = 0.0
91
+ ) -> None:
92
+ """Record a successful request.
93
+
94
+ Args:
95
+ model_id: Model identifier
96
+ latency_ms: Request latency in milliseconds
97
+ """
98
+ if model_id not in self._metrics:
99
+ self._metrics[model_id] = HealthMetrics(status=HealthStatus.UNKNOWN)
100
+
101
+ metrics = self._metrics[model_id]
102
+ metrics.success_count += 1
103
+ metrics.total_requests += 1
104
+ metrics.last_success_time = time.time()
105
+ metrics.consecutive_failures = 0
106
+
107
+ # Update rolling latency
108
+ if model_id not in self._latency_samples:
109
+ self._latency_samples[model_id] = []
110
+
111
+ self._latency_samples[model_id].append(latency_ms)
112
+
113
+ # Keep only recent samples
114
+ if len(self._latency_samples[model_id]) > self.health_check_window:
115
+ self._latency_samples[model_id].pop(0)
116
+
117
+ # Update average latency
118
+ if self._latency_samples[model_id]:
119
+ metrics.avg_latency_ms = sum(self._latency_samples[model_id]) / len(self._latency_samples[model_id])
120
+
121
+ # Update status
122
+ self._update_status(model_id)
123
+
124
+ def record_failure(
125
+ self,
126
+ model_id: str,
127
+ error: Optional[str] = None
128
+ ) -> None:
129
+ """Record a failed request.
130
+
131
+ Args:
132
+ model_id: Model identifier
133
+ error: Error message
134
+ """
135
+ if model_id not in self._metrics:
136
+ self._metrics[model_id] = HealthMetrics(status=HealthStatus.UNKNOWN)
137
+
138
+ metrics = self._metrics[model_id]
139
+ metrics.failure_count += 1
140
+ metrics.total_requests += 1
141
+ metrics.last_failure_time = time.time()
142
+ metrics.consecutive_failures += 1
143
+
144
+ # Update status
145
+ self._update_status(model_id)
146
+
147
+ def _update_status(self, model_id: str) -> None:
148
+ """Update health status based on metrics.
149
+
150
+ Args:
151
+ model_id: Model identifier
152
+ """
153
+ metrics = self._metrics[model_id]
154
+
155
+ # Check consecutive failures
156
+ if metrics.consecutive_failures >= self.consecutive_failure_threshold:
157
+ metrics.status = HealthStatus.UNHEALTHY
158
+ return
159
+
160
+ # Check success rate (only if we have enough data)
161
+ if metrics.total_requests > 0:
162
+ success_rate = metrics.success_rate
163
+
164
+ if success_rate >= self.degraded_threshold:
165
+ metrics.status = HealthStatus.HEALTHY
166
+ elif success_rate >= self.unhealthy_threshold:
167
+ metrics.status = HealthStatus.DEGRADED
168
+ else:
169
+ metrics.status = HealthStatus.UNHEALTHY
170
+ else:
171
+ metrics.status = HealthStatus.UNKNOWN
172
+
173
+ def get_status(self, model_id: str) -> HealthStatus:
174
+ """Get current health status.
175
+
176
+ Args:
177
+ model_id: Model identifier
178
+
179
+ Returns:
180
+ Current health status
181
+ """
182
+ if model_id not in self._metrics:
183
+ return HealthStatus.UNKNOWN
184
+
185
+ return self._metrics[model_id].status
186
+
187
+ def get_metrics(self, model_id: str) -> Optional[HealthMetrics]:
188
+ """Get detailed health metrics.
189
+
190
+ Args:
191
+ model_id: Model identifier
192
+
193
+ Returns:
194
+ HealthMetrics or None if not tracked
195
+ """
196
+ return self._metrics.get(model_id)
197
+
198
+ def is_healthy(self, model_id: str) -> bool:
199
+ """Check if model is healthy.
200
+
201
+ Args:
202
+ model_id: Model identifier
203
+
204
+ Returns:
205
+ True if healthy, False otherwise
206
+ """
207
+ status = self.get_status(model_id)
208
+ return status == HealthStatus.HEALTHY
209
+
210
+ def get_all_healthy_models(self) -> list[str]:
211
+ """Get list of all healthy models.
212
+
213
+ Returns:
214
+ List of model IDs with healthy status
215
+ """
216
+ return [
217
+ model_id
218
+ for model_id, metrics in self._metrics.items()
219
+ if metrics.status == HealthStatus.HEALTHY
220
+ ]
221
+
222
+ def reset(self, model_id: Optional[str] = None) -> None:
223
+ """Reset health metrics.
224
+
225
+ Args:
226
+ model_id: Model to reset (None = reset all)
227
+ """
228
+ if model_id:
229
+ if model_id in self._metrics:
230
+ self._metrics[model_id] = HealthMetrics(status=HealthStatus.UNKNOWN)
231
+ if model_id in self._latency_samples:
232
+ self._latency_samples[model_id] = []
233
+ else:
234
+ self._metrics.clear()
235
+ self._latency_samples.clear()
@@ -0,0 +1,305 @@
1
+ """US8: Advanced Model Pool with Fallback Chain
2
+
3
+ Provides intelligent model selection with health-aware fallback.
4
+
5
+ Features:
6
+ - Fallback chain: [primary, fallback1, fallback2]
7
+ - Automatic fallback on 5xx errors
8
+ - Health-based model selection
9
+ - Connection pooling for reduced latency
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import asyncio
15
+ from typing import List, Optional, Dict, Any
16
+ from dataclasses import dataclass
17
+
18
+ from loom.interfaces.llm import BaseLLM
19
+ from loom.llm.model_health import ModelHealthChecker, HealthStatus
20
+ from loom.core.error_classifier import ErrorClassifier, ErrorCategory
21
+
22
+
23
+ @dataclass
24
+ class ModelConfig:
25
+ """Configuration for a model in the pool."""
26
+ model_id: str
27
+ llm: BaseLLM
28
+ priority: int = 0 # Higher priority = preferred (for same health status)
29
+ max_concurrent: int = 10 # Max concurrent requests to this model
30
+
31
+
32
+ class FallbackChain:
33
+ """Manages fallback chain for model selection.
34
+
35
+ Example:
36
+ chain = FallbackChain([
37
+ ModelConfig("gpt-4", gpt4_llm, priority=100),
38
+ ModelConfig("gpt-3.5-turbo", gpt35_llm, priority=50),
39
+ ModelConfig("claude-2", claude_llm, priority=30),
40
+ ])
41
+
42
+ # Automatically selects best available model
43
+ llm = await chain.get_next_model()
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ models: List[ModelConfig],
49
+ health_checker: Optional[ModelHealthChecker] = None,
50
+ ):
51
+ """Initialize fallback chain.
52
+
53
+ Args:
54
+ models: List of model configurations (ordered by preference)
55
+ health_checker: Health checker instance
56
+ """
57
+ self.models = models
58
+ self.health_checker = health_checker or ModelHealthChecker()
59
+ self._model_map = {m.model_id: m for m in models}
60
+ self._semaphores: Dict[str, asyncio.Semaphore] = {}
61
+
62
+ # Create semaphores for rate limiting
63
+ for model in models:
64
+ self._semaphores[model.model_id] = asyncio.Semaphore(model.max_concurrent)
65
+
66
+ def get_next_model(
67
+ self,
68
+ skip_unhealthy: bool = True,
69
+ prefer_healthy: bool = True,
70
+ ) -> Optional[ModelConfig]:
71
+ """Get next best model from fallback chain.
72
+
73
+ Args:
74
+ skip_unhealthy: Skip models marked as unhealthy
75
+ prefer_healthy: Prefer healthy models over degraded
76
+
77
+ Returns:
78
+ ModelConfig or None if no models available
79
+ """
80
+ # Categorize models by health
81
+ healthy = []
82
+ degraded = []
83
+ unhealthy = []
84
+ unknown = []
85
+
86
+ for model in self.models:
87
+ status = self.health_checker.get_status(model.model_id)
88
+
89
+ if status == HealthStatus.HEALTHY:
90
+ healthy.append(model)
91
+ elif status == HealthStatus.DEGRADED:
92
+ degraded.append(model)
93
+ elif status == HealthStatus.UNHEALTHY:
94
+ unhealthy.append(model)
95
+ else:
96
+ unknown.append(model)
97
+
98
+ # Select based on preference
99
+ candidates = []
100
+
101
+ if prefer_healthy and healthy:
102
+ candidates = healthy
103
+ elif healthy or degraded:
104
+ candidates = healthy + degraded
105
+ elif unknown:
106
+ candidates = unknown
107
+ elif not skip_unhealthy:
108
+ candidates = unhealthy
109
+
110
+ if not candidates:
111
+ return None
112
+
113
+ # Sort by priority (highest first)
114
+ candidates.sort(key=lambda m: m.priority, reverse=True)
115
+
116
+ return candidates[0]
117
+
118
+ async def call_with_fallback(
119
+ self,
120
+ operation: str, # "generate" or "generate_with_tools"
121
+ *args: Any,
122
+ max_fallback_attempts: int = 3,
123
+ **kwargs: Any,
124
+ ) -> Any:
125
+ """Call model with automatic fallback on failure.
126
+
127
+ Args:
128
+ operation: LLM method name
129
+ *args: Positional arguments
130
+ max_fallback_attempts: Max models to try
131
+ **kwargs: Keyword arguments
132
+
133
+ Returns:
134
+ Result from successful model call
135
+
136
+ Raises:
137
+ Exception if all models fail
138
+ """
139
+ attempts = 0
140
+ last_exception = None
141
+
142
+ while attempts < max_fallback_attempts:
143
+ model_config = self.get_next_model()
144
+
145
+ if not model_config:
146
+ break
147
+
148
+ try:
149
+ # Acquire semaphore for rate limiting
150
+ async with self._semaphores[model_config.model_id]:
151
+ # Call the model
152
+ import time
153
+ start = time.time()
154
+
155
+ method = getattr(model_config.llm, operation)
156
+ result = await method(*args, **kwargs)
157
+
158
+ latency_ms = (time.time() - start) * 1000
159
+
160
+ # Record success
161
+ self.health_checker.record_success(
162
+ model_config.model_id,
163
+ latency_ms=latency_ms
164
+ )
165
+
166
+ return result
167
+
168
+ except Exception as e:
169
+ last_exception = e
170
+
171
+ # Record failure
172
+ self.health_checker.record_failure(
173
+ model_config.model_id,
174
+ error=str(e)
175
+ )
176
+
177
+ # Check if error is retryable
178
+ category = ErrorClassifier.classify(e)
179
+
180
+ if category == ErrorCategory.SERVICE_ERROR:
181
+ # 5xx error - try fallback
182
+ attempts += 1
183
+ continue
184
+ else:
185
+ # Non-retryable error - propagate
186
+ raise
187
+
188
+ attempts += 1
189
+
190
+ # All models failed
191
+ if last_exception:
192
+ raise last_exception
193
+ else:
194
+ raise RuntimeError("No healthy models available")
195
+
196
+ def get_health_summary(self) -> Dict[str, Any]:
197
+ """Get health summary for all models.
198
+
199
+ Returns:
200
+ Dict with health information
201
+ """
202
+ summary = {}
203
+
204
+ for model in self.models:
205
+ metrics = self.health_checker.get_metrics(model.model_id)
206
+ if metrics:
207
+ summary[model.model_id] = {
208
+ "status": metrics.status.value,
209
+ "success_rate": metrics.success_rate,
210
+ "avg_latency_ms": metrics.avg_latency_ms,
211
+ "consecutive_failures": metrics.consecutive_failures,
212
+ }
213
+ else:
214
+ summary[model.model_id] = {"status": "unknown"}
215
+
216
+ return summary
217
+
218
+
219
+ class ModelPoolLLM(BaseLLM):
220
+ """LLM wrapper that uses fallback chain.
221
+
222
+ Drop-in replacement for any BaseLLM with automatic fallback.
223
+
224
+ Example:
225
+ # Create pool with fallback chain
226
+ pool_llm = ModelPoolLLM([
227
+ ModelConfig("gpt-4", gpt4_llm, priority=100),
228
+ ModelConfig("gpt-3.5-turbo", gpt35_llm, priority=50),
229
+ ])
230
+
231
+ # Use like any LLM
232
+ agent = Agent(llm=pool_llm, tools=tools)
233
+ """
234
+
235
+ def __init__(
236
+ self,
237
+ models: List[ModelConfig],
238
+ max_fallback_attempts: int = 3,
239
+ ):
240
+ """Initialize model pool LLM.
241
+
242
+ Args:
243
+ models: List of model configurations
244
+ max_fallback_attempts: Max models to try on failure
245
+ """
246
+ self.fallback_chain = FallbackChain(models)
247
+ self.max_fallback_attempts = max_fallback_attempts
248
+
249
+ # Use first model's capabilities as default
250
+ if models:
251
+ self._supports_tools = models[0].llm.supports_tools
252
+ self._model_name = f"pool({','.join([m.model_id for m in models])})"
253
+ else:
254
+ self._supports_tools = False
255
+ self._model_name = "pool(empty)"
256
+
257
+ @property
258
+ def model_name(self) -> str:
259
+ """Get pool model name."""
260
+ return self._model_name
261
+
262
+ @property
263
+ def supports_tools(self) -> bool:
264
+ """Check if pool supports tools."""
265
+ return self._supports_tools
266
+
267
+ async def generate(self, messages: List[dict]) -> str:
268
+ """Generate completion with automatic fallback."""
269
+ return await self.fallback_chain.call_with_fallback(
270
+ "generate",
271
+ messages,
272
+ max_fallback_attempts=self.max_fallback_attempts,
273
+ )
274
+
275
+ async def generate_with_tools(
276
+ self,
277
+ messages: List[dict],
278
+ tools: List[dict],
279
+ ) -> dict:
280
+ """Generate with tools, with automatic fallback."""
281
+ return await self.fallback_chain.call_with_fallback(
282
+ "generate_with_tools",
283
+ messages,
284
+ tools,
285
+ max_fallback_attempts=self.max_fallback_attempts,
286
+ )
287
+
288
+ async def stream(self, messages: List[dict]):
289
+ """Stream responses (uses first healthy model)."""
290
+ model_config = self.fallback_chain.get_next_model()
291
+ if not model_config:
292
+ raise RuntimeError("No healthy models available")
293
+
294
+ # Delegate to model's stream
295
+ if hasattr(model_config.llm, 'stream'):
296
+ async for chunk in model_config.llm.stream(messages):
297
+ yield chunk
298
+ else:
299
+ # Fallback to non-streaming
300
+ result = await self.generate(messages)
301
+ yield result
302
+
303
+ def get_health_summary(self) -> Dict[str, Any]:
304
+ """Get health summary for pool."""
305
+ return self.fallback_chain.get_health_summary()