agentrun-sdk 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agentrun-sdk might be problematic. Click here for more details.
- agentrun_operation_sdk/cli/__init__.py +1 -0
- agentrun_operation_sdk/cli/cli.py +19 -0
- agentrun_operation_sdk/cli/common.py +21 -0
- agentrun_operation_sdk/cli/runtime/__init__.py +1 -0
- agentrun_operation_sdk/cli/runtime/commands.py +203 -0
- agentrun_operation_sdk/client/client.py +75 -0
- agentrun_operation_sdk/operations/runtime/__init__.py +8 -0
- agentrun_operation_sdk/operations/runtime/configure.py +101 -0
- agentrun_operation_sdk/operations/runtime/launch.py +82 -0
- agentrun_operation_sdk/operations/runtime/models.py +31 -0
- agentrun_operation_sdk/services/runtime.py +152 -0
- agentrun_operation_sdk/utils/logging_config.py +72 -0
- agentrun_operation_sdk/utils/runtime/config.py +94 -0
- agentrun_operation_sdk/utils/runtime/container.py +280 -0
- agentrun_operation_sdk/utils/runtime/entrypoint.py +203 -0
- agentrun_operation_sdk/utils/runtime/schema.py +56 -0
- agentrun_sdk/__init__.py +7 -0
- agentrun_sdk/agent/__init__.py +25 -0
- agentrun_sdk/agent/agent.py +696 -0
- agentrun_sdk/agent/agent_result.py +46 -0
- agentrun_sdk/agent/conversation_manager/__init__.py +26 -0
- agentrun_sdk/agent/conversation_manager/conversation_manager.py +88 -0
- agentrun_sdk/agent/conversation_manager/null_conversation_manager.py +46 -0
- agentrun_sdk/agent/conversation_manager/sliding_window_conversation_manager.py +179 -0
- agentrun_sdk/agent/conversation_manager/summarizing_conversation_manager.py +252 -0
- agentrun_sdk/agent/state.py +97 -0
- agentrun_sdk/event_loop/__init__.py +9 -0
- agentrun_sdk/event_loop/event_loop.py +499 -0
- agentrun_sdk/event_loop/streaming.py +319 -0
- agentrun_sdk/experimental/__init__.py +4 -0
- agentrun_sdk/experimental/hooks/__init__.py +15 -0
- agentrun_sdk/experimental/hooks/events.py +123 -0
- agentrun_sdk/handlers/__init__.py +10 -0
- agentrun_sdk/handlers/callback_handler.py +70 -0
- agentrun_sdk/hooks/__init__.py +49 -0
- agentrun_sdk/hooks/events.py +80 -0
- agentrun_sdk/hooks/registry.py +247 -0
- agentrun_sdk/models/__init__.py +10 -0
- agentrun_sdk/models/anthropic.py +432 -0
- agentrun_sdk/models/bedrock.py +649 -0
- agentrun_sdk/models/litellm.py +225 -0
- agentrun_sdk/models/llamaapi.py +438 -0
- agentrun_sdk/models/mistral.py +539 -0
- agentrun_sdk/models/model.py +95 -0
- agentrun_sdk/models/ollama.py +357 -0
- agentrun_sdk/models/openai.py +436 -0
- agentrun_sdk/models/sagemaker.py +598 -0
- agentrun_sdk/models/writer.py +449 -0
- agentrun_sdk/multiagent/__init__.py +22 -0
- agentrun_sdk/multiagent/a2a/__init__.py +15 -0
- agentrun_sdk/multiagent/a2a/executor.py +148 -0
- agentrun_sdk/multiagent/a2a/server.py +252 -0
- agentrun_sdk/multiagent/base.py +92 -0
- agentrun_sdk/multiagent/graph.py +555 -0
- agentrun_sdk/multiagent/swarm.py +656 -0
- agentrun_sdk/py.typed +1 -0
- agentrun_sdk/session/__init__.py +18 -0
- agentrun_sdk/session/file_session_manager.py +216 -0
- agentrun_sdk/session/repository_session_manager.py +152 -0
- agentrun_sdk/session/s3_session_manager.py +272 -0
- agentrun_sdk/session/session_manager.py +73 -0
- agentrun_sdk/session/session_repository.py +51 -0
- agentrun_sdk/telemetry/__init__.py +21 -0
- agentrun_sdk/telemetry/config.py +194 -0
- agentrun_sdk/telemetry/metrics.py +476 -0
- agentrun_sdk/telemetry/metrics_constants.py +15 -0
- agentrun_sdk/telemetry/tracer.py +563 -0
- agentrun_sdk/tools/__init__.py +17 -0
- agentrun_sdk/tools/decorator.py +569 -0
- agentrun_sdk/tools/executor.py +137 -0
- agentrun_sdk/tools/loader.py +152 -0
- agentrun_sdk/tools/mcp/__init__.py +13 -0
- agentrun_sdk/tools/mcp/mcp_agent_tool.py +99 -0
- agentrun_sdk/tools/mcp/mcp_client.py +423 -0
- agentrun_sdk/tools/mcp/mcp_instrumentation.py +322 -0
- agentrun_sdk/tools/mcp/mcp_types.py +63 -0
- agentrun_sdk/tools/registry.py +607 -0
- agentrun_sdk/tools/structured_output.py +421 -0
- agentrun_sdk/tools/tools.py +217 -0
- agentrun_sdk/tools/watcher.py +136 -0
- agentrun_sdk/types/__init__.py +5 -0
- agentrun_sdk/types/collections.py +23 -0
- agentrun_sdk/types/content.py +188 -0
- agentrun_sdk/types/event_loop.py +48 -0
- agentrun_sdk/types/exceptions.py +81 -0
- agentrun_sdk/types/guardrails.py +254 -0
- agentrun_sdk/types/media.py +89 -0
- agentrun_sdk/types/session.py +152 -0
- agentrun_sdk/types/streaming.py +201 -0
- agentrun_sdk/types/tools.py +258 -0
- agentrun_sdk/types/traces.py +5 -0
- agentrun_sdk-0.1.2.dist-info/METADATA +51 -0
- agentrun_sdk-0.1.2.dist-info/RECORD +115 -0
- agentrun_sdk-0.1.2.dist-info/WHEEL +5 -0
- agentrun_sdk-0.1.2.dist-info/entry_points.txt +2 -0
- agentrun_sdk-0.1.2.dist-info/top_level.txt +3 -0
- agentrun_wrapper/__init__.py +11 -0
- agentrun_wrapper/_utils/__init__.py +6 -0
- agentrun_wrapper/_utils/endpoints.py +16 -0
- agentrun_wrapper/identity/__init__.py +5 -0
- agentrun_wrapper/identity/auth.py +211 -0
- agentrun_wrapper/memory/__init__.py +6 -0
- agentrun_wrapper/memory/client.py +1697 -0
- agentrun_wrapper/memory/constants.py +103 -0
- agentrun_wrapper/memory/controlplane.py +626 -0
- agentrun_wrapper/py.typed +1 -0
- agentrun_wrapper/runtime/__init__.py +13 -0
- agentrun_wrapper/runtime/app.py +473 -0
- agentrun_wrapper/runtime/context.py +34 -0
- agentrun_wrapper/runtime/models.py +25 -0
- agentrun_wrapper/services/__init__.py +1 -0
- agentrun_wrapper/services/identity.py +192 -0
- agentrun_wrapper/tools/__init__.py +6 -0
- agentrun_wrapper/tools/browser_client.py +325 -0
- agentrun_wrapper/tools/code_interpreter_client.py +186 -0
|
@@ -0,0 +1,436 @@
|
|
|
1
|
+
"""OpenAI model provider.
|
|
2
|
+
|
|
3
|
+
- Docs: https://platform.openai.com/docs/overview
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import base64
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import mimetypes
|
|
10
|
+
from typing import Any, AsyncGenerator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast
|
|
11
|
+
|
|
12
|
+
import openai
|
|
13
|
+
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
|
|
14
|
+
from pydantic import BaseModel
|
|
15
|
+
from typing_extensions import Unpack, override
|
|
16
|
+
|
|
17
|
+
from ..types.content import ContentBlock, Messages
|
|
18
|
+
from ..types.streaming import StreamEvent
|
|
19
|
+
from ..types.tools import ToolResult, ToolSpec, ToolUse
|
|
20
|
+
from .model import Model
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
T = TypeVar("T", bound=BaseModel)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Client(Protocol):
|
|
28
|
+
"""Protocol defining the OpenAI-compatible interface for the underlying provider client."""
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
# pragma: no cover
|
|
32
|
+
def chat(self) -> Any:
|
|
33
|
+
"""Chat completions interface."""
|
|
34
|
+
...
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class OpenAIModel(Model):
|
|
38
|
+
"""OpenAI model provider implementation."""
|
|
39
|
+
|
|
40
|
+
client: Client
|
|
41
|
+
|
|
42
|
+
class OpenAIConfig(TypedDict, total=False):
|
|
43
|
+
"""Configuration options for OpenAI models.
|
|
44
|
+
|
|
45
|
+
Attributes:
|
|
46
|
+
model_id: Model ID (e.g., "gpt-4o").
|
|
47
|
+
For a complete list of supported models, see https://platform.openai.com/docs/models.
|
|
48
|
+
params: Model parameters (e.g., max_tokens).
|
|
49
|
+
For a complete list of supported parameters, see
|
|
50
|
+
https://platform.openai.com/docs/api-reference/chat/create.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
model_id: str
|
|
54
|
+
params: Optional[dict[str, Any]]
|
|
55
|
+
|
|
56
|
+
def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None:
|
|
57
|
+
"""Initialize provider instance.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
client_args: Arguments for the OpenAI client.
|
|
61
|
+
For a complete list of supported arguments, see https://pypi.org/project/openai/.
|
|
62
|
+
**model_config: Configuration options for the OpenAI model.
|
|
63
|
+
"""
|
|
64
|
+
self.config = dict(model_config)
|
|
65
|
+
|
|
66
|
+
logger.debug("config=<%s> | initializing", self.config)
|
|
67
|
+
|
|
68
|
+
client_args = client_args or {}
|
|
69
|
+
self.client = openai.AsyncOpenAI(**client_args)
|
|
70
|
+
|
|
71
|
+
@override
|
|
72
|
+
def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override]
|
|
73
|
+
"""Update the OpenAI model configuration with the provided arguments.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
**model_config: Configuration overrides.
|
|
77
|
+
"""
|
|
78
|
+
self.config.update(model_config)
|
|
79
|
+
|
|
80
|
+
@override
|
|
81
|
+
def get_config(self) -> OpenAIConfig:
|
|
82
|
+
"""Get the OpenAI model configuration.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
The OpenAI model configuration.
|
|
86
|
+
"""
|
|
87
|
+
return cast(OpenAIModel.OpenAIConfig, self.config)
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]:
|
|
91
|
+
"""Format an OpenAI compatible content block.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
content: Message content.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
OpenAI compatible content block.
|
|
98
|
+
|
|
99
|
+
Raises:
|
|
100
|
+
TypeError: If the content block type cannot be converted to an OpenAI-compatible format.
|
|
101
|
+
"""
|
|
102
|
+
if "document" in content:
|
|
103
|
+
mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream")
|
|
104
|
+
file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8")
|
|
105
|
+
return {
|
|
106
|
+
"file": {
|
|
107
|
+
"file_data": f"data:{mime_type};base64,{file_data}",
|
|
108
|
+
"filename": content["document"]["name"],
|
|
109
|
+
},
|
|
110
|
+
"type": "file",
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
if "image" in content:
|
|
114
|
+
mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream")
|
|
115
|
+
image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8")
|
|
116
|
+
|
|
117
|
+
return {
|
|
118
|
+
"image_url": {
|
|
119
|
+
"detail": "auto",
|
|
120
|
+
"format": mime_type,
|
|
121
|
+
"url": f"data:{mime_type};base64,{image_data}",
|
|
122
|
+
},
|
|
123
|
+
"type": "image_url",
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
if "text" in content:
|
|
127
|
+
return {"text": content["text"], "type": "text"}
|
|
128
|
+
|
|
129
|
+
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
|
|
130
|
+
|
|
131
|
+
@classmethod
|
|
132
|
+
def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]:
|
|
133
|
+
"""Format an OpenAI compatible tool call.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
tool_use: Tool use requested by the model.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
OpenAI compatible tool call.
|
|
140
|
+
"""
|
|
141
|
+
return {
|
|
142
|
+
"function": {
|
|
143
|
+
"arguments": json.dumps(tool_use["input"]),
|
|
144
|
+
"name": tool_use["name"],
|
|
145
|
+
},
|
|
146
|
+
"id": tool_use["toolUseId"],
|
|
147
|
+
"type": "function",
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
@classmethod
|
|
151
|
+
def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]:
|
|
152
|
+
"""Format an OpenAI compatible tool message.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
tool_result: Tool result collected from a tool execution.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
OpenAI compatible tool message.
|
|
159
|
+
"""
|
|
160
|
+
contents = cast(
|
|
161
|
+
list[ContentBlock],
|
|
162
|
+
[
|
|
163
|
+
{"text": json.dumps(content["json"])} if "json" in content else content
|
|
164
|
+
for content in tool_result["content"]
|
|
165
|
+
],
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
return {
|
|
169
|
+
"role": "tool",
|
|
170
|
+
"tool_call_id": tool_result["toolUseId"],
|
|
171
|
+
"content": [cls.format_request_message_content(content) for content in contents],
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
@classmethod
|
|
175
|
+
def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
|
|
176
|
+
"""Format an OpenAI compatible messages array.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
messages: List of message objects to be processed by the model.
|
|
180
|
+
system_prompt: System prompt to provide context to the model.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
An OpenAI compatible messages array.
|
|
184
|
+
"""
|
|
185
|
+
formatted_messages: list[dict[str, Any]]
|
|
186
|
+
formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
|
|
187
|
+
|
|
188
|
+
for message in messages:
|
|
189
|
+
contents = message["content"]
|
|
190
|
+
|
|
191
|
+
formatted_contents = [
|
|
192
|
+
cls.format_request_message_content(content)
|
|
193
|
+
for content in contents
|
|
194
|
+
if not any(block_type in content for block_type in ["toolResult", "toolUse"])
|
|
195
|
+
]
|
|
196
|
+
formatted_tool_calls = [
|
|
197
|
+
cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content
|
|
198
|
+
]
|
|
199
|
+
formatted_tool_messages = [
|
|
200
|
+
cls.format_request_tool_message(content["toolResult"])
|
|
201
|
+
for content in contents
|
|
202
|
+
if "toolResult" in content
|
|
203
|
+
]
|
|
204
|
+
|
|
205
|
+
formatted_message = {
|
|
206
|
+
"role": message["role"],
|
|
207
|
+
"content": formatted_contents,
|
|
208
|
+
**({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}),
|
|
209
|
+
}
|
|
210
|
+
formatted_messages.append(formatted_message)
|
|
211
|
+
formatted_messages.extend(formatted_tool_messages)
|
|
212
|
+
|
|
213
|
+
return [message for message in formatted_messages if message["content"] or "tool_calls" in message]
|
|
214
|
+
|
|
215
|
+
def format_request(
|
|
216
|
+
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
|
|
217
|
+
) -> dict[str, Any]:
|
|
218
|
+
"""Format an OpenAI compatible chat streaming request.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
messages: List of message objects to be processed by the model.
|
|
222
|
+
tool_specs: List of tool specifications to make available to the model.
|
|
223
|
+
system_prompt: System prompt to provide context to the model.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
An OpenAI compatible chat streaming request.
|
|
227
|
+
|
|
228
|
+
Raises:
|
|
229
|
+
TypeError: If a message contains a content block type that cannot be converted to an OpenAI-compatible
|
|
230
|
+
format.
|
|
231
|
+
"""
|
|
232
|
+
return {
|
|
233
|
+
"messages": self.format_request_messages(messages, system_prompt),
|
|
234
|
+
"model": self.config["model_id"],
|
|
235
|
+
"stream": True,
|
|
236
|
+
"stream_options": {"include_usage": True},
|
|
237
|
+
"tools": [
|
|
238
|
+
{
|
|
239
|
+
"type": "function",
|
|
240
|
+
"function": {
|
|
241
|
+
"name": tool_spec["name"],
|
|
242
|
+
"description": tool_spec["description"],
|
|
243
|
+
"parameters": tool_spec["inputSchema"]["json"],
|
|
244
|
+
},
|
|
245
|
+
}
|
|
246
|
+
for tool_spec in tool_specs or []
|
|
247
|
+
],
|
|
248
|
+
**cast(dict[str, Any], self.config.get("params", {})),
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
|
|
252
|
+
"""Format an OpenAI response event into a standardized message chunk.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
event: A response event from the OpenAI compatible model.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
The formatted chunk.
|
|
259
|
+
|
|
260
|
+
Raises:
|
|
261
|
+
RuntimeError: If chunk_type is not recognized.
|
|
262
|
+
This error should never be encountered as chunk_type is controlled in the stream method.
|
|
263
|
+
"""
|
|
264
|
+
match event["chunk_type"]:
|
|
265
|
+
case "message_start":
|
|
266
|
+
return {"messageStart": {"role": "assistant"}}
|
|
267
|
+
|
|
268
|
+
case "content_start":
|
|
269
|
+
if event["data_type"] == "tool":
|
|
270
|
+
return {
|
|
271
|
+
"contentBlockStart": {
|
|
272
|
+
"start": {
|
|
273
|
+
"toolUse": {
|
|
274
|
+
"name": event["data"].function.name,
|
|
275
|
+
"toolUseId": event["data"].id,
|
|
276
|
+
}
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
return {"contentBlockStart": {"start": {}}}
|
|
282
|
+
|
|
283
|
+
case "content_delta":
|
|
284
|
+
if event["data_type"] == "tool":
|
|
285
|
+
return {
|
|
286
|
+
"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}}
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
if event["data_type"] == "reasoning_content":
|
|
290
|
+
return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}}
|
|
291
|
+
|
|
292
|
+
return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
|
|
293
|
+
|
|
294
|
+
case "content_stop":
|
|
295
|
+
return {"contentBlockStop": {}}
|
|
296
|
+
|
|
297
|
+
case "message_stop":
|
|
298
|
+
match event["data"]:
|
|
299
|
+
case "tool_calls":
|
|
300
|
+
return {"messageStop": {"stopReason": "tool_use"}}
|
|
301
|
+
case "length":
|
|
302
|
+
return {"messageStop": {"stopReason": "max_tokens"}}
|
|
303
|
+
case _:
|
|
304
|
+
return {"messageStop": {"stopReason": "end_turn"}}
|
|
305
|
+
|
|
306
|
+
case "metadata":
|
|
307
|
+
return {
|
|
308
|
+
"metadata": {
|
|
309
|
+
"usage": {
|
|
310
|
+
"inputTokens": event["data"].prompt_tokens,
|
|
311
|
+
"outputTokens": event["data"].completion_tokens,
|
|
312
|
+
"totalTokens": event["data"].total_tokens,
|
|
313
|
+
},
|
|
314
|
+
"metrics": {
|
|
315
|
+
"latencyMs": 0, # TODO
|
|
316
|
+
},
|
|
317
|
+
},
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
case _:
|
|
321
|
+
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
|
|
322
|
+
|
|
323
|
+
@override
|
|
324
|
+
async def stream(
|
|
325
|
+
self,
|
|
326
|
+
messages: Messages,
|
|
327
|
+
tool_specs: Optional[list[ToolSpec]] = None,
|
|
328
|
+
system_prompt: Optional[str] = None,
|
|
329
|
+
**kwargs: Any,
|
|
330
|
+
) -> AsyncGenerator[StreamEvent, None]:
|
|
331
|
+
"""Stream conversation with the OpenAI model.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
messages: List of message objects to be processed by the model.
|
|
335
|
+
tool_specs: List of tool specifications to make available to the model.
|
|
336
|
+
system_prompt: System prompt to provide context to the model.
|
|
337
|
+
**kwargs: Additional keyword arguments for future extensibility.
|
|
338
|
+
|
|
339
|
+
Yields:
|
|
340
|
+
Formatted message chunks from the model.
|
|
341
|
+
"""
|
|
342
|
+
logger.debug("formatting request")
|
|
343
|
+
request = self.format_request(messages, tool_specs, system_prompt)
|
|
344
|
+
logger.debug("formatted request=<%s>", request)
|
|
345
|
+
|
|
346
|
+
logger.debug("invoking model")
|
|
347
|
+
response = await self.client.chat.completions.create(**request)
|
|
348
|
+
|
|
349
|
+
logger.debug("got response from model")
|
|
350
|
+
yield self.format_chunk({"chunk_type": "message_start"})
|
|
351
|
+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
|
|
352
|
+
|
|
353
|
+
tool_calls: dict[int, list[Any]] = {}
|
|
354
|
+
|
|
355
|
+
async for event in response:
|
|
356
|
+
# Defensive: skip events with empty or missing choices
|
|
357
|
+
if not getattr(event, "choices", None):
|
|
358
|
+
continue
|
|
359
|
+
choice = event.choices[0]
|
|
360
|
+
|
|
361
|
+
if choice.delta.content:
|
|
362
|
+
yield self.format_chunk(
|
|
363
|
+
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
|
|
367
|
+
yield self.format_chunk(
|
|
368
|
+
{
|
|
369
|
+
"chunk_type": "content_delta",
|
|
370
|
+
"data_type": "reasoning_content",
|
|
371
|
+
"data": choice.delta.reasoning_content,
|
|
372
|
+
}
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
for tool_call in choice.delta.tool_calls or []:
|
|
376
|
+
tool_calls.setdefault(tool_call.index, []).append(tool_call)
|
|
377
|
+
|
|
378
|
+
if choice.finish_reason:
|
|
379
|
+
break
|
|
380
|
+
|
|
381
|
+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
|
|
382
|
+
|
|
383
|
+
for tool_deltas in tool_calls.values():
|
|
384
|
+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
|
|
385
|
+
|
|
386
|
+
for tool_delta in tool_deltas:
|
|
387
|
+
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
|
|
388
|
+
|
|
389
|
+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
|
|
390
|
+
|
|
391
|
+
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
|
|
392
|
+
|
|
393
|
+
# Skip remaining events as we don't have use for anything except the final usage payload
|
|
394
|
+
async for event in response:
|
|
395
|
+
_ = event
|
|
396
|
+
|
|
397
|
+
if event.usage:
|
|
398
|
+
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
|
|
399
|
+
|
|
400
|
+
logger.debug("finished streaming response from model")
|
|
401
|
+
|
|
402
|
+
@override
|
|
403
|
+
async def structured_output(
|
|
404
|
+
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
|
|
405
|
+
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
|
|
406
|
+
"""Get structured output from the model.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
output_model: The output model to use for the agent.
|
|
410
|
+
prompt: The prompt messages to use for the agent.
|
|
411
|
+
system_prompt: System prompt to provide context to the model.
|
|
412
|
+
**kwargs: Additional keyword arguments for future extensibility.
|
|
413
|
+
|
|
414
|
+
Yields:
|
|
415
|
+
Model events with the last being the structured output.
|
|
416
|
+
"""
|
|
417
|
+
response: ParsedChatCompletion = await self.client.beta.chat.completions.parse( # type: ignore
|
|
418
|
+
model=self.get_config()["model_id"],
|
|
419
|
+
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
|
|
420
|
+
response_format=output_model,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
parsed: T | None = None
|
|
424
|
+
# Find the first choice with tool_calls
|
|
425
|
+
if len(response.choices) > 1:
|
|
426
|
+
raise ValueError("Multiple choices found in the OpenAI response.")
|
|
427
|
+
|
|
428
|
+
for choice in response.choices:
|
|
429
|
+
if isinstance(choice.message.parsed, output_model):
|
|
430
|
+
parsed = choice.message.parsed
|
|
431
|
+
break
|
|
432
|
+
|
|
433
|
+
if parsed:
|
|
434
|
+
yield {"output": parsed}
|
|
435
|
+
else:
|
|
436
|
+
raise ValueError("No valid tool use or tool use input was found in the OpenAI response.")
|