pydantic-ai-slim 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_agent_graph.py +18 -5
- pydantic_ai/_output.py +6 -2
- pydantic_ai/_parts_manager.py +5 -5
- pydantic_ai/_pydantic.py +8 -1
- pydantic_ai/_utils.py +9 -1
- pydantic_ai/agent.py +24 -49
- pydantic_ai/direct.py +7 -31
- pydantic_ai/mcp.py +59 -7
- pydantic_ai/messages.py +72 -29
- pydantic_ai/models/anthropic.py +80 -87
- pydantic_ai/models/bedrock.py +2 -2
- pydantic_ai/models/google.py +1 -1
- pydantic_ai/models/mistral.py +1 -1
- pydantic_ai/models/openai.py +0 -1
- pydantic_ai/result.py +5 -3
- pydantic_ai/tools.py +7 -3
- pydantic_ai/usage.py +7 -2
- {pydantic_ai_slim-0.2.6.dist-info → pydantic_ai_slim-0.2.8.dist-info}/METADATA +4 -4
- {pydantic_ai_slim-0.2.6.dist-info → pydantic_ai_slim-0.2.8.dist-info}/RECORD +22 -22
- {pydantic_ai_slim-0.2.6.dist-info → pydantic_ai_slim-0.2.8.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.2.6.dist-info → pydantic_ai_slim-0.2.8.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.2.6.dist-info → pydantic_ai_slim-0.2.8.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_agent_graph.py
CHANGED
|
@@ -222,27 +222,40 @@ async def _prepare_request_parameters(
|
|
|
222
222
|
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
|
|
223
223
|
) -> models.ModelRequestParameters:
|
|
224
224
|
"""Build tools and create an agent model."""
|
|
225
|
-
|
|
225
|
+
function_tool_defs_map: dict[str, ToolDefinition] = {}
|
|
226
226
|
|
|
227
227
|
run_context = build_run_context(ctx)
|
|
228
228
|
|
|
229
229
|
async def add_tool(tool: Tool[DepsT]) -> None:
|
|
230
230
|
ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
|
|
231
231
|
if tool_def := await tool.prepare_tool_def(ctx):
|
|
232
|
-
|
|
232
|
+
# prepare_tool_def may change tool_def.name
|
|
233
|
+
if tool_def.name in function_tool_defs_map:
|
|
234
|
+
if tool_def.name != tool.name:
|
|
235
|
+
# Prepare tool def may have renamed the tool
|
|
236
|
+
raise exceptions.UserError(
|
|
237
|
+
f"Renaming tool '{tool.name}' to '{tool_def.name}' conflicts with existing tool."
|
|
238
|
+
)
|
|
239
|
+
else:
|
|
240
|
+
raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}.')
|
|
241
|
+
function_tool_defs_map[tool_def.name] = tool_def
|
|
233
242
|
|
|
234
243
|
async def add_mcp_server_tools(server: MCPServer) -> None:
|
|
235
244
|
if not server.is_running:
|
|
236
245
|
raise exceptions.UserError(f'MCP server is not running: {server}')
|
|
237
246
|
tool_defs = await server.list_tools()
|
|
238
|
-
|
|
239
|
-
|
|
247
|
+
for tool_def in tool_defs:
|
|
248
|
+
if tool_def.name in function_tool_defs_map:
|
|
249
|
+
raise exceptions.UserError(
|
|
250
|
+
f"MCP Server '{server}' defines a tool whose name conflicts with existing tool: {tool_def.name!r}. Consider using `tool_prefix` to avoid name conflicts."
|
|
251
|
+
)
|
|
252
|
+
function_tool_defs_map[tool_def.name] = tool_def
|
|
240
253
|
|
|
241
254
|
await asyncio.gather(
|
|
242
255
|
*map(add_tool, ctx.deps.function_tools.values()),
|
|
243
256
|
*map(add_mcp_server_tools, ctx.deps.mcp_servers),
|
|
244
257
|
)
|
|
245
|
-
|
|
258
|
+
function_tool_defs = list(function_tool_defs_map.values())
|
|
246
259
|
if ctx.deps.prepare_tools:
|
|
247
260
|
# Prepare the tools using the provided function
|
|
248
261
|
# This also acts over tool definitions pulled from MCP servers
|
pydantic_ai/_output.py
CHANGED
|
@@ -231,9 +231,13 @@ class OutputSchemaTool(Generic[OutputDataT]):
|
|
|
231
231
|
try:
|
|
232
232
|
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
|
|
233
233
|
if isinstance(tool_call.args, str):
|
|
234
|
-
output = self.type_adapter.validate_json(
|
|
234
|
+
output = self.type_adapter.validate_json(
|
|
235
|
+
tool_call.args or '{}', experimental_allow_partial=pyd_allow_partial
|
|
236
|
+
)
|
|
235
237
|
else:
|
|
236
|
-
output = self.type_adapter.validate_python(
|
|
238
|
+
output = self.type_adapter.validate_python(
|
|
239
|
+
tool_call.args or {}, experimental_allow_partial=pyd_allow_partial
|
|
240
|
+
)
|
|
237
241
|
except ValidationError as e:
|
|
238
242
|
if wrap_validation_errors:
|
|
239
243
|
m = _messages.RetryPromptPart(
|
pydantic_ai/_parts_manager.py
CHANGED
|
@@ -132,7 +132,7 @@ class ModelResponsePartsManager:
|
|
|
132
132
|
) -> ModelResponseStreamEvent | None:
|
|
133
133
|
"""Handle or update a tool call, creating or updating a `ToolCallPart` or `ToolCallPartDelta`.
|
|
134
134
|
|
|
135
|
-
Managed items remain as `ToolCallPartDelta`s until they have
|
|
135
|
+
Managed items remain as `ToolCallPartDelta`s until they have at least a tool_name, at which
|
|
136
136
|
point they are upgraded to `ToolCallPart`s.
|
|
137
137
|
|
|
138
138
|
If `vendor_part_id` is None, updates the latest matching ToolCallPart (or ToolCallPartDelta)
|
|
@@ -143,11 +143,11 @@ class ModelResponsePartsManager:
|
|
|
143
143
|
If None, the latest matching tool call may be updated.
|
|
144
144
|
tool_name: The name of the tool. If None, the manager does not enforce
|
|
145
145
|
a name match when `vendor_part_id` is None.
|
|
146
|
-
args: Arguments for the tool call, either as a string
|
|
146
|
+
args: Arguments for the tool call, either as a string, a dictionary of key-value pairs, or None.
|
|
147
147
|
tool_call_id: An optional string representing an identifier for this tool call.
|
|
148
148
|
|
|
149
149
|
Returns:
|
|
150
|
-
- A `PartStartEvent` if a new
|
|
150
|
+
- A `PartStartEvent` if a new ToolCallPart is created.
|
|
151
151
|
- A `PartDeltaEvent` if an existing part is updated.
|
|
152
152
|
- `None` if no new event is emitted (e.g., the part is still incomplete).
|
|
153
153
|
|
|
@@ -207,7 +207,7 @@ class ModelResponsePartsManager:
|
|
|
207
207
|
*,
|
|
208
208
|
vendor_part_id: Hashable | None,
|
|
209
209
|
tool_name: str,
|
|
210
|
-
args: str | dict[str, Any],
|
|
210
|
+
args: str | dict[str, Any] | None,
|
|
211
211
|
tool_call_id: str | None = None,
|
|
212
212
|
) -> ModelResponseStreamEvent:
|
|
213
213
|
"""Immediately create or fully-overwrite a ToolCallPart with the given information.
|
|
@@ -218,7 +218,7 @@ class ModelResponsePartsManager:
|
|
|
218
218
|
vendor_part_id: The vendor's ID for this tool call part. If not
|
|
219
219
|
None and an existing part is found, that part is overwritten.
|
|
220
220
|
tool_name: The name of the tool being invoked.
|
|
221
|
-
args: The arguments for the tool call, either as a string
|
|
221
|
+
args: The arguments for the tool call, either as a string, a dictionary, or None.
|
|
222
222
|
tool_call_id: An optional string identifier for this tool call.
|
|
223
223
|
|
|
224
224
|
Returns:
|
pydantic_ai/_pydantic.py
CHANGED
|
@@ -76,8 +76,15 @@ def function_schema( # noqa: C901
|
|
|
76
76
|
description, field_descriptions = doc_descriptions(function, sig, docstring_format=docstring_format)
|
|
77
77
|
|
|
78
78
|
if require_parameter_descriptions:
|
|
79
|
-
if
|
|
79
|
+
if takes_ctx:
|
|
80
|
+
parameters_without_ctx = set(
|
|
81
|
+
name for name in sig.parameters if not _is_call_ctx(sig.parameters[name].annotation)
|
|
82
|
+
)
|
|
83
|
+
missing_params = parameters_without_ctx - set(field_descriptions)
|
|
84
|
+
else:
|
|
80
85
|
missing_params = set(sig.parameters) - set(field_descriptions)
|
|
86
|
+
|
|
87
|
+
if missing_params:
|
|
81
88
|
errors.append(f'Missing parameter descriptions for {", ".join(missing_params)}')
|
|
82
89
|
|
|
83
90
|
for index, (name, p) in enumerate(sig.parameters.items()):
|
pydantic_ai/_utils.py
CHANGED
|
@@ -5,7 +5,7 @@ import time
|
|
|
5
5
|
import uuid
|
|
6
6
|
from collections.abc import AsyncIterable, AsyncIterator, Iterator
|
|
7
7
|
from contextlib import asynccontextmanager, suppress
|
|
8
|
-
from dataclasses import dataclass, is_dataclass
|
|
8
|
+
from dataclasses import dataclass, fields, is_dataclass
|
|
9
9
|
from datetime import datetime, timezone
|
|
10
10
|
from functools import partial
|
|
11
11
|
from types import GenericAlias
|
|
@@ -290,3 +290,11 @@ class PeekableAsyncStream(Generic[T]):
|
|
|
290
290
|
|
|
291
291
|
def get_traceparent(x: AgentRun | AgentRunResult | GraphRun | GraphRunResult) -> str:
|
|
292
292
|
return x._traceparent(required=False) or '' # type: ignore[reportPrivateUsage]
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def dataclasses_no_defaults_repr(self: Any) -> str:
|
|
296
|
+
"""Exclude fields with values equal to the field default."""
|
|
297
|
+
kv_pairs = (
|
|
298
|
+
f'{f.name}={getattr(self, f.name)!r}' for f in fields(self) if f.repr and getattr(self, f.name) != f.default
|
|
299
|
+
)
|
|
300
|
+
return f'{self.__class__.__qualname__}({", ".join(kv_pairs)})'
|
pydantic_ai/agent.py
CHANGED
|
@@ -574,30 +574,21 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
574
574
|
UserPromptPart(
|
|
575
575
|
content='What is the capital of France?',
|
|
576
576
|
timestamp=datetime.datetime(...),
|
|
577
|
-
part_kind='user-prompt',
|
|
578
577
|
)
|
|
579
|
-
]
|
|
580
|
-
instructions=None,
|
|
581
|
-
kind='request',
|
|
578
|
+
]
|
|
582
579
|
)
|
|
583
580
|
),
|
|
584
581
|
CallToolsNode(
|
|
585
582
|
model_response=ModelResponse(
|
|
586
|
-
parts=[TextPart(content='Paris'
|
|
583
|
+
parts=[TextPart(content='Paris')],
|
|
587
584
|
usage=Usage(
|
|
588
|
-
requests=1,
|
|
589
|
-
request_tokens=56,
|
|
590
|
-
response_tokens=1,
|
|
591
|
-
total_tokens=57,
|
|
592
|
-
details=None,
|
|
585
|
+
requests=1, request_tokens=56, response_tokens=1, total_tokens=57
|
|
593
586
|
),
|
|
594
587
|
model_name='gpt-4o',
|
|
595
588
|
timestamp=datetime.datetime(...),
|
|
596
|
-
kind='response',
|
|
597
|
-
vendor_id=None,
|
|
598
589
|
)
|
|
599
590
|
),
|
|
600
|
-
End(data=FinalResult(output='Paris'
|
|
591
|
+
End(data=FinalResult(output='Paris')),
|
|
601
592
|
]
|
|
602
593
|
'''
|
|
603
594
|
print(agent_run.result.output)
|
|
@@ -1760,9 +1751,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1760
1751
|
lifespan=lifespan,
|
|
1761
1752
|
)
|
|
1762
1753
|
|
|
1763
|
-
async def to_cli(self: Self, deps: AgentDepsT = None) -> None:
|
|
1754
|
+
async def to_cli(self: Self, deps: AgentDepsT = None, prog_name: str = 'pydantic-ai') -> None:
|
|
1764
1755
|
"""Run the agent in a CLI chat interface.
|
|
1765
1756
|
|
|
1757
|
+
Args:
|
|
1758
|
+
deps: The dependencies to pass to the agent.
|
|
1759
|
+
prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'.
|
|
1760
|
+
|
|
1766
1761
|
Example:
|
|
1767
1762
|
```python {title="agent_to_cli.py" test="skip"}
|
|
1768
1763
|
from pydantic_ai import Agent
|
|
@@ -1777,29 +1772,24 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1777
1772
|
|
|
1778
1773
|
from pydantic_ai._cli import run_chat
|
|
1779
1774
|
|
|
1780
|
-
|
|
1781
|
-
# `prog_name` from here.
|
|
1775
|
+
await run_chat(stream=True, agent=self, deps=deps, console=Console(), code_theme='monokai', prog_name=prog_name)
|
|
1782
1776
|
|
|
1783
|
-
|
|
1784
|
-
stream=True,
|
|
1785
|
-
agent=self,
|
|
1786
|
-
deps=deps,
|
|
1787
|
-
console=Console(),
|
|
1788
|
-
code_theme='monokai',
|
|
1789
|
-
prog_name='pydantic-ai',
|
|
1790
|
-
)
|
|
1791
|
-
|
|
1792
|
-
def to_cli_sync(self: Self, deps: AgentDepsT = None) -> None:
|
|
1777
|
+
def to_cli_sync(self: Self, deps: AgentDepsT = None, prog_name: str = 'pydantic-ai') -> None:
|
|
1793
1778
|
"""Run the agent in a CLI chat interface with the non-async interface.
|
|
1794
1779
|
|
|
1780
|
+
Args:
|
|
1781
|
+
deps: The dependencies to pass to the agent.
|
|
1782
|
+
prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'.
|
|
1783
|
+
|
|
1795
1784
|
```python {title="agent_to_cli_sync.py" test="skip"}
|
|
1796
1785
|
from pydantic_ai import Agent
|
|
1797
1786
|
|
|
1798
1787
|
agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.')
|
|
1799
1788
|
agent.to_cli_sync()
|
|
1789
|
+
agent.to_cli_sync(prog_name='assistant')
|
|
1800
1790
|
```
|
|
1801
1791
|
"""
|
|
1802
|
-
return get_event_loop().run_until_complete(self.to_cli(deps=deps))
|
|
1792
|
+
return get_event_loop().run_until_complete(self.to_cli(deps=deps, prog_name=prog_name))
|
|
1803
1793
|
|
|
1804
1794
|
|
|
1805
1795
|
@dataclasses.dataclass(repr=False)
|
|
@@ -1841,30 +1831,21 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
|
|
|
1841
1831
|
UserPromptPart(
|
|
1842
1832
|
content='What is the capital of France?',
|
|
1843
1833
|
timestamp=datetime.datetime(...),
|
|
1844
|
-
part_kind='user-prompt',
|
|
1845
1834
|
)
|
|
1846
|
-
]
|
|
1847
|
-
instructions=None,
|
|
1848
|
-
kind='request',
|
|
1835
|
+
]
|
|
1849
1836
|
)
|
|
1850
1837
|
),
|
|
1851
1838
|
CallToolsNode(
|
|
1852
1839
|
model_response=ModelResponse(
|
|
1853
|
-
parts=[TextPart(content='Paris'
|
|
1840
|
+
parts=[TextPart(content='Paris')],
|
|
1854
1841
|
usage=Usage(
|
|
1855
|
-
requests=1,
|
|
1856
|
-
request_tokens=56,
|
|
1857
|
-
response_tokens=1,
|
|
1858
|
-
total_tokens=57,
|
|
1859
|
-
details=None,
|
|
1842
|
+
requests=1, request_tokens=56, response_tokens=1, total_tokens=57
|
|
1860
1843
|
),
|
|
1861
1844
|
model_name='gpt-4o',
|
|
1862
1845
|
timestamp=datetime.datetime(...),
|
|
1863
|
-
kind='response',
|
|
1864
|
-
vendor_id=None,
|
|
1865
1846
|
)
|
|
1866
1847
|
),
|
|
1867
|
-
End(data=FinalResult(output='Paris'
|
|
1848
|
+
End(data=FinalResult(output='Paris')),
|
|
1868
1849
|
]
|
|
1869
1850
|
'''
|
|
1870
1851
|
print(agent_run.result.output)
|
|
@@ -1987,30 +1968,24 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
|
|
|
1987
1968
|
UserPromptPart(
|
|
1988
1969
|
content='What is the capital of France?',
|
|
1989
1970
|
timestamp=datetime.datetime(...),
|
|
1990
|
-
part_kind='user-prompt',
|
|
1991
1971
|
)
|
|
1992
|
-
]
|
|
1993
|
-
instructions=None,
|
|
1994
|
-
kind='request',
|
|
1972
|
+
]
|
|
1995
1973
|
)
|
|
1996
1974
|
),
|
|
1997
1975
|
CallToolsNode(
|
|
1998
1976
|
model_response=ModelResponse(
|
|
1999
|
-
parts=[TextPart(content='Paris'
|
|
1977
|
+
parts=[TextPart(content='Paris')],
|
|
2000
1978
|
usage=Usage(
|
|
2001
1979
|
requests=1,
|
|
2002
1980
|
request_tokens=56,
|
|
2003
1981
|
response_tokens=1,
|
|
2004
1982
|
total_tokens=57,
|
|
2005
|
-
details=None,
|
|
2006
1983
|
),
|
|
2007
1984
|
model_name='gpt-4o',
|
|
2008
1985
|
timestamp=datetime.datetime(...),
|
|
2009
|
-
kind='response',
|
|
2010
|
-
vendor_id=None,
|
|
2011
1986
|
)
|
|
2012
1987
|
),
|
|
2013
|
-
End(data=FinalResult(output='Paris'
|
|
1988
|
+
End(data=FinalResult(output='Paris')),
|
|
2014
1989
|
]
|
|
2015
1990
|
'''
|
|
2016
1991
|
print('Final result:', agent_run.result.output)
|
pydantic_ai/direct.py
CHANGED
|
@@ -41,18 +41,10 @@ async def model_request(
|
|
|
41
41
|
print(model_response)
|
|
42
42
|
'''
|
|
43
43
|
ModelResponse(
|
|
44
|
-
parts=[TextPart(content='Paris'
|
|
45
|
-
usage=Usage(
|
|
46
|
-
requests=1,
|
|
47
|
-
request_tokens=56,
|
|
48
|
-
response_tokens=1,
|
|
49
|
-
total_tokens=57,
|
|
50
|
-
details=None,
|
|
51
|
-
),
|
|
44
|
+
parts=[TextPart(content='Paris')],
|
|
45
|
+
usage=Usage(requests=1, request_tokens=56, response_tokens=1, total_tokens=57),
|
|
52
46
|
model_name='claude-3-5-haiku-latest',
|
|
53
47
|
timestamp=datetime.datetime(...),
|
|
54
|
-
kind='response',
|
|
55
|
-
vendor_id=None,
|
|
56
48
|
)
|
|
57
49
|
'''
|
|
58
50
|
```
|
|
@@ -102,14 +94,10 @@ def model_request_sync(
|
|
|
102
94
|
print(model_response)
|
|
103
95
|
'''
|
|
104
96
|
ModelResponse(
|
|
105
|
-
parts=[TextPart(content='Paris'
|
|
106
|
-
usage=Usage(
|
|
107
|
-
requests=1, request_tokens=56, response_tokens=1, total_tokens=57, details=None
|
|
108
|
-
),
|
|
97
|
+
parts=[TextPart(content='Paris')],
|
|
98
|
+
usage=Usage(requests=1, request_tokens=56, response_tokens=1, total_tokens=57),
|
|
109
99
|
model_name='claude-3-5-haiku-latest',
|
|
110
100
|
timestamp=datetime.datetime(...),
|
|
111
|
-
kind='response',
|
|
112
|
-
vendor_id=None,
|
|
113
101
|
)
|
|
114
102
|
'''
|
|
115
103
|
```
|
|
@@ -163,23 +151,11 @@ def model_request_stream(
|
|
|
163
151
|
print(chunks)
|
|
164
152
|
'''
|
|
165
153
|
[
|
|
166
|
-
PartStartEvent(
|
|
167
|
-
index=0,
|
|
168
|
-
part=TextPart(content='Albert Einstein was ', part_kind='text'),
|
|
169
|
-
event_kind='part_start',
|
|
170
|
-
),
|
|
171
|
-
PartDeltaEvent(
|
|
172
|
-
index=0,
|
|
173
|
-
delta=TextPartDelta(
|
|
174
|
-
content_delta='a German-born theoretical ', part_delta_kind='text'
|
|
175
|
-
),
|
|
176
|
-
event_kind='part_delta',
|
|
177
|
-
),
|
|
154
|
+
PartStartEvent(index=0, part=TextPart(content='Albert Einstein was ')),
|
|
178
155
|
PartDeltaEvent(
|
|
179
|
-
index=0,
|
|
180
|
-
delta=TextPartDelta(content_delta='physicist.', part_delta_kind='text'),
|
|
181
|
-
event_kind='part_delta',
|
|
156
|
+
index=0, delta=TextPartDelta(content_delta='a German-born theoretical ')
|
|
182
157
|
),
|
|
158
|
+
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='physicist.')),
|
|
183
159
|
]
|
|
184
160
|
'''
|
|
185
161
|
```
|
pydantic_ai/mcp.py
CHANGED
|
@@ -46,6 +46,13 @@ class MCPServer(ABC):
|
|
|
46
46
|
"""
|
|
47
47
|
|
|
48
48
|
is_running: bool = False
|
|
49
|
+
tool_prefix: str | None = None
|
|
50
|
+
"""A prefix to add to all tools that are registered with the server.
|
|
51
|
+
|
|
52
|
+
If not empty, will include a trailing underscore(`_`).
|
|
53
|
+
|
|
54
|
+
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
55
|
+
"""
|
|
49
56
|
|
|
50
57
|
_client: ClientSession
|
|
51
58
|
_read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
|
@@ -57,7 +64,10 @@ class MCPServer(ABC):
|
|
|
57
64
|
async def client_streams(
|
|
58
65
|
self,
|
|
59
66
|
) -> AsyncIterator[
|
|
60
|
-
tuple[
|
|
67
|
+
tuple[
|
|
68
|
+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
|
69
|
+
MemoryObjectSendStream[JSONRPCMessage],
|
|
70
|
+
]
|
|
61
71
|
]:
|
|
62
72
|
"""Create the streams for the MCP server."""
|
|
63
73
|
raise NotImplementedError('MCP Server subclasses must implement this method.')
|
|
@@ -68,6 +78,14 @@ class MCPServer(ABC):
|
|
|
68
78
|
"""Get the log level for the MCP server."""
|
|
69
79
|
raise NotImplementedError('MCP Server subclasses must implement this method.')
|
|
70
80
|
|
|
81
|
+
def get_prefixed_tool_name(self, tool_name: str) -> str:
|
|
82
|
+
"""Get the tool name with prefix if `tool_prefix` is set."""
|
|
83
|
+
return f'{self.tool_prefix}_{tool_name}' if self.tool_prefix else tool_name
|
|
84
|
+
|
|
85
|
+
def get_unprefixed_tool_name(self, tool_name: str) -> str:
|
|
86
|
+
"""Get original tool name without prefix for calling tools."""
|
|
87
|
+
return tool_name.removeprefix(f'{self.tool_prefix}_') if self.tool_prefix else tool_name
|
|
88
|
+
|
|
71
89
|
async def list_tools(self) -> list[ToolDefinition]:
|
|
72
90
|
"""Retrieve tools that are currently active on the server.
|
|
73
91
|
|
|
@@ -78,7 +96,7 @@ class MCPServer(ABC):
|
|
|
78
96
|
tools = await self._client.list_tools()
|
|
79
97
|
return [
|
|
80
98
|
ToolDefinition(
|
|
81
|
-
name=tool.name,
|
|
99
|
+
name=self.get_prefixed_tool_name(tool.name),
|
|
82
100
|
description=tool.description or '',
|
|
83
101
|
parameters_json_schema=tool.inputSchema,
|
|
84
102
|
)
|
|
@@ -100,7 +118,7 @@ class MCPServer(ABC):
|
|
|
100
118
|
Raises:
|
|
101
119
|
ModelRetry: If the tool call fails.
|
|
102
120
|
"""
|
|
103
|
-
result = await self._client.call_tool(tool_name, arguments)
|
|
121
|
+
result = await self._client.call_tool(self.get_unprefixed_tool_name(tool_name), arguments)
|
|
104
122
|
|
|
105
123
|
content = [self._map_tool_result_part(part) for part in result.content]
|
|
106
124
|
|
|
@@ -126,7 +144,10 @@ class MCPServer(ABC):
|
|
|
126
144
|
return self
|
|
127
145
|
|
|
128
146
|
async def __aexit__(
|
|
129
|
-
self,
|
|
147
|
+
self,
|
|
148
|
+
exc_type: type[BaseException] | None,
|
|
149
|
+
exc_value: BaseException | None,
|
|
150
|
+
traceback: TracebackType | None,
|
|
130
151
|
) -> bool | None:
|
|
131
152
|
await self._exit_stack.aclose()
|
|
132
153
|
self.is_running = False
|
|
@@ -223,11 +244,22 @@ class MCPServerStdio(MCPServer):
|
|
|
223
244
|
cwd: str | Path | None = None
|
|
224
245
|
"""The working directory to use when spawning the process."""
|
|
225
246
|
|
|
247
|
+
tool_prefix: str | None = None
|
|
248
|
+
"""A prefix to add to all tools that are registered with the server.
|
|
249
|
+
|
|
250
|
+
If not empty, will include a trailing underscore(`_`).
|
|
251
|
+
|
|
252
|
+
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
253
|
+
"""
|
|
254
|
+
|
|
226
255
|
@asynccontextmanager
|
|
227
256
|
async def client_streams(
|
|
228
257
|
self,
|
|
229
258
|
) -> AsyncIterator[
|
|
230
|
-
tuple[
|
|
259
|
+
tuple[
|
|
260
|
+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
|
261
|
+
MemoryObjectSendStream[JSONRPCMessage],
|
|
262
|
+
]
|
|
231
263
|
]:
|
|
232
264
|
server = StdioServerParameters(command=self.command, args=list(self.args), env=self.env, cwd=self.cwd)
|
|
233
265
|
async with stdio_client(server=server) as (read_stream, write_stream):
|
|
@@ -236,6 +268,9 @@ class MCPServerStdio(MCPServer):
|
|
|
236
268
|
def _get_log_level(self) -> LoggingLevel | None:
|
|
237
269
|
return self.log_level
|
|
238
270
|
|
|
271
|
+
def __repr__(self) -> str:
|
|
272
|
+
return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})'
|
|
273
|
+
|
|
239
274
|
|
|
240
275
|
@dataclass
|
|
241
276
|
class MCPServerHTTP(MCPServer):
|
|
@@ -303,16 +338,33 @@ class MCPServerHTTP(MCPServer):
|
|
|
303
338
|
If `None`, no log level will be set.
|
|
304
339
|
"""
|
|
305
340
|
|
|
341
|
+
tool_prefix: str | None = None
|
|
342
|
+
"""A prefix to add to all tools that are registered with the server.
|
|
343
|
+
|
|
344
|
+
If not empty, will include a trailing underscore (`_`).
|
|
345
|
+
|
|
346
|
+
For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
347
|
+
"""
|
|
348
|
+
|
|
306
349
|
@asynccontextmanager
|
|
307
350
|
async def client_streams(
|
|
308
351
|
self,
|
|
309
352
|
) -> AsyncIterator[
|
|
310
|
-
tuple[
|
|
353
|
+
tuple[
|
|
354
|
+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
|
355
|
+
MemoryObjectSendStream[JSONRPCMessage],
|
|
356
|
+
]
|
|
311
357
|
]: # pragma: no cover
|
|
312
358
|
async with sse_client(
|
|
313
|
-
url=self.url,
|
|
359
|
+
url=self.url,
|
|
360
|
+
headers=self.headers,
|
|
361
|
+
timeout=self.timeout,
|
|
362
|
+
sse_read_timeout=self.sse_read_timeout,
|
|
314
363
|
) as (read_stream, write_stream):
|
|
315
364
|
yield read_stream, write_stream
|
|
316
365
|
|
|
317
366
|
def _get_log_level(self) -> LoggingLevel | None:
|
|
318
367
|
return self.log_level
|
|
368
|
+
|
|
369
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
370
|
+
return f'MCPServerHTTP(url={self.url!r}, tool_prefix={self.tool_prefix!r})'
|