agentfield 0.1.22rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentfield/__init__.py +66 -0
- agentfield/agent.py +3569 -0
- agentfield/agent_ai.py +1125 -0
- agentfield/agent_cli.py +386 -0
- agentfield/agent_field_handler.py +494 -0
- agentfield/agent_mcp.py +534 -0
- agentfield/agent_registry.py +29 -0
- agentfield/agent_server.py +1185 -0
- agentfield/agent_utils.py +269 -0
- agentfield/agent_workflow.py +323 -0
- agentfield/async_config.py +278 -0
- agentfield/async_execution_manager.py +1227 -0
- agentfield/client.py +1447 -0
- agentfield/connection_manager.py +280 -0
- agentfield/decorators.py +527 -0
- agentfield/did_manager.py +337 -0
- agentfield/dynamic_skills.py +304 -0
- agentfield/execution_context.py +255 -0
- agentfield/execution_state.py +453 -0
- agentfield/http_connection_manager.py +429 -0
- agentfield/litellm_adapters.py +140 -0
- agentfield/logger.py +249 -0
- agentfield/mcp_client.py +204 -0
- agentfield/mcp_manager.py +340 -0
- agentfield/mcp_stdio_bridge.py +550 -0
- agentfield/memory.py +723 -0
- agentfield/memory_events.py +489 -0
- agentfield/multimodal.py +173 -0
- agentfield/multimodal_response.py +403 -0
- agentfield/pydantic_utils.py +227 -0
- agentfield/rate_limiter.py +280 -0
- agentfield/result_cache.py +441 -0
- agentfield/router.py +190 -0
- agentfield/status.py +70 -0
- agentfield/types.py +710 -0
- agentfield/utils.py +26 -0
- agentfield/vc_generator.py +464 -0
- agentfield/vision.py +198 -0
- agentfield-0.1.22rc2.dist-info/METADATA +102 -0
- agentfield-0.1.22rc2.dist-info/RECORD +42 -0
- agentfield-0.1.22rc2.dist-info/WHEEL +5 -0
- agentfield-0.1.22rc2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HTTP Connection Manager for async execution.
|
|
3
|
+
|
|
4
|
+
This module provides aiohttp session pooling with configurable connection limits,
|
|
5
|
+
connection reuse, proper cleanup, timeout handling, and connection health monitoring.
|
|
6
|
+
Supports both single requests and batch operations for the AgentField SDK async execution.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import time
|
|
11
|
+
from contextlib import asynccontextmanager
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from typing import Any, Dict, List, Optional, Union
|
|
14
|
+
|
|
15
|
+
import aiohttp
|
|
16
|
+
|
|
17
|
+
from .async_config import AsyncConfig
|
|
18
|
+
from .logger import get_logger
|
|
19
|
+
|
|
20
|
+
logger = get_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class ConnectionMetrics:
|
|
25
|
+
"""Metrics for connection pool monitoring."""
|
|
26
|
+
|
|
27
|
+
total_requests: int = 0
|
|
28
|
+
successful_requests: int = 0
|
|
29
|
+
failed_requests: int = 0
|
|
30
|
+
timeout_requests: int = 0
|
|
31
|
+
active_connections: int = 0
|
|
32
|
+
pool_size: int = 0
|
|
33
|
+
created_at: float = field(default_factory=time.time)
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def success_rate(self) -> float:
|
|
37
|
+
"""Calculate success rate as a percentage."""
|
|
38
|
+
if self.total_requests == 0:
|
|
39
|
+
return 0.0
|
|
40
|
+
return (self.successful_requests / self.total_requests) * 100
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def uptime(self) -> float:
|
|
44
|
+
"""Get uptime in seconds."""
|
|
45
|
+
return time.time() - self.created_at
|
|
46
|
+
|
|
47
|
+
def record_request(self, success: bool, timeout: bool = False) -> None:
|
|
48
|
+
"""Record a request attempt."""
|
|
49
|
+
self.total_requests += 1
|
|
50
|
+
if success:
|
|
51
|
+
self.successful_requests += 1
|
|
52
|
+
else:
|
|
53
|
+
self.failed_requests += 1
|
|
54
|
+
if timeout:
|
|
55
|
+
self.timeout_requests += 1
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class ConnectionHealth:
|
|
60
|
+
"""Health status of connection pool."""
|
|
61
|
+
|
|
62
|
+
is_healthy: bool = True
|
|
63
|
+
last_check: float = field(default_factory=time.time)
|
|
64
|
+
consecutive_failures: int = 0
|
|
65
|
+
last_error: Optional[str] = None
|
|
66
|
+
|
|
67
|
+
def mark_healthy(self) -> None:
|
|
68
|
+
"""Mark connection as healthy."""
|
|
69
|
+
self.is_healthy = True
|
|
70
|
+
self.consecutive_failures = 0
|
|
71
|
+
self.last_error = None
|
|
72
|
+
self.last_check = time.time()
|
|
73
|
+
|
|
74
|
+
def mark_unhealthy(self, error: str) -> None:
|
|
75
|
+
"""Mark connection as unhealthy."""
|
|
76
|
+
self.is_healthy = False
|
|
77
|
+
self.consecutive_failures += 1
|
|
78
|
+
self.last_error = error
|
|
79
|
+
self.last_check = time.time()
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class ConnectionManager:
|
|
83
|
+
"""
|
|
84
|
+
HTTP Connection Manager with aiohttp session pooling.
|
|
85
|
+
|
|
86
|
+
Provides efficient HTTP connection management for async execution with:
|
|
87
|
+
- Configurable connection limits and timeouts
|
|
88
|
+
- Connection reuse and proper cleanup
|
|
89
|
+
- Health monitoring and metrics
|
|
90
|
+
- Support for single requests and batch operations
|
|
91
|
+
- Thread-safe operations for concurrent access
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(self, config: Optional[AsyncConfig] = None):
|
|
95
|
+
"""
|
|
96
|
+
Initialize the connection manager.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
config: AsyncConfig instance for configuration parameters
|
|
100
|
+
"""
|
|
101
|
+
self.config = config or AsyncConfig()
|
|
102
|
+
self._session: Optional[aiohttp.ClientSession] = None
|
|
103
|
+
self._connector: Optional[aiohttp.TCPConnector] = None
|
|
104
|
+
self._lock = asyncio.Lock()
|
|
105
|
+
self._closed = False
|
|
106
|
+
|
|
107
|
+
# Metrics and health monitoring
|
|
108
|
+
self.metrics = ConnectionMetrics()
|
|
109
|
+
self.health = ConnectionHealth()
|
|
110
|
+
|
|
111
|
+
# Background tasks
|
|
112
|
+
self._health_check_task: Optional[asyncio.Task] = None
|
|
113
|
+
self._cleanup_task: Optional[asyncio.Task] = None
|
|
114
|
+
|
|
115
|
+
logger.debug(f"ConnectionManager initialized with config: {self.config}")
|
|
116
|
+
|
|
117
|
+
async def __aenter__(self):
|
|
118
|
+
"""Async context manager entry."""
|
|
119
|
+
await self.start()
|
|
120
|
+
return self
|
|
121
|
+
|
|
122
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
123
|
+
"""Async context manager exit."""
|
|
124
|
+
await self.close()
|
|
125
|
+
|
|
126
|
+
async def start(self) -> None:
|
|
127
|
+
"""
|
|
128
|
+
Start the connection manager and initialize session.
|
|
129
|
+
|
|
130
|
+
Raises:
|
|
131
|
+
RuntimeError: If manager is already started or closed
|
|
132
|
+
"""
|
|
133
|
+
async with self._lock:
|
|
134
|
+
if self._session is not None:
|
|
135
|
+
raise RuntimeError("ConnectionManager is already started")
|
|
136
|
+
|
|
137
|
+
if self._closed:
|
|
138
|
+
raise RuntimeError(
|
|
139
|
+
"ConnectionManager is closed and cannot be restarted"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# Create TCP connector with configuration
|
|
143
|
+
self._connector = aiohttp.TCPConnector(
|
|
144
|
+
limit=self.config.connection_pool_size,
|
|
145
|
+
limit_per_host=self.config.connection_pool_per_host,
|
|
146
|
+
ttl_dns_cache=300, # 5 minutes DNS cache
|
|
147
|
+
use_dns_cache=True,
|
|
148
|
+
keepalive_timeout=30,
|
|
149
|
+
enable_cleanup_closed=True,
|
|
150
|
+
force_close=False,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Create session with timeout configuration
|
|
154
|
+
timeout = aiohttp.ClientTimeout(
|
|
155
|
+
total=self.config.max_execution_timeout,
|
|
156
|
+
connect=self.config.polling_timeout,
|
|
157
|
+
sock_read=self.config.polling_timeout,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
self._session = aiohttp.ClientSession(
|
|
161
|
+
connector=self._connector,
|
|
162
|
+
timeout=timeout,
|
|
163
|
+
headers={
|
|
164
|
+
"User-Agent": "AgentField-SDK-AsyncClient/1.0",
|
|
165
|
+
"Accept": "application/json",
|
|
166
|
+
"Content-Type": "application/json",
|
|
167
|
+
},
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Update metrics
|
|
171
|
+
self.metrics.pool_size = self.config.connection_pool_size
|
|
172
|
+
|
|
173
|
+
# Start background tasks if enabled
|
|
174
|
+
if self.config.enable_performance_logging:
|
|
175
|
+
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
|
176
|
+
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
|
177
|
+
|
|
178
|
+
logger.info(
|
|
179
|
+
f"ConnectionManager started with pool size {self.config.connection_pool_size}"
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
async def close(self) -> None:
|
|
183
|
+
"""
|
|
184
|
+
Close the connection manager and cleanup resources.
|
|
185
|
+
"""
|
|
186
|
+
async with self._lock:
|
|
187
|
+
if self._closed:
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
self._closed = True
|
|
191
|
+
|
|
192
|
+
# Cancel background tasks
|
|
193
|
+
if self._health_check_task:
|
|
194
|
+
self._health_check_task.cancel()
|
|
195
|
+
try:
|
|
196
|
+
await self._health_check_task
|
|
197
|
+
except asyncio.CancelledError:
|
|
198
|
+
pass
|
|
199
|
+
|
|
200
|
+
if self._cleanup_task:
|
|
201
|
+
self._cleanup_task.cancel()
|
|
202
|
+
try:
|
|
203
|
+
await self._cleanup_task
|
|
204
|
+
except asyncio.CancelledError:
|
|
205
|
+
pass
|
|
206
|
+
|
|
207
|
+
# Close session and connector
|
|
208
|
+
if self._session:
|
|
209
|
+
await self._session.close()
|
|
210
|
+
self._session = None
|
|
211
|
+
|
|
212
|
+
if self._connector:
|
|
213
|
+
await self._connector.close()
|
|
214
|
+
self._connector = None
|
|
215
|
+
|
|
216
|
+
logger.info("ConnectionManager closed")
|
|
217
|
+
|
|
218
|
+
@asynccontextmanager
|
|
219
|
+
async def get_session(self):
|
|
220
|
+
"""
|
|
221
|
+
Get an aiohttp session for making requests.
|
|
222
|
+
|
|
223
|
+
Yields:
|
|
224
|
+
aiohttp.ClientSession: Active session for making requests
|
|
225
|
+
|
|
226
|
+
Raises:
|
|
227
|
+
RuntimeError: If manager is not started or is closed
|
|
228
|
+
"""
|
|
229
|
+
if self._session is None:
|
|
230
|
+
raise RuntimeError("ConnectionManager is not started. Call start() first.")
|
|
231
|
+
|
|
232
|
+
if self._closed:
|
|
233
|
+
raise RuntimeError("ConnectionManager is closed")
|
|
234
|
+
|
|
235
|
+
try:
|
|
236
|
+
yield self._session
|
|
237
|
+
except Exception as e:
|
|
238
|
+
self.health.mark_unhealthy(str(e))
|
|
239
|
+
raise
|
|
240
|
+
|
|
241
|
+
async def request(self, method: str, url: str, **kwargs) -> aiohttp.ClientResponse:
|
|
242
|
+
"""
|
|
243
|
+
Make a single HTTP request.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
method: HTTP method (GET, POST, etc.)
|
|
247
|
+
url: Request URL
|
|
248
|
+
**kwargs: Additional arguments for aiohttp request
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
aiohttp.ClientResponse: Response object
|
|
252
|
+
|
|
253
|
+
Raises:
|
|
254
|
+
aiohttp.ClientError: For HTTP-related errors
|
|
255
|
+
asyncio.TimeoutError: For timeout errors
|
|
256
|
+
"""
|
|
257
|
+
start_time = time.time()
|
|
258
|
+
success = False
|
|
259
|
+
timeout_occurred = False
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
async with self.get_session() as session:
|
|
263
|
+
response = await session.request(method, url, **kwargs)
|
|
264
|
+
success = True
|
|
265
|
+
self.health.mark_healthy()
|
|
266
|
+
return response
|
|
267
|
+
|
|
268
|
+
except asyncio.TimeoutError:
|
|
269
|
+
timeout_occurred = True
|
|
270
|
+
logger.warn(f"Request timeout for {method} {url}")
|
|
271
|
+
raise
|
|
272
|
+
except Exception as e:
|
|
273
|
+
self.health.mark_unhealthy(str(e))
|
|
274
|
+
logger.error(f"Request failed for {method} {url}: {e}")
|
|
275
|
+
raise
|
|
276
|
+
finally:
|
|
277
|
+
# Record metrics
|
|
278
|
+
self.metrics.record_request(success, timeout_occurred)
|
|
279
|
+
|
|
280
|
+
# Log slow requests
|
|
281
|
+
duration = time.time() - start_time
|
|
282
|
+
if (
|
|
283
|
+
self.config.log_slow_executions
|
|
284
|
+
and duration > self.config.slow_execution_threshold
|
|
285
|
+
):
|
|
286
|
+
logger.warn(f"Slow request: {method} {url} took {duration:.2f}s")
|
|
287
|
+
|
|
288
|
+
async def batch_request(
|
|
289
|
+
self, requests: List[Dict[str, Any]]
|
|
290
|
+
) -> List[Union[aiohttp.ClientResponse, Exception]]:
|
|
291
|
+
"""
|
|
292
|
+
Make multiple HTTP requests concurrently.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
requests: List of request dictionaries with 'method', 'url', and optional kwargs
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
List of responses or exceptions for each request
|
|
299
|
+
"""
|
|
300
|
+
if not requests:
|
|
301
|
+
return []
|
|
302
|
+
|
|
303
|
+
# Limit concurrent requests
|
|
304
|
+
semaphore = asyncio.Semaphore(self.config.max_active_polls)
|
|
305
|
+
|
|
306
|
+
async def make_request(
|
|
307
|
+
req_data: Dict[str, Any],
|
|
308
|
+
) -> Union[aiohttp.ClientResponse, Exception]:
|
|
309
|
+
async with semaphore:
|
|
310
|
+
try:
|
|
311
|
+
method = req_data.pop("method")
|
|
312
|
+
url = req_data.pop("url")
|
|
313
|
+
return await self.request(method, url, **req_data)
|
|
314
|
+
except Exception as e:
|
|
315
|
+
return e
|
|
316
|
+
|
|
317
|
+
# Execute all requests concurrently
|
|
318
|
+
tasks = [make_request(req.copy()) for req in requests]
|
|
319
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
320
|
+
|
|
321
|
+
logger.debug(f"Batch request completed: {len(requests)} requests")
|
|
322
|
+
return results
|
|
323
|
+
|
|
324
|
+
async def health_check(self) -> bool:
|
|
325
|
+
"""
|
|
326
|
+
Perform a health check on the connection pool.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
bool: True if healthy, False otherwise
|
|
330
|
+
"""
|
|
331
|
+
try:
|
|
332
|
+
# Simple health check - try to create a request
|
|
333
|
+
if self._session is None or self._session.closed:
|
|
334
|
+
self.health.mark_unhealthy("Session is closed")
|
|
335
|
+
return False
|
|
336
|
+
|
|
337
|
+
# Check connector health
|
|
338
|
+
if self._connector is None or self._connector.closed:
|
|
339
|
+
self.health.mark_unhealthy("Connector is closed")
|
|
340
|
+
return False
|
|
341
|
+
|
|
342
|
+
self.health.mark_healthy()
|
|
343
|
+
return True
|
|
344
|
+
|
|
345
|
+
except Exception as e:
|
|
346
|
+
self.health.mark_unhealthy(str(e))
|
|
347
|
+
return False
|
|
348
|
+
|
|
349
|
+
async def _health_check_loop(self) -> None:
|
|
350
|
+
"""Background task for periodic health checks."""
|
|
351
|
+
while not self._closed:
|
|
352
|
+
try:
|
|
353
|
+
await asyncio.sleep(60) # Check every minute
|
|
354
|
+
await self.health_check()
|
|
355
|
+
|
|
356
|
+
# Log health status if unhealthy
|
|
357
|
+
if not self.health.is_healthy:
|
|
358
|
+
logger.warn(f"Connection pool unhealthy: {self.health.last_error}")
|
|
359
|
+
|
|
360
|
+
except asyncio.CancelledError:
|
|
361
|
+
break
|
|
362
|
+
except Exception as e:
|
|
363
|
+
logger.error(f"Health check loop error: {e}")
|
|
364
|
+
|
|
365
|
+
async def _cleanup_loop(self) -> None:
|
|
366
|
+
"""Background task for periodic cleanup."""
|
|
367
|
+
while not self._closed:
|
|
368
|
+
try:
|
|
369
|
+
await asyncio.sleep(self.config.cleanup_interval)
|
|
370
|
+
|
|
371
|
+
# Update active connections metric
|
|
372
|
+
if self._connector:
|
|
373
|
+
self.metrics.active_connections = len(self._connector._conns)
|
|
374
|
+
|
|
375
|
+
# Log metrics if performance logging is enabled
|
|
376
|
+
if self.config.enable_performance_logging:
|
|
377
|
+
logger.debug(
|
|
378
|
+
f"Connection metrics: {self.metrics.total_requests} total, "
|
|
379
|
+
f"{self.metrics.success_rate:.1f}% success, "
|
|
380
|
+
f"{self.metrics.active_connections} active"
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
except asyncio.CancelledError:
|
|
384
|
+
break
|
|
385
|
+
except Exception as e:
|
|
386
|
+
logger.error(f"Cleanup loop error: {e}")
|
|
387
|
+
|
|
388
|
+
def get_metrics(self) -> ConnectionMetrics:
|
|
389
|
+
"""
|
|
390
|
+
Get current connection metrics.
|
|
391
|
+
|
|
392
|
+
Returns:
|
|
393
|
+
ConnectionMetrics: Current metrics snapshot
|
|
394
|
+
"""
|
|
395
|
+
# Update active connections if connector is available
|
|
396
|
+
if self._connector:
|
|
397
|
+
self.metrics.active_connections = len(self._connector._conns)
|
|
398
|
+
|
|
399
|
+
return self.metrics
|
|
400
|
+
|
|
401
|
+
def get_health(self) -> ConnectionHealth:
|
|
402
|
+
"""
|
|
403
|
+
Get current health status.
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
ConnectionHealth: Current health status
|
|
407
|
+
"""
|
|
408
|
+
return self.health
|
|
409
|
+
|
|
410
|
+
@property
|
|
411
|
+
def is_healthy(self) -> bool:
|
|
412
|
+
"""Check if connection manager is healthy."""
|
|
413
|
+
return self.health.is_healthy and not self._closed
|
|
414
|
+
|
|
415
|
+
@property
|
|
416
|
+
def is_closed(self) -> bool:
|
|
417
|
+
"""Check if connection manager is closed."""
|
|
418
|
+
return self._closed
|
|
419
|
+
|
|
420
|
+
def __repr__(self) -> str:
|
|
421
|
+
"""String representation of the connection manager."""
|
|
422
|
+
return (
|
|
423
|
+
f"ConnectionManager("
|
|
424
|
+
f"pool_size={self.config.connection_pool_size}, "
|
|
425
|
+
f"healthy={self.is_healthy}, "
|
|
426
|
+
f"closed={self.is_closed}, "
|
|
427
|
+
f"requests={self.metrics.total_requests}"
|
|
428
|
+
f")"
|
|
429
|
+
)
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LiteLLM Provider Adapters
|
|
3
|
+
|
|
4
|
+
This module centralizes provider-specific parameter transformations and patches
|
|
5
|
+
required to ensure compatibility across different LLM providers.
|
|
6
|
+
|
|
7
|
+
Each patch should be:
|
|
8
|
+
1. Well-documented with the reason for its existence
|
|
9
|
+
2. Tied to specific providers/models that require it
|
|
10
|
+
3. Transparent about what transformation is being applied
|
|
11
|
+
|
|
12
|
+
This abstraction allows the core SDK to remain clean while handling necessary
|
|
13
|
+
provider-specific quirks in one maintainable location.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from typing import Dict, Any
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_provider_from_model(model: str) -> str:
|
|
20
|
+
"""
|
|
21
|
+
Extract provider name from model string.
|
|
22
|
+
|
|
23
|
+
LiteLLM uses the format "provider/model-name" (e.g., "openai/gpt-4o").
|
|
24
|
+
This function extracts the provider prefix.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model: Model string in LiteLLM format
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Provider name (e.g., "openai", "anthropic", "cohere")
|
|
31
|
+
Returns "unknown" if format doesn't match
|
|
32
|
+
|
|
33
|
+
Examples:
|
|
34
|
+
>>> get_provider_from_model("openai/gpt-4o")
|
|
35
|
+
'openai'
|
|
36
|
+
>>> get_provider_from_model("anthropic/claude-3-opus")
|
|
37
|
+
'anthropic'
|
|
38
|
+
>>> get_provider_from_model("gpt-4o")
|
|
39
|
+
'unknown'
|
|
40
|
+
"""
|
|
41
|
+
if "/" in model:
|
|
42
|
+
return model.split("/")[0]
|
|
43
|
+
return "unknown"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def apply_openai_patches(params: Dict[str, Any]) -> Dict[str, Any]:
|
|
47
|
+
"""
|
|
48
|
+
Apply OpenAI-specific parameter patches.
|
|
49
|
+
|
|
50
|
+
**Patch 1: max_tokens → max_completion_tokens**
|
|
51
|
+
|
|
52
|
+
Reason: OpenAI's newer models (gpt-4o, gpt-4o-mini, etc.) use
|
|
53
|
+
`max_completion_tokens` instead of `max_tokens` to disambiguate between
|
|
54
|
+
input tokens and output tokens. LiteLLM may not always handle this
|
|
55
|
+
transformation automatically for all OpenAI models.
|
|
56
|
+
|
|
57
|
+
This patch ensures compatibility by renaming the parameter when targeting
|
|
58
|
+
OpenAI models.
|
|
59
|
+
|
|
60
|
+
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
params: Parameter dictionary to transform
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Transformed parameter dictionary
|
|
67
|
+
"""
|
|
68
|
+
# Create a copy to avoid mutating the original
|
|
69
|
+
patched = params.copy()
|
|
70
|
+
|
|
71
|
+
# Patch: max_tokens → max_completion_tokens for OpenAI
|
|
72
|
+
if "max_tokens" in patched:
|
|
73
|
+
patched["max_completion_tokens"] = patched.pop("max_tokens")
|
|
74
|
+
|
|
75
|
+
return patched
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def apply_provider_patches(params: Dict[str, Any], model: str) -> Dict[str, Any]:
|
|
79
|
+
"""
|
|
80
|
+
Apply provider-specific parameter transformations.
|
|
81
|
+
|
|
82
|
+
This is the main entry point for all provider-specific patches. It detects
|
|
83
|
+
the provider from the model string and applies appropriate transformations.
|
|
84
|
+
|
|
85
|
+
**When to add a new patch:**
|
|
86
|
+
1. A specific provider requires a different parameter name
|
|
87
|
+
2. A provider has parameter constraints that differ from LiteLLM defaults
|
|
88
|
+
3. There's a known incompatibility that needs a workaround
|
|
89
|
+
|
|
90
|
+
**How to add a new patch:**
|
|
91
|
+
1. Create a new function: `apply_{provider}_patches(params)`
|
|
92
|
+
2. Document the patch reason and affected models
|
|
93
|
+
3. Add a new elif branch in this function
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
params: Parameter dictionary from AIConfig.get_litellm_params()
|
|
97
|
+
model: Model string (e.g., "openai/gpt-4o")
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Transformed parameter dictionary with provider-specific patches applied
|
|
101
|
+
|
|
102
|
+
Examples:
|
|
103
|
+
>>> params = {"model": "openai/gpt-4o", "max_tokens": 1000}
|
|
104
|
+
>>> apply_provider_patches(params, "openai/gpt-4o")
|
|
105
|
+
{'model': 'openai/gpt-4o', 'max_completion_tokens': 1000}
|
|
106
|
+
"""
|
|
107
|
+
provider = get_provider_from_model(model)
|
|
108
|
+
|
|
109
|
+
# Apply provider-specific patches
|
|
110
|
+
if provider == "openai":
|
|
111
|
+
return apply_openai_patches(params)
|
|
112
|
+
|
|
113
|
+
# Add more providers here as needed:
|
|
114
|
+
# elif provider == "anthropic":
|
|
115
|
+
# return apply_anthropic_patches(params)
|
|
116
|
+
# elif provider == "cohere":
|
|
117
|
+
# return apply_cohere_patches(params)
|
|
118
|
+
|
|
119
|
+
# No patches needed for this provider
|
|
120
|
+
return params
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def filter_none_values(params: Dict[str, Any]) -> Dict[str, Any]:
|
|
124
|
+
"""
|
|
125
|
+
Remove None values from parameter dictionary.
|
|
126
|
+
|
|
127
|
+
This ensures we only pass explicitly set parameters to LiteLLM,
|
|
128
|
+
allowing it to use its own defaults for unset values.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
params: Parameter dictionary potentially containing None values
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Dictionary with None values removed
|
|
135
|
+
|
|
136
|
+
Examples:
|
|
137
|
+
>>> filter_none_values({"a": 1, "b": None, "c": "test"})
|
|
138
|
+
{'a': 1, 'c': 'test'}
|
|
139
|
+
"""
|
|
140
|
+
return {k: v for k, v in params.items() if v is not None}
|