pydantic-ai-slim 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (55) hide show
  1. pydantic_ai/_agent_graph.py +219 -315
  2. pydantic_ai/_cli.py +9 -7
  3. pydantic_ai/_output.py +296 -226
  4. pydantic_ai/_parts_manager.py +2 -2
  5. pydantic_ai/_run_context.py +8 -14
  6. pydantic_ai/_tool_manager.py +190 -0
  7. pydantic_ai/_utils.py +18 -1
  8. pydantic_ai/ag_ui.py +675 -0
  9. pydantic_ai/agent.py +369 -155
  10. pydantic_ai/common_tools/duckduckgo.py +5 -2
  11. pydantic_ai/exceptions.py +14 -2
  12. pydantic_ai/ext/aci.py +12 -3
  13. pydantic_ai/ext/langchain.py +9 -1
  14. pydantic_ai/mcp.py +147 -84
  15. pydantic_ai/messages.py +19 -9
  16. pydantic_ai/models/__init__.py +43 -19
  17. pydantic_ai/models/anthropic.py +2 -2
  18. pydantic_ai/models/bedrock.py +1 -1
  19. pydantic_ai/models/cohere.py +1 -1
  20. pydantic_ai/models/function.py +50 -24
  21. pydantic_ai/models/gemini.py +3 -11
  22. pydantic_ai/models/google.py +3 -12
  23. pydantic_ai/models/groq.py +2 -1
  24. pydantic_ai/models/huggingface.py +463 -0
  25. pydantic_ai/models/instrumented.py +1 -1
  26. pydantic_ai/models/mistral.py +3 -3
  27. pydantic_ai/models/openai.py +5 -5
  28. pydantic_ai/output.py +21 -7
  29. pydantic_ai/profiles/google.py +1 -1
  30. pydantic_ai/profiles/moonshotai.py +8 -0
  31. pydantic_ai/providers/__init__.py +4 -0
  32. pydantic_ai/providers/google.py +2 -2
  33. pydantic_ai/providers/google_vertex.py +10 -5
  34. pydantic_ai/providers/grok.py +13 -1
  35. pydantic_ai/providers/groq.py +2 -0
  36. pydantic_ai/providers/huggingface.py +88 -0
  37. pydantic_ai/result.py +57 -33
  38. pydantic_ai/tools.py +26 -119
  39. pydantic_ai/toolsets/__init__.py +22 -0
  40. pydantic_ai/toolsets/abstract.py +155 -0
  41. pydantic_ai/toolsets/combined.py +88 -0
  42. pydantic_ai/toolsets/deferred.py +38 -0
  43. pydantic_ai/toolsets/filtered.py +24 -0
  44. pydantic_ai/toolsets/function.py +238 -0
  45. pydantic_ai/toolsets/prefixed.py +37 -0
  46. pydantic_ai/toolsets/prepared.py +36 -0
  47. pydantic_ai/toolsets/renamed.py +42 -0
  48. pydantic_ai/toolsets/wrapper.py +37 -0
  49. pydantic_ai/usage.py +14 -8
  50. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +13 -8
  51. pydantic_ai_slim-0.4.4.dist-info/RECORD +98 -0
  52. pydantic_ai_slim-0.4.2.dist-info/RECORD +0 -83
  53. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
  54. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
  55. {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
@@ -2,10 +2,10 @@
2
2
 
3
3
  The manager tracks which parts (in particular, text and tool calls) correspond to which
4
4
  vendor-specific identifiers (e.g., `index`, `tool_call_id`, etc., as appropriate for a given model),
5
- and produces PydanticAI-format events as appropriate for consumers of the streaming APIs.
5
+ and produces Pydantic AI-format events as appropriate for consumers of the streaming APIs.
6
6
 
7
7
  The "vendor-specific identifiers" to use depend on the semantics of the responses of the responses from the vendor,
8
- and are tightly coupled to the specific model being used, and the PydanticAI Model subclass implementation.
8
+ and are tightly coupled to the specific model being used, and the Pydantic AI Model subclass implementation.
9
9
 
10
10
  This `ModelResponsePartsManager` is used in each of the subclasses of `StreamedResponse` as a way to consolidate
11
11
  event-emitting logic.
@@ -5,6 +5,7 @@ from collections.abc import Sequence
5
5
  from dataclasses import field
6
6
  from typing import TYPE_CHECKING, Generic
7
7
 
8
+ from opentelemetry.trace import NoOpTracer, Tracer
8
9
  from typing_extensions import TypeVar
9
10
 
10
11
  from . import _utils, messages as _messages
@@ -27,10 +28,16 @@ class RunContext(Generic[AgentDepsT]):
27
28
  """The model used in this run."""
28
29
  usage: Usage
29
30
  """LLM usage associated with the run."""
30
- prompt: str | Sequence[_messages.UserContent] | None
31
+ prompt: str | Sequence[_messages.UserContent] | None = None
31
32
  """The original user prompt passed to the run."""
32
33
  messages: list[_messages.ModelMessage] = field(default_factory=list)
33
34
  """Messages exchanged in the conversation so far."""
35
+ tracer: Tracer = field(default_factory=NoOpTracer)
36
+ """The tracer to use for tracing the run."""
37
+ trace_include_content: bool = False
38
+ """Whether to include the content of the messages in the trace."""
39
+ retries: dict[str, int] = field(default_factory=dict)
40
+ """Number of retries for each tool so far."""
34
41
  tool_call_id: str | None = None
35
42
  """The ID of the tool call."""
36
43
  tool_name: str | None = None
@@ -40,17 +47,4 @@ class RunContext(Generic[AgentDepsT]):
40
47
  run_step: int = 0
41
48
  """The current step in the run."""
42
49
 
43
- def replace_with(
44
- self,
45
- retry: int | None = None,
46
- tool_name: str | None | _utils.Unset = _utils.UNSET,
47
- ) -> RunContext[AgentDepsT]:
48
- # Create a new `RunContext` a new `retry` value and `tool_name`.
49
- kwargs = {}
50
- if retry is not None:
51
- kwargs['retry'] = retry
52
- if tool_name is not _utils.UNSET: # pragma: no branch
53
- kwargs['tool_name'] = tool_name
54
- return dataclasses.replace(self, **kwargs)
55
-
56
50
  __repr__ = _utils.dataclasses_no_defaults_repr
@@ -0,0 +1,190 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from collections.abc import Iterable
5
+ from dataclasses import dataclass, replace
6
+ from typing import Any, Generic
7
+
8
+ from pydantic import ValidationError
9
+ from typing_extensions import assert_never
10
+
11
+ from pydantic_ai.output import DeferredToolCalls
12
+
13
+ from . import messages as _messages
14
+ from ._run_context import AgentDepsT, RunContext
15
+ from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior
16
+ from .messages import ToolCallPart
17
+ from .tools import ToolDefinition
18
+ from .toolsets.abstract import AbstractToolset, ToolsetTool
19
+
20
+
21
+ @dataclass
22
+ class ToolManager(Generic[AgentDepsT]):
23
+ """Manages tools for an agent run step. It caches the agent run's toolset's tool definitions and handles calling tools and retries."""
24
+
25
+ ctx: RunContext[AgentDepsT]
26
+ """The agent run context for a specific run step."""
27
+ toolset: AbstractToolset[AgentDepsT]
28
+ """The toolset that provides the tools for this run step."""
29
+ tools: dict[str, ToolsetTool[AgentDepsT]]
30
+ """The cached tools for this run step."""
31
+
32
+ @classmethod
33
+ async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
34
+ """Build a new tool manager for a specific run step."""
35
+ return cls(
36
+ ctx=ctx,
37
+ toolset=toolset,
38
+ tools=await toolset.get_tools(ctx),
39
+ )
40
+
41
+ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
42
+ """Build a new tool manager for the next run step, carrying over the retries from the current run step."""
43
+ return await self.__class__.build(self.toolset, replace(ctx, retries=self.ctx.retries))
44
+
45
+ @property
46
+ def tool_defs(self) -> list[ToolDefinition]:
47
+ """The tool definitions for the tools in this tool manager."""
48
+ return [tool.tool_def for tool in self.tools.values()]
49
+
50
+ def get_tool_def(self, name: str) -> ToolDefinition | None:
51
+ """Get the tool definition for a given tool name, or `None` if the tool is unknown."""
52
+ try:
53
+ return self.tools[name].tool_def
54
+ except KeyError:
55
+ return None
56
+
57
+ async def handle_call(self, call: ToolCallPart, allow_partial: bool = False) -> Any:
58
+ """Handle a tool call by validating the arguments, calling the tool, and handling retries.
59
+
60
+ Args:
61
+ call: The tool call part to handle.
62
+ allow_partial: Whether to allow partial validation of the tool arguments.
63
+ """
64
+ if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
65
+ # Output tool calls are not traced
66
+ return await self._call_tool(call, allow_partial)
67
+ else:
68
+ return await self._call_tool_traced(call, allow_partial)
69
+
70
+ async def _call_tool(self, call: ToolCallPart, allow_partial: bool = False) -> Any:
71
+ name = call.tool_name
72
+ tool = self.tools.get(name)
73
+ try:
74
+ if tool is None:
75
+ if self.tools:
76
+ msg = f'Available tools: {", ".join(f"{name!r}" for name in self.tools.keys())}'
77
+ else:
78
+ msg = 'No tools available.'
79
+ raise ModelRetry(f'Unknown tool name: {name!r}. {msg}')
80
+
81
+ ctx = replace(
82
+ self.ctx,
83
+ tool_name=name,
84
+ tool_call_id=call.tool_call_id,
85
+ retry=self.ctx.retries.get(name, 0),
86
+ )
87
+
88
+ pyd_allow_partial = 'trailing-strings' if allow_partial else 'off'
89
+ validator = tool.args_validator
90
+ if isinstance(call.args, str):
91
+ args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial)
92
+ else:
93
+ args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial)
94
+
95
+ output = await self.toolset.call_tool(name, args_dict, ctx, tool)
96
+ except (ValidationError, ModelRetry) as e:
97
+ max_retries = tool.max_retries if tool is not None else 1
98
+ current_retry = self.ctx.retries.get(name, 0)
99
+
100
+ if current_retry == max_retries:
101
+ raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e
102
+ else:
103
+ if isinstance(e, ValidationError):
104
+ m = _messages.RetryPromptPart(
105
+ tool_name=name,
106
+ content=e.errors(include_url=False, include_context=False),
107
+ tool_call_id=call.tool_call_id,
108
+ )
109
+ e = ToolRetryError(m)
110
+ elif isinstance(e, ModelRetry):
111
+ m = _messages.RetryPromptPart(
112
+ tool_name=name,
113
+ content=e.message,
114
+ tool_call_id=call.tool_call_id,
115
+ )
116
+ e = ToolRetryError(m)
117
+ else:
118
+ assert_never(e)
119
+
120
+ self.ctx.retries[name] = current_retry + 1
121
+ raise e
122
+ else:
123
+ self.ctx.retries.pop(name, None)
124
+ return output
125
+
126
+ async def _call_tool_traced(self, call: ToolCallPart, allow_partial: bool = False) -> Any:
127
+ """See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
128
+ span_attributes = {
129
+ 'gen_ai.tool.name': call.tool_name,
130
+ # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai
131
+ 'gen_ai.tool.call.id': call.tool_call_id,
132
+ **({'tool_arguments': call.args_as_json_str()} if self.ctx.trace_include_content else {}),
133
+ 'logfire.msg': f'running tool: {call.tool_name}',
134
+ # add the JSON schema so these attributes are formatted nicely in Logfire
135
+ 'logfire.json_schema': json.dumps(
136
+ {
137
+ 'type': 'object',
138
+ 'properties': {
139
+ **(
140
+ {
141
+ 'tool_arguments': {'type': 'object'},
142
+ 'tool_response': {'type': 'object'},
143
+ }
144
+ if self.ctx.trace_include_content
145
+ else {}
146
+ ),
147
+ 'gen_ai.tool.name': {},
148
+ 'gen_ai.tool.call.id': {},
149
+ },
150
+ }
151
+ ),
152
+ }
153
+ with self.ctx.tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
154
+ try:
155
+ tool_result = await self._call_tool(call, allow_partial)
156
+ except ToolRetryError as e:
157
+ part = e.tool_retry
158
+ if self.ctx.trace_include_content and span.is_recording():
159
+ span.set_attribute('tool_response', part.model_response())
160
+ raise e
161
+
162
+ if self.ctx.trace_include_content and span.is_recording():
163
+ span.set_attribute(
164
+ 'tool_response',
165
+ tool_result
166
+ if isinstance(tool_result, str)
167
+ else _messages.tool_return_ta.dump_json(tool_result).decode(),
168
+ )
169
+
170
+ return tool_result
171
+
172
+ def get_deferred_tool_calls(self, parts: Iterable[_messages.ModelResponsePart]) -> DeferredToolCalls | None:
173
+ """Get the deferred tool calls from the model response parts."""
174
+ deferred_calls_and_defs = [
175
+ (part, tool_def)
176
+ for part in parts
177
+ if isinstance(part, _messages.ToolCallPart)
178
+ and (tool_def := self.get_tool_def(part.tool_name))
179
+ and tool_def.kind == 'deferred'
180
+ ]
181
+ if not deferred_calls_and_defs:
182
+ return None
183
+
184
+ deferred_calls: list[_messages.ToolCallPart] = []
185
+ deferred_tool_defs: dict[str, ToolDefinition] = {}
186
+ for part, tool_def in deferred_calls_and_defs:
187
+ deferred_calls.append(part)
188
+ deferred_tool_defs[part.tool_name] = tool_def
189
+
190
+ return DeferredToolCalls(deferred_calls, deferred_tool_defs)
pydantic_ai/_utils.py CHANGED
@@ -4,8 +4,10 @@ import asyncio
4
4
  import functools
5
5
  import inspect
6
6
  import re
7
+ import sys
7
8
  import time
8
9
  import uuid
10
+ import warnings
9
11
  from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator
10
12
  from contextlib import asynccontextmanager, suppress
11
13
  from dataclasses import dataclass, fields, is_dataclass
@@ -29,7 +31,7 @@ from typing_extensions import (
29
31
  from typing_inspection import typing_objects
30
32
  from typing_inspection.introspection import is_union_origin
31
33
 
32
- from pydantic_graph._utils import AbstractSpan
34
+ from pydantic_graph._utils import AbstractSpan, get_event_loop
33
35
 
34
36
  from . import exceptions
35
37
 
@@ -461,3 +463,18 @@ def get_union_args(tp: Any) -> tuple[Any, ...]:
461
463
  return get_args(tp)
462
464
  else:
463
465
  return ()
466
+
467
+
468
+ # The `asyncio.Lock` `loop` argument was deprecated in 3.8 and removed in 3.10,
469
+ # but 3.9 still needs it to have the intended behavior.
470
+
471
+ if sys.version_info < (3, 10):
472
+
473
+ def get_async_lock() -> asyncio.Lock: # pragma: lax no cover
474
+ with warnings.catch_warnings():
475
+ warnings.simplefilter('ignore', DeprecationWarning)
476
+ return asyncio.Lock(loop=get_event_loop())
477
+ else:
478
+
479
+ def get_async_lock() -> asyncio.Lock: # pragma: lax no cover
480
+ return asyncio.Lock()