casual-mcp 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- casual_mcp/__init__.py +4 -0
- casual_mcp/cli.py +379 -6
- casual_mcp/main.py +93 -8
- casual_mcp/mcp_tool_chat.py +152 -15
- casual_mcp/models/__init__.py +16 -0
- casual_mcp/models/chat_stats.py +37 -0
- casual_mcp/models/config.py +3 -1
- casual_mcp/models/toolset_config.py +40 -0
- casual_mcp/tool_filter.py +171 -0
- casual_mcp-0.7.0.dist-info/METADATA +193 -0
- casual_mcp-0.7.0.dist-info/RECORD +23 -0
- casual_mcp-0.5.0.dist-info/METADATA +0 -630
- casual_mcp-0.5.0.dist-info/RECORD +0 -20
- {casual_mcp-0.5.0.dist-info → casual_mcp-0.7.0.dist-info}/WHEEL +0 -0
- {casual_mcp-0.5.0.dist-info → casual_mcp-0.7.0.dist-info}/entry_points.txt +0 -0
- {casual_mcp-0.5.0.dist-info → casual_mcp-0.7.0.dist-info}/licenses/LICENSE +0 -0
- {casual_mcp-0.5.0.dist-info → casual_mcp-0.7.0.dist-info}/top_level.txt +0 -0
casual_mcp/mcp_tool_chat.py
CHANGED
|
@@ -14,12 +14,18 @@ from fastmcp import Client
|
|
|
14
14
|
|
|
15
15
|
from casual_mcp.convert_tools import tools_from_mcp
|
|
16
16
|
from casual_mcp.logging import get_logger
|
|
17
|
+
from casual_mcp.models.chat_stats import ChatStats
|
|
18
|
+
from casual_mcp.models.toolset_config import ToolSetConfig
|
|
17
19
|
from casual_mcp.tool_cache import ToolCache
|
|
20
|
+
from casual_mcp.tool_filter import filter_tools_by_toolset
|
|
18
21
|
from casual_mcp.utils import format_tool_call_result
|
|
19
22
|
|
|
20
23
|
logger = get_logger("mcp_tool_chat")
|
|
21
24
|
sessions: dict[str, list[ChatMessage]] = {}
|
|
22
25
|
|
|
26
|
+
# Type alias for metadata dictionary
|
|
27
|
+
MetaDict = dict[str, Any]
|
|
28
|
+
|
|
23
29
|
|
|
24
30
|
def get_session_messages(session_id: str) -> list[ChatMessage]:
|
|
25
31
|
global sessions
|
|
@@ -44,19 +50,65 @@ class McpToolChat:
|
|
|
44
50
|
provider: LLMProvider,
|
|
45
51
|
system: str | None = None,
|
|
46
52
|
tool_cache: ToolCache | None = None,
|
|
53
|
+
server_names: set[str] | None = None,
|
|
47
54
|
):
|
|
48
55
|
self.provider = provider
|
|
49
56
|
self.mcp_client = mcp_client
|
|
50
57
|
self.system = system
|
|
51
58
|
self.tool_cache = tool_cache or ToolCache(mcp_client)
|
|
59
|
+
self.server_names = server_names or set()
|
|
52
60
|
self._tool_cache_version = -1
|
|
61
|
+
self._last_stats: ChatStats | None = None
|
|
53
62
|
|
|
54
63
|
@staticmethod
|
|
55
64
|
def get_session(session_id: str) -> list[ChatMessage] | None:
|
|
56
65
|
global sessions
|
|
57
66
|
return sessions.get(session_id)
|
|
58
67
|
|
|
59
|
-
|
|
68
|
+
def get_stats(self) -> ChatStats | None:
|
|
69
|
+
"""
|
|
70
|
+
Get usage statistics from the last chat() or generate() call.
|
|
71
|
+
|
|
72
|
+
Returns None if no calls have been made yet.
|
|
73
|
+
Stats are reset at the start of each new chat()/generate() call.
|
|
74
|
+
"""
|
|
75
|
+
return self._last_stats
|
|
76
|
+
|
|
77
|
+
def _extract_server_from_tool_name(self, tool_name: str) -> str:
|
|
78
|
+
"""
|
|
79
|
+
Extract server name from a tool name.
|
|
80
|
+
|
|
81
|
+
With multiple servers, fastmcp prefixes tools as "serverName_toolName".
|
|
82
|
+
With a single server, tools are not prefixed.
|
|
83
|
+
|
|
84
|
+
Returns the server name or "default" if it cannot be determined.
|
|
85
|
+
"""
|
|
86
|
+
if "_" in tool_name:
|
|
87
|
+
return tool_name.split("_", 1)[0]
|
|
88
|
+
return "default"
|
|
89
|
+
|
|
90
|
+
async def generate(
|
|
91
|
+
self,
|
|
92
|
+
prompt: str,
|
|
93
|
+
session_id: str | None = None,
|
|
94
|
+
tool_set: ToolSetConfig | None = None,
|
|
95
|
+
meta: MetaDict | None = None,
|
|
96
|
+
) -> list[ChatMessage]:
|
|
97
|
+
"""
|
|
98
|
+
Generate a response to a prompt, optionally using session history.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
prompt: The user prompt to respond to
|
|
102
|
+
session_id: Optional session ID for conversation persistence
|
|
103
|
+
tool_set: Optional tool set configuration to filter available tools
|
|
104
|
+
meta: Optional metadata to pass through to MCP tool calls.
|
|
105
|
+
Useful for passing context like character_id without
|
|
106
|
+
exposing it to the LLM. Servers can access this via
|
|
107
|
+
ctx.request_context.meta.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
List of response messages including any tool calls and results
|
|
111
|
+
"""
|
|
60
112
|
# Fetch the session if we have a session ID
|
|
61
113
|
messages: list[ChatMessage]
|
|
62
114
|
if session_id:
|
|
@@ -73,7 +125,7 @@ class McpToolChat:
|
|
|
73
125
|
add_messages_to_session(session_id, [user_message])
|
|
74
126
|
|
|
75
127
|
# Perform Chat
|
|
76
|
-
response = await self.chat(messages=messages)
|
|
128
|
+
response = await self.chat(messages=messages, tool_set=tool_set, meta=meta)
|
|
77
129
|
|
|
78
130
|
# Add responses to session
|
|
79
131
|
if session_id:
|
|
@@ -81,9 +133,36 @@ class McpToolChat:
|
|
|
81
133
|
|
|
82
134
|
return response
|
|
83
135
|
|
|
84
|
-
async def chat(
|
|
136
|
+
async def chat(
|
|
137
|
+
self,
|
|
138
|
+
messages: list[ChatMessage],
|
|
139
|
+
tool_set: ToolSetConfig | None = None,
|
|
140
|
+
meta: MetaDict | None = None,
|
|
141
|
+
) -> list[ChatMessage]:
|
|
142
|
+
"""
|
|
143
|
+
Process a conversation with tool calling support.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
messages: The conversation messages to process
|
|
147
|
+
tool_set: Optional tool set configuration to filter available tools
|
|
148
|
+
meta: Optional metadata to pass through to MCP tool calls.
|
|
149
|
+
Useful for passing context like character_id without
|
|
150
|
+
exposing it to the LLM. Servers can access this via
|
|
151
|
+
ctx.request_context.meta.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
List of response messages including any tool calls and results
|
|
155
|
+
"""
|
|
85
156
|
tools = await self.tool_cache.get_tools()
|
|
86
157
|
|
|
158
|
+
# Filter tools if a toolset is specified
|
|
159
|
+
if tool_set is not None:
|
|
160
|
+
tools = filter_tools_by_toolset(tools, tool_set, self.server_names, validate=True)
|
|
161
|
+
logger.info(f"Filtered to {len(tools)} tools using toolset")
|
|
162
|
+
|
|
163
|
+
# Reset stats at the start of each chat
|
|
164
|
+
self._last_stats = ChatStats()
|
|
165
|
+
|
|
87
166
|
# Add a system message if required
|
|
88
167
|
has_system_message = any(message.role == "system" for message in messages)
|
|
89
168
|
if self.system and not has_system_message:
|
|
@@ -97,6 +176,15 @@ class McpToolChat:
|
|
|
97
176
|
logger.info("Calling the LLM")
|
|
98
177
|
ai_message = await self.provider.chat(messages=messages, tools=tools_from_mcp(tools))
|
|
99
178
|
|
|
179
|
+
# Accumulate token usage stats
|
|
180
|
+
self._last_stats.llm_calls += 1
|
|
181
|
+
usage = self.provider.get_usage()
|
|
182
|
+
if usage:
|
|
183
|
+
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
|
184
|
+
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
|
185
|
+
self._last_stats.tokens.prompt_tokens += prompt_tokens
|
|
186
|
+
self._last_stats.tokens.completion_tokens += completion_tokens
|
|
187
|
+
|
|
100
188
|
# Add the assistant's message
|
|
101
189
|
response_messages.append(ai_message)
|
|
102
190
|
messages.append(ai_message)
|
|
@@ -108,14 +196,29 @@ class McpToolChat:
|
|
|
108
196
|
logger.info(f"Executing {len(ai_message.tool_calls)} tool calls")
|
|
109
197
|
result_count = 0
|
|
110
198
|
for tool_call in ai_message.tool_calls:
|
|
199
|
+
# Track tool call stats
|
|
200
|
+
tool_name = tool_call.function.name
|
|
201
|
+
self._last_stats.tool_calls.by_tool[tool_name] = (
|
|
202
|
+
self._last_stats.tool_calls.by_tool.get(tool_name, 0) + 1
|
|
203
|
+
)
|
|
204
|
+
server_name = self._extract_server_from_tool_name(tool_name)
|
|
205
|
+
self._last_stats.tool_calls.by_server[server_name] = (
|
|
206
|
+
self._last_stats.tool_calls.by_server.get(server_name, 0) + 1
|
|
207
|
+
)
|
|
208
|
+
|
|
111
209
|
try:
|
|
112
|
-
result = await self.execute(tool_call)
|
|
210
|
+
result = await self.execute(tool_call, meta=meta)
|
|
113
211
|
except Exception as e:
|
|
114
212
|
logger.error(
|
|
115
213
|
f"Failed to execute tool '{tool_call.function.name}' "
|
|
116
214
|
f"(id={tool_call.id}): {e}"
|
|
117
215
|
)
|
|
118
|
-
|
|
216
|
+
# Surface the failure to the LLM so it knows the tool failed
|
|
217
|
+
result = ToolResultMessage(
|
|
218
|
+
name=tool_call.function.name,
|
|
219
|
+
tool_call_id=tool_call.id,
|
|
220
|
+
content=f"Error executing tool: {e}",
|
|
221
|
+
)
|
|
119
222
|
if result:
|
|
120
223
|
messages.append(result)
|
|
121
224
|
response_messages.append(result)
|
|
@@ -127,12 +230,27 @@ class McpToolChat:
|
|
|
127
230
|
|
|
128
231
|
return response_messages
|
|
129
232
|
|
|
130
|
-
async def execute(
|
|
233
|
+
async def execute(
|
|
234
|
+
self,
|
|
235
|
+
tool_call: AssistantToolCall,
|
|
236
|
+
meta: MetaDict | None = None,
|
|
237
|
+
) -> ToolResultMessage:
|
|
238
|
+
"""
|
|
239
|
+
Execute a single tool call.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
tool_call: The tool call to execute
|
|
243
|
+
meta: Optional metadata to pass to the MCP server.
|
|
244
|
+
Servers can access this via ctx.request_context.meta.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
ToolResultMessage with the tool execution result
|
|
248
|
+
"""
|
|
131
249
|
tool_name = tool_call.function.name
|
|
132
250
|
tool_args = json.loads(tool_call.function.arguments)
|
|
133
251
|
try:
|
|
134
252
|
async with self.mcp_client:
|
|
135
|
-
result = await self.mcp_client.call_tool(tool_name, tool_args)
|
|
253
|
+
result = await self.mcp_client.call_tool(tool_name, tool_args, meta=meta)
|
|
136
254
|
except Exception as e:
|
|
137
255
|
if isinstance(e, ValueError):
|
|
138
256
|
logger.warning(e)
|
|
@@ -148,16 +266,35 @@ class McpToolChat:
|
|
|
148
266
|
logger.debug(f"Tool Call Result: {result}")
|
|
149
267
|
|
|
150
268
|
result_format = os.getenv("TOOL_RESULT_FORMAT", "result")
|
|
151
|
-
|
|
152
|
-
|
|
269
|
+
|
|
270
|
+
# Prefer structuredContent when available (machine-readable format)
|
|
271
|
+
# Note: MCP types use camelCase (structuredContent), mypy stubs may differ
|
|
272
|
+
structured = getattr(result, "structuredContent", None)
|
|
273
|
+
if structured is not None:
|
|
274
|
+
try:
|
|
275
|
+
content_text = json.dumps(structured)
|
|
276
|
+
except (TypeError, ValueError):
|
|
277
|
+
content_text = str(structured)
|
|
278
|
+
elif not result.content:
|
|
153
279
|
content_text = "[No content returned]"
|
|
154
280
|
else:
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
281
|
+
# Fall back to processing content items
|
|
282
|
+
content_parts: list[Any] = []
|
|
283
|
+
for content_item in result.content:
|
|
284
|
+
if content_item.type == "text":
|
|
285
|
+
try:
|
|
286
|
+
parsed = json.loads(content_item.text)
|
|
287
|
+
content_parts.append(parsed)
|
|
288
|
+
except json.JSONDecodeError:
|
|
289
|
+
content_parts.append(content_item.text)
|
|
290
|
+
elif hasattr(content_item, "mimeType"):
|
|
291
|
+
# Image or audio content
|
|
292
|
+
content_parts.append(f"[{content_item.type}: {content_item.mimeType}]")
|
|
293
|
+
else:
|
|
294
|
+
content_parts.append(str(content_item))
|
|
295
|
+
|
|
296
|
+
content_text = json.dumps(content_parts)
|
|
297
|
+
|
|
161
298
|
content = format_tool_call_result(tool_call, content_text, style=result_format)
|
|
162
299
|
|
|
163
300
|
return ToolResultMessage(
|
casual_mcp/models/__init__.py
CHANGED
|
@@ -7,6 +7,11 @@ from casual_llm import (
|
|
|
7
7
|
UserMessage,
|
|
8
8
|
)
|
|
9
9
|
|
|
10
|
+
from .chat_stats import (
|
|
11
|
+
ChatStats,
|
|
12
|
+
TokenUsageStats,
|
|
13
|
+
ToolCallStats,
|
|
14
|
+
)
|
|
10
15
|
from .mcp_server_config import (
|
|
11
16
|
McpServerConfig,
|
|
12
17
|
RemoteServerConfig,
|
|
@@ -17,6 +22,11 @@ from .model_config import (
|
|
|
17
22
|
OllamaModelConfig,
|
|
18
23
|
OpenAIModelConfig,
|
|
19
24
|
)
|
|
25
|
+
from .toolset_config import (
|
|
26
|
+
ExcludeSpec,
|
|
27
|
+
ToolSetConfig,
|
|
28
|
+
ToolSpec,
|
|
29
|
+
)
|
|
20
30
|
|
|
21
31
|
__all__ = [
|
|
22
32
|
"UserMessage",
|
|
@@ -25,10 +35,16 @@ __all__ = [
|
|
|
25
35
|
"ToolResultMessage",
|
|
26
36
|
"SystemMessage",
|
|
27
37
|
"ChatMessage",
|
|
38
|
+
"ChatStats",
|
|
39
|
+
"TokenUsageStats",
|
|
40
|
+
"ToolCallStats",
|
|
28
41
|
"McpModelConfig",
|
|
29
42
|
"OllamaModelConfig",
|
|
30
43
|
"OpenAIModelConfig",
|
|
31
44
|
"McpServerConfig",
|
|
32
45
|
"StdioServerConfig",
|
|
33
46
|
"RemoteServerConfig",
|
|
47
|
+
"ExcludeSpec",
|
|
48
|
+
"ToolSetConfig",
|
|
49
|
+
"ToolSpec",
|
|
34
50
|
]
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Usage statistics models for chat sessions."""
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, computed_field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TokenUsageStats(BaseModel):
|
|
7
|
+
"""Token usage statistics accumulated across all LLM calls."""
|
|
8
|
+
|
|
9
|
+
prompt_tokens: int = Field(default=0, ge=0)
|
|
10
|
+
completion_tokens: int = Field(default=0, ge=0)
|
|
11
|
+
|
|
12
|
+
@computed_field # type: ignore[prop-decorator]
|
|
13
|
+
@property
|
|
14
|
+
def total_tokens(self) -> int:
|
|
15
|
+
"""Total tokens (prompt + completion)."""
|
|
16
|
+
return self.prompt_tokens + self.completion_tokens
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ToolCallStats(BaseModel):
|
|
20
|
+
"""Statistics about tool calls during a chat session."""
|
|
21
|
+
|
|
22
|
+
by_tool: dict[str, int] = Field(default_factory=dict)
|
|
23
|
+
by_server: dict[str, int] = Field(default_factory=dict)
|
|
24
|
+
|
|
25
|
+
@computed_field # type: ignore[prop-decorator]
|
|
26
|
+
@property
|
|
27
|
+
def total(self) -> int:
|
|
28
|
+
"""Total number of tool calls made."""
|
|
29
|
+
return sum(self.by_tool.values())
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ChatStats(BaseModel):
|
|
33
|
+
"""Combined statistics from a chat session."""
|
|
34
|
+
|
|
35
|
+
tokens: TokenUsageStats = Field(default_factory=TokenUsageStats)
|
|
36
|
+
tool_calls: ToolCallStats = Field(default_factory=ToolCallStats)
|
|
37
|
+
llm_calls: int = Field(default=0, ge=0, description="Number of LLM calls made")
|
casual_mcp/models/config.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
|
-
from pydantic import BaseModel
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
2
|
|
|
3
3
|
from casual_mcp.models.mcp_server_config import McpServerConfig
|
|
4
4
|
from casual_mcp.models.model_config import McpModelConfig
|
|
5
|
+
from casual_mcp.models.toolset_config import ToolSetConfig
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class Config(BaseModel):
|
|
8
9
|
namespace_tools: bool | None = False
|
|
9
10
|
models: dict[str, McpModelConfig]
|
|
10
11
|
servers: dict[str, McpServerConfig]
|
|
12
|
+
tool_sets: dict[str, ToolSetConfig] = Field(default_factory=dict)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Toolset configuration models for filtering available tools."""
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ExcludeSpec(BaseModel):
|
|
7
|
+
"""Specification for excluding specific tools from a server."""
|
|
8
|
+
|
|
9
|
+
exclude: list[str] = Field(description="List of tool names to exclude")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# Tool specification: true (all), list (include specific), or exclude object
|
|
13
|
+
ToolSpec = bool | list[str] | ExcludeSpec
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ToolSetConfig(BaseModel):
|
|
17
|
+
"""Configuration for a named toolset.
|
|
18
|
+
|
|
19
|
+
A toolset defines which tools from which servers should be available
|
|
20
|
+
during a chat session. Each server can be configured to:
|
|
21
|
+
- Include all tools (True)
|
|
22
|
+
- Include specific tools (list of tool names)
|
|
23
|
+
- Include all except specific tools (ExcludeSpec)
|
|
24
|
+
|
|
25
|
+
Example:
|
|
26
|
+
{
|
|
27
|
+
"description": "Research tools",
|
|
28
|
+
"servers": {
|
|
29
|
+
"wikimedia": True,
|
|
30
|
+
"search": ["brave_web_search"],
|
|
31
|
+
"fetch": {"exclude": ["fetch_dangerous"]}
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
description: str = Field(default="", description="Human-readable description")
|
|
37
|
+
servers: dict[str, ToolSpec] = Field(
|
|
38
|
+
default_factory=dict,
|
|
39
|
+
description="Mapping of server name to tool specification",
|
|
40
|
+
)
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""Tool filtering logic for toolsets.
|
|
2
|
+
|
|
3
|
+
This module provides functionality to filter MCP tools based on toolset
|
|
4
|
+
configurations, including validation to ensure referenced servers and
|
|
5
|
+
tools actually exist.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import mcp
|
|
9
|
+
|
|
10
|
+
from casual_mcp.logging import get_logger
|
|
11
|
+
from casual_mcp.models.toolset_config import ExcludeSpec, ToolSetConfig
|
|
12
|
+
|
|
13
|
+
logger = get_logger("tool_filter")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ToolSetValidationError(Exception):
|
|
17
|
+
"""Raised when a toolset references invalid servers or tools."""
|
|
18
|
+
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def extract_server_and_tool(tool_name: str, server_names: set[str]) -> tuple[str, str]:
|
|
23
|
+
"""Extract server name and base tool name from a potentially prefixed tool name.
|
|
24
|
+
|
|
25
|
+
When multiple servers are configured, fastmcp prefixes tools as "serverName_toolName".
|
|
26
|
+
When a single server is configured, tools are not prefixed.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
tool_name: The full tool name (possibly prefixed)
|
|
30
|
+
server_names: Set of configured server names
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Tuple of (server_name, base_tool_name)
|
|
34
|
+
"""
|
|
35
|
+
if "_" in tool_name:
|
|
36
|
+
prefix = tool_name.split("_", 1)[0]
|
|
37
|
+
if prefix in server_names:
|
|
38
|
+
return prefix, tool_name.split("_", 1)[1]
|
|
39
|
+
|
|
40
|
+
# Single server case - return the single server name
|
|
41
|
+
if len(server_names) == 1:
|
|
42
|
+
return next(iter(server_names)), tool_name
|
|
43
|
+
|
|
44
|
+
# Fallback - can't determine server
|
|
45
|
+
return "default", tool_name
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _build_server_tool_map(tools: list[mcp.Tool], server_names: set[str]) -> dict[str, set[str]]:
|
|
49
|
+
"""Build a mapping of server names to their available tool names.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
tools: List of MCP tools
|
|
53
|
+
server_names: Set of configured server names
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Dict mapping server name to set of base tool names
|
|
57
|
+
"""
|
|
58
|
+
server_tool_map: dict[str, set[str]] = {name: set() for name in server_names}
|
|
59
|
+
|
|
60
|
+
for tool in tools:
|
|
61
|
+
server_name, base_name = extract_server_and_tool(tool.name, server_names)
|
|
62
|
+
if server_name in server_tool_map:
|
|
63
|
+
server_tool_map[server_name].add(base_name)
|
|
64
|
+
|
|
65
|
+
return server_tool_map
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def validate_toolset(
|
|
69
|
+
toolset: ToolSetConfig,
|
|
70
|
+
tools: list[mcp.Tool],
|
|
71
|
+
server_names: set[str],
|
|
72
|
+
) -> None:
|
|
73
|
+
"""Validate that a toolset references only valid servers and tools.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
toolset: The toolset configuration to validate
|
|
77
|
+
tools: List of available MCP tools
|
|
78
|
+
server_names: Set of configured server names
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
ToolSetValidationError: If the toolset references non-existent servers or tools
|
|
82
|
+
"""
|
|
83
|
+
server_tool_map = _build_server_tool_map(tools, server_names)
|
|
84
|
+
errors: list[str] = []
|
|
85
|
+
|
|
86
|
+
for server_name, tool_spec in toolset.servers.items():
|
|
87
|
+
# Check server exists
|
|
88
|
+
if server_name not in server_names:
|
|
89
|
+
errors.append(f"Server '{server_name}' not found in configuration")
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
available = server_tool_map.get(server_name, set())
|
|
93
|
+
|
|
94
|
+
# Validate tool names in include list
|
|
95
|
+
if isinstance(tool_spec, list):
|
|
96
|
+
for tool_name in tool_spec:
|
|
97
|
+
if tool_name not in available:
|
|
98
|
+
errors.append(
|
|
99
|
+
f"Tool '{tool_name}' not found in server '{server_name}'. "
|
|
100
|
+
f"Available: {sorted(available)}"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Validate tool names in exclude list
|
|
104
|
+
elif isinstance(tool_spec, ExcludeSpec):
|
|
105
|
+
for tool_name in tool_spec.exclude:
|
|
106
|
+
if tool_name not in available:
|
|
107
|
+
errors.append(
|
|
108
|
+
f"Tool '{tool_name}' not found in server '{server_name}' "
|
|
109
|
+
f"(specified in exclude list). Available: {sorted(available)}"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
if errors:
|
|
113
|
+
raise ToolSetValidationError("\n".join(errors))
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def filter_tools_by_toolset(
|
|
117
|
+
tools: list[mcp.Tool],
|
|
118
|
+
toolset: ToolSetConfig,
|
|
119
|
+
server_names: set[str],
|
|
120
|
+
validate: bool = True,
|
|
121
|
+
) -> list[mcp.Tool]:
|
|
122
|
+
"""Filter a list of MCP tools based on a toolset configuration.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
tools: Full list of available MCP tools
|
|
126
|
+
toolset: The toolset configuration to apply
|
|
127
|
+
server_names: Set of configured server names
|
|
128
|
+
validate: Whether to validate the toolset first (raises on invalid)
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Filtered list of tools matching the toolset
|
|
132
|
+
|
|
133
|
+
Raises:
|
|
134
|
+
ToolSetValidationError: If validate=True and toolset is invalid
|
|
135
|
+
"""
|
|
136
|
+
if validate:
|
|
137
|
+
validate_toolset(toolset, tools, server_names)
|
|
138
|
+
|
|
139
|
+
filtered: list[mcp.Tool] = []
|
|
140
|
+
|
|
141
|
+
for tool in tools:
|
|
142
|
+
server_name, base_name = extract_server_and_tool(tool.name, server_names)
|
|
143
|
+
|
|
144
|
+
# Check if this server is in the toolset
|
|
145
|
+
if server_name not in toolset.servers:
|
|
146
|
+
continue
|
|
147
|
+
|
|
148
|
+
tool_spec = toolset.servers[server_name]
|
|
149
|
+
|
|
150
|
+
# Determine if tool should be included
|
|
151
|
+
include = False
|
|
152
|
+
|
|
153
|
+
if tool_spec is True:
|
|
154
|
+
# All tools from this server
|
|
155
|
+
include = True
|
|
156
|
+
elif isinstance(tool_spec, list):
|
|
157
|
+
# Only specific tools
|
|
158
|
+
include = base_name in tool_spec
|
|
159
|
+
elif isinstance(tool_spec, ExcludeSpec):
|
|
160
|
+
# All except excluded tools
|
|
161
|
+
include = base_name not in tool_spec.exclude
|
|
162
|
+
|
|
163
|
+
if include:
|
|
164
|
+
filtered.append(tool)
|
|
165
|
+
|
|
166
|
+
logger.debug(
|
|
167
|
+
f"Filtered {len(tools)} tools to {len(filtered)} using toolset "
|
|
168
|
+
f"with {len(toolset.servers)} servers"
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
return filtered
|