pydantic-ai-slim 0.0.29__py3-none-any.whl → 0.0.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +2 -2
- pydantic_ai/_agent_graph.py +97 -44
- pydantic_ai/_result.py +3 -3
- pydantic_ai/_utils.py +2 -0
- pydantic_ai/agent.py +95 -67
- pydantic_ai/messages.py +71 -1
- pydantic_ai/models/__init__.py +4 -0
- pydantic_ai/models/function.py +15 -4
- pydantic_ai/models/instrumented.py +70 -78
- pydantic_ai/result.py +125 -1
- pydantic_ai/usage.py +10 -0
- {pydantic_ai_slim-0.0.29.dist-info → pydantic_ai_slim-0.0.31.dist-info}/METADATA +4 -3
- {pydantic_ai_slim-0.0.29.dist-info → pydantic_ai_slim-0.0.31.dist-info}/RECORD +14 -14
- {pydantic_ai_slim-0.0.29.dist-info → pydantic_ai_slim-0.0.31.dist-info}/WHEEL +0 -0
pydantic_ai/agent.py
CHANGED
|
@@ -9,9 +9,9 @@ from types import FrameType
|
|
|
9
9
|
from typing import Any, Callable, Generic, cast, final, overload
|
|
10
10
|
|
|
11
11
|
import logfire_api
|
|
12
|
-
from typing_extensions import TypeVar, deprecated
|
|
12
|
+
from typing_extensions import TypeGuard, TypeVar, deprecated
|
|
13
13
|
|
|
14
|
-
from pydantic_graph import
|
|
14
|
+
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
|
|
15
15
|
from pydantic_graph._utils import get_event_loop
|
|
16
16
|
|
|
17
17
|
from . import (
|
|
@@ -25,6 +25,7 @@ from . import (
|
|
|
25
25
|
result,
|
|
26
26
|
usage as _usage,
|
|
27
27
|
)
|
|
28
|
+
from .models.instrumented import InstrumentedModel
|
|
28
29
|
from .result import FinalResult, ResultDataT, StreamedRunResult
|
|
29
30
|
from .settings import ModelSettings, merge_model_settings
|
|
30
31
|
from .tools import (
|
|
@@ -42,18 +43,17 @@ from .tools import (
|
|
|
42
43
|
# Re-exporting like this improves auto-import behavior in PyCharm
|
|
43
44
|
capture_run_messages = _agent_graph.capture_run_messages
|
|
44
45
|
EndStrategy = _agent_graph.EndStrategy
|
|
45
|
-
|
|
46
|
+
CallToolsNode = _agent_graph.CallToolsNode
|
|
46
47
|
ModelRequestNode = _agent_graph.ModelRequestNode
|
|
47
48
|
UserPromptNode = _agent_graph.UserPromptNode
|
|
48
49
|
|
|
49
|
-
|
|
50
50
|
__all__ = (
|
|
51
51
|
'Agent',
|
|
52
52
|
'AgentRun',
|
|
53
53
|
'AgentRunResult',
|
|
54
54
|
'capture_run_messages',
|
|
55
55
|
'EndStrategy',
|
|
56
|
-
'
|
|
56
|
+
'CallToolsNode',
|
|
57
57
|
'ModelRequestNode',
|
|
58
58
|
'UserPromptNode',
|
|
59
59
|
)
|
|
@@ -71,6 +71,7 @@ else:
|
|
|
71
71
|
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
|
|
72
72
|
|
|
73
73
|
T = TypeVar('T')
|
|
74
|
+
S = TypeVar('S')
|
|
74
75
|
NoneType = type(None)
|
|
75
76
|
RunResultDataT = TypeVar('RunResultDataT')
|
|
76
77
|
"""Type variable for the result data of a run where `result_type` was customized on the run call."""
|
|
@@ -294,7 +295,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
294
295
|
"""
|
|
295
296
|
if infer_name and self.name is None:
|
|
296
297
|
self._infer_name(inspect.currentframe())
|
|
297
|
-
with self.iter(
|
|
298
|
+
async with self.iter(
|
|
298
299
|
user_prompt=user_prompt,
|
|
299
300
|
result_type=result_type,
|
|
300
301
|
message_history=message_history,
|
|
@@ -310,8 +311,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
310
311
|
assert (final_result := agent_run.result) is not None, 'The graph run did not finish properly'
|
|
311
312
|
return final_result
|
|
312
313
|
|
|
313
|
-
@
|
|
314
|
-
def iter(
|
|
314
|
+
@asynccontextmanager
|
|
315
|
+
async def iter(
|
|
315
316
|
self,
|
|
316
317
|
user_prompt: str | Sequence[_messages.UserContent],
|
|
317
318
|
*,
|
|
@@ -323,7 +324,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
323
324
|
usage_limits: _usage.UsageLimits | None = None,
|
|
324
325
|
usage: _usage.Usage | None = None,
|
|
325
326
|
infer_name: bool = True,
|
|
326
|
-
) ->
|
|
327
|
+
) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
|
|
327
328
|
"""A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
|
|
328
329
|
|
|
329
330
|
This method builds an internal agent graph (using system prompts, tools and result schemas) and then returns an
|
|
@@ -344,7 +345,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
344
345
|
|
|
345
346
|
async def main():
|
|
346
347
|
nodes = []
|
|
347
|
-
with agent.iter('What is the capital of France?') as agent_run:
|
|
348
|
+
async with agent.iter('What is the capital of France?') as agent_run:
|
|
348
349
|
async for node in agent_run:
|
|
349
350
|
nodes.append(node)
|
|
350
351
|
print(nodes)
|
|
@@ -362,7 +363,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
362
363
|
kind='request',
|
|
363
364
|
)
|
|
364
365
|
),
|
|
365
|
-
|
|
366
|
+
CallToolsNode(
|
|
366
367
|
model_response=ModelResponse(
|
|
367
368
|
parts=[TextPart(content='Paris', part_kind='text')],
|
|
368
369
|
model_name='gpt-4o',
|
|
@@ -370,7 +371,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
370
371
|
kind='response',
|
|
371
372
|
)
|
|
372
373
|
),
|
|
373
|
-
End(data=FinalResult(data='Paris', tool_name=None)),
|
|
374
|
+
End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
|
|
374
375
|
]
|
|
375
376
|
'''
|
|
376
377
|
print(agent_run.result.data)
|
|
@@ -454,7 +455,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
454
455
|
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
|
|
455
456
|
)
|
|
456
457
|
|
|
457
|
-
with graph.iter(
|
|
458
|
+
async with graph.iter(
|
|
458
459
|
start_node,
|
|
459
460
|
state=state,
|
|
460
461
|
deps=graph_deps,
|
|
@@ -633,7 +634,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
633
634
|
self._infer_name(frame.f_back)
|
|
634
635
|
|
|
635
636
|
yielded = False
|
|
636
|
-
with self.iter(
|
|
637
|
+
async with self.iter(
|
|
637
638
|
user_prompt,
|
|
638
639
|
result_type=result_type,
|
|
639
640
|
message_history=message_history,
|
|
@@ -646,10 +647,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
646
647
|
) as agent_run:
|
|
647
648
|
first_node = agent_run.next_node # start with the first node
|
|
648
649
|
assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node
|
|
649
|
-
node
|
|
650
|
+
node = first_node
|
|
650
651
|
while True:
|
|
651
|
-
if
|
|
652
|
-
node = cast(_agent_graph.ModelRequestNode[AgentDepsT, Any], node)
|
|
652
|
+
if self.is_model_request_node(node):
|
|
653
653
|
graph_ctx = agent_run.ctx
|
|
654
654
|
async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage]
|
|
655
655
|
|
|
@@ -662,11 +662,10 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
662
662
|
new_part = maybe_part_event.part
|
|
663
663
|
if isinstance(new_part, _messages.TextPart):
|
|
664
664
|
if _agent_graph.allow_text_result(result_schema):
|
|
665
|
-
return FinalResult(s, None)
|
|
666
|
-
elif isinstance(new_part, _messages.ToolCallPart):
|
|
667
|
-
|
|
668
|
-
call,
|
|
669
|
-
return FinalResult(s, call.tool_name)
|
|
665
|
+
return FinalResult(s, None, None)
|
|
666
|
+
elif isinstance(new_part, _messages.ToolCallPart) and result_schema:
|
|
667
|
+
for call, _ in result_schema.find_tool([new_part]):
|
|
668
|
+
return FinalResult(s, call.tool_name, call.tool_call_id)
|
|
670
669
|
return None
|
|
671
670
|
|
|
672
671
|
final_result_details = await stream_to_final(streamed_response)
|
|
@@ -693,6 +692,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
693
692
|
async for _event in _agent_graph.process_function_tools(
|
|
694
693
|
tool_calls,
|
|
695
694
|
final_result_details.tool_name,
|
|
695
|
+
final_result_details.tool_call_id,
|
|
696
696
|
graph_ctx,
|
|
697
697
|
parts,
|
|
698
698
|
):
|
|
@@ -717,9 +717,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
717
717
|
)
|
|
718
718
|
break
|
|
719
719
|
next_node = await agent_run.next(node)
|
|
720
|
-
if not isinstance(next_node,
|
|
720
|
+
if not isinstance(next_node, _agent_graph.AgentNode):
|
|
721
721
|
raise exceptions.AgentRunError('Should have produced a StreamedRunResult before getting here')
|
|
722
|
-
node = cast(
|
|
722
|
+
node = cast(_agent_graph.AgentNode[Any, Any], next_node)
|
|
723
723
|
|
|
724
724
|
if not yielded:
|
|
725
725
|
raise exceptions.AgentRunError('Agent run finished without producing a final result')
|
|
@@ -1116,6 +1116,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1116
1116
|
else:
|
|
1117
1117
|
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
|
|
1118
1118
|
|
|
1119
|
+
if not isinstance(model_, InstrumentedModel):
|
|
1120
|
+
model_ = InstrumentedModel(model_)
|
|
1121
|
+
|
|
1119
1122
|
return model_
|
|
1120
1123
|
|
|
1121
1124
|
def _get_deps(self: Agent[T, ResultDataT], deps: T) -> T:
|
|
@@ -1173,12 +1176,52 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1173
1176
|
else:
|
|
1174
1177
|
return self._result_schema # pyright: ignore[reportReturnType]
|
|
1175
1178
|
|
|
1179
|
+
@staticmethod
|
|
1180
|
+
def is_model_request_node(
|
|
1181
|
+
node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
|
|
1182
|
+
) -> TypeGuard[_agent_graph.ModelRequestNode[T, S]]:
|
|
1183
|
+
"""Check if the node is a `ModelRequestNode`, narrowing the type if it is.
|
|
1184
|
+
|
|
1185
|
+
This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
|
|
1186
|
+
"""
|
|
1187
|
+
return isinstance(node, _agent_graph.ModelRequestNode)
|
|
1188
|
+
|
|
1189
|
+
@staticmethod
|
|
1190
|
+
def is_call_tools_node(
|
|
1191
|
+
node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
|
|
1192
|
+
) -> TypeGuard[_agent_graph.CallToolsNode[T, S]]:
|
|
1193
|
+
"""Check if the node is a `CallToolsNode`, narrowing the type if it is.
|
|
1194
|
+
|
|
1195
|
+
This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
|
|
1196
|
+
"""
|
|
1197
|
+
return isinstance(node, _agent_graph.CallToolsNode)
|
|
1198
|
+
|
|
1199
|
+
@staticmethod
|
|
1200
|
+
def is_user_prompt_node(
|
|
1201
|
+
node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
|
|
1202
|
+
) -> TypeGuard[_agent_graph.UserPromptNode[T, S]]:
|
|
1203
|
+
"""Check if the node is a `UserPromptNode`, narrowing the type if it is.
|
|
1204
|
+
|
|
1205
|
+
This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
|
|
1206
|
+
"""
|
|
1207
|
+
return isinstance(node, _agent_graph.UserPromptNode)
|
|
1208
|
+
|
|
1209
|
+
@staticmethod
|
|
1210
|
+
def is_end_node(
|
|
1211
|
+
node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
|
|
1212
|
+
) -> TypeGuard[End[result.FinalResult[S]]]:
|
|
1213
|
+
"""Check if the node is a `End`, narrowing the type if it is.
|
|
1214
|
+
|
|
1215
|
+
This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
|
|
1216
|
+
"""
|
|
1217
|
+
return isinstance(node, End)
|
|
1218
|
+
|
|
1176
1219
|
|
|
1177
1220
|
@dataclasses.dataclass(repr=False)
|
|
1178
1221
|
class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
1179
1222
|
"""A stateful, async-iterable run of an [`Agent`][pydantic_ai.agent.Agent].
|
|
1180
1223
|
|
|
1181
|
-
You generally obtain an `AgentRun` instance by calling `with my_agent.iter(...) as agent_run:`.
|
|
1224
|
+
You generally obtain an `AgentRun` instance by calling `async with my_agent.iter(...) as agent_run:`.
|
|
1182
1225
|
|
|
1183
1226
|
Once you have an instance, you can use it to iterate through the run's nodes as they execute. When an
|
|
1184
1227
|
[`End`][pydantic_graph.nodes.End] is reached, the run finishes and [`result`][pydantic_ai.agent.AgentRun.result]
|
|
@@ -1193,7 +1236,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1193
1236
|
async def main():
|
|
1194
1237
|
nodes = []
|
|
1195
1238
|
# Iterate through the run, recording each node along the way:
|
|
1196
|
-
with agent.iter('What is the capital of France?') as agent_run:
|
|
1239
|
+
async with agent.iter('What is the capital of France?') as agent_run:
|
|
1197
1240
|
async for node in agent_run:
|
|
1198
1241
|
nodes.append(node)
|
|
1199
1242
|
print(nodes)
|
|
@@ -1211,7 +1254,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1211
1254
|
kind='request',
|
|
1212
1255
|
)
|
|
1213
1256
|
),
|
|
1214
|
-
|
|
1257
|
+
CallToolsNode(
|
|
1215
1258
|
model_response=ModelResponse(
|
|
1216
1259
|
parts=[TextPart(content='Paris', part_kind='text')],
|
|
1217
1260
|
model_name='gpt-4o',
|
|
@@ -1219,7 +1262,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1219
1262
|
kind='response',
|
|
1220
1263
|
)
|
|
1221
1264
|
),
|
|
1222
|
-
End(data=FinalResult(data='Paris', tool_name=None)),
|
|
1265
|
+
End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
|
|
1223
1266
|
]
|
|
1224
1267
|
'''
|
|
1225
1268
|
print(agent_run.result.data)
|
|
@@ -1244,15 +1287,17 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1244
1287
|
@property
|
|
1245
1288
|
def next_node(
|
|
1246
1289
|
self,
|
|
1247
|
-
) ->
|
|
1248
|
-
BaseNode[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT]]
|
|
1249
|
-
| End[FinalResult[ResultDataT]]
|
|
1250
|
-
):
|
|
1290
|
+
) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]:
|
|
1251
1291
|
"""The next node that will be run in the agent graph.
|
|
1252
1292
|
|
|
1253
1293
|
This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`.
|
|
1254
1294
|
"""
|
|
1255
|
-
|
|
1295
|
+
next_node = self._graph_run.next_node
|
|
1296
|
+
if isinstance(next_node, End):
|
|
1297
|
+
return next_node
|
|
1298
|
+
if _agent_graph.is_agent_node(next_node):
|
|
1299
|
+
return next_node
|
|
1300
|
+
raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover
|
|
1256
1301
|
|
|
1257
1302
|
@property
|
|
1258
1303
|
def result(self) -> AgentRunResult[ResultDataT] | None:
|
|
@@ -1273,45 +1318,24 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1273
1318
|
|
|
1274
1319
|
def __aiter__(
|
|
1275
1320
|
self,
|
|
1276
|
-
) -> AsyncIterator[
|
|
1277
|
-
BaseNode[
|
|
1278
|
-
_agent_graph.GraphAgentState,
|
|
1279
|
-
_agent_graph.GraphAgentDeps[AgentDepsT, Any],
|
|
1280
|
-
FinalResult[ResultDataT],
|
|
1281
|
-
]
|
|
1282
|
-
| End[FinalResult[ResultDataT]]
|
|
1283
|
-
]:
|
|
1321
|
+
) -> AsyncIterator[_agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]]:
|
|
1284
1322
|
"""Provide async-iteration over the nodes in the agent run."""
|
|
1285
1323
|
return self
|
|
1286
1324
|
|
|
1287
1325
|
async def __anext__(
|
|
1288
1326
|
self,
|
|
1289
|
-
) ->
|
|
1290
|
-
BaseNode[
|
|
1291
|
-
_agent_graph.GraphAgentState,
|
|
1292
|
-
_agent_graph.GraphAgentDeps[AgentDepsT, Any],
|
|
1293
|
-
FinalResult[ResultDataT],
|
|
1294
|
-
]
|
|
1295
|
-
| End[FinalResult[ResultDataT]]
|
|
1296
|
-
):
|
|
1327
|
+
) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]:
|
|
1297
1328
|
"""Advance to the next node automatically based on the last returned node."""
|
|
1298
|
-
|
|
1329
|
+
next_node = await self._graph_run.__anext__()
|
|
1330
|
+
if _agent_graph.is_agent_node(next_node):
|
|
1331
|
+
return next_node
|
|
1332
|
+
assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
|
|
1333
|
+
return next_node
|
|
1299
1334
|
|
|
1300
1335
|
async def next(
|
|
1301
1336
|
self,
|
|
1302
|
-
node:
|
|
1303
|
-
|
|
1304
|
-
_agent_graph.GraphAgentDeps[AgentDepsT, Any],
|
|
1305
|
-
FinalResult[ResultDataT],
|
|
1306
|
-
],
|
|
1307
|
-
) -> (
|
|
1308
|
-
BaseNode[
|
|
1309
|
-
_agent_graph.GraphAgentState,
|
|
1310
|
-
_agent_graph.GraphAgentDeps[AgentDepsT, Any],
|
|
1311
|
-
FinalResult[ResultDataT],
|
|
1312
|
-
]
|
|
1313
|
-
| End[FinalResult[ResultDataT]]
|
|
1314
|
-
):
|
|
1337
|
+
node: _agent_graph.AgentNode[AgentDepsT, ResultDataT],
|
|
1338
|
+
) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]:
|
|
1315
1339
|
"""Manually drive the agent run by passing in the node you want to run next.
|
|
1316
1340
|
|
|
1317
1341
|
This lets you inspect or mutate the node before continuing execution, or skip certain nodes
|
|
@@ -1326,7 +1350,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1326
1350
|
agent = Agent('openai:gpt-4o')
|
|
1327
1351
|
|
|
1328
1352
|
async def main():
|
|
1329
|
-
with agent.iter('What is the capital of France?') as agent_run:
|
|
1353
|
+
async with agent.iter('What is the capital of France?') as agent_run:
|
|
1330
1354
|
next_node = agent_run.next_node # start with the first node
|
|
1331
1355
|
nodes = [next_node]
|
|
1332
1356
|
while not isinstance(next_node, End):
|
|
@@ -1354,7 +1378,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1354
1378
|
kind='request',
|
|
1355
1379
|
)
|
|
1356
1380
|
),
|
|
1357
|
-
|
|
1381
|
+
CallToolsNode(
|
|
1358
1382
|
model_response=ModelResponse(
|
|
1359
1383
|
parts=[TextPart(content='Paris', part_kind='text')],
|
|
1360
1384
|
model_name='gpt-4o',
|
|
@@ -1362,7 +1386,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1362
1386
|
kind='response',
|
|
1363
1387
|
)
|
|
1364
1388
|
),
|
|
1365
|
-
End(data=FinalResult(data='Paris', tool_name=None)),
|
|
1389
|
+
End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
|
|
1366
1390
|
]
|
|
1367
1391
|
'''
|
|
1368
1392
|
print('Final result:', agent_run.result.data)
|
|
@@ -1378,7 +1402,11 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1378
1402
|
"""
|
|
1379
1403
|
# Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it
|
|
1380
1404
|
# on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate.
|
|
1381
|
-
|
|
1405
|
+
next_node = await self._graph_run.next(node)
|
|
1406
|
+
if _agent_graph.is_agent_node(next_node):
|
|
1407
|
+
return next_node
|
|
1408
|
+
assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
|
|
1409
|
+
return next_node
|
|
1382
1410
|
|
|
1383
1411
|
def usage(self) -> _usage.Usage:
|
|
1384
1412
|
"""Get usage statistics for the run so far, including token usage, model requests, and so on."""
|
pydantic_ai/messages.py
CHANGED
|
@@ -8,6 +8,7 @@ from typing import Annotated, Any, Literal, Union, cast, overload
|
|
|
8
8
|
|
|
9
9
|
import pydantic
|
|
10
10
|
import pydantic_core
|
|
11
|
+
from opentelemetry._events import Event
|
|
11
12
|
from typing_extensions import TypeAlias
|
|
12
13
|
|
|
13
14
|
from ._utils import now_utc as _now_utc
|
|
@@ -33,6 +34,9 @@ class SystemPromptPart:
|
|
|
33
34
|
part_kind: Literal['system-prompt'] = 'system-prompt'
|
|
34
35
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
35
36
|
|
|
37
|
+
def otel_event(self) -> Event:
|
|
38
|
+
return Event('gen_ai.system.message', body={'content': self.content, 'role': 'system'})
|
|
39
|
+
|
|
36
40
|
|
|
37
41
|
@dataclass
|
|
38
42
|
class AudioUrl:
|
|
@@ -138,6 +142,14 @@ class UserPromptPart:
|
|
|
138
142
|
part_kind: Literal['user-prompt'] = 'user-prompt'
|
|
139
143
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
140
144
|
|
|
145
|
+
def otel_event(self) -> Event:
|
|
146
|
+
if isinstance(self.content, str):
|
|
147
|
+
content = self.content
|
|
148
|
+
else:
|
|
149
|
+
# TODO figure out what to record for images and audio
|
|
150
|
+
content = [part if isinstance(part, str) else {'kind': part.kind} for part in self.content]
|
|
151
|
+
return Event('gen_ai.user.message', body={'content': content, 'role': 'user'})
|
|
152
|
+
|
|
141
153
|
|
|
142
154
|
tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Any, config=pydantic.ConfigDict(defer_build=True))
|
|
143
155
|
|
|
@@ -176,6 +188,9 @@ class ToolReturnPart:
|
|
|
176
188
|
else:
|
|
177
189
|
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
|
|
178
190
|
|
|
191
|
+
def otel_event(self) -> Event:
|
|
192
|
+
return Event('gen_ai.tool.message', body={'content': self.content, 'role': 'tool', 'id': self.tool_call_id})
|
|
193
|
+
|
|
179
194
|
|
|
180
195
|
error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
|
|
181
196
|
|
|
@@ -224,6 +239,14 @@ class RetryPromptPart:
|
|
|
224
239
|
description = f'{len(self.content)} validation errors: {json_errors.decode()}'
|
|
225
240
|
return f'{description}\n\nFix the errors and try again.'
|
|
226
241
|
|
|
242
|
+
def otel_event(self) -> Event:
|
|
243
|
+
if self.tool_name is None:
|
|
244
|
+
return Event('gen_ai.user.message', body={'content': self.model_response(), 'role': 'user'})
|
|
245
|
+
else:
|
|
246
|
+
return Event(
|
|
247
|
+
'gen_ai.tool.message', body={'content': self.model_response(), 'role': 'tool', 'id': self.tool_call_id}
|
|
248
|
+
)
|
|
249
|
+
|
|
227
250
|
|
|
228
251
|
ModelRequestPart = Annotated[
|
|
229
252
|
Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind')
|
|
@@ -329,6 +352,36 @@ class ModelResponse:
|
|
|
329
352
|
kind: Literal['response'] = 'response'
|
|
330
353
|
"""Message type identifier, this is available on all parts as a discriminator."""
|
|
331
354
|
|
|
355
|
+
def otel_events(self) -> list[Event]:
|
|
356
|
+
"""Return OpenTelemetry events for the response."""
|
|
357
|
+
result: list[Event] = []
|
|
358
|
+
|
|
359
|
+
def new_event_body():
|
|
360
|
+
new_body: dict[str, Any] = {'role': 'assistant'}
|
|
361
|
+
ev = Event('gen_ai.assistant.message', body=new_body)
|
|
362
|
+
result.append(ev)
|
|
363
|
+
return new_body
|
|
364
|
+
|
|
365
|
+
body = new_event_body()
|
|
366
|
+
for part in self.parts:
|
|
367
|
+
if isinstance(part, ToolCallPart):
|
|
368
|
+
body.setdefault('tool_calls', []).append(
|
|
369
|
+
{
|
|
370
|
+
'id': part.tool_call_id,
|
|
371
|
+
'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888
|
|
372
|
+
'function': {
|
|
373
|
+
'name': part.tool_name,
|
|
374
|
+
'arguments': part.args,
|
|
375
|
+
},
|
|
376
|
+
}
|
|
377
|
+
)
|
|
378
|
+
elif isinstance(part, TextPart):
|
|
379
|
+
if body.get('content'):
|
|
380
|
+
body = new_event_body()
|
|
381
|
+
body['content'] = part.content
|
|
382
|
+
|
|
383
|
+
return result
|
|
384
|
+
|
|
332
385
|
|
|
333
386
|
ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]
|
|
334
387
|
"""Any message sent to or returned by a model."""
|
|
@@ -533,9 +586,26 @@ class PartDeltaEvent:
|
|
|
533
586
|
"""Event type identifier, used as a discriminator."""
|
|
534
587
|
|
|
535
588
|
|
|
589
|
+
@dataclass
|
|
590
|
+
class FinalResultEvent:
|
|
591
|
+
"""An event indicating the response to the current model request matches the result schema."""
|
|
592
|
+
|
|
593
|
+
tool_name: str | None
|
|
594
|
+
"""The name of the result tool that was called. `None` if the result is from text content and not from a tool."""
|
|
595
|
+
tool_call_id: str | None
|
|
596
|
+
"""The tool call ID, if any, that this result is associated with."""
|
|
597
|
+
event_kind: Literal['final_result'] = 'final_result'
|
|
598
|
+
"""Event type identifier, used as a discriminator."""
|
|
599
|
+
|
|
600
|
+
|
|
536
601
|
ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
|
|
537
602
|
"""An event in the model response stream, either starting a new part or applying a delta to an existing one."""
|
|
538
603
|
|
|
604
|
+
AgentStreamEvent = Annotated[
|
|
605
|
+
Union[PartStartEvent, PartDeltaEvent, FinalResultEvent], pydantic.Discriminator('event_kind')
|
|
606
|
+
]
|
|
607
|
+
"""An event in the agent stream."""
|
|
608
|
+
|
|
539
609
|
|
|
540
610
|
@dataclass
|
|
541
611
|
class FunctionToolCallEvent:
|
|
@@ -558,7 +628,7 @@ class FunctionToolResultEvent:
|
|
|
558
628
|
|
|
559
629
|
result: ToolReturnPart | RetryPromptPart
|
|
560
630
|
"""The result of the call to the function tool."""
|
|
561
|
-
|
|
631
|
+
tool_call_id: str
|
|
562
632
|
"""An ID used to match the result to its original call."""
|
|
563
633
|
event_kind: Literal['function_tool_result'] = 'function_tool_result'
|
|
564
634
|
"""Event type identifier, used as a discriminator."""
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -84,6 +84,8 @@ KnownModelName = Literal[
|
|
|
84
84
|
'gpt-4-turbo-2024-04-09',
|
|
85
85
|
'gpt-4-turbo-preview',
|
|
86
86
|
'gpt-4-vision-preview',
|
|
87
|
+
'gpt-4.5-preview',
|
|
88
|
+
'gpt-4.5-preview-2025-02-27',
|
|
87
89
|
'gpt-4o',
|
|
88
90
|
'gpt-4o-2024-05-13',
|
|
89
91
|
'gpt-4o-2024-08-06',
|
|
@@ -138,6 +140,8 @@ KnownModelName = Literal[
|
|
|
138
140
|
'openai:gpt-4-turbo-2024-04-09',
|
|
139
141
|
'openai:gpt-4-turbo-preview',
|
|
140
142
|
'openai:gpt-4-vision-preview',
|
|
143
|
+
'openai:gpt-4.5-preview',
|
|
144
|
+
'openai:gpt-4.5-preview-2025-02-27',
|
|
141
145
|
'openai:gpt-4o',
|
|
142
146
|
'openai:gpt-4o-2024-05-13',
|
|
143
147
|
'openai:gpt-4o-2024-08-06',
|
pydantic_ai/models/function.py
CHANGED
|
@@ -177,6 +177,8 @@ class DeltaToolCall:
|
|
|
177
177
|
"""Incremental change to the name of the tool."""
|
|
178
178
|
json_args: str | None = None
|
|
179
179
|
"""Incremental change to the arguments as JSON"""
|
|
180
|
+
tool_call_id: str | None = None
|
|
181
|
+
"""Incremental change to the tool call ID."""
|
|
180
182
|
|
|
181
183
|
|
|
182
184
|
DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
|
|
@@ -224,7 +226,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
224
226
|
vendor_part_id=dtc_index,
|
|
225
227
|
tool_name=delta_tool_call.name,
|
|
226
228
|
args=delta_tool_call.json_args,
|
|
227
|
-
tool_call_id=
|
|
229
|
+
tool_call_id=delta_tool_call.tool_call_id,
|
|
228
230
|
)
|
|
229
231
|
if maybe_event is not None:
|
|
230
232
|
yield maybe_event
|
|
@@ -280,7 +282,16 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
|
|
280
282
|
return 0
|
|
281
283
|
if isinstance(content, str):
|
|
282
284
|
return len(re.split(r'[\s",.:]+', content.strip()))
|
|
283
|
-
# TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
|
|
284
285
|
else: # pragma: no cover
|
|
285
|
-
|
|
286
|
-
|
|
286
|
+
tokens = 0
|
|
287
|
+
for part in content:
|
|
288
|
+
if isinstance(part, str):
|
|
289
|
+
tokens += len(re.split(r'[\s",.:]+', part.strip()))
|
|
290
|
+
# TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
|
|
291
|
+
if isinstance(part, (AudioUrl, ImageUrl)):
|
|
292
|
+
tokens += 0
|
|
293
|
+
elif isinstance(part, BinaryContent):
|
|
294
|
+
tokens += len(part.data)
|
|
295
|
+
else:
|
|
296
|
+
tokens += 0
|
|
297
|
+
return tokens
|