pydantic-ai-slim 0.0.21__py3-none-any.whl → 0.0.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +770 -0
- pydantic_ai/agent.py +182 -554
- pydantic_ai/models/__init__.py +4 -0
- pydantic_ai/models/gemini.py +7 -1
- pydantic_ai/models/openai.py +6 -1
- pydantic_ai/settings.py +5 -0
- {pydantic_ai_slim-0.0.21.dist-info → pydantic_ai_slim-0.0.22.dist-info}/METADATA +2 -3
- {pydantic_ai_slim-0.0.21.dist-info → pydantic_ai_slim-0.0.22.dist-info}/RECORD +9 -8
- {pydantic_ai_slim-0.0.21.dist-info → pydantic_ai_slim-0.0.22.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,770 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import dataclasses
|
|
5
|
+
from abc import ABC
|
|
6
|
+
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
7
|
+
from contextlib import asynccontextmanager, contextmanager
|
|
8
|
+
from contextvars import ContextVar
|
|
9
|
+
from dataclasses import field
|
|
10
|
+
from typing import Any, Generic, Literal, Union, cast
|
|
11
|
+
|
|
12
|
+
import logfire_api
|
|
13
|
+
from typing_extensions import TypeVar, assert_never
|
|
14
|
+
|
|
15
|
+
from pydantic_graph import BaseNode, Graph, GraphRunContext
|
|
16
|
+
from pydantic_graph.nodes import End, NodeRunEndT
|
|
17
|
+
|
|
18
|
+
from . import (
|
|
19
|
+
_result,
|
|
20
|
+
_system_prompt,
|
|
21
|
+
exceptions,
|
|
22
|
+
messages as _messages,
|
|
23
|
+
models,
|
|
24
|
+
result,
|
|
25
|
+
usage as _usage,
|
|
26
|
+
)
|
|
27
|
+
from .result import ResultDataT
|
|
28
|
+
from .settings import ModelSettings, merge_model_settings
|
|
29
|
+
from .tools import (
|
|
30
|
+
RunContext,
|
|
31
|
+
Tool,
|
|
32
|
+
ToolDefinition,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
36
|
+
|
|
37
|
+
# while waiting for https://github.com/pydantic/logfire/issues/745
|
|
38
|
+
try:
|
|
39
|
+
import logfire._internal.stack_info
|
|
40
|
+
except ImportError:
|
|
41
|
+
pass
|
|
42
|
+
else:
|
|
43
|
+
from pathlib import Path
|
|
44
|
+
|
|
45
|
+
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
|
|
46
|
+
|
|
47
|
+
T = TypeVar('T')
|
|
48
|
+
NoneType = type(None)
|
|
49
|
+
EndStrategy = Literal['early', 'exhaustive']
|
|
50
|
+
"""The strategy for handling multiple tool calls when a final result is found.
|
|
51
|
+
|
|
52
|
+
- `'early'`: Stop processing other tool calls once a final result is found
|
|
53
|
+
- `'exhaustive'`: Process all tool calls even after finding a final result
|
|
54
|
+
"""
|
|
55
|
+
DepsT = TypeVar('DepsT')
|
|
56
|
+
ResultT = TypeVar('ResultT')
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclasses.dataclass
|
|
60
|
+
class MarkFinalResult(Generic[ResultDataT]):
|
|
61
|
+
"""Marker class to indicate that the result is the final result.
|
|
62
|
+
|
|
63
|
+
This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly.
|
|
64
|
+
|
|
65
|
+
It also avoids problems in the case where the result type is itself `None`, but is set.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
data: ResultDataT
|
|
69
|
+
"""The final result data."""
|
|
70
|
+
tool_name: str | None
|
|
71
|
+
"""Name of the final result tool, None if the result is a string."""
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclasses.dataclass
|
|
75
|
+
class GraphAgentState:
|
|
76
|
+
"""State kept across the execution of the agent graph."""
|
|
77
|
+
|
|
78
|
+
message_history: list[_messages.ModelMessage]
|
|
79
|
+
usage: _usage.Usage
|
|
80
|
+
retries: int
|
|
81
|
+
run_step: int
|
|
82
|
+
|
|
83
|
+
def increment_retries(self, max_result_retries: int) -> None:
|
|
84
|
+
self.retries += 1
|
|
85
|
+
if self.retries > max_result_retries:
|
|
86
|
+
raise exceptions.UnexpectedModelBehavior(
|
|
87
|
+
f'Exceeded maximum retries ({max_result_retries}) for result validation'
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclasses.dataclass
|
|
92
|
+
class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
93
|
+
"""Dependencies/config passed to the agent graph."""
|
|
94
|
+
|
|
95
|
+
user_deps: DepsT
|
|
96
|
+
|
|
97
|
+
prompt: str
|
|
98
|
+
new_message_index: int
|
|
99
|
+
|
|
100
|
+
model: models.Model
|
|
101
|
+
model_settings: ModelSettings | None
|
|
102
|
+
usage_limits: _usage.UsageLimits
|
|
103
|
+
max_result_retries: int
|
|
104
|
+
end_strategy: EndStrategy
|
|
105
|
+
|
|
106
|
+
result_schema: _result.ResultSchema[ResultDataT] | None
|
|
107
|
+
result_tools: list[ToolDefinition]
|
|
108
|
+
result_validators: list[_result.ResultValidator[DepsT, ResultDataT]]
|
|
109
|
+
|
|
110
|
+
function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
|
|
111
|
+
|
|
112
|
+
run_span: logfire_api.LogfireSpan
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclasses.dataclass
|
|
116
|
+
class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
|
|
117
|
+
user_prompt: str
|
|
118
|
+
|
|
119
|
+
system_prompts: tuple[str, ...]
|
|
120
|
+
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
|
|
121
|
+
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
|
|
122
|
+
|
|
123
|
+
async def _get_first_message(
|
|
124
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
|
|
125
|
+
) -> _messages.ModelRequest:
|
|
126
|
+
run_context = _build_run_context(ctx)
|
|
127
|
+
history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context)
|
|
128
|
+
ctx.state.message_history = history
|
|
129
|
+
run_context.messages = history
|
|
130
|
+
|
|
131
|
+
# TODO: We need to make it so that function_tools are not shared between runs
|
|
132
|
+
# See comment on the current_retry field of `Tool` for more details.
|
|
133
|
+
for tool in ctx.deps.function_tools.values():
|
|
134
|
+
tool.current_retry = 0
|
|
135
|
+
return next_message
|
|
136
|
+
|
|
137
|
+
async def _prepare_messages(
|
|
138
|
+
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[DepsT]
|
|
139
|
+
) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
|
|
140
|
+
try:
|
|
141
|
+
ctx_messages = get_captured_run_messages()
|
|
142
|
+
except LookupError:
|
|
143
|
+
messages: list[_messages.ModelMessage] = []
|
|
144
|
+
else:
|
|
145
|
+
if ctx_messages.used:
|
|
146
|
+
messages = []
|
|
147
|
+
else:
|
|
148
|
+
messages = ctx_messages.messages
|
|
149
|
+
ctx_messages.used = True
|
|
150
|
+
|
|
151
|
+
if message_history:
|
|
152
|
+
# Shallow copy messages
|
|
153
|
+
messages.extend(message_history)
|
|
154
|
+
# Reevaluate any dynamic system prompt parts
|
|
155
|
+
await self._reevaluate_dynamic_prompts(messages, run_context)
|
|
156
|
+
return messages, _messages.ModelRequest([_messages.UserPromptPart(user_prompt)])
|
|
157
|
+
else:
|
|
158
|
+
parts = await self._sys_parts(run_context)
|
|
159
|
+
parts.append(_messages.UserPromptPart(user_prompt))
|
|
160
|
+
return messages, _messages.ModelRequest(parts)
|
|
161
|
+
|
|
162
|
+
async def _reevaluate_dynamic_prompts(
|
|
163
|
+
self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
|
|
164
|
+
) -> None:
|
|
165
|
+
"""Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
|
|
166
|
+
# Only proceed if there's at least one dynamic runner.
|
|
167
|
+
if self.system_prompt_dynamic_functions:
|
|
168
|
+
for msg in messages:
|
|
169
|
+
if isinstance(msg, _messages.ModelRequest):
|
|
170
|
+
for i, part in enumerate(msg.parts):
|
|
171
|
+
if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
|
|
172
|
+
# Look up the runner by its ref
|
|
173
|
+
if runner := self.system_prompt_dynamic_functions.get(part.dynamic_ref):
|
|
174
|
+
updated_part_content = await runner.run(run_context)
|
|
175
|
+
msg.parts[i] = _messages.SystemPromptPart(
|
|
176
|
+
updated_part_content, dynamic_ref=part.dynamic_ref
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.ModelRequestPart]:
|
|
180
|
+
"""Build the initial messages for the conversation."""
|
|
181
|
+
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self.system_prompts]
|
|
182
|
+
for sys_prompt_runner in self.system_prompt_functions:
|
|
183
|
+
prompt = await sys_prompt_runner.run(run_context)
|
|
184
|
+
if sys_prompt_runner.dynamic:
|
|
185
|
+
messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
|
|
186
|
+
else:
|
|
187
|
+
messages.append(_messages.SystemPromptPart(prompt))
|
|
188
|
+
return messages
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@dataclasses.dataclass
|
|
192
|
+
class UserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
|
|
193
|
+
async def run(
|
|
194
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
|
|
195
|
+
) -> ModelRequestNode[DepsT, NodeRunEndT]:
|
|
196
|
+
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@dataclasses.dataclass
|
|
200
|
+
class StreamUserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
|
|
201
|
+
async def run(
|
|
202
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
|
|
203
|
+
) -> StreamModelRequestNode[DepsT, NodeRunEndT]:
|
|
204
|
+
return StreamModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
async def _prepare_model(
|
|
208
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
209
|
+
) -> models.AgentModel:
|
|
210
|
+
"""Build tools and create an agent model."""
|
|
211
|
+
function_tool_defs: list[ToolDefinition] = []
|
|
212
|
+
|
|
213
|
+
run_context = _build_run_context(ctx)
|
|
214
|
+
|
|
215
|
+
async def add_tool(tool: Tool[DepsT]) -> None:
|
|
216
|
+
ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
|
|
217
|
+
if tool_def := await tool.prepare_tool_def(ctx):
|
|
218
|
+
function_tool_defs.append(tool_def)
|
|
219
|
+
|
|
220
|
+
await asyncio.gather(*map(add_tool, ctx.deps.function_tools.values()))
|
|
221
|
+
|
|
222
|
+
result_schema = ctx.deps.result_schema
|
|
223
|
+
return await run_context.model.agent_model(
|
|
224
|
+
function_tools=function_tool_defs,
|
|
225
|
+
allow_text_result=_allow_text_result(result_schema),
|
|
226
|
+
result_tools=result_schema.tool_defs() if result_schema is not None else [],
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@dataclasses.dataclass
|
|
231
|
+
class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
|
|
232
|
+
"""Make a request to the model using the last message in state.message_history."""
|
|
233
|
+
|
|
234
|
+
request: _messages.ModelRequest
|
|
235
|
+
|
|
236
|
+
async def run(
|
|
237
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
238
|
+
) -> HandleResponseNode[DepsT, NodeRunEndT]:
|
|
239
|
+
ctx.state.message_history.append(self.request)
|
|
240
|
+
|
|
241
|
+
# Check usage
|
|
242
|
+
if ctx.deps.usage_limits:
|
|
243
|
+
ctx.deps.usage_limits.check_before_request(ctx.state.usage)
|
|
244
|
+
|
|
245
|
+
# Increment run_step
|
|
246
|
+
ctx.state.run_step += 1
|
|
247
|
+
|
|
248
|
+
with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step):
|
|
249
|
+
agent_model = await _prepare_model(ctx)
|
|
250
|
+
|
|
251
|
+
# Actually make the model request
|
|
252
|
+
model_settings = merge_model_settings(ctx.deps.model_settings, None)
|
|
253
|
+
with _logfire.span('model request') as span:
|
|
254
|
+
model_response, request_usage = await agent_model.request(ctx.state.message_history, model_settings)
|
|
255
|
+
span.set_attribute('response', model_response)
|
|
256
|
+
span.set_attribute('usage', request_usage)
|
|
257
|
+
|
|
258
|
+
# Update usage
|
|
259
|
+
ctx.state.usage.incr(request_usage, requests=1)
|
|
260
|
+
if ctx.deps.usage_limits:
|
|
261
|
+
ctx.deps.usage_limits.check_tokens(ctx.state.usage)
|
|
262
|
+
|
|
263
|
+
# Append the model response to state.message_history
|
|
264
|
+
ctx.state.message_history.append(model_response)
|
|
265
|
+
return HandleResponseNode(model_response)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@dataclasses.dataclass
|
|
269
|
+
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
|
|
270
|
+
"""Process e response from a model, decide whether to end the run or make a new request."""
|
|
271
|
+
|
|
272
|
+
model_response: _messages.ModelResponse
|
|
273
|
+
|
|
274
|
+
async def run(
|
|
275
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
276
|
+
) -> Union[ModelRequestNode[DepsT, NodeRunEndT], FinalResultNode[DepsT, NodeRunEndT]]: # noqa UP007
|
|
277
|
+
with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
|
|
278
|
+
texts: list[str] = []
|
|
279
|
+
tool_calls: list[_messages.ToolCallPart] = []
|
|
280
|
+
for part in self.model_response.parts:
|
|
281
|
+
if isinstance(part, _messages.TextPart):
|
|
282
|
+
# ignore empty content for text parts, see #437
|
|
283
|
+
if part.content:
|
|
284
|
+
texts.append(part.content)
|
|
285
|
+
elif isinstance(part, _messages.ToolCallPart):
|
|
286
|
+
tool_calls.append(part)
|
|
287
|
+
else:
|
|
288
|
+
assert_never(part)
|
|
289
|
+
|
|
290
|
+
# At the moment, we prioritize at least executing tool calls if they are present.
|
|
291
|
+
# In the future, we'd consider making this configurable at the agent or run level.
|
|
292
|
+
# This accounts for cases like anthropic returns that might contain a text response
|
|
293
|
+
# and a tool call response, where the text response just indicates the tool call will happen.
|
|
294
|
+
if tool_calls:
|
|
295
|
+
return await self._handle_tool_calls_response(ctx, tool_calls, handle_span)
|
|
296
|
+
elif texts:
|
|
297
|
+
return await self._handle_text_response(ctx, texts, handle_span)
|
|
298
|
+
else:
|
|
299
|
+
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
300
|
+
|
|
301
|
+
async def _handle_tool_calls_response(
|
|
302
|
+
self,
|
|
303
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
304
|
+
tool_calls: list[_messages.ToolCallPart],
|
|
305
|
+
handle_span: logfire_api.LogfireSpan,
|
|
306
|
+
):
|
|
307
|
+
result_schema = ctx.deps.result_schema
|
|
308
|
+
|
|
309
|
+
# first look for the result tool call
|
|
310
|
+
final_result: MarkFinalResult[NodeRunEndT] | None = None
|
|
311
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
312
|
+
if result_schema is not None:
|
|
313
|
+
if match := result_schema.find_tool(tool_calls):
|
|
314
|
+
call, result_tool = match
|
|
315
|
+
try:
|
|
316
|
+
result_data = result_tool.validate(call)
|
|
317
|
+
result_data = await _validate_result(result_data, ctx, call)
|
|
318
|
+
except _result.ToolRetryError as e:
|
|
319
|
+
# TODO: Should only increment retry stuff once per node execution, not for each tool call
|
|
320
|
+
# Also, should increment the tool-specific retry count rather than the run retry count
|
|
321
|
+
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
322
|
+
parts.append(e.tool_retry)
|
|
323
|
+
else:
|
|
324
|
+
final_result = MarkFinalResult(result_data, call.tool_name)
|
|
325
|
+
|
|
326
|
+
# Then build the other request parts based on end strategy
|
|
327
|
+
tool_responses = await _process_function_tools(tool_calls, final_result and final_result.tool_name, ctx)
|
|
328
|
+
|
|
329
|
+
if final_result:
|
|
330
|
+
handle_span.set_attribute('result', final_result.data)
|
|
331
|
+
handle_span.message = 'handle model response -> final result'
|
|
332
|
+
return FinalResultNode[DepsT, NodeRunEndT](final_result, tool_responses)
|
|
333
|
+
else:
|
|
334
|
+
if tool_responses:
|
|
335
|
+
handle_span.set_attribute('tool_responses', tool_responses)
|
|
336
|
+
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
337
|
+
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
338
|
+
parts.extend(tool_responses)
|
|
339
|
+
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts))
|
|
340
|
+
|
|
341
|
+
async def _handle_text_response(
|
|
342
|
+
self,
|
|
343
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
344
|
+
texts: list[str],
|
|
345
|
+
handle_span: logfire_api.LogfireSpan,
|
|
346
|
+
):
|
|
347
|
+
result_schema = ctx.deps.result_schema
|
|
348
|
+
|
|
349
|
+
text = '\n\n'.join(texts)
|
|
350
|
+
if _allow_text_result(result_schema):
|
|
351
|
+
result_data_input = cast(NodeRunEndT, text)
|
|
352
|
+
try:
|
|
353
|
+
result_data = await _validate_result(result_data_input, ctx, None)
|
|
354
|
+
except _result.ToolRetryError as e:
|
|
355
|
+
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
356
|
+
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
|
|
357
|
+
else:
|
|
358
|
+
handle_span.set_attribute('result', result_data)
|
|
359
|
+
handle_span.message = 'handle model response -> final result'
|
|
360
|
+
return FinalResultNode[DepsT, NodeRunEndT](MarkFinalResult(result_data, None))
|
|
361
|
+
else:
|
|
362
|
+
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
363
|
+
return ModelRequestNode[DepsT, NodeRunEndT](
|
|
364
|
+
_messages.ModelRequest(
|
|
365
|
+
parts=[
|
|
366
|
+
_messages.RetryPromptPart(
|
|
367
|
+
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
368
|
+
)
|
|
369
|
+
]
|
|
370
|
+
)
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
@dataclasses.dataclass
|
|
375
|
+
class StreamModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
|
|
376
|
+
"""Make a request to the model using the last message in state.message_history (or a specified request)."""
|
|
377
|
+
|
|
378
|
+
request: _messages.ModelRequest
|
|
379
|
+
_result: StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]] | None = (
|
|
380
|
+
field(default=None, repr=False)
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
async def run(
|
|
384
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
385
|
+
) -> Union[StreamModelRequestNode[DepsT, NodeRunEndT], End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: # noqa UP007
|
|
386
|
+
if self._result is not None:
|
|
387
|
+
return self._result
|
|
388
|
+
|
|
389
|
+
async with self.run_to_result(ctx) as final_node:
|
|
390
|
+
return final_node
|
|
391
|
+
|
|
392
|
+
@asynccontextmanager
|
|
393
|
+
async def run_to_result(
|
|
394
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
395
|
+
) -> AsyncIterator[StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]]]:
|
|
396
|
+
result_schema = ctx.deps.result_schema
|
|
397
|
+
|
|
398
|
+
ctx.state.message_history.append(self.request)
|
|
399
|
+
|
|
400
|
+
# Check usage
|
|
401
|
+
if ctx.deps.usage_limits:
|
|
402
|
+
ctx.deps.usage_limits.check_before_request(ctx.state.usage)
|
|
403
|
+
|
|
404
|
+
# Increment run_step
|
|
405
|
+
ctx.state.run_step += 1
|
|
406
|
+
|
|
407
|
+
with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step):
|
|
408
|
+
agent_model = await _prepare_model(ctx)
|
|
409
|
+
|
|
410
|
+
# Actually make the model request
|
|
411
|
+
model_settings = merge_model_settings(ctx.deps.model_settings, None)
|
|
412
|
+
with _logfire.span('model request {run_step=}', run_step=ctx.state.run_step) as model_req_span:
|
|
413
|
+
async with agent_model.request_stream(ctx.state.message_history, model_settings) as streamed_response:
|
|
414
|
+
ctx.state.usage.requests += 1
|
|
415
|
+
model_req_span.set_attribute('response_type', streamed_response.__class__.__name__)
|
|
416
|
+
# We want to end the "model request" span here, but we can't exit the context manager
|
|
417
|
+
# in the traditional way
|
|
418
|
+
model_req_span.__exit__(None, None, None)
|
|
419
|
+
|
|
420
|
+
with _logfire.span('handle model response') as handle_span:
|
|
421
|
+
received_text = False
|
|
422
|
+
|
|
423
|
+
async for maybe_part_event in streamed_response:
|
|
424
|
+
if isinstance(maybe_part_event, _messages.PartStartEvent):
|
|
425
|
+
new_part = maybe_part_event.part
|
|
426
|
+
if isinstance(new_part, _messages.TextPart):
|
|
427
|
+
received_text = True
|
|
428
|
+
if _allow_text_result(result_schema):
|
|
429
|
+
handle_span.message = 'handle model response -> final result'
|
|
430
|
+
streamed_run_result = _build_streamed_run_result(streamed_response, None, ctx)
|
|
431
|
+
self._result = End(streamed_run_result)
|
|
432
|
+
yield self._result
|
|
433
|
+
return
|
|
434
|
+
elif isinstance(new_part, _messages.ToolCallPart):
|
|
435
|
+
if result_schema is not None and (match := result_schema.find_tool([new_part])):
|
|
436
|
+
call, _ = match
|
|
437
|
+
handle_span.message = 'handle model response -> final result'
|
|
438
|
+
streamed_run_result = _build_streamed_run_result(
|
|
439
|
+
streamed_response, call.tool_name, ctx
|
|
440
|
+
)
|
|
441
|
+
self._result = End(streamed_run_result)
|
|
442
|
+
yield self._result
|
|
443
|
+
return
|
|
444
|
+
else:
|
|
445
|
+
assert_never(new_part)
|
|
446
|
+
|
|
447
|
+
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
448
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
449
|
+
model_response = streamed_response.get()
|
|
450
|
+
if not model_response.parts:
|
|
451
|
+
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
452
|
+
ctx.state.message_history.append(model_response)
|
|
453
|
+
|
|
454
|
+
run_context = _build_run_context(ctx)
|
|
455
|
+
for p in model_response.parts:
|
|
456
|
+
if isinstance(p, _messages.ToolCallPart):
|
|
457
|
+
if tool := ctx.deps.function_tools.get(p.tool_name):
|
|
458
|
+
tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
|
|
459
|
+
else:
|
|
460
|
+
parts.append(_unknown_tool(p.tool_name, ctx))
|
|
461
|
+
|
|
462
|
+
if received_text and not tasks and not parts:
|
|
463
|
+
# Can only get here if self._allow_text_result returns `False` for the provided result_schema
|
|
464
|
+
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
465
|
+
self._result = StreamModelRequestNode[DepsT, NodeRunEndT](
|
|
466
|
+
_messages.ModelRequest(
|
|
467
|
+
parts=[
|
|
468
|
+
_messages.RetryPromptPart(
|
|
469
|
+
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
470
|
+
)
|
|
471
|
+
]
|
|
472
|
+
)
|
|
473
|
+
)
|
|
474
|
+
yield self._result
|
|
475
|
+
return
|
|
476
|
+
|
|
477
|
+
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
478
|
+
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
479
|
+
parts.extend(task_results)
|
|
480
|
+
|
|
481
|
+
next_request = _messages.ModelRequest(parts=parts)
|
|
482
|
+
if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
483
|
+
try:
|
|
484
|
+
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
485
|
+
except:
|
|
486
|
+
# TODO: This is janky, so I think we should probably change it, but how?
|
|
487
|
+
ctx.state.message_history.append(next_request)
|
|
488
|
+
raise
|
|
489
|
+
|
|
490
|
+
handle_span.set_attribute('tool_responses', parts)
|
|
491
|
+
tool_responses_str = ' '.join(r.part_kind for r in parts)
|
|
492
|
+
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
493
|
+
# the model_response should have been fully streamed by now, we can add its usage
|
|
494
|
+
streamed_response_usage = streamed_response.usage()
|
|
495
|
+
run_context.usage.incr(streamed_response_usage)
|
|
496
|
+
ctx.deps.usage_limits.check_tokens(run_context.usage)
|
|
497
|
+
self._result = StreamModelRequestNode[DepsT, NodeRunEndT](next_request)
|
|
498
|
+
yield self._result
|
|
499
|
+
return
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
@dataclasses.dataclass
|
|
503
|
+
class FinalResultNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[NodeRunEndT]]):
|
|
504
|
+
"""Produce the final result of the run."""
|
|
505
|
+
|
|
506
|
+
data: MarkFinalResult[NodeRunEndT]
|
|
507
|
+
"""The final result data."""
|
|
508
|
+
extra_parts: list[_messages.ModelRequestPart] = dataclasses.field(default_factory=list)
|
|
509
|
+
|
|
510
|
+
async def run(
|
|
511
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
512
|
+
) -> End[MarkFinalResult[NodeRunEndT]]:
|
|
513
|
+
run_span = ctx.deps.run_span
|
|
514
|
+
usage = ctx.state.usage
|
|
515
|
+
messages = ctx.state.message_history
|
|
516
|
+
|
|
517
|
+
# TODO: For backwards compatibility, append a new ModelRequest using the tool returns and retries
|
|
518
|
+
if self.extra_parts:
|
|
519
|
+
messages.append(_messages.ModelRequest(parts=self.extra_parts))
|
|
520
|
+
|
|
521
|
+
# TODO: Set this attribute somewhere
|
|
522
|
+
# handle_span = self.handle_model_response_span
|
|
523
|
+
# handle_span.set_attribute('final_data', self.data)
|
|
524
|
+
run_span.set_attribute('usage', usage)
|
|
525
|
+
run_span.set_attribute('all_messages', messages)
|
|
526
|
+
|
|
527
|
+
# End the run with self.data
|
|
528
|
+
return End(self.data)
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
|
|
532
|
+
return RunContext[DepsT](
|
|
533
|
+
deps=ctx.deps.user_deps,
|
|
534
|
+
model=ctx.deps.model,
|
|
535
|
+
usage=ctx.state.usage,
|
|
536
|
+
prompt=ctx.deps.prompt,
|
|
537
|
+
messages=ctx.state.message_history,
|
|
538
|
+
run_step=ctx.state.run_step,
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def _build_streamed_run_result(
|
|
543
|
+
result_stream: models.StreamedResponse,
|
|
544
|
+
result_tool_name: str | None,
|
|
545
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
546
|
+
) -> result.StreamedRunResult[DepsT, NodeRunEndT]:
|
|
547
|
+
new_message_index = ctx.deps.new_message_index
|
|
548
|
+
result_schema = ctx.deps.result_schema
|
|
549
|
+
run_span = ctx.deps.run_span
|
|
550
|
+
usage_limits = ctx.deps.usage_limits
|
|
551
|
+
messages = ctx.state.message_history
|
|
552
|
+
run_context = _build_run_context(ctx)
|
|
553
|
+
|
|
554
|
+
async def on_complete():
|
|
555
|
+
"""Called when the stream has completed.
|
|
556
|
+
|
|
557
|
+
The model response will have been added to messages by now
|
|
558
|
+
by `StreamedRunResult._marked_completed`.
|
|
559
|
+
"""
|
|
560
|
+
last_message = messages[-1]
|
|
561
|
+
assert isinstance(last_message, _messages.ModelResponse)
|
|
562
|
+
tool_calls = [part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)]
|
|
563
|
+
parts = await _process_function_tools(
|
|
564
|
+
tool_calls,
|
|
565
|
+
result_tool_name,
|
|
566
|
+
ctx,
|
|
567
|
+
)
|
|
568
|
+
# TODO: Should we do something here related to the retry count?
|
|
569
|
+
# Maybe we should move the incrementing of the retry count to where we actually make a request?
|
|
570
|
+
# if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
571
|
+
# ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
572
|
+
if parts:
|
|
573
|
+
messages.append(_messages.ModelRequest(parts))
|
|
574
|
+
run_span.set_attribute('all_messages', messages)
|
|
575
|
+
|
|
576
|
+
return result.StreamedRunResult[DepsT, NodeRunEndT](
|
|
577
|
+
messages,
|
|
578
|
+
new_message_index,
|
|
579
|
+
usage_limits,
|
|
580
|
+
result_stream,
|
|
581
|
+
result_schema,
|
|
582
|
+
run_context,
|
|
583
|
+
ctx.deps.result_validators,
|
|
584
|
+
result_tool_name,
|
|
585
|
+
on_complete,
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
async def _process_function_tools(
|
|
590
|
+
tool_calls: list[_messages.ToolCallPart],
|
|
591
|
+
result_tool_name: str | None,
|
|
592
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
593
|
+
) -> list[_messages.ModelRequestPart]:
|
|
594
|
+
"""Process function (non-result) tool calls in parallel.
|
|
595
|
+
|
|
596
|
+
Also add stub return parts for any other tools that need it.
|
|
597
|
+
"""
|
|
598
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
599
|
+
tasks: list[asyncio.Task[_messages.ToolReturnPart | _messages.RetryPromptPart]] = []
|
|
600
|
+
|
|
601
|
+
stub_function_tools = bool(result_tool_name) and ctx.deps.end_strategy == 'early'
|
|
602
|
+
result_schema = ctx.deps.result_schema
|
|
603
|
+
|
|
604
|
+
# we rely on the fact that if we found a result, it's the first result tool in the last
|
|
605
|
+
found_used_result_tool = False
|
|
606
|
+
run_context = _build_run_context(ctx)
|
|
607
|
+
|
|
608
|
+
for call in tool_calls:
|
|
609
|
+
if call.tool_name == result_tool_name and not found_used_result_tool:
|
|
610
|
+
found_used_result_tool = True
|
|
611
|
+
parts.append(
|
|
612
|
+
_messages.ToolReturnPart(
|
|
613
|
+
tool_name=call.tool_name,
|
|
614
|
+
content='Final result processed.',
|
|
615
|
+
tool_call_id=call.tool_call_id,
|
|
616
|
+
)
|
|
617
|
+
)
|
|
618
|
+
elif tool := ctx.deps.function_tools.get(call.tool_name):
|
|
619
|
+
if stub_function_tools:
|
|
620
|
+
parts.append(
|
|
621
|
+
_messages.ToolReturnPart(
|
|
622
|
+
tool_name=call.tool_name,
|
|
623
|
+
content='Tool not executed - a final result was already processed.',
|
|
624
|
+
tool_call_id=call.tool_call_id,
|
|
625
|
+
)
|
|
626
|
+
)
|
|
627
|
+
else:
|
|
628
|
+
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
|
|
629
|
+
elif result_schema is not None and call.tool_name in result_schema.tools:
|
|
630
|
+
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
|
|
631
|
+
# validation, we don't add another part here
|
|
632
|
+
if result_tool_name is not None:
|
|
633
|
+
parts.append(
|
|
634
|
+
_messages.ToolReturnPart(
|
|
635
|
+
tool_name=call.tool_name,
|
|
636
|
+
content='Result tool not used - a final result was already processed.',
|
|
637
|
+
tool_call_id=call.tool_call_id,
|
|
638
|
+
)
|
|
639
|
+
)
|
|
640
|
+
else:
|
|
641
|
+
parts.append(_unknown_tool(call.tool_name, ctx))
|
|
642
|
+
|
|
643
|
+
# Run all tool tasks in parallel
|
|
644
|
+
if tasks:
|
|
645
|
+
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
646
|
+
task_results: Sequence[_messages.ToolReturnPart | _messages.RetryPromptPart] = await asyncio.gather(*tasks)
|
|
647
|
+
for result in task_results:
|
|
648
|
+
if isinstance(result, _messages.ToolReturnPart):
|
|
649
|
+
parts.append(result)
|
|
650
|
+
elif isinstance(result, _messages.RetryPromptPart):
|
|
651
|
+
parts.append(result)
|
|
652
|
+
else:
|
|
653
|
+
assert_never(result)
|
|
654
|
+
return parts
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
def _unknown_tool(
|
|
658
|
+
tool_name: str,
|
|
659
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
660
|
+
) -> _messages.RetryPromptPart:
|
|
661
|
+
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
662
|
+
tool_names = list(ctx.deps.function_tools.keys())
|
|
663
|
+
if result_schema := ctx.deps.result_schema:
|
|
664
|
+
tool_names.extend(result_schema.tool_names())
|
|
665
|
+
|
|
666
|
+
if tool_names:
|
|
667
|
+
msg = f'Available tools: {", ".join(tool_names)}'
|
|
668
|
+
else:
|
|
669
|
+
msg = 'No tools available.'
|
|
670
|
+
|
|
671
|
+
return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
async def _validate_result(
|
|
675
|
+
result_data: T,
|
|
676
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
|
|
677
|
+
tool_call: _messages.ToolCallPart | None,
|
|
678
|
+
) -> T:
|
|
679
|
+
for validator in ctx.deps.result_validators:
|
|
680
|
+
run_context = _build_run_context(ctx)
|
|
681
|
+
result_data = await validator.validate(result_data, tool_call, run_context)
|
|
682
|
+
return result_data
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
def _allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool:
|
|
686
|
+
return result_schema is None or result_schema.allow_text_result
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
@dataclasses.dataclass
|
|
690
|
+
class _RunMessages:
|
|
691
|
+
messages: list[_messages.ModelMessage]
|
|
692
|
+
used: bool = False
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
_messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var')
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
@contextmanager
|
|
699
|
+
def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
|
|
700
|
+
"""Context manager to access the messages used in a [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] call.
|
|
701
|
+
|
|
702
|
+
Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information.
|
|
703
|
+
|
|
704
|
+
Examples:
|
|
705
|
+
```python
|
|
706
|
+
from pydantic_ai import Agent, capture_run_messages
|
|
707
|
+
|
|
708
|
+
agent = Agent('test')
|
|
709
|
+
|
|
710
|
+
with capture_run_messages() as messages:
|
|
711
|
+
try:
|
|
712
|
+
result = agent.run_sync('foobar')
|
|
713
|
+
except Exception:
|
|
714
|
+
print(messages)
|
|
715
|
+
raise
|
|
716
|
+
```
|
|
717
|
+
|
|
718
|
+
!!! note
|
|
719
|
+
If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
|
|
720
|
+
`messages` will represent the messages exchanged during the first call only.
|
|
721
|
+
"""
|
|
722
|
+
try:
|
|
723
|
+
yield _messages_ctx_var.get().messages
|
|
724
|
+
except LookupError:
|
|
725
|
+
messages: list[_messages.ModelMessage] = []
|
|
726
|
+
token = _messages_ctx_var.set(_RunMessages(messages))
|
|
727
|
+
try:
|
|
728
|
+
yield messages
|
|
729
|
+
finally:
|
|
730
|
+
_messages_ctx_var.reset(token)
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
def get_captured_run_messages() -> _RunMessages:
|
|
734
|
+
return _messages_ctx_var.get()
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
def build_agent_graph(
|
|
738
|
+
name: str | None, deps_type: type[DepsT], result_type: type[ResultT]
|
|
739
|
+
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]]:
|
|
740
|
+
# We'll define the known node classes:
|
|
741
|
+
nodes = (
|
|
742
|
+
UserPromptNode[DepsT],
|
|
743
|
+
ModelRequestNode[DepsT],
|
|
744
|
+
HandleResponseNode[DepsT],
|
|
745
|
+
FinalResultNode[DepsT, ResultT],
|
|
746
|
+
)
|
|
747
|
+
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]](
|
|
748
|
+
nodes=nodes,
|
|
749
|
+
name=name or 'Agent',
|
|
750
|
+
state_type=GraphAgentState,
|
|
751
|
+
run_end_type=MarkFinalResult[result_type],
|
|
752
|
+
auto_instrument=False,
|
|
753
|
+
)
|
|
754
|
+
return graph
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
def build_agent_stream_graph(
|
|
758
|
+
name: str | None, deps_type: type[DepsT], result_type: type[ResultT] | None
|
|
759
|
+
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]:
|
|
760
|
+
nodes = [
|
|
761
|
+
StreamUserPromptNode[DepsT, result.StreamedRunResult[DepsT, ResultT]],
|
|
762
|
+
StreamModelRequestNode[DepsT, result.StreamedRunResult[DepsT, ResultT]],
|
|
763
|
+
]
|
|
764
|
+
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]](
|
|
765
|
+
nodes=nodes,
|
|
766
|
+
name=name or 'Agent',
|
|
767
|
+
state_type=GraphAgentState,
|
|
768
|
+
run_end_type=result.StreamedRunResult[DepsT, result_type],
|
|
769
|
+
)
|
|
770
|
+
return graph
|