jarvis-ai-assistant 0.1.100__py3-none-any.whl → 0.1.102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jarvis-ai-assistant might be problematic. Click here for more details.
- jarvis/__init__.py +1 -1
- jarvis/agent.py +3 -24
- jarvis/jarvis_code_agent/main.py +1 -3
- jarvis/jarvis_coder/patch_handler.py +5 -5
- jarvis/jarvis_github/main.py +232 -0
- jarvis/models/ai8.py +2 -3
- jarvis/models/oyi.py +1 -3
- jarvis/tools/registry.py +48 -40
- jarvis/utils.py +9 -124
- {jarvis_ai_assistant-0.1.100.dist-info → jarvis_ai_assistant-0.1.102.dist-info}/METADATA +1 -47
- {jarvis_ai_assistant-0.1.100.dist-info → jarvis_ai_assistant-0.1.102.dist-info}/RECORD +16 -21
- {jarvis_ai_assistant-0.1.100.dist-info → jarvis_ai_assistant-0.1.102.dist-info}/entry_points.txt +0 -3
- jarvis/jarvis_codebase/main.py +0 -875
- jarvis/jarvis_coder/main.py +0 -241
- jarvis/jarvis_coder/plan_generator.py +0 -145
- jarvis/jarvis_rag/__init__.py +0 -0
- jarvis/jarvis_rag/main.py +0 -822
- jarvis/tools/rag.py +0 -138
- /jarvis/{jarvis_codebase → jarvis_github}/__init__.py +0 -0
- {jarvis_ai_assistant-0.1.100.dist-info → jarvis_ai_assistant-0.1.102.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.100.dist-info → jarvis_ai_assistant-0.1.102.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.100.dist-info → jarvis_ai_assistant-0.1.102.dist-info}/top_level.txt +0 -0
jarvis/jarvis_codebase/main.py
DELETED
|
@@ -1,875 +0,0 @@
|
|
|
1
|
-
import hashlib
|
|
2
|
-
import os
|
|
3
|
-
import numpy as np
|
|
4
|
-
import faiss
|
|
5
|
-
from typing import List, Tuple, Optional, Dict
|
|
6
|
-
|
|
7
|
-
import yaml
|
|
8
|
-
from jarvis.models.registry import PlatformRegistry
|
|
9
|
-
import concurrent.futures
|
|
10
|
-
from threading import Lock
|
|
11
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
12
|
-
from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_file_md5, get_max_context_length, get_single_line_input, get_thread_count, load_embedding_model, load_rerank_model
|
|
13
|
-
from jarvis.utils import init_env
|
|
14
|
-
import argparse
|
|
15
|
-
import pickle
|
|
16
|
-
import lzma # 添加 lzma 导入
|
|
17
|
-
from tqdm import tqdm
|
|
18
|
-
import re
|
|
19
|
-
|
|
20
|
-
class CodeBase:
|
|
21
|
-
def __init__(self, root_dir: str):
|
|
22
|
-
init_env()
|
|
23
|
-
self.root_dir = root_dir
|
|
24
|
-
os.chdir(self.root_dir)
|
|
25
|
-
self.thread_count = get_thread_count()
|
|
26
|
-
self.max_context_length = get_max_context_length()
|
|
27
|
-
self.index = None
|
|
28
|
-
|
|
29
|
-
# 初始化数据目录
|
|
30
|
-
self.data_dir = os.path.join(self.root_dir, ".jarvis/codebase")
|
|
31
|
-
self.cache_dir = os.path.join(self.data_dir, "cache")
|
|
32
|
-
if not os.path.exists(self.cache_dir):
|
|
33
|
-
os.makedirs(self.cache_dir)
|
|
34
|
-
|
|
35
|
-
# 初始化嵌入模型
|
|
36
|
-
try:
|
|
37
|
-
self.embedding_model = load_embedding_model()
|
|
38
|
-
test_text = """This is a test text"""
|
|
39
|
-
self.embedding_model.encode([test_text],
|
|
40
|
-
convert_to_tensor=True,
|
|
41
|
-
normalize_embeddings=True)
|
|
42
|
-
PrettyOutput.print("Model loaded successfully", output_type=OutputType.SUCCESS)
|
|
43
|
-
except Exception as e:
|
|
44
|
-
PrettyOutput.print(f"Failed to load model: {str(e)}", output_type=OutputType.ERROR)
|
|
45
|
-
raise
|
|
46
|
-
|
|
47
|
-
self.vector_dim = self.embedding_model.get_sentence_embedding_dimension()
|
|
48
|
-
self.git_file_list = self.get_git_file_list()
|
|
49
|
-
self.platform_registry = PlatformRegistry.get_global_platform_registry()
|
|
50
|
-
|
|
51
|
-
# 初始化缓存和索引
|
|
52
|
-
self.vector_cache = {}
|
|
53
|
-
self.file_paths = []
|
|
54
|
-
|
|
55
|
-
# 加载所有缓存文件
|
|
56
|
-
self._load_all_cache()
|
|
57
|
-
|
|
58
|
-
def get_git_file_list(self):
|
|
59
|
-
"""Get the list of files in the git repository, excluding the .jarvis-codebase directory"""
|
|
60
|
-
files = os.popen("git ls-files").read().splitlines()
|
|
61
|
-
# Filter out files in the .jarvis-codebase directory
|
|
62
|
-
return [f for f in files if not f.startswith(".jarvis-")]
|
|
63
|
-
|
|
64
|
-
def is_text_file(self, file_path: str):
|
|
65
|
-
with open(file_path, "r", encoding="utf-8") as f:
|
|
66
|
-
try:
|
|
67
|
-
f.read()
|
|
68
|
-
return True
|
|
69
|
-
except UnicodeDecodeError:
|
|
70
|
-
return False
|
|
71
|
-
|
|
72
|
-
def make_description(self, file_path: str, content: str) -> str:
|
|
73
|
-
model = PlatformRegistry.get_global_platform_registry().get_cheap_platform()
|
|
74
|
-
if self.thread_count > 1:
|
|
75
|
-
model.set_suppress_output(True)
|
|
76
|
-
else:
|
|
77
|
-
PrettyOutput.print(f"Make description for {file_path} ...", output_type=OutputType.PROGRESS)
|
|
78
|
-
prompt = f"""Please analyze the following code file and generate a detailed description. The description should include:
|
|
79
|
-
1. Overall file functionality description
|
|
80
|
-
2. description for each global variable, function, type definition, class, method, and other code elements
|
|
81
|
-
|
|
82
|
-
Please use concise and professional language, emphasizing technical functionality to facilitate subsequent code retrieval.
|
|
83
|
-
File path: {file_path}
|
|
84
|
-
Code content:
|
|
85
|
-
{content}
|
|
86
|
-
"""
|
|
87
|
-
response = model.chat_until_success(prompt)
|
|
88
|
-
return response
|
|
89
|
-
|
|
90
|
-
def export(self):
|
|
91
|
-
"""Export the current index data to standard output"""
|
|
92
|
-
for file_path, data in self.vector_cache.items():
|
|
93
|
-
print(f"## {file_path}")
|
|
94
|
-
print(f"- path: {file_path}")
|
|
95
|
-
print(f"- description: {data['description']}")
|
|
96
|
-
|
|
97
|
-
def _get_cache_path(self, file_path: str) -> str:
|
|
98
|
-
"""Get cache file path for a source file
|
|
99
|
-
|
|
100
|
-
Args:
|
|
101
|
-
file_path: Source file path
|
|
102
|
-
|
|
103
|
-
Returns:
|
|
104
|
-
str: Cache file path
|
|
105
|
-
"""
|
|
106
|
-
# 处理文件路径:
|
|
107
|
-
# 1. 移除开头的 ./ 或 /
|
|
108
|
-
# 2. 将 / 替换为 --
|
|
109
|
-
# 3. 添加 .cache 后缀
|
|
110
|
-
clean_path = file_path.lstrip('./').lstrip('/')
|
|
111
|
-
cache_name = clean_path.replace('/', '--') + '.cache'
|
|
112
|
-
return os.path.join(self.cache_dir, cache_name)
|
|
113
|
-
|
|
114
|
-
def _load_all_cache(self):
|
|
115
|
-
"""Load all cache files"""
|
|
116
|
-
try:
|
|
117
|
-
# 清空现有缓存和文件路径
|
|
118
|
-
self.vector_cache = {}
|
|
119
|
-
self.file_paths = []
|
|
120
|
-
vectors = []
|
|
121
|
-
|
|
122
|
-
for cache_file in os.listdir(self.cache_dir):
|
|
123
|
-
if not cache_file.endswith('.cache'):
|
|
124
|
-
continue
|
|
125
|
-
|
|
126
|
-
cache_path = os.path.join(self.cache_dir, cache_file)
|
|
127
|
-
try:
|
|
128
|
-
with lzma.open(cache_path, 'rb') as f:
|
|
129
|
-
cache_data = pickle.load(f)
|
|
130
|
-
file_path = cache_data["path"]
|
|
131
|
-
self.vector_cache[file_path] = cache_data
|
|
132
|
-
self.file_paths.append(file_path)
|
|
133
|
-
vectors.append(cache_data["vector"])
|
|
134
|
-
except Exception as e:
|
|
135
|
-
PrettyOutput.print(f"Failed to load cache file {cache_file}: {str(e)}",
|
|
136
|
-
output_type=OutputType.WARNING)
|
|
137
|
-
continue
|
|
138
|
-
|
|
139
|
-
if vectors:
|
|
140
|
-
# 重建索引
|
|
141
|
-
vectors_array = np.vstack(vectors)
|
|
142
|
-
hnsw_index = faiss.IndexHNSWFlat(self.vector_dim, 16)
|
|
143
|
-
hnsw_index.hnsw.efConstruction = 40
|
|
144
|
-
hnsw_index.hnsw.efSearch = 16
|
|
145
|
-
self.index = faiss.IndexIDMap(hnsw_index)
|
|
146
|
-
self.index.add_with_ids(vectors_array, np.array(range(len(vectors)))) # type: ignore
|
|
147
|
-
|
|
148
|
-
PrettyOutput.print(f"Loaded {len(self.vector_cache)} vector cache and rebuilt index",
|
|
149
|
-
output_type=OutputType.INFO)
|
|
150
|
-
else:
|
|
151
|
-
self.index = None
|
|
152
|
-
PrettyOutput.print("No valid cache files found", output_type=OutputType.WARNING)
|
|
153
|
-
|
|
154
|
-
except Exception as e:
|
|
155
|
-
PrettyOutput.print(f"Failed to load cache directory: {str(e)}",
|
|
156
|
-
output_type=OutputType.WARNING)
|
|
157
|
-
self.vector_cache = {}
|
|
158
|
-
self.file_paths = []
|
|
159
|
-
self.index = None
|
|
160
|
-
|
|
161
|
-
def cache_vector(self, file_path: str, vector: np.ndarray, description: str):
|
|
162
|
-
"""Cache the vector representation of a file"""
|
|
163
|
-
try:
|
|
164
|
-
with open(file_path, "rb") as f:
|
|
165
|
-
file_md5 = hashlib.md5(f.read()).hexdigest()
|
|
166
|
-
except Exception as e:
|
|
167
|
-
PrettyOutput.print(f"Failed to calculate MD5 for {file_path}: {str(e)}",
|
|
168
|
-
output_type=OutputType.ERROR)
|
|
169
|
-
file_md5 = ""
|
|
170
|
-
|
|
171
|
-
# 准备缓存数据
|
|
172
|
-
cache_data = {
|
|
173
|
-
"path": file_path, # 保存文件路径
|
|
174
|
-
"md5": file_md5, # 保存文件MD5
|
|
175
|
-
"description": description, # 保存文件描述
|
|
176
|
-
"vector": vector # 保存向量
|
|
177
|
-
}
|
|
178
|
-
|
|
179
|
-
# 更新内存缓存
|
|
180
|
-
self.vector_cache[file_path] = cache_data
|
|
181
|
-
|
|
182
|
-
# 保存到单独的缓存文件
|
|
183
|
-
cache_path = self._get_cache_path(file_path)
|
|
184
|
-
try:
|
|
185
|
-
with lzma.open(cache_path, 'wb') as f:
|
|
186
|
-
pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
187
|
-
except Exception as e:
|
|
188
|
-
PrettyOutput.print(f"Failed to save cache for {file_path}: {str(e)}",
|
|
189
|
-
output_type=OutputType.ERROR)
|
|
190
|
-
|
|
191
|
-
def get_cached_vector(self, file_path: str, description: str) -> Optional[np.ndarray]:
|
|
192
|
-
"""Get the vector representation of a file from the cache"""
|
|
193
|
-
if file_path not in self.vector_cache:
|
|
194
|
-
return None
|
|
195
|
-
|
|
196
|
-
# Check if the file has been modified
|
|
197
|
-
try:
|
|
198
|
-
with open(file_path, "rb") as f:
|
|
199
|
-
current_md5 = hashlib.md5(f.read()).hexdigest()
|
|
200
|
-
except Exception as e:
|
|
201
|
-
PrettyOutput.print(f"Failed to calculate MD5 for {file_path}: {str(e)}",
|
|
202
|
-
output_type=OutputType.ERROR)
|
|
203
|
-
return None
|
|
204
|
-
|
|
205
|
-
cached_data = self.vector_cache[file_path]
|
|
206
|
-
if cached_data["md5"] != current_md5:
|
|
207
|
-
return None
|
|
208
|
-
|
|
209
|
-
# Check if the description has changed
|
|
210
|
-
if cached_data["description"] != description:
|
|
211
|
-
return None
|
|
212
|
-
|
|
213
|
-
return cached_data["vector"]
|
|
214
|
-
|
|
215
|
-
def get_embedding(self, text: str) -> np.ndarray:
|
|
216
|
-
"""Use the transformers model to get the vector representation of text"""
|
|
217
|
-
# Truncate long text
|
|
218
|
-
max_length = 512 # Or other suitable length
|
|
219
|
-
text = ' '.join(text.split()[:max_length])
|
|
220
|
-
|
|
221
|
-
# Get the embedding vector
|
|
222
|
-
embedding = self.embedding_model.encode(text,
|
|
223
|
-
normalize_embeddings=True, # L2 normalization
|
|
224
|
-
show_progress_bar=False)
|
|
225
|
-
vector = np.array(embedding, dtype=np.float32)
|
|
226
|
-
return vector
|
|
227
|
-
|
|
228
|
-
def vectorize_file(self, file_path: str, description: str) -> np.ndarray:
|
|
229
|
-
"""Vectorize the file content and description"""
|
|
230
|
-
try:
|
|
231
|
-
# Try to get the vector from the cache first
|
|
232
|
-
cached_vector = self.get_cached_vector(file_path, description)
|
|
233
|
-
if cached_vector is not None:
|
|
234
|
-
return cached_vector
|
|
235
|
-
|
|
236
|
-
# Read the file content and combine information
|
|
237
|
-
content = open(file_path, "r", encoding="utf-8").read()[:self.max_context_length] # Limit the file content length
|
|
238
|
-
|
|
239
|
-
# Combine file information, including file content
|
|
240
|
-
combined_text = f"""
|
|
241
|
-
File path: {file_path}
|
|
242
|
-
Description: {description}
|
|
243
|
-
Content: {content}
|
|
244
|
-
"""
|
|
245
|
-
vector = self.get_embedding(combined_text)
|
|
246
|
-
|
|
247
|
-
# Save to cache
|
|
248
|
-
self.cache_vector(file_path, vector, description)
|
|
249
|
-
return vector
|
|
250
|
-
except Exception as e:
|
|
251
|
-
PrettyOutput.print(f"Error vectorizing file {file_path}: {str(e)}",
|
|
252
|
-
output_type=OutputType.ERROR)
|
|
253
|
-
return np.zeros(self.vector_dim, dtype=np.float32) # type: ignore
|
|
254
|
-
|
|
255
|
-
def clean_cache(self) -> bool:
|
|
256
|
-
"""Clean expired cache records"""
|
|
257
|
-
try:
|
|
258
|
-
files_to_delete = []
|
|
259
|
-
for file_path in list(self.vector_cache.keys()):
|
|
260
|
-
if not os.path.exists(file_path):
|
|
261
|
-
files_to_delete.append(file_path)
|
|
262
|
-
cache_path = self._get_cache_path(file_path)
|
|
263
|
-
try:
|
|
264
|
-
os.remove(cache_path)
|
|
265
|
-
except Exception:
|
|
266
|
-
pass
|
|
267
|
-
|
|
268
|
-
for file_path in files_to_delete:
|
|
269
|
-
del self.vector_cache[file_path]
|
|
270
|
-
if file_path in self.file_paths:
|
|
271
|
-
self.file_paths.remove(file_path)
|
|
272
|
-
|
|
273
|
-
return bool(files_to_delete)
|
|
274
|
-
|
|
275
|
-
except Exception as e:
|
|
276
|
-
PrettyOutput.print(f"Failed to clean cache: {str(e)}",
|
|
277
|
-
output_type=OutputType.ERROR)
|
|
278
|
-
return False
|
|
279
|
-
|
|
280
|
-
def process_file(self, file_path: str):
|
|
281
|
-
"""Process a single file"""
|
|
282
|
-
try:
|
|
283
|
-
# Skip non-existent files
|
|
284
|
-
if not os.path.exists(file_path):
|
|
285
|
-
return None
|
|
286
|
-
|
|
287
|
-
if not self.is_text_file(file_path):
|
|
288
|
-
return None
|
|
289
|
-
|
|
290
|
-
md5 = get_file_md5(file_path)
|
|
291
|
-
|
|
292
|
-
content = open(file_path, "r", encoding="utf-8").read()
|
|
293
|
-
|
|
294
|
-
# Check if the file has already been processed and the content has not changed
|
|
295
|
-
if file_path in self.vector_cache:
|
|
296
|
-
if self.vector_cache[file_path].get("md5") == md5:
|
|
297
|
-
return None
|
|
298
|
-
|
|
299
|
-
description = self.make_description(file_path, content) # Pass the truncated content
|
|
300
|
-
vector = self.vectorize_file(file_path, description)
|
|
301
|
-
|
|
302
|
-
# Save to cache, using the actual file path as the key
|
|
303
|
-
self.vector_cache[file_path] = {
|
|
304
|
-
"vector": vector,
|
|
305
|
-
"description": description,
|
|
306
|
-
"md5": md5
|
|
307
|
-
}
|
|
308
|
-
|
|
309
|
-
return file_path
|
|
310
|
-
|
|
311
|
-
except Exception as e:
|
|
312
|
-
PrettyOutput.print(f"Failed to process file {file_path}: {str(e)}",
|
|
313
|
-
output_type=OutputType.ERROR)
|
|
314
|
-
return None
|
|
315
|
-
|
|
316
|
-
def build_index(self):
|
|
317
|
-
"""Build a faiss index from the vector cache"""
|
|
318
|
-
try:
|
|
319
|
-
if not self.vector_cache:
|
|
320
|
-
self.index = None
|
|
321
|
-
return
|
|
322
|
-
|
|
323
|
-
# Create the underlying HNSW index
|
|
324
|
-
hnsw_index = faiss.IndexHNSWFlat(self.vector_dim, 16)
|
|
325
|
-
hnsw_index.hnsw.efConstruction = 40
|
|
326
|
-
hnsw_index.hnsw.efSearch = 16
|
|
327
|
-
|
|
328
|
-
# Wrap the HNSW index with IndexIDMap
|
|
329
|
-
self.index = faiss.IndexIDMap(hnsw_index)
|
|
330
|
-
|
|
331
|
-
vectors = []
|
|
332
|
-
ids = []
|
|
333
|
-
self.file_paths = [] # Reset the file path list
|
|
334
|
-
|
|
335
|
-
for i, (file_path, data) in enumerate(self.vector_cache.items()):
|
|
336
|
-
if "vector" not in data:
|
|
337
|
-
PrettyOutput.print(f"Invalid cache data for {file_path}: missing vector",
|
|
338
|
-
output_type=OutputType.WARNING)
|
|
339
|
-
continue
|
|
340
|
-
|
|
341
|
-
vector = data["vector"]
|
|
342
|
-
if not isinstance(vector, np.ndarray):
|
|
343
|
-
PrettyOutput.print(f"Invalid vector type for {file_path}: {type(vector)}",
|
|
344
|
-
output_type=OutputType.WARNING)
|
|
345
|
-
continue
|
|
346
|
-
|
|
347
|
-
vectors.append(vector.reshape(1, -1))
|
|
348
|
-
ids.append(i)
|
|
349
|
-
self.file_paths.append(file_path)
|
|
350
|
-
|
|
351
|
-
if vectors:
|
|
352
|
-
vectors = np.vstack(vectors)
|
|
353
|
-
if len(vectors) != len(ids):
|
|
354
|
-
PrettyOutput.print(f"Vector count mismatch: {len(vectors)} vectors vs {len(ids)} ids",
|
|
355
|
-
output_type=OutputType.ERROR)
|
|
356
|
-
self.index = None
|
|
357
|
-
return
|
|
358
|
-
|
|
359
|
-
try:
|
|
360
|
-
self.index.add_with_ids(vectors, np.array(ids)) # type: ignore
|
|
361
|
-
PrettyOutput.print(f"Successfully built index with {len(vectors)} vectors",
|
|
362
|
-
output_type=OutputType.SUCCESS)
|
|
363
|
-
except Exception as e:
|
|
364
|
-
PrettyOutput.print(f"Failed to add vectors to index: {str(e)}",
|
|
365
|
-
output_type=OutputType.ERROR)
|
|
366
|
-
self.index = None
|
|
367
|
-
else:
|
|
368
|
-
PrettyOutput.print("No valid vectors found, index not built",
|
|
369
|
-
output_type=OutputType.WARNING)
|
|
370
|
-
self.index = None
|
|
371
|
-
|
|
372
|
-
except Exception as e:
|
|
373
|
-
PrettyOutput.print(f"Failed to build index: {str(e)}",
|
|
374
|
-
output_type=OutputType.ERROR)
|
|
375
|
-
self.index = None
|
|
376
|
-
|
|
377
|
-
def gen_vector_db_from_cache(self):
|
|
378
|
-
"""Generate a vector database from the cache"""
|
|
379
|
-
self.build_index()
|
|
380
|
-
self._load_all_cache()
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
def generate_codebase(self, force: bool = False):
|
|
384
|
-
"""Generate the codebase index
|
|
385
|
-
Args:
|
|
386
|
-
force: Whether to force rebuild the index, without asking the user
|
|
387
|
-
"""
|
|
388
|
-
try:
|
|
389
|
-
# Update the git file list
|
|
390
|
-
self.git_file_list = self.get_git_file_list()
|
|
391
|
-
|
|
392
|
-
# Check file changes
|
|
393
|
-
PrettyOutput.print("\nCheck file changes...", output_type=OutputType.INFO)
|
|
394
|
-
changes_detected = False
|
|
395
|
-
new_files = []
|
|
396
|
-
modified_files = []
|
|
397
|
-
deleted_files = []
|
|
398
|
-
|
|
399
|
-
# Check deleted files
|
|
400
|
-
files_to_delete = []
|
|
401
|
-
for file_path in list(self.vector_cache.keys()):
|
|
402
|
-
if file_path not in self.git_file_list:
|
|
403
|
-
deleted_files.append(file_path)
|
|
404
|
-
files_to_delete.append(file_path)
|
|
405
|
-
changes_detected = True
|
|
406
|
-
|
|
407
|
-
# Check new and modified files
|
|
408
|
-
with tqdm(total=len(self.git_file_list), desc="Check file status") as pbar:
|
|
409
|
-
for file_path in self.git_file_list:
|
|
410
|
-
if not os.path.exists(file_path) or not self.is_text_file(file_path):
|
|
411
|
-
pbar.update(1)
|
|
412
|
-
continue
|
|
413
|
-
|
|
414
|
-
try:
|
|
415
|
-
current_md5 = get_file_md5(file_path)
|
|
416
|
-
|
|
417
|
-
if file_path not in self.vector_cache:
|
|
418
|
-
new_files.append(file_path)
|
|
419
|
-
changes_detected = True
|
|
420
|
-
elif self.vector_cache[file_path].get("md5") != current_md5:
|
|
421
|
-
modified_files.append(file_path)
|
|
422
|
-
changes_detected = True
|
|
423
|
-
except Exception as e:
|
|
424
|
-
PrettyOutput.print(f"Failed to check file {file_path}: {str(e)}",
|
|
425
|
-
output_type=OutputType.ERROR)
|
|
426
|
-
pbar.update(1)
|
|
427
|
-
|
|
428
|
-
# If changes are detected, display changes and ask the user
|
|
429
|
-
if changes_detected:
|
|
430
|
-
PrettyOutput.print("\nDetected the following changes:", output_type=OutputType.WARNING)
|
|
431
|
-
if new_files:
|
|
432
|
-
PrettyOutput.print("\nNew files:", output_type=OutputType.INFO)
|
|
433
|
-
for f in new_files:
|
|
434
|
-
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
435
|
-
if modified_files:
|
|
436
|
-
PrettyOutput.print("\nModified files:", output_type=OutputType.INFO)
|
|
437
|
-
for f in modified_files:
|
|
438
|
-
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
439
|
-
if deleted_files:
|
|
440
|
-
PrettyOutput.print("\nDeleted files:", output_type=OutputType.INFO)
|
|
441
|
-
for f in deleted_files:
|
|
442
|
-
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
443
|
-
|
|
444
|
-
# If force is True, continue directly
|
|
445
|
-
if not force:
|
|
446
|
-
# Ask the user whether to continue
|
|
447
|
-
while True:
|
|
448
|
-
response = get_single_line_input("\nRebuild the index? [y/N]").lower().strip()
|
|
449
|
-
if response in ['y', 'yes']:
|
|
450
|
-
break
|
|
451
|
-
elif response in ['', 'n', 'no']:
|
|
452
|
-
PrettyOutput.print("Cancel rebuilding the index", output_type=OutputType.INFO)
|
|
453
|
-
return
|
|
454
|
-
else:
|
|
455
|
-
PrettyOutput.print("Please input y or n", output_type=OutputType.WARNING)
|
|
456
|
-
|
|
457
|
-
# Clean deleted files
|
|
458
|
-
for file_path in files_to_delete:
|
|
459
|
-
del self.vector_cache[file_path]
|
|
460
|
-
if files_to_delete:
|
|
461
|
-
PrettyOutput.print(f"Cleaned the cache of {len(files_to_delete)} files",
|
|
462
|
-
output_type=OutputType.INFO)
|
|
463
|
-
|
|
464
|
-
# Process new and modified files
|
|
465
|
-
files_to_process = new_files + modified_files
|
|
466
|
-
processed_files = []
|
|
467
|
-
|
|
468
|
-
with tqdm(total=len(files_to_process), desc="Processing files") as pbar:
|
|
469
|
-
# Use a thread pool to process files
|
|
470
|
-
with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
|
|
471
|
-
# Submit all tasks
|
|
472
|
-
future_to_file = {
|
|
473
|
-
executor.submit(self.process_file, file): file
|
|
474
|
-
for file in files_to_process
|
|
475
|
-
}
|
|
476
|
-
|
|
477
|
-
# Process completed tasks
|
|
478
|
-
for future in concurrent.futures.as_completed(future_to_file):
|
|
479
|
-
file = future_to_file[future]
|
|
480
|
-
try:
|
|
481
|
-
result = future.result()
|
|
482
|
-
if result:
|
|
483
|
-
processed_files.append(result)
|
|
484
|
-
except Exception as e:
|
|
485
|
-
PrettyOutput.print(f"Failed to process file {file}: {str(e)}",
|
|
486
|
-
output_type=OutputType.ERROR)
|
|
487
|
-
pbar.update(1)
|
|
488
|
-
|
|
489
|
-
if processed_files:
|
|
490
|
-
PrettyOutput.print("\nRebuilding the vector database...", output_type=OutputType.INFO)
|
|
491
|
-
self.gen_vector_db_from_cache()
|
|
492
|
-
PrettyOutput.print(f"Successfully generated the index for {len(processed_files)} files",
|
|
493
|
-
output_type=OutputType.SUCCESS)
|
|
494
|
-
else:
|
|
495
|
-
PrettyOutput.print("No file changes detected, no need to rebuild the index", output_type=OutputType.INFO)
|
|
496
|
-
|
|
497
|
-
except Exception as e:
|
|
498
|
-
# Try to save the cache when an exception occurs
|
|
499
|
-
try:
|
|
500
|
-
self._load_all_cache()
|
|
501
|
-
except Exception as save_error:
|
|
502
|
-
PrettyOutput.print(f"Failed to save cache: {str(save_error)}",
|
|
503
|
-
output_type=OutputType.ERROR)
|
|
504
|
-
raise e # Re-raise the original exception
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
def _text_search_score(self, content: str, keywords: List[str]) -> float:
|
|
508
|
-
"""Calculate the matching score between the text content and the keywords
|
|
509
|
-
|
|
510
|
-
Args:
|
|
511
|
-
content: Text content
|
|
512
|
-
keywords: List of keywords
|
|
513
|
-
|
|
514
|
-
Returns:
|
|
515
|
-
float: Matching score (0-1)
|
|
516
|
-
"""
|
|
517
|
-
if not keywords:
|
|
518
|
-
return 0.0
|
|
519
|
-
|
|
520
|
-
content = content.lower()
|
|
521
|
-
matched_keywords = set()
|
|
522
|
-
|
|
523
|
-
for keyword in keywords:
|
|
524
|
-
keyword = keyword.lower()
|
|
525
|
-
if keyword in content:
|
|
526
|
-
matched_keywords.add(keyword)
|
|
527
|
-
|
|
528
|
-
# Calculate the matching score
|
|
529
|
-
score = len(matched_keywords) / len(keywords)
|
|
530
|
-
return score
|
|
531
|
-
|
|
532
|
-
def pick_results(self, query: str, initial_results: List[str]) -> List[str]:
|
|
533
|
-
"""Use a large model to pick the search results
|
|
534
|
-
|
|
535
|
-
Args:
|
|
536
|
-
query: Search query
|
|
537
|
-
initial_results: Initial results list of file paths
|
|
538
|
-
|
|
539
|
-
Returns:
|
|
540
|
-
List[str]: The picked results list, each item is a file path
|
|
541
|
-
"""
|
|
542
|
-
if not initial_results:
|
|
543
|
-
return []
|
|
544
|
-
|
|
545
|
-
try:
|
|
546
|
-
PrettyOutput.print(f"Picking results for query: {query}", output_type=OutputType.INFO)
|
|
547
|
-
|
|
548
|
-
# Maximum content length per batch
|
|
549
|
-
max_batch_length = self.max_context_length - 1000 # Reserve space for prompt
|
|
550
|
-
max_file_length = max_batch_length // 3 # Limit individual file size
|
|
551
|
-
|
|
552
|
-
# Process files in batches
|
|
553
|
-
all_selected_files = set()
|
|
554
|
-
current_batch = []
|
|
555
|
-
current_length = 0
|
|
556
|
-
|
|
557
|
-
for path in initial_results:
|
|
558
|
-
try:
|
|
559
|
-
content = open(path, "r", encoding="utf-8").read()
|
|
560
|
-
# Truncate large files
|
|
561
|
-
if len(content) > max_file_length:
|
|
562
|
-
PrettyOutput.print(f"Truncating large file: {path}", OutputType.WARNING)
|
|
563
|
-
content = content[:max_file_length] + "\n... (content truncated)"
|
|
564
|
-
|
|
565
|
-
file_info = f"File: {path}\nContent: {content}\n\n"
|
|
566
|
-
file_length = len(file_info)
|
|
567
|
-
|
|
568
|
-
# If adding this file would exceed batch limit
|
|
569
|
-
if current_length + file_length > max_batch_length:
|
|
570
|
-
# Process current batch
|
|
571
|
-
if current_batch:
|
|
572
|
-
selected = self._process_batch(query, current_batch)
|
|
573
|
-
all_selected_files.update(selected)
|
|
574
|
-
# Start new batch
|
|
575
|
-
current_batch = [file_info]
|
|
576
|
-
current_length = file_length
|
|
577
|
-
else:
|
|
578
|
-
current_batch.append(file_info)
|
|
579
|
-
current_length += file_length
|
|
580
|
-
|
|
581
|
-
except Exception as e:
|
|
582
|
-
PrettyOutput.print(f"Failed to read file {path}: {str(e)}", OutputType.ERROR)
|
|
583
|
-
continue
|
|
584
|
-
|
|
585
|
-
# Process final batch
|
|
586
|
-
if current_batch:
|
|
587
|
-
selected = self._process_batch(query, current_batch)
|
|
588
|
-
all_selected_files.update(selected)
|
|
589
|
-
|
|
590
|
-
# Convert set to list and maintain original order
|
|
591
|
-
final_results = [path for path in initial_results if path in all_selected_files]
|
|
592
|
-
return final_results
|
|
593
|
-
|
|
594
|
-
except Exception as e:
|
|
595
|
-
PrettyOutput.print(f"Failed to pick: {str(e)}", OutputType.ERROR)
|
|
596
|
-
return initial_results
|
|
597
|
-
|
|
598
|
-
def _process_batch(self, query: str, files_info: List[str]) -> List[str]:
|
|
599
|
-
"""Process a batch of files
|
|
600
|
-
|
|
601
|
-
Args:
|
|
602
|
-
query: Search query
|
|
603
|
-
files_info: List of file information strings
|
|
604
|
-
|
|
605
|
-
Returns:
|
|
606
|
-
List[str]: Selected file paths from this batch
|
|
607
|
-
"""
|
|
608
|
-
prompt = f"""Please analyze the following code files and determine which files are most relevant to the given query. Consider file paths and code content to make your judgment.
|
|
609
|
-
|
|
610
|
-
Query: {query}
|
|
611
|
-
|
|
612
|
-
Available files:
|
|
613
|
-
{''.join(files_info)}
|
|
614
|
-
|
|
615
|
-
Please output a YAML list of relevant file paths, ordered by relevance (most relevant first). Only include files that are truly relevant to the query.
|
|
616
|
-
Output format:
|
|
617
|
-
<FILES>
|
|
618
|
-
- path/to/file1.py
|
|
619
|
-
- path/to/file2.py
|
|
620
|
-
</FILES>
|
|
621
|
-
|
|
622
|
-
Note: Only include files that have a strong connection to the query."""
|
|
623
|
-
|
|
624
|
-
# Use a large model to evaluate
|
|
625
|
-
model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
|
|
626
|
-
response = model.chat_until_success(prompt)
|
|
627
|
-
|
|
628
|
-
# Parse the response
|
|
629
|
-
import yaml
|
|
630
|
-
files_match = re.search(r'<FILES>\n(.*?)</FILES>', response, re.DOTALL)
|
|
631
|
-
if not files_match:
|
|
632
|
-
return []
|
|
633
|
-
|
|
634
|
-
# Extract the file list
|
|
635
|
-
try:
|
|
636
|
-
selected_files = yaml.safe_load(files_match.group(1))
|
|
637
|
-
return selected_files if selected_files else []
|
|
638
|
-
except Exception as e:
|
|
639
|
-
PrettyOutput.print(f"Failed to parse response: {str(e)}", OutputType.ERROR)
|
|
640
|
-
return []
|
|
641
|
-
|
|
642
|
-
def _generate_query_variants(self, query: str) -> List[str]:
|
|
643
|
-
"""Generate different expressions of the query
|
|
644
|
-
|
|
645
|
-
Args:
|
|
646
|
-
query: Original query
|
|
647
|
-
|
|
648
|
-
Returns:
|
|
649
|
-
List[str]: The query variants list
|
|
650
|
-
"""
|
|
651
|
-
model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
|
|
652
|
-
prompt = f"""Please generate 3 different expressions based on the following query, each expression should fully convey the meaning of the original query. These expressions will be used for code search, maintain professionalism and accuracy.
|
|
653
|
-
Original query: {query}
|
|
654
|
-
|
|
655
|
-
Please output 3 expressions directly, separated by two line breaks, without numbering or other markers.
|
|
656
|
-
"""
|
|
657
|
-
variants = model.chat_until_success(prompt).strip().split('\n\n')
|
|
658
|
-
variants.append(query) # Add the original query
|
|
659
|
-
return variants
|
|
660
|
-
|
|
661
|
-
def _vector_search(self, query_variants: List[str], top_k: int) -> Dict[str, Tuple[str, float, str]]:
|
|
662
|
-
"""Use vector search to find related files
|
|
663
|
-
|
|
664
|
-
Args:
|
|
665
|
-
query_variants: The query variants list
|
|
666
|
-
top_k: The number of results to return
|
|
667
|
-
|
|
668
|
-
Returns:
|
|
669
|
-
Dict[str, Tuple[str, float, str]]: The mapping from file path to (file path, score, description)
|
|
670
|
-
"""
|
|
671
|
-
results = {}
|
|
672
|
-
for query in query_variants:
|
|
673
|
-
query_vector = self.get_embedding(query)
|
|
674
|
-
query_vector = query_vector.reshape(1, -1)
|
|
675
|
-
|
|
676
|
-
distances, indices = self.index.search(query_vector, top_k) # type: ignore
|
|
677
|
-
|
|
678
|
-
for i, distance in zip(indices[0], distances[0]):
|
|
679
|
-
if i == -1:
|
|
680
|
-
continue
|
|
681
|
-
|
|
682
|
-
similarity = 1.0 / (1.0 + float(distance))
|
|
683
|
-
file_path = self.file_paths[i]
|
|
684
|
-
# Use the highest similarity score
|
|
685
|
-
if file_path not in results:
|
|
686
|
-
if similarity > 0.5:
|
|
687
|
-
data = self.vector_cache[file_path]
|
|
688
|
-
results[file_path] = (file_path, similarity, data["description"])
|
|
689
|
-
|
|
690
|
-
return results
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
def search_similar(self, query: str, top_k: int = 30) -> List[str]:
|
|
694
|
-
"""Search related files"""
|
|
695
|
-
try:
|
|
696
|
-
if self.index is None:
|
|
697
|
-
return []
|
|
698
|
-
# Generate the query variants
|
|
699
|
-
query_variants = self._generate_query_variants(query)
|
|
700
|
-
|
|
701
|
-
# Perform vector search
|
|
702
|
-
vector_results = self._vector_search(query_variants, top_k)
|
|
703
|
-
|
|
704
|
-
results = list(vector_results.values())
|
|
705
|
-
results.sort(key=lambda x: x[1], reverse=True)
|
|
706
|
-
|
|
707
|
-
# Take the top top_k results for reordering
|
|
708
|
-
initial_results = results[:top_k]
|
|
709
|
-
|
|
710
|
-
# If no results are found, return directly
|
|
711
|
-
if not initial_results:
|
|
712
|
-
return []
|
|
713
|
-
|
|
714
|
-
# Filter low-scoring results
|
|
715
|
-
initial_results = [(path, score, desc) for path, score, desc in initial_results if score >= 0.5]
|
|
716
|
-
|
|
717
|
-
for path, score, desc in initial_results:
|
|
718
|
-
PrettyOutput.print(f"File: {path} Similarity: {score:.3f}", output_type=OutputType.INFO)
|
|
719
|
-
|
|
720
|
-
# Reorder the preliminary results
|
|
721
|
-
return self.pick_results(query, [path for path, _, _ in initial_results])
|
|
722
|
-
|
|
723
|
-
except Exception as e:
|
|
724
|
-
PrettyOutput.print(f"Failed to search: {str(e)}", output_type=OutputType.ERROR)
|
|
725
|
-
return []
|
|
726
|
-
|
|
727
|
-
def ask_codebase(self, query: str, top_k: int=20) -> str:
|
|
728
|
-
"""Query the codebase"""
|
|
729
|
-
results = self.search_similar(query, top_k)
|
|
730
|
-
if not results:
|
|
731
|
-
PrettyOutput.print("No related files found", output_type=OutputType.WARNING)
|
|
732
|
-
return ""
|
|
733
|
-
|
|
734
|
-
PrettyOutput.print(f"Found related files: ", output_type=OutputType.SUCCESS)
|
|
735
|
-
for path in results:
|
|
736
|
-
PrettyOutput.print(f"File: {path}",
|
|
737
|
-
output_type=OutputType.INFO)
|
|
738
|
-
|
|
739
|
-
prompt = f"""You are a code expert, please answer the user's question based on the following file information:
|
|
740
|
-
"""
|
|
741
|
-
for path in results:
|
|
742
|
-
try:
|
|
743
|
-
if len(prompt) > self.max_context_length:
|
|
744
|
-
PrettyOutput.print(f"Avoid context overflow, discard low-related file: {path}", OutputType.WARNING)
|
|
745
|
-
continue
|
|
746
|
-
content = open(path, "r", encoding="utf-8").read()
|
|
747
|
-
prompt += f"""
|
|
748
|
-
File path: {path}
|
|
749
|
-
File content:
|
|
750
|
-
{content}
|
|
751
|
-
========================================
|
|
752
|
-
"""
|
|
753
|
-
except Exception as e:
|
|
754
|
-
PrettyOutput.print(f"Failed to read file {path}: {str(e)}",
|
|
755
|
-
output_type=OutputType.ERROR)
|
|
756
|
-
continue
|
|
757
|
-
|
|
758
|
-
prompt += f"""
|
|
759
|
-
User question: {query}
|
|
760
|
-
|
|
761
|
-
Please answer the user's question in Chinese using professional language. If the provided file content is insufficient to answer the user's question, please inform the user. Never make up information.
|
|
762
|
-
"""
|
|
763
|
-
model = PlatformRegistry.get_global_platform_registry().get_codegen_platform()
|
|
764
|
-
response = model.chat_until_success(prompt)
|
|
765
|
-
return response
|
|
766
|
-
|
|
767
|
-
def is_index_generated(self) -> bool:
|
|
768
|
-
"""Check if the index has been generated"""
|
|
769
|
-
try:
|
|
770
|
-
# 1. 检查基本条件
|
|
771
|
-
if not self.vector_cache or not self.file_paths:
|
|
772
|
-
return False
|
|
773
|
-
|
|
774
|
-
if not hasattr(self, 'index') or self.index is None:
|
|
775
|
-
return False
|
|
776
|
-
|
|
777
|
-
# 2. 检查索引是否可用
|
|
778
|
-
# 创建测试向量
|
|
779
|
-
test_vector = np.zeros((1, self.vector_dim), dtype=np.float32) # type: ignore
|
|
780
|
-
try:
|
|
781
|
-
self.index.search(test_vector, 1) # type: ignore
|
|
782
|
-
except Exception:
|
|
783
|
-
return False
|
|
784
|
-
|
|
785
|
-
# 3. 验证向量缓存和文件路径的一致性
|
|
786
|
-
if len(self.vector_cache) != len(self.file_paths):
|
|
787
|
-
return False
|
|
788
|
-
|
|
789
|
-
# 4. 验证所有缓存文件
|
|
790
|
-
for file_path in self.file_paths:
|
|
791
|
-
if file_path not in self.vector_cache:
|
|
792
|
-
return False
|
|
793
|
-
|
|
794
|
-
cache_path = self._get_cache_path(file_path)
|
|
795
|
-
if not os.path.exists(cache_path):
|
|
796
|
-
return False
|
|
797
|
-
|
|
798
|
-
cache_data = self.vector_cache[file_path]
|
|
799
|
-
if not isinstance(cache_data.get("vector"), np.ndarray):
|
|
800
|
-
return False
|
|
801
|
-
|
|
802
|
-
return True
|
|
803
|
-
|
|
804
|
-
except Exception as e:
|
|
805
|
-
PrettyOutput.print(f"Error checking index status: {str(e)}",
|
|
806
|
-
output_type=OutputType.ERROR)
|
|
807
|
-
return False
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
def main():
|
|
814
|
-
|
|
815
|
-
parser = argparse.ArgumentParser(description='Codebase management and search tool')
|
|
816
|
-
subparsers = parser.add_subparsers(dest='command', help='Available commands')
|
|
817
|
-
|
|
818
|
-
# Generate command
|
|
819
|
-
generate_parser = subparsers.add_parser('generate', help='Generate codebase index')
|
|
820
|
-
generate_parser.add_argument('--force', action='store_true', help='Force rebuild index')
|
|
821
|
-
|
|
822
|
-
# Search command
|
|
823
|
-
search_parser = subparsers.add_parser('search', help='Search similar code files')
|
|
824
|
-
search_parser.add_argument('query', type=str, help='Search query')
|
|
825
|
-
search_parser.add_argument('--top-k', type=int, default=20, help='Number of results to return (default: 20)')
|
|
826
|
-
|
|
827
|
-
# Ask command
|
|
828
|
-
ask_parser = subparsers.add_parser('ask', help='Ask a question about the codebase')
|
|
829
|
-
ask_parser.add_argument('question', type=str, help='Question to ask')
|
|
830
|
-
ask_parser.add_argument('--top-k', type=int, default=20, help='Number of results to use (default: 20)')
|
|
831
|
-
|
|
832
|
-
export_parser = subparsers.add_parser('export', help='Export current index data')
|
|
833
|
-
args = parser.parse_args()
|
|
834
|
-
|
|
835
|
-
current_dir = find_git_root()
|
|
836
|
-
codebase = CodeBase(current_dir)
|
|
837
|
-
|
|
838
|
-
if args.command == 'export':
|
|
839
|
-
codebase.export()
|
|
840
|
-
return
|
|
841
|
-
|
|
842
|
-
# 如果没有生成索引,且不是生成命令,提示用户先生成索引
|
|
843
|
-
if not codebase.is_index_generated() and args.command != 'generate':
|
|
844
|
-
PrettyOutput.print("索引尚未生成,请先运行 'generate' 命令生成索引", output_type=OutputType.WARNING)
|
|
845
|
-
return
|
|
846
|
-
|
|
847
|
-
if args.command == 'generate':
|
|
848
|
-
try:
|
|
849
|
-
codebase.generate_codebase(force=args.force)
|
|
850
|
-
PrettyOutput.print("\nCodebase generation completed", output_type=OutputType.SUCCESS)
|
|
851
|
-
except Exception as e:
|
|
852
|
-
PrettyOutput.print(f"Error during codebase generation: {str(e)}", output_type=OutputType.ERROR)
|
|
853
|
-
|
|
854
|
-
elif args.command == 'search':
|
|
855
|
-
results = codebase.search_similar(args.query, args.top_k)
|
|
856
|
-
if not results:
|
|
857
|
-
PrettyOutput.print("No similar files found", output_type=OutputType.WARNING)
|
|
858
|
-
return
|
|
859
|
-
|
|
860
|
-
PrettyOutput.print("\nSearch Results:", output_type=OutputType.INFO)
|
|
861
|
-
for path in results:
|
|
862
|
-
PrettyOutput.print("\n" + "="*50, output_type=OutputType.INFO)
|
|
863
|
-
PrettyOutput.print(f"File: {path}", output_type=OutputType.INFO)
|
|
864
|
-
|
|
865
|
-
elif args.command == 'ask':
|
|
866
|
-
response = codebase.ask_codebase(args.question, args.top_k)
|
|
867
|
-
PrettyOutput.print("\nAnswer:", output_type=OutputType.INFO)
|
|
868
|
-
PrettyOutput.print(response, output_type=OutputType.INFO)
|
|
869
|
-
|
|
870
|
-
else:
|
|
871
|
-
parser.print_help()
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
if __name__ == "__main__":
|
|
875
|
-
exit(main())
|