honeymcp 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- honeymcp/__init__.py +34 -0
- honeymcp/cli.py +205 -0
- honeymcp/core/__init__.py +20 -0
- honeymcp/core/dynamic_ghost_tools.py +443 -0
- honeymcp/core/fingerprinter.py +273 -0
- honeymcp/core/ghost_tools.py +624 -0
- honeymcp/core/middleware.py +573 -0
- honeymcp/dashboard/__init__.py +0 -0
- honeymcp/dashboard/app.py +228 -0
- honeymcp/integrations/__init__.py +3 -0
- honeymcp/llm/__init__.py +6 -0
- honeymcp/llm/analyzers.py +278 -0
- honeymcp/llm/clients/__init__.py +102 -0
- honeymcp/llm/clients/provider_type.py +11 -0
- honeymcp/llm/prompts/__init__.py +81 -0
- honeymcp/llm/prompts/dynamic_ghost_tools.yaml +88 -0
- honeymcp/models/__init__.py +8 -0
- honeymcp/models/config.py +187 -0
- honeymcp/models/events.py +60 -0
- honeymcp/models/ghost_tool_spec.py +31 -0
- honeymcp/models/protection_mode.py +17 -0
- honeymcp/storage/__init__.py +5 -0
- honeymcp/storage/event_store.py +176 -0
- honeymcp-0.1.0.dist-info/METADATA +699 -0
- honeymcp-0.1.0.dist-info/RECORD +28 -0
- honeymcp-0.1.0.dist-info/WHEEL +4 -0
- honeymcp-0.1.0.dist-info/entry_points.txt +2 -0
- honeymcp-0.1.0.dist-info/licenses/LICENSE +17 -0
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
"""HoneyMCP Dashboard - Real-time attack visualization with Streamlit."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import sys
|
|
5
|
+
from datetime import date, datetime, timedelta
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import List
|
|
8
|
+
|
|
9
|
+
import streamlit as st
|
|
10
|
+
|
|
11
|
+
# Add parent directory to path for imports
|
|
12
|
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
13
|
+
|
|
14
|
+
# pylint: disable=wrong-import-position
|
|
15
|
+
from honeymcp.models.events import AttackFingerprint
|
|
16
|
+
from honeymcp.storage.event_store import list_events
|
|
17
|
+
|
|
18
|
+
# Page configuration
|
|
19
|
+
st.set_page_config(
|
|
20
|
+
page_title="HoneyMCP Dashboard",
|
|
21
|
+
page_icon="🍯",
|
|
22
|
+
layout="wide",
|
|
23
|
+
initial_sidebar_state="expanded",
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def load_events() -> List[AttackFingerprint]:
|
|
28
|
+
"""Load attack events from storage."""
|
|
29
|
+
try:
|
|
30
|
+
events = asyncio.run(list_events())
|
|
31
|
+
return events
|
|
32
|
+
except Exception as e:
|
|
33
|
+
st.error(f"Failed to load events: {e}")
|
|
34
|
+
return []
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_threat_emoji(threat_level: str) -> str:
|
|
38
|
+
"""Get emoji for threat level."""
|
|
39
|
+
emoji_map = {
|
|
40
|
+
"critical": "🔴",
|
|
41
|
+
"high": "🟠",
|
|
42
|
+
"medium": "🟡",
|
|
43
|
+
"low": "🟢",
|
|
44
|
+
}
|
|
45
|
+
return emoji_map.get(threat_level.lower(), "⚪")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def format_timestamp(dt: datetime) -> str:
|
|
49
|
+
"""Format timestamp for display."""
|
|
50
|
+
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def main(): # pylint: disable=too-many-branches,too-many-statements
|
|
54
|
+
"""Main dashboard application."""
|
|
55
|
+
|
|
56
|
+
# Header
|
|
57
|
+
st.title("🍯 HoneyMCP Dashboard")
|
|
58
|
+
st.markdown("**Real-time AI Agent Attack Detection & Intelligence**")
|
|
59
|
+
st.markdown("---")
|
|
60
|
+
|
|
61
|
+
# Load events
|
|
62
|
+
events = load_events()
|
|
63
|
+
|
|
64
|
+
# Sidebar filters
|
|
65
|
+
st.sidebar.header("Filters")
|
|
66
|
+
|
|
67
|
+
# Date range filter
|
|
68
|
+
if events:
|
|
69
|
+
min_date = min(e.timestamp for e in events).date()
|
|
70
|
+
max_date = max(e.timestamp for e in events).date()
|
|
71
|
+
else:
|
|
72
|
+
min_date = date.today() - timedelta(days=7)
|
|
73
|
+
max_date = date.today()
|
|
74
|
+
|
|
75
|
+
st.sidebar.date_input(
|
|
76
|
+
"Date Range",
|
|
77
|
+
value=(min_date, max_date),
|
|
78
|
+
min_value=min_date,
|
|
79
|
+
max_value=max_date,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Threat level filter
|
|
83
|
+
threat_filter = st.sidebar.selectbox(
|
|
84
|
+
"Threat Level",
|
|
85
|
+
["All", "Critical", "High", "Medium", "Low"],
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Attack category filter
|
|
89
|
+
if events:
|
|
90
|
+
categories = sorted(set(e.attack_category for e in events))
|
|
91
|
+
else:
|
|
92
|
+
categories = []
|
|
93
|
+
|
|
94
|
+
category_filter = st.sidebar.selectbox(
|
|
95
|
+
"Attack Category",
|
|
96
|
+
["All"] + categories,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Apply filters
|
|
100
|
+
filtered_events = events
|
|
101
|
+
|
|
102
|
+
if threat_filter != "All":
|
|
103
|
+
filtered_events = [
|
|
104
|
+
e for e in filtered_events if e.threat_level.lower() == threat_filter.lower()
|
|
105
|
+
]
|
|
106
|
+
|
|
107
|
+
if category_filter != "All":
|
|
108
|
+
filtered_events = [e for e in filtered_events if e.attack_category == category_filter]
|
|
109
|
+
|
|
110
|
+
# Metrics row
|
|
111
|
+
st.header("📊 Attack Metrics")
|
|
112
|
+
col1, col2, col3, col4 = st.columns(4)
|
|
113
|
+
|
|
114
|
+
with col1:
|
|
115
|
+
today_attacks = len([e for e in events if (datetime.utcnow() - e.timestamp).days < 1])
|
|
116
|
+
st.metric(
|
|
117
|
+
"Total Attacks",
|
|
118
|
+
len(events),
|
|
119
|
+
delta=f"+{today_attacks} today",
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
with col2:
|
|
123
|
+
critical_count = len([e for e in events if e.threat_level == "critical"])
|
|
124
|
+
st.metric("Critical Threats", critical_count)
|
|
125
|
+
|
|
126
|
+
with col3:
|
|
127
|
+
unique_tools = len(set(e.ghost_tool_called for e in events)) if events else 0
|
|
128
|
+
st.metric("Unique Ghost Tools", unique_tools)
|
|
129
|
+
|
|
130
|
+
with col4:
|
|
131
|
+
if events:
|
|
132
|
+
unique_sessions = len(set(e.session_id for e in events))
|
|
133
|
+
st.metric("Unique Sessions", unique_sessions)
|
|
134
|
+
else:
|
|
135
|
+
st.metric("Unique Sessions", 0)
|
|
136
|
+
|
|
137
|
+
st.markdown("---")
|
|
138
|
+
|
|
139
|
+
# Attack breakdown
|
|
140
|
+
if events:
|
|
141
|
+
st.header("🎯 Attack Breakdown")
|
|
142
|
+
col1, col2 = st.columns(2)
|
|
143
|
+
|
|
144
|
+
with col1:
|
|
145
|
+
st.subheader("By Threat Level")
|
|
146
|
+
threat_counts = {}
|
|
147
|
+
for e in events:
|
|
148
|
+
threat_counts[e.threat_level] = threat_counts.get(e.threat_level, 0) + 1
|
|
149
|
+
st.bar_chart(threat_counts)
|
|
150
|
+
|
|
151
|
+
with col2:
|
|
152
|
+
st.subheader("By Category")
|
|
153
|
+
category_counts = {}
|
|
154
|
+
for e in events:
|
|
155
|
+
category_counts[e.attack_category] = category_counts.get(e.attack_category, 0) + 1
|
|
156
|
+
st.bar_chart(category_counts)
|
|
157
|
+
|
|
158
|
+
st.markdown("---")
|
|
159
|
+
|
|
160
|
+
# Event feed
|
|
161
|
+
st.header("🚨 Recent Attacks")
|
|
162
|
+
|
|
163
|
+
if not filtered_events:
|
|
164
|
+
st.info("No attacks detected yet. Ghost tools are active and monitoring.")
|
|
165
|
+
else:
|
|
166
|
+
# Sort by timestamp (newest first)
|
|
167
|
+
filtered_events.sort(key=lambda e: e.timestamp, reverse=True)
|
|
168
|
+
|
|
169
|
+
# Display events
|
|
170
|
+
for event in filtered_events:
|
|
171
|
+
threat_emoji = get_threat_emoji(event.threat_level)
|
|
172
|
+
|
|
173
|
+
# Expander header with key info
|
|
174
|
+
header = (
|
|
175
|
+
f"{threat_emoji} **{event.ghost_tool_called}** | "
|
|
176
|
+
f"{format_timestamp(event.timestamp)} | "
|
|
177
|
+
f"Session: {event.session_id[:8]}... | "
|
|
178
|
+
f"Threat: {event.threat_level.upper()}"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
with st.expander(header):
|
|
182
|
+
# Event details
|
|
183
|
+
col1, col2 = st.columns(2)
|
|
184
|
+
|
|
185
|
+
with col1:
|
|
186
|
+
st.markdown("**Event Details**")
|
|
187
|
+
st.text(f"Event ID: {event.event_id}")
|
|
188
|
+
st.text(f"Timestamp: {format_timestamp(event.timestamp)}")
|
|
189
|
+
st.text(f"Session ID: {event.session_id}")
|
|
190
|
+
st.text(f"Threat Level: {event.threat_level}")
|
|
191
|
+
st.text(f"Category: {event.attack_category}")
|
|
192
|
+
|
|
193
|
+
with col2:
|
|
194
|
+
st.markdown("**Tool Call Sequence**")
|
|
195
|
+
for i, tool in enumerate(event.tool_call_sequence, 1):
|
|
196
|
+
if tool == event.ghost_tool_called:
|
|
197
|
+
st.markdown(f"{i}. **{tool}** ⚠️ (honeypot)")
|
|
198
|
+
else:
|
|
199
|
+
st.text(f"{i}. {tool}")
|
|
200
|
+
|
|
201
|
+
# Arguments
|
|
202
|
+
if event.arguments:
|
|
203
|
+
st.markdown("**Arguments Passed**")
|
|
204
|
+
st.json(event.arguments)
|
|
205
|
+
|
|
206
|
+
# Response sent
|
|
207
|
+
st.markdown("**Fake Response Sent to Attacker**")
|
|
208
|
+
st.code(event.response_sent, language="text")
|
|
209
|
+
|
|
210
|
+
# Full event data
|
|
211
|
+
with st.expander("View Full Event JSON"):
|
|
212
|
+
st.json(event.model_dump(mode="json"))
|
|
213
|
+
|
|
214
|
+
# Footer
|
|
215
|
+
st.markdown("---")
|
|
216
|
+
st.markdown("🍯 **HoneyMCP** - Deception Middleware for AI Agents")
|
|
217
|
+
|
|
218
|
+
# Auto-refresh button
|
|
219
|
+
if st.button("🔄 Refresh", key="refresh_btn"):
|
|
220
|
+
st.rerun()
|
|
221
|
+
|
|
222
|
+
# Auto-refresh timer info
|
|
223
|
+
st.sidebar.markdown("---")
|
|
224
|
+
st.sidebar.info("💡 Click 'Refresh' to reload events")
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
if __name__ == "__main__":
|
|
228
|
+
main()
|
honeymcp/llm/__init__.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
"""Tool and server analysis utilities."""
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import logging
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
from fastmcp import FastMCP
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class ToolInfo:
|
|
15
|
+
"""Information about a tool extracted from an MCP server."""
|
|
16
|
+
|
|
17
|
+
name: str
|
|
18
|
+
"""Tool name"""
|
|
19
|
+
|
|
20
|
+
description: str
|
|
21
|
+
"""Tool description"""
|
|
22
|
+
|
|
23
|
+
parameters: Dict[str, Any]
|
|
24
|
+
"""JSON schema for tool parameters"""
|
|
25
|
+
|
|
26
|
+
category: Optional[str] = None
|
|
27
|
+
"""Optional category classification"""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
async def extract_tool_info( # pylint: disable=too-many-branches,too-many-statements,too-many-nested-blocks,protected-access
|
|
31
|
+
server: FastMCP,
|
|
32
|
+
) -> List[ToolInfo]:
|
|
33
|
+
"""Extract tool information from a FastMCP server.
|
|
34
|
+
|
|
35
|
+
This function attempts to extract tool information using multiple methods
|
|
36
|
+
to ensure compatibility with different FastMCP versions.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
server: FastMCP server instance
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
List of ToolInfo objects containing tool metadata
|
|
43
|
+
|
|
44
|
+
Raises:
|
|
45
|
+
ValueError: If no tools can be extracted from the server
|
|
46
|
+
"""
|
|
47
|
+
tools = []
|
|
48
|
+
|
|
49
|
+
# Method 1: Try using the public list_tools method if available
|
|
50
|
+
if hasattr(server, "list_tools"):
|
|
51
|
+
try:
|
|
52
|
+
tool_list = await server.list_tools()
|
|
53
|
+
for tool in tool_list:
|
|
54
|
+
# Handle both dict and object formats
|
|
55
|
+
if isinstance(tool, dict):
|
|
56
|
+
name = tool.get("name", "unknown")
|
|
57
|
+
description = tool.get("description", "No description")
|
|
58
|
+
parameters = tool.get("inputSchema", {})
|
|
59
|
+
else:
|
|
60
|
+
# Handle FunctionTool or similar objects
|
|
61
|
+
name = getattr(tool, "name", "unknown")
|
|
62
|
+
description = getattr(tool, "description", "No description")
|
|
63
|
+
parameters = getattr(tool, "inputSchema", {})
|
|
64
|
+
if not parameters and hasattr(tool, "parameters"):
|
|
65
|
+
parameters = tool.parameters
|
|
66
|
+
|
|
67
|
+
tools.append(
|
|
68
|
+
ToolInfo(
|
|
69
|
+
name=name,
|
|
70
|
+
description=description,
|
|
71
|
+
parameters=parameters,
|
|
72
|
+
category=None,
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
if tools:
|
|
76
|
+
logger.info("Extracted %s tools using list_tools method", len(tools))
|
|
77
|
+
return tools
|
|
78
|
+
except Exception as e:
|
|
79
|
+
logger.warning("Failed to extract tools using list_tools: %s", e)
|
|
80
|
+
|
|
81
|
+
# Method 2: Try accessing internal _tools dictionary
|
|
82
|
+
if hasattr(server, "_tools"):
|
|
83
|
+
try:
|
|
84
|
+
internal_tools = server._tools
|
|
85
|
+
for tool_name, tool_obj in internal_tools.items():
|
|
86
|
+
description = "No description"
|
|
87
|
+
parameters = {}
|
|
88
|
+
|
|
89
|
+
# Extract description
|
|
90
|
+
if hasattr(tool_obj, "description"):
|
|
91
|
+
description = tool_obj.description
|
|
92
|
+
elif hasattr(tool_obj, "__doc__") and tool_obj.__doc__:
|
|
93
|
+
description = tool_obj.__doc__.strip()
|
|
94
|
+
|
|
95
|
+
# Extract parameters from function signature if available
|
|
96
|
+
if hasattr(tool_obj, "fn"):
|
|
97
|
+
sig = inspect.signature(tool_obj.fn)
|
|
98
|
+
properties = {}
|
|
99
|
+
required = []
|
|
100
|
+
|
|
101
|
+
for param_name, param in sig.parameters.items():
|
|
102
|
+
if param_name in ("self", "cls"):
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
param_type = "string" # Default type
|
|
106
|
+
if param.annotation != inspect.Parameter.empty:
|
|
107
|
+
if param.annotation == int:
|
|
108
|
+
param_type = "integer"
|
|
109
|
+
elif param.annotation == float:
|
|
110
|
+
param_type = "number"
|
|
111
|
+
elif param.annotation == bool:
|
|
112
|
+
param_type = "boolean"
|
|
113
|
+
elif param.annotation == list:
|
|
114
|
+
param_type = "array"
|
|
115
|
+
elif param.annotation == dict:
|
|
116
|
+
param_type = "object"
|
|
117
|
+
|
|
118
|
+
properties[param_name] = {
|
|
119
|
+
"type": param_type,
|
|
120
|
+
"description": f"Parameter {param_name}",
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
if param.default == inspect.Parameter.empty:
|
|
124
|
+
required.append(param_name)
|
|
125
|
+
|
|
126
|
+
if properties:
|
|
127
|
+
parameters = {
|
|
128
|
+
"type": "object",
|
|
129
|
+
"properties": properties,
|
|
130
|
+
"required": required,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
tools.append(
|
|
134
|
+
ToolInfo(
|
|
135
|
+
name=tool_name,
|
|
136
|
+
description=description,
|
|
137
|
+
parameters=parameters,
|
|
138
|
+
category=None,
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
if tools:
|
|
143
|
+
logger.info("Extracted %s tools using _tools dictionary", len(tools))
|
|
144
|
+
return tools
|
|
145
|
+
except Exception as e:
|
|
146
|
+
logger.warning("Failed to extract tools using _tools: %s", e)
|
|
147
|
+
|
|
148
|
+
# Method 3: Try accessing internal docket
|
|
149
|
+
if hasattr(server, "_docket") and hasattr(server._docket, "tools"):
|
|
150
|
+
try:
|
|
151
|
+
docket_tools = server._docket.tools
|
|
152
|
+
for tool_name, tool_obj in docket_tools.items():
|
|
153
|
+
description = "No description"
|
|
154
|
+
parameters = {}
|
|
155
|
+
|
|
156
|
+
if hasattr(tool_obj, "description"):
|
|
157
|
+
description = tool_obj.description
|
|
158
|
+
|
|
159
|
+
if hasattr(tool_obj, "parameters"):
|
|
160
|
+
parameters = tool_obj.parameters
|
|
161
|
+
elif hasattr(tool_obj, "input_schema"):
|
|
162
|
+
parameters = tool_obj.input_schema
|
|
163
|
+
|
|
164
|
+
tools.append(
|
|
165
|
+
ToolInfo(
|
|
166
|
+
name=tool_name,
|
|
167
|
+
description=description,
|
|
168
|
+
parameters=parameters,
|
|
169
|
+
category=None,
|
|
170
|
+
)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
if tools:
|
|
174
|
+
logger.info("Extracted %s tools using _docket", len(tools))
|
|
175
|
+
return tools
|
|
176
|
+
except Exception as e:
|
|
177
|
+
logger.warning("Failed to extract tools using _docket: %s", e)
|
|
178
|
+
|
|
179
|
+
# If no tools were extracted, raise an error
|
|
180
|
+
if not tools:
|
|
181
|
+
raise ValueError(
|
|
182
|
+
"Could not extract tools from FastMCP server. "
|
|
183
|
+
"The server may not have any tools registered, or the FastMCP version "
|
|
184
|
+
"may not be compatible with the extraction methods."
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
return tools
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def categorize_tools(tools: List[ToolInfo]) -> Dict[str, List[ToolInfo]]:
|
|
191
|
+
"""Categorize tools based on their names and descriptions.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
tools: List of ToolInfo objects
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Dictionary mapping category names to lists of tools
|
|
198
|
+
"""
|
|
199
|
+
categories = {
|
|
200
|
+
"file_system": [],
|
|
201
|
+
"database": [],
|
|
202
|
+
"api": [],
|
|
203
|
+
"security": [],
|
|
204
|
+
"development": [],
|
|
205
|
+
"communication": [],
|
|
206
|
+
"data_processing": [],
|
|
207
|
+
"other": [],
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
# Keywords for each category
|
|
211
|
+
category_keywords = {
|
|
212
|
+
"file_system": [
|
|
213
|
+
"file",
|
|
214
|
+
"read",
|
|
215
|
+
"write",
|
|
216
|
+
"directory",
|
|
217
|
+
"path",
|
|
218
|
+
"folder",
|
|
219
|
+
"upload",
|
|
220
|
+
"download",
|
|
221
|
+
],
|
|
222
|
+
"database": [
|
|
223
|
+
"database",
|
|
224
|
+
"query",
|
|
225
|
+
"sql",
|
|
226
|
+
"table",
|
|
227
|
+
"record",
|
|
228
|
+
"insert",
|
|
229
|
+
"update",
|
|
230
|
+
"delete",
|
|
231
|
+
],
|
|
232
|
+
"api": ["api", "request", "response", "endpoint", "http", "rest", "graphql"],
|
|
233
|
+
"security": [
|
|
234
|
+
"auth",
|
|
235
|
+
"token",
|
|
236
|
+
"credential",
|
|
237
|
+
"password",
|
|
238
|
+
"key",
|
|
239
|
+
"secret",
|
|
240
|
+
"permission",
|
|
241
|
+
],
|
|
242
|
+
"development": [
|
|
243
|
+
"build",
|
|
244
|
+
"deploy",
|
|
245
|
+
"test",
|
|
246
|
+
"debug",
|
|
247
|
+
"compile",
|
|
248
|
+
"run",
|
|
249
|
+
"execute",
|
|
250
|
+
],
|
|
251
|
+
"communication": ["send", "message", "email", "notify", "alert", "webhook"],
|
|
252
|
+
"data_processing": [
|
|
253
|
+
"process",
|
|
254
|
+
"transform",
|
|
255
|
+
"parse",
|
|
256
|
+
"convert",
|
|
257
|
+
"analyze",
|
|
258
|
+
"calculate",
|
|
259
|
+
],
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
for tool in tools:
|
|
263
|
+
text = f"{tool.name} {tool.description}".lower()
|
|
264
|
+
categorized = False
|
|
265
|
+
|
|
266
|
+
for category, keywords in category_keywords.items():
|
|
267
|
+
if any(keyword in text for keyword in keywords):
|
|
268
|
+
categories[category].append(tool)
|
|
269
|
+
tool.category = category
|
|
270
|
+
categorized = True
|
|
271
|
+
break
|
|
272
|
+
|
|
273
|
+
if not categorized:
|
|
274
|
+
categories["other"].append(tool)
|
|
275
|
+
tool.category = "other"
|
|
276
|
+
|
|
277
|
+
# Remove empty categories
|
|
278
|
+
return {k: v for k, v in categories.items() if v}
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""LLM client module for different providers."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Dict, Optional, Any
|
|
6
|
+
from dotenv import load_dotenv
|
|
7
|
+
from honeymcp.llm.clients.provider_type import LLMProviderType
|
|
8
|
+
|
|
9
|
+
# Load .env.honeymcp first (if exists), then .env as fallback
|
|
10
|
+
# This allows honeymcp-specific config without interfering with project's .env
|
|
11
|
+
_honeymcp_env = Path.cwd() / ".env.honeymcp"
|
|
12
|
+
if _honeymcp_env.exists():
|
|
13
|
+
load_dotenv(_honeymcp_env)
|
|
14
|
+
else:
|
|
15
|
+
load_dotenv() # Fall back to .env
|
|
16
|
+
|
|
17
|
+
LLM_PROVIDER = LLMProviderType(os.getenv("LLM_PROVIDER", LLMProviderType.WATSONX.value))
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _get_base_llm_settings(model_name: str, model_parameters: Optional[Dict]) -> Dict:
|
|
21
|
+
if model_parameters is None:
|
|
22
|
+
model_parameters = {}
|
|
23
|
+
|
|
24
|
+
if LLM_PROVIDER == LLMProviderType.WATSONX:
|
|
25
|
+
parameters = {
|
|
26
|
+
"max_new_tokens": model_parameters.get("max_tokens", 100),
|
|
27
|
+
"decoding_method": model_parameters.get("decoding_method", "greedy"),
|
|
28
|
+
"temperature": model_parameters.get("temperature", 0.9),
|
|
29
|
+
"repetition_penalty": model_parameters.get("repetition_penalty", 1.0),
|
|
30
|
+
"top_k": model_parameters.get("top_k", 50),
|
|
31
|
+
"top_p": model_parameters.get("top_p", 1.0),
|
|
32
|
+
"stop_sequences": model_parameters.get("stop_sequences", []),
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
return {
|
|
36
|
+
"url": os.getenv("WATSONX_API_ENDPOINT"),
|
|
37
|
+
"project_id": os.getenv("WATSONX_PROJECT_ID"),
|
|
38
|
+
"apikey": os.getenv("WATSONX_API_KEY"),
|
|
39
|
+
"model_id": model_name,
|
|
40
|
+
"params": parameters,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
if LLM_PROVIDER == LLMProviderType.OPENAI:
|
|
44
|
+
parameters = {
|
|
45
|
+
"max_tokens": model_parameters.get("max_tokens", 100),
|
|
46
|
+
"temperature": model_parameters.get("temperature", 0),
|
|
47
|
+
"stop": model_parameters.get("stop_sequences", []),
|
|
48
|
+
}
|
|
49
|
+
return {
|
|
50
|
+
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
51
|
+
"model": model_name,
|
|
52
|
+
**parameters,
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
if LLM_PROVIDER == LLMProviderType.RITS:
|
|
56
|
+
rits_base_url = os.getenv("RITS_API_BASE_URL")
|
|
57
|
+
|
|
58
|
+
parameters = {
|
|
59
|
+
"max_tokens": model_parameters.get("max_tokens", 100),
|
|
60
|
+
"temperature": model_parameters.get("temperature", 0.9),
|
|
61
|
+
"top_p": model_parameters.get("top_p", 1.0),
|
|
62
|
+
"stop": model_parameters.get("stop_sequences", []),
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
return {
|
|
66
|
+
"base_url": f"{rits_base_url}/v1",
|
|
67
|
+
"model": model_name,
|
|
68
|
+
"api_key": os.getenv("RITS_API_KEY"),
|
|
69
|
+
"extra_body": parameters,
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
raise ValueError(f"Incorrect LLM provider: {LLM_PROVIDER}")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def get_chat_llm_client(
|
|
76
|
+
model_name: str = "rits/openai/gpt-oss-120b",
|
|
77
|
+
model_parameters: Optional[Dict] = None,
|
|
78
|
+
) -> Any:
|
|
79
|
+
"""Get a chat LLM client based on the configured provider.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
model_name: The name of the model to use.
|
|
83
|
+
model_parameters: Optional model parameters.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
The LLM client instance.
|
|
87
|
+
"""
|
|
88
|
+
if LLM_PROVIDER in (LLMProviderType.OPENAI, LLMProviderType.RITS):
|
|
89
|
+
from langchain_openai import ChatOpenAI # pylint: disable=import-outside-toplevel
|
|
90
|
+
|
|
91
|
+
return ChatOpenAI(
|
|
92
|
+
**_get_base_llm_settings(model_name=model_name, model_parameters=model_parameters)
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if LLM_PROVIDER == LLMProviderType.WATSONX:
|
|
96
|
+
from langchain_ibm import ChatWatsonx # pylint: disable=import-outside-toplevel
|
|
97
|
+
|
|
98
|
+
return ChatWatsonx(
|
|
99
|
+
**_get_base_llm_settings(model_name=model_name, model_parameters=model_parameters)
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
raise ValueError(f"Unsupported LLM provider: {LLM_PROVIDER}")
|