pydantic-ai-slim 0.0.55__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/__init__.py +10 -3
- pydantic_ai/_agent_graph.py +70 -59
- pydantic_ai/_cli.py +1 -2
- pydantic_ai/{_result.py → _output.py} +69 -47
- pydantic_ai/_utils.py +20 -0
- pydantic_ai/agent.py +511 -161
- pydantic_ai/format_as_xml.py +6 -113
- pydantic_ai/format_prompt.py +116 -0
- pydantic_ai/messages.py +104 -21
- pydantic_ai/models/__init__.py +24 -4
- pydantic_ai/models/_json_schema.py +160 -0
- pydantic_ai/models/anthropic.py +5 -3
- pydantic_ai/models/bedrock.py +100 -22
- pydantic_ai/models/cohere.py +48 -44
- pydantic_ai/models/fallback.py +2 -1
- pydantic_ai/models/function.py +8 -8
- pydantic_ai/models/gemini.py +82 -75
- pydantic_ai/models/groq.py +32 -28
- pydantic_ai/models/instrumented.py +4 -4
- pydantic_ai/models/mistral.py +62 -58
- pydantic_ai/models/openai.py +110 -158
- pydantic_ai/models/test.py +45 -46
- pydantic_ai/result.py +203 -90
- pydantic_ai/tools.py +4 -4
- {pydantic_ai_slim-0.0.55.dist-info → pydantic_ai_slim-0.1.1.dist-info}/METADATA +5 -5
- pydantic_ai_slim-0.1.1.dist-info/RECORD +53 -0
- pydantic_ai_slim-0.0.55.dist-info/RECORD +0 -51
- {pydantic_ai_slim-0.0.55.dist-info → pydantic_ai_slim-0.1.1.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.0.55.dist-info → pydantic_ai_slim-0.1.1.dist-info}/entry_points.txt +0 -0
pydantic_ai/agent.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import dataclasses
|
|
4
4
|
import inspect
|
|
5
|
+
import warnings
|
|
5
6
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
6
7
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
|
|
7
8
|
from copy import deepcopy
|
|
@@ -10,14 +11,14 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final,
|
|
|
10
11
|
|
|
11
12
|
from opentelemetry.trace import NoOpTracer, use_span
|
|
12
13
|
from pydantic.json_schema import GenerateJsonSchema
|
|
13
|
-
from typing_extensions import TypeGuard, TypeVar, deprecated
|
|
14
|
+
from typing_extensions import Literal, Never, TypeGuard, TypeVar, deprecated
|
|
14
15
|
|
|
15
16
|
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
|
|
16
17
|
from pydantic_graph._utils import get_event_loop
|
|
17
18
|
|
|
18
19
|
from . import (
|
|
19
20
|
_agent_graph,
|
|
20
|
-
|
|
21
|
+
_output,
|
|
21
22
|
_system_prompt,
|
|
22
23
|
_utils,
|
|
23
24
|
exceptions,
|
|
@@ -26,8 +27,9 @@ from . import (
|
|
|
26
27
|
result,
|
|
27
28
|
usage as _usage,
|
|
28
29
|
)
|
|
30
|
+
from ._utils import AbstractSpan
|
|
29
31
|
from .models.instrumented import InstrumentationSettings, InstrumentedModel
|
|
30
|
-
from .result import FinalResult,
|
|
32
|
+
from .result import FinalResult, OutputDataT, StreamedRunResult, ToolOutput
|
|
31
33
|
from .settings import ModelSettings, merge_model_settings
|
|
32
34
|
from .tools import (
|
|
33
35
|
AgentDepsT,
|
|
@@ -52,6 +54,7 @@ UserPromptNode = _agent_graph.UserPromptNode
|
|
|
52
54
|
if TYPE_CHECKING:
|
|
53
55
|
from pydantic_ai.mcp import MCPServer
|
|
54
56
|
|
|
57
|
+
|
|
55
58
|
__all__ = (
|
|
56
59
|
'Agent',
|
|
57
60
|
'AgentRun',
|
|
@@ -68,17 +71,17 @@ __all__ = (
|
|
|
68
71
|
T = TypeVar('T')
|
|
69
72
|
S = TypeVar('S')
|
|
70
73
|
NoneType = type(None)
|
|
71
|
-
|
|
72
|
-
"""Type variable for the result data of a run where `
|
|
74
|
+
RunOutputDataT = TypeVar('RunOutputDataT')
|
|
75
|
+
"""Type variable for the result data of a run where `output_type` was customized on the run call."""
|
|
73
76
|
|
|
74
77
|
|
|
75
78
|
@final
|
|
76
79
|
@dataclasses.dataclass(init=False)
|
|
77
|
-
class Agent(Generic[AgentDepsT,
|
|
80
|
+
class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
78
81
|
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
|
|
79
82
|
|
|
80
83
|
Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT]
|
|
81
|
-
and the result data type they return, [`
|
|
84
|
+
and the result data type they return, [`OutputDataT`][pydantic_ai.result.OutputDataT].
|
|
82
85
|
|
|
83
86
|
By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
|
|
84
87
|
|
|
@@ -89,7 +92,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
89
92
|
|
|
90
93
|
agent = Agent('openai:gpt-4o')
|
|
91
94
|
result = agent.run_sync('What is the capital of France?')
|
|
92
|
-
print(result.
|
|
95
|
+
print(result.output)
|
|
93
96
|
#> Paris
|
|
94
97
|
```
|
|
95
98
|
"""
|
|
@@ -115,9 +118,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
115
118
|
be merged with this value, with the runtime argument taking priority.
|
|
116
119
|
"""
|
|
117
120
|
|
|
118
|
-
|
|
121
|
+
output_type: type[OutputDataT] | ToolOutput[OutputDataT]
|
|
119
122
|
"""
|
|
120
|
-
The type of
|
|
123
|
+
The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`.
|
|
121
124
|
"""
|
|
122
125
|
|
|
123
126
|
instrument: InstrumentationSettings | bool | None
|
|
@@ -126,10 +129,12 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
126
129
|
_instrument_default: ClassVar[InstrumentationSettings | bool] = False
|
|
127
130
|
|
|
128
131
|
_deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
132
|
+
_deprecated_result_tool_name: str | None = dataclasses.field(repr=False)
|
|
133
|
+
_deprecated_result_tool_description: str | None = dataclasses.field(repr=False)
|
|
134
|
+
_output_schema: _output.OutputSchema[OutputDataT] | None = dataclasses.field(repr=False)
|
|
135
|
+
_output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False)
|
|
136
|
+
_instructions: str | None = dataclasses.field(repr=False)
|
|
137
|
+
_instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
|
|
133
138
|
_system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
|
|
134
139
|
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
|
|
135
140
|
_system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
|
|
@@ -142,11 +147,36 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
142
147
|
_override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
|
|
143
148
|
_override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
|
|
144
149
|
|
|
150
|
+
@overload
|
|
151
|
+
def __init__(
|
|
152
|
+
self,
|
|
153
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
154
|
+
*,
|
|
155
|
+
output_type: type[OutputDataT] | ToolOutput[OutputDataT] = str,
|
|
156
|
+
instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
|
|
157
|
+
system_prompt: str | Sequence[str] = (),
|
|
158
|
+
deps_type: type[AgentDepsT] = NoneType,
|
|
159
|
+
name: str | None = None,
|
|
160
|
+
model_settings: ModelSettings | None = None,
|
|
161
|
+
retries: int = 1,
|
|
162
|
+
output_retries: int | None = None,
|
|
163
|
+
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
164
|
+
mcp_servers: Sequence[MCPServer] = (),
|
|
165
|
+
defer_model_check: bool = False,
|
|
166
|
+
end_strategy: EndStrategy = 'early',
|
|
167
|
+
instrument: InstrumentationSettings | bool | None = None,
|
|
168
|
+
) -> None: ...
|
|
169
|
+
|
|
170
|
+
@overload
|
|
171
|
+
@deprecated(
|
|
172
|
+
'`result_type`, `result_tool_name`, `result_tool_description` & `result_retries` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.'
|
|
173
|
+
)
|
|
145
174
|
def __init__(
|
|
146
175
|
self,
|
|
147
176
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
148
177
|
*,
|
|
149
|
-
result_type: type[
|
|
178
|
+
result_type: type[OutputDataT] = str,
|
|
179
|
+
instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
|
|
150
180
|
system_prompt: str | Sequence[str] = (),
|
|
151
181
|
deps_type: type[AgentDepsT] = NoneType,
|
|
152
182
|
name: str | None = None,
|
|
@@ -160,13 +190,37 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
160
190
|
defer_model_check: bool = False,
|
|
161
191
|
end_strategy: EndStrategy = 'early',
|
|
162
192
|
instrument: InstrumentationSettings | bool | None = None,
|
|
193
|
+
) -> None: ...
|
|
194
|
+
|
|
195
|
+
def __init__(
|
|
196
|
+
self,
|
|
197
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
198
|
+
*,
|
|
199
|
+
# TODO change this back to `output_type: type[OutputDataT] | ToolOutput[OutputDataT] = str,` when we remove the overloads
|
|
200
|
+
output_type: Any = str,
|
|
201
|
+
instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
|
|
202
|
+
system_prompt: str | Sequence[str] = (),
|
|
203
|
+
deps_type: type[AgentDepsT] = NoneType,
|
|
204
|
+
name: str | None = None,
|
|
205
|
+
model_settings: ModelSettings | None = None,
|
|
206
|
+
retries: int = 1,
|
|
207
|
+
output_retries: int | None = None,
|
|
208
|
+
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
209
|
+
mcp_servers: Sequence[MCPServer] = (),
|
|
210
|
+
defer_model_check: bool = False,
|
|
211
|
+
end_strategy: EndStrategy = 'early',
|
|
212
|
+
instrument: InstrumentationSettings | bool | None = None,
|
|
213
|
+
**_deprecated_kwargs: Any,
|
|
163
214
|
):
|
|
164
215
|
"""Create an agent.
|
|
165
216
|
|
|
166
217
|
Args:
|
|
167
218
|
model: The default model to use for this agent, if not provide,
|
|
168
219
|
you must provide the model when calling it. We allow str here since the actual list of allowed models changes frequently.
|
|
169
|
-
|
|
220
|
+
output_type: The type of the output data, used to validate the data returned by the model,
|
|
221
|
+
defaults to `str`.
|
|
222
|
+
instructions: Instructions to use for this agent, you can also register instructions via a function with
|
|
223
|
+
[`instructions`][pydantic_ai.Agent.instructions].
|
|
170
224
|
system_prompt: Static system prompts to use for this agent, you can also register system
|
|
171
225
|
prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt].
|
|
172
226
|
deps_type: The type used for dependency injection, this parameter exists solely to allow you to fully
|
|
@@ -177,9 +231,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
177
231
|
when the agent is first run.
|
|
178
232
|
model_settings: Optional model request settings to use for this agent's runs, by default.
|
|
179
233
|
retries: The default number of retries to allow before raising an error.
|
|
180
|
-
|
|
181
|
-
result_tool_description: The description of the final result tool.
|
|
182
|
-
result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
|
|
234
|
+
output_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
|
|
183
235
|
tools: Tools to register with the agent, you can also register tools via the decorators
|
|
184
236
|
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
|
|
185
237
|
mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer]
|
|
@@ -207,17 +259,48 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
207
259
|
self.end_strategy = end_strategy
|
|
208
260
|
self.name = name
|
|
209
261
|
self.model_settings = model_settings
|
|
210
|
-
|
|
262
|
+
|
|
263
|
+
if 'result_type' in _deprecated_kwargs: # pragma: no cover
|
|
264
|
+
if output_type is not str:
|
|
265
|
+
raise TypeError('`result_type` and `output_type` cannot be set at the same time.')
|
|
266
|
+
warnings.warn('`result_type` is deprecated, use `output_type` instead', DeprecationWarning)
|
|
267
|
+
output_type = _deprecated_kwargs['result_type']
|
|
268
|
+
|
|
269
|
+
self.output_type = output_type
|
|
270
|
+
|
|
211
271
|
self.instrument = instrument
|
|
212
272
|
|
|
213
273
|
self._deps_type = deps_type
|
|
214
274
|
|
|
215
|
-
self.
|
|
216
|
-
self.
|
|
217
|
-
|
|
218
|
-
|
|
275
|
+
self._deprecated_result_tool_name = _deprecated_kwargs.get('result_tool_name')
|
|
276
|
+
if self._deprecated_result_tool_name is not None: # pragma: no cover
|
|
277
|
+
warnings.warn(
|
|
278
|
+
'`result_tool_name` is deprecated, use `output_type` with `ToolOutput` instead',
|
|
279
|
+
DeprecationWarning,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
self._deprecated_result_tool_description = _deprecated_kwargs.get('result_tool_description')
|
|
283
|
+
if self._deprecated_result_tool_description is not None: # pragma: no cover
|
|
284
|
+
warnings.warn(
|
|
285
|
+
'`result_tool_description` is deprecated, use `output_type` with `ToolOutput` instead',
|
|
286
|
+
DeprecationWarning,
|
|
287
|
+
)
|
|
288
|
+
result_retries = _deprecated_kwargs.get('result_retries')
|
|
289
|
+
if result_retries is not None: # pragma: no cover
|
|
290
|
+
if output_retries is not None:
|
|
291
|
+
raise TypeError('`output_retries` and `result_retries` cannot be set at the same time.')
|
|
292
|
+
warnings.warn('`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning)
|
|
293
|
+
output_retries = result_retries
|
|
294
|
+
|
|
295
|
+
self._output_schema = _output.OutputSchema[OutputDataT].build(
|
|
296
|
+
output_type, self._deprecated_result_tool_name, self._deprecated_result_tool_description
|
|
297
|
+
)
|
|
298
|
+
self._output_validators = []
|
|
299
|
+
|
|
300
|
+
self._instructions_functions = (
|
|
301
|
+
[_system_prompt.SystemPromptRunner(instructions)] if callable(instructions) else []
|
|
219
302
|
)
|
|
220
|
-
self.
|
|
303
|
+
self._instructions = instructions if isinstance(instructions, str) else None
|
|
221
304
|
|
|
222
305
|
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
|
|
223
306
|
self._system_prompt_functions = []
|
|
@@ -226,7 +309,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
226
309
|
self._function_tools = {}
|
|
227
310
|
|
|
228
311
|
self._default_retries = retries
|
|
229
|
-
self._max_result_retries =
|
|
312
|
+
self._max_result_retries = output_retries if output_retries is not None else retries
|
|
230
313
|
self._mcp_servers = mcp_servers
|
|
231
314
|
for tool in tools:
|
|
232
315
|
if isinstance(tool, Tool):
|
|
@@ -244,7 +327,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
244
327
|
self,
|
|
245
328
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
246
329
|
*,
|
|
247
|
-
|
|
330
|
+
output_type: None = None,
|
|
248
331
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
249
332
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
250
333
|
deps: AgentDepsT = None,
|
|
@@ -252,14 +335,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
252
335
|
usage_limits: _usage.UsageLimits | None = None,
|
|
253
336
|
usage: _usage.Usage | None = None,
|
|
254
337
|
infer_name: bool = True,
|
|
255
|
-
) -> AgentRunResult[
|
|
338
|
+
) -> AgentRunResult[OutputDataT]: ...
|
|
256
339
|
|
|
257
340
|
@overload
|
|
258
341
|
async def run(
|
|
259
342
|
self,
|
|
260
343
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
261
344
|
*,
|
|
262
|
-
|
|
345
|
+
output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT],
|
|
263
346
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
264
347
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
265
348
|
deps: AgentDepsT = None,
|
|
@@ -267,13 +350,15 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
267
350
|
usage_limits: _usage.UsageLimits | None = None,
|
|
268
351
|
usage: _usage.Usage | None = None,
|
|
269
352
|
infer_name: bool = True,
|
|
270
|
-
) -> AgentRunResult[
|
|
353
|
+
) -> AgentRunResult[RunOutputDataT]: ...
|
|
271
354
|
|
|
355
|
+
@overload
|
|
356
|
+
@deprecated('`result_type` is deprecated, use `output_type` instead.')
|
|
272
357
|
async def run(
|
|
273
358
|
self,
|
|
274
359
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
275
360
|
*,
|
|
276
|
-
result_type: type[
|
|
361
|
+
result_type: type[RunOutputDataT],
|
|
277
362
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
278
363
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
279
364
|
deps: AgentDepsT = None,
|
|
@@ -281,6 +366,21 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
281
366
|
usage_limits: _usage.UsageLimits | None = None,
|
|
282
367
|
usage: _usage.Usage | None = None,
|
|
283
368
|
infer_name: bool = True,
|
|
369
|
+
) -> AgentRunResult[RunOutputDataT]: ...
|
|
370
|
+
|
|
371
|
+
async def run(
|
|
372
|
+
self,
|
|
373
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
374
|
+
*,
|
|
375
|
+
output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None,
|
|
376
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
377
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
378
|
+
deps: AgentDepsT = None,
|
|
379
|
+
model_settings: ModelSettings | None = None,
|
|
380
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
381
|
+
usage: _usage.Usage | None = None,
|
|
382
|
+
infer_name: bool = True,
|
|
383
|
+
**_deprecated_kwargs: Never,
|
|
284
384
|
) -> AgentRunResult[Any]:
|
|
285
385
|
"""Run the agent with a user prompt in async mode.
|
|
286
386
|
|
|
@@ -295,14 +395,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
295
395
|
|
|
296
396
|
async def main():
|
|
297
397
|
agent_run = await agent.run('What is the capital of France?')
|
|
298
|
-
print(agent_run.
|
|
398
|
+
print(agent_run.output)
|
|
299
399
|
#> Paris
|
|
300
400
|
```
|
|
301
401
|
|
|
302
402
|
Args:
|
|
303
403
|
user_prompt: User input to start/continue the conversation.
|
|
304
|
-
|
|
305
|
-
|
|
404
|
+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
405
|
+
output validators since output validators would expect an argument that matches the agent's output type.
|
|
306
406
|
message_history: History of the conversation so far.
|
|
307
407
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
308
408
|
deps: Optional dependencies to use for this run.
|
|
@@ -316,9 +416,16 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
316
416
|
"""
|
|
317
417
|
if infer_name and self.name is None:
|
|
318
418
|
self._infer_name(inspect.currentframe())
|
|
419
|
+
|
|
420
|
+
if 'result_type' in _deprecated_kwargs: # pragma: no cover
|
|
421
|
+
if output_type is not str:
|
|
422
|
+
raise TypeError('`result_type` and `output_type` cannot be set at the same time.')
|
|
423
|
+
warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning)
|
|
424
|
+
output_type = _deprecated_kwargs['result_type']
|
|
425
|
+
|
|
319
426
|
async with self.iter(
|
|
320
427
|
user_prompt=user_prompt,
|
|
321
|
-
|
|
428
|
+
output_type=output_type,
|
|
322
429
|
message_history=message_history,
|
|
323
430
|
model=model,
|
|
324
431
|
deps=deps,
|
|
@@ -332,12 +439,44 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
332
439
|
assert agent_run.result is not None, 'The graph run did not finish properly'
|
|
333
440
|
return agent_run.result
|
|
334
441
|
|
|
442
|
+
@overload
|
|
443
|
+
def iter(
|
|
444
|
+
self,
|
|
445
|
+
user_prompt: str | Sequence[_messages.UserContent] | None,
|
|
446
|
+
*,
|
|
447
|
+
output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None,
|
|
448
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
449
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
450
|
+
deps: AgentDepsT = None,
|
|
451
|
+
model_settings: ModelSettings | None = None,
|
|
452
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
453
|
+
usage: _usage.Usage | None = None,
|
|
454
|
+
infer_name: bool = True,
|
|
455
|
+
**_deprecated_kwargs: Never,
|
|
456
|
+
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ...
|
|
457
|
+
|
|
458
|
+
@overload
|
|
459
|
+
@deprecated('`result_type` is deprecated, use `output_type` instead.')
|
|
460
|
+
def iter(
|
|
461
|
+
self,
|
|
462
|
+
user_prompt: str | Sequence[_messages.UserContent] | None,
|
|
463
|
+
*,
|
|
464
|
+
result_type: type[RunOutputDataT],
|
|
465
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
466
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
467
|
+
deps: AgentDepsT = None,
|
|
468
|
+
model_settings: ModelSettings | None = None,
|
|
469
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
470
|
+
usage: _usage.Usage | None = None,
|
|
471
|
+
infer_name: bool = True,
|
|
472
|
+
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ...
|
|
473
|
+
|
|
335
474
|
@asynccontextmanager
|
|
336
475
|
async def iter(
|
|
337
476
|
self,
|
|
338
477
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
339
478
|
*,
|
|
340
|
-
|
|
479
|
+
output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None,
|
|
341
480
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
342
481
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
343
482
|
deps: AgentDepsT = None,
|
|
@@ -345,10 +484,11 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
345
484
|
usage_limits: _usage.UsageLimits | None = None,
|
|
346
485
|
usage: _usage.Usage | None = None,
|
|
347
486
|
infer_name: bool = True,
|
|
487
|
+
**_deprecated_kwargs: Never,
|
|
348
488
|
) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
|
|
349
489
|
"""A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
|
|
350
490
|
|
|
351
|
-
This method builds an internal agent graph (using system prompts, tools and
|
|
491
|
+
This method builds an internal agent graph (using system prompts, tools and output schemas) and then returns an
|
|
352
492
|
`AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are
|
|
353
493
|
executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the
|
|
354
494
|
stream of events coming from the execution of tools.
|
|
@@ -374,6 +514,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
374
514
|
[
|
|
375
515
|
UserPromptNode(
|
|
376
516
|
user_prompt='What is the capital of France?',
|
|
517
|
+
instructions=None,
|
|
518
|
+
instructions_functions=[],
|
|
377
519
|
system_prompts=(),
|
|
378
520
|
system_prompt_functions=[],
|
|
379
521
|
system_prompt_dynamic_functions={},
|
|
@@ -387,6 +529,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
387
529
|
part_kind='user-prompt',
|
|
388
530
|
)
|
|
389
531
|
],
|
|
532
|
+
instructions=None,
|
|
390
533
|
kind='request',
|
|
391
534
|
)
|
|
392
535
|
),
|
|
@@ -398,17 +541,17 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
398
541
|
kind='response',
|
|
399
542
|
)
|
|
400
543
|
),
|
|
401
|
-
End(data=FinalResult(
|
|
544
|
+
End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)),
|
|
402
545
|
]
|
|
403
546
|
'''
|
|
404
|
-
print(agent_run.result.
|
|
547
|
+
print(agent_run.result.output)
|
|
405
548
|
#> Paris
|
|
406
549
|
```
|
|
407
550
|
|
|
408
551
|
Args:
|
|
409
552
|
user_prompt: User input to start/continue the conversation.
|
|
410
|
-
|
|
411
|
-
|
|
553
|
+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
554
|
+
output validators since output validators would expect an argument that matches the agent's output type.
|
|
412
555
|
message_history: History of the conversation so far.
|
|
413
556
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
414
557
|
deps: Optional dependencies to use for this run.
|
|
@@ -425,12 +568,22 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
425
568
|
model_used = self._get_model(model)
|
|
426
569
|
del model
|
|
427
570
|
|
|
571
|
+
if 'result_type' in _deprecated_kwargs: # pragma: no cover
|
|
572
|
+
if output_type is not str:
|
|
573
|
+
raise TypeError('`result_type` and `output_type` cannot be set at the same time.')
|
|
574
|
+
warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning)
|
|
575
|
+
output_type = _deprecated_kwargs['result_type']
|
|
576
|
+
|
|
428
577
|
deps = self._get_deps(deps)
|
|
429
578
|
new_message_index = len(message_history) if message_history else 0
|
|
430
|
-
|
|
579
|
+
output_schema = self._prepare_output_schema(output_type)
|
|
580
|
+
|
|
581
|
+
output_type_ = output_type or self.output_type
|
|
431
582
|
|
|
432
583
|
# Build the graph
|
|
433
|
-
graph =
|
|
584
|
+
graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = (
|
|
585
|
+
_agent_graph.build_agent_graph(self.name, self._deps_type, output_type_)
|
|
586
|
+
)
|
|
434
587
|
|
|
435
588
|
# Build the initial state
|
|
436
589
|
state = _agent_graph.GraphAgentState(
|
|
@@ -440,10 +593,10 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
440
593
|
run_step=0,
|
|
441
594
|
)
|
|
442
595
|
|
|
443
|
-
# We consider it a user error if a user tries to restrict the result type while having
|
|
596
|
+
# We consider it a user error if a user tries to restrict the result type while having an output validator that
|
|
444
597
|
# may change the result type from the restricted type to something else. Therefore, we consider the following
|
|
445
598
|
# typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
|
|
446
|
-
|
|
599
|
+
output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators)
|
|
447
600
|
|
|
448
601
|
# TODO: Instead of this, copy the function tools to ensure they don't share current_retry state between agent
|
|
449
602
|
# runs. Requires some changes to `Tool` to make them copyable though.
|
|
@@ -467,7 +620,16 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
467
620
|
},
|
|
468
621
|
)
|
|
469
622
|
|
|
470
|
-
|
|
623
|
+
async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
|
|
624
|
+
if self._instructions is None and not self._instructions_functions:
|
|
625
|
+
return None
|
|
626
|
+
|
|
627
|
+
instructions = self._instructions or ''
|
|
628
|
+
for instructions_runner in self._instructions_functions:
|
|
629
|
+
instructions += await instructions_runner.run(run_context)
|
|
630
|
+
return instructions
|
|
631
|
+
|
|
632
|
+
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
|
|
471
633
|
user_deps=deps,
|
|
472
634
|
prompt=user_prompt,
|
|
473
635
|
new_message_index=new_message_index,
|
|
@@ -476,16 +638,18 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
476
638
|
usage_limits=usage_limits,
|
|
477
639
|
max_result_retries=self._max_result_retries,
|
|
478
640
|
end_strategy=self.end_strategy,
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
result_validators=result_validators,
|
|
641
|
+
output_schema=output_schema,
|
|
642
|
+
output_validators=output_validators,
|
|
482
643
|
function_tools=self._function_tools,
|
|
483
644
|
mcp_servers=self._mcp_servers,
|
|
484
645
|
run_span=run_span,
|
|
485
646
|
tracer=tracer,
|
|
647
|
+
get_instructions=get_instructions,
|
|
486
648
|
)
|
|
487
649
|
start_node = _agent_graph.UserPromptNode[AgentDepsT](
|
|
488
650
|
user_prompt=user_prompt,
|
|
651
|
+
instructions=self._instructions,
|
|
652
|
+
instructions_functions=self._instructions_functions,
|
|
489
653
|
system_prompts=self._system_prompts,
|
|
490
654
|
system_prompt_functions=self._system_prompt_functions,
|
|
491
655
|
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
|
|
@@ -512,14 +676,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
512
676
|
usage_limits: _usage.UsageLimits | None = None,
|
|
513
677
|
usage: _usage.Usage | None = None,
|
|
514
678
|
infer_name: bool = True,
|
|
515
|
-
) -> AgentRunResult[
|
|
679
|
+
) -> AgentRunResult[OutputDataT]: ...
|
|
516
680
|
|
|
517
681
|
@overload
|
|
518
682
|
def run_sync(
|
|
519
683
|
self,
|
|
520
684
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
521
685
|
*,
|
|
522
|
-
|
|
686
|
+
output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None,
|
|
523
687
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
524
688
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
525
689
|
deps: AgentDepsT = None,
|
|
@@ -527,13 +691,15 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
527
691
|
usage_limits: _usage.UsageLimits | None = None,
|
|
528
692
|
usage: _usage.Usage | None = None,
|
|
529
693
|
infer_name: bool = True,
|
|
530
|
-
) -> AgentRunResult[
|
|
694
|
+
) -> AgentRunResult[RunOutputDataT]: ...
|
|
531
695
|
|
|
696
|
+
@overload
|
|
697
|
+
@deprecated('`result_type` is deprecated, use `output_type` instead.')
|
|
532
698
|
def run_sync(
|
|
533
699
|
self,
|
|
534
700
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
535
701
|
*,
|
|
536
|
-
result_type: type[
|
|
702
|
+
result_type: type[RunOutputDataT],
|
|
537
703
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
538
704
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
539
705
|
deps: AgentDepsT = None,
|
|
@@ -541,6 +707,21 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
541
707
|
usage_limits: _usage.UsageLimits | None = None,
|
|
542
708
|
usage: _usage.Usage | None = None,
|
|
543
709
|
infer_name: bool = True,
|
|
710
|
+
) -> AgentRunResult[RunOutputDataT]: ...
|
|
711
|
+
|
|
712
|
+
def run_sync(
|
|
713
|
+
self,
|
|
714
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
715
|
+
*,
|
|
716
|
+
output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None,
|
|
717
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
718
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
719
|
+
deps: AgentDepsT = None,
|
|
720
|
+
model_settings: ModelSettings | None = None,
|
|
721
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
722
|
+
usage: _usage.Usage | None = None,
|
|
723
|
+
infer_name: bool = True,
|
|
724
|
+
**_deprecated_kwargs: Never,
|
|
544
725
|
) -> AgentRunResult[Any]:
|
|
545
726
|
"""Synchronously run the agent with a user prompt.
|
|
546
727
|
|
|
@@ -554,14 +735,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
554
735
|
agent = Agent('openai:gpt-4o')
|
|
555
736
|
|
|
556
737
|
result_sync = agent.run_sync('What is the capital of Italy?')
|
|
557
|
-
print(result_sync.
|
|
738
|
+
print(result_sync.output)
|
|
558
739
|
#> Rome
|
|
559
740
|
```
|
|
560
741
|
|
|
561
742
|
Args:
|
|
562
743
|
user_prompt: User input to start/continue the conversation.
|
|
563
|
-
|
|
564
|
-
|
|
744
|
+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
745
|
+
output validators since output validators would expect an argument that matches the agent's output type.
|
|
565
746
|
message_history: History of the conversation so far.
|
|
566
747
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
567
748
|
deps: Optional dependencies to use for this run.
|
|
@@ -575,10 +756,17 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
575
756
|
"""
|
|
576
757
|
if infer_name and self.name is None:
|
|
577
758
|
self._infer_name(inspect.currentframe())
|
|
759
|
+
|
|
760
|
+
if 'result_type' in _deprecated_kwargs: # pragma: no cover
|
|
761
|
+
if output_type is not str:
|
|
762
|
+
raise TypeError('`result_type` and `output_type` cannot be set at the same time.')
|
|
763
|
+
warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning)
|
|
764
|
+
output_type = _deprecated_kwargs['result_type']
|
|
765
|
+
|
|
578
766
|
return get_event_loop().run_until_complete(
|
|
579
767
|
self.run(
|
|
580
768
|
user_prompt,
|
|
581
|
-
|
|
769
|
+
output_type=output_type,
|
|
582
770
|
message_history=message_history,
|
|
583
771
|
model=model,
|
|
584
772
|
deps=deps,
|
|
@@ -594,7 +782,21 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
594
782
|
self,
|
|
595
783
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
596
784
|
*,
|
|
597
|
-
|
|
785
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
786
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
787
|
+
deps: AgentDepsT = None,
|
|
788
|
+
model_settings: ModelSettings | None = None,
|
|
789
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
790
|
+
usage: _usage.Usage | None = None,
|
|
791
|
+
infer_name: bool = True,
|
|
792
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ...
|
|
793
|
+
|
|
794
|
+
@overload
|
|
795
|
+
def run_stream(
|
|
796
|
+
self,
|
|
797
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
798
|
+
*,
|
|
799
|
+
output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT],
|
|
598
800
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
599
801
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
600
802
|
deps: AgentDepsT = None,
|
|
@@ -602,14 +804,15 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
602
804
|
usage_limits: _usage.UsageLimits | None = None,
|
|
603
805
|
usage: _usage.Usage | None = None,
|
|
604
806
|
infer_name: bool = True,
|
|
605
|
-
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT,
|
|
807
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
|
|
606
808
|
|
|
607
809
|
@overload
|
|
810
|
+
@deprecated('`result_type` is deprecated, use `output_type` instead.')
|
|
608
811
|
def run_stream(
|
|
609
812
|
self,
|
|
610
813
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
611
814
|
*,
|
|
612
|
-
result_type: type[
|
|
815
|
+
result_type: type[RunOutputDataT],
|
|
613
816
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
614
817
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
615
818
|
deps: AgentDepsT = None,
|
|
@@ -617,14 +820,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
617
820
|
usage_limits: _usage.UsageLimits | None = None,
|
|
618
821
|
usage: _usage.Usage | None = None,
|
|
619
822
|
infer_name: bool = True,
|
|
620
|
-
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT,
|
|
823
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
|
|
621
824
|
|
|
622
825
|
@asynccontextmanager
|
|
623
826
|
async def run_stream( # noqa C901
|
|
624
827
|
self,
|
|
625
828
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
626
829
|
*,
|
|
627
|
-
|
|
830
|
+
output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None,
|
|
628
831
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
629
832
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
630
833
|
deps: AgentDepsT = None,
|
|
@@ -632,6 +835,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
632
835
|
usage_limits: _usage.UsageLimits | None = None,
|
|
633
836
|
usage: _usage.Usage | None = None,
|
|
634
837
|
infer_name: bool = True,
|
|
838
|
+
**_deprecated_kwargs: Never,
|
|
635
839
|
) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]:
|
|
636
840
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
637
841
|
|
|
@@ -643,14 +847,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
643
847
|
|
|
644
848
|
async def main():
|
|
645
849
|
async with agent.run_stream('What is the capital of the UK?') as response:
|
|
646
|
-
print(await response.
|
|
850
|
+
print(await response.get_output())
|
|
647
851
|
#> London
|
|
648
852
|
```
|
|
649
853
|
|
|
650
854
|
Args:
|
|
651
855
|
user_prompt: User input to start/continue the conversation.
|
|
652
|
-
|
|
653
|
-
|
|
856
|
+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
857
|
+
output validators since output validators would expect an argument that matches the agent's output type.
|
|
654
858
|
message_history: History of the conversation so far.
|
|
655
859
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
656
860
|
deps: Optional dependencies to use for this run.
|
|
@@ -669,10 +873,16 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
669
873
|
if frame := inspect.currentframe(): # pragma: no branch
|
|
670
874
|
self._infer_name(frame.f_back)
|
|
671
875
|
|
|
876
|
+
if 'result_type' in _deprecated_kwargs: # pragma: no cover
|
|
877
|
+
if output_type is not str:
|
|
878
|
+
raise TypeError('`result_type` and `output_type` cannot be set at the same time.')
|
|
879
|
+
warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning)
|
|
880
|
+
output_type = _deprecated_kwargs['result_type']
|
|
881
|
+
|
|
672
882
|
yielded = False
|
|
673
883
|
async with self.iter(
|
|
674
884
|
user_prompt,
|
|
675
|
-
|
|
885
|
+
output_type=output_type,
|
|
676
886
|
message_history=message_history,
|
|
677
887
|
model=model,
|
|
678
888
|
deps=deps,
|
|
@@ -692,15 +902,15 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
692
902
|
async def stream_to_final(
|
|
693
903
|
s: models.StreamedResponse,
|
|
694
904
|
) -> FinalResult[models.StreamedResponse] | None:
|
|
695
|
-
|
|
905
|
+
output_schema = graph_ctx.deps.output_schema
|
|
696
906
|
async for maybe_part_event in streamed_response:
|
|
697
907
|
if isinstance(maybe_part_event, _messages.PartStartEvent):
|
|
698
908
|
new_part = maybe_part_event.part
|
|
699
909
|
if isinstance(new_part, _messages.TextPart):
|
|
700
|
-
if _agent_graph.
|
|
910
|
+
if _agent_graph.allow_text_output(output_schema):
|
|
701
911
|
return FinalResult(s, None, None)
|
|
702
|
-
elif isinstance(new_part, _messages.ToolCallPart) and
|
|
703
|
-
for call, _ in
|
|
912
|
+
elif isinstance(new_part, _messages.ToolCallPart) and output_schema:
|
|
913
|
+
for call, _ in output_schema.find_tool([new_part]):
|
|
704
914
|
return FinalResult(s, call.tool_name, call.tool_call_id)
|
|
705
915
|
return None
|
|
706
916
|
|
|
@@ -745,9 +955,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
745
955
|
graph_ctx.deps.new_message_index,
|
|
746
956
|
graph_ctx.deps.usage_limits,
|
|
747
957
|
streamed_response,
|
|
748
|
-
graph_ctx.deps.
|
|
958
|
+
graph_ctx.deps.output_schema,
|
|
749
959
|
_agent_graph.build_run_context(graph_ctx),
|
|
750
|
-
graph_ctx.deps.
|
|
960
|
+
graph_ctx.deps.output_validators,
|
|
751
961
|
final_result_details.tool_name,
|
|
752
962
|
on_complete,
|
|
753
963
|
)
|
|
@@ -796,6 +1006,73 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
796
1006
|
if _utils.is_set(override_model_before):
|
|
797
1007
|
self._override_model = override_model_before
|
|
798
1008
|
|
|
1009
|
+
@overload
|
|
1010
|
+
def instructions(
|
|
1011
|
+
self, func: Callable[[RunContext[AgentDepsT]], str], /
|
|
1012
|
+
) -> Callable[[RunContext[AgentDepsT]], str]: ...
|
|
1013
|
+
|
|
1014
|
+
@overload
|
|
1015
|
+
def instructions(
|
|
1016
|
+
self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], /
|
|
1017
|
+
) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ...
|
|
1018
|
+
|
|
1019
|
+
@overload
|
|
1020
|
+
def instructions(self, func: Callable[[], str], /) -> Callable[[], str]: ...
|
|
1021
|
+
|
|
1022
|
+
@overload
|
|
1023
|
+
def instructions(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...
|
|
1024
|
+
|
|
1025
|
+
@overload
|
|
1026
|
+
def instructions(
|
|
1027
|
+
self, /
|
|
1028
|
+
) -> Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]: ...
|
|
1029
|
+
|
|
1030
|
+
def instructions(
|
|
1031
|
+
self,
|
|
1032
|
+
func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
|
|
1033
|
+
/,
|
|
1034
|
+
) -> (
|
|
1035
|
+
Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]
|
|
1036
|
+
| _system_prompt.SystemPromptFunc[AgentDepsT]
|
|
1037
|
+
):
|
|
1038
|
+
"""Decorator to register an instructions function.
|
|
1039
|
+
|
|
1040
|
+
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
|
|
1041
|
+
Can decorate a sync or async functions.
|
|
1042
|
+
|
|
1043
|
+
The decorator can be used bare (`agent.instructions`).
|
|
1044
|
+
|
|
1045
|
+
Overloads for every possible signature of `instructions` are included so the decorator doesn't obscure
|
|
1046
|
+
the type of the function.
|
|
1047
|
+
|
|
1048
|
+
Example:
|
|
1049
|
+
```python
|
|
1050
|
+
from pydantic_ai import Agent, RunContext
|
|
1051
|
+
|
|
1052
|
+
agent = Agent('test', deps_type=str)
|
|
1053
|
+
|
|
1054
|
+
@agent.instructions
|
|
1055
|
+
def simple_instructions() -> str:
|
|
1056
|
+
return 'foobar'
|
|
1057
|
+
|
|
1058
|
+
@agent.instructions
|
|
1059
|
+
async def async_instructions(ctx: RunContext[str]) -> str:
|
|
1060
|
+
return f'{ctx.deps} is the best'
|
|
1061
|
+
```
|
|
1062
|
+
"""
|
|
1063
|
+
if func is None:
|
|
1064
|
+
|
|
1065
|
+
def decorator(
|
|
1066
|
+
func_: _system_prompt.SystemPromptFunc[AgentDepsT],
|
|
1067
|
+
) -> _system_prompt.SystemPromptFunc[AgentDepsT]:
|
|
1068
|
+
self._instructions_functions.append(_system_prompt.SystemPromptRunner(func_))
|
|
1069
|
+
return func_
|
|
1070
|
+
|
|
1071
|
+
return decorator
|
|
1072
|
+
else:
|
|
1073
|
+
self._instructions_functions.append(_system_prompt.SystemPromptRunner(func))
|
|
1074
|
+
return func
|
|
1075
|
+
|
|
799
1076
|
@overload
|
|
800
1077
|
def system_prompt(
|
|
801
1078
|
self, func: Callable[[RunContext[AgentDepsT]], str], /
|
|
@@ -876,34 +1153,34 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
876
1153
|
return func
|
|
877
1154
|
|
|
878
1155
|
@overload
|
|
879
|
-
def
|
|
880
|
-
self, func: Callable[[RunContext[AgentDepsT],
|
|
881
|
-
) -> Callable[[RunContext[AgentDepsT],
|
|
1156
|
+
def output_validator(
|
|
1157
|
+
self, func: Callable[[RunContext[AgentDepsT], OutputDataT], OutputDataT], /
|
|
1158
|
+
) -> Callable[[RunContext[AgentDepsT], OutputDataT], OutputDataT]: ...
|
|
882
1159
|
|
|
883
1160
|
@overload
|
|
884
|
-
def
|
|
885
|
-
self, func: Callable[[RunContext[AgentDepsT],
|
|
886
|
-
) -> Callable[[RunContext[AgentDepsT],
|
|
1161
|
+
def output_validator(
|
|
1162
|
+
self, func: Callable[[RunContext[AgentDepsT], OutputDataT], Awaitable[OutputDataT]], /
|
|
1163
|
+
) -> Callable[[RunContext[AgentDepsT], OutputDataT], Awaitable[OutputDataT]]: ...
|
|
887
1164
|
|
|
888
1165
|
@overload
|
|
889
|
-
def
|
|
890
|
-
self, func: Callable[[
|
|
891
|
-
) -> Callable[[
|
|
1166
|
+
def output_validator(
|
|
1167
|
+
self, func: Callable[[OutputDataT], OutputDataT], /
|
|
1168
|
+
) -> Callable[[OutputDataT], OutputDataT]: ...
|
|
892
1169
|
|
|
893
1170
|
@overload
|
|
894
|
-
def
|
|
895
|
-
self, func: Callable[[
|
|
896
|
-
) -> Callable[[
|
|
1171
|
+
def output_validator(
|
|
1172
|
+
self, func: Callable[[OutputDataT], Awaitable[OutputDataT]], /
|
|
1173
|
+
) -> Callable[[OutputDataT], Awaitable[OutputDataT]]: ...
|
|
897
1174
|
|
|
898
|
-
def
|
|
899
|
-
self, func:
|
|
900
|
-
) ->
|
|
901
|
-
"""Decorator to register
|
|
1175
|
+
def output_validator(
|
|
1176
|
+
self, func: _output.OutputValidatorFunc[AgentDepsT, OutputDataT], /
|
|
1177
|
+
) -> _output.OutputValidatorFunc[AgentDepsT, OutputDataT]:
|
|
1178
|
+
"""Decorator to register an output validator function.
|
|
902
1179
|
|
|
903
1180
|
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
904
1181
|
Can decorate a sync or async functions.
|
|
905
1182
|
|
|
906
|
-
Overloads for every possible signature of `
|
|
1183
|
+
Overloads for every possible signature of `output_validator` are included so the decorator doesn't obscure
|
|
907
1184
|
the type of the function, see `tests/typed_agent.py` for tests.
|
|
908
1185
|
|
|
909
1186
|
Example:
|
|
@@ -912,26 +1189,29 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
912
1189
|
|
|
913
1190
|
agent = Agent('test', deps_type=str)
|
|
914
1191
|
|
|
915
|
-
@agent.
|
|
916
|
-
def
|
|
1192
|
+
@agent.output_validator
|
|
1193
|
+
def output_validator_simple(data: str) -> str:
|
|
917
1194
|
if 'wrong' in data:
|
|
918
1195
|
raise ModelRetry('wrong response')
|
|
919
1196
|
return data
|
|
920
1197
|
|
|
921
|
-
@agent.
|
|
922
|
-
async def
|
|
1198
|
+
@agent.output_validator
|
|
1199
|
+
async def output_validator_deps(ctx: RunContext[str], data: str) -> str:
|
|
923
1200
|
if ctx.deps in data:
|
|
924
1201
|
raise ModelRetry('wrong response')
|
|
925
1202
|
return data
|
|
926
1203
|
|
|
927
1204
|
result = agent.run_sync('foobar', deps='spam')
|
|
928
|
-
print(result.
|
|
1205
|
+
print(result.output)
|
|
929
1206
|
#> success (no tool calls)
|
|
930
1207
|
```
|
|
931
1208
|
"""
|
|
932
|
-
self.
|
|
1209
|
+
self._output_validators.append(_output.OutputValidator[AgentDepsT, Any](func))
|
|
933
1210
|
return func
|
|
934
1211
|
|
|
1212
|
+
@deprecated('`result_validator` is deprecated, use `output_validator` instead.')
|
|
1213
|
+
def result_validator(self, func: Any, /) -> Any: ...
|
|
1214
|
+
|
|
935
1215
|
@overload
|
|
936
1216
|
def tool(self, func: ToolFuncContext[AgentDepsT, ToolParams], /) -> ToolFuncContext[AgentDepsT, ToolParams]: ...
|
|
937
1217
|
|
|
@@ -987,7 +1267,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
987
1267
|
return ctx.deps + y
|
|
988
1268
|
|
|
989
1269
|
result = agent.run_sync('foobar', deps=1)
|
|
990
|
-
print(result.
|
|
1270
|
+
print(result.output)
|
|
991
1271
|
#> {"foobar":1,"spam":1.0}
|
|
992
1272
|
```
|
|
993
1273
|
|
|
@@ -1096,7 +1376,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1096
1376
|
return 3.14
|
|
1097
1377
|
|
|
1098
1378
|
result = agent.run_sync('foobar', deps=1)
|
|
1099
|
-
print(result.
|
|
1379
|
+
print(result.output)
|
|
1100
1380
|
#> {"foobar":123,"spam":3.14}
|
|
1101
1381
|
```
|
|
1102
1382
|
|
|
@@ -1183,7 +1463,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1183
1463
|
if tool.name in self._function_tools:
|
|
1184
1464
|
raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
|
|
1185
1465
|
|
|
1186
|
-
if self.
|
|
1466
|
+
if self._output_schema and tool.name in self._output_schema.tools:
|
|
1187
1467
|
raise exceptions.UserError(f'Tool name conflicts with result schema name: {tool.name!r}')
|
|
1188
1468
|
|
|
1189
1469
|
self._function_tools[tool.name] = tool
|
|
@@ -1226,7 +1506,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1226
1506
|
|
|
1227
1507
|
return model_
|
|
1228
1508
|
|
|
1229
|
-
def _get_deps(self: Agent[T,
|
|
1509
|
+
def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T:
|
|
1230
1510
|
"""Get deps for a run.
|
|
1231
1511
|
|
|
1232
1512
|
If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
|
|
@@ -1264,22 +1544,19 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1264
1544
|
def last_run_messages(self) -> list[_messages.ModelMessage]:
|
|
1265
1545
|
raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
|
|
1266
1546
|
|
|
1267
|
-
def
|
|
1268
|
-
self,
|
|
1269
|
-
) ->
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
|
|
1278
|
-
return _result.ResultSchema[result_type].build(
|
|
1279
|
-
result_type, self._result_tool_name, self._result_tool_description
|
|
1547
|
+
def _prepare_output_schema(
|
|
1548
|
+
self, output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None
|
|
1549
|
+
) -> _output.OutputSchema[RunOutputDataT] | None:
|
|
1550
|
+
if output_type is not None:
|
|
1551
|
+
if self._output_validators:
|
|
1552
|
+
raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators')
|
|
1553
|
+
return _output.OutputSchema[RunOutputDataT].build(
|
|
1554
|
+
output_type,
|
|
1555
|
+
self._deprecated_result_tool_name,
|
|
1556
|
+
self._deprecated_result_tool_description,
|
|
1280
1557
|
)
|
|
1281
1558
|
else:
|
|
1282
|
-
return self.
|
|
1559
|
+
return self._output_schema # pyright: ignore[reportReturnType]
|
|
1283
1560
|
|
|
1284
1561
|
@staticmethod
|
|
1285
1562
|
def is_model_request_node(
|
|
@@ -1337,7 +1614,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1337
1614
|
|
|
1338
1615
|
|
|
1339
1616
|
@dataclasses.dataclass(repr=False)
|
|
1340
|
-
class AgentRun(Generic[AgentDepsT,
|
|
1617
|
+
class AgentRun(Generic[AgentDepsT, OutputDataT]):
|
|
1341
1618
|
"""A stateful, async-iterable run of an [`Agent`][pydantic_ai.agent.Agent].
|
|
1342
1619
|
|
|
1343
1620
|
You generally obtain an `AgentRun` instance by calling `async with my_agent.iter(...) as agent_run:`.
|
|
@@ -1363,6 +1640,8 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1363
1640
|
[
|
|
1364
1641
|
UserPromptNode(
|
|
1365
1642
|
user_prompt='What is the capital of France?',
|
|
1643
|
+
instructions=None,
|
|
1644
|
+
instructions_functions=[],
|
|
1366
1645
|
system_prompts=(),
|
|
1367
1646
|
system_prompt_functions=[],
|
|
1368
1647
|
system_prompt_dynamic_functions={},
|
|
@@ -1376,6 +1655,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1376
1655
|
part_kind='user-prompt',
|
|
1377
1656
|
)
|
|
1378
1657
|
],
|
|
1658
|
+
instructions=None,
|
|
1379
1659
|
kind='request',
|
|
1380
1660
|
)
|
|
1381
1661
|
),
|
|
@@ -1387,10 +1667,10 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1387
1667
|
kind='response',
|
|
1388
1668
|
)
|
|
1389
1669
|
),
|
|
1390
|
-
End(data=FinalResult(
|
|
1670
|
+
End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)),
|
|
1391
1671
|
]
|
|
1392
1672
|
'''
|
|
1393
|
-
print(agent_run.result.
|
|
1673
|
+
print(agent_run.result.output)
|
|
1394
1674
|
#> Paris
|
|
1395
1675
|
```
|
|
1396
1676
|
|
|
@@ -1399,9 +1679,19 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1399
1679
|
"""
|
|
1400
1680
|
|
|
1401
1681
|
_graph_run: GraphRun[
|
|
1402
|
-
_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[
|
|
1682
|
+
_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[OutputDataT]
|
|
1403
1683
|
]
|
|
1404
1684
|
|
|
1685
|
+
@overload
|
|
1686
|
+
def _span(self, *, required: Literal[False]) -> AbstractSpan | None: ...
|
|
1687
|
+
@overload
|
|
1688
|
+
def _span(self) -> AbstractSpan: ...
|
|
1689
|
+
def _span(self, *, required: bool = True) -> AbstractSpan | None:
|
|
1690
|
+
span = self._graph_run._span(required=False) # type: ignore[reportPrivateUsage]
|
|
1691
|
+
if span is None and required: # pragma: no cover
|
|
1692
|
+
raise AttributeError('Span is not available for this agent run')
|
|
1693
|
+
return span
|
|
1694
|
+
|
|
1405
1695
|
@property
|
|
1406
1696
|
def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]:
|
|
1407
1697
|
"""The current context of the agent run."""
|
|
@@ -1412,7 +1702,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1412
1702
|
@property
|
|
1413
1703
|
def next_node(
|
|
1414
1704
|
self,
|
|
1415
|
-
) -> _agent_graph.AgentNode[AgentDepsT,
|
|
1705
|
+
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
|
|
1416
1706
|
"""The next node that will be run in the agent graph.
|
|
1417
1707
|
|
|
1418
1708
|
This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`.
|
|
@@ -1425,7 +1715,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1425
1715
|
raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover
|
|
1426
1716
|
|
|
1427
1717
|
@property
|
|
1428
|
-
def result(self) -> AgentRunResult[
|
|
1718
|
+
def result(self) -> AgentRunResult[OutputDataT] | None:
|
|
1429
1719
|
"""The final result of the run if it has ended, otherwise `None`.
|
|
1430
1720
|
|
|
1431
1721
|
Once the run returns an [`End`][pydantic_graph.nodes.End] node, `result` is populated
|
|
@@ -1435,21 +1725,22 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1435
1725
|
if graph_run_result is None:
|
|
1436
1726
|
return None
|
|
1437
1727
|
return AgentRunResult(
|
|
1438
|
-
graph_run_result.output.
|
|
1728
|
+
graph_run_result.output.output,
|
|
1439
1729
|
graph_run_result.output.tool_name,
|
|
1440
1730
|
graph_run_result.state,
|
|
1441
1731
|
self._graph_run.deps.new_message_index,
|
|
1732
|
+
self._graph_run._span(required=False), # type: ignore[reportPrivateUsage]
|
|
1442
1733
|
)
|
|
1443
1734
|
|
|
1444
1735
|
def __aiter__(
|
|
1445
1736
|
self,
|
|
1446
|
-
) -> AsyncIterator[_agent_graph.AgentNode[AgentDepsT,
|
|
1737
|
+
) -> AsyncIterator[_agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]]:
|
|
1447
1738
|
"""Provide async-iteration over the nodes in the agent run."""
|
|
1448
1739
|
return self
|
|
1449
1740
|
|
|
1450
1741
|
async def __anext__(
|
|
1451
1742
|
self,
|
|
1452
|
-
) -> _agent_graph.AgentNode[AgentDepsT,
|
|
1743
|
+
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
|
|
1453
1744
|
"""Advance to the next node automatically based on the last returned node."""
|
|
1454
1745
|
next_node = await self._graph_run.__anext__()
|
|
1455
1746
|
if _agent_graph.is_agent_node(next_node):
|
|
@@ -1459,8 +1750,8 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1459
1750
|
|
|
1460
1751
|
async def next(
|
|
1461
1752
|
self,
|
|
1462
|
-
node: _agent_graph.AgentNode[AgentDepsT,
|
|
1463
|
-
) -> _agent_graph.AgentNode[AgentDepsT,
|
|
1753
|
+
node: _agent_graph.AgentNode[AgentDepsT, OutputDataT],
|
|
1754
|
+
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
|
|
1464
1755
|
"""Manually drive the agent run by passing in the node you want to run next.
|
|
1465
1756
|
|
|
1466
1757
|
This lets you inspect or mutate the node before continuing execution, or skip certain nodes
|
|
@@ -1487,6 +1778,8 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1487
1778
|
[
|
|
1488
1779
|
UserPromptNode(
|
|
1489
1780
|
user_prompt='What is the capital of France?',
|
|
1781
|
+
instructions=None,
|
|
1782
|
+
instructions_functions=[],
|
|
1490
1783
|
system_prompts=(),
|
|
1491
1784
|
system_prompt_functions=[],
|
|
1492
1785
|
system_prompt_dynamic_functions={},
|
|
@@ -1500,6 +1793,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1500
1793
|
part_kind='user-prompt',
|
|
1501
1794
|
)
|
|
1502
1795
|
],
|
|
1796
|
+
instructions=None,
|
|
1503
1797
|
kind='request',
|
|
1504
1798
|
)
|
|
1505
1799
|
),
|
|
@@ -1511,10 +1805,10 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1511
1805
|
kind='response',
|
|
1512
1806
|
)
|
|
1513
1807
|
),
|
|
1514
|
-
End(data=FinalResult(
|
|
1808
|
+
End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)),
|
|
1515
1809
|
]
|
|
1516
1810
|
'''
|
|
1517
|
-
print('Final result:', agent_run.result.
|
|
1811
|
+
print('Final result:', agent_run.result.output)
|
|
1518
1812
|
#> Final result: Paris
|
|
1519
1813
|
```
|
|
1520
1814
|
|
|
@@ -1544,94 +1838,150 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1544
1838
|
|
|
1545
1839
|
|
|
1546
1840
|
@dataclasses.dataclass
|
|
1547
|
-
class AgentRunResult(Generic[
|
|
1841
|
+
class AgentRunResult(Generic[OutputDataT]):
|
|
1548
1842
|
"""The final result of an agent run."""
|
|
1549
1843
|
|
|
1550
|
-
|
|
1844
|
+
output: OutputDataT
|
|
1845
|
+
"""The output data from the agent run."""
|
|
1551
1846
|
|
|
1552
|
-
|
|
1847
|
+
_output_tool_name: str | None = dataclasses.field(repr=False)
|
|
1553
1848
|
_state: _agent_graph.GraphAgentState = dataclasses.field(repr=False)
|
|
1554
1849
|
_new_message_index: int = dataclasses.field(repr=False)
|
|
1850
|
+
_span_value: AbstractSpan | None = dataclasses.field(repr=False)
|
|
1555
1851
|
|
|
1556
|
-
|
|
1557
|
-
|
|
1852
|
+
@overload
|
|
1853
|
+
def _span(self, *, required: Literal[False]) -> AbstractSpan | None: ...
|
|
1854
|
+
@overload
|
|
1855
|
+
def _span(self) -> AbstractSpan: ...
|
|
1856
|
+
def _span(self, *, required: bool = True) -> AbstractSpan | None:
|
|
1857
|
+
if self._span_value is None and required: # pragma: no cover
|
|
1858
|
+
raise AttributeError('Span is not available for this agent run')
|
|
1859
|
+
return self._span_value
|
|
1860
|
+
|
|
1861
|
+
@property
|
|
1862
|
+
@deprecated('`result.data` is deprecated, use `result.output` instead.')
|
|
1863
|
+
def data(self) -> OutputDataT:
|
|
1864
|
+
return self.output
|
|
1865
|
+
|
|
1866
|
+
def _set_output_tool_return(self, return_content: str) -> list[_messages.ModelMessage]:
|
|
1867
|
+
"""Set return content for the output tool.
|
|
1558
1868
|
|
|
1559
|
-
Useful if you want to continue the conversation and want to set the response to the
|
|
1869
|
+
Useful if you want to continue the conversation and want to set the response to the output tool call.
|
|
1560
1870
|
"""
|
|
1561
|
-
if not self.
|
|
1562
|
-
raise ValueError('Cannot set
|
|
1871
|
+
if not self._output_tool_name:
|
|
1872
|
+
raise ValueError('Cannot set output tool return content when the return type is `str`.')
|
|
1563
1873
|
messages = deepcopy(self._state.message_history)
|
|
1564
1874
|
last_message = messages[-1]
|
|
1565
1875
|
for part in last_message.parts:
|
|
1566
|
-
if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self.
|
|
1876
|
+
if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._output_tool_name:
|
|
1567
1877
|
part.content = return_content
|
|
1568
1878
|
return messages
|
|
1569
|
-
raise LookupError(f'No tool call found with tool name {self.
|
|
1879
|
+
raise LookupError(f'No tool call found with tool name {self._output_tool_name!r}.')
|
|
1880
|
+
|
|
1881
|
+
@overload
|
|
1882
|
+
def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ...
|
|
1883
|
+
|
|
1884
|
+
@overload
|
|
1885
|
+
@deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.')
|
|
1886
|
+
def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ...
|
|
1570
1887
|
|
|
1571
|
-
def all_messages(
|
|
1888
|
+
def all_messages(
|
|
1889
|
+
self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None
|
|
1890
|
+
) -> list[_messages.ModelMessage]:
|
|
1572
1891
|
"""Return the history of _messages.
|
|
1573
1892
|
|
|
1574
1893
|
Args:
|
|
1575
|
-
|
|
1576
|
-
This provides a convenient way to modify the content of the
|
|
1577
|
-
the conversation and want to set the response to the
|
|
1894
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
1895
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
1896
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
1578
1897
|
not be modified.
|
|
1898
|
+
result_tool_return_content: Deprecated, use `output_tool_return_content` instead.
|
|
1579
1899
|
|
|
1580
1900
|
Returns:
|
|
1581
1901
|
List of messages.
|
|
1582
1902
|
"""
|
|
1583
|
-
|
|
1584
|
-
|
|
1903
|
+
content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content)
|
|
1904
|
+
if content is not None:
|
|
1905
|
+
return self._set_output_tool_return(content)
|
|
1585
1906
|
else:
|
|
1586
1907
|
return self._state.message_history
|
|
1587
1908
|
|
|
1588
|
-
|
|
1909
|
+
@overload
|
|
1910
|
+
def all_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: ...
|
|
1911
|
+
|
|
1912
|
+
@overload
|
|
1913
|
+
@deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.')
|
|
1914
|
+
def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: ...
|
|
1915
|
+
|
|
1916
|
+
def all_messages_json(
|
|
1917
|
+
self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None
|
|
1918
|
+
) -> bytes:
|
|
1589
1919
|
"""Return all messages from [`all_messages`][pydantic_ai.agent.AgentRunResult.all_messages] as JSON bytes.
|
|
1590
1920
|
|
|
1591
1921
|
Args:
|
|
1592
|
-
|
|
1593
|
-
This provides a convenient way to modify the content of the
|
|
1594
|
-
the conversation and want to set the response to the
|
|
1922
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
1923
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
1924
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
1595
1925
|
not be modified.
|
|
1926
|
+
result_tool_return_content: Deprecated, use `output_tool_return_content` instead.
|
|
1596
1927
|
|
|
1597
1928
|
Returns:
|
|
1598
1929
|
JSON bytes representing the messages.
|
|
1599
1930
|
"""
|
|
1600
|
-
|
|
1601
|
-
|
|
1602
|
-
|
|
1931
|
+
content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content)
|
|
1932
|
+
return _messages.ModelMessagesTypeAdapter.dump_json(self.all_messages(output_tool_return_content=content))
|
|
1933
|
+
|
|
1934
|
+
@overload
|
|
1935
|
+
def new_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ...
|
|
1936
|
+
|
|
1937
|
+
@overload
|
|
1938
|
+
@deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.')
|
|
1939
|
+
def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ...
|
|
1603
1940
|
|
|
1604
|
-
def new_messages(
|
|
1941
|
+
def new_messages(
|
|
1942
|
+
self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None
|
|
1943
|
+
) -> list[_messages.ModelMessage]:
|
|
1605
1944
|
"""Return new messages associated with this run.
|
|
1606
1945
|
|
|
1607
1946
|
Messages from older runs are excluded.
|
|
1608
1947
|
|
|
1609
1948
|
Args:
|
|
1610
|
-
|
|
1611
|
-
This provides a convenient way to modify the content of the
|
|
1612
|
-
the conversation and want to set the response to the
|
|
1949
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
1950
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
1951
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
1613
1952
|
not be modified.
|
|
1953
|
+
result_tool_return_content: Deprecated, use `output_tool_return_content` instead.
|
|
1614
1954
|
|
|
1615
1955
|
Returns:
|
|
1616
1956
|
List of new messages.
|
|
1617
1957
|
"""
|
|
1618
|
-
|
|
1958
|
+
content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content)
|
|
1959
|
+
return self.all_messages(output_tool_return_content=content)[self._new_message_index :]
|
|
1619
1960
|
|
|
1620
|
-
|
|
1961
|
+
@overload
|
|
1962
|
+
def new_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: ...
|
|
1963
|
+
|
|
1964
|
+
@overload
|
|
1965
|
+
@deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.')
|
|
1966
|
+
def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: ...
|
|
1967
|
+
|
|
1968
|
+
def new_messages_json(
|
|
1969
|
+
self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None
|
|
1970
|
+
) -> bytes:
|
|
1621
1971
|
"""Return new messages from [`new_messages`][pydantic_ai.agent.AgentRunResult.new_messages] as JSON bytes.
|
|
1622
1972
|
|
|
1623
1973
|
Args:
|
|
1624
|
-
|
|
1625
|
-
This provides a convenient way to modify the content of the
|
|
1626
|
-
the conversation and want to set the response to the
|
|
1974
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
1975
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
1976
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
1627
1977
|
not be modified.
|
|
1978
|
+
result_tool_return_content: Deprecated, use `output_tool_return_content` instead.
|
|
1628
1979
|
|
|
1629
1980
|
Returns:
|
|
1630
1981
|
JSON bytes representing the new messages.
|
|
1631
1982
|
"""
|
|
1632
|
-
|
|
1633
|
-
|
|
1634
|
-
)
|
|
1983
|
+
content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content)
|
|
1984
|
+
return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages(output_tool_return_content=content))
|
|
1635
1985
|
|
|
1636
1986
|
def usage(self) -> _usage.Usage:
|
|
1637
1987
|
"""Return the usage of the whole run."""
|