pydantic-ai-slim 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,774 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import asyncio
4
+ import dataclasses
5
+ from abc import ABC
6
+ from collections.abc import AsyncIterator, Iterator, Sequence
7
+ from contextlib import asynccontextmanager, contextmanager
8
+ from contextvars import ContextVar
9
+ from dataclasses import field
10
+ from typing import Any, Generic, Literal, Union, cast
11
+
12
+ import logfire_api
13
+ from typing_extensions import TypeVar, assert_never
14
+
15
+ from pydantic_graph import BaseNode, Graph, GraphRunContext
16
+ from pydantic_graph.nodes import End, NodeRunEndT
17
+
18
+ from . import (
19
+ _result,
20
+ _system_prompt,
21
+ exceptions,
22
+ messages as _messages,
23
+ models,
24
+ result,
25
+ usage as _usage,
26
+ )
27
+ from .result import ResultDataT
28
+ from .settings import ModelSettings, merge_model_settings
29
+ from .tools import (
30
+ RunContext,
31
+ Tool,
32
+ ToolDefinition,
33
+ )
34
+
35
+ _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
36
+
37
+ # while waiting for https://github.com/pydantic/logfire/issues/745
38
+ try:
39
+ import logfire._internal.stack_info
40
+ except ImportError:
41
+ pass
42
+ else:
43
+ from pathlib import Path
44
+
45
+ logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
46
+
47
+ T = TypeVar('T')
48
+ NoneType = type(None)
49
+ EndStrategy = Literal['early', 'exhaustive']
50
+ """The strategy for handling multiple tool calls when a final result is found.
51
+
52
+ - `'early'`: Stop processing other tool calls once a final result is found
53
+ - `'exhaustive'`: Process all tool calls even after finding a final result
54
+ """
55
+ DepsT = TypeVar('DepsT')
56
+ ResultT = TypeVar('ResultT')
57
+
58
+
59
+ @dataclasses.dataclass
60
+ class MarkFinalResult(Generic[ResultDataT]):
61
+ """Marker class to indicate that the result is the final result.
62
+
63
+ This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly.
64
+
65
+ It also avoids problems in the case where the result type is itself `None`, but is set.
66
+ """
67
+
68
+ data: ResultDataT
69
+ """The final result data."""
70
+ tool_name: str | None
71
+ """Name of the final result tool, None if the result is a string."""
72
+
73
+
74
+ @dataclasses.dataclass
75
+ class GraphAgentState:
76
+ """State kept across the execution of the agent graph."""
77
+
78
+ message_history: list[_messages.ModelMessage]
79
+ usage: _usage.Usage
80
+ retries: int
81
+ run_step: int
82
+
83
+ def increment_retries(self, max_result_retries: int) -> None:
84
+ self.retries += 1
85
+ if self.retries > max_result_retries:
86
+ raise exceptions.UnexpectedModelBehavior(
87
+ f'Exceeded maximum retries ({max_result_retries}) for result validation'
88
+ )
89
+
90
+
91
+ @dataclasses.dataclass
92
+ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
93
+ """Dependencies/config passed to the agent graph."""
94
+
95
+ user_deps: DepsT
96
+
97
+ prompt: str
98
+ new_message_index: int
99
+
100
+ model: models.Model
101
+ model_settings: ModelSettings | None
102
+ usage_limits: _usage.UsageLimits
103
+ max_result_retries: int
104
+ end_strategy: EndStrategy
105
+
106
+ result_schema: _result.ResultSchema[ResultDataT] | None
107
+ result_tools: list[ToolDefinition]
108
+ result_validators: list[_result.ResultValidator[DepsT, ResultDataT]]
109
+
110
+ function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
111
+
112
+ run_span: logfire_api.LogfireSpan
113
+
114
+
115
+ @dataclasses.dataclass
116
+ class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
117
+ user_prompt: str
118
+
119
+ system_prompts: tuple[str, ...]
120
+ system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
121
+ system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
122
+
123
+ async def _get_first_message(
124
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
125
+ ) -> _messages.ModelRequest:
126
+ run_context = _build_run_context(ctx)
127
+ history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context)
128
+ ctx.state.message_history = history
129
+ run_context.messages = history
130
+
131
+ # TODO: We need to make it so that function_tools are not shared between runs
132
+ # See comment on the current_retry field of `Tool` for more details.
133
+ for tool in ctx.deps.function_tools.values():
134
+ tool.current_retry = 0
135
+ return next_message
136
+
137
+ async def _prepare_messages(
138
+ self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[DepsT]
139
+ ) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
140
+ try:
141
+ ctx_messages = get_captured_run_messages()
142
+ except LookupError:
143
+ messages: list[_messages.ModelMessage] = []
144
+ else:
145
+ if ctx_messages.used:
146
+ messages = []
147
+ else:
148
+ messages = ctx_messages.messages
149
+ ctx_messages.used = True
150
+
151
+ if message_history:
152
+ # Shallow copy messages
153
+ messages.extend(message_history)
154
+ # Reevaluate any dynamic system prompt parts
155
+ await self._reevaluate_dynamic_prompts(messages, run_context)
156
+ return messages, _messages.ModelRequest([_messages.UserPromptPart(user_prompt)])
157
+ else:
158
+ parts = await self._sys_parts(run_context)
159
+ parts.append(_messages.UserPromptPart(user_prompt))
160
+ return messages, _messages.ModelRequest(parts)
161
+
162
+ async def _reevaluate_dynamic_prompts(
163
+ self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
164
+ ) -> None:
165
+ """Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
166
+ # Only proceed if there's at least one dynamic runner.
167
+ if self.system_prompt_dynamic_functions:
168
+ for msg in messages:
169
+ if isinstance(msg, _messages.ModelRequest):
170
+ for i, part in enumerate(msg.parts):
171
+ if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
172
+ # Look up the runner by its ref
173
+ if runner := self.system_prompt_dynamic_functions.get(part.dynamic_ref):
174
+ updated_part_content = await runner.run(run_context)
175
+ msg.parts[i] = _messages.SystemPromptPart(
176
+ updated_part_content, dynamic_ref=part.dynamic_ref
177
+ )
178
+
179
+ async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.ModelRequestPart]:
180
+ """Build the initial messages for the conversation."""
181
+ messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self.system_prompts]
182
+ for sys_prompt_runner in self.system_prompt_functions:
183
+ prompt = await sys_prompt_runner.run(run_context)
184
+ if sys_prompt_runner.dynamic:
185
+ messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
186
+ else:
187
+ messages.append(_messages.SystemPromptPart(prompt))
188
+ return messages
189
+
190
+
191
+ @dataclasses.dataclass
192
+ class UserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
193
+ async def run(
194
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
195
+ ) -> ModelRequestNode[DepsT, NodeRunEndT]:
196
+ return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
197
+
198
+
199
+ @dataclasses.dataclass
200
+ class StreamUserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
201
+ async def run(
202
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
203
+ ) -> StreamModelRequestNode[DepsT, NodeRunEndT]:
204
+ return StreamModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
205
+
206
+
207
+ async def _prepare_request_parameters(
208
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
209
+ ) -> models.ModelRequestParameters:
210
+ """Build tools and create an agent model."""
211
+ function_tool_defs: list[ToolDefinition] = []
212
+
213
+ run_context = _build_run_context(ctx)
214
+
215
+ async def add_tool(tool: Tool[DepsT]) -> None:
216
+ ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
217
+ if tool_def := await tool.prepare_tool_def(ctx):
218
+ function_tool_defs.append(tool_def)
219
+
220
+ await asyncio.gather(*map(add_tool, ctx.deps.function_tools.values()))
221
+
222
+ result_schema = ctx.deps.result_schema
223
+ return models.ModelRequestParameters(
224
+ function_tools=function_tool_defs,
225
+ allow_text_result=_allow_text_result(result_schema),
226
+ result_tools=result_schema.tool_defs() if result_schema is not None else [],
227
+ )
228
+
229
+
230
+ @dataclasses.dataclass
231
+ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
232
+ """Make a request to the model using the last message in state.message_history."""
233
+
234
+ request: _messages.ModelRequest
235
+
236
+ async def run(
237
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
238
+ ) -> HandleResponseNode[DepsT, NodeRunEndT]:
239
+ ctx.state.message_history.append(self.request)
240
+
241
+ # Check usage
242
+ if ctx.deps.usage_limits:
243
+ ctx.deps.usage_limits.check_before_request(ctx.state.usage)
244
+
245
+ # Increment run_step
246
+ ctx.state.run_step += 1
247
+
248
+ with _logfire.span('preparing model request params {run_step=}', run_step=ctx.state.run_step):
249
+ model_request_parameters = await _prepare_request_parameters(ctx)
250
+
251
+ # Actually make the model request
252
+ model_settings = merge_model_settings(ctx.deps.model_settings, None)
253
+ with _logfire.span('model request') as span:
254
+ model_response, request_usage = await ctx.deps.model.request(
255
+ ctx.state.message_history, model_settings, model_request_parameters
256
+ )
257
+ span.set_attribute('response', model_response)
258
+ span.set_attribute('usage', request_usage)
259
+
260
+ # Update usage
261
+ ctx.state.usage.incr(request_usage, requests=1)
262
+ if ctx.deps.usage_limits:
263
+ ctx.deps.usage_limits.check_tokens(ctx.state.usage)
264
+
265
+ # Append the model response to state.message_history
266
+ ctx.state.message_history.append(model_response)
267
+ return HandleResponseNode(model_response)
268
+
269
+
270
+ @dataclasses.dataclass
271
+ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
272
+ """Process e response from a model, decide whether to end the run or make a new request."""
273
+
274
+ model_response: _messages.ModelResponse
275
+
276
+ async def run(
277
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
278
+ ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], FinalResultNode[DepsT, NodeRunEndT]]: # noqa UP007
279
+ with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
280
+ texts: list[str] = []
281
+ tool_calls: list[_messages.ToolCallPart] = []
282
+ for part in self.model_response.parts:
283
+ if isinstance(part, _messages.TextPart):
284
+ # ignore empty content for text parts, see #437
285
+ if part.content:
286
+ texts.append(part.content)
287
+ elif isinstance(part, _messages.ToolCallPart):
288
+ tool_calls.append(part)
289
+ else:
290
+ assert_never(part)
291
+
292
+ # At the moment, we prioritize at least executing tool calls if they are present.
293
+ # In the future, we'd consider making this configurable at the agent or run level.
294
+ # This accounts for cases like anthropic returns that might contain a text response
295
+ # and a tool call response, where the text response just indicates the tool call will happen.
296
+ if tool_calls:
297
+ return await self._handle_tool_calls_response(ctx, tool_calls, handle_span)
298
+ elif texts:
299
+ return await self._handle_text_response(ctx, texts, handle_span)
300
+ else:
301
+ raise exceptions.UnexpectedModelBehavior('Received empty model response')
302
+
303
+ async def _handle_tool_calls_response(
304
+ self,
305
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
306
+ tool_calls: list[_messages.ToolCallPart],
307
+ handle_span: logfire_api.LogfireSpan,
308
+ ):
309
+ result_schema = ctx.deps.result_schema
310
+
311
+ # first look for the result tool call
312
+ final_result: MarkFinalResult[NodeRunEndT] | None = None
313
+ parts: list[_messages.ModelRequestPart] = []
314
+ if result_schema is not None:
315
+ if match := result_schema.find_tool(tool_calls):
316
+ call, result_tool = match
317
+ try:
318
+ result_data = result_tool.validate(call)
319
+ result_data = await _validate_result(result_data, ctx, call)
320
+ except _result.ToolRetryError as e:
321
+ # TODO: Should only increment retry stuff once per node execution, not for each tool call
322
+ # Also, should increment the tool-specific retry count rather than the run retry count
323
+ ctx.state.increment_retries(ctx.deps.max_result_retries)
324
+ parts.append(e.tool_retry)
325
+ else:
326
+ final_result = MarkFinalResult(result_data, call.tool_name)
327
+
328
+ # Then build the other request parts based on end strategy
329
+ tool_responses = await _process_function_tools(tool_calls, final_result and final_result.tool_name, ctx)
330
+
331
+ if final_result:
332
+ handle_span.set_attribute('result', final_result.data)
333
+ handle_span.message = 'handle model response -> final result'
334
+ return FinalResultNode[DepsT, NodeRunEndT](final_result, tool_responses)
335
+ else:
336
+ if tool_responses:
337
+ handle_span.set_attribute('tool_responses', tool_responses)
338
+ tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
339
+ handle_span.message = f'handle model response -> {tool_responses_str}'
340
+ parts.extend(tool_responses)
341
+ return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts))
342
+
343
+ async def _handle_text_response(
344
+ self,
345
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
346
+ texts: list[str],
347
+ handle_span: logfire_api.LogfireSpan,
348
+ ):
349
+ result_schema = ctx.deps.result_schema
350
+
351
+ text = '\n\n'.join(texts)
352
+ if _allow_text_result(result_schema):
353
+ result_data_input = cast(NodeRunEndT, text)
354
+ try:
355
+ result_data = await _validate_result(result_data_input, ctx, None)
356
+ except _result.ToolRetryError as e:
357
+ ctx.state.increment_retries(ctx.deps.max_result_retries)
358
+ return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
359
+ else:
360
+ handle_span.set_attribute('result', result_data)
361
+ handle_span.message = 'handle model response -> final result'
362
+ return FinalResultNode[DepsT, NodeRunEndT](MarkFinalResult(result_data, None))
363
+ else:
364
+ ctx.state.increment_retries(ctx.deps.max_result_retries)
365
+ return ModelRequestNode[DepsT, NodeRunEndT](
366
+ _messages.ModelRequest(
367
+ parts=[
368
+ _messages.RetryPromptPart(
369
+ content='Plain text responses are not permitted, please call one of the functions instead.',
370
+ )
371
+ ]
372
+ )
373
+ )
374
+
375
+
376
+ @dataclasses.dataclass
377
+ class StreamModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
378
+ """Make a request to the model using the last message in state.message_history (or a specified request)."""
379
+
380
+ request: _messages.ModelRequest
381
+ _result: StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]] | None = (
382
+ field(default=None, repr=False)
383
+ )
384
+
385
+ async def run(
386
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
387
+ ) -> Union[StreamModelRequestNode[DepsT, NodeRunEndT], End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: # noqa UP007
388
+ if self._result is not None:
389
+ return self._result
390
+
391
+ async with self.run_to_result(ctx) as final_node:
392
+ return final_node
393
+
394
+ @asynccontextmanager
395
+ async def run_to_result(
396
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
397
+ ) -> AsyncIterator[StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]]]:
398
+ result_schema = ctx.deps.result_schema
399
+
400
+ ctx.state.message_history.append(self.request)
401
+
402
+ # Check usage
403
+ if ctx.deps.usage_limits:
404
+ ctx.deps.usage_limits.check_before_request(ctx.state.usage)
405
+
406
+ # Increment run_step
407
+ ctx.state.run_step += 1
408
+
409
+ with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step):
410
+ model_request_parameters = await _prepare_request_parameters(ctx)
411
+
412
+ # Actually make the model request
413
+ model_settings = merge_model_settings(ctx.deps.model_settings, None)
414
+ with _logfire.span('model request {run_step=}', run_step=ctx.state.run_step) as model_req_span:
415
+ async with ctx.deps.model.request_stream(
416
+ ctx.state.message_history, model_settings, model_request_parameters
417
+ ) as streamed_response:
418
+ ctx.state.usage.requests += 1
419
+ model_req_span.set_attribute('response_type', streamed_response.__class__.__name__)
420
+ # We want to end the "model request" span here, but we can't exit the context manager
421
+ # in the traditional way
422
+ model_req_span.__exit__(None, None, None)
423
+
424
+ with _logfire.span('handle model response') as handle_span:
425
+ received_text = False
426
+
427
+ async for maybe_part_event in streamed_response:
428
+ if isinstance(maybe_part_event, _messages.PartStartEvent):
429
+ new_part = maybe_part_event.part
430
+ if isinstance(new_part, _messages.TextPart):
431
+ received_text = True
432
+ if _allow_text_result(result_schema):
433
+ handle_span.message = 'handle model response -> final result'
434
+ streamed_run_result = _build_streamed_run_result(streamed_response, None, ctx)
435
+ self._result = End(streamed_run_result)
436
+ yield self._result
437
+ return
438
+ elif isinstance(new_part, _messages.ToolCallPart):
439
+ if result_schema is not None and (match := result_schema.find_tool([new_part])):
440
+ call, _ = match
441
+ handle_span.message = 'handle model response -> final result'
442
+ streamed_run_result = _build_streamed_run_result(
443
+ streamed_response, call.tool_name, ctx
444
+ )
445
+ self._result = End(streamed_run_result)
446
+ yield self._result
447
+ return
448
+ else:
449
+ assert_never(new_part)
450
+
451
+ tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
452
+ parts: list[_messages.ModelRequestPart] = []
453
+ model_response = streamed_response.get()
454
+ if not model_response.parts:
455
+ raise exceptions.UnexpectedModelBehavior('Received empty model response')
456
+ ctx.state.message_history.append(model_response)
457
+
458
+ run_context = _build_run_context(ctx)
459
+ for p in model_response.parts:
460
+ if isinstance(p, _messages.ToolCallPart):
461
+ if tool := ctx.deps.function_tools.get(p.tool_name):
462
+ tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
463
+ else:
464
+ parts.append(_unknown_tool(p.tool_name, ctx))
465
+
466
+ if received_text and not tasks and not parts:
467
+ # Can only get here if self._allow_text_result returns `False` for the provided result_schema
468
+ ctx.state.increment_retries(ctx.deps.max_result_retries)
469
+ self._result = StreamModelRequestNode[DepsT, NodeRunEndT](
470
+ _messages.ModelRequest(
471
+ parts=[
472
+ _messages.RetryPromptPart(
473
+ content='Plain text responses are not permitted, please call one of the functions instead.',
474
+ )
475
+ ]
476
+ )
477
+ )
478
+ yield self._result
479
+ return
480
+
481
+ with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
482
+ task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
483
+ parts.extend(task_results)
484
+
485
+ next_request = _messages.ModelRequest(parts=parts)
486
+ if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
487
+ try:
488
+ ctx.state.increment_retries(ctx.deps.max_result_retries)
489
+ except:
490
+ # TODO: This is janky, so I think we should probably change it, but how?
491
+ ctx.state.message_history.append(next_request)
492
+ raise
493
+
494
+ handle_span.set_attribute('tool_responses', parts)
495
+ tool_responses_str = ' '.join(r.part_kind for r in parts)
496
+ handle_span.message = f'handle model response -> {tool_responses_str}'
497
+ # the model_response should have been fully streamed by now, we can add its usage
498
+ streamed_response_usage = streamed_response.usage()
499
+ run_context.usage.incr(streamed_response_usage)
500
+ ctx.deps.usage_limits.check_tokens(run_context.usage)
501
+ self._result = StreamModelRequestNode[DepsT, NodeRunEndT](next_request)
502
+ yield self._result
503
+ return
504
+
505
+
506
+ @dataclasses.dataclass
507
+ class FinalResultNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[NodeRunEndT]]):
508
+ """Produce the final result of the run."""
509
+
510
+ data: MarkFinalResult[NodeRunEndT]
511
+ """The final result data."""
512
+ extra_parts: list[_messages.ModelRequestPart] = dataclasses.field(default_factory=list)
513
+
514
+ async def run(
515
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
516
+ ) -> End[MarkFinalResult[NodeRunEndT]]:
517
+ run_span = ctx.deps.run_span
518
+ usage = ctx.state.usage
519
+ messages = ctx.state.message_history
520
+
521
+ # TODO: For backwards compatibility, append a new ModelRequest using the tool returns and retries
522
+ if self.extra_parts:
523
+ messages.append(_messages.ModelRequest(parts=self.extra_parts))
524
+
525
+ # TODO: Set this attribute somewhere
526
+ # handle_span = self.handle_model_response_span
527
+ # handle_span.set_attribute('final_data', self.data)
528
+ run_span.set_attribute('usage', usage)
529
+ run_span.set_attribute('all_messages', messages)
530
+
531
+ # End the run with self.data
532
+ return End(self.data)
533
+
534
+
535
+ def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
536
+ return RunContext[DepsT](
537
+ deps=ctx.deps.user_deps,
538
+ model=ctx.deps.model,
539
+ usage=ctx.state.usage,
540
+ prompt=ctx.deps.prompt,
541
+ messages=ctx.state.message_history,
542
+ run_step=ctx.state.run_step,
543
+ )
544
+
545
+
546
+ def _build_streamed_run_result(
547
+ result_stream: models.StreamedResponse,
548
+ result_tool_name: str | None,
549
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
550
+ ) -> result.StreamedRunResult[DepsT, NodeRunEndT]:
551
+ new_message_index = ctx.deps.new_message_index
552
+ result_schema = ctx.deps.result_schema
553
+ run_span = ctx.deps.run_span
554
+ usage_limits = ctx.deps.usage_limits
555
+ messages = ctx.state.message_history
556
+ run_context = _build_run_context(ctx)
557
+
558
+ async def on_complete():
559
+ """Called when the stream has completed.
560
+
561
+ The model response will have been added to messages by now
562
+ by `StreamedRunResult._marked_completed`.
563
+ """
564
+ last_message = messages[-1]
565
+ assert isinstance(last_message, _messages.ModelResponse)
566
+ tool_calls = [part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)]
567
+ parts = await _process_function_tools(
568
+ tool_calls,
569
+ result_tool_name,
570
+ ctx,
571
+ )
572
+ # TODO: Should we do something here related to the retry count?
573
+ # Maybe we should move the incrementing of the retry count to where we actually make a request?
574
+ # if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
575
+ # ctx.state.increment_retries(ctx.deps.max_result_retries)
576
+ if parts:
577
+ messages.append(_messages.ModelRequest(parts))
578
+ run_span.set_attribute('all_messages', messages)
579
+
580
+ return result.StreamedRunResult[DepsT, NodeRunEndT](
581
+ messages,
582
+ new_message_index,
583
+ usage_limits,
584
+ result_stream,
585
+ result_schema,
586
+ run_context,
587
+ ctx.deps.result_validators,
588
+ result_tool_name,
589
+ on_complete,
590
+ )
591
+
592
+
593
+ async def _process_function_tools(
594
+ tool_calls: list[_messages.ToolCallPart],
595
+ result_tool_name: str | None,
596
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
597
+ ) -> list[_messages.ModelRequestPart]:
598
+ """Process function (non-result) tool calls in parallel.
599
+
600
+ Also add stub return parts for any other tools that need it.
601
+ """
602
+ parts: list[_messages.ModelRequestPart] = []
603
+ tasks: list[asyncio.Task[_messages.ToolReturnPart | _messages.RetryPromptPart]] = []
604
+
605
+ stub_function_tools = bool(result_tool_name) and ctx.deps.end_strategy == 'early'
606
+ result_schema = ctx.deps.result_schema
607
+
608
+ # we rely on the fact that if we found a result, it's the first result tool in the last
609
+ found_used_result_tool = False
610
+ run_context = _build_run_context(ctx)
611
+
612
+ for call in tool_calls:
613
+ if call.tool_name == result_tool_name and not found_used_result_tool:
614
+ found_used_result_tool = True
615
+ parts.append(
616
+ _messages.ToolReturnPart(
617
+ tool_name=call.tool_name,
618
+ content='Final result processed.',
619
+ tool_call_id=call.tool_call_id,
620
+ )
621
+ )
622
+ elif tool := ctx.deps.function_tools.get(call.tool_name):
623
+ if stub_function_tools:
624
+ parts.append(
625
+ _messages.ToolReturnPart(
626
+ tool_name=call.tool_name,
627
+ content='Tool not executed - a final result was already processed.',
628
+ tool_call_id=call.tool_call_id,
629
+ )
630
+ )
631
+ else:
632
+ tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
633
+ elif result_schema is not None and call.tool_name in result_schema.tools:
634
+ # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
635
+ # validation, we don't add another part here
636
+ if result_tool_name is not None:
637
+ parts.append(
638
+ _messages.ToolReturnPart(
639
+ tool_name=call.tool_name,
640
+ content='Result tool not used - a final result was already processed.',
641
+ tool_call_id=call.tool_call_id,
642
+ )
643
+ )
644
+ else:
645
+ parts.append(_unknown_tool(call.tool_name, ctx))
646
+
647
+ # Run all tool tasks in parallel
648
+ if tasks:
649
+ with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
650
+ task_results: Sequence[_messages.ToolReturnPart | _messages.RetryPromptPart] = await asyncio.gather(*tasks)
651
+ for result in task_results:
652
+ if isinstance(result, _messages.ToolReturnPart):
653
+ parts.append(result)
654
+ elif isinstance(result, _messages.RetryPromptPart):
655
+ parts.append(result)
656
+ else:
657
+ assert_never(result)
658
+ return parts
659
+
660
+
661
+ def _unknown_tool(
662
+ tool_name: str,
663
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
664
+ ) -> _messages.RetryPromptPart:
665
+ ctx.state.increment_retries(ctx.deps.max_result_retries)
666
+ tool_names = list(ctx.deps.function_tools.keys())
667
+ if result_schema := ctx.deps.result_schema:
668
+ tool_names.extend(result_schema.tool_names())
669
+
670
+ if tool_names:
671
+ msg = f'Available tools: {", ".join(tool_names)}'
672
+ else:
673
+ msg = 'No tools available.'
674
+
675
+ return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
676
+
677
+
678
+ async def _validate_result(
679
+ result_data: T,
680
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
681
+ tool_call: _messages.ToolCallPart | None,
682
+ ) -> T:
683
+ for validator in ctx.deps.result_validators:
684
+ run_context = _build_run_context(ctx)
685
+ result_data = await validator.validate(result_data, tool_call, run_context)
686
+ return result_data
687
+
688
+
689
+ def _allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool:
690
+ return result_schema is None or result_schema.allow_text_result
691
+
692
+
693
+ @dataclasses.dataclass
694
+ class _RunMessages:
695
+ messages: list[_messages.ModelMessage]
696
+ used: bool = False
697
+
698
+
699
+ _messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var')
700
+
701
+
702
+ @contextmanager
703
+ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
704
+ """Context manager to access the messages used in a [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] call.
705
+
706
+ Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information.
707
+
708
+ Examples:
709
+ ```python
710
+ from pydantic_ai import Agent, capture_run_messages
711
+
712
+ agent = Agent('test')
713
+
714
+ with capture_run_messages() as messages:
715
+ try:
716
+ result = agent.run_sync('foobar')
717
+ except Exception:
718
+ print(messages)
719
+ raise
720
+ ```
721
+
722
+ !!! note
723
+ If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
724
+ `messages` will represent the messages exchanged during the first call only.
725
+ """
726
+ try:
727
+ yield _messages_ctx_var.get().messages
728
+ except LookupError:
729
+ messages: list[_messages.ModelMessage] = []
730
+ token = _messages_ctx_var.set(_RunMessages(messages))
731
+ try:
732
+ yield messages
733
+ finally:
734
+ _messages_ctx_var.reset(token)
735
+
736
+
737
+ def get_captured_run_messages() -> _RunMessages:
738
+ return _messages_ctx_var.get()
739
+
740
+
741
+ def build_agent_graph(
742
+ name: str | None, deps_type: type[DepsT], result_type: type[ResultT]
743
+ ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]]:
744
+ # We'll define the known node classes:
745
+ nodes = (
746
+ UserPromptNode[DepsT],
747
+ ModelRequestNode[DepsT],
748
+ HandleResponseNode[DepsT],
749
+ FinalResultNode[DepsT, ResultT],
750
+ )
751
+ graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]](
752
+ nodes=nodes,
753
+ name=name or 'Agent',
754
+ state_type=GraphAgentState,
755
+ run_end_type=MarkFinalResult[result_type],
756
+ auto_instrument=False,
757
+ )
758
+ return graph
759
+
760
+
761
+ def build_agent_stream_graph(
762
+ name: str | None, deps_type: type[DepsT], result_type: type[ResultT] | None
763
+ ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]:
764
+ nodes = [
765
+ StreamUserPromptNode[DepsT, result.StreamedRunResult[DepsT, ResultT]],
766
+ StreamModelRequestNode[DepsT, result.StreamedRunResult[DepsT, ResultT]],
767
+ ]
768
+ graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]](
769
+ nodes=nodes,
770
+ name=name or 'Agent',
771
+ state_type=GraphAgentState,
772
+ run_end_type=result.StreamedRunResult[DepsT, result_type],
773
+ )
774
+ return graph