polos-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- polos/__init__.py +105 -0
- polos/agents/__init__.py +7 -0
- polos/agents/agent.py +746 -0
- polos/agents/conversation_history.py +121 -0
- polos/agents/stop_conditions.py +280 -0
- polos/agents/stream.py +635 -0
- polos/core/__init__.py +0 -0
- polos/core/context.py +143 -0
- polos/core/state.py +26 -0
- polos/core/step.py +1380 -0
- polos/core/workflow.py +1192 -0
- polos/features/__init__.py +0 -0
- polos/features/events.py +456 -0
- polos/features/schedules.py +110 -0
- polos/features/tracing.py +605 -0
- polos/features/wait.py +82 -0
- polos/llm/__init__.py +9 -0
- polos/llm/generate.py +152 -0
- polos/llm/providers/__init__.py +5 -0
- polos/llm/providers/anthropic.py +615 -0
- polos/llm/providers/azure.py +42 -0
- polos/llm/providers/base.py +196 -0
- polos/llm/providers/fireworks.py +41 -0
- polos/llm/providers/gemini.py +40 -0
- polos/llm/providers/groq.py +40 -0
- polos/llm/providers/openai.py +1021 -0
- polos/llm/providers/together.py +40 -0
- polos/llm/stream.py +183 -0
- polos/middleware/__init__.py +0 -0
- polos/middleware/guardrail.py +148 -0
- polos/middleware/guardrail_executor.py +253 -0
- polos/middleware/hook.py +164 -0
- polos/middleware/hook_executor.py +104 -0
- polos/runtime/__init__.py +0 -0
- polos/runtime/batch.py +87 -0
- polos/runtime/client.py +841 -0
- polos/runtime/queue.py +42 -0
- polos/runtime/worker.py +1365 -0
- polos/runtime/worker_server.py +249 -0
- polos/tools/__init__.py +0 -0
- polos/tools/tool.py +587 -0
- polos/types/__init__.py +23 -0
- polos/types/types.py +116 -0
- polos/utils/__init__.py +27 -0
- polos/utils/agent.py +27 -0
- polos/utils/client_context.py +41 -0
- polos/utils/config.py +12 -0
- polos/utils/output_schema.py +311 -0
- polos/utils/retry.py +47 -0
- polos/utils/serializer.py +167 -0
- polos/utils/tracing.py +27 -0
- polos/utils/worker_singleton.py +40 -0
- polos_sdk-0.1.0.dist-info/METADATA +650 -0
- polos_sdk-0.1.0.dist-info/RECORD +55 -0
- polos_sdk-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,1021 @@
|
|
|
1
|
+
"""OpenAI provider implementation supporting both Responses API and Chat Completions API."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from .base import LLMProvider, LLMResponse, register_provider
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@register_provider("openai")
|
|
14
|
+
class OpenAIProvider(LLMProvider):
|
|
15
|
+
"""OpenAI provider for LLM calls supporting both Responses API and Chat Completions API."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
api_key: str | None = None,
|
|
20
|
+
base_url: str | None = None,
|
|
21
|
+
llm_api: str = "responses",
|
|
22
|
+
):
|
|
23
|
+
"""
|
|
24
|
+
Initialize OpenAI provider.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
api_key: OpenAI API key. If not provided, uses OPENAI_API_KEY env var.
|
|
28
|
+
base_url: Optional base URL for the API. If not provided, defaults to OpenAI's URL.
|
|
29
|
+
Useful for Azure OpenAI or other OpenAI-compatible endpoints.
|
|
30
|
+
llm_api: API version to use - "responses" (default) or "chat_completions"
|
|
31
|
+
"""
|
|
32
|
+
# Import OpenAI SDK only when this provider is used (lazy loading)
|
|
33
|
+
try:
|
|
34
|
+
from openai import AsyncOpenAI # noqa: F401
|
|
35
|
+
except ImportError:
|
|
36
|
+
raise ImportError(
|
|
37
|
+
"OpenAI SDK not installed. Install it with: pip install polos[openai]"
|
|
38
|
+
) from None
|
|
39
|
+
|
|
40
|
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
41
|
+
if not self.api_key:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
"OpenAI API key not provided. Set OPENAI_API_KEY environment variable "
|
|
44
|
+
"or pass api_key parameter."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Get base URL from parameter or default to OpenAI
|
|
48
|
+
self.base_url = base_url or os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
|
49
|
+
|
|
50
|
+
# Store API version
|
|
51
|
+
self.llm_api = llm_api
|
|
52
|
+
|
|
53
|
+
# Initialize OpenAI async client
|
|
54
|
+
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
55
|
+
|
|
56
|
+
# For chat_completions, we need supports_structured_output flag
|
|
57
|
+
# This is used when llm_api is "chat_completions"
|
|
58
|
+
self.supports_structured_output = True # Can be overridden by subclasses
|
|
59
|
+
|
|
60
|
+
async def generate(
|
|
61
|
+
self,
|
|
62
|
+
messages: list[dict[str, Any]],
|
|
63
|
+
model: str,
|
|
64
|
+
tools: list[dict[str, Any]] | None = None,
|
|
65
|
+
temperature: float | None = None,
|
|
66
|
+
max_tokens: int | None = None,
|
|
67
|
+
top_p: float | None = None,
|
|
68
|
+
agent_config: dict[str, Any] | None = None,
|
|
69
|
+
tool_results: list[dict[str, Any]] | None = None,
|
|
70
|
+
output_schema: dict[str, Any] | None = None,
|
|
71
|
+
output_schema_name: str | None = None,
|
|
72
|
+
**kwargs,
|
|
73
|
+
) -> LLMResponse:
|
|
74
|
+
"""
|
|
75
|
+
Make a request to OpenAI using either Responses API or Chat Completions API.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
messages: List of message dicts with 'role' and 'content' keys
|
|
79
|
+
model: Model identifier (e.g., "gpt-4", "gpt-3.5-turbo")
|
|
80
|
+
tools: Optional list of tool schemas for function calling
|
|
81
|
+
temperature: Optional temperature parameter (0-2)
|
|
82
|
+
max_tokens: Optional max tokens parameter
|
|
83
|
+
agent_config: Optional AgentConfig dict containing system_prompt and other config
|
|
84
|
+
tool_results: Optional list of tool results in OpenAI format to add to messages
|
|
85
|
+
output_schema: Optional structured output schema
|
|
86
|
+
output_schema_name: Optional schema name
|
|
87
|
+
**kwargs: Additional OpenAI-specific parameters
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
LLMResponse with content, usage, and tool_calls
|
|
91
|
+
"""
|
|
92
|
+
if self.llm_api == "chat_completions":
|
|
93
|
+
return await self._generate_chat_completions(
|
|
94
|
+
messages,
|
|
95
|
+
model,
|
|
96
|
+
tools,
|
|
97
|
+
temperature,
|
|
98
|
+
max_tokens,
|
|
99
|
+
top_p,
|
|
100
|
+
agent_config,
|
|
101
|
+
tool_results,
|
|
102
|
+
output_schema,
|
|
103
|
+
output_schema_name,
|
|
104
|
+
**kwargs,
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
return await self._generate_responses(
|
|
108
|
+
messages,
|
|
109
|
+
model,
|
|
110
|
+
tools,
|
|
111
|
+
temperature,
|
|
112
|
+
max_tokens,
|
|
113
|
+
top_p,
|
|
114
|
+
agent_config,
|
|
115
|
+
tool_results,
|
|
116
|
+
output_schema,
|
|
117
|
+
output_schema_name,
|
|
118
|
+
**kwargs,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
async def stream(
|
|
122
|
+
self,
|
|
123
|
+
messages: list[dict[str, Any]],
|
|
124
|
+
model: str,
|
|
125
|
+
tools: list[dict[str, Any]] | None = None,
|
|
126
|
+
temperature: float | None = None,
|
|
127
|
+
max_tokens: int | None = None,
|
|
128
|
+
top_p: float | None = None,
|
|
129
|
+
agent_config: dict[str, Any] | None = None,
|
|
130
|
+
tool_results: list[dict[str, Any]] | None = None,
|
|
131
|
+
output_schema: dict[str, Any] | None = None,
|
|
132
|
+
output_schema_name: str | None = None,
|
|
133
|
+
**kwargs,
|
|
134
|
+
):
|
|
135
|
+
"""
|
|
136
|
+
Stream responses from OpenAI using either Responses API or Chat Completions API.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
messages: List of message dicts with 'role' and 'content' keys
|
|
140
|
+
model: Model identifier (e.g., "gpt-4", "gpt-3.5-turbo")
|
|
141
|
+
tools: Optional list of tool schemas for function calling
|
|
142
|
+
temperature: Optional temperature parameter (0-2)
|
|
143
|
+
max_tokens: Optional max tokens parameter
|
|
144
|
+
top_p: Optional top_p parameter
|
|
145
|
+
agent_config: Optional AgentConfig dict containing system_prompt and other config
|
|
146
|
+
tool_results: Optional list of tool results in OpenAI format to add to messages
|
|
147
|
+
output_schema: Optional structured output schema
|
|
148
|
+
output_schema_name: Optional schema name for structured output
|
|
149
|
+
**kwargs: Additional OpenAI-specific parameters
|
|
150
|
+
|
|
151
|
+
Yields:
|
|
152
|
+
Dict with event information:
|
|
153
|
+
- type: "text_delta", "tool_call", "done", "error"
|
|
154
|
+
- data: Event-specific data
|
|
155
|
+
"""
|
|
156
|
+
if self.llm_api == "chat_completions":
|
|
157
|
+
async for event in self._stream_chat_completions(
|
|
158
|
+
messages,
|
|
159
|
+
model,
|
|
160
|
+
tools,
|
|
161
|
+
temperature,
|
|
162
|
+
max_tokens,
|
|
163
|
+
top_p,
|
|
164
|
+
agent_config,
|
|
165
|
+
tool_results,
|
|
166
|
+
output_schema,
|
|
167
|
+
output_schema_name,
|
|
168
|
+
**kwargs,
|
|
169
|
+
):
|
|
170
|
+
yield event
|
|
171
|
+
else:
|
|
172
|
+
async for event in self._stream_responses(
|
|
173
|
+
messages,
|
|
174
|
+
model,
|
|
175
|
+
tools,
|
|
176
|
+
temperature,
|
|
177
|
+
max_tokens,
|
|
178
|
+
top_p,
|
|
179
|
+
agent_config,
|
|
180
|
+
tool_results,
|
|
181
|
+
output_schema,
|
|
182
|
+
output_schema_name,
|
|
183
|
+
**kwargs,
|
|
184
|
+
):
|
|
185
|
+
yield event
|
|
186
|
+
|
|
187
|
+
async def _generate_responses(
|
|
188
|
+
self,
|
|
189
|
+
messages: list[dict[str, Any]],
|
|
190
|
+
model: str,
|
|
191
|
+
tools: list[dict[str, Any]] | None = None,
|
|
192
|
+
temperature: float | None = None,
|
|
193
|
+
max_tokens: int | None = None,
|
|
194
|
+
top_p: float | None = None,
|
|
195
|
+
agent_config: dict[str, Any] | None = None,
|
|
196
|
+
tool_results: list[dict[str, Any]] | None = None,
|
|
197
|
+
output_schema: dict[str, Any] | None = None,
|
|
198
|
+
output_schema_name: str | None = None,
|
|
199
|
+
**kwargs,
|
|
200
|
+
) -> LLMResponse:
|
|
201
|
+
"""Generate using Responses API."""
|
|
202
|
+
# Prepare messages - copy to avoid mutating input
|
|
203
|
+
processed_messages = messages.copy() if messages else []
|
|
204
|
+
|
|
205
|
+
# Add tool_results to messages (OpenAI format, no conversion needed)
|
|
206
|
+
if tool_results:
|
|
207
|
+
processed_messages.extend(tool_results)
|
|
208
|
+
|
|
209
|
+
# Prepare request parameters for Responses API
|
|
210
|
+
request_params = {
|
|
211
|
+
"model": model,
|
|
212
|
+
"input": processed_messages,
|
|
213
|
+
"stream": False,
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
# Add system prompt from agent_config if present
|
|
217
|
+
# OpenAI Responses API uses "instructions" parameter
|
|
218
|
+
if agent_config and agent_config.get("system_prompt"):
|
|
219
|
+
request_params["instructions"] = agent_config.get("system_prompt")
|
|
220
|
+
|
|
221
|
+
if temperature is not None:
|
|
222
|
+
request_params["temperature"] = temperature
|
|
223
|
+
if max_tokens is not None:
|
|
224
|
+
request_params["max_output_tokens"] = max_tokens
|
|
225
|
+
if tools:
|
|
226
|
+
# Validate tools format - Responses API expects:
|
|
227
|
+
# {"type": "function", "name": "...", "description": "...", "parameters": {...}}
|
|
228
|
+
# (name, description, parameters at top level, not nested in "function")
|
|
229
|
+
validated_tools = _validate_tools_responses(tools)
|
|
230
|
+
if validated_tools:
|
|
231
|
+
request_params["tools"] = validated_tools
|
|
232
|
+
else:
|
|
233
|
+
# If tools were provided but none validated, log a warning
|
|
234
|
+
if tools:
|
|
235
|
+
logger.warning("Tools provided but none were valid. Tools: %s", tools)
|
|
236
|
+
|
|
237
|
+
# Add structured output format if output_schema is provided
|
|
238
|
+
if output_schema and output_schema_name:
|
|
239
|
+
request_params["text"] = {
|
|
240
|
+
"format": {
|
|
241
|
+
"type": "json_schema",
|
|
242
|
+
"name": output_schema_name,
|
|
243
|
+
"strict": True,
|
|
244
|
+
"schema": output_schema,
|
|
245
|
+
}
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
# Add any additional kwargs
|
|
249
|
+
request_params.update(kwargs)
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
# Use the SDK's Responses API
|
|
253
|
+
response = await self.client.responses.create(**request_params)
|
|
254
|
+
if not response:
|
|
255
|
+
raise RuntimeError("OpenAI API returned no response")
|
|
256
|
+
|
|
257
|
+
response_dict = response.model_dump(exclude_none=True, mode="json")
|
|
258
|
+
|
|
259
|
+
# Check for errors
|
|
260
|
+
if response_dict.get("error"):
|
|
261
|
+
error_msg = response_dict.get("error").get("message", "Unknown error")
|
|
262
|
+
raise RuntimeError(f"OpenAI API error: {error_msg}")
|
|
263
|
+
|
|
264
|
+
# Extract content from response using the output_text property
|
|
265
|
+
# This property aggregates all text from output items
|
|
266
|
+
content = response.output_text
|
|
267
|
+
raw_output = response_dict.get("output")
|
|
268
|
+
processed_messages.extend(raw_output)
|
|
269
|
+
|
|
270
|
+
# Extract tool calls from output items
|
|
271
|
+
# Tool calls are in the output items as content items with type "tool_call"
|
|
272
|
+
tool_calls = []
|
|
273
|
+
for output_item in response.output:
|
|
274
|
+
# Output items have a type and content
|
|
275
|
+
if hasattr(output_item, "type") and output_item.type == "function_call":
|
|
276
|
+
# Extract tool call information
|
|
277
|
+
tool_call_data = {
|
|
278
|
+
"id": getattr(output_item, "id", ""),
|
|
279
|
+
"call_id": getattr(output_item, "call_id", ""),
|
|
280
|
+
"type": "function",
|
|
281
|
+
"function": {
|
|
282
|
+
"name": getattr(output_item, "name", ""),
|
|
283
|
+
"arguments": getattr(output_item, "arguments", "")
|
|
284
|
+
if hasattr(output_item, "arguments")
|
|
285
|
+
else "",
|
|
286
|
+
},
|
|
287
|
+
}
|
|
288
|
+
tool_calls.append(tool_call_data)
|
|
289
|
+
|
|
290
|
+
# Extract usage information
|
|
291
|
+
# ResponseUsage has input_tokens, output_tokens, and total_tokens
|
|
292
|
+
usage_data = response.usage
|
|
293
|
+
usage = {
|
|
294
|
+
"input_tokens": usage_data.input_tokens if usage_data else 0,
|
|
295
|
+
"output_tokens": usage_data.output_tokens if usage_data else 0,
|
|
296
|
+
"total_tokens": usage_data.total_tokens if usage_data else 0,
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
# Extract model and stop_reason from response
|
|
300
|
+
response_model = getattr(response, "model", None) or model
|
|
301
|
+
incomplete_details = getattr(response, "incomplete_details", None)
|
|
302
|
+
if incomplete_details:
|
|
303
|
+
response_stop_reason = getattr(incomplete_details, "reason", None)
|
|
304
|
+
else:
|
|
305
|
+
response_stop_reason = None
|
|
306
|
+
|
|
307
|
+
return LLMResponse(
|
|
308
|
+
content=content,
|
|
309
|
+
usage=usage,
|
|
310
|
+
tool_calls=tool_calls,
|
|
311
|
+
raw_output=processed_messages,
|
|
312
|
+
model=response_model,
|
|
313
|
+
stop_reason=response_stop_reason,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
except Exception as e:
|
|
317
|
+
# Re-raise with more context
|
|
318
|
+
raise RuntimeError(f"OpenAI Responses API call failed: {str(e)}") from e
|
|
319
|
+
|
|
320
|
+
async def _generate_chat_completions(
|
|
321
|
+
self,
|
|
322
|
+
messages: list[dict[str, Any]],
|
|
323
|
+
model: str,
|
|
324
|
+
tools: list[dict[str, Any]] | None = None,
|
|
325
|
+
temperature: float | None = None,
|
|
326
|
+
max_tokens: int | None = None,
|
|
327
|
+
top_p: float | None = None,
|
|
328
|
+
agent_config: dict[str, Any] | None = None,
|
|
329
|
+
tool_results: list[dict[str, Any]] | None = None,
|
|
330
|
+
output_schema: dict[str, Any] | None = None,
|
|
331
|
+
output_schema_name: str | None = None,
|
|
332
|
+
**kwargs,
|
|
333
|
+
) -> LLMResponse:
|
|
334
|
+
"""Generate using Chat Completions API."""
|
|
335
|
+
# Prepare messages - copy to avoid mutating input
|
|
336
|
+
processed_messages = messages.copy() if messages else []
|
|
337
|
+
|
|
338
|
+
# Add system prompt from agent_config if present
|
|
339
|
+
# Chat Completions API uses "system" role in messages
|
|
340
|
+
if agent_config and agent_config.get("system_prompt"):
|
|
341
|
+
# Check if there's already a system message
|
|
342
|
+
has_system = any(msg.get("role") == "system" for msg in processed_messages)
|
|
343
|
+
if not has_system:
|
|
344
|
+
processed_messages.insert(
|
|
345
|
+
0, {"role": "system", "content": agent_config.get("system_prompt")}
|
|
346
|
+
)
|
|
347
|
+
else:
|
|
348
|
+
# Update existing system message
|
|
349
|
+
for msg in processed_messages:
|
|
350
|
+
if msg.get("role") == "system":
|
|
351
|
+
msg["content"] = (
|
|
352
|
+
msg.get("content", "") + "\n\n" + agent_config.get("system_prompt")
|
|
353
|
+
)
|
|
354
|
+
break
|
|
355
|
+
|
|
356
|
+
# Add tool_results to messages.
|
|
357
|
+
if tool_results:
|
|
358
|
+
for tool_result in tool_results:
|
|
359
|
+
if tool_result.get("type") == "function_call_output":
|
|
360
|
+
tool_call_id = tool_result.get("call_id")
|
|
361
|
+
output = tool_result.get("output")
|
|
362
|
+
processed_messages.append(
|
|
363
|
+
{
|
|
364
|
+
"role": "tool",
|
|
365
|
+
"content": output if isinstance(output, str) else json.dumps(output),
|
|
366
|
+
"tool_call_id": tool_call_id,
|
|
367
|
+
}
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
# Prepare request parameters for Chat Completions API
|
|
371
|
+
request_params = {
|
|
372
|
+
"model": model,
|
|
373
|
+
"messages": processed_messages,
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
if temperature is not None:
|
|
377
|
+
request_params["temperature"] = temperature
|
|
378
|
+
if max_tokens is not None:
|
|
379
|
+
request_params["max_tokens"] = max_tokens
|
|
380
|
+
if top_p is not None:
|
|
381
|
+
request_params["top_p"] = top_p
|
|
382
|
+
|
|
383
|
+
if tools:
|
|
384
|
+
# Chat Completions API expects tools in format:
|
|
385
|
+
# [{"type": "function", "function": {"name": "...", "description": "...",
|
|
386
|
+
# "parameters": {...}}}]
|
|
387
|
+
validated_tools = _validate_tools_chat_completions(tools)
|
|
388
|
+
if validated_tools:
|
|
389
|
+
request_params["tools"] = validated_tools
|
|
390
|
+
|
|
391
|
+
# Handle structured output
|
|
392
|
+
if output_schema:
|
|
393
|
+
if (
|
|
394
|
+
self.supports_structured_output
|
|
395
|
+
and output_schema_name
|
|
396
|
+
and not request_params.get("tools")
|
|
397
|
+
):
|
|
398
|
+
# Use response_format parameter if supported and there are no tools
|
|
399
|
+
request_params["response_format"] = {
|
|
400
|
+
"type": "json_schema",
|
|
401
|
+
"json_schema": {
|
|
402
|
+
"name": output_schema_name,
|
|
403
|
+
"strict": True,
|
|
404
|
+
"schema": output_schema,
|
|
405
|
+
},
|
|
406
|
+
}
|
|
407
|
+
else:
|
|
408
|
+
# Add structured output instructions to system prompt
|
|
409
|
+
schema_json = json.dumps(output_schema, indent=2)
|
|
410
|
+
structured_output_instruction = (
|
|
411
|
+
f"\n\nIMPORTANT: You must structure your text response as valid JSON "
|
|
412
|
+
f"that strictly conforms to this schema:\n\n{schema_json}\n\n"
|
|
413
|
+
f"Return ONLY valid JSON that matches this schema. Do not include any "
|
|
414
|
+
f"text, explanation, markdown code fences (```json or ```), or "
|
|
415
|
+
f"formatting outside of the JSON structure. Return only the raw JSON "
|
|
416
|
+
f"without any markdown formatting."
|
|
417
|
+
)
|
|
418
|
+
# Update system message
|
|
419
|
+
has_system = any(msg.get("role") == "system" for msg in processed_messages)
|
|
420
|
+
if has_system:
|
|
421
|
+
for msg in processed_messages:
|
|
422
|
+
if msg.get("role") == "system":
|
|
423
|
+
msg["content"] = msg.get("content", "") + structured_output_instruction
|
|
424
|
+
break
|
|
425
|
+
else:
|
|
426
|
+
processed_messages.insert(
|
|
427
|
+
0, {"role": "system", "content": structured_output_instruction}
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
# Add any additional kwargs
|
|
431
|
+
request_params.update(kwargs)
|
|
432
|
+
|
|
433
|
+
try:
|
|
434
|
+
usage = {
|
|
435
|
+
"input_tokens": 0,
|
|
436
|
+
"output_tokens": 0,
|
|
437
|
+
"total_tokens": 0,
|
|
438
|
+
}
|
|
439
|
+
response_model = model
|
|
440
|
+
response_stop_reason = None
|
|
441
|
+
tool_calls = []
|
|
442
|
+
content = None
|
|
443
|
+
|
|
444
|
+
# Use the Chat Completions API
|
|
445
|
+
response = await self.client.chat.completions.create(**request_params)
|
|
446
|
+
if not response:
|
|
447
|
+
raise RuntimeError("OpenAI API returned no response")
|
|
448
|
+
|
|
449
|
+
# Extract content from response
|
|
450
|
+
if response.choices and len(response.choices) > 0:
|
|
451
|
+
choice = response.choices[0]
|
|
452
|
+
if not choice.message:
|
|
453
|
+
raise RuntimeError("OpenAI API returned no message")
|
|
454
|
+
|
|
455
|
+
processed_messages.append(choice.message.model_dump(exclude_none=True, mode="json"))
|
|
456
|
+
content = choice.message.content or ""
|
|
457
|
+
response_stop_reason = choice.finish_reason
|
|
458
|
+
|
|
459
|
+
# Extract tool calls
|
|
460
|
+
if choice.message.tool_calls:
|
|
461
|
+
for tool_call in choice.message.tool_calls:
|
|
462
|
+
tool_calls.append(
|
|
463
|
+
{
|
|
464
|
+
"call_id": tool_call.id,
|
|
465
|
+
"id": "",
|
|
466
|
+
"type": "function",
|
|
467
|
+
"function": {
|
|
468
|
+
"name": tool_call.function.name,
|
|
469
|
+
"arguments": tool_call.function.arguments,
|
|
470
|
+
},
|
|
471
|
+
}
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
# Extract usage information
|
|
475
|
+
if response.usage:
|
|
476
|
+
usage["input_tokens"] = response.usage.prompt_tokens
|
|
477
|
+
usage["output_tokens"] = response.usage.completion_tokens
|
|
478
|
+
usage["total_tokens"] = response.usage.total_tokens
|
|
479
|
+
|
|
480
|
+
# Extract model and stop_reason
|
|
481
|
+
response_model = response.model or model
|
|
482
|
+
|
|
483
|
+
return LLMResponse(
|
|
484
|
+
content=content,
|
|
485
|
+
usage=usage,
|
|
486
|
+
tool_calls=tool_calls,
|
|
487
|
+
raw_output=processed_messages,
|
|
488
|
+
model=response_model,
|
|
489
|
+
stop_reason=response_stop_reason,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
except Exception as e:
|
|
493
|
+
# Re-raise with more context
|
|
494
|
+
raise RuntimeError(f"OpenAI Chat Completions API call failed: {str(e)}") from e
|
|
495
|
+
|
|
496
|
+
async def _stream_responses(
|
|
497
|
+
self,
|
|
498
|
+
messages: list[dict[str, Any]],
|
|
499
|
+
model: str,
|
|
500
|
+
tools: list[dict[str, Any]] | None = None,
|
|
501
|
+
temperature: float | None = None,
|
|
502
|
+
max_tokens: int | None = None,
|
|
503
|
+
top_p: float | None = None,
|
|
504
|
+
agent_config: dict[str, Any] | None = None,
|
|
505
|
+
tool_results: list[dict[str, Any]] | None = None,
|
|
506
|
+
output_schema: dict[str, Any] | None = None,
|
|
507
|
+
output_schema_name: str | None = None,
|
|
508
|
+
**kwargs,
|
|
509
|
+
):
|
|
510
|
+
"""Stream using Responses API."""
|
|
511
|
+
# Prepare messages - copy to avoid mutating input
|
|
512
|
+
processed_messages = messages.copy() if messages else []
|
|
513
|
+
|
|
514
|
+
# Add tool_results to messages (OpenAI format, no conversion needed)
|
|
515
|
+
if tool_results:
|
|
516
|
+
processed_messages.extend(tool_results)
|
|
517
|
+
|
|
518
|
+
# Prepare request parameters for Responses API
|
|
519
|
+
request_params = {
|
|
520
|
+
"model": model,
|
|
521
|
+
"input": processed_messages,
|
|
522
|
+
"stream": True, # Enable streaming
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
# Add system prompt from agent_config if present
|
|
526
|
+
# OpenAI Responses API uses "instructions" parameter
|
|
527
|
+
if agent_config and agent_config.get("system_prompt"):
|
|
528
|
+
request_params["instructions"] = agent_config.get("system_prompt")
|
|
529
|
+
|
|
530
|
+
if temperature is not None:
|
|
531
|
+
request_params["temperature"] = temperature
|
|
532
|
+
if max_tokens is not None:
|
|
533
|
+
request_params["max_output_tokens"] = max_tokens
|
|
534
|
+
if tools:
|
|
535
|
+
# Validate tools format - Responses API expects:
|
|
536
|
+
# {"type": "function", "name": "...", "description": "...", "parameters": {...}}
|
|
537
|
+
validated_tools = _validate_tools_responses(tools)
|
|
538
|
+
if validated_tools:
|
|
539
|
+
request_params["tools"] = validated_tools
|
|
540
|
+
|
|
541
|
+
# Add structured output format if output_schema is provided
|
|
542
|
+
if output_schema and output_schema_name:
|
|
543
|
+
request_params["text"] = {
|
|
544
|
+
"format": {
|
|
545
|
+
"type": "json_schema",
|
|
546
|
+
"name": output_schema_name,
|
|
547
|
+
"strict": True,
|
|
548
|
+
"schema": output_schema,
|
|
549
|
+
}
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
# Add any additional kwargs
|
|
553
|
+
request_params.update(kwargs)
|
|
554
|
+
try:
|
|
555
|
+
# Use the SDK's Responses API with streaming
|
|
556
|
+
stream = await self.client.responses.create(**request_params)
|
|
557
|
+
|
|
558
|
+
# The Responses API returns an async iterator of events
|
|
559
|
+
# Events include: response.created, response.in_progress, response.output_item.added,
|
|
560
|
+
# response.content_part.added, response.output_text.delta, response.output_text.done,
|
|
561
|
+
# response.content_part.done, response.output_item.done, response.completed
|
|
562
|
+
accumulated_text = ""
|
|
563
|
+
tool_calls = []
|
|
564
|
+
usage = {
|
|
565
|
+
"input_tokens": 0,
|
|
566
|
+
"output_tokens": 0,
|
|
567
|
+
"total_tokens": 0,
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
async for event in stream:
|
|
571
|
+
event = event.model_dump(exclude_none=True, mode="json")
|
|
572
|
+
|
|
573
|
+
event_type = event.get("type")
|
|
574
|
+
|
|
575
|
+
if event_type == "response.output_text.delta":
|
|
576
|
+
# Text delta - incremental text chunk
|
|
577
|
+
delta_text = event.get("delta", "")
|
|
578
|
+
if delta_text:
|
|
579
|
+
accumulated_text += delta_text
|
|
580
|
+
yield {
|
|
581
|
+
"type": "text_delta",
|
|
582
|
+
"data": {
|
|
583
|
+
"content": delta_text,
|
|
584
|
+
},
|
|
585
|
+
}
|
|
586
|
+
|
|
587
|
+
elif event_type == "response.output_item.added":
|
|
588
|
+
# Output item added - could be message, tool call, etc.
|
|
589
|
+
# For tool calls, we wait until output_item.done to get complete arguments
|
|
590
|
+
# Just acknowledge, don't yield anything yet
|
|
591
|
+
pass
|
|
592
|
+
|
|
593
|
+
elif event_type == "response.output_item.done":
|
|
594
|
+
# Output item done - extract final content if it's a message or tool call
|
|
595
|
+
item = event.get("item")
|
|
596
|
+
item_type = item.get("type")
|
|
597
|
+
|
|
598
|
+
if item_type == "function_call":
|
|
599
|
+
# Function tool call - extract complete tool call with arguments
|
|
600
|
+
tool_call_id = item.get("id", "")
|
|
601
|
+
tool_call_name = item.get("name", "")
|
|
602
|
+
tool_call_arguments = item.get("arguments", "")
|
|
603
|
+
tool_call_call_id = item.get("call_id", "")
|
|
604
|
+
|
|
605
|
+
if tool_call_name and tool_call_arguments:
|
|
606
|
+
tool_call_data = {
|
|
607
|
+
"id": tool_call_id,
|
|
608
|
+
"call_id": tool_call_call_id,
|
|
609
|
+
"type": "function",
|
|
610
|
+
"function": {
|
|
611
|
+
"name": tool_call_name,
|
|
612
|
+
"arguments": tool_call_arguments,
|
|
613
|
+
},
|
|
614
|
+
}
|
|
615
|
+
tool_calls.append(tool_call_data)
|
|
616
|
+
yield {
|
|
617
|
+
"type": "tool_call",
|
|
618
|
+
"data": {
|
|
619
|
+
"tool_call": tool_call_data,
|
|
620
|
+
},
|
|
621
|
+
}
|
|
622
|
+
|
|
623
|
+
elif event_type == "response.content_part.added":
|
|
624
|
+
# Content part added - could be text
|
|
625
|
+
# For tool calls, we wait until output_item.done to get complete arguments
|
|
626
|
+
part = event.get("part")
|
|
627
|
+
if part:
|
|
628
|
+
part_type = part.get("type")
|
|
629
|
+
if part_type == "output_text":
|
|
630
|
+
# Text content part
|
|
631
|
+
text = part.get("text", "")
|
|
632
|
+
if text:
|
|
633
|
+
accumulated_text += text
|
|
634
|
+
yield {
|
|
635
|
+
"type": "text_delta",
|
|
636
|
+
"data": {
|
|
637
|
+
"content": text,
|
|
638
|
+
},
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
elif event_type == "response.completed":
|
|
642
|
+
# Stream complete - final event with usage
|
|
643
|
+
raw_output = None
|
|
644
|
+
response_dict = None
|
|
645
|
+
response_model = model # Default to input model
|
|
646
|
+
response_stop_reason = None
|
|
647
|
+
response = event.get("response")
|
|
648
|
+
if response:
|
|
649
|
+
response_dict = (
|
|
650
|
+
response.model_dump(exclude_none=True, mode="json")
|
|
651
|
+
if hasattr(response, "model_dump")
|
|
652
|
+
else response
|
|
653
|
+
)
|
|
654
|
+
if response_dict:
|
|
655
|
+
raw_output = response_dict.get("output")
|
|
656
|
+
processed_messages.extend(raw_output)
|
|
657
|
+
usage_data = response_dict.get("usage")
|
|
658
|
+
|
|
659
|
+
if usage_data:
|
|
660
|
+
usage["input_tokens"] = usage_data.get("input_tokens", 0)
|
|
661
|
+
usage["output_tokens"] = usage_data.get("output_tokens", 0)
|
|
662
|
+
usage["total_tokens"] = usage_data.get(
|
|
663
|
+
"input_tokens", 0
|
|
664
|
+
) + usage_data.get("output_tokens", 0)
|
|
665
|
+
|
|
666
|
+
response_model = response_dict.get("model", model)
|
|
667
|
+
incomplete_details = response_dict.get("incomplete_details")
|
|
668
|
+
response_stop_reason = (
|
|
669
|
+
incomplete_details.get("reason", None) if incomplete_details else None
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
yield {
|
|
673
|
+
"type": "done",
|
|
674
|
+
"data": {
|
|
675
|
+
"usage": usage,
|
|
676
|
+
"raw_output": processed_messages,
|
|
677
|
+
"model": response_model,
|
|
678
|
+
"stop_reason": response_stop_reason,
|
|
679
|
+
},
|
|
680
|
+
}
|
|
681
|
+
|
|
682
|
+
elif event_type == "response.failed":
|
|
683
|
+
# Stream failed
|
|
684
|
+
response = event.get("response")
|
|
685
|
+
error_msg = "Stream failed"
|
|
686
|
+
if response and isinstance(response, dict):
|
|
687
|
+
error_obj = response.get("error")
|
|
688
|
+
if error_obj and isinstance(error_obj, dict):
|
|
689
|
+
error_msg = error_obj.get("message", "Stream failed")
|
|
690
|
+
|
|
691
|
+
yield {
|
|
692
|
+
"type": "error",
|
|
693
|
+
"data": {
|
|
694
|
+
"error": error_msg,
|
|
695
|
+
},
|
|
696
|
+
}
|
|
697
|
+
break
|
|
698
|
+
|
|
699
|
+
elif event_type == "error":
|
|
700
|
+
# Error event
|
|
701
|
+
error_msg = event.get("message", "Unknown error")
|
|
702
|
+
error_code = event.get("code", "")
|
|
703
|
+
if error_code:
|
|
704
|
+
error_msg = f"{error_code}: {error_msg}"
|
|
705
|
+
|
|
706
|
+
yield {
|
|
707
|
+
"type": "error",
|
|
708
|
+
"data": {
|
|
709
|
+
"error": error_msg,
|
|
710
|
+
},
|
|
711
|
+
}
|
|
712
|
+
break
|
|
713
|
+
|
|
714
|
+
except Exception as e:
|
|
715
|
+
# Yield error event and re-raise
|
|
716
|
+
yield {
|
|
717
|
+
"type": "error",
|
|
718
|
+
"data": {
|
|
719
|
+
"error": str(e),
|
|
720
|
+
},
|
|
721
|
+
}
|
|
722
|
+
raise RuntimeError(f"OpenAI Responses API streaming failed: {str(e)}") from e
|
|
723
|
+
|
|
724
|
+
async def _stream_chat_completions(
|
|
725
|
+
self,
|
|
726
|
+
messages: list[dict[str, Any]],
|
|
727
|
+
model: str,
|
|
728
|
+
tools: list[dict[str, Any]] | None = None,
|
|
729
|
+
temperature: float | None = None,
|
|
730
|
+
max_tokens: int | None = None,
|
|
731
|
+
top_p: float | None = None,
|
|
732
|
+
agent_config: dict[str, Any] | None = None,
|
|
733
|
+
tool_results: list[dict[str, Any]] | None = None,
|
|
734
|
+
output_schema: dict[str, Any] | None = None,
|
|
735
|
+
output_schema_name: str | None = None,
|
|
736
|
+
**kwargs,
|
|
737
|
+
):
|
|
738
|
+
"""Stream using Chat Completions API."""
|
|
739
|
+
# Prepare messages - copy to avoid mutating input
|
|
740
|
+
processed_messages = messages.copy() if messages else []
|
|
741
|
+
|
|
742
|
+
# Add system prompt from agent_config if present
|
|
743
|
+
if agent_config and agent_config.get("system_prompt"):
|
|
744
|
+
has_system = any(msg.get("role") == "system" for msg in processed_messages)
|
|
745
|
+
if not has_system:
|
|
746
|
+
processed_messages.insert(
|
|
747
|
+
0, {"role": "system", "content": agent_config.get("system_prompt")}
|
|
748
|
+
)
|
|
749
|
+
else:
|
|
750
|
+
for msg in processed_messages:
|
|
751
|
+
if msg.get("role") == "system":
|
|
752
|
+
msg["content"] = (
|
|
753
|
+
msg.get("content", "") + "\n\n" + agent_config.get("system_prompt")
|
|
754
|
+
)
|
|
755
|
+
break
|
|
756
|
+
|
|
757
|
+
# Add tool_results to messages
|
|
758
|
+
if tool_results:
|
|
759
|
+
for tool_result in tool_results:
|
|
760
|
+
if tool_result.get("type") == "function_call_output":
|
|
761
|
+
tool_call_id = tool_result.get("call_id")
|
|
762
|
+
output = tool_result.get("output")
|
|
763
|
+
processed_messages.append(
|
|
764
|
+
{
|
|
765
|
+
"role": "tool",
|
|
766
|
+
"content": output if isinstance(output, str) else json.dumps(output),
|
|
767
|
+
"tool_call_id": tool_call_id,
|
|
768
|
+
}
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
# Prepare request parameters for Chat Completions API
|
|
772
|
+
request_params = {
|
|
773
|
+
"model": model,
|
|
774
|
+
"messages": processed_messages,
|
|
775
|
+
"stream": True, # Enable streaming
|
|
776
|
+
}
|
|
777
|
+
|
|
778
|
+
if temperature is not None:
|
|
779
|
+
request_params["temperature"] = temperature
|
|
780
|
+
if max_tokens is not None:
|
|
781
|
+
request_params["max_tokens"] = max_tokens
|
|
782
|
+
if top_p is not None:
|
|
783
|
+
request_params["top_p"] = top_p
|
|
784
|
+
|
|
785
|
+
if tools:
|
|
786
|
+
validated_tools = _validate_tools_chat_completions(tools)
|
|
787
|
+
if validated_tools:
|
|
788
|
+
request_params["tools"] = validated_tools
|
|
789
|
+
|
|
790
|
+
# Handle structured output
|
|
791
|
+
if output_schema:
|
|
792
|
+
if (
|
|
793
|
+
self.supports_structured_output
|
|
794
|
+
and output_schema_name
|
|
795
|
+
and not request_params.get("tools")
|
|
796
|
+
):
|
|
797
|
+
# Use response_format parameter if supported
|
|
798
|
+
request_params["response_format"] = {
|
|
799
|
+
"type": "json_schema",
|
|
800
|
+
"json_schema": {
|
|
801
|
+
"name": output_schema_name,
|
|
802
|
+
"strict": True,
|
|
803
|
+
"schema": output_schema,
|
|
804
|
+
},
|
|
805
|
+
}
|
|
806
|
+
else:
|
|
807
|
+
# Add structured output instructions to system prompt (like Anthropic)
|
|
808
|
+
schema_json = json.dumps(output_schema, indent=2)
|
|
809
|
+
structured_output_instruction = (
|
|
810
|
+
f"\n\nIMPORTANT: You must structure your text response as valid JSON "
|
|
811
|
+
f"that strictly conforms to this schema:\n\n{schema_json}\n\n"
|
|
812
|
+
f"Return ONLY valid JSON that matches this schema. Do not include any "
|
|
813
|
+
f"text, explanation, markdown code fences (```json or ```), or "
|
|
814
|
+
f"formatting outside of the JSON structure. Return only the raw JSON "
|
|
815
|
+
f"without any markdown formatting."
|
|
816
|
+
)
|
|
817
|
+
# Update system message
|
|
818
|
+
has_system = any(msg.get("role") == "system" for msg in processed_messages)
|
|
819
|
+
if has_system:
|
|
820
|
+
for msg in processed_messages:
|
|
821
|
+
if msg.get("role") == "system":
|
|
822
|
+
msg["content"] = msg.get("content", "") + structured_output_instruction
|
|
823
|
+
break
|
|
824
|
+
else:
|
|
825
|
+
processed_messages.insert(
|
|
826
|
+
0, {"role": "system", "content": structured_output_instruction}
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
# Add any additional kwargs
|
|
830
|
+
request_params.update(kwargs)
|
|
831
|
+
|
|
832
|
+
tool_calls = []
|
|
833
|
+
usage = {
|
|
834
|
+
"input_tokens": 0,
|
|
835
|
+
"output_tokens": 0,
|
|
836
|
+
"total_tokens": 0,
|
|
837
|
+
}
|
|
838
|
+
stop_reason = None
|
|
839
|
+
response_model = model
|
|
840
|
+
accumulated_text = ""
|
|
841
|
+
accumulated_tool_calls = []
|
|
842
|
+
|
|
843
|
+
try:
|
|
844
|
+
# Stream from Chat Completions API
|
|
845
|
+
stream = await self.client.chat.completions.create(**request_params)
|
|
846
|
+
|
|
847
|
+
async for chunk in stream:
|
|
848
|
+
# Process streaming chunks
|
|
849
|
+
if chunk.choices and len(chunk.choices) > 0:
|
|
850
|
+
choice = chunk.choices[0]
|
|
851
|
+
delta = choice.delta
|
|
852
|
+
|
|
853
|
+
# Extract text delta
|
|
854
|
+
if delta.content:
|
|
855
|
+
accumulated_text += delta.content
|
|
856
|
+
yield {
|
|
857
|
+
"type": "text_delta",
|
|
858
|
+
"data": {
|
|
859
|
+
"content": delta.content,
|
|
860
|
+
},
|
|
861
|
+
}
|
|
862
|
+
|
|
863
|
+
# Extract tool calls
|
|
864
|
+
if delta.tool_calls:
|
|
865
|
+
for tool_call_delta in delta.tool_calls:
|
|
866
|
+
# Tool calls come in parts - need to accumulate
|
|
867
|
+
idx = tool_call_delta.index
|
|
868
|
+
delta_dict = tool_call_delta.model_dump(exclude_none=True, mode="json")
|
|
869
|
+
|
|
870
|
+
if idx is None or len(tool_calls) <= idx:
|
|
871
|
+
# New tool call
|
|
872
|
+
tool_calls.append(
|
|
873
|
+
{
|
|
874
|
+
"call_id": tool_call_delta.id or "",
|
|
875
|
+
"id": "",
|
|
876
|
+
"type": "function",
|
|
877
|
+
"function": {
|
|
878
|
+
"name": tool_call_delta.function.name or "",
|
|
879
|
+
"arguments": tool_call_delta.function.arguments or "",
|
|
880
|
+
},
|
|
881
|
+
}
|
|
882
|
+
)
|
|
883
|
+
|
|
884
|
+
# To feed back to the next model request
|
|
885
|
+
accumulated_tool_calls.append(
|
|
886
|
+
{
|
|
887
|
+
"id": delta_dict.get("id"),
|
|
888
|
+
"type": delta_dict.get("type"),
|
|
889
|
+
"function": delta_dict.get("function"),
|
|
890
|
+
}
|
|
891
|
+
)
|
|
892
|
+
|
|
893
|
+
# Merge any extra top-level fields
|
|
894
|
+
# (preserve new fields, don't overwrite existing)
|
|
895
|
+
known_keys = {"index", "id", "function", "type"}
|
|
896
|
+
for key, value in delta_dict.items():
|
|
897
|
+
if (
|
|
898
|
+
key not in known_keys
|
|
899
|
+
and key not in accumulated_tool_calls[-1]
|
|
900
|
+
):
|
|
901
|
+
accumulated_tool_calls[-1][key] = value
|
|
902
|
+
|
|
903
|
+
else:
|
|
904
|
+
# Append to existing tool call
|
|
905
|
+
existing = tool_calls[idx]
|
|
906
|
+
accumulated_existing = accumulated_tool_calls[idx]
|
|
907
|
+
if tool_call_delta.function.name:
|
|
908
|
+
existing["function"]["name"] = tool_call_delta.function.name
|
|
909
|
+
accumulated_existing["function"]["name"] = (
|
|
910
|
+
tool_call_delta.function.name
|
|
911
|
+
)
|
|
912
|
+
if tool_call_delta.function.arguments:
|
|
913
|
+
existing["function"]["arguments"] += (
|
|
914
|
+
tool_call_delta.function.arguments
|
|
915
|
+
)
|
|
916
|
+
accumulated_existing["function"]["arguments"] += (
|
|
917
|
+
tool_call_delta.function.arguments
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
# Merge any extra top-level fields
|
|
921
|
+
# (preserve new fields, don't overwrite existing)
|
|
922
|
+
known_keys = {"index", "id", "function", "type"}
|
|
923
|
+
for key, value in delta_dict.items():
|
|
924
|
+
if key not in known_keys and key not in accumulated_existing:
|
|
925
|
+
accumulated_existing[key] = value
|
|
926
|
+
|
|
927
|
+
# Update finish_reason if available
|
|
928
|
+
if choice.finish_reason:
|
|
929
|
+
stop_reason = choice.finish_reason
|
|
930
|
+
|
|
931
|
+
# Update model if available
|
|
932
|
+
if chunk.model:
|
|
933
|
+
response_model = chunk.model
|
|
934
|
+
|
|
935
|
+
# Update usage if available
|
|
936
|
+
if chunk.usage:
|
|
937
|
+
usage["input_tokens"] = chunk.usage.prompt_tokens or 0
|
|
938
|
+
usage["output_tokens"] = chunk.usage.completion_tokens or 0
|
|
939
|
+
usage["total_tokens"] = chunk.usage.total_tokens or 0
|
|
940
|
+
|
|
941
|
+
# Yield tool calls if any
|
|
942
|
+
for tool_call in tool_calls:
|
|
943
|
+
if tool_call["function"]["name"]: # Only yield complete tool calls
|
|
944
|
+
yield {
|
|
945
|
+
"type": "tool_call",
|
|
946
|
+
"data": {
|
|
947
|
+
"tool_call": tool_call,
|
|
948
|
+
},
|
|
949
|
+
}
|
|
950
|
+
|
|
951
|
+
# Final done event
|
|
952
|
+
usage["total_tokens"] = usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
|
|
953
|
+
|
|
954
|
+
processed_messages.append(
|
|
955
|
+
{
|
|
956
|
+
"role": "assistant",
|
|
957
|
+
"content": accumulated_text,
|
|
958
|
+
"tool_calls": accumulated_tool_calls,
|
|
959
|
+
}
|
|
960
|
+
)
|
|
961
|
+
yield {
|
|
962
|
+
"type": "done",
|
|
963
|
+
"data": {
|
|
964
|
+
"usage": usage,
|
|
965
|
+
"model": response_model,
|
|
966
|
+
"stop_reason": stop_reason,
|
|
967
|
+
"raw_output": processed_messages,
|
|
968
|
+
},
|
|
969
|
+
}
|
|
970
|
+
|
|
971
|
+
except Exception as e:
|
|
972
|
+
# Stream error
|
|
973
|
+
yield {
|
|
974
|
+
"type": "error",
|
|
975
|
+
"data": {
|
|
976
|
+
"error": str(e),
|
|
977
|
+
},
|
|
978
|
+
}
|
|
979
|
+
|
|
980
|
+
|
|
981
|
+
def _validate_tools_responses(tools: list[dict[str, Any]] | None) -> list[dict[str, Any]]:
|
|
982
|
+
"""Validate tools for Responses API format."""
|
|
983
|
+
validated_tools = []
|
|
984
|
+
for tool in tools:
|
|
985
|
+
# Default to type "function" if not specified
|
|
986
|
+
if "type" not in tool:
|
|
987
|
+
tool["type"] = "function"
|
|
988
|
+
|
|
989
|
+
validated_tools.append(tool)
|
|
990
|
+
return validated_tools
|
|
991
|
+
|
|
992
|
+
|
|
993
|
+
def _validate_tools_chat_completions(tools: list[dict[str, Any]] | None) -> list[dict[str, Any]]:
|
|
994
|
+
"""
|
|
995
|
+
Validate and normalize tools to OpenAI Chat Completions format.
|
|
996
|
+
|
|
997
|
+
OpenAI Chat Completions expects:
|
|
998
|
+
[{"type": "function", "function": {"name": "...", "description": "...", "parameters": {...}}}]
|
|
999
|
+
"""
|
|
1000
|
+
validated = []
|
|
1001
|
+
for tool in tools:
|
|
1002
|
+
if not isinstance(tool, dict):
|
|
1003
|
+
continue
|
|
1004
|
+
|
|
1005
|
+
# Handle OpenAI format: {"type": "function", "function": {...}}
|
|
1006
|
+
if tool.get("type", "function") == "function":
|
|
1007
|
+
if tool.get("name"):
|
|
1008
|
+
validated.append(
|
|
1009
|
+
{
|
|
1010
|
+
"type": "function",
|
|
1011
|
+
"function": {
|
|
1012
|
+
"name": tool.get("name"),
|
|
1013
|
+
"description": tool.get("description", ""),
|
|
1014
|
+
"parameters": tool.get("parameters", {}),
|
|
1015
|
+
},
|
|
1016
|
+
}
|
|
1017
|
+
)
|
|
1018
|
+
else:
|
|
1019
|
+
validated.append(tool)
|
|
1020
|
+
|
|
1021
|
+
return validated
|