jarvis-ai-assistant 0.1.131__py3-none-any.whl → 0.1.132__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +48 -29
- jarvis/jarvis_agent/patch.py +61 -43
- jarvis/jarvis_agent/shell_input_handler.py +1 -1
- jarvis/jarvis_code_agent/code_agent.py +87 -86
- jarvis/jarvis_dev/main.py +335 -626
- jarvis/jarvis_git_squash/main.py +10 -31
- jarvis/jarvis_multi_agent/__init__.py +19 -28
- jarvis/jarvis_platform/ai8.py +7 -32
- jarvis/jarvis_platform/base.py +2 -7
- jarvis/jarvis_platform/kimi.py +3 -144
- jarvis/jarvis_platform/ollama.py +54 -68
- jarvis/jarvis_platform/openai.py +0 -4
- jarvis/jarvis_platform/oyi.py +0 -75
- jarvis/jarvis_platform/yuanbao.py +264 -0
- jarvis/jarvis_rag/file_processors.py +138 -0
- jarvis/jarvis_rag/main.py +1305 -425
- jarvis/jarvis_tools/ask_codebase.py +205 -39
- jarvis/jarvis_tools/code_review.py +125 -99
- jarvis/jarvis_tools/execute_python_script.py +58 -0
- jarvis/jarvis_tools/execute_shell.py +13 -26
- jarvis/jarvis_tools/execute_shell_script.py +1 -1
- jarvis/jarvis_tools/file_analyzer.py +271 -0
- jarvis/jarvis_tools/file_operation.py +1 -1
- jarvis/jarvis_tools/find_caller.py +213 -0
- jarvis/jarvis_tools/find_symbol.py +211 -0
- jarvis/jarvis_tools/function_analyzer.py +248 -0
- jarvis/jarvis_tools/git_commiter.py +4 -4
- jarvis/jarvis_tools/methodology.py +89 -48
- jarvis/jarvis_tools/project_analyzer.py +220 -0
- jarvis/jarvis_tools/read_code.py +23 -2
- jarvis/jarvis_tools/read_webpage.py +195 -81
- jarvis/jarvis_tools/registry.py +132 -11
- jarvis/jarvis_tools/search_web.py +55 -10
- jarvis/jarvis_tools/tool_generator.py +6 -8
- jarvis/jarvis_utils/__init__.py +1 -0
- jarvis/jarvis_utils/config.py +67 -3
- jarvis/jarvis_utils/embedding.py +344 -45
- jarvis/jarvis_utils/git_utils.py +9 -1
- jarvis/jarvis_utils/input.py +7 -6
- jarvis/jarvis_utils/methodology.py +379 -7
- jarvis/jarvis_utils/output.py +5 -3
- jarvis/jarvis_utils/utils.py +59 -7
- {jarvis_ai_assistant-0.1.131.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/METADATA +3 -2
- jarvis_ai_assistant-0.1.132.dist-info/RECORD +82 -0
- {jarvis_ai_assistant-0.1.131.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/entry_points.txt +2 -0
- jarvis/jarvis_codebase/__init__.py +0 -0
- jarvis/jarvis_codebase/main.py +0 -1011
- jarvis/jarvis_tools/treesitter_analyzer.py +0 -331
- jarvis/jarvis_treesitter/README.md +0 -104
- jarvis/jarvis_treesitter/__init__.py +0 -20
- jarvis/jarvis_treesitter/database.py +0 -258
- jarvis/jarvis_treesitter/example.py +0 -115
- jarvis/jarvis_treesitter/grammar_builder.py +0 -182
- jarvis/jarvis_treesitter/language.py +0 -117
- jarvis/jarvis_treesitter/symbol.py +0 -31
- jarvis/jarvis_treesitter/tools_usage.md +0 -121
- jarvis_ai_assistant-0.1.131.dist-info/RECORD +0 -85
- {jarvis_ai_assistant-0.1.131.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.131.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.131.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/top_level.txt +0 -0
jarvis/jarvis_codebase/main.py
DELETED
|
@@ -1,1011 +0,0 @@
|
|
|
1
|
-
import hashlib
|
|
2
|
-
import os
|
|
3
|
-
import numpy as np
|
|
4
|
-
import faiss
|
|
5
|
-
from typing import List, Tuple, Optional, Dict
|
|
6
|
-
|
|
7
|
-
from yaspin import yaspin
|
|
8
|
-
|
|
9
|
-
from jarvis.jarvis_platform.registry import PlatformRegistry
|
|
10
|
-
import concurrent.futures
|
|
11
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
12
|
-
import argparse
|
|
13
|
-
import pickle
|
|
14
|
-
import lzma # 添加 lzma 导入
|
|
15
|
-
from tqdm import tqdm
|
|
16
|
-
import re
|
|
17
|
-
|
|
18
|
-
from jarvis.jarvis_utils.config import get_max_token_count, get_thread_count
|
|
19
|
-
from jarvis.jarvis_utils.embedding import get_embedding, load_embedding_model, get_context_token_count
|
|
20
|
-
from jarvis.jarvis_utils.git_utils import find_git_root
|
|
21
|
-
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
|
22
|
-
from jarvis.jarvis_utils.utils import get_file_md5, init_env, user_confirm
|
|
23
|
-
|
|
24
|
-
class CodeBase:
|
|
25
|
-
def __init__(self, root_dir: str):
|
|
26
|
-
with yaspin(text="正在初始化环境...", color="cyan") as spinner:
|
|
27
|
-
init_env()
|
|
28
|
-
spinner.text = "环境初始化完成"
|
|
29
|
-
spinner.ok("✅")
|
|
30
|
-
|
|
31
|
-
self.root_dir = root_dir
|
|
32
|
-
os.chdir(self.root_dir)
|
|
33
|
-
self.thread_count = get_thread_count()
|
|
34
|
-
self.max_token_count = get_max_token_count()
|
|
35
|
-
self.index = None
|
|
36
|
-
|
|
37
|
-
# 初始化数据目录
|
|
38
|
-
with yaspin(text="正在初始化数据目录...", color="cyan") as spinner:
|
|
39
|
-
self.data_dir = os.path.join(self.root_dir, ".jarvis/codebase")
|
|
40
|
-
self.cache_dir = os.path.join(self.data_dir, "cache")
|
|
41
|
-
if not os.path.exists(self.cache_dir):
|
|
42
|
-
os.makedirs(self.cache_dir)
|
|
43
|
-
spinner.text = "数据目录初始化完成"
|
|
44
|
-
spinner.ok("✅")
|
|
45
|
-
|
|
46
|
-
with yaspin("正在初始化嵌入模型...", color="cyan") as spinner:
|
|
47
|
-
# 初始化嵌入模型
|
|
48
|
-
try:
|
|
49
|
-
self.embedding_model = load_embedding_model()
|
|
50
|
-
test_text = """This is a test text"""
|
|
51
|
-
self.embedding_model.encode([test_text],
|
|
52
|
-
convert_to_tensor=True,
|
|
53
|
-
normalize_embeddings=True)
|
|
54
|
-
spinner.text = "嵌入模型初始化完成"
|
|
55
|
-
spinner.ok("✅")
|
|
56
|
-
except Exception as e:
|
|
57
|
-
spinner.text = "嵌入模型初始化失败"
|
|
58
|
-
spinner.fail("❌")
|
|
59
|
-
raise
|
|
60
|
-
|
|
61
|
-
self.vector_dim = self.embedding_model.get_sentence_embedding_dimension()
|
|
62
|
-
self.git_file_list = self.get_git_file_list()
|
|
63
|
-
self.platform_registry = PlatformRegistry.get_global_platform_registry()
|
|
64
|
-
|
|
65
|
-
# 初始化缓存和索引
|
|
66
|
-
self.vector_cache = {}
|
|
67
|
-
self.file_paths = []
|
|
68
|
-
|
|
69
|
-
# 加载所有缓存文件
|
|
70
|
-
with spinner.hidden():
|
|
71
|
-
self._load_all_cache()
|
|
72
|
-
|
|
73
|
-
def get_git_file_list(self):
|
|
74
|
-
"""Get the list of files in the git repository, excluding the .jarvis-codebase directory"""
|
|
75
|
-
files = os.popen("git ls-files").read().splitlines()
|
|
76
|
-
# Filter out files in the .jarvis-codebase directory
|
|
77
|
-
return [f for f in files if not f.startswith(".jarvis")]
|
|
78
|
-
|
|
79
|
-
def is_text_file(self, file_path: str):
|
|
80
|
-
try:
|
|
81
|
-
open(file_path, "r", encoding="utf-8", errors="ignore").read()
|
|
82
|
-
return True
|
|
83
|
-
except Exception:
|
|
84
|
-
return False
|
|
85
|
-
|
|
86
|
-
def make_description(self, file_path: str, content: str) -> str:
|
|
87
|
-
model = PlatformRegistry.get_global_platform_registry().get_cheap_platform()
|
|
88
|
-
prompt = f"""请分析以下代码文件并生成详细描述。描述应包含:
|
|
89
|
-
1. 文件整体功能描述
|
|
90
|
-
2. 对每个全局变量、函数、类型定义、类、方法和其他代码元素的描述
|
|
91
|
-
|
|
92
|
-
请使用简洁专业的语言,强调技术功能,以便于后续代码检索。
|
|
93
|
-
文件路径: {file_path}
|
|
94
|
-
代码内容:
|
|
95
|
-
{content}
|
|
96
|
-
"""
|
|
97
|
-
response = model.chat_until_success(prompt)
|
|
98
|
-
return response
|
|
99
|
-
|
|
100
|
-
def export(self):
|
|
101
|
-
"""Export the current index data to standard output"""
|
|
102
|
-
for file_path, data in self.vector_cache.items():
|
|
103
|
-
print(f"## {file_path}")
|
|
104
|
-
print(f"- path: {file_path}")
|
|
105
|
-
print(f"- description: {data['description']}")
|
|
106
|
-
|
|
107
|
-
def _get_cache_path(self, file_path: str) -> str:
|
|
108
|
-
"""Get cache file path for a source file
|
|
109
|
-
|
|
110
|
-
Args:
|
|
111
|
-
file_path: Source file path
|
|
112
|
-
|
|
113
|
-
Returns:
|
|
114
|
-
str: Cache file path
|
|
115
|
-
"""
|
|
116
|
-
# 处理文件路径:
|
|
117
|
-
# 1. 移除开头的 ./ 或 /
|
|
118
|
-
# 2. 将 / 替换为 --
|
|
119
|
-
# 3. 添加 .cache 后缀
|
|
120
|
-
clean_path = file_path.lstrip('./').lstrip('/')
|
|
121
|
-
cache_name = clean_path.replace('/', '--') + '.cache'
|
|
122
|
-
return os.path.join(self.cache_dir, cache_name)
|
|
123
|
-
|
|
124
|
-
def _load_all_cache(self):
|
|
125
|
-
"""Load all cache files"""
|
|
126
|
-
with yaspin(text="正在加载缓存文件...", color="cyan") as spinner:
|
|
127
|
-
try:
|
|
128
|
-
# 清空现有缓存和文件路径
|
|
129
|
-
self.vector_cache = {}
|
|
130
|
-
self.file_paths = []
|
|
131
|
-
vectors = []
|
|
132
|
-
|
|
133
|
-
for cache_file in os.listdir(self.cache_dir):
|
|
134
|
-
if not cache_file.endswith('.cache'):
|
|
135
|
-
continue
|
|
136
|
-
|
|
137
|
-
cache_path = os.path.join(self.cache_dir, cache_file)
|
|
138
|
-
try:
|
|
139
|
-
with lzma.open(cache_path, 'rb') as f:
|
|
140
|
-
cache_data = pickle.load(f)
|
|
141
|
-
file_path = cache_data["path"]
|
|
142
|
-
self.vector_cache[file_path] = cache_data
|
|
143
|
-
self.file_paths.append(file_path)
|
|
144
|
-
vectors.append(cache_data["vector"])
|
|
145
|
-
spinner.write(f"✅ 加载缓存文件成功 {file_path}")
|
|
146
|
-
except Exception as e:
|
|
147
|
-
spinner.write(f"❌ 加载缓存文件失败 {cache_file} {str(e)}")
|
|
148
|
-
continue
|
|
149
|
-
|
|
150
|
-
if vectors:
|
|
151
|
-
# 重建索引
|
|
152
|
-
vectors_array = np.vstack(vectors)
|
|
153
|
-
hnsw_index = faiss.IndexHNSWFlat(self.vector_dim, 16)
|
|
154
|
-
hnsw_index.hnsw.efConstruction = 40
|
|
155
|
-
hnsw_index.hnsw.efSearch = 16
|
|
156
|
-
self.index = faiss.IndexIDMap(hnsw_index)
|
|
157
|
-
self.index.add_with_ids(vectors_array, np.array(range(len(vectors)))) # type: ignore
|
|
158
|
-
|
|
159
|
-
spinner.text = f"加载 {len(self.vector_cache)} 个向量缓存并重建索引"
|
|
160
|
-
spinner.ok("✅")
|
|
161
|
-
else:
|
|
162
|
-
self.index = None
|
|
163
|
-
spinner.text = "没有找到有效的缓存文件"
|
|
164
|
-
spinner.ok("✅")
|
|
165
|
-
|
|
166
|
-
except Exception as e:
|
|
167
|
-
spinner.text = f"加载缓存目录失败: {str(e)}"
|
|
168
|
-
spinner.fail("❌")
|
|
169
|
-
self.vector_cache = {}
|
|
170
|
-
self.file_paths = []
|
|
171
|
-
self.index = None
|
|
172
|
-
|
|
173
|
-
def cache_vector(self, file_path: str, vector: np.ndarray, description: str):
|
|
174
|
-
"""Cache the vector representation of a file"""
|
|
175
|
-
try:
|
|
176
|
-
with open(file_path, "rb") as f:
|
|
177
|
-
file_md5 = hashlib.md5(f.read()).hexdigest()
|
|
178
|
-
except Exception as e:
|
|
179
|
-
PrettyOutput.print(f"计算 {file_path} 的MD5失败: {str(e)}",
|
|
180
|
-
output_type=OutputType.ERROR)
|
|
181
|
-
file_md5 = ""
|
|
182
|
-
|
|
183
|
-
# 准备缓存数据
|
|
184
|
-
cache_data = {
|
|
185
|
-
"path": file_path, # 保存文件路径
|
|
186
|
-
"md5": file_md5, # 保存文件MD5
|
|
187
|
-
"description": description, # 保存文件描述
|
|
188
|
-
"vector": vector # 保存向量
|
|
189
|
-
}
|
|
190
|
-
|
|
191
|
-
# 更新内存缓存
|
|
192
|
-
self.vector_cache[file_path] = cache_data
|
|
193
|
-
|
|
194
|
-
# 保存到单独的缓存文件
|
|
195
|
-
cache_path = self._get_cache_path(file_path)
|
|
196
|
-
try:
|
|
197
|
-
with lzma.open(cache_path, 'wb') as f:
|
|
198
|
-
pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
199
|
-
except Exception as e:
|
|
200
|
-
PrettyOutput.print(f"保存 {file_path} 的缓存失败: {str(e)}",
|
|
201
|
-
output_type=OutputType.ERROR)
|
|
202
|
-
|
|
203
|
-
def get_cached_vector(self, file_path: str, description: str) -> Optional[np.ndarray]:
|
|
204
|
-
"""Get the vector representation of a file from the cache"""
|
|
205
|
-
if file_path not in self.vector_cache:
|
|
206
|
-
return None
|
|
207
|
-
|
|
208
|
-
# Check if the file has been modified
|
|
209
|
-
try:
|
|
210
|
-
with open(file_path, "rb") as f:
|
|
211
|
-
current_md5 = hashlib.md5(f.read()).hexdigest()
|
|
212
|
-
except Exception as e:
|
|
213
|
-
PrettyOutput.print(f"计算 {file_path} 的MD5失败: {str(e)}",
|
|
214
|
-
output_type=OutputType.ERROR)
|
|
215
|
-
return None
|
|
216
|
-
|
|
217
|
-
cached_data = self.vector_cache[file_path]
|
|
218
|
-
if cached_data["md5"] != current_md5:
|
|
219
|
-
return None
|
|
220
|
-
|
|
221
|
-
# Check if the description has changed
|
|
222
|
-
if cached_data["description"] != description:
|
|
223
|
-
return None
|
|
224
|
-
|
|
225
|
-
return cached_data["vector"]
|
|
226
|
-
|
|
227
|
-
def vectorize_file(self, file_path: str, description: str) -> np.ndarray:
|
|
228
|
-
"""Vectorize the file content and description"""
|
|
229
|
-
try:
|
|
230
|
-
# Try to get the vector from the cache first
|
|
231
|
-
cached_vector = self.get_cached_vector(file_path, description)
|
|
232
|
-
if cached_vector is not None:
|
|
233
|
-
return cached_vector
|
|
234
|
-
|
|
235
|
-
# Read the file content and combine information
|
|
236
|
-
content = open(file_path, "r", encoding="utf-8", errors="ignore").read()[:self.max_token_count] # Limit the file content length
|
|
237
|
-
|
|
238
|
-
# Combine file information, including file content
|
|
239
|
-
combined_text = f"""
|
|
240
|
-
File path: {file_path}
|
|
241
|
-
Description: {description}
|
|
242
|
-
Content: {content}
|
|
243
|
-
"""
|
|
244
|
-
vector = get_embedding(self.embedding_model, combined_text)
|
|
245
|
-
|
|
246
|
-
# Save to cache
|
|
247
|
-
self.cache_vector(file_path, vector, description)
|
|
248
|
-
return vector
|
|
249
|
-
except Exception as e:
|
|
250
|
-
PrettyOutput.print(f"向量化 {file_path} 失败: {str(e)}",
|
|
251
|
-
output_type=OutputType.ERROR)
|
|
252
|
-
return np.zeros(self.vector_dim, dtype=np.float32) # type: ignore
|
|
253
|
-
|
|
254
|
-
def clean_cache(self) -> bool:
|
|
255
|
-
"""Clean expired cache records"""
|
|
256
|
-
try:
|
|
257
|
-
files_to_delete = []
|
|
258
|
-
for file_path in list(self.vector_cache.keys()):
|
|
259
|
-
if not os.path.exists(file_path):
|
|
260
|
-
files_to_delete.append(file_path)
|
|
261
|
-
cache_path = self._get_cache_path(file_path)
|
|
262
|
-
try:
|
|
263
|
-
os.remove(cache_path)
|
|
264
|
-
except Exception:
|
|
265
|
-
pass
|
|
266
|
-
|
|
267
|
-
for file_path in files_to_delete:
|
|
268
|
-
del self.vector_cache[file_path]
|
|
269
|
-
if file_path in self.file_paths:
|
|
270
|
-
self.file_paths.remove(file_path)
|
|
271
|
-
|
|
272
|
-
return bool(files_to_delete)
|
|
273
|
-
|
|
274
|
-
except Exception as e:
|
|
275
|
-
PrettyOutput.print(f"清理缓存失败: {str(e)}",
|
|
276
|
-
output_type=OutputType.ERROR)
|
|
277
|
-
return False
|
|
278
|
-
|
|
279
|
-
def process_file(self, file_path: str):
|
|
280
|
-
"""Process a single file"""
|
|
281
|
-
try:
|
|
282
|
-
# Skip non-existent files
|
|
283
|
-
if not os.path.exists(file_path):
|
|
284
|
-
return None
|
|
285
|
-
|
|
286
|
-
if not self.is_text_file(file_path):
|
|
287
|
-
return None
|
|
288
|
-
|
|
289
|
-
md5 = get_file_md5(file_path)
|
|
290
|
-
|
|
291
|
-
content = open(file_path, "r", encoding="utf-8", errors="ignore").read()
|
|
292
|
-
|
|
293
|
-
# Check if the file has already been processed and the content has not changed
|
|
294
|
-
if file_path in self.vector_cache:
|
|
295
|
-
if self.vector_cache[file_path].get("md5") == md5:
|
|
296
|
-
return None
|
|
297
|
-
|
|
298
|
-
description = self.make_description(file_path, content) # Pass the truncated content
|
|
299
|
-
vector = self.vectorize_file(file_path, description)
|
|
300
|
-
|
|
301
|
-
# Save to cache, using the actual file path as the key
|
|
302
|
-
self.vector_cache[file_path] = {
|
|
303
|
-
"vector": vector,
|
|
304
|
-
"description": description,
|
|
305
|
-
"md5": md5
|
|
306
|
-
}
|
|
307
|
-
|
|
308
|
-
return file_path
|
|
309
|
-
|
|
310
|
-
except Exception as e:
|
|
311
|
-
PrettyOutput.print(f"处理 {file_path} 失败: {str(e)}",
|
|
312
|
-
output_type=OutputType.ERROR)
|
|
313
|
-
return None
|
|
314
|
-
|
|
315
|
-
def build_index(self):
|
|
316
|
-
"""Build a faiss index from the vector cache"""
|
|
317
|
-
try:
|
|
318
|
-
if not self.vector_cache:
|
|
319
|
-
self.index = None
|
|
320
|
-
return
|
|
321
|
-
|
|
322
|
-
# Create the underlying HNSW index
|
|
323
|
-
hnsw_index = faiss.IndexHNSWFlat(self.vector_dim, 16)
|
|
324
|
-
hnsw_index.hnsw.efConstruction = 40
|
|
325
|
-
hnsw_index.hnsw.efSearch = 16
|
|
326
|
-
|
|
327
|
-
# Wrap the HNSW index with IndexIDMap
|
|
328
|
-
self.index = faiss.IndexIDMap(hnsw_index)
|
|
329
|
-
|
|
330
|
-
vectors = []
|
|
331
|
-
ids = []
|
|
332
|
-
self.file_paths = [] # Reset the file path list
|
|
333
|
-
|
|
334
|
-
for i, ( file_path, data) in enumerate(self.vector_cache.items()):
|
|
335
|
-
if "vector" not in data:
|
|
336
|
-
PrettyOutput.print(f"无效的缓存数据 {file_path}: 缺少向量",
|
|
337
|
-
output_type=OutputType.WARNING)
|
|
338
|
-
continue
|
|
339
|
-
|
|
340
|
-
vector = data["vector"]
|
|
341
|
-
if not isinstance(vector, np.ndarray):
|
|
342
|
-
PrettyOutput.print(f"无效的向量类型 {file_path}: {type(vector)}",
|
|
343
|
-
output_type=OutputType.WARNING)
|
|
344
|
-
continue
|
|
345
|
-
|
|
346
|
-
vectors.append(vector.reshape(1, -1))
|
|
347
|
-
ids.append(i)
|
|
348
|
-
self.file_paths.append(file_path)
|
|
349
|
-
|
|
350
|
-
if vectors:
|
|
351
|
-
vectors = np.vstack(vectors)
|
|
352
|
-
if len(vectors) != len(ids):
|
|
353
|
-
PrettyOutput.print(f"向量数量不匹配: {len(vectors)} 个向量 vs {len(ids)} 个ID",
|
|
354
|
-
output_type=OutputType.WARNING)
|
|
355
|
-
self.index = None
|
|
356
|
-
return
|
|
357
|
-
|
|
358
|
-
try:
|
|
359
|
-
self.index.add_with_ids(vectors, np.array(ids)) # type: ignore
|
|
360
|
-
PrettyOutput.print(f"成功构建包含 {len(vectors)} 个向量的索引",
|
|
361
|
-
output_type=OutputType.SUCCESS)
|
|
362
|
-
except Exception as e:
|
|
363
|
-
PrettyOutput.print(f"添加向量到索引失败: {str(e)}",
|
|
364
|
-
output_type=OutputType.ERROR)
|
|
365
|
-
self.index = None
|
|
366
|
-
else:
|
|
367
|
-
PrettyOutput.print("没有找到有效的向量, 索引未构建",
|
|
368
|
-
output_type=OutputType.WARNING)
|
|
369
|
-
self.index = None
|
|
370
|
-
|
|
371
|
-
except Exception as e:
|
|
372
|
-
PrettyOutput.print(f"构建索引失败: {str(e)}",
|
|
373
|
-
output_type=OutputType.ERROR)
|
|
374
|
-
self.index = None
|
|
375
|
-
|
|
376
|
-
def gen_vector_db_from_cache(self):
|
|
377
|
-
"""Generate a vector database from the cache"""
|
|
378
|
-
self.build_index()
|
|
379
|
-
self._load_all_cache()
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
def generate_codebase(self, force: bool = False):
|
|
383
|
-
"""Generate the codebase index
|
|
384
|
-
Args:
|
|
385
|
-
force: Whether to force rebuild the index, without asking the user
|
|
386
|
-
"""
|
|
387
|
-
try:
|
|
388
|
-
# Clean up cache for non-existent files
|
|
389
|
-
files_to_delete = []
|
|
390
|
-
for cached_file in list(self.vector_cache.keys()):
|
|
391
|
-
if not os.path.exists(cached_file) or not self.is_text_file(cached_file):
|
|
392
|
-
files_to_delete.append(cached_file)
|
|
393
|
-
cache_path = self._get_cache_path(cached_file)
|
|
394
|
-
try:
|
|
395
|
-
os.remove(cache_path)
|
|
396
|
-
except Exception as e:
|
|
397
|
-
PrettyOutput.print(f"删除缓存文件 {cached_file} 失败: {str(e)}",
|
|
398
|
-
output_type=OutputType.WARNING)
|
|
399
|
-
|
|
400
|
-
if files_to_delete:
|
|
401
|
-
for file_path in files_to_delete:
|
|
402
|
-
del self.vector_cache[file_path]
|
|
403
|
-
PrettyOutput.print(f"清理了 {len(files_to_delete)} 个不存在的文件的缓存",
|
|
404
|
-
output_type=OutputType.INFO)
|
|
405
|
-
|
|
406
|
-
# Update the git file list
|
|
407
|
-
self.git_file_list = self.get_git_file_list()
|
|
408
|
-
|
|
409
|
-
# Check file changes
|
|
410
|
-
PrettyOutput.print("检查文件变化...", output_type=OutputType.INFO)
|
|
411
|
-
changes_detected = False
|
|
412
|
-
new_files = []
|
|
413
|
-
modified_files = []
|
|
414
|
-
deleted_files = []
|
|
415
|
-
|
|
416
|
-
# Check deleted files
|
|
417
|
-
files_to_delete = []
|
|
418
|
-
for file_path in list(self.vector_cache.keys()):
|
|
419
|
-
if file_path not in self.git_file_list:
|
|
420
|
-
deleted_files.append(file_path)
|
|
421
|
-
files_to_delete.append(file_path)
|
|
422
|
-
changes_detected = True
|
|
423
|
-
# Check new and modified files
|
|
424
|
-
from rich.progress import Progress
|
|
425
|
-
with Progress() as progress:
|
|
426
|
-
task = progress.add_task("Check file status", total=len(self.git_file_list))
|
|
427
|
-
for file_path in self.git_file_list:
|
|
428
|
-
if not os.path.exists(file_path) or not self.is_text_file(file_path):
|
|
429
|
-
progress.advance(task)
|
|
430
|
-
continue
|
|
431
|
-
|
|
432
|
-
try:
|
|
433
|
-
current_md5 = get_file_md5(file_path)
|
|
434
|
-
|
|
435
|
-
if file_path not in self.vector_cache:
|
|
436
|
-
new_files.append(file_path)
|
|
437
|
-
changes_detected = True
|
|
438
|
-
elif self.vector_cache[file_path].get("md5") != current_md5:
|
|
439
|
-
modified_files.append(file_path)
|
|
440
|
-
changes_detected = True
|
|
441
|
-
except Exception as e:
|
|
442
|
-
PrettyOutput.print(f"检查 {file_path} 失败: {str(e)}",
|
|
443
|
-
output_type=OutputType.ERROR)
|
|
444
|
-
progress.advance(task)
|
|
445
|
-
|
|
446
|
-
# If changes are detected, display changes and ask the user
|
|
447
|
-
if changes_detected:
|
|
448
|
-
output_lines = ["检测到以下变化:"]
|
|
449
|
-
if new_files:
|
|
450
|
-
output_lines.append("新文件:")
|
|
451
|
-
output_lines.extend(f" {f}" for f in new_files)
|
|
452
|
-
if modified_files:
|
|
453
|
-
output_lines.append("修改的文件:")
|
|
454
|
-
output_lines.extend(f" {f}" for f in modified_files)
|
|
455
|
-
if deleted_files:
|
|
456
|
-
output_lines.append("删除的文件:")
|
|
457
|
-
output_lines.extend(f" {f}" for f in deleted_files)
|
|
458
|
-
|
|
459
|
-
PrettyOutput.print("\n".join(output_lines), output_type=OutputType.INFO)
|
|
460
|
-
|
|
461
|
-
# If force is True, continue directly
|
|
462
|
-
if not force:
|
|
463
|
-
if not user_confirm("重建索引?", False):
|
|
464
|
-
return
|
|
465
|
-
|
|
466
|
-
# Clean deleted files
|
|
467
|
-
for file_path in files_to_delete:
|
|
468
|
-
del self.vector_cache[file_path]
|
|
469
|
-
if files_to_delete:
|
|
470
|
-
PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
|
|
471
|
-
output_type=OutputType.INFO)
|
|
472
|
-
|
|
473
|
-
# Process new and modified files
|
|
474
|
-
files_to_process = new_files + modified_files
|
|
475
|
-
processed_files = []
|
|
476
|
-
|
|
477
|
-
with yaspin(text="正在处理文件...", color="cyan") as spinner:
|
|
478
|
-
# Use a thread pool to process files
|
|
479
|
-
with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
|
|
480
|
-
# Submit all tasks
|
|
481
|
-
future_to_file = {
|
|
482
|
-
executor.submit(self.process_file, file): file
|
|
483
|
-
for file in files_to_process
|
|
484
|
-
}
|
|
485
|
-
|
|
486
|
-
# Process completed tasks
|
|
487
|
-
for future in concurrent.futures.as_completed(future_to_file):
|
|
488
|
-
file = future_to_file[future]
|
|
489
|
-
try:
|
|
490
|
-
result = future.result()
|
|
491
|
-
if result:
|
|
492
|
-
processed_files.append(result)
|
|
493
|
-
spinner.write(f"✅ 处理文件成功 {file}")
|
|
494
|
-
except Exception as e:
|
|
495
|
-
spinner.write(f"❌ 处理文件失败 {file}: {str(e)}")
|
|
496
|
-
|
|
497
|
-
spinner.text = f"处理完成"
|
|
498
|
-
spinner.ok("✅")
|
|
499
|
-
|
|
500
|
-
if processed_files:
|
|
501
|
-
with yaspin(text="重建向量数据库...", color="cyan") as spinner:
|
|
502
|
-
self.gen_vector_db_from_cache()
|
|
503
|
-
spinner.text = f"成功生成了 {len(processed_files)} 个文件的索引"
|
|
504
|
-
spinner.ok("✅")
|
|
505
|
-
else:
|
|
506
|
-
PrettyOutput.print("没有检测到文件变化, 不需要重建索引", output_type=OutputType.INFO)
|
|
507
|
-
|
|
508
|
-
except Exception as e:
|
|
509
|
-
# Try to save the cache when an exception occurs
|
|
510
|
-
try:
|
|
511
|
-
self._load_all_cache()
|
|
512
|
-
except Exception as save_error:
|
|
513
|
-
PrettyOutput.print(f"保存缓存失败: {str(save_error)}",
|
|
514
|
-
output_type=OutputType.ERROR)
|
|
515
|
-
raise e # Re-raise the original exception
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
def _text_search_score(self, content: str, keywords: List[str]) -> float:
|
|
519
|
-
"""Calculate the matching score between the text content and the keywords
|
|
520
|
-
|
|
521
|
-
Args:
|
|
522
|
-
content: Text content
|
|
523
|
-
keywords: List of keywords
|
|
524
|
-
|
|
525
|
-
Returns:
|
|
526
|
-
float: Matching score (0-1)
|
|
527
|
-
"""
|
|
528
|
-
if not keywords:
|
|
529
|
-
return 0.0
|
|
530
|
-
|
|
531
|
-
content = content.lower()
|
|
532
|
-
matched_keywords = set()
|
|
533
|
-
|
|
534
|
-
for keyword in keywords:
|
|
535
|
-
keyword = keyword.lower()
|
|
536
|
-
if keyword in content:
|
|
537
|
-
matched_keywords.add(keyword)
|
|
538
|
-
|
|
539
|
-
# Calculate the matching score
|
|
540
|
-
score = len(matched_keywords) / len(keywords)
|
|
541
|
-
return score
|
|
542
|
-
|
|
543
|
-
def pick_results(self, query: List[str], initial_results: List[str]) -> List[Dict[str,str]]:
|
|
544
|
-
"""Use a large model to pick the search results
|
|
545
|
-
|
|
546
|
-
Args:
|
|
547
|
-
query: Search query
|
|
548
|
-
initial_results: Initial results list of file paths
|
|
549
|
-
|
|
550
|
-
Returns:
|
|
551
|
-
List[str]: The picked results list, each item is a file path
|
|
552
|
-
"""
|
|
553
|
-
if not initial_results:
|
|
554
|
-
return []
|
|
555
|
-
with yaspin(text="正在筛选结果...", color="cyan") as spinner:
|
|
556
|
-
try:
|
|
557
|
-
# Maximum content length per batch
|
|
558
|
-
max_batch_length = self.max_token_count - 1000 # Reserve space for prompt
|
|
559
|
-
max_file_length = max_batch_length // 3 # Limit individual file size
|
|
560
|
-
|
|
561
|
-
# Process files in batches
|
|
562
|
-
all_selected_files = []
|
|
563
|
-
current_batch = []
|
|
564
|
-
current_token_count = 0
|
|
565
|
-
|
|
566
|
-
for path in initial_results:
|
|
567
|
-
try:
|
|
568
|
-
content = open(path, "r", encoding="utf-8", errors="ignore").read()
|
|
569
|
-
# Truncate large files
|
|
570
|
-
if get_context_token_count(content) > max_file_length:
|
|
571
|
-
spinner.write(f"❌ 截断大文件: {path}")
|
|
572
|
-
content = content[:max_file_length] + "\n... (content truncated)"
|
|
573
|
-
|
|
574
|
-
file_info = f"File: {path}\nContent: {content}\n\n"
|
|
575
|
-
tokens_count = get_context_token_count(file_info)
|
|
576
|
-
|
|
577
|
-
# If adding this file would exceed batch limit
|
|
578
|
-
if current_token_count + tokens_count > max_batch_length:
|
|
579
|
-
# Process current batch
|
|
580
|
-
if current_batch:
|
|
581
|
-
selected = self._process_batch('\n'.join(query), current_batch)
|
|
582
|
-
all_selected_files.extend(selected)
|
|
583
|
-
# Start new batch
|
|
584
|
-
current_batch = [file_info]
|
|
585
|
-
current_token_count = tokens_count
|
|
586
|
-
else:
|
|
587
|
-
current_batch.append(file_info)
|
|
588
|
-
current_token_count += tokens_count
|
|
589
|
-
|
|
590
|
-
except Exception as e:
|
|
591
|
-
spinner.write(f"❌ 读取 {path} 失败: {str(e)}")
|
|
592
|
-
continue
|
|
593
|
-
|
|
594
|
-
# Process final batch
|
|
595
|
-
if current_batch:
|
|
596
|
-
selected = self._process_batch('\n'.join(query), current_batch)
|
|
597
|
-
all_selected_files.extend(selected)
|
|
598
|
-
|
|
599
|
-
spinner.write("✅ 结果筛选完成")
|
|
600
|
-
# Convert set to list and maintain original order
|
|
601
|
-
return all_selected_files
|
|
602
|
-
|
|
603
|
-
except Exception as e:
|
|
604
|
-
spinner.text = f"选择失败: {str(e)}"
|
|
605
|
-
spinner.fail("❌")
|
|
606
|
-
return [{"file": f, "reason": "" } for f in initial_results]
|
|
607
|
-
|
|
608
|
-
def _process_batch(self, query: str, files_info: List[str]) -> List[Dict[str, str]]:
|
|
609
|
-
"""Process a batch of files"""
|
|
610
|
-
prompt = f"""作为一名代码分析专家,请使用链式思维推理帮助识别与给定查询最相关的文件。
|
|
611
|
-
|
|
612
|
-
查询: {query}
|
|
613
|
-
|
|
614
|
-
可用文件:
|
|
615
|
-
{''.join(files_info)}
|
|
616
|
-
|
|
617
|
-
请按以下步骤思考:
|
|
618
|
-
1. 首先,分析查询以识别关键需求和技术概念
|
|
619
|
-
2. 对于每个文件:
|
|
620
|
-
- 检查其路径和内容
|
|
621
|
-
- 评估其与查询需求的关系
|
|
622
|
-
- 考虑直接和间接关系
|
|
623
|
-
- 评估其相关性(高/中/低)
|
|
624
|
-
3. 仅选择与查询明确相关的文件
|
|
625
|
-
4. 按相关性排序,最相关的文件在前
|
|
626
|
-
|
|
627
|
-
请以YAML格式输出您的选择:
|
|
628
|
-
<FILES>
|
|
629
|
-
- file: path/to/most/relevant.py
|
|
630
|
-
reason: xxxxxxxxxx
|
|
631
|
-
- path/to/next/relevant.py
|
|
632
|
-
reason: yyyyyyyyyy
|
|
633
|
-
</FILES>
|
|
634
|
-
|
|
635
|
-
重要提示:
|
|
636
|
-
- 仅包含真正相关的文件
|
|
637
|
-
- 排除连接不明确或较弱的文件
|
|
638
|
-
- 重点关注实现文件而非测试文件
|
|
639
|
-
- 同时考虑文件路径和内容
|
|
640
|
-
- 仅输出文件路径,不要包含其他文本
|
|
641
|
-
"""
|
|
642
|
-
|
|
643
|
-
# Use a large model to evaluate
|
|
644
|
-
model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
|
|
645
|
-
response = model.chat_until_success(prompt)
|
|
646
|
-
|
|
647
|
-
# Parse the response
|
|
648
|
-
import yaml
|
|
649
|
-
files_match = re.search(r'<FILES>\n(.*?)</FILES>', response, re.DOTALL)
|
|
650
|
-
if not files_match:
|
|
651
|
-
return []
|
|
652
|
-
|
|
653
|
-
try:
|
|
654
|
-
selected_files = yaml.safe_load(files_match.group(1))
|
|
655
|
-
return selected_files if selected_files else []
|
|
656
|
-
except Exception as e:
|
|
657
|
-
PrettyOutput.print(f"解析响应失败: {str(e)}", OutputType.ERROR)
|
|
658
|
-
return []
|
|
659
|
-
|
|
660
|
-
def _generate_query_variants(self, query: str) -> List[str]:
|
|
661
|
-
"""Generate different expressions of the query optimized for vector search
|
|
662
|
-
|
|
663
|
-
Args:
|
|
664
|
-
query: Original query
|
|
665
|
-
|
|
666
|
-
Returns:
|
|
667
|
-
List[str]: The query variants list
|
|
668
|
-
"""
|
|
669
|
-
model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
|
|
670
|
-
prompt = f"""请基于以下查询生成10个针对向量搜索优化的不同表达。每个表达应满足:
|
|
671
|
-
1. 聚焦关键技术概念和术语
|
|
672
|
-
2. 使用清晰明确的语言
|
|
673
|
-
3. 包含重要的上下文术语
|
|
674
|
-
4. 避免使用通用或模糊的词语
|
|
675
|
-
5. 保持与原始查询的语义相似性
|
|
676
|
-
6. 适合基于嵌入的搜索
|
|
677
|
-
|
|
678
|
-
原始查询:
|
|
679
|
-
{query}
|
|
680
|
-
|
|
681
|
-
示例转换:
|
|
682
|
-
查询: "如何处理用户登录?"
|
|
683
|
-
输出格式:
|
|
684
|
-
<QUESTION>
|
|
685
|
-
- 用户认证的实现与流程
|
|
686
|
-
- 登录系统架构与组件
|
|
687
|
-
- 凭证验证与会话管理
|
|
688
|
-
- ...
|
|
689
|
-
</QUESTION>
|
|
690
|
-
|
|
691
|
-
请以指定格式提供10个搜索优化的表达。
|
|
692
|
-
"""
|
|
693
|
-
response = model.chat_until_success(prompt)
|
|
694
|
-
|
|
695
|
-
# Parse the response using YAML format
|
|
696
|
-
import yaml
|
|
697
|
-
variants = []
|
|
698
|
-
question_match = re.search(r'<QUESTION>\n(.*?)</QUESTION>', response, re.DOTALL)
|
|
699
|
-
if question_match:
|
|
700
|
-
try:
|
|
701
|
-
variants = yaml.safe_load(question_match.group(1))
|
|
702
|
-
if not isinstance(variants, list):
|
|
703
|
-
variants = [str(variants)]
|
|
704
|
-
except Exception as e:
|
|
705
|
-
PrettyOutput.print(f"解析变体失败: {str(e)}", OutputType.ERROR)
|
|
706
|
-
|
|
707
|
-
# Add original query
|
|
708
|
-
variants.append(query)
|
|
709
|
-
return variants if variants else [query]
|
|
710
|
-
|
|
711
|
-
def _vector_search(self, query_variants: List[str], top_k: int) -> Dict[str, Tuple[str, float, str]]:
|
|
712
|
-
"""Use vector search to find related files
|
|
713
|
-
|
|
714
|
-
Args:
|
|
715
|
-
query_variants: The query variants list
|
|
716
|
-
top_k: The number of results to return
|
|
717
|
-
|
|
718
|
-
Returns:
|
|
719
|
-
Dict[str, Tuple[str, float, str]]: The mapping from file path to (file path, score, description)
|
|
720
|
-
"""
|
|
721
|
-
results = {}
|
|
722
|
-
for query in query_variants:
|
|
723
|
-
query_vector = get_embedding(self.embedding_model, query)
|
|
724
|
-
query_vector = query_vector.reshape(1, -1)
|
|
725
|
-
|
|
726
|
-
distances, indices = self.index.search(query_vector, top_k) # type: ignore
|
|
727
|
-
|
|
728
|
-
for i, distance in zip(indices[0], distances[0]):
|
|
729
|
-
if i == -1:
|
|
730
|
-
continue
|
|
731
|
-
|
|
732
|
-
similarity = 1.0 / (1.0 + float(distance))
|
|
733
|
-
file_path = self.file_paths[i]
|
|
734
|
-
# Use the highest similarity score
|
|
735
|
-
if file_path not in results:
|
|
736
|
-
if similarity > 0.5:
|
|
737
|
-
data = self.vector_cache[file_path]
|
|
738
|
-
results[file_path] = (file_path, similarity, data["description"])
|
|
739
|
-
|
|
740
|
-
return results
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
def search_similar(self, query: str, top_k: int = 30) -> List[Dict[str, str]]:
|
|
744
|
-
"""Search related files with optimized retrieval"""
|
|
745
|
-
with yaspin(text="正在搜索相关文件...", color="cyan") as spinner:
|
|
746
|
-
try:
|
|
747
|
-
with spinner.hidden():
|
|
748
|
-
self.generate_codebase()
|
|
749
|
-
if self.index is None:
|
|
750
|
-
spinner.text = "没有找到有效的缓存文件"
|
|
751
|
-
spinner.ok("✅")
|
|
752
|
-
return []
|
|
753
|
-
|
|
754
|
-
# Generate query variants for better coverage
|
|
755
|
-
spinner.text = "生成查询变体..."
|
|
756
|
-
query_variants = self._generate_query_variants(query)
|
|
757
|
-
spinner.write("✅ 查询变体生成完成")
|
|
758
|
-
|
|
759
|
-
# Collect results from all variants
|
|
760
|
-
spinner.text = "收集结果..."
|
|
761
|
-
all_results = []
|
|
762
|
-
seen_files = set()
|
|
763
|
-
|
|
764
|
-
for variant in query_variants:
|
|
765
|
-
# Get vector for each variant
|
|
766
|
-
query_vector = get_embedding(self.embedding_model, variant)
|
|
767
|
-
query_vector = query_vector.reshape(1, -1)
|
|
768
|
-
|
|
769
|
-
# Search with current variant
|
|
770
|
-
initial_k = min(top_k * 2, len(self.file_paths))
|
|
771
|
-
distances, indices = self.index.search(query_vector, initial_k) # type: ignore
|
|
772
|
-
|
|
773
|
-
# Process results
|
|
774
|
-
for idx, dist in zip(indices[0], distances[0]):
|
|
775
|
-
if idx != -1:
|
|
776
|
-
file_path = self.file_paths[idx]
|
|
777
|
-
if file_path not in seen_files:
|
|
778
|
-
similarity = 1.0 / (1.0 + float(dist))
|
|
779
|
-
if similarity > 0.3: # Lower threshold for better recall
|
|
780
|
-
seen_files.add(file_path)
|
|
781
|
-
all_results.append((file_path, similarity, self.vector_cache[file_path]["description"]))
|
|
782
|
-
spinner.write("✅ 结果收集完成")
|
|
783
|
-
if not all_results:
|
|
784
|
-
spinner.text = "没有找到相关文件"
|
|
785
|
-
spinner.ok("✅")
|
|
786
|
-
return []
|
|
787
|
-
|
|
788
|
-
spinner.text = "排序..."
|
|
789
|
-
# Sort by similarity and take top_k
|
|
790
|
-
all_results.sort(key=lambda x: x[1], reverse=True)
|
|
791
|
-
results = all_results[:top_k]
|
|
792
|
-
spinner.write("✅ 排序完成")
|
|
793
|
-
|
|
794
|
-
with spinner.hidden():
|
|
795
|
-
results = self.pick_results(query_variants, [path for path, _, _ in results])
|
|
796
|
-
|
|
797
|
-
output = "Found related files:\n"
|
|
798
|
-
for file in results:
|
|
799
|
-
output += f'''- {file['file']} ({file['reason']})\n'''
|
|
800
|
-
|
|
801
|
-
spinner.text="结果输出完成"
|
|
802
|
-
spinner.ok("✅")
|
|
803
|
-
return results
|
|
804
|
-
|
|
805
|
-
except Exception as e:
|
|
806
|
-
spinner.text = f"搜索失败: {str(e)}"
|
|
807
|
-
spinner.fail("❌")
|
|
808
|
-
return []
|
|
809
|
-
|
|
810
|
-
def ask_codebase(self, query: str, top_k: int=20) -> Tuple[List[Dict[str, str]], str]:
|
|
811
|
-
"""Query the codebase with enhanced context building"""
|
|
812
|
-
files_from_codebase = self.search_similar(query, top_k)
|
|
813
|
-
|
|
814
|
-
if not files_from_codebase:
|
|
815
|
-
PrettyOutput.print("没有找到相关文件", output_type=OutputType.WARNING)
|
|
816
|
-
return [], ""
|
|
817
|
-
|
|
818
|
-
prompt = f"""
|
|
819
|
-
# 🤖 角色定义
|
|
820
|
-
您是一位代码分析专家,能够提供关于代码库的全面且准确的回答。
|
|
821
|
-
|
|
822
|
-
# 🎯 核心职责
|
|
823
|
-
- 深入分析代码文件
|
|
824
|
-
- 清晰解释技术概念
|
|
825
|
-
- 提供相关代码示例
|
|
826
|
-
- 识别缺失的信息
|
|
827
|
-
- 使用用户的语言进行回答
|
|
828
|
-
|
|
829
|
-
# 📋 回答要求
|
|
830
|
-
## 内容质量
|
|
831
|
-
- 关注实现细节
|
|
832
|
-
- 保持技术准确性
|
|
833
|
-
- 包含相关代码片段
|
|
834
|
-
- 指出任何缺失的信息
|
|
835
|
-
- 使用专业术语
|
|
836
|
-
|
|
837
|
-
## 回答格式
|
|
838
|
-
- question: [重述问题]
|
|
839
|
-
answer: |
|
|
840
|
-
[详细的技术回答,包含:
|
|
841
|
-
- 实现细节
|
|
842
|
-
- 代码示例(如果相关)
|
|
843
|
-
- 缺失的信息(如果有)
|
|
844
|
-
- 相关技术概念]
|
|
845
|
-
|
|
846
|
-
- question: [如果需要,提出后续问题]
|
|
847
|
-
answer: |
|
|
848
|
-
[额外的技术细节]
|
|
849
|
-
|
|
850
|
-
# 🔍 分析上下文
|
|
851
|
-
问题: {query}
|
|
852
|
-
|
|
853
|
-
相关代码文件(按相关性排序):
|
|
854
|
-
"""
|
|
855
|
-
|
|
856
|
-
with yaspin(text="正在生成回答...", color="cyan") as spinner:
|
|
857
|
-
# 添加上下文,控制长度
|
|
858
|
-
spinner.text = "添加上下文..."
|
|
859
|
-
available_count = self.max_token_count - get_context_token_count(prompt) - 1000 # 为回答预留空间
|
|
860
|
-
current_count = 0
|
|
861
|
-
|
|
862
|
-
for path in files_from_codebase:
|
|
863
|
-
try:
|
|
864
|
-
content = open(path["file"], "r", encoding="utf-8", errors="ignore").read()
|
|
865
|
-
file_content = f"""
|
|
866
|
-
## 文件: {path["file"]}
|
|
867
|
-
```
|
|
868
|
-
{content}
|
|
869
|
-
```
|
|
870
|
-
---
|
|
871
|
-
"""
|
|
872
|
-
if current_count + get_context_token_count(file_content) > available_count:
|
|
873
|
-
spinner.write("⚠️ 由于上下文长度限制, 一些文件被省略")
|
|
874
|
-
break
|
|
875
|
-
|
|
876
|
-
prompt += file_content
|
|
877
|
-
current_count += get_context_token_count(file_content)
|
|
878
|
-
|
|
879
|
-
except Exception as e:
|
|
880
|
-
spinner.write(f"❌ 读取 {path} 失败: {str(e)}")
|
|
881
|
-
continue
|
|
882
|
-
|
|
883
|
-
prompt += """
|
|
884
|
-
# ❗ 重要规则
|
|
885
|
-
1. 始终基于提供的代码进行回答
|
|
886
|
-
2. 保持技术准确性
|
|
887
|
-
3. 在相关时包含代码示例
|
|
888
|
-
4. 指出任何缺失的信息
|
|
889
|
-
5. 保持专业语言
|
|
890
|
-
6. 使用用户的语言进行回答
|
|
891
|
-
"""
|
|
892
|
-
|
|
893
|
-
model = PlatformRegistry.get_global_platform_registry().get_thinking_platform()
|
|
894
|
-
spinner.text = "生成回答..."
|
|
895
|
-
ret = files_from_codebase, model.chat_until_success(prompt)
|
|
896
|
-
spinner.text = "回答生成完成"
|
|
897
|
-
spinner.ok("✅")
|
|
898
|
-
return ret
|
|
899
|
-
|
|
900
|
-
def is_index_generated(self) -> bool:
|
|
901
|
-
"""Check if the index has been generated"""
|
|
902
|
-
try:
|
|
903
|
-
# 1. 检查基本条件
|
|
904
|
-
if not self.vector_cache or not self.file_paths:
|
|
905
|
-
return False
|
|
906
|
-
|
|
907
|
-
if not hasattr(self, 'index') or self.index is None:
|
|
908
|
-
return False
|
|
909
|
-
|
|
910
|
-
# 2. 检查索引是否可用
|
|
911
|
-
# 创建测试向量
|
|
912
|
-
test_vector = np.zeros((1, self.vector_dim), dtype=np.float32) # type: ignore
|
|
913
|
-
try:
|
|
914
|
-
self.index.search(test_vector, 1) # type: ignore
|
|
915
|
-
except Exception:
|
|
916
|
-
return False
|
|
917
|
-
|
|
918
|
-
# 3. 验证向量缓存和文件路径的一致性
|
|
919
|
-
if len(self.vector_cache) != len(self.file_paths):
|
|
920
|
-
return False
|
|
921
|
-
|
|
922
|
-
# 4. 验证所有缓存文件
|
|
923
|
-
for file_path in self.file_paths:
|
|
924
|
-
if file_path not in self.vector_cache:
|
|
925
|
-
return False
|
|
926
|
-
|
|
927
|
-
cache_path = self._get_cache_path(file_path)
|
|
928
|
-
if not os.path.exists(cache_path):
|
|
929
|
-
return False
|
|
930
|
-
|
|
931
|
-
cache_data = self.vector_cache[file_path]
|
|
932
|
-
if not isinstance(cache_data.get("vector"), np.ndarray):
|
|
933
|
-
return False
|
|
934
|
-
|
|
935
|
-
return True
|
|
936
|
-
|
|
937
|
-
except Exception as e:
|
|
938
|
-
PrettyOutput.print(f"检查索引状态失败: {str(e)}",
|
|
939
|
-
output_type=OutputType.ERROR)
|
|
940
|
-
return False
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
def main():
|
|
947
|
-
|
|
948
|
-
parser = argparse.ArgumentParser(description='Codebase management and search tool')
|
|
949
|
-
subparsers = parser.add_subparsers(dest='command', help='Available commands')
|
|
950
|
-
|
|
951
|
-
# Generate command
|
|
952
|
-
generate_parser = subparsers.add_parser('generate', help='Generate codebase index')
|
|
953
|
-
generate_parser.add_argument('--force', action='store_true', help='Force rebuild index')
|
|
954
|
-
|
|
955
|
-
# Search command
|
|
956
|
-
search_parser = subparsers.add_parser('search', help='Search similar code files')
|
|
957
|
-
search_parser.add_argument('query', type=str, help='Search query')
|
|
958
|
-
search_parser.add_argument('--top-k', type=int, default=20, help='Number of results to return (default: 20)')
|
|
959
|
-
|
|
960
|
-
# Ask command
|
|
961
|
-
ask_parser = subparsers.add_parser('ask', help='Ask a question about the codebase')
|
|
962
|
-
ask_parser.add_argument('question', type=str, help='Question to ask')
|
|
963
|
-
ask_parser.add_argument('--top-k', type=int, default=20, help='Number of results to use (default: 20)')
|
|
964
|
-
|
|
965
|
-
export_parser = subparsers.add_parser('export', help='Export current index data')
|
|
966
|
-
args = parser.parse_args()
|
|
967
|
-
|
|
968
|
-
current_dir = find_git_root()
|
|
969
|
-
codebase = CodeBase(current_dir)
|
|
970
|
-
|
|
971
|
-
if args.command == 'export':
|
|
972
|
-
codebase.export()
|
|
973
|
-
return
|
|
974
|
-
|
|
975
|
-
# 如果没有生成索引,且不是生成命令,提示用户先生成索引
|
|
976
|
-
if not codebase.is_index_generated() and args.command != 'generate':
|
|
977
|
-
PrettyOutput.print("索引尚未生成,请先运行 'generate' 命令生成索引", output_type=OutputType.WARNING)
|
|
978
|
-
return
|
|
979
|
-
|
|
980
|
-
if args.command == 'generate':
|
|
981
|
-
try:
|
|
982
|
-
codebase.generate_codebase(force=args.force)
|
|
983
|
-
PrettyOutput.print("代码库生成完成", output_type=OutputType.SUCCESS)
|
|
984
|
-
except Exception as e:
|
|
985
|
-
PrettyOutput.print(f"代码库生成失败: {str(e)}", output_type=OutputType.ERROR)
|
|
986
|
-
|
|
987
|
-
elif args.command == 'search':
|
|
988
|
-
results = codebase.search_similar(args.query, args.top_k)
|
|
989
|
-
if not results:
|
|
990
|
-
PrettyOutput.print("没有找到相似的文件", output_type=OutputType.WARNING)
|
|
991
|
-
return
|
|
992
|
-
|
|
993
|
-
output = "搜索结果:\n"
|
|
994
|
-
for path in results:
|
|
995
|
-
output += f"""- {path}\n"""
|
|
996
|
-
PrettyOutput.print(output, output_type=OutputType.INFO, lang="markdown")
|
|
997
|
-
|
|
998
|
-
elif args.command == 'ask':
|
|
999
|
-
files, answer = codebase.ask_codebase(args.question, args.top_k)
|
|
1000
|
-
output = f"# 相关文件:\n"
|
|
1001
|
-
for file in files:
|
|
1002
|
-
output += f"""- {file['file']} ({file['reason']})\n"""
|
|
1003
|
-
output += f"# 回答:\n{answer}"
|
|
1004
|
-
PrettyOutput.print(output, output_type=OutputType.SYSTEM, lang="markdown")
|
|
1005
|
-
|
|
1006
|
-
else:
|
|
1007
|
-
parser.print_help()
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
if __name__ == "__main__":
|
|
1011
|
-
exit(main())
|