maque 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maque/__init__.py +30 -0
- maque/__main__.py +926 -0
- maque/ai_platform/__init__.py +0 -0
- maque/ai_platform/crawl.py +45 -0
- maque/ai_platform/metrics.py +258 -0
- maque/ai_platform/nlp_preprocess.py +67 -0
- maque/ai_platform/webpage_screen_shot.py +195 -0
- maque/algorithms/__init__.py +78 -0
- maque/algorithms/bezier.py +15 -0
- maque/algorithms/bktree.py +117 -0
- maque/algorithms/core.py +104 -0
- maque/algorithms/hilbert.py +16 -0
- maque/algorithms/rate_function.py +92 -0
- maque/algorithms/transform.py +27 -0
- maque/algorithms/trie.py +272 -0
- maque/algorithms/utils.py +63 -0
- maque/algorithms/video.py +587 -0
- maque/api/__init__.py +1 -0
- maque/api/common.py +110 -0
- maque/api/fetch.py +26 -0
- maque/api/static/icon.png +0 -0
- maque/api/static/redoc.standalone.js +1782 -0
- maque/api/static/swagger-ui-bundle.js +3 -0
- maque/api/static/swagger-ui.css +3 -0
- maque/cli/__init__.py +1 -0
- maque/cli/clean_invisible_chars.py +324 -0
- maque/cli/core.py +34 -0
- maque/cli/groups/__init__.py +26 -0
- maque/cli/groups/config.py +205 -0
- maque/cli/groups/data.py +615 -0
- maque/cli/groups/doctor.py +259 -0
- maque/cli/groups/embedding.py +222 -0
- maque/cli/groups/git.py +29 -0
- maque/cli/groups/help.py +410 -0
- maque/cli/groups/llm.py +223 -0
- maque/cli/groups/mcp.py +241 -0
- maque/cli/groups/mllm.py +1795 -0
- maque/cli/groups/mllm_simple.py +60 -0
- maque/cli/groups/quant.py +210 -0
- maque/cli/groups/service.py +490 -0
- maque/cli/groups/system.py +570 -0
- maque/cli/mllm_run.py +1451 -0
- maque/cli/script.py +52 -0
- maque/cli/tree.py +49 -0
- maque/clustering/__init__.py +52 -0
- maque/clustering/analyzer.py +347 -0
- maque/clustering/clusterers.py +464 -0
- maque/clustering/sampler.py +134 -0
- maque/clustering/visualizer.py +205 -0
- maque/constant.py +13 -0
- maque/core.py +133 -0
- maque/cv/__init__.py +1 -0
- maque/cv/image.py +219 -0
- maque/cv/utils.py +68 -0
- maque/cv/video/__init__.py +3 -0
- maque/cv/video/keyframe_extractor.py +368 -0
- maque/embedding/__init__.py +43 -0
- maque/embedding/base.py +56 -0
- maque/embedding/multimodal.py +308 -0
- maque/embedding/server.py +523 -0
- maque/embedding/text.py +311 -0
- maque/git/__init__.py +24 -0
- maque/git/pure_git.py +912 -0
- maque/io/__init__.py +29 -0
- maque/io/core.py +38 -0
- maque/io/ops.py +194 -0
- maque/llm/__init__.py +111 -0
- maque/llm/backend.py +416 -0
- maque/llm/base.py +411 -0
- maque/llm/server.py +366 -0
- maque/mcp_server.py +1096 -0
- maque/mllm_data_processor_pipeline/__init__.py +17 -0
- maque/mllm_data_processor_pipeline/core.py +341 -0
- maque/mllm_data_processor_pipeline/example.py +291 -0
- maque/mllm_data_processor_pipeline/steps/__init__.py +56 -0
- maque/mllm_data_processor_pipeline/steps/data_alignment.py +267 -0
- maque/mllm_data_processor_pipeline/steps/data_loader.py +172 -0
- maque/mllm_data_processor_pipeline/steps/data_validation.py +304 -0
- maque/mllm_data_processor_pipeline/steps/format_conversion.py +411 -0
- maque/mllm_data_processor_pipeline/steps/mllm_annotation.py +331 -0
- maque/mllm_data_processor_pipeline/steps/mllm_refinement.py +446 -0
- maque/mllm_data_processor_pipeline/steps/result_validation.py +501 -0
- maque/mllm_data_processor_pipeline/web_app.py +317 -0
- maque/nlp/__init__.py +14 -0
- maque/nlp/ngram.py +9 -0
- maque/nlp/parser.py +63 -0
- maque/nlp/risk_matcher.py +543 -0
- maque/nlp/sentence_splitter.py +202 -0
- maque/nlp/simple_tradition_cvt.py +31 -0
- maque/performance/__init__.py +21 -0
- maque/performance/_measure_time.py +70 -0
- maque/performance/_profiler.py +367 -0
- maque/performance/_stat_memory.py +51 -0
- maque/pipelines/__init__.py +15 -0
- maque/pipelines/clustering.py +252 -0
- maque/quantization/__init__.py +42 -0
- maque/quantization/auto_round.py +120 -0
- maque/quantization/base.py +145 -0
- maque/quantization/bitsandbytes.py +127 -0
- maque/quantization/llm_compressor.py +102 -0
- maque/retriever/__init__.py +35 -0
- maque/retriever/chroma.py +654 -0
- maque/retriever/document.py +140 -0
- maque/retriever/milvus.py +1140 -0
- maque/table_ops/__init__.py +1 -0
- maque/table_ops/core.py +133 -0
- maque/table_viewer/__init__.py +4 -0
- maque/table_viewer/download_assets.py +57 -0
- maque/table_viewer/server.py +698 -0
- maque/table_viewer/static/element-plus-icons.js +5791 -0
- maque/table_viewer/static/element-plus.css +1 -0
- maque/table_viewer/static/element-plus.js +65236 -0
- maque/table_viewer/static/main.css +268 -0
- maque/table_viewer/static/main.js +669 -0
- maque/table_viewer/static/vue.global.js +18227 -0
- maque/table_viewer/templates/index.html +401 -0
- maque/utils/__init__.py +56 -0
- maque/utils/color.py +68 -0
- maque/utils/color_string.py +45 -0
- maque/utils/compress.py +66 -0
- maque/utils/constant.py +183 -0
- maque/utils/core.py +261 -0
- maque/utils/cursor.py +143 -0
- maque/utils/distance.py +58 -0
- maque/utils/docker.py +96 -0
- maque/utils/downloads.py +51 -0
- maque/utils/excel_helper.py +542 -0
- maque/utils/helper_metrics.py +121 -0
- maque/utils/helper_parser.py +168 -0
- maque/utils/net.py +64 -0
- maque/utils/nvidia_stat.py +140 -0
- maque/utils/ops.py +53 -0
- maque/utils/packages.py +31 -0
- maque/utils/path.py +57 -0
- maque/utils/tar.py +260 -0
- maque/utils/untar.py +129 -0
- maque/web/__init__.py +0 -0
- maque/web/image_downloader.py +1410 -0
- maque-0.2.1.dist-info/METADATA +450 -0
- maque-0.2.1.dist-info/RECORD +143 -0
- maque-0.2.1.dist-info/WHEEL +4 -0
- maque-0.2.1.dist-info/entry_points.txt +3 -0
- maque-0.2.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,654 @@
|
|
|
1
|
+
#! /usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
ChromaDB 检索器实现
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import List, Optional, Union, Literal
|
|
9
|
+
|
|
10
|
+
from loguru import logger
|
|
11
|
+
import chromadb
|
|
12
|
+
from chromadb.config import Settings
|
|
13
|
+
|
|
14
|
+
from ..embedding.base import BaseEmbedding
|
|
15
|
+
from .document import Document, SearchResult, Modality, _content_hash
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
DistanceMetric = Literal["cosine", "l2", "ip"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ChromaRetriever:
|
|
22
|
+
"""
|
|
23
|
+
基于 ChromaDB 的检索器
|
|
24
|
+
支持文本和图片的向量检索
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
embedding: BaseEmbedding,
|
|
30
|
+
persist_dir: Optional[str] = None,
|
|
31
|
+
collection_name: str = "default",
|
|
32
|
+
distance_metric: DistanceMetric = "cosine",
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
初始化检索器
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
embedding: Embedding 实例(TextEmbedding 或 MultiModalEmbedding)
|
|
39
|
+
persist_dir: 持久化目录,None 为内存模式
|
|
40
|
+
collection_name: 集合名称
|
|
41
|
+
distance_metric: 距离度量方式 (cosine/l2/ip)
|
|
42
|
+
"""
|
|
43
|
+
self.embedding = embedding
|
|
44
|
+
self.persist_dir = persist_dir
|
|
45
|
+
self.collection_name = collection_name
|
|
46
|
+
self.distance_metric = distance_metric
|
|
47
|
+
|
|
48
|
+
# 初始化 ChromaDB
|
|
49
|
+
if persist_dir:
|
|
50
|
+
logger.debug(f"Initializing ChromaDB with persist_dir: {persist_dir}")
|
|
51
|
+
self.client = chromadb.PersistentClient(path=persist_dir)
|
|
52
|
+
else:
|
|
53
|
+
logger.debug("Initializing ChromaDB in memory mode")
|
|
54
|
+
self.client = chromadb.Client()
|
|
55
|
+
|
|
56
|
+
# 获取或创建集合
|
|
57
|
+
self.collection = self.client.get_or_create_collection(
|
|
58
|
+
name=collection_name,
|
|
59
|
+
metadata={"hnsw:space": distance_metric},
|
|
60
|
+
)
|
|
61
|
+
logger.info(f"Collection '{collection_name}' ready, {self.count()} documents")
|
|
62
|
+
|
|
63
|
+
def _get_input_type(self, modality: Modality) -> str:
|
|
64
|
+
"""获取 embedding 的 input_type 参数"""
|
|
65
|
+
return "image" if modality == "image" else "text"
|
|
66
|
+
|
|
67
|
+
def _embed_documents(self, documents: List[Document]) -> List[List[float]]:
|
|
68
|
+
"""对文档进行向量化"""
|
|
69
|
+
if not documents:
|
|
70
|
+
return []
|
|
71
|
+
|
|
72
|
+
# 检查是否有图片,如果有则需要多模态 embedding
|
|
73
|
+
has_image = any(doc.is_image for doc in documents)
|
|
74
|
+
if has_image and not self.embedding.supports_image:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Embedding 不支持图片,但文档中包含图片。"
|
|
77
|
+
f"请使用 MultiModalEmbedding。"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# 分组处理:文本和图片分开
|
|
81
|
+
if has_image:
|
|
82
|
+
# 多模态:按顺序处理,保持索引对应
|
|
83
|
+
embeddings = []
|
|
84
|
+
for doc in documents:
|
|
85
|
+
input_type = self._get_input_type(doc.modality)
|
|
86
|
+
vec = self.embedding.embed([doc.content], input_type=input_type)[0]
|
|
87
|
+
embeddings.append(vec)
|
|
88
|
+
return embeddings
|
|
89
|
+
else:
|
|
90
|
+
# 纯文本:批量处理
|
|
91
|
+
contents = [doc.content for doc in documents]
|
|
92
|
+
return self.embedding.embed(contents)
|
|
93
|
+
|
|
94
|
+
def _embed_query(
|
|
95
|
+
self,
|
|
96
|
+
query: str,
|
|
97
|
+
query_type: Modality = "text",
|
|
98
|
+
) -> List[float]:
|
|
99
|
+
"""对查询进行向量化"""
|
|
100
|
+
if query_type == "image" and not self.embedding.supports_image:
|
|
101
|
+
raise ValueError("Embedding 不支持图片查询")
|
|
102
|
+
|
|
103
|
+
if self.embedding.supports_image:
|
|
104
|
+
input_type = self._get_input_type(query_type)
|
|
105
|
+
return self.embedding.embed([query], input_type=input_type)[0]
|
|
106
|
+
else:
|
|
107
|
+
return self.embedding.embed([query])[0]
|
|
108
|
+
|
|
109
|
+
def _embed_queries(
|
|
110
|
+
self,
|
|
111
|
+
queries: List[str],
|
|
112
|
+
query_type: Modality = "text",
|
|
113
|
+
) -> List[List[float]]:
|
|
114
|
+
"""对多个查询进行批量向量化"""
|
|
115
|
+
if not queries:
|
|
116
|
+
return []
|
|
117
|
+
|
|
118
|
+
if query_type == "image" and not self.embedding.supports_image:
|
|
119
|
+
raise ValueError("Embedding 不支持图片查询")
|
|
120
|
+
|
|
121
|
+
if self.embedding.supports_image:
|
|
122
|
+
input_type = self._get_input_type(query_type)
|
|
123
|
+
return self.embedding.embed(queries, input_type=input_type)
|
|
124
|
+
else:
|
|
125
|
+
return self.embedding.embed(queries)
|
|
126
|
+
|
|
127
|
+
# ========== 索引操作 ==========
|
|
128
|
+
|
|
129
|
+
def add(
|
|
130
|
+
self,
|
|
131
|
+
documents: Union[Document, List[Document]],
|
|
132
|
+
skip_existing: bool = False,
|
|
133
|
+
) -> List[str]:
|
|
134
|
+
"""
|
|
135
|
+
添加文档
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
documents: 单个文档或文档列表
|
|
139
|
+
skip_existing: 是否跳过已存在的文档
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
添加的文档 ID 列表
|
|
143
|
+
"""
|
|
144
|
+
if isinstance(documents, Document):
|
|
145
|
+
documents = [documents]
|
|
146
|
+
|
|
147
|
+
if not documents:
|
|
148
|
+
return []
|
|
149
|
+
|
|
150
|
+
# 过滤已存在的文档
|
|
151
|
+
if skip_existing:
|
|
152
|
+
existing_ids = self._get_existing_ids([doc.id for doc in documents])
|
|
153
|
+
skipped = len([doc for doc in documents if doc.id in existing_ids])
|
|
154
|
+
documents = [doc for doc in documents if doc.id not in existing_ids]
|
|
155
|
+
if skipped > 0:
|
|
156
|
+
logger.debug(f"Skipped {skipped} existing documents")
|
|
157
|
+
if not documents:
|
|
158
|
+
return []
|
|
159
|
+
|
|
160
|
+
# 向量化
|
|
161
|
+
embeddings = self._embed_documents(documents)
|
|
162
|
+
|
|
163
|
+
# 准备数据
|
|
164
|
+
ids = [doc.id for doc in documents]
|
|
165
|
+
contents = [doc.content for doc in documents]
|
|
166
|
+
metadatas = [
|
|
167
|
+
{**doc.metadata, "_modality": doc.modality}
|
|
168
|
+
for doc in documents
|
|
169
|
+
]
|
|
170
|
+
|
|
171
|
+
# 添加到集合
|
|
172
|
+
self.collection.add(
|
|
173
|
+
ids=ids,
|
|
174
|
+
embeddings=embeddings,
|
|
175
|
+
documents=contents,
|
|
176
|
+
metadatas=metadatas,
|
|
177
|
+
)
|
|
178
|
+
logger.debug(f"Added {len(documents)} documents")
|
|
179
|
+
|
|
180
|
+
return ids
|
|
181
|
+
|
|
182
|
+
def upsert(
|
|
183
|
+
self,
|
|
184
|
+
documents: Union[Document, List[Document]],
|
|
185
|
+
skip_existing: bool = False,
|
|
186
|
+
) -> List[str]:
|
|
187
|
+
"""
|
|
188
|
+
添加或更新文档
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
documents: 单个文档或文档列表
|
|
192
|
+
skip_existing: 是否跳过已存在的文档(为 True 时行为与 add 相同)
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
upsert 的文档 ID 列表
|
|
196
|
+
"""
|
|
197
|
+
if isinstance(documents, Document):
|
|
198
|
+
documents = [documents]
|
|
199
|
+
|
|
200
|
+
if not documents:
|
|
201
|
+
return []
|
|
202
|
+
|
|
203
|
+
# 过滤已存在的文档
|
|
204
|
+
if skip_existing:
|
|
205
|
+
existing_ids = self._get_existing_ids([doc.id for doc in documents])
|
|
206
|
+
documents = [doc for doc in documents if doc.id not in existing_ids]
|
|
207
|
+
if not documents:
|
|
208
|
+
return []
|
|
209
|
+
|
|
210
|
+
# 向量化
|
|
211
|
+
embeddings = self._embed_documents(documents)
|
|
212
|
+
|
|
213
|
+
# 准备数据
|
|
214
|
+
ids = [doc.id for doc in documents]
|
|
215
|
+
contents = [doc.content for doc in documents]
|
|
216
|
+
metadatas = [
|
|
217
|
+
{**doc.metadata, "_modality": doc.modality}
|
|
218
|
+
for doc in documents
|
|
219
|
+
]
|
|
220
|
+
|
|
221
|
+
# upsert 到集合
|
|
222
|
+
self.collection.upsert(
|
|
223
|
+
ids=ids,
|
|
224
|
+
embeddings=embeddings,
|
|
225
|
+
documents=contents,
|
|
226
|
+
metadatas=metadatas,
|
|
227
|
+
)
|
|
228
|
+
logger.debug(f"Upserted {len(documents)} documents")
|
|
229
|
+
|
|
230
|
+
return ids
|
|
231
|
+
|
|
232
|
+
def delete(self, ids: Union[str, List[str]]) -> None:
|
|
233
|
+
"""
|
|
234
|
+
删除文档
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
ids: 单个 ID 或 ID 列表
|
|
238
|
+
"""
|
|
239
|
+
if isinstance(ids, str):
|
|
240
|
+
ids = [ids]
|
|
241
|
+
|
|
242
|
+
self.collection.delete(ids=ids)
|
|
243
|
+
logger.debug(f"Deleted {len(ids)} documents")
|
|
244
|
+
|
|
245
|
+
def delete_by_content(self, contents: Union[str, List[str]]) -> None:
|
|
246
|
+
"""
|
|
247
|
+
根据内容删除文档
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
contents: 单个内容或内容列表
|
|
251
|
+
"""
|
|
252
|
+
if isinstance(contents, str):
|
|
253
|
+
contents = [contents]
|
|
254
|
+
|
|
255
|
+
ids = [_content_hash(content) for content in contents]
|
|
256
|
+
self.delete(ids)
|
|
257
|
+
|
|
258
|
+
# ========== 检索操作 ==========
|
|
259
|
+
|
|
260
|
+
def search(
|
|
261
|
+
self,
|
|
262
|
+
query: str,
|
|
263
|
+
top_k: int = 5,
|
|
264
|
+
query_type: Modality = "text",
|
|
265
|
+
where: Optional[dict] = None,
|
|
266
|
+
where_document: Optional[dict] = None,
|
|
267
|
+
) -> List[SearchResult]:
|
|
268
|
+
"""
|
|
269
|
+
检索相似文档
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
query: 查询内容(文本或图片路径/URL)
|
|
273
|
+
top_k: 返回数量
|
|
274
|
+
query_type: 查询类型 "text" / "image"
|
|
275
|
+
where: 元数据过滤条件
|
|
276
|
+
where_document: 文档内容过滤条件
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
SearchResult 列表
|
|
280
|
+
"""
|
|
281
|
+
# 向量化查询
|
|
282
|
+
query_embedding = self._embed_query(query, query_type)
|
|
283
|
+
|
|
284
|
+
# 检索
|
|
285
|
+
results = self.collection.query(
|
|
286
|
+
query_embeddings=[query_embedding],
|
|
287
|
+
n_results=top_k,
|
|
288
|
+
where=where,
|
|
289
|
+
where_document=where_document,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
parsed = self._parse_results(results)
|
|
293
|
+
return parsed[0] if parsed else []
|
|
294
|
+
|
|
295
|
+
def search_by_vector(
|
|
296
|
+
self,
|
|
297
|
+
vector: List[float],
|
|
298
|
+
top_k: int = 5,
|
|
299
|
+
where: Optional[dict] = None,
|
|
300
|
+
) -> List[SearchResult]:
|
|
301
|
+
"""
|
|
302
|
+
直接使用向量检索
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
vector: 查询向量
|
|
306
|
+
top_k: 返回数量
|
|
307
|
+
where: 元数据过滤条件
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
SearchResult 列表
|
|
311
|
+
"""
|
|
312
|
+
results = self.collection.query(
|
|
313
|
+
query_embeddings=[vector],
|
|
314
|
+
n_results=top_k,
|
|
315
|
+
where=where,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
return self._parse_results(results)[0]
|
|
319
|
+
|
|
320
|
+
def search_batch(
|
|
321
|
+
self,
|
|
322
|
+
queries: List[str],
|
|
323
|
+
top_k: int = 5,
|
|
324
|
+
query_type: Modality = "text",
|
|
325
|
+
where: Optional[dict] = None,
|
|
326
|
+
where_document: Optional[dict] = None,
|
|
327
|
+
) -> List[List[SearchResult]]:
|
|
328
|
+
"""
|
|
329
|
+
批量检索相似文档
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
queries: 查询内容列表(文本或图片路径/URL)
|
|
333
|
+
top_k: 每个查询返回的数量
|
|
334
|
+
query_type: 查询类型 "text" / "image"
|
|
335
|
+
where: 元数据过滤条件
|
|
336
|
+
where_document: 文档内容过滤条件
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
SearchResult 列表的列表,每个查询对应一个结果列表
|
|
340
|
+
|
|
341
|
+
Example:
|
|
342
|
+
>>> results = retriever.search_batch(["query1", "query2"], top_k=5)
|
|
343
|
+
>>> for i, query_results in enumerate(results):
|
|
344
|
+
... print(f"Query {i}: {len(query_results)} results")
|
|
345
|
+
"""
|
|
346
|
+
if not queries:
|
|
347
|
+
return []
|
|
348
|
+
|
|
349
|
+
# 批量向量化查询
|
|
350
|
+
query_embeddings = self._embed_queries(queries, query_type)
|
|
351
|
+
|
|
352
|
+
# 批量检索
|
|
353
|
+
results = self.collection.query(
|
|
354
|
+
query_embeddings=query_embeddings,
|
|
355
|
+
n_results=top_k,
|
|
356
|
+
where=where,
|
|
357
|
+
where_document=where_document,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
return self._parse_results(results)
|
|
361
|
+
|
|
362
|
+
def search_by_vectors(
|
|
363
|
+
self,
|
|
364
|
+
vectors: List[List[float]],
|
|
365
|
+
top_k: int = 5,
|
|
366
|
+
where: Optional[dict] = None,
|
|
367
|
+
) -> List[List[SearchResult]]:
|
|
368
|
+
"""
|
|
369
|
+
批量使用向量检索
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
vectors: 查询向量列表
|
|
373
|
+
top_k: 每个查询返回的数量
|
|
374
|
+
where: 元数据过滤条件
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
SearchResult 列表的列表,每个向量对应一个结果列表
|
|
378
|
+
|
|
379
|
+
Example:
|
|
380
|
+
>>> vectors = [[0.1, 0.2, ...], [0.3, 0.4, ...]]
|
|
381
|
+
>>> results = retriever.search_by_vectors(vectors, top_k=5)
|
|
382
|
+
"""
|
|
383
|
+
if not vectors:
|
|
384
|
+
return []
|
|
385
|
+
|
|
386
|
+
results = self.collection.query(
|
|
387
|
+
query_embeddings=vectors,
|
|
388
|
+
n_results=top_k,
|
|
389
|
+
where=where,
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
return self._parse_results(results)
|
|
393
|
+
|
|
394
|
+
def _parse_results(self, results: dict) -> List[List[SearchResult]]:
|
|
395
|
+
"""
|
|
396
|
+
解析 ChromaDB 返回结果
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
results: ChromaDB query 返回的结果字典
|
|
400
|
+
|
|
401
|
+
Returns:
|
|
402
|
+
SearchResult 列表的列表,每个查询对应一个结果列表
|
|
403
|
+
"""
|
|
404
|
+
if not results or not results.get("ids"):
|
|
405
|
+
return []
|
|
406
|
+
|
|
407
|
+
all_results = []
|
|
408
|
+
num_queries = len(results["ids"])
|
|
409
|
+
|
|
410
|
+
for query_idx in range(num_queries):
|
|
411
|
+
ids = results["ids"][query_idx]
|
|
412
|
+
if not ids:
|
|
413
|
+
all_results.append([])
|
|
414
|
+
continue
|
|
415
|
+
|
|
416
|
+
documents = results.get("documents", [[]] * num_queries)[query_idx]
|
|
417
|
+
metadatas = results.get("metadatas", [[]] * num_queries)[query_idx]
|
|
418
|
+
distances = results.get("distances", [[]] * num_queries)[query_idx]
|
|
419
|
+
|
|
420
|
+
query_results = []
|
|
421
|
+
for i, doc_id in enumerate(ids):
|
|
422
|
+
metadata = dict(metadatas[i]) if metadatas and i < len(metadatas) else {}
|
|
423
|
+
modality = metadata.pop("_modality", "text")
|
|
424
|
+
|
|
425
|
+
# 距离转相似度 (cosine: 1 - distance)
|
|
426
|
+
distance = distances[i] if distances and i < len(distances) else 0
|
|
427
|
+
if self.distance_metric == "cosine":
|
|
428
|
+
score = 1 - distance
|
|
429
|
+
else:
|
|
430
|
+
score = -distance # l2/ip: 距离越小越好
|
|
431
|
+
|
|
432
|
+
query_results.append(SearchResult(
|
|
433
|
+
id=doc_id,
|
|
434
|
+
content=documents[i] if documents and i < len(documents) else "",
|
|
435
|
+
score=score,
|
|
436
|
+
modality=modality,
|
|
437
|
+
metadata=metadata,
|
|
438
|
+
))
|
|
439
|
+
|
|
440
|
+
all_results.append(query_results)
|
|
441
|
+
|
|
442
|
+
return all_results
|
|
443
|
+
|
|
444
|
+
# ========== 管理操作 ==========
|
|
445
|
+
|
|
446
|
+
def get(
|
|
447
|
+
self,
|
|
448
|
+
ids: Optional[Union[str, List[str]]] = None,
|
|
449
|
+
where: Optional[dict] = None,
|
|
450
|
+
limit: Optional[int] = None,
|
|
451
|
+
) -> List[Document]:
|
|
452
|
+
"""
|
|
453
|
+
获取文档
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
ids: 文档 ID 或 ID 列表
|
|
457
|
+
where: 元数据过滤条件
|
|
458
|
+
limit: 返回数量限制
|
|
459
|
+
|
|
460
|
+
Returns:
|
|
461
|
+
Document 列表
|
|
462
|
+
"""
|
|
463
|
+
if isinstance(ids, str):
|
|
464
|
+
ids = [ids]
|
|
465
|
+
|
|
466
|
+
results = self.collection.get(
|
|
467
|
+
ids=ids,
|
|
468
|
+
where=where,
|
|
469
|
+
limit=limit,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
documents = []
|
|
473
|
+
if results and results.get("ids"):
|
|
474
|
+
for i, doc_id in enumerate(results["ids"]):
|
|
475
|
+
metadata = results["metadatas"][i] if results.get("metadatas") else {}
|
|
476
|
+
modality = metadata.pop("_modality", "text")
|
|
477
|
+
content = results["documents"][i] if results.get("documents") else ""
|
|
478
|
+
|
|
479
|
+
documents.append(Document(
|
|
480
|
+
id=doc_id,
|
|
481
|
+
content=content,
|
|
482
|
+
modality=modality,
|
|
483
|
+
metadata=metadata,
|
|
484
|
+
))
|
|
485
|
+
|
|
486
|
+
return documents
|
|
487
|
+
|
|
488
|
+
def count(self) -> int:
|
|
489
|
+
"""返回文档数量"""
|
|
490
|
+
return self.collection.count()
|
|
491
|
+
|
|
492
|
+
def clear(self) -> None:
|
|
493
|
+
"""清空集合"""
|
|
494
|
+
logger.info(f"Clearing collection: {self.collection_name}")
|
|
495
|
+
self.client.delete_collection(self.collection_name)
|
|
496
|
+
self.collection = self.client.get_or_create_collection(
|
|
497
|
+
name=self.collection_name,
|
|
498
|
+
metadata={"hnsw:space": self.distance_metric},
|
|
499
|
+
)
|
|
500
|
+
logger.info(f"Collection '{self.collection_name}' cleared and recreated")
|
|
501
|
+
|
|
502
|
+
# ========== 便利方法 ==========
|
|
503
|
+
|
|
504
|
+
def upsert_batch(
|
|
505
|
+
self,
|
|
506
|
+
documents: List[Document],
|
|
507
|
+
batch_size: int = 32,
|
|
508
|
+
skip_existing: bool = False,
|
|
509
|
+
show_progress: bool = True,
|
|
510
|
+
) -> int:
|
|
511
|
+
"""
|
|
512
|
+
批量插入文档(带进度条和增量更新支持)
|
|
513
|
+
|
|
514
|
+
Args:
|
|
515
|
+
documents: 文档列表
|
|
516
|
+
batch_size: 批处理大小
|
|
517
|
+
skip_existing: 是否跳过已存在的文档
|
|
518
|
+
show_progress: 是否显示进度条
|
|
519
|
+
|
|
520
|
+
Returns:
|
|
521
|
+
实际插入的文档数量
|
|
522
|
+
|
|
523
|
+
Example:
|
|
524
|
+
>>> retriever = ChromaRetriever(embedding, persist_dir, collection_name)
|
|
525
|
+
>>> docs = [Document.text(content=text, id=f"doc_{i}") for i, text in enumerate(texts)]
|
|
526
|
+
>>> count = retriever.upsert_batch(docs, batch_size=32, skip_existing=True)
|
|
527
|
+
>>> print(f"插入 {count} 个文档")
|
|
528
|
+
"""
|
|
529
|
+
if not documents:
|
|
530
|
+
return 0
|
|
531
|
+
|
|
532
|
+
total_docs = len(documents)
|
|
533
|
+
logger.info(f"Starting batch upsert: {total_docs} documents, batch_size={batch_size}")
|
|
534
|
+
|
|
535
|
+
# 过滤已存在的文档
|
|
536
|
+
skipped = 0
|
|
537
|
+
if skip_existing:
|
|
538
|
+
existing_ids = self._get_existing_ids([doc.id for doc in documents])
|
|
539
|
+
skipped = len([doc for doc in documents if doc.id in existing_ids])
|
|
540
|
+
documents = [doc for doc in documents if doc.id not in existing_ids]
|
|
541
|
+
if skipped > 0:
|
|
542
|
+
logger.info(f"Skipped {skipped} existing documents")
|
|
543
|
+
if not documents:
|
|
544
|
+
return 0
|
|
545
|
+
|
|
546
|
+
# 批量插入
|
|
547
|
+
inserted = 0
|
|
548
|
+
total_batches = (len(documents) + batch_size - 1) // batch_size
|
|
549
|
+
iterator = range(0, len(documents), batch_size)
|
|
550
|
+
|
|
551
|
+
if show_progress:
|
|
552
|
+
try:
|
|
553
|
+
from tqdm import tqdm
|
|
554
|
+
iterator = tqdm(
|
|
555
|
+
iterator,
|
|
556
|
+
desc="Upserting",
|
|
557
|
+
total=total_batches,
|
|
558
|
+
unit="batch",
|
|
559
|
+
)
|
|
560
|
+
except ImportError:
|
|
561
|
+
logger.debug("tqdm not installed, progress bar disabled")
|
|
562
|
+
|
|
563
|
+
for i in iterator:
|
|
564
|
+
batch = documents[i:i + batch_size]
|
|
565
|
+
self.upsert(batch)
|
|
566
|
+
inserted += len(batch)
|
|
567
|
+
|
|
568
|
+
logger.info(f"Batch upsert completed: {inserted} inserted, {skipped} skipped")
|
|
569
|
+
return inserted
|
|
570
|
+
|
|
571
|
+
def _get_existing_ids(self, candidate_ids: List[str]) -> set:
|
|
572
|
+
"""获取已存在的文档 ID 集合"""
|
|
573
|
+
existing_ids = set()
|
|
574
|
+
batch_size = 10000
|
|
575
|
+
|
|
576
|
+
for i in range(0, len(candidate_ids), batch_size):
|
|
577
|
+
batch_ids = candidate_ids[i:i + batch_size]
|
|
578
|
+
try:
|
|
579
|
+
results = self.collection.get(ids=batch_ids)
|
|
580
|
+
if results and results.get("ids"):
|
|
581
|
+
existing_ids.update(results["ids"])
|
|
582
|
+
except Exception:
|
|
583
|
+
pass # ID 不存在时忽略
|
|
584
|
+
|
|
585
|
+
return existing_ids
|
|
586
|
+
|
|
587
|
+
def get_all_ids(self) -> List[str]:
|
|
588
|
+
"""获取所有文档 ID"""
|
|
589
|
+
results = self.collection.get(include=[])
|
|
590
|
+
return results.get("ids", [])
|
|
591
|
+
|
|
592
|
+
def migrate_to(
|
|
593
|
+
self,
|
|
594
|
+
target,
|
|
595
|
+
batch_size: int = 100,
|
|
596
|
+
skip_existing: bool = True,
|
|
597
|
+
show_progress: bool = True,
|
|
598
|
+
) -> int:
|
|
599
|
+
"""
|
|
600
|
+
将当前 collection 的所有数据迁移到目标 retriever
|
|
601
|
+
|
|
602
|
+
Args:
|
|
603
|
+
target: 目标 retriever(ChromaRetriever 或 MilvusRetriever)
|
|
604
|
+
batch_size: 批处理大小
|
|
605
|
+
skip_existing: 是否跳过已存在的文档
|
|
606
|
+
show_progress: 是否显示进度条
|
|
607
|
+
|
|
608
|
+
Returns:
|
|
609
|
+
迁移的文档数量
|
|
610
|
+
"""
|
|
611
|
+
all_ids = self.get_all_ids()
|
|
612
|
+
if not all_ids:
|
|
613
|
+
logger.info("No documents to migrate")
|
|
614
|
+
return 0
|
|
615
|
+
|
|
616
|
+
total = len(all_ids)
|
|
617
|
+
logger.info(f"Starting migration: {total} documents")
|
|
618
|
+
|
|
619
|
+
migrated = 0
|
|
620
|
+
iterator = range(0, total, batch_size)
|
|
621
|
+
|
|
622
|
+
if show_progress:
|
|
623
|
+
try:
|
|
624
|
+
from tqdm import tqdm
|
|
625
|
+
iterator = tqdm(
|
|
626
|
+
iterator,
|
|
627
|
+
desc="Migrating",
|
|
628
|
+
total=(total + batch_size - 1) // batch_size,
|
|
629
|
+
unit="batch",
|
|
630
|
+
)
|
|
631
|
+
except ImportError:
|
|
632
|
+
pass
|
|
633
|
+
|
|
634
|
+
for i in iterator:
|
|
635
|
+
batch_ids = all_ids[i:i + batch_size]
|
|
636
|
+
documents = self.get(ids=batch_ids)
|
|
637
|
+
if documents:
|
|
638
|
+
migrated += target.upsert_batch(
|
|
639
|
+
documents,
|
|
640
|
+
batch_size=batch_size,
|
|
641
|
+
skip_existing=skip_existing,
|
|
642
|
+
show_progress=False,
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
logger.info(f"Migration completed: {migrated} documents migrated")
|
|
646
|
+
return migrated
|
|
647
|
+
|
|
648
|
+
def __repr__(self) -> str:
|
|
649
|
+
return (
|
|
650
|
+
f"ChromaRetriever("
|
|
651
|
+
f"collection={self.collection_name!r}, "
|
|
652
|
+
f"count={self.count()}, "
|
|
653
|
+
f"embedding={self.embedding.__class__.__name__})"
|
|
654
|
+
)
|