jarvis-ai-assistant 0.1.90__py3-none-any.whl → 0.1.92__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jarvis-ai-assistant might be problematic. Click here for more details.
- jarvis/__init__.py +1 -1
- jarvis/agent.py +2 -0
- jarvis/jarvis_codebase/main.py +268 -176
- jarvis/jarvis_platform/main.py +13 -2
- jarvis/jarvis_rag/main.py +185 -49
- jarvis/jarvis_smart_shell/main.py +16 -15
- jarvis/main.py +9 -0
- jarvis/models/ollama.py +3 -3
- jarvis/models/openai.py +6 -2
- jarvis/tools/thinker.py +149 -0
- jarvis/utils.py +5 -1
- {jarvis_ai_assistant-0.1.90.dist-info → jarvis_ai_assistant-0.1.92.dist-info}/METADATA +3 -1
- {jarvis_ai_assistant-0.1.90.dist-info → jarvis_ai_assistant-0.1.92.dist-info}/RECORD +17 -16
- {jarvis_ai_assistant-0.1.90.dist-info → jarvis_ai_assistant-0.1.92.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.90.dist-info → jarvis_ai_assistant-0.1.92.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.90.dist-info → jarvis_ai_assistant-0.1.92.dist-info}/entry_points.txt +0 -0
- {jarvis_ai_assistant-0.1.90.dist-info → jarvis_ai_assistant-0.1.92.dist-info}/top_level.txt +0 -0
jarvis/__init__.py
CHANGED
jarvis/agent.py
CHANGED
|
@@ -259,6 +259,8 @@ class Agent:
|
|
|
259
259
|
analysis_prompt = """本次任务已结束,请分析是否需要生成方法论。
|
|
260
260
|
如果认为需要生成方法论,请先判断是创建新的方法论还是更新已有方法论。如果是更新已有方法论,使用update,否则使用add。
|
|
261
261
|
如果认为不需要生成方法论,请说明原因。
|
|
262
|
+
方法论应该适应普遍场景,不要出现本次任务特定的信息,如代码的commit信息等。
|
|
263
|
+
方法论中应该包含:问题重述、最优解决方案、注意事项(按需),除此外不要出现任何其他的信息。
|
|
262
264
|
仅输出方法论工具的调用指令,或者是不需要生成方法论的说明,除此之外不要输出任何内容。
|
|
263
265
|
"""
|
|
264
266
|
self.prompt = analysis_prompt
|
jarvis/jarvis_codebase/main.py
CHANGED
|
@@ -2,16 +2,18 @@ import hashlib
|
|
|
2
2
|
import os
|
|
3
3
|
import numpy as np
|
|
4
4
|
import faiss
|
|
5
|
-
from typing import List, Tuple, Optional
|
|
5
|
+
from typing import List, Tuple, Optional, Dict
|
|
6
6
|
from jarvis.models.registry import PlatformRegistry
|
|
7
7
|
import concurrent.futures
|
|
8
8
|
from threading import Lock
|
|
9
9
|
from concurrent.futures import ThreadPoolExecutor
|
|
10
|
-
from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_max_context_length, get_thread_count, load_embedding_model, load_rerank_model
|
|
10
|
+
from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_file_md5, get_max_context_length, get_thread_count, load_embedding_model, load_rerank_model
|
|
11
11
|
from jarvis.utils import load_env_from_file
|
|
12
12
|
import argparse
|
|
13
13
|
from sentence_transformers import SentenceTransformer
|
|
14
14
|
import pickle
|
|
15
|
+
import lzma # 添加 lzma 导入
|
|
16
|
+
from tqdm import tqdm
|
|
15
17
|
|
|
16
18
|
class CodeBase:
|
|
17
19
|
def __init__(self, root_dir: str):
|
|
@@ -58,7 +60,7 @@ class CodeBase:
|
|
|
58
60
|
# 加载缓存
|
|
59
61
|
if os.path.exists(self.cache_path):
|
|
60
62
|
try:
|
|
61
|
-
with open(self.cache_path, 'rb') as f:
|
|
63
|
+
with lzma.open(self.cache_path, 'rb') as f:
|
|
62
64
|
cache_data = pickle.load(f)
|
|
63
65
|
self.vector_cache = cache_data["vectors"]
|
|
64
66
|
self.file_paths = cache_data["file_paths"]
|
|
@@ -88,7 +90,7 @@ class CodeBase:
|
|
|
88
90
|
return False
|
|
89
91
|
|
|
90
92
|
def make_description(self, file_path: str, content: str) -> str:
|
|
91
|
-
model = PlatformRegistry.get_global_platform_registry().
|
|
93
|
+
model = PlatformRegistry.get_global_platform_registry().get_cheap_platform()
|
|
92
94
|
model.set_suppress_output(True)
|
|
93
95
|
prompt = f"""请分析以下代码文件,并生成一个详细的描述。描述应该包含以下要点:
|
|
94
96
|
|
|
@@ -108,20 +110,24 @@ class CodeBase:
|
|
|
108
110
|
response = model.chat(prompt)
|
|
109
111
|
return response
|
|
110
112
|
|
|
111
|
-
def
|
|
113
|
+
def _save_cache(self):
|
|
112
114
|
"""保存缓存数据"""
|
|
113
115
|
try:
|
|
116
|
+
# 创建缓存数据的副本
|
|
114
117
|
cache_data = {
|
|
115
|
-
"vectors": self.vector_cache,
|
|
116
|
-
"file_paths": self.file_paths
|
|
118
|
+
"vectors": dict(self.vector_cache), # 创建字典的副本
|
|
119
|
+
"file_paths": list(self.file_paths) # 创建列表的副本
|
|
117
120
|
}
|
|
118
|
-
|
|
119
|
-
|
|
121
|
+
|
|
122
|
+
# 使用 lzma 压缩存储
|
|
123
|
+
with lzma.open(self.cache_path, 'wb') as f:
|
|
124
|
+
pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
120
125
|
PrettyOutput.print(f"保存了 {len(self.vector_cache)} 个向量缓存",
|
|
121
126
|
output_type=OutputType.INFO)
|
|
122
127
|
except Exception as e:
|
|
123
128
|
PrettyOutput.print(f"保存缓存失败: {str(e)}",
|
|
124
129
|
output_type=OutputType.ERROR)
|
|
130
|
+
raise # 抛出异常以便上层处理
|
|
125
131
|
|
|
126
132
|
def get_cached_vector(self, file_path: str, description: str) -> Optional[np.ndarray]:
|
|
127
133
|
"""从缓存获取文件的向量表示"""
|
|
@@ -157,24 +163,13 @@ class CodeBase:
|
|
|
157
163
|
output_type=OutputType.ERROR)
|
|
158
164
|
file_md5 = ""
|
|
159
165
|
|
|
166
|
+
# 只更新内存中的缓存
|
|
160
167
|
self.vector_cache[file_path] = {
|
|
161
168
|
"path": file_path, # 保存文件路径
|
|
162
169
|
"md5": file_md5, # 保存文件MD5
|
|
163
170
|
"description": description, # 保存文件描述
|
|
164
171
|
"vector": vector # 保存向量
|
|
165
172
|
}
|
|
166
|
-
|
|
167
|
-
# 保存缓存到文件
|
|
168
|
-
try:
|
|
169
|
-
with open(self.cache_path, 'wb') as f:
|
|
170
|
-
cache_data = {
|
|
171
|
-
"vectors": self.vector_cache,
|
|
172
|
-
"file_paths": self.file_paths
|
|
173
|
-
}
|
|
174
|
-
pickle.dump(cache_data, f)
|
|
175
|
-
except Exception as e:
|
|
176
|
-
PrettyOutput.print(f"保存向量缓存失败: {str(e)}",
|
|
177
|
-
output_type=OutputType.ERROR)
|
|
178
173
|
|
|
179
174
|
def get_embedding(self, text: str) -> np.ndarray:
|
|
180
175
|
"""使用 transformers 模型获取文本的向量表示"""
|
|
@@ -219,18 +214,30 @@ class CodeBase:
|
|
|
219
214
|
|
|
220
215
|
def clean_cache(self) -> bool:
|
|
221
216
|
"""清理过期的缓存记录,返回是否有文件被删除"""
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
217
|
+
try:
|
|
218
|
+
files_to_delete = []
|
|
219
|
+
for file_path in list(self.vector_cache.keys()):
|
|
220
|
+
if file_path not in self.git_file_list:
|
|
221
|
+
del self.vector_cache[file_path]
|
|
222
|
+
files_to_delete.append(file_path)
|
|
223
|
+
|
|
224
|
+
if files_to_delete:
|
|
225
|
+
# 只在有文件被删除时保存缓存
|
|
226
|
+
self._save_cache()
|
|
227
|
+
PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
|
|
228
|
+
output_type=OutputType.INFO)
|
|
229
|
+
return True
|
|
230
|
+
return False
|
|
231
|
+
|
|
232
|
+
except Exception as e:
|
|
233
|
+
PrettyOutput.print(f"清理缓存失败: {str(e)}",
|
|
234
|
+
output_type=OutputType.ERROR)
|
|
235
|
+
# 发生异常时尝试保存当前状态
|
|
236
|
+
try:
|
|
237
|
+
self._save_cache()
|
|
238
|
+
except:
|
|
239
|
+
pass
|
|
240
|
+
return False
|
|
234
241
|
|
|
235
242
|
def process_file(self, file_path: str):
|
|
236
243
|
"""处理单个文件"""
|
|
@@ -241,16 +248,10 @@ class CodeBase:
|
|
|
241
248
|
|
|
242
249
|
if not self.is_text_file(file_path):
|
|
243
250
|
return None
|
|
244
|
-
|
|
245
|
-
# 读取文件内容,限制长度
|
|
246
|
-
with open(file_path, "r", encoding="utf-8") as f:
|
|
247
|
-
content = f.read()
|
|
248
|
-
if len(content) > self.max_context_length:
|
|
249
|
-
PrettyOutput.print(f"文件 {file_path} 内容超出长度限制,将截取前 {self.max_context_length} 个字符",
|
|
250
|
-
output_type=OutputType.WARNING)
|
|
251
|
-
content = content[:self.max_context_length]
|
|
252
251
|
|
|
253
|
-
md5 =
|
|
252
|
+
md5 = get_file_md5(file_path)
|
|
253
|
+
|
|
254
|
+
content = open(file_path, "r", encoding="utf-8").read()
|
|
254
255
|
|
|
255
256
|
# 检查文件是否已经处理过且内容未变
|
|
256
257
|
if file_path in self.vector_cache:
|
|
@@ -302,7 +303,7 @@ class CodeBase:
|
|
|
302
303
|
def gen_vector_db_from_cache(self):
|
|
303
304
|
"""从缓存生成向量数据库"""
|
|
304
305
|
self.build_index()
|
|
305
|
-
self.
|
|
306
|
+
self._save_cache()
|
|
306
307
|
|
|
307
308
|
|
|
308
309
|
def generate_codebase(self, force: bool = False):
|
|
@@ -310,100 +311,152 @@ class CodeBase:
|
|
|
310
311
|
Args:
|
|
311
312
|
force: 是否强制重建索引,不询问用户
|
|
312
313
|
"""
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
# 检查文件变化
|
|
317
|
-
changes_detected = False
|
|
318
|
-
new_files = []
|
|
319
|
-
modified_files = []
|
|
320
|
-
deleted_files = []
|
|
321
|
-
|
|
322
|
-
# 检查删除的文件
|
|
323
|
-
files_to_delete = []
|
|
324
|
-
for file_path in list(self.vector_cache.keys()):
|
|
325
|
-
if file_path not in self.git_file_list:
|
|
326
|
-
deleted_files.append(file_path)
|
|
327
|
-
files_to_delete.append(file_path)
|
|
328
|
-
changes_detected = True
|
|
329
|
-
|
|
330
|
-
# 检查新增和修改的文件
|
|
331
|
-
for file_path in self.git_file_list:
|
|
332
|
-
if not os.path.exists(file_path) or not self.is_text_file(file_path):
|
|
333
|
-
continue
|
|
314
|
+
try:
|
|
315
|
+
# 更新 git 文件列表
|
|
316
|
+
self.git_file_list = self.get_git_file_list()
|
|
334
317
|
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
318
|
+
# 检查文件变化
|
|
319
|
+
PrettyOutput.print("\n检查文件变化...", output_type=OutputType.INFO)
|
|
320
|
+
changes_detected = False
|
|
321
|
+
new_files = []
|
|
322
|
+
modified_files = []
|
|
323
|
+
deleted_files = []
|
|
324
|
+
|
|
325
|
+
# 检查删除的文件
|
|
326
|
+
files_to_delete = []
|
|
327
|
+
for file_path in list(self.vector_cache.keys()):
|
|
328
|
+
if file_path not in self.git_file_list:
|
|
329
|
+
deleted_files.append(file_path)
|
|
330
|
+
files_to_delete.append(file_path)
|
|
343
331
|
changes_detected = True
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
332
|
+
|
|
333
|
+
# 检查新增和修改的文件
|
|
334
|
+
with tqdm(total=len(self.git_file_list), desc="检查文件状态") as pbar:
|
|
335
|
+
for file_path in self.git_file_list:
|
|
336
|
+
if not os.path.exists(file_path) or not self.is_text_file(file_path):
|
|
337
|
+
pbar.update(1)
|
|
338
|
+
continue
|
|
339
|
+
|
|
340
|
+
try:
|
|
341
|
+
current_md5 = get_file_md5(file_path)
|
|
342
|
+
|
|
343
|
+
if file_path not in self.vector_cache:
|
|
344
|
+
new_files.append(file_path)
|
|
345
|
+
changes_detected = True
|
|
346
|
+
elif self.vector_cache[file_path].get("md5") != current_md5:
|
|
347
|
+
modified_files.append(file_path)
|
|
348
|
+
changes_detected = True
|
|
349
|
+
except Exception as e:
|
|
350
|
+
PrettyOutput.print(f"检查文件失败 {file_path}: {str(e)}",
|
|
351
|
+
output_type=OutputType.ERROR)
|
|
352
|
+
pbar.update(1)
|
|
353
|
+
|
|
354
|
+
# 如果检测到变化,显示变化并询问用户
|
|
355
|
+
if changes_detected:
|
|
356
|
+
PrettyOutput.print("\n检测到以下变化:", output_type=OutputType.WARNING)
|
|
357
|
+
if new_files:
|
|
358
|
+
PrettyOutput.print("\n新增文件:", output_type=OutputType.INFO)
|
|
359
|
+
for f in new_files:
|
|
360
|
+
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
361
|
+
if modified_files:
|
|
362
|
+
PrettyOutput.print("\n修改的文件:", output_type=OutputType.INFO)
|
|
363
|
+
for f in modified_files:
|
|
364
|
+
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
365
|
+
if deleted_files:
|
|
366
|
+
PrettyOutput.print("\n删除的文件:", output_type=OutputType.INFO)
|
|
367
|
+
for f in deleted_files:
|
|
368
|
+
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
369
|
+
|
|
370
|
+
# 如果force为True,直接继续
|
|
371
|
+
if not force:
|
|
372
|
+
# 询问用户是否继续
|
|
373
|
+
while True:
|
|
374
|
+
response = input("\n是否重建索引?[y/N] ").lower().strip()
|
|
375
|
+
if response in ['y', 'yes']:
|
|
376
|
+
break
|
|
377
|
+
elif response in ['', 'n', 'no']:
|
|
378
|
+
PrettyOutput.print("取消重建索引", output_type=OutputType.INFO)
|
|
379
|
+
return
|
|
380
|
+
else:
|
|
381
|
+
PrettyOutput.print("请输入 y 或 n", output_type=OutputType.WARNING)
|
|
382
|
+
|
|
383
|
+
# 清理已删除的文件
|
|
384
|
+
for file_path in files_to_delete:
|
|
385
|
+
del self.vector_cache[file_path]
|
|
386
|
+
if files_to_delete:
|
|
387
|
+
PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
|
|
388
|
+
output_type=OutputType.INFO)
|
|
389
|
+
|
|
390
|
+
# 处理新文件和修改的文件
|
|
391
|
+
files_to_process = new_files + modified_files
|
|
392
|
+
processed_files = []
|
|
393
|
+
|
|
394
|
+
with tqdm(total=len(files_to_process), desc="处理文件") as pbar:
|
|
395
|
+
# 使用线程池处理文件
|
|
396
|
+
with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
|
|
397
|
+
# 提交所有任务
|
|
398
|
+
future_to_file = {
|
|
399
|
+
executor.submit(self.process_file, file): file
|
|
400
|
+
for file in files_to_process
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
# 处理完成的任务
|
|
404
|
+
for future in concurrent.futures.as_completed(future_to_file):
|
|
405
|
+
file = future_to_file[future]
|
|
406
|
+
try:
|
|
407
|
+
result = future.result()
|
|
408
|
+
if result:
|
|
409
|
+
processed_files.append(result)
|
|
410
|
+
except Exception as e:
|
|
411
|
+
PrettyOutput.print(f"处理文件失败 {file}: {str(e)}",
|
|
412
|
+
output_type=OutputType.ERROR)
|
|
413
|
+
pbar.update(1)
|
|
414
|
+
|
|
415
|
+
if processed_files:
|
|
416
|
+
PrettyOutput.print("\n重新生成向量数据库...", output_type=OutputType.INFO)
|
|
417
|
+
self.gen_vector_db_from_cache()
|
|
418
|
+
PrettyOutput.print(f"成功为 {len(processed_files)} 个文件生成索引",
|
|
419
|
+
output_type=OutputType.SUCCESS)
|
|
420
|
+
else:
|
|
421
|
+
PrettyOutput.print("没有检测到文件变更,无需重建索引", output_type=OutputType.INFO)
|
|
422
|
+
|
|
423
|
+
except Exception as e:
|
|
424
|
+
# 发生异常时尝试保存缓存
|
|
425
|
+
try:
|
|
426
|
+
self._save_cache()
|
|
427
|
+
except Exception as save_error:
|
|
428
|
+
PrettyOutput.print(f"保存缓存失败: {str(save_error)}",
|
|
429
|
+
output_type=OutputType.ERROR)
|
|
430
|
+
raise e # 重新抛出原始异常
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def _text_search_score(self, content: str, keywords: List[str]) -> float:
|
|
434
|
+
"""计算文本内容与关键词的匹配分数
|
|
348
435
|
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
if new_files:
|
|
353
|
-
PrettyOutput.print("\n新增文件:", output_type=OutputType.INFO)
|
|
354
|
-
for f in new_files:
|
|
355
|
-
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
356
|
-
if modified_files:
|
|
357
|
-
PrettyOutput.print("\n修改的文件:", output_type=OutputType.INFO)
|
|
358
|
-
for f in modified_files:
|
|
359
|
-
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
360
|
-
if deleted_files:
|
|
361
|
-
PrettyOutput.print("\n删除的文件:", output_type=OutputType.INFO)
|
|
362
|
-
for f in deleted_files:
|
|
363
|
-
PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
# 如果force为True,直接继续
|
|
367
|
-
if not force:
|
|
368
|
-
# 询问用户是否继续
|
|
369
|
-
while True:
|
|
370
|
-
response = input("\n是否重建索引?[y/N] ").lower().strip()
|
|
371
|
-
if response in ['y', 'yes']:
|
|
372
|
-
break
|
|
373
|
-
elif response in ['', 'n', 'no']:
|
|
374
|
-
PrettyOutput.print("取消重建索引", output_type=OutputType.INFO)
|
|
375
|
-
return
|
|
376
|
-
else:
|
|
377
|
-
PrettyOutput.print("请输入 y 或 n", output_type=OutputType.WARNING)
|
|
378
|
-
|
|
379
|
-
# 清理已删除的文件
|
|
380
|
-
for file_path in files_to_delete:
|
|
381
|
-
del self.vector_cache[file_path]
|
|
382
|
-
if files_to_delete:
|
|
383
|
-
PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
|
|
384
|
-
output_type=OutputType.INFO)
|
|
436
|
+
Args:
|
|
437
|
+
content: 文本内容
|
|
438
|
+
keywords: 关键词列表
|
|
385
439
|
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
PrettyOutput.print("没有检测到文件变更,无需重建索引", output_type=OutputType.INFO)
|
|
440
|
+
Returns:
|
|
441
|
+
float: 匹配分数 (0-1)
|
|
442
|
+
"""
|
|
443
|
+
if not keywords:
|
|
444
|
+
return 0.0
|
|
445
|
+
|
|
446
|
+
content = content.lower()
|
|
447
|
+
matched_keywords = set()
|
|
448
|
+
|
|
449
|
+
for keyword in keywords:
|
|
450
|
+
keyword = keyword.lower()
|
|
451
|
+
if keyword in content:
|
|
452
|
+
matched_keywords.add(keyword)
|
|
453
|
+
|
|
454
|
+
# 计算匹配分数
|
|
455
|
+
score = len(matched_keywords) / len(keywords)
|
|
456
|
+
return score
|
|
404
457
|
|
|
405
458
|
def rerank_results(self, query: str, initial_results: List[Tuple[str, float, str]]) -> List[Tuple[str, float, str]]:
|
|
406
|
-
"""
|
|
459
|
+
"""使用多种策略对搜索结果重新排序"""
|
|
407
460
|
if not initial_results:
|
|
408
461
|
return []
|
|
409
462
|
|
|
@@ -413,13 +466,15 @@ class CodeBase:
|
|
|
413
466
|
# 加载模型和分词器
|
|
414
467
|
model, tokenizer = load_rerank_model()
|
|
415
468
|
|
|
416
|
-
# 准备数据
|
|
469
|
+
# 准备数据
|
|
417
470
|
pairs = []
|
|
471
|
+
|
|
418
472
|
for path, _, desc in initial_results:
|
|
419
473
|
try:
|
|
420
474
|
with open(path, "r", encoding="utf-8") as f:
|
|
421
475
|
content = f.read()[:512] # 限制内容长度
|
|
422
|
-
|
|
476
|
+
|
|
477
|
+
# 组合文件信息
|
|
423
478
|
doc_content = f"文件: {path}\n描述: {desc}\n内容: {content}"
|
|
424
479
|
pairs.append([query, doc_content])
|
|
425
480
|
except Exception as e:
|
|
@@ -430,6 +485,7 @@ class CodeBase:
|
|
|
430
485
|
|
|
431
486
|
# 使用更大的batch size提高处理速度
|
|
432
487
|
batch_size = 16 # 根据GPU显存调整
|
|
488
|
+
batch_scores = []
|
|
433
489
|
|
|
434
490
|
with torch.no_grad():
|
|
435
491
|
for i in range(0, len(pairs), batch_size):
|
|
@@ -446,8 +502,7 @@ class CodeBase:
|
|
|
446
502
|
encoded = {k: v.cuda() for k, v in encoded.items()}
|
|
447
503
|
|
|
448
504
|
outputs = model(**encoded)
|
|
449
|
-
|
|
450
|
-
batch_scores = outputs.logits.squeeze(-1).cpu().numpy()
|
|
505
|
+
batch_scores.extend(outputs.logits.squeeze(-1).cpu().numpy())
|
|
451
506
|
|
|
452
507
|
# 归一化分数到 0-1 范围
|
|
453
508
|
if batch_scores:
|
|
@@ -456,61 +511,98 @@ class CodeBase:
|
|
|
456
511
|
if max_score > min_score:
|
|
457
512
|
batch_scores = [(s - min_score) / (max_score - min_score) for s in batch_scores]
|
|
458
513
|
|
|
459
|
-
#
|
|
514
|
+
# 将重排序分数与原始分数结合
|
|
460
515
|
scored_results = []
|
|
461
|
-
for (path,
|
|
462
|
-
|
|
463
|
-
|
|
516
|
+
for (path, orig_score, desc), rerank_score in zip(initial_results, batch_scores):
|
|
517
|
+
# 综合分数 = 0.3 * 原始分数 + 0.7 * 重排序分数
|
|
518
|
+
combined_score = 0.3 * float(orig_score) + 0.7 * float(rerank_score)
|
|
519
|
+
if combined_score >= 0.5: # 只保留相关度较高的结果
|
|
520
|
+
scored_results.append((path, combined_score, desc))
|
|
464
521
|
|
|
465
|
-
#
|
|
522
|
+
# 按综合分数降序排序
|
|
466
523
|
scored_results.sort(key=lambda x: x[1], reverse=True)
|
|
467
524
|
|
|
468
525
|
return scored_results
|
|
469
526
|
|
|
470
527
|
except Exception as e:
|
|
471
|
-
PrettyOutput.print(f"
|
|
472
|
-
|
|
528
|
+
PrettyOutput.print(f"重排序失败: {str(e)}",
|
|
529
|
+
output_type=OutputType.ERROR)
|
|
530
|
+
return initial_results # 发生错误时返回原始结果
|
|
531
|
+
|
|
532
|
+
def _generate_query_variants(self, query: str) -> List[str]:
|
|
533
|
+
"""生成查询的不同表述变体
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
query: 原始查询
|
|
537
|
+
|
|
538
|
+
Returns:
|
|
539
|
+
List[str]: 查询变体列表
|
|
540
|
+
"""
|
|
541
|
+
model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
|
|
542
|
+
prompt = f"""请根据以下查询,生成3个不同的表述,每个表述都要完整表达原始查询的意思。这些表述将用于代码搜索,要保持专业性和准确性。
|
|
543
|
+
原始查询: {query}
|
|
544
|
+
|
|
545
|
+
请直接输出3个表述,用换行分隔,不要有编号或其他标记。
|
|
546
|
+
"""
|
|
547
|
+
variants = model.chat(prompt).strip().split('\n')
|
|
548
|
+
variants.append(query) # 添加原始查询
|
|
549
|
+
return variants
|
|
550
|
+
|
|
551
|
+
def _vector_search(self, query_variants: List[str], top_k: int) -> Dict[str, Tuple[str, float, str]]:
|
|
552
|
+
"""使用向量搜索查找相关文件
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
query_variants: 查询变体列表
|
|
556
|
+
top_k: 返回结果数量
|
|
557
|
+
|
|
558
|
+
Returns:
|
|
559
|
+
Dict[str, Tuple[str, float, str]]: 文件路径到(路径,分数,描述)的映射
|
|
560
|
+
"""
|
|
561
|
+
results = {}
|
|
562
|
+
for query in query_variants:
|
|
563
|
+
query_vector = self.get_embedding(query)
|
|
564
|
+
query_vector = query_vector.reshape(1, -1)
|
|
565
|
+
|
|
566
|
+
distances, indices = self.index.search(query_vector, top_k)
|
|
567
|
+
|
|
568
|
+
for i, distance in zip(indices[0], distances[0]):
|
|
569
|
+
if i == -1:
|
|
570
|
+
continue
|
|
571
|
+
|
|
572
|
+
similarity = 1.0 / (1.0 + float(distance))
|
|
573
|
+
if similarity >= 0.5:
|
|
574
|
+
file_path = self.file_paths[i]
|
|
575
|
+
# 使用最高的相似度分数
|
|
576
|
+
if file_path not in results or similarity > results[file_path][1]:
|
|
577
|
+
data = self.vector_cache[file_path]
|
|
578
|
+
results[file_path] = (file_path, similarity, data["description"])
|
|
579
|
+
|
|
580
|
+
return results
|
|
581
|
+
|
|
473
582
|
|
|
474
583
|
def search_similar(self, query: str, top_k: int = 30) -> List[Tuple[str, float, str]]:
|
|
475
584
|
"""搜索关联文件"""
|
|
476
585
|
try:
|
|
477
586
|
if self.index is None:
|
|
478
|
-
return []
|
|
479
|
-
#
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
587
|
+
return []
|
|
588
|
+
# 生成查询变体
|
|
589
|
+
query_variants = self._generate_query_variants(query)
|
|
590
|
+
|
|
591
|
+
# 进行向量搜索
|
|
592
|
+
vector_results = self._vector_search(query_variants, top_k)
|
|
483
593
|
|
|
484
|
-
|
|
485
|
-
"""
|
|
486
|
-
query_variants = model.chat(prompt).strip().split('\n')
|
|
487
|
-
query_variants.append(query) # 添加原始查询
|
|
488
|
-
|
|
489
|
-
# 对每个查询变体进行搜索
|
|
490
|
-
all_results = {}
|
|
491
|
-
for q in query_variants:
|
|
492
|
-
q_vector = self.get_embedding(q)
|
|
493
|
-
q_vector = q_vector.reshape(1, -1)
|
|
494
|
-
|
|
495
|
-
distances, indices = self.index.search(q_vector, top_k)
|
|
496
|
-
|
|
497
|
-
for i, distance in zip(indices[0], distances[0]):
|
|
498
|
-
if i == -1:
|
|
499
|
-
continue
|
|
500
|
-
|
|
501
|
-
similarity = 1.0 / (1.0 + float(distance))
|
|
502
|
-
if similarity >= 0.5:
|
|
503
|
-
file_path = self.file_paths[i]
|
|
504
|
-
# 使用最高的相似度分数
|
|
505
|
-
if file_path not in all_results or similarity > all_results[file_path][1]:
|
|
506
|
-
data = self.vector_cache[file_path]
|
|
507
|
-
all_results[file_path] = (file_path, similarity, data["description"])
|
|
508
|
-
|
|
509
|
-
# 转换为列表并排序
|
|
510
|
-
results = list(all_results.values())
|
|
594
|
+
results = list(vector_results.values())
|
|
511
595
|
results.sort(key=lambda x: x[1], reverse=True)
|
|
596
|
+
|
|
597
|
+
# 取前 top_k 个结果进行重排序
|
|
598
|
+
initial_results = results[:top_k]
|
|
512
599
|
|
|
513
|
-
|
|
600
|
+
# 如果没有找到结果,直接返回
|
|
601
|
+
if not initial_results:
|
|
602
|
+
return []
|
|
603
|
+
|
|
604
|
+
# 对初步结果进行重排序
|
|
605
|
+
return self.rerank_results(query, initial_results)
|
|
514
606
|
|
|
515
607
|
except Exception as e:
|
|
516
608
|
PrettyOutput.print(f"搜索失败: {str(e)}", output_type=OutputType.ERROR)
|
|
@@ -564,7 +656,7 @@ class CodeBase:
|
|
|
564
656
|
|
|
565
657
|
# 检查缓存是否有效
|
|
566
658
|
try:
|
|
567
|
-
with open(self.cache_path, 'rb') as f:
|
|
659
|
+
with lzma.open(self.cache_path, 'rb') as f:
|
|
568
660
|
cache_data = pickle.load(f)
|
|
569
661
|
if not cache_data.get("vectors") or not cache_data.get("file_paths"):
|
|
570
662
|
return False
|
jarvis/jarvis_platform/main.py
CHANGED
|
@@ -32,7 +32,7 @@ def list_platforms():
|
|
|
32
32
|
PrettyOutput.print(" 没有可用的模型信息", OutputType.WARNING)
|
|
33
33
|
|
|
34
34
|
except Exception as e:
|
|
35
|
-
PrettyOutput.print(f"获取 {platform_name} 平台模型列表失败: {str(e)}", OutputType.
|
|
35
|
+
PrettyOutput.print(f"获取 {platform_name} 平台模型列表失败: {str(e)}", OutputType.WARNING)
|
|
36
36
|
|
|
37
37
|
def chat_with_model(platform_name: str, model_name: str):
|
|
38
38
|
"""与指定平台和模型进行对话"""
|
|
@@ -55,13 +55,24 @@ def chat_with_model(platform_name: str, model_name: str):
|
|
|
55
55
|
user_input = get_multiline_input("")
|
|
56
56
|
|
|
57
57
|
# 检查是否取消输入
|
|
58
|
-
if user_input == "__interrupt__":
|
|
58
|
+
if user_input == "__interrupt__" or user_input.strip() == "/bye":
|
|
59
|
+
PrettyOutput.print("再见!", OutputType.SUCCESS)
|
|
59
60
|
break
|
|
60
61
|
|
|
61
62
|
# 检查是否为空输入
|
|
62
63
|
if not user_input.strip():
|
|
63
64
|
continue
|
|
64
65
|
|
|
66
|
+
# 检查是否为清除会话命令
|
|
67
|
+
if user_input.strip() == "/clear":
|
|
68
|
+
try:
|
|
69
|
+
platform.delete_chat()
|
|
70
|
+
platform.set_model_name(model_name) # 重新初始化会话
|
|
71
|
+
PrettyOutput.print("会话已清除", OutputType.SUCCESS)
|
|
72
|
+
except Exception as e:
|
|
73
|
+
PrettyOutput.print(f"清除会话失败: {str(e)}", OutputType.ERROR)
|
|
74
|
+
continue
|
|
75
|
+
|
|
65
76
|
try:
|
|
66
77
|
# 发送到模型并获取回复
|
|
67
78
|
response = platform.chat(user_input)
|