jaf-py 2.5.10__py3-none-any.whl → 2.5.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. jaf/__init__.py +154 -57
  2. jaf/a2a/__init__.py +42 -21
  3. jaf/a2a/agent.py +79 -126
  4. jaf/a2a/agent_card.py +87 -78
  5. jaf/a2a/client.py +30 -66
  6. jaf/a2a/examples/client_example.py +12 -12
  7. jaf/a2a/examples/integration_example.py +38 -47
  8. jaf/a2a/examples/server_example.py +56 -53
  9. jaf/a2a/memory/__init__.py +0 -4
  10. jaf/a2a/memory/cleanup.py +28 -21
  11. jaf/a2a/memory/factory.py +155 -133
  12. jaf/a2a/memory/providers/composite.py +21 -26
  13. jaf/a2a/memory/providers/in_memory.py +89 -83
  14. jaf/a2a/memory/providers/postgres.py +117 -115
  15. jaf/a2a/memory/providers/redis.py +128 -121
  16. jaf/a2a/memory/serialization.py +77 -87
  17. jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
  18. jaf/a2a/memory/tests/test_cleanup.py +211 -94
  19. jaf/a2a/memory/tests/test_serialization.py +73 -68
  20. jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
  21. jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
  22. jaf/a2a/memory/types.py +91 -53
  23. jaf/a2a/protocol.py +95 -125
  24. jaf/a2a/server.py +90 -118
  25. jaf/a2a/standalone_client.py +30 -43
  26. jaf/a2a/tests/__init__.py +16 -33
  27. jaf/a2a/tests/run_tests.py +17 -53
  28. jaf/a2a/tests/test_agent.py +40 -140
  29. jaf/a2a/tests/test_client.py +54 -117
  30. jaf/a2a/tests/test_integration.py +28 -82
  31. jaf/a2a/tests/test_protocol.py +54 -139
  32. jaf/a2a/tests/test_types.py +50 -136
  33. jaf/a2a/types.py +58 -34
  34. jaf/cli.py +21 -41
  35. jaf/core/__init__.py +7 -1
  36. jaf/core/agent_tool.py +93 -72
  37. jaf/core/analytics.py +257 -207
  38. jaf/core/checkpoint.py +223 -0
  39. jaf/core/composition.py +249 -235
  40. jaf/core/engine.py +817 -519
  41. jaf/core/errors.py +55 -42
  42. jaf/core/guardrails.py +276 -202
  43. jaf/core/handoff.py +47 -31
  44. jaf/core/parallel_agents.py +69 -75
  45. jaf/core/performance.py +75 -73
  46. jaf/core/proxy.py +43 -44
  47. jaf/core/proxy_helpers.py +24 -27
  48. jaf/core/regeneration.py +220 -129
  49. jaf/core/state.py +68 -66
  50. jaf/core/streaming.py +115 -108
  51. jaf/core/tool_results.py +111 -101
  52. jaf/core/tools.py +114 -116
  53. jaf/core/tracing.py +310 -210
  54. jaf/core/types.py +403 -151
  55. jaf/core/workflows.py +209 -168
  56. jaf/exceptions.py +46 -38
  57. jaf/memory/__init__.py +1 -6
  58. jaf/memory/approval_storage.py +54 -77
  59. jaf/memory/factory.py +4 -4
  60. jaf/memory/providers/in_memory.py +216 -180
  61. jaf/memory/providers/postgres.py +216 -146
  62. jaf/memory/providers/redis.py +173 -116
  63. jaf/memory/types.py +70 -51
  64. jaf/memory/utils.py +36 -34
  65. jaf/plugins/__init__.py +12 -12
  66. jaf/plugins/base.py +105 -96
  67. jaf/policies/__init__.py +0 -1
  68. jaf/policies/handoff.py +37 -46
  69. jaf/policies/validation.py +76 -52
  70. jaf/providers/__init__.py +6 -3
  71. jaf/providers/mcp.py +97 -51
  72. jaf/providers/model.py +475 -283
  73. jaf/server/__init__.py +1 -1
  74. jaf/server/main.py +7 -11
  75. jaf/server/server.py +514 -359
  76. jaf/server/types.py +208 -52
  77. jaf/utils/__init__.py +17 -18
  78. jaf/utils/attachments.py +111 -116
  79. jaf/utils/document_processor.py +175 -174
  80. jaf/visualization/__init__.py +1 -1
  81. jaf/visualization/example.py +111 -110
  82. jaf/visualization/functional_core.py +46 -71
  83. jaf/visualization/graphviz.py +154 -189
  84. jaf/visualization/imperative_shell.py +7 -16
  85. jaf/visualization/types.py +8 -4
  86. {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/METADATA +2 -2
  87. jaf_py-2.5.12.dist-info/RECORD +97 -0
  88. jaf_py-2.5.10.dist-info/RECORD +0 -96
  89. {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/WHEEL +0 -0
  90. {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/entry_points.txt +0 -0
  91. {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/licenses/LICENSE +0 -0
  92. {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/top_level.txt +0 -0
jaf/providers/model.py CHANGED
@@ -10,23 +10,36 @@ import httpx
10
10
  import time
11
11
  import os
12
12
  import base64
13
+ import asyncio
13
14
 
14
15
  from openai import AsyncOpenAI
15
16
  from pydantic import BaseModel
16
17
  import litellm
17
18
 
18
19
  from ..core.types import (
19
- Agent, ContentRole, Message, ModelProvider, RunConfig, RunState,
20
- CompletionStreamChunk, ToolCallDelta, ToolCallFunctionDelta,
21
- MessageContentPart, get_text_content
20
+ Agent,
21
+ ContentRole,
22
+ Message,
23
+ ModelProvider,
24
+ RunConfig,
25
+ RunState,
26
+ CompletionStreamChunk,
27
+ ToolCallDelta,
28
+ ToolCallFunctionDelta,
29
+ MessageContentPart,
30
+ get_text_content,
31
+ RetryEvent,
32
+ RetryEventData,
22
33
  )
23
34
  from ..core.proxy import ProxyConfig
24
35
  from ..utils.document_processor import (
25
- extract_document_content, is_document_supported,
26
- get_document_description, DocumentProcessingError
36
+ extract_document_content,
37
+ is_document_supported,
38
+ get_document_description,
39
+ DocumentProcessingError,
27
40
  )
28
41
 
29
- Ctx = TypeVar('Ctx')
42
+ Ctx = TypeVar("Ctx")
30
43
 
31
44
  # Vision model caching
32
45
  VISION_MODEL_CACHE_TTL = 5 * 60 # 5 minutes
@@ -34,92 +47,183 @@ VISION_API_TIMEOUT = 3.0 # 3 seconds
34
47
  _vision_model_cache: Dict[str, Dict[str, Any]] = {}
35
48
  MAX_IMAGE_BYTES = int(os.environ.get("JAF_MAX_IMAGE_BYTES", 8 * 1024 * 1024))
36
49
 
50
+
37
51
  async def _is_vision_model(model: str, base_url: str) -> bool:
38
52
  """
39
53
  Check if a model supports vision capabilities.
40
-
54
+
41
55
  Args:
42
56
  model: Model name to check
43
57
  base_url: Base URL of the LiteLLM server
44
-
58
+
45
59
  Returns:
46
60
  True if model supports vision, False otherwise
47
61
  """
48
62
  cache_key = f"{base_url}:{model}"
49
63
  cached = _vision_model_cache.get(cache_key)
50
-
51
- if cached and time.time() - cached['timestamp'] < VISION_MODEL_CACHE_TTL:
52
- return cached['supports']
53
-
64
+
65
+ if cached and time.time() - cached["timestamp"] < VISION_MODEL_CACHE_TTL:
66
+ return cached["supports"]
67
+
54
68
  try:
55
69
  async with httpx.AsyncClient(timeout=VISION_API_TIMEOUT) as client:
56
70
  response = await client.get(
57
- f"{base_url}/model_group/info",
58
- headers={'accept': 'application/json'}
71
+ f"{base_url}/model_group/info", headers={"accept": "application/json"}
59
72
  )
60
-
73
+
61
74
  if response.status_code == 200:
62
75
  data = response.json()
63
76
  model_info = None
64
-
65
- if 'data' in data and isinstance(data['data'], list):
66
- for m in data['data']:
67
- if (m.get('model_group') == model or
68
- model in str(m.get('model_group', ''))):
77
+
78
+ if "data" in data and isinstance(data["data"], list):
79
+ for m in data["data"]:
80
+ if m.get("model_group") == model or model in str(m.get("model_group", "")):
69
81
  model_info = m
70
82
  break
71
-
72
- if model_info and 'supports_vision' in model_info:
73
- result = model_info['supports_vision']
74
- _vision_model_cache[cache_key] = {
75
- 'supports': result,
76
- 'timestamp': time.time()
77
- }
83
+
84
+ if model_info and "supports_vision" in model_info:
85
+ result = model_info["supports_vision"]
86
+ _vision_model_cache[cache_key] = {"supports": result, "timestamp": time.time()}
78
87
  return result
79
88
  else:
80
- print(f"Warning: Vision API returned status {response.status_code} for model {model}")
81
-
89
+ print(
90
+ f"Warning: Vision API returned status {response.status_code} for model {model}"
91
+ )
92
+
82
93
  except Exception as e:
83
94
  print(f"Warning: Vision API error for model {model}: {e}")
84
-
95
+
85
96
  # Fallback to known vision models
86
97
  known_vision_models = [
87
- 'gpt-4-vision-preview',
88
- 'gpt-4o',
89
- 'gpt-4o-mini',
90
- 'claude-sonnet-4',
91
- 'claude-sonnet-4-20250514',
92
- 'gemini-2.5-flash',
93
- 'gemini-2.5-pro'
98
+ "gpt-4-vision-preview",
99
+ "gpt-4o",
100
+ "gpt-4o-mini",
101
+ "claude-sonnet-4",
102
+ "claude-sonnet-4-20250514",
103
+ "gemini-2.5-flash",
104
+ "gemini-2.5-pro",
94
105
  ]
95
-
106
+
96
107
  is_known_vision_model = any(
97
- vision_model.lower() in model.lower()
98
- for vision_model in known_vision_models
108
+ vision_model.lower() in model.lower() for vision_model in known_vision_models
99
109
  )
100
-
101
- _vision_model_cache[cache_key] = {
102
- 'supports': is_known_vision_model,
103
- 'timestamp': time.time()
104
- }
105
-
110
+
111
+ _vision_model_cache[cache_key] = {"supports": is_known_vision_model, "timestamp": time.time()}
112
+
106
113
  return is_known_vision_model
107
114
 
115
+
116
+ async def _retry_with_events(
117
+ operation_func,
118
+ state: RunState,
119
+ config: RunConfig,
120
+ operation_name: str = "llm_call",
121
+ max_retries: int = 3,
122
+ backoff_factor: float = 1.0,
123
+ ):
124
+ """
125
+ Wrapper that retries an async operation and emits retry events.
126
+
127
+ Args:
128
+ operation_func: Async function to execute (should accept no arguments)
129
+ state: Current run state
130
+ config: Run configuration with event handler
131
+ operation_name: Name of the operation for logging
132
+ max_retries: Maximum number of retry attempts
133
+ backoff_factor: Exponential backoff multiplier
134
+
135
+ Returns:
136
+ Result from operation_func
137
+
138
+ Raises:
139
+ Last exception if all retries are exhausted
140
+ """
141
+ last_exception = None
142
+
143
+ for attempt in range(max_retries + 1):
144
+ try:
145
+ return await operation_func()
146
+ except Exception as e:
147
+ last_exception = e
148
+
149
+ # Check if this is a retryable HTTP error
150
+ is_retryable = False
151
+ reason = str(e)
152
+ error_details = {"error_type": type(e).__name__, "error_message": str(e)}
153
+
154
+ # Check for HTTP errors (common in OpenAI/LiteLLM)
155
+ if hasattr(e, "status_code"):
156
+ status_code = e.status_code
157
+ error_details["status_code"] = status_code
158
+
159
+ # Retry on rate limits (429) and server errors (5xx)
160
+ if status_code == 429:
161
+ is_retryable = True
162
+ reason = f"HTTP {status_code} - Rate Limit"
163
+ elif 500 <= status_code < 600:
164
+ is_retryable = True
165
+ reason = f"HTTP {status_code} - Server Error"
166
+ else:
167
+ reason = f"HTTP {status_code}"
168
+
169
+ # Check for common exception names
170
+ elif "RateLimitError" in type(e).__name__:
171
+ is_retryable = True
172
+ reason = "Rate Limit Error"
173
+ elif "ServiceUnavailableError" in type(e).__name__ or "APIError" in type(e).__name__:
174
+ is_retryable = True
175
+ reason = "API Error"
176
+ elif "Timeout" in type(e).__name__:
177
+ is_retryable = True
178
+ reason = "Timeout"
179
+
180
+ # If not last attempt and is retryable, retry with backoff
181
+ if attempt < max_retries and is_retryable:
182
+ delay = backoff_factor * (2**attempt) # Exponential backoff
183
+
184
+ # Emit retry event
185
+ if config.on_event:
186
+ retry_event = RetryEvent(
187
+ data=RetryEventData(
188
+ attempt=attempt + 1,
189
+ max_retries=max_retries,
190
+ reason=reason,
191
+ operation=operation_name,
192
+ trace_id=state.trace_id,
193
+ run_id=state.run_id,
194
+ delay=delay,
195
+ error_details=error_details,
196
+ )
197
+ )
198
+ config.on_event(retry_event)
199
+
200
+ print(
201
+ f"[JAF:RETRY] Attempt {attempt + 1}/{max_retries} failed: {reason}. Retrying in {delay}s..."
202
+ )
203
+ await asyncio.sleep(delay)
204
+ else:
205
+ # Not retryable or last attempt, re-raise
206
+ raise
207
+
208
+ # Should never reach here, but just in case
209
+ raise last_exception
210
+
211
+
108
212
  def make_litellm_provider(
109
213
  base_url: str,
110
214
  api_key: str = "anything",
111
215
  default_timeout: Optional[float] = None,
112
- proxy_config: Optional[ProxyConfig] = None
216
+ proxy_config: Optional[ProxyConfig] = None,
113
217
  ) -> ModelProvider[Ctx]:
114
218
  """
115
219
  Create a LiteLLM-compatible model provider.
116
-
220
+
117
221
  Args:
118
222
  base_url: Base URL for the LiteLLM server
119
223
  api_key: API key (defaults to "anything" for local servers)
120
224
  default_timeout: Default timeout for model API calls in seconds
121
225
  proxy_config: Optional proxy configuration
122
-
226
+
123
227
  Returns:
124
228
  ModelProvider instance
125
229
  """
@@ -128,48 +232,47 @@ def make_litellm_provider(
128
232
  def __init__(self):
129
233
  # Default to "anything" if api_key is not provided, for local servers
130
234
  effective_api_key = api_key if api_key is not None else "anything"
131
-
235
+
132
236
  # Configure HTTP client with proxy support
133
237
  client_kwargs = {
134
238
  "base_url": base_url,
135
239
  "api_key": effective_api_key,
136
240
  }
137
-
241
+
138
242
  if proxy_config:
139
243
  proxies = proxy_config.to_httpx_proxies()
140
244
  if proxies:
141
245
  # Create httpx client with proxy configuration
142
246
  try:
143
247
  # Use the https proxy if available, otherwise http proxy
144
- proxy_url = proxies.get('https://') or proxies.get('http://')
248
+ proxy_url = proxies.get("https://") or proxies.get("http://")
145
249
  if proxy_url:
146
250
  http_client = httpx.AsyncClient(proxy=proxy_url)
147
251
  client_kwargs["http_client"] = http_client
148
252
  except Exception as e:
149
253
  print(f"Warning: Could not configure proxy: {e}")
150
254
  # Fall back to environment variables for proxy
151
-
255
+
152
256
  self.client = AsyncOpenAI(**client_kwargs)
153
257
  self.default_timeout = default_timeout
154
258
 
155
259
  async def get_completion(
156
- self,
157
- state: RunState[Ctx],
158
- agent: Agent[Ctx, Any],
159
- config: RunConfig[Ctx]
260
+ self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: RunConfig[Ctx]
160
261
  ) -> Dict[str, Any]:
161
262
  """Get completion from the model."""
162
263
 
163
264
  # Determine model to use
164
- model = (config.model_override or
165
- (agent.model_config.name if agent.model_config else "gpt-4o"))
265
+ model = config.model_override or (
266
+ agent.model_config.name if agent.model_config else "gpt-4o"
267
+ )
166
268
 
167
269
  # Check if any message contains image content or image attachments
168
270
  has_image_content = any(
169
- (isinstance(msg.content, list) and
170
- any(part.type == 'image_url' for part in msg.content)) or
171
- (msg.attachments and
172
- any(att.kind == 'image' for att in msg.attachments))
271
+ (
272
+ isinstance(msg.content, list)
273
+ and any(part.type == "image_url" for part in msg.content)
274
+ )
275
+ or (msg.attachments and any(att.kind == "image" for att in msg.attachments))
173
276
  for msg in state.messages
174
277
  )
175
278
 
@@ -182,51 +285,59 @@ def make_litellm_provider(
182
285
  )
183
286
 
184
287
  # Create system message
185
- system_message = {
186
- "role": "system",
187
- "content": agent.instructions(state)
188
- }
288
+ system_message = {"role": "system", "content": agent.instructions(state)}
189
289
 
190
290
  # Convert messages to OpenAI format
191
291
  converted_messages = []
192
292
  for msg in state.messages:
193
293
  converted_msg = await _convert_message(msg)
194
294
  converted_messages.append(converted_msg)
195
-
295
+
196
296
  messages = [system_message] + converted_messages
197
297
 
198
298
  # Convert tools to OpenAI format
199
299
  tools = None
200
300
  if agent.tools:
301
+ # Check if we should inline schema refs
302
+ inline_refs = (
303
+ agent.model_config.inline_tool_schemas if agent.model_config else False
304
+ )
201
305
  tools = [
202
306
  {
203
307
  "type": "function",
204
308
  "function": {
205
309
  "name": tool.schema.name,
206
310
  "description": tool.schema.description,
207
- "parameters": _pydantic_to_json_schema(tool.schema.parameters),
208
- }
311
+ "parameters": _pydantic_to_json_schema(
312
+ tool.schema.parameters, inline_refs=inline_refs or False
313
+ ),
314
+ },
209
315
  }
210
316
  for tool in agent.tools
211
317
  ]
212
318
 
213
319
  # Determine tool choice behavior
214
320
  last_message = state.messages[-1] if state.messages else None
215
- is_after_tool_call = last_message and (last_message.role == ContentRole.TOOL or last_message.role == 'tool')
321
+ is_after_tool_call = last_message and (
322
+ last_message.role == ContentRole.TOOL or last_message.role == "tool"
323
+ )
216
324
 
217
325
  # Prepare request parameters
218
- request_params = {
219
- "model": model,
220
- "messages": messages,
221
- "stream": False
222
- }
326
+ request_params = {"model": model, "messages": messages, "stream": False}
223
327
 
224
328
  # Add optional parameters
225
329
  if agent.model_config:
226
330
  if agent.model_config.temperature is not None:
227
331
  request_params["temperature"] = agent.model_config.temperature
228
- if agent.model_config.max_tokens is not None:
229
- request_params["max_tokens"] = agent.model_config.max_tokens
332
+ # Use agent's max_tokens if set, otherwise fall back to config's max_tokens
333
+ max_tokens = agent.model_config.max_tokens
334
+ if max_tokens is None:
335
+ max_tokens = config.max_tokens
336
+ if max_tokens is not None:
337
+ request_params["max_tokens"] = max_tokens
338
+ elif config.max_tokens is not None:
339
+ # No model_config but config has max_tokens
340
+ request_params["max_tokens"] = config.max_tokens
230
341
 
231
342
  if tools:
232
343
  request_params["tools"] = tools
@@ -236,8 +347,14 @@ def make_litellm_provider(
236
347
  if agent.output_codec:
237
348
  request_params["response_format"] = {"type": "json_object"}
238
349
 
239
- # Make the API call
240
- response = await self.client.chat.completions.create(**request_params)
350
+ # Make the API call with retry handling
351
+ async def _api_call():
352
+ return await self.client.chat.completions.create(**request_params)
353
+
354
+ # Use retry wrapper to track retries in Langfuse
355
+ response = await _retry_with_events(
356
+ _api_call, state, config, operation_name="llm_call", max_retries=3, backoff_factor=1.0
357
+ )
241
358
 
242
359
  # Return in the expected format that the engine expects
243
360
  choice = response.choices[0]
@@ -247,12 +364,9 @@ def make_litellm_provider(
247
364
  if choice.message.tool_calls:
248
365
  tool_calls = [
249
366
  {
250
- 'id': tc.id,
251
- 'type': tc.type,
252
- 'function': {
253
- 'name': tc.function.name,
254
- 'arguments': tc.function.arguments
255
- }
367
+ "id": tc.id,
368
+ "type": tc.type,
369
+ "function": {"name": tc.function.name, "arguments": tc.function.arguments},
256
370
  }
257
371
  for tc in choice.message.tool_calls
258
372
  ]
@@ -267,64 +381,64 @@ def make_litellm_provider(
267
381
  }
268
382
 
269
383
  return {
270
- 'id': response.id,
271
- 'created': response.created,
272
- 'model': response.model,
273
- 'system_fingerprint': response.system_fingerprint,
274
- 'message': {
275
- 'content': choice.message.content,
276
- 'tool_calls': tool_calls
277
- },
278
- 'usage': usage_data,
279
- 'prompt': messages
384
+ "id": response.id,
385
+ "created": response.created,
386
+ "model": response.model,
387
+ "system_fingerprint": response.system_fingerprint,
388
+ "message": {"content": choice.message.content, "tool_calls": tool_calls},
389
+ "usage": usage_data,
390
+ "prompt": messages,
280
391
  }
281
392
 
282
393
  async def get_completion_stream(
283
- self,
284
- state: RunState[Ctx],
285
- agent: Agent[Ctx, Any],
286
- config: RunConfig[Ctx]
394
+ self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: RunConfig[Ctx]
287
395
  ) -> AsyncIterator[CompletionStreamChunk]:
288
396
  """
289
397
  Stream completion chunks from the model provider, yielding text deltas and tool-call deltas.
290
398
  Uses OpenAI-compatible streaming via LiteLLM endpoint.
291
399
  """
292
400
  # Determine model to use
293
- model = (config.model_override or
294
- (agent.model_config.name if agent.model_config else "gpt-4o"))
401
+ model = config.model_override or (
402
+ agent.model_config.name if agent.model_config else "gpt-4o"
403
+ )
295
404
 
296
405
  # Create system message
297
- system_message = {
298
- "role": "system",
299
- "content": agent.instructions(state)
300
- }
406
+ system_message = {"role": "system", "content": agent.instructions(state)}
301
407
 
302
- # Convert messages to OpenAI format
408
+ # Convert messages to OpenAI format
303
409
  converted_messages = []
304
410
  for msg in state.messages:
305
411
  converted_msg = await _convert_message(msg)
306
412
  converted_messages.append(converted_msg)
307
-
413
+
308
414
  messages = [system_message] + converted_messages
309
415
 
310
416
  # Convert tools to OpenAI format
311
417
  tools = None
312
418
  if agent.tools:
419
+ # Check if we should inline schema refs
420
+ inline_refs = (
421
+ agent.model_config.inline_tool_schemas if agent.model_config else False
422
+ )
313
423
  tools = [
314
424
  {
315
425
  "type": "function",
316
426
  "function": {
317
427
  "name": tool.schema.name,
318
428
  "description": tool.schema.description,
319
- "parameters": _pydantic_to_json_schema(tool.schema.parameters),
320
- }
429
+ "parameters": _pydantic_to_json_schema(
430
+ tool.schema.parameters, inline_refs=inline_refs or False
431
+ ),
432
+ },
321
433
  }
322
434
  for tool in agent.tools
323
435
  ]
324
436
 
325
437
  # Determine tool choice behavior
326
438
  last_message = state.messages[-1] if state.messages else None
327
- is_after_tool_call = last_message and (last_message.role == ContentRole.TOOL or last_message.role == 'tool')
439
+ is_after_tool_call = last_message and (
440
+ last_message.role == ContentRole.TOOL or last_message.role == "tool"
441
+ )
328
442
 
329
443
  # Prepare request parameters
330
444
  request_params: Dict[str, Any] = {
@@ -336,8 +450,15 @@ def make_litellm_provider(
336
450
  if agent.model_config:
337
451
  if agent.model_config.temperature is not None:
338
452
  request_params["temperature"] = agent.model_config.temperature
339
- if agent.model_config.max_tokens is not None:
340
- request_params["max_tokens"] = agent.model_config.max_tokens
453
+ # Use agent's max_tokens if set, otherwise fall back to config's max_tokens
454
+ max_tokens = agent.model_config.max_tokens
455
+ if max_tokens is None:
456
+ max_tokens = config.max_tokens
457
+ if max_tokens is not None:
458
+ request_params["max_tokens"] = max_tokens
459
+ elif config.max_tokens is not None:
460
+ # No model_config but config has max_tokens
461
+ request_params["max_tokens"] = config.max_tokens
341
462
 
342
463
  if tools:
343
464
  request_params["tools"] = tools
@@ -388,19 +509,20 @@ def make_litellm_provider(
388
509
  fn = getattr(tc, "function", None)
389
510
  fn_name = getattr(fn, "name", None) if fn is not None else None
390
511
  # OpenAI streams "arguments" as incremental deltas
391
- args_delta = getattr(fn, "arguments", None) if fn is not None else None
512
+ args_delta = (
513
+ getattr(fn, "arguments", None) if fn is not None else None
514
+ )
392
515
 
393
516
  yield CompletionStreamChunk(
394
517
  tool_call_delta=ToolCallDelta(
395
518
  index=idx,
396
519
  id=tc_id,
397
- type='function',
520
+ type="function",
398
521
  function=ToolCallFunctionDelta(
399
- name=fn_name,
400
- arguments_delta=args_delta
401
- )
522
+ name=fn_name, arguments_delta=args_delta
523
+ ),
402
524
  ),
403
- raw=raw_obj
525
+ raw=raw_obj,
404
526
  )
405
527
  except Exception:
406
528
  # Skip malformed tool-call deltas
@@ -408,26 +530,29 @@ def make_litellm_provider(
408
530
 
409
531
  # Completion ended
410
532
  if finish_reason:
411
- yield CompletionStreamChunk(is_done=True, finish_reason=finish_reason, raw=raw_obj)
533
+ yield CompletionStreamChunk(
534
+ is_done=True, finish_reason=finish_reason, raw=raw_obj
535
+ )
412
536
  except Exception:
413
537
  # Skip individual chunk errors, keep streaming
414
538
  continue
415
539
 
416
540
  return LiteLLMProvider()
417
541
 
542
+
418
543
  def make_litellm_sdk_provider(
419
544
  api_key: Optional[str] = None,
420
545
  model: str = "gpt-3.5-turbo",
421
546
  base_url: Optional[str] = None,
422
547
  default_timeout: Optional[float] = None,
423
- **litellm_kwargs: Any
548
+ **litellm_kwargs: Any,
424
549
  ) -> ModelProvider[Ctx]:
425
550
  """
426
551
  Create a LiteLLM SDK-based model provider with universal provider support.
427
-
552
+
428
553
  LiteLLM automatically detects the provider from the model name and handles
429
554
  API key management through environment variables or direct parameters.
430
-
555
+
431
556
  Args:
432
557
  api_key: API key for the provider (optional, can use env vars)
433
558
  model: Model name (e.g., "gpt-4", "claude-3-sonnet", "gemini-pro", "llama2", etc.)
@@ -440,23 +565,23 @@ def make_litellm_sdk_provider(
440
565
  - azure_deployment: "your-deployment" (for Azure OpenAI)
441
566
  - api_base: "https://your-endpoint.com" (custom endpoint)
442
567
  - custom_llm_provider: "custom_provider_name"
443
-
568
+
444
569
  Returns:
445
570
  ModelProvider instance
446
-
571
+
447
572
  Examples:
448
573
  # OpenAI
449
574
  make_litellm_sdk_provider(api_key="sk-...", model="gpt-4")
450
-
575
+
451
576
  # Anthropic Claude
452
577
  make_litellm_sdk_provider(api_key="sk-ant-...", model="claude-3-sonnet-20240229")
453
-
578
+
454
579
  # Google Gemini
455
580
  make_litellm_sdk_provider(model="gemini-pro", vertex_project="my-project")
456
-
581
+
457
582
  # Ollama (local)
458
583
  make_litellm_sdk_provider(model="ollama/llama2", base_url="http://localhost:11434")
459
-
584
+
460
585
  # Azure OpenAI
461
586
  make_litellm_sdk_provider(
462
587
  model="azure/gpt-4",
@@ -464,13 +589,13 @@ def make_litellm_sdk_provider(
464
589
  azure_deployment="gpt-4-deployment",
465
590
  api_base="https://your-resource.openai.azure.com"
466
591
  )
467
-
592
+
468
593
  # Hugging Face
469
594
  make_litellm_sdk_provider(
470
595
  model="huggingface/microsoft/DialoGPT-medium",
471
596
  api_key="hf_..."
472
597
  )
473
-
598
+
474
599
  # Any custom provider
475
600
  make_litellm_sdk_provider(
476
601
  model="custom_provider/model-name",
@@ -488,10 +613,7 @@ def make_litellm_sdk_provider(
488
613
  self.litellm_kwargs = litellm_kwargs
489
614
 
490
615
  async def get_completion(
491
- self,
492
- state: RunState[Ctx],
493
- agent: Agent[Ctx, Any],
494
- config: RunConfig[Ctx]
616
+ self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: RunConfig[Ctx]
495
617
  ) -> Dict[str, Any]:
496
618
  """Get completion from the model using LiteLLM SDK."""
497
619
 
@@ -499,10 +621,7 @@ def make_litellm_sdk_provider(
499
621
  model_name = config.model_override or self.model
500
622
 
501
623
  # Create system message
502
- system_message = {
503
- "role": "system",
504
- "content": agent.instructions(state)
505
- }
624
+ system_message = {"role": "system", "content": agent.instructions(state)}
506
625
 
507
626
  # Convert messages to OpenAI format
508
627
  messages = [system_message]
@@ -513,24 +632,26 @@ def make_litellm_sdk_provider(
513
632
  # Convert tools to OpenAI format
514
633
  tools = None
515
634
  if agent.tools:
635
+ # Check if we should inline schema refs
636
+ inline_refs = (
637
+ agent.model_config.inline_tool_schemas if agent.model_config else False
638
+ )
516
639
  tools = [
517
640
  {
518
641
  "type": "function",
519
642
  "function": {
520
643
  "name": tool.schema.name,
521
644
  "description": tool.schema.description,
522
- "parameters": _pydantic_to_json_schema(tool.schema.parameters),
523
- }
645
+ "parameters": _pydantic_to_json_schema(
646
+ tool.schema.parameters, inline_refs=inline_refs or False
647
+ ),
648
+ },
524
649
  }
525
650
  for tool in agent.tools
526
651
  ]
527
652
 
528
653
  # Prepare request parameters for LiteLLM
529
- request_params = {
530
- "model": model_name,
531
- "messages": messages,
532
- **self.litellm_kwargs
533
- }
654
+ request_params = {"model": model_name, "messages": messages, **self.litellm_kwargs}
534
655
 
535
656
  # Add API key if provided
536
657
  if self.api_key:
@@ -540,8 +661,15 @@ def make_litellm_sdk_provider(
540
661
  if agent.model_config:
541
662
  if agent.model_config.temperature is not None:
542
663
  request_params["temperature"] = agent.model_config.temperature
543
- if agent.model_config.max_tokens is not None:
544
- request_params["max_tokens"] = agent.model_config.max_tokens
664
+ # Use agent's max_tokens if set, otherwise fall back to config's max_tokens
665
+ max_tokens = agent.model_config.max_tokens
666
+ if max_tokens is None:
667
+ max_tokens = config.max_tokens
668
+ if max_tokens is not None:
669
+ request_params["max_tokens"] = max_tokens
670
+ elif config.max_tokens is not None:
671
+ # No model_config but config has max_tokens
672
+ request_params["max_tokens"] = config.max_tokens
545
673
 
546
674
  if tools:
547
675
  request_params["tools"] = tools
@@ -554,8 +682,14 @@ def make_litellm_sdk_provider(
554
682
  if self.base_url:
555
683
  request_params["api_base"] = self.base_url
556
684
 
557
- # Make the API call using litellm
558
- response = await litellm.acompletion(**request_params)
685
+ # Make the API call using litellm with retry handling
686
+ async def _api_call():
687
+ return await litellm.acompletion(**request_params)
688
+
689
+ # Use retry wrapper to track retries in Langfuse
690
+ response = await _retry_with_events(
691
+ _api_call, state, config, operation_name="llm_call", max_retries=3, backoff_factor=1.0
692
+ )
559
693
 
560
694
  # Return in the expected format that the engine expects
561
695
  choice = response.choices[0]
@@ -565,12 +699,9 @@ def make_litellm_sdk_provider(
565
699
  if choice.message.tool_calls:
566
700
  tool_calls = [
567
701
  {
568
- 'id': tc.id,
569
- 'type': tc.type,
570
- 'function': {
571
- 'name': tc.function.name,
572
- 'arguments': tc.function.arguments
573
- }
702
+ "id": tc.id,
703
+ "type": tc.type,
704
+ "function": {"name": tc.function.name, "arguments": tc.function.arguments},
574
705
  }
575
706
  for tc in choice.message.tool_calls
576
707
  ]
@@ -585,23 +716,17 @@ def make_litellm_sdk_provider(
585
716
  }
586
717
 
587
718
  return {
588
- 'id': response.id,
589
- 'created': response.created,
590
- 'model': response.model,
591
- 'system_fingerprint': getattr(response, 'system_fingerprint', None),
592
- 'message': {
593
- 'content': choice.message.content,
594
- 'tool_calls': tool_calls
595
- },
596
- 'usage': usage_data,
597
- 'prompt': messages
719
+ "id": response.id,
720
+ "created": response.created,
721
+ "model": response.model,
722
+ "system_fingerprint": getattr(response, "system_fingerprint", None),
723
+ "message": {"content": choice.message.content, "tool_calls": tool_calls},
724
+ "usage": usage_data,
725
+ "prompt": messages,
598
726
  }
599
727
 
600
728
  async def get_completion_stream(
601
- self,
602
- state: RunState[Ctx],
603
- agent: Agent[Ctx, Any],
604
- config: RunConfig[Ctx]
729
+ self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: RunConfig[Ctx]
605
730
  ) -> AsyncIterator[CompletionStreamChunk]:
606
731
  """
607
732
  Stream completion chunks from the model provider using LiteLLM SDK.
@@ -610,10 +735,7 @@ def make_litellm_sdk_provider(
610
735
  model_name = config.model_override or self.model
611
736
 
612
737
  # Create system message
613
- system_message = {
614
- "role": "system",
615
- "content": agent.instructions(state)
616
- }
738
+ system_message = {"role": "system", "content": agent.instructions(state)}
617
739
 
618
740
  # Convert messages to OpenAI format
619
741
  messages = [system_message]
@@ -624,14 +746,20 @@ def make_litellm_sdk_provider(
624
746
  # Convert tools to OpenAI format
625
747
  tools = None
626
748
  if agent.tools:
749
+ # Check if we should inline schema refs
750
+ inline_refs = (
751
+ agent.model_config.inline_tool_schemas if agent.model_config else False
752
+ )
627
753
  tools = [
628
754
  {
629
755
  "type": "function",
630
756
  "function": {
631
757
  "name": tool.schema.name,
632
758
  "description": tool.schema.description,
633
- "parameters": _pydantic_to_json_schema(tool.schema.parameters),
634
- }
759
+ "parameters": _pydantic_to_json_schema(
760
+ tool.schema.parameters, inline_refs=inline_refs or False
761
+ ),
762
+ },
635
763
  }
636
764
  for tool in agent.tools
637
765
  ]
@@ -641,7 +769,7 @@ def make_litellm_sdk_provider(
641
769
  "model": model_name,
642
770
  "messages": messages,
643
771
  "stream": True,
644
- **self.litellm_kwargs
772
+ **self.litellm_kwargs,
645
773
  }
646
774
 
647
775
  # Add API key if provided
@@ -652,8 +780,15 @@ def make_litellm_sdk_provider(
652
780
  if agent.model_config:
653
781
  if agent.model_config.temperature is not None:
654
782
  request_params["temperature"] = agent.model_config.temperature
655
- if agent.model_config.max_tokens is not None:
656
- request_params["max_tokens"] = agent.model_config.max_tokens
783
+ # Use agent's max_tokens if set, otherwise fall back to config's max_tokens
784
+ max_tokens = agent.model_config.max_tokens
785
+ if max_tokens is None:
786
+ max_tokens = config.max_tokens
787
+ if max_tokens is not None:
788
+ request_params["max_tokens"] = max_tokens
789
+ elif config.max_tokens is not None:
790
+ # No model_config but config has max_tokens
791
+ request_params["max_tokens"] = config.max_tokens
657
792
 
658
793
  if tools:
659
794
  request_params["tools"] = tools
@@ -668,12 +803,12 @@ def make_litellm_sdk_provider(
668
803
 
669
804
  # Stream using litellm
670
805
  stream = await litellm.acompletion(**request_params)
671
-
806
+
672
807
  async for chunk in stream:
673
808
  try:
674
809
  # Best-effort extraction of raw for debugging
675
810
  try:
676
- raw_obj = chunk.model_dump() if hasattr(chunk, 'model_dump') else None
811
+ raw_obj = chunk.model_dump() if hasattr(chunk, "model_dump") else None
677
812
  except Exception:
678
813
  raw_obj = None
679
814
 
@@ -702,52 +837,59 @@ def make_litellm_sdk_provider(
702
837
  tc_id = getattr(tc, "id", None)
703
838
  fn = getattr(tc, "function", None)
704
839
  fn_name = getattr(fn, "name", None) if fn is not None else None
705
- args_delta = getattr(fn, "arguments", None) if fn is not None else None
840
+ args_delta = (
841
+ getattr(fn, "arguments", None) if fn is not None else None
842
+ )
706
843
 
707
844
  yield CompletionStreamChunk(
708
845
  tool_call_delta=ToolCallDelta(
709
846
  index=idx,
710
847
  id=tc_id,
711
- type='function',
848
+ type="function",
712
849
  function=ToolCallFunctionDelta(
713
- name=fn_name,
714
- arguments_delta=args_delta
715
- )
850
+ name=fn_name, arguments_delta=args_delta
851
+ ),
716
852
  ),
717
- raw=raw_obj
853
+ raw=raw_obj,
718
854
  )
719
855
  except Exception:
720
856
  continue
721
857
 
722
858
  # Completion ended
723
859
  if finish_reason:
724
- yield CompletionStreamChunk(is_done=True, finish_reason=finish_reason, raw=raw_obj)
860
+ yield CompletionStreamChunk(
861
+ is_done=True, finish_reason=finish_reason, raw=raw_obj
862
+ )
725
863
  except Exception:
726
864
  continue
727
865
 
728
866
  return LiteLLMSDKProvider()
729
867
 
868
+
730
869
  async def _convert_message(msg: Message) -> Dict[str, Any]:
731
870
  """
732
871
  Handles all possible role types (string and enum) and content formats.
733
872
  """
734
873
  # Normalize role to handle both string and enum values
735
- role_value = msg.role.value if hasattr(msg.role, 'value') else str(msg.role).lower()
874
+ role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role).lower()
736
875
 
737
876
  # Handle user messages
738
- if role_value in ('user', ContentRole.USER.value if hasattr(ContentRole, 'USER') else 'user'):
877
+ if role_value in ("user", ContentRole.USER.value if hasattr(ContentRole, "USER") else "user"):
739
878
  if isinstance(msg.content, list):
740
879
  # Multi-part content
741
880
  return {
742
881
  "role": "user",
743
- "content": [_convert_content_part(part) for part in msg.content]
882
+ "content": [_convert_content_part(part) for part in msg.content],
744
883
  }
745
884
  else:
746
885
  # Build message with attachments if available
747
- return await _build_chat_message_with_attachments('user', msg)
886
+ return await _build_chat_message_with_attachments("user", msg)
748
887
 
749
888
  # Handle assistant messages
750
- elif role_value in ('assistant', ContentRole.ASSISTANT.value if hasattr(ContentRole, 'ASSISTANT') else 'assistant'):
889
+ elif role_value in (
890
+ "assistant",
891
+ ContentRole.ASSISTANT.value if hasattr(ContentRole, "ASSISTANT") else "assistant",
892
+ ):
751
893
  result = {
752
894
  "role": "assistant",
753
895
  "content": get_text_content(msg.content) or "", # Ensure content is never None
@@ -759,10 +901,7 @@ async def _convert_message(msg: Message) -> Dict[str, Any]:
759
901
  {
760
902
  "id": tc.id,
761
903
  "type": tc.type,
762
- "function": {
763
- "name": tc.function.name,
764
- "arguments": tc.function.arguments
765
- }
904
+ "function": {"name": tc.function.name, "arguments": tc.function.arguments},
766
905
  }
767
906
  for tc in msg.tool_calls
768
907
  if tc.id and tc.function and tc.function.name # Validate tool call structure
@@ -771,37 +910,37 @@ async def _convert_message(msg: Message) -> Dict[str, Any]:
771
910
  return result
772
911
 
773
912
  # Handle system messages
774
- elif role_value in ('system', ContentRole.SYSTEM.value if hasattr(ContentRole, 'SYSTEM') else 'system'):
775
- return {
776
- "role": "system",
777
- "content": get_text_content(msg.content) or ""
778
- }
913
+ elif role_value in (
914
+ "system",
915
+ ContentRole.SYSTEM.value if hasattr(ContentRole, "SYSTEM") else "system",
916
+ ):
917
+ return {"role": "system", "content": get_text_content(msg.content) or ""}
779
918
 
780
919
  # Handle tool messages
781
- elif role_value in ('tool', ContentRole.TOOL.value if hasattr(ContentRole, 'TOOL') else 'tool'):
920
+ elif role_value in ("tool", ContentRole.TOOL.value if hasattr(ContentRole, "TOOL") else "tool"):
782
921
  if not msg.tool_call_id:
783
922
  raise ValueError(f"Tool message must have tool_call_id. Message: {msg}")
784
923
 
785
924
  return {
786
925
  "role": "tool",
787
926
  "content": get_text_content(msg.content) or "",
788
- "tool_call_id": msg.tool_call_id
927
+ "tool_call_id": msg.tool_call_id,
789
928
  }
790
929
 
791
930
  # Handle function messages (legacy support)
792
- elif role_value == 'function':
931
+ elif role_value == "function":
793
932
  if not msg.tool_call_id:
794
933
  raise ValueError(f"Function message must have tool_call_id. Message: {msg}")
795
934
 
796
935
  return {
797
936
  "role": "function",
798
937
  "content": get_text_content(msg.content) or "",
799
- "name": getattr(msg, 'name', 'unknown_function')
938
+ "name": getattr(msg, "name", "unknown_function"),
800
939
  }
801
940
 
802
941
  # Unknown role - provide helpful error message
803
942
  else:
804
- available_roles = ['user', 'assistant', 'system', 'tool', 'function']
943
+ available_roles = ["user", "assistant", "system", "tool", "function"]
805
944
  raise ValueError(
806
945
  f"Unknown message role: {msg.role} (type: {type(msg.role)}). "
807
946
  f"Supported roles: {available_roles}. "
@@ -811,46 +950,31 @@ async def _convert_message(msg: Message) -> Dict[str, Any]:
811
950
 
812
951
  def _convert_content_part(part: MessageContentPart) -> Dict[str, Any]:
813
952
  """Convert MessageContentPart to OpenAI format."""
814
- if part.type == 'text':
815
- return {
816
- "type": "text",
817
- "text": part.text
818
- }
819
- elif part.type == 'image_url':
820
- return {
821
- "type": "image_url",
822
- "image_url": part.image_url
823
- }
824
- elif part.type == 'file':
825
- return {
826
- "type": "file",
827
- "file": part.file
828
- }
953
+ if part.type == "text":
954
+ return {"type": "text", "text": part.text}
955
+ elif part.type == "image_url":
956
+ return {"type": "image_url", "image_url": part.image_url}
957
+ elif part.type == "file":
958
+ return {"type": "file", "file": part.file}
829
959
  else:
830
960
  raise ValueError(f"Unknown content part type: {part.type}")
831
961
 
832
962
 
833
- async def _build_chat_message_with_attachments(
834
- role: str,
835
- msg: Message
836
- ) -> Dict[str, Any]:
963
+ async def _build_chat_message_with_attachments(role: str, msg: Message) -> Dict[str, Any]:
837
964
  """
838
965
  Build multi-part content for Chat Completions if attachments exist.
839
966
  Supports images via image_url and documents via content extraction.
840
967
  """
841
968
  has_attachments = msg.attachments and len(msg.attachments) > 0
842
969
  if not has_attachments:
843
- if role == 'assistant':
970
+ if role == "assistant":
844
971
  base_msg = {"role": "assistant", "content": get_text_content(msg.content)}
845
972
  if msg.tool_calls:
846
973
  base_msg["tool_calls"] = [
847
974
  {
848
975
  "id": tc.id,
849
976
  "type": tc.type,
850
- "function": {
851
- "name": tc.function.name,
852
- "arguments": tc.function.arguments
853
- }
977
+ "function": {"name": tc.function.name, "arguments": tc.function.arguments},
854
978
  }
855
979
  for tc in msg.tool_calls
856
980
  ]
@@ -863,7 +987,7 @@ async def _build_chat_message_with_attachments(
863
987
  parts.append({"type": "text", "text": text_content})
864
988
 
865
989
  for att in msg.attachments:
866
- if att.kind == 'image':
990
+ if att.kind == "image":
867
991
  # Prefer explicit URL; otherwise construct a data URL from base64
868
992
  url = att.url
869
993
  if not url and att.data and att.mime_type:
@@ -871,100 +995,168 @@ async def _build_chat_message_with_attachments(
871
995
  try:
872
996
  # Estimate decoded size (base64 is ~4/3 of decoded size)
873
997
  estimated_size = len(att.data) * 3 // 4
874
-
998
+
875
999
  if estimated_size > MAX_IMAGE_BYTES:
876
- print(f"Warning: Skipping oversized image ({estimated_size} bytes > {MAX_IMAGE_BYTES}). "
877
- f"Set JAF_MAX_IMAGE_BYTES env var to adjust limit.")
878
- parts.append({
879
- "type": "text",
880
- "text": f"[IMAGE SKIPPED: Size exceeds limit of {MAX_IMAGE_BYTES//1024//1024}MB. "
881
- f"Image name: {att.name or 'unnamed'}]"
882
- })
1000
+ print(
1001
+ f"Warning: Skipping oversized image ({estimated_size} bytes > {MAX_IMAGE_BYTES}). "
1002
+ f"Set JAF_MAX_IMAGE_BYTES env var to adjust limit."
1003
+ )
1004
+ parts.append(
1005
+ {
1006
+ "type": "text",
1007
+ "text": f"[IMAGE SKIPPED: Size exceeds limit of {MAX_IMAGE_BYTES // 1024 // 1024}MB. "
1008
+ f"Image name: {att.name or 'unnamed'}]",
1009
+ }
1010
+ )
883
1011
  continue
884
-
1012
+
885
1013
  # Create data URL for valid-sized images
886
1014
  url = f"data:{att.mime_type};base64,{att.data}"
887
1015
  except Exception as e:
888
1016
  print(f"Error processing image data: {e}")
889
- parts.append({
890
- "type": "text",
891
- "text": f"[IMAGE ERROR: Failed to process image data. Image name: {att.name or 'unnamed'}]"
892
- })
1017
+ parts.append(
1018
+ {
1019
+ "type": "text",
1020
+ "text": f"[IMAGE ERROR: Failed to process image data. Image name: {att.name or 'unnamed'}]",
1021
+ }
1022
+ )
893
1023
  continue
894
-
1024
+
895
1025
  if url:
896
- parts.append({
897
- "type": "image_url",
898
- "image_url": {"url": url}
899
- })
900
-
901
- elif att.kind in ['document', 'file']:
1026
+ parts.append({"type": "image_url", "image_url": {"url": url}})
1027
+
1028
+ elif att.kind in ["document", "file"]:
902
1029
  # Check if attachment has use_litellm_format flag or is a large document
903
1030
  use_litellm_format = att.use_litellm_format is True
904
-
1031
+
905
1032
  if use_litellm_format and (att.url or att.data):
906
1033
  # For now, fall back to content extraction since most providers don't support native file format
907
1034
  # TODO: Add provider-specific file format support
908
- print(f"Info: LiteLLM format requested for {att.name}, falling back to content extraction")
1035
+ print(
1036
+ f"Info: LiteLLM format requested for {att.name}, falling back to content extraction"
1037
+ )
909
1038
  use_litellm_format = False
910
-
1039
+
911
1040
  if not use_litellm_format:
912
1041
  # Extract document content if supported and we have data or URL
913
1042
  if is_document_supported(att.mime_type) and (att.data or att.url):
914
1043
  try:
915
1044
  processed = await extract_document_content(att)
916
- file_name = att.name or 'document'
1045
+ file_name = att.name or "document"
917
1046
  description = get_document_description(att.mime_type)
918
-
919
- parts.append({
920
- "type": "text",
921
- "text": f"DOCUMENT: {file_name} ({description}):\n\n{processed.content}"
922
- })
1047
+
1048
+ parts.append(
1049
+ {
1050
+ "type": "text",
1051
+ "text": f"DOCUMENT: {file_name} ({description}):\n\n{processed.content}",
1052
+ }
1053
+ )
923
1054
  except DocumentProcessingError as e:
924
1055
  # Fallback to filename if extraction fails
925
- label = att.name or att.format or att.mime_type or 'attachment'
926
- parts.append({
927
- "type": "text",
928
- "text": f"ERROR: Failed to process {att.kind}: {label} ({e})"
929
- })
1056
+ label = att.name or att.format or att.mime_type or "attachment"
1057
+ parts.append(
1058
+ {
1059
+ "type": "text",
1060
+ "text": f"ERROR: Failed to process {att.kind}: {label} ({e})",
1061
+ }
1062
+ )
930
1063
  else:
931
1064
  # Unsupported document type - show placeholder
932
- label = att.name or att.format or att.mime_type or 'attachment'
1065
+ label = att.name or att.format or att.mime_type or "attachment"
933
1066
  url_info = f" ({att.url})" if att.url else ""
934
- parts.append({
935
- "type": "text",
936
- "text": f"ATTACHMENT: {att.kind}: {label}{url_info}"
937
- })
1067
+ parts.append(
1068
+ {"type": "text", "text": f"ATTACHMENT: {att.kind}: {label}{url_info}"}
1069
+ )
938
1070
 
939
1071
  base_msg = {"role": role, "content": parts}
940
- if role == 'assistant' and msg.tool_calls:
1072
+ if role == "assistant" and msg.tool_calls:
941
1073
  base_msg["tool_calls"] = [
942
1074
  {
943
1075
  "id": tc.id,
944
1076
  "type": tc.type,
945
- "function": {
946
- "name": tc.function.name,
947
- "arguments": tc.function.arguments
948
- }
1077
+ "function": {"name": tc.function.name, "arguments": tc.function.arguments},
949
1078
  }
950
1079
  for tc in msg.tool_calls
951
1080
  ]
952
-
1081
+
953
1082
  return base_msg
954
1083
 
955
- def _pydantic_to_json_schema(model_class: type[BaseModel]) -> Dict[str, Any]:
1084
+
1085
+ def _resolve_schema_refs(
1086
+ schema: Dict[str, Any], defs: Optional[Dict[str, Any]] = None
1087
+ ) -> Dict[str, Any]:
1088
+ """
1089
+ Recursively resolve $ref references in a JSON schema by inlining definitions.
1090
+
1091
+ Args:
1092
+ schema: The schema object to process (may contain $ref)
1093
+ defs: The $defs dictionary containing reusable definitions
1094
+
1095
+ Returns:
1096
+ Schema with all references resolved inline
1097
+ """
1098
+ if defs is None:
1099
+ # Extract $defs from root schema if present
1100
+ defs = schema.get("$defs", {})
1101
+
1102
+ # If this is a reference, resolve it
1103
+ if isinstance(schema, dict) and "$ref" in schema:
1104
+ ref_path = schema["$ref"]
1105
+
1106
+ # Handle #/$defs/DefinitionName format
1107
+ if ref_path.startswith("#/$defs/"):
1108
+ def_name = ref_path.split("/")[-1]
1109
+ if def_name in defs:
1110
+ # Recursively resolve the definition (it might have refs too)
1111
+ resolved_def = _resolve_schema_refs(defs[def_name], defs)
1112
+ return resolved_def
1113
+ else:
1114
+ # If definition not found, return the original ref
1115
+ return schema
1116
+ else:
1117
+ # Other ref formats - return as is
1118
+ return schema
1119
+
1120
+ # If this is a dict, recursively process all values
1121
+ if isinstance(schema, dict):
1122
+ result = {}
1123
+ for key, value in schema.items():
1124
+ # Skip $defs as we're inlining them
1125
+ if key == "$defs":
1126
+ continue
1127
+ result[key] = _resolve_schema_refs(value, defs)
1128
+ return result
1129
+
1130
+ # If this is a list, recursively process all items
1131
+ if isinstance(schema, list):
1132
+ return [_resolve_schema_refs(item, defs) for item in schema]
1133
+
1134
+ # For primitive types, return as is
1135
+ return schema
1136
+
1137
+
1138
+ def _pydantic_to_json_schema(
1139
+ model_class: type[BaseModel], inline_refs: bool = False
1140
+ ) -> Dict[str, Any]:
956
1141
  """
957
1142
  Convert a Pydantic model to JSON schema for OpenAI tools.
958
-
1143
+
959
1144
  Args:
960
1145
  model_class: Pydantic model class
961
-
1146
+ inline_refs: If True, resolve $refs and inline $defs in the schema
1147
+
962
1148
  Returns:
963
1149
  JSON schema dictionary
964
1150
  """
965
- if hasattr(model_class, 'model_json_schema'):
1151
+ if hasattr(model_class, "model_json_schema"):
966
1152
  # Pydantic v2
967
- return model_class.model_json_schema()
1153
+ schema = model_class.model_json_schema()
968
1154
  else:
969
1155
  # Pydantic v1 fallback
970
- return model_class.schema()
1156
+ schema = model_class.schema()
1157
+
1158
+ # If inline_refs is True, resolve all references
1159
+ if inline_refs:
1160
+ schema = _resolve_schema_refs(schema)
1161
+
1162
+ return schema