pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +7 -25
- pydantic_ai/_result.py +34 -16
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +9 -2
- pydantic_ai/agent.py +333 -148
- pydantic_ai/messages.py +87 -48
- pydantic_ai/models/__init__.py +30 -6
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +59 -31
- pydantic_ai/models/gemini.py +150 -108
- pydantic_ai/models/groq.py +94 -74
- pydantic_ai/models/mistral.py +680 -0
- pydantic_ai/models/ollama.py +1 -1
- pydantic_ai/models/openai.py +102 -76
- pydantic_ai/models/test.py +62 -51
- pydantic_ai/models/vertexai.py +7 -3
- pydantic_ai/result.py +35 -37
- pydantic_ai/settings.py +72 -0
- pydantic_ai/tools.py +28 -18
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.13.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.13.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.13.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +0 -23
pydantic_ai/models/openai.py
CHANGED
|
@@ -4,23 +4,29 @@ from collections.abc import AsyncIterator, Iterable
|
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
7
|
+
from itertools import chain
|
|
7
8
|
from typing import Literal, Union, overload
|
|
8
9
|
|
|
9
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
10
11
|
from typing_extensions import assert_never
|
|
11
12
|
|
|
12
13
|
from .. import UnexpectedModelBehavior, _utils, result
|
|
14
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
15
|
from ..messages import (
|
|
14
16
|
ArgsJson,
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
17
|
+
ModelMessage,
|
|
18
|
+
ModelRequest,
|
|
19
|
+
ModelResponse,
|
|
20
|
+
ModelResponsePart,
|
|
21
|
+
RetryPromptPart,
|
|
22
|
+
SystemPromptPart,
|
|
23
|
+
TextPart,
|
|
24
|
+
ToolCallPart,
|
|
25
|
+
ToolReturnPart,
|
|
26
|
+
UserPromptPart,
|
|
22
27
|
)
|
|
23
28
|
from ..result import Cost
|
|
29
|
+
from ..settings import ModelSettings
|
|
24
30
|
from ..tools import ToolDefinition
|
|
25
31
|
from . import (
|
|
26
32
|
AgentModel,
|
|
@@ -40,7 +46,7 @@ try:
|
|
|
40
46
|
except ImportError as _import_error:
|
|
41
47
|
raise ImportError(
|
|
42
48
|
'Please install `openai` to use the OpenAI model, '
|
|
43
|
-
"you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
|
|
49
|
+
"you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
|
|
44
50
|
) from _import_error
|
|
45
51
|
|
|
46
52
|
OpenAIModelName = Union[ChatModel, str]
|
|
@@ -66,6 +72,7 @@ class OpenAIModel(Model):
|
|
|
66
72
|
self,
|
|
67
73
|
model_name: OpenAIModelName,
|
|
68
74
|
*,
|
|
75
|
+
base_url: str | None = None,
|
|
69
76
|
api_key: str | None = None,
|
|
70
77
|
openai_client: AsyncOpenAI | None = None,
|
|
71
78
|
http_client: AsyncHTTPClient | None = None,
|
|
@@ -76,22 +83,25 @@ class OpenAIModel(Model):
|
|
|
76
83
|
model_name: The name of the OpenAI model to use. List of model names available
|
|
77
84
|
[here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
|
|
78
85
|
(Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
|
|
86
|
+
base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable
|
|
87
|
+
will be used if available. Otherwise, defaults to OpenAI's base url.
|
|
79
88
|
api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
|
|
80
89
|
will be used if available.
|
|
81
90
|
openai_client: An existing
|
|
82
91
|
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
83
|
-
client to use
|
|
92
|
+
client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
|
|
84
93
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
85
94
|
"""
|
|
86
95
|
self.model_name: OpenAIModelName = model_name
|
|
87
96
|
if openai_client is not None:
|
|
88
97
|
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
98
|
+
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
89
99
|
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
90
100
|
self.client = openai_client
|
|
91
101
|
elif http_client is not None:
|
|
92
|
-
self.client = AsyncOpenAI(api_key=api_key, http_client=http_client)
|
|
102
|
+
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
|
|
93
103
|
else:
|
|
94
|
-
self.client = AsyncOpenAI(api_key=api_key, http_client=cached_async_http_client())
|
|
104
|
+
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
95
105
|
|
|
96
106
|
async def agent_model(
|
|
97
107
|
self,
|
|
@@ -135,28 +145,34 @@ class OpenAIAgentModel(AgentModel):
|
|
|
135
145
|
allow_text_result: bool
|
|
136
146
|
tools: list[chat.ChatCompletionToolParam]
|
|
137
147
|
|
|
138
|
-
async def request(
|
|
139
|
-
|
|
148
|
+
async def request(
|
|
149
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
150
|
+
) -> tuple[ModelResponse, result.Cost]:
|
|
151
|
+
response = await self._completions_create(messages, False, model_settings)
|
|
140
152
|
return self._process_response(response), _map_cost(response)
|
|
141
153
|
|
|
142
154
|
@asynccontextmanager
|
|
143
|
-
async def request_stream(
|
|
144
|
-
|
|
155
|
+
async def request_stream(
|
|
156
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
157
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
158
|
+
response = await self._completions_create(messages, True, model_settings)
|
|
145
159
|
async with response:
|
|
146
160
|
yield await self._process_streamed_response(response)
|
|
147
161
|
|
|
148
162
|
@overload
|
|
149
163
|
async def _completions_create(
|
|
150
|
-
self, messages: list[
|
|
164
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
|
|
151
165
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
152
166
|
pass
|
|
153
167
|
|
|
154
168
|
@overload
|
|
155
|
-
async def _completions_create(
|
|
169
|
+
async def _completions_create(
|
|
170
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
|
|
171
|
+
) -> chat.ChatCompletion:
|
|
156
172
|
pass
|
|
157
173
|
|
|
158
174
|
async def _completions_create(
|
|
159
|
-
self, messages: list[
|
|
175
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
|
|
160
176
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
161
177
|
# standalone function to make it easier to override
|
|
162
178
|
if not self.tools:
|
|
@@ -166,7 +182,10 @@ class OpenAIAgentModel(AgentModel):
|
|
|
166
182
|
else:
|
|
167
183
|
tool_choice = 'auto'
|
|
168
184
|
|
|
169
|
-
openai_messages =
|
|
185
|
+
openai_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
186
|
+
|
|
187
|
+
model_settings = model_settings or {}
|
|
188
|
+
|
|
170
189
|
return await self.client.chat.completions.create(
|
|
171
190
|
model=self.model_name,
|
|
172
191
|
messages=openai_messages,
|
|
@@ -176,21 +195,24 @@ class OpenAIAgentModel(AgentModel):
|
|
|
176
195
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
177
196
|
stream=stream,
|
|
178
197
|
stream_options={'include_usage': True} if stream else NOT_GIVEN,
|
|
198
|
+
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
199
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
200
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
201
|
+
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
179
202
|
)
|
|
180
203
|
|
|
181
204
|
@staticmethod
|
|
182
|
-
def _process_response(response: chat.ChatCompletion) ->
|
|
205
|
+
def _process_response(response: chat.ChatCompletion) -> ModelResponse:
|
|
183
206
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
184
207
|
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
185
208
|
choice = response.choices[0]
|
|
209
|
+
items: list[ModelResponsePart] = []
|
|
210
|
+
if choice.message.content is not None:
|
|
211
|
+
items.append(TextPart(choice.message.content))
|
|
186
212
|
if choice.message.tool_calls is not None:
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
)
|
|
191
|
-
else:
|
|
192
|
-
assert choice.message.content is not None, choice
|
|
193
|
-
return ModelTextResponse(choice.message.content, timestamp=timestamp)
|
|
213
|
+
for c in choice.message.tool_calls:
|
|
214
|
+
items.append(ToolCallPart.from_json(c.function.name, c.function.arguments, c.id))
|
|
215
|
+
return ModelResponse(items, timestamp=timestamp)
|
|
194
216
|
|
|
195
217
|
@staticmethod
|
|
196
218
|
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
|
@@ -221,47 +243,57 @@ class OpenAIAgentModel(AgentModel):
|
|
|
221
243
|
)
|
|
222
244
|
# else continue until we get either delta.content or delta.tool_calls
|
|
223
245
|
|
|
224
|
-
@
|
|
225
|
-
def _map_message(message:
|
|
246
|
+
@classmethod
|
|
247
|
+
def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
226
248
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
|
|
227
|
-
if message
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
)
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
tool_call_id=_guard_tool_id(message),
|
|
248
|
-
content=message.model_response(),
|
|
249
|
-
)
|
|
250
|
-
elif message.role == 'model-text-response':
|
|
251
|
-
# ModelTextResponse ->
|
|
252
|
-
return chat.ChatCompletionAssistantMessageParam(role='assistant', content=message.content)
|
|
253
|
-
elif message.role == 'model-structured-response':
|
|
254
|
-
assert (
|
|
255
|
-
message.role == 'model-structured-response'
|
|
256
|
-
), f'Expected role to be "llm-tool-calls", got {message.role}'
|
|
257
|
-
# ModelStructuredResponse ->
|
|
258
|
-
return chat.ChatCompletionAssistantMessageParam(
|
|
259
|
-
role='assistant',
|
|
260
|
-
tool_calls=[_map_tool_call(t) for t in message.calls],
|
|
261
|
-
)
|
|
249
|
+
if isinstance(message, ModelRequest):
|
|
250
|
+
yield from cls._map_user_message(message)
|
|
251
|
+
elif isinstance(message, ModelResponse):
|
|
252
|
+
texts: list[str] = []
|
|
253
|
+
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
|
|
254
|
+
for item in message.parts:
|
|
255
|
+
if isinstance(item, TextPart):
|
|
256
|
+
texts.append(item.content)
|
|
257
|
+
elif isinstance(item, ToolCallPart):
|
|
258
|
+
tool_calls.append(_map_tool_call(item))
|
|
259
|
+
else:
|
|
260
|
+
assert_never(item)
|
|
261
|
+
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
|
|
262
|
+
if texts:
|
|
263
|
+
# Note: model responses from this model should only have one text item, so the following
|
|
264
|
+
# shouldn't merge multiple texts into one unless you switch models between runs:
|
|
265
|
+
message_param['content'] = '\n\n'.join(texts)
|
|
266
|
+
if tool_calls:
|
|
267
|
+
message_param['tool_calls'] = tool_calls
|
|
268
|
+
yield message_param
|
|
262
269
|
else:
|
|
263
270
|
assert_never(message)
|
|
264
271
|
|
|
272
|
+
@classmethod
|
|
273
|
+
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
274
|
+
for part in message.parts:
|
|
275
|
+
if isinstance(part, SystemPromptPart):
|
|
276
|
+
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
|
|
277
|
+
elif isinstance(part, UserPromptPart):
|
|
278
|
+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
|
|
279
|
+
elif isinstance(part, ToolReturnPart):
|
|
280
|
+
yield chat.ChatCompletionToolMessageParam(
|
|
281
|
+
role='tool',
|
|
282
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
|
|
283
|
+
content=part.model_response_str(),
|
|
284
|
+
)
|
|
285
|
+
elif isinstance(part, RetryPromptPart):
|
|
286
|
+
if part.tool_name is None:
|
|
287
|
+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
|
|
288
|
+
else:
|
|
289
|
+
yield chat.ChatCompletionToolMessageParam(
|
|
290
|
+
role='tool',
|
|
291
|
+
tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
|
|
292
|
+
content=part.model_response(),
|
|
293
|
+
)
|
|
294
|
+
else:
|
|
295
|
+
assert_never(part)
|
|
296
|
+
|
|
265
297
|
|
|
266
298
|
@dataclass
|
|
267
299
|
class OpenAIStreamTextResponse(StreamTextResponse):
|
|
@@ -335,14 +367,14 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
|
|
|
335
367
|
else:
|
|
336
368
|
self._delta_tool_calls[new.index] = new
|
|
337
369
|
|
|
338
|
-
def get(self, *, final: bool = False) ->
|
|
339
|
-
|
|
370
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
371
|
+
items: list[ModelResponsePart] = []
|
|
340
372
|
for c in self._delta_tool_calls.values():
|
|
341
373
|
if f := c.function:
|
|
342
374
|
if f.name is not None and f.arguments is not None:
|
|
343
|
-
|
|
375
|
+
items.append(ToolCallPart.from_json(f.name, f.arguments, c.id))
|
|
344
376
|
|
|
345
|
-
return
|
|
377
|
+
return ModelResponse(items, timestamp=self._timestamp)
|
|
346
378
|
|
|
347
379
|
def cost(self) -> Cost:
|
|
348
380
|
return self._cost
|
|
@@ -351,16 +383,10 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse):
|
|
|
351
383
|
return self._timestamp
|
|
352
384
|
|
|
353
385
|
|
|
354
|
-
def
|
|
355
|
-
"""Type guard that checks a `tool_id` is not None both for static typing and runtime."""
|
|
356
|
-
assert t.tool_id is not None, f'OpenAI requires `tool_id` to be set: {t}'
|
|
357
|
-
return t.tool_id
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
def _map_tool_call(t: ToolCall) -> chat.ChatCompletionMessageToolCallParam:
|
|
386
|
+
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
361
387
|
assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
|
|
362
388
|
return chat.ChatCompletionMessageToolCallParam(
|
|
363
|
-
id=
|
|
389
|
+
id=_guard_tool_call_id(t=t, model_source='OpenAI'),
|
|
364
390
|
type='function',
|
|
365
391
|
function={'name': t.tool_name, 'arguments': t.args.args_json},
|
|
366
392
|
)
|
pydantic_ai/models/test.py
CHANGED
|
@@ -9,18 +9,20 @@ from datetime import date, datetime, timedelta
|
|
|
9
9
|
from typing import Any, Literal
|
|
10
10
|
|
|
11
11
|
import pydantic_core
|
|
12
|
+
from typing_extensions import assert_never
|
|
12
13
|
|
|
13
14
|
from .. import _utils
|
|
14
15
|
from ..messages import (
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
16
|
+
ModelMessage,
|
|
17
|
+
ModelRequest,
|
|
18
|
+
ModelResponse,
|
|
19
|
+
RetryPromptPart,
|
|
20
|
+
TextPart,
|
|
21
|
+
ToolCallPart,
|
|
22
|
+
ToolReturnPart,
|
|
22
23
|
)
|
|
23
24
|
from ..result import Cost
|
|
25
|
+
from ..settings import ModelSettings
|
|
24
26
|
from ..tools import ToolDefinition
|
|
25
27
|
from . import (
|
|
26
28
|
AgentModel,
|
|
@@ -127,74 +129,83 @@ class TestAgentModel(AgentModel):
|
|
|
127
129
|
result_tools: list[ToolDefinition]
|
|
128
130
|
seed: int
|
|
129
131
|
|
|
130
|
-
async def request(
|
|
131
|
-
|
|
132
|
+
async def request(
|
|
133
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
134
|
+
) -> tuple[ModelResponse, Cost]:
|
|
135
|
+
return self._request(messages, model_settings), Cost()
|
|
132
136
|
|
|
133
137
|
@asynccontextmanager
|
|
134
|
-
async def request_stream(
|
|
135
|
-
|
|
138
|
+
async def request_stream(
|
|
139
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
140
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
141
|
+
msg = self._request(messages, model_settings)
|
|
136
142
|
cost = Cost()
|
|
137
|
-
|
|
138
|
-
|
|
143
|
+
|
|
144
|
+
# TODO: Rework this once we make StreamTextResponse more general
|
|
145
|
+
texts: list[str] = []
|
|
146
|
+
tool_calls: list[ToolCallPart] = []
|
|
147
|
+
for item in msg.parts:
|
|
148
|
+
if isinstance(item, TextPart):
|
|
149
|
+
texts.append(item.content)
|
|
150
|
+
elif isinstance(item, ToolCallPart):
|
|
151
|
+
tool_calls.append(item)
|
|
152
|
+
else:
|
|
153
|
+
assert_never(item)
|
|
154
|
+
|
|
155
|
+
if texts:
|
|
156
|
+
yield TestStreamTextResponse('\n\n'.join(texts), cost)
|
|
139
157
|
else:
|
|
140
158
|
yield TestStreamStructuredResponse(msg, cost)
|
|
141
159
|
|
|
142
160
|
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
143
161
|
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
|
|
144
162
|
|
|
145
|
-
def _request(self, messages: list[
|
|
163
|
+
def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | None) -> ModelResponse:
|
|
146
164
|
# if there are tools, the first thing we want to do is call all of them
|
|
147
|
-
if self.tool_calls and not any(m
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
165
|
+
if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
|
|
166
|
+
return ModelResponse(
|
|
167
|
+
parts=[ToolCallPart.from_dict(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
if messages:
|
|
171
|
+
last_message = messages[-1]
|
|
172
|
+
assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.'
|
|
173
|
+
|
|
174
|
+
# check if there are any retry prompts, if so retry them
|
|
175
|
+
new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
|
|
176
|
+
if new_retry_names:
|
|
177
|
+
return ModelResponse(
|
|
178
|
+
parts=[
|
|
179
|
+
ToolCallPart.from_dict(name, self.gen_tool_args(args))
|
|
180
|
+
for name, args in self.tool_calls
|
|
181
|
+
if name in new_retry_names
|
|
182
|
+
]
|
|
183
|
+
)
|
|
163
184
|
|
|
164
185
|
if response_text := self.result.left:
|
|
165
186
|
if response_text.value is None:
|
|
166
187
|
# build up details of tool responses
|
|
167
188
|
output: dict[str, Any] = {}
|
|
168
189
|
for message in messages:
|
|
169
|
-
if isinstance(message,
|
|
170
|
-
|
|
190
|
+
if isinstance(message, ModelRequest):
|
|
191
|
+
for part in message.parts:
|
|
192
|
+
if isinstance(part, ToolReturnPart):
|
|
193
|
+
output[part.tool_name] = part.content
|
|
171
194
|
if output:
|
|
172
|
-
return
|
|
195
|
+
return ModelResponse.from_text(pydantic_core.to_json(output).decode())
|
|
173
196
|
else:
|
|
174
|
-
return
|
|
197
|
+
return ModelResponse.from_text('success (no tool calls)')
|
|
175
198
|
else:
|
|
176
|
-
return
|
|
199
|
+
return ModelResponse.from_text(response_text.value)
|
|
177
200
|
else:
|
|
178
201
|
assert self.result_tools, 'No result tools provided'
|
|
179
202
|
custom_result_args = self.result.right
|
|
180
203
|
result_tool = self.result_tools[self.seed % len(self.result_tools)]
|
|
181
204
|
if custom_result_args is not None:
|
|
182
|
-
return
|
|
205
|
+
return ModelResponse(parts=[ToolCallPart.from_dict(result_tool.name, custom_result_args)])
|
|
183
206
|
else:
|
|
184
207
|
response_args = self.gen_tool_args(result_tool)
|
|
185
|
-
return
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
def _get_new_messages(messages: list[Message]) -> list[Message]:
|
|
189
|
-
last_model_index = None
|
|
190
|
-
for i, m in enumerate(messages):
|
|
191
|
-
if m.role in ('model-structured-response', 'model-text-response'):
|
|
192
|
-
last_model_index = i
|
|
193
|
-
|
|
194
|
-
if last_model_index is not None:
|
|
195
|
-
return messages[last_model_index + 1 :]
|
|
196
|
-
else:
|
|
197
|
-
return []
|
|
208
|
+
return ModelResponse(parts=[ToolCallPart.from_dict(result_tool.name, response_args)])
|
|
198
209
|
|
|
199
210
|
|
|
200
211
|
@dataclass
|
|
@@ -234,7 +245,7 @@ class TestStreamTextResponse(StreamTextResponse):
|
|
|
234
245
|
class TestStreamStructuredResponse(StreamStructuredResponse):
|
|
235
246
|
"""A structured response that streams test data."""
|
|
236
247
|
|
|
237
|
-
_structured_response:
|
|
248
|
+
_structured_response: ModelResponse
|
|
238
249
|
_cost: Cost
|
|
239
250
|
_iter: Iterator[None] = field(default_factory=lambda: iter([None]))
|
|
240
251
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
@@ -242,7 +253,7 @@ class TestStreamStructuredResponse(StreamStructuredResponse):
|
|
|
242
253
|
async def __anext__(self) -> None:
|
|
243
254
|
return _utils.sync_anext(self._iter)
|
|
244
255
|
|
|
245
|
-
def get(self, *, final: bool = False) ->
|
|
256
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
246
257
|
return self._structured_response
|
|
247
258
|
|
|
248
259
|
def cost(self) -> Cost:
|
pydantic_ai/models/vertexai.py
CHANGED
|
@@ -21,7 +21,7 @@ try:
|
|
|
21
21
|
except ImportError as _import_error:
|
|
22
22
|
raise ImportError(
|
|
23
23
|
'Please install `google-auth` to use the VertexAI model, '
|
|
24
|
-
"you can use the `vertexai` optional group — `pip install 'pydantic-ai[vertexai]'`"
|
|
24
|
+
"you can use the `vertexai` optional group — `pip install 'pydantic-ai-slim[vertexai]'`"
|
|
25
25
|
) from _import_error
|
|
26
26
|
|
|
27
27
|
VERTEX_AI_URL_TEMPLATE = (
|
|
@@ -114,7 +114,7 @@ class VertexAIModel(Model):
|
|
|
114
114
|
allow_text_result: bool,
|
|
115
115
|
result_tools: list[ToolDefinition],
|
|
116
116
|
) -> GeminiAgentModel:
|
|
117
|
-
url, auth = await self.
|
|
117
|
+
url, auth = await self.ainit()
|
|
118
118
|
return GeminiAgentModel(
|
|
119
119
|
http_client=self.http_client,
|
|
120
120
|
model_name=self.model_name,
|
|
@@ -125,7 +125,11 @@ class VertexAIModel(Model):
|
|
|
125
125
|
result_tools=result_tools,
|
|
126
126
|
)
|
|
127
127
|
|
|
128
|
-
async def
|
|
128
|
+
async def ainit(self) -> tuple[str, BearerTokenAuth]:
|
|
129
|
+
"""Initialize the model, setting the URL and auth.
|
|
130
|
+
|
|
131
|
+
This will raise an error if authentication fails.
|
|
132
|
+
"""
|
|
129
133
|
if self.url is not None and self.auth is not None:
|
|
130
134
|
return self.url, self.auth
|
|
131
135
|
|