stravinsky 0.2.40__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mcp_bridge/__init__.py +1 -1
- mcp_bridge/auth/token_refresh.py +130 -0
- mcp_bridge/cli/__init__.py +6 -0
- mcp_bridge/cli/install_hooks.py +1265 -0
- mcp_bridge/cli/session_report.py +585 -0
- mcp_bridge/hooks/HOOKS_SETTINGS.json +175 -0
- mcp_bridge/hooks/README.md +215 -0
- mcp_bridge/hooks/__init__.py +119 -43
- mcp_bridge/hooks/edit_recovery.py +42 -37
- mcp_bridge/hooks/git_noninteractive.py +89 -0
- mcp_bridge/hooks/keyword_detector.py +30 -0
- mcp_bridge/hooks/manager.py +50 -0
- mcp_bridge/hooks/notification_hook.py +103 -0
- mcp_bridge/hooks/parallel_enforcer.py +127 -0
- mcp_bridge/hooks/parallel_execution.py +111 -0
- mcp_bridge/hooks/pre_compact.py +123 -0
- mcp_bridge/hooks/preemptive_compaction.py +81 -7
- mcp_bridge/hooks/rules_injector.py +507 -0
- mcp_bridge/hooks/session_idle.py +116 -0
- mcp_bridge/hooks/session_notifier.py +125 -0
- mcp_bridge/{native_hooks → hooks}/stravinsky_mode.py +51 -16
- mcp_bridge/hooks/subagent_stop.py +98 -0
- mcp_bridge/hooks/task_validator.py +73 -0
- mcp_bridge/hooks/tmux_manager.py +141 -0
- mcp_bridge/hooks/todo_continuation.py +90 -0
- mcp_bridge/hooks/todo_delegation.py +88 -0
- mcp_bridge/hooks/tool_messaging.py +164 -0
- mcp_bridge/hooks/truncator.py +21 -17
- mcp_bridge/notifications.py +151 -0
- mcp_bridge/prompts/__init__.py +3 -1
- mcp_bridge/prompts/dewey.py +30 -20
- mcp_bridge/prompts/explore.py +46 -8
- mcp_bridge/prompts/multimodal.py +24 -3
- mcp_bridge/prompts/planner.py +222 -0
- mcp_bridge/prompts/stravinsky.py +107 -28
- mcp_bridge/server.py +170 -10
- mcp_bridge/server_tools.py +554 -32
- mcp_bridge/tools/agent_manager.py +316 -106
- mcp_bridge/tools/background_tasks.py +2 -1
- mcp_bridge/tools/code_search.py +97 -11
- mcp_bridge/tools/lsp/__init__.py +7 -0
- mcp_bridge/tools/lsp/manager.py +448 -0
- mcp_bridge/tools/lsp/tools.py +637 -150
- mcp_bridge/tools/model_invoke.py +270 -47
- mcp_bridge/tools/semantic_search.py +2492 -0
- mcp_bridge/tools/templates.py +32 -18
- stravinsky-0.3.4.dist-info/METADATA +420 -0
- stravinsky-0.3.4.dist-info/RECORD +79 -0
- stravinsky-0.3.4.dist-info/entry_points.txt +5 -0
- mcp_bridge/native_hooks/edit_recovery.py +0 -46
- mcp_bridge/native_hooks/truncator.py +0 -23
- stravinsky-0.2.40.dist-info/METADATA +0 -204
- stravinsky-0.2.40.dist-info/RECORD +0 -57
- stravinsky-0.2.40.dist-info/entry_points.txt +0 -3
- /mcp_bridge/{native_hooks → hooks}/context.py +0 -0
- {stravinsky-0.2.40.dist-info → stravinsky-0.3.4.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,2492 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Semantic Code Search - Vector-based code understanding
|
|
3
|
+
|
|
4
|
+
Uses ChromaDB for persistent vector storage with multiple embedding providers:
|
|
5
|
+
- Ollama (local, free) - nomic-embed-text (768 dims)
|
|
6
|
+
- Mxbai (local, free) - mxbai-embed-large (1024 dims, better for code)
|
|
7
|
+
- Gemini (cloud, OAuth) - gemini-embedding-001 (768-3072 dims)
|
|
8
|
+
- OpenAI (cloud, OAuth) - text-embedding-3-small (1536 dims)
|
|
9
|
+
- HuggingFace (cloud, token) - sentence-transformers/all-mpnet-base-v2 (768 dims)
|
|
10
|
+
|
|
11
|
+
Enables natural language queries like "find authentication logic" without
|
|
12
|
+
requiring exact pattern matching.
|
|
13
|
+
|
|
14
|
+
Architecture:
|
|
15
|
+
- Per-project ChromaDB storage at ~/.stravinsky/vectordb/<project_hash>/
|
|
16
|
+
- Lazy initialization on first query
|
|
17
|
+
- Provider abstraction for embedding generation
|
|
18
|
+
- Chunking strategy: function/class level with context
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import hashlib
|
|
22
|
+
import logging
|
|
23
|
+
import sys
|
|
24
|
+
import threading
|
|
25
|
+
from abc import ABC, abstractmethod
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import Literal
|
|
28
|
+
|
|
29
|
+
import httpx
|
|
30
|
+
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
|
31
|
+
|
|
32
|
+
from mcp_bridge.auth.token_store import TokenStore
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# Lazy imports for watchdog (avoid startup cost)
|
|
38
|
+
_watchdog = None
|
|
39
|
+
_watchdog_import_lock = threading.Lock()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_watchdog():
|
|
43
|
+
"""Lazy import of watchdog components for file watching."""
|
|
44
|
+
global _watchdog
|
|
45
|
+
if _watchdog is None:
|
|
46
|
+
with _watchdog_import_lock:
|
|
47
|
+
if _watchdog is None:
|
|
48
|
+
from watchdog.observers import Observer
|
|
49
|
+
from watchdog.events import FileSystemEventHandler
|
|
50
|
+
|
|
51
|
+
_watchdog = {"Observer": Observer, "FileSystemEventHandler": FileSystemEventHandler}
|
|
52
|
+
return _watchdog
|
|
53
|
+
|
|
54
|
+
# Embedding provider type
|
|
55
|
+
EmbeddingProvider = Literal["ollama", "mxbai", "gemini", "openai", "huggingface"]
|
|
56
|
+
|
|
57
|
+
# Lazy imports to avoid startup cost
|
|
58
|
+
_chromadb = None
|
|
59
|
+
_ollama = None
|
|
60
|
+
_httpx = None
|
|
61
|
+
_filelock = None
|
|
62
|
+
_import_lock = threading.Lock()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_filelock():
|
|
66
|
+
global _filelock
|
|
67
|
+
if _filelock is None:
|
|
68
|
+
with _import_lock:
|
|
69
|
+
if _filelock is None:
|
|
70
|
+
import filelock
|
|
71
|
+
|
|
72
|
+
_filelock = filelock
|
|
73
|
+
return _filelock
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_chromadb():
|
|
77
|
+
global _chromadb
|
|
78
|
+
if _chromadb is None:
|
|
79
|
+
with _import_lock:
|
|
80
|
+
if _chromadb is None:
|
|
81
|
+
try:
|
|
82
|
+
import chromadb
|
|
83
|
+
_chromadb = chromadb
|
|
84
|
+
except ImportError as e:
|
|
85
|
+
import sys
|
|
86
|
+
if sys.version_info >= (3, 14):
|
|
87
|
+
raise ImportError(
|
|
88
|
+
"ChromaDB is not available on Python 3.14+. "
|
|
89
|
+
"Semantic search is not supported on Python 3.14 yet. "
|
|
90
|
+
"Use Python 3.11-3.13 for semantic search features."
|
|
91
|
+
) from e
|
|
92
|
+
raise
|
|
93
|
+
return _chromadb
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def get_ollama():
|
|
97
|
+
global _ollama
|
|
98
|
+
if _ollama is None:
|
|
99
|
+
with _import_lock:
|
|
100
|
+
if _ollama is None:
|
|
101
|
+
import ollama
|
|
102
|
+
|
|
103
|
+
_ollama = ollama
|
|
104
|
+
return _ollama
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def get_httpx():
|
|
108
|
+
global _httpx
|
|
109
|
+
if _httpx is None:
|
|
110
|
+
with _import_lock:
|
|
111
|
+
if _httpx is None:
|
|
112
|
+
import httpx
|
|
113
|
+
|
|
114
|
+
_httpx = httpx
|
|
115
|
+
return _httpx
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
# ========================
|
|
119
|
+
# EMBEDDING PROVIDERS
|
|
120
|
+
# ========================
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class BaseEmbeddingProvider(ABC):
|
|
124
|
+
"""Abstract base class for embedding providers."""
|
|
125
|
+
|
|
126
|
+
@abstractmethod
|
|
127
|
+
async def get_embedding(self, text: str) -> list[float]:
|
|
128
|
+
"""Get embedding vector for text."""
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
@abstractmethod
|
|
132
|
+
async def check_available(self) -> bool:
|
|
133
|
+
"""Check if the provider is available and ready."""
|
|
134
|
+
pass
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
@abstractmethod
|
|
138
|
+
def dimension(self) -> int:
|
|
139
|
+
"""Return the embedding dimension for this provider."""
|
|
140
|
+
pass
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
@abstractmethod
|
|
144
|
+
def name(self) -> str:
|
|
145
|
+
"""Return the provider name."""
|
|
146
|
+
pass
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class OllamaProvider(BaseEmbeddingProvider):
|
|
150
|
+
"""Ollama local embedding provider using nomic-embed-text."""
|
|
151
|
+
|
|
152
|
+
MODEL = "nomic-embed-text"
|
|
153
|
+
DIMENSION = 768
|
|
154
|
+
|
|
155
|
+
def __init__(self):
|
|
156
|
+
self._available: bool | None = None
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def dimension(self) -> int:
|
|
160
|
+
return self.DIMENSION
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
def name(self) -> str:
|
|
164
|
+
return "ollama"
|
|
165
|
+
|
|
166
|
+
async def check_available(self) -> bool:
|
|
167
|
+
if self._available is not None:
|
|
168
|
+
return self._available
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
ollama = get_ollama()
|
|
172
|
+
models = ollama.list()
|
|
173
|
+
model_names = [m.model for m in models.models] if hasattr(models, "models") else []
|
|
174
|
+
|
|
175
|
+
if not any(name and self.MODEL in name for name in model_names):
|
|
176
|
+
print(
|
|
177
|
+
f"⚠️ Embedding model '{self.MODEL}' not found. Run: ollama pull {self.MODEL}",
|
|
178
|
+
file=sys.stderr,
|
|
179
|
+
)
|
|
180
|
+
self._available = False
|
|
181
|
+
return False
|
|
182
|
+
|
|
183
|
+
self._available = True
|
|
184
|
+
return True
|
|
185
|
+
except Exception as e:
|
|
186
|
+
print(f"⚠️ Ollama not available: {e}. Start with: ollama serve", file=sys.stderr)
|
|
187
|
+
self._available = False
|
|
188
|
+
return False
|
|
189
|
+
|
|
190
|
+
async def get_embedding(self, text: str) -> list[float]:
|
|
191
|
+
ollama = get_ollama()
|
|
192
|
+
# nomic-embed-text has 8192 token context. Code can be 1-2 chars/token.
|
|
193
|
+
# Truncate to 2000 chars (~1000-2000 tokens) for larger safety margin
|
|
194
|
+
truncated = text[:2000] if len(text) > 2000 else text
|
|
195
|
+
response = ollama.embeddings(model=self.MODEL, prompt=truncated)
|
|
196
|
+
return response["embedding"]
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class GeminiProvider(BaseEmbeddingProvider):
|
|
200
|
+
"""Gemini embedding provider using OAuth authentication."""
|
|
201
|
+
|
|
202
|
+
MODEL = "gemini-embedding-001"
|
|
203
|
+
DIMENSION = 768 # Using 768 for efficiency, can be up to 3072
|
|
204
|
+
|
|
205
|
+
def __init__(self):
|
|
206
|
+
self._available: bool | None = None
|
|
207
|
+
self._token_store = None
|
|
208
|
+
|
|
209
|
+
def _get_token_store(self):
|
|
210
|
+
if self._token_store is None:
|
|
211
|
+
from ..auth.token_store import TokenStore
|
|
212
|
+
|
|
213
|
+
self._token_store = TokenStore()
|
|
214
|
+
return self._token_store
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def dimension(self) -> int:
|
|
218
|
+
return self.DIMENSION
|
|
219
|
+
|
|
220
|
+
@property
|
|
221
|
+
def name(self) -> str:
|
|
222
|
+
return "gemini"
|
|
223
|
+
|
|
224
|
+
async def check_available(self) -> bool:
|
|
225
|
+
if self._available is not None:
|
|
226
|
+
return self._available
|
|
227
|
+
|
|
228
|
+
try:
|
|
229
|
+
token_store = self._get_token_store()
|
|
230
|
+
access_token = token_store.get_access_token("gemini")
|
|
231
|
+
|
|
232
|
+
if not access_token:
|
|
233
|
+
print(
|
|
234
|
+
"⚠️ Gemini not authenticated. Run: stravinsky-auth login gemini",
|
|
235
|
+
file=sys.stderr,
|
|
236
|
+
)
|
|
237
|
+
self._available = False
|
|
238
|
+
return False
|
|
239
|
+
|
|
240
|
+
self._available = True
|
|
241
|
+
return True
|
|
242
|
+
except Exception as e:
|
|
243
|
+
print(f"⚠️ Gemini not available: {e}", file=sys.stderr)
|
|
244
|
+
self._available = False
|
|
245
|
+
return False
|
|
246
|
+
|
|
247
|
+
async def get_embedding(self, text: str) -> list[float]:
|
|
248
|
+
import os
|
|
249
|
+
|
|
250
|
+
from ..auth.oauth import (
|
|
251
|
+
ANTIGRAVITY_DEFAULT_PROJECT_ID,
|
|
252
|
+
ANTIGRAVITY_ENDPOINTS,
|
|
253
|
+
ANTIGRAVITY_HEADERS,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
token_store = self._get_token_store()
|
|
257
|
+
access_token = token_store.get_access_token("gemini")
|
|
258
|
+
|
|
259
|
+
if not access_token:
|
|
260
|
+
raise ValueError("Not authenticated with Gemini. Run: stravinsky-auth login gemini")
|
|
261
|
+
|
|
262
|
+
httpx = get_httpx()
|
|
263
|
+
|
|
264
|
+
# Use Antigravity endpoint for embeddings (same auth as invoke_gemini)
|
|
265
|
+
project_id = os.getenv("STRAVINSKY_ANTIGRAVITY_PROJECT_ID", ANTIGRAVITY_DEFAULT_PROJECT_ID)
|
|
266
|
+
|
|
267
|
+
headers = {
|
|
268
|
+
"Authorization": f"Bearer {access_token}",
|
|
269
|
+
"Content-Type": "application/json",
|
|
270
|
+
**ANTIGRAVITY_HEADERS,
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
# Wrap request for Antigravity API
|
|
274
|
+
import uuid
|
|
275
|
+
|
|
276
|
+
inner_payload = {
|
|
277
|
+
"model": f"models/{self.MODEL}",
|
|
278
|
+
"content": {"parts": [{"text": text}]},
|
|
279
|
+
"outputDimensionality": self.DIMENSION,
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
wrapped_payload = {
|
|
283
|
+
"project": project_id,
|
|
284
|
+
"model": self.MODEL,
|
|
285
|
+
"userAgent": "antigravity",
|
|
286
|
+
"requestId": f"embed-{uuid.uuid4()}",
|
|
287
|
+
"request": inner_payload,
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
# Try endpoints in order
|
|
291
|
+
last_error = None
|
|
292
|
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
293
|
+
for endpoint in ANTIGRAVITY_ENDPOINTS:
|
|
294
|
+
api_url = f"{endpoint}/v1internal:embedContent"
|
|
295
|
+
|
|
296
|
+
try:
|
|
297
|
+
response = await client.post(
|
|
298
|
+
api_url,
|
|
299
|
+
headers=headers,
|
|
300
|
+
json=wrapped_payload,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
if response.status_code in (401, 403):
|
|
304
|
+
last_error = Exception(f"{response.status_code} from {endpoint}")
|
|
305
|
+
continue
|
|
306
|
+
|
|
307
|
+
response.raise_for_status()
|
|
308
|
+
data = response.json()
|
|
309
|
+
|
|
310
|
+
# Extract embedding from response
|
|
311
|
+
inner_response = data.get("response", data)
|
|
312
|
+
embedding = inner_response.get("embedding", {})
|
|
313
|
+
values = embedding.get("values", [])
|
|
314
|
+
|
|
315
|
+
if values:
|
|
316
|
+
return values
|
|
317
|
+
|
|
318
|
+
raise ValueError(f"No embedding values in response: {data}")
|
|
319
|
+
|
|
320
|
+
except Exception as e:
|
|
321
|
+
last_error = e
|
|
322
|
+
continue
|
|
323
|
+
|
|
324
|
+
raise ValueError(f"All Antigravity endpoints failed for embeddings: {last_error}")
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
class OpenAIProvider(BaseEmbeddingProvider):
|
|
328
|
+
"""OpenAI embedding provider using OAuth authentication."""
|
|
329
|
+
|
|
330
|
+
MODEL = "text-embedding-3-small"
|
|
331
|
+
DIMENSION = 1536
|
|
332
|
+
|
|
333
|
+
def __init__(self):
|
|
334
|
+
self._available: bool | None = None
|
|
335
|
+
self._token_store = None
|
|
336
|
+
|
|
337
|
+
def _get_token_store(self):
|
|
338
|
+
if self._token_store is None:
|
|
339
|
+
from ..auth.token_store import TokenStore
|
|
340
|
+
|
|
341
|
+
self._token_store = TokenStore()
|
|
342
|
+
return self._token_store
|
|
343
|
+
|
|
344
|
+
@property
|
|
345
|
+
def dimension(self) -> int:
|
|
346
|
+
return self.DIMENSION
|
|
347
|
+
|
|
348
|
+
@property
|
|
349
|
+
def name(self) -> str:
|
|
350
|
+
return "openai"
|
|
351
|
+
|
|
352
|
+
async def check_available(self) -> bool:
|
|
353
|
+
if self._available is not None:
|
|
354
|
+
return self._available
|
|
355
|
+
|
|
356
|
+
try:
|
|
357
|
+
token_store = self._get_token_store()
|
|
358
|
+
access_token = token_store.get_access_token("openai")
|
|
359
|
+
|
|
360
|
+
if not access_token:
|
|
361
|
+
print(
|
|
362
|
+
"⚠️ OpenAI not authenticated. Run: stravinsky-auth login openai",
|
|
363
|
+
file=sys.stderr,
|
|
364
|
+
)
|
|
365
|
+
self._available = False
|
|
366
|
+
return False
|
|
367
|
+
|
|
368
|
+
self._available = True
|
|
369
|
+
return True
|
|
370
|
+
except Exception as e:
|
|
371
|
+
print(f"⚠️ OpenAI not available: {e}", file=sys.stderr)
|
|
372
|
+
self._available = False
|
|
373
|
+
return False
|
|
374
|
+
|
|
375
|
+
async def get_embedding(self, text: str) -> list[float]:
|
|
376
|
+
token_store = self._get_token_store()
|
|
377
|
+
access_token = token_store.get_access_token("openai")
|
|
378
|
+
|
|
379
|
+
if not access_token:
|
|
380
|
+
raise ValueError("Not authenticated with OpenAI. Run: stravinsky-auth login openai")
|
|
381
|
+
|
|
382
|
+
httpx = get_httpx()
|
|
383
|
+
|
|
384
|
+
# Use standard OpenAI API for embeddings
|
|
385
|
+
api_url = "https://api.openai.com/v1/embeddings"
|
|
386
|
+
|
|
387
|
+
headers = {
|
|
388
|
+
"Authorization": f"Bearer {access_token}",
|
|
389
|
+
"Content-Type": "application/json",
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
payload = {
|
|
393
|
+
"model": self.MODEL,
|
|
394
|
+
"input": text,
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
398
|
+
response = await client.post(api_url, headers=headers, json=payload)
|
|
399
|
+
|
|
400
|
+
if response.status_code == 401:
|
|
401
|
+
raise ValueError("OpenAI authentication failed. Run: stravinsky-auth login openai")
|
|
402
|
+
|
|
403
|
+
response.raise_for_status()
|
|
404
|
+
data = response.json()
|
|
405
|
+
|
|
406
|
+
# Extract embedding from response
|
|
407
|
+
embeddings = data.get("data", [])
|
|
408
|
+
if embeddings and "embedding" in embeddings[0]:
|
|
409
|
+
return embeddings[0]["embedding"]
|
|
410
|
+
|
|
411
|
+
raise ValueError(f"No embedding in response: {data}")
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
class MxbaiProvider(BaseEmbeddingProvider):
|
|
415
|
+
"""Ollama local embedding provider using mxbai-embed-large (better for code).
|
|
416
|
+
|
|
417
|
+
mxbai-embed-large is a 1024-dimensional model optimized for code understanding.
|
|
418
|
+
It generally outperforms nomic-embed-text on code-related retrieval tasks.
|
|
419
|
+
"""
|
|
420
|
+
|
|
421
|
+
MODEL = "mxbai-embed-large"
|
|
422
|
+
DIMENSION = 1024
|
|
423
|
+
|
|
424
|
+
def __init__(self):
|
|
425
|
+
self._available: bool | None = None
|
|
426
|
+
|
|
427
|
+
@property
|
|
428
|
+
def dimension(self) -> int:
|
|
429
|
+
return self.DIMENSION
|
|
430
|
+
|
|
431
|
+
@property
|
|
432
|
+
def name(self) -> str:
|
|
433
|
+
return "mxbai"
|
|
434
|
+
|
|
435
|
+
async def check_available(self) -> bool:
|
|
436
|
+
if self._available is not None:
|
|
437
|
+
return self._available
|
|
438
|
+
|
|
439
|
+
try:
|
|
440
|
+
ollama = get_ollama()
|
|
441
|
+
models = ollama.list()
|
|
442
|
+
model_names = [m.model for m in models.models] if hasattr(models, "models") else []
|
|
443
|
+
|
|
444
|
+
if not any(name and self.MODEL in name for name in model_names):
|
|
445
|
+
print(
|
|
446
|
+
f"⚠️ Embedding model '{self.MODEL}' not found. Run: ollama pull {self.MODEL}",
|
|
447
|
+
file=sys.stderr,
|
|
448
|
+
)
|
|
449
|
+
self._available = False
|
|
450
|
+
return False
|
|
451
|
+
|
|
452
|
+
self._available = True
|
|
453
|
+
return True
|
|
454
|
+
except Exception as e:
|
|
455
|
+
print(f"⚠️ Ollama not available: {e}. Start with: ollama serve", file=sys.stderr)
|
|
456
|
+
self._available = False
|
|
457
|
+
return False
|
|
458
|
+
|
|
459
|
+
async def get_embedding(self, text: str) -> list[float]:
|
|
460
|
+
ollama = get_ollama()
|
|
461
|
+
# mxbai-embed-large has 512 token context. Code can be 1-2 chars/token.
|
|
462
|
+
# Truncate to 2000 chars (~1000-2000 tokens) for safety margin
|
|
463
|
+
truncated = text[:2000] if len(text) > 2000 else text
|
|
464
|
+
response = ollama.embeddings(model=self.MODEL, prompt=truncated)
|
|
465
|
+
return response["embedding"]
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
class HuggingFaceProvider(BaseEmbeddingProvider):
|
|
469
|
+
"""Hugging Face Inference API embedding provider.
|
|
470
|
+
|
|
471
|
+
Uses the Hugging Face Inference API for embeddings. Requires HF_TOKEN from:
|
|
472
|
+
1. Environment variable: HF_TOKEN or HUGGING_FACE_HUB_TOKEN
|
|
473
|
+
2. HF CLI config: ~/.cache/huggingface/token or ~/.huggingface/token
|
|
474
|
+
|
|
475
|
+
Default model: sentence-transformers/all-mpnet-base-v2 (768 dims, high quality)
|
|
476
|
+
"""
|
|
477
|
+
|
|
478
|
+
DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
|
479
|
+
DEFAULT_DIMENSION = 768
|
|
480
|
+
|
|
481
|
+
def __init__(self, model: str | None = None):
|
|
482
|
+
self._available: bool | None = None
|
|
483
|
+
self._model = model or self.DEFAULT_MODEL
|
|
484
|
+
# Dimension varies by model, but we'll use default for common models
|
|
485
|
+
self._dimension = self.DEFAULT_DIMENSION
|
|
486
|
+
self._token: str | None = None
|
|
487
|
+
|
|
488
|
+
@property
|
|
489
|
+
def dimension(self) -> int:
|
|
490
|
+
return self._dimension
|
|
491
|
+
|
|
492
|
+
@property
|
|
493
|
+
def name(self) -> str:
|
|
494
|
+
return "huggingface"
|
|
495
|
+
|
|
496
|
+
def _get_hf_token(self) -> str | None:
|
|
497
|
+
"""Discover HF token from environment or CLI config."""
|
|
498
|
+
import os
|
|
499
|
+
|
|
500
|
+
# Check environment variables first
|
|
501
|
+
token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
|
|
502
|
+
if token:
|
|
503
|
+
return token
|
|
504
|
+
|
|
505
|
+
# Check HF CLI config locations
|
|
506
|
+
hf_token_paths = [
|
|
507
|
+
Path.home() / ".cache" / "huggingface" / "token",
|
|
508
|
+
Path.home() / ".huggingface" / "token",
|
|
509
|
+
]
|
|
510
|
+
|
|
511
|
+
for token_path in hf_token_paths:
|
|
512
|
+
if token_path.exists():
|
|
513
|
+
try:
|
|
514
|
+
return token_path.read_text().strip()
|
|
515
|
+
except Exception:
|
|
516
|
+
continue
|
|
517
|
+
|
|
518
|
+
return None
|
|
519
|
+
|
|
520
|
+
async def check_available(self) -> bool:
|
|
521
|
+
if self._available is not None:
|
|
522
|
+
return self._available
|
|
523
|
+
|
|
524
|
+
try:
|
|
525
|
+
self._token = self._get_hf_token()
|
|
526
|
+
if not self._token:
|
|
527
|
+
print(
|
|
528
|
+
"⚠️ Hugging Face token not found. Run: huggingface-cli login or set HF_TOKEN env var",
|
|
529
|
+
file=sys.stderr,
|
|
530
|
+
)
|
|
531
|
+
self._available = False
|
|
532
|
+
return False
|
|
533
|
+
|
|
534
|
+
self._available = True
|
|
535
|
+
return True
|
|
536
|
+
except Exception as e:
|
|
537
|
+
print(f"⚠️ Hugging Face not available: {e}", file=sys.stderr)
|
|
538
|
+
self._available = False
|
|
539
|
+
return False
|
|
540
|
+
|
|
541
|
+
@retry(
|
|
542
|
+
stop=stop_after_attempt(3),
|
|
543
|
+
wait=wait_exponential(multiplier=1, min=2, max=10),
|
|
544
|
+
retry=retry_if_exception_type(httpx.HTTPStatusError),
|
|
545
|
+
)
|
|
546
|
+
async def get_embedding(self, text: str) -> list[float]:
|
|
547
|
+
"""Get embedding from HF Inference API with retry logic."""
|
|
548
|
+
if not self._token:
|
|
549
|
+
self._token = self._get_hf_token()
|
|
550
|
+
if not self._token:
|
|
551
|
+
raise ValueError(
|
|
552
|
+
"Hugging Face token not found. Run: huggingface-cli login or set HF_TOKEN"
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
httpx_client = get_httpx()
|
|
556
|
+
|
|
557
|
+
# HF Serverless Inference API endpoint
|
|
558
|
+
# Note: Free tier may have limited availability for some models
|
|
559
|
+
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{self._model}"
|
|
560
|
+
|
|
561
|
+
headers = {
|
|
562
|
+
"Authorization": f"Bearer {self._token}",
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
# Truncate text to reasonable length (most models have 512 token limit)
|
|
566
|
+
# ~2000 chars ≈ 500 tokens for safety
|
|
567
|
+
truncated = text[:2000] if len(text) > 2000 else text
|
|
568
|
+
|
|
569
|
+
# HF Inference API accepts raw JSON with inputs field
|
|
570
|
+
payload = {"inputs": [truncated], "options": {"wait_for_model": True}}
|
|
571
|
+
|
|
572
|
+
async with httpx_client.AsyncClient(timeout=60.0) as client:
|
|
573
|
+
response = await client.post(api_url, headers=headers, json=payload)
|
|
574
|
+
|
|
575
|
+
# Handle specific error codes
|
|
576
|
+
if response.status_code == 401:
|
|
577
|
+
raise ValueError(
|
|
578
|
+
"Hugging Face authentication failed. Run: huggingface-cli login or set HF_TOKEN"
|
|
579
|
+
)
|
|
580
|
+
elif response.status_code == 410:
|
|
581
|
+
# Model removed from free tier
|
|
582
|
+
raise ValueError(
|
|
583
|
+
f"Model {self._model} is no longer available on HF free Inference API (410 Gone). "
|
|
584
|
+
"Try a different model or use Ollama for local embeddings instead."
|
|
585
|
+
)
|
|
586
|
+
elif response.status_code == 503:
|
|
587
|
+
# Model loading - retry will handle this
|
|
588
|
+
logger.info(f"Model {self._model} is loading, retrying...")
|
|
589
|
+
response.raise_for_status()
|
|
590
|
+
elif response.status_code == 429:
|
|
591
|
+
# Rate limit - retry will handle with exponential backoff
|
|
592
|
+
logger.warning("HF API rate limit hit, retrying with backoff...")
|
|
593
|
+
response.raise_for_status()
|
|
594
|
+
|
|
595
|
+
response.raise_for_status()
|
|
596
|
+
|
|
597
|
+
# Response is a single embedding vector (list of floats)
|
|
598
|
+
embedding = response.json()
|
|
599
|
+
|
|
600
|
+
# Handle different response formats
|
|
601
|
+
if isinstance(embedding, list):
|
|
602
|
+
# Direct embedding or batch with single item
|
|
603
|
+
if isinstance(embedding[0], (int, float)):
|
|
604
|
+
return embedding
|
|
605
|
+
elif isinstance(embedding[0], list):
|
|
606
|
+
# Batch response with single embedding
|
|
607
|
+
return embedding[0]
|
|
608
|
+
|
|
609
|
+
raise ValueError(f"Unexpected response format from HF API: {type(embedding)}")
|
|
610
|
+
|
|
611
|
+
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
612
|
+
"""Batch embedding support for HF API.
|
|
613
|
+
|
|
614
|
+
HF API supports batch requests, so we can send multiple texts at once.
|
|
615
|
+
"""
|
|
616
|
+
if not texts:
|
|
617
|
+
return []
|
|
618
|
+
|
|
619
|
+
if not self._token:
|
|
620
|
+
self._token = self._get_hf_token()
|
|
621
|
+
if not self._token:
|
|
622
|
+
raise ValueError(
|
|
623
|
+
"Hugging Face token not found. Run: huggingface-cli login or set HF_TOKEN"
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
httpx_client = get_httpx()
|
|
627
|
+
|
|
628
|
+
# HF Serverless Inference API endpoint
|
|
629
|
+
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{self._model}"
|
|
630
|
+
|
|
631
|
+
headers = {
|
|
632
|
+
"Authorization": f"Bearer {self._token}",
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
# Truncate all texts
|
|
636
|
+
truncated_texts = [text[:2000] if len(text) > 2000 else text for text in texts]
|
|
637
|
+
|
|
638
|
+
payload = {"inputs": truncated_texts, "options": {"wait_for_model": True}}
|
|
639
|
+
|
|
640
|
+
async with httpx_client.AsyncClient(timeout=120.0) as client:
|
|
641
|
+
response = await client.post(api_url, headers=headers, json=payload)
|
|
642
|
+
|
|
643
|
+
if response.status_code == 401:
|
|
644
|
+
raise ValueError(
|
|
645
|
+
"Hugging Face authentication failed. Run: huggingface-cli login or set HF_TOKEN"
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
response.raise_for_status()
|
|
649
|
+
|
|
650
|
+
embeddings = response.json()
|
|
651
|
+
|
|
652
|
+
# Response should be a list of embeddings
|
|
653
|
+
if isinstance(embeddings, list) and all(isinstance(e, list) for e in embeddings):
|
|
654
|
+
return embeddings
|
|
655
|
+
|
|
656
|
+
raise ValueError(f"Unexpected batch response format from HF API: {type(embeddings)}")
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
# Embedding provider instance cache
|
|
660
|
+
_embedding_provider_cache: dict[str, BaseEmbeddingProvider] = {}
|
|
661
|
+
_embedding_provider_lock = threading.Lock()
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
def get_embedding_provider(provider: EmbeddingProvider) -> BaseEmbeddingProvider:
|
|
665
|
+
"""Factory function to get an embedding provider instance with caching."""
|
|
666
|
+
if provider not in _embedding_provider_cache:
|
|
667
|
+
with _embedding_provider_lock:
|
|
668
|
+
# Double-check pattern to avoid race condition
|
|
669
|
+
if provider not in _embedding_provider_cache:
|
|
670
|
+
providers = {
|
|
671
|
+
"ollama": OllamaProvider,
|
|
672
|
+
"mxbai": MxbaiProvider,
|
|
673
|
+
"gemini": GeminiProvider,
|
|
674
|
+
"openai": OpenAIProvider,
|
|
675
|
+
"huggingface": HuggingFaceProvider,
|
|
676
|
+
}
|
|
677
|
+
|
|
678
|
+
if provider not in providers:
|
|
679
|
+
raise ValueError(f"Unknown provider: {provider}. Available: {list(providers.keys())}")
|
|
680
|
+
|
|
681
|
+
_embedding_provider_cache[provider] = providers[provider]()
|
|
682
|
+
|
|
683
|
+
return _embedding_provider_cache[provider]
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
class CodebaseVectorStore:
|
|
687
|
+
"""
|
|
688
|
+
Persistent vector store for a single codebase.
|
|
689
|
+
|
|
690
|
+
Storage: ~/.stravinsky/vectordb/<project_hash>_<provider>/
|
|
691
|
+
Embedding: Configurable via provider (ollama, gemini, openai)
|
|
692
|
+
"""
|
|
693
|
+
|
|
694
|
+
CHUNK_SIZE = 50 # lines per chunk
|
|
695
|
+
CHUNK_OVERLAP = 10 # lines of overlap between chunks
|
|
696
|
+
|
|
697
|
+
# File patterns to index
|
|
698
|
+
CODE_EXTENSIONS = {
|
|
699
|
+
".py",
|
|
700
|
+
".js",
|
|
701
|
+
".ts",
|
|
702
|
+
".tsx",
|
|
703
|
+
".jsx",
|
|
704
|
+
".go",
|
|
705
|
+
".rs",
|
|
706
|
+
".rb",
|
|
707
|
+
".java",
|
|
708
|
+
".c",
|
|
709
|
+
".cpp",
|
|
710
|
+
".h",
|
|
711
|
+
".hpp",
|
|
712
|
+
".cs",
|
|
713
|
+
".swift",
|
|
714
|
+
".kt",
|
|
715
|
+
".scala",
|
|
716
|
+
".vue",
|
|
717
|
+
".svelte",
|
|
718
|
+
".md",
|
|
719
|
+
".txt",
|
|
720
|
+
".yaml",
|
|
721
|
+
".yml",
|
|
722
|
+
".json",
|
|
723
|
+
".toml",
|
|
724
|
+
}
|
|
725
|
+
|
|
726
|
+
# Directories to skip
|
|
727
|
+
SKIP_DIRS = {
|
|
728
|
+
"node_modules",
|
|
729
|
+
".git",
|
|
730
|
+
"__pycache__",
|
|
731
|
+
".venv",
|
|
732
|
+
"venv",
|
|
733
|
+
"env",
|
|
734
|
+
"dist",
|
|
735
|
+
"build",
|
|
736
|
+
".next",
|
|
737
|
+
".nuxt",
|
|
738
|
+
"target",
|
|
739
|
+
".tox",
|
|
740
|
+
".pytest_cache",
|
|
741
|
+
".mypy_cache",
|
|
742
|
+
".ruff_cache",
|
|
743
|
+
"coverage",
|
|
744
|
+
".stravinsky",
|
|
745
|
+
}
|
|
746
|
+
|
|
747
|
+
def __init__(self, project_path: str, provider: EmbeddingProvider = "ollama"):
|
|
748
|
+
self.project_path = Path(project_path).resolve()
|
|
749
|
+
self.project_hash = hashlib.md5(str(self.project_path).encode()).hexdigest()[:12]
|
|
750
|
+
|
|
751
|
+
# Initialize embedding provider
|
|
752
|
+
self.provider_name = provider
|
|
753
|
+
self.provider = get_embedding_provider(provider)
|
|
754
|
+
|
|
755
|
+
# Store in user's home directory, separate by provider to avoid dimension mismatch
|
|
756
|
+
self.db_path = Path.home() / ".stravinsky" / "vectordb" / f"{self.project_hash}_{provider}"
|
|
757
|
+
self.db_path.mkdir(parents=True, exist_ok=True)
|
|
758
|
+
|
|
759
|
+
# File lock for single-process access to ChromaDB (prevents corruption)
|
|
760
|
+
self._lock_path = self.db_path / ".chromadb.lock"
|
|
761
|
+
self._file_lock = None
|
|
762
|
+
|
|
763
|
+
self._client = None
|
|
764
|
+
self._collection = None
|
|
765
|
+
|
|
766
|
+
# File watcher attributes
|
|
767
|
+
self._watcher: "CodebaseFileWatcher | None" = None
|
|
768
|
+
self._watcher_lock = threading.Lock()
|
|
769
|
+
|
|
770
|
+
@property
|
|
771
|
+
def file_lock(self):
|
|
772
|
+
"""Get or create the file lock for this database.
|
|
773
|
+
|
|
774
|
+
Uses filelock to ensure single-process access to ChromaDB,
|
|
775
|
+
preventing database corruption from concurrent writes.
|
|
776
|
+
"""
|
|
777
|
+
if self._file_lock is None:
|
|
778
|
+
filelock = get_filelock()
|
|
779
|
+
# Timeout of 30 seconds - if lock can't be acquired, raise error
|
|
780
|
+
self._file_lock = filelock.FileLock(str(self._lock_path), timeout=30)
|
|
781
|
+
return self._file_lock
|
|
782
|
+
|
|
783
|
+
@property
|
|
784
|
+
def client(self):
|
|
785
|
+
if self._client is None:
|
|
786
|
+
chromadb = get_chromadb()
|
|
787
|
+
# Acquire lock before creating client to prevent concurrent access
|
|
788
|
+
try:
|
|
789
|
+
self.file_lock.acquire()
|
|
790
|
+
logger.debug(f"Acquired ChromaDB lock for {self.db_path}")
|
|
791
|
+
except Exception as e:
|
|
792
|
+
logger.warning(f"Could not acquire ChromaDB lock: {e}. Proceeding without lock.")
|
|
793
|
+
self._client = chromadb.PersistentClient(path=str(self.db_path))
|
|
794
|
+
return self._client
|
|
795
|
+
|
|
796
|
+
@property
|
|
797
|
+
def collection(self):
|
|
798
|
+
if self._collection is None:
|
|
799
|
+
self._collection = self.client.get_or_create_collection(
|
|
800
|
+
name="codebase", metadata={"hnsw:space": "cosine"}
|
|
801
|
+
)
|
|
802
|
+
return self._collection
|
|
803
|
+
|
|
804
|
+
async def check_embedding_service(self) -> bool:
|
|
805
|
+
"""Check if the embedding provider is available."""
|
|
806
|
+
return await self.provider.check_available()
|
|
807
|
+
|
|
808
|
+
async def get_embedding(self, text: str) -> list[float]:
|
|
809
|
+
"""Get embedding vector for text using the configured provider."""
|
|
810
|
+
return await self.provider.get_embedding(text)
|
|
811
|
+
|
|
812
|
+
async def get_embeddings_batch(
|
|
813
|
+
self, texts: list[str], max_concurrent: int = 10
|
|
814
|
+
) -> list[list[float]]:
|
|
815
|
+
"""Get embeddings for multiple texts with parallel execution.
|
|
816
|
+
|
|
817
|
+
Uses asyncio.gather with semaphore-based concurrency control to avoid
|
|
818
|
+
overwhelming the embedding service while maximizing throughput.
|
|
819
|
+
|
|
820
|
+
Args:
|
|
821
|
+
texts: List of text strings to embed
|
|
822
|
+
max_concurrent: Maximum concurrent embedding requests (default: 10)
|
|
823
|
+
|
|
824
|
+
Returns:
|
|
825
|
+
List of embedding vectors in the same order as input texts.
|
|
826
|
+
"""
|
|
827
|
+
import asyncio
|
|
828
|
+
|
|
829
|
+
if not texts:
|
|
830
|
+
return []
|
|
831
|
+
|
|
832
|
+
# Use semaphore to limit concurrent requests
|
|
833
|
+
semaphore = asyncio.Semaphore(max_concurrent)
|
|
834
|
+
|
|
835
|
+
async def get_with_semaphore(text: str, index: int) -> tuple[int, list[float]]:
|
|
836
|
+
async with semaphore:
|
|
837
|
+
emb = await self.get_embedding(text)
|
|
838
|
+
return (index, emb)
|
|
839
|
+
|
|
840
|
+
# Launch all embedding requests concurrently (respecting semaphore)
|
|
841
|
+
tasks = [get_with_semaphore(text, i) for i, text in enumerate(texts)]
|
|
842
|
+
results = await asyncio.gather(*tasks)
|
|
843
|
+
|
|
844
|
+
# Sort by original index to maintain order
|
|
845
|
+
sorted_results = sorted(results, key=lambda x: x[0])
|
|
846
|
+
return [emb for _, emb in sorted_results]
|
|
847
|
+
|
|
848
|
+
def _chunk_file(self, file_path: Path) -> list[dict]:
|
|
849
|
+
"""Split a file into chunks with metadata.
|
|
850
|
+
|
|
851
|
+
Uses AST-aware chunking for Python files to respect function/class
|
|
852
|
+
boundaries. Falls back to line-based chunking for other languages.
|
|
853
|
+
"""
|
|
854
|
+
try:
|
|
855
|
+
content = file_path.read_text(encoding="utf-8", errors="ignore")
|
|
856
|
+
except Exception:
|
|
857
|
+
return []
|
|
858
|
+
|
|
859
|
+
lines = content.split("\n")
|
|
860
|
+
if len(lines) < 5: # Skip very small files
|
|
861
|
+
return []
|
|
862
|
+
|
|
863
|
+
rel_path = str(file_path.relative_to(self.project_path))
|
|
864
|
+
language = file_path.suffix.lstrip(".")
|
|
865
|
+
|
|
866
|
+
# Use AST-aware chunking for Python files
|
|
867
|
+
if language == "py":
|
|
868
|
+
chunks = self._chunk_python_ast(content, rel_path, language)
|
|
869
|
+
if chunks: # If AST parsing succeeded
|
|
870
|
+
return chunks
|
|
871
|
+
|
|
872
|
+
# Fallback: line-based chunking for other languages or if AST fails
|
|
873
|
+
return self._chunk_by_lines(lines, rel_path, language)
|
|
874
|
+
|
|
875
|
+
def _chunk_python_ast(self, content: str, rel_path: str, language: str) -> list[dict]:
|
|
876
|
+
"""Parse Python file and create chunks based on function/class boundaries.
|
|
877
|
+
|
|
878
|
+
Each function, method, and class becomes its own chunk, preserving
|
|
879
|
+
semantic boundaries for better embedding quality.
|
|
880
|
+
"""
|
|
881
|
+
import ast
|
|
882
|
+
|
|
883
|
+
try:
|
|
884
|
+
tree = ast.parse(content)
|
|
885
|
+
except SyntaxError:
|
|
886
|
+
return [] # Fall back to line-based chunking
|
|
887
|
+
|
|
888
|
+
lines = content.split("\n")
|
|
889
|
+
chunks = []
|
|
890
|
+
|
|
891
|
+
def get_docstring(node: ast.AST) -> str:
|
|
892
|
+
"""Extract docstring from a node if present."""
|
|
893
|
+
if (
|
|
894
|
+
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef))
|
|
895
|
+
and node.body
|
|
896
|
+
):
|
|
897
|
+
first = node.body[0]
|
|
898
|
+
if isinstance(first, ast.Expr) and isinstance(first.value, ast.Constant):
|
|
899
|
+
if isinstance(first.value.value, str):
|
|
900
|
+
return first.value.value
|
|
901
|
+
return ""
|
|
902
|
+
|
|
903
|
+
def get_decorators(
|
|
904
|
+
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
|
|
905
|
+
) -> list[str]:
|
|
906
|
+
"""Extract decorator names from a node."""
|
|
907
|
+
decorators = []
|
|
908
|
+
for dec in node.decorator_list:
|
|
909
|
+
if isinstance(dec, ast.Name):
|
|
910
|
+
decorators.append(f"@{dec.id}")
|
|
911
|
+
elif isinstance(dec, ast.Attribute):
|
|
912
|
+
decorators.append(f"@{ast.unparse(dec)}")
|
|
913
|
+
elif isinstance(dec, ast.Call):
|
|
914
|
+
if isinstance(dec.func, ast.Name):
|
|
915
|
+
decorators.append(f"@{dec.func.id}")
|
|
916
|
+
elif isinstance(dec.func, ast.Attribute):
|
|
917
|
+
decorators.append(f"@{ast.unparse(dec.func)}")
|
|
918
|
+
return decorators
|
|
919
|
+
|
|
920
|
+
def get_base_classes(node: ast.ClassDef) -> list[str]:
|
|
921
|
+
"""Extract base class names from a class definition."""
|
|
922
|
+
bases = []
|
|
923
|
+
for base in node.bases:
|
|
924
|
+
if isinstance(base, ast.Name):
|
|
925
|
+
bases.append(base.id)
|
|
926
|
+
elif isinstance(base, ast.Attribute):
|
|
927
|
+
bases.append(ast.unparse(base))
|
|
928
|
+
else:
|
|
929
|
+
bases.append(ast.unparse(base))
|
|
930
|
+
return bases
|
|
931
|
+
|
|
932
|
+
def get_return_type(node: ast.FunctionDef | ast.AsyncFunctionDef) -> str:
|
|
933
|
+
"""Extract return type annotation from a function."""
|
|
934
|
+
if node.returns:
|
|
935
|
+
return ast.unparse(node.returns)
|
|
936
|
+
return ""
|
|
937
|
+
|
|
938
|
+
def get_parameters(node: ast.FunctionDef | ast.AsyncFunctionDef) -> list[str]:
|
|
939
|
+
"""Extract parameter signatures from a function."""
|
|
940
|
+
params = []
|
|
941
|
+
for arg in node.args.args:
|
|
942
|
+
param = arg.arg
|
|
943
|
+
if arg.annotation:
|
|
944
|
+
param += f": {ast.unparse(arg.annotation)}"
|
|
945
|
+
params.append(param)
|
|
946
|
+
return params
|
|
947
|
+
|
|
948
|
+
def add_chunk(
|
|
949
|
+
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
|
|
950
|
+
node_type: str,
|
|
951
|
+
name: str,
|
|
952
|
+
parent_class: str | None = None,
|
|
953
|
+
) -> None:
|
|
954
|
+
"""Add a chunk for a function/class node."""
|
|
955
|
+
start_line = node.lineno
|
|
956
|
+
end_line = node.end_lineno or start_line
|
|
957
|
+
|
|
958
|
+
# Extract the source code for this node
|
|
959
|
+
chunk_lines = lines[start_line - 1 : end_line]
|
|
960
|
+
chunk_text = "\n".join(chunk_lines)
|
|
961
|
+
content_hash = hashlib.md5(chunk_text.encode("utf-8")).hexdigest()[:12]
|
|
962
|
+
|
|
963
|
+
# Skip very small chunks
|
|
964
|
+
if len(chunk_lines) < 3:
|
|
965
|
+
return
|
|
966
|
+
|
|
967
|
+
# Build descriptive header
|
|
968
|
+
docstring = get_docstring(node)
|
|
969
|
+
if parent_class:
|
|
970
|
+
header = f"File: {rel_path}\n{node_type}: {parent_class}.{name}\nLines: {start_line}-{end_line}"
|
|
971
|
+
else:
|
|
972
|
+
header = f"File: {rel_path}\n{node_type}: {name}\nLines: {start_line}-{end_line}"
|
|
973
|
+
|
|
974
|
+
if docstring:
|
|
975
|
+
header += f"\nDocstring: {docstring[:200]}..."
|
|
976
|
+
|
|
977
|
+
document = f"{header}\n\n{chunk_text}"
|
|
978
|
+
|
|
979
|
+
chunks.append(
|
|
980
|
+
{
|
|
981
|
+
"id": f"{rel_path}:{start_line}-{end_line}:{content_hash}",
|
|
982
|
+
"document": document,
|
|
983
|
+
"metadata": {
|
|
984
|
+
"file_path": rel_path,
|
|
985
|
+
"start_line": start_line,
|
|
986
|
+
"end_line": end_line,
|
|
987
|
+
"language": language,
|
|
988
|
+
"node_type": node_type.lower(),
|
|
989
|
+
"name": f"{parent_class}.{name}" if parent_class else name,
|
|
990
|
+
# Structural metadata for filtering
|
|
991
|
+
"decorators": ",".join(get_decorators(node)),
|
|
992
|
+
"is_async": isinstance(node, ast.AsyncFunctionDef),
|
|
993
|
+
# Class-specific metadata
|
|
994
|
+
"base_classes": ",".join(get_base_classes(node))
|
|
995
|
+
if isinstance(node, ast.ClassDef)
|
|
996
|
+
else "",
|
|
997
|
+
# Function-specific metadata
|
|
998
|
+
"return_type": get_return_type(node)
|
|
999
|
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
|
1000
|
+
else "",
|
|
1001
|
+
"parameters": ",".join(get_parameters(node))
|
|
1002
|
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
|
1003
|
+
else "",
|
|
1004
|
+
},
|
|
1005
|
+
}
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
# Walk the AST and extract functions/classes
|
|
1009
|
+
for node in ast.walk(tree):
|
|
1010
|
+
if isinstance(node, ast.ClassDef):
|
|
1011
|
+
add_chunk(node, "Class", node.name)
|
|
1012
|
+
# Also add methods as separate chunks for granular search
|
|
1013
|
+
for item in node.body:
|
|
1014
|
+
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
1015
|
+
add_chunk(item, "Method", item.name, parent_class=node.name)
|
|
1016
|
+
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
1017
|
+
# Only top-level functions (not methods)
|
|
1018
|
+
# Check if this function is inside a class body
|
|
1019
|
+
is_method = False
|
|
1020
|
+
for parent in ast.walk(tree):
|
|
1021
|
+
if isinstance(parent, ast.ClassDef):
|
|
1022
|
+
body = getattr(parent, "body", None)
|
|
1023
|
+
if isinstance(body, list) and node in body:
|
|
1024
|
+
is_method = True
|
|
1025
|
+
break
|
|
1026
|
+
if not is_method:
|
|
1027
|
+
add_chunk(node, "Function", node.name)
|
|
1028
|
+
|
|
1029
|
+
# If we found no functions/classes, chunk module-level code
|
|
1030
|
+
if not chunks and len(lines) >= 5:
|
|
1031
|
+
# Add module-level chunk for imports and constants
|
|
1032
|
+
module_chunk = "\n".join(lines[: min(50, len(lines))])
|
|
1033
|
+
chunks.append(
|
|
1034
|
+
{
|
|
1035
|
+
"id": f"{rel_path}:1-{min(50, len(lines))}",
|
|
1036
|
+
"document": f"File: {rel_path}\nModule-level code\nLines: 1-{min(50, len(lines))}\n\n{module_chunk}",
|
|
1037
|
+
"metadata": {
|
|
1038
|
+
"file_path": rel_path,
|
|
1039
|
+
"start_line": 1,
|
|
1040
|
+
"end_line": min(50, len(lines)),
|
|
1041
|
+
"language": language,
|
|
1042
|
+
"node_type": "module",
|
|
1043
|
+
"name": rel_path,
|
|
1044
|
+
},
|
|
1045
|
+
}
|
|
1046
|
+
)
|
|
1047
|
+
|
|
1048
|
+
return chunks
|
|
1049
|
+
|
|
1050
|
+
def _chunk_by_lines(self, lines: list[str], rel_path: str, language: str) -> list[dict]:
|
|
1051
|
+
"""Fallback line-based chunking with overlap."""
|
|
1052
|
+
chunks = []
|
|
1053
|
+
|
|
1054
|
+
for i in range(0, len(lines), self.CHUNK_SIZE - self.CHUNK_OVERLAP):
|
|
1055
|
+
chunk_lines = lines[i : i + self.CHUNK_SIZE]
|
|
1056
|
+
if len(chunk_lines) < 5: # Skip tiny trailing chunks
|
|
1057
|
+
continue
|
|
1058
|
+
|
|
1059
|
+
chunk_text = "\n".join(chunk_lines)
|
|
1060
|
+
content_hash = hashlib.md5(chunk_text.encode("utf-8")).hexdigest()[:12]
|
|
1061
|
+
start_line = i + 1
|
|
1062
|
+
end_line = i + len(chunk_lines)
|
|
1063
|
+
|
|
1064
|
+
# Create a searchable document with context
|
|
1065
|
+
document = f"File: {rel_path}\nLines: {start_line}-{end_line}\n\n{chunk_text}"
|
|
1066
|
+
|
|
1067
|
+
chunks.append(
|
|
1068
|
+
{
|
|
1069
|
+
"id": f"{rel_path}:{start_line}-{end_line}:{content_hash}",
|
|
1070
|
+
"document": document,
|
|
1071
|
+
"metadata": {
|
|
1072
|
+
"file_path": rel_path,
|
|
1073
|
+
"start_line": start_line,
|
|
1074
|
+
"end_line": end_line,
|
|
1075
|
+
"language": language,
|
|
1076
|
+
},
|
|
1077
|
+
}
|
|
1078
|
+
)
|
|
1079
|
+
|
|
1080
|
+
return chunks
|
|
1081
|
+
|
|
1082
|
+
def _get_files_to_index(self) -> list[Path]:
|
|
1083
|
+
"""Get all indexable files in the project."""
|
|
1084
|
+
files = []
|
|
1085
|
+
for file_path in self.project_path.rglob("*"):
|
|
1086
|
+
if file_path.is_file():
|
|
1087
|
+
# Skip hidden files and directories
|
|
1088
|
+
if any(
|
|
1089
|
+
part.startswith(".") for part in file_path.parts[len(self.project_path.parts) :]
|
|
1090
|
+
):
|
|
1091
|
+
if file_path.suffix not in {".md", ".txt"}: # Allow .github docs
|
|
1092
|
+
continue
|
|
1093
|
+
|
|
1094
|
+
# Skip excluded directories
|
|
1095
|
+
if any(skip_dir in file_path.parts for skip_dir in self.SKIP_DIRS):
|
|
1096
|
+
continue
|
|
1097
|
+
|
|
1098
|
+
# Only include code files
|
|
1099
|
+
if file_path.suffix.lower() in self.CODE_EXTENSIONS:
|
|
1100
|
+
files.append(file_path)
|
|
1101
|
+
|
|
1102
|
+
return files
|
|
1103
|
+
|
|
1104
|
+
async def index_codebase(self, force: bool = False) -> dict:
|
|
1105
|
+
"""
|
|
1106
|
+
Index the entire codebase into the vector store.
|
|
1107
|
+
|
|
1108
|
+
Args:
|
|
1109
|
+
force: If True, reindex everything. Otherwise, only index new/changed files.
|
|
1110
|
+
|
|
1111
|
+
Returns:
|
|
1112
|
+
Statistics about the indexing operation.
|
|
1113
|
+
"""
|
|
1114
|
+
import time
|
|
1115
|
+
|
|
1116
|
+
# Start timing
|
|
1117
|
+
start_time = time.time()
|
|
1118
|
+
|
|
1119
|
+
print(f"🔍 SEMANTIC-INDEX: {self.project_path}", file=sys.stderr)
|
|
1120
|
+
|
|
1121
|
+
# Notify reindex start (non-blocking)
|
|
1122
|
+
notifier = None # Initialize to avoid NameError in error handlers
|
|
1123
|
+
try:
|
|
1124
|
+
from mcp_bridge.notifications import get_notification_manager
|
|
1125
|
+
notifier = get_notification_manager()
|
|
1126
|
+
await notifier.notify_reindex_start(str(self.project_path))
|
|
1127
|
+
except Exception as e:
|
|
1128
|
+
logger.warning(f"Failed to send reindex start notification: {e}")
|
|
1129
|
+
|
|
1130
|
+
try:
|
|
1131
|
+
if not await self.check_embedding_service():
|
|
1132
|
+
error_msg = "Embedding service not available"
|
|
1133
|
+
# Notify error
|
|
1134
|
+
try:
|
|
1135
|
+
if notifier:
|
|
1136
|
+
await notifier.notify_reindex_error(error_msg)
|
|
1137
|
+
except Exception as e:
|
|
1138
|
+
logger.warning(f"Failed to send reindex error notification: {e}")
|
|
1139
|
+
return {"error": error_msg, "indexed": 0}
|
|
1140
|
+
|
|
1141
|
+
# Get existing document IDs
|
|
1142
|
+
existing_ids = set()
|
|
1143
|
+
try:
|
|
1144
|
+
# Only fetch IDs to minimize overhead
|
|
1145
|
+
existing = self.collection.get(include=[])
|
|
1146
|
+
existing_ids = set(existing["ids"]) if existing["ids"] else set()
|
|
1147
|
+
except Exception:
|
|
1148
|
+
pass
|
|
1149
|
+
|
|
1150
|
+
if force:
|
|
1151
|
+
# Clear existing collection
|
|
1152
|
+
try:
|
|
1153
|
+
self.client.delete_collection("codebase")
|
|
1154
|
+
self._collection = None
|
|
1155
|
+
existing_ids = set()
|
|
1156
|
+
except Exception:
|
|
1157
|
+
pass
|
|
1158
|
+
|
|
1159
|
+
files = self._get_files_to_index()
|
|
1160
|
+
all_chunks = []
|
|
1161
|
+
current_chunk_ids = set()
|
|
1162
|
+
|
|
1163
|
+
# Mark: Generate all chunks for current codebase
|
|
1164
|
+
for file_path in files:
|
|
1165
|
+
chunks = self._chunk_file(file_path)
|
|
1166
|
+
all_chunks.extend(chunks)
|
|
1167
|
+
for c in chunks:
|
|
1168
|
+
current_chunk_ids.add(c["id"])
|
|
1169
|
+
|
|
1170
|
+
# Sweep: Identify stale chunks to remove
|
|
1171
|
+
to_delete = existing_ids - current_chunk_ids
|
|
1172
|
+
|
|
1173
|
+
# Identify new chunks to add
|
|
1174
|
+
to_add_ids = current_chunk_ids - existing_ids
|
|
1175
|
+
chunks_to_add = [c for c in all_chunks if c["id"] in to_add_ids]
|
|
1176
|
+
|
|
1177
|
+
# Prune stale chunks
|
|
1178
|
+
if to_delete:
|
|
1179
|
+
print(f" Pruning {len(to_delete)} stale chunks...", file=sys.stderr)
|
|
1180
|
+
self.collection.delete(ids=list(to_delete))
|
|
1181
|
+
|
|
1182
|
+
if not chunks_to_add:
|
|
1183
|
+
stats = {
|
|
1184
|
+
"indexed": 0,
|
|
1185
|
+
"pruned": len(to_delete),
|
|
1186
|
+
"total_files": len(files),
|
|
1187
|
+
"message": "No new chunks to index",
|
|
1188
|
+
"time_taken": round(time.time() - start_time, 1),
|
|
1189
|
+
}
|
|
1190
|
+
# Notify completion
|
|
1191
|
+
try:
|
|
1192
|
+
if notifier:
|
|
1193
|
+
await notifier.notify_reindex_complete(stats)
|
|
1194
|
+
except Exception as e:
|
|
1195
|
+
logger.warning(f"Failed to send reindex complete notification: {e}")
|
|
1196
|
+
return stats
|
|
1197
|
+
|
|
1198
|
+
# Batch embed and store
|
|
1199
|
+
batch_size = 50
|
|
1200
|
+
total_indexed = 0
|
|
1201
|
+
|
|
1202
|
+
for i in range(0, len(chunks_to_add), batch_size):
|
|
1203
|
+
batch = chunks_to_add[i : i + batch_size]
|
|
1204
|
+
|
|
1205
|
+
documents = [c["document"] for c in batch]
|
|
1206
|
+
embeddings = await self.get_embeddings_batch(documents)
|
|
1207
|
+
|
|
1208
|
+
self.collection.add(
|
|
1209
|
+
ids=[c["id"] for c in batch],
|
|
1210
|
+
documents=documents,
|
|
1211
|
+
embeddings=embeddings, # type: ignore[arg-type]
|
|
1212
|
+
metadatas=[c["metadata"] for c in batch],
|
|
1213
|
+
)
|
|
1214
|
+
total_indexed += len(batch)
|
|
1215
|
+
print(f" Indexed {total_indexed}/{len(chunks_to_add)} chunks...", file=sys.stderr)
|
|
1216
|
+
|
|
1217
|
+
stats = {
|
|
1218
|
+
"indexed": total_indexed,
|
|
1219
|
+
"pruned": len(to_delete),
|
|
1220
|
+
"total_files": len(files),
|
|
1221
|
+
"db_path": str(self.db_path),
|
|
1222
|
+
"time_taken": round(time.time() - start_time, 1),
|
|
1223
|
+
}
|
|
1224
|
+
|
|
1225
|
+
# Notify completion
|
|
1226
|
+
try:
|
|
1227
|
+
if notifier:
|
|
1228
|
+
await notifier.notify_reindex_complete(stats)
|
|
1229
|
+
except Exception as e:
|
|
1230
|
+
logger.warning(f"Failed to send reindex complete notification: {e}")
|
|
1231
|
+
|
|
1232
|
+
return stats
|
|
1233
|
+
|
|
1234
|
+
except Exception as e:
|
|
1235
|
+
error_msg = str(e)
|
|
1236
|
+
logger.error(f"Reindexing failed: {error_msg}")
|
|
1237
|
+
|
|
1238
|
+
# Notify error
|
|
1239
|
+
try:
|
|
1240
|
+
if notifier:
|
|
1241
|
+
await notifier.notify_reindex_error(error_msg)
|
|
1242
|
+
except Exception as notify_error:
|
|
1243
|
+
logger.warning(f"Failed to send reindex error notification: {notify_error}")
|
|
1244
|
+
|
|
1245
|
+
raise
|
|
1246
|
+
|
|
1247
|
+
async def search(
|
|
1248
|
+
self,
|
|
1249
|
+
query: str,
|
|
1250
|
+
n_results: int = 10,
|
|
1251
|
+
language: str | None = None,
|
|
1252
|
+
node_type: str | None = None,
|
|
1253
|
+
decorator: str | None = None,
|
|
1254
|
+
is_async: bool | None = None,
|
|
1255
|
+
base_class: str | None = None,
|
|
1256
|
+
) -> list[dict]:
|
|
1257
|
+
"""
|
|
1258
|
+
Search the codebase with a natural language query.
|
|
1259
|
+
|
|
1260
|
+
Args:
|
|
1261
|
+
query: Natural language search query
|
|
1262
|
+
n_results: Maximum number of results to return
|
|
1263
|
+
language: Filter by language (e.g., "py", "ts", "js")
|
|
1264
|
+
node_type: Filter by node type (e.g., "function", "class", "method")
|
|
1265
|
+
decorator: Filter by decorator (e.g., "@property", "@staticmethod")
|
|
1266
|
+
is_async: Filter by async status (True = async only, False = sync only)
|
|
1267
|
+
base_class: Filter by base class (e.g., "BaseClass")
|
|
1268
|
+
|
|
1269
|
+
Returns:
|
|
1270
|
+
List of matching code chunks with metadata.
|
|
1271
|
+
"""
|
|
1272
|
+
filters = []
|
|
1273
|
+
if language:
|
|
1274
|
+
filters.append(f"language={language}")
|
|
1275
|
+
if node_type:
|
|
1276
|
+
filters.append(f"node_type={node_type}")
|
|
1277
|
+
if decorator:
|
|
1278
|
+
filters.append(f"decorator={decorator}")
|
|
1279
|
+
if is_async is not None:
|
|
1280
|
+
filters.append(f"is_async={is_async}")
|
|
1281
|
+
if base_class:
|
|
1282
|
+
filters.append(f"base_class={base_class}")
|
|
1283
|
+
filter_str = f" [{', '.join(filters)}]" if filters else ""
|
|
1284
|
+
print(f"🔎 SEMANTIC-SEARCH: '{query[:50]}...'{filter_str}", file=sys.stderr)
|
|
1285
|
+
|
|
1286
|
+
if not await self.check_embedding_service():
|
|
1287
|
+
return [{"error": "Embedding service not available"}]
|
|
1288
|
+
|
|
1289
|
+
# Check if collection has documents
|
|
1290
|
+
try:
|
|
1291
|
+
count = self.collection.count()
|
|
1292
|
+
if count == 0:
|
|
1293
|
+
return [{"error": "No documents indexed", "hint": "Run index_codebase first"}]
|
|
1294
|
+
except Exception as e:
|
|
1295
|
+
return [{"error": f"Collection error: {e}"}]
|
|
1296
|
+
|
|
1297
|
+
# Get query embedding
|
|
1298
|
+
query_embedding = await self.get_embedding(query)
|
|
1299
|
+
|
|
1300
|
+
# Build where clause for metadata filtering
|
|
1301
|
+
where_filters = []
|
|
1302
|
+
if language:
|
|
1303
|
+
where_filters.append({"language": language})
|
|
1304
|
+
if node_type:
|
|
1305
|
+
where_filters.append({"node_type": node_type.lower()})
|
|
1306
|
+
if decorator:
|
|
1307
|
+
# ChromaDB $like for substring match in comma-separated field
|
|
1308
|
+
# Use % wildcards for pattern matching
|
|
1309
|
+
where_filters.append({"decorators": {"$like": f"%{decorator}%"}})
|
|
1310
|
+
if is_async is not None:
|
|
1311
|
+
where_filters.append({"is_async": is_async})
|
|
1312
|
+
if base_class:
|
|
1313
|
+
# Use $like for substring match
|
|
1314
|
+
where_filters.append({"base_classes": {"$like": f"%{base_class}%"}})
|
|
1315
|
+
|
|
1316
|
+
where_clause = None
|
|
1317
|
+
if len(where_filters) == 1:
|
|
1318
|
+
where_clause = where_filters[0]
|
|
1319
|
+
elif len(where_filters) > 1:
|
|
1320
|
+
where_clause = {"$and": where_filters}
|
|
1321
|
+
|
|
1322
|
+
# Search with optional filtering
|
|
1323
|
+
query_kwargs: dict = {
|
|
1324
|
+
"query_embeddings": [query_embedding],
|
|
1325
|
+
"n_results": n_results,
|
|
1326
|
+
"include": ["documents", "metadatas", "distances"],
|
|
1327
|
+
}
|
|
1328
|
+
if where_clause:
|
|
1329
|
+
query_kwargs["where"] = where_clause
|
|
1330
|
+
|
|
1331
|
+
results = self.collection.query(**query_kwargs)
|
|
1332
|
+
|
|
1333
|
+
# Format results
|
|
1334
|
+
formatted = []
|
|
1335
|
+
if results["ids"] and results["ids"][0]:
|
|
1336
|
+
for i, _doc_id in enumerate(results["ids"][0]):
|
|
1337
|
+
metadata = results["metadatas"][0][i] if results["metadatas"] else {}
|
|
1338
|
+
distance = results["distances"][0][i] if results["distances"] else 0
|
|
1339
|
+
document = results["documents"][0][i] if results["documents"] else ""
|
|
1340
|
+
|
|
1341
|
+
# Extract just the code part (skip file/line header)
|
|
1342
|
+
code_lines = document.split("\n\n", 1)
|
|
1343
|
+
code = code_lines[1] if len(code_lines) > 1 else document
|
|
1344
|
+
|
|
1345
|
+
formatted.append(
|
|
1346
|
+
{
|
|
1347
|
+
"file": metadata.get("file_path", "unknown"),
|
|
1348
|
+
"lines": f"{metadata.get('start_line', '?')}-{metadata.get('end_line', '?')}",
|
|
1349
|
+
"language": metadata.get("language", ""),
|
|
1350
|
+
"relevance": round(1 - distance, 3), # Convert distance to similarity
|
|
1351
|
+
"code_preview": code[:500] + "..." if len(code) > 500 else code,
|
|
1352
|
+
}
|
|
1353
|
+
)
|
|
1354
|
+
|
|
1355
|
+
return formatted
|
|
1356
|
+
|
|
1357
|
+
def get_stats(self) -> dict:
|
|
1358
|
+
"""Get statistics about the vector store."""
|
|
1359
|
+
try:
|
|
1360
|
+
count = self.collection.count()
|
|
1361
|
+
return {
|
|
1362
|
+
"project_path": str(self.project_path),
|
|
1363
|
+
"db_path": str(self.db_path),
|
|
1364
|
+
"chunks_indexed": count,
|
|
1365
|
+
"embedding_provider": self.provider.name,
|
|
1366
|
+
"embedding_dimension": self.provider.dimension,
|
|
1367
|
+
}
|
|
1368
|
+
except Exception as e:
|
|
1369
|
+
return {"error": str(e)}
|
|
1370
|
+
|
|
1371
|
+
def start_watching(self, debounce_seconds: float = 2.0) -> "CodebaseFileWatcher":
|
|
1372
|
+
"""Start watching the project directory for file changes.
|
|
1373
|
+
|
|
1374
|
+
Args:
|
|
1375
|
+
debounce_seconds: Time to wait before reindexing after changes (default: 2.0s)
|
|
1376
|
+
|
|
1377
|
+
Returns:
|
|
1378
|
+
The CodebaseFileWatcher instance
|
|
1379
|
+
"""
|
|
1380
|
+
with self._watcher_lock:
|
|
1381
|
+
if self._watcher is None:
|
|
1382
|
+
# Avoid circular import by importing here
|
|
1383
|
+
self._watcher = CodebaseFileWatcher(
|
|
1384
|
+
project_path=self.project_path,
|
|
1385
|
+
store=self,
|
|
1386
|
+
debounce_seconds=debounce_seconds,
|
|
1387
|
+
)
|
|
1388
|
+
self._watcher.start()
|
|
1389
|
+
else:
|
|
1390
|
+
if not self._watcher.is_running():
|
|
1391
|
+
self._watcher.start()
|
|
1392
|
+
else:
|
|
1393
|
+
logger.warning(f"Watcher for {self.project_path} is already running")
|
|
1394
|
+
return self._watcher
|
|
1395
|
+
|
|
1396
|
+
def stop_watching(self) -> bool:
|
|
1397
|
+
"""Stop watching the project directory.
|
|
1398
|
+
|
|
1399
|
+
Returns:
|
|
1400
|
+
True if watcher was stopped, False if no watcher was active
|
|
1401
|
+
"""
|
|
1402
|
+
with self._watcher_lock:
|
|
1403
|
+
if self._watcher is not None:
|
|
1404
|
+
self._watcher.stop()
|
|
1405
|
+
self._watcher = None
|
|
1406
|
+
return True
|
|
1407
|
+
return False
|
|
1408
|
+
|
|
1409
|
+
def is_watching(self) -> bool:
|
|
1410
|
+
"""Check if the project directory is being watched.
|
|
1411
|
+
|
|
1412
|
+
Returns:
|
|
1413
|
+
True if watcher is active and running, False otherwise
|
|
1414
|
+
"""
|
|
1415
|
+
with self._watcher_lock:
|
|
1416
|
+
if self._watcher is not None:
|
|
1417
|
+
return self._watcher.is_running()
|
|
1418
|
+
return False
|
|
1419
|
+
|
|
1420
|
+
|
|
1421
|
+
# --- Module-level API for MCP tools ---
|
|
1422
|
+
|
|
1423
|
+
_stores: dict[str, CodebaseVectorStore] = {}
|
|
1424
|
+
_stores_lock = threading.Lock()
|
|
1425
|
+
|
|
1426
|
+
# Module-level watcher management
|
|
1427
|
+
_watchers: dict[str, "CodebaseFileWatcher"] = {}
|
|
1428
|
+
_watchers_lock = threading.Lock()
|
|
1429
|
+
|
|
1430
|
+
|
|
1431
|
+
def get_store(project_path: str, provider: EmbeddingProvider = "ollama") -> CodebaseVectorStore:
|
|
1432
|
+
"""Get or create a vector store for a project.
|
|
1433
|
+
|
|
1434
|
+
Note: Cache key includes provider to prevent cross-provider conflicts
|
|
1435
|
+
(different providers have different embedding dimensions).
|
|
1436
|
+
"""
|
|
1437
|
+
path = str(Path(project_path).resolve())
|
|
1438
|
+
cache_key = f"{path}:{provider}"
|
|
1439
|
+
if cache_key not in _stores:
|
|
1440
|
+
with _stores_lock:
|
|
1441
|
+
# Double-check pattern to avoid race condition
|
|
1442
|
+
if cache_key not in _stores:
|
|
1443
|
+
_stores[cache_key] = CodebaseVectorStore(path, provider)
|
|
1444
|
+
return _stores[cache_key]
|
|
1445
|
+
|
|
1446
|
+
|
|
1447
|
+
async def semantic_search(
|
|
1448
|
+
query: str,
|
|
1449
|
+
project_path: str = ".",
|
|
1450
|
+
n_results: int = 10,
|
|
1451
|
+
language: str | None = None,
|
|
1452
|
+
node_type: str | None = None,
|
|
1453
|
+
decorator: str | None = None,
|
|
1454
|
+
is_async: bool | None = None,
|
|
1455
|
+
base_class: str | None = None,
|
|
1456
|
+
provider: EmbeddingProvider = "ollama",
|
|
1457
|
+
) -> str:
|
|
1458
|
+
"""
|
|
1459
|
+
Search codebase with natural language query.
|
|
1460
|
+
|
|
1461
|
+
Args:
|
|
1462
|
+
query: Natural language search query (e.g., "find authentication logic")
|
|
1463
|
+
project_path: Path to the project root
|
|
1464
|
+
n_results: Maximum number of results to return
|
|
1465
|
+
language: Filter by language (e.g., "py", "ts", "js")
|
|
1466
|
+
node_type: Filter by node type (e.g., "function", "class", "method")
|
|
1467
|
+
decorator: Filter by decorator (e.g., "@property", "@staticmethod")
|
|
1468
|
+
is_async: Filter by async status (True = async only, False = sync only)
|
|
1469
|
+
base_class: Filter by base class (e.g., "BaseClass")
|
|
1470
|
+
provider: Embedding provider (ollama, mxbai, gemini, openai, huggingface)
|
|
1471
|
+
|
|
1472
|
+
Returns:
|
|
1473
|
+
Formatted search results with file paths and code snippets.
|
|
1474
|
+
"""
|
|
1475
|
+
store = get_store(project_path, provider)
|
|
1476
|
+
results = await store.search(
|
|
1477
|
+
query,
|
|
1478
|
+
n_results,
|
|
1479
|
+
language,
|
|
1480
|
+
node_type,
|
|
1481
|
+
decorator=decorator,
|
|
1482
|
+
is_async=is_async,
|
|
1483
|
+
base_class=base_class,
|
|
1484
|
+
)
|
|
1485
|
+
|
|
1486
|
+
if not results:
|
|
1487
|
+
return "No results found"
|
|
1488
|
+
|
|
1489
|
+
if "error" in results[0]:
|
|
1490
|
+
return f"Error: {results[0]['error']}\nHint: {results[0].get('hint', 'Check Ollama is running')}"
|
|
1491
|
+
|
|
1492
|
+
lines = [f"Found {len(results)} results for: '{query}'\n"]
|
|
1493
|
+
for i, r in enumerate(results, 1):
|
|
1494
|
+
lines.append(f"{i}. {r['file']}:{r['lines']} (relevance: {r['relevance']})")
|
|
1495
|
+
lines.append(f"```{r['language']}")
|
|
1496
|
+
lines.append(r["code_preview"])
|
|
1497
|
+
lines.append("```\n")
|
|
1498
|
+
|
|
1499
|
+
return "\n".join(lines)
|
|
1500
|
+
|
|
1501
|
+
|
|
1502
|
+
async def hybrid_search(
|
|
1503
|
+
query: str,
|
|
1504
|
+
pattern: str | None = None,
|
|
1505
|
+
project_path: str = ".",
|
|
1506
|
+
n_results: int = 10,
|
|
1507
|
+
language: str | None = None,
|
|
1508
|
+
node_type: str | None = None,
|
|
1509
|
+
decorator: str | None = None,
|
|
1510
|
+
is_async: bool | None = None,
|
|
1511
|
+
base_class: str | None = None,
|
|
1512
|
+
provider: EmbeddingProvider = "ollama",
|
|
1513
|
+
) -> str:
|
|
1514
|
+
"""
|
|
1515
|
+
Hybrid search combining semantic similarity with structural AST matching.
|
|
1516
|
+
|
|
1517
|
+
Performs semantic search first, then optionally filters/boosts results
|
|
1518
|
+
that also match an ast-grep structural pattern.
|
|
1519
|
+
|
|
1520
|
+
Args:
|
|
1521
|
+
query: Natural language search query (e.g., "find authentication logic")
|
|
1522
|
+
pattern: Optional ast-grep pattern for structural matching (e.g., "def $FUNC($$$):")
|
|
1523
|
+
project_path: Path to the project root
|
|
1524
|
+
n_results: Maximum number of results to return
|
|
1525
|
+
language: Filter by language (e.g., "py", "ts", "js")
|
|
1526
|
+
node_type: Filter by node type (e.g., "function", "class", "method")
|
|
1527
|
+
decorator: Filter by decorator (e.g., "@property", "@staticmethod")
|
|
1528
|
+
is_async: Filter by async status (True = async only, False = sync only)
|
|
1529
|
+
base_class: Filter by base class (e.g., "BaseClass")
|
|
1530
|
+
provider: Embedding provider (ollama, gemini, openai)
|
|
1531
|
+
|
|
1532
|
+
Returns:
|
|
1533
|
+
Formatted search results with relevance scores and structural match indicators.
|
|
1534
|
+
"""
|
|
1535
|
+
from mcp_bridge.tools.code_search import ast_grep_search
|
|
1536
|
+
|
|
1537
|
+
# Get semantic results (fetch more if we're going to filter)
|
|
1538
|
+
fetch_count = n_results * 2 if pattern else n_results
|
|
1539
|
+
semantic_result = await semantic_search(
|
|
1540
|
+
query=query,
|
|
1541
|
+
project_path=project_path,
|
|
1542
|
+
n_results=fetch_count,
|
|
1543
|
+
language=language,
|
|
1544
|
+
node_type=node_type,
|
|
1545
|
+
decorator=decorator,
|
|
1546
|
+
is_async=is_async,
|
|
1547
|
+
base_class=base_class,
|
|
1548
|
+
provider=provider,
|
|
1549
|
+
)
|
|
1550
|
+
|
|
1551
|
+
if not pattern:
|
|
1552
|
+
return semantic_result
|
|
1553
|
+
|
|
1554
|
+
if semantic_result.startswith("Error:") or semantic_result == "No results found":
|
|
1555
|
+
return semantic_result
|
|
1556
|
+
|
|
1557
|
+
# Get structural matches from ast-grep
|
|
1558
|
+
ast_result = await ast_grep_search(
|
|
1559
|
+
pattern=pattern,
|
|
1560
|
+
directory=project_path,
|
|
1561
|
+
language=language or "",
|
|
1562
|
+
)
|
|
1563
|
+
|
|
1564
|
+
# Extract file paths from ast-grep results
|
|
1565
|
+
ast_files: set[str] = set()
|
|
1566
|
+
if ast_result and not ast_result.startswith("Error:") and ast_result != "No matches found":
|
|
1567
|
+
for line in ast_result.split("\n"):
|
|
1568
|
+
if line.startswith("- "):
|
|
1569
|
+
# Format: "- file.py:123"
|
|
1570
|
+
file_part = line[2:].split(":")[0]
|
|
1571
|
+
ast_files.add(file_part)
|
|
1572
|
+
|
|
1573
|
+
if not ast_files:
|
|
1574
|
+
# No structural matches, return semantic results with note
|
|
1575
|
+
return f"{semantic_result}\n\n[Note: No structural matches for pattern '{pattern}']"
|
|
1576
|
+
|
|
1577
|
+
# Parse semantic results and boost/annotate files that appear in both
|
|
1578
|
+
lines = []
|
|
1579
|
+
result_lines = semantic_result.split("\n")
|
|
1580
|
+
header = result_lines[0] if result_lines else ""
|
|
1581
|
+
lines.append(header.replace("results for:", "hybrid results for:"))
|
|
1582
|
+
lines.append(f"[Structural pattern: {pattern}]\n")
|
|
1583
|
+
|
|
1584
|
+
i = 1
|
|
1585
|
+
boosted_count = 0
|
|
1586
|
+
while i < len(result_lines):
|
|
1587
|
+
line = result_lines[i]
|
|
1588
|
+
if line and (line[0].isdigit() or line.startswith("```") or line.strip()):
|
|
1589
|
+
# Check if this is a result header line (e.g., "1. file.py:10-20")
|
|
1590
|
+
if line and line[0].isdigit() and "." in line:
|
|
1591
|
+
file_part = line.split()[1].split(":")[0] if len(line.split()) > 1 else ""
|
|
1592
|
+
if file_part in ast_files:
|
|
1593
|
+
lines.append(f"{line} 🎯 [structural match]")
|
|
1594
|
+
boosted_count += 1
|
|
1595
|
+
else:
|
|
1596
|
+
lines.append(line)
|
|
1597
|
+
else:
|
|
1598
|
+
lines.append(line)
|
|
1599
|
+
else:
|
|
1600
|
+
lines.append(line)
|
|
1601
|
+
i += 1
|
|
1602
|
+
|
|
1603
|
+
lines.append(
|
|
1604
|
+
f"\n[{boosted_count}/{len(ast_files)} semantic results also match structural pattern]"
|
|
1605
|
+
)
|
|
1606
|
+
|
|
1607
|
+
return "\n".join(lines)
|
|
1608
|
+
|
|
1609
|
+
|
|
1610
|
+
async def index_codebase(
|
|
1611
|
+
project_path: str = ".",
|
|
1612
|
+
force: bool = False,
|
|
1613
|
+
provider: EmbeddingProvider = "ollama",
|
|
1614
|
+
) -> str:
|
|
1615
|
+
"""
|
|
1616
|
+
Index a codebase for semantic search.
|
|
1617
|
+
|
|
1618
|
+
Args:
|
|
1619
|
+
project_path: Path to the project root
|
|
1620
|
+
force: If True, reindex everything. Otherwise, only new/changed files.
|
|
1621
|
+
provider: Embedding provider - ollama (local/free), mxbai (local/free),
|
|
1622
|
+
gemini (cloud/OAuth), openai (cloud/OAuth), huggingface (cloud/token)
|
|
1623
|
+
|
|
1624
|
+
Returns:
|
|
1625
|
+
Indexing statistics.
|
|
1626
|
+
"""
|
|
1627
|
+
store = get_store(project_path, provider)
|
|
1628
|
+
stats = await store.index_codebase(force=force)
|
|
1629
|
+
|
|
1630
|
+
if "error" in stats:
|
|
1631
|
+
return f"Error: {stats['error']}"
|
|
1632
|
+
|
|
1633
|
+
return (
|
|
1634
|
+
f"Indexed {stats['indexed']} chunks from {stats['total_files']} files\n"
|
|
1635
|
+
f"Database: {stats.get('db_path', 'unknown')}\n"
|
|
1636
|
+
f"{stats.get('message', '')}"
|
|
1637
|
+
)
|
|
1638
|
+
|
|
1639
|
+
|
|
1640
|
+
async def semantic_stats(
|
|
1641
|
+
project_path: str = ".",
|
|
1642
|
+
provider: EmbeddingProvider = "ollama",
|
|
1643
|
+
) -> str:
|
|
1644
|
+
"""
|
|
1645
|
+
Get statistics about the semantic search index.
|
|
1646
|
+
|
|
1647
|
+
Args:
|
|
1648
|
+
project_path: Path to the project root
|
|
1649
|
+
provider: Embedding provider - ollama (local/free), mxbai (local/free),
|
|
1650
|
+
gemini (cloud/OAuth), openai (cloud/OAuth), huggingface (cloud/token)
|
|
1651
|
+
|
|
1652
|
+
Returns:
|
|
1653
|
+
Index statistics.
|
|
1654
|
+
"""
|
|
1655
|
+
store = get_store(project_path, provider)
|
|
1656
|
+
stats = store.get_stats()
|
|
1657
|
+
|
|
1658
|
+
if "error" in stats:
|
|
1659
|
+
return f"Error: {stats['error']}"
|
|
1660
|
+
|
|
1661
|
+
return (
|
|
1662
|
+
f"Project: {stats['project_path']}\n"
|
|
1663
|
+
f"Database: {stats['db_path']}\n"
|
|
1664
|
+
f"Chunks indexed: {stats['chunks_indexed']}\n"
|
|
1665
|
+
f"Embedding provider: {stats['embedding_provider']} ({stats['embedding_dimension']} dims)"
|
|
1666
|
+
)
|
|
1667
|
+
|
|
1668
|
+
|
|
1669
|
+
async def semantic_health(project_path: str = ".", provider: EmbeddingProvider = "ollama") -> str:
|
|
1670
|
+
"""Check health of semantic search system."""
|
|
1671
|
+
store = get_store(project_path, provider)
|
|
1672
|
+
|
|
1673
|
+
status = []
|
|
1674
|
+
|
|
1675
|
+
# Check Provider
|
|
1676
|
+
try:
|
|
1677
|
+
is_avail = await store.check_embedding_service()
|
|
1678
|
+
status.append(
|
|
1679
|
+
f"Provider ({store.provider.name}): {'✅ Online' if is_avail else '❌ Offline'}"
|
|
1680
|
+
)
|
|
1681
|
+
except Exception as e:
|
|
1682
|
+
status.append(f"Provider ({store.provider.name}): ❌ Error - {e}")
|
|
1683
|
+
|
|
1684
|
+
# Check DB
|
|
1685
|
+
try:
|
|
1686
|
+
count = store.collection.count()
|
|
1687
|
+
status.append(f"Vector DB: ✅ Online ({count} documents)")
|
|
1688
|
+
except Exception as e:
|
|
1689
|
+
status.append(f"Vector DB: ❌ Error - {e}")
|
|
1690
|
+
|
|
1691
|
+
return "\n".join(status)
|
|
1692
|
+
|
|
1693
|
+
|
|
1694
|
+
# ========================
|
|
1695
|
+
# FILE WATCHER MANAGEMENT
|
|
1696
|
+
# ========================
|
|
1697
|
+
|
|
1698
|
+
|
|
1699
|
+
def start_file_watcher(
|
|
1700
|
+
project_path: str,
|
|
1701
|
+
provider: EmbeddingProvider = "ollama",
|
|
1702
|
+
debounce_seconds: float = 2.0,
|
|
1703
|
+
) -> "CodebaseFileWatcher":
|
|
1704
|
+
"""Start watching a project directory for file changes.
|
|
1705
|
+
|
|
1706
|
+
Args:
|
|
1707
|
+
project_path: Path to the project root
|
|
1708
|
+
provider: Embedding provider to use for reindexing
|
|
1709
|
+
debounce_seconds: Time to wait before reindexing after changes
|
|
1710
|
+
|
|
1711
|
+
Returns:
|
|
1712
|
+
The started CodebaseFileWatcher instance
|
|
1713
|
+
"""
|
|
1714
|
+
path = str(Path(project_path).resolve())
|
|
1715
|
+
with _watchers_lock:
|
|
1716
|
+
if path not in _watchers:
|
|
1717
|
+
store = get_store(project_path, provider)
|
|
1718
|
+
|
|
1719
|
+
# Check if index exists
|
|
1720
|
+
try:
|
|
1721
|
+
stats = store.get_stats()
|
|
1722
|
+
chunks_indexed = stats.get("chunks_indexed", 0)
|
|
1723
|
+
if chunks_indexed == 0:
|
|
1724
|
+
logger.warning(
|
|
1725
|
+
f"No index found for {path}. Consider running semantic_index() "
|
|
1726
|
+
f"first for better performance. FileWatcher will still monitor changes."
|
|
1727
|
+
)
|
|
1728
|
+
except Exception as e:
|
|
1729
|
+
logger.debug(f"Could not check index status: {e}")
|
|
1730
|
+
|
|
1731
|
+
watcher = store.start_watching(debounce_seconds=debounce_seconds)
|
|
1732
|
+
_watchers[path] = watcher
|
|
1733
|
+
else:
|
|
1734
|
+
watcher = _watchers[path]
|
|
1735
|
+
if not watcher.is_running():
|
|
1736
|
+
watcher.start()
|
|
1737
|
+
return _watchers[path]
|
|
1738
|
+
|
|
1739
|
+
|
|
1740
|
+
def stop_file_watcher(project_path: str) -> bool:
|
|
1741
|
+
"""Stop watching a project directory.
|
|
1742
|
+
|
|
1743
|
+
Args:
|
|
1744
|
+
project_path: Path to the project root
|
|
1745
|
+
|
|
1746
|
+
Returns:
|
|
1747
|
+
True if watcher was stopped, False if no watcher was active
|
|
1748
|
+
"""
|
|
1749
|
+
path = str(Path(project_path).resolve())
|
|
1750
|
+
with _watchers_lock:
|
|
1751
|
+
if path in _watchers:
|
|
1752
|
+
watcher = _watchers[path]
|
|
1753
|
+
watcher.stop()
|
|
1754
|
+
del _watchers[path]
|
|
1755
|
+
return True
|
|
1756
|
+
return False
|
|
1757
|
+
|
|
1758
|
+
|
|
1759
|
+
def get_file_watcher(project_path: str) -> "CodebaseFileWatcher | None":
|
|
1760
|
+
"""Get an active file watcher for a project.
|
|
1761
|
+
|
|
1762
|
+
Args:
|
|
1763
|
+
project_path: Path to the project root
|
|
1764
|
+
|
|
1765
|
+
Returns:
|
|
1766
|
+
The CodebaseFileWatcher if active, None otherwise
|
|
1767
|
+
"""
|
|
1768
|
+
path = str(Path(project_path).resolve())
|
|
1769
|
+
with _watchers_lock:
|
|
1770
|
+
watcher = _watchers.get(path)
|
|
1771
|
+
if watcher is not None and watcher.is_running():
|
|
1772
|
+
return watcher
|
|
1773
|
+
return None
|
|
1774
|
+
|
|
1775
|
+
|
|
1776
|
+
def list_file_watchers() -> list[dict]:
|
|
1777
|
+
"""List all active file watchers.
|
|
1778
|
+
|
|
1779
|
+
Returns:
|
|
1780
|
+
List of dicts with watcher info (project_path, debounce_seconds, provider, status)
|
|
1781
|
+
"""
|
|
1782
|
+
with _watchers_lock:
|
|
1783
|
+
watchers_info = []
|
|
1784
|
+
for path, watcher in _watchers.items():
|
|
1785
|
+
watchers_info.append(
|
|
1786
|
+
{
|
|
1787
|
+
"project_path": path,
|
|
1788
|
+
"debounce_seconds": watcher.debounce_seconds,
|
|
1789
|
+
"provider": watcher.store.provider_name,
|
|
1790
|
+
"status": "running" if watcher.is_running() else "stopped",
|
|
1791
|
+
}
|
|
1792
|
+
)
|
|
1793
|
+
return watchers_info
|
|
1794
|
+
|
|
1795
|
+
|
|
1796
|
+
# ========================
|
|
1797
|
+
# MULTI-QUERY EXPANSION & DECOMPOSITION
|
|
1798
|
+
# ========================
|
|
1799
|
+
|
|
1800
|
+
|
|
1801
|
+
async def _expand_query_with_llm(query: str, num_variations: int = 3) -> list[str]:
|
|
1802
|
+
"""
|
|
1803
|
+
Use LLM to rephrase a query into multiple semantic variations.
|
|
1804
|
+
|
|
1805
|
+
For example: "database connection" -> ["SQLAlchemy engine setup",
|
|
1806
|
+
"connect to postgres", "db session management"]
|
|
1807
|
+
|
|
1808
|
+
Args:
|
|
1809
|
+
query: Original search query
|
|
1810
|
+
num_variations: Number of variations to generate (default: 3)
|
|
1811
|
+
|
|
1812
|
+
Returns:
|
|
1813
|
+
List of query variations including the original
|
|
1814
|
+
"""
|
|
1815
|
+
from mcp_bridge.tools.model_invoke import invoke_gemini
|
|
1816
|
+
|
|
1817
|
+
prompt = f"""You are a code search query expander. Given a search query, generate {num_variations} alternative phrasings that would help find relevant code.
|
|
1818
|
+
|
|
1819
|
+
Original query: "{query}"
|
|
1820
|
+
|
|
1821
|
+
Generate {num_variations} alternative queries that:
|
|
1822
|
+
1. Use different technical terminology (e.g., "database" -> "SQLAlchemy", "ORM", "connection pool")
|
|
1823
|
+
2. Reference specific implementations or patterns
|
|
1824
|
+
3. Include related concepts that might appear in code
|
|
1825
|
+
|
|
1826
|
+
Return ONLY the alternative queries, one per line. No numbering, no explanations.
|
|
1827
|
+
Example output for "database connection":
|
|
1828
|
+
SQLAlchemy engine configuration
|
|
1829
|
+
postgres connection setup
|
|
1830
|
+
db session factory pattern"""
|
|
1831
|
+
|
|
1832
|
+
try:
|
|
1833
|
+
result = await invoke_gemini(
|
|
1834
|
+
token_store=TokenStore(),
|
|
1835
|
+
prompt=prompt,
|
|
1836
|
+
model="gemini-2.0-flash",
|
|
1837
|
+
temperature=0.7,
|
|
1838
|
+
max_tokens=200,
|
|
1839
|
+
)
|
|
1840
|
+
|
|
1841
|
+
# Parse variations from response
|
|
1842
|
+
variations = [line.strip() for line in result.strip().split("\n") if line.strip()]
|
|
1843
|
+
# Always include original query first
|
|
1844
|
+
all_queries = [query] + variations[:num_variations]
|
|
1845
|
+
return all_queries
|
|
1846
|
+
|
|
1847
|
+
except Exception as e:
|
|
1848
|
+
logger.warning(f"Query expansion failed: {e}, using original query only")
|
|
1849
|
+
return [query]
|
|
1850
|
+
|
|
1851
|
+
|
|
1852
|
+
async def _decompose_query_with_llm(query: str) -> list[str]:
|
|
1853
|
+
"""
|
|
1854
|
+
Break a complex query into smaller, focused sub-questions.
|
|
1855
|
+
|
|
1856
|
+
For example: "Initialize the DB and then create a user model" ->
|
|
1857
|
+
["database initialization", "user model definition"]
|
|
1858
|
+
|
|
1859
|
+
Args:
|
|
1860
|
+
query: Complex search query
|
|
1861
|
+
|
|
1862
|
+
Returns:
|
|
1863
|
+
List of sub-queries, or [query] if decomposition not needed
|
|
1864
|
+
"""
|
|
1865
|
+
from mcp_bridge.tools.model_invoke import invoke_gemini
|
|
1866
|
+
|
|
1867
|
+
prompt = f"""You are a code search query analyzer. Determine if this query should be broken into sub-queries.
|
|
1868
|
+
|
|
1869
|
+
Query: "{query}"
|
|
1870
|
+
|
|
1871
|
+
If the query contains multiple distinct concepts (connected by "and", "then", "also", etc.),
|
|
1872
|
+
break it into separate focused sub-queries.
|
|
1873
|
+
|
|
1874
|
+
If the query is already focused on a single concept, return just that query.
|
|
1875
|
+
|
|
1876
|
+
Return ONLY the sub-queries, one per line. No numbering, no explanations.
|
|
1877
|
+
|
|
1878
|
+
Examples:
|
|
1879
|
+
- "Initialize the DB and then create a user model" ->
|
|
1880
|
+
database initialization
|
|
1881
|
+
user model definition
|
|
1882
|
+
|
|
1883
|
+
- "authentication logic" ->
|
|
1884
|
+
authentication logic"""
|
|
1885
|
+
|
|
1886
|
+
try:
|
|
1887
|
+
result = await invoke_gemini(
|
|
1888
|
+
token_store=TokenStore(),
|
|
1889
|
+
prompt=prompt,
|
|
1890
|
+
model="gemini-2.0-flash",
|
|
1891
|
+
temperature=0.3, # Lower temperature for more consistent decomposition
|
|
1892
|
+
max_tokens=150,
|
|
1893
|
+
)
|
|
1894
|
+
|
|
1895
|
+
# Parse sub-queries from response
|
|
1896
|
+
sub_queries = [line.strip() for line in result.strip().split("\n") if line.strip()]
|
|
1897
|
+
return sub_queries if sub_queries else [query]
|
|
1898
|
+
|
|
1899
|
+
except Exception as e:
|
|
1900
|
+
logger.warning(f"Query decomposition failed: {e}, using original query")
|
|
1901
|
+
return [query]
|
|
1902
|
+
|
|
1903
|
+
|
|
1904
|
+
def _aggregate_results(
|
|
1905
|
+
all_results: list[list[dict]],
|
|
1906
|
+
n_results: int = 10,
|
|
1907
|
+
) -> list[dict]:
|
|
1908
|
+
"""
|
|
1909
|
+
Aggregate and deduplicate results from multiple queries.
|
|
1910
|
+
|
|
1911
|
+
Uses reciprocal rank fusion to combine relevance scores from different queries.
|
|
1912
|
+
|
|
1913
|
+
Args:
|
|
1914
|
+
all_results: List of result lists from different queries
|
|
1915
|
+
n_results: Maximum number of results to return
|
|
1916
|
+
|
|
1917
|
+
Returns:
|
|
1918
|
+
Deduplicated and re-ranked results
|
|
1919
|
+
"""
|
|
1920
|
+
# Track seen files to avoid duplicates
|
|
1921
|
+
seen_files: dict[str, dict] = {} # file:lines -> result with best score
|
|
1922
|
+
file_scores: dict[str, float] = {} # file:lines -> aggregated score
|
|
1923
|
+
|
|
1924
|
+
# Reciprocal Rank Fusion constant
|
|
1925
|
+
k = 60
|
|
1926
|
+
|
|
1927
|
+
for query_idx, results in enumerate(all_results):
|
|
1928
|
+
for rank, result in enumerate(results):
|
|
1929
|
+
file_key = f"{result.get('file', '')}:{result.get('lines', '')}"
|
|
1930
|
+
|
|
1931
|
+
# RRF score contribution
|
|
1932
|
+
rrf_score = 1 / (k + rank + 1)
|
|
1933
|
+
|
|
1934
|
+
if file_key not in seen_files:
|
|
1935
|
+
seen_files[file_key] = result.copy()
|
|
1936
|
+
file_scores[file_key] = rrf_score
|
|
1937
|
+
else:
|
|
1938
|
+
# Aggregate scores
|
|
1939
|
+
file_scores[file_key] += rrf_score
|
|
1940
|
+
# Keep higher original relevance if available
|
|
1941
|
+
if result.get("relevance", 0) > seen_files[file_key].get("relevance", 0):
|
|
1942
|
+
seen_files[file_key] = result.copy()
|
|
1943
|
+
|
|
1944
|
+
# Sort by aggregated score and return top N
|
|
1945
|
+
sorted_keys = sorted(file_scores.keys(), key=lambda k: file_scores[k], reverse=True)
|
|
1946
|
+
|
|
1947
|
+
aggregated = []
|
|
1948
|
+
for key in sorted_keys[:n_results]:
|
|
1949
|
+
result = seen_files[key]
|
|
1950
|
+
# Update relevance to reflect aggregated score (normalized)
|
|
1951
|
+
max_score = max(file_scores.values()) if file_scores else 1
|
|
1952
|
+
result["relevance"] = round(file_scores[key] / max_score, 3)
|
|
1953
|
+
aggregated.append(result)
|
|
1954
|
+
|
|
1955
|
+
return aggregated
|
|
1956
|
+
|
|
1957
|
+
|
|
1958
|
+
async def multi_query_search(
|
|
1959
|
+
query: str,
|
|
1960
|
+
project_path: str = ".",
|
|
1961
|
+
n_results: int = 10,
|
|
1962
|
+
num_expansions: int = 3,
|
|
1963
|
+
language: str | None = None,
|
|
1964
|
+
node_type: str | None = None,
|
|
1965
|
+
provider: EmbeddingProvider = "ollama",
|
|
1966
|
+
) -> str:
|
|
1967
|
+
"""
|
|
1968
|
+
Search with LLM-expanded query variations for better recall.
|
|
1969
|
+
|
|
1970
|
+
Rephrases the query into multiple semantic variations, searches for each,
|
|
1971
|
+
and aggregates results using reciprocal rank fusion.
|
|
1972
|
+
|
|
1973
|
+
Args:
|
|
1974
|
+
query: Natural language search query
|
|
1975
|
+
project_path: Path to the project root
|
|
1976
|
+
n_results: Maximum number of results to return
|
|
1977
|
+
num_expansions: Number of query variations to generate (default: 3)
|
|
1978
|
+
language: Filter by language (e.g., "py", "ts")
|
|
1979
|
+
node_type: Filter by node type (e.g., "function", "class")
|
|
1980
|
+
provider: Embedding provider
|
|
1981
|
+
|
|
1982
|
+
Returns:
|
|
1983
|
+
Formatted search results with relevance scores.
|
|
1984
|
+
"""
|
|
1985
|
+
import asyncio
|
|
1986
|
+
|
|
1987
|
+
print(f"🔍 MULTI-QUERY: Expanding '{query[:50]}...'", file=sys.stderr)
|
|
1988
|
+
|
|
1989
|
+
# Get query expansions
|
|
1990
|
+
expanded_queries = await _expand_query_with_llm(query, num_expansions)
|
|
1991
|
+
print(f" Generated {len(expanded_queries)} query variations", file=sys.stderr)
|
|
1992
|
+
|
|
1993
|
+
# Get store once
|
|
1994
|
+
store = get_store(project_path, provider)
|
|
1995
|
+
|
|
1996
|
+
# Search with all queries in parallel
|
|
1997
|
+
async def search_single(q: str) -> list[dict]:
|
|
1998
|
+
return await store.search(
|
|
1999
|
+
q,
|
|
2000
|
+
n_results=n_results, # Get full results for each query
|
|
2001
|
+
language=language,
|
|
2002
|
+
node_type=node_type,
|
|
2003
|
+
)
|
|
2004
|
+
|
|
2005
|
+
all_results = await asyncio.gather(*[search_single(q) for q in expanded_queries])
|
|
2006
|
+
|
|
2007
|
+
# Filter out error results
|
|
2008
|
+
valid_results = [r for r in all_results if r and "error" not in r[0]]
|
|
2009
|
+
|
|
2010
|
+
if not valid_results:
|
|
2011
|
+
if all_results and all_results[0] and "error" in all_results[0][0]:
|
|
2012
|
+
return f"Error: {all_results[0][0]['error']}"
|
|
2013
|
+
return "No results found"
|
|
2014
|
+
|
|
2015
|
+
# Aggregate results
|
|
2016
|
+
aggregated = _aggregate_results(valid_results, n_results)
|
|
2017
|
+
|
|
2018
|
+
if not aggregated:
|
|
2019
|
+
return "No results found"
|
|
2020
|
+
|
|
2021
|
+
# Format output
|
|
2022
|
+
lines = [f"Found {len(aggregated)} results for multi-query expansion of: '{query}'"]
|
|
2023
|
+
lines.append(
|
|
2024
|
+
f"[Expanded to: {', '.join(q[:30] + '...' if len(q) > 30 else q for q in expanded_queries)}]\n"
|
|
2025
|
+
)
|
|
2026
|
+
|
|
2027
|
+
for i, r in enumerate(aggregated, 1):
|
|
2028
|
+
lines.append(f"{i}. {r['file']}:{r['lines']} (relevance: {r['relevance']})")
|
|
2029
|
+
lines.append(f"```{r.get('language', '')}")
|
|
2030
|
+
lines.append(r.get("code_preview", ""))
|
|
2031
|
+
lines.append("```\n")
|
|
2032
|
+
|
|
2033
|
+
return "\n".join(lines)
|
|
2034
|
+
|
|
2035
|
+
|
|
2036
|
+
async def decomposed_search(
|
|
2037
|
+
query: str,
|
|
2038
|
+
project_path: str = ".",
|
|
2039
|
+
n_results: int = 10,
|
|
2040
|
+
language: str | None = None,
|
|
2041
|
+
node_type: str | None = None,
|
|
2042
|
+
provider: EmbeddingProvider = "ollama",
|
|
2043
|
+
) -> str:
|
|
2044
|
+
"""
|
|
2045
|
+
Search by decomposing complex queries into focused sub-questions.
|
|
2046
|
+
|
|
2047
|
+
Breaks multi-part queries like "Initialize the DB and create a user model"
|
|
2048
|
+
into separate searches, returning organized results for each part.
|
|
2049
|
+
|
|
2050
|
+
Args:
|
|
2051
|
+
query: Complex search query (may contain multiple concepts)
|
|
2052
|
+
project_path: Path to the project root
|
|
2053
|
+
n_results: Maximum results per sub-query
|
|
2054
|
+
language: Filter by language
|
|
2055
|
+
node_type: Filter by node type
|
|
2056
|
+
provider: Embedding provider
|
|
2057
|
+
|
|
2058
|
+
Returns:
|
|
2059
|
+
Formatted results organized by sub-question.
|
|
2060
|
+
"""
|
|
2061
|
+
import asyncio
|
|
2062
|
+
|
|
2063
|
+
print(f"🔍 DECOMPOSED-SEARCH: Analyzing '{query[:50]}...'", file=sys.stderr)
|
|
2064
|
+
|
|
2065
|
+
# Decompose query
|
|
2066
|
+
sub_queries = await _decompose_query_with_llm(query)
|
|
2067
|
+
print(f" Decomposed into {len(sub_queries)} sub-queries", file=sys.stderr)
|
|
2068
|
+
|
|
2069
|
+
if len(sub_queries) == 1 and sub_queries[0] == query:
|
|
2070
|
+
# No decomposition needed, use regular search
|
|
2071
|
+
return await semantic_search(
|
|
2072
|
+
query=query,
|
|
2073
|
+
project_path=project_path,
|
|
2074
|
+
n_results=n_results,
|
|
2075
|
+
language=language,
|
|
2076
|
+
node_type=node_type,
|
|
2077
|
+
provider=provider,
|
|
2078
|
+
)
|
|
2079
|
+
|
|
2080
|
+
# Get store once
|
|
2081
|
+
store = get_store(project_path, provider)
|
|
2082
|
+
|
|
2083
|
+
# Search each sub-query in parallel
|
|
2084
|
+
async def search_sub(q: str) -> tuple[str, list[dict]]:
|
|
2085
|
+
results = await store.search(
|
|
2086
|
+
q,
|
|
2087
|
+
n_results=n_results // len(sub_queries) + 2, # Distribute results
|
|
2088
|
+
language=language,
|
|
2089
|
+
node_type=node_type,
|
|
2090
|
+
)
|
|
2091
|
+
return (q, results)
|
|
2092
|
+
|
|
2093
|
+
sub_results = await asyncio.gather(*[search_sub(q) for q in sub_queries])
|
|
2094
|
+
|
|
2095
|
+
# Format output with sections for each sub-query
|
|
2096
|
+
lines = [f"Decomposed search for: '{query}'"]
|
|
2097
|
+
lines.append(f"[Split into {len(sub_queries)} sub-queries]\n")
|
|
2098
|
+
|
|
2099
|
+
total_results = 0
|
|
2100
|
+
for sub_query, results in sub_results:
|
|
2101
|
+
lines.append(f"### {sub_query}")
|
|
2102
|
+
|
|
2103
|
+
if not results or (results and "error" in results[0]):
|
|
2104
|
+
lines.append(" No results found\n")
|
|
2105
|
+
continue
|
|
2106
|
+
|
|
2107
|
+
for i, r in enumerate(results[:5], 1): # Limit per sub-query
|
|
2108
|
+
lines.append(f" {i}. {r['file']}:{r['lines']} (relevance: {r['relevance']})")
|
|
2109
|
+
# Shorter preview for decomposed results
|
|
2110
|
+
preview = r.get("code_preview", "")[:200]
|
|
2111
|
+
if len(r.get("code_preview", "")) > 200:
|
|
2112
|
+
preview += "..."
|
|
2113
|
+
lines.append(f" ```{r.get('language', '')}")
|
|
2114
|
+
lines.append(f" {preview}")
|
|
2115
|
+
lines.append(" ```")
|
|
2116
|
+
total_results += 1
|
|
2117
|
+
lines.append("")
|
|
2118
|
+
|
|
2119
|
+
lines.append(f"[Total: {total_results} results across {len(sub_queries)} sub-queries]")
|
|
2120
|
+
|
|
2121
|
+
return "\n".join(lines)
|
|
2122
|
+
|
|
2123
|
+
|
|
2124
|
+
async def enhanced_search(
|
|
2125
|
+
query: str,
|
|
2126
|
+
project_path: str = ".",
|
|
2127
|
+
n_results: int = 10,
|
|
2128
|
+
mode: str = "auto",
|
|
2129
|
+
language: str | None = None,
|
|
2130
|
+
node_type: str | None = None,
|
|
2131
|
+
provider: EmbeddingProvider = "ollama",
|
|
2132
|
+
) -> str:
|
|
2133
|
+
"""
|
|
2134
|
+
Unified enhanced search combining expansion and decomposition.
|
|
2135
|
+
|
|
2136
|
+
Automatically selects the best strategy based on query complexity:
|
|
2137
|
+
- Simple queries: Multi-query expansion for better recall
|
|
2138
|
+
- Complex queries: Decomposition + expansion for comprehensive coverage
|
|
2139
|
+
|
|
2140
|
+
Args:
|
|
2141
|
+
query: Search query (simple or complex)
|
|
2142
|
+
project_path: Path to the project root
|
|
2143
|
+
n_results: Maximum number of results
|
|
2144
|
+
mode: Search mode - "auto", "expand", "decompose", or "both"
|
|
2145
|
+
language: Filter by language
|
|
2146
|
+
node_type: Filter by node type
|
|
2147
|
+
provider: Embedding provider
|
|
2148
|
+
|
|
2149
|
+
Returns:
|
|
2150
|
+
Formatted search results.
|
|
2151
|
+
"""
|
|
2152
|
+
# Detect query complexity
|
|
2153
|
+
complex_indicators = [" and ", " then ", " also ", " with ", ", then", ". then", "; "]
|
|
2154
|
+
is_complex = any(ind in query.lower() for ind in complex_indicators)
|
|
2155
|
+
|
|
2156
|
+
# Determine mode
|
|
2157
|
+
if mode == "auto":
|
|
2158
|
+
mode = "decompose" if is_complex else "expand"
|
|
2159
|
+
|
|
2160
|
+
if mode == "decompose":
|
|
2161
|
+
return await decomposed_search(
|
|
2162
|
+
query=query,
|
|
2163
|
+
project_path=project_path,
|
|
2164
|
+
n_results=n_results,
|
|
2165
|
+
language=language,
|
|
2166
|
+
node_type=node_type,
|
|
2167
|
+
provider=provider,
|
|
2168
|
+
)
|
|
2169
|
+
elif mode == "expand":
|
|
2170
|
+
return await multi_query_search(
|
|
2171
|
+
query=query,
|
|
2172
|
+
project_path=project_path,
|
|
2173
|
+
n_results=n_results,
|
|
2174
|
+
language=language,
|
|
2175
|
+
node_type=node_type,
|
|
2176
|
+
provider=provider,
|
|
2177
|
+
)
|
|
2178
|
+
elif mode == "both":
|
|
2179
|
+
# Decompose first, then expand each sub-query
|
|
2180
|
+
sub_queries = await _decompose_query_with_llm(query)
|
|
2181
|
+
|
|
2182
|
+
all_results: list[list[dict]] = []
|
|
2183
|
+
store = get_store(project_path, provider)
|
|
2184
|
+
|
|
2185
|
+
for sub_q in sub_queries:
|
|
2186
|
+
# Expand each sub-query
|
|
2187
|
+
expanded = await _expand_query_with_llm(sub_q, num_variations=2)
|
|
2188
|
+
for exp_q in expanded:
|
|
2189
|
+
results = await store.search(
|
|
2190
|
+
exp_q,
|
|
2191
|
+
n_results=5,
|
|
2192
|
+
language=language,
|
|
2193
|
+
node_type=node_type,
|
|
2194
|
+
)
|
|
2195
|
+
if results and "error" not in results[0]:
|
|
2196
|
+
all_results.append(results)
|
|
2197
|
+
|
|
2198
|
+
aggregated = _aggregate_results(all_results, n_results)
|
|
2199
|
+
|
|
2200
|
+
if not aggregated:
|
|
2201
|
+
return "No results found"
|
|
2202
|
+
|
|
2203
|
+
lines = [f"Enhanced search (decompose+expand) for: '{query}'"]
|
|
2204
|
+
lines.append(f"[{len(sub_queries)} sub-queries × expansions]\n")
|
|
2205
|
+
|
|
2206
|
+
for i, r in enumerate(aggregated, 1):
|
|
2207
|
+
lines.append(f"{i}. {r['file']}:{r['lines']} (relevance: {r['relevance']})")
|
|
2208
|
+
lines.append(f"```{r.get('language', '')}")
|
|
2209
|
+
lines.append(r.get("code_preview", ""))
|
|
2210
|
+
lines.append("```\n")
|
|
2211
|
+
|
|
2212
|
+
return "\n".join(lines)
|
|
2213
|
+
|
|
2214
|
+
else:
|
|
2215
|
+
return f"Unknown mode: {mode}. Use 'auto', 'expand', 'decompose', or 'both'"
|
|
2216
|
+
|
|
2217
|
+
|
|
2218
|
+
# ========================
|
|
2219
|
+
# FILE WATCHER IMPLEMENTATION
|
|
2220
|
+
# ========================
|
|
2221
|
+
|
|
2222
|
+
|
|
2223
|
+
class CodebaseFileWatcher:
|
|
2224
|
+
"""Watch a project directory for file changes and trigger reindexing.
|
|
2225
|
+
|
|
2226
|
+
Features:
|
|
2227
|
+
- Watches for file create, modify, delete, move events
|
|
2228
|
+
- Filters to .py files only
|
|
2229
|
+
- Skips hidden files and directories (., .git, __pycache__, venv, etc.)
|
|
2230
|
+
- Debounces rapid changes to batch them into a single reindex
|
|
2231
|
+
- Thread-safe with daemon threads for clean shutdown
|
|
2232
|
+
- Integrates with CodebaseVectorStore for incremental indexing
|
|
2233
|
+
"""
|
|
2234
|
+
|
|
2235
|
+
# Default debounce time in seconds
|
|
2236
|
+
DEFAULT_DEBOUNCE_SECONDS = 2.0
|
|
2237
|
+
|
|
2238
|
+
def __init__(
|
|
2239
|
+
self,
|
|
2240
|
+
project_path: Path | str,
|
|
2241
|
+
store: CodebaseVectorStore,
|
|
2242
|
+
debounce_seconds: float = DEFAULT_DEBOUNCE_SECONDS,
|
|
2243
|
+
):
|
|
2244
|
+
"""Initialize the file watcher.
|
|
2245
|
+
|
|
2246
|
+
Args:
|
|
2247
|
+
project_path: Path to the project root to watch
|
|
2248
|
+
store: CodebaseVectorStore instance for reindexing
|
|
2249
|
+
debounce_seconds: Time to wait before reindexing after changes (default: 2.0s)
|
|
2250
|
+
"""
|
|
2251
|
+
self.project_path = Path(project_path).resolve()
|
|
2252
|
+
self.store = store
|
|
2253
|
+
self.debounce_seconds = debounce_seconds
|
|
2254
|
+
|
|
2255
|
+
# Observer and handler for watchdog
|
|
2256
|
+
self._observer = None
|
|
2257
|
+
self._event_handler = None
|
|
2258
|
+
|
|
2259
|
+
# Thread safety
|
|
2260
|
+
self._lock = threading.Lock()
|
|
2261
|
+
self._running = False
|
|
2262
|
+
|
|
2263
|
+
# Debouncing
|
|
2264
|
+
self._pending_reindex_timer: threading.Timer | None = None
|
|
2265
|
+
self._pending_files: set[Path] = set()
|
|
2266
|
+
self._pending_lock = threading.Lock()
|
|
2267
|
+
|
|
2268
|
+
def start(self) -> None:
|
|
2269
|
+
"""Start watching the project directory.
|
|
2270
|
+
|
|
2271
|
+
Creates and starts a watchdog observer in a daemon thread.
|
|
2272
|
+
"""
|
|
2273
|
+
with self._lock:
|
|
2274
|
+
if self._running:
|
|
2275
|
+
logger.warning(f"Watcher for {self.project_path} is already running")
|
|
2276
|
+
return
|
|
2277
|
+
|
|
2278
|
+
try:
|
|
2279
|
+
watchdog = get_watchdog()
|
|
2280
|
+
Observer = watchdog["Observer"]
|
|
2281
|
+
|
|
2282
|
+
# Create event handler class and instantiate
|
|
2283
|
+
FileChangeHandler = _create_file_change_handler_class()
|
|
2284
|
+
self._event_handler = FileChangeHandler(
|
|
2285
|
+
project_path=self.project_path,
|
|
2286
|
+
watcher=self,
|
|
2287
|
+
)
|
|
2288
|
+
|
|
2289
|
+
# Create and start observer (daemon mode for clean shutdown)
|
|
2290
|
+
self._observer = Observer()
|
|
2291
|
+
self._observer.daemon = True
|
|
2292
|
+
self._observer.schedule(
|
|
2293
|
+
self._event_handler,
|
|
2294
|
+
str(self.project_path),
|
|
2295
|
+
recursive=True,
|
|
2296
|
+
)
|
|
2297
|
+
self._observer.start()
|
|
2298
|
+
self._running = True
|
|
2299
|
+
logger.info(f"File watcher started for {self.project_path}")
|
|
2300
|
+
|
|
2301
|
+
except Exception as e:
|
|
2302
|
+
logger.error(f"Failed to start file watcher: {e}")
|
|
2303
|
+
self._running = False
|
|
2304
|
+
raise
|
|
2305
|
+
|
|
2306
|
+
def stop(self) -> None:
|
|
2307
|
+
"""Stop watching the project directory.
|
|
2308
|
+
|
|
2309
|
+
Cancels any pending reindex timers and stops the observer.
|
|
2310
|
+
"""
|
|
2311
|
+
with self._lock:
|
|
2312
|
+
# Cancel pending reindex
|
|
2313
|
+
if self._pending_reindex_timer is not None:
|
|
2314
|
+
self._pending_reindex_timer.cancel()
|
|
2315
|
+
self._pending_reindex_timer = None
|
|
2316
|
+
|
|
2317
|
+
# Stop observer
|
|
2318
|
+
if self._observer is not None:
|
|
2319
|
+
self._observer.stop()
|
|
2320
|
+
self._observer.join(timeout=5) # Wait up to 5 seconds for shutdown
|
|
2321
|
+
self._observer = None
|
|
2322
|
+
|
|
2323
|
+
self._event_handler = None
|
|
2324
|
+
self._running = False
|
|
2325
|
+
logger.info(f"File watcher stopped for {self.project_path}")
|
|
2326
|
+
|
|
2327
|
+
def is_running(self) -> bool:
|
|
2328
|
+
"""Check if the watcher is currently running.
|
|
2329
|
+
|
|
2330
|
+
Returns:
|
|
2331
|
+
True if watcher is active, False otherwise
|
|
2332
|
+
"""
|
|
2333
|
+
with self._lock:
|
|
2334
|
+
return self._running and self._observer is not None and self._observer.is_alive()
|
|
2335
|
+
|
|
2336
|
+
def _on_file_changed(self, file_path: Path) -> None:
|
|
2337
|
+
"""Called when a file changes (internal use by _FileChangeHandler).
|
|
2338
|
+
|
|
2339
|
+
Accumulates files and triggers debounced reindex.
|
|
2340
|
+
|
|
2341
|
+
Args:
|
|
2342
|
+
file_path: Path to the changed file
|
|
2343
|
+
"""
|
|
2344
|
+
with self._pending_lock:
|
|
2345
|
+
self._pending_files.add(file_path)
|
|
2346
|
+
|
|
2347
|
+
# Cancel previous timer
|
|
2348
|
+
if self._pending_reindex_timer is not None:
|
|
2349
|
+
self._pending_reindex_timer.cancel()
|
|
2350
|
+
|
|
2351
|
+
# Start new timer
|
|
2352
|
+
self._pending_reindex_timer = self._create_debounce_timer()
|
|
2353
|
+
self._pending_reindex_timer.start()
|
|
2354
|
+
|
|
2355
|
+
def _create_debounce_timer(self) -> threading.Timer:
|
|
2356
|
+
"""Create a new debounce timer for reindexing.
|
|
2357
|
+
|
|
2358
|
+
Returns:
|
|
2359
|
+
A threading.Timer configured for debounce reindexing
|
|
2360
|
+
"""
|
|
2361
|
+
return threading.Timer(
|
|
2362
|
+
self.debounce_seconds,
|
|
2363
|
+
self._trigger_reindex,
|
|
2364
|
+
)
|
|
2365
|
+
|
|
2366
|
+
def _trigger_reindex(self) -> None:
|
|
2367
|
+
"""Trigger reindexing of accumulated changed files.
|
|
2368
|
+
|
|
2369
|
+
This is called after the debounce period expires. It performs an
|
|
2370
|
+
incremental reindex focusing on the changed files.
|
|
2371
|
+
"""
|
|
2372
|
+
import asyncio
|
|
2373
|
+
|
|
2374
|
+
with self._pending_lock:
|
|
2375
|
+
if not self._pending_files:
|
|
2376
|
+
self._pending_reindex_timer = None
|
|
2377
|
+
return
|
|
2378
|
+
|
|
2379
|
+
files_to_index = list(self._pending_files)
|
|
2380
|
+
self._pending_files.clear()
|
|
2381
|
+
self._pending_reindex_timer = None
|
|
2382
|
+
|
|
2383
|
+
# Run async reindex in a new event loop
|
|
2384
|
+
try:
|
|
2385
|
+
loop = asyncio.new_event_loop()
|
|
2386
|
+
asyncio.set_event_loop(loop)
|
|
2387
|
+
try:
|
|
2388
|
+
loop.run_until_complete(self.store.index_codebase(force=False))
|
|
2389
|
+
logger.debug(f"Reindexed {len(files_to_index)} changed files")
|
|
2390
|
+
finally:
|
|
2391
|
+
loop.close()
|
|
2392
|
+
except Exception as e:
|
|
2393
|
+
logger.error(f"Error during file watcher reindex: {e}")
|
|
2394
|
+
|
|
2395
|
+
|
|
2396
|
+
def _create_file_change_handler_class():
|
|
2397
|
+
"""Create FileChangeHandler class that inherits from FileSystemEventHandler.
|
|
2398
|
+
|
|
2399
|
+
This is a factory function that creates the handler class dynamically
|
|
2400
|
+
after watchdog is imported, allowing for lazy loading.
|
|
2401
|
+
"""
|
|
2402
|
+
watchdog = get_watchdog()
|
|
2403
|
+
FileSystemEventHandler = watchdog["FileSystemEventHandler"]
|
|
2404
|
+
|
|
2405
|
+
class _FileChangeHandler(FileSystemEventHandler):
|
|
2406
|
+
"""Watchdog event handler for file system changes.
|
|
2407
|
+
|
|
2408
|
+
Detects file create, modify, delete, and move events, filters them,
|
|
2409
|
+
and notifies the watcher of relevant changes.
|
|
2410
|
+
"""
|
|
2411
|
+
|
|
2412
|
+
def __init__(self, project_path: Path, watcher: CodebaseFileWatcher):
|
|
2413
|
+
"""Initialize the event handler.
|
|
2414
|
+
|
|
2415
|
+
Args:
|
|
2416
|
+
project_path: Root path of the project being watched
|
|
2417
|
+
watcher: CodebaseFileWatcher instance to notify
|
|
2418
|
+
"""
|
|
2419
|
+
super().__init__()
|
|
2420
|
+
self.project_path = project_path
|
|
2421
|
+
self.watcher = watcher
|
|
2422
|
+
|
|
2423
|
+
def on_created(self, event) -> None:
|
|
2424
|
+
"""Called when a file is created."""
|
|
2425
|
+
if not event.is_directory and self._should_index_file(event.src_path):
|
|
2426
|
+
logger.debug(f"File created: {event.src_path}")
|
|
2427
|
+
self.watcher._on_file_changed(Path(event.src_path))
|
|
2428
|
+
|
|
2429
|
+
def on_modified(self, event) -> None:
|
|
2430
|
+
"""Called when a file is modified."""
|
|
2431
|
+
if not event.is_directory and self._should_index_file(event.src_path):
|
|
2432
|
+
logger.debug(f"File modified: {event.src_path}")
|
|
2433
|
+
self.watcher._on_file_changed(Path(event.src_path))
|
|
2434
|
+
|
|
2435
|
+
def on_deleted(self, event) -> None:
|
|
2436
|
+
"""Called when a file is deleted."""
|
|
2437
|
+
if not event.is_directory and self._should_index_file(event.src_path):
|
|
2438
|
+
logger.debug(f"File deleted: {event.src_path}")
|
|
2439
|
+
self.watcher._on_file_changed(Path(event.src_path))
|
|
2440
|
+
|
|
2441
|
+
def on_moved(self, event) -> None:
|
|
2442
|
+
"""Called when a file is moved."""
|
|
2443
|
+
if not event.is_directory:
|
|
2444
|
+
# Check destination path
|
|
2445
|
+
if self._should_index_file(event.dest_path):
|
|
2446
|
+
logger.debug(f"File moved: {event.src_path} -> {event.dest_path}")
|
|
2447
|
+
self.watcher._on_file_changed(Path(event.dest_path))
|
|
2448
|
+
# Also check source path (for deletion case)
|
|
2449
|
+
elif self._should_index_file(event.src_path):
|
|
2450
|
+
logger.debug(f"File moved out: {event.src_path}")
|
|
2451
|
+
self.watcher._on_file_changed(Path(event.src_path))
|
|
2452
|
+
|
|
2453
|
+
def _should_index_file(self, file_path: str) -> bool:
|
|
2454
|
+
"""Check if a file should trigger reindexing.
|
|
2455
|
+
|
|
2456
|
+
Filters based on:
|
|
2457
|
+
- File extension (.py only)
|
|
2458
|
+
- Hidden files and directories (starting with .)
|
|
2459
|
+
- Skip directories (venv, __pycache__, .git, node_modules, etc.)
|
|
2460
|
+
|
|
2461
|
+
Args:
|
|
2462
|
+
file_path: Path to the file to check
|
|
2463
|
+
|
|
2464
|
+
Returns:
|
|
2465
|
+
True if file should trigger reindexing, False otherwise
|
|
2466
|
+
"""
|
|
2467
|
+
path = Path(file_path)
|
|
2468
|
+
|
|
2469
|
+
# Only .py files
|
|
2470
|
+
if path.suffix != ".py":
|
|
2471
|
+
return False
|
|
2472
|
+
|
|
2473
|
+
# Skip hidden files
|
|
2474
|
+
if path.name.startswith("."):
|
|
2475
|
+
return False
|
|
2476
|
+
|
|
2477
|
+
# Check for skip directories in the path
|
|
2478
|
+
for part in path.parts:
|
|
2479
|
+
if part.startswith("."): # Hidden directories like .git, .venv
|
|
2480
|
+
return False
|
|
2481
|
+
if part in {"__pycache__", "venv", "env", "node_modules"}:
|
|
2482
|
+
return False
|
|
2483
|
+
|
|
2484
|
+
# File is within project (resolve both paths to handle symlinks)
|
|
2485
|
+
try:
|
|
2486
|
+
path.resolve().relative_to(self.project_path)
|
|
2487
|
+
return True
|
|
2488
|
+
except ValueError:
|
|
2489
|
+
# File is outside project
|
|
2490
|
+
return False
|
|
2491
|
+
|
|
2492
|
+
return _FileChangeHandler
|