pydantic-ai-slim 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_parts_manager.py +1 -1
- pydantic_ai/_pydantic.py +1 -0
- pydantic_ai/_result.py +29 -28
- pydantic_ai/_system_prompt.py +4 -4
- pydantic_ai/_utils.py +1 -56
- pydantic_ai/agent.py +137 -113
- pydantic_ai/messages.py +24 -56
- pydantic_ai/models/__init__.py +122 -51
- pydantic_ai/models/anthropic.py +109 -38
- pydantic_ai/models/cohere.py +290 -0
- pydantic_ai/models/function.py +12 -8
- pydantic_ai/models/gemini.py +29 -15
- pydantic_ai/models/groq.py +27 -23
- pydantic_ai/models/mistral.py +34 -29
- pydantic_ai/models/openai.py +45 -23
- pydantic_ai/models/test.py +47 -24
- pydantic_ai/models/vertexai.py +2 -1
- pydantic_ai/result.py +45 -26
- pydantic_ai/settings.py +58 -1
- pydantic_ai/tools.py +29 -26
- {pydantic_ai_slim-0.0.19.dist-info → pydantic_ai_slim-0.0.21.dist-info}/METADATA +6 -4
- pydantic_ai_slim-0.0.21.dist-info/RECORD +29 -0
- pydantic_ai/models/ollama.py +0 -120
- pydantic_ai_slim-0.0.19.dist-info/RECORD +0 -29
- {pydantic_ai_slim-0.0.19.dist-info → pydantic_ai_slim-0.0.21.dist-info}/WHEEL +0 -0
pydantic_ai/models/mistral.py
CHANGED
|
@@ -6,7 +6,7 @@ from contextlib import asynccontextmanager
|
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime, timezone
|
|
8
8
|
from itertools import chain
|
|
9
|
-
from typing import Any, Callable, Literal, Union
|
|
9
|
+
from typing import Any, Callable, Literal, Union, cast
|
|
10
10
|
|
|
11
11
|
import pydantic_core
|
|
12
12
|
from httpx import AsyncClient as AsyncHTTPClient, Timeout
|
|
@@ -15,7 +15,6 @@ from typing_extensions import assert_never
|
|
|
15
15
|
from .. import UnexpectedModelBehavior, _utils
|
|
16
16
|
from .._utils import now_utc as _now_utc
|
|
17
17
|
from ..messages import (
|
|
18
|
-
ArgsJson,
|
|
19
18
|
ModelMessage,
|
|
20
19
|
ModelRequest,
|
|
21
20
|
ModelResponse,
|
|
@@ -36,6 +35,7 @@ from . import (
|
|
|
36
35
|
Model,
|
|
37
36
|
StreamedResponse,
|
|
38
37
|
cached_async_http_client,
|
|
38
|
+
check_allow_model_requests,
|
|
39
39
|
)
|
|
40
40
|
|
|
41
41
|
try:
|
|
@@ -84,6 +84,12 @@ Since [the Mistral docs](https://docs.mistral.ai/getting-started/models/models_o
|
|
|
84
84
|
"""
|
|
85
85
|
|
|
86
86
|
|
|
87
|
+
class MistralModelSettings(ModelSettings):
|
|
88
|
+
"""Settings used for a Mistral model request."""
|
|
89
|
+
|
|
90
|
+
# This class is a placeholder for any future mistral-specific settings
|
|
91
|
+
|
|
92
|
+
|
|
87
93
|
@dataclass(init=False)
|
|
88
94
|
class MistralModel(Model):
|
|
89
95
|
"""A model that uses Mistral.
|
|
@@ -130,6 +136,7 @@ class MistralModel(Model):
|
|
|
130
136
|
result_tools: list[ToolDefinition],
|
|
131
137
|
) -> AgentModel:
|
|
132
138
|
"""Create an agent model, this is called for each step of an agent run from Pydantic AI call."""
|
|
139
|
+
check_allow_model_requests()
|
|
133
140
|
return MistralAgentModel(
|
|
134
141
|
self.client,
|
|
135
142
|
self.model_name,
|
|
@@ -147,7 +154,7 @@ class MistralAgentModel(AgentModel):
|
|
|
147
154
|
"""Implementation of `AgentModel` for Mistral models."""
|
|
148
155
|
|
|
149
156
|
client: Mistral
|
|
150
|
-
model_name:
|
|
157
|
+
model_name: MistralModelName
|
|
151
158
|
allow_text_result: bool
|
|
152
159
|
function_tools: list[ToolDefinition]
|
|
153
160
|
result_tools: list[ToolDefinition]
|
|
@@ -157,7 +164,7 @@ class MistralAgentModel(AgentModel):
|
|
|
157
164
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
158
165
|
) -> tuple[ModelResponse, Usage]:
|
|
159
166
|
"""Make a non-streaming request to the model from Pydantic AI call."""
|
|
160
|
-
response = await self._completions_create(messages, model_settings)
|
|
167
|
+
response = await self._completions_create(messages, cast(MistralModelSettings, model_settings or {}))
|
|
161
168
|
return self._process_response(response), _map_usage(response)
|
|
162
169
|
|
|
163
170
|
@asynccontextmanager
|
|
@@ -165,15 +172,14 @@ class MistralAgentModel(AgentModel):
|
|
|
165
172
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
166
173
|
) -> AsyncIterator[StreamedResponse]:
|
|
167
174
|
"""Make a streaming request to the model from Pydantic AI call."""
|
|
168
|
-
response = await self._stream_completions_create(messages, model_settings)
|
|
175
|
+
response = await self._stream_completions_create(messages, cast(MistralModelSettings, model_settings or {}))
|
|
169
176
|
async with response:
|
|
170
177
|
yield await self._process_streamed_response(self.result_tools, response)
|
|
171
178
|
|
|
172
179
|
async def _completions_create(
|
|
173
|
-
self, messages: list[ModelMessage], model_settings:
|
|
180
|
+
self, messages: list[ModelMessage], model_settings: MistralModelSettings
|
|
174
181
|
) -> MistralChatCompletionResponse:
|
|
175
182
|
"""Make a non-streaming request to the model."""
|
|
176
|
-
model_settings = model_settings or {}
|
|
177
183
|
response = await self.client.chat.complete_async(
|
|
178
184
|
model=str(self.model_name),
|
|
179
185
|
messages=list(chain(*(self._map_message(m) for m in messages))),
|
|
@@ -185,6 +191,7 @@ class MistralAgentModel(AgentModel):
|
|
|
185
191
|
temperature=model_settings.get('temperature', UNSET),
|
|
186
192
|
top_p=model_settings.get('top_p', 1),
|
|
187
193
|
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
|
|
194
|
+
random_seed=model_settings.get('seed', UNSET),
|
|
188
195
|
)
|
|
189
196
|
assert response, 'A unexpected empty response from Mistral.'
|
|
190
197
|
return response
|
|
@@ -192,12 +199,11 @@ class MistralAgentModel(AgentModel):
|
|
|
192
199
|
async def _stream_completions_create(
|
|
193
200
|
self,
|
|
194
201
|
messages: list[ModelMessage],
|
|
195
|
-
model_settings:
|
|
202
|
+
model_settings: MistralModelSettings,
|
|
196
203
|
) -> MistralEventStreamAsync[MistralCompletionEvent]:
|
|
197
204
|
"""Create a streaming completion request to the Mistral model."""
|
|
198
205
|
response: MistralEventStreamAsync[MistralCompletionEvent] | None
|
|
199
206
|
mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
200
|
-
model_settings = model_settings or {}
|
|
201
207
|
|
|
202
208
|
if self.result_tools and self.function_tools or self.function_tools:
|
|
203
209
|
# Function Calling
|
|
@@ -211,6 +217,8 @@ class MistralAgentModel(AgentModel):
|
|
|
211
217
|
top_p=model_settings.get('top_p', 1),
|
|
212
218
|
max_tokens=model_settings.get('max_tokens', UNSET),
|
|
213
219
|
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
|
|
220
|
+
presence_penalty=model_settings.get('presence_penalty'),
|
|
221
|
+
frequency_penalty=model_settings.get('frequency_penalty'),
|
|
214
222
|
)
|
|
215
223
|
|
|
216
224
|
elif self.result_tools:
|
|
@@ -265,8 +273,7 @@ class MistralAgentModel(AgentModel):
|
|
|
265
273
|
]
|
|
266
274
|
return tools if tools else None
|
|
267
275
|
|
|
268
|
-
|
|
269
|
-
def _process_response(response: MistralChatCompletionResponse) -> ModelResponse:
|
|
276
|
+
def _process_response(self, response: MistralChatCompletionResponse) -> ModelResponse:
|
|
270
277
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
271
278
|
assert response.choices, 'Unexpected empty response choice.'
|
|
272
279
|
|
|
@@ -288,10 +295,10 @@ class MistralAgentModel(AgentModel):
|
|
|
288
295
|
tool = _map_mistral_to_pydantic_tool_call(tool_call=tool_call)
|
|
289
296
|
parts.append(tool)
|
|
290
297
|
|
|
291
|
-
return ModelResponse(parts, timestamp=timestamp)
|
|
298
|
+
return ModelResponse(parts, model_name=self.model_name, timestamp=timestamp)
|
|
292
299
|
|
|
293
|
-
@staticmethod
|
|
294
300
|
async def _process_streamed_response(
|
|
301
|
+
self,
|
|
295
302
|
result_tools: list[ToolDefinition],
|
|
296
303
|
response: MistralEventStreamAsync[MistralCompletionEvent],
|
|
297
304
|
) -> StreamedResponse:
|
|
@@ -306,23 +313,21 @@ class MistralAgentModel(AgentModel):
|
|
|
306
313
|
else:
|
|
307
314
|
timestamp = datetime.now(tz=timezone.utc)
|
|
308
315
|
|
|
309
|
-
return MistralStreamedResponse(
|
|
316
|
+
return MistralStreamedResponse(
|
|
317
|
+
_response=peekable_response,
|
|
318
|
+
_model_name=self.model_name,
|
|
319
|
+
_timestamp=timestamp,
|
|
320
|
+
_result_tools={c.name: c for c in result_tools},
|
|
321
|
+
)
|
|
310
322
|
|
|
311
323
|
@staticmethod
|
|
312
324
|
def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
|
|
313
325
|
"""Maps a pydantic-ai ToolCall to a MistralToolCall."""
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
)
|
|
320
|
-
else:
|
|
321
|
-
return MistralToolCall(
|
|
322
|
-
id=t.tool_call_id,
|
|
323
|
-
type='function',
|
|
324
|
-
function=MistralFunctionCall(name=t.tool_name, arguments=t.args.args_dict),
|
|
325
|
-
)
|
|
326
|
+
return MistralToolCall(
|
|
327
|
+
id=t.tool_call_id,
|
|
328
|
+
type='function',
|
|
329
|
+
function=MistralFunctionCall(name=t.tool_name, arguments=t.args),
|
|
330
|
+
)
|
|
326
331
|
|
|
327
332
|
def _generate_user_output_format(self, schemas: list[dict[str, Any]]) -> MistralUserMessage:
|
|
328
333
|
"""Get a message with an example of the expected output format."""
|
|
@@ -505,7 +510,7 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
505
510
|
continue
|
|
506
511
|
|
|
507
512
|
# The following part_id will be thrown away
|
|
508
|
-
return ToolCallPart
|
|
513
|
+
return ToolCallPart(tool_name=result_tool.name, args=output_json)
|
|
509
514
|
|
|
510
515
|
@staticmethod
|
|
511
516
|
def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
|
|
@@ -563,7 +568,7 @@ def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPa
|
|
|
563
568
|
tool_call_id = tool_call.id or None
|
|
564
569
|
func_call = tool_call.function
|
|
565
570
|
|
|
566
|
-
return ToolCallPart
|
|
571
|
+
return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
|
|
567
572
|
|
|
568
573
|
|
|
569
574
|
def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
|
|
@@ -594,7 +599,7 @@ def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None
|
|
|
594
599
|
elif isinstance(content, str):
|
|
595
600
|
result = content
|
|
596
601
|
|
|
597
|
-
# Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and
|
|
602
|
+
# Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and responses`)
|
|
598
603
|
if result and len(result) == 0:
|
|
599
604
|
result = None
|
|
600
605
|
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
|
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
7
7
|
from itertools import chain
|
|
8
|
-
from typing import Literal, Union, overload
|
|
8
|
+
from typing import Literal, Union, cast, overload
|
|
9
9
|
|
|
10
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
11
11
|
from typing_extensions import assert_never
|
|
@@ -48,9 +48,17 @@ except ImportError as _import_error:
|
|
|
48
48
|
OpenAIModelName = Union[ChatModel, str]
|
|
49
49
|
"""
|
|
50
50
|
Using this more broad type for the model name instead of the ChatModel definition
|
|
51
|
-
allows this model to be used more easily with other model types (ie, Ollama)
|
|
51
|
+
allows this model to be used more easily with other model types (ie, Ollama, Deepseek)
|
|
52
52
|
"""
|
|
53
53
|
|
|
54
|
+
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class OpenAIModelSettings(ModelSettings):
|
|
58
|
+
"""Settings used for an OpenAI model request."""
|
|
59
|
+
|
|
60
|
+
# This class is a placeholder for any future openai-specific settings
|
|
61
|
+
|
|
54
62
|
|
|
55
63
|
@dataclass(init=False)
|
|
56
64
|
class OpenAIModel(Model):
|
|
@@ -63,6 +71,7 @@ class OpenAIModel(Model):
|
|
|
63
71
|
|
|
64
72
|
model_name: OpenAIModelName
|
|
65
73
|
client: AsyncOpenAI = field(repr=False)
|
|
74
|
+
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
|
|
66
75
|
|
|
67
76
|
def __init__(
|
|
68
77
|
self,
|
|
@@ -72,6 +81,7 @@ class OpenAIModel(Model):
|
|
|
72
81
|
api_key: str | None = None,
|
|
73
82
|
openai_client: AsyncOpenAI | None = None,
|
|
74
83
|
http_client: AsyncHTTPClient | None = None,
|
|
84
|
+
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
75
85
|
):
|
|
76
86
|
"""Initialize an OpenAI model.
|
|
77
87
|
|
|
@@ -87,6 +97,8 @@ class OpenAIModel(Model):
|
|
|
87
97
|
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
88
98
|
client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
|
|
89
99
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
100
|
+
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
|
|
101
|
+
In the future, this may be inferred from the model name.
|
|
90
102
|
"""
|
|
91
103
|
self.model_name: OpenAIModelName = model_name
|
|
92
104
|
if openai_client is not None:
|
|
@@ -98,6 +110,7 @@ class OpenAIModel(Model):
|
|
|
98
110
|
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
|
|
99
111
|
else:
|
|
100
112
|
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
113
|
+
self.system_prompt_role = system_prompt_role
|
|
101
114
|
|
|
102
115
|
async def agent_model(
|
|
103
116
|
self,
|
|
@@ -115,6 +128,7 @@ class OpenAIModel(Model):
|
|
|
115
128
|
self.model_name,
|
|
116
129
|
allow_text_result,
|
|
117
130
|
tools,
|
|
131
|
+
self.system_prompt_role,
|
|
118
132
|
)
|
|
119
133
|
|
|
120
134
|
def name(self) -> str:
|
|
@@ -140,35 +154,36 @@ class OpenAIAgentModel(AgentModel):
|
|
|
140
154
|
model_name: OpenAIModelName
|
|
141
155
|
allow_text_result: bool
|
|
142
156
|
tools: list[chat.ChatCompletionToolParam]
|
|
157
|
+
system_prompt_role: OpenAISystemPromptRole | None
|
|
143
158
|
|
|
144
159
|
async def request(
|
|
145
160
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
146
161
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
147
|
-
response = await self._completions_create(messages, False, model_settings)
|
|
162
|
+
response = await self._completions_create(messages, False, cast(OpenAIModelSettings, model_settings or {}))
|
|
148
163
|
return self._process_response(response), _map_usage(response)
|
|
149
164
|
|
|
150
165
|
@asynccontextmanager
|
|
151
166
|
async def request_stream(
|
|
152
167
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
153
168
|
) -> AsyncIterator[StreamedResponse]:
|
|
154
|
-
response = await self._completions_create(messages, True, model_settings)
|
|
169
|
+
response = await self._completions_create(messages, True, cast(OpenAIModelSettings, model_settings or {}))
|
|
155
170
|
async with response:
|
|
156
171
|
yield await self._process_streamed_response(response)
|
|
157
172
|
|
|
158
173
|
@overload
|
|
159
174
|
async def _completions_create(
|
|
160
|
-
self, messages: list[ModelMessage], stream: Literal[True], model_settings:
|
|
175
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: OpenAIModelSettings
|
|
161
176
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
162
177
|
pass
|
|
163
178
|
|
|
164
179
|
@overload
|
|
165
180
|
async def _completions_create(
|
|
166
|
-
self, messages: list[ModelMessage], stream: Literal[False], model_settings:
|
|
181
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: OpenAIModelSettings
|
|
167
182
|
) -> chat.ChatCompletion:
|
|
168
183
|
pass
|
|
169
184
|
|
|
170
185
|
async def _completions_create(
|
|
171
|
-
self, messages: list[ModelMessage], stream: bool, model_settings:
|
|
186
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: OpenAIModelSettings
|
|
172
187
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
173
188
|
# standalone function to make it easier to override
|
|
174
189
|
if not self.tools:
|
|
@@ -180,13 +195,11 @@ class OpenAIAgentModel(AgentModel):
|
|
|
180
195
|
|
|
181
196
|
openai_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
182
197
|
|
|
183
|
-
model_settings = model_settings or {}
|
|
184
|
-
|
|
185
198
|
return await self.client.chat.completions.create(
|
|
186
199
|
model=self.model_name,
|
|
187
200
|
messages=openai_messages,
|
|
188
201
|
n=1,
|
|
189
|
-
parallel_tool_calls=
|
|
202
|
+
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
190
203
|
tools=self.tools or NOT_GIVEN,
|
|
191
204
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
192
205
|
stream=stream,
|
|
@@ -195,10 +208,13 @@ class OpenAIAgentModel(AgentModel):
|
|
|
195
208
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
196
209
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
197
210
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
211
|
+
seed=model_settings.get('seed', NOT_GIVEN),
|
|
212
|
+
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
|
|
213
|
+
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
214
|
+
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
198
215
|
)
|
|
199
216
|
|
|
200
|
-
|
|
201
|
-
def _process_response(response: chat.ChatCompletion) -> ModelResponse:
|
|
217
|
+
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
202
218
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
203
219
|
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
204
220
|
choice = response.choices[0]
|
|
@@ -207,24 +223,26 @@ class OpenAIAgentModel(AgentModel):
|
|
|
207
223
|
items.append(TextPart(choice.message.content))
|
|
208
224
|
if choice.message.tool_calls is not None:
|
|
209
225
|
for c in choice.message.tool_calls:
|
|
210
|
-
items.append(ToolCallPart
|
|
211
|
-
return ModelResponse(items, timestamp=timestamp)
|
|
226
|
+
items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
|
|
227
|
+
return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
|
|
212
228
|
|
|
213
|
-
|
|
214
|
-
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
|
|
229
|
+
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
|
|
215
230
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
216
231
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
217
232
|
first_chunk = await peekable_response.peek()
|
|
218
233
|
if isinstance(first_chunk, _utils.Unset):
|
|
219
234
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
220
235
|
|
|
221
|
-
return OpenAIStreamedResponse(
|
|
236
|
+
return OpenAIStreamedResponse(
|
|
237
|
+
_model_name=self.model_name,
|
|
238
|
+
_response=peekable_response,
|
|
239
|
+
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
240
|
+
)
|
|
222
241
|
|
|
223
|
-
|
|
224
|
-
def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
242
|
+
def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
225
243
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
|
|
226
244
|
if isinstance(message, ModelRequest):
|
|
227
|
-
yield from
|
|
245
|
+
yield from self._map_user_message(message)
|
|
228
246
|
elif isinstance(message, ModelResponse):
|
|
229
247
|
texts: list[str] = []
|
|
230
248
|
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
|
|
@@ -246,11 +264,15 @@ class OpenAIAgentModel(AgentModel):
|
|
|
246
264
|
else:
|
|
247
265
|
assert_never(message)
|
|
248
266
|
|
|
249
|
-
|
|
250
|
-
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
267
|
+
def _map_user_message(self, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
251
268
|
for part in message.parts:
|
|
252
269
|
if isinstance(part, SystemPromptPart):
|
|
253
|
-
|
|
270
|
+
if self.system_prompt_role == 'developer':
|
|
271
|
+
yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content)
|
|
272
|
+
elif self.system_prompt_role == 'user':
|
|
273
|
+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
|
|
274
|
+
else:
|
|
275
|
+
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
|
|
254
276
|
elif isinstance(part, UserPromptPart):
|
|
255
277
|
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
|
|
256
278
|
elif isinstance(part, ToolReturnPart):
|
pydantic_ai/models/test.py
CHANGED
|
@@ -12,7 +12,6 @@ import pydantic_core
|
|
|
12
12
|
|
|
13
13
|
from .. import _utils
|
|
14
14
|
from ..messages import (
|
|
15
|
-
ArgsJson,
|
|
16
15
|
ModelMessage,
|
|
17
16
|
ModelRequest,
|
|
18
17
|
ModelResponse,
|
|
@@ -34,6 +33,20 @@ from . import (
|
|
|
34
33
|
from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
|
|
35
34
|
|
|
36
35
|
|
|
36
|
+
@dataclass
|
|
37
|
+
class _TextResult:
|
|
38
|
+
"""A private wrapper class to tag a result that came from the custom_result_text field."""
|
|
39
|
+
|
|
40
|
+
value: str | None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class _FunctionToolResult:
|
|
45
|
+
"""A wrapper class to tag a result that came from the custom_result_args field."""
|
|
46
|
+
|
|
47
|
+
value: Any | None
|
|
48
|
+
|
|
49
|
+
|
|
37
50
|
@dataclass
|
|
38
51
|
class TestModel(Model):
|
|
39
52
|
"""A model specifically for testing purposes.
|
|
@@ -53,7 +66,7 @@ class TestModel(Model):
|
|
|
53
66
|
call_tools: list[str] | Literal['all'] = 'all'
|
|
54
67
|
"""List of tools to call. If `'all'`, all tools will be called."""
|
|
55
68
|
custom_result_text: str | None = None
|
|
56
|
-
"""If set, this text is
|
|
69
|
+
"""If set, this text is returned as the final result."""
|
|
57
70
|
custom_result_args: Any | None = None
|
|
58
71
|
"""If set, these args will be passed to the result tool."""
|
|
59
72
|
seed: int = 0
|
|
@@ -95,21 +108,21 @@ class TestModel(Model):
|
|
|
95
108
|
if self.custom_result_text is not None:
|
|
96
109
|
assert allow_text_result, 'Plain response not allowed, but `custom_result_text` is set.'
|
|
97
110
|
assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
|
|
98
|
-
result:
|
|
111
|
+
result: _TextResult | _FunctionToolResult = _TextResult(self.custom_result_text)
|
|
99
112
|
elif self.custom_result_args is not None:
|
|
100
113
|
assert result_tools is not None, 'No result tools provided, but `custom_result_args` is set.'
|
|
101
114
|
result_tool = result_tools[0]
|
|
102
115
|
|
|
103
116
|
if k := result_tool.outer_typed_dict_key:
|
|
104
|
-
result =
|
|
117
|
+
result = _FunctionToolResult({k: self.custom_result_args})
|
|
105
118
|
else:
|
|
106
|
-
result =
|
|
119
|
+
result = _FunctionToolResult(self.custom_result_args)
|
|
107
120
|
elif allow_text_result:
|
|
108
|
-
result =
|
|
121
|
+
result = _TextResult(None)
|
|
109
122
|
elif result_tools:
|
|
110
|
-
result =
|
|
123
|
+
result = _FunctionToolResult(None)
|
|
111
124
|
else:
|
|
112
|
-
result =
|
|
125
|
+
result = _TextResult(None)
|
|
113
126
|
|
|
114
127
|
return TestAgentModel(tool_calls, result, result_tools, self.seed)
|
|
115
128
|
|
|
@@ -126,9 +139,10 @@ class TestAgentModel(AgentModel):
|
|
|
126
139
|
|
|
127
140
|
tool_calls: list[tuple[str, ToolDefinition]]
|
|
128
141
|
# left means the text is plain text; right means it's a function call
|
|
129
|
-
result:
|
|
142
|
+
result: _TextResult | _FunctionToolResult
|
|
130
143
|
result_tools: list[ToolDefinition]
|
|
131
144
|
seed: int
|
|
145
|
+
model_name: str = 'test'
|
|
132
146
|
|
|
133
147
|
async def request(
|
|
134
148
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
@@ -142,7 +156,7 @@ class TestAgentModel(AgentModel):
|
|
|
142
156
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
143
157
|
) -> AsyncIterator[StreamedResponse]:
|
|
144
158
|
model_response = self._request(messages, model_settings)
|
|
145
|
-
yield TestStreamedResponse(model_response, messages)
|
|
159
|
+
yield TestStreamedResponse(_model_name=self.model_name, _structured_response=model_response, _messages=messages)
|
|
146
160
|
|
|
147
161
|
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
148
162
|
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
|
|
@@ -151,7 +165,8 @@ class TestAgentModel(AgentModel):
|
|
|
151
165
|
# if there are tools, the first thing we want to do is call all of them
|
|
152
166
|
if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
|
|
153
167
|
return ModelResponse(
|
|
154
|
-
parts=[ToolCallPart
|
|
168
|
+
parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in self.tool_calls],
|
|
169
|
+
model_name=self.model_name,
|
|
155
170
|
)
|
|
156
171
|
|
|
157
172
|
if messages:
|
|
@@ -164,7 +179,7 @@ class TestAgentModel(AgentModel):
|
|
|
164
179
|
# Handle retries for both function tools and result tools
|
|
165
180
|
# Check function tools first
|
|
166
181
|
retry_parts: list[ModelResponsePart] = [
|
|
167
|
-
ToolCallPart
|
|
182
|
+
ToolCallPart(name, self.gen_tool_args(args))
|
|
168
183
|
for name, args in self.tool_calls
|
|
169
184
|
if name in new_retry_names
|
|
170
185
|
]
|
|
@@ -172,15 +187,20 @@ class TestAgentModel(AgentModel):
|
|
|
172
187
|
if self.result_tools:
|
|
173
188
|
retry_parts.extend(
|
|
174
189
|
[
|
|
175
|
-
ToolCallPart
|
|
190
|
+
ToolCallPart(
|
|
191
|
+
tool.name,
|
|
192
|
+
self.result.value
|
|
193
|
+
if isinstance(self.result, _FunctionToolResult) and self.result.value is not None
|
|
194
|
+
else self.gen_tool_args(tool),
|
|
195
|
+
)
|
|
176
196
|
for tool in self.result_tools
|
|
177
197
|
if tool.name in new_retry_names
|
|
178
198
|
]
|
|
179
199
|
)
|
|
180
|
-
return ModelResponse(parts=retry_parts)
|
|
200
|
+
return ModelResponse(parts=retry_parts, model_name=self.model_name)
|
|
181
201
|
|
|
182
|
-
if
|
|
183
|
-
if response_text.value is None:
|
|
202
|
+
if isinstance(self.result, _TextResult):
|
|
203
|
+
if (response_text := self.result.value) is None:
|
|
184
204
|
# build up details of tool responses
|
|
185
205
|
output: dict[str, Any] = {}
|
|
186
206
|
for message in messages:
|
|
@@ -189,20 +209,24 @@ class TestAgentModel(AgentModel):
|
|
|
189
209
|
if isinstance(part, ToolReturnPart):
|
|
190
210
|
output[part.tool_name] = part.content
|
|
191
211
|
if output:
|
|
192
|
-
return ModelResponse
|
|
212
|
+
return ModelResponse(
|
|
213
|
+
parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self.model_name
|
|
214
|
+
)
|
|
193
215
|
else:
|
|
194
|
-
return ModelResponse
|
|
216
|
+
return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.model_name)
|
|
195
217
|
else:
|
|
196
|
-
return ModelResponse
|
|
218
|
+
return ModelResponse(parts=[TextPart(response_text)], model_name=self.model_name)
|
|
197
219
|
else:
|
|
198
220
|
assert self.result_tools, 'No result tools provided'
|
|
199
|
-
custom_result_args = self.result.
|
|
221
|
+
custom_result_args = self.result.value
|
|
200
222
|
result_tool = self.result_tools[self.seed % len(self.result_tools)]
|
|
201
223
|
if custom_result_args is not None:
|
|
202
|
-
return ModelResponse(
|
|
224
|
+
return ModelResponse(
|
|
225
|
+
parts=[ToolCallPart(result_tool.name, custom_result_args)], model_name=self.model_name
|
|
226
|
+
)
|
|
203
227
|
else:
|
|
204
228
|
response_args = self.gen_tool_args(result_tool)
|
|
205
|
-
return ModelResponse(parts=[ToolCallPart
|
|
229
|
+
return ModelResponse(parts=[ToolCallPart(result_tool.name, response_args)], model_name=self.model_name)
|
|
206
230
|
|
|
207
231
|
|
|
208
232
|
@dataclass
|
|
@@ -233,9 +257,8 @@ class TestStreamedResponse(StreamedResponse):
|
|
|
233
257
|
self._usage += _get_string_usage(word)
|
|
234
258
|
yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
|
|
235
259
|
else:
|
|
236
|
-
args = part.args.args_json if isinstance(part.args, ArgsJson) else part.args.args_dict
|
|
237
260
|
yield self._parts_manager.handle_tool_call_part(
|
|
238
|
-
vendor_part_id=i, tool_name=part.tool_name, args=args, tool_call_id=part.tool_call_id
|
|
261
|
+
vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
|
|
239
262
|
)
|
|
240
263
|
|
|
241
264
|
def timestamp(self) -> datetime:
|
pydantic_ai/models/vertexai.py
CHANGED
|
@@ -10,7 +10,7 @@ from httpx import AsyncClient as AsyncHTTPClient
|
|
|
10
10
|
from .._utils import run_in_executor
|
|
11
11
|
from ..exceptions import UserError
|
|
12
12
|
from ..tools import ToolDefinition
|
|
13
|
-
from . import Model, cached_async_http_client
|
|
13
|
+
from . import Model, cached_async_http_client, check_allow_model_requests
|
|
14
14
|
from .gemini import GeminiAgentModel, GeminiModelName
|
|
15
15
|
|
|
16
16
|
try:
|
|
@@ -114,6 +114,7 @@ class VertexAIModel(Model):
|
|
|
114
114
|
allow_text_result: bool,
|
|
115
115
|
result_tools: list[ToolDefinition],
|
|
116
116
|
) -> GeminiAgentModel:
|
|
117
|
+
check_allow_model_requests()
|
|
117
118
|
url, auth = await self.ainit()
|
|
118
119
|
return GeminiAgentModel(
|
|
119
120
|
http_client=self.http_client,
|