jarvis-ai-assistant 0.1.134__py3-none-any.whl → 0.1.138__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

Files changed (78) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/jarvis_agent/__init__.py +201 -79
  3. jarvis/jarvis_agent/builtin_input_handler.py +16 -6
  4. jarvis/jarvis_agent/file_input_handler.py +9 -9
  5. jarvis/jarvis_agent/jarvis.py +10 -10
  6. jarvis/jarvis_agent/main.py +12 -11
  7. jarvis/jarvis_agent/output_handler.py +3 -3
  8. jarvis/jarvis_agent/patch.py +86 -62
  9. jarvis/jarvis_agent/shell_input_handler.py +5 -3
  10. jarvis/jarvis_code_agent/code_agent.py +134 -99
  11. jarvis/jarvis_code_agent/file_select.py +24 -24
  12. jarvis/jarvis_dev/main.py +45 -51
  13. jarvis/jarvis_git_details/__init__.py +0 -0
  14. jarvis/jarvis_git_details/main.py +179 -0
  15. jarvis/jarvis_git_squash/main.py +7 -7
  16. jarvis/jarvis_lsp/base.py +11 -11
  17. jarvis/jarvis_lsp/cpp.py +14 -14
  18. jarvis/jarvis_lsp/go.py +13 -13
  19. jarvis/jarvis_lsp/python.py +8 -8
  20. jarvis/jarvis_lsp/registry.py +21 -21
  21. jarvis/jarvis_lsp/rust.py +15 -15
  22. jarvis/jarvis_methodology/main.py +101 -0
  23. jarvis/jarvis_multi_agent/__init__.py +11 -11
  24. jarvis/jarvis_multi_agent/main.py +6 -6
  25. jarvis/jarvis_platform/__init__.py +1 -1
  26. jarvis/jarvis_platform/ai8.py +67 -89
  27. jarvis/jarvis_platform/base.py +14 -13
  28. jarvis/jarvis_platform/kimi.py +25 -28
  29. jarvis/jarvis_platform/ollama.py +24 -26
  30. jarvis/jarvis_platform/openai.py +15 -19
  31. jarvis/jarvis_platform/oyi.py +48 -50
  32. jarvis/jarvis_platform/registry.py +27 -28
  33. jarvis/jarvis_platform/yuanbao.py +38 -42
  34. jarvis/jarvis_platform_manager/main.py +81 -81
  35. jarvis/jarvis_platform_manager/openai_test.py +21 -21
  36. jarvis/jarvis_rag/file_processors.py +18 -18
  37. jarvis/jarvis_rag/main.py +261 -277
  38. jarvis/jarvis_smart_shell/main.py +12 -12
  39. jarvis/jarvis_tools/ask_codebase.py +28 -28
  40. jarvis/jarvis_tools/ask_user.py +8 -8
  41. jarvis/jarvis_tools/base.py +4 -4
  42. jarvis/jarvis_tools/chdir.py +9 -9
  43. jarvis/jarvis_tools/code_review.py +19 -19
  44. jarvis/jarvis_tools/create_code_agent.py +15 -15
  45. jarvis/jarvis_tools/execute_python_script.py +3 -3
  46. jarvis/jarvis_tools/execute_shell.py +11 -11
  47. jarvis/jarvis_tools/execute_shell_script.py +3 -3
  48. jarvis/jarvis_tools/file_analyzer.py +29 -29
  49. jarvis/jarvis_tools/file_operation.py +22 -20
  50. jarvis/jarvis_tools/find_caller.py +25 -25
  51. jarvis/jarvis_tools/find_methodolopy.py +65 -0
  52. jarvis/jarvis_tools/find_symbol.py +24 -24
  53. jarvis/jarvis_tools/function_analyzer.py +27 -27
  54. jarvis/jarvis_tools/git_commiter.py +9 -9
  55. jarvis/jarvis_tools/lsp_get_diagnostics.py +19 -19
  56. jarvis/jarvis_tools/methodology.py +23 -62
  57. jarvis/jarvis_tools/project_analyzer.py +29 -33
  58. jarvis/jarvis_tools/rag.py +15 -15
  59. jarvis/jarvis_tools/read_code.py +24 -22
  60. jarvis/jarvis_tools/read_webpage.py +31 -31
  61. jarvis/jarvis_tools/registry.py +72 -52
  62. jarvis/jarvis_tools/tool_generator.py +18 -18
  63. jarvis/jarvis_utils/config.py +23 -23
  64. jarvis/jarvis_utils/embedding.py +83 -83
  65. jarvis/jarvis_utils/git_utils.py +20 -20
  66. jarvis/jarvis_utils/globals.py +18 -6
  67. jarvis/jarvis_utils/input.py +10 -9
  68. jarvis/jarvis_utils/methodology.py +140 -136
  69. jarvis/jarvis_utils/output.py +11 -11
  70. jarvis/jarvis_utils/utils.py +22 -70
  71. {jarvis_ai_assistant-0.1.134.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/METADATA +1 -1
  72. jarvis_ai_assistant-0.1.138.dist-info/RECORD +85 -0
  73. {jarvis_ai_assistant-0.1.134.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/entry_points.txt +2 -0
  74. jarvis/jarvis_tools/select_code_files.py +0 -62
  75. jarvis_ai_assistant-0.1.134.dist-info/RECORD +0 -82
  76. {jarvis_ai_assistant-0.1.134.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/LICENSE +0 -0
  77. {jarvis_ai_assistant-0.1.134.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/WHEEL +0 -0
  78. {jarvis_ai_assistant-0.1.134.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/top_level.txt +0 -0
jarvis/jarvis_rag/main.py CHANGED
@@ -13,13 +13,12 @@ import lzma # 添加 lzma 导入
13
13
  from threading import Lock
14
14
  import hashlib
15
15
 
16
- from jarvis.jarvis_utils.config import get_max_paragraph_length, get_max_token_count, get_min_paragraph_length, get_thread_count, get_rag_ignored_paths
16
+ from jarvis.jarvis_utils.config import get_max_paragraph_length, get_max_token_count, get_min_paragraph_length, get_rag_ignored_paths
17
17
  from jarvis.jarvis_utils.embedding import get_context_token_count, get_embedding, get_embedding_batch, load_embedding_model, rerank_results
18
18
  from jarvis.jarvis_utils.output import OutputType, PrettyOutput
19
- from jarvis.jarvis_utils.utils import ct, get_file_md5, init_env, init_gpu_config, ot
19
+ from jarvis.jarvis_utils.utils import ct, get_file_md5, init_env, ot
20
20
 
21
21
  from jarvis.jarvis_rag.file_processors import TextFileProcessor, PDFProcessor, DocxProcessor, PPTProcessor, ExcelProcessor
22
-
23
22
  """
24
23
  Jarvis RAG (Retrieval-Augmented Generation) Module
25
24
 
@@ -50,7 +49,7 @@ class Document:
50
49
  class RAGTool:
51
50
  def __init__(self, root_dir: str):
52
51
  """Initialize RAG tool
53
-
52
+
54
53
  Args:
55
54
  root_dir: Project root directory
56
55
  """
@@ -60,7 +59,7 @@ class RAGTool:
60
59
  os.chdir(self.root_dir)
61
60
  spinner.text = "环境初始化完成"
62
61
  spinner.ok("✅")
63
-
62
+
64
63
  # Initialize configuration
65
64
  with yaspin(text="初始化配置...", color="cyan") as spinner:
66
65
  self.min_paragraph_length = get_min_paragraph_length() # Minimum paragraph length
@@ -69,7 +68,7 @@ class RAGTool:
69
68
  self.max_token_count = int(get_max_token_count() * 0.8)
70
69
  spinner.text = "配置初始化完成"
71
70
  spinner.ok("✅")
72
-
71
+
73
72
  # Initialize data directory
74
73
  with yaspin(text="初始化数据目录...", color="cyan") as spinner:
75
74
  self.data_dir = os.path.join(self.root_dir, ".jarvis/rag")
@@ -77,7 +76,7 @@ class RAGTool:
77
76
  os.makedirs(self.data_dir)
78
77
  spinner.text = "数据目录初始化完成"
79
78
  spinner.ok("✅")
80
-
79
+
81
80
  # Initialize embedding model
82
81
  with yaspin(text="初始化模型...", color="cyan") as spinner:
83
82
  try:
@@ -94,14 +93,14 @@ class RAGTool:
94
93
  self.cache_dir = os.path.join(self.data_dir, "cache")
95
94
  if not os.path.exists(self.cache_dir):
96
95
  os.makedirs(self.cache_dir)
97
-
96
+
98
97
  self.documents: List[Document] = []
99
98
  self.index = None
100
99
  self.flat_index = None
101
100
  self.file_md5_cache = {}
102
101
  spinner.text = "缓存目录初始化完成"
103
102
  spinner.ok("✅")
104
-
103
+
105
104
  # 加载缓存索引
106
105
  self._load_cache_index()
107
106
 
@@ -118,34 +117,19 @@ class RAGTool:
118
117
  spinner.ok("✅")
119
118
 
120
119
 
121
- # Add thread related configuration
122
- with yaspin(text="初始化线程配置...", color="cyan") as spinner:
123
- self.thread_count = get_thread_count()
124
- self.vector_lock = Lock() # Protect vector list concurrency
125
- spinner.text = "线程配置初始化完成"
126
- spinner.ok("✅")
127
-
128
- # 初始化 GPU 内存配置
129
- with yaspin(text="初始化 GPU 内存配置...", color="cyan") as spinner:
130
- with spinner.hidden():
131
- self.gpu_config = init_gpu_config()
132
- spinner.text = "GPU 内存配置初始化完成"
133
- spinner.ok("✅")
134
-
135
-
136
120
  def _get_cache_path(self, file_path: str, cache_type: str = "doc") -> str:
137
121
  """Get cache file path for a document
138
-
122
+
139
123
  Args:
140
124
  file_path: Original file path
141
125
  cache_type: Type of cache ("doc" for documents, "vec" for vectors)
142
-
126
+
143
127
  Returns:
144
128
  str: Cache file path
145
129
  """
146
130
  # 使用文件路径的哈希作为缓存文件名
147
131
  file_hash = hashlib.md5(file_path.encode()).hexdigest()
148
-
132
+
149
133
  # 确保不同类型的缓存有不同的目录
150
134
  if cache_type == "doc":
151
135
  cache_subdir = os.path.join(self.cache_dir, "documents")
@@ -153,11 +137,11 @@ class RAGTool:
153
137
  cache_subdir = os.path.join(self.cache_dir, "vectors")
154
138
  else:
155
139
  cache_subdir = self.cache_dir
156
-
140
+
157
141
  # 确保子目录存在
158
142
  if not os.path.exists(cache_subdir):
159
143
  os.makedirs(cache_subdir)
160
-
144
+
161
145
  return os.path.join(cache_subdir, f"{file_hash}.cache")
162
146
 
163
147
  def _load_cache_index(self):
@@ -171,7 +155,7 @@ class RAGTool:
171
155
  self.file_md5_cache = cache_data.get("file_md5_cache", {})
172
156
  spinner.text = "缓存索引加载完成"
173
157
  spinner.ok("✅")
174
-
158
+
175
159
  # 从各个缓存文件加载文档
176
160
  with yaspin(text="加载缓存文件...", color="cyan") as spinner:
177
161
  for file_path in self.file_md5_cache:
@@ -186,35 +170,35 @@ class RAGTool:
186
170
  spinner.write(f"❌ 加载文档缓存失败: {file_path}: {str(e)}")
187
171
  spinner.text = "文档缓存加载完成"
188
172
  spinner.ok("✅")
189
-
173
+
190
174
  # 重建向量索引
191
175
  if self.documents:
192
176
  with yaspin(text="重建向量索引...", color="cyan") as spinner:
193
177
  vectors = []
194
-
178
+
195
179
  # 按照文档列表顺序加载向量
196
180
  processed_files = set()
197
181
  for doc in self.documents:
198
182
  file_path = doc.metadata['file_path']
199
-
183
+
200
184
  # 避免重复处理同一个文件
201
185
  if file_path in processed_files:
202
186
  continue
203
-
187
+
204
188
  processed_files.add(file_path)
205
189
  vec_cache_path = self._get_cache_path(file_path, "vec")
206
-
190
+
207
191
  if os.path.exists(vec_cache_path):
208
192
  try:
209
193
  # 加载该文件的向量缓存
210
194
  with lzma.open(vec_cache_path, 'rb') as f:
211
195
  vec_cache_data = pickle.load(f)
212
196
  file_vectors = vec_cache_data["vectors"]
213
-
197
+
214
198
  # 按照文档的chunk_index检索对应向量
215
- doc_indices = [d.metadata['chunk_index'] for d in self.documents
199
+ doc_indices = [d.metadata['chunk_index'] for d in self.documents
216
200
  if d.metadata['file_path'] == file_path]
217
-
201
+
218
202
  # 检查向量数量与文档块数量是否匹配
219
203
  if len(doc_indices) <= file_vectors.shape[0]:
220
204
  for idx in doc_indices:
@@ -222,21 +206,21 @@ class RAGTool:
222
206
  vectors.append(file_vectors[idx].reshape(1, -1))
223
207
  else:
224
208
  spinner.write(f"⚠️ 向量缓存不匹配: {file_path}")
225
-
209
+
226
210
  spinner.text = f"加载向量缓存: {file_path}"
227
211
  except Exception as e:
228
212
  spinner.write(f"❌ 加载向量缓存失败: {file_path}: {str(e)}")
229
213
  else:
230
214
  spinner.write(f"⚠️ 缺少向量缓存: {file_path}")
231
-
215
+
232
216
  if vectors:
233
217
  vectors = np.vstack(vectors)
234
218
  self._build_index(vectors, spinner)
235
219
  spinner.text = f"向量索引重建完成,加载 {len(self.documents)} 个文档片段"
236
220
  spinner.ok("✅")
237
-
221
+
238
222
  except Exception as e:
239
- PrettyOutput.print(f"加载缓存索引失败: {str(e)}",
223
+ PrettyOutput.print(f"加载缓存索引失败: {str(e)}",
240
224
  output_type=OutputType.WARNING)
241
225
  self.documents = []
242
226
  self.index = None
@@ -245,7 +229,7 @@ class RAGTool:
245
229
 
246
230
  def _save_cache(self, file_path: str, documents: List[Document], vectors: np.ndarray, spinner=None):
247
231
  """Save cache for a single file
248
-
232
+
249
233
  Args:
250
234
  file_path: File path
251
235
  documents: List of documents
@@ -262,7 +246,7 @@ class RAGTool:
262
246
  }
263
247
  with lzma.open(doc_cache_path, 'wb') as f:
264
248
  pickle.dump(doc_cache_data, f)
265
-
249
+
266
250
  # 保存向量缓存
267
251
  if spinner:
268
252
  spinner.text = f"保存 {file_path} 的向量缓存..."
@@ -272,7 +256,7 @@ class RAGTool:
272
256
  }
273
257
  with lzma.open(vec_cache_path, 'wb') as f:
274
258
  pickle.dump(vec_cache_data, f)
275
-
259
+
276
260
  # 更新并保存索引
277
261
  if spinner:
278
262
  spinner.text = f"更新 {file_path} 的索引缓存..."
@@ -282,10 +266,10 @@ class RAGTool:
282
266
  }
283
267
  with lzma.open(index_path, 'wb') as f:
284
268
  pickle.dump(index_data, f)
285
-
269
+
286
270
  if spinner:
287
271
  spinner.text = f"{file_path} 的缓存保存完成"
288
-
272
+
289
273
  except Exception as e:
290
274
  if spinner:
291
275
  spinner.text = f"保存 {file_path} 的缓存失败: {str(e)}"
@@ -299,13 +283,13 @@ class RAGTool:
299
283
  self.index = None
300
284
  self.flat_index = None
301
285
  return
302
-
286
+
303
287
  # Create a flat index to store original vectors, for reconstruction
304
288
  if spinner:
305
289
  spinner.text = "创建平面索引用于向量重建..."
306
290
  self.flat_index = faiss.IndexFlatIP(self.vector_dim)
307
291
  self.flat_index.add(vectors) # type: ignore
308
-
292
+
309
293
  # Create an IVF index for fast search
310
294
  if spinner:
311
295
  spinner.text = "创建IVF索引用于快速搜索..."
@@ -321,52 +305,52 @@ class RAGTool:
321
305
  else:
322
306
  # 原始逻辑:每1000个向量一个聚类中心,最少4个
323
307
  nlist = max(4, int(num_vectors / 1000))
324
-
308
+
325
309
  quantizer = faiss.IndexFlatIP(self.vector_dim)
326
310
  self.index = faiss.IndexIVFFlat(quantizer, self.vector_dim, nlist, faiss.METRIC_INNER_PRODUCT)
327
-
311
+
328
312
  # Train and add vectors
329
313
  if spinner:
330
314
  spinner.text = f"训练索引({vectors.shape[0]}个向量,{nlist}个聚类中心)..."
331
315
  self.index.train(vectors) # type: ignore
332
-
316
+
333
317
  if spinner:
334
318
  spinner.text = "添加向量到索引..."
335
319
  self.index.add(vectors) # type: ignore
336
-
320
+
337
321
  # Set the number of clusters to probe during search
338
322
  if spinner:
339
323
  spinner.text = "设置搜索参数..."
340
324
  self.index.nprobe = min(nlist, 10)
341
-
325
+
342
326
  if spinner:
343
327
  spinner.text = f"索引构建完成,共 {vectors.shape[0]} 个向量"
344
328
 
345
329
  def _split_text(self, text: str) -> List[str]:
346
330
  """使用基于token计数的更智能的分割策略
347
-
331
+
348
332
  Args:
349
333
  text: 要分割的文本
350
-
334
+
351
335
  Returns:
352
336
  List[str]: 分割后的段落列表
353
337
  """
354
338
  from jarvis.jarvis_utils.embedding import get_context_token_count
355
-
339
+
356
340
  # 计算可用的最大和最小token数
357
341
  max_tokens = int(self.max_paragraph_length * 0.25) # 字符长度转换为大致token数
358
342
  min_tokens = int(self.min_paragraph_length * 0.25) # 字符长度转换为大致token数
359
-
343
+
360
344
  # 添加重叠块以保持上下文一致性
361
345
  paragraphs = []
362
346
  current_chunk = []
363
347
  current_token_count = 0
364
-
348
+
365
349
  # 首先按句子分割
366
350
  sentences = []
367
351
  current_sentence = []
368
352
  sentence_ends = {'。', '!', '?', '…', '.', '!', '?'}
369
-
353
+
370
354
  for char in text:
371
355
  current_sentence.append(char)
372
356
  if char in sentence_ends:
@@ -374,32 +358,32 @@ class RAGTool:
374
358
  if sentence.strip():
375
359
  sentences.append(sentence)
376
360
  current_sentence = []
377
-
361
+
378
362
  if current_sentence:
379
363
  sentence = ''.join(current_sentence)
380
364
  if sentence.strip():
381
365
  sentences.append(sentence)
382
-
366
+
383
367
  # 基于句子构建重叠块
384
368
  for sentence in sentences:
385
369
  # 计算当前句子的token数
386
370
  sentence_token_count = get_context_token_count(sentence)
387
-
371
+
388
372
  # 检查添加此句子是否会超过最大token限制
389
373
  if current_token_count + sentence_token_count > max_tokens:
390
374
  if current_chunk:
391
375
  chunk_text = ' '.join(current_chunk)
392
376
  chunk_token_count = get_context_token_count(chunk_text)
393
-
377
+
394
378
  if chunk_token_count >= min_tokens:
395
379
  paragraphs.append(chunk_text)
396
-
380
+
397
381
  # 保留一些内容作为重叠
398
382
  # 保留最后两个句子作为重叠部分
399
383
  if len(current_chunk) >= 2:
400
384
  overlap_text = ' '.join(current_chunk[-2:])
401
385
  overlap_token_count = get_context_token_count(overlap_text)
402
-
386
+
403
387
  current_chunk = []
404
388
  if overlap_text:
405
389
  current_chunk.append(overlap_text)
@@ -410,19 +394,19 @@ class RAGTool:
410
394
  # 如果当前块中句子不足两个,就重置
411
395
  current_chunk = []
412
396
  current_token_count = 0
413
-
397
+
414
398
  # 添加当前句子到块中
415
399
  current_chunk.append(sentence)
416
400
  current_token_count += sentence_token_count
417
-
401
+
418
402
  # 处理最后一个块
419
403
  if current_chunk:
420
404
  chunk_text = ' '.join(current_chunk)
421
405
  chunk_token_count = get_context_token_count(chunk_text)
422
-
406
+
423
407
  if chunk_token_count >= min_tokens:
424
408
  paragraphs.append(chunk_text)
425
-
409
+
426
410
  return paragraphs
427
411
 
428
412
 
@@ -452,13 +436,13 @@ class RAGTool:
452
436
  if p.can_handle(file_path):
453
437
  processor = p
454
438
  break
455
-
439
+
456
440
  if not processor:
457
441
  # If no appropriate processor is found, return an empty document
458
442
  if spinner:
459
443
  spinner.text = f"没有找到适用于 {file_path} 的处理器,跳过处理"
460
444
  return []
461
-
445
+
462
446
  # Extract text content
463
447
  if spinner:
464
448
  spinner.text = f"提取 {file_path} 的文本内容..."
@@ -467,12 +451,12 @@ class RAGTool:
467
451
  if spinner:
468
452
  spinner.text = f"文件 {file_path} 没有文本内容,跳过处理"
469
453
  return []
470
-
454
+
471
455
  # Split text
472
456
  if spinner:
473
457
  spinner.text = f"分割 {file_path} 的文本..."
474
458
  chunks = self._split_text(content)
475
-
459
+
476
460
  # Create document objects
477
461
  if spinner:
478
462
  spinner.text = f"为 {file_path} 创建 {len(chunks)} 个文档对象..."
@@ -489,34 +473,34 @@ class RAGTool:
489
473
  md5=current_md5
490
474
  )
491
475
  documents.append(doc)
492
-
476
+
493
477
  # Update MD5 cache
494
478
  self.file_md5_cache[file_path] = current_md5
495
479
  if spinner:
496
480
  spinner.text = f"文件 {file_path} 处理完成,共创建 {len(documents)} 个文档对象"
497
481
  return documents
498
-
482
+
499
483
  except Exception as e:
500
484
  if spinner:
501
485
  spinner.text = f"处理文件失败: {file_path}: {str(e)}"
502
- PrettyOutput.print(f"处理文件失败: {file_path}: {str(e)}",
486
+ PrettyOutput.print(f"处理文件失败: {file_path}: {str(e)}",
503
487
  output_type=OutputType.ERROR)
504
488
  return []
505
489
 
506
490
  def _should_ignore_path(self, path: str, ignored_paths: List[str]) -> bool:
507
491
  """
508
492
  检查路径是否应该被忽略
509
-
493
+
510
494
  Args:
511
495
  path: 文件或目录路径
512
496
  ignored_paths: 忽略模式列表
513
-
497
+
514
498
  Returns:
515
499
  bool: 如果路径应该被忽略则返回True
516
500
  """
517
501
  import fnmatch
518
502
  import os
519
-
503
+
520
504
  # 获取相对路径
521
505
  rel_path = path
522
506
  if os.path.isabs(path):
@@ -525,31 +509,31 @@ class RAGTool:
525
509
  except ValueError:
526
510
  # 如果不能计算相对路径,使用原始路径
527
511
  pass
528
-
512
+
529
513
  path_parts = rel_path.split(os.sep)
530
-
514
+
531
515
  # 检查路径的每一部分是否匹配任意忽略模式
532
516
  for part in path_parts:
533
517
  for pattern in ignored_paths:
534
518
  if fnmatch.fnmatch(part, pattern):
535
519
  return True
536
-
520
+
537
521
  # 检查完整路径是否匹配任意忽略模式
538
522
  for pattern in ignored_paths:
539
523
  if fnmatch.fnmatch(rel_path, pattern):
540
524
  return True
541
-
525
+
542
526
  return False
543
-
527
+
544
528
  def _is_git_repo(self) -> bool:
545
529
  """
546
530
  检查当前目录是否为Git仓库
547
-
531
+
548
532
  Returns:
549
533
  bool: 如果是Git仓库则返回True
550
534
  """
551
535
  import subprocess
552
-
536
+
553
537
  try:
554
538
  result = subprocess.run(
555
539
  ["git", "rev-parse", "--is-inside-work-tree"],
@@ -562,16 +546,16 @@ class RAGTool:
562
546
  return result.returncode == 0 and result.stdout.strip() == "true"
563
547
  except Exception:
564
548
  return False
565
-
549
+
566
550
  def _get_git_managed_files(self) -> List[str]:
567
551
  """
568
552
  获取Git仓库中被管理的文件列表
569
-
553
+
570
554
  Returns:
571
555
  List[str]: 被Git管理的文件路径列表(相对路径)
572
556
  """
573
557
  import subprocess
574
-
558
+
575
559
  try:
576
560
  # 获取git索引中的文件
577
561
  result = subprocess.run(
@@ -582,12 +566,12 @@ class RAGTool:
582
566
  text=True,
583
567
  check=False
584
568
  )
585
-
569
+
586
570
  if result.returncode != 0:
587
571
  return []
588
-
572
+
589
573
  git_files = [line.strip() for line in result.stdout.splitlines() if line.strip()]
590
-
574
+
591
575
  # 添加未暂存但已跟踪的修改文件
592
576
  result = subprocess.run(
593
577
  ["git", "ls-files", "--modified"],
@@ -597,14 +581,14 @@ class RAGTool:
597
581
  text=True,
598
582
  check=False
599
583
  )
600
-
584
+
601
585
  if result.returncode == 0:
602
586
  modified_files = [line.strip() for line in result.stdout.splitlines() if line.strip()]
603
587
  git_files.extend([f for f in modified_files if f not in git_files])
604
-
588
+
605
589
  # 转换为绝对路径
606
590
  return [os.path.join(self.root_dir, file) for file in git_files]
607
-
591
+
608
592
  except Exception as e:
609
593
  PrettyOutput.print(f"获取Git管理的文件失败: {str(e)}", output_type=OutputType.WARNING)
610
594
  return []
@@ -615,10 +599,10 @@ class RAGTool:
615
599
  # Get all files
616
600
  with yaspin(text="获取所有文件...", color="cyan") as spinner:
617
601
  all_files = []
618
-
602
+
619
603
  # 获取需要忽略的路径列表
620
604
  ignored_paths = get_rag_ignored_paths()
621
-
605
+
622
606
  # 检查是否为Git仓库
623
607
  is_git_repo = self._is_git_repo()
624
608
  if is_git_repo:
@@ -627,9 +611,9 @@ class RAGTool:
627
611
  for file_path in git_files:
628
612
  if self._should_ignore_path(file_path, ignored_paths):
629
613
  continue
630
-
614
+
631
615
  if os.path.getsize(file_path) > 100 * 1024 * 1024: # 100MB
632
- PrettyOutput.print(f"跳过大文件: {file_path}",
616
+ PrettyOutput.print(f"跳过大文件: {file_path}",
633
617
  output_type=OutputType.WARNING)
634
618
  continue
635
619
  all_files.append(file_path)
@@ -639,20 +623,20 @@ class RAGTool:
639
623
  # 检查目录是否匹配忽略模式
640
624
  if self._should_ignore_path(root, ignored_paths):
641
625
  continue
642
-
626
+
643
627
  for file in files:
644
628
  file_path = os.path.join(root, file)
645
-
629
+
646
630
  # 检查文件是否匹配忽略模式
647
631
  if self._should_ignore_path(file_path, ignored_paths):
648
632
  continue
649
-
633
+
650
634
  if os.path.getsize(file_path) > 100 * 1024 * 1024: # 100MB
651
- PrettyOutput.print(f"跳过大文件: {file_path}",
635
+ PrettyOutput.print(f"跳过大文件: {file_path}",
652
636
  output_type=OutputType.WARNING)
653
637
  continue
654
638
  all_files.append(file_path)
655
-
639
+
656
640
  spinner.text = f"获取所有文件完成,共 {len(all_files)} 个文件"
657
641
  spinner.ok("✅")
658
642
 
@@ -660,10 +644,10 @@ class RAGTool:
660
644
  with yaspin(text="清理缓存...", color="cyan") as spinner:
661
645
  deleted_files = set(self.file_md5_cache.keys()) - set(all_files)
662
646
  deleted_count = len(deleted_files)
663
-
647
+
664
648
  if deleted_count > 0:
665
649
  spinner.write(f"🗑️ 删除不存在文件的缓存: {deleted_count} 个")
666
-
650
+
667
651
  for file_path in deleted_files:
668
652
  # Remove from MD5 cache
669
653
  del self.file_md5_cache[file_path]
@@ -671,7 +655,7 @@ class RAGTool:
671
655
  self.documents = [doc for doc in self.documents if doc.metadata['file_path'] != file_path]
672
656
  # Delete cache files
673
657
  self._delete_file_cache(file_path, None) # Pass None as spinner to not show individual deletions
674
-
658
+
675
659
  spinner.text = f"清理缓存完成,共删除 {deleted_count} 个不存在文件的缓存"
676
660
  spinner.ok("✅")
677
661
 
@@ -681,7 +665,7 @@ class RAGTool:
681
665
  unchanged_files = []
682
666
  new_files_count = 0
683
667
  modified_files_count = 0
684
-
668
+
685
669
  for file_path in all_files:
686
670
  current_md5 = get_file_md5(file_path)
687
671
  if current_md5: # Only process files that can successfully calculate MD5
@@ -691,7 +675,7 @@ class RAGTool:
691
675
  else:
692
676
  # New file or modified file
693
677
  files_to_process.append(file_path)
694
-
678
+
695
679
  # 如果是修改的文件,删除旧缓存
696
680
  if file_path in self.file_md5_cache:
697
681
  modified_files_count += 1
@@ -701,7 +685,7 @@ class RAGTool:
701
685
  self.documents = [doc for doc in self.documents if doc.metadata['file_path'] != file_path]
702
686
  else:
703
687
  new_files_count += 1
704
-
688
+
705
689
  # 输出汇总信息
706
690
  if unchanged_files:
707
691
  spinner.write(f"📚 已缓存文件: {len(unchanged_files)} 个")
@@ -709,12 +693,12 @@ class RAGTool:
709
693
  spinner.write(f"🆕 新增文件: {new_files_count} 个")
710
694
  if modified_files_count > 0:
711
695
  spinner.write(f"📝 修改文件: {modified_files_count} 个")
712
-
696
+
713
697
  spinner.text = f"检查文件变化完成,共 {len(files_to_process)} 个文件需要处理"
714
698
  spinner.ok("✅")
715
699
 
716
700
  # Keep documents for unchanged files
717
- unchanged_documents = [doc for doc in self.documents
701
+ unchanged_documents = [doc for doc in self.documents
718
702
  if doc.metadata['file_path'] in unchanged_files]
719
703
 
720
704
  # Process files one by one with optimized vectorization
@@ -724,7 +708,7 @@ class RAGTool:
724
708
  success_count = 0
725
709
  skipped_count = 0
726
710
  failed_count = 0
727
-
711
+
728
712
  with yaspin(text=f"处理文件中 (0/{len(files_to_process)})...", color="cyan") as spinner:
729
713
  for index, file_path in enumerate(files_to_process):
730
714
  spinner.text = f"处理文件中 ({index+1}/{len(files_to_process)}): {file_path}"
@@ -738,13 +722,13 @@ class RAGTool:
738
722
  f"File:{doc.metadata['file_path']} Content:{doc.content}"
739
723
  for doc in file_docs
740
724
  ]
741
-
725
+
742
726
  file_vectors = get_embedding_batch(self.embedding_model, f"({index+1}/{len(files_to_process)}){file_path}", texts_to_vectorize, spinner)
743
-
727
+
744
728
  # Save cache for this file
745
729
  spinner.text = f"处理文件中 ({index+1}/{len(files_to_process)}): 保存 {file_path} 的缓存..."
746
730
  self._save_cache(file_path, file_docs, file_vectors, spinner)
747
-
731
+
748
732
  # Accumulate documents and vectors
749
733
  new_documents.extend(file_docs)
750
734
  new_vectors.append(file_vectors)
@@ -752,15 +736,15 @@ class RAGTool:
752
736
  else:
753
737
  # 文件跳过处理
754
738
  skipped_count += 1
755
-
739
+
756
740
  except Exception as e:
757
741
  spinner.write(f"❌ 处理失败: {file_path}: {str(e)}")
758
742
  failed_count += 1
759
-
743
+
760
744
  # 输出处理统计
761
745
  spinner.text = f"文件处理完成: 成功 {success_count} 个, 跳过 {skipped_count} 个, 失败 {failed_count} 个"
762
746
  spinner.ok("✅")
763
-
747
+
764
748
  # Update documents list
765
749
  self.documents.extend(new_documents)
766
750
 
@@ -769,7 +753,7 @@ class RAGTool:
769
753
  with yaspin(text="构建最终索引...", color="cyan") as spinner:
770
754
  spinner.text = "合并新向量..."
771
755
  all_new_vectors = np.vstack(new_vectors)
772
-
756
+
773
757
  unchanged_vector_count = 0
774
758
  if self.flat_index is not None:
775
759
  # Get vectors for unchanged documents
@@ -800,11 +784,11 @@ class RAGTool:
800
784
  f" • 处理文件: {len(files_to_process)} 个\n"
801
785
  f" - 成功: {success_count} 个\n"
802
786
  f" - 跳过: {skipped_count} 个\n"
803
- f" - 失败: {failed_count} 个",
787
+ f" - 失败: {failed_count} 个",
804
788
  OutputType.SUCCESS
805
789
  )
806
790
  except Exception as e:
807
- PrettyOutput.print(f"索引构建失败: {str(e)}",
791
+ PrettyOutput.print(f"索引构建失败: {str(e)}",
808
792
  output_type=OutputType.ERROR)
809
793
 
810
794
  def _get_unchanged_vectors(self, unchanged_documents: List[Document], spinner=None) -> Optional[np.ndarray]:
@@ -817,23 +801,23 @@ class RAGTool:
817
801
 
818
802
  if spinner:
819
803
  spinner.text = f"加载 {len(unchanged_documents)} 个未变化文档的向量..."
820
-
804
+
821
805
  # 按文件分组处理
822
806
  unchanged_files = set(doc.metadata['file_path'] for doc in unchanged_documents)
823
807
  unchanged_vectors = []
824
-
808
+
825
809
  for file_path in unchanged_files:
826
810
  if spinner:
827
811
  spinner.text = f"加载 {file_path} 的向量..."
828
-
812
+
829
813
  # 获取该文件所有文档的chunk索引
830
- doc_indices = [(i, doc.metadata['chunk_index'])
831
- for i, doc in enumerate(unchanged_documents)
814
+ doc_indices = [(i, doc.metadata['chunk_index'])
815
+ for i, doc in enumerate(unchanged_documents)
832
816
  if doc.metadata['file_path'] == file_path]
833
-
817
+
834
818
  if not doc_indices:
835
819
  continue
836
-
820
+
837
821
  # 加载该文件的向量
838
822
  vec_cache_path = self._get_cache_path(file_path, "vec")
839
823
  if os.path.exists(vec_cache_path):
@@ -841,12 +825,12 @@ class RAGTool:
841
825
  with lzma.open(vec_cache_path, 'rb') as f:
842
826
  vec_cache_data = pickle.load(f)
843
827
  file_vectors = vec_cache_data["vectors"]
844
-
828
+
845
829
  # 按照chunk_index加载对应的向量
846
830
  for _, chunk_idx in doc_indices:
847
831
  if chunk_idx < file_vectors.shape[0]:
848
832
  unchanged_vectors.append(file_vectors[chunk_idx].reshape(1, -1))
849
-
833
+
850
834
  if spinner:
851
835
  spinner.text = f"成功加载 {file_path} 的向量"
852
836
  except Exception as e:
@@ -855,17 +839,17 @@ class RAGTool:
855
839
  else:
856
840
  if spinner:
857
841
  spinner.text = f"未找到 {file_path} 的向量缓存"
858
-
842
+
859
843
  # 从flat_index重建向量
860
844
  if self.flat_index is not None:
861
845
  if spinner:
862
846
  spinner.text = f"从索引重建 {file_path} 的向量..."
863
-
847
+
864
848
  for doc_idx, chunk_idx in doc_indices:
865
- idx = next((i for i, d in enumerate(self.documents)
866
- if d.metadata['file_path'] == file_path and
849
+ idx = next((i for i, d in enumerate(self.documents)
850
+ if d.metadata['file_path'] == file_path and
867
851
  d.metadata['chunk_index'] == chunk_idx), None)
868
-
852
+
869
853
  if idx is not None:
870
854
  vector = np.zeros((1, self.vector_dim), dtype=np.float32) # type: ignore
871
855
  self.flat_index.reconstruct(idx, vector.ravel())
@@ -875,12 +859,12 @@ class RAGTool:
875
859
  if spinner:
876
860
  spinner.text = "未能加载任何未变化文档的向量"
877
861
  return None
878
-
862
+
879
863
  if spinner:
880
864
  spinner.text = f"未变化文档向量加载完成,共 {len(unchanged_vectors)} 个"
881
-
865
+
882
866
  return np.vstack(unchanged_vectors)
883
-
867
+
884
868
  except Exception as e:
885
869
  if spinner:
886
870
  spinner.text = f"获取不变向量失败: {str(e)}"
@@ -889,17 +873,17 @@ class RAGTool:
889
873
 
890
874
  def _perform_keyword_search(self, query: str, limit: int = 15) -> List[Tuple[int, float]]:
891
875
  """执行基于关键词的文本搜索
892
-
876
+
893
877
  Args:
894
878
  query: 查询字符串
895
879
  limit: 返回结果数量限制
896
-
880
+
897
881
  Returns:
898
882
  List[Tuple[int, float]]: 文档索引和得分的列表
899
883
  """
900
884
  # 使用大模型生成关键词
901
885
  keywords = self._generate_keywords_with_llm(query)
902
-
886
+
903
887
  # 如果大模型生成失败,回退到简单的关键词提取
904
888
  if not keywords:
905
889
  # 简单的关键词预处理
@@ -907,35 +891,35 @@ class RAGTool:
907
891
  # 移除停用词和过短的词
908
892
  stop_words = {'的', '了', '和', '是', '在', '有', '与', '对', '为', 'a', 'an', 'the', 'and', 'is', 'in', 'of', 'to', 'with'}
909
893
  keywords = [k for k in keywords if k not in stop_words and len(k) > 1]
910
-
894
+
911
895
  if not keywords:
912
896
  return []
913
-
897
+
914
898
  # 使用TF-IDF思想的简单实现
915
899
  doc_scores = []
916
-
900
+
917
901
  # 计算IDF(逆文档频率)
918
902
  doc_count = len(self.documents)
919
903
  keyword_doc_count = {}
920
-
904
+
921
905
  for keyword in keywords:
922
906
  count = 0
923
907
  for doc in self.documents:
924
908
  if keyword in doc.content.lower():
925
909
  count += 1
926
910
  keyword_doc_count[keyword] = max(1, count) # 避免除零错误
927
-
911
+
928
912
  # 计算每个关键词的IDF值
929
913
  keyword_idf = {
930
- keyword: np.log(doc_count / count)
914
+ keyword: np.log(doc_count / count)
931
915
  for keyword, count in keyword_doc_count.items()
932
916
  }
933
-
917
+
934
918
  # 为每个文档计算得分
935
919
  for i, doc in enumerate(self.documents):
936
920
  doc_content = doc.content.lower()
937
921
  score = 0
938
-
922
+
939
923
  # 计算每个关键词的TF(词频)
940
924
  for keyword in keywords:
941
925
  # 简单的TF:关键词在文档中出现的次数
@@ -943,46 +927,46 @@ class RAGTool:
943
927
  # TF-IDF得分
944
928
  if tf > 0:
945
929
  score += tf * keyword_idf[keyword]
946
-
930
+
947
931
  # 添加额外权重:标题匹配、完整短语匹配等
948
932
  if query.lower() in doc_content:
949
933
  score *= 2.0 # 完整查询匹配加倍得分
950
-
934
+
951
935
  # 文件路径匹配也加分
952
936
  file_path = doc.metadata['file_path'].lower()
953
937
  for keyword in keywords:
954
938
  if keyword in file_path:
955
939
  score += 0.5 * keyword_idf.get(keyword, 1.0)
956
-
940
+
957
941
  if score > 0:
958
942
  # 归一化得分(0-1范围)
959
943
  doc_scores.append((i, score))
960
-
944
+
961
945
  # 排序并限制结果数量
962
946
  doc_scores.sort(key=lambda x: x[1], reverse=True)
963
-
947
+
964
948
  # 归一化分数到0-1之间
965
949
  if doc_scores:
966
950
  max_score = max(score for _, score in doc_scores)
967
951
  if max_score > 0:
968
952
  doc_scores = [(idx, score/max_score) for idx, score in doc_scores]
969
-
953
+
970
954
  return doc_scores[:limit]
971
955
 
972
956
  def _generate_keywords_with_llm(self, query: str) -> List[str]:
973
957
  """
974
958
  使用大语言模型从查询中提取关键词
975
-
959
+
976
960
  Args:
977
961
  query: 用户查询
978
-
962
+
979
963
  Returns:
980
964
  List[str]: 提取的关键词列表
981
965
  """
982
966
  try:
983
967
  from jarvis.jarvis_utils.output import PrettyOutput, OutputType
984
968
  from jarvis.jarvis_platform.registry import PlatformRegistry
985
-
969
+
986
970
  # 获取平台注册表和模型
987
971
  registry = PlatformRegistry.get_global_platform_registry()
988
972
  model = registry.get_normal_platform()
@@ -990,55 +974,55 @@ class RAGTool:
990
974
  # 构建关键词提取提示词
991
975
  prompt = f"""
992
976
  请分析以下查询,提取用于文档检索的关键词。你的任务是:
993
-
977
+
994
978
  1. 识别核心概念、主题和实体,包括:
995
979
  - 技术术语和专业名词
996
980
  - 代码相关的函数名、类名、变量名和库名
997
981
  - 重要的业务领域词汇
998
982
  - 操作和动作相关的词汇
999
-
983
+
1000
984
  2. 优先提取与以下场景相关的关键词:
1001
985
  - 代码搜索: 编程语言、框架、API、特定功能
1002
986
  - 文档检索: 主题、标题词汇、专业名词
1003
987
  - 错误排查: 错误信息、异常名称、问题症状
1004
-
988
+
1005
989
  3. 同时包含:
1006
990
  - 中英文关键词 (尤其是技术领域常用英文术语)
1007
991
  - 完整的专业术语和缩写形式
1008
992
  - 可能的同义词或相关概念
1009
-
993
+
1010
994
  4. 关键词应当精准、具体,数量控制在3-10个之间。
1011
-
995
+
1012
996
  输出格式:
1013
997
  {ot("KEYWORD")}
1014
998
  关键词1
1015
999
  关键词2
1016
1000
  ...
1017
1001
  {ct("KEYWORD")}
1018
-
1002
+
1019
1003
  查询文本:
1020
1004
  {query}
1021
1005
 
1022
1006
  仅返回提取的关键词,不要包含其他内容。
1023
1007
  """
1024
-
1008
+
1025
1009
  # 调用大模型获取响应
1026
1010
  response = model.chat_until_success(prompt)
1027
-
1011
+
1028
1012
  if response:
1029
1013
  # 清理响应,提取关键词
1030
1014
  sm = re.search(ot('KEYWORD') + r"(.*?)" + ct('KEYWORD'), response, re.DOTALL)
1031
1015
  if sm:
1032
1016
  extracted_keywords = sm[1]
1033
-
1017
+
1034
1018
  if extracted_keywords:
1035
1019
  # 记录检测到的关键词
1036
1020
  ret = extracted_keywords.strip().splitlines()
1037
1021
  return ret
1038
-
1022
+
1039
1023
  # 如果处理失败,返回空列表
1040
1024
  return []
1041
-
1025
+
1042
1026
  except Exception as e:
1043
1027
  from jarvis.jarvis_utils.output import PrettyOutput, OutputType
1044
1028
  PrettyOutput.print(f"使用大模型生成关键词失败: {str(e)}", OutputType.WARNING)
@@ -1046,38 +1030,38 @@ class RAGTool:
1046
1030
 
1047
1031
  def _hybrid_search(self, query: str, top_k: int = 15) -> List[Tuple[int, float]]:
1048
1032
  """混合搜索方法,综合向量相似度和关键词匹配
1049
-
1033
+
1050
1034
  Args:
1051
1035
  query: 查询字符串
1052
1036
  top_k: 返回结果数量限制
1053
-
1037
+
1054
1038
  Returns:
1055
1039
  List[Tuple[int, float]]: 文档索引和得分的列表
1056
1040
  """
1057
1041
  # 获取向量搜索结果
1058
1042
  query_vector = get_embedding(self.embedding_model, query)
1059
1043
  query_vector = query_vector.reshape(1, -1)
1060
-
1044
+
1061
1045
  # 进行向量搜索
1062
1046
  vector_limit = min(top_k * 3, len(self.documents))
1063
1047
  if self.index and vector_limit > 0:
1064
1048
  distances, indices = self.index.search(query_vector, vector_limit) # type: ignore
1065
- vector_results = [(int(idx), 1.0 / (1.0 + float(dist)))
1049
+ vector_results = [(int(idx), 1.0 / (1.0 + float(dist)))
1066
1050
  for idx, dist in zip(indices[0], distances[0])
1067
1051
  if idx != -1 and idx < len(self.documents)]
1068
1052
  else:
1069
1053
  vector_results = []
1070
-
1054
+
1071
1055
  # 进行关键词搜索
1072
1056
  keyword_results = self._perform_keyword_search(query, top_k * 2)
1073
-
1057
+
1074
1058
  # 合并结果集
1075
1059
  combined_results = {}
1076
-
1060
+
1077
1061
  # 加入向量结果,权重为0.7
1078
1062
  for idx, score in vector_results:
1079
1063
  combined_results[idx] = score * 0.7
1080
-
1064
+
1081
1065
  # 加入关键词结果,权重为0.3,如果文档已存在则取加权平均
1082
1066
  for idx, score in keyword_results:
1083
1067
  if idx in combined_results:
@@ -1086,11 +1070,11 @@ class RAGTool:
1086
1070
  else:
1087
1071
  # 新文档,直接添加关键词得分(权重稍低)
1088
1072
  combined_results[idx] = score * 0.3
1089
-
1073
+
1090
1074
  # 转换成列表并排序
1091
1075
  result_list = [(idx, score) for idx, score in combined_results.items()]
1092
1076
  result_list.sort(key=lambda x: x[1], reverse=True)
1093
-
1077
+
1094
1078
  return result_list[:top_k]
1095
1079
 
1096
1080
 
@@ -1099,35 +1083,35 @@ class RAGTool:
1099
1083
  if not self.is_index_built():
1100
1084
  PrettyOutput.print("索引未建立,自动建立索引中...", OutputType.INFO)
1101
1085
  self.build_index(self.root_dir)
1102
-
1086
+
1103
1087
  # 如果索引建立失败或文档列表为空,返回空结果
1104
1088
  if not self.is_index_built():
1105
1089
  PrettyOutput.print("索引建立失败或文档列表为空", OutputType.WARNING)
1106
1090
  return []
1107
-
1091
+
1108
1092
  # 使用混合搜索获取候选文档
1109
1093
  with yaspin(text="执行混合搜索...", color="cyan") as spinner:
1110
1094
  # 获取初始候选结果
1111
1095
  search_results = self._hybrid_search(query, top_k * 2)
1112
-
1096
+
1113
1097
  if not search_results:
1114
1098
  spinner.text = "搜索结果为空"
1115
1099
  spinner.fail("❌")
1116
1100
  return []
1117
-
1101
+
1118
1102
  # 准备重排序
1119
1103
  initial_indices = [idx for idx, _ in search_results]
1120
1104
  spinner.text = f"检索完成,获取 {len(initial_indices)} 个候选文档"
1121
1105
  spinner.ok("✅")
1122
-
1106
+
1123
1107
  indices_list = [idx for idx, _ in search_results if idx < len(self.documents)]
1124
-
1108
+
1125
1109
  # 应用重排序优化检索结果
1126
1110
  with yaspin(text="执行重排序...", color="cyan") as spinner:
1127
1111
  # 准备重排序所需文档内容和初始分数
1128
1112
  docs_to_rerank = []
1129
1113
  initial_scores = []
1130
-
1114
+
1131
1115
  for idx, score in search_results:
1132
1116
  if idx < len(self.documents):
1133
1117
  doc = self.documents[idx]
@@ -1135,12 +1119,12 @@ class RAGTool:
1135
1119
  doc_content = f"File:{doc.metadata['file_path']} Content:{doc.content}"
1136
1120
  docs_to_rerank.append(doc_content)
1137
1121
  initial_scores.append(score)
1138
-
1122
+
1139
1123
  if not docs_to_rerank:
1140
1124
  spinner.text = "没有可重排序的文档"
1141
1125
  spinner.fail("❌")
1142
1126
  return []
1143
-
1127
+
1144
1128
  # 执行重排序
1145
1129
  spinner.text = f"重排序 {len(docs_to_rerank)} 个文档..."
1146
1130
  reranked_scores = rerank_results(
@@ -1149,69 +1133,69 @@ class RAGTool:
1149
1133
  initial_scores=initial_scores,
1150
1134
  spinner=spinner
1151
1135
  )
1152
-
1136
+
1153
1137
  # 更新搜索结果的分数
1154
1138
  search_results = []
1155
1139
  for i, idx in enumerate(indices_list):
1156
1140
  if i < len(reranked_scores):
1157
1141
  search_results.append((idx, reranked_scores[i]))
1158
-
1142
+
1159
1143
  # 按分数重新排序
1160
1144
  search_results.sort(key=lambda x: x[1], reverse=True)
1161
-
1145
+
1162
1146
  spinner.text = "重排序完成"
1163
1147
  spinner.ok("✅")
1164
-
1148
+
1165
1149
  # 重新获取排序后的索引列表
1166
1150
  indices_list = [idx for idx, _ in search_results if idx < len(self.documents)]
1167
-
1151
+
1168
1152
  # Process results with context window
1169
1153
  with yaspin(text="处理结果...", color="cyan") as spinner:
1170
1154
  results = []
1171
1155
  seen_files = set()
1172
-
1156
+
1173
1157
  # 检查索引列表是否为空
1174
1158
  if not indices_list:
1175
1159
  spinner.text = "搜索结果为空"
1176
1160
  spinner.fail("❌")
1177
1161
  return []
1178
-
1162
+
1179
1163
  for idx in indices_list:
1180
1164
  if idx < len(self.documents): # 确保索引有效
1181
1165
  doc = self.documents[idx]
1182
-
1166
+
1183
1167
  # 使用重排序得分或基于原始相似度的得分
1184
1168
  similarity = next((score for i, score in search_results if i == idx), 0.5) if search_results else 0.5
1185
-
1169
+
1186
1170
  file_path = doc.metadata['file_path']
1187
1171
  if file_path not in seen_files:
1188
1172
  seen_files.add(file_path)
1189
-
1173
+
1190
1174
  # Get full context from original document
1191
- original_doc = next((d for d in self.documents
1175
+ original_doc = next((d for d in self.documents
1192
1176
  if d.metadata['file_path'] == file_path), None)
1193
1177
  if original_doc:
1194
1178
  window_docs = [] # Add this line to initialize the list
1195
1179
  # Find all chunks from this file
1196
- file_chunks = [d for d in self.documents
1180
+ file_chunks = [d for d in self.documents
1197
1181
  if d.metadata['file_path'] == file_path]
1198
1182
  # Add all related chunks
1199
1183
  for chunk_doc in file_chunks:
1200
1184
  window_docs.append((chunk_doc, similarity * 0.9))
1201
-
1185
+
1202
1186
  results.extend(window_docs)
1203
1187
  if len(results) >= top_k * (2 * self.context_window + 1):
1204
1188
  break
1205
1189
  spinner.text = "处理结果完成"
1206
1190
  spinner.ok("✅")
1207
-
1191
+
1208
1192
  # Sort by similarity and deduplicate
1209
1193
  with yaspin(text="排序...", color="cyan") as spinner:
1210
1194
  if not results:
1211
1195
  spinner.text = "无有效结果"
1212
1196
  spinner.fail("❌")
1213
1197
  return []
1214
-
1198
+
1215
1199
  results.sort(key=lambda x: x[1], reverse=True)
1216
1200
  seen = set()
1217
1201
  final_results = []
@@ -1224,15 +1208,15 @@ class RAGTool:
1224
1208
  break
1225
1209
  spinner.text = "排序完成"
1226
1210
  spinner.ok("✅")
1227
-
1211
+
1228
1212
  return final_results
1229
1213
 
1230
1214
  def query(self, query: str) -> List[Document]:
1231
1215
  """Query related documents
1232
-
1216
+
1233
1217
  Args:
1234
1218
  query: Query text
1235
-
1219
+
1236
1220
  Returns:
1237
1221
  List[Document]: Related documents
1238
1222
  """
@@ -1246,15 +1230,15 @@ class RAGTool:
1246
1230
  if not self.is_index_built():
1247
1231
  PrettyOutput.print("索引未建立,自动建立索引中...", OutputType.INFO)
1248
1232
  self.build_index(self.root_dir)
1249
-
1233
+
1250
1234
  # 如果建立索引后仍未成功,返回错误信息
1251
1235
  if not self.is_index_built():
1252
1236
  PrettyOutput.print("无法建立索引,请检查文档和配置", OutputType.ERROR)
1253
1237
  return "无法建立索引,请检查文档和配置。可能的原因:文档目录为空、权限不足或格式不支持。"
1254
-
1238
+
1255
1239
  # 增强查询预处理 - 提取关键词和语义信息
1256
1240
  enhanced_query = self._enhance_query(question)
1257
-
1241
+
1258
1242
  # 使用增强的查询进行搜索
1259
1243
  results = self.search(enhanced_query)
1260
1244
  if not results:
@@ -1262,7 +1246,7 @@ class RAGTool:
1262
1246
 
1263
1247
  # 模型实例
1264
1248
  model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
1265
-
1249
+
1266
1250
  # 计算基础提示词的token数量
1267
1251
  base_prompt = f"""
1268
1252
  # 🤖 角色定义
@@ -1313,10 +1297,10 @@ class RAGTool:
1313
1297
  6. 使用用户的语言回答
1314
1298
  """
1315
1299
  footer_token_count = get_context_token_count(footer_prompt)
1316
-
1300
+
1317
1301
  # 每批可用的token数,减去一些安全余量
1318
1302
  available_tokens_per_batch = self.max_token_count - base_token_count - footer_token_count - 1000
1319
-
1303
+
1320
1304
  # 确定是否需要分批处理
1321
1305
  with yaspin(text="计算文档上下文大小...", color="cyan") as spinner:
1322
1306
  # 将结果按文件分组
@@ -1326,16 +1310,16 @@ class RAGTool:
1326
1310
  if file_path not in file_groups:
1327
1311
  file_groups[file_path] = []
1328
1312
  file_groups[file_path].append((doc, score))
1329
-
1313
+
1330
1314
  # 计算所有文档的总token数
1331
1315
  total_docs_tokens = 0
1332
1316
  total_len = 0
1333
1317
  for file_path, docs in file_groups.items():
1334
1318
  file_header = f"\n## 文件: {file_path}\n"
1335
1319
  file_tokens = get_context_token_count(file_header)
1336
-
1320
+
1337
1321
  # 处理所有相关性足够高的文档
1338
- for doc, score in docs:
1322
+ for doc, score in docs:
1339
1323
  if score < 0.2:
1340
1324
  continue
1341
1325
  doc_content = f"""
@@ -1347,48 +1331,48 @@ class RAGTool:
1347
1331
  file_tokens += get_context_token_count(doc_content)
1348
1332
  total_len += len(doc_content)
1349
1333
  total_docs_tokens += file_tokens
1350
-
1334
+
1351
1335
  # 确定是否需要分批处理及分几批
1352
1336
  need_batching = total_docs_tokens > available_tokens_per_batch
1353
1337
  batch_count = 1
1354
1338
  if need_batching:
1355
1339
  batch_count = (total_docs_tokens + available_tokens_per_batch - 1) // available_tokens_per_batch
1356
-
1340
+
1357
1341
  if need_batching:
1358
1342
  spinner.text = f"文档需要分 {batch_count} 批处理 (总计 {total_docs_tokens} tokens), 总长度 {total_len} 字符"
1359
1343
  else:
1360
1344
  spinner.text = f"文档无需分批 (总计 {total_docs_tokens} tokens), 总长度 {total_len} 字符"
1361
1345
  spinner.ok("✅")
1362
-
1346
+
1363
1347
  # 单批处理直接使用原方法
1364
1348
  if not need_batching:
1365
1349
  with yaspin(text="添加上下文...", color="cyan") as spinner:
1366
1350
  prompt = base_prompt
1367
1351
  current_count = base_token_count
1368
-
1352
+
1369
1353
  # 保存已添加的内容指纹,避免重复
1370
1354
  added_content_hashes = set()
1371
-
1355
+
1372
1356
  # 按文件添加文档片段
1373
1357
  for file_path, docs in file_groups.items():
1374
1358
  # 按相关性排序
1375
1359
  docs.sort(key=lambda x: x[1], reverse=True)
1376
-
1360
+
1377
1361
  # 添加文件信息
1378
1362
  file_header = f"\n## 文件: {file_path}\n"
1379
1363
  if current_count + get_context_token_count(file_header) > available_tokens_per_batch:
1380
1364
  break
1381
-
1365
+
1382
1366
  prompt += file_header
1383
1367
  current_count += get_context_token_count(file_header)
1384
-
1368
+
1385
1369
  # 添加相关的文档片段,不限制每个文件的片段数量
1386
1370
  for doc, score in docs:
1387
1371
  # 计算内容指纹以避免重复
1388
1372
  content_hash = hash(doc.content)
1389
1373
  if content_hash in added_content_hashes:
1390
1374
  continue
1391
-
1375
+
1392
1376
  # 格式化文档片段
1393
1377
  doc_content = f"""
1394
1378
  ### 片段 {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']} [相关度: {score:.2f}]
@@ -1398,11 +1382,11 @@ class RAGTool:
1398
1382
  """
1399
1383
  if current_count + get_context_token_count(doc_content) > available_tokens_per_batch:
1400
1384
  break
1401
-
1385
+
1402
1386
  prompt += doc_content
1403
1387
  current_count += get_context_token_count(doc_content)
1404
1388
  added_content_hashes.add(content_hash)
1405
-
1389
+
1406
1390
  prompt += footer_prompt
1407
1391
  spinner.text = "添加上下文完成"
1408
1392
  spinner.ok("✅")
@@ -1413,39 +1397,39 @@ class RAGTool:
1413
1397
  spinner.text = "答案生成完成"
1414
1398
  spinner.ok("✅")
1415
1399
  return response
1416
-
1400
+
1417
1401
  # 分批处理文档
1418
1402
  else:
1419
1403
  batch_responses = []
1420
-
1404
+
1421
1405
  # 准备批次
1422
1406
  with yaspin(text=f"准备分批处理 (共{batch_count}批)...", color="cyan") as spinner:
1423
1407
  batches = []
1424
1408
  current_batch = []
1425
1409
  current_batch_tokens = 0
1426
-
1410
+
1427
1411
  # 按相关性排序处理文件
1428
- sorted_files = sorted(file_groups.items(),
1429
- key=lambda x: max(score for _, score in x[1]) if x[1] else 0,
1412
+ sorted_files = sorted(file_groups.items(),
1413
+ key=lambda x: max(score for _, score in x[1]) if x[1] else 0,
1430
1414
  reverse=True)
1431
-
1415
+
1432
1416
  for file_path, docs in sorted_files:
1433
1417
  # 按相关性排序文档
1434
1418
  docs.sort(key=lambda x: x[1], reverse=True)
1435
-
1419
+
1436
1420
  # 处理每个文件的文档
1437
1421
  file_header = f"\n## 文件: {file_path}\n"
1438
1422
  file_header_tokens = get_context_token_count(file_header)
1439
-
1423
+
1440
1424
  # 如果当前批次添加这个文件会超过限制,创建新批次
1441
1425
  file_docs = []
1442
1426
  file_docs_tokens = 0
1443
-
1427
+
1444
1428
  # 计算此文件要添加的所有文档,不限制片段数量
1445
1429
  for doc, score in docs:
1446
1430
  if score < 0.2: # 过滤低相关性文档
1447
1431
  continue
1448
-
1432
+
1449
1433
  doc_content = f"""
1450
1434
  ### 片段 {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']} [相关度: {score:.2f}]
1451
1435
  ```
@@ -1455,38 +1439,38 @@ class RAGTool:
1455
1439
  doc_tokens = get_context_token_count(doc_content)
1456
1440
  file_docs.append((doc, score, doc_content, doc_tokens))
1457
1441
  file_docs_tokens += doc_tokens
1458
-
1442
+
1459
1443
  # 如果此文件的内容加上文件头会导致当前批次超限,创建新批次
1460
1444
  if current_batch and (current_batch_tokens + file_header_tokens + file_docs_tokens > available_tokens_per_batch):
1461
1445
  batches.append(current_batch)
1462
1446
  current_batch = []
1463
1447
  current_batch_tokens = 0
1464
-
1448
+
1465
1449
  # 将文件及其文档添加到当前批次
1466
1450
  if file_docs: # 如果有要添加的文档
1467
1451
  current_batch.append((file_path, file_header, file_docs))
1468
1452
  current_batch_tokens += file_header_tokens + file_docs_tokens
1469
-
1453
+
1470
1454
  # 添加最后一个批次
1471
1455
  if current_batch:
1472
1456
  batches.append(current_batch)
1473
-
1457
+
1474
1458
  spinner.text = f"分批准备完成,共 {len(batches)} 批"
1475
1459
  spinner.ok("✅")
1476
-
1460
+
1477
1461
  # 处理每个批次
1478
1462
  for batch_idx, batch in enumerate(batches):
1479
1463
  with yaspin(text=f"处理批次 {batch_idx+1}/{len(batches)}...", color="cyan") as spinner:
1480
1464
  # 构建批次提示词
1481
1465
  batch_prompt = base_prompt + f"\n\n## 批次 {batch_idx+1}/{len(batches)} 的相关文档:\n"
1482
-
1466
+
1483
1467
  # 添加批次中的文档
1484
1468
  for file_path, file_header, file_docs in batch:
1485
1469
  batch_prompt += file_header
1486
-
1470
+
1487
1471
  for doc, score, doc_content, _ in file_docs:
1488
1472
  batch_prompt += doc_content
1489
-
1473
+
1490
1474
  # 为最后一个批次添加总结指令,为中间批次添加部分分析指令
1491
1475
  if batch_idx == len(batches) - 1:
1492
1476
  # 最后一个批次,添加总结所有批次的指令
@@ -1507,20 +1491,20 @@ class RAGTool:
1507
1491
  3. 简明扼要,重点关注与问题直接相关的内容
1508
1492
  4. 忽略与问题无关的内容
1509
1493
  """
1510
-
1494
+
1511
1495
  spinner.text = f"正在分析批次 {batch_idx+1}/{len(batches)}..."
1512
-
1496
+
1513
1497
  # 调用模型处理当前批次
1514
1498
  batch_response = model.chat_until_success(batch_prompt)
1515
1499
  batch_responses.append(batch_response)
1516
-
1500
+
1517
1501
  spinner.text = f"批次 {batch_idx+1}/{len(batches)} 分析完成"
1518
1502
  spinner.ok("✅")
1519
-
1503
+
1520
1504
  # 如果只有一个批次,直接返回结果
1521
1505
  if len(batch_responses) == 1:
1522
1506
  return batch_responses[0]
1523
-
1507
+
1524
1508
  # 如果有多个批次,需要汇总结果
1525
1509
  with yaspin(text="汇总多批次分析结果...", color="cyan") as spinner:
1526
1510
  # 构建汇总提示词
@@ -1536,7 +1520,7 @@ class RAGTool:
1536
1520
  以下是各批次的分析结果:
1537
1521
 
1538
1522
  """
1539
-
1523
+
1540
1524
  # 添加每个批次的分析结果
1541
1525
  for i, response in enumerate(batch_responses):
1542
1526
  summary_prompt += f"""
@@ -1544,7 +1528,7 @@ class RAGTool:
1544
1528
  {response}
1545
1529
 
1546
1530
  """
1547
-
1531
+
1548
1532
  # 添加汇总指导
1549
1533
  summary_prompt += """
1550
1534
  ## 汇总要求
@@ -1583,66 +1567,66 @@ class RAGTool:
1583
1567
 
1584
1568
  请直接提供最终回答,不需要解释你的汇总过程。
1585
1569
  """
1586
-
1570
+
1587
1571
  spinner.text = "正在生成最终汇总答案..."
1588
-
1572
+
1589
1573
  # 调用模型生成最终汇总
1590
1574
  final_response = model.chat_until_success(summary_prompt)
1591
-
1575
+
1592
1576
  spinner.text = "汇总答案生成完成"
1593
1577
  spinner.ok("✅")
1594
-
1578
+
1595
1579
  return final_response
1596
-
1580
+
1597
1581
  except Exception as e:
1598
1582
  PrettyOutput.print(f"回答失败:{str(e)}", OutputType.ERROR)
1599
1583
  return None
1600
-
1584
+
1601
1585
  def _enhance_query(self, query: str) -> str:
1602
1586
  """增强查询以提高检索质量
1603
-
1587
+
1604
1588
  Args:
1605
1589
  query: 原始查询
1606
-
1590
+
1607
1591
  Returns:
1608
1592
  str: 增强后的查询
1609
1593
  """
1610
1594
  # 简单的查询预处理
1611
1595
  query = query.strip()
1612
-
1596
+
1613
1597
  # 如果查询太短,返回原始查询
1614
1598
  if len(query) < 10:
1615
1599
  return query
1616
-
1600
+
1617
1601
  try:
1618
1602
  # 尝试使用大模型增强查询(如果可用)
1619
1603
  model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
1620
1604
  enhance_prompt = f"""请分析以下查询,提取关键概念、关键词和主题。
1621
-
1605
+
1622
1606
  查询:"{query}"
1623
1607
 
1624
1608
  输出格式:对原始查询的改写版本,专注于提取关键信息,保留原始语义,以提高检索相关度。
1625
1609
  仅输出改写后的查询文本,不要输出其他内容。
1626
1610
  只对信息进行最小必要的增强,不要过度添加与原始查询无关的内容。
1627
1611
  """
1628
-
1612
+
1629
1613
  enhanced_query = model.chat_until_success(enhance_prompt)
1630
1614
  # 清理增强的查询结果
1631
1615
  enhanced_query = enhanced_query.strip().strip('"')
1632
-
1616
+
1633
1617
  # 如果增强查询有效且不是完全相同的,使用它
1634
1618
  if enhanced_query and len(enhanced_query) >= len(query) / 2 and enhanced_query != query:
1635
1619
  return enhanced_query
1636
-
1620
+
1637
1621
  except Exception:
1638
1622
  # 如果增强失败,使用原始查询
1639
1623
  pass
1640
-
1624
+
1641
1625
  return query
1642
1626
 
1643
1627
  def is_index_built(self) -> bool:
1644
1628
  """Check if the index is built and valid
1645
-
1629
+
1646
1630
  Returns:
1647
1631
  bool: True if index is built and valid
1648
1632
  """
@@ -1650,7 +1634,7 @@ class RAGTool:
1650
1634
 
1651
1635
  def _delete_file_cache(self, file_path: str, spinner=None):
1652
1636
  """Delete cache files for a specific file
1653
-
1637
+
1654
1638
  Args:
1655
1639
  file_path: Path to the original file
1656
1640
  spinner: Optional spinner for progress information. If None, runs silently.
@@ -1662,14 +1646,14 @@ class RAGTool:
1662
1646
  os.remove(doc_cache_path)
1663
1647
  if spinner is not None:
1664
1648
  spinner.write(f"🗑️ 删除文档缓存: {file_path}")
1665
-
1649
+
1666
1650
  # Delete vector cache
1667
1651
  vec_cache_path = self._get_cache_path(file_path, "vec")
1668
1652
  if os.path.exists(vec_cache_path):
1669
1653
  os.remove(vec_cache_path)
1670
1654
  if spinner is not None:
1671
1655
  spinner.write(f"🗑️ 删除向量缓存: {file_path}")
1672
-
1656
+
1673
1657
  except Exception as e:
1674
1658
  if spinner is not None:
1675
1659
  spinner.write(f"❌ 删除缓存失败: {file_path}: {str(e)}")
@@ -1679,13 +1663,13 @@ def main():
1679
1663
  """Main function"""
1680
1664
  import argparse
1681
1665
  import sys
1682
-
1666
+
1683
1667
  # Set standard output encoding to UTF-8
1684
1668
  if sys.stdout.encoding != 'utf-8':
1685
1669
  import codecs
1686
1670
  sys.stdout = codecs.getwriter('utf-8')(sys.stdout.buffer, 'strict')
1687
1671
  sys.stderr = codecs.getwriter('utf-8')(sys.stderr.buffer, 'strict')
1688
-
1672
+
1689
1673
  parser = argparse.ArgumentParser(description='Document retrieval and analysis tool')
1690
1674
  parser.add_argument('--dir', type=str, help='Directory to process')
1691
1675
  parser.add_argument('--build', action='store_true', help='Build document index')
@@ -1709,7 +1693,7 @@ def main():
1709
1693
  if not rag.is_index_built():
1710
1694
  PrettyOutput.print(f"索引未建立,自动为目录 '{args.dir}' 建立索引...", OutputType.INFO)
1711
1695
  rag.build_index(args.dir)
1712
-
1696
+
1713
1697
  if not rag.is_index_built():
1714
1698
  PrettyOutput.print("索引建立失败,请检查目录和文件格式", OutputType.ERROR)
1715
1699
  return 1
@@ -1719,7 +1703,7 @@ def main():
1719
1703
  if not results:
1720
1704
  PrettyOutput.print("未找到相关内容", output_type=OutputType.WARNING)
1721
1705
  return 1
1722
-
1706
+
1723
1707
  for doc in results:
1724
1708
  output = f"""文件: {doc.metadata['file_path']}\n"""
1725
1709
  output += f"""片段 {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']}\n"""
@@ -1733,7 +1717,7 @@ def main():
1733
1717
  if not response:
1734
1718
  PrettyOutput.print("获取答案失败", output_type=OutputType.WARNING)
1735
1719
  return 1
1736
-
1720
+
1737
1721
  # Display answer
1738
1722
  output = f"""{response}"""
1739
1723
  PrettyOutput.print(output, output_type=OutputType.INFO)