pydantic-ai-slim 0.0.20__py3-none-any.whl → 0.0.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,770 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import asyncio
4
+ import dataclasses
5
+ from abc import ABC
6
+ from collections.abc import AsyncIterator, Iterator, Sequence
7
+ from contextlib import asynccontextmanager, contextmanager
8
+ from contextvars import ContextVar
9
+ from dataclasses import field
10
+ from typing import Any, Generic, Literal, Union, cast
11
+
12
+ import logfire_api
13
+ from typing_extensions import TypeVar, assert_never
14
+
15
+ from pydantic_graph import BaseNode, Graph, GraphRunContext
16
+ from pydantic_graph.nodes import End, NodeRunEndT
17
+
18
+ from . import (
19
+ _result,
20
+ _system_prompt,
21
+ exceptions,
22
+ messages as _messages,
23
+ models,
24
+ result,
25
+ usage as _usage,
26
+ )
27
+ from .result import ResultDataT
28
+ from .settings import ModelSettings, merge_model_settings
29
+ from .tools import (
30
+ RunContext,
31
+ Tool,
32
+ ToolDefinition,
33
+ )
34
+
35
+ _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
36
+
37
+ # while waiting for https://github.com/pydantic/logfire/issues/745
38
+ try:
39
+ import logfire._internal.stack_info
40
+ except ImportError:
41
+ pass
42
+ else:
43
+ from pathlib import Path
44
+
45
+ logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
46
+
47
+ T = TypeVar('T')
48
+ NoneType = type(None)
49
+ EndStrategy = Literal['early', 'exhaustive']
50
+ """The strategy for handling multiple tool calls when a final result is found.
51
+
52
+ - `'early'`: Stop processing other tool calls once a final result is found
53
+ - `'exhaustive'`: Process all tool calls even after finding a final result
54
+ """
55
+ DepsT = TypeVar('DepsT')
56
+ ResultT = TypeVar('ResultT')
57
+
58
+
59
+ @dataclasses.dataclass
60
+ class MarkFinalResult(Generic[ResultDataT]):
61
+ """Marker class to indicate that the result is the final result.
62
+
63
+ This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly.
64
+
65
+ It also avoids problems in the case where the result type is itself `None`, but is set.
66
+ """
67
+
68
+ data: ResultDataT
69
+ """The final result data."""
70
+ tool_name: str | None
71
+ """Name of the final result tool, None if the result is a string."""
72
+
73
+
74
+ @dataclasses.dataclass
75
+ class GraphAgentState:
76
+ """State kept across the execution of the agent graph."""
77
+
78
+ message_history: list[_messages.ModelMessage]
79
+ usage: _usage.Usage
80
+ retries: int
81
+ run_step: int
82
+
83
+ def increment_retries(self, max_result_retries: int) -> None:
84
+ self.retries += 1
85
+ if self.retries > max_result_retries:
86
+ raise exceptions.UnexpectedModelBehavior(
87
+ f'Exceeded maximum retries ({max_result_retries}) for result validation'
88
+ )
89
+
90
+
91
+ @dataclasses.dataclass
92
+ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
93
+ """Dependencies/config passed to the agent graph."""
94
+
95
+ user_deps: DepsT
96
+
97
+ prompt: str
98
+ new_message_index: int
99
+
100
+ model: models.Model
101
+ model_settings: ModelSettings | None
102
+ usage_limits: _usage.UsageLimits
103
+ max_result_retries: int
104
+ end_strategy: EndStrategy
105
+
106
+ result_schema: _result.ResultSchema[ResultDataT] | None
107
+ result_tools: list[ToolDefinition]
108
+ result_validators: list[_result.ResultValidator[DepsT, ResultDataT]]
109
+
110
+ function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
111
+
112
+ run_span: logfire_api.LogfireSpan
113
+
114
+
115
+ @dataclasses.dataclass
116
+ class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
117
+ user_prompt: str
118
+
119
+ system_prompts: tuple[str, ...]
120
+ system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
121
+ system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
122
+
123
+ async def _get_first_message(
124
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
125
+ ) -> _messages.ModelRequest:
126
+ run_context = _build_run_context(ctx)
127
+ history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context)
128
+ ctx.state.message_history = history
129
+ run_context.messages = history
130
+
131
+ # TODO: We need to make it so that function_tools are not shared between runs
132
+ # See comment on the current_retry field of `Tool` for more details.
133
+ for tool in ctx.deps.function_tools.values():
134
+ tool.current_retry = 0
135
+ return next_message
136
+
137
+ async def _prepare_messages(
138
+ self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[DepsT]
139
+ ) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
140
+ try:
141
+ ctx_messages = get_captured_run_messages()
142
+ except LookupError:
143
+ messages: list[_messages.ModelMessage] = []
144
+ else:
145
+ if ctx_messages.used:
146
+ messages = []
147
+ else:
148
+ messages = ctx_messages.messages
149
+ ctx_messages.used = True
150
+
151
+ if message_history:
152
+ # Shallow copy messages
153
+ messages.extend(message_history)
154
+ # Reevaluate any dynamic system prompt parts
155
+ await self._reevaluate_dynamic_prompts(messages, run_context)
156
+ return messages, _messages.ModelRequest([_messages.UserPromptPart(user_prompt)])
157
+ else:
158
+ parts = await self._sys_parts(run_context)
159
+ parts.append(_messages.UserPromptPart(user_prompt))
160
+ return messages, _messages.ModelRequest(parts)
161
+
162
+ async def _reevaluate_dynamic_prompts(
163
+ self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
164
+ ) -> None:
165
+ """Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
166
+ # Only proceed if there's at least one dynamic runner.
167
+ if self.system_prompt_dynamic_functions:
168
+ for msg in messages:
169
+ if isinstance(msg, _messages.ModelRequest):
170
+ for i, part in enumerate(msg.parts):
171
+ if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
172
+ # Look up the runner by its ref
173
+ if runner := self.system_prompt_dynamic_functions.get(part.dynamic_ref):
174
+ updated_part_content = await runner.run(run_context)
175
+ msg.parts[i] = _messages.SystemPromptPart(
176
+ updated_part_content, dynamic_ref=part.dynamic_ref
177
+ )
178
+
179
+ async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.ModelRequestPart]:
180
+ """Build the initial messages for the conversation."""
181
+ messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self.system_prompts]
182
+ for sys_prompt_runner in self.system_prompt_functions:
183
+ prompt = await sys_prompt_runner.run(run_context)
184
+ if sys_prompt_runner.dynamic:
185
+ messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
186
+ else:
187
+ messages.append(_messages.SystemPromptPart(prompt))
188
+ return messages
189
+
190
+
191
+ @dataclasses.dataclass
192
+ class UserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
193
+ async def run(
194
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
195
+ ) -> ModelRequestNode[DepsT, NodeRunEndT]:
196
+ return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
197
+
198
+
199
+ @dataclasses.dataclass
200
+ class StreamUserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
201
+ async def run(
202
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
203
+ ) -> StreamModelRequestNode[DepsT, NodeRunEndT]:
204
+ return StreamModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
205
+
206
+
207
+ async def _prepare_model(
208
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
209
+ ) -> models.AgentModel:
210
+ """Build tools and create an agent model."""
211
+ function_tool_defs: list[ToolDefinition] = []
212
+
213
+ run_context = _build_run_context(ctx)
214
+
215
+ async def add_tool(tool: Tool[DepsT]) -> None:
216
+ ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
217
+ if tool_def := await tool.prepare_tool_def(ctx):
218
+ function_tool_defs.append(tool_def)
219
+
220
+ await asyncio.gather(*map(add_tool, ctx.deps.function_tools.values()))
221
+
222
+ result_schema = ctx.deps.result_schema
223
+ return await run_context.model.agent_model(
224
+ function_tools=function_tool_defs,
225
+ allow_text_result=_allow_text_result(result_schema),
226
+ result_tools=result_schema.tool_defs() if result_schema is not None else [],
227
+ )
228
+
229
+
230
+ @dataclasses.dataclass
231
+ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
232
+ """Make a request to the model using the last message in state.message_history."""
233
+
234
+ request: _messages.ModelRequest
235
+
236
+ async def run(
237
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
238
+ ) -> HandleResponseNode[DepsT, NodeRunEndT]:
239
+ ctx.state.message_history.append(self.request)
240
+
241
+ # Check usage
242
+ if ctx.deps.usage_limits:
243
+ ctx.deps.usage_limits.check_before_request(ctx.state.usage)
244
+
245
+ # Increment run_step
246
+ ctx.state.run_step += 1
247
+
248
+ with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step):
249
+ agent_model = await _prepare_model(ctx)
250
+
251
+ # Actually make the model request
252
+ model_settings = merge_model_settings(ctx.deps.model_settings, None)
253
+ with _logfire.span('model request') as span:
254
+ model_response, request_usage = await agent_model.request(ctx.state.message_history, model_settings)
255
+ span.set_attribute('response', model_response)
256
+ span.set_attribute('usage', request_usage)
257
+
258
+ # Update usage
259
+ ctx.state.usage.incr(request_usage, requests=1)
260
+ if ctx.deps.usage_limits:
261
+ ctx.deps.usage_limits.check_tokens(ctx.state.usage)
262
+
263
+ # Append the model response to state.message_history
264
+ ctx.state.message_history.append(model_response)
265
+ return HandleResponseNode(model_response)
266
+
267
+
268
+ @dataclasses.dataclass
269
+ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
270
+ """Process e response from a model, decide whether to end the run or make a new request."""
271
+
272
+ model_response: _messages.ModelResponse
273
+
274
+ async def run(
275
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
276
+ ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], FinalResultNode[DepsT, NodeRunEndT]]: # noqa UP007
277
+ with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
278
+ texts: list[str] = []
279
+ tool_calls: list[_messages.ToolCallPart] = []
280
+ for part in self.model_response.parts:
281
+ if isinstance(part, _messages.TextPart):
282
+ # ignore empty content for text parts, see #437
283
+ if part.content:
284
+ texts.append(part.content)
285
+ elif isinstance(part, _messages.ToolCallPart):
286
+ tool_calls.append(part)
287
+ else:
288
+ assert_never(part)
289
+
290
+ # At the moment, we prioritize at least executing tool calls if they are present.
291
+ # In the future, we'd consider making this configurable at the agent or run level.
292
+ # This accounts for cases like anthropic returns that might contain a text response
293
+ # and a tool call response, where the text response just indicates the tool call will happen.
294
+ if tool_calls:
295
+ return await self._handle_tool_calls_response(ctx, tool_calls, handle_span)
296
+ elif texts:
297
+ return await self._handle_text_response(ctx, texts, handle_span)
298
+ else:
299
+ raise exceptions.UnexpectedModelBehavior('Received empty model response')
300
+
301
+ async def _handle_tool_calls_response(
302
+ self,
303
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
304
+ tool_calls: list[_messages.ToolCallPart],
305
+ handle_span: logfire_api.LogfireSpan,
306
+ ):
307
+ result_schema = ctx.deps.result_schema
308
+
309
+ # first look for the result tool call
310
+ final_result: MarkFinalResult[NodeRunEndT] | None = None
311
+ parts: list[_messages.ModelRequestPart] = []
312
+ if result_schema is not None:
313
+ if match := result_schema.find_tool(tool_calls):
314
+ call, result_tool = match
315
+ try:
316
+ result_data = result_tool.validate(call)
317
+ result_data = await _validate_result(result_data, ctx, call)
318
+ except _result.ToolRetryError as e:
319
+ # TODO: Should only increment retry stuff once per node execution, not for each tool call
320
+ # Also, should increment the tool-specific retry count rather than the run retry count
321
+ ctx.state.increment_retries(ctx.deps.max_result_retries)
322
+ parts.append(e.tool_retry)
323
+ else:
324
+ final_result = MarkFinalResult(result_data, call.tool_name)
325
+
326
+ # Then build the other request parts based on end strategy
327
+ tool_responses = await _process_function_tools(tool_calls, final_result and final_result.tool_name, ctx)
328
+
329
+ if final_result:
330
+ handle_span.set_attribute('result', final_result.data)
331
+ handle_span.message = 'handle model response -> final result'
332
+ return FinalResultNode[DepsT, NodeRunEndT](final_result, tool_responses)
333
+ else:
334
+ if tool_responses:
335
+ handle_span.set_attribute('tool_responses', tool_responses)
336
+ tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
337
+ handle_span.message = f'handle model response -> {tool_responses_str}'
338
+ parts.extend(tool_responses)
339
+ return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts))
340
+
341
+ async def _handle_text_response(
342
+ self,
343
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
344
+ texts: list[str],
345
+ handle_span: logfire_api.LogfireSpan,
346
+ ):
347
+ result_schema = ctx.deps.result_schema
348
+
349
+ text = '\n\n'.join(texts)
350
+ if _allow_text_result(result_schema):
351
+ result_data_input = cast(NodeRunEndT, text)
352
+ try:
353
+ result_data = await _validate_result(result_data_input, ctx, None)
354
+ except _result.ToolRetryError as e:
355
+ ctx.state.increment_retries(ctx.deps.max_result_retries)
356
+ return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
357
+ else:
358
+ handle_span.set_attribute('result', result_data)
359
+ handle_span.message = 'handle model response -> final result'
360
+ return FinalResultNode[DepsT, NodeRunEndT](MarkFinalResult(result_data, None))
361
+ else:
362
+ ctx.state.increment_retries(ctx.deps.max_result_retries)
363
+ return ModelRequestNode[DepsT, NodeRunEndT](
364
+ _messages.ModelRequest(
365
+ parts=[
366
+ _messages.RetryPromptPart(
367
+ content='Plain text responses are not permitted, please call one of the functions instead.',
368
+ )
369
+ ]
370
+ )
371
+ )
372
+
373
+
374
+ @dataclasses.dataclass
375
+ class StreamModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
376
+ """Make a request to the model using the last message in state.message_history (or a specified request)."""
377
+
378
+ request: _messages.ModelRequest
379
+ _result: StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]] | None = (
380
+ field(default=None, repr=False)
381
+ )
382
+
383
+ async def run(
384
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
385
+ ) -> Union[StreamModelRequestNode[DepsT, NodeRunEndT], End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: # noqa UP007
386
+ if self._result is not None:
387
+ return self._result
388
+
389
+ async with self.run_to_result(ctx) as final_node:
390
+ return final_node
391
+
392
+ @asynccontextmanager
393
+ async def run_to_result(
394
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
395
+ ) -> AsyncIterator[StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]]]:
396
+ result_schema = ctx.deps.result_schema
397
+
398
+ ctx.state.message_history.append(self.request)
399
+
400
+ # Check usage
401
+ if ctx.deps.usage_limits:
402
+ ctx.deps.usage_limits.check_before_request(ctx.state.usage)
403
+
404
+ # Increment run_step
405
+ ctx.state.run_step += 1
406
+
407
+ with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step):
408
+ agent_model = await _prepare_model(ctx)
409
+
410
+ # Actually make the model request
411
+ model_settings = merge_model_settings(ctx.deps.model_settings, None)
412
+ with _logfire.span('model request {run_step=}', run_step=ctx.state.run_step) as model_req_span:
413
+ async with agent_model.request_stream(ctx.state.message_history, model_settings) as streamed_response:
414
+ ctx.state.usage.requests += 1
415
+ model_req_span.set_attribute('response_type', streamed_response.__class__.__name__)
416
+ # We want to end the "model request" span here, but we can't exit the context manager
417
+ # in the traditional way
418
+ model_req_span.__exit__(None, None, None)
419
+
420
+ with _logfire.span('handle model response') as handle_span:
421
+ received_text = False
422
+
423
+ async for maybe_part_event in streamed_response:
424
+ if isinstance(maybe_part_event, _messages.PartStartEvent):
425
+ new_part = maybe_part_event.part
426
+ if isinstance(new_part, _messages.TextPart):
427
+ received_text = True
428
+ if _allow_text_result(result_schema):
429
+ handle_span.message = 'handle model response -> final result'
430
+ streamed_run_result = _build_streamed_run_result(streamed_response, None, ctx)
431
+ self._result = End(streamed_run_result)
432
+ yield self._result
433
+ return
434
+ elif isinstance(new_part, _messages.ToolCallPart):
435
+ if result_schema is not None and (match := result_schema.find_tool([new_part])):
436
+ call, _ = match
437
+ handle_span.message = 'handle model response -> final result'
438
+ streamed_run_result = _build_streamed_run_result(
439
+ streamed_response, call.tool_name, ctx
440
+ )
441
+ self._result = End(streamed_run_result)
442
+ yield self._result
443
+ return
444
+ else:
445
+ assert_never(new_part)
446
+
447
+ tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
448
+ parts: list[_messages.ModelRequestPart] = []
449
+ model_response = streamed_response.get()
450
+ if not model_response.parts:
451
+ raise exceptions.UnexpectedModelBehavior('Received empty model response')
452
+ ctx.state.message_history.append(model_response)
453
+
454
+ run_context = _build_run_context(ctx)
455
+ for p in model_response.parts:
456
+ if isinstance(p, _messages.ToolCallPart):
457
+ if tool := ctx.deps.function_tools.get(p.tool_name):
458
+ tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
459
+ else:
460
+ parts.append(_unknown_tool(p.tool_name, ctx))
461
+
462
+ if received_text and not tasks and not parts:
463
+ # Can only get here if self._allow_text_result returns `False` for the provided result_schema
464
+ ctx.state.increment_retries(ctx.deps.max_result_retries)
465
+ self._result = StreamModelRequestNode[DepsT, NodeRunEndT](
466
+ _messages.ModelRequest(
467
+ parts=[
468
+ _messages.RetryPromptPart(
469
+ content='Plain text responses are not permitted, please call one of the functions instead.',
470
+ )
471
+ ]
472
+ )
473
+ )
474
+ yield self._result
475
+ return
476
+
477
+ with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
478
+ task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
479
+ parts.extend(task_results)
480
+
481
+ next_request = _messages.ModelRequest(parts=parts)
482
+ if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
483
+ try:
484
+ ctx.state.increment_retries(ctx.deps.max_result_retries)
485
+ except:
486
+ # TODO: This is janky, so I think we should probably change it, but how?
487
+ ctx.state.message_history.append(next_request)
488
+ raise
489
+
490
+ handle_span.set_attribute('tool_responses', parts)
491
+ tool_responses_str = ' '.join(r.part_kind for r in parts)
492
+ handle_span.message = f'handle model response -> {tool_responses_str}'
493
+ # the model_response should have been fully streamed by now, we can add its usage
494
+ streamed_response_usage = streamed_response.usage()
495
+ run_context.usage.incr(streamed_response_usage)
496
+ ctx.deps.usage_limits.check_tokens(run_context.usage)
497
+ self._result = StreamModelRequestNode[DepsT, NodeRunEndT](next_request)
498
+ yield self._result
499
+ return
500
+
501
+
502
+ @dataclasses.dataclass
503
+ class FinalResultNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[NodeRunEndT]]):
504
+ """Produce the final result of the run."""
505
+
506
+ data: MarkFinalResult[NodeRunEndT]
507
+ """The final result data."""
508
+ extra_parts: list[_messages.ModelRequestPart] = dataclasses.field(default_factory=list)
509
+
510
+ async def run(
511
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
512
+ ) -> End[MarkFinalResult[NodeRunEndT]]:
513
+ run_span = ctx.deps.run_span
514
+ usage = ctx.state.usage
515
+ messages = ctx.state.message_history
516
+
517
+ # TODO: For backwards compatibility, append a new ModelRequest using the tool returns and retries
518
+ if self.extra_parts:
519
+ messages.append(_messages.ModelRequest(parts=self.extra_parts))
520
+
521
+ # TODO: Set this attribute somewhere
522
+ # handle_span = self.handle_model_response_span
523
+ # handle_span.set_attribute('final_data', self.data)
524
+ run_span.set_attribute('usage', usage)
525
+ run_span.set_attribute('all_messages', messages)
526
+
527
+ # End the run with self.data
528
+ return End(self.data)
529
+
530
+
531
+ def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
532
+ return RunContext[DepsT](
533
+ deps=ctx.deps.user_deps,
534
+ model=ctx.deps.model,
535
+ usage=ctx.state.usage,
536
+ prompt=ctx.deps.prompt,
537
+ messages=ctx.state.message_history,
538
+ run_step=ctx.state.run_step,
539
+ )
540
+
541
+
542
+ def _build_streamed_run_result(
543
+ result_stream: models.StreamedResponse,
544
+ result_tool_name: str | None,
545
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
546
+ ) -> result.StreamedRunResult[DepsT, NodeRunEndT]:
547
+ new_message_index = ctx.deps.new_message_index
548
+ result_schema = ctx.deps.result_schema
549
+ run_span = ctx.deps.run_span
550
+ usage_limits = ctx.deps.usage_limits
551
+ messages = ctx.state.message_history
552
+ run_context = _build_run_context(ctx)
553
+
554
+ async def on_complete():
555
+ """Called when the stream has completed.
556
+
557
+ The model response will have been added to messages by now
558
+ by `StreamedRunResult._marked_completed`.
559
+ """
560
+ last_message = messages[-1]
561
+ assert isinstance(last_message, _messages.ModelResponse)
562
+ tool_calls = [part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)]
563
+ parts = await _process_function_tools(
564
+ tool_calls,
565
+ result_tool_name,
566
+ ctx,
567
+ )
568
+ # TODO: Should we do something here related to the retry count?
569
+ # Maybe we should move the incrementing of the retry count to where we actually make a request?
570
+ # if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
571
+ # ctx.state.increment_retries(ctx.deps.max_result_retries)
572
+ if parts:
573
+ messages.append(_messages.ModelRequest(parts))
574
+ run_span.set_attribute('all_messages', messages)
575
+
576
+ return result.StreamedRunResult[DepsT, NodeRunEndT](
577
+ messages,
578
+ new_message_index,
579
+ usage_limits,
580
+ result_stream,
581
+ result_schema,
582
+ run_context,
583
+ ctx.deps.result_validators,
584
+ result_tool_name,
585
+ on_complete,
586
+ )
587
+
588
+
589
+ async def _process_function_tools(
590
+ tool_calls: list[_messages.ToolCallPart],
591
+ result_tool_name: str | None,
592
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
593
+ ) -> list[_messages.ModelRequestPart]:
594
+ """Process function (non-result) tool calls in parallel.
595
+
596
+ Also add stub return parts for any other tools that need it.
597
+ """
598
+ parts: list[_messages.ModelRequestPart] = []
599
+ tasks: list[asyncio.Task[_messages.ToolReturnPart | _messages.RetryPromptPart]] = []
600
+
601
+ stub_function_tools = bool(result_tool_name) and ctx.deps.end_strategy == 'early'
602
+ result_schema = ctx.deps.result_schema
603
+
604
+ # we rely on the fact that if we found a result, it's the first result tool in the last
605
+ found_used_result_tool = False
606
+ run_context = _build_run_context(ctx)
607
+
608
+ for call in tool_calls:
609
+ if call.tool_name == result_tool_name and not found_used_result_tool:
610
+ found_used_result_tool = True
611
+ parts.append(
612
+ _messages.ToolReturnPart(
613
+ tool_name=call.tool_name,
614
+ content='Final result processed.',
615
+ tool_call_id=call.tool_call_id,
616
+ )
617
+ )
618
+ elif tool := ctx.deps.function_tools.get(call.tool_name):
619
+ if stub_function_tools:
620
+ parts.append(
621
+ _messages.ToolReturnPart(
622
+ tool_name=call.tool_name,
623
+ content='Tool not executed - a final result was already processed.',
624
+ tool_call_id=call.tool_call_id,
625
+ )
626
+ )
627
+ else:
628
+ tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
629
+ elif result_schema is not None and call.tool_name in result_schema.tools:
630
+ # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
631
+ # validation, we don't add another part here
632
+ if result_tool_name is not None:
633
+ parts.append(
634
+ _messages.ToolReturnPart(
635
+ tool_name=call.tool_name,
636
+ content='Result tool not used - a final result was already processed.',
637
+ tool_call_id=call.tool_call_id,
638
+ )
639
+ )
640
+ else:
641
+ parts.append(_unknown_tool(call.tool_name, ctx))
642
+
643
+ # Run all tool tasks in parallel
644
+ if tasks:
645
+ with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
646
+ task_results: Sequence[_messages.ToolReturnPart | _messages.RetryPromptPart] = await asyncio.gather(*tasks)
647
+ for result in task_results:
648
+ if isinstance(result, _messages.ToolReturnPart):
649
+ parts.append(result)
650
+ elif isinstance(result, _messages.RetryPromptPart):
651
+ parts.append(result)
652
+ else:
653
+ assert_never(result)
654
+ return parts
655
+
656
+
657
+ def _unknown_tool(
658
+ tool_name: str,
659
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
660
+ ) -> _messages.RetryPromptPart:
661
+ ctx.state.increment_retries(ctx.deps.max_result_retries)
662
+ tool_names = list(ctx.deps.function_tools.keys())
663
+ if result_schema := ctx.deps.result_schema:
664
+ tool_names.extend(result_schema.tool_names())
665
+
666
+ if tool_names:
667
+ msg = f'Available tools: {", ".join(tool_names)}'
668
+ else:
669
+ msg = 'No tools available.'
670
+
671
+ return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
672
+
673
+
674
+ async def _validate_result(
675
+ result_data: T,
676
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
677
+ tool_call: _messages.ToolCallPart | None,
678
+ ) -> T:
679
+ for validator in ctx.deps.result_validators:
680
+ run_context = _build_run_context(ctx)
681
+ result_data = await validator.validate(result_data, tool_call, run_context)
682
+ return result_data
683
+
684
+
685
+ def _allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool:
686
+ return result_schema is None or result_schema.allow_text_result
687
+
688
+
689
+ @dataclasses.dataclass
690
+ class _RunMessages:
691
+ messages: list[_messages.ModelMessage]
692
+ used: bool = False
693
+
694
+
695
+ _messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var')
696
+
697
+
698
+ @contextmanager
699
+ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
700
+ """Context manager to access the messages used in a [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] call.
701
+
702
+ Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information.
703
+
704
+ Examples:
705
+ ```python
706
+ from pydantic_ai import Agent, capture_run_messages
707
+
708
+ agent = Agent('test')
709
+
710
+ with capture_run_messages() as messages:
711
+ try:
712
+ result = agent.run_sync('foobar')
713
+ except Exception:
714
+ print(messages)
715
+ raise
716
+ ```
717
+
718
+ !!! note
719
+ If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
720
+ `messages` will represent the messages exchanged during the first call only.
721
+ """
722
+ try:
723
+ yield _messages_ctx_var.get().messages
724
+ except LookupError:
725
+ messages: list[_messages.ModelMessage] = []
726
+ token = _messages_ctx_var.set(_RunMessages(messages))
727
+ try:
728
+ yield messages
729
+ finally:
730
+ _messages_ctx_var.reset(token)
731
+
732
+
733
+ def get_captured_run_messages() -> _RunMessages:
734
+ return _messages_ctx_var.get()
735
+
736
+
737
+ def build_agent_graph(
738
+ name: str | None, deps_type: type[DepsT], result_type: type[ResultT]
739
+ ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]]:
740
+ # We'll define the known node classes:
741
+ nodes = (
742
+ UserPromptNode[DepsT],
743
+ ModelRequestNode[DepsT],
744
+ HandleResponseNode[DepsT],
745
+ FinalResultNode[DepsT, ResultT],
746
+ )
747
+ graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]](
748
+ nodes=nodes,
749
+ name=name or 'Agent',
750
+ state_type=GraphAgentState,
751
+ run_end_type=MarkFinalResult[result_type],
752
+ auto_instrument=False,
753
+ )
754
+ return graph
755
+
756
+
757
+ def build_agent_stream_graph(
758
+ name: str | None, deps_type: type[DepsT], result_type: type[ResultT] | None
759
+ ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]:
760
+ nodes = [
761
+ StreamUserPromptNode[DepsT, result.StreamedRunResult[DepsT, ResultT]],
762
+ StreamModelRequestNode[DepsT, result.StreamedRunResult[DepsT, ResultT]],
763
+ ]
764
+ graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]](
765
+ nodes=nodes,
766
+ name=name or 'Agent',
767
+ state_type=GraphAgentState,
768
+ run_end_type=result.StreamedRunResult[DepsT, result_type],
769
+ )
770
+ return graph