jarvis-ai-assistant 0.1.130__py3-none-any.whl → 0.1.132__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +71 -38
- jarvis/jarvis_agent/builtin_input_handler.py +73 -0
- jarvis/{jarvis_code_agent → jarvis_agent}/file_input_handler.py +1 -1
- jarvis/jarvis_agent/main.py +1 -1
- jarvis/{jarvis_code_agent → jarvis_agent}/patch.py +77 -55
- jarvis/{jarvis_code_agent → jarvis_agent}/shell_input_handler.py +1 -2
- jarvis/jarvis_code_agent/code_agent.py +93 -88
- jarvis/jarvis_dev/main.py +335 -626
- jarvis/jarvis_git_squash/main.py +11 -32
- jarvis/jarvis_lsp/base.py +2 -26
- jarvis/jarvis_lsp/cpp.py +2 -14
- jarvis/jarvis_lsp/go.py +0 -13
- jarvis/jarvis_lsp/python.py +1 -30
- jarvis/jarvis_lsp/registry.py +10 -14
- jarvis/jarvis_lsp/rust.py +0 -12
- jarvis/jarvis_multi_agent/__init__.py +20 -29
- jarvis/jarvis_platform/ai8.py +7 -32
- jarvis/jarvis_platform/base.py +2 -7
- jarvis/jarvis_platform/kimi.py +3 -144
- jarvis/jarvis_platform/ollama.py +54 -68
- jarvis/jarvis_platform/openai.py +0 -4
- jarvis/jarvis_platform/oyi.py +0 -75
- jarvis/jarvis_platform/registry.py +1 -1
- jarvis/jarvis_platform/yuanbao.py +264 -0
- jarvis/jarvis_platform_manager/main.py +3 -3
- jarvis/jarvis_rag/file_processors.py +138 -0
- jarvis/jarvis_rag/main.py +1305 -425
- jarvis/jarvis_tools/ask_codebase.py +227 -41
- jarvis/jarvis_tools/code_review.py +229 -166
- jarvis/jarvis_tools/create_code_agent.py +76 -72
- jarvis/jarvis_tools/create_sub_agent.py +32 -15
- jarvis/jarvis_tools/execute_python_script.py +58 -0
- jarvis/jarvis_tools/execute_shell.py +15 -28
- jarvis/jarvis_tools/execute_shell_script.py +2 -2
- jarvis/jarvis_tools/file_analyzer.py +271 -0
- jarvis/jarvis_tools/file_operation.py +3 -3
- jarvis/jarvis_tools/find_caller.py +213 -0
- jarvis/jarvis_tools/find_symbol.py +211 -0
- jarvis/jarvis_tools/function_analyzer.py +248 -0
- jarvis/jarvis_tools/git_commiter.py +89 -70
- jarvis/jarvis_tools/lsp_find_definition.py +83 -67
- jarvis/jarvis_tools/lsp_find_references.py +62 -46
- jarvis/jarvis_tools/lsp_get_diagnostics.py +90 -74
- jarvis/jarvis_tools/methodology.py +89 -48
- jarvis/jarvis_tools/project_analyzer.py +220 -0
- jarvis/jarvis_tools/read_code.py +24 -3
- jarvis/jarvis_tools/read_webpage.py +195 -81
- jarvis/jarvis_tools/registry.py +132 -11
- jarvis/jarvis_tools/search_web.py +73 -30
- jarvis/jarvis_tools/tool_generator.py +7 -9
- jarvis/jarvis_utils/__init__.py +1 -0
- jarvis/jarvis_utils/config.py +67 -3
- jarvis/jarvis_utils/embedding.py +344 -45
- jarvis/jarvis_utils/git_utils.py +18 -2
- jarvis/jarvis_utils/input.py +7 -4
- jarvis/jarvis_utils/methodology.py +379 -7
- jarvis/jarvis_utils/output.py +5 -3
- jarvis/jarvis_utils/utils.py +62 -10
- {jarvis_ai_assistant-0.1.130.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/METADATA +3 -4
- jarvis_ai_assistant-0.1.132.dist-info/RECORD +82 -0
- {jarvis_ai_assistant-0.1.130.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/entry_points.txt +2 -0
- jarvis/jarvis_c2rust/c2rust.yaml +0 -734
- jarvis/jarvis_code_agent/builtin_input_handler.py +0 -43
- jarvis/jarvis_codebase/__init__.py +0 -0
- jarvis/jarvis_codebase/main.py +0 -1011
- jarvis/jarvis_tools/lsp_get_document_symbols.py +0 -87
- jarvis/jarvis_tools/lsp_prepare_rename.py +0 -130
- jarvis_ai_assistant-0.1.130.dist-info/RECORD +0 -79
- {jarvis_ai_assistant-0.1.130.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.130.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.130.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/top_level.txt +0 -0
jarvis/jarvis_rag/main.py
CHANGED
|
@@ -1,12 +1,10 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import re
|
|
2
3
|
import numpy as np
|
|
3
4
|
import faiss
|
|
4
5
|
from typing import List, Tuple, Optional, Dict
|
|
5
6
|
import pickle
|
|
6
7
|
from dataclasses import dataclass
|
|
7
|
-
from tqdm import tqdm
|
|
8
|
-
import fitz # PyMuPDF for PDF files
|
|
9
|
-
from docx import Document as DocxDocument # python-docx for DOCX files
|
|
10
8
|
from pathlib import Path
|
|
11
9
|
|
|
12
10
|
from yaspin import yaspin
|
|
@@ -15,10 +13,30 @@ import lzma # 添加 lzma 导入
|
|
|
15
13
|
from threading import Lock
|
|
16
14
|
import hashlib
|
|
17
15
|
|
|
18
|
-
from jarvis.jarvis_utils.config import get_max_paragraph_length, get_max_token_count, get_min_paragraph_length, get_thread_count
|
|
19
|
-
from jarvis.jarvis_utils.embedding import get_context_token_count, get_embedding, get_embedding_batch, load_embedding_model
|
|
16
|
+
from jarvis.jarvis_utils.config import get_max_paragraph_length, get_max_token_count, get_min_paragraph_length, get_thread_count, get_rag_ignored_paths
|
|
17
|
+
from jarvis.jarvis_utils.embedding import get_context_token_count, get_embedding, get_embedding_batch, load_embedding_model, rerank_results
|
|
20
18
|
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
|
21
|
-
from jarvis.jarvis_utils.utils import get_file_md5, init_env, init_gpu_config
|
|
19
|
+
from jarvis.jarvis_utils.utils import ct, get_file_md5, init_env, init_gpu_config, ot
|
|
20
|
+
|
|
21
|
+
from .file_processors import TextFileProcessor, PDFProcessor, DocxProcessor, PPTProcessor, ExcelProcessor
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
Jarvis RAG (Retrieval-Augmented Generation) Module
|
|
25
|
+
|
|
26
|
+
这个模块实现了高效的本地RAG系统,具有以下特性:
|
|
27
|
+
1. 多格式文档处理(文本、PDF、Word、PPT、Excel)
|
|
28
|
+
2. 高效向量检索与关键词匹配相结合的混合搜索
|
|
29
|
+
3. 交叉编码器重排序,大幅提升检索准确性
|
|
30
|
+
4. 增量更新检测,避免重复处理
|
|
31
|
+
5. 自动上下文扩展,提供更完整信息
|
|
32
|
+
6. 针对RAG优化的文本分割,保持语义完整性
|
|
33
|
+
7. 缓存机制,提高反复查询性能
|
|
34
|
+
8. 批处理向量化,优化内存和计算资源使用
|
|
35
|
+
9. 多线程处理能力
|
|
36
|
+
10. GPU加速(如果可用)
|
|
37
|
+
|
|
38
|
+
适用于:代码库文档检索、知识库问答、本地资料分析等场景
|
|
39
|
+
"""
|
|
22
40
|
|
|
23
41
|
@dataclass
|
|
24
42
|
class Document:
|
|
@@ -27,111 +45,7 @@ class Document:
|
|
|
27
45
|
metadata: Dict # Metadata (file path, position, etc.)
|
|
28
46
|
md5: str = "" # File MD5 value, for incremental update detection
|
|
29
47
|
|
|
30
|
-
class FileProcessor:
|
|
31
|
-
"""Base class for file processor"""
|
|
32
|
-
@staticmethod
|
|
33
|
-
def can_handle(file_path: str) -> bool:
|
|
34
|
-
"""Determine if the file can be processed"""
|
|
35
|
-
raise NotImplementedError
|
|
36
|
-
|
|
37
|
-
@staticmethod
|
|
38
|
-
def extract_text(file_path: str) -> str:
|
|
39
|
-
"""Extract file text content"""
|
|
40
|
-
raise NotImplementedError
|
|
41
|
-
|
|
42
|
-
class TextFileProcessor(FileProcessor):
|
|
43
|
-
"""Text file processor"""
|
|
44
|
-
ENCODINGS = ['utf-8', 'gbk', 'gb2312', 'latin1']
|
|
45
|
-
SAMPLE_SIZE = 8192 # Read the first 8KB to detect encoding
|
|
46
|
-
|
|
47
|
-
@staticmethod
|
|
48
|
-
def can_handle(file_path: str) -> bool:
|
|
49
|
-
"""Determine if the file is a text file by trying to decode it"""
|
|
50
|
-
try:
|
|
51
|
-
# Read the first part of the file to detect encoding
|
|
52
|
-
with open(file_path, 'rb') as f:
|
|
53
|
-
sample = f.read(TextFileProcessor.SAMPLE_SIZE)
|
|
54
|
-
|
|
55
|
-
# Check if it contains null bytes (usually represents a binary file)
|
|
56
|
-
if b'\x00' in sample:
|
|
57
|
-
return False
|
|
58
|
-
|
|
59
|
-
# Check if it contains too many non-printable characters (usually represents a binary file)
|
|
60
|
-
non_printable = sum(1 for byte in sample if byte < 32 and byte not in (9, 10, 13)) # tab, newline, carriage return
|
|
61
|
-
if non_printable / len(sample) > 0.3: # If non-printable characters exceed 30%, it is considered a binary file
|
|
62
|
-
return False
|
|
63
|
-
|
|
64
|
-
# Try to decode with different encodings
|
|
65
|
-
for encoding in TextFileProcessor.ENCODINGS:
|
|
66
|
-
try:
|
|
67
|
-
sample.decode(encoding)
|
|
68
|
-
return True
|
|
69
|
-
except UnicodeDecodeError:
|
|
70
|
-
continue
|
|
71
|
-
|
|
72
|
-
return False
|
|
73
|
-
|
|
74
|
-
except Exception:
|
|
75
|
-
return False
|
|
76
|
-
|
|
77
|
-
@staticmethod
|
|
78
|
-
def extract_text(file_path: str) -> str:
|
|
79
|
-
"""Extract text content, using the detected correct encoding"""
|
|
80
|
-
detected_encoding = None
|
|
81
|
-
try:
|
|
82
|
-
# First try to detect encoding
|
|
83
|
-
with open(file_path, 'rb') as f:
|
|
84
|
-
raw_data = f.read()
|
|
85
|
-
|
|
86
|
-
# Try different encodings
|
|
87
|
-
for encoding in TextFileProcessor.ENCODINGS:
|
|
88
|
-
try:
|
|
89
|
-
raw_data.decode(encoding)
|
|
90
|
-
detected_encoding = encoding
|
|
91
|
-
break
|
|
92
|
-
except UnicodeDecodeError:
|
|
93
|
-
continue
|
|
94
|
-
|
|
95
|
-
if not detected_encoding:
|
|
96
|
-
raise UnicodeDecodeError(f"Failed to decode file with supported encodings: {file_path}") # type: ignore
|
|
97
|
-
|
|
98
|
-
# Use the detected encoding to read the file
|
|
99
|
-
with open(file_path, 'r', encoding=detected_encoding, errors='replace') as f:
|
|
100
|
-
content = f.read()
|
|
101
|
-
|
|
102
|
-
# Normalize Unicode characters
|
|
103
|
-
import unicodedata
|
|
104
|
-
content = unicodedata.normalize('NFKC', content)
|
|
105
|
-
|
|
106
|
-
return content
|
|
107
|
-
|
|
108
|
-
except Exception as e:
|
|
109
|
-
raise Exception(f"Failed to read file: {str(e)}")
|
|
110
48
|
|
|
111
|
-
class PDFProcessor(FileProcessor):
|
|
112
|
-
"""PDF file processor"""
|
|
113
|
-
@staticmethod
|
|
114
|
-
def can_handle(file_path: str) -> bool:
|
|
115
|
-
return Path(file_path).suffix.lower() == '.pdf'
|
|
116
|
-
|
|
117
|
-
@staticmethod
|
|
118
|
-
def extract_text(file_path: str) -> str:
|
|
119
|
-
text_parts = []
|
|
120
|
-
with fitz.open(file_path) as doc: # type: ignore
|
|
121
|
-
for page in doc:
|
|
122
|
-
text_parts.append(page.get_text()) # type: ignore
|
|
123
|
-
return "\n".join(text_parts)
|
|
124
|
-
|
|
125
|
-
class DocxProcessor(FileProcessor):
|
|
126
|
-
"""DOCX file processor"""
|
|
127
|
-
@staticmethod
|
|
128
|
-
def can_handle(file_path: str) -> bool:
|
|
129
|
-
return Path(file_path).suffix.lower() == '.docx'
|
|
130
|
-
|
|
131
|
-
@staticmethod
|
|
132
|
-
def extract_text(file_path: str) -> str:
|
|
133
|
-
doc = DocxDocument(file_path)
|
|
134
|
-
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
|
135
49
|
|
|
136
50
|
class RAGTool:
|
|
137
51
|
def __init__(self, root_dir: str):
|
|
@@ -196,7 +110,9 @@ class RAGTool:
|
|
|
196
110
|
self.file_processors = [
|
|
197
111
|
TextFileProcessor(),
|
|
198
112
|
PDFProcessor(),
|
|
199
|
-
DocxProcessor()
|
|
113
|
+
DocxProcessor(),
|
|
114
|
+
PPTProcessor(),
|
|
115
|
+
ExcelProcessor()
|
|
200
116
|
]
|
|
201
117
|
spinner.text = "文件处理器初始化完成"
|
|
202
118
|
spinner.ok("✅")
|
|
@@ -211,23 +127,38 @@ class RAGTool:
|
|
|
211
127
|
|
|
212
128
|
# 初始化 GPU 内存配置
|
|
213
129
|
with yaspin(text="初始化 GPU 内存配置...", color="cyan") as spinner:
|
|
214
|
-
|
|
130
|
+
with spinner.hidden():
|
|
131
|
+
self.gpu_config = init_gpu_config()
|
|
215
132
|
spinner.text = "GPU 内存配置初始化完成"
|
|
216
133
|
spinner.ok("✅")
|
|
217
134
|
|
|
218
135
|
|
|
219
|
-
def _get_cache_path(self, file_path: str) -> str:
|
|
136
|
+
def _get_cache_path(self, file_path: str, cache_type: str = "doc") -> str:
|
|
220
137
|
"""Get cache file path for a document
|
|
221
138
|
|
|
222
139
|
Args:
|
|
223
140
|
file_path: Original file path
|
|
141
|
+
cache_type: Type of cache ("doc" for documents, "vec" for vectors)
|
|
224
142
|
|
|
225
143
|
Returns:
|
|
226
144
|
str: Cache file path
|
|
227
145
|
"""
|
|
228
146
|
# 使用文件路径的哈希作为缓存文件名
|
|
229
147
|
file_hash = hashlib.md5(file_path.encode()).hexdigest()
|
|
230
|
-
|
|
148
|
+
|
|
149
|
+
# 确保不同类型的缓存有不同的目录
|
|
150
|
+
if cache_type == "doc":
|
|
151
|
+
cache_subdir = os.path.join(self.cache_dir, "documents")
|
|
152
|
+
elif cache_type == "vec":
|
|
153
|
+
cache_subdir = os.path.join(self.cache_dir, "vectors")
|
|
154
|
+
else:
|
|
155
|
+
cache_subdir = self.cache_dir
|
|
156
|
+
|
|
157
|
+
# 确保子目录存在
|
|
158
|
+
if not os.path.exists(cache_subdir):
|
|
159
|
+
os.makedirs(cache_subdir)
|
|
160
|
+
|
|
161
|
+
return os.path.join(cache_subdir, f"{file_hash}.cache")
|
|
231
162
|
|
|
232
163
|
def _load_cache_index(self):
|
|
233
164
|
"""Load cache index"""
|
|
@@ -244,37 +175,64 @@ class RAGTool:
|
|
|
244
175
|
# 从各个缓存文件加载文档
|
|
245
176
|
with yaspin(text="加载缓存文件...", color="cyan") as spinner:
|
|
246
177
|
for file_path in self.file_md5_cache:
|
|
247
|
-
|
|
248
|
-
if os.path.exists(
|
|
178
|
+
doc_cache_path = self._get_cache_path(file_path, "doc")
|
|
179
|
+
if os.path.exists(doc_cache_path):
|
|
249
180
|
try:
|
|
250
|
-
with lzma.open(
|
|
251
|
-
|
|
252
|
-
self.documents.extend(
|
|
253
|
-
spinner.
|
|
181
|
+
with lzma.open(doc_cache_path, 'rb') as f:
|
|
182
|
+
doc_cache_data = pickle.load(f)
|
|
183
|
+
self.documents.extend(doc_cache_data["documents"])
|
|
184
|
+
spinner.text = f"加载文档缓存: {file_path}"
|
|
254
185
|
except Exception as e:
|
|
255
|
-
spinner.write(f"❌
|
|
256
|
-
spinner.text = "
|
|
186
|
+
spinner.write(f"❌ 加载文档缓存失败: {file_path}: {str(e)}")
|
|
187
|
+
spinner.text = "文档缓存加载完成"
|
|
257
188
|
spinner.ok("✅")
|
|
258
189
|
|
|
259
190
|
# 重建向量索引
|
|
260
|
-
|
|
261
191
|
if self.documents:
|
|
262
192
|
with yaspin(text="重建向量索引...", color="cyan") as spinner:
|
|
263
193
|
vectors = []
|
|
194
|
+
|
|
195
|
+
# 按照文档列表顺序加载向量
|
|
196
|
+
processed_files = set()
|
|
264
197
|
for doc in self.documents:
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
198
|
+
file_path = doc.metadata['file_path']
|
|
199
|
+
|
|
200
|
+
# 避免重复处理同一个文件
|
|
201
|
+
if file_path in processed_files:
|
|
202
|
+
continue
|
|
203
|
+
|
|
204
|
+
processed_files.add(file_path)
|
|
205
|
+
vec_cache_path = self._get_cache_path(file_path, "vec")
|
|
206
|
+
|
|
207
|
+
if os.path.exists(vec_cache_path):
|
|
208
|
+
try:
|
|
209
|
+
# 加载该文件的向量缓存
|
|
210
|
+
with lzma.open(vec_cache_path, 'rb') as f:
|
|
211
|
+
vec_cache_data = pickle.load(f)
|
|
212
|
+
file_vectors = vec_cache_data["vectors"]
|
|
213
|
+
|
|
214
|
+
# 按照文档的chunk_index检索对应向量
|
|
215
|
+
doc_indices = [d.metadata['chunk_index'] for d in self.documents
|
|
216
|
+
if d.metadata['file_path'] == file_path]
|
|
217
|
+
|
|
218
|
+
# 检查向量数量与文档块数量是否匹配
|
|
219
|
+
if len(doc_indices) <= file_vectors.shape[0]:
|
|
220
|
+
for idx in doc_indices:
|
|
221
|
+
if idx < file_vectors.shape[0]:
|
|
222
|
+
vectors.append(file_vectors[idx].reshape(1, -1))
|
|
223
|
+
else:
|
|
224
|
+
spinner.write(f"⚠️ 向量缓存不匹配: {file_path}")
|
|
225
|
+
|
|
226
|
+
spinner.text = f"加载向量缓存: {file_path}"
|
|
227
|
+
except Exception as e:
|
|
228
|
+
spinner.write(f"❌ 加载向量缓存失败: {file_path}: {str(e)}")
|
|
229
|
+
else:
|
|
230
|
+
spinner.write(f"⚠️ 缺少向量缓存: {file_path}")
|
|
273
231
|
|
|
274
232
|
if vectors:
|
|
275
233
|
vectors = np.vstack(vectors)
|
|
276
|
-
self._build_index(vectors)
|
|
277
|
-
spinner.text = "向量索引重建完成,加载 {len(self.documents)} 个文档片段"
|
|
234
|
+
self._build_index(vectors, spinner)
|
|
235
|
+
spinner.text = f"向量索引重建完成,加载 {len(self.documents)} 个文档片段"
|
|
278
236
|
spinner.ok("✅")
|
|
279
237
|
|
|
280
238
|
except Exception as e:
|
|
@@ -285,67 +243,126 @@ class RAGTool:
|
|
|
285
243
|
self.flat_index = None
|
|
286
244
|
self.file_md5_cache = {}
|
|
287
245
|
|
|
288
|
-
def _save_cache(self, file_path: str, documents: List[Document], vectors: np.ndarray):
|
|
246
|
+
def _save_cache(self, file_path: str, documents: List[Document], vectors: np.ndarray, spinner=None):
|
|
289
247
|
"""Save cache for a single file
|
|
290
248
|
|
|
291
249
|
Args:
|
|
292
250
|
file_path: File path
|
|
293
251
|
documents: List of documents
|
|
294
252
|
vectors: Document vectors
|
|
253
|
+
spinner: Optional spinner for progress display
|
|
295
254
|
"""
|
|
296
255
|
try:
|
|
297
|
-
#
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
256
|
+
# 保存文档缓存
|
|
257
|
+
if spinner:
|
|
258
|
+
spinner.text = f"保存 {file_path} 的文档缓存..."
|
|
259
|
+
doc_cache_path = self._get_cache_path(file_path, "doc")
|
|
260
|
+
doc_cache_data = {
|
|
261
|
+
"documents": documents
|
|
262
|
+
}
|
|
263
|
+
with lzma.open(doc_cache_path, 'wb') as f:
|
|
264
|
+
pickle.dump(doc_cache_data, f)
|
|
265
|
+
|
|
266
|
+
# 保存向量缓存
|
|
267
|
+
if spinner:
|
|
268
|
+
spinner.text = f"保存 {file_path} 的向量缓存..."
|
|
269
|
+
vec_cache_path = self._get_cache_path(file_path, "vec")
|
|
270
|
+
vec_cache_data = {
|
|
301
271
|
"vectors": vectors
|
|
302
272
|
}
|
|
303
|
-
with lzma.open(
|
|
304
|
-
pickle.dump(
|
|
273
|
+
with lzma.open(vec_cache_path, 'wb') as f:
|
|
274
|
+
pickle.dump(vec_cache_data, f)
|
|
305
275
|
|
|
306
276
|
# 更新并保存索引
|
|
277
|
+
if spinner:
|
|
278
|
+
spinner.text = f"更新 {file_path} 的索引缓存..."
|
|
307
279
|
index_path = os.path.join(self.data_dir, "index.pkl")
|
|
308
280
|
index_data = {
|
|
309
281
|
"file_md5_cache": self.file_md5_cache
|
|
310
282
|
}
|
|
311
283
|
with lzma.open(index_path, 'wb') as f:
|
|
312
284
|
pickle.dump(index_data, f)
|
|
285
|
+
|
|
286
|
+
if spinner:
|
|
287
|
+
spinner.text = f"{file_path} 的缓存保存完成"
|
|
313
288
|
|
|
314
289
|
except Exception as e:
|
|
290
|
+
if spinner:
|
|
291
|
+
spinner.text = f"保存 {file_path} 的缓存失败: {str(e)}"
|
|
315
292
|
PrettyOutput.print(f"保存缓存失败: {str(e)}", output_type=OutputType.ERROR)
|
|
316
293
|
|
|
317
|
-
def _build_index(self, vectors: np.ndarray):
|
|
294
|
+
def _build_index(self, vectors: np.ndarray, spinner=None):
|
|
318
295
|
"""Build FAISS index"""
|
|
319
296
|
if vectors.shape[0] == 0:
|
|
297
|
+
if spinner:
|
|
298
|
+
spinner.text = "向量为空,跳过索引构建"
|
|
320
299
|
self.index = None
|
|
321
300
|
self.flat_index = None
|
|
322
301
|
return
|
|
323
302
|
|
|
324
303
|
# Create a flat index to store original vectors, for reconstruction
|
|
304
|
+
if spinner:
|
|
305
|
+
spinner.text = "创建平面索引用于向量重建..."
|
|
325
306
|
self.flat_index = faiss.IndexFlatIP(self.vector_dim)
|
|
326
307
|
self.flat_index.add(vectors) # type: ignore
|
|
327
308
|
|
|
328
309
|
# Create an IVF index for fast search
|
|
329
|
-
|
|
310
|
+
if spinner:
|
|
311
|
+
spinner.text = "创建IVF索引用于快速搜索..."
|
|
312
|
+
# 修改聚类中心的计算方式,小数据量时使用更少的聚类中心
|
|
313
|
+
# 避免"WARNING clustering X points to Y centroids: please provide at least Z training points"警告
|
|
314
|
+
num_vectors = vectors.shape[0]
|
|
315
|
+
if num_vectors < 100:
|
|
316
|
+
# 对于小于100个向量的情况,使用更少的聚类中心
|
|
317
|
+
nlist = 1 # 只用1个聚类中心
|
|
318
|
+
elif num_vectors < 1000:
|
|
319
|
+
# 对于100-1000个向量的情况,使用较少的聚类中心
|
|
320
|
+
nlist = max(1, int(num_vectors / 100)) # 每100个向量一个聚类中心
|
|
321
|
+
else:
|
|
322
|
+
# 原始逻辑:每1000个向量一个聚类中心,最少4个
|
|
323
|
+
nlist = max(4, int(num_vectors / 1000))
|
|
324
|
+
|
|
330
325
|
quantizer = faiss.IndexFlatIP(self.vector_dim)
|
|
331
326
|
self.index = faiss.IndexIVFFlat(quantizer, self.vector_dim, nlist, faiss.METRIC_INNER_PRODUCT)
|
|
332
327
|
|
|
333
328
|
# Train and add vectors
|
|
329
|
+
if spinner:
|
|
330
|
+
spinner.text = f"训练索引({vectors.shape[0]}个向量,{nlist}个聚类中心)..."
|
|
334
331
|
self.index.train(vectors) # type: ignore
|
|
332
|
+
|
|
333
|
+
if spinner:
|
|
334
|
+
spinner.text = "添加向量到索引..."
|
|
335
335
|
self.index.add(vectors) # type: ignore
|
|
336
|
+
|
|
336
337
|
# Set the number of clusters to probe during search
|
|
338
|
+
if spinner:
|
|
339
|
+
spinner.text = "设置搜索参数..."
|
|
337
340
|
self.index.nprobe = min(nlist, 10)
|
|
341
|
+
|
|
342
|
+
if spinner:
|
|
343
|
+
spinner.text = f"索引构建完成,共 {vectors.shape[0]} 个向量"
|
|
338
344
|
|
|
339
345
|
def _split_text(self, text: str) -> List[str]:
|
|
340
|
-
"""
|
|
341
|
-
|
|
342
|
-
|
|
346
|
+
"""使用基于token计数的更智能的分割策略
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
text: 要分割的文本
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
List[str]: 分割后的段落列表
|
|
353
|
+
"""
|
|
354
|
+
from jarvis.jarvis_utils.embedding import get_context_token_count
|
|
355
|
+
|
|
356
|
+
# 计算可用的最大和最小token数
|
|
357
|
+
max_tokens = int(self.max_paragraph_length * 0.25) # 字符长度转换为大致token数
|
|
358
|
+
min_tokens = int(self.min_paragraph_length * 0.25) # 字符长度转换为大致token数
|
|
343
359
|
|
|
360
|
+
# 添加重叠块以保持上下文一致性
|
|
344
361
|
paragraphs = []
|
|
345
362
|
current_chunk = []
|
|
346
|
-
|
|
363
|
+
current_token_count = 0
|
|
347
364
|
|
|
348
|
-
#
|
|
365
|
+
# 首先按句子分割
|
|
349
366
|
sentences = []
|
|
350
367
|
current_sentence = []
|
|
351
368
|
sentence_ends = {'。', '!', '?', '…', '.', '!', '?'}
|
|
@@ -363,77 +380,73 @@ class RAGTool:
|
|
|
363
380
|
if sentence.strip():
|
|
364
381
|
sentences.append(sentence)
|
|
365
382
|
|
|
366
|
-
#
|
|
383
|
+
# 基于句子构建重叠块
|
|
367
384
|
for sentence in sentences:
|
|
368
|
-
|
|
385
|
+
# 计算当前句子的token数
|
|
386
|
+
sentence_token_count = get_context_token_count(sentence)
|
|
387
|
+
|
|
388
|
+
# 检查添加此句子是否会超过最大token限制
|
|
389
|
+
if current_token_count + sentence_token_count > max_tokens:
|
|
369
390
|
if current_chunk:
|
|
370
391
|
chunk_text = ' '.join(current_chunk)
|
|
371
|
-
|
|
392
|
+
chunk_token_count = get_context_token_count(chunk_text)
|
|
393
|
+
|
|
394
|
+
if chunk_token_count >= min_tokens:
|
|
372
395
|
paragraphs.append(chunk_text)
|
|
396
|
+
|
|
397
|
+
# 保留一些内容作为重叠
|
|
398
|
+
# 保留最后两个句子作为重叠部分
|
|
399
|
+
if len(current_chunk) >= 2:
|
|
400
|
+
overlap_text = ' '.join(current_chunk[-2:])
|
|
401
|
+
overlap_token_count = get_context_token_count(overlap_text)
|
|
373
402
|
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
403
|
+
current_chunk = []
|
|
404
|
+
if overlap_text:
|
|
405
|
+
current_chunk.append(overlap_text)
|
|
406
|
+
current_token_count = overlap_token_count
|
|
407
|
+
else:
|
|
408
|
+
current_token_count = 0
|
|
380
409
|
else:
|
|
381
|
-
|
|
382
|
-
|
|
410
|
+
# 如果当前块中句子不足两个,就重置
|
|
411
|
+
current_chunk = []
|
|
412
|
+
current_token_count = 0
|
|
413
|
+
|
|
414
|
+
# 添加当前句子到块中
|
|
383
415
|
current_chunk.append(sentence)
|
|
384
|
-
|
|
416
|
+
current_token_count += sentence_token_count
|
|
385
417
|
|
|
386
|
-
#
|
|
418
|
+
# 处理最后一个块
|
|
387
419
|
if current_chunk:
|
|
388
420
|
chunk_text = ' '.join(current_chunk)
|
|
389
|
-
|
|
421
|
+
chunk_token_count = get_context_token_count(chunk_text)
|
|
422
|
+
|
|
423
|
+
if chunk_token_count >= min_tokens:
|
|
390
424
|
paragraphs.append(chunk_text)
|
|
391
425
|
|
|
392
426
|
return paragraphs
|
|
393
427
|
|
|
394
428
|
|
|
395
|
-
def
|
|
396
|
-
"""Process a batch of documents using shared memory"""
|
|
397
|
-
try:
|
|
398
|
-
texts = []
|
|
399
|
-
self.documents = [] # Reset documents to store chunks
|
|
400
|
-
|
|
401
|
-
for doc in documents:
|
|
402
|
-
# Split original document into chunks
|
|
403
|
-
chunks = self._split_text(doc.content)
|
|
404
|
-
for chunk_idx, chunk in enumerate(chunks):
|
|
405
|
-
# Create new Document for each chunk
|
|
406
|
-
new_metadata = doc.metadata.copy()
|
|
407
|
-
new_metadata.update({
|
|
408
|
-
'chunk_index': chunk_idx,
|
|
409
|
-
'total_chunks': len(chunks),
|
|
410
|
-
'original_length': len(doc.content)
|
|
411
|
-
})
|
|
412
|
-
self.documents.append(Document(
|
|
413
|
-
content=chunk,
|
|
414
|
-
metadata=new_metadata,
|
|
415
|
-
md5=doc.md5
|
|
416
|
-
))
|
|
417
|
-
texts.append(f"File:{doc.metadata['file_path']} Chunk:{chunk_idx} Content:{chunk}")
|
|
418
|
-
|
|
419
|
-
return get_embedding_batch(self.embedding_model, texts)
|
|
420
|
-
except Exception as e:
|
|
421
|
-
PrettyOutput.print(f"批量处理失败: {str(e)}", OutputType.ERROR)
|
|
422
|
-
return np.zeros((0, self.vector_dim), dtype=np.float32) # type: ignore
|
|
423
|
-
|
|
424
|
-
def _process_file(self, file_path: str) -> List[Document]:
|
|
429
|
+
def _process_file(self, file_path: str, spinner=None) -> List[Document]:
|
|
425
430
|
"""Process a single file"""
|
|
426
431
|
try:
|
|
427
432
|
# Calculate file MD5
|
|
433
|
+
if spinner:
|
|
434
|
+
spinner.text = f"计算文件 {file_path} 的MD5..."
|
|
428
435
|
current_md5 = get_file_md5(file_path)
|
|
429
436
|
if not current_md5:
|
|
437
|
+
if spinner:
|
|
438
|
+
spinner.text = f"文件 {file_path} 计算MD5失败"
|
|
430
439
|
return []
|
|
431
440
|
|
|
432
441
|
# Check if the file needs to be reprocessed
|
|
433
442
|
if file_path in self.file_md5_cache and self.file_md5_cache[file_path] == current_md5:
|
|
443
|
+
if spinner:
|
|
444
|
+
spinner.text = f"文件 {file_path} 未发生变化,跳过处理"
|
|
434
445
|
return []
|
|
435
446
|
|
|
436
447
|
# Find the appropriate processor
|
|
448
|
+
if spinner:
|
|
449
|
+
spinner.text = f"查找适用于 {file_path} 的处理器..."
|
|
437
450
|
processor = None
|
|
438
451
|
for p in self.file_processors:
|
|
439
452
|
if p.can_handle(file_path):
|
|
@@ -442,17 +455,27 @@ class RAGTool:
|
|
|
442
455
|
|
|
443
456
|
if not processor:
|
|
444
457
|
# If no appropriate processor is found, return an empty document
|
|
458
|
+
if spinner:
|
|
459
|
+
spinner.text = f"没有找到适用于 {file_path} 的处理器,跳过处理"
|
|
445
460
|
return []
|
|
446
461
|
|
|
447
462
|
# Extract text content
|
|
463
|
+
if spinner:
|
|
464
|
+
spinner.text = f"提取 {file_path} 的文本内容..."
|
|
448
465
|
content = processor.extract_text(file_path)
|
|
449
466
|
if not content.strip():
|
|
467
|
+
if spinner:
|
|
468
|
+
spinner.text = f"文件 {file_path} 没有文本内容,跳过处理"
|
|
450
469
|
return []
|
|
451
470
|
|
|
452
471
|
# Split text
|
|
472
|
+
if spinner:
|
|
473
|
+
spinner.text = f"分割 {file_path} 的文本..."
|
|
453
474
|
chunks = self._split_text(content)
|
|
454
475
|
|
|
455
476
|
# Create document objects
|
|
477
|
+
if spinner:
|
|
478
|
+
spinner.text = f"为 {file_path} 创建 {len(chunks)} 个文档对象..."
|
|
456
479
|
documents = []
|
|
457
480
|
for i, chunk in enumerate(chunks):
|
|
458
481
|
doc = Document(
|
|
@@ -469,209 +492,726 @@ class RAGTool:
|
|
|
469
492
|
|
|
470
493
|
# Update MD5 cache
|
|
471
494
|
self.file_md5_cache[file_path] = current_md5
|
|
495
|
+
if spinner:
|
|
496
|
+
spinner.text = f"文件 {file_path} 处理完成,共创建 {len(documents)} 个文档对象"
|
|
472
497
|
return documents
|
|
473
498
|
|
|
474
499
|
except Exception as e:
|
|
500
|
+
if spinner:
|
|
501
|
+
spinner.text = f"处理文件失败: {file_path}: {str(e)}"
|
|
475
502
|
PrettyOutput.print(f"处理文件失败: {file_path}: {str(e)}",
|
|
476
503
|
output_type=OutputType.ERROR)
|
|
477
504
|
return []
|
|
478
505
|
|
|
479
|
-
def
|
|
480
|
-
"""
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
506
|
+
def _should_ignore_path(self, path: str, ignored_paths: List[str]) -> bool:
|
|
507
|
+
"""
|
|
508
|
+
检查路径是否应该被忽略
|
|
509
|
+
|
|
510
|
+
Args:
|
|
511
|
+
path: 文件或目录路径
|
|
512
|
+
ignored_paths: 忽略模式列表
|
|
513
|
+
|
|
514
|
+
Returns:
|
|
515
|
+
bool: 如果路径应该被忽略则返回True
|
|
516
|
+
"""
|
|
517
|
+
import fnmatch
|
|
518
|
+
import os
|
|
519
|
+
|
|
520
|
+
# 获取相对路径
|
|
521
|
+
rel_path = path
|
|
522
|
+
if os.path.isabs(path):
|
|
523
|
+
try:
|
|
524
|
+
rel_path = os.path.relpath(path, self.root_dir)
|
|
525
|
+
except ValueError:
|
|
526
|
+
# 如果不能计算相对路径,使用原始路径
|
|
527
|
+
pass
|
|
528
|
+
|
|
529
|
+
path_parts = rel_path.split(os.sep)
|
|
530
|
+
|
|
531
|
+
# 检查路径的每一部分是否匹配任意忽略模式
|
|
532
|
+
for part in path_parts:
|
|
533
|
+
for pattern in ignored_paths:
|
|
534
|
+
if fnmatch.fnmatch(part, pattern):
|
|
535
|
+
return True
|
|
489
536
|
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
537
|
+
# 检查完整路径是否匹配任意忽略模式
|
|
538
|
+
for pattern in ignored_paths:
|
|
539
|
+
if fnmatch.fnmatch(rel_path, pattern):
|
|
540
|
+
return True
|
|
541
|
+
|
|
542
|
+
return False
|
|
543
|
+
|
|
544
|
+
def _is_git_repo(self) -> bool:
|
|
545
|
+
"""
|
|
546
|
+
检查当前目录是否为Git仓库
|
|
547
|
+
|
|
548
|
+
Returns:
|
|
549
|
+
bool: 如果是Git仓库则返回True
|
|
550
|
+
"""
|
|
551
|
+
import subprocess
|
|
552
|
+
|
|
553
|
+
try:
|
|
554
|
+
result = subprocess.run(
|
|
555
|
+
["git", "rev-parse", "--is-inside-work-tree"],
|
|
556
|
+
cwd=self.root_dir,
|
|
557
|
+
stdout=subprocess.PIPE,
|
|
558
|
+
stderr=subprocess.PIPE,
|
|
559
|
+
text=True,
|
|
560
|
+
check=False
|
|
561
|
+
)
|
|
562
|
+
return result.returncode == 0 and result.stdout.strip() == "true"
|
|
563
|
+
except Exception:
|
|
564
|
+
return False
|
|
565
|
+
|
|
566
|
+
def _get_git_managed_files(self) -> List[str]:
|
|
567
|
+
"""
|
|
568
|
+
获取Git仓库中被管理的文件列表
|
|
569
|
+
|
|
570
|
+
Returns:
|
|
571
|
+
List[str]: 被Git管理的文件路径列表(相对路径)
|
|
572
|
+
"""
|
|
573
|
+
import subprocess
|
|
574
|
+
|
|
575
|
+
try:
|
|
576
|
+
# 获取git索引中的文件
|
|
577
|
+
result = subprocess.run(
|
|
578
|
+
["git", "ls-files"],
|
|
579
|
+
cwd=self.root_dir,
|
|
580
|
+
stdout=subprocess.PIPE,
|
|
581
|
+
stderr=subprocess.PIPE,
|
|
582
|
+
text=True,
|
|
583
|
+
check=False
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
if result.returncode != 0:
|
|
587
|
+
return []
|
|
588
|
+
|
|
589
|
+
git_files = [line.strip() for line in result.stdout.splitlines() if line.strip()]
|
|
539
590
|
|
|
591
|
+
# 添加未暂存但已跟踪的修改文件
|
|
592
|
+
result = subprocess.run(
|
|
593
|
+
["git", "ls-files", "--modified"],
|
|
594
|
+
cwd=self.root_dir,
|
|
595
|
+
stdout=subprocess.PIPE,
|
|
596
|
+
stderr=subprocess.PIPE,
|
|
597
|
+
text=True,
|
|
598
|
+
check=False
|
|
599
|
+
)
|
|
540
600
|
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
601
|
+
if result.returncode == 0:
|
|
602
|
+
modified_files = [line.strip() for line in result.stdout.splitlines() if line.strip()]
|
|
603
|
+
git_files.extend([f for f in modified_files if f not in git_files])
|
|
604
|
+
|
|
605
|
+
# 转换为绝对路径
|
|
606
|
+
return [os.path.join(self.root_dir, file) for file in git_files]
|
|
607
|
+
|
|
608
|
+
except Exception as e:
|
|
609
|
+
PrettyOutput.print(f"获取Git管理的文件失败: {str(e)}", output_type=OutputType.WARNING)
|
|
610
|
+
return []
|
|
611
|
+
|
|
612
|
+
def build_index(self, dir: str):
|
|
613
|
+
try:
|
|
614
|
+
"""Build document index with optimized processing"""
|
|
615
|
+
# Get all files
|
|
616
|
+
with yaspin(text="获取所有文件...", color="cyan") as spinner:
|
|
617
|
+
all_files = []
|
|
618
|
+
|
|
619
|
+
# 获取需要忽略的路径列表
|
|
620
|
+
ignored_paths = get_rag_ignored_paths()
|
|
621
|
+
|
|
622
|
+
# 检查是否为Git仓库
|
|
623
|
+
is_git_repo = self._is_git_repo()
|
|
624
|
+
if is_git_repo:
|
|
625
|
+
git_files = self._get_git_managed_files()
|
|
626
|
+
# 过滤掉被忽略的文件
|
|
627
|
+
for file_path in git_files:
|
|
628
|
+
if self._should_ignore_path(file_path, ignored_paths):
|
|
629
|
+
continue
|
|
553
630
|
|
|
554
|
-
|
|
555
|
-
|
|
631
|
+
if os.path.getsize(file_path) > 100 * 1024 * 1024: # 100MB
|
|
632
|
+
PrettyOutput.print(f"跳过大文件: {file_path}",
|
|
633
|
+
output_type=OutputType.WARNING)
|
|
634
|
+
continue
|
|
635
|
+
all_files.append(file_path)
|
|
636
|
+
else:
|
|
637
|
+
# 非Git仓库,使用常规文件遍历
|
|
638
|
+
for root, _, files in os.walk(dir):
|
|
639
|
+
# 检查目录是否匹配忽略模式
|
|
640
|
+
if self._should_ignore_path(root, ignored_paths):
|
|
641
|
+
continue
|
|
556
642
|
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
new_vectors.append(file_vectors)
|
|
560
|
-
|
|
561
|
-
spinner.text = f"处理文件 {file_path} 完成"
|
|
562
|
-
spinner.ok("✅")
|
|
643
|
+
for file in files:
|
|
644
|
+
file_path = os.path.join(root, file)
|
|
563
645
|
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
646
|
+
# 检查文件是否匹配忽略模式
|
|
647
|
+
if self._should_ignore_path(file_path, ignored_paths):
|
|
648
|
+
continue
|
|
649
|
+
|
|
650
|
+
if os.path.getsize(file_path) > 100 * 1024 * 1024: # 100MB
|
|
651
|
+
PrettyOutput.print(f"跳过大文件: {file_path}",
|
|
652
|
+
output_type=OutputType.WARNING)
|
|
653
|
+
continue
|
|
654
|
+
all_files.append(file_path)
|
|
655
|
+
|
|
656
|
+
spinner.text = f"获取所有文件完成,共 {len(all_files)} 个文件"
|
|
657
|
+
spinner.ok("✅")
|
|
658
|
+
|
|
659
|
+
# Clean up cache for deleted files
|
|
660
|
+
with yaspin(text="清理缓存...", color="cyan") as spinner:
|
|
661
|
+
deleted_files = set(self.file_md5_cache.keys()) - set(all_files)
|
|
662
|
+
deleted_count = len(deleted_files)
|
|
663
|
+
|
|
664
|
+
if deleted_count > 0:
|
|
665
|
+
spinner.write(f"🗑️ 删除不存在文件的缓存: {deleted_count} 个")
|
|
567
666
|
|
|
667
|
+
for file_path in deleted_files:
|
|
668
|
+
# Remove from MD5 cache
|
|
669
|
+
del self.file_md5_cache[file_path]
|
|
670
|
+
# Remove related documents
|
|
671
|
+
self.documents = [doc for doc in self.documents if doc.metadata['file_path'] != file_path]
|
|
672
|
+
# Delete cache files
|
|
673
|
+
self._delete_file_cache(file_path, None) # Pass None as spinner to not show individual deletions
|
|
568
674
|
|
|
569
|
-
|
|
570
|
-
|
|
675
|
+
spinner.text = f"清理缓存完成,共删除 {deleted_count} 个不存在文件的缓存"
|
|
676
|
+
spinner.ok("✅")
|
|
571
677
|
|
|
572
|
-
#
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
678
|
+
# Check file changes
|
|
679
|
+
with yaspin(text="检查文件变化...", color="cyan") as spinner:
|
|
680
|
+
files_to_process = []
|
|
681
|
+
unchanged_files = []
|
|
682
|
+
new_files_count = 0
|
|
683
|
+
modified_files_count = 0
|
|
684
|
+
|
|
685
|
+
for file_path in all_files:
|
|
686
|
+
current_md5 = get_file_md5(file_path)
|
|
687
|
+
if current_md5: # Only process files that can successfully calculate MD5
|
|
688
|
+
if file_path in self.file_md5_cache and self.file_md5_cache[file_path] == current_md5:
|
|
689
|
+
# File未变化,记录但不重新处理
|
|
690
|
+
unchanged_files.append(file_path)
|
|
691
|
+
else:
|
|
692
|
+
# New file or modified file
|
|
693
|
+
files_to_process.append(file_path)
|
|
694
|
+
|
|
695
|
+
# 如果是修改的文件,删除旧缓存
|
|
696
|
+
if file_path in self.file_md5_cache:
|
|
697
|
+
modified_files_count += 1
|
|
698
|
+
# 删除旧缓存
|
|
699
|
+
self._delete_file_cache(file_path, spinner)
|
|
700
|
+
# 从文档列表中移除
|
|
701
|
+
self.documents = [doc for doc in self.documents if doc.metadata['file_path'] != file_path]
|
|
702
|
+
else:
|
|
703
|
+
new_files_count += 1
|
|
704
|
+
|
|
705
|
+
# 输出汇总信息
|
|
706
|
+
if unchanged_files:
|
|
707
|
+
spinner.write(f"📚 已缓存文件: {len(unchanged_files)} 个")
|
|
708
|
+
if new_files_count > 0:
|
|
709
|
+
spinner.write(f"🆕 新增文件: {new_files_count} 个")
|
|
710
|
+
if modified_files_count > 0:
|
|
711
|
+
spinner.write(f"📝 修改文件: {modified_files_count} 个")
|
|
576
712
|
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
713
|
+
spinner.text = f"检查文件变化完成,共 {len(files_to_process)} 个文件需要处理"
|
|
714
|
+
spinner.ok("✅")
|
|
715
|
+
|
|
716
|
+
# Keep documents for unchanged files
|
|
717
|
+
unchanged_documents = [doc for doc in self.documents
|
|
718
|
+
if doc.metadata['file_path'] in unchanged_files]
|
|
719
|
+
|
|
720
|
+
# Process files one by one with optimized vectorization
|
|
721
|
+
if files_to_process:
|
|
722
|
+
new_documents = []
|
|
723
|
+
new_vectors = []
|
|
724
|
+
success_count = 0
|
|
725
|
+
skipped_count = 0
|
|
726
|
+
failed_count = 0
|
|
727
|
+
|
|
728
|
+
with yaspin(text=f"处理文件中 (0/{len(files_to_process)})...", color="cyan") as spinner:
|
|
729
|
+
for index, file_path in enumerate(files_to_process):
|
|
730
|
+
spinner.text = f"处理文件中 ({index+1}/{len(files_to_process)}): {file_path}"
|
|
731
|
+
try:
|
|
732
|
+
# Process single file
|
|
733
|
+
file_docs = self._process_file(file_path, spinner)
|
|
734
|
+
if file_docs:
|
|
735
|
+
# Vectorize documents from this file
|
|
736
|
+
spinner.text = f"处理文件中 ({index+1}/{len(files_to_process)}): 为 {file_path} 生成向量嵌入..."
|
|
737
|
+
texts_to_vectorize = [
|
|
738
|
+
f"File:{doc.metadata['file_path']} Content:{doc.content}"
|
|
739
|
+
for doc in file_docs
|
|
740
|
+
]
|
|
741
|
+
|
|
742
|
+
file_vectors = get_embedding_batch(self.embedding_model, f"({index+1}/{len(files_to_process)}){file_path}", texts_to_vectorize, spinner)
|
|
743
|
+
|
|
744
|
+
# Save cache for this file
|
|
745
|
+
spinner.text = f"处理文件中 ({index+1}/{len(files_to_process)}): 保存 {file_path} 的缓存..."
|
|
746
|
+
self._save_cache(file_path, file_docs, file_vectors, spinner)
|
|
747
|
+
|
|
748
|
+
# Accumulate documents and vectors
|
|
749
|
+
new_documents.extend(file_docs)
|
|
750
|
+
new_vectors.append(file_vectors)
|
|
751
|
+
success_count += 1
|
|
752
|
+
else:
|
|
753
|
+
# 文件跳过处理
|
|
754
|
+
skipped_count += 1
|
|
755
|
+
|
|
756
|
+
except Exception as e:
|
|
757
|
+
spinner.write(f"❌ 处理失败: {file_path}: {str(e)}")
|
|
758
|
+
failed_count += 1
|
|
759
|
+
|
|
760
|
+
# 输出处理统计
|
|
761
|
+
spinner.text = f"文件处理完成: 成功 {success_count} 个, 跳过 {skipped_count} 个, 失败 {failed_count} 个"
|
|
762
|
+
spinner.ok("✅")
|
|
763
|
+
|
|
764
|
+
# Update documents list
|
|
765
|
+
self.documents.extend(new_documents)
|
|
766
|
+
|
|
767
|
+
# Build final index
|
|
768
|
+
if new_vectors:
|
|
769
|
+
with yaspin(text="构建最终索引...", color="cyan") as spinner:
|
|
770
|
+
spinner.text = "合并新向量..."
|
|
771
|
+
all_new_vectors = np.vstack(new_vectors)
|
|
772
|
+
|
|
773
|
+
unchanged_vector_count = 0
|
|
774
|
+
if self.flat_index is not None:
|
|
775
|
+
# Get vectors for unchanged documents
|
|
776
|
+
spinner.text = "获取未变化文档的向量..."
|
|
777
|
+
unchanged_vectors = self._get_unchanged_vectors(unchanged_documents, spinner)
|
|
778
|
+
if unchanged_vectors is not None:
|
|
779
|
+
unchanged_vector_count = unchanged_vectors.shape[0]
|
|
780
|
+
spinner.text = f"合并新旧向量(新:{all_new_vectors.shape[0]},旧:{unchanged_vector_count})..."
|
|
781
|
+
final_vectors = np.vstack([unchanged_vectors, all_new_vectors])
|
|
782
|
+
else:
|
|
783
|
+
spinner.text = f"仅使用新向量({all_new_vectors.shape[0]})..."
|
|
784
|
+
final_vectors = all_new_vectors
|
|
582
785
|
else:
|
|
786
|
+
spinner.text = f"仅使用新向量({all_new_vectors.shape[0]})..."
|
|
583
787
|
final_vectors = all_new_vectors
|
|
584
|
-
else:
|
|
585
|
-
final_vectors = all_new_vectors
|
|
586
788
|
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
789
|
+
# Build index
|
|
790
|
+
spinner.text = f"构建索引(向量数量:{final_vectors.shape[0]})..."
|
|
791
|
+
self._build_index(final_vectors, spinner)
|
|
792
|
+
spinner.text = f"索引构建完成,共 {len(self.documents)} 个文档片段"
|
|
793
|
+
spinner.ok("✅")
|
|
592
794
|
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
795
|
+
# 输出最终统计信息
|
|
796
|
+
PrettyOutput.print(
|
|
797
|
+
f"📊 索引统计:\n"
|
|
798
|
+
f" • 总文档数: {len(self.documents)} 个文档片段\n"
|
|
799
|
+
f" • 已缓存文件: {len(unchanged_files)} 个\n"
|
|
800
|
+
f" • 处理文件: {len(files_to_process)} 个\n"
|
|
801
|
+
f" - 成功: {success_count} 个\n"
|
|
802
|
+
f" - 跳过: {skipped_count} 个\n"
|
|
803
|
+
f" - 失败: {failed_count} 个",
|
|
804
|
+
OutputType.SUCCESS
|
|
805
|
+
)
|
|
806
|
+
except Exception as e:
|
|
807
|
+
PrettyOutput.print(f"索引构建失败: {str(e)}",
|
|
808
|
+
output_type=OutputType.ERROR)
|
|
599
809
|
|
|
600
|
-
def _get_unchanged_vectors(self, unchanged_documents: List[Document]) -> Optional[np.ndarray]:
|
|
810
|
+
def _get_unchanged_vectors(self, unchanged_documents: List[Document], spinner=None) -> Optional[np.ndarray]:
|
|
601
811
|
"""Get vectors for unchanged documents from existing index"""
|
|
602
812
|
try:
|
|
603
|
-
if not unchanged_documents
|
|
813
|
+
if not unchanged_documents:
|
|
814
|
+
if spinner:
|
|
815
|
+
spinner.text = "没有未变化的文档"
|
|
604
816
|
return None
|
|
605
817
|
|
|
818
|
+
if spinner:
|
|
819
|
+
spinner.text = f"加载 {len(unchanged_documents)} 个未变化文档的向量..."
|
|
820
|
+
|
|
821
|
+
# 按文件分组处理
|
|
822
|
+
unchanged_files = set(doc.metadata['file_path'] for doc in unchanged_documents)
|
|
606
823
|
unchanged_vectors = []
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
824
|
+
|
|
825
|
+
for file_path in unchanged_files:
|
|
826
|
+
if spinner:
|
|
827
|
+
spinner.text = f"加载 {file_path} 的向量..."
|
|
828
|
+
|
|
829
|
+
# 获取该文件所有文档的chunk索引
|
|
830
|
+
doc_indices = [(i, doc.metadata['chunk_index'])
|
|
831
|
+
for i, doc in enumerate(unchanged_documents)
|
|
832
|
+
if doc.metadata['file_path'] == file_path]
|
|
833
|
+
|
|
834
|
+
if not doc_indices:
|
|
835
|
+
continue
|
|
836
|
+
|
|
837
|
+
# 加载该文件的向量
|
|
838
|
+
vec_cache_path = self._get_cache_path(file_path, "vec")
|
|
839
|
+
if os.path.exists(vec_cache_path):
|
|
840
|
+
try:
|
|
841
|
+
with lzma.open(vec_cache_path, 'rb') as f:
|
|
842
|
+
vec_cache_data = pickle.load(f)
|
|
843
|
+
file_vectors = vec_cache_data["vectors"]
|
|
844
|
+
|
|
845
|
+
# 按照chunk_index加载对应的向量
|
|
846
|
+
for _, chunk_idx in doc_indices:
|
|
847
|
+
if chunk_idx < file_vectors.shape[0]:
|
|
848
|
+
unchanged_vectors.append(file_vectors[chunk_idx].reshape(1, -1))
|
|
849
|
+
|
|
850
|
+
if spinner:
|
|
851
|
+
spinner.text = f"成功加载 {file_path} 的向量"
|
|
852
|
+
except Exception as e:
|
|
853
|
+
if spinner:
|
|
854
|
+
spinner.text = f"加载 {file_path} 向量失败: {str(e)}"
|
|
855
|
+
else:
|
|
856
|
+
if spinner:
|
|
857
|
+
spinner.text = f"未找到 {file_path} 的向量缓存"
|
|
858
|
+
|
|
859
|
+
# 从flat_index重建向量
|
|
860
|
+
if self.flat_index is not None:
|
|
861
|
+
if spinner:
|
|
862
|
+
spinner.text = f"从索引重建 {file_path} 的向量..."
|
|
863
|
+
|
|
864
|
+
for doc_idx, chunk_idx in doc_indices:
|
|
865
|
+
idx = next((i for i, d in enumerate(self.documents)
|
|
866
|
+
if d.metadata['file_path'] == file_path and
|
|
867
|
+
d.metadata['chunk_index'] == chunk_idx), None)
|
|
868
|
+
|
|
869
|
+
if idx is not None:
|
|
870
|
+
vector = np.zeros((1, self.vector_dim), dtype=np.float32) # type: ignore
|
|
871
|
+
self.flat_index.reconstruct(idx, vector.ravel())
|
|
872
|
+
unchanged_vectors.append(vector)
|
|
614
873
|
|
|
615
|
-
|
|
874
|
+
if not unchanged_vectors:
|
|
875
|
+
if spinner:
|
|
876
|
+
spinner.text = "未能加载任何未变化文档的向量"
|
|
877
|
+
return None
|
|
878
|
+
|
|
879
|
+
if spinner:
|
|
880
|
+
spinner.text = f"未变化文档向量加载完成,共 {len(unchanged_vectors)} 个"
|
|
881
|
+
|
|
882
|
+
return np.vstack(unchanged_vectors)
|
|
616
883
|
|
|
617
884
|
except Exception as e:
|
|
885
|
+
if spinner:
|
|
886
|
+
spinner.text = f"获取不变向量失败: {str(e)}"
|
|
618
887
|
PrettyOutput.print(f"获取不变向量失败: {str(e)}", OutputType.ERROR)
|
|
619
888
|
return None
|
|
620
889
|
|
|
621
|
-
def
|
|
890
|
+
def _perform_keyword_search(self, query: str, limit: int = 15) -> List[Tuple[int, float]]:
|
|
891
|
+
"""执行基于关键词的文本搜索
|
|
892
|
+
|
|
893
|
+
Args:
|
|
894
|
+
query: 查询字符串
|
|
895
|
+
limit: 返回结果数量限制
|
|
896
|
+
|
|
897
|
+
Returns:
|
|
898
|
+
List[Tuple[int, float]]: 文档索引和得分的列表
|
|
899
|
+
"""
|
|
900
|
+
# 使用大模型生成关键词
|
|
901
|
+
keywords = self._generate_keywords_with_llm(query)
|
|
902
|
+
|
|
903
|
+
# 如果大模型生成失败,回退到简单的关键词提取
|
|
904
|
+
if not keywords:
|
|
905
|
+
# 简单的关键词预处理
|
|
906
|
+
keywords = query.lower().split()
|
|
907
|
+
# 移除停用词和过短的词
|
|
908
|
+
stop_words = {'的', '了', '和', '是', '在', '有', '与', '对', '为', 'a', 'an', 'the', 'and', 'is', 'in', 'of', 'to', 'with'}
|
|
909
|
+
keywords = [k for k in keywords if k not in stop_words and len(k) > 1]
|
|
910
|
+
|
|
911
|
+
if not keywords:
|
|
912
|
+
return []
|
|
913
|
+
|
|
914
|
+
# 使用TF-IDF思想的简单实现
|
|
915
|
+
doc_scores = []
|
|
916
|
+
|
|
917
|
+
# 计算IDF(逆文档频率)
|
|
918
|
+
doc_count = len(self.documents)
|
|
919
|
+
keyword_doc_count = {}
|
|
920
|
+
|
|
921
|
+
for keyword in keywords:
|
|
922
|
+
count = 0
|
|
923
|
+
for doc in self.documents:
|
|
924
|
+
if keyword in doc.content.lower():
|
|
925
|
+
count += 1
|
|
926
|
+
keyword_doc_count[keyword] = max(1, count) # 避免除零错误
|
|
927
|
+
|
|
928
|
+
# 计算每个关键词的IDF值
|
|
929
|
+
keyword_idf = {
|
|
930
|
+
keyword: np.log(doc_count / count)
|
|
931
|
+
for keyword, count in keyword_doc_count.items()
|
|
932
|
+
}
|
|
933
|
+
|
|
934
|
+
# 为每个文档计算得分
|
|
935
|
+
for i, doc in enumerate(self.documents):
|
|
936
|
+
doc_content = doc.content.lower()
|
|
937
|
+
score = 0
|
|
938
|
+
|
|
939
|
+
# 计算每个关键词的TF(词频)
|
|
940
|
+
for keyword in keywords:
|
|
941
|
+
# 简单的TF:关键词在文档中出现的次数
|
|
942
|
+
tf = doc_content.count(keyword)
|
|
943
|
+
# TF-IDF得分
|
|
944
|
+
if tf > 0:
|
|
945
|
+
score += tf * keyword_idf[keyword]
|
|
946
|
+
|
|
947
|
+
# 添加额外权重:标题匹配、完整短语匹配等
|
|
948
|
+
if query.lower() in doc_content:
|
|
949
|
+
score *= 2.0 # 完整查询匹配加倍得分
|
|
950
|
+
|
|
951
|
+
# 文件路径匹配也加分
|
|
952
|
+
file_path = doc.metadata['file_path'].lower()
|
|
953
|
+
for keyword in keywords:
|
|
954
|
+
if keyword in file_path:
|
|
955
|
+
score += 0.5 * keyword_idf.get(keyword, 1.0)
|
|
956
|
+
|
|
957
|
+
if score > 0:
|
|
958
|
+
# 归一化得分(0-1范围)
|
|
959
|
+
doc_scores.append((i, score))
|
|
960
|
+
|
|
961
|
+
# 排序并限制结果数量
|
|
962
|
+
doc_scores.sort(key=lambda x: x[1], reverse=True)
|
|
963
|
+
|
|
964
|
+
# 归一化分数到0-1之间
|
|
965
|
+
if doc_scores:
|
|
966
|
+
max_score = max(score for _, score in doc_scores)
|
|
967
|
+
if max_score > 0:
|
|
968
|
+
doc_scores = [(idx, score/max_score) for idx, score in doc_scores]
|
|
969
|
+
|
|
970
|
+
return doc_scores[:limit]
|
|
971
|
+
|
|
972
|
+
def _generate_keywords_with_llm(self, query: str) -> List[str]:
|
|
973
|
+
"""
|
|
974
|
+
使用大语言模型从查询中提取关键词
|
|
975
|
+
|
|
976
|
+
Args:
|
|
977
|
+
query: 用户查询
|
|
978
|
+
|
|
979
|
+
Returns:
|
|
980
|
+
List[str]: 提取的关键词列表
|
|
981
|
+
"""
|
|
982
|
+
try:
|
|
983
|
+
from jarvis.jarvis_utils.output import PrettyOutput, OutputType
|
|
984
|
+
from jarvis.jarvis_platform.registry import PlatformRegistry
|
|
985
|
+
|
|
986
|
+
# 获取平台注册表和模型
|
|
987
|
+
registry = PlatformRegistry.get_global_platform_registry()
|
|
988
|
+
model = registry.get_normal_platform()
|
|
989
|
+
|
|
990
|
+
# 构建关键词提取提示词
|
|
991
|
+
prompt = f"""
|
|
992
|
+
请分析以下查询,提取用于文档检索的关键词。你的任务是:
|
|
993
|
+
|
|
994
|
+
1. 识别核心概念、主题和实体,包括:
|
|
995
|
+
- 技术术语和专业名词
|
|
996
|
+
- 代码相关的函数名、类名、变量名和库名
|
|
997
|
+
- 重要的业务领域词汇
|
|
998
|
+
- 操作和动作相关的词汇
|
|
999
|
+
|
|
1000
|
+
2. 优先提取与以下场景相关的关键词:
|
|
1001
|
+
- 代码搜索: 编程语言、框架、API、特定功能
|
|
1002
|
+
- 文档检索: 主题、标题词汇、专业名词
|
|
1003
|
+
- 错误排查: 错误信息、异常名称、问题症状
|
|
1004
|
+
|
|
1005
|
+
3. 同时包含:
|
|
1006
|
+
- 中英文关键词 (尤其是技术领域常用英文术语)
|
|
1007
|
+
- 完整的专业术语和缩写形式
|
|
1008
|
+
- 可能的同义词或相关概念
|
|
1009
|
+
|
|
1010
|
+
4. 关键词应当精准、具体,数量控制在3-10个之间。
|
|
1011
|
+
|
|
1012
|
+
输出格式:
|
|
1013
|
+
{ot("KEYWORD")}
|
|
1014
|
+
关键词1
|
|
1015
|
+
关键词2
|
|
1016
|
+
...
|
|
1017
|
+
{ct("KEYWORD")}
|
|
1018
|
+
|
|
1019
|
+
查询文本:
|
|
1020
|
+
{query}
|
|
1021
|
+
|
|
1022
|
+
仅返回提取的关键词,不要包含其他内容。
|
|
1023
|
+
"""
|
|
1024
|
+
|
|
1025
|
+
# 调用大模型获取响应
|
|
1026
|
+
response = model.chat_until_success(prompt)
|
|
1027
|
+
|
|
1028
|
+
if response:
|
|
1029
|
+
# 清理响应,提取关键词
|
|
1030
|
+
sm = re.search(ot('KEYWORD') + r"(.*?)" + ct('KEYWORD'), response, re.DOTALL)
|
|
1031
|
+
if sm:
|
|
1032
|
+
extracted_keywords = sm[1]
|
|
1033
|
+
|
|
1034
|
+
if extracted_keywords:
|
|
1035
|
+
# 记录检测到的关键词
|
|
1036
|
+
ret = extracted_keywords.strip().splitlines()
|
|
1037
|
+
return ret
|
|
1038
|
+
|
|
1039
|
+
# 如果处理失败,返回空列表
|
|
1040
|
+
return []
|
|
1041
|
+
|
|
1042
|
+
except Exception as e:
|
|
1043
|
+
from jarvis.jarvis_utils.output import PrettyOutput, OutputType
|
|
1044
|
+
PrettyOutput.print(f"使用大模型生成关键词失败: {str(e)}", OutputType.WARNING)
|
|
1045
|
+
return []
|
|
1046
|
+
|
|
1047
|
+
def _hybrid_search(self, query: str, top_k: int = 15) -> List[Tuple[int, float]]:
|
|
1048
|
+
"""混合搜索方法,综合向量相似度和关键词匹配
|
|
1049
|
+
|
|
1050
|
+
Args:
|
|
1051
|
+
query: 查询字符串
|
|
1052
|
+
top_k: 返回结果数量限制
|
|
1053
|
+
|
|
1054
|
+
Returns:
|
|
1055
|
+
List[Tuple[int, float]]: 文档索引和得分的列表
|
|
1056
|
+
"""
|
|
1057
|
+
# 获取向量搜索结果
|
|
1058
|
+
query_vector = get_embedding(self.embedding_model, query)
|
|
1059
|
+
query_vector = query_vector.reshape(1, -1)
|
|
1060
|
+
|
|
1061
|
+
# 进行向量搜索
|
|
1062
|
+
vector_limit = min(top_k * 3, len(self.documents))
|
|
1063
|
+
if self.index and vector_limit > 0:
|
|
1064
|
+
distances, indices = self.index.search(query_vector, vector_limit) # type: ignore
|
|
1065
|
+
vector_results = [(int(idx), 1.0 / (1.0 + float(dist)))
|
|
1066
|
+
for idx, dist in zip(indices[0], distances[0])
|
|
1067
|
+
if idx != -1 and idx < len(self.documents)]
|
|
1068
|
+
else:
|
|
1069
|
+
vector_results = []
|
|
1070
|
+
|
|
1071
|
+
# 进行关键词搜索
|
|
1072
|
+
keyword_results = self._perform_keyword_search(query, top_k * 2)
|
|
1073
|
+
|
|
1074
|
+
# 合并结果集
|
|
1075
|
+
combined_results = {}
|
|
1076
|
+
|
|
1077
|
+
# 加入向量结果,权重为0.7
|
|
1078
|
+
for idx, score in vector_results:
|
|
1079
|
+
combined_results[idx] = score * 0.7
|
|
1080
|
+
|
|
1081
|
+
# 加入关键词结果,权重为0.3,如果文档已存在则取加权平均
|
|
1082
|
+
for idx, score in keyword_results:
|
|
1083
|
+
if idx in combined_results:
|
|
1084
|
+
# 已有向量得分,取加权平均
|
|
1085
|
+
combined_results[idx] = combined_results[idx] + score * 0.3
|
|
1086
|
+
else:
|
|
1087
|
+
# 新文档,直接添加关键词得分(权重稍低)
|
|
1088
|
+
combined_results[idx] = score * 0.3
|
|
1089
|
+
|
|
1090
|
+
# 转换成列表并排序
|
|
1091
|
+
result_list = [(idx, score) for idx, score in combined_results.items()]
|
|
1092
|
+
result_list.sort(key=lambda x: x[1], reverse=True)
|
|
1093
|
+
|
|
1094
|
+
return result_list[:top_k]
|
|
1095
|
+
|
|
1096
|
+
|
|
1097
|
+
def search(self, query: str, top_k: int = 15) -> List[Tuple[Document, float]]:
|
|
622
1098
|
"""Search documents with context window"""
|
|
623
|
-
if not self.
|
|
1099
|
+
if not self.is_index_built():
|
|
1100
|
+
PrettyOutput.print("索引未建立,自动建立索引中...", OutputType.INFO)
|
|
624
1101
|
self.build_index(self.root_dir)
|
|
625
1102
|
|
|
626
|
-
#
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
1103
|
+
# 如果索引建立失败或文档列表为空,返回空结果
|
|
1104
|
+
if not self.is_index_built():
|
|
1105
|
+
PrettyOutput.print("索引建立失败或文档列表为空", OutputType.WARNING)
|
|
1106
|
+
return []
|
|
1107
|
+
|
|
1108
|
+
# 使用混合搜索获取候选文档
|
|
1109
|
+
with yaspin(text="执行混合搜索...", color="cyan") as spinner:
|
|
1110
|
+
# 获取初始候选结果
|
|
1111
|
+
search_results = self._hybrid_search(query, top_k * 2)
|
|
1112
|
+
|
|
1113
|
+
if not search_results:
|
|
1114
|
+
spinner.text = "搜索结果为空"
|
|
1115
|
+
spinner.fail("❌")
|
|
1116
|
+
return []
|
|
1117
|
+
|
|
1118
|
+
# 准备重排序
|
|
1119
|
+
initial_indices = [idx for idx, _ in search_results]
|
|
1120
|
+
spinner.text = f"检索完成,获取 {len(initial_indices)} 个候选文档"
|
|
631
1121
|
spinner.ok("✅")
|
|
632
1122
|
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
1123
|
+
indices_list = [idx for idx, _ in search_results if idx < len(self.documents)]
|
|
1124
|
+
|
|
1125
|
+
# 应用重排序优化检索结果
|
|
1126
|
+
with yaspin(text="执行重排序...", color="cyan") as spinner:
|
|
1127
|
+
# 准备重排序所需文档内容和初始分数
|
|
1128
|
+
docs_to_rerank = []
|
|
1129
|
+
initial_scores = []
|
|
1130
|
+
|
|
1131
|
+
for idx, score in search_results:
|
|
1132
|
+
if idx < len(self.documents):
|
|
1133
|
+
doc = self.documents[idx]
|
|
1134
|
+
# 获取原始文档内容
|
|
1135
|
+
doc_content = f"File:{doc.metadata['file_path']} Content:{doc.content}"
|
|
1136
|
+
docs_to_rerank.append(doc_content)
|
|
1137
|
+
initial_scores.append(score)
|
|
1138
|
+
|
|
1139
|
+
if not docs_to_rerank:
|
|
1140
|
+
spinner.text = "没有可重排序的文档"
|
|
1141
|
+
spinner.fail("❌")
|
|
1142
|
+
return []
|
|
1143
|
+
|
|
1144
|
+
# 执行重排序
|
|
1145
|
+
spinner.text = f"重排序 {len(docs_to_rerank)} 个文档..."
|
|
1146
|
+
reranked_scores = rerank_results(
|
|
1147
|
+
query=query,
|
|
1148
|
+
documents=docs_to_rerank,
|
|
1149
|
+
initial_scores=initial_scores,
|
|
1150
|
+
spinner=spinner
|
|
1151
|
+
)
|
|
1152
|
+
|
|
1153
|
+
# 更新搜索结果的分数
|
|
1154
|
+
search_results = []
|
|
1155
|
+
for i, idx in enumerate(indices_list):
|
|
1156
|
+
if i < len(reranked_scores):
|
|
1157
|
+
search_results.append((idx, reranked_scores[i]))
|
|
1158
|
+
|
|
1159
|
+
# 按分数重新排序
|
|
1160
|
+
search_results.sort(key=lambda x: x[1], reverse=True)
|
|
1161
|
+
|
|
1162
|
+
spinner.text = "重排序完成"
|
|
638
1163
|
spinner.ok("✅")
|
|
639
1164
|
|
|
1165
|
+
# 重新获取排序后的索引列表
|
|
1166
|
+
indices_list = [idx for idx, _ in search_results if idx < len(self.documents)]
|
|
1167
|
+
|
|
640
1168
|
# Process results with context window
|
|
641
1169
|
with yaspin(text="处理结果...", color="cyan") as spinner:
|
|
642
1170
|
results = []
|
|
643
1171
|
seen_files = set()
|
|
644
1172
|
|
|
645
|
-
|
|
646
|
-
|
|
1173
|
+
# 检查索引列表是否为空
|
|
1174
|
+
if not indices_list:
|
|
1175
|
+
spinner.text = "搜索结果为空"
|
|
1176
|
+
spinner.fail("❌")
|
|
1177
|
+
return []
|
|
1178
|
+
|
|
1179
|
+
for idx in indices_list:
|
|
1180
|
+
if idx < len(self.documents): # 确保索引有效
|
|
647
1181
|
doc = self.documents[idx]
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
1182
|
+
|
|
1183
|
+
# 使用重排序得分或基于原始相似度的得分
|
|
1184
|
+
similarity = next((score for i, score in search_results if i == idx), 0.5) if search_results else 0.5
|
|
1185
|
+
|
|
1186
|
+
file_path = doc.metadata['file_path']
|
|
1187
|
+
if file_path not in seen_files:
|
|
1188
|
+
seen_files.add(file_path)
|
|
1189
|
+
|
|
1190
|
+
# Get full context from original document
|
|
1191
|
+
original_doc = next((d for d in self.documents
|
|
1192
|
+
if d.metadata['file_path'] == file_path), None)
|
|
1193
|
+
if original_doc:
|
|
1194
|
+
window_docs = [] # Add this line to initialize the list
|
|
1195
|
+
# Find all chunks from this file
|
|
1196
|
+
file_chunks = [d for d in self.documents
|
|
1197
|
+
if d.metadata['file_path'] == file_path]
|
|
1198
|
+
# Add all related chunks
|
|
1199
|
+
for chunk_doc in file_chunks:
|
|
1200
|
+
window_docs.append((chunk_doc, similarity * 0.9))
|
|
1201
|
+
|
|
1202
|
+
results.extend(window_docs)
|
|
1203
|
+
if len(results) >= top_k * (2 * self.context_window + 1):
|
|
1204
|
+
break
|
|
670
1205
|
spinner.text = "处理结果完成"
|
|
671
1206
|
spinner.ok("✅")
|
|
672
1207
|
|
|
673
1208
|
# Sort by similarity and deduplicate
|
|
674
1209
|
with yaspin(text="排序...", color="cyan") as spinner:
|
|
1210
|
+
if not results:
|
|
1211
|
+
spinner.text = "无有效结果"
|
|
1212
|
+
spinner.fail("❌")
|
|
1213
|
+
return []
|
|
1214
|
+
|
|
675
1215
|
results.sort(key=lambda x: x[1], reverse=True)
|
|
676
1216
|
seen = set()
|
|
677
1217
|
final_results = []
|
|
@@ -702,97 +1242,403 @@ class RAGTool:
|
|
|
702
1242
|
def ask(self, question: str) -> Optional[str]:
|
|
703
1243
|
"""Ask questions about documents with enhanced context building"""
|
|
704
1244
|
try:
|
|
705
|
-
|
|
1245
|
+
# 检查索引是否已建立,如果没有则自动建立
|
|
1246
|
+
if not self.is_index_built():
|
|
1247
|
+
PrettyOutput.print("索引未建立,自动建立索引中...", OutputType.INFO)
|
|
1248
|
+
self.build_index(self.root_dir)
|
|
1249
|
+
|
|
1250
|
+
# 如果建立索引后仍未成功,返回错误信息
|
|
1251
|
+
if not self.is_index_built():
|
|
1252
|
+
PrettyOutput.print("无法建立索引,请检查文档和配置", OutputType.ERROR)
|
|
1253
|
+
return "无法建立索引,请检查文档和配置。可能的原因:文档目录为空、权限不足或格式不支持。"
|
|
1254
|
+
|
|
1255
|
+
# 增强查询预处理 - 提取关键词和语义信息
|
|
1256
|
+
enhanced_query = self._enhance_query(question)
|
|
1257
|
+
|
|
1258
|
+
# 使用增强的查询进行搜索
|
|
1259
|
+
results = self.search(enhanced_query)
|
|
706
1260
|
if not results:
|
|
707
|
-
return
|
|
1261
|
+
return "未找到与问题相关的文档。请尝试重新表述问题或确认问题相关内容已包含在索引中。"
|
|
1262
|
+
|
|
1263
|
+
# 模型实例
|
|
1264
|
+
model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
|
|
708
1265
|
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
-
|
|
716
|
-
-
|
|
717
|
-
-
|
|
718
|
-
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
-
|
|
724
|
-
-
|
|
725
|
-
-
|
|
726
|
-
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
-
|
|
732
|
-
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
-
|
|
737
|
-
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
-
|
|
742
|
-
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
Relevant Documents (by relevance):
|
|
1266
|
+
# 计算基础提示词的token数量
|
|
1267
|
+
base_prompt = f"""
|
|
1268
|
+
# 🤖 角色定义
|
|
1269
|
+
您是一位文档分析专家,能够基于提供的文档提供准确且全面的回答。
|
|
1270
|
+
|
|
1271
|
+
# 🎯 核心职责
|
|
1272
|
+
- 全面分析文档片段
|
|
1273
|
+
- 准确回答问题
|
|
1274
|
+
- 引用源文档
|
|
1275
|
+
- 识别缺失信息
|
|
1276
|
+
- 保持专业语气
|
|
1277
|
+
|
|
1278
|
+
# 📋 回答要求
|
|
1279
|
+
## 内容质量
|
|
1280
|
+
- 严格基于提供的文档作答
|
|
1281
|
+
- 具体且精确
|
|
1282
|
+
- 在有帮助时引用相关内容
|
|
1283
|
+
- 指出任何信息缺口
|
|
1284
|
+
- 使用专业语言
|
|
1285
|
+
|
|
1286
|
+
## 回答结构
|
|
1287
|
+
1. 直接回答
|
|
1288
|
+
- 清晰简洁的回应
|
|
1289
|
+
- 基于文档证据
|
|
1290
|
+
- 专业术语
|
|
1291
|
+
|
|
1292
|
+
2. 支持细节
|
|
1293
|
+
- 相关文档引用
|
|
1294
|
+
- 文件参考
|
|
1295
|
+
- 上下文解释
|
|
1296
|
+
|
|
1297
|
+
3. 信息缺口(如有)
|
|
1298
|
+
- 缺失信息
|
|
1299
|
+
- 需要的额外上下文
|
|
1300
|
+
- 潜在限制
|
|
1301
|
+
|
|
1302
|
+
# 🔍 分析上下文
|
|
1303
|
+
问题: {question}
|
|
748
1304
|
"""
|
|
1305
|
+
base_token_count = get_context_token_count(base_prompt)
|
|
1306
|
+
footer_prompt = """
|
|
1307
|
+
# ❗ 重要规则
|
|
1308
|
+
1. 仅使用提供的文档
|
|
1309
|
+
2. 保持精确和准确
|
|
1310
|
+
3. 在相关时引用来源
|
|
1311
|
+
4. 指出缺失的信息
|
|
1312
|
+
5. 保持专业语气
|
|
1313
|
+
6. 使用用户的语言回答
|
|
1314
|
+
"""
|
|
1315
|
+
footer_token_count = get_context_token_count(footer_prompt)
|
|
1316
|
+
|
|
1317
|
+
# 每批可用的token数,减去一些安全余量
|
|
1318
|
+
available_tokens_per_batch = self.max_token_count - base_token_count - footer_token_count - 1000
|
|
1319
|
+
|
|
1320
|
+
# 确定是否需要分批处理
|
|
1321
|
+
with yaspin(text="计算文档上下文大小...", color="cyan") as spinner:
|
|
1322
|
+
# 将结果按文件分组
|
|
1323
|
+
file_groups = {}
|
|
1324
|
+
for doc, score in results:
|
|
1325
|
+
file_path = doc.metadata['file_path']
|
|
1326
|
+
if file_path not in file_groups:
|
|
1327
|
+
file_groups[file_path] = []
|
|
1328
|
+
file_groups[file_path].append((doc, score))
|
|
1329
|
+
|
|
1330
|
+
# 计算所有文档的总token数
|
|
1331
|
+
total_docs_tokens = 0
|
|
1332
|
+
total_len = 0
|
|
1333
|
+
for file_path, docs in file_groups.items():
|
|
1334
|
+
file_header = f"\n## 文件: {file_path}\n"
|
|
1335
|
+
file_tokens = get_context_token_count(file_header)
|
|
1336
|
+
|
|
1337
|
+
# 处理所有相关性足够高的文档
|
|
1338
|
+
for doc, score in docs:
|
|
1339
|
+
if score < 0.2:
|
|
1340
|
+
continue
|
|
1341
|
+
doc_content = f"""
|
|
1342
|
+
### 片段 {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']} [相关度: {score:.2f}]
|
|
1343
|
+
```
|
|
1344
|
+
{doc.content}
|
|
1345
|
+
```
|
|
1346
|
+
"""
|
|
1347
|
+
file_tokens += get_context_token_count(doc_content)
|
|
1348
|
+
total_len += len(doc_content)
|
|
1349
|
+
total_docs_tokens += file_tokens
|
|
1350
|
+
|
|
1351
|
+
# 确定是否需要分批处理及分几批
|
|
1352
|
+
need_batching = total_docs_tokens > available_tokens_per_batch
|
|
1353
|
+
batch_count = 1
|
|
1354
|
+
if need_batching:
|
|
1355
|
+
batch_count = (total_docs_tokens + available_tokens_per_batch - 1) // available_tokens_per_batch
|
|
1356
|
+
|
|
1357
|
+
if need_batching:
|
|
1358
|
+
spinner.text = f"文档需要分 {batch_count} 批处理 (总计 {total_docs_tokens} tokens), 总长度 {total_len} 字符"
|
|
1359
|
+
else:
|
|
1360
|
+
spinner.text = f"文档无需分批 (总计 {total_docs_tokens} tokens), 总长度 {total_len} 字符"
|
|
1361
|
+
spinner.ok("✅")
|
|
1362
|
+
|
|
1363
|
+
# 单批处理直接使用原方法
|
|
1364
|
+
if not need_batching:
|
|
1365
|
+
with yaspin(text="添加上下文...", color="cyan") as spinner:
|
|
1366
|
+
prompt = base_prompt
|
|
1367
|
+
current_count = base_token_count
|
|
1368
|
+
|
|
1369
|
+
# 保存已添加的内容指纹,避免重复
|
|
1370
|
+
added_content_hashes = set()
|
|
1371
|
+
|
|
1372
|
+
# 按文件添加文档片段
|
|
1373
|
+
for file_path, docs in file_groups.items():
|
|
1374
|
+
# 按相关性排序
|
|
1375
|
+
docs.sort(key=lambda x: x[1], reverse=True)
|
|
1376
|
+
|
|
1377
|
+
# 添加文件信息
|
|
1378
|
+
file_header = f"\n## 文件: {file_path}\n"
|
|
1379
|
+
if current_count + get_context_token_count(file_header) > available_tokens_per_batch:
|
|
1380
|
+
break
|
|
1381
|
+
|
|
1382
|
+
prompt += file_header
|
|
1383
|
+
current_count += get_context_token_count(file_header)
|
|
1384
|
+
|
|
1385
|
+
# 添加相关的文档片段,不限制每个文件的片段数量
|
|
1386
|
+
for doc, score in docs:
|
|
1387
|
+
# 计算内容指纹以避免重复
|
|
1388
|
+
content_hash = hash(doc.content)
|
|
1389
|
+
if content_hash in added_content_hashes:
|
|
1390
|
+
continue
|
|
1391
|
+
|
|
1392
|
+
# 格式化文档片段
|
|
1393
|
+
doc_content = f"""
|
|
1394
|
+
### 片段 {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']} [相关度: {score:.2f}]
|
|
1395
|
+
```
|
|
1396
|
+
{doc.content}
|
|
1397
|
+
```
|
|
1398
|
+
"""
|
|
1399
|
+
if current_count + get_context_token_count(doc_content) > available_tokens_per_batch:
|
|
1400
|
+
break
|
|
1401
|
+
|
|
1402
|
+
prompt += doc_content
|
|
1403
|
+
current_count += get_context_token_count(doc_content)
|
|
1404
|
+
added_content_hashes.add(content_hash)
|
|
1405
|
+
|
|
1406
|
+
prompt += footer_prompt
|
|
1407
|
+
spinner.text = "添加上下文完成"
|
|
1408
|
+
spinner.ok("✅")
|
|
749
1409
|
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
1410
|
+
# 直接生成答案
|
|
1411
|
+
with yaspin(text="正在生成答案...", color="cyan") as spinner:
|
|
1412
|
+
response = model.chat_until_success(prompt)
|
|
1413
|
+
spinner.text = "答案生成完成"
|
|
1414
|
+
spinner.ok("✅")
|
|
1415
|
+
return response
|
|
1416
|
+
|
|
1417
|
+
# 分批处理文档
|
|
1418
|
+
else:
|
|
1419
|
+
batch_responses = []
|
|
754
1420
|
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
)
|
|
769
|
-
break
|
|
1421
|
+
# 准备批次
|
|
1422
|
+
with yaspin(text=f"准备分批处理 (共{batch_count}批)...", color="cyan") as spinner:
|
|
1423
|
+
batches = []
|
|
1424
|
+
current_batch = []
|
|
1425
|
+
current_batch_tokens = 0
|
|
1426
|
+
|
|
1427
|
+
# 按相关性排序处理文件
|
|
1428
|
+
sorted_files = sorted(file_groups.items(),
|
|
1429
|
+
key=lambda x: max(score for _, score in x[1]) if x[1] else 0,
|
|
1430
|
+
reverse=True)
|
|
1431
|
+
|
|
1432
|
+
for file_path, docs in sorted_files:
|
|
1433
|
+
# 按相关性排序文档
|
|
1434
|
+
docs.sort(key=lambda x: x[1], reverse=True)
|
|
770
1435
|
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
1436
|
+
# 处理每个文件的文档
|
|
1437
|
+
file_header = f"\n## 文件: {file_path}\n"
|
|
1438
|
+
file_header_tokens = get_context_token_count(file_header)
|
|
1439
|
+
|
|
1440
|
+
# 如果当前批次添加这个文件会超过限制,创建新批次
|
|
1441
|
+
file_docs = []
|
|
1442
|
+
file_docs_tokens = 0
|
|
1443
|
+
|
|
1444
|
+
# 计算此文件要添加的所有文档,不限制片段数量
|
|
1445
|
+
for doc, score in docs:
|
|
1446
|
+
if score < 0.2: # 过滤低相关性文档
|
|
1447
|
+
continue
|
|
1448
|
+
|
|
1449
|
+
doc_content = f"""
|
|
1450
|
+
### 片段 {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']} [相关度: {score:.2f}]
|
|
1451
|
+
```
|
|
1452
|
+
{doc.content}
|
|
1453
|
+
```
|
|
1454
|
+
"""
|
|
1455
|
+
doc_tokens = get_context_token_count(doc_content)
|
|
1456
|
+
file_docs.append((doc, score, doc_content, doc_tokens))
|
|
1457
|
+
file_docs_tokens += doc_tokens
|
|
1458
|
+
|
|
1459
|
+
# 如果此文件的内容加上文件头会导致当前批次超限,创建新批次
|
|
1460
|
+
if current_batch and (current_batch_tokens + file_header_tokens + file_docs_tokens > available_tokens_per_batch):
|
|
1461
|
+
batches.append(current_batch)
|
|
1462
|
+
current_batch = []
|
|
1463
|
+
current_batch_tokens = 0
|
|
1464
|
+
|
|
1465
|
+
# 将文件及其文档添加到当前批次
|
|
1466
|
+
if file_docs: # 如果有要添加的文档
|
|
1467
|
+
current_batch.append((file_path, file_header, file_docs))
|
|
1468
|
+
current_batch_tokens += file_header_tokens + file_docs_tokens
|
|
1469
|
+
|
|
1470
|
+
# 添加最后一个批次
|
|
1471
|
+
if current_batch:
|
|
1472
|
+
batches.append(current_batch)
|
|
1473
|
+
|
|
1474
|
+
spinner.text = f"分批准备完成,共 {len(batches)} 批"
|
|
1475
|
+
spinner.ok("✅")
|
|
1476
|
+
|
|
1477
|
+
# 处理每个批次
|
|
1478
|
+
for batch_idx, batch in enumerate(batches):
|
|
1479
|
+
with yaspin(text=f"处理批次 {batch_idx+1}/{len(batches)}...", color="cyan") as spinner:
|
|
1480
|
+
# 构建批次提示词
|
|
1481
|
+
batch_prompt = base_prompt + f"\n\n## 批次 {batch_idx+1}/{len(batches)} 的相关文档:\n"
|
|
1482
|
+
|
|
1483
|
+
# 添加批次中的文档
|
|
1484
|
+
for file_path, file_header, file_docs in batch:
|
|
1485
|
+
batch_prompt += file_header
|
|
1486
|
+
|
|
1487
|
+
for doc, score, doc_content, _ in file_docs:
|
|
1488
|
+
batch_prompt += doc_content
|
|
1489
|
+
|
|
1490
|
+
# 为最后一个批次添加总结指令,为中间批次添加部分分析指令
|
|
1491
|
+
if batch_idx == len(batches) - 1:
|
|
1492
|
+
# 最后一个批次,添加总结所有批次的指令
|
|
1493
|
+
if len(batches) > 1:
|
|
1494
|
+
batch_prompt += f"""
|
|
1495
|
+
# 📊 汇总分析
|
|
1496
|
+
这是最后一批文档。请基于此批次和之前批次的分析,提供一个全面的最终回答。
|
|
1497
|
+
"""
|
|
1498
|
+
batch_prompt += footer_prompt
|
|
1499
|
+
else:
|
|
1500
|
+
# 中间批次,添加部分分析指令
|
|
1501
|
+
batch_prompt += f"""
|
|
1502
|
+
# 📝 批次分析
|
|
1503
|
+
这是第 {batch_idx+1}/{len(batches)} 批文档。请分析这批文档中与问题相关的信息。
|
|
1504
|
+
在你的分析中:
|
|
1505
|
+
1. 提取关键信息点
|
|
1506
|
+
2. 识别可能对最终答案有帮助的内容
|
|
1507
|
+
3. 简明扼要,重点关注与问题直接相关的内容
|
|
1508
|
+
4. 忽略与问题无关的内容
|
|
1509
|
+
"""
|
|
1510
|
+
|
|
1511
|
+
spinner.text = f"正在分析批次 {batch_idx+1}/{len(batches)}..."
|
|
1512
|
+
|
|
1513
|
+
# 调用模型处理当前批次
|
|
1514
|
+
batch_response = model.chat_until_success(batch_prompt)
|
|
1515
|
+
batch_responses.append(batch_response)
|
|
1516
|
+
|
|
1517
|
+
spinner.text = f"批次 {batch_idx+1}/{len(batches)} 分析完成"
|
|
1518
|
+
spinner.ok("✅")
|
|
1519
|
+
|
|
1520
|
+
# 如果只有一个批次,直接返回结果
|
|
1521
|
+
if len(batch_responses) == 1:
|
|
1522
|
+
return batch_responses[0]
|
|
1523
|
+
|
|
1524
|
+
# 如果有多个批次,需要汇总结果
|
|
1525
|
+
with yaspin(text="汇总多批次分析结果...", color="cyan") as spinner:
|
|
1526
|
+
# 构建汇总提示词
|
|
1527
|
+
summary_prompt = f"""
|
|
1528
|
+
# 🔄 批次汇总任务
|
|
785
1529
|
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
1530
|
+
## 原始问题
|
|
1531
|
+
{question}
|
|
1532
|
+
|
|
1533
|
+
## 多批次分析结果
|
|
1534
|
+
你已经对相关文档进行了多批次分析,现在需要将这些分析结果汇总成一个连贯、全面的回答。
|
|
1535
|
+
|
|
1536
|
+
以下是各批次的分析结果:
|
|
1537
|
+
|
|
1538
|
+
"""
|
|
1539
|
+
|
|
1540
|
+
# 添加每个批次的分析结果
|
|
1541
|
+
for i, response in enumerate(batch_responses):
|
|
1542
|
+
summary_prompt += f"""
|
|
1543
|
+
### 批次 {i+1} 分析结果
|
|
1544
|
+
{response}
|
|
1545
|
+
|
|
1546
|
+
"""
|
|
1547
|
+
|
|
1548
|
+
# 添加汇总指导
|
|
1549
|
+
summary_prompt += """
|
|
1550
|
+
## 汇总要求
|
|
1551
|
+
请基于以上所有批次的分析结果,提供一个综合、连贯的最终回答。
|
|
1552
|
+
|
|
1553
|
+
# 🎯 核心职责
|
|
1554
|
+
- 全面分析文档片段
|
|
1555
|
+
- 准确回答问题
|
|
1556
|
+
- 引用源文档
|
|
1557
|
+
- 识别缺失信息
|
|
1558
|
+
- 保持专业语气
|
|
1559
|
+
|
|
1560
|
+
# 📋 回答要求
|
|
1561
|
+
## 内容质量
|
|
1562
|
+
- 严格基于提供的文档作答
|
|
1563
|
+
- 具体且精确
|
|
1564
|
+
- 在有帮助时引用相关内容
|
|
1565
|
+
- 指出任何信息缺口
|
|
1566
|
+
- 使用专业语言
|
|
1567
|
+
|
|
1568
|
+
## 回答结构
|
|
1569
|
+
1. 直接回答
|
|
1570
|
+
- 清晰简洁的回应
|
|
1571
|
+
- 基于文档证据
|
|
1572
|
+
- 专业术语
|
|
1573
|
+
|
|
1574
|
+
2. 支持细节
|
|
1575
|
+
- 相关文档引用
|
|
1576
|
+
- 文件参考
|
|
1577
|
+
- 上下文解释
|
|
1578
|
+
|
|
1579
|
+
3. 信息缺口(如有)
|
|
1580
|
+
- 缺失信息
|
|
1581
|
+
- 需要的额外上下文
|
|
1582
|
+
- 潜在限制
|
|
1583
|
+
|
|
1584
|
+
请直接提供最终回答,不需要解释你的汇总过程。
|
|
1585
|
+
"""
|
|
1586
|
+
|
|
1587
|
+
spinner.text = "正在生成最终汇总答案..."
|
|
1588
|
+
|
|
1589
|
+
# 调用模型生成最终汇总
|
|
1590
|
+
final_response = model.chat_until_success(summary_prompt)
|
|
1591
|
+
|
|
1592
|
+
spinner.text = "汇总答案生成完成"
|
|
1593
|
+
spinner.ok("✅")
|
|
1594
|
+
|
|
1595
|
+
return final_response
|
|
792
1596
|
|
|
793
1597
|
except Exception as e:
|
|
794
1598
|
PrettyOutput.print(f"回答失败:{str(e)}", OutputType.ERROR)
|
|
795
1599
|
return None
|
|
1600
|
+
|
|
1601
|
+
def _enhance_query(self, query: str) -> str:
|
|
1602
|
+
"""增强查询以提高检索质量
|
|
1603
|
+
|
|
1604
|
+
Args:
|
|
1605
|
+
query: 原始查询
|
|
1606
|
+
|
|
1607
|
+
Returns:
|
|
1608
|
+
str: 增强后的查询
|
|
1609
|
+
"""
|
|
1610
|
+
# 简单的查询预处理
|
|
1611
|
+
query = query.strip()
|
|
1612
|
+
|
|
1613
|
+
# 如果查询太短,返回原始查询
|
|
1614
|
+
if len(query) < 10:
|
|
1615
|
+
return query
|
|
1616
|
+
|
|
1617
|
+
try:
|
|
1618
|
+
# 尝试使用大模型增强查询(如果可用)
|
|
1619
|
+
model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
|
|
1620
|
+
enhance_prompt = f"""请分析以下查询,提取关键概念、关键词和主题。
|
|
1621
|
+
|
|
1622
|
+
查询:"{query}"
|
|
1623
|
+
|
|
1624
|
+
输出格式:对原始查询的改写版本,专注于提取关键信息,保留原始语义,以提高检索相关度。
|
|
1625
|
+
仅输出改写后的查询文本,不要输出其他内容。
|
|
1626
|
+
只对信息进行最小必要的增强,不要过度添加与原始查询无关的内容。
|
|
1627
|
+
"""
|
|
1628
|
+
|
|
1629
|
+
enhanced_query = model.chat_until_success(enhance_prompt)
|
|
1630
|
+
# 清理增强的查询结果
|
|
1631
|
+
enhanced_query = enhanced_query.strip().strip('"')
|
|
1632
|
+
|
|
1633
|
+
# 如果增强查询有效且不是完全相同的,使用它
|
|
1634
|
+
if enhanced_query and len(enhanced_query) >= len(query) / 2 and enhanced_query != query:
|
|
1635
|
+
return enhanced_query
|
|
1636
|
+
|
|
1637
|
+
except Exception:
|
|
1638
|
+
# 如果增强失败,使用原始查询
|
|
1639
|
+
pass
|
|
1640
|
+
|
|
1641
|
+
return query
|
|
796
1642
|
|
|
797
1643
|
def is_index_built(self) -> bool:
|
|
798
1644
|
"""Check if the index is built and valid
|
|
@@ -802,6 +1648,33 @@ Relevant Documents (by relevance):
|
|
|
802
1648
|
"""
|
|
803
1649
|
return self.index is not None and len(self.documents) > 0
|
|
804
1650
|
|
|
1651
|
+
def _delete_file_cache(self, file_path: str, spinner=None):
|
|
1652
|
+
"""Delete cache files for a specific file
|
|
1653
|
+
|
|
1654
|
+
Args:
|
|
1655
|
+
file_path: Path to the original file
|
|
1656
|
+
spinner: Optional spinner for progress information. If None, runs silently.
|
|
1657
|
+
"""
|
|
1658
|
+
try:
|
|
1659
|
+
# Delete document cache
|
|
1660
|
+
doc_cache_path = self._get_cache_path(file_path, "doc")
|
|
1661
|
+
if os.path.exists(doc_cache_path):
|
|
1662
|
+
os.remove(doc_cache_path)
|
|
1663
|
+
if spinner is not None:
|
|
1664
|
+
spinner.write(f"🗑️ 删除文档缓存: {file_path}")
|
|
1665
|
+
|
|
1666
|
+
# Delete vector cache
|
|
1667
|
+
vec_cache_path = self._get_cache_path(file_path, "vec")
|
|
1668
|
+
if os.path.exists(vec_cache_path):
|
|
1669
|
+
os.remove(vec_cache_path)
|
|
1670
|
+
if spinner is not None:
|
|
1671
|
+
spinner.write(f"🗑️ 删除向量缓存: {file_path}")
|
|
1672
|
+
|
|
1673
|
+
except Exception as e:
|
|
1674
|
+
if spinner is not None:
|
|
1675
|
+
spinner.write(f"❌ 删除缓存失败: {file_path}: {str(e)}")
|
|
1676
|
+
PrettyOutput.print(f"删除缓存失败: {file_path}: {str(e)}", output_type=OutputType.ERROR)
|
|
1677
|
+
|
|
805
1678
|
def main():
|
|
806
1679
|
"""Main function"""
|
|
807
1680
|
import argparse
|
|
@@ -828,11 +1701,18 @@ def main():
|
|
|
828
1701
|
args.dir = current_dir
|
|
829
1702
|
|
|
830
1703
|
if args.dir and args.build:
|
|
831
|
-
PrettyOutput.print(f"正在处理目录: {args.dir}", output_type=OutputType.INFO)
|
|
832
1704
|
rag.build_index(args.dir)
|
|
833
1705
|
return 0
|
|
834
1706
|
|
|
835
1707
|
if args.search or args.ask:
|
|
1708
|
+
# 当需要搜索或提问时,自动检查并建立索引
|
|
1709
|
+
if not rag.is_index_built():
|
|
1710
|
+
PrettyOutput.print(f"索引未建立,自动为目录 '{args.dir}' 建立索引...", OutputType.INFO)
|
|
1711
|
+
rag.build_index(args.dir)
|
|
1712
|
+
|
|
1713
|
+
if not rag.is_index_built():
|
|
1714
|
+
PrettyOutput.print("索引建立失败,请检查目录和文件格式", OutputType.ERROR)
|
|
1715
|
+
return 1
|
|
836
1716
|
|
|
837
1717
|
if args.search:
|
|
838
1718
|
results = rag.query(args.search)
|
|
@@ -855,7 +1735,7 @@ def main():
|
|
|
855
1735
|
return 1
|
|
856
1736
|
|
|
857
1737
|
# Display answer
|
|
858
|
-
output = f"""
|
|
1738
|
+
output = f"""{response}"""
|
|
859
1739
|
PrettyOutput.print(output, output_type=OutputType.INFO)
|
|
860
1740
|
return 0
|
|
861
1741
|
|