pydantic-ai-slim 0.0.29__py3-none-any.whl → 0.0.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +2 -2
- pydantic_ai/_agent_graph.py +97 -44
- pydantic_ai/_result.py +3 -3
- pydantic_ai/_utils.py +2 -0
- pydantic_ai/agent.py +95 -67
- pydantic_ai/messages.py +71 -1
- pydantic_ai/models/__init__.py +4 -0
- pydantic_ai/models/function.py +15 -4
- pydantic_ai/models/instrumented.py +70 -78
- pydantic_ai/result.py +125 -1
- pydantic_ai/usage.py +10 -0
- {pydantic_ai_slim-0.0.29.dist-info → pydantic_ai_slim-0.0.31.dist-info}/METADATA +4 -3
- {pydantic_ai_slim-0.0.29.dist-info → pydantic_ai_slim-0.0.31.dist-info}/RECORD +14 -14
- {pydantic_ai_slim-0.0.29.dist-info → pydantic_ai_slim-0.0.31.dist-info}/WHEEL +0 -0
pydantic_ai/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from importlib.metadata import version
|
|
2
2
|
|
|
3
|
-
from .agent import Agent,
|
|
3
|
+
from .agent import Agent, CallToolsNode, EndStrategy, ModelRequestNode, UserPromptNode, capture_run_messages
|
|
4
4
|
from .exceptions import (
|
|
5
5
|
AgentRunError,
|
|
6
6
|
FallbackExceptionGroup,
|
|
@@ -18,7 +18,7 @@ __all__ = (
|
|
|
18
18
|
# agent
|
|
19
19
|
'Agent',
|
|
20
20
|
'EndStrategy',
|
|
21
|
-
'
|
|
21
|
+
'CallToolsNode',
|
|
22
22
|
'ModelRequestNode',
|
|
23
23
|
'UserPromptNode',
|
|
24
24
|
'capture_run_messages',
|
pydantic_ai/_agent_graph.py
CHANGED
|
@@ -2,7 +2,6 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import dataclasses
|
|
5
|
-
from abc import ABC
|
|
6
5
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
7
6
|
from contextlib import asynccontextmanager, contextmanager
|
|
8
7
|
from contextvars import ContextVar
|
|
@@ -10,7 +9,7 @@ from dataclasses import field
|
|
|
10
9
|
from typing import Any, Generic, Literal, Union, cast
|
|
11
10
|
|
|
12
11
|
import logfire_api
|
|
13
|
-
from typing_extensions import TypeVar, assert_never
|
|
12
|
+
from typing_extensions import TypeGuard, TypeVar, assert_never
|
|
14
13
|
|
|
15
14
|
from pydantic_graph import BaseNode, Graph, GraphRunContext
|
|
16
15
|
from pydantic_graph.nodes import End, NodeRunEndT
|
|
@@ -24,6 +23,7 @@ from . import (
|
|
|
24
23
|
result,
|
|
25
24
|
usage as _usage,
|
|
26
25
|
)
|
|
26
|
+
from .models.instrumented import InstrumentedModel
|
|
27
27
|
from .result import ResultDataT
|
|
28
28
|
from .settings import ModelSettings, merge_model_settings
|
|
29
29
|
from .tools import (
|
|
@@ -37,7 +37,7 @@ __all__ = (
|
|
|
37
37
|
'GraphAgentDeps',
|
|
38
38
|
'UserPromptNode',
|
|
39
39
|
'ModelRequestNode',
|
|
40
|
-
'
|
|
40
|
+
'CallToolsNode',
|
|
41
41
|
'build_run_context',
|
|
42
42
|
'capture_run_messages',
|
|
43
43
|
)
|
|
@@ -55,6 +55,7 @@ else:
|
|
|
55
55
|
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
|
|
56
56
|
|
|
57
57
|
T = TypeVar('T')
|
|
58
|
+
S = TypeVar('S')
|
|
58
59
|
NoneType = type(None)
|
|
59
60
|
EndStrategy = Literal['early', 'exhaustive']
|
|
60
61
|
"""The strategy for handling multiple tool calls when a final result is found.
|
|
@@ -107,8 +108,31 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
107
108
|
run_span: logfire_api.LogfireSpan
|
|
108
109
|
|
|
109
110
|
|
|
111
|
+
class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
|
|
112
|
+
"""The base class for all agent nodes.
|
|
113
|
+
|
|
114
|
+
Using subclass of `BaseNode` for all nodes reduces the amount of boilerplate of generics everywhere
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def is_agent_node(
|
|
119
|
+
node: BaseNode[GraphAgentState, GraphAgentDeps[T, Any], result.FinalResult[S]] | End[result.FinalResult[S]],
|
|
120
|
+
) -> TypeGuard[AgentNode[T, S]]:
|
|
121
|
+
"""Check if the provided node is an instance of `AgentNode`.
|
|
122
|
+
|
|
123
|
+
Usage:
|
|
124
|
+
|
|
125
|
+
if is_agent_node(node):
|
|
126
|
+
# `node` is an AgentNode
|
|
127
|
+
...
|
|
128
|
+
|
|
129
|
+
This method preserves the generic parameters on the narrowed type, unlike `isinstance(node, AgentNode)`.
|
|
130
|
+
"""
|
|
131
|
+
return isinstance(node, AgentNode)
|
|
132
|
+
|
|
133
|
+
|
|
110
134
|
@dataclasses.dataclass
|
|
111
|
-
class UserPromptNode(
|
|
135
|
+
class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
112
136
|
user_prompt: str | Sequence[_messages.UserContent]
|
|
113
137
|
|
|
114
138
|
system_prompts: tuple[str, ...]
|
|
@@ -215,17 +239,17 @@ async def _prepare_request_parameters(
|
|
|
215
239
|
|
|
216
240
|
|
|
217
241
|
@dataclasses.dataclass
|
|
218
|
-
class ModelRequestNode(
|
|
242
|
+
class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
219
243
|
"""Make a request to the model using the last message in state.message_history."""
|
|
220
244
|
|
|
221
245
|
request: _messages.ModelRequest
|
|
222
246
|
|
|
223
|
-
_result:
|
|
247
|
+
_result: CallToolsNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
|
|
224
248
|
_did_stream: bool = field(default=False, repr=False)
|
|
225
249
|
|
|
226
250
|
async def run(
|
|
227
251
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
228
|
-
) ->
|
|
252
|
+
) -> CallToolsNode[DepsT, NodeRunEndT]:
|
|
229
253
|
if self._result is not None:
|
|
230
254
|
return self._result
|
|
231
255
|
|
|
@@ -236,48 +260,60 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], res
|
|
|
236
260
|
|
|
237
261
|
return await self._make_request(ctx)
|
|
238
262
|
|
|
263
|
+
@asynccontextmanager
|
|
264
|
+
async def stream(
|
|
265
|
+
self,
|
|
266
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
|
|
267
|
+
) -> AsyncIterator[result.AgentStream[DepsT, T]]:
|
|
268
|
+
async with self._stream(ctx) as streamed_response:
|
|
269
|
+
agent_stream = result.AgentStream[DepsT, T](
|
|
270
|
+
streamed_response,
|
|
271
|
+
ctx.deps.result_schema,
|
|
272
|
+
ctx.deps.result_validators,
|
|
273
|
+
build_run_context(ctx),
|
|
274
|
+
ctx.deps.usage_limits,
|
|
275
|
+
)
|
|
276
|
+
yield agent_stream
|
|
277
|
+
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
|
|
278
|
+
# otherwise usage won't be properly counted:
|
|
279
|
+
async for _ in agent_stream:
|
|
280
|
+
pass
|
|
281
|
+
|
|
239
282
|
@asynccontextmanager
|
|
240
283
|
async def _stream(
|
|
241
284
|
self,
|
|
242
285
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
|
|
243
286
|
) -> AsyncIterator[models.StreamedResponse]:
|
|
244
|
-
# TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public
|
|
245
287
|
assert not self._did_stream, 'stream() should only be called once per node'
|
|
246
288
|
|
|
247
289
|
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
248
|
-
with
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
request_usage = streamed_response.usage()
|
|
261
|
-
span.set_attribute('response', model_response)
|
|
262
|
-
span.set_attribute('usage', request_usage)
|
|
290
|
+
async with ctx.deps.model.request_stream(
|
|
291
|
+
ctx.state.message_history, model_settings, model_request_parameters
|
|
292
|
+
) as streamed_response:
|
|
293
|
+
self._did_stream = True
|
|
294
|
+
ctx.state.usage.incr(_usage.Usage(), requests=1)
|
|
295
|
+
yield streamed_response
|
|
296
|
+
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
|
|
297
|
+
# otherwise usage won't be properly counted:
|
|
298
|
+
async for _ in streamed_response:
|
|
299
|
+
pass
|
|
300
|
+
model_response = streamed_response.get()
|
|
301
|
+
request_usage = streamed_response.usage()
|
|
263
302
|
|
|
264
303
|
self._finish_handling(ctx, model_response, request_usage)
|
|
265
304
|
assert self._result is not None # this should be set by the previous line
|
|
266
305
|
|
|
267
306
|
async def _make_request(
|
|
268
307
|
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
|
|
269
|
-
) ->
|
|
308
|
+
) -> CallToolsNode[DepsT, NodeRunEndT]:
|
|
270
309
|
if self._result is not None:
|
|
271
310
|
return self._result
|
|
272
311
|
|
|
273
312
|
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
ctx.state.usage.incr(_usage.Usage(), requests=1)
|
|
279
|
-
span.set_attribute('response', model_response)
|
|
280
|
-
span.set_attribute('usage', request_usage)
|
|
313
|
+
model_response, request_usage = await ctx.deps.model.request(
|
|
314
|
+
ctx.state.message_history, model_settings, model_request_parameters
|
|
315
|
+
)
|
|
316
|
+
ctx.state.usage.incr(_usage.Usage(), requests=1)
|
|
281
317
|
|
|
282
318
|
return self._finish_handling(ctx, model_response, request_usage)
|
|
283
319
|
|
|
@@ -303,7 +339,7 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], res
|
|
|
303
339
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
304
340
|
response: _messages.ModelResponse,
|
|
305
341
|
usage: _usage.Usage,
|
|
306
|
-
) ->
|
|
342
|
+
) -> CallToolsNode[DepsT, NodeRunEndT]:
|
|
307
343
|
# Update usage
|
|
308
344
|
ctx.state.usage.incr(usage, requests=0)
|
|
309
345
|
if ctx.deps.usage_limits:
|
|
@@ -313,13 +349,13 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], res
|
|
|
313
349
|
ctx.state.message_history.append(response)
|
|
314
350
|
|
|
315
351
|
# Set the `_result` attribute since we can't use `return` in an async iterator
|
|
316
|
-
self._result =
|
|
352
|
+
self._result = CallToolsNode(response)
|
|
317
353
|
|
|
318
354
|
return self._result
|
|
319
355
|
|
|
320
356
|
|
|
321
357
|
@dataclasses.dataclass
|
|
322
|
-
class
|
|
358
|
+
class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
|
|
323
359
|
"""Process a model response, and decide whether to end the run or make a new request."""
|
|
324
360
|
|
|
325
361
|
model_response: _messages.ModelResponse
|
|
@@ -413,8 +449,7 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], r
|
|
|
413
449
|
final_result: result.FinalResult[NodeRunEndT] | None = None
|
|
414
450
|
parts: list[_messages.ModelRequestPart] = []
|
|
415
451
|
if result_schema is not None:
|
|
416
|
-
|
|
417
|
-
call, result_tool = match
|
|
452
|
+
for call, result_tool in result_schema.find_tool(tool_calls):
|
|
418
453
|
try:
|
|
419
454
|
result_data = result_tool.validate(call)
|
|
420
455
|
result_data = await _validate_result(result_data, ctx, call)
|
|
@@ -424,12 +459,17 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], r
|
|
|
424
459
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
425
460
|
parts.append(e.tool_retry)
|
|
426
461
|
else:
|
|
427
|
-
final_result = result.FinalResult(result_data, call.tool_name)
|
|
462
|
+
final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
|
|
463
|
+
break
|
|
428
464
|
|
|
429
465
|
# Then build the other request parts based on end strategy
|
|
430
466
|
tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
|
|
431
467
|
async for event in process_function_tools(
|
|
432
|
-
tool_calls,
|
|
468
|
+
tool_calls,
|
|
469
|
+
final_result and final_result.tool_name,
|
|
470
|
+
final_result and final_result.tool_call_id,
|
|
471
|
+
ctx,
|
|
472
|
+
tool_responses,
|
|
433
473
|
):
|
|
434
474
|
yield event
|
|
435
475
|
|
|
@@ -455,7 +495,10 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], r
|
|
|
455
495
|
messages.append(_messages.ModelRequest(parts=tool_responses))
|
|
456
496
|
|
|
457
497
|
run_span.set_attribute('usage', usage)
|
|
458
|
-
run_span.set_attribute(
|
|
498
|
+
run_span.set_attribute(
|
|
499
|
+
'all_messages_events',
|
|
500
|
+
[InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)],
|
|
501
|
+
)
|
|
459
502
|
|
|
460
503
|
# End the run with self.data
|
|
461
504
|
return End(final_result)
|
|
@@ -477,7 +520,7 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], r
|
|
|
477
520
|
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
|
|
478
521
|
else:
|
|
479
522
|
# The following cast is safe because we know `str` is an allowed result type
|
|
480
|
-
return self._handle_final_result(ctx, result.FinalResult(result_data,
|
|
523
|
+
return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
|
|
481
524
|
else:
|
|
482
525
|
ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
483
526
|
return ModelRequestNode[DepsT, NodeRunEndT](
|
|
@@ -506,6 +549,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
|
|
|
506
549
|
async def process_function_tools(
|
|
507
550
|
tool_calls: list[_messages.ToolCallPart],
|
|
508
551
|
result_tool_name: str | None,
|
|
552
|
+
result_tool_call_id: str | None,
|
|
509
553
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
510
554
|
output_parts: list[_messages.ModelRequestPart],
|
|
511
555
|
) -> AsyncIterator[_messages.HandleResponseEvent]:
|
|
@@ -525,7 +569,11 @@ async def process_function_tools(
|
|
|
525
569
|
calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
|
|
526
570
|
call_index_to_event_id: dict[int, str] = {}
|
|
527
571
|
for call in tool_calls:
|
|
528
|
-
if
|
|
572
|
+
if (
|
|
573
|
+
call.tool_name == result_tool_name
|
|
574
|
+
and call.tool_call_id == result_tool_call_id
|
|
575
|
+
and not found_used_result_tool
|
|
576
|
+
):
|
|
529
577
|
found_used_result_tool = True
|
|
530
578
|
output_parts.append(
|
|
531
579
|
_messages.ToolReturnPart(
|
|
@@ -552,9 +600,14 @@ async def process_function_tools(
|
|
|
552
600
|
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
|
|
553
601
|
# validation, we don't add another part here
|
|
554
602
|
if result_tool_name is not None:
|
|
603
|
+
if found_used_result_tool:
|
|
604
|
+
content = 'Result tool not used - a final result was already processed.'
|
|
605
|
+
else:
|
|
606
|
+
# TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
|
|
607
|
+
content = 'Result tool not used - result failed validation.'
|
|
555
608
|
part = _messages.ToolReturnPart(
|
|
556
609
|
tool_name=call.tool_name,
|
|
557
|
-
content=
|
|
610
|
+
content=content,
|
|
558
611
|
tool_call_id=call.tool_call_id,
|
|
559
612
|
)
|
|
560
613
|
output_parts.append(part)
|
|
@@ -575,7 +628,7 @@ async def process_function_tools(
|
|
|
575
628
|
for task in done:
|
|
576
629
|
index = tasks.index(task)
|
|
577
630
|
result = task.result()
|
|
578
|
-
yield _messages.FunctionToolResultEvent(result,
|
|
631
|
+
yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index])
|
|
579
632
|
if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)):
|
|
580
633
|
results_by_index[index] = result
|
|
581
634
|
else:
|
|
@@ -675,7 +728,7 @@ def build_agent_graph(
|
|
|
675
728
|
nodes = (
|
|
676
729
|
UserPromptNode[DepsT],
|
|
677
730
|
ModelRequestNode[DepsT],
|
|
678
|
-
|
|
731
|
+
CallToolsNode[DepsT],
|
|
679
732
|
)
|
|
680
733
|
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]](
|
|
681
734
|
nodes=nodes,
|
pydantic_ai/_result.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import inspect
|
|
4
4
|
import sys
|
|
5
5
|
import types
|
|
6
|
-
from collections.abc import Awaitable, Iterable
|
|
6
|
+
from collections.abc import Awaitable, Iterable, Iterator
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
|
|
9
9
|
|
|
@@ -127,12 +127,12 @@ class ResultSchema(Generic[ResultDataT]):
|
|
|
127
127
|
def find_tool(
|
|
128
128
|
self,
|
|
129
129
|
parts: Iterable[_messages.ModelResponsePart],
|
|
130
|
-
) -> tuple[_messages.ToolCallPart, ResultTool[ResultDataT]]
|
|
130
|
+
) -> Iterator[tuple[_messages.ToolCallPart, ResultTool[ResultDataT]]]:
|
|
131
131
|
"""Find a tool that matches one of the calls."""
|
|
132
132
|
for part in parts:
|
|
133
133
|
if isinstance(part, _messages.ToolCallPart):
|
|
134
134
|
if result := self.tools.get(part.tool_name):
|
|
135
|
-
|
|
135
|
+
yield part, result
|
|
136
136
|
|
|
137
137
|
def tool_names(self) -> list[str]:
|
|
138
138
|
"""Return the names of the tools."""
|
pydantic_ai/_utils.py
CHANGED
|
@@ -48,6 +48,8 @@ def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
|
|
|
48
48
|
|
|
49
49
|
if schema.get('type') == 'object':
|
|
50
50
|
return schema
|
|
51
|
+
elif schema.get('$ref') is not None:
|
|
52
|
+
return schema.get('$defs', {}).get(schema['$ref'][8:]) # This removes the initial "#/$defs/".
|
|
51
53
|
else:
|
|
52
54
|
raise UserError('Schema must be an object')
|
|
53
55
|
|