pydantic-ai-slim 1.0.7__py3-none-any.whl → 1.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +43 -23
- pydantic_ai/_cli.py +1 -1
- pydantic_ai/_otel_messages.py +2 -0
- pydantic_ai/_parts_manager.py +82 -12
- pydantic_ai/_run_context.py +8 -1
- pydantic_ai/_tool_manager.py +1 -0
- pydantic_ai/ag_ui.py +93 -40
- pydantic_ai/agent/__init__.py +2 -4
- pydantic_ai/builtin_tools.py +12 -0
- pydantic_ai/durable_exec/temporal/_model.py +14 -6
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/messages.py +69 -30
- pydantic_ai/models/__init__.py +4 -6
- pydantic_ai/models/anthropic.py +119 -45
- pydantic_ai/models/function.py +17 -8
- pydantic_ai/models/google.py +105 -16
- pydantic_ai/models/groq.py +68 -17
- pydantic_ai/models/openai.py +262 -41
- pydantic_ai/providers/__init__.py +1 -1
- pydantic_ai/result.py +24 -8
- pydantic_ai/toolsets/function.py +8 -2
- pydantic_ai/usage.py +2 -2
- {pydantic_ai_slim-1.0.7.dist-info → pydantic_ai_slim-1.0.9.dist-info}/METADATA +5 -5
- {pydantic_ai_slim-1.0.7.dist-info → pydantic_ai_slim-1.0.9.dist-info}/RECORD +27 -27
- {pydantic_ai_slim-1.0.7.dist-info → pydantic_ai_slim-1.0.9.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.7.dist-info → pydantic_ai_slim-1.0.9.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.7.dist-info → pydantic_ai_slim-1.0.9.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_agent_graph.py
CHANGED
|
@@ -8,6 +8,7 @@ from collections import defaultdict, deque
|
|
|
8
8
|
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
|
|
9
9
|
from contextlib import asynccontextmanager, contextmanager
|
|
10
10
|
from contextvars import ContextVar
|
|
11
|
+
from copy import deepcopy
|
|
11
12
|
from dataclasses import field, replace
|
|
12
13
|
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast
|
|
13
14
|
|
|
@@ -186,9 +187,8 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
186
187
|
messages = ctx_messages.messages
|
|
187
188
|
ctx_messages.used = True
|
|
188
189
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
messages.extend(message_history)
|
|
190
|
+
# Replace the `capture_run_messages` list with the message history
|
|
191
|
+
messages[:] = _clean_message_history(ctx.state.message_history)
|
|
192
192
|
# Use the `capture_run_messages` list as the message history so that new messages are added to it
|
|
193
193
|
ctx.state.message_history = messages
|
|
194
194
|
ctx.deps.new_message_index = len(messages)
|
|
@@ -455,7 +455,18 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
455
455
|
# This will raise errors for any tool name conflicts
|
|
456
456
|
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
|
|
457
457
|
|
|
458
|
-
|
|
458
|
+
original_history = ctx.state.message_history[:]
|
|
459
|
+
message_history = await _process_message_history(original_history, ctx.deps.history_processors, run_context)
|
|
460
|
+
# Never merge the new `ModelRequest` with the one preceding it, to keep `new_messages()` from accidentally including part of the existing message history
|
|
461
|
+
message_history = [*_clean_message_history(message_history[:-1]), message_history[-1]]
|
|
462
|
+
# `ctx.state.message_history` is the same list used by `capture_run_messages`, so we should replace its contents, not the reference
|
|
463
|
+
ctx.state.message_history[:] = message_history
|
|
464
|
+
# Update the new message index to ensure `result.new_messages()` returns the correct messages
|
|
465
|
+
ctx.deps.new_message_index -= len(original_history) - len(message_history)
|
|
466
|
+
|
|
467
|
+
# Do one more cleaning pass to merge possible consecutive trailing `ModelRequest`s into one, with tool call parts before user parts,
|
|
468
|
+
# but don't store it in the message history on state.
|
|
469
|
+
# See `tests/test_tools.py::test_parallel_tool_return_with_deferred` for an example where this is necessary
|
|
459
470
|
message_history = _clean_message_history(message_history)
|
|
460
471
|
|
|
461
472
|
model_request_parameters = await _prepare_request_parameters(ctx)
|
|
@@ -465,7 +476,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
465
476
|
usage = ctx.state.usage
|
|
466
477
|
if ctx.deps.usage_limits.count_tokens_before_request:
|
|
467
478
|
# Copy to avoid modifying the original usage object with the counted usage
|
|
468
|
-
usage =
|
|
479
|
+
usage = deepcopy(usage)
|
|
469
480
|
|
|
470
481
|
counted_usage = await ctx.deps.model.count_tokens(message_history, model_settings, model_request_parameters)
|
|
471
482
|
usage.incr(counted_usage)
|
|
@@ -534,21 +545,22 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
534
545
|
# Ensure that the stream is only run once
|
|
535
546
|
|
|
536
547
|
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa: C901
|
|
537
|
-
|
|
548
|
+
text = ''
|
|
538
549
|
tool_calls: list[_messages.ToolCallPart] = []
|
|
539
550
|
thinking_parts: list[_messages.ThinkingPart] = []
|
|
540
551
|
|
|
541
552
|
for part in self.model_response.parts:
|
|
542
553
|
if isinstance(part, _messages.TextPart):
|
|
543
|
-
|
|
544
|
-
if part.content:
|
|
545
|
-
texts.append(part.content)
|
|
554
|
+
text += part.content
|
|
546
555
|
elif isinstance(part, _messages.ToolCallPart):
|
|
547
556
|
tool_calls.append(part)
|
|
548
557
|
elif isinstance(part, _messages.BuiltinToolCallPart):
|
|
549
|
-
|
|
558
|
+
# Text parts before a built-in tool call are essentially thoughts,
|
|
559
|
+
# not part of the final result output, so we reset the accumulated text
|
|
560
|
+
text = ''
|
|
561
|
+
yield _messages.BuiltinToolCallEvent(part) # pyright: ignore[reportDeprecated]
|
|
550
562
|
elif isinstance(part, _messages.BuiltinToolReturnPart):
|
|
551
|
-
yield _messages.BuiltinToolResultEvent(part)
|
|
563
|
+
yield _messages.BuiltinToolResultEvent(part) # pyright: ignore[reportDeprecated]
|
|
552
564
|
elif isinstance(part, _messages.ThinkingPart):
|
|
553
565
|
thinking_parts.append(part)
|
|
554
566
|
else:
|
|
@@ -561,9 +573,9 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
561
573
|
if tool_calls:
|
|
562
574
|
async for event in self._handle_tool_calls(ctx, tool_calls):
|
|
563
575
|
yield event
|
|
564
|
-
elif
|
|
576
|
+
elif text:
|
|
565
577
|
# No events are emitted during the handling of text responses, so we don't need to yield anything
|
|
566
|
-
self._next_node = await self._handle_text_response(ctx,
|
|
578
|
+
self._next_node = await self._handle_text_response(ctx, text)
|
|
567
579
|
elif thinking_parts:
|
|
568
580
|
# handle thinking-only responses (responses that contain only ThinkingPart instances)
|
|
569
581
|
# this can happen with models that support thinking mode when they don't provide
|
|
@@ -582,9 +594,16 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
582
594
|
if isinstance(ctx.deps.output_schema, _output.TextOutputSchema):
|
|
583
595
|
for message in reversed(ctx.state.message_history):
|
|
584
596
|
if isinstance(message, _messages.ModelResponse):
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
597
|
+
text = ''
|
|
598
|
+
for part in message.parts:
|
|
599
|
+
if isinstance(part, _messages.TextPart):
|
|
600
|
+
text += part.content
|
|
601
|
+
elif isinstance(part, _messages.BuiltinToolCallPart):
|
|
602
|
+
# Text parts before a built-in tool call are essentially thoughts,
|
|
603
|
+
# not part of the final result output, so we reset the accumulated text
|
|
604
|
+
text = '' # pragma: no cover
|
|
605
|
+
if text:
|
|
606
|
+
self._next_node = await self._handle_text_response(ctx, text)
|
|
588
607
|
return
|
|
589
608
|
|
|
590
609
|
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
@@ -644,11 +663,9 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
644
663
|
async def _handle_text_response(
|
|
645
664
|
self,
|
|
646
665
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
647
|
-
|
|
666
|
+
text: str,
|
|
648
667
|
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
|
|
649
668
|
output_schema = ctx.deps.output_schema
|
|
650
|
-
|
|
651
|
-
text = '\n\n'.join(texts)
|
|
652
669
|
try:
|
|
653
670
|
run_context = build_run_context(ctx)
|
|
654
671
|
if isinstance(output_schema, _output.TextOutputSchema):
|
|
@@ -1086,12 +1103,11 @@ def build_agent_graph(
|
|
|
1086
1103
|
|
|
1087
1104
|
|
|
1088
1105
|
async def _process_message_history(
|
|
1089
|
-
|
|
1106
|
+
messages: list[_messages.ModelMessage],
|
|
1090
1107
|
processors: Sequence[HistoryProcessor[DepsT]],
|
|
1091
1108
|
run_context: RunContext[DepsT],
|
|
1092
1109
|
) -> list[_messages.ModelMessage]:
|
|
1093
1110
|
"""Process message history through a sequence of processors."""
|
|
1094
|
-
messages = state.message_history
|
|
1095
1111
|
for processor in processors:
|
|
1096
1112
|
takes_ctx = is_takes_ctx(processor)
|
|
1097
1113
|
|
|
@@ -1109,8 +1125,12 @@ async def _process_message_history(
|
|
|
1109
1125
|
sync_processor = cast(_HistoryProcessorSync, processor)
|
|
1110
1126
|
messages = await run_in_executor(sync_processor, messages)
|
|
1111
1127
|
|
|
1112
|
-
|
|
1113
|
-
|
|
1128
|
+
if len(messages) == 0:
|
|
1129
|
+
raise exceptions.UserError('Processed history cannot be empty.')
|
|
1130
|
+
|
|
1131
|
+
if not isinstance(messages[-1], _messages.ModelRequest):
|
|
1132
|
+
raise exceptions.UserError('Processed history must end with a `ModelRequest`.')
|
|
1133
|
+
|
|
1114
1134
|
return messages
|
|
1115
1135
|
|
|
1116
1136
|
|
pydantic_ai/_cli.py
CHANGED
|
@@ -356,7 +356,7 @@ def handle_slash_command(
|
|
|
356
356
|
except IndexError:
|
|
357
357
|
console.print('[dim]No output available to copy.[/dim]')
|
|
358
358
|
else:
|
|
359
|
-
text_to_copy = '
|
|
359
|
+
text_to_copy = ''.join(part.content for part in parts if isinstance(part, TextPart))
|
|
360
360
|
text_to_copy = text_to_copy.strip()
|
|
361
361
|
if text_to_copy:
|
|
362
362
|
pyperclip.copy(text_to_copy)
|
pydantic_ai/_otel_messages.py
CHANGED
|
@@ -21,6 +21,7 @@ class ToolCallPart(TypedDict):
|
|
|
21
21
|
id: str
|
|
22
22
|
name: str
|
|
23
23
|
arguments: NotRequired[JsonValue]
|
|
24
|
+
builtin: NotRequired[bool] # Not (currently?) part of the spec, used by Logfire
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class ToolCallResponsePart(TypedDict):
|
|
@@ -28,6 +29,7 @@ class ToolCallResponsePart(TypedDict):
|
|
|
28
29
|
id: str
|
|
29
30
|
name: str
|
|
30
31
|
result: NotRequired[JsonValue]
|
|
32
|
+
builtin: NotRequired[bool] # Not (currently?) part of the spec, used by Logfire
|
|
31
33
|
|
|
32
34
|
|
|
33
35
|
class MediaUrlPart(TypedDict):
|
pydantic_ai/_parts_manager.py
CHANGED
|
@@ -19,6 +19,8 @@ from typing import Any
|
|
|
19
19
|
|
|
20
20
|
from pydantic_ai.exceptions import UnexpectedModelBehavior
|
|
21
21
|
from pydantic_ai.messages import (
|
|
22
|
+
BuiltinToolCallPart,
|
|
23
|
+
BuiltinToolReturnPart,
|
|
22
24
|
ModelResponsePart,
|
|
23
25
|
ModelResponseStreamEvent,
|
|
24
26
|
PartDeltaEvent,
|
|
@@ -226,11 +228,11 @@ class ModelResponsePartsManager:
|
|
|
226
228
|
self,
|
|
227
229
|
*,
|
|
228
230
|
vendor_part_id: Hashable | None,
|
|
229
|
-
tool_name: str | None,
|
|
230
|
-
args: str | dict[str, Any] | None,
|
|
231
|
-
tool_call_id: str | None,
|
|
231
|
+
tool_name: str | None = None,
|
|
232
|
+
args: str | dict[str, Any] | None = None,
|
|
233
|
+
tool_call_id: str | None = None,
|
|
232
234
|
) -> ModelResponseStreamEvent | None:
|
|
233
|
-
"""Handle or update a tool call, creating or updating a `ToolCallPart` or `ToolCallPartDelta`.
|
|
235
|
+
"""Handle or update a tool call, creating or updating a `ToolCallPart`, `BuiltinToolCallPart`, or `ToolCallPartDelta`.
|
|
234
236
|
|
|
235
237
|
Managed items remain as `ToolCallPartDelta`s until they have at least a tool_name, at which
|
|
236
238
|
point they are upgraded to `ToolCallPart`s.
|
|
@@ -247,15 +249,17 @@ class ModelResponsePartsManager:
|
|
|
247
249
|
tool_call_id: An optional string representing an identifier for this tool call.
|
|
248
250
|
|
|
249
251
|
Returns:
|
|
250
|
-
- A `PartStartEvent` if a new ToolCallPart is created.
|
|
252
|
+
- A `PartStartEvent` if a new ToolCallPart or BuiltinToolCallPart is created.
|
|
251
253
|
- A `PartDeltaEvent` if an existing part is updated.
|
|
252
254
|
- `None` if no new event is emitted (e.g., the part is still incomplete).
|
|
253
255
|
|
|
254
256
|
Raises:
|
|
255
257
|
UnexpectedModelBehavior: If attempting to apply a tool call delta to a part that is not
|
|
256
|
-
a ToolCallPart or ToolCallPartDelta.
|
|
258
|
+
a ToolCallPart, BuiltinToolCallPart, or ToolCallPartDelta.
|
|
257
259
|
"""
|
|
258
|
-
existing_matching_part_and_index: tuple[ToolCallPartDelta | ToolCallPart, int] | None =
|
|
260
|
+
existing_matching_part_and_index: tuple[ToolCallPartDelta | ToolCallPart | BuiltinToolCallPart, int] | None = (
|
|
261
|
+
None
|
|
262
|
+
)
|
|
259
263
|
|
|
260
264
|
if vendor_part_id is None:
|
|
261
265
|
# vendor_part_id is None, so check if the latest part is a matching tool call or delta to update
|
|
@@ -264,14 +268,14 @@ class ModelResponsePartsManager:
|
|
|
264
268
|
if tool_name is None and self._parts:
|
|
265
269
|
part_index = len(self._parts) - 1
|
|
266
270
|
latest_part = self._parts[part_index]
|
|
267
|
-
if isinstance(latest_part, ToolCallPart | ToolCallPartDelta): # pragma: no branch
|
|
271
|
+
if isinstance(latest_part, ToolCallPart | BuiltinToolCallPart | ToolCallPartDelta): # pragma: no branch
|
|
268
272
|
existing_matching_part_and_index = latest_part, part_index
|
|
269
273
|
else:
|
|
270
274
|
# vendor_part_id is provided, so look up the corresponding part or delta
|
|
271
275
|
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
|
|
272
276
|
if part_index is not None:
|
|
273
277
|
existing_part = self._parts[part_index]
|
|
274
|
-
if not isinstance(existing_part, ToolCallPartDelta | ToolCallPart):
|
|
278
|
+
if not isinstance(existing_part, ToolCallPartDelta | ToolCallPart | BuiltinToolCallPart):
|
|
275
279
|
raise UnexpectedModelBehavior(f'Cannot apply a tool call delta to {existing_part=}')
|
|
276
280
|
existing_matching_part_and_index = existing_part, part_index
|
|
277
281
|
|
|
@@ -284,7 +288,7 @@ class ModelResponsePartsManager:
|
|
|
284
288
|
new_part_index = len(self._parts)
|
|
285
289
|
self._parts.append(part)
|
|
286
290
|
# Only emit a PartStartEvent if we have enough information to produce a full ToolCallPart
|
|
287
|
-
if isinstance(part, ToolCallPart):
|
|
291
|
+
if isinstance(part, ToolCallPart | BuiltinToolCallPart):
|
|
288
292
|
return PartStartEvent(index=new_part_index, part=part)
|
|
289
293
|
else:
|
|
290
294
|
# Update the existing part or delta with the new information
|
|
@@ -292,7 +296,7 @@ class ModelResponsePartsManager:
|
|
|
292
296
|
delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id)
|
|
293
297
|
updated_part = delta.apply(existing_part)
|
|
294
298
|
self._parts[part_index] = updated_part
|
|
295
|
-
if isinstance(updated_part, ToolCallPart):
|
|
299
|
+
if isinstance(updated_part, ToolCallPart | BuiltinToolCallPart):
|
|
296
300
|
if isinstance(existing_part, ToolCallPartDelta):
|
|
297
301
|
# We just upgraded a delta to a full part, so emit a PartStartEvent
|
|
298
302
|
return PartStartEvent(index=part_index, part=updated_part)
|
|
@@ -337,7 +341,7 @@ class ModelResponsePartsManager:
|
|
|
337
341
|
else:
|
|
338
342
|
# vendor_part_id is provided, so find and overwrite or create a new ToolCallPart.
|
|
339
343
|
maybe_part_index = self._vendor_id_to_part_index.get(vendor_part_id)
|
|
340
|
-
if maybe_part_index is not None:
|
|
344
|
+
if maybe_part_index is not None and isinstance(self._parts[maybe_part_index], ToolCallPart):
|
|
341
345
|
new_part_index = maybe_part_index
|
|
342
346
|
self._parts[new_part_index] = new_part
|
|
343
347
|
else:
|
|
@@ -345,3 +349,69 @@ class ModelResponsePartsManager:
|
|
|
345
349
|
self._parts.append(new_part)
|
|
346
350
|
self._vendor_id_to_part_index[vendor_part_id] = new_part_index
|
|
347
351
|
return PartStartEvent(index=new_part_index, part=new_part)
|
|
352
|
+
|
|
353
|
+
def handle_builtin_tool_call_part(
|
|
354
|
+
self,
|
|
355
|
+
*,
|
|
356
|
+
vendor_part_id: Hashable | None,
|
|
357
|
+
part: BuiltinToolCallPart,
|
|
358
|
+
) -> ModelResponseStreamEvent:
|
|
359
|
+
"""Create or overwrite a BuiltinToolCallPart.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
vendor_part_id: The vendor's ID for this tool call part. If not
|
|
363
|
+
None and an existing part is found, that part is overwritten.
|
|
364
|
+
part: The BuiltinToolCallPart.
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part
|
|
368
|
+
has been added to the manager, or replaced an existing part.
|
|
369
|
+
"""
|
|
370
|
+
if vendor_part_id is None:
|
|
371
|
+
# vendor_part_id is None, so we unconditionally append a new BuiltinToolCallPart to the end of the list
|
|
372
|
+
new_part_index = len(self._parts)
|
|
373
|
+
self._parts.append(part)
|
|
374
|
+
else:
|
|
375
|
+
# vendor_part_id is provided, so find and overwrite or create a new BuiltinToolCallPart.
|
|
376
|
+
maybe_part_index = self._vendor_id_to_part_index.get(vendor_part_id)
|
|
377
|
+
if maybe_part_index is not None and isinstance(self._parts[maybe_part_index], BuiltinToolCallPart):
|
|
378
|
+
new_part_index = maybe_part_index
|
|
379
|
+
self._parts[new_part_index] = part
|
|
380
|
+
else:
|
|
381
|
+
new_part_index = len(self._parts)
|
|
382
|
+
self._parts.append(part)
|
|
383
|
+
self._vendor_id_to_part_index[vendor_part_id] = new_part_index
|
|
384
|
+
return PartStartEvent(index=new_part_index, part=part)
|
|
385
|
+
|
|
386
|
+
def handle_builtin_tool_return_part(
|
|
387
|
+
self,
|
|
388
|
+
*,
|
|
389
|
+
vendor_part_id: Hashable | None,
|
|
390
|
+
part: BuiltinToolReturnPart,
|
|
391
|
+
) -> ModelResponseStreamEvent:
|
|
392
|
+
"""Create or overwrite a BuiltinToolReturnPart.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
vendor_part_id: The vendor's ID for this tool call part. If not
|
|
396
|
+
None and an existing part is found, that part is overwritten.
|
|
397
|
+
part: The BuiltinToolReturnPart.
|
|
398
|
+
|
|
399
|
+
Returns:
|
|
400
|
+
ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part
|
|
401
|
+
has been added to the manager, or replaced an existing part.
|
|
402
|
+
"""
|
|
403
|
+
if vendor_part_id is None:
|
|
404
|
+
# vendor_part_id is None, so we unconditionally append a new BuiltinToolReturnPart to the end of the list
|
|
405
|
+
new_part_index = len(self._parts)
|
|
406
|
+
self._parts.append(part)
|
|
407
|
+
else:
|
|
408
|
+
# vendor_part_id is provided, so find and overwrite or create a new BuiltinToolReturnPart.
|
|
409
|
+
maybe_part_index = self._vendor_id_to_part_index.get(vendor_part_id)
|
|
410
|
+
if maybe_part_index is not None and isinstance(self._parts[maybe_part_index], BuiltinToolReturnPart):
|
|
411
|
+
new_part_index = maybe_part_index
|
|
412
|
+
self._parts[new_part_index] = part
|
|
413
|
+
else:
|
|
414
|
+
new_part_index = len(self._parts)
|
|
415
|
+
self._parts.append(part)
|
|
416
|
+
self._vendor_id_to_part_index[vendor_part_id] = new_part_index
|
|
417
|
+
return PartStartEvent(index=new_part_index, part=part)
|
pydantic_ai/_run_context.py
CHANGED
|
@@ -43,10 +43,17 @@ class RunContext(Generic[AgentDepsT]):
|
|
|
43
43
|
tool_name: str | None = None
|
|
44
44
|
"""Name of the tool being called."""
|
|
45
45
|
retry: int = 0
|
|
46
|
-
"""Number of retries so far."""
|
|
46
|
+
"""Number of retries of this tool so far."""
|
|
47
|
+
max_retries: int = 0
|
|
48
|
+
"""The maximum number of retries of this tool."""
|
|
47
49
|
run_step: int = 0
|
|
48
50
|
"""The current step in the run."""
|
|
49
51
|
tool_call_approved: bool = False
|
|
50
52
|
"""Whether a tool call that required approval has now been approved."""
|
|
51
53
|
|
|
54
|
+
@property
|
|
55
|
+
def last_attempt(self) -> bool:
|
|
56
|
+
"""Whether this is the last attempt at running this tool before an error is raised."""
|
|
57
|
+
return self.retry == self.max_retries
|
|
58
|
+
|
|
52
59
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
pydantic_ai/_tool_manager.py
CHANGED
pydantic_ai/ag_ui.py
CHANGED
|
@@ -9,7 +9,7 @@ from __future__ import annotations
|
|
|
9
9
|
import json
|
|
10
10
|
import uuid
|
|
11
11
|
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Mapping, Sequence
|
|
12
|
-
from dataclasses import Field, dataclass, replace
|
|
12
|
+
from dataclasses import Field, dataclass, field, replace
|
|
13
13
|
from http import HTTPStatus
|
|
14
14
|
from typing import (
|
|
15
15
|
Any,
|
|
@@ -23,13 +23,15 @@ from typing import (
|
|
|
23
23
|
)
|
|
24
24
|
|
|
25
25
|
from pydantic import BaseModel, ValidationError
|
|
26
|
-
from typing_extensions import assert_never
|
|
27
26
|
|
|
28
27
|
from . import _utils
|
|
29
28
|
from ._agent_graph import CallToolsNode, ModelRequestNode
|
|
30
29
|
from .agent import AbstractAgent, AgentRun, AgentRunResult
|
|
31
30
|
from .exceptions import UserError
|
|
32
31
|
from .messages import (
|
|
32
|
+
BaseToolCallPart,
|
|
33
|
+
BuiltinToolCallPart,
|
|
34
|
+
BuiltinToolReturnPart,
|
|
33
35
|
FunctionToolResultEvent,
|
|
34
36
|
ModelMessage,
|
|
35
37
|
ModelRequest,
|
|
@@ -123,6 +125,8 @@ SSE_CONTENT_TYPE: Final[str] = 'text/event-stream'
|
|
|
123
125
|
OnCompleteFunc: TypeAlias = Callable[[AgentRunResult[Any]], None] | Callable[[AgentRunResult[Any]], Awaitable[None]]
|
|
124
126
|
"""Callback function type that receives the `AgentRunResult` of the completed run. Can be sync or async."""
|
|
125
127
|
|
|
128
|
+
_BUILTIN_TOOL_CALL_ID_PREFIX: Final[str] = 'pyd_ai_builtin'
|
|
129
|
+
|
|
126
130
|
|
|
127
131
|
class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
|
|
128
132
|
"""ASGI application for running Pydantic AI agents with AG-UI protocol support."""
|
|
@@ -487,20 +491,37 @@ async def _handle_model_request_event( # noqa: C901
|
|
|
487
491
|
stream_ctx.part_end = TextMessageEndEvent(
|
|
488
492
|
message_id=message_id,
|
|
489
493
|
)
|
|
490
|
-
elif isinstance(part,
|
|
494
|
+
elif isinstance(part, BaseToolCallPart):
|
|
495
|
+
tool_call_id = part.tool_call_id
|
|
496
|
+
if isinstance(part, BuiltinToolCallPart):
|
|
497
|
+
builtin_tool_call_id = '|'.join(
|
|
498
|
+
[_BUILTIN_TOOL_CALL_ID_PREFIX, part.provider_name or '', tool_call_id]
|
|
499
|
+
)
|
|
500
|
+
stream_ctx.builtin_tool_call_ids[tool_call_id] = builtin_tool_call_id
|
|
501
|
+
tool_call_id = builtin_tool_call_id
|
|
502
|
+
|
|
491
503
|
message_id = stream_ctx.message_id or stream_ctx.new_message_id()
|
|
492
504
|
yield ToolCallStartEvent(
|
|
493
|
-
tool_call_id=
|
|
505
|
+
tool_call_id=tool_call_id,
|
|
494
506
|
tool_call_name=part.tool_name,
|
|
495
507
|
parent_message_id=message_id,
|
|
496
508
|
)
|
|
497
509
|
if part.args:
|
|
498
510
|
yield ToolCallArgsEvent(
|
|
499
|
-
tool_call_id=
|
|
500
|
-
delta=part.
|
|
511
|
+
tool_call_id=tool_call_id,
|
|
512
|
+
delta=part.args_as_json_str(),
|
|
501
513
|
)
|
|
502
514
|
stream_ctx.part_end = ToolCallEndEvent(
|
|
503
|
-
tool_call_id=
|
|
515
|
+
tool_call_id=tool_call_id,
|
|
516
|
+
)
|
|
517
|
+
elif isinstance(part, BuiltinToolReturnPart): # pragma: no branch
|
|
518
|
+
tool_call_id = stream_ctx.builtin_tool_call_ids[part.tool_call_id]
|
|
519
|
+
yield ToolCallResultEvent(
|
|
520
|
+
message_id=stream_ctx.new_message_id(),
|
|
521
|
+
type=EventType.TOOL_CALL_RESULT,
|
|
522
|
+
role='tool',
|
|
523
|
+
tool_call_id=tool_call_id,
|
|
524
|
+
content=part.model_response_str(),
|
|
504
525
|
)
|
|
505
526
|
|
|
506
527
|
elif isinstance(agent_event, PartDeltaEvent):
|
|
@@ -512,9 +533,12 @@ async def _handle_model_request_event( # noqa: C901
|
|
|
512
533
|
delta=delta.content_delta,
|
|
513
534
|
)
|
|
514
535
|
elif isinstance(delta, ToolCallPartDelta): # pragma: no branch
|
|
515
|
-
|
|
536
|
+
tool_call_id = delta.tool_call_id
|
|
537
|
+
assert tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set'
|
|
538
|
+
if tool_call_id in stream_ctx.builtin_tool_call_ids:
|
|
539
|
+
tool_call_id = stream_ctx.builtin_tool_call_ids[tool_call_id]
|
|
516
540
|
yield ToolCallArgsEvent(
|
|
517
|
-
tool_call_id=
|
|
541
|
+
tool_call_id=tool_call_id,
|
|
518
542
|
delta=delta.args_delta if isinstance(delta.args_delta, str) else json.dumps(delta.args_delta),
|
|
519
543
|
)
|
|
520
544
|
elif isinstance(delta, ThinkingPartDelta): # pragma: no branch
|
|
@@ -550,24 +574,23 @@ async def _handle_tool_result_event(
|
|
|
550
574
|
if not isinstance(result, ToolReturnPart):
|
|
551
575
|
return
|
|
552
576
|
|
|
553
|
-
message_id = stream_ctx.new_message_id()
|
|
554
577
|
yield ToolCallResultEvent(
|
|
555
|
-
message_id=
|
|
578
|
+
message_id=stream_ctx.new_message_id(),
|
|
556
579
|
type=EventType.TOOL_CALL_RESULT,
|
|
557
580
|
role='tool',
|
|
558
581
|
tool_call_id=result.tool_call_id,
|
|
559
582
|
content=result.model_response_str(),
|
|
560
583
|
)
|
|
561
584
|
|
|
562
|
-
# Now check for
|
|
563
|
-
|
|
564
|
-
if isinstance(
|
|
565
|
-
yield
|
|
566
|
-
elif isinstance(
|
|
585
|
+
# Now check for AG-UI events returned by the tool calls.
|
|
586
|
+
possible_event = result.metadata or result.content
|
|
587
|
+
if isinstance(possible_event, BaseEvent):
|
|
588
|
+
yield possible_event
|
|
589
|
+
elif isinstance(possible_event, str | bytes): # pragma: no branch
|
|
567
590
|
# Avoid iterable check for strings and bytes.
|
|
568
591
|
pass
|
|
569
|
-
elif isinstance(
|
|
570
|
-
for item in
|
|
592
|
+
elif isinstance(possible_event, Iterable): # pragma: no branch
|
|
593
|
+
for item in possible_event: # type: ignore[reportUnknownMemberType]
|
|
571
594
|
if isinstance(item, BaseEvent): # pragma: no branch
|
|
572
595
|
yield item
|
|
573
596
|
|
|
@@ -579,7 +602,9 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
|
|
|
579
602
|
request_parts: list[ModelRequestPart] | None = None
|
|
580
603
|
response_parts: list[ModelResponsePart] | None = None
|
|
581
604
|
for msg in messages:
|
|
582
|
-
if isinstance(msg, UserMessage | SystemMessage | DeveloperMessage
|
|
605
|
+
if isinstance(msg, UserMessage | SystemMessage | DeveloperMessage) or (
|
|
606
|
+
isinstance(msg, ToolMessage) and not msg.tool_call_id.startswith(_BUILTIN_TOOL_CALL_ID_PREFIX)
|
|
607
|
+
):
|
|
583
608
|
if request_parts is None:
|
|
584
609
|
request_parts = []
|
|
585
610
|
result.append(ModelRequest(parts=request_parts))
|
|
@@ -589,44 +614,71 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
|
|
|
589
614
|
request_parts.append(UserPromptPart(content=msg.content))
|
|
590
615
|
elif isinstance(msg, SystemMessage | DeveloperMessage):
|
|
591
616
|
request_parts.append(SystemPromptPart(content=msg.content))
|
|
592
|
-
|
|
593
|
-
|
|
617
|
+
else:
|
|
618
|
+
tool_call_id = msg.tool_call_id
|
|
619
|
+
tool_name = tool_calls.get(tool_call_id)
|
|
594
620
|
if tool_name is None: # pragma: no cover
|
|
595
|
-
raise _ToolCallNotFoundError(tool_call_id=
|
|
621
|
+
raise _ToolCallNotFoundError(tool_call_id=tool_call_id)
|
|
596
622
|
|
|
597
623
|
request_parts.append(
|
|
598
624
|
ToolReturnPart(
|
|
599
625
|
tool_name=tool_name,
|
|
600
626
|
content=msg.content,
|
|
601
|
-
tool_call_id=
|
|
627
|
+
tool_call_id=tool_call_id,
|
|
602
628
|
)
|
|
603
629
|
)
|
|
604
|
-
else:
|
|
605
|
-
assert_never(msg)
|
|
606
630
|
|
|
607
|
-
elif isinstance(msg, AssistantMessage):
|
|
631
|
+
elif isinstance(msg, AssistantMessage) or ( # pragma: no branch
|
|
632
|
+
isinstance(msg, ToolMessage) and msg.tool_call_id.startswith(_BUILTIN_TOOL_CALL_ID_PREFIX)
|
|
633
|
+
):
|
|
608
634
|
if response_parts is None:
|
|
609
635
|
response_parts = []
|
|
610
636
|
result.append(ModelResponse(parts=response_parts))
|
|
611
637
|
request_parts = None
|
|
612
638
|
|
|
613
|
-
if msg
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
639
|
+
if isinstance(msg, AssistantMessage):
|
|
640
|
+
if msg.content:
|
|
641
|
+
response_parts.append(TextPart(content=msg.content))
|
|
642
|
+
|
|
643
|
+
if msg.tool_calls:
|
|
644
|
+
for tool_call in msg.tool_calls:
|
|
645
|
+
tool_call_id = tool_call.id
|
|
646
|
+
tool_name = tool_call.function.name
|
|
647
|
+
tool_calls[tool_call_id] = tool_name
|
|
648
|
+
|
|
649
|
+
if tool_call_id.startswith(_BUILTIN_TOOL_CALL_ID_PREFIX):
|
|
650
|
+
_, provider_name, tool_call_id = tool_call_id.split('|', 2)
|
|
651
|
+
response_parts.append(
|
|
652
|
+
BuiltinToolCallPart(
|
|
653
|
+
tool_name=tool_name,
|
|
654
|
+
args=tool_call.function.arguments,
|
|
655
|
+
tool_call_id=tool_call_id,
|
|
656
|
+
provider_name=provider_name,
|
|
657
|
+
)
|
|
658
|
+
)
|
|
659
|
+
else:
|
|
660
|
+
response_parts.append(
|
|
661
|
+
ToolCallPart(
|
|
662
|
+
tool_name=tool_name,
|
|
663
|
+
tool_call_id=tool_call_id,
|
|
664
|
+
args=tool_call.function.arguments,
|
|
665
|
+
)
|
|
666
|
+
)
|
|
667
|
+
else:
|
|
668
|
+
tool_call_id = msg.tool_call_id
|
|
669
|
+
tool_name = tool_calls.get(tool_call_id)
|
|
670
|
+
if tool_name is None: # pragma: no cover
|
|
671
|
+
raise _ToolCallNotFoundError(tool_call_id=tool_call_id)
|
|
672
|
+
_, provider_name, tool_call_id = tool_call_id.split('|', 2)
|
|
619
673
|
|
|
620
|
-
response_parts.
|
|
621
|
-
|
|
622
|
-
tool_name=
|
|
623
|
-
|
|
624
|
-
|
|
674
|
+
response_parts.append(
|
|
675
|
+
BuiltinToolReturnPart(
|
|
676
|
+
tool_name=tool_name,
|
|
677
|
+
content=msg.content,
|
|
678
|
+
tool_call_id=tool_call_id,
|
|
679
|
+
provider_name=provider_name,
|
|
625
680
|
)
|
|
626
|
-
for tool_call in msg.tool_calls
|
|
627
681
|
)
|
|
628
|
-
else:
|
|
629
|
-
assert_never(msg)
|
|
630
682
|
|
|
631
683
|
return result
|
|
632
684
|
|
|
@@ -687,6 +739,7 @@ class _RequestStreamContext:
|
|
|
687
739
|
message_id: str = ''
|
|
688
740
|
part_end: BaseEvent | None = None
|
|
689
741
|
thinking: bool = False
|
|
742
|
+
builtin_tool_call_ids: dict[str, str] = field(default_factory=dict)
|
|
690
743
|
|
|
691
744
|
def new_message_id(self) -> str:
|
|
692
745
|
"""Generate a new message ID for the request stream.
|
pydantic_ai/agent/__init__.py
CHANGED
|
@@ -614,12 +614,10 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
|
|
|
614
614
|
instrumentation_settings = None
|
|
615
615
|
tracer = NoOpTracer()
|
|
616
616
|
|
|
617
|
-
graph_deps = _agent_graph.GraphAgentDeps[
|
|
618
|
-
AgentDepsT, RunOutputDataT
|
|
619
|
-
](
|
|
617
|
+
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
|
|
620
618
|
user_deps=deps,
|
|
621
619
|
prompt=user_prompt,
|
|
622
|
-
new_message_index=
|
|
620
|
+
new_message_index=len(message_history) if message_history else 0,
|
|
623
621
|
model=model_used,
|
|
624
622
|
model_settings=model_settings,
|
|
625
623
|
usage_limits=usage_limits,
|
pydantic_ai/builtin_tools.py
CHANGED
|
@@ -18,6 +18,9 @@ class AbstractBuiltinTool(ABC):
|
|
|
18
18
|
The builtin tools are passed to the model as part of the `ModelRequestParameters`.
|
|
19
19
|
"""
|
|
20
20
|
|
|
21
|
+
kind: str = 'unknown_builtin_tool'
|
|
22
|
+
"""Built-in tool identifier, this should be available on all built-in tools as a discriminator."""
|
|
23
|
+
|
|
21
24
|
|
|
22
25
|
@dataclass(kw_only=True)
|
|
23
26
|
class WebSearchTool(AbstractBuiltinTool):
|
|
@@ -80,6 +83,9 @@ class WebSearchTool(AbstractBuiltinTool):
|
|
|
80
83
|
* Anthropic
|
|
81
84
|
"""
|
|
82
85
|
|
|
86
|
+
kind: str = 'web_search'
|
|
87
|
+
"""The kind of tool."""
|
|
88
|
+
|
|
83
89
|
|
|
84
90
|
class WebSearchUserLocation(TypedDict, total=False):
|
|
85
91
|
"""Allows you to localize search results based on a user's location.
|
|
@@ -113,6 +119,9 @@ class CodeExecutionTool(AbstractBuiltinTool):
|
|
|
113
119
|
* Google
|
|
114
120
|
"""
|
|
115
121
|
|
|
122
|
+
kind: str = 'code_execution'
|
|
123
|
+
"""The kind of tool."""
|
|
124
|
+
|
|
116
125
|
|
|
117
126
|
class UrlContextTool(AbstractBuiltinTool):
|
|
118
127
|
"""Allows your agent to access contents from URLs.
|
|
@@ -121,3 +130,6 @@ class UrlContextTool(AbstractBuiltinTool):
|
|
|
121
130
|
|
|
122
131
|
* Google
|
|
123
132
|
"""
|
|
133
|
+
|
|
134
|
+
kind: str = 'url_context'
|
|
135
|
+
"""The kind of tool."""
|