jaf-py 2.5.10__py3-none-any.whl → 2.5.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/__init__.py +154 -57
- jaf/a2a/__init__.py +42 -21
- jaf/a2a/agent.py +79 -126
- jaf/a2a/agent_card.py +87 -78
- jaf/a2a/client.py +30 -66
- jaf/a2a/examples/client_example.py +12 -12
- jaf/a2a/examples/integration_example.py +38 -47
- jaf/a2a/examples/server_example.py +56 -53
- jaf/a2a/memory/__init__.py +0 -4
- jaf/a2a/memory/cleanup.py +28 -21
- jaf/a2a/memory/factory.py +155 -133
- jaf/a2a/memory/providers/composite.py +21 -26
- jaf/a2a/memory/providers/in_memory.py +89 -83
- jaf/a2a/memory/providers/postgres.py +117 -115
- jaf/a2a/memory/providers/redis.py +128 -121
- jaf/a2a/memory/serialization.py +77 -87
- jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
- jaf/a2a/memory/tests/test_cleanup.py +211 -94
- jaf/a2a/memory/tests/test_serialization.py +73 -68
- jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
- jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
- jaf/a2a/memory/types.py +91 -53
- jaf/a2a/protocol.py +95 -125
- jaf/a2a/server.py +90 -118
- jaf/a2a/standalone_client.py +30 -43
- jaf/a2a/tests/__init__.py +16 -33
- jaf/a2a/tests/run_tests.py +17 -53
- jaf/a2a/tests/test_agent.py +40 -140
- jaf/a2a/tests/test_client.py +54 -117
- jaf/a2a/tests/test_integration.py +28 -82
- jaf/a2a/tests/test_protocol.py +54 -139
- jaf/a2a/tests/test_types.py +50 -136
- jaf/a2a/types.py +58 -34
- jaf/cli.py +21 -41
- jaf/core/__init__.py +7 -1
- jaf/core/agent_tool.py +93 -72
- jaf/core/analytics.py +257 -207
- jaf/core/checkpoint.py +223 -0
- jaf/core/composition.py +249 -235
- jaf/core/engine.py +817 -519
- jaf/core/errors.py +55 -42
- jaf/core/guardrails.py +276 -202
- jaf/core/handoff.py +47 -31
- jaf/core/parallel_agents.py +69 -75
- jaf/core/performance.py +75 -73
- jaf/core/proxy.py +43 -44
- jaf/core/proxy_helpers.py +24 -27
- jaf/core/regeneration.py +220 -129
- jaf/core/state.py +68 -66
- jaf/core/streaming.py +115 -108
- jaf/core/tool_results.py +111 -101
- jaf/core/tools.py +114 -116
- jaf/core/tracing.py +269 -210
- jaf/core/types.py +371 -151
- jaf/core/workflows.py +209 -168
- jaf/exceptions.py +46 -38
- jaf/memory/__init__.py +1 -6
- jaf/memory/approval_storage.py +54 -77
- jaf/memory/factory.py +4 -4
- jaf/memory/providers/in_memory.py +216 -180
- jaf/memory/providers/postgres.py +216 -146
- jaf/memory/providers/redis.py +173 -116
- jaf/memory/types.py +70 -51
- jaf/memory/utils.py +36 -34
- jaf/plugins/__init__.py +12 -12
- jaf/plugins/base.py +105 -96
- jaf/policies/__init__.py +0 -1
- jaf/policies/handoff.py +37 -46
- jaf/policies/validation.py +76 -52
- jaf/providers/__init__.py +6 -3
- jaf/providers/mcp.py +97 -51
- jaf/providers/model.py +360 -279
- jaf/server/__init__.py +1 -1
- jaf/server/main.py +7 -11
- jaf/server/server.py +514 -359
- jaf/server/types.py +208 -52
- jaf/utils/__init__.py +17 -18
- jaf/utils/attachments.py +111 -116
- jaf/utils/document_processor.py +175 -174
- jaf/visualization/__init__.py +1 -1
- jaf/visualization/example.py +111 -110
- jaf/visualization/functional_core.py +46 -71
- jaf/visualization/graphviz.py +154 -189
- jaf/visualization/imperative_shell.py +7 -16
- jaf/visualization/types.py +8 -4
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/METADATA +2 -2
- jaf_py-2.5.11.dist-info/RECORD +97 -0
- jaf_py-2.5.10.dist-info/RECORD +0 -96
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/WHEEL +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/top_level.txt +0 -0
jaf/providers/model.py
CHANGED
|
@@ -16,17 +16,27 @@ from pydantic import BaseModel
|
|
|
16
16
|
import litellm
|
|
17
17
|
|
|
18
18
|
from ..core.types import (
|
|
19
|
-
Agent,
|
|
20
|
-
|
|
21
|
-
|
|
19
|
+
Agent,
|
|
20
|
+
ContentRole,
|
|
21
|
+
Message,
|
|
22
|
+
ModelProvider,
|
|
23
|
+
RunConfig,
|
|
24
|
+
RunState,
|
|
25
|
+
CompletionStreamChunk,
|
|
26
|
+
ToolCallDelta,
|
|
27
|
+
ToolCallFunctionDelta,
|
|
28
|
+
MessageContentPart,
|
|
29
|
+
get_text_content,
|
|
22
30
|
)
|
|
23
31
|
from ..core.proxy import ProxyConfig
|
|
24
32
|
from ..utils.document_processor import (
|
|
25
|
-
extract_document_content,
|
|
26
|
-
|
|
33
|
+
extract_document_content,
|
|
34
|
+
is_document_supported,
|
|
35
|
+
get_document_description,
|
|
36
|
+
DocumentProcessingError,
|
|
27
37
|
)
|
|
28
38
|
|
|
29
|
-
Ctx = TypeVar(
|
|
39
|
+
Ctx = TypeVar("Ctx")
|
|
30
40
|
|
|
31
41
|
# Vision model caching
|
|
32
42
|
VISION_MODEL_CACHE_TTL = 5 * 60 # 5 minutes
|
|
@@ -34,92 +44,87 @@ VISION_API_TIMEOUT = 3.0 # 3 seconds
|
|
|
34
44
|
_vision_model_cache: Dict[str, Dict[str, Any]] = {}
|
|
35
45
|
MAX_IMAGE_BYTES = int(os.environ.get("JAF_MAX_IMAGE_BYTES", 8 * 1024 * 1024))
|
|
36
46
|
|
|
47
|
+
|
|
37
48
|
async def _is_vision_model(model: str, base_url: str) -> bool:
|
|
38
49
|
"""
|
|
39
50
|
Check if a model supports vision capabilities.
|
|
40
|
-
|
|
51
|
+
|
|
41
52
|
Args:
|
|
42
53
|
model: Model name to check
|
|
43
54
|
base_url: Base URL of the LiteLLM server
|
|
44
|
-
|
|
55
|
+
|
|
45
56
|
Returns:
|
|
46
57
|
True if model supports vision, False otherwise
|
|
47
58
|
"""
|
|
48
59
|
cache_key = f"{base_url}:{model}"
|
|
49
60
|
cached = _vision_model_cache.get(cache_key)
|
|
50
|
-
|
|
51
|
-
if cached and time.time() - cached[
|
|
52
|
-
return cached[
|
|
53
|
-
|
|
61
|
+
|
|
62
|
+
if cached and time.time() - cached["timestamp"] < VISION_MODEL_CACHE_TTL:
|
|
63
|
+
return cached["supports"]
|
|
64
|
+
|
|
54
65
|
try:
|
|
55
66
|
async with httpx.AsyncClient(timeout=VISION_API_TIMEOUT) as client:
|
|
56
67
|
response = await client.get(
|
|
57
|
-
f"{base_url}/model_group/info",
|
|
58
|
-
headers={'accept': 'application/json'}
|
|
68
|
+
f"{base_url}/model_group/info", headers={"accept": "application/json"}
|
|
59
69
|
)
|
|
60
|
-
|
|
70
|
+
|
|
61
71
|
if response.status_code == 200:
|
|
62
72
|
data = response.json()
|
|
63
73
|
model_info = None
|
|
64
|
-
|
|
65
|
-
if
|
|
66
|
-
for m in data[
|
|
67
|
-
if
|
|
68
|
-
model in str(m.get('model_group', ''))):
|
|
74
|
+
|
|
75
|
+
if "data" in data and isinstance(data["data"], list):
|
|
76
|
+
for m in data["data"]:
|
|
77
|
+
if m.get("model_group") == model or model in str(m.get("model_group", "")):
|
|
69
78
|
model_info = m
|
|
70
79
|
break
|
|
71
|
-
|
|
72
|
-
if model_info and
|
|
73
|
-
result = model_info[
|
|
74
|
-
_vision_model_cache[cache_key] = {
|
|
75
|
-
'supports': result,
|
|
76
|
-
'timestamp': time.time()
|
|
77
|
-
}
|
|
80
|
+
|
|
81
|
+
if model_info and "supports_vision" in model_info:
|
|
82
|
+
result = model_info["supports_vision"]
|
|
83
|
+
_vision_model_cache[cache_key] = {"supports": result, "timestamp": time.time()}
|
|
78
84
|
return result
|
|
79
85
|
else:
|
|
80
|
-
print(
|
|
81
|
-
|
|
86
|
+
print(
|
|
87
|
+
f"Warning: Vision API returned status {response.status_code} for model {model}"
|
|
88
|
+
)
|
|
89
|
+
|
|
82
90
|
except Exception as e:
|
|
83
91
|
print(f"Warning: Vision API error for model {model}: {e}")
|
|
84
|
-
|
|
92
|
+
|
|
85
93
|
# Fallback to known vision models
|
|
86
94
|
known_vision_models = [
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
95
|
+
"gpt-4-vision-preview",
|
|
96
|
+
"gpt-4o",
|
|
97
|
+
"gpt-4o-mini",
|
|
98
|
+
"claude-sonnet-4",
|
|
99
|
+
"claude-sonnet-4-20250514",
|
|
100
|
+
"gemini-2.5-flash",
|
|
101
|
+
"gemini-2.5-pro",
|
|
94
102
|
]
|
|
95
|
-
|
|
103
|
+
|
|
96
104
|
is_known_vision_model = any(
|
|
97
|
-
vision_model.lower() in model.lower()
|
|
98
|
-
for vision_model in known_vision_models
|
|
105
|
+
vision_model.lower() in model.lower() for vision_model in known_vision_models
|
|
99
106
|
)
|
|
100
|
-
|
|
101
|
-
_vision_model_cache[cache_key] = {
|
|
102
|
-
|
|
103
|
-
'timestamp': time.time()
|
|
104
|
-
}
|
|
105
|
-
|
|
107
|
+
|
|
108
|
+
_vision_model_cache[cache_key] = {"supports": is_known_vision_model, "timestamp": time.time()}
|
|
109
|
+
|
|
106
110
|
return is_known_vision_model
|
|
107
111
|
|
|
112
|
+
|
|
108
113
|
def make_litellm_provider(
|
|
109
114
|
base_url: str,
|
|
110
115
|
api_key: str = "anything",
|
|
111
116
|
default_timeout: Optional[float] = None,
|
|
112
|
-
proxy_config: Optional[ProxyConfig] = None
|
|
117
|
+
proxy_config: Optional[ProxyConfig] = None,
|
|
113
118
|
) -> ModelProvider[Ctx]:
|
|
114
119
|
"""
|
|
115
120
|
Create a LiteLLM-compatible model provider.
|
|
116
|
-
|
|
121
|
+
|
|
117
122
|
Args:
|
|
118
123
|
base_url: Base URL for the LiteLLM server
|
|
119
124
|
api_key: API key (defaults to "anything" for local servers)
|
|
120
125
|
default_timeout: Default timeout for model API calls in seconds
|
|
121
126
|
proxy_config: Optional proxy configuration
|
|
122
|
-
|
|
127
|
+
|
|
123
128
|
Returns:
|
|
124
129
|
ModelProvider instance
|
|
125
130
|
"""
|
|
@@ -128,48 +133,47 @@ def make_litellm_provider(
|
|
|
128
133
|
def __init__(self):
|
|
129
134
|
# Default to "anything" if api_key is not provided, for local servers
|
|
130
135
|
effective_api_key = api_key if api_key is not None else "anything"
|
|
131
|
-
|
|
136
|
+
|
|
132
137
|
# Configure HTTP client with proxy support
|
|
133
138
|
client_kwargs = {
|
|
134
139
|
"base_url": base_url,
|
|
135
140
|
"api_key": effective_api_key,
|
|
136
141
|
}
|
|
137
|
-
|
|
142
|
+
|
|
138
143
|
if proxy_config:
|
|
139
144
|
proxies = proxy_config.to_httpx_proxies()
|
|
140
145
|
if proxies:
|
|
141
146
|
# Create httpx client with proxy configuration
|
|
142
147
|
try:
|
|
143
148
|
# Use the https proxy if available, otherwise http proxy
|
|
144
|
-
proxy_url = proxies.get(
|
|
149
|
+
proxy_url = proxies.get("https://") or proxies.get("http://")
|
|
145
150
|
if proxy_url:
|
|
146
151
|
http_client = httpx.AsyncClient(proxy=proxy_url)
|
|
147
152
|
client_kwargs["http_client"] = http_client
|
|
148
153
|
except Exception as e:
|
|
149
154
|
print(f"Warning: Could not configure proxy: {e}")
|
|
150
155
|
# Fall back to environment variables for proxy
|
|
151
|
-
|
|
156
|
+
|
|
152
157
|
self.client = AsyncOpenAI(**client_kwargs)
|
|
153
158
|
self.default_timeout = default_timeout
|
|
154
159
|
|
|
155
160
|
async def get_completion(
|
|
156
|
-
self,
|
|
157
|
-
state: RunState[Ctx],
|
|
158
|
-
agent: Agent[Ctx, Any],
|
|
159
|
-
config: RunConfig[Ctx]
|
|
161
|
+
self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: RunConfig[Ctx]
|
|
160
162
|
) -> Dict[str, Any]:
|
|
161
163
|
"""Get completion from the model."""
|
|
162
164
|
|
|
163
165
|
# Determine model to use
|
|
164
|
-
model =
|
|
165
|
-
|
|
166
|
+
model = config.model_override or (
|
|
167
|
+
agent.model_config.name if agent.model_config else "gpt-4o"
|
|
168
|
+
)
|
|
166
169
|
|
|
167
170
|
# Check if any message contains image content or image attachments
|
|
168
171
|
has_image_content = any(
|
|
169
|
-
(
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
172
|
+
(
|
|
173
|
+
isinstance(msg.content, list)
|
|
174
|
+
and any(part.type == "image_url" for part in msg.content)
|
|
175
|
+
)
|
|
176
|
+
or (msg.attachments and any(att.kind == "image" for att in msg.attachments))
|
|
173
177
|
for msg in state.messages
|
|
174
178
|
)
|
|
175
179
|
|
|
@@ -182,51 +186,59 @@ def make_litellm_provider(
|
|
|
182
186
|
)
|
|
183
187
|
|
|
184
188
|
# Create system message
|
|
185
|
-
system_message = {
|
|
186
|
-
"role": "system",
|
|
187
|
-
"content": agent.instructions(state)
|
|
188
|
-
}
|
|
189
|
+
system_message = {"role": "system", "content": agent.instructions(state)}
|
|
189
190
|
|
|
190
191
|
# Convert messages to OpenAI format
|
|
191
192
|
converted_messages = []
|
|
192
193
|
for msg in state.messages:
|
|
193
194
|
converted_msg = await _convert_message(msg)
|
|
194
195
|
converted_messages.append(converted_msg)
|
|
195
|
-
|
|
196
|
+
|
|
196
197
|
messages = [system_message] + converted_messages
|
|
197
198
|
|
|
198
199
|
# Convert tools to OpenAI format
|
|
199
200
|
tools = None
|
|
200
201
|
if agent.tools:
|
|
202
|
+
# Check if we should inline schema refs
|
|
203
|
+
inline_refs = (
|
|
204
|
+
agent.model_config.inline_tool_schemas if agent.model_config else False
|
|
205
|
+
)
|
|
201
206
|
tools = [
|
|
202
207
|
{
|
|
203
208
|
"type": "function",
|
|
204
209
|
"function": {
|
|
205
210
|
"name": tool.schema.name,
|
|
206
211
|
"description": tool.schema.description,
|
|
207
|
-
"parameters": _pydantic_to_json_schema(
|
|
208
|
-
|
|
212
|
+
"parameters": _pydantic_to_json_schema(
|
|
213
|
+
tool.schema.parameters, inline_refs=inline_refs or False
|
|
214
|
+
),
|
|
215
|
+
},
|
|
209
216
|
}
|
|
210
217
|
for tool in agent.tools
|
|
211
218
|
]
|
|
212
219
|
|
|
213
220
|
# Determine tool choice behavior
|
|
214
221
|
last_message = state.messages[-1] if state.messages else None
|
|
215
|
-
is_after_tool_call = last_message and (
|
|
222
|
+
is_after_tool_call = last_message and (
|
|
223
|
+
last_message.role == ContentRole.TOOL or last_message.role == "tool"
|
|
224
|
+
)
|
|
216
225
|
|
|
217
226
|
# Prepare request parameters
|
|
218
|
-
request_params = {
|
|
219
|
-
"model": model,
|
|
220
|
-
"messages": messages,
|
|
221
|
-
"stream": False
|
|
222
|
-
}
|
|
227
|
+
request_params = {"model": model, "messages": messages, "stream": False}
|
|
223
228
|
|
|
224
229
|
# Add optional parameters
|
|
225
230
|
if agent.model_config:
|
|
226
231
|
if agent.model_config.temperature is not None:
|
|
227
232
|
request_params["temperature"] = agent.model_config.temperature
|
|
228
|
-
|
|
229
|
-
|
|
233
|
+
# Use agent's max_tokens if set, otherwise fall back to config's max_tokens
|
|
234
|
+
max_tokens = agent.model_config.max_tokens
|
|
235
|
+
if max_tokens is None:
|
|
236
|
+
max_tokens = config.max_tokens
|
|
237
|
+
if max_tokens is not None:
|
|
238
|
+
request_params["max_tokens"] = max_tokens
|
|
239
|
+
elif config.max_tokens is not None:
|
|
240
|
+
# No model_config but config has max_tokens
|
|
241
|
+
request_params["max_tokens"] = config.max_tokens
|
|
230
242
|
|
|
231
243
|
if tools:
|
|
232
244
|
request_params["tools"] = tools
|
|
@@ -247,12 +259,9 @@ def make_litellm_provider(
|
|
|
247
259
|
if choice.message.tool_calls:
|
|
248
260
|
tool_calls = [
|
|
249
261
|
{
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
'name': tc.function.name,
|
|
254
|
-
'arguments': tc.function.arguments
|
|
255
|
-
}
|
|
262
|
+
"id": tc.id,
|
|
263
|
+
"type": tc.type,
|
|
264
|
+
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
|
256
265
|
}
|
|
257
266
|
for tc in choice.message.tool_calls
|
|
258
267
|
]
|
|
@@ -267,64 +276,64 @@ def make_litellm_provider(
|
|
|
267
276
|
}
|
|
268
277
|
|
|
269
278
|
return {
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
},
|
|
278
|
-
'usage': usage_data,
|
|
279
|
-
'prompt': messages
|
|
279
|
+
"id": response.id,
|
|
280
|
+
"created": response.created,
|
|
281
|
+
"model": response.model,
|
|
282
|
+
"system_fingerprint": response.system_fingerprint,
|
|
283
|
+
"message": {"content": choice.message.content, "tool_calls": tool_calls},
|
|
284
|
+
"usage": usage_data,
|
|
285
|
+
"prompt": messages,
|
|
280
286
|
}
|
|
281
287
|
|
|
282
288
|
async def get_completion_stream(
|
|
283
|
-
self,
|
|
284
|
-
state: RunState[Ctx],
|
|
285
|
-
agent: Agent[Ctx, Any],
|
|
286
|
-
config: RunConfig[Ctx]
|
|
289
|
+
self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: RunConfig[Ctx]
|
|
287
290
|
) -> AsyncIterator[CompletionStreamChunk]:
|
|
288
291
|
"""
|
|
289
292
|
Stream completion chunks from the model provider, yielding text deltas and tool-call deltas.
|
|
290
293
|
Uses OpenAI-compatible streaming via LiteLLM endpoint.
|
|
291
294
|
"""
|
|
292
295
|
# Determine model to use
|
|
293
|
-
model =
|
|
294
|
-
|
|
296
|
+
model = config.model_override or (
|
|
297
|
+
agent.model_config.name if agent.model_config else "gpt-4o"
|
|
298
|
+
)
|
|
295
299
|
|
|
296
300
|
# Create system message
|
|
297
|
-
system_message = {
|
|
298
|
-
"role": "system",
|
|
299
|
-
"content": agent.instructions(state)
|
|
300
|
-
}
|
|
301
|
+
system_message = {"role": "system", "content": agent.instructions(state)}
|
|
301
302
|
|
|
302
|
-
# Convert messages to OpenAI format
|
|
303
|
+
# Convert messages to OpenAI format
|
|
303
304
|
converted_messages = []
|
|
304
305
|
for msg in state.messages:
|
|
305
306
|
converted_msg = await _convert_message(msg)
|
|
306
307
|
converted_messages.append(converted_msg)
|
|
307
|
-
|
|
308
|
+
|
|
308
309
|
messages = [system_message] + converted_messages
|
|
309
310
|
|
|
310
311
|
# Convert tools to OpenAI format
|
|
311
312
|
tools = None
|
|
312
313
|
if agent.tools:
|
|
314
|
+
# Check if we should inline schema refs
|
|
315
|
+
inline_refs = (
|
|
316
|
+
agent.model_config.inline_tool_schemas if agent.model_config else False
|
|
317
|
+
)
|
|
313
318
|
tools = [
|
|
314
319
|
{
|
|
315
320
|
"type": "function",
|
|
316
321
|
"function": {
|
|
317
322
|
"name": tool.schema.name,
|
|
318
323
|
"description": tool.schema.description,
|
|
319
|
-
"parameters": _pydantic_to_json_schema(
|
|
320
|
-
|
|
324
|
+
"parameters": _pydantic_to_json_schema(
|
|
325
|
+
tool.schema.parameters, inline_refs=inline_refs or False
|
|
326
|
+
),
|
|
327
|
+
},
|
|
321
328
|
}
|
|
322
329
|
for tool in agent.tools
|
|
323
330
|
]
|
|
324
331
|
|
|
325
332
|
# Determine tool choice behavior
|
|
326
333
|
last_message = state.messages[-1] if state.messages else None
|
|
327
|
-
is_after_tool_call = last_message and (
|
|
334
|
+
is_after_tool_call = last_message and (
|
|
335
|
+
last_message.role == ContentRole.TOOL or last_message.role == "tool"
|
|
336
|
+
)
|
|
328
337
|
|
|
329
338
|
# Prepare request parameters
|
|
330
339
|
request_params: Dict[str, Any] = {
|
|
@@ -336,8 +345,15 @@ def make_litellm_provider(
|
|
|
336
345
|
if agent.model_config:
|
|
337
346
|
if agent.model_config.temperature is not None:
|
|
338
347
|
request_params["temperature"] = agent.model_config.temperature
|
|
339
|
-
|
|
340
|
-
|
|
348
|
+
# Use agent's max_tokens if set, otherwise fall back to config's max_tokens
|
|
349
|
+
max_tokens = agent.model_config.max_tokens
|
|
350
|
+
if max_tokens is None:
|
|
351
|
+
max_tokens = config.max_tokens
|
|
352
|
+
if max_tokens is not None:
|
|
353
|
+
request_params["max_tokens"] = max_tokens
|
|
354
|
+
elif config.max_tokens is not None:
|
|
355
|
+
# No model_config but config has max_tokens
|
|
356
|
+
request_params["max_tokens"] = config.max_tokens
|
|
341
357
|
|
|
342
358
|
if tools:
|
|
343
359
|
request_params["tools"] = tools
|
|
@@ -388,19 +404,20 @@ def make_litellm_provider(
|
|
|
388
404
|
fn = getattr(tc, "function", None)
|
|
389
405
|
fn_name = getattr(fn, "name", None) if fn is not None else None
|
|
390
406
|
# OpenAI streams "arguments" as incremental deltas
|
|
391
|
-
args_delta =
|
|
407
|
+
args_delta = (
|
|
408
|
+
getattr(fn, "arguments", None) if fn is not None else None
|
|
409
|
+
)
|
|
392
410
|
|
|
393
411
|
yield CompletionStreamChunk(
|
|
394
412
|
tool_call_delta=ToolCallDelta(
|
|
395
413
|
index=idx,
|
|
396
414
|
id=tc_id,
|
|
397
|
-
type=
|
|
415
|
+
type="function",
|
|
398
416
|
function=ToolCallFunctionDelta(
|
|
399
|
-
name=fn_name,
|
|
400
|
-
|
|
401
|
-
)
|
|
417
|
+
name=fn_name, arguments_delta=args_delta
|
|
418
|
+
),
|
|
402
419
|
),
|
|
403
|
-
raw=raw_obj
|
|
420
|
+
raw=raw_obj,
|
|
404
421
|
)
|
|
405
422
|
except Exception:
|
|
406
423
|
# Skip malformed tool-call deltas
|
|
@@ -408,26 +425,29 @@ def make_litellm_provider(
|
|
|
408
425
|
|
|
409
426
|
# Completion ended
|
|
410
427
|
if finish_reason:
|
|
411
|
-
yield CompletionStreamChunk(
|
|
428
|
+
yield CompletionStreamChunk(
|
|
429
|
+
is_done=True, finish_reason=finish_reason, raw=raw_obj
|
|
430
|
+
)
|
|
412
431
|
except Exception:
|
|
413
432
|
# Skip individual chunk errors, keep streaming
|
|
414
433
|
continue
|
|
415
434
|
|
|
416
435
|
return LiteLLMProvider()
|
|
417
436
|
|
|
437
|
+
|
|
418
438
|
def make_litellm_sdk_provider(
|
|
419
439
|
api_key: Optional[str] = None,
|
|
420
440
|
model: str = "gpt-3.5-turbo",
|
|
421
441
|
base_url: Optional[str] = None,
|
|
422
442
|
default_timeout: Optional[float] = None,
|
|
423
|
-
**litellm_kwargs: Any
|
|
443
|
+
**litellm_kwargs: Any,
|
|
424
444
|
) -> ModelProvider[Ctx]:
|
|
425
445
|
"""
|
|
426
446
|
Create a LiteLLM SDK-based model provider with universal provider support.
|
|
427
|
-
|
|
447
|
+
|
|
428
448
|
LiteLLM automatically detects the provider from the model name and handles
|
|
429
449
|
API key management through environment variables or direct parameters.
|
|
430
|
-
|
|
450
|
+
|
|
431
451
|
Args:
|
|
432
452
|
api_key: API key for the provider (optional, can use env vars)
|
|
433
453
|
model: Model name (e.g., "gpt-4", "claude-3-sonnet", "gemini-pro", "llama2", etc.)
|
|
@@ -440,23 +460,23 @@ def make_litellm_sdk_provider(
|
|
|
440
460
|
- azure_deployment: "your-deployment" (for Azure OpenAI)
|
|
441
461
|
- api_base: "https://your-endpoint.com" (custom endpoint)
|
|
442
462
|
- custom_llm_provider: "custom_provider_name"
|
|
443
|
-
|
|
463
|
+
|
|
444
464
|
Returns:
|
|
445
465
|
ModelProvider instance
|
|
446
|
-
|
|
466
|
+
|
|
447
467
|
Examples:
|
|
448
468
|
# OpenAI
|
|
449
469
|
make_litellm_sdk_provider(api_key="sk-...", model="gpt-4")
|
|
450
|
-
|
|
470
|
+
|
|
451
471
|
# Anthropic Claude
|
|
452
472
|
make_litellm_sdk_provider(api_key="sk-ant-...", model="claude-3-sonnet-20240229")
|
|
453
|
-
|
|
473
|
+
|
|
454
474
|
# Google Gemini
|
|
455
475
|
make_litellm_sdk_provider(model="gemini-pro", vertex_project="my-project")
|
|
456
|
-
|
|
476
|
+
|
|
457
477
|
# Ollama (local)
|
|
458
478
|
make_litellm_sdk_provider(model="ollama/llama2", base_url="http://localhost:11434")
|
|
459
|
-
|
|
479
|
+
|
|
460
480
|
# Azure OpenAI
|
|
461
481
|
make_litellm_sdk_provider(
|
|
462
482
|
model="azure/gpt-4",
|
|
@@ -464,13 +484,13 @@ def make_litellm_sdk_provider(
|
|
|
464
484
|
azure_deployment="gpt-4-deployment",
|
|
465
485
|
api_base="https://your-resource.openai.azure.com"
|
|
466
486
|
)
|
|
467
|
-
|
|
487
|
+
|
|
468
488
|
# Hugging Face
|
|
469
489
|
make_litellm_sdk_provider(
|
|
470
490
|
model="huggingface/microsoft/DialoGPT-medium",
|
|
471
491
|
api_key="hf_..."
|
|
472
492
|
)
|
|
473
|
-
|
|
493
|
+
|
|
474
494
|
# Any custom provider
|
|
475
495
|
make_litellm_sdk_provider(
|
|
476
496
|
model="custom_provider/model-name",
|
|
@@ -488,10 +508,7 @@ def make_litellm_sdk_provider(
|
|
|
488
508
|
self.litellm_kwargs = litellm_kwargs
|
|
489
509
|
|
|
490
510
|
async def get_completion(
|
|
491
|
-
self,
|
|
492
|
-
state: RunState[Ctx],
|
|
493
|
-
agent: Agent[Ctx, Any],
|
|
494
|
-
config: RunConfig[Ctx]
|
|
511
|
+
self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: RunConfig[Ctx]
|
|
495
512
|
) -> Dict[str, Any]:
|
|
496
513
|
"""Get completion from the model using LiteLLM SDK."""
|
|
497
514
|
|
|
@@ -499,10 +516,7 @@ def make_litellm_sdk_provider(
|
|
|
499
516
|
model_name = config.model_override or self.model
|
|
500
517
|
|
|
501
518
|
# Create system message
|
|
502
|
-
system_message = {
|
|
503
|
-
"role": "system",
|
|
504
|
-
"content": agent.instructions(state)
|
|
505
|
-
}
|
|
519
|
+
system_message = {"role": "system", "content": agent.instructions(state)}
|
|
506
520
|
|
|
507
521
|
# Convert messages to OpenAI format
|
|
508
522
|
messages = [system_message]
|
|
@@ -513,24 +527,26 @@ def make_litellm_sdk_provider(
|
|
|
513
527
|
# Convert tools to OpenAI format
|
|
514
528
|
tools = None
|
|
515
529
|
if agent.tools:
|
|
530
|
+
# Check if we should inline schema refs
|
|
531
|
+
inline_refs = (
|
|
532
|
+
agent.model_config.inline_tool_schemas if agent.model_config else False
|
|
533
|
+
)
|
|
516
534
|
tools = [
|
|
517
535
|
{
|
|
518
536
|
"type": "function",
|
|
519
537
|
"function": {
|
|
520
538
|
"name": tool.schema.name,
|
|
521
539
|
"description": tool.schema.description,
|
|
522
|
-
"parameters": _pydantic_to_json_schema(
|
|
523
|
-
|
|
540
|
+
"parameters": _pydantic_to_json_schema(
|
|
541
|
+
tool.schema.parameters, inline_refs=inline_refs or False
|
|
542
|
+
),
|
|
543
|
+
},
|
|
524
544
|
}
|
|
525
545
|
for tool in agent.tools
|
|
526
546
|
]
|
|
527
547
|
|
|
528
548
|
# Prepare request parameters for LiteLLM
|
|
529
|
-
request_params = {
|
|
530
|
-
"model": model_name,
|
|
531
|
-
"messages": messages,
|
|
532
|
-
**self.litellm_kwargs
|
|
533
|
-
}
|
|
549
|
+
request_params = {"model": model_name, "messages": messages, **self.litellm_kwargs}
|
|
534
550
|
|
|
535
551
|
# Add API key if provided
|
|
536
552
|
if self.api_key:
|
|
@@ -540,8 +556,15 @@ def make_litellm_sdk_provider(
|
|
|
540
556
|
if agent.model_config:
|
|
541
557
|
if agent.model_config.temperature is not None:
|
|
542
558
|
request_params["temperature"] = agent.model_config.temperature
|
|
543
|
-
|
|
544
|
-
|
|
559
|
+
# Use agent's max_tokens if set, otherwise fall back to config's max_tokens
|
|
560
|
+
max_tokens = agent.model_config.max_tokens
|
|
561
|
+
if max_tokens is None:
|
|
562
|
+
max_tokens = config.max_tokens
|
|
563
|
+
if max_tokens is not None:
|
|
564
|
+
request_params["max_tokens"] = max_tokens
|
|
565
|
+
elif config.max_tokens is not None:
|
|
566
|
+
# No model_config but config has max_tokens
|
|
567
|
+
request_params["max_tokens"] = config.max_tokens
|
|
545
568
|
|
|
546
569
|
if tools:
|
|
547
570
|
request_params["tools"] = tools
|
|
@@ -565,12 +588,9 @@ def make_litellm_sdk_provider(
|
|
|
565
588
|
if choice.message.tool_calls:
|
|
566
589
|
tool_calls = [
|
|
567
590
|
{
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
'name': tc.function.name,
|
|
572
|
-
'arguments': tc.function.arguments
|
|
573
|
-
}
|
|
591
|
+
"id": tc.id,
|
|
592
|
+
"type": tc.type,
|
|
593
|
+
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
|
574
594
|
}
|
|
575
595
|
for tc in choice.message.tool_calls
|
|
576
596
|
]
|
|
@@ -585,23 +605,17 @@ def make_litellm_sdk_provider(
|
|
|
585
605
|
}
|
|
586
606
|
|
|
587
607
|
return {
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
},
|
|
596
|
-
'usage': usage_data,
|
|
597
|
-
'prompt': messages
|
|
608
|
+
"id": response.id,
|
|
609
|
+
"created": response.created,
|
|
610
|
+
"model": response.model,
|
|
611
|
+
"system_fingerprint": getattr(response, "system_fingerprint", None),
|
|
612
|
+
"message": {"content": choice.message.content, "tool_calls": tool_calls},
|
|
613
|
+
"usage": usage_data,
|
|
614
|
+
"prompt": messages,
|
|
598
615
|
}
|
|
599
616
|
|
|
600
617
|
async def get_completion_stream(
|
|
601
|
-
self,
|
|
602
|
-
state: RunState[Ctx],
|
|
603
|
-
agent: Agent[Ctx, Any],
|
|
604
|
-
config: RunConfig[Ctx]
|
|
618
|
+
self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: RunConfig[Ctx]
|
|
605
619
|
) -> AsyncIterator[CompletionStreamChunk]:
|
|
606
620
|
"""
|
|
607
621
|
Stream completion chunks from the model provider using LiteLLM SDK.
|
|
@@ -610,10 +624,7 @@ def make_litellm_sdk_provider(
|
|
|
610
624
|
model_name = config.model_override or self.model
|
|
611
625
|
|
|
612
626
|
# Create system message
|
|
613
|
-
system_message = {
|
|
614
|
-
"role": "system",
|
|
615
|
-
"content": agent.instructions(state)
|
|
616
|
-
}
|
|
627
|
+
system_message = {"role": "system", "content": agent.instructions(state)}
|
|
617
628
|
|
|
618
629
|
# Convert messages to OpenAI format
|
|
619
630
|
messages = [system_message]
|
|
@@ -624,14 +635,20 @@ def make_litellm_sdk_provider(
|
|
|
624
635
|
# Convert tools to OpenAI format
|
|
625
636
|
tools = None
|
|
626
637
|
if agent.tools:
|
|
638
|
+
# Check if we should inline schema refs
|
|
639
|
+
inline_refs = (
|
|
640
|
+
agent.model_config.inline_tool_schemas if agent.model_config else False
|
|
641
|
+
)
|
|
627
642
|
tools = [
|
|
628
643
|
{
|
|
629
644
|
"type": "function",
|
|
630
645
|
"function": {
|
|
631
646
|
"name": tool.schema.name,
|
|
632
647
|
"description": tool.schema.description,
|
|
633
|
-
"parameters": _pydantic_to_json_schema(
|
|
634
|
-
|
|
648
|
+
"parameters": _pydantic_to_json_schema(
|
|
649
|
+
tool.schema.parameters, inline_refs=inline_refs or False
|
|
650
|
+
),
|
|
651
|
+
},
|
|
635
652
|
}
|
|
636
653
|
for tool in agent.tools
|
|
637
654
|
]
|
|
@@ -641,7 +658,7 @@ def make_litellm_sdk_provider(
|
|
|
641
658
|
"model": model_name,
|
|
642
659
|
"messages": messages,
|
|
643
660
|
"stream": True,
|
|
644
|
-
**self.litellm_kwargs
|
|
661
|
+
**self.litellm_kwargs,
|
|
645
662
|
}
|
|
646
663
|
|
|
647
664
|
# Add API key if provided
|
|
@@ -652,8 +669,15 @@ def make_litellm_sdk_provider(
|
|
|
652
669
|
if agent.model_config:
|
|
653
670
|
if agent.model_config.temperature is not None:
|
|
654
671
|
request_params["temperature"] = agent.model_config.temperature
|
|
655
|
-
|
|
656
|
-
|
|
672
|
+
# Use agent's max_tokens if set, otherwise fall back to config's max_tokens
|
|
673
|
+
max_tokens = agent.model_config.max_tokens
|
|
674
|
+
if max_tokens is None:
|
|
675
|
+
max_tokens = config.max_tokens
|
|
676
|
+
if max_tokens is not None:
|
|
677
|
+
request_params["max_tokens"] = max_tokens
|
|
678
|
+
elif config.max_tokens is not None:
|
|
679
|
+
# No model_config but config has max_tokens
|
|
680
|
+
request_params["max_tokens"] = config.max_tokens
|
|
657
681
|
|
|
658
682
|
if tools:
|
|
659
683
|
request_params["tools"] = tools
|
|
@@ -668,12 +692,12 @@ def make_litellm_sdk_provider(
|
|
|
668
692
|
|
|
669
693
|
# Stream using litellm
|
|
670
694
|
stream = await litellm.acompletion(**request_params)
|
|
671
|
-
|
|
695
|
+
|
|
672
696
|
async for chunk in stream:
|
|
673
697
|
try:
|
|
674
698
|
# Best-effort extraction of raw for debugging
|
|
675
699
|
try:
|
|
676
|
-
raw_obj = chunk.model_dump() if hasattr(chunk,
|
|
700
|
+
raw_obj = chunk.model_dump() if hasattr(chunk, "model_dump") else None
|
|
677
701
|
except Exception:
|
|
678
702
|
raw_obj = None
|
|
679
703
|
|
|
@@ -702,52 +726,59 @@ def make_litellm_sdk_provider(
|
|
|
702
726
|
tc_id = getattr(tc, "id", None)
|
|
703
727
|
fn = getattr(tc, "function", None)
|
|
704
728
|
fn_name = getattr(fn, "name", None) if fn is not None else None
|
|
705
|
-
args_delta =
|
|
729
|
+
args_delta = (
|
|
730
|
+
getattr(fn, "arguments", None) if fn is not None else None
|
|
731
|
+
)
|
|
706
732
|
|
|
707
733
|
yield CompletionStreamChunk(
|
|
708
734
|
tool_call_delta=ToolCallDelta(
|
|
709
735
|
index=idx,
|
|
710
736
|
id=tc_id,
|
|
711
|
-
type=
|
|
737
|
+
type="function",
|
|
712
738
|
function=ToolCallFunctionDelta(
|
|
713
|
-
name=fn_name,
|
|
714
|
-
|
|
715
|
-
)
|
|
739
|
+
name=fn_name, arguments_delta=args_delta
|
|
740
|
+
),
|
|
716
741
|
),
|
|
717
|
-
raw=raw_obj
|
|
742
|
+
raw=raw_obj,
|
|
718
743
|
)
|
|
719
744
|
except Exception:
|
|
720
745
|
continue
|
|
721
746
|
|
|
722
747
|
# Completion ended
|
|
723
748
|
if finish_reason:
|
|
724
|
-
yield CompletionStreamChunk(
|
|
749
|
+
yield CompletionStreamChunk(
|
|
750
|
+
is_done=True, finish_reason=finish_reason, raw=raw_obj
|
|
751
|
+
)
|
|
725
752
|
except Exception:
|
|
726
753
|
continue
|
|
727
754
|
|
|
728
755
|
return LiteLLMSDKProvider()
|
|
729
756
|
|
|
757
|
+
|
|
730
758
|
async def _convert_message(msg: Message) -> Dict[str, Any]:
|
|
731
759
|
"""
|
|
732
760
|
Handles all possible role types (string and enum) and content formats.
|
|
733
761
|
"""
|
|
734
762
|
# Normalize role to handle both string and enum values
|
|
735
|
-
role_value = msg.role.value if hasattr(msg.role,
|
|
763
|
+
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role).lower()
|
|
736
764
|
|
|
737
765
|
# Handle user messages
|
|
738
|
-
if role_value in (
|
|
766
|
+
if role_value in ("user", ContentRole.USER.value if hasattr(ContentRole, "USER") else "user"):
|
|
739
767
|
if isinstance(msg.content, list):
|
|
740
768
|
# Multi-part content
|
|
741
769
|
return {
|
|
742
770
|
"role": "user",
|
|
743
|
-
"content": [_convert_content_part(part) for part in msg.content]
|
|
771
|
+
"content": [_convert_content_part(part) for part in msg.content],
|
|
744
772
|
}
|
|
745
773
|
else:
|
|
746
774
|
# Build message with attachments if available
|
|
747
|
-
return await _build_chat_message_with_attachments(
|
|
775
|
+
return await _build_chat_message_with_attachments("user", msg)
|
|
748
776
|
|
|
749
777
|
# Handle assistant messages
|
|
750
|
-
elif role_value in (
|
|
778
|
+
elif role_value in (
|
|
779
|
+
"assistant",
|
|
780
|
+
ContentRole.ASSISTANT.value if hasattr(ContentRole, "ASSISTANT") else "assistant",
|
|
781
|
+
):
|
|
751
782
|
result = {
|
|
752
783
|
"role": "assistant",
|
|
753
784
|
"content": get_text_content(msg.content) or "", # Ensure content is never None
|
|
@@ -759,10 +790,7 @@ async def _convert_message(msg: Message) -> Dict[str, Any]:
|
|
|
759
790
|
{
|
|
760
791
|
"id": tc.id,
|
|
761
792
|
"type": tc.type,
|
|
762
|
-
"function": {
|
|
763
|
-
"name": tc.function.name,
|
|
764
|
-
"arguments": tc.function.arguments
|
|
765
|
-
}
|
|
793
|
+
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
|
766
794
|
}
|
|
767
795
|
for tc in msg.tool_calls
|
|
768
796
|
if tc.id and tc.function and tc.function.name # Validate tool call structure
|
|
@@ -771,37 +799,37 @@ async def _convert_message(msg: Message) -> Dict[str, Any]:
|
|
|
771
799
|
return result
|
|
772
800
|
|
|
773
801
|
# Handle system messages
|
|
774
|
-
elif role_value in (
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
}
|
|
802
|
+
elif role_value in (
|
|
803
|
+
"system",
|
|
804
|
+
ContentRole.SYSTEM.value if hasattr(ContentRole, "SYSTEM") else "system",
|
|
805
|
+
):
|
|
806
|
+
return {"role": "system", "content": get_text_content(msg.content) or ""}
|
|
779
807
|
|
|
780
808
|
# Handle tool messages
|
|
781
|
-
elif role_value in (
|
|
809
|
+
elif role_value in ("tool", ContentRole.TOOL.value if hasattr(ContentRole, "TOOL") else "tool"):
|
|
782
810
|
if not msg.tool_call_id:
|
|
783
811
|
raise ValueError(f"Tool message must have tool_call_id. Message: {msg}")
|
|
784
812
|
|
|
785
813
|
return {
|
|
786
814
|
"role": "tool",
|
|
787
815
|
"content": get_text_content(msg.content) or "",
|
|
788
|
-
"tool_call_id": msg.tool_call_id
|
|
816
|
+
"tool_call_id": msg.tool_call_id,
|
|
789
817
|
}
|
|
790
818
|
|
|
791
819
|
# Handle function messages (legacy support)
|
|
792
|
-
elif role_value ==
|
|
820
|
+
elif role_value == "function":
|
|
793
821
|
if not msg.tool_call_id:
|
|
794
822
|
raise ValueError(f"Function message must have tool_call_id. Message: {msg}")
|
|
795
823
|
|
|
796
824
|
return {
|
|
797
825
|
"role": "function",
|
|
798
826
|
"content": get_text_content(msg.content) or "",
|
|
799
|
-
"name": getattr(msg,
|
|
827
|
+
"name": getattr(msg, "name", "unknown_function"),
|
|
800
828
|
}
|
|
801
829
|
|
|
802
830
|
# Unknown role - provide helpful error message
|
|
803
831
|
else:
|
|
804
|
-
available_roles = [
|
|
832
|
+
available_roles = ["user", "assistant", "system", "tool", "function"]
|
|
805
833
|
raise ValueError(
|
|
806
834
|
f"Unknown message role: {msg.role} (type: {type(msg.role)}). "
|
|
807
835
|
f"Supported roles: {available_roles}. "
|
|
@@ -811,46 +839,31 @@ async def _convert_message(msg: Message) -> Dict[str, Any]:
|
|
|
811
839
|
|
|
812
840
|
def _convert_content_part(part: MessageContentPart) -> Dict[str, Any]:
|
|
813
841
|
"""Convert MessageContentPart to OpenAI format."""
|
|
814
|
-
if part.type ==
|
|
815
|
-
return {
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
return {
|
|
821
|
-
"type": "image_url",
|
|
822
|
-
"image_url": part.image_url
|
|
823
|
-
}
|
|
824
|
-
elif part.type == 'file':
|
|
825
|
-
return {
|
|
826
|
-
"type": "file",
|
|
827
|
-
"file": part.file
|
|
828
|
-
}
|
|
842
|
+
if part.type == "text":
|
|
843
|
+
return {"type": "text", "text": part.text}
|
|
844
|
+
elif part.type == "image_url":
|
|
845
|
+
return {"type": "image_url", "image_url": part.image_url}
|
|
846
|
+
elif part.type == "file":
|
|
847
|
+
return {"type": "file", "file": part.file}
|
|
829
848
|
else:
|
|
830
849
|
raise ValueError(f"Unknown content part type: {part.type}")
|
|
831
850
|
|
|
832
851
|
|
|
833
|
-
async def _build_chat_message_with_attachments(
|
|
834
|
-
role: str,
|
|
835
|
-
msg: Message
|
|
836
|
-
) -> Dict[str, Any]:
|
|
852
|
+
async def _build_chat_message_with_attachments(role: str, msg: Message) -> Dict[str, Any]:
|
|
837
853
|
"""
|
|
838
854
|
Build multi-part content for Chat Completions if attachments exist.
|
|
839
855
|
Supports images via image_url and documents via content extraction.
|
|
840
856
|
"""
|
|
841
857
|
has_attachments = msg.attachments and len(msg.attachments) > 0
|
|
842
858
|
if not has_attachments:
|
|
843
|
-
if role ==
|
|
859
|
+
if role == "assistant":
|
|
844
860
|
base_msg = {"role": "assistant", "content": get_text_content(msg.content)}
|
|
845
861
|
if msg.tool_calls:
|
|
846
862
|
base_msg["tool_calls"] = [
|
|
847
863
|
{
|
|
848
864
|
"id": tc.id,
|
|
849
865
|
"type": tc.type,
|
|
850
|
-
"function": {
|
|
851
|
-
"name": tc.function.name,
|
|
852
|
-
"arguments": tc.function.arguments
|
|
853
|
-
}
|
|
866
|
+
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
|
854
867
|
}
|
|
855
868
|
for tc in msg.tool_calls
|
|
856
869
|
]
|
|
@@ -863,7 +876,7 @@ async def _build_chat_message_with_attachments(
|
|
|
863
876
|
parts.append({"type": "text", "text": text_content})
|
|
864
877
|
|
|
865
878
|
for att in msg.attachments:
|
|
866
|
-
if att.kind ==
|
|
879
|
+
if att.kind == "image":
|
|
867
880
|
# Prefer explicit URL; otherwise construct a data URL from base64
|
|
868
881
|
url = att.url
|
|
869
882
|
if not url and att.data and att.mime_type:
|
|
@@ -871,100 +884,168 @@ async def _build_chat_message_with_attachments(
|
|
|
871
884
|
try:
|
|
872
885
|
# Estimate decoded size (base64 is ~4/3 of decoded size)
|
|
873
886
|
estimated_size = len(att.data) * 3 // 4
|
|
874
|
-
|
|
887
|
+
|
|
875
888
|
if estimated_size > MAX_IMAGE_BYTES:
|
|
876
|
-
print(
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
889
|
+
print(
|
|
890
|
+
f"Warning: Skipping oversized image ({estimated_size} bytes > {MAX_IMAGE_BYTES}). "
|
|
891
|
+
f"Set JAF_MAX_IMAGE_BYTES env var to adjust limit."
|
|
892
|
+
)
|
|
893
|
+
parts.append(
|
|
894
|
+
{
|
|
895
|
+
"type": "text",
|
|
896
|
+
"text": f"[IMAGE SKIPPED: Size exceeds limit of {MAX_IMAGE_BYTES // 1024 // 1024}MB. "
|
|
897
|
+
f"Image name: {att.name or 'unnamed'}]",
|
|
898
|
+
}
|
|
899
|
+
)
|
|
883
900
|
continue
|
|
884
|
-
|
|
901
|
+
|
|
885
902
|
# Create data URL for valid-sized images
|
|
886
903
|
url = f"data:{att.mime_type};base64,{att.data}"
|
|
887
904
|
except Exception as e:
|
|
888
905
|
print(f"Error processing image data: {e}")
|
|
889
|
-
parts.append(
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
906
|
+
parts.append(
|
|
907
|
+
{
|
|
908
|
+
"type": "text",
|
|
909
|
+
"text": f"[IMAGE ERROR: Failed to process image data. Image name: {att.name or 'unnamed'}]",
|
|
910
|
+
}
|
|
911
|
+
)
|
|
893
912
|
continue
|
|
894
|
-
|
|
913
|
+
|
|
895
914
|
if url:
|
|
896
|
-
parts.append({
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
})
|
|
900
|
-
|
|
901
|
-
elif att.kind in ['document', 'file']:
|
|
915
|
+
parts.append({"type": "image_url", "image_url": {"url": url}})
|
|
916
|
+
|
|
917
|
+
elif att.kind in ["document", "file"]:
|
|
902
918
|
# Check if attachment has use_litellm_format flag or is a large document
|
|
903
919
|
use_litellm_format = att.use_litellm_format is True
|
|
904
|
-
|
|
920
|
+
|
|
905
921
|
if use_litellm_format and (att.url or att.data):
|
|
906
922
|
# For now, fall back to content extraction since most providers don't support native file format
|
|
907
923
|
# TODO: Add provider-specific file format support
|
|
908
|
-
print(
|
|
924
|
+
print(
|
|
925
|
+
f"Info: LiteLLM format requested for {att.name}, falling back to content extraction"
|
|
926
|
+
)
|
|
909
927
|
use_litellm_format = False
|
|
910
|
-
|
|
928
|
+
|
|
911
929
|
if not use_litellm_format:
|
|
912
930
|
# Extract document content if supported and we have data or URL
|
|
913
931
|
if is_document_supported(att.mime_type) and (att.data or att.url):
|
|
914
932
|
try:
|
|
915
933
|
processed = await extract_document_content(att)
|
|
916
|
-
file_name = att.name or
|
|
934
|
+
file_name = att.name or "document"
|
|
917
935
|
description = get_document_description(att.mime_type)
|
|
918
|
-
|
|
919
|
-
parts.append(
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
936
|
+
|
|
937
|
+
parts.append(
|
|
938
|
+
{
|
|
939
|
+
"type": "text",
|
|
940
|
+
"text": f"DOCUMENT: {file_name} ({description}):\n\n{processed.content}",
|
|
941
|
+
}
|
|
942
|
+
)
|
|
923
943
|
except DocumentProcessingError as e:
|
|
924
944
|
# Fallback to filename if extraction fails
|
|
925
|
-
label = att.name or att.format or att.mime_type or
|
|
926
|
-
parts.append(
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
945
|
+
label = att.name or att.format or att.mime_type or "attachment"
|
|
946
|
+
parts.append(
|
|
947
|
+
{
|
|
948
|
+
"type": "text",
|
|
949
|
+
"text": f"ERROR: Failed to process {att.kind}: {label} ({e})",
|
|
950
|
+
}
|
|
951
|
+
)
|
|
930
952
|
else:
|
|
931
953
|
# Unsupported document type - show placeholder
|
|
932
|
-
label = att.name or att.format or att.mime_type or
|
|
954
|
+
label = att.name or att.format or att.mime_type or "attachment"
|
|
933
955
|
url_info = f" ({att.url})" if att.url else ""
|
|
934
|
-
parts.append(
|
|
935
|
-
"type": "text",
|
|
936
|
-
|
|
937
|
-
})
|
|
956
|
+
parts.append(
|
|
957
|
+
{"type": "text", "text": f"ATTACHMENT: {att.kind}: {label}{url_info}"}
|
|
958
|
+
)
|
|
938
959
|
|
|
939
960
|
base_msg = {"role": role, "content": parts}
|
|
940
|
-
if role ==
|
|
961
|
+
if role == "assistant" and msg.tool_calls:
|
|
941
962
|
base_msg["tool_calls"] = [
|
|
942
963
|
{
|
|
943
964
|
"id": tc.id,
|
|
944
965
|
"type": tc.type,
|
|
945
|
-
"function": {
|
|
946
|
-
"name": tc.function.name,
|
|
947
|
-
"arguments": tc.function.arguments
|
|
948
|
-
}
|
|
966
|
+
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
|
949
967
|
}
|
|
950
968
|
for tc in msg.tool_calls
|
|
951
969
|
]
|
|
952
|
-
|
|
970
|
+
|
|
953
971
|
return base_msg
|
|
954
972
|
|
|
955
|
-
|
|
973
|
+
|
|
974
|
+
def _resolve_schema_refs(
|
|
975
|
+
schema: Dict[str, Any], defs: Optional[Dict[str, Any]] = None
|
|
976
|
+
) -> Dict[str, Any]:
|
|
977
|
+
"""
|
|
978
|
+
Recursively resolve $ref references in a JSON schema by inlining definitions.
|
|
979
|
+
|
|
980
|
+
Args:
|
|
981
|
+
schema: The schema object to process (may contain $ref)
|
|
982
|
+
defs: The $defs dictionary containing reusable definitions
|
|
983
|
+
|
|
984
|
+
Returns:
|
|
985
|
+
Schema with all references resolved inline
|
|
986
|
+
"""
|
|
987
|
+
if defs is None:
|
|
988
|
+
# Extract $defs from root schema if present
|
|
989
|
+
defs = schema.get("$defs", {})
|
|
990
|
+
|
|
991
|
+
# If this is a reference, resolve it
|
|
992
|
+
if isinstance(schema, dict) and "$ref" in schema:
|
|
993
|
+
ref_path = schema["$ref"]
|
|
994
|
+
|
|
995
|
+
# Handle #/$defs/DefinitionName format
|
|
996
|
+
if ref_path.startswith("#/$defs/"):
|
|
997
|
+
def_name = ref_path.split("/")[-1]
|
|
998
|
+
if def_name in defs:
|
|
999
|
+
# Recursively resolve the definition (it might have refs too)
|
|
1000
|
+
resolved_def = _resolve_schema_refs(defs[def_name], defs)
|
|
1001
|
+
return resolved_def
|
|
1002
|
+
else:
|
|
1003
|
+
# If definition not found, return the original ref
|
|
1004
|
+
return schema
|
|
1005
|
+
else:
|
|
1006
|
+
# Other ref formats - return as is
|
|
1007
|
+
return schema
|
|
1008
|
+
|
|
1009
|
+
# If this is a dict, recursively process all values
|
|
1010
|
+
if isinstance(schema, dict):
|
|
1011
|
+
result = {}
|
|
1012
|
+
for key, value in schema.items():
|
|
1013
|
+
# Skip $defs as we're inlining them
|
|
1014
|
+
if key == "$defs":
|
|
1015
|
+
continue
|
|
1016
|
+
result[key] = _resolve_schema_refs(value, defs)
|
|
1017
|
+
return result
|
|
1018
|
+
|
|
1019
|
+
# If this is a list, recursively process all items
|
|
1020
|
+
if isinstance(schema, list):
|
|
1021
|
+
return [_resolve_schema_refs(item, defs) for item in schema]
|
|
1022
|
+
|
|
1023
|
+
# For primitive types, return as is
|
|
1024
|
+
return schema
|
|
1025
|
+
|
|
1026
|
+
|
|
1027
|
+
def _pydantic_to_json_schema(
|
|
1028
|
+
model_class: type[BaseModel], inline_refs: bool = False
|
|
1029
|
+
) -> Dict[str, Any]:
|
|
956
1030
|
"""
|
|
957
1031
|
Convert a Pydantic model to JSON schema for OpenAI tools.
|
|
958
|
-
|
|
1032
|
+
|
|
959
1033
|
Args:
|
|
960
1034
|
model_class: Pydantic model class
|
|
961
|
-
|
|
1035
|
+
inline_refs: If True, resolve $refs and inline $defs in the schema
|
|
1036
|
+
|
|
962
1037
|
Returns:
|
|
963
1038
|
JSON schema dictionary
|
|
964
1039
|
"""
|
|
965
|
-
if hasattr(model_class,
|
|
1040
|
+
if hasattr(model_class, "model_json_schema"):
|
|
966
1041
|
# Pydantic v2
|
|
967
|
-
|
|
1042
|
+
schema = model_class.model_json_schema()
|
|
968
1043
|
else:
|
|
969
1044
|
# Pydantic v1 fallback
|
|
970
|
-
|
|
1045
|
+
schema = model_class.schema()
|
|
1046
|
+
|
|
1047
|
+
# If inline_refs is True, resolve all references
|
|
1048
|
+
if inline_refs:
|
|
1049
|
+
schema = _resolve_schema_refs(schema)
|
|
1050
|
+
|
|
1051
|
+
return schema
|