claude-self-reflect 4.0.3 → 5.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -5,11 +5,22 @@ import sys
|
|
|
5
5
|
import importlib
|
|
6
6
|
import logging
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import Dict, List, Optional
|
|
8
|
+
from typing import Dict, List, Optional
|
|
9
9
|
from fastmcp import Context
|
|
10
10
|
from pydantic import Field
|
|
11
11
|
import hashlib
|
|
12
12
|
import json
|
|
13
|
+
import asyncio
|
|
14
|
+
|
|
15
|
+
# Import security module - handle both relative and absolute imports
|
|
16
|
+
try:
|
|
17
|
+
from .security_patches import ModuleWhitelist
|
|
18
|
+
except ImportError:
|
|
19
|
+
try:
|
|
20
|
+
from security_patches import ModuleWhitelist
|
|
21
|
+
except ImportError:
|
|
22
|
+
# Security module is required - fail closed, not open
|
|
23
|
+
raise RuntimeError("Security module 'security_patches' is required for code reload functionality")
|
|
13
24
|
|
|
14
25
|
logger = logging.getLogger(__name__)
|
|
15
26
|
|
|
@@ -19,20 +30,36 @@ class CodeReloader:
|
|
|
19
30
|
|
|
20
31
|
def __init__(self):
|
|
21
32
|
"""Initialize the code reloader."""
|
|
22
|
-
self.module_hashes: Dict[str, str] = {}
|
|
23
|
-
self.reload_history: List[Dict] = []
|
|
24
33
|
self.cache_dir = Path.home() / '.claude-self-reflect' / 'reload_cache'
|
|
25
34
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
26
|
-
|
|
27
|
-
|
|
35
|
+
self.hash_file = self.cache_dir / 'module_hashes.json'
|
|
36
|
+
self._lock = asyncio.Lock() # Thread safety for async operations
|
|
37
|
+
|
|
38
|
+
# Load persisted hashes from disk with error handling
|
|
39
|
+
if self.hash_file.exists():
|
|
40
|
+
try:
|
|
41
|
+
with open(self.hash_file, 'r') as f:
|
|
42
|
+
self.module_hashes: Dict[str, str] = json.load(f)
|
|
43
|
+
except (json.JSONDecodeError, IOError) as e:
|
|
44
|
+
logger.error(f"Failed to load module hashes: {e}. Starting fresh.")
|
|
45
|
+
self.module_hashes: Dict[str, str] = {}
|
|
46
|
+
else:
|
|
47
|
+
self.module_hashes: Dict[str, str] = {}
|
|
48
|
+
|
|
49
|
+
self.reload_history: List[Dict] = []
|
|
50
|
+
logger.info(f"CodeReloader initialized with {len(self.module_hashes)} cached hashes")
|
|
28
51
|
|
|
29
52
|
def _get_file_hash(self, filepath: Path) -> str:
|
|
30
53
|
"""Get SHA256 hash of a file."""
|
|
31
54
|
with open(filepath, 'rb') as f:
|
|
32
55
|
return hashlib.sha256(f.read()).hexdigest()
|
|
33
56
|
|
|
34
|
-
def
|
|
35
|
-
"""Detect which modules have changed since last check.
|
|
57
|
+
def _detect_changed_modules(self) -> List[str]:
|
|
58
|
+
"""Detect which modules have changed since last check.
|
|
59
|
+
|
|
60
|
+
This method ONLY detects changes, it does NOT update the stored hashes.
|
|
61
|
+
Use _update_module_hashes() to update hashes after successful reload.
|
|
62
|
+
"""
|
|
36
63
|
changed = []
|
|
37
64
|
src_dir = Path(__file__).parent
|
|
38
65
|
|
|
@@ -43,13 +70,61 @@ class CodeReloader:
|
|
|
43
70
|
module_name = f"src.{py_file.stem}"
|
|
44
71
|
current_hash = self._get_file_hash(py_file)
|
|
45
72
|
|
|
73
|
+
# Only detect changes, DO NOT update hashes here
|
|
46
74
|
if module_name in self.module_hashes:
|
|
47
75
|
if self.module_hashes[module_name] != current_hash:
|
|
48
76
|
changed.append(module_name)
|
|
77
|
+
logger.debug(f"Change detected in {module_name}: {self.module_hashes[module_name][:8]} -> {current_hash[:8]}")
|
|
78
|
+
else:
|
|
79
|
+
# New module not seen before
|
|
80
|
+
changed.append(module_name)
|
|
81
|
+
logger.debug(f"New module detected: {module_name}")
|
|
82
|
+
|
|
83
|
+
return changed
|
|
84
|
+
|
|
85
|
+
def _update_module_hashes(self, modules: Optional[List[str]] = None) -> None:
|
|
86
|
+
"""Update the stored hashes for specified modules or all modules.
|
|
87
|
+
|
|
88
|
+
This should be called AFTER successful reload to mark modules as up-to-date.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
modules: List of module names to update. If None, updates all modules.
|
|
92
|
+
"""
|
|
93
|
+
src_dir = Path(__file__).parent
|
|
94
|
+
updated = []
|
|
95
|
+
|
|
96
|
+
for py_file in src_dir.glob("*.py"):
|
|
97
|
+
if py_file.name == "__pycache__":
|
|
98
|
+
continue
|
|
49
99
|
|
|
100
|
+
module_name = f"src.{py_file.stem}"
|
|
101
|
+
|
|
102
|
+
# If specific modules provided, only update those
|
|
103
|
+
if modules is not None and module_name not in modules:
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
current_hash = self._get_file_hash(py_file)
|
|
107
|
+
old_hash = self.module_hashes.get(module_name, "new")
|
|
50
108
|
self.module_hashes[module_name] = current_hash
|
|
109
|
+
|
|
110
|
+
if old_hash != current_hash:
|
|
111
|
+
updated.append(module_name)
|
|
112
|
+
logger.debug(f"Updated hash for {module_name}: {old_hash[:8] if old_hash != 'new' else 'new'} -> {current_hash[:8]}")
|
|
51
113
|
|
|
52
|
-
|
|
114
|
+
# Persist the updated hashes to disk using atomic write
|
|
115
|
+
temp_file = Path(str(self.hash_file) + '.tmp')
|
|
116
|
+
try:
|
|
117
|
+
with open(temp_file, 'w') as f:
|
|
118
|
+
json.dump(self.module_hashes, f, indent=2)
|
|
119
|
+
# Atomic rename on POSIX systems
|
|
120
|
+
temp_file.replace(self.hash_file)
|
|
121
|
+
except Exception as e:
|
|
122
|
+
logger.error(f"Failed to persist module hashes: {e}")
|
|
123
|
+
if temp_file.exists():
|
|
124
|
+
temp_file.unlink() # Clean up temp file on failure
|
|
125
|
+
|
|
126
|
+
if updated:
|
|
127
|
+
logger.info(f"Updated hashes for {len(updated)} modules: {', '.join(updated)}")
|
|
53
128
|
|
|
54
129
|
async def reload_modules(
|
|
55
130
|
self,
|
|
@@ -61,93 +136,98 @@ class CodeReloader:
|
|
|
61
136
|
|
|
62
137
|
await ctx.debug("Starting code reload process...")
|
|
63
138
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
response +=
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
139
|
+
async with self._lock: # Ensure thread safety for reload operations
|
|
140
|
+
try:
|
|
141
|
+
# Track what we're reloading
|
|
142
|
+
reload_targets = []
|
|
143
|
+
|
|
144
|
+
if auto_detect:
|
|
145
|
+
# Detect changed modules (without updating hashes)
|
|
146
|
+
changed = self._detect_changed_modules()
|
|
147
|
+
if changed:
|
|
148
|
+
reload_targets.extend(changed)
|
|
149
|
+
await ctx.debug(f"Auto-detected changes in: {changed}")
|
|
150
|
+
|
|
151
|
+
if modules:
|
|
152
|
+
# Add explicitly requested modules
|
|
153
|
+
reload_targets.extend(modules)
|
|
154
|
+
|
|
155
|
+
if not reload_targets:
|
|
156
|
+
return "📊 No modules to reload. All code is up to date!"
|
|
157
|
+
|
|
158
|
+
# Perform the reload
|
|
159
|
+
reloaded = []
|
|
160
|
+
failed = []
|
|
161
|
+
|
|
162
|
+
for module_name in reload_targets:
|
|
163
|
+
try:
|
|
164
|
+
# SECURITY FIX: Validate module is in whitelist
|
|
165
|
+
if not ModuleWhitelist.is_allowed_module(module_name):
|
|
166
|
+
logger.warning(f"Module not in whitelist, skipping: {module_name}")
|
|
167
|
+
failed.append((module_name, "Module not in whitelist"))
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
if module_name in sys.modules:
|
|
171
|
+
# Store old module reference for rollback
|
|
172
|
+
old_module = sys.modules[module_name]
|
|
173
|
+
|
|
174
|
+
# Reload the module
|
|
175
|
+
logger.info(f"Reloading module: {module_name}")
|
|
176
|
+
reloaded_module = importlib.reload(sys.modules[module_name])
|
|
177
|
+
|
|
178
|
+
# Update any global references if needed
|
|
179
|
+
self._update_global_references(module_name, reloaded_module)
|
|
180
|
+
|
|
181
|
+
reloaded.append(module_name)
|
|
182
|
+
await ctx.debug(f"✅ Reloaded: {module_name}")
|
|
183
|
+
else:
|
|
184
|
+
# Module not loaded yet, import it
|
|
185
|
+
importlib.import_module(module_name)
|
|
186
|
+
reloaded.append(module_name)
|
|
187
|
+
await ctx.debug(f"✅ Imported: {module_name}")
|
|
188
|
+
|
|
189
|
+
except Exception as e:
|
|
190
|
+
logger.error(f"Failed to reload {module_name}: {e}", exc_info=True)
|
|
191
|
+
failed.append((module_name, str(e)))
|
|
192
|
+
await ctx.debug(f"❌ Failed: {module_name} - {e}")
|
|
193
|
+
|
|
194
|
+
# Update hashes ONLY for successfully reloaded modules
|
|
195
|
+
if reloaded:
|
|
196
|
+
self._update_module_hashes(reloaded)
|
|
197
|
+
await ctx.debug(f"Updated hashes for {len(reloaded)} successfully reloaded modules")
|
|
198
|
+
|
|
199
|
+
# Record reload history
|
|
200
|
+
self.reload_history.append({
|
|
201
|
+
"timestamp": os.environ.get('MCP_REQUEST_ID', 'unknown'),
|
|
202
|
+
"reloaded": reloaded,
|
|
203
|
+
"failed": failed
|
|
204
|
+
})
|
|
205
|
+
|
|
206
|
+
# Build response
|
|
207
|
+
response = "🔄 **Code Reload Results**\n\n"
|
|
208
|
+
|
|
209
|
+
if reloaded:
|
|
210
|
+
response += f"**Successfully Reloaded ({len(reloaded)}):**\n"
|
|
211
|
+
for module in reloaded:
|
|
212
|
+
response += f"- ✅ {module}\n"
|
|
213
|
+
response += "\n"
|
|
214
|
+
|
|
215
|
+
if failed:
|
|
216
|
+
response += f"**Failed to Reload ({len(failed)}):**\n"
|
|
217
|
+
for module, error in failed:
|
|
218
|
+
response += f"- ❌ {module}: {error}\n"
|
|
219
|
+
response += "\n"
|
|
220
|
+
|
|
221
|
+
response += "**Important Notes:**\n"
|
|
222
|
+
response += "- Class instances created before reload keep old code\n"
|
|
223
|
+
response += "- New requests will use the reloaded code\n"
|
|
224
|
+
response += "- Some changes may require full restart (e.g., new tools)\n"
|
|
225
|
+
|
|
226
|
+
return response
|
|
227
|
+
|
|
228
|
+
except Exception as e:
|
|
229
|
+
logger.error(f"Code reload failed: {e}", exc_info=True)
|
|
230
|
+
return f"❌ Code reload failed: {str(e)}"
|
|
151
231
|
|
|
152
232
|
def _update_global_references(self, module_name: str, new_module):
|
|
153
233
|
"""Update global references after module reload."""
|
|
@@ -171,8 +251,8 @@ class CodeReloader:
|
|
|
171
251
|
"""Get the current reload status and history."""
|
|
172
252
|
|
|
173
253
|
try:
|
|
174
|
-
# Check for changed files
|
|
175
|
-
changed = self.
|
|
254
|
+
# Check for changed files (WITHOUT updating hashes)
|
|
255
|
+
changed = self._detect_changed_modules()
|
|
176
256
|
|
|
177
257
|
response = "📊 **Code Reload Status**\n\n"
|
|
178
258
|
|
|
@@ -224,6 +304,24 @@ class CodeReloader:
|
|
|
224
304
|
logger.error(f"Failed to clear cache: {e}", exc_info=True)
|
|
225
305
|
return f"❌ Failed to clear cache: {str(e)}"
|
|
226
306
|
|
|
307
|
+
async def force_update_hashes(self, ctx: Context) -> str:
|
|
308
|
+
"""Force update all module hashes to current state.
|
|
309
|
+
|
|
310
|
+
This is useful when you want to mark all current code as 'baseline'
|
|
311
|
+
without actually reloading anything.
|
|
312
|
+
"""
|
|
313
|
+
try:
|
|
314
|
+
await ctx.debug("Force updating all module hashes...")
|
|
315
|
+
|
|
316
|
+
# Update all module hashes
|
|
317
|
+
self._update_module_hashes(modules=None)
|
|
318
|
+
|
|
319
|
+
return f"✅ Force updated hashes for all {len(self.module_hashes)} tracked modules"
|
|
320
|
+
|
|
321
|
+
except Exception as e:
|
|
322
|
+
logger.error(f"Failed to force update hashes: {e}", exc_info=True)
|
|
323
|
+
return f"❌ Failed to force update hashes: {str(e)}"
|
|
324
|
+
|
|
227
325
|
|
|
228
326
|
def register_code_reload_tool(mcp, get_embedding_manager):
|
|
229
327
|
"""Register the code reloading tool with the MCP server."""
|
|
@@ -257,6 +355,8 @@ def register_code_reload_tool(mcp, get_embedding_manager):
|
|
|
257
355
|
|
|
258
356
|
Shows which files have been modified since last reload and
|
|
259
357
|
the history of recent reload operations.
|
|
358
|
+
|
|
359
|
+
Note: This only checks for changes, it does not update the stored hashes.
|
|
260
360
|
"""
|
|
261
361
|
return await reloader.get_reload_status(ctx)
|
|
262
362
|
|
|
@@ -267,5 +367,14 @@ def register_code_reload_tool(mcp, get_embedding_manager):
|
|
|
267
367
|
Useful when reload isn't working due to cached bytecode.
|
|
268
368
|
"""
|
|
269
369
|
return await reloader.clear_python_cache(ctx)
|
|
370
|
+
|
|
371
|
+
@mcp.tool()
|
|
372
|
+
async def force_update_module_hashes(ctx: Context) -> str:
|
|
373
|
+
"""Force update all module hashes to mark current code as baseline.
|
|
374
|
+
|
|
375
|
+
Use this when you want to ignore current changes and treat
|
|
376
|
+
the current state as the new baseline without reloading.
|
|
377
|
+
"""
|
|
378
|
+
return await reloader.force_update_hashes(ctx)
|
|
270
379
|
|
|
271
|
-
logger.info("Code reload tools registered successfully")
|
|
380
|
+
logger.info("Code reload tools registered successfully")
|
|
@@ -8,6 +8,7 @@ import time
|
|
|
8
8
|
from typing import List, Dict, Any, Optional, Tuple
|
|
9
9
|
from datetime import datetime
|
|
10
10
|
import logging
|
|
11
|
+
from .safe_getters import safe_get_list, safe_get_str
|
|
11
12
|
|
|
12
13
|
logger = logging.getLogger(__name__)
|
|
13
14
|
|
|
@@ -176,9 +177,9 @@ async def search_single_collection(
|
|
|
176
177
|
'collection_name': collection_name,
|
|
177
178
|
'raw_payload': point.payload, # Renamed from 'payload' for consistency
|
|
178
179
|
'code_patterns': point.payload.get('code_patterns'),
|
|
179
|
-
'files_analyzed': point.payload
|
|
180
|
-
'tools_used':
|
|
181
|
-
'concepts': point.payload
|
|
180
|
+
'files_analyzed': safe_get_list(point.payload, 'files_analyzed'),
|
|
181
|
+
'tools_used': safe_get_list(point.payload, 'tools_used'),
|
|
182
|
+
'concepts': safe_get_list(point.payload, 'concepts')
|
|
182
183
|
}
|
|
183
184
|
results.append(search_result)
|
|
184
185
|
else:
|
|
@@ -219,9 +220,9 @@ async def search_single_collection(
|
|
|
219
220
|
'collection_name': collection_name,
|
|
220
221
|
'raw_payload': point.payload,
|
|
221
222
|
'code_patterns': point.payload.get('code_patterns'),
|
|
222
|
-
'files_analyzed': point.payload
|
|
223
|
-
'tools_used':
|
|
224
|
-
'concepts': point.payload
|
|
223
|
+
'files_analyzed': safe_get_list(point.payload, 'files_analyzed'),
|
|
224
|
+
'tools_used': safe_get_list(point.payload, 'tools_used'),
|
|
225
|
+
'concepts': safe_get_list(point.payload, 'concepts')
|
|
225
226
|
}
|
|
226
227
|
results.append(search_result)
|
|
227
228
|
|
|
@@ -5,6 +5,7 @@ import time
|
|
|
5
5
|
from datetime import datetime, timezone
|
|
6
6
|
from typing import List, Dict, Any, Optional
|
|
7
7
|
import logging
|
|
8
|
+
from .safe_getters import safe_get_list, safe_get_str
|
|
8
9
|
|
|
9
10
|
logger = logging.getLogger(__name__)
|
|
10
11
|
|
|
@@ -114,16 +115,19 @@ def format_search_results_rich(
|
|
|
114
115
|
concept_frequency = {}
|
|
115
116
|
|
|
116
117
|
for result in results:
|
|
117
|
-
# Count file modifications
|
|
118
|
-
|
|
118
|
+
# Count file modifications - using safe_get_list for consistency
|
|
119
|
+
files = safe_get_list(result, 'files_analyzed')
|
|
120
|
+
for file in files:
|
|
119
121
|
file_frequency[file] = file_frequency.get(file, 0) + 1
|
|
120
122
|
|
|
121
|
-
# Count tool usage
|
|
122
|
-
|
|
123
|
+
# Count tool usage - using safe_get_list for consistency
|
|
124
|
+
tools = safe_get_list(result, 'tools_used')
|
|
125
|
+
for tool in tools:
|
|
123
126
|
tool_frequency[tool] = tool_frequency.get(tool, 0) + 1
|
|
124
127
|
|
|
125
|
-
# Count concepts
|
|
126
|
-
|
|
128
|
+
# Count concepts - using safe_get_list for consistency
|
|
129
|
+
concepts = safe_get_list(result, 'concepts')
|
|
130
|
+
for concept in concepts:
|
|
127
131
|
concept_frequency[concept] = concept_frequency.get(concept, 0) + 1
|
|
128
132
|
|
|
129
133
|
# Show most frequently modified files
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
"""Safe getter utilities for handling None values consistently."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Dict, List, Optional, Set, Union
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def safe_get_list(
|
|
10
|
+
data: Optional[Dict[str, Any]],
|
|
11
|
+
key: str,
|
|
12
|
+
default: Optional[List] = None
|
|
13
|
+
) -> List[Any]:
|
|
14
|
+
"""
|
|
15
|
+
Safely get a list field from a dictionary, handling None and non-list values.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
data: Dictionary to get value from (can be None)
|
|
19
|
+
key: Key to retrieve
|
|
20
|
+
default: Default value if key not found or value is None
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
A list, either the value, converted value, or default/empty list
|
|
24
|
+
"""
|
|
25
|
+
if data is None:
|
|
26
|
+
return default if default is not None else []
|
|
27
|
+
|
|
28
|
+
value = data.get(key)
|
|
29
|
+
|
|
30
|
+
if value is None:
|
|
31
|
+
return default if default is not None else []
|
|
32
|
+
|
|
33
|
+
# Handle sets and tuples by converting to list
|
|
34
|
+
if isinstance(value, (set, tuple)):
|
|
35
|
+
return list(value)
|
|
36
|
+
|
|
37
|
+
# If it's already a list, return it
|
|
38
|
+
if isinstance(value, list):
|
|
39
|
+
return value
|
|
40
|
+
|
|
41
|
+
# If it's not a list-like type, log warning and return empty list
|
|
42
|
+
logger.warning(
|
|
43
|
+
f"Expected list-like type for key '{key}', got {type(value).__name__}. "
|
|
44
|
+
f"Value: {repr(value)[:100]}"
|
|
45
|
+
)
|
|
46
|
+
return default if default is not None else []
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def safe_get_str(
|
|
50
|
+
data: Optional[Dict[str, Any]],
|
|
51
|
+
key: str,
|
|
52
|
+
default: str = ""
|
|
53
|
+
) -> str:
|
|
54
|
+
"""
|
|
55
|
+
Safely get a string field from a dictionary.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
data: Dictionary to get value from (can be None)
|
|
59
|
+
key: Key to retrieve
|
|
60
|
+
default: Default value if key not found or value is None
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
A string, either the value or the default
|
|
64
|
+
"""
|
|
65
|
+
if data is None:
|
|
66
|
+
return default
|
|
67
|
+
|
|
68
|
+
value = data.get(key)
|
|
69
|
+
|
|
70
|
+
if value is None:
|
|
71
|
+
return default
|
|
72
|
+
|
|
73
|
+
# Convert to string if needed
|
|
74
|
+
return str(value)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def safe_get_dict(
|
|
78
|
+
data: Optional[Dict[str, Any]],
|
|
79
|
+
key: str,
|
|
80
|
+
default: Optional[Dict] = None
|
|
81
|
+
) -> Dict[str, Any]:
|
|
82
|
+
"""
|
|
83
|
+
Safely get a dictionary field from another dictionary.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
data: Dictionary to get value from (can be None)
|
|
87
|
+
key: Key to retrieve
|
|
88
|
+
default: Default value if key not found or value is None
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
A dictionary, either the value or the default/empty dict
|
|
92
|
+
"""
|
|
93
|
+
if data is None:
|
|
94
|
+
return default if default is not None else {}
|
|
95
|
+
|
|
96
|
+
value = data.get(key)
|
|
97
|
+
|
|
98
|
+
if value is None:
|
|
99
|
+
return default if default is not None else {}
|
|
100
|
+
|
|
101
|
+
if isinstance(value, dict):
|
|
102
|
+
return value
|
|
103
|
+
|
|
104
|
+
logger.warning(
|
|
105
|
+
f"Expected dict for key '{key}', got {type(value).__name__}. "
|
|
106
|
+
f"Value: {repr(value)[:100]}"
|
|
107
|
+
)
|
|
108
|
+
return default if default is not None else {}
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def safe_get_float(
|
|
112
|
+
data: Optional[Dict[str, Any]],
|
|
113
|
+
key: str,
|
|
114
|
+
default: float = 0.0
|
|
115
|
+
) -> float:
|
|
116
|
+
"""
|
|
117
|
+
Safely get a float field from a dictionary.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
data: Dictionary to get value from (can be None)
|
|
121
|
+
key: Key to retrieve
|
|
122
|
+
default: Default value if key not found or value is None/non-numeric
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
A float, either the converted value or the default
|
|
126
|
+
"""
|
|
127
|
+
if data is None:
|
|
128
|
+
return default
|
|
129
|
+
|
|
130
|
+
value = data.get(key)
|
|
131
|
+
|
|
132
|
+
if value is None:
|
|
133
|
+
return default
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
return float(value)
|
|
137
|
+
except (TypeError, ValueError) as e:
|
|
138
|
+
logger.warning(
|
|
139
|
+
f"Could not convert key '{key}' value to float: {repr(value)[:100]}. "
|
|
140
|
+
f"Error: {e}"
|
|
141
|
+
)
|
|
142
|
+
return default
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def safe_get_int(
|
|
146
|
+
data: Optional[Dict[str, Any]],
|
|
147
|
+
key: str,
|
|
148
|
+
default: int = 0
|
|
149
|
+
) -> int:
|
|
150
|
+
"""
|
|
151
|
+
Safely get an integer field from a dictionary.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
data: Dictionary to get value from (can be None)
|
|
155
|
+
key: Key to retrieve
|
|
156
|
+
default: Default value if key not found or value is None/non-numeric
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
An integer, either the converted value or the default
|
|
160
|
+
"""
|
|
161
|
+
if data is None:
|
|
162
|
+
return default
|
|
163
|
+
|
|
164
|
+
value = data.get(key)
|
|
165
|
+
|
|
166
|
+
if value is None:
|
|
167
|
+
return default
|
|
168
|
+
|
|
169
|
+
try:
|
|
170
|
+
return int(value)
|
|
171
|
+
except (TypeError, ValueError) as e:
|
|
172
|
+
logger.warning(
|
|
173
|
+
f"Could not convert key '{key}' value to int: {repr(value)[:100]}. "
|
|
174
|
+
f"Error: {e}"
|
|
175
|
+
)
|
|
176
|
+
return default
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def safe_get_bool(
|
|
180
|
+
data: Optional[Dict[str, Any]],
|
|
181
|
+
key: str,
|
|
182
|
+
default: bool = False
|
|
183
|
+
) -> bool:
|
|
184
|
+
"""
|
|
185
|
+
Safely get a boolean field from a dictionary.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
data: Dictionary to get value from (can be None)
|
|
189
|
+
key: Key to retrieve
|
|
190
|
+
default: Default value if key not found or value is None
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
A boolean, either the value or the default
|
|
194
|
+
"""
|
|
195
|
+
if data is None:
|
|
196
|
+
return default
|
|
197
|
+
|
|
198
|
+
value = data.get(key)
|
|
199
|
+
|
|
200
|
+
if value is None:
|
|
201
|
+
return default
|
|
202
|
+
|
|
203
|
+
if isinstance(value, bool):
|
|
204
|
+
return value
|
|
205
|
+
|
|
206
|
+
# Handle string booleans
|
|
207
|
+
if isinstance(value, str):
|
|
208
|
+
return value.lower() in ('true', '1', 'yes', 'on')
|
|
209
|
+
|
|
210
|
+
# Handle numeric booleans
|
|
211
|
+
try:
|
|
212
|
+
return bool(int(value))
|
|
213
|
+
except (TypeError, ValueError):
|
|
214
|
+
logger.warning(
|
|
215
|
+
f"Could not convert key '{key}' value to bool: {repr(value)[:100]}"
|
|
216
|
+
)
|
|
217
|
+
return default
|
|
@@ -20,6 +20,26 @@ from .rich_formatting import format_search_results_rich
|
|
|
20
20
|
logger = logging.getLogger(__name__)
|
|
21
21
|
|
|
22
22
|
|
|
23
|
+
def is_searchable_collection(name: str) -> bool:
|
|
24
|
+
"""
|
|
25
|
+
Check if collection name matches searchable patterns.
|
|
26
|
+
Supports both v3 and v4 collection naming conventions.
|
|
27
|
+
"""
|
|
28
|
+
return (
|
|
29
|
+
# v3 patterns
|
|
30
|
+
name.endswith('_local')
|
|
31
|
+
or name.endswith('_voyage')
|
|
32
|
+
# v4 patterns
|
|
33
|
+
or name.endswith('_384d') # Local v4 collections
|
|
34
|
+
or name.endswith('_1024d') # Cloud v4 collections
|
|
35
|
+
or '_cloud_' in name # Cloud v4 intermediate naming
|
|
36
|
+
# Reflections
|
|
37
|
+
or name.startswith('reflections')
|
|
38
|
+
# CSR prefixed collections
|
|
39
|
+
or name.startswith('csr_')
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
23
43
|
class SearchTools:
|
|
24
44
|
"""Handles all search operations for the MCP server."""
|
|
25
45
|
|
|
@@ -114,6 +134,11 @@ class SearchTools:
|
|
|
114
134
|
# Convert results to dict format
|
|
115
135
|
results = []
|
|
116
136
|
for result in search_results:
|
|
137
|
+
# Guard against None payload
|
|
138
|
+
if result.payload is None:
|
|
139
|
+
logger.warning(f"Result in {collection_name} has None payload, skipping")
|
|
140
|
+
continue
|
|
141
|
+
|
|
117
142
|
results.append({
|
|
118
143
|
'conversation_id': result.payload.get('conversation_id'),
|
|
119
144
|
'timestamp': result.payload.get('timestamp'),
|
|
@@ -274,10 +299,10 @@ class SearchTools:
|
|
|
274
299
|
return "<search_results><message>No collections available</message></search_results>"
|
|
275
300
|
|
|
276
301
|
# Include both conversation collections and reflection collections
|
|
302
|
+
# Use module-level function for consistency
|
|
277
303
|
filtered_collections = [
|
|
278
304
|
c for c in collections
|
|
279
|
-
if (c.name
|
|
280
|
-
c.name.startswith('reflections'))
|
|
305
|
+
if is_searchable_collection(c.name)
|
|
281
306
|
]
|
|
282
307
|
await ctx.debug(f"Searching across {len(filtered_collections)} collections")
|
|
283
308
|
|
|
@@ -403,8 +428,7 @@ class SearchTools:
|
|
|
403
428
|
# Include both conversation collections and reflection collections
|
|
404
429
|
filtered_collections = [
|
|
405
430
|
c for c in collections
|
|
406
|
-
if (c.name
|
|
407
|
-
c.name.startswith('reflections'))
|
|
431
|
+
if is_searchable_collection(c.name)
|
|
408
432
|
]
|
|
409
433
|
|
|
410
434
|
# Quick PARALLEL count across collections
|
|
@@ -493,8 +517,7 @@ class SearchTools:
|
|
|
493
517
|
# Include both conversation collections and reflection collections
|
|
494
518
|
filtered_collections = [
|
|
495
519
|
c for c in collections
|
|
496
|
-
if (c.name
|
|
497
|
-
c.name.startswith('reflections'))
|
|
520
|
+
if is_searchable_collection(c.name)
|
|
498
521
|
]
|
|
499
522
|
|
|
500
523
|
# Gather results for summary using PARALLEL search
|
|
@@ -590,8 +613,7 @@ class SearchTools:
|
|
|
590
613
|
# Include both conversation collections and reflection collections
|
|
591
614
|
filtered_collections = [
|
|
592
615
|
c for c in collections
|
|
593
|
-
if (c.name
|
|
594
|
-
c.name.startswith('reflections'))
|
|
616
|
+
if is_searchable_collection(c.name)
|
|
595
617
|
]
|
|
596
618
|
|
|
597
619
|
# Gather all results using PARALLEL search
|