pydantic-ai-slim 0.0.6a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +8 -0
- pydantic_ai/_griffe.py +128 -0
- pydantic_ai/_pydantic.py +216 -0
- pydantic_ai/_result.py +258 -0
- pydantic_ai/_retriever.py +114 -0
- pydantic_ai/_system_prompt.py +33 -0
- pydantic_ai/_utils.py +247 -0
- pydantic_ai/agent.py +795 -0
- pydantic_ai/dependencies.py +83 -0
- pydantic_ai/exceptions.py +56 -0
- pydantic_ai/messages.py +205 -0
- pydantic_ai/models/__init__.py +300 -0
- pydantic_ai/models/function.py +268 -0
- pydantic_ai/models/gemini.py +720 -0
- pydantic_ai/models/groq.py +400 -0
- pydantic_ai/models/openai.py +379 -0
- pydantic_ai/models/test.py +389 -0
- pydantic_ai/models/vertexai.py +306 -0
- pydantic_ai/py.typed +0 -0
- pydantic_ai/result.py +314 -0
- pydantic_ai_slim-0.0.6a1.dist-info/METADATA +49 -0
- pydantic_ai_slim-0.0.6a1.dist-info/RECORD +23 -0
- pydantic_ai_slim-0.0.6a1.dist-info/WHEEL +4 -0
pydantic_ai/agent.py
ADDED
|
@@ -0,0 +1,795 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
5
|
+
from contextlib import asynccontextmanager, contextmanager
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Callable, Generic, cast, final, overload
|
|
8
|
+
|
|
9
|
+
import logfire_api
|
|
10
|
+
from typing_extensions import assert_never
|
|
11
|
+
|
|
12
|
+
from . import (
|
|
13
|
+
_result,
|
|
14
|
+
_retriever as _r,
|
|
15
|
+
_system_prompt,
|
|
16
|
+
_utils,
|
|
17
|
+
exceptions,
|
|
18
|
+
messages as _messages,
|
|
19
|
+
models,
|
|
20
|
+
result,
|
|
21
|
+
)
|
|
22
|
+
from .dependencies import AgentDeps, CallContext, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc
|
|
23
|
+
from .result import ResultData
|
|
24
|
+
|
|
25
|
+
__all__ = ('Agent',)
|
|
26
|
+
|
|
27
|
+
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
28
|
+
|
|
29
|
+
NoneType = type(None)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@final
|
|
33
|
+
@dataclass(init=False)
|
|
34
|
+
class Agent(Generic[AgentDeps, ResultData]):
|
|
35
|
+
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
|
|
36
|
+
|
|
37
|
+
Agents are generic in the dependency type they take [`AgentDeps`][pydantic_ai.dependencies.AgentDeps]
|
|
38
|
+
and the result data type they return, [`ResultData`][pydantic_ai.result.ResultData].
|
|
39
|
+
|
|
40
|
+
By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
|
|
41
|
+
|
|
42
|
+
Minimal usage example:
|
|
43
|
+
|
|
44
|
+
```py
|
|
45
|
+
from pydantic_ai import Agent
|
|
46
|
+
|
|
47
|
+
agent = Agent('openai:gpt-4o')
|
|
48
|
+
result = agent.run_sync('What is the capital of France?')
|
|
49
|
+
print(result.data)
|
|
50
|
+
#> Paris
|
|
51
|
+
```
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
# dataclass fields mostly for my sanity — knowing what attributes are available
|
|
55
|
+
model: models.Model | models.KnownModelName | None
|
|
56
|
+
"""The default model configured for this agent."""
|
|
57
|
+
_result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
|
|
58
|
+
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
|
|
59
|
+
_allow_text_result: bool = field(repr=False)
|
|
60
|
+
_system_prompts: tuple[str, ...] = field(repr=False)
|
|
61
|
+
_retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = field(repr=False)
|
|
62
|
+
_default_retries: int = field(repr=False)
|
|
63
|
+
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
|
|
64
|
+
_deps_type: type[AgentDeps] = field(repr=False)
|
|
65
|
+
_max_result_retries: int = field(repr=False)
|
|
66
|
+
_current_result_retry: int = field(repr=False)
|
|
67
|
+
_override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
|
|
68
|
+
_override_model: _utils.Option[models.Model] = field(default=None, repr=False)
|
|
69
|
+
last_run_messages: list[_messages.Message] | None = None
|
|
70
|
+
"""The messages from the last run, useful when a run raised an exception.
|
|
71
|
+
|
|
72
|
+
Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
78
|
+
result_type: type[ResultData] = str,
|
|
79
|
+
*,
|
|
80
|
+
system_prompt: str | Sequence[str] = (),
|
|
81
|
+
deps_type: type[AgentDeps] = NoneType,
|
|
82
|
+
retries: int = 1,
|
|
83
|
+
result_tool_name: str = 'final_result',
|
|
84
|
+
result_tool_description: str | None = None,
|
|
85
|
+
result_retries: int | None = None,
|
|
86
|
+
defer_model_check: bool = False,
|
|
87
|
+
):
|
|
88
|
+
"""Create an agent.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
model: The default model to use for this agent, if not provide,
|
|
92
|
+
you must provide the model when calling the agent.
|
|
93
|
+
result_type: The type of the result data, used to validate the result data, defaults to `str`.
|
|
94
|
+
system_prompt: Static system prompts to use for this agent, you can also register system
|
|
95
|
+
prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt].
|
|
96
|
+
deps_type: The type used for dependency injection, this parameter exists solely to allow you to fully
|
|
97
|
+
parameterize the agent, and therefore get the best out of static type checking.
|
|
98
|
+
If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright
|
|
99
|
+
or add a type hint `: Agent[None, <return type>]`.
|
|
100
|
+
retries: The default number of retries to allow before raising an error.
|
|
101
|
+
result_tool_name: The name of the tool to use for the final result.
|
|
102
|
+
result_tool_description: The description of the final result tool.
|
|
103
|
+
result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
|
|
104
|
+
defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
|
|
105
|
+
it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
|
|
106
|
+
which checks for the necessary environment variables. Set this to `false`
|
|
107
|
+
to defer the evaluation until the first run. Useful if you want to
|
|
108
|
+
[override the model][pydantic_ai.Agent.override_model] for testing.
|
|
109
|
+
"""
|
|
110
|
+
if model is None or defer_model_check:
|
|
111
|
+
self.model = model
|
|
112
|
+
else:
|
|
113
|
+
self.model = models.infer_model(model)
|
|
114
|
+
|
|
115
|
+
self._result_schema = _result.ResultSchema[result_type].build(
|
|
116
|
+
result_type, result_tool_name, result_tool_description
|
|
117
|
+
)
|
|
118
|
+
# if the result tool is None, or its schema allows `str`, we allow plain text results
|
|
119
|
+
self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result
|
|
120
|
+
|
|
121
|
+
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
|
|
122
|
+
self._retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = {}
|
|
123
|
+
self._deps_type = deps_type
|
|
124
|
+
self._default_retries = retries
|
|
125
|
+
self._system_prompt_functions = []
|
|
126
|
+
self._max_result_retries = result_retries if result_retries is not None else retries
|
|
127
|
+
self._current_result_retry = 0
|
|
128
|
+
self._result_validators = []
|
|
129
|
+
|
|
130
|
+
async def run(
|
|
131
|
+
self,
|
|
132
|
+
user_prompt: str,
|
|
133
|
+
*,
|
|
134
|
+
message_history: list[_messages.Message] | None = None,
|
|
135
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
136
|
+
deps: AgentDeps = None,
|
|
137
|
+
) -> result.RunResult[ResultData]:
|
|
138
|
+
"""Run the agent with a user prompt in async mode.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
user_prompt: User input to start/continue the conversation.
|
|
142
|
+
message_history: History of the conversation so far.
|
|
143
|
+
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
144
|
+
deps: Optional dependencies to use for this run.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
The result of the run.
|
|
148
|
+
"""
|
|
149
|
+
model_used, custom_model, agent_model = await self._get_agent_model(model)
|
|
150
|
+
|
|
151
|
+
deps = self._get_deps(deps)
|
|
152
|
+
|
|
153
|
+
new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
|
|
154
|
+
self.last_run_messages = messages
|
|
155
|
+
|
|
156
|
+
for retriever in self._retrievers.values():
|
|
157
|
+
retriever.reset()
|
|
158
|
+
|
|
159
|
+
cost = result.Cost()
|
|
160
|
+
|
|
161
|
+
with _logfire.span(
|
|
162
|
+
'agent run {prompt=}',
|
|
163
|
+
prompt=user_prompt,
|
|
164
|
+
agent=self,
|
|
165
|
+
custom_model=custom_model,
|
|
166
|
+
model_name=model_used.name(),
|
|
167
|
+
) as run_span:
|
|
168
|
+
run_step = 0
|
|
169
|
+
while True:
|
|
170
|
+
run_step += 1
|
|
171
|
+
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
|
|
172
|
+
model_response, request_cost = await agent_model.request(messages)
|
|
173
|
+
model_req_span.set_attribute('response', model_response)
|
|
174
|
+
model_req_span.set_attribute('cost', request_cost)
|
|
175
|
+
model_req_span.message = f'model request -> {model_response.role}'
|
|
176
|
+
|
|
177
|
+
messages.append(model_response)
|
|
178
|
+
cost += request_cost
|
|
179
|
+
|
|
180
|
+
with _logfire.span('handle model response') as handle_span:
|
|
181
|
+
either = await self._handle_model_response(model_response, deps)
|
|
182
|
+
|
|
183
|
+
if isinstance(either, _MarkFinalResult):
|
|
184
|
+
# we have a final result, end the conversation
|
|
185
|
+
result_data = either.data
|
|
186
|
+
run_span.set_attribute('all_messages', messages)
|
|
187
|
+
run_span.set_attribute('cost', cost)
|
|
188
|
+
handle_span.set_attribute('result', result_data)
|
|
189
|
+
handle_span.message = 'handle model response -> final result'
|
|
190
|
+
return result.RunResult(messages, new_message_index, result_data, cost)
|
|
191
|
+
else:
|
|
192
|
+
# continue the conversation
|
|
193
|
+
tool_responses = either
|
|
194
|
+
handle_span.set_attribute('tool_responses', tool_responses)
|
|
195
|
+
response_msgs = ' '.join(m.role for m in tool_responses)
|
|
196
|
+
handle_span.message = f'handle model response -> {response_msgs}'
|
|
197
|
+
messages.extend(tool_responses)
|
|
198
|
+
|
|
199
|
+
def run_sync(
|
|
200
|
+
self,
|
|
201
|
+
user_prompt: str,
|
|
202
|
+
*,
|
|
203
|
+
message_history: list[_messages.Message] | None = None,
|
|
204
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
205
|
+
deps: AgentDeps = None,
|
|
206
|
+
) -> result.RunResult[ResultData]:
|
|
207
|
+
"""Run the agent with a user prompt synchronously.
|
|
208
|
+
|
|
209
|
+
This is a convenience method that wraps `self.run` with `asyncio.run()`.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
user_prompt: User input to start/continue the conversation.
|
|
213
|
+
message_history: History of the conversation so far.
|
|
214
|
+
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
215
|
+
deps: Optional dependencies to use for this run.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
The result of the run.
|
|
219
|
+
"""
|
|
220
|
+
return asyncio.run(self.run(user_prompt, message_history=message_history, model=model, deps=deps))
|
|
221
|
+
|
|
222
|
+
@asynccontextmanager
|
|
223
|
+
async def run_stream(
|
|
224
|
+
self,
|
|
225
|
+
user_prompt: str,
|
|
226
|
+
*,
|
|
227
|
+
message_history: list[_messages.Message] | None = None,
|
|
228
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
229
|
+
deps: AgentDeps = None,
|
|
230
|
+
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
|
|
231
|
+
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
user_prompt: User input to start/continue the conversation.
|
|
235
|
+
message_history: History of the conversation so far.
|
|
236
|
+
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
237
|
+
deps: Optional dependencies to use for this run.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
The result of the run.
|
|
241
|
+
"""
|
|
242
|
+
model_used, custom_model, agent_model = await self._get_agent_model(model)
|
|
243
|
+
|
|
244
|
+
deps = self._get_deps(deps)
|
|
245
|
+
|
|
246
|
+
new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
|
|
247
|
+
self.last_run_messages = messages
|
|
248
|
+
|
|
249
|
+
for retriever in self._retrievers.values():
|
|
250
|
+
retriever.reset()
|
|
251
|
+
|
|
252
|
+
cost = result.Cost()
|
|
253
|
+
|
|
254
|
+
with _logfire.span(
|
|
255
|
+
'agent run stream {prompt=}',
|
|
256
|
+
prompt=user_prompt,
|
|
257
|
+
agent=self,
|
|
258
|
+
custom_model=custom_model,
|
|
259
|
+
model_name=model_used.name(),
|
|
260
|
+
) as run_span:
|
|
261
|
+
run_step = 0
|
|
262
|
+
while True:
|
|
263
|
+
run_step += 1
|
|
264
|
+
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
|
|
265
|
+
async with agent_model.request_stream(messages) as model_response:
|
|
266
|
+
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
|
|
267
|
+
# We want to end the "model request" span here, but we can't exit the context manager
|
|
268
|
+
# in the traditional way
|
|
269
|
+
model_req_span.__exit__(None, None, None)
|
|
270
|
+
|
|
271
|
+
with _logfire.span('handle model response') as handle_span:
|
|
272
|
+
either = await self._handle_streamed_model_response(model_response, deps)
|
|
273
|
+
|
|
274
|
+
if isinstance(either, _MarkFinalResult):
|
|
275
|
+
result_stream = either.data
|
|
276
|
+
run_span.set_attribute('all_messages', messages)
|
|
277
|
+
handle_span.set_attribute('result_type', result_stream.__class__.__name__)
|
|
278
|
+
handle_span.message = 'handle model response -> final result'
|
|
279
|
+
yield result.StreamedRunResult(
|
|
280
|
+
messages,
|
|
281
|
+
new_message_index,
|
|
282
|
+
cost,
|
|
283
|
+
result_stream,
|
|
284
|
+
self._result_schema,
|
|
285
|
+
deps,
|
|
286
|
+
self._result_validators,
|
|
287
|
+
)
|
|
288
|
+
return
|
|
289
|
+
else:
|
|
290
|
+
tool_responses = either
|
|
291
|
+
handle_span.set_attribute('tool_responses', tool_responses)
|
|
292
|
+
response_msgs = ' '.join(m.role for m in tool_responses)
|
|
293
|
+
handle_span.message = f'handle model response -> {response_msgs}'
|
|
294
|
+
messages.extend(tool_responses)
|
|
295
|
+
# the model_response should have been fully streamed by now, we can add it's cost
|
|
296
|
+
cost += model_response.cost()
|
|
297
|
+
|
|
298
|
+
@contextmanager
|
|
299
|
+
def override_deps(self, overriding_deps: AgentDeps) -> Iterator[None]:
|
|
300
|
+
"""Context manager to temporarily override agent dependencies, this is particularly useful when testing.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
overriding_deps: The dependencies to use instead of the dependencies passed to the agent run.
|
|
304
|
+
"""
|
|
305
|
+
override_deps_before = self._override_deps
|
|
306
|
+
self._override_deps = _utils.Some(overriding_deps)
|
|
307
|
+
try:
|
|
308
|
+
yield
|
|
309
|
+
finally:
|
|
310
|
+
self._override_deps = override_deps_before
|
|
311
|
+
|
|
312
|
+
@contextmanager
|
|
313
|
+
def override_model(self, overriding_model: models.Model | models.KnownModelName) -> Iterator[None]:
|
|
314
|
+
"""Context manager to temporarily override the model used by the agent.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
overriding_model: The model to use instead of the model passed to the agent run.
|
|
318
|
+
"""
|
|
319
|
+
override_model_before = self._override_model
|
|
320
|
+
self._override_model = _utils.Some(models.infer_model(overriding_model))
|
|
321
|
+
try:
|
|
322
|
+
yield
|
|
323
|
+
finally:
|
|
324
|
+
self._override_model = override_model_before
|
|
325
|
+
|
|
326
|
+
@overload
|
|
327
|
+
def system_prompt(
|
|
328
|
+
self, func: Callable[[CallContext[AgentDeps]], str], /
|
|
329
|
+
) -> Callable[[CallContext[AgentDeps]], str]: ...
|
|
330
|
+
|
|
331
|
+
@overload
|
|
332
|
+
def system_prompt(
|
|
333
|
+
self, func: Callable[[CallContext[AgentDeps]], Awaitable[str]], /
|
|
334
|
+
) -> Callable[[CallContext[AgentDeps]], Awaitable[str]]: ...
|
|
335
|
+
|
|
336
|
+
@overload
|
|
337
|
+
def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
|
|
338
|
+
|
|
339
|
+
@overload
|
|
340
|
+
def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...
|
|
341
|
+
|
|
342
|
+
def system_prompt(
|
|
343
|
+
self, func: _system_prompt.SystemPromptFunc[AgentDeps], /
|
|
344
|
+
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
|
|
345
|
+
"""Decorator to register a system prompt function.
|
|
346
|
+
|
|
347
|
+
Optionally takes [`CallContext`][pydantic_ai.dependencies.CallContext] as it's only argument.
|
|
348
|
+
Can decorate a sync or async functions.
|
|
349
|
+
|
|
350
|
+
Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
|
|
351
|
+
the type of the function, see `tests/typed_agent.py` for tests.
|
|
352
|
+
|
|
353
|
+
Example:
|
|
354
|
+
```py
|
|
355
|
+
from pydantic_ai import Agent, CallContext
|
|
356
|
+
|
|
357
|
+
agent = Agent('test', deps_type=str)
|
|
358
|
+
|
|
359
|
+
@agent.system_prompt
|
|
360
|
+
def simple_system_prompt() -> str:
|
|
361
|
+
return 'foobar'
|
|
362
|
+
|
|
363
|
+
@agent.system_prompt
|
|
364
|
+
async def async_system_prompt(ctx: CallContext[str]) -> str:
|
|
365
|
+
return f'{ctx.deps} is the best'
|
|
366
|
+
|
|
367
|
+
result = agent.run_sync('foobar', deps='spam')
|
|
368
|
+
print(result.data)
|
|
369
|
+
#> success (no retriever calls)
|
|
370
|
+
```
|
|
371
|
+
"""
|
|
372
|
+
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
|
|
373
|
+
return func
|
|
374
|
+
|
|
375
|
+
@overload
|
|
376
|
+
def result_validator(
|
|
377
|
+
self, func: Callable[[CallContext[AgentDeps], ResultData], ResultData], /
|
|
378
|
+
) -> Callable[[CallContext[AgentDeps], ResultData], ResultData]: ...
|
|
379
|
+
|
|
380
|
+
@overload
|
|
381
|
+
def result_validator(
|
|
382
|
+
self, func: Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]], /
|
|
383
|
+
) -> Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]]: ...
|
|
384
|
+
|
|
385
|
+
@overload
|
|
386
|
+
def result_validator(self, func: Callable[[ResultData], ResultData], /) -> Callable[[ResultData], ResultData]: ...
|
|
387
|
+
|
|
388
|
+
@overload
|
|
389
|
+
def result_validator(
|
|
390
|
+
self, func: Callable[[ResultData], Awaitable[ResultData]], /
|
|
391
|
+
) -> Callable[[ResultData], Awaitable[ResultData]]: ...
|
|
392
|
+
|
|
393
|
+
def result_validator(
|
|
394
|
+
self, func: _result.ResultValidatorFunc[AgentDeps, ResultData], /
|
|
395
|
+
) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
|
|
396
|
+
"""Decorator to register a result validator function.
|
|
397
|
+
|
|
398
|
+
Optionally takes [`CallContext`][pydantic_ai.dependencies.CallContext] as it's first argument.
|
|
399
|
+
Can decorate a sync or async functions.
|
|
400
|
+
|
|
401
|
+
Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
|
|
402
|
+
the type of the function, see `tests/typed_agent.py` for tests.
|
|
403
|
+
|
|
404
|
+
Example:
|
|
405
|
+
```py
|
|
406
|
+
from pydantic_ai import Agent, CallContext, ModelRetry
|
|
407
|
+
|
|
408
|
+
agent = Agent('test', deps_type=str)
|
|
409
|
+
|
|
410
|
+
@agent.result_validator
|
|
411
|
+
def result_validator_simple(data: str) -> str:
|
|
412
|
+
if 'wrong' in data:
|
|
413
|
+
raise ModelRetry('wrong response')
|
|
414
|
+
return data
|
|
415
|
+
|
|
416
|
+
@agent.result_validator
|
|
417
|
+
async def result_validator_deps(ctx: CallContext[str], data: str) -> str:
|
|
418
|
+
if ctx.deps in data:
|
|
419
|
+
raise ModelRetry('wrong response')
|
|
420
|
+
return data
|
|
421
|
+
|
|
422
|
+
result = agent.run_sync('foobar', deps='spam')
|
|
423
|
+
print(result.data)
|
|
424
|
+
#> success (no retriever calls)
|
|
425
|
+
```
|
|
426
|
+
"""
|
|
427
|
+
self._result_validators.append(_result.ResultValidator(func))
|
|
428
|
+
return func
|
|
429
|
+
|
|
430
|
+
@overload
|
|
431
|
+
def retriever(
|
|
432
|
+
self, func: RetrieverContextFunc[AgentDeps, RetrieverParams], /
|
|
433
|
+
) -> RetrieverContextFunc[AgentDeps, RetrieverParams]: ...
|
|
434
|
+
|
|
435
|
+
@overload
|
|
436
|
+
def retriever(
|
|
437
|
+
self, /, *, retries: int | None = None
|
|
438
|
+
) -> Callable[
|
|
439
|
+
[RetrieverContextFunc[AgentDeps, RetrieverParams]], RetrieverContextFunc[AgentDeps, RetrieverParams]
|
|
440
|
+
]: ...
|
|
441
|
+
|
|
442
|
+
def retriever(
|
|
443
|
+
self,
|
|
444
|
+
func: RetrieverContextFunc[AgentDeps, RetrieverParams] | None = None,
|
|
445
|
+
/,
|
|
446
|
+
*,
|
|
447
|
+
retries: int | None = None,
|
|
448
|
+
) -> Any:
|
|
449
|
+
"""Decorator to register a retriever function which takes
|
|
450
|
+
[`CallContext`][pydantic_ai.dependencies.CallContext] as its first argument.
|
|
451
|
+
|
|
452
|
+
Can decorate a sync or async functions.
|
|
453
|
+
|
|
454
|
+
The docstring is inspected to extract both the tool description and description of each parameter,
|
|
455
|
+
[learn more](../agents.md#retrievers-tools-and-schema).
|
|
456
|
+
|
|
457
|
+
We can't add overloads for every possible signature of retriever, since the return type is a recursive union
|
|
458
|
+
so the signature of functions decorated with `@agent.retriever` is obscured.
|
|
459
|
+
|
|
460
|
+
Example:
|
|
461
|
+
```py
|
|
462
|
+
from pydantic_ai import Agent, CallContext
|
|
463
|
+
|
|
464
|
+
agent = Agent('test', deps_type=int)
|
|
465
|
+
|
|
466
|
+
@agent.retriever
|
|
467
|
+
def foobar(ctx: CallContext[int], x: int) -> int:
|
|
468
|
+
return ctx.deps + x
|
|
469
|
+
|
|
470
|
+
@agent.retriever(retries=2)
|
|
471
|
+
async def spam(ctx: CallContext[str], y: float) -> float:
|
|
472
|
+
return ctx.deps + y
|
|
473
|
+
|
|
474
|
+
result = agent.run_sync('foobar', deps=1)
|
|
475
|
+
print(result.data)
|
|
476
|
+
#> {"foobar":1,"spam":1.0}
|
|
477
|
+
```
|
|
478
|
+
|
|
479
|
+
Args:
|
|
480
|
+
func: The retriever function to register.
|
|
481
|
+
retries: The number of retries to allow for this retriever, defaults to the agent's default retries,
|
|
482
|
+
which defaults to 1.
|
|
483
|
+
""" # noqa: D205
|
|
484
|
+
if func is None:
|
|
485
|
+
|
|
486
|
+
def retriever_decorator(
|
|
487
|
+
func_: RetrieverContextFunc[AgentDeps, RetrieverParams],
|
|
488
|
+
) -> RetrieverContextFunc[AgentDeps, RetrieverParams]:
|
|
489
|
+
# noinspection PyTypeChecker
|
|
490
|
+
self._register_retriever(_utils.Either(left=func_), retries)
|
|
491
|
+
return func_
|
|
492
|
+
|
|
493
|
+
return retriever_decorator
|
|
494
|
+
else:
|
|
495
|
+
# noinspection PyTypeChecker
|
|
496
|
+
self._register_retriever(_utils.Either(left=func), retries)
|
|
497
|
+
return func
|
|
498
|
+
|
|
499
|
+
@overload
|
|
500
|
+
def retriever_plain(self, func: RetrieverPlainFunc[RetrieverParams], /) -> RetrieverPlainFunc[RetrieverParams]: ...
|
|
501
|
+
|
|
502
|
+
@overload
|
|
503
|
+
def retriever_plain(
|
|
504
|
+
self, /, *, retries: int | None = None
|
|
505
|
+
) -> Callable[[RetrieverPlainFunc[RetrieverParams]], RetrieverPlainFunc[RetrieverParams]]: ...
|
|
506
|
+
|
|
507
|
+
def retriever_plain(
|
|
508
|
+
self, func: RetrieverPlainFunc[RetrieverParams] | None = None, /, *, retries: int | None = None
|
|
509
|
+
) -> Any:
|
|
510
|
+
"""Decorator to register a retriever function which DOES NOT take `CallContext` as an argument.
|
|
511
|
+
|
|
512
|
+
Can decorate a sync or async functions.
|
|
513
|
+
|
|
514
|
+
The docstring is inspected to extract both the tool description and description of each parameter,
|
|
515
|
+
[learn more](../agents.md#retrievers-tools-and-schema).
|
|
516
|
+
|
|
517
|
+
We can't add overloads for every possible signature of retriever, since the return type is a recursive union
|
|
518
|
+
so the signature of functions decorated with `@agent.retriever` is obscured.
|
|
519
|
+
|
|
520
|
+
Example:
|
|
521
|
+
```py
|
|
522
|
+
from pydantic_ai import Agent, CallContext
|
|
523
|
+
|
|
524
|
+
agent = Agent('test')
|
|
525
|
+
|
|
526
|
+
@agent.retriever
|
|
527
|
+
def foobar(ctx: CallContext[int]) -> int:
|
|
528
|
+
return 123
|
|
529
|
+
|
|
530
|
+
@agent.retriever(retries=2)
|
|
531
|
+
async def spam(ctx: CallContext[str]) -> float:
|
|
532
|
+
return 3.14
|
|
533
|
+
|
|
534
|
+
result = agent.run_sync('foobar', deps=1)
|
|
535
|
+
print(result.data)
|
|
536
|
+
#> {"foobar":123,"spam":3.14}
|
|
537
|
+
```
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
func: The retriever function to register.
|
|
541
|
+
retries: The number of retries to allow for this retriever, defaults to the agent's default retries,
|
|
542
|
+
which defaults to 1.
|
|
543
|
+
"""
|
|
544
|
+
if func is None:
|
|
545
|
+
|
|
546
|
+
def retriever_decorator(
|
|
547
|
+
func_: RetrieverPlainFunc[RetrieverParams],
|
|
548
|
+
) -> RetrieverPlainFunc[RetrieverParams]:
|
|
549
|
+
# noinspection PyTypeChecker
|
|
550
|
+
self._register_retriever(_utils.Either(right=func_), retries)
|
|
551
|
+
return func_
|
|
552
|
+
|
|
553
|
+
return retriever_decorator
|
|
554
|
+
else:
|
|
555
|
+
self._register_retriever(_utils.Either(right=func), retries)
|
|
556
|
+
return func
|
|
557
|
+
|
|
558
|
+
def _register_retriever(
|
|
559
|
+
self, func: _r.RetrieverEitherFunc[AgentDeps, RetrieverParams], retries: int | None
|
|
560
|
+
) -> None:
|
|
561
|
+
"""Private utility to register a retriever function."""
|
|
562
|
+
retries_ = retries if retries is not None else self._default_retries
|
|
563
|
+
retriever = _r.Retriever[AgentDeps, RetrieverParams](func, retries_)
|
|
564
|
+
|
|
565
|
+
if self._result_schema and retriever.name in self._result_schema.tools:
|
|
566
|
+
raise ValueError(f'Retriever name conflicts with result schema name: {retriever.name!r}')
|
|
567
|
+
|
|
568
|
+
if retriever.name in self._retrievers:
|
|
569
|
+
raise ValueError(f'Retriever name conflicts with existing retriever: {retriever.name!r}')
|
|
570
|
+
|
|
571
|
+
self._retrievers[retriever.name] = retriever
|
|
572
|
+
|
|
573
|
+
async def _get_agent_model(
|
|
574
|
+
self, model: models.Model | models.KnownModelName | None
|
|
575
|
+
) -> tuple[models.Model, models.Model | None, models.AgentModel]:
|
|
576
|
+
"""Create a model configured for this agent.
|
|
577
|
+
|
|
578
|
+
Args:
|
|
579
|
+
model: model to use for this run, required if `model` was not set when creating the agent.
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
a tuple of `(model used, custom_model if any, agent_model)`
|
|
583
|
+
"""
|
|
584
|
+
model_: models.Model
|
|
585
|
+
if some_model := self._override_model:
|
|
586
|
+
# we don't want `override_model()` to cover up errors from the model not being defined, hence this check
|
|
587
|
+
if model is None and self.model is None:
|
|
588
|
+
raise exceptions.UserError(
|
|
589
|
+
'`model` must be set either when creating the agent or when calling it. '
|
|
590
|
+
'(Even when `override_model()` is customizing the model that will actually be called)'
|
|
591
|
+
)
|
|
592
|
+
model_ = some_model.value
|
|
593
|
+
custom_model = None
|
|
594
|
+
elif model is not None:
|
|
595
|
+
custom_model = model_ = models.infer_model(model)
|
|
596
|
+
elif self.model is not None:
|
|
597
|
+
# noinspection PyTypeChecker
|
|
598
|
+
model_ = self.model = models.infer_model(self.model)
|
|
599
|
+
custom_model = None
|
|
600
|
+
else:
|
|
601
|
+
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
|
|
602
|
+
|
|
603
|
+
result_tools = list(self._result_schema.tools.values()) if self._result_schema else None
|
|
604
|
+
agent_model = await model_.agent_model(self._retrievers, self._allow_text_result, result_tools)
|
|
605
|
+
return model_, custom_model, agent_model
|
|
606
|
+
|
|
607
|
+
async def _prepare_messages(
|
|
608
|
+
self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.Message] | None
|
|
609
|
+
) -> tuple[int, list[_messages.Message]]:
|
|
610
|
+
# if message history includes system prompts, we don't want to regenerate them
|
|
611
|
+
if message_history and any(m.role == 'system' for m in message_history):
|
|
612
|
+
# shallow copy messages
|
|
613
|
+
messages = message_history.copy()
|
|
614
|
+
else:
|
|
615
|
+
messages = await self._init_messages(deps)
|
|
616
|
+
if message_history:
|
|
617
|
+
messages += message_history
|
|
618
|
+
|
|
619
|
+
new_message_index = len(messages)
|
|
620
|
+
messages.append(_messages.UserPrompt(user_prompt))
|
|
621
|
+
return new_message_index, messages
|
|
622
|
+
|
|
623
|
+
async def _handle_model_response(
|
|
624
|
+
self, model_response: _messages.ModelAnyResponse, deps: AgentDeps
|
|
625
|
+
) -> _MarkFinalResult[ResultData] | list[_messages.Message]:
|
|
626
|
+
"""Process a non-streamed response from the model.
|
|
627
|
+
|
|
628
|
+
Returns:
|
|
629
|
+
Return `Either` — left: final result data, right: list of messages to send back to the model.
|
|
630
|
+
"""
|
|
631
|
+
if model_response.role == 'model-text-response':
|
|
632
|
+
# plain string response
|
|
633
|
+
if self._allow_text_result:
|
|
634
|
+
result_data_input = cast(ResultData, model_response.content)
|
|
635
|
+
try:
|
|
636
|
+
result_data = await self._validate_result(result_data_input, deps, None)
|
|
637
|
+
except _result.ToolRetryError as e:
|
|
638
|
+
self._incr_result_retry()
|
|
639
|
+
return [e.tool_retry]
|
|
640
|
+
else:
|
|
641
|
+
return _MarkFinalResult(result_data)
|
|
642
|
+
else:
|
|
643
|
+
self._incr_result_retry()
|
|
644
|
+
response = _messages.RetryPrompt(
|
|
645
|
+
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
646
|
+
)
|
|
647
|
+
return [response]
|
|
648
|
+
elif model_response.role == 'model-structured-response':
|
|
649
|
+
if self._result_schema is not None:
|
|
650
|
+
# if there's a result schema, and any of the calls match one of its tools, return the result
|
|
651
|
+
# NOTE: this means we ignore any other tools called here
|
|
652
|
+
if match := self._result_schema.find_tool(model_response):
|
|
653
|
+
call, result_tool = match
|
|
654
|
+
try:
|
|
655
|
+
result_data = result_tool.validate(call)
|
|
656
|
+
result_data = await self._validate_result(result_data, deps, call)
|
|
657
|
+
except _result.ToolRetryError as e:
|
|
658
|
+
self._incr_result_retry()
|
|
659
|
+
return [e.tool_retry]
|
|
660
|
+
else:
|
|
661
|
+
return _MarkFinalResult(result_data)
|
|
662
|
+
|
|
663
|
+
if not model_response.calls:
|
|
664
|
+
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
|
|
665
|
+
|
|
666
|
+
# otherwise we run all retriever functions in parallel
|
|
667
|
+
messages: list[_messages.Message] = []
|
|
668
|
+
tasks: list[asyncio.Task[_messages.Message]] = []
|
|
669
|
+
for call in model_response.calls:
|
|
670
|
+
if retriever := self._retrievers.get(call.tool_name):
|
|
671
|
+
tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name))
|
|
672
|
+
else:
|
|
673
|
+
messages.append(self._unknown_tool(call.tool_name))
|
|
674
|
+
|
|
675
|
+
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
676
|
+
messages += await asyncio.gather(*tasks)
|
|
677
|
+
return messages
|
|
678
|
+
else:
|
|
679
|
+
assert_never(model_response)
|
|
680
|
+
|
|
681
|
+
async def _handle_streamed_model_response(
|
|
682
|
+
self, model_response: models.EitherStreamedResponse, deps: AgentDeps
|
|
683
|
+
) -> _MarkFinalResult[models.EitherStreamedResponse] | list[_messages.Message]:
|
|
684
|
+
"""Process a streamed response from the model.
|
|
685
|
+
|
|
686
|
+
TODO: change the response type to `models.EitherStreamedResponse | list[_messages.Message]` once we drop 3.9
|
|
687
|
+
(with 3.9 we get `TypeError: Subscripted generics cannot be used with class and instance checks`)
|
|
688
|
+
|
|
689
|
+
Returns:
|
|
690
|
+
Return `Either` — left: final result data, right: list of messages to send back to the model.
|
|
691
|
+
"""
|
|
692
|
+
if isinstance(model_response, models.StreamTextResponse):
|
|
693
|
+
# plain string response
|
|
694
|
+
if self._allow_text_result:
|
|
695
|
+
return _MarkFinalResult(model_response)
|
|
696
|
+
else:
|
|
697
|
+
self._incr_result_retry()
|
|
698
|
+
response = _messages.RetryPrompt(
|
|
699
|
+
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
700
|
+
)
|
|
701
|
+
# stream the response, so cost is correct
|
|
702
|
+
async for _ in model_response:
|
|
703
|
+
pass
|
|
704
|
+
|
|
705
|
+
return [response]
|
|
706
|
+
else:
|
|
707
|
+
assert isinstance(model_response, models.StreamStructuredResponse), f'Unexpected response: {model_response}'
|
|
708
|
+
if self._result_schema is not None:
|
|
709
|
+
# if there's a result schema, iterate over the stream until we find at least one tool
|
|
710
|
+
# NOTE: this means we ignore any other tools called here
|
|
711
|
+
structured_msg = model_response.get()
|
|
712
|
+
while not structured_msg.calls:
|
|
713
|
+
try:
|
|
714
|
+
await model_response.__anext__()
|
|
715
|
+
except StopAsyncIteration:
|
|
716
|
+
break
|
|
717
|
+
structured_msg = model_response.get()
|
|
718
|
+
|
|
719
|
+
if self._result_schema.find_tool(structured_msg):
|
|
720
|
+
return _MarkFinalResult(model_response)
|
|
721
|
+
|
|
722
|
+
# the model is calling a retriever function, consume the response to get the next message
|
|
723
|
+
async for _ in model_response:
|
|
724
|
+
pass
|
|
725
|
+
structured_msg = model_response.get()
|
|
726
|
+
if not structured_msg.calls:
|
|
727
|
+
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
|
|
728
|
+
messages: list[_messages.Message] = [structured_msg]
|
|
729
|
+
|
|
730
|
+
# we now run all retriever functions in parallel
|
|
731
|
+
tasks: list[asyncio.Task[_messages.Message]] = []
|
|
732
|
+
for call in structured_msg.calls:
|
|
733
|
+
if retriever := self._retrievers.get(call.tool_name):
|
|
734
|
+
tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name))
|
|
735
|
+
else:
|
|
736
|
+
messages.append(self._unknown_tool(call.tool_name))
|
|
737
|
+
|
|
738
|
+
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
739
|
+
messages += await asyncio.gather(*tasks)
|
|
740
|
+
return messages
|
|
741
|
+
|
|
742
|
+
async def _validate_result(
|
|
743
|
+
self, result_data: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall | None
|
|
744
|
+
) -> ResultData:
|
|
745
|
+
for validator in self._result_validators:
|
|
746
|
+
result_data = await validator.validate(result_data, deps, self._current_result_retry, tool_call)
|
|
747
|
+
return result_data
|
|
748
|
+
|
|
749
|
+
def _incr_result_retry(self) -> None:
|
|
750
|
+
self._current_result_retry += 1
|
|
751
|
+
if self._current_result_retry > self._max_result_retries:
|
|
752
|
+
raise exceptions.UnexpectedModelBehavior(
|
|
753
|
+
f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
async def _init_messages(self, deps: AgentDeps) -> list[_messages.Message]:
|
|
757
|
+
"""Build the initial messages for the conversation."""
|
|
758
|
+
messages: list[_messages.Message] = [_messages.SystemPrompt(p) for p in self._system_prompts]
|
|
759
|
+
for sys_prompt_runner in self._system_prompt_functions:
|
|
760
|
+
prompt = await sys_prompt_runner.run(deps)
|
|
761
|
+
messages.append(_messages.SystemPrompt(prompt))
|
|
762
|
+
return messages
|
|
763
|
+
|
|
764
|
+
def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt:
|
|
765
|
+
self._incr_result_retry()
|
|
766
|
+
names = list(self._retrievers.keys())
|
|
767
|
+
if self._result_schema:
|
|
768
|
+
names.extend(self._result_schema.tool_names())
|
|
769
|
+
if names:
|
|
770
|
+
msg = f'Available tools: {", ".join(names)}'
|
|
771
|
+
else:
|
|
772
|
+
msg = 'No tools available.'
|
|
773
|
+
return _messages.RetryPrompt(content=f'Unknown tool name: {tool_name!r}. {msg}')
|
|
774
|
+
|
|
775
|
+
def _get_deps(self, deps: AgentDeps) -> AgentDeps:
|
|
776
|
+
"""Get deps for a run.
|
|
777
|
+
|
|
778
|
+
If we've overridden deps via `_override_deps_stack`, use that, otherwise use the deps passed to the call.
|
|
779
|
+
|
|
780
|
+
We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
|
|
781
|
+
"""
|
|
782
|
+
if some_deps := self._override_deps:
|
|
783
|
+
return some_deps.value
|
|
784
|
+
else:
|
|
785
|
+
return deps
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
@dataclass
|
|
789
|
+
class _MarkFinalResult(Generic[ResultData]):
|
|
790
|
+
"""Marker class to indicate that the result is the final result.
|
|
791
|
+
|
|
792
|
+
This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultData` directly.
|
|
793
|
+
"""
|
|
794
|
+
|
|
795
|
+
data: ResultData
|