pydantic-ai-slim 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +1 -0
- pydantic_ai/_result.py +26 -21
- pydantic_ai/_system_prompt.py +4 -4
- pydantic_ai/agent.py +107 -87
- pydantic_ai/messages.py +3 -10
- pydantic_ai/models/__init__.py +29 -1
- pydantic_ai/models/anthropic.py +94 -30
- pydantic_ai/models/cohere.py +278 -0
- pydantic_ai/models/function.py +12 -8
- pydantic_ai/models/gemini.py +9 -9
- pydantic_ai/models/groq.py +9 -7
- pydantic_ai/models/mistral.py +12 -6
- pydantic_ai/models/ollama.py +3 -0
- pydantic_ai/models/openai.py +27 -13
- pydantic_ai/models/test.py +16 -8
- pydantic_ai/models/vertexai.py +2 -1
- pydantic_ai/result.py +45 -26
- pydantic_ai/settings.py +18 -1
- pydantic_ai/tools.py +18 -18
- {pydantic_ai_slim-0.0.19.dist-info → pydantic_ai_slim-0.0.20.dist-info}/METADATA +6 -4
- pydantic_ai_slim-0.0.20.dist-info/RECORD +30 -0
- pydantic_ai_slim-0.0.19.dist-info/RECORD +0 -29
- {pydantic_ai_slim-0.0.19.dist-info → pydantic_ai_slim-0.0.20.dist-info}/WHEEL +0 -0
pydantic_ai/models/groq.py
CHANGED
|
@@ -197,7 +197,7 @@ class GroqAgentModel(AgentModel):
|
|
|
197
197
|
model=str(self.model_name),
|
|
198
198
|
messages=groq_messages,
|
|
199
199
|
n=1,
|
|
200
|
-
parallel_tool_calls=True if self.tools else NOT_GIVEN,
|
|
200
|
+
parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
|
|
201
201
|
tools=self.tools or NOT_GIVEN,
|
|
202
202
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
203
203
|
stream=stream,
|
|
@@ -207,8 +207,7 @@ class GroqAgentModel(AgentModel):
|
|
|
207
207
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
208
208
|
)
|
|
209
209
|
|
|
210
|
-
|
|
211
|
-
def _process_response(response: chat.ChatCompletion) -> ModelResponse:
|
|
210
|
+
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
212
211
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
213
212
|
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
214
213
|
choice = response.choices[0]
|
|
@@ -220,17 +219,20 @@ class GroqAgentModel(AgentModel):
|
|
|
220
219
|
items.append(
|
|
221
220
|
ToolCallPart.from_raw_args(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)
|
|
222
221
|
)
|
|
223
|
-
return ModelResponse(items, timestamp=timestamp)
|
|
222
|
+
return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
|
|
224
223
|
|
|
225
|
-
|
|
226
|
-
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
224
|
+
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
|
|
227
225
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
228
226
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
229
227
|
first_chunk = await peekable_response.peek()
|
|
230
228
|
if isinstance(first_chunk, _utils.Unset):
|
|
231
229
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
232
230
|
|
|
233
|
-
return GroqStreamedResponse(
|
|
231
|
+
return GroqStreamedResponse(
|
|
232
|
+
_response=peekable_response,
|
|
233
|
+
_model_name=self.model_name,
|
|
234
|
+
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
235
|
+
)
|
|
234
236
|
|
|
235
237
|
@classmethod
|
|
236
238
|
def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -36,6 +36,7 @@ from . import (
|
|
|
36
36
|
Model,
|
|
37
37
|
StreamedResponse,
|
|
38
38
|
cached_async_http_client,
|
|
39
|
+
check_allow_model_requests,
|
|
39
40
|
)
|
|
40
41
|
|
|
41
42
|
try:
|
|
@@ -130,6 +131,7 @@ class MistralModel(Model):
|
|
|
130
131
|
result_tools: list[ToolDefinition],
|
|
131
132
|
) -> AgentModel:
|
|
132
133
|
"""Create an agent model, this is called for each step of an agent run from Pydantic AI call."""
|
|
134
|
+
check_allow_model_requests()
|
|
133
135
|
return MistralAgentModel(
|
|
134
136
|
self.client,
|
|
135
137
|
self.model_name,
|
|
@@ -147,7 +149,7 @@ class MistralAgentModel(AgentModel):
|
|
|
147
149
|
"""Implementation of `AgentModel` for Mistral models."""
|
|
148
150
|
|
|
149
151
|
client: Mistral
|
|
150
|
-
model_name:
|
|
152
|
+
model_name: MistralModelName
|
|
151
153
|
allow_text_result: bool
|
|
152
154
|
function_tools: list[ToolDefinition]
|
|
153
155
|
result_tools: list[ToolDefinition]
|
|
@@ -265,8 +267,7 @@ class MistralAgentModel(AgentModel):
|
|
|
265
267
|
]
|
|
266
268
|
return tools if tools else None
|
|
267
269
|
|
|
268
|
-
|
|
269
|
-
def _process_response(response: MistralChatCompletionResponse) -> ModelResponse:
|
|
270
|
+
def _process_response(self, response: MistralChatCompletionResponse) -> ModelResponse:
|
|
270
271
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
271
272
|
assert response.choices, 'Unexpected empty response choice.'
|
|
272
273
|
|
|
@@ -288,10 +289,10 @@ class MistralAgentModel(AgentModel):
|
|
|
288
289
|
tool = _map_mistral_to_pydantic_tool_call(tool_call=tool_call)
|
|
289
290
|
parts.append(tool)
|
|
290
291
|
|
|
291
|
-
return ModelResponse(parts, timestamp=timestamp)
|
|
292
|
+
return ModelResponse(parts, model_name=self.model_name, timestamp=timestamp)
|
|
292
293
|
|
|
293
|
-
@staticmethod
|
|
294
294
|
async def _process_streamed_response(
|
|
295
|
+
self,
|
|
295
296
|
result_tools: list[ToolDefinition],
|
|
296
297
|
response: MistralEventStreamAsync[MistralCompletionEvent],
|
|
297
298
|
) -> StreamedResponse:
|
|
@@ -306,7 +307,12 @@ class MistralAgentModel(AgentModel):
|
|
|
306
307
|
else:
|
|
307
308
|
timestamp = datetime.now(tz=timezone.utc)
|
|
308
309
|
|
|
309
|
-
return MistralStreamedResponse(
|
|
310
|
+
return MistralStreamedResponse(
|
|
311
|
+
_response=peekable_response,
|
|
312
|
+
_model_name=self.model_name,
|
|
313
|
+
_timestamp=timestamp,
|
|
314
|
+
_result_tools={c.name: c for c in result_tools},
|
|
315
|
+
)
|
|
310
316
|
|
|
311
317
|
@staticmethod
|
|
312
318
|
def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
|
pydantic_ai/models/ollama.py
CHANGED
|
@@ -10,6 +10,7 @@ from . import (
|
|
|
10
10
|
AgentModel,
|
|
11
11
|
Model,
|
|
12
12
|
cached_async_http_client,
|
|
13
|
+
check_allow_model_requests,
|
|
13
14
|
)
|
|
14
15
|
|
|
15
16
|
try:
|
|
@@ -25,6 +26,7 @@ from .openai import OpenAIModel
|
|
|
25
26
|
|
|
26
27
|
CommonOllamaModelNames = Literal[
|
|
27
28
|
'codellama',
|
|
29
|
+
'deepseek-r1',
|
|
28
30
|
'gemma',
|
|
29
31
|
'gemma2',
|
|
30
32
|
'llama3',
|
|
@@ -110,6 +112,7 @@ class OllamaModel(Model):
|
|
|
110
112
|
allow_text_result: bool,
|
|
111
113
|
result_tools: list[ToolDefinition],
|
|
112
114
|
) -> AgentModel:
|
|
115
|
+
check_allow_model_requests()
|
|
113
116
|
return await self.openai_model.agent_model(
|
|
114
117
|
function_tools=function_tools,
|
|
115
118
|
allow_text_result=allow_text_result,
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -51,6 +51,8 @@ Using this more broad type for the model name instead of the ChatModel definitio
|
|
|
51
51
|
allows this model to be used more easily with other model types (ie, Ollama)
|
|
52
52
|
"""
|
|
53
53
|
|
|
54
|
+
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
55
|
+
|
|
54
56
|
|
|
55
57
|
@dataclass(init=False)
|
|
56
58
|
class OpenAIModel(Model):
|
|
@@ -63,6 +65,7 @@ class OpenAIModel(Model):
|
|
|
63
65
|
|
|
64
66
|
model_name: OpenAIModelName
|
|
65
67
|
client: AsyncOpenAI = field(repr=False)
|
|
68
|
+
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
|
|
66
69
|
|
|
67
70
|
def __init__(
|
|
68
71
|
self,
|
|
@@ -72,6 +75,7 @@ class OpenAIModel(Model):
|
|
|
72
75
|
api_key: str | None = None,
|
|
73
76
|
openai_client: AsyncOpenAI | None = None,
|
|
74
77
|
http_client: AsyncHTTPClient | None = None,
|
|
78
|
+
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
75
79
|
):
|
|
76
80
|
"""Initialize an OpenAI model.
|
|
77
81
|
|
|
@@ -87,6 +91,8 @@ class OpenAIModel(Model):
|
|
|
87
91
|
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
88
92
|
client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
|
|
89
93
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
94
|
+
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
|
|
95
|
+
In the future, this may be inferred from the model name.
|
|
90
96
|
"""
|
|
91
97
|
self.model_name: OpenAIModelName = model_name
|
|
92
98
|
if openai_client is not None:
|
|
@@ -98,6 +104,7 @@ class OpenAIModel(Model):
|
|
|
98
104
|
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
|
|
99
105
|
else:
|
|
100
106
|
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
107
|
+
self.system_prompt_role = system_prompt_role
|
|
101
108
|
|
|
102
109
|
async def agent_model(
|
|
103
110
|
self,
|
|
@@ -115,6 +122,7 @@ class OpenAIModel(Model):
|
|
|
115
122
|
self.model_name,
|
|
116
123
|
allow_text_result,
|
|
117
124
|
tools,
|
|
125
|
+
self.system_prompt_role,
|
|
118
126
|
)
|
|
119
127
|
|
|
120
128
|
def name(self) -> str:
|
|
@@ -140,6 +148,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
140
148
|
model_name: OpenAIModelName
|
|
141
149
|
allow_text_result: bool
|
|
142
150
|
tools: list[chat.ChatCompletionToolParam]
|
|
151
|
+
system_prompt_role: OpenAISystemPromptRole | None
|
|
143
152
|
|
|
144
153
|
async def request(
|
|
145
154
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
@@ -186,7 +195,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
186
195
|
model=self.model_name,
|
|
187
196
|
messages=openai_messages,
|
|
188
197
|
n=1,
|
|
189
|
-
parallel_tool_calls=True if self.tools else NOT_GIVEN,
|
|
198
|
+
parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
|
|
190
199
|
tools=self.tools or NOT_GIVEN,
|
|
191
200
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
192
201
|
stream=stream,
|
|
@@ -197,8 +206,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
197
206
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
198
207
|
)
|
|
199
208
|
|
|
200
|
-
|
|
201
|
-
def _process_response(response: chat.ChatCompletion) -> ModelResponse:
|
|
209
|
+
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
202
210
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
203
211
|
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
204
212
|
choice = response.choices[0]
|
|
@@ -208,23 +216,25 @@ class OpenAIAgentModel(AgentModel):
|
|
|
208
216
|
if choice.message.tool_calls is not None:
|
|
209
217
|
for c in choice.message.tool_calls:
|
|
210
218
|
items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
|
|
211
|
-
return ModelResponse(items, timestamp=timestamp)
|
|
219
|
+
return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
|
|
212
220
|
|
|
213
|
-
|
|
214
|
-
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
|
|
221
|
+
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
|
|
215
222
|
"""Process a streamed response, and prepare a streaming response to return."""
|
|
216
223
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
217
224
|
first_chunk = await peekable_response.peek()
|
|
218
225
|
if isinstance(first_chunk, _utils.Unset):
|
|
219
226
|
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
220
227
|
|
|
221
|
-
return OpenAIStreamedResponse(
|
|
228
|
+
return OpenAIStreamedResponse(
|
|
229
|
+
_model_name=self.model_name,
|
|
230
|
+
_response=peekable_response,
|
|
231
|
+
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
232
|
+
)
|
|
222
233
|
|
|
223
|
-
|
|
224
|
-
def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
234
|
+
def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
225
235
|
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
|
|
226
236
|
if isinstance(message, ModelRequest):
|
|
227
|
-
yield from
|
|
237
|
+
yield from self._map_user_message(message)
|
|
228
238
|
elif isinstance(message, ModelResponse):
|
|
229
239
|
texts: list[str] = []
|
|
230
240
|
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
|
|
@@ -246,11 +256,15 @@ class OpenAIAgentModel(AgentModel):
|
|
|
246
256
|
else:
|
|
247
257
|
assert_never(message)
|
|
248
258
|
|
|
249
|
-
|
|
250
|
-
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
259
|
+
def _map_user_message(self, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
251
260
|
for part in message.parts:
|
|
252
261
|
if isinstance(part, SystemPromptPart):
|
|
253
|
-
|
|
262
|
+
if self.system_prompt_role == 'developer':
|
|
263
|
+
yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content)
|
|
264
|
+
elif self.system_prompt_role == 'user':
|
|
265
|
+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
|
|
266
|
+
else:
|
|
267
|
+
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
|
|
254
268
|
elif isinstance(part, UserPromptPart):
|
|
255
269
|
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
|
|
256
270
|
elif isinstance(part, ToolReturnPart):
|
pydantic_ai/models/test.py
CHANGED
|
@@ -129,6 +129,7 @@ class TestAgentModel(AgentModel):
|
|
|
129
129
|
result: _utils.Either[str | None, Any | None]
|
|
130
130
|
result_tools: list[ToolDefinition]
|
|
131
131
|
seed: int
|
|
132
|
+
model_name: str = 'test'
|
|
132
133
|
|
|
133
134
|
async def request(
|
|
134
135
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
@@ -142,7 +143,7 @@ class TestAgentModel(AgentModel):
|
|
|
142
143
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
143
144
|
) -> AsyncIterator[StreamedResponse]:
|
|
144
145
|
model_response = self._request(messages, model_settings)
|
|
145
|
-
yield TestStreamedResponse(model_response, messages)
|
|
146
|
+
yield TestStreamedResponse(_model_name=self.model_name, _structured_response=model_response, _messages=messages)
|
|
146
147
|
|
|
147
148
|
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
|
|
148
149
|
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
|
|
@@ -151,7 +152,8 @@ class TestAgentModel(AgentModel):
|
|
|
151
152
|
# if there are tools, the first thing we want to do is call all of them
|
|
152
153
|
if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
|
|
153
154
|
return ModelResponse(
|
|
154
|
-
parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
|
|
155
|
+
parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls],
|
|
156
|
+
model_name=self.model_name,
|
|
155
157
|
)
|
|
156
158
|
|
|
157
159
|
if messages:
|
|
@@ -177,7 +179,7 @@ class TestAgentModel(AgentModel):
|
|
|
177
179
|
if tool.name in new_retry_names
|
|
178
180
|
]
|
|
179
181
|
)
|
|
180
|
-
return ModelResponse(parts=retry_parts)
|
|
182
|
+
return ModelResponse(parts=retry_parts, model_name=self.model_name)
|
|
181
183
|
|
|
182
184
|
if response_text := self.result.left:
|
|
183
185
|
if response_text.value is None:
|
|
@@ -189,20 +191,26 @@ class TestAgentModel(AgentModel):
|
|
|
189
191
|
if isinstance(part, ToolReturnPart):
|
|
190
192
|
output[part.tool_name] = part.content
|
|
191
193
|
if output:
|
|
192
|
-
return ModelResponse
|
|
194
|
+
return ModelResponse(
|
|
195
|
+
parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self.model_name
|
|
196
|
+
)
|
|
193
197
|
else:
|
|
194
|
-
return ModelResponse
|
|
198
|
+
return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self.model_name)
|
|
195
199
|
else:
|
|
196
|
-
return ModelResponse
|
|
200
|
+
return ModelResponse(parts=[TextPart(response_text.value)], model_name=self.model_name)
|
|
197
201
|
else:
|
|
198
202
|
assert self.result_tools, 'No result tools provided'
|
|
199
203
|
custom_result_args = self.result.right
|
|
200
204
|
result_tool = self.result_tools[self.seed % len(self.result_tools)]
|
|
201
205
|
if custom_result_args is not None:
|
|
202
|
-
return ModelResponse(
|
|
206
|
+
return ModelResponse(
|
|
207
|
+
parts=[ToolCallPart.from_raw_args(result_tool.name, custom_result_args)], model_name=self.model_name
|
|
208
|
+
)
|
|
203
209
|
else:
|
|
204
210
|
response_args = self.gen_tool_args(result_tool)
|
|
205
|
-
return ModelResponse(
|
|
211
|
+
return ModelResponse(
|
|
212
|
+
parts=[ToolCallPart.from_raw_args(result_tool.name, response_args)], model_name=self.model_name
|
|
213
|
+
)
|
|
206
214
|
|
|
207
215
|
|
|
208
216
|
@dataclass
|
pydantic_ai/models/vertexai.py
CHANGED
|
@@ -10,7 +10,7 @@ from httpx import AsyncClient as AsyncHTTPClient
|
|
|
10
10
|
from .._utils import run_in_executor
|
|
11
11
|
from ..exceptions import UserError
|
|
12
12
|
from ..tools import ToolDefinition
|
|
13
|
-
from . import Model, cached_async_http_client
|
|
13
|
+
from . import Model, cached_async_http_client, check_allow_model_requests
|
|
14
14
|
from .gemini import GeminiAgentModel, GeminiModelName
|
|
15
15
|
|
|
16
16
|
try:
|
|
@@ -114,6 +114,7 @@ class VertexAIModel(Model):
|
|
|
114
114
|
allow_text_result: bool,
|
|
115
115
|
result_tools: list[ToolDefinition],
|
|
116
116
|
) -> GeminiAgentModel:
|
|
117
|
+
check_allow_model_requests()
|
|
117
118
|
url, auth = await self.ainit()
|
|
118
119
|
return GeminiAgentModel(
|
|
119
120
|
http_client=self.http_client,
|
pydantic_ai/result.py
CHANGED
|
@@ -11,35 +11,49 @@ import logfire_api
|
|
|
11
11
|
from typing_extensions import TypeVar
|
|
12
12
|
|
|
13
13
|
from . import _result, _utils, exceptions, messages as _messages, models
|
|
14
|
-
from .tools import
|
|
14
|
+
from .tools import AgentDepsT, RunContext
|
|
15
15
|
from .usage import Usage, UsageLimits
|
|
16
16
|
|
|
17
|
-
__all__ = '
|
|
17
|
+
__all__ = 'ResultDataT', 'ResultDataT_inv', 'ResultValidatorFunc', 'RunResult', 'StreamedRunResult'
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
|
|
21
|
-
"""
|
|
20
|
+
T = TypeVar('T')
|
|
21
|
+
"""An invariant TypeVar."""
|
|
22
|
+
ResultDataT_inv = TypeVar('ResultDataT_inv', default=str)
|
|
23
|
+
"""
|
|
24
|
+
An invariant type variable for the result data of a model.
|
|
25
|
+
|
|
26
|
+
We need to use an invariant typevar for `ResultValidator` and `ResultValidatorFunc` because the result data type is used
|
|
27
|
+
in both the input and output of a `ResultValidatorFunc`. This can theoretically lead to some issues assuming that types
|
|
28
|
+
possessing ResultValidator's are covariant in the result data type, but in practice this is rarely an issue, and
|
|
29
|
+
changing it would have negative consequences for the ergonomics of the library.
|
|
30
|
+
|
|
31
|
+
At some point, it may make sense to change the input to ResultValidatorFunc to be `Any` or `object` as doing that would
|
|
32
|
+
resolve these potential variance issues.
|
|
33
|
+
"""
|
|
34
|
+
ResultDataT = TypeVar('ResultDataT', default=str, covariant=True)
|
|
35
|
+
"""Covariant type variable for the result data type of a run."""
|
|
22
36
|
|
|
23
37
|
ResultValidatorFunc = Union[
|
|
24
|
-
Callable[[RunContext[
|
|
25
|
-
Callable[[RunContext[
|
|
26
|
-
Callable[[
|
|
27
|
-
Callable[[
|
|
38
|
+
Callable[[RunContext[AgentDepsT], ResultDataT_inv], ResultDataT_inv],
|
|
39
|
+
Callable[[RunContext[AgentDepsT], ResultDataT_inv], Awaitable[ResultDataT_inv]],
|
|
40
|
+
Callable[[ResultDataT_inv], ResultDataT_inv],
|
|
41
|
+
Callable[[ResultDataT_inv], Awaitable[ResultDataT_inv]],
|
|
28
42
|
]
|
|
29
43
|
"""
|
|
30
|
-
A function that always takes
|
|
44
|
+
A function that always takes and returns the same type of data (which is the result type of an agent run), and:
|
|
31
45
|
|
|
32
46
|
* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
|
|
33
47
|
* may or may not be async
|
|
34
48
|
|
|
35
|
-
Usage `ResultValidatorFunc[AgentDeps,
|
|
49
|
+
Usage `ResultValidatorFunc[AgentDeps, T]`.
|
|
36
50
|
"""
|
|
37
51
|
|
|
38
52
|
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
39
53
|
|
|
40
54
|
|
|
41
55
|
@dataclass
|
|
42
|
-
class _BaseRunResult(ABC, Generic[
|
|
56
|
+
class _BaseRunResult(ABC, Generic[ResultDataT]):
|
|
43
57
|
"""Base type for results.
|
|
44
58
|
|
|
45
59
|
You should not import or use this type directly, instead use its subclasses `RunResult` and `StreamedRunResult`.
|
|
@@ -119,10 +133,10 @@ class _BaseRunResult(ABC, Generic[ResultData]):
|
|
|
119
133
|
|
|
120
134
|
|
|
121
135
|
@dataclass
|
|
122
|
-
class RunResult(_BaseRunResult[
|
|
136
|
+
class RunResult(_BaseRunResult[ResultDataT]):
|
|
123
137
|
"""Result of a non-streamed run."""
|
|
124
138
|
|
|
125
|
-
data:
|
|
139
|
+
data: ResultDataT
|
|
126
140
|
"""Data from the final response in the run."""
|
|
127
141
|
_result_tool_name: str | None
|
|
128
142
|
_usage: Usage
|
|
@@ -165,14 +179,14 @@ class RunResult(_BaseRunResult[ResultData]):
|
|
|
165
179
|
|
|
166
180
|
|
|
167
181
|
@dataclass
|
|
168
|
-
class StreamedRunResult(_BaseRunResult[
|
|
182
|
+
class StreamedRunResult(_BaseRunResult[ResultDataT], Generic[AgentDepsT, ResultDataT]):
|
|
169
183
|
"""Result of a streamed run that returns structured data via a tool call."""
|
|
170
184
|
|
|
171
185
|
_usage_limits: UsageLimits | None
|
|
172
186
|
_stream_response: models.StreamedResponse
|
|
173
|
-
_result_schema: _result.ResultSchema[
|
|
174
|
-
_run_ctx: RunContext[
|
|
175
|
-
_result_validators: list[_result.ResultValidator[
|
|
187
|
+
_result_schema: _result.ResultSchema[ResultDataT] | None
|
|
188
|
+
_run_ctx: RunContext[AgentDepsT]
|
|
189
|
+
_result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]]
|
|
176
190
|
_result_tool_name: str | None
|
|
177
191
|
_on_complete: Callable[[], Awaitable[None]]
|
|
178
192
|
is_complete: bool = field(default=False, init=False)
|
|
@@ -185,7 +199,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
185
199
|
[`get_data`][pydantic_ai.result.StreamedRunResult.get_data] completes.
|
|
186
200
|
"""
|
|
187
201
|
|
|
188
|
-
async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[
|
|
202
|
+
async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]:
|
|
189
203
|
"""Stream the response as an async iterable.
|
|
190
204
|
|
|
191
205
|
The pydantic validator for structured data will be called in
|
|
@@ -269,7 +283,12 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
269
283
|
yield combined_validated_text
|
|
270
284
|
|
|
271
285
|
lf_span.set_attribute('combined_text', combined_validated_text)
|
|
272
|
-
await self._marked_completed(
|
|
286
|
+
await self._marked_completed(
|
|
287
|
+
_messages.ModelResponse(
|
|
288
|
+
parts=[_messages.TextPart(combined_validated_text)],
|
|
289
|
+
model_name=self._stream_response.model_name(),
|
|
290
|
+
)
|
|
291
|
+
)
|
|
273
292
|
|
|
274
293
|
async def stream_structured(
|
|
275
294
|
self, *, debounce_by: float | None = 0.1
|
|
@@ -306,7 +325,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
306
325
|
lf_span.set_attribute('structured_response', msg)
|
|
307
326
|
await self._marked_completed(msg)
|
|
308
327
|
|
|
309
|
-
async def get_data(self) ->
|
|
328
|
+
async def get_data(self) -> ResultDataT:
|
|
310
329
|
"""Stream the whole response, validate and return it."""
|
|
311
330
|
usage_checking_stream = _get_usage_checking_stream_response(
|
|
312
331
|
self._stream_response, self._usage_limits, self.usage
|
|
@@ -332,7 +351,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
332
351
|
|
|
333
352
|
async def validate_structured_result(
|
|
334
353
|
self, message: _messages.ModelResponse, *, allow_partial: bool = False
|
|
335
|
-
) ->
|
|
354
|
+
) -> ResultDataT:
|
|
336
355
|
"""Validate a structured result message."""
|
|
337
356
|
if self._result_schema is not None and self._result_tool_name is not None:
|
|
338
357
|
match = self._result_schema.find_named_tool(message.parts, self._result_tool_name)
|
|
@@ -351,17 +370,17 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
|
|
|
351
370
|
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
|
|
352
371
|
for validator in self._result_validators:
|
|
353
372
|
text = await validator.validate(
|
|
354
|
-
text,
|
|
373
|
+
text,
|
|
355
374
|
None,
|
|
356
375
|
self._run_ctx,
|
|
357
376
|
)
|
|
358
|
-
# Since there is no result tool, we can assume that str is compatible with
|
|
359
|
-
return cast(
|
|
377
|
+
# Since there is no result tool, we can assume that str is compatible with ResultDataT
|
|
378
|
+
return cast(ResultDataT, text)
|
|
360
379
|
|
|
361
380
|
async def _validate_text_result(self, text: str) -> str:
|
|
362
381
|
for validator in self._result_validators:
|
|
363
|
-
text = await validator.validate(
|
|
364
|
-
text,
|
|
382
|
+
text = await validator.validate(
|
|
383
|
+
text,
|
|
365
384
|
None,
|
|
366
385
|
self._run_ctx,
|
|
367
386
|
)
|
pydantic_ai/settings.py
CHANGED
|
@@ -12,7 +12,8 @@ if TYPE_CHECKING:
|
|
|
12
12
|
class ModelSettings(TypedDict, total=False):
|
|
13
13
|
"""Settings to configure an LLM.
|
|
14
14
|
|
|
15
|
-
Here we include only settings which apply to multiple models / model providers
|
|
15
|
+
Here we include only settings which apply to multiple models / model providers,
|
|
16
|
+
though not all of these settings are supported by all models.
|
|
16
17
|
"""
|
|
17
18
|
|
|
18
19
|
max_tokens: int
|
|
@@ -24,6 +25,8 @@ class ModelSettings(TypedDict, total=False):
|
|
|
24
25
|
* Anthropic
|
|
25
26
|
* OpenAI
|
|
26
27
|
* Groq
|
|
28
|
+
* Cohere
|
|
29
|
+
* Mistral
|
|
27
30
|
"""
|
|
28
31
|
|
|
29
32
|
temperature: float
|
|
@@ -40,6 +43,8 @@ class ModelSettings(TypedDict, total=False):
|
|
|
40
43
|
* Anthropic
|
|
41
44
|
* OpenAI
|
|
42
45
|
* Groq
|
|
46
|
+
* Cohere
|
|
47
|
+
* Mistral
|
|
43
48
|
"""
|
|
44
49
|
|
|
45
50
|
top_p: float
|
|
@@ -55,6 +60,8 @@ class ModelSettings(TypedDict, total=False):
|
|
|
55
60
|
* Anthropic
|
|
56
61
|
* OpenAI
|
|
57
62
|
* Groq
|
|
63
|
+
* Cohere
|
|
64
|
+
* Mistral
|
|
58
65
|
"""
|
|
59
66
|
|
|
60
67
|
timeout: float | Timeout
|
|
@@ -66,6 +73,16 @@ class ModelSettings(TypedDict, total=False):
|
|
|
66
73
|
* Anthropic
|
|
67
74
|
* OpenAI
|
|
68
75
|
* Groq
|
|
76
|
+
* Mistral
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
parallel_tool_calls: bool
|
|
80
|
+
"""Whether to allow parallel tool calls.
|
|
81
|
+
|
|
82
|
+
Supported by:
|
|
83
|
+
* OpenAI
|
|
84
|
+
* Groq
|
|
85
|
+
* Anthropic
|
|
69
86
|
"""
|
|
70
87
|
|
|
71
88
|
|