jarvis-ai-assistant 0.1.97__py3-none-any.whl → 0.1.99__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jarvis-ai-assistant might be problematic. Click here for more details.
- jarvis/__init__.py +1 -1
- jarvis/agent.py +199 -157
- jarvis/jarvis_code_agent/__init__.py +0 -0
- jarvis/jarvis_code_agent/main.py +203 -0
- jarvis/jarvis_codebase/main.py +412 -284
- jarvis/jarvis_coder/file_select.py +209 -0
- jarvis/jarvis_coder/git_utils.py +81 -19
- jarvis/jarvis_coder/main.py +68 -446
- jarvis/jarvis_coder/patch_handler.py +117 -47
- jarvis/jarvis_coder/plan_generator.py +69 -27
- jarvis/jarvis_platform/main.py +38 -38
- jarvis/jarvis_rag/main.py +189 -189
- jarvis/jarvis_smart_shell/main.py +22 -24
- jarvis/models/base.py +6 -1
- jarvis/models/ollama.py +2 -2
- jarvis/models/registry.py +3 -6
- jarvis/tools/ask_user.py +6 -6
- jarvis/tools/codebase_qa.py +5 -7
- jarvis/tools/create_code_sub_agent.py +55 -0
- jarvis/tools/{sub_agent.py → create_sub_agent.py} +4 -1
- jarvis/tools/execute_code_modification.py +72 -0
- jarvis/tools/{file_ops.py → file_operation.py} +13 -14
- jarvis/tools/find_related_files.py +86 -0
- jarvis/tools/methodology.py +25 -25
- jarvis/tools/rag.py +32 -32
- jarvis/tools/registry.py +72 -36
- jarvis/tools/search.py +1 -1
- jarvis/tools/select_code_files.py +64 -0
- jarvis/utils.py +153 -49
- {jarvis_ai_assistant-0.1.97.dist-info → jarvis_ai_assistant-0.1.99.dist-info}/METADATA +1 -1
- jarvis_ai_assistant-0.1.99.dist-info/RECORD +52 -0
- {jarvis_ai_assistant-0.1.97.dist-info → jarvis_ai_assistant-0.1.99.dist-info}/entry_points.txt +2 -1
- jarvis/main.py +0 -155
- jarvis/tools/coder.py +0 -69
- jarvis_ai_assistant-0.1.97.dist-info/RECORD +0 -47
- /jarvis/tools/{shell.py → execute_shell.py} +0 -0
- /jarvis/tools/{generator.py → generate_tool.py} +0 -0
- /jarvis/tools/{webpage.py → read_webpage.py} +0 -0
- {jarvis_ai_assistant-0.1.97.dist-info → jarvis_ai_assistant-0.1.99.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.97.dist-info → jarvis_ai_assistant-0.1.99.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.97.dist-info → jarvis_ai_assistant-0.1.99.dist-info}/top_level.txt +0 -0
jarvis/jarvis_codebase/main.py
CHANGED
|
@@ -9,7 +9,7 @@ from jarvis.models.registry import PlatformRegistry
|
|
|
9
9
|
import concurrent.futures
|
|
10
10
|
from threading import Lock
|
|
11
11
|
from concurrent.futures import ThreadPoolExecutor
|
|
12
|
-
from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_file_md5, get_max_context_length, get_thread_count, load_embedding_model, load_rerank_model
|
|
12
|
+
from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_file_md5, get_max_context_length, get_single_line_input, get_thread_count, load_embedding_model, load_rerank_model
|
|
13
13
|
from jarvis.utils import load_env_from_file
|
|
14
14
|
import argparse
|
|
15
15
|
import pickle
|
|
@@ -28,59 +28,37 @@ class CodeBase:
|
|
|
28
28
|
|
|
29
29
|
# 初始化数据目录
|
|
30
30
|
self.data_dir = os.path.join(self.root_dir, ".jarvis-codebase")
|
|
31
|
-
|
|
32
|
-
|
|
31
|
+
self.cache_dir = os.path.join(self.data_dir, "cache")
|
|
32
|
+
if not os.path.exists(self.cache_dir):
|
|
33
|
+
os.makedirs(self.cache_dir)
|
|
33
34
|
|
|
34
|
-
#
|
|
35
|
+
# 初始化嵌入模型
|
|
35
36
|
try:
|
|
36
37
|
self.embedding_model = load_embedding_model()
|
|
37
|
-
|
|
38
|
-
# 强制完全加载所有模型组件
|
|
39
|
-
test_text = """
|
|
40
|
-
这是一段测试文本,用于确保模型完全加载。
|
|
41
|
-
包含多行内容,以模拟实际使用场景。
|
|
42
|
-
"""
|
|
43
|
-
# 预热模型,确保所有组件都被加载
|
|
38
|
+
test_text = """This is a test text"""
|
|
44
39
|
self.embedding_model.encode([test_text],
|
|
45
40
|
convert_to_tensor=True,
|
|
46
41
|
normalize_embeddings=True)
|
|
47
|
-
PrettyOutput.print("
|
|
42
|
+
PrettyOutput.print("Model loaded successfully", output_type=OutputType.SUCCESS)
|
|
48
43
|
except Exception as e:
|
|
49
|
-
PrettyOutput.print(f"
|
|
44
|
+
PrettyOutput.print(f"Failed to load model: {str(e)}", output_type=OutputType.ERROR)
|
|
50
45
|
raise
|
|
51
46
|
|
|
52
47
|
self.vector_dim = self.embedding_model.get_sentence_embedding_dimension()
|
|
53
|
-
|
|
54
48
|
self.git_file_list = self.get_git_file_list()
|
|
55
49
|
self.platform_registry = PlatformRegistry.get_global_platform_registry()
|
|
56
50
|
|
|
57
51
|
# 初始化缓存和索引
|
|
58
|
-
self.cache_path = os.path.join(self.data_dir, "cache.pkl")
|
|
59
52
|
self.vector_cache = {}
|
|
60
53
|
self.file_paths = []
|
|
61
54
|
|
|
62
|
-
#
|
|
63
|
-
|
|
64
|
-
try:
|
|
65
|
-
with lzma.open(self.cache_path, 'rb') as f:
|
|
66
|
-
cache_data = pickle.load(f)
|
|
67
|
-
self.vector_cache = cache_data["vectors"]
|
|
68
|
-
self.file_paths = cache_data["file_paths"]
|
|
69
|
-
PrettyOutput.print(f"加载了 {len(self.vector_cache)} 个向量缓存",
|
|
70
|
-
output_type=OutputType.INFO)
|
|
71
|
-
# 从缓存重建索引
|
|
72
|
-
self.build_index()
|
|
73
|
-
except Exception as e:
|
|
74
|
-
PrettyOutput.print(f"加载缓存失败: {str(e)}",
|
|
75
|
-
output_type=OutputType.WARNING)
|
|
76
|
-
self.vector_cache = {}
|
|
77
|
-
self.file_paths = []
|
|
78
|
-
self.index = None
|
|
55
|
+
# 加载所有缓存文件
|
|
56
|
+
self._load_all_cache()
|
|
79
57
|
|
|
80
58
|
def get_git_file_list(self):
|
|
81
|
-
"""
|
|
59
|
+
"""Get the list of files in the git repository, excluding the .jarvis-codebase directory"""
|
|
82
60
|
files = os.popen("git ls-files").read().splitlines()
|
|
83
|
-
#
|
|
61
|
+
# Filter out files in the .jarvis-codebase directory
|
|
84
62
|
return [f for f in files if not f.startswith(".jarvis-")]
|
|
85
63
|
|
|
86
64
|
def is_text_file(self, file_path: str):
|
|
@@ -95,10 +73,11 @@ class CodeBase:
|
|
|
95
73
|
model = PlatformRegistry.get_global_platform_registry().get_cheap_platform()
|
|
96
74
|
if self.thread_count > 1:
|
|
97
75
|
model.set_suppress_output(True)
|
|
76
|
+
else:
|
|
77
|
+
PrettyOutput.print(f"Make description for {file_path} ...", output_type=OutputType.PROGRESS)
|
|
98
78
|
prompt = f"""Please analyze the following code file and generate a detailed description. The description should include:
|
|
99
|
-
1. Overall file functionality description
|
|
100
|
-
2.
|
|
101
|
-
3. 5 potential questions users might ask about this file
|
|
79
|
+
1. Overall file functionality description
|
|
80
|
+
2. description for each global variable, function, type definition, class, method, and other code elements
|
|
102
81
|
|
|
103
82
|
Please use concise and professional language, emphasizing technical functionality to facilitate subsequent code retrieval.
|
|
104
83
|
File path: {file_path}
|
|
@@ -109,42 +88,117 @@ Code content:
|
|
|
109
88
|
return response
|
|
110
89
|
|
|
111
90
|
def export(self):
|
|
112
|
-
"""
|
|
91
|
+
"""Export the current index data to standard output"""
|
|
113
92
|
for file_path, data in self.vector_cache.items():
|
|
114
93
|
print(f"## {file_path}")
|
|
115
94
|
print(f"- path: {file_path}")
|
|
116
95
|
print(f"- description: {data['description']}")
|
|
117
96
|
|
|
118
|
-
def
|
|
119
|
-
"""
|
|
97
|
+
def _get_cache_path(self, file_path: str) -> str:
|
|
98
|
+
"""Get cache file path for a source file
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
file_path: Source file path
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
str: Cache file path
|
|
105
|
+
"""
|
|
106
|
+
# 处理文件路径:
|
|
107
|
+
# 1. 移除开头的 ./ 或 /
|
|
108
|
+
# 2. 将 / 替换为 --
|
|
109
|
+
# 3. 添加 .cache 后缀
|
|
110
|
+
clean_path = file_path.lstrip('./').lstrip('/')
|
|
111
|
+
cache_name = clean_path.replace('/', '--') + '.cache'
|
|
112
|
+
return os.path.join(self.cache_dir, cache_name)
|
|
113
|
+
|
|
114
|
+
def _load_all_cache(self):
|
|
115
|
+
"""Load all cache files"""
|
|
120
116
|
try:
|
|
121
|
-
#
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
117
|
+
# 清空现有缓存和文件路径
|
|
118
|
+
self.vector_cache = {}
|
|
119
|
+
self.file_paths = []
|
|
120
|
+
vectors = []
|
|
121
|
+
|
|
122
|
+
for cache_file in os.listdir(self.cache_dir):
|
|
123
|
+
if not cache_file.endswith('.cache'):
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
cache_path = os.path.join(self.cache_dir, cache_file)
|
|
127
|
+
try:
|
|
128
|
+
with lzma.open(cache_path, 'rb') as f:
|
|
129
|
+
cache_data = pickle.load(f)
|
|
130
|
+
file_path = cache_data["path"]
|
|
131
|
+
self.vector_cache[file_path] = cache_data
|
|
132
|
+
self.file_paths.append(file_path)
|
|
133
|
+
vectors.append(cache_data["vector"])
|
|
134
|
+
except Exception as e:
|
|
135
|
+
PrettyOutput.print(f"Failed to load cache file {cache_file}: {str(e)}",
|
|
136
|
+
output_type=OutputType.WARNING)
|
|
137
|
+
continue
|
|
126
138
|
|
|
127
|
-
|
|
128
|
-
|
|
139
|
+
if vectors:
|
|
140
|
+
# 重建索引
|
|
141
|
+
vectors_array = np.vstack(vectors)
|
|
142
|
+
hnsw_index = faiss.IndexHNSWFlat(self.vector_dim, 16)
|
|
143
|
+
hnsw_index.hnsw.efConstruction = 40
|
|
144
|
+
hnsw_index.hnsw.efSearch = 16
|
|
145
|
+
self.index = faiss.IndexIDMap(hnsw_index)
|
|
146
|
+
self.index.add_with_ids(vectors_array, np.array(range(len(vectors)))) # type: ignore
|
|
147
|
+
|
|
148
|
+
PrettyOutput.print(f"Loaded {len(self.vector_cache)} vector cache and rebuilt index",
|
|
149
|
+
output_type=OutputType.INFO)
|
|
150
|
+
else:
|
|
151
|
+
self.index = None
|
|
152
|
+
PrettyOutput.print("No valid cache files found", output_type=OutputType.WARNING)
|
|
153
|
+
|
|
154
|
+
except Exception as e:
|
|
155
|
+
PrettyOutput.print(f"Failed to load cache directory: {str(e)}",
|
|
156
|
+
output_type=OutputType.WARNING)
|
|
157
|
+
self.vector_cache = {}
|
|
158
|
+
self.file_paths = []
|
|
159
|
+
self.index = None
|
|
160
|
+
|
|
161
|
+
def cache_vector(self, file_path: str, vector: np.ndarray, description: str):
|
|
162
|
+
"""Cache the vector representation of a file"""
|
|
163
|
+
try:
|
|
164
|
+
with open(file_path, "rb") as f:
|
|
165
|
+
file_md5 = hashlib.md5(f.read()).hexdigest()
|
|
166
|
+
except Exception as e:
|
|
167
|
+
PrettyOutput.print(f"Failed to calculate MD5 for {file_path}: {str(e)}",
|
|
168
|
+
output_type=OutputType.ERROR)
|
|
169
|
+
file_md5 = ""
|
|
170
|
+
|
|
171
|
+
# 准备缓存数据
|
|
172
|
+
cache_data = {
|
|
173
|
+
"path": file_path, # 保存文件路径
|
|
174
|
+
"md5": file_md5, # 保存文件MD5
|
|
175
|
+
"description": description, # 保存文件描述
|
|
176
|
+
"vector": vector # 保存向量
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
# 更新内存缓存
|
|
180
|
+
self.vector_cache[file_path] = cache_data
|
|
181
|
+
|
|
182
|
+
# 保存到单独的缓存文件
|
|
183
|
+
cache_path = self._get_cache_path(file_path)
|
|
184
|
+
try:
|
|
185
|
+
with lzma.open(cache_path, 'wb') as f:
|
|
129
186
|
pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
130
|
-
PrettyOutput.print(f"保存了 {len(self.vector_cache)} 个向量缓存",
|
|
131
|
-
output_type=OutputType.INFO)
|
|
132
187
|
except Exception as e:
|
|
133
|
-
PrettyOutput.print(f"
|
|
188
|
+
PrettyOutput.print(f"Failed to save cache for {file_path}: {str(e)}",
|
|
134
189
|
output_type=OutputType.ERROR)
|
|
135
|
-
raise # 抛出异常以便上层处理
|
|
136
190
|
|
|
137
191
|
def get_cached_vector(self, file_path: str, description: str) -> Optional[np.ndarray]:
|
|
138
|
-
"""
|
|
192
|
+
"""Get the vector representation of a file from the cache"""
|
|
139
193
|
if file_path not in self.vector_cache:
|
|
140
194
|
return None
|
|
141
195
|
|
|
142
|
-
#
|
|
196
|
+
# Check if the file has been modified
|
|
143
197
|
try:
|
|
144
198
|
with open(file_path, "rb") as f:
|
|
145
199
|
current_md5 = hashlib.md5(f.read()).hexdigest()
|
|
146
200
|
except Exception as e:
|
|
147
|
-
PrettyOutput.print(f"
|
|
201
|
+
PrettyOutput.print(f"Failed to calculate MD5 for {file_path}: {str(e)}",
|
|
148
202
|
output_type=OutputType.ERROR)
|
|
149
203
|
return None
|
|
150
204
|
|
|
@@ -152,63 +206,45 @@ Code content:
|
|
|
152
206
|
if cached_data["md5"] != current_md5:
|
|
153
207
|
return None
|
|
154
208
|
|
|
155
|
-
#
|
|
209
|
+
# Check if the description has changed
|
|
156
210
|
if cached_data["description"] != description:
|
|
157
211
|
return None
|
|
158
212
|
|
|
159
213
|
return cached_data["vector"]
|
|
160
214
|
|
|
161
|
-
def cache_vector(self, file_path: str, vector: np.ndarray, description: str):
|
|
162
|
-
"""缓存文件的向量表示"""
|
|
163
|
-
try:
|
|
164
|
-
with open(file_path, "rb") as f:
|
|
165
|
-
file_md5 = hashlib.md5(f.read()).hexdigest()
|
|
166
|
-
except Exception as e:
|
|
167
|
-
PrettyOutput.print(f"计算文件MD5失败 {file_path}: {str(e)}",
|
|
168
|
-
output_type=OutputType.ERROR)
|
|
169
|
-
file_md5 = ""
|
|
170
|
-
|
|
171
|
-
# 只更新内存中的缓存
|
|
172
|
-
self.vector_cache[file_path] = {
|
|
173
|
-
"path": file_path, # 保存文件路径
|
|
174
|
-
"md5": file_md5, # 保存文件MD5
|
|
175
|
-
"description": description, # 保存文件描述
|
|
176
|
-
"vector": vector # 保存向量
|
|
177
|
-
}
|
|
178
|
-
|
|
179
215
|
def get_embedding(self, text: str) -> np.ndarray:
|
|
180
|
-
"""
|
|
181
|
-
#
|
|
182
|
-
max_length = 512 #
|
|
216
|
+
"""Use the transformers model to get the vector representation of text"""
|
|
217
|
+
# Truncate long text
|
|
218
|
+
max_length = 512 # Or other suitable length
|
|
183
219
|
text = ' '.join(text.split()[:max_length])
|
|
184
220
|
|
|
185
|
-
#
|
|
221
|
+
# Get the embedding vector
|
|
186
222
|
embedding = self.embedding_model.encode(text,
|
|
187
|
-
normalize_embeddings=True, # L2
|
|
223
|
+
normalize_embeddings=True, # L2 normalization
|
|
188
224
|
show_progress_bar=False)
|
|
189
225
|
vector = np.array(embedding, dtype=np.float32)
|
|
190
226
|
return vector
|
|
191
227
|
|
|
192
228
|
def vectorize_file(self, file_path: str, description: str) -> np.ndarray:
|
|
193
|
-
"""
|
|
229
|
+
"""Vectorize the file content and description"""
|
|
194
230
|
try:
|
|
195
|
-
#
|
|
231
|
+
# Try to get the vector from the cache first
|
|
196
232
|
cached_vector = self.get_cached_vector(file_path, description)
|
|
197
233
|
if cached_vector is not None:
|
|
198
234
|
return cached_vector
|
|
199
235
|
|
|
200
|
-
#
|
|
201
|
-
content = open(file_path, "r", encoding="utf-8").read()[:self.max_context_length] #
|
|
236
|
+
# Read the file content and combine information
|
|
237
|
+
content = open(file_path, "r", encoding="utf-8").read()[:self.max_context_length] # Limit the file content length
|
|
202
238
|
|
|
203
|
-
#
|
|
239
|
+
# Combine file information, including file content
|
|
204
240
|
combined_text = f"""
|
|
205
|
-
{file_path}
|
|
206
|
-
{description}
|
|
207
|
-
{content}
|
|
241
|
+
File path: {file_path}
|
|
242
|
+
Description: {description}
|
|
243
|
+
Content: {content}
|
|
208
244
|
"""
|
|
209
245
|
vector = self.get_embedding(combined_text)
|
|
210
246
|
|
|
211
|
-
#
|
|
247
|
+
# Save to cache
|
|
212
248
|
self.cache_vector(file_path, vector, description)
|
|
213
249
|
return vector
|
|
214
250
|
except Exception as e:
|
|
@@ -217,36 +253,34 @@ Code content:
|
|
|
217
253
|
return np.zeros(self.vector_dim, dtype=np.float32) # type: ignore
|
|
218
254
|
|
|
219
255
|
def clean_cache(self) -> bool:
|
|
220
|
-
"""
|
|
256
|
+
"""Clean expired cache records"""
|
|
221
257
|
try:
|
|
222
258
|
files_to_delete = []
|
|
223
259
|
for file_path in list(self.vector_cache.keys()):
|
|
224
|
-
if
|
|
225
|
-
del self.vector_cache[file_path]
|
|
260
|
+
if not os.path.exists(file_path):
|
|
226
261
|
files_to_delete.append(file_path)
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
262
|
+
cache_path = self._get_cache_path(file_path)
|
|
263
|
+
try:
|
|
264
|
+
os.remove(cache_path)
|
|
265
|
+
except Exception:
|
|
266
|
+
pass
|
|
267
|
+
|
|
268
|
+
for file_path in files_to_delete:
|
|
269
|
+
del self.vector_cache[file_path]
|
|
270
|
+
if file_path in self.file_paths:
|
|
271
|
+
self.file_paths.remove(file_path)
|
|
272
|
+
|
|
273
|
+
return bool(files_to_delete)
|
|
235
274
|
|
|
236
275
|
except Exception as e:
|
|
237
|
-
PrettyOutput.print(f"
|
|
238
|
-
|
|
239
|
-
# 发生异常时尝试保存当前状态
|
|
240
|
-
try:
|
|
241
|
-
self._save_cache()
|
|
242
|
-
except:
|
|
243
|
-
pass
|
|
276
|
+
PrettyOutput.print(f"Failed to clean cache: {str(e)}",
|
|
277
|
+
output_type=OutputType.ERROR)
|
|
244
278
|
return False
|
|
245
279
|
|
|
246
280
|
def process_file(self, file_path: str):
|
|
247
|
-
"""
|
|
281
|
+
"""Process a single file"""
|
|
248
282
|
try:
|
|
249
|
-
#
|
|
283
|
+
# Skip non-existent files
|
|
250
284
|
if not os.path.exists(file_path):
|
|
251
285
|
return None
|
|
252
286
|
|
|
@@ -257,15 +291,15 @@ Code content:
|
|
|
257
291
|
|
|
258
292
|
content = open(file_path, "r", encoding="utf-8").read()
|
|
259
293
|
|
|
260
|
-
#
|
|
294
|
+
# Check if the file has already been processed and the content has not changed
|
|
261
295
|
if file_path in self.vector_cache:
|
|
262
296
|
if self.vector_cache[file_path].get("md5") == md5:
|
|
263
297
|
return None
|
|
264
298
|
|
|
265
|
-
description = self.make_description(file_path, content) #
|
|
299
|
+
description = self.make_description(file_path, content) # Pass the truncated content
|
|
266
300
|
vector = self.vectorize_file(file_path, description)
|
|
267
301
|
|
|
268
|
-
#
|
|
302
|
+
# Save to cache, using the actual file path as the key
|
|
269
303
|
self.vector_cache[file_path] = {
|
|
270
304
|
"vector": vector,
|
|
271
305
|
"description": description,
|
|
@@ -275,58 +309,94 @@ Code content:
|
|
|
275
309
|
return file_path
|
|
276
310
|
|
|
277
311
|
except Exception as e:
|
|
278
|
-
PrettyOutput.print(f"
|
|
312
|
+
PrettyOutput.print(f"Failed to process file {file_path}: {str(e)}",
|
|
279
313
|
output_type=OutputType.ERROR)
|
|
280
314
|
return None
|
|
281
315
|
|
|
282
316
|
def build_index(self):
|
|
283
|
-
"""
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
vectors
|
|
298
|
-
ids
|
|
299
|
-
self.file_paths
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
317
|
+
"""Build a faiss index from the vector cache"""
|
|
318
|
+
try:
|
|
319
|
+
if not self.vector_cache:
|
|
320
|
+
self.index = None
|
|
321
|
+
return
|
|
322
|
+
|
|
323
|
+
# Create the underlying HNSW index
|
|
324
|
+
hnsw_index = faiss.IndexHNSWFlat(self.vector_dim, 16)
|
|
325
|
+
hnsw_index.hnsw.efConstruction = 40
|
|
326
|
+
hnsw_index.hnsw.efSearch = 16
|
|
327
|
+
|
|
328
|
+
# Wrap the HNSW index with IndexIDMap
|
|
329
|
+
self.index = faiss.IndexIDMap(hnsw_index)
|
|
330
|
+
|
|
331
|
+
vectors = []
|
|
332
|
+
ids = []
|
|
333
|
+
self.file_paths = [] # Reset the file path list
|
|
334
|
+
|
|
335
|
+
for i, (file_path, data) in enumerate(self.vector_cache.items()):
|
|
336
|
+
if "vector" not in data:
|
|
337
|
+
PrettyOutput.print(f"Invalid cache data for {file_path}: missing vector",
|
|
338
|
+
output_type=OutputType.WARNING)
|
|
339
|
+
continue
|
|
340
|
+
|
|
341
|
+
vector = data["vector"]
|
|
342
|
+
if not isinstance(vector, np.ndarray):
|
|
343
|
+
PrettyOutput.print(f"Invalid vector type for {file_path}: {type(vector)}",
|
|
344
|
+
output_type=OutputType.WARNING)
|
|
345
|
+
continue
|
|
346
|
+
|
|
347
|
+
vectors.append(vector.reshape(1, -1))
|
|
348
|
+
ids.append(i)
|
|
349
|
+
self.file_paths.append(file_path)
|
|
350
|
+
|
|
351
|
+
if vectors:
|
|
352
|
+
vectors = np.vstack(vectors)
|
|
353
|
+
if len(vectors) != len(ids):
|
|
354
|
+
PrettyOutput.print(f"Vector count mismatch: {len(vectors)} vectors vs {len(ids)} ids",
|
|
355
|
+
output_type=OutputType.ERROR)
|
|
356
|
+
self.index = None
|
|
357
|
+
return
|
|
358
|
+
|
|
359
|
+
try:
|
|
360
|
+
self.index.add_with_ids(vectors, np.array(ids)) # type: ignore
|
|
361
|
+
PrettyOutput.print(f"Successfully built index with {len(vectors)} vectors",
|
|
362
|
+
output_type=OutputType.SUCCESS)
|
|
363
|
+
except Exception as e:
|
|
364
|
+
PrettyOutput.print(f"Failed to add vectors to index: {str(e)}",
|
|
365
|
+
output_type=OutputType.ERROR)
|
|
366
|
+
self.index = None
|
|
367
|
+
else:
|
|
368
|
+
PrettyOutput.print("No valid vectors found, index not built",
|
|
369
|
+
output_type=OutputType.WARNING)
|
|
370
|
+
self.index = None
|
|
371
|
+
|
|
372
|
+
except Exception as e:
|
|
373
|
+
PrettyOutput.print(f"Failed to build index: {str(e)}",
|
|
374
|
+
output_type=OutputType.ERROR)
|
|
305
375
|
self.index = None
|
|
306
376
|
|
|
307
377
|
def gen_vector_db_from_cache(self):
|
|
308
|
-
"""
|
|
378
|
+
"""Generate a vector database from the cache"""
|
|
309
379
|
self.build_index()
|
|
310
|
-
self.
|
|
380
|
+
self._load_all_cache()
|
|
311
381
|
|
|
312
382
|
|
|
313
383
|
def generate_codebase(self, force: bool = False):
|
|
314
|
-
"""
|
|
384
|
+
"""Generate the codebase index
|
|
315
385
|
Args:
|
|
316
|
-
force:
|
|
386
|
+
force: Whether to force rebuild the index, without asking the user
|
|
317
387
|
"""
|
|
318
388
|
try:
|
|
319
|
-
#
|
|
389
|
+
# Update the git file list
|
|
320
390
|
self.git_file_list = self.get_git_file_list()
|
|
321
391
|
|
|
322
|
-
#
|
|
323
|
-
PrettyOutput.print("\
|
|
392
|
+
# Check file changes
|
|
393
|
+
PrettyOutput.print("\nCheck file changes...", output_type=OutputType.INFO)
|
|
324
394
|
changes_detected = False
|
|
325
395
|
new_files = []
|
|
326
396
|
modified_files = []
|
|
327
397
|
deleted_files = []
|
|
328
398
|
|
|
329
|
-
#
|
|
399
|
+
# Check deleted files
|
|
330
400
|
files_to_delete = []
|
|
331
401
|
for file_path in list(self.vector_cache.keys()):
|
|
332
402
|
if file_path not in self.git_file_list:
|
|
@@ -334,8 +404,8 @@ Code content:
|
|
|
334
404
|
files_to_delete.append(file_path)
|
|
335
405
|
changes_detected = True
|
|
336
406
|
|
|
337
|
-
#
|
|
338
|
-
with tqdm(total=len(self.git_file_list), desc="
|
|
407
|
+
# Check new and modified files
|
|
408
|
+
with tqdm(total=len(self.git_file_list), desc="Check file status") as pbar:
|
|
339
409
|
for file_path in self.git_file_list:
|
|
340
410
|
if not os.path.exists(file_path) or not self.is_text_file(file_path):
|
|
341
411
|
pbar.update(1)
|
|
@@ -351,60 +421,60 @@ Code content:
|
|
|
351
421
|
modified_files.append(file_path)
|
|
352
422
|
changes_detected = True
|
|
353
423
|
except Exception as e:
|
|
354
|
-
PrettyOutput.print(f"
|
|
424
|
+
PrettyOutput.print(f"Failed to check file {file_path}: {str(e)}",
|
|
355
425
|
output_type=OutputType.ERROR)
|
|
356
426
|
pbar.update(1)
|
|
357
427
|
|
|
358
|
-
#
|
|
428
|
+
# If changes are detected, display changes and ask the user
|
|
359
429
|
if changes_detected:
|
|
360
|
-
PrettyOutput.print("\
|
|
430
|
+
PrettyOutput.print("\nDetected the following changes:", output_type=OutputType.WARNING)
|
|
361
431
|
if new_files:
|
|
362
|
-
PrettyOutput.print("\
|
|
432
|
+
PrettyOutput.print("\nNew files:", output_type=OutputType.INFO)
|
|
363
433
|
for f in new_files:
|
|
364
434
|
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
365
435
|
if modified_files:
|
|
366
|
-
PrettyOutput.print("\
|
|
436
|
+
PrettyOutput.print("\nModified files:", output_type=OutputType.INFO)
|
|
367
437
|
for f in modified_files:
|
|
368
438
|
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
369
439
|
if deleted_files:
|
|
370
|
-
PrettyOutput.print("\
|
|
440
|
+
PrettyOutput.print("\nDeleted files:", output_type=OutputType.INFO)
|
|
371
441
|
for f in deleted_files:
|
|
372
442
|
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
373
443
|
|
|
374
|
-
#
|
|
444
|
+
# If force is True, continue directly
|
|
375
445
|
if not force:
|
|
376
|
-
#
|
|
446
|
+
# Ask the user whether to continue
|
|
377
447
|
while True:
|
|
378
|
-
response =
|
|
448
|
+
response = get_single_line_input("\nRebuild the index? [y/N]").lower().strip()
|
|
379
449
|
if response in ['y', 'yes']:
|
|
380
450
|
break
|
|
381
451
|
elif response in ['', 'n', 'no']:
|
|
382
|
-
PrettyOutput.print("
|
|
452
|
+
PrettyOutput.print("Cancel rebuilding the index", output_type=OutputType.INFO)
|
|
383
453
|
return
|
|
384
454
|
else:
|
|
385
|
-
PrettyOutput.print("
|
|
455
|
+
PrettyOutput.print("Please input y or n", output_type=OutputType.WARNING)
|
|
386
456
|
|
|
387
|
-
#
|
|
457
|
+
# Clean deleted files
|
|
388
458
|
for file_path in files_to_delete:
|
|
389
459
|
del self.vector_cache[file_path]
|
|
390
460
|
if files_to_delete:
|
|
391
|
-
PrettyOutput.print(f"
|
|
461
|
+
PrettyOutput.print(f"Cleaned the cache of {len(files_to_delete)} files",
|
|
392
462
|
output_type=OutputType.INFO)
|
|
393
463
|
|
|
394
|
-
#
|
|
464
|
+
# Process new and modified files
|
|
395
465
|
files_to_process = new_files + modified_files
|
|
396
466
|
processed_files = []
|
|
397
467
|
|
|
398
|
-
with tqdm(total=len(files_to_process), desc="
|
|
399
|
-
#
|
|
468
|
+
with tqdm(total=len(files_to_process), desc="Processing files") as pbar:
|
|
469
|
+
# Use a thread pool to process files
|
|
400
470
|
with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
|
|
401
|
-
#
|
|
471
|
+
# Submit all tasks
|
|
402
472
|
future_to_file = {
|
|
403
473
|
executor.submit(self.process_file, file): file
|
|
404
474
|
for file in files_to_process
|
|
405
475
|
}
|
|
406
476
|
|
|
407
|
-
#
|
|
477
|
+
# Process completed tasks
|
|
408
478
|
for future in concurrent.futures.as_completed(future_to_file):
|
|
409
479
|
file = future_to_file[future]
|
|
410
480
|
try:
|
|
@@ -412,37 +482,37 @@ Code content:
|
|
|
412
482
|
if result:
|
|
413
483
|
processed_files.append(result)
|
|
414
484
|
except Exception as e:
|
|
415
|
-
PrettyOutput.print(f"
|
|
485
|
+
PrettyOutput.print(f"Failed to process file {file}: {str(e)}",
|
|
416
486
|
output_type=OutputType.ERROR)
|
|
417
487
|
pbar.update(1)
|
|
418
488
|
|
|
419
489
|
if processed_files:
|
|
420
|
-
PrettyOutput.print("\
|
|
490
|
+
PrettyOutput.print("\nRebuilding the vector database...", output_type=OutputType.INFO)
|
|
421
491
|
self.gen_vector_db_from_cache()
|
|
422
|
-
PrettyOutput.print(f"
|
|
492
|
+
PrettyOutput.print(f"Successfully generated the index for {len(processed_files)} files",
|
|
423
493
|
output_type=OutputType.SUCCESS)
|
|
424
494
|
else:
|
|
425
|
-
PrettyOutput.print("
|
|
495
|
+
PrettyOutput.print("No file changes detected, no need to rebuild the index", output_type=OutputType.INFO)
|
|
426
496
|
|
|
427
497
|
except Exception as e:
|
|
428
|
-
#
|
|
498
|
+
# Try to save the cache when an exception occurs
|
|
429
499
|
try:
|
|
430
|
-
self.
|
|
500
|
+
self._load_all_cache()
|
|
431
501
|
except Exception as save_error:
|
|
432
|
-
PrettyOutput.print(f"
|
|
502
|
+
PrettyOutput.print(f"Failed to save cache: {str(save_error)}",
|
|
433
503
|
output_type=OutputType.ERROR)
|
|
434
|
-
raise e #
|
|
504
|
+
raise e # Re-raise the original exception
|
|
435
505
|
|
|
436
506
|
|
|
437
507
|
def _text_search_score(self, content: str, keywords: List[str]) -> float:
|
|
438
|
-
"""
|
|
508
|
+
"""Calculate the matching score between the text content and the keywords
|
|
439
509
|
|
|
440
510
|
Args:
|
|
441
|
-
content:
|
|
442
|
-
keywords:
|
|
511
|
+
content: Text content
|
|
512
|
+
keywords: List of keywords
|
|
443
513
|
|
|
444
514
|
Returns:
|
|
445
|
-
float:
|
|
515
|
+
float: Matching score (0-1)
|
|
446
516
|
"""
|
|
447
517
|
if not keywords:
|
|
448
518
|
return 0.0
|
|
@@ -455,89 +525,128 @@ Code content:
|
|
|
455
525
|
if keyword in content:
|
|
456
526
|
matched_keywords.add(keyword)
|
|
457
527
|
|
|
458
|
-
#
|
|
528
|
+
# Calculate the matching score
|
|
459
529
|
score = len(matched_keywords) / len(keywords)
|
|
460
530
|
return score
|
|
461
531
|
|
|
462
|
-
def
|
|
463
|
-
"""
|
|
532
|
+
def pick_results(self, query: str, initial_results: List[str]) -> List[str]:
|
|
533
|
+
"""Use a large model to pick the search results
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
query: Search query
|
|
537
|
+
initial_results: Initial results list of file paths
|
|
538
|
+
|
|
539
|
+
Returns:
|
|
540
|
+
List[str]: The picked results list, each item is a file path
|
|
541
|
+
"""
|
|
464
542
|
if not initial_results:
|
|
465
543
|
return []
|
|
466
544
|
|
|
467
545
|
try:
|
|
468
|
-
|
|
546
|
+
PrettyOutput.print(f"Picking results for query: {query}", output_type=OutputType.INFO)
|
|
469
547
|
|
|
470
|
-
#
|
|
471
|
-
|
|
548
|
+
# Maximum content length per batch
|
|
549
|
+
max_batch_length = self.max_context_length - 1000 # Reserve space for prompt
|
|
550
|
+
max_file_length = max_batch_length // 3 # Limit individual file size
|
|
472
551
|
|
|
473
|
-
#
|
|
474
|
-
|
|
552
|
+
# Process files in batches
|
|
553
|
+
all_selected_files = set()
|
|
554
|
+
current_batch = []
|
|
555
|
+
current_length = 0
|
|
475
556
|
|
|
476
|
-
for path
|
|
557
|
+
for path in initial_results:
|
|
477
558
|
try:
|
|
478
|
-
content = open(path, "r", encoding="utf-8").read()
|
|
559
|
+
content = open(path, "r", encoding="utf-8").read()
|
|
560
|
+
# Truncate large files
|
|
561
|
+
if len(content) > max_file_length:
|
|
562
|
+
PrettyOutput.print(f"Truncating large file: {path}", OutputType.WARNING)
|
|
563
|
+
content = content[:max_file_length] + "\n... (content truncated)"
|
|
479
564
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
pairs.append([query, doc_content])
|
|
483
|
-
except Exception as e:
|
|
484
|
-
PrettyOutput.print(f"读取文件失败 {path}: {str(e)}",
|
|
485
|
-
output_type=OutputType.ERROR)
|
|
486
|
-
doc_content = f"File path: {path}\nDescription: {desc}"
|
|
487
|
-
pairs.append([query, doc_content])
|
|
488
|
-
|
|
489
|
-
# 使用更大的batch size提高处理速度
|
|
490
|
-
batch_size = 16 # 根据GPU显存调整
|
|
491
|
-
batch_scores = []
|
|
492
|
-
|
|
493
|
-
with torch.no_grad():
|
|
494
|
-
for i in range(0, len(pairs), batch_size):
|
|
495
|
-
batch_pairs = pairs[i:i + batch_size]
|
|
496
|
-
encoded = tokenizer(
|
|
497
|
-
batch_pairs,
|
|
498
|
-
padding=True,
|
|
499
|
-
truncation=True,
|
|
500
|
-
max_length=512,
|
|
501
|
-
return_tensors='pt'
|
|
502
|
-
)
|
|
565
|
+
file_info = f"File: {path}\nContent: {content}\n\n"
|
|
566
|
+
file_length = len(file_info)
|
|
503
567
|
|
|
504
|
-
|
|
505
|
-
|
|
568
|
+
# If adding this file would exceed batch limit
|
|
569
|
+
if current_length + file_length > max_batch_length:
|
|
570
|
+
# Process current batch
|
|
571
|
+
if current_batch:
|
|
572
|
+
selected = self._process_batch(query, current_batch)
|
|
573
|
+
all_selected_files.update(selected)
|
|
574
|
+
# Start new batch
|
|
575
|
+
current_batch = [file_info]
|
|
576
|
+
current_length = file_length
|
|
577
|
+
else:
|
|
578
|
+
current_batch.append(file_info)
|
|
579
|
+
current_length += file_length
|
|
506
580
|
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
# 归一化分数到 0-1 范围
|
|
511
|
-
if batch_scores:
|
|
512
|
-
min_score = min(batch_scores)
|
|
513
|
-
max_score = max(batch_scores)
|
|
514
|
-
if max_score > min_score:
|
|
515
|
-
batch_scores = [(s - min_score) / (max_score - min_score) for s in batch_scores]
|
|
516
|
-
|
|
517
|
-
# 将重排序分数与原始分数结合
|
|
518
|
-
scored_results = []
|
|
519
|
-
for (path,_, desc), rerank_score in zip(initial_results, batch_scores):
|
|
520
|
-
if rerank_score >= 0.5: # 只保留相关度较高的结果
|
|
521
|
-
scored_results.append((path, rerank_score))
|
|
522
|
-
|
|
523
|
-
# 按综合分数降序排序
|
|
524
|
-
scored_results.sort(key=lambda x: x[1], reverse=True)
|
|
581
|
+
except Exception as e:
|
|
582
|
+
PrettyOutput.print(f"Failed to read file {path}: {str(e)}", OutputType.ERROR)
|
|
583
|
+
continue
|
|
525
584
|
|
|
526
|
-
|
|
585
|
+
# Process final batch
|
|
586
|
+
if current_batch:
|
|
587
|
+
selected = self._process_batch(query, current_batch)
|
|
588
|
+
all_selected_files.update(selected)
|
|
527
589
|
|
|
590
|
+
# Convert set to list and maintain original order
|
|
591
|
+
final_results = [path for path in initial_results if path in all_selected_files]
|
|
592
|
+
return final_results
|
|
593
|
+
|
|
528
594
|
except Exception as e:
|
|
529
|
-
PrettyOutput.print(f"
|
|
530
|
-
|
|
531
|
-
|
|
595
|
+
PrettyOutput.print(f"Failed to pick: {str(e)}", OutputType.ERROR)
|
|
596
|
+
return initial_results
|
|
597
|
+
|
|
598
|
+
def _process_batch(self, query: str, files_info: List[str]) -> List[str]:
|
|
599
|
+
"""Process a batch of files
|
|
600
|
+
|
|
601
|
+
Args:
|
|
602
|
+
query: Search query
|
|
603
|
+
files_info: List of file information strings
|
|
604
|
+
|
|
605
|
+
Returns:
|
|
606
|
+
List[str]: Selected file paths from this batch
|
|
607
|
+
"""
|
|
608
|
+
prompt = f"""Please analyze the following code files and determine which files are most relevant to the given query. Consider file paths and code content to make your judgment.
|
|
609
|
+
|
|
610
|
+
Query: {query}
|
|
611
|
+
|
|
612
|
+
Available files:
|
|
613
|
+
{''.join(files_info)}
|
|
614
|
+
|
|
615
|
+
Please output a YAML list of relevant file paths, ordered by relevance (most relevant first). Only include files that are truly relevant to the query.
|
|
616
|
+
Output format:
|
|
617
|
+
<FILES>
|
|
618
|
+
- path/to/file1.py
|
|
619
|
+
- path/to/file2.py
|
|
620
|
+
</FILES>
|
|
621
|
+
|
|
622
|
+
Note: Only include files that have a strong connection to the query."""
|
|
623
|
+
|
|
624
|
+
# Use a large model to evaluate
|
|
625
|
+
model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
|
|
626
|
+
response = model.chat_until_success(prompt)
|
|
627
|
+
|
|
628
|
+
# Parse the response
|
|
629
|
+
import yaml
|
|
630
|
+
files_match = re.search(r'<FILES>\n(.*?)</FILES>', response, re.DOTALL)
|
|
631
|
+
if not files_match:
|
|
632
|
+
return []
|
|
633
|
+
|
|
634
|
+
# Extract the file list
|
|
635
|
+
try:
|
|
636
|
+
selected_files = yaml.safe_load(files_match.group(1))
|
|
637
|
+
return selected_files if selected_files else []
|
|
638
|
+
except Exception as e:
|
|
639
|
+
PrettyOutput.print(f"Failed to parse response: {str(e)}", OutputType.ERROR)
|
|
640
|
+
return []
|
|
532
641
|
|
|
533
642
|
def _generate_query_variants(self, query: str) -> List[str]:
|
|
534
|
-
"""
|
|
643
|
+
"""Generate different expressions of the query
|
|
535
644
|
|
|
536
645
|
Args:
|
|
537
|
-
query:
|
|
646
|
+
query: Original query
|
|
538
647
|
|
|
539
648
|
Returns:
|
|
540
|
-
List[str]:
|
|
649
|
+
List[str]: The query variants list
|
|
541
650
|
"""
|
|
542
651
|
model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
|
|
543
652
|
prompt = f"""Please generate 3 different expressions based on the following query, each expression should fully convey the meaning of the original query. These expressions will be used for code search, maintain professionalism and accuracy.
|
|
@@ -546,18 +655,18 @@ Original query: {query}
|
|
|
546
655
|
Please output 3 expressions directly, separated by two line breaks, without numbering or other markers.
|
|
547
656
|
"""
|
|
548
657
|
variants = model.chat_until_success(prompt).strip().split('\n\n')
|
|
549
|
-
variants.append(query) #
|
|
658
|
+
variants.append(query) # Add the original query
|
|
550
659
|
return variants
|
|
551
660
|
|
|
552
661
|
def _vector_search(self, query_variants: List[str], top_k: int) -> Dict[str, Tuple[str, float, str]]:
|
|
553
|
-
"""
|
|
662
|
+
"""Use vector search to find related files
|
|
554
663
|
|
|
555
664
|
Args:
|
|
556
|
-
query_variants:
|
|
557
|
-
top_k:
|
|
665
|
+
query_variants: The query variants list
|
|
666
|
+
top_k: The number of results to return
|
|
558
667
|
|
|
559
668
|
Returns:
|
|
560
|
-
Dict[str, Tuple[str, float, str]]:
|
|
669
|
+
Dict[str, Tuple[str, float, str]]: The mapping from file path to (file path, score, description)
|
|
561
670
|
"""
|
|
562
671
|
results = {}
|
|
563
672
|
for query in query_variants:
|
|
@@ -571,75 +680,78 @@ Please output 3 expressions directly, separated by two line breaks, without numb
|
|
|
571
680
|
continue
|
|
572
681
|
|
|
573
682
|
similarity = 1.0 / (1.0 + float(distance))
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
if
|
|
683
|
+
file_path = self.file_paths[i]
|
|
684
|
+
# Use the highest similarity score
|
|
685
|
+
if file_path not in results:
|
|
686
|
+
if similarity > 0.5:
|
|
578
687
|
data = self.vector_cache[file_path]
|
|
579
688
|
results[file_path] = (file_path, similarity, data["description"])
|
|
580
689
|
|
|
581
690
|
return results
|
|
582
691
|
|
|
583
692
|
|
|
584
|
-
def search_similar(self, query: str, top_k: int = 30) -> List[
|
|
585
|
-
"""
|
|
693
|
+
def search_similar(self, query: str, top_k: int = 30) -> List[str]:
|
|
694
|
+
"""Search related files"""
|
|
586
695
|
try:
|
|
587
696
|
if self.index is None:
|
|
588
697
|
return []
|
|
589
|
-
#
|
|
698
|
+
# Generate the query variants
|
|
590
699
|
query_variants = self._generate_query_variants(query)
|
|
591
700
|
|
|
592
|
-
#
|
|
701
|
+
# Perform vector search
|
|
593
702
|
vector_results = self._vector_search(query_variants, top_k)
|
|
594
703
|
|
|
595
704
|
results = list(vector_results.values())
|
|
596
705
|
results.sort(key=lambda x: x[1], reverse=True)
|
|
597
706
|
|
|
598
|
-
#
|
|
707
|
+
# Take the top top_k results for reordering
|
|
599
708
|
initial_results = results[:top_k]
|
|
600
709
|
|
|
601
|
-
#
|
|
710
|
+
# If no results are found, return directly
|
|
602
711
|
if not initial_results:
|
|
603
712
|
return []
|
|
604
713
|
|
|
605
|
-
#
|
|
714
|
+
# Filter low-scoring results
|
|
606
715
|
initial_results = [(path, score, desc) for path, score, desc in initial_results if score >= 0.5]
|
|
716
|
+
|
|
717
|
+
for path, score, desc in initial_results:
|
|
718
|
+
PrettyOutput.print(f"File: {path} Similarity: {score:.3f}", output_type=OutputType.INFO)
|
|
607
719
|
|
|
608
|
-
#
|
|
609
|
-
return self.
|
|
720
|
+
# Reorder the preliminary results
|
|
721
|
+
return self.pick_results(query, [path for path, _, _ in initial_results])
|
|
610
722
|
|
|
611
723
|
except Exception as e:
|
|
612
|
-
PrettyOutput.print(f"
|
|
724
|
+
PrettyOutput.print(f"Failed to search: {str(e)}", output_type=OutputType.ERROR)
|
|
613
725
|
return []
|
|
614
726
|
|
|
615
727
|
def ask_codebase(self, query: str, top_k: int=20) -> str:
|
|
616
|
-
"""
|
|
728
|
+
"""Query the codebase"""
|
|
617
729
|
results = self.search_similar(query, top_k)
|
|
618
730
|
if not results:
|
|
619
|
-
PrettyOutput.print("
|
|
731
|
+
PrettyOutput.print("No related files found", output_type=OutputType.WARNING)
|
|
620
732
|
return ""
|
|
621
733
|
|
|
622
|
-
PrettyOutput.print(f"
|
|
623
|
-
for path
|
|
624
|
-
PrettyOutput.print(f"
|
|
734
|
+
PrettyOutput.print(f"Found related files: ", output_type=OutputType.SUCCESS)
|
|
735
|
+
for path in results:
|
|
736
|
+
PrettyOutput.print(f"File: {path}",
|
|
625
737
|
output_type=OutputType.INFO)
|
|
626
738
|
|
|
627
|
-
prompt = f"""
|
|
739
|
+
prompt = f"""You are a code expert, please answer the user's question based on the following file information:
|
|
628
740
|
"""
|
|
629
|
-
for path
|
|
741
|
+
for path in results:
|
|
630
742
|
try:
|
|
631
743
|
if len(prompt) > self.max_context_length:
|
|
632
|
-
PrettyOutput.print(f"
|
|
744
|
+
PrettyOutput.print(f"Avoid context overflow, discard low-related file: {path}", OutputType.WARNING)
|
|
633
745
|
continue
|
|
634
746
|
content = open(path, "r", encoding="utf-8").read()
|
|
635
747
|
prompt += f"""
|
|
636
|
-
File path: {path}
|
|
748
|
+
File path: {path}
|
|
637
749
|
File content:
|
|
638
750
|
{content}
|
|
639
751
|
========================================
|
|
640
752
|
"""
|
|
641
753
|
except Exception as e:
|
|
642
|
-
PrettyOutput.print(f"
|
|
754
|
+
PrettyOutput.print(f"Failed to read file {path}: {str(e)}",
|
|
643
755
|
output_type=OutputType.ERROR)
|
|
644
756
|
continue
|
|
645
757
|
|
|
@@ -653,29 +765,46 @@ Please answer the user's question in Chinese using professional language. If the
|
|
|
653
765
|
return response
|
|
654
766
|
|
|
655
767
|
def is_index_generated(self) -> bool:
|
|
656
|
-
"""
|
|
657
|
-
# 检查缓存文件是否存在
|
|
658
|
-
if not os.path.exists(self.cache_path):
|
|
659
|
-
return False
|
|
660
|
-
|
|
661
|
-
# 检查缓存是否有效
|
|
768
|
+
"""Check if the index has been generated"""
|
|
662
769
|
try:
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
770
|
+
# 1. 检查基本条件
|
|
771
|
+
if not self.vector_cache or not self.file_paths:
|
|
772
|
+
return False
|
|
773
|
+
|
|
774
|
+
if not hasattr(self, 'index') or self.index is None:
|
|
775
|
+
return False
|
|
776
|
+
|
|
777
|
+
# 2. 检查索引是否可用
|
|
778
|
+
# 创建测试向量
|
|
779
|
+
test_vector = np.zeros((1, self.vector_dim), dtype=np.float32) # type: ignore
|
|
780
|
+
try:
|
|
781
|
+
self.index.search(test_vector, 1) # type: ignore
|
|
782
|
+
except Exception:
|
|
783
|
+
return False
|
|
784
|
+
|
|
785
|
+
# 3. 验证向量缓存和文件路径的一致性
|
|
786
|
+
if len(self.vector_cache) != len(self.file_paths):
|
|
787
|
+
return False
|
|
788
|
+
|
|
789
|
+
# 4. 验证所有缓存文件
|
|
790
|
+
for file_path in self.file_paths:
|
|
791
|
+
if file_path not in self.vector_cache:
|
|
666
792
|
return False
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
793
|
+
|
|
794
|
+
cache_path = self._get_cache_path(file_path)
|
|
795
|
+
if not os.path.exists(cache_path):
|
|
796
|
+
return False
|
|
797
|
+
|
|
798
|
+
cache_data = self.vector_cache[file_path]
|
|
799
|
+
if not isinstance(cache_data.get("vector"), np.ndarray):
|
|
800
|
+
return False
|
|
801
|
+
|
|
802
|
+
return True
|
|
803
|
+
|
|
804
|
+
except Exception as e:
|
|
805
|
+
PrettyOutput.print(f"Error checking index status: {str(e)}",
|
|
806
|
+
output_type=OutputType.ERROR)
|
|
676
807
|
return False
|
|
677
|
-
|
|
678
|
-
return True
|
|
679
808
|
|
|
680
809
|
|
|
681
810
|
|
|
@@ -729,10 +858,9 @@ def main():
|
|
|
729
858
|
return
|
|
730
859
|
|
|
731
860
|
PrettyOutput.print("\nSearch Results:", output_type=OutputType.INFO)
|
|
732
|
-
for path
|
|
861
|
+
for path in results:
|
|
733
862
|
PrettyOutput.print("\n" + "="*50, output_type=OutputType.INFO)
|
|
734
863
|
PrettyOutput.print(f"File: {path}", output_type=OutputType.INFO)
|
|
735
|
-
PrettyOutput.print(f"Similarity: {score:.3f}", output_type=OutputType.INFO)
|
|
736
864
|
|
|
737
865
|
elif args.command == 'ask':
|
|
738
866
|
response = codebase.ask_codebase(args.question, args.top_k)
|