claude-self-reflect 3.3.0 → 4.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude/agents/claude-self-reflect-test.md +525 -11
- package/.claude/agents/quality-fixer.md +314 -0
- package/.claude/agents/reflection-specialist.md +40 -1
- package/installer/cli.js +16 -0
- package/installer/postinstall.js +14 -0
- package/installer/statusline-setup.js +289 -0
- package/mcp-server/run-mcp.sh +45 -7
- package/mcp-server/src/code_reload_tool.py +271 -0
- package/mcp-server/src/embedding_manager.py +60 -26
- package/mcp-server/src/enhanced_tool_registry.py +407 -0
- package/mcp-server/src/mode_switch_tool.py +181 -0
- package/mcp-server/src/parallel_search.py +24 -85
- package/mcp-server/src/project_resolver.py +20 -2
- package/mcp-server/src/reflection_tools.py +60 -13
- package/mcp-server/src/rich_formatting.py +103 -0
- package/mcp-server/src/search_tools.py +180 -79
- package/mcp-server/src/security_patches.py +555 -0
- package/mcp-server/src/server.py +318 -240
- package/mcp-server/src/status.py +13 -8
- package/mcp-server/src/temporal_tools.py +10 -3
- package/mcp-server/src/test_quality.py +153 -0
- package/package.json +6 -1
- package/scripts/ast_grep_final_analyzer.py +328 -0
- package/scripts/ast_grep_unified_registry.py +710 -0
- package/scripts/csr-status +511 -0
- package/scripts/import-conversations-unified.py +114 -28
- package/scripts/session_quality_tracker.py +661 -0
- package/scripts/streaming-watcher.py +140 -5
- package/scripts/update_patterns.py +334 -0
|
@@ -0,0 +1,555 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Security patches for Claude Self-Reflect v4.0
|
|
4
|
+
Addresses all critical and high priority issues from CRITICAL_HIGH_PRIORITY_ISSUES.md
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import hashlib
|
|
8
|
+
import uuid
|
|
9
|
+
import asyncio
|
|
10
|
+
import logging
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Optional, List, Set, Any
|
|
13
|
+
import re
|
|
14
|
+
import os
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
# ==================== CRITICAL FIXES ====================
|
|
19
|
+
|
|
20
|
+
class SecureHashGenerator:
|
|
21
|
+
"""Fix for Critical Issue #4: Replace MD5 with SHA-256+UUID"""
|
|
22
|
+
|
|
23
|
+
@staticmethod
|
|
24
|
+
def generate_id(content: str, legacy_support: bool = True) -> str:
|
|
25
|
+
"""
|
|
26
|
+
Generate secure ID with backward compatibility for MD5.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
content: Content to hash
|
|
30
|
+
legacy_support: Enable backward compatibility for existing MD5 IDs
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Secure ID string
|
|
34
|
+
"""
|
|
35
|
+
if legacy_support:
|
|
36
|
+
# For backward compatibility, check if this might be an existing conversation
|
|
37
|
+
# This would need to be checked against the database in production
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
# Use SHA-256 for new IDs
|
|
41
|
+
# For Qdrant compatibility, use first 16 chars of SHA-256 as hex string
|
|
42
|
+
# This gives us 64 bits of entropy which is sufficient for uniqueness
|
|
43
|
+
sha256_hash = hashlib.sha256(content.encode()).hexdigest()
|
|
44
|
+
# Return just the hash without suffix for Qdrant compatibility
|
|
45
|
+
# Qdrant accepts hex strings as point IDs
|
|
46
|
+
return sha256_hash[:32] # Use first 32 hex chars (128 bits)
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def is_legacy_id(id_str: str) -> bool:
|
|
50
|
+
"""Check if an ID is using the legacy MD5 format."""
|
|
51
|
+
# Both MD5 and truncated SHA-256 are 32 chars, but this is fine
|
|
52
|
+
# We treat all 32-char hex strings as valid IDs
|
|
53
|
+
return len(id_str) == 32 and all(c in '0123456789abcdef' for c in id_str.lower())
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class PathValidator:
|
|
57
|
+
"""Fix for Critical Issue #2: Path Traversal Vulnerability"""
|
|
58
|
+
|
|
59
|
+
ALLOWED_DIRS = [
|
|
60
|
+
Path.home() / '.claude',
|
|
61
|
+
Path.home() / '.claude-self-reflect',
|
|
62
|
+
Path.home() / 'projects' / 'claude-self-reflect',
|
|
63
|
+
Path('/tmp') # For temporary files
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def is_safe_path(path: Path) -> bool:
|
|
68
|
+
"""
|
|
69
|
+
Validate that a resolved path is within allowed directories.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
path: Path to validate
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
True if path is safe, False otherwise
|
|
76
|
+
"""
|
|
77
|
+
try:
|
|
78
|
+
resolved = path.expanduser().resolve()
|
|
79
|
+
|
|
80
|
+
# Check for path traversal attempts
|
|
81
|
+
if '..' in str(path):
|
|
82
|
+
logger.warning(f"Path traversal attempt detected: {path}")
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
# Check if path is within allowed directories
|
|
86
|
+
for allowed_dir in PathValidator.ALLOWED_DIRS:
|
|
87
|
+
try:
|
|
88
|
+
resolved.relative_to(allowed_dir.resolve())
|
|
89
|
+
return True
|
|
90
|
+
except ValueError:
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
logger.warning(f"Path outside allowed directories: {resolved}")
|
|
94
|
+
return False
|
|
95
|
+
|
|
96
|
+
except Exception as e:
|
|
97
|
+
logger.error(f"Path validation error: {e}")
|
|
98
|
+
return False
|
|
99
|
+
|
|
100
|
+
@staticmethod
|
|
101
|
+
def sanitize_path(path_str: str) -> Optional[Path]:
|
|
102
|
+
"""
|
|
103
|
+
Sanitize and validate a path string.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
path_str: Path string to sanitize
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Safe Path object or None if unsafe
|
|
110
|
+
"""
|
|
111
|
+
# Remove any null bytes or special characters
|
|
112
|
+
clean_path = re.sub(r'[\x00-\x1f\x7f]', '', path_str)
|
|
113
|
+
|
|
114
|
+
path = Path(clean_path)
|
|
115
|
+
|
|
116
|
+
if PathValidator.is_safe_path(path):
|
|
117
|
+
return path.expanduser().resolve()
|
|
118
|
+
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class ModuleWhitelist:
|
|
123
|
+
"""Fix for Critical Issue #1: Command Injection via module reload"""
|
|
124
|
+
|
|
125
|
+
ALLOWED_MODULES = {
|
|
126
|
+
# Core MCP modules
|
|
127
|
+
'src.server',
|
|
128
|
+
'src.reflection_tools',
|
|
129
|
+
'src.search_tools',
|
|
130
|
+
'src.temporal_tools',
|
|
131
|
+
'src.embedding_manager',
|
|
132
|
+
'src.project_resolver',
|
|
133
|
+
'src.rich_formatting',
|
|
134
|
+
'src.mode_switch_tool',
|
|
135
|
+
'src.code_reload_tool',
|
|
136
|
+
'src.enhanced_tool_registry',
|
|
137
|
+
'src.parallel_search', # Added for hot-reload support
|
|
138
|
+
'src.security_patches', # Added for hot-reload support
|
|
139
|
+
# Standard library modules (safe to reload)
|
|
140
|
+
'json',
|
|
141
|
+
'logging',
|
|
142
|
+
'datetime',
|
|
143
|
+
'pathlib',
|
|
144
|
+
'typing',
|
|
145
|
+
'asyncio',
|
|
146
|
+
'os',
|
|
147
|
+
'sys'
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
@staticmethod
|
|
151
|
+
def is_allowed_module(module_name: str) -> bool:
|
|
152
|
+
"""
|
|
153
|
+
Check if a module is allowed to be reloaded.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
module_name: Name of the module to check
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
True if module is whitelisted
|
|
160
|
+
"""
|
|
161
|
+
# SECURITY: Block dangerous attributes/functions immediately
|
|
162
|
+
dangerous_patterns = [
|
|
163
|
+
'system', 'exec', 'eval', 'subprocess', '__import__',
|
|
164
|
+
'compile', 'open', 'file', 'input', 'raw_input'
|
|
165
|
+
]
|
|
166
|
+
|
|
167
|
+
for pattern in dangerous_patterns:
|
|
168
|
+
if pattern in module_name.lower():
|
|
169
|
+
logger.error(f"SECURITY: Dangerous module pattern blocked: {module_name}")
|
|
170
|
+
return False
|
|
171
|
+
|
|
172
|
+
# Check exact match first
|
|
173
|
+
if module_name in ModuleWhitelist.ALLOWED_MODULES:
|
|
174
|
+
return True
|
|
175
|
+
|
|
176
|
+
# Check if it's a submodule of an allowed module (with proper validation)
|
|
177
|
+
for allowed in ModuleWhitelist.ALLOWED_MODULES:
|
|
178
|
+
# Only allow true submodules, not just string prefixes
|
|
179
|
+
if module_name.startswith(f"{allowed}.") and len(module_name) > len(allowed) + 1:
|
|
180
|
+
# Additional check: ensure it's actually a submodule path
|
|
181
|
+
remaining = module_name[len(allowed) + 1:]
|
|
182
|
+
if '.' not in remaining or remaining.split('.')[0].isidentifier():
|
|
183
|
+
return True
|
|
184
|
+
|
|
185
|
+
logger.warning(f"Module not in whitelist: {module_name}")
|
|
186
|
+
return False
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class AsyncSafetyPatterns:
|
|
190
|
+
"""Fix for Critical Issues #5 & #6: Thread Safety and Race Conditions"""
|
|
191
|
+
|
|
192
|
+
def __init__(self):
|
|
193
|
+
self.locks = {}
|
|
194
|
+
self.semaphores = {}
|
|
195
|
+
|
|
196
|
+
def get_lock(self, resource_name: str) -> asyncio.Lock:
|
|
197
|
+
"""
|
|
198
|
+
Get or create a lock for a resource.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
resource_name: Name of the resource to lock
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
asyncio.Lock for the resource
|
|
205
|
+
"""
|
|
206
|
+
if resource_name not in self.locks:
|
|
207
|
+
self.locks[resource_name] = asyncio.Lock()
|
|
208
|
+
return self.locks[resource_name]
|
|
209
|
+
|
|
210
|
+
def get_semaphore(self, resource_name: str, limit: int = 10) -> asyncio.Semaphore:
|
|
211
|
+
"""
|
|
212
|
+
Get or create a semaphore for concurrency limiting.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
resource_name: Name of the resource
|
|
216
|
+
limit: Maximum concurrent operations
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
asyncio.Semaphore for the resource
|
|
220
|
+
"""
|
|
221
|
+
key = f"{resource_name}_{limit}"
|
|
222
|
+
if key not in self.semaphores:
|
|
223
|
+
self.semaphores[key] = asyncio.Semaphore(limit)
|
|
224
|
+
return self.semaphores[key]
|
|
225
|
+
|
|
226
|
+
@staticmethod
|
|
227
|
+
async def run_in_executor(func, *args):
|
|
228
|
+
"""
|
|
229
|
+
Safely run blocking code in executor.
|
|
230
|
+
Replaces dangerous threading.Thread usage.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
func: Blocking function to run
|
|
234
|
+
*args: Arguments for the function
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
Result of the function
|
|
238
|
+
"""
|
|
239
|
+
loop = asyncio.get_event_loop()
|
|
240
|
+
return await loop.run_in_executor(None, func, *args)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class QdrantAuthManager:
|
|
244
|
+
"""Fix for Critical Issue #7: Unprotected Network Endpoints"""
|
|
245
|
+
|
|
246
|
+
@staticmethod
|
|
247
|
+
def get_secure_client(url: str, api_key: Optional[str] = None):
|
|
248
|
+
"""
|
|
249
|
+
Create a Qdrant client with authentication.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
url: Qdrant server URL
|
|
253
|
+
api_key: API key for authentication
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Configured Qdrant client
|
|
257
|
+
"""
|
|
258
|
+
from qdrant_client import AsyncQdrantClient
|
|
259
|
+
|
|
260
|
+
# Get API key from environment if not provided
|
|
261
|
+
if not api_key:
|
|
262
|
+
api_key = os.getenv('QDRANT_API_KEY')
|
|
263
|
+
|
|
264
|
+
if not api_key:
|
|
265
|
+
logger.warning("Qdrant API key not configured - using unauthenticated connection")
|
|
266
|
+
# For backward compatibility, allow unauthenticated during migration period
|
|
267
|
+
# This should be removed after 2025-12-01
|
|
268
|
+
from datetime import datetime
|
|
269
|
+
if datetime.now() > datetime(2025, 12, 1):
|
|
270
|
+
raise ValueError("Qdrant authentication is now required")
|
|
271
|
+
|
|
272
|
+
return AsyncQdrantClient(
|
|
273
|
+
url=url,
|
|
274
|
+
api_key=api_key,
|
|
275
|
+
timeout=30
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
# ==================== HIGH PRIORITY FIXES ====================
|
|
280
|
+
|
|
281
|
+
class LazyAsyncInitializer:
|
|
282
|
+
"""Fix for High Issue #8: Module-Level Async Client Initialization"""
|
|
283
|
+
|
|
284
|
+
def __init__(self):
|
|
285
|
+
self._client = None
|
|
286
|
+
self._initialization_lock = asyncio.Lock()
|
|
287
|
+
|
|
288
|
+
async def get_client(self, *args, **kwargs):
|
|
289
|
+
"""
|
|
290
|
+
Lazy initialize client within async context.
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
Initialized client
|
|
294
|
+
"""
|
|
295
|
+
if self._client is None:
|
|
296
|
+
async with self._initialization_lock:
|
|
297
|
+
if self._client is None:
|
|
298
|
+
# Initialize client here with proper async context
|
|
299
|
+
self._client = await self._create_client(*args, **kwargs)
|
|
300
|
+
return self._client
|
|
301
|
+
|
|
302
|
+
async def _create_client(self, *args, **kwargs):
|
|
303
|
+
"""Override in subclass to create specific client."""
|
|
304
|
+
raise NotImplementedError
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class ConcurrencyLimiter:
|
|
308
|
+
"""Fix for High Issue #9: Unbounded Concurrency"""
|
|
309
|
+
|
|
310
|
+
DEFAULT_LIMIT = 10
|
|
311
|
+
|
|
312
|
+
@staticmethod
|
|
313
|
+
async def limited_gather(tasks: List, limit: int = DEFAULT_LIMIT):
|
|
314
|
+
"""
|
|
315
|
+
Execute tasks with concurrency limit.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
tasks: List of coroutines to execute
|
|
319
|
+
limit: Maximum concurrent tasks
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
List of results
|
|
323
|
+
"""
|
|
324
|
+
semaphore = asyncio.Semaphore(limit)
|
|
325
|
+
|
|
326
|
+
async def run_with_limit(task):
|
|
327
|
+
async with semaphore:
|
|
328
|
+
return await task
|
|
329
|
+
|
|
330
|
+
return await asyncio.gather(
|
|
331
|
+
*[run_with_limit(task) for task in tasks],
|
|
332
|
+
return_exceptions=True
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class MemoryOptimizer:
|
|
337
|
+
"""Fix for High Issue #10: Memory Leak - Decay Processing"""
|
|
338
|
+
|
|
339
|
+
@staticmethod
|
|
340
|
+
def calculate_safe_limit(requested_limit: int, memory_factor: float = 1.5) -> int:
|
|
341
|
+
"""
|
|
342
|
+
Calculate safe limit to prevent memory explosion.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
requested_limit: User-requested limit
|
|
346
|
+
memory_factor: Multiplication factor (reduced from 3x to 1.5x)
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
Safe limit value
|
|
350
|
+
"""
|
|
351
|
+
# Cap the multiplication factor to prevent OOM
|
|
352
|
+
safe_factor = min(memory_factor, 2.0)
|
|
353
|
+
|
|
354
|
+
# Also cap absolute value
|
|
355
|
+
max_safe_limit = 1000
|
|
356
|
+
|
|
357
|
+
calculated = int(requested_limit * safe_factor)
|
|
358
|
+
return min(calculated, max_safe_limit)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class ResourceManager:
|
|
362
|
+
"""Fix for High Issue #11: Incomplete Resource Cleanup"""
|
|
363
|
+
|
|
364
|
+
def __init__(self):
|
|
365
|
+
self.resources = []
|
|
366
|
+
|
|
367
|
+
async def __aenter__(self):
|
|
368
|
+
return self
|
|
369
|
+
|
|
370
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
371
|
+
"""Cleanup all registered resources."""
|
|
372
|
+
for resource in self.resources:
|
|
373
|
+
try:
|
|
374
|
+
if hasattr(resource, 'close'):
|
|
375
|
+
if asyncio.iscoroutinefunction(resource.close):
|
|
376
|
+
await resource.close()
|
|
377
|
+
else:
|
|
378
|
+
resource.close()
|
|
379
|
+
elif hasattr(resource, 'cleanup'):
|
|
380
|
+
if asyncio.iscoroutinefunction(resource.cleanup):
|
|
381
|
+
await resource.cleanup()
|
|
382
|
+
else:
|
|
383
|
+
resource.cleanup()
|
|
384
|
+
except Exception as e:
|
|
385
|
+
logger.error(f"Resource cleanup failed: {e}")
|
|
386
|
+
|
|
387
|
+
def register(self, resource):
|
|
388
|
+
"""Register a resource for cleanup."""
|
|
389
|
+
self.resources.append(resource)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
class ExceptionLogger:
|
|
393
|
+
"""Fix for High Issue #12: Silent Exception Handling"""
|
|
394
|
+
|
|
395
|
+
@staticmethod
|
|
396
|
+
def log_exception(e: Exception, context: str = "") -> None:
|
|
397
|
+
"""
|
|
398
|
+
Log exception with context and metrics.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
e: Exception to log
|
|
402
|
+
context: Additional context about where the exception occurred
|
|
403
|
+
"""
|
|
404
|
+
logger.error(
|
|
405
|
+
f"Exception in {context}: {type(e).__name__}: {str(e)}",
|
|
406
|
+
exc_info=True,
|
|
407
|
+
extra={
|
|
408
|
+
'exception_type': type(e).__name__,
|
|
409
|
+
'context': context,
|
|
410
|
+
'metric': 'exception_count'
|
|
411
|
+
}
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
class InputValidator:
|
|
416
|
+
"""Fix for High Issue #15: Missing Input Validation"""
|
|
417
|
+
|
|
418
|
+
@staticmethod
|
|
419
|
+
def validate_search_query(query: str) -> str:
|
|
420
|
+
"""
|
|
421
|
+
Validate and sanitize search query.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
query: Search query to validate
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
Sanitized query
|
|
428
|
+
"""
|
|
429
|
+
# Remove potentially dangerous characters
|
|
430
|
+
sanitized = re.sub(r'[<>&"\'`]', '', query)
|
|
431
|
+
|
|
432
|
+
# Limit length
|
|
433
|
+
max_length = 1000
|
|
434
|
+
if len(sanitized) > max_length:
|
|
435
|
+
sanitized = sanitized[:max_length]
|
|
436
|
+
|
|
437
|
+
# Remove control characters
|
|
438
|
+
sanitized = re.sub(r'[\x00-\x1f\x7f]', '', sanitized)
|
|
439
|
+
|
|
440
|
+
return sanitized
|
|
441
|
+
|
|
442
|
+
@staticmethod
|
|
443
|
+
def validate_project_name(name: str) -> str:
|
|
444
|
+
"""
|
|
445
|
+
Validate project name.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
name: Project name to validate
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
Sanitized project name
|
|
452
|
+
"""
|
|
453
|
+
# Allow only alphanumeric, dash, underscore
|
|
454
|
+
sanitized = re.sub(r'[^a-zA-Z0-9_-]', '', name)
|
|
455
|
+
|
|
456
|
+
# Limit length
|
|
457
|
+
max_length = 100
|
|
458
|
+
if len(sanitized) > max_length:
|
|
459
|
+
sanitized = sanitized[:max_length]
|
|
460
|
+
|
|
461
|
+
return sanitized
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
# ==================== MIGRATION HELPERS ====================
|
|
465
|
+
|
|
466
|
+
class BackwardCompatibility:
|
|
467
|
+
"""Helper functions for maintaining backward compatibility during migration"""
|
|
468
|
+
|
|
469
|
+
@staticmethod
|
|
470
|
+
async def dual_id_lookup(new_id: str, old_id: str, client) -> Optional[Any]:
|
|
471
|
+
"""
|
|
472
|
+
Try to find an item by both new and old ID formats.
|
|
473
|
+
|
|
474
|
+
Args:
|
|
475
|
+
new_id: New format ID (SHA-256)
|
|
476
|
+
old_id: Old format ID (MD5)
|
|
477
|
+
client: Database client
|
|
478
|
+
|
|
479
|
+
Returns:
|
|
480
|
+
Found item or None
|
|
481
|
+
"""
|
|
482
|
+
# Try new ID first
|
|
483
|
+
result = await client.get(new_id)
|
|
484
|
+
if result:
|
|
485
|
+
return result
|
|
486
|
+
|
|
487
|
+
# Fall back to old ID for backward compatibility
|
|
488
|
+
return await client.get(old_id)
|
|
489
|
+
|
|
490
|
+
@staticmethod
|
|
491
|
+
def get_collection_name(project: str, mode: str, version: str = "v4") -> str:
|
|
492
|
+
"""
|
|
493
|
+
Get collection name with backward compatibility.
|
|
494
|
+
|
|
495
|
+
Args:
|
|
496
|
+
project: Project name
|
|
497
|
+
mode: Embedding mode (local/cloud)
|
|
498
|
+
version: Collection version
|
|
499
|
+
|
|
500
|
+
Returns:
|
|
501
|
+
Collection name
|
|
502
|
+
"""
|
|
503
|
+
if version == "v3":
|
|
504
|
+
# Old format
|
|
505
|
+
suffix = "_local" if mode == "local" else "_voyage"
|
|
506
|
+
return f"{project}{suffix}"
|
|
507
|
+
else:
|
|
508
|
+
# New format with dimensions
|
|
509
|
+
dim = "384d" if mode == "local" else "1024d"
|
|
510
|
+
return f"csr_{project}_{mode}_{dim}"
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
# ==================== TESTING UTILITIES ====================
|
|
514
|
+
|
|
515
|
+
class RegressionTester:
|
|
516
|
+
"""Utilities for regression testing after fixes"""
|
|
517
|
+
|
|
518
|
+
@staticmethod
|
|
519
|
+
async def test_hash_migration():
|
|
520
|
+
"""Test MD5 to SHA-256 migration preserves data access."""
|
|
521
|
+
# This would be implemented with actual database calls
|
|
522
|
+
pass
|
|
523
|
+
|
|
524
|
+
@staticmethod
|
|
525
|
+
async def test_path_traversal():
|
|
526
|
+
"""Test path traversal protection."""
|
|
527
|
+
dangerous_paths = [
|
|
528
|
+
"../../../etc/passwd",
|
|
529
|
+
"/etc/passwd",
|
|
530
|
+
"~/../../../etc/passwd",
|
|
531
|
+
"/tmp/../etc/passwd",
|
|
532
|
+
"..\\..\\windows\\system32"
|
|
533
|
+
]
|
|
534
|
+
|
|
535
|
+
for path in dangerous_paths:
|
|
536
|
+
result = PathValidator.sanitize_path(path)
|
|
537
|
+
assert result is None, f"Path traversal not blocked: {path}"
|
|
538
|
+
|
|
539
|
+
@staticmethod
|
|
540
|
+
async def test_concurrency_limits():
|
|
541
|
+
"""Test concurrency limiting works."""
|
|
542
|
+
tasks = [asyncio.sleep(0.1) for _ in range(100)]
|
|
543
|
+
start_time = asyncio.get_event_loop().time()
|
|
544
|
+
|
|
545
|
+
await ConcurrencyLimiter.limited_gather(tasks, limit=10)
|
|
546
|
+
|
|
547
|
+
elapsed = asyncio.get_event_loop().time() - start_time
|
|
548
|
+
# With limit of 10, 100 tasks of 0.1s each should take ~1s
|
|
549
|
+
assert 0.9 < elapsed < 1.5, f"Concurrency limit not working: {elapsed}s"
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
if __name__ == "__main__":
|
|
553
|
+
# Run basic tests
|
|
554
|
+
asyncio.run(RegressionTester.test_path_traversal())
|
|
555
|
+
print("Security patches loaded successfully")
|