pydantic-ai-slim 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +13 -29
- pydantic_ai/_result.py +52 -38
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +20 -8
- pydantic_ai/agent.py +431 -167
- pydantic_ai/messages.py +90 -48
- pydantic_ai/models/__init__.py +59 -42
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +66 -44
- pydantic_ai/models/gemini.py +160 -117
- pydantic_ai/models/groq.py +125 -108
- pydantic_ai/models/mistral.py +680 -0
- pydantic_ai/models/ollama.py +116 -0
- pydantic_ai/models/openai.py +145 -114
- pydantic_ai/models/test.py +109 -77
- pydantic_ai/models/vertexai.py +14 -9
- pydantic_ai/result.py +35 -37
- pydantic_ai/settings.py +72 -0
- pydantic_ai/tools.py +140 -45
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.13.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.13.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.13.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.11.dist-info/RECORD +0 -22
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Literal, Union
|
|
5
|
+
|
|
6
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
7
|
+
|
|
8
|
+
from ..tools import ToolDefinition
|
|
9
|
+
from . import (
|
|
10
|
+
AgentModel,
|
|
11
|
+
Model,
|
|
12
|
+
cached_async_http_client,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from openai import AsyncOpenAI
|
|
17
|
+
except ImportError as e:
|
|
18
|
+
raise ImportError(
|
|
19
|
+
'Please install `openai` to use the OpenAI model, '
|
|
20
|
+
"you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
|
|
21
|
+
) from e
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
from .openai import OpenAIModel
|
|
25
|
+
|
|
26
|
+
CommonOllamaModelNames = Literal[
|
|
27
|
+
'codellama',
|
|
28
|
+
'gemma',
|
|
29
|
+
'gemma2',
|
|
30
|
+
'llama3',
|
|
31
|
+
'llama3.1',
|
|
32
|
+
'llama3.2',
|
|
33
|
+
'llama3.2-vision',
|
|
34
|
+
'llama3.3',
|
|
35
|
+
'mistral',
|
|
36
|
+
'mistral-nemo',
|
|
37
|
+
'mixtral',
|
|
38
|
+
'phi3',
|
|
39
|
+
'qwq',
|
|
40
|
+
'qwen',
|
|
41
|
+
'qwen2',
|
|
42
|
+
'qwen2.5',
|
|
43
|
+
'starcoder2',
|
|
44
|
+
]
|
|
45
|
+
"""This contains just the most common ollama models.
|
|
46
|
+
|
|
47
|
+
For a full list see [ollama.com/library](https://ollama.com/library).
|
|
48
|
+
"""
|
|
49
|
+
OllamaModelName = Union[CommonOllamaModelNames, str]
|
|
50
|
+
"""Possible ollama models.
|
|
51
|
+
|
|
52
|
+
Since Ollama supports hundreds of models, we explicitly list the most models but
|
|
53
|
+
allow any name in the type hints.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass(init=False)
|
|
58
|
+
class OllamaModel(Model):
|
|
59
|
+
"""A model that implements Ollama using the OpenAI API.
|
|
60
|
+
|
|
61
|
+
Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact with the Ollama server.
|
|
62
|
+
|
|
63
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
model_name: OllamaModelName
|
|
67
|
+
openai_model: OpenAIModel
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
model_name: OllamaModelName,
|
|
72
|
+
*,
|
|
73
|
+
base_url: str | None = 'http://localhost:11434/v1/',
|
|
74
|
+
openai_client: AsyncOpenAI | None = None,
|
|
75
|
+
http_client: AsyncHTTPClient | None = None,
|
|
76
|
+
):
|
|
77
|
+
"""Initialize an Ollama model.
|
|
78
|
+
|
|
79
|
+
Ollama has built-in compatability for the OpenAI chat completions API ([source](https://ollama.com/blog/openai-compatibility)), so we reuse the
|
|
80
|
+
[`OpenAIModel`][pydantic_ai.models.openai.OpenAIModel] here.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
model_name: The name of the Ollama model to use. List of models available [here](https://ollama.com/library)
|
|
84
|
+
You must first download the model (`ollama pull <MODEL-NAME>`) in order to use the model
|
|
85
|
+
base_url: The base url for the ollama requests. The default value is the ollama default
|
|
86
|
+
openai_client: An existing
|
|
87
|
+
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
88
|
+
client to use, if provided, `base_url` and `http_client` must be `None`.
|
|
89
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
90
|
+
"""
|
|
91
|
+
self.model_name = model_name
|
|
92
|
+
if openai_client is not None:
|
|
93
|
+
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
94
|
+
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
95
|
+
self.openai_model = OpenAIModel(model_name=model_name, openai_client=openai_client)
|
|
96
|
+
else:
|
|
97
|
+
# API key is not required for ollama but a value is required to create the client
|
|
98
|
+
http_client_ = http_client or cached_async_http_client()
|
|
99
|
+
oai_client = AsyncOpenAI(base_url=base_url, api_key='ollama', http_client=http_client_)
|
|
100
|
+
self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client)
|
|
101
|
+
|
|
102
|
+
async def agent_model(
|
|
103
|
+
self,
|
|
104
|
+
*,
|
|
105
|
+
function_tools: list[ToolDefinition],
|
|
106
|
+
allow_text_result: bool,
|
|
107
|
+
result_tools: list[ToolDefinition],
|
|
108
|
+
) -> AgentModel:
|
|
109
|
+
return await self.openai_model.agent_model(
|
|
110
|
+
function_tools=function_tools,
|
|
111
|
+
allow_text_result=allow_text_result,
|
|
112
|
+
result_tools=result_tools,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def name(self) -> str:
|
|
116
|
+
return f'ollama:{self.model_name}'
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -1,28 +1,34 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Iterable
|
|
3
|
+
from collections.abc import AsyncIterator, Iterable
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
7
|
-
from
|
|
7
|
+
from itertools import chain
|
|
8
|
+
from typing import Literal, Union, overload
|
|
8
9
|
|
|
9
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
10
11
|
from typing_extensions import assert_never
|
|
11
12
|
|
|
12
13
|
from .. import UnexpectedModelBehavior, _utils, result
|
|
14
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
15
|
from ..messages import (
|
|
14
16
|
ArgsJson,
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
17
|
+
ModelMessage,
|
|
18
|
+
ModelRequest,
|
|
19
|
+
ModelResponse,
|
|
20
|
+
ModelResponsePart,
|
|
21
|
+
RetryPromptPart,
|
|
22
|
+
SystemPromptPart,
|
|
23
|
+
TextPart,
|
|
24
|
+
ToolCallPart,
|
|
25
|
+
ToolReturnPart,
|
|
26
|
+
UserPromptPart,
|
|
22
27
|
)
|
|
23
28
|
from ..result import Cost
|
|
29
|
+
from ..settings import ModelSettings
|
|
30
|
+
from ..tools import ToolDefinition
|
|
24
31
|
from . import (
|
|
25
|
-
AbstractToolDefinition,
|
|
26
32
|
AgentModel,
|
|
27
33
|
EitherStreamedResponse,
|
|
28
34
|
Model,
|
|
@@ -37,11 +43,17 @@ try:
|
|
|
37
43
|
from openai.types import ChatModel, chat
|
|
38
44
|
from openai.types.chat import ChatCompletionChunk
|
|
39
45
|
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
|
40
|
-
except ImportError as
|
|
46
|
+
except ImportError as _import_error:
|
|
41
47
|
raise ImportError(
|
|
42
48
|
'Please install `openai` to use the OpenAI model, '
|
|
43
|
-
"you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
|
|
44
|
-
) from
|
|
49
|
+
"you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
|
|
50
|
+
) from _import_error
|
|
51
|
+
|
|
52
|
+
OpenAIModelName = Union[ChatModel, str]
|
|
53
|
+
"""
|
|
54
|
+
Using this more broad type for the model name instead of the ChatModel definition
|
|
55
|
+
allows this model to be used more easily with other model types (ie, Ollama)
|
|
56
|
+
"""
|
|
45
57
|
|
|
46
58
|
|
|
47
59
|
@dataclass(init=False)
|
|
@@ -53,13 +65,14 @@ class OpenAIModel(Model):
|
|
|
53
65
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
54
66
|
"""
|
|
55
67
|
|
|
56
|
-
model_name:
|
|
68
|
+
model_name: OpenAIModelName
|
|
57
69
|
client: AsyncOpenAI = field(repr=False)
|
|
58
70
|
|
|
59
71
|
def __init__(
|
|
60
72
|
self,
|
|
61
|
-
model_name:
|
|
73
|
+
model_name: OpenAIModelName,
|
|
62
74
|
*,
|
|
75
|
+
base_url: str | None = None,
|
|
63
76
|
api_key: str | None = None,
|
|
64
77
|
openai_client: AsyncOpenAI | None = None,
|
|
65
78
|
http_client: AsyncHTTPClient | None = None,
|
|
@@ -70,32 +83,36 @@ class OpenAIModel(Model):
|
|
|
70
83
|
model_name: The name of the OpenAI model to use. List of model names available
|
|
71
84
|
[here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
|
|
72
85
|
(Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
|
|
86
|
+
base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable
|
|
87
|
+
will be used if available. Otherwise, defaults to OpenAI's base url.
|
|
73
88
|
api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
|
|
74
89
|
will be used if available.
|
|
75
90
|
openai_client: An existing
|
|
76
91
|
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
77
|
-
client to use
|
|
92
|
+
client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
|
|
78
93
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
79
94
|
"""
|
|
80
|
-
self.model_name:
|
|
95
|
+
self.model_name: OpenAIModelName = model_name
|
|
81
96
|
if openai_client is not None:
|
|
82
97
|
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
98
|
+
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
83
99
|
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
84
100
|
self.client = openai_client
|
|
85
101
|
elif http_client is not None:
|
|
86
|
-
self.client = AsyncOpenAI(api_key=api_key, http_client=http_client)
|
|
102
|
+
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
|
|
87
103
|
else:
|
|
88
|
-
self.client = AsyncOpenAI(api_key=api_key, http_client=cached_async_http_client())
|
|
104
|
+
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
89
105
|
|
|
90
106
|
async def agent_model(
|
|
91
107
|
self,
|
|
92
|
-
|
|
108
|
+
*,
|
|
109
|
+
function_tools: list[ToolDefinition],
|
|
93
110
|
allow_text_result: bool,
|
|
94
|
-
result_tools:
|
|
111
|
+
result_tools: list[ToolDefinition],
|
|
95
112
|
) -> AgentModel:
|
|
96
113
|
check_allow_model_requests()
|
|
97
|
-
tools = [self._map_tool_definition(r) for r in function_tools
|
|
98
|
-
if result_tools
|
|
114
|
+
tools = [self._map_tool_definition(r) for r in function_tools]
|
|
115
|
+
if result_tools:
|
|
99
116
|
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
100
117
|
return OpenAIAgentModel(
|
|
101
118
|
self.client,
|
|
@@ -108,13 +125,13 @@ class OpenAIModel(Model):
|
|
|
108
125
|
return f'openai:{self.model_name}'
|
|
109
126
|
|
|
110
127
|
@staticmethod
|
|
111
|
-
def _map_tool_definition(f:
|
|
128
|
+
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
112
129
|
return {
|
|
113
130
|
'type': 'function',
|
|
114
131
|
'function': {
|
|
115
132
|
'name': f.name,
|
|
116
133
|
'description': f.description,
|
|
117
|
-
'parameters': f.
|
|
134
|
+
'parameters': f.parameters_json_schema,
|
|
118
135
|
},
|
|
119
136
|
}
|
|
120
137
|
|
|
@@ -124,32 +141,38 @@ class OpenAIAgentModel(AgentModel):
|
|
|
124
141
|
"""Implementation of `AgentModel` for OpenAI models."""
|
|
125
142
|
|
|
126
143
|
client: AsyncOpenAI
|
|
127
|
-
model_name:
|
|
144
|
+
model_name: OpenAIModelName
|
|
128
145
|
allow_text_result: bool
|
|
129
146
|
tools: list[chat.ChatCompletionToolParam]
|
|
130
147
|
|
|
131
|
-
async def request(
|
|
132
|
-
|
|
148
|
+
async def request(
|
|
149
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
150
|
+
) -> tuple[ModelResponse, result.Cost]:
|
|
151
|
+
response = await self._completions_create(messages, False, model_settings)
|
|
133
152
|
return self._process_response(response), _map_cost(response)
|
|
134
153
|
|
|
135
154
|
@asynccontextmanager
|
|
136
|
-
async def request_stream(
|
|
137
|
-
|
|
155
|
+
async def request_stream(
|
|
156
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
157
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
158
|
+
response = await self._completions_create(messages, True, model_settings)
|
|
138
159
|
async with response:
|
|
139
160
|
yield await self._process_streamed_response(response)
|
|
140
161
|
|
|
141
162
|
@overload
|
|
142
163
|
async def _completions_create(
|
|
143
|
-
self, messages: list[
|
|
164
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
|
|
144
165
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
145
166
|
pass
|
|
146
167
|
|
|
147
168
|
@overload
|
|
148
|
-
async def _completions_create(
|
|
169
|
+
async def _completions_create(
|
|
170
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
|
|
171
|
+
) -> chat.ChatCompletion:
|
|
149
172
|
pass
|
|
150
173
|
|
|
151
174
|
async def _completions_create(
|
|
152
|
-
self, messages: list[
|
|
175
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
|
|
153
176
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
154
177
|
# standalone function to make it easier to override
|
|
155
178
|
if not self.tools:
|
|
@@ -159,7 +182,10 @@ class OpenAIAgentModel(AgentModel):
|
|
|
159
182
|
else:
|
|
160
183
|
tool_choice = 'auto'
|
|
161
184
|
|
|
162
|
-
openai_messages =
|
|
185
|
+
openai_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
186
|
+
|
|
187
|
+
model_settings = model_settings or {}
|
|
188
|
+
|
|
163
189
|
return await self.client.chat.completions.create(
|
|
164
190
|
model=self.model_name,
|
|
165
191
|
messages=openai_messages,
|
|
@@ -169,93 +195,104 @@ class OpenAIAgentModel(AgentModel):
|
|
|
169
195
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
170
196
|
stream=stream,
|
|
171
197
|
stream_options={'include_usage': True} if stream else NOT_GIVEN,
|
|
198
|
+
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
199
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
200
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
201
|
+
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
172
202
|
)
|
|
173
203
|
|
|
174
204
|
@staticmethod
|
|
175
|
-
def _process_response(response: chat.ChatCompletion) ->
|
|
205
|
+
def _process_response(response: chat.ChatCompletion) -> ModelResponse:
|
|
176
206
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
177
207
|
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
178
208
|
choice = response.choices[0]
|
|
209
|
+
items: list[ModelResponsePart] = []
|
|
210
|
+
if choice.message.content is not None:
|
|
211
|
+
items.append(TextPart(choice.message.content))
|
|
179
212
|
if choice.message.tool_calls is not None:
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
)
|
|
184
|
-
else:
|
|
185
|
-
assert choice.message.content is not None, choice
|
|
186
|
-
return ModelTextResponse(choice.message.content, timestamp=timestamp)
|
|
213
|
+
for c in choice.message.tool_calls:
|
|
214
|
+
items.append(ToolCallPart.from_json(c.function.name, c.function.arguments, c.id))
|
|
215
|
+
return ModelResponse(items, timestamp=timestamp)
|
|
187
216
|
|
|
188
217
|
@staticmethod
|
|
189
218
|
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
|
190
219
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
timestamp = datetime.fromtimestamp(first_chunk.created, tz=timezone.utc)
|
|
196
|
-
delta = first_chunk.choices[0].delta
|
|
197
|
-
start_cost = _map_cost(first_chunk)
|
|
198
|
-
|
|
199
|
-
# the first chunk may only contain `role`, so we iterate until we get either `tool_calls` or `content`
|
|
200
|
-
while delta.tool_calls is None and delta.content is None:
|
|
220
|
+
timestamp: datetime | None = None
|
|
221
|
+
start_cost = Cost()
|
|
222
|
+
# the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
|
|
223
|
+
while True:
|
|
201
224
|
try:
|
|
202
|
-
|
|
225
|
+
chunk = await response.__anext__()
|
|
203
226
|
except StopAsyncIteration as e:
|
|
204
227
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
205
|
-
delta = next_chunk.choices[0].delta
|
|
206
|
-
start_cost += _map_cost(next_chunk)
|
|
207
228
|
|
|
208
|
-
|
|
209
|
-
|
|
229
|
+
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
230
|
+
start_cost += _map_cost(chunk)
|
|
231
|
+
|
|
232
|
+
if chunk.choices:
|
|
233
|
+
delta = chunk.choices[0].delta
|
|
234
|
+
|
|
235
|
+
if delta.content is not None:
|
|
236
|
+
return OpenAIStreamTextResponse(delta.content, response, timestamp, start_cost)
|
|
237
|
+
elif delta.tool_calls is not None:
|
|
238
|
+
return OpenAIStreamStructuredResponse(
|
|
239
|
+
response,
|
|
240
|
+
{c.index: c for c in delta.tool_calls},
|
|
241
|
+
timestamp,
|
|
242
|
+
start_cost,
|
|
243
|
+
)
|
|
244
|
+
# else continue until we get either delta.content or delta.tool_calls
|
|
245
|
+
|
|
246
|
+
@classmethod
|
|
247
|
+
def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
248
|
+
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
|
|
249
|
+
if isinstance(message, ModelRequest):
|
|
250
|
+
yield from cls._map_user_message(message)
|
|
251
|
+
elif isinstance(message, ModelResponse):
|
|
252
|
+
texts: list[str] = []
|
|
253
|
+
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
|
|
254
|
+
for item in message.parts:
|
|
255
|
+
if isinstance(item, TextPart):
|
|
256
|
+
texts.append(item.content)
|
|
257
|
+
elif isinstance(item, ToolCallPart):
|
|
258
|
+
tool_calls.append(_map_tool_call(item))
|
|
259
|
+
else:
|
|
260
|
+
assert_never(item)
|
|
261
|
+
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
|
|
262
|
+
if texts:
|
|
263
|
+
# Note: model responses from this model should only have one text item, so the following
|
|
264
|
+
# shouldn't merge multiple texts into one unless you switch models between runs:
|
|
265
|
+
message_param['content'] = '\n\n'.join(texts)
|
|
266
|
+
if tool_calls:
|
|
267
|
+
message_param['tool_calls'] = tool_calls
|
|
268
|
+
yield message_param
|
|
210
269
|
else:
|
|
211
|
-
|
|
212
|
-
return OpenAIStreamStructuredResponse(
|
|
213
|
-
response,
|
|
214
|
-
{c.index: c for c in delta.tool_calls},
|
|
215
|
-
timestamp,
|
|
216
|
-
start_cost,
|
|
217
|
-
)
|
|
270
|
+
assert_never(message)
|
|
218
271
|
|
|
219
|
-
@
|
|
220
|
-
def
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
elif message.role == 'tool-return':
|
|
229
|
-
# ToolReturn ->
|
|
230
|
-
return chat.ChatCompletionToolMessageParam(
|
|
231
|
-
role='tool',
|
|
232
|
-
tool_call_id=_guard_tool_id(message),
|
|
233
|
-
content=message.model_response_str(),
|
|
234
|
-
)
|
|
235
|
-
elif message.role == 'retry-prompt':
|
|
236
|
-
# RetryPrompt ->
|
|
237
|
-
if message.tool_name is None:
|
|
238
|
-
return chat.ChatCompletionUserMessageParam(role='user', content=message.model_response())
|
|
239
|
-
else:
|
|
240
|
-
return chat.ChatCompletionToolMessageParam(
|
|
272
|
+
@classmethod
|
|
273
|
+
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
274
|
+
for part in message.parts:
|
|
275
|
+
if isinstance(part, SystemPromptPart):
|
|
276
|
+
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
|
|
277
|
+
elif isinstance(part, UserPromptPart):
|
|
278
|
+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
|
|
279
|
+
elif isinstance(part, ToolReturnPart):
|
|
280
|
+
yield chat.ChatCompletionToolMessageParam(
|
|
241
281
|
role='tool',
|
|
242
|
-
tool_call_id=
|
|
243
|
-
content=
|
|
282
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
|
|
283
|
+
content=part.model_response_str(),
|
|
244
284
|
)
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
)
|
|
257
|
-
else:
|
|
258
|
-
assert_never(message)
|
|
285
|
+
elif isinstance(part, RetryPromptPart):
|
|
286
|
+
if part.tool_name is None:
|
|
287
|
+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
|
|
288
|
+
else:
|
|
289
|
+
yield chat.ChatCompletionToolMessageParam(
|
|
290
|
+
role='tool',
|
|
291
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
|
|
292
|
+
content=part.model_response(),
|
|
293
|
+
)
|
|
294
|
+
else:
|
|
295
|
+
assert_never(part)
|
|
259
296
|
|
|
260
297
|
|
|
261
298
|
@dataclass
|
|
@@ -330,14 +367,14 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
|
|
|
330
367
|
else:
|
|
331
368
|
self._delta_tool_calls[new.index] = new
|
|
332
369
|
|
|
333
|
-
def get(self, *, final: bool = False) ->
|
|
334
|
-
|
|
370
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
371
|
+
items: list[ModelResponsePart] = []
|
|
335
372
|
for c in self._delta_tool_calls.values():
|
|
336
373
|
if f := c.function:
|
|
337
374
|
if f.name is not None and f.arguments is not None:
|
|
338
|
-
|
|
375
|
+
items.append(ToolCallPart.from_json(f.name, f.arguments, c.id))
|
|
339
376
|
|
|
340
|
-
return
|
|
377
|
+
return ModelResponse(items, timestamp=self._timestamp)
|
|
341
378
|
|
|
342
379
|
def cost(self) -> Cost:
|
|
343
380
|
return self._cost
|
|
@@ -346,16 +383,10 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
|
|
|
346
383
|
return self._timestamp
|
|
347
384
|
|
|
348
385
|
|
|
349
|
-
def
|
|
350
|
-
"""Type guard that checks a `tool_id` is not None both for static typing and runtime."""
|
|
351
|
-
assert t.tool_id is not None, f'OpenAI requires `tool_id` to be set: {t}'
|
|
352
|
-
return t.tool_id
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
def _map_tool_call(t: ToolCall) -> chat.ChatCompletionMessageToolCallParam:
|
|
386
|
+
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
356
387
|
assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
|
|
357
388
|
return chat.ChatCompletionMessageToolCallParam(
|
|
358
|
-
id=
|
|
389
|
+
id=_guard_tool_call_id(t=t, model_source='OpenAI'),
|
|
359
390
|
type='function',
|
|
360
391
|
function={'name': t.tool_name, 'arguments': t.args.args_json},
|
|
361
392
|
)
|