pydantic-ai-slim 0.0.28__tar.gz → 0.0.30__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/PKG-INFO +3 -3
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/_agent_graph.py +48 -7
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/agent.py +70 -46
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/common_tools/duckduckgo.py +14 -4
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/messages.py +16 -1
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/models/__init__.py +4 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/models/function.py +15 -4
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/models/instrumented.py +38 -19
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/result.py +121 -1
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/usage.py +10 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pyproject.toml +3 -3
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/.gitignore +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/README.md +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/_pydantic.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/_result.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/_utils.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/models/anthropic.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/models/cohere.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/models/gemini.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/models/groq.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/models/mistral.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/models/openai.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/models/vertexai.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.28 → pydantic_ai_slim-0.0.30}/pydantic_ai/tools.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.30
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
31
|
Requires-Dist: logfire-api>=1.2.0
|
|
32
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.0.30
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
34
|
Provides-Extra: anthropic
|
|
35
35
|
Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
|
|
@@ -44,7 +44,7 @@ Requires-Dist: logfire>=2.3; extra == 'logfire'
|
|
|
44
44
|
Provides-Extra: mistral
|
|
45
45
|
Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
|
|
46
46
|
Provides-Extra: openai
|
|
47
|
-
Requires-Dist: openai>=1.
|
|
47
|
+
Requires-Dist: openai>=1.65.1; extra == 'openai'
|
|
48
48
|
Provides-Extra: tavily
|
|
49
49
|
Requires-Dist: tavily-python>=0.5.0; extra == 'tavily'
|
|
50
50
|
Provides-Extra: vertexai
|
|
@@ -2,7 +2,6 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import dataclasses
|
|
5
|
-
from abc import ABC
|
|
6
5
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
7
6
|
from contextlib import asynccontextmanager, contextmanager
|
|
8
7
|
from contextvars import ContextVar
|
|
@@ -10,7 +9,7 @@ from dataclasses import field
|
|
|
10
9
|
from typing import Any, Generic, Literal, Union, cast
|
|
11
10
|
|
|
12
11
|
import logfire_api
|
|
13
|
-
from typing_extensions import TypeVar, assert_never
|
|
12
|
+
from typing_extensions import TypeGuard, TypeVar, assert_never
|
|
14
13
|
|
|
15
14
|
from pydantic_graph import BaseNode, Graph, GraphRunContext
|
|
16
15
|
from pydantic_graph.nodes import End, NodeRunEndT
|
|
@@ -55,6 +54,7 @@ else:
|
|
|
55
54
|
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
|
|
56
55
|
|
|
57
56
|
T = TypeVar('T')
|
|
57
|
+
S = TypeVar('S')
|
|
58
58
|
NoneType = type(None)
|
|
59
59
|
EndStrategy = Literal['early', 'exhaustive']
|
|
60
60
|
"""The strategy for handling multiple tool calls when a final result is found.
|
|
@@ -107,8 +107,31 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
107
107
|
run_span: logfire_api.LogfireSpan
|
|
108
108
|
|
|
109
109
|
|
|
110
|
+
class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
|
|
111
|
+
"""The base class for all agent nodes.
|
|
112
|
+
|
|
113
|
+
Using subclass of `BaseNode` for all nodes reduces the amount of boilerplate of generics everywhere
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def is_agent_node(
|
|
118
|
+
node: BaseNode[GraphAgentState, GraphAgentDeps[T, Any], result.FinalResult[S]] | End[result.FinalResult[S]],
|
|
119
|
+
) -> TypeGuard[AgentNode[T, S]]:
|
|
120
|
+
"""Check if the provided node is an instance of `AgentNode`.
|
|
121
|
+
|
|
122
|
+
Usage:
|
|
123
|
+
|
|
124
|
+
if is_agent_node(node):
|
|
125
|
+
# `node` is an AgentNode
|
|
126
|
+
...
|
|
127
|
+
|
|
128
|
+
This method preserves the generic parameters on the narrowed type, unlike `isinstance(node, AgentNode)`.
|
|
129
|
+
"""
|
|
130
|
+
return isinstance(node, AgentNode)
|
|
131
|
+
|
|
132
|
+
|
|
110
133
|
@dataclasses.dataclass
|
|
111
|
-
class UserPromptNode(
|
|
134
|
+
class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
112
135
|
user_prompt: str | Sequence[_messages.UserContent]
|
|
113
136
|
|
|
114
137
|
system_prompts: tuple[str, ...]
|
|
@@ -215,7 +238,7 @@ async def _prepare_request_parameters(
|
|
|
215
238
|
|
|
216
239
|
|
|
217
240
|
@dataclasses.dataclass
|
|
218
|
-
class ModelRequestNode(
|
|
241
|
+
class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
219
242
|
"""Make a request to the model using the last message in state.message_history."""
|
|
220
243
|
|
|
221
244
|
request: _messages.ModelRequest
|
|
@@ -236,12 +259,30 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], res
|
|
|
236
259
|
|
|
237
260
|
return await self._make_request(ctx)
|
|
238
261
|
|
|
262
|
+
@asynccontextmanager
|
|
263
|
+
async def stream(
|
|
264
|
+
self,
|
|
265
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
|
|
266
|
+
) -> AsyncIterator[result.AgentStream[DepsT, T]]:
|
|
267
|
+
async with self._stream(ctx) as streamed_response:
|
|
268
|
+
agent_stream = result.AgentStream[DepsT, T](
|
|
269
|
+
streamed_response,
|
|
270
|
+
ctx.deps.result_schema,
|
|
271
|
+
ctx.deps.result_validators,
|
|
272
|
+
build_run_context(ctx),
|
|
273
|
+
ctx.deps.usage_limits,
|
|
274
|
+
)
|
|
275
|
+
yield agent_stream
|
|
276
|
+
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
|
|
277
|
+
# otherwise usage won't be properly counted:
|
|
278
|
+
async for _ in agent_stream:
|
|
279
|
+
pass
|
|
280
|
+
|
|
239
281
|
@asynccontextmanager
|
|
240
282
|
async def _stream(
|
|
241
283
|
self,
|
|
242
284
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
|
|
243
285
|
) -> AsyncIterator[models.StreamedResponse]:
|
|
244
|
-
# TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public
|
|
245
286
|
assert not self._did_stream, 'stream() should only be called once per node'
|
|
246
287
|
|
|
247
288
|
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
@@ -319,7 +360,7 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], res
|
|
|
319
360
|
|
|
320
361
|
|
|
321
362
|
@dataclasses.dataclass
|
|
322
|
-
class HandleResponseNode(
|
|
363
|
+
class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
|
|
323
364
|
"""Process a model response, and decide whether to end the run or make a new request."""
|
|
324
365
|
|
|
325
366
|
model_response: _messages.ModelResponse
|
|
@@ -575,7 +616,7 @@ async def process_function_tools(
|
|
|
575
616
|
for task in done:
|
|
576
617
|
index = tasks.index(task)
|
|
577
618
|
result = task.result()
|
|
578
|
-
yield _messages.FunctionToolResultEvent(result,
|
|
619
|
+
yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index])
|
|
579
620
|
if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)):
|
|
580
621
|
results_by_index[index] = result
|
|
581
622
|
else:
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
import asyncio
|
|
4
3
|
import dataclasses
|
|
5
4
|
import inspect
|
|
6
5
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
@@ -10,9 +9,10 @@ from types import FrameType
|
|
|
10
9
|
from typing import Any, Callable, Generic, cast, final, overload
|
|
11
10
|
|
|
12
11
|
import logfire_api
|
|
13
|
-
from typing_extensions import TypeVar, deprecated
|
|
12
|
+
from typing_extensions import TypeGuard, TypeVar, deprecated
|
|
14
13
|
|
|
15
|
-
from pydantic_graph import
|
|
14
|
+
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
|
|
15
|
+
from pydantic_graph._utils import get_event_loop
|
|
16
16
|
|
|
17
17
|
from . import (
|
|
18
18
|
_agent_graph,
|
|
@@ -46,7 +46,6 @@ HandleResponseNode = _agent_graph.HandleResponseNode
|
|
|
46
46
|
ModelRequestNode = _agent_graph.ModelRequestNode
|
|
47
47
|
UserPromptNode = _agent_graph.UserPromptNode
|
|
48
48
|
|
|
49
|
-
|
|
50
49
|
__all__ = (
|
|
51
50
|
'Agent',
|
|
52
51
|
'AgentRun',
|
|
@@ -71,6 +70,7 @@ else:
|
|
|
71
70
|
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
|
|
72
71
|
|
|
73
72
|
T = TypeVar('T')
|
|
73
|
+
S = TypeVar('S')
|
|
74
74
|
NoneType = type(None)
|
|
75
75
|
RunResultDataT = TypeVar('RunResultDataT')
|
|
76
76
|
"""Type variable for the result data of a run where `result_type` was customized on the run call."""
|
|
@@ -538,7 +538,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
538
538
|
"""
|
|
539
539
|
if infer_name and self.name is None:
|
|
540
540
|
self._infer_name(inspect.currentframe())
|
|
541
|
-
return
|
|
541
|
+
return get_event_loop().run_until_complete(
|
|
542
542
|
self.run(
|
|
543
543
|
user_prompt,
|
|
544
544
|
result_type=result_type,
|
|
@@ -646,10 +646,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
646
646
|
) as agent_run:
|
|
647
647
|
first_node = agent_run.next_node # start with the first node
|
|
648
648
|
assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node
|
|
649
|
-
node
|
|
649
|
+
node = first_node
|
|
650
650
|
while True:
|
|
651
|
-
if
|
|
652
|
-
node = cast(_agent_graph.ModelRequestNode[AgentDepsT, Any], node)
|
|
651
|
+
if self.is_model_request_node(node):
|
|
653
652
|
graph_ctx = agent_run.ctx
|
|
654
653
|
async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage]
|
|
655
654
|
|
|
@@ -717,9 +716,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
717
716
|
)
|
|
718
717
|
break
|
|
719
718
|
next_node = await agent_run.next(node)
|
|
720
|
-
if not isinstance(next_node,
|
|
719
|
+
if not isinstance(next_node, _agent_graph.AgentNode):
|
|
721
720
|
raise exceptions.AgentRunError('Should have produced a StreamedRunResult before getting here')
|
|
722
|
-
node = cast(
|
|
721
|
+
node = cast(_agent_graph.AgentNode[Any, Any], next_node)
|
|
723
722
|
|
|
724
723
|
if not yielded:
|
|
725
724
|
raise exceptions.AgentRunError('Agent run finished without producing a final result')
|
|
@@ -1173,6 +1172,46 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1173
1172
|
else:
|
|
1174
1173
|
return self._result_schema # pyright: ignore[reportReturnType]
|
|
1175
1174
|
|
|
1175
|
+
@staticmethod
|
|
1176
|
+
def is_model_request_node(
|
|
1177
|
+
node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
|
|
1178
|
+
) -> TypeGuard[_agent_graph.ModelRequestNode[T, S]]:
|
|
1179
|
+
"""Check if the node is a `ModelRequestNode`, narrowing the type if it is.
|
|
1180
|
+
|
|
1181
|
+
This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
|
|
1182
|
+
"""
|
|
1183
|
+
return isinstance(node, _agent_graph.ModelRequestNode)
|
|
1184
|
+
|
|
1185
|
+
@staticmethod
|
|
1186
|
+
def is_handle_response_node(
|
|
1187
|
+
node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
|
|
1188
|
+
) -> TypeGuard[_agent_graph.HandleResponseNode[T, S]]:
|
|
1189
|
+
"""Check if the node is a `HandleResponseNode`, narrowing the type if it is.
|
|
1190
|
+
|
|
1191
|
+
This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
|
|
1192
|
+
"""
|
|
1193
|
+
return isinstance(node, _agent_graph.HandleResponseNode)
|
|
1194
|
+
|
|
1195
|
+
@staticmethod
|
|
1196
|
+
def is_user_prompt_node(
|
|
1197
|
+
node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
|
|
1198
|
+
) -> TypeGuard[_agent_graph.UserPromptNode[T, S]]:
|
|
1199
|
+
"""Check if the node is a `UserPromptNode`, narrowing the type if it is.
|
|
1200
|
+
|
|
1201
|
+
This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
|
|
1202
|
+
"""
|
|
1203
|
+
return isinstance(node, _agent_graph.UserPromptNode)
|
|
1204
|
+
|
|
1205
|
+
@staticmethod
|
|
1206
|
+
def is_end_node(
|
|
1207
|
+
node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
|
|
1208
|
+
) -> TypeGuard[End[result.FinalResult[S]]]:
|
|
1209
|
+
"""Check if the node is a `End`, narrowing the type if it is.
|
|
1210
|
+
|
|
1211
|
+
This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
|
|
1212
|
+
"""
|
|
1213
|
+
return isinstance(node, End)
|
|
1214
|
+
|
|
1176
1215
|
|
|
1177
1216
|
@dataclasses.dataclass(repr=False)
|
|
1178
1217
|
class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
@@ -1244,15 +1283,17 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1244
1283
|
@property
|
|
1245
1284
|
def next_node(
|
|
1246
1285
|
self,
|
|
1247
|
-
) ->
|
|
1248
|
-
BaseNode[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT]]
|
|
1249
|
-
| End[FinalResult[ResultDataT]]
|
|
1250
|
-
):
|
|
1286
|
+
) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]:
|
|
1251
1287
|
"""The next node that will be run in the agent graph.
|
|
1252
1288
|
|
|
1253
1289
|
This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`.
|
|
1254
1290
|
"""
|
|
1255
|
-
|
|
1291
|
+
next_node = self._graph_run.next_node
|
|
1292
|
+
if isinstance(next_node, End):
|
|
1293
|
+
return next_node
|
|
1294
|
+
if _agent_graph.is_agent_node(next_node):
|
|
1295
|
+
return next_node
|
|
1296
|
+
raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover
|
|
1256
1297
|
|
|
1257
1298
|
@property
|
|
1258
1299
|
def result(self) -> AgentRunResult[ResultDataT] | None:
|
|
@@ -1273,45 +1314,24 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1273
1314
|
|
|
1274
1315
|
def __aiter__(
|
|
1275
1316
|
self,
|
|
1276
|
-
) -> AsyncIterator[
|
|
1277
|
-
BaseNode[
|
|
1278
|
-
_agent_graph.GraphAgentState,
|
|
1279
|
-
_agent_graph.GraphAgentDeps[AgentDepsT, Any],
|
|
1280
|
-
FinalResult[ResultDataT],
|
|
1281
|
-
]
|
|
1282
|
-
| End[FinalResult[ResultDataT]]
|
|
1283
|
-
]:
|
|
1317
|
+
) -> AsyncIterator[_agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]]:
|
|
1284
1318
|
"""Provide async-iteration over the nodes in the agent run."""
|
|
1285
1319
|
return self
|
|
1286
1320
|
|
|
1287
1321
|
async def __anext__(
|
|
1288
1322
|
self,
|
|
1289
|
-
) ->
|
|
1290
|
-
BaseNode[
|
|
1291
|
-
_agent_graph.GraphAgentState,
|
|
1292
|
-
_agent_graph.GraphAgentDeps[AgentDepsT, Any],
|
|
1293
|
-
FinalResult[ResultDataT],
|
|
1294
|
-
]
|
|
1295
|
-
| End[FinalResult[ResultDataT]]
|
|
1296
|
-
):
|
|
1323
|
+
) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]:
|
|
1297
1324
|
"""Advance to the next node automatically based on the last returned node."""
|
|
1298
|
-
|
|
1325
|
+
next_node = await self._graph_run.__anext__()
|
|
1326
|
+
if _agent_graph.is_agent_node(next_node):
|
|
1327
|
+
return next_node
|
|
1328
|
+
assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
|
|
1329
|
+
return next_node
|
|
1299
1330
|
|
|
1300
1331
|
async def next(
|
|
1301
1332
|
self,
|
|
1302
|
-
node:
|
|
1303
|
-
|
|
1304
|
-
_agent_graph.GraphAgentDeps[AgentDepsT, Any],
|
|
1305
|
-
FinalResult[ResultDataT],
|
|
1306
|
-
],
|
|
1307
|
-
) -> (
|
|
1308
|
-
BaseNode[
|
|
1309
|
-
_agent_graph.GraphAgentState,
|
|
1310
|
-
_agent_graph.GraphAgentDeps[AgentDepsT, Any],
|
|
1311
|
-
FinalResult[ResultDataT],
|
|
1312
|
-
]
|
|
1313
|
-
| End[FinalResult[ResultDataT]]
|
|
1314
|
-
):
|
|
1333
|
+
node: _agent_graph.AgentNode[AgentDepsT, ResultDataT],
|
|
1334
|
+
) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]:
|
|
1315
1335
|
"""Manually drive the agent run by passing in the node you want to run next.
|
|
1316
1336
|
|
|
1317
1337
|
This lets you inspect or mutate the node before continuing execution, or skip certain nodes
|
|
@@ -1378,7 +1398,11 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1378
1398
|
"""
|
|
1379
1399
|
# Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it
|
|
1380
1400
|
# on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate.
|
|
1381
|
-
|
|
1401
|
+
next_node = await self._graph_run.next(node)
|
|
1402
|
+
if _agent_graph.is_agent_node(next_node):
|
|
1403
|
+
return next_node
|
|
1404
|
+
assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
|
|
1405
|
+
return next_node
|
|
1382
1406
|
|
|
1383
1407
|
def usage(self) -> _usage.Usage:
|
|
1384
1408
|
"""Get usage statistics for the run so far, including token usage, model requests, and so on."""
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
from dataclasses import dataclass
|
|
2
3
|
from typing import TypedDict
|
|
3
4
|
|
|
@@ -39,6 +40,9 @@ class DuckDuckGoSearchTool:
|
|
|
39
40
|
client: DDGS
|
|
40
41
|
"""The DuckDuckGo search client."""
|
|
41
42
|
|
|
43
|
+
max_results: int | None = None
|
|
44
|
+
"""The maximum number of results. If None, returns results only from the first response."""
|
|
45
|
+
|
|
42
46
|
async def __call__(self, query: str) -> list[DuckDuckGoResult]:
|
|
43
47
|
"""Searches DuckDuckGo for the given query and returns the results.
|
|
44
48
|
|
|
@@ -48,16 +52,22 @@ class DuckDuckGoSearchTool:
|
|
|
48
52
|
Returns:
|
|
49
53
|
The search results.
|
|
50
54
|
"""
|
|
51
|
-
|
|
55
|
+
search = functools.partial(self.client.text, max_results=self.max_results)
|
|
56
|
+
results = await anyio.to_thread.run_sync(search, query)
|
|
52
57
|
if len(results) == 0:
|
|
53
58
|
raise RuntimeError('No search results found.')
|
|
54
59
|
return duckduckgo_ta.validate_python(results)
|
|
55
60
|
|
|
56
61
|
|
|
57
|
-
def duckduckgo_search_tool(duckduckgo_client: DDGS | None = None):
|
|
58
|
-
"""Creates a DuckDuckGo search tool.
|
|
62
|
+
def duckduckgo_search_tool(duckduckgo_client: DDGS | None = None, max_results: int | None = None):
|
|
63
|
+
"""Creates a DuckDuckGo search tool.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
duckduckgo_client: The DuckDuckGo search client.
|
|
67
|
+
max_results: The maximum number of results. If None, returns results only from the first response.
|
|
68
|
+
"""
|
|
59
69
|
return Tool(
|
|
60
|
-
DuckDuckGoSearchTool(client=duckduckgo_client or DDGS()).__call__,
|
|
70
|
+
DuckDuckGoSearchTool(client=duckduckgo_client or DDGS(), max_results=max_results).__call__,
|
|
61
71
|
name='duckduckgo_search',
|
|
62
72
|
description='Searches DuckDuckGo for the given query and returns the results.',
|
|
63
73
|
)
|
|
@@ -533,9 +533,24 @@ class PartDeltaEvent:
|
|
|
533
533
|
"""Event type identifier, used as a discriminator."""
|
|
534
534
|
|
|
535
535
|
|
|
536
|
+
@dataclass
|
|
537
|
+
class FinalResultEvent:
|
|
538
|
+
"""An event indicating the response to the current model request matches the result schema."""
|
|
539
|
+
|
|
540
|
+
tool_name: str | None
|
|
541
|
+
"""The name of the result tool that was called. `None` if the result is from text content and not from a tool."""
|
|
542
|
+
event_kind: Literal['final_result'] = 'final_result'
|
|
543
|
+
"""Event type identifier, used as a discriminator."""
|
|
544
|
+
|
|
545
|
+
|
|
536
546
|
ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
|
|
537
547
|
"""An event in the model response stream, either starting a new part or applying a delta to an existing one."""
|
|
538
548
|
|
|
549
|
+
AgentStreamEvent = Annotated[
|
|
550
|
+
Union[PartStartEvent, PartDeltaEvent, FinalResultEvent], pydantic.Discriminator('event_kind')
|
|
551
|
+
]
|
|
552
|
+
"""An event in the agent stream."""
|
|
553
|
+
|
|
539
554
|
|
|
540
555
|
@dataclass
|
|
541
556
|
class FunctionToolCallEvent:
|
|
@@ -558,7 +573,7 @@ class FunctionToolResultEvent:
|
|
|
558
573
|
|
|
559
574
|
result: ToolReturnPart | RetryPromptPart
|
|
560
575
|
"""The result of the call to the function tool."""
|
|
561
|
-
|
|
576
|
+
tool_call_id: str
|
|
562
577
|
"""An ID used to match the result to its original call."""
|
|
563
578
|
event_kind: Literal['function_tool_result'] = 'function_tool_result'
|
|
564
579
|
"""Event type identifier, used as a discriminator."""
|
|
@@ -84,6 +84,8 @@ KnownModelName = Literal[
|
|
|
84
84
|
'gpt-4-turbo-2024-04-09',
|
|
85
85
|
'gpt-4-turbo-preview',
|
|
86
86
|
'gpt-4-vision-preview',
|
|
87
|
+
'gpt-4.5-preview',
|
|
88
|
+
'gpt-4.5-preview-2025-02-27',
|
|
87
89
|
'gpt-4o',
|
|
88
90
|
'gpt-4o-2024-05-13',
|
|
89
91
|
'gpt-4o-2024-08-06',
|
|
@@ -138,6 +140,8 @@ KnownModelName = Literal[
|
|
|
138
140
|
'openai:gpt-4-turbo-2024-04-09',
|
|
139
141
|
'openai:gpt-4-turbo-preview',
|
|
140
142
|
'openai:gpt-4-vision-preview',
|
|
143
|
+
'openai:gpt-4.5-preview',
|
|
144
|
+
'openai:gpt-4.5-preview-2025-02-27',
|
|
141
145
|
'openai:gpt-4o',
|
|
142
146
|
'openai:gpt-4o-2024-05-13',
|
|
143
147
|
'openai:gpt-4o-2024-08-06',
|
|
@@ -177,6 +177,8 @@ class DeltaToolCall:
|
|
|
177
177
|
"""Incremental change to the name of the tool."""
|
|
178
178
|
json_args: str | None = None
|
|
179
179
|
"""Incremental change to the arguments as JSON"""
|
|
180
|
+
tool_call_id: str | None = None
|
|
181
|
+
"""Incremental change to the tool call ID."""
|
|
180
182
|
|
|
181
183
|
|
|
182
184
|
DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
|
|
@@ -224,7 +226,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
224
226
|
vendor_part_id=dtc_index,
|
|
225
227
|
tool_name=delta_tool_call.name,
|
|
226
228
|
args=delta_tool_call.json_args,
|
|
227
|
-
tool_call_id=
|
|
229
|
+
tool_call_id=delta_tool_call.tool_call_id,
|
|
228
230
|
)
|
|
229
231
|
if maybe_event is not None:
|
|
230
232
|
yield maybe_event
|
|
@@ -280,7 +282,16 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
|
|
280
282
|
return 0
|
|
281
283
|
if isinstance(content, str):
|
|
282
284
|
return len(re.split(r'[\s",.:]+', content.strip()))
|
|
283
|
-
# TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
|
|
284
285
|
else: # pragma: no cover
|
|
285
|
-
|
|
286
|
-
|
|
286
|
+
tokens = 0
|
|
287
|
+
for part in content:
|
|
288
|
+
if isinstance(part, str):
|
|
289
|
+
tokens += len(re.split(r'[\s",.:]+', part.strip()))
|
|
290
|
+
# TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
|
|
291
|
+
if isinstance(part, (AudioUrl, ImageUrl)):
|
|
292
|
+
tokens += 0
|
|
293
|
+
elif isinstance(part, BinaryContent):
|
|
294
|
+
tokens += len(part.data)
|
|
295
|
+
else:
|
|
296
|
+
tokens += 0
|
|
297
|
+
return tokens
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import json
|
|
3
4
|
from collections.abc import AsyncIterator, Iterator
|
|
4
5
|
from contextlib import asynccontextmanager, contextmanager
|
|
5
6
|
from dataclasses import dataclass, field
|
|
@@ -9,6 +10,7 @@ from typing import Any, Callable, Literal
|
|
|
9
10
|
import logfire_api
|
|
10
11
|
from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
|
|
11
12
|
from opentelemetry.trace import Tracer, TracerProvider, get_tracer_provider
|
|
13
|
+
from opentelemetry.util.types import AttributeValue
|
|
12
14
|
|
|
13
15
|
from ..messages import (
|
|
14
16
|
ModelMessage,
|
|
@@ -46,40 +48,42 @@ MODEL_SETTING_ATTRIBUTES: tuple[
|
|
|
46
48
|
'frequency_penalty',
|
|
47
49
|
)
|
|
48
50
|
|
|
49
|
-
NOT_GIVEN = object()
|
|
50
|
-
|
|
51
51
|
|
|
52
52
|
@dataclass
|
|
53
53
|
class InstrumentedModel(WrapperModel):
|
|
54
|
-
"""Model which is instrumented with
|
|
54
|
+
"""Model which is instrumented with OpenTelemetry."""
|
|
55
55
|
|
|
56
56
|
tracer: Tracer = field(repr=False)
|
|
57
57
|
event_logger: EventLogger = field(repr=False)
|
|
58
|
+
event_mode: Literal['attributes', 'logs'] = 'attributes'
|
|
58
59
|
|
|
59
60
|
def __init__(
|
|
60
61
|
self,
|
|
61
62
|
wrapped: Model | KnownModelName,
|
|
62
63
|
tracer_provider: TracerProvider | None = None,
|
|
63
64
|
event_logger_provider: EventLoggerProvider | None = None,
|
|
65
|
+
event_mode: Literal['attributes', 'logs'] = 'attributes',
|
|
64
66
|
):
|
|
65
67
|
super().__init__(wrapped)
|
|
66
68
|
tracer_provider = tracer_provider or get_tracer_provider()
|
|
67
69
|
event_logger_provider = event_logger_provider or get_event_logger_provider()
|
|
68
70
|
self.tracer = tracer_provider.get_tracer('pydantic-ai')
|
|
69
71
|
self.event_logger = event_logger_provider.get_event_logger('pydantic-ai')
|
|
72
|
+
self.event_mode = event_mode
|
|
70
73
|
|
|
71
74
|
@classmethod
|
|
72
75
|
def from_logfire(
|
|
73
76
|
cls,
|
|
74
77
|
wrapped: Model | KnownModelName,
|
|
75
78
|
logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE,
|
|
79
|
+
event_mode: Literal['attributes', 'logs'] = 'attributes',
|
|
76
80
|
) -> InstrumentedModel:
|
|
77
81
|
if hasattr(logfire_instance.config, 'get_event_logger_provider'):
|
|
78
82
|
event_provider = logfire_instance.config.get_event_logger_provider()
|
|
79
83
|
else:
|
|
80
84
|
event_provider = None
|
|
81
85
|
tracer_provider = logfire_instance.config.get_tracer_provider()
|
|
82
|
-
return cls(wrapped, tracer_provider, event_provider)
|
|
86
|
+
return cls(wrapped, tracer_provider, event_provider, event_mode)
|
|
83
87
|
|
|
84
88
|
async def request(
|
|
85
89
|
self,
|
|
@@ -111,7 +115,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
111
115
|
finish(response_stream.get(), response_stream.usage())
|
|
112
116
|
|
|
113
117
|
@contextmanager
|
|
114
|
-
def _instrument(
|
|
118
|
+
def _instrument( # noqa: C901
|
|
115
119
|
self,
|
|
116
120
|
messages: list[ModelMessage],
|
|
117
121
|
model_settings: ModelSettings | None,
|
|
@@ -126,7 +130,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
126
130
|
# - server.port: to parse from the base_url
|
|
127
131
|
# - error.type: unclear if we should do something here or just always rely on span exceptions
|
|
128
132
|
# - gen_ai.request.stop_sequences/top_k: model_settings doesn't include these
|
|
129
|
-
attributes: dict[str,
|
|
133
|
+
attributes: dict[str, AttributeValue] = {
|
|
130
134
|
'gen_ai.operation.name': operation,
|
|
131
135
|
'gen_ai.system': system,
|
|
132
136
|
'gen_ai.request.model': model_name,
|
|
@@ -134,10 +138,11 @@ class InstrumentedModel(WrapperModel):
|
|
|
134
138
|
|
|
135
139
|
if model_settings:
|
|
136
140
|
for key in MODEL_SETTING_ATTRIBUTES:
|
|
137
|
-
if (value := model_settings.get(key,
|
|
141
|
+
if isinstance(value := model_settings.get(key), (float, int)):
|
|
138
142
|
attributes[f'gen_ai.request.{key}'] = value
|
|
139
143
|
|
|
140
|
-
|
|
144
|
+
events_list = []
|
|
145
|
+
emit_event = partial(self._emit_event, system, events_list)
|
|
141
146
|
|
|
142
147
|
with self.tracer.start_as_current_span(span_name, attributes=attributes) as span:
|
|
143
148
|
if span.is_recording():
|
|
@@ -167,22 +172,36 @@ class InstrumentedModel(WrapperModel):
|
|
|
167
172
|
)
|
|
168
173
|
span.set_attributes(
|
|
169
174
|
{
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
'gen_ai.response.model': response.model_name or model_name,
|
|
175
|
-
'gen_ai.usage.input_tokens': usage.request_tokens,
|
|
176
|
-
'gen_ai.usage.output_tokens': usage.response_tokens,
|
|
177
|
-
}.items()
|
|
178
|
-
if v is not None
|
|
175
|
+
# TODO finish_reason (https://github.com/open-telemetry/semantic-conventions/issues/1277), id
|
|
176
|
+
# https://github.com/pydantic/pydantic-ai/issues/886
|
|
177
|
+
'gen_ai.response.model': response.model_name or model_name,
|
|
178
|
+
**usage.opentelemetry_attributes(),
|
|
179
179
|
}
|
|
180
180
|
)
|
|
181
|
+
if events_list:
|
|
182
|
+
attr_name = 'events'
|
|
183
|
+
span.set_attributes(
|
|
184
|
+
{
|
|
185
|
+
attr_name: json.dumps(events_list),
|
|
186
|
+
'logfire.json_schema': json.dumps(
|
|
187
|
+
{
|
|
188
|
+
'type': 'object',
|
|
189
|
+
'properties': {attr_name: {'type': 'array'}},
|
|
190
|
+
}
|
|
191
|
+
),
|
|
192
|
+
}
|
|
193
|
+
)
|
|
181
194
|
|
|
182
195
|
yield finish
|
|
183
196
|
|
|
184
|
-
def _emit_event(
|
|
185
|
-
self
|
|
197
|
+
def _emit_event(
|
|
198
|
+
self, system: str, events_list: list[dict[str, Any]], event_name: str, body: dict[str, Any]
|
|
199
|
+
) -> None:
|
|
200
|
+
attributes = {'gen_ai.system': system}
|
|
201
|
+
if self.event_mode == 'logs':
|
|
202
|
+
self.event_logger.emit(Event(event_name, body=body, attributes=attributes))
|
|
203
|
+
else:
|
|
204
|
+
events_list.append({'event.name': event_name, **body, **attributes})
|
|
186
205
|
|
|
187
206
|
|
|
188
207
|
def _request_part_body(part: ModelRequestPart) -> tuple[str, dict[str, Any]]:
|
|
@@ -7,9 +7,10 @@ from datetime import datetime
|
|
|
7
7
|
from typing import Generic, Union, cast
|
|
8
8
|
|
|
9
9
|
import logfire_api
|
|
10
|
-
from typing_extensions import TypeVar
|
|
10
|
+
from typing_extensions import TypeVar, assert_type
|
|
11
11
|
|
|
12
12
|
from . import _result, _utils, exceptions, messages as _messages, models
|
|
13
|
+
from .messages import AgentStreamEvent, FinalResultEvent
|
|
13
14
|
from .tools import AgentDepsT, RunContext
|
|
14
15
|
from .usage import Usage, UsageLimits
|
|
15
16
|
|
|
@@ -51,6 +52,125 @@ Usage `ResultValidatorFunc[AgentDepsT, T]`.
|
|
|
51
52
|
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
52
53
|
|
|
53
54
|
|
|
55
|
+
@dataclass
|
|
56
|
+
class AgentStream(Generic[AgentDepsT, ResultDataT]):
|
|
57
|
+
_raw_stream_response: models.StreamedResponse
|
|
58
|
+
_result_schema: _result.ResultSchema[ResultDataT] | None
|
|
59
|
+
_result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]]
|
|
60
|
+
_run_ctx: RunContext[AgentDepsT]
|
|
61
|
+
_usage_limits: UsageLimits | None
|
|
62
|
+
|
|
63
|
+
_agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
|
|
64
|
+
_final_result_event: FinalResultEvent | None = field(default=None, init=False)
|
|
65
|
+
_initial_run_ctx_usage: Usage = field(init=False)
|
|
66
|
+
|
|
67
|
+
def __post_init__(self):
|
|
68
|
+
self._initial_run_ctx_usage = copy(self._run_ctx.usage)
|
|
69
|
+
|
|
70
|
+
async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]:
|
|
71
|
+
"""Asynchronously stream the (validated) agent outputs."""
|
|
72
|
+
async for response in self.stream_responses(debounce_by=debounce_by):
|
|
73
|
+
if self._final_result_event is not None:
|
|
74
|
+
yield await self._validate_response(response, self._final_result_event.tool_name, allow_partial=True)
|
|
75
|
+
if self._final_result_event is not None:
|
|
76
|
+
yield await self._validate_response(
|
|
77
|
+
self._raw_stream_response.get(), self._final_result_event.tool_name, allow_partial=False
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
|
|
81
|
+
"""Asynchronously stream the (unvalidated) model responses for the agent."""
|
|
82
|
+
# if the message currently has any parts with content, yield before streaming
|
|
83
|
+
msg = self._raw_stream_response.get()
|
|
84
|
+
for part in msg.parts:
|
|
85
|
+
if part.has_content():
|
|
86
|
+
yield msg
|
|
87
|
+
break
|
|
88
|
+
|
|
89
|
+
async with _utils.group_by_temporal(self, debounce_by) as group_iter:
|
|
90
|
+
async for _items in group_iter:
|
|
91
|
+
yield self._raw_stream_response.get() # current state of the response
|
|
92
|
+
|
|
93
|
+
def usage(self) -> Usage:
|
|
94
|
+
"""Return the usage of the whole run.
|
|
95
|
+
|
|
96
|
+
!!! note
|
|
97
|
+
This won't return the full usage until the stream is finished.
|
|
98
|
+
"""
|
|
99
|
+
return self._initial_run_ctx_usage + self._raw_stream_response.usage()
|
|
100
|
+
|
|
101
|
+
async def _validate_response(
|
|
102
|
+
self, message: _messages.ModelResponse, result_tool_name: str | None, *, allow_partial: bool = False
|
|
103
|
+
) -> ResultDataT:
|
|
104
|
+
"""Validate a structured result message."""
|
|
105
|
+
if self._result_schema is not None and result_tool_name is not None:
|
|
106
|
+
match = self._result_schema.find_named_tool(message.parts, result_tool_name)
|
|
107
|
+
if match is None:
|
|
108
|
+
raise exceptions.UnexpectedModelBehavior(
|
|
109
|
+
f'Invalid response, unable to find tool: {self._result_schema.tool_names()}'
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
call, result_tool = match
|
|
113
|
+
result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
|
|
114
|
+
|
|
115
|
+
for validator in self._result_validators:
|
|
116
|
+
result_data = await validator.validate(result_data, call, self._run_ctx)
|
|
117
|
+
return result_data
|
|
118
|
+
else:
|
|
119
|
+
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
|
|
120
|
+
for validator in self._result_validators:
|
|
121
|
+
text = await validator.validate(
|
|
122
|
+
text,
|
|
123
|
+
None,
|
|
124
|
+
self._run_ctx,
|
|
125
|
+
)
|
|
126
|
+
# Since there is no result tool, we can assume that str is compatible with ResultDataT
|
|
127
|
+
return cast(ResultDataT, text)
|
|
128
|
+
|
|
129
|
+
def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
|
|
130
|
+
"""Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
|
|
131
|
+
|
|
132
|
+
This proxies the _raw_stream_response and sends all events to the agent stream, while also checking for matches
|
|
133
|
+
on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the
|
|
134
|
+
first match is found.
|
|
135
|
+
"""
|
|
136
|
+
if self._agent_stream_iterator is not None:
|
|
137
|
+
return self._agent_stream_iterator
|
|
138
|
+
|
|
139
|
+
async def aiter():
|
|
140
|
+
result_schema = self._result_schema
|
|
141
|
+
allow_text_result = result_schema is None or result_schema.allow_text_result
|
|
142
|
+
|
|
143
|
+
def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None:
|
|
144
|
+
"""Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
|
|
145
|
+
if isinstance(e, _messages.PartStartEvent):
|
|
146
|
+
new_part = e.part
|
|
147
|
+
if isinstance(new_part, _messages.ToolCallPart):
|
|
148
|
+
if result_schema is not None and (match := result_schema.find_tool([new_part])):
|
|
149
|
+
call, _ = match
|
|
150
|
+
return _messages.FinalResultEvent(tool_name=call.tool_name)
|
|
151
|
+
elif allow_text_result:
|
|
152
|
+
assert_type(e, _messages.PartStartEvent)
|
|
153
|
+
return _messages.FinalResultEvent(tool_name=None)
|
|
154
|
+
|
|
155
|
+
usage_checking_stream = _get_usage_checking_stream_response(
|
|
156
|
+
self._raw_stream_response, self._usage_limits, self.usage
|
|
157
|
+
)
|
|
158
|
+
async for event in usage_checking_stream:
|
|
159
|
+
yield event
|
|
160
|
+
if (final_result_event := _get_final_result_event(event)) is not None:
|
|
161
|
+
self._final_result_event = final_result_event
|
|
162
|
+
yield final_result_event
|
|
163
|
+
break
|
|
164
|
+
|
|
165
|
+
# If we broke out of the above loop, we need to yield the rest of the events
|
|
166
|
+
# If we didn't, this will just be a no-op
|
|
167
|
+
async for event in usage_checking_stream:
|
|
168
|
+
yield event
|
|
169
|
+
|
|
170
|
+
self._agent_stream_iterator = aiter()
|
|
171
|
+
return self._agent_stream_iterator
|
|
172
|
+
|
|
173
|
+
|
|
54
174
|
@dataclass
|
|
55
175
|
class StreamedRunResult(Generic[AgentDepsT, ResultDataT]):
|
|
56
176
|
"""Result of a streamed run that returns structured data via a tool call."""
|
|
@@ -56,6 +56,16 @@ class Usage:
|
|
|
56
56
|
new_usage.incr(other)
|
|
57
57
|
return new_usage
|
|
58
58
|
|
|
59
|
+
def opentelemetry_attributes(self) -> dict[str, int]:
|
|
60
|
+
"""Get the token limits as OpenTelemetry attributes."""
|
|
61
|
+
result = {
|
|
62
|
+
'gen_ai.usage.input_tokens': self.request_tokens,
|
|
63
|
+
'gen_ai.usage.output_tokens': self.response_tokens,
|
|
64
|
+
}
|
|
65
|
+
for key, value in (self.details or {}).items():
|
|
66
|
+
result[f'gen_ai.usage.details.{key}'] = value
|
|
67
|
+
return {k: v for k, v in result.items() if v is not None}
|
|
68
|
+
|
|
59
69
|
|
|
60
70
|
@dataclass
|
|
61
71
|
class UsageLimits:
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "pydantic-ai-slim"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.30"
|
|
8
8
|
description = "Agent Framework / shim to use Pydantic with LLMs, slim package"
|
|
9
9
|
authors = [{ name = "Samuel Colvin", email = "samuel@pydantic.dev" }]
|
|
10
10
|
license = "MIT"
|
|
@@ -37,7 +37,7 @@ dependencies = [
|
|
|
37
37
|
"httpx>=0.27",
|
|
38
38
|
"logfire-api>=1.2.0",
|
|
39
39
|
"pydantic>=2.10",
|
|
40
|
-
"pydantic-graph==0.0.
|
|
40
|
+
"pydantic-graph==0.0.30",
|
|
41
41
|
"exceptiongroup; python_version < '3.11'",
|
|
42
42
|
]
|
|
43
43
|
|
|
@@ -45,7 +45,7 @@ dependencies = [
|
|
|
45
45
|
# WARNING if you add optional groups, please update docs/install.md
|
|
46
46
|
logfire = ["logfire>=2.3"]
|
|
47
47
|
# Models
|
|
48
|
-
openai = ["openai>=1.
|
|
48
|
+
openai = ["openai>=1.65.1"]
|
|
49
49
|
cohere = ["cohere>=5.13.11"]
|
|
50
50
|
vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
|
|
51
51
|
anthropic = ["anthropic>=0.40.0"]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|