mcp-vector-search 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mcp-vector-search might be problematic. Click here for more details.
- mcp_vector_search/__init__.py +9 -0
- mcp_vector_search/cli/__init__.py +1 -0
- mcp_vector_search/cli/commands/__init__.py +1 -0
- mcp_vector_search/cli/commands/config.py +303 -0
- mcp_vector_search/cli/commands/index.py +304 -0
- mcp_vector_search/cli/commands/init.py +212 -0
- mcp_vector_search/cli/commands/search.py +395 -0
- mcp_vector_search/cli/commands/status.py +340 -0
- mcp_vector_search/cli/commands/watch.py +288 -0
- mcp_vector_search/cli/main.py +117 -0
- mcp_vector_search/cli/output.py +242 -0
- mcp_vector_search/config/__init__.py +1 -0
- mcp_vector_search/config/defaults.py +175 -0
- mcp_vector_search/config/settings.py +108 -0
- mcp_vector_search/core/__init__.py +1 -0
- mcp_vector_search/core/database.py +431 -0
- mcp_vector_search/core/embeddings.py +250 -0
- mcp_vector_search/core/exceptions.py +66 -0
- mcp_vector_search/core/indexer.py +310 -0
- mcp_vector_search/core/models.py +174 -0
- mcp_vector_search/core/project.py +304 -0
- mcp_vector_search/core/search.py +324 -0
- mcp_vector_search/core/watcher.py +320 -0
- mcp_vector_search/mcp/__init__.py +1 -0
- mcp_vector_search/parsers/__init__.py +1 -0
- mcp_vector_search/parsers/base.py +180 -0
- mcp_vector_search/parsers/javascript.py +238 -0
- mcp_vector_search/parsers/python.py +407 -0
- mcp_vector_search/parsers/registry.py +187 -0
- mcp_vector_search/py.typed +1 -0
- mcp_vector_search-0.0.3.dist-info/METADATA +333 -0
- mcp_vector_search-0.0.3.dist-info/RECORD +35 -0
- mcp_vector_search-0.0.3.dist-info/WHEEL +4 -0
- mcp_vector_search-0.0.3.dist-info/entry_points.txt +2 -0
- mcp_vector_search-0.0.3.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,431 @@
|
|
|
1
|
+
"""Database abstraction and ChromaDB implementation for MCP Vector Search."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
|
7
|
+
|
|
8
|
+
from loguru import logger
|
|
9
|
+
|
|
10
|
+
from .exceptions import (
|
|
11
|
+
DatabaseError,
|
|
12
|
+
DatabaseInitializationError,
|
|
13
|
+
DatabaseNotInitializedError,
|
|
14
|
+
DocumentAdditionError,
|
|
15
|
+
SearchError,
|
|
16
|
+
)
|
|
17
|
+
from .models import CodeChunk, IndexStats, SearchResult
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@runtime_checkable
|
|
21
|
+
class EmbeddingFunction(Protocol):
|
|
22
|
+
"""Protocol for embedding functions."""
|
|
23
|
+
|
|
24
|
+
def __call__(self, texts: List[str]) -> List[List[float]]:
|
|
25
|
+
"""Generate embeddings for input texts."""
|
|
26
|
+
...
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class VectorDatabase(ABC):
|
|
30
|
+
"""Abstract interface for vector database operations."""
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
async def initialize(self) -> None:
|
|
34
|
+
"""Initialize the database connection and collections."""
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
async def close(self) -> None:
|
|
39
|
+
"""Close database connections and cleanup resources."""
|
|
40
|
+
...
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
async def add_chunks(self, chunks: List[CodeChunk]) -> None:
|
|
44
|
+
"""Add code chunks to the database.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
chunks: List of code chunks to add
|
|
48
|
+
"""
|
|
49
|
+
...
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
async def search(
|
|
53
|
+
self,
|
|
54
|
+
query: str,
|
|
55
|
+
limit: int = 10,
|
|
56
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
57
|
+
similarity_threshold: float = 0.7,
|
|
58
|
+
) -> List[SearchResult]:
|
|
59
|
+
"""Search for similar code chunks.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
query: Search query
|
|
63
|
+
limit: Maximum number of results
|
|
64
|
+
filters: Optional filters to apply
|
|
65
|
+
similarity_threshold: Minimum similarity score
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
List of search results
|
|
69
|
+
"""
|
|
70
|
+
...
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
async def delete_by_file(self, file_path: Path) -> int:
|
|
74
|
+
"""Delete all chunks for a specific file.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
file_path: Path to the file
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Number of deleted chunks
|
|
81
|
+
"""
|
|
82
|
+
...
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
async def get_stats(self) -> IndexStats:
|
|
86
|
+
"""Get database statistics.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Index statistics
|
|
90
|
+
"""
|
|
91
|
+
...
|
|
92
|
+
|
|
93
|
+
@abstractmethod
|
|
94
|
+
async def reset(self) -> None:
|
|
95
|
+
"""Reset the database (delete all data)."""
|
|
96
|
+
...
|
|
97
|
+
|
|
98
|
+
async def __aenter__(self) -> "VectorDatabase":
|
|
99
|
+
"""Async context manager entry."""
|
|
100
|
+
await self.initialize()
|
|
101
|
+
return self
|
|
102
|
+
|
|
103
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
104
|
+
"""Async context manager exit."""
|
|
105
|
+
await self.close()
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class ChromaVectorDatabase(VectorDatabase):
|
|
109
|
+
"""ChromaDB implementation of vector database."""
|
|
110
|
+
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
persist_directory: Path,
|
|
114
|
+
embedding_function: EmbeddingFunction,
|
|
115
|
+
collection_name: str = "code_search",
|
|
116
|
+
) -> None:
|
|
117
|
+
"""Initialize ChromaDB vector database.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
persist_directory: Directory to persist database
|
|
121
|
+
embedding_function: Function to generate embeddings
|
|
122
|
+
collection_name: Name of the collection
|
|
123
|
+
"""
|
|
124
|
+
self.persist_directory = persist_directory
|
|
125
|
+
self.embedding_function = embedding_function
|
|
126
|
+
self.collection_name = collection_name
|
|
127
|
+
self._client = None
|
|
128
|
+
self._collection = None
|
|
129
|
+
|
|
130
|
+
async def initialize(self) -> None:
|
|
131
|
+
"""Initialize ChromaDB client and collection."""
|
|
132
|
+
try:
|
|
133
|
+
import chromadb
|
|
134
|
+
|
|
135
|
+
# Ensure directory exists
|
|
136
|
+
self.persist_directory.mkdir(parents=True, exist_ok=True)
|
|
137
|
+
|
|
138
|
+
# Create client with new API
|
|
139
|
+
self._client = chromadb.PersistentClient(
|
|
140
|
+
path=str(self.persist_directory),
|
|
141
|
+
settings=chromadb.Settings(
|
|
142
|
+
anonymized_telemetry=False,
|
|
143
|
+
allow_reset=True,
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Create or get collection
|
|
148
|
+
self._collection = self._client.get_or_create_collection(
|
|
149
|
+
name=self.collection_name,
|
|
150
|
+
embedding_function=self.embedding_function,
|
|
151
|
+
metadata={
|
|
152
|
+
"description": "Semantic code search collection",
|
|
153
|
+
},
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
logger.info(f"ChromaDB initialized at {self.persist_directory}")
|
|
157
|
+
|
|
158
|
+
except Exception as e:
|
|
159
|
+
logger.error(f"Failed to initialize ChromaDB: {e}")
|
|
160
|
+
raise DatabaseInitializationError(f"ChromaDB initialization failed: {e}") from e
|
|
161
|
+
|
|
162
|
+
async def remove_file_chunks(self, file_path: str) -> int:
|
|
163
|
+
"""Remove all chunks for a specific file.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
file_path: Relative path to the file
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
Number of chunks removed
|
|
170
|
+
"""
|
|
171
|
+
if not self._collection:
|
|
172
|
+
raise DatabaseNotInitializedError("Database not initialized")
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
# Get all chunks for this file
|
|
176
|
+
results = self._collection.get(
|
|
177
|
+
where={"file_path": file_path}
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
if not results["ids"]:
|
|
181
|
+
return 0
|
|
182
|
+
|
|
183
|
+
# Delete the chunks
|
|
184
|
+
self._collection.delete(ids=results["ids"])
|
|
185
|
+
|
|
186
|
+
removed_count = len(results["ids"])
|
|
187
|
+
logger.debug(f"Removed {removed_count} chunks for file: {file_path}")
|
|
188
|
+
return removed_count
|
|
189
|
+
|
|
190
|
+
except Exception as e:
|
|
191
|
+
logger.error(f"Failed to remove chunks for file {file_path}: {e}")
|
|
192
|
+
return 0
|
|
193
|
+
|
|
194
|
+
async def close(self) -> None:
|
|
195
|
+
"""Close database connections."""
|
|
196
|
+
if self._client:
|
|
197
|
+
# ChromaDB doesn't require explicit closing
|
|
198
|
+
self._client = None
|
|
199
|
+
self._collection = None
|
|
200
|
+
logger.debug("ChromaDB connections closed")
|
|
201
|
+
|
|
202
|
+
async def add_chunks(self, chunks: List[CodeChunk]) -> None:
|
|
203
|
+
"""Add code chunks to the database."""
|
|
204
|
+
if not self._collection:
|
|
205
|
+
raise DatabaseNotInitializedError("Database not initialized")
|
|
206
|
+
|
|
207
|
+
if not chunks:
|
|
208
|
+
return
|
|
209
|
+
|
|
210
|
+
try:
|
|
211
|
+
documents = []
|
|
212
|
+
metadatas = []
|
|
213
|
+
ids = []
|
|
214
|
+
|
|
215
|
+
for chunk in chunks:
|
|
216
|
+
# Create searchable text
|
|
217
|
+
searchable_text = self._create_searchable_text(chunk)
|
|
218
|
+
documents.append(searchable_text)
|
|
219
|
+
|
|
220
|
+
# Create metadata
|
|
221
|
+
metadata = {
|
|
222
|
+
"file_path": str(chunk.file_path),
|
|
223
|
+
"start_line": chunk.start_line,
|
|
224
|
+
"end_line": chunk.end_line,
|
|
225
|
+
"language": chunk.language,
|
|
226
|
+
"chunk_type": chunk.chunk_type,
|
|
227
|
+
"function_name": chunk.function_name or "",
|
|
228
|
+
"class_name": chunk.class_name or "",
|
|
229
|
+
"docstring": chunk.docstring or "",
|
|
230
|
+
"complexity_score": chunk.complexity_score,
|
|
231
|
+
}
|
|
232
|
+
metadatas.append(metadata)
|
|
233
|
+
|
|
234
|
+
# Use chunk ID
|
|
235
|
+
ids.append(chunk.id)
|
|
236
|
+
|
|
237
|
+
# Add to collection
|
|
238
|
+
self._collection.add(
|
|
239
|
+
documents=documents,
|
|
240
|
+
metadatas=metadatas,
|
|
241
|
+
ids=ids,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
logger.debug(f"Added {len(chunks)} chunks to database")
|
|
245
|
+
|
|
246
|
+
except Exception as e:
|
|
247
|
+
logger.error(f"Failed to add chunks: {e}")
|
|
248
|
+
raise DocumentAdditionError(f"Failed to add chunks: {e}") from e
|
|
249
|
+
|
|
250
|
+
async def search(
|
|
251
|
+
self,
|
|
252
|
+
query: str,
|
|
253
|
+
limit: int = 10,
|
|
254
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
255
|
+
similarity_threshold: float = 0.7,
|
|
256
|
+
) -> List[SearchResult]:
|
|
257
|
+
"""Search for similar code chunks."""
|
|
258
|
+
if not self._collection:
|
|
259
|
+
raise DatabaseNotInitializedError("Database not initialized")
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
# Build where clause
|
|
263
|
+
where_clause = self._build_where_clause(filters) if filters else None
|
|
264
|
+
|
|
265
|
+
# Perform search
|
|
266
|
+
results = self._collection.query(
|
|
267
|
+
query_texts=[query],
|
|
268
|
+
n_results=limit,
|
|
269
|
+
where=where_clause,
|
|
270
|
+
include=["documents", "metadatas", "distances"],
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# Process results
|
|
274
|
+
search_results = []
|
|
275
|
+
|
|
276
|
+
if results["documents"] and results["documents"][0]:
|
|
277
|
+
for i, (doc, metadata, distance) in enumerate(
|
|
278
|
+
zip(
|
|
279
|
+
results["documents"][0],
|
|
280
|
+
results["metadatas"][0],
|
|
281
|
+
results["distances"][0],
|
|
282
|
+
)
|
|
283
|
+
):
|
|
284
|
+
# Convert distance to similarity (ChromaDB uses cosine distance)
|
|
285
|
+
similarity = 1.0 - distance
|
|
286
|
+
|
|
287
|
+
if similarity >= similarity_threshold:
|
|
288
|
+
result = SearchResult(
|
|
289
|
+
content=doc,
|
|
290
|
+
file_path=Path(metadata["file_path"]),
|
|
291
|
+
start_line=metadata["start_line"],
|
|
292
|
+
end_line=metadata["end_line"],
|
|
293
|
+
language=metadata["language"],
|
|
294
|
+
similarity_score=similarity,
|
|
295
|
+
rank=i + 1,
|
|
296
|
+
chunk_type=metadata.get("chunk_type", "code"),
|
|
297
|
+
function_name=metadata.get("function_name") or None,
|
|
298
|
+
class_name=metadata.get("class_name") or None,
|
|
299
|
+
)
|
|
300
|
+
search_results.append(result)
|
|
301
|
+
|
|
302
|
+
logger.debug(f"Found {len(search_results)} results for query: {query}")
|
|
303
|
+
return search_results
|
|
304
|
+
|
|
305
|
+
except Exception as e:
|
|
306
|
+
logger.error(f"Search failed: {e}")
|
|
307
|
+
raise SearchError(f"Search failed: {e}") from e
|
|
308
|
+
|
|
309
|
+
async def delete_by_file(self, file_path: Path) -> int:
|
|
310
|
+
"""Delete all chunks for a specific file."""
|
|
311
|
+
if not self._collection:
|
|
312
|
+
raise DatabaseNotInitializedError("Database not initialized")
|
|
313
|
+
|
|
314
|
+
try:
|
|
315
|
+
# Get all chunks for this file
|
|
316
|
+
results = self._collection.get(
|
|
317
|
+
where={"file_path": str(file_path)},
|
|
318
|
+
include=["metadatas"],
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if results["ids"]:
|
|
322
|
+
self._collection.delete(ids=results["ids"])
|
|
323
|
+
count = len(results["ids"])
|
|
324
|
+
logger.debug(f"Deleted {count} chunks for {file_path}")
|
|
325
|
+
return count
|
|
326
|
+
|
|
327
|
+
return 0
|
|
328
|
+
|
|
329
|
+
except Exception as e:
|
|
330
|
+
logger.error(f"Failed to delete chunks for {file_path}: {e}")
|
|
331
|
+
raise DatabaseError(f"Failed to delete chunks: {e}") from e
|
|
332
|
+
|
|
333
|
+
async def get_stats(self) -> IndexStats:
|
|
334
|
+
"""Get database statistics."""
|
|
335
|
+
if not self._collection:
|
|
336
|
+
raise DatabaseNotInitializedError("Database not initialized")
|
|
337
|
+
|
|
338
|
+
try:
|
|
339
|
+
# Get total count
|
|
340
|
+
count = self._collection.count()
|
|
341
|
+
|
|
342
|
+
# Get sample for language distribution
|
|
343
|
+
sample_results = self._collection.get(
|
|
344
|
+
limit=min(1000, count) if count > 0 else 0,
|
|
345
|
+
include=["metadatas"],
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
languages = {}
|
|
349
|
+
file_types = {}
|
|
350
|
+
|
|
351
|
+
if sample_results["metadatas"]:
|
|
352
|
+
for metadata in sample_results["metadatas"]:
|
|
353
|
+
# Count languages
|
|
354
|
+
lang = metadata.get("language", "unknown")
|
|
355
|
+
languages[lang] = languages.get(lang, 0) + 1
|
|
356
|
+
|
|
357
|
+
# Count file types
|
|
358
|
+
file_path = metadata.get("file_path", "")
|
|
359
|
+
ext = Path(file_path).suffix or "no_extension"
|
|
360
|
+
file_types[ext] = file_types.get(ext, 0) + 1
|
|
361
|
+
|
|
362
|
+
# Estimate index size (rough approximation)
|
|
363
|
+
index_size_mb = count * 0.001 # Rough estimate
|
|
364
|
+
|
|
365
|
+
return IndexStats(
|
|
366
|
+
total_files=len(set(m.get("file_path", "") for m in sample_results.get("metadatas", []))),
|
|
367
|
+
total_chunks=count,
|
|
368
|
+
languages=languages,
|
|
369
|
+
file_types=file_types,
|
|
370
|
+
index_size_mb=index_size_mb,
|
|
371
|
+
last_updated="unknown", # TODO: Track this
|
|
372
|
+
embedding_model="unknown", # TODO: Track this
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
except Exception as e:
|
|
376
|
+
logger.error(f"Failed to get stats: {e}")
|
|
377
|
+
return IndexStats(
|
|
378
|
+
total_files=0,
|
|
379
|
+
total_chunks=0,
|
|
380
|
+
languages={},
|
|
381
|
+
file_types={},
|
|
382
|
+
index_size_mb=0.0,
|
|
383
|
+
last_updated="error",
|
|
384
|
+
embedding_model="unknown",
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
async def reset(self) -> None:
|
|
388
|
+
"""Reset the database."""
|
|
389
|
+
if self._client:
|
|
390
|
+
try:
|
|
391
|
+
self._client.reset()
|
|
392
|
+
# Recreate collection
|
|
393
|
+
await self.initialize()
|
|
394
|
+
logger.info("Database reset successfully")
|
|
395
|
+
except Exception as e:
|
|
396
|
+
logger.error(f"Failed to reset database: {e}")
|
|
397
|
+
raise DatabaseError(f"Failed to reset database: {e}") from e
|
|
398
|
+
|
|
399
|
+
def _create_searchable_text(self, chunk: CodeChunk) -> str:
|
|
400
|
+
"""Create optimized searchable text from code chunk."""
|
|
401
|
+
parts = [chunk.content]
|
|
402
|
+
|
|
403
|
+
# Add contextual information
|
|
404
|
+
if chunk.function_name:
|
|
405
|
+
parts.append(f"Function: {chunk.function_name}")
|
|
406
|
+
|
|
407
|
+
if chunk.class_name:
|
|
408
|
+
parts.append(f"Class: {chunk.class_name}")
|
|
409
|
+
|
|
410
|
+
if chunk.docstring:
|
|
411
|
+
parts.append(f"Documentation: {chunk.docstring}")
|
|
412
|
+
|
|
413
|
+
# Add language and file context
|
|
414
|
+
parts.append(f"Language: {chunk.language}")
|
|
415
|
+
parts.append(f"File: {chunk.file_path.name}")
|
|
416
|
+
|
|
417
|
+
return "\n".join(parts)
|
|
418
|
+
|
|
419
|
+
def _build_where_clause(self, filters: Dict[str, Any]) -> Dict[str, Any]:
|
|
420
|
+
"""Build ChromaDB where clause from filters."""
|
|
421
|
+
where = {}
|
|
422
|
+
|
|
423
|
+
for key, value in filters.items():
|
|
424
|
+
if isinstance(value, list):
|
|
425
|
+
where[key] = {"$in": value}
|
|
426
|
+
elif isinstance(value, str) and value.startswith("!"):
|
|
427
|
+
where[key] = {"$ne": value[1:]}
|
|
428
|
+
else:
|
|
429
|
+
where[key] = value
|
|
430
|
+
|
|
431
|
+
return where
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
"""Embedding generation for MCP Vector Search."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
import aiofiles
|
|
9
|
+
from loguru import logger
|
|
10
|
+
from sentence_transformers import SentenceTransformer
|
|
11
|
+
|
|
12
|
+
from .exceptions import EmbeddingError
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EmbeddingCache:
|
|
16
|
+
"""LRU cache for embeddings with disk persistence."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, cache_dir: Path, max_size: int = 1000) -> None:
|
|
19
|
+
"""Initialize embedding cache.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
cache_dir: Directory to store cached embeddings
|
|
23
|
+
max_size: Maximum number of embeddings to keep in memory
|
|
24
|
+
"""
|
|
25
|
+
self.cache_dir = cache_dir
|
|
26
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
27
|
+
self.max_size = max_size
|
|
28
|
+
self._memory_cache: Dict[str, List[float]] = {}
|
|
29
|
+
|
|
30
|
+
def _hash_content(self, content: str) -> str:
|
|
31
|
+
"""Generate cache key from content."""
|
|
32
|
+
return hashlib.sha256(content.encode()).hexdigest()[:16]
|
|
33
|
+
|
|
34
|
+
async def get_embedding(self, content: str) -> Optional[List[float]]:
|
|
35
|
+
"""Get cached embedding for content."""
|
|
36
|
+
cache_key = self._hash_content(content)
|
|
37
|
+
|
|
38
|
+
# Check memory cache first
|
|
39
|
+
if cache_key in self._memory_cache:
|
|
40
|
+
return self._memory_cache[cache_key]
|
|
41
|
+
|
|
42
|
+
# Check disk cache
|
|
43
|
+
cache_file = self.cache_dir / f"{cache_key}.json"
|
|
44
|
+
if cache_file.exists():
|
|
45
|
+
try:
|
|
46
|
+
async with aiofiles.open(cache_file, "r") as f:
|
|
47
|
+
content_str = await f.read()
|
|
48
|
+
embedding = json.loads(content_str)
|
|
49
|
+
|
|
50
|
+
# Add to memory cache if space available
|
|
51
|
+
if len(self._memory_cache) < self.max_size:
|
|
52
|
+
self._memory_cache[cache_key] = embedding
|
|
53
|
+
|
|
54
|
+
return embedding
|
|
55
|
+
except Exception as e:
|
|
56
|
+
logger.warning(f"Failed to load cached embedding: {e}")
|
|
57
|
+
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
async def store_embedding(self, content: str, embedding: List[float]) -> None:
|
|
61
|
+
"""Store embedding in cache."""
|
|
62
|
+
cache_key = self._hash_content(content)
|
|
63
|
+
|
|
64
|
+
# Store in memory cache if space available
|
|
65
|
+
if len(self._memory_cache) < self.max_size:
|
|
66
|
+
self._memory_cache[cache_key] = embedding
|
|
67
|
+
|
|
68
|
+
# Store in disk cache
|
|
69
|
+
cache_file = self.cache_dir / f"{cache_key}.json"
|
|
70
|
+
try:
|
|
71
|
+
async with aiofiles.open(cache_file, "w") as f:
|
|
72
|
+
await f.write(json.dumps(embedding))
|
|
73
|
+
except Exception as e:
|
|
74
|
+
logger.warning(f"Failed to cache embedding: {e}")
|
|
75
|
+
|
|
76
|
+
def clear_memory_cache(self) -> None:
|
|
77
|
+
"""Clear the in-memory cache."""
|
|
78
|
+
self._memory_cache.clear()
|
|
79
|
+
|
|
80
|
+
def get_cache_stats(self) -> Dict[str, int]:
|
|
81
|
+
"""Get cache statistics."""
|
|
82
|
+
disk_files = len(list(self.cache_dir.glob("*.json")))
|
|
83
|
+
return {
|
|
84
|
+
"memory_cached": len(self._memory_cache),
|
|
85
|
+
"disk_cached": disk_files,
|
|
86
|
+
"memory_limit": self.max_size,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class CodeBERTEmbeddingFunction:
|
|
91
|
+
"""ChromaDB-compatible embedding function using CodeBERT."""
|
|
92
|
+
|
|
93
|
+
def __init__(self, model_name: str = "microsoft/codebert-base") -> None:
|
|
94
|
+
"""Initialize CodeBERT embedding function.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
model_name: Name of the sentence transformer model
|
|
98
|
+
"""
|
|
99
|
+
try:
|
|
100
|
+
self.model = SentenceTransformer(model_name)
|
|
101
|
+
self.model_name = model_name
|
|
102
|
+
logger.info(f"Loaded embedding model: {model_name}")
|
|
103
|
+
except Exception as e:
|
|
104
|
+
logger.error(f"Failed to load embedding model {model_name}: {e}")
|
|
105
|
+
raise EmbeddingError(f"Failed to load embedding model: {e}") from e
|
|
106
|
+
|
|
107
|
+
def __call__(self, input: List[str]) -> List[List[float]]:
|
|
108
|
+
"""Generate embeddings for input texts (ChromaDB interface)."""
|
|
109
|
+
try:
|
|
110
|
+
embeddings = self.model.encode(input, convert_to_numpy=True)
|
|
111
|
+
return embeddings.tolist()
|
|
112
|
+
except Exception as e:
|
|
113
|
+
logger.error(f"Failed to generate embeddings: {e}")
|
|
114
|
+
raise EmbeddingError(f"Failed to generate embeddings: {e}") from e
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class BatchEmbeddingProcessor:
|
|
118
|
+
"""Batch processing for efficient embedding generation with caching."""
|
|
119
|
+
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
embedding_function: CodeBERTEmbeddingFunction,
|
|
123
|
+
cache: Optional[EmbeddingCache] = None,
|
|
124
|
+
batch_size: int = 32,
|
|
125
|
+
) -> None:
|
|
126
|
+
"""Initialize batch embedding processor.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
embedding_function: Function to generate embeddings
|
|
130
|
+
cache: Optional embedding cache
|
|
131
|
+
batch_size: Size of batches for processing
|
|
132
|
+
"""
|
|
133
|
+
self.embedding_function = embedding_function
|
|
134
|
+
self.cache = cache
|
|
135
|
+
self.batch_size = batch_size
|
|
136
|
+
|
|
137
|
+
async def process_batch(self, contents: List[str]) -> List[List[float]]:
|
|
138
|
+
"""Process a batch of content for embeddings.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
contents: List of text content to embed
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
List of embeddings
|
|
145
|
+
"""
|
|
146
|
+
if not contents:
|
|
147
|
+
return []
|
|
148
|
+
|
|
149
|
+
embeddings = []
|
|
150
|
+
uncached_contents = []
|
|
151
|
+
uncached_indices = []
|
|
152
|
+
|
|
153
|
+
# Check cache for each content if cache is available
|
|
154
|
+
if self.cache:
|
|
155
|
+
for i, content in enumerate(contents):
|
|
156
|
+
cached_embedding = await self.cache.get_embedding(content)
|
|
157
|
+
if cached_embedding:
|
|
158
|
+
embeddings.append(cached_embedding)
|
|
159
|
+
else:
|
|
160
|
+
embeddings.append(None) # Placeholder
|
|
161
|
+
uncached_contents.append(content)
|
|
162
|
+
uncached_indices.append(i)
|
|
163
|
+
else:
|
|
164
|
+
# No cache, process all content
|
|
165
|
+
uncached_contents = contents
|
|
166
|
+
uncached_indices = list(range(len(contents)))
|
|
167
|
+
embeddings = [None] * len(contents)
|
|
168
|
+
|
|
169
|
+
# Generate embeddings for uncached content
|
|
170
|
+
if uncached_contents:
|
|
171
|
+
logger.debug(f"Generating {len(uncached_contents)} new embeddings")
|
|
172
|
+
|
|
173
|
+
try:
|
|
174
|
+
new_embeddings = []
|
|
175
|
+
for i in range(0, len(uncached_contents), self.batch_size):
|
|
176
|
+
batch = uncached_contents[i : i + self.batch_size]
|
|
177
|
+
batch_embeddings = self.embedding_function(batch)
|
|
178
|
+
new_embeddings.extend(batch_embeddings)
|
|
179
|
+
|
|
180
|
+
# Cache new embeddings and fill placeholders
|
|
181
|
+
for i, (content, embedding) in enumerate(
|
|
182
|
+
zip(uncached_contents, new_embeddings)
|
|
183
|
+
):
|
|
184
|
+
if self.cache:
|
|
185
|
+
await self.cache.store_embedding(content, embedding)
|
|
186
|
+
embeddings[uncached_indices[i]] = embedding
|
|
187
|
+
|
|
188
|
+
except Exception as e:
|
|
189
|
+
logger.error(f"Failed to generate embeddings: {e}")
|
|
190
|
+
raise EmbeddingError(f"Failed to generate embeddings: {e}") from e
|
|
191
|
+
|
|
192
|
+
return embeddings
|
|
193
|
+
|
|
194
|
+
def get_stats(self) -> Dict[str, any]:
|
|
195
|
+
"""Get processor statistics."""
|
|
196
|
+
stats = {
|
|
197
|
+
"model_name": self.embedding_function.model_name,
|
|
198
|
+
"batch_size": self.batch_size,
|
|
199
|
+
"cache_enabled": self.cache is not None,
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
if self.cache:
|
|
203
|
+
stats.update(self.cache.get_cache_stats())
|
|
204
|
+
|
|
205
|
+
return stats
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def create_embedding_function(
|
|
209
|
+
model_name: str = "microsoft/codebert-base",
|
|
210
|
+
cache_dir: Optional[Path] = None,
|
|
211
|
+
cache_size: int = 1000,
|
|
212
|
+
):
|
|
213
|
+
"""Create embedding function and cache.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
model_name: Name of the embedding model
|
|
217
|
+
cache_dir: Directory for caching embeddings
|
|
218
|
+
cache_size: Maximum cache size
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
Tuple of (embedding_function, cache)
|
|
222
|
+
"""
|
|
223
|
+
try:
|
|
224
|
+
# Use ChromaDB's built-in sentence transformer function
|
|
225
|
+
from chromadb.utils import embedding_functions
|
|
226
|
+
|
|
227
|
+
# Map our model names to sentence-transformers compatible names
|
|
228
|
+
model_mapping = {
|
|
229
|
+
"microsoft/codebert-base": "sentence-transformers/all-MiniLM-L6-v2", # Fallback to working model
|
|
230
|
+
"microsoft/unixcoder-base": "sentence-transformers/all-MiniLM-L6-v2", # Fallback to working model
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
actual_model = model_mapping.get(model_name, model_name)
|
|
234
|
+
|
|
235
|
+
embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
236
|
+
model_name=actual_model
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
logger.info(f"Created ChromaDB embedding function with model: {actual_model}")
|
|
240
|
+
|
|
241
|
+
except Exception as e:
|
|
242
|
+
logger.warning(f"Failed to create ChromaDB embedding function: {e}")
|
|
243
|
+
# Fallback to our custom implementation
|
|
244
|
+
embedding_function = CodeBERTEmbeddingFunction(model_name)
|
|
245
|
+
|
|
246
|
+
cache = None
|
|
247
|
+
if cache_dir:
|
|
248
|
+
cache = EmbeddingCache(cache_dir, cache_size)
|
|
249
|
+
|
|
250
|
+
return embedding_function, cache
|