jarvis-ai-assistant 0.1.131__py3-none-any.whl → 0.1.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

Files changed (75) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/jarvis_agent/__init__.py +165 -285
  3. jarvis/jarvis_agent/jarvis.py +143 -0
  4. jarvis/jarvis_agent/main.py +0 -2
  5. jarvis/jarvis_agent/patch.py +70 -48
  6. jarvis/jarvis_agent/shell_input_handler.py +1 -1
  7. jarvis/jarvis_code_agent/code_agent.py +169 -117
  8. jarvis/jarvis_dev/main.py +327 -626
  9. jarvis/jarvis_git_squash/main.py +10 -31
  10. jarvis/jarvis_lsp/base.py +0 -42
  11. jarvis/jarvis_lsp/cpp.py +0 -15
  12. jarvis/jarvis_lsp/go.py +0 -15
  13. jarvis/jarvis_lsp/python.py +0 -19
  14. jarvis/jarvis_lsp/registry.py +0 -62
  15. jarvis/jarvis_lsp/rust.py +0 -15
  16. jarvis/jarvis_multi_agent/__init__.py +19 -69
  17. jarvis/jarvis_multi_agent/main.py +43 -0
  18. jarvis/jarvis_platform/ai8.py +7 -32
  19. jarvis/jarvis_platform/base.py +2 -7
  20. jarvis/jarvis_platform/kimi.py +3 -144
  21. jarvis/jarvis_platform/ollama.py +54 -68
  22. jarvis/jarvis_platform/openai.py +0 -4
  23. jarvis/jarvis_platform/oyi.py +0 -75
  24. jarvis/jarvis_platform/registry.py +2 -16
  25. jarvis/jarvis_platform/yuanbao.py +264 -0
  26. jarvis/jarvis_rag/file_processors.py +138 -0
  27. jarvis/jarvis_rag/main.py +1305 -425
  28. jarvis/jarvis_tools/ask_codebase.py +216 -43
  29. jarvis/jarvis_tools/code_review.py +158 -113
  30. jarvis/jarvis_tools/create_sub_agent.py +0 -1
  31. jarvis/jarvis_tools/execute_python_script.py +58 -0
  32. jarvis/jarvis_tools/execute_shell.py +13 -26
  33. jarvis/jarvis_tools/execute_shell_script.py +1 -1
  34. jarvis/jarvis_tools/file_analyzer.py +282 -0
  35. jarvis/jarvis_tools/file_operation.py +1 -1
  36. jarvis/jarvis_tools/find_caller.py +278 -0
  37. jarvis/jarvis_tools/find_symbol.py +295 -0
  38. jarvis/jarvis_tools/function_analyzer.py +331 -0
  39. jarvis/jarvis_tools/git_commiter.py +5 -5
  40. jarvis/jarvis_tools/methodology.py +88 -53
  41. jarvis/jarvis_tools/project_analyzer.py +308 -0
  42. jarvis/jarvis_tools/rag.py +0 -5
  43. jarvis/jarvis_tools/read_code.py +24 -3
  44. jarvis/jarvis_tools/read_webpage.py +195 -81
  45. jarvis/jarvis_tools/registry.py +132 -11
  46. jarvis/jarvis_tools/search_web.py +22 -307
  47. jarvis/jarvis_tools/tool_generator.py +8 -10
  48. jarvis/jarvis_utils/__init__.py +1 -0
  49. jarvis/jarvis_utils/config.py +80 -76
  50. jarvis/jarvis_utils/embedding.py +344 -45
  51. jarvis/jarvis_utils/git_utils.py +9 -1
  52. jarvis/jarvis_utils/input.py +7 -6
  53. jarvis/jarvis_utils/methodology.py +384 -15
  54. jarvis/jarvis_utils/output.py +5 -3
  55. jarvis/jarvis_utils/utils.py +60 -8
  56. {jarvis_ai_assistant-0.1.131.dist-info → jarvis_ai_assistant-0.1.134.dist-info}/METADATA +8 -16
  57. jarvis_ai_assistant-0.1.134.dist-info/RECORD +82 -0
  58. {jarvis_ai_assistant-0.1.131.dist-info → jarvis_ai_assistant-0.1.134.dist-info}/entry_points.txt +4 -3
  59. jarvis/jarvis_codebase/__init__.py +0 -0
  60. jarvis/jarvis_codebase/main.py +0 -1011
  61. jarvis/jarvis_tools/lsp_find_definition.py +0 -150
  62. jarvis/jarvis_tools/lsp_find_references.py +0 -127
  63. jarvis/jarvis_tools/treesitter_analyzer.py +0 -331
  64. jarvis/jarvis_treesitter/README.md +0 -104
  65. jarvis/jarvis_treesitter/__init__.py +0 -20
  66. jarvis/jarvis_treesitter/database.py +0 -258
  67. jarvis/jarvis_treesitter/example.py +0 -115
  68. jarvis/jarvis_treesitter/grammar_builder.py +0 -182
  69. jarvis/jarvis_treesitter/language.py +0 -117
  70. jarvis/jarvis_treesitter/symbol.py +0 -31
  71. jarvis/jarvis_treesitter/tools_usage.md +0 -121
  72. jarvis_ai_assistant-0.1.131.dist-info/RECORD +0 -85
  73. {jarvis_ai_assistant-0.1.131.dist-info → jarvis_ai_assistant-0.1.134.dist-info}/LICENSE +0 -0
  74. {jarvis_ai_assistant-0.1.131.dist-info → jarvis_ai_assistant-0.1.134.dist-info}/WHEEL +0 -0
  75. {jarvis_ai_assistant-0.1.131.dist-info → jarvis_ai_assistant-0.1.134.dist-info}/top_level.txt +0 -0
jarvis/jarvis_rag/main.py CHANGED
@@ -1,12 +1,10 @@
1
1
  import os
2
+ import re
2
3
  import numpy as np
3
4
  import faiss
4
5
  from typing import List, Tuple, Optional, Dict
5
6
  import pickle
6
7
  from dataclasses import dataclass
7
- from tqdm import tqdm
8
- import fitz # PyMuPDF for PDF files
9
- from docx import Document as DocxDocument # python-docx for DOCX files
10
8
  from pathlib import Path
11
9
 
12
10
  from yaspin import yaspin
@@ -15,10 +13,30 @@ import lzma # 添加 lzma 导入
15
13
  from threading import Lock
16
14
  import hashlib
17
15
 
18
- from jarvis.jarvis_utils.config import get_max_paragraph_length, get_max_token_count, get_min_paragraph_length, get_thread_count
19
- from jarvis.jarvis_utils.embedding import get_context_token_count, get_embedding, get_embedding_batch, load_embedding_model
16
+ from jarvis.jarvis_utils.config import get_max_paragraph_length, get_max_token_count, get_min_paragraph_length, get_thread_count, get_rag_ignored_paths
17
+ from jarvis.jarvis_utils.embedding import get_context_token_count, get_embedding, get_embedding_batch, load_embedding_model, rerank_results
20
18
  from jarvis.jarvis_utils.output import OutputType, PrettyOutput
21
- from jarvis.jarvis_utils.utils import get_file_md5, init_env, init_gpu_config
19
+ from jarvis.jarvis_utils.utils import ct, get_file_md5, init_env, init_gpu_config, ot
20
+
21
+ from jarvis.jarvis_rag.file_processors import TextFileProcessor, PDFProcessor, DocxProcessor, PPTProcessor, ExcelProcessor
22
+
23
+ """
24
+ Jarvis RAG (Retrieval-Augmented Generation) Module
25
+
26
+ 这个模块实现了高效的本地RAG系统,具有以下特性:
27
+ 1. 多格式文档处理(文本、PDF、Word、PPT、Excel)
28
+ 2. 高效向量检索与关键词匹配相结合的混合搜索
29
+ 3. 交叉编码器重排序,大幅提升检索准确性
30
+ 4. 增量更新检测,避免重复处理
31
+ 5. 自动上下文扩展,提供更完整信息
32
+ 6. 针对RAG优化的文本分割,保持语义完整性
33
+ 7. 缓存机制,提高反复查询性能
34
+ 8. 批处理向量化,优化内存和计算资源使用
35
+ 9. 多线程处理能力
36
+ 10. GPU加速(如果可用)
37
+
38
+ 适用于:代码库文档检索、知识库问答、本地资料分析等场景
39
+ """
22
40
 
23
41
  @dataclass
24
42
  class Document:
@@ -27,111 +45,7 @@ class Document:
27
45
  metadata: Dict # Metadata (file path, position, etc.)
28
46
  md5: str = "" # File MD5 value, for incremental update detection
29
47
 
30
- class FileProcessor:
31
- """Base class for file processor"""
32
- @staticmethod
33
- def can_handle(file_path: str) -> bool:
34
- """Determine if the file can be processed"""
35
- raise NotImplementedError
36
-
37
- @staticmethod
38
- def extract_text(file_path: str) -> str:
39
- """Extract file text content"""
40
- raise NotImplementedError
41
-
42
- class TextFileProcessor(FileProcessor):
43
- """Text file processor"""
44
- ENCODINGS = ['utf-8', 'gbk', 'gb2312', 'latin1']
45
- SAMPLE_SIZE = 8192 # Read the first 8KB to detect encoding
46
-
47
- @staticmethod
48
- def can_handle(file_path: str) -> bool:
49
- """Determine if the file is a text file by trying to decode it"""
50
- try:
51
- # Read the first part of the file to detect encoding
52
- with open(file_path, 'rb') as f:
53
- sample = f.read(TextFileProcessor.SAMPLE_SIZE)
54
-
55
- # Check if it contains null bytes (usually represents a binary file)
56
- if b'\x00' in sample:
57
- return False
58
-
59
- # Check if it contains too many non-printable characters (usually represents a binary file)
60
- non_printable = sum(1 for byte in sample if byte < 32 and byte not in (9, 10, 13)) # tab, newline, carriage return
61
- if non_printable / len(sample) > 0.3: # If non-printable characters exceed 30%, it is considered a binary file
62
- return False
63
-
64
- # Try to decode with different encodings
65
- for encoding in TextFileProcessor.ENCODINGS:
66
- try:
67
- sample.decode(encoding)
68
- return True
69
- except UnicodeDecodeError:
70
- continue
71
-
72
- return False
73
-
74
- except Exception:
75
- return False
76
-
77
- @staticmethod
78
- def extract_text(file_path: str) -> str:
79
- """Extract text content, using the detected correct encoding"""
80
- detected_encoding = None
81
- try:
82
- # First try to detect encoding
83
- with open(file_path, 'rb') as f:
84
- raw_data = f.read()
85
-
86
- # Try different encodings
87
- for encoding in TextFileProcessor.ENCODINGS:
88
- try:
89
- raw_data.decode(encoding)
90
- detected_encoding = encoding
91
- break
92
- except UnicodeDecodeError:
93
- continue
94
-
95
- if not detected_encoding:
96
- raise UnicodeDecodeError(f"Failed to decode file with supported encodings: {file_path}") # type: ignore
97
-
98
- # Use the detected encoding to read the file
99
- with open(file_path, 'r', encoding=detected_encoding, errors='ignore') as f:
100
- content = f.read()
101
-
102
- # Normalize Unicode characters
103
- import unicodedata
104
- content = unicodedata.normalize('NFKC', content)
105
-
106
- return content
107
-
108
- except Exception as e:
109
- raise Exception(f"Failed to read file: {str(e)}")
110
48
 
111
- class PDFProcessor(FileProcessor):
112
- """PDF file processor"""
113
- @staticmethod
114
- def can_handle(file_path: str) -> bool:
115
- return Path(file_path).suffix.lower() == '.pdf'
116
-
117
- @staticmethod
118
- def extract_text(file_path: str) -> str:
119
- text_parts = []
120
- with fitz.open(file_path) as doc: # type: ignore
121
- for page in doc:
122
- text_parts.append(page.get_text()) # type: ignore
123
- return "\n".join(text_parts)
124
-
125
- class DocxProcessor(FileProcessor):
126
- """DOCX file processor"""
127
- @staticmethod
128
- def can_handle(file_path: str) -> bool:
129
- return Path(file_path).suffix.lower() == '.docx'
130
-
131
- @staticmethod
132
- def extract_text(file_path: str) -> str:
133
- doc = DocxDocument(file_path)
134
- return "\n".join([paragraph.text for paragraph in doc.paragraphs])
135
49
 
136
50
  class RAGTool:
137
51
  def __init__(self, root_dir: str):
@@ -196,7 +110,9 @@ class RAGTool:
196
110
  self.file_processors = [
197
111
  TextFileProcessor(),
198
112
  PDFProcessor(),
199
- DocxProcessor()
113
+ DocxProcessor(),
114
+ PPTProcessor(),
115
+ ExcelProcessor()
200
116
  ]
201
117
  spinner.text = "文件处理器初始化完成"
202
118
  spinner.ok("✅")
@@ -211,23 +127,38 @@ class RAGTool:
211
127
 
212
128
  # 初始化 GPU 内存配置
213
129
  with yaspin(text="初始化 GPU 内存配置...", color="cyan") as spinner:
214
- self.gpu_config = init_gpu_config()
130
+ with spinner.hidden():
131
+ self.gpu_config = init_gpu_config()
215
132
  spinner.text = "GPU 内存配置初始化完成"
216
133
  spinner.ok("✅")
217
134
 
218
135
 
219
- def _get_cache_path(self, file_path: str) -> str:
136
+ def _get_cache_path(self, file_path: str, cache_type: str = "doc") -> str:
220
137
  """Get cache file path for a document
221
138
 
222
139
  Args:
223
140
  file_path: Original file path
141
+ cache_type: Type of cache ("doc" for documents, "vec" for vectors)
224
142
 
225
143
  Returns:
226
144
  str: Cache file path
227
145
  """
228
146
  # 使用文件路径的哈希作为缓存文件名
229
147
  file_hash = hashlib.md5(file_path.encode()).hexdigest()
230
- return os.path.join(self.cache_dir, f"{file_hash}.cache")
148
+
149
+ # 确保不同类型的缓存有不同的目录
150
+ if cache_type == "doc":
151
+ cache_subdir = os.path.join(self.cache_dir, "documents")
152
+ elif cache_type == "vec":
153
+ cache_subdir = os.path.join(self.cache_dir, "vectors")
154
+ else:
155
+ cache_subdir = self.cache_dir
156
+
157
+ # 确保子目录存在
158
+ if not os.path.exists(cache_subdir):
159
+ os.makedirs(cache_subdir)
160
+
161
+ return os.path.join(cache_subdir, f"{file_hash}.cache")
231
162
 
232
163
  def _load_cache_index(self):
233
164
  """Load cache index"""
@@ -244,37 +175,64 @@ class RAGTool:
244
175
  # 从各个缓存文件加载文档
245
176
  with yaspin(text="加载缓存文件...", color="cyan") as spinner:
246
177
  for file_path in self.file_md5_cache:
247
- cache_path = self._get_cache_path(file_path)
248
- if os.path.exists(cache_path):
178
+ doc_cache_path = self._get_cache_path(file_path, "doc")
179
+ if os.path.exists(doc_cache_path):
249
180
  try:
250
- with lzma.open(cache_path, 'rb') as f:
251
- file_cache = pickle.load(f)
252
- self.documents.extend(file_cache["documents"])
253
- spinner.write(f" 加载缓存文件: {file_path}")
181
+ with lzma.open(doc_cache_path, 'rb') as f:
182
+ doc_cache_data = pickle.load(f)
183
+ self.documents.extend(doc_cache_data["documents"])
184
+ spinner.text = f"加载文档缓存: {file_path}"
254
185
  except Exception as e:
255
- spinner.write(f"❌ 加载缓存文件失败: {file_path}: {str(e)}")
256
- spinner.text = "缓存文件加载完成"
186
+ spinner.write(f"❌ 加载文档缓存失败: {file_path}: {str(e)}")
187
+ spinner.text = "文档缓存加载完成"
257
188
  spinner.ok("✅")
258
189
 
259
190
  # 重建向量索引
260
-
261
191
  if self.documents:
262
192
  with yaspin(text="重建向量索引...", color="cyan") as spinner:
263
193
  vectors = []
194
+
195
+ # 按照文档列表顺序加载向量
196
+ processed_files = set()
264
197
  for doc in self.documents:
265
- cache_path = self._get_cache_path(doc.metadata['file_path'])
266
- if os.path.exists(cache_path):
267
- with lzma.open(cache_path, 'rb') as f:
268
- file_cache = pickle.load(f)
269
- doc_idx = next((i for i, d in enumerate(file_cache["documents"])
270
- if d.metadata['chunk_index'] == doc.metadata['chunk_index']), None)
271
- if doc_idx is not None:
272
- vectors.append(file_cache["vectors"][doc_idx])
198
+ file_path = doc.metadata['file_path']
199
+
200
+ # 避免重复处理同一个文件
201
+ if file_path in processed_files:
202
+ continue
203
+
204
+ processed_files.add(file_path)
205
+ vec_cache_path = self._get_cache_path(file_path, "vec")
206
+
207
+ if os.path.exists(vec_cache_path):
208
+ try:
209
+ # 加载该文件的向量缓存
210
+ with lzma.open(vec_cache_path, 'rb') as f:
211
+ vec_cache_data = pickle.load(f)
212
+ file_vectors = vec_cache_data["vectors"]
213
+
214
+ # 按照文档的chunk_index检索对应向量
215
+ doc_indices = [d.metadata['chunk_index'] for d in self.documents
216
+ if d.metadata['file_path'] == file_path]
217
+
218
+ # 检查向量数量与文档块数量是否匹配
219
+ if len(doc_indices) <= file_vectors.shape[0]:
220
+ for idx in doc_indices:
221
+ if idx < file_vectors.shape[0]:
222
+ vectors.append(file_vectors[idx].reshape(1, -1))
223
+ else:
224
+ spinner.write(f"⚠️ 向量缓存不匹配: {file_path}")
225
+
226
+ spinner.text = f"加载向量缓存: {file_path}"
227
+ except Exception as e:
228
+ spinner.write(f"❌ 加载向量缓存失败: {file_path}: {str(e)}")
229
+ else:
230
+ spinner.write(f"⚠️ 缺少向量缓存: {file_path}")
273
231
 
274
232
  if vectors:
275
233
  vectors = np.vstack(vectors)
276
- self._build_index(vectors)
277
- spinner.text = "向量索引重建完成,加载 {len(self.documents)} 个文档片段"
234
+ self._build_index(vectors, spinner)
235
+ spinner.text = f"向量索引重建完成,加载 {len(self.documents)} 个文档片段"
278
236
  spinner.ok("✅")
279
237
 
280
238
  except Exception as e:
@@ -285,67 +243,126 @@ class RAGTool:
285
243
  self.flat_index = None
286
244
  self.file_md5_cache = {}
287
245
 
288
- def _save_cache(self, file_path: str, documents: List[Document], vectors: np.ndarray):
246
+ def _save_cache(self, file_path: str, documents: List[Document], vectors: np.ndarray, spinner=None):
289
247
  """Save cache for a single file
290
248
 
291
249
  Args:
292
250
  file_path: File path
293
251
  documents: List of documents
294
252
  vectors: Document vectors
253
+ spinner: Optional spinner for progress display
295
254
  """
296
255
  try:
297
- # 保存文件缓存
298
- cache_path = self._get_cache_path(file_path)
299
- cache_data = {
300
- "documents": documents,
256
+ # 保存文档缓存
257
+ if spinner:
258
+ spinner.text = f"保存 {file_path} 的文档缓存..."
259
+ doc_cache_path = self._get_cache_path(file_path, "doc")
260
+ doc_cache_data = {
261
+ "documents": documents
262
+ }
263
+ with lzma.open(doc_cache_path, 'wb') as f:
264
+ pickle.dump(doc_cache_data, f)
265
+
266
+ # 保存向量缓存
267
+ if spinner:
268
+ spinner.text = f"保存 {file_path} 的向量缓存..."
269
+ vec_cache_path = self._get_cache_path(file_path, "vec")
270
+ vec_cache_data = {
301
271
  "vectors": vectors
302
272
  }
303
- with lzma.open(cache_path, 'wb') as f:
304
- pickle.dump(cache_data, f)
273
+ with lzma.open(vec_cache_path, 'wb') as f:
274
+ pickle.dump(vec_cache_data, f)
305
275
 
306
276
  # 更新并保存索引
277
+ if spinner:
278
+ spinner.text = f"更新 {file_path} 的索引缓存..."
307
279
  index_path = os.path.join(self.data_dir, "index.pkl")
308
280
  index_data = {
309
281
  "file_md5_cache": self.file_md5_cache
310
282
  }
311
283
  with lzma.open(index_path, 'wb') as f:
312
284
  pickle.dump(index_data, f)
285
+
286
+ if spinner:
287
+ spinner.text = f"{file_path} 的缓存保存完成"
313
288
 
314
289
  except Exception as e:
290
+ if spinner:
291
+ spinner.text = f"保存 {file_path} 的缓存失败: {str(e)}"
315
292
  PrettyOutput.print(f"保存缓存失败: {str(e)}", output_type=OutputType.ERROR)
316
293
 
317
- def _build_index(self, vectors: np.ndarray):
294
+ def _build_index(self, vectors: np.ndarray, spinner=None):
318
295
  """Build FAISS index"""
319
296
  if vectors.shape[0] == 0:
297
+ if spinner:
298
+ spinner.text = "向量为空,跳过索引构建"
320
299
  self.index = None
321
300
  self.flat_index = None
322
301
  return
323
302
 
324
303
  # Create a flat index to store original vectors, for reconstruction
304
+ if spinner:
305
+ spinner.text = "创建平面索引用于向量重建..."
325
306
  self.flat_index = faiss.IndexFlatIP(self.vector_dim)
326
307
  self.flat_index.add(vectors) # type: ignore
327
308
 
328
309
  # Create an IVF index for fast search
329
- nlist = max(4, int(vectors.shape[0] / 1000)) # 每1000个向量一个聚类中心
310
+ if spinner:
311
+ spinner.text = "创建IVF索引用于快速搜索..."
312
+ # 修改聚类中心的计算方式,小数据量时使用更少的聚类中心
313
+ # 避免"WARNING clustering X points to Y centroids: please provide at least Z training points"警告
314
+ num_vectors = vectors.shape[0]
315
+ if num_vectors < 100:
316
+ # 对于小于100个向量的情况,使用更少的聚类中心
317
+ nlist = 1 # 只用1个聚类中心
318
+ elif num_vectors < 1000:
319
+ # 对于100-1000个向量的情况,使用较少的聚类中心
320
+ nlist = max(1, int(num_vectors / 100)) # 每100个向量一个聚类中心
321
+ else:
322
+ # 原始逻辑:每1000个向量一个聚类中心,最少4个
323
+ nlist = max(4, int(num_vectors / 1000))
324
+
330
325
  quantizer = faiss.IndexFlatIP(self.vector_dim)
331
326
  self.index = faiss.IndexIVFFlat(quantizer, self.vector_dim, nlist, faiss.METRIC_INNER_PRODUCT)
332
327
 
333
328
  # Train and add vectors
329
+ if spinner:
330
+ spinner.text = f"训练索引({vectors.shape[0]}个向量,{nlist}个聚类中心)..."
334
331
  self.index.train(vectors) # type: ignore
332
+
333
+ if spinner:
334
+ spinner.text = "添加向量到索引..."
335
335
  self.index.add(vectors) # type: ignore
336
+
336
337
  # Set the number of clusters to probe during search
338
+ if spinner:
339
+ spinner.text = "设置搜索参数..."
337
340
  self.index.nprobe = min(nlist, 10)
341
+
342
+ if spinner:
343
+ spinner.text = f"索引构建完成,共 {vectors.shape[0]} 个向量"
338
344
 
339
345
  def _split_text(self, text: str) -> List[str]:
340
- """Use a more intelligent splitting strategy"""
341
- # Add overlapping blocks to maintain context consistency
342
- overlap_size = min(200, self.max_paragraph_length // 4)
346
+ """使用基于token计数的更智能的分割策略
347
+
348
+ Args:
349
+ text: 要分割的文本
350
+
351
+ Returns:
352
+ List[str]: 分割后的段落列表
353
+ """
354
+ from jarvis.jarvis_utils.embedding import get_context_token_count
355
+
356
+ # 计算可用的最大和最小token数
357
+ max_tokens = int(self.max_paragraph_length * 0.25) # 字符长度转换为大致token数
358
+ min_tokens = int(self.min_paragraph_length * 0.25) # 字符长度转换为大致token数
343
359
 
360
+ # 添加重叠块以保持上下文一致性
344
361
  paragraphs = []
345
362
  current_chunk = []
346
- current_length = 0
363
+ current_token_count = 0
347
364
 
348
- # First split by sentence
365
+ # 首先按句子分割
349
366
  sentences = []
350
367
  current_sentence = []
351
368
  sentence_ends = {'。', '!', '?', '…', '.', '!', '?'}
@@ -363,77 +380,73 @@ class RAGTool:
363
380
  if sentence.strip():
364
381
  sentences.append(sentence)
365
382
 
366
- # Build overlapping blocks based on sentences
383
+ # 基于句子构建重叠块
367
384
  for sentence in sentences:
368
- if current_length + len(sentence) > self.max_paragraph_length:
385
+ # 计算当前句子的token数
386
+ sentence_token_count = get_context_token_count(sentence)
387
+
388
+ # 检查添加此句子是否会超过最大token限制
389
+ if current_token_count + sentence_token_count > max_tokens:
369
390
  if current_chunk:
370
391
  chunk_text = ' '.join(current_chunk)
371
- if len(chunk_text) >= self.min_paragraph_length:
392
+ chunk_token_count = get_context_token_count(chunk_text)
393
+
394
+ if chunk_token_count >= min_tokens:
372
395
  paragraphs.append(chunk_text)
396
+
397
+ # 保留一些内容作为重叠
398
+ # 保留最后两个句子作为重叠部分
399
+ if len(current_chunk) >= 2:
400
+ overlap_text = ' '.join(current_chunk[-2:])
401
+ overlap_token_count = get_context_token_count(overlap_text)
373
402
 
374
- # Keep some content as overlap
375
- overlap_text = ' '.join(current_chunk[-2:]) # Keep the last two sentences
376
- current_chunk = []
377
- if overlap_text:
378
- current_chunk.append(overlap_text)
379
- current_length = len(overlap_text)
403
+ current_chunk = []
404
+ if overlap_text:
405
+ current_chunk.append(overlap_text)
406
+ current_token_count = overlap_token_count
407
+ else:
408
+ current_token_count = 0
380
409
  else:
381
- current_length = 0
382
-
410
+ # 如果当前块中句子不足两个,就重置
411
+ current_chunk = []
412
+ current_token_count = 0
413
+
414
+ # 添加当前句子到块中
383
415
  current_chunk.append(sentence)
384
- current_length += len(sentence)
416
+ current_token_count += sentence_token_count
385
417
 
386
- # Process the last chunk
418
+ # 处理最后一个块
387
419
  if current_chunk:
388
420
  chunk_text = ' '.join(current_chunk)
389
- if len(chunk_text) >= self.min_paragraph_length:
421
+ chunk_token_count = get_context_token_count(chunk_text)
422
+
423
+ if chunk_token_count >= min_tokens:
390
424
  paragraphs.append(chunk_text)
391
425
 
392
426
  return paragraphs
393
427
 
394
428
 
395
- def _process_document_batch(self, documents: List[Document]) -> np.ndarray:
396
- """Process a batch of documents using shared memory"""
397
- try:
398
- texts = []
399
- self.documents = [] # Reset documents to store chunks
400
-
401
- for doc in documents:
402
- # Split original document into chunks
403
- chunks = self._split_text(doc.content)
404
- for chunk_idx, chunk in enumerate(chunks):
405
- # Create new Document for each chunk
406
- new_metadata = doc.metadata.copy()
407
- new_metadata.update({
408
- 'chunk_index': chunk_idx,
409
- 'total_chunks': len(chunks),
410
- 'original_length': len(doc.content)
411
- })
412
- self.documents.append(Document(
413
- content=chunk,
414
- metadata=new_metadata,
415
- md5=doc.md5
416
- ))
417
- texts.append(f"File:{doc.metadata['file_path']} Chunk:{chunk_idx} Content:{chunk}")
418
-
419
- return get_embedding_batch(self.embedding_model, texts)
420
- except Exception as e:
421
- PrettyOutput.print(f"批量处理失败: {str(e)}", OutputType.ERROR)
422
- return np.zeros((0, self.vector_dim), dtype=np.float32) # type: ignore
423
-
424
- def _process_file(self, file_path: str) -> List[Document]:
429
+ def _process_file(self, file_path: str, spinner=None) -> List[Document]:
425
430
  """Process a single file"""
426
431
  try:
427
432
  # Calculate file MD5
433
+ if spinner:
434
+ spinner.text = f"计算文件 {file_path} 的MD5..."
428
435
  current_md5 = get_file_md5(file_path)
429
436
  if not current_md5:
437
+ if spinner:
438
+ spinner.text = f"文件 {file_path} 计算MD5失败"
430
439
  return []
431
440
 
432
441
  # Check if the file needs to be reprocessed
433
442
  if file_path in self.file_md5_cache and self.file_md5_cache[file_path] == current_md5:
443
+ if spinner:
444
+ spinner.text = f"文件 {file_path} 未发生变化,跳过处理"
434
445
  return []
435
446
 
436
447
  # Find the appropriate processor
448
+ if spinner:
449
+ spinner.text = f"查找适用于 {file_path} 的处理器..."
437
450
  processor = None
438
451
  for p in self.file_processors:
439
452
  if p.can_handle(file_path):
@@ -442,17 +455,27 @@ class RAGTool:
442
455
 
443
456
  if not processor:
444
457
  # If no appropriate processor is found, return an empty document
458
+ if spinner:
459
+ spinner.text = f"没有找到适用于 {file_path} 的处理器,跳过处理"
445
460
  return []
446
461
 
447
462
  # Extract text content
463
+ if spinner:
464
+ spinner.text = f"提取 {file_path} 的文本内容..."
448
465
  content = processor.extract_text(file_path)
449
466
  if not content.strip():
467
+ if spinner:
468
+ spinner.text = f"文件 {file_path} 没有文本内容,跳过处理"
450
469
  return []
451
470
 
452
471
  # Split text
472
+ if spinner:
473
+ spinner.text = f"分割 {file_path} 的文本..."
453
474
  chunks = self._split_text(content)
454
475
 
455
476
  # Create document objects
477
+ if spinner:
478
+ spinner.text = f"为 {file_path} 创建 {len(chunks)} 个文档对象..."
456
479
  documents = []
457
480
  for i, chunk in enumerate(chunks):
458
481
  doc = Document(
@@ -469,209 +492,726 @@ class RAGTool:
469
492
 
470
493
  # Update MD5 cache
471
494
  self.file_md5_cache[file_path] = current_md5
495
+ if spinner:
496
+ spinner.text = f"文件 {file_path} 处理完成,共创建 {len(documents)} 个文档对象"
472
497
  return documents
473
498
 
474
499
  except Exception as e:
500
+ if spinner:
501
+ spinner.text = f"处理文件失败: {file_path}: {str(e)}"
475
502
  PrettyOutput.print(f"处理文件失败: {file_path}: {str(e)}",
476
503
  output_type=OutputType.ERROR)
477
504
  return []
478
505
 
479
- def build_index(self, dir: str):
480
- """Build document index with optimized processing"""
481
- # Get all files
482
- with yaspin(text="获取所有文件...", color="cyan") as spinner:
483
- all_files = []
484
- for root, _, files in os.walk(dir):
485
- # Skip .jarvis directories and other ignored paths
486
- if any(ignored in root for ignored in ['.git', '__pycache__', 'node_modules', '.jarvis']) or \
487
- any(part.startswith('.jarvis-') for part in root.split(os.sep)):
488
- continue
506
+ def _should_ignore_path(self, path: str, ignored_paths: List[str]) -> bool:
507
+ """
508
+ 检查路径是否应该被忽略
509
+
510
+ Args:
511
+ path: 文件或目录路径
512
+ ignored_paths: 忽略模式列表
513
+
514
+ Returns:
515
+ bool: 如果路径应该被忽略则返回True
516
+ """
517
+ import fnmatch
518
+ import os
519
+
520
+ # 获取相对路径
521
+ rel_path = path
522
+ if os.path.isabs(path):
523
+ try:
524
+ rel_path = os.path.relpath(path, self.root_dir)
525
+ except ValueError:
526
+ # 如果不能计算相对路径,使用原始路径
527
+ pass
528
+
529
+ path_parts = rel_path.split(os.sep)
530
+
531
+ # 检查路径的每一部分是否匹配任意忽略模式
532
+ for part in path_parts:
533
+ for pattern in ignored_paths:
534
+ if fnmatch.fnmatch(part, pattern):
535
+ return True
489
536
 
490
- for file in files:
491
- # Skip .jarvis files
492
- if '.jarvis' in root:
493
- continue
494
-
495
- file_path = os.path.join(root, file)
496
- if os.path.getsize(file_path) > 100 * 1024 * 1024: # 100MB
497
- PrettyOutput.print(f"Skip large file: {file_path}",
498
- output_type=OutputType.WARNING)
499
- continue
500
- all_files.append(file_path)
501
- spinner.text = f"获取所有文件完成,共 {len(all_files)} 个文件"
502
- spinner.ok("✅")
503
-
504
- # Clean up cache for deleted files
505
- with yaspin(text="清理缓存...", color="cyan") as spinner:
506
- deleted_files = set(self.file_md5_cache.keys()) - set(all_files)
507
- for file_path in deleted_files:
508
- del self.file_md5_cache[file_path]
509
- # Remove related documents
510
- self.documents = [doc for doc in self.documents if doc.metadata['file_path'] != file_path]
511
- spinner.text = f"清理缓存完成,共 {len(deleted_files)} 个文件"
512
- spinner.ok("✅")
513
-
514
- # Check file changes
515
- with yaspin(text="检查文件变化...", color="cyan") as spinner:
516
- files_to_process = []
517
- unchanged_files = []
518
- for file_path in all_files:
519
- current_md5 = get_file_md5(file_path)
520
- if current_md5: # Only process files that can successfully calculate MD5
521
- if file_path in self.file_md5_cache and self.file_md5_cache[file_path] == current_md5:
522
- # File未变化,记录但不重新处理
523
- unchanged_files.append(file_path)
524
- else:
525
- # New file or modified file
526
- files_to_process.append(file_path)
527
- spinner.write(f"⚠️ 文件变化: {file_path}")
528
- spinner.text = f"检查文件变化完成,共 {len(files_to_process)} 个文件需要处理"
529
- spinner.ok("✅")
530
-
531
- # Keep documents for unchanged files
532
- unchanged_documents = [doc for doc in self.documents
533
- if doc.metadata['file_path'] in unchanged_files]
534
-
535
- # Process files one by one with optimized vectorization
536
- if files_to_process:
537
- new_documents = []
538
- new_vectors = []
537
+ # 检查完整路径是否匹配任意忽略模式
538
+ for pattern in ignored_paths:
539
+ if fnmatch.fnmatch(rel_path, pattern):
540
+ return True
541
+
542
+ return False
543
+
544
+ def _is_git_repo(self) -> bool:
545
+ """
546
+ 检查当前目录是否为Git仓库
547
+
548
+ Returns:
549
+ bool: 如果是Git仓库则返回True
550
+ """
551
+ import subprocess
552
+
553
+ try:
554
+ result = subprocess.run(
555
+ ["git", "rev-parse", "--is-inside-work-tree"],
556
+ cwd=self.root_dir,
557
+ stdout=subprocess.PIPE,
558
+ stderr=subprocess.PIPE,
559
+ text=True,
560
+ check=False
561
+ )
562
+ return result.returncode == 0 and result.stdout.strip() == "true"
563
+ except Exception:
564
+ return False
565
+
566
+ def _get_git_managed_files(self) -> List[str]:
567
+ """
568
+ 获取Git仓库中被管理的文件列表
569
+
570
+ Returns:
571
+ List[str]: 被Git管理的文件路径列表(相对路径)
572
+ """
573
+ import subprocess
574
+
575
+ try:
576
+ # 获取git索引中的文件
577
+ result = subprocess.run(
578
+ ["git", "ls-files"],
579
+ cwd=self.root_dir,
580
+ stdout=subprocess.PIPE,
581
+ stderr=subprocess.PIPE,
582
+ text=True,
583
+ check=False
584
+ )
585
+
586
+ if result.returncode != 0:
587
+ return []
588
+
589
+ git_files = [line.strip() for line in result.stdout.splitlines() if line.strip()]
539
590
 
591
+ # 添加未暂存但已跟踪的修改文件
592
+ result = subprocess.run(
593
+ ["git", "ls-files", "--modified"],
594
+ cwd=self.root_dir,
595
+ stdout=subprocess.PIPE,
596
+ stderr=subprocess.PIPE,
597
+ text=True,
598
+ check=False
599
+ )
540
600
 
541
- for file_path in files_to_process:
542
- with yaspin(text=f"处理文件 {file_path} ...", color="cyan") as spinner:
543
- try:
544
- # Process single file
545
- file_docs = self._process_file(file_path)
546
- if file_docs:
547
- # Vectorize documents from this file
548
- texts_to_vectorize = [
549
- f"File:{doc.metadata['file_path']} Content:{doc.content}"
550
- for doc in file_docs
551
- ]
552
- file_vectors = get_embedding_batch(self.embedding_model, texts_to_vectorize)
601
+ if result.returncode == 0:
602
+ modified_files = [line.strip() for line in result.stdout.splitlines() if line.strip()]
603
+ git_files.extend([f for f in modified_files if f not in git_files])
604
+
605
+ # 转换为绝对路径
606
+ return [os.path.join(self.root_dir, file) for file in git_files]
607
+
608
+ except Exception as e:
609
+ PrettyOutput.print(f"获取Git管理的文件失败: {str(e)}", output_type=OutputType.WARNING)
610
+ return []
611
+
612
+ def build_index(self, dir: str):
613
+ try:
614
+ """Build document index with optimized processing"""
615
+ # Get all files
616
+ with yaspin(text="获取所有文件...", color="cyan") as spinner:
617
+ all_files = []
618
+
619
+ # 获取需要忽略的路径列表
620
+ ignored_paths = get_rag_ignored_paths()
621
+
622
+ # 检查是否为Git仓库
623
+ is_git_repo = self._is_git_repo()
624
+ if is_git_repo:
625
+ git_files = self._get_git_managed_files()
626
+ # 过滤掉被忽略的文件
627
+ for file_path in git_files:
628
+ if self._should_ignore_path(file_path, ignored_paths):
629
+ continue
553
630
 
554
- # Save cache for this file
555
- self._save_cache(file_path, file_docs, file_vectors)
631
+ if os.path.getsize(file_path) > 100 * 1024 * 1024: # 100MB
632
+ PrettyOutput.print(f"跳过大文件: {file_path}",
633
+ output_type=OutputType.WARNING)
634
+ continue
635
+ all_files.append(file_path)
636
+ else:
637
+ # 非Git仓库,使用常规文件遍历
638
+ for root, _, files in os.walk(dir):
639
+ # 检查目录是否匹配忽略模式
640
+ if self._should_ignore_path(root, ignored_paths):
641
+ continue
556
642
 
557
- # Accumulate documents and vectors
558
- new_documents.extend(file_docs)
559
- new_vectors.append(file_vectors)
560
-
561
- spinner.text = f"处理文件 {file_path} 完成"
562
- spinner.ok("✅")
643
+ for file in files:
644
+ file_path = os.path.join(root, file)
563
645
 
564
- except Exception as e:
565
- spinner.text = f"处理文件失败: {file_path}: {str(e)}"
566
- spinner.fail("❌")
646
+ # 检查文件是否匹配忽略模式
647
+ if self._should_ignore_path(file_path, ignored_paths):
648
+ continue
649
+
650
+ if os.path.getsize(file_path) > 100 * 1024 * 1024: # 100MB
651
+ PrettyOutput.print(f"跳过大文件: {file_path}",
652
+ output_type=OutputType.WARNING)
653
+ continue
654
+ all_files.append(file_path)
655
+
656
+ spinner.text = f"获取所有文件完成,共 {len(all_files)} 个文件"
657
+ spinner.ok("✅")
658
+
659
+ # Clean up cache for deleted files
660
+ with yaspin(text="清理缓存...", color="cyan") as spinner:
661
+ deleted_files = set(self.file_md5_cache.keys()) - set(all_files)
662
+ deleted_count = len(deleted_files)
663
+
664
+ if deleted_count > 0:
665
+ spinner.write(f"🗑️ 删除不存在文件的缓存: {deleted_count} 个")
567
666
 
667
+ for file_path in deleted_files:
668
+ # Remove from MD5 cache
669
+ del self.file_md5_cache[file_path]
670
+ # Remove related documents
671
+ self.documents = [doc for doc in self.documents if doc.metadata['file_path'] != file_path]
672
+ # Delete cache files
673
+ self._delete_file_cache(file_path, None) # Pass None as spinner to not show individual deletions
568
674
 
569
- # Update documents list
570
- self.documents.extend(new_documents)
675
+ spinner.text = f"清理缓存完成,共删除 {deleted_count} 个不存在文件的缓存"
676
+ spinner.ok("✅")
571
677
 
572
- # Build final index
573
- if new_vectors:
574
- with yaspin(text="构建最终索引...", color="cyan") as spinner:
575
- all_new_vectors = np.vstack(new_vectors)
678
+ # Check file changes
679
+ with yaspin(text="检查文件变化...", color="cyan") as spinner:
680
+ files_to_process = []
681
+ unchanged_files = []
682
+ new_files_count = 0
683
+ modified_files_count = 0
684
+
685
+ for file_path in all_files:
686
+ current_md5 = get_file_md5(file_path)
687
+ if current_md5: # Only process files that can successfully calculate MD5
688
+ if file_path in self.file_md5_cache and self.file_md5_cache[file_path] == current_md5:
689
+ # File未变化,记录但不重新处理
690
+ unchanged_files.append(file_path)
691
+ else:
692
+ # New file or modified file
693
+ files_to_process.append(file_path)
694
+
695
+ # 如果是修改的文件,删除旧缓存
696
+ if file_path in self.file_md5_cache:
697
+ modified_files_count += 1
698
+ # 删除旧缓存
699
+ self._delete_file_cache(file_path, spinner)
700
+ # 从文档列表中移除
701
+ self.documents = [doc for doc in self.documents if doc.metadata['file_path'] != file_path]
702
+ else:
703
+ new_files_count += 1
704
+
705
+ # 输出汇总信息
706
+ if unchanged_files:
707
+ spinner.write(f"📚 已缓存文件: {len(unchanged_files)} 个")
708
+ if new_files_count > 0:
709
+ spinner.write(f"🆕 新增文件: {new_files_count} 个")
710
+ if modified_files_count > 0:
711
+ spinner.write(f"📝 修改文件: {modified_files_count} 个")
576
712
 
577
- if self.flat_index is not None:
578
- # Get vectors for unchanged documents
579
- unchanged_vectors = self._get_unchanged_vectors(unchanged_documents)
580
- if unchanged_vectors is not None:
581
- final_vectors = np.vstack([unchanged_vectors, all_new_vectors])
713
+ spinner.text = f"检查文件变化完成,共 {len(files_to_process)} 个文件需要处理"
714
+ spinner.ok("✅")
715
+
716
+ # Keep documents for unchanged files
717
+ unchanged_documents = [doc for doc in self.documents
718
+ if doc.metadata['file_path'] in unchanged_files]
719
+
720
+ # Process files one by one with optimized vectorization
721
+ if files_to_process:
722
+ new_documents = []
723
+ new_vectors = []
724
+ success_count = 0
725
+ skipped_count = 0
726
+ failed_count = 0
727
+
728
+ with yaspin(text=f"处理文件中 (0/{len(files_to_process)})...", color="cyan") as spinner:
729
+ for index, file_path in enumerate(files_to_process):
730
+ spinner.text = f"处理文件中 ({index+1}/{len(files_to_process)}): {file_path}"
731
+ try:
732
+ # Process single file
733
+ file_docs = self._process_file(file_path, spinner)
734
+ if file_docs:
735
+ # Vectorize documents from this file
736
+ spinner.text = f"处理文件中 ({index+1}/{len(files_to_process)}): 为 {file_path} 生成向量嵌入..."
737
+ texts_to_vectorize = [
738
+ f"File:{doc.metadata['file_path']} Content:{doc.content}"
739
+ for doc in file_docs
740
+ ]
741
+
742
+ file_vectors = get_embedding_batch(self.embedding_model, f"({index+1}/{len(files_to_process)}){file_path}", texts_to_vectorize, spinner)
743
+
744
+ # Save cache for this file
745
+ spinner.text = f"处理文件中 ({index+1}/{len(files_to_process)}): 保存 {file_path} 的缓存..."
746
+ self._save_cache(file_path, file_docs, file_vectors, spinner)
747
+
748
+ # Accumulate documents and vectors
749
+ new_documents.extend(file_docs)
750
+ new_vectors.append(file_vectors)
751
+ success_count += 1
752
+ else:
753
+ # 文件跳过处理
754
+ skipped_count += 1
755
+
756
+ except Exception as e:
757
+ spinner.write(f"❌ 处理失败: {file_path}: {str(e)}")
758
+ failed_count += 1
759
+
760
+ # 输出处理统计
761
+ spinner.text = f"文件处理完成: 成功 {success_count} 个, 跳过 {skipped_count} 个, 失败 {failed_count} 个"
762
+ spinner.ok("✅")
763
+
764
+ # Update documents list
765
+ self.documents.extend(new_documents)
766
+
767
+ # Build final index
768
+ if new_vectors:
769
+ with yaspin(text="构建最终索引...", color="cyan") as spinner:
770
+ spinner.text = "合并新向量..."
771
+ all_new_vectors = np.vstack(new_vectors)
772
+
773
+ unchanged_vector_count = 0
774
+ if self.flat_index is not None:
775
+ # Get vectors for unchanged documents
776
+ spinner.text = "获取未变化文档的向量..."
777
+ unchanged_vectors = self._get_unchanged_vectors(unchanged_documents, spinner)
778
+ if unchanged_vectors is not None:
779
+ unchanged_vector_count = unchanged_vectors.shape[0]
780
+ spinner.text = f"合并新旧向量(新:{all_new_vectors.shape[0]},旧:{unchanged_vector_count})..."
781
+ final_vectors = np.vstack([unchanged_vectors, all_new_vectors])
782
+ else:
783
+ spinner.text = f"仅使用新向量({all_new_vectors.shape[0]})..."
784
+ final_vectors = all_new_vectors
582
785
  else:
786
+ spinner.text = f"仅使用新向量({all_new_vectors.shape[0]})..."
583
787
  final_vectors = all_new_vectors
584
- else:
585
- final_vectors = all_new_vectors
586
788
 
587
- # Build index
588
- spinner.text = f"构建索引..."
589
- self._build_index(final_vectors)
590
- spinner.text = f"索引构建完成,共 {len(self.documents)} 个文档 "
591
- spinner.ok("✅")
789
+ # Build index
790
+ spinner.text = f"构建索引(向量数量:{final_vectors.shape[0]})..."
791
+ self._build_index(final_vectors, spinner)
792
+ spinner.text = f"索引构建完成,共 {len(self.documents)} 个文档片段"
793
+ spinner.ok("✅")
592
794
 
593
- PrettyOutput.print(
594
- f"索引 {len(self.documents)} 个文档 "
595
- f"(新/修改: {len(new_documents)}, "
596
- f"不变: {len(unchanged_documents)})",
597
- OutputType.SUCCESS
598
- )
795
+ # 输出最终统计信息
796
+ PrettyOutput.print(
797
+ f"📊 索引统计:\n"
798
+ f" 总文档数: {len(self.documents)} 个文档片段\n"
799
+ f" • 已缓存文件: {len(unchanged_files)} 个\n"
800
+ f" • 处理文件: {len(files_to_process)} 个\n"
801
+ f" - 成功: {success_count} 个\n"
802
+ f" - 跳过: {skipped_count} 个\n"
803
+ f" - 失败: {failed_count} 个",
804
+ OutputType.SUCCESS
805
+ )
806
+ except Exception as e:
807
+ PrettyOutput.print(f"索引构建失败: {str(e)}",
808
+ output_type=OutputType.ERROR)
599
809
 
600
- def _get_unchanged_vectors(self, unchanged_documents: List[Document]) -> Optional[np.ndarray]:
810
+ def _get_unchanged_vectors(self, unchanged_documents: List[Document], spinner=None) -> Optional[np.ndarray]:
601
811
  """Get vectors for unchanged documents from existing index"""
602
812
  try:
603
- if not unchanged_documents or self.flat_index is None:
813
+ if not unchanged_documents:
814
+ if spinner:
815
+ spinner.text = "没有未变化的文档"
604
816
  return None
605
817
 
818
+ if spinner:
819
+ spinner.text = f"加载 {len(unchanged_documents)} 个未变化文档的向量..."
820
+
821
+ # 按文件分组处理
822
+ unchanged_files = set(doc.metadata['file_path'] for doc in unchanged_documents)
606
823
  unchanged_vectors = []
607
- for doc in unchanged_documents:
608
- doc_idx = next((i for i, d in enumerate(self.documents)
609
- if d.metadata['file_path'] == doc.metadata['file_path']), None)
610
- if doc_idx is not None:
611
- vector = np.zeros((1, self.vector_dim), dtype=np.float32) # type: ignore
612
- self.flat_index.reconstruct(doc_idx, vector.ravel())
613
- unchanged_vectors.append(vector)
824
+
825
+ for file_path in unchanged_files:
826
+ if spinner:
827
+ spinner.text = f"加载 {file_path} 的向量..."
828
+
829
+ # 获取该文件所有文档的chunk索引
830
+ doc_indices = [(i, doc.metadata['chunk_index'])
831
+ for i, doc in enumerate(unchanged_documents)
832
+ if doc.metadata['file_path'] == file_path]
833
+
834
+ if not doc_indices:
835
+ continue
836
+
837
+ # 加载该文件的向量
838
+ vec_cache_path = self._get_cache_path(file_path, "vec")
839
+ if os.path.exists(vec_cache_path):
840
+ try:
841
+ with lzma.open(vec_cache_path, 'rb') as f:
842
+ vec_cache_data = pickle.load(f)
843
+ file_vectors = vec_cache_data["vectors"]
844
+
845
+ # 按照chunk_index加载对应的向量
846
+ for _, chunk_idx in doc_indices:
847
+ if chunk_idx < file_vectors.shape[0]:
848
+ unchanged_vectors.append(file_vectors[chunk_idx].reshape(1, -1))
849
+
850
+ if spinner:
851
+ spinner.text = f"成功加载 {file_path} 的向量"
852
+ except Exception as e:
853
+ if spinner:
854
+ spinner.text = f"加载 {file_path} 向量失败: {str(e)}"
855
+ else:
856
+ if spinner:
857
+ spinner.text = f"未找到 {file_path} 的向量缓存"
858
+
859
+ # 从flat_index重建向量
860
+ if self.flat_index is not None:
861
+ if spinner:
862
+ spinner.text = f"从索引重建 {file_path} 的向量..."
863
+
864
+ for doc_idx, chunk_idx in doc_indices:
865
+ idx = next((i for i, d in enumerate(self.documents)
866
+ if d.metadata['file_path'] == file_path and
867
+ d.metadata['chunk_index'] == chunk_idx), None)
868
+
869
+ if idx is not None:
870
+ vector = np.zeros((1, self.vector_dim), dtype=np.float32) # type: ignore
871
+ self.flat_index.reconstruct(idx, vector.ravel())
872
+ unchanged_vectors.append(vector)
614
873
 
615
- return np.vstack(unchanged_vectors) if unchanged_vectors else None
874
+ if not unchanged_vectors:
875
+ if spinner:
876
+ spinner.text = "未能加载任何未变化文档的向量"
877
+ return None
878
+
879
+ if spinner:
880
+ spinner.text = f"未变化文档向量加载完成,共 {len(unchanged_vectors)} 个"
881
+
882
+ return np.vstack(unchanged_vectors)
616
883
 
617
884
  except Exception as e:
885
+ if spinner:
886
+ spinner.text = f"获取不变向量失败: {str(e)}"
618
887
  PrettyOutput.print(f"获取不变向量失败: {str(e)}", OutputType.ERROR)
619
888
  return None
620
889
 
621
- def search(self, query: str, top_k: int = 30) -> List[Tuple[Document, float]]:
890
+ def _perform_keyword_search(self, query: str, limit: int = 15) -> List[Tuple[int, float]]:
891
+ """执行基于关键词的文本搜索
892
+
893
+ Args:
894
+ query: 查询字符串
895
+ limit: 返回结果数量限制
896
+
897
+ Returns:
898
+ List[Tuple[int, float]]: 文档索引和得分的列表
899
+ """
900
+ # 使用大模型生成关键词
901
+ keywords = self._generate_keywords_with_llm(query)
902
+
903
+ # 如果大模型生成失败,回退到简单的关键词提取
904
+ if not keywords:
905
+ # 简单的关键词预处理
906
+ keywords = query.lower().split()
907
+ # 移除停用词和过短的词
908
+ stop_words = {'的', '了', '和', '是', '在', '有', '与', '对', '为', 'a', 'an', 'the', 'and', 'is', 'in', 'of', 'to', 'with'}
909
+ keywords = [k for k in keywords if k not in stop_words and len(k) > 1]
910
+
911
+ if not keywords:
912
+ return []
913
+
914
+ # 使用TF-IDF思想的简单实现
915
+ doc_scores = []
916
+
917
+ # 计算IDF(逆文档频率)
918
+ doc_count = len(self.documents)
919
+ keyword_doc_count = {}
920
+
921
+ for keyword in keywords:
922
+ count = 0
923
+ for doc in self.documents:
924
+ if keyword in doc.content.lower():
925
+ count += 1
926
+ keyword_doc_count[keyword] = max(1, count) # 避免除零错误
927
+
928
+ # 计算每个关键词的IDF值
929
+ keyword_idf = {
930
+ keyword: np.log(doc_count / count)
931
+ for keyword, count in keyword_doc_count.items()
932
+ }
933
+
934
+ # 为每个文档计算得分
935
+ for i, doc in enumerate(self.documents):
936
+ doc_content = doc.content.lower()
937
+ score = 0
938
+
939
+ # 计算每个关键词的TF(词频)
940
+ for keyword in keywords:
941
+ # 简单的TF:关键词在文档中出现的次数
942
+ tf = doc_content.count(keyword)
943
+ # TF-IDF得分
944
+ if tf > 0:
945
+ score += tf * keyword_idf[keyword]
946
+
947
+ # 添加额外权重:标题匹配、完整短语匹配等
948
+ if query.lower() in doc_content:
949
+ score *= 2.0 # 完整查询匹配加倍得分
950
+
951
+ # 文件路径匹配也加分
952
+ file_path = doc.metadata['file_path'].lower()
953
+ for keyword in keywords:
954
+ if keyword in file_path:
955
+ score += 0.5 * keyword_idf.get(keyword, 1.0)
956
+
957
+ if score > 0:
958
+ # 归一化得分(0-1范围)
959
+ doc_scores.append((i, score))
960
+
961
+ # 排序并限制结果数量
962
+ doc_scores.sort(key=lambda x: x[1], reverse=True)
963
+
964
+ # 归一化分数到0-1之间
965
+ if doc_scores:
966
+ max_score = max(score for _, score in doc_scores)
967
+ if max_score > 0:
968
+ doc_scores = [(idx, score/max_score) for idx, score in doc_scores]
969
+
970
+ return doc_scores[:limit]
971
+
972
+ def _generate_keywords_with_llm(self, query: str) -> List[str]:
973
+ """
974
+ 使用大语言模型从查询中提取关键词
975
+
976
+ Args:
977
+ query: 用户查询
978
+
979
+ Returns:
980
+ List[str]: 提取的关键词列表
981
+ """
982
+ try:
983
+ from jarvis.jarvis_utils.output import PrettyOutput, OutputType
984
+ from jarvis.jarvis_platform.registry import PlatformRegistry
985
+
986
+ # 获取平台注册表和模型
987
+ registry = PlatformRegistry.get_global_platform_registry()
988
+ model = registry.get_normal_platform()
989
+
990
+ # 构建关键词提取提示词
991
+ prompt = f"""
992
+ 请分析以下查询,提取用于文档检索的关键词。你的任务是:
993
+
994
+ 1. 识别核心概念、主题和实体,包括:
995
+ - 技术术语和专业名词
996
+ - 代码相关的函数名、类名、变量名和库名
997
+ - 重要的业务领域词汇
998
+ - 操作和动作相关的词汇
999
+
1000
+ 2. 优先提取与以下场景相关的关键词:
1001
+ - 代码搜索: 编程语言、框架、API、特定功能
1002
+ - 文档检索: 主题、标题词汇、专业名词
1003
+ - 错误排查: 错误信息、异常名称、问题症状
1004
+
1005
+ 3. 同时包含:
1006
+ - 中英文关键词 (尤其是技术领域常用英文术语)
1007
+ - 完整的专业术语和缩写形式
1008
+ - 可能的同义词或相关概念
1009
+
1010
+ 4. 关键词应当精准、具体,数量控制在3-10个之间。
1011
+
1012
+ 输出格式:
1013
+ {ot("KEYWORD")}
1014
+ 关键词1
1015
+ 关键词2
1016
+ ...
1017
+ {ct("KEYWORD")}
1018
+
1019
+ 查询文本:
1020
+ {query}
1021
+
1022
+ 仅返回提取的关键词,不要包含其他内容。
1023
+ """
1024
+
1025
+ # 调用大模型获取响应
1026
+ response = model.chat_until_success(prompt)
1027
+
1028
+ if response:
1029
+ # 清理响应,提取关键词
1030
+ sm = re.search(ot('KEYWORD') + r"(.*?)" + ct('KEYWORD'), response, re.DOTALL)
1031
+ if sm:
1032
+ extracted_keywords = sm[1]
1033
+
1034
+ if extracted_keywords:
1035
+ # 记录检测到的关键词
1036
+ ret = extracted_keywords.strip().splitlines()
1037
+ return ret
1038
+
1039
+ # 如果处理失败,返回空列表
1040
+ return []
1041
+
1042
+ except Exception as e:
1043
+ from jarvis.jarvis_utils.output import PrettyOutput, OutputType
1044
+ PrettyOutput.print(f"使用大模型生成关键词失败: {str(e)}", OutputType.WARNING)
1045
+ return []
1046
+
1047
+ def _hybrid_search(self, query: str, top_k: int = 15) -> List[Tuple[int, float]]:
1048
+ """混合搜索方法,综合向量相似度和关键词匹配
1049
+
1050
+ Args:
1051
+ query: 查询字符串
1052
+ top_k: 返回结果数量限制
1053
+
1054
+ Returns:
1055
+ List[Tuple[int, float]]: 文档索引和得分的列表
1056
+ """
1057
+ # 获取向量搜索结果
1058
+ query_vector = get_embedding(self.embedding_model, query)
1059
+ query_vector = query_vector.reshape(1, -1)
1060
+
1061
+ # 进行向量搜索
1062
+ vector_limit = min(top_k * 3, len(self.documents))
1063
+ if self.index and vector_limit > 0:
1064
+ distances, indices = self.index.search(query_vector, vector_limit) # type: ignore
1065
+ vector_results = [(int(idx), 1.0 / (1.0 + float(dist)))
1066
+ for idx, dist in zip(indices[0], distances[0])
1067
+ if idx != -1 and idx < len(self.documents)]
1068
+ else:
1069
+ vector_results = []
1070
+
1071
+ # 进行关键词搜索
1072
+ keyword_results = self._perform_keyword_search(query, top_k * 2)
1073
+
1074
+ # 合并结果集
1075
+ combined_results = {}
1076
+
1077
+ # 加入向量结果,权重为0.7
1078
+ for idx, score in vector_results:
1079
+ combined_results[idx] = score * 0.7
1080
+
1081
+ # 加入关键词结果,权重为0.3,如果文档已存在则取加权平均
1082
+ for idx, score in keyword_results:
1083
+ if idx in combined_results:
1084
+ # 已有向量得分,取加权平均
1085
+ combined_results[idx] = combined_results[idx] + score * 0.3
1086
+ else:
1087
+ # 新文档,直接添加关键词得分(权重稍低)
1088
+ combined_results[idx] = score * 0.3
1089
+
1090
+ # 转换成列表并排序
1091
+ result_list = [(idx, score) for idx, score in combined_results.items()]
1092
+ result_list.sort(key=lambda x: x[1], reverse=True)
1093
+
1094
+ return result_list[:top_k]
1095
+
1096
+
1097
+ def search(self, query: str, top_k: int = 15) -> List[Tuple[Document, float]]:
622
1098
  """Search documents with context window"""
623
- if not self.index:
1099
+ if not self.is_index_built():
1100
+ PrettyOutput.print("索引未建立,自动建立索引中...", OutputType.INFO)
624
1101
  self.build_index(self.root_dir)
625
1102
 
626
- # Get query vector
627
- with yaspin(text="获取查询向量...", color="cyan") as spinner:
628
- query_vector = get_embedding(self.embedding_model, query)
629
- query_vector = query_vector.reshape(1, -1)
630
- spinner.text = "查询向量获取完成"
1103
+ # 如果索引建立失败或文档列表为空,返回空结果
1104
+ if not self.is_index_built():
1105
+ PrettyOutput.print("索引建立失败或文档列表为空", OutputType.WARNING)
1106
+ return []
1107
+
1108
+ # 使用混合搜索获取候选文档
1109
+ with yaspin(text="执行混合搜索...", color="cyan") as spinner:
1110
+ # 获取初始候选结果
1111
+ search_results = self._hybrid_search(query, top_k * 2)
1112
+
1113
+ if not search_results:
1114
+ spinner.text = "搜索结果为空"
1115
+ spinner.fail("❌")
1116
+ return []
1117
+
1118
+ # 准备重排序
1119
+ initial_indices = [idx for idx, _ in search_results]
1120
+ spinner.text = f"检索完成,获取 {len(initial_indices)} 个候选文档"
631
1121
  spinner.ok("✅")
632
1122
 
633
- # Search with more candidates
634
- with yaspin(text="搜索...", color="cyan") as spinner:
635
- initial_k = min(top_k * 4, len(self.documents))
636
- distances, indices = self.index.search(query_vector, initial_k) # type: ignore
637
- spinner.text = "搜索完成"
1123
+ indices_list = [idx for idx, _ in search_results if idx < len(self.documents)]
1124
+
1125
+ # 应用重排序优化检索结果
1126
+ with yaspin(text="执行重排序...", color="cyan") as spinner:
1127
+ # 准备重排序所需文档内容和初始分数
1128
+ docs_to_rerank = []
1129
+ initial_scores = []
1130
+
1131
+ for idx, score in search_results:
1132
+ if idx < len(self.documents):
1133
+ doc = self.documents[idx]
1134
+ # 获取原始文档内容
1135
+ doc_content = f"File:{doc.metadata['file_path']} Content:{doc.content}"
1136
+ docs_to_rerank.append(doc_content)
1137
+ initial_scores.append(score)
1138
+
1139
+ if not docs_to_rerank:
1140
+ spinner.text = "没有可重排序的文档"
1141
+ spinner.fail("❌")
1142
+ return []
1143
+
1144
+ # 执行重排序
1145
+ spinner.text = f"重排序 {len(docs_to_rerank)} 个文档..."
1146
+ reranked_scores = rerank_results(
1147
+ query=query,
1148
+ documents=docs_to_rerank,
1149
+ initial_scores=initial_scores,
1150
+ spinner=spinner
1151
+ )
1152
+
1153
+ # 更新搜索结果的分数
1154
+ search_results = []
1155
+ for i, idx in enumerate(indices_list):
1156
+ if i < len(reranked_scores):
1157
+ search_results.append((idx, reranked_scores[i]))
1158
+
1159
+ # 按分数重新排序
1160
+ search_results.sort(key=lambda x: x[1], reverse=True)
1161
+
1162
+ spinner.text = "重排序完成"
638
1163
  spinner.ok("✅")
639
1164
 
1165
+ # 重新获取排序后的索引列表
1166
+ indices_list = [idx for idx, _ in search_results if idx < len(self.documents)]
1167
+
640
1168
  # Process results with context window
641
1169
  with yaspin(text="处理结果...", color="cyan") as spinner:
642
1170
  results = []
643
1171
  seen_files = set()
644
1172
 
645
- for idx, dist in zip(indices[0], distances[0]):
646
- if idx != -1:
1173
+ # 检查索引列表是否为空
1174
+ if not indices_list:
1175
+ spinner.text = "搜索结果为空"
1176
+ spinner.fail("❌")
1177
+ return []
1178
+
1179
+ for idx in indices_list:
1180
+ if idx < len(self.documents): # 确保索引有效
647
1181
  doc = self.documents[idx]
648
- similarity = 1.0 / (1.0 + float(dist))
649
- if similarity > 0.3:
650
- file_path = doc.metadata['file_path']
651
- if file_path not in seen_files:
652
- seen_files.add(file_path)
653
-
654
- # Get full context from original document
655
- original_doc = next((d for d in self.documents
656
- if d.metadata['file_path'] == file_path), None)
657
- if original_doc:
658
- window_docs = [] # Add this line to initialize the list
659
- full_content = original_doc.content
660
- # Find all chunks from this file
661
- file_chunks = [d for d in self.documents
662
- if d.metadata['file_path'] == file_path]
663
- # Add all related chunks
664
- for chunk_doc in file_chunks:
665
- window_docs.append((chunk_doc, similarity * 0.9))
666
-
667
- results.extend(window_docs)
668
- if len(results) >= top_k * (2 * self.context_window + 1):
669
- break
1182
+
1183
+ # 使用重排序得分或基于原始相似度的得分
1184
+ similarity = next((score for i, score in search_results if i == idx), 0.5) if search_results else 0.5
1185
+
1186
+ file_path = doc.metadata['file_path']
1187
+ if file_path not in seen_files:
1188
+ seen_files.add(file_path)
1189
+
1190
+ # Get full context from original document
1191
+ original_doc = next((d for d in self.documents
1192
+ if d.metadata['file_path'] == file_path), None)
1193
+ if original_doc:
1194
+ window_docs = [] # Add this line to initialize the list
1195
+ # Find all chunks from this file
1196
+ file_chunks = [d for d in self.documents
1197
+ if d.metadata['file_path'] == file_path]
1198
+ # Add all related chunks
1199
+ for chunk_doc in file_chunks:
1200
+ window_docs.append((chunk_doc, similarity * 0.9))
1201
+
1202
+ results.extend(window_docs)
1203
+ if len(results) >= top_k * (2 * self.context_window + 1):
1204
+ break
670
1205
  spinner.text = "处理结果完成"
671
1206
  spinner.ok("✅")
672
1207
 
673
1208
  # Sort by similarity and deduplicate
674
1209
  with yaspin(text="排序...", color="cyan") as spinner:
1210
+ if not results:
1211
+ spinner.text = "无有效结果"
1212
+ spinner.fail("❌")
1213
+ return []
1214
+
675
1215
  results.sort(key=lambda x: x[1], reverse=True)
676
1216
  seen = set()
677
1217
  final_results = []
@@ -702,97 +1242,403 @@ class RAGTool:
702
1242
  def ask(self, question: str) -> Optional[str]:
703
1243
  """Ask questions about documents with enhanced context building"""
704
1244
  try:
705
- results = self.search(question)
1245
+ # 检查索引是否已建立,如果没有则自动建立
1246
+ if not self.is_index_built():
1247
+ PrettyOutput.print("索引未建立,自动建立索引中...", OutputType.INFO)
1248
+ self.build_index(self.root_dir)
1249
+
1250
+ # 如果建立索引后仍未成功,返回错误信息
1251
+ if not self.is_index_built():
1252
+ PrettyOutput.print("无法建立索引,请检查文档和配置", OutputType.ERROR)
1253
+ return "无法建立索引,请检查文档和配置。可能的原因:文档目录为空、权限不足或格式不支持。"
1254
+
1255
+ # 增强查询预处理 - 提取关键词和语义信息
1256
+ enhanced_query = self._enhance_query(question)
1257
+
1258
+ # 使用增强的查询进行搜索
1259
+ results = self.search(enhanced_query)
706
1260
  if not results:
707
- return None
1261
+ return "未找到与问题相关的文档。请尝试重新表述问题或确认问题相关内容已包含在索引中。"
1262
+
1263
+ # 模型实例
1264
+ model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
708
1265
 
709
- prompt = f"""
710
- # 🤖 Role Definition
711
- You are a document analysis expert who provides accurate and comprehensive answers based on provided documents.
712
-
713
- # 🎯 Core Responsibilities
714
- - Analyze document fragments thoroughly
715
- - Answer questions accurately
716
- - Reference source documents
717
- - Identify missing information
718
- - Maintain professional tone
719
-
720
- # 📋 Answer Requirements
721
- ## Content Quality
722
- - Base answers strictly on provided documents
723
- - Be specific and precise
724
- - Include relevant quotes when helpful
725
- - Indicate any information gaps
726
- - Use professional language
727
-
728
- ## Answer Structure
729
- 1. Direct Answer
730
- - Clear and concise response
731
- - Based on document evidence
732
- - Professional terminology
733
-
734
- 2. Supporting Details
735
- - Relevant document quotes
736
- - File references
737
- - Context explanation
738
-
739
- 3. Information Gaps (if any)
740
- - Missing information
741
- - Additional context needed
742
- - Potential limitations
743
-
744
- # 🔍 Analysis Context
745
- Question: {question}
746
-
747
- Relevant Documents (by relevance):
1266
+ # 计算基础提示词的token数量
1267
+ base_prompt = f"""
1268
+ # 🤖 角色定义
1269
+ 您是一位文档分析专家,能够基于提供的文档提供准确且全面的回答。
1270
+
1271
+ # 🎯 核心职责
1272
+ - 全面分析文档片段
1273
+ - 准确回答问题
1274
+ - 引用源文档
1275
+ - 识别缺失信息
1276
+ - 保持专业语气
1277
+
1278
+ # 📋 回答要求
1279
+ ## 内容质量
1280
+ - 严格基于提供的文档作答
1281
+ - 具体且精确
1282
+ - 在有帮助时引用相关内容
1283
+ - 指出任何信息缺口
1284
+ - 使用专业语言
1285
+
1286
+ ## 回答结构
1287
+ 1. 直接回答
1288
+ - 清晰简洁的回应
1289
+ - 基于文档证据
1290
+ - 专业术语
1291
+
1292
+ 2. 支持细节
1293
+ - 相关文档引用
1294
+ - 文件参考
1295
+ - 上下文解释
1296
+
1297
+ 3. 信息缺口(如有)
1298
+ - 缺失信息
1299
+ - 需要的额外上下文
1300
+ - 潜在限制
1301
+
1302
+ # 🔍 分析上下文
1303
+ 问题: {question}
748
1304
  """
1305
+ base_token_count = get_context_token_count(base_prompt)
1306
+ footer_prompt = """
1307
+ # ❗ 重要规则
1308
+ 1. 仅使用提供的文档
1309
+ 2. 保持精确和准确
1310
+ 3. 在相关时引用来源
1311
+ 4. 指出缺失的信息
1312
+ 5. 保持专业语气
1313
+ 6. 使用用户的语言回答
1314
+ """
1315
+ footer_token_count = get_context_token_count(footer_prompt)
1316
+
1317
+ # 每批可用的token数,减去一些安全余量
1318
+ available_tokens_per_batch = self.max_token_count - base_token_count - footer_token_count - 1000
1319
+
1320
+ # 确定是否需要分批处理
1321
+ with yaspin(text="计算文档上下文大小...", color="cyan") as spinner:
1322
+ # 将结果按文件分组
1323
+ file_groups = {}
1324
+ for doc, score in results:
1325
+ file_path = doc.metadata['file_path']
1326
+ if file_path not in file_groups:
1327
+ file_groups[file_path] = []
1328
+ file_groups[file_path].append((doc, score))
1329
+
1330
+ # 计算所有文档的总token数
1331
+ total_docs_tokens = 0
1332
+ total_len = 0
1333
+ for file_path, docs in file_groups.items():
1334
+ file_header = f"\n## 文件: {file_path}\n"
1335
+ file_tokens = get_context_token_count(file_header)
1336
+
1337
+ # 处理所有相关性足够高的文档
1338
+ for doc, score in docs:
1339
+ if score < 0.2:
1340
+ continue
1341
+ doc_content = f"""
1342
+ ### 片段 {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']} [相关度: {score:.2f}]
1343
+ ```
1344
+ {doc.content}
1345
+ ```
1346
+ """
1347
+ file_tokens += get_context_token_count(doc_content)
1348
+ total_len += len(doc_content)
1349
+ total_docs_tokens += file_tokens
1350
+
1351
+ # 确定是否需要分批处理及分几批
1352
+ need_batching = total_docs_tokens > available_tokens_per_batch
1353
+ batch_count = 1
1354
+ if need_batching:
1355
+ batch_count = (total_docs_tokens + available_tokens_per_batch - 1) // available_tokens_per_batch
1356
+
1357
+ if need_batching:
1358
+ spinner.text = f"文档需要分 {batch_count} 批处理 (总计 {total_docs_tokens} tokens), 总长度 {total_len} 字符"
1359
+ else:
1360
+ spinner.text = f"文档无需分批 (总计 {total_docs_tokens} tokens), 总长度 {total_len} 字符"
1361
+ spinner.ok("✅")
1362
+
1363
+ # 单批处理直接使用原方法
1364
+ if not need_batching:
1365
+ with yaspin(text="添加上下文...", color="cyan") as spinner:
1366
+ prompt = base_prompt
1367
+ current_count = base_token_count
1368
+
1369
+ # 保存已添加的内容指纹,避免重复
1370
+ added_content_hashes = set()
1371
+
1372
+ # 按文件添加文档片段
1373
+ for file_path, docs in file_groups.items():
1374
+ # 按相关性排序
1375
+ docs.sort(key=lambda x: x[1], reverse=True)
1376
+
1377
+ # 添加文件信息
1378
+ file_header = f"\n## 文件: {file_path}\n"
1379
+ if current_count + get_context_token_count(file_header) > available_tokens_per_batch:
1380
+ break
1381
+
1382
+ prompt += file_header
1383
+ current_count += get_context_token_count(file_header)
1384
+
1385
+ # 添加相关的文档片段,不限制每个文件的片段数量
1386
+ for doc, score in docs:
1387
+ # 计算内容指纹以避免重复
1388
+ content_hash = hash(doc.content)
1389
+ if content_hash in added_content_hashes:
1390
+ continue
1391
+
1392
+ # 格式化文档片段
1393
+ doc_content = f"""
1394
+ ### 片段 {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']} [相关度: {score:.2f}]
1395
+ ```
1396
+ {doc.content}
1397
+ ```
1398
+ """
1399
+ if current_count + get_context_token_count(doc_content) > available_tokens_per_batch:
1400
+ break
1401
+
1402
+ prompt += doc_content
1403
+ current_count += get_context_token_count(doc_content)
1404
+ added_content_hashes.add(content_hash)
1405
+
1406
+ prompt += footer_prompt
1407
+ spinner.text = "添加上下文完成"
1408
+ spinner.ok("✅")
749
1409
 
750
- # Add context with length control
751
- with yaspin(text="添加上下文...", color="cyan") as spinner:
752
- available_count = self.max_token_count - get_context_token_count(prompt) - 1000
753
- current_count = 0
1410
+ # 直接生成答案
1411
+ with yaspin(text="正在生成答案...", color="cyan") as spinner:
1412
+ response = model.chat_until_success(prompt)
1413
+ spinner.text = "答案生成完成"
1414
+ spinner.ok("✅")
1415
+ return response
1416
+
1417
+ # 分批处理文档
1418
+ else:
1419
+ batch_responses = []
754
1420
 
755
- for doc, score in results:
756
- doc_content = f"""
757
- ## Document Fragment [Score: {score:.3f}]
758
- Source: {doc.metadata['file_path']}
759
- ```
760
- {doc.content}
761
- ```
762
- ---
763
- """
764
- if current_count + get_context_token_count(doc_content) > available_count:
765
- PrettyOutput.print(
766
- "由于上下文长度限制,部分内容被省略",
767
- output_type=OutputType.WARNING
768
- )
769
- break
1421
+ # 准备批次
1422
+ with yaspin(text=f"准备分批处理 (共{batch_count}批)...", color="cyan") as spinner:
1423
+ batches = []
1424
+ current_batch = []
1425
+ current_batch_tokens = 0
1426
+
1427
+ # 按相关性排序处理文件
1428
+ sorted_files = sorted(file_groups.items(),
1429
+ key=lambda x: max(score for _, score in x[1]) if x[1] else 0,
1430
+ reverse=True)
1431
+
1432
+ for file_path, docs in sorted_files:
1433
+ # 按相关性排序文档
1434
+ docs.sort(key=lambda x: x[1], reverse=True)
770
1435
 
771
- prompt += doc_content
772
- current_count += get_context_token_count(doc_content)
773
-
774
- prompt += """
775
- # ❗ Important Rules
776
- 1. Only use provided documents
777
- 2. Be precise and accurate
778
- 3. Quote sources when relevant
779
- 4. Indicate missing information
780
- 5. Maintain professional tone
781
- 6. Answer in user's language
782
- """
783
- spinner.text = "添加上下文完成"
784
- spinner.ok("")
1436
+ # 处理每个文件的文档
1437
+ file_header = f"\n## 文件: {file_path}\n"
1438
+ file_header_tokens = get_context_token_count(file_header)
1439
+
1440
+ # 如果当前批次添加这个文件会超过限制,创建新批次
1441
+ file_docs = []
1442
+ file_docs_tokens = 0
1443
+
1444
+ # 计算此文件要添加的所有文档,不限制片段数量
1445
+ for doc, score in docs:
1446
+ if score < 0.2: # 过滤低相关性文档
1447
+ continue
1448
+
1449
+ doc_content = f"""
1450
+ ### 片段 {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']} [相关度: {score:.2f}]
1451
+ ```
1452
+ {doc.content}
1453
+ ```
1454
+ """
1455
+ doc_tokens = get_context_token_count(doc_content)
1456
+ file_docs.append((doc, score, doc_content, doc_tokens))
1457
+ file_docs_tokens += doc_tokens
1458
+
1459
+ # 如果此文件的内容加上文件头会导致当前批次超限,创建新批次
1460
+ if current_batch and (current_batch_tokens + file_header_tokens + file_docs_tokens > available_tokens_per_batch):
1461
+ batches.append(current_batch)
1462
+ current_batch = []
1463
+ current_batch_tokens = 0
1464
+
1465
+ # 将文件及其文档添加到当前批次
1466
+ if file_docs: # 如果有要添加的文档
1467
+ current_batch.append((file_path, file_header, file_docs))
1468
+ current_batch_tokens += file_header_tokens + file_docs_tokens
1469
+
1470
+ # 添加最后一个批次
1471
+ if current_batch:
1472
+ batches.append(current_batch)
1473
+
1474
+ spinner.text = f"分批准备完成,共 {len(batches)} 批"
1475
+ spinner.ok("✅")
1476
+
1477
+ # 处理每个批次
1478
+ for batch_idx, batch in enumerate(batches):
1479
+ with yaspin(text=f"处理批次 {batch_idx+1}/{len(batches)}...", color="cyan") as spinner:
1480
+ # 构建批次提示词
1481
+ batch_prompt = base_prompt + f"\n\n## 批次 {batch_idx+1}/{len(batches)} 的相关文档:\n"
1482
+
1483
+ # 添加批次中的文档
1484
+ for file_path, file_header, file_docs in batch:
1485
+ batch_prompt += file_header
1486
+
1487
+ for doc, score, doc_content, _ in file_docs:
1488
+ batch_prompt += doc_content
1489
+
1490
+ # 为最后一个批次添加总结指令,为中间批次添加部分分析指令
1491
+ if batch_idx == len(batches) - 1:
1492
+ # 最后一个批次,添加总结所有批次的指令
1493
+ if len(batches) > 1:
1494
+ batch_prompt += f"""
1495
+ # 📊 汇总分析
1496
+ 这是最后一批文档。请基于此批次和之前批次的分析,提供一个全面的最终回答。
1497
+ """
1498
+ batch_prompt += footer_prompt
1499
+ else:
1500
+ # 中间批次,添加部分分析指令
1501
+ batch_prompt += f"""
1502
+ # 📝 批次分析
1503
+ 这是第 {batch_idx+1}/{len(batches)} 批文档。请分析这批文档中与问题相关的信息。
1504
+ 在你的分析中:
1505
+ 1. 提取关键信息点
1506
+ 2. 识别可能对最终答案有帮助的内容
1507
+ 3. 简明扼要,重点关注与问题直接相关的内容
1508
+ 4. 忽略与问题无关的内容
1509
+ """
1510
+
1511
+ spinner.text = f"正在分析批次 {batch_idx+1}/{len(batches)}..."
1512
+
1513
+ # 调用模型处理当前批次
1514
+ batch_response = model.chat_until_success(batch_prompt)
1515
+ batch_responses.append(batch_response)
1516
+
1517
+ spinner.text = f"批次 {batch_idx+1}/{len(batches)} 分析完成"
1518
+ spinner.ok("✅")
1519
+
1520
+ # 如果只有一个批次,直接返回结果
1521
+ if len(batch_responses) == 1:
1522
+ return batch_responses[0]
1523
+
1524
+ # 如果有多个批次,需要汇总结果
1525
+ with yaspin(text="汇总多批次分析结果...", color="cyan") as spinner:
1526
+ # 构建汇总提示词
1527
+ summary_prompt = f"""
1528
+ # 🔄 批次汇总任务
785
1529
 
786
- with yaspin(text="回答...", color="cyan") as spinner:
787
- model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
788
- response = model.chat_until_success(prompt)
789
- spinner.text = "回答完成"
790
- spinner.ok("✅")
791
- return response
1530
+ ## 原始问题
1531
+ {question}
1532
+
1533
+ ## 多批次分析结果
1534
+ 你已经对相关文档进行了多批次分析,现在需要将这些分析结果汇总成一个连贯、全面的回答。
1535
+
1536
+ 以下是各批次的分析结果:
1537
+
1538
+ """
1539
+
1540
+ # 添加每个批次的分析结果
1541
+ for i, response in enumerate(batch_responses):
1542
+ summary_prompt += f"""
1543
+ ### 批次 {i+1} 分析结果
1544
+ {response}
1545
+
1546
+ """
1547
+
1548
+ # 添加汇总指导
1549
+ summary_prompt += """
1550
+ ## 汇总要求
1551
+ 请基于以上所有批次的分析结果,提供一个综合、连贯的最终回答。
1552
+
1553
+ # 🎯 核心职责
1554
+ - 全面分析文档片段
1555
+ - 准确回答问题
1556
+ - 引用源文档
1557
+ - 识别缺失信息
1558
+ - 保持专业语气
1559
+
1560
+ # 📋 回答要求
1561
+ ## 内容质量
1562
+ - 严格基于提供的文档作答
1563
+ - 具体且精确
1564
+ - 在有帮助时引用相关内容
1565
+ - 指出任何信息缺口
1566
+ - 使用专业语言
1567
+
1568
+ ## 回答结构
1569
+ 1. 直接回答
1570
+ - 清晰简洁的回应
1571
+ - 基于文档证据
1572
+ - 专业术语
1573
+
1574
+ 2. 支持细节
1575
+ - 相关文档引用
1576
+ - 文件参考
1577
+ - 上下文解释
1578
+
1579
+ 3. 信息缺口(如有)
1580
+ - 缺失信息
1581
+ - 需要的额外上下文
1582
+ - 潜在限制
1583
+
1584
+ 请直接提供最终回答,不需要解释你的汇总过程。
1585
+ """
1586
+
1587
+ spinner.text = "正在生成最终汇总答案..."
1588
+
1589
+ # 调用模型生成最终汇总
1590
+ final_response = model.chat_until_success(summary_prompt)
1591
+
1592
+ spinner.text = "汇总答案生成完成"
1593
+ spinner.ok("✅")
1594
+
1595
+ return final_response
792
1596
 
793
1597
  except Exception as e:
794
1598
  PrettyOutput.print(f"回答失败:{str(e)}", OutputType.ERROR)
795
1599
  return None
1600
+
1601
+ def _enhance_query(self, query: str) -> str:
1602
+ """增强查询以提高检索质量
1603
+
1604
+ Args:
1605
+ query: 原始查询
1606
+
1607
+ Returns:
1608
+ str: 增强后的查询
1609
+ """
1610
+ # 简单的查询预处理
1611
+ query = query.strip()
1612
+
1613
+ # 如果查询太短,返回原始查询
1614
+ if len(query) < 10:
1615
+ return query
1616
+
1617
+ try:
1618
+ # 尝试使用大模型增强查询(如果可用)
1619
+ model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
1620
+ enhance_prompt = f"""请分析以下查询,提取关键概念、关键词和主题。
1621
+
1622
+ 查询:"{query}"
1623
+
1624
+ 输出格式:对原始查询的改写版本,专注于提取关键信息,保留原始语义,以提高检索相关度。
1625
+ 仅输出改写后的查询文本,不要输出其他内容。
1626
+ 只对信息进行最小必要的增强,不要过度添加与原始查询无关的内容。
1627
+ """
1628
+
1629
+ enhanced_query = model.chat_until_success(enhance_prompt)
1630
+ # 清理增强的查询结果
1631
+ enhanced_query = enhanced_query.strip().strip('"')
1632
+
1633
+ # 如果增强查询有效且不是完全相同的,使用它
1634
+ if enhanced_query and len(enhanced_query) >= len(query) / 2 and enhanced_query != query:
1635
+ return enhanced_query
1636
+
1637
+ except Exception:
1638
+ # 如果增强失败,使用原始查询
1639
+ pass
1640
+
1641
+ return query
796
1642
 
797
1643
  def is_index_built(self) -> bool:
798
1644
  """Check if the index is built and valid
@@ -802,6 +1648,33 @@ Relevant Documents (by relevance):
802
1648
  """
803
1649
  return self.index is not None and len(self.documents) > 0
804
1650
 
1651
+ def _delete_file_cache(self, file_path: str, spinner=None):
1652
+ """Delete cache files for a specific file
1653
+
1654
+ Args:
1655
+ file_path: Path to the original file
1656
+ spinner: Optional spinner for progress information. If None, runs silently.
1657
+ """
1658
+ try:
1659
+ # Delete document cache
1660
+ doc_cache_path = self._get_cache_path(file_path, "doc")
1661
+ if os.path.exists(doc_cache_path):
1662
+ os.remove(doc_cache_path)
1663
+ if spinner is not None:
1664
+ spinner.write(f"🗑️ 删除文档缓存: {file_path}")
1665
+
1666
+ # Delete vector cache
1667
+ vec_cache_path = self._get_cache_path(file_path, "vec")
1668
+ if os.path.exists(vec_cache_path):
1669
+ os.remove(vec_cache_path)
1670
+ if spinner is not None:
1671
+ spinner.write(f"🗑️ 删除向量缓存: {file_path}")
1672
+
1673
+ except Exception as e:
1674
+ if spinner is not None:
1675
+ spinner.write(f"❌ 删除缓存失败: {file_path}: {str(e)}")
1676
+ PrettyOutput.print(f"删除缓存失败: {file_path}: {str(e)}", output_type=OutputType.ERROR)
1677
+
805
1678
  def main():
806
1679
  """Main function"""
807
1680
  import argparse
@@ -828,11 +1701,18 @@ def main():
828
1701
  args.dir = current_dir
829
1702
 
830
1703
  if args.dir and args.build:
831
- PrettyOutput.print(f"正在处理目录: {args.dir}", output_type=OutputType.INFO)
832
1704
  rag.build_index(args.dir)
833
1705
  return 0
834
1706
 
835
1707
  if args.search or args.ask:
1708
+ # 当需要搜索或提问时,自动检查并建立索引
1709
+ if not rag.is_index_built():
1710
+ PrettyOutput.print(f"索引未建立,自动为目录 '{args.dir}' 建立索引...", OutputType.INFO)
1711
+ rag.build_index(args.dir)
1712
+
1713
+ if not rag.is_index_built():
1714
+ PrettyOutput.print("索引建立失败,请检查目录和文件格式", OutputType.ERROR)
1715
+ return 1
836
1716
 
837
1717
  if args.search:
838
1718
  results = rag.query(args.search)
@@ -855,7 +1735,7 @@ def main():
855
1735
  return 1
856
1736
 
857
1737
  # Display answer
858
- output = f"""答案:\n{response}"""
1738
+ output = f"""{response}"""
859
1739
  PrettyOutput.print(output, output_type=OutputType.INFO)
860
1740
  return 0
861
1741