isage-middleware 0.2.4.3__cp311-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. isage_middleware-0.2.4.3.dist-info/METADATA +266 -0
  2. isage_middleware-0.2.4.3.dist-info/RECORD +94 -0
  3. isage_middleware-0.2.4.3.dist-info/WHEEL +5 -0
  4. isage_middleware-0.2.4.3.dist-info/top_level.txt +1 -0
  5. sage/middleware/__init__.py +59 -0
  6. sage/middleware/_version.py +6 -0
  7. sage/middleware/components/__init__.py +30 -0
  8. sage/middleware/components/extensions_compat.py +141 -0
  9. sage/middleware/components/sage_db/__init__.py +116 -0
  10. sage/middleware/components/sage_db/backend.py +136 -0
  11. sage/middleware/components/sage_db/service.py +15 -0
  12. sage/middleware/components/sage_flow/__init__.py +76 -0
  13. sage/middleware/components/sage_flow/python/__init__.py +14 -0
  14. sage/middleware/components/sage_flow/python/micro_service/__init__.py +4 -0
  15. sage/middleware/components/sage_flow/python/micro_service/sage_flow_service.py +88 -0
  16. sage/middleware/components/sage_flow/python/sage_flow.py +30 -0
  17. sage/middleware/components/sage_flow/service.py +14 -0
  18. sage/middleware/components/sage_mem/__init__.py +83 -0
  19. sage/middleware/components/sage_sias/__init__.py +59 -0
  20. sage/middleware/components/sage_sias/continual_learner.py +184 -0
  21. sage/middleware/components/sage_sias/coreset_selector.py +302 -0
  22. sage/middleware/components/sage_sias/types.py +94 -0
  23. sage/middleware/components/sage_tsdb/__init__.py +81 -0
  24. sage/middleware/components/sage_tsdb/python/__init__.py +21 -0
  25. sage/middleware/components/sage_tsdb/python/_sage_tsdb.pyi +17 -0
  26. sage/middleware/components/sage_tsdb/python/algorithms/__init__.py +17 -0
  27. sage/middleware/components/sage_tsdb/python/algorithms/base.py +51 -0
  28. sage/middleware/components/sage_tsdb/python/algorithms/out_of_order_join.py +248 -0
  29. sage/middleware/components/sage_tsdb/python/algorithms/window_aggregator.py +296 -0
  30. sage/middleware/components/sage_tsdb/python/micro_service/__init__.py +7 -0
  31. sage/middleware/components/sage_tsdb/python/micro_service/sage_tsdb_service.py +365 -0
  32. sage/middleware/components/sage_tsdb/python/sage_tsdb.py +523 -0
  33. sage/middleware/components/sage_tsdb/service.py +17 -0
  34. sage/middleware/components/vector_stores/__init__.py +25 -0
  35. sage/middleware/components/vector_stores/chroma.py +483 -0
  36. sage/middleware/components/vector_stores/chroma_adapter.py +185 -0
  37. sage/middleware/components/vector_stores/milvus.py +677 -0
  38. sage/middleware/operators/__init__.py +56 -0
  39. sage/middleware/operators/agent/__init__.py +24 -0
  40. sage/middleware/operators/agent/planning/__init__.py +5 -0
  41. sage/middleware/operators/agent/planning/llm_adapter.py +41 -0
  42. sage/middleware/operators/agent/planning/planner_adapter.py +98 -0
  43. sage/middleware/operators/agent/planning/router.py +107 -0
  44. sage/middleware/operators/agent/runtime.py +296 -0
  45. sage/middleware/operators/agentic/__init__.py +41 -0
  46. sage/middleware/operators/agentic/config.py +254 -0
  47. sage/middleware/operators/agentic/planning_operator.py +125 -0
  48. sage/middleware/operators/agentic/refined_searcher.py +132 -0
  49. sage/middleware/operators/agentic/runtime.py +241 -0
  50. sage/middleware/operators/agentic/timing_operator.py +125 -0
  51. sage/middleware/operators/agentic/tool_selection_operator.py +127 -0
  52. sage/middleware/operators/context/__init__.py +17 -0
  53. sage/middleware/operators/context/critic_evaluation.py +16 -0
  54. sage/middleware/operators/context/model_context.py +565 -0
  55. sage/middleware/operators/context/quality_label.py +12 -0
  56. sage/middleware/operators/context/search_query_results.py +61 -0
  57. sage/middleware/operators/context/search_result.py +42 -0
  58. sage/middleware/operators/context/search_session.py +79 -0
  59. sage/middleware/operators/filters/__init__.py +26 -0
  60. sage/middleware/operators/filters/context_sink.py +387 -0
  61. sage/middleware/operators/filters/context_source.py +376 -0
  62. sage/middleware/operators/filters/evaluate_filter.py +83 -0
  63. sage/middleware/operators/filters/tool_filter.py +74 -0
  64. sage/middleware/operators/llm/__init__.py +18 -0
  65. sage/middleware/operators/llm/sagellm_generator.py +432 -0
  66. sage/middleware/operators/rag/__init__.py +147 -0
  67. sage/middleware/operators/rag/arxiv.py +331 -0
  68. sage/middleware/operators/rag/chunk.py +13 -0
  69. sage/middleware/operators/rag/document_loaders.py +23 -0
  70. sage/middleware/operators/rag/evaluate.py +658 -0
  71. sage/middleware/operators/rag/generator.py +340 -0
  72. sage/middleware/operators/rag/index_builder/__init__.py +48 -0
  73. sage/middleware/operators/rag/index_builder/builder.py +363 -0
  74. sage/middleware/operators/rag/index_builder/manifest.py +101 -0
  75. sage/middleware/operators/rag/index_builder/storage.py +131 -0
  76. sage/middleware/operators/rag/pipeline.py +46 -0
  77. sage/middleware/operators/rag/profiler.py +59 -0
  78. sage/middleware/operators/rag/promptor.py +400 -0
  79. sage/middleware/operators/rag/refiner.py +231 -0
  80. sage/middleware/operators/rag/reranker.py +364 -0
  81. sage/middleware/operators/rag/retriever.py +1308 -0
  82. sage/middleware/operators/rag/searcher.py +37 -0
  83. sage/middleware/operators/rag/types.py +28 -0
  84. sage/middleware/operators/rag/writer.py +80 -0
  85. sage/middleware/operators/tools/__init__.py +71 -0
  86. sage/middleware/operators/tools/arxiv_paper_searcher.py +175 -0
  87. sage/middleware/operators/tools/arxiv_searcher.py +102 -0
  88. sage/middleware/operators/tools/duckduckgo_searcher.py +105 -0
  89. sage/middleware/operators/tools/image_captioner.py +104 -0
  90. sage/middleware/operators/tools/nature_news_fetcher.py +224 -0
  91. sage/middleware/operators/tools/searcher_tool.py +514 -0
  92. sage/middleware/operators/tools/text_detector.py +185 -0
  93. sage/middleware/operators/tools/url_text_extractor.py +104 -0
  94. sage/middleware/py.typed +2 -0
@@ -0,0 +1,483 @@
1
+ """
2
+ ChromaDB 后端管理工具
3
+ 提供 ChromaDB 向量数据库的初始化、文档管理和检索功能
4
+ """
5
+
6
+ import json
7
+ import logging
8
+ import os
9
+ import time
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+
14
+
15
+ class ChromaBackend:
16
+ """ChromaDB 后端管理器"""
17
+
18
+ def __init__(self, config: dict[str, Any], logger: logging.Logger | Any = None):
19
+ """
20
+ 初始化 ChromaDB 后端
21
+
22
+ Args:
23
+ config: ChromaDB 配置字典
24
+ logger: 日志记录器
25
+ """
26
+ self.config = config
27
+ self.logger = logger or logging.getLogger(__name__)
28
+
29
+ # ChromaDB 基本配置
30
+ self.host = config.get("host", "localhost")
31
+ self.port = config.get("port", 8000)
32
+ self.persistence_path = config.get("persistence_path", "./chroma_db")
33
+ self.collection_name = config.get("collection_name", "dense_retriever_collection")
34
+ self.use_embedding_query = config.get("use_embedding_query", True)
35
+ self.metadata_config = config.get("metadata", {"hnsw:space": "cosine"})
36
+
37
+ # 初始化客户端和集合
38
+ self.client: Any = None # Will be initialized by _init_client
39
+ self.collection: Any = None # Will be initialized by _init_collection
40
+ self._init_client()
41
+ self._init_collection()
42
+
43
+ def _init_client(self):
44
+ """初始化 ChromaDB 客户端"""
45
+ try:
46
+ import chromadb
47
+ from chromadb.config import Settings # noqa: F401
48
+
49
+ # 判断使用本地还是远程模式
50
+ if self.host in ["localhost", "127.0.0.1"] and not self.config.get("force_http", False):
51
+ # 本地持久化模式
52
+ self.client = chromadb.PersistentClient(path=self.persistence_path)
53
+ self.logger.info(
54
+ f"Initialized ChromaDB persistent client at: {self.persistence_path}"
55
+ )
56
+ else:
57
+ # 远程服务器模式
58
+ full_host = (
59
+ f"http://{self.host}:{self.port}"
60
+ if not self.host.startswith("http")
61
+ else self.host
62
+ )
63
+
64
+ # 处理认证
65
+ auth_config = self.config.get("auth", {})
66
+ if auth_config:
67
+ # 如果需要认证,可以在这里添加认证逻辑
68
+ pass
69
+
70
+ self.client = chromadb.HttpClient(host=full_host)
71
+ self.logger.info(f"Initialized ChromaDB HTTP client at: {full_host}")
72
+
73
+ except ImportError as e:
74
+ self.logger.error(f"Failed to import ChromaDB: {e}")
75
+ raise ImportError(
76
+ "ChromaDB dependencies not available. Install with: pip install chromadb"
77
+ )
78
+ except Exception as e:
79
+ self.logger.error(f"Failed to initialize ChromaDB client: {e}")
80
+ raise
81
+
82
+ def _init_collection(self):
83
+ """初始化或获取 ChromaDB 集合"""
84
+ try:
85
+ # 尝试获取已存在的集合
86
+ try:
87
+ self.collection = self.client.get_collection(name=self.collection_name)
88
+ self.logger.info(f"Retrieved existing ChromaDB collection: {self.collection_name}")
89
+ except Exception:
90
+ # 集合不存在,创建新集合
91
+ self.collection = self.client.create_collection(
92
+ name=self.collection_name, metadata=self.metadata_config
93
+ )
94
+ self.logger.info(f"Created new ChromaDB collection: {self.collection_name}")
95
+
96
+ except Exception as e:
97
+ self.logger.error(f"Failed to initialize ChromaDB collection: {e}")
98
+ raise
99
+
100
+ def add_documents(
101
+ self, documents: list[str], embeddings: list[np.ndarray], doc_ids: list[str]
102
+ ) -> list[str]:
103
+ """
104
+ 添加文档到 ChromaDB 集合
105
+
106
+ Args:
107
+ documents: 文档内容列表
108
+ embeddings: 向量嵌入列表
109
+ doc_ids: 文档ID列表
110
+
111
+ Returns:
112
+ 成功添加的文档ID列表
113
+ """
114
+ try:
115
+ # 转换 embedding 格式(ChromaDB 需要 list 格式)
116
+ embeddings_list = [embedding.tolist() for embedding in embeddings]
117
+
118
+ # 准备元数据
119
+ metadatas = []
120
+ for i, doc_id in enumerate(doc_ids):
121
+ metadata = {
122
+ "doc_id": doc_id,
123
+ "length": len(documents[i]),
124
+ "added_time": time.time(),
125
+ }
126
+ metadatas.append(metadata)
127
+
128
+ # 添加到 ChromaDB
129
+ self.collection.add(
130
+ embeddings=embeddings_list,
131
+ documents=documents,
132
+ metadatas=metadatas,
133
+ ids=doc_ids,
134
+ )
135
+
136
+ self.logger.info(f"Added {len(documents)} documents to ChromaDB collection")
137
+ return doc_ids
138
+
139
+ except Exception as e:
140
+ self.logger.error(f"Error adding documents to ChromaDB: {e}")
141
+ return []
142
+
143
+ def search(self, query_vector: np.ndarray, query_text: str, top_k: int) -> list[str]:
144
+ """
145
+ 在 ChromaDB 中执行搜索
146
+
147
+ Args:
148
+ query_vector: 查询向量
149
+ query_text: 查询文本
150
+ top_k: 返回的文档数量
151
+
152
+ Returns:
153
+ 检索到的文档内容列表
154
+ """
155
+ try:
156
+ print(f"ChromaBackend.search: using top_k = {top_k}")
157
+
158
+ if self.use_embedding_query:
159
+ # 使用向量查询
160
+ results = self.collection.query(
161
+ query_embeddings=[query_vector.tolist()],
162
+ n_results=top_k,
163
+ include=["documents", "metadatas", "distances"],
164
+ )
165
+ else:
166
+ # 使用文本查询(如果 ChromaDB 支持内建的 embedding 函数)
167
+ results = self.collection.query(
168
+ query_texts=[query_text],
169
+ n_results=top_k,
170
+ include=["documents", "metadatas", "distances"],
171
+ )
172
+
173
+ # 提取文档内容
174
+ if results["documents"] and len(results["documents"]) > 0:
175
+ documents = results["documents"][0] # 返回第一个查询的结果
176
+ print(f"ChromaBackend.search: returned {len(documents)} documents")
177
+ return documents
178
+ else:
179
+ return []
180
+
181
+ except Exception as e:
182
+ self.logger.error(f"Error executing ChromaDB search: {e}")
183
+ return []
184
+
185
+ def delete_collection(self):
186
+ """删除当前集合"""
187
+ try:
188
+ self.client.delete_collection(name=self.collection_name)
189
+ self.logger.info(f"Deleted ChromaDB collection: {self.collection_name}")
190
+ return True
191
+ except Exception as e:
192
+ self.logger.error(f"Error deleting ChromaDB collection: {e}")
193
+ return False
194
+
195
+ def get_collection_info(self) -> dict[str, Any]:
196
+ """
197
+ 获取集合信息
198
+
199
+ Returns:
200
+ 包含集合信息的字典
201
+ """
202
+ try:
203
+ return {
204
+ "backend": "chroma",
205
+ "collection_name": self.collection.name,
206
+ "document_count": self.collection.count(),
207
+ "metadata": self.metadata_config,
208
+ "persistence_path": (
209
+ self.persistence_path if hasattr(self, "persistence_path") else None
210
+ ),
211
+ }
212
+ except Exception as e:
213
+ self.logger.error(f"Failed to get ChromaDB collection info: {e}")
214
+ return {"backend": "chroma", "error": str(e)}
215
+
216
+ def save_config(self, save_path: str) -> bool:
217
+ """
218
+ 保存 ChromaDB 配置信息
219
+
220
+ Args:
221
+ save_path: 保存路径
222
+
223
+ Returns:
224
+ 是否保存成功
225
+ """
226
+ try:
227
+ os.makedirs(save_path, exist_ok=True)
228
+
229
+ # ChromaDB 本身会处理持久化,这里只需要保存配置信息
230
+ config_path = os.path.join(save_path, "chroma_config.json")
231
+ config_info = {
232
+ "collection_name": self.collection.name,
233
+ "collection_count": self.collection.count(),
234
+ "backend_type": "chroma",
235
+ "chroma_config": self.config,
236
+ "saved_time": time.time(),
237
+ }
238
+
239
+ with open(config_path, "w", encoding="utf-8") as f:
240
+ json.dump(config_info, f, ensure_ascii=False, indent=2)
241
+
242
+ self.logger.info(f"Successfully saved ChromaDB config to: {save_path}")
243
+ self.logger.info(
244
+ f"ChromaDB collection '{self.collection.name}' contains {config_info['collection_count']} documents"
245
+ )
246
+ return True
247
+
248
+ except Exception as e:
249
+ self.logger.error(f"Failed to save ChromaDB config: {e}")
250
+ return False
251
+
252
+ def load_config(self, load_path: str) -> bool:
253
+ """
254
+ 从配置文件重新连接到 ChromaDB 集合
255
+
256
+ Args:
257
+ load_path: 配置文件路径
258
+
259
+ Returns:
260
+ 是否加载成功
261
+ """
262
+ try:
263
+ config_path = os.path.join(load_path, "chroma_config.json")
264
+ if os.path.exists(config_path):
265
+ with open(config_path, encoding="utf-8") as f:
266
+ config_info = json.load(f)
267
+
268
+ collection_name = config_info.get("collection_name")
269
+ if collection_name:
270
+ # 尝试连接到已存在的集合
271
+ self.collection = self.client.get_collection(name=collection_name)
272
+ self.collection_name = collection_name
273
+ self.logger.info(
274
+ f"Successfully connected to ChromaDB collection: {collection_name}"
275
+ )
276
+ self.logger.info(f"Collection contains {self.collection.count()} documents")
277
+ return True
278
+ else:
279
+ self.logger.error("No collection name found in config")
280
+ return False
281
+ else:
282
+ self.logger.error(f"ChromaDB config not found at: {config_path}")
283
+ return False
284
+
285
+ except Exception as e:
286
+ self.logger.error(f"Failed to load ChromaDB config: {e}")
287
+ return False
288
+
289
+ def load_knowledge_from_file(self, file_path: str, embedding_model) -> bool:
290
+ """
291
+ 从文件加载知识库到 ChromaDB
292
+
293
+ Args:
294
+ file_path: 知识库文件路径
295
+ embedding_model: 嵌入模型实例
296
+
297
+ Returns:
298
+ 是否加载成功
299
+ """
300
+ try:
301
+ self.logger.info(f"Loading knowledge from file: {file_path}")
302
+ with open(file_path, encoding="utf-8") as f:
303
+ content = f.read()
304
+
305
+ # 将知识库按段落分割
306
+ documents = [doc.strip() for doc in content.split("\n\n") if doc.strip()]
307
+
308
+ if documents:
309
+ # 生成文档ID
310
+ doc_ids = [f"doc_{int(time.time() * 1000)}_{i}" for i in range(len(documents))]
311
+
312
+ # 生成 embedding
313
+ embeddings = []
314
+ for doc in documents:
315
+ embedding = embedding_model.embed(doc)
316
+ embeddings.append(np.array(embedding, dtype=np.float32))
317
+
318
+ # 添加到 ChromaDB
319
+ added_ids = self.add_documents(documents, embeddings, doc_ids)
320
+
321
+ if added_ids:
322
+ self.logger.info(f"Loaded {len(added_ids)} documents from {file_path}")
323
+ return True
324
+ else:
325
+ self.logger.error(f"Failed to add documents from {file_path}")
326
+ return False
327
+ else:
328
+ self.logger.warning(f"No valid documents found in {file_path}")
329
+ return False
330
+
331
+ except Exception as e:
332
+ self.logger.error(f"Failed to load knowledge from file {file_path}: {e}")
333
+ return False
334
+
335
+ def clear_collection(self) -> bool:
336
+ """
337
+ 清空集合中的所有文档
338
+
339
+ Returns:
340
+ 是否清空成功
341
+ """
342
+ try:
343
+ # 获取所有文档ID
344
+ all_docs = self.collection.get()
345
+ if all_docs["ids"]:
346
+ # 删除所有文档
347
+ self.collection.delete(ids=all_docs["ids"])
348
+ self.logger.info(f"Cleared {len(all_docs['ids'])} documents from collection")
349
+ return True
350
+ except Exception as e:
351
+ self.logger.error(f"Failed to clear collection: {e}")
352
+ return False
353
+
354
+ def update_document(self, doc_id: str, new_content: str, new_embedding: np.ndarray) -> bool:
355
+ """
356
+ 更新指定文档
357
+
358
+ Args:
359
+ doc_id: 文档ID
360
+ new_content: 新的文档内容
361
+ new_embedding: 新的向量嵌入
362
+
363
+ Returns:
364
+ 是否更新成功
365
+ """
366
+ try:
367
+ # ChromaDB 的 update 方法
368
+ self.collection.update(
369
+ ids=[doc_id],
370
+ documents=[new_content],
371
+ embeddings=[new_embedding.tolist()],
372
+ metadatas=[
373
+ {
374
+ "doc_id": doc_id,
375
+ "length": len(new_content),
376
+ "updated_time": time.time(),
377
+ }
378
+ ],
379
+ )
380
+
381
+ self.logger.info(f"Updated document: {doc_id}")
382
+ return True
383
+
384
+ except Exception as e:
385
+ self.logger.error(f"Failed to update document {doc_id}: {e}")
386
+ return False
387
+
388
+ def delete_document(self, doc_id: str) -> bool:
389
+ """
390
+ 删除指定文档
391
+
392
+ Args:
393
+ doc_id: 文档ID
394
+
395
+ Returns:
396
+ 是否删除成功
397
+ """
398
+ try:
399
+ self.collection.delete(ids=[doc_id])
400
+ self.logger.info(f"Deleted document: {doc_id}")
401
+ return True
402
+ except Exception as e:
403
+ self.logger.error(f"Failed to delete document {doc_id}: {e}")
404
+ return False
405
+
406
+
407
+ class ChromaUtils:
408
+ """ChromaDB 工具类,提供常用的辅助方法"""
409
+
410
+ @staticmethod
411
+ def create_chroma_config(
412
+ persistence_path: str = "./chroma_db",
413
+ collection_name: str = "default_collection",
414
+ distance_metric: str = "cosine",
415
+ host: str = "localhost",
416
+ port: int = 8000,
417
+ ) -> dict[str, Any]:
418
+ """
419
+ 创建标准的 ChromaDB 配置
420
+
421
+ Args:
422
+ persistence_path: 持久化路径
423
+ collection_name: 集合名称
424
+ distance_metric: 距离度量方法
425
+ host: 服务器地址
426
+ port: 服务器端口
427
+
428
+ Returns:
429
+ ChromaDB 配置字典
430
+ """
431
+ return {
432
+ "host": host,
433
+ "port": port,
434
+ "persistence_path": persistence_path,
435
+ "collection_name": collection_name,
436
+ "use_embedding_query": True,
437
+ "metadata": {
438
+ "hnsw:space": distance_metric,
439
+ "hnsw:M": 16,
440
+ "hnsw:ef_construction": 200,
441
+ "hnsw:ef": 10,
442
+ },
443
+ }
444
+
445
+ @staticmethod
446
+ def validate_chroma_config(config: dict[str, Any]) -> bool:
447
+ """
448
+ 验证 ChromaDB 配置的有效性
449
+
450
+ Args:
451
+ config: ChromaDB 配置字典
452
+
453
+ Returns:
454
+ 配置是否有效
455
+ """
456
+ required_keys = ["collection_name"]
457
+
458
+ for key in required_keys:
459
+ if key not in config:
460
+ return False
461
+
462
+ # 验证距离度量
463
+ if "metadata" in config and "hnsw:space" in config["metadata"]:
464
+ valid_metrics = ["cosine", "l2", "ip"]
465
+ if config["metadata"]["hnsw:space"] not in valid_metrics:
466
+ return False
467
+
468
+ return True
469
+
470
+ @staticmethod
471
+ def check_chromadb_availability() -> bool:
472
+ """
473
+ 检查 ChromaDB 是否可用
474
+
475
+ Returns:
476
+ ChromaDB 是否已安装并可用
477
+ """
478
+ try:
479
+ import chromadb # noqa: F401
480
+
481
+ return True
482
+ except ImportError:
483
+ return False
@@ -0,0 +1,185 @@
1
+ """ChromaDB VectorStore Adapter
2
+
3
+ Adapter that wraps ChromaBackend to implement the VectorStore protocol,
4
+ enabling it to work with IndexBuilder.
5
+
6
+ Layer: L3 (sage-libs/integrations)
7
+ Dependencies: sage.middleware.operators.rag.index_builder (L4 Protocol only - runtime_checkable)
8
+ """
9
+
10
+ import json
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ from sage.middleware.components.vector_stores.chroma import ChromaBackend
15
+
16
+
17
+ class ChromaVectorStoreAdapter:
18
+ """Adapter wrapping ChromaBackend to implement VectorStore Protocol.
19
+
20
+ This adapter enables ChromaBackend to work with IndexBuilder by
21
+ implementing the VectorStore interface.
22
+
23
+ Note: We don't formally implement the Protocol here (that would create
24
+ L3→L4 dependency). Instead, we provide duck-typing compatibility.
25
+ The Protocol is only for type checking at runtime.
26
+
27
+ Args:
28
+ persist_path: Directory to store ChromaDB data
29
+ dim: Vector dimension (unused for Chroma, but required by Protocol)
30
+ collection_name: Name of the Chroma collection
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ persist_path: Path,
36
+ dim: int,
37
+ collection_name: str = "sage_index",
38
+ ):
39
+ """Initialize ChromaDB adapter.
40
+
41
+ Args:
42
+ persist_path: Path to persist ChromaDB data
43
+ dim: Vector dimension (recorded but not enforced by Chroma)
44
+ collection_name: Name of collection to use
45
+ """
46
+ self.persist_path = persist_path
47
+ self.dim = dim
48
+ self.collection_name = collection_name
49
+
50
+ # Create parent directory
51
+ persist_path.mkdir(parents=True, exist_ok=True)
52
+
53
+ # Initialize ChromaBackend with local persistence
54
+ config = {
55
+ "persistence_path": str(persist_path),
56
+ "collection_name": collection_name,
57
+ "metadata": {"hnsw:space": "cosine"},
58
+ }
59
+
60
+ self.backend = ChromaBackend(config)
61
+
62
+ # Track documents for count
63
+ self._doc_count = 0
64
+
65
+ def add(self, vector: list[float], metadata: dict[str, Any]) -> None:
66
+ """Add a single vector with metadata.
67
+
68
+ Args:
69
+ vector: Vector embedding
70
+ metadata: Metadata dictionary
71
+ """
72
+ # ChromaBackend.add_documents expects batch format
73
+ # We'll accumulate and flush periodically, or add one at a time
74
+ doc_id = f"doc_{self._doc_count}"
75
+
76
+ self.backend.add_documents(
77
+ ids=[doc_id],
78
+ embeddings=[vector],
79
+ metadatas=[metadata],
80
+ documents=[metadata.get("text", "")], # Use 'text' field if available
81
+ )
82
+
83
+ self._doc_count += 1
84
+
85
+ def build_index(self) -> None:
86
+ """Build/optimize the index.
87
+
88
+ ChromaDB builds indices automatically, so this is a no-op.
89
+ """
90
+ # ChromaDB automatically maintains indices
91
+ pass
92
+
93
+ def save(self, path: str) -> None:
94
+ """Persist the vector store to disk.
95
+
96
+ Args:
97
+ path: Path to save (unused for Chroma - uses persistence_path from config)
98
+ """
99
+ # ChromaDB with PersistentClient automatically persists
100
+ # Save metadata about the index
101
+ manifest_path = Path(path).parent / "chroma_manifest.json"
102
+ manifest = {
103
+ "collection_name": self.collection_name,
104
+ "persistence_path": str(self.persist_path),
105
+ "dim": self.dim,
106
+ "count": self._doc_count,
107
+ }
108
+
109
+ with open(manifest_path, "w") as f:
110
+ json.dump(manifest, f, indent=2)
111
+
112
+ def load(self, path: str) -> None:
113
+ """Load vector store from disk.
114
+
115
+ Args:
116
+ path: Path to load from
117
+ """
118
+ # ChromaDB automatically loads from persistence_path
119
+ # Try to load manifest for metadata
120
+ manifest_path = Path(path).parent / "chroma_manifest.json"
121
+ if manifest_path.exists():
122
+ with open(manifest_path) as f:
123
+ manifest = json.load(f)
124
+ self._doc_count = manifest.get("count", 0)
125
+
126
+ def search(
127
+ self,
128
+ query_vector: list[float],
129
+ top_k: int = 5,
130
+ filter_dict: dict[str, Any] | None = None,
131
+ ) -> list[dict]:
132
+ """Search for similar vectors.
133
+
134
+ Args:
135
+ query_vector: Query embedding
136
+ top_k: Number of results to return
137
+ filter_dict: Optional metadata filters
138
+
139
+ Returns:
140
+ List of result dictionaries with 'id', 'score', 'metadata'
141
+ """
142
+ # Use ChromaBackend.query
143
+ results = self.backend.query(
144
+ query_embeddings=[query_vector],
145
+ n_results=top_k,
146
+ where=filter_dict, # ChromaDB uses 'where' for metadata filtering
147
+ )
148
+
149
+ # Convert ChromaDB results to standard format
150
+ formatted_results = []
151
+ if results and "ids" in results:
152
+ ids = results["ids"][0] if results["ids"] else []
153
+ distances = results["distances"][0] if results["distances"] else []
154
+ metadatas = results["metadatas"][0] if results["metadatas"] else []
155
+
156
+ for i, doc_id in enumerate(ids):
157
+ formatted_results.append(
158
+ {
159
+ "id": doc_id,
160
+ "score": float(distances[i]) if i < len(distances) else 0.0,
161
+ "metadata": metadatas[i] if i < len(metadatas) else {},
162
+ }
163
+ )
164
+
165
+ return formatted_results
166
+
167
+ def get_dim(self) -> int:
168
+ """Get vector dimension.
169
+
170
+ Returns:
171
+ Vector dimension
172
+ """
173
+ return self.dim
174
+
175
+ def count(self) -> int:
176
+ """Get number of vectors in store.
177
+
178
+ Returns:
179
+ Number of stored vectors
180
+ """
181
+ # ChromaDB collection has a count method
182
+ try:
183
+ return self.backend.collection.count()
184
+ except Exception:
185
+ return self._doc_count