claude-self-reflect 3.2.4 → 3.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude/agents/claude-self-reflect-test.md +992 -510
- package/.claude/agents/reflection-specialist.md +59 -3
- package/README.md +14 -5
- package/installer/cli.js +16 -0
- package/installer/postinstall.js +14 -0
- package/installer/statusline-setup.js +289 -0
- package/mcp-server/run-mcp.sh +73 -5
- package/mcp-server/src/app_context.py +64 -0
- package/mcp-server/src/config.py +57 -0
- package/mcp-server/src/connection_pool.py +286 -0
- package/mcp-server/src/decay_manager.py +106 -0
- package/mcp-server/src/embedding_manager.py +64 -40
- package/mcp-server/src/embeddings_old.py +141 -0
- package/mcp-server/src/models.py +64 -0
- package/mcp-server/src/parallel_search.py +305 -0
- package/mcp-server/src/project_resolver.py +5 -0
- package/mcp-server/src/reflection_tools.py +211 -0
- package/mcp-server/src/rich_formatting.py +196 -0
- package/mcp-server/src/search_tools.py +874 -0
- package/mcp-server/src/server.py +127 -1720
- package/mcp-server/src/temporal_design.py +132 -0
- package/mcp-server/src/temporal_tools.py +604 -0
- package/mcp-server/src/temporal_utils.py +384 -0
- package/mcp-server/src/utils.py +150 -67
- package/package.json +15 -1
- package/scripts/add-timestamp-indexes.py +134 -0
- package/scripts/ast_grep_final_analyzer.py +325 -0
- package/scripts/ast_grep_unified_registry.py +556 -0
- package/scripts/check-collections.py +29 -0
- package/scripts/csr-status +366 -0
- package/scripts/debug-august-parsing.py +76 -0
- package/scripts/debug-import-single.py +91 -0
- package/scripts/debug-project-resolver.py +82 -0
- package/scripts/debug-temporal-tools.py +135 -0
- package/scripts/delta-metadata-update.py +547 -0
- package/scripts/import-conversations-unified.py +157 -25
- package/scripts/precompact-hook.sh +33 -0
- package/scripts/session_quality_tracker.py +481 -0
- package/scripts/streaming-watcher.py +1578 -0
- package/scripts/update_patterns.py +334 -0
- package/scripts/utils.py +39 -0
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Connection pooling for Qdrant client to improve performance and resource management.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from typing import Optional, Any
|
|
7
|
+
from contextlib import asynccontextmanager
|
|
8
|
+
import logging
|
|
9
|
+
from qdrant_client import AsyncQdrantClient
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class QdrantConnectionPool:
|
|
15
|
+
"""
|
|
16
|
+
A connection pool for Qdrant clients with configurable size and timeout.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
url: str,
|
|
22
|
+
pool_size: int = 10,
|
|
23
|
+
max_overflow: int = 5,
|
|
24
|
+
timeout: float = 30.0,
|
|
25
|
+
retry_attempts: int = 3,
|
|
26
|
+
retry_delay: float = 1.0
|
|
27
|
+
):
|
|
28
|
+
"""
|
|
29
|
+
Initialize the connection pool.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
url: Qdrant server URL
|
|
33
|
+
pool_size: Base number of connections to maintain
|
|
34
|
+
max_overflow: Additional connections that can be created if pool is exhausted
|
|
35
|
+
timeout: Timeout for acquiring a connection from the pool
|
|
36
|
+
retry_attempts: Number of retry attempts for failed operations
|
|
37
|
+
retry_delay: Delay between retry attempts (with exponential backoff)
|
|
38
|
+
"""
|
|
39
|
+
self.url = url
|
|
40
|
+
self.pool_size = pool_size
|
|
41
|
+
self.max_overflow = max_overflow
|
|
42
|
+
self.timeout = timeout
|
|
43
|
+
self.retry_attempts = retry_attempts
|
|
44
|
+
self.retry_delay = retry_delay
|
|
45
|
+
|
|
46
|
+
# Connection pool
|
|
47
|
+
self._pool = asyncio.Queue(maxsize=pool_size)
|
|
48
|
+
self._overflow_connections = []
|
|
49
|
+
self._semaphore = asyncio.Semaphore(pool_size + max_overflow)
|
|
50
|
+
self._initialized = False
|
|
51
|
+
self._lock = asyncio.Lock()
|
|
52
|
+
|
|
53
|
+
# Statistics
|
|
54
|
+
self.stats = {
|
|
55
|
+
'connections_created': 0,
|
|
56
|
+
'connections_reused': 0,
|
|
57
|
+
'connections_failed': 0,
|
|
58
|
+
'overflow_used': 0,
|
|
59
|
+
'timeouts': 0
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
async def initialize(self):
|
|
63
|
+
"""Initialize the connection pool with base connections."""
|
|
64
|
+
async with self._lock:
|
|
65
|
+
if self._initialized:
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
# Create initial pool connections
|
|
69
|
+
for _ in range(self.pool_size):
|
|
70
|
+
try:
|
|
71
|
+
client = AsyncQdrantClient(url=self.url)
|
|
72
|
+
await self._pool.put(client)
|
|
73
|
+
self.stats['connections_created'] += 1
|
|
74
|
+
except Exception as e:
|
|
75
|
+
logger.error(f"Failed to create initial connection: {e}")
|
|
76
|
+
self.stats['connections_failed'] += 1
|
|
77
|
+
|
|
78
|
+
self._initialized = True
|
|
79
|
+
logger.info(f"Connection pool initialized with {self._pool.qsize()} connections")
|
|
80
|
+
|
|
81
|
+
@asynccontextmanager
|
|
82
|
+
async def acquire(self):
|
|
83
|
+
"""
|
|
84
|
+
Acquire a connection from the pool.
|
|
85
|
+
|
|
86
|
+
Yields:
|
|
87
|
+
AsyncQdrantClient instance
|
|
88
|
+
"""
|
|
89
|
+
if not self._initialized:
|
|
90
|
+
await self.initialize()
|
|
91
|
+
|
|
92
|
+
client = None
|
|
93
|
+
acquired_from_overflow = False
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
# Try to get a connection with timeout
|
|
97
|
+
try:
|
|
98
|
+
client = await asyncio.wait_for(
|
|
99
|
+
self._pool.get(),
|
|
100
|
+
timeout=self.timeout
|
|
101
|
+
)
|
|
102
|
+
self.stats['connections_reused'] += 1
|
|
103
|
+
except asyncio.TimeoutError:
|
|
104
|
+
# Pool is exhausted, try overflow
|
|
105
|
+
self.stats['timeouts'] += 1
|
|
106
|
+
|
|
107
|
+
if len(self._overflow_connections) < self.max_overflow:
|
|
108
|
+
# Create overflow connection
|
|
109
|
+
logger.debug("Creating overflow connection")
|
|
110
|
+
client = AsyncQdrantClient(url=self.url)
|
|
111
|
+
self._overflow_connections.append(client)
|
|
112
|
+
acquired_from_overflow = True
|
|
113
|
+
self.stats['overflow_used'] += 1
|
|
114
|
+
self.stats['connections_created'] += 1
|
|
115
|
+
else:
|
|
116
|
+
raise RuntimeError("Connection pool exhausted and max overflow reached")
|
|
117
|
+
|
|
118
|
+
# Yield the client for use
|
|
119
|
+
yield client
|
|
120
|
+
|
|
121
|
+
finally:
|
|
122
|
+
# Return connection to pool
|
|
123
|
+
if client is not None:
|
|
124
|
+
if acquired_from_overflow:
|
|
125
|
+
# Remove from overflow list
|
|
126
|
+
if client in self._overflow_connections:
|
|
127
|
+
self._overflow_connections.remove(client)
|
|
128
|
+
else:
|
|
129
|
+
# Return to pool
|
|
130
|
+
try:
|
|
131
|
+
await self._pool.put(client)
|
|
132
|
+
except asyncio.QueueFull:
|
|
133
|
+
# This shouldn't happen, but handle gracefully
|
|
134
|
+
logger.warning("Connection pool is full, closing extra connection")
|
|
135
|
+
# In production, we might want to close the client here
|
|
136
|
+
|
|
137
|
+
async def execute_with_retry(self, func, *args, **kwargs):
|
|
138
|
+
"""
|
|
139
|
+
Execute a function with retry logic and exponential backoff.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
func: Async function to execute
|
|
143
|
+
*args: Positional arguments for the function
|
|
144
|
+
**kwargs: Keyword arguments for the function
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Result from the function
|
|
148
|
+
"""
|
|
149
|
+
last_exception = None
|
|
150
|
+
delay = self.retry_delay
|
|
151
|
+
|
|
152
|
+
for attempt in range(self.retry_attempts):
|
|
153
|
+
try:
|
|
154
|
+
async with self.acquire() as client:
|
|
155
|
+
# Pass the client as the first argument
|
|
156
|
+
return await func(client, *args, **kwargs)
|
|
157
|
+
except Exception as e:
|
|
158
|
+
last_exception = e
|
|
159
|
+
if attempt < self.retry_attempts - 1:
|
|
160
|
+
logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {delay}s...")
|
|
161
|
+
await asyncio.sleep(delay)
|
|
162
|
+
delay *= 2 # Exponential backoff
|
|
163
|
+
else:
|
|
164
|
+
logger.error(f"All {self.retry_attempts} attempts failed: {e}")
|
|
165
|
+
|
|
166
|
+
raise last_exception
|
|
167
|
+
|
|
168
|
+
async def close(self):
|
|
169
|
+
"""Close all connections in the pool."""
|
|
170
|
+
async with self._lock:
|
|
171
|
+
# Close all pooled connections
|
|
172
|
+
while not self._pool.empty():
|
|
173
|
+
try:
|
|
174
|
+
client = await self._pool.get()
|
|
175
|
+
# AsyncQdrantClient doesn't have a close method, but we can del it
|
|
176
|
+
del client
|
|
177
|
+
except Exception as e:
|
|
178
|
+
logger.error(f"Error closing connection: {e}")
|
|
179
|
+
|
|
180
|
+
# Close overflow connections
|
|
181
|
+
for client in self._overflow_connections:
|
|
182
|
+
try:
|
|
183
|
+
del client
|
|
184
|
+
except Exception as e:
|
|
185
|
+
logger.error(f"Error closing overflow connection: {e}")
|
|
186
|
+
|
|
187
|
+
self._overflow_connections.clear()
|
|
188
|
+
self._initialized = False
|
|
189
|
+
logger.info("Connection pool closed")
|
|
190
|
+
|
|
191
|
+
def get_stats(self) -> dict:
|
|
192
|
+
"""Get pool statistics."""
|
|
193
|
+
return {
|
|
194
|
+
**self.stats,
|
|
195
|
+
'current_pool_size': self._pool.qsize() if self._initialized else 0,
|
|
196
|
+
'overflow_active': len(self._overflow_connections),
|
|
197
|
+
'initialized': self._initialized
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
# Circuit breaker implementation for additional resilience
|
|
202
|
+
class CircuitBreaker:
|
|
203
|
+
"""
|
|
204
|
+
Circuit breaker pattern to prevent cascading failures.
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
def __init__(
|
|
208
|
+
self,
|
|
209
|
+
failure_threshold: int = 5,
|
|
210
|
+
recovery_timeout: float = 60.0,
|
|
211
|
+
expected_exception: type = Exception
|
|
212
|
+
):
|
|
213
|
+
"""
|
|
214
|
+
Initialize circuit breaker.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
failure_threshold: Number of failures before opening circuit
|
|
218
|
+
recovery_timeout: Time to wait before attempting recovery
|
|
219
|
+
expected_exception: Exception type to catch
|
|
220
|
+
"""
|
|
221
|
+
self.failure_threshold = failure_threshold
|
|
222
|
+
self.recovery_timeout = recovery_timeout
|
|
223
|
+
self.expected_exception = expected_exception
|
|
224
|
+
|
|
225
|
+
self._failure_count = 0
|
|
226
|
+
self._last_failure_time = None
|
|
227
|
+
self._state = 'closed' # closed, open, half_open
|
|
228
|
+
self._lock = asyncio.Lock()
|
|
229
|
+
|
|
230
|
+
async def call(self, func, *args, **kwargs):
|
|
231
|
+
"""
|
|
232
|
+
Call a function through the circuit breaker.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
func: Async function to call
|
|
236
|
+
*args: Positional arguments
|
|
237
|
+
**kwargs: Keyword arguments
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
Result from function
|
|
241
|
+
|
|
242
|
+
Raises:
|
|
243
|
+
CircuitBreakerOpen: If circuit is open
|
|
244
|
+
"""
|
|
245
|
+
async with self._lock:
|
|
246
|
+
# Check circuit state
|
|
247
|
+
if self._state == 'open':
|
|
248
|
+
# Check if we should try half-open
|
|
249
|
+
if self._last_failure_time:
|
|
250
|
+
time_since_failure = asyncio.get_event_loop().time() - self._last_failure_time
|
|
251
|
+
if time_since_failure > self.recovery_timeout:
|
|
252
|
+
self._state = 'half_open'
|
|
253
|
+
logger.info("Circuit breaker entering half-open state")
|
|
254
|
+
else:
|
|
255
|
+
raise CircuitBreakerOpen(f"Circuit breaker is open (failures: {self._failure_count})")
|
|
256
|
+
|
|
257
|
+
try:
|
|
258
|
+
# Attempt the call
|
|
259
|
+
result = await func(*args, **kwargs)
|
|
260
|
+
|
|
261
|
+
# Success - update state
|
|
262
|
+
async with self._lock:
|
|
263
|
+
if self._state == 'half_open':
|
|
264
|
+
self._state = 'closed'
|
|
265
|
+
logger.info("Circuit breaker closed after successful recovery")
|
|
266
|
+
self._failure_count = 0
|
|
267
|
+
self._last_failure_time = None
|
|
268
|
+
|
|
269
|
+
return result
|
|
270
|
+
|
|
271
|
+
except self.expected_exception as e:
|
|
272
|
+
# Failure - update state
|
|
273
|
+
async with self._lock:
|
|
274
|
+
self._failure_count += 1
|
|
275
|
+
self._last_failure_time = asyncio.get_event_loop().time()
|
|
276
|
+
|
|
277
|
+
if self._failure_count >= self.failure_threshold:
|
|
278
|
+
self._state = 'open'
|
|
279
|
+
logger.error(f"Circuit breaker opened after {self._failure_count} failures")
|
|
280
|
+
|
|
281
|
+
raise e
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class CircuitBreakerOpen(Exception):
|
|
285
|
+
"""Exception raised when circuit breaker is open."""
|
|
286
|
+
pass
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""Decay calculation manager for Claude Self-Reflect MCP server."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
from typing import List, Tuple, Optional
|
|
6
|
+
try:
|
|
7
|
+
from .config import (
|
|
8
|
+
USE_DECAY,
|
|
9
|
+
DECAY_SCALE_DAYS,
|
|
10
|
+
DECAY_WEIGHT,
|
|
11
|
+
USE_NATIVE_DECAY,
|
|
12
|
+
logger
|
|
13
|
+
)
|
|
14
|
+
except ImportError:
|
|
15
|
+
# Fallback for direct execution
|
|
16
|
+
import os
|
|
17
|
+
import logging
|
|
18
|
+
USE_DECAY = os.getenv('USE_DECAY', 'false').lower() == 'true'
|
|
19
|
+
DECAY_SCALE_DAYS = float(os.getenv('DECAY_SCALE_DAYS', '90'))
|
|
20
|
+
DECAY_WEIGHT = float(os.getenv('DECAY_WEIGHT', '0.3'))
|
|
21
|
+
USE_NATIVE_DECAY = os.getenv('USE_NATIVE_DECAY', 'false').lower() == 'true'
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
class DecayManager:
|
|
25
|
+
"""Manages memory decay calculations for search results."""
|
|
26
|
+
|
|
27
|
+
def __init__(self):
|
|
28
|
+
self.scale_ms = DECAY_SCALE_DAYS * 24 * 60 * 60 * 1000
|
|
29
|
+
self.weight = DECAY_WEIGHT
|
|
30
|
+
self.use_decay = USE_DECAY
|
|
31
|
+
self.use_native = USE_NATIVE_DECAY
|
|
32
|
+
|
|
33
|
+
def calculate_decay_score(
|
|
34
|
+
self,
|
|
35
|
+
base_score: float,
|
|
36
|
+
timestamp: str
|
|
37
|
+
) -> float:
|
|
38
|
+
"""Calculate decayed score for a single result."""
|
|
39
|
+
if not self.use_decay:
|
|
40
|
+
return base_score
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
# Parse timestamp
|
|
44
|
+
if timestamp.endswith('Z'):
|
|
45
|
+
timestamp = timestamp.replace('Z', '+00:00')
|
|
46
|
+
|
|
47
|
+
result_time = datetime.fromisoformat(timestamp)
|
|
48
|
+
if result_time.tzinfo is None:
|
|
49
|
+
result_time = result_time.replace(tzinfo=timezone.utc)
|
|
50
|
+
|
|
51
|
+
# Calculate age
|
|
52
|
+
now = datetime.now(timezone.utc)
|
|
53
|
+
age_ms = (now - result_time).total_seconds() * 1000
|
|
54
|
+
|
|
55
|
+
# Calculate decay factor using half-life formula
|
|
56
|
+
# decay = exp(-ln(2) * age / half_life)
|
|
57
|
+
decay_factor = math.exp(-0.693147 * age_ms / self.scale_ms)
|
|
58
|
+
|
|
59
|
+
# Apply decay with weight
|
|
60
|
+
final_score = base_score * (1 - self.weight) + base_score * self.weight * decay_factor
|
|
61
|
+
|
|
62
|
+
return final_score
|
|
63
|
+
|
|
64
|
+
except Exception as e:
|
|
65
|
+
logger.error(f"Failed to calculate decay: {e}")
|
|
66
|
+
return base_score
|
|
67
|
+
|
|
68
|
+
def apply_decay_to_results(
|
|
69
|
+
self,
|
|
70
|
+
results: List[Tuple[float, str, dict]]
|
|
71
|
+
) -> List[Tuple[float, str, dict]]:
|
|
72
|
+
"""Apply decay to a list of results and re-sort."""
|
|
73
|
+
if not self.use_decay:
|
|
74
|
+
return results
|
|
75
|
+
|
|
76
|
+
decayed_results = []
|
|
77
|
+
for score, id_str, payload in results:
|
|
78
|
+
timestamp = payload.get('timestamp', datetime.now().isoformat())
|
|
79
|
+
decayed_score = self.calculate_decay_score(score, timestamp)
|
|
80
|
+
decayed_results.append((decayed_score, id_str, payload))
|
|
81
|
+
|
|
82
|
+
# Re-sort by decayed score
|
|
83
|
+
decayed_results.sort(key=lambda x: x[0], reverse=True)
|
|
84
|
+
|
|
85
|
+
return decayed_results
|
|
86
|
+
|
|
87
|
+
def get_native_decay_config(self) -> Optional[dict]:
|
|
88
|
+
"""Get configuration for native Qdrant decay."""
|
|
89
|
+
if not self.use_native:
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
return {
|
|
93
|
+
'scale_seconds': self.scale_ms / 1000,
|
|
94
|
+
'weight': self.weight,
|
|
95
|
+
'midpoint': 0.5 # Half-life semantics
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
def should_use_decay(self, explicit_setting: Optional[int] = None) -> bool:
|
|
99
|
+
"""Determine if decay should be used for a query."""
|
|
100
|
+
if explicit_setting is not None:
|
|
101
|
+
if explicit_setting == 1:
|
|
102
|
+
return True
|
|
103
|
+
elif explicit_setting == 0:
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
return self.use_decay
|
|
@@ -16,16 +16,16 @@ class EmbeddingManager:
|
|
|
16
16
|
"""Manages embedding models with proper cache and lock handling."""
|
|
17
17
|
|
|
18
18
|
def __init__(self):
|
|
19
|
-
self.
|
|
20
|
-
self.model_type = None # 'local' or 'voyage'
|
|
19
|
+
self.local_model = None
|
|
21
20
|
self.voyage_client = None
|
|
22
|
-
|
|
21
|
+
self.model_type = None # Default model type ('local' or 'voyage')
|
|
22
|
+
|
|
23
23
|
# Configuration
|
|
24
24
|
self.prefer_local = os.getenv('PREFER_LOCAL_EMBEDDINGS', 'true').lower() == 'true'
|
|
25
25
|
self.voyage_key = os.getenv('VOYAGE_KEY') or os.getenv('VOYAGE_KEY-2')
|
|
26
26
|
self.embedding_model = os.getenv('EMBEDDING_MODEL', 'sentence-transformers/all-MiniLM-L6-v2')
|
|
27
27
|
self.download_timeout = int(os.getenv('FASTEMBED_DOWNLOAD_TIMEOUT', '30'))
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
# Set cache directory to our controlled location
|
|
30
30
|
self.cache_dir = Path(__file__).parent.parent / '.fastembed-cache'
|
|
31
31
|
|
|
@@ -50,27 +50,35 @@ class EmbeddingManager:
|
|
|
50
50
|
logger.warning(f"Error cleaning locks: {e}")
|
|
51
51
|
|
|
52
52
|
def initialize(self) -> bool:
|
|
53
|
-
"""Initialize embedding
|
|
54
|
-
logger.info("Initializing embedding manager...")
|
|
55
|
-
|
|
53
|
+
"""Initialize BOTH embedding models to support mixed collections."""
|
|
54
|
+
logger.info("Initializing embedding manager for dual-mode support...")
|
|
55
|
+
|
|
56
56
|
# Clean up any stale locks first
|
|
57
57
|
self._clean_stale_locks()
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
58
|
+
|
|
59
|
+
# Initialize both models for mixed collection support
|
|
60
|
+
local_success = self._try_initialize_local()
|
|
61
|
+
voyage_success = False
|
|
62
|
+
|
|
63
|
+
if self.voyage_key:
|
|
64
|
+
voyage_success = self._try_initialize_voyage()
|
|
65
|
+
|
|
66
|
+
# Set default model type based on preference and availability
|
|
67
|
+
if self.prefer_local and local_success:
|
|
68
|
+
self.model_type = 'local'
|
|
69
|
+
logger.info("Default model set to LOCAL embeddings")
|
|
70
|
+
elif voyage_success:
|
|
71
|
+
self.model_type = 'voyage'
|
|
72
|
+
logger.info("Default model set to VOYAGE embeddings")
|
|
73
|
+
elif local_success:
|
|
74
|
+
self.model_type = 'local'
|
|
75
|
+
logger.info("Default model set to LOCAL embeddings (fallback)")
|
|
65
76
|
else:
|
|
66
|
-
|
|
67
|
-
if self.voyage_key and self._try_initialize_voyage():
|
|
68
|
-
return True
|
|
69
|
-
logger.warning("Voyage AI failed, trying local as fallback...")
|
|
70
|
-
if self._try_initialize_local():
|
|
71
|
-
return True
|
|
72
|
-
logger.error("Both Voyage AI and local embeddings failed")
|
|
77
|
+
logger.error("Failed to initialize any embedding model")
|
|
73
78
|
return False
|
|
79
|
+
|
|
80
|
+
logger.info(f"Embedding models available - Local: {local_success}, Voyage: {voyage_success}")
|
|
81
|
+
return True
|
|
74
82
|
|
|
75
83
|
def _try_initialize_local(self) -> bool:
|
|
76
84
|
"""Try to initialize local FastEmbed model with timeout and optimizations."""
|
|
@@ -119,11 +127,10 @@ class EmbeddingManager:
|
|
|
119
127
|
from fastembed import TextEmbedding
|
|
120
128
|
# Initialize with optimized settings
|
|
121
129
|
# Note: FastEmbed uses these environment variables internally
|
|
122
|
-
self.
|
|
130
|
+
self.local_model = TextEmbedding(
|
|
123
131
|
model_name=self.embedding_model,
|
|
124
132
|
threads=1 # Single thread per worker to prevent over-subscription
|
|
125
133
|
)
|
|
126
|
-
self.model_type = 'local'
|
|
127
134
|
success = True
|
|
128
135
|
logger.info(f"Successfully initialized local model: {self.embedding_model} with single-thread mode")
|
|
129
136
|
except Exception as e:
|
|
@@ -177,39 +184,48 @@ class EmbeddingManager:
|
|
|
177
184
|
logger.error(f"Failed to initialize Voyage AI: {e}")
|
|
178
185
|
return False
|
|
179
186
|
|
|
180
|
-
def embed(self, texts: Union[str, List[str]], input_type: str = "document") -> Optional[List[List[float]]]:
|
|
181
|
-
"""Generate embeddings using the
|
|
182
|
-
|
|
183
|
-
|
|
187
|
+
def embed(self, texts: Union[str, List[str]], input_type: str = "document", force_type: str = None) -> Optional[List[List[float]]]:
|
|
188
|
+
"""Generate embeddings using the specified or default model."""
|
|
189
|
+
# Determine which model to use
|
|
190
|
+
use_type = force_type if force_type else self.model_type
|
|
191
|
+
logger.debug(f"Embedding with: force_type={force_type}, self.model_type={self.model_type}, use_type={use_type}")
|
|
192
|
+
|
|
193
|
+
if use_type == 'local' and not self.local_model:
|
|
194
|
+
logger.error("Local model not initialized")
|
|
184
195
|
return None
|
|
185
|
-
|
|
196
|
+
elif use_type == 'voyage' and not self.voyage_client:
|
|
197
|
+
logger.error("Voyage client not initialized")
|
|
198
|
+
return None
|
|
199
|
+
|
|
186
200
|
# Ensure texts is a list
|
|
187
201
|
if isinstance(texts, str):
|
|
188
202
|
texts = [texts]
|
|
189
|
-
|
|
203
|
+
|
|
190
204
|
try:
|
|
191
|
-
if
|
|
205
|
+
if use_type == 'local':
|
|
192
206
|
# FastEmbed returns a generator, convert to list
|
|
193
|
-
embeddings = list(self.
|
|
207
|
+
embeddings = list(self.local_model.embed(texts))
|
|
194
208
|
return [emb.tolist() for emb in embeddings]
|
|
195
|
-
|
|
196
|
-
elif
|
|
209
|
+
|
|
210
|
+
elif use_type == 'voyage':
|
|
211
|
+
# Always use voyage-3 for consistency with collection dimensions (1024)
|
|
197
212
|
result = self.voyage_client.embed(
|
|
198
213
|
texts=texts,
|
|
199
|
-
model="voyage-3
|
|
214
|
+
model="voyage-3",
|
|
200
215
|
input_type=input_type
|
|
201
216
|
)
|
|
202
217
|
return result.embeddings
|
|
203
|
-
|
|
218
|
+
|
|
204
219
|
except Exception as e:
|
|
205
|
-
logger.error(f"Error generating embeddings: {e}")
|
|
220
|
+
logger.error(f"Error generating embeddings with {use_type}: {e}")
|
|
206
221
|
return None
|
|
207
222
|
|
|
208
|
-
def get_vector_dimension(self) -> int:
|
|
209
|
-
"""Get the dimension of embeddings."""
|
|
210
|
-
if self.model_type
|
|
223
|
+
def get_vector_dimension(self, force_type: str = None) -> int:
|
|
224
|
+
"""Get the dimension of embeddings for a specific type."""
|
|
225
|
+
use_type = force_type if force_type else self.model_type
|
|
226
|
+
if use_type == 'local':
|
|
211
227
|
return 384 # all-MiniLM-L6-v2 dimension
|
|
212
|
-
elif
|
|
228
|
+
elif use_type == 'voyage':
|
|
213
229
|
return 1024 # voyage-3 dimension
|
|
214
230
|
return 0
|
|
215
231
|
|
|
@@ -222,6 +238,14 @@ class EmbeddingManager:
|
|
|
222
238
|
'prefer_local': self.prefer_local,
|
|
223
239
|
'has_voyage_key': bool(self.voyage_key)
|
|
224
240
|
}
|
|
241
|
+
|
|
242
|
+
async def generate_embedding(self, text: str, force_type: str = None) -> Optional[List[float]]:
|
|
243
|
+
"""Generate embedding for a single text (async wrapper for compatibility)."""
|
|
244
|
+
# Use the force_type if specified, otherwise use default
|
|
245
|
+
result = self.embed(text, input_type="query", force_type=force_type)
|
|
246
|
+
if result and len(result) > 0:
|
|
247
|
+
return result[0]
|
|
248
|
+
return None
|
|
225
249
|
|
|
226
250
|
|
|
227
251
|
# Global instance
|