gnosys-strata 1.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,397 @@
1
+ """Strata functions for TreeShell crystallization - clean function-per-tool architecture."""
2
+
3
+ import logging
4
+ import traceback
5
+ from typing import List, Dict, Any, Optional
6
+
7
+ from strata.mcp_client_manager import MCPClientManager
8
+ from strata.utils.shared_search import UniversalToolSearcher
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Global client manager - instantiated at import time
13
+ client_manager = MCPClientManager()
14
+
15
+
16
+ async def discover_server_actions(user_query: str, server_names: List[str] = None) -> Dict[str, Any]:
17
+ """
18
+ **PREFERRED STARTING POINT**: Discover available actions from servers based on user query.
19
+
20
+ Args:
21
+ user_query: Natural language user query to filter results.
22
+ server_names: List of server names to discover actions from.
23
+ """
24
+ if not server_names:
25
+ server_names = list(client_manager.active_clients.keys())
26
+
27
+ discovery_result = {}
28
+ for server_name in server_names:
29
+ try:
30
+ client = client_manager.get_client(server_name)
31
+ tools = await client.list_tools()
32
+
33
+ if user_query and tools:
34
+ tools_map = {server_name: tools}
35
+ searcher = UniversalToolSearcher(tools_map)
36
+ search_results = searcher.search(user_query, max_results=50)
37
+
38
+ filtered_action_names = []
39
+ for result_item in search_results:
40
+ for tool in tools:
41
+ if tool["name"] == result_item["name"]:
42
+ filtered_action_names.append(tool)
43
+ break
44
+ discovery_result[server_name] = filtered_action_names
45
+ else:
46
+ discovery_result[server_name] = tools
47
+
48
+ except KeyError:
49
+ discovery_result[server_name] = {"error": f"Server '{server_name}' not found or not connected"}
50
+ except Exception as e:
51
+ logger.error(f"Error discovering tools for {server_name}: {e}\n{traceback.format_exc()}")
52
+ discovery_result[server_name] = {"error": str(e), "traceback": traceback.format_exc()}
53
+
54
+ return discovery_result
55
+
56
+
57
+ async def get_action_details(server_name: str, action_name: str) -> Dict[str, Any]:
58
+ """
59
+ Get detailed information about a specific action.
60
+
61
+ Args:
62
+ server_name: The name of the server.
63
+ action_name: The name of the action/operation.
64
+ """
65
+ try:
66
+ client = client_manager.get_client(server_name)
67
+ tools = await client.list_tools()
68
+
69
+ tool = next((t for t in tools if t["name"] == action_name), None)
70
+
71
+ if tool:
72
+ return {
73
+ "name": tool["name"],
74
+ "description": tool.get("description"),
75
+ "inputSchema": tool.get("inputSchema"),
76
+ }
77
+ else:
78
+ return {"error": f"Action '{action_name}' not found on server '{server_name}'"}
79
+ except KeyError:
80
+ return {"error": f"Server '{server_name}' not found or not connected"}
81
+ except Exception as e:
82
+ logger.error(f"Error getting action details for {server_name}/{action_name}: {e}\n{traceback.format_exc()}")
83
+ return {"error": str(e), "traceback": traceback.format_exc()}
84
+
85
+
86
+ async def execute_action(
87
+ server_name: str,
88
+ action_name: str,
89
+ path_params: Optional[str] = None,
90
+ query_params: Optional[str] = None,
91
+ body_schema: Optional[str] = "{}"
92
+ ) -> Dict[str, Any]:
93
+ """
94
+ Execute a specific action with the provided parameters.
95
+
96
+ Args:
97
+ server_name: The name of the server.
98
+ action_name: The name of the action/operation to execute.
99
+ path_params: JSON string containing path parameters.
100
+ query_params: JSON string containing query parameters.
101
+ body_schema: JSON string containing request body.
102
+ """
103
+ import json
104
+
105
+ # Check if server is connected (no JIT - explicit connect required)
106
+ if server_name not in client_manager.active_clients:
107
+ server_config = client_manager.server_list.get_server(server_name)
108
+ if server_config:
109
+ return {
110
+ "error": f"Server '{server_name}' is not connected",
111
+ "suggestion": f"Connect first with: manage_servers.exec {{\"connect\": \"{server_name}\"}}"
112
+ }
113
+ else:
114
+ return {"error": f"Server '{server_name}' not configured"}
115
+
116
+ try:
117
+ client = client_manager.get_client(server_name)
118
+
119
+ if not client.is_connected():
120
+ return {"error": f"Server '{server_name}' is not connected"}
121
+
122
+ action_params = {}
123
+ for param_name, param_value in [
124
+ ("path_params", path_params),
125
+ ("query_params", query_params),
126
+ ("body_schema", body_schema),
127
+ ]:
128
+ if param_value and param_value != "{}":
129
+ try:
130
+ if isinstance(param_value, str):
131
+ action_params.update(json.loads(param_value))
132
+ else:
133
+ action_params.update(param_value)
134
+ except json.JSONDecodeError as e:
135
+ return {"error": f"Invalid JSON in {param_name}: {str(e)}"}
136
+
137
+ result = await client.call_tool(action_name, action_params)
138
+ return {"result": result}
139
+
140
+ except Exception as e:
141
+ logger.error(f"Execution failed for {server_name}/{action_name}: {e}\n{traceback.format_exc()}")
142
+ return {"error": f"Execution failed: {str(e)}", "traceback": traceback.format_exc()}
143
+
144
+
145
+ async def manage_servers(
146
+ list_configured_mcps: bool = False,
147
+ list_sets: bool = False,
148
+ connect: Optional[str] = None,
149
+ connect_set: Optional[str] = None,
150
+ connect_set_exclusive: bool = False,
151
+ search_sets: Optional[str] = None,
152
+ upsert_set: Optional[Dict[str, Any]] = None,
153
+ delete_set: Optional[str] = None,
154
+ disconnect: Optional[str] = None,
155
+ disconnect_set: Optional[str] = None,
156
+ disconnect_all: bool = False,
157
+ populate_catalog: bool = False
158
+ ) -> str:
159
+ """
160
+ Manage MCP server connections and Sets.
161
+
162
+ Args:
163
+ list_configured_mcps: If true, lists all configured servers with their status.
164
+ list_sets: If true, lists all configured Sets and their servers.
165
+ connect: Name of the server to connect (turn on).
166
+ connect_set: Name of the Set to connect (turn on all servers in set).
167
+ connect_set_exclusive: If true with connect_set, disconnects all other servers first.
168
+ search_sets: Search set descriptions for matching sets.
169
+ upsert_set: Create or update a Set (dict with name, servers, description, include_sets).
170
+ delete_set: Name of the Set to delete.
171
+ disconnect: Name of the server to disconnect (turn off).
172
+ disconnect_set: Name of the Set to disconnect (turn off all servers in set).
173
+ disconnect_all: If true, disconnects all servers.
174
+ populate_catalog: If true, connects to all enabled servers, refreshes catalog cache, then disconnects.
175
+ """
176
+ import asyncio
177
+
178
+ results = []
179
+
180
+ if list_configured_mcps:
181
+ active = client_manager.list_active_servers()
182
+ configured = client_manager.server_list.list_servers()
183
+ lines = [f"{s.name}, {'on' if s.name in active else 'off'}" for s in configured]
184
+ results.append("\n".join(lines))
185
+
186
+ if list_sets:
187
+ sets = client_manager.server_list.list_sets()
188
+ lines = []
189
+ for name, data in sets.items():
190
+ desc = data.get('description', '')
191
+ servers = data.get('servers', [])
192
+ includes = data.get('include_sets', [])
193
+ line = f"{name}: {desc}" if desc else f"{name}:"
194
+ if servers:
195
+ line += f"\n servers: {', '.join(servers)}"
196
+ if includes:
197
+ line += f"\n includes: {', '.join(includes)}"
198
+ lines.append(line)
199
+ results.append("\n".join(lines) if lines else "No sets configured")
200
+
201
+ if search_sets:
202
+ sets = client_manager.server_list.list_sets()
203
+ query_lower = search_sets.lower()
204
+ matches = []
205
+ for name, data in sets.items():
206
+ desc = data.get('description', '')
207
+ if query_lower in name.lower() or query_lower in desc.lower():
208
+ servers = ', '.join(data.get('servers', []))
209
+ matches.append(f"{name}: {desc}\n {servers}" if desc else f"{name}:\n {servers}")
210
+ results.append("\n".join(matches) if matches else f"no sets matching '{search_sets}'")
211
+
212
+ if upsert_set:
213
+ try:
214
+ name = upsert_set.get("name")
215
+ servers = upsert_set.get("servers", [])
216
+ desc = upsert_set.get("description", "")
217
+ include_sets = upsert_set.get("include_sets")
218
+ if name and (servers or include_sets):
219
+ client_manager.server_list.add_set(name, servers, desc, include_sets)
220
+ results.append(f"set '{name}' saved")
221
+ else:
222
+ results.append("error: missing name, or need servers or include_sets")
223
+ except Exception as e:
224
+ logger.error(f"Error upserting set: {e}\n{traceback.format_exc()}")
225
+ results.append(f"error: {e}\n{traceback.format_exc()}")
226
+
227
+ if delete_set:
228
+ success = client_manager.server_list.remove_set(delete_set)
229
+ results.append(f"set '{delete_set}' deleted" if success else f"error: set '{delete_set}' not found")
230
+
231
+ if connect:
232
+ server_config = client_manager.server_list.get_server(connect)
233
+ if server_config:
234
+ asyncio.create_task(client_manager._connect_server(server_config))
235
+ results.append(f"{connect} starting")
236
+ else:
237
+ results.append(f"error: {connect} not configured")
238
+
239
+ if connect_set:
240
+ servers_in_set = client_manager.server_list.get_set(connect_set)
241
+ if servers_in_set:
242
+ disconnected = []
243
+ if connect_set_exclusive:
244
+ for srv in list(client_manager.active_clients.keys()):
245
+ if srv not in servers_in_set:
246
+ await client_manager._disconnect_server(srv)
247
+ disconnected.append(srv)
248
+
249
+ statuses = []
250
+ for srv in servers_in_set:
251
+ server_config = client_manager.server_list.get_server(srv)
252
+ if server_config and srv not in client_manager.active_clients:
253
+ asyncio.create_task(client_manager._connect_server(server_config))
254
+ statuses.append(f"{srv}: starting")
255
+ elif srv in client_manager.active_clients:
256
+ statuses.append(f"{srv}: on")
257
+ else:
258
+ statuses.append(f"{srv}: not configured")
259
+ prefix = f"connect_set '{connect_set}' (exclusive):" if connect_set_exclusive else f"connect_set '{connect_set}':"
260
+ output = f"{prefix}\n" + "\n".join(statuses)
261
+ if disconnected:
262
+ output += f"\nstopped: {', '.join(disconnected)}"
263
+ results.append(output)
264
+ else:
265
+ results.append(f"error: set '{connect_set}' not found")
266
+
267
+ if disconnect:
268
+ await client_manager._disconnect_server(disconnect)
269
+ results.append(f"{disconnect} off")
270
+
271
+ if disconnect_set:
272
+ servers_in_set = client_manager.server_list.get_set(disconnect_set)
273
+ if servers_in_set:
274
+ for srv in servers_in_set:
275
+ await client_manager._disconnect_server(srv)
276
+ results.append(f"disconnect_set '{disconnect_set}': {len(servers_in_set)} stopped")
277
+ else:
278
+ results.append(f"error: set '{disconnect_set}' not found")
279
+
280
+ if disconnect_all:
281
+ await client_manager.disconnect_all()
282
+ results.append("all disconnected")
283
+
284
+ if populate_catalog:
285
+ enabled_servers = client_manager.server_list.list_servers(enabled_only=True)
286
+ already_cached = [s for s in enabled_servers if client_manager.catalog.get_tools(s.name)]
287
+ to_populate = [s for s in enabled_servers if not client_manager.catalog.get_tools(s.name)]
288
+
289
+ if not to_populate:
290
+ results.append(f"catalog: {len(already_cached)}/{len(enabled_servers)} cached, nothing to populate")
291
+ else:
292
+ indexed = []
293
+ for server in to_populate:
294
+ try:
295
+ was_connected = server.name in client_manager.active_clients
296
+ if not was_connected:
297
+ await client_manager._connect_server(server)
298
+ client = client_manager.get_client(server.name)
299
+ tools = await client.list_tools()
300
+ client_manager.catalog.update_server(server.name, tools)
301
+ if not was_connected:
302
+ await client_manager._disconnect_server(server.name)
303
+ indexed.append(f"{server.name}: {len(tools)} tools")
304
+ except Exception as e:
305
+ logger.error(f"Error populating catalog for {server.name}: {e}\n{traceback.format_exc()}")
306
+ indexed.append(f"{server.name}: error - {e}")
307
+ results.append(f"catalog: indexed {len(to_populate)}, skipped {len(already_cached)}\n" + "\n".join(indexed))
308
+
309
+ return "\n".join(str(r) for r in results)
310
+
311
+
312
+ async def search_mcp_catalog(query: str, max_results: int = 20) -> Dict[str, Any]:
313
+ """
314
+ Search for tools in the offline catalog and discover Sets/Collections.
315
+
316
+ Args:
317
+ query: Search query for tools or collections.
318
+ max_results: Maximum results to return. Default 20.
319
+ """
320
+ # Search tools
321
+ tool_results = client_manager.catalog.search(query, max_results)
322
+
323
+ # Annotate with current status
324
+ active_servers = client_manager.list_active_servers()
325
+ for r in tool_results:
326
+ r["current_status"] = "online" if r.get("category_name") in active_servers else "offline"
327
+
328
+ # Search sets
329
+ sets = client_manager.server_list.list_sets()
330
+ matching_sets = []
331
+ for set_name, set_data in sets.items():
332
+ description = set_data.get("description", "")
333
+ if (query.lower() in set_name.lower()) or (query.lower() in description.lower()):
334
+ matching_sets.append({
335
+ "type": "collection",
336
+ "name": set_name,
337
+ "description": description,
338
+ "servers": set_data.get("servers", []),
339
+ "status": "available"
340
+ })
341
+
342
+ return {"collections": matching_sets, "tools": tool_results}
343
+
344
+
345
+ async def search_documentation(query: str, server_name: str, max_results: int = 10) -> List[Dict[str, Any]]:
346
+ """
347
+ Search for server action documentations by keyword matching.
348
+
349
+ Args:
350
+ query: Search keywords.
351
+ server_name: Name of the server to search within.
352
+ max_results: Number of results to return. Default: 10.
353
+ """
354
+ try:
355
+ client = client_manager.get_client(server_name)
356
+ tools = await client.list_tools()
357
+
358
+ tools_map = {server_name: tools if tools else []}
359
+ searcher = UniversalToolSearcher(tools_map)
360
+ return searcher.search(query, max_results=max_results)
361
+ except KeyError:
362
+ return [{"error": f"Server '{server_name}' not found or not connected"}]
363
+ except Exception as e:
364
+ logger.error(f"Error searching documentation for {server_name}: {e}\n{traceback.format_exc()}")
365
+ return [{"error": f"Error searching documentation: {str(e)}", "traceback": traceback.format_exc()}]
366
+
367
+
368
+ async def handle_auth_failure(
369
+ server_name: str,
370
+ intention: str,
371
+ auth_data: Optional[Dict[str, Any]] = None
372
+ ) -> Dict[str, Any]:
373
+ """
374
+ Handle authentication failures that occur when executing actions.
375
+
376
+ Args:
377
+ server_name: The name of the server.
378
+ intention: Action to take for authentication ('get_auth_url' or 'save_auth_data').
379
+ auth_data: Authentication data when saving.
380
+ """
381
+ if intention == "get_auth_url":
382
+ return {
383
+ "server": server_name,
384
+ "message": f"Authentication required for server '{server_name}'",
385
+ "instructions": "Please provide authentication credentials",
386
+ "required_fields": {"token": "Authentication token or API key"},
387
+ }
388
+ elif intention == "save_auth_data":
389
+ if not auth_data:
390
+ return {"error": "auth_data is required when intention is 'save_auth_data'"}
391
+ return {
392
+ "server": server_name,
393
+ "status": "success",
394
+ "message": f"Authentication data saved for server '{server_name}'",
395
+ }
396
+ else:
397
+ return {"error": f"Invalid intention: '{intention}'"}
File without changes
@@ -0,0 +1,181 @@
1
+ """
2
+ BM25+ based search utility with field-level scoring
3
+
4
+ This implementation flattens document fields into separate documents for independent scoring,
5
+ then aggregates scores by original document ID with field weights.
6
+
7
+ Algorithm:
8
+ 1. Each field becomes a separate document: "original_id:field_key" -> field_value
9
+ 2. BM25 scores each field independently
10
+ 3. Final score = sum(field_score * field_weight) for all fields of same original_id
11
+
12
+ Installation:
13
+ pip install "bm25s"
14
+ pip install PyStemmer
15
+ """
16
+
17
+ from collections import defaultdict
18
+ from typing import List, Tuple
19
+
20
+ import bm25s
21
+ import Stemmer
22
+
23
+
24
+ class BM25SearchEngine:
25
+ """
26
+ Field-aware BM25+ search engine that scores each field independently
27
+ """
28
+
29
+ def __init__(self, use_stemmer: bool = True):
30
+ """
31
+ Initialize the BM25+ search engine
32
+
33
+ Args:
34
+ use_stemmer: Whether to use stemming for better search results
35
+ """
36
+ self.stemmer = Stemmer.Stemmer("english") if use_stemmer else None
37
+ self.retriever = None
38
+ # Maps flattened_doc_id -> (original_doc_id, field_key, field_weight)
39
+ self.corpus_metadata = None
40
+ # Maps original_doc_id -> [(field_key, weight), ...]
41
+ self.doc_field_weights = None
42
+
43
+ def build_index(self, documents: List[Tuple[List[Tuple[str, str, int]], str]]):
44
+ """
45
+ Build BM25 index from documents by flattening fields into separate documents
46
+
47
+ Args:
48
+ documents: List of (fields, doc_id) tuples
49
+ fields: List of (field_key, field_value, weight) tuples
50
+ doc_id: Document identifier string
51
+
52
+ Example:
53
+ documents = [
54
+ (
55
+ [
56
+ ("service", "projects", 30),
57
+ ("operation", "create_project", 30),
58
+ ("description", "Creates a new project", 20),
59
+ ],
60
+ "projects:create_project"
61
+ ),
62
+ ]
63
+
64
+ This creates separate BM25 documents:
65
+ - "projects:create_project:service" -> "projects"
66
+ - "projects:create_project:operation" -> "create_project"
67
+ - "projects:create_project:description" -> "Creates a new project"
68
+ """
69
+ corpus = []
70
+ self.corpus_metadata = []
71
+ self.doc_field_weights = defaultdict(list)
72
+
73
+ for fields, original_doc_id in documents:
74
+ for field_key, field_value, weight in fields:
75
+ if field_value and weight > 0:
76
+ # Preprocess field value for better tokenization
77
+ processed_value = self._preprocess_field_value(field_value.strip())
78
+ corpus.append(processed_value)
79
+
80
+ # Store metadata: flattened_id -> (original_id, field_key, weight)
81
+ self.corpus_metadata.append((original_doc_id, field_key, weight))
82
+
83
+ # Store field weights by original document
84
+ self.doc_field_weights[original_doc_id].append((field_key, weight))
85
+
86
+ if not corpus:
87
+ raise ValueError("No documents to index")
88
+
89
+ # Tokenize corpus (each field value separately)
90
+ corpus_tokens = bm25s.tokenize(
91
+ corpus,
92
+ stopwords=[], # Disable stopwords for better field matching
93
+ show_progress=False,
94
+ )
95
+
96
+ # Create and index BM25+ retriever
97
+ self.retriever = bm25s.BM25(method="bm25+")
98
+ self.retriever.index(corpus_tokens, show_progress=False)
99
+
100
+ def search(self, query: str, top_k: int = 10) -> List[Tuple[float, str]]:
101
+ """
102
+ Search indexed documents with field-level scoring and weighted aggregation
103
+
104
+ Args:
105
+ query: Search query string
106
+ top_k: Number of top results to return
107
+
108
+ Returns:
109
+ List of (score, doc_id) tuples sorted by score descending
110
+
111
+ Algorithm:
112
+ 1. Search all flattened field documents
113
+ 2. Group results by original document ID
114
+ 3. Calculate weighted sum: score = sum(field_score * field_weight)
115
+ 4. Return top_k results by final weighted score
116
+ """
117
+ if self.retriever is None or self.corpus_metadata is None:
118
+ raise ValueError("No documents indexed. Call build_index() first.")
119
+
120
+ # Tokenize query (matching build_index settings)
121
+ query_tokens = bm25s.tokenize(
122
+ query,
123
+ stopwords=[], # Disable stopwords to match build_index
124
+ show_progress=False,
125
+ )
126
+
127
+ # Search all flattened documents to ensure complete field aggregation
128
+ # We need all matching fields for accurate document scoring
129
+ search_k = len(self.corpus_metadata)
130
+ doc_indices, scores = self.retriever.retrieve(
131
+ query_tokens, k=search_k, show_progress=False
132
+ )
133
+
134
+ # Aggregate scores by original document ID
135
+ doc_scores = defaultdict(float)
136
+
137
+ for i in range(doc_indices.shape[1]):
138
+ idx = doc_indices[0, i]
139
+ field_score = scores[0, i]
140
+
141
+ if idx < len(self.corpus_metadata):
142
+ original_doc_id, _, field_weight = self.corpus_metadata[idx]
143
+
144
+ # Add weighted field score to document total
145
+ weighted_score = float(field_score) * field_weight
146
+ doc_scores[original_doc_id] += weighted_score
147
+
148
+ # Sort by aggregated score and return top k
149
+ sorted_results = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)
150
+
151
+ # Return top_k results as (score, doc_id) tuples
152
+ return [(score, doc_id) for doc_id, score in sorted_results[:top_k]]
153
+
154
+ def _preprocess_field_value(self, value: str) -> str:
155
+ """
156
+ Preprocess field values to improve tokenization
157
+
158
+ Converts underscore_separated and camelCase text to space-separated words
159
+ for better BM25 matching.
160
+
161
+ Examples:
162
+ "create_project" -> "create project"
163
+ "getUserProjects" -> "get User Projects"
164
+ "/api/v1/projects" -> "/api/v1/projects"
165
+ """
166
+ import re
167
+
168
+ # Replace underscores with spaces
169
+ value = value.replace("_", " ")
170
+
171
+ # Replace hyphens with spaces
172
+ value = value.replace("-", " ")
173
+
174
+ # Split camelCase: insert space before uppercase letters
175
+ # But preserve existing spaces and special characters
176
+ value = re.sub(r"([a-z])([A-Z])", r"\1 \2", value)
177
+
178
+ # Clean up multiple spaces
179
+ value = re.sub(r"\s+", " ", value).strip()
180
+
181
+ return value
@@ -0,0 +1,82 @@
1
+ """Persistent catalog for storing MCP tool definitions."""
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ from typing import Dict, List, Any, Optional
7
+
8
+ from platformdirs import user_cache_dir
9
+
10
+ from strata.utils.shared_search import UniversalToolSearcher
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class ToolCatalog:
16
+ """Manages persistent storage and retrieval of tool definitions."""
17
+
18
+ def __init__(self, app_name: str = "strata"):
19
+ self.cache_dir = user_cache_dir(app_name)
20
+ self.catalog_file = os.path.join(self.cache_dir, "tool_catalog.json")
21
+ self._catalog: Dict[str, List[Any]] = {}
22
+ self._ensure_cache_dir()
23
+ self.load()
24
+
25
+ def _ensure_cache_dir(self):
26
+ """Ensure the cache directory exists."""
27
+ os.makedirs(self.cache_dir, exist_ok=True)
28
+
29
+ def load(self) -> None:
30
+ """Load the catalog from disk."""
31
+ if os.path.exists(self.catalog_file):
32
+ try:
33
+ with open(self.catalog_file, "r", encoding="utf-8") as f:
34
+ self._catalog = json.load(f)
35
+ logger.info(f"Loaded tool catalog from {self.catalog_file}")
36
+ except Exception as e:
37
+ logger.error(f"Failed to load catalog: {e}")
38
+ self._catalog = {}
39
+ else:
40
+ self._catalog = {}
41
+
42
+ def save(self) -> None:
43
+ """Save the catalog to disk."""
44
+ try:
45
+ with open(self.catalog_file, "w", encoding="utf-8") as f:
46
+ json.dump(self._catalog, f, indent=2)
47
+ logger.info(f"Saved tool catalog to {self.catalog_file}")
48
+ except Exception as e:
49
+ logger.error(f"Failed to save catalog: {e}")
50
+
51
+ def update_server(self, server_name: str, tools: List[Any]) -> None:
52
+ """Update the tools for a specific server."""
53
+ self._catalog[server_name] = tools
54
+ self.save()
55
+
56
+ def get_tools(self, server_name: str) -> List[Any]:
57
+ """Get tools for a specific server from the catalog."""
58
+ return self._catalog.get(server_name, [])
59
+
60
+ def get_all_tools(self) -> Dict[str, List[Any]]:
61
+ """Get all tools in the catalog."""
62
+ return self._catalog
63
+
64
+ def search(self, query: str, max_results: int = 20) -> List[Dict[str, Any]]:
65
+ """Search for tools across the entire catalog."""
66
+ if not self._catalog:
67
+ return []
68
+
69
+ searcher = UniversalToolSearcher(self._catalog)
70
+ results = searcher.search(query, max_results=max_results)
71
+
72
+ # Add a flag to indicate these are catalog results
73
+ for result in results:
74
+ result["source"] = "catalog"
75
+
76
+ return results
77
+
78
+ def remove_server(self, server_name: str) -> None:
79
+ """Remove a server from the catalog."""
80
+ if server_name in self._catalog:
81
+ del self._catalog[server_name]
82
+ self.save()
@@ -0,0 +1,29 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+
4
+ def find_in_dict_case_insensitive(
5
+ name: str, dictionary: Dict[str, Any]
6
+ ) -> Optional[str]:
7
+ """Helper function to find name in dictionary using case-insensitive matching.
8
+
9
+ Args:
10
+ name: The name to search for
11
+ dictionary: Dictionary to search in
12
+
13
+ Returns:
14
+ The actual key from the dictionary if found, None otherwise
15
+ """
16
+ if not isinstance(name, str) or not dictionary:
17
+ return None
18
+
19
+ # First try exact match
20
+ if name in dictionary:
21
+ return name
22
+
23
+ # Then try case-insensitive match
24
+ name_lower = name.lower()
25
+ for key in dictionary.keys():
26
+ if isinstance(key, str) and key.lower() == name_lower:
27
+ return key
28
+
29
+ return None