isage-middleware 0.2.4.3__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_middleware-0.2.4.3.dist-info/METADATA +266 -0
- isage_middleware-0.2.4.3.dist-info/RECORD +94 -0
- isage_middleware-0.2.4.3.dist-info/WHEEL +5 -0
- isage_middleware-0.2.4.3.dist-info/top_level.txt +1 -0
- sage/middleware/__init__.py +59 -0
- sage/middleware/_version.py +6 -0
- sage/middleware/components/__init__.py +30 -0
- sage/middleware/components/extensions_compat.py +141 -0
- sage/middleware/components/sage_db/__init__.py +116 -0
- sage/middleware/components/sage_db/backend.py +136 -0
- sage/middleware/components/sage_db/service.py +15 -0
- sage/middleware/components/sage_flow/__init__.py +76 -0
- sage/middleware/components/sage_flow/python/__init__.py +14 -0
- sage/middleware/components/sage_flow/python/micro_service/__init__.py +4 -0
- sage/middleware/components/sage_flow/python/micro_service/sage_flow_service.py +88 -0
- sage/middleware/components/sage_flow/python/sage_flow.py +30 -0
- sage/middleware/components/sage_flow/service.py +14 -0
- sage/middleware/components/sage_mem/__init__.py +83 -0
- sage/middleware/components/sage_sias/__init__.py +59 -0
- sage/middleware/components/sage_sias/continual_learner.py +184 -0
- sage/middleware/components/sage_sias/coreset_selector.py +302 -0
- sage/middleware/components/sage_sias/types.py +94 -0
- sage/middleware/components/sage_tsdb/__init__.py +81 -0
- sage/middleware/components/sage_tsdb/python/__init__.py +21 -0
- sage/middleware/components/sage_tsdb/python/_sage_tsdb.pyi +17 -0
- sage/middleware/components/sage_tsdb/python/algorithms/__init__.py +17 -0
- sage/middleware/components/sage_tsdb/python/algorithms/base.py +51 -0
- sage/middleware/components/sage_tsdb/python/algorithms/out_of_order_join.py +248 -0
- sage/middleware/components/sage_tsdb/python/algorithms/window_aggregator.py +296 -0
- sage/middleware/components/sage_tsdb/python/micro_service/__init__.py +7 -0
- sage/middleware/components/sage_tsdb/python/micro_service/sage_tsdb_service.py +365 -0
- sage/middleware/components/sage_tsdb/python/sage_tsdb.py +523 -0
- sage/middleware/components/sage_tsdb/service.py +17 -0
- sage/middleware/components/vector_stores/__init__.py +25 -0
- sage/middleware/components/vector_stores/chroma.py +483 -0
- sage/middleware/components/vector_stores/chroma_adapter.py +185 -0
- sage/middleware/components/vector_stores/milvus.py +677 -0
- sage/middleware/operators/__init__.py +56 -0
- sage/middleware/operators/agent/__init__.py +24 -0
- sage/middleware/operators/agent/planning/__init__.py +5 -0
- sage/middleware/operators/agent/planning/llm_adapter.py +41 -0
- sage/middleware/operators/agent/planning/planner_adapter.py +98 -0
- sage/middleware/operators/agent/planning/router.py +107 -0
- sage/middleware/operators/agent/runtime.py +296 -0
- sage/middleware/operators/agentic/__init__.py +41 -0
- sage/middleware/operators/agentic/config.py +254 -0
- sage/middleware/operators/agentic/planning_operator.py +125 -0
- sage/middleware/operators/agentic/refined_searcher.py +132 -0
- sage/middleware/operators/agentic/runtime.py +241 -0
- sage/middleware/operators/agentic/timing_operator.py +125 -0
- sage/middleware/operators/agentic/tool_selection_operator.py +127 -0
- sage/middleware/operators/context/__init__.py +17 -0
- sage/middleware/operators/context/critic_evaluation.py +16 -0
- sage/middleware/operators/context/model_context.py +565 -0
- sage/middleware/operators/context/quality_label.py +12 -0
- sage/middleware/operators/context/search_query_results.py +61 -0
- sage/middleware/operators/context/search_result.py +42 -0
- sage/middleware/operators/context/search_session.py +79 -0
- sage/middleware/operators/filters/__init__.py +26 -0
- sage/middleware/operators/filters/context_sink.py +387 -0
- sage/middleware/operators/filters/context_source.py +376 -0
- sage/middleware/operators/filters/evaluate_filter.py +83 -0
- sage/middleware/operators/filters/tool_filter.py +74 -0
- sage/middleware/operators/llm/__init__.py +18 -0
- sage/middleware/operators/llm/sagellm_generator.py +432 -0
- sage/middleware/operators/rag/__init__.py +147 -0
- sage/middleware/operators/rag/arxiv.py +331 -0
- sage/middleware/operators/rag/chunk.py +13 -0
- sage/middleware/operators/rag/document_loaders.py +23 -0
- sage/middleware/operators/rag/evaluate.py +658 -0
- sage/middleware/operators/rag/generator.py +340 -0
- sage/middleware/operators/rag/index_builder/__init__.py +48 -0
- sage/middleware/operators/rag/index_builder/builder.py +363 -0
- sage/middleware/operators/rag/index_builder/manifest.py +101 -0
- sage/middleware/operators/rag/index_builder/storage.py +131 -0
- sage/middleware/operators/rag/pipeline.py +46 -0
- sage/middleware/operators/rag/profiler.py +59 -0
- sage/middleware/operators/rag/promptor.py +400 -0
- sage/middleware/operators/rag/refiner.py +231 -0
- sage/middleware/operators/rag/reranker.py +364 -0
- sage/middleware/operators/rag/retriever.py +1308 -0
- sage/middleware/operators/rag/searcher.py +37 -0
- sage/middleware/operators/rag/types.py +28 -0
- sage/middleware/operators/rag/writer.py +80 -0
- sage/middleware/operators/tools/__init__.py +71 -0
- sage/middleware/operators/tools/arxiv_paper_searcher.py +175 -0
- sage/middleware/operators/tools/arxiv_searcher.py +102 -0
- sage/middleware/operators/tools/duckduckgo_searcher.py +105 -0
- sage/middleware/operators/tools/image_captioner.py +104 -0
- sage/middleware/operators/tools/nature_news_fetcher.py +224 -0
- sage/middleware/operators/tools/searcher_tool.py +514 -0
- sage/middleware/operators/tools/text_detector.py +185 -0
- sage/middleware/operators/tools/url_text_extractor.py +104 -0
- sage/middleware/py.typed +2 -0
|
@@ -0,0 +1,483 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ChromaDB 后端管理工具
|
|
3
|
+
提供 ChromaDB 向量数据库的初始化、文档管理和检索功能
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import time
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ChromaBackend:
|
|
16
|
+
"""ChromaDB 后端管理器"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, config: dict[str, Any], logger: logging.Logger | Any = None):
|
|
19
|
+
"""
|
|
20
|
+
初始化 ChromaDB 后端
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
config: ChromaDB 配置字典
|
|
24
|
+
logger: 日志记录器
|
|
25
|
+
"""
|
|
26
|
+
self.config = config
|
|
27
|
+
self.logger = logger or logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
# ChromaDB 基本配置
|
|
30
|
+
self.host = config.get("host", "localhost")
|
|
31
|
+
self.port = config.get("port", 8000)
|
|
32
|
+
self.persistence_path = config.get("persistence_path", "./chroma_db")
|
|
33
|
+
self.collection_name = config.get("collection_name", "dense_retriever_collection")
|
|
34
|
+
self.use_embedding_query = config.get("use_embedding_query", True)
|
|
35
|
+
self.metadata_config = config.get("metadata", {"hnsw:space": "cosine"})
|
|
36
|
+
|
|
37
|
+
# 初始化客户端和集合
|
|
38
|
+
self.client: Any = None # Will be initialized by _init_client
|
|
39
|
+
self.collection: Any = None # Will be initialized by _init_collection
|
|
40
|
+
self._init_client()
|
|
41
|
+
self._init_collection()
|
|
42
|
+
|
|
43
|
+
def _init_client(self):
|
|
44
|
+
"""初始化 ChromaDB 客户端"""
|
|
45
|
+
try:
|
|
46
|
+
import chromadb
|
|
47
|
+
from chromadb.config import Settings # noqa: F401
|
|
48
|
+
|
|
49
|
+
# 判断使用本地还是远程模式
|
|
50
|
+
if self.host in ["localhost", "127.0.0.1"] and not self.config.get("force_http", False):
|
|
51
|
+
# 本地持久化模式
|
|
52
|
+
self.client = chromadb.PersistentClient(path=self.persistence_path)
|
|
53
|
+
self.logger.info(
|
|
54
|
+
f"Initialized ChromaDB persistent client at: {self.persistence_path}"
|
|
55
|
+
)
|
|
56
|
+
else:
|
|
57
|
+
# 远程服务器模式
|
|
58
|
+
full_host = (
|
|
59
|
+
f"http://{self.host}:{self.port}"
|
|
60
|
+
if not self.host.startswith("http")
|
|
61
|
+
else self.host
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# 处理认证
|
|
65
|
+
auth_config = self.config.get("auth", {})
|
|
66
|
+
if auth_config:
|
|
67
|
+
# 如果需要认证,可以在这里添加认证逻辑
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
self.client = chromadb.HttpClient(host=full_host)
|
|
71
|
+
self.logger.info(f"Initialized ChromaDB HTTP client at: {full_host}")
|
|
72
|
+
|
|
73
|
+
except ImportError as e:
|
|
74
|
+
self.logger.error(f"Failed to import ChromaDB: {e}")
|
|
75
|
+
raise ImportError(
|
|
76
|
+
"ChromaDB dependencies not available. Install with: pip install chromadb"
|
|
77
|
+
)
|
|
78
|
+
except Exception as e:
|
|
79
|
+
self.logger.error(f"Failed to initialize ChromaDB client: {e}")
|
|
80
|
+
raise
|
|
81
|
+
|
|
82
|
+
def _init_collection(self):
|
|
83
|
+
"""初始化或获取 ChromaDB 集合"""
|
|
84
|
+
try:
|
|
85
|
+
# 尝试获取已存在的集合
|
|
86
|
+
try:
|
|
87
|
+
self.collection = self.client.get_collection(name=self.collection_name)
|
|
88
|
+
self.logger.info(f"Retrieved existing ChromaDB collection: {self.collection_name}")
|
|
89
|
+
except Exception:
|
|
90
|
+
# 集合不存在,创建新集合
|
|
91
|
+
self.collection = self.client.create_collection(
|
|
92
|
+
name=self.collection_name, metadata=self.metadata_config
|
|
93
|
+
)
|
|
94
|
+
self.logger.info(f"Created new ChromaDB collection: {self.collection_name}")
|
|
95
|
+
|
|
96
|
+
except Exception as e:
|
|
97
|
+
self.logger.error(f"Failed to initialize ChromaDB collection: {e}")
|
|
98
|
+
raise
|
|
99
|
+
|
|
100
|
+
def add_documents(
|
|
101
|
+
self, documents: list[str], embeddings: list[np.ndarray], doc_ids: list[str]
|
|
102
|
+
) -> list[str]:
|
|
103
|
+
"""
|
|
104
|
+
添加文档到 ChromaDB 集合
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
documents: 文档内容列表
|
|
108
|
+
embeddings: 向量嵌入列表
|
|
109
|
+
doc_ids: 文档ID列表
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
成功添加的文档ID列表
|
|
113
|
+
"""
|
|
114
|
+
try:
|
|
115
|
+
# 转换 embedding 格式(ChromaDB 需要 list 格式)
|
|
116
|
+
embeddings_list = [embedding.tolist() for embedding in embeddings]
|
|
117
|
+
|
|
118
|
+
# 准备元数据
|
|
119
|
+
metadatas = []
|
|
120
|
+
for i, doc_id in enumerate(doc_ids):
|
|
121
|
+
metadata = {
|
|
122
|
+
"doc_id": doc_id,
|
|
123
|
+
"length": len(documents[i]),
|
|
124
|
+
"added_time": time.time(),
|
|
125
|
+
}
|
|
126
|
+
metadatas.append(metadata)
|
|
127
|
+
|
|
128
|
+
# 添加到 ChromaDB
|
|
129
|
+
self.collection.add(
|
|
130
|
+
embeddings=embeddings_list,
|
|
131
|
+
documents=documents,
|
|
132
|
+
metadatas=metadatas,
|
|
133
|
+
ids=doc_ids,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
self.logger.info(f"Added {len(documents)} documents to ChromaDB collection")
|
|
137
|
+
return doc_ids
|
|
138
|
+
|
|
139
|
+
except Exception as e:
|
|
140
|
+
self.logger.error(f"Error adding documents to ChromaDB: {e}")
|
|
141
|
+
return []
|
|
142
|
+
|
|
143
|
+
def search(self, query_vector: np.ndarray, query_text: str, top_k: int) -> list[str]:
|
|
144
|
+
"""
|
|
145
|
+
在 ChromaDB 中执行搜索
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
query_vector: 查询向量
|
|
149
|
+
query_text: 查询文本
|
|
150
|
+
top_k: 返回的文档数量
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
检索到的文档内容列表
|
|
154
|
+
"""
|
|
155
|
+
try:
|
|
156
|
+
print(f"ChromaBackend.search: using top_k = {top_k}")
|
|
157
|
+
|
|
158
|
+
if self.use_embedding_query:
|
|
159
|
+
# 使用向量查询
|
|
160
|
+
results = self.collection.query(
|
|
161
|
+
query_embeddings=[query_vector.tolist()],
|
|
162
|
+
n_results=top_k,
|
|
163
|
+
include=["documents", "metadatas", "distances"],
|
|
164
|
+
)
|
|
165
|
+
else:
|
|
166
|
+
# 使用文本查询(如果 ChromaDB 支持内建的 embedding 函数)
|
|
167
|
+
results = self.collection.query(
|
|
168
|
+
query_texts=[query_text],
|
|
169
|
+
n_results=top_k,
|
|
170
|
+
include=["documents", "metadatas", "distances"],
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# 提取文档内容
|
|
174
|
+
if results["documents"] and len(results["documents"]) > 0:
|
|
175
|
+
documents = results["documents"][0] # 返回第一个查询的结果
|
|
176
|
+
print(f"ChromaBackend.search: returned {len(documents)} documents")
|
|
177
|
+
return documents
|
|
178
|
+
else:
|
|
179
|
+
return []
|
|
180
|
+
|
|
181
|
+
except Exception as e:
|
|
182
|
+
self.logger.error(f"Error executing ChromaDB search: {e}")
|
|
183
|
+
return []
|
|
184
|
+
|
|
185
|
+
def delete_collection(self):
|
|
186
|
+
"""删除当前集合"""
|
|
187
|
+
try:
|
|
188
|
+
self.client.delete_collection(name=self.collection_name)
|
|
189
|
+
self.logger.info(f"Deleted ChromaDB collection: {self.collection_name}")
|
|
190
|
+
return True
|
|
191
|
+
except Exception as e:
|
|
192
|
+
self.logger.error(f"Error deleting ChromaDB collection: {e}")
|
|
193
|
+
return False
|
|
194
|
+
|
|
195
|
+
def get_collection_info(self) -> dict[str, Any]:
|
|
196
|
+
"""
|
|
197
|
+
获取集合信息
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
包含集合信息的字典
|
|
201
|
+
"""
|
|
202
|
+
try:
|
|
203
|
+
return {
|
|
204
|
+
"backend": "chroma",
|
|
205
|
+
"collection_name": self.collection.name,
|
|
206
|
+
"document_count": self.collection.count(),
|
|
207
|
+
"metadata": self.metadata_config,
|
|
208
|
+
"persistence_path": (
|
|
209
|
+
self.persistence_path if hasattr(self, "persistence_path") else None
|
|
210
|
+
),
|
|
211
|
+
}
|
|
212
|
+
except Exception as e:
|
|
213
|
+
self.logger.error(f"Failed to get ChromaDB collection info: {e}")
|
|
214
|
+
return {"backend": "chroma", "error": str(e)}
|
|
215
|
+
|
|
216
|
+
def save_config(self, save_path: str) -> bool:
|
|
217
|
+
"""
|
|
218
|
+
保存 ChromaDB 配置信息
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
save_path: 保存路径
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
是否保存成功
|
|
225
|
+
"""
|
|
226
|
+
try:
|
|
227
|
+
os.makedirs(save_path, exist_ok=True)
|
|
228
|
+
|
|
229
|
+
# ChromaDB 本身会处理持久化,这里只需要保存配置信息
|
|
230
|
+
config_path = os.path.join(save_path, "chroma_config.json")
|
|
231
|
+
config_info = {
|
|
232
|
+
"collection_name": self.collection.name,
|
|
233
|
+
"collection_count": self.collection.count(),
|
|
234
|
+
"backend_type": "chroma",
|
|
235
|
+
"chroma_config": self.config,
|
|
236
|
+
"saved_time": time.time(),
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
with open(config_path, "w", encoding="utf-8") as f:
|
|
240
|
+
json.dump(config_info, f, ensure_ascii=False, indent=2)
|
|
241
|
+
|
|
242
|
+
self.logger.info(f"Successfully saved ChromaDB config to: {save_path}")
|
|
243
|
+
self.logger.info(
|
|
244
|
+
f"ChromaDB collection '{self.collection.name}' contains {config_info['collection_count']} documents"
|
|
245
|
+
)
|
|
246
|
+
return True
|
|
247
|
+
|
|
248
|
+
except Exception as e:
|
|
249
|
+
self.logger.error(f"Failed to save ChromaDB config: {e}")
|
|
250
|
+
return False
|
|
251
|
+
|
|
252
|
+
def load_config(self, load_path: str) -> bool:
|
|
253
|
+
"""
|
|
254
|
+
从配置文件重新连接到 ChromaDB 集合
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
load_path: 配置文件路径
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
是否加载成功
|
|
261
|
+
"""
|
|
262
|
+
try:
|
|
263
|
+
config_path = os.path.join(load_path, "chroma_config.json")
|
|
264
|
+
if os.path.exists(config_path):
|
|
265
|
+
with open(config_path, encoding="utf-8") as f:
|
|
266
|
+
config_info = json.load(f)
|
|
267
|
+
|
|
268
|
+
collection_name = config_info.get("collection_name")
|
|
269
|
+
if collection_name:
|
|
270
|
+
# 尝试连接到已存在的集合
|
|
271
|
+
self.collection = self.client.get_collection(name=collection_name)
|
|
272
|
+
self.collection_name = collection_name
|
|
273
|
+
self.logger.info(
|
|
274
|
+
f"Successfully connected to ChromaDB collection: {collection_name}"
|
|
275
|
+
)
|
|
276
|
+
self.logger.info(f"Collection contains {self.collection.count()} documents")
|
|
277
|
+
return True
|
|
278
|
+
else:
|
|
279
|
+
self.logger.error("No collection name found in config")
|
|
280
|
+
return False
|
|
281
|
+
else:
|
|
282
|
+
self.logger.error(f"ChromaDB config not found at: {config_path}")
|
|
283
|
+
return False
|
|
284
|
+
|
|
285
|
+
except Exception as e:
|
|
286
|
+
self.logger.error(f"Failed to load ChromaDB config: {e}")
|
|
287
|
+
return False
|
|
288
|
+
|
|
289
|
+
def load_knowledge_from_file(self, file_path: str, embedding_model) -> bool:
|
|
290
|
+
"""
|
|
291
|
+
从文件加载知识库到 ChromaDB
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
file_path: 知识库文件路径
|
|
295
|
+
embedding_model: 嵌入模型实例
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
是否加载成功
|
|
299
|
+
"""
|
|
300
|
+
try:
|
|
301
|
+
self.logger.info(f"Loading knowledge from file: {file_path}")
|
|
302
|
+
with open(file_path, encoding="utf-8") as f:
|
|
303
|
+
content = f.read()
|
|
304
|
+
|
|
305
|
+
# 将知识库按段落分割
|
|
306
|
+
documents = [doc.strip() for doc in content.split("\n\n") if doc.strip()]
|
|
307
|
+
|
|
308
|
+
if documents:
|
|
309
|
+
# 生成文档ID
|
|
310
|
+
doc_ids = [f"doc_{int(time.time() * 1000)}_{i}" for i in range(len(documents))]
|
|
311
|
+
|
|
312
|
+
# 生成 embedding
|
|
313
|
+
embeddings = []
|
|
314
|
+
for doc in documents:
|
|
315
|
+
embedding = embedding_model.embed(doc)
|
|
316
|
+
embeddings.append(np.array(embedding, dtype=np.float32))
|
|
317
|
+
|
|
318
|
+
# 添加到 ChromaDB
|
|
319
|
+
added_ids = self.add_documents(documents, embeddings, doc_ids)
|
|
320
|
+
|
|
321
|
+
if added_ids:
|
|
322
|
+
self.logger.info(f"Loaded {len(added_ids)} documents from {file_path}")
|
|
323
|
+
return True
|
|
324
|
+
else:
|
|
325
|
+
self.logger.error(f"Failed to add documents from {file_path}")
|
|
326
|
+
return False
|
|
327
|
+
else:
|
|
328
|
+
self.logger.warning(f"No valid documents found in {file_path}")
|
|
329
|
+
return False
|
|
330
|
+
|
|
331
|
+
except Exception as e:
|
|
332
|
+
self.logger.error(f"Failed to load knowledge from file {file_path}: {e}")
|
|
333
|
+
return False
|
|
334
|
+
|
|
335
|
+
def clear_collection(self) -> bool:
|
|
336
|
+
"""
|
|
337
|
+
清空集合中的所有文档
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
是否清空成功
|
|
341
|
+
"""
|
|
342
|
+
try:
|
|
343
|
+
# 获取所有文档ID
|
|
344
|
+
all_docs = self.collection.get()
|
|
345
|
+
if all_docs["ids"]:
|
|
346
|
+
# 删除所有文档
|
|
347
|
+
self.collection.delete(ids=all_docs["ids"])
|
|
348
|
+
self.logger.info(f"Cleared {len(all_docs['ids'])} documents from collection")
|
|
349
|
+
return True
|
|
350
|
+
except Exception as e:
|
|
351
|
+
self.logger.error(f"Failed to clear collection: {e}")
|
|
352
|
+
return False
|
|
353
|
+
|
|
354
|
+
def update_document(self, doc_id: str, new_content: str, new_embedding: np.ndarray) -> bool:
|
|
355
|
+
"""
|
|
356
|
+
更新指定文档
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
doc_id: 文档ID
|
|
360
|
+
new_content: 新的文档内容
|
|
361
|
+
new_embedding: 新的向量嵌入
|
|
362
|
+
|
|
363
|
+
Returns:
|
|
364
|
+
是否更新成功
|
|
365
|
+
"""
|
|
366
|
+
try:
|
|
367
|
+
# ChromaDB 的 update 方法
|
|
368
|
+
self.collection.update(
|
|
369
|
+
ids=[doc_id],
|
|
370
|
+
documents=[new_content],
|
|
371
|
+
embeddings=[new_embedding.tolist()],
|
|
372
|
+
metadatas=[
|
|
373
|
+
{
|
|
374
|
+
"doc_id": doc_id,
|
|
375
|
+
"length": len(new_content),
|
|
376
|
+
"updated_time": time.time(),
|
|
377
|
+
}
|
|
378
|
+
],
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
self.logger.info(f"Updated document: {doc_id}")
|
|
382
|
+
return True
|
|
383
|
+
|
|
384
|
+
except Exception as e:
|
|
385
|
+
self.logger.error(f"Failed to update document {doc_id}: {e}")
|
|
386
|
+
return False
|
|
387
|
+
|
|
388
|
+
def delete_document(self, doc_id: str) -> bool:
|
|
389
|
+
"""
|
|
390
|
+
删除指定文档
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
doc_id: 文档ID
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
是否删除成功
|
|
397
|
+
"""
|
|
398
|
+
try:
|
|
399
|
+
self.collection.delete(ids=[doc_id])
|
|
400
|
+
self.logger.info(f"Deleted document: {doc_id}")
|
|
401
|
+
return True
|
|
402
|
+
except Exception as e:
|
|
403
|
+
self.logger.error(f"Failed to delete document {doc_id}: {e}")
|
|
404
|
+
return False
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
class ChromaUtils:
|
|
408
|
+
"""ChromaDB 工具类,提供常用的辅助方法"""
|
|
409
|
+
|
|
410
|
+
@staticmethod
|
|
411
|
+
def create_chroma_config(
|
|
412
|
+
persistence_path: str = "./chroma_db",
|
|
413
|
+
collection_name: str = "default_collection",
|
|
414
|
+
distance_metric: str = "cosine",
|
|
415
|
+
host: str = "localhost",
|
|
416
|
+
port: int = 8000,
|
|
417
|
+
) -> dict[str, Any]:
|
|
418
|
+
"""
|
|
419
|
+
创建标准的 ChromaDB 配置
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
persistence_path: 持久化路径
|
|
423
|
+
collection_name: 集合名称
|
|
424
|
+
distance_metric: 距离度量方法
|
|
425
|
+
host: 服务器地址
|
|
426
|
+
port: 服务器端口
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
ChromaDB 配置字典
|
|
430
|
+
"""
|
|
431
|
+
return {
|
|
432
|
+
"host": host,
|
|
433
|
+
"port": port,
|
|
434
|
+
"persistence_path": persistence_path,
|
|
435
|
+
"collection_name": collection_name,
|
|
436
|
+
"use_embedding_query": True,
|
|
437
|
+
"metadata": {
|
|
438
|
+
"hnsw:space": distance_metric,
|
|
439
|
+
"hnsw:M": 16,
|
|
440
|
+
"hnsw:ef_construction": 200,
|
|
441
|
+
"hnsw:ef": 10,
|
|
442
|
+
},
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
@staticmethod
|
|
446
|
+
def validate_chroma_config(config: dict[str, Any]) -> bool:
|
|
447
|
+
"""
|
|
448
|
+
验证 ChromaDB 配置的有效性
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
config: ChromaDB 配置字典
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
配置是否有效
|
|
455
|
+
"""
|
|
456
|
+
required_keys = ["collection_name"]
|
|
457
|
+
|
|
458
|
+
for key in required_keys:
|
|
459
|
+
if key not in config:
|
|
460
|
+
return False
|
|
461
|
+
|
|
462
|
+
# 验证距离度量
|
|
463
|
+
if "metadata" in config and "hnsw:space" in config["metadata"]:
|
|
464
|
+
valid_metrics = ["cosine", "l2", "ip"]
|
|
465
|
+
if config["metadata"]["hnsw:space"] not in valid_metrics:
|
|
466
|
+
return False
|
|
467
|
+
|
|
468
|
+
return True
|
|
469
|
+
|
|
470
|
+
@staticmethod
|
|
471
|
+
def check_chromadb_availability() -> bool:
|
|
472
|
+
"""
|
|
473
|
+
检查 ChromaDB 是否可用
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
ChromaDB 是否已安装并可用
|
|
477
|
+
"""
|
|
478
|
+
try:
|
|
479
|
+
import chromadb # noqa: F401
|
|
480
|
+
|
|
481
|
+
return True
|
|
482
|
+
except ImportError:
|
|
483
|
+
return False
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""ChromaDB VectorStore Adapter
|
|
2
|
+
|
|
3
|
+
Adapter that wraps ChromaBackend to implement the VectorStore protocol,
|
|
4
|
+
enabling it to work with IndexBuilder.
|
|
5
|
+
|
|
6
|
+
Layer: L3 (sage-libs/integrations)
|
|
7
|
+
Dependencies: sage.middleware.operators.rag.index_builder (L4 Protocol only - runtime_checkable)
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from sage.middleware.components.vector_stores.chroma import ChromaBackend
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ChromaVectorStoreAdapter:
|
|
18
|
+
"""Adapter wrapping ChromaBackend to implement VectorStore Protocol.
|
|
19
|
+
|
|
20
|
+
This adapter enables ChromaBackend to work with IndexBuilder by
|
|
21
|
+
implementing the VectorStore interface.
|
|
22
|
+
|
|
23
|
+
Note: We don't formally implement the Protocol here (that would create
|
|
24
|
+
L3→L4 dependency). Instead, we provide duck-typing compatibility.
|
|
25
|
+
The Protocol is only for type checking at runtime.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
persist_path: Directory to store ChromaDB data
|
|
29
|
+
dim: Vector dimension (unused for Chroma, but required by Protocol)
|
|
30
|
+
collection_name: Name of the Chroma collection
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
persist_path: Path,
|
|
36
|
+
dim: int,
|
|
37
|
+
collection_name: str = "sage_index",
|
|
38
|
+
):
|
|
39
|
+
"""Initialize ChromaDB adapter.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
persist_path: Path to persist ChromaDB data
|
|
43
|
+
dim: Vector dimension (recorded but not enforced by Chroma)
|
|
44
|
+
collection_name: Name of collection to use
|
|
45
|
+
"""
|
|
46
|
+
self.persist_path = persist_path
|
|
47
|
+
self.dim = dim
|
|
48
|
+
self.collection_name = collection_name
|
|
49
|
+
|
|
50
|
+
# Create parent directory
|
|
51
|
+
persist_path.mkdir(parents=True, exist_ok=True)
|
|
52
|
+
|
|
53
|
+
# Initialize ChromaBackend with local persistence
|
|
54
|
+
config = {
|
|
55
|
+
"persistence_path": str(persist_path),
|
|
56
|
+
"collection_name": collection_name,
|
|
57
|
+
"metadata": {"hnsw:space": "cosine"},
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
self.backend = ChromaBackend(config)
|
|
61
|
+
|
|
62
|
+
# Track documents for count
|
|
63
|
+
self._doc_count = 0
|
|
64
|
+
|
|
65
|
+
def add(self, vector: list[float], metadata: dict[str, Any]) -> None:
|
|
66
|
+
"""Add a single vector with metadata.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
vector: Vector embedding
|
|
70
|
+
metadata: Metadata dictionary
|
|
71
|
+
"""
|
|
72
|
+
# ChromaBackend.add_documents expects batch format
|
|
73
|
+
# We'll accumulate and flush periodically, or add one at a time
|
|
74
|
+
doc_id = f"doc_{self._doc_count}"
|
|
75
|
+
|
|
76
|
+
self.backend.add_documents(
|
|
77
|
+
ids=[doc_id],
|
|
78
|
+
embeddings=[vector],
|
|
79
|
+
metadatas=[metadata],
|
|
80
|
+
documents=[metadata.get("text", "")], # Use 'text' field if available
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
self._doc_count += 1
|
|
84
|
+
|
|
85
|
+
def build_index(self) -> None:
|
|
86
|
+
"""Build/optimize the index.
|
|
87
|
+
|
|
88
|
+
ChromaDB builds indices automatically, so this is a no-op.
|
|
89
|
+
"""
|
|
90
|
+
# ChromaDB automatically maintains indices
|
|
91
|
+
pass
|
|
92
|
+
|
|
93
|
+
def save(self, path: str) -> None:
|
|
94
|
+
"""Persist the vector store to disk.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
path: Path to save (unused for Chroma - uses persistence_path from config)
|
|
98
|
+
"""
|
|
99
|
+
# ChromaDB with PersistentClient automatically persists
|
|
100
|
+
# Save metadata about the index
|
|
101
|
+
manifest_path = Path(path).parent / "chroma_manifest.json"
|
|
102
|
+
manifest = {
|
|
103
|
+
"collection_name": self.collection_name,
|
|
104
|
+
"persistence_path": str(self.persist_path),
|
|
105
|
+
"dim": self.dim,
|
|
106
|
+
"count": self._doc_count,
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
with open(manifest_path, "w") as f:
|
|
110
|
+
json.dump(manifest, f, indent=2)
|
|
111
|
+
|
|
112
|
+
def load(self, path: str) -> None:
|
|
113
|
+
"""Load vector store from disk.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
path: Path to load from
|
|
117
|
+
"""
|
|
118
|
+
# ChromaDB automatically loads from persistence_path
|
|
119
|
+
# Try to load manifest for metadata
|
|
120
|
+
manifest_path = Path(path).parent / "chroma_manifest.json"
|
|
121
|
+
if manifest_path.exists():
|
|
122
|
+
with open(manifest_path) as f:
|
|
123
|
+
manifest = json.load(f)
|
|
124
|
+
self._doc_count = manifest.get("count", 0)
|
|
125
|
+
|
|
126
|
+
def search(
|
|
127
|
+
self,
|
|
128
|
+
query_vector: list[float],
|
|
129
|
+
top_k: int = 5,
|
|
130
|
+
filter_dict: dict[str, Any] | None = None,
|
|
131
|
+
) -> list[dict]:
|
|
132
|
+
"""Search for similar vectors.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
query_vector: Query embedding
|
|
136
|
+
top_k: Number of results to return
|
|
137
|
+
filter_dict: Optional metadata filters
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
List of result dictionaries with 'id', 'score', 'metadata'
|
|
141
|
+
"""
|
|
142
|
+
# Use ChromaBackend.query
|
|
143
|
+
results = self.backend.query(
|
|
144
|
+
query_embeddings=[query_vector],
|
|
145
|
+
n_results=top_k,
|
|
146
|
+
where=filter_dict, # ChromaDB uses 'where' for metadata filtering
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Convert ChromaDB results to standard format
|
|
150
|
+
formatted_results = []
|
|
151
|
+
if results and "ids" in results:
|
|
152
|
+
ids = results["ids"][0] if results["ids"] else []
|
|
153
|
+
distances = results["distances"][0] if results["distances"] else []
|
|
154
|
+
metadatas = results["metadatas"][0] if results["metadatas"] else []
|
|
155
|
+
|
|
156
|
+
for i, doc_id in enumerate(ids):
|
|
157
|
+
formatted_results.append(
|
|
158
|
+
{
|
|
159
|
+
"id": doc_id,
|
|
160
|
+
"score": float(distances[i]) if i < len(distances) else 0.0,
|
|
161
|
+
"metadata": metadatas[i] if i < len(metadatas) else {},
|
|
162
|
+
}
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
return formatted_results
|
|
166
|
+
|
|
167
|
+
def get_dim(self) -> int:
|
|
168
|
+
"""Get vector dimension.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Vector dimension
|
|
172
|
+
"""
|
|
173
|
+
return self.dim
|
|
174
|
+
|
|
175
|
+
def count(self) -> int:
|
|
176
|
+
"""Get number of vectors in store.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
Number of stored vectors
|
|
180
|
+
"""
|
|
181
|
+
# ChromaDB collection has a count method
|
|
182
|
+
try:
|
|
183
|
+
return self.backend.collection.count()
|
|
184
|
+
except Exception:
|
|
185
|
+
return self._doc_count
|