jarvis-ai-assistant 0.1.91__py3-none-any.whl → 0.1.93__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jarvis-ai-assistant might be problematic. Click here for more details.
- jarvis/__init__.py +1 -1
- jarvis/agent.py +6 -4
- jarvis/jarvis_codebase/main.py +274 -188
- jarvis/jarvis_coder/__init__.py +0 -0
- jarvis/jarvis_coder/git_utils.py +64 -0
- jarvis/jarvis_coder/main.py +630 -0
- jarvis/jarvis_coder/patch_handler.py +493 -0
- jarvis/jarvis_coder/plan_generator.py +75 -0
- jarvis/jarvis_platform/main.py +13 -2
- jarvis/jarvis_rag/main.py +185 -49
- jarvis/jarvis_smart_shell/main.py +16 -9
- jarvis/main.py +9 -0
- jarvis/models/ai8.py +4 -3
- jarvis/models/ollama.py +3 -3
- jarvis/models/openai.py +2 -2
- jarvis/models/oyi.py +13 -13
- jarvis/tools/ask_user.py +1 -2
- jarvis/tools/coder.py +69 -0
- jarvis/tools/thinker.py +25 -79
- jarvis/utils.py +30 -2
- {jarvis_ai_assistant-0.1.91.dist-info → jarvis_ai_assistant-0.1.93.dist-info}/METADATA +3 -1
- jarvis_ai_assistant-0.1.93.dist-info/RECORD +47 -0
- {jarvis_ai_assistant-0.1.91.dist-info → jarvis_ai_assistant-0.1.93.dist-info}/entry_points.txt +1 -0
- jarvis_ai_assistant-0.1.91.dist-info/RECORD +0 -41
- {jarvis_ai_assistant-0.1.91.dist-info → jarvis_ai_assistant-0.1.93.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.91.dist-info → jarvis_ai_assistant-0.1.93.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.91.dist-info → jarvis_ai_assistant-0.1.93.dist-info}/top_level.txt +0 -0
jarvis/__init__.py
CHANGED
jarvis/agent.py
CHANGED
|
@@ -140,7 +140,7 @@ class Agent:
|
|
|
140
140
|
|
|
141
141
|
def _load_methodology(self, user_input: str) -> Dict[str, str]:
|
|
142
142
|
"""加载方法论并构建向量索引"""
|
|
143
|
-
PrettyOutput.print("加载方法论...", OutputType.
|
|
143
|
+
PrettyOutput.print("加载方法论...", OutputType.PROGRESS)
|
|
144
144
|
user_jarvis_methodology = os.path.expanduser("~/.jarvis_methodology")
|
|
145
145
|
if not os.path.exists(user_jarvis_methodology):
|
|
146
146
|
return {}
|
|
@@ -165,13 +165,13 @@ class Agent:
|
|
|
165
165
|
|
|
166
166
|
if vectors:
|
|
167
167
|
vectors_array = np.vstack(vectors)
|
|
168
|
-
self.methodology_index.add_with_ids(vectors_array, np.array(ids))
|
|
168
|
+
self.methodology_index.add_with_ids(vectors_array, np.array(ids)) # type: ignore
|
|
169
169
|
query_embedding = self._create_methodology_embedding(user_input)
|
|
170
170
|
k = min(5, len(self.methodology_data))
|
|
171
171
|
PrettyOutput.print(f"检索方法论...", OutputType.INFO)
|
|
172
172
|
distances, indices = self.methodology_index.search(
|
|
173
173
|
query_embedding.reshape(1, -1), k
|
|
174
|
-
)
|
|
174
|
+
) # type: ignore
|
|
175
175
|
|
|
176
176
|
relevant_methodologies = {}
|
|
177
177
|
for dist, idx in zip(distances[0], indices[0]):
|
|
@@ -208,7 +208,7 @@ class Agent:
|
|
|
208
208
|
"""
|
|
209
209
|
# 创建一个新的模型实例来做总结,避免影响主对话
|
|
210
210
|
|
|
211
|
-
PrettyOutput.print("总结对话历史,准备生成总结,开始新的对话...", OutputType.
|
|
211
|
+
PrettyOutput.print("总结对话历史,准备生成总结,开始新的对话...", OutputType.PROGRESS)
|
|
212
212
|
|
|
213
213
|
prompt = """请总结之前对话中的关键信息,包括:
|
|
214
214
|
1. 当前任务目标
|
|
@@ -259,6 +259,8 @@ class Agent:
|
|
|
259
259
|
analysis_prompt = """本次任务已结束,请分析是否需要生成方法论。
|
|
260
260
|
如果认为需要生成方法论,请先判断是创建新的方法论还是更新已有方法论。如果是更新已有方法论,使用update,否则使用add。
|
|
261
261
|
如果认为不需要生成方法论,请说明原因。
|
|
262
|
+
方法论应该适应普遍场景,不要出现本次任务特定的信息,如代码的commit信息等。
|
|
263
|
+
方法论中应该包含:问题重述、最优解决方案、注意事项(按需),除此外不要出现任何其他的信息。
|
|
262
264
|
仅输出方法论工具的调用指令,或者是不需要生成方法论的说明,除此之外不要输出任何内容。
|
|
263
265
|
"""
|
|
264
266
|
self.prompt = analysis_prompt
|
jarvis/jarvis_codebase/main.py
CHANGED
|
@@ -2,16 +2,18 @@ import hashlib
|
|
|
2
2
|
import os
|
|
3
3
|
import numpy as np
|
|
4
4
|
import faiss
|
|
5
|
-
from typing import List, Tuple, Optional
|
|
5
|
+
from typing import List, Tuple, Optional, Dict
|
|
6
6
|
from jarvis.models.registry import PlatformRegistry
|
|
7
7
|
import concurrent.futures
|
|
8
8
|
from threading import Lock
|
|
9
9
|
from concurrent.futures import ThreadPoolExecutor
|
|
10
|
-
from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_max_context_length, get_thread_count, load_embedding_model, load_rerank_model
|
|
10
|
+
from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_file_md5, get_max_context_length, get_thread_count, load_embedding_model, load_rerank_model
|
|
11
11
|
from jarvis.utils import load_env_from_file
|
|
12
12
|
import argparse
|
|
13
13
|
from sentence_transformers import SentenceTransformer
|
|
14
14
|
import pickle
|
|
15
|
+
import lzma # 添加 lzma 导入
|
|
16
|
+
from tqdm import tqdm
|
|
15
17
|
|
|
16
18
|
class CodeBase:
|
|
17
19
|
def __init__(self, root_dir: str):
|
|
@@ -58,7 +60,7 @@ class CodeBase:
|
|
|
58
60
|
# 加载缓存
|
|
59
61
|
if os.path.exists(self.cache_path):
|
|
60
62
|
try:
|
|
61
|
-
with open(self.cache_path, 'rb') as f:
|
|
63
|
+
with lzma.open(self.cache_path, 'rb') as f:
|
|
62
64
|
cache_data = pickle.load(f)
|
|
63
65
|
self.vector_cache = cache_data["vectors"]
|
|
64
66
|
self.file_paths = cache_data["file_paths"]
|
|
@@ -88,19 +90,13 @@ class CodeBase:
|
|
|
88
90
|
return False
|
|
89
91
|
|
|
90
92
|
def make_description(self, file_path: str, content: str) -> str:
|
|
91
|
-
model = PlatformRegistry.get_global_platform_registry().
|
|
93
|
+
model = PlatformRegistry.get_global_platform_registry().get_cheap_platform()
|
|
92
94
|
model.set_suppress_output(True)
|
|
93
95
|
prompt = f"""请分析以下代码文件,并生成一个详细的描述。描述应该包含以下要点:
|
|
96
|
+
1. 整个文件的功能描述,不超过100个字
|
|
97
|
+
2. 每个全局变量的函数、类型定义、类、方法等代码元素的一句话描述,不超过50字
|
|
94
98
|
|
|
95
|
-
|
|
96
|
-
2. 关键类和方法的作用
|
|
97
|
-
3. 重要的依赖和技术特征(如使用了什么框架、算法、设计模式等)
|
|
98
|
-
4. 代码处理的主要数据类型和数据结构
|
|
99
|
-
5. 关键业务逻辑和处理流程
|
|
100
|
-
6. 特殊功能点和亮点特性
|
|
101
|
-
|
|
102
|
-
请用简洁专业的语言描述,突出代码的技术特征和功能特点,以便后续进行关联代码检索。
|
|
103
|
-
|
|
99
|
+
请用简洁专业的语言描述,突出代码的技术功能,以便后续进行关联代码检索。
|
|
104
100
|
文件路径:{file_path}
|
|
105
101
|
代码内容:
|
|
106
102
|
{content}
|
|
@@ -108,20 +104,24 @@ class CodeBase:
|
|
|
108
104
|
response = model.chat(prompt)
|
|
109
105
|
return response
|
|
110
106
|
|
|
111
|
-
def
|
|
107
|
+
def _save_cache(self):
|
|
112
108
|
"""保存缓存数据"""
|
|
113
109
|
try:
|
|
110
|
+
# 创建缓存数据的副本
|
|
114
111
|
cache_data = {
|
|
115
|
-
"vectors": self.vector_cache,
|
|
116
|
-
"file_paths": self.file_paths
|
|
112
|
+
"vectors": dict(self.vector_cache), # 创建字典的副本
|
|
113
|
+
"file_paths": list(self.file_paths) # 创建列表的副本
|
|
117
114
|
}
|
|
118
|
-
|
|
119
|
-
|
|
115
|
+
|
|
116
|
+
# 使用 lzma 压缩存储
|
|
117
|
+
with lzma.open(self.cache_path, 'wb') as f:
|
|
118
|
+
pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
120
119
|
PrettyOutput.print(f"保存了 {len(self.vector_cache)} 个向量缓存",
|
|
121
120
|
output_type=OutputType.INFO)
|
|
122
121
|
except Exception as e:
|
|
123
122
|
PrettyOutput.print(f"保存缓存失败: {str(e)}",
|
|
124
123
|
output_type=OutputType.ERROR)
|
|
124
|
+
raise # 抛出异常以便上层处理
|
|
125
125
|
|
|
126
126
|
def get_cached_vector(self, file_path: str, description: str) -> Optional[np.ndarray]:
|
|
127
127
|
"""从缓存获取文件的向量表示"""
|
|
@@ -157,24 +157,13 @@ class CodeBase:
|
|
|
157
157
|
output_type=OutputType.ERROR)
|
|
158
158
|
file_md5 = ""
|
|
159
159
|
|
|
160
|
+
# 只更新内存中的缓存
|
|
160
161
|
self.vector_cache[file_path] = {
|
|
161
162
|
"path": file_path, # 保存文件路径
|
|
162
163
|
"md5": file_md5, # 保存文件MD5
|
|
163
164
|
"description": description, # 保存文件描述
|
|
164
165
|
"vector": vector # 保存向量
|
|
165
166
|
}
|
|
166
|
-
|
|
167
|
-
# 保存缓存到文件
|
|
168
|
-
try:
|
|
169
|
-
with open(self.cache_path, 'wb') as f:
|
|
170
|
-
cache_data = {
|
|
171
|
-
"vectors": self.vector_cache,
|
|
172
|
-
"file_paths": self.file_paths
|
|
173
|
-
}
|
|
174
|
-
pickle.dump(cache_data, f)
|
|
175
|
-
except Exception as e:
|
|
176
|
-
PrettyOutput.print(f"保存向量缓存失败: {str(e)}",
|
|
177
|
-
output_type=OutputType.ERROR)
|
|
178
167
|
|
|
179
168
|
def get_embedding(self, text: str) -> np.ndarray:
|
|
180
169
|
"""使用 transformers 模型获取文本的向量表示"""
|
|
@@ -215,22 +204,34 @@ class CodeBase:
|
|
|
215
204
|
except Exception as e:
|
|
216
205
|
PrettyOutput.print(f"Error vectorizing file {file_path}: {str(e)}",
|
|
217
206
|
output_type=OutputType.ERROR)
|
|
218
|
-
return np.zeros(self.vector_dim, dtype=np.float32)
|
|
207
|
+
return np.zeros(self.vector_dim, dtype=np.float32) # type: ignore
|
|
219
208
|
|
|
220
209
|
def clean_cache(self) -> bool:
|
|
221
210
|
"""清理过期的缓存记录,返回是否有文件被删除"""
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
211
|
+
try:
|
|
212
|
+
files_to_delete = []
|
|
213
|
+
for file_path in list(self.vector_cache.keys()):
|
|
214
|
+
if file_path not in self.git_file_list:
|
|
215
|
+
del self.vector_cache[file_path]
|
|
216
|
+
files_to_delete.append(file_path)
|
|
217
|
+
|
|
218
|
+
if files_to_delete:
|
|
219
|
+
# 只在有文件被删除时保存缓存
|
|
220
|
+
self._save_cache()
|
|
221
|
+
PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
|
|
222
|
+
output_type=OutputType.INFO)
|
|
223
|
+
return True
|
|
224
|
+
return False
|
|
225
|
+
|
|
226
|
+
except Exception as e:
|
|
227
|
+
PrettyOutput.print(f"清理缓存失败: {str(e)}",
|
|
228
|
+
output_type=OutputType.ERROR)
|
|
229
|
+
# 发生异常时尝试保存当前状态
|
|
230
|
+
try:
|
|
231
|
+
self._save_cache()
|
|
232
|
+
except:
|
|
233
|
+
pass
|
|
234
|
+
return False
|
|
234
235
|
|
|
235
236
|
def process_file(self, file_path: str):
|
|
236
237
|
"""处理单个文件"""
|
|
@@ -241,16 +242,10 @@ class CodeBase:
|
|
|
241
242
|
|
|
242
243
|
if not self.is_text_file(file_path):
|
|
243
244
|
return None
|
|
244
|
-
|
|
245
|
-
# 读取文件内容,限制长度
|
|
246
|
-
with open(file_path, "r", encoding="utf-8") as f:
|
|
247
|
-
content = f.read()
|
|
248
|
-
if len(content) > self.max_context_length:
|
|
249
|
-
PrettyOutput.print(f"文件 {file_path} 内容超出长度限制,将截取前 {self.max_context_length} 个字符",
|
|
250
|
-
output_type=OutputType.WARNING)
|
|
251
|
-
content = content[:self.max_context_length]
|
|
252
245
|
|
|
253
|
-
md5 =
|
|
246
|
+
md5 = get_file_md5(file_path)
|
|
247
|
+
|
|
248
|
+
content = open(file_path, "r", encoding="utf-8").read()
|
|
254
249
|
|
|
255
250
|
# 检查文件是否已经处理过且内容未变
|
|
256
251
|
if file_path in self.vector_cache:
|
|
@@ -295,14 +290,14 @@ class CodeBase:
|
|
|
295
290
|
|
|
296
291
|
if vectors:
|
|
297
292
|
vectors = np.vstack(vectors)
|
|
298
|
-
self.index.add_with_ids(vectors, np.array(ids))
|
|
293
|
+
self.index.add_with_ids(vectors, np.array(ids)) # type: ignore
|
|
299
294
|
else:
|
|
300
295
|
self.index = None
|
|
301
296
|
|
|
302
297
|
def gen_vector_db_from_cache(self):
|
|
303
298
|
"""从缓存生成向量数据库"""
|
|
304
299
|
self.build_index()
|
|
305
|
-
self.
|
|
300
|
+
self._save_cache()
|
|
306
301
|
|
|
307
302
|
|
|
308
303
|
def generate_codebase(self, force: bool = False):
|
|
@@ -310,100 +305,152 @@ class CodeBase:
|
|
|
310
305
|
Args:
|
|
311
306
|
force: 是否强制重建索引,不询问用户
|
|
312
307
|
"""
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
# 检查文件变化
|
|
317
|
-
changes_detected = False
|
|
318
|
-
new_files = []
|
|
319
|
-
modified_files = []
|
|
320
|
-
deleted_files = []
|
|
321
|
-
|
|
322
|
-
# 检查删除的文件
|
|
323
|
-
files_to_delete = []
|
|
324
|
-
for file_path in list(self.vector_cache.keys()):
|
|
325
|
-
if file_path not in self.git_file_list:
|
|
326
|
-
deleted_files.append(file_path)
|
|
327
|
-
files_to_delete.append(file_path)
|
|
328
|
-
changes_detected = True
|
|
329
|
-
|
|
330
|
-
# 检查新增和修改的文件
|
|
331
|
-
for file_path in self.git_file_list:
|
|
332
|
-
if not os.path.exists(file_path) or not self.is_text_file(file_path):
|
|
333
|
-
continue
|
|
308
|
+
try:
|
|
309
|
+
# 更新 git 文件列表
|
|
310
|
+
self.git_file_list = self.get_git_file_list()
|
|
334
311
|
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
312
|
+
# 检查文件变化
|
|
313
|
+
PrettyOutput.print("\n检查文件变化...", output_type=OutputType.INFO)
|
|
314
|
+
changes_detected = False
|
|
315
|
+
new_files = []
|
|
316
|
+
modified_files = []
|
|
317
|
+
deleted_files = []
|
|
318
|
+
|
|
319
|
+
# 检查删除的文件
|
|
320
|
+
files_to_delete = []
|
|
321
|
+
for file_path in list(self.vector_cache.keys()):
|
|
322
|
+
if file_path not in self.git_file_list:
|
|
323
|
+
deleted_files.append(file_path)
|
|
324
|
+
files_to_delete.append(file_path)
|
|
343
325
|
changes_detected = True
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
326
|
+
|
|
327
|
+
# 检查新增和修改的文件
|
|
328
|
+
with tqdm(total=len(self.git_file_list), desc="检查文件状态") as pbar:
|
|
329
|
+
for file_path in self.git_file_list:
|
|
330
|
+
if not os.path.exists(file_path) or not self.is_text_file(file_path):
|
|
331
|
+
pbar.update(1)
|
|
332
|
+
continue
|
|
333
|
+
|
|
334
|
+
try:
|
|
335
|
+
current_md5 = get_file_md5(file_path)
|
|
336
|
+
|
|
337
|
+
if file_path not in self.vector_cache:
|
|
338
|
+
new_files.append(file_path)
|
|
339
|
+
changes_detected = True
|
|
340
|
+
elif self.vector_cache[file_path].get("md5") != current_md5:
|
|
341
|
+
modified_files.append(file_path)
|
|
342
|
+
changes_detected = True
|
|
343
|
+
except Exception as e:
|
|
344
|
+
PrettyOutput.print(f"检查文件失败 {file_path}: {str(e)}",
|
|
345
|
+
output_type=OutputType.ERROR)
|
|
346
|
+
pbar.update(1)
|
|
347
|
+
|
|
348
|
+
# 如果检测到变化,显示变化并询问用户
|
|
349
|
+
if changes_detected:
|
|
350
|
+
PrettyOutput.print("\n检测到以下变化:", output_type=OutputType.WARNING)
|
|
351
|
+
if new_files:
|
|
352
|
+
PrettyOutput.print("\n新增文件:", output_type=OutputType.INFO)
|
|
353
|
+
for f in new_files:
|
|
354
|
+
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
355
|
+
if modified_files:
|
|
356
|
+
PrettyOutput.print("\n修改的文件:", output_type=OutputType.INFO)
|
|
357
|
+
for f in modified_files:
|
|
358
|
+
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
359
|
+
if deleted_files:
|
|
360
|
+
PrettyOutput.print("\n删除的文件:", output_type=OutputType.INFO)
|
|
361
|
+
for f in deleted_files:
|
|
362
|
+
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
363
|
+
|
|
364
|
+
# 如果force为True,直接继续
|
|
365
|
+
if not force:
|
|
366
|
+
# 询问用户是否继续
|
|
367
|
+
while True:
|
|
368
|
+
response = input("\n是否重建索引?[y/N] ").lower().strip()
|
|
369
|
+
if response in ['y', 'yes']:
|
|
370
|
+
break
|
|
371
|
+
elif response in ['', 'n', 'no']:
|
|
372
|
+
PrettyOutput.print("取消重建索引", output_type=OutputType.INFO)
|
|
373
|
+
return
|
|
374
|
+
else:
|
|
375
|
+
PrettyOutput.print("请输入 y 或 n", output_type=OutputType.WARNING)
|
|
376
|
+
|
|
377
|
+
# 清理已删除的文件
|
|
378
|
+
for file_path in files_to_delete:
|
|
379
|
+
del self.vector_cache[file_path]
|
|
380
|
+
if files_to_delete:
|
|
381
|
+
PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
|
|
382
|
+
output_type=OutputType.INFO)
|
|
383
|
+
|
|
384
|
+
# 处理新文件和修改的文件
|
|
385
|
+
files_to_process = new_files + modified_files
|
|
386
|
+
processed_files = []
|
|
387
|
+
|
|
388
|
+
with tqdm(total=len(files_to_process), desc="处理文件") as pbar:
|
|
389
|
+
# 使用线程池处理文件
|
|
390
|
+
with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
|
|
391
|
+
# 提交所有任务
|
|
392
|
+
future_to_file = {
|
|
393
|
+
executor.submit(self.process_file, file): file
|
|
394
|
+
for file in files_to_process
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
# 处理完成的任务
|
|
398
|
+
for future in concurrent.futures.as_completed(future_to_file):
|
|
399
|
+
file = future_to_file[future]
|
|
400
|
+
try:
|
|
401
|
+
result = future.result()
|
|
402
|
+
if result:
|
|
403
|
+
processed_files.append(result)
|
|
404
|
+
except Exception as e:
|
|
405
|
+
PrettyOutput.print(f"处理文件失败 {file}: {str(e)}",
|
|
406
|
+
output_type=OutputType.ERROR)
|
|
407
|
+
pbar.update(1)
|
|
408
|
+
|
|
409
|
+
if processed_files:
|
|
410
|
+
PrettyOutput.print("\n重新生成向量数据库...", output_type=OutputType.INFO)
|
|
411
|
+
self.gen_vector_db_from_cache()
|
|
412
|
+
PrettyOutput.print(f"成功为 {len(processed_files)} 个文件生成索引",
|
|
413
|
+
output_type=OutputType.SUCCESS)
|
|
414
|
+
else:
|
|
415
|
+
PrettyOutput.print("没有检测到文件变更,无需重建索引", output_type=OutputType.INFO)
|
|
416
|
+
|
|
417
|
+
except Exception as e:
|
|
418
|
+
# 发生异常时尝试保存缓存
|
|
419
|
+
try:
|
|
420
|
+
self._save_cache()
|
|
421
|
+
except Exception as save_error:
|
|
422
|
+
PrettyOutput.print(f"保存缓存失败: {str(save_error)}",
|
|
423
|
+
output_type=OutputType.ERROR)
|
|
424
|
+
raise e # 重新抛出原始异常
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def _text_search_score(self, content: str, keywords: List[str]) -> float:
|
|
428
|
+
"""计算文本内容与关键词的匹配分数
|
|
348
429
|
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
if new_files:
|
|
353
|
-
PrettyOutput.print("\n新增文件:", output_type=OutputType.INFO)
|
|
354
|
-
for f in new_files:
|
|
355
|
-
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
356
|
-
if modified_files:
|
|
357
|
-
PrettyOutput.print("\n修改的文件:", output_type=OutputType.INFO)
|
|
358
|
-
for f in modified_files:
|
|
359
|
-
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
360
|
-
if deleted_files:
|
|
361
|
-
PrettyOutput.print("\n删除的文件:", output_type=OutputType.INFO)
|
|
362
|
-
for f in deleted_files:
|
|
363
|
-
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
# 如果force为True,直接继续
|
|
367
|
-
if not force:
|
|
368
|
-
# 询问用户是否继续
|
|
369
|
-
while True:
|
|
370
|
-
response = input("\n是否重建索引?[y/N] ").lower().strip()
|
|
371
|
-
if response in ['y', 'yes']:
|
|
372
|
-
break
|
|
373
|
-
elif response in ['', 'n', 'no']:
|
|
374
|
-
PrettyOutput.print("取消重建索引", output_type=OutputType.INFO)
|
|
375
|
-
return
|
|
376
|
-
else:
|
|
377
|
-
PrettyOutput.print("请输入 y 或 n", output_type=OutputType.WARNING)
|
|
378
|
-
|
|
379
|
-
# 清理已删除的文件
|
|
380
|
-
for file_path in files_to_delete:
|
|
381
|
-
del self.vector_cache[file_path]
|
|
382
|
-
if files_to_delete:
|
|
383
|
-
PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
|
|
384
|
-
output_type=OutputType.INFO)
|
|
430
|
+
Args:
|
|
431
|
+
content: 文本内容
|
|
432
|
+
keywords: 关键词列表
|
|
385
433
|
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
PrettyOutput.print("没有检测到文件变更,无需重建索引", output_type=OutputType.INFO)
|
|
434
|
+
Returns:
|
|
435
|
+
float: 匹配分数 (0-1)
|
|
436
|
+
"""
|
|
437
|
+
if not keywords:
|
|
438
|
+
return 0.0
|
|
439
|
+
|
|
440
|
+
content = content.lower()
|
|
441
|
+
matched_keywords = set()
|
|
442
|
+
|
|
443
|
+
for keyword in keywords:
|
|
444
|
+
keyword = keyword.lower()
|
|
445
|
+
if keyword in content:
|
|
446
|
+
matched_keywords.add(keyword)
|
|
447
|
+
|
|
448
|
+
# 计算匹配分数
|
|
449
|
+
score = len(matched_keywords) / len(keywords)
|
|
450
|
+
return score
|
|
404
451
|
|
|
405
452
|
def rerank_results(self, query: str, initial_results: List[Tuple[str, float, str]]) -> List[Tuple[str, float, str]]:
|
|
406
|
-
"""
|
|
453
|
+
"""使用多种策略对搜索结果重新排序"""
|
|
407
454
|
if not initial_results:
|
|
408
455
|
return []
|
|
409
456
|
|
|
@@ -413,13 +460,15 @@ class CodeBase:
|
|
|
413
460
|
# 加载模型和分词器
|
|
414
461
|
model, tokenizer = load_rerank_model()
|
|
415
462
|
|
|
416
|
-
# 准备数据
|
|
463
|
+
# 准备数据
|
|
417
464
|
pairs = []
|
|
465
|
+
|
|
418
466
|
for path, _, desc in initial_results:
|
|
419
467
|
try:
|
|
420
468
|
with open(path, "r", encoding="utf-8") as f:
|
|
421
469
|
content = f.read()[:512] # 限制内容长度
|
|
422
|
-
|
|
470
|
+
|
|
471
|
+
# 组合文件信息
|
|
423
472
|
doc_content = f"文件: {path}\n描述: {desc}\n内容: {content}"
|
|
424
473
|
pairs.append([query, doc_content])
|
|
425
474
|
except Exception as e:
|
|
@@ -430,6 +479,7 @@ class CodeBase:
|
|
|
430
479
|
|
|
431
480
|
# 使用更大的batch size提高处理速度
|
|
432
481
|
batch_size = 16 # 根据GPU显存调整
|
|
482
|
+
batch_scores = []
|
|
433
483
|
|
|
434
484
|
with torch.no_grad():
|
|
435
485
|
for i in range(0, len(pairs), batch_size):
|
|
@@ -446,8 +496,7 @@ class CodeBase:
|
|
|
446
496
|
encoded = {k: v.cuda() for k, v in encoded.items()}
|
|
447
497
|
|
|
448
498
|
outputs = model(**encoded)
|
|
449
|
-
|
|
450
|
-
batch_scores = outputs.logits.squeeze(-1).cpu().numpy()
|
|
499
|
+
batch_scores.extend(outputs.logits.squeeze(-1).cpu().numpy())
|
|
451
500
|
|
|
452
501
|
# 归一化分数到 0-1 范围
|
|
453
502
|
if batch_scores:
|
|
@@ -456,61 +505,98 @@ class CodeBase:
|
|
|
456
505
|
if max_score > min_score:
|
|
457
506
|
batch_scores = [(s - min_score) / (max_score - min_score) for s in batch_scores]
|
|
458
507
|
|
|
459
|
-
#
|
|
508
|
+
# 将重排序分数与原始分数结合
|
|
460
509
|
scored_results = []
|
|
461
|
-
for (path,
|
|
462
|
-
|
|
463
|
-
|
|
510
|
+
for (path, orig_score, desc), rerank_score in zip(initial_results, batch_scores):
|
|
511
|
+
# 综合分数 = 0.3 * 原始分数 + 0.7 * 重排序分数
|
|
512
|
+
combined_score = 0.3 * float(orig_score) + 0.7 * float(rerank_score)
|
|
513
|
+
if combined_score >= 0.5: # 只保留相关度较高的结果
|
|
514
|
+
scored_results.append((path, combined_score, desc))
|
|
464
515
|
|
|
465
|
-
#
|
|
516
|
+
# 按综合分数降序排序
|
|
466
517
|
scored_results.sort(key=lambda x: x[1], reverse=True)
|
|
467
518
|
|
|
468
519
|
return scored_results
|
|
469
520
|
|
|
470
521
|
except Exception as e:
|
|
471
|
-
PrettyOutput.print(f"
|
|
472
|
-
|
|
522
|
+
PrettyOutput.print(f"重排序失败: {str(e)}",
|
|
523
|
+
output_type=OutputType.ERROR)
|
|
524
|
+
return initial_results # 发生错误时返回原始结果
|
|
525
|
+
|
|
526
|
+
def _generate_query_variants(self, query: str) -> List[str]:
|
|
527
|
+
"""生成查询的不同表述变体
|
|
528
|
+
|
|
529
|
+
Args:
|
|
530
|
+
query: 原始查询
|
|
531
|
+
|
|
532
|
+
Returns:
|
|
533
|
+
List[str]: 查询变体列表
|
|
534
|
+
"""
|
|
535
|
+
model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
|
|
536
|
+
prompt = f"""请根据以下查询,生成3个不同的表述,每个表述都要完整表达原始查询的意思。这些表述将用于代码搜索,要保持专业性和准确性。
|
|
537
|
+
原始查询: {query}
|
|
538
|
+
|
|
539
|
+
请直接输出3个表述,用换行分隔,不要有编号或其他标记。
|
|
540
|
+
"""
|
|
541
|
+
variants = model.chat(prompt).strip().split('\n')
|
|
542
|
+
variants.append(query) # 添加原始查询
|
|
543
|
+
return variants
|
|
544
|
+
|
|
545
|
+
def _vector_search(self, query_variants: List[str], top_k: int) -> Dict[str, Tuple[str, float, str]]:
|
|
546
|
+
"""使用向量搜索查找相关文件
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
query_variants: 查询变体列表
|
|
550
|
+
top_k: 返回结果数量
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
Dict[str, Tuple[str, float, str]]: 文件路径到(路径,分数,描述)的映射
|
|
554
|
+
"""
|
|
555
|
+
results = {}
|
|
556
|
+
for query in query_variants:
|
|
557
|
+
query_vector = self.get_embedding(query)
|
|
558
|
+
query_vector = query_vector.reshape(1, -1)
|
|
559
|
+
|
|
560
|
+
distances, indices = self.index.search(query_vector, top_k) # type: ignore
|
|
561
|
+
|
|
562
|
+
for i, distance in zip(indices[0], distances[0]):
|
|
563
|
+
if i == -1:
|
|
564
|
+
continue
|
|
565
|
+
|
|
566
|
+
similarity = 1.0 / (1.0 + float(distance))
|
|
567
|
+
if similarity >= 0.5:
|
|
568
|
+
file_path = self.file_paths[i]
|
|
569
|
+
# 使用最高的相似度分数
|
|
570
|
+
if file_path not in results or similarity > results[file_path][1]:
|
|
571
|
+
data = self.vector_cache[file_path]
|
|
572
|
+
results[file_path] = (file_path, similarity, data["description"])
|
|
573
|
+
|
|
574
|
+
return results
|
|
575
|
+
|
|
473
576
|
|
|
474
577
|
def search_similar(self, query: str, top_k: int = 30) -> List[Tuple[str, float, str]]:
|
|
475
578
|
"""搜索关联文件"""
|
|
476
579
|
try:
|
|
477
580
|
if self.index is None:
|
|
478
|
-
return []
|
|
479
|
-
#
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
581
|
+
return []
|
|
582
|
+
# 生成查询变体
|
|
583
|
+
query_variants = self._generate_query_variants(query)
|
|
584
|
+
|
|
585
|
+
# 进行向量搜索
|
|
586
|
+
vector_results = self._vector_search(query_variants, top_k)
|
|
483
587
|
|
|
484
|
-
|
|
485
|
-
"""
|
|
486
|
-
query_variants = model.chat(prompt).strip().split('\n')
|
|
487
|
-
query_variants.append(query) # 添加原始查询
|
|
488
|
-
|
|
489
|
-
# 对每个查询变体进行搜索
|
|
490
|
-
all_results = {}
|
|
491
|
-
for q in query_variants:
|
|
492
|
-
q_vector = self.get_embedding(q)
|
|
493
|
-
q_vector = q_vector.reshape(1, -1)
|
|
494
|
-
|
|
495
|
-
distances, indices = self.index.search(q_vector, top_k)
|
|
496
|
-
|
|
497
|
-
for i, distance in zip(indices[0], distances[0]):
|
|
498
|
-
if i == -1:
|
|
499
|
-
continue
|
|
500
|
-
|
|
501
|
-
similarity = 1.0 / (1.0 + float(distance))
|
|
502
|
-
if similarity >= 0.5:
|
|
503
|
-
file_path = self.file_paths[i]
|
|
504
|
-
# 使用最高的相似度分数
|
|
505
|
-
if file_path not in all_results or similarity > all_results[file_path][1]:
|
|
506
|
-
data = self.vector_cache[file_path]
|
|
507
|
-
all_results[file_path] = (file_path, similarity, data["description"])
|
|
508
|
-
|
|
509
|
-
# 转换为列表并排序
|
|
510
|
-
results = list(all_results.values())
|
|
588
|
+
results = list(vector_results.values())
|
|
511
589
|
results.sort(key=lambda x: x[1], reverse=True)
|
|
590
|
+
|
|
591
|
+
# 取前 top_k 个结果进行重排序
|
|
592
|
+
initial_results = results[:top_k]
|
|
512
593
|
|
|
513
|
-
|
|
594
|
+
# 如果没有找到结果,直接返回
|
|
595
|
+
if not initial_results:
|
|
596
|
+
return []
|
|
597
|
+
|
|
598
|
+
# 对初步结果进行重排序
|
|
599
|
+
return self.rerank_results(query, initial_results)
|
|
514
600
|
|
|
515
601
|
except Exception as e:
|
|
516
602
|
PrettyOutput.print(f"搜索失败: {str(e)}", output_type=OutputType.ERROR)
|
|
@@ -564,7 +650,7 @@ class CodeBase:
|
|
|
564
650
|
|
|
565
651
|
# 检查缓存是否有效
|
|
566
652
|
try:
|
|
567
|
-
with open(self.cache_path, 'rb') as f:
|
|
653
|
+
with lzma.open(self.cache_path, 'rb') as f:
|
|
568
654
|
cache_data = pickle.load(f)
|
|
569
655
|
if not cache_data.get("vectors") or not cache_data.get("file_paths"):
|
|
570
656
|
return False
|
|
@@ -625,4 +711,4 @@ def main():
|
|
|
625
711
|
|
|
626
712
|
|
|
627
713
|
if __name__ == "__main__":
|
|
628
|
-
exit(main())
|
|
714
|
+
exit(main())
|
|
File without changes
|