casual-mcp 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
casual_mcp/__init__.py CHANGED
@@ -1,13 +1,24 @@
1
+ from importlib.metadata import version
2
+
1
3
  from . import models
4
+ from .models.chat_stats import ChatStats, TokenUsageStats, ToolCallStats
5
+
6
+ __version__ = version("casual-mcp")
2
7
  from .mcp_tool_chat import McpToolChat
3
- from .providers.provider_factory import ProviderFactory
8
+ from .provider_factory import ProviderFactory
9
+ from .tool_cache import ToolCache
4
10
  from .utils import load_config, load_mcp_client, render_system_prompt
5
11
 
6
12
  __all__ = [
13
+ "__version__",
7
14
  "McpToolChat",
8
15
  "ProviderFactory",
16
+ "ToolCache",
9
17
  "load_config",
10
18
  "load_mcp_client",
11
19
  "render_system_prompt",
12
20
  "models",
21
+ "ChatStats",
22
+ "TokenUsageStats",
23
+ "ToolCallStats",
13
24
  ]
casual_mcp/cli.py CHANGED
@@ -1,6 +1,10 @@
1
1
  import asyncio
2
+ from typing import Any
3
+
4
+ import mcp
2
5
  import typer
3
6
  import uvicorn
7
+ from fastmcp import Client
4
8
  from rich.console import Console
5
9
  from rich.table import Table
6
10
 
@@ -10,63 +14,58 @@ from casual_mcp.utils import load_config, load_mcp_client
10
14
  app = typer.Typer()
11
15
  console = Console()
12
16
 
17
+
13
18
  @app.command()
14
- def serve(host: str = "0.0.0.0", port: int = 8000, reload: bool = True):
19
+ def serve(host: str = "0.0.0.0", port: int = 8000, reload: bool = True) -> None:
15
20
  """
16
21
  Start the Casual MCP API server.
17
22
  """
18
- uvicorn.run(
19
- "casual_mcp.main:app",
20
- host=host,
21
- port=port,
22
- reload=reload,
23
- app_dir="src"
24
- )
23
+ uvicorn.run("casual_mcp.main:app", host=host, port=port, reload=reload, app_dir="src")
24
+
25
25
 
26
26
  @app.command()
27
- def servers():
27
+ def servers() -> None:
28
28
  """
29
29
  Return a table of all configured servers
30
30
  """
31
- config = load_config('casual_mcp_config.json')
31
+ config = load_config("casual_mcp_config.json")
32
32
  table = Table("Name", "Type", "Command / Url", "Env")
33
33
 
34
34
  for name, server in config.servers.items():
35
- type = 'stdio'
35
+ server_type = "stdio"
36
36
  if isinstance(server, RemoteServerConfig):
37
- type = 'remote'
37
+ server_type = "remote"
38
38
 
39
- path = ''
39
+ path = ""
40
40
  if isinstance(server, RemoteServerConfig):
41
41
  path = server.url
42
42
  else:
43
43
  path = f"{server.command} {' '.join(server.args)}"
44
- env = ''
44
+ env = ""
45
45
 
46
- table.add_row(name, type, path, env)
46
+ table.add_row(name, server_type, path, env)
47
47
 
48
48
  console.print(table)
49
49
 
50
+
50
51
  @app.command()
51
- def models():
52
+ def models() -> None:
52
53
  """
53
54
  Return a table of all configured models
54
55
  """
55
- config = load_config('casual_mcp_config.json')
56
+ config = load_config("casual_mcp_config.json")
56
57
  table = Table("Name", "Provider", "Model", "Endpoint")
57
58
 
58
59
  for name, model in config.models.items():
59
- endpoint = ''
60
- if model.provider == 'openai':
61
- endpoint = model.endpoint or ''
62
-
60
+ endpoint = model.endpoint or ""
63
61
  table.add_row(name, model.provider, model.model, str(endpoint))
64
62
 
65
63
  console.print(table)
66
64
 
65
+
67
66
  @app.command()
68
- def tools():
69
- config = load_config('casual_mcp_config.json')
67
+ def tools() -> None:
68
+ config = load_config("casual_mcp_config.json")
70
69
  mcp_client = load_mcp_client(config)
71
70
  table = Table("Name", "Description")
72
71
  # async with mcp_client:
@@ -76,9 +75,10 @@ def tools():
76
75
  console.print(table)
77
76
 
78
77
 
79
- async def get_tools(client):
78
+ async def get_tools(client: Client[Any]) -> list[mcp.Tool]:
80
79
  async with client:
81
80
  return await client.list_tools()
82
81
 
82
+
83
83
  if __name__ == "__main__":
84
84
  app()
@@ -0,0 +1,68 @@
1
+ import mcp
2
+ from casual_llm import Tool
3
+
4
+ from casual_mcp.logging import get_logger
5
+
6
+ logger = get_logger("convert_tools")
7
+
8
+
9
+ # MCP format converters (for interop with MCP libraries)
10
+ def tool_from_mcp(mcp_tool: mcp.Tool) -> Tool:
11
+ """
12
+ Convert an MCP Tool to casual-llm Tool format.
13
+
14
+ Args:
15
+ mcp_tool: MCP Tool object (from mcp library)
16
+
17
+ Returns:
18
+ casual-llm Tool instance
19
+
20
+ Raises:
21
+ ValueError: If MCP tool is missing required fields
22
+
23
+ Examples:
24
+ >>> # Assuming mcp_tool is an MCP Tool object
25
+ >>> # tool = tool_from_mcp(mcp_tool)
26
+ >>> # assert tool.name == mcp_tool.name
27
+ pass
28
+ """
29
+ if not mcp_tool.name or not mcp_tool.description:
30
+ raise ValueError(
31
+ f"MCP tool missing required fields: "
32
+ f"name={mcp_tool.name}, description={mcp_tool.description}"
33
+ )
34
+
35
+ input_schema = getattr(mcp_tool, "inputSchema", {})
36
+ if not isinstance(input_schema, dict):
37
+ input_schema = {}
38
+
39
+ return Tool.from_input_schema(
40
+ name=mcp_tool.name, description=mcp_tool.description, input_schema=input_schema
41
+ )
42
+
43
+
44
+ def tools_from_mcp(mcp_tools: list[mcp.Tool]) -> list[Tool]:
45
+ """
46
+ Convert multiple MCP Tools to casual-llm format.
47
+
48
+ Args:
49
+ mcp_tools: List of MCP Tool objects
50
+
51
+ Returns:
52
+ List of casual-llm Tool instances
53
+
54
+ Examples:
55
+ >>> # tools = tools_from_mcp(mcp_tool_list)
56
+ >>> # assert len(tools) == len(mcp_tool_list)
57
+ pass
58
+ """
59
+ tools = []
60
+
61
+ for mcp_tool in mcp_tools:
62
+ try:
63
+ tool = tool_from_mcp(mcp_tool)
64
+ tools.append(tool)
65
+ except ValueError as e:
66
+ logger.warning(f"Skipping invalid MCP tool: {e}")
67
+
68
+ return tools
casual_mcp/logging.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import logging
2
- from typing import Literal
3
2
 
4
3
  from rich.console import Console
5
4
  from rich.logging import RichHandler
@@ -10,7 +9,7 @@ def get_logger(name: str) -> logging.Logger:
10
9
 
11
10
 
12
11
  def configure_logging(
13
- level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] | int = "INFO",
12
+ level: str | int = "INFO",
14
13
  logger: logging.Logger | None = None,
15
14
  ) -> None:
16
15
  if logger is None:
@@ -27,4 +26,9 @@ def configure_logging(
27
26
  logger.removeHandler(hdlr)
28
27
 
29
28
  logger.addHandler(handler)
29
+
30
+ # Set logging level on FastMCP and MCP libraries
31
+ logging.getLogger("fastmcp").setLevel(level)
32
+ logging.getLogger("mcp").setLevel(level)
33
+
30
34
  logger.info("Logging Configured")
casual_mcp/main.py CHANGED
@@ -1,26 +1,28 @@
1
1
  import os
2
- import sys
3
- from pathlib import Path
2
+ from typing import Any
4
3
 
4
+ from casual_llm import ChatMessage
5
5
  from dotenv import load_dotenv
6
6
  from fastapi import FastAPI, HTTPException
7
7
  from pydantic import BaseModel, Field
8
8
 
9
9
  from casual_mcp import McpToolChat
10
10
  from casual_mcp.logging import configure_logging, get_logger
11
- from casual_mcp.models.messages import ChatMessage
12
- from casual_mcp.providers.provider_factory import ProviderFactory
11
+ from casual_mcp.provider_factory import ProviderFactory
12
+ from casual_mcp.tool_cache import ToolCache
13
13
  from casual_mcp.utils import load_config, load_mcp_client, render_system_prompt
14
14
 
15
+ # Load environment variables
15
16
  load_dotenv()
16
17
 
17
18
  # Configure logging
18
- configure_logging(os.getenv("LOG_LEVEL", 'INFO'))
19
+ configure_logging(os.getenv("LOG_LEVEL", "INFO"))
19
20
  logger = get_logger("main")
20
21
 
21
22
  config = load_config("casual_mcp_config.json")
22
23
  mcp_client = load_mcp_client(config)
23
- provider_factory = ProviderFactory(mcp_client)
24
+ tool_cache = ToolCache(mcp_client)
25
+ provider_factory = ProviderFactory()
24
26
 
25
27
  app = FastAPI()
26
28
 
@@ -39,63 +41,62 @@ Always present information as current and factual.
39
41
 
40
42
 
41
43
  class GenerateRequest(BaseModel):
42
- session_id: str | None = Field(
43
- default=None, title="Session to use"
44
- )
45
- model: str = Field(
46
- title="Model to user"
47
- )
48
- system_prompt: str | None = Field(
49
- default=None, title="System Prompt to use"
50
- )
51
- prompt: str = Field(
52
- title="User Prompt"
53
- )
44
+ session_id: str | None = Field(default=None, title="Session to use")
45
+ model: str = Field(title="Model to use")
46
+ system_prompt: str | None = Field(default=None, title="System Prompt to use")
47
+ prompt: str = Field(title="User Prompt")
48
+ include_stats: bool = Field(default=False, title="Include usage statistics in response")
54
49
 
55
50
 
56
51
  class ChatRequest(BaseModel):
57
- model: str = Field(
58
- title="Model to user"
59
- )
60
- system_prompt: str | None = Field(
61
- default=None, title="System Prompt to use"
62
- )
63
- messages: list[ChatMessage] = Field(
64
- title="Previous messages to supply to the LLM"
65
- )
66
-
67
- sys.path.append(str(Path(__file__).parent.resolve()))
68
-
69
-
52
+ model: str = Field(title="Model to use")
53
+ system_prompt: str | None = Field(default=None, title="System Prompt to use")
54
+ messages: list[ChatMessage] = Field(title="Previous messages to supply to the LLM")
55
+ include_stats: bool = Field(default=False, title="Include usage statistics in response")
70
56
 
71
57
 
72
58
  @app.post("/chat")
73
- async def chat(req: ChatRequest):
74
- chat = await get_chat(req.model, req.system_prompt)
75
- messages = await chat.chat(req.messages)
59
+ async def chat(req: ChatRequest) -> dict[str, Any]:
60
+ chat_instance = await get_chat(req.model, req.system_prompt)
61
+ messages = await chat_instance.chat(req.messages)
76
62
 
77
- return {
78
- "messages": messages,
79
- "response": messages[len(messages) - 1].content
80
- }
63
+ if not messages:
64
+ error_result: dict[str, Any] = {"messages": [], "response": ""}
65
+ if req.include_stats:
66
+ error_result["stats"] = chat_instance.get_stats()
67
+ raise HTTPException(
68
+ status_code=500,
69
+ detail={"error": "No response generated", **error_result},
70
+ )
71
+
72
+ result: dict[str, Any] = {"messages": messages, "response": messages[-1].content}
73
+ if req.include_stats:
74
+ result["stats"] = chat_instance.get_stats()
75
+ return result
81
76
 
82
77
 
83
78
  @app.post("/generate")
84
- async def generate(req: GenerateRequest):
85
- chat = await get_chat(req.model, req.system_prompt)
86
- messages = await chat.generate(
87
- req.prompt,
88
- req.session_id
89
- )
79
+ async def generate(req: GenerateRequest) -> dict[str, Any]:
80
+ chat_instance = await get_chat(req.model, req.system_prompt)
81
+ messages = await chat_instance.generate(req.prompt, req.session_id)
82
+
83
+ if not messages:
84
+ error_result: dict[str, Any] = {"messages": [], "response": ""}
85
+ if req.include_stats:
86
+ error_result["stats"] = chat_instance.get_stats()
87
+ raise HTTPException(
88
+ status_code=500,
89
+ detail={"error": "No response generated", **error_result},
90
+ )
90
91
 
91
- return {
92
- "messages": messages,
93
- "response": messages[len(messages) - 1].content
94
- }
92
+ result: dict[str, Any] = {"messages": messages, "response": messages[-1].content}
93
+ if req.include_stats:
94
+ result["stats"] = chat_instance.get_stats()
95
+ return result
95
96
 
96
97
 
97
98
  @app.get("/generate/session/{session_id}")
98
- async def get_generate_session(session_id):
99
+ async def get_generate_session(session_id: str) -> list[ChatMessage]:
99
100
  session = McpToolChat.get_session(session_id)
100
101
  if not session:
101
102
  raise HTTPException(status_code=404, detail="Session not found")
@@ -106,17 +107,14 @@ async def get_generate_session(session_id):
106
107
  async def get_chat(model: str, system: str | None = None) -> McpToolChat:
107
108
  # Get Provider from Model Config
108
109
  model_config = config.models[model]
109
- provider = await provider_factory.get_provider(model, model_config)
110
+ provider = provider_factory.get_provider(model, model_config)
110
111
 
111
112
  # Get the system prompt
112
113
  if not system:
113
- if (model_config.template):
114
- async with mcp_client:
115
- system = render_system_prompt(
116
- f"{model_config.template}.j2",
117
- await mcp_client.list_tools()
118
- )
114
+ if model_config.template:
115
+ tools = await tool_cache.get_tools()
116
+ system = render_system_prompt(f"{model_config.template}.j2", tools)
119
117
  else:
120
118
  system = default_system_prompt
121
119
 
122
- return McpToolChat(mcp_client, provider, system)
120
+ return McpToolChat(mcp_client, provider, system, tool_cache=tool_cache)
@@ -1,62 +1,92 @@
1
1
  import json
2
2
  import os
3
+ from typing import Any
3
4
 
4
- from fastmcp import Client
5
-
6
- from casual_mcp.logging import get_logger
7
- from casual_mcp.models.messages import (
5
+ from casual_llm import (
6
+ AssistantToolCall,
8
7
  ChatMessage,
8
+ LLMProvider,
9
9
  SystemMessage,
10
10
  ToolResultMessage,
11
11
  UserMessage,
12
12
  )
13
- from casual_mcp.models.tool_call import AssistantToolCall
14
- from casual_mcp.providers.provider_factory import LLMProvider
13
+ from fastmcp import Client
14
+
15
+ from casual_mcp.convert_tools import tools_from_mcp
16
+ from casual_mcp.logging import get_logger
17
+ from casual_mcp.models.chat_stats import ChatStats
18
+ from casual_mcp.tool_cache import ToolCache
15
19
  from casual_mcp.utils import format_tool_call_result
16
20
 
17
21
  logger = get_logger("mcp_tool_chat")
18
22
  sessions: dict[str, list[ChatMessage]] = {}
19
23
 
20
24
 
21
- def get_session_messages(session_id: str | None):
25
+ def get_session_messages(session_id: str) -> list[ChatMessage]:
22
26
  global sessions
23
27
 
24
- if not sessions.get(session_id):
28
+ if session_id not in sessions:
25
29
  logger.info(f"Starting new session {session_id}")
26
30
  sessions[session_id] = []
27
31
  else:
28
- logger.info(
29
- f"Retrieving session {session_id} of length {len(sessions[session_id])}"
30
- )
32
+ logger.info(f"Retrieving session {session_id} of length {len(sessions[session_id])}")
31
33
  return sessions[session_id].copy()
32
34
 
33
35
 
34
- def add_messages_to_session(session_id: str, messages: list[ChatMessage]):
36
+ def add_messages_to_session(session_id: str, messages: list[ChatMessage]) -> None:
35
37
  global sessions
36
38
  sessions[session_id].extend(messages.copy())
37
39
 
38
40
 
39
41
  class McpToolChat:
40
- def __init__(self, mcp_client: Client, provider: LLMProvider, system: str = None):
42
+ def __init__(
43
+ self,
44
+ mcp_client: Client[Any],
45
+ provider: LLMProvider,
46
+ system: str | None = None,
47
+ tool_cache: ToolCache | None = None,
48
+ ):
41
49
  self.provider = provider
42
50
  self.mcp_client = mcp_client
43
51
  self.system = system
52
+ self.tool_cache = tool_cache or ToolCache(mcp_client)
53
+ self._tool_cache_version = -1
54
+ self._last_stats: ChatStats | None = None
44
55
 
45
56
  @staticmethod
46
- def get_session(session_id) -> list[ChatMessage] | None:
57
+ def get_session(session_id: str) -> list[ChatMessage] | None:
47
58
  global sessions
48
59
  return sessions.get(session_id)
49
60
 
50
- async def generate(
51
- self,
52
- prompt: str,
53
- session_id: str | None = None
54
- ) -> list[ChatMessage]:
61
+ def get_stats(self) -> ChatStats | None:
62
+ """
63
+ Get usage statistics from the last chat() or generate() call.
64
+
65
+ Returns None if no calls have been made yet.
66
+ Stats are reset at the start of each new chat()/generate() call.
67
+ """
68
+ return self._last_stats
69
+
70
+ def _extract_server_from_tool_name(self, tool_name: str) -> str:
71
+ """
72
+ Extract server name from a tool name.
73
+
74
+ With multiple servers, fastmcp prefixes tools as "serverName_toolName".
75
+ With a single server, tools are not prefixed.
76
+
77
+ Returns the server name or "default" if it cannot be determined.
78
+ """
79
+ if "_" in tool_name:
80
+ return tool_name.split("_", 1)[0]
81
+ return "default"
82
+
83
+ async def generate(self, prompt: str, session_id: str | None = None) -> list[ChatMessage]:
55
84
  # Fetch the session if we have a session ID
85
+ messages: list[ChatMessage]
56
86
  if session_id:
57
87
  messages = get_session_messages(session_id)
58
88
  else:
59
- messages: list[ChatMessage] = []
89
+ messages = []
60
90
 
61
91
  # Add the prompt as a user message
62
92
  user_message = UserMessage(content=prompt)
@@ -75,56 +105,75 @@ class McpToolChat:
75
105
 
76
106
  return response
77
107
 
108
+ async def chat(self, messages: list[ChatMessage]) -> list[ChatMessage]:
109
+ tools = await self.tool_cache.get_tools()
110
+
111
+ # Reset stats at the start of each chat
112
+ self._last_stats = ChatStats()
78
113
 
79
- async def chat(
80
- self,
81
- messages: list[ChatMessage]
82
- ) -> list[ChatMessage]:
83
114
  # Add a system message if required
84
- has_system_message = any(message.role == 'system' for message in messages)
115
+ has_system_message = any(message.role == "system" for message in messages)
85
116
  if self.system and not has_system_message:
86
117
  # Insert the system message at the start of the messages
87
- logger.debug(f"Adding System Message")
118
+ logger.debug("Adding System Message")
88
119
  messages.insert(0, SystemMessage(content=self.system))
89
120
 
90
121
  logger.info("Start Chat")
91
- async with self.mcp_client:
92
- tools = await self.mcp_client.list_tools()
93
-
94
122
  response_messages: list[ChatMessage] = []
95
123
  while True:
96
124
  logger.info("Calling the LLM")
97
- ai_message = await self.provider.generate(messages, tools)
125
+ ai_message = await self.provider.chat(messages=messages, tools=tools_from_mcp(tools))
126
+
127
+ # Accumulate token usage stats
128
+ self._last_stats.llm_calls += 1
129
+ usage = self.provider.get_usage()
130
+ if usage:
131
+ prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
132
+ completion_tokens = getattr(usage, "completion_tokens", 0) or 0
133
+ self._last_stats.tokens.prompt_tokens += prompt_tokens
134
+ self._last_stats.tokens.completion_tokens += completion_tokens
98
135
 
99
136
  # Add the assistant's message
100
137
  response_messages.append(ai_message)
101
138
  messages.append(ai_message)
102
139
 
140
+ logger.debug(f"Assistant: {ai_message}")
103
141
  if not ai_message.tool_calls:
104
142
  break
105
143
 
106
- if ai_message.tool_calls and len(ai_message.tool_calls) > 0:
107
- logger.info(f"Executing {len(ai_message.tool_calls)} tool calls")
108
- result_count = 0
109
- for tool_call in ai_message.tool_calls:
110
- try:
111
- result = await self.execute(tool_call)
112
- except Exception as e:
113
- logger.error(e)
114
- return messages
115
- if result:
116
- messages.append(result)
117
- response_messages.append(result)
118
- result_count = result_count + 1
119
-
120
- logger.info(f"Added {result_count} tool results")
144
+ logger.info(f"Executing {len(ai_message.tool_calls)} tool calls")
145
+ result_count = 0
146
+ for tool_call in ai_message.tool_calls:
147
+ # Track tool call stats
148
+ tool_name = tool_call.function.name
149
+ self._last_stats.tool_calls.by_tool[tool_name] = (
150
+ self._last_stats.tool_calls.by_tool.get(tool_name, 0) + 1
151
+ )
152
+ server_name = self._extract_server_from_tool_name(tool_name)
153
+ self._last_stats.tool_calls.by_server[server_name] = (
154
+ self._last_stats.tool_calls.by_server.get(server_name, 0) + 1
155
+ )
156
+
157
+ try:
158
+ result = await self.execute(tool_call)
159
+ except Exception as e:
160
+ logger.error(
161
+ f"Failed to execute tool '{tool_call.function.name}' "
162
+ f"(id={tool_call.id}): {e}"
163
+ )
164
+ continue
165
+ if result:
166
+ messages.append(result)
167
+ response_messages.append(result)
168
+ result_count = result_count + 1
169
+
170
+ logger.info(f"Added {result_count} tool results")
121
171
 
122
172
  logger.debug(f"Final Response: {response_messages[-1].content}")
123
173
 
124
174
  return response_messages
125
175
 
126
-
127
- async def execute(self, tool_call: AssistantToolCall):
176
+ async def execute(self, tool_call: AssistantToolCall) -> ToolResultMessage:
128
177
  tool_name = tool_call.function.name
129
178
  tool_args = json.loads(tool_call.function.arguments)
130
179
  try:
@@ -144,8 +193,37 @@ class McpToolChat:
144
193
 
145
194
  logger.debug(f"Tool Call Result: {result}")
146
195
 
147
- result_format = os.getenv('TOOL_RESULT_FORMAT', 'result')
148
- content = format_tool_call_result(tool_call, result.content[0].text, style=result_format)
196
+ result_format = os.getenv("TOOL_RESULT_FORMAT", "result")
197
+
198
+ # Prefer structuredContent when available (machine-readable format)
199
+ # Note: MCP types use camelCase (structuredContent), mypy stubs may differ
200
+ structured = getattr(result, "structuredContent", None)
201
+ if structured is not None:
202
+ try:
203
+ content_text = json.dumps(structured)
204
+ except (TypeError, ValueError):
205
+ content_text = str(structured)
206
+ elif not result.content:
207
+ content_text = "[No content returned]"
208
+ else:
209
+ # Fall back to processing content items
210
+ content_parts: list[Any] = []
211
+ for content_item in result.content:
212
+ if content_item.type == "text":
213
+ try:
214
+ parsed = json.loads(content_item.text)
215
+ content_parts.append(parsed)
216
+ except json.JSONDecodeError:
217
+ content_parts.append(content_item.text)
218
+ elif hasattr(content_item, "mimeType"):
219
+ # Image or audio content
220
+ content_parts.append(f"[{content_item.type}: {content_item.mimeType}]")
221
+ else:
222
+ content_parts.append(str(content_item))
223
+
224
+ content_text = json.dumps(content_parts)
225
+
226
+ content = format_tool_call_result(tool_call, content_text, style=result_format)
149
227
 
150
228
  return ToolResultMessage(
151
229
  name=tool_call.function.name,