isage-middleware 0.2.4.3__cp311-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. isage_middleware-0.2.4.3.dist-info/METADATA +266 -0
  2. isage_middleware-0.2.4.3.dist-info/RECORD +94 -0
  3. isage_middleware-0.2.4.3.dist-info/WHEEL +5 -0
  4. isage_middleware-0.2.4.3.dist-info/top_level.txt +1 -0
  5. sage/middleware/__init__.py +59 -0
  6. sage/middleware/_version.py +6 -0
  7. sage/middleware/components/__init__.py +30 -0
  8. sage/middleware/components/extensions_compat.py +141 -0
  9. sage/middleware/components/sage_db/__init__.py +116 -0
  10. sage/middleware/components/sage_db/backend.py +136 -0
  11. sage/middleware/components/sage_db/service.py +15 -0
  12. sage/middleware/components/sage_flow/__init__.py +76 -0
  13. sage/middleware/components/sage_flow/python/__init__.py +14 -0
  14. sage/middleware/components/sage_flow/python/micro_service/__init__.py +4 -0
  15. sage/middleware/components/sage_flow/python/micro_service/sage_flow_service.py +88 -0
  16. sage/middleware/components/sage_flow/python/sage_flow.py +30 -0
  17. sage/middleware/components/sage_flow/service.py +14 -0
  18. sage/middleware/components/sage_mem/__init__.py +83 -0
  19. sage/middleware/components/sage_sias/__init__.py +59 -0
  20. sage/middleware/components/sage_sias/continual_learner.py +184 -0
  21. sage/middleware/components/sage_sias/coreset_selector.py +302 -0
  22. sage/middleware/components/sage_sias/types.py +94 -0
  23. sage/middleware/components/sage_tsdb/__init__.py +81 -0
  24. sage/middleware/components/sage_tsdb/python/__init__.py +21 -0
  25. sage/middleware/components/sage_tsdb/python/_sage_tsdb.pyi +17 -0
  26. sage/middleware/components/sage_tsdb/python/algorithms/__init__.py +17 -0
  27. sage/middleware/components/sage_tsdb/python/algorithms/base.py +51 -0
  28. sage/middleware/components/sage_tsdb/python/algorithms/out_of_order_join.py +248 -0
  29. sage/middleware/components/sage_tsdb/python/algorithms/window_aggregator.py +296 -0
  30. sage/middleware/components/sage_tsdb/python/micro_service/__init__.py +7 -0
  31. sage/middleware/components/sage_tsdb/python/micro_service/sage_tsdb_service.py +365 -0
  32. sage/middleware/components/sage_tsdb/python/sage_tsdb.py +523 -0
  33. sage/middleware/components/sage_tsdb/service.py +17 -0
  34. sage/middleware/components/vector_stores/__init__.py +25 -0
  35. sage/middleware/components/vector_stores/chroma.py +483 -0
  36. sage/middleware/components/vector_stores/chroma_adapter.py +185 -0
  37. sage/middleware/components/vector_stores/milvus.py +677 -0
  38. sage/middleware/operators/__init__.py +56 -0
  39. sage/middleware/operators/agent/__init__.py +24 -0
  40. sage/middleware/operators/agent/planning/__init__.py +5 -0
  41. sage/middleware/operators/agent/planning/llm_adapter.py +41 -0
  42. sage/middleware/operators/agent/planning/planner_adapter.py +98 -0
  43. sage/middleware/operators/agent/planning/router.py +107 -0
  44. sage/middleware/operators/agent/runtime.py +296 -0
  45. sage/middleware/operators/agentic/__init__.py +41 -0
  46. sage/middleware/operators/agentic/config.py +254 -0
  47. sage/middleware/operators/agentic/planning_operator.py +125 -0
  48. sage/middleware/operators/agentic/refined_searcher.py +132 -0
  49. sage/middleware/operators/agentic/runtime.py +241 -0
  50. sage/middleware/operators/agentic/timing_operator.py +125 -0
  51. sage/middleware/operators/agentic/tool_selection_operator.py +127 -0
  52. sage/middleware/operators/context/__init__.py +17 -0
  53. sage/middleware/operators/context/critic_evaluation.py +16 -0
  54. sage/middleware/operators/context/model_context.py +565 -0
  55. sage/middleware/operators/context/quality_label.py +12 -0
  56. sage/middleware/operators/context/search_query_results.py +61 -0
  57. sage/middleware/operators/context/search_result.py +42 -0
  58. sage/middleware/operators/context/search_session.py +79 -0
  59. sage/middleware/operators/filters/__init__.py +26 -0
  60. sage/middleware/operators/filters/context_sink.py +387 -0
  61. sage/middleware/operators/filters/context_source.py +376 -0
  62. sage/middleware/operators/filters/evaluate_filter.py +83 -0
  63. sage/middleware/operators/filters/tool_filter.py +74 -0
  64. sage/middleware/operators/llm/__init__.py +18 -0
  65. sage/middleware/operators/llm/sagellm_generator.py +432 -0
  66. sage/middleware/operators/rag/__init__.py +147 -0
  67. sage/middleware/operators/rag/arxiv.py +331 -0
  68. sage/middleware/operators/rag/chunk.py +13 -0
  69. sage/middleware/operators/rag/document_loaders.py +23 -0
  70. sage/middleware/operators/rag/evaluate.py +658 -0
  71. sage/middleware/operators/rag/generator.py +340 -0
  72. sage/middleware/operators/rag/index_builder/__init__.py +48 -0
  73. sage/middleware/operators/rag/index_builder/builder.py +363 -0
  74. sage/middleware/operators/rag/index_builder/manifest.py +101 -0
  75. sage/middleware/operators/rag/index_builder/storage.py +131 -0
  76. sage/middleware/operators/rag/pipeline.py +46 -0
  77. sage/middleware/operators/rag/profiler.py +59 -0
  78. sage/middleware/operators/rag/promptor.py +400 -0
  79. sage/middleware/operators/rag/refiner.py +231 -0
  80. sage/middleware/operators/rag/reranker.py +364 -0
  81. sage/middleware/operators/rag/retriever.py +1308 -0
  82. sage/middleware/operators/rag/searcher.py +37 -0
  83. sage/middleware/operators/rag/types.py +28 -0
  84. sage/middleware/operators/rag/writer.py +80 -0
  85. sage/middleware/operators/tools/__init__.py +71 -0
  86. sage/middleware/operators/tools/arxiv_paper_searcher.py +175 -0
  87. sage/middleware/operators/tools/arxiv_searcher.py +102 -0
  88. sage/middleware/operators/tools/duckduckgo_searcher.py +105 -0
  89. sage/middleware/operators/tools/image_captioner.py +104 -0
  90. sage/middleware/operators/tools/nature_news_fetcher.py +224 -0
  91. sage/middleware/operators/tools/searcher_tool.py +514 -0
  92. sage/middleware/operators/tools/text_detector.py +185 -0
  93. sage/middleware/operators/tools/url_text_extractor.py +104 -0
  94. sage/middleware/py.typed +2 -0
@@ -0,0 +1,1308 @@
1
+ import json
2
+ import os
3
+ import time
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+
8
+ from sage.common.components.sage_embedding.embedding_model import EmbeddingModel
9
+ from sage.common.config.output_paths import get_states_file
10
+ from sage.common.core.functions import MapFunction as MapOperator
11
+ from sage.middleware.components.vector_stores.chroma import ChromaBackend, ChromaUtils
12
+ from sage.middleware.components.vector_stores.milvus import MilvusBackend, MilvusUtils
13
+
14
+
15
+ # ChromaDB 密集检索器
16
+ class ChromaRetriever(MapOperator):
17
+ def __init__(self, config, enable_profile=False, **kwargs):
18
+ super().__init__(**kwargs)
19
+ self.config = config
20
+ self.enable_profile = enable_profile
21
+
22
+ # 只支持 ChromaDB 后端
23
+ self.backend_type = "chroma"
24
+
25
+ # 通用配置
26
+ self.vector_dimension = config.get("dimension", 384)
27
+ self.top_k = config.get("top_k", 10)
28
+ self.embedding_config = config.get("embedding", {})
29
+
30
+ # 先初始化 embedding 模型
31
+ self._init_embedding_model()
32
+
33
+ # 再初始化 ChromaDB 后端(这样知识库加载时embedding模型已可用)
34
+ self.chroma_config = config.get("chroma", {})
35
+ self._init_chroma_backend()
36
+
37
+ # 只有启用profile时才设置数据存储路径
38
+ if self.enable_profile:
39
+ # Use unified output path system
40
+ self.data_base_path = str(get_states_file("dummy", "retriever_data").parent)
41
+ os.makedirs(self.data_base_path, exist_ok=True)
42
+ self.data_records = []
43
+
44
+ def _init_chroma_backend(self):
45
+ """初始化 ChromaDB 后端"""
46
+ try:
47
+ # 检查 ChromaDB 是否可用
48
+ if not ChromaUtils.check_chromadb_availability():
49
+ raise ImportError(
50
+ "ChromaDB dependencies not available. Install with: pip install chromadb"
51
+ )
52
+
53
+ # 验证配置
54
+ if not ChromaUtils.validate_chroma_config(self.chroma_config):
55
+ raise ValueError("Invalid ChromaDB configuration")
56
+
57
+ # 创建 ChromaDB 后端实例
58
+ self.chroma_backend = ChromaBackend(self.chroma_config, self.logger)
59
+
60
+ # 自动加载知识库文件
61
+ knowledge_file = self.chroma_config.get("knowledge_file")
62
+ if knowledge_file:
63
+ # 如果是相对路径,尝试从当前工作目录和项目根目录解析
64
+ if not os.path.isabs(knowledge_file):
65
+ # 尝试从当前目录
66
+ if os.path.exists(knowledge_file):
67
+ resolved_path = knowledge_file
68
+ else:
69
+ # 尝试从项目根目录解析
70
+ project_root = os.getcwd()
71
+ while project_root != "/" and not os.path.exists(
72
+ os.path.join(project_root, "pyproject.toml")
73
+ ):
74
+ project_root = os.path.dirname(project_root)
75
+
76
+ potential_path = os.path.join(project_root, knowledge_file)
77
+ if os.path.exists(potential_path):
78
+ resolved_path = potential_path
79
+ else:
80
+ resolved_path = knowledge_file
81
+ else:
82
+ resolved_path = knowledge_file
83
+
84
+ if os.path.exists(resolved_path):
85
+ self._load_knowledge_from_file(resolved_path)
86
+ else:
87
+ self.logger.warning(f"Knowledge file not found: {resolved_path}")
88
+
89
+ except Exception as e:
90
+ self.logger.error(f"Failed to initialize ChromaDB: {e}")
91
+ raise
92
+
93
+ def _load_knowledge_from_file(self, file_path: str):
94
+ """从文件加载知识库"""
95
+ try:
96
+ # 使用 ChromaDB 后端加载
97
+ success = self.chroma_backend.load_knowledge_from_file(file_path, self.embedding_model)
98
+ if not success:
99
+ self.logger.error(f"Failed to load knowledge from file: {file_path}")
100
+
101
+ except Exception as e:
102
+ self.logger.error(f"Failed to load knowledge from file {file_path}: {e}")
103
+
104
+ def _init_embedding_model(self):
105
+ """初始化HuggingFace嵌入模型(使用sentence-transformers)"""
106
+ embedding_method = self.embedding_config.get("method", "default")
107
+ model = self.embedding_config.get("model", "sentence-transformers/all-MiniLM-L6-v2")
108
+
109
+ self.logger.info(f"Initializing embedding model with method: {embedding_method}")
110
+ self.embedding_model = EmbeddingModel(method=embedding_method, model=model)
111
+
112
+ # 验证向量维度
113
+ if hasattr(self.embedding_model, "get_dim"):
114
+ model_dim = self.embedding_model.get_dim()
115
+ if model_dim != self.vector_dimension:
116
+ self.logger.warning(
117
+ f"Embedding model dimension ({model_dim}) != configured dimension ({self.vector_dimension})"
118
+ )
119
+ # 更新向量维度以匹配模型
120
+ self.vector_dimension = model_dim
121
+
122
+ def add_documents(self, documents: list[str], doc_ids: list[str] | None = None) -> list[str]:
123
+ """
124
+ 添加文档到索引中
125
+ Args:
126
+ documents: 文档内容列表
127
+ doc_ids: 文档ID列表,如果为None则自动生成
128
+ Returns:
129
+ 添加的文档ID列表
130
+ """
131
+ if not documents:
132
+ return []
133
+
134
+ # 生成文档ID
135
+ if doc_ids is None:
136
+ doc_ids = [f"doc_{int(time.time() * 1000)}_{i}" for i in range(len(documents))]
137
+ elif len(doc_ids) != len(documents):
138
+ raise ValueError("doc_ids length must match documents length")
139
+
140
+ # 生成 embedding
141
+ embeddings = []
142
+ for doc in documents:
143
+ embedding = self.embedding_model.embed(doc)
144
+ # print(embedding)
145
+ embeddings.append(np.array(embedding, dtype=np.float32))
146
+
147
+ # 使用 ChromaDB 后端添加文档
148
+ return self.chroma_backend.add_documents(documents, embeddings, doc_ids)
149
+
150
+ def _save_data_record(self, query, retrieved_docs):
151
+ """保存检索数据记录"""
152
+ if not self.enable_profile:
153
+ return
154
+
155
+ record = {
156
+ "timestamp": time.time(),
157
+ "query": query,
158
+ "retrieval_results": retrieved_docs,
159
+ "backend_type": self.backend_type,
160
+ "backend_config": getattr(self, f"{self.backend_type}_config", {}),
161
+ "embedding_config": self.embedding_config,
162
+ }
163
+ self.data_records.append(record)
164
+ self._persist_data_records()
165
+
166
+ def _persist_data_records(self):
167
+ """将数据记录持久化到文件"""
168
+ if not self.enable_profile or not self.data_records:
169
+ return
170
+
171
+ timestamp = int(time.time())
172
+ filename = f"retriever_data_{timestamp}.json"
173
+ path = os.path.join(self.data_base_path, filename)
174
+
175
+ try:
176
+ with open(path, "w", encoding="utf-8") as f:
177
+ json.dump(self.data_records, f, ensure_ascii=False, indent=2)
178
+ self.data_records = []
179
+ except Exception as e:
180
+ self.logger.error(f"Failed to persist data records: {e}")
181
+
182
+ def execute(self, data: str) -> dict[str, Any]:
183
+ """
184
+ 执行检索
185
+ Args:
186
+ data: 查询字符串、元组或字典
187
+ Returns:
188
+ dict: {"query": ..., "results": ..., "input": 原始输入, ...}
189
+ """
190
+ is_dict_input = isinstance(data, dict)
191
+ if is_dict_input:
192
+ input_query = data.get("query", "")
193
+ elif isinstance(data, tuple) and len(data) > 0:
194
+ input_query = data[0]
195
+ else:
196
+ input_query = data
197
+
198
+ if not isinstance(input_query, str):
199
+ self.logger.error(f"Invalid input query type: {type(input_query)}")
200
+ if is_dict_input:
201
+ data["retrieval_results"] = []
202
+ return data
203
+ else:
204
+ return {"query": str(input_query), "retrieval_results": [], "input": data}
205
+
206
+ self.logger.info(
207
+ f"[ {self.__class__.__name__}]: Starting {self.backend_type.upper()} retrieval for query: {input_query}"
208
+ )
209
+ self.logger.info(f"[ {self.__class__.__name__}]: Using top_k = {self.top_k}")
210
+
211
+ try:
212
+ # 生成查询向量
213
+ query_embedding = self.embedding_model.embed(input_query)
214
+ query_vector = np.array(query_embedding, dtype=np.float32)
215
+
216
+ # 使用 ChromaDB 执行检索
217
+ retrieved_docs = self.chroma_backend.search(query_vector, input_query, self.top_k)
218
+
219
+ self.logger.info(
220
+ f"\033[32m[ {self.__class__.__name__}]: Retrieved {len(retrieved_docs)} documents from ChromaDB\033[0m"
221
+ )
222
+ self.logger.debug(
223
+ f"Retrieved documents: {retrieved_docs[:3]}..."
224
+ ) # 只显示前3个文档的预览
225
+
226
+ # 将字符串列表转换为标准化的字典格式,以便后续组件使用
227
+ standardized_docs = []
228
+ for doc in retrieved_docs:
229
+ if isinstance(doc, str):
230
+ standardized_docs.append({"text": doc})
231
+ elif isinstance(doc, dict):
232
+ # 如果已经是字典,确保有text字段
233
+ if "text" not in doc and "content" in doc:
234
+ doc["text"] = doc["content"]
235
+ elif "text" not in doc:
236
+ # 将整个字典内容作为text
237
+ doc["text"] = str(doc)
238
+ standardized_docs.append(doc)
239
+ else:
240
+ # 其他类型转为字符串
241
+ standardized_docs.append({"text": str(doc)})
242
+
243
+ # 保存数据记录(只有enable_profile=True时才保存)
244
+ if self.enable_profile:
245
+ self._save_data_record(input_query, standardized_docs)
246
+
247
+ if is_dict_input:
248
+ # 保存原始检索结果(用于压缩率计算)
249
+ if "retrieval_results" not in data:
250
+ data["retrieval_results"] = standardized_docs
251
+ return data
252
+ else:
253
+ return {
254
+ "query": input_query,
255
+ "retrieval_results": standardized_docs,
256
+ "input": data,
257
+ }
258
+
259
+ except Exception as e:
260
+ self.logger.error(f"ChromaDB retrieval failed: {str(e)}")
261
+ if is_dict_input:
262
+ data["retrieval_results"] = []
263
+ return data
264
+ else:
265
+ return {"query": input_query, "retrieval_results": [], "input": data}
266
+
267
+ def save_index(self, save_path: str) -> bool:
268
+ """
269
+ 保存索引到磁盘
270
+ Args:
271
+ save_path: 保存路径
272
+ Returns:
273
+ 是否保存成功
274
+ """
275
+ return self.chroma_backend.save_config(save_path)
276
+
277
+ def load_index(self, load_path: str) -> bool:
278
+ """
279
+ 从磁盘加载索引
280
+ Args:
281
+ load_path: 加载路径
282
+ Returns:
283
+ 是否加载成功
284
+ """
285
+ return self.chroma_backend.load_config(load_path)
286
+
287
+ def get_collection_info(self) -> dict[str, Any]:
288
+ """获取集合信息"""
289
+ return self.chroma_backend.get_collection_info()
290
+
291
+ def __del__(self):
292
+ """确保在对象销毁时保存所有未保存的记录"""
293
+ if hasattr(self, "enable_profile") and self.enable_profile:
294
+ try:
295
+ self._persist_data_records()
296
+ except Exception:
297
+ pass
298
+
299
+
300
+ # Milvus稠密向量检索
301
+ class MilvusDenseRetriever(MapOperator):
302
+ """
303
+ 使用 Milvus 后端进行稠密向量检索。
304
+ """
305
+
306
+ def __init__(self, config, enable_profile=False, **kwargs):
307
+ super().__init__(**kwargs)
308
+ self.config = config
309
+ self.enable_profile = enable_profile
310
+
311
+ # 只支持Milvus后端
312
+ self.backend_type = "milvus"
313
+
314
+ # 通用配置
315
+ self.vector_dimension = self.config.get("dimension", 384)
316
+ self.top_k = self.config.get("top_k", 5)
317
+ self.embedding_config = self.config.get("embedding", {})
318
+
319
+ # 初始化Milvus后端
320
+ self.milvus_config = config.get("milvus_dense", {})
321
+ self._init_milvus_backend()
322
+
323
+ # 初始化 embedding 模型
324
+ self._init_embedding_model()
325
+
326
+ # 只有启用profile时才设置数据存储路径
327
+ if self.enable_profile:
328
+ if self.ctx is not None and hasattr(self.ctx, "env_base_dir") and self.ctx.env_base_dir:
329
+ self.data_base_path = os.path.join(
330
+ self.ctx.env_base_dir, ".sage_states", "retriever_data"
331
+ )
332
+ else:
333
+ # 使用默认路径
334
+ self.data_base_path = os.path.join(os.getcwd(), ".sage_states", "retriever_data")
335
+
336
+ os.makedirs(self.data_base_path, exist_ok=True)
337
+ self.data_records = []
338
+
339
+ def _init_milvus_backend(self):
340
+ """初始化milvus后端"""
341
+ try:
342
+ # 检查 milvus 是否可用
343
+ if not MilvusUtils.check_milvus_available():
344
+ raise ImportError(
345
+ "Milvus dependencies not available. Install with: pip install pymilvus"
346
+ )
347
+
348
+ # 验证配置
349
+ if not MilvusUtils.validate_milvus_config(self.milvus_config):
350
+ raise ValueError("Invalid Milvus configuration")
351
+
352
+ # 初始化后端
353
+ self.milvus_backend = MilvusBackend(config=self.milvus_config, logger=self.logger)
354
+
355
+ # 自动加载知识库文件
356
+ knowledge_file = self.milvus_config.get("knowledge_file")
357
+ if knowledge_file and os.path.exists(knowledge_file):
358
+ self._load_knowledge_from_file_dense(knowledge_file)
359
+
360
+ except Exception as e:
361
+ self.logger.error(f"Failed to initialize milvus: {e}")
362
+ raise
363
+
364
+ def _load_knowledge_from_file_dense(self, file_path: str):
365
+ """从文件中加载知识库"""
366
+ try:
367
+ # 使用Milvus后端加载
368
+ success = self.milvus_backend.load_knowledge_from_file_dense(
369
+ file_path, self.embedding_model
370
+ )
371
+ if not success:
372
+ self.logger.error(f"Failed to load knowledge from file: {file_path}")
373
+ except Exception as e:
374
+ self.logger.error(f"Failed to load knowledge from file: {e}")
375
+
376
+ def _init_embedding_model(self):
377
+ """初始化embedding模型"""
378
+ embedding_method = self.embedding_config.get("method", "default")
379
+ model = self.embedding_config.get("model", "sentence-transformers/all-MiniLM-L6-v2")
380
+
381
+ self.logger.info(f"Initializing embedding model with method: {embedding_method}")
382
+ self.embedding_model = EmbeddingModel(method=embedding_method, model=model)
383
+
384
+ # 验证向量维度
385
+ if hasattr(self.embedding_model, "get_dim"):
386
+ model_dim = self.embedding_model.get_dim()
387
+ if model_dim != self.vector_dimension:
388
+ self.logger.warning(
389
+ f"Embedding model dimension ({model_dim}) != configured dimension ({self.vector_dimension})"
390
+ )
391
+ # 更新向量维度以匹配模型
392
+ self.vector_dimension = model_dim
393
+
394
+ def add_documents(self, documents: list[str], doc_ids: list[str] | None = None) -> list[str]:
395
+ """
396
+ 添加文档到milvus
397
+ Args:
398
+ documents: 文档内容列表
399
+ doc_ids: 文档ID列表,如果为None则自动生成
400
+ Returns:
401
+ 添加的文档ID列表
402
+ """
403
+ if not documents:
404
+ self.logger.warning("No documents to add")
405
+ return []
406
+
407
+ if doc_ids is None:
408
+ doc_ids = [f"doc_{int(time.time() * 1000)}_{i}" for i in range(len(documents))]
409
+ elif len(doc_ids) != len(documents):
410
+ raise ValueError("doc_ids length must match documents length")
411
+
412
+ # 生成 embedding
413
+ embeddings = []
414
+ for doc in documents:
415
+ embedding = self.embedding_model.embed(doc)
416
+ print(embedding)
417
+ embeddings.append(np.array(embedding, dtype=np.float32))
418
+
419
+ # 使用 milvus 后端添加文档
420
+ return self.milvus_backend.add_dense_documents(documents, embeddings, doc_ids)
421
+
422
+ def _save_data_record(self, query, retrieved_docs):
423
+ """
424
+ 保存检索数据记录
425
+ """
426
+ if not self.enable_profile:
427
+ return
428
+
429
+ record = {
430
+ "timestamp": time.time(),
431
+ "query": query,
432
+ "retrieval_results": retrieved_docs,
433
+ "backend_type": self.backend_type,
434
+ "backend_config": getattr(self, f"{self.backend_type}_config", {}),
435
+ "embedding_config": self.embedding_config,
436
+ }
437
+
438
+ self.data_records.append(record)
439
+ self._persist_data_records()
440
+
441
+ def _persist_data_records(self):
442
+ """
443
+ 将数据记录持久化到文件
444
+ """
445
+ if not self.enable_profile or not self.data_records:
446
+ return
447
+
448
+ timestamp = int(time.time())
449
+ filename = f"milvus_dense_retriever_data_{timestamp}.json"
450
+ path = os.path.join(self.data_base_path, filename)
451
+
452
+ try:
453
+ with open(path, "w", encoding="utf-8") as f:
454
+ json.dump(self.data_records, f, ensure_ascii=False, indent=2)
455
+ self.data_records = []
456
+ except Exception as e:
457
+ self.logger.error(f"Failed to persist data records: {e}")
458
+
459
+ def execute(self, data: str) -> dict[str, Any]:
460
+ """
461
+ 执行检索
462
+ Args:
463
+ data: 查询字符串、元组或字典
464
+ Returns:
465
+ dict: {"query": ..., "retrieval_results": ..., "input": 原始输入, ...}
466
+ """
467
+ # 支持字典类型输入,优先取 question 字段
468
+ is_dict_input = isinstance(data, dict)
469
+ if is_dict_input:
470
+ input_query = data.get("question", "")
471
+ elif isinstance(data, tuple) and len(data) > 0:
472
+ input_query = data[0]
473
+ else:
474
+ input_query = data
475
+
476
+ if not isinstance(input_query, str):
477
+ self.logger.error(f"Invalid input query type: {type(input_query)}")
478
+ if is_dict_input:
479
+ data["retrieval_results"] = []
480
+ return data
481
+ else:
482
+ return {
483
+ "query": str(input_query),
484
+ "retrieval_results": [],
485
+ "input": data,
486
+ }
487
+
488
+ self.logger.info(
489
+ f"[ {self.__class__.__name__}]: Starting {self.backend_type.upper()} retrieval for query: {input_query}"
490
+ )
491
+ self.logger.info(f"[ {self.__class__.__name__}]: Using top_k = {self.top_k}")
492
+
493
+ try:
494
+ # 生成查询向量
495
+ query_embedding = self.embedding_model.encode(input_query)
496
+ query_vector = np.array(query_embedding, dtype=np.float32)
497
+
498
+ # 使用Milvus执行稠密检索
499
+ retrieved_docs = self.milvus_backend.dense_search(
500
+ query_vector=query_vector,
501
+ top_k=self.top_k,
502
+ )
503
+
504
+ self.logger.info(
505
+ f"\033[32m[ {self.__class__.__name__}]: Retrieved {len(retrieved_docs)} documents from Milvus\033[0m"
506
+ )
507
+ self.logger.debug(
508
+ f"Retrieved documents: {retrieved_docs[:3]}..."
509
+ ) # 只显示前3个文档的预览
510
+
511
+ print(f"Query: {input_query}")
512
+ print(f"Configured top_k: {self.top_k}")
513
+ print(f"Retrieved {len(retrieved_docs)} documents from Milvus")
514
+ print(retrieved_docs)
515
+
516
+ # 保存数据记录(只有enable_profile=True时才保存)
517
+ if self.enable_profile:
518
+ self._save_data_record(input_query, retrieved_docs)
519
+
520
+ if is_dict_input:
521
+ data["retrieval_results"] = retrieved_docs
522
+ return data
523
+ else:
524
+ return {
525
+ "query": input_query,
526
+ "retrieval_results": retrieved_docs,
527
+ "input": data,
528
+ }
529
+
530
+ except Exception as e:
531
+ self.logger.error(f" retrieval failed: {str(e)}")
532
+ if is_dict_input:
533
+ data["retrieval_results"] = []
534
+ return data
535
+ else:
536
+ return {
537
+ "query": input_query,
538
+ "retrieval_results": [],
539
+ "input": data,
540
+ }
541
+
542
+ def save_config(self, save_path: str) -> bool:
543
+ """
544
+ 保存配置到磁盘
545
+ Args:
546
+ save_path: 保存路径
547
+ Returns:
548
+ 是否保存成功
549
+ """
550
+ return self.milvus_backend.save_config(save_path)
551
+
552
+ def load_config(self, load_path: str) -> bool:
553
+ """
554
+ 从磁盘加载配置
555
+ Args:
556
+ load_path: 加载路径
557
+ Returns:
558
+ 是否加载成功
559
+ """
560
+ return self.milvus_backend.load_config(load_path)
561
+
562
+ def get_collection_info(self) -> dict[str, Any]:
563
+ """
564
+ 获取集合信息
565
+ """
566
+ return self.milvus_backend.get_collection_info()
567
+
568
+ def delete_collection(self, collection_name: str) -> bool:
569
+ """
570
+ 删除集合
571
+ """
572
+ return self.milvus_backend.delete_collection(collection_name)
573
+
574
+ def __del__(self):
575
+ """确保在对象销毁时保存所有未保存的记录"""
576
+ if hasattr(self, "enable_profile") and self.enable_profile:
577
+ try:
578
+ self._persist_data_records()
579
+ except Exception:
580
+ pass
581
+
582
+
583
+ # Milvus稀疏向量检索
584
+ class MilvusSparseRetriever(MapOperator):
585
+ """
586
+ 使用 Milvus 后端进行稀疏向量检索。
587
+ """
588
+
589
+ def __init__(self, config, enable_profile=False, **kwargs):
590
+ super().__init__(**kwargs)
591
+ self.config = config
592
+ self.enable_profile = enable_profile
593
+
594
+ # 只支持Milvus后端
595
+ self.backend_type = "milvus"
596
+
597
+ # 通用配置
598
+ self.top_k = self.config.get("top_k", 10)
599
+
600
+ # 初始化Milvus后端
601
+ self.milvus_config = config.get("milvus_sparse", {})
602
+ self._init_milvus_backend()
603
+ self._init_embedding_model()
604
+
605
+ # 只有启用profile时才设置数据存储路径
606
+ if self.enable_profile:
607
+ if self.ctx is not None and hasattr(self.ctx, "env_base_dir") and self.ctx.env_base_dir:
608
+ self.data_base_path = os.path.join(
609
+ self.ctx.env_base_dir, ".sage_states", "retriever_data"
610
+ )
611
+ else:
612
+ # 使用默认路径
613
+ self.data_base_path = os.path.join(os.getcwd(), ".sage_states", "retriever_data")
614
+
615
+ os.makedirs(self.data_base_path, exist_ok=True)
616
+ self.data_records = []
617
+
618
+ def _init_milvus_backend(self):
619
+ """初始化milvus后端"""
620
+ try:
621
+ # 检查 milvus 是否可用
622
+ if not MilvusUtils.check_milvus_available():
623
+ raise ImportError(
624
+ "Milvus dependencies not available. Install with: pip install pymilvus"
625
+ )
626
+
627
+ # 验证配置
628
+ if not MilvusUtils.validate_milvus_config(self.milvus_config):
629
+ raise ValueError("Invalid Milvus configuration")
630
+
631
+ # 初始化后端
632
+ self.milvus_backend = MilvusBackend(config=self.milvus_config, logger=self.logger)
633
+
634
+ # 自动加载知识库文件
635
+ knowledge_file = self.milvus_config.get("knowledge_file")
636
+ if knowledge_file and os.path.exists(knowledge_file):
637
+ self._load_knowledge_from_file(knowledge_file)
638
+
639
+ except Exception as e:
640
+ self.logger.error(f"Failed to initialize milvus: {e}")
641
+ raise
642
+
643
+ def _init_embedding_model(self):
644
+ """初始化embedding模型"""
645
+ try:
646
+ # 尝试新的导入路径(PyMilvus 2.6.0+)
647
+ try:
648
+ from pymilvus.model.hybrid import (
649
+ BGEM3EmbeddingFunction, # type: ignore[import-not-found]
650
+ )
651
+ except ImportError:
652
+ # 如果失败,尝试直接从 model 导入
653
+ try:
654
+ from pymilvus.model import (
655
+ BGEM3EmbeddingFunction, # type: ignore[import-not-found]
656
+ )
657
+ except ImportError:
658
+ # 最后尝试安装单独的包
659
+ self.logger.error(
660
+ "Please install: pip install 'pymilvus[model]' or pip install pymilvus.model"
661
+ )
662
+ raise ImportError("Embedding model dependencies not available")
663
+
664
+ self.embedding_model = BGEM3EmbeddingFunction(use_fp16=False, device="cpu")
665
+
666
+ except ImportError as e:
667
+ self.logger.error(f"Failed to import EmbeddingModel: {e}")
668
+ raise ImportError("Embedding model dependencies not available")
669
+
670
+ def _load_knowledge_from_file(self, file_path: str):
671
+ """从文件中加载知识库"""
672
+ try:
673
+ # 使用Milvus后端加载
674
+ success = self.milvus_backend.load_knowledge_from_file_sparse(file_path)
675
+ self.logger.info(f"Loaded {success} documents from {file_path}")
676
+ if not success:
677
+ self.logger.error(f"Failed to load knowledge from file: {file_path}")
678
+ except Exception as e:
679
+ self.logger.error(f"Failed to load knowledge from file: {e}")
680
+
681
+ def add_documents(self, documents: list[str], doc_ids: list[str] | None = None) -> list[str]:
682
+ """
683
+ 添加文档到milvus
684
+ Args:
685
+ documents: 文档内容列表
686
+ doc_ids: 文档ID列表,如果为None则自动生成
687
+ Returns:
688
+ 添加的文档ID列表
689
+ """
690
+ if not documents:
691
+ self.logger.warning("No documents to add")
692
+ return []
693
+
694
+ # 生成 embedding
695
+ embedding = self.embedding_model.encode_documents(documents)
696
+ embeddings = embedding["sparse"]
697
+
698
+ if doc_ids is None:
699
+ doc_ids = [f"doc_{int(time.time() * 1000)}_{i}" for i in range(len(documents))]
700
+ elif len(doc_ids) != len(documents):
701
+ raise ValueError("doc_ids length must match documents length")
702
+
703
+ # 使用 milvus 后端添加文档
704
+ return self.milvus_backend.add_sparse_documents(documents, embeddings, doc_ids)
705
+
706
+ def _save_data_record(self, query, retrieved_docs):
707
+ """
708
+ 保存检索数据记录
709
+ """
710
+ if not self.enable_profile:
711
+ return
712
+
713
+ record = {
714
+ "timestamp": time.time(),
715
+ "query": query,
716
+ "retrieval_results": retrieved_docs,
717
+ "backend_type": self.backend_type,
718
+ "backend_config": getattr(self, f"{self.backend_type}_config", {}),
719
+ }
720
+
721
+ self.data_records.append(record)
722
+ self._persist_data_records()
723
+
724
+ def _persist_data_records(self):
725
+ """
726
+ 将数据记录持久化到文件
727
+ """
728
+ if not self.enable_profile or not self.data_records:
729
+ return
730
+
731
+ timestamp = int(time.time())
732
+ filename = f"milvus_dense_retriever_data_{timestamp}.json"
733
+ path = os.path.join(self.data_base_path, filename)
734
+
735
+ try:
736
+ with open(path, "w", encoding="utf-8") as f:
737
+ json.dump(self.data_records, f, ensure_ascii=False, indent=2)
738
+ self.data_records = []
739
+ except Exception as e:
740
+ self.logger.error(f"Failed to persist data records: {e}")
741
+
742
+ def execute(self, data: str) -> dict[str, Any]:
743
+ """
744
+ 执行检索
745
+ Args:
746
+ data: 查询字符串、元组或字典
747
+ Returns:
748
+ dict: {"query": ..., "retrieval_results": ..., "input": 原始输入, ...}
749
+ """
750
+ # 支持字典类型输入,优先取 question 字段
751
+ is_dict_input = isinstance(data, dict)
752
+ if is_dict_input:
753
+ input_query = data.get("question", "")
754
+ elif isinstance(data, tuple) and len(data) > 0:
755
+ input_query = data[0]
756
+ else:
757
+ input_query = data
758
+
759
+ if not isinstance(input_query, str):
760
+ self.logger.error(f"Invalid input query type: {type(input_query)}")
761
+ if is_dict_input:
762
+ data["retrieval_results"] = []
763
+ return data
764
+ else:
765
+ return {
766
+ "query": str(input_query),
767
+ "retrieval_results": [],
768
+ "input": data,
769
+ }
770
+
771
+ self.logger.info(
772
+ f"[ {self.__class__.__name__}]: Starting {self.backend_type.upper()} retrieval for query: {input_query}"
773
+ )
774
+ self.logger.info(f"[ {self.__class__.__name__}]: Using top_k = {self.top_k}")
775
+
776
+ try:
777
+ # 使用Milvus执行稀疏检索 - 直接传递查询文本,让sparse_search方法处理向量生成
778
+ retrieved_docs = self.milvus_backend.sparse_search(
779
+ query_text=input_query,
780
+ top_k=self.top_k,
781
+ )
782
+
783
+ self.logger.info(
784
+ f"\033[32m[ {self.__class__.__name__}]: Retrieved {len(retrieved_docs)} documents from Milvus\033[0m"
785
+ )
786
+ self.logger.debug(
787
+ f"Retrieved documents: {retrieved_docs[:3]}..."
788
+ ) # 只显示前3个文档的预览
789
+
790
+ print(f"Query: {input_query}")
791
+ print(f"Configured top_k: {self.top_k}")
792
+ print(f"Retrieved {len(retrieved_docs)} documents from Milvus")
793
+ print(retrieved_docs)
794
+
795
+ # 保存数据记录(只有enable_profile=True时才保存)
796
+ if self.enable_profile:
797
+ self._save_data_record(input_query, retrieved_docs)
798
+
799
+ if is_dict_input:
800
+ data["retrieval_results"] = retrieved_docs
801
+ return data
802
+ else:
803
+ return {
804
+ "query": input_query,
805
+ "retrieval_results": retrieved_docs,
806
+ "input": data,
807
+ }
808
+
809
+ except Exception as e:
810
+ self.logger.error(f" retrieval failed: {str(e)}")
811
+ if is_dict_input:
812
+ data["retrieval_results"] = []
813
+ return data
814
+ else:
815
+ return {
816
+ "query": input_query,
817
+ "retrieval_results": [],
818
+ "input": data,
819
+ }
820
+
821
+ def save_config(self, save_path: str) -> bool:
822
+ """
823
+ 保存配置到磁盘
824
+ Args:
825
+ save_path: 保存路径
826
+ Returns:
827
+ 是否保存成功
828
+ """
829
+ return self.milvus_backend.save_config(save_path)
830
+
831
+ def load_config(self, load_path: str) -> bool:
832
+ """
833
+ 从磁盘加载配置
834
+ Args:
835
+ load_path: 加载路径
836
+ Returns:
837
+ 是否加载成功
838
+ """
839
+ return self.milvus_backend.load_config(load_path)
840
+
841
+ def get_collection_info(self) -> dict[str, Any]:
842
+ """
843
+ 获取集合信息
844
+ """
845
+ return self.milvus_backend.get_collection_info()
846
+
847
+ def __del__(self):
848
+ """确保在对象销毁时保存所有未保存的记录"""
849
+ if hasattr(self, "enable_profile") and self.enable_profile:
850
+ try:
851
+ self._persist_data_records()
852
+ except Exception:
853
+ pass
854
+
855
+
856
+ # Wiki18 FAISS 检索器
857
+ class Wiki18FAISSRetriever(MapOperator):
858
+ """
859
+ 基于FAISS的Wiki18数据集检索器,使用HuggingFace嵌入模型(如BGE-Large-EN-v1.5)
860
+ """
861
+
862
+ def __init__(self, config, enable_profile=False, **kwargs):
863
+ super().__init__(**kwargs)
864
+ self.config = config
865
+ self.enable_profile = enable_profile
866
+
867
+ # 配置参数
868
+ self.top_k = config.get("top_k", 5)
869
+ self.embedding_config = config.get("embedding", {})
870
+ self.faiss_config = config.get("faiss", {})
871
+
872
+ # 初始化BGE-M3模型
873
+ self._init_bge_m3_model()
874
+
875
+ # 初始化FAISS索引
876
+ self._init_faiss_index()
877
+
878
+ # Profile数据存储
879
+ if self.enable_profile:
880
+ if self.ctx is not None and hasattr(self.ctx, "env_base_dir") and self.ctx.env_base_dir:
881
+ self.data_base_path = os.path.join(
882
+ self.ctx.env_base_dir, ".sage_states", "retriever_data"
883
+ )
884
+ else:
885
+ self.data_base_path = os.path.join(os.getcwd(), ".sage_states", "retriever_data")
886
+
887
+ os.makedirs(self.data_base_path, exist_ok=True)
888
+ self.data_records = []
889
+
890
+ def _init_bge_m3_model(self):
891
+ """初始化BGE-M3嵌入模型(使用sentence-transformers)"""
892
+ try:
893
+ import torch
894
+ from sentence_transformers import SentenceTransformer
895
+
896
+ # 从配置获取模型路径,默认使用BGE-Large-EN-v1.5
897
+ model_path = self.embedding_config.get("model", "BAAI/bge-large-en-v1.5")
898
+
899
+ # 从配置获取GPU设备,默认使用GPU 0
900
+ gpu_device = self.embedding_config.get("gpu_device", 0)
901
+
902
+ # 明确指定GPU设备
903
+ if torch.cuda.is_available():
904
+ device = f"cuda:{gpu_device}"
905
+ self.logger.info(f"嵌入模型将使用GPU {gpu_device}")
906
+ else:
907
+ device = "cpu"
908
+ self.logger.info("嵌入模型将使用CPU")
909
+
910
+ # 初始化嵌入模型
911
+ self.embedding_model = SentenceTransformer(model_path, device=device)
912
+
913
+ self.logger.info(f"嵌入模型初始化成功: {model_path} 在设备 {device}")
914
+
915
+ except ImportError as e:
916
+ self.logger.error(f"无法导入sentence-transformers: {e}")
917
+ self.logger.error("请安装: pip install sentence-transformers")
918
+ raise
919
+ except Exception as e:
920
+ self.logger.error(f"嵌入模型初始化失败: {e}")
921
+ raise
922
+
923
+ def _init_faiss_index(self):
924
+ """初始化FAISS索引"""
925
+ try:
926
+ import faiss
927
+
928
+ # FAISS配置 - 从配置文件读取路径
929
+ index_path = self.faiss_config.get("index_path")
930
+ documents_path = self.faiss_config.get("documents_path")
931
+ mapping_path = self.faiss_config.get("mapping_path") # 可选的段落到文档映射
932
+
933
+ # 检查必需的配置项
934
+ if not index_path:
935
+ raise ValueError("faiss.index_path 配置项是必需的")
936
+ if not documents_path:
937
+ raise ValueError("faiss.documents_path 配置项是必需的")
938
+
939
+ # 展开环境变量(支持 ${HOME}, ${USER}, $HOME 等格式)
940
+ index_path = os.path.expandvars(index_path)
941
+ documents_path = os.path.expandvars(documents_path)
942
+ if mapping_path:
943
+ mapping_path = os.path.expandvars(mapping_path)
944
+
945
+ # 尝试加载已有索引
946
+ if os.path.exists(index_path) and os.path.exists(documents_path):
947
+ self.logger.info(f"加载已有FAISS索引: {index_path}")
948
+ self.faiss_index = faiss.read_index(index_path)
949
+
950
+ # 加载段落到文档的映射(如果有)
951
+ self.passage_to_doc_mapping = None
952
+ if mapping_path and os.path.exists(mapping_path):
953
+ try:
954
+ with open(mapping_path, encoding="utf-8") as f:
955
+ self.passage_to_doc_mapping = json.load(f)
956
+ self.logger.info(
957
+ f"加载了段落映射: {len(self.passage_to_doc_mapping)} 个段落映射到文档"
958
+ )
959
+ except Exception as e:
960
+ self.logger.warning(f"加载段落映射失败: {e},将直接使用检索索引")
961
+
962
+ # 加载JSONL格式的文档数据
963
+ self.documents = []
964
+ try:
965
+ with open(documents_path, encoding="utf-8") as f:
966
+ for line in f:
967
+ line = line.strip()
968
+ if line:
969
+ try:
970
+ doc = json.loads(line)
971
+ self.documents.append(doc)
972
+ except json.JSONDecodeError as e:
973
+ self.logger.warning(
974
+ f"跳过无效的JSON行: {line[:100]}... 错误: {e}"
975
+ )
976
+
977
+ except Exception as e:
978
+ self.logger.error(f"加载文档文件失败: {e}")
979
+ self.documents = []
980
+
981
+ self.logger.info(f"加载了 {len(self.documents)} 个文档")
982
+ self.logger.info(f"FAISS索引大小: {self.faiss_index.ntotal} 个向量")
983
+
984
+ else:
985
+ # 如果没有预构建索引,需要从Wiki18数据构建
986
+ self.logger.warning(f"未找到预构建的FAISS索引: {index_path}")
987
+ self.logger.warning("需要先构建Wiki18 FAISS索引")
988
+
989
+ # 创建空索引和文档列表作为占位符
990
+ dimension = 1024 # 嵌入模型的维度(BGE系列)
991
+ self.faiss_index = faiss.IndexFlatIP(dimension) # 内积相似度
992
+ self.documents = []
993
+
994
+ except ImportError as e:
995
+ self.logger.error(f"无法导入FAISS: {e}")
996
+ self.logger.error("请安装FAISS: pip install faiss-cpu 或 pip install faiss-gpu")
997
+ raise
998
+ except Exception as e:
999
+ self.logger.error(f"FAISS索引初始化失败: {e}")
1000
+ raise
1001
+
1002
+ def _encode_query(self, query: str) -> np.ndarray:
1003
+ """
1004
+ 使用嵌入模型编码查询
1005
+
1006
+ Args:
1007
+ query: 查询文本
1008
+
1009
+ Returns:
1010
+ 查询的向量表示
1011
+ """
1012
+ try:
1013
+ # 使用sentence-transformers的encode方法
1014
+ embeddings = self.embedding_model.encode([query])
1015
+ return embeddings[0] # 返回第一个查询的向量
1016
+
1017
+ except Exception as e:
1018
+ self.logger.error(f"查询编码失败: {e}")
1019
+ raise
1020
+
1021
+ def _search_faiss(self, query_vector: np.ndarray, top_k: int) -> tuple[list[float], list[int]]:
1022
+ """
1023
+ 在FAISS索引中搜索
1024
+
1025
+ Args:
1026
+ query_vector: 查询向量
1027
+ top_k: 返回top k个结果
1028
+
1029
+ Returns:
1030
+ (scores, indices): 相似度分数和文档索引
1031
+ """
1032
+ try:
1033
+ if self.faiss_index.ntotal == 0:
1034
+ self.logger.warning("FAISS索引为空,无法检索")
1035
+ return [], []
1036
+
1037
+ # FAISS搜索
1038
+ query_vector = query_vector.reshape(1, -1).astype("float32")
1039
+ scores, indices = self.faiss_index.search(query_vector, top_k) # type: ignore[call-overload]
1040
+
1041
+ return scores[0].tolist(), indices[0].tolist()
1042
+
1043
+ except Exception as e:
1044
+ self.logger.error(f"FAISS搜索失败: {e}")
1045
+ return [], []
1046
+
1047
+ def _format_retrieved_documents(
1048
+ self, scores: list[float], indices: list[int]
1049
+ ) -> list[dict[str, Any]]:
1050
+ """
1051
+ 格式化检索到的文档
1052
+
1053
+ Args:
1054
+ scores: 相似度分数列表
1055
+ indices: 文档索引列表
1056
+
1057
+ Returns:
1058
+ 格式化后的文档列表
1059
+ """
1060
+ retrieved_docs = []
1061
+
1062
+ for score, idx in zip(scores, indices, strict=False):
1063
+ # 如果有段落到文档的映射,使用映射
1064
+ if hasattr(self, "passage_to_doc_mapping") and self.passage_to_doc_mapping is not None:
1065
+ if idx >= 0 and idx < len(self.passage_to_doc_mapping):
1066
+ doc_idx = self.passage_to_doc_mapping[idx]
1067
+ if doc_idx >= 0 and doc_idx < len(self.documents):
1068
+ original_doc = self.documents[doc_idx]
1069
+
1070
+ # 创建标准化的文档格式
1071
+ standardized_doc = {
1072
+ "text": original_doc.get("contents", str(original_doc)),
1073
+ "similarity_score": float(score),
1074
+ "document_index": int(doc_idx),
1075
+ "passage_index": int(idx), # 保存段落索引
1076
+ }
1077
+
1078
+ # 保留其他有用的元数据
1079
+ if "title" in original_doc:
1080
+ standardized_doc["title"] = original_doc["title"]
1081
+ if "id" in original_doc:
1082
+ standardized_doc["id"] = original_doc["id"]
1083
+ if "doc_size" in original_doc:
1084
+ standardized_doc["doc_size"] = original_doc["doc_size"]
1085
+
1086
+ retrieved_docs.append(standardized_doc)
1087
+ else:
1088
+ self.logger.warning(
1089
+ f"映射的文档索引超出范围: {doc_idx} >= {len(self.documents)}"
1090
+ )
1091
+ else:
1092
+ self.logger.warning(
1093
+ f"段落索引超出映射范围: {idx} >= {len(self.passage_to_doc_mapping)}"
1094
+ )
1095
+ else:
1096
+ # 没有映射时,直接使用索引
1097
+ if idx >= 0 and idx < len(self.documents):
1098
+ original_doc = self.documents[idx]
1099
+
1100
+ # 创建标准化的文档格式,与ChromaRetriever保持一致
1101
+ standardized_doc = {
1102
+ "text": original_doc.get(
1103
+ "contents", str(original_doc)
1104
+ ), # 将contents字段映射为text
1105
+ "similarity_score": float(score),
1106
+ "document_index": int(idx),
1107
+ }
1108
+
1109
+ # 保留其他有用的元数据
1110
+ if "title" in original_doc:
1111
+ standardized_doc["title"] = original_doc["title"]
1112
+ if "id" in original_doc:
1113
+ standardized_doc["id"] = original_doc["id"]
1114
+ if "doc_size" in original_doc:
1115
+ standardized_doc["doc_size"] = original_doc["doc_size"]
1116
+
1117
+ retrieved_docs.append(standardized_doc)
1118
+
1119
+ return retrieved_docs
1120
+
1121
+ def _save_data_record(self, query: str, retrieved_docs: list[dict[str, Any]]):
1122
+ """保存检索记录用于分析"""
1123
+ if not self.enable_profile:
1124
+ return
1125
+
1126
+ record = {
1127
+ "timestamp": time.time(),
1128
+ "query": query,
1129
+ "retrieved_count": len(retrieved_docs),
1130
+ "documents": retrieved_docs,
1131
+ }
1132
+
1133
+ self.data_records.append(record)
1134
+
1135
+ # 每100条记录持久化一次
1136
+ if len(self.data_records) >= 100:
1137
+ self._persist_data_records()
1138
+
1139
+ def _persist_data_records(self):
1140
+ """持久化数据记录"""
1141
+ if not self.enable_profile or not self.data_records:
1142
+ return
1143
+
1144
+ try:
1145
+ timestamp = int(time.time())
1146
+ filename = f"wiki18_faiss_retrieval_records_{timestamp}.json"
1147
+ filepath = os.path.join(self.data_base_path, filename)
1148
+
1149
+ with open(filepath, "w", encoding="utf-8") as f:
1150
+ json.dump(self.data_records, f, ensure_ascii=False, indent=2)
1151
+
1152
+ self.logger.info(f"保存了 {len(self.data_records)} 条检索记录到 {filepath}")
1153
+ self.data_records = [] # 清空缓存
1154
+
1155
+ except Exception as e:
1156
+ self.logger.error(f"保存检索记录失败: {e}")
1157
+
1158
+ def execute(self, data: str | dict[str, Any] | tuple) -> dict[str, Any]:
1159
+ """
1160
+ 执行检索
1161
+ Args:
1162
+ data: 查询字符串、元组或字典
1163
+ Returns:
1164
+ dict: {"query": ..., "results": ..., "input": 原始输入, ...}
1165
+ """
1166
+ # 支持字典类型输入,优先取 question 字段
1167
+ is_dict_input = isinstance(data, dict)
1168
+ if is_dict_input:
1169
+ if "query" in data:
1170
+ input_query = data["query"]
1171
+ elif "question" in data:
1172
+ input_query = data["question"]
1173
+ else:
1174
+ self.logger.error("输入字典必须包含 'query' 或 'question' 字段")
1175
+ data["retrieval_results"] = []
1176
+ return data
1177
+ elif isinstance(data, tuple) and len(data) > 0:
1178
+ input_query = data[0]
1179
+ else:
1180
+ input_query = data
1181
+
1182
+ if not isinstance(input_query, str):
1183
+ self.logger.error(f"Invalid input query type: {type(input_query)}")
1184
+ if is_dict_input:
1185
+ data["retrieval_results"] = []
1186
+ return data
1187
+ else:
1188
+ return {"query": str(input_query), "retrieval_results": [], "input": data}
1189
+
1190
+ if not input_query or not input_query.strip():
1191
+ self.logger.error("查询不能为空")
1192
+ if is_dict_input:
1193
+ data["retrieval_results"] = []
1194
+ return data
1195
+ else:
1196
+ return {"query": "", "retrieval_results": [], "input": data}
1197
+
1198
+ input_query = input_query.strip()
1199
+ self.logger.info(
1200
+ f"[ {self.__class__.__name__}]: Starting FAISS retrieval for query: {input_query}"
1201
+ )
1202
+ self.logger.info(f"[ {self.__class__.__name__}]: Using top_k = {self.top_k}")
1203
+
1204
+ try:
1205
+ # 编码查询
1206
+ query_vector = self._encode_query(input_query)
1207
+
1208
+ # FAISS搜索
1209
+ scores, indices = self._search_faiss(query_vector, self.top_k)
1210
+
1211
+ # 格式化结果
1212
+ retrieved_docs = self._format_retrieved_documents(scores, indices)
1213
+
1214
+ self.logger.info(
1215
+ f"\033[32m[ {self.__class__.__name__}]: Retrieved {len(retrieved_docs)} documents from FAISS\033[0m"
1216
+ )
1217
+ self.logger.debug(
1218
+ f"Retrieved documents: {retrieved_docs[:3]}..."
1219
+ ) # 只显示前3个文档的预览
1220
+
1221
+ # 保存数据记录(只有enable_profile=True时才保存)
1222
+ if self.enable_profile:
1223
+ self._save_data_record(input_query, retrieved_docs)
1224
+
1225
+ if is_dict_input:
1226
+ data["retrieval_results"] = retrieved_docs
1227
+ # retrieve_time 由 MapOperator 自动添加
1228
+ return data
1229
+ else:
1230
+ return {
1231
+ "query": input_query,
1232
+ "retrieval_results": retrieved_docs,
1233
+ # retrieve_time 由 MapOperator 自动添加
1234
+ "input": data,
1235
+ }
1236
+
1237
+ except Exception as e:
1238
+ self.logger.error(f"FAISS retrieval failed: {str(e)}")
1239
+ if is_dict_input:
1240
+ data["retrieval_results"] = []
1241
+ return data
1242
+ else:
1243
+ return {"query": input_query, "retrieval_results": [], "input": data}
1244
+
1245
+ def build_index_from_wiki18(self, wiki18_data_path: str, save_path: str | None = None):
1246
+ """
1247
+ 从Wiki18数据集构建FAISS索引
1248
+
1249
+ Args:
1250
+ wiki18_data_path: Wiki18数据集路径
1251
+ save_path: 索引保存路径
1252
+ """
1253
+ try:
1254
+ import faiss
1255
+
1256
+ self.logger.info(f"开始从Wiki18数据构建FAISS索引: {wiki18_data_path}")
1257
+
1258
+ # 加载Wiki18数据
1259
+ documents = []
1260
+ with open(wiki18_data_path, encoding="utf-8") as f:
1261
+ for line in f:
1262
+ doc = json.loads(line.strip())
1263
+ documents.append(doc)
1264
+
1265
+ self.logger.info(f"加载了 {len(documents)} 个文档")
1266
+
1267
+ # 提取文档文本并编码
1268
+ doc_texts = [doc.get("text", "") for doc in documents]
1269
+
1270
+ # 批量编码所有文档
1271
+ self.logger.info("开始编码文档...")
1272
+ embeddings = self.embedding_model.encode(doc_texts)
1273
+ doc_vectors = embeddings["dense_vecs"] # 获取dense向量
1274
+
1275
+ # 创建FAISS索引
1276
+ dimension = doc_vectors.shape[1]
1277
+ self.faiss_index = faiss.IndexFlatIP(dimension) # 内积相似度
1278
+
1279
+ # 添加向量到索引
1280
+ self.faiss_index.add(doc_vectors.astype("float32")) # type: ignore[call-overload]
1281
+ self.documents = documents
1282
+
1283
+ self.logger.info(f"FAISS索引构建完成,包含 {self.faiss_index.ntotal} 个向量")
1284
+
1285
+ # 保存索引和文档
1286
+ if save_path:
1287
+ index_save_path = save_path + "_index"
1288
+ docs_save_path = save_path + "_documents.json"
1289
+
1290
+ faiss.write_index(self.faiss_index, index_save_path)
1291
+
1292
+ with open(docs_save_path, "w", encoding="utf-8") as f:
1293
+ json.dump(self.documents, f, ensure_ascii=False, indent=2)
1294
+
1295
+ self.logger.info(f"索引已保存到: {index_save_path}")
1296
+ self.logger.info(f"文档已保存到: {docs_save_path}")
1297
+
1298
+ except Exception as e:
1299
+ self.logger.error(f"构建FAISS索引失败: {e}")
1300
+ raise
1301
+
1302
+ def __del__(self):
1303
+ """确保在对象销毁时保存所有未保存的记录"""
1304
+ if hasattr(self, "enable_profile") and self.enable_profile:
1305
+ try:
1306
+ self._persist_data_records()
1307
+ except Exception:
1308
+ pass