jaf-py 2.5.10__py3-none-any.whl → 2.5.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. jaf/__init__.py +154 -57
  2. jaf/a2a/__init__.py +42 -21
  3. jaf/a2a/agent.py +79 -126
  4. jaf/a2a/agent_card.py +87 -78
  5. jaf/a2a/client.py +30 -66
  6. jaf/a2a/examples/client_example.py +12 -12
  7. jaf/a2a/examples/integration_example.py +38 -47
  8. jaf/a2a/examples/server_example.py +56 -53
  9. jaf/a2a/memory/__init__.py +0 -4
  10. jaf/a2a/memory/cleanup.py +28 -21
  11. jaf/a2a/memory/factory.py +155 -133
  12. jaf/a2a/memory/providers/composite.py +21 -26
  13. jaf/a2a/memory/providers/in_memory.py +89 -83
  14. jaf/a2a/memory/providers/postgres.py +117 -115
  15. jaf/a2a/memory/providers/redis.py +128 -121
  16. jaf/a2a/memory/serialization.py +77 -87
  17. jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
  18. jaf/a2a/memory/tests/test_cleanup.py +211 -94
  19. jaf/a2a/memory/tests/test_serialization.py +73 -68
  20. jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
  21. jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
  22. jaf/a2a/memory/types.py +91 -53
  23. jaf/a2a/protocol.py +95 -125
  24. jaf/a2a/server.py +90 -118
  25. jaf/a2a/standalone_client.py +30 -43
  26. jaf/a2a/tests/__init__.py +16 -33
  27. jaf/a2a/tests/run_tests.py +17 -53
  28. jaf/a2a/tests/test_agent.py +40 -140
  29. jaf/a2a/tests/test_client.py +54 -117
  30. jaf/a2a/tests/test_integration.py +28 -82
  31. jaf/a2a/tests/test_protocol.py +54 -139
  32. jaf/a2a/tests/test_types.py +50 -136
  33. jaf/a2a/types.py +58 -34
  34. jaf/cli.py +21 -41
  35. jaf/core/__init__.py +7 -1
  36. jaf/core/agent_tool.py +93 -72
  37. jaf/core/analytics.py +257 -207
  38. jaf/core/checkpoint.py +223 -0
  39. jaf/core/composition.py +249 -235
  40. jaf/core/engine.py +817 -519
  41. jaf/core/errors.py +55 -42
  42. jaf/core/guardrails.py +276 -202
  43. jaf/core/handoff.py +47 -31
  44. jaf/core/parallel_agents.py +69 -75
  45. jaf/core/performance.py +75 -73
  46. jaf/core/proxy.py +43 -44
  47. jaf/core/proxy_helpers.py +24 -27
  48. jaf/core/regeneration.py +220 -129
  49. jaf/core/state.py +68 -66
  50. jaf/core/streaming.py +115 -108
  51. jaf/core/tool_results.py +111 -101
  52. jaf/core/tools.py +114 -116
  53. jaf/core/tracing.py +269 -210
  54. jaf/core/types.py +371 -151
  55. jaf/core/workflows.py +209 -168
  56. jaf/exceptions.py +46 -38
  57. jaf/memory/__init__.py +1 -6
  58. jaf/memory/approval_storage.py +54 -77
  59. jaf/memory/factory.py +4 -4
  60. jaf/memory/providers/in_memory.py +216 -180
  61. jaf/memory/providers/postgres.py +216 -146
  62. jaf/memory/providers/redis.py +173 -116
  63. jaf/memory/types.py +70 -51
  64. jaf/memory/utils.py +36 -34
  65. jaf/plugins/__init__.py +12 -12
  66. jaf/plugins/base.py +105 -96
  67. jaf/policies/__init__.py +0 -1
  68. jaf/policies/handoff.py +37 -46
  69. jaf/policies/validation.py +76 -52
  70. jaf/providers/__init__.py +6 -3
  71. jaf/providers/mcp.py +97 -51
  72. jaf/providers/model.py +360 -279
  73. jaf/server/__init__.py +1 -1
  74. jaf/server/main.py +7 -11
  75. jaf/server/server.py +514 -359
  76. jaf/server/types.py +208 -52
  77. jaf/utils/__init__.py +17 -18
  78. jaf/utils/attachments.py +111 -116
  79. jaf/utils/document_processor.py +175 -174
  80. jaf/visualization/__init__.py +1 -1
  81. jaf/visualization/example.py +111 -110
  82. jaf/visualization/functional_core.py +46 -71
  83. jaf/visualization/graphviz.py +154 -189
  84. jaf/visualization/imperative_shell.py +7 -16
  85. jaf/visualization/types.py +8 -4
  86. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/METADATA +2 -2
  87. jaf_py-2.5.11.dist-info/RECORD +97 -0
  88. jaf_py-2.5.10.dist-info/RECORD +0 -96
  89. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/WHEEL +0 -0
  90. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/entry_points.txt +0 -0
  91. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/licenses/LICENSE +0 -0
  92. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/top_level.txt +0 -0
jaf/providers/model.py CHANGED
@@ -16,17 +16,27 @@ from pydantic import BaseModel
16
16
  import litellm
17
17
 
18
18
  from ..core.types import (
19
- Agent, ContentRole, Message, ModelProvider, RunConfig, RunState,
20
- CompletionStreamChunk, ToolCallDelta, ToolCallFunctionDelta,
21
- MessageContentPart, get_text_content
19
+ Agent,
20
+ ContentRole,
21
+ Message,
22
+ ModelProvider,
23
+ RunConfig,
24
+ RunState,
25
+ CompletionStreamChunk,
26
+ ToolCallDelta,
27
+ ToolCallFunctionDelta,
28
+ MessageContentPart,
29
+ get_text_content,
22
30
  )
23
31
  from ..core.proxy import ProxyConfig
24
32
  from ..utils.document_processor import (
25
- extract_document_content, is_document_supported,
26
- get_document_description, DocumentProcessingError
33
+ extract_document_content,
34
+ is_document_supported,
35
+ get_document_description,
36
+ DocumentProcessingError,
27
37
  )
28
38
 
29
- Ctx = TypeVar('Ctx')
39
+ Ctx = TypeVar("Ctx")
30
40
 
31
41
  # Vision model caching
32
42
  VISION_MODEL_CACHE_TTL = 5 * 60 # 5 minutes
@@ -34,92 +44,87 @@ VISION_API_TIMEOUT = 3.0 # 3 seconds
34
44
  _vision_model_cache: Dict[str, Dict[str, Any]] = {}
35
45
  MAX_IMAGE_BYTES = int(os.environ.get("JAF_MAX_IMAGE_BYTES", 8 * 1024 * 1024))
36
46
 
47
+
37
48
  async def _is_vision_model(model: str, base_url: str) -> bool:
38
49
  """
39
50
  Check if a model supports vision capabilities.
40
-
51
+
41
52
  Args:
42
53
  model: Model name to check
43
54
  base_url: Base URL of the LiteLLM server
44
-
55
+
45
56
  Returns:
46
57
  True if model supports vision, False otherwise
47
58
  """
48
59
  cache_key = f"{base_url}:{model}"
49
60
  cached = _vision_model_cache.get(cache_key)
50
-
51
- if cached and time.time() - cached['timestamp'] < VISION_MODEL_CACHE_TTL:
52
- return cached['supports']
53
-
61
+
62
+ if cached and time.time() - cached["timestamp"] < VISION_MODEL_CACHE_TTL:
63
+ return cached["supports"]
64
+
54
65
  try:
55
66
  async with httpx.AsyncClient(timeout=VISION_API_TIMEOUT) as client:
56
67
  response = await client.get(
57
- f"{base_url}/model_group/info",
58
- headers={'accept': 'application/json'}
68
+ f"{base_url}/model_group/info", headers={"accept": "application/json"}
59
69
  )
60
-
70
+
61
71
  if response.status_code == 200:
62
72
  data = response.json()
63
73
  model_info = None
64
-
65
- if 'data' in data and isinstance(data['data'], list):
66
- for m in data['data']:
67
- if (m.get('model_group') == model or
68
- model in str(m.get('model_group', ''))):
74
+
75
+ if "data" in data and isinstance(data["data"], list):
76
+ for m in data["data"]:
77
+ if m.get("model_group") == model or model in str(m.get("model_group", "")):
69
78
  model_info = m
70
79
  break
71
-
72
- if model_info and 'supports_vision' in model_info:
73
- result = model_info['supports_vision']
74
- _vision_model_cache[cache_key] = {
75
- 'supports': result,
76
- 'timestamp': time.time()
77
- }
80
+
81
+ if model_info and "supports_vision" in model_info:
82
+ result = model_info["supports_vision"]
83
+ _vision_model_cache[cache_key] = {"supports": result, "timestamp": time.time()}
78
84
  return result
79
85
  else:
80
- print(f"Warning: Vision API returned status {response.status_code} for model {model}")
81
-
86
+ print(
87
+ f"Warning: Vision API returned status {response.status_code} for model {model}"
88
+ )
89
+
82
90
  except Exception as e:
83
91
  print(f"Warning: Vision API error for model {model}: {e}")
84
-
92
+
85
93
  # Fallback to known vision models
86
94
  known_vision_models = [
87
- 'gpt-4-vision-preview',
88
- 'gpt-4o',
89
- 'gpt-4o-mini',
90
- 'claude-sonnet-4',
91
- 'claude-sonnet-4-20250514',
92
- 'gemini-2.5-flash',
93
- 'gemini-2.5-pro'
95
+ "gpt-4-vision-preview",
96
+ "gpt-4o",
97
+ "gpt-4o-mini",
98
+ "claude-sonnet-4",
99
+ "claude-sonnet-4-20250514",
100
+ "gemini-2.5-flash",
101
+ "gemini-2.5-pro",
94
102
  ]
95
-
103
+
96
104
  is_known_vision_model = any(
97
- vision_model.lower() in model.lower()
98
- for vision_model in known_vision_models
105
+ vision_model.lower() in model.lower() for vision_model in known_vision_models
99
106
  )
100
-
101
- _vision_model_cache[cache_key] = {
102
- 'supports': is_known_vision_model,
103
- 'timestamp': time.time()
104
- }
105
-
107
+
108
+ _vision_model_cache[cache_key] = {"supports": is_known_vision_model, "timestamp": time.time()}
109
+
106
110
  return is_known_vision_model
107
111
 
112
+
108
113
  def make_litellm_provider(
109
114
  base_url: str,
110
115
  api_key: str = "anything",
111
116
  default_timeout: Optional[float] = None,
112
- proxy_config: Optional[ProxyConfig] = None
117
+ proxy_config: Optional[ProxyConfig] = None,
113
118
  ) -> ModelProvider[Ctx]:
114
119
  """
115
120
  Create a LiteLLM-compatible model provider.
116
-
121
+
117
122
  Args:
118
123
  base_url: Base URL for the LiteLLM server
119
124
  api_key: API key (defaults to "anything" for local servers)
120
125
  default_timeout: Default timeout for model API calls in seconds
121
126
  proxy_config: Optional proxy configuration
122
-
127
+
123
128
  Returns:
124
129
  ModelProvider instance
125
130
  """
@@ -128,48 +133,47 @@ def make_litellm_provider(
128
133
  def __init__(self):
129
134
  # Default to "anything" if api_key is not provided, for local servers
130
135
  effective_api_key = api_key if api_key is not None else "anything"
131
-
136
+
132
137
  # Configure HTTP client with proxy support
133
138
  client_kwargs = {
134
139
  "base_url": base_url,
135
140
  "api_key": effective_api_key,
136
141
  }
137
-
142
+
138
143
  if proxy_config:
139
144
  proxies = proxy_config.to_httpx_proxies()
140
145
  if proxies:
141
146
  # Create httpx client with proxy configuration
142
147
  try:
143
148
  # Use the https proxy if available, otherwise http proxy
144
- proxy_url = proxies.get('https://') or proxies.get('http://')
149
+ proxy_url = proxies.get("https://") or proxies.get("http://")
145
150
  if proxy_url:
146
151
  http_client = httpx.AsyncClient(proxy=proxy_url)
147
152
  client_kwargs["http_client"] = http_client
148
153
  except Exception as e:
149
154
  print(f"Warning: Could not configure proxy: {e}")
150
155
  # Fall back to environment variables for proxy
151
-
156
+
152
157
  self.client = AsyncOpenAI(**client_kwargs)
153
158
  self.default_timeout = default_timeout
154
159
 
155
160
  async def get_completion(
156
- self,
157
- state: RunState[Ctx],
158
- agent: Agent[Ctx, Any],
159
- config: RunConfig[Ctx]
161
+ self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: RunConfig[Ctx]
160
162
  ) -> Dict[str, Any]:
161
163
  """Get completion from the model."""
162
164
 
163
165
  # Determine model to use
164
- model = (config.model_override or
165
- (agent.model_config.name if agent.model_config else "gpt-4o"))
166
+ model = config.model_override or (
167
+ agent.model_config.name if agent.model_config else "gpt-4o"
168
+ )
166
169
 
167
170
  # Check if any message contains image content or image attachments
168
171
  has_image_content = any(
169
- (isinstance(msg.content, list) and
170
- any(part.type == 'image_url' for part in msg.content)) or
171
- (msg.attachments and
172
- any(att.kind == 'image' for att in msg.attachments))
172
+ (
173
+ isinstance(msg.content, list)
174
+ and any(part.type == "image_url" for part in msg.content)
175
+ )
176
+ or (msg.attachments and any(att.kind == "image" for att in msg.attachments))
173
177
  for msg in state.messages
174
178
  )
175
179
 
@@ -182,51 +186,59 @@ def make_litellm_provider(
182
186
  )
183
187
 
184
188
  # Create system message
185
- system_message = {
186
- "role": "system",
187
- "content": agent.instructions(state)
188
- }
189
+ system_message = {"role": "system", "content": agent.instructions(state)}
189
190
 
190
191
  # Convert messages to OpenAI format
191
192
  converted_messages = []
192
193
  for msg in state.messages:
193
194
  converted_msg = await _convert_message(msg)
194
195
  converted_messages.append(converted_msg)
195
-
196
+
196
197
  messages = [system_message] + converted_messages
197
198
 
198
199
  # Convert tools to OpenAI format
199
200
  tools = None
200
201
  if agent.tools:
202
+ # Check if we should inline schema refs
203
+ inline_refs = (
204
+ agent.model_config.inline_tool_schemas if agent.model_config else False
205
+ )
201
206
  tools = [
202
207
  {
203
208
  "type": "function",
204
209
  "function": {
205
210
  "name": tool.schema.name,
206
211
  "description": tool.schema.description,
207
- "parameters": _pydantic_to_json_schema(tool.schema.parameters),
208
- }
212
+ "parameters": _pydantic_to_json_schema(
213
+ tool.schema.parameters, inline_refs=inline_refs or False
214
+ ),
215
+ },
209
216
  }
210
217
  for tool in agent.tools
211
218
  ]
212
219
 
213
220
  # Determine tool choice behavior
214
221
  last_message = state.messages[-1] if state.messages else None
215
- is_after_tool_call = last_message and (last_message.role == ContentRole.TOOL or last_message.role == 'tool')
222
+ is_after_tool_call = last_message and (
223
+ last_message.role == ContentRole.TOOL or last_message.role == "tool"
224
+ )
216
225
 
217
226
  # Prepare request parameters
218
- request_params = {
219
- "model": model,
220
- "messages": messages,
221
- "stream": False
222
- }
227
+ request_params = {"model": model, "messages": messages, "stream": False}
223
228
 
224
229
  # Add optional parameters
225
230
  if agent.model_config:
226
231
  if agent.model_config.temperature is not None:
227
232
  request_params["temperature"] = agent.model_config.temperature
228
- if agent.model_config.max_tokens is not None:
229
- request_params["max_tokens"] = agent.model_config.max_tokens
233
+ # Use agent's max_tokens if set, otherwise fall back to config's max_tokens
234
+ max_tokens = agent.model_config.max_tokens
235
+ if max_tokens is None:
236
+ max_tokens = config.max_tokens
237
+ if max_tokens is not None:
238
+ request_params["max_tokens"] = max_tokens
239
+ elif config.max_tokens is not None:
240
+ # No model_config but config has max_tokens
241
+ request_params["max_tokens"] = config.max_tokens
230
242
 
231
243
  if tools:
232
244
  request_params["tools"] = tools
@@ -247,12 +259,9 @@ def make_litellm_provider(
247
259
  if choice.message.tool_calls:
248
260
  tool_calls = [
249
261
  {
250
- 'id': tc.id,
251
- 'type': tc.type,
252
- 'function': {
253
- 'name': tc.function.name,
254
- 'arguments': tc.function.arguments
255
- }
262
+ "id": tc.id,
263
+ "type": tc.type,
264
+ "function": {"name": tc.function.name, "arguments": tc.function.arguments},
256
265
  }
257
266
  for tc in choice.message.tool_calls
258
267
  ]
@@ -267,64 +276,64 @@ def make_litellm_provider(
267
276
  }
268
277
 
269
278
  return {
270
- 'id': response.id,
271
- 'created': response.created,
272
- 'model': response.model,
273
- 'system_fingerprint': response.system_fingerprint,
274
- 'message': {
275
- 'content': choice.message.content,
276
- 'tool_calls': tool_calls
277
- },
278
- 'usage': usage_data,
279
- 'prompt': messages
279
+ "id": response.id,
280
+ "created": response.created,
281
+ "model": response.model,
282
+ "system_fingerprint": response.system_fingerprint,
283
+ "message": {"content": choice.message.content, "tool_calls": tool_calls},
284
+ "usage": usage_data,
285
+ "prompt": messages,
280
286
  }
281
287
 
282
288
  async def get_completion_stream(
283
- self,
284
- state: RunState[Ctx],
285
- agent: Agent[Ctx, Any],
286
- config: RunConfig[Ctx]
289
+ self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: RunConfig[Ctx]
287
290
  ) -> AsyncIterator[CompletionStreamChunk]:
288
291
  """
289
292
  Stream completion chunks from the model provider, yielding text deltas and tool-call deltas.
290
293
  Uses OpenAI-compatible streaming via LiteLLM endpoint.
291
294
  """
292
295
  # Determine model to use
293
- model = (config.model_override or
294
- (agent.model_config.name if agent.model_config else "gpt-4o"))
296
+ model = config.model_override or (
297
+ agent.model_config.name if agent.model_config else "gpt-4o"
298
+ )
295
299
 
296
300
  # Create system message
297
- system_message = {
298
- "role": "system",
299
- "content": agent.instructions(state)
300
- }
301
+ system_message = {"role": "system", "content": agent.instructions(state)}
301
302
 
302
- # Convert messages to OpenAI format
303
+ # Convert messages to OpenAI format
303
304
  converted_messages = []
304
305
  for msg in state.messages:
305
306
  converted_msg = await _convert_message(msg)
306
307
  converted_messages.append(converted_msg)
307
-
308
+
308
309
  messages = [system_message] + converted_messages
309
310
 
310
311
  # Convert tools to OpenAI format
311
312
  tools = None
312
313
  if agent.tools:
314
+ # Check if we should inline schema refs
315
+ inline_refs = (
316
+ agent.model_config.inline_tool_schemas if agent.model_config else False
317
+ )
313
318
  tools = [
314
319
  {
315
320
  "type": "function",
316
321
  "function": {
317
322
  "name": tool.schema.name,
318
323
  "description": tool.schema.description,
319
- "parameters": _pydantic_to_json_schema(tool.schema.parameters),
320
- }
324
+ "parameters": _pydantic_to_json_schema(
325
+ tool.schema.parameters, inline_refs=inline_refs or False
326
+ ),
327
+ },
321
328
  }
322
329
  for tool in agent.tools
323
330
  ]
324
331
 
325
332
  # Determine tool choice behavior
326
333
  last_message = state.messages[-1] if state.messages else None
327
- is_after_tool_call = last_message and (last_message.role == ContentRole.TOOL or last_message.role == 'tool')
334
+ is_after_tool_call = last_message and (
335
+ last_message.role == ContentRole.TOOL or last_message.role == "tool"
336
+ )
328
337
 
329
338
  # Prepare request parameters
330
339
  request_params: Dict[str, Any] = {
@@ -336,8 +345,15 @@ def make_litellm_provider(
336
345
  if agent.model_config:
337
346
  if agent.model_config.temperature is not None:
338
347
  request_params["temperature"] = agent.model_config.temperature
339
- if agent.model_config.max_tokens is not None:
340
- request_params["max_tokens"] = agent.model_config.max_tokens
348
+ # Use agent's max_tokens if set, otherwise fall back to config's max_tokens
349
+ max_tokens = agent.model_config.max_tokens
350
+ if max_tokens is None:
351
+ max_tokens = config.max_tokens
352
+ if max_tokens is not None:
353
+ request_params["max_tokens"] = max_tokens
354
+ elif config.max_tokens is not None:
355
+ # No model_config but config has max_tokens
356
+ request_params["max_tokens"] = config.max_tokens
341
357
 
342
358
  if tools:
343
359
  request_params["tools"] = tools
@@ -388,19 +404,20 @@ def make_litellm_provider(
388
404
  fn = getattr(tc, "function", None)
389
405
  fn_name = getattr(fn, "name", None) if fn is not None else None
390
406
  # OpenAI streams "arguments" as incremental deltas
391
- args_delta = getattr(fn, "arguments", None) if fn is not None else None
407
+ args_delta = (
408
+ getattr(fn, "arguments", None) if fn is not None else None
409
+ )
392
410
 
393
411
  yield CompletionStreamChunk(
394
412
  tool_call_delta=ToolCallDelta(
395
413
  index=idx,
396
414
  id=tc_id,
397
- type='function',
415
+ type="function",
398
416
  function=ToolCallFunctionDelta(
399
- name=fn_name,
400
- arguments_delta=args_delta
401
- )
417
+ name=fn_name, arguments_delta=args_delta
418
+ ),
402
419
  ),
403
- raw=raw_obj
420
+ raw=raw_obj,
404
421
  )
405
422
  except Exception:
406
423
  # Skip malformed tool-call deltas
@@ -408,26 +425,29 @@ def make_litellm_provider(
408
425
 
409
426
  # Completion ended
410
427
  if finish_reason:
411
- yield CompletionStreamChunk(is_done=True, finish_reason=finish_reason, raw=raw_obj)
428
+ yield CompletionStreamChunk(
429
+ is_done=True, finish_reason=finish_reason, raw=raw_obj
430
+ )
412
431
  except Exception:
413
432
  # Skip individual chunk errors, keep streaming
414
433
  continue
415
434
 
416
435
  return LiteLLMProvider()
417
436
 
437
+
418
438
  def make_litellm_sdk_provider(
419
439
  api_key: Optional[str] = None,
420
440
  model: str = "gpt-3.5-turbo",
421
441
  base_url: Optional[str] = None,
422
442
  default_timeout: Optional[float] = None,
423
- **litellm_kwargs: Any
443
+ **litellm_kwargs: Any,
424
444
  ) -> ModelProvider[Ctx]:
425
445
  """
426
446
  Create a LiteLLM SDK-based model provider with universal provider support.
427
-
447
+
428
448
  LiteLLM automatically detects the provider from the model name and handles
429
449
  API key management through environment variables or direct parameters.
430
-
450
+
431
451
  Args:
432
452
  api_key: API key for the provider (optional, can use env vars)
433
453
  model: Model name (e.g., "gpt-4", "claude-3-sonnet", "gemini-pro", "llama2", etc.)
@@ -440,23 +460,23 @@ def make_litellm_sdk_provider(
440
460
  - azure_deployment: "your-deployment" (for Azure OpenAI)
441
461
  - api_base: "https://your-endpoint.com" (custom endpoint)
442
462
  - custom_llm_provider: "custom_provider_name"
443
-
463
+
444
464
  Returns:
445
465
  ModelProvider instance
446
-
466
+
447
467
  Examples:
448
468
  # OpenAI
449
469
  make_litellm_sdk_provider(api_key="sk-...", model="gpt-4")
450
-
470
+
451
471
  # Anthropic Claude
452
472
  make_litellm_sdk_provider(api_key="sk-ant-...", model="claude-3-sonnet-20240229")
453
-
473
+
454
474
  # Google Gemini
455
475
  make_litellm_sdk_provider(model="gemini-pro", vertex_project="my-project")
456
-
476
+
457
477
  # Ollama (local)
458
478
  make_litellm_sdk_provider(model="ollama/llama2", base_url="http://localhost:11434")
459
-
479
+
460
480
  # Azure OpenAI
461
481
  make_litellm_sdk_provider(
462
482
  model="azure/gpt-4",
@@ -464,13 +484,13 @@ def make_litellm_sdk_provider(
464
484
  azure_deployment="gpt-4-deployment",
465
485
  api_base="https://your-resource.openai.azure.com"
466
486
  )
467
-
487
+
468
488
  # Hugging Face
469
489
  make_litellm_sdk_provider(
470
490
  model="huggingface/microsoft/DialoGPT-medium",
471
491
  api_key="hf_..."
472
492
  )
473
-
493
+
474
494
  # Any custom provider
475
495
  make_litellm_sdk_provider(
476
496
  model="custom_provider/model-name",
@@ -488,10 +508,7 @@ def make_litellm_sdk_provider(
488
508
  self.litellm_kwargs = litellm_kwargs
489
509
 
490
510
  async def get_completion(
491
- self,
492
- state: RunState[Ctx],
493
- agent: Agent[Ctx, Any],
494
- config: RunConfig[Ctx]
511
+ self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: RunConfig[Ctx]
495
512
  ) -> Dict[str, Any]:
496
513
  """Get completion from the model using LiteLLM SDK."""
497
514
 
@@ -499,10 +516,7 @@ def make_litellm_sdk_provider(
499
516
  model_name = config.model_override or self.model
500
517
 
501
518
  # Create system message
502
- system_message = {
503
- "role": "system",
504
- "content": agent.instructions(state)
505
- }
519
+ system_message = {"role": "system", "content": agent.instructions(state)}
506
520
 
507
521
  # Convert messages to OpenAI format
508
522
  messages = [system_message]
@@ -513,24 +527,26 @@ def make_litellm_sdk_provider(
513
527
  # Convert tools to OpenAI format
514
528
  tools = None
515
529
  if agent.tools:
530
+ # Check if we should inline schema refs
531
+ inline_refs = (
532
+ agent.model_config.inline_tool_schemas if agent.model_config else False
533
+ )
516
534
  tools = [
517
535
  {
518
536
  "type": "function",
519
537
  "function": {
520
538
  "name": tool.schema.name,
521
539
  "description": tool.schema.description,
522
- "parameters": _pydantic_to_json_schema(tool.schema.parameters),
523
- }
540
+ "parameters": _pydantic_to_json_schema(
541
+ tool.schema.parameters, inline_refs=inline_refs or False
542
+ ),
543
+ },
524
544
  }
525
545
  for tool in agent.tools
526
546
  ]
527
547
 
528
548
  # Prepare request parameters for LiteLLM
529
- request_params = {
530
- "model": model_name,
531
- "messages": messages,
532
- **self.litellm_kwargs
533
- }
549
+ request_params = {"model": model_name, "messages": messages, **self.litellm_kwargs}
534
550
 
535
551
  # Add API key if provided
536
552
  if self.api_key:
@@ -540,8 +556,15 @@ def make_litellm_sdk_provider(
540
556
  if agent.model_config:
541
557
  if agent.model_config.temperature is not None:
542
558
  request_params["temperature"] = agent.model_config.temperature
543
- if agent.model_config.max_tokens is not None:
544
- request_params["max_tokens"] = agent.model_config.max_tokens
559
+ # Use agent's max_tokens if set, otherwise fall back to config's max_tokens
560
+ max_tokens = agent.model_config.max_tokens
561
+ if max_tokens is None:
562
+ max_tokens = config.max_tokens
563
+ if max_tokens is not None:
564
+ request_params["max_tokens"] = max_tokens
565
+ elif config.max_tokens is not None:
566
+ # No model_config but config has max_tokens
567
+ request_params["max_tokens"] = config.max_tokens
545
568
 
546
569
  if tools:
547
570
  request_params["tools"] = tools
@@ -565,12 +588,9 @@ def make_litellm_sdk_provider(
565
588
  if choice.message.tool_calls:
566
589
  tool_calls = [
567
590
  {
568
- 'id': tc.id,
569
- 'type': tc.type,
570
- 'function': {
571
- 'name': tc.function.name,
572
- 'arguments': tc.function.arguments
573
- }
591
+ "id": tc.id,
592
+ "type": tc.type,
593
+ "function": {"name": tc.function.name, "arguments": tc.function.arguments},
574
594
  }
575
595
  for tc in choice.message.tool_calls
576
596
  ]
@@ -585,23 +605,17 @@ def make_litellm_sdk_provider(
585
605
  }
586
606
 
587
607
  return {
588
- 'id': response.id,
589
- 'created': response.created,
590
- 'model': response.model,
591
- 'system_fingerprint': getattr(response, 'system_fingerprint', None),
592
- 'message': {
593
- 'content': choice.message.content,
594
- 'tool_calls': tool_calls
595
- },
596
- 'usage': usage_data,
597
- 'prompt': messages
608
+ "id": response.id,
609
+ "created": response.created,
610
+ "model": response.model,
611
+ "system_fingerprint": getattr(response, "system_fingerprint", None),
612
+ "message": {"content": choice.message.content, "tool_calls": tool_calls},
613
+ "usage": usage_data,
614
+ "prompt": messages,
598
615
  }
599
616
 
600
617
  async def get_completion_stream(
601
- self,
602
- state: RunState[Ctx],
603
- agent: Agent[Ctx, Any],
604
- config: RunConfig[Ctx]
618
+ self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: RunConfig[Ctx]
605
619
  ) -> AsyncIterator[CompletionStreamChunk]:
606
620
  """
607
621
  Stream completion chunks from the model provider using LiteLLM SDK.
@@ -610,10 +624,7 @@ def make_litellm_sdk_provider(
610
624
  model_name = config.model_override or self.model
611
625
 
612
626
  # Create system message
613
- system_message = {
614
- "role": "system",
615
- "content": agent.instructions(state)
616
- }
627
+ system_message = {"role": "system", "content": agent.instructions(state)}
617
628
 
618
629
  # Convert messages to OpenAI format
619
630
  messages = [system_message]
@@ -624,14 +635,20 @@ def make_litellm_sdk_provider(
624
635
  # Convert tools to OpenAI format
625
636
  tools = None
626
637
  if agent.tools:
638
+ # Check if we should inline schema refs
639
+ inline_refs = (
640
+ agent.model_config.inline_tool_schemas if agent.model_config else False
641
+ )
627
642
  tools = [
628
643
  {
629
644
  "type": "function",
630
645
  "function": {
631
646
  "name": tool.schema.name,
632
647
  "description": tool.schema.description,
633
- "parameters": _pydantic_to_json_schema(tool.schema.parameters),
634
- }
648
+ "parameters": _pydantic_to_json_schema(
649
+ tool.schema.parameters, inline_refs=inline_refs or False
650
+ ),
651
+ },
635
652
  }
636
653
  for tool in agent.tools
637
654
  ]
@@ -641,7 +658,7 @@ def make_litellm_sdk_provider(
641
658
  "model": model_name,
642
659
  "messages": messages,
643
660
  "stream": True,
644
- **self.litellm_kwargs
661
+ **self.litellm_kwargs,
645
662
  }
646
663
 
647
664
  # Add API key if provided
@@ -652,8 +669,15 @@ def make_litellm_sdk_provider(
652
669
  if agent.model_config:
653
670
  if agent.model_config.temperature is not None:
654
671
  request_params["temperature"] = agent.model_config.temperature
655
- if agent.model_config.max_tokens is not None:
656
- request_params["max_tokens"] = agent.model_config.max_tokens
672
+ # Use agent's max_tokens if set, otherwise fall back to config's max_tokens
673
+ max_tokens = agent.model_config.max_tokens
674
+ if max_tokens is None:
675
+ max_tokens = config.max_tokens
676
+ if max_tokens is not None:
677
+ request_params["max_tokens"] = max_tokens
678
+ elif config.max_tokens is not None:
679
+ # No model_config but config has max_tokens
680
+ request_params["max_tokens"] = config.max_tokens
657
681
 
658
682
  if tools:
659
683
  request_params["tools"] = tools
@@ -668,12 +692,12 @@ def make_litellm_sdk_provider(
668
692
 
669
693
  # Stream using litellm
670
694
  stream = await litellm.acompletion(**request_params)
671
-
695
+
672
696
  async for chunk in stream:
673
697
  try:
674
698
  # Best-effort extraction of raw for debugging
675
699
  try:
676
- raw_obj = chunk.model_dump() if hasattr(chunk, 'model_dump') else None
700
+ raw_obj = chunk.model_dump() if hasattr(chunk, "model_dump") else None
677
701
  except Exception:
678
702
  raw_obj = None
679
703
 
@@ -702,52 +726,59 @@ def make_litellm_sdk_provider(
702
726
  tc_id = getattr(tc, "id", None)
703
727
  fn = getattr(tc, "function", None)
704
728
  fn_name = getattr(fn, "name", None) if fn is not None else None
705
- args_delta = getattr(fn, "arguments", None) if fn is not None else None
729
+ args_delta = (
730
+ getattr(fn, "arguments", None) if fn is not None else None
731
+ )
706
732
 
707
733
  yield CompletionStreamChunk(
708
734
  tool_call_delta=ToolCallDelta(
709
735
  index=idx,
710
736
  id=tc_id,
711
- type='function',
737
+ type="function",
712
738
  function=ToolCallFunctionDelta(
713
- name=fn_name,
714
- arguments_delta=args_delta
715
- )
739
+ name=fn_name, arguments_delta=args_delta
740
+ ),
716
741
  ),
717
- raw=raw_obj
742
+ raw=raw_obj,
718
743
  )
719
744
  except Exception:
720
745
  continue
721
746
 
722
747
  # Completion ended
723
748
  if finish_reason:
724
- yield CompletionStreamChunk(is_done=True, finish_reason=finish_reason, raw=raw_obj)
749
+ yield CompletionStreamChunk(
750
+ is_done=True, finish_reason=finish_reason, raw=raw_obj
751
+ )
725
752
  except Exception:
726
753
  continue
727
754
 
728
755
  return LiteLLMSDKProvider()
729
756
 
757
+
730
758
  async def _convert_message(msg: Message) -> Dict[str, Any]:
731
759
  """
732
760
  Handles all possible role types (string and enum) and content formats.
733
761
  """
734
762
  # Normalize role to handle both string and enum values
735
- role_value = msg.role.value if hasattr(msg.role, 'value') else str(msg.role).lower()
763
+ role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role).lower()
736
764
 
737
765
  # Handle user messages
738
- if role_value in ('user', ContentRole.USER.value if hasattr(ContentRole, 'USER') else 'user'):
766
+ if role_value in ("user", ContentRole.USER.value if hasattr(ContentRole, "USER") else "user"):
739
767
  if isinstance(msg.content, list):
740
768
  # Multi-part content
741
769
  return {
742
770
  "role": "user",
743
- "content": [_convert_content_part(part) for part in msg.content]
771
+ "content": [_convert_content_part(part) for part in msg.content],
744
772
  }
745
773
  else:
746
774
  # Build message with attachments if available
747
- return await _build_chat_message_with_attachments('user', msg)
775
+ return await _build_chat_message_with_attachments("user", msg)
748
776
 
749
777
  # Handle assistant messages
750
- elif role_value in ('assistant', ContentRole.ASSISTANT.value if hasattr(ContentRole, 'ASSISTANT') else 'assistant'):
778
+ elif role_value in (
779
+ "assistant",
780
+ ContentRole.ASSISTANT.value if hasattr(ContentRole, "ASSISTANT") else "assistant",
781
+ ):
751
782
  result = {
752
783
  "role": "assistant",
753
784
  "content": get_text_content(msg.content) or "", # Ensure content is never None
@@ -759,10 +790,7 @@ async def _convert_message(msg: Message) -> Dict[str, Any]:
759
790
  {
760
791
  "id": tc.id,
761
792
  "type": tc.type,
762
- "function": {
763
- "name": tc.function.name,
764
- "arguments": tc.function.arguments
765
- }
793
+ "function": {"name": tc.function.name, "arguments": tc.function.arguments},
766
794
  }
767
795
  for tc in msg.tool_calls
768
796
  if tc.id and tc.function and tc.function.name # Validate tool call structure
@@ -771,37 +799,37 @@ async def _convert_message(msg: Message) -> Dict[str, Any]:
771
799
  return result
772
800
 
773
801
  # Handle system messages
774
- elif role_value in ('system', ContentRole.SYSTEM.value if hasattr(ContentRole, 'SYSTEM') else 'system'):
775
- return {
776
- "role": "system",
777
- "content": get_text_content(msg.content) or ""
778
- }
802
+ elif role_value in (
803
+ "system",
804
+ ContentRole.SYSTEM.value if hasattr(ContentRole, "SYSTEM") else "system",
805
+ ):
806
+ return {"role": "system", "content": get_text_content(msg.content) or ""}
779
807
 
780
808
  # Handle tool messages
781
- elif role_value in ('tool', ContentRole.TOOL.value if hasattr(ContentRole, 'TOOL') else 'tool'):
809
+ elif role_value in ("tool", ContentRole.TOOL.value if hasattr(ContentRole, "TOOL") else "tool"):
782
810
  if not msg.tool_call_id:
783
811
  raise ValueError(f"Tool message must have tool_call_id. Message: {msg}")
784
812
 
785
813
  return {
786
814
  "role": "tool",
787
815
  "content": get_text_content(msg.content) or "",
788
- "tool_call_id": msg.tool_call_id
816
+ "tool_call_id": msg.tool_call_id,
789
817
  }
790
818
 
791
819
  # Handle function messages (legacy support)
792
- elif role_value == 'function':
820
+ elif role_value == "function":
793
821
  if not msg.tool_call_id:
794
822
  raise ValueError(f"Function message must have tool_call_id. Message: {msg}")
795
823
 
796
824
  return {
797
825
  "role": "function",
798
826
  "content": get_text_content(msg.content) or "",
799
- "name": getattr(msg, 'name', 'unknown_function')
827
+ "name": getattr(msg, "name", "unknown_function"),
800
828
  }
801
829
 
802
830
  # Unknown role - provide helpful error message
803
831
  else:
804
- available_roles = ['user', 'assistant', 'system', 'tool', 'function']
832
+ available_roles = ["user", "assistant", "system", "tool", "function"]
805
833
  raise ValueError(
806
834
  f"Unknown message role: {msg.role} (type: {type(msg.role)}). "
807
835
  f"Supported roles: {available_roles}. "
@@ -811,46 +839,31 @@ async def _convert_message(msg: Message) -> Dict[str, Any]:
811
839
 
812
840
  def _convert_content_part(part: MessageContentPart) -> Dict[str, Any]:
813
841
  """Convert MessageContentPart to OpenAI format."""
814
- if part.type == 'text':
815
- return {
816
- "type": "text",
817
- "text": part.text
818
- }
819
- elif part.type == 'image_url':
820
- return {
821
- "type": "image_url",
822
- "image_url": part.image_url
823
- }
824
- elif part.type == 'file':
825
- return {
826
- "type": "file",
827
- "file": part.file
828
- }
842
+ if part.type == "text":
843
+ return {"type": "text", "text": part.text}
844
+ elif part.type == "image_url":
845
+ return {"type": "image_url", "image_url": part.image_url}
846
+ elif part.type == "file":
847
+ return {"type": "file", "file": part.file}
829
848
  else:
830
849
  raise ValueError(f"Unknown content part type: {part.type}")
831
850
 
832
851
 
833
- async def _build_chat_message_with_attachments(
834
- role: str,
835
- msg: Message
836
- ) -> Dict[str, Any]:
852
+ async def _build_chat_message_with_attachments(role: str, msg: Message) -> Dict[str, Any]:
837
853
  """
838
854
  Build multi-part content for Chat Completions if attachments exist.
839
855
  Supports images via image_url and documents via content extraction.
840
856
  """
841
857
  has_attachments = msg.attachments and len(msg.attachments) > 0
842
858
  if not has_attachments:
843
- if role == 'assistant':
859
+ if role == "assistant":
844
860
  base_msg = {"role": "assistant", "content": get_text_content(msg.content)}
845
861
  if msg.tool_calls:
846
862
  base_msg["tool_calls"] = [
847
863
  {
848
864
  "id": tc.id,
849
865
  "type": tc.type,
850
- "function": {
851
- "name": tc.function.name,
852
- "arguments": tc.function.arguments
853
- }
866
+ "function": {"name": tc.function.name, "arguments": tc.function.arguments},
854
867
  }
855
868
  for tc in msg.tool_calls
856
869
  ]
@@ -863,7 +876,7 @@ async def _build_chat_message_with_attachments(
863
876
  parts.append({"type": "text", "text": text_content})
864
877
 
865
878
  for att in msg.attachments:
866
- if att.kind == 'image':
879
+ if att.kind == "image":
867
880
  # Prefer explicit URL; otherwise construct a data URL from base64
868
881
  url = att.url
869
882
  if not url and att.data and att.mime_type:
@@ -871,100 +884,168 @@ async def _build_chat_message_with_attachments(
871
884
  try:
872
885
  # Estimate decoded size (base64 is ~4/3 of decoded size)
873
886
  estimated_size = len(att.data) * 3 // 4
874
-
887
+
875
888
  if estimated_size > MAX_IMAGE_BYTES:
876
- print(f"Warning: Skipping oversized image ({estimated_size} bytes > {MAX_IMAGE_BYTES}). "
877
- f"Set JAF_MAX_IMAGE_BYTES env var to adjust limit.")
878
- parts.append({
879
- "type": "text",
880
- "text": f"[IMAGE SKIPPED: Size exceeds limit of {MAX_IMAGE_BYTES//1024//1024}MB. "
881
- f"Image name: {att.name or 'unnamed'}]"
882
- })
889
+ print(
890
+ f"Warning: Skipping oversized image ({estimated_size} bytes > {MAX_IMAGE_BYTES}). "
891
+ f"Set JAF_MAX_IMAGE_BYTES env var to adjust limit."
892
+ )
893
+ parts.append(
894
+ {
895
+ "type": "text",
896
+ "text": f"[IMAGE SKIPPED: Size exceeds limit of {MAX_IMAGE_BYTES // 1024 // 1024}MB. "
897
+ f"Image name: {att.name or 'unnamed'}]",
898
+ }
899
+ )
883
900
  continue
884
-
901
+
885
902
  # Create data URL for valid-sized images
886
903
  url = f"data:{att.mime_type};base64,{att.data}"
887
904
  except Exception as e:
888
905
  print(f"Error processing image data: {e}")
889
- parts.append({
890
- "type": "text",
891
- "text": f"[IMAGE ERROR: Failed to process image data. Image name: {att.name or 'unnamed'}]"
892
- })
906
+ parts.append(
907
+ {
908
+ "type": "text",
909
+ "text": f"[IMAGE ERROR: Failed to process image data. Image name: {att.name or 'unnamed'}]",
910
+ }
911
+ )
893
912
  continue
894
-
913
+
895
914
  if url:
896
- parts.append({
897
- "type": "image_url",
898
- "image_url": {"url": url}
899
- })
900
-
901
- elif att.kind in ['document', 'file']:
915
+ parts.append({"type": "image_url", "image_url": {"url": url}})
916
+
917
+ elif att.kind in ["document", "file"]:
902
918
  # Check if attachment has use_litellm_format flag or is a large document
903
919
  use_litellm_format = att.use_litellm_format is True
904
-
920
+
905
921
  if use_litellm_format and (att.url or att.data):
906
922
  # For now, fall back to content extraction since most providers don't support native file format
907
923
  # TODO: Add provider-specific file format support
908
- print(f"Info: LiteLLM format requested for {att.name}, falling back to content extraction")
924
+ print(
925
+ f"Info: LiteLLM format requested for {att.name}, falling back to content extraction"
926
+ )
909
927
  use_litellm_format = False
910
-
928
+
911
929
  if not use_litellm_format:
912
930
  # Extract document content if supported and we have data or URL
913
931
  if is_document_supported(att.mime_type) and (att.data or att.url):
914
932
  try:
915
933
  processed = await extract_document_content(att)
916
- file_name = att.name or 'document'
934
+ file_name = att.name or "document"
917
935
  description = get_document_description(att.mime_type)
918
-
919
- parts.append({
920
- "type": "text",
921
- "text": f"DOCUMENT: {file_name} ({description}):\n\n{processed.content}"
922
- })
936
+
937
+ parts.append(
938
+ {
939
+ "type": "text",
940
+ "text": f"DOCUMENT: {file_name} ({description}):\n\n{processed.content}",
941
+ }
942
+ )
923
943
  except DocumentProcessingError as e:
924
944
  # Fallback to filename if extraction fails
925
- label = att.name or att.format or att.mime_type or 'attachment'
926
- parts.append({
927
- "type": "text",
928
- "text": f"ERROR: Failed to process {att.kind}: {label} ({e})"
929
- })
945
+ label = att.name or att.format or att.mime_type or "attachment"
946
+ parts.append(
947
+ {
948
+ "type": "text",
949
+ "text": f"ERROR: Failed to process {att.kind}: {label} ({e})",
950
+ }
951
+ )
930
952
  else:
931
953
  # Unsupported document type - show placeholder
932
- label = att.name or att.format or att.mime_type or 'attachment'
954
+ label = att.name or att.format or att.mime_type or "attachment"
933
955
  url_info = f" ({att.url})" if att.url else ""
934
- parts.append({
935
- "type": "text",
936
- "text": f"ATTACHMENT: {att.kind}: {label}{url_info}"
937
- })
956
+ parts.append(
957
+ {"type": "text", "text": f"ATTACHMENT: {att.kind}: {label}{url_info}"}
958
+ )
938
959
 
939
960
  base_msg = {"role": role, "content": parts}
940
- if role == 'assistant' and msg.tool_calls:
961
+ if role == "assistant" and msg.tool_calls:
941
962
  base_msg["tool_calls"] = [
942
963
  {
943
964
  "id": tc.id,
944
965
  "type": tc.type,
945
- "function": {
946
- "name": tc.function.name,
947
- "arguments": tc.function.arguments
948
- }
966
+ "function": {"name": tc.function.name, "arguments": tc.function.arguments},
949
967
  }
950
968
  for tc in msg.tool_calls
951
969
  ]
952
-
970
+
953
971
  return base_msg
954
972
 
955
- def _pydantic_to_json_schema(model_class: type[BaseModel]) -> Dict[str, Any]:
973
+
974
+ def _resolve_schema_refs(
975
+ schema: Dict[str, Any], defs: Optional[Dict[str, Any]] = None
976
+ ) -> Dict[str, Any]:
977
+ """
978
+ Recursively resolve $ref references in a JSON schema by inlining definitions.
979
+
980
+ Args:
981
+ schema: The schema object to process (may contain $ref)
982
+ defs: The $defs dictionary containing reusable definitions
983
+
984
+ Returns:
985
+ Schema with all references resolved inline
986
+ """
987
+ if defs is None:
988
+ # Extract $defs from root schema if present
989
+ defs = schema.get("$defs", {})
990
+
991
+ # If this is a reference, resolve it
992
+ if isinstance(schema, dict) and "$ref" in schema:
993
+ ref_path = schema["$ref"]
994
+
995
+ # Handle #/$defs/DefinitionName format
996
+ if ref_path.startswith("#/$defs/"):
997
+ def_name = ref_path.split("/")[-1]
998
+ if def_name in defs:
999
+ # Recursively resolve the definition (it might have refs too)
1000
+ resolved_def = _resolve_schema_refs(defs[def_name], defs)
1001
+ return resolved_def
1002
+ else:
1003
+ # If definition not found, return the original ref
1004
+ return schema
1005
+ else:
1006
+ # Other ref formats - return as is
1007
+ return schema
1008
+
1009
+ # If this is a dict, recursively process all values
1010
+ if isinstance(schema, dict):
1011
+ result = {}
1012
+ for key, value in schema.items():
1013
+ # Skip $defs as we're inlining them
1014
+ if key == "$defs":
1015
+ continue
1016
+ result[key] = _resolve_schema_refs(value, defs)
1017
+ return result
1018
+
1019
+ # If this is a list, recursively process all items
1020
+ if isinstance(schema, list):
1021
+ return [_resolve_schema_refs(item, defs) for item in schema]
1022
+
1023
+ # For primitive types, return as is
1024
+ return schema
1025
+
1026
+
1027
+ def _pydantic_to_json_schema(
1028
+ model_class: type[BaseModel], inline_refs: bool = False
1029
+ ) -> Dict[str, Any]:
956
1030
  """
957
1031
  Convert a Pydantic model to JSON schema for OpenAI tools.
958
-
1032
+
959
1033
  Args:
960
1034
  model_class: Pydantic model class
961
-
1035
+ inline_refs: If True, resolve $refs and inline $defs in the schema
1036
+
962
1037
  Returns:
963
1038
  JSON schema dictionary
964
1039
  """
965
- if hasattr(model_class, 'model_json_schema'):
1040
+ if hasattr(model_class, "model_json_schema"):
966
1041
  # Pydantic v2
967
- return model_class.model_json_schema()
1042
+ schema = model_class.model_json_schema()
968
1043
  else:
969
1044
  # Pydantic v1 fallback
970
- return model_class.schema()
1045
+ schema = model_class.schema()
1046
+
1047
+ # If inline_refs is True, resolve all references
1048
+ if inline_refs:
1049
+ schema = _resolve_schema_refs(schema)
1050
+
1051
+ return schema