isage-middleware 0.2.4.3__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_middleware-0.2.4.3.dist-info/METADATA +266 -0
- isage_middleware-0.2.4.3.dist-info/RECORD +94 -0
- isage_middleware-0.2.4.3.dist-info/WHEEL +5 -0
- isage_middleware-0.2.4.3.dist-info/top_level.txt +1 -0
- sage/middleware/__init__.py +59 -0
- sage/middleware/_version.py +6 -0
- sage/middleware/components/__init__.py +30 -0
- sage/middleware/components/extensions_compat.py +141 -0
- sage/middleware/components/sage_db/__init__.py +116 -0
- sage/middleware/components/sage_db/backend.py +136 -0
- sage/middleware/components/sage_db/service.py +15 -0
- sage/middleware/components/sage_flow/__init__.py +76 -0
- sage/middleware/components/sage_flow/python/__init__.py +14 -0
- sage/middleware/components/sage_flow/python/micro_service/__init__.py +4 -0
- sage/middleware/components/sage_flow/python/micro_service/sage_flow_service.py +88 -0
- sage/middleware/components/sage_flow/python/sage_flow.py +30 -0
- sage/middleware/components/sage_flow/service.py +14 -0
- sage/middleware/components/sage_mem/__init__.py +83 -0
- sage/middleware/components/sage_sias/__init__.py +59 -0
- sage/middleware/components/sage_sias/continual_learner.py +184 -0
- sage/middleware/components/sage_sias/coreset_selector.py +302 -0
- sage/middleware/components/sage_sias/types.py +94 -0
- sage/middleware/components/sage_tsdb/__init__.py +81 -0
- sage/middleware/components/sage_tsdb/python/__init__.py +21 -0
- sage/middleware/components/sage_tsdb/python/_sage_tsdb.pyi +17 -0
- sage/middleware/components/sage_tsdb/python/algorithms/__init__.py +17 -0
- sage/middleware/components/sage_tsdb/python/algorithms/base.py +51 -0
- sage/middleware/components/sage_tsdb/python/algorithms/out_of_order_join.py +248 -0
- sage/middleware/components/sage_tsdb/python/algorithms/window_aggregator.py +296 -0
- sage/middleware/components/sage_tsdb/python/micro_service/__init__.py +7 -0
- sage/middleware/components/sage_tsdb/python/micro_service/sage_tsdb_service.py +365 -0
- sage/middleware/components/sage_tsdb/python/sage_tsdb.py +523 -0
- sage/middleware/components/sage_tsdb/service.py +17 -0
- sage/middleware/components/vector_stores/__init__.py +25 -0
- sage/middleware/components/vector_stores/chroma.py +483 -0
- sage/middleware/components/vector_stores/chroma_adapter.py +185 -0
- sage/middleware/components/vector_stores/milvus.py +677 -0
- sage/middleware/operators/__init__.py +56 -0
- sage/middleware/operators/agent/__init__.py +24 -0
- sage/middleware/operators/agent/planning/__init__.py +5 -0
- sage/middleware/operators/agent/planning/llm_adapter.py +41 -0
- sage/middleware/operators/agent/planning/planner_adapter.py +98 -0
- sage/middleware/operators/agent/planning/router.py +107 -0
- sage/middleware/operators/agent/runtime.py +296 -0
- sage/middleware/operators/agentic/__init__.py +41 -0
- sage/middleware/operators/agentic/config.py +254 -0
- sage/middleware/operators/agentic/planning_operator.py +125 -0
- sage/middleware/operators/agentic/refined_searcher.py +132 -0
- sage/middleware/operators/agentic/runtime.py +241 -0
- sage/middleware/operators/agentic/timing_operator.py +125 -0
- sage/middleware/operators/agentic/tool_selection_operator.py +127 -0
- sage/middleware/operators/context/__init__.py +17 -0
- sage/middleware/operators/context/critic_evaluation.py +16 -0
- sage/middleware/operators/context/model_context.py +565 -0
- sage/middleware/operators/context/quality_label.py +12 -0
- sage/middleware/operators/context/search_query_results.py +61 -0
- sage/middleware/operators/context/search_result.py +42 -0
- sage/middleware/operators/context/search_session.py +79 -0
- sage/middleware/operators/filters/__init__.py +26 -0
- sage/middleware/operators/filters/context_sink.py +387 -0
- sage/middleware/operators/filters/context_source.py +376 -0
- sage/middleware/operators/filters/evaluate_filter.py +83 -0
- sage/middleware/operators/filters/tool_filter.py +74 -0
- sage/middleware/operators/llm/__init__.py +18 -0
- sage/middleware/operators/llm/sagellm_generator.py +432 -0
- sage/middleware/operators/rag/__init__.py +147 -0
- sage/middleware/operators/rag/arxiv.py +331 -0
- sage/middleware/operators/rag/chunk.py +13 -0
- sage/middleware/operators/rag/document_loaders.py +23 -0
- sage/middleware/operators/rag/evaluate.py +658 -0
- sage/middleware/operators/rag/generator.py +340 -0
- sage/middleware/operators/rag/index_builder/__init__.py +48 -0
- sage/middleware/operators/rag/index_builder/builder.py +363 -0
- sage/middleware/operators/rag/index_builder/manifest.py +101 -0
- sage/middleware/operators/rag/index_builder/storage.py +131 -0
- sage/middleware/operators/rag/pipeline.py +46 -0
- sage/middleware/operators/rag/profiler.py +59 -0
- sage/middleware/operators/rag/promptor.py +400 -0
- sage/middleware/operators/rag/refiner.py +231 -0
- sage/middleware/operators/rag/reranker.py +364 -0
- sage/middleware/operators/rag/retriever.py +1308 -0
- sage/middleware/operators/rag/searcher.py +37 -0
- sage/middleware/operators/rag/types.py +28 -0
- sage/middleware/operators/rag/writer.py +80 -0
- sage/middleware/operators/tools/__init__.py +71 -0
- sage/middleware/operators/tools/arxiv_paper_searcher.py +175 -0
- sage/middleware/operators/tools/arxiv_searcher.py +102 -0
- sage/middleware/operators/tools/duckduckgo_searcher.py +105 -0
- sage/middleware/operators/tools/image_captioner.py +104 -0
- sage/middleware/operators/tools/nature_news_fetcher.py +224 -0
- sage/middleware/operators/tools/searcher_tool.py +514 -0
- sage/middleware/operators/tools/text_detector.py +185 -0
- sage/middleware/operators/tools/url_text_extractor.py +104 -0
- sage/middleware/py.typed +2 -0
|
@@ -0,0 +1,1308 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from sage.common.components.sage_embedding.embedding_model import EmbeddingModel
|
|
9
|
+
from sage.common.config.output_paths import get_states_file
|
|
10
|
+
from sage.common.core.functions import MapFunction as MapOperator
|
|
11
|
+
from sage.middleware.components.vector_stores.chroma import ChromaBackend, ChromaUtils
|
|
12
|
+
from sage.middleware.components.vector_stores.milvus import MilvusBackend, MilvusUtils
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# ChromaDB 密集检索器
|
|
16
|
+
class ChromaRetriever(MapOperator):
|
|
17
|
+
def __init__(self, config, enable_profile=False, **kwargs):
|
|
18
|
+
super().__init__(**kwargs)
|
|
19
|
+
self.config = config
|
|
20
|
+
self.enable_profile = enable_profile
|
|
21
|
+
|
|
22
|
+
# 只支持 ChromaDB 后端
|
|
23
|
+
self.backend_type = "chroma"
|
|
24
|
+
|
|
25
|
+
# 通用配置
|
|
26
|
+
self.vector_dimension = config.get("dimension", 384)
|
|
27
|
+
self.top_k = config.get("top_k", 10)
|
|
28
|
+
self.embedding_config = config.get("embedding", {})
|
|
29
|
+
|
|
30
|
+
# 先初始化 embedding 模型
|
|
31
|
+
self._init_embedding_model()
|
|
32
|
+
|
|
33
|
+
# 再初始化 ChromaDB 后端(这样知识库加载时embedding模型已可用)
|
|
34
|
+
self.chroma_config = config.get("chroma", {})
|
|
35
|
+
self._init_chroma_backend()
|
|
36
|
+
|
|
37
|
+
# 只有启用profile时才设置数据存储路径
|
|
38
|
+
if self.enable_profile:
|
|
39
|
+
# Use unified output path system
|
|
40
|
+
self.data_base_path = str(get_states_file("dummy", "retriever_data").parent)
|
|
41
|
+
os.makedirs(self.data_base_path, exist_ok=True)
|
|
42
|
+
self.data_records = []
|
|
43
|
+
|
|
44
|
+
def _init_chroma_backend(self):
|
|
45
|
+
"""初始化 ChromaDB 后端"""
|
|
46
|
+
try:
|
|
47
|
+
# 检查 ChromaDB 是否可用
|
|
48
|
+
if not ChromaUtils.check_chromadb_availability():
|
|
49
|
+
raise ImportError(
|
|
50
|
+
"ChromaDB dependencies not available. Install with: pip install chromadb"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# 验证配置
|
|
54
|
+
if not ChromaUtils.validate_chroma_config(self.chroma_config):
|
|
55
|
+
raise ValueError("Invalid ChromaDB configuration")
|
|
56
|
+
|
|
57
|
+
# 创建 ChromaDB 后端实例
|
|
58
|
+
self.chroma_backend = ChromaBackend(self.chroma_config, self.logger)
|
|
59
|
+
|
|
60
|
+
# 自动加载知识库文件
|
|
61
|
+
knowledge_file = self.chroma_config.get("knowledge_file")
|
|
62
|
+
if knowledge_file:
|
|
63
|
+
# 如果是相对路径,尝试从当前工作目录和项目根目录解析
|
|
64
|
+
if not os.path.isabs(knowledge_file):
|
|
65
|
+
# 尝试从当前目录
|
|
66
|
+
if os.path.exists(knowledge_file):
|
|
67
|
+
resolved_path = knowledge_file
|
|
68
|
+
else:
|
|
69
|
+
# 尝试从项目根目录解析
|
|
70
|
+
project_root = os.getcwd()
|
|
71
|
+
while project_root != "/" and not os.path.exists(
|
|
72
|
+
os.path.join(project_root, "pyproject.toml")
|
|
73
|
+
):
|
|
74
|
+
project_root = os.path.dirname(project_root)
|
|
75
|
+
|
|
76
|
+
potential_path = os.path.join(project_root, knowledge_file)
|
|
77
|
+
if os.path.exists(potential_path):
|
|
78
|
+
resolved_path = potential_path
|
|
79
|
+
else:
|
|
80
|
+
resolved_path = knowledge_file
|
|
81
|
+
else:
|
|
82
|
+
resolved_path = knowledge_file
|
|
83
|
+
|
|
84
|
+
if os.path.exists(resolved_path):
|
|
85
|
+
self._load_knowledge_from_file(resolved_path)
|
|
86
|
+
else:
|
|
87
|
+
self.logger.warning(f"Knowledge file not found: {resolved_path}")
|
|
88
|
+
|
|
89
|
+
except Exception as e:
|
|
90
|
+
self.logger.error(f"Failed to initialize ChromaDB: {e}")
|
|
91
|
+
raise
|
|
92
|
+
|
|
93
|
+
def _load_knowledge_from_file(self, file_path: str):
|
|
94
|
+
"""从文件加载知识库"""
|
|
95
|
+
try:
|
|
96
|
+
# 使用 ChromaDB 后端加载
|
|
97
|
+
success = self.chroma_backend.load_knowledge_from_file(file_path, self.embedding_model)
|
|
98
|
+
if not success:
|
|
99
|
+
self.logger.error(f"Failed to load knowledge from file: {file_path}")
|
|
100
|
+
|
|
101
|
+
except Exception as e:
|
|
102
|
+
self.logger.error(f"Failed to load knowledge from file {file_path}: {e}")
|
|
103
|
+
|
|
104
|
+
def _init_embedding_model(self):
|
|
105
|
+
"""初始化HuggingFace嵌入模型(使用sentence-transformers)"""
|
|
106
|
+
embedding_method = self.embedding_config.get("method", "default")
|
|
107
|
+
model = self.embedding_config.get("model", "sentence-transformers/all-MiniLM-L6-v2")
|
|
108
|
+
|
|
109
|
+
self.logger.info(f"Initializing embedding model with method: {embedding_method}")
|
|
110
|
+
self.embedding_model = EmbeddingModel(method=embedding_method, model=model)
|
|
111
|
+
|
|
112
|
+
# 验证向量维度
|
|
113
|
+
if hasattr(self.embedding_model, "get_dim"):
|
|
114
|
+
model_dim = self.embedding_model.get_dim()
|
|
115
|
+
if model_dim != self.vector_dimension:
|
|
116
|
+
self.logger.warning(
|
|
117
|
+
f"Embedding model dimension ({model_dim}) != configured dimension ({self.vector_dimension})"
|
|
118
|
+
)
|
|
119
|
+
# 更新向量维度以匹配模型
|
|
120
|
+
self.vector_dimension = model_dim
|
|
121
|
+
|
|
122
|
+
def add_documents(self, documents: list[str], doc_ids: list[str] | None = None) -> list[str]:
|
|
123
|
+
"""
|
|
124
|
+
添加文档到索引中
|
|
125
|
+
Args:
|
|
126
|
+
documents: 文档内容列表
|
|
127
|
+
doc_ids: 文档ID列表,如果为None则自动生成
|
|
128
|
+
Returns:
|
|
129
|
+
添加的文档ID列表
|
|
130
|
+
"""
|
|
131
|
+
if not documents:
|
|
132
|
+
return []
|
|
133
|
+
|
|
134
|
+
# 生成文档ID
|
|
135
|
+
if doc_ids is None:
|
|
136
|
+
doc_ids = [f"doc_{int(time.time() * 1000)}_{i}" for i in range(len(documents))]
|
|
137
|
+
elif len(doc_ids) != len(documents):
|
|
138
|
+
raise ValueError("doc_ids length must match documents length")
|
|
139
|
+
|
|
140
|
+
# 生成 embedding
|
|
141
|
+
embeddings = []
|
|
142
|
+
for doc in documents:
|
|
143
|
+
embedding = self.embedding_model.embed(doc)
|
|
144
|
+
# print(embedding)
|
|
145
|
+
embeddings.append(np.array(embedding, dtype=np.float32))
|
|
146
|
+
|
|
147
|
+
# 使用 ChromaDB 后端添加文档
|
|
148
|
+
return self.chroma_backend.add_documents(documents, embeddings, doc_ids)
|
|
149
|
+
|
|
150
|
+
def _save_data_record(self, query, retrieved_docs):
|
|
151
|
+
"""保存检索数据记录"""
|
|
152
|
+
if not self.enable_profile:
|
|
153
|
+
return
|
|
154
|
+
|
|
155
|
+
record = {
|
|
156
|
+
"timestamp": time.time(),
|
|
157
|
+
"query": query,
|
|
158
|
+
"retrieval_results": retrieved_docs,
|
|
159
|
+
"backend_type": self.backend_type,
|
|
160
|
+
"backend_config": getattr(self, f"{self.backend_type}_config", {}),
|
|
161
|
+
"embedding_config": self.embedding_config,
|
|
162
|
+
}
|
|
163
|
+
self.data_records.append(record)
|
|
164
|
+
self._persist_data_records()
|
|
165
|
+
|
|
166
|
+
def _persist_data_records(self):
|
|
167
|
+
"""将数据记录持久化到文件"""
|
|
168
|
+
if not self.enable_profile or not self.data_records:
|
|
169
|
+
return
|
|
170
|
+
|
|
171
|
+
timestamp = int(time.time())
|
|
172
|
+
filename = f"retriever_data_{timestamp}.json"
|
|
173
|
+
path = os.path.join(self.data_base_path, filename)
|
|
174
|
+
|
|
175
|
+
try:
|
|
176
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
177
|
+
json.dump(self.data_records, f, ensure_ascii=False, indent=2)
|
|
178
|
+
self.data_records = []
|
|
179
|
+
except Exception as e:
|
|
180
|
+
self.logger.error(f"Failed to persist data records: {e}")
|
|
181
|
+
|
|
182
|
+
def execute(self, data: str) -> dict[str, Any]:
|
|
183
|
+
"""
|
|
184
|
+
执行检索
|
|
185
|
+
Args:
|
|
186
|
+
data: 查询字符串、元组或字典
|
|
187
|
+
Returns:
|
|
188
|
+
dict: {"query": ..., "results": ..., "input": 原始输入, ...}
|
|
189
|
+
"""
|
|
190
|
+
is_dict_input = isinstance(data, dict)
|
|
191
|
+
if is_dict_input:
|
|
192
|
+
input_query = data.get("query", "")
|
|
193
|
+
elif isinstance(data, tuple) and len(data) > 0:
|
|
194
|
+
input_query = data[0]
|
|
195
|
+
else:
|
|
196
|
+
input_query = data
|
|
197
|
+
|
|
198
|
+
if not isinstance(input_query, str):
|
|
199
|
+
self.logger.error(f"Invalid input query type: {type(input_query)}")
|
|
200
|
+
if is_dict_input:
|
|
201
|
+
data["retrieval_results"] = []
|
|
202
|
+
return data
|
|
203
|
+
else:
|
|
204
|
+
return {"query": str(input_query), "retrieval_results": [], "input": data}
|
|
205
|
+
|
|
206
|
+
self.logger.info(
|
|
207
|
+
f"[ {self.__class__.__name__}]: Starting {self.backend_type.upper()} retrieval for query: {input_query}"
|
|
208
|
+
)
|
|
209
|
+
self.logger.info(f"[ {self.__class__.__name__}]: Using top_k = {self.top_k}")
|
|
210
|
+
|
|
211
|
+
try:
|
|
212
|
+
# 生成查询向量
|
|
213
|
+
query_embedding = self.embedding_model.embed(input_query)
|
|
214
|
+
query_vector = np.array(query_embedding, dtype=np.float32)
|
|
215
|
+
|
|
216
|
+
# 使用 ChromaDB 执行检索
|
|
217
|
+
retrieved_docs = self.chroma_backend.search(query_vector, input_query, self.top_k)
|
|
218
|
+
|
|
219
|
+
self.logger.info(
|
|
220
|
+
f"\033[32m[ {self.__class__.__name__}]: Retrieved {len(retrieved_docs)} documents from ChromaDB\033[0m"
|
|
221
|
+
)
|
|
222
|
+
self.logger.debug(
|
|
223
|
+
f"Retrieved documents: {retrieved_docs[:3]}..."
|
|
224
|
+
) # 只显示前3个文档的预览
|
|
225
|
+
|
|
226
|
+
# 将字符串列表转换为标准化的字典格式,以便后续组件使用
|
|
227
|
+
standardized_docs = []
|
|
228
|
+
for doc in retrieved_docs:
|
|
229
|
+
if isinstance(doc, str):
|
|
230
|
+
standardized_docs.append({"text": doc})
|
|
231
|
+
elif isinstance(doc, dict):
|
|
232
|
+
# 如果已经是字典,确保有text字段
|
|
233
|
+
if "text" not in doc and "content" in doc:
|
|
234
|
+
doc["text"] = doc["content"]
|
|
235
|
+
elif "text" not in doc:
|
|
236
|
+
# 将整个字典内容作为text
|
|
237
|
+
doc["text"] = str(doc)
|
|
238
|
+
standardized_docs.append(doc)
|
|
239
|
+
else:
|
|
240
|
+
# 其他类型转为字符串
|
|
241
|
+
standardized_docs.append({"text": str(doc)})
|
|
242
|
+
|
|
243
|
+
# 保存数据记录(只有enable_profile=True时才保存)
|
|
244
|
+
if self.enable_profile:
|
|
245
|
+
self._save_data_record(input_query, standardized_docs)
|
|
246
|
+
|
|
247
|
+
if is_dict_input:
|
|
248
|
+
# 保存原始检索结果(用于压缩率计算)
|
|
249
|
+
if "retrieval_results" not in data:
|
|
250
|
+
data["retrieval_results"] = standardized_docs
|
|
251
|
+
return data
|
|
252
|
+
else:
|
|
253
|
+
return {
|
|
254
|
+
"query": input_query,
|
|
255
|
+
"retrieval_results": standardized_docs,
|
|
256
|
+
"input": data,
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
except Exception as e:
|
|
260
|
+
self.logger.error(f"ChromaDB retrieval failed: {str(e)}")
|
|
261
|
+
if is_dict_input:
|
|
262
|
+
data["retrieval_results"] = []
|
|
263
|
+
return data
|
|
264
|
+
else:
|
|
265
|
+
return {"query": input_query, "retrieval_results": [], "input": data}
|
|
266
|
+
|
|
267
|
+
def save_index(self, save_path: str) -> bool:
|
|
268
|
+
"""
|
|
269
|
+
保存索引到磁盘
|
|
270
|
+
Args:
|
|
271
|
+
save_path: 保存路径
|
|
272
|
+
Returns:
|
|
273
|
+
是否保存成功
|
|
274
|
+
"""
|
|
275
|
+
return self.chroma_backend.save_config(save_path)
|
|
276
|
+
|
|
277
|
+
def load_index(self, load_path: str) -> bool:
|
|
278
|
+
"""
|
|
279
|
+
从磁盘加载索引
|
|
280
|
+
Args:
|
|
281
|
+
load_path: 加载路径
|
|
282
|
+
Returns:
|
|
283
|
+
是否加载成功
|
|
284
|
+
"""
|
|
285
|
+
return self.chroma_backend.load_config(load_path)
|
|
286
|
+
|
|
287
|
+
def get_collection_info(self) -> dict[str, Any]:
|
|
288
|
+
"""获取集合信息"""
|
|
289
|
+
return self.chroma_backend.get_collection_info()
|
|
290
|
+
|
|
291
|
+
def __del__(self):
|
|
292
|
+
"""确保在对象销毁时保存所有未保存的记录"""
|
|
293
|
+
if hasattr(self, "enable_profile") and self.enable_profile:
|
|
294
|
+
try:
|
|
295
|
+
self._persist_data_records()
|
|
296
|
+
except Exception:
|
|
297
|
+
pass
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
# Milvus稠密向量检索
|
|
301
|
+
class MilvusDenseRetriever(MapOperator):
|
|
302
|
+
"""
|
|
303
|
+
使用 Milvus 后端进行稠密向量检索。
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
def __init__(self, config, enable_profile=False, **kwargs):
|
|
307
|
+
super().__init__(**kwargs)
|
|
308
|
+
self.config = config
|
|
309
|
+
self.enable_profile = enable_profile
|
|
310
|
+
|
|
311
|
+
# 只支持Milvus后端
|
|
312
|
+
self.backend_type = "milvus"
|
|
313
|
+
|
|
314
|
+
# 通用配置
|
|
315
|
+
self.vector_dimension = self.config.get("dimension", 384)
|
|
316
|
+
self.top_k = self.config.get("top_k", 5)
|
|
317
|
+
self.embedding_config = self.config.get("embedding", {})
|
|
318
|
+
|
|
319
|
+
# 初始化Milvus后端
|
|
320
|
+
self.milvus_config = config.get("milvus_dense", {})
|
|
321
|
+
self._init_milvus_backend()
|
|
322
|
+
|
|
323
|
+
# 初始化 embedding 模型
|
|
324
|
+
self._init_embedding_model()
|
|
325
|
+
|
|
326
|
+
# 只有启用profile时才设置数据存储路径
|
|
327
|
+
if self.enable_profile:
|
|
328
|
+
if self.ctx is not None and hasattr(self.ctx, "env_base_dir") and self.ctx.env_base_dir:
|
|
329
|
+
self.data_base_path = os.path.join(
|
|
330
|
+
self.ctx.env_base_dir, ".sage_states", "retriever_data"
|
|
331
|
+
)
|
|
332
|
+
else:
|
|
333
|
+
# 使用默认路径
|
|
334
|
+
self.data_base_path = os.path.join(os.getcwd(), ".sage_states", "retriever_data")
|
|
335
|
+
|
|
336
|
+
os.makedirs(self.data_base_path, exist_ok=True)
|
|
337
|
+
self.data_records = []
|
|
338
|
+
|
|
339
|
+
def _init_milvus_backend(self):
|
|
340
|
+
"""初始化milvus后端"""
|
|
341
|
+
try:
|
|
342
|
+
# 检查 milvus 是否可用
|
|
343
|
+
if not MilvusUtils.check_milvus_available():
|
|
344
|
+
raise ImportError(
|
|
345
|
+
"Milvus dependencies not available. Install with: pip install pymilvus"
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
# 验证配置
|
|
349
|
+
if not MilvusUtils.validate_milvus_config(self.milvus_config):
|
|
350
|
+
raise ValueError("Invalid Milvus configuration")
|
|
351
|
+
|
|
352
|
+
# 初始化后端
|
|
353
|
+
self.milvus_backend = MilvusBackend(config=self.milvus_config, logger=self.logger)
|
|
354
|
+
|
|
355
|
+
# 自动加载知识库文件
|
|
356
|
+
knowledge_file = self.milvus_config.get("knowledge_file")
|
|
357
|
+
if knowledge_file and os.path.exists(knowledge_file):
|
|
358
|
+
self._load_knowledge_from_file_dense(knowledge_file)
|
|
359
|
+
|
|
360
|
+
except Exception as e:
|
|
361
|
+
self.logger.error(f"Failed to initialize milvus: {e}")
|
|
362
|
+
raise
|
|
363
|
+
|
|
364
|
+
def _load_knowledge_from_file_dense(self, file_path: str):
|
|
365
|
+
"""从文件中加载知识库"""
|
|
366
|
+
try:
|
|
367
|
+
# 使用Milvus后端加载
|
|
368
|
+
success = self.milvus_backend.load_knowledge_from_file_dense(
|
|
369
|
+
file_path, self.embedding_model
|
|
370
|
+
)
|
|
371
|
+
if not success:
|
|
372
|
+
self.logger.error(f"Failed to load knowledge from file: {file_path}")
|
|
373
|
+
except Exception as e:
|
|
374
|
+
self.logger.error(f"Failed to load knowledge from file: {e}")
|
|
375
|
+
|
|
376
|
+
def _init_embedding_model(self):
|
|
377
|
+
"""初始化embedding模型"""
|
|
378
|
+
embedding_method = self.embedding_config.get("method", "default")
|
|
379
|
+
model = self.embedding_config.get("model", "sentence-transformers/all-MiniLM-L6-v2")
|
|
380
|
+
|
|
381
|
+
self.logger.info(f"Initializing embedding model with method: {embedding_method}")
|
|
382
|
+
self.embedding_model = EmbeddingModel(method=embedding_method, model=model)
|
|
383
|
+
|
|
384
|
+
# 验证向量维度
|
|
385
|
+
if hasattr(self.embedding_model, "get_dim"):
|
|
386
|
+
model_dim = self.embedding_model.get_dim()
|
|
387
|
+
if model_dim != self.vector_dimension:
|
|
388
|
+
self.logger.warning(
|
|
389
|
+
f"Embedding model dimension ({model_dim}) != configured dimension ({self.vector_dimension})"
|
|
390
|
+
)
|
|
391
|
+
# 更新向量维度以匹配模型
|
|
392
|
+
self.vector_dimension = model_dim
|
|
393
|
+
|
|
394
|
+
def add_documents(self, documents: list[str], doc_ids: list[str] | None = None) -> list[str]:
|
|
395
|
+
"""
|
|
396
|
+
添加文档到milvus
|
|
397
|
+
Args:
|
|
398
|
+
documents: 文档内容列表
|
|
399
|
+
doc_ids: 文档ID列表,如果为None则自动生成
|
|
400
|
+
Returns:
|
|
401
|
+
添加的文档ID列表
|
|
402
|
+
"""
|
|
403
|
+
if not documents:
|
|
404
|
+
self.logger.warning("No documents to add")
|
|
405
|
+
return []
|
|
406
|
+
|
|
407
|
+
if doc_ids is None:
|
|
408
|
+
doc_ids = [f"doc_{int(time.time() * 1000)}_{i}" for i in range(len(documents))]
|
|
409
|
+
elif len(doc_ids) != len(documents):
|
|
410
|
+
raise ValueError("doc_ids length must match documents length")
|
|
411
|
+
|
|
412
|
+
# 生成 embedding
|
|
413
|
+
embeddings = []
|
|
414
|
+
for doc in documents:
|
|
415
|
+
embedding = self.embedding_model.embed(doc)
|
|
416
|
+
print(embedding)
|
|
417
|
+
embeddings.append(np.array(embedding, dtype=np.float32))
|
|
418
|
+
|
|
419
|
+
# 使用 milvus 后端添加文档
|
|
420
|
+
return self.milvus_backend.add_dense_documents(documents, embeddings, doc_ids)
|
|
421
|
+
|
|
422
|
+
def _save_data_record(self, query, retrieved_docs):
|
|
423
|
+
"""
|
|
424
|
+
保存检索数据记录
|
|
425
|
+
"""
|
|
426
|
+
if not self.enable_profile:
|
|
427
|
+
return
|
|
428
|
+
|
|
429
|
+
record = {
|
|
430
|
+
"timestamp": time.time(),
|
|
431
|
+
"query": query,
|
|
432
|
+
"retrieval_results": retrieved_docs,
|
|
433
|
+
"backend_type": self.backend_type,
|
|
434
|
+
"backend_config": getattr(self, f"{self.backend_type}_config", {}),
|
|
435
|
+
"embedding_config": self.embedding_config,
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
self.data_records.append(record)
|
|
439
|
+
self._persist_data_records()
|
|
440
|
+
|
|
441
|
+
def _persist_data_records(self):
|
|
442
|
+
"""
|
|
443
|
+
将数据记录持久化到文件
|
|
444
|
+
"""
|
|
445
|
+
if not self.enable_profile or not self.data_records:
|
|
446
|
+
return
|
|
447
|
+
|
|
448
|
+
timestamp = int(time.time())
|
|
449
|
+
filename = f"milvus_dense_retriever_data_{timestamp}.json"
|
|
450
|
+
path = os.path.join(self.data_base_path, filename)
|
|
451
|
+
|
|
452
|
+
try:
|
|
453
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
454
|
+
json.dump(self.data_records, f, ensure_ascii=False, indent=2)
|
|
455
|
+
self.data_records = []
|
|
456
|
+
except Exception as e:
|
|
457
|
+
self.logger.error(f"Failed to persist data records: {e}")
|
|
458
|
+
|
|
459
|
+
def execute(self, data: str) -> dict[str, Any]:
|
|
460
|
+
"""
|
|
461
|
+
执行检索
|
|
462
|
+
Args:
|
|
463
|
+
data: 查询字符串、元组或字典
|
|
464
|
+
Returns:
|
|
465
|
+
dict: {"query": ..., "retrieval_results": ..., "input": 原始输入, ...}
|
|
466
|
+
"""
|
|
467
|
+
# 支持字典类型输入,优先取 question 字段
|
|
468
|
+
is_dict_input = isinstance(data, dict)
|
|
469
|
+
if is_dict_input:
|
|
470
|
+
input_query = data.get("question", "")
|
|
471
|
+
elif isinstance(data, tuple) and len(data) > 0:
|
|
472
|
+
input_query = data[0]
|
|
473
|
+
else:
|
|
474
|
+
input_query = data
|
|
475
|
+
|
|
476
|
+
if not isinstance(input_query, str):
|
|
477
|
+
self.logger.error(f"Invalid input query type: {type(input_query)}")
|
|
478
|
+
if is_dict_input:
|
|
479
|
+
data["retrieval_results"] = []
|
|
480
|
+
return data
|
|
481
|
+
else:
|
|
482
|
+
return {
|
|
483
|
+
"query": str(input_query),
|
|
484
|
+
"retrieval_results": [],
|
|
485
|
+
"input": data,
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
self.logger.info(
|
|
489
|
+
f"[ {self.__class__.__name__}]: Starting {self.backend_type.upper()} retrieval for query: {input_query}"
|
|
490
|
+
)
|
|
491
|
+
self.logger.info(f"[ {self.__class__.__name__}]: Using top_k = {self.top_k}")
|
|
492
|
+
|
|
493
|
+
try:
|
|
494
|
+
# 生成查询向量
|
|
495
|
+
query_embedding = self.embedding_model.encode(input_query)
|
|
496
|
+
query_vector = np.array(query_embedding, dtype=np.float32)
|
|
497
|
+
|
|
498
|
+
# 使用Milvus执行稠密检索
|
|
499
|
+
retrieved_docs = self.milvus_backend.dense_search(
|
|
500
|
+
query_vector=query_vector,
|
|
501
|
+
top_k=self.top_k,
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
self.logger.info(
|
|
505
|
+
f"\033[32m[ {self.__class__.__name__}]: Retrieved {len(retrieved_docs)} documents from Milvus\033[0m"
|
|
506
|
+
)
|
|
507
|
+
self.logger.debug(
|
|
508
|
+
f"Retrieved documents: {retrieved_docs[:3]}..."
|
|
509
|
+
) # 只显示前3个文档的预览
|
|
510
|
+
|
|
511
|
+
print(f"Query: {input_query}")
|
|
512
|
+
print(f"Configured top_k: {self.top_k}")
|
|
513
|
+
print(f"Retrieved {len(retrieved_docs)} documents from Milvus")
|
|
514
|
+
print(retrieved_docs)
|
|
515
|
+
|
|
516
|
+
# 保存数据记录(只有enable_profile=True时才保存)
|
|
517
|
+
if self.enable_profile:
|
|
518
|
+
self._save_data_record(input_query, retrieved_docs)
|
|
519
|
+
|
|
520
|
+
if is_dict_input:
|
|
521
|
+
data["retrieval_results"] = retrieved_docs
|
|
522
|
+
return data
|
|
523
|
+
else:
|
|
524
|
+
return {
|
|
525
|
+
"query": input_query,
|
|
526
|
+
"retrieval_results": retrieved_docs,
|
|
527
|
+
"input": data,
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
except Exception as e:
|
|
531
|
+
self.logger.error(f" retrieval failed: {str(e)}")
|
|
532
|
+
if is_dict_input:
|
|
533
|
+
data["retrieval_results"] = []
|
|
534
|
+
return data
|
|
535
|
+
else:
|
|
536
|
+
return {
|
|
537
|
+
"query": input_query,
|
|
538
|
+
"retrieval_results": [],
|
|
539
|
+
"input": data,
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
def save_config(self, save_path: str) -> bool:
|
|
543
|
+
"""
|
|
544
|
+
保存配置到磁盘
|
|
545
|
+
Args:
|
|
546
|
+
save_path: 保存路径
|
|
547
|
+
Returns:
|
|
548
|
+
是否保存成功
|
|
549
|
+
"""
|
|
550
|
+
return self.milvus_backend.save_config(save_path)
|
|
551
|
+
|
|
552
|
+
def load_config(self, load_path: str) -> bool:
|
|
553
|
+
"""
|
|
554
|
+
从磁盘加载配置
|
|
555
|
+
Args:
|
|
556
|
+
load_path: 加载路径
|
|
557
|
+
Returns:
|
|
558
|
+
是否加载成功
|
|
559
|
+
"""
|
|
560
|
+
return self.milvus_backend.load_config(load_path)
|
|
561
|
+
|
|
562
|
+
def get_collection_info(self) -> dict[str, Any]:
|
|
563
|
+
"""
|
|
564
|
+
获取集合信息
|
|
565
|
+
"""
|
|
566
|
+
return self.milvus_backend.get_collection_info()
|
|
567
|
+
|
|
568
|
+
def delete_collection(self, collection_name: str) -> bool:
|
|
569
|
+
"""
|
|
570
|
+
删除集合
|
|
571
|
+
"""
|
|
572
|
+
return self.milvus_backend.delete_collection(collection_name)
|
|
573
|
+
|
|
574
|
+
def __del__(self):
|
|
575
|
+
"""确保在对象销毁时保存所有未保存的记录"""
|
|
576
|
+
if hasattr(self, "enable_profile") and self.enable_profile:
|
|
577
|
+
try:
|
|
578
|
+
self._persist_data_records()
|
|
579
|
+
except Exception:
|
|
580
|
+
pass
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
# Milvus稀疏向量检索
|
|
584
|
+
class MilvusSparseRetriever(MapOperator):
|
|
585
|
+
"""
|
|
586
|
+
使用 Milvus 后端进行稀疏向量检索。
|
|
587
|
+
"""
|
|
588
|
+
|
|
589
|
+
def __init__(self, config, enable_profile=False, **kwargs):
|
|
590
|
+
super().__init__(**kwargs)
|
|
591
|
+
self.config = config
|
|
592
|
+
self.enable_profile = enable_profile
|
|
593
|
+
|
|
594
|
+
# 只支持Milvus后端
|
|
595
|
+
self.backend_type = "milvus"
|
|
596
|
+
|
|
597
|
+
# 通用配置
|
|
598
|
+
self.top_k = self.config.get("top_k", 10)
|
|
599
|
+
|
|
600
|
+
# 初始化Milvus后端
|
|
601
|
+
self.milvus_config = config.get("milvus_sparse", {})
|
|
602
|
+
self._init_milvus_backend()
|
|
603
|
+
self._init_embedding_model()
|
|
604
|
+
|
|
605
|
+
# 只有启用profile时才设置数据存储路径
|
|
606
|
+
if self.enable_profile:
|
|
607
|
+
if self.ctx is not None and hasattr(self.ctx, "env_base_dir") and self.ctx.env_base_dir:
|
|
608
|
+
self.data_base_path = os.path.join(
|
|
609
|
+
self.ctx.env_base_dir, ".sage_states", "retriever_data"
|
|
610
|
+
)
|
|
611
|
+
else:
|
|
612
|
+
# 使用默认路径
|
|
613
|
+
self.data_base_path = os.path.join(os.getcwd(), ".sage_states", "retriever_data")
|
|
614
|
+
|
|
615
|
+
os.makedirs(self.data_base_path, exist_ok=True)
|
|
616
|
+
self.data_records = []
|
|
617
|
+
|
|
618
|
+
def _init_milvus_backend(self):
|
|
619
|
+
"""初始化milvus后端"""
|
|
620
|
+
try:
|
|
621
|
+
# 检查 milvus 是否可用
|
|
622
|
+
if not MilvusUtils.check_milvus_available():
|
|
623
|
+
raise ImportError(
|
|
624
|
+
"Milvus dependencies not available. Install with: pip install pymilvus"
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
# 验证配置
|
|
628
|
+
if not MilvusUtils.validate_milvus_config(self.milvus_config):
|
|
629
|
+
raise ValueError("Invalid Milvus configuration")
|
|
630
|
+
|
|
631
|
+
# 初始化后端
|
|
632
|
+
self.milvus_backend = MilvusBackend(config=self.milvus_config, logger=self.logger)
|
|
633
|
+
|
|
634
|
+
# 自动加载知识库文件
|
|
635
|
+
knowledge_file = self.milvus_config.get("knowledge_file")
|
|
636
|
+
if knowledge_file and os.path.exists(knowledge_file):
|
|
637
|
+
self._load_knowledge_from_file(knowledge_file)
|
|
638
|
+
|
|
639
|
+
except Exception as e:
|
|
640
|
+
self.logger.error(f"Failed to initialize milvus: {e}")
|
|
641
|
+
raise
|
|
642
|
+
|
|
643
|
+
def _init_embedding_model(self):
|
|
644
|
+
"""初始化embedding模型"""
|
|
645
|
+
try:
|
|
646
|
+
# 尝试新的导入路径(PyMilvus 2.6.0+)
|
|
647
|
+
try:
|
|
648
|
+
from pymilvus.model.hybrid import (
|
|
649
|
+
BGEM3EmbeddingFunction, # type: ignore[import-not-found]
|
|
650
|
+
)
|
|
651
|
+
except ImportError:
|
|
652
|
+
# 如果失败,尝试直接从 model 导入
|
|
653
|
+
try:
|
|
654
|
+
from pymilvus.model import (
|
|
655
|
+
BGEM3EmbeddingFunction, # type: ignore[import-not-found]
|
|
656
|
+
)
|
|
657
|
+
except ImportError:
|
|
658
|
+
# 最后尝试安装单独的包
|
|
659
|
+
self.logger.error(
|
|
660
|
+
"Please install: pip install 'pymilvus[model]' or pip install pymilvus.model"
|
|
661
|
+
)
|
|
662
|
+
raise ImportError("Embedding model dependencies not available")
|
|
663
|
+
|
|
664
|
+
self.embedding_model = BGEM3EmbeddingFunction(use_fp16=False, device="cpu")
|
|
665
|
+
|
|
666
|
+
except ImportError as e:
|
|
667
|
+
self.logger.error(f"Failed to import EmbeddingModel: {e}")
|
|
668
|
+
raise ImportError("Embedding model dependencies not available")
|
|
669
|
+
|
|
670
|
+
def _load_knowledge_from_file(self, file_path: str):
|
|
671
|
+
"""从文件中加载知识库"""
|
|
672
|
+
try:
|
|
673
|
+
# 使用Milvus后端加载
|
|
674
|
+
success = self.milvus_backend.load_knowledge_from_file_sparse(file_path)
|
|
675
|
+
self.logger.info(f"Loaded {success} documents from {file_path}")
|
|
676
|
+
if not success:
|
|
677
|
+
self.logger.error(f"Failed to load knowledge from file: {file_path}")
|
|
678
|
+
except Exception as e:
|
|
679
|
+
self.logger.error(f"Failed to load knowledge from file: {e}")
|
|
680
|
+
|
|
681
|
+
def add_documents(self, documents: list[str], doc_ids: list[str] | None = None) -> list[str]:
|
|
682
|
+
"""
|
|
683
|
+
添加文档到milvus
|
|
684
|
+
Args:
|
|
685
|
+
documents: 文档内容列表
|
|
686
|
+
doc_ids: 文档ID列表,如果为None则自动生成
|
|
687
|
+
Returns:
|
|
688
|
+
添加的文档ID列表
|
|
689
|
+
"""
|
|
690
|
+
if not documents:
|
|
691
|
+
self.logger.warning("No documents to add")
|
|
692
|
+
return []
|
|
693
|
+
|
|
694
|
+
# 生成 embedding
|
|
695
|
+
embedding = self.embedding_model.encode_documents(documents)
|
|
696
|
+
embeddings = embedding["sparse"]
|
|
697
|
+
|
|
698
|
+
if doc_ids is None:
|
|
699
|
+
doc_ids = [f"doc_{int(time.time() * 1000)}_{i}" for i in range(len(documents))]
|
|
700
|
+
elif len(doc_ids) != len(documents):
|
|
701
|
+
raise ValueError("doc_ids length must match documents length")
|
|
702
|
+
|
|
703
|
+
# 使用 milvus 后端添加文档
|
|
704
|
+
return self.milvus_backend.add_sparse_documents(documents, embeddings, doc_ids)
|
|
705
|
+
|
|
706
|
+
def _save_data_record(self, query, retrieved_docs):
|
|
707
|
+
"""
|
|
708
|
+
保存检索数据记录
|
|
709
|
+
"""
|
|
710
|
+
if not self.enable_profile:
|
|
711
|
+
return
|
|
712
|
+
|
|
713
|
+
record = {
|
|
714
|
+
"timestamp": time.time(),
|
|
715
|
+
"query": query,
|
|
716
|
+
"retrieval_results": retrieved_docs,
|
|
717
|
+
"backend_type": self.backend_type,
|
|
718
|
+
"backend_config": getattr(self, f"{self.backend_type}_config", {}),
|
|
719
|
+
}
|
|
720
|
+
|
|
721
|
+
self.data_records.append(record)
|
|
722
|
+
self._persist_data_records()
|
|
723
|
+
|
|
724
|
+
def _persist_data_records(self):
|
|
725
|
+
"""
|
|
726
|
+
将数据记录持久化到文件
|
|
727
|
+
"""
|
|
728
|
+
if not self.enable_profile or not self.data_records:
|
|
729
|
+
return
|
|
730
|
+
|
|
731
|
+
timestamp = int(time.time())
|
|
732
|
+
filename = f"milvus_dense_retriever_data_{timestamp}.json"
|
|
733
|
+
path = os.path.join(self.data_base_path, filename)
|
|
734
|
+
|
|
735
|
+
try:
|
|
736
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
737
|
+
json.dump(self.data_records, f, ensure_ascii=False, indent=2)
|
|
738
|
+
self.data_records = []
|
|
739
|
+
except Exception as e:
|
|
740
|
+
self.logger.error(f"Failed to persist data records: {e}")
|
|
741
|
+
|
|
742
|
+
def execute(self, data: str) -> dict[str, Any]:
|
|
743
|
+
"""
|
|
744
|
+
执行检索
|
|
745
|
+
Args:
|
|
746
|
+
data: 查询字符串、元组或字典
|
|
747
|
+
Returns:
|
|
748
|
+
dict: {"query": ..., "retrieval_results": ..., "input": 原始输入, ...}
|
|
749
|
+
"""
|
|
750
|
+
# 支持字典类型输入,优先取 question 字段
|
|
751
|
+
is_dict_input = isinstance(data, dict)
|
|
752
|
+
if is_dict_input:
|
|
753
|
+
input_query = data.get("question", "")
|
|
754
|
+
elif isinstance(data, tuple) and len(data) > 0:
|
|
755
|
+
input_query = data[0]
|
|
756
|
+
else:
|
|
757
|
+
input_query = data
|
|
758
|
+
|
|
759
|
+
if not isinstance(input_query, str):
|
|
760
|
+
self.logger.error(f"Invalid input query type: {type(input_query)}")
|
|
761
|
+
if is_dict_input:
|
|
762
|
+
data["retrieval_results"] = []
|
|
763
|
+
return data
|
|
764
|
+
else:
|
|
765
|
+
return {
|
|
766
|
+
"query": str(input_query),
|
|
767
|
+
"retrieval_results": [],
|
|
768
|
+
"input": data,
|
|
769
|
+
}
|
|
770
|
+
|
|
771
|
+
self.logger.info(
|
|
772
|
+
f"[ {self.__class__.__name__}]: Starting {self.backend_type.upper()} retrieval for query: {input_query}"
|
|
773
|
+
)
|
|
774
|
+
self.logger.info(f"[ {self.__class__.__name__}]: Using top_k = {self.top_k}")
|
|
775
|
+
|
|
776
|
+
try:
|
|
777
|
+
# 使用Milvus执行稀疏检索 - 直接传递查询文本,让sparse_search方法处理向量生成
|
|
778
|
+
retrieved_docs = self.milvus_backend.sparse_search(
|
|
779
|
+
query_text=input_query,
|
|
780
|
+
top_k=self.top_k,
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
self.logger.info(
|
|
784
|
+
f"\033[32m[ {self.__class__.__name__}]: Retrieved {len(retrieved_docs)} documents from Milvus\033[0m"
|
|
785
|
+
)
|
|
786
|
+
self.logger.debug(
|
|
787
|
+
f"Retrieved documents: {retrieved_docs[:3]}..."
|
|
788
|
+
) # 只显示前3个文档的预览
|
|
789
|
+
|
|
790
|
+
print(f"Query: {input_query}")
|
|
791
|
+
print(f"Configured top_k: {self.top_k}")
|
|
792
|
+
print(f"Retrieved {len(retrieved_docs)} documents from Milvus")
|
|
793
|
+
print(retrieved_docs)
|
|
794
|
+
|
|
795
|
+
# 保存数据记录(只有enable_profile=True时才保存)
|
|
796
|
+
if self.enable_profile:
|
|
797
|
+
self._save_data_record(input_query, retrieved_docs)
|
|
798
|
+
|
|
799
|
+
if is_dict_input:
|
|
800
|
+
data["retrieval_results"] = retrieved_docs
|
|
801
|
+
return data
|
|
802
|
+
else:
|
|
803
|
+
return {
|
|
804
|
+
"query": input_query,
|
|
805
|
+
"retrieval_results": retrieved_docs,
|
|
806
|
+
"input": data,
|
|
807
|
+
}
|
|
808
|
+
|
|
809
|
+
except Exception as e:
|
|
810
|
+
self.logger.error(f" retrieval failed: {str(e)}")
|
|
811
|
+
if is_dict_input:
|
|
812
|
+
data["retrieval_results"] = []
|
|
813
|
+
return data
|
|
814
|
+
else:
|
|
815
|
+
return {
|
|
816
|
+
"query": input_query,
|
|
817
|
+
"retrieval_results": [],
|
|
818
|
+
"input": data,
|
|
819
|
+
}
|
|
820
|
+
|
|
821
|
+
def save_config(self, save_path: str) -> bool:
|
|
822
|
+
"""
|
|
823
|
+
保存配置到磁盘
|
|
824
|
+
Args:
|
|
825
|
+
save_path: 保存路径
|
|
826
|
+
Returns:
|
|
827
|
+
是否保存成功
|
|
828
|
+
"""
|
|
829
|
+
return self.milvus_backend.save_config(save_path)
|
|
830
|
+
|
|
831
|
+
def load_config(self, load_path: str) -> bool:
|
|
832
|
+
"""
|
|
833
|
+
从磁盘加载配置
|
|
834
|
+
Args:
|
|
835
|
+
load_path: 加载路径
|
|
836
|
+
Returns:
|
|
837
|
+
是否加载成功
|
|
838
|
+
"""
|
|
839
|
+
return self.milvus_backend.load_config(load_path)
|
|
840
|
+
|
|
841
|
+
def get_collection_info(self) -> dict[str, Any]:
|
|
842
|
+
"""
|
|
843
|
+
获取集合信息
|
|
844
|
+
"""
|
|
845
|
+
return self.milvus_backend.get_collection_info()
|
|
846
|
+
|
|
847
|
+
def __del__(self):
|
|
848
|
+
"""确保在对象销毁时保存所有未保存的记录"""
|
|
849
|
+
if hasattr(self, "enable_profile") and self.enable_profile:
|
|
850
|
+
try:
|
|
851
|
+
self._persist_data_records()
|
|
852
|
+
except Exception:
|
|
853
|
+
pass
|
|
854
|
+
|
|
855
|
+
|
|
856
|
+
# Wiki18 FAISS 检索器
|
|
857
|
+
class Wiki18FAISSRetriever(MapOperator):
|
|
858
|
+
"""
|
|
859
|
+
基于FAISS的Wiki18数据集检索器,使用HuggingFace嵌入模型(如BGE-Large-EN-v1.5)
|
|
860
|
+
"""
|
|
861
|
+
|
|
862
|
+
def __init__(self, config, enable_profile=False, **kwargs):
|
|
863
|
+
super().__init__(**kwargs)
|
|
864
|
+
self.config = config
|
|
865
|
+
self.enable_profile = enable_profile
|
|
866
|
+
|
|
867
|
+
# 配置参数
|
|
868
|
+
self.top_k = config.get("top_k", 5)
|
|
869
|
+
self.embedding_config = config.get("embedding", {})
|
|
870
|
+
self.faiss_config = config.get("faiss", {})
|
|
871
|
+
|
|
872
|
+
# 初始化BGE-M3模型
|
|
873
|
+
self._init_bge_m3_model()
|
|
874
|
+
|
|
875
|
+
# 初始化FAISS索引
|
|
876
|
+
self._init_faiss_index()
|
|
877
|
+
|
|
878
|
+
# Profile数据存储
|
|
879
|
+
if self.enable_profile:
|
|
880
|
+
if self.ctx is not None and hasattr(self.ctx, "env_base_dir") and self.ctx.env_base_dir:
|
|
881
|
+
self.data_base_path = os.path.join(
|
|
882
|
+
self.ctx.env_base_dir, ".sage_states", "retriever_data"
|
|
883
|
+
)
|
|
884
|
+
else:
|
|
885
|
+
self.data_base_path = os.path.join(os.getcwd(), ".sage_states", "retriever_data")
|
|
886
|
+
|
|
887
|
+
os.makedirs(self.data_base_path, exist_ok=True)
|
|
888
|
+
self.data_records = []
|
|
889
|
+
|
|
890
|
+
def _init_bge_m3_model(self):
|
|
891
|
+
"""初始化BGE-M3嵌入模型(使用sentence-transformers)"""
|
|
892
|
+
try:
|
|
893
|
+
import torch
|
|
894
|
+
from sentence_transformers import SentenceTransformer
|
|
895
|
+
|
|
896
|
+
# 从配置获取模型路径,默认使用BGE-Large-EN-v1.5
|
|
897
|
+
model_path = self.embedding_config.get("model", "BAAI/bge-large-en-v1.5")
|
|
898
|
+
|
|
899
|
+
# 从配置获取GPU设备,默认使用GPU 0
|
|
900
|
+
gpu_device = self.embedding_config.get("gpu_device", 0)
|
|
901
|
+
|
|
902
|
+
# 明确指定GPU设备
|
|
903
|
+
if torch.cuda.is_available():
|
|
904
|
+
device = f"cuda:{gpu_device}"
|
|
905
|
+
self.logger.info(f"嵌入模型将使用GPU {gpu_device}")
|
|
906
|
+
else:
|
|
907
|
+
device = "cpu"
|
|
908
|
+
self.logger.info("嵌入模型将使用CPU")
|
|
909
|
+
|
|
910
|
+
# 初始化嵌入模型
|
|
911
|
+
self.embedding_model = SentenceTransformer(model_path, device=device)
|
|
912
|
+
|
|
913
|
+
self.logger.info(f"嵌入模型初始化成功: {model_path} 在设备 {device}")
|
|
914
|
+
|
|
915
|
+
except ImportError as e:
|
|
916
|
+
self.logger.error(f"无法导入sentence-transformers: {e}")
|
|
917
|
+
self.logger.error("请安装: pip install sentence-transformers")
|
|
918
|
+
raise
|
|
919
|
+
except Exception as e:
|
|
920
|
+
self.logger.error(f"嵌入模型初始化失败: {e}")
|
|
921
|
+
raise
|
|
922
|
+
|
|
923
|
+
def _init_faiss_index(self):
|
|
924
|
+
"""初始化FAISS索引"""
|
|
925
|
+
try:
|
|
926
|
+
import faiss
|
|
927
|
+
|
|
928
|
+
# FAISS配置 - 从配置文件读取路径
|
|
929
|
+
index_path = self.faiss_config.get("index_path")
|
|
930
|
+
documents_path = self.faiss_config.get("documents_path")
|
|
931
|
+
mapping_path = self.faiss_config.get("mapping_path") # 可选的段落到文档映射
|
|
932
|
+
|
|
933
|
+
# 检查必需的配置项
|
|
934
|
+
if not index_path:
|
|
935
|
+
raise ValueError("faiss.index_path 配置项是必需的")
|
|
936
|
+
if not documents_path:
|
|
937
|
+
raise ValueError("faiss.documents_path 配置项是必需的")
|
|
938
|
+
|
|
939
|
+
# 展开环境变量(支持 ${HOME}, ${USER}, $HOME 等格式)
|
|
940
|
+
index_path = os.path.expandvars(index_path)
|
|
941
|
+
documents_path = os.path.expandvars(documents_path)
|
|
942
|
+
if mapping_path:
|
|
943
|
+
mapping_path = os.path.expandvars(mapping_path)
|
|
944
|
+
|
|
945
|
+
# 尝试加载已有索引
|
|
946
|
+
if os.path.exists(index_path) and os.path.exists(documents_path):
|
|
947
|
+
self.logger.info(f"加载已有FAISS索引: {index_path}")
|
|
948
|
+
self.faiss_index = faiss.read_index(index_path)
|
|
949
|
+
|
|
950
|
+
# 加载段落到文档的映射(如果有)
|
|
951
|
+
self.passage_to_doc_mapping = None
|
|
952
|
+
if mapping_path and os.path.exists(mapping_path):
|
|
953
|
+
try:
|
|
954
|
+
with open(mapping_path, encoding="utf-8") as f:
|
|
955
|
+
self.passage_to_doc_mapping = json.load(f)
|
|
956
|
+
self.logger.info(
|
|
957
|
+
f"加载了段落映射: {len(self.passage_to_doc_mapping)} 个段落映射到文档"
|
|
958
|
+
)
|
|
959
|
+
except Exception as e:
|
|
960
|
+
self.logger.warning(f"加载段落映射失败: {e},将直接使用检索索引")
|
|
961
|
+
|
|
962
|
+
# 加载JSONL格式的文档数据
|
|
963
|
+
self.documents = []
|
|
964
|
+
try:
|
|
965
|
+
with open(documents_path, encoding="utf-8") as f:
|
|
966
|
+
for line in f:
|
|
967
|
+
line = line.strip()
|
|
968
|
+
if line:
|
|
969
|
+
try:
|
|
970
|
+
doc = json.loads(line)
|
|
971
|
+
self.documents.append(doc)
|
|
972
|
+
except json.JSONDecodeError as e:
|
|
973
|
+
self.logger.warning(
|
|
974
|
+
f"跳过无效的JSON行: {line[:100]}... 错误: {e}"
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
except Exception as e:
|
|
978
|
+
self.logger.error(f"加载文档文件失败: {e}")
|
|
979
|
+
self.documents = []
|
|
980
|
+
|
|
981
|
+
self.logger.info(f"加载了 {len(self.documents)} 个文档")
|
|
982
|
+
self.logger.info(f"FAISS索引大小: {self.faiss_index.ntotal} 个向量")
|
|
983
|
+
|
|
984
|
+
else:
|
|
985
|
+
# 如果没有预构建索引,需要从Wiki18数据构建
|
|
986
|
+
self.logger.warning(f"未找到预构建的FAISS索引: {index_path}")
|
|
987
|
+
self.logger.warning("需要先构建Wiki18 FAISS索引")
|
|
988
|
+
|
|
989
|
+
# 创建空索引和文档列表作为占位符
|
|
990
|
+
dimension = 1024 # 嵌入模型的维度(BGE系列)
|
|
991
|
+
self.faiss_index = faiss.IndexFlatIP(dimension) # 内积相似度
|
|
992
|
+
self.documents = []
|
|
993
|
+
|
|
994
|
+
except ImportError as e:
|
|
995
|
+
self.logger.error(f"无法导入FAISS: {e}")
|
|
996
|
+
self.logger.error("请安装FAISS: pip install faiss-cpu 或 pip install faiss-gpu")
|
|
997
|
+
raise
|
|
998
|
+
except Exception as e:
|
|
999
|
+
self.logger.error(f"FAISS索引初始化失败: {e}")
|
|
1000
|
+
raise
|
|
1001
|
+
|
|
1002
|
+
def _encode_query(self, query: str) -> np.ndarray:
|
|
1003
|
+
"""
|
|
1004
|
+
使用嵌入模型编码查询
|
|
1005
|
+
|
|
1006
|
+
Args:
|
|
1007
|
+
query: 查询文本
|
|
1008
|
+
|
|
1009
|
+
Returns:
|
|
1010
|
+
查询的向量表示
|
|
1011
|
+
"""
|
|
1012
|
+
try:
|
|
1013
|
+
# 使用sentence-transformers的encode方法
|
|
1014
|
+
embeddings = self.embedding_model.encode([query])
|
|
1015
|
+
return embeddings[0] # 返回第一个查询的向量
|
|
1016
|
+
|
|
1017
|
+
except Exception as e:
|
|
1018
|
+
self.logger.error(f"查询编码失败: {e}")
|
|
1019
|
+
raise
|
|
1020
|
+
|
|
1021
|
+
def _search_faiss(self, query_vector: np.ndarray, top_k: int) -> tuple[list[float], list[int]]:
|
|
1022
|
+
"""
|
|
1023
|
+
在FAISS索引中搜索
|
|
1024
|
+
|
|
1025
|
+
Args:
|
|
1026
|
+
query_vector: 查询向量
|
|
1027
|
+
top_k: 返回top k个结果
|
|
1028
|
+
|
|
1029
|
+
Returns:
|
|
1030
|
+
(scores, indices): 相似度分数和文档索引
|
|
1031
|
+
"""
|
|
1032
|
+
try:
|
|
1033
|
+
if self.faiss_index.ntotal == 0:
|
|
1034
|
+
self.logger.warning("FAISS索引为空,无法检索")
|
|
1035
|
+
return [], []
|
|
1036
|
+
|
|
1037
|
+
# FAISS搜索
|
|
1038
|
+
query_vector = query_vector.reshape(1, -1).astype("float32")
|
|
1039
|
+
scores, indices = self.faiss_index.search(query_vector, top_k) # type: ignore[call-overload]
|
|
1040
|
+
|
|
1041
|
+
return scores[0].tolist(), indices[0].tolist()
|
|
1042
|
+
|
|
1043
|
+
except Exception as e:
|
|
1044
|
+
self.logger.error(f"FAISS搜索失败: {e}")
|
|
1045
|
+
return [], []
|
|
1046
|
+
|
|
1047
|
+
def _format_retrieved_documents(
|
|
1048
|
+
self, scores: list[float], indices: list[int]
|
|
1049
|
+
) -> list[dict[str, Any]]:
|
|
1050
|
+
"""
|
|
1051
|
+
格式化检索到的文档
|
|
1052
|
+
|
|
1053
|
+
Args:
|
|
1054
|
+
scores: 相似度分数列表
|
|
1055
|
+
indices: 文档索引列表
|
|
1056
|
+
|
|
1057
|
+
Returns:
|
|
1058
|
+
格式化后的文档列表
|
|
1059
|
+
"""
|
|
1060
|
+
retrieved_docs = []
|
|
1061
|
+
|
|
1062
|
+
for score, idx in zip(scores, indices, strict=False):
|
|
1063
|
+
# 如果有段落到文档的映射,使用映射
|
|
1064
|
+
if hasattr(self, "passage_to_doc_mapping") and self.passage_to_doc_mapping is not None:
|
|
1065
|
+
if idx >= 0 and idx < len(self.passage_to_doc_mapping):
|
|
1066
|
+
doc_idx = self.passage_to_doc_mapping[idx]
|
|
1067
|
+
if doc_idx >= 0 and doc_idx < len(self.documents):
|
|
1068
|
+
original_doc = self.documents[doc_idx]
|
|
1069
|
+
|
|
1070
|
+
# 创建标准化的文档格式
|
|
1071
|
+
standardized_doc = {
|
|
1072
|
+
"text": original_doc.get("contents", str(original_doc)),
|
|
1073
|
+
"similarity_score": float(score),
|
|
1074
|
+
"document_index": int(doc_idx),
|
|
1075
|
+
"passage_index": int(idx), # 保存段落索引
|
|
1076
|
+
}
|
|
1077
|
+
|
|
1078
|
+
# 保留其他有用的元数据
|
|
1079
|
+
if "title" in original_doc:
|
|
1080
|
+
standardized_doc["title"] = original_doc["title"]
|
|
1081
|
+
if "id" in original_doc:
|
|
1082
|
+
standardized_doc["id"] = original_doc["id"]
|
|
1083
|
+
if "doc_size" in original_doc:
|
|
1084
|
+
standardized_doc["doc_size"] = original_doc["doc_size"]
|
|
1085
|
+
|
|
1086
|
+
retrieved_docs.append(standardized_doc)
|
|
1087
|
+
else:
|
|
1088
|
+
self.logger.warning(
|
|
1089
|
+
f"映射的文档索引超出范围: {doc_idx} >= {len(self.documents)}"
|
|
1090
|
+
)
|
|
1091
|
+
else:
|
|
1092
|
+
self.logger.warning(
|
|
1093
|
+
f"段落索引超出映射范围: {idx} >= {len(self.passage_to_doc_mapping)}"
|
|
1094
|
+
)
|
|
1095
|
+
else:
|
|
1096
|
+
# 没有映射时,直接使用索引
|
|
1097
|
+
if idx >= 0 and idx < len(self.documents):
|
|
1098
|
+
original_doc = self.documents[idx]
|
|
1099
|
+
|
|
1100
|
+
# 创建标准化的文档格式,与ChromaRetriever保持一致
|
|
1101
|
+
standardized_doc = {
|
|
1102
|
+
"text": original_doc.get(
|
|
1103
|
+
"contents", str(original_doc)
|
|
1104
|
+
), # 将contents字段映射为text
|
|
1105
|
+
"similarity_score": float(score),
|
|
1106
|
+
"document_index": int(idx),
|
|
1107
|
+
}
|
|
1108
|
+
|
|
1109
|
+
# 保留其他有用的元数据
|
|
1110
|
+
if "title" in original_doc:
|
|
1111
|
+
standardized_doc["title"] = original_doc["title"]
|
|
1112
|
+
if "id" in original_doc:
|
|
1113
|
+
standardized_doc["id"] = original_doc["id"]
|
|
1114
|
+
if "doc_size" in original_doc:
|
|
1115
|
+
standardized_doc["doc_size"] = original_doc["doc_size"]
|
|
1116
|
+
|
|
1117
|
+
retrieved_docs.append(standardized_doc)
|
|
1118
|
+
|
|
1119
|
+
return retrieved_docs
|
|
1120
|
+
|
|
1121
|
+
def _save_data_record(self, query: str, retrieved_docs: list[dict[str, Any]]):
|
|
1122
|
+
"""保存检索记录用于分析"""
|
|
1123
|
+
if not self.enable_profile:
|
|
1124
|
+
return
|
|
1125
|
+
|
|
1126
|
+
record = {
|
|
1127
|
+
"timestamp": time.time(),
|
|
1128
|
+
"query": query,
|
|
1129
|
+
"retrieved_count": len(retrieved_docs),
|
|
1130
|
+
"documents": retrieved_docs,
|
|
1131
|
+
}
|
|
1132
|
+
|
|
1133
|
+
self.data_records.append(record)
|
|
1134
|
+
|
|
1135
|
+
# 每100条记录持久化一次
|
|
1136
|
+
if len(self.data_records) >= 100:
|
|
1137
|
+
self._persist_data_records()
|
|
1138
|
+
|
|
1139
|
+
def _persist_data_records(self):
|
|
1140
|
+
"""持久化数据记录"""
|
|
1141
|
+
if not self.enable_profile or not self.data_records:
|
|
1142
|
+
return
|
|
1143
|
+
|
|
1144
|
+
try:
|
|
1145
|
+
timestamp = int(time.time())
|
|
1146
|
+
filename = f"wiki18_faiss_retrieval_records_{timestamp}.json"
|
|
1147
|
+
filepath = os.path.join(self.data_base_path, filename)
|
|
1148
|
+
|
|
1149
|
+
with open(filepath, "w", encoding="utf-8") as f:
|
|
1150
|
+
json.dump(self.data_records, f, ensure_ascii=False, indent=2)
|
|
1151
|
+
|
|
1152
|
+
self.logger.info(f"保存了 {len(self.data_records)} 条检索记录到 {filepath}")
|
|
1153
|
+
self.data_records = [] # 清空缓存
|
|
1154
|
+
|
|
1155
|
+
except Exception as e:
|
|
1156
|
+
self.logger.error(f"保存检索记录失败: {e}")
|
|
1157
|
+
|
|
1158
|
+
def execute(self, data: str | dict[str, Any] | tuple) -> dict[str, Any]:
|
|
1159
|
+
"""
|
|
1160
|
+
执行检索
|
|
1161
|
+
Args:
|
|
1162
|
+
data: 查询字符串、元组或字典
|
|
1163
|
+
Returns:
|
|
1164
|
+
dict: {"query": ..., "results": ..., "input": 原始输入, ...}
|
|
1165
|
+
"""
|
|
1166
|
+
# 支持字典类型输入,优先取 question 字段
|
|
1167
|
+
is_dict_input = isinstance(data, dict)
|
|
1168
|
+
if is_dict_input:
|
|
1169
|
+
if "query" in data:
|
|
1170
|
+
input_query = data["query"]
|
|
1171
|
+
elif "question" in data:
|
|
1172
|
+
input_query = data["question"]
|
|
1173
|
+
else:
|
|
1174
|
+
self.logger.error("输入字典必须包含 'query' 或 'question' 字段")
|
|
1175
|
+
data["retrieval_results"] = []
|
|
1176
|
+
return data
|
|
1177
|
+
elif isinstance(data, tuple) and len(data) > 0:
|
|
1178
|
+
input_query = data[0]
|
|
1179
|
+
else:
|
|
1180
|
+
input_query = data
|
|
1181
|
+
|
|
1182
|
+
if not isinstance(input_query, str):
|
|
1183
|
+
self.logger.error(f"Invalid input query type: {type(input_query)}")
|
|
1184
|
+
if is_dict_input:
|
|
1185
|
+
data["retrieval_results"] = []
|
|
1186
|
+
return data
|
|
1187
|
+
else:
|
|
1188
|
+
return {"query": str(input_query), "retrieval_results": [], "input": data}
|
|
1189
|
+
|
|
1190
|
+
if not input_query or not input_query.strip():
|
|
1191
|
+
self.logger.error("查询不能为空")
|
|
1192
|
+
if is_dict_input:
|
|
1193
|
+
data["retrieval_results"] = []
|
|
1194
|
+
return data
|
|
1195
|
+
else:
|
|
1196
|
+
return {"query": "", "retrieval_results": [], "input": data}
|
|
1197
|
+
|
|
1198
|
+
input_query = input_query.strip()
|
|
1199
|
+
self.logger.info(
|
|
1200
|
+
f"[ {self.__class__.__name__}]: Starting FAISS retrieval for query: {input_query}"
|
|
1201
|
+
)
|
|
1202
|
+
self.logger.info(f"[ {self.__class__.__name__}]: Using top_k = {self.top_k}")
|
|
1203
|
+
|
|
1204
|
+
try:
|
|
1205
|
+
# 编码查询
|
|
1206
|
+
query_vector = self._encode_query(input_query)
|
|
1207
|
+
|
|
1208
|
+
# FAISS搜索
|
|
1209
|
+
scores, indices = self._search_faiss(query_vector, self.top_k)
|
|
1210
|
+
|
|
1211
|
+
# 格式化结果
|
|
1212
|
+
retrieved_docs = self._format_retrieved_documents(scores, indices)
|
|
1213
|
+
|
|
1214
|
+
self.logger.info(
|
|
1215
|
+
f"\033[32m[ {self.__class__.__name__}]: Retrieved {len(retrieved_docs)} documents from FAISS\033[0m"
|
|
1216
|
+
)
|
|
1217
|
+
self.logger.debug(
|
|
1218
|
+
f"Retrieved documents: {retrieved_docs[:3]}..."
|
|
1219
|
+
) # 只显示前3个文档的预览
|
|
1220
|
+
|
|
1221
|
+
# 保存数据记录(只有enable_profile=True时才保存)
|
|
1222
|
+
if self.enable_profile:
|
|
1223
|
+
self._save_data_record(input_query, retrieved_docs)
|
|
1224
|
+
|
|
1225
|
+
if is_dict_input:
|
|
1226
|
+
data["retrieval_results"] = retrieved_docs
|
|
1227
|
+
# retrieve_time 由 MapOperator 自动添加
|
|
1228
|
+
return data
|
|
1229
|
+
else:
|
|
1230
|
+
return {
|
|
1231
|
+
"query": input_query,
|
|
1232
|
+
"retrieval_results": retrieved_docs,
|
|
1233
|
+
# retrieve_time 由 MapOperator 自动添加
|
|
1234
|
+
"input": data,
|
|
1235
|
+
}
|
|
1236
|
+
|
|
1237
|
+
except Exception as e:
|
|
1238
|
+
self.logger.error(f"FAISS retrieval failed: {str(e)}")
|
|
1239
|
+
if is_dict_input:
|
|
1240
|
+
data["retrieval_results"] = []
|
|
1241
|
+
return data
|
|
1242
|
+
else:
|
|
1243
|
+
return {"query": input_query, "retrieval_results": [], "input": data}
|
|
1244
|
+
|
|
1245
|
+
def build_index_from_wiki18(self, wiki18_data_path: str, save_path: str | None = None):
|
|
1246
|
+
"""
|
|
1247
|
+
从Wiki18数据集构建FAISS索引
|
|
1248
|
+
|
|
1249
|
+
Args:
|
|
1250
|
+
wiki18_data_path: Wiki18数据集路径
|
|
1251
|
+
save_path: 索引保存路径
|
|
1252
|
+
"""
|
|
1253
|
+
try:
|
|
1254
|
+
import faiss
|
|
1255
|
+
|
|
1256
|
+
self.logger.info(f"开始从Wiki18数据构建FAISS索引: {wiki18_data_path}")
|
|
1257
|
+
|
|
1258
|
+
# 加载Wiki18数据
|
|
1259
|
+
documents = []
|
|
1260
|
+
with open(wiki18_data_path, encoding="utf-8") as f:
|
|
1261
|
+
for line in f:
|
|
1262
|
+
doc = json.loads(line.strip())
|
|
1263
|
+
documents.append(doc)
|
|
1264
|
+
|
|
1265
|
+
self.logger.info(f"加载了 {len(documents)} 个文档")
|
|
1266
|
+
|
|
1267
|
+
# 提取文档文本并编码
|
|
1268
|
+
doc_texts = [doc.get("text", "") for doc in documents]
|
|
1269
|
+
|
|
1270
|
+
# 批量编码所有文档
|
|
1271
|
+
self.logger.info("开始编码文档...")
|
|
1272
|
+
embeddings = self.embedding_model.encode(doc_texts)
|
|
1273
|
+
doc_vectors = embeddings["dense_vecs"] # 获取dense向量
|
|
1274
|
+
|
|
1275
|
+
# 创建FAISS索引
|
|
1276
|
+
dimension = doc_vectors.shape[1]
|
|
1277
|
+
self.faiss_index = faiss.IndexFlatIP(dimension) # 内积相似度
|
|
1278
|
+
|
|
1279
|
+
# 添加向量到索引
|
|
1280
|
+
self.faiss_index.add(doc_vectors.astype("float32")) # type: ignore[call-overload]
|
|
1281
|
+
self.documents = documents
|
|
1282
|
+
|
|
1283
|
+
self.logger.info(f"FAISS索引构建完成,包含 {self.faiss_index.ntotal} 个向量")
|
|
1284
|
+
|
|
1285
|
+
# 保存索引和文档
|
|
1286
|
+
if save_path:
|
|
1287
|
+
index_save_path = save_path + "_index"
|
|
1288
|
+
docs_save_path = save_path + "_documents.json"
|
|
1289
|
+
|
|
1290
|
+
faiss.write_index(self.faiss_index, index_save_path)
|
|
1291
|
+
|
|
1292
|
+
with open(docs_save_path, "w", encoding="utf-8") as f:
|
|
1293
|
+
json.dump(self.documents, f, ensure_ascii=False, indent=2)
|
|
1294
|
+
|
|
1295
|
+
self.logger.info(f"索引已保存到: {index_save_path}")
|
|
1296
|
+
self.logger.info(f"文档已保存到: {docs_save_path}")
|
|
1297
|
+
|
|
1298
|
+
except Exception as e:
|
|
1299
|
+
self.logger.error(f"构建FAISS索引失败: {e}")
|
|
1300
|
+
raise
|
|
1301
|
+
|
|
1302
|
+
def __del__(self):
|
|
1303
|
+
"""确保在对象销毁时保存所有未保存的记录"""
|
|
1304
|
+
if hasattr(self, "enable_profile") and self.enable_profile:
|
|
1305
|
+
try:
|
|
1306
|
+
self._persist_data_records()
|
|
1307
|
+
except Exception:
|
|
1308
|
+
pass
|