casual-mcp 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,12 +14,18 @@ from fastmcp import Client
14
14
 
15
15
  from casual_mcp.convert_tools import tools_from_mcp
16
16
  from casual_mcp.logging import get_logger
17
+ from casual_mcp.models.chat_stats import ChatStats
18
+ from casual_mcp.models.toolset_config import ToolSetConfig
17
19
  from casual_mcp.tool_cache import ToolCache
20
+ from casual_mcp.tool_filter import filter_tools_by_toolset
18
21
  from casual_mcp.utils import format_tool_call_result
19
22
 
20
23
  logger = get_logger("mcp_tool_chat")
21
24
  sessions: dict[str, list[ChatMessage]] = {}
22
25
 
26
+ # Type alias for metadata dictionary
27
+ MetaDict = dict[str, Any]
28
+
23
29
 
24
30
  def get_session_messages(session_id: str) -> list[ChatMessage]:
25
31
  global sessions
@@ -44,19 +50,65 @@ class McpToolChat:
44
50
  provider: LLMProvider,
45
51
  system: str | None = None,
46
52
  tool_cache: ToolCache | None = None,
53
+ server_names: set[str] | None = None,
47
54
  ):
48
55
  self.provider = provider
49
56
  self.mcp_client = mcp_client
50
57
  self.system = system
51
58
  self.tool_cache = tool_cache or ToolCache(mcp_client)
59
+ self.server_names = server_names or set()
52
60
  self._tool_cache_version = -1
61
+ self._last_stats: ChatStats | None = None
53
62
 
54
63
  @staticmethod
55
64
  def get_session(session_id: str) -> list[ChatMessage] | None:
56
65
  global sessions
57
66
  return sessions.get(session_id)
58
67
 
59
- async def generate(self, prompt: str, session_id: str | None = None) -> list[ChatMessage]:
68
+ def get_stats(self) -> ChatStats | None:
69
+ """
70
+ Get usage statistics from the last chat() or generate() call.
71
+
72
+ Returns None if no calls have been made yet.
73
+ Stats are reset at the start of each new chat()/generate() call.
74
+ """
75
+ return self._last_stats
76
+
77
+ def _extract_server_from_tool_name(self, tool_name: str) -> str:
78
+ """
79
+ Extract server name from a tool name.
80
+
81
+ With multiple servers, fastmcp prefixes tools as "serverName_toolName".
82
+ With a single server, tools are not prefixed.
83
+
84
+ Returns the server name or "default" if it cannot be determined.
85
+ """
86
+ if "_" in tool_name:
87
+ return tool_name.split("_", 1)[0]
88
+ return "default"
89
+
90
+ async def generate(
91
+ self,
92
+ prompt: str,
93
+ session_id: str | None = None,
94
+ tool_set: ToolSetConfig | None = None,
95
+ meta: MetaDict | None = None,
96
+ ) -> list[ChatMessage]:
97
+ """
98
+ Generate a response to a prompt, optionally using session history.
99
+
100
+ Args:
101
+ prompt: The user prompt to respond to
102
+ session_id: Optional session ID for conversation persistence
103
+ tool_set: Optional tool set configuration to filter available tools
104
+ meta: Optional metadata to pass through to MCP tool calls.
105
+ Useful for passing context like character_id without
106
+ exposing it to the LLM. Servers can access this via
107
+ ctx.request_context.meta.
108
+
109
+ Returns:
110
+ List of response messages including any tool calls and results
111
+ """
60
112
  # Fetch the session if we have a session ID
61
113
  messages: list[ChatMessage]
62
114
  if session_id:
@@ -73,7 +125,7 @@ class McpToolChat:
73
125
  add_messages_to_session(session_id, [user_message])
74
126
 
75
127
  # Perform Chat
76
- response = await self.chat(messages=messages)
128
+ response = await self.chat(messages=messages, tool_set=tool_set, meta=meta)
77
129
 
78
130
  # Add responses to session
79
131
  if session_id:
@@ -81,9 +133,36 @@ class McpToolChat:
81
133
 
82
134
  return response
83
135
 
84
- async def chat(self, messages: list[ChatMessage]) -> list[ChatMessage]:
136
+ async def chat(
137
+ self,
138
+ messages: list[ChatMessage],
139
+ tool_set: ToolSetConfig | None = None,
140
+ meta: MetaDict | None = None,
141
+ ) -> list[ChatMessage]:
142
+ """
143
+ Process a conversation with tool calling support.
144
+
145
+ Args:
146
+ messages: The conversation messages to process
147
+ tool_set: Optional tool set configuration to filter available tools
148
+ meta: Optional metadata to pass through to MCP tool calls.
149
+ Useful for passing context like character_id without
150
+ exposing it to the LLM. Servers can access this via
151
+ ctx.request_context.meta.
152
+
153
+ Returns:
154
+ List of response messages including any tool calls and results
155
+ """
85
156
  tools = await self.tool_cache.get_tools()
86
157
 
158
+ # Filter tools if a toolset is specified
159
+ if tool_set is not None:
160
+ tools = filter_tools_by_toolset(tools, tool_set, self.server_names, validate=True)
161
+ logger.info(f"Filtered to {len(tools)} tools using toolset")
162
+
163
+ # Reset stats at the start of each chat
164
+ self._last_stats = ChatStats()
165
+
87
166
  # Add a system message if required
88
167
  has_system_message = any(message.role == "system" for message in messages)
89
168
  if self.system and not has_system_message:
@@ -97,6 +176,15 @@ class McpToolChat:
97
176
  logger.info("Calling the LLM")
98
177
  ai_message = await self.provider.chat(messages=messages, tools=tools_from_mcp(tools))
99
178
 
179
+ # Accumulate token usage stats
180
+ self._last_stats.llm_calls += 1
181
+ usage = self.provider.get_usage()
182
+ if usage:
183
+ prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
184
+ completion_tokens = getattr(usage, "completion_tokens", 0) or 0
185
+ self._last_stats.tokens.prompt_tokens += prompt_tokens
186
+ self._last_stats.tokens.completion_tokens += completion_tokens
187
+
100
188
  # Add the assistant's message
101
189
  response_messages.append(ai_message)
102
190
  messages.append(ai_message)
@@ -108,14 +196,29 @@ class McpToolChat:
108
196
  logger.info(f"Executing {len(ai_message.tool_calls)} tool calls")
109
197
  result_count = 0
110
198
  for tool_call in ai_message.tool_calls:
199
+ # Track tool call stats
200
+ tool_name = tool_call.function.name
201
+ self._last_stats.tool_calls.by_tool[tool_name] = (
202
+ self._last_stats.tool_calls.by_tool.get(tool_name, 0) + 1
203
+ )
204
+ server_name = self._extract_server_from_tool_name(tool_name)
205
+ self._last_stats.tool_calls.by_server[server_name] = (
206
+ self._last_stats.tool_calls.by_server.get(server_name, 0) + 1
207
+ )
208
+
111
209
  try:
112
- result = await self.execute(tool_call)
210
+ result = await self.execute(tool_call, meta=meta)
113
211
  except Exception as e:
114
212
  logger.error(
115
213
  f"Failed to execute tool '{tool_call.function.name}' "
116
214
  f"(id={tool_call.id}): {e}"
117
215
  )
118
- continue
216
+ # Surface the failure to the LLM so it knows the tool failed
217
+ result = ToolResultMessage(
218
+ name=tool_call.function.name,
219
+ tool_call_id=tool_call.id,
220
+ content=f"Error executing tool: {e}",
221
+ )
119
222
  if result:
120
223
  messages.append(result)
121
224
  response_messages.append(result)
@@ -127,12 +230,27 @@ class McpToolChat:
127
230
 
128
231
  return response_messages
129
232
 
130
- async def execute(self, tool_call: AssistantToolCall) -> ToolResultMessage:
233
+ async def execute(
234
+ self,
235
+ tool_call: AssistantToolCall,
236
+ meta: MetaDict | None = None,
237
+ ) -> ToolResultMessage:
238
+ """
239
+ Execute a single tool call.
240
+
241
+ Args:
242
+ tool_call: The tool call to execute
243
+ meta: Optional metadata to pass to the MCP server.
244
+ Servers can access this via ctx.request_context.meta.
245
+
246
+ Returns:
247
+ ToolResultMessage with the tool execution result
248
+ """
131
249
  tool_name = tool_call.function.name
132
250
  tool_args = json.loads(tool_call.function.arguments)
133
251
  try:
134
252
  async with self.mcp_client:
135
- result = await self.mcp_client.call_tool(tool_name, tool_args)
253
+ result = await self.mcp_client.call_tool(tool_name, tool_args, meta=meta)
136
254
  except Exception as e:
137
255
  if isinstance(e, ValueError):
138
256
  logger.warning(e)
@@ -148,16 +266,35 @@ class McpToolChat:
148
266
  logger.debug(f"Tool Call Result: {result}")
149
267
 
150
268
  result_format = os.getenv("TOOL_RESULT_FORMAT", "result")
151
- # Extract text content from result (handle both TextContent and other content types)
152
- if not result.content:
269
+
270
+ # Prefer structuredContent when available (machine-readable format)
271
+ # Note: MCP types use camelCase (structuredContent), mypy stubs may differ
272
+ structured = getattr(result, "structuredContent", None)
273
+ if structured is not None:
274
+ try:
275
+ content_text = json.dumps(structured)
276
+ except (TypeError, ValueError):
277
+ content_text = str(structured)
278
+ elif not result.content:
153
279
  content_text = "[No content returned]"
154
280
  else:
155
- content_item = result.content[0]
156
- if hasattr(content_item, "text"):
157
- content_text = content_item.text
158
- else:
159
- # Handle non-text content (e.g., ImageContent)
160
- content_text = f"[Non-text content: {type(content_item).__name__}]"
281
+ # Fall back to processing content items
282
+ content_parts: list[Any] = []
283
+ for content_item in result.content:
284
+ if content_item.type == "text":
285
+ try:
286
+ parsed = json.loads(content_item.text)
287
+ content_parts.append(parsed)
288
+ except json.JSONDecodeError:
289
+ content_parts.append(content_item.text)
290
+ elif hasattr(content_item, "mimeType"):
291
+ # Image or audio content
292
+ content_parts.append(f"[{content_item.type}: {content_item.mimeType}]")
293
+ else:
294
+ content_parts.append(str(content_item))
295
+
296
+ content_text = json.dumps(content_parts)
297
+
161
298
  content = format_tool_call_result(tool_call, content_text, style=result_format)
162
299
 
163
300
  return ToolResultMessage(
@@ -7,6 +7,11 @@ from casual_llm import (
7
7
  UserMessage,
8
8
  )
9
9
 
10
+ from .chat_stats import (
11
+ ChatStats,
12
+ TokenUsageStats,
13
+ ToolCallStats,
14
+ )
10
15
  from .mcp_server_config import (
11
16
  McpServerConfig,
12
17
  RemoteServerConfig,
@@ -17,6 +22,11 @@ from .model_config import (
17
22
  OllamaModelConfig,
18
23
  OpenAIModelConfig,
19
24
  )
25
+ from .toolset_config import (
26
+ ExcludeSpec,
27
+ ToolSetConfig,
28
+ ToolSpec,
29
+ )
20
30
 
21
31
  __all__ = [
22
32
  "UserMessage",
@@ -25,10 +35,16 @@ __all__ = [
25
35
  "ToolResultMessage",
26
36
  "SystemMessage",
27
37
  "ChatMessage",
38
+ "ChatStats",
39
+ "TokenUsageStats",
40
+ "ToolCallStats",
28
41
  "McpModelConfig",
29
42
  "OllamaModelConfig",
30
43
  "OpenAIModelConfig",
31
44
  "McpServerConfig",
32
45
  "StdioServerConfig",
33
46
  "RemoteServerConfig",
47
+ "ExcludeSpec",
48
+ "ToolSetConfig",
49
+ "ToolSpec",
34
50
  ]
@@ -0,0 +1,37 @@
1
+ """Usage statistics models for chat sessions."""
2
+
3
+ from pydantic import BaseModel, Field, computed_field
4
+
5
+
6
+ class TokenUsageStats(BaseModel):
7
+ """Token usage statistics accumulated across all LLM calls."""
8
+
9
+ prompt_tokens: int = Field(default=0, ge=0)
10
+ completion_tokens: int = Field(default=0, ge=0)
11
+
12
+ @computed_field # type: ignore[prop-decorator]
13
+ @property
14
+ def total_tokens(self) -> int:
15
+ """Total tokens (prompt + completion)."""
16
+ return self.prompt_tokens + self.completion_tokens
17
+
18
+
19
+ class ToolCallStats(BaseModel):
20
+ """Statistics about tool calls during a chat session."""
21
+
22
+ by_tool: dict[str, int] = Field(default_factory=dict)
23
+ by_server: dict[str, int] = Field(default_factory=dict)
24
+
25
+ @computed_field # type: ignore[prop-decorator]
26
+ @property
27
+ def total(self) -> int:
28
+ """Total number of tool calls made."""
29
+ return sum(self.by_tool.values())
30
+
31
+
32
+ class ChatStats(BaseModel):
33
+ """Combined statistics from a chat session."""
34
+
35
+ tokens: TokenUsageStats = Field(default_factory=TokenUsageStats)
36
+ tool_calls: ToolCallStats = Field(default_factory=ToolCallStats)
37
+ llm_calls: int = Field(default=0, ge=0, description="Number of LLM calls made")
@@ -1,10 +1,12 @@
1
- from pydantic import BaseModel
1
+ from pydantic import BaseModel, Field
2
2
 
3
3
  from casual_mcp.models.mcp_server_config import McpServerConfig
4
4
  from casual_mcp.models.model_config import McpModelConfig
5
+ from casual_mcp.models.toolset_config import ToolSetConfig
5
6
 
6
7
 
7
8
  class Config(BaseModel):
8
9
  namespace_tools: bool | None = False
9
10
  models: dict[str, McpModelConfig]
10
11
  servers: dict[str, McpServerConfig]
12
+ tool_sets: dict[str, ToolSetConfig] = Field(default_factory=dict)
@@ -0,0 +1,40 @@
1
+ """Toolset configuration models for filtering available tools."""
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class ExcludeSpec(BaseModel):
7
+ """Specification for excluding specific tools from a server."""
8
+
9
+ exclude: list[str] = Field(description="List of tool names to exclude")
10
+
11
+
12
+ # Tool specification: true (all), list (include specific), or exclude object
13
+ ToolSpec = bool | list[str] | ExcludeSpec
14
+
15
+
16
+ class ToolSetConfig(BaseModel):
17
+ """Configuration for a named toolset.
18
+
19
+ A toolset defines which tools from which servers should be available
20
+ during a chat session. Each server can be configured to:
21
+ - Include all tools (True)
22
+ - Include specific tools (list of tool names)
23
+ - Include all except specific tools (ExcludeSpec)
24
+
25
+ Example:
26
+ {
27
+ "description": "Research tools",
28
+ "servers": {
29
+ "wikimedia": True,
30
+ "search": ["brave_web_search"],
31
+ "fetch": {"exclude": ["fetch_dangerous"]}
32
+ }
33
+ }
34
+ """
35
+
36
+ description: str = Field(default="", description="Human-readable description")
37
+ servers: dict[str, ToolSpec] = Field(
38
+ default_factory=dict,
39
+ description="Mapping of server name to tool specification",
40
+ )
@@ -0,0 +1,171 @@
1
+ """Tool filtering logic for toolsets.
2
+
3
+ This module provides functionality to filter MCP tools based on toolset
4
+ configurations, including validation to ensure referenced servers and
5
+ tools actually exist.
6
+ """
7
+
8
+ import mcp
9
+
10
+ from casual_mcp.logging import get_logger
11
+ from casual_mcp.models.toolset_config import ExcludeSpec, ToolSetConfig
12
+
13
+ logger = get_logger("tool_filter")
14
+
15
+
16
+ class ToolSetValidationError(Exception):
17
+ """Raised when a toolset references invalid servers or tools."""
18
+
19
+ pass
20
+
21
+
22
+ def extract_server_and_tool(tool_name: str, server_names: set[str]) -> tuple[str, str]:
23
+ """Extract server name and base tool name from a potentially prefixed tool name.
24
+
25
+ When multiple servers are configured, fastmcp prefixes tools as "serverName_toolName".
26
+ When a single server is configured, tools are not prefixed.
27
+
28
+ Args:
29
+ tool_name: The full tool name (possibly prefixed)
30
+ server_names: Set of configured server names
31
+
32
+ Returns:
33
+ Tuple of (server_name, base_tool_name)
34
+ """
35
+ if "_" in tool_name:
36
+ prefix = tool_name.split("_", 1)[0]
37
+ if prefix in server_names:
38
+ return prefix, tool_name.split("_", 1)[1]
39
+
40
+ # Single server case - return the single server name
41
+ if len(server_names) == 1:
42
+ return next(iter(server_names)), tool_name
43
+
44
+ # Fallback - can't determine server
45
+ return "default", tool_name
46
+
47
+
48
+ def _build_server_tool_map(tools: list[mcp.Tool], server_names: set[str]) -> dict[str, set[str]]:
49
+ """Build a mapping of server names to their available tool names.
50
+
51
+ Args:
52
+ tools: List of MCP tools
53
+ server_names: Set of configured server names
54
+
55
+ Returns:
56
+ Dict mapping server name to set of base tool names
57
+ """
58
+ server_tool_map: dict[str, set[str]] = {name: set() for name in server_names}
59
+
60
+ for tool in tools:
61
+ server_name, base_name = extract_server_and_tool(tool.name, server_names)
62
+ if server_name in server_tool_map:
63
+ server_tool_map[server_name].add(base_name)
64
+
65
+ return server_tool_map
66
+
67
+
68
+ def validate_toolset(
69
+ toolset: ToolSetConfig,
70
+ tools: list[mcp.Tool],
71
+ server_names: set[str],
72
+ ) -> None:
73
+ """Validate that a toolset references only valid servers and tools.
74
+
75
+ Args:
76
+ toolset: The toolset configuration to validate
77
+ tools: List of available MCP tools
78
+ server_names: Set of configured server names
79
+
80
+ Raises:
81
+ ToolSetValidationError: If the toolset references non-existent servers or tools
82
+ """
83
+ server_tool_map = _build_server_tool_map(tools, server_names)
84
+ errors: list[str] = []
85
+
86
+ for server_name, tool_spec in toolset.servers.items():
87
+ # Check server exists
88
+ if server_name not in server_names:
89
+ errors.append(f"Server '{server_name}' not found in configuration")
90
+ continue
91
+
92
+ available = server_tool_map.get(server_name, set())
93
+
94
+ # Validate tool names in include list
95
+ if isinstance(tool_spec, list):
96
+ for tool_name in tool_spec:
97
+ if tool_name not in available:
98
+ errors.append(
99
+ f"Tool '{tool_name}' not found in server '{server_name}'. "
100
+ f"Available: {sorted(available)}"
101
+ )
102
+
103
+ # Validate tool names in exclude list
104
+ elif isinstance(tool_spec, ExcludeSpec):
105
+ for tool_name in tool_spec.exclude:
106
+ if tool_name not in available:
107
+ errors.append(
108
+ f"Tool '{tool_name}' not found in server '{server_name}' "
109
+ f"(specified in exclude list). Available: {sorted(available)}"
110
+ )
111
+
112
+ if errors:
113
+ raise ToolSetValidationError("\n".join(errors))
114
+
115
+
116
+ def filter_tools_by_toolset(
117
+ tools: list[mcp.Tool],
118
+ toolset: ToolSetConfig,
119
+ server_names: set[str],
120
+ validate: bool = True,
121
+ ) -> list[mcp.Tool]:
122
+ """Filter a list of MCP tools based on a toolset configuration.
123
+
124
+ Args:
125
+ tools: Full list of available MCP tools
126
+ toolset: The toolset configuration to apply
127
+ server_names: Set of configured server names
128
+ validate: Whether to validate the toolset first (raises on invalid)
129
+
130
+ Returns:
131
+ Filtered list of tools matching the toolset
132
+
133
+ Raises:
134
+ ToolSetValidationError: If validate=True and toolset is invalid
135
+ """
136
+ if validate:
137
+ validate_toolset(toolset, tools, server_names)
138
+
139
+ filtered: list[mcp.Tool] = []
140
+
141
+ for tool in tools:
142
+ server_name, base_name = extract_server_and_tool(tool.name, server_names)
143
+
144
+ # Check if this server is in the toolset
145
+ if server_name not in toolset.servers:
146
+ continue
147
+
148
+ tool_spec = toolset.servers[server_name]
149
+
150
+ # Determine if tool should be included
151
+ include = False
152
+
153
+ if tool_spec is True:
154
+ # All tools from this server
155
+ include = True
156
+ elif isinstance(tool_spec, list):
157
+ # Only specific tools
158
+ include = base_name in tool_spec
159
+ elif isinstance(tool_spec, ExcludeSpec):
160
+ # All except excluded tools
161
+ include = base_name not in tool_spec.exclude
162
+
163
+ if include:
164
+ filtered.append(tool)
165
+
166
+ logger.debug(
167
+ f"Filtered {len(tools)} tools to {len(filtered)} using toolset "
168
+ f"with {len(toolset.servers)} servers"
169
+ )
170
+
171
+ return filtered