jarvis-ai-assistant 0.1.91__py3-none-any.whl → 0.1.93__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

jarvis/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  """Jarvis AI Assistant"""
2
2
 
3
- __version__ = "0.1.91"
3
+ __version__ = "0.1.93"
jarvis/agent.py CHANGED
@@ -140,7 +140,7 @@ class Agent:
140
140
 
141
141
  def _load_methodology(self, user_input: str) -> Dict[str, str]:
142
142
  """加载方法论并构建向量索引"""
143
- PrettyOutput.print("加载方法论...", OutputType.PLANNING)
143
+ PrettyOutput.print("加载方法论...", OutputType.PROGRESS)
144
144
  user_jarvis_methodology = os.path.expanduser("~/.jarvis_methodology")
145
145
  if not os.path.exists(user_jarvis_methodology):
146
146
  return {}
@@ -165,13 +165,13 @@ class Agent:
165
165
 
166
166
  if vectors:
167
167
  vectors_array = np.vstack(vectors)
168
- self.methodology_index.add_with_ids(vectors_array, np.array(ids))
168
+ self.methodology_index.add_with_ids(vectors_array, np.array(ids)) # type: ignore
169
169
  query_embedding = self._create_methodology_embedding(user_input)
170
170
  k = min(5, len(self.methodology_data))
171
171
  PrettyOutput.print(f"检索方法论...", OutputType.INFO)
172
172
  distances, indices = self.methodology_index.search(
173
173
  query_embedding.reshape(1, -1), k
174
- )
174
+ ) # type: ignore
175
175
 
176
176
  relevant_methodologies = {}
177
177
  for dist, idx in zip(distances[0], indices[0]):
@@ -208,7 +208,7 @@ class Agent:
208
208
  """
209
209
  # 创建一个新的模型实例来做总结,避免影响主对话
210
210
 
211
- PrettyOutput.print("总结对话历史,准备生成总结,开始新的对话...", OutputType.PLANNING)
211
+ PrettyOutput.print("总结对话历史,准备生成总结,开始新的对话...", OutputType.PROGRESS)
212
212
 
213
213
  prompt = """请总结之前对话中的关键信息,包括:
214
214
  1. 当前任务目标
@@ -259,6 +259,8 @@ class Agent:
259
259
  analysis_prompt = """本次任务已结束,请分析是否需要生成方法论。
260
260
  如果认为需要生成方法论,请先判断是创建新的方法论还是更新已有方法论。如果是更新已有方法论,使用update,否则使用add。
261
261
  如果认为不需要生成方法论,请说明原因。
262
+ 方法论应该适应普遍场景,不要出现本次任务特定的信息,如代码的commit信息等。
263
+ 方法论中应该包含:问题重述、最优解决方案、注意事项(按需),除此外不要出现任何其他的信息。
262
264
  仅输出方法论工具的调用指令,或者是不需要生成方法论的说明,除此之外不要输出任何内容。
263
265
  """
264
266
  self.prompt = analysis_prompt
@@ -2,16 +2,18 @@ import hashlib
2
2
  import os
3
3
  import numpy as np
4
4
  import faiss
5
- from typing import List, Tuple, Optional
5
+ from typing import List, Tuple, Optional, Dict
6
6
  from jarvis.models.registry import PlatformRegistry
7
7
  import concurrent.futures
8
8
  from threading import Lock
9
9
  from concurrent.futures import ThreadPoolExecutor
10
- from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_max_context_length, get_thread_count, load_embedding_model, load_rerank_model
10
+ from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_file_md5, get_max_context_length, get_thread_count, load_embedding_model, load_rerank_model
11
11
  from jarvis.utils import load_env_from_file
12
12
  import argparse
13
13
  from sentence_transformers import SentenceTransformer
14
14
  import pickle
15
+ import lzma # 添加 lzma 导入
16
+ from tqdm import tqdm
15
17
 
16
18
  class CodeBase:
17
19
  def __init__(self, root_dir: str):
@@ -58,7 +60,7 @@ class CodeBase:
58
60
  # 加载缓存
59
61
  if os.path.exists(self.cache_path):
60
62
  try:
61
- with open(self.cache_path, 'rb') as f:
63
+ with lzma.open(self.cache_path, 'rb') as f:
62
64
  cache_data = pickle.load(f)
63
65
  self.vector_cache = cache_data["vectors"]
64
66
  self.file_paths = cache_data["file_paths"]
@@ -88,19 +90,13 @@ class CodeBase:
88
90
  return False
89
91
 
90
92
  def make_description(self, file_path: str, content: str) -> str:
91
- model = PlatformRegistry.get_global_platform_registry().get_codegen_platform()
93
+ model = PlatformRegistry.get_global_platform_registry().get_cheap_platform()
92
94
  model.set_suppress_output(True)
93
95
  prompt = f"""请分析以下代码文件,并生成一个详细的描述。描述应该包含以下要点:
96
+ 1. 整个文件的功能描述,不超过100个字
97
+ 2. 每个全局变量的函数、类型定义、类、方法等代码元素的一句话描述,不超过50字
94
98
 
95
- 1. 主要功能和用途
96
- 2. 关键类和方法的作用
97
- 3. 重要的依赖和技术特征(如使用了什么框架、算法、设计模式等)
98
- 4. 代码处理的主要数据类型和数据结构
99
- 5. 关键业务逻辑和处理流程
100
- 6. 特殊功能点和亮点特性
101
-
102
- 请用简洁专业的语言描述,突出代码的技术特征和功能特点,以便后续进行关联代码检索。
103
-
99
+ 请用简洁专业的语言描述,突出代码的技术功能,以便后续进行关联代码检索。
104
100
  文件路径:{file_path}
105
101
  代码内容:
106
102
  {content}
@@ -108,20 +104,24 @@ class CodeBase:
108
104
  response = model.chat(prompt)
109
105
  return response
110
106
 
111
- def save_cache(self):
107
+ def _save_cache(self):
112
108
  """保存缓存数据"""
113
109
  try:
110
+ # 创建缓存数据的副本
114
111
  cache_data = {
115
- "vectors": self.vector_cache,
116
- "file_paths": self.file_paths
112
+ "vectors": dict(self.vector_cache), # 创建字典的副本
113
+ "file_paths": list(self.file_paths) # 创建列表的副本
117
114
  }
118
- with open(self.cache_path, 'wb') as f:
119
- pickle.dump(cache_data, f)
115
+
116
+ # 使用 lzma 压缩存储
117
+ with lzma.open(self.cache_path, 'wb') as f:
118
+ pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
120
119
  PrettyOutput.print(f"保存了 {len(self.vector_cache)} 个向量缓存",
121
120
  output_type=OutputType.INFO)
122
121
  except Exception as e:
123
122
  PrettyOutput.print(f"保存缓存失败: {str(e)}",
124
123
  output_type=OutputType.ERROR)
124
+ raise # 抛出异常以便上层处理
125
125
 
126
126
  def get_cached_vector(self, file_path: str, description: str) -> Optional[np.ndarray]:
127
127
  """从缓存获取文件的向量表示"""
@@ -157,24 +157,13 @@ class CodeBase:
157
157
  output_type=OutputType.ERROR)
158
158
  file_md5 = ""
159
159
 
160
+ # 只更新内存中的缓存
160
161
  self.vector_cache[file_path] = {
161
162
  "path": file_path, # 保存文件路径
162
163
  "md5": file_md5, # 保存文件MD5
163
164
  "description": description, # 保存文件描述
164
165
  "vector": vector # 保存向量
165
166
  }
166
-
167
- # 保存缓存到文件
168
- try:
169
- with open(self.cache_path, 'wb') as f:
170
- cache_data = {
171
- "vectors": self.vector_cache,
172
- "file_paths": self.file_paths
173
- }
174
- pickle.dump(cache_data, f)
175
- except Exception as e:
176
- PrettyOutput.print(f"保存向量缓存失败: {str(e)}",
177
- output_type=OutputType.ERROR)
178
167
 
179
168
  def get_embedding(self, text: str) -> np.ndarray:
180
169
  """使用 transformers 模型获取文本的向量表示"""
@@ -215,22 +204,34 @@ class CodeBase:
215
204
  except Exception as e:
216
205
  PrettyOutput.print(f"Error vectorizing file {file_path}: {str(e)}",
217
206
  output_type=OutputType.ERROR)
218
- return np.zeros(self.vector_dim, dtype=np.float32)
207
+ return np.zeros(self.vector_dim, dtype=np.float32) # type: ignore
219
208
 
220
209
  def clean_cache(self) -> bool:
221
210
  """清理过期的缓存记录,返回是否有文件被删除"""
222
- files_to_delete = []
223
- for file_path in list(self.vector_cache.keys()):
224
- if file_path not in self.git_file_list:
225
- del self.vector_cache[file_path]
226
- files_to_delete.append(file_path)
227
-
228
- if files_to_delete:
229
- self.save_cache()
230
- PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
231
- output_type=OutputType.INFO)
232
- return True
233
- return False
211
+ try:
212
+ files_to_delete = []
213
+ for file_path in list(self.vector_cache.keys()):
214
+ if file_path not in self.git_file_list:
215
+ del self.vector_cache[file_path]
216
+ files_to_delete.append(file_path)
217
+
218
+ if files_to_delete:
219
+ # 只在有文件被删除时保存缓存
220
+ self._save_cache()
221
+ PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
222
+ output_type=OutputType.INFO)
223
+ return True
224
+ return False
225
+
226
+ except Exception as e:
227
+ PrettyOutput.print(f"清理缓存失败: {str(e)}",
228
+ output_type=OutputType.ERROR)
229
+ # 发生异常时尝试保存当前状态
230
+ try:
231
+ self._save_cache()
232
+ except:
233
+ pass
234
+ return False
234
235
 
235
236
  def process_file(self, file_path: str):
236
237
  """处理单个文件"""
@@ -241,16 +242,10 @@ class CodeBase:
241
242
 
242
243
  if not self.is_text_file(file_path):
243
244
  return None
244
-
245
- # 读取文件内容,限制长度
246
- with open(file_path, "r", encoding="utf-8") as f:
247
- content = f.read()
248
- if len(content) > self.max_context_length:
249
- PrettyOutput.print(f"文件 {file_path} 内容超出长度限制,将截取前 {self.max_context_length} 个字符",
250
- output_type=OutputType.WARNING)
251
- content = content[:self.max_context_length]
252
245
 
253
- md5 = hashlib.md5(content.encode('utf-8')).hexdigest()
246
+ md5 = get_file_md5(file_path)
247
+
248
+ content = open(file_path, "r", encoding="utf-8").read()
254
249
 
255
250
  # 检查文件是否已经处理过且内容未变
256
251
  if file_path in self.vector_cache:
@@ -295,14 +290,14 @@ class CodeBase:
295
290
 
296
291
  if vectors:
297
292
  vectors = np.vstack(vectors)
298
- self.index.add_with_ids(vectors, np.array(ids))
293
+ self.index.add_with_ids(vectors, np.array(ids)) # type: ignore
299
294
  else:
300
295
  self.index = None
301
296
 
302
297
  def gen_vector_db_from_cache(self):
303
298
  """从缓存生成向量数据库"""
304
299
  self.build_index()
305
- self.save_cache()
300
+ self._save_cache()
306
301
 
307
302
 
308
303
  def generate_codebase(self, force: bool = False):
@@ -310,100 +305,152 @@ class CodeBase:
310
305
  Args:
311
306
  force: 是否强制重建索引,不询问用户
312
307
  """
313
- # 更新 git 文件列表
314
- self.git_file_list = self.get_git_file_list()
315
-
316
- # 检查文件变化
317
- changes_detected = False
318
- new_files = []
319
- modified_files = []
320
- deleted_files = []
321
-
322
- # 检查删除的文件
323
- files_to_delete = []
324
- for file_path in list(self.vector_cache.keys()):
325
- if file_path not in self.git_file_list:
326
- deleted_files.append(file_path)
327
- files_to_delete.append(file_path)
328
- changes_detected = True
329
-
330
- # 检查新增和修改的文件
331
- for file_path in self.git_file_list:
332
- if not os.path.exists(file_path) or not self.is_text_file(file_path):
333
- continue
308
+ try:
309
+ # 更新 git 文件列表
310
+ self.git_file_list = self.get_git_file_list()
334
311
 
335
- try:
336
- current_md5 = hashlib.md5(open(file_path, "rb").read()).hexdigest()
337
-
338
- if file_path not in self.vector_cache:
339
- new_files.append(file_path)
340
- changes_detected = True
341
- elif self.vector_cache[file_path].get("md5") != current_md5:
342
- modified_files.append(file_path)
312
+ # 检查文件变化
313
+ PrettyOutput.print("\n检查文件变化...", output_type=OutputType.INFO)
314
+ changes_detected = False
315
+ new_files = []
316
+ modified_files = []
317
+ deleted_files = []
318
+
319
+ # 检查删除的文件
320
+ files_to_delete = []
321
+ for file_path in list(self.vector_cache.keys()):
322
+ if file_path not in self.git_file_list:
323
+ deleted_files.append(file_path)
324
+ files_to_delete.append(file_path)
343
325
  changes_detected = True
344
- except Exception as e:
345
- PrettyOutput.print(f"检查文件失败 {file_path}: {str(e)}",
346
- output_type=OutputType.ERROR)
347
- continue
326
+
327
+ # 检查新增和修改的文件
328
+ with tqdm(total=len(self.git_file_list), desc="检查文件状态") as pbar:
329
+ for file_path in self.git_file_list:
330
+ if not os.path.exists(file_path) or not self.is_text_file(file_path):
331
+ pbar.update(1)
332
+ continue
333
+
334
+ try:
335
+ current_md5 = get_file_md5(file_path)
336
+
337
+ if file_path not in self.vector_cache:
338
+ new_files.append(file_path)
339
+ changes_detected = True
340
+ elif self.vector_cache[file_path].get("md5") != current_md5:
341
+ modified_files.append(file_path)
342
+ changes_detected = True
343
+ except Exception as e:
344
+ PrettyOutput.print(f"检查文件失败 {file_path}: {str(e)}",
345
+ output_type=OutputType.ERROR)
346
+ pbar.update(1)
347
+
348
+ # 如果检测到变化,显示变化并询问用户
349
+ if changes_detected:
350
+ PrettyOutput.print("\n检测到以下变化:", output_type=OutputType.WARNING)
351
+ if new_files:
352
+ PrettyOutput.print("\n新增文件:", output_type=OutputType.INFO)
353
+ for f in new_files:
354
+ PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
355
+ if modified_files:
356
+ PrettyOutput.print("\n修改的文件:", output_type=OutputType.INFO)
357
+ for f in modified_files:
358
+ PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
359
+ if deleted_files:
360
+ PrettyOutput.print("\n删除的文件:", output_type=OutputType.INFO)
361
+ for f in deleted_files:
362
+ PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
363
+
364
+ # 如果force为True,直接继续
365
+ if not force:
366
+ # 询问用户是否继续
367
+ while True:
368
+ response = input("\n是否重建索引?[y/N] ").lower().strip()
369
+ if response in ['y', 'yes']:
370
+ break
371
+ elif response in ['', 'n', 'no']:
372
+ PrettyOutput.print("取消重建索引", output_type=OutputType.INFO)
373
+ return
374
+ else:
375
+ PrettyOutput.print("请输入 y 或 n", output_type=OutputType.WARNING)
376
+
377
+ # 清理已删除的文件
378
+ for file_path in files_to_delete:
379
+ del self.vector_cache[file_path]
380
+ if files_to_delete:
381
+ PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
382
+ output_type=OutputType.INFO)
383
+
384
+ # 处理新文件和修改的文件
385
+ files_to_process = new_files + modified_files
386
+ processed_files = []
387
+
388
+ with tqdm(total=len(files_to_process), desc="处理文件") as pbar:
389
+ # 使用线程池处理文件
390
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
391
+ # 提交所有任务
392
+ future_to_file = {
393
+ executor.submit(self.process_file, file): file
394
+ for file in files_to_process
395
+ }
396
+
397
+ # 处理完成的任务
398
+ for future in concurrent.futures.as_completed(future_to_file):
399
+ file = future_to_file[future]
400
+ try:
401
+ result = future.result()
402
+ if result:
403
+ processed_files.append(result)
404
+ except Exception as e:
405
+ PrettyOutput.print(f"处理文件失败 {file}: {str(e)}",
406
+ output_type=OutputType.ERROR)
407
+ pbar.update(1)
408
+
409
+ if processed_files:
410
+ PrettyOutput.print("\n重新生成向量数据库...", output_type=OutputType.INFO)
411
+ self.gen_vector_db_from_cache()
412
+ PrettyOutput.print(f"成功为 {len(processed_files)} 个文件生成索引",
413
+ output_type=OutputType.SUCCESS)
414
+ else:
415
+ PrettyOutput.print("没有检测到文件变更,无需重建索引", output_type=OutputType.INFO)
416
+
417
+ except Exception as e:
418
+ # 发生异常时尝试保存缓存
419
+ try:
420
+ self._save_cache()
421
+ except Exception as save_error:
422
+ PrettyOutput.print(f"保存缓存失败: {str(save_error)}",
423
+ output_type=OutputType.ERROR)
424
+ raise e # 重新抛出原始异常
425
+
426
+
427
+ def _text_search_score(self, content: str, keywords: List[str]) -> float:
428
+ """计算文本内容与关键词的匹配分数
348
429
 
349
- # 如果检测到变化,显示变化并询问用户
350
- if changes_detected:
351
- PrettyOutput.print("\n检测到以下变化:", output_type=OutputType.WARNING)
352
- if new_files:
353
- PrettyOutput.print("\n新增文件:", output_type=OutputType.INFO)
354
- for f in new_files:
355
- PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
356
- if modified_files:
357
- PrettyOutput.print("\n修改的文件:", output_type=OutputType.INFO)
358
- for f in modified_files:
359
- PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
360
- if deleted_files:
361
- PrettyOutput.print("\n删除的文件:", output_type=OutputType.INFO)
362
- for f in deleted_files:
363
- PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
364
-
365
-
366
- # 如果force为True,直接继续
367
- if not force:
368
- # 询问用户是否继续
369
- while True:
370
- response = input("\n是否重建索引?[y/N] ").lower().strip()
371
- if response in ['y', 'yes']:
372
- break
373
- elif response in ['', 'n', 'no']:
374
- PrettyOutput.print("取消重建索引", output_type=OutputType.INFO)
375
- return
376
- else:
377
- PrettyOutput.print("请输入 y 或 n", output_type=OutputType.WARNING)
378
-
379
- # 清理已删除的文件
380
- for file_path in files_to_delete:
381
- del self.vector_cache[file_path]
382
- if files_to_delete:
383
- PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
384
- output_type=OutputType.INFO)
430
+ Args:
431
+ content: 文本内容
432
+ keywords: 关键词列表
385
433
 
386
- # 处理新文件和修改的文件
387
- processed_files = []
388
- files_to_process = new_files + modified_files
389
-
390
- # 使用线程池处理文件
391
- with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
392
- futures = [executor.submit(self.process_file, file) for file in files_to_process]
393
- for future in concurrent.futures.as_completed(futures):
394
- result = future.result()
395
- if result:
396
- processed_files.append(result)
397
- PrettyOutput.print(f"索引文件: {result}", output_type=OutputType.INFO)
398
-
399
- PrettyOutput.print("重新生成向量数据库", output_type=OutputType.INFO)
400
- self.gen_vector_db_from_cache()
401
- PrettyOutput.print(f"成功为 {len(processed_files)} 个文件生成索引", output_type=OutputType.INFO)
402
- else:
403
- PrettyOutput.print("没有检测到文件变更,无需重建索引", output_type=OutputType.INFO)
434
+ Returns:
435
+ float: 匹配分数 (0-1)
436
+ """
437
+ if not keywords:
438
+ return 0.0
439
+
440
+ content = content.lower()
441
+ matched_keywords = set()
442
+
443
+ for keyword in keywords:
444
+ keyword = keyword.lower()
445
+ if keyword in content:
446
+ matched_keywords.add(keyword)
447
+
448
+ # 计算匹配分数
449
+ score = len(matched_keywords) / len(keywords)
450
+ return score
404
451
 
405
452
  def rerank_results(self, query: str, initial_results: List[Tuple[str, float, str]]) -> List[Tuple[str, float, str]]:
406
- """使用 BAAI/bge-reranker-v2-m3 对搜索结果重新排序"""
453
+ """使用多种策略对搜索结果重新排序"""
407
454
  if not initial_results:
408
455
  return []
409
456
 
@@ -413,13 +460,15 @@ class CodeBase:
413
460
  # 加载模型和分词器
414
461
  model, tokenizer = load_rerank_model()
415
462
 
416
- # 准备数据 - 加入文件内容进行更准确的重排序
463
+ # 准备数据
417
464
  pairs = []
465
+
418
466
  for path, _, desc in initial_results:
419
467
  try:
420
468
  with open(path, "r", encoding="utf-8") as f:
421
469
  content = f.read()[:512] # 限制内容长度
422
- # 组合文件路径、描述和内容
470
+
471
+ # 组合文件信息
423
472
  doc_content = f"文件: {path}\n描述: {desc}\n内容: {content}"
424
473
  pairs.append([query, doc_content])
425
474
  except Exception as e:
@@ -430,6 +479,7 @@ class CodeBase:
430
479
 
431
480
  # 使用更大的batch size提高处理速度
432
481
  batch_size = 16 # 根据GPU显存调整
482
+ batch_scores = []
433
483
 
434
484
  with torch.no_grad():
435
485
  for i in range(0, len(pairs), batch_size):
@@ -446,8 +496,7 @@ class CodeBase:
446
496
  encoded = {k: v.cuda() for k, v in encoded.items()}
447
497
 
448
498
  outputs = model(**encoded)
449
- # 修改这里:直接使用 outputs.logits 作为分数
450
- batch_scores = outputs.logits.squeeze(-1).cpu().numpy()
499
+ batch_scores.extend(outputs.logits.squeeze(-1).cpu().numpy())
451
500
 
452
501
  # 归一化分数到 0-1 范围
453
502
  if batch_scores:
@@ -456,61 +505,98 @@ class CodeBase:
456
505
  if max_score > min_score:
457
506
  batch_scores = [(s - min_score) / (max_score - min_score) for s in batch_scores]
458
507
 
459
- # 将分数与原始结果组合并排序
508
+ # 将重排序分数与原始分数结合
460
509
  scored_results = []
461
- for (path, _, desc), score in zip(initial_results, batch_scores):
462
- if score >= 0.5: # 只保留相关度大于 0.5 的结果
463
- scored_results.append((path, float(score), desc))
510
+ for (path, orig_score, desc), rerank_score in zip(initial_results, batch_scores):
511
+ # 综合分数 = 0.3 * 原始分数 + 0.7 * 重排序分数
512
+ combined_score = 0.3 * float(orig_score) + 0.7 * float(rerank_score)
513
+ if combined_score >= 0.5: # 只保留相关度较高的结果
514
+ scored_results.append((path, combined_score, desc))
464
515
 
465
- # 按分数降序排序
516
+ # 按综合分数降序排序
466
517
  scored_results.sort(key=lambda x: x[1], reverse=True)
467
518
 
468
519
  return scored_results
469
520
 
470
521
  except Exception as e:
471
- PrettyOutput.print(f"重排序失败,使用原始排序: {str(e)}", output_type=OutputType.WARNING)
472
- return initial_results
522
+ PrettyOutput.print(f"重排序失败: {str(e)}",
523
+ output_type=OutputType.ERROR)
524
+ return initial_results # 发生错误时返回原始结果
525
+
526
+ def _generate_query_variants(self, query: str) -> List[str]:
527
+ """生成查询的不同表述变体
528
+
529
+ Args:
530
+ query: 原始查询
531
+
532
+ Returns:
533
+ List[str]: 查询变体列表
534
+ """
535
+ model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
536
+ prompt = f"""请根据以下查询,生成3个不同的表述,每个表述都要完整表达原始查询的意思。这些表述将用于代码搜索,要保持专业性和准确性。
537
+ 原始查询: {query}
538
+
539
+ 请直接输出3个表述,用换行分隔,不要有编号或其他标记。
540
+ """
541
+ variants = model.chat(prompt).strip().split('\n')
542
+ variants.append(query) # 添加原始查询
543
+ return variants
544
+
545
+ def _vector_search(self, query_variants: List[str], top_k: int) -> Dict[str, Tuple[str, float, str]]:
546
+ """使用向量搜索查找相关文件
547
+
548
+ Args:
549
+ query_variants: 查询变体列表
550
+ top_k: 返回结果数量
551
+
552
+ Returns:
553
+ Dict[str, Tuple[str, float, str]]: 文件路径到(路径,分数,描述)的映射
554
+ """
555
+ results = {}
556
+ for query in query_variants:
557
+ query_vector = self.get_embedding(query)
558
+ query_vector = query_vector.reshape(1, -1)
559
+
560
+ distances, indices = self.index.search(query_vector, top_k) # type: ignore
561
+
562
+ for i, distance in zip(indices[0], distances[0]):
563
+ if i == -1:
564
+ continue
565
+
566
+ similarity = 1.0 / (1.0 + float(distance))
567
+ if similarity >= 0.5:
568
+ file_path = self.file_paths[i]
569
+ # 使用最高的相似度分数
570
+ if file_path not in results or similarity > results[file_path][1]:
571
+ data = self.vector_cache[file_path]
572
+ results[file_path] = (file_path, similarity, data["description"])
573
+
574
+ return results
575
+
473
576
 
474
577
  def search_similar(self, query: str, top_k: int = 30) -> List[Tuple[str, float, str]]:
475
578
  """搜索关联文件"""
476
579
  try:
477
580
  if self.index is None:
478
- return []
479
- # 生成多个查询变体以提高召回率
480
- model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
481
- prompt = f"""请根据以下查询,生成3个不同的表述,每个表述都要完整表达原始查询的意思。这些表述将用于代码搜索,要保持专业性和准确性。
482
- 原始查询: {query}
581
+ return []
582
+ # 生成查询变体
583
+ query_variants = self._generate_query_variants(query)
584
+
585
+ # 进行向量搜索
586
+ vector_results = self._vector_search(query_variants, top_k)
483
587
 
484
- 请直接输出3个表述,用换行分隔,不要有编号或其他标记。
485
- """
486
- query_variants = model.chat(prompt).strip().split('\n')
487
- query_variants.append(query) # 添加原始查询
488
-
489
- # 对每个查询变体进行搜索
490
- all_results = {}
491
- for q in query_variants:
492
- q_vector = self.get_embedding(q)
493
- q_vector = q_vector.reshape(1, -1)
494
-
495
- distances, indices = self.index.search(q_vector, top_k)
496
-
497
- for i, distance in zip(indices[0], distances[0]):
498
- if i == -1:
499
- continue
500
-
501
- similarity = 1.0 / (1.0 + float(distance))
502
- if similarity >= 0.5:
503
- file_path = self.file_paths[i]
504
- # 使用最高的相似度分数
505
- if file_path not in all_results or similarity > all_results[file_path][1]:
506
- data = self.vector_cache[file_path]
507
- all_results[file_path] = (file_path, similarity, data["description"])
508
-
509
- # 转换为列表并排序
510
- results = list(all_results.values())
588
+ results = list(vector_results.values())
511
589
  results.sort(key=lambda x: x[1], reverse=True)
590
+
591
+ # 取前 top_k 个结果进行重排序
592
+ initial_results = results[:top_k]
512
593
 
513
- return results[:top_k]
594
+ # 如果没有找到结果,直接返回
595
+ if not initial_results:
596
+ return []
597
+
598
+ # 对初步结果进行重排序
599
+ return self.rerank_results(query, initial_results)
514
600
 
515
601
  except Exception as e:
516
602
  PrettyOutput.print(f"搜索失败: {str(e)}", output_type=OutputType.ERROR)
@@ -564,7 +650,7 @@ class CodeBase:
564
650
 
565
651
  # 检查缓存是否有效
566
652
  try:
567
- with open(self.cache_path, 'rb') as f:
653
+ with lzma.open(self.cache_path, 'rb') as f:
568
654
  cache_data = pickle.load(f)
569
655
  if not cache_data.get("vectors") or not cache_data.get("file_paths"):
570
656
  return False
@@ -625,4 +711,4 @@ def main():
625
711
 
626
712
 
627
713
  if __name__ == "__main__":
628
- exit(main())
714
+ exit(main())
File without changes