pydantic-ai-slim 0.0.41__tar.gz → 0.0.43__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/.gitignore +1 -1
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/PKG-INFO +5 -3
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/_agent_graph.py +61 -7
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/_cli.py +1 -1
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/_pydantic.py +6 -5
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/agent.py +55 -7
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/common_tools/duckduckgo.py +1 -1
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/common_tools/tavily.py +1 -1
- pydantic_ai_slim-0.0.43/pydantic_ai/mcp.py +198 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/messages.py +3 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/anthropic.py +1 -1
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/cohere.py +1 -1
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/groq.py +1 -1
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/instrumented.py +13 -7
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/mistral.py +1 -1
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/openai.py +1 -1
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/vertexai.py +1 -1
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/wrapper.py +5 -1
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/anthropic.py +1 -1
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/azure.py +1 -1
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/bedrock.py +1 -1
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/deepseek.py +1 -1
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/google_vertex.py +11 -19
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/groq.py +16 -12
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/mistral.py +12 -12
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/openai.py +1 -1
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/tools.py +23 -2
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pyproject.toml +8 -3
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/README.md +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/_result.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/_utils.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/__init__.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/bedrock.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/function.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/gemini.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/__init__.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/google_gla.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.43
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
31
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
32
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.0.43
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
34
|
Requires-Dist: typing-inspection>=0.4.0
|
|
35
35
|
Provides-Extra: anthropic
|
|
@@ -45,9 +45,11 @@ Requires-Dist: cohere>=5.13.11; extra == 'cohere'
|
|
|
45
45
|
Provides-Extra: duckduckgo
|
|
46
46
|
Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
|
|
47
47
|
Provides-Extra: groq
|
|
48
|
-
Requires-Dist: groq>=0.
|
|
48
|
+
Requires-Dist: groq>=0.15.0; extra == 'groq'
|
|
49
49
|
Provides-Extra: logfire
|
|
50
50
|
Requires-Dist: logfire>=2.3; extra == 'logfire'
|
|
51
|
+
Provides-Extra: mcp
|
|
52
|
+
Requires-Dist: mcp>=1.4.1; (python_version >= '3.10') and extra == 'mcp'
|
|
51
53
|
Provides-Extra: mistral
|
|
52
54
|
Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
|
|
53
55
|
Provides-Extra: openai
|
|
@@ -7,7 +7,7 @@ from collections.abc import AsyncIterator, Iterator, Sequence
|
|
|
7
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
8
8
|
from contextvars import ContextVar
|
|
9
9
|
from dataclasses import field
|
|
10
|
-
from typing import Any, Generic, Literal, Union, cast
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast
|
|
11
11
|
|
|
12
12
|
from opentelemetry.trace import Span, Tracer
|
|
13
13
|
from typing_extensions import TypeGuard, TypeVar, assert_never
|
|
@@ -27,11 +27,10 @@ from . import (
|
|
|
27
27
|
from .models.instrumented import InstrumentedModel
|
|
28
28
|
from .result import ResultDataT
|
|
29
29
|
from .settings import ModelSettings, merge_model_settings
|
|
30
|
-
from .tools import
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
)
|
|
30
|
+
from .tools import RunContext, Tool, ToolDefinition
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from .mcp import MCPServer
|
|
35
34
|
|
|
36
35
|
__all__ = (
|
|
37
36
|
'GraphAgentState',
|
|
@@ -94,6 +93,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
94
93
|
result_validators: list[_result.ResultValidator[DepsT, ResultDataT]]
|
|
95
94
|
|
|
96
95
|
function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
|
|
96
|
+
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
|
|
97
97
|
|
|
98
98
|
run_span: Span
|
|
99
99
|
tracer: Tracer
|
|
@@ -219,7 +219,17 @@ async def _prepare_request_parameters(
|
|
|
219
219
|
if tool_def := await tool.prepare_tool_def(ctx):
|
|
220
220
|
function_tool_defs.append(tool_def)
|
|
221
221
|
|
|
222
|
-
|
|
222
|
+
async def add_mcp_server_tools(server: MCPServer) -> None:
|
|
223
|
+
if not server.is_running:
|
|
224
|
+
raise exceptions.UserError(f'MCP server is not running: {server}')
|
|
225
|
+
tool_defs = await server.list_tools()
|
|
226
|
+
# TODO(Marcelo): We should check if the tool names are unique. If not, we should raise an error.
|
|
227
|
+
function_tool_defs.extend(tool_defs)
|
|
228
|
+
|
|
229
|
+
await asyncio.gather(
|
|
230
|
+
*map(add_tool, ctx.deps.function_tools.values()),
|
|
231
|
+
*map(add_mcp_server_tools, ctx.deps.mcp_servers),
|
|
232
|
+
)
|
|
223
233
|
|
|
224
234
|
result_schema = ctx.deps.result_schema
|
|
225
235
|
return models.ModelRequestParameters(
|
|
@@ -594,6 +604,21 @@ async def process_function_tools(
|
|
|
594
604
|
yield event
|
|
595
605
|
call_index_to_event_id[len(calls_to_run)] = event.call_id
|
|
596
606
|
calls_to_run.append((tool, call))
|
|
607
|
+
elif mcp_tool := await _tool_from_mcp_server(call.tool_name, ctx):
|
|
608
|
+
if stub_function_tools:
|
|
609
|
+
# TODO(Marcelo): We should add coverage for this part of the code.
|
|
610
|
+
output_parts.append( # pragma: no cover
|
|
611
|
+
_messages.ToolReturnPart(
|
|
612
|
+
tool_name=call.tool_name,
|
|
613
|
+
content='Tool not executed - a final result was already processed.',
|
|
614
|
+
tool_call_id=call.tool_call_id,
|
|
615
|
+
)
|
|
616
|
+
)
|
|
617
|
+
else:
|
|
618
|
+
event = _messages.FunctionToolCallEvent(call)
|
|
619
|
+
yield event
|
|
620
|
+
call_index_to_event_id[len(calls_to_run)] = event.call_id
|
|
621
|
+
calls_to_run.append((mcp_tool, call))
|
|
597
622
|
elif result_schema is not None and call.tool_name in result_schema.tools:
|
|
598
623
|
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
|
|
599
624
|
# validation, we don't add another part here
|
|
@@ -641,6 +666,35 @@ async def process_function_tools(
|
|
|
641
666
|
output_parts.append(results_by_index[k])
|
|
642
667
|
|
|
643
668
|
|
|
669
|
+
async def _tool_from_mcp_server(
|
|
670
|
+
tool_name: str,
|
|
671
|
+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
672
|
+
) -> Tool[DepsT] | None:
|
|
673
|
+
"""Call each MCP server to find the tool with the given name.
|
|
674
|
+
|
|
675
|
+
Args:
|
|
676
|
+
tool_name: The name of the tool to find.
|
|
677
|
+
ctx: The current run context.
|
|
678
|
+
|
|
679
|
+
Returns:
|
|
680
|
+
The tool with the given name, or `None` if no tool with the given name is found.
|
|
681
|
+
"""
|
|
682
|
+
|
|
683
|
+
async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any:
|
|
684
|
+
# There's no normal situation where the server will not be running at this point, we check just in case
|
|
685
|
+
# some weird edge case occurs.
|
|
686
|
+
if not server.is_running: # pragma: no cover
|
|
687
|
+
raise exceptions.UserError(f'MCP server is not running: {server}')
|
|
688
|
+
result = await server.call_tool(tool_name, args)
|
|
689
|
+
return result
|
|
690
|
+
|
|
691
|
+
for server in ctx.deps.mcp_servers:
|
|
692
|
+
tools = await server.list_tools()
|
|
693
|
+
if tool_name in {tool.name for tool in tools}:
|
|
694
|
+
return Tool(name=tool_name, function=run_tool, takes_ctx=True)
|
|
695
|
+
return None
|
|
696
|
+
|
|
697
|
+
|
|
644
698
|
def _unknown_tool(
|
|
645
699
|
tool_name: str,
|
|
646
700
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
@@ -31,7 +31,7 @@ try:
|
|
|
31
31
|
except ImportError as _import_error:
|
|
32
32
|
raise ImportError(
|
|
33
33
|
'Please install `rich`, `prompt-toolkit` and `argcomplete` to use the PydanticAI CLI, '
|
|
34
|
-
|
|
34
|
+
'you can use the `cli` optional group — `pip install "pydantic-ai-slim[cli]"`'
|
|
35
35
|
) from _import_error
|
|
36
36
|
|
|
37
37
|
from pydantic_ai.agent import Agent
|
|
@@ -44,6 +44,7 @@ def function_schema( # noqa: C901
|
|
|
44
44
|
takes_ctx: bool,
|
|
45
45
|
docstring_format: DocstringFormat,
|
|
46
46
|
require_parameter_descriptions: bool,
|
|
47
|
+
schema_generator: type[GenerateJsonSchema],
|
|
47
48
|
) -> FunctionSchema:
|
|
48
49
|
"""Build a Pydantic validator and JSON schema from a tool function.
|
|
49
50
|
|
|
@@ -52,6 +53,7 @@ def function_schema( # noqa: C901
|
|
|
52
53
|
takes_ctx: Whether the function takes a `RunContext` first argument.
|
|
53
54
|
docstring_format: The docstring format to use.
|
|
54
55
|
require_parameter_descriptions: Whether to require descriptions for all tool function parameters.
|
|
56
|
+
schema_generator: The JSON schema generator class to use.
|
|
55
57
|
|
|
56
58
|
Returns:
|
|
57
59
|
A `FunctionSchema` instance.
|
|
@@ -150,14 +152,12 @@ def function_schema( # noqa: C901
|
|
|
150
152
|
)
|
|
151
153
|
# PluggableSchemaValidator is api compatible with SchemaValidator
|
|
152
154
|
schema_validator = cast(SchemaValidator, schema_validator)
|
|
153
|
-
json_schema =
|
|
155
|
+
json_schema = schema_generator().generate(schema)
|
|
154
156
|
|
|
155
157
|
# workaround for https://github.com/pydantic/pydantic/issues/10785
|
|
156
|
-
# if we build a custom
|
|
158
|
+
# if we build a custom TypedDict schema (matches when `single_arg_name is None`), we manually set
|
|
157
159
|
# `additionalProperties` in the JSON Schema
|
|
158
|
-
if single_arg_name is None:
|
|
159
|
-
json_schema['additionalProperties'] = bool(var_kwargs_schema)
|
|
160
|
-
elif not description:
|
|
160
|
+
if single_arg_name is not None and not description:
|
|
161
161
|
# if the tool description is not set, and we have a single parameter, take the description from that
|
|
162
162
|
# and set it on the tool
|
|
163
163
|
description = json_schema.pop('description', None)
|
|
@@ -218,6 +218,7 @@ def _build_schema(
|
|
|
218
218
|
td_schema = core_schema.typed_dict_schema(
|
|
219
219
|
fields,
|
|
220
220
|
config=core_config,
|
|
221
|
+
total=var_kwargs_schema is None,
|
|
221
222
|
extras_schema=gen_schema.generate_schema(var_kwargs_schema) if var_kwargs_schema else None,
|
|
222
223
|
)
|
|
223
224
|
return td_schema, None
|
|
@@ -3,12 +3,13 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import dataclasses
|
|
4
4
|
import inspect
|
|
5
5
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
6
|
-
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
|
|
6
|
+
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from types import FrameType
|
|
9
|
-
from typing import Any, Callable, ClassVar, Generic, cast, final, overload
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload
|
|
10
10
|
|
|
11
11
|
from opentelemetry.trace import NoOpTracer, use_span
|
|
12
|
+
from pydantic.json_schema import GenerateJsonSchema
|
|
12
13
|
from typing_extensions import TypeGuard, TypeVar, deprecated
|
|
13
14
|
|
|
14
15
|
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
|
|
@@ -31,6 +32,7 @@ from .settings import ModelSettings, merge_model_settings
|
|
|
31
32
|
from .tools import (
|
|
32
33
|
AgentDepsT,
|
|
33
34
|
DocstringFormat,
|
|
35
|
+
GenerateToolJsonSchema,
|
|
34
36
|
RunContext,
|
|
35
37
|
Tool,
|
|
36
38
|
ToolFuncContext,
|
|
@@ -47,6 +49,9 @@ CallToolsNode = _agent_graph.CallToolsNode
|
|
|
47
49
|
ModelRequestNode = _agent_graph.ModelRequestNode
|
|
48
50
|
UserPromptNode = _agent_graph.UserPromptNode
|
|
49
51
|
|
|
52
|
+
if TYPE_CHECKING:
|
|
53
|
+
from pydantic_ai.mcp import MCPServer
|
|
54
|
+
|
|
50
55
|
__all__ = (
|
|
51
56
|
'Agent',
|
|
52
57
|
'AgentRun',
|
|
@@ -129,6 +134,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
129
134
|
repr=False
|
|
130
135
|
)
|
|
131
136
|
_function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
|
|
137
|
+
_mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
|
|
132
138
|
_default_retries: int = dataclasses.field(repr=False)
|
|
133
139
|
_max_result_retries: int = dataclasses.field(repr=False)
|
|
134
140
|
_override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
|
|
@@ -148,6 +154,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
148
154
|
result_tool_description: str | None = None,
|
|
149
155
|
result_retries: int | None = None,
|
|
150
156
|
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
157
|
+
mcp_servers: Sequence[MCPServer] = (),
|
|
151
158
|
defer_model_check: bool = False,
|
|
152
159
|
end_strategy: EndStrategy = 'early',
|
|
153
160
|
instrument: InstrumentationSettings | bool | None = None,
|
|
@@ -173,6 +180,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
173
180
|
result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
|
|
174
181
|
tools: Tools to register with the agent, you can also register tools via the decorators
|
|
175
182
|
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
|
|
183
|
+
mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer]
|
|
184
|
+
for each server you want the agent to connect to.
|
|
176
185
|
defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
|
|
177
186
|
it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
|
|
178
187
|
which checks for the necessary environment variables. Set this to `false`
|
|
@@ -186,6 +195,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
186
195
|
If this isn't set, then the last value set by
|
|
187
196
|
[`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all]
|
|
188
197
|
will be used, which defaults to False.
|
|
198
|
+
See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
|
|
189
199
|
"""
|
|
190
200
|
if model is None or defer_model_check:
|
|
191
201
|
self.model = model
|
|
@@ -215,6 +225,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
215
225
|
|
|
216
226
|
self._default_retries = retries
|
|
217
227
|
self._max_result_retries = result_retries if result_retries is not None else retries
|
|
228
|
+
self._mcp_servers = mcp_servers
|
|
218
229
|
for tool in tools:
|
|
219
230
|
if isinstance(tool, Tool):
|
|
220
231
|
self._register_tool(tool)
|
|
@@ -435,7 +446,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
435
446
|
usage_limits = usage_limits or _usage.UsageLimits()
|
|
436
447
|
|
|
437
448
|
if isinstance(model_used, InstrumentedModel):
|
|
438
|
-
tracer = model_used.
|
|
449
|
+
tracer = model_used.settings.tracer
|
|
439
450
|
else:
|
|
440
451
|
tracer = NoOpTracer()
|
|
441
452
|
agent_name = self.name or 'agent'
|
|
@@ -461,6 +472,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
461
472
|
result_tools=self._result_schema.tool_defs() if self._result_schema else [],
|
|
462
473
|
result_validators=result_validators,
|
|
463
474
|
function_tools=self._function_tools,
|
|
475
|
+
mcp_servers=self._mcp_servers,
|
|
464
476
|
run_span=run_span,
|
|
465
477
|
tracer=tracer,
|
|
466
478
|
)
|
|
@@ -927,6 +939,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
927
939
|
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
928
940
|
docstring_format: DocstringFormat = 'auto',
|
|
929
941
|
require_parameter_descriptions: bool = False,
|
|
942
|
+
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
930
943
|
) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
|
|
931
944
|
|
|
932
945
|
def tool(
|
|
@@ -939,6 +952,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
939
952
|
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
940
953
|
docstring_format: DocstringFormat = 'auto',
|
|
941
954
|
require_parameter_descriptions: bool = False,
|
|
955
|
+
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
942
956
|
) -> Any:
|
|
943
957
|
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
944
958
|
|
|
@@ -980,6 +994,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
980
994
|
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
|
|
981
995
|
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
|
|
982
996
|
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
|
|
997
|
+
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
|
|
983
998
|
"""
|
|
984
999
|
if func is None:
|
|
985
1000
|
|
|
@@ -988,7 +1003,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
988
1003
|
) -> ToolFuncContext[AgentDepsT, ToolParams]:
|
|
989
1004
|
# noinspection PyTypeChecker
|
|
990
1005
|
self._register_function(
|
|
991
|
-
func_,
|
|
1006
|
+
func_,
|
|
1007
|
+
True,
|
|
1008
|
+
name,
|
|
1009
|
+
retries,
|
|
1010
|
+
prepare,
|
|
1011
|
+
docstring_format,
|
|
1012
|
+
require_parameter_descriptions,
|
|
1013
|
+
schema_generator,
|
|
992
1014
|
)
|
|
993
1015
|
return func_
|
|
994
1016
|
|
|
@@ -996,7 +1018,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
996
1018
|
else:
|
|
997
1019
|
# noinspection PyTypeChecker
|
|
998
1020
|
self._register_function(
|
|
999
|
-
func, True, name, retries, prepare, docstring_format, require_parameter_descriptions
|
|
1021
|
+
func, True, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator
|
|
1000
1022
|
)
|
|
1001
1023
|
return func
|
|
1002
1024
|
|
|
@@ -1013,6 +1035,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1013
1035
|
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
1014
1036
|
docstring_format: DocstringFormat = 'auto',
|
|
1015
1037
|
require_parameter_descriptions: bool = False,
|
|
1038
|
+
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
1016
1039
|
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
|
|
1017
1040
|
|
|
1018
1041
|
def tool_plain(
|
|
@@ -1025,6 +1048,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1025
1048
|
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
1026
1049
|
docstring_format: DocstringFormat = 'auto',
|
|
1027
1050
|
require_parameter_descriptions: bool = False,
|
|
1051
|
+
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
1028
1052
|
) -> Any:
|
|
1029
1053
|
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
|
|
1030
1054
|
|
|
@@ -1066,20 +1090,28 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1066
1090
|
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
|
|
1067
1091
|
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
|
|
1068
1092
|
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
|
|
1093
|
+
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
|
|
1069
1094
|
"""
|
|
1070
1095
|
if func is None:
|
|
1071
1096
|
|
|
1072
1097
|
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
|
|
1073
1098
|
# noinspection PyTypeChecker
|
|
1074
1099
|
self._register_function(
|
|
1075
|
-
func_,
|
|
1100
|
+
func_,
|
|
1101
|
+
False,
|
|
1102
|
+
name,
|
|
1103
|
+
retries,
|
|
1104
|
+
prepare,
|
|
1105
|
+
docstring_format,
|
|
1106
|
+
require_parameter_descriptions,
|
|
1107
|
+
schema_generator,
|
|
1076
1108
|
)
|
|
1077
1109
|
return func_
|
|
1078
1110
|
|
|
1079
1111
|
return tool_decorator
|
|
1080
1112
|
else:
|
|
1081
1113
|
self._register_function(
|
|
1082
|
-
func, False, name, retries, prepare, docstring_format, require_parameter_descriptions
|
|
1114
|
+
func, False, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator
|
|
1083
1115
|
)
|
|
1084
1116
|
return func
|
|
1085
1117
|
|
|
@@ -1092,6 +1124,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1092
1124
|
prepare: ToolPrepareFunc[AgentDepsT] | None,
|
|
1093
1125
|
docstring_format: DocstringFormat,
|
|
1094
1126
|
require_parameter_descriptions: bool,
|
|
1127
|
+
schema_generator: type[GenerateJsonSchema],
|
|
1095
1128
|
) -> None:
|
|
1096
1129
|
"""Private utility to register a function as a tool."""
|
|
1097
1130
|
retries_ = retries if retries is not None else self._default_retries
|
|
@@ -1103,6 +1136,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1103
1136
|
prepare=prepare,
|
|
1104
1137
|
docstring_format=docstring_format,
|
|
1105
1138
|
require_parameter_descriptions=require_parameter_descriptions,
|
|
1139
|
+
schema_generator=schema_generator,
|
|
1106
1140
|
)
|
|
1107
1141
|
self._register_tool(tool)
|
|
1108
1142
|
|
|
@@ -1253,6 +1287,20 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1253
1287
|
"""
|
|
1254
1288
|
return isinstance(node, End)
|
|
1255
1289
|
|
|
1290
|
+
@asynccontextmanager
|
|
1291
|
+
async def run_mcp_servers(self) -> AsyncIterator[None]:
|
|
1292
|
+
"""Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent.
|
|
1293
|
+
|
|
1294
|
+
Returns: a context manager to start and shutdown the servers.
|
|
1295
|
+
"""
|
|
1296
|
+
exit_stack = AsyncExitStack()
|
|
1297
|
+
try:
|
|
1298
|
+
for mcp_server in self._mcp_servers:
|
|
1299
|
+
await exit_stack.enter_async_context(mcp_server)
|
|
1300
|
+
yield
|
|
1301
|
+
finally:
|
|
1302
|
+
await exit_stack.aclose()
|
|
1303
|
+
|
|
1256
1304
|
|
|
1257
1305
|
@dataclasses.dataclass(repr=False)
|
|
1258
1306
|
class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
@@ -13,7 +13,7 @@ try:
|
|
|
13
13
|
except ImportError as _import_error:
|
|
14
14
|
raise ImportError(
|
|
15
15
|
'Please install `duckduckgo-search` to use the DuckDuckGo search tool, '
|
|
16
|
-
|
|
16
|
+
'you can use the `duckduckgo` optional group — `pip install "pydantic-ai-slim[duckduckgo]"`'
|
|
17
17
|
) from _import_error
|
|
18
18
|
|
|
19
19
|
__all__ = ('duckduckgo_search_tool',)
|
|
@@ -11,7 +11,7 @@ try:
|
|
|
11
11
|
except ImportError as _import_error:
|
|
12
12
|
raise ImportError(
|
|
13
13
|
'Please install `tavily-python` to use the Tavily search tool, '
|
|
14
|
-
|
|
14
|
+
'you can use the `tavily` optional group — `pip install "pydantic-ai-slim[tavily]"`'
|
|
15
15
|
) from _import_error
|
|
16
16
|
|
|
17
17
|
__all__ = ('tavily_search_tool',)
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import AsyncIterator, Sequence
|
|
5
|
+
from contextlib import AsyncExitStack, asynccontextmanager
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from types import TracebackType
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
11
|
+
from mcp.types import JSONRPCMessage
|
|
12
|
+
from typing_extensions import Self
|
|
13
|
+
|
|
14
|
+
from pydantic_ai.tools import ToolDefinition
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
from mcp.client.session import ClientSession
|
|
18
|
+
from mcp.client.sse import sse_client
|
|
19
|
+
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
20
|
+
from mcp.types import CallToolResult
|
|
21
|
+
except ImportError as _import_error:
|
|
22
|
+
raise ImportError(
|
|
23
|
+
'Please install the `mcp` package to use the MCP server, '
|
|
24
|
+
'you can use the `mcp` optional group — `pip install "pydantic-ai-slim[mcp]"`'
|
|
25
|
+
) from _import_error
|
|
26
|
+
|
|
27
|
+
__all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP'
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MCPServer(ABC):
|
|
31
|
+
"""Base class for attaching agents to MCP servers.
|
|
32
|
+
|
|
33
|
+
See <https://modelcontextprotocol.io> for more information.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
is_running: bool = False
|
|
37
|
+
|
|
38
|
+
_client: ClientSession
|
|
39
|
+
_read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
|
40
|
+
_write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
|
41
|
+
_exit_stack: AsyncExitStack
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
@asynccontextmanager
|
|
45
|
+
async def client_streams(
|
|
46
|
+
self,
|
|
47
|
+
) -> AsyncIterator[
|
|
48
|
+
tuple[MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage]]
|
|
49
|
+
]:
|
|
50
|
+
"""Create the streams for the MCP server."""
|
|
51
|
+
raise NotImplementedError('MCP Server subclasses must implement this method.')
|
|
52
|
+
yield
|
|
53
|
+
|
|
54
|
+
async def list_tools(self) -> list[ToolDefinition]:
|
|
55
|
+
"""Retrieve tools that are currently active on the server.
|
|
56
|
+
|
|
57
|
+
Note:
|
|
58
|
+
- We don't cache tools as they might change.
|
|
59
|
+
- We also don't subscribe to the server to avoid complexity.
|
|
60
|
+
"""
|
|
61
|
+
tools = await self._client.list_tools()
|
|
62
|
+
return [
|
|
63
|
+
ToolDefinition(
|
|
64
|
+
name=tool.name,
|
|
65
|
+
description=tool.description or '',
|
|
66
|
+
parameters_json_schema=tool.inputSchema,
|
|
67
|
+
)
|
|
68
|
+
for tool in tools.tools
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> CallToolResult:
|
|
72
|
+
"""Call a tool on the server.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
tool_name: The name of the tool to call.
|
|
76
|
+
arguments: The arguments to pass to the tool.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
The result of the tool call.
|
|
80
|
+
"""
|
|
81
|
+
return await self._client.call_tool(tool_name, arguments)
|
|
82
|
+
|
|
83
|
+
async def __aenter__(self) -> Self:
|
|
84
|
+
self._exit_stack = AsyncExitStack()
|
|
85
|
+
|
|
86
|
+
self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams())
|
|
87
|
+
client = ClientSession(read_stream=self._read_stream, write_stream=self._write_stream)
|
|
88
|
+
self._client = await self._exit_stack.enter_async_context(client)
|
|
89
|
+
|
|
90
|
+
await self._client.initialize()
|
|
91
|
+
self.is_running = True
|
|
92
|
+
return self
|
|
93
|
+
|
|
94
|
+
async def __aexit__(
|
|
95
|
+
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
|
|
96
|
+
) -> bool | None:
|
|
97
|
+
await self._exit_stack.aclose()
|
|
98
|
+
self.is_running = False
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass
|
|
102
|
+
class MCPServerStdio(MCPServer):
|
|
103
|
+
"""Runs an MCP server in a subprocess and communicates with it over stdin/stdout.
|
|
104
|
+
|
|
105
|
+
This class implements the stdio transport from the MCP specification.
|
|
106
|
+
See <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio> for more information.
|
|
107
|
+
|
|
108
|
+
!!! note
|
|
109
|
+
Using this class as an async context manager will start the server as a subprocess when entering the context,
|
|
110
|
+
and stop it when exiting the context.
|
|
111
|
+
|
|
112
|
+
Example:
|
|
113
|
+
```python {py="3.10"}
|
|
114
|
+
from pydantic_ai import Agent
|
|
115
|
+
from pydantic_ai.mcp import MCPServerStdio
|
|
116
|
+
|
|
117
|
+
server = MCPServerStdio('npx', ['-y', '@pydantic/mcp-run-python', 'stdio']) # (1)!
|
|
118
|
+
agent = Agent('openai:gpt-4o', mcp_servers=[server])
|
|
119
|
+
|
|
120
|
+
async def main():
|
|
121
|
+
async with agent.run_mcp_servers(): # (2)!
|
|
122
|
+
...
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
1. See [MCP Run Python](../mcp/run-python.md) for more information.
|
|
126
|
+
2. This will start the server as a subprocess and connect to it.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
command: str
|
|
130
|
+
"""The command to run."""
|
|
131
|
+
|
|
132
|
+
args: Sequence[str]
|
|
133
|
+
"""The arguments to pass to the command."""
|
|
134
|
+
|
|
135
|
+
env: dict[str, str] | None = None
|
|
136
|
+
"""The environment variables the CLI server will have access to.
|
|
137
|
+
|
|
138
|
+
By default the subprocess will not inherit any environment variables from the parent process.
|
|
139
|
+
If you want to inherit the environment variables from the parent process, use `env=os.environ`.
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
@asynccontextmanager
|
|
143
|
+
async def client_streams(
|
|
144
|
+
self,
|
|
145
|
+
) -> AsyncIterator[
|
|
146
|
+
tuple[MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage]]
|
|
147
|
+
]:
|
|
148
|
+
server = StdioServerParameters(command=self.command, args=list(self.args), env=self.env)
|
|
149
|
+
async with stdio_client(server=server) as (read_stream, write_stream):
|
|
150
|
+
yield read_stream, write_stream
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@dataclass
|
|
154
|
+
class MCPServerHTTP(MCPServer):
|
|
155
|
+
"""An MCP server that connects over streamable HTTP connections.
|
|
156
|
+
|
|
157
|
+
This class implements the SSE transport from the MCP specification.
|
|
158
|
+
See <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse> for more information.
|
|
159
|
+
|
|
160
|
+
The name "HTTP" is used since this implemented will be adapted in future to use the new
|
|
161
|
+
[Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development.
|
|
162
|
+
|
|
163
|
+
!!! note
|
|
164
|
+
Using this class as an async context manager will create a new pool of HTTP connections to connect
|
|
165
|
+
to a server which should already be running.
|
|
166
|
+
|
|
167
|
+
Example:
|
|
168
|
+
```python {py="3.10"}
|
|
169
|
+
from pydantic_ai import Agent
|
|
170
|
+
from pydantic_ai.mcp import MCPServerHTTP
|
|
171
|
+
|
|
172
|
+
server = MCPServerHTTP('http://localhost:3001/sse') # (1)!
|
|
173
|
+
agent = Agent('openai:gpt-4o', mcp_servers=[server])
|
|
174
|
+
|
|
175
|
+
async def main():
|
|
176
|
+
async with agent.run_mcp_servers(): # (2)!
|
|
177
|
+
...
|
|
178
|
+
```
|
|
179
|
+
|
|
180
|
+
1. E.g. you might be connecting to a server run with `npx @pydantic/mcp-run-python sse`,
|
|
181
|
+
see [MCP Run Python](../mcp/run-python.md) for more information.
|
|
182
|
+
2. This will connect to a server running on `localhost:3001`.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
url: str
|
|
186
|
+
"""The URL of the SSE endpoint on the MCP server.
|
|
187
|
+
|
|
188
|
+
For example for a server running locally, this might be `http://localhost:3001/sse`.
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
@asynccontextmanager
|
|
192
|
+
async def client_streams(
|
|
193
|
+
self,
|
|
194
|
+
) -> AsyncIterator[
|
|
195
|
+
tuple[MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage]]
|
|
196
|
+
]: # pragma: no cover
|
|
197
|
+
async with sse_client(url=self.url) as (read_stream, write_stream):
|
|
198
|
+
yield read_stream, write_stream
|
|
@@ -26,6 +26,9 @@ class SystemPromptPart:
|
|
|
26
26
|
content: str
|
|
27
27
|
"""The content of the prompt."""
|
|
28
28
|
|
|
29
|
+
timestamp: datetime = field(default_factory=_now_utc)
|
|
30
|
+
"""The timestamp of the prompt."""
|
|
31
|
+
|
|
29
32
|
dynamic_ref: str | None = None
|
|
30
33
|
"""The ref of the dynamic system prompt function that generated this part.
|
|
31
34
|
|
|
@@ -65,7 +65,7 @@ try:
|
|
|
65
65
|
except ImportError as _import_error:
|
|
66
66
|
raise ImportError(
|
|
67
67
|
'Please install `anthropic` to use the Anthropic model, '
|
|
68
|
-
|
|
68
|
+
'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`'
|
|
69
69
|
) from _import_error
|
|
70
70
|
|
|
71
71
|
LatestAnthropicModelNames = Literal[
|
|
@@ -50,7 +50,7 @@ try:
|
|
|
50
50
|
except ImportError as _import_error:
|
|
51
51
|
raise ImportError(
|
|
52
52
|
'Please install `cohere` to use the Cohere model, '
|
|
53
|
-
|
|
53
|
+
'you can use the `cohere` optional group — `pip install "pydantic-ai-slim[cohere]"`'
|
|
54
54
|
) from _import_error
|
|
55
55
|
|
|
56
56
|
LatestCohereModelNames = Literal[
|
|
@@ -41,7 +41,7 @@ try:
|
|
|
41
41
|
except ImportError as _import_error:
|
|
42
42
|
raise ImportError(
|
|
43
43
|
'Please install `groq` to use the Groq model, '
|
|
44
|
-
|
|
44
|
+
'you can use the `groq` optional group — `pip install "pydantic-ai-slim[groq]"`'
|
|
45
45
|
) from _import_error
|
|
46
46
|
|
|
47
47
|
|
|
@@ -52,7 +52,9 @@ class InstrumentationSettings:
|
|
|
52
52
|
|
|
53
53
|
- `Agent(instrument=...)`
|
|
54
54
|
- [`Agent.instrument_all()`][pydantic_ai.agent.Agent.instrument_all]
|
|
55
|
-
- `InstrumentedModel`
|
|
55
|
+
- [`InstrumentedModel`][pydantic_ai.models.instrumented.InstrumentedModel]
|
|
56
|
+
|
|
57
|
+
See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
|
|
56
58
|
"""
|
|
57
59
|
|
|
58
60
|
tracer: Tracer = field(repr=False)
|
|
@@ -94,9 +96,13 @@ GEN_AI_REQUEST_MODEL_ATTRIBUTE = 'gen_ai.request.model'
|
|
|
94
96
|
|
|
95
97
|
@dataclass
|
|
96
98
|
class InstrumentedModel(WrapperModel):
|
|
97
|
-
"""Model which
|
|
99
|
+
"""Model which wraps another model so that requests are instrumented with OpenTelemetry.
|
|
100
|
+
|
|
101
|
+
See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
|
|
102
|
+
"""
|
|
98
103
|
|
|
99
|
-
|
|
104
|
+
settings: InstrumentationSettings
|
|
105
|
+
"""Configuration for instrumenting requests."""
|
|
100
106
|
|
|
101
107
|
def __init__(
|
|
102
108
|
self,
|
|
@@ -104,7 +110,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
104
110
|
options: InstrumentationSettings | None = None,
|
|
105
111
|
) -> None:
|
|
106
112
|
super().__init__(wrapped)
|
|
107
|
-
self.
|
|
113
|
+
self.settings = options or InstrumentationSettings()
|
|
108
114
|
|
|
109
115
|
async def request(
|
|
110
116
|
self,
|
|
@@ -156,7 +162,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
156
162
|
if isinstance(value := model_settings.get(key), (float, int)):
|
|
157
163
|
attributes[f'gen_ai.request.{key}'] = value
|
|
158
164
|
|
|
159
|
-
with self.
|
|
165
|
+
with self.settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
|
|
160
166
|
|
|
161
167
|
def finish(response: ModelResponse, usage: Usage):
|
|
162
168
|
if not span.is_recording():
|
|
@@ -190,9 +196,9 @@ class InstrumentedModel(WrapperModel):
|
|
|
190
196
|
yield finish
|
|
191
197
|
|
|
192
198
|
def _emit_events(self, span: Span, events: list[Event]) -> None:
|
|
193
|
-
if self.
|
|
199
|
+
if self.settings.event_mode == 'logs':
|
|
194
200
|
for event in events:
|
|
195
|
-
self.
|
|
201
|
+
self.settings.event_logger.emit(event)
|
|
196
202
|
else:
|
|
197
203
|
attr_name = 'events'
|
|
198
204
|
span.set_attributes(
|
|
@@ -75,7 +75,7 @@ try:
|
|
|
75
75
|
except ImportError as e:
|
|
76
76
|
raise ImportError(
|
|
77
77
|
'Please install `mistral` to use the Mistral model, '
|
|
78
|
-
|
|
78
|
+
'you can use the `mistral` optional group — `pip install "pydantic-ai-slim[mistral]"`'
|
|
79
79
|
) from e
|
|
80
80
|
|
|
81
81
|
LatestMistralModelNames = Literal[
|
|
@@ -57,7 +57,7 @@ try:
|
|
|
57
57
|
except ImportError as _import_error:
|
|
58
58
|
raise ImportError(
|
|
59
59
|
'Please install `openai` to use the OpenAI model, '
|
|
60
|
-
|
|
60
|
+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
|
|
61
61
|
) from _import_error
|
|
62
62
|
|
|
63
63
|
OpenAIModelName = Union[str, ChatModel]
|
|
@@ -27,7 +27,7 @@ try:
|
|
|
27
27
|
except ImportError as _import_error:
|
|
28
28
|
raise ImportError(
|
|
29
29
|
'Please install `google-auth` to use the VertexAI model, '
|
|
30
|
-
|
|
30
|
+
'you can use the `vertexai` optional group — `pip install "pydantic-ai-slim[vertexai]"`'
|
|
31
31
|
) from _import_error
|
|
32
32
|
|
|
33
33
|
VERTEX_AI_URL_TEMPLATE = (
|
|
@@ -13,9 +13,13 @@ from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, i
|
|
|
13
13
|
|
|
14
14
|
@dataclass(init=False)
|
|
15
15
|
class WrapperModel(Model):
|
|
16
|
-
"""Model which wraps another model.
|
|
16
|
+
"""Model which wraps another model.
|
|
17
|
+
|
|
18
|
+
Does nothing on its own, used as a base class.
|
|
19
|
+
"""
|
|
17
20
|
|
|
18
21
|
wrapped: Model
|
|
22
|
+
"""The underlying model being wrapped."""
|
|
19
23
|
|
|
20
24
|
def __init__(self, wrapped: Model | KnownModelName):
|
|
21
25
|
self.wrapped = infer_model(wrapped)
|
|
@@ -12,7 +12,7 @@ try:
|
|
|
12
12
|
except ImportError as _import_error: # pragma: no cover
|
|
13
13
|
raise ImportError(
|
|
14
14
|
'Please install the `anthropic` package to use the Anthropic provider, '
|
|
15
|
-
|
|
15
|
+
'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`'
|
|
16
16
|
) from _import_error
|
|
17
17
|
|
|
18
18
|
|
|
@@ -13,7 +13,7 @@ try:
|
|
|
13
13
|
except ImportError as _import_error: # pragma: no cover
|
|
14
14
|
raise ImportError(
|
|
15
15
|
'Please install the `openai` package to use the Azure provider, '
|
|
16
|
-
|
|
16
|
+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
|
|
17
17
|
) from _import_error
|
|
18
18
|
|
|
19
19
|
|
|
@@ -11,7 +11,7 @@ try:
|
|
|
11
11
|
except ImportError as _import_error:
|
|
12
12
|
raise ImportError(
|
|
13
13
|
'Please install the `boto3` package to use the Bedrock provider, '
|
|
14
|
-
|
|
14
|
+
'you can use the `bedrock` optional group — `pip install "pydantic-ai-slim[bedrock]"`'
|
|
15
15
|
) from _import_error
|
|
16
16
|
|
|
17
17
|
|
|
@@ -13,7 +13,7 @@ try:
|
|
|
13
13
|
except ImportError as _import_error: # pragma: no cover
|
|
14
14
|
raise ImportError(
|
|
15
15
|
'Please install the `openai` package to use the DeepSeek provider, '
|
|
16
|
-
|
|
16
|
+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
|
|
17
17
|
) from _import_error
|
|
18
18
|
|
|
19
19
|
from . import Provider
|
|
@@ -2,7 +2,6 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
4
|
from collections.abc import AsyncGenerator, Mapping
|
|
5
|
-
from datetime import datetime, timedelta
|
|
6
5
|
from pathlib import Path
|
|
7
6
|
from typing import Literal, overload
|
|
8
7
|
|
|
@@ -22,15 +21,12 @@ try:
|
|
|
22
21
|
except ImportError as _import_error:
|
|
23
22
|
raise ImportError(
|
|
24
23
|
'Please install the `google-auth` package to use the Google Vertex AI provider, '
|
|
25
|
-
|
|
24
|
+
'you can use the `vertexai` optional group — `pip install "pydantic-ai-slim[vertexai]"`'
|
|
26
25
|
) from _import_error
|
|
27
26
|
|
|
28
27
|
|
|
29
28
|
__all__ = ('GoogleVertexProvider',)
|
|
30
29
|
|
|
31
|
-
# default expiry is 3600 seconds
|
|
32
|
-
MAX_TOKEN_AGE = timedelta(seconds=3000)
|
|
33
|
-
|
|
34
30
|
|
|
35
31
|
class GoogleVertexProvider(Provider[httpx.AsyncClient]):
|
|
36
32
|
"""Provider for Vertex AI API."""
|
|
@@ -131,19 +127,21 @@ class _VertexAIAuth(httpx.Auth):
|
|
|
131
127
|
self.region = region
|
|
132
128
|
|
|
133
129
|
self.credentials = None
|
|
134
|
-
self.token_created: datetime | None = None
|
|
135
130
|
|
|
136
131
|
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
|
137
132
|
if self.credentials is None:
|
|
138
133
|
self.credentials = await self._get_credentials()
|
|
139
|
-
if self.credentials.token is None
|
|
140
|
-
await
|
|
141
|
-
self.token_created = datetime.now()
|
|
134
|
+
if self.credentials.token is None: # type: ignore[reportUnknownMemberType]
|
|
135
|
+
await self._refresh_token()
|
|
142
136
|
request.headers['Authorization'] = f'Bearer {self.credentials.token}' # type: ignore[reportUnknownMemberType]
|
|
143
|
-
|
|
144
137
|
# NOTE: This workaround is in place because we might get the project_id from the credentials.
|
|
145
138
|
request.url = httpx.URL(str(request.url).replace('projects/None', f'projects/{self.project_id}'))
|
|
146
|
-
yield request
|
|
139
|
+
response = yield request
|
|
140
|
+
|
|
141
|
+
if response.status_code == 401:
|
|
142
|
+
await self._refresh_token()
|
|
143
|
+
request.headers['Authorization'] = f'Bearer {self.credentials.token}' # type: ignore[reportUnknownMemberType]
|
|
144
|
+
yield request
|
|
147
145
|
|
|
148
146
|
async def _get_credentials(self) -> BaseCredentials | ServiceAccountCredentials:
|
|
149
147
|
if self.service_account_file is not None:
|
|
@@ -166,15 +164,9 @@ class _VertexAIAuth(httpx.Auth):
|
|
|
166
164
|
self.project_id = creds_project_id
|
|
167
165
|
return creds
|
|
168
166
|
|
|
169
|
-
def
|
|
170
|
-
if self.token_created is None:
|
|
171
|
-
return True
|
|
172
|
-
else:
|
|
173
|
-
return (datetime.now() - self.token_created) > MAX_TOKEN_AGE
|
|
174
|
-
|
|
175
|
-
def _refresh_token(self) -> str: # pragma: no cover
|
|
167
|
+
async def _refresh_token(self) -> str: # pragma: no cover
|
|
176
168
|
assert self.credentials is not None
|
|
177
|
-
self.credentials.refresh
|
|
169
|
+
await anyio.to_thread.run_sync(self.credentials.refresh, Request()) # type: ignore[reportUnknownMemberType]
|
|
178
170
|
assert isinstance(self.credentials.token, str), f'Expected token to be a string, got {self.credentials.token}' # type: ignore[reportUnknownMemberType]
|
|
179
171
|
return self.credentials.token
|
|
180
172
|
|
|
@@ -12,7 +12,7 @@ try:
|
|
|
12
12
|
except ImportError as _import_error: # pragma: no cover
|
|
13
13
|
raise ImportError(
|
|
14
14
|
'Please install the `groq` package to use the Groq provider, '
|
|
15
|
-
|
|
15
|
+
'you can use the `groq` optional group — `pip install "pydantic-ai-slim[groq]"`'
|
|
16
16
|
) from _import_error
|
|
17
17
|
|
|
18
18
|
|
|
@@ -57,17 +57,21 @@ class GroqProvider(Provider[AsyncGroq]):
|
|
|
57
57
|
client to use. If provided, `api_key` and `http_client` must be `None`.
|
|
58
58
|
http_client: An existing `AsyncHTTPClient` to use for making HTTP requests.
|
|
59
59
|
"""
|
|
60
|
-
api_key = api_key or os.environ.get('GROQ_API_KEY')
|
|
61
|
-
|
|
62
|
-
if api_key is None and groq_client is None:
|
|
63
|
-
raise ValueError(
|
|
64
|
-
'Set the `GROQ_API_KEY` environment variable or pass it via `GroqProvider(api_key=...)`'
|
|
65
|
-
'to use the Groq provider.'
|
|
66
|
-
)
|
|
67
|
-
|
|
68
60
|
if groq_client is not None:
|
|
61
|
+
assert http_client is None, 'Cannot provide both `groq_client` and `http_client`'
|
|
62
|
+
assert api_key is None, 'Cannot provide both `groq_client` and `api_key`'
|
|
69
63
|
self._client = groq_client
|
|
70
|
-
elif http_client is not None:
|
|
71
|
-
self._client = AsyncGroq(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
|
72
64
|
else:
|
|
73
|
-
|
|
65
|
+
api_key = api_key or os.environ.get('GROQ_API_KEY')
|
|
66
|
+
|
|
67
|
+
if api_key is None:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
'Set the `GROQ_API_KEY` environment variable or pass it via `GroqProvider(api_key=...)`'
|
|
70
|
+
'to use the Groq provider.'
|
|
71
|
+
)
|
|
72
|
+
elif http_client is not None:
|
|
73
|
+
self._client = AsyncGroq(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
|
74
|
+
else:
|
|
75
|
+
self._client = AsyncGroq(
|
|
76
|
+
base_url=self.base_url, api_key=api_key, http_client=cached_async_http_client()
|
|
77
|
+
)
|
|
@@ -12,7 +12,7 @@ try:
|
|
|
12
12
|
except ImportError as e: # pragma: no cover
|
|
13
13
|
raise ImportError(
|
|
14
14
|
'Please install the `mistral` package to use the Mistral provider, '
|
|
15
|
-
|
|
15
|
+
'you can use the `mistral` optional group — `pip install "pydantic-ai-slim[mistral]"`'
|
|
16
16
|
) from e
|
|
17
17
|
|
|
18
18
|
|
|
@@ -55,19 +55,19 @@ class MistralProvider(Provider[Mistral]):
|
|
|
55
55
|
mistral_client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
56
56
|
http_client: An existing async client to use for making HTTP requests.
|
|
57
57
|
"""
|
|
58
|
-
api_key = api_key or os.environ.get('MISTRAL_API_KEY')
|
|
59
|
-
|
|
60
|
-
if api_key is None and mistral_client is None:
|
|
61
|
-
raise ValueError(
|
|
62
|
-
'Set the `MISTRAL_API_KEY` environment variable or pass it via `MistralProvider(api_key=...)`'
|
|
63
|
-
'to use the Mistral provider.'
|
|
64
|
-
)
|
|
65
|
-
|
|
66
58
|
if mistral_client is not None:
|
|
67
59
|
assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
|
|
68
60
|
assert api_key is None, 'Cannot provide both `mistral_client` and `api_key`'
|
|
69
61
|
self._client = mistral_client
|
|
70
|
-
elif http_client is not None:
|
|
71
|
-
self._client = Mistral(api_key=api_key, async_client=http_client)
|
|
72
62
|
else:
|
|
73
|
-
|
|
63
|
+
api_key = api_key or os.environ.get('MISTRAL_API_KEY')
|
|
64
|
+
|
|
65
|
+
if api_key is None:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
'Set the `MISTRAL_API_KEY` environment variable or pass it via `MistralProvider(api_key=...)`'
|
|
68
|
+
'to use the Mistral provider.'
|
|
69
|
+
)
|
|
70
|
+
elif http_client is not None:
|
|
71
|
+
self._client = Mistral(api_key=api_key, async_client=http_client)
|
|
72
|
+
else:
|
|
73
|
+
self._client = Mistral(api_key=api_key, async_client=cached_async_http_client())
|
|
@@ -11,7 +11,7 @@ try:
|
|
|
11
11
|
except ImportError as _import_error: # pragma: no cover
|
|
12
12
|
raise ImportError(
|
|
13
13
|
'Please install the `openai` package to use the OpenAI provider, '
|
|
14
|
-
|
|
14
|
+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
|
|
15
15
|
) from _import_error
|
|
16
16
|
|
|
17
17
|
|
|
@@ -7,7 +7,8 @@ from dataclasses import dataclass, field
|
|
|
7
7
|
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
|
|
8
8
|
|
|
9
9
|
from pydantic import ValidationError
|
|
10
|
-
from
|
|
10
|
+
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
|
|
11
|
+
from pydantic_core import SchemaValidator, core_schema
|
|
11
12
|
from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar
|
|
12
13
|
|
|
13
14
|
from . import _pydantic, _utils, messages as _messages, models
|
|
@@ -142,6 +143,22 @@ DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto']
|
|
|
142
143
|
A = TypeVar('A')
|
|
143
144
|
|
|
144
145
|
|
|
146
|
+
class GenerateToolJsonSchema(GenerateJsonSchema):
|
|
147
|
+
def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaValue:
|
|
148
|
+
s = super().typed_dict_schema(schema)
|
|
149
|
+
total = schema.get('total')
|
|
150
|
+
if total is not None:
|
|
151
|
+
s['additionalProperties'] = not total
|
|
152
|
+
return s
|
|
153
|
+
|
|
154
|
+
def _named_required_fields_schema(self, named_required_fields: Sequence[tuple[str, bool, Any]]) -> JsonSchemaValue:
|
|
155
|
+
# Remove largely-useless property titles
|
|
156
|
+
s = super()._named_required_fields_schema(named_required_fields)
|
|
157
|
+
for p in s.get('properties', {}):
|
|
158
|
+
s['properties'][p].pop('title', None)
|
|
159
|
+
return s
|
|
160
|
+
|
|
161
|
+
|
|
145
162
|
@dataclass(init=False)
|
|
146
163
|
class Tool(Generic[AgentDepsT]):
|
|
147
164
|
"""A tool function for an agent."""
|
|
@@ -176,6 +193,7 @@ class Tool(Generic[AgentDepsT]):
|
|
|
176
193
|
prepare: ToolPrepareFunc[AgentDepsT] | None = None,
|
|
177
194
|
docstring_format: DocstringFormat = 'auto',
|
|
178
195
|
require_parameter_descriptions: bool = False,
|
|
196
|
+
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
179
197
|
):
|
|
180
198
|
"""Create a new tool instance.
|
|
181
199
|
|
|
@@ -225,11 +243,14 @@ class Tool(Generic[AgentDepsT]):
|
|
|
225
243
|
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
|
|
226
244
|
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
|
|
227
245
|
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
|
|
246
|
+
schema_generator: The JSON schema generator class to use. Defaults to `GenerateToolJsonSchema`.
|
|
228
247
|
"""
|
|
229
248
|
if takes_ctx is None:
|
|
230
249
|
takes_ctx = _pydantic.takes_ctx(function)
|
|
231
250
|
|
|
232
|
-
f = _pydantic.function_schema(
|
|
251
|
+
f = _pydantic.function_schema(
|
|
252
|
+
function, takes_ctx, docstring_format, require_parameter_descriptions, schema_generator
|
|
253
|
+
)
|
|
233
254
|
self.function = function
|
|
234
255
|
self.takes_ctx = takes_ctx
|
|
235
256
|
self.max_retries = max_retries
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "pydantic-ai-slim"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.43"
|
|
8
8
|
description = "Agent Framework / shim to use Pydantic with LLMs, slim package"
|
|
9
9
|
authors = [{ name = "Samuel Colvin", email = "samuel@pydantic.dev" }]
|
|
10
10
|
license = "MIT"
|
|
@@ -36,7 +36,7 @@ dependencies = [
|
|
|
36
36
|
"griffe>=1.3.2",
|
|
37
37
|
"httpx>=0.27",
|
|
38
38
|
"pydantic>=2.10",
|
|
39
|
-
"pydantic-graph==0.0.
|
|
39
|
+
"pydantic-graph==0.0.43",
|
|
40
40
|
"exceptiongroup; python_version < '3.11'",
|
|
41
41
|
"opentelemetry-api>=1.28.0",
|
|
42
42
|
"typing-inspection>=0.4.0",
|
|
@@ -50,7 +50,7 @@ openai = ["openai>=1.65.1"]
|
|
|
50
50
|
cohere = ["cohere>=5.13.11"]
|
|
51
51
|
vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
|
|
52
52
|
anthropic = ["anthropic>=0.49.0"]
|
|
53
|
-
groq = ["groq>=0.
|
|
53
|
+
groq = ["groq>=0.15.0"]
|
|
54
54
|
mistral = ["mistralai>=1.2.5"]
|
|
55
55
|
bedrock = ["boto3>=1.34.116"]
|
|
56
56
|
# Tools
|
|
@@ -58,6 +58,8 @@ duckduckgo = ["duckduckgo-search>=7.0.0"]
|
|
|
58
58
|
tavily = ["tavily-python>=0.5.0"]
|
|
59
59
|
# CLI
|
|
60
60
|
cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"]
|
|
61
|
+
# MCP
|
|
62
|
+
mcp = ["mcp>=1.4.1; python_version >= '3.10'"]
|
|
61
63
|
|
|
62
64
|
[dependency-groups]
|
|
63
65
|
dev = [
|
|
@@ -75,6 +77,9 @@ dev = [
|
|
|
75
77
|
"boto3-stubs[bedrock-runtime]",
|
|
76
78
|
]
|
|
77
79
|
|
|
80
|
+
[tool.hatch.metadata]
|
|
81
|
+
allow-direct-references = true
|
|
82
|
+
|
|
78
83
|
[project.scripts]
|
|
79
84
|
pai = "pydantic_ai._cli:app"
|
|
80
85
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|