pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +10 -3
- pydantic_ai/_parts_manager.py +239 -0
- pydantic_ai/_pydantic.py +17 -3
- pydantic_ai/_result.py +26 -21
- pydantic_ai/_system_prompt.py +4 -4
- pydantic_ai/_utils.py +80 -17
- pydantic_ai/agent.py +187 -159
- pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai/messages.py +217 -15
- pydantic_ai/models/__init__.py +58 -71
- pydantic_ai/models/anthropic.py +112 -48
- pydantic_ai/models/cohere.py +278 -0
- pydantic_ai/models/function.py +57 -85
- pydantic_ai/models/gemini.py +83 -129
- pydantic_ai/models/groq.py +60 -130
- pydantic_ai/models/mistral.py +86 -142
- pydantic_ai/models/ollama.py +4 -0
- pydantic_ai/models/openai.py +75 -136
- pydantic_ai/models/test.py +55 -80
- pydantic_ai/models/vertexai.py +2 -1
- pydantic_ai/result.py +132 -114
- pydantic_ai/settings.py +18 -1
- pydantic_ai/tools.py +42 -23
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.20.dist-info}/METADATA +7 -3
- pydantic_ai_slim-0.0.20.dist-info/RECORD +30 -0
- pydantic_ai_slim-0.0.18.dist-info/RECORD +0 -28
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.20.dist-info}/WHEEL +0 -0
pydantic_ai/models/openai.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator, Iterable
|
|
3
|
+
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
@@ -10,13 +10,14 @@ from typing import Literal, Union, overload
|
|
|
10
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
11
11
|
from typing_extensions import assert_never
|
|
12
12
|
|
|
13
|
-
from .. import UnexpectedModelBehavior, _utils,
|
|
13
|
+
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
14
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
15
|
from ..messages import (
|
|
16
16
|
ModelMessage,
|
|
17
17
|
ModelRequest,
|
|
18
18
|
ModelResponse,
|
|
19
19
|
ModelResponsePart,
|
|
20
|
+
ModelResponseStreamEvent,
|
|
20
21
|
RetryPromptPart,
|
|
21
22
|
SystemPromptPart,
|
|
22
23
|
TextPart,
|
|
@@ -24,15 +25,12 @@ from ..messages import (
|
|
|
24
25
|
ToolReturnPart,
|
|
25
26
|
UserPromptPart,
|
|
26
27
|
)
|
|
27
|
-
from ..result import Usage
|
|
28
28
|
from ..settings import ModelSettings
|
|
29
29
|
from ..tools import ToolDefinition
|
|
30
30
|
from . import (
|
|
31
31
|
AgentModel,
|
|
32
|
-
EitherStreamedResponse,
|
|
33
32
|
Model,
|
|
34
|
-
|
|
35
|
-
StreamTextResponse,
|
|
33
|
+
StreamedResponse,
|
|
36
34
|
cached_async_http_client,
|
|
37
35
|
check_allow_model_requests,
|
|
38
36
|
)
|
|
@@ -41,7 +39,6 @@ try:
|
|
|
41
39
|
from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream
|
|
42
40
|
from openai.types import ChatModel, chat
|
|
43
41
|
from openai.types.chat import ChatCompletionChunk
|
|
44
|
-
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
|
45
42
|
except ImportError as _import_error:
|
|
46
43
|
raise ImportError(
|
|
47
44
|
'Please install `openai` to use the OpenAI model, '
|
|
@@ -54,6 +51,8 @@ Using this more broad type for the model name instead of the ChatModel definitio
|
|
|
54
51
|
allows this model to be used more easily with other model types (ie, Ollama)
|
|
55
52
|
"""
|
|
56
53
|
|
|
54
|
+
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
55
|
+
|
|
57
56
|
|
|
58
57
|
@dataclass(init=False)
|
|
59
58
|
class OpenAIModel(Model):
|
|
@@ -66,6 +65,7 @@ class OpenAIModel(Model):
|
|
|
66
65
|
|
|
67
66
|
model_name: OpenAIModelName
|
|
68
67
|
client: AsyncOpenAI = field(repr=False)
|
|
68
|
+
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
|
|
69
69
|
|
|
70
70
|
def __init__(
|
|
71
71
|
self,
|
|
@@ -75,6 +75,7 @@ class OpenAIModel(Model):
|
|
|
75
75
|
api_key: str | None = None,
|
|
76
76
|
openai_client: AsyncOpenAI | None = None,
|
|
77
77
|
http_client: AsyncHTTPClient | None = None,
|
|
78
|
+
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
78
79
|
):
|
|
79
80
|
"""Initialize an OpenAI model.
|
|
80
81
|
|
|
@@ -90,6 +91,8 @@ class OpenAIModel(Model):
|
|
|
90
91
|
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
91
92
|
client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
|
|
92
93
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
94
|
+
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
|
|
95
|
+
In the future, this may be inferred from the model name.
|
|
93
96
|
"""
|
|
94
97
|
self.model_name: OpenAIModelName = model_name
|
|
95
98
|
if openai_client is not None:
|
|
@@ -101,6 +104,7 @@ class OpenAIModel(Model):
|
|
|
101
104
|
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
|
|
102
105
|
else:
|
|
103
106
|
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
107
|
+
self.system_prompt_role = system_prompt_role
|
|
104
108
|
|
|
105
109
|
async def agent_model(
|
|
106
110
|
self,
|
|
@@ -118,6 +122,7 @@ class OpenAIModel(Model):
|
|
|
118
122
|
self.model_name,
|
|
119
123
|
allow_text_result,
|
|
120
124
|
tools,
|
|
125
|
+
self.system_prompt_role,
|
|
121
126
|
)
|
|
122
127
|
|
|
123
128
|
def name(self) -> str:
|
|
@@ -143,17 +148,18 @@ class OpenAIAgentModel(AgentModel):
|
|
|
143
148
|
model_name: OpenAIModelName
|
|
144
149
|
allow_text_result: bool
|
|
145
150
|
tools: list[chat.ChatCompletionToolParam]
|
|
151
|
+
system_prompt_role: OpenAISystemPromptRole | None
|
|
146
152
|
|
|
147
153
|
async def request(
|
|
148
154
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
149
|
-
) -> tuple[ModelResponse,
|
|
155
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
150
156
|
response = await self._completions_create(messages, False, model_settings)
|
|
151
157
|
return self._process_response(response), _map_usage(response)
|
|
152
158
|
|
|
153
159
|
@asynccontextmanager
|
|
154
160
|
async def request_stream(
|
|
155
161
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
156
|
-
) -> AsyncIterator[
|
|
162
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
157
163
|
response = await self._completions_create(messages, True, model_settings)
|
|
158
164
|
async with response:
|
|
159
165
|
yield await self._process_streamed_response(response)
|
|
@@ -189,7 +195,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
189
195
|
model=self.model_name,
|
|
190
196
|
messages=openai_messages,
|
|
191
197
|
n=1,
|
|
192
|
-
parallel_tool_calls=True if self.tools else NOT_GIVEN,
|
|
198
|
+
parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
|
|
193
199
|
tools=self.tools or NOT_GIVEN,
|
|
194
200
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
195
201
|
stream=stream,
|
|
@@ -200,8 +206,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
200
206
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
201
207
|
)
|
|
202
208
|
|
|
203
|
-
|
|
204
|
-
def _process_response(response: chat.ChatCompletion) -> ModelResponse:
|
|
209
|
+
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
205
210
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
206
211
|
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
207
212
|
choice = response.choices[0]
|
|
@@ -211,42 +216,25 @@ class OpenAIAgentModel(AgentModel):
|
|
|
211
216
|
if choice.message.tool_calls is not None:
|
|
212
217
|
for c in choice.message.tool_calls:
|
|
213
218
|
items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
|
|
214
|
-
return ModelResponse(items, timestamp=timestamp)
|
|
219
|
+
return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
|
|
215
220
|
|
|
216
|
-
|
|
217
|
-
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
|
221
|
+
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
|
|
218
222
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
start_usage += _map_usage(chunk)
|
|
230
|
-
|
|
231
|
-
if chunk.choices:
|
|
232
|
-
delta = chunk.choices[0].delta
|
|
233
|
-
|
|
234
|
-
if delta.content is not None:
|
|
235
|
-
return OpenAIStreamTextResponse(delta.content, response, timestamp, start_usage)
|
|
236
|
-
elif delta.tool_calls is not None:
|
|
237
|
-
return OpenAIStreamStructuredResponse(
|
|
238
|
-
response,
|
|
239
|
-
{c.index: c for c in delta.tool_calls},
|
|
240
|
-
timestamp,
|
|
241
|
-
start_usage,
|
|
242
|
-
)
|
|
243
|
-
# else continue until we get either delta.content or delta.tool_calls
|
|
223
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
224
|
+
first_chunk = await peekable_response.peek()
|
|
225
|
+
if isinstance(first_chunk, _utils.Unset):
|
|
226
|
+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
227
|
+
|
|
228
|
+
return OpenAIStreamedResponse(
|
|
229
|
+
_model_name=self.model_name,
|
|
230
|
+
_response=peekable_response,
|
|
231
|
+
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
232
|
+
)
|
|
244
233
|
|
|
245
|
-
|
|
246
|
-
def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
234
|
+
def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
247
235
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
|
|
248
236
|
if isinstance(message, ModelRequest):
|
|
249
|
-
yield from
|
|
237
|
+
yield from self._map_user_message(message)
|
|
250
238
|
elif isinstance(message, ModelResponse):
|
|
251
239
|
texts: list[str] = []
|
|
252
240
|
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
|
|
@@ -268,11 +256,15 @@ class OpenAIAgentModel(AgentModel):
|
|
|
268
256
|
else:
|
|
269
257
|
assert_never(message)
|
|
270
258
|
|
|
271
|
-
|
|
272
|
-
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
259
|
+
def _map_user_message(self, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
273
260
|
for part in message.parts:
|
|
274
261
|
if isinstance(part, SystemPromptPart):
|
|
275
|
-
|
|
262
|
+
if self.system_prompt_role == 'developer':
|
|
263
|
+
yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content)
|
|
264
|
+
elif self.system_prompt_role == 'user':
|
|
265
|
+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
|
|
266
|
+
else:
|
|
267
|
+
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
|
|
276
268
|
elif isinstance(part, UserPromptPart):
|
|
277
269
|
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
|
|
278
270
|
elif isinstance(part, ToolReturnPart):
|
|
@@ -295,88 +287,35 @@ class OpenAIAgentModel(AgentModel):
|
|
|
295
287
|
|
|
296
288
|
|
|
297
289
|
@dataclass
|
|
298
|
-
class
|
|
299
|
-
"""Implementation of `
|
|
300
|
-
|
|
301
|
-
_first: str | None
|
|
302
|
-
_response: AsyncStream[ChatCompletionChunk]
|
|
303
|
-
_timestamp: datetime
|
|
304
|
-
_usage: result.Usage
|
|
305
|
-
_buffer: list[str] = field(default_factory=list, init=False)
|
|
306
|
-
|
|
307
|
-
async def __anext__(self) -> None:
|
|
308
|
-
if self._first is not None:
|
|
309
|
-
self._buffer.append(self._first)
|
|
310
|
-
self._first = None
|
|
311
|
-
return None
|
|
312
|
-
|
|
313
|
-
chunk = await self._response.__anext__()
|
|
314
|
-
self._usage += _map_usage(chunk)
|
|
315
|
-
try:
|
|
316
|
-
choice = chunk.choices[0]
|
|
317
|
-
except IndexError:
|
|
318
|
-
raise StopAsyncIteration()
|
|
319
|
-
|
|
320
|
-
# we don't raise StopAsyncIteration on the last chunk because usage comes after this
|
|
321
|
-
if choice.finish_reason is None:
|
|
322
|
-
assert choice.delta.content is not None, f'Expected delta with content, invalid chunk: {chunk!r}'
|
|
323
|
-
if choice.delta.content is not None:
|
|
324
|
-
self._buffer.append(choice.delta.content)
|
|
325
|
-
|
|
326
|
-
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
327
|
-
yield from self._buffer
|
|
328
|
-
self._buffer.clear()
|
|
329
|
-
|
|
330
|
-
def usage(self) -> Usage:
|
|
331
|
-
return self._usage
|
|
332
|
-
|
|
333
|
-
def timestamp(self) -> datetime:
|
|
334
|
-
return self._timestamp
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
@dataclass
|
|
338
|
-
class OpenAIStreamStructuredResponse(StreamStructuredResponse):
|
|
339
|
-
"""Implementation of `StreamStructuredResponse` for OpenAI models."""
|
|
290
|
+
class OpenAIStreamedResponse(StreamedResponse):
|
|
291
|
+
"""Implementation of `StreamedResponse` for OpenAI models."""
|
|
340
292
|
|
|
341
|
-
_response:
|
|
342
|
-
_delta_tool_calls: dict[int, ChoiceDeltaToolCall]
|
|
293
|
+
_response: AsyncIterable[ChatCompletionChunk]
|
|
343
294
|
_timestamp: datetime
|
|
344
|
-
_usage: result.Usage
|
|
345
|
-
|
|
346
|
-
async def __anext__(self) -> None:
|
|
347
|
-
chunk = await self._response.__anext__()
|
|
348
|
-
self._usage += _map_usage(chunk)
|
|
349
|
-
try:
|
|
350
|
-
choice = chunk.choices[0]
|
|
351
|
-
except IndexError:
|
|
352
|
-
raise StopAsyncIteration()
|
|
353
|
-
|
|
354
|
-
if choice.finish_reason is not None:
|
|
355
|
-
raise StopAsyncIteration()
|
|
356
|
-
|
|
357
|
-
assert choice.delta.content is None, f'Expected tool calls, got content instead, invalid chunk: {chunk!r}'
|
|
358
|
-
|
|
359
|
-
for new in choice.delta.tool_calls or []:
|
|
360
|
-
if current := self._delta_tool_calls.get(new.index):
|
|
361
|
-
if current.function is None:
|
|
362
|
-
current.function = new.function
|
|
363
|
-
elif new.function is not None:
|
|
364
|
-
current.function.name = _utils.add_optional(current.function.name, new.function.name)
|
|
365
|
-
current.function.arguments = _utils.add_optional(current.function.arguments, new.function.arguments)
|
|
366
|
-
else:
|
|
367
|
-
self._delta_tool_calls[new.index] = new
|
|
368
|
-
|
|
369
|
-
def get(self, *, final: bool = False) -> ModelResponse:
|
|
370
|
-
items: list[ModelResponsePart] = []
|
|
371
|
-
for c in self._delta_tool_calls.values():
|
|
372
|
-
if f := c.function:
|
|
373
|
-
if f.name is not None and f.arguments is not None:
|
|
374
|
-
items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
|
|
375
295
|
|
|
376
|
-
|
|
296
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
297
|
+
async for chunk in self._response:
|
|
298
|
+
self._usage += _map_usage(chunk)
|
|
377
299
|
|
|
378
|
-
|
|
379
|
-
|
|
300
|
+
try:
|
|
301
|
+
choice = chunk.choices[0]
|
|
302
|
+
except IndexError:
|
|
303
|
+
continue
|
|
304
|
+
|
|
305
|
+
# Handle the text part of the response
|
|
306
|
+
content = choice.delta.content
|
|
307
|
+
if content is not None:
|
|
308
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
|
|
309
|
+
|
|
310
|
+
for dtc in choice.delta.tool_calls or []:
|
|
311
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
312
|
+
vendor_part_id=dtc.index,
|
|
313
|
+
tool_name=dtc.function and dtc.function.name,
|
|
314
|
+
args=dtc.function and dtc.function.arguments,
|
|
315
|
+
tool_call_id=dtc.id,
|
|
316
|
+
)
|
|
317
|
+
if maybe_event is not None:
|
|
318
|
+
yield maybe_event
|
|
380
319
|
|
|
381
320
|
def timestamp(self) -> datetime:
|
|
382
321
|
return self._timestamp
|
|
@@ -390,19 +329,19 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
|
390
329
|
)
|
|
391
330
|
|
|
392
331
|
|
|
393
|
-
def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) ->
|
|
394
|
-
|
|
395
|
-
if
|
|
396
|
-
return
|
|
332
|
+
def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usage:
|
|
333
|
+
response_usage = response.usage
|
|
334
|
+
if response_usage is None:
|
|
335
|
+
return usage.Usage()
|
|
397
336
|
else:
|
|
398
337
|
details: dict[str, int] = {}
|
|
399
|
-
if
|
|
400
|
-
details.update(
|
|
401
|
-
if
|
|
402
|
-
details.update(
|
|
403
|
-
return
|
|
404
|
-
request_tokens=
|
|
405
|
-
response_tokens=
|
|
406
|
-
total_tokens=
|
|
338
|
+
if response_usage.completion_tokens_details is not None:
|
|
339
|
+
details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True))
|
|
340
|
+
if response_usage.prompt_tokens_details is not None:
|
|
341
|
+
details.update(response_usage.prompt_tokens_details.model_dump(exclude_none=True))
|
|
342
|
+
return usage.Usage(
|
|
343
|
+
request_tokens=response_usage.prompt_tokens,
|
|
344
|
+
response_tokens=response_usage.completion_tokens,
|
|
345
|
+
total_tokens=response_usage.total_tokens,
|
|
407
346
|
details=details,
|
|
408
347
|
)
|
pydantic_ai/models/test.py
CHANGED
|
@@ -2,21 +2,22 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import re
|
|
4
4
|
import string
|
|
5
|
-
from collections.abc import AsyncIterator, Iterable
|
|
5
|
+
from collections.abc import AsyncIterator, Iterable
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
|
-
from dataclasses import dataclass, field
|
|
7
|
+
from dataclasses import InitVar, dataclass, field
|
|
8
8
|
from datetime import date, datetime, timedelta
|
|
9
9
|
from typing import Any, Literal
|
|
10
10
|
|
|
11
11
|
import pydantic_core
|
|
12
|
-
from typing_extensions import assert_never
|
|
13
12
|
|
|
14
13
|
from .. import _utils
|
|
15
14
|
from ..messages import (
|
|
15
|
+
ArgsJson,
|
|
16
16
|
ModelMessage,
|
|
17
17
|
ModelRequest,
|
|
18
18
|
ModelResponse,
|
|
19
19
|
ModelResponsePart,
|
|
20
|
+
ModelResponseStreamEvent,
|
|
20
21
|
RetryPromptPart,
|
|
21
22
|
TextPart,
|
|
22
23
|
ToolCallPart,
|
|
@@ -27,12 +28,10 @@ from ..settings import ModelSettings
|
|
|
27
28
|
from ..tools import ToolDefinition
|
|
28
29
|
from . import (
|
|
29
30
|
AgentModel,
|
|
30
|
-
EitherStreamedResponse,
|
|
31
31
|
Model,
|
|
32
|
-
|
|
33
|
-
StreamTextResponse,
|
|
32
|
+
StreamedResponse,
|
|
34
33
|
)
|
|
35
|
-
from .function import
|
|
34
|
+
from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
|
|
36
35
|
|
|
37
36
|
|
|
38
37
|
@dataclass
|
|
@@ -130,6 +129,7 @@ class TestAgentModel(AgentModel):
|
|
|
130
129
|
result: _utils.Either[str | None, Any | None]
|
|
131
130
|
result_tools: list[ToolDefinition]
|
|
132
131
|
seed: int
|
|
132
|
+
model_name: str = 'test'
|
|
133
133
|
|
|
134
134
|
async def request(
|
|
135
135
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
@@ -141,25 +141,9 @@ class TestAgentModel(AgentModel):
|
|
|
141
141
|
@asynccontextmanager
|
|
142
142
|
async def request_stream(
|
|
143
143
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
144
|
-
) -> AsyncIterator[
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
# TODO: Rework this once we make StreamTextResponse more general
|
|
149
|
-
texts: list[str] = []
|
|
150
|
-
tool_calls: list[ToolCallPart] = []
|
|
151
|
-
for item in msg.parts:
|
|
152
|
-
if isinstance(item, TextPart):
|
|
153
|
-
texts.append(item.content)
|
|
154
|
-
elif isinstance(item, ToolCallPart):
|
|
155
|
-
tool_calls.append(item)
|
|
156
|
-
else:
|
|
157
|
-
assert_never(item)
|
|
158
|
-
|
|
159
|
-
if texts:
|
|
160
|
-
yield TestStreamTextResponse('\n\n'.join(texts), usage)
|
|
161
|
-
else:
|
|
162
|
-
yield TestStreamStructuredResponse(msg, usage)
|
|
144
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
145
|
+
model_response = self._request(messages, model_settings)
|
|
146
|
+
yield TestStreamedResponse(_model_name=self.model_name, _structured_response=model_response, _messages=messages)
|
|
163
147
|
|
|
164
148
|
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
165
149
|
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
|
|
@@ -168,7 +152,8 @@ class TestAgentModel(AgentModel):
|
|
|
168
152
|
# if there are tools, the first thing we want to do is call all of them
|
|
169
153
|
if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
|
|
170
154
|
return ModelResponse(
|
|
171
|
-
parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
|
|
155
|
+
parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls],
|
|
156
|
+
model_name=self.model_name,
|
|
172
157
|
)
|
|
173
158
|
|
|
174
159
|
if messages:
|
|
@@ -194,7 +179,7 @@ class TestAgentModel(AgentModel):
|
|
|
194
179
|
if tool.name in new_retry_names
|
|
195
180
|
]
|
|
196
181
|
)
|
|
197
|
-
return ModelResponse(parts=retry_parts)
|
|
182
|
+
return ModelResponse(parts=retry_parts, model_name=self.model_name)
|
|
198
183
|
|
|
199
184
|
if response_text := self.result.left:
|
|
200
185
|
if response_text.value is None:
|
|
@@ -206,75 +191,60 @@ class TestAgentModel(AgentModel):
|
|
|
206
191
|
if isinstance(part, ToolReturnPart):
|
|
207
192
|
output[part.tool_name] = part.content
|
|
208
193
|
if output:
|
|
209
|
-
return ModelResponse
|
|
194
|
+
return ModelResponse(
|
|
195
|
+
parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self.model_name
|
|
196
|
+
)
|
|
210
197
|
else:
|
|
211
|
-
return ModelResponse
|
|
198
|
+
return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.model_name)
|
|
212
199
|
else:
|
|
213
|
-
return ModelResponse
|
|
200
|
+
return ModelResponse(parts=[TextPart(response_text.value)], model_name=self.model_name)
|
|
214
201
|
else:
|
|
215
202
|
assert self.result_tools, 'No result tools provided'
|
|
216
203
|
custom_result_args = self.result.right
|
|
217
204
|
result_tool = self.result_tools[self.seed % len(self.result_tools)]
|
|
218
205
|
if custom_result_args is not None:
|
|
219
|
-
return ModelResponse(
|
|
206
|
+
return ModelResponse(
|
|
207
|
+
parts=[ToolCallPart.from_raw_args(result_tool.name, custom_result_args)], model_name=self.model_name
|
|
208
|
+
)
|
|
220
209
|
else:
|
|
221
210
|
response_args = self.gen_tool_args(result_tool)
|
|
222
|
-
return ModelResponse(
|
|
211
|
+
return ModelResponse(
|
|
212
|
+
parts=[ToolCallPart.from_raw_args(result_tool.name, response_args)], model_name=self.model_name
|
|
213
|
+
)
|
|
223
214
|
|
|
224
215
|
|
|
225
216
|
@dataclass
|
|
226
|
-
class
|
|
227
|
-
"""A text response that streams test data."""
|
|
228
|
-
|
|
229
|
-
_text: str
|
|
230
|
-
_usage: Usage
|
|
231
|
-
_iter: Iterator[str] = field(init=False)
|
|
232
|
-
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
233
|
-
_buffer: list[str] = field(default_factory=list, init=False)
|
|
234
|
-
|
|
235
|
-
def __post_init__(self):
|
|
236
|
-
*words, last_word = self._text.split(' ')
|
|
237
|
-
words = [f'{word} ' for word in words]
|
|
238
|
-
words.append(last_word)
|
|
239
|
-
if len(words) == 1 and len(self._text) > 2:
|
|
240
|
-
mid = len(self._text) // 2
|
|
241
|
-
words = [self._text[:mid], self._text[mid:]]
|
|
242
|
-
self._iter = iter(words)
|
|
243
|
-
|
|
244
|
-
async def __anext__(self) -> None:
|
|
245
|
-
next_str = _utils.sync_anext(self._iter)
|
|
246
|
-
response_tokens = _estimate_string_usage(next_str)
|
|
247
|
-
self._usage += Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
|
248
|
-
self._buffer.append(next_str)
|
|
249
|
-
|
|
250
|
-
def get(self, *, final: bool = False) -> Iterable[str]:
|
|
251
|
-
yield from self._buffer
|
|
252
|
-
self._buffer.clear()
|
|
253
|
-
|
|
254
|
-
def usage(self) -> Usage:
|
|
255
|
-
return self._usage
|
|
256
|
-
|
|
257
|
-
def timestamp(self) -> datetime:
|
|
258
|
-
return self._timestamp
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
@dataclass
|
|
262
|
-
class TestStreamStructuredResponse(StreamStructuredResponse):
|
|
217
|
+
class TestStreamedResponse(StreamedResponse):
|
|
263
218
|
"""A structured response that streams test data."""
|
|
264
219
|
|
|
265
220
|
_structured_response: ModelResponse
|
|
266
|
-
|
|
267
|
-
_iter: Iterator[None] = field(default_factory=lambda: iter([None]))
|
|
268
|
-
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
269
|
-
|
|
270
|
-
async def __anext__(self) -> None:
|
|
271
|
-
return _utils.sync_anext(self._iter)
|
|
221
|
+
_messages: InitVar[Iterable[ModelMessage]]
|
|
272
222
|
|
|
273
|
-
|
|
274
|
-
return self._structured_response
|
|
223
|
+
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
275
224
|
|
|
276
|
-
def
|
|
277
|
-
|
|
225
|
+
def __post_init__(self, _messages: Iterable[ModelMessage]):
|
|
226
|
+
self._usage = _estimate_usage(_messages)
|
|
227
|
+
|
|
228
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
229
|
+
for i, part in enumerate(self._structured_response.parts):
|
|
230
|
+
if isinstance(part, TextPart):
|
|
231
|
+
text = part.content
|
|
232
|
+
*words, last_word = text.split(' ')
|
|
233
|
+
words = [f'{word} ' for word in words]
|
|
234
|
+
words.append(last_word)
|
|
235
|
+
if len(words) == 1 and len(text) > 2:
|
|
236
|
+
mid = len(text) // 2
|
|
237
|
+
words = [text[:mid], text[mid:]]
|
|
238
|
+
self._usage += _get_string_usage('')
|
|
239
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id=i, content='')
|
|
240
|
+
for word in words:
|
|
241
|
+
self._usage += _get_string_usage(word)
|
|
242
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
|
|
243
|
+
else:
|
|
244
|
+
args = part.args.args_json if isinstance(part.args, ArgsJson) else part.args.args_dict
|
|
245
|
+
yield self._parts_manager.handle_tool_call_part(
|
|
246
|
+
vendor_part_id=i, tool_name=part.tool_name, args=args, tool_call_id=part.tool_call_id
|
|
247
|
+
)
|
|
278
248
|
|
|
279
249
|
def timestamp(self) -> datetime:
|
|
280
250
|
return self._timestamp
|
|
@@ -434,3 +404,8 @@ class _JsonSchemaTestData:
|
|
|
434
404
|
rem //= chars
|
|
435
405
|
s += _chars[self.seed % chars]
|
|
436
406
|
return s
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def _get_string_usage(text: str) -> Usage:
|
|
410
|
+
response_tokens = _estimate_string_tokens(text)
|
|
411
|
+
return Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
pydantic_ai/models/vertexai.py
CHANGED
|
@@ -10,7 +10,7 @@ from httpx import AsyncClient as AsyncHTTPClient
|
|
|
10
10
|
from .._utils import run_in_executor
|
|
11
11
|
from ..exceptions import UserError
|
|
12
12
|
from ..tools import ToolDefinition
|
|
13
|
-
from . import Model, cached_async_http_client
|
|
13
|
+
from . import Model, cached_async_http_client, check_allow_model_requests
|
|
14
14
|
from .gemini import GeminiAgentModel, GeminiModelName
|
|
15
15
|
|
|
16
16
|
try:
|
|
@@ -114,6 +114,7 @@ class VertexAIModel(Model):
|
|
|
114
114
|
allow_text_result: bool,
|
|
115
115
|
result_tools: list[ToolDefinition],
|
|
116
116
|
) -> GeminiAgentModel:
|
|
117
|
+
check_allow_model_requests()
|
|
117
118
|
url, auth = await self.ainit()
|
|
118
119
|
return GeminiAgentModel(
|
|
119
120
|
http_client=self.http_client,
|