iflow-mcp_pingcy_app_chatppt 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1017 @@
1
+ """
2
+ 多文档RAG引擎 - 支持多个PPT文档的统一索引管理
3
+ 整体架构:创建空index -> 不断增加多个ppt到index -> index缓存 -> 引擎初始化自动从缓存加载index用于查询
4
+ """
5
+
6
+ import os
7
+ import chromadb
8
+ import pickle
9
+ from pathlib import Path
10
+ from typing import List, Dict, Optional, Any, Set
11
+ import asyncio
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ import logging
14
+ import hashlib
15
+
16
+ # LlamaIndex 导入
17
+ from llama_index.core import Settings, VectorStoreIndex, StorageContext, load_index_from_storage
18
+ from llama_index.core.schema import TextNode, NodeWithScore, MetadataMode
19
+ from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
20
+ from llama_index.core.query_engine import CustomQueryEngine
21
+ from llama_index.core.prompts import PromptTemplate
22
+ from llama_index.core.base.response.schema import Response
23
+ from llama_index.core.vector_stores.types import MetadataFilter, MetadataFilters, FilterOperator
24
+ from llama_index.vector_stores.chroma import ChromaVectorStore
25
+ from llama_index.embeddings.openai import OpenAIEmbedding
26
+
27
+ from doubao import DoubaoVisionLLM
28
+ from ppt_utils import PPTUtils
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ class MultimodalQueryEngine(CustomQueryEngine):
33
+ """结合文本检索和图像分析的自定义查询引擎。"""
34
+
35
+ def __init__(
36
+ self,
37
+ retriever: BaseRetriever,
38
+ doubao_llm: DoubaoVisionLLM,
39
+ qa_prompt: Optional[PromptTemplate] = None
40
+ ):
41
+ # 默认问答提示模板
42
+ default_prompt = """以下是PPT幻灯片中解析的Markdown文本和图片信息。Markdown文本已经尝试将相关图表转换为表格。
43
+ 优先使用图片信息来回答问题。在无法理解图像时才使用Markdown文本信息。
44
+
45
+ ---------------------
46
+ {context_str}
47
+ ---------------------
48
+
49
+ -- 根据上下文信息并且不依赖先验知识, 回答查询。
50
+ -- 解释你是从解析的markdown、还是图片中得到答案的, 如果有差异, 请说明最终答案的理由。
51
+ -- 尽可能详细的回答问题。
52
+ -- 给出你重点参考的图片路径和页码。
53
+
54
+ 查询: {query_str}
55
+ 答案: """
56
+
57
+ final_qa_prompt = qa_prompt or PromptTemplate(default_prompt)
58
+
59
+ # 调用父类构造函数
60
+ super().__init__()
61
+
62
+ # 设置属性
63
+ self._retriever = retriever
64
+ self._doubao_llm = doubao_llm
65
+ self._qa_prompt = final_qa_prompt
66
+
67
+ def custom_query(self, query_str: str) -> Response:
68
+ """执行具有多模态理解的查询。"""
69
+ # 检索相关节点
70
+ nodes = self._retriever.retrieve(query_str)
71
+
72
+ if not nodes:
73
+ return Response(
74
+ response="抱歉,没有找到相关的PPT内容来回答您的问题。",
75
+ source_nodes=[],
76
+ metadata={}
77
+ )
78
+
79
+ # 从文本节点创建上下文字符串,包含文档名称
80
+ context_str = "\n\n".join([
81
+ f"文档: {Path(node.metadata['source']).name}, 页面 {node.metadata['page_num']}: {node.get_content(metadata_mode=MetadataMode.LLM)}\n"
82
+ f"来源图片: {node.metadata['image_path']}"
83
+ for node in nodes
84
+ ])
85
+
86
+ # 格式化提示
87
+ fmt_prompt = self._qa_prompt.format(
88
+ context_str=context_str,
89
+ query_str=query_str
90
+ )
91
+
92
+ # 获取图片路径用于视觉分析
93
+ image_paths = [node.metadata["image_path"] for node in nodes]
94
+
95
+ # 使用豆包视觉LLM生成回答
96
+ try:
97
+ response_text = self._doubao_llm.generate_response(
98
+ prompt=fmt_prompt,
99
+ image_paths=image_paths
100
+ )
101
+ except Exception as e:
102
+ response_text = f"生成回答时出现错误: {str(e)}"
103
+
104
+ return Response(
105
+ response=response_text,
106
+ source_nodes=nodes,
107
+ metadata={
108
+ "num_sources": len(nodes),
109
+ "image_paths": image_paths,
110
+ "source_documents": list(set(node.metadata["source"] for node in nodes))
111
+ }
112
+ )
113
+
114
+
115
+ class MultiDocRAGEngine:
116
+ """
117
+ 多文档RAG引擎:
118
+ - 使用单个统一的向量索引管理多个PPT文档
119
+ - 支持增量添加文档到索引
120
+ - 支持从索引中删除特定文档
121
+ - 自动索引缓存和加载
122
+ - 支持全文档检索和文档特定检索
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ persist_dir: str = "./.multi_doc_chroma_db",
128
+ cache_dir: str = "./.cache",
129
+ collection_name: str = "multi_ppt_documents",
130
+ embedding_model: str = "text-embedding-3-small",
131
+ doubao_model: str = "ep-20250205153642-hzqpj",
132
+ top_k: int = 3
133
+ ):
134
+ self.persist_dir = Path(persist_dir)
135
+ self.persist_dir.mkdir(parents=True, exist_ok=True)
136
+ self.collection_name = collection_name
137
+ self.top_k = top_k
138
+
139
+ # 索引存储路径
140
+ self.index_storage_dir = self.persist_dir / "unified_index"
141
+ self.index_storage_dir.mkdir(parents=True, exist_ok=True)
142
+
143
+ # 节点缓存目录
144
+ self.node_cache_dir = Path(cache_dir) / "multi_doc_nodes"
145
+ self.node_cache_dir.mkdir(parents=True, exist_ok=True)
146
+
147
+ # Markdown缓存目录
148
+ self.markdown_cache_dir = Path(cache_dir) / "parsed_markdown"
149
+ self.markdown_cache_dir.mkdir(parents=True, exist_ok=True)
150
+
151
+ # 初始化PPT处理工具
152
+ self.ppt_utils = PPTUtils(cache_dir=cache_dir)
153
+
154
+ # 初始化嵌入模型(始终使用OpenAI)
155
+ self.embed_model = OpenAIEmbedding(model=embedding_model)
156
+ Settings.embed_model = self.embed_model
157
+ logger.info(f"正在使用OpenAI嵌入模型: {embedding_model}")
158
+
159
+ # 初始化豆包视觉LLM
160
+ self.doubao_llm = DoubaoVisionLLM(model_name=doubao_model)
161
+
162
+ # 文档元数据缓存路径
163
+ self.docs_metadata_path = self.persist_dir / "docs_metadata.pkl"
164
+
165
+ # 向量存储和索引
166
+ self._vector_store = None
167
+ self._index = None
168
+ self._query_engine = None
169
+
170
+ # 图像解析提示
171
+ self.parse_prompt = """用中文提取图片中的详细信息,并使用Markdown格式化输出。
172
+ -- 对于其中的文字,使用OCR识别,并尽量保持原格式或类似格式输出。
173
+ -- 对于其中的表格与统计图表信息,选择表格结合文字的方式进行描述。
174
+ -- 对于其中的图形、图表、流程图等视觉元素,请用文字详细描述其内容和布局。
175
+ -- 对于其他有意义的图像部分,请使用文字描述。
176
+ -- 合理排版,使得输出内容清晰易懂。"""
177
+
178
+ # 初始化时自动加载索引
179
+ self._load_index()
180
+
181
+ def _initialize_vector_store(self):
182
+ """初始化ChromaDB向量存储。"""
183
+ if self._vector_store is None:
184
+ # 初始化ChromaDB客户端
185
+ chroma_client = chromadb.PersistentClient(path=str(self.persist_dir))
186
+ chroma_collection = chroma_client.get_or_create_collection(self.collection_name)
187
+
188
+ # 创建向量存储
189
+ self._vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
190
+ logger.info(f"已初始化ChromaDB向量存储: {self.collection_name}")
191
+
192
+ return self._vector_store
193
+
194
+ def _load_index(self):
195
+ """自动加载缓存的索引(如果存在)。"""
196
+ try:
197
+ if self.index_storage_dir.exists() and (self.index_storage_dir / "docstore.json").exists():
198
+ logger.info("正在加载缓存的统一索引")
199
+
200
+ # 初始化向量存储
201
+ vector_store = self._initialize_vector_store()
202
+
203
+ # 加载索引
204
+ storage_context = StorageContext.from_defaults(
205
+ persist_dir=str(self.index_storage_dir),
206
+ vector_store=vector_store
207
+ )
208
+
209
+ self._index = load_index_from_storage(storage_context=storage_context)
210
+ logger.info("成功加载缓存的统一索引")
211
+ else:
212
+ logger.info("未找到缓存索引,将按需创建空索引")
213
+
214
+ except Exception as e:
215
+ logger.warning(f"加载缓存索引失败: {e},将创建新索引")
216
+ self._index = None
217
+
218
+ def _ensure_index_exists(self):
219
+ """确保索引存在,如果不存在则创建空索引。"""
220
+ if self._index is None:
221
+ logger.info("正在创建新的空索引")
222
+
223
+ # 初始化向量存储
224
+ vector_store = self._initialize_vector_store()
225
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
226
+
227
+ # 创建空索引(没有节点)
228
+ self._index = VectorStoreIndex(
229
+ [],
230
+ storage_context=storage_context,
231
+ show_progress=False
232
+ )
233
+
234
+ # 立即持久化空索引
235
+ self._persist_index()
236
+ logger.info("已创建并持久化空索引")
237
+
238
+ def _persist_index(self):
239
+ """持久化索引到存储。"""
240
+ if self._index is not None:
241
+ self._index.storage_context.persist(persist_dir=str(self.index_storage_dir))
242
+ logger.info(f"索引已持久化到: {self.index_storage_dir}")
243
+
244
+ def _get_node_cache_path(self, ppt_path: str) -> Path:
245
+ """获取PPT文件的文本节点缓存路径。"""
246
+ file_hash = self._get_file_hash(ppt_path)
247
+ return self.node_cache_dir / f"{file_hash}_nodes.pkl"
248
+
249
+ def _get_markdown_cache_path(self, ppt_path: str) -> Path:
250
+ """获取PPT文件的Markdown缓存路径。"""
251
+ file_hash = self._get_file_hash(ppt_path)
252
+ ppt_name = Path(ppt_path).stem
253
+ return self.markdown_cache_dir / f"{ppt_name}_{file_hash}.md"
254
+
255
+ def _get_file_hash(self, file_path: str) -> str:
256
+ """计算文件的哈希值用于缓存。"""
257
+ with open(file_path, 'rb') as f:
258
+ return hashlib.md5(f.read()).hexdigest()
259
+
260
+ def _save_docs_metadata(self, docs_info: Dict[str, Dict]):
261
+ """保存文档元数据(以doc_id为key)。"""
262
+ with open(self.docs_metadata_path, 'wb') as f:
263
+ pickle.dump(docs_info, f)
264
+
265
+ def _load_docs_metadata(self) -> Dict[str, Dict]:
266
+ """加载文档元数据(以doc_id为key)。"""
267
+ if self.docs_metadata_path.exists():
268
+ try:
269
+ with open(self.docs_metadata_path, 'rb') as f:
270
+ return pickle.load(f)
271
+ except Exception as e:
272
+ logger.warning(f"加载文档元数据失败: {e}")
273
+ return {}
274
+
275
+ def _get_doc_id_from_path(self, ppt_path: str) -> str:
276
+ """从文件路径获取文档ID(hash值)。"""
277
+ return self._get_file_hash(ppt_path)
278
+
279
+ def _find_doc_by_id(self, doc_id: str) -> Optional[Dict[str, Any]]:
280
+ """根据doc_id查找文档信息。"""
281
+ docs_info = self._load_docs_metadata()
282
+ return docs_info.get(doc_id)
283
+
284
+ def is_document_indexed(self, ppt_path: str) -> bool:
285
+ """
286
+ 检查文档是否已经在索引中(推荐方法)。
287
+
288
+ Args:
289
+ ppt_path: PPT文件路径
290
+
291
+ Returns:
292
+ 文档是否已索引
293
+ """
294
+ try:
295
+ doc_id = self._get_doc_id_from_path(ppt_path)
296
+ return self._find_doc_by_id(doc_id) is not None
297
+ except Exception:
298
+ return False
299
+
300
+ def is_document_id_indexed(self, doc_id: str) -> bool:
301
+ """
302
+ 检查文档ID是否已经在索引中。
303
+
304
+ Args:
305
+ doc_id: 文档ID
306
+
307
+ Returns:
308
+ 文档是否已索引
309
+ """
310
+ return self._find_doc_by_id(doc_id) is not None
311
+
312
+ async def _parse_image_async(self, image_path: str, page_num: int) -> str:
313
+ """使用豆包视觉LLM异步解析单张图片。"""
314
+ try:
315
+ content = await asyncio.to_thread(
316
+ self.doubao_llm.generate_response,
317
+ self.parse_prompt,
318
+ [image_path]
319
+ )
320
+ logger.info(f"已解析第{page_num}页图片: {len(content)} 个字符")
321
+ return content
322
+ except Exception as e:
323
+ logger.error(f"解析图片失败 {image_path}: {e}")
324
+ return f"解析图片失败: {str(e)}"
325
+
326
+ async def _save_parsed_markdown(self, ppt_path: str, image_paths: List[str], parsed_contents: List[str]) -> None:
327
+ """将解析的内容保存到Markdown文件中。"""
328
+ try:
329
+ markdown_path = self._get_markdown_cache_path(ppt_path)
330
+ ppt_name = Path(ppt_path).name
331
+
332
+ # 创建Markdown内容
333
+ markdown_content = []
334
+ markdown_content.append(f"# {ppt_name} - 视觉解析结果")
335
+ markdown_content.append(f"\n生成时间: {Path(ppt_path).stat().st_mtime}")
336
+ markdown_content.append(f"总页数: {len(image_paths)}")
337
+ markdown_content.append(f"文档路径: {ppt_path}")
338
+ markdown_content.append("\n" + "="*80 + "\n")
339
+
340
+ # 添加每页的解析结果
341
+ for i, (image_path, content) in enumerate(zip(image_paths, parsed_contents), 1):
342
+ markdown_content.append(f"## 第 {i} 页")
343
+ markdown_content.append(f"**图片路径:** `{image_path}`")
344
+ markdown_content.append(f"**内容长度:** {len(content)} 字符")
345
+ markdown_content.append("\n### 解析内容:")
346
+ markdown_content.append(content)
347
+ markdown_content.append("\n" + "-"*50 + "\n")
348
+
349
+ # 写入文件
350
+ full_content = "\n".join(markdown_content)
351
+ with open(markdown_path, 'w', encoding='utf-8') as f:
352
+ f.write(full_content)
353
+
354
+ logger.info(f"已保存解析的Markdown到: {markdown_path}")
355
+
356
+ except Exception as e:
357
+ logger.error(f"保存Markdown解析结果失败 {ppt_path}: {e}")
358
+
359
+ async def _create_text_nodes(self, ppt_path: str, image_paths: List[str]) -> List[TextNode]:
360
+ """从图片创建文本节点,包含关联的元数据。"""
361
+ logger.info(f"正在为{len(image_paths)}张图片创建文本节点")
362
+
363
+ # 检查API密钥是否配置
364
+ openai_key = os.getenv("OPENAI_API_KEY", "")
365
+ ark_key = os.getenv("ARK_API_KEY", "")
366
+
367
+ if not openai_key or not ark_key:
368
+ raise RuntimeError("必须配置OPENAI_API_KEY和ARK_API_KEY环境变量")
369
+
370
+ if openai_key.startswith(("test", "sk-test")) or ark_key.startswith("test"):
371
+ raise RuntimeError("API密钥不能是测试密钥,必须使用真实的API密钥")
372
+
373
+ # 并发解析所有图片
374
+ parse_tasks = [
375
+ self._parse_image_async(image_path, i + 1)
376
+ for i, image_path in enumerate(image_paths)
377
+ ]
378
+
379
+ parsed_contents = await asyncio.gather(*parse_tasks)
380
+
381
+ # 保存解析结果到Markdown文件
382
+ await self._save_parsed_markdown(ppt_path, image_paths, parsed_contents)
383
+
384
+ # 创建文本节点,添加文档唯一标识符
385
+ doc_id = self._get_file_hash(ppt_path)
386
+ nodes = []
387
+ for i, (image_path, content) in enumerate(zip(image_paths, parsed_contents)):
388
+ node = TextNode(
389
+ text=content,
390
+ metadata={
391
+ "source": ppt_path,
392
+ "source_file_id": doc_id, # 文档唯一标识符
393
+ "doc_name": Path(ppt_path).name, # 文档名称
394
+ "page_num": i + 1,
395
+ "image_path": image_path,
396
+ "doc_type": "ppt_slide"
397
+ }
398
+ )
399
+ nodes.append(node)
400
+
401
+ logger.info(f"已创建{len(nodes)}个文本节点,文档ID: {doc_id}")
402
+ return nodes
403
+
404
+ async def add_ppt_document(self, ppt_path: str, force_reprocess: bool = False) -> Dict[str, Any]:
405
+ """
406
+ 向统一索引中添加PPT文档。
407
+
408
+ Args:
409
+ ppt_path: PPT文件路径
410
+ force_reprocess: 强制重新处理,即使已缓存
411
+
412
+ Returns:
413
+ 添加结果字典
414
+ """
415
+ ppt_path = str(Path(ppt_path).resolve())
416
+
417
+ if not Path(ppt_path).exists():
418
+ raise FileNotFoundError(f"PPT文件未找到:{ppt_path}")
419
+
420
+ # 检查文档是否已经在索引中 - 直接用doc_id检查
421
+ doc_id = self._get_doc_id_from_path(ppt_path)
422
+ existing_doc_info = self._find_doc_by_id(doc_id)
423
+
424
+ if not force_reprocess and existing_doc_info is not None:
425
+ return {
426
+ "status": "skipped",
427
+ "message": "文档已存在于索引中",
428
+ "source": ppt_path,
429
+ "doc_id": doc_id
430
+ }
431
+
432
+ logger.info(f"正在添加PPT文档到统一索引: {ppt_path}")
433
+
434
+ # 确保索引存在
435
+ self._ensure_index_exists()
436
+
437
+ # 获取节点缓存路径
438
+ node_cache_path = self._get_node_cache_path(ppt_path)
439
+
440
+ try:
441
+ # 检查是否有缓存的节点
442
+ if not force_reprocess and node_cache_path.exists():
443
+ logger.info(f"正在使用缓存的节点:{ppt_path}")
444
+ with open(node_cache_path, 'rb') as f:
445
+ nodes = pickle.load(f)
446
+ else:
447
+ # 完整处理流程
448
+ logger.info("正在通过完整流程处理PPT")
449
+
450
+ # 阶段1:PPT → 图片
451
+ image_paths = self.ppt_utils.ppt_to_images(ppt_path)
452
+ if not image_paths:
453
+ raise RuntimeError("PPT中未生成任何图片")
454
+ logger.info(f"从PPT生成了{len(image_paths)}张图片")
455
+
456
+ # 阶段2:图片 → 文本节点
457
+ nodes = await self._create_text_nodes(ppt_path, image_paths)
458
+
459
+ # 缓存节点
460
+ with open(node_cache_path, 'wb') as f:
461
+ pickle.dump(nodes, f)
462
+ logger.info(f"已缓存{len(nodes)}个节点")
463
+
464
+ logger.info(f"{nodes}")
465
+ # 阶段3:将节点添加到统一索引
466
+ logger.info(f"正在将{len(nodes)}个节点添加到统一索引")
467
+ self._index.insert_nodes(nodes)
468
+
469
+ # 持久化索引
470
+ self._persist_index()
471
+
472
+ # 更新文档元数据(使用doc_id作为key)
473
+ docs_info = self._load_docs_metadata()
474
+ docs_info[doc_id] = {
475
+ "file_path": ppt_path, # 文件路径作为字段
476
+ "doc_name": Path(ppt_path).name,
477
+ "pages": len(nodes),
478
+ "added_at": str(Path(ppt_path).stat().st_mtime),
479
+ "file_size": Path(ppt_path).stat().st_size
480
+ }
481
+ self._save_docs_metadata(docs_info)
482
+
483
+ # 重置查询引擎
484
+ self._query_engine = None
485
+
486
+ logger.info(f"成功将PPT添加到统一索引:{ppt_path}")
487
+
488
+ return {
489
+ "status": "success",
490
+ "message": "PPT已成功添加到统一索引",
491
+ "pages": len(nodes),
492
+ "source": ppt_path
493
+ }
494
+
495
+ except Exception as e:
496
+ logger.error(f"添加PPT失败 {ppt_path}: {e}")
497
+ return {
498
+ "status": "error",
499
+ "message": f"添加PPT失败:{str(e)}",
500
+ "source": ppt_path
501
+ }
502
+
503
+ def remove_ppt_document(self, ppt_path: str) -> Dict[str, Any]:
504
+ """
505
+ 从统一索引中删除PPT文档。
506
+
507
+ Args:
508
+ ppt_path: PPT文件路径
509
+
510
+ Returns:
511
+ 删除结果字典
512
+ """
513
+ ppt_path = str(Path(ppt_path).resolve())
514
+
515
+ if self._index is None:
516
+ return {
517
+ "status": "error",
518
+ "message": "索引不存在",
519
+ "source": ppt_path
520
+ }
521
+
522
+ try:
523
+ # 通过文件路径计算doc_id,然后查找文档
524
+ doc_id = self._get_doc_id_from_path(ppt_path)
525
+ doc_info = self._find_doc_by_id(doc_id)
526
+
527
+ if doc_info is None:
528
+ return {
529
+ "status": "error",
530
+ "message": "文档未在索引中找到",
531
+ "source": ppt_path,
532
+ "doc_id": doc_id
533
+ }
534
+
535
+ # 从向量存储中删除具有指定doc_id的节点
536
+ vector_store = self._index.vector_store
537
+
538
+ if hasattr(vector_store, '_collection'):
539
+
540
+ # 直接从ChromaDB集合中删除
541
+ collection = vector_store._collection
542
+ try:
543
+ collection.delete(where={"source_file_id": doc_id})
544
+ deleted_count = "unknown" # ChromaDB delete不返回计数
545
+ except Exception as e:
546
+ logger.error(f"从ChromaDB删除失败:{e}")
547
+ return {
548
+ "status": "error",
549
+ "message": f"从向量存储删除失败:{str(e)}",
550
+ "source": ppt_path
551
+ }
552
+ else:
553
+ logger.warning("无法从向量存储删除 - 不支持的操作")
554
+ return {
555
+ "status": "error",
556
+ "message": "向量存储不支持删除操作",
557
+ "source": ppt_path
558
+ }
559
+
560
+ # 更新文档元数据(删除对应的doc_id)
561
+ docs_info = self._load_docs_metadata()
562
+ if doc_id in docs_info:
563
+ del docs_info[doc_id]
564
+ self._save_docs_metadata(docs_info)
565
+
566
+ # 删除节点缓存
567
+ node_cache_path = self._get_node_cache_path(ppt_path)
568
+ if node_cache_path.exists():
569
+ node_cache_path.unlink()
570
+ logger.info(f"已删除节点缓存:{node_cache_path}")
571
+
572
+ # 删除Markdown缓存
573
+ markdown_cache_path = self._get_markdown_cache_path(ppt_path)
574
+ if markdown_cache_path.exists():
575
+ markdown_cache_path.unlink()
576
+ logger.info(f"已删除Markdown缓存:{markdown_cache_path}")
577
+
578
+ # 重新加载索引以确保一致性
579
+ logger.info("正在重新加载索引以确保一致性")
580
+ self._index = None
581
+ self._load_index()
582
+
583
+ # 重置查询引擎
584
+ self._query_engine = None
585
+
586
+ logger.info(f"成功从统一索引中删除PPT:{ppt_path}")
587
+
588
+ return {
589
+ "status": "success",
590
+ "message": f"PPT已成功从统一索引中移除",
591
+ "source": ppt_path,
592
+ "deleted_nodes": deleted_count
593
+ }
594
+
595
+ except Exception as e:
596
+ logger.error(f"删除PPT失败 {ppt_path}: {e}")
597
+ return {
598
+ "status": "error",
599
+ "message": f"删除PPT失败:{str(e)}",
600
+ "source": ppt_path
601
+ }
602
+
603
+ def get_document_info(self, doc_id: Optional[str] = None) -> Dict[str, Any]:
604
+ """
605
+ 获取文档信息。
606
+
607
+ Args:
608
+ doc_id: 特定文档ID
609
+ 如果两为None则返回所有文档信息
610
+
611
+ Returns:
612
+ 文档信息字典(以doc_id为key)
613
+ """
614
+ docs_info = self._load_docs_metadata()
615
+
616
+ if doc_id is not None:
617
+
618
+ # 通过doc_id查询
619
+ if doc_id in docs_info:
620
+ return {doc_id: docs_info[doc_id]}
621
+ else:
622
+ return {}
623
+
624
+ # 返回所有文档信息
625
+ return docs_info
626
+
627
+ def _get_query_engine(self, file_path: Optional[str] = None, doc_id: Optional[str] = None) -> MultimodalQueryEngine:
628
+ """获取或创建多模态查询引擎,支持文档过滤。
629
+
630
+ Args:
631
+ file_path: 通过文件路径过滤文档
632
+ doc_id: 通过文档ID过滤文档
633
+ 注意:file_path和doc_id只能传入其中一个,如果都传入则优先使用doc_id
634
+
635
+ Returns:
636
+ 配置好的多模态查询引擎
637
+ """
638
+ if self._index is None:
639
+ raise RuntimeError("尚未索引任何文档。请先添加PPT文档。")
640
+
641
+ # 确定过滤的doc_id
642
+ filter_doc_id = None
643
+
644
+ if doc_id is not None and file_path is not None:
645
+ logger.warning("同时提供了doc_id和file_path,将使用doc_id")
646
+
647
+ if doc_id is not None:
648
+ # 直接使用提供的doc_id
649
+ filter_doc_id = doc_id
650
+ # 验证doc_id是否存在
651
+ if self._find_doc_by_id(filter_doc_id) is None:
652
+ raise ValueError(f"未找到文档ID:{filter_doc_id}")
653
+
654
+ elif file_path is not None:
655
+ # 通过文件路径计算doc_id
656
+ try:
657
+ ppt_path = str(Path(file_path).resolve())
658
+ filter_doc_id = self._get_doc_id_from_path(ppt_path)
659
+ # 验证doc_id是否存在
660
+ if self._find_doc_by_id(filter_doc_id) is None:
661
+ raise ValueError(f"文件路径对应的文档未找到:{file_path}")
662
+ except Exception as e:
663
+ raise ValueError(f"无效的文件路径:{file_path},错误:{str(e)}")
664
+
665
+ # 创建检索器
666
+ if filter_doc_id is not None:
667
+ # 创建带文档过滤的检索器
668
+ filters = MetadataFilters(
669
+ filters=[
670
+ MetadataFilter(
671
+ key="source_file_id",
672
+ value=filter_doc_id,
673
+ operator=FilterOperator.EQ
674
+ )
675
+ ]
676
+ )
677
+ retriever = VectorIndexRetriever(
678
+ index=self._index,
679
+ similarity_top_k=self.top_k,
680
+ filters=filters
681
+ )
682
+ logger.info(f"已创建过滤检索器,文档ID:{filter_doc_id}")
683
+ else:
684
+ # 创建通用检索器(无过滤)
685
+ retriever = self._index.as_retriever(similarity_top_k=self.top_k)
686
+ logger.info("已创建通用检索器(无文档过滤)")
687
+
688
+ # 创建多模态查询引擎
689
+ query_engine = MultimodalQueryEngine(
690
+ retriever=retriever,
691
+ doubao_llm=self.doubao_llm
692
+ )
693
+
694
+ return query_engine
695
+
696
+ async def query(self, query: str, file_path: Optional[str] = None, doc_id: Optional[str] = None) -> Dict[str, Any]:
697
+ """
698
+ 查询所有文档或特定文档。
699
+
700
+ Args:
701
+ query: 用户查询
702
+ file_path: 可选的文档文件路径过滤器,如果指定则只在该文档中搜索
703
+ doc_id: 可选的文档ID过滤器,如果指定则只在该文档中搜索
704
+ 注意:file_path和doc_id只能传入其中一个,如果都传入则优先使用doc_id
705
+
706
+ Returns:
707
+ 查询结果字典
708
+ """
709
+ try:
710
+ # 记录过滤参数
711
+ filter_info = []
712
+ if doc_id:
713
+ filter_info.append(f"doc_id={doc_id}")
714
+ if file_path:
715
+ filter_info.append(f"file_path={file_path}")
716
+
717
+ filter_desc = ", ".join(filter_info) if filter_info else "无过滤器"
718
+ logger.info(f"正在使用{filter_desc}进行查询")
719
+
720
+ # 获取查询引擎
721
+ query_engine = self._get_query_engine(file_path=file_path, doc_id=doc_id)
722
+
723
+ # 执行查询
724
+ response = await asyncio.to_thread(query_engine.query, query)
725
+
726
+ # 提取信息
727
+ result = {
728
+ "status": "success",
729
+ "query": query,
730
+ "file_path": file_path,
731
+ "doc_id": doc_id,
732
+ "answer": response.response,
733
+ "sources": [
734
+ {
735
+ "doc_name": node.metadata.get("doc_name"),
736
+ "doc_id": node.metadata.get("source_file_id"),
737
+ "page_num": node.metadata.get("page_num"),
738
+ "image_path": node.metadata.get("image_path"),
739
+ "source": node.metadata.get("source")
740
+ }
741
+ for node in response.source_nodes
742
+ ],
743
+ "metadata": response.metadata or {}
744
+ }
745
+
746
+ logger.info(f"使用{len(response.source_nodes)}个来源生成了答案")
747
+ return result
748
+
749
+ except Exception as e:
750
+ logger.error(f"查询失败:{e}")
751
+ return {
752
+ "status": "error",
753
+ "query": query,
754
+ "file_path": file_path,
755
+ "doc_id": doc_id,
756
+ "message": f"查询失败:{str(e)}"
757
+ }
758
+
759
+ def get_index_status(self) -> Dict[str, Any]:
760
+ """获取索引状态。"""
761
+ try:
762
+ docs_info = self.get_document_info()
763
+
764
+ total_pages = sum(info.get("pages", 0) for info in docs_info.values())
765
+ indexed_docs = [info["file_path"] for info in docs_info.values()]
766
+
767
+ if self._index is None:
768
+ status = "empty"
769
+ elif not docs_info:
770
+ status = "empty"
771
+ else:
772
+ status = "ready"
773
+
774
+ return {
775
+ "status": status,
776
+ "total_documents": len(docs_info),
777
+ "total_pages": total_pages,
778
+ "documents": indexed_docs,
779
+ "document_details": docs_info,
780
+ "collection_name": self.collection_name,
781
+ "index_path": str(self.index_storage_dir)
782
+ }
783
+
784
+ except Exception as e:
785
+ return {
786
+ "status": "error",
787
+ "message": str(e)
788
+ }
789
+
790
+ def print_vectorstore_info(self) -> None:
791
+ """
792
+ 打印向量存储中集合的详细信息。
793
+ 包括节点数量、文档数量、文档ID和路径等信息。
794
+ """
795
+ try:
796
+ print("\n🔍 向量存储集合信息")
797
+ print("=" * 60)
798
+
799
+ # 确保向量存储已初始化
800
+ if self._vector_store is None:
801
+ vector_store = self._initialize_vector_store()
802
+ else:
803
+ vector_store = self._vector_store
804
+
805
+ # 获取ChromaDB集合
806
+ if hasattr(vector_store, '_collection'):
807
+ collection = vector_store._collection
808
+
809
+ # 获取集合基本信息
810
+ collection_name = collection.name
811
+ print(f"📚 集合名称: {collection_name}")
812
+
813
+ # 获取所有文档
814
+ try:
815
+ result = collection.get(include=['metadatas', 'documents'])
816
+
817
+ if not result['ids']:
818
+ print("📭 集合为空,没有任何节点")
819
+ return
820
+
821
+ total_nodes = len(result['ids'])
822
+ print(f"📄 总节点数: {total_nodes}")
823
+
824
+ # 分析文档信息
825
+ doc_stats = {}
826
+ for i, metadata in enumerate(result['metadatas']):
827
+ if metadata:
828
+ doc_id = metadata.get('source_file_id', 'unknown')
829
+ doc_name = metadata.get('doc_name', 'unknown')
830
+ page_num = metadata.get('page_num', 0)
831
+ print(f"🔍 处理文档 {i + 1}/{len(result['metadatas'])}: {doc_name} (ID: {doc_id}, 页面: {page_num})")
832
+
833
+ if doc_id not in doc_stats:
834
+ doc_stats[doc_id] = {
835
+ 'doc_name': doc_name,
836
+ 'pages': set(),
837
+ 'node_count': 0
838
+ }
839
+
840
+ doc_stats[doc_id]['pages'].add(page_num)
841
+ doc_stats[doc_id]['node_count'] += 1
842
+
843
+ print(f"📊 包含文档数: {len(doc_stats)}")
844
+ print("\n📋 文档详情:")
845
+ print("-" * 60)
846
+
847
+ # 按文档显示详细信息
848
+ for i, (doc_id, stats) in enumerate(doc_stats.items(), 1):
849
+ print(f"\n{i}. 📄 {stats['doc_name']}")
850
+ print(f" 🆔 文档ID: {doc_id}")
851
+ print(f" 📍 节点数: {stats['node_count']}")
852
+ print(f" 📃 页面数: {len(stats['pages'])}")
853
+ if stats['pages']:
854
+ page_range = f"{min(stats['pages'])}-{max(stats['pages'])}"
855
+ print(f" 📖 页面范围: {page_range}")
856
+
857
+ print("\n" + "=" * 60)
858
+
859
+ # 验证与元数据的一致性
860
+ metadata_docs = self.get_document_info()
861
+ if len(metadata_docs) != len(doc_stats):
862
+ print("⚠️ 警告: 向量存储中的文档数与元数据不一致!")
863
+ print(f" 向量存储: {len(doc_stats)} 个文档")
864
+ print(f" 元数据: {len(metadata_docs)} 个文档")
865
+ else:
866
+ print("✅ 向量存储与元数据一致")
867
+
868
+ except Exception as e:
869
+ print(f"❌ 无法获取集合数据: {e}")
870
+
871
+ else:
872
+ print("❌ 向量存储不支持此操作")
873
+
874
+ except Exception as e:
875
+ print(f"❌ 获取向量存储信息失败: {e}")
876
+
877
+ def get_vectorstore_stats(self) -> Dict[str, Any]:
878
+ """
879
+ 获取向量存储统计信息(返回字典格式)。
880
+
881
+ Returns:
882
+ 包含向量存储统计信息的字典
883
+ """
884
+ try:
885
+ # 确保向量存储已初始化
886
+ if self._vector_store is None:
887
+ vector_store = self._initialize_vector_store()
888
+ else:
889
+ vector_store = self._vector_store
890
+
891
+ # 获取ChromaDB集合
892
+ if hasattr(vector_store, '_collection'):
893
+ collection = vector_store._collection
894
+
895
+ # 获取所有文档
896
+ result = collection.get(include=['metadatas', 'documents'])
897
+
898
+ if not result['ids']:
899
+ return {
900
+ "status": "empty",
901
+ "collection_name": collection.name,
902
+ "total_nodes": 0,
903
+ "total_documents": 0,
904
+ "documents": []
905
+ }
906
+
907
+ total_nodes = len(result['ids'])
908
+
909
+ # 分析文档信息
910
+ doc_stats = {}
911
+ for metadata in result['metadatas']:
912
+ if metadata:
913
+ doc_id = metadata.get('doc_id', 'unknown')
914
+ doc_name = metadata.get('doc_name', 'unknown')
915
+ source = metadata.get('source', 'unknown')
916
+ page_num = metadata.get('page_num', 0)
917
+
918
+ if doc_id not in doc_stats:
919
+ doc_stats[doc_id] = {
920
+ 'doc_id': doc_id,
921
+ 'doc_name': doc_name,
922
+ 'source': source,
923
+ 'pages': set(),
924
+ 'node_count': 0
925
+ }
926
+
927
+ doc_stats[doc_id]['pages'].add(page_num)
928
+ doc_stats[doc_id]['node_count'] += 1
929
+
930
+ # 转换页面集合为列表并排序
931
+ documents = []
932
+ for stats in doc_stats.values():
933
+ pages_list = sorted(list(stats['pages']))
934
+ documents.append({
935
+ 'doc_id': stats['doc_id'],
936
+ 'doc_name': stats['doc_name'],
937
+ 'source': stats['source'],
938
+ 'node_count': stats['node_count'],
939
+ 'page_count': len(pages_list),
940
+ 'page_range': f"{min(pages_list)}-{max(pages_list)}" if pages_list else "0-0"
941
+ })
942
+
943
+ return {
944
+ "status": "success",
945
+ "collection_name": collection.name,
946
+ "total_nodes": total_nodes,
947
+ "total_documents": len(doc_stats),
948
+ "documents": documents
949
+ }
950
+
951
+ else:
952
+ return {
953
+ "status": "error",
954
+ "message": "向量存储不支持此操作"
955
+ }
956
+
957
+ except Exception as e:
958
+ return {
959
+ "status": "error",
960
+ "message": f"获取向量存储统计信息失败:{str(e)}"
961
+ }
962
+
963
+ def clear_all_documents(self) -> Dict[str, Any]:
964
+ """
965
+ 清除所有文档和索引,通过循环调用 remove_ppt_document 来实现。
966
+
967
+ Returns:
968
+ 清理结果
969
+ """
970
+ try:
971
+ logger.info("开始逐个删除所有文档以清空索引")
972
+ docs_info = self.get_document_info()
973
+
974
+ if not docs_info:
975
+ logger.info("未找到需要清除的文档")
976
+ return {
977
+ "status": "success",
978
+ "message": "未找到需要清除的文档"
979
+ }
980
+
981
+ # 创建要删除的文档路径列表的副本,以避免在迭代时修改字典
982
+ doc_paths_to_remove = [doc['file_path'] for doc in docs_info.values()]
983
+ total_docs = len(doc_paths_to_remove)
984
+ logger.info(f"找到{total_docs}个文档需要删除")
985
+
986
+ all_successful = True
987
+ errors = []
988
+
989
+ for i, ppt_path in enumerate(doc_paths_to_remove):
990
+ print(f"正在删除文档 ({i + 1}/{total_docs}): {Path(ppt_path).name}")
991
+ logger.info(f"正在删除文档:{ppt_path}")
992
+
993
+ result = self.remove_ppt_document(ppt_path)
994
+
995
+ if result['status'] == 'error':
996
+ all_successful = False
997
+ error_message = f"删除失败 {ppt_path}: {result['message']}"
998
+ logger.warning(error_message)
999
+ errors.append(error_message)
1000
+
1001
+ # 最终状态检查
1002
+ final_docs_info = self.get_document_info()
1003
+ if not final_docs_info and all_successful:
1004
+ message = "所有文档清除成功"
1005
+ logger.info(message)
1006
+ return {"status": "success", "message": message}
1007
+ else:
1008
+ message = f"清除所有文档完成。成功:{all_successful}。剩余文档:{len(final_docs_info)}。错误:{errors}"
1009
+ logger.warning(message)
1010
+ return {"status": "error" if not all_successful else "success", "message": message}
1011
+
1012
+ except Exception as e:
1013
+ logger.error(f"清除所有文档失败:{e}", exc_info=True)
1014
+ return {
1015
+ "status": "error",
1016
+ "message": f"清空所有文档时发生意外错误:{str(e)}"
1017
+ }