jarvis-ai-assistant 0.1.132__py3-none-any.whl → 0.1.138__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jarvis-ai-assistant might be problematic. Click here for more details.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +330 -347
- jarvis/jarvis_agent/builtin_input_handler.py +16 -6
- jarvis/jarvis_agent/file_input_handler.py +9 -9
- jarvis/jarvis_agent/jarvis.py +143 -0
- jarvis/jarvis_agent/main.py +12 -13
- jarvis/jarvis_agent/output_handler.py +3 -3
- jarvis/jarvis_agent/patch.py +92 -64
- jarvis/jarvis_agent/shell_input_handler.py +5 -3
- jarvis/jarvis_code_agent/code_agent.py +263 -177
- jarvis/jarvis_code_agent/file_select.py +24 -24
- jarvis/jarvis_dev/main.py +45 -59
- jarvis/jarvis_git_details/__init__.py +0 -0
- jarvis/jarvis_git_details/main.py +179 -0
- jarvis/jarvis_git_squash/main.py +7 -7
- jarvis/jarvis_lsp/base.py +11 -53
- jarvis/jarvis_lsp/cpp.py +13 -28
- jarvis/jarvis_lsp/go.py +13 -28
- jarvis/jarvis_lsp/python.py +8 -27
- jarvis/jarvis_lsp/registry.py +21 -83
- jarvis/jarvis_lsp/rust.py +15 -30
- jarvis/jarvis_methodology/main.py +101 -0
- jarvis/jarvis_multi_agent/__init__.py +10 -51
- jarvis/jarvis_multi_agent/main.py +43 -0
- jarvis/jarvis_platform/__init__.py +1 -1
- jarvis/jarvis_platform/ai8.py +67 -89
- jarvis/jarvis_platform/base.py +14 -13
- jarvis/jarvis_platform/kimi.py +25 -28
- jarvis/jarvis_platform/ollama.py +24 -26
- jarvis/jarvis_platform/openai.py +15 -19
- jarvis/jarvis_platform/oyi.py +48 -50
- jarvis/jarvis_platform/registry.py +29 -44
- jarvis/jarvis_platform/yuanbao.py +39 -43
- jarvis/jarvis_platform_manager/main.py +81 -81
- jarvis/jarvis_platform_manager/openai_test.py +21 -21
- jarvis/jarvis_rag/file_processors.py +18 -18
- jarvis/jarvis_rag/main.py +262 -278
- jarvis/jarvis_smart_shell/main.py +12 -12
- jarvis/jarvis_tools/ask_codebase.py +85 -78
- jarvis/jarvis_tools/ask_user.py +8 -8
- jarvis/jarvis_tools/base.py +4 -4
- jarvis/jarvis_tools/chdir.py +9 -9
- jarvis/jarvis_tools/code_review.py +40 -21
- jarvis/jarvis_tools/create_code_agent.py +15 -15
- jarvis/jarvis_tools/create_sub_agent.py +0 -1
- jarvis/jarvis_tools/execute_python_script.py +3 -3
- jarvis/jarvis_tools/execute_shell.py +11 -11
- jarvis/jarvis_tools/execute_shell_script.py +3 -3
- jarvis/jarvis_tools/file_analyzer.py +116 -105
- jarvis/jarvis_tools/file_operation.py +22 -20
- jarvis/jarvis_tools/find_caller.py +105 -40
- jarvis/jarvis_tools/find_methodolopy.py +65 -0
- jarvis/jarvis_tools/find_symbol.py +123 -39
- jarvis/jarvis_tools/function_analyzer.py +140 -57
- jarvis/jarvis_tools/git_commiter.py +10 -10
- jarvis/jarvis_tools/lsp_get_diagnostics.py +19 -19
- jarvis/jarvis_tools/methodology.py +22 -67
- jarvis/jarvis_tools/project_analyzer.py +137 -53
- jarvis/jarvis_tools/rag.py +15 -20
- jarvis/jarvis_tools/read_code.py +25 -23
- jarvis/jarvis_tools/read_webpage.py +31 -31
- jarvis/jarvis_tools/registry.py +72 -52
- jarvis/jarvis_tools/search_web.py +23 -353
- jarvis/jarvis_tools/tool_generator.py +19 -19
- jarvis/jarvis_utils/config.py +36 -96
- jarvis/jarvis_utils/embedding.py +83 -83
- jarvis/jarvis_utils/git_utils.py +20 -20
- jarvis/jarvis_utils/globals.py +18 -6
- jarvis/jarvis_utils/input.py +10 -9
- jarvis/jarvis_utils/methodology.py +141 -140
- jarvis/jarvis_utils/output.py +13 -13
- jarvis/jarvis_utils/utils.py +23 -71
- {jarvis_ai_assistant-0.1.132.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/METADATA +6 -15
- jarvis_ai_assistant-0.1.138.dist-info/RECORD +85 -0
- {jarvis_ai_assistant-0.1.132.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/entry_points.txt +4 -3
- jarvis/jarvis_tools/lsp_find_definition.py +0 -150
- jarvis/jarvis_tools/lsp_find_references.py +0 -127
- jarvis/jarvis_tools/select_code_files.py +0 -62
- jarvis_ai_assistant-0.1.132.dist-info/RECORD +0 -82
- {jarvis_ai_assistant-0.1.132.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.132.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.132.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/top_level.txt +0 -0
jarvis/jarvis_rag/main.py
CHANGED
|
@@ -13,13 +13,12 @@ import lzma # 添加 lzma 导入
|
|
|
13
13
|
from threading import Lock
|
|
14
14
|
import hashlib
|
|
15
15
|
|
|
16
|
-
from jarvis.jarvis_utils.config import get_max_paragraph_length, get_max_token_count, get_min_paragraph_length,
|
|
16
|
+
from jarvis.jarvis_utils.config import get_max_paragraph_length, get_max_token_count, get_min_paragraph_length, get_rag_ignored_paths
|
|
17
17
|
from jarvis.jarvis_utils.embedding import get_context_token_count, get_embedding, get_embedding_batch, load_embedding_model, rerank_results
|
|
18
18
|
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
|
19
|
-
from jarvis.jarvis_utils.utils import ct, get_file_md5, init_env,
|
|
20
|
-
|
|
21
|
-
from .file_processors import TextFileProcessor, PDFProcessor, DocxProcessor, PPTProcessor, ExcelProcessor
|
|
19
|
+
from jarvis.jarvis_utils.utils import ct, get_file_md5, init_env, ot
|
|
22
20
|
|
|
21
|
+
from jarvis.jarvis_rag.file_processors import TextFileProcessor, PDFProcessor, DocxProcessor, PPTProcessor, ExcelProcessor
|
|
23
22
|
"""
|
|
24
23
|
Jarvis RAG (Retrieval-Augmented Generation) Module
|
|
25
24
|
|
|
@@ -50,7 +49,7 @@ class Document:
|
|
|
50
49
|
class RAGTool:
|
|
51
50
|
def __init__(self, root_dir: str):
|
|
52
51
|
"""Initialize RAG tool
|
|
53
|
-
|
|
52
|
+
|
|
54
53
|
Args:
|
|
55
54
|
root_dir: Project root directory
|
|
56
55
|
"""
|
|
@@ -60,7 +59,7 @@ class RAGTool:
|
|
|
60
59
|
os.chdir(self.root_dir)
|
|
61
60
|
spinner.text = "环境初始化完成"
|
|
62
61
|
spinner.ok("✅")
|
|
63
|
-
|
|
62
|
+
|
|
64
63
|
# Initialize configuration
|
|
65
64
|
with yaspin(text="初始化配置...", color="cyan") as spinner:
|
|
66
65
|
self.min_paragraph_length = get_min_paragraph_length() # Minimum paragraph length
|
|
@@ -69,7 +68,7 @@ class RAGTool:
|
|
|
69
68
|
self.max_token_count = int(get_max_token_count() * 0.8)
|
|
70
69
|
spinner.text = "配置初始化完成"
|
|
71
70
|
spinner.ok("✅")
|
|
72
|
-
|
|
71
|
+
|
|
73
72
|
# Initialize data directory
|
|
74
73
|
with yaspin(text="初始化数据目录...", color="cyan") as spinner:
|
|
75
74
|
self.data_dir = os.path.join(self.root_dir, ".jarvis/rag")
|
|
@@ -77,7 +76,7 @@ class RAGTool:
|
|
|
77
76
|
os.makedirs(self.data_dir)
|
|
78
77
|
spinner.text = "数据目录初始化完成"
|
|
79
78
|
spinner.ok("✅")
|
|
80
|
-
|
|
79
|
+
|
|
81
80
|
# Initialize embedding model
|
|
82
81
|
with yaspin(text="初始化模型...", color="cyan") as spinner:
|
|
83
82
|
try:
|
|
@@ -94,14 +93,14 @@ class RAGTool:
|
|
|
94
93
|
self.cache_dir = os.path.join(self.data_dir, "cache")
|
|
95
94
|
if not os.path.exists(self.cache_dir):
|
|
96
95
|
os.makedirs(self.cache_dir)
|
|
97
|
-
|
|
96
|
+
|
|
98
97
|
self.documents: List[Document] = []
|
|
99
98
|
self.index = None
|
|
100
99
|
self.flat_index = None
|
|
101
100
|
self.file_md5_cache = {}
|
|
102
101
|
spinner.text = "缓存目录初始化完成"
|
|
103
102
|
spinner.ok("✅")
|
|
104
|
-
|
|
103
|
+
|
|
105
104
|
# 加载缓存索引
|
|
106
105
|
self._load_cache_index()
|
|
107
106
|
|
|
@@ -118,34 +117,19 @@ class RAGTool:
|
|
|
118
117
|
spinner.ok("✅")
|
|
119
118
|
|
|
120
119
|
|
|
121
|
-
# Add thread related configuration
|
|
122
|
-
with yaspin(text="初始化线程配置...", color="cyan") as spinner:
|
|
123
|
-
self.thread_count = get_thread_count()
|
|
124
|
-
self.vector_lock = Lock() # Protect vector list concurrency
|
|
125
|
-
spinner.text = "线程配置初始化完成"
|
|
126
|
-
spinner.ok("✅")
|
|
127
|
-
|
|
128
|
-
# 初始化 GPU 内存配置
|
|
129
|
-
with yaspin(text="初始化 GPU 内存配置...", color="cyan") as spinner:
|
|
130
|
-
with spinner.hidden():
|
|
131
|
-
self.gpu_config = init_gpu_config()
|
|
132
|
-
spinner.text = "GPU 内存配置初始化完成"
|
|
133
|
-
spinner.ok("✅")
|
|
134
|
-
|
|
135
|
-
|
|
136
120
|
def _get_cache_path(self, file_path: str, cache_type: str = "doc") -> str:
|
|
137
121
|
"""Get cache file path for a document
|
|
138
|
-
|
|
122
|
+
|
|
139
123
|
Args:
|
|
140
124
|
file_path: Original file path
|
|
141
125
|
cache_type: Type of cache ("doc" for documents, "vec" for vectors)
|
|
142
|
-
|
|
126
|
+
|
|
143
127
|
Returns:
|
|
144
128
|
str: Cache file path
|
|
145
129
|
"""
|
|
146
130
|
# 使用文件路径的哈希作为缓存文件名
|
|
147
131
|
file_hash = hashlib.md5(file_path.encode()).hexdigest()
|
|
148
|
-
|
|
132
|
+
|
|
149
133
|
# 确保不同类型的缓存有不同的目录
|
|
150
134
|
if cache_type == "doc":
|
|
151
135
|
cache_subdir = os.path.join(self.cache_dir, "documents")
|
|
@@ -153,11 +137,11 @@ class RAGTool:
|
|
|
153
137
|
cache_subdir = os.path.join(self.cache_dir, "vectors")
|
|
154
138
|
else:
|
|
155
139
|
cache_subdir = self.cache_dir
|
|
156
|
-
|
|
140
|
+
|
|
157
141
|
# 确保子目录存在
|
|
158
142
|
if not os.path.exists(cache_subdir):
|
|
159
143
|
os.makedirs(cache_subdir)
|
|
160
|
-
|
|
144
|
+
|
|
161
145
|
return os.path.join(cache_subdir, f"{file_hash}.cache")
|
|
162
146
|
|
|
163
147
|
def _load_cache_index(self):
|
|
@@ -171,7 +155,7 @@ class RAGTool:
|
|
|
171
155
|
self.file_md5_cache = cache_data.get("file_md5_cache", {})
|
|
172
156
|
spinner.text = "缓存索引加载完成"
|
|
173
157
|
spinner.ok("✅")
|
|
174
|
-
|
|
158
|
+
|
|
175
159
|
# 从各个缓存文件加载文档
|
|
176
160
|
with yaspin(text="加载缓存文件...", color="cyan") as spinner:
|
|
177
161
|
for file_path in self.file_md5_cache:
|
|
@@ -186,35 +170,35 @@ class RAGTool:
|
|
|
186
170
|
spinner.write(f"❌ 加载文档缓存失败: {file_path}: {str(e)}")
|
|
187
171
|
spinner.text = "文档缓存加载完成"
|
|
188
172
|
spinner.ok("✅")
|
|
189
|
-
|
|
173
|
+
|
|
190
174
|
# 重建向量索引
|
|
191
175
|
if self.documents:
|
|
192
176
|
with yaspin(text="重建向量索引...", color="cyan") as spinner:
|
|
193
177
|
vectors = []
|
|
194
|
-
|
|
178
|
+
|
|
195
179
|
# 按照文档列表顺序加载向量
|
|
196
180
|
processed_files = set()
|
|
197
181
|
for doc in self.documents:
|
|
198
182
|
file_path = doc.metadata['file_path']
|
|
199
|
-
|
|
183
|
+
|
|
200
184
|
# 避免重复处理同一个文件
|
|
201
185
|
if file_path in processed_files:
|
|
202
186
|
continue
|
|
203
|
-
|
|
187
|
+
|
|
204
188
|
processed_files.add(file_path)
|
|
205
189
|
vec_cache_path = self._get_cache_path(file_path, "vec")
|
|
206
|
-
|
|
190
|
+
|
|
207
191
|
if os.path.exists(vec_cache_path):
|
|
208
192
|
try:
|
|
209
193
|
# 加载该文件的向量缓存
|
|
210
194
|
with lzma.open(vec_cache_path, 'rb') as f:
|
|
211
195
|
vec_cache_data = pickle.load(f)
|
|
212
196
|
file_vectors = vec_cache_data["vectors"]
|
|
213
|
-
|
|
197
|
+
|
|
214
198
|
# 按照文档的chunk_index检索对应向量
|
|
215
|
-
doc_indices = [d.metadata['chunk_index'] for d in self.documents
|
|
199
|
+
doc_indices = [d.metadata['chunk_index'] for d in self.documents
|
|
216
200
|
if d.metadata['file_path'] == file_path]
|
|
217
|
-
|
|
201
|
+
|
|
218
202
|
# 检查向量数量与文档块数量是否匹配
|
|
219
203
|
if len(doc_indices) <= file_vectors.shape[0]:
|
|
220
204
|
for idx in doc_indices:
|
|
@@ -222,21 +206,21 @@ class RAGTool:
|
|
|
222
206
|
vectors.append(file_vectors[idx].reshape(1, -1))
|
|
223
207
|
else:
|
|
224
208
|
spinner.write(f"⚠️ 向量缓存不匹配: {file_path}")
|
|
225
|
-
|
|
209
|
+
|
|
226
210
|
spinner.text = f"加载向量缓存: {file_path}"
|
|
227
211
|
except Exception as e:
|
|
228
212
|
spinner.write(f"❌ 加载向量缓存失败: {file_path}: {str(e)}")
|
|
229
213
|
else:
|
|
230
214
|
spinner.write(f"⚠️ 缺少向量缓存: {file_path}")
|
|
231
|
-
|
|
215
|
+
|
|
232
216
|
if vectors:
|
|
233
217
|
vectors = np.vstack(vectors)
|
|
234
218
|
self._build_index(vectors, spinner)
|
|
235
219
|
spinner.text = f"向量索引重建完成,加载 {len(self.documents)} 个文档片段"
|
|
236
220
|
spinner.ok("✅")
|
|
237
|
-
|
|
221
|
+
|
|
238
222
|
except Exception as e:
|
|
239
|
-
PrettyOutput.print(f"加载缓存索引失败: {str(e)}",
|
|
223
|
+
PrettyOutput.print(f"加载缓存索引失败: {str(e)}",
|
|
240
224
|
output_type=OutputType.WARNING)
|
|
241
225
|
self.documents = []
|
|
242
226
|
self.index = None
|
|
@@ -245,7 +229,7 @@ class RAGTool:
|
|
|
245
229
|
|
|
246
230
|
def _save_cache(self, file_path: str, documents: List[Document], vectors: np.ndarray, spinner=None):
|
|
247
231
|
"""Save cache for a single file
|
|
248
|
-
|
|
232
|
+
|
|
249
233
|
Args:
|
|
250
234
|
file_path: File path
|
|
251
235
|
documents: List of documents
|
|
@@ -262,7 +246,7 @@ class RAGTool:
|
|
|
262
246
|
}
|
|
263
247
|
with lzma.open(doc_cache_path, 'wb') as f:
|
|
264
248
|
pickle.dump(doc_cache_data, f)
|
|
265
|
-
|
|
249
|
+
|
|
266
250
|
# 保存向量缓存
|
|
267
251
|
if spinner:
|
|
268
252
|
spinner.text = f"保存 {file_path} 的向量缓存..."
|
|
@@ -272,7 +256,7 @@ class RAGTool:
|
|
|
272
256
|
}
|
|
273
257
|
with lzma.open(vec_cache_path, 'wb') as f:
|
|
274
258
|
pickle.dump(vec_cache_data, f)
|
|
275
|
-
|
|
259
|
+
|
|
276
260
|
# 更新并保存索引
|
|
277
261
|
if spinner:
|
|
278
262
|
spinner.text = f"更新 {file_path} 的索引缓存..."
|
|
@@ -282,10 +266,10 @@ class RAGTool:
|
|
|
282
266
|
}
|
|
283
267
|
with lzma.open(index_path, 'wb') as f:
|
|
284
268
|
pickle.dump(index_data, f)
|
|
285
|
-
|
|
269
|
+
|
|
286
270
|
if spinner:
|
|
287
271
|
spinner.text = f"{file_path} 的缓存保存完成"
|
|
288
|
-
|
|
272
|
+
|
|
289
273
|
except Exception as e:
|
|
290
274
|
if spinner:
|
|
291
275
|
spinner.text = f"保存 {file_path} 的缓存失败: {str(e)}"
|
|
@@ -299,13 +283,13 @@ class RAGTool:
|
|
|
299
283
|
self.index = None
|
|
300
284
|
self.flat_index = None
|
|
301
285
|
return
|
|
302
|
-
|
|
286
|
+
|
|
303
287
|
# Create a flat index to store original vectors, for reconstruction
|
|
304
288
|
if spinner:
|
|
305
289
|
spinner.text = "创建平面索引用于向量重建..."
|
|
306
290
|
self.flat_index = faiss.IndexFlatIP(self.vector_dim)
|
|
307
291
|
self.flat_index.add(vectors) # type: ignore
|
|
308
|
-
|
|
292
|
+
|
|
309
293
|
# Create an IVF index for fast search
|
|
310
294
|
if spinner:
|
|
311
295
|
spinner.text = "创建IVF索引用于快速搜索..."
|
|
@@ -321,52 +305,52 @@ class RAGTool:
|
|
|
321
305
|
else:
|
|
322
306
|
# 原始逻辑:每1000个向量一个聚类中心,最少4个
|
|
323
307
|
nlist = max(4, int(num_vectors / 1000))
|
|
324
|
-
|
|
308
|
+
|
|
325
309
|
quantizer = faiss.IndexFlatIP(self.vector_dim)
|
|
326
310
|
self.index = faiss.IndexIVFFlat(quantizer, self.vector_dim, nlist, faiss.METRIC_INNER_PRODUCT)
|
|
327
|
-
|
|
311
|
+
|
|
328
312
|
# Train and add vectors
|
|
329
313
|
if spinner:
|
|
330
314
|
spinner.text = f"训练索引({vectors.shape[0]}个向量,{nlist}个聚类中心)..."
|
|
331
315
|
self.index.train(vectors) # type: ignore
|
|
332
|
-
|
|
316
|
+
|
|
333
317
|
if spinner:
|
|
334
318
|
spinner.text = "添加向量到索引..."
|
|
335
319
|
self.index.add(vectors) # type: ignore
|
|
336
|
-
|
|
320
|
+
|
|
337
321
|
# Set the number of clusters to probe during search
|
|
338
322
|
if spinner:
|
|
339
323
|
spinner.text = "设置搜索参数..."
|
|
340
324
|
self.index.nprobe = min(nlist, 10)
|
|
341
|
-
|
|
325
|
+
|
|
342
326
|
if spinner:
|
|
343
327
|
spinner.text = f"索引构建完成,共 {vectors.shape[0]} 个向量"
|
|
344
328
|
|
|
345
329
|
def _split_text(self, text: str) -> List[str]:
|
|
346
330
|
"""使用基于token计数的更智能的分割策略
|
|
347
|
-
|
|
331
|
+
|
|
348
332
|
Args:
|
|
349
333
|
text: 要分割的文本
|
|
350
|
-
|
|
334
|
+
|
|
351
335
|
Returns:
|
|
352
336
|
List[str]: 分割后的段落列表
|
|
353
337
|
"""
|
|
354
338
|
from jarvis.jarvis_utils.embedding import get_context_token_count
|
|
355
|
-
|
|
339
|
+
|
|
356
340
|
# 计算可用的最大和最小token数
|
|
357
341
|
max_tokens = int(self.max_paragraph_length * 0.25) # 字符长度转换为大致token数
|
|
358
342
|
min_tokens = int(self.min_paragraph_length * 0.25) # 字符长度转换为大致token数
|
|
359
|
-
|
|
343
|
+
|
|
360
344
|
# 添加重叠块以保持上下文一致性
|
|
361
345
|
paragraphs = []
|
|
362
346
|
current_chunk = []
|
|
363
347
|
current_token_count = 0
|
|
364
|
-
|
|
348
|
+
|
|
365
349
|
# 首先按句子分割
|
|
366
350
|
sentences = []
|
|
367
351
|
current_sentence = []
|
|
368
352
|
sentence_ends = {'。', '!', '?', '…', '.', '!', '?'}
|
|
369
|
-
|
|
353
|
+
|
|
370
354
|
for char in text:
|
|
371
355
|
current_sentence.append(char)
|
|
372
356
|
if char in sentence_ends:
|
|
@@ -374,32 +358,32 @@ class RAGTool:
|
|
|
374
358
|
if sentence.strip():
|
|
375
359
|
sentences.append(sentence)
|
|
376
360
|
current_sentence = []
|
|
377
|
-
|
|
361
|
+
|
|
378
362
|
if current_sentence:
|
|
379
363
|
sentence = ''.join(current_sentence)
|
|
380
364
|
if sentence.strip():
|
|
381
365
|
sentences.append(sentence)
|
|
382
|
-
|
|
366
|
+
|
|
383
367
|
# 基于句子构建重叠块
|
|
384
368
|
for sentence in sentences:
|
|
385
369
|
# 计算当前句子的token数
|
|
386
370
|
sentence_token_count = get_context_token_count(sentence)
|
|
387
|
-
|
|
371
|
+
|
|
388
372
|
# 检查添加此句子是否会超过最大token限制
|
|
389
373
|
if current_token_count + sentence_token_count > max_tokens:
|
|
390
374
|
if current_chunk:
|
|
391
375
|
chunk_text = ' '.join(current_chunk)
|
|
392
376
|
chunk_token_count = get_context_token_count(chunk_text)
|
|
393
|
-
|
|
377
|
+
|
|
394
378
|
if chunk_token_count >= min_tokens:
|
|
395
379
|
paragraphs.append(chunk_text)
|
|
396
|
-
|
|
380
|
+
|
|
397
381
|
# 保留一些内容作为重叠
|
|
398
382
|
# 保留最后两个句子作为重叠部分
|
|
399
383
|
if len(current_chunk) >= 2:
|
|
400
384
|
overlap_text = ' '.join(current_chunk[-2:])
|
|
401
385
|
overlap_token_count = get_context_token_count(overlap_text)
|
|
402
|
-
|
|
386
|
+
|
|
403
387
|
current_chunk = []
|
|
404
388
|
if overlap_text:
|
|
405
389
|
current_chunk.append(overlap_text)
|
|
@@ -410,19 +394,19 @@ class RAGTool:
|
|
|
410
394
|
# 如果当前块中句子不足两个,就重置
|
|
411
395
|
current_chunk = []
|
|
412
396
|
current_token_count = 0
|
|
413
|
-
|
|
397
|
+
|
|
414
398
|
# 添加当前句子到块中
|
|
415
399
|
current_chunk.append(sentence)
|
|
416
400
|
current_token_count += sentence_token_count
|
|
417
|
-
|
|
401
|
+
|
|
418
402
|
# 处理最后一个块
|
|
419
403
|
if current_chunk:
|
|
420
404
|
chunk_text = ' '.join(current_chunk)
|
|
421
405
|
chunk_token_count = get_context_token_count(chunk_text)
|
|
422
|
-
|
|
406
|
+
|
|
423
407
|
if chunk_token_count >= min_tokens:
|
|
424
408
|
paragraphs.append(chunk_text)
|
|
425
|
-
|
|
409
|
+
|
|
426
410
|
return paragraphs
|
|
427
411
|
|
|
428
412
|
|
|
@@ -452,13 +436,13 @@ class RAGTool:
|
|
|
452
436
|
if p.can_handle(file_path):
|
|
453
437
|
processor = p
|
|
454
438
|
break
|
|
455
|
-
|
|
439
|
+
|
|
456
440
|
if not processor:
|
|
457
441
|
# If no appropriate processor is found, return an empty document
|
|
458
442
|
if spinner:
|
|
459
443
|
spinner.text = f"没有找到适用于 {file_path} 的处理器,跳过处理"
|
|
460
444
|
return []
|
|
461
|
-
|
|
445
|
+
|
|
462
446
|
# Extract text content
|
|
463
447
|
if spinner:
|
|
464
448
|
spinner.text = f"提取 {file_path} 的文本内容..."
|
|
@@ -467,12 +451,12 @@ class RAGTool:
|
|
|
467
451
|
if spinner:
|
|
468
452
|
spinner.text = f"文件 {file_path} 没有文本内容,跳过处理"
|
|
469
453
|
return []
|
|
470
|
-
|
|
454
|
+
|
|
471
455
|
# Split text
|
|
472
456
|
if spinner:
|
|
473
457
|
spinner.text = f"分割 {file_path} 的文本..."
|
|
474
458
|
chunks = self._split_text(content)
|
|
475
|
-
|
|
459
|
+
|
|
476
460
|
# Create document objects
|
|
477
461
|
if spinner:
|
|
478
462
|
spinner.text = f"为 {file_path} 创建 {len(chunks)} 个文档对象..."
|
|
@@ -489,34 +473,34 @@ class RAGTool:
|
|
|
489
473
|
md5=current_md5
|
|
490
474
|
)
|
|
491
475
|
documents.append(doc)
|
|
492
|
-
|
|
476
|
+
|
|
493
477
|
# Update MD5 cache
|
|
494
478
|
self.file_md5_cache[file_path] = current_md5
|
|
495
479
|
if spinner:
|
|
496
480
|
spinner.text = f"文件 {file_path} 处理完成,共创建 {len(documents)} 个文档对象"
|
|
497
481
|
return documents
|
|
498
|
-
|
|
482
|
+
|
|
499
483
|
except Exception as e:
|
|
500
484
|
if spinner:
|
|
501
485
|
spinner.text = f"处理文件失败: {file_path}: {str(e)}"
|
|
502
|
-
PrettyOutput.print(f"处理文件失败: {file_path}: {str(e)}",
|
|
486
|
+
PrettyOutput.print(f"处理文件失败: {file_path}: {str(e)}",
|
|
503
487
|
output_type=OutputType.ERROR)
|
|
504
488
|
return []
|
|
505
489
|
|
|
506
490
|
def _should_ignore_path(self, path: str, ignored_paths: List[str]) -> bool:
|
|
507
491
|
"""
|
|
508
492
|
检查路径是否应该被忽略
|
|
509
|
-
|
|
493
|
+
|
|
510
494
|
Args:
|
|
511
495
|
path: 文件或目录路径
|
|
512
496
|
ignored_paths: 忽略模式列表
|
|
513
|
-
|
|
497
|
+
|
|
514
498
|
Returns:
|
|
515
499
|
bool: 如果路径应该被忽略则返回True
|
|
516
500
|
"""
|
|
517
501
|
import fnmatch
|
|
518
502
|
import os
|
|
519
|
-
|
|
503
|
+
|
|
520
504
|
# 获取相对路径
|
|
521
505
|
rel_path = path
|
|
522
506
|
if os.path.isabs(path):
|
|
@@ -525,31 +509,31 @@ class RAGTool:
|
|
|
525
509
|
except ValueError:
|
|
526
510
|
# 如果不能计算相对路径,使用原始路径
|
|
527
511
|
pass
|
|
528
|
-
|
|
512
|
+
|
|
529
513
|
path_parts = rel_path.split(os.sep)
|
|
530
|
-
|
|
514
|
+
|
|
531
515
|
# 检查路径的每一部分是否匹配任意忽略模式
|
|
532
516
|
for part in path_parts:
|
|
533
517
|
for pattern in ignored_paths:
|
|
534
518
|
if fnmatch.fnmatch(part, pattern):
|
|
535
519
|
return True
|
|
536
|
-
|
|
520
|
+
|
|
537
521
|
# 检查完整路径是否匹配任意忽略模式
|
|
538
522
|
for pattern in ignored_paths:
|
|
539
523
|
if fnmatch.fnmatch(rel_path, pattern):
|
|
540
524
|
return True
|
|
541
|
-
|
|
525
|
+
|
|
542
526
|
return False
|
|
543
|
-
|
|
527
|
+
|
|
544
528
|
def _is_git_repo(self) -> bool:
|
|
545
529
|
"""
|
|
546
530
|
检查当前目录是否为Git仓库
|
|
547
|
-
|
|
531
|
+
|
|
548
532
|
Returns:
|
|
549
533
|
bool: 如果是Git仓库则返回True
|
|
550
534
|
"""
|
|
551
535
|
import subprocess
|
|
552
|
-
|
|
536
|
+
|
|
553
537
|
try:
|
|
554
538
|
result = subprocess.run(
|
|
555
539
|
["git", "rev-parse", "--is-inside-work-tree"],
|
|
@@ -562,16 +546,16 @@ class RAGTool:
|
|
|
562
546
|
return result.returncode == 0 and result.stdout.strip() == "true"
|
|
563
547
|
except Exception:
|
|
564
548
|
return False
|
|
565
|
-
|
|
549
|
+
|
|
566
550
|
def _get_git_managed_files(self) -> List[str]:
|
|
567
551
|
"""
|
|
568
552
|
获取Git仓库中被管理的文件列表
|
|
569
|
-
|
|
553
|
+
|
|
570
554
|
Returns:
|
|
571
555
|
List[str]: 被Git管理的文件路径列表(相对路径)
|
|
572
556
|
"""
|
|
573
557
|
import subprocess
|
|
574
|
-
|
|
558
|
+
|
|
575
559
|
try:
|
|
576
560
|
# 获取git索引中的文件
|
|
577
561
|
result = subprocess.run(
|
|
@@ -582,12 +566,12 @@ class RAGTool:
|
|
|
582
566
|
text=True,
|
|
583
567
|
check=False
|
|
584
568
|
)
|
|
585
|
-
|
|
569
|
+
|
|
586
570
|
if result.returncode != 0:
|
|
587
571
|
return []
|
|
588
|
-
|
|
572
|
+
|
|
589
573
|
git_files = [line.strip() for line in result.stdout.splitlines() if line.strip()]
|
|
590
|
-
|
|
574
|
+
|
|
591
575
|
# 添加未暂存但已跟踪的修改文件
|
|
592
576
|
result = subprocess.run(
|
|
593
577
|
["git", "ls-files", "--modified"],
|
|
@@ -597,14 +581,14 @@ class RAGTool:
|
|
|
597
581
|
text=True,
|
|
598
582
|
check=False
|
|
599
583
|
)
|
|
600
|
-
|
|
584
|
+
|
|
601
585
|
if result.returncode == 0:
|
|
602
586
|
modified_files = [line.strip() for line in result.stdout.splitlines() if line.strip()]
|
|
603
587
|
git_files.extend([f for f in modified_files if f not in git_files])
|
|
604
|
-
|
|
588
|
+
|
|
605
589
|
# 转换为绝对路径
|
|
606
590
|
return [os.path.join(self.root_dir, file) for file in git_files]
|
|
607
|
-
|
|
591
|
+
|
|
608
592
|
except Exception as e:
|
|
609
593
|
PrettyOutput.print(f"获取Git管理的文件失败: {str(e)}", output_type=OutputType.WARNING)
|
|
610
594
|
return []
|
|
@@ -615,10 +599,10 @@ class RAGTool:
|
|
|
615
599
|
# Get all files
|
|
616
600
|
with yaspin(text="获取所有文件...", color="cyan") as spinner:
|
|
617
601
|
all_files = []
|
|
618
|
-
|
|
602
|
+
|
|
619
603
|
# 获取需要忽略的路径列表
|
|
620
604
|
ignored_paths = get_rag_ignored_paths()
|
|
621
|
-
|
|
605
|
+
|
|
622
606
|
# 检查是否为Git仓库
|
|
623
607
|
is_git_repo = self._is_git_repo()
|
|
624
608
|
if is_git_repo:
|
|
@@ -627,9 +611,9 @@ class RAGTool:
|
|
|
627
611
|
for file_path in git_files:
|
|
628
612
|
if self._should_ignore_path(file_path, ignored_paths):
|
|
629
613
|
continue
|
|
630
|
-
|
|
614
|
+
|
|
631
615
|
if os.path.getsize(file_path) > 100 * 1024 * 1024: # 100MB
|
|
632
|
-
PrettyOutput.print(f"跳过大文件: {file_path}",
|
|
616
|
+
PrettyOutput.print(f"跳过大文件: {file_path}",
|
|
633
617
|
output_type=OutputType.WARNING)
|
|
634
618
|
continue
|
|
635
619
|
all_files.append(file_path)
|
|
@@ -639,20 +623,20 @@ class RAGTool:
|
|
|
639
623
|
# 检查目录是否匹配忽略模式
|
|
640
624
|
if self._should_ignore_path(root, ignored_paths):
|
|
641
625
|
continue
|
|
642
|
-
|
|
626
|
+
|
|
643
627
|
for file in files:
|
|
644
628
|
file_path = os.path.join(root, file)
|
|
645
|
-
|
|
629
|
+
|
|
646
630
|
# 检查文件是否匹配忽略模式
|
|
647
631
|
if self._should_ignore_path(file_path, ignored_paths):
|
|
648
632
|
continue
|
|
649
|
-
|
|
633
|
+
|
|
650
634
|
if os.path.getsize(file_path) > 100 * 1024 * 1024: # 100MB
|
|
651
|
-
PrettyOutput.print(f"跳过大文件: {file_path}",
|
|
635
|
+
PrettyOutput.print(f"跳过大文件: {file_path}",
|
|
652
636
|
output_type=OutputType.WARNING)
|
|
653
637
|
continue
|
|
654
638
|
all_files.append(file_path)
|
|
655
|
-
|
|
639
|
+
|
|
656
640
|
spinner.text = f"获取所有文件完成,共 {len(all_files)} 个文件"
|
|
657
641
|
spinner.ok("✅")
|
|
658
642
|
|
|
@@ -660,10 +644,10 @@ class RAGTool:
|
|
|
660
644
|
with yaspin(text="清理缓存...", color="cyan") as spinner:
|
|
661
645
|
deleted_files = set(self.file_md5_cache.keys()) - set(all_files)
|
|
662
646
|
deleted_count = len(deleted_files)
|
|
663
|
-
|
|
647
|
+
|
|
664
648
|
if deleted_count > 0:
|
|
665
649
|
spinner.write(f"🗑️ 删除不存在文件的缓存: {deleted_count} 个")
|
|
666
|
-
|
|
650
|
+
|
|
667
651
|
for file_path in deleted_files:
|
|
668
652
|
# Remove from MD5 cache
|
|
669
653
|
del self.file_md5_cache[file_path]
|
|
@@ -671,7 +655,7 @@ class RAGTool:
|
|
|
671
655
|
self.documents = [doc for doc in self.documents if doc.metadata['file_path'] != file_path]
|
|
672
656
|
# Delete cache files
|
|
673
657
|
self._delete_file_cache(file_path, None) # Pass None as spinner to not show individual deletions
|
|
674
|
-
|
|
658
|
+
|
|
675
659
|
spinner.text = f"清理缓存完成,共删除 {deleted_count} 个不存在文件的缓存"
|
|
676
660
|
spinner.ok("✅")
|
|
677
661
|
|
|
@@ -681,7 +665,7 @@ class RAGTool:
|
|
|
681
665
|
unchanged_files = []
|
|
682
666
|
new_files_count = 0
|
|
683
667
|
modified_files_count = 0
|
|
684
|
-
|
|
668
|
+
|
|
685
669
|
for file_path in all_files:
|
|
686
670
|
current_md5 = get_file_md5(file_path)
|
|
687
671
|
if current_md5: # Only process files that can successfully calculate MD5
|
|
@@ -691,7 +675,7 @@ class RAGTool:
|
|
|
691
675
|
else:
|
|
692
676
|
# New file or modified file
|
|
693
677
|
files_to_process.append(file_path)
|
|
694
|
-
|
|
678
|
+
|
|
695
679
|
# 如果是修改的文件,删除旧缓存
|
|
696
680
|
if file_path in self.file_md5_cache:
|
|
697
681
|
modified_files_count += 1
|
|
@@ -701,7 +685,7 @@ class RAGTool:
|
|
|
701
685
|
self.documents = [doc for doc in self.documents if doc.metadata['file_path'] != file_path]
|
|
702
686
|
else:
|
|
703
687
|
new_files_count += 1
|
|
704
|
-
|
|
688
|
+
|
|
705
689
|
# 输出汇总信息
|
|
706
690
|
if unchanged_files:
|
|
707
691
|
spinner.write(f"📚 已缓存文件: {len(unchanged_files)} 个")
|
|
@@ -709,12 +693,12 @@ class RAGTool:
|
|
|
709
693
|
spinner.write(f"🆕 新增文件: {new_files_count} 个")
|
|
710
694
|
if modified_files_count > 0:
|
|
711
695
|
spinner.write(f"📝 修改文件: {modified_files_count} 个")
|
|
712
|
-
|
|
696
|
+
|
|
713
697
|
spinner.text = f"检查文件变化完成,共 {len(files_to_process)} 个文件需要处理"
|
|
714
698
|
spinner.ok("✅")
|
|
715
699
|
|
|
716
700
|
# Keep documents for unchanged files
|
|
717
|
-
unchanged_documents = [doc for doc in self.documents
|
|
701
|
+
unchanged_documents = [doc for doc in self.documents
|
|
718
702
|
if doc.metadata['file_path'] in unchanged_files]
|
|
719
703
|
|
|
720
704
|
# Process files one by one with optimized vectorization
|
|
@@ -724,7 +708,7 @@ class RAGTool:
|
|
|
724
708
|
success_count = 0
|
|
725
709
|
skipped_count = 0
|
|
726
710
|
failed_count = 0
|
|
727
|
-
|
|
711
|
+
|
|
728
712
|
with yaspin(text=f"处理文件中 (0/{len(files_to_process)})...", color="cyan") as spinner:
|
|
729
713
|
for index, file_path in enumerate(files_to_process):
|
|
730
714
|
spinner.text = f"处理文件中 ({index+1}/{len(files_to_process)}): {file_path}"
|
|
@@ -738,13 +722,13 @@ class RAGTool:
|
|
|
738
722
|
f"File:{doc.metadata['file_path']} Content:{doc.content}"
|
|
739
723
|
for doc in file_docs
|
|
740
724
|
]
|
|
741
|
-
|
|
725
|
+
|
|
742
726
|
file_vectors = get_embedding_batch(self.embedding_model, f"({index+1}/{len(files_to_process)}){file_path}", texts_to_vectorize, spinner)
|
|
743
|
-
|
|
727
|
+
|
|
744
728
|
# Save cache for this file
|
|
745
729
|
spinner.text = f"处理文件中 ({index+1}/{len(files_to_process)}): 保存 {file_path} 的缓存..."
|
|
746
730
|
self._save_cache(file_path, file_docs, file_vectors, spinner)
|
|
747
|
-
|
|
731
|
+
|
|
748
732
|
# Accumulate documents and vectors
|
|
749
733
|
new_documents.extend(file_docs)
|
|
750
734
|
new_vectors.append(file_vectors)
|
|
@@ -752,15 +736,15 @@ class RAGTool:
|
|
|
752
736
|
else:
|
|
753
737
|
# 文件跳过处理
|
|
754
738
|
skipped_count += 1
|
|
755
|
-
|
|
739
|
+
|
|
756
740
|
except Exception as e:
|
|
757
741
|
spinner.write(f"❌ 处理失败: {file_path}: {str(e)}")
|
|
758
742
|
failed_count += 1
|
|
759
|
-
|
|
743
|
+
|
|
760
744
|
# 输出处理统计
|
|
761
745
|
spinner.text = f"文件处理完成: 成功 {success_count} 个, 跳过 {skipped_count} 个, 失败 {failed_count} 个"
|
|
762
746
|
spinner.ok("✅")
|
|
763
|
-
|
|
747
|
+
|
|
764
748
|
# Update documents list
|
|
765
749
|
self.documents.extend(new_documents)
|
|
766
750
|
|
|
@@ -769,7 +753,7 @@ class RAGTool:
|
|
|
769
753
|
with yaspin(text="构建最终索引...", color="cyan") as spinner:
|
|
770
754
|
spinner.text = "合并新向量..."
|
|
771
755
|
all_new_vectors = np.vstack(new_vectors)
|
|
772
|
-
|
|
756
|
+
|
|
773
757
|
unchanged_vector_count = 0
|
|
774
758
|
if self.flat_index is not None:
|
|
775
759
|
# Get vectors for unchanged documents
|
|
@@ -800,11 +784,11 @@ class RAGTool:
|
|
|
800
784
|
f" • 处理文件: {len(files_to_process)} 个\n"
|
|
801
785
|
f" - 成功: {success_count} 个\n"
|
|
802
786
|
f" - 跳过: {skipped_count} 个\n"
|
|
803
|
-
f" - 失败: {failed_count} 个",
|
|
787
|
+
f" - 失败: {failed_count} 个",
|
|
804
788
|
OutputType.SUCCESS
|
|
805
789
|
)
|
|
806
790
|
except Exception as e:
|
|
807
|
-
PrettyOutput.print(f"索引构建失败: {str(e)}",
|
|
791
|
+
PrettyOutput.print(f"索引构建失败: {str(e)}",
|
|
808
792
|
output_type=OutputType.ERROR)
|
|
809
793
|
|
|
810
794
|
def _get_unchanged_vectors(self, unchanged_documents: List[Document], spinner=None) -> Optional[np.ndarray]:
|
|
@@ -817,23 +801,23 @@ class RAGTool:
|
|
|
817
801
|
|
|
818
802
|
if spinner:
|
|
819
803
|
spinner.text = f"加载 {len(unchanged_documents)} 个未变化文档的向量..."
|
|
820
|
-
|
|
804
|
+
|
|
821
805
|
# 按文件分组处理
|
|
822
806
|
unchanged_files = set(doc.metadata['file_path'] for doc in unchanged_documents)
|
|
823
807
|
unchanged_vectors = []
|
|
824
|
-
|
|
808
|
+
|
|
825
809
|
for file_path in unchanged_files:
|
|
826
810
|
if spinner:
|
|
827
811
|
spinner.text = f"加载 {file_path} 的向量..."
|
|
828
|
-
|
|
812
|
+
|
|
829
813
|
# 获取该文件所有文档的chunk索引
|
|
830
|
-
doc_indices = [(i, doc.metadata['chunk_index'])
|
|
831
|
-
for i, doc in enumerate(unchanged_documents)
|
|
814
|
+
doc_indices = [(i, doc.metadata['chunk_index'])
|
|
815
|
+
for i, doc in enumerate(unchanged_documents)
|
|
832
816
|
if doc.metadata['file_path'] == file_path]
|
|
833
|
-
|
|
817
|
+
|
|
834
818
|
if not doc_indices:
|
|
835
819
|
continue
|
|
836
|
-
|
|
820
|
+
|
|
837
821
|
# 加载该文件的向量
|
|
838
822
|
vec_cache_path = self._get_cache_path(file_path, "vec")
|
|
839
823
|
if os.path.exists(vec_cache_path):
|
|
@@ -841,12 +825,12 @@ class RAGTool:
|
|
|
841
825
|
with lzma.open(vec_cache_path, 'rb') as f:
|
|
842
826
|
vec_cache_data = pickle.load(f)
|
|
843
827
|
file_vectors = vec_cache_data["vectors"]
|
|
844
|
-
|
|
828
|
+
|
|
845
829
|
# 按照chunk_index加载对应的向量
|
|
846
830
|
for _, chunk_idx in doc_indices:
|
|
847
831
|
if chunk_idx < file_vectors.shape[0]:
|
|
848
832
|
unchanged_vectors.append(file_vectors[chunk_idx].reshape(1, -1))
|
|
849
|
-
|
|
833
|
+
|
|
850
834
|
if spinner:
|
|
851
835
|
spinner.text = f"成功加载 {file_path} 的向量"
|
|
852
836
|
except Exception as e:
|
|
@@ -855,17 +839,17 @@ class RAGTool:
|
|
|
855
839
|
else:
|
|
856
840
|
if spinner:
|
|
857
841
|
spinner.text = f"未找到 {file_path} 的向量缓存"
|
|
858
|
-
|
|
842
|
+
|
|
859
843
|
# 从flat_index重建向量
|
|
860
844
|
if self.flat_index is not None:
|
|
861
845
|
if spinner:
|
|
862
846
|
spinner.text = f"从索引重建 {file_path} 的向量..."
|
|
863
|
-
|
|
847
|
+
|
|
864
848
|
for doc_idx, chunk_idx in doc_indices:
|
|
865
|
-
idx = next((i for i, d in enumerate(self.documents)
|
|
866
|
-
if d.metadata['file_path'] == file_path and
|
|
849
|
+
idx = next((i for i, d in enumerate(self.documents)
|
|
850
|
+
if d.metadata['file_path'] == file_path and
|
|
867
851
|
d.metadata['chunk_index'] == chunk_idx), None)
|
|
868
|
-
|
|
852
|
+
|
|
869
853
|
if idx is not None:
|
|
870
854
|
vector = np.zeros((1, self.vector_dim), dtype=np.float32) # type: ignore
|
|
871
855
|
self.flat_index.reconstruct(idx, vector.ravel())
|
|
@@ -875,12 +859,12 @@ class RAGTool:
|
|
|
875
859
|
if spinner:
|
|
876
860
|
spinner.text = "未能加载任何未变化文档的向量"
|
|
877
861
|
return None
|
|
878
|
-
|
|
862
|
+
|
|
879
863
|
if spinner:
|
|
880
864
|
spinner.text = f"未变化文档向量加载完成,共 {len(unchanged_vectors)} 个"
|
|
881
|
-
|
|
865
|
+
|
|
882
866
|
return np.vstack(unchanged_vectors)
|
|
883
|
-
|
|
867
|
+
|
|
884
868
|
except Exception as e:
|
|
885
869
|
if spinner:
|
|
886
870
|
spinner.text = f"获取不变向量失败: {str(e)}"
|
|
@@ -889,17 +873,17 @@ class RAGTool:
|
|
|
889
873
|
|
|
890
874
|
def _perform_keyword_search(self, query: str, limit: int = 15) -> List[Tuple[int, float]]:
|
|
891
875
|
"""执行基于关键词的文本搜索
|
|
892
|
-
|
|
876
|
+
|
|
893
877
|
Args:
|
|
894
878
|
query: 查询字符串
|
|
895
879
|
limit: 返回结果数量限制
|
|
896
|
-
|
|
880
|
+
|
|
897
881
|
Returns:
|
|
898
882
|
List[Tuple[int, float]]: 文档索引和得分的列表
|
|
899
883
|
"""
|
|
900
884
|
# 使用大模型生成关键词
|
|
901
885
|
keywords = self._generate_keywords_with_llm(query)
|
|
902
|
-
|
|
886
|
+
|
|
903
887
|
# 如果大模型生成失败,回退到简单的关键词提取
|
|
904
888
|
if not keywords:
|
|
905
889
|
# 简单的关键词预处理
|
|
@@ -907,35 +891,35 @@ class RAGTool:
|
|
|
907
891
|
# 移除停用词和过短的词
|
|
908
892
|
stop_words = {'的', '了', '和', '是', '在', '有', '与', '对', '为', 'a', 'an', 'the', 'and', 'is', 'in', 'of', 'to', 'with'}
|
|
909
893
|
keywords = [k for k in keywords if k not in stop_words and len(k) > 1]
|
|
910
|
-
|
|
894
|
+
|
|
911
895
|
if not keywords:
|
|
912
896
|
return []
|
|
913
|
-
|
|
897
|
+
|
|
914
898
|
# 使用TF-IDF思想的简单实现
|
|
915
899
|
doc_scores = []
|
|
916
|
-
|
|
900
|
+
|
|
917
901
|
# 计算IDF(逆文档频率)
|
|
918
902
|
doc_count = len(self.documents)
|
|
919
903
|
keyword_doc_count = {}
|
|
920
|
-
|
|
904
|
+
|
|
921
905
|
for keyword in keywords:
|
|
922
906
|
count = 0
|
|
923
907
|
for doc in self.documents:
|
|
924
908
|
if keyword in doc.content.lower():
|
|
925
909
|
count += 1
|
|
926
910
|
keyword_doc_count[keyword] = max(1, count) # 避免除零错误
|
|
927
|
-
|
|
911
|
+
|
|
928
912
|
# 计算每个关键词的IDF值
|
|
929
913
|
keyword_idf = {
|
|
930
|
-
keyword: np.log(doc_count / count)
|
|
914
|
+
keyword: np.log(doc_count / count)
|
|
931
915
|
for keyword, count in keyword_doc_count.items()
|
|
932
916
|
}
|
|
933
|
-
|
|
917
|
+
|
|
934
918
|
# 为每个文档计算得分
|
|
935
919
|
for i, doc in enumerate(self.documents):
|
|
936
920
|
doc_content = doc.content.lower()
|
|
937
921
|
score = 0
|
|
938
|
-
|
|
922
|
+
|
|
939
923
|
# 计算每个关键词的TF(词频)
|
|
940
924
|
for keyword in keywords:
|
|
941
925
|
# 简单的TF:关键词在文档中出现的次数
|
|
@@ -943,46 +927,46 @@ class RAGTool:
|
|
|
943
927
|
# TF-IDF得分
|
|
944
928
|
if tf > 0:
|
|
945
929
|
score += tf * keyword_idf[keyword]
|
|
946
|
-
|
|
930
|
+
|
|
947
931
|
# 添加额外权重:标题匹配、完整短语匹配等
|
|
948
932
|
if query.lower() in doc_content:
|
|
949
933
|
score *= 2.0 # 完整查询匹配加倍得分
|
|
950
|
-
|
|
934
|
+
|
|
951
935
|
# 文件路径匹配也加分
|
|
952
936
|
file_path = doc.metadata['file_path'].lower()
|
|
953
937
|
for keyword in keywords:
|
|
954
938
|
if keyword in file_path:
|
|
955
939
|
score += 0.5 * keyword_idf.get(keyword, 1.0)
|
|
956
|
-
|
|
940
|
+
|
|
957
941
|
if score > 0:
|
|
958
942
|
# 归一化得分(0-1范围)
|
|
959
943
|
doc_scores.append((i, score))
|
|
960
|
-
|
|
944
|
+
|
|
961
945
|
# 排序并限制结果数量
|
|
962
946
|
doc_scores.sort(key=lambda x: x[1], reverse=True)
|
|
963
|
-
|
|
947
|
+
|
|
964
948
|
# 归一化分数到0-1之间
|
|
965
949
|
if doc_scores:
|
|
966
950
|
max_score = max(score for _, score in doc_scores)
|
|
967
951
|
if max_score > 0:
|
|
968
952
|
doc_scores = [(idx, score/max_score) for idx, score in doc_scores]
|
|
969
|
-
|
|
953
|
+
|
|
970
954
|
return doc_scores[:limit]
|
|
971
955
|
|
|
972
956
|
def _generate_keywords_with_llm(self, query: str) -> List[str]:
|
|
973
957
|
"""
|
|
974
958
|
使用大语言模型从查询中提取关键词
|
|
975
|
-
|
|
959
|
+
|
|
976
960
|
Args:
|
|
977
961
|
query: 用户查询
|
|
978
|
-
|
|
962
|
+
|
|
979
963
|
Returns:
|
|
980
964
|
List[str]: 提取的关键词列表
|
|
981
965
|
"""
|
|
982
966
|
try:
|
|
983
967
|
from jarvis.jarvis_utils.output import PrettyOutput, OutputType
|
|
984
968
|
from jarvis.jarvis_platform.registry import PlatformRegistry
|
|
985
|
-
|
|
969
|
+
|
|
986
970
|
# 获取平台注册表和模型
|
|
987
971
|
registry = PlatformRegistry.get_global_platform_registry()
|
|
988
972
|
model = registry.get_normal_platform()
|
|
@@ -990,55 +974,55 @@ class RAGTool:
|
|
|
990
974
|
# 构建关键词提取提示词
|
|
991
975
|
prompt = f"""
|
|
992
976
|
请分析以下查询,提取用于文档检索的关键词。你的任务是:
|
|
993
|
-
|
|
977
|
+
|
|
994
978
|
1. 识别核心概念、主题和实体,包括:
|
|
995
979
|
- 技术术语和专业名词
|
|
996
980
|
- 代码相关的函数名、类名、变量名和库名
|
|
997
981
|
- 重要的业务领域词汇
|
|
998
982
|
- 操作和动作相关的词汇
|
|
999
|
-
|
|
983
|
+
|
|
1000
984
|
2. 优先提取与以下场景相关的关键词:
|
|
1001
985
|
- 代码搜索: 编程语言、框架、API、特定功能
|
|
1002
986
|
- 文档检索: 主题、标题词汇、专业名词
|
|
1003
987
|
- 错误排查: 错误信息、异常名称、问题症状
|
|
1004
|
-
|
|
988
|
+
|
|
1005
989
|
3. 同时包含:
|
|
1006
990
|
- 中英文关键词 (尤其是技术领域常用英文术语)
|
|
1007
991
|
- 完整的专业术语和缩写形式
|
|
1008
992
|
- 可能的同义词或相关概念
|
|
1009
|
-
|
|
993
|
+
|
|
1010
994
|
4. 关键词应当精准、具体,数量控制在3-10个之间。
|
|
1011
|
-
|
|
995
|
+
|
|
1012
996
|
输出格式:
|
|
1013
997
|
{ot("KEYWORD")}
|
|
1014
998
|
关键词1
|
|
1015
999
|
关键词2
|
|
1016
1000
|
...
|
|
1017
1001
|
{ct("KEYWORD")}
|
|
1018
|
-
|
|
1002
|
+
|
|
1019
1003
|
查询文本:
|
|
1020
1004
|
{query}
|
|
1021
1005
|
|
|
1022
1006
|
仅返回提取的关键词,不要包含其他内容。
|
|
1023
1007
|
"""
|
|
1024
|
-
|
|
1008
|
+
|
|
1025
1009
|
# 调用大模型获取响应
|
|
1026
1010
|
response = model.chat_until_success(prompt)
|
|
1027
|
-
|
|
1011
|
+
|
|
1028
1012
|
if response:
|
|
1029
1013
|
# 清理响应,提取关键词
|
|
1030
1014
|
sm = re.search(ot('KEYWORD') + r"(.*?)" + ct('KEYWORD'), response, re.DOTALL)
|
|
1031
1015
|
if sm:
|
|
1032
1016
|
extracted_keywords = sm[1]
|
|
1033
|
-
|
|
1017
|
+
|
|
1034
1018
|
if extracted_keywords:
|
|
1035
1019
|
# 记录检测到的关键词
|
|
1036
1020
|
ret = extracted_keywords.strip().splitlines()
|
|
1037
1021
|
return ret
|
|
1038
|
-
|
|
1022
|
+
|
|
1039
1023
|
# 如果处理失败,返回空列表
|
|
1040
1024
|
return []
|
|
1041
|
-
|
|
1025
|
+
|
|
1042
1026
|
except Exception as e:
|
|
1043
1027
|
from jarvis.jarvis_utils.output import PrettyOutput, OutputType
|
|
1044
1028
|
PrettyOutput.print(f"使用大模型生成关键词失败: {str(e)}", OutputType.WARNING)
|
|
@@ -1046,38 +1030,38 @@ class RAGTool:
|
|
|
1046
1030
|
|
|
1047
1031
|
def _hybrid_search(self, query: str, top_k: int = 15) -> List[Tuple[int, float]]:
|
|
1048
1032
|
"""混合搜索方法,综合向量相似度和关键词匹配
|
|
1049
|
-
|
|
1033
|
+
|
|
1050
1034
|
Args:
|
|
1051
1035
|
query: 查询字符串
|
|
1052
1036
|
top_k: 返回结果数量限制
|
|
1053
|
-
|
|
1037
|
+
|
|
1054
1038
|
Returns:
|
|
1055
1039
|
List[Tuple[int, float]]: 文档索引和得分的列表
|
|
1056
1040
|
"""
|
|
1057
1041
|
# 获取向量搜索结果
|
|
1058
1042
|
query_vector = get_embedding(self.embedding_model, query)
|
|
1059
1043
|
query_vector = query_vector.reshape(1, -1)
|
|
1060
|
-
|
|
1044
|
+
|
|
1061
1045
|
# 进行向量搜索
|
|
1062
1046
|
vector_limit = min(top_k * 3, len(self.documents))
|
|
1063
1047
|
if self.index and vector_limit > 0:
|
|
1064
1048
|
distances, indices = self.index.search(query_vector, vector_limit) # type: ignore
|
|
1065
|
-
vector_results = [(int(idx), 1.0 / (1.0 + float(dist)))
|
|
1049
|
+
vector_results = [(int(idx), 1.0 / (1.0 + float(dist)))
|
|
1066
1050
|
for idx, dist in zip(indices[0], distances[0])
|
|
1067
1051
|
if idx != -1 and idx < len(self.documents)]
|
|
1068
1052
|
else:
|
|
1069
1053
|
vector_results = []
|
|
1070
|
-
|
|
1054
|
+
|
|
1071
1055
|
# 进行关键词搜索
|
|
1072
1056
|
keyword_results = self._perform_keyword_search(query, top_k * 2)
|
|
1073
|
-
|
|
1057
|
+
|
|
1074
1058
|
# 合并结果集
|
|
1075
1059
|
combined_results = {}
|
|
1076
|
-
|
|
1060
|
+
|
|
1077
1061
|
# 加入向量结果,权重为0.7
|
|
1078
1062
|
for idx, score in vector_results:
|
|
1079
1063
|
combined_results[idx] = score * 0.7
|
|
1080
|
-
|
|
1064
|
+
|
|
1081
1065
|
# 加入关键词结果,权重为0.3,如果文档已存在则取加权平均
|
|
1082
1066
|
for idx, score in keyword_results:
|
|
1083
1067
|
if idx in combined_results:
|
|
@@ -1086,11 +1070,11 @@ class RAGTool:
|
|
|
1086
1070
|
else:
|
|
1087
1071
|
# 新文档,直接添加关键词得分(权重稍低)
|
|
1088
1072
|
combined_results[idx] = score * 0.3
|
|
1089
|
-
|
|
1073
|
+
|
|
1090
1074
|
# 转换成列表并排序
|
|
1091
1075
|
result_list = [(idx, score) for idx, score in combined_results.items()]
|
|
1092
1076
|
result_list.sort(key=lambda x: x[1], reverse=True)
|
|
1093
|
-
|
|
1077
|
+
|
|
1094
1078
|
return result_list[:top_k]
|
|
1095
1079
|
|
|
1096
1080
|
|
|
@@ -1099,35 +1083,35 @@ class RAGTool:
|
|
|
1099
1083
|
if not self.is_index_built():
|
|
1100
1084
|
PrettyOutput.print("索引未建立,自动建立索引中...", OutputType.INFO)
|
|
1101
1085
|
self.build_index(self.root_dir)
|
|
1102
|
-
|
|
1086
|
+
|
|
1103
1087
|
# 如果索引建立失败或文档列表为空,返回空结果
|
|
1104
1088
|
if not self.is_index_built():
|
|
1105
1089
|
PrettyOutput.print("索引建立失败或文档列表为空", OutputType.WARNING)
|
|
1106
1090
|
return []
|
|
1107
|
-
|
|
1091
|
+
|
|
1108
1092
|
# 使用混合搜索获取候选文档
|
|
1109
1093
|
with yaspin(text="执行混合搜索...", color="cyan") as spinner:
|
|
1110
1094
|
# 获取初始候选结果
|
|
1111
1095
|
search_results = self._hybrid_search(query, top_k * 2)
|
|
1112
|
-
|
|
1096
|
+
|
|
1113
1097
|
if not search_results:
|
|
1114
1098
|
spinner.text = "搜索结果为空"
|
|
1115
1099
|
spinner.fail("❌")
|
|
1116
1100
|
return []
|
|
1117
|
-
|
|
1101
|
+
|
|
1118
1102
|
# 准备重排序
|
|
1119
1103
|
initial_indices = [idx for idx, _ in search_results]
|
|
1120
1104
|
spinner.text = f"检索完成,获取 {len(initial_indices)} 个候选文档"
|
|
1121
1105
|
spinner.ok("✅")
|
|
1122
|
-
|
|
1106
|
+
|
|
1123
1107
|
indices_list = [idx for idx, _ in search_results if idx < len(self.documents)]
|
|
1124
|
-
|
|
1108
|
+
|
|
1125
1109
|
# 应用重排序优化检索结果
|
|
1126
1110
|
with yaspin(text="执行重排序...", color="cyan") as spinner:
|
|
1127
1111
|
# 准备重排序所需文档内容和初始分数
|
|
1128
1112
|
docs_to_rerank = []
|
|
1129
1113
|
initial_scores = []
|
|
1130
|
-
|
|
1114
|
+
|
|
1131
1115
|
for idx, score in search_results:
|
|
1132
1116
|
if idx < len(self.documents):
|
|
1133
1117
|
doc = self.documents[idx]
|
|
@@ -1135,12 +1119,12 @@ class RAGTool:
|
|
|
1135
1119
|
doc_content = f"File:{doc.metadata['file_path']} Content:{doc.content}"
|
|
1136
1120
|
docs_to_rerank.append(doc_content)
|
|
1137
1121
|
initial_scores.append(score)
|
|
1138
|
-
|
|
1122
|
+
|
|
1139
1123
|
if not docs_to_rerank:
|
|
1140
1124
|
spinner.text = "没有可重排序的文档"
|
|
1141
1125
|
spinner.fail("❌")
|
|
1142
1126
|
return []
|
|
1143
|
-
|
|
1127
|
+
|
|
1144
1128
|
# 执行重排序
|
|
1145
1129
|
spinner.text = f"重排序 {len(docs_to_rerank)} 个文档..."
|
|
1146
1130
|
reranked_scores = rerank_results(
|
|
@@ -1149,69 +1133,69 @@ class RAGTool:
|
|
|
1149
1133
|
initial_scores=initial_scores,
|
|
1150
1134
|
spinner=spinner
|
|
1151
1135
|
)
|
|
1152
|
-
|
|
1136
|
+
|
|
1153
1137
|
# 更新搜索结果的分数
|
|
1154
1138
|
search_results = []
|
|
1155
1139
|
for i, idx in enumerate(indices_list):
|
|
1156
1140
|
if i < len(reranked_scores):
|
|
1157
1141
|
search_results.append((idx, reranked_scores[i]))
|
|
1158
|
-
|
|
1142
|
+
|
|
1159
1143
|
# 按分数重新排序
|
|
1160
1144
|
search_results.sort(key=lambda x: x[1], reverse=True)
|
|
1161
|
-
|
|
1145
|
+
|
|
1162
1146
|
spinner.text = "重排序完成"
|
|
1163
1147
|
spinner.ok("✅")
|
|
1164
|
-
|
|
1148
|
+
|
|
1165
1149
|
# 重新获取排序后的索引列表
|
|
1166
1150
|
indices_list = [idx for idx, _ in search_results if idx < len(self.documents)]
|
|
1167
|
-
|
|
1151
|
+
|
|
1168
1152
|
# Process results with context window
|
|
1169
1153
|
with yaspin(text="处理结果...", color="cyan") as spinner:
|
|
1170
1154
|
results = []
|
|
1171
1155
|
seen_files = set()
|
|
1172
|
-
|
|
1156
|
+
|
|
1173
1157
|
# 检查索引列表是否为空
|
|
1174
1158
|
if not indices_list:
|
|
1175
1159
|
spinner.text = "搜索结果为空"
|
|
1176
1160
|
spinner.fail("❌")
|
|
1177
1161
|
return []
|
|
1178
|
-
|
|
1162
|
+
|
|
1179
1163
|
for idx in indices_list:
|
|
1180
1164
|
if idx < len(self.documents): # 确保索引有效
|
|
1181
1165
|
doc = self.documents[idx]
|
|
1182
|
-
|
|
1166
|
+
|
|
1183
1167
|
# 使用重排序得分或基于原始相似度的得分
|
|
1184
1168
|
similarity = next((score for i, score in search_results if i == idx), 0.5) if search_results else 0.5
|
|
1185
|
-
|
|
1169
|
+
|
|
1186
1170
|
file_path = doc.metadata['file_path']
|
|
1187
1171
|
if file_path not in seen_files:
|
|
1188
1172
|
seen_files.add(file_path)
|
|
1189
|
-
|
|
1173
|
+
|
|
1190
1174
|
# Get full context from original document
|
|
1191
|
-
original_doc = next((d for d in self.documents
|
|
1175
|
+
original_doc = next((d for d in self.documents
|
|
1192
1176
|
if d.metadata['file_path'] == file_path), None)
|
|
1193
1177
|
if original_doc:
|
|
1194
1178
|
window_docs = [] # Add this line to initialize the list
|
|
1195
1179
|
# Find all chunks from this file
|
|
1196
|
-
file_chunks = [d for d in self.documents
|
|
1180
|
+
file_chunks = [d for d in self.documents
|
|
1197
1181
|
if d.metadata['file_path'] == file_path]
|
|
1198
1182
|
# Add all related chunks
|
|
1199
1183
|
for chunk_doc in file_chunks:
|
|
1200
1184
|
window_docs.append((chunk_doc, similarity * 0.9))
|
|
1201
|
-
|
|
1185
|
+
|
|
1202
1186
|
results.extend(window_docs)
|
|
1203
1187
|
if len(results) >= top_k * (2 * self.context_window + 1):
|
|
1204
1188
|
break
|
|
1205
1189
|
spinner.text = "处理结果完成"
|
|
1206
1190
|
spinner.ok("✅")
|
|
1207
|
-
|
|
1191
|
+
|
|
1208
1192
|
# Sort by similarity and deduplicate
|
|
1209
1193
|
with yaspin(text="排序...", color="cyan") as spinner:
|
|
1210
1194
|
if not results:
|
|
1211
1195
|
spinner.text = "无有效结果"
|
|
1212
1196
|
spinner.fail("❌")
|
|
1213
1197
|
return []
|
|
1214
|
-
|
|
1198
|
+
|
|
1215
1199
|
results.sort(key=lambda x: x[1], reverse=True)
|
|
1216
1200
|
seen = set()
|
|
1217
1201
|
final_results = []
|
|
@@ -1224,15 +1208,15 @@ class RAGTool:
|
|
|
1224
1208
|
break
|
|
1225
1209
|
spinner.text = "排序完成"
|
|
1226
1210
|
spinner.ok("✅")
|
|
1227
|
-
|
|
1211
|
+
|
|
1228
1212
|
return final_results
|
|
1229
1213
|
|
|
1230
1214
|
def query(self, query: str) -> List[Document]:
|
|
1231
1215
|
"""Query related documents
|
|
1232
|
-
|
|
1216
|
+
|
|
1233
1217
|
Args:
|
|
1234
1218
|
query: Query text
|
|
1235
|
-
|
|
1219
|
+
|
|
1236
1220
|
Returns:
|
|
1237
1221
|
List[Document]: Related documents
|
|
1238
1222
|
"""
|
|
@@ -1246,15 +1230,15 @@ class RAGTool:
|
|
|
1246
1230
|
if not self.is_index_built():
|
|
1247
1231
|
PrettyOutput.print("索引未建立,自动建立索引中...", OutputType.INFO)
|
|
1248
1232
|
self.build_index(self.root_dir)
|
|
1249
|
-
|
|
1233
|
+
|
|
1250
1234
|
# 如果建立索引后仍未成功,返回错误信息
|
|
1251
1235
|
if not self.is_index_built():
|
|
1252
1236
|
PrettyOutput.print("无法建立索引,请检查文档和配置", OutputType.ERROR)
|
|
1253
1237
|
return "无法建立索引,请检查文档和配置。可能的原因:文档目录为空、权限不足或格式不支持。"
|
|
1254
|
-
|
|
1238
|
+
|
|
1255
1239
|
# 增强查询预处理 - 提取关键词和语义信息
|
|
1256
1240
|
enhanced_query = self._enhance_query(question)
|
|
1257
|
-
|
|
1241
|
+
|
|
1258
1242
|
# 使用增强的查询进行搜索
|
|
1259
1243
|
results = self.search(enhanced_query)
|
|
1260
1244
|
if not results:
|
|
@@ -1262,7 +1246,7 @@ class RAGTool:
|
|
|
1262
1246
|
|
|
1263
1247
|
# 模型实例
|
|
1264
1248
|
model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
|
|
1265
|
-
|
|
1249
|
+
|
|
1266
1250
|
# 计算基础提示词的token数量
|
|
1267
1251
|
base_prompt = f"""
|
|
1268
1252
|
# 🤖 角色定义
|
|
@@ -1313,10 +1297,10 @@ class RAGTool:
|
|
|
1313
1297
|
6. 使用用户的语言回答
|
|
1314
1298
|
"""
|
|
1315
1299
|
footer_token_count = get_context_token_count(footer_prompt)
|
|
1316
|
-
|
|
1300
|
+
|
|
1317
1301
|
# 每批可用的token数,减去一些安全余量
|
|
1318
1302
|
available_tokens_per_batch = self.max_token_count - base_token_count - footer_token_count - 1000
|
|
1319
|
-
|
|
1303
|
+
|
|
1320
1304
|
# 确定是否需要分批处理
|
|
1321
1305
|
with yaspin(text="计算文档上下文大小...", color="cyan") as spinner:
|
|
1322
1306
|
# 将结果按文件分组
|
|
@@ -1326,16 +1310,16 @@ class RAGTool:
|
|
|
1326
1310
|
if file_path not in file_groups:
|
|
1327
1311
|
file_groups[file_path] = []
|
|
1328
1312
|
file_groups[file_path].append((doc, score))
|
|
1329
|
-
|
|
1313
|
+
|
|
1330
1314
|
# 计算所有文档的总token数
|
|
1331
1315
|
total_docs_tokens = 0
|
|
1332
1316
|
total_len = 0
|
|
1333
1317
|
for file_path, docs in file_groups.items():
|
|
1334
1318
|
file_header = f"\n## 文件: {file_path}\n"
|
|
1335
1319
|
file_tokens = get_context_token_count(file_header)
|
|
1336
|
-
|
|
1320
|
+
|
|
1337
1321
|
# 处理所有相关性足够高的文档
|
|
1338
|
-
for doc, score in docs:
|
|
1322
|
+
for doc, score in docs:
|
|
1339
1323
|
if score < 0.2:
|
|
1340
1324
|
continue
|
|
1341
1325
|
doc_content = f"""
|
|
@@ -1347,48 +1331,48 @@ class RAGTool:
|
|
|
1347
1331
|
file_tokens += get_context_token_count(doc_content)
|
|
1348
1332
|
total_len += len(doc_content)
|
|
1349
1333
|
total_docs_tokens += file_tokens
|
|
1350
|
-
|
|
1334
|
+
|
|
1351
1335
|
# 确定是否需要分批处理及分几批
|
|
1352
1336
|
need_batching = total_docs_tokens > available_tokens_per_batch
|
|
1353
1337
|
batch_count = 1
|
|
1354
1338
|
if need_batching:
|
|
1355
1339
|
batch_count = (total_docs_tokens + available_tokens_per_batch - 1) // available_tokens_per_batch
|
|
1356
|
-
|
|
1340
|
+
|
|
1357
1341
|
if need_batching:
|
|
1358
1342
|
spinner.text = f"文档需要分 {batch_count} 批处理 (总计 {total_docs_tokens} tokens), 总长度 {total_len} 字符"
|
|
1359
1343
|
else:
|
|
1360
1344
|
spinner.text = f"文档无需分批 (总计 {total_docs_tokens} tokens), 总长度 {total_len} 字符"
|
|
1361
1345
|
spinner.ok("✅")
|
|
1362
|
-
|
|
1346
|
+
|
|
1363
1347
|
# 单批处理直接使用原方法
|
|
1364
1348
|
if not need_batching:
|
|
1365
1349
|
with yaspin(text="添加上下文...", color="cyan") as spinner:
|
|
1366
1350
|
prompt = base_prompt
|
|
1367
1351
|
current_count = base_token_count
|
|
1368
|
-
|
|
1352
|
+
|
|
1369
1353
|
# 保存已添加的内容指纹,避免重复
|
|
1370
1354
|
added_content_hashes = set()
|
|
1371
|
-
|
|
1355
|
+
|
|
1372
1356
|
# 按文件添加文档片段
|
|
1373
1357
|
for file_path, docs in file_groups.items():
|
|
1374
1358
|
# 按相关性排序
|
|
1375
1359
|
docs.sort(key=lambda x: x[1], reverse=True)
|
|
1376
|
-
|
|
1360
|
+
|
|
1377
1361
|
# 添加文件信息
|
|
1378
1362
|
file_header = f"\n## 文件: {file_path}\n"
|
|
1379
1363
|
if current_count + get_context_token_count(file_header) > available_tokens_per_batch:
|
|
1380
1364
|
break
|
|
1381
|
-
|
|
1365
|
+
|
|
1382
1366
|
prompt += file_header
|
|
1383
1367
|
current_count += get_context_token_count(file_header)
|
|
1384
|
-
|
|
1368
|
+
|
|
1385
1369
|
# 添加相关的文档片段,不限制每个文件的片段数量
|
|
1386
1370
|
for doc, score in docs:
|
|
1387
1371
|
# 计算内容指纹以避免重复
|
|
1388
1372
|
content_hash = hash(doc.content)
|
|
1389
1373
|
if content_hash in added_content_hashes:
|
|
1390
1374
|
continue
|
|
1391
|
-
|
|
1375
|
+
|
|
1392
1376
|
# 格式化文档片段
|
|
1393
1377
|
doc_content = f"""
|
|
1394
1378
|
### 片段 {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']} [相关度: {score:.2f}]
|
|
@@ -1398,11 +1382,11 @@ class RAGTool:
|
|
|
1398
1382
|
"""
|
|
1399
1383
|
if current_count + get_context_token_count(doc_content) > available_tokens_per_batch:
|
|
1400
1384
|
break
|
|
1401
|
-
|
|
1385
|
+
|
|
1402
1386
|
prompt += doc_content
|
|
1403
1387
|
current_count += get_context_token_count(doc_content)
|
|
1404
1388
|
added_content_hashes.add(content_hash)
|
|
1405
|
-
|
|
1389
|
+
|
|
1406
1390
|
prompt += footer_prompt
|
|
1407
1391
|
spinner.text = "添加上下文完成"
|
|
1408
1392
|
spinner.ok("✅")
|
|
@@ -1413,39 +1397,39 @@ class RAGTool:
|
|
|
1413
1397
|
spinner.text = "答案生成完成"
|
|
1414
1398
|
spinner.ok("✅")
|
|
1415
1399
|
return response
|
|
1416
|
-
|
|
1400
|
+
|
|
1417
1401
|
# 分批处理文档
|
|
1418
1402
|
else:
|
|
1419
1403
|
batch_responses = []
|
|
1420
|
-
|
|
1404
|
+
|
|
1421
1405
|
# 准备批次
|
|
1422
1406
|
with yaspin(text=f"准备分批处理 (共{batch_count}批)...", color="cyan") as spinner:
|
|
1423
1407
|
batches = []
|
|
1424
1408
|
current_batch = []
|
|
1425
1409
|
current_batch_tokens = 0
|
|
1426
|
-
|
|
1410
|
+
|
|
1427
1411
|
# 按相关性排序处理文件
|
|
1428
|
-
sorted_files = sorted(file_groups.items(),
|
|
1429
|
-
key=lambda x: max(score for _, score in x[1]) if x[1] else 0,
|
|
1412
|
+
sorted_files = sorted(file_groups.items(),
|
|
1413
|
+
key=lambda x: max(score for _, score in x[1]) if x[1] else 0,
|
|
1430
1414
|
reverse=True)
|
|
1431
|
-
|
|
1415
|
+
|
|
1432
1416
|
for file_path, docs in sorted_files:
|
|
1433
1417
|
# 按相关性排序文档
|
|
1434
1418
|
docs.sort(key=lambda x: x[1], reverse=True)
|
|
1435
|
-
|
|
1419
|
+
|
|
1436
1420
|
# 处理每个文件的文档
|
|
1437
1421
|
file_header = f"\n## 文件: {file_path}\n"
|
|
1438
1422
|
file_header_tokens = get_context_token_count(file_header)
|
|
1439
|
-
|
|
1423
|
+
|
|
1440
1424
|
# 如果当前批次添加这个文件会超过限制,创建新批次
|
|
1441
1425
|
file_docs = []
|
|
1442
1426
|
file_docs_tokens = 0
|
|
1443
|
-
|
|
1427
|
+
|
|
1444
1428
|
# 计算此文件要添加的所有文档,不限制片段数量
|
|
1445
1429
|
for doc, score in docs:
|
|
1446
1430
|
if score < 0.2: # 过滤低相关性文档
|
|
1447
1431
|
continue
|
|
1448
|
-
|
|
1432
|
+
|
|
1449
1433
|
doc_content = f"""
|
|
1450
1434
|
### 片段 {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']} [相关度: {score:.2f}]
|
|
1451
1435
|
```
|
|
@@ -1455,38 +1439,38 @@ class RAGTool:
|
|
|
1455
1439
|
doc_tokens = get_context_token_count(doc_content)
|
|
1456
1440
|
file_docs.append((doc, score, doc_content, doc_tokens))
|
|
1457
1441
|
file_docs_tokens += doc_tokens
|
|
1458
|
-
|
|
1442
|
+
|
|
1459
1443
|
# 如果此文件的内容加上文件头会导致当前批次超限,创建新批次
|
|
1460
1444
|
if current_batch and (current_batch_tokens + file_header_tokens + file_docs_tokens > available_tokens_per_batch):
|
|
1461
1445
|
batches.append(current_batch)
|
|
1462
1446
|
current_batch = []
|
|
1463
1447
|
current_batch_tokens = 0
|
|
1464
|
-
|
|
1448
|
+
|
|
1465
1449
|
# 将文件及其文档添加到当前批次
|
|
1466
1450
|
if file_docs: # 如果有要添加的文档
|
|
1467
1451
|
current_batch.append((file_path, file_header, file_docs))
|
|
1468
1452
|
current_batch_tokens += file_header_tokens + file_docs_tokens
|
|
1469
|
-
|
|
1453
|
+
|
|
1470
1454
|
# 添加最后一个批次
|
|
1471
1455
|
if current_batch:
|
|
1472
1456
|
batches.append(current_batch)
|
|
1473
|
-
|
|
1457
|
+
|
|
1474
1458
|
spinner.text = f"分批准备完成,共 {len(batches)} 批"
|
|
1475
1459
|
spinner.ok("✅")
|
|
1476
|
-
|
|
1460
|
+
|
|
1477
1461
|
# 处理每个批次
|
|
1478
1462
|
for batch_idx, batch in enumerate(batches):
|
|
1479
1463
|
with yaspin(text=f"处理批次 {batch_idx+1}/{len(batches)}...", color="cyan") as spinner:
|
|
1480
1464
|
# 构建批次提示词
|
|
1481
1465
|
batch_prompt = base_prompt + f"\n\n## 批次 {batch_idx+1}/{len(batches)} 的相关文档:\n"
|
|
1482
|
-
|
|
1466
|
+
|
|
1483
1467
|
# 添加批次中的文档
|
|
1484
1468
|
for file_path, file_header, file_docs in batch:
|
|
1485
1469
|
batch_prompt += file_header
|
|
1486
|
-
|
|
1470
|
+
|
|
1487
1471
|
for doc, score, doc_content, _ in file_docs:
|
|
1488
1472
|
batch_prompt += doc_content
|
|
1489
|
-
|
|
1473
|
+
|
|
1490
1474
|
# 为最后一个批次添加总结指令,为中间批次添加部分分析指令
|
|
1491
1475
|
if batch_idx == len(batches) - 1:
|
|
1492
1476
|
# 最后一个批次,添加总结所有批次的指令
|
|
@@ -1507,20 +1491,20 @@ class RAGTool:
|
|
|
1507
1491
|
3. 简明扼要,重点关注与问题直接相关的内容
|
|
1508
1492
|
4. 忽略与问题无关的内容
|
|
1509
1493
|
"""
|
|
1510
|
-
|
|
1494
|
+
|
|
1511
1495
|
spinner.text = f"正在分析批次 {batch_idx+1}/{len(batches)}..."
|
|
1512
|
-
|
|
1496
|
+
|
|
1513
1497
|
# 调用模型处理当前批次
|
|
1514
1498
|
batch_response = model.chat_until_success(batch_prompt)
|
|
1515
1499
|
batch_responses.append(batch_response)
|
|
1516
|
-
|
|
1500
|
+
|
|
1517
1501
|
spinner.text = f"批次 {batch_idx+1}/{len(batches)} 分析完成"
|
|
1518
1502
|
spinner.ok("✅")
|
|
1519
|
-
|
|
1503
|
+
|
|
1520
1504
|
# 如果只有一个批次,直接返回结果
|
|
1521
1505
|
if len(batch_responses) == 1:
|
|
1522
1506
|
return batch_responses[0]
|
|
1523
|
-
|
|
1507
|
+
|
|
1524
1508
|
# 如果有多个批次,需要汇总结果
|
|
1525
1509
|
with yaspin(text="汇总多批次分析结果...", color="cyan") as spinner:
|
|
1526
1510
|
# 构建汇总提示词
|
|
@@ -1536,7 +1520,7 @@ class RAGTool:
|
|
|
1536
1520
|
以下是各批次的分析结果:
|
|
1537
1521
|
|
|
1538
1522
|
"""
|
|
1539
|
-
|
|
1523
|
+
|
|
1540
1524
|
# 添加每个批次的分析结果
|
|
1541
1525
|
for i, response in enumerate(batch_responses):
|
|
1542
1526
|
summary_prompt += f"""
|
|
@@ -1544,7 +1528,7 @@ class RAGTool:
|
|
|
1544
1528
|
{response}
|
|
1545
1529
|
|
|
1546
1530
|
"""
|
|
1547
|
-
|
|
1531
|
+
|
|
1548
1532
|
# 添加汇总指导
|
|
1549
1533
|
summary_prompt += """
|
|
1550
1534
|
## 汇总要求
|
|
@@ -1583,66 +1567,66 @@ class RAGTool:
|
|
|
1583
1567
|
|
|
1584
1568
|
请直接提供最终回答,不需要解释你的汇总过程。
|
|
1585
1569
|
"""
|
|
1586
|
-
|
|
1570
|
+
|
|
1587
1571
|
spinner.text = "正在生成最终汇总答案..."
|
|
1588
|
-
|
|
1572
|
+
|
|
1589
1573
|
# 调用模型生成最终汇总
|
|
1590
1574
|
final_response = model.chat_until_success(summary_prompt)
|
|
1591
|
-
|
|
1575
|
+
|
|
1592
1576
|
spinner.text = "汇总答案生成完成"
|
|
1593
1577
|
spinner.ok("✅")
|
|
1594
|
-
|
|
1578
|
+
|
|
1595
1579
|
return final_response
|
|
1596
|
-
|
|
1580
|
+
|
|
1597
1581
|
except Exception as e:
|
|
1598
1582
|
PrettyOutput.print(f"回答失败:{str(e)}", OutputType.ERROR)
|
|
1599
1583
|
return None
|
|
1600
|
-
|
|
1584
|
+
|
|
1601
1585
|
def _enhance_query(self, query: str) -> str:
|
|
1602
1586
|
"""增强查询以提高检索质量
|
|
1603
|
-
|
|
1587
|
+
|
|
1604
1588
|
Args:
|
|
1605
1589
|
query: 原始查询
|
|
1606
|
-
|
|
1590
|
+
|
|
1607
1591
|
Returns:
|
|
1608
1592
|
str: 增强后的查询
|
|
1609
1593
|
"""
|
|
1610
1594
|
# 简单的查询预处理
|
|
1611
1595
|
query = query.strip()
|
|
1612
|
-
|
|
1596
|
+
|
|
1613
1597
|
# 如果查询太短,返回原始查询
|
|
1614
1598
|
if len(query) < 10:
|
|
1615
1599
|
return query
|
|
1616
|
-
|
|
1600
|
+
|
|
1617
1601
|
try:
|
|
1618
1602
|
# 尝试使用大模型增强查询(如果可用)
|
|
1619
1603
|
model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
|
|
1620
1604
|
enhance_prompt = f"""请分析以下查询,提取关键概念、关键词和主题。
|
|
1621
|
-
|
|
1605
|
+
|
|
1622
1606
|
查询:"{query}"
|
|
1623
1607
|
|
|
1624
1608
|
输出格式:对原始查询的改写版本,专注于提取关键信息,保留原始语义,以提高检索相关度。
|
|
1625
1609
|
仅输出改写后的查询文本,不要输出其他内容。
|
|
1626
1610
|
只对信息进行最小必要的增强,不要过度添加与原始查询无关的内容。
|
|
1627
1611
|
"""
|
|
1628
|
-
|
|
1612
|
+
|
|
1629
1613
|
enhanced_query = model.chat_until_success(enhance_prompt)
|
|
1630
1614
|
# 清理增强的查询结果
|
|
1631
1615
|
enhanced_query = enhanced_query.strip().strip('"')
|
|
1632
|
-
|
|
1616
|
+
|
|
1633
1617
|
# 如果增强查询有效且不是完全相同的,使用它
|
|
1634
1618
|
if enhanced_query and len(enhanced_query) >= len(query) / 2 and enhanced_query != query:
|
|
1635
1619
|
return enhanced_query
|
|
1636
|
-
|
|
1620
|
+
|
|
1637
1621
|
except Exception:
|
|
1638
1622
|
# 如果增强失败,使用原始查询
|
|
1639
1623
|
pass
|
|
1640
|
-
|
|
1624
|
+
|
|
1641
1625
|
return query
|
|
1642
1626
|
|
|
1643
1627
|
def is_index_built(self) -> bool:
|
|
1644
1628
|
"""Check if the index is built and valid
|
|
1645
|
-
|
|
1629
|
+
|
|
1646
1630
|
Returns:
|
|
1647
1631
|
bool: True if index is built and valid
|
|
1648
1632
|
"""
|
|
@@ -1650,7 +1634,7 @@ class RAGTool:
|
|
|
1650
1634
|
|
|
1651
1635
|
def _delete_file_cache(self, file_path: str, spinner=None):
|
|
1652
1636
|
"""Delete cache files for a specific file
|
|
1653
|
-
|
|
1637
|
+
|
|
1654
1638
|
Args:
|
|
1655
1639
|
file_path: Path to the original file
|
|
1656
1640
|
spinner: Optional spinner for progress information. If None, runs silently.
|
|
@@ -1662,14 +1646,14 @@ class RAGTool:
|
|
|
1662
1646
|
os.remove(doc_cache_path)
|
|
1663
1647
|
if spinner is not None:
|
|
1664
1648
|
spinner.write(f"🗑️ 删除文档缓存: {file_path}")
|
|
1665
|
-
|
|
1649
|
+
|
|
1666
1650
|
# Delete vector cache
|
|
1667
1651
|
vec_cache_path = self._get_cache_path(file_path, "vec")
|
|
1668
1652
|
if os.path.exists(vec_cache_path):
|
|
1669
1653
|
os.remove(vec_cache_path)
|
|
1670
1654
|
if spinner is not None:
|
|
1671
1655
|
spinner.write(f"🗑️ 删除向量缓存: {file_path}")
|
|
1672
|
-
|
|
1656
|
+
|
|
1673
1657
|
except Exception as e:
|
|
1674
1658
|
if spinner is not None:
|
|
1675
1659
|
spinner.write(f"❌ 删除缓存失败: {file_path}: {str(e)}")
|
|
@@ -1679,13 +1663,13 @@ def main():
|
|
|
1679
1663
|
"""Main function"""
|
|
1680
1664
|
import argparse
|
|
1681
1665
|
import sys
|
|
1682
|
-
|
|
1666
|
+
|
|
1683
1667
|
# Set standard output encoding to UTF-8
|
|
1684
1668
|
if sys.stdout.encoding != 'utf-8':
|
|
1685
1669
|
import codecs
|
|
1686
1670
|
sys.stdout = codecs.getwriter('utf-8')(sys.stdout.buffer, 'strict')
|
|
1687
1671
|
sys.stderr = codecs.getwriter('utf-8')(sys.stderr.buffer, 'strict')
|
|
1688
|
-
|
|
1672
|
+
|
|
1689
1673
|
parser = argparse.ArgumentParser(description='Document retrieval and analysis tool')
|
|
1690
1674
|
parser.add_argument('--dir', type=str, help='Directory to process')
|
|
1691
1675
|
parser.add_argument('--build', action='store_true', help='Build document index')
|
|
@@ -1709,7 +1693,7 @@ def main():
|
|
|
1709
1693
|
if not rag.is_index_built():
|
|
1710
1694
|
PrettyOutput.print(f"索引未建立,自动为目录 '{args.dir}' 建立索引...", OutputType.INFO)
|
|
1711
1695
|
rag.build_index(args.dir)
|
|
1712
|
-
|
|
1696
|
+
|
|
1713
1697
|
if not rag.is_index_built():
|
|
1714
1698
|
PrettyOutput.print("索引建立失败,请检查目录和文件格式", OutputType.ERROR)
|
|
1715
1699
|
return 1
|
|
@@ -1719,7 +1703,7 @@ def main():
|
|
|
1719
1703
|
if not results:
|
|
1720
1704
|
PrettyOutput.print("未找到相关内容", output_type=OutputType.WARNING)
|
|
1721
1705
|
return 1
|
|
1722
|
-
|
|
1706
|
+
|
|
1723
1707
|
for doc in results:
|
|
1724
1708
|
output = f"""文件: {doc.metadata['file_path']}\n"""
|
|
1725
1709
|
output += f"""片段 {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']}\n"""
|
|
@@ -1733,7 +1717,7 @@ def main():
|
|
|
1733
1717
|
if not response:
|
|
1734
1718
|
PrettyOutput.print("获取答案失败", output_type=OutputType.WARNING)
|
|
1735
1719
|
return 1
|
|
1736
|
-
|
|
1720
|
+
|
|
1737
1721
|
# Display answer
|
|
1738
1722
|
output = f"""{response}"""
|
|
1739
1723
|
PrettyOutput.print(output, output_type=OutputType.INFO)
|