casual-mcp 0.3.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- casual_mcp/__init__.py +8 -1
- casual_mcp/cli.py +24 -24
- casual_mcp/convert_tools.py +68 -0
- casual_mcp/logging.py +6 -2
- casual_mcp/main.py +30 -55
- casual_mcp/mcp_tool_chat.py +62 -49
- casual_mcp/models/__init__.py +13 -8
- casual_mcp/models/config.py +2 -2
- casual_mcp/models/generation_error.py +1 -1
- casual_mcp/models/model_config.py +3 -3
- casual_mcp/provider_factory.py +47 -0
- casual_mcp/tool_cache.py +114 -0
- casual_mcp/utils.py +18 -11
- casual_mcp-0.5.0.dist-info/METADATA +630 -0
- casual_mcp-0.5.0.dist-info/RECORD +20 -0
- {casual_mcp-0.3.1.dist-info → casual_mcp-0.5.0.dist-info}/WHEEL +1 -1
- casual_mcp/models/messages.py +0 -31
- casual_mcp/models/tool_call.py +0 -14
- casual_mcp/providers/__init__.py +0 -0
- casual_mcp/providers/abstract_provider.py +0 -15
- casual_mcp/providers/ollama_provider.py +0 -72
- casual_mcp/providers/openai_provider.py +0 -178
- casual_mcp/providers/provider_factory.py +0 -56
- casual_mcp-0.3.1.dist-info/METADATA +0 -398
- casual_mcp-0.3.1.dist-info/RECORD +0 -24
- {casual_mcp-0.3.1.dist-info → casual_mcp-0.5.0.dist-info}/entry_points.txt +0 -0
- {casual_mcp-0.3.1.dist-info → casual_mcp-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {casual_mcp-0.3.1.dist-info → casual_mcp-0.5.0.dist-info}/top_level.txt +0 -0
casual_mcp/__init__.py
CHANGED
|
@@ -1,11 +1,18 @@
|
|
|
1
|
+
from importlib.metadata import version
|
|
2
|
+
|
|
1
3
|
from . import models
|
|
4
|
+
|
|
5
|
+
__version__ = version("casual-mcp")
|
|
2
6
|
from .mcp_tool_chat import McpToolChat
|
|
3
|
-
from .
|
|
7
|
+
from .provider_factory import ProviderFactory
|
|
8
|
+
from .tool_cache import ToolCache
|
|
4
9
|
from .utils import load_config, load_mcp_client, render_system_prompt
|
|
5
10
|
|
|
6
11
|
__all__ = [
|
|
12
|
+
"__version__",
|
|
7
13
|
"McpToolChat",
|
|
8
14
|
"ProviderFactory",
|
|
15
|
+
"ToolCache",
|
|
9
16
|
"load_config",
|
|
10
17
|
"load_mcp_client",
|
|
11
18
|
"render_system_prompt",
|
casual_mcp/cli.py
CHANGED
|
@@ -1,6 +1,10 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import mcp
|
|
2
5
|
import typer
|
|
3
6
|
import uvicorn
|
|
7
|
+
from fastmcp import Client
|
|
4
8
|
from rich.console import Console
|
|
5
9
|
from rich.table import Table
|
|
6
10
|
|
|
@@ -10,63 +14,58 @@ from casual_mcp.utils import load_config, load_mcp_client
|
|
|
10
14
|
app = typer.Typer()
|
|
11
15
|
console = Console()
|
|
12
16
|
|
|
17
|
+
|
|
13
18
|
@app.command()
|
|
14
|
-
def serve(host: str = "0.0.0.0", port: int = 8000, reload: bool = True):
|
|
19
|
+
def serve(host: str = "0.0.0.0", port: int = 8000, reload: bool = True) -> None:
|
|
15
20
|
"""
|
|
16
21
|
Start the Casual MCP API server.
|
|
17
22
|
"""
|
|
18
|
-
uvicorn.run(
|
|
19
|
-
|
|
20
|
-
host=host,
|
|
21
|
-
port=port,
|
|
22
|
-
reload=reload,
|
|
23
|
-
app_dir="src"
|
|
24
|
-
)
|
|
23
|
+
uvicorn.run("casual_mcp.main:app", host=host, port=port, reload=reload, app_dir="src")
|
|
24
|
+
|
|
25
25
|
|
|
26
26
|
@app.command()
|
|
27
|
-
def servers():
|
|
27
|
+
def servers() -> None:
|
|
28
28
|
"""
|
|
29
29
|
Return a table of all configured servers
|
|
30
30
|
"""
|
|
31
|
-
config = load_config(
|
|
31
|
+
config = load_config("casual_mcp_config.json")
|
|
32
32
|
table = Table("Name", "Type", "Command / Url", "Env")
|
|
33
33
|
|
|
34
34
|
for name, server in config.servers.items():
|
|
35
|
-
|
|
35
|
+
server_type = "stdio"
|
|
36
36
|
if isinstance(server, RemoteServerConfig):
|
|
37
|
-
|
|
37
|
+
server_type = "remote"
|
|
38
38
|
|
|
39
|
-
path =
|
|
39
|
+
path = ""
|
|
40
40
|
if isinstance(server, RemoteServerConfig):
|
|
41
41
|
path = server.url
|
|
42
42
|
else:
|
|
43
43
|
path = f"{server.command} {' '.join(server.args)}"
|
|
44
|
-
env =
|
|
44
|
+
env = ""
|
|
45
45
|
|
|
46
|
-
table.add_row(name,
|
|
46
|
+
table.add_row(name, server_type, path, env)
|
|
47
47
|
|
|
48
48
|
console.print(table)
|
|
49
49
|
|
|
50
|
+
|
|
50
51
|
@app.command()
|
|
51
|
-
def models():
|
|
52
|
+
def models() -> None:
|
|
52
53
|
"""
|
|
53
54
|
Return a table of all configured models
|
|
54
55
|
"""
|
|
55
|
-
config = load_config(
|
|
56
|
+
config = load_config("casual_mcp_config.json")
|
|
56
57
|
table = Table("Name", "Provider", "Model", "Endpoint")
|
|
57
58
|
|
|
58
59
|
for name, model in config.models.items():
|
|
59
|
-
endpoint =
|
|
60
|
-
if model.provider == 'openai':
|
|
61
|
-
endpoint = model.endpoint or ''
|
|
62
|
-
|
|
60
|
+
endpoint = model.endpoint or ""
|
|
63
61
|
table.add_row(name, model.provider, model.model, str(endpoint))
|
|
64
62
|
|
|
65
63
|
console.print(table)
|
|
66
64
|
|
|
65
|
+
|
|
67
66
|
@app.command()
|
|
68
|
-
def tools():
|
|
69
|
-
config = load_config(
|
|
67
|
+
def tools() -> None:
|
|
68
|
+
config = load_config("casual_mcp_config.json")
|
|
70
69
|
mcp_client = load_mcp_client(config)
|
|
71
70
|
table = Table("Name", "Description")
|
|
72
71
|
# async with mcp_client:
|
|
@@ -76,9 +75,10 @@ def tools():
|
|
|
76
75
|
console.print(table)
|
|
77
76
|
|
|
78
77
|
|
|
79
|
-
async def get_tools(client):
|
|
78
|
+
async def get_tools(client: Client[Any]) -> list[mcp.Tool]:
|
|
80
79
|
async with client:
|
|
81
80
|
return await client.list_tools()
|
|
82
81
|
|
|
82
|
+
|
|
83
83
|
if __name__ == "__main__":
|
|
84
84
|
app()
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import mcp
|
|
2
|
+
from casual_llm import Tool
|
|
3
|
+
|
|
4
|
+
from casual_mcp.logging import get_logger
|
|
5
|
+
|
|
6
|
+
logger = get_logger("convert_tools")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# MCP format converters (for interop with MCP libraries)
|
|
10
|
+
def tool_from_mcp(mcp_tool: mcp.Tool) -> Tool:
|
|
11
|
+
"""
|
|
12
|
+
Convert an MCP Tool to casual-llm Tool format.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
mcp_tool: MCP Tool object (from mcp library)
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
casual-llm Tool instance
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
ValueError: If MCP tool is missing required fields
|
|
22
|
+
|
|
23
|
+
Examples:
|
|
24
|
+
>>> # Assuming mcp_tool is an MCP Tool object
|
|
25
|
+
>>> # tool = tool_from_mcp(mcp_tool)
|
|
26
|
+
>>> # assert tool.name == mcp_tool.name
|
|
27
|
+
pass
|
|
28
|
+
"""
|
|
29
|
+
if not mcp_tool.name or not mcp_tool.description:
|
|
30
|
+
raise ValueError(
|
|
31
|
+
f"MCP tool missing required fields: "
|
|
32
|
+
f"name={mcp_tool.name}, description={mcp_tool.description}"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
input_schema = getattr(mcp_tool, "inputSchema", {})
|
|
36
|
+
if not isinstance(input_schema, dict):
|
|
37
|
+
input_schema = {}
|
|
38
|
+
|
|
39
|
+
return Tool.from_input_schema(
|
|
40
|
+
name=mcp_tool.name, description=mcp_tool.description, input_schema=input_schema
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def tools_from_mcp(mcp_tools: list[mcp.Tool]) -> list[Tool]:
|
|
45
|
+
"""
|
|
46
|
+
Convert multiple MCP Tools to casual-llm format.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
mcp_tools: List of MCP Tool objects
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
List of casual-llm Tool instances
|
|
53
|
+
|
|
54
|
+
Examples:
|
|
55
|
+
>>> # tools = tools_from_mcp(mcp_tool_list)
|
|
56
|
+
>>> # assert len(tools) == len(mcp_tool_list)
|
|
57
|
+
pass
|
|
58
|
+
"""
|
|
59
|
+
tools = []
|
|
60
|
+
|
|
61
|
+
for mcp_tool in mcp_tools:
|
|
62
|
+
try:
|
|
63
|
+
tool = tool_from_mcp(mcp_tool)
|
|
64
|
+
tools.append(tool)
|
|
65
|
+
except ValueError as e:
|
|
66
|
+
logger.warning(f"Skipping invalid MCP tool: {e}")
|
|
67
|
+
|
|
68
|
+
return tools
|
casual_mcp/logging.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Literal
|
|
3
2
|
|
|
4
3
|
from rich.console import Console
|
|
5
4
|
from rich.logging import RichHandler
|
|
@@ -10,7 +9,7 @@ def get_logger(name: str) -> logging.Logger:
|
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
def configure_logging(
|
|
13
|
-
level:
|
|
12
|
+
level: str | int = "INFO",
|
|
14
13
|
logger: logging.Logger | None = None,
|
|
15
14
|
) -> None:
|
|
16
15
|
if logger is None:
|
|
@@ -27,4 +26,9 @@ def configure_logging(
|
|
|
27
26
|
logger.removeHandler(hdlr)
|
|
28
27
|
|
|
29
28
|
logger.addHandler(handler)
|
|
29
|
+
|
|
30
|
+
# Set logging level on FastMCP and MCP libraries
|
|
31
|
+
logging.getLogger("fastmcp").setLevel(level)
|
|
32
|
+
logging.getLogger("mcp").setLevel(level)
|
|
33
|
+
|
|
30
34
|
logger.info("Logging Configured")
|
casual_mcp/main.py
CHANGED
|
@@ -1,21 +1,28 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import
|
|
3
|
-
from pathlib import Path
|
|
2
|
+
from typing import Any
|
|
4
3
|
|
|
4
|
+
from casual_llm import ChatMessage
|
|
5
5
|
from dotenv import load_dotenv
|
|
6
6
|
from fastapi import FastAPI, HTTPException
|
|
7
7
|
from pydantic import BaseModel, Field
|
|
8
8
|
|
|
9
9
|
from casual_mcp import McpToolChat
|
|
10
10
|
from casual_mcp.logging import configure_logging, get_logger
|
|
11
|
-
from casual_mcp.
|
|
12
|
-
from casual_mcp.
|
|
11
|
+
from casual_mcp.provider_factory import ProviderFactory
|
|
12
|
+
from casual_mcp.tool_cache import ToolCache
|
|
13
13
|
from casual_mcp.utils import load_config, load_mcp_client, render_system_prompt
|
|
14
14
|
|
|
15
|
+
# Load environment variables
|
|
15
16
|
load_dotenv()
|
|
17
|
+
|
|
18
|
+
# Configure logging
|
|
19
|
+
configure_logging(os.getenv("LOG_LEVEL", "INFO"))
|
|
20
|
+
logger = get_logger("main")
|
|
21
|
+
|
|
16
22
|
config = load_config("casual_mcp_config.json")
|
|
17
23
|
mcp_client = load_mcp_client(config)
|
|
18
|
-
|
|
24
|
+
tool_cache = ToolCache(mcp_client)
|
|
25
|
+
provider_factory = ProviderFactory()
|
|
19
26
|
|
|
20
27
|
app = FastAPI()
|
|
21
28
|
|
|
@@ -34,65 +41,36 @@ Always present information as current and factual.
|
|
|
34
41
|
|
|
35
42
|
|
|
36
43
|
class GenerateRequest(BaseModel):
|
|
37
|
-
session_id: str | None = Field(
|
|
38
|
-
|
|
39
|
-
)
|
|
40
|
-
|
|
41
|
-
title="Model to user"
|
|
42
|
-
)
|
|
43
|
-
system_prompt: str | None = Field(
|
|
44
|
-
default=None, title="System Prompt to use"
|
|
45
|
-
)
|
|
46
|
-
prompt: str = Field(
|
|
47
|
-
title="User Prompt"
|
|
48
|
-
)
|
|
44
|
+
session_id: str | None = Field(default=None, title="Session to use")
|
|
45
|
+
model: str = Field(title="Model to use")
|
|
46
|
+
system_prompt: str | None = Field(default=None, title="System Prompt to use")
|
|
47
|
+
prompt: str = Field(title="User Prompt")
|
|
49
48
|
|
|
50
49
|
|
|
51
50
|
class ChatRequest(BaseModel):
|
|
52
|
-
model: str = Field(
|
|
53
|
-
|
|
54
|
-
)
|
|
55
|
-
system_prompt: str | None = Field(
|
|
56
|
-
default=None, title="System Prompt to use"
|
|
57
|
-
)
|
|
58
|
-
messages: list[ChatMessage] = Field(
|
|
59
|
-
title="Previous messages to supply to the LLM"
|
|
60
|
-
)
|
|
61
|
-
|
|
62
|
-
sys.path.append(str(Path(__file__).parent.resolve()))
|
|
63
|
-
|
|
64
|
-
# Configure logging
|
|
65
|
-
configure_logging(os.getenv("LOG_LEVEL", 'INFO'))
|
|
66
|
-
logger = get_logger("main")
|
|
51
|
+
model: str = Field(title="Model to use")
|
|
52
|
+
system_prompt: str | None = Field(default=None, title="System Prompt to use")
|
|
53
|
+
messages: list[ChatMessage] = Field(title="Previous messages to supply to the LLM")
|
|
67
54
|
|
|
68
55
|
|
|
69
56
|
@app.post("/chat")
|
|
70
|
-
async def chat(req: ChatRequest):
|
|
57
|
+
async def chat(req: ChatRequest) -> dict[str, Any]:
|
|
71
58
|
chat = await get_chat(req.model, req.system_prompt)
|
|
72
59
|
messages = await chat.chat(req.messages)
|
|
73
60
|
|
|
74
|
-
return {
|
|
75
|
-
"messages": messages,
|
|
76
|
-
"response": messages[len(messages) - 1].content
|
|
77
|
-
}
|
|
61
|
+
return {"messages": messages, "response": messages[-1].content}
|
|
78
62
|
|
|
79
63
|
|
|
80
64
|
@app.post("/generate")
|
|
81
|
-
async def generate(req: GenerateRequest):
|
|
65
|
+
async def generate(req: GenerateRequest) -> dict[str, Any]:
|
|
82
66
|
chat = await get_chat(req.model, req.system_prompt)
|
|
83
|
-
messages = await chat.generate(
|
|
84
|
-
req.prompt,
|
|
85
|
-
req.session_id
|
|
86
|
-
)
|
|
67
|
+
messages = await chat.generate(req.prompt, req.session_id)
|
|
87
68
|
|
|
88
|
-
return {
|
|
89
|
-
"messages": messages,
|
|
90
|
-
"response": messages[len(messages) - 1].content
|
|
91
|
-
}
|
|
69
|
+
return {"messages": messages, "response": messages[-1].content}
|
|
92
70
|
|
|
93
71
|
|
|
94
72
|
@app.get("/generate/session/{session_id}")
|
|
95
|
-
async def get_generate_session(session_id):
|
|
73
|
+
async def get_generate_session(session_id: str) -> list[ChatMessage]:
|
|
96
74
|
session = McpToolChat.get_session(session_id)
|
|
97
75
|
if not session:
|
|
98
76
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
@@ -103,17 +81,14 @@ async def get_generate_session(session_id):
|
|
|
103
81
|
async def get_chat(model: str, system: str | None = None) -> McpToolChat:
|
|
104
82
|
# Get Provider from Model Config
|
|
105
83
|
model_config = config.models[model]
|
|
106
|
-
provider =
|
|
84
|
+
provider = provider_factory.get_provider(model, model_config)
|
|
107
85
|
|
|
108
86
|
# Get the system prompt
|
|
109
87
|
if not system:
|
|
110
|
-
if
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
f"{model_config.template}.j2",
|
|
114
|
-
await mcp_client.list_tools()
|
|
115
|
-
)
|
|
88
|
+
if model_config.template:
|
|
89
|
+
tools = await tool_cache.get_tools()
|
|
90
|
+
system = render_system_prompt(f"{model_config.template}.j2", tools)
|
|
116
91
|
else:
|
|
117
92
|
system = default_system_prompt
|
|
118
93
|
|
|
119
|
-
return McpToolChat(mcp_client, provider, system)
|
|
94
|
+
return McpToolChat(mcp_client, provider, system, tool_cache=tool_cache)
|
casual_mcp/mcp_tool_chat.py
CHANGED
|
@@ -1,62 +1,68 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
|
+
from typing import Any
|
|
3
4
|
|
|
4
|
-
from
|
|
5
|
-
|
|
6
|
-
from casual_mcp.logging import get_logger
|
|
7
|
-
from casual_mcp.models.messages import (
|
|
5
|
+
from casual_llm import (
|
|
6
|
+
AssistantToolCall,
|
|
8
7
|
ChatMessage,
|
|
8
|
+
LLMProvider,
|
|
9
9
|
SystemMessage,
|
|
10
10
|
ToolResultMessage,
|
|
11
11
|
UserMessage,
|
|
12
12
|
)
|
|
13
|
-
from
|
|
14
|
-
|
|
13
|
+
from fastmcp import Client
|
|
14
|
+
|
|
15
|
+
from casual_mcp.convert_tools import tools_from_mcp
|
|
16
|
+
from casual_mcp.logging import get_logger
|
|
17
|
+
from casual_mcp.tool_cache import ToolCache
|
|
15
18
|
from casual_mcp.utils import format_tool_call_result
|
|
16
19
|
|
|
17
20
|
logger = get_logger("mcp_tool_chat")
|
|
18
21
|
sessions: dict[str, list[ChatMessage]] = {}
|
|
19
22
|
|
|
20
23
|
|
|
21
|
-
def get_session_messages(session_id: str
|
|
24
|
+
def get_session_messages(session_id: str) -> list[ChatMessage]:
|
|
22
25
|
global sessions
|
|
23
26
|
|
|
24
|
-
if not sessions
|
|
27
|
+
if session_id not in sessions:
|
|
25
28
|
logger.info(f"Starting new session {session_id}")
|
|
26
29
|
sessions[session_id] = []
|
|
27
30
|
else:
|
|
28
|
-
logger.info(
|
|
29
|
-
f"Retrieving session {session_id} of length {len(sessions[session_id])}"
|
|
30
|
-
)
|
|
31
|
+
logger.info(f"Retrieving session {session_id} of length {len(sessions[session_id])}")
|
|
31
32
|
return sessions[session_id].copy()
|
|
32
33
|
|
|
33
34
|
|
|
34
|
-
def add_messages_to_session(session_id: str, messages: list[ChatMessage]):
|
|
35
|
+
def add_messages_to_session(session_id: str, messages: list[ChatMessage]) -> None:
|
|
35
36
|
global sessions
|
|
36
37
|
sessions[session_id].extend(messages.copy())
|
|
37
38
|
|
|
38
39
|
|
|
39
40
|
class McpToolChat:
|
|
40
|
-
def __init__(
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
mcp_client: Client[Any],
|
|
44
|
+
provider: LLMProvider,
|
|
45
|
+
system: str | None = None,
|
|
46
|
+
tool_cache: ToolCache | None = None,
|
|
47
|
+
):
|
|
41
48
|
self.provider = provider
|
|
42
49
|
self.mcp_client = mcp_client
|
|
43
50
|
self.system = system
|
|
51
|
+
self.tool_cache = tool_cache or ToolCache(mcp_client)
|
|
52
|
+
self._tool_cache_version = -1
|
|
44
53
|
|
|
45
54
|
@staticmethod
|
|
46
|
-
def get_session(session_id) -> list[ChatMessage] | None:
|
|
55
|
+
def get_session(session_id: str) -> list[ChatMessage] | None:
|
|
47
56
|
global sessions
|
|
48
57
|
return sessions.get(session_id)
|
|
49
58
|
|
|
50
|
-
async def generate(
|
|
51
|
-
self,
|
|
52
|
-
prompt: str,
|
|
53
|
-
session_id: str | None = None
|
|
54
|
-
) -> list[ChatMessage]:
|
|
59
|
+
async def generate(self, prompt: str, session_id: str | None = None) -> list[ChatMessage]:
|
|
55
60
|
# Fetch the session if we have a session ID
|
|
61
|
+
messages: list[ChatMessage]
|
|
56
62
|
if session_id:
|
|
57
63
|
messages = get_session_messages(session_id)
|
|
58
64
|
else:
|
|
59
|
-
messages
|
|
65
|
+
messages = []
|
|
60
66
|
|
|
61
67
|
# Add the prompt as a user message
|
|
62
68
|
user_message = UserMessage(content=prompt)
|
|
@@ -75,56 +81,53 @@ class McpToolChat:
|
|
|
75
81
|
|
|
76
82
|
return response
|
|
77
83
|
|
|
84
|
+
async def chat(self, messages: list[ChatMessage]) -> list[ChatMessage]:
|
|
85
|
+
tools = await self.tool_cache.get_tools()
|
|
78
86
|
|
|
79
|
-
async def chat(
|
|
80
|
-
self,
|
|
81
|
-
messages: list[ChatMessage]
|
|
82
|
-
) -> list[ChatMessage]:
|
|
83
87
|
# Add a system message if required
|
|
84
|
-
has_system_message = any(message.role ==
|
|
88
|
+
has_system_message = any(message.role == "system" for message in messages)
|
|
85
89
|
if self.system and not has_system_message:
|
|
86
90
|
# Insert the system message at the start of the messages
|
|
87
|
-
logger.debug(
|
|
91
|
+
logger.debug("Adding System Message")
|
|
88
92
|
messages.insert(0, SystemMessage(content=self.system))
|
|
89
93
|
|
|
90
94
|
logger.info("Start Chat")
|
|
91
|
-
async with self.mcp_client:
|
|
92
|
-
tools = await self.mcp_client.list_tools()
|
|
93
|
-
|
|
94
95
|
response_messages: list[ChatMessage] = []
|
|
95
96
|
while True:
|
|
96
97
|
logger.info("Calling the LLM")
|
|
97
|
-
ai_message = await self.provider.
|
|
98
|
+
ai_message = await self.provider.chat(messages=messages, tools=tools_from_mcp(tools))
|
|
98
99
|
|
|
99
100
|
# Add the assistant's message
|
|
100
101
|
response_messages.append(ai_message)
|
|
101
102
|
messages.append(ai_message)
|
|
102
103
|
|
|
104
|
+
logger.debug(f"Assistant: {ai_message}")
|
|
103
105
|
if not ai_message.tool_calls:
|
|
104
106
|
break
|
|
105
107
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
108
|
+
logger.info(f"Executing {len(ai_message.tool_calls)} tool calls")
|
|
109
|
+
result_count = 0
|
|
110
|
+
for tool_call in ai_message.tool_calls:
|
|
111
|
+
try:
|
|
112
|
+
result = await self.execute(tool_call)
|
|
113
|
+
except Exception as e:
|
|
114
|
+
logger.error(
|
|
115
|
+
f"Failed to execute tool '{tool_call.function.name}' "
|
|
116
|
+
f"(id={tool_call.id}): {e}"
|
|
117
|
+
)
|
|
118
|
+
continue
|
|
119
|
+
if result:
|
|
120
|
+
messages.append(result)
|
|
121
|
+
response_messages.append(result)
|
|
122
|
+
result_count = result_count + 1
|
|
123
|
+
|
|
124
|
+
logger.info(f"Added {result_count} tool results")
|
|
121
125
|
|
|
122
126
|
logger.debug(f"Final Response: {response_messages[-1].content}")
|
|
123
127
|
|
|
124
128
|
return response_messages
|
|
125
129
|
|
|
126
|
-
|
|
127
|
-
async def execute(self, tool_call: AssistantToolCall):
|
|
130
|
+
async def execute(self, tool_call: AssistantToolCall) -> ToolResultMessage:
|
|
128
131
|
tool_name = tool_call.function.name
|
|
129
132
|
tool_args = json.loads(tool_call.function.arguments)
|
|
130
133
|
try:
|
|
@@ -144,8 +147,18 @@ class McpToolChat:
|
|
|
144
147
|
|
|
145
148
|
logger.debug(f"Tool Call Result: {result}")
|
|
146
149
|
|
|
147
|
-
result_format = os.getenv(
|
|
148
|
-
content
|
|
150
|
+
result_format = os.getenv("TOOL_RESULT_FORMAT", "result")
|
|
151
|
+
# Extract text content from result (handle both TextContent and other content types)
|
|
152
|
+
if not result.content:
|
|
153
|
+
content_text = "[No content returned]"
|
|
154
|
+
else:
|
|
155
|
+
content_item = result.content[0]
|
|
156
|
+
if hasattr(content_item, "text"):
|
|
157
|
+
content_text = content_item.text
|
|
158
|
+
else:
|
|
159
|
+
# Handle non-text content (e.g., ImageContent)
|
|
160
|
+
content_text = f"[Non-text content: {type(content_item).__name__}]"
|
|
161
|
+
content = format_tool_call_result(tool_call, content_text, style=result_format)
|
|
149
162
|
|
|
150
163
|
return ToolResultMessage(
|
|
151
164
|
name=tool_call.function.name,
|
casual_mcp/models/__init__.py
CHANGED
|
@@ -1,27 +1,32 @@
|
|
|
1
|
-
from
|
|
2
|
-
McpServerConfig,
|
|
3
|
-
RemoteServerConfig,
|
|
4
|
-
StdioServerConfig,
|
|
5
|
-
)
|
|
6
|
-
from .messages import (
|
|
1
|
+
from casual_llm import (
|
|
7
2
|
AssistantMessage,
|
|
3
|
+
AssistantToolCall,
|
|
8
4
|
ChatMessage,
|
|
9
5
|
SystemMessage,
|
|
10
6
|
ToolResultMessage,
|
|
11
7
|
UserMessage,
|
|
12
8
|
)
|
|
9
|
+
|
|
10
|
+
from .mcp_server_config import (
|
|
11
|
+
McpServerConfig,
|
|
12
|
+
RemoteServerConfig,
|
|
13
|
+
StdioServerConfig,
|
|
14
|
+
)
|
|
13
15
|
from .model_config import (
|
|
14
|
-
|
|
16
|
+
McpModelConfig,
|
|
17
|
+
OllamaModelConfig,
|
|
15
18
|
OpenAIModelConfig,
|
|
16
19
|
)
|
|
17
20
|
|
|
18
21
|
__all__ = [
|
|
19
22
|
"UserMessage",
|
|
20
23
|
"AssistantMessage",
|
|
24
|
+
"AssistantToolCall",
|
|
21
25
|
"ToolResultMessage",
|
|
22
26
|
"SystemMessage",
|
|
23
27
|
"ChatMessage",
|
|
24
|
-
"
|
|
28
|
+
"McpModelConfig",
|
|
29
|
+
"OllamaModelConfig",
|
|
25
30
|
"OpenAIModelConfig",
|
|
26
31
|
"McpServerConfig",
|
|
27
32
|
"StdioServerConfig",
|
casual_mcp/models/config.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from pydantic import BaseModel
|
|
2
2
|
|
|
3
3
|
from casual_mcp.models.mcp_server_config import McpServerConfig
|
|
4
|
-
from casual_mcp.models.model_config import
|
|
4
|
+
from casual_mcp.models.model_config import McpModelConfig
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class Config(BaseModel):
|
|
8
8
|
namespace_tools: bool | None = False
|
|
9
|
-
models: dict[str,
|
|
9
|
+
models: dict[str, McpModelConfig]
|
|
10
10
|
servers: dict[str, McpServerConfig]
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
|
-
from pydantic import BaseModel
|
|
3
|
+
from pydantic import BaseModel
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class BaseModelConfig(BaseModel):
|
|
7
7
|
provider: Literal["openai", "ollama"]
|
|
8
8
|
model: str
|
|
9
|
-
endpoint:
|
|
9
|
+
endpoint: str | None = None
|
|
10
10
|
template: str | None = None
|
|
11
11
|
|
|
12
12
|
|
|
@@ -18,4 +18,4 @@ class OllamaModelConfig(BaseModelConfig):
|
|
|
18
18
|
provider: Literal["ollama"]
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
|
|
21
|
+
McpModelConfig = OpenAIModelConfig | OllamaModelConfig
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from casual_llm import (
|
|
4
|
+
LLMProvider,
|
|
5
|
+
ModelConfig,
|
|
6
|
+
Provider,
|
|
7
|
+
create_provider,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
from casual_mcp.logging import get_logger
|
|
11
|
+
from casual_mcp.models.model_config import McpModelConfig
|
|
12
|
+
|
|
13
|
+
logger = get_logger("providers.factory")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ProviderFactory:
|
|
17
|
+
PROVIDER_MAP = {
|
|
18
|
+
"openai": Provider.OPENAI,
|
|
19
|
+
"ollama": Provider.OLLAMA,
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
def __init__(self) -> None:
|
|
23
|
+
self.providers: dict[str, LLMProvider] = {}
|
|
24
|
+
|
|
25
|
+
def get_provider(self, name: str, config: McpModelConfig) -> LLMProvider:
|
|
26
|
+
existing = self.providers.get(name)
|
|
27
|
+
if existing:
|
|
28
|
+
return existing
|
|
29
|
+
|
|
30
|
+
provider = self.PROVIDER_MAP.get(config.provider)
|
|
31
|
+
if provider is None:
|
|
32
|
+
raise ValueError(f"Unknown provider: {config.provider}")
|
|
33
|
+
|
|
34
|
+
# Use casual-llm create provider
|
|
35
|
+
api_key = os.getenv("OPENAI_API_KEY") if provider == Provider.OPENAI else None
|
|
36
|
+
llm_provider = create_provider(
|
|
37
|
+
ModelConfig(
|
|
38
|
+
provider=provider,
|
|
39
|
+
name=config.model,
|
|
40
|
+
base_url=config.endpoint,
|
|
41
|
+
api_key=api_key,
|
|
42
|
+
)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# add to providers and return
|
|
46
|
+
self.providers[name] = llm_provider
|
|
47
|
+
return llm_provider
|