pydantic-ai-slim 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +16 -4
- pydantic_ai/_agent_graph.py +264 -351
- pydantic_ai/_utils.py +1 -1
- pydantic_ai/agent.py +581 -156
- pydantic_ai/messages.py +121 -1
- pydantic_ai/models/__init__.py +12 -1
- pydantic_ai/models/anthropic.py +67 -50
- pydantic_ai/models/cohere.py +5 -2
- pydantic_ai/models/function.py +15 -6
- pydantic_ai/models/gemini.py +73 -5
- pydantic_ai/models/groq.py +35 -8
- pydantic_ai/models/instrumented.py +225 -0
- pydantic_ai/models/mistral.py +29 -4
- pydantic_ai/models/openai.py +59 -13
- pydantic_ai/models/test.py +6 -6
- pydantic_ai/models/wrapper.py +45 -0
- pydantic_ai/result.py +106 -144
- pydantic_ai/tools.py +2 -2
- {pydantic_ai_slim-0.0.24.dist-info → pydantic_ai_slim-0.0.26.dist-info}/METADATA +2 -2
- pydantic_ai_slim-0.0.26.dist-info/RECORD +32 -0
- pydantic_ai_slim-0.0.24.dist-info/RECORD +0 -30
- {pydantic_ai_slim-0.0.24.dist-info → pydantic_ai_slim-0.0.26.dist-info}/WHEEL +0 -0
pydantic_ai/_agent_graph.py
CHANGED
|
@@ -32,6 +32,16 @@ from .tools import (
|
|
|
32
32
|
ToolDefinition,
|
|
33
33
|
)
|
|
34
34
|
|
|
35
|
+
__all__ = (
|
|
36
|
+
'GraphAgentState',
|
|
37
|
+
'GraphAgentDeps',
|
|
38
|
+
'UserPromptNode',
|
|
39
|
+
'ModelRequestNode',
|
|
40
|
+
'HandleResponseNode',
|
|
41
|
+
'build_run_context',
|
|
42
|
+
'capture_run_messages',
|
|
43
|
+
)
|
|
44
|
+
|
|
35
45
|
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
36
46
|
|
|
37
47
|
# while waiting for https://github.com/pydantic/logfire/issues/745
|
|
@@ -56,21 +66,6 @@ DepsT = TypeVar('DepsT')
|
|
|
56
66
|
ResultT = TypeVar('ResultT')
|
|
57
67
|
|
|
58
68
|
|
|
59
|
-
@dataclasses.dataclass
|
|
60
|
-
class MarkFinalResult(Generic[ResultDataT]):
|
|
61
|
-
"""Marker class to indicate that the result is the final result.
|
|
62
|
-
|
|
63
|
-
This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly.
|
|
64
|
-
|
|
65
|
-
It also avoids problems in the case where the result type is itself `None`, but is set.
|
|
66
|
-
"""
|
|
67
|
-
|
|
68
|
-
data: ResultDataT
|
|
69
|
-
"""The final result data."""
|
|
70
|
-
tool_name: str | None
|
|
71
|
-
"""Name of the final result tool, None if the result is a string."""
|
|
72
|
-
|
|
73
|
-
|
|
74
69
|
@dataclasses.dataclass
|
|
75
70
|
class GraphAgentState:
|
|
76
71
|
"""State kept across the execution of the agent graph."""
|
|
@@ -94,7 +89,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
94
89
|
|
|
95
90
|
user_deps: DepsT
|
|
96
91
|
|
|
97
|
-
prompt: str
|
|
92
|
+
prompt: str | Sequence[_messages.UserContent]
|
|
98
93
|
new_message_index: int
|
|
99
94
|
|
|
100
95
|
model: models.Model
|
|
@@ -113,17 +108,22 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
113
108
|
|
|
114
109
|
|
|
115
110
|
@dataclasses.dataclass
|
|
116
|
-
class
|
|
117
|
-
user_prompt: str
|
|
111
|
+
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], ABC):
|
|
112
|
+
user_prompt: str | Sequence[_messages.UserContent]
|
|
118
113
|
|
|
119
114
|
system_prompts: tuple[str, ...]
|
|
120
115
|
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
|
|
121
116
|
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
|
|
122
117
|
|
|
118
|
+
async def run(
|
|
119
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
120
|
+
) -> ModelRequestNode[DepsT, NodeRunEndT]:
|
|
121
|
+
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
|
|
122
|
+
|
|
123
123
|
async def _get_first_message(
|
|
124
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT,
|
|
124
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
125
125
|
) -> _messages.ModelRequest:
|
|
126
|
-
run_context =
|
|
126
|
+
run_context = build_run_context(ctx)
|
|
127
127
|
history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context)
|
|
128
128
|
ctx.state.message_history = history
|
|
129
129
|
run_context.messages = history
|
|
@@ -135,7 +135,10 @@ class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
|
|
|
135
135
|
return next_message
|
|
136
136
|
|
|
137
137
|
async def _prepare_messages(
|
|
138
|
-
self,
|
|
138
|
+
self,
|
|
139
|
+
user_prompt: str | Sequence[_messages.UserContent],
|
|
140
|
+
message_history: list[_messages.ModelMessage] | None,
|
|
141
|
+
run_context: RunContext[DepsT],
|
|
139
142
|
) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
|
|
140
143
|
try:
|
|
141
144
|
ctx_messages = get_captured_run_messages()
|
|
@@ -188,29 +191,13 @@ class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
|
|
|
188
191
|
return messages
|
|
189
192
|
|
|
190
193
|
|
|
191
|
-
@dataclasses.dataclass
|
|
192
|
-
class UserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
|
|
193
|
-
async def run(
|
|
194
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
|
|
195
|
-
) -> ModelRequestNode[DepsT, NodeRunEndT]:
|
|
196
|
-
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
@dataclasses.dataclass
|
|
200
|
-
class StreamUserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
|
|
201
|
-
async def run(
|
|
202
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
|
|
203
|
-
) -> StreamModelRequestNode[DepsT, NodeRunEndT]:
|
|
204
|
-
return StreamModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
|
|
205
|
-
|
|
206
|
-
|
|
207
194
|
async def _prepare_request_parameters(
|
|
208
195
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
209
196
|
) -> models.ModelRequestParameters:
|
|
210
197
|
"""Build tools and create an agent model."""
|
|
211
198
|
function_tool_defs: list[ToolDefinition] = []
|
|
212
199
|
|
|
213
|
-
run_context =
|
|
200
|
+
run_context = build_run_context(ctx)
|
|
214
201
|
|
|
215
202
|
async def add_tool(tool: Tool[DepsT]) -> None:
|
|
216
203
|
ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
|
|
@@ -222,20 +209,81 @@ async def _prepare_request_parameters(
|
|
|
222
209
|
result_schema = ctx.deps.result_schema
|
|
223
210
|
return models.ModelRequestParameters(
|
|
224
211
|
function_tools=function_tool_defs,
|
|
225
|
-
allow_text_result=
|
|
212
|
+
allow_text_result=allow_text_result(result_schema),
|
|
226
213
|
result_tools=result_schema.tool_defs() if result_schema is not None else [],
|
|
227
214
|
)
|
|
228
215
|
|
|
229
216
|
|
|
230
217
|
@dataclasses.dataclass
|
|
231
|
-
class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
|
|
218
|
+
class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
|
|
232
219
|
"""Make a request to the model using the last message in state.message_history."""
|
|
233
220
|
|
|
234
221
|
request: _messages.ModelRequest
|
|
235
222
|
|
|
223
|
+
_result: HandleResponseNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
|
|
224
|
+
_did_stream: bool = field(default=False, repr=False)
|
|
225
|
+
|
|
236
226
|
async def run(
|
|
237
227
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
238
228
|
) -> HandleResponseNode[DepsT, NodeRunEndT]:
|
|
229
|
+
if self._result is not None:
|
|
230
|
+
return self._result
|
|
231
|
+
|
|
232
|
+
if self._did_stream:
|
|
233
|
+
# `self._result` gets set when exiting the `stream` contextmanager, so hitting this
|
|
234
|
+
# means that the stream was started but not finished before `run()` was called
|
|
235
|
+
raise exceptions.AgentRunError('You must finish streaming before calling run()')
|
|
236
|
+
|
|
237
|
+
return await self._make_request(ctx)
|
|
238
|
+
|
|
239
|
+
@asynccontextmanager
|
|
240
|
+
async def _stream(
|
|
241
|
+
self,
|
|
242
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
|
|
243
|
+
) -> AsyncIterator[models.StreamedResponse]:
|
|
244
|
+
# TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public
|
|
245
|
+
assert not self._did_stream, 'stream() should only be called once per node'
|
|
246
|
+
|
|
247
|
+
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
248
|
+
with _logfire.span('model request', run_step=ctx.state.run_step) as span:
|
|
249
|
+
async with ctx.deps.model.request_stream(
|
|
250
|
+
ctx.state.message_history, model_settings, model_request_parameters
|
|
251
|
+
) as streamed_response:
|
|
252
|
+
self._did_stream = True
|
|
253
|
+
ctx.state.usage.incr(_usage.Usage(), requests=1)
|
|
254
|
+
yield streamed_response
|
|
255
|
+
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
|
|
256
|
+
# otherwise usage won't be properly counted:
|
|
257
|
+
async for _ in streamed_response:
|
|
258
|
+
pass
|
|
259
|
+
model_response = streamed_response.get()
|
|
260
|
+
request_usage = streamed_response.usage()
|
|
261
|
+
span.set_attribute('response', model_response)
|
|
262
|
+
span.set_attribute('usage', request_usage)
|
|
263
|
+
|
|
264
|
+
self._finish_handling(ctx, model_response, request_usage)
|
|
265
|
+
assert self._result is not None # this should be set by the previous line
|
|
266
|
+
|
|
267
|
+
async def _make_request(
|
|
268
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
269
|
+
) -> HandleResponseNode[DepsT, NodeRunEndT]:
|
|
270
|
+
if self._result is not None:
|
|
271
|
+
return self._result
|
|
272
|
+
|
|
273
|
+
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
274
|
+
with _logfire.span('model request', run_step=ctx.state.run_step) as span:
|
|
275
|
+
model_response, request_usage = await ctx.deps.model.request(
|
|
276
|
+
ctx.state.message_history, model_settings, model_request_parameters
|
|
277
|
+
)
|
|
278
|
+
ctx.state.usage.incr(_usage.Usage(), requests=1)
|
|
279
|
+
span.set_attribute('response', model_response)
|
|
280
|
+
span.set_attribute('usage', request_usage)
|
|
281
|
+
|
|
282
|
+
return self._finish_handling(ctx, model_response, request_usage)
|
|
283
|
+
|
|
284
|
+
async def _prepare_request(
|
|
285
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
286
|
+
) -> tuple[ModelSettings | None, models.ModelRequestParameters]:
|
|
239
287
|
ctx.state.message_history.append(self.request)
|
|
240
288
|
|
|
241
289
|
# Check usage
|
|
@@ -245,71 +293,124 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], Nod
|
|
|
245
293
|
# Increment run_step
|
|
246
294
|
ctx.state.run_step += 1
|
|
247
295
|
|
|
296
|
+
model_settings = merge_model_settings(ctx.deps.model_settings, None)
|
|
248
297
|
with _logfire.span('preparing model request params {run_step=}', run_step=ctx.state.run_step):
|
|
249
298
|
model_request_parameters = await _prepare_request_parameters(ctx)
|
|
299
|
+
return model_settings, model_request_parameters
|
|
250
300
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
span.set_attribute('response', model_response)
|
|
258
|
-
span.set_attribute('usage', request_usage)
|
|
259
|
-
|
|
301
|
+
def _finish_handling(
|
|
302
|
+
self,
|
|
303
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
304
|
+
response: _messages.ModelResponse,
|
|
305
|
+
usage: _usage.Usage,
|
|
306
|
+
) -> HandleResponseNode[DepsT, NodeRunEndT]:
|
|
260
307
|
# Update usage
|
|
261
|
-
ctx.state.usage.incr(
|
|
308
|
+
ctx.state.usage.incr(usage, requests=0)
|
|
262
309
|
if ctx.deps.usage_limits:
|
|
263
310
|
ctx.deps.usage_limits.check_tokens(ctx.state.usage)
|
|
264
311
|
|
|
265
312
|
# Append the model response to state.message_history
|
|
266
|
-
ctx.state.message_history.append(
|
|
267
|
-
|
|
313
|
+
ctx.state.message_history.append(response)
|
|
314
|
+
|
|
315
|
+
# Set the `_result` attribute since we can't use `return` in an async iterator
|
|
316
|
+
self._result = HandleResponseNode(response)
|
|
317
|
+
|
|
318
|
+
return self._result
|
|
268
319
|
|
|
269
320
|
|
|
270
321
|
@dataclasses.dataclass
|
|
271
|
-
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
|
|
272
|
-
"""Process
|
|
322
|
+
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
|
|
323
|
+
"""Process a model response, and decide whether to end the run or make a new request."""
|
|
273
324
|
|
|
274
325
|
model_response: _messages.ModelResponse
|
|
275
326
|
|
|
327
|
+
_events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False)
|
|
328
|
+
_next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
|
|
329
|
+
default=None, repr=False
|
|
330
|
+
)
|
|
331
|
+
_tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False)
|
|
332
|
+
|
|
276
333
|
async def run(
|
|
277
334
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
278
|
-
) -> Union[ModelRequestNode[DepsT, NodeRunEndT],
|
|
335
|
+
) -> Union[ModelRequestNode[DepsT, NodeRunEndT], End[result.FinalResult[NodeRunEndT]]]: # noqa UP007
|
|
336
|
+
async with self.stream(ctx):
|
|
337
|
+
pass
|
|
338
|
+
|
|
339
|
+
assert (next_node := self._next_node) is not None, 'the stream should set `self._next_node` before it ends'
|
|
340
|
+
return next_node
|
|
341
|
+
|
|
342
|
+
@asynccontextmanager
|
|
343
|
+
async def stream(
|
|
344
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
345
|
+
) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
|
|
346
|
+
"""Process the model response and yield events for the start and end of each function tool call."""
|
|
279
347
|
with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
348
|
+
stream = self._run_stream(ctx)
|
|
349
|
+
yield stream
|
|
350
|
+
|
|
351
|
+
# Run the stream to completion if it was not finished:
|
|
352
|
+
async for _event in stream:
|
|
353
|
+
pass
|
|
354
|
+
|
|
355
|
+
# Set the next node based on the final state of the stream
|
|
356
|
+
next_node = self._next_node
|
|
357
|
+
if isinstance(next_node, End):
|
|
358
|
+
handle_span.set_attribute('result', next_node.data)
|
|
359
|
+
handle_span.message = 'handle model response -> final result'
|
|
360
|
+
elif tool_responses := self._tool_responses:
|
|
361
|
+
# TODO: We could drop `self._tool_responses` if we drop this set_attribute
|
|
362
|
+
# I'm thinking it might be better to just create a span for the handling of each tool
|
|
363
|
+
# than to set an attribute here.
|
|
364
|
+
handle_span.set_attribute('tool_responses', tool_responses)
|
|
365
|
+
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
366
|
+
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
367
|
+
|
|
368
|
+
async def _run_stream(
|
|
369
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
370
|
+
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
371
|
+
if self._events_iterator is None:
|
|
372
|
+
# Ensure that the stream is only run once
|
|
373
|
+
|
|
374
|
+
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
375
|
+
texts: list[str] = []
|
|
376
|
+
tool_calls: list[_messages.ToolCallPart] = []
|
|
377
|
+
for part in self.model_response.parts:
|
|
378
|
+
if isinstance(part, _messages.TextPart):
|
|
379
|
+
# ignore empty content for text parts, see #437
|
|
380
|
+
if part.content:
|
|
381
|
+
texts.append(part.content)
|
|
382
|
+
elif isinstance(part, _messages.ToolCallPart):
|
|
383
|
+
tool_calls.append(part)
|
|
384
|
+
else:
|
|
385
|
+
assert_never(part)
|
|
386
|
+
|
|
387
|
+
# At the moment, we prioritize at least executing tool calls if they are present.
|
|
388
|
+
# In the future, we'd consider making this configurable at the agent or run level.
|
|
389
|
+
# This accounts for cases like anthropic returns that might contain a text response
|
|
390
|
+
# and a tool call response, where the text response just indicates the tool call will happen.
|
|
391
|
+
if tool_calls:
|
|
392
|
+
async for event in self._handle_tool_calls(ctx, tool_calls):
|
|
393
|
+
yield event
|
|
394
|
+
elif texts:
|
|
395
|
+
# No events are emitted during the handling of text responses, so we don't need to yield anything
|
|
396
|
+
self._next_node = await self._handle_text_response(ctx, texts)
|
|
289
397
|
else:
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
# At the moment, we prioritize at least executing tool calls if they are present.
|
|
293
|
-
# In the future, we'd consider making this configurable at the agent or run level.
|
|
294
|
-
# This accounts for cases like anthropic returns that might contain a text response
|
|
295
|
-
# and a tool call response, where the text response just indicates the tool call will happen.
|
|
296
|
-
if tool_calls:
|
|
297
|
-
return await self._handle_tool_calls_response(ctx, tool_calls, handle_span)
|
|
298
|
-
elif texts:
|
|
299
|
-
return await self._handle_text_response(ctx, texts, handle_span)
|
|
300
|
-
else:
|
|
301
|
-
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
398
|
+
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
302
399
|
|
|
303
|
-
|
|
400
|
+
self._events_iterator = _run_stream()
|
|
401
|
+
|
|
402
|
+
async for event in self._events_iterator:
|
|
403
|
+
yield event
|
|
404
|
+
|
|
405
|
+
async def _handle_tool_calls(
|
|
304
406
|
self,
|
|
305
407
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
306
408
|
tool_calls: list[_messages.ToolCallPart],
|
|
307
|
-
|
|
308
|
-
):
|
|
409
|
+
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
309
410
|
result_schema = ctx.deps.result_schema
|
|
310
411
|
|
|
311
412
|
# first look for the result tool call
|
|
312
|
-
final_result:
|
|
413
|
+
final_result: result.FinalResult[NodeRunEndT] | None = None
|
|
313
414
|
parts: list[_messages.ModelRequestPart] = []
|
|
314
415
|
if result_schema is not None:
|
|
315
416
|
if match := result_schema.find_tool(tool_calls):
|
|
@@ -323,33 +424,51 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
|
|
|
323
424
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
324
425
|
parts.append(e.tool_retry)
|
|
325
426
|
else:
|
|
326
|
-
final_result =
|
|
427
|
+
final_result = result.FinalResult(result_data, call.tool_name)
|
|
327
428
|
|
|
328
429
|
# Then build the other request parts based on end strategy
|
|
329
|
-
tool_responses =
|
|
430
|
+
tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
|
|
431
|
+
async for event in process_function_tools(
|
|
432
|
+
tool_calls, final_result and final_result.tool_name, ctx, tool_responses
|
|
433
|
+
):
|
|
434
|
+
yield event
|
|
330
435
|
|
|
331
436
|
if final_result:
|
|
332
|
-
|
|
333
|
-
handle_span.message = 'handle model response -> final result'
|
|
334
|
-
return FinalResultNode[DepsT, NodeRunEndT](final_result, tool_responses)
|
|
437
|
+
self._next_node = self._handle_final_result(ctx, final_result, tool_responses)
|
|
335
438
|
else:
|
|
336
439
|
if tool_responses:
|
|
337
|
-
handle_span.set_attribute('tool_responses', tool_responses)
|
|
338
|
-
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
339
|
-
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
340
440
|
parts.extend(tool_responses)
|
|
341
|
-
|
|
441
|
+
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts))
|
|
442
|
+
|
|
443
|
+
def _handle_final_result(
|
|
444
|
+
self,
|
|
445
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
446
|
+
final_result: result.FinalResult[NodeRunEndT],
|
|
447
|
+
tool_responses: list[_messages.ModelRequestPart],
|
|
448
|
+
) -> End[result.FinalResult[NodeRunEndT]]:
|
|
449
|
+
run_span = ctx.deps.run_span
|
|
450
|
+
usage = ctx.state.usage
|
|
451
|
+
messages = ctx.state.message_history
|
|
452
|
+
|
|
453
|
+
# For backwards compatibility, append a new ModelRequest using the tool returns and retries
|
|
454
|
+
if tool_responses:
|
|
455
|
+
messages.append(_messages.ModelRequest(parts=tool_responses))
|
|
456
|
+
|
|
457
|
+
run_span.set_attribute('usage', usage)
|
|
458
|
+
run_span.set_attribute('all_messages', messages)
|
|
459
|
+
|
|
460
|
+
# End the run with self.data
|
|
461
|
+
return End(final_result)
|
|
342
462
|
|
|
343
463
|
async def _handle_text_response(
|
|
344
464
|
self,
|
|
345
465
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
346
466
|
texts: list[str],
|
|
347
|
-
|
|
348
|
-
):
|
|
467
|
+
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
|
|
349
468
|
result_schema = ctx.deps.result_schema
|
|
350
469
|
|
|
351
470
|
text = '\n\n'.join(texts)
|
|
352
|
-
if
|
|
471
|
+
if allow_text_result(result_schema):
|
|
353
472
|
result_data_input = cast(NodeRunEndT, text)
|
|
354
473
|
try:
|
|
355
474
|
result_data = await _validate_result(result_data_input, ctx, None)
|
|
@@ -357,9 +476,8 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
|
|
|
357
476
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
358
477
|
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
|
|
359
478
|
else:
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
return FinalResultNode[DepsT, NodeRunEndT](MarkFinalResult(result_data, None))
|
|
479
|
+
# The following cast is safe because we know `str` is an allowed result type
|
|
480
|
+
return self._handle_final_result(ctx, result.FinalResult(result_data, tool_name=None), [])
|
|
363
481
|
else:
|
|
364
482
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
365
483
|
return ModelRequestNode[DepsT, NodeRunEndT](
|
|
@@ -373,166 +491,8 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
|
|
|
373
491
|
)
|
|
374
492
|
|
|
375
493
|
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
"""Make a request to the model using the last message in state.message_history (or a specified request)."""
|
|
379
|
-
|
|
380
|
-
request: _messages.ModelRequest
|
|
381
|
-
_result: StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]] | None = (
|
|
382
|
-
field(default=None, repr=False)
|
|
383
|
-
)
|
|
384
|
-
|
|
385
|
-
async def run(
|
|
386
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
387
|
-
) -> Union[StreamModelRequestNode[DepsT, NodeRunEndT], End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: # noqa UP007
|
|
388
|
-
if self._result is not None:
|
|
389
|
-
return self._result
|
|
390
|
-
|
|
391
|
-
async with self.run_to_result(ctx) as final_node:
|
|
392
|
-
return final_node
|
|
393
|
-
|
|
394
|
-
@asynccontextmanager
|
|
395
|
-
async def run_to_result(
|
|
396
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
397
|
-
) -> AsyncIterator[StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]]]:
|
|
398
|
-
result_schema = ctx.deps.result_schema
|
|
399
|
-
|
|
400
|
-
ctx.state.message_history.append(self.request)
|
|
401
|
-
|
|
402
|
-
# Check usage
|
|
403
|
-
if ctx.deps.usage_limits:
|
|
404
|
-
ctx.deps.usage_limits.check_before_request(ctx.state.usage)
|
|
405
|
-
|
|
406
|
-
# Increment run_step
|
|
407
|
-
ctx.state.run_step += 1
|
|
408
|
-
|
|
409
|
-
with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step):
|
|
410
|
-
model_request_parameters = await _prepare_request_parameters(ctx)
|
|
411
|
-
|
|
412
|
-
# Actually make the model request
|
|
413
|
-
model_settings = merge_model_settings(ctx.deps.model_settings, None)
|
|
414
|
-
with _logfire.span('model request {run_step=}', run_step=ctx.state.run_step) as model_req_span:
|
|
415
|
-
async with ctx.deps.model.request_stream(
|
|
416
|
-
ctx.state.message_history, model_settings, model_request_parameters
|
|
417
|
-
) as streamed_response:
|
|
418
|
-
ctx.state.usage.requests += 1
|
|
419
|
-
model_req_span.set_attribute('response_type', streamed_response.__class__.__name__)
|
|
420
|
-
# We want to end the "model request" span here, but we can't exit the context manager
|
|
421
|
-
# in the traditional way
|
|
422
|
-
model_req_span.__exit__(None, None, None)
|
|
423
|
-
|
|
424
|
-
with _logfire.span('handle model response') as handle_span:
|
|
425
|
-
received_text = False
|
|
426
|
-
|
|
427
|
-
async for maybe_part_event in streamed_response:
|
|
428
|
-
if isinstance(maybe_part_event, _messages.PartStartEvent):
|
|
429
|
-
new_part = maybe_part_event.part
|
|
430
|
-
if isinstance(new_part, _messages.TextPart):
|
|
431
|
-
received_text = True
|
|
432
|
-
if _allow_text_result(result_schema):
|
|
433
|
-
handle_span.message = 'handle model response -> final result'
|
|
434
|
-
streamed_run_result = _build_streamed_run_result(streamed_response, None, ctx)
|
|
435
|
-
self._result = End(streamed_run_result)
|
|
436
|
-
yield self._result
|
|
437
|
-
return
|
|
438
|
-
elif isinstance(new_part, _messages.ToolCallPart):
|
|
439
|
-
if result_schema is not None and (match := result_schema.find_tool([new_part])):
|
|
440
|
-
call, _ = match
|
|
441
|
-
handle_span.message = 'handle model response -> final result'
|
|
442
|
-
streamed_run_result = _build_streamed_run_result(
|
|
443
|
-
streamed_response, call.tool_name, ctx
|
|
444
|
-
)
|
|
445
|
-
self._result = End(streamed_run_result)
|
|
446
|
-
yield self._result
|
|
447
|
-
return
|
|
448
|
-
else:
|
|
449
|
-
assert_never(new_part)
|
|
450
|
-
|
|
451
|
-
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
452
|
-
parts: list[_messages.ModelRequestPart] = []
|
|
453
|
-
model_response = streamed_response.get()
|
|
454
|
-
if not model_response.parts:
|
|
455
|
-
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
456
|
-
ctx.state.message_history.append(model_response)
|
|
457
|
-
|
|
458
|
-
run_context = _build_run_context(ctx)
|
|
459
|
-
for p in model_response.parts:
|
|
460
|
-
if isinstance(p, _messages.ToolCallPart):
|
|
461
|
-
if tool := ctx.deps.function_tools.get(p.tool_name):
|
|
462
|
-
tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
|
|
463
|
-
else:
|
|
464
|
-
parts.append(_unknown_tool(p.tool_name, ctx))
|
|
465
|
-
|
|
466
|
-
if received_text and not tasks and not parts:
|
|
467
|
-
# Can only get here if self._allow_text_result returns `False` for the provided result_schema
|
|
468
|
-
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
469
|
-
self._result = StreamModelRequestNode[DepsT, NodeRunEndT](
|
|
470
|
-
_messages.ModelRequest(
|
|
471
|
-
parts=[
|
|
472
|
-
_messages.RetryPromptPart(
|
|
473
|
-
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
474
|
-
)
|
|
475
|
-
]
|
|
476
|
-
)
|
|
477
|
-
)
|
|
478
|
-
yield self._result
|
|
479
|
-
return
|
|
480
|
-
|
|
481
|
-
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
482
|
-
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
483
|
-
parts.extend(task_results)
|
|
484
|
-
|
|
485
|
-
next_request = _messages.ModelRequest(parts=parts)
|
|
486
|
-
if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
487
|
-
try:
|
|
488
|
-
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
489
|
-
except:
|
|
490
|
-
# TODO: This is janky, so I think we should probably change it, but how?
|
|
491
|
-
ctx.state.message_history.append(next_request)
|
|
492
|
-
raise
|
|
493
|
-
|
|
494
|
-
handle_span.set_attribute('tool_responses', parts)
|
|
495
|
-
tool_responses_str = ' '.join(r.part_kind for r in parts)
|
|
496
|
-
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
497
|
-
# the model_response should have been fully streamed by now, we can add its usage
|
|
498
|
-
streamed_response_usage = streamed_response.usage()
|
|
499
|
-
run_context.usage.incr(streamed_response_usage)
|
|
500
|
-
ctx.deps.usage_limits.check_tokens(run_context.usage)
|
|
501
|
-
self._result = StreamModelRequestNode[DepsT, NodeRunEndT](next_request)
|
|
502
|
-
yield self._result
|
|
503
|
-
return
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
@dataclasses.dataclass
|
|
507
|
-
class FinalResultNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[NodeRunEndT]]):
|
|
508
|
-
"""Produce the final result of the run."""
|
|
509
|
-
|
|
510
|
-
data: MarkFinalResult[NodeRunEndT]
|
|
511
|
-
"""The final result data."""
|
|
512
|
-
extra_parts: list[_messages.ModelRequestPart] = dataclasses.field(default_factory=list)
|
|
513
|
-
|
|
514
|
-
async def run(
|
|
515
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
516
|
-
) -> End[MarkFinalResult[NodeRunEndT]]:
|
|
517
|
-
run_span = ctx.deps.run_span
|
|
518
|
-
usage = ctx.state.usage
|
|
519
|
-
messages = ctx.state.message_history
|
|
520
|
-
|
|
521
|
-
# TODO: For backwards compatibility, append a new ModelRequest using the tool returns and retries
|
|
522
|
-
if self.extra_parts:
|
|
523
|
-
messages.append(_messages.ModelRequest(parts=self.extra_parts))
|
|
524
|
-
|
|
525
|
-
# TODO: Set this attribute somewhere
|
|
526
|
-
# handle_span = self.handle_model_response_span
|
|
527
|
-
# handle_span.set_attribute('final_data', self.data)
|
|
528
|
-
run_span.set_attribute('usage', usage)
|
|
529
|
-
run_span.set_attribute('all_messages', messages)
|
|
530
|
-
|
|
531
|
-
# End the run with self.data
|
|
532
|
-
return End(self.data)
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
|
|
494
|
+
def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
|
|
495
|
+
"""Build a `RunContext` object from the current agent graph run context."""
|
|
536
496
|
return RunContext[DepsT](
|
|
537
497
|
deps=ctx.deps.user_deps,
|
|
538
498
|
model=ctx.deps.model,
|
|
@@ -543,76 +503,31 @@ def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[Deps
|
|
|
543
503
|
)
|
|
544
504
|
|
|
545
505
|
|
|
546
|
-
def
|
|
547
|
-
result_stream: models.StreamedResponse,
|
|
548
|
-
result_tool_name: str | None,
|
|
549
|
-
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
550
|
-
) -> result.StreamedRunResult[DepsT, NodeRunEndT]:
|
|
551
|
-
new_message_index = ctx.deps.new_message_index
|
|
552
|
-
result_schema = ctx.deps.result_schema
|
|
553
|
-
run_span = ctx.deps.run_span
|
|
554
|
-
usage_limits = ctx.deps.usage_limits
|
|
555
|
-
messages = ctx.state.message_history
|
|
556
|
-
run_context = _build_run_context(ctx)
|
|
557
|
-
|
|
558
|
-
async def on_complete():
|
|
559
|
-
"""Called when the stream has completed.
|
|
560
|
-
|
|
561
|
-
The model response will have been added to messages by now
|
|
562
|
-
by `StreamedRunResult._marked_completed`.
|
|
563
|
-
"""
|
|
564
|
-
last_message = messages[-1]
|
|
565
|
-
assert isinstance(last_message, _messages.ModelResponse)
|
|
566
|
-
tool_calls = [part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)]
|
|
567
|
-
parts = await _process_function_tools(
|
|
568
|
-
tool_calls,
|
|
569
|
-
result_tool_name,
|
|
570
|
-
ctx,
|
|
571
|
-
)
|
|
572
|
-
# TODO: Should we do something here related to the retry count?
|
|
573
|
-
# Maybe we should move the incrementing of the retry count to where we actually make a request?
|
|
574
|
-
# if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
575
|
-
# ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
576
|
-
if parts:
|
|
577
|
-
messages.append(_messages.ModelRequest(parts))
|
|
578
|
-
run_span.set_attribute('all_messages', messages)
|
|
579
|
-
|
|
580
|
-
return result.StreamedRunResult[DepsT, NodeRunEndT](
|
|
581
|
-
messages,
|
|
582
|
-
new_message_index,
|
|
583
|
-
usage_limits,
|
|
584
|
-
result_stream,
|
|
585
|
-
result_schema,
|
|
586
|
-
run_context,
|
|
587
|
-
ctx.deps.result_validators,
|
|
588
|
-
result_tool_name,
|
|
589
|
-
on_complete,
|
|
590
|
-
)
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
async def _process_function_tools(
|
|
506
|
+
async def process_function_tools(
|
|
594
507
|
tool_calls: list[_messages.ToolCallPart],
|
|
595
508
|
result_tool_name: str | None,
|
|
596
509
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
597
|
-
|
|
598
|
-
|
|
510
|
+
output_parts: list[_messages.ModelRequestPart],
|
|
511
|
+
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
512
|
+
"""Process function (i.e., non-result) tool calls in parallel.
|
|
599
513
|
|
|
600
514
|
Also add stub return parts for any other tools that need it.
|
|
601
|
-
"""
|
|
602
|
-
parts: list[_messages.ModelRequestPart] = []
|
|
603
|
-
tasks: list[asyncio.Task[_messages.ToolReturnPart | _messages.RetryPromptPart]] = []
|
|
604
515
|
|
|
516
|
+
Because async iterators can't have return values, we use `output_parts` as an output argument.
|
|
517
|
+
"""
|
|
605
518
|
stub_function_tools = bool(result_tool_name) and ctx.deps.end_strategy == 'early'
|
|
606
519
|
result_schema = ctx.deps.result_schema
|
|
607
520
|
|
|
608
521
|
# we rely on the fact that if we found a result, it's the first result tool in the last
|
|
609
522
|
found_used_result_tool = False
|
|
610
|
-
run_context =
|
|
523
|
+
run_context = build_run_context(ctx)
|
|
611
524
|
|
|
525
|
+
calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
|
|
526
|
+
call_index_to_event_id: dict[int, str] = {}
|
|
612
527
|
for call in tool_calls:
|
|
613
528
|
if call.tool_name == result_tool_name and not found_used_result_tool:
|
|
614
529
|
found_used_result_tool = True
|
|
615
|
-
|
|
530
|
+
output_parts.append(
|
|
616
531
|
_messages.ToolReturnPart(
|
|
617
532
|
tool_name=call.tool_name,
|
|
618
533
|
content='Final result processed.',
|
|
@@ -621,7 +536,7 @@ async def _process_function_tools(
|
|
|
621
536
|
)
|
|
622
537
|
elif tool := ctx.deps.function_tools.get(call.tool_name):
|
|
623
538
|
if stub_function_tools:
|
|
624
|
-
|
|
539
|
+
output_parts.append(
|
|
625
540
|
_messages.ToolReturnPart(
|
|
626
541
|
tool_name=call.tool_name,
|
|
627
542
|
content='Tool not executed - a final result was already processed.',
|
|
@@ -629,33 +544,47 @@ async def _process_function_tools(
|
|
|
629
544
|
)
|
|
630
545
|
)
|
|
631
546
|
else:
|
|
632
|
-
|
|
547
|
+
event = _messages.FunctionToolCallEvent(call)
|
|
548
|
+
yield event
|
|
549
|
+
call_index_to_event_id[len(calls_to_run)] = event.call_id
|
|
550
|
+
calls_to_run.append((tool, call))
|
|
633
551
|
elif result_schema is not None and call.tool_name in result_schema.tools:
|
|
634
552
|
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
|
|
635
553
|
# validation, we don't add another part here
|
|
636
554
|
if result_tool_name is not None:
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
tool_call_id=call.tool_call_id,
|
|
642
|
-
)
|
|
555
|
+
part = _messages.ToolReturnPart(
|
|
556
|
+
tool_name=call.tool_name,
|
|
557
|
+
content='Result tool not used - a final result was already processed.',
|
|
558
|
+
tool_call_id=call.tool_call_id,
|
|
643
559
|
)
|
|
560
|
+
output_parts.append(part)
|
|
644
561
|
else:
|
|
645
|
-
|
|
562
|
+
output_parts.append(_unknown_tool(call.tool_name, ctx))
|
|
563
|
+
|
|
564
|
+
if not calls_to_run:
|
|
565
|
+
return
|
|
646
566
|
|
|
647
567
|
# Run all tool tasks in parallel
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
568
|
+
results_by_index: dict[int, _messages.ModelRequestPart] = {}
|
|
569
|
+
with _logfire.span('running {tools=}', tools=[call.tool_name for _, call in calls_to_run]):
|
|
570
|
+
# TODO: Should we wrap each individual tool call in a dedicated span?
|
|
571
|
+
tasks = [asyncio.create_task(tool.run(call, run_context), name=call.tool_name) for tool, call in calls_to_run]
|
|
572
|
+
pending = tasks
|
|
573
|
+
while pending:
|
|
574
|
+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
|
575
|
+
for task in done:
|
|
576
|
+
index = tasks.index(task)
|
|
577
|
+
result = task.result()
|
|
578
|
+
yield _messages.FunctionToolResultEvent(result, call_id=call_index_to_event_id[index])
|
|
579
|
+
if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)):
|
|
580
|
+
results_by_index[index] = result
|
|
656
581
|
else:
|
|
657
582
|
assert_never(result)
|
|
658
|
-
|
|
583
|
+
|
|
584
|
+
# We append the results at the end, rather than as they are received, to retain a consistent ordering
|
|
585
|
+
# This is mostly just to simplify testing
|
|
586
|
+
for k in sorted(results_by_index):
|
|
587
|
+
output_parts.append(results_by_index[k])
|
|
659
588
|
|
|
660
589
|
|
|
661
590
|
def _unknown_tool(
|
|
@@ -681,12 +610,13 @@ async def _validate_result(
|
|
|
681
610
|
tool_call: _messages.ToolCallPart | None,
|
|
682
611
|
) -> T:
|
|
683
612
|
for validator in ctx.deps.result_validators:
|
|
684
|
-
run_context =
|
|
613
|
+
run_context = build_run_context(ctx)
|
|
685
614
|
result_data = await validator.validate(result_data, tool_call, run_context)
|
|
686
615
|
return result_data
|
|
687
616
|
|
|
688
617
|
|
|
689
|
-
def
|
|
618
|
+
def allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool:
|
|
619
|
+
"""Check if the result schema allows text results."""
|
|
690
620
|
return result_schema is None or result_schema.allow_text_result
|
|
691
621
|
|
|
692
622
|
|
|
@@ -740,35 +670,18 @@ def get_captured_run_messages() -> _RunMessages:
|
|
|
740
670
|
|
|
741
671
|
def build_agent_graph(
|
|
742
672
|
name: str | None, deps_type: type[DepsT], result_type: type[ResultT]
|
|
743
|
-
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT,
|
|
744
|
-
|
|
673
|
+
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[ResultT]], result.FinalResult[ResultT]]:
|
|
674
|
+
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
|
|
745
675
|
nodes = (
|
|
746
676
|
UserPromptNode[DepsT],
|
|
747
677
|
ModelRequestNode[DepsT],
|
|
748
678
|
HandleResponseNode[DepsT],
|
|
749
|
-
FinalResultNode[DepsT, ResultT],
|
|
750
679
|
)
|
|
751
|
-
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any],
|
|
680
|
+
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]](
|
|
752
681
|
nodes=nodes,
|
|
753
682
|
name=name or 'Agent',
|
|
754
683
|
state_type=GraphAgentState,
|
|
755
|
-
run_end_type=
|
|
684
|
+
run_end_type=result.FinalResult[result_type],
|
|
756
685
|
auto_instrument=False,
|
|
757
686
|
)
|
|
758
687
|
return graph
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
def build_agent_stream_graph(
|
|
762
|
-
name: str | None, deps_type: type[DepsT], result_type: type[ResultT] | None
|
|
763
|
-
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]:
|
|
764
|
-
nodes = [
|
|
765
|
-
StreamUserPromptNode[DepsT, result.StreamedRunResult[DepsT, ResultT]],
|
|
766
|
-
StreamModelRequestNode[DepsT, result.StreamedRunResult[DepsT, ResultT]],
|
|
767
|
-
]
|
|
768
|
-
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]](
|
|
769
|
-
nodes=nodes,
|
|
770
|
-
name=name or 'Agent',
|
|
771
|
-
state_type=GraphAgentState,
|
|
772
|
-
run_end_type=result.StreamedRunResult[DepsT, result_type],
|
|
773
|
-
)
|
|
774
|
-
return graph
|