jarvis-ai-assistant 0.1.91__py3-none-any.whl → 0.1.92__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

jarvis/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  """Jarvis AI Assistant"""
2
2
 
3
- __version__ = "0.1.91"
3
+ __version__ = "0.1.92"
jarvis/agent.py CHANGED
@@ -259,6 +259,8 @@ class Agent:
259
259
  analysis_prompt = """本次任务已结束,请分析是否需要生成方法论。
260
260
  如果认为需要生成方法论,请先判断是创建新的方法论还是更新已有方法论。如果是更新已有方法论,使用update,否则使用add。
261
261
  如果认为不需要生成方法论,请说明原因。
262
+ 方法论应该适应普遍场景,不要出现本次任务特定的信息,如代码的commit信息等。
263
+ 方法论中应该包含:问题重述、最优解决方案、注意事项(按需),除此外不要出现任何其他的信息。
262
264
  仅输出方法论工具的调用指令,或者是不需要生成方法论的说明,除此之外不要输出任何内容。
263
265
  """
264
266
  self.prompt = analysis_prompt
@@ -2,16 +2,18 @@ import hashlib
2
2
  import os
3
3
  import numpy as np
4
4
  import faiss
5
- from typing import List, Tuple, Optional
5
+ from typing import List, Tuple, Optional, Dict
6
6
  from jarvis.models.registry import PlatformRegistry
7
7
  import concurrent.futures
8
8
  from threading import Lock
9
9
  from concurrent.futures import ThreadPoolExecutor
10
- from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_max_context_length, get_thread_count, load_embedding_model, load_rerank_model
10
+ from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_file_md5, get_max_context_length, get_thread_count, load_embedding_model, load_rerank_model
11
11
  from jarvis.utils import load_env_from_file
12
12
  import argparse
13
13
  from sentence_transformers import SentenceTransformer
14
14
  import pickle
15
+ import lzma # 添加 lzma 导入
16
+ from tqdm import tqdm
15
17
 
16
18
  class CodeBase:
17
19
  def __init__(self, root_dir: str):
@@ -58,7 +60,7 @@ class CodeBase:
58
60
  # 加载缓存
59
61
  if os.path.exists(self.cache_path):
60
62
  try:
61
- with open(self.cache_path, 'rb') as f:
63
+ with lzma.open(self.cache_path, 'rb') as f:
62
64
  cache_data = pickle.load(f)
63
65
  self.vector_cache = cache_data["vectors"]
64
66
  self.file_paths = cache_data["file_paths"]
@@ -88,7 +90,7 @@ class CodeBase:
88
90
  return False
89
91
 
90
92
  def make_description(self, file_path: str, content: str) -> str:
91
- model = PlatformRegistry.get_global_platform_registry().get_codegen_platform()
93
+ model = PlatformRegistry.get_global_platform_registry().get_cheap_platform()
92
94
  model.set_suppress_output(True)
93
95
  prompt = f"""请分析以下代码文件,并生成一个详细的描述。描述应该包含以下要点:
94
96
 
@@ -108,20 +110,24 @@ class CodeBase:
108
110
  response = model.chat(prompt)
109
111
  return response
110
112
 
111
- def save_cache(self):
113
+ def _save_cache(self):
112
114
  """保存缓存数据"""
113
115
  try:
116
+ # 创建缓存数据的副本
114
117
  cache_data = {
115
- "vectors": self.vector_cache,
116
- "file_paths": self.file_paths
118
+ "vectors": dict(self.vector_cache), # 创建字典的副本
119
+ "file_paths": list(self.file_paths) # 创建列表的副本
117
120
  }
118
- with open(self.cache_path, 'wb') as f:
119
- pickle.dump(cache_data, f)
121
+
122
+ # 使用 lzma 压缩存储
123
+ with lzma.open(self.cache_path, 'wb') as f:
124
+ pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
120
125
  PrettyOutput.print(f"保存了 {len(self.vector_cache)} 个向量缓存",
121
126
  output_type=OutputType.INFO)
122
127
  except Exception as e:
123
128
  PrettyOutput.print(f"保存缓存失败: {str(e)}",
124
129
  output_type=OutputType.ERROR)
130
+ raise # 抛出异常以便上层处理
125
131
 
126
132
  def get_cached_vector(self, file_path: str, description: str) -> Optional[np.ndarray]:
127
133
  """从缓存获取文件的向量表示"""
@@ -157,24 +163,13 @@ class CodeBase:
157
163
  output_type=OutputType.ERROR)
158
164
  file_md5 = ""
159
165
 
166
+ # 只更新内存中的缓存
160
167
  self.vector_cache[file_path] = {
161
168
  "path": file_path, # 保存文件路径
162
169
  "md5": file_md5, # 保存文件MD5
163
170
  "description": description, # 保存文件描述
164
171
  "vector": vector # 保存向量
165
172
  }
166
-
167
- # 保存缓存到文件
168
- try:
169
- with open(self.cache_path, 'wb') as f:
170
- cache_data = {
171
- "vectors": self.vector_cache,
172
- "file_paths": self.file_paths
173
- }
174
- pickle.dump(cache_data, f)
175
- except Exception as e:
176
- PrettyOutput.print(f"保存向量缓存失败: {str(e)}",
177
- output_type=OutputType.ERROR)
178
173
 
179
174
  def get_embedding(self, text: str) -> np.ndarray:
180
175
  """使用 transformers 模型获取文本的向量表示"""
@@ -219,18 +214,30 @@ class CodeBase:
219
214
 
220
215
  def clean_cache(self) -> bool:
221
216
  """清理过期的缓存记录,返回是否有文件被删除"""
222
- files_to_delete = []
223
- for file_path in list(self.vector_cache.keys()):
224
- if file_path not in self.git_file_list:
225
- del self.vector_cache[file_path]
226
- files_to_delete.append(file_path)
227
-
228
- if files_to_delete:
229
- self.save_cache()
230
- PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
231
- output_type=OutputType.INFO)
232
- return True
233
- return False
217
+ try:
218
+ files_to_delete = []
219
+ for file_path in list(self.vector_cache.keys()):
220
+ if file_path not in self.git_file_list:
221
+ del self.vector_cache[file_path]
222
+ files_to_delete.append(file_path)
223
+
224
+ if files_to_delete:
225
+ # 只在有文件被删除时保存缓存
226
+ self._save_cache()
227
+ PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
228
+ output_type=OutputType.INFO)
229
+ return True
230
+ return False
231
+
232
+ except Exception as e:
233
+ PrettyOutput.print(f"清理缓存失败: {str(e)}",
234
+ output_type=OutputType.ERROR)
235
+ # 发生异常时尝试保存当前状态
236
+ try:
237
+ self._save_cache()
238
+ except:
239
+ pass
240
+ return False
234
241
 
235
242
  def process_file(self, file_path: str):
236
243
  """处理单个文件"""
@@ -241,16 +248,10 @@ class CodeBase:
241
248
 
242
249
  if not self.is_text_file(file_path):
243
250
  return None
244
-
245
- # 读取文件内容,限制长度
246
- with open(file_path, "r", encoding="utf-8") as f:
247
- content = f.read()
248
- if len(content) > self.max_context_length:
249
- PrettyOutput.print(f"文件 {file_path} 内容超出长度限制,将截取前 {self.max_context_length} 个字符",
250
- output_type=OutputType.WARNING)
251
- content = content[:self.max_context_length]
252
251
 
253
- md5 = hashlib.md5(content.encode('utf-8')).hexdigest()
252
+ md5 = get_file_md5(file_path)
253
+
254
+ content = open(file_path, "r", encoding="utf-8").read()
254
255
 
255
256
  # 检查文件是否已经处理过且内容未变
256
257
  if file_path in self.vector_cache:
@@ -302,7 +303,7 @@ class CodeBase:
302
303
  def gen_vector_db_from_cache(self):
303
304
  """从缓存生成向量数据库"""
304
305
  self.build_index()
305
- self.save_cache()
306
+ self._save_cache()
306
307
 
307
308
 
308
309
  def generate_codebase(self, force: bool = False):
@@ -310,100 +311,152 @@ class CodeBase:
310
311
  Args:
311
312
  force: 是否强制重建索引,不询问用户
312
313
  """
313
- # 更新 git 文件列表
314
- self.git_file_list = self.get_git_file_list()
315
-
316
- # 检查文件变化
317
- changes_detected = False
318
- new_files = []
319
- modified_files = []
320
- deleted_files = []
321
-
322
- # 检查删除的文件
323
- files_to_delete = []
324
- for file_path in list(self.vector_cache.keys()):
325
- if file_path not in self.git_file_list:
326
- deleted_files.append(file_path)
327
- files_to_delete.append(file_path)
328
- changes_detected = True
329
-
330
- # 检查新增和修改的文件
331
- for file_path in self.git_file_list:
332
- if not os.path.exists(file_path) or not self.is_text_file(file_path):
333
- continue
314
+ try:
315
+ # 更新 git 文件列表
316
+ self.git_file_list = self.get_git_file_list()
334
317
 
335
- try:
336
- current_md5 = hashlib.md5(open(file_path, "rb").read()).hexdigest()
337
-
338
- if file_path not in self.vector_cache:
339
- new_files.append(file_path)
340
- changes_detected = True
341
- elif self.vector_cache[file_path].get("md5") != current_md5:
342
- modified_files.append(file_path)
318
+ # 检查文件变化
319
+ PrettyOutput.print("\n检查文件变化...", output_type=OutputType.INFO)
320
+ changes_detected = False
321
+ new_files = []
322
+ modified_files = []
323
+ deleted_files = []
324
+
325
+ # 检查删除的文件
326
+ files_to_delete = []
327
+ for file_path in list(self.vector_cache.keys()):
328
+ if file_path not in self.git_file_list:
329
+ deleted_files.append(file_path)
330
+ files_to_delete.append(file_path)
343
331
  changes_detected = True
344
- except Exception as e:
345
- PrettyOutput.print(f"检查文件失败 {file_path}: {str(e)}",
346
- output_type=OutputType.ERROR)
347
- continue
332
+
333
+ # 检查新增和修改的文件
334
+ with tqdm(total=len(self.git_file_list), desc="检查文件状态") as pbar:
335
+ for file_path in self.git_file_list:
336
+ if not os.path.exists(file_path) or not self.is_text_file(file_path):
337
+ pbar.update(1)
338
+ continue
339
+
340
+ try:
341
+ current_md5 = get_file_md5(file_path)
342
+
343
+ if file_path not in self.vector_cache:
344
+ new_files.append(file_path)
345
+ changes_detected = True
346
+ elif self.vector_cache[file_path].get("md5") != current_md5:
347
+ modified_files.append(file_path)
348
+ changes_detected = True
349
+ except Exception as e:
350
+ PrettyOutput.print(f"检查文件失败 {file_path}: {str(e)}",
351
+ output_type=OutputType.ERROR)
352
+ pbar.update(1)
353
+
354
+ # 如果检测到变化,显示变化并询问用户
355
+ if changes_detected:
356
+ PrettyOutput.print("\n检测到以下变化:", output_type=OutputType.WARNING)
357
+ if new_files:
358
+ PrettyOutput.print("\n新增文件:", output_type=OutputType.INFO)
359
+ for f in new_files:
360
+ PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
361
+ if modified_files:
362
+ PrettyOutput.print("\n修改的文件:", output_type=OutputType.INFO)
363
+ for f in modified_files:
364
+ PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
365
+ if deleted_files:
366
+ PrettyOutput.print("\n删除的文件:", output_type=OutputType.INFO)
367
+ for f in deleted_files:
368
+ PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
369
+
370
+ # 如果force为True,直接继续
371
+ if not force:
372
+ # 询问用户是否继续
373
+ while True:
374
+ response = input("\n是否重建索引?[y/N] ").lower().strip()
375
+ if response in ['y', 'yes']:
376
+ break
377
+ elif response in ['', 'n', 'no']:
378
+ PrettyOutput.print("取消重建索引", output_type=OutputType.INFO)
379
+ return
380
+ else:
381
+ PrettyOutput.print("请输入 y 或 n", output_type=OutputType.WARNING)
382
+
383
+ # 清理已删除的文件
384
+ for file_path in files_to_delete:
385
+ del self.vector_cache[file_path]
386
+ if files_to_delete:
387
+ PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
388
+ output_type=OutputType.INFO)
389
+
390
+ # 处理新文件和修改的文件
391
+ files_to_process = new_files + modified_files
392
+ processed_files = []
393
+
394
+ with tqdm(total=len(files_to_process), desc="处理文件") as pbar:
395
+ # 使用线程池处理文件
396
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
397
+ # 提交所有任务
398
+ future_to_file = {
399
+ executor.submit(self.process_file, file): file
400
+ for file in files_to_process
401
+ }
402
+
403
+ # 处理完成的任务
404
+ for future in concurrent.futures.as_completed(future_to_file):
405
+ file = future_to_file[future]
406
+ try:
407
+ result = future.result()
408
+ if result:
409
+ processed_files.append(result)
410
+ except Exception as e:
411
+ PrettyOutput.print(f"处理文件失败 {file}: {str(e)}",
412
+ output_type=OutputType.ERROR)
413
+ pbar.update(1)
414
+
415
+ if processed_files:
416
+ PrettyOutput.print("\n重新生成向量数据库...", output_type=OutputType.INFO)
417
+ self.gen_vector_db_from_cache()
418
+ PrettyOutput.print(f"成功为 {len(processed_files)} 个文件生成索引",
419
+ output_type=OutputType.SUCCESS)
420
+ else:
421
+ PrettyOutput.print("没有检测到文件变更,无需重建索引", output_type=OutputType.INFO)
422
+
423
+ except Exception as e:
424
+ # 发生异常时尝试保存缓存
425
+ try:
426
+ self._save_cache()
427
+ except Exception as save_error:
428
+ PrettyOutput.print(f"保存缓存失败: {str(save_error)}",
429
+ output_type=OutputType.ERROR)
430
+ raise e # 重新抛出原始异常
431
+
432
+
433
+ def _text_search_score(self, content: str, keywords: List[str]) -> float:
434
+ """计算文本内容与关键词的匹配分数
348
435
 
349
- # 如果检测到变化,显示变化并询问用户
350
- if changes_detected:
351
- PrettyOutput.print("\n检测到以下变化:", output_type=OutputType.WARNING)
352
- if new_files:
353
- PrettyOutput.print("\n新增文件:", output_type=OutputType.INFO)
354
- for f in new_files:
355
- PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
356
- if modified_files:
357
- PrettyOutput.print("\n修改的文件:", output_type=OutputType.INFO)
358
- for f in modified_files:
359
- PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
360
- if deleted_files:
361
- PrettyOutput.print("\n删除的文件:", output_type=OutputType.INFO)
362
- for f in deleted_files:
363
- PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
364
-
365
-
366
- # 如果force为True,直接继续
367
- if not force:
368
- # 询问用户是否继续
369
- while True:
370
- response = input("\n是否重建索引?[y/N] ").lower().strip()
371
- if response in ['y', 'yes']:
372
- break
373
- elif response in ['', 'n', 'no']:
374
- PrettyOutput.print("取消重建索引", output_type=OutputType.INFO)
375
- return
376
- else:
377
- PrettyOutput.print("请输入 y 或 n", output_type=OutputType.WARNING)
378
-
379
- # 清理已删除的文件
380
- for file_path in files_to_delete:
381
- del self.vector_cache[file_path]
382
- if files_to_delete:
383
- PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
384
- output_type=OutputType.INFO)
436
+ Args:
437
+ content: 文本内容
438
+ keywords: 关键词列表
385
439
 
386
- # 处理新文件和修改的文件
387
- processed_files = []
388
- files_to_process = new_files + modified_files
389
-
390
- # 使用线程池处理文件
391
- with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
392
- futures = [executor.submit(self.process_file, file) for file in files_to_process]
393
- for future in concurrent.futures.as_completed(futures):
394
- result = future.result()
395
- if result:
396
- processed_files.append(result)
397
- PrettyOutput.print(f"索引文件: {result}", output_type=OutputType.INFO)
398
-
399
- PrettyOutput.print("重新生成向量数据库", output_type=OutputType.INFO)
400
- self.gen_vector_db_from_cache()
401
- PrettyOutput.print(f"成功为 {len(processed_files)} 个文件生成索引", output_type=OutputType.INFO)
402
- else:
403
- PrettyOutput.print("没有检测到文件变更,无需重建索引", output_type=OutputType.INFO)
440
+ Returns:
441
+ float: 匹配分数 (0-1)
442
+ """
443
+ if not keywords:
444
+ return 0.0
445
+
446
+ content = content.lower()
447
+ matched_keywords = set()
448
+
449
+ for keyword in keywords:
450
+ keyword = keyword.lower()
451
+ if keyword in content:
452
+ matched_keywords.add(keyword)
453
+
454
+ # 计算匹配分数
455
+ score = len(matched_keywords) / len(keywords)
456
+ return score
404
457
 
405
458
  def rerank_results(self, query: str, initial_results: List[Tuple[str, float, str]]) -> List[Tuple[str, float, str]]:
406
- """使用 BAAI/bge-reranker-v2-m3 对搜索结果重新排序"""
459
+ """使用多种策略对搜索结果重新排序"""
407
460
  if not initial_results:
408
461
  return []
409
462
 
@@ -413,13 +466,15 @@ class CodeBase:
413
466
  # 加载模型和分词器
414
467
  model, tokenizer = load_rerank_model()
415
468
 
416
- # 准备数据 - 加入文件内容进行更准确的重排序
469
+ # 准备数据
417
470
  pairs = []
471
+
418
472
  for path, _, desc in initial_results:
419
473
  try:
420
474
  with open(path, "r", encoding="utf-8") as f:
421
475
  content = f.read()[:512] # 限制内容长度
422
- # 组合文件路径、描述和内容
476
+
477
+ # 组合文件信息
423
478
  doc_content = f"文件: {path}\n描述: {desc}\n内容: {content}"
424
479
  pairs.append([query, doc_content])
425
480
  except Exception as e:
@@ -430,6 +485,7 @@ class CodeBase:
430
485
 
431
486
  # 使用更大的batch size提高处理速度
432
487
  batch_size = 16 # 根据GPU显存调整
488
+ batch_scores = []
433
489
 
434
490
  with torch.no_grad():
435
491
  for i in range(0, len(pairs), batch_size):
@@ -446,8 +502,7 @@ class CodeBase:
446
502
  encoded = {k: v.cuda() for k, v in encoded.items()}
447
503
 
448
504
  outputs = model(**encoded)
449
- # 修改这里:直接使用 outputs.logits 作为分数
450
- batch_scores = outputs.logits.squeeze(-1).cpu().numpy()
505
+ batch_scores.extend(outputs.logits.squeeze(-1).cpu().numpy())
451
506
 
452
507
  # 归一化分数到 0-1 范围
453
508
  if batch_scores:
@@ -456,61 +511,98 @@ class CodeBase:
456
511
  if max_score > min_score:
457
512
  batch_scores = [(s - min_score) / (max_score - min_score) for s in batch_scores]
458
513
 
459
- # 将分数与原始结果组合并排序
514
+ # 将重排序分数与原始分数结合
460
515
  scored_results = []
461
- for (path, _, desc), score in zip(initial_results, batch_scores):
462
- if score >= 0.5: # 只保留相关度大于 0.5 的结果
463
- scored_results.append((path, float(score), desc))
516
+ for (path, orig_score, desc), rerank_score in zip(initial_results, batch_scores):
517
+ # 综合分数 = 0.3 * 原始分数 + 0.7 * 重排序分数
518
+ combined_score = 0.3 * float(orig_score) + 0.7 * float(rerank_score)
519
+ if combined_score >= 0.5: # 只保留相关度较高的结果
520
+ scored_results.append((path, combined_score, desc))
464
521
 
465
- # 按分数降序排序
522
+ # 按综合分数降序排序
466
523
  scored_results.sort(key=lambda x: x[1], reverse=True)
467
524
 
468
525
  return scored_results
469
526
 
470
527
  except Exception as e:
471
- PrettyOutput.print(f"重排序失败,使用原始排序: {str(e)}", output_type=OutputType.WARNING)
472
- return initial_results
528
+ PrettyOutput.print(f"重排序失败: {str(e)}",
529
+ output_type=OutputType.ERROR)
530
+ return initial_results # 发生错误时返回原始结果
531
+
532
+ def _generate_query_variants(self, query: str) -> List[str]:
533
+ """生成查询的不同表述变体
534
+
535
+ Args:
536
+ query: 原始查询
537
+
538
+ Returns:
539
+ List[str]: 查询变体列表
540
+ """
541
+ model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
542
+ prompt = f"""请根据以下查询,生成3个不同的表述,每个表述都要完整表达原始查询的意思。这些表述将用于代码搜索,要保持专业性和准确性。
543
+ 原始查询: {query}
544
+
545
+ 请直接输出3个表述,用换行分隔,不要有编号或其他标记。
546
+ """
547
+ variants = model.chat(prompt).strip().split('\n')
548
+ variants.append(query) # 添加原始查询
549
+ return variants
550
+
551
+ def _vector_search(self, query_variants: List[str], top_k: int) -> Dict[str, Tuple[str, float, str]]:
552
+ """使用向量搜索查找相关文件
553
+
554
+ Args:
555
+ query_variants: 查询变体列表
556
+ top_k: 返回结果数量
557
+
558
+ Returns:
559
+ Dict[str, Tuple[str, float, str]]: 文件路径到(路径,分数,描述)的映射
560
+ """
561
+ results = {}
562
+ for query in query_variants:
563
+ query_vector = self.get_embedding(query)
564
+ query_vector = query_vector.reshape(1, -1)
565
+
566
+ distances, indices = self.index.search(query_vector, top_k)
567
+
568
+ for i, distance in zip(indices[0], distances[0]):
569
+ if i == -1:
570
+ continue
571
+
572
+ similarity = 1.0 / (1.0 + float(distance))
573
+ if similarity >= 0.5:
574
+ file_path = self.file_paths[i]
575
+ # 使用最高的相似度分数
576
+ if file_path not in results or similarity > results[file_path][1]:
577
+ data = self.vector_cache[file_path]
578
+ results[file_path] = (file_path, similarity, data["description"])
579
+
580
+ return results
581
+
473
582
 
474
583
  def search_similar(self, query: str, top_k: int = 30) -> List[Tuple[str, float, str]]:
475
584
  """搜索关联文件"""
476
585
  try:
477
586
  if self.index is None:
478
- return []
479
- # 生成多个查询变体以提高召回率
480
- model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
481
- prompt = f"""请根据以下查询,生成3个不同的表述,每个表述都要完整表达原始查询的意思。这些表述将用于代码搜索,要保持专业性和准确性。
482
- 原始查询: {query}
587
+ return []
588
+ # 生成查询变体
589
+ query_variants = self._generate_query_variants(query)
590
+
591
+ # 进行向量搜索
592
+ vector_results = self._vector_search(query_variants, top_k)
483
593
 
484
- 请直接输出3个表述,用换行分隔,不要有编号或其他标记。
485
- """
486
- query_variants = model.chat(prompt).strip().split('\n')
487
- query_variants.append(query) # 添加原始查询
488
-
489
- # 对每个查询变体进行搜索
490
- all_results = {}
491
- for q in query_variants:
492
- q_vector = self.get_embedding(q)
493
- q_vector = q_vector.reshape(1, -1)
494
-
495
- distances, indices = self.index.search(q_vector, top_k)
496
-
497
- for i, distance in zip(indices[0], distances[0]):
498
- if i == -1:
499
- continue
500
-
501
- similarity = 1.0 / (1.0 + float(distance))
502
- if similarity >= 0.5:
503
- file_path = self.file_paths[i]
504
- # 使用最高的相似度分数
505
- if file_path not in all_results or similarity > all_results[file_path][1]:
506
- data = self.vector_cache[file_path]
507
- all_results[file_path] = (file_path, similarity, data["description"])
508
-
509
- # 转换为列表并排序
510
- results = list(all_results.values())
594
+ results = list(vector_results.values())
511
595
  results.sort(key=lambda x: x[1], reverse=True)
596
+
597
+ # 取前 top_k 个结果进行重排序
598
+ initial_results = results[:top_k]
512
599
 
513
- return results[:top_k]
600
+ # 如果没有找到结果,直接返回
601
+ if not initial_results:
602
+ return []
603
+
604
+ # 对初步结果进行重排序
605
+ return self.rerank_results(query, initial_results)
514
606
 
515
607
  except Exception as e:
516
608
  PrettyOutput.print(f"搜索失败: {str(e)}", output_type=OutputType.ERROR)
@@ -564,7 +656,7 @@ class CodeBase:
564
656
 
565
657
  # 检查缓存是否有效
566
658
  try:
567
- with open(self.cache_path, 'rb') as f:
659
+ with lzma.open(self.cache_path, 'rb') as f:
568
660
  cache_data = pickle.load(f)
569
661
  if not cache_data.get("vectors") or not cache_data.get("file_paths"):
570
662
  return False
@@ -32,7 +32,7 @@ def list_platforms():
32
32
  PrettyOutput.print(" 没有可用的模型信息", OutputType.WARNING)
33
33
 
34
34
  except Exception as e:
35
- PrettyOutput.print(f"获取 {platform_name} 平台模型列表失败: {str(e)}", OutputType.ERROR)
35
+ PrettyOutput.print(f"获取 {platform_name} 平台模型列表失败: {str(e)}", OutputType.WARNING)
36
36
 
37
37
  def chat_with_model(platform_name: str, model_name: str):
38
38
  """与指定平台和模型进行对话"""
@@ -55,13 +55,24 @@ def chat_with_model(platform_name: str, model_name: str):
55
55
  user_input = get_multiline_input("")
56
56
 
57
57
  # 检查是否取消输入
58
- if user_input == "__interrupt__":
58
+ if user_input == "__interrupt__" or user_input.strip() == "/bye":
59
+ PrettyOutput.print("再见!", OutputType.SUCCESS)
59
60
  break
60
61
 
61
62
  # 检查是否为空输入
62
63
  if not user_input.strip():
63
64
  continue
64
65
 
66
+ # 检查是否为清除会话命令
67
+ if user_input.strip() == "/clear":
68
+ try:
69
+ platform.delete_chat()
70
+ platform.set_model_name(model_name) # 重新初始化会话
71
+ PrettyOutput.print("会话已清除", OutputType.SUCCESS)
72
+ except Exception as e:
73
+ PrettyOutput.print(f"清除会话失败: {str(e)}", OutputType.ERROR)
74
+ continue
75
+
65
76
  try:
66
77
  # 发送到模型并获取回复
67
78
  response = platform.chat(user_input)