pydantic-ai-slim 0.0.24__py3-none-any.whl → 0.0.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +5 -1
- pydantic_ai/_agent_graph.py +256 -346
- pydantic_ai/_utils.py +1 -1
- pydantic_ai/agent.py +572 -147
- pydantic_ai/messages.py +31 -0
- pydantic_ai/models/__init__.py +12 -1
- pydantic_ai/models/anthropic.py +41 -49
- pydantic_ai/models/cohere.py +1 -1
- pydantic_ai/models/function.py +3 -3
- pydantic_ai/models/gemini.py +18 -2
- pydantic_ai/models/instrumented.py +225 -0
- pydantic_ai/models/mistral.py +0 -3
- pydantic_ai/models/openai.py +2 -5
- pydantic_ai/models/test.py +6 -6
- pydantic_ai/models/wrapper.py +45 -0
- pydantic_ai/result.py +106 -144
- {pydantic_ai_slim-0.0.24.dist-info → pydantic_ai_slim-0.0.25.dist-info}/METADATA +2 -2
- pydantic_ai_slim-0.0.25.dist-info/RECORD +32 -0
- pydantic_ai_slim-0.0.24.dist-info/RECORD +0 -30
- {pydantic_ai_slim-0.0.24.dist-info → pydantic_ai_slim-0.0.25.dist-info}/WHEEL +0 -0
pydantic_ai/_agent_graph.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import dataclasses
|
|
5
5
|
from abc import ABC
|
|
6
|
-
from collections.abc import AsyncIterator, Iterator
|
|
6
|
+
from collections.abc import AsyncIterator, Iterator
|
|
7
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
8
8
|
from contextvars import ContextVar
|
|
9
9
|
from dataclasses import field
|
|
@@ -32,6 +32,16 @@ from .tools import (
|
|
|
32
32
|
ToolDefinition,
|
|
33
33
|
)
|
|
34
34
|
|
|
35
|
+
__all__ = (
|
|
36
|
+
'GraphAgentState',
|
|
37
|
+
'GraphAgentDeps',
|
|
38
|
+
'UserPromptNode',
|
|
39
|
+
'ModelRequestNode',
|
|
40
|
+
'HandleResponseNode',
|
|
41
|
+
'build_run_context',
|
|
42
|
+
'capture_run_messages',
|
|
43
|
+
)
|
|
44
|
+
|
|
35
45
|
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
36
46
|
|
|
37
47
|
# while waiting for https://github.com/pydantic/logfire/issues/745
|
|
@@ -56,21 +66,6 @@ DepsT = TypeVar('DepsT')
|
|
|
56
66
|
ResultT = TypeVar('ResultT')
|
|
57
67
|
|
|
58
68
|
|
|
59
|
-
@dataclasses.dataclass
|
|
60
|
-
class MarkFinalResult(Generic[ResultDataT]):
|
|
61
|
-
"""Marker class to indicate that the result is the final result.
|
|
62
|
-
|
|
63
|
-
This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly.
|
|
64
|
-
|
|
65
|
-
It also avoids problems in the case where the result type is itself `None`, but is set.
|
|
66
|
-
"""
|
|
67
|
-
|
|
68
|
-
data: ResultDataT
|
|
69
|
-
"""The final result data."""
|
|
70
|
-
tool_name: str | None
|
|
71
|
-
"""Name of the final result tool, None if the result is a string."""
|
|
72
|
-
|
|
73
|
-
|
|
74
69
|
@dataclasses.dataclass
|
|
75
70
|
class GraphAgentState:
|
|
76
71
|
"""State kept across the execution of the agent graph."""
|
|
@@ -113,17 +108,22 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
113
108
|
|
|
114
109
|
|
|
115
110
|
@dataclasses.dataclass
|
|
116
|
-
class
|
|
111
|
+
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
|
|
117
112
|
user_prompt: str
|
|
118
113
|
|
|
119
114
|
system_prompts: tuple[str, ...]
|
|
120
115
|
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
|
|
121
116
|
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
|
|
122
117
|
|
|
118
|
+
async def run(
|
|
119
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
|
|
120
|
+
) -> ModelRequestNode[DepsT, NodeRunEndT]:
|
|
121
|
+
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
|
|
122
|
+
|
|
123
123
|
async def _get_first_message(
|
|
124
124
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
|
|
125
125
|
) -> _messages.ModelRequest:
|
|
126
|
-
run_context =
|
|
126
|
+
run_context = build_run_context(ctx)
|
|
127
127
|
history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context)
|
|
128
128
|
ctx.state.message_history = history
|
|
129
129
|
run_context.messages = history
|
|
@@ -188,29 +188,13 @@ class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
|
|
|
188
188
|
return messages
|
|
189
189
|
|
|
190
190
|
|
|
191
|
-
@dataclasses.dataclass
|
|
192
|
-
class UserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
|
|
193
|
-
async def run(
|
|
194
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
|
|
195
|
-
) -> ModelRequestNode[DepsT, NodeRunEndT]:
|
|
196
|
-
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
@dataclasses.dataclass
|
|
200
|
-
class StreamUserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
|
|
201
|
-
async def run(
|
|
202
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
|
|
203
|
-
) -> StreamModelRequestNode[DepsT, NodeRunEndT]:
|
|
204
|
-
return StreamModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
|
|
205
|
-
|
|
206
|
-
|
|
207
191
|
async def _prepare_request_parameters(
|
|
208
192
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
209
193
|
) -> models.ModelRequestParameters:
|
|
210
194
|
"""Build tools and create an agent model."""
|
|
211
195
|
function_tool_defs: list[ToolDefinition] = []
|
|
212
196
|
|
|
213
|
-
run_context =
|
|
197
|
+
run_context = build_run_context(ctx)
|
|
214
198
|
|
|
215
199
|
async def add_tool(tool: Tool[DepsT]) -> None:
|
|
216
200
|
ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
|
|
@@ -222,7 +206,7 @@ async def _prepare_request_parameters(
|
|
|
222
206
|
result_schema = ctx.deps.result_schema
|
|
223
207
|
return models.ModelRequestParameters(
|
|
224
208
|
function_tools=function_tool_defs,
|
|
225
|
-
allow_text_result=
|
|
209
|
+
allow_text_result=allow_text_result(result_schema),
|
|
226
210
|
result_tools=result_schema.tool_defs() if result_schema is not None else [],
|
|
227
211
|
)
|
|
228
212
|
|
|
@@ -233,9 +217,70 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], Nod
|
|
|
233
217
|
|
|
234
218
|
request: _messages.ModelRequest
|
|
235
219
|
|
|
220
|
+
_result: HandleResponseNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
|
|
221
|
+
_did_stream: bool = field(default=False, repr=False)
|
|
222
|
+
|
|
236
223
|
async def run(
|
|
237
224
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
238
225
|
) -> HandleResponseNode[DepsT, NodeRunEndT]:
|
|
226
|
+
if self._result is not None:
|
|
227
|
+
return self._result
|
|
228
|
+
|
|
229
|
+
if self._did_stream:
|
|
230
|
+
# `self._result` gets set when exiting the `stream` contextmanager, so hitting this
|
|
231
|
+
# means that the stream was started but not finished before `run()` was called
|
|
232
|
+
raise exceptions.AgentRunError('You must finish streaming before calling run()')
|
|
233
|
+
|
|
234
|
+
return await self._make_request(ctx)
|
|
235
|
+
|
|
236
|
+
@asynccontextmanager
|
|
237
|
+
async def _stream(
|
|
238
|
+
self,
|
|
239
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
|
|
240
|
+
) -> AsyncIterator[models.StreamedResponse]:
|
|
241
|
+
# TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public
|
|
242
|
+
assert not self._did_stream, 'stream() should only be called once per node'
|
|
243
|
+
|
|
244
|
+
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
245
|
+
with _logfire.span('model request', run_step=ctx.state.run_step) as span:
|
|
246
|
+
async with ctx.deps.model.request_stream(
|
|
247
|
+
ctx.state.message_history, model_settings, model_request_parameters
|
|
248
|
+
) as streamed_response:
|
|
249
|
+
self._did_stream = True
|
|
250
|
+
ctx.state.usage.incr(_usage.Usage(), requests=1)
|
|
251
|
+
yield streamed_response
|
|
252
|
+
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
|
|
253
|
+
# otherwise usage won't be properly counted:
|
|
254
|
+
async for _ in streamed_response:
|
|
255
|
+
pass
|
|
256
|
+
model_response = streamed_response.get()
|
|
257
|
+
request_usage = streamed_response.usage()
|
|
258
|
+
span.set_attribute('response', model_response)
|
|
259
|
+
span.set_attribute('usage', request_usage)
|
|
260
|
+
|
|
261
|
+
self._finish_handling(ctx, model_response, request_usage)
|
|
262
|
+
assert self._result is not None # this should be set by the previous line
|
|
263
|
+
|
|
264
|
+
async def _make_request(
|
|
265
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
266
|
+
) -> HandleResponseNode[DepsT, NodeRunEndT]:
|
|
267
|
+
if self._result is not None:
|
|
268
|
+
return self._result
|
|
269
|
+
|
|
270
|
+
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
271
|
+
with _logfire.span('model request', run_step=ctx.state.run_step) as span:
|
|
272
|
+
model_response, request_usage = await ctx.deps.model.request(
|
|
273
|
+
ctx.state.message_history, model_settings, model_request_parameters
|
|
274
|
+
)
|
|
275
|
+
ctx.state.usage.incr(_usage.Usage(), requests=1)
|
|
276
|
+
span.set_attribute('response', model_response)
|
|
277
|
+
span.set_attribute('usage', request_usage)
|
|
278
|
+
|
|
279
|
+
return self._finish_handling(ctx, model_response, request_usage)
|
|
280
|
+
|
|
281
|
+
async def _prepare_request(
|
|
282
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
283
|
+
) -> tuple[ModelSettings | None, models.ModelRequestParameters]:
|
|
239
284
|
ctx.state.message_history.append(self.request)
|
|
240
285
|
|
|
241
286
|
# Check usage
|
|
@@ -245,71 +290,124 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], Nod
|
|
|
245
290
|
# Increment run_step
|
|
246
291
|
ctx.state.run_step += 1
|
|
247
292
|
|
|
293
|
+
model_settings = merge_model_settings(ctx.deps.model_settings, None)
|
|
248
294
|
with _logfire.span('preparing model request params {run_step=}', run_step=ctx.state.run_step):
|
|
249
295
|
model_request_parameters = await _prepare_request_parameters(ctx)
|
|
296
|
+
return model_settings, model_request_parameters
|
|
250
297
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
span.set_attribute('response', model_response)
|
|
258
|
-
span.set_attribute('usage', request_usage)
|
|
259
|
-
|
|
298
|
+
def _finish_handling(
|
|
299
|
+
self,
|
|
300
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
301
|
+
response: _messages.ModelResponse,
|
|
302
|
+
usage: _usage.Usage,
|
|
303
|
+
) -> HandleResponseNode[DepsT, NodeRunEndT]:
|
|
260
304
|
# Update usage
|
|
261
|
-
ctx.state.usage.incr(
|
|
305
|
+
ctx.state.usage.incr(usage, requests=0)
|
|
262
306
|
if ctx.deps.usage_limits:
|
|
263
307
|
ctx.deps.usage_limits.check_tokens(ctx.state.usage)
|
|
264
308
|
|
|
265
309
|
# Append the model response to state.message_history
|
|
266
|
-
ctx.state.message_history.append(
|
|
267
|
-
|
|
310
|
+
ctx.state.message_history.append(response)
|
|
311
|
+
|
|
312
|
+
# Set the `_result` attribute since we can't use `return` in an async iterator
|
|
313
|
+
self._result = HandleResponseNode(response)
|
|
314
|
+
|
|
315
|
+
return self._result
|
|
268
316
|
|
|
269
317
|
|
|
270
318
|
@dataclasses.dataclass
|
|
271
319
|
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
|
|
272
|
-
"""Process
|
|
320
|
+
"""Process a model response, and decide whether to end the run or make a new request."""
|
|
273
321
|
|
|
274
322
|
model_response: _messages.ModelResponse
|
|
275
323
|
|
|
324
|
+
_events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False)
|
|
325
|
+
_next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
|
|
326
|
+
default=None, repr=False
|
|
327
|
+
)
|
|
328
|
+
_tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False)
|
|
329
|
+
|
|
276
330
|
async def run(
|
|
277
331
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
278
|
-
) -> Union[ModelRequestNode[DepsT, NodeRunEndT],
|
|
332
|
+
) -> Union[ModelRequestNode[DepsT, NodeRunEndT], End[result.FinalResult[NodeRunEndT]]]: # noqa UP007
|
|
333
|
+
async with self.stream(ctx):
|
|
334
|
+
pass
|
|
335
|
+
|
|
336
|
+
assert (next_node := self._next_node) is not None, 'the stream should set `self._next_node` before it ends'
|
|
337
|
+
return next_node
|
|
338
|
+
|
|
339
|
+
@asynccontextmanager
|
|
340
|
+
async def stream(
|
|
341
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
|
|
342
|
+
) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
|
|
343
|
+
"""Process the model response and yield events for the start and end of each function tool call."""
|
|
279
344
|
with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
345
|
+
stream = self._run_stream(ctx)
|
|
346
|
+
yield stream
|
|
347
|
+
|
|
348
|
+
# Run the stream to completion if it was not finished:
|
|
349
|
+
async for _event in stream:
|
|
350
|
+
pass
|
|
351
|
+
|
|
352
|
+
# Set the next node based on the final state of the stream
|
|
353
|
+
next_node = self._next_node
|
|
354
|
+
if isinstance(next_node, End):
|
|
355
|
+
handle_span.set_attribute('result', next_node.data)
|
|
356
|
+
handle_span.message = 'handle model response -> final result'
|
|
357
|
+
elif tool_responses := self._tool_responses:
|
|
358
|
+
# TODO: We could drop `self._tool_responses` if we drop this set_attribute
|
|
359
|
+
# I'm thinking it might be better to just create a span for the handling of each tool
|
|
360
|
+
# than to set an attribute here.
|
|
361
|
+
handle_span.set_attribute('tool_responses', tool_responses)
|
|
362
|
+
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
363
|
+
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
364
|
+
|
|
365
|
+
async def _run_stream(
|
|
366
|
+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
|
|
367
|
+
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
368
|
+
if self._events_iterator is None:
|
|
369
|
+
# Ensure that the stream is only run once
|
|
370
|
+
|
|
371
|
+
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
372
|
+
texts: list[str] = []
|
|
373
|
+
tool_calls: list[_messages.ToolCallPart] = []
|
|
374
|
+
for part in self.model_response.parts:
|
|
375
|
+
if isinstance(part, _messages.TextPart):
|
|
376
|
+
# ignore empty content for text parts, see #437
|
|
377
|
+
if part.content:
|
|
378
|
+
texts.append(part.content)
|
|
379
|
+
elif isinstance(part, _messages.ToolCallPart):
|
|
380
|
+
tool_calls.append(part)
|
|
381
|
+
else:
|
|
382
|
+
assert_never(part)
|
|
383
|
+
|
|
384
|
+
# At the moment, we prioritize at least executing tool calls if they are present.
|
|
385
|
+
# In the future, we'd consider making this configurable at the agent or run level.
|
|
386
|
+
# This accounts for cases like anthropic returns that might contain a text response
|
|
387
|
+
# and a tool call response, where the text response just indicates the tool call will happen.
|
|
388
|
+
if tool_calls:
|
|
389
|
+
async for event in self._handle_tool_calls(ctx, tool_calls):
|
|
390
|
+
yield event
|
|
391
|
+
elif texts:
|
|
392
|
+
# No events are emitted during the handling of text responses, so we don't need to yield anything
|
|
393
|
+
self._next_node = await self._handle_text_response(ctx, texts)
|
|
289
394
|
else:
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
# At the moment, we prioritize at least executing tool calls if they are present.
|
|
293
|
-
# In the future, we'd consider making this configurable at the agent or run level.
|
|
294
|
-
# This accounts for cases like anthropic returns that might contain a text response
|
|
295
|
-
# and a tool call response, where the text response just indicates the tool call will happen.
|
|
296
|
-
if tool_calls:
|
|
297
|
-
return await self._handle_tool_calls_response(ctx, tool_calls, handle_span)
|
|
298
|
-
elif texts:
|
|
299
|
-
return await self._handle_text_response(ctx, texts, handle_span)
|
|
300
|
-
else:
|
|
301
|
-
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
395
|
+
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
302
396
|
|
|
303
|
-
|
|
397
|
+
self._events_iterator = _run_stream()
|
|
398
|
+
|
|
399
|
+
async for event in self._events_iterator:
|
|
400
|
+
yield event
|
|
401
|
+
|
|
402
|
+
async def _handle_tool_calls(
|
|
304
403
|
self,
|
|
305
404
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
306
405
|
tool_calls: list[_messages.ToolCallPart],
|
|
307
|
-
|
|
308
|
-
):
|
|
406
|
+
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
309
407
|
result_schema = ctx.deps.result_schema
|
|
310
408
|
|
|
311
409
|
# first look for the result tool call
|
|
312
|
-
final_result:
|
|
410
|
+
final_result: result.FinalResult[NodeRunEndT] | None = None
|
|
313
411
|
parts: list[_messages.ModelRequestPart] = []
|
|
314
412
|
if result_schema is not None:
|
|
315
413
|
if match := result_schema.find_tool(tool_calls):
|
|
@@ -323,33 +421,51 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
|
|
|
323
421
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
324
422
|
parts.append(e.tool_retry)
|
|
325
423
|
else:
|
|
326
|
-
final_result =
|
|
424
|
+
final_result = result.FinalResult(result_data, call.tool_name)
|
|
327
425
|
|
|
328
426
|
# Then build the other request parts based on end strategy
|
|
329
|
-
tool_responses =
|
|
427
|
+
tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
|
|
428
|
+
async for event in process_function_tools(
|
|
429
|
+
tool_calls, final_result and final_result.tool_name, ctx, tool_responses
|
|
430
|
+
):
|
|
431
|
+
yield event
|
|
330
432
|
|
|
331
433
|
if final_result:
|
|
332
|
-
|
|
333
|
-
handle_span.message = 'handle model response -> final result'
|
|
334
|
-
return FinalResultNode[DepsT, NodeRunEndT](final_result, tool_responses)
|
|
434
|
+
self._next_node = self._handle_final_result(ctx, final_result, tool_responses)
|
|
335
435
|
else:
|
|
336
436
|
if tool_responses:
|
|
337
|
-
handle_span.set_attribute('tool_responses', tool_responses)
|
|
338
|
-
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
339
|
-
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
340
437
|
parts.extend(tool_responses)
|
|
341
|
-
|
|
438
|
+
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts))
|
|
439
|
+
|
|
440
|
+
def _handle_final_result(
|
|
441
|
+
self,
|
|
442
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
443
|
+
final_result: result.FinalResult[NodeRunEndT],
|
|
444
|
+
tool_responses: list[_messages.ModelRequestPart],
|
|
445
|
+
) -> End[result.FinalResult[NodeRunEndT]]:
|
|
446
|
+
run_span = ctx.deps.run_span
|
|
447
|
+
usage = ctx.state.usage
|
|
448
|
+
messages = ctx.state.message_history
|
|
449
|
+
|
|
450
|
+
# For backwards compatibility, append a new ModelRequest using the tool returns and retries
|
|
451
|
+
if tool_responses:
|
|
452
|
+
messages.append(_messages.ModelRequest(parts=tool_responses))
|
|
453
|
+
|
|
454
|
+
run_span.set_attribute('usage', usage)
|
|
455
|
+
run_span.set_attribute('all_messages', messages)
|
|
456
|
+
|
|
457
|
+
# End the run with self.data
|
|
458
|
+
return End(final_result)
|
|
342
459
|
|
|
343
460
|
async def _handle_text_response(
|
|
344
461
|
self,
|
|
345
462
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
346
463
|
texts: list[str],
|
|
347
|
-
|
|
348
|
-
):
|
|
464
|
+
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
|
|
349
465
|
result_schema = ctx.deps.result_schema
|
|
350
466
|
|
|
351
467
|
text = '\n\n'.join(texts)
|
|
352
|
-
if
|
|
468
|
+
if allow_text_result(result_schema):
|
|
353
469
|
result_data_input = cast(NodeRunEndT, text)
|
|
354
470
|
try:
|
|
355
471
|
result_data = await _validate_result(result_data_input, ctx, None)
|
|
@@ -357,9 +473,8 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
|
|
|
357
473
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
358
474
|
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
|
|
359
475
|
else:
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
return FinalResultNode[DepsT, NodeRunEndT](MarkFinalResult(result_data, None))
|
|
476
|
+
# The following cast is safe because we know `str` is an allowed result type
|
|
477
|
+
return self._handle_final_result(ctx, result.FinalResult(result_data, tool_name=None), [])
|
|
363
478
|
else:
|
|
364
479
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
365
480
|
return ModelRequestNode[DepsT, NodeRunEndT](
|
|
@@ -373,166 +488,8 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
|
|
|
373
488
|
)
|
|
374
489
|
|
|
375
490
|
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
"""Make a request to the model using the last message in state.message_history (or a specified request)."""
|
|
379
|
-
|
|
380
|
-
request: _messages.ModelRequest
|
|
381
|
-
_result: StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]] | None = (
|
|
382
|
-
field(default=None, repr=False)
|
|
383
|
-
)
|
|
384
|
-
|
|
385
|
-
async def run(
|
|
386
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
387
|
-
) -> Union[StreamModelRequestNode[DepsT, NodeRunEndT], End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: # noqa UP007
|
|
388
|
-
if self._result is not None:
|
|
389
|
-
return self._result
|
|
390
|
-
|
|
391
|
-
async with self.run_to_result(ctx) as final_node:
|
|
392
|
-
return final_node
|
|
393
|
-
|
|
394
|
-
@asynccontextmanager
|
|
395
|
-
async def run_to_result(
|
|
396
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
397
|
-
) -> AsyncIterator[StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]]]:
|
|
398
|
-
result_schema = ctx.deps.result_schema
|
|
399
|
-
|
|
400
|
-
ctx.state.message_history.append(self.request)
|
|
401
|
-
|
|
402
|
-
# Check usage
|
|
403
|
-
if ctx.deps.usage_limits:
|
|
404
|
-
ctx.deps.usage_limits.check_before_request(ctx.state.usage)
|
|
405
|
-
|
|
406
|
-
# Increment run_step
|
|
407
|
-
ctx.state.run_step += 1
|
|
408
|
-
|
|
409
|
-
with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step):
|
|
410
|
-
model_request_parameters = await _prepare_request_parameters(ctx)
|
|
411
|
-
|
|
412
|
-
# Actually make the model request
|
|
413
|
-
model_settings = merge_model_settings(ctx.deps.model_settings, None)
|
|
414
|
-
with _logfire.span('model request {run_step=}', run_step=ctx.state.run_step) as model_req_span:
|
|
415
|
-
async with ctx.deps.model.request_stream(
|
|
416
|
-
ctx.state.message_history, model_settings, model_request_parameters
|
|
417
|
-
) as streamed_response:
|
|
418
|
-
ctx.state.usage.requests += 1
|
|
419
|
-
model_req_span.set_attribute('response_type', streamed_response.__class__.__name__)
|
|
420
|
-
# We want to end the "model request" span here, but we can't exit the context manager
|
|
421
|
-
# in the traditional way
|
|
422
|
-
model_req_span.__exit__(None, None, None)
|
|
423
|
-
|
|
424
|
-
with _logfire.span('handle model response') as handle_span:
|
|
425
|
-
received_text = False
|
|
426
|
-
|
|
427
|
-
async for maybe_part_event in streamed_response:
|
|
428
|
-
if isinstance(maybe_part_event, _messages.PartStartEvent):
|
|
429
|
-
new_part = maybe_part_event.part
|
|
430
|
-
if isinstance(new_part, _messages.TextPart):
|
|
431
|
-
received_text = True
|
|
432
|
-
if _allow_text_result(result_schema):
|
|
433
|
-
handle_span.message = 'handle model response -> final result'
|
|
434
|
-
streamed_run_result = _build_streamed_run_result(streamed_response, None, ctx)
|
|
435
|
-
self._result = End(streamed_run_result)
|
|
436
|
-
yield self._result
|
|
437
|
-
return
|
|
438
|
-
elif isinstance(new_part, _messages.ToolCallPart):
|
|
439
|
-
if result_schema is not None and (match := result_schema.find_tool([new_part])):
|
|
440
|
-
call, _ = match
|
|
441
|
-
handle_span.message = 'handle model response -> final result'
|
|
442
|
-
streamed_run_result = _build_streamed_run_result(
|
|
443
|
-
streamed_response, call.tool_name, ctx
|
|
444
|
-
)
|
|
445
|
-
self._result = End(streamed_run_result)
|
|
446
|
-
yield self._result
|
|
447
|
-
return
|
|
448
|
-
else:
|
|
449
|
-
assert_never(new_part)
|
|
450
|
-
|
|
451
|
-
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
452
|
-
parts: list[_messages.ModelRequestPart] = []
|
|
453
|
-
model_response = streamed_response.get()
|
|
454
|
-
if not model_response.parts:
|
|
455
|
-
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
456
|
-
ctx.state.message_history.append(model_response)
|
|
457
|
-
|
|
458
|
-
run_context = _build_run_context(ctx)
|
|
459
|
-
for p in model_response.parts:
|
|
460
|
-
if isinstance(p, _messages.ToolCallPart):
|
|
461
|
-
if tool := ctx.deps.function_tools.get(p.tool_name):
|
|
462
|
-
tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
|
|
463
|
-
else:
|
|
464
|
-
parts.append(_unknown_tool(p.tool_name, ctx))
|
|
465
|
-
|
|
466
|
-
if received_text and not tasks and not parts:
|
|
467
|
-
# Can only get here if self._allow_text_result returns `False` for the provided result_schema
|
|
468
|
-
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
469
|
-
self._result = StreamModelRequestNode[DepsT, NodeRunEndT](
|
|
470
|
-
_messages.ModelRequest(
|
|
471
|
-
parts=[
|
|
472
|
-
_messages.RetryPromptPart(
|
|
473
|
-
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
474
|
-
)
|
|
475
|
-
]
|
|
476
|
-
)
|
|
477
|
-
)
|
|
478
|
-
yield self._result
|
|
479
|
-
return
|
|
480
|
-
|
|
481
|
-
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
482
|
-
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
483
|
-
parts.extend(task_results)
|
|
484
|
-
|
|
485
|
-
next_request = _messages.ModelRequest(parts=parts)
|
|
486
|
-
if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
487
|
-
try:
|
|
488
|
-
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
489
|
-
except:
|
|
490
|
-
# TODO: This is janky, so I think we should probably change it, but how?
|
|
491
|
-
ctx.state.message_history.append(next_request)
|
|
492
|
-
raise
|
|
493
|
-
|
|
494
|
-
handle_span.set_attribute('tool_responses', parts)
|
|
495
|
-
tool_responses_str = ' '.join(r.part_kind for r in parts)
|
|
496
|
-
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
497
|
-
# the model_response should have been fully streamed by now, we can add its usage
|
|
498
|
-
streamed_response_usage = streamed_response.usage()
|
|
499
|
-
run_context.usage.incr(streamed_response_usage)
|
|
500
|
-
ctx.deps.usage_limits.check_tokens(run_context.usage)
|
|
501
|
-
self._result = StreamModelRequestNode[DepsT, NodeRunEndT](next_request)
|
|
502
|
-
yield self._result
|
|
503
|
-
return
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
@dataclasses.dataclass
|
|
507
|
-
class FinalResultNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[NodeRunEndT]]):
|
|
508
|
-
"""Produce the final result of the run."""
|
|
509
|
-
|
|
510
|
-
data: MarkFinalResult[NodeRunEndT]
|
|
511
|
-
"""The final result data."""
|
|
512
|
-
extra_parts: list[_messages.ModelRequestPart] = dataclasses.field(default_factory=list)
|
|
513
|
-
|
|
514
|
-
async def run(
|
|
515
|
-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
516
|
-
) -> End[MarkFinalResult[NodeRunEndT]]:
|
|
517
|
-
run_span = ctx.deps.run_span
|
|
518
|
-
usage = ctx.state.usage
|
|
519
|
-
messages = ctx.state.message_history
|
|
520
|
-
|
|
521
|
-
# TODO: For backwards compatibility, append a new ModelRequest using the tool returns and retries
|
|
522
|
-
if self.extra_parts:
|
|
523
|
-
messages.append(_messages.ModelRequest(parts=self.extra_parts))
|
|
524
|
-
|
|
525
|
-
# TODO: Set this attribute somewhere
|
|
526
|
-
# handle_span = self.handle_model_response_span
|
|
527
|
-
# handle_span.set_attribute('final_data', self.data)
|
|
528
|
-
run_span.set_attribute('usage', usage)
|
|
529
|
-
run_span.set_attribute('all_messages', messages)
|
|
530
|
-
|
|
531
|
-
# End the run with self.data
|
|
532
|
-
return End(self.data)
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
|
|
491
|
+
def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
|
|
492
|
+
"""Build a `RunContext` object from the current agent graph run context."""
|
|
536
493
|
return RunContext[DepsT](
|
|
537
494
|
deps=ctx.deps.user_deps,
|
|
538
495
|
model=ctx.deps.model,
|
|
@@ -543,76 +500,31 @@ def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[Deps
|
|
|
543
500
|
)
|
|
544
501
|
|
|
545
502
|
|
|
546
|
-
def
|
|
547
|
-
result_stream: models.StreamedResponse,
|
|
548
|
-
result_tool_name: str | None,
|
|
549
|
-
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
550
|
-
) -> result.StreamedRunResult[DepsT, NodeRunEndT]:
|
|
551
|
-
new_message_index = ctx.deps.new_message_index
|
|
552
|
-
result_schema = ctx.deps.result_schema
|
|
553
|
-
run_span = ctx.deps.run_span
|
|
554
|
-
usage_limits = ctx.deps.usage_limits
|
|
555
|
-
messages = ctx.state.message_history
|
|
556
|
-
run_context = _build_run_context(ctx)
|
|
557
|
-
|
|
558
|
-
async def on_complete():
|
|
559
|
-
"""Called when the stream has completed.
|
|
560
|
-
|
|
561
|
-
The model response will have been added to messages by now
|
|
562
|
-
by `StreamedRunResult._marked_completed`.
|
|
563
|
-
"""
|
|
564
|
-
last_message = messages[-1]
|
|
565
|
-
assert isinstance(last_message, _messages.ModelResponse)
|
|
566
|
-
tool_calls = [part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)]
|
|
567
|
-
parts = await _process_function_tools(
|
|
568
|
-
tool_calls,
|
|
569
|
-
result_tool_name,
|
|
570
|
-
ctx,
|
|
571
|
-
)
|
|
572
|
-
# TODO: Should we do something here related to the retry count?
|
|
573
|
-
# Maybe we should move the incrementing of the retry count to where we actually make a request?
|
|
574
|
-
# if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
575
|
-
# ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
576
|
-
if parts:
|
|
577
|
-
messages.append(_messages.ModelRequest(parts))
|
|
578
|
-
run_span.set_attribute('all_messages', messages)
|
|
579
|
-
|
|
580
|
-
return result.StreamedRunResult[DepsT, NodeRunEndT](
|
|
581
|
-
messages,
|
|
582
|
-
new_message_index,
|
|
583
|
-
usage_limits,
|
|
584
|
-
result_stream,
|
|
585
|
-
result_schema,
|
|
586
|
-
run_context,
|
|
587
|
-
ctx.deps.result_validators,
|
|
588
|
-
result_tool_name,
|
|
589
|
-
on_complete,
|
|
590
|
-
)
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
async def _process_function_tools(
|
|
503
|
+
async def process_function_tools(
|
|
594
504
|
tool_calls: list[_messages.ToolCallPart],
|
|
595
505
|
result_tool_name: str | None,
|
|
596
506
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
597
|
-
|
|
598
|
-
|
|
507
|
+
output_parts: list[_messages.ModelRequestPart],
|
|
508
|
+
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
509
|
+
"""Process function (i.e., non-result) tool calls in parallel.
|
|
599
510
|
|
|
600
511
|
Also add stub return parts for any other tools that need it.
|
|
601
|
-
"""
|
|
602
|
-
parts: list[_messages.ModelRequestPart] = []
|
|
603
|
-
tasks: list[asyncio.Task[_messages.ToolReturnPart | _messages.RetryPromptPart]] = []
|
|
604
512
|
|
|
513
|
+
Because async iterators can't have return values, we use `output_parts` as an output argument.
|
|
514
|
+
"""
|
|
605
515
|
stub_function_tools = bool(result_tool_name) and ctx.deps.end_strategy == 'early'
|
|
606
516
|
result_schema = ctx.deps.result_schema
|
|
607
517
|
|
|
608
518
|
# we rely on the fact that if we found a result, it's the first result tool in the last
|
|
609
519
|
found_used_result_tool = False
|
|
610
|
-
run_context =
|
|
520
|
+
run_context = build_run_context(ctx)
|
|
611
521
|
|
|
522
|
+
calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
|
|
523
|
+
call_index_to_event_id: dict[int, str] = {}
|
|
612
524
|
for call in tool_calls:
|
|
613
525
|
if call.tool_name == result_tool_name and not found_used_result_tool:
|
|
614
526
|
found_used_result_tool = True
|
|
615
|
-
|
|
527
|
+
output_parts.append(
|
|
616
528
|
_messages.ToolReturnPart(
|
|
617
529
|
tool_name=call.tool_name,
|
|
618
530
|
content='Final result processed.',
|
|
@@ -621,7 +533,7 @@ async def _process_function_tools(
|
|
|
621
533
|
)
|
|
622
534
|
elif tool := ctx.deps.function_tools.get(call.tool_name):
|
|
623
535
|
if stub_function_tools:
|
|
624
|
-
|
|
536
|
+
output_parts.append(
|
|
625
537
|
_messages.ToolReturnPart(
|
|
626
538
|
tool_name=call.tool_name,
|
|
627
539
|
content='Tool not executed - a final result was already processed.',
|
|
@@ -629,33 +541,47 @@ async def _process_function_tools(
|
|
|
629
541
|
)
|
|
630
542
|
)
|
|
631
543
|
else:
|
|
632
|
-
|
|
544
|
+
event = _messages.FunctionToolCallEvent(call)
|
|
545
|
+
yield event
|
|
546
|
+
call_index_to_event_id[len(calls_to_run)] = event.call_id
|
|
547
|
+
calls_to_run.append((tool, call))
|
|
633
548
|
elif result_schema is not None and call.tool_name in result_schema.tools:
|
|
634
549
|
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
|
|
635
550
|
# validation, we don't add another part here
|
|
636
551
|
if result_tool_name is not None:
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
tool_call_id=call.tool_call_id,
|
|
642
|
-
)
|
|
552
|
+
part = _messages.ToolReturnPart(
|
|
553
|
+
tool_name=call.tool_name,
|
|
554
|
+
content='Result tool not used - a final result was already processed.',
|
|
555
|
+
tool_call_id=call.tool_call_id,
|
|
643
556
|
)
|
|
557
|
+
output_parts.append(part)
|
|
644
558
|
else:
|
|
645
|
-
|
|
559
|
+
output_parts.append(_unknown_tool(call.tool_name, ctx))
|
|
560
|
+
|
|
561
|
+
if not calls_to_run:
|
|
562
|
+
return
|
|
646
563
|
|
|
647
564
|
# Run all tool tasks in parallel
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
565
|
+
results_by_index: dict[int, _messages.ModelRequestPart] = {}
|
|
566
|
+
with _logfire.span('running {tools=}', tools=[call.tool_name for _, call in calls_to_run]):
|
|
567
|
+
# TODO: Should we wrap each individual tool call in a dedicated span?
|
|
568
|
+
tasks = [asyncio.create_task(tool.run(call, run_context), name=call.tool_name) for tool, call in calls_to_run]
|
|
569
|
+
pending = tasks
|
|
570
|
+
while pending:
|
|
571
|
+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
|
572
|
+
for task in done:
|
|
573
|
+
index = tasks.index(task)
|
|
574
|
+
result = task.result()
|
|
575
|
+
yield _messages.FunctionToolResultEvent(result, call_id=call_index_to_event_id[index])
|
|
576
|
+
if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)):
|
|
577
|
+
results_by_index[index] = result
|
|
656
578
|
else:
|
|
657
579
|
assert_never(result)
|
|
658
|
-
|
|
580
|
+
|
|
581
|
+
# We append the results at the end, rather than as they are received, to retain a consistent ordering
|
|
582
|
+
# This is mostly just to simplify testing
|
|
583
|
+
for k in sorted(results_by_index):
|
|
584
|
+
output_parts.append(results_by_index[k])
|
|
659
585
|
|
|
660
586
|
|
|
661
587
|
def _unknown_tool(
|
|
@@ -681,12 +607,13 @@ async def _validate_result(
|
|
|
681
607
|
tool_call: _messages.ToolCallPart | None,
|
|
682
608
|
) -> T:
|
|
683
609
|
for validator in ctx.deps.result_validators:
|
|
684
|
-
run_context =
|
|
610
|
+
run_context = build_run_context(ctx)
|
|
685
611
|
result_data = await validator.validate(result_data, tool_call, run_context)
|
|
686
612
|
return result_data
|
|
687
613
|
|
|
688
614
|
|
|
689
|
-
def
|
|
615
|
+
def allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool:
|
|
616
|
+
"""Check if the result schema allows text results."""
|
|
690
617
|
return result_schema is None or result_schema.allow_text_result
|
|
691
618
|
|
|
692
619
|
|
|
@@ -740,35 +667,18 @@ def get_captured_run_messages() -> _RunMessages:
|
|
|
740
667
|
|
|
741
668
|
def build_agent_graph(
|
|
742
669
|
name: str | None, deps_type: type[DepsT], result_type: type[ResultT]
|
|
743
|
-
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any],
|
|
744
|
-
|
|
670
|
+
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]]:
|
|
671
|
+
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
|
|
745
672
|
nodes = (
|
|
746
673
|
UserPromptNode[DepsT],
|
|
747
674
|
ModelRequestNode[DepsT],
|
|
748
675
|
HandleResponseNode[DepsT],
|
|
749
|
-
FinalResultNode[DepsT, ResultT],
|
|
750
676
|
)
|
|
751
|
-
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any],
|
|
677
|
+
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]](
|
|
752
678
|
nodes=nodes,
|
|
753
679
|
name=name or 'Agent',
|
|
754
680
|
state_type=GraphAgentState,
|
|
755
|
-
run_end_type=
|
|
681
|
+
run_end_type=result.FinalResult[result_type],
|
|
756
682
|
auto_instrument=False,
|
|
757
683
|
)
|
|
758
684
|
return graph
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
def build_agent_stream_graph(
|
|
762
|
-
name: str | None, deps_type: type[DepsT], result_type: type[ResultT] | None
|
|
763
|
-
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]:
|
|
764
|
-
nodes = [
|
|
765
|
-
StreamUserPromptNode[DepsT, result.StreamedRunResult[DepsT, ResultT]],
|
|
766
|
-
StreamModelRequestNode[DepsT, result.StreamedRunResult[DepsT, ResultT]],
|
|
767
|
-
]
|
|
768
|
-
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]](
|
|
769
|
-
nodes=nodes,
|
|
770
|
-
name=name or 'Agent',
|
|
771
|
-
state_type=GraphAgentState,
|
|
772
|
-
run_end_type=result.StreamedRunResult[DepsT, result_type],
|
|
773
|
-
)
|
|
774
|
-
return graph
|