pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +12 -2
- pydantic_ai/_pydantic.py +7 -25
- pydantic_ai/_result.py +33 -18
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_utils.py +9 -2
- pydantic_ai/agent.py +366 -171
- pydantic_ai/exceptions.py +20 -2
- pydantic_ai/messages.py +111 -50
- pydantic_ai/models/__init__.py +39 -14
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +62 -40
- pydantic_ai/models/gemini.py +164 -124
- pydantic_ai/models/groq.py +112 -94
- pydantic_ai/models/mistral.py +668 -0
- pydantic_ai/models/ollama.py +1 -1
- pydantic_ai/models/openai.py +120 -96
- pydantic_ai/models/test.py +78 -61
- pydantic_ai/models/vertexai.py +7 -3
- pydantic_ai/result.py +96 -68
- pydantic_ai/settings.py +137 -0
- pydantic_ai/tools.py +46 -26
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.14.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +0 -23
pydantic_ai/models/openai.py
CHANGED
|
@@ -4,23 +4,28 @@ from collections.abc import AsyncIterator, Iterable
|
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
7
|
+
from itertools import chain
|
|
7
8
|
from typing import Literal, Union, overload
|
|
8
9
|
|
|
9
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
10
11
|
from typing_extensions import assert_never
|
|
11
12
|
|
|
12
13
|
from .. import UnexpectedModelBehavior, _utils, result
|
|
14
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
15
|
from ..messages import (
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
16
|
+
ModelMessage,
|
|
17
|
+
ModelRequest,
|
|
18
|
+
ModelResponse,
|
|
19
|
+
ModelResponsePart,
|
|
20
|
+
RetryPromptPart,
|
|
21
|
+
SystemPromptPart,
|
|
22
|
+
TextPart,
|
|
23
|
+
ToolCallPart,
|
|
24
|
+
ToolReturnPart,
|
|
25
|
+
UserPromptPart,
|
|
22
26
|
)
|
|
23
|
-
from ..result import
|
|
27
|
+
from ..result import Usage
|
|
28
|
+
from ..settings import ModelSettings
|
|
24
29
|
from ..tools import ToolDefinition
|
|
25
30
|
from . import (
|
|
26
31
|
AgentModel,
|
|
@@ -40,7 +45,7 @@ try:
|
|
|
40
45
|
except ImportError as _import_error:
|
|
41
46
|
raise ImportError(
|
|
42
47
|
'Please install `openai` to use the OpenAI model, '
|
|
43
|
-
"you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
|
|
48
|
+
"you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
|
|
44
49
|
) from _import_error
|
|
45
50
|
|
|
46
51
|
OpenAIModelName = Union[ChatModel, str]
|
|
@@ -66,6 +71,7 @@ class OpenAIModel(Model):
|
|
|
66
71
|
self,
|
|
67
72
|
model_name: OpenAIModelName,
|
|
68
73
|
*,
|
|
74
|
+
base_url: str | None = None,
|
|
69
75
|
api_key: str | None = None,
|
|
70
76
|
openai_client: AsyncOpenAI | None = None,
|
|
71
77
|
http_client: AsyncHTTPClient | None = None,
|
|
@@ -76,22 +82,25 @@ class OpenAIModel(Model):
|
|
|
76
82
|
model_name: The name of the OpenAI model to use. List of model names available
|
|
77
83
|
[here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
|
|
78
84
|
(Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
|
|
85
|
+
base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable
|
|
86
|
+
will be used if available. Otherwise, defaults to OpenAI's base url.
|
|
79
87
|
api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
|
|
80
88
|
will be used if available.
|
|
81
89
|
openai_client: An existing
|
|
82
90
|
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
83
|
-
client to use
|
|
91
|
+
client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
|
|
84
92
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
85
93
|
"""
|
|
86
94
|
self.model_name: OpenAIModelName = model_name
|
|
87
95
|
if openai_client is not None:
|
|
88
96
|
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
97
|
+
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
89
98
|
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
90
99
|
self.client = openai_client
|
|
91
100
|
elif http_client is not None:
|
|
92
|
-
self.client = AsyncOpenAI(api_key=api_key, http_client=http_client)
|
|
101
|
+
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
|
|
93
102
|
else:
|
|
94
|
-
self.client = AsyncOpenAI(api_key=api_key, http_client=cached_async_http_client())
|
|
103
|
+
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
95
104
|
|
|
96
105
|
async def agent_model(
|
|
97
106
|
self,
|
|
@@ -135,28 +144,34 @@ class OpenAIAgentModel(AgentModel):
|
|
|
135
144
|
allow_text_result: bool
|
|
136
145
|
tools: list[chat.ChatCompletionToolParam]
|
|
137
146
|
|
|
138
|
-
async def request(
|
|
139
|
-
|
|
140
|
-
|
|
147
|
+
async def request(
|
|
148
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
149
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
150
|
+
response = await self._completions_create(messages, False, model_settings)
|
|
151
|
+
return self._process_response(response), _map_usage(response)
|
|
141
152
|
|
|
142
153
|
@asynccontextmanager
|
|
143
|
-
async def request_stream(
|
|
144
|
-
|
|
154
|
+
async def request_stream(
|
|
155
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
156
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
157
|
+
response = await self._completions_create(messages, True, model_settings)
|
|
145
158
|
async with response:
|
|
146
159
|
yield await self._process_streamed_response(response)
|
|
147
160
|
|
|
148
161
|
@overload
|
|
149
162
|
async def _completions_create(
|
|
150
|
-
self, messages: list[
|
|
163
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
|
|
151
164
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
152
165
|
pass
|
|
153
166
|
|
|
154
167
|
@overload
|
|
155
|
-
async def _completions_create(
|
|
168
|
+
async def _completions_create(
|
|
169
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
|
|
170
|
+
) -> chat.ChatCompletion:
|
|
156
171
|
pass
|
|
157
172
|
|
|
158
173
|
async def _completions_create(
|
|
159
|
-
self, messages: list[
|
|
174
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
|
|
160
175
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
161
176
|
# standalone function to make it easier to override
|
|
162
177
|
if not self.tools:
|
|
@@ -166,7 +181,10 @@ class OpenAIAgentModel(AgentModel):
|
|
|
166
181
|
else:
|
|
167
182
|
tool_choice = 'auto'
|
|
168
183
|
|
|
169
|
-
openai_messages =
|
|
184
|
+
openai_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
185
|
+
|
|
186
|
+
model_settings = model_settings or {}
|
|
187
|
+
|
|
170
188
|
return await self.client.chat.completions.create(
|
|
171
189
|
model=self.model_name,
|
|
172
190
|
messages=openai_messages,
|
|
@@ -176,27 +194,30 @@ class OpenAIAgentModel(AgentModel):
|
|
|
176
194
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
177
195
|
stream=stream,
|
|
178
196
|
stream_options={'include_usage': True} if stream else NOT_GIVEN,
|
|
197
|
+
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
198
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
199
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
200
|
+
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
179
201
|
)
|
|
180
202
|
|
|
181
203
|
@staticmethod
|
|
182
|
-
def _process_response(response: chat.ChatCompletion) ->
|
|
204
|
+
def _process_response(response: chat.ChatCompletion) -> ModelResponse:
|
|
183
205
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
184
206
|
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
185
207
|
choice = response.choices[0]
|
|
208
|
+
items: list[ModelResponsePart] = []
|
|
209
|
+
if choice.message.content is not None:
|
|
210
|
+
items.append(TextPart(choice.message.content))
|
|
186
211
|
if choice.message.tool_calls is not None:
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
)
|
|
191
|
-
else:
|
|
192
|
-
assert choice.message.content is not None, choice
|
|
193
|
-
return ModelTextResponse(choice.message.content, timestamp=timestamp)
|
|
212
|
+
for c in choice.message.tool_calls:
|
|
213
|
+
items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
|
|
214
|
+
return ModelResponse(items, timestamp=timestamp)
|
|
194
215
|
|
|
195
216
|
@staticmethod
|
|
196
217
|
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
|
197
218
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
198
219
|
timestamp: datetime | None = None
|
|
199
|
-
|
|
220
|
+
start_usage = Usage()
|
|
200
221
|
# the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
|
|
201
222
|
while True:
|
|
202
223
|
try:
|
|
@@ -205,63 +226,73 @@ class OpenAIAgentModel(AgentModel):
|
|
|
205
226
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
|
206
227
|
|
|
207
228
|
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
|
|
208
|
-
|
|
229
|
+
start_usage += _map_usage(chunk)
|
|
209
230
|
|
|
210
231
|
if chunk.choices:
|
|
211
232
|
delta = chunk.choices[0].delta
|
|
212
233
|
|
|
213
234
|
if delta.content is not None:
|
|
214
|
-
return OpenAIStreamTextResponse(delta.content, response, timestamp,
|
|
235
|
+
return OpenAIStreamTextResponse(delta.content, response, timestamp, start_usage)
|
|
215
236
|
elif delta.tool_calls is not None:
|
|
216
237
|
return OpenAIStreamStructuredResponse(
|
|
217
238
|
response,
|
|
218
239
|
{c.index: c for c in delta.tool_calls},
|
|
219
240
|
timestamp,
|
|
220
|
-
|
|
241
|
+
start_usage,
|
|
221
242
|
)
|
|
222
243
|
# else continue until we get either delta.content or delta.tool_calls
|
|
223
244
|
|
|
224
|
-
@
|
|
225
|
-
def _map_message(message:
|
|
245
|
+
@classmethod
|
|
246
|
+
def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
226
247
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
|
|
227
|
-
if message
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
)
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
tool_call_id=_guard_tool_id(message),
|
|
248
|
-
content=message.model_response(),
|
|
249
|
-
)
|
|
250
|
-
elif message.role == 'model-text-response':
|
|
251
|
-
# ModelTextResponse ->
|
|
252
|
-
return chat.ChatCompletionAssistantMessageParam(role='assistant', content=message.content)
|
|
253
|
-
elif message.role == 'model-structured-response':
|
|
254
|
-
assert (
|
|
255
|
-
message.role == 'model-structured-response'
|
|
256
|
-
), f'Expected role to be "llm-tool-calls", got {message.role}'
|
|
257
|
-
# ModelStructuredResponse ->
|
|
258
|
-
return chat.ChatCompletionAssistantMessageParam(
|
|
259
|
-
role='assistant',
|
|
260
|
-
tool_calls=[_map_tool_call(t) for t in message.calls],
|
|
261
|
-
)
|
|
248
|
+
if isinstance(message, ModelRequest):
|
|
249
|
+
yield from cls._map_user_message(message)
|
|
250
|
+
elif isinstance(message, ModelResponse):
|
|
251
|
+
texts: list[str] = []
|
|
252
|
+
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
|
|
253
|
+
for item in message.parts:
|
|
254
|
+
if isinstance(item, TextPart):
|
|
255
|
+
texts.append(item.content)
|
|
256
|
+
elif isinstance(item, ToolCallPart):
|
|
257
|
+
tool_calls.append(_map_tool_call(item))
|
|
258
|
+
else:
|
|
259
|
+
assert_never(item)
|
|
260
|
+
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
|
|
261
|
+
if texts:
|
|
262
|
+
# Note: model responses from this model should only have one text item, so the following
|
|
263
|
+
# shouldn't merge multiple texts into one unless you switch models between runs:
|
|
264
|
+
message_param['content'] = '\n\n'.join(texts)
|
|
265
|
+
if tool_calls:
|
|
266
|
+
message_param['tool_calls'] = tool_calls
|
|
267
|
+
yield message_param
|
|
262
268
|
else:
|
|
263
269
|
assert_never(message)
|
|
264
270
|
|
|
271
|
+
@classmethod
|
|
272
|
+
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
273
|
+
for part in message.parts:
|
|
274
|
+
if isinstance(part, SystemPromptPart):
|
|
275
|
+
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
|
|
276
|
+
elif isinstance(part, UserPromptPart):
|
|
277
|
+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
|
|
278
|
+
elif isinstance(part, ToolReturnPart):
|
|
279
|
+
yield chat.ChatCompletionToolMessageParam(
|
|
280
|
+
role='tool',
|
|
281
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
|
|
282
|
+
content=part.model_response_str(),
|
|
283
|
+
)
|
|
284
|
+
elif isinstance(part, RetryPromptPart):
|
|
285
|
+
if part.tool_name is None:
|
|
286
|
+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
|
|
287
|
+
else:
|
|
288
|
+
yield chat.ChatCompletionToolMessageParam(
|
|
289
|
+
role='tool',
|
|
290
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
|
|
291
|
+
content=part.model_response(),
|
|
292
|
+
)
|
|
293
|
+
else:
|
|
294
|
+
assert_never(part)
|
|
295
|
+
|
|
265
296
|
|
|
266
297
|
@dataclass
|
|
267
298
|
class OpenAIStreamTextResponse(StreamTextResponse):
|
|
@@ -270,7 +301,7 @@ class OpenAIStreamTextResponse(StreamTextResponse):
|
|
|
270
301
|
_first: str | None
|
|
271
302
|
_response: AsyncStream[ChatCompletionChunk]
|
|
272
303
|
_timestamp: datetime
|
|
273
|
-
|
|
304
|
+
_usage: result.Usage
|
|
274
305
|
_buffer: list[str] = field(default_factory=list, init=False)
|
|
275
306
|
|
|
276
307
|
async def __anext__(self) -> None:
|
|
@@ -280,7 +311,7 @@ class OpenAIStreamTextResponse(StreamTextResponse):
|
|
|
280
311
|
return None
|
|
281
312
|
|
|
282
313
|
chunk = await self._response.__anext__()
|
|
283
|
-
self.
|
|
314
|
+
self._usage += _map_usage(chunk)
|
|
284
315
|
try:
|
|
285
316
|
choice = chunk.choices[0]
|
|
286
317
|
except IndexError:
|
|
@@ -296,8 +327,8 @@ class OpenAIStreamTextResponse(StreamTextResponse):
|
|
|
296
327
|
yield from self._buffer
|
|
297
328
|
self._buffer.clear()
|
|
298
329
|
|
|
299
|
-
def
|
|
300
|
-
return self.
|
|
330
|
+
def usage(self) -> Usage:
|
|
331
|
+
return self._usage
|
|
301
332
|
|
|
302
333
|
def timestamp(self) -> datetime:
|
|
303
334
|
return self._timestamp
|
|
@@ -310,11 +341,11 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
|
|
|
310
341
|
_response: AsyncStream[ChatCompletionChunk]
|
|
311
342
|
_delta_tool_calls: dict[int, ChoiceDeltaToolCall]
|
|
312
343
|
_timestamp: datetime
|
|
313
|
-
|
|
344
|
+
_usage: result.Usage
|
|
314
345
|
|
|
315
346
|
async def __anext__(self) -> None:
|
|
316
347
|
chunk = await self._response.__anext__()
|
|
317
|
-
self.
|
|
348
|
+
self._usage += _map_usage(chunk)
|
|
318
349
|
try:
|
|
319
350
|
choice = chunk.choices[0]
|
|
320
351
|
except IndexError:
|
|
@@ -335,48 +366,41 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
|
|
|
335
366
|
else:
|
|
336
367
|
self._delta_tool_calls[new.index] = new
|
|
337
368
|
|
|
338
|
-
def get(self, *, final: bool = False) ->
|
|
339
|
-
|
|
369
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
370
|
+
items: list[ModelResponsePart] = []
|
|
340
371
|
for c in self._delta_tool_calls.values():
|
|
341
372
|
if f := c.function:
|
|
342
373
|
if f.name is not None and f.arguments is not None:
|
|
343
|
-
|
|
374
|
+
items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
|
|
344
375
|
|
|
345
|
-
return
|
|
376
|
+
return ModelResponse(items, timestamp=self._timestamp)
|
|
346
377
|
|
|
347
|
-
def
|
|
348
|
-
return self.
|
|
378
|
+
def usage(self) -> Usage:
|
|
379
|
+
return self._usage
|
|
349
380
|
|
|
350
381
|
def timestamp(self) -> datetime:
|
|
351
382
|
return self._timestamp
|
|
352
383
|
|
|
353
384
|
|
|
354
|
-
def
|
|
355
|
-
"""Type guard that checks a `tool_id` is not None both for static typing and runtime."""
|
|
356
|
-
assert t.tool_id is not None, f'OpenAI requires `tool_id` to be set: {t}'
|
|
357
|
-
return t.tool_id
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
def _map_tool_call(t: ToolCall) -> chat.ChatCompletionMessageToolCallParam:
|
|
361
|
-
assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
|
|
385
|
+
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
362
386
|
return chat.ChatCompletionMessageToolCallParam(
|
|
363
|
-
id=
|
|
387
|
+
id=_guard_tool_call_id(t=t, model_source='OpenAI'),
|
|
364
388
|
type='function',
|
|
365
|
-
function={'name': t.tool_name, 'arguments': t.
|
|
389
|
+
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
366
390
|
)
|
|
367
391
|
|
|
368
392
|
|
|
369
|
-
def
|
|
393
|
+
def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> result.Usage:
|
|
370
394
|
usage = response.usage
|
|
371
395
|
if usage is None:
|
|
372
|
-
return result.
|
|
396
|
+
return result.Usage()
|
|
373
397
|
else:
|
|
374
398
|
details: dict[str, int] = {}
|
|
375
399
|
if usage.completion_tokens_details is not None:
|
|
376
400
|
details.update(usage.completion_tokens_details.model_dump(exclude_none=True))
|
|
377
401
|
if usage.prompt_tokens_details is not None:
|
|
378
402
|
details.update(usage.prompt_tokens_details.model_dump(exclude_none=True))
|
|
379
|
-
return result.
|
|
403
|
+
return result.Usage(
|
|
380
404
|
request_tokens=usage.prompt_tokens,
|
|
381
405
|
response_tokens=usage.completion_tokens,
|
|
382
406
|
total_tokens=usage.total_tokens,
|
pydantic_ai/models/test.py
CHANGED
|
@@ -9,18 +9,20 @@ from datetime import date, datetime, timedelta
|
|
|
9
9
|
from typing import Any, Literal
|
|
10
10
|
|
|
11
11
|
import pydantic_core
|
|
12
|
+
from typing_extensions import assert_never
|
|
12
13
|
|
|
13
14
|
from .. import _utils
|
|
14
15
|
from ..messages import (
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
16
|
+
ModelMessage,
|
|
17
|
+
ModelRequest,
|
|
18
|
+
ModelResponse,
|
|
19
|
+
RetryPromptPart,
|
|
20
|
+
TextPart,
|
|
21
|
+
ToolCallPart,
|
|
22
|
+
ToolReturnPart,
|
|
22
23
|
)
|
|
23
|
-
from ..result import
|
|
24
|
+
from ..result import Usage
|
|
25
|
+
from ..settings import ModelSettings
|
|
24
26
|
from ..tools import ToolDefinition
|
|
25
27
|
from . import (
|
|
26
28
|
AgentModel,
|
|
@@ -29,6 +31,7 @@ from . import (
|
|
|
29
31
|
StreamStructuredResponse,
|
|
30
32
|
StreamTextResponse,
|
|
31
33
|
)
|
|
34
|
+
from .function import _estimate_string_usage, _estimate_usage # pyright: ignore[reportPrivateUsage]
|
|
32
35
|
|
|
33
36
|
|
|
34
37
|
@dataclass
|
|
@@ -127,74 +130,85 @@ class TestAgentModel(AgentModel):
|
|
|
127
130
|
result_tools: list[ToolDefinition]
|
|
128
131
|
seed: int
|
|
129
132
|
|
|
130
|
-
async def request(
|
|
131
|
-
|
|
133
|
+
async def request(
|
|
134
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
135
|
+
) -> tuple[ModelResponse, Usage]:
|
|
136
|
+
model_response = self._request(messages, model_settings)
|
|
137
|
+
usage = _estimate_usage([*messages, model_response])
|
|
138
|
+
return model_response, usage
|
|
132
139
|
|
|
133
140
|
@asynccontextmanager
|
|
134
|
-
async def request_stream(
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
141
|
+
async def request_stream(
|
|
142
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
143
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
144
|
+
msg = self._request(messages, model_settings)
|
|
145
|
+
usage = _estimate_usage(messages)
|
|
146
|
+
|
|
147
|
+
# TODO: Rework this once we make StreamTextResponse more general
|
|
148
|
+
texts: list[str] = []
|
|
149
|
+
tool_calls: list[ToolCallPart] = []
|
|
150
|
+
for item in msg.parts:
|
|
151
|
+
if isinstance(item, TextPart):
|
|
152
|
+
texts.append(item.content)
|
|
153
|
+
elif isinstance(item, ToolCallPart):
|
|
154
|
+
tool_calls.append(item)
|
|
155
|
+
else:
|
|
156
|
+
assert_never(item)
|
|
157
|
+
|
|
158
|
+
if texts:
|
|
159
|
+
yield TestStreamTextResponse('\n\n'.join(texts), usage)
|
|
139
160
|
else:
|
|
140
|
-
yield TestStreamStructuredResponse(msg,
|
|
161
|
+
yield TestStreamStructuredResponse(msg, usage)
|
|
141
162
|
|
|
142
163
|
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
143
164
|
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
|
|
144
165
|
|
|
145
|
-
def _request(self, messages: list[
|
|
166
|
+
def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | None) -> ModelResponse:
|
|
146
167
|
# if there are tools, the first thing we want to do is call all of them
|
|
147
|
-
if self.tool_calls and not any(m
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
168
|
+
if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
|
|
169
|
+
return ModelResponse(
|
|
170
|
+
parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
if messages:
|
|
174
|
+
last_message = messages[-1]
|
|
175
|
+
assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.'
|
|
176
|
+
|
|
177
|
+
# check if there are any retry prompts, if so retry them
|
|
178
|
+
new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
|
|
179
|
+
if new_retry_names:
|
|
180
|
+
return ModelResponse(
|
|
181
|
+
parts=[
|
|
182
|
+
ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
|
|
183
|
+
for name, args in self.tool_calls
|
|
184
|
+
if name in new_retry_names
|
|
185
|
+
]
|
|
186
|
+
)
|
|
163
187
|
|
|
164
188
|
if response_text := self.result.left:
|
|
165
189
|
if response_text.value is None:
|
|
166
190
|
# build up details of tool responses
|
|
167
191
|
output: dict[str, Any] = {}
|
|
168
192
|
for message in messages:
|
|
169
|
-
if isinstance(message,
|
|
170
|
-
|
|
193
|
+
if isinstance(message, ModelRequest):
|
|
194
|
+
for part in message.parts:
|
|
195
|
+
if isinstance(part, ToolReturnPart):
|
|
196
|
+
output[part.tool_name] = part.content
|
|
171
197
|
if output:
|
|
172
|
-
return
|
|
198
|
+
return ModelResponse.from_text(pydantic_core.to_json(output).decode())
|
|
173
199
|
else:
|
|
174
|
-
return
|
|
200
|
+
return ModelResponse.from_text('success (no tool calls)')
|
|
175
201
|
else:
|
|
176
|
-
return
|
|
202
|
+
return ModelResponse.from_text(response_text.value)
|
|
177
203
|
else:
|
|
178
204
|
assert self.result_tools, 'No result tools provided'
|
|
179
205
|
custom_result_args = self.result.right
|
|
180
206
|
result_tool = self.result_tools[self.seed % len(self.result_tools)]
|
|
181
207
|
if custom_result_args is not None:
|
|
182
|
-
return
|
|
208
|
+
return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, custom_result_args)])
|
|
183
209
|
else:
|
|
184
210
|
response_args = self.gen_tool_args(result_tool)
|
|
185
|
-
return
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
def _get_new_messages(messages: list[Message]) -> list[Message]:
|
|
189
|
-
last_model_index = None
|
|
190
|
-
for i, m in enumerate(messages):
|
|
191
|
-
if m.role in ('model-structured-response', 'model-text-response'):
|
|
192
|
-
last_model_index = i
|
|
193
|
-
|
|
194
|
-
if last_model_index is not None:
|
|
195
|
-
return messages[last_model_index + 1 :]
|
|
196
|
-
else:
|
|
197
|
-
return []
|
|
211
|
+
return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, response_args)])
|
|
198
212
|
|
|
199
213
|
|
|
200
214
|
@dataclass
|
|
@@ -202,7 +216,7 @@ class TestStreamTextResponse(StreamTextResponse):
|
|
|
202
216
|
"""A text response that streams test data."""
|
|
203
217
|
|
|
204
218
|
_text: str
|
|
205
|
-
|
|
219
|
+
_usage: Usage
|
|
206
220
|
_iter: Iterator[str] = field(init=False)
|
|
207
221
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
208
222
|
_buffer: list[str] = field(default_factory=list, init=False)
|
|
@@ -217,14 +231,17 @@ class TestStreamTextResponse(StreamTextResponse):
|
|
|
217
231
|
self._iter = iter(words)
|
|
218
232
|
|
|
219
233
|
async def __anext__(self) -> None:
|
|
220
|
-
|
|
234
|
+
next_str = _utils.sync_anext(self._iter)
|
|
235
|
+
response_tokens = _estimate_string_usage(next_str)
|
|
236
|
+
self._usage += Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
|
237
|
+
self._buffer.append(next_str)
|
|
221
238
|
|
|
222
239
|
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
223
240
|
yield from self._buffer
|
|
224
241
|
self._buffer.clear()
|
|
225
242
|
|
|
226
|
-
def
|
|
227
|
-
return self.
|
|
243
|
+
def usage(self) -> Usage:
|
|
244
|
+
return self._usage
|
|
228
245
|
|
|
229
246
|
def timestamp(self) -> datetime:
|
|
230
247
|
return self._timestamp
|
|
@@ -234,19 +251,19 @@ class TestStreamTextResponse(StreamTextResponse):
|
|
|
234
251
|
class TestStreamStructuredResponse(StreamStructuredResponse):
|
|
235
252
|
"""A structured response that streams test data."""
|
|
236
253
|
|
|
237
|
-
_structured_response:
|
|
238
|
-
|
|
254
|
+
_structured_response: ModelResponse
|
|
255
|
+
_usage: Usage
|
|
239
256
|
_iter: Iterator[None] = field(default_factory=lambda: iter([None]))
|
|
240
257
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
241
258
|
|
|
242
259
|
async def __anext__(self) -> None:
|
|
243
260
|
return _utils.sync_anext(self._iter)
|
|
244
261
|
|
|
245
|
-
def get(self, *, final: bool = False) ->
|
|
262
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
246
263
|
return self._structured_response
|
|
247
264
|
|
|
248
|
-
def
|
|
249
|
-
return self.
|
|
265
|
+
def usage(self) -> Usage:
|
|
266
|
+
return self._usage
|
|
250
267
|
|
|
251
268
|
def timestamp(self) -> datetime:
|
|
252
269
|
return self._timestamp
|
pydantic_ai/models/vertexai.py
CHANGED
|
@@ -21,7 +21,7 @@ try:
|
|
|
21
21
|
except ImportError as _import_error:
|
|
22
22
|
raise ImportError(
|
|
23
23
|
'Please install `google-auth` to use the VertexAI model, '
|
|
24
|
-
"you can use the `vertexai` optional group — `pip install 'pydantic-ai[vertexai]'`"
|
|
24
|
+
"you can use the `vertexai` optional group — `pip install 'pydantic-ai-slim[vertexai]'`"
|
|
25
25
|
) from _import_error
|
|
26
26
|
|
|
27
27
|
VERTEX_AI_URL_TEMPLATE = (
|
|
@@ -114,7 +114,7 @@ class VertexAIModel(Model):
|
|
|
114
114
|
allow_text_result: bool,
|
|
115
115
|
result_tools: list[ToolDefinition],
|
|
116
116
|
) -> GeminiAgentModel:
|
|
117
|
-
url, auth = await self.
|
|
117
|
+
url, auth = await self.ainit()
|
|
118
118
|
return GeminiAgentModel(
|
|
119
119
|
http_client=self.http_client,
|
|
120
120
|
model_name=self.model_name,
|
|
@@ -125,7 +125,11 @@ class VertexAIModel(Model):
|
|
|
125
125
|
result_tools=result_tools,
|
|
126
126
|
)
|
|
127
127
|
|
|
128
|
-
async def
|
|
128
|
+
async def ainit(self) -> tuple[str, BearerTokenAuth]:
|
|
129
|
+
"""Initialize the model, setting the URL and auth.
|
|
130
|
+
|
|
131
|
+
This will raise an error if authentication fails.
|
|
132
|
+
"""
|
|
129
133
|
if self.url is not None and self.auth is not None:
|
|
130
134
|
return self.url, self.auth
|
|
131
135
|
|