pydantic-ai-slim 0.0.21__py3-none-any.whl → 0.0.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +770 -0
- pydantic_ai/agent.py +182 -554
- pydantic_ai/models/__init__.py +4 -0
- pydantic_ai/models/gemini.py +7 -1
- pydantic_ai/models/openai.py +6 -1
- pydantic_ai/settings.py +5 -0
- {pydantic_ai_slim-0.0.21.dist-info → pydantic_ai_slim-0.0.22.dist-info}/METADATA +2 -3
- {pydantic_ai_slim-0.0.21.dist-info → pydantic_ai_slim-0.0.22.dist-info}/RECORD +9 -8
- {pydantic_ai_slim-0.0.21.dist-info → pydantic_ai_slim-0.0.22.dist-info}/WHEEL +0 -0
pydantic_ai/agent.py
CHANGED
|
@@ -5,14 +5,17 @@ import dataclasses
|
|
|
5
5
|
import inspect
|
|
6
6
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
7
7
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
|
|
8
|
-
from contextvars import ContextVar
|
|
9
8
|
from types import FrameType
|
|
10
|
-
from typing import Any, Callable, Generic,
|
|
9
|
+
from typing import Any, Callable, Generic, cast, final, overload
|
|
11
10
|
|
|
12
11
|
import logfire_api
|
|
13
|
-
from typing_extensions import TypeVar,
|
|
12
|
+
from typing_extensions import TypeVar, deprecated
|
|
13
|
+
|
|
14
|
+
from pydantic_graph import Graph, GraphRunContext, HistoryStep
|
|
15
|
+
from pydantic_graph.nodes import End
|
|
14
16
|
|
|
15
17
|
from . import (
|
|
18
|
+
_agent_graph,
|
|
16
19
|
_result,
|
|
17
20
|
_system_prompt,
|
|
18
21
|
_utils,
|
|
@@ -22,6 +25,7 @@ from . import (
|
|
|
22
25
|
result,
|
|
23
26
|
usage as _usage,
|
|
24
27
|
)
|
|
28
|
+
from ._agent_graph import EndStrategy, capture_run_messages # imported for re-export
|
|
25
29
|
from .result import ResultDataT
|
|
26
30
|
from .settings import ModelSettings, merge_model_settings
|
|
27
31
|
from .tools import (
|
|
@@ -29,7 +33,6 @@ from .tools import (
|
|
|
29
33
|
DocstringFormat,
|
|
30
34
|
RunContext,
|
|
31
35
|
Tool,
|
|
32
|
-
ToolDefinition,
|
|
33
36
|
ToolFuncContext,
|
|
34
37
|
ToolFuncEither,
|
|
35
38
|
ToolFuncPlain,
|
|
@@ -52,14 +55,7 @@ else:
|
|
|
52
55
|
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
|
|
53
56
|
|
|
54
57
|
T = TypeVar('T')
|
|
55
|
-
"""An invariant TypeVar."""
|
|
56
58
|
NoneType = type(None)
|
|
57
|
-
EndStrategy = Literal['early', 'exhaustive']
|
|
58
|
-
"""The strategy for handling multiple tool calls when a final result is found.
|
|
59
|
-
|
|
60
|
-
- `'early'`: Stop processing other tool calls once a final result is found
|
|
61
|
-
- `'exhaustive'`: Process all tool calls even after finding a final result
|
|
62
|
-
"""
|
|
63
59
|
RunResultDataT = TypeVar('RunResultDataT')
|
|
64
60
|
"""Type variable for the result data of a run where `result_type` was customized on the run call."""
|
|
65
61
|
|
|
@@ -104,18 +100,24 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
104
100
|
Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
|
|
105
101
|
be merged with this value, with the runtime argument taking priority.
|
|
106
102
|
"""
|
|
103
|
+
|
|
104
|
+
result_type: type[ResultDataT] = dataclasses.field(repr=False)
|
|
105
|
+
"""
|
|
106
|
+
The type of the result data, used to validate the result data, defaults to `str`.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
_deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
|
|
107
110
|
_result_tool_name: str = dataclasses.field(repr=False)
|
|
108
111
|
_result_tool_description: str | None = dataclasses.field(repr=False)
|
|
109
112
|
_result_schema: _result.ResultSchema[ResultDataT] | None = dataclasses.field(repr=False)
|
|
110
113
|
_result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] = dataclasses.field(repr=False)
|
|
111
114
|
_system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
|
|
112
|
-
_function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
|
|
113
|
-
_default_retries: int = dataclasses.field(repr=False)
|
|
114
115
|
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
|
|
115
116
|
_system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
|
|
116
117
|
repr=False
|
|
117
118
|
)
|
|
118
|
-
|
|
119
|
+
_function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
|
|
120
|
+
_default_retries: int = dataclasses.field(repr=False)
|
|
119
121
|
_max_result_retries: int = dataclasses.field(repr=False)
|
|
120
122
|
_override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
|
|
121
123
|
_override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
|
|
@@ -174,25 +176,30 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
174
176
|
self.end_strategy = end_strategy
|
|
175
177
|
self.name = name
|
|
176
178
|
self.model_settings = model_settings
|
|
179
|
+
self.result_type = result_type
|
|
180
|
+
|
|
181
|
+
self._deps_type = deps_type
|
|
182
|
+
|
|
177
183
|
self._result_tool_name = result_tool_name
|
|
178
184
|
self._result_tool_description = result_tool_description
|
|
179
|
-
self._result_schema = _result.ResultSchema[result_type].build(
|
|
185
|
+
self._result_schema: _result.ResultSchema[ResultDataT] | None = _result.ResultSchema[result_type].build(
|
|
180
186
|
result_type, result_tool_name, result_tool_description
|
|
181
187
|
)
|
|
188
|
+
self._result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] = []
|
|
182
189
|
|
|
183
190
|
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
|
|
184
|
-
self.
|
|
191
|
+
self._system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = []
|
|
192
|
+
self._system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = {}
|
|
193
|
+
|
|
194
|
+
self._function_tools: dict[str, Tool[AgentDepsT]] = {}
|
|
195
|
+
|
|
185
196
|
self._default_retries = retries
|
|
197
|
+
self._max_result_retries = result_retries if result_retries is not None else retries
|
|
186
198
|
for tool in tools:
|
|
187
199
|
if isinstance(tool, Tool):
|
|
188
200
|
self._register_tool(tool)
|
|
189
201
|
else:
|
|
190
202
|
self._register_tool(Tool(tool))
|
|
191
|
-
self._deps_type = deps_type
|
|
192
|
-
self._system_prompt_functions = []
|
|
193
|
-
self._system_prompt_dynamic_functions = {}
|
|
194
|
-
self._max_result_retries = result_retries if result_retries is not None else retries
|
|
195
|
-
self._result_validators = []
|
|
196
203
|
|
|
197
204
|
@overload
|
|
198
205
|
async def run(
|
|
@@ -272,66 +279,80 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
272
279
|
|
|
273
280
|
deps = self._get_deps(deps)
|
|
274
281
|
new_message_index = len(message_history) if message_history else 0
|
|
275
|
-
result_schema = self._prepare_result_schema(result_type)
|
|
282
|
+
result_schema: _result.ResultSchema[RunResultDataT] | None = self._prepare_result_schema(result_type)
|
|
283
|
+
|
|
284
|
+
# Build the graph
|
|
285
|
+
graph = self._build_graph(result_type)
|
|
286
|
+
|
|
287
|
+
# Build the initial state
|
|
288
|
+
state = _agent_graph.GraphAgentState(
|
|
289
|
+
message_history=message_history[:] if message_history else [],
|
|
290
|
+
usage=usage or _usage.Usage(),
|
|
291
|
+
retries=0,
|
|
292
|
+
run_step=0,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
# We consider it a user error if a user tries to restrict the result type while having a result validator that
|
|
296
|
+
# may change the result type from the restricted type to something else. Therefore, we consider the following
|
|
297
|
+
# typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
|
|
298
|
+
result_validators = cast(list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators)
|
|
299
|
+
|
|
300
|
+
# TODO: Instead of this, copy the function tools to ensure they don't share current_retry state between agent
|
|
301
|
+
# runs. Requires some changes to `Tool` to make them copyable though.
|
|
302
|
+
for v in self._function_tools.values():
|
|
303
|
+
v.current_retry = 0
|
|
304
|
+
|
|
305
|
+
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
306
|
+
usage_limits = usage_limits or _usage.UsageLimits()
|
|
276
307
|
|
|
277
308
|
with _logfire.span(
|
|
278
309
|
'{agent_name} run {prompt=}',
|
|
279
310
|
prompt=user_prompt,
|
|
280
311
|
agent=self,
|
|
281
|
-
model_name=model_used.name(),
|
|
312
|
+
model_name=model_used.name() if model_used else 'no-model',
|
|
282
313
|
agent_name=self.name or 'agent',
|
|
283
314
|
) as run_span:
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
|
|
302
|
-
model_response, request_usage = await agent_model.request(messages, model_settings)
|
|
303
|
-
model_req_span.set_attribute('response', model_response)
|
|
304
|
-
model_req_span.set_attribute('usage', request_usage)
|
|
315
|
+
# Build the deps object for the graph
|
|
316
|
+
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
|
|
317
|
+
user_deps=deps,
|
|
318
|
+
prompt=user_prompt,
|
|
319
|
+
new_message_index=new_message_index,
|
|
320
|
+
model=model_used,
|
|
321
|
+
model_settings=model_settings,
|
|
322
|
+
usage_limits=usage_limits,
|
|
323
|
+
max_result_retries=self._max_result_retries,
|
|
324
|
+
end_strategy=self.end_strategy,
|
|
325
|
+
result_schema=result_schema,
|
|
326
|
+
result_tools=self._result_schema.tool_defs() if self._result_schema else [],
|
|
327
|
+
result_validators=result_validators,
|
|
328
|
+
function_tools=self._function_tools,
|
|
329
|
+
run_span=run_span,
|
|
330
|
+
)
|
|
305
331
|
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
332
|
+
start_node = _agent_graph.UserPromptNode[AgentDepsT](
|
|
333
|
+
user_prompt=user_prompt,
|
|
334
|
+
system_prompts=self._system_prompts,
|
|
335
|
+
system_prompt_functions=self._system_prompt_functions,
|
|
336
|
+
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
|
|
337
|
+
)
|
|
309
338
|
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
339
|
+
# Actually run
|
|
340
|
+
end_result, _ = await graph.run(
|
|
341
|
+
start_node,
|
|
342
|
+
state=state,
|
|
343
|
+
deps=graph_deps,
|
|
344
|
+
infer_name=False,
|
|
345
|
+
)
|
|
314
346
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
run_span.set_attribute('usage', run_context.usage)
|
|
325
|
-
handle_span.set_attribute('result', result_data)
|
|
326
|
-
handle_span.message = 'handle model response -> final result'
|
|
327
|
-
return result.RunResult(
|
|
328
|
-
messages, new_message_index, result_data, result_tool_name, run_context.usage
|
|
329
|
-
)
|
|
330
|
-
else:
|
|
331
|
-
# continue the conversation
|
|
332
|
-
handle_span.set_attribute('tool_responses', tool_responses)
|
|
333
|
-
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
334
|
-
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
347
|
+
# Build final run result
|
|
348
|
+
# We don't do any advanced checking if the data is actually from a final result or not
|
|
349
|
+
return result.RunResult(
|
|
350
|
+
state.message_history,
|
|
351
|
+
new_message_index,
|
|
352
|
+
end_result.data,
|
|
353
|
+
end_result.tool_name,
|
|
354
|
+
state.usage,
|
|
355
|
+
)
|
|
335
356
|
|
|
336
357
|
@overload
|
|
337
358
|
def run_sync(
|
|
@@ -503,7 +524,31 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
503
524
|
|
|
504
525
|
deps = self._get_deps(deps)
|
|
505
526
|
new_message_index = len(message_history) if message_history else 0
|
|
506
|
-
result_schema = self._prepare_result_schema(result_type)
|
|
527
|
+
result_schema: _result.ResultSchema[RunResultDataT] | None = self._prepare_result_schema(result_type)
|
|
528
|
+
|
|
529
|
+
# Build the graph
|
|
530
|
+
graph = self._build_stream_graph(result_type)
|
|
531
|
+
|
|
532
|
+
# Build the initial state
|
|
533
|
+
graph_state = _agent_graph.GraphAgentState(
|
|
534
|
+
message_history=message_history[:] if message_history else [],
|
|
535
|
+
usage=usage or _usage.Usage(),
|
|
536
|
+
retries=0,
|
|
537
|
+
run_step=0,
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
# We consider it a user error if a user tries to restrict the result type while having a result validator that
|
|
541
|
+
# may change the result type from the restricted type to something else. Therefore, we consider the following
|
|
542
|
+
# typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
|
|
543
|
+
result_validators = cast(list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators)
|
|
544
|
+
|
|
545
|
+
# TODO: Instead of this, copy the function tools to ensure they don't share current_retry state between agent
|
|
546
|
+
# runs. Requires some changes to `Tool` to make them copyable though.
|
|
547
|
+
for v in self._function_tools.values():
|
|
548
|
+
v.current_retry = 0
|
|
549
|
+
|
|
550
|
+
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
551
|
+
usage_limits = usage_limits or _usage.UsageLimits()
|
|
507
552
|
|
|
508
553
|
with _logfire.span(
|
|
509
554
|
'{agent_name} run stream {prompt=}',
|
|
@@ -512,97 +557,53 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
512
557
|
model_name=model_used.name(),
|
|
513
558
|
agent_name=self.name or 'agent',
|
|
514
559
|
) as run_span:
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
560
|
+
# Build the deps object for the graph
|
|
561
|
+
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
|
|
562
|
+
user_deps=deps,
|
|
563
|
+
prompt=user_prompt,
|
|
564
|
+
new_message_index=new_message_index,
|
|
565
|
+
model=model_used,
|
|
566
|
+
model_settings=model_settings,
|
|
567
|
+
usage_limits=usage_limits,
|
|
568
|
+
max_result_retries=self._max_result_retries,
|
|
569
|
+
end_strategy=self.end_strategy,
|
|
570
|
+
result_schema=result_schema,
|
|
571
|
+
result_tools=self._result_schema.tool_defs() if self._result_schema else [],
|
|
572
|
+
result_validators=result_validators,
|
|
573
|
+
function_tools=self._function_tools,
|
|
574
|
+
run_span=run_span,
|
|
575
|
+
)
|
|
521
576
|
|
|
522
|
-
|
|
523
|
-
|
|
577
|
+
start_node = _agent_graph.StreamUserPromptNode[AgentDepsT](
|
|
578
|
+
user_prompt=user_prompt,
|
|
579
|
+
system_prompts=self._system_prompts,
|
|
580
|
+
system_prompt_functions=self._system_prompt_functions,
|
|
581
|
+
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
|
|
582
|
+
)
|
|
524
583
|
|
|
584
|
+
# Actually run
|
|
585
|
+
node = start_node
|
|
586
|
+
history: list[HistoryStep[_agent_graph.GraphAgentState, RunResultDataT]] = []
|
|
525
587
|
while True:
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
async with
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
# Check if we got a final result
|
|
546
|
-
if isinstance(maybe_final_result, _MarkFinalResult):
|
|
547
|
-
result_stream = maybe_final_result.data
|
|
548
|
-
result_tool_name = maybe_final_result.tool_name
|
|
549
|
-
handle_span.message = 'handle model response -> final result'
|
|
550
|
-
|
|
551
|
-
async def on_complete():
|
|
552
|
-
"""Called when the stream has completed.
|
|
553
|
-
|
|
554
|
-
The model response will have been added to messages by now
|
|
555
|
-
by `StreamedRunResult._marked_completed`.
|
|
556
|
-
"""
|
|
557
|
-
last_message = messages[-1]
|
|
558
|
-
assert isinstance(last_message, _messages.ModelResponse)
|
|
559
|
-
tool_calls = [
|
|
560
|
-
part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
|
|
561
|
-
]
|
|
562
|
-
parts = await self._process_function_tools(
|
|
563
|
-
tool_calls, result_tool_name, run_context, result_schema
|
|
564
|
-
)
|
|
565
|
-
if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
566
|
-
self._incr_result_retry(run_context)
|
|
567
|
-
if parts:
|
|
568
|
-
messages.append(_messages.ModelRequest(parts))
|
|
569
|
-
run_span.set_attribute('all_messages', messages)
|
|
570
|
-
|
|
571
|
-
# The following is not guaranteed to be true, but we consider it a user error if
|
|
572
|
-
# there are result validators that might convert the result data from an overridden
|
|
573
|
-
# `result_type` to a type that is not valid as such.
|
|
574
|
-
result_validators = cast(
|
|
575
|
-
list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators
|
|
576
|
-
)
|
|
577
|
-
|
|
578
|
-
yield result.StreamedRunResult(
|
|
579
|
-
messages,
|
|
580
|
-
new_message_index,
|
|
581
|
-
usage_limits,
|
|
582
|
-
result_stream,
|
|
583
|
-
result_schema,
|
|
584
|
-
run_context,
|
|
585
|
-
result_validators,
|
|
586
|
-
result_tool_name,
|
|
587
|
-
on_complete,
|
|
588
|
-
)
|
|
589
|
-
return
|
|
590
|
-
else:
|
|
591
|
-
# continue the conversation
|
|
592
|
-
model_response_msg, tool_responses = maybe_final_result
|
|
593
|
-
# if we got a model response add that to messages
|
|
594
|
-
messages.append(model_response_msg)
|
|
595
|
-
if tool_responses:
|
|
596
|
-
# if we got one or more tool response parts, add a model request message
|
|
597
|
-
messages.append(_messages.ModelRequest(tool_responses))
|
|
598
|
-
|
|
599
|
-
handle_span.set_attribute('tool_responses', tool_responses)
|
|
600
|
-
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
601
|
-
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
602
|
-
# the model_response should have been fully streamed by now, we can add its usage
|
|
603
|
-
model_response_usage = model_response.usage()
|
|
604
|
-
run_context.usage.incr(model_response_usage)
|
|
605
|
-
usage_limits.check_tokens(run_context.usage)
|
|
588
|
+
if isinstance(node, _agent_graph.StreamModelRequestNode):
|
|
589
|
+
node = cast(
|
|
590
|
+
_agent_graph.StreamModelRequestNode[
|
|
591
|
+
AgentDepsT, result.StreamedRunResult[AgentDepsT, RunResultDataT]
|
|
592
|
+
],
|
|
593
|
+
node,
|
|
594
|
+
)
|
|
595
|
+
async with node.run_to_result(GraphRunContext(graph_state, graph_deps)) as r:
|
|
596
|
+
if isinstance(r, End):
|
|
597
|
+
yield r.data
|
|
598
|
+
break
|
|
599
|
+
assert not isinstance(node, End) # the previous line should be hit first
|
|
600
|
+
node = await graph.next(
|
|
601
|
+
node,
|
|
602
|
+
history,
|
|
603
|
+
state=graph_state,
|
|
604
|
+
deps=graph_deps,
|
|
605
|
+
infer_name=False,
|
|
606
|
+
)
|
|
606
607
|
|
|
607
608
|
@contextmanager
|
|
608
609
|
def override(
|
|
@@ -718,7 +719,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
718
719
|
return decorator
|
|
719
720
|
else:
|
|
720
721
|
assert not dynamic, "dynamic can't be True in this case"
|
|
721
|
-
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func, dynamic=dynamic))
|
|
722
|
+
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic))
|
|
722
723
|
return func
|
|
723
724
|
|
|
724
725
|
@overload
|
|
@@ -998,335 +999,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
998
999
|
|
|
999
1000
|
return model_
|
|
1000
1001
|
|
|
1001
|
-
|
|
1002
|
-
self, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultDataT] | None
|
|
1003
|
-
) -> models.AgentModel:
|
|
1004
|
-
"""Build tools and create an agent model."""
|
|
1005
|
-
function_tools: list[ToolDefinition] = []
|
|
1006
|
-
|
|
1007
|
-
async def add_tool(tool: Tool[AgentDepsT]) -> None:
|
|
1008
|
-
ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
|
|
1009
|
-
if tool_def := await tool.prepare_tool_def(ctx):
|
|
1010
|
-
function_tools.append(tool_def)
|
|
1011
|
-
|
|
1012
|
-
await asyncio.gather(*map(add_tool, self._function_tools.values()))
|
|
1013
|
-
|
|
1014
|
-
return await run_context.model.agent_model(
|
|
1015
|
-
function_tools=function_tools,
|
|
1016
|
-
allow_text_result=self._allow_text_result(result_schema),
|
|
1017
|
-
result_tools=result_schema.tool_defs() if result_schema is not None else [],
|
|
1018
|
-
)
|
|
1019
|
-
|
|
1020
|
-
async def _reevaluate_dynamic_prompts(
|
|
1021
|
-
self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDepsT]
|
|
1022
|
-
) -> None:
|
|
1023
|
-
"""Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
|
|
1024
|
-
# Only proceed if there's at least one dynamic runner.
|
|
1025
|
-
if self._system_prompt_dynamic_functions:
|
|
1026
|
-
for msg in messages:
|
|
1027
|
-
if isinstance(msg, _messages.ModelRequest):
|
|
1028
|
-
for i, part in enumerate(msg.parts):
|
|
1029
|
-
if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
|
|
1030
|
-
# Look up the runner by its ref
|
|
1031
|
-
if runner := self._system_prompt_dynamic_functions.get(part.dynamic_ref):
|
|
1032
|
-
updated_part_content = await runner.run(run_context)
|
|
1033
|
-
msg.parts[i] = _messages.SystemPromptPart(
|
|
1034
|
-
updated_part_content, dynamic_ref=part.dynamic_ref
|
|
1035
|
-
)
|
|
1036
|
-
|
|
1037
|
-
def _prepare_result_schema(
|
|
1038
|
-
self, result_type: type[RunResultDataT] | None
|
|
1039
|
-
) -> _result.ResultSchema[RunResultDataT] | None:
|
|
1040
|
-
if result_type is not None:
|
|
1041
|
-
if self._result_validators:
|
|
1042
|
-
raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
|
|
1043
|
-
return _result.ResultSchema[result_type].build(
|
|
1044
|
-
result_type, self._result_tool_name, self._result_tool_description
|
|
1045
|
-
)
|
|
1046
|
-
else:
|
|
1047
|
-
return self._result_schema # pyright: ignore[reportReturnType]
|
|
1048
|
-
|
|
1049
|
-
async def _prepare_messages(
|
|
1050
|
-
self,
|
|
1051
|
-
user_prompt: str,
|
|
1052
|
-
message_history: list[_messages.ModelMessage] | None,
|
|
1053
|
-
run_context: RunContext[AgentDepsT],
|
|
1054
|
-
) -> list[_messages.ModelMessage]:
|
|
1055
|
-
try:
|
|
1056
|
-
ctx_messages = get_captured_run_messages()
|
|
1057
|
-
except LookupError:
|
|
1058
|
-
messages: list[_messages.ModelMessage] = []
|
|
1059
|
-
else:
|
|
1060
|
-
if ctx_messages.used:
|
|
1061
|
-
messages = []
|
|
1062
|
-
else:
|
|
1063
|
-
messages = ctx_messages.messages
|
|
1064
|
-
ctx_messages.used = True
|
|
1065
|
-
|
|
1066
|
-
if message_history:
|
|
1067
|
-
# Shallow copy messages
|
|
1068
|
-
messages.extend(message_history)
|
|
1069
|
-
# Reevaluate any dynamic system prompt parts
|
|
1070
|
-
await self._reevaluate_dynamic_prompts(messages, run_context)
|
|
1071
|
-
messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
|
|
1072
|
-
else:
|
|
1073
|
-
parts = await self._sys_parts(run_context)
|
|
1074
|
-
parts.append(_messages.UserPromptPart(user_prompt))
|
|
1075
|
-
messages.append(_messages.ModelRequest(parts))
|
|
1076
|
-
|
|
1077
|
-
return messages
|
|
1078
|
-
|
|
1079
|
-
async def _handle_model_response(
|
|
1080
|
-
self,
|
|
1081
|
-
model_response: _messages.ModelResponse,
|
|
1082
|
-
run_context: RunContext[AgentDepsT],
|
|
1083
|
-
result_schema: _result.ResultSchema[RunResultDataT] | None,
|
|
1084
|
-
) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
|
|
1085
|
-
"""Process a non-streamed response from the model.
|
|
1086
|
-
|
|
1087
|
-
Returns:
|
|
1088
|
-
A tuple of `(final_result, request parts)`. If `final_result` is not `None`, the conversation should end.
|
|
1089
|
-
"""
|
|
1090
|
-
texts: list[str] = []
|
|
1091
|
-
tool_calls: list[_messages.ToolCallPart] = []
|
|
1092
|
-
for part in model_response.parts:
|
|
1093
|
-
if isinstance(part, _messages.TextPart):
|
|
1094
|
-
# ignore empty content for text parts, see #437
|
|
1095
|
-
if part.content:
|
|
1096
|
-
texts.append(part.content)
|
|
1097
|
-
else:
|
|
1098
|
-
tool_calls.append(part)
|
|
1099
|
-
|
|
1100
|
-
# At the moment, we prioritize at least executing tool calls if they are present.
|
|
1101
|
-
# In the future, we'd consider making this configurable at the agent or run level.
|
|
1102
|
-
# This accounts for cases like anthropic returns that might contain a text response
|
|
1103
|
-
# and a tool call response, where the text response just indicates the tool call will happen.
|
|
1104
|
-
if tool_calls:
|
|
1105
|
-
return await self._handle_structured_response(tool_calls, run_context, result_schema)
|
|
1106
|
-
elif texts:
|
|
1107
|
-
text = '\n\n'.join(texts)
|
|
1108
|
-
return await self._handle_text_response(text, run_context, result_schema)
|
|
1109
|
-
else:
|
|
1110
|
-
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
1111
|
-
|
|
1112
|
-
async def _handle_text_response(
|
|
1113
|
-
self, text: str, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultDataT] | None
|
|
1114
|
-
) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
|
|
1115
|
-
"""Handle a plain text response from the model for non-streaming responses."""
|
|
1116
|
-
if self._allow_text_result(result_schema):
|
|
1117
|
-
result_data_input = cast(RunResultDataT, text)
|
|
1118
|
-
try:
|
|
1119
|
-
result_data = await self._validate_result(result_data_input, run_context, None)
|
|
1120
|
-
except _result.ToolRetryError as e:
|
|
1121
|
-
self._incr_result_retry(run_context)
|
|
1122
|
-
return None, [e.tool_retry]
|
|
1123
|
-
else:
|
|
1124
|
-
return _MarkFinalResult(result_data, None), []
|
|
1125
|
-
else:
|
|
1126
|
-
self._incr_result_retry(run_context)
|
|
1127
|
-
response = _messages.RetryPromptPart(
|
|
1128
|
-
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
1129
|
-
)
|
|
1130
|
-
return None, [response]
|
|
1131
|
-
|
|
1132
|
-
async def _handle_structured_response(
|
|
1133
|
-
self,
|
|
1134
|
-
tool_calls: list[_messages.ToolCallPart],
|
|
1135
|
-
run_context: RunContext[AgentDepsT],
|
|
1136
|
-
result_schema: _result.ResultSchema[RunResultDataT] | None,
|
|
1137
|
-
) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
|
|
1138
|
-
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
|
|
1139
|
-
assert tool_calls, 'Expected at least one tool call'
|
|
1140
|
-
|
|
1141
|
-
# first look for the result tool call
|
|
1142
|
-
final_result: _MarkFinalResult[RunResultDataT] | None = None
|
|
1143
|
-
|
|
1144
|
-
parts: list[_messages.ModelRequestPart] = []
|
|
1145
|
-
if result_schema is not None:
|
|
1146
|
-
if match := result_schema.find_tool(tool_calls):
|
|
1147
|
-
call, result_tool = match
|
|
1148
|
-
try:
|
|
1149
|
-
result_data = result_tool.validate(call)
|
|
1150
|
-
result_data = await self._validate_result(result_data, run_context, call)
|
|
1151
|
-
except _result.ToolRetryError as e:
|
|
1152
|
-
parts.append(e.tool_retry)
|
|
1153
|
-
else:
|
|
1154
|
-
final_result = _MarkFinalResult(result_data, call.tool_name)
|
|
1155
|
-
|
|
1156
|
-
# Then build the other request parts based on end strategy
|
|
1157
|
-
parts += await self._process_function_tools(
|
|
1158
|
-
tool_calls, final_result and final_result.tool_name, run_context, result_schema
|
|
1159
|
-
)
|
|
1160
|
-
|
|
1161
|
-
if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
1162
|
-
self._incr_result_retry(run_context)
|
|
1163
|
-
|
|
1164
|
-
return final_result, parts
|
|
1165
|
-
|
|
1166
|
-
async def _process_function_tools(
|
|
1167
|
-
self,
|
|
1168
|
-
tool_calls: list[_messages.ToolCallPart],
|
|
1169
|
-
result_tool_name: str | None,
|
|
1170
|
-
run_context: RunContext[AgentDepsT],
|
|
1171
|
-
result_schema: _result.ResultSchema[RunResultDataT] | None,
|
|
1172
|
-
) -> list[_messages.ModelRequestPart]:
|
|
1173
|
-
"""Process function (non-result) tool calls in parallel.
|
|
1174
|
-
|
|
1175
|
-
Also add stub return parts for any other tools that need it.
|
|
1176
|
-
"""
|
|
1177
|
-
parts: list[_messages.ModelRequestPart] = []
|
|
1178
|
-
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
1179
|
-
|
|
1180
|
-
stub_function_tools = bool(result_tool_name) and self.end_strategy == 'early'
|
|
1181
|
-
|
|
1182
|
-
# we rely on the fact that if we found a result, it's the first result tool in the last
|
|
1183
|
-
found_used_result_tool = False
|
|
1184
|
-
for call in tool_calls:
|
|
1185
|
-
if call.tool_name == result_tool_name and not found_used_result_tool:
|
|
1186
|
-
found_used_result_tool = True
|
|
1187
|
-
parts.append(
|
|
1188
|
-
_messages.ToolReturnPart(
|
|
1189
|
-
tool_name=call.tool_name,
|
|
1190
|
-
content='Final result processed.',
|
|
1191
|
-
tool_call_id=call.tool_call_id,
|
|
1192
|
-
)
|
|
1193
|
-
)
|
|
1194
|
-
elif tool := self._function_tools.get(call.tool_name):
|
|
1195
|
-
if stub_function_tools:
|
|
1196
|
-
parts.append(
|
|
1197
|
-
_messages.ToolReturnPart(
|
|
1198
|
-
tool_name=call.tool_name,
|
|
1199
|
-
content='Tool not executed - a final result was already processed.',
|
|
1200
|
-
tool_call_id=call.tool_call_id,
|
|
1201
|
-
)
|
|
1202
|
-
)
|
|
1203
|
-
else:
|
|
1204
|
-
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
|
|
1205
|
-
elif result_schema is not None and call.tool_name in result_schema.tools:
|
|
1206
|
-
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
|
|
1207
|
-
# validation, we don't add another part here
|
|
1208
|
-
if result_tool_name is not None:
|
|
1209
|
-
parts.append(
|
|
1210
|
-
_messages.ToolReturnPart(
|
|
1211
|
-
tool_name=call.tool_name,
|
|
1212
|
-
content='Result tool not used - a final result was already processed.',
|
|
1213
|
-
tool_call_id=call.tool_call_id,
|
|
1214
|
-
)
|
|
1215
|
-
)
|
|
1216
|
-
else:
|
|
1217
|
-
parts.append(self._unknown_tool(call.tool_name, result_schema))
|
|
1218
|
-
|
|
1219
|
-
# Run all tool tasks in parallel
|
|
1220
|
-
if tasks:
|
|
1221
|
-
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
1222
|
-
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
1223
|
-
parts.extend(task_results)
|
|
1224
|
-
return parts
|
|
1225
|
-
|
|
1226
|
-
async def _handle_streamed_response(
|
|
1227
|
-
self,
|
|
1228
|
-
streamed_response: models.StreamedResponse,
|
|
1229
|
-
run_context: RunContext[AgentDepsT],
|
|
1230
|
-
result_schema: _result.ResultSchema[RunResultDataT] | None,
|
|
1231
|
-
) -> _MarkFinalResult[models.StreamedResponse] | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]:
|
|
1232
|
-
"""Process a streamed response from the model.
|
|
1233
|
-
|
|
1234
|
-
Returns:
|
|
1235
|
-
Either a final result or a tuple of the model response and the tool responses for the next request.
|
|
1236
|
-
If a final result is returned, the conversation should end.
|
|
1237
|
-
"""
|
|
1238
|
-
received_text = False
|
|
1239
|
-
|
|
1240
|
-
async for maybe_part_event in streamed_response:
|
|
1241
|
-
if isinstance(maybe_part_event, _messages.PartStartEvent):
|
|
1242
|
-
new_part = maybe_part_event.part
|
|
1243
|
-
if isinstance(new_part, _messages.TextPart):
|
|
1244
|
-
received_text = True
|
|
1245
|
-
if self._allow_text_result(result_schema):
|
|
1246
|
-
return _MarkFinalResult(streamed_response, None)
|
|
1247
|
-
elif isinstance(new_part, _messages.ToolCallPart):
|
|
1248
|
-
if result_schema is not None and (match := result_schema.find_tool([new_part])):
|
|
1249
|
-
call, _ = match
|
|
1250
|
-
return _MarkFinalResult(streamed_response, call.tool_name)
|
|
1251
|
-
else:
|
|
1252
|
-
assert_never(new_part)
|
|
1253
|
-
|
|
1254
|
-
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
1255
|
-
parts: list[_messages.ModelRequestPart] = []
|
|
1256
|
-
model_response = streamed_response.get()
|
|
1257
|
-
if not model_response.parts:
|
|
1258
|
-
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
1259
|
-
for p in model_response.parts:
|
|
1260
|
-
if isinstance(p, _messages.ToolCallPart):
|
|
1261
|
-
if tool := self._function_tools.get(p.tool_name):
|
|
1262
|
-
tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
|
|
1263
|
-
else:
|
|
1264
|
-
parts.append(self._unknown_tool(p.tool_name, result_schema))
|
|
1265
|
-
|
|
1266
|
-
if received_text and not tasks and not parts:
|
|
1267
|
-
# Can only get here if self._allow_text_result returns `False` for the provided result_schema
|
|
1268
|
-
self._incr_result_retry(run_context)
|
|
1269
|
-
model_response = _messages.RetryPromptPart(
|
|
1270
|
-
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
1271
|
-
)
|
|
1272
|
-
return streamed_response.get(), [model_response]
|
|
1273
|
-
|
|
1274
|
-
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
1275
|
-
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
1276
|
-
parts.extend(task_results)
|
|
1277
|
-
|
|
1278
|
-
if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
1279
|
-
self._incr_result_retry(run_context)
|
|
1280
|
-
|
|
1281
|
-
return model_response, parts
|
|
1282
|
-
|
|
1283
|
-
async def _validate_result(
|
|
1284
|
-
self,
|
|
1285
|
-
result_data: RunResultDataT,
|
|
1286
|
-
run_context: RunContext[AgentDepsT],
|
|
1287
|
-
tool_call: _messages.ToolCallPart | None,
|
|
1288
|
-
) -> RunResultDataT:
|
|
1289
|
-
if self._result_validators:
|
|
1290
|
-
agent_result_data = cast(ResultDataT, result_data)
|
|
1291
|
-
for validator in self._result_validators:
|
|
1292
|
-
agent_result_data = await validator.validate(agent_result_data, tool_call, run_context)
|
|
1293
|
-
return cast(RunResultDataT, agent_result_data)
|
|
1294
|
-
else:
|
|
1295
|
-
return result_data
|
|
1296
|
-
|
|
1297
|
-
def _incr_result_retry(self, run_context: RunContext[AgentDepsT]) -> None:
|
|
1298
|
-
run_context.retry += 1
|
|
1299
|
-
if run_context.retry > self._max_result_retries:
|
|
1300
|
-
raise exceptions.UnexpectedModelBehavior(
|
|
1301
|
-
f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
|
|
1302
|
-
)
|
|
1303
|
-
|
|
1304
|
-
async def _sys_parts(self, run_context: RunContext[AgentDepsT]) -> list[_messages.ModelRequestPart]:
|
|
1305
|
-
"""Build the initial messages for the conversation."""
|
|
1306
|
-
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
|
|
1307
|
-
for sys_prompt_runner in self._system_prompt_functions:
|
|
1308
|
-
prompt = await sys_prompt_runner.run(run_context)
|
|
1309
|
-
if sys_prompt_runner.dynamic:
|
|
1310
|
-
messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
|
|
1311
|
-
else:
|
|
1312
|
-
messages.append(_messages.SystemPromptPart(prompt))
|
|
1313
|
-
return messages
|
|
1314
|
-
|
|
1315
|
-
def _unknown_tool(
|
|
1316
|
-
self,
|
|
1317
|
-
tool_name: str,
|
|
1318
|
-
result_schema: _result.ResultSchema[RunResultDataT] | None,
|
|
1319
|
-
) -> _messages.RetryPromptPart:
|
|
1320
|
-
names = list(self._function_tools.keys())
|
|
1321
|
-
if result_schema:
|
|
1322
|
-
names.extend(result_schema.tool_names())
|
|
1323
|
-
if names:
|
|
1324
|
-
msg = f'Available tools: {", ".join(names)}'
|
|
1325
|
-
else:
|
|
1326
|
-
msg = 'No tools available.'
|
|
1327
|
-
return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
|
|
1328
|
-
|
|
1329
|
-
def _get_deps(self: Agent[T, Any], deps: T) -> T:
|
|
1002
|
+
def _get_deps(self: Agent[T, ResultDataT], deps: T) -> T:
|
|
1330
1003
|
"""Get deps for a run.
|
|
1331
1004
|
|
|
1332
1005
|
If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
|
|
@@ -1357,10 +1030,6 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1357
1030
|
self.name = name
|
|
1358
1031
|
return
|
|
1359
1032
|
|
|
1360
|
-
@staticmethod
|
|
1361
|
-
def _allow_text_result(result_schema: _result.ResultSchema[RunResultDataT] | None) -> bool:
|
|
1362
|
-
return result_schema is None or result_schema.allow_text_result
|
|
1363
|
-
|
|
1364
1033
|
@property
|
|
1365
1034
|
@deprecated(
|
|
1366
1035
|
'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None
|
|
@@ -1368,65 +1037,24 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1368
1037
|
def last_run_messages(self) -> list[_messages.ModelMessage]:
|
|
1369
1038
|
raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
|
|
1370
1039
|
|
|
1040
|
+
def _build_graph(
|
|
1041
|
+
self, result_type: type[RunResultDataT] | None
|
|
1042
|
+
) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], Any]:
|
|
1043
|
+
return _agent_graph.build_agent_graph(self.name, self._deps_type, result_type or self.result_type)
|
|
1371
1044
|
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
_messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var')
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
@contextmanager
|
|
1382
|
-
def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
|
|
1383
|
-
"""Context manager to access the messages used in a [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] call.
|
|
1384
|
-
|
|
1385
|
-
Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information.
|
|
1386
|
-
|
|
1387
|
-
Examples:
|
|
1388
|
-
```python
|
|
1389
|
-
from pydantic_ai import Agent, capture_run_messages
|
|
1390
|
-
|
|
1391
|
-
agent = Agent('test')
|
|
1392
|
-
|
|
1393
|
-
with capture_run_messages() as messages:
|
|
1394
|
-
try:
|
|
1395
|
-
result = agent.run_sync('foobar')
|
|
1396
|
-
except Exception:
|
|
1397
|
-
print(messages)
|
|
1398
|
-
raise
|
|
1399
|
-
```
|
|
1400
|
-
|
|
1401
|
-
!!! note
|
|
1402
|
-
If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
|
|
1403
|
-
`messages` will represent the messages exchanged during the first call only.
|
|
1404
|
-
"""
|
|
1405
|
-
try:
|
|
1406
|
-
yield _messages_ctx_var.get().messages
|
|
1407
|
-
except LookupError:
|
|
1408
|
-
messages: list[_messages.ModelMessage] = []
|
|
1409
|
-
token = _messages_ctx_var.set(_RunMessages(messages))
|
|
1410
|
-
try:
|
|
1411
|
-
yield messages
|
|
1412
|
-
finally:
|
|
1413
|
-
_messages_ctx_var.reset(token)
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
def get_captured_run_messages() -> _RunMessages:
|
|
1417
|
-
return _messages_ctx_var.get()
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
@dataclasses.dataclass
|
|
1421
|
-
class _MarkFinalResult(Generic[ResultDataT]):
|
|
1422
|
-
"""Marker class to indicate that the result is the final result.
|
|
1423
|
-
|
|
1424
|
-
This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly.
|
|
1425
|
-
|
|
1426
|
-
It also avoids problems in the case where the result type is itself `None`, but is set.
|
|
1427
|
-
"""
|
|
1045
|
+
def _build_stream_graph(
|
|
1046
|
+
self, result_type: type[RunResultDataT] | None
|
|
1047
|
+
) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], Any]:
|
|
1048
|
+
return _agent_graph.build_agent_stream_graph(self.name, self._deps_type, result_type or self.result_type)
|
|
1428
1049
|
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
|
|
1050
|
+
def _prepare_result_schema(
|
|
1051
|
+
self, result_type: type[RunResultDataT] | None
|
|
1052
|
+
) -> _result.ResultSchema[RunResultDataT] | None:
|
|
1053
|
+
if result_type is not None:
|
|
1054
|
+
if self._result_validators:
|
|
1055
|
+
raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
|
|
1056
|
+
return _result.ResultSchema[result_type].build(
|
|
1057
|
+
result_type, self._result_tool_name, self._result_tool_description
|
|
1058
|
+
)
|
|
1059
|
+
else:
|
|
1060
|
+
return self._result_schema # pyright: ignore[reportReturnType]
|