gnosys-strata 1.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gnosys_strata-1.1.4.dist-info/METADATA +140 -0
- gnosys_strata-1.1.4.dist-info/RECORD +28 -0
- gnosys_strata-1.1.4.dist-info/WHEEL +4 -0
- gnosys_strata-1.1.4.dist-info/entry_points.txt +2 -0
- strata/__init__.py +6 -0
- strata/__main__.py +6 -0
- strata/cli.py +364 -0
- strata/config.py +310 -0
- strata/logging_config.py +109 -0
- strata/main.py +6 -0
- strata/mcp_client_manager.py +282 -0
- strata/mcp_proxy/__init__.py +7 -0
- strata/mcp_proxy/auth_provider.py +200 -0
- strata/mcp_proxy/client.py +162 -0
- strata/mcp_proxy/transport/__init__.py +7 -0
- strata/mcp_proxy/transport/base.py +104 -0
- strata/mcp_proxy/transport/http.py +80 -0
- strata/mcp_proxy/transport/stdio.py +69 -0
- strata/server.py +216 -0
- strata/tools.py +714 -0
- strata/treeshell_functions.py +397 -0
- strata/utils/__init__.py +0 -0
- strata/utils/bm25_search.py +181 -0
- strata/utils/catalog.py +82 -0
- strata/utils/dict_utils.py +29 -0
- strata/utils/field_search.py +233 -0
- strata/utils/shared_search.py +202 -0
- strata/utils/tool_integration.py +269 -0
|
@@ -0,0 +1,397 @@
|
|
|
1
|
+
"""Strata functions for TreeShell crystallization - clean function-per-tool architecture."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import traceback
|
|
5
|
+
from typing import List, Dict, Any, Optional
|
|
6
|
+
|
|
7
|
+
from strata.mcp_client_manager import MCPClientManager
|
|
8
|
+
from strata.utils.shared_search import UniversalToolSearcher
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
# Global client manager - instantiated at import time
|
|
13
|
+
client_manager = MCPClientManager()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
async def discover_server_actions(user_query: str, server_names: List[str] = None) -> Dict[str, Any]:
|
|
17
|
+
"""
|
|
18
|
+
**PREFERRED STARTING POINT**: Discover available actions from servers based on user query.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
user_query: Natural language user query to filter results.
|
|
22
|
+
server_names: List of server names to discover actions from.
|
|
23
|
+
"""
|
|
24
|
+
if not server_names:
|
|
25
|
+
server_names = list(client_manager.active_clients.keys())
|
|
26
|
+
|
|
27
|
+
discovery_result = {}
|
|
28
|
+
for server_name in server_names:
|
|
29
|
+
try:
|
|
30
|
+
client = client_manager.get_client(server_name)
|
|
31
|
+
tools = await client.list_tools()
|
|
32
|
+
|
|
33
|
+
if user_query and tools:
|
|
34
|
+
tools_map = {server_name: tools}
|
|
35
|
+
searcher = UniversalToolSearcher(tools_map)
|
|
36
|
+
search_results = searcher.search(user_query, max_results=50)
|
|
37
|
+
|
|
38
|
+
filtered_action_names = []
|
|
39
|
+
for result_item in search_results:
|
|
40
|
+
for tool in tools:
|
|
41
|
+
if tool["name"] == result_item["name"]:
|
|
42
|
+
filtered_action_names.append(tool)
|
|
43
|
+
break
|
|
44
|
+
discovery_result[server_name] = filtered_action_names
|
|
45
|
+
else:
|
|
46
|
+
discovery_result[server_name] = tools
|
|
47
|
+
|
|
48
|
+
except KeyError:
|
|
49
|
+
discovery_result[server_name] = {"error": f"Server '{server_name}' not found or not connected"}
|
|
50
|
+
except Exception as e:
|
|
51
|
+
logger.error(f"Error discovering tools for {server_name}: {e}\n{traceback.format_exc()}")
|
|
52
|
+
discovery_result[server_name] = {"error": str(e), "traceback": traceback.format_exc()}
|
|
53
|
+
|
|
54
|
+
return discovery_result
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
async def get_action_details(server_name: str, action_name: str) -> Dict[str, Any]:
|
|
58
|
+
"""
|
|
59
|
+
Get detailed information about a specific action.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
server_name: The name of the server.
|
|
63
|
+
action_name: The name of the action/operation.
|
|
64
|
+
"""
|
|
65
|
+
try:
|
|
66
|
+
client = client_manager.get_client(server_name)
|
|
67
|
+
tools = await client.list_tools()
|
|
68
|
+
|
|
69
|
+
tool = next((t for t in tools if t["name"] == action_name), None)
|
|
70
|
+
|
|
71
|
+
if tool:
|
|
72
|
+
return {
|
|
73
|
+
"name": tool["name"],
|
|
74
|
+
"description": tool.get("description"),
|
|
75
|
+
"inputSchema": tool.get("inputSchema"),
|
|
76
|
+
}
|
|
77
|
+
else:
|
|
78
|
+
return {"error": f"Action '{action_name}' not found on server '{server_name}'"}
|
|
79
|
+
except KeyError:
|
|
80
|
+
return {"error": f"Server '{server_name}' not found or not connected"}
|
|
81
|
+
except Exception as e:
|
|
82
|
+
logger.error(f"Error getting action details for {server_name}/{action_name}: {e}\n{traceback.format_exc()}")
|
|
83
|
+
return {"error": str(e), "traceback": traceback.format_exc()}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
async def execute_action(
|
|
87
|
+
server_name: str,
|
|
88
|
+
action_name: str,
|
|
89
|
+
path_params: Optional[str] = None,
|
|
90
|
+
query_params: Optional[str] = None,
|
|
91
|
+
body_schema: Optional[str] = "{}"
|
|
92
|
+
) -> Dict[str, Any]:
|
|
93
|
+
"""
|
|
94
|
+
Execute a specific action with the provided parameters.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
server_name: The name of the server.
|
|
98
|
+
action_name: The name of the action/operation to execute.
|
|
99
|
+
path_params: JSON string containing path parameters.
|
|
100
|
+
query_params: JSON string containing query parameters.
|
|
101
|
+
body_schema: JSON string containing request body.
|
|
102
|
+
"""
|
|
103
|
+
import json
|
|
104
|
+
|
|
105
|
+
# Check if server is connected (no JIT - explicit connect required)
|
|
106
|
+
if server_name not in client_manager.active_clients:
|
|
107
|
+
server_config = client_manager.server_list.get_server(server_name)
|
|
108
|
+
if server_config:
|
|
109
|
+
return {
|
|
110
|
+
"error": f"Server '{server_name}' is not connected",
|
|
111
|
+
"suggestion": f"Connect first with: manage_servers.exec {{\"connect\": \"{server_name}\"}}"
|
|
112
|
+
}
|
|
113
|
+
else:
|
|
114
|
+
return {"error": f"Server '{server_name}' not configured"}
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
client = client_manager.get_client(server_name)
|
|
118
|
+
|
|
119
|
+
if not client.is_connected():
|
|
120
|
+
return {"error": f"Server '{server_name}' is not connected"}
|
|
121
|
+
|
|
122
|
+
action_params = {}
|
|
123
|
+
for param_name, param_value in [
|
|
124
|
+
("path_params", path_params),
|
|
125
|
+
("query_params", query_params),
|
|
126
|
+
("body_schema", body_schema),
|
|
127
|
+
]:
|
|
128
|
+
if param_value and param_value != "{}":
|
|
129
|
+
try:
|
|
130
|
+
if isinstance(param_value, str):
|
|
131
|
+
action_params.update(json.loads(param_value))
|
|
132
|
+
else:
|
|
133
|
+
action_params.update(param_value)
|
|
134
|
+
except json.JSONDecodeError as e:
|
|
135
|
+
return {"error": f"Invalid JSON in {param_name}: {str(e)}"}
|
|
136
|
+
|
|
137
|
+
result = await client.call_tool(action_name, action_params)
|
|
138
|
+
return {"result": result}
|
|
139
|
+
|
|
140
|
+
except Exception as e:
|
|
141
|
+
logger.error(f"Execution failed for {server_name}/{action_name}: {e}\n{traceback.format_exc()}")
|
|
142
|
+
return {"error": f"Execution failed: {str(e)}", "traceback": traceback.format_exc()}
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
async def manage_servers(
|
|
146
|
+
list_configured_mcps: bool = False,
|
|
147
|
+
list_sets: bool = False,
|
|
148
|
+
connect: Optional[str] = None,
|
|
149
|
+
connect_set: Optional[str] = None,
|
|
150
|
+
connect_set_exclusive: bool = False,
|
|
151
|
+
search_sets: Optional[str] = None,
|
|
152
|
+
upsert_set: Optional[Dict[str, Any]] = None,
|
|
153
|
+
delete_set: Optional[str] = None,
|
|
154
|
+
disconnect: Optional[str] = None,
|
|
155
|
+
disconnect_set: Optional[str] = None,
|
|
156
|
+
disconnect_all: bool = False,
|
|
157
|
+
populate_catalog: bool = False
|
|
158
|
+
) -> str:
|
|
159
|
+
"""
|
|
160
|
+
Manage MCP server connections and Sets.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
list_configured_mcps: If true, lists all configured servers with their status.
|
|
164
|
+
list_sets: If true, lists all configured Sets and their servers.
|
|
165
|
+
connect: Name of the server to connect (turn on).
|
|
166
|
+
connect_set: Name of the Set to connect (turn on all servers in set).
|
|
167
|
+
connect_set_exclusive: If true with connect_set, disconnects all other servers first.
|
|
168
|
+
search_sets: Search set descriptions for matching sets.
|
|
169
|
+
upsert_set: Create or update a Set (dict with name, servers, description, include_sets).
|
|
170
|
+
delete_set: Name of the Set to delete.
|
|
171
|
+
disconnect: Name of the server to disconnect (turn off).
|
|
172
|
+
disconnect_set: Name of the Set to disconnect (turn off all servers in set).
|
|
173
|
+
disconnect_all: If true, disconnects all servers.
|
|
174
|
+
populate_catalog: If true, connects to all enabled servers, refreshes catalog cache, then disconnects.
|
|
175
|
+
"""
|
|
176
|
+
import asyncio
|
|
177
|
+
|
|
178
|
+
results = []
|
|
179
|
+
|
|
180
|
+
if list_configured_mcps:
|
|
181
|
+
active = client_manager.list_active_servers()
|
|
182
|
+
configured = client_manager.server_list.list_servers()
|
|
183
|
+
lines = [f"{s.name}, {'on' if s.name in active else 'off'}" for s in configured]
|
|
184
|
+
results.append("\n".join(lines))
|
|
185
|
+
|
|
186
|
+
if list_sets:
|
|
187
|
+
sets = client_manager.server_list.list_sets()
|
|
188
|
+
lines = []
|
|
189
|
+
for name, data in sets.items():
|
|
190
|
+
desc = data.get('description', '')
|
|
191
|
+
servers = data.get('servers', [])
|
|
192
|
+
includes = data.get('include_sets', [])
|
|
193
|
+
line = f"{name}: {desc}" if desc else f"{name}:"
|
|
194
|
+
if servers:
|
|
195
|
+
line += f"\n servers: {', '.join(servers)}"
|
|
196
|
+
if includes:
|
|
197
|
+
line += f"\n includes: {', '.join(includes)}"
|
|
198
|
+
lines.append(line)
|
|
199
|
+
results.append("\n".join(lines) if lines else "No sets configured")
|
|
200
|
+
|
|
201
|
+
if search_sets:
|
|
202
|
+
sets = client_manager.server_list.list_sets()
|
|
203
|
+
query_lower = search_sets.lower()
|
|
204
|
+
matches = []
|
|
205
|
+
for name, data in sets.items():
|
|
206
|
+
desc = data.get('description', '')
|
|
207
|
+
if query_lower in name.lower() or query_lower in desc.lower():
|
|
208
|
+
servers = ', '.join(data.get('servers', []))
|
|
209
|
+
matches.append(f"{name}: {desc}\n {servers}" if desc else f"{name}:\n {servers}")
|
|
210
|
+
results.append("\n".join(matches) if matches else f"no sets matching '{search_sets}'")
|
|
211
|
+
|
|
212
|
+
if upsert_set:
|
|
213
|
+
try:
|
|
214
|
+
name = upsert_set.get("name")
|
|
215
|
+
servers = upsert_set.get("servers", [])
|
|
216
|
+
desc = upsert_set.get("description", "")
|
|
217
|
+
include_sets = upsert_set.get("include_sets")
|
|
218
|
+
if name and (servers or include_sets):
|
|
219
|
+
client_manager.server_list.add_set(name, servers, desc, include_sets)
|
|
220
|
+
results.append(f"set '{name}' saved")
|
|
221
|
+
else:
|
|
222
|
+
results.append("error: missing name, or need servers or include_sets")
|
|
223
|
+
except Exception as e:
|
|
224
|
+
logger.error(f"Error upserting set: {e}\n{traceback.format_exc()}")
|
|
225
|
+
results.append(f"error: {e}\n{traceback.format_exc()}")
|
|
226
|
+
|
|
227
|
+
if delete_set:
|
|
228
|
+
success = client_manager.server_list.remove_set(delete_set)
|
|
229
|
+
results.append(f"set '{delete_set}' deleted" if success else f"error: set '{delete_set}' not found")
|
|
230
|
+
|
|
231
|
+
if connect:
|
|
232
|
+
server_config = client_manager.server_list.get_server(connect)
|
|
233
|
+
if server_config:
|
|
234
|
+
asyncio.create_task(client_manager._connect_server(server_config))
|
|
235
|
+
results.append(f"{connect} starting")
|
|
236
|
+
else:
|
|
237
|
+
results.append(f"error: {connect} not configured")
|
|
238
|
+
|
|
239
|
+
if connect_set:
|
|
240
|
+
servers_in_set = client_manager.server_list.get_set(connect_set)
|
|
241
|
+
if servers_in_set:
|
|
242
|
+
disconnected = []
|
|
243
|
+
if connect_set_exclusive:
|
|
244
|
+
for srv in list(client_manager.active_clients.keys()):
|
|
245
|
+
if srv not in servers_in_set:
|
|
246
|
+
await client_manager._disconnect_server(srv)
|
|
247
|
+
disconnected.append(srv)
|
|
248
|
+
|
|
249
|
+
statuses = []
|
|
250
|
+
for srv in servers_in_set:
|
|
251
|
+
server_config = client_manager.server_list.get_server(srv)
|
|
252
|
+
if server_config and srv not in client_manager.active_clients:
|
|
253
|
+
asyncio.create_task(client_manager._connect_server(server_config))
|
|
254
|
+
statuses.append(f"{srv}: starting")
|
|
255
|
+
elif srv in client_manager.active_clients:
|
|
256
|
+
statuses.append(f"{srv}: on")
|
|
257
|
+
else:
|
|
258
|
+
statuses.append(f"{srv}: not configured")
|
|
259
|
+
prefix = f"connect_set '{connect_set}' (exclusive):" if connect_set_exclusive else f"connect_set '{connect_set}':"
|
|
260
|
+
output = f"{prefix}\n" + "\n".join(statuses)
|
|
261
|
+
if disconnected:
|
|
262
|
+
output += f"\nstopped: {', '.join(disconnected)}"
|
|
263
|
+
results.append(output)
|
|
264
|
+
else:
|
|
265
|
+
results.append(f"error: set '{connect_set}' not found")
|
|
266
|
+
|
|
267
|
+
if disconnect:
|
|
268
|
+
await client_manager._disconnect_server(disconnect)
|
|
269
|
+
results.append(f"{disconnect} off")
|
|
270
|
+
|
|
271
|
+
if disconnect_set:
|
|
272
|
+
servers_in_set = client_manager.server_list.get_set(disconnect_set)
|
|
273
|
+
if servers_in_set:
|
|
274
|
+
for srv in servers_in_set:
|
|
275
|
+
await client_manager._disconnect_server(srv)
|
|
276
|
+
results.append(f"disconnect_set '{disconnect_set}': {len(servers_in_set)} stopped")
|
|
277
|
+
else:
|
|
278
|
+
results.append(f"error: set '{disconnect_set}' not found")
|
|
279
|
+
|
|
280
|
+
if disconnect_all:
|
|
281
|
+
await client_manager.disconnect_all()
|
|
282
|
+
results.append("all disconnected")
|
|
283
|
+
|
|
284
|
+
if populate_catalog:
|
|
285
|
+
enabled_servers = client_manager.server_list.list_servers(enabled_only=True)
|
|
286
|
+
already_cached = [s for s in enabled_servers if client_manager.catalog.get_tools(s.name)]
|
|
287
|
+
to_populate = [s for s in enabled_servers if not client_manager.catalog.get_tools(s.name)]
|
|
288
|
+
|
|
289
|
+
if not to_populate:
|
|
290
|
+
results.append(f"catalog: {len(already_cached)}/{len(enabled_servers)} cached, nothing to populate")
|
|
291
|
+
else:
|
|
292
|
+
indexed = []
|
|
293
|
+
for server in to_populate:
|
|
294
|
+
try:
|
|
295
|
+
was_connected = server.name in client_manager.active_clients
|
|
296
|
+
if not was_connected:
|
|
297
|
+
await client_manager._connect_server(server)
|
|
298
|
+
client = client_manager.get_client(server.name)
|
|
299
|
+
tools = await client.list_tools()
|
|
300
|
+
client_manager.catalog.update_server(server.name, tools)
|
|
301
|
+
if not was_connected:
|
|
302
|
+
await client_manager._disconnect_server(server.name)
|
|
303
|
+
indexed.append(f"{server.name}: {len(tools)} tools")
|
|
304
|
+
except Exception as e:
|
|
305
|
+
logger.error(f"Error populating catalog for {server.name}: {e}\n{traceback.format_exc()}")
|
|
306
|
+
indexed.append(f"{server.name}: error - {e}")
|
|
307
|
+
results.append(f"catalog: indexed {len(to_populate)}, skipped {len(already_cached)}\n" + "\n".join(indexed))
|
|
308
|
+
|
|
309
|
+
return "\n".join(str(r) for r in results)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
async def search_mcp_catalog(query: str, max_results: int = 20) -> Dict[str, Any]:
|
|
313
|
+
"""
|
|
314
|
+
Search for tools in the offline catalog and discover Sets/Collections.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
query: Search query for tools or collections.
|
|
318
|
+
max_results: Maximum results to return. Default 20.
|
|
319
|
+
"""
|
|
320
|
+
# Search tools
|
|
321
|
+
tool_results = client_manager.catalog.search(query, max_results)
|
|
322
|
+
|
|
323
|
+
# Annotate with current status
|
|
324
|
+
active_servers = client_manager.list_active_servers()
|
|
325
|
+
for r in tool_results:
|
|
326
|
+
r["current_status"] = "online" if r.get("category_name") in active_servers else "offline"
|
|
327
|
+
|
|
328
|
+
# Search sets
|
|
329
|
+
sets = client_manager.server_list.list_sets()
|
|
330
|
+
matching_sets = []
|
|
331
|
+
for set_name, set_data in sets.items():
|
|
332
|
+
description = set_data.get("description", "")
|
|
333
|
+
if (query.lower() in set_name.lower()) or (query.lower() in description.lower()):
|
|
334
|
+
matching_sets.append({
|
|
335
|
+
"type": "collection",
|
|
336
|
+
"name": set_name,
|
|
337
|
+
"description": description,
|
|
338
|
+
"servers": set_data.get("servers", []),
|
|
339
|
+
"status": "available"
|
|
340
|
+
})
|
|
341
|
+
|
|
342
|
+
return {"collections": matching_sets, "tools": tool_results}
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
async def search_documentation(query: str, server_name: str, max_results: int = 10) -> List[Dict[str, Any]]:
|
|
346
|
+
"""
|
|
347
|
+
Search for server action documentations by keyword matching.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
query: Search keywords.
|
|
351
|
+
server_name: Name of the server to search within.
|
|
352
|
+
max_results: Number of results to return. Default: 10.
|
|
353
|
+
"""
|
|
354
|
+
try:
|
|
355
|
+
client = client_manager.get_client(server_name)
|
|
356
|
+
tools = await client.list_tools()
|
|
357
|
+
|
|
358
|
+
tools_map = {server_name: tools if tools else []}
|
|
359
|
+
searcher = UniversalToolSearcher(tools_map)
|
|
360
|
+
return searcher.search(query, max_results=max_results)
|
|
361
|
+
except KeyError:
|
|
362
|
+
return [{"error": f"Server '{server_name}' not found or not connected"}]
|
|
363
|
+
except Exception as e:
|
|
364
|
+
logger.error(f"Error searching documentation for {server_name}: {e}\n{traceback.format_exc()}")
|
|
365
|
+
return [{"error": f"Error searching documentation: {str(e)}", "traceback": traceback.format_exc()}]
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
async def handle_auth_failure(
|
|
369
|
+
server_name: str,
|
|
370
|
+
intention: str,
|
|
371
|
+
auth_data: Optional[Dict[str, Any]] = None
|
|
372
|
+
) -> Dict[str, Any]:
|
|
373
|
+
"""
|
|
374
|
+
Handle authentication failures that occur when executing actions.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
server_name: The name of the server.
|
|
378
|
+
intention: Action to take for authentication ('get_auth_url' or 'save_auth_data').
|
|
379
|
+
auth_data: Authentication data when saving.
|
|
380
|
+
"""
|
|
381
|
+
if intention == "get_auth_url":
|
|
382
|
+
return {
|
|
383
|
+
"server": server_name,
|
|
384
|
+
"message": f"Authentication required for server '{server_name}'",
|
|
385
|
+
"instructions": "Please provide authentication credentials",
|
|
386
|
+
"required_fields": {"token": "Authentication token or API key"},
|
|
387
|
+
}
|
|
388
|
+
elif intention == "save_auth_data":
|
|
389
|
+
if not auth_data:
|
|
390
|
+
return {"error": "auth_data is required when intention is 'save_auth_data'"}
|
|
391
|
+
return {
|
|
392
|
+
"server": server_name,
|
|
393
|
+
"status": "success",
|
|
394
|
+
"message": f"Authentication data saved for server '{server_name}'",
|
|
395
|
+
}
|
|
396
|
+
else:
|
|
397
|
+
return {"error": f"Invalid intention: '{intention}'"}
|
strata/utils/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
"""
|
|
2
|
+
BM25+ based search utility with field-level scoring
|
|
3
|
+
|
|
4
|
+
This implementation flattens document fields into separate documents for independent scoring,
|
|
5
|
+
then aggregates scores by original document ID with field weights.
|
|
6
|
+
|
|
7
|
+
Algorithm:
|
|
8
|
+
1. Each field becomes a separate document: "original_id:field_key" -> field_value
|
|
9
|
+
2. BM25 scores each field independently
|
|
10
|
+
3. Final score = sum(field_score * field_weight) for all fields of same original_id
|
|
11
|
+
|
|
12
|
+
Installation:
|
|
13
|
+
pip install "bm25s"
|
|
14
|
+
pip install PyStemmer
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from collections import defaultdict
|
|
18
|
+
from typing import List, Tuple
|
|
19
|
+
|
|
20
|
+
import bm25s
|
|
21
|
+
import Stemmer
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BM25SearchEngine:
|
|
25
|
+
"""
|
|
26
|
+
Field-aware BM25+ search engine that scores each field independently
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, use_stemmer: bool = True):
|
|
30
|
+
"""
|
|
31
|
+
Initialize the BM25+ search engine
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
use_stemmer: Whether to use stemming for better search results
|
|
35
|
+
"""
|
|
36
|
+
self.stemmer = Stemmer.Stemmer("english") if use_stemmer else None
|
|
37
|
+
self.retriever = None
|
|
38
|
+
# Maps flattened_doc_id -> (original_doc_id, field_key, field_weight)
|
|
39
|
+
self.corpus_metadata = None
|
|
40
|
+
# Maps original_doc_id -> [(field_key, weight), ...]
|
|
41
|
+
self.doc_field_weights = None
|
|
42
|
+
|
|
43
|
+
def build_index(self, documents: List[Tuple[List[Tuple[str, str, int]], str]]):
|
|
44
|
+
"""
|
|
45
|
+
Build BM25 index from documents by flattening fields into separate documents
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
documents: List of (fields, doc_id) tuples
|
|
49
|
+
fields: List of (field_key, field_value, weight) tuples
|
|
50
|
+
doc_id: Document identifier string
|
|
51
|
+
|
|
52
|
+
Example:
|
|
53
|
+
documents = [
|
|
54
|
+
(
|
|
55
|
+
[
|
|
56
|
+
("service", "projects", 30),
|
|
57
|
+
("operation", "create_project", 30),
|
|
58
|
+
("description", "Creates a new project", 20),
|
|
59
|
+
],
|
|
60
|
+
"projects:create_project"
|
|
61
|
+
),
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
This creates separate BM25 documents:
|
|
65
|
+
- "projects:create_project:service" -> "projects"
|
|
66
|
+
- "projects:create_project:operation" -> "create_project"
|
|
67
|
+
- "projects:create_project:description" -> "Creates a new project"
|
|
68
|
+
"""
|
|
69
|
+
corpus = []
|
|
70
|
+
self.corpus_metadata = []
|
|
71
|
+
self.doc_field_weights = defaultdict(list)
|
|
72
|
+
|
|
73
|
+
for fields, original_doc_id in documents:
|
|
74
|
+
for field_key, field_value, weight in fields:
|
|
75
|
+
if field_value and weight > 0:
|
|
76
|
+
# Preprocess field value for better tokenization
|
|
77
|
+
processed_value = self._preprocess_field_value(field_value.strip())
|
|
78
|
+
corpus.append(processed_value)
|
|
79
|
+
|
|
80
|
+
# Store metadata: flattened_id -> (original_id, field_key, weight)
|
|
81
|
+
self.corpus_metadata.append((original_doc_id, field_key, weight))
|
|
82
|
+
|
|
83
|
+
# Store field weights by original document
|
|
84
|
+
self.doc_field_weights[original_doc_id].append((field_key, weight))
|
|
85
|
+
|
|
86
|
+
if not corpus:
|
|
87
|
+
raise ValueError("No documents to index")
|
|
88
|
+
|
|
89
|
+
# Tokenize corpus (each field value separately)
|
|
90
|
+
corpus_tokens = bm25s.tokenize(
|
|
91
|
+
corpus,
|
|
92
|
+
stopwords=[], # Disable stopwords for better field matching
|
|
93
|
+
show_progress=False,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Create and index BM25+ retriever
|
|
97
|
+
self.retriever = bm25s.BM25(method="bm25+")
|
|
98
|
+
self.retriever.index(corpus_tokens, show_progress=False)
|
|
99
|
+
|
|
100
|
+
def search(self, query: str, top_k: int = 10) -> List[Tuple[float, str]]:
|
|
101
|
+
"""
|
|
102
|
+
Search indexed documents with field-level scoring and weighted aggregation
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
query: Search query string
|
|
106
|
+
top_k: Number of top results to return
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
List of (score, doc_id) tuples sorted by score descending
|
|
110
|
+
|
|
111
|
+
Algorithm:
|
|
112
|
+
1. Search all flattened field documents
|
|
113
|
+
2. Group results by original document ID
|
|
114
|
+
3. Calculate weighted sum: score = sum(field_score * field_weight)
|
|
115
|
+
4. Return top_k results by final weighted score
|
|
116
|
+
"""
|
|
117
|
+
if self.retriever is None or self.corpus_metadata is None:
|
|
118
|
+
raise ValueError("No documents indexed. Call build_index() first.")
|
|
119
|
+
|
|
120
|
+
# Tokenize query (matching build_index settings)
|
|
121
|
+
query_tokens = bm25s.tokenize(
|
|
122
|
+
query,
|
|
123
|
+
stopwords=[], # Disable stopwords to match build_index
|
|
124
|
+
show_progress=False,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Search all flattened documents to ensure complete field aggregation
|
|
128
|
+
# We need all matching fields for accurate document scoring
|
|
129
|
+
search_k = len(self.corpus_metadata)
|
|
130
|
+
doc_indices, scores = self.retriever.retrieve(
|
|
131
|
+
query_tokens, k=search_k, show_progress=False
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Aggregate scores by original document ID
|
|
135
|
+
doc_scores = defaultdict(float)
|
|
136
|
+
|
|
137
|
+
for i in range(doc_indices.shape[1]):
|
|
138
|
+
idx = doc_indices[0, i]
|
|
139
|
+
field_score = scores[0, i]
|
|
140
|
+
|
|
141
|
+
if idx < len(self.corpus_metadata):
|
|
142
|
+
original_doc_id, _, field_weight = self.corpus_metadata[idx]
|
|
143
|
+
|
|
144
|
+
# Add weighted field score to document total
|
|
145
|
+
weighted_score = float(field_score) * field_weight
|
|
146
|
+
doc_scores[original_doc_id] += weighted_score
|
|
147
|
+
|
|
148
|
+
# Sort by aggregated score and return top k
|
|
149
|
+
sorted_results = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)
|
|
150
|
+
|
|
151
|
+
# Return top_k results as (score, doc_id) tuples
|
|
152
|
+
return [(score, doc_id) for doc_id, score in sorted_results[:top_k]]
|
|
153
|
+
|
|
154
|
+
def _preprocess_field_value(self, value: str) -> str:
|
|
155
|
+
"""
|
|
156
|
+
Preprocess field values to improve tokenization
|
|
157
|
+
|
|
158
|
+
Converts underscore_separated and camelCase text to space-separated words
|
|
159
|
+
for better BM25 matching.
|
|
160
|
+
|
|
161
|
+
Examples:
|
|
162
|
+
"create_project" -> "create project"
|
|
163
|
+
"getUserProjects" -> "get User Projects"
|
|
164
|
+
"/api/v1/projects" -> "/api/v1/projects"
|
|
165
|
+
"""
|
|
166
|
+
import re
|
|
167
|
+
|
|
168
|
+
# Replace underscores with spaces
|
|
169
|
+
value = value.replace("_", " ")
|
|
170
|
+
|
|
171
|
+
# Replace hyphens with spaces
|
|
172
|
+
value = value.replace("-", " ")
|
|
173
|
+
|
|
174
|
+
# Split camelCase: insert space before uppercase letters
|
|
175
|
+
# But preserve existing spaces and special characters
|
|
176
|
+
value = re.sub(r"([a-z])([A-Z])", r"\1 \2", value)
|
|
177
|
+
|
|
178
|
+
# Clean up multiple spaces
|
|
179
|
+
value = re.sub(r"\s+", " ", value).strip()
|
|
180
|
+
|
|
181
|
+
return value
|
strata/utils/catalog.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""Persistent catalog for storing MCP tool definitions."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from typing import Dict, List, Any, Optional
|
|
7
|
+
|
|
8
|
+
from platformdirs import user_cache_dir
|
|
9
|
+
|
|
10
|
+
from strata.utils.shared_search import UniversalToolSearcher
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ToolCatalog:
|
|
16
|
+
"""Manages persistent storage and retrieval of tool definitions."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, app_name: str = "strata"):
|
|
19
|
+
self.cache_dir = user_cache_dir(app_name)
|
|
20
|
+
self.catalog_file = os.path.join(self.cache_dir, "tool_catalog.json")
|
|
21
|
+
self._catalog: Dict[str, List[Any]] = {}
|
|
22
|
+
self._ensure_cache_dir()
|
|
23
|
+
self.load()
|
|
24
|
+
|
|
25
|
+
def _ensure_cache_dir(self):
|
|
26
|
+
"""Ensure the cache directory exists."""
|
|
27
|
+
os.makedirs(self.cache_dir, exist_ok=True)
|
|
28
|
+
|
|
29
|
+
def load(self) -> None:
|
|
30
|
+
"""Load the catalog from disk."""
|
|
31
|
+
if os.path.exists(self.catalog_file):
|
|
32
|
+
try:
|
|
33
|
+
with open(self.catalog_file, "r", encoding="utf-8") as f:
|
|
34
|
+
self._catalog = json.load(f)
|
|
35
|
+
logger.info(f"Loaded tool catalog from {self.catalog_file}")
|
|
36
|
+
except Exception as e:
|
|
37
|
+
logger.error(f"Failed to load catalog: {e}")
|
|
38
|
+
self._catalog = {}
|
|
39
|
+
else:
|
|
40
|
+
self._catalog = {}
|
|
41
|
+
|
|
42
|
+
def save(self) -> None:
|
|
43
|
+
"""Save the catalog to disk."""
|
|
44
|
+
try:
|
|
45
|
+
with open(self.catalog_file, "w", encoding="utf-8") as f:
|
|
46
|
+
json.dump(self._catalog, f, indent=2)
|
|
47
|
+
logger.info(f"Saved tool catalog to {self.catalog_file}")
|
|
48
|
+
except Exception as e:
|
|
49
|
+
logger.error(f"Failed to save catalog: {e}")
|
|
50
|
+
|
|
51
|
+
def update_server(self, server_name: str, tools: List[Any]) -> None:
|
|
52
|
+
"""Update the tools for a specific server."""
|
|
53
|
+
self._catalog[server_name] = tools
|
|
54
|
+
self.save()
|
|
55
|
+
|
|
56
|
+
def get_tools(self, server_name: str) -> List[Any]:
|
|
57
|
+
"""Get tools for a specific server from the catalog."""
|
|
58
|
+
return self._catalog.get(server_name, [])
|
|
59
|
+
|
|
60
|
+
def get_all_tools(self) -> Dict[str, List[Any]]:
|
|
61
|
+
"""Get all tools in the catalog."""
|
|
62
|
+
return self._catalog
|
|
63
|
+
|
|
64
|
+
def search(self, query: str, max_results: int = 20) -> List[Dict[str, Any]]:
|
|
65
|
+
"""Search for tools across the entire catalog."""
|
|
66
|
+
if not self._catalog:
|
|
67
|
+
return []
|
|
68
|
+
|
|
69
|
+
searcher = UniversalToolSearcher(self._catalog)
|
|
70
|
+
results = searcher.search(query, max_results=max_results)
|
|
71
|
+
|
|
72
|
+
# Add a flag to indicate these are catalog results
|
|
73
|
+
for result in results:
|
|
74
|
+
result["source"] = "catalog"
|
|
75
|
+
|
|
76
|
+
return results
|
|
77
|
+
|
|
78
|
+
def remove_server(self, server_name: str) -> None:
|
|
79
|
+
"""Remove a server from the catalog."""
|
|
80
|
+
if server_name in self._catalog:
|
|
81
|
+
del self._catalog[server_name]
|
|
82
|
+
self.save()
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def find_in_dict_case_insensitive(
|
|
5
|
+
name: str, dictionary: Dict[str, Any]
|
|
6
|
+
) -> Optional[str]:
|
|
7
|
+
"""Helper function to find name in dictionary using case-insensitive matching.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
name: The name to search for
|
|
11
|
+
dictionary: Dictionary to search in
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
The actual key from the dictionary if found, None otherwise
|
|
15
|
+
"""
|
|
16
|
+
if not isinstance(name, str) or not dictionary:
|
|
17
|
+
return None
|
|
18
|
+
|
|
19
|
+
# First try exact match
|
|
20
|
+
if name in dictionary:
|
|
21
|
+
return name
|
|
22
|
+
|
|
23
|
+
# Then try case-insensitive match
|
|
24
|
+
name_lower = name.lower()
|
|
25
|
+
for key in dictionary.keys():
|
|
26
|
+
if isinstance(key, str) and key.lower() == name_lower:
|
|
27
|
+
return key
|
|
28
|
+
|
|
29
|
+
return None
|