jaf-py 2.5.10__py3-none-any.whl → 2.5.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/__init__.py +154 -57
- jaf/a2a/__init__.py +42 -21
- jaf/a2a/agent.py +79 -126
- jaf/a2a/agent_card.py +87 -78
- jaf/a2a/client.py +30 -66
- jaf/a2a/examples/client_example.py +12 -12
- jaf/a2a/examples/integration_example.py +38 -47
- jaf/a2a/examples/server_example.py +56 -53
- jaf/a2a/memory/__init__.py +0 -4
- jaf/a2a/memory/cleanup.py +28 -21
- jaf/a2a/memory/factory.py +155 -133
- jaf/a2a/memory/providers/composite.py +21 -26
- jaf/a2a/memory/providers/in_memory.py +89 -83
- jaf/a2a/memory/providers/postgres.py +117 -115
- jaf/a2a/memory/providers/redis.py +128 -121
- jaf/a2a/memory/serialization.py +77 -87
- jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
- jaf/a2a/memory/tests/test_cleanup.py +211 -94
- jaf/a2a/memory/tests/test_serialization.py +73 -68
- jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
- jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
- jaf/a2a/memory/types.py +91 -53
- jaf/a2a/protocol.py +95 -125
- jaf/a2a/server.py +90 -118
- jaf/a2a/standalone_client.py +30 -43
- jaf/a2a/tests/__init__.py +16 -33
- jaf/a2a/tests/run_tests.py +17 -53
- jaf/a2a/tests/test_agent.py +40 -140
- jaf/a2a/tests/test_client.py +54 -117
- jaf/a2a/tests/test_integration.py +28 -82
- jaf/a2a/tests/test_protocol.py +54 -139
- jaf/a2a/tests/test_types.py +50 -136
- jaf/a2a/types.py +58 -34
- jaf/cli.py +21 -41
- jaf/core/__init__.py +7 -1
- jaf/core/agent_tool.py +93 -72
- jaf/core/analytics.py +257 -207
- jaf/core/checkpoint.py +223 -0
- jaf/core/composition.py +249 -235
- jaf/core/engine.py +817 -519
- jaf/core/errors.py +55 -42
- jaf/core/guardrails.py +276 -202
- jaf/core/handoff.py +47 -31
- jaf/core/parallel_agents.py +69 -75
- jaf/core/performance.py +75 -73
- jaf/core/proxy.py +43 -44
- jaf/core/proxy_helpers.py +24 -27
- jaf/core/regeneration.py +220 -129
- jaf/core/state.py +68 -66
- jaf/core/streaming.py +115 -108
- jaf/core/tool_results.py +111 -101
- jaf/core/tools.py +114 -116
- jaf/core/tracing.py +310 -210
- jaf/core/types.py +403 -151
- jaf/core/workflows.py +209 -168
- jaf/exceptions.py +46 -38
- jaf/memory/__init__.py +1 -6
- jaf/memory/approval_storage.py +54 -77
- jaf/memory/factory.py +4 -4
- jaf/memory/providers/in_memory.py +216 -180
- jaf/memory/providers/postgres.py +216 -146
- jaf/memory/providers/redis.py +173 -116
- jaf/memory/types.py +70 -51
- jaf/memory/utils.py +36 -34
- jaf/plugins/__init__.py +12 -12
- jaf/plugins/base.py +105 -96
- jaf/policies/__init__.py +0 -1
- jaf/policies/handoff.py +37 -46
- jaf/policies/validation.py +76 -52
- jaf/providers/__init__.py +6 -3
- jaf/providers/mcp.py +97 -51
- jaf/providers/model.py +475 -283
- jaf/server/__init__.py +1 -1
- jaf/server/main.py +7 -11
- jaf/server/server.py +514 -359
- jaf/server/types.py +208 -52
- jaf/utils/__init__.py +17 -18
- jaf/utils/attachments.py +111 -116
- jaf/utils/document_processor.py +175 -174
- jaf/visualization/__init__.py +1 -1
- jaf/visualization/example.py +111 -110
- jaf/visualization/functional_core.py +46 -71
- jaf/visualization/graphviz.py +154 -189
- jaf/visualization/imperative_shell.py +7 -16
- jaf/visualization/types.py +8 -4
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/METADATA +2 -2
- jaf_py-2.5.12.dist-info/RECORD +97 -0
- jaf_py-2.5.10.dist-info/RECORD +0 -96
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/WHEEL +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/top_level.txt +0 -0
jaf/providers/model.py
CHANGED
|
@@ -10,23 +10,36 @@ import httpx
|
|
|
10
10
|
import time
|
|
11
11
|
import os
|
|
12
12
|
import base64
|
|
13
|
+
import asyncio
|
|
13
14
|
|
|
14
15
|
from openai import AsyncOpenAI
|
|
15
16
|
from pydantic import BaseModel
|
|
16
17
|
import litellm
|
|
17
18
|
|
|
18
19
|
from ..core.types import (
|
|
19
|
-
Agent,
|
|
20
|
-
|
|
21
|
-
|
|
20
|
+
Agent,
|
|
21
|
+
ContentRole,
|
|
22
|
+
Message,
|
|
23
|
+
ModelProvider,
|
|
24
|
+
RunConfig,
|
|
25
|
+
RunState,
|
|
26
|
+
CompletionStreamChunk,
|
|
27
|
+
ToolCallDelta,
|
|
28
|
+
ToolCallFunctionDelta,
|
|
29
|
+
MessageContentPart,
|
|
30
|
+
get_text_content,
|
|
31
|
+
RetryEvent,
|
|
32
|
+
RetryEventData,
|
|
22
33
|
)
|
|
23
34
|
from ..core.proxy import ProxyConfig
|
|
24
35
|
from ..utils.document_processor import (
|
|
25
|
-
extract_document_content,
|
|
26
|
-
|
|
36
|
+
extract_document_content,
|
|
37
|
+
is_document_supported,
|
|
38
|
+
get_document_description,
|
|
39
|
+
DocumentProcessingError,
|
|
27
40
|
)
|
|
28
41
|
|
|
29
|
-
Ctx = TypeVar(
|
|
42
|
+
Ctx = TypeVar("Ctx")
|
|
30
43
|
|
|
31
44
|
# Vision model caching
|
|
32
45
|
VISION_MODEL_CACHE_TTL = 5 * 60 # 5 minutes
|
|
@@ -34,92 +47,183 @@ VISION_API_TIMEOUT = 3.0 # 3 seconds
|
|
|
34
47
|
_vision_model_cache: Dict[str, Dict[str, Any]] = {}
|
|
35
48
|
MAX_IMAGE_BYTES = int(os.environ.get("JAF_MAX_IMAGE_BYTES", 8 * 1024 * 1024))
|
|
36
49
|
|
|
50
|
+
|
|
37
51
|
async def _is_vision_model(model: str, base_url: str) -> bool:
|
|
38
52
|
"""
|
|
39
53
|
Check if a model supports vision capabilities.
|
|
40
|
-
|
|
54
|
+
|
|
41
55
|
Args:
|
|
42
56
|
model: Model name to check
|
|
43
57
|
base_url: Base URL of the LiteLLM server
|
|
44
|
-
|
|
58
|
+
|
|
45
59
|
Returns:
|
|
46
60
|
True if model supports vision, False otherwise
|
|
47
61
|
"""
|
|
48
62
|
cache_key = f"{base_url}:{model}"
|
|
49
63
|
cached = _vision_model_cache.get(cache_key)
|
|
50
|
-
|
|
51
|
-
if cached and time.time() - cached[
|
|
52
|
-
return cached[
|
|
53
|
-
|
|
64
|
+
|
|
65
|
+
if cached and time.time() - cached["timestamp"] < VISION_MODEL_CACHE_TTL:
|
|
66
|
+
return cached["supports"]
|
|
67
|
+
|
|
54
68
|
try:
|
|
55
69
|
async with httpx.AsyncClient(timeout=VISION_API_TIMEOUT) as client:
|
|
56
70
|
response = await client.get(
|
|
57
|
-
f"{base_url}/model_group/info",
|
|
58
|
-
headers={'accept': 'application/json'}
|
|
71
|
+
f"{base_url}/model_group/info", headers={"accept": "application/json"}
|
|
59
72
|
)
|
|
60
|
-
|
|
73
|
+
|
|
61
74
|
if response.status_code == 200:
|
|
62
75
|
data = response.json()
|
|
63
76
|
model_info = None
|
|
64
|
-
|
|
65
|
-
if
|
|
66
|
-
for m in data[
|
|
67
|
-
if
|
|
68
|
-
model in str(m.get('model_group', ''))):
|
|
77
|
+
|
|
78
|
+
if "data" in data and isinstance(data["data"], list):
|
|
79
|
+
for m in data["data"]:
|
|
80
|
+
if m.get("model_group") == model or model in str(m.get("model_group", "")):
|
|
69
81
|
model_info = m
|
|
70
82
|
break
|
|
71
|
-
|
|
72
|
-
if model_info and
|
|
73
|
-
result = model_info[
|
|
74
|
-
_vision_model_cache[cache_key] = {
|
|
75
|
-
'supports': result,
|
|
76
|
-
'timestamp': time.time()
|
|
77
|
-
}
|
|
83
|
+
|
|
84
|
+
if model_info and "supports_vision" in model_info:
|
|
85
|
+
result = model_info["supports_vision"]
|
|
86
|
+
_vision_model_cache[cache_key] = {"supports": result, "timestamp": time.time()}
|
|
78
87
|
return result
|
|
79
88
|
else:
|
|
80
|
-
print(
|
|
81
|
-
|
|
89
|
+
print(
|
|
90
|
+
f"Warning: Vision API returned status {response.status_code} for model {model}"
|
|
91
|
+
)
|
|
92
|
+
|
|
82
93
|
except Exception as e:
|
|
83
94
|
print(f"Warning: Vision API error for model {model}: {e}")
|
|
84
|
-
|
|
95
|
+
|
|
85
96
|
# Fallback to known vision models
|
|
86
97
|
known_vision_models = [
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
98
|
+
"gpt-4-vision-preview",
|
|
99
|
+
"gpt-4o",
|
|
100
|
+
"gpt-4o-mini",
|
|
101
|
+
"claude-sonnet-4",
|
|
102
|
+
"claude-sonnet-4-20250514",
|
|
103
|
+
"gemini-2.5-flash",
|
|
104
|
+
"gemini-2.5-pro",
|
|
94
105
|
]
|
|
95
|
-
|
|
106
|
+
|
|
96
107
|
is_known_vision_model = any(
|
|
97
|
-
vision_model.lower() in model.lower()
|
|
98
|
-
for vision_model in known_vision_models
|
|
108
|
+
vision_model.lower() in model.lower() for vision_model in known_vision_models
|
|
99
109
|
)
|
|
100
|
-
|
|
101
|
-
_vision_model_cache[cache_key] = {
|
|
102
|
-
|
|
103
|
-
'timestamp': time.time()
|
|
104
|
-
}
|
|
105
|
-
|
|
110
|
+
|
|
111
|
+
_vision_model_cache[cache_key] = {"supports": is_known_vision_model, "timestamp": time.time()}
|
|
112
|
+
|
|
106
113
|
return is_known_vision_model
|
|
107
114
|
|
|
115
|
+
|
|
116
|
+
async def _retry_with_events(
|
|
117
|
+
operation_func,
|
|
118
|
+
state: RunState,
|
|
119
|
+
config: RunConfig,
|
|
120
|
+
operation_name: str = "llm_call",
|
|
121
|
+
max_retries: int = 3,
|
|
122
|
+
backoff_factor: float = 1.0,
|
|
123
|
+
):
|
|
124
|
+
"""
|
|
125
|
+
Wrapper that retries an async operation and emits retry events.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
operation_func: Async function to execute (should accept no arguments)
|
|
129
|
+
state: Current run state
|
|
130
|
+
config: Run configuration with event handler
|
|
131
|
+
operation_name: Name of the operation for logging
|
|
132
|
+
max_retries: Maximum number of retry attempts
|
|
133
|
+
backoff_factor: Exponential backoff multiplier
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Result from operation_func
|
|
137
|
+
|
|
138
|
+
Raises:
|
|
139
|
+
Last exception if all retries are exhausted
|
|
140
|
+
"""
|
|
141
|
+
last_exception = None
|
|
142
|
+
|
|
143
|
+
for attempt in range(max_retries + 1):
|
|
144
|
+
try:
|
|
145
|
+
return await operation_func()
|
|
146
|
+
except Exception as e:
|
|
147
|
+
last_exception = e
|
|
148
|
+
|
|
149
|
+
# Check if this is a retryable HTTP error
|
|
150
|
+
is_retryable = False
|
|
151
|
+
reason = str(e)
|
|
152
|
+
error_details = {"error_type": type(e).__name__, "error_message": str(e)}
|
|
153
|
+
|
|
154
|
+
# Check for HTTP errors (common in OpenAI/LiteLLM)
|
|
155
|
+
if hasattr(e, "status_code"):
|
|
156
|
+
status_code = e.status_code
|
|
157
|
+
error_details["status_code"] = status_code
|
|
158
|
+
|
|
159
|
+
# Retry on rate limits (429) and server errors (5xx)
|
|
160
|
+
if status_code == 429:
|
|
161
|
+
is_retryable = True
|
|
162
|
+
reason = f"HTTP {status_code} - Rate Limit"
|
|
163
|
+
elif 500 <= status_code < 600:
|
|
164
|
+
is_retryable = True
|
|
165
|
+
reason = f"HTTP {status_code} - Server Error"
|
|
166
|
+
else:
|
|
167
|
+
reason = f"HTTP {status_code}"
|
|
168
|
+
|
|
169
|
+
# Check for common exception names
|
|
170
|
+
elif "RateLimitError" in type(e).__name__:
|
|
171
|
+
is_retryable = True
|
|
172
|
+
reason = "Rate Limit Error"
|
|
173
|
+
elif "ServiceUnavailableError" in type(e).__name__ or "APIError" in type(e).__name__:
|
|
174
|
+
is_retryable = True
|
|
175
|
+
reason = "API Error"
|
|
176
|
+
elif "Timeout" in type(e).__name__:
|
|
177
|
+
is_retryable = True
|
|
178
|
+
reason = "Timeout"
|
|
179
|
+
|
|
180
|
+
# If not last attempt and is retryable, retry with backoff
|
|
181
|
+
if attempt < max_retries and is_retryable:
|
|
182
|
+
delay = backoff_factor * (2**attempt) # Exponential backoff
|
|
183
|
+
|
|
184
|
+
# Emit retry event
|
|
185
|
+
if config.on_event:
|
|
186
|
+
retry_event = RetryEvent(
|
|
187
|
+
data=RetryEventData(
|
|
188
|
+
attempt=attempt + 1,
|
|
189
|
+
max_retries=max_retries,
|
|
190
|
+
reason=reason,
|
|
191
|
+
operation=operation_name,
|
|
192
|
+
trace_id=state.trace_id,
|
|
193
|
+
run_id=state.run_id,
|
|
194
|
+
delay=delay,
|
|
195
|
+
error_details=error_details,
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
config.on_event(retry_event)
|
|
199
|
+
|
|
200
|
+
print(
|
|
201
|
+
f"[JAF:RETRY] Attempt {attempt + 1}/{max_retries} failed: {reason}. Retrying in {delay}s..."
|
|
202
|
+
)
|
|
203
|
+
await asyncio.sleep(delay)
|
|
204
|
+
else:
|
|
205
|
+
# Not retryable or last attempt, re-raise
|
|
206
|
+
raise
|
|
207
|
+
|
|
208
|
+
# Should never reach here, but just in case
|
|
209
|
+
raise last_exception
|
|
210
|
+
|
|
211
|
+
|
|
108
212
|
def make_litellm_provider(
|
|
109
213
|
base_url: str,
|
|
110
214
|
api_key: str = "anything",
|
|
111
215
|
default_timeout: Optional[float] = None,
|
|
112
|
-
proxy_config: Optional[ProxyConfig] = None
|
|
216
|
+
proxy_config: Optional[ProxyConfig] = None,
|
|
113
217
|
) -> ModelProvider[Ctx]:
|
|
114
218
|
"""
|
|
115
219
|
Create a LiteLLM-compatible model provider.
|
|
116
|
-
|
|
220
|
+
|
|
117
221
|
Args:
|
|
118
222
|
base_url: Base URL for the LiteLLM server
|
|
119
223
|
api_key: API key (defaults to "anything" for local servers)
|
|
120
224
|
default_timeout: Default timeout for model API calls in seconds
|
|
121
225
|
proxy_config: Optional proxy configuration
|
|
122
|
-
|
|
226
|
+
|
|
123
227
|
Returns:
|
|
124
228
|
ModelProvider instance
|
|
125
229
|
"""
|
|
@@ -128,48 +232,47 @@ def make_litellm_provider(
|
|
|
128
232
|
def __init__(self):
|
|
129
233
|
# Default to "anything" if api_key is not provided, for local servers
|
|
130
234
|
effective_api_key = api_key if api_key is not None else "anything"
|
|
131
|
-
|
|
235
|
+
|
|
132
236
|
# Configure HTTP client with proxy support
|
|
133
237
|
client_kwargs = {
|
|
134
238
|
"base_url": base_url,
|
|
135
239
|
"api_key": effective_api_key,
|
|
136
240
|
}
|
|
137
|
-
|
|
241
|
+
|
|
138
242
|
if proxy_config:
|
|
139
243
|
proxies = proxy_config.to_httpx_proxies()
|
|
140
244
|
if proxies:
|
|
141
245
|
# Create httpx client with proxy configuration
|
|
142
246
|
try:
|
|
143
247
|
# Use the https proxy if available, otherwise http proxy
|
|
144
|
-
proxy_url = proxies.get(
|
|
248
|
+
proxy_url = proxies.get("https://") or proxies.get("http://")
|
|
145
249
|
if proxy_url:
|
|
146
250
|
http_client = httpx.AsyncClient(proxy=proxy_url)
|
|
147
251
|
client_kwargs["http_client"] = http_client
|
|
148
252
|
except Exception as e:
|
|
149
253
|
print(f"Warning: Could not configure proxy: {e}")
|
|
150
254
|
# Fall back to environment variables for proxy
|
|
151
|
-
|
|
255
|
+
|
|
152
256
|
self.client = AsyncOpenAI(**client_kwargs)
|
|
153
257
|
self.default_timeout = default_timeout
|
|
154
258
|
|
|
155
259
|
async def get_completion(
|
|
156
|
-
self,
|
|
157
|
-
state: RunState[Ctx],
|
|
158
|
-
agent: Agent[Ctx, Any],
|
|
159
|
-
config: RunConfig[Ctx]
|
|
260
|
+
self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: RunConfig[Ctx]
|
|
160
261
|
) -> Dict[str, Any]:
|
|
161
262
|
"""Get completion from the model."""
|
|
162
263
|
|
|
163
264
|
# Determine model to use
|
|
164
|
-
model =
|
|
165
|
-
|
|
265
|
+
model = config.model_override or (
|
|
266
|
+
agent.model_config.name if agent.model_config else "gpt-4o"
|
|
267
|
+
)
|
|
166
268
|
|
|
167
269
|
# Check if any message contains image content or image attachments
|
|
168
270
|
has_image_content = any(
|
|
169
|
-
(
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
271
|
+
(
|
|
272
|
+
isinstance(msg.content, list)
|
|
273
|
+
and any(part.type == "image_url" for part in msg.content)
|
|
274
|
+
)
|
|
275
|
+
or (msg.attachments and any(att.kind == "image" for att in msg.attachments))
|
|
173
276
|
for msg in state.messages
|
|
174
277
|
)
|
|
175
278
|
|
|
@@ -182,51 +285,59 @@ def make_litellm_provider(
|
|
|
182
285
|
)
|
|
183
286
|
|
|
184
287
|
# Create system message
|
|
185
|
-
system_message = {
|
|
186
|
-
"role": "system",
|
|
187
|
-
"content": agent.instructions(state)
|
|
188
|
-
}
|
|
288
|
+
system_message = {"role": "system", "content": agent.instructions(state)}
|
|
189
289
|
|
|
190
290
|
# Convert messages to OpenAI format
|
|
191
291
|
converted_messages = []
|
|
192
292
|
for msg in state.messages:
|
|
193
293
|
converted_msg = await _convert_message(msg)
|
|
194
294
|
converted_messages.append(converted_msg)
|
|
195
|
-
|
|
295
|
+
|
|
196
296
|
messages = [system_message] + converted_messages
|
|
197
297
|
|
|
198
298
|
# Convert tools to OpenAI format
|
|
199
299
|
tools = None
|
|
200
300
|
if agent.tools:
|
|
301
|
+
# Check if we should inline schema refs
|
|
302
|
+
inline_refs = (
|
|
303
|
+
agent.model_config.inline_tool_schemas if agent.model_config else False
|
|
304
|
+
)
|
|
201
305
|
tools = [
|
|
202
306
|
{
|
|
203
307
|
"type": "function",
|
|
204
308
|
"function": {
|
|
205
309
|
"name": tool.schema.name,
|
|
206
310
|
"description": tool.schema.description,
|
|
207
|
-
"parameters": _pydantic_to_json_schema(
|
|
208
|
-
|
|
311
|
+
"parameters": _pydantic_to_json_schema(
|
|
312
|
+
tool.schema.parameters, inline_refs=inline_refs or False
|
|
313
|
+
),
|
|
314
|
+
},
|
|
209
315
|
}
|
|
210
316
|
for tool in agent.tools
|
|
211
317
|
]
|
|
212
318
|
|
|
213
319
|
# Determine tool choice behavior
|
|
214
320
|
last_message = state.messages[-1] if state.messages else None
|
|
215
|
-
is_after_tool_call = last_message and (
|
|
321
|
+
is_after_tool_call = last_message and (
|
|
322
|
+
last_message.role == ContentRole.TOOL or last_message.role == "tool"
|
|
323
|
+
)
|
|
216
324
|
|
|
217
325
|
# Prepare request parameters
|
|
218
|
-
request_params = {
|
|
219
|
-
"model": model,
|
|
220
|
-
"messages": messages,
|
|
221
|
-
"stream": False
|
|
222
|
-
}
|
|
326
|
+
request_params = {"model": model, "messages": messages, "stream": False}
|
|
223
327
|
|
|
224
328
|
# Add optional parameters
|
|
225
329
|
if agent.model_config:
|
|
226
330
|
if agent.model_config.temperature is not None:
|
|
227
331
|
request_params["temperature"] = agent.model_config.temperature
|
|
228
|
-
|
|
229
|
-
|
|
332
|
+
# Use agent's max_tokens if set, otherwise fall back to config's max_tokens
|
|
333
|
+
max_tokens = agent.model_config.max_tokens
|
|
334
|
+
if max_tokens is None:
|
|
335
|
+
max_tokens = config.max_tokens
|
|
336
|
+
if max_tokens is not None:
|
|
337
|
+
request_params["max_tokens"] = max_tokens
|
|
338
|
+
elif config.max_tokens is not None:
|
|
339
|
+
# No model_config but config has max_tokens
|
|
340
|
+
request_params["max_tokens"] = config.max_tokens
|
|
230
341
|
|
|
231
342
|
if tools:
|
|
232
343
|
request_params["tools"] = tools
|
|
@@ -236,8 +347,14 @@ def make_litellm_provider(
|
|
|
236
347
|
if agent.output_codec:
|
|
237
348
|
request_params["response_format"] = {"type": "json_object"}
|
|
238
349
|
|
|
239
|
-
# Make the API call
|
|
240
|
-
|
|
350
|
+
# Make the API call with retry handling
|
|
351
|
+
async def _api_call():
|
|
352
|
+
return await self.client.chat.completions.create(**request_params)
|
|
353
|
+
|
|
354
|
+
# Use retry wrapper to track retries in Langfuse
|
|
355
|
+
response = await _retry_with_events(
|
|
356
|
+
_api_call, state, config, operation_name="llm_call", max_retries=3, backoff_factor=1.0
|
|
357
|
+
)
|
|
241
358
|
|
|
242
359
|
# Return in the expected format that the engine expects
|
|
243
360
|
choice = response.choices[0]
|
|
@@ -247,12 +364,9 @@ def make_litellm_provider(
|
|
|
247
364
|
if choice.message.tool_calls:
|
|
248
365
|
tool_calls = [
|
|
249
366
|
{
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
'name': tc.function.name,
|
|
254
|
-
'arguments': tc.function.arguments
|
|
255
|
-
}
|
|
367
|
+
"id": tc.id,
|
|
368
|
+
"type": tc.type,
|
|
369
|
+
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
|
256
370
|
}
|
|
257
371
|
for tc in choice.message.tool_calls
|
|
258
372
|
]
|
|
@@ -267,64 +381,64 @@ def make_litellm_provider(
|
|
|
267
381
|
}
|
|
268
382
|
|
|
269
383
|
return {
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
},
|
|
278
|
-
'usage': usage_data,
|
|
279
|
-
'prompt': messages
|
|
384
|
+
"id": response.id,
|
|
385
|
+
"created": response.created,
|
|
386
|
+
"model": response.model,
|
|
387
|
+
"system_fingerprint": response.system_fingerprint,
|
|
388
|
+
"message": {"content": choice.message.content, "tool_calls": tool_calls},
|
|
389
|
+
"usage": usage_data,
|
|
390
|
+
"prompt": messages,
|
|
280
391
|
}
|
|
281
392
|
|
|
282
393
|
async def get_completion_stream(
|
|
283
|
-
self,
|
|
284
|
-
state: RunState[Ctx],
|
|
285
|
-
agent: Agent[Ctx, Any],
|
|
286
|
-
config: RunConfig[Ctx]
|
|
394
|
+
self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: RunConfig[Ctx]
|
|
287
395
|
) -> AsyncIterator[CompletionStreamChunk]:
|
|
288
396
|
"""
|
|
289
397
|
Stream completion chunks from the model provider, yielding text deltas and tool-call deltas.
|
|
290
398
|
Uses OpenAI-compatible streaming via LiteLLM endpoint.
|
|
291
399
|
"""
|
|
292
400
|
# Determine model to use
|
|
293
|
-
model =
|
|
294
|
-
|
|
401
|
+
model = config.model_override or (
|
|
402
|
+
agent.model_config.name if agent.model_config else "gpt-4o"
|
|
403
|
+
)
|
|
295
404
|
|
|
296
405
|
# Create system message
|
|
297
|
-
system_message = {
|
|
298
|
-
"role": "system",
|
|
299
|
-
"content": agent.instructions(state)
|
|
300
|
-
}
|
|
406
|
+
system_message = {"role": "system", "content": agent.instructions(state)}
|
|
301
407
|
|
|
302
|
-
# Convert messages to OpenAI format
|
|
408
|
+
# Convert messages to OpenAI format
|
|
303
409
|
converted_messages = []
|
|
304
410
|
for msg in state.messages:
|
|
305
411
|
converted_msg = await _convert_message(msg)
|
|
306
412
|
converted_messages.append(converted_msg)
|
|
307
|
-
|
|
413
|
+
|
|
308
414
|
messages = [system_message] + converted_messages
|
|
309
415
|
|
|
310
416
|
# Convert tools to OpenAI format
|
|
311
417
|
tools = None
|
|
312
418
|
if agent.tools:
|
|
419
|
+
# Check if we should inline schema refs
|
|
420
|
+
inline_refs = (
|
|
421
|
+
agent.model_config.inline_tool_schemas if agent.model_config else False
|
|
422
|
+
)
|
|
313
423
|
tools = [
|
|
314
424
|
{
|
|
315
425
|
"type": "function",
|
|
316
426
|
"function": {
|
|
317
427
|
"name": tool.schema.name,
|
|
318
428
|
"description": tool.schema.description,
|
|
319
|
-
"parameters": _pydantic_to_json_schema(
|
|
320
|
-
|
|
429
|
+
"parameters": _pydantic_to_json_schema(
|
|
430
|
+
tool.schema.parameters, inline_refs=inline_refs or False
|
|
431
|
+
),
|
|
432
|
+
},
|
|
321
433
|
}
|
|
322
434
|
for tool in agent.tools
|
|
323
435
|
]
|
|
324
436
|
|
|
325
437
|
# Determine tool choice behavior
|
|
326
438
|
last_message = state.messages[-1] if state.messages else None
|
|
327
|
-
is_after_tool_call = last_message and (
|
|
439
|
+
is_after_tool_call = last_message and (
|
|
440
|
+
last_message.role == ContentRole.TOOL or last_message.role == "tool"
|
|
441
|
+
)
|
|
328
442
|
|
|
329
443
|
# Prepare request parameters
|
|
330
444
|
request_params: Dict[str, Any] = {
|
|
@@ -336,8 +450,15 @@ def make_litellm_provider(
|
|
|
336
450
|
if agent.model_config:
|
|
337
451
|
if agent.model_config.temperature is not None:
|
|
338
452
|
request_params["temperature"] = agent.model_config.temperature
|
|
339
|
-
|
|
340
|
-
|
|
453
|
+
# Use agent's max_tokens if set, otherwise fall back to config's max_tokens
|
|
454
|
+
max_tokens = agent.model_config.max_tokens
|
|
455
|
+
if max_tokens is None:
|
|
456
|
+
max_tokens = config.max_tokens
|
|
457
|
+
if max_tokens is not None:
|
|
458
|
+
request_params["max_tokens"] = max_tokens
|
|
459
|
+
elif config.max_tokens is not None:
|
|
460
|
+
# No model_config but config has max_tokens
|
|
461
|
+
request_params["max_tokens"] = config.max_tokens
|
|
341
462
|
|
|
342
463
|
if tools:
|
|
343
464
|
request_params["tools"] = tools
|
|
@@ -388,19 +509,20 @@ def make_litellm_provider(
|
|
|
388
509
|
fn = getattr(tc, "function", None)
|
|
389
510
|
fn_name = getattr(fn, "name", None) if fn is not None else None
|
|
390
511
|
# OpenAI streams "arguments" as incremental deltas
|
|
391
|
-
args_delta =
|
|
512
|
+
args_delta = (
|
|
513
|
+
getattr(fn, "arguments", None) if fn is not None else None
|
|
514
|
+
)
|
|
392
515
|
|
|
393
516
|
yield CompletionStreamChunk(
|
|
394
517
|
tool_call_delta=ToolCallDelta(
|
|
395
518
|
index=idx,
|
|
396
519
|
id=tc_id,
|
|
397
|
-
type=
|
|
520
|
+
type="function",
|
|
398
521
|
function=ToolCallFunctionDelta(
|
|
399
|
-
name=fn_name,
|
|
400
|
-
|
|
401
|
-
)
|
|
522
|
+
name=fn_name, arguments_delta=args_delta
|
|
523
|
+
),
|
|
402
524
|
),
|
|
403
|
-
raw=raw_obj
|
|
525
|
+
raw=raw_obj,
|
|
404
526
|
)
|
|
405
527
|
except Exception:
|
|
406
528
|
# Skip malformed tool-call deltas
|
|
@@ -408,26 +530,29 @@ def make_litellm_provider(
|
|
|
408
530
|
|
|
409
531
|
# Completion ended
|
|
410
532
|
if finish_reason:
|
|
411
|
-
yield CompletionStreamChunk(
|
|
533
|
+
yield CompletionStreamChunk(
|
|
534
|
+
is_done=True, finish_reason=finish_reason, raw=raw_obj
|
|
535
|
+
)
|
|
412
536
|
except Exception:
|
|
413
537
|
# Skip individual chunk errors, keep streaming
|
|
414
538
|
continue
|
|
415
539
|
|
|
416
540
|
return LiteLLMProvider()
|
|
417
541
|
|
|
542
|
+
|
|
418
543
|
def make_litellm_sdk_provider(
|
|
419
544
|
api_key: Optional[str] = None,
|
|
420
545
|
model: str = "gpt-3.5-turbo",
|
|
421
546
|
base_url: Optional[str] = None,
|
|
422
547
|
default_timeout: Optional[float] = None,
|
|
423
|
-
**litellm_kwargs: Any
|
|
548
|
+
**litellm_kwargs: Any,
|
|
424
549
|
) -> ModelProvider[Ctx]:
|
|
425
550
|
"""
|
|
426
551
|
Create a LiteLLM SDK-based model provider with universal provider support.
|
|
427
|
-
|
|
552
|
+
|
|
428
553
|
LiteLLM automatically detects the provider from the model name and handles
|
|
429
554
|
API key management through environment variables or direct parameters.
|
|
430
|
-
|
|
555
|
+
|
|
431
556
|
Args:
|
|
432
557
|
api_key: API key for the provider (optional, can use env vars)
|
|
433
558
|
model: Model name (e.g., "gpt-4", "claude-3-sonnet", "gemini-pro", "llama2", etc.)
|
|
@@ -440,23 +565,23 @@ def make_litellm_sdk_provider(
|
|
|
440
565
|
- azure_deployment: "your-deployment" (for Azure OpenAI)
|
|
441
566
|
- api_base: "https://your-endpoint.com" (custom endpoint)
|
|
442
567
|
- custom_llm_provider: "custom_provider_name"
|
|
443
|
-
|
|
568
|
+
|
|
444
569
|
Returns:
|
|
445
570
|
ModelProvider instance
|
|
446
|
-
|
|
571
|
+
|
|
447
572
|
Examples:
|
|
448
573
|
# OpenAI
|
|
449
574
|
make_litellm_sdk_provider(api_key="sk-...", model="gpt-4")
|
|
450
|
-
|
|
575
|
+
|
|
451
576
|
# Anthropic Claude
|
|
452
577
|
make_litellm_sdk_provider(api_key="sk-ant-...", model="claude-3-sonnet-20240229")
|
|
453
|
-
|
|
578
|
+
|
|
454
579
|
# Google Gemini
|
|
455
580
|
make_litellm_sdk_provider(model="gemini-pro", vertex_project="my-project")
|
|
456
|
-
|
|
581
|
+
|
|
457
582
|
# Ollama (local)
|
|
458
583
|
make_litellm_sdk_provider(model="ollama/llama2", base_url="http://localhost:11434")
|
|
459
|
-
|
|
584
|
+
|
|
460
585
|
# Azure OpenAI
|
|
461
586
|
make_litellm_sdk_provider(
|
|
462
587
|
model="azure/gpt-4",
|
|
@@ -464,13 +589,13 @@ def make_litellm_sdk_provider(
|
|
|
464
589
|
azure_deployment="gpt-4-deployment",
|
|
465
590
|
api_base="https://your-resource.openai.azure.com"
|
|
466
591
|
)
|
|
467
|
-
|
|
592
|
+
|
|
468
593
|
# Hugging Face
|
|
469
594
|
make_litellm_sdk_provider(
|
|
470
595
|
model="huggingface/microsoft/DialoGPT-medium",
|
|
471
596
|
api_key="hf_..."
|
|
472
597
|
)
|
|
473
|
-
|
|
598
|
+
|
|
474
599
|
# Any custom provider
|
|
475
600
|
make_litellm_sdk_provider(
|
|
476
601
|
model="custom_provider/model-name",
|
|
@@ -488,10 +613,7 @@ def make_litellm_sdk_provider(
|
|
|
488
613
|
self.litellm_kwargs = litellm_kwargs
|
|
489
614
|
|
|
490
615
|
async def get_completion(
|
|
491
|
-
self,
|
|
492
|
-
state: RunState[Ctx],
|
|
493
|
-
agent: Agent[Ctx, Any],
|
|
494
|
-
config: RunConfig[Ctx]
|
|
616
|
+
self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: RunConfig[Ctx]
|
|
495
617
|
) -> Dict[str, Any]:
|
|
496
618
|
"""Get completion from the model using LiteLLM SDK."""
|
|
497
619
|
|
|
@@ -499,10 +621,7 @@ def make_litellm_sdk_provider(
|
|
|
499
621
|
model_name = config.model_override or self.model
|
|
500
622
|
|
|
501
623
|
# Create system message
|
|
502
|
-
system_message = {
|
|
503
|
-
"role": "system",
|
|
504
|
-
"content": agent.instructions(state)
|
|
505
|
-
}
|
|
624
|
+
system_message = {"role": "system", "content": agent.instructions(state)}
|
|
506
625
|
|
|
507
626
|
# Convert messages to OpenAI format
|
|
508
627
|
messages = [system_message]
|
|
@@ -513,24 +632,26 @@ def make_litellm_sdk_provider(
|
|
|
513
632
|
# Convert tools to OpenAI format
|
|
514
633
|
tools = None
|
|
515
634
|
if agent.tools:
|
|
635
|
+
# Check if we should inline schema refs
|
|
636
|
+
inline_refs = (
|
|
637
|
+
agent.model_config.inline_tool_schemas if agent.model_config else False
|
|
638
|
+
)
|
|
516
639
|
tools = [
|
|
517
640
|
{
|
|
518
641
|
"type": "function",
|
|
519
642
|
"function": {
|
|
520
643
|
"name": tool.schema.name,
|
|
521
644
|
"description": tool.schema.description,
|
|
522
|
-
"parameters": _pydantic_to_json_schema(
|
|
523
|
-
|
|
645
|
+
"parameters": _pydantic_to_json_schema(
|
|
646
|
+
tool.schema.parameters, inline_refs=inline_refs or False
|
|
647
|
+
),
|
|
648
|
+
},
|
|
524
649
|
}
|
|
525
650
|
for tool in agent.tools
|
|
526
651
|
]
|
|
527
652
|
|
|
528
653
|
# Prepare request parameters for LiteLLM
|
|
529
|
-
request_params = {
|
|
530
|
-
"model": model_name,
|
|
531
|
-
"messages": messages,
|
|
532
|
-
**self.litellm_kwargs
|
|
533
|
-
}
|
|
654
|
+
request_params = {"model": model_name, "messages": messages, **self.litellm_kwargs}
|
|
534
655
|
|
|
535
656
|
# Add API key if provided
|
|
536
657
|
if self.api_key:
|
|
@@ -540,8 +661,15 @@ def make_litellm_sdk_provider(
|
|
|
540
661
|
if agent.model_config:
|
|
541
662
|
if agent.model_config.temperature is not None:
|
|
542
663
|
request_params["temperature"] = agent.model_config.temperature
|
|
543
|
-
|
|
544
|
-
|
|
664
|
+
# Use agent's max_tokens if set, otherwise fall back to config's max_tokens
|
|
665
|
+
max_tokens = agent.model_config.max_tokens
|
|
666
|
+
if max_tokens is None:
|
|
667
|
+
max_tokens = config.max_tokens
|
|
668
|
+
if max_tokens is not None:
|
|
669
|
+
request_params["max_tokens"] = max_tokens
|
|
670
|
+
elif config.max_tokens is not None:
|
|
671
|
+
# No model_config but config has max_tokens
|
|
672
|
+
request_params["max_tokens"] = config.max_tokens
|
|
545
673
|
|
|
546
674
|
if tools:
|
|
547
675
|
request_params["tools"] = tools
|
|
@@ -554,8 +682,14 @@ def make_litellm_sdk_provider(
|
|
|
554
682
|
if self.base_url:
|
|
555
683
|
request_params["api_base"] = self.base_url
|
|
556
684
|
|
|
557
|
-
# Make the API call using litellm
|
|
558
|
-
|
|
685
|
+
# Make the API call using litellm with retry handling
|
|
686
|
+
async def _api_call():
|
|
687
|
+
return await litellm.acompletion(**request_params)
|
|
688
|
+
|
|
689
|
+
# Use retry wrapper to track retries in Langfuse
|
|
690
|
+
response = await _retry_with_events(
|
|
691
|
+
_api_call, state, config, operation_name="llm_call", max_retries=3, backoff_factor=1.0
|
|
692
|
+
)
|
|
559
693
|
|
|
560
694
|
# Return in the expected format that the engine expects
|
|
561
695
|
choice = response.choices[0]
|
|
@@ -565,12 +699,9 @@ def make_litellm_sdk_provider(
|
|
|
565
699
|
if choice.message.tool_calls:
|
|
566
700
|
tool_calls = [
|
|
567
701
|
{
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
'name': tc.function.name,
|
|
572
|
-
'arguments': tc.function.arguments
|
|
573
|
-
}
|
|
702
|
+
"id": tc.id,
|
|
703
|
+
"type": tc.type,
|
|
704
|
+
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
|
574
705
|
}
|
|
575
706
|
for tc in choice.message.tool_calls
|
|
576
707
|
]
|
|
@@ -585,23 +716,17 @@ def make_litellm_sdk_provider(
|
|
|
585
716
|
}
|
|
586
717
|
|
|
587
718
|
return {
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
},
|
|
596
|
-
'usage': usage_data,
|
|
597
|
-
'prompt': messages
|
|
719
|
+
"id": response.id,
|
|
720
|
+
"created": response.created,
|
|
721
|
+
"model": response.model,
|
|
722
|
+
"system_fingerprint": getattr(response, "system_fingerprint", None),
|
|
723
|
+
"message": {"content": choice.message.content, "tool_calls": tool_calls},
|
|
724
|
+
"usage": usage_data,
|
|
725
|
+
"prompt": messages,
|
|
598
726
|
}
|
|
599
727
|
|
|
600
728
|
async def get_completion_stream(
|
|
601
|
-
self,
|
|
602
|
-
state: RunState[Ctx],
|
|
603
|
-
agent: Agent[Ctx, Any],
|
|
604
|
-
config: RunConfig[Ctx]
|
|
729
|
+
self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: RunConfig[Ctx]
|
|
605
730
|
) -> AsyncIterator[CompletionStreamChunk]:
|
|
606
731
|
"""
|
|
607
732
|
Stream completion chunks from the model provider using LiteLLM SDK.
|
|
@@ -610,10 +735,7 @@ def make_litellm_sdk_provider(
|
|
|
610
735
|
model_name = config.model_override or self.model
|
|
611
736
|
|
|
612
737
|
# Create system message
|
|
613
|
-
system_message = {
|
|
614
|
-
"role": "system",
|
|
615
|
-
"content": agent.instructions(state)
|
|
616
|
-
}
|
|
738
|
+
system_message = {"role": "system", "content": agent.instructions(state)}
|
|
617
739
|
|
|
618
740
|
# Convert messages to OpenAI format
|
|
619
741
|
messages = [system_message]
|
|
@@ -624,14 +746,20 @@ def make_litellm_sdk_provider(
|
|
|
624
746
|
# Convert tools to OpenAI format
|
|
625
747
|
tools = None
|
|
626
748
|
if agent.tools:
|
|
749
|
+
# Check if we should inline schema refs
|
|
750
|
+
inline_refs = (
|
|
751
|
+
agent.model_config.inline_tool_schemas if agent.model_config else False
|
|
752
|
+
)
|
|
627
753
|
tools = [
|
|
628
754
|
{
|
|
629
755
|
"type": "function",
|
|
630
756
|
"function": {
|
|
631
757
|
"name": tool.schema.name,
|
|
632
758
|
"description": tool.schema.description,
|
|
633
|
-
"parameters": _pydantic_to_json_schema(
|
|
634
|
-
|
|
759
|
+
"parameters": _pydantic_to_json_schema(
|
|
760
|
+
tool.schema.parameters, inline_refs=inline_refs or False
|
|
761
|
+
),
|
|
762
|
+
},
|
|
635
763
|
}
|
|
636
764
|
for tool in agent.tools
|
|
637
765
|
]
|
|
@@ -641,7 +769,7 @@ def make_litellm_sdk_provider(
|
|
|
641
769
|
"model": model_name,
|
|
642
770
|
"messages": messages,
|
|
643
771
|
"stream": True,
|
|
644
|
-
**self.litellm_kwargs
|
|
772
|
+
**self.litellm_kwargs,
|
|
645
773
|
}
|
|
646
774
|
|
|
647
775
|
# Add API key if provided
|
|
@@ -652,8 +780,15 @@ def make_litellm_sdk_provider(
|
|
|
652
780
|
if agent.model_config:
|
|
653
781
|
if agent.model_config.temperature is not None:
|
|
654
782
|
request_params["temperature"] = agent.model_config.temperature
|
|
655
|
-
|
|
656
|
-
|
|
783
|
+
# Use agent's max_tokens if set, otherwise fall back to config's max_tokens
|
|
784
|
+
max_tokens = agent.model_config.max_tokens
|
|
785
|
+
if max_tokens is None:
|
|
786
|
+
max_tokens = config.max_tokens
|
|
787
|
+
if max_tokens is not None:
|
|
788
|
+
request_params["max_tokens"] = max_tokens
|
|
789
|
+
elif config.max_tokens is not None:
|
|
790
|
+
# No model_config but config has max_tokens
|
|
791
|
+
request_params["max_tokens"] = config.max_tokens
|
|
657
792
|
|
|
658
793
|
if tools:
|
|
659
794
|
request_params["tools"] = tools
|
|
@@ -668,12 +803,12 @@ def make_litellm_sdk_provider(
|
|
|
668
803
|
|
|
669
804
|
# Stream using litellm
|
|
670
805
|
stream = await litellm.acompletion(**request_params)
|
|
671
|
-
|
|
806
|
+
|
|
672
807
|
async for chunk in stream:
|
|
673
808
|
try:
|
|
674
809
|
# Best-effort extraction of raw for debugging
|
|
675
810
|
try:
|
|
676
|
-
raw_obj = chunk.model_dump() if hasattr(chunk,
|
|
811
|
+
raw_obj = chunk.model_dump() if hasattr(chunk, "model_dump") else None
|
|
677
812
|
except Exception:
|
|
678
813
|
raw_obj = None
|
|
679
814
|
|
|
@@ -702,52 +837,59 @@ def make_litellm_sdk_provider(
|
|
|
702
837
|
tc_id = getattr(tc, "id", None)
|
|
703
838
|
fn = getattr(tc, "function", None)
|
|
704
839
|
fn_name = getattr(fn, "name", None) if fn is not None else None
|
|
705
|
-
args_delta =
|
|
840
|
+
args_delta = (
|
|
841
|
+
getattr(fn, "arguments", None) if fn is not None else None
|
|
842
|
+
)
|
|
706
843
|
|
|
707
844
|
yield CompletionStreamChunk(
|
|
708
845
|
tool_call_delta=ToolCallDelta(
|
|
709
846
|
index=idx,
|
|
710
847
|
id=tc_id,
|
|
711
|
-
type=
|
|
848
|
+
type="function",
|
|
712
849
|
function=ToolCallFunctionDelta(
|
|
713
|
-
name=fn_name,
|
|
714
|
-
|
|
715
|
-
)
|
|
850
|
+
name=fn_name, arguments_delta=args_delta
|
|
851
|
+
),
|
|
716
852
|
),
|
|
717
|
-
raw=raw_obj
|
|
853
|
+
raw=raw_obj,
|
|
718
854
|
)
|
|
719
855
|
except Exception:
|
|
720
856
|
continue
|
|
721
857
|
|
|
722
858
|
# Completion ended
|
|
723
859
|
if finish_reason:
|
|
724
|
-
yield CompletionStreamChunk(
|
|
860
|
+
yield CompletionStreamChunk(
|
|
861
|
+
is_done=True, finish_reason=finish_reason, raw=raw_obj
|
|
862
|
+
)
|
|
725
863
|
except Exception:
|
|
726
864
|
continue
|
|
727
865
|
|
|
728
866
|
return LiteLLMSDKProvider()
|
|
729
867
|
|
|
868
|
+
|
|
730
869
|
async def _convert_message(msg: Message) -> Dict[str, Any]:
|
|
731
870
|
"""
|
|
732
871
|
Handles all possible role types (string and enum) and content formats.
|
|
733
872
|
"""
|
|
734
873
|
# Normalize role to handle both string and enum values
|
|
735
|
-
role_value = msg.role.value if hasattr(msg.role,
|
|
874
|
+
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role).lower()
|
|
736
875
|
|
|
737
876
|
# Handle user messages
|
|
738
|
-
if role_value in (
|
|
877
|
+
if role_value in ("user", ContentRole.USER.value if hasattr(ContentRole, "USER") else "user"):
|
|
739
878
|
if isinstance(msg.content, list):
|
|
740
879
|
# Multi-part content
|
|
741
880
|
return {
|
|
742
881
|
"role": "user",
|
|
743
|
-
"content": [_convert_content_part(part) for part in msg.content]
|
|
882
|
+
"content": [_convert_content_part(part) for part in msg.content],
|
|
744
883
|
}
|
|
745
884
|
else:
|
|
746
885
|
# Build message with attachments if available
|
|
747
|
-
return await _build_chat_message_with_attachments(
|
|
886
|
+
return await _build_chat_message_with_attachments("user", msg)
|
|
748
887
|
|
|
749
888
|
# Handle assistant messages
|
|
750
|
-
elif role_value in (
|
|
889
|
+
elif role_value in (
|
|
890
|
+
"assistant",
|
|
891
|
+
ContentRole.ASSISTANT.value if hasattr(ContentRole, "ASSISTANT") else "assistant",
|
|
892
|
+
):
|
|
751
893
|
result = {
|
|
752
894
|
"role": "assistant",
|
|
753
895
|
"content": get_text_content(msg.content) or "", # Ensure content is never None
|
|
@@ -759,10 +901,7 @@ async def _convert_message(msg: Message) -> Dict[str, Any]:
|
|
|
759
901
|
{
|
|
760
902
|
"id": tc.id,
|
|
761
903
|
"type": tc.type,
|
|
762
|
-
"function": {
|
|
763
|
-
"name": tc.function.name,
|
|
764
|
-
"arguments": tc.function.arguments
|
|
765
|
-
}
|
|
904
|
+
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
|
766
905
|
}
|
|
767
906
|
for tc in msg.tool_calls
|
|
768
907
|
if tc.id and tc.function and tc.function.name # Validate tool call structure
|
|
@@ -771,37 +910,37 @@ async def _convert_message(msg: Message) -> Dict[str, Any]:
|
|
|
771
910
|
return result
|
|
772
911
|
|
|
773
912
|
# Handle system messages
|
|
774
|
-
elif role_value in (
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
}
|
|
913
|
+
elif role_value in (
|
|
914
|
+
"system",
|
|
915
|
+
ContentRole.SYSTEM.value if hasattr(ContentRole, "SYSTEM") else "system",
|
|
916
|
+
):
|
|
917
|
+
return {"role": "system", "content": get_text_content(msg.content) or ""}
|
|
779
918
|
|
|
780
919
|
# Handle tool messages
|
|
781
|
-
elif role_value in (
|
|
920
|
+
elif role_value in ("tool", ContentRole.TOOL.value if hasattr(ContentRole, "TOOL") else "tool"):
|
|
782
921
|
if not msg.tool_call_id:
|
|
783
922
|
raise ValueError(f"Tool message must have tool_call_id. Message: {msg}")
|
|
784
923
|
|
|
785
924
|
return {
|
|
786
925
|
"role": "tool",
|
|
787
926
|
"content": get_text_content(msg.content) or "",
|
|
788
|
-
"tool_call_id": msg.tool_call_id
|
|
927
|
+
"tool_call_id": msg.tool_call_id,
|
|
789
928
|
}
|
|
790
929
|
|
|
791
930
|
# Handle function messages (legacy support)
|
|
792
|
-
elif role_value ==
|
|
931
|
+
elif role_value == "function":
|
|
793
932
|
if not msg.tool_call_id:
|
|
794
933
|
raise ValueError(f"Function message must have tool_call_id. Message: {msg}")
|
|
795
934
|
|
|
796
935
|
return {
|
|
797
936
|
"role": "function",
|
|
798
937
|
"content": get_text_content(msg.content) or "",
|
|
799
|
-
"name": getattr(msg,
|
|
938
|
+
"name": getattr(msg, "name", "unknown_function"),
|
|
800
939
|
}
|
|
801
940
|
|
|
802
941
|
# Unknown role - provide helpful error message
|
|
803
942
|
else:
|
|
804
|
-
available_roles = [
|
|
943
|
+
available_roles = ["user", "assistant", "system", "tool", "function"]
|
|
805
944
|
raise ValueError(
|
|
806
945
|
f"Unknown message role: {msg.role} (type: {type(msg.role)}). "
|
|
807
946
|
f"Supported roles: {available_roles}. "
|
|
@@ -811,46 +950,31 @@ async def _convert_message(msg: Message) -> Dict[str, Any]:
|
|
|
811
950
|
|
|
812
951
|
def _convert_content_part(part: MessageContentPart) -> Dict[str, Any]:
|
|
813
952
|
"""Convert MessageContentPart to OpenAI format."""
|
|
814
|
-
if part.type ==
|
|
815
|
-
return {
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
return {
|
|
821
|
-
"type": "image_url",
|
|
822
|
-
"image_url": part.image_url
|
|
823
|
-
}
|
|
824
|
-
elif part.type == 'file':
|
|
825
|
-
return {
|
|
826
|
-
"type": "file",
|
|
827
|
-
"file": part.file
|
|
828
|
-
}
|
|
953
|
+
if part.type == "text":
|
|
954
|
+
return {"type": "text", "text": part.text}
|
|
955
|
+
elif part.type == "image_url":
|
|
956
|
+
return {"type": "image_url", "image_url": part.image_url}
|
|
957
|
+
elif part.type == "file":
|
|
958
|
+
return {"type": "file", "file": part.file}
|
|
829
959
|
else:
|
|
830
960
|
raise ValueError(f"Unknown content part type: {part.type}")
|
|
831
961
|
|
|
832
962
|
|
|
833
|
-
async def _build_chat_message_with_attachments(
|
|
834
|
-
role: str,
|
|
835
|
-
msg: Message
|
|
836
|
-
) -> Dict[str, Any]:
|
|
963
|
+
async def _build_chat_message_with_attachments(role: str, msg: Message) -> Dict[str, Any]:
|
|
837
964
|
"""
|
|
838
965
|
Build multi-part content for Chat Completions if attachments exist.
|
|
839
966
|
Supports images via image_url and documents via content extraction.
|
|
840
967
|
"""
|
|
841
968
|
has_attachments = msg.attachments and len(msg.attachments) > 0
|
|
842
969
|
if not has_attachments:
|
|
843
|
-
if role ==
|
|
970
|
+
if role == "assistant":
|
|
844
971
|
base_msg = {"role": "assistant", "content": get_text_content(msg.content)}
|
|
845
972
|
if msg.tool_calls:
|
|
846
973
|
base_msg["tool_calls"] = [
|
|
847
974
|
{
|
|
848
975
|
"id": tc.id,
|
|
849
976
|
"type": tc.type,
|
|
850
|
-
"function": {
|
|
851
|
-
"name": tc.function.name,
|
|
852
|
-
"arguments": tc.function.arguments
|
|
853
|
-
}
|
|
977
|
+
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
|
854
978
|
}
|
|
855
979
|
for tc in msg.tool_calls
|
|
856
980
|
]
|
|
@@ -863,7 +987,7 @@ async def _build_chat_message_with_attachments(
|
|
|
863
987
|
parts.append({"type": "text", "text": text_content})
|
|
864
988
|
|
|
865
989
|
for att in msg.attachments:
|
|
866
|
-
if att.kind ==
|
|
990
|
+
if att.kind == "image":
|
|
867
991
|
# Prefer explicit URL; otherwise construct a data URL from base64
|
|
868
992
|
url = att.url
|
|
869
993
|
if not url and att.data and att.mime_type:
|
|
@@ -871,100 +995,168 @@ async def _build_chat_message_with_attachments(
|
|
|
871
995
|
try:
|
|
872
996
|
# Estimate decoded size (base64 is ~4/3 of decoded size)
|
|
873
997
|
estimated_size = len(att.data) * 3 // 4
|
|
874
|
-
|
|
998
|
+
|
|
875
999
|
if estimated_size > MAX_IMAGE_BYTES:
|
|
876
|
-
print(
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
1000
|
+
print(
|
|
1001
|
+
f"Warning: Skipping oversized image ({estimated_size} bytes > {MAX_IMAGE_BYTES}). "
|
|
1002
|
+
f"Set JAF_MAX_IMAGE_BYTES env var to adjust limit."
|
|
1003
|
+
)
|
|
1004
|
+
parts.append(
|
|
1005
|
+
{
|
|
1006
|
+
"type": "text",
|
|
1007
|
+
"text": f"[IMAGE SKIPPED: Size exceeds limit of {MAX_IMAGE_BYTES // 1024 // 1024}MB. "
|
|
1008
|
+
f"Image name: {att.name or 'unnamed'}]",
|
|
1009
|
+
}
|
|
1010
|
+
)
|
|
883
1011
|
continue
|
|
884
|
-
|
|
1012
|
+
|
|
885
1013
|
# Create data URL for valid-sized images
|
|
886
1014
|
url = f"data:{att.mime_type};base64,{att.data}"
|
|
887
1015
|
except Exception as e:
|
|
888
1016
|
print(f"Error processing image data: {e}")
|
|
889
|
-
parts.append(
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
1017
|
+
parts.append(
|
|
1018
|
+
{
|
|
1019
|
+
"type": "text",
|
|
1020
|
+
"text": f"[IMAGE ERROR: Failed to process image data. Image name: {att.name or 'unnamed'}]",
|
|
1021
|
+
}
|
|
1022
|
+
)
|
|
893
1023
|
continue
|
|
894
|
-
|
|
1024
|
+
|
|
895
1025
|
if url:
|
|
896
|
-
parts.append({
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
})
|
|
900
|
-
|
|
901
|
-
elif att.kind in ['document', 'file']:
|
|
1026
|
+
parts.append({"type": "image_url", "image_url": {"url": url}})
|
|
1027
|
+
|
|
1028
|
+
elif att.kind in ["document", "file"]:
|
|
902
1029
|
# Check if attachment has use_litellm_format flag or is a large document
|
|
903
1030
|
use_litellm_format = att.use_litellm_format is True
|
|
904
|
-
|
|
1031
|
+
|
|
905
1032
|
if use_litellm_format and (att.url or att.data):
|
|
906
1033
|
# For now, fall back to content extraction since most providers don't support native file format
|
|
907
1034
|
# TODO: Add provider-specific file format support
|
|
908
|
-
print(
|
|
1035
|
+
print(
|
|
1036
|
+
f"Info: LiteLLM format requested for {att.name}, falling back to content extraction"
|
|
1037
|
+
)
|
|
909
1038
|
use_litellm_format = False
|
|
910
|
-
|
|
1039
|
+
|
|
911
1040
|
if not use_litellm_format:
|
|
912
1041
|
# Extract document content if supported and we have data or URL
|
|
913
1042
|
if is_document_supported(att.mime_type) and (att.data or att.url):
|
|
914
1043
|
try:
|
|
915
1044
|
processed = await extract_document_content(att)
|
|
916
|
-
file_name = att.name or
|
|
1045
|
+
file_name = att.name or "document"
|
|
917
1046
|
description = get_document_description(att.mime_type)
|
|
918
|
-
|
|
919
|
-
parts.append(
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
1047
|
+
|
|
1048
|
+
parts.append(
|
|
1049
|
+
{
|
|
1050
|
+
"type": "text",
|
|
1051
|
+
"text": f"DOCUMENT: {file_name} ({description}):\n\n{processed.content}",
|
|
1052
|
+
}
|
|
1053
|
+
)
|
|
923
1054
|
except DocumentProcessingError as e:
|
|
924
1055
|
# Fallback to filename if extraction fails
|
|
925
|
-
label = att.name or att.format or att.mime_type or
|
|
926
|
-
parts.append(
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
1056
|
+
label = att.name or att.format or att.mime_type or "attachment"
|
|
1057
|
+
parts.append(
|
|
1058
|
+
{
|
|
1059
|
+
"type": "text",
|
|
1060
|
+
"text": f"ERROR: Failed to process {att.kind}: {label} ({e})",
|
|
1061
|
+
}
|
|
1062
|
+
)
|
|
930
1063
|
else:
|
|
931
1064
|
# Unsupported document type - show placeholder
|
|
932
|
-
label = att.name or att.format or att.mime_type or
|
|
1065
|
+
label = att.name or att.format or att.mime_type or "attachment"
|
|
933
1066
|
url_info = f" ({att.url})" if att.url else ""
|
|
934
|
-
parts.append(
|
|
935
|
-
"type": "text",
|
|
936
|
-
|
|
937
|
-
})
|
|
1067
|
+
parts.append(
|
|
1068
|
+
{"type": "text", "text": f"ATTACHMENT: {att.kind}: {label}{url_info}"}
|
|
1069
|
+
)
|
|
938
1070
|
|
|
939
1071
|
base_msg = {"role": role, "content": parts}
|
|
940
|
-
if role ==
|
|
1072
|
+
if role == "assistant" and msg.tool_calls:
|
|
941
1073
|
base_msg["tool_calls"] = [
|
|
942
1074
|
{
|
|
943
1075
|
"id": tc.id,
|
|
944
1076
|
"type": tc.type,
|
|
945
|
-
"function": {
|
|
946
|
-
"name": tc.function.name,
|
|
947
|
-
"arguments": tc.function.arguments
|
|
948
|
-
}
|
|
1077
|
+
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
|
949
1078
|
}
|
|
950
1079
|
for tc in msg.tool_calls
|
|
951
1080
|
]
|
|
952
|
-
|
|
1081
|
+
|
|
953
1082
|
return base_msg
|
|
954
1083
|
|
|
955
|
-
|
|
1084
|
+
|
|
1085
|
+
def _resolve_schema_refs(
|
|
1086
|
+
schema: Dict[str, Any], defs: Optional[Dict[str, Any]] = None
|
|
1087
|
+
) -> Dict[str, Any]:
|
|
1088
|
+
"""
|
|
1089
|
+
Recursively resolve $ref references in a JSON schema by inlining definitions.
|
|
1090
|
+
|
|
1091
|
+
Args:
|
|
1092
|
+
schema: The schema object to process (may contain $ref)
|
|
1093
|
+
defs: The $defs dictionary containing reusable definitions
|
|
1094
|
+
|
|
1095
|
+
Returns:
|
|
1096
|
+
Schema with all references resolved inline
|
|
1097
|
+
"""
|
|
1098
|
+
if defs is None:
|
|
1099
|
+
# Extract $defs from root schema if present
|
|
1100
|
+
defs = schema.get("$defs", {})
|
|
1101
|
+
|
|
1102
|
+
# If this is a reference, resolve it
|
|
1103
|
+
if isinstance(schema, dict) and "$ref" in schema:
|
|
1104
|
+
ref_path = schema["$ref"]
|
|
1105
|
+
|
|
1106
|
+
# Handle #/$defs/DefinitionName format
|
|
1107
|
+
if ref_path.startswith("#/$defs/"):
|
|
1108
|
+
def_name = ref_path.split("/")[-1]
|
|
1109
|
+
if def_name in defs:
|
|
1110
|
+
# Recursively resolve the definition (it might have refs too)
|
|
1111
|
+
resolved_def = _resolve_schema_refs(defs[def_name], defs)
|
|
1112
|
+
return resolved_def
|
|
1113
|
+
else:
|
|
1114
|
+
# If definition not found, return the original ref
|
|
1115
|
+
return schema
|
|
1116
|
+
else:
|
|
1117
|
+
# Other ref formats - return as is
|
|
1118
|
+
return schema
|
|
1119
|
+
|
|
1120
|
+
# If this is a dict, recursively process all values
|
|
1121
|
+
if isinstance(schema, dict):
|
|
1122
|
+
result = {}
|
|
1123
|
+
for key, value in schema.items():
|
|
1124
|
+
# Skip $defs as we're inlining them
|
|
1125
|
+
if key == "$defs":
|
|
1126
|
+
continue
|
|
1127
|
+
result[key] = _resolve_schema_refs(value, defs)
|
|
1128
|
+
return result
|
|
1129
|
+
|
|
1130
|
+
# If this is a list, recursively process all items
|
|
1131
|
+
if isinstance(schema, list):
|
|
1132
|
+
return [_resolve_schema_refs(item, defs) for item in schema]
|
|
1133
|
+
|
|
1134
|
+
# For primitive types, return as is
|
|
1135
|
+
return schema
|
|
1136
|
+
|
|
1137
|
+
|
|
1138
|
+
def _pydantic_to_json_schema(
|
|
1139
|
+
model_class: type[BaseModel], inline_refs: bool = False
|
|
1140
|
+
) -> Dict[str, Any]:
|
|
956
1141
|
"""
|
|
957
1142
|
Convert a Pydantic model to JSON schema for OpenAI tools.
|
|
958
|
-
|
|
1143
|
+
|
|
959
1144
|
Args:
|
|
960
1145
|
model_class: Pydantic model class
|
|
961
|
-
|
|
1146
|
+
inline_refs: If True, resolve $refs and inline $defs in the schema
|
|
1147
|
+
|
|
962
1148
|
Returns:
|
|
963
1149
|
JSON schema dictionary
|
|
964
1150
|
"""
|
|
965
|
-
if hasattr(model_class,
|
|
1151
|
+
if hasattr(model_class, "model_json_schema"):
|
|
966
1152
|
# Pydantic v2
|
|
967
|
-
|
|
1153
|
+
schema = model_class.model_json_schema()
|
|
968
1154
|
else:
|
|
969
1155
|
# Pydantic v1 fallback
|
|
970
|
-
|
|
1156
|
+
schema = model_class.schema()
|
|
1157
|
+
|
|
1158
|
+
# If inline_refs is True, resolve all references
|
|
1159
|
+
if inline_refs:
|
|
1160
|
+
schema = _resolve_schema_refs(schema)
|
|
1161
|
+
|
|
1162
|
+
return schema
|