casual-mcp 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- casual_mcp/__init__.py +12 -1
- casual_mcp/cli.py +24 -24
- casual_mcp/convert_tools.py +68 -0
- casual_mcp/logging.py +6 -2
- casual_mcp/main.py +55 -57
- casual_mcp/mcp_tool_chat.py +127 -49
- casual_mcp/models/__init__.py +21 -8
- casual_mcp/models/chat_stats.py +37 -0
- casual_mcp/models/config.py +2 -2
- casual_mcp/models/generation_error.py +1 -1
- casual_mcp/models/model_config.py +3 -3
- casual_mcp/provider_factory.py +47 -0
- casual_mcp/tool_cache.py +114 -0
- casual_mcp/utils.py +18 -11
- casual_mcp-0.6.0.dist-info/METADATA +691 -0
- casual_mcp-0.6.0.dist-info/RECORD +21 -0
- {casual_mcp-0.4.0.dist-info → casual_mcp-0.6.0.dist-info}/WHEEL +1 -1
- casual_mcp/models/messages.py +0 -31
- casual_mcp/models/tool_call.py +0 -14
- casual_mcp/providers/__init__.py +0 -0
- casual_mcp/providers/abstract_provider.py +0 -15
- casual_mcp/providers/ollama_provider.py +0 -72
- casual_mcp/providers/openai_provider.py +0 -179
- casual_mcp/providers/provider_factory.py +0 -56
- casual_mcp-0.4.0.dist-info/METADATA +0 -399
- casual_mcp-0.4.0.dist-info/RECORD +0 -24
- {casual_mcp-0.4.0.dist-info → casual_mcp-0.6.0.dist-info}/entry_points.txt +0 -0
- {casual_mcp-0.4.0.dist-info → casual_mcp-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {casual_mcp-0.4.0.dist-info → casual_mcp-0.6.0.dist-info}/top_level.txt +0 -0
casual_mcp/__init__.py
CHANGED
|
@@ -1,13 +1,24 @@
|
|
|
1
|
+
from importlib.metadata import version
|
|
2
|
+
|
|
1
3
|
from . import models
|
|
4
|
+
from .models.chat_stats import ChatStats, TokenUsageStats, ToolCallStats
|
|
5
|
+
|
|
6
|
+
__version__ = version("casual-mcp")
|
|
2
7
|
from .mcp_tool_chat import McpToolChat
|
|
3
|
-
from .
|
|
8
|
+
from .provider_factory import ProviderFactory
|
|
9
|
+
from .tool_cache import ToolCache
|
|
4
10
|
from .utils import load_config, load_mcp_client, render_system_prompt
|
|
5
11
|
|
|
6
12
|
__all__ = [
|
|
13
|
+
"__version__",
|
|
7
14
|
"McpToolChat",
|
|
8
15
|
"ProviderFactory",
|
|
16
|
+
"ToolCache",
|
|
9
17
|
"load_config",
|
|
10
18
|
"load_mcp_client",
|
|
11
19
|
"render_system_prompt",
|
|
12
20
|
"models",
|
|
21
|
+
"ChatStats",
|
|
22
|
+
"TokenUsageStats",
|
|
23
|
+
"ToolCallStats",
|
|
13
24
|
]
|
casual_mcp/cli.py
CHANGED
|
@@ -1,6 +1,10 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import mcp
|
|
2
5
|
import typer
|
|
3
6
|
import uvicorn
|
|
7
|
+
from fastmcp import Client
|
|
4
8
|
from rich.console import Console
|
|
5
9
|
from rich.table import Table
|
|
6
10
|
|
|
@@ -10,63 +14,58 @@ from casual_mcp.utils import load_config, load_mcp_client
|
|
|
10
14
|
app = typer.Typer()
|
|
11
15
|
console = Console()
|
|
12
16
|
|
|
17
|
+
|
|
13
18
|
@app.command()
|
|
14
|
-
def serve(host: str = "0.0.0.0", port: int = 8000, reload: bool = True):
|
|
19
|
+
def serve(host: str = "0.0.0.0", port: int = 8000, reload: bool = True) -> None:
|
|
15
20
|
"""
|
|
16
21
|
Start the Casual MCP API server.
|
|
17
22
|
"""
|
|
18
|
-
uvicorn.run(
|
|
19
|
-
|
|
20
|
-
host=host,
|
|
21
|
-
port=port,
|
|
22
|
-
reload=reload,
|
|
23
|
-
app_dir="src"
|
|
24
|
-
)
|
|
23
|
+
uvicorn.run("casual_mcp.main:app", host=host, port=port, reload=reload, app_dir="src")
|
|
24
|
+
|
|
25
25
|
|
|
26
26
|
@app.command()
|
|
27
|
-
def servers():
|
|
27
|
+
def servers() -> None:
|
|
28
28
|
"""
|
|
29
29
|
Return a table of all configured servers
|
|
30
30
|
"""
|
|
31
|
-
config = load_config(
|
|
31
|
+
config = load_config("casual_mcp_config.json")
|
|
32
32
|
table = Table("Name", "Type", "Command / Url", "Env")
|
|
33
33
|
|
|
34
34
|
for name, server in config.servers.items():
|
|
35
|
-
|
|
35
|
+
server_type = "stdio"
|
|
36
36
|
if isinstance(server, RemoteServerConfig):
|
|
37
|
-
|
|
37
|
+
server_type = "remote"
|
|
38
38
|
|
|
39
|
-
path =
|
|
39
|
+
path = ""
|
|
40
40
|
if isinstance(server, RemoteServerConfig):
|
|
41
41
|
path = server.url
|
|
42
42
|
else:
|
|
43
43
|
path = f"{server.command} {' '.join(server.args)}"
|
|
44
|
-
env =
|
|
44
|
+
env = ""
|
|
45
45
|
|
|
46
|
-
table.add_row(name,
|
|
46
|
+
table.add_row(name, server_type, path, env)
|
|
47
47
|
|
|
48
48
|
console.print(table)
|
|
49
49
|
|
|
50
|
+
|
|
50
51
|
@app.command()
|
|
51
|
-
def models():
|
|
52
|
+
def models() -> None:
|
|
52
53
|
"""
|
|
53
54
|
Return a table of all configured models
|
|
54
55
|
"""
|
|
55
|
-
config = load_config(
|
|
56
|
+
config = load_config("casual_mcp_config.json")
|
|
56
57
|
table = Table("Name", "Provider", "Model", "Endpoint")
|
|
57
58
|
|
|
58
59
|
for name, model in config.models.items():
|
|
59
|
-
endpoint =
|
|
60
|
-
if model.provider == 'openai':
|
|
61
|
-
endpoint = model.endpoint or ''
|
|
62
|
-
|
|
60
|
+
endpoint = model.endpoint or ""
|
|
63
61
|
table.add_row(name, model.provider, model.model, str(endpoint))
|
|
64
62
|
|
|
65
63
|
console.print(table)
|
|
66
64
|
|
|
65
|
+
|
|
67
66
|
@app.command()
|
|
68
|
-
def tools():
|
|
69
|
-
config = load_config(
|
|
67
|
+
def tools() -> None:
|
|
68
|
+
config = load_config("casual_mcp_config.json")
|
|
70
69
|
mcp_client = load_mcp_client(config)
|
|
71
70
|
table = Table("Name", "Description")
|
|
72
71
|
# async with mcp_client:
|
|
@@ -76,9 +75,10 @@ def tools():
|
|
|
76
75
|
console.print(table)
|
|
77
76
|
|
|
78
77
|
|
|
79
|
-
async def get_tools(client):
|
|
78
|
+
async def get_tools(client: Client[Any]) -> list[mcp.Tool]:
|
|
80
79
|
async with client:
|
|
81
80
|
return await client.list_tools()
|
|
82
81
|
|
|
82
|
+
|
|
83
83
|
if __name__ == "__main__":
|
|
84
84
|
app()
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import mcp
|
|
2
|
+
from casual_llm import Tool
|
|
3
|
+
|
|
4
|
+
from casual_mcp.logging import get_logger
|
|
5
|
+
|
|
6
|
+
logger = get_logger("convert_tools")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# MCP format converters (for interop with MCP libraries)
|
|
10
|
+
def tool_from_mcp(mcp_tool: mcp.Tool) -> Tool:
|
|
11
|
+
"""
|
|
12
|
+
Convert an MCP Tool to casual-llm Tool format.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
mcp_tool: MCP Tool object (from mcp library)
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
casual-llm Tool instance
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
ValueError: If MCP tool is missing required fields
|
|
22
|
+
|
|
23
|
+
Examples:
|
|
24
|
+
>>> # Assuming mcp_tool is an MCP Tool object
|
|
25
|
+
>>> # tool = tool_from_mcp(mcp_tool)
|
|
26
|
+
>>> # assert tool.name == mcp_tool.name
|
|
27
|
+
pass
|
|
28
|
+
"""
|
|
29
|
+
if not mcp_tool.name or not mcp_tool.description:
|
|
30
|
+
raise ValueError(
|
|
31
|
+
f"MCP tool missing required fields: "
|
|
32
|
+
f"name={mcp_tool.name}, description={mcp_tool.description}"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
input_schema = getattr(mcp_tool, "inputSchema", {})
|
|
36
|
+
if not isinstance(input_schema, dict):
|
|
37
|
+
input_schema = {}
|
|
38
|
+
|
|
39
|
+
return Tool.from_input_schema(
|
|
40
|
+
name=mcp_tool.name, description=mcp_tool.description, input_schema=input_schema
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def tools_from_mcp(mcp_tools: list[mcp.Tool]) -> list[Tool]:
|
|
45
|
+
"""
|
|
46
|
+
Convert multiple MCP Tools to casual-llm format.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
mcp_tools: List of MCP Tool objects
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
List of casual-llm Tool instances
|
|
53
|
+
|
|
54
|
+
Examples:
|
|
55
|
+
>>> # tools = tools_from_mcp(mcp_tool_list)
|
|
56
|
+
>>> # assert len(tools) == len(mcp_tool_list)
|
|
57
|
+
pass
|
|
58
|
+
"""
|
|
59
|
+
tools = []
|
|
60
|
+
|
|
61
|
+
for mcp_tool in mcp_tools:
|
|
62
|
+
try:
|
|
63
|
+
tool = tool_from_mcp(mcp_tool)
|
|
64
|
+
tools.append(tool)
|
|
65
|
+
except ValueError as e:
|
|
66
|
+
logger.warning(f"Skipping invalid MCP tool: {e}")
|
|
67
|
+
|
|
68
|
+
return tools
|
casual_mcp/logging.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Literal
|
|
3
2
|
|
|
4
3
|
from rich.console import Console
|
|
5
4
|
from rich.logging import RichHandler
|
|
@@ -10,7 +9,7 @@ def get_logger(name: str) -> logging.Logger:
|
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
def configure_logging(
|
|
13
|
-
level:
|
|
12
|
+
level: str | int = "INFO",
|
|
14
13
|
logger: logging.Logger | None = None,
|
|
15
14
|
) -> None:
|
|
16
15
|
if logger is None:
|
|
@@ -27,4 +26,9 @@ def configure_logging(
|
|
|
27
26
|
logger.removeHandler(hdlr)
|
|
28
27
|
|
|
29
28
|
logger.addHandler(handler)
|
|
29
|
+
|
|
30
|
+
# Set logging level on FastMCP and MCP libraries
|
|
31
|
+
logging.getLogger("fastmcp").setLevel(level)
|
|
32
|
+
logging.getLogger("mcp").setLevel(level)
|
|
33
|
+
|
|
30
34
|
logger.info("Logging Configured")
|
casual_mcp/main.py
CHANGED
|
@@ -1,26 +1,28 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import
|
|
3
|
-
from pathlib import Path
|
|
2
|
+
from typing import Any
|
|
4
3
|
|
|
4
|
+
from casual_llm import ChatMessage
|
|
5
5
|
from dotenv import load_dotenv
|
|
6
6
|
from fastapi import FastAPI, HTTPException
|
|
7
7
|
from pydantic import BaseModel, Field
|
|
8
8
|
|
|
9
9
|
from casual_mcp import McpToolChat
|
|
10
10
|
from casual_mcp.logging import configure_logging, get_logger
|
|
11
|
-
from casual_mcp.
|
|
12
|
-
from casual_mcp.
|
|
11
|
+
from casual_mcp.provider_factory import ProviderFactory
|
|
12
|
+
from casual_mcp.tool_cache import ToolCache
|
|
13
13
|
from casual_mcp.utils import load_config, load_mcp_client, render_system_prompt
|
|
14
14
|
|
|
15
|
+
# Load environment variables
|
|
15
16
|
load_dotenv()
|
|
16
17
|
|
|
17
18
|
# Configure logging
|
|
18
|
-
configure_logging(os.getenv("LOG_LEVEL",
|
|
19
|
+
configure_logging(os.getenv("LOG_LEVEL", "INFO"))
|
|
19
20
|
logger = get_logger("main")
|
|
20
21
|
|
|
21
22
|
config = load_config("casual_mcp_config.json")
|
|
22
23
|
mcp_client = load_mcp_client(config)
|
|
23
|
-
|
|
24
|
+
tool_cache = ToolCache(mcp_client)
|
|
25
|
+
provider_factory = ProviderFactory()
|
|
24
26
|
|
|
25
27
|
app = FastAPI()
|
|
26
28
|
|
|
@@ -39,63 +41,62 @@ Always present information as current and factual.
|
|
|
39
41
|
|
|
40
42
|
|
|
41
43
|
class GenerateRequest(BaseModel):
|
|
42
|
-
session_id: str | None = Field(
|
|
43
|
-
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
)
|
|
48
|
-
system_prompt: str | None = Field(
|
|
49
|
-
default=None, title="System Prompt to use"
|
|
50
|
-
)
|
|
51
|
-
prompt: str = Field(
|
|
52
|
-
title="User Prompt"
|
|
53
|
-
)
|
|
44
|
+
session_id: str | None = Field(default=None, title="Session to use")
|
|
45
|
+
model: str = Field(title="Model to use")
|
|
46
|
+
system_prompt: str | None = Field(default=None, title="System Prompt to use")
|
|
47
|
+
prompt: str = Field(title="User Prompt")
|
|
48
|
+
include_stats: bool = Field(default=False, title="Include usage statistics in response")
|
|
54
49
|
|
|
55
50
|
|
|
56
51
|
class ChatRequest(BaseModel):
|
|
57
|
-
model: str = Field(
|
|
58
|
-
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
default=None, title="System Prompt to use"
|
|
62
|
-
)
|
|
63
|
-
messages: list[ChatMessage] = Field(
|
|
64
|
-
title="Previous messages to supply to the LLM"
|
|
65
|
-
)
|
|
66
|
-
|
|
67
|
-
sys.path.append(str(Path(__file__).parent.resolve()))
|
|
68
|
-
|
|
69
|
-
|
|
52
|
+
model: str = Field(title="Model to use")
|
|
53
|
+
system_prompt: str | None = Field(default=None, title="System Prompt to use")
|
|
54
|
+
messages: list[ChatMessage] = Field(title="Previous messages to supply to the LLM")
|
|
55
|
+
include_stats: bool = Field(default=False, title="Include usage statistics in response")
|
|
70
56
|
|
|
71
57
|
|
|
72
58
|
@app.post("/chat")
|
|
73
|
-
async def chat(req: ChatRequest):
|
|
74
|
-
|
|
75
|
-
messages = await
|
|
59
|
+
async def chat(req: ChatRequest) -> dict[str, Any]:
|
|
60
|
+
chat_instance = await get_chat(req.model, req.system_prompt)
|
|
61
|
+
messages = await chat_instance.chat(req.messages)
|
|
76
62
|
|
|
77
|
-
|
|
78
|
-
"messages":
|
|
79
|
-
|
|
80
|
-
|
|
63
|
+
if not messages:
|
|
64
|
+
error_result: dict[str, Any] = {"messages": [], "response": ""}
|
|
65
|
+
if req.include_stats:
|
|
66
|
+
error_result["stats"] = chat_instance.get_stats()
|
|
67
|
+
raise HTTPException(
|
|
68
|
+
status_code=500,
|
|
69
|
+
detail={"error": "No response generated", **error_result},
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
result: dict[str, Any] = {"messages": messages, "response": messages[-1].content}
|
|
73
|
+
if req.include_stats:
|
|
74
|
+
result["stats"] = chat_instance.get_stats()
|
|
75
|
+
return result
|
|
81
76
|
|
|
82
77
|
|
|
83
78
|
@app.post("/generate")
|
|
84
|
-
async def generate(req: GenerateRequest):
|
|
85
|
-
|
|
86
|
-
messages = await
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
79
|
+
async def generate(req: GenerateRequest) -> dict[str, Any]:
|
|
80
|
+
chat_instance = await get_chat(req.model, req.system_prompt)
|
|
81
|
+
messages = await chat_instance.generate(req.prompt, req.session_id)
|
|
82
|
+
|
|
83
|
+
if not messages:
|
|
84
|
+
error_result: dict[str, Any] = {"messages": [], "response": ""}
|
|
85
|
+
if req.include_stats:
|
|
86
|
+
error_result["stats"] = chat_instance.get_stats()
|
|
87
|
+
raise HTTPException(
|
|
88
|
+
status_code=500,
|
|
89
|
+
detail={"error": "No response generated", **error_result},
|
|
90
|
+
)
|
|
90
91
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
"
|
|
94
|
-
|
|
92
|
+
result: dict[str, Any] = {"messages": messages, "response": messages[-1].content}
|
|
93
|
+
if req.include_stats:
|
|
94
|
+
result["stats"] = chat_instance.get_stats()
|
|
95
|
+
return result
|
|
95
96
|
|
|
96
97
|
|
|
97
98
|
@app.get("/generate/session/{session_id}")
|
|
98
|
-
async def get_generate_session(session_id):
|
|
99
|
+
async def get_generate_session(session_id: str) -> list[ChatMessage]:
|
|
99
100
|
session = McpToolChat.get_session(session_id)
|
|
100
101
|
if not session:
|
|
101
102
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
@@ -106,17 +107,14 @@ async def get_generate_session(session_id):
|
|
|
106
107
|
async def get_chat(model: str, system: str | None = None) -> McpToolChat:
|
|
107
108
|
# Get Provider from Model Config
|
|
108
109
|
model_config = config.models[model]
|
|
109
|
-
provider =
|
|
110
|
+
provider = provider_factory.get_provider(model, model_config)
|
|
110
111
|
|
|
111
112
|
# Get the system prompt
|
|
112
113
|
if not system:
|
|
113
|
-
if
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
f"{model_config.template}.j2",
|
|
117
|
-
await mcp_client.list_tools()
|
|
118
|
-
)
|
|
114
|
+
if model_config.template:
|
|
115
|
+
tools = await tool_cache.get_tools()
|
|
116
|
+
system = render_system_prompt(f"{model_config.template}.j2", tools)
|
|
119
117
|
else:
|
|
120
118
|
system = default_system_prompt
|
|
121
119
|
|
|
122
|
-
return McpToolChat(mcp_client, provider, system)
|
|
120
|
+
return McpToolChat(mcp_client, provider, system, tool_cache=tool_cache)
|
casual_mcp/mcp_tool_chat.py
CHANGED
|
@@ -1,62 +1,92 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
|
+
from typing import Any
|
|
3
4
|
|
|
4
|
-
from
|
|
5
|
-
|
|
6
|
-
from casual_mcp.logging import get_logger
|
|
7
|
-
from casual_mcp.models.messages import (
|
|
5
|
+
from casual_llm import (
|
|
6
|
+
AssistantToolCall,
|
|
8
7
|
ChatMessage,
|
|
8
|
+
LLMProvider,
|
|
9
9
|
SystemMessage,
|
|
10
10
|
ToolResultMessage,
|
|
11
11
|
UserMessage,
|
|
12
12
|
)
|
|
13
|
-
from
|
|
14
|
-
|
|
13
|
+
from fastmcp import Client
|
|
14
|
+
|
|
15
|
+
from casual_mcp.convert_tools import tools_from_mcp
|
|
16
|
+
from casual_mcp.logging import get_logger
|
|
17
|
+
from casual_mcp.models.chat_stats import ChatStats
|
|
18
|
+
from casual_mcp.tool_cache import ToolCache
|
|
15
19
|
from casual_mcp.utils import format_tool_call_result
|
|
16
20
|
|
|
17
21
|
logger = get_logger("mcp_tool_chat")
|
|
18
22
|
sessions: dict[str, list[ChatMessage]] = {}
|
|
19
23
|
|
|
20
24
|
|
|
21
|
-
def get_session_messages(session_id: str
|
|
25
|
+
def get_session_messages(session_id: str) -> list[ChatMessage]:
|
|
22
26
|
global sessions
|
|
23
27
|
|
|
24
|
-
if not sessions
|
|
28
|
+
if session_id not in sessions:
|
|
25
29
|
logger.info(f"Starting new session {session_id}")
|
|
26
30
|
sessions[session_id] = []
|
|
27
31
|
else:
|
|
28
|
-
logger.info(
|
|
29
|
-
f"Retrieving session {session_id} of length {len(sessions[session_id])}"
|
|
30
|
-
)
|
|
32
|
+
logger.info(f"Retrieving session {session_id} of length {len(sessions[session_id])}")
|
|
31
33
|
return sessions[session_id].copy()
|
|
32
34
|
|
|
33
35
|
|
|
34
|
-
def add_messages_to_session(session_id: str, messages: list[ChatMessage]):
|
|
36
|
+
def add_messages_to_session(session_id: str, messages: list[ChatMessage]) -> None:
|
|
35
37
|
global sessions
|
|
36
38
|
sessions[session_id].extend(messages.copy())
|
|
37
39
|
|
|
38
40
|
|
|
39
41
|
class McpToolChat:
|
|
40
|
-
def __init__(
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
mcp_client: Client[Any],
|
|
45
|
+
provider: LLMProvider,
|
|
46
|
+
system: str | None = None,
|
|
47
|
+
tool_cache: ToolCache | None = None,
|
|
48
|
+
):
|
|
41
49
|
self.provider = provider
|
|
42
50
|
self.mcp_client = mcp_client
|
|
43
51
|
self.system = system
|
|
52
|
+
self.tool_cache = tool_cache or ToolCache(mcp_client)
|
|
53
|
+
self._tool_cache_version = -1
|
|
54
|
+
self._last_stats: ChatStats | None = None
|
|
44
55
|
|
|
45
56
|
@staticmethod
|
|
46
|
-
def get_session(session_id) -> list[ChatMessage] | None:
|
|
57
|
+
def get_session(session_id: str) -> list[ChatMessage] | None:
|
|
47
58
|
global sessions
|
|
48
59
|
return sessions.get(session_id)
|
|
49
60
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
61
|
+
def get_stats(self) -> ChatStats | None:
|
|
62
|
+
"""
|
|
63
|
+
Get usage statistics from the last chat() or generate() call.
|
|
64
|
+
|
|
65
|
+
Returns None if no calls have been made yet.
|
|
66
|
+
Stats are reset at the start of each new chat()/generate() call.
|
|
67
|
+
"""
|
|
68
|
+
return self._last_stats
|
|
69
|
+
|
|
70
|
+
def _extract_server_from_tool_name(self, tool_name: str) -> str:
|
|
71
|
+
"""
|
|
72
|
+
Extract server name from a tool name.
|
|
73
|
+
|
|
74
|
+
With multiple servers, fastmcp prefixes tools as "serverName_toolName".
|
|
75
|
+
With a single server, tools are not prefixed.
|
|
76
|
+
|
|
77
|
+
Returns the server name or "default" if it cannot be determined.
|
|
78
|
+
"""
|
|
79
|
+
if "_" in tool_name:
|
|
80
|
+
return tool_name.split("_", 1)[0]
|
|
81
|
+
return "default"
|
|
82
|
+
|
|
83
|
+
async def generate(self, prompt: str, session_id: str | None = None) -> list[ChatMessage]:
|
|
55
84
|
# Fetch the session if we have a session ID
|
|
85
|
+
messages: list[ChatMessage]
|
|
56
86
|
if session_id:
|
|
57
87
|
messages = get_session_messages(session_id)
|
|
58
88
|
else:
|
|
59
|
-
messages
|
|
89
|
+
messages = []
|
|
60
90
|
|
|
61
91
|
# Add the prompt as a user message
|
|
62
92
|
user_message = UserMessage(content=prompt)
|
|
@@ -75,56 +105,75 @@ class McpToolChat:
|
|
|
75
105
|
|
|
76
106
|
return response
|
|
77
107
|
|
|
108
|
+
async def chat(self, messages: list[ChatMessage]) -> list[ChatMessage]:
|
|
109
|
+
tools = await self.tool_cache.get_tools()
|
|
110
|
+
|
|
111
|
+
# Reset stats at the start of each chat
|
|
112
|
+
self._last_stats = ChatStats()
|
|
78
113
|
|
|
79
|
-
async def chat(
|
|
80
|
-
self,
|
|
81
|
-
messages: list[ChatMessage]
|
|
82
|
-
) -> list[ChatMessage]:
|
|
83
114
|
# Add a system message if required
|
|
84
|
-
has_system_message = any(message.role ==
|
|
115
|
+
has_system_message = any(message.role == "system" for message in messages)
|
|
85
116
|
if self.system and not has_system_message:
|
|
86
117
|
# Insert the system message at the start of the messages
|
|
87
|
-
logger.debug(
|
|
118
|
+
logger.debug("Adding System Message")
|
|
88
119
|
messages.insert(0, SystemMessage(content=self.system))
|
|
89
120
|
|
|
90
121
|
logger.info("Start Chat")
|
|
91
|
-
async with self.mcp_client:
|
|
92
|
-
tools = await self.mcp_client.list_tools()
|
|
93
|
-
|
|
94
122
|
response_messages: list[ChatMessage] = []
|
|
95
123
|
while True:
|
|
96
124
|
logger.info("Calling the LLM")
|
|
97
|
-
ai_message = await self.provider.
|
|
125
|
+
ai_message = await self.provider.chat(messages=messages, tools=tools_from_mcp(tools))
|
|
126
|
+
|
|
127
|
+
# Accumulate token usage stats
|
|
128
|
+
self._last_stats.llm_calls += 1
|
|
129
|
+
usage = self.provider.get_usage()
|
|
130
|
+
if usage:
|
|
131
|
+
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
|
132
|
+
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
|
133
|
+
self._last_stats.tokens.prompt_tokens += prompt_tokens
|
|
134
|
+
self._last_stats.tokens.completion_tokens += completion_tokens
|
|
98
135
|
|
|
99
136
|
# Add the assistant's message
|
|
100
137
|
response_messages.append(ai_message)
|
|
101
138
|
messages.append(ai_message)
|
|
102
139
|
|
|
140
|
+
logger.debug(f"Assistant: {ai_message}")
|
|
103
141
|
if not ai_message.tool_calls:
|
|
104
142
|
break
|
|
105
143
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
144
|
+
logger.info(f"Executing {len(ai_message.tool_calls)} tool calls")
|
|
145
|
+
result_count = 0
|
|
146
|
+
for tool_call in ai_message.tool_calls:
|
|
147
|
+
# Track tool call stats
|
|
148
|
+
tool_name = tool_call.function.name
|
|
149
|
+
self._last_stats.tool_calls.by_tool[tool_name] = (
|
|
150
|
+
self._last_stats.tool_calls.by_tool.get(tool_name, 0) + 1
|
|
151
|
+
)
|
|
152
|
+
server_name = self._extract_server_from_tool_name(tool_name)
|
|
153
|
+
self._last_stats.tool_calls.by_server[server_name] = (
|
|
154
|
+
self._last_stats.tool_calls.by_server.get(server_name, 0) + 1
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
result = await self.execute(tool_call)
|
|
159
|
+
except Exception as e:
|
|
160
|
+
logger.error(
|
|
161
|
+
f"Failed to execute tool '{tool_call.function.name}' "
|
|
162
|
+
f"(id={tool_call.id}): {e}"
|
|
163
|
+
)
|
|
164
|
+
continue
|
|
165
|
+
if result:
|
|
166
|
+
messages.append(result)
|
|
167
|
+
response_messages.append(result)
|
|
168
|
+
result_count = result_count + 1
|
|
169
|
+
|
|
170
|
+
logger.info(f"Added {result_count} tool results")
|
|
121
171
|
|
|
122
172
|
logger.debug(f"Final Response: {response_messages[-1].content}")
|
|
123
173
|
|
|
124
174
|
return response_messages
|
|
125
175
|
|
|
126
|
-
|
|
127
|
-
async def execute(self, tool_call: AssistantToolCall):
|
|
176
|
+
async def execute(self, tool_call: AssistantToolCall) -> ToolResultMessage:
|
|
128
177
|
tool_name = tool_call.function.name
|
|
129
178
|
tool_args = json.loads(tool_call.function.arguments)
|
|
130
179
|
try:
|
|
@@ -144,8 +193,37 @@ class McpToolChat:
|
|
|
144
193
|
|
|
145
194
|
logger.debug(f"Tool Call Result: {result}")
|
|
146
195
|
|
|
147
|
-
result_format = os.getenv(
|
|
148
|
-
|
|
196
|
+
result_format = os.getenv("TOOL_RESULT_FORMAT", "result")
|
|
197
|
+
|
|
198
|
+
# Prefer structuredContent when available (machine-readable format)
|
|
199
|
+
# Note: MCP types use camelCase (structuredContent), mypy stubs may differ
|
|
200
|
+
structured = getattr(result, "structuredContent", None)
|
|
201
|
+
if structured is not None:
|
|
202
|
+
try:
|
|
203
|
+
content_text = json.dumps(structured)
|
|
204
|
+
except (TypeError, ValueError):
|
|
205
|
+
content_text = str(structured)
|
|
206
|
+
elif not result.content:
|
|
207
|
+
content_text = "[No content returned]"
|
|
208
|
+
else:
|
|
209
|
+
# Fall back to processing content items
|
|
210
|
+
content_parts: list[Any] = []
|
|
211
|
+
for content_item in result.content:
|
|
212
|
+
if content_item.type == "text":
|
|
213
|
+
try:
|
|
214
|
+
parsed = json.loads(content_item.text)
|
|
215
|
+
content_parts.append(parsed)
|
|
216
|
+
except json.JSONDecodeError:
|
|
217
|
+
content_parts.append(content_item.text)
|
|
218
|
+
elif hasattr(content_item, "mimeType"):
|
|
219
|
+
# Image or audio content
|
|
220
|
+
content_parts.append(f"[{content_item.type}: {content_item.mimeType}]")
|
|
221
|
+
else:
|
|
222
|
+
content_parts.append(str(content_item))
|
|
223
|
+
|
|
224
|
+
content_text = json.dumps(content_parts)
|
|
225
|
+
|
|
226
|
+
content = format_tool_call_result(tool_call, content_text, style=result_format)
|
|
149
227
|
|
|
150
228
|
return ToolResultMessage(
|
|
151
229
|
name=tool_call.function.name,
|