pydantic-ai-slim 0.0.54__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/__init__.py +10 -3
- pydantic_ai/_agent_graph.py +67 -55
- pydantic_ai/_cli.py +1 -2
- pydantic_ai/{_result.py → _output.py} +69 -47
- pydantic_ai/_utils.py +20 -0
- pydantic_ai/agent.py +503 -163
- pydantic_ai/format_as_xml.py +6 -113
- pydantic_ai/format_prompt.py +116 -0
- pydantic_ai/messages.py +104 -21
- pydantic_ai/models/__init__.py +25 -5
- pydantic_ai/models/_json_schema.py +156 -0
- pydantic_ai/models/anthropic.py +14 -4
- pydantic_ai/models/bedrock.py +100 -22
- pydantic_ai/models/cohere.py +48 -44
- pydantic_ai/models/fallback.py +2 -1
- pydantic_ai/models/function.py +8 -8
- pydantic_ai/models/gemini.py +65 -75
- pydantic_ai/models/groq.py +34 -29
- pydantic_ai/models/instrumented.py +4 -4
- pydantic_ai/models/mistral.py +67 -58
- pydantic_ai/models/openai.py +113 -158
- pydantic_ai/models/test.py +45 -46
- pydantic_ai/models/wrapper.py +3 -0
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/azure.py +2 -2
- pydantic_ai/result.py +203 -90
- pydantic_ai/tools.py +3 -3
- {pydantic_ai_slim-0.0.54.dist-info → pydantic_ai_slim-0.1.0.dist-info}/METADATA +5 -5
- pydantic_ai_slim-0.1.0.dist-info/RECORD +53 -0
- pydantic_ai_slim-0.0.54.dist-info/RECORD +0 -51
- {pydantic_ai_slim-0.0.54.dist-info → pydantic_ai_slim-0.1.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.0.54.dist-info → pydantic_ai_slim-0.1.0.dist-info}/entry_points.txt +0 -0
pydantic_ai/agent.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import dataclasses
|
|
4
4
|
import inspect
|
|
5
|
+
import warnings
|
|
5
6
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
6
7
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
|
|
7
8
|
from copy import deepcopy
|
|
@@ -10,14 +11,14 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final,
|
|
|
10
11
|
|
|
11
12
|
from opentelemetry.trace import NoOpTracer, use_span
|
|
12
13
|
from pydantic.json_schema import GenerateJsonSchema
|
|
13
|
-
from typing_extensions import TypeGuard, TypeVar, deprecated
|
|
14
|
+
from typing_extensions import Literal, Never, TypeGuard, TypeVar, deprecated
|
|
14
15
|
|
|
15
16
|
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
|
|
16
17
|
from pydantic_graph._utils import get_event_loop
|
|
17
18
|
|
|
18
19
|
from . import (
|
|
19
20
|
_agent_graph,
|
|
20
|
-
|
|
21
|
+
_output,
|
|
21
22
|
_system_prompt,
|
|
22
23
|
_utils,
|
|
23
24
|
exceptions,
|
|
@@ -26,8 +27,9 @@ from . import (
|
|
|
26
27
|
result,
|
|
27
28
|
usage as _usage,
|
|
28
29
|
)
|
|
30
|
+
from ._utils import AbstractSpan
|
|
29
31
|
from .models.instrumented import InstrumentationSettings, InstrumentedModel
|
|
30
|
-
from .result import FinalResult,
|
|
32
|
+
from .result import FinalResult, OutputDataT, StreamedRunResult, ToolOutput
|
|
31
33
|
from .settings import ModelSettings, merge_model_settings
|
|
32
34
|
from .tools import (
|
|
33
35
|
AgentDepsT,
|
|
@@ -52,6 +54,7 @@ UserPromptNode = _agent_graph.UserPromptNode
|
|
|
52
54
|
if TYPE_CHECKING:
|
|
53
55
|
from pydantic_ai.mcp import MCPServer
|
|
54
56
|
|
|
57
|
+
|
|
55
58
|
__all__ = (
|
|
56
59
|
'Agent',
|
|
57
60
|
'AgentRun',
|
|
@@ -68,17 +71,17 @@ __all__ = (
|
|
|
68
71
|
T = TypeVar('T')
|
|
69
72
|
S = TypeVar('S')
|
|
70
73
|
NoneType = type(None)
|
|
71
|
-
|
|
72
|
-
"""Type variable for the result data of a run where `
|
|
74
|
+
RunOutputDataT = TypeVar('RunOutputDataT')
|
|
75
|
+
"""Type variable for the result data of a run where `output_type` was customized on the run call."""
|
|
73
76
|
|
|
74
77
|
|
|
75
78
|
@final
|
|
76
79
|
@dataclasses.dataclass(init=False)
|
|
77
|
-
class Agent(Generic[AgentDepsT,
|
|
80
|
+
class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
78
81
|
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
|
|
79
82
|
|
|
80
83
|
Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT]
|
|
81
|
-
and the result data type they return, [`
|
|
84
|
+
and the result data type they return, [`OutputDataT`][pydantic_ai.result.OutputDataT].
|
|
82
85
|
|
|
83
86
|
By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
|
|
84
87
|
|
|
@@ -89,7 +92,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
89
92
|
|
|
90
93
|
agent = Agent('openai:gpt-4o')
|
|
91
94
|
result = agent.run_sync('What is the capital of France?')
|
|
92
|
-
print(result.
|
|
95
|
+
print(result.output)
|
|
93
96
|
#> Paris
|
|
94
97
|
```
|
|
95
98
|
"""
|
|
@@ -115,9 +118,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
115
118
|
be merged with this value, with the runtime argument taking priority.
|
|
116
119
|
"""
|
|
117
120
|
|
|
118
|
-
|
|
121
|
+
output_type: type[OutputDataT] | ToolOutput[OutputDataT]
|
|
119
122
|
"""
|
|
120
|
-
The type of
|
|
123
|
+
The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`.
|
|
121
124
|
"""
|
|
122
125
|
|
|
123
126
|
instrument: InstrumentationSettings | bool | None
|
|
@@ -126,10 +129,12 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
126
129
|
_instrument_default: ClassVar[InstrumentationSettings | bool] = False
|
|
127
130
|
|
|
128
131
|
_deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
132
|
+
_deprecated_result_tool_name: str | None = dataclasses.field(repr=False)
|
|
133
|
+
_deprecated_result_tool_description: str | None = dataclasses.field(repr=False)
|
|
134
|
+
_output_schema: _output.OutputSchema[OutputDataT] | None = dataclasses.field(repr=False)
|
|
135
|
+
_output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False)
|
|
136
|
+
_instructions: str | None = dataclasses.field(repr=False)
|
|
137
|
+
_instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
|
|
133
138
|
_system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
|
|
134
139
|
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
|
|
135
140
|
_system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
|
|
@@ -142,11 +147,36 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
142
147
|
_override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
|
|
143
148
|
_override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
|
|
144
149
|
|
|
150
|
+
@overload
|
|
151
|
+
def __init__(
|
|
152
|
+
self,
|
|
153
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
154
|
+
*,
|
|
155
|
+
output_type: type[OutputDataT] | ToolOutput[OutputDataT] = str,
|
|
156
|
+
instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
|
|
157
|
+
system_prompt: str | Sequence[str] = (),
|
|
158
|
+
deps_type: type[AgentDepsT] = NoneType,
|
|
159
|
+
name: str | None = None,
|
|
160
|
+
model_settings: ModelSettings | None = None,
|
|
161
|
+
retries: int = 1,
|
|
162
|
+
output_retries: int | None = None,
|
|
163
|
+
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
164
|
+
mcp_servers: Sequence[MCPServer] = (),
|
|
165
|
+
defer_model_check: bool = False,
|
|
166
|
+
end_strategy: EndStrategy = 'early',
|
|
167
|
+
instrument: InstrumentationSettings | bool | None = None,
|
|
168
|
+
) -> None: ...
|
|
169
|
+
|
|
170
|
+
@overload
|
|
171
|
+
@deprecated(
|
|
172
|
+
'`result_type`, `result_tool_name`, `result_tool_description` & `result_retries` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.'
|
|
173
|
+
)
|
|
145
174
|
def __init__(
|
|
146
175
|
self,
|
|
147
176
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
148
177
|
*,
|
|
149
|
-
result_type: type[
|
|
178
|
+
result_type: type[OutputDataT] = str,
|
|
179
|
+
instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
|
|
150
180
|
system_prompt: str | Sequence[str] = (),
|
|
151
181
|
deps_type: type[AgentDepsT] = NoneType,
|
|
152
182
|
name: str | None = None,
|
|
@@ -160,13 +190,37 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
160
190
|
defer_model_check: bool = False,
|
|
161
191
|
end_strategy: EndStrategy = 'early',
|
|
162
192
|
instrument: InstrumentationSettings | bool | None = None,
|
|
193
|
+
) -> None: ...
|
|
194
|
+
|
|
195
|
+
def __init__(
|
|
196
|
+
self,
|
|
197
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
198
|
+
*,
|
|
199
|
+
# TODO change this back to `output_type: type[OutputDataT] | ToolOutput[OutputDataT] = str,` when we remove the overloads
|
|
200
|
+
output_type: Any = str,
|
|
201
|
+
instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
|
|
202
|
+
system_prompt: str | Sequence[str] = (),
|
|
203
|
+
deps_type: type[AgentDepsT] = NoneType,
|
|
204
|
+
name: str | None = None,
|
|
205
|
+
model_settings: ModelSettings | None = None,
|
|
206
|
+
retries: int = 1,
|
|
207
|
+
output_retries: int | None = None,
|
|
208
|
+
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
209
|
+
mcp_servers: Sequence[MCPServer] = (),
|
|
210
|
+
defer_model_check: bool = False,
|
|
211
|
+
end_strategy: EndStrategy = 'early',
|
|
212
|
+
instrument: InstrumentationSettings | bool | None = None,
|
|
213
|
+
**_deprecated_kwargs: Any,
|
|
163
214
|
):
|
|
164
215
|
"""Create an agent.
|
|
165
216
|
|
|
166
217
|
Args:
|
|
167
218
|
model: The default model to use for this agent, if not provide,
|
|
168
219
|
you must provide the model when calling it. We allow str here since the actual list of allowed models changes frequently.
|
|
169
|
-
|
|
220
|
+
output_type: The type of the output data, used to validate the data returned by the model,
|
|
221
|
+
defaults to `str`.
|
|
222
|
+
instructions: Instructions to use for this agent, you can also register instructions via a function with
|
|
223
|
+
[`instructions`][pydantic_ai.Agent.instructions].
|
|
170
224
|
system_prompt: Static system prompts to use for this agent, you can also register system
|
|
171
225
|
prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt].
|
|
172
226
|
deps_type: The type used for dependency injection, this parameter exists solely to allow you to fully
|
|
@@ -177,9 +231,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
177
231
|
when the agent is first run.
|
|
178
232
|
model_settings: Optional model request settings to use for this agent's runs, by default.
|
|
179
233
|
retries: The default number of retries to allow before raising an error.
|
|
180
|
-
|
|
181
|
-
result_tool_description: The description of the final result tool.
|
|
182
|
-
result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
|
|
234
|
+
output_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
|
|
183
235
|
tools: Tools to register with the agent, you can also register tools via the decorators
|
|
184
236
|
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
|
|
185
237
|
mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer]
|
|
@@ -207,17 +259,48 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
207
259
|
self.end_strategy = end_strategy
|
|
208
260
|
self.name = name
|
|
209
261
|
self.model_settings = model_settings
|
|
210
|
-
|
|
262
|
+
|
|
263
|
+
if 'result_type' in _deprecated_kwargs: # pragma: no cover
|
|
264
|
+
if output_type is not str:
|
|
265
|
+
raise TypeError('`result_type` and `output_type` cannot be set at the same time.')
|
|
266
|
+
warnings.warn('`result_type` is deprecated, use `output_type` instead', DeprecationWarning)
|
|
267
|
+
output_type = _deprecated_kwargs['result_type']
|
|
268
|
+
|
|
269
|
+
self.output_type = output_type
|
|
270
|
+
|
|
211
271
|
self.instrument = instrument
|
|
212
272
|
|
|
213
273
|
self._deps_type = deps_type
|
|
214
274
|
|
|
215
|
-
self.
|
|
216
|
-
self.
|
|
217
|
-
|
|
218
|
-
|
|
275
|
+
self._deprecated_result_tool_name = _deprecated_kwargs.get('result_tool_name')
|
|
276
|
+
if self._deprecated_result_tool_name is not None: # pragma: no cover
|
|
277
|
+
warnings.warn(
|
|
278
|
+
'`result_tool_name` is deprecated, use `output_type` with `ToolOutput` instead',
|
|
279
|
+
DeprecationWarning,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
self._deprecated_result_tool_description = _deprecated_kwargs.get('result_tool_description')
|
|
283
|
+
if self._deprecated_result_tool_description is not None: # pragma: no cover
|
|
284
|
+
warnings.warn(
|
|
285
|
+
'`result_tool_description` is deprecated, use `output_type` with `ToolOutput` instead',
|
|
286
|
+
DeprecationWarning,
|
|
287
|
+
)
|
|
288
|
+
result_retries = _deprecated_kwargs.get('result_retries')
|
|
289
|
+
if result_retries is not None: # pragma: no cover
|
|
290
|
+
if output_retries is not None:
|
|
291
|
+
raise TypeError('`output_retries` and `result_retries` cannot be set at the same time.')
|
|
292
|
+
warnings.warn('`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning)
|
|
293
|
+
output_retries = result_retries
|
|
294
|
+
|
|
295
|
+
self._output_schema = _output.OutputSchema[OutputDataT].build(
|
|
296
|
+
output_type, self._deprecated_result_tool_name, self._deprecated_result_tool_description
|
|
297
|
+
)
|
|
298
|
+
self._output_validators = []
|
|
299
|
+
|
|
300
|
+
self._instructions_functions = (
|
|
301
|
+
[_system_prompt.SystemPromptRunner(instructions)] if callable(instructions) else []
|
|
219
302
|
)
|
|
220
|
-
self.
|
|
303
|
+
self._instructions = instructions if isinstance(instructions, str) else None
|
|
221
304
|
|
|
222
305
|
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
|
|
223
306
|
self._system_prompt_functions = []
|
|
@@ -226,7 +309,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
226
309
|
self._function_tools = {}
|
|
227
310
|
|
|
228
311
|
self._default_retries = retries
|
|
229
|
-
self._max_result_retries =
|
|
312
|
+
self._max_result_retries = output_retries if output_retries is not None else retries
|
|
230
313
|
self._mcp_servers = mcp_servers
|
|
231
314
|
for tool in tools:
|
|
232
315
|
if isinstance(tool, Tool):
|
|
@@ -244,7 +327,22 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
244
327
|
self,
|
|
245
328
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
246
329
|
*,
|
|
247
|
-
|
|
330
|
+
output_type: None = None,
|
|
331
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
332
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
333
|
+
deps: AgentDepsT = None,
|
|
334
|
+
model_settings: ModelSettings | None = None,
|
|
335
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
336
|
+
usage: _usage.Usage | None = None,
|
|
337
|
+
infer_name: bool = True,
|
|
338
|
+
) -> AgentRunResult[OutputDataT]: ...
|
|
339
|
+
|
|
340
|
+
@overload
|
|
341
|
+
async def run(
|
|
342
|
+
self,
|
|
343
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
344
|
+
*,
|
|
345
|
+
output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT],
|
|
248
346
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
249
347
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
250
348
|
deps: AgentDepsT = None,
|
|
@@ -252,14 +350,15 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
252
350
|
usage_limits: _usage.UsageLimits | None = None,
|
|
253
351
|
usage: _usage.Usage | None = None,
|
|
254
352
|
infer_name: bool = True,
|
|
255
|
-
) -> AgentRunResult[
|
|
353
|
+
) -> AgentRunResult[RunOutputDataT]: ...
|
|
256
354
|
|
|
257
355
|
@overload
|
|
356
|
+
@deprecated('`result_type` is deprecated, use `output_type` instead.')
|
|
258
357
|
async def run(
|
|
259
358
|
self,
|
|
260
359
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
261
360
|
*,
|
|
262
|
-
result_type: type[
|
|
361
|
+
result_type: type[RunOutputDataT],
|
|
263
362
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
264
363
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
265
364
|
deps: AgentDepsT = None,
|
|
@@ -267,13 +366,13 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
267
366
|
usage_limits: _usage.UsageLimits | None = None,
|
|
268
367
|
usage: _usage.Usage | None = None,
|
|
269
368
|
infer_name: bool = True,
|
|
270
|
-
) -> AgentRunResult[
|
|
369
|
+
) -> AgentRunResult[RunOutputDataT]: ...
|
|
271
370
|
|
|
272
371
|
async def run(
|
|
273
372
|
self,
|
|
274
373
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
275
374
|
*,
|
|
276
|
-
|
|
375
|
+
output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None,
|
|
277
376
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
278
377
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
279
378
|
deps: AgentDepsT = None,
|
|
@@ -281,6 +380,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
281
380
|
usage_limits: _usage.UsageLimits | None = None,
|
|
282
381
|
usage: _usage.Usage | None = None,
|
|
283
382
|
infer_name: bool = True,
|
|
383
|
+
**_deprecated_kwargs: Never,
|
|
284
384
|
) -> AgentRunResult[Any]:
|
|
285
385
|
"""Run the agent with a user prompt in async mode.
|
|
286
386
|
|
|
@@ -295,14 +395,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
295
395
|
|
|
296
396
|
async def main():
|
|
297
397
|
agent_run = await agent.run('What is the capital of France?')
|
|
298
|
-
print(agent_run.
|
|
398
|
+
print(agent_run.output)
|
|
299
399
|
#> Paris
|
|
300
400
|
```
|
|
301
401
|
|
|
302
402
|
Args:
|
|
303
403
|
user_prompt: User input to start/continue the conversation.
|
|
304
|
-
|
|
305
|
-
|
|
404
|
+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
405
|
+
output validators since output validators would expect an argument that matches the agent's output type.
|
|
306
406
|
message_history: History of the conversation so far.
|
|
307
407
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
308
408
|
deps: Optional dependencies to use for this run.
|
|
@@ -316,9 +416,16 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
316
416
|
"""
|
|
317
417
|
if infer_name and self.name is None:
|
|
318
418
|
self._infer_name(inspect.currentframe())
|
|
419
|
+
|
|
420
|
+
if 'result_type' in _deprecated_kwargs: # pragma: no cover
|
|
421
|
+
if output_type is not str:
|
|
422
|
+
raise TypeError('`result_type` and `output_type` cannot be set at the same time.')
|
|
423
|
+
warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning)
|
|
424
|
+
output_type = _deprecated_kwargs['result_type']
|
|
425
|
+
|
|
319
426
|
async with self.iter(
|
|
320
427
|
user_prompt=user_prompt,
|
|
321
|
-
|
|
428
|
+
output_type=output_type,
|
|
322
429
|
message_history=message_history,
|
|
323
430
|
model=model,
|
|
324
431
|
deps=deps,
|
|
@@ -332,12 +439,44 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
332
439
|
assert agent_run.result is not None, 'The graph run did not finish properly'
|
|
333
440
|
return agent_run.result
|
|
334
441
|
|
|
442
|
+
@overload
|
|
443
|
+
def iter(
|
|
444
|
+
self,
|
|
445
|
+
user_prompt: str | Sequence[_messages.UserContent] | None,
|
|
446
|
+
*,
|
|
447
|
+
output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None,
|
|
448
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
449
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
450
|
+
deps: AgentDepsT = None,
|
|
451
|
+
model_settings: ModelSettings | None = None,
|
|
452
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
453
|
+
usage: _usage.Usage | None = None,
|
|
454
|
+
infer_name: bool = True,
|
|
455
|
+
**_deprecated_kwargs: Never,
|
|
456
|
+
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ...
|
|
457
|
+
|
|
458
|
+
@overload
|
|
459
|
+
@deprecated('`result_type` is deprecated, use `output_type` instead.')
|
|
460
|
+
def iter(
|
|
461
|
+
self,
|
|
462
|
+
user_prompt: str | Sequence[_messages.UserContent] | None,
|
|
463
|
+
*,
|
|
464
|
+
result_type: type[RunOutputDataT],
|
|
465
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
466
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
467
|
+
deps: AgentDepsT = None,
|
|
468
|
+
model_settings: ModelSettings | None = None,
|
|
469
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
470
|
+
usage: _usage.Usage | None = None,
|
|
471
|
+
infer_name: bool = True,
|
|
472
|
+
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ...
|
|
473
|
+
|
|
335
474
|
@asynccontextmanager
|
|
336
475
|
async def iter(
|
|
337
476
|
self,
|
|
338
477
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
339
478
|
*,
|
|
340
|
-
|
|
479
|
+
output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None,
|
|
341
480
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
342
481
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
343
482
|
deps: AgentDepsT = None,
|
|
@@ -345,10 +484,11 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
345
484
|
usage_limits: _usage.UsageLimits | None = None,
|
|
346
485
|
usage: _usage.Usage | None = None,
|
|
347
486
|
infer_name: bool = True,
|
|
487
|
+
**_deprecated_kwargs: Never,
|
|
348
488
|
) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
|
|
349
489
|
"""A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
|
|
350
490
|
|
|
351
|
-
This method builds an internal agent graph (using system prompts, tools and
|
|
491
|
+
This method builds an internal agent graph (using system prompts, tools and output schemas) and then returns an
|
|
352
492
|
`AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are
|
|
353
493
|
executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the
|
|
354
494
|
stream of events coming from the execution of tools.
|
|
@@ -374,6 +514,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
374
514
|
[
|
|
375
515
|
UserPromptNode(
|
|
376
516
|
user_prompt='What is the capital of France?',
|
|
517
|
+
instructions=None,
|
|
518
|
+
instructions_functions=[],
|
|
377
519
|
system_prompts=(),
|
|
378
520
|
system_prompt_functions=[],
|
|
379
521
|
system_prompt_dynamic_functions={},
|
|
@@ -387,6 +529,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
387
529
|
part_kind='user-prompt',
|
|
388
530
|
)
|
|
389
531
|
],
|
|
532
|
+
instructions=None,
|
|
390
533
|
kind='request',
|
|
391
534
|
)
|
|
392
535
|
),
|
|
@@ -398,17 +541,17 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
398
541
|
kind='response',
|
|
399
542
|
)
|
|
400
543
|
),
|
|
401
|
-
End(data=FinalResult(
|
|
544
|
+
End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)),
|
|
402
545
|
]
|
|
403
546
|
'''
|
|
404
|
-
print(agent_run.result.
|
|
547
|
+
print(agent_run.result.output)
|
|
405
548
|
#> Paris
|
|
406
549
|
```
|
|
407
550
|
|
|
408
551
|
Args:
|
|
409
552
|
user_prompt: User input to start/continue the conversation.
|
|
410
|
-
|
|
411
|
-
|
|
553
|
+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
554
|
+
output validators since output validators would expect an argument that matches the agent's output type.
|
|
412
555
|
message_history: History of the conversation so far.
|
|
413
556
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
414
557
|
deps: Optional dependencies to use for this run.
|
|
@@ -425,12 +568,22 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
425
568
|
model_used = self._get_model(model)
|
|
426
569
|
del model
|
|
427
570
|
|
|
571
|
+
if 'result_type' in _deprecated_kwargs: # pragma: no cover
|
|
572
|
+
if output_type is not str:
|
|
573
|
+
raise TypeError('`result_type` and `output_type` cannot be set at the same time.')
|
|
574
|
+
warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning)
|
|
575
|
+
output_type = _deprecated_kwargs['result_type']
|
|
576
|
+
|
|
428
577
|
deps = self._get_deps(deps)
|
|
429
578
|
new_message_index = len(message_history) if message_history else 0
|
|
430
|
-
|
|
579
|
+
output_schema = self._prepare_output_schema(output_type)
|
|
580
|
+
|
|
581
|
+
output_type_ = output_type or self.output_type
|
|
431
582
|
|
|
432
583
|
# Build the graph
|
|
433
|
-
graph =
|
|
584
|
+
graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = (
|
|
585
|
+
_agent_graph.build_agent_graph(self.name, self._deps_type, output_type_)
|
|
586
|
+
)
|
|
434
587
|
|
|
435
588
|
# Build the initial state
|
|
436
589
|
state = _agent_graph.GraphAgentState(
|
|
@@ -440,10 +593,10 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
440
593
|
run_step=0,
|
|
441
594
|
)
|
|
442
595
|
|
|
443
|
-
# We consider it a user error if a user tries to restrict the result type while having
|
|
596
|
+
# We consider it a user error if a user tries to restrict the result type while having an output validator that
|
|
444
597
|
# may change the result type from the restricted type to something else. Therefore, we consider the following
|
|
445
598
|
# typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
|
|
446
|
-
|
|
599
|
+
output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators)
|
|
447
600
|
|
|
448
601
|
# TODO: Instead of this, copy the function tools to ensure they don't share current_retry state between agent
|
|
449
602
|
# runs. Requires some changes to `Tool` to make them copyable though.
|
|
@@ -467,7 +620,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
467
620
|
},
|
|
468
621
|
)
|
|
469
622
|
|
|
470
|
-
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT,
|
|
623
|
+
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
|
|
471
624
|
user_deps=deps,
|
|
472
625
|
prompt=user_prompt,
|
|
473
626
|
new_message_index=new_message_index,
|
|
@@ -476,9 +629,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
476
629
|
usage_limits=usage_limits,
|
|
477
630
|
max_result_retries=self._max_result_retries,
|
|
478
631
|
end_strategy=self.end_strategy,
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
result_validators=result_validators,
|
|
632
|
+
output_schema=output_schema,
|
|
633
|
+
output_validators=output_validators,
|
|
482
634
|
function_tools=self._function_tools,
|
|
483
635
|
mcp_servers=self._mcp_servers,
|
|
484
636
|
run_span=run_span,
|
|
@@ -486,6 +638,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
486
638
|
)
|
|
487
639
|
start_node = _agent_graph.UserPromptNode[AgentDepsT](
|
|
488
640
|
user_prompt=user_prompt,
|
|
641
|
+
instructions=self._instructions,
|
|
642
|
+
instructions_functions=self._instructions_functions,
|
|
489
643
|
system_prompts=self._system_prompts,
|
|
490
644
|
system_prompt_functions=self._system_prompt_functions,
|
|
491
645
|
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
|
|
@@ -512,14 +666,30 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
512
666
|
usage_limits: _usage.UsageLimits | None = None,
|
|
513
667
|
usage: _usage.Usage | None = None,
|
|
514
668
|
infer_name: bool = True,
|
|
515
|
-
) -> AgentRunResult[
|
|
669
|
+
) -> AgentRunResult[OutputDataT]: ...
|
|
670
|
+
|
|
671
|
+
@overload
|
|
672
|
+
def run_sync(
|
|
673
|
+
self,
|
|
674
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
675
|
+
*,
|
|
676
|
+
output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None,
|
|
677
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
678
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
679
|
+
deps: AgentDepsT = None,
|
|
680
|
+
model_settings: ModelSettings | None = None,
|
|
681
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
682
|
+
usage: _usage.Usage | None = None,
|
|
683
|
+
infer_name: bool = True,
|
|
684
|
+
) -> AgentRunResult[RunOutputDataT]: ...
|
|
516
685
|
|
|
517
686
|
@overload
|
|
687
|
+
@deprecated('`result_type` is deprecated, use `output_type` instead.')
|
|
518
688
|
def run_sync(
|
|
519
689
|
self,
|
|
520
690
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
521
691
|
*,
|
|
522
|
-
result_type: type[
|
|
692
|
+
result_type: type[RunOutputDataT],
|
|
523
693
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
524
694
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
525
695
|
deps: AgentDepsT = None,
|
|
@@ -527,13 +697,13 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
527
697
|
usage_limits: _usage.UsageLimits | None = None,
|
|
528
698
|
usage: _usage.Usage | None = None,
|
|
529
699
|
infer_name: bool = True,
|
|
530
|
-
) -> AgentRunResult[
|
|
700
|
+
) -> AgentRunResult[RunOutputDataT]: ...
|
|
531
701
|
|
|
532
702
|
def run_sync(
|
|
533
703
|
self,
|
|
534
704
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
535
705
|
*,
|
|
536
|
-
|
|
706
|
+
output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None,
|
|
537
707
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
538
708
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
539
709
|
deps: AgentDepsT = None,
|
|
@@ -541,6 +711,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
541
711
|
usage_limits: _usage.UsageLimits | None = None,
|
|
542
712
|
usage: _usage.Usage | None = None,
|
|
543
713
|
infer_name: bool = True,
|
|
714
|
+
**_deprecated_kwargs: Never,
|
|
544
715
|
) -> AgentRunResult[Any]:
|
|
545
716
|
"""Synchronously run the agent with a user prompt.
|
|
546
717
|
|
|
@@ -554,14 +725,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
554
725
|
agent = Agent('openai:gpt-4o')
|
|
555
726
|
|
|
556
727
|
result_sync = agent.run_sync('What is the capital of Italy?')
|
|
557
|
-
print(result_sync.
|
|
728
|
+
print(result_sync.output)
|
|
558
729
|
#> Rome
|
|
559
730
|
```
|
|
560
731
|
|
|
561
732
|
Args:
|
|
562
733
|
user_prompt: User input to start/continue the conversation.
|
|
563
|
-
|
|
564
|
-
|
|
734
|
+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
735
|
+
output validators since output validators would expect an argument that matches the agent's output type.
|
|
565
736
|
message_history: History of the conversation so far.
|
|
566
737
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
567
738
|
deps: Optional dependencies to use for this run.
|
|
@@ -575,10 +746,17 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
575
746
|
"""
|
|
576
747
|
if infer_name and self.name is None:
|
|
577
748
|
self._infer_name(inspect.currentframe())
|
|
749
|
+
|
|
750
|
+
if 'result_type' in _deprecated_kwargs: # pragma: no cover
|
|
751
|
+
if output_type is not str:
|
|
752
|
+
raise TypeError('`result_type` and `output_type` cannot be set at the same time.')
|
|
753
|
+
warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning)
|
|
754
|
+
output_type = _deprecated_kwargs['result_type']
|
|
755
|
+
|
|
578
756
|
return get_event_loop().run_until_complete(
|
|
579
757
|
self.run(
|
|
580
758
|
user_prompt,
|
|
581
|
-
|
|
759
|
+
output_type=output_type,
|
|
582
760
|
message_history=message_history,
|
|
583
761
|
model=model,
|
|
584
762
|
deps=deps,
|
|
@@ -589,12 +767,26 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
589
767
|
)
|
|
590
768
|
)
|
|
591
769
|
|
|
770
|
+
@overload
|
|
771
|
+
def run_stream(
|
|
772
|
+
self,
|
|
773
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
774
|
+
*,
|
|
775
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
776
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
777
|
+
deps: AgentDepsT = None,
|
|
778
|
+
model_settings: ModelSettings | None = None,
|
|
779
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
780
|
+
usage: _usage.Usage | None = None,
|
|
781
|
+
infer_name: bool = True,
|
|
782
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ...
|
|
783
|
+
|
|
592
784
|
@overload
|
|
593
785
|
def run_stream(
|
|
594
786
|
self,
|
|
595
787
|
user_prompt: str | Sequence[_messages.UserContent],
|
|
596
788
|
*,
|
|
597
|
-
|
|
789
|
+
output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT],
|
|
598
790
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
599
791
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
600
792
|
deps: AgentDepsT = None,
|
|
@@ -602,14 +794,15 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
602
794
|
usage_limits: _usage.UsageLimits | None = None,
|
|
603
795
|
usage: _usage.Usage | None = None,
|
|
604
796
|
infer_name: bool = True,
|
|
605
|
-
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT,
|
|
797
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
|
|
606
798
|
|
|
607
799
|
@overload
|
|
800
|
+
@deprecated('`result_type` is deprecated, use `output_type` instead.')
|
|
608
801
|
def run_stream(
|
|
609
802
|
self,
|
|
610
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
803
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
611
804
|
*,
|
|
612
|
-
result_type: type[
|
|
805
|
+
result_type: type[RunOutputDataT],
|
|
613
806
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
614
807
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
615
808
|
deps: AgentDepsT = None,
|
|
@@ -617,14 +810,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
617
810
|
usage_limits: _usage.UsageLimits | None = None,
|
|
618
811
|
usage: _usage.Usage | None = None,
|
|
619
812
|
infer_name: bool = True,
|
|
620
|
-
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT,
|
|
813
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
|
|
621
814
|
|
|
622
815
|
@asynccontextmanager
|
|
623
816
|
async def run_stream( # noqa C901
|
|
624
817
|
self,
|
|
625
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
818
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
626
819
|
*,
|
|
627
|
-
|
|
820
|
+
output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None,
|
|
628
821
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
629
822
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
630
823
|
deps: AgentDepsT = None,
|
|
@@ -632,6 +825,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
632
825
|
usage_limits: _usage.UsageLimits | None = None,
|
|
633
826
|
usage: _usage.Usage | None = None,
|
|
634
827
|
infer_name: bool = True,
|
|
828
|
+
**_deprecated_kwargs: Never,
|
|
635
829
|
) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]:
|
|
636
830
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
637
831
|
|
|
@@ -643,14 +837,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
643
837
|
|
|
644
838
|
async def main():
|
|
645
839
|
async with agent.run_stream('What is the capital of the UK?') as response:
|
|
646
|
-
print(await response.
|
|
840
|
+
print(await response.get_output())
|
|
647
841
|
#> London
|
|
648
842
|
```
|
|
649
843
|
|
|
650
844
|
Args:
|
|
651
845
|
user_prompt: User input to start/continue the conversation.
|
|
652
|
-
|
|
653
|
-
|
|
846
|
+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
|
|
847
|
+
output validators since output validators would expect an argument that matches the agent's output type.
|
|
654
848
|
message_history: History of the conversation so far.
|
|
655
849
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
656
850
|
deps: Optional dependencies to use for this run.
|
|
@@ -669,10 +863,16 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
669
863
|
if frame := inspect.currentframe(): # pragma: no branch
|
|
670
864
|
self._infer_name(frame.f_back)
|
|
671
865
|
|
|
866
|
+
if 'result_type' in _deprecated_kwargs: # pragma: no cover
|
|
867
|
+
if output_type is not str:
|
|
868
|
+
raise TypeError('`result_type` and `output_type` cannot be set at the same time.')
|
|
869
|
+
warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning)
|
|
870
|
+
output_type = _deprecated_kwargs['result_type']
|
|
871
|
+
|
|
672
872
|
yielded = False
|
|
673
873
|
async with self.iter(
|
|
674
874
|
user_prompt,
|
|
675
|
-
|
|
875
|
+
output_type=output_type,
|
|
676
876
|
message_history=message_history,
|
|
677
877
|
model=model,
|
|
678
878
|
deps=deps,
|
|
@@ -692,15 +892,15 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
692
892
|
async def stream_to_final(
|
|
693
893
|
s: models.StreamedResponse,
|
|
694
894
|
) -> FinalResult[models.StreamedResponse] | None:
|
|
695
|
-
|
|
895
|
+
output_schema = graph_ctx.deps.output_schema
|
|
696
896
|
async for maybe_part_event in streamed_response:
|
|
697
897
|
if isinstance(maybe_part_event, _messages.PartStartEvent):
|
|
698
898
|
new_part = maybe_part_event.part
|
|
699
899
|
if isinstance(new_part, _messages.TextPart):
|
|
700
|
-
if _agent_graph.
|
|
900
|
+
if _agent_graph.allow_text_output(output_schema):
|
|
701
901
|
return FinalResult(s, None, None)
|
|
702
|
-
elif isinstance(new_part, _messages.ToolCallPart) and
|
|
703
|
-
for call, _ in
|
|
902
|
+
elif isinstance(new_part, _messages.ToolCallPart) and output_schema:
|
|
903
|
+
for call, _ in output_schema.find_tool([new_part]):
|
|
704
904
|
return FinalResult(s, call.tool_name, call.tool_call_id)
|
|
705
905
|
return None
|
|
706
906
|
|
|
@@ -745,9 +945,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
745
945
|
graph_ctx.deps.new_message_index,
|
|
746
946
|
graph_ctx.deps.usage_limits,
|
|
747
947
|
streamed_response,
|
|
748
|
-
graph_ctx.deps.
|
|
948
|
+
graph_ctx.deps.output_schema,
|
|
749
949
|
_agent_graph.build_run_context(graph_ctx),
|
|
750
|
-
graph_ctx.deps.
|
|
950
|
+
graph_ctx.deps.output_validators,
|
|
751
951
|
final_result_details.tool_name,
|
|
752
952
|
on_complete,
|
|
753
953
|
)
|
|
@@ -796,6 +996,73 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
796
996
|
if _utils.is_set(override_model_before):
|
|
797
997
|
self._override_model = override_model_before
|
|
798
998
|
|
|
999
|
+
@overload
|
|
1000
|
+
def instructions(
|
|
1001
|
+
self, func: Callable[[RunContext[AgentDepsT]], str], /
|
|
1002
|
+
) -> Callable[[RunContext[AgentDepsT]], str]: ...
|
|
1003
|
+
|
|
1004
|
+
@overload
|
|
1005
|
+
def instructions(
|
|
1006
|
+
self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], /
|
|
1007
|
+
) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ...
|
|
1008
|
+
|
|
1009
|
+
@overload
|
|
1010
|
+
def instructions(self, func: Callable[[], str], /) -> Callable[[], str]: ...
|
|
1011
|
+
|
|
1012
|
+
@overload
|
|
1013
|
+
def instructions(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...
|
|
1014
|
+
|
|
1015
|
+
@overload
|
|
1016
|
+
def instructions(
|
|
1017
|
+
self, /
|
|
1018
|
+
) -> Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]: ...
|
|
1019
|
+
|
|
1020
|
+
def instructions(
|
|
1021
|
+
self,
|
|
1022
|
+
func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
|
|
1023
|
+
/,
|
|
1024
|
+
) -> (
|
|
1025
|
+
Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]
|
|
1026
|
+
| _system_prompt.SystemPromptFunc[AgentDepsT]
|
|
1027
|
+
):
|
|
1028
|
+
"""Decorator to register an instructions function.
|
|
1029
|
+
|
|
1030
|
+
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
|
|
1031
|
+
Can decorate a sync or async functions.
|
|
1032
|
+
|
|
1033
|
+
The decorator can be used bare (`agent.instructions`).
|
|
1034
|
+
|
|
1035
|
+
Overloads for every possible signature of `instructions` are included so the decorator doesn't obscure
|
|
1036
|
+
the type of the function.
|
|
1037
|
+
|
|
1038
|
+
Example:
|
|
1039
|
+
```python
|
|
1040
|
+
from pydantic_ai import Agent, RunContext
|
|
1041
|
+
|
|
1042
|
+
agent = Agent('test', deps_type=str)
|
|
1043
|
+
|
|
1044
|
+
@agent.instructions
|
|
1045
|
+
def simple_instructions() -> str:
|
|
1046
|
+
return 'foobar'
|
|
1047
|
+
|
|
1048
|
+
@agent.instructions
|
|
1049
|
+
async def async_instructions(ctx: RunContext[str]) -> str:
|
|
1050
|
+
return f'{ctx.deps} is the best'
|
|
1051
|
+
```
|
|
1052
|
+
"""
|
|
1053
|
+
if func is None:
|
|
1054
|
+
|
|
1055
|
+
def decorator(
|
|
1056
|
+
func_: _system_prompt.SystemPromptFunc[AgentDepsT],
|
|
1057
|
+
) -> _system_prompt.SystemPromptFunc[AgentDepsT]:
|
|
1058
|
+
self._instructions_functions.append(_system_prompt.SystemPromptRunner(func_))
|
|
1059
|
+
return func_
|
|
1060
|
+
|
|
1061
|
+
return decorator
|
|
1062
|
+
else:
|
|
1063
|
+
self._instructions_functions.append(_system_prompt.SystemPromptRunner(func))
|
|
1064
|
+
return func
|
|
1065
|
+
|
|
799
1066
|
@overload
|
|
800
1067
|
def system_prompt(
|
|
801
1068
|
self, func: Callable[[RunContext[AgentDepsT]], str], /
|
|
@@ -876,34 +1143,34 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
876
1143
|
return func
|
|
877
1144
|
|
|
878
1145
|
@overload
|
|
879
|
-
def
|
|
880
|
-
self, func: Callable[[RunContext[AgentDepsT],
|
|
881
|
-
) -> Callable[[RunContext[AgentDepsT],
|
|
1146
|
+
def output_validator(
|
|
1147
|
+
self, func: Callable[[RunContext[AgentDepsT], OutputDataT], OutputDataT], /
|
|
1148
|
+
) -> Callable[[RunContext[AgentDepsT], OutputDataT], OutputDataT]: ...
|
|
882
1149
|
|
|
883
1150
|
@overload
|
|
884
|
-
def
|
|
885
|
-
self, func: Callable[[RunContext[AgentDepsT],
|
|
886
|
-
) -> Callable[[RunContext[AgentDepsT],
|
|
1151
|
+
def output_validator(
|
|
1152
|
+
self, func: Callable[[RunContext[AgentDepsT], OutputDataT], Awaitable[OutputDataT]], /
|
|
1153
|
+
) -> Callable[[RunContext[AgentDepsT], OutputDataT], Awaitable[OutputDataT]]: ...
|
|
887
1154
|
|
|
888
1155
|
@overload
|
|
889
|
-
def
|
|
890
|
-
self, func: Callable[[
|
|
891
|
-
) -> Callable[[
|
|
1156
|
+
def output_validator(
|
|
1157
|
+
self, func: Callable[[OutputDataT], OutputDataT], /
|
|
1158
|
+
) -> Callable[[OutputDataT], OutputDataT]: ...
|
|
892
1159
|
|
|
893
1160
|
@overload
|
|
894
|
-
def
|
|
895
|
-
self, func: Callable[[
|
|
896
|
-
) -> Callable[[
|
|
1161
|
+
def output_validator(
|
|
1162
|
+
self, func: Callable[[OutputDataT], Awaitable[OutputDataT]], /
|
|
1163
|
+
) -> Callable[[OutputDataT], Awaitable[OutputDataT]]: ...
|
|
897
1164
|
|
|
898
|
-
def
|
|
899
|
-
self, func:
|
|
900
|
-
) ->
|
|
901
|
-
"""Decorator to register
|
|
1165
|
+
def output_validator(
|
|
1166
|
+
self, func: _output.OutputValidatorFunc[AgentDepsT, OutputDataT], /
|
|
1167
|
+
) -> _output.OutputValidatorFunc[AgentDepsT, OutputDataT]:
|
|
1168
|
+
"""Decorator to register an output validator function.
|
|
902
1169
|
|
|
903
1170
|
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
904
1171
|
Can decorate a sync or async functions.
|
|
905
1172
|
|
|
906
|
-
Overloads for every possible signature of `
|
|
1173
|
+
Overloads for every possible signature of `output_validator` are included so the decorator doesn't obscure
|
|
907
1174
|
the type of the function, see `tests/typed_agent.py` for tests.
|
|
908
1175
|
|
|
909
1176
|
Example:
|
|
@@ -912,26 +1179,29 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
912
1179
|
|
|
913
1180
|
agent = Agent('test', deps_type=str)
|
|
914
1181
|
|
|
915
|
-
@agent.
|
|
916
|
-
def
|
|
1182
|
+
@agent.output_validator
|
|
1183
|
+
def output_validator_simple(data: str) -> str:
|
|
917
1184
|
if 'wrong' in data:
|
|
918
1185
|
raise ModelRetry('wrong response')
|
|
919
1186
|
return data
|
|
920
1187
|
|
|
921
|
-
@agent.
|
|
922
|
-
async def
|
|
1188
|
+
@agent.output_validator
|
|
1189
|
+
async def output_validator_deps(ctx: RunContext[str], data: str) -> str:
|
|
923
1190
|
if ctx.deps in data:
|
|
924
1191
|
raise ModelRetry('wrong response')
|
|
925
1192
|
return data
|
|
926
1193
|
|
|
927
1194
|
result = agent.run_sync('foobar', deps='spam')
|
|
928
|
-
print(result.
|
|
1195
|
+
print(result.output)
|
|
929
1196
|
#> success (no tool calls)
|
|
930
1197
|
```
|
|
931
1198
|
"""
|
|
932
|
-
self.
|
|
1199
|
+
self._output_validators.append(_output.OutputValidator[AgentDepsT, Any](func))
|
|
933
1200
|
return func
|
|
934
1201
|
|
|
1202
|
+
@deprecated('`result_validator` is deprecated, use `output_validator` instead.')
|
|
1203
|
+
def result_validator(self, func: Any, /) -> Any: ...
|
|
1204
|
+
|
|
935
1205
|
@overload
|
|
936
1206
|
def tool(self, func: ToolFuncContext[AgentDepsT, ToolParams], /) -> ToolFuncContext[AgentDepsT, ToolParams]: ...
|
|
937
1207
|
|
|
@@ -987,7 +1257,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
987
1257
|
return ctx.deps + y
|
|
988
1258
|
|
|
989
1259
|
result = agent.run_sync('foobar', deps=1)
|
|
990
|
-
print(result.
|
|
1260
|
+
print(result.output)
|
|
991
1261
|
#> {"foobar":1,"spam":1.0}
|
|
992
1262
|
```
|
|
993
1263
|
|
|
@@ -1096,7 +1366,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1096
1366
|
return 3.14
|
|
1097
1367
|
|
|
1098
1368
|
result = agent.run_sync('foobar', deps=1)
|
|
1099
|
-
print(result.
|
|
1369
|
+
print(result.output)
|
|
1100
1370
|
#> {"foobar":123,"spam":3.14}
|
|
1101
1371
|
```
|
|
1102
1372
|
|
|
@@ -1183,7 +1453,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1183
1453
|
if tool.name in self._function_tools:
|
|
1184
1454
|
raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
|
|
1185
1455
|
|
|
1186
|
-
if self.
|
|
1456
|
+
if self._output_schema and tool.name in self._output_schema.tools:
|
|
1187
1457
|
raise exceptions.UserError(f'Tool name conflicts with result schema name: {tool.name!r}')
|
|
1188
1458
|
|
|
1189
1459
|
self._function_tools[tool.name] = tool
|
|
@@ -1226,7 +1496,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1226
1496
|
|
|
1227
1497
|
return model_
|
|
1228
1498
|
|
|
1229
|
-
def _get_deps(self: Agent[T,
|
|
1499
|
+
def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T:
|
|
1230
1500
|
"""Get deps for a run.
|
|
1231
1501
|
|
|
1232
1502
|
If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
|
|
@@ -1264,22 +1534,19 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1264
1534
|
def last_run_messages(self) -> list[_messages.ModelMessage]:
|
|
1265
1535
|
raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
|
|
1266
1536
|
|
|
1267
|
-
def
|
|
1268
|
-
self,
|
|
1269
|
-
) ->
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
|
|
1278
|
-
return _result.ResultSchema[result_type].build(
|
|
1279
|
-
result_type, self._result_tool_name, self._result_tool_description
|
|
1537
|
+
def _prepare_output_schema(
|
|
1538
|
+
self, output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None
|
|
1539
|
+
) -> _output.OutputSchema[RunOutputDataT] | None:
|
|
1540
|
+
if output_type is not None:
|
|
1541
|
+
if self._output_validators:
|
|
1542
|
+
raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators')
|
|
1543
|
+
return _output.OutputSchema[RunOutputDataT].build(
|
|
1544
|
+
output_type,
|
|
1545
|
+
self._deprecated_result_tool_name,
|
|
1546
|
+
self._deprecated_result_tool_description,
|
|
1280
1547
|
)
|
|
1281
1548
|
else:
|
|
1282
|
-
return self.
|
|
1549
|
+
return self._output_schema # pyright: ignore[reportReturnType]
|
|
1283
1550
|
|
|
1284
1551
|
@staticmethod
|
|
1285
1552
|
def is_model_request_node(
|
|
@@ -1337,7 +1604,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1337
1604
|
|
|
1338
1605
|
|
|
1339
1606
|
@dataclasses.dataclass(repr=False)
|
|
1340
|
-
class AgentRun(Generic[AgentDepsT,
|
|
1607
|
+
class AgentRun(Generic[AgentDepsT, OutputDataT]):
|
|
1341
1608
|
"""A stateful, async-iterable run of an [`Agent`][pydantic_ai.agent.Agent].
|
|
1342
1609
|
|
|
1343
1610
|
You generally obtain an `AgentRun` instance by calling `async with my_agent.iter(...) as agent_run:`.
|
|
@@ -1363,6 +1630,8 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1363
1630
|
[
|
|
1364
1631
|
UserPromptNode(
|
|
1365
1632
|
user_prompt='What is the capital of France?',
|
|
1633
|
+
instructions=None,
|
|
1634
|
+
instructions_functions=[],
|
|
1366
1635
|
system_prompts=(),
|
|
1367
1636
|
system_prompt_functions=[],
|
|
1368
1637
|
system_prompt_dynamic_functions={},
|
|
@@ -1376,6 +1645,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1376
1645
|
part_kind='user-prompt',
|
|
1377
1646
|
)
|
|
1378
1647
|
],
|
|
1648
|
+
instructions=None,
|
|
1379
1649
|
kind='request',
|
|
1380
1650
|
)
|
|
1381
1651
|
),
|
|
@@ -1387,10 +1657,10 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1387
1657
|
kind='response',
|
|
1388
1658
|
)
|
|
1389
1659
|
),
|
|
1390
|
-
End(data=FinalResult(
|
|
1660
|
+
End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)),
|
|
1391
1661
|
]
|
|
1392
1662
|
'''
|
|
1393
|
-
print(agent_run.result.
|
|
1663
|
+
print(agent_run.result.output)
|
|
1394
1664
|
#> Paris
|
|
1395
1665
|
```
|
|
1396
1666
|
|
|
@@ -1399,9 +1669,19 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1399
1669
|
"""
|
|
1400
1670
|
|
|
1401
1671
|
_graph_run: GraphRun[
|
|
1402
|
-
_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[
|
|
1672
|
+
_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[OutputDataT]
|
|
1403
1673
|
]
|
|
1404
1674
|
|
|
1675
|
+
@overload
|
|
1676
|
+
def _span(self, *, required: Literal[False]) -> AbstractSpan | None: ...
|
|
1677
|
+
@overload
|
|
1678
|
+
def _span(self) -> AbstractSpan: ...
|
|
1679
|
+
def _span(self, *, required: bool = True) -> AbstractSpan | None:
|
|
1680
|
+
span = self._graph_run._span(required=False) # type: ignore[reportPrivateUsage]
|
|
1681
|
+
if span is None and required: # pragma: no cover
|
|
1682
|
+
raise AttributeError('Span is not available for this agent run')
|
|
1683
|
+
return span
|
|
1684
|
+
|
|
1405
1685
|
@property
|
|
1406
1686
|
def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]:
|
|
1407
1687
|
"""The current context of the agent run."""
|
|
@@ -1412,7 +1692,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1412
1692
|
@property
|
|
1413
1693
|
def next_node(
|
|
1414
1694
|
self,
|
|
1415
|
-
) -> _agent_graph.AgentNode[AgentDepsT,
|
|
1695
|
+
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
|
|
1416
1696
|
"""The next node that will be run in the agent graph.
|
|
1417
1697
|
|
|
1418
1698
|
This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`.
|
|
@@ -1425,7 +1705,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1425
1705
|
raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover
|
|
1426
1706
|
|
|
1427
1707
|
@property
|
|
1428
|
-
def result(self) -> AgentRunResult[
|
|
1708
|
+
def result(self) -> AgentRunResult[OutputDataT] | None:
|
|
1429
1709
|
"""The final result of the run if it has ended, otherwise `None`.
|
|
1430
1710
|
|
|
1431
1711
|
Once the run returns an [`End`][pydantic_graph.nodes.End] node, `result` is populated
|
|
@@ -1435,21 +1715,22 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1435
1715
|
if graph_run_result is None:
|
|
1436
1716
|
return None
|
|
1437
1717
|
return AgentRunResult(
|
|
1438
|
-
graph_run_result.output.
|
|
1718
|
+
graph_run_result.output.output,
|
|
1439
1719
|
graph_run_result.output.tool_name,
|
|
1440
1720
|
graph_run_result.state,
|
|
1441
1721
|
self._graph_run.deps.new_message_index,
|
|
1722
|
+
self._graph_run._span(required=False), # type: ignore[reportPrivateUsage]
|
|
1442
1723
|
)
|
|
1443
1724
|
|
|
1444
1725
|
def __aiter__(
|
|
1445
1726
|
self,
|
|
1446
|
-
) -> AsyncIterator[_agent_graph.AgentNode[AgentDepsT,
|
|
1727
|
+
) -> AsyncIterator[_agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]]:
|
|
1447
1728
|
"""Provide async-iteration over the nodes in the agent run."""
|
|
1448
1729
|
return self
|
|
1449
1730
|
|
|
1450
1731
|
async def __anext__(
|
|
1451
1732
|
self,
|
|
1452
|
-
) -> _agent_graph.AgentNode[AgentDepsT,
|
|
1733
|
+
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
|
|
1453
1734
|
"""Advance to the next node automatically based on the last returned node."""
|
|
1454
1735
|
next_node = await self._graph_run.__anext__()
|
|
1455
1736
|
if _agent_graph.is_agent_node(next_node):
|
|
@@ -1459,8 +1740,8 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1459
1740
|
|
|
1460
1741
|
async def next(
|
|
1461
1742
|
self,
|
|
1462
|
-
node: _agent_graph.AgentNode[AgentDepsT,
|
|
1463
|
-
) -> _agent_graph.AgentNode[AgentDepsT,
|
|
1743
|
+
node: _agent_graph.AgentNode[AgentDepsT, OutputDataT],
|
|
1744
|
+
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
|
|
1464
1745
|
"""Manually drive the agent run by passing in the node you want to run next.
|
|
1465
1746
|
|
|
1466
1747
|
This lets you inspect or mutate the node before continuing execution, or skip certain nodes
|
|
@@ -1487,6 +1768,8 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1487
1768
|
[
|
|
1488
1769
|
UserPromptNode(
|
|
1489
1770
|
user_prompt='What is the capital of France?',
|
|
1771
|
+
instructions=None,
|
|
1772
|
+
instructions_functions=[],
|
|
1490
1773
|
system_prompts=(),
|
|
1491
1774
|
system_prompt_functions=[],
|
|
1492
1775
|
system_prompt_dynamic_functions={},
|
|
@@ -1500,6 +1783,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1500
1783
|
part_kind='user-prompt',
|
|
1501
1784
|
)
|
|
1502
1785
|
],
|
|
1786
|
+
instructions=None,
|
|
1503
1787
|
kind='request',
|
|
1504
1788
|
)
|
|
1505
1789
|
),
|
|
@@ -1511,10 +1795,10 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1511
1795
|
kind='response',
|
|
1512
1796
|
)
|
|
1513
1797
|
),
|
|
1514
|
-
End(data=FinalResult(
|
|
1798
|
+
End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)),
|
|
1515
1799
|
]
|
|
1516
1800
|
'''
|
|
1517
|
-
print('Final result:', agent_run.result.
|
|
1801
|
+
print('Final result:', agent_run.result.output)
|
|
1518
1802
|
#> Final result: Paris
|
|
1519
1803
|
```
|
|
1520
1804
|
|
|
@@ -1544,94 +1828,150 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1544
1828
|
|
|
1545
1829
|
|
|
1546
1830
|
@dataclasses.dataclass
|
|
1547
|
-
class AgentRunResult(Generic[
|
|
1831
|
+
class AgentRunResult(Generic[OutputDataT]):
|
|
1548
1832
|
"""The final result of an agent run."""
|
|
1549
1833
|
|
|
1550
|
-
|
|
1834
|
+
output: OutputDataT
|
|
1835
|
+
"""The output data from the agent run."""
|
|
1551
1836
|
|
|
1552
|
-
|
|
1837
|
+
_output_tool_name: str | None = dataclasses.field(repr=False)
|
|
1553
1838
|
_state: _agent_graph.GraphAgentState = dataclasses.field(repr=False)
|
|
1554
1839
|
_new_message_index: int = dataclasses.field(repr=False)
|
|
1840
|
+
_span_value: AbstractSpan | None = dataclasses.field(repr=False)
|
|
1841
|
+
|
|
1842
|
+
@overload
|
|
1843
|
+
def _span(self, *, required: Literal[False]) -> AbstractSpan | None: ...
|
|
1844
|
+
@overload
|
|
1845
|
+
def _span(self) -> AbstractSpan: ...
|
|
1846
|
+
def _span(self, *, required: bool = True) -> AbstractSpan | None:
|
|
1847
|
+
if self._span_value is None and required: # pragma: no cover
|
|
1848
|
+
raise AttributeError('Span is not available for this agent run')
|
|
1849
|
+
return self._span_value
|
|
1850
|
+
|
|
1851
|
+
@property
|
|
1852
|
+
@deprecated('`result.data` is deprecated, use `result.output` instead.')
|
|
1853
|
+
def data(self) -> OutputDataT:
|
|
1854
|
+
return self.output
|
|
1555
1855
|
|
|
1556
|
-
def
|
|
1557
|
-
"""Set return content for the
|
|
1856
|
+
def _set_output_tool_return(self, return_content: str) -> list[_messages.ModelMessage]:
|
|
1857
|
+
"""Set return content for the output tool.
|
|
1558
1858
|
|
|
1559
|
-
Useful if you want to continue the conversation and want to set the response to the
|
|
1859
|
+
Useful if you want to continue the conversation and want to set the response to the output tool call.
|
|
1560
1860
|
"""
|
|
1561
|
-
if not self.
|
|
1562
|
-
raise ValueError('Cannot set
|
|
1861
|
+
if not self._output_tool_name:
|
|
1862
|
+
raise ValueError('Cannot set output tool return content when the return type is `str`.')
|
|
1563
1863
|
messages = deepcopy(self._state.message_history)
|
|
1564
1864
|
last_message = messages[-1]
|
|
1565
1865
|
for part in last_message.parts:
|
|
1566
|
-
if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self.
|
|
1866
|
+
if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._output_tool_name:
|
|
1567
1867
|
part.content = return_content
|
|
1568
1868
|
return messages
|
|
1569
|
-
raise LookupError(f'No tool call found with tool name {self.
|
|
1869
|
+
raise LookupError(f'No tool call found with tool name {self._output_tool_name!r}.')
|
|
1570
1870
|
|
|
1571
|
-
|
|
1871
|
+
@overload
|
|
1872
|
+
def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ...
|
|
1873
|
+
|
|
1874
|
+
@overload
|
|
1875
|
+
@deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.')
|
|
1876
|
+
def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ...
|
|
1877
|
+
|
|
1878
|
+
def all_messages(
|
|
1879
|
+
self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None
|
|
1880
|
+
) -> list[_messages.ModelMessage]:
|
|
1572
1881
|
"""Return the history of _messages.
|
|
1573
1882
|
|
|
1574
1883
|
Args:
|
|
1575
|
-
|
|
1576
|
-
This provides a convenient way to modify the content of the
|
|
1577
|
-
the conversation and want to set the response to the
|
|
1884
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
1885
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
1886
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
1578
1887
|
not be modified.
|
|
1888
|
+
result_tool_return_content: Deprecated, use `output_tool_return_content` instead.
|
|
1579
1889
|
|
|
1580
1890
|
Returns:
|
|
1581
1891
|
List of messages.
|
|
1582
1892
|
"""
|
|
1583
|
-
|
|
1584
|
-
|
|
1893
|
+
content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content)
|
|
1894
|
+
if content is not None:
|
|
1895
|
+
return self._set_output_tool_return(content)
|
|
1585
1896
|
else:
|
|
1586
1897
|
return self._state.message_history
|
|
1587
1898
|
|
|
1588
|
-
|
|
1899
|
+
@overload
|
|
1900
|
+
def all_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: ...
|
|
1901
|
+
|
|
1902
|
+
@overload
|
|
1903
|
+
@deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.')
|
|
1904
|
+
def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: ...
|
|
1905
|
+
|
|
1906
|
+
def all_messages_json(
|
|
1907
|
+
self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None
|
|
1908
|
+
) -> bytes:
|
|
1589
1909
|
"""Return all messages from [`all_messages`][pydantic_ai.agent.AgentRunResult.all_messages] as JSON bytes.
|
|
1590
1910
|
|
|
1591
1911
|
Args:
|
|
1592
|
-
|
|
1593
|
-
This provides a convenient way to modify the content of the
|
|
1594
|
-
the conversation and want to set the response to the
|
|
1912
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
1913
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
1914
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
1595
1915
|
not be modified.
|
|
1916
|
+
result_tool_return_content: Deprecated, use `output_tool_return_content` instead.
|
|
1596
1917
|
|
|
1597
1918
|
Returns:
|
|
1598
1919
|
JSON bytes representing the messages.
|
|
1599
1920
|
"""
|
|
1600
|
-
|
|
1601
|
-
|
|
1602
|
-
|
|
1921
|
+
content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content)
|
|
1922
|
+
return _messages.ModelMessagesTypeAdapter.dump_json(self.all_messages(output_tool_return_content=content))
|
|
1923
|
+
|
|
1924
|
+
@overload
|
|
1925
|
+
def new_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ...
|
|
1926
|
+
|
|
1927
|
+
@overload
|
|
1928
|
+
@deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.')
|
|
1929
|
+
def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ...
|
|
1603
1930
|
|
|
1604
|
-
def new_messages(
|
|
1931
|
+
def new_messages(
|
|
1932
|
+
self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None
|
|
1933
|
+
) -> list[_messages.ModelMessage]:
|
|
1605
1934
|
"""Return new messages associated with this run.
|
|
1606
1935
|
|
|
1607
1936
|
Messages from older runs are excluded.
|
|
1608
1937
|
|
|
1609
1938
|
Args:
|
|
1610
|
-
|
|
1611
|
-
This provides a convenient way to modify the content of the
|
|
1612
|
-
the conversation and want to set the response to the
|
|
1939
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
1940
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
1941
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
1613
1942
|
not be modified.
|
|
1943
|
+
result_tool_return_content: Deprecated, use `output_tool_return_content` instead.
|
|
1614
1944
|
|
|
1615
1945
|
Returns:
|
|
1616
1946
|
List of new messages.
|
|
1617
1947
|
"""
|
|
1618
|
-
|
|
1948
|
+
content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content)
|
|
1949
|
+
return self.all_messages(output_tool_return_content=content)[self._new_message_index :]
|
|
1950
|
+
|
|
1951
|
+
@overload
|
|
1952
|
+
def new_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: ...
|
|
1619
1953
|
|
|
1620
|
-
|
|
1954
|
+
@overload
|
|
1955
|
+
@deprecated('`result_tool_return_content` is deprecated, use `output_tool_return_content` instead.')
|
|
1956
|
+
def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: ...
|
|
1957
|
+
|
|
1958
|
+
def new_messages_json(
|
|
1959
|
+
self, *, output_tool_return_content: str | None = None, result_tool_return_content: str | None = None
|
|
1960
|
+
) -> bytes:
|
|
1621
1961
|
"""Return new messages from [`new_messages`][pydantic_ai.agent.AgentRunResult.new_messages] as JSON bytes.
|
|
1622
1962
|
|
|
1623
1963
|
Args:
|
|
1624
|
-
|
|
1625
|
-
This provides a convenient way to modify the content of the
|
|
1626
|
-
the conversation and want to set the response to the
|
|
1964
|
+
output_tool_return_content: The return content of the tool call to set in the last message.
|
|
1965
|
+
This provides a convenient way to modify the content of the output tool call if you want to continue
|
|
1966
|
+
the conversation and want to set the response to the output tool call. If `None`, the last message will
|
|
1627
1967
|
not be modified.
|
|
1968
|
+
result_tool_return_content: Deprecated, use `output_tool_return_content` instead.
|
|
1628
1969
|
|
|
1629
1970
|
Returns:
|
|
1630
1971
|
JSON bytes representing the new messages.
|
|
1631
1972
|
"""
|
|
1632
|
-
|
|
1633
|
-
|
|
1634
|
-
)
|
|
1973
|
+
content = result.coalesce_deprecated_return_content(output_tool_return_content, result_tool_return_content)
|
|
1974
|
+
return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages(output_tool_return_content=content))
|
|
1635
1975
|
|
|
1636
1976
|
def usage(self) -> _usage.Usage:
|
|
1637
1977
|
"""Return the usage of the whole run."""
|