casual-mcp 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- casual_mcp/__init__.py +13 -0
- casual_mcp/cli.py +68 -0
- casual_mcp/logging.py +30 -0
- casual_mcp/main.py +118 -0
- casual_mcp/mcp_tool_chat.py +90 -0
- casual_mcp/models/__init__.py +33 -0
- casual_mcp/models/config.py +10 -0
- casual_mcp/models/generation_error.py +10 -0
- casual_mcp/models/mcp_server_config.py +39 -0
- casual_mcp/models/messages.py +31 -0
- casual_mcp/models/model_config.py +21 -0
- casual_mcp/models/tool_call.py +14 -0
- casual_mcp/multi_server_mcp_client.py +170 -0
- casual_mcp/providers/__init__.py +0 -0
- casual_mcp/providers/abstract_provider.py +15 -0
- casual_mcp/providers/ollama_provider.py +72 -0
- casual_mcp/providers/openai_provider.py +178 -0
- casual_mcp/providers/provider_factory.py +48 -0
- casual_mcp/utils.py +90 -0
- casual_mcp-0.1.0.dist-info/METADATA +352 -0
- casual_mcp-0.1.0.dist-info/RECORD +25 -0
- casual_mcp-0.1.0.dist-info/WHEEL +5 -0
- casual_mcp-0.1.0.dist-info/entry_points.txt +2 -0
- casual_mcp-0.1.0.dist-info/licenses/LICENSE +7 -0
- casual_mcp-0.1.0.dist-info/top_level.txt +1 -0
casual_mcp/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from . import models
|
|
2
|
+
from .mcp_tool_chat import McpToolChat
|
|
3
|
+
from .multi_server_mcp_client import MultiServerMCPClient
|
|
4
|
+
from .providers.provider_factory import ProviderFactory
|
|
5
|
+
from .utils import load_config
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"McpToolChat",
|
|
9
|
+
"MultiServerMCPClient",
|
|
10
|
+
"ProviderFactory",
|
|
11
|
+
"load_config",
|
|
12
|
+
"models",
|
|
13
|
+
]
|
casual_mcp/cli.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import typer
|
|
2
|
+
import uvicorn
|
|
3
|
+
from rich.console import Console
|
|
4
|
+
from rich.table import Table
|
|
5
|
+
|
|
6
|
+
from casual_mcp.utils import load_config
|
|
7
|
+
|
|
8
|
+
app = typer.Typer()
|
|
9
|
+
console = Console()
|
|
10
|
+
|
|
11
|
+
@app.command()
|
|
12
|
+
def serve(host: str = "0.0.0.0", port: int = 8000, reload: bool = True):
|
|
13
|
+
"""
|
|
14
|
+
Start the Casual MCP API server.
|
|
15
|
+
"""
|
|
16
|
+
uvicorn.run(
|
|
17
|
+
"casual_mcp.main:app",
|
|
18
|
+
host=host,
|
|
19
|
+
port=port,
|
|
20
|
+
reload=reload,
|
|
21
|
+
app_dir="src"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
@app.command()
|
|
25
|
+
def servers():
|
|
26
|
+
"""
|
|
27
|
+
Return a table of all configured servers
|
|
28
|
+
"""
|
|
29
|
+
config = load_config('config.json')
|
|
30
|
+
table = Table("Name", "Type", "Path / Package / Url", "Env")
|
|
31
|
+
|
|
32
|
+
for name, server in config.servers.items():
|
|
33
|
+
path = ''
|
|
34
|
+
match server.type:
|
|
35
|
+
case 'python':
|
|
36
|
+
path = server.path
|
|
37
|
+
case 'node':
|
|
38
|
+
path = server.path
|
|
39
|
+
case 'http':
|
|
40
|
+
path = server.url
|
|
41
|
+
case 'uvx':
|
|
42
|
+
path = server.package
|
|
43
|
+
env = ''
|
|
44
|
+
|
|
45
|
+
table.add_row(name, server.type, path, env)
|
|
46
|
+
|
|
47
|
+
console.print(table)
|
|
48
|
+
|
|
49
|
+
@app.command()
|
|
50
|
+
def models():
|
|
51
|
+
"""
|
|
52
|
+
Return a table of all configured models
|
|
53
|
+
"""
|
|
54
|
+
config = load_config('config.json')
|
|
55
|
+
table = Table("Name", "Provider", "Model", "Endpoint")
|
|
56
|
+
|
|
57
|
+
for name, model in config.models.items():
|
|
58
|
+
endpoint = ''
|
|
59
|
+
if model.provider == 'openai':
|
|
60
|
+
endpoint = model.endpoint or ''
|
|
61
|
+
|
|
62
|
+
table.add_row(name, model.provider, model.model, str(endpoint))
|
|
63
|
+
|
|
64
|
+
console.print(table)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
if __name__ == "__main__":
|
|
68
|
+
app()
|
casual_mcp/logging.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from rich.console import Console
|
|
5
|
+
from rich.logging import RichHandler
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_logger(name: str) -> logging.Logger:
|
|
9
|
+
return logging.getLogger(f"casual_mcp.{name}")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def configure_logging(
|
|
13
|
+
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] | int = "INFO",
|
|
14
|
+
logger: logging.Logger | None = None,
|
|
15
|
+
) -> None:
|
|
16
|
+
if logger is None:
|
|
17
|
+
logger = logging.getLogger("casual_mcp")
|
|
18
|
+
|
|
19
|
+
handler = RichHandler(console=Console(stderr=True), rich_tracebacks=True)
|
|
20
|
+
formatter = logging.Formatter("%(name)s: %(message)s")
|
|
21
|
+
handler.setFormatter(formatter)
|
|
22
|
+
|
|
23
|
+
logger.setLevel(level)
|
|
24
|
+
|
|
25
|
+
# Remove any existing handlers to avoid duplicates on reconfiguration
|
|
26
|
+
for hdlr in logger.handlers[:]:
|
|
27
|
+
logger.removeHandler(hdlr)
|
|
28
|
+
|
|
29
|
+
logger.addHandler(handler)
|
|
30
|
+
logger.info("Logging Configured")
|
casual_mcp/main.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from dotenv import load_dotenv
|
|
6
|
+
from fastapi import FastAPI, HTTPException
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
from casual_mcp import McpToolChat, MultiServerMCPClient
|
|
10
|
+
from casual_mcp.logging import configure_logging, get_logger
|
|
11
|
+
from casual_mcp.models.messages import CasualMcpMessage
|
|
12
|
+
from casual_mcp.providers.provider_factory import ProviderFactory
|
|
13
|
+
from casual_mcp.utils import load_config, render_system_prompt
|
|
14
|
+
|
|
15
|
+
load_dotenv()
|
|
16
|
+
config = load_config("config.json")
|
|
17
|
+
mcp_client = MultiServerMCPClient(namespace_tools=config.namespace_tools)
|
|
18
|
+
provider_factory = ProviderFactory()
|
|
19
|
+
|
|
20
|
+
app = FastAPI()
|
|
21
|
+
|
|
22
|
+
default_system_prompt = """You are a helpful assistant.
|
|
23
|
+
|
|
24
|
+
You have access to up-to-date information through the tools, but you must never mention that tools were used.
|
|
25
|
+
|
|
26
|
+
Respond naturally and confidently, as if you already know all the facts.
|
|
27
|
+
|
|
28
|
+
**Never mention your knowledge cutoff, training data, or when you were last updated.**
|
|
29
|
+
|
|
30
|
+
You must not speculate or guess about dates — if a date is given to you by a tool, assume it is correct and respond accordingly without disclaimers.
|
|
31
|
+
|
|
32
|
+
Always present information as current and factual.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
class GenerateRequest(BaseModel):
|
|
36
|
+
session_id: str | None = Field(
|
|
37
|
+
default=None, title="Session to use"
|
|
38
|
+
)
|
|
39
|
+
model: str = Field(
|
|
40
|
+
title="Model to user"
|
|
41
|
+
)
|
|
42
|
+
system_prompt: str | None = Field(
|
|
43
|
+
default=None, title="System Prompt to use"
|
|
44
|
+
)
|
|
45
|
+
user_prompt: str = Field(
|
|
46
|
+
title="User Prompt"
|
|
47
|
+
)
|
|
48
|
+
messages: list[CasualMcpMessage] | None = Field(
|
|
49
|
+
default=None, title="Previous messages to supply to the LLM"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
sys.path.append(str(Path(__file__).parent.resolve()))
|
|
53
|
+
|
|
54
|
+
# Configure logging
|
|
55
|
+
configure_logging(os.getenv("LOG_LEVEL", 'INFO'))
|
|
56
|
+
logger = get_logger("main")
|
|
57
|
+
|
|
58
|
+
async def perform_chat(
|
|
59
|
+
model,
|
|
60
|
+
user,
|
|
61
|
+
system: str | None = None,
|
|
62
|
+
messages: list[CasualMcpMessage] = None,
|
|
63
|
+
session_id: str | None = None
|
|
64
|
+
) -> list[CasualMcpMessage]:
|
|
65
|
+
# Get Provider from Model Config
|
|
66
|
+
model_config = config.models[model]
|
|
67
|
+
provider = provider_factory.get_provider(model, model_config)
|
|
68
|
+
|
|
69
|
+
if not system:
|
|
70
|
+
if (model_config.template):
|
|
71
|
+
system = render_system_prompt(
|
|
72
|
+
f"{model_config.template}.j2",
|
|
73
|
+
await mcp_client.list_tools()
|
|
74
|
+
)
|
|
75
|
+
else:
|
|
76
|
+
system = default_system_prompt
|
|
77
|
+
|
|
78
|
+
chat = McpToolChat(mcp_client, provider, system)
|
|
79
|
+
return await chat.chat(
|
|
80
|
+
prompt=user,
|
|
81
|
+
messages=messages,
|
|
82
|
+
session_id=session_id
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@app.post("/chat")
|
|
87
|
+
async def chat(req: GenerateRequest):
|
|
88
|
+
if len(mcp_client.tools) == 0:
|
|
89
|
+
await mcp_client.load_config(config.servers)
|
|
90
|
+
provider_factory.set_tools(await mcp_client.list_tools())
|
|
91
|
+
|
|
92
|
+
messages = await perform_chat(
|
|
93
|
+
req.model,
|
|
94
|
+
system=req.system_prompt,
|
|
95
|
+
user=req.user_prompt,
|
|
96
|
+
messages=req.messages,
|
|
97
|
+
session_id=req.session_id
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return {
|
|
101
|
+
"messages": messages,
|
|
102
|
+
"response": messages[len(messages) - 1].content
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# This endpoint will either go away or be used for something else, don't use it
|
|
107
|
+
@app.post("/generate")
|
|
108
|
+
async def generate_response(req: GenerateRequest):
|
|
109
|
+
return await chat(req)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@app.get("/chat/session/{session_id}")
|
|
113
|
+
async def get_chat_session(session_id):
|
|
114
|
+
session = McpToolChat.get_session(session_id)
|
|
115
|
+
if not session:
|
|
116
|
+
raise HTTPException(status_code=404, detail="Session not found")
|
|
117
|
+
|
|
118
|
+
return session
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
|
|
2
|
+
from casual_mcp.logging import get_logger
|
|
3
|
+
from casual_mcp.models.messages import CasualMcpMessage, SystemMessage, UserMessage
|
|
4
|
+
from casual_mcp.multi_server_mcp_client import MultiServerMCPClient
|
|
5
|
+
from casual_mcp.providers.provider_factory import LLMProvider
|
|
6
|
+
|
|
7
|
+
logger = get_logger("mcp_tool_chat")
|
|
8
|
+
sessions: dict[str, list[CasualMcpMessage]] = {}
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class McpToolChat:
|
|
12
|
+
def __init__(self, tool_client: MultiServerMCPClient, provider: LLMProvider, system: str):
|
|
13
|
+
self.provider = provider
|
|
14
|
+
self.tool_client = tool_client
|
|
15
|
+
self.system = system
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def get_session(session_id) -> list[CasualMcpMessage] | None:
|
|
19
|
+
global sessions
|
|
20
|
+
return sessions.get(session_id)
|
|
21
|
+
|
|
22
|
+
async def chat(
|
|
23
|
+
self,
|
|
24
|
+
prompt: str | None = None,
|
|
25
|
+
messages: list[CasualMcpMessage] = None,
|
|
26
|
+
session_id: str | None = None
|
|
27
|
+
) -> list[CasualMcpMessage]:
|
|
28
|
+
global sessions
|
|
29
|
+
|
|
30
|
+
# todo: check that we have a prompt or that there is a user message in messages
|
|
31
|
+
|
|
32
|
+
# If we have a session ID then create if new and fetch it
|
|
33
|
+
if session_id:
|
|
34
|
+
if not sessions.get(session_id):
|
|
35
|
+
logger.info(f"Starting new session {session_id}")
|
|
36
|
+
sessions[session_id] = []
|
|
37
|
+
else:
|
|
38
|
+
logger.info(
|
|
39
|
+
f"Retrieving session {session_id} of length {len(sessions[session_id])}"
|
|
40
|
+
)
|
|
41
|
+
messages = sessions[session_id].copy()
|
|
42
|
+
|
|
43
|
+
logger.info("Start Chat")
|
|
44
|
+
tools = await self.tool_client.list_tools()
|
|
45
|
+
|
|
46
|
+
if messages is None or len(messages) == 0:
|
|
47
|
+
message_history = []
|
|
48
|
+
messages = [SystemMessage(content=self.system)]
|
|
49
|
+
else:
|
|
50
|
+
message_history = messages.copy()
|
|
51
|
+
|
|
52
|
+
if prompt:
|
|
53
|
+
messages.append(UserMessage(content=prompt))
|
|
54
|
+
|
|
55
|
+
response: str | None = None
|
|
56
|
+
while True:
|
|
57
|
+
logger.info("Calling the LLM")
|
|
58
|
+
ai_message = await self.provider.generate(messages, tools)
|
|
59
|
+
response = ai_message.content
|
|
60
|
+
|
|
61
|
+
# Add the assistant's message
|
|
62
|
+
messages.append(ai_message)
|
|
63
|
+
|
|
64
|
+
if not ai_message.tool_calls:
|
|
65
|
+
break
|
|
66
|
+
|
|
67
|
+
if ai_message.tool_calls and len(ai_message.tool_calls) > 0:
|
|
68
|
+
logger.info(f"Executing {len(ai_message.tool_calls)} tool calls")
|
|
69
|
+
result_count = 0
|
|
70
|
+
for tool_call in ai_message.tool_calls:
|
|
71
|
+
try:
|
|
72
|
+
result = await self.tool_client.execute(tool_call)
|
|
73
|
+
except Exception as e:
|
|
74
|
+
logger.error(e)
|
|
75
|
+
return messages
|
|
76
|
+
if result:
|
|
77
|
+
messages.append(result)
|
|
78
|
+
result_count = result_count + 1
|
|
79
|
+
|
|
80
|
+
logger.info(f"Added {result_count} tool results")
|
|
81
|
+
|
|
82
|
+
logger.debug(f"""Final Response:
|
|
83
|
+
{response} """)
|
|
84
|
+
|
|
85
|
+
new_messages = [item for item in messages if item not in message_history]
|
|
86
|
+
if session_id:
|
|
87
|
+
sessions[session_id].extend(new_messages)
|
|
88
|
+
|
|
89
|
+
return new_messages
|
|
90
|
+
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from .mcp_server_config import (
|
|
2
|
+
HttpMcpServerConfig,
|
|
3
|
+
McpServerConfig,
|
|
4
|
+
NodeMcpServerConfig,
|
|
5
|
+
PythonMcpServerConfig,
|
|
6
|
+
UvxMcpServerConfig,
|
|
7
|
+
)
|
|
8
|
+
from .messages import (
|
|
9
|
+
AssistantMessage,
|
|
10
|
+
CasualMcpMessage,
|
|
11
|
+
SystemMessage,
|
|
12
|
+
ToolResultMessage,
|
|
13
|
+
UserMessage,
|
|
14
|
+
)
|
|
15
|
+
from .model_config import (
|
|
16
|
+
ModelConfig,
|
|
17
|
+
OpenAIModelConfig,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"UserMessage",
|
|
22
|
+
"AssistantMessage",
|
|
23
|
+
"ToolResultMessage",
|
|
24
|
+
"SystemMessage",
|
|
25
|
+
"CasualMcpMessage",
|
|
26
|
+
"ModelConfig",
|
|
27
|
+
"OpenAIModelConfig",
|
|
28
|
+
"McpServerConfig",
|
|
29
|
+
"PythonMcpServerConfig",
|
|
30
|
+
"UvxMcpServerConfig",
|
|
31
|
+
"NodeMcpServerConfig",
|
|
32
|
+
"HttpMcpServerConfig",
|
|
33
|
+
]
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from pydantic import BaseModel
|
|
2
|
+
|
|
3
|
+
from casual_mcp.models.mcp_server_config import McpServerConfig
|
|
4
|
+
from casual_mcp.models.model_config import ModelConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Config(BaseModel):
|
|
8
|
+
namespace_tools: bool | None = False
|
|
9
|
+
models: dict[str, ModelConfig]
|
|
10
|
+
servers: dict[str, McpServerConfig]
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseMcpServerConfig(BaseModel):
|
|
7
|
+
type: Literal["python", "node", "http", "uvx"]
|
|
8
|
+
system_prompt: str | None | None = None
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PythonMcpServerConfig(BaseMcpServerConfig):
|
|
12
|
+
type: Literal["python"] = "python"
|
|
13
|
+
path: str
|
|
14
|
+
env: dict[str, str] | None | None = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class UvxMcpServerConfig(BaseMcpServerConfig):
|
|
18
|
+
type: Literal["uvx"] = "uvx"
|
|
19
|
+
package: str
|
|
20
|
+
env: dict[str, str] | None | None = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class NodeMcpServerConfig(BaseMcpServerConfig):
|
|
24
|
+
type: Literal["node"] = "node"
|
|
25
|
+
path: str
|
|
26
|
+
env: dict[str, str] | None | None = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class HttpMcpServerConfig(BaseMcpServerConfig):
|
|
30
|
+
type: Literal["http"] = "http"
|
|
31
|
+
url: str
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
McpServerConfig = (
|
|
35
|
+
PythonMcpServerConfig
|
|
36
|
+
| NodeMcpServerConfig
|
|
37
|
+
| HttpMcpServerConfig
|
|
38
|
+
| UvxMcpServerConfig
|
|
39
|
+
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from typing import Literal, TypeAlias
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
from casual_mcp.models.tool_call import AssistantToolCall
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AssistantMessage(BaseModel):
|
|
9
|
+
role: Literal["assistant"] = "assistant"
|
|
10
|
+
content: str | None
|
|
11
|
+
tool_calls: list[AssistantToolCall] | None
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SystemMessage(BaseModel):
|
|
15
|
+
role: Literal["system"] = "system"
|
|
16
|
+
content: str
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ToolResultMessage(BaseModel):
|
|
20
|
+
role: Literal["tool"] = "tool"
|
|
21
|
+
name: str
|
|
22
|
+
tool_call_id: str
|
|
23
|
+
content: str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class UserMessage(BaseModel):
|
|
27
|
+
role: Literal["user"] = "user"
|
|
28
|
+
content: str | None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
CasualMcpMessage: TypeAlias = AssistantMessage | SystemMessage | ToolResultMessage | UserMessage
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, HttpUrl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseModelConfig(BaseModel):
|
|
7
|
+
provider: Literal["openai", "ollama"]
|
|
8
|
+
model: str
|
|
9
|
+
endpoint: HttpUrl | None = None
|
|
10
|
+
template: str | None = None
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OpenAIModelConfig(BaseModelConfig):
|
|
14
|
+
provider: Literal["openai"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class OllamaModelConfig(BaseModelConfig):
|
|
18
|
+
provider: Literal["ollama"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
ModelConfig = OpenAIModelConfig | OllamaModelConfig
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class AssistantToolCallFunction(BaseModel):
|
|
7
|
+
name: str
|
|
8
|
+
arguments: str
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AssistantToolCall(BaseModel):
|
|
12
|
+
id: str
|
|
13
|
+
type: Literal["function"] = "function"
|
|
14
|
+
function: AssistantToolCallFunction
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import mcp
|
|
5
|
+
from fastmcp import Client
|
|
6
|
+
from fastmcp.client.logging import LogMessage
|
|
7
|
+
from fastmcp.client.transports import (
|
|
8
|
+
ClientTransport,
|
|
9
|
+
NodeStdioTransport,
|
|
10
|
+
PythonStdioTransport,
|
|
11
|
+
StreamableHttpTransport,
|
|
12
|
+
UvxStdioTransport,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from casual_mcp.logging import get_logger
|
|
16
|
+
from casual_mcp.models.mcp_server_config import McpServerConfig
|
|
17
|
+
from casual_mcp.models.messages import ToolResultMessage
|
|
18
|
+
from casual_mcp.models.tool_call import AssistantToolCall, AssistantToolCallFunction
|
|
19
|
+
from casual_mcp.utils import format_tool_call_result
|
|
20
|
+
|
|
21
|
+
logger = get_logger("multi_server_mcp_client")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
async def my_log_handler(params: LogMessage):
|
|
25
|
+
logger.log(params.level, params.data)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_server_transport(config: McpServerConfig) -> ClientTransport:
|
|
29
|
+
match config.type:
|
|
30
|
+
case 'python':
|
|
31
|
+
return PythonStdioTransport(
|
|
32
|
+
script_path=config.path,
|
|
33
|
+
env=config.env
|
|
34
|
+
)
|
|
35
|
+
case 'node':
|
|
36
|
+
return NodeStdioTransport(
|
|
37
|
+
script_path=config.path,
|
|
38
|
+
env=config.env
|
|
39
|
+
)
|
|
40
|
+
case 'http':
|
|
41
|
+
return StreamableHttpTransport(
|
|
42
|
+
url=config.url
|
|
43
|
+
)
|
|
44
|
+
case 'uvx':
|
|
45
|
+
return UvxStdioTransport(
|
|
46
|
+
tool_name=config.package,
|
|
47
|
+
env_vars=config.env
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class MultiServerMCPClient:
|
|
52
|
+
def __init__(self, namespace_tools: bool = False):
|
|
53
|
+
self.servers: dict[str, Client] = {} # Map server names to client connections
|
|
54
|
+
self.tools_map = {} # Map tool names to server names
|
|
55
|
+
self.tools: list[mcp.types.Tool] = []
|
|
56
|
+
self.system_prompts: list[str] = []
|
|
57
|
+
self.namespace_tools = namespace_tools
|
|
58
|
+
|
|
59
|
+
async def load_config(self, config: dict[str, McpServerConfig]):
|
|
60
|
+
# Load the servers from config
|
|
61
|
+
logger.info("Loading server config")
|
|
62
|
+
for name, server_config in config.items():
|
|
63
|
+
transport = get_server_transport(server_config)
|
|
64
|
+
await self.connect_to_server(
|
|
65
|
+
transport,
|
|
66
|
+
name,
|
|
67
|
+
system_prompt=server_config.system_prompt
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
async def connect_to_server_script(self, path, name, env={}):
|
|
72
|
+
# Connect via stdio to a local script
|
|
73
|
+
transport = PythonStdioTransport(
|
|
74
|
+
script_path=path,
|
|
75
|
+
env=env,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return await self.connect_to_server(transport, name)
|
|
79
|
+
|
|
80
|
+
async def connect_to_server(self, server, name, system_prompt: str = None):
|
|
81
|
+
"""Connect to an MCP server and register its tools."""
|
|
82
|
+
logger.debug(f"Connecting to server {name}")
|
|
83
|
+
|
|
84
|
+
async with Client(
|
|
85
|
+
server,
|
|
86
|
+
log_handler=my_log_handler,
|
|
87
|
+
) as server_client:
|
|
88
|
+
# Store the connection
|
|
89
|
+
self.servers[name] = server_client
|
|
90
|
+
|
|
91
|
+
# Fetch tools and map them to this server
|
|
92
|
+
tools = await server_client.list_tools()
|
|
93
|
+
|
|
94
|
+
# If we are namespacing servers then change the tool names
|
|
95
|
+
for tool in tools:
|
|
96
|
+
if self.namespace_tools:
|
|
97
|
+
tool.name = f"{name}-{tool.name}"
|
|
98
|
+
else:
|
|
99
|
+
if self.tools_map.get(tool.name):
|
|
100
|
+
raise SystemError(
|
|
101
|
+
f"Tool name collision {name}:{tool.name} already added by {self.tools_map[tool.name]}" # noqa: E501
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
self.tools_map[tool.name] = name
|
|
105
|
+
self.tools.extend(tools)
|
|
106
|
+
|
|
107
|
+
if system_prompt:
|
|
108
|
+
prompt = await server_client.get_prompt(system_prompt)
|
|
109
|
+
if prompt:
|
|
110
|
+
self.system_prompts.append(prompt)
|
|
111
|
+
|
|
112
|
+
return tools
|
|
113
|
+
|
|
114
|
+
async def list_tools(self):
|
|
115
|
+
"""Fetch and aggregate tools from all connected servers."""
|
|
116
|
+
return self.tools
|
|
117
|
+
|
|
118
|
+
async def call_tool(self, function: AssistantToolCallFunction):
|
|
119
|
+
"""Route a tool call to the appropriate server."""
|
|
120
|
+
tool_name = function.name
|
|
121
|
+
tool_args = json.loads(function.arguments)
|
|
122
|
+
|
|
123
|
+
# Find which server has this tool
|
|
124
|
+
server_name = self.tools_map.get(tool_name)
|
|
125
|
+
|
|
126
|
+
# Remove the sever name if the tools are namespaced
|
|
127
|
+
if self.namespace_tools:
|
|
128
|
+
tool_name = tool_name.removeprefix(f"{server_name}-")
|
|
129
|
+
else:
|
|
130
|
+
tool_name = tool_name
|
|
131
|
+
|
|
132
|
+
if not self.tools_map.get(tool_name):
|
|
133
|
+
raise ValueError(f"Tool not found: {tool_name}")
|
|
134
|
+
|
|
135
|
+
logger.info(f"Calling tool {tool_name}")
|
|
136
|
+
|
|
137
|
+
server_client = self.servers[server_name]
|
|
138
|
+
async with server_client:
|
|
139
|
+
return await server_client.call_tool(tool_name, tool_args)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
async def execute(self, tool_call: AssistantToolCall):
|
|
143
|
+
try:
|
|
144
|
+
result = await self.call_tool(tool_call.function)
|
|
145
|
+
except Exception as e:
|
|
146
|
+
if isinstance(e, ValueError):
|
|
147
|
+
logger.warning(e)
|
|
148
|
+
else:
|
|
149
|
+
logger.error(f"Error calling tool: {e}")
|
|
150
|
+
|
|
151
|
+
return ToolResultMessage(
|
|
152
|
+
name=tool_call.function.name,
|
|
153
|
+
tool_call_id=tool_call.id,
|
|
154
|
+
content=str(e),
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
logger.debug(f"Tool Call Result: {result}")
|
|
158
|
+
|
|
159
|
+
result_format = os.getenv('TOOL_RESULT_FORMAT', 'result')
|
|
160
|
+
content = format_tool_call_result(tool_call, result[0].text, style=result_format)
|
|
161
|
+
|
|
162
|
+
return ToolResultMessage(
|
|
163
|
+
name=tool_call.function.name,
|
|
164
|
+
tool_call_id=tool_call.id,
|
|
165
|
+
content=content,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def get_system_prompts(self) -> list[str]:
|
|
170
|
+
return self.system_prompts
|
|
File without changes
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
import mcp
|
|
4
|
+
|
|
5
|
+
from casual_mcp.models.messages import CasualMcpMessage
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CasualMcpProvider(ABC):
|
|
9
|
+
@abstractmethod
|
|
10
|
+
async def generate(
|
|
11
|
+
self,
|
|
12
|
+
messages: list[CasualMcpMessage],
|
|
13
|
+
tools: list[mcp.Tool]
|
|
14
|
+
) -> CasualMcpMessage:
|
|
15
|
+
pass
|