maque 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. maque/__init__.py +30 -0
  2. maque/__main__.py +926 -0
  3. maque/ai_platform/__init__.py +0 -0
  4. maque/ai_platform/crawl.py +45 -0
  5. maque/ai_platform/metrics.py +258 -0
  6. maque/ai_platform/nlp_preprocess.py +67 -0
  7. maque/ai_platform/webpage_screen_shot.py +195 -0
  8. maque/algorithms/__init__.py +78 -0
  9. maque/algorithms/bezier.py +15 -0
  10. maque/algorithms/bktree.py +117 -0
  11. maque/algorithms/core.py +104 -0
  12. maque/algorithms/hilbert.py +16 -0
  13. maque/algorithms/rate_function.py +92 -0
  14. maque/algorithms/transform.py +27 -0
  15. maque/algorithms/trie.py +272 -0
  16. maque/algorithms/utils.py +63 -0
  17. maque/algorithms/video.py +587 -0
  18. maque/api/__init__.py +1 -0
  19. maque/api/common.py +110 -0
  20. maque/api/fetch.py +26 -0
  21. maque/api/static/icon.png +0 -0
  22. maque/api/static/redoc.standalone.js +1782 -0
  23. maque/api/static/swagger-ui-bundle.js +3 -0
  24. maque/api/static/swagger-ui.css +3 -0
  25. maque/cli/__init__.py +1 -0
  26. maque/cli/clean_invisible_chars.py +324 -0
  27. maque/cli/core.py +34 -0
  28. maque/cli/groups/__init__.py +26 -0
  29. maque/cli/groups/config.py +205 -0
  30. maque/cli/groups/data.py +615 -0
  31. maque/cli/groups/doctor.py +259 -0
  32. maque/cli/groups/embedding.py +222 -0
  33. maque/cli/groups/git.py +29 -0
  34. maque/cli/groups/help.py +410 -0
  35. maque/cli/groups/llm.py +223 -0
  36. maque/cli/groups/mcp.py +241 -0
  37. maque/cli/groups/mllm.py +1795 -0
  38. maque/cli/groups/mllm_simple.py +60 -0
  39. maque/cli/groups/quant.py +210 -0
  40. maque/cli/groups/service.py +490 -0
  41. maque/cli/groups/system.py +570 -0
  42. maque/cli/mllm_run.py +1451 -0
  43. maque/cli/script.py +52 -0
  44. maque/cli/tree.py +49 -0
  45. maque/clustering/__init__.py +52 -0
  46. maque/clustering/analyzer.py +347 -0
  47. maque/clustering/clusterers.py +464 -0
  48. maque/clustering/sampler.py +134 -0
  49. maque/clustering/visualizer.py +205 -0
  50. maque/constant.py +13 -0
  51. maque/core.py +133 -0
  52. maque/cv/__init__.py +1 -0
  53. maque/cv/image.py +219 -0
  54. maque/cv/utils.py +68 -0
  55. maque/cv/video/__init__.py +3 -0
  56. maque/cv/video/keyframe_extractor.py +368 -0
  57. maque/embedding/__init__.py +43 -0
  58. maque/embedding/base.py +56 -0
  59. maque/embedding/multimodal.py +308 -0
  60. maque/embedding/server.py +523 -0
  61. maque/embedding/text.py +311 -0
  62. maque/git/__init__.py +24 -0
  63. maque/git/pure_git.py +912 -0
  64. maque/io/__init__.py +29 -0
  65. maque/io/core.py +38 -0
  66. maque/io/ops.py +194 -0
  67. maque/llm/__init__.py +111 -0
  68. maque/llm/backend.py +416 -0
  69. maque/llm/base.py +411 -0
  70. maque/llm/server.py +366 -0
  71. maque/mcp_server.py +1096 -0
  72. maque/mllm_data_processor_pipeline/__init__.py +17 -0
  73. maque/mllm_data_processor_pipeline/core.py +341 -0
  74. maque/mllm_data_processor_pipeline/example.py +291 -0
  75. maque/mllm_data_processor_pipeline/steps/__init__.py +56 -0
  76. maque/mllm_data_processor_pipeline/steps/data_alignment.py +267 -0
  77. maque/mllm_data_processor_pipeline/steps/data_loader.py +172 -0
  78. maque/mllm_data_processor_pipeline/steps/data_validation.py +304 -0
  79. maque/mllm_data_processor_pipeline/steps/format_conversion.py +411 -0
  80. maque/mllm_data_processor_pipeline/steps/mllm_annotation.py +331 -0
  81. maque/mllm_data_processor_pipeline/steps/mllm_refinement.py +446 -0
  82. maque/mllm_data_processor_pipeline/steps/result_validation.py +501 -0
  83. maque/mllm_data_processor_pipeline/web_app.py +317 -0
  84. maque/nlp/__init__.py +14 -0
  85. maque/nlp/ngram.py +9 -0
  86. maque/nlp/parser.py +63 -0
  87. maque/nlp/risk_matcher.py +543 -0
  88. maque/nlp/sentence_splitter.py +202 -0
  89. maque/nlp/simple_tradition_cvt.py +31 -0
  90. maque/performance/__init__.py +21 -0
  91. maque/performance/_measure_time.py +70 -0
  92. maque/performance/_profiler.py +367 -0
  93. maque/performance/_stat_memory.py +51 -0
  94. maque/pipelines/__init__.py +15 -0
  95. maque/pipelines/clustering.py +252 -0
  96. maque/quantization/__init__.py +42 -0
  97. maque/quantization/auto_round.py +120 -0
  98. maque/quantization/base.py +145 -0
  99. maque/quantization/bitsandbytes.py +127 -0
  100. maque/quantization/llm_compressor.py +102 -0
  101. maque/retriever/__init__.py +35 -0
  102. maque/retriever/chroma.py +654 -0
  103. maque/retriever/document.py +140 -0
  104. maque/retriever/milvus.py +1140 -0
  105. maque/table_ops/__init__.py +1 -0
  106. maque/table_ops/core.py +133 -0
  107. maque/table_viewer/__init__.py +4 -0
  108. maque/table_viewer/download_assets.py +57 -0
  109. maque/table_viewer/server.py +698 -0
  110. maque/table_viewer/static/element-plus-icons.js +5791 -0
  111. maque/table_viewer/static/element-plus.css +1 -0
  112. maque/table_viewer/static/element-plus.js +65236 -0
  113. maque/table_viewer/static/main.css +268 -0
  114. maque/table_viewer/static/main.js +669 -0
  115. maque/table_viewer/static/vue.global.js +18227 -0
  116. maque/table_viewer/templates/index.html +401 -0
  117. maque/utils/__init__.py +56 -0
  118. maque/utils/color.py +68 -0
  119. maque/utils/color_string.py +45 -0
  120. maque/utils/compress.py +66 -0
  121. maque/utils/constant.py +183 -0
  122. maque/utils/core.py +261 -0
  123. maque/utils/cursor.py +143 -0
  124. maque/utils/distance.py +58 -0
  125. maque/utils/docker.py +96 -0
  126. maque/utils/downloads.py +51 -0
  127. maque/utils/excel_helper.py +542 -0
  128. maque/utils/helper_metrics.py +121 -0
  129. maque/utils/helper_parser.py +168 -0
  130. maque/utils/net.py +64 -0
  131. maque/utils/nvidia_stat.py +140 -0
  132. maque/utils/ops.py +53 -0
  133. maque/utils/packages.py +31 -0
  134. maque/utils/path.py +57 -0
  135. maque/utils/tar.py +260 -0
  136. maque/utils/untar.py +129 -0
  137. maque/web/__init__.py +0 -0
  138. maque/web/image_downloader.py +1410 -0
  139. maque-0.2.1.dist-info/METADATA +450 -0
  140. maque-0.2.1.dist-info/RECORD +143 -0
  141. maque-0.2.1.dist-info/WHEEL +4 -0
  142. maque-0.2.1.dist-info/entry_points.txt +3 -0
  143. maque-0.2.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,654 @@
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ ChromaDB 检索器实现
6
+ """
7
+
8
+ from typing import List, Optional, Union, Literal
9
+
10
+ from loguru import logger
11
+ import chromadb
12
+ from chromadb.config import Settings
13
+
14
+ from ..embedding.base import BaseEmbedding
15
+ from .document import Document, SearchResult, Modality, _content_hash
16
+
17
+
18
+ DistanceMetric = Literal["cosine", "l2", "ip"]
19
+
20
+
21
+ class ChromaRetriever:
22
+ """
23
+ 基于 ChromaDB 的检索器
24
+ 支持文本和图片的向量检索
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ embedding: BaseEmbedding,
30
+ persist_dir: Optional[str] = None,
31
+ collection_name: str = "default",
32
+ distance_metric: DistanceMetric = "cosine",
33
+ ):
34
+ """
35
+ 初始化检索器
36
+
37
+ Args:
38
+ embedding: Embedding 实例(TextEmbedding 或 MultiModalEmbedding)
39
+ persist_dir: 持久化目录,None 为内存模式
40
+ collection_name: 集合名称
41
+ distance_metric: 距离度量方式 (cosine/l2/ip)
42
+ """
43
+ self.embedding = embedding
44
+ self.persist_dir = persist_dir
45
+ self.collection_name = collection_name
46
+ self.distance_metric = distance_metric
47
+
48
+ # 初始化 ChromaDB
49
+ if persist_dir:
50
+ logger.debug(f"Initializing ChromaDB with persist_dir: {persist_dir}")
51
+ self.client = chromadb.PersistentClient(path=persist_dir)
52
+ else:
53
+ logger.debug("Initializing ChromaDB in memory mode")
54
+ self.client = chromadb.Client()
55
+
56
+ # 获取或创建集合
57
+ self.collection = self.client.get_or_create_collection(
58
+ name=collection_name,
59
+ metadata={"hnsw:space": distance_metric},
60
+ )
61
+ logger.info(f"Collection '{collection_name}' ready, {self.count()} documents")
62
+
63
+ def _get_input_type(self, modality: Modality) -> str:
64
+ """获取 embedding 的 input_type 参数"""
65
+ return "image" if modality == "image" else "text"
66
+
67
+ def _embed_documents(self, documents: List[Document]) -> List[List[float]]:
68
+ """对文档进行向量化"""
69
+ if not documents:
70
+ return []
71
+
72
+ # 检查是否有图片,如果有则需要多模态 embedding
73
+ has_image = any(doc.is_image for doc in documents)
74
+ if has_image and not self.embedding.supports_image:
75
+ raise ValueError(
76
+ f"Embedding 不支持图片,但文档中包含图片。"
77
+ f"请使用 MultiModalEmbedding。"
78
+ )
79
+
80
+ # 分组处理:文本和图片分开
81
+ if has_image:
82
+ # 多模态:按顺序处理,保持索引对应
83
+ embeddings = []
84
+ for doc in documents:
85
+ input_type = self._get_input_type(doc.modality)
86
+ vec = self.embedding.embed([doc.content], input_type=input_type)[0]
87
+ embeddings.append(vec)
88
+ return embeddings
89
+ else:
90
+ # 纯文本:批量处理
91
+ contents = [doc.content for doc in documents]
92
+ return self.embedding.embed(contents)
93
+
94
+ def _embed_query(
95
+ self,
96
+ query: str,
97
+ query_type: Modality = "text",
98
+ ) -> List[float]:
99
+ """对查询进行向量化"""
100
+ if query_type == "image" and not self.embedding.supports_image:
101
+ raise ValueError("Embedding 不支持图片查询")
102
+
103
+ if self.embedding.supports_image:
104
+ input_type = self._get_input_type(query_type)
105
+ return self.embedding.embed([query], input_type=input_type)[0]
106
+ else:
107
+ return self.embedding.embed([query])[0]
108
+
109
+ def _embed_queries(
110
+ self,
111
+ queries: List[str],
112
+ query_type: Modality = "text",
113
+ ) -> List[List[float]]:
114
+ """对多个查询进行批量向量化"""
115
+ if not queries:
116
+ return []
117
+
118
+ if query_type == "image" and not self.embedding.supports_image:
119
+ raise ValueError("Embedding 不支持图片查询")
120
+
121
+ if self.embedding.supports_image:
122
+ input_type = self._get_input_type(query_type)
123
+ return self.embedding.embed(queries, input_type=input_type)
124
+ else:
125
+ return self.embedding.embed(queries)
126
+
127
+ # ========== 索引操作 ==========
128
+
129
+ def add(
130
+ self,
131
+ documents: Union[Document, List[Document]],
132
+ skip_existing: bool = False,
133
+ ) -> List[str]:
134
+ """
135
+ 添加文档
136
+
137
+ Args:
138
+ documents: 单个文档或文档列表
139
+ skip_existing: 是否跳过已存在的文档
140
+
141
+ Returns:
142
+ 添加的文档 ID 列表
143
+ """
144
+ if isinstance(documents, Document):
145
+ documents = [documents]
146
+
147
+ if not documents:
148
+ return []
149
+
150
+ # 过滤已存在的文档
151
+ if skip_existing:
152
+ existing_ids = self._get_existing_ids([doc.id for doc in documents])
153
+ skipped = len([doc for doc in documents if doc.id in existing_ids])
154
+ documents = [doc for doc in documents if doc.id not in existing_ids]
155
+ if skipped > 0:
156
+ logger.debug(f"Skipped {skipped} existing documents")
157
+ if not documents:
158
+ return []
159
+
160
+ # 向量化
161
+ embeddings = self._embed_documents(documents)
162
+
163
+ # 准备数据
164
+ ids = [doc.id for doc in documents]
165
+ contents = [doc.content for doc in documents]
166
+ metadatas = [
167
+ {**doc.metadata, "_modality": doc.modality}
168
+ for doc in documents
169
+ ]
170
+
171
+ # 添加到集合
172
+ self.collection.add(
173
+ ids=ids,
174
+ embeddings=embeddings,
175
+ documents=contents,
176
+ metadatas=metadatas,
177
+ )
178
+ logger.debug(f"Added {len(documents)} documents")
179
+
180
+ return ids
181
+
182
+ def upsert(
183
+ self,
184
+ documents: Union[Document, List[Document]],
185
+ skip_existing: bool = False,
186
+ ) -> List[str]:
187
+ """
188
+ 添加或更新文档
189
+
190
+ Args:
191
+ documents: 单个文档或文档列表
192
+ skip_existing: 是否跳过已存在的文档(为 True 时行为与 add 相同)
193
+
194
+ Returns:
195
+ upsert 的文档 ID 列表
196
+ """
197
+ if isinstance(documents, Document):
198
+ documents = [documents]
199
+
200
+ if not documents:
201
+ return []
202
+
203
+ # 过滤已存在的文档
204
+ if skip_existing:
205
+ existing_ids = self._get_existing_ids([doc.id for doc in documents])
206
+ documents = [doc for doc in documents if doc.id not in existing_ids]
207
+ if not documents:
208
+ return []
209
+
210
+ # 向量化
211
+ embeddings = self._embed_documents(documents)
212
+
213
+ # 准备数据
214
+ ids = [doc.id for doc in documents]
215
+ contents = [doc.content for doc in documents]
216
+ metadatas = [
217
+ {**doc.metadata, "_modality": doc.modality}
218
+ for doc in documents
219
+ ]
220
+
221
+ # upsert 到集合
222
+ self.collection.upsert(
223
+ ids=ids,
224
+ embeddings=embeddings,
225
+ documents=contents,
226
+ metadatas=metadatas,
227
+ )
228
+ logger.debug(f"Upserted {len(documents)} documents")
229
+
230
+ return ids
231
+
232
+ def delete(self, ids: Union[str, List[str]]) -> None:
233
+ """
234
+ 删除文档
235
+
236
+ Args:
237
+ ids: 单个 ID 或 ID 列表
238
+ """
239
+ if isinstance(ids, str):
240
+ ids = [ids]
241
+
242
+ self.collection.delete(ids=ids)
243
+ logger.debug(f"Deleted {len(ids)} documents")
244
+
245
+ def delete_by_content(self, contents: Union[str, List[str]]) -> None:
246
+ """
247
+ 根据内容删除文档
248
+
249
+ Args:
250
+ contents: 单个内容或内容列表
251
+ """
252
+ if isinstance(contents, str):
253
+ contents = [contents]
254
+
255
+ ids = [_content_hash(content) for content in contents]
256
+ self.delete(ids)
257
+
258
+ # ========== 检索操作 ==========
259
+
260
+ def search(
261
+ self,
262
+ query: str,
263
+ top_k: int = 5,
264
+ query_type: Modality = "text",
265
+ where: Optional[dict] = None,
266
+ where_document: Optional[dict] = None,
267
+ ) -> List[SearchResult]:
268
+ """
269
+ 检索相似文档
270
+
271
+ Args:
272
+ query: 查询内容(文本或图片路径/URL)
273
+ top_k: 返回数量
274
+ query_type: 查询类型 "text" / "image"
275
+ where: 元数据过滤条件
276
+ where_document: 文档内容过滤条件
277
+
278
+ Returns:
279
+ SearchResult 列表
280
+ """
281
+ # 向量化查询
282
+ query_embedding = self._embed_query(query, query_type)
283
+
284
+ # 检索
285
+ results = self.collection.query(
286
+ query_embeddings=[query_embedding],
287
+ n_results=top_k,
288
+ where=where,
289
+ where_document=where_document,
290
+ )
291
+
292
+ parsed = self._parse_results(results)
293
+ return parsed[0] if parsed else []
294
+
295
+ def search_by_vector(
296
+ self,
297
+ vector: List[float],
298
+ top_k: int = 5,
299
+ where: Optional[dict] = None,
300
+ ) -> List[SearchResult]:
301
+ """
302
+ 直接使用向量检索
303
+
304
+ Args:
305
+ vector: 查询向量
306
+ top_k: 返回数量
307
+ where: 元数据过滤条件
308
+
309
+ Returns:
310
+ SearchResult 列表
311
+ """
312
+ results = self.collection.query(
313
+ query_embeddings=[vector],
314
+ n_results=top_k,
315
+ where=where,
316
+ )
317
+
318
+ return self._parse_results(results)[0]
319
+
320
+ def search_batch(
321
+ self,
322
+ queries: List[str],
323
+ top_k: int = 5,
324
+ query_type: Modality = "text",
325
+ where: Optional[dict] = None,
326
+ where_document: Optional[dict] = None,
327
+ ) -> List[List[SearchResult]]:
328
+ """
329
+ 批量检索相似文档
330
+
331
+ Args:
332
+ queries: 查询内容列表(文本或图片路径/URL)
333
+ top_k: 每个查询返回的数量
334
+ query_type: 查询类型 "text" / "image"
335
+ where: 元数据过滤条件
336
+ where_document: 文档内容过滤条件
337
+
338
+ Returns:
339
+ SearchResult 列表的列表,每个查询对应一个结果列表
340
+
341
+ Example:
342
+ >>> results = retriever.search_batch(["query1", "query2"], top_k=5)
343
+ >>> for i, query_results in enumerate(results):
344
+ ... print(f"Query {i}: {len(query_results)} results")
345
+ """
346
+ if not queries:
347
+ return []
348
+
349
+ # 批量向量化查询
350
+ query_embeddings = self._embed_queries(queries, query_type)
351
+
352
+ # 批量检索
353
+ results = self.collection.query(
354
+ query_embeddings=query_embeddings,
355
+ n_results=top_k,
356
+ where=where,
357
+ where_document=where_document,
358
+ )
359
+
360
+ return self._parse_results(results)
361
+
362
+ def search_by_vectors(
363
+ self,
364
+ vectors: List[List[float]],
365
+ top_k: int = 5,
366
+ where: Optional[dict] = None,
367
+ ) -> List[List[SearchResult]]:
368
+ """
369
+ 批量使用向量检索
370
+
371
+ Args:
372
+ vectors: 查询向量列表
373
+ top_k: 每个查询返回的数量
374
+ where: 元数据过滤条件
375
+
376
+ Returns:
377
+ SearchResult 列表的列表,每个向量对应一个结果列表
378
+
379
+ Example:
380
+ >>> vectors = [[0.1, 0.2, ...], [0.3, 0.4, ...]]
381
+ >>> results = retriever.search_by_vectors(vectors, top_k=5)
382
+ """
383
+ if not vectors:
384
+ return []
385
+
386
+ results = self.collection.query(
387
+ query_embeddings=vectors,
388
+ n_results=top_k,
389
+ where=where,
390
+ )
391
+
392
+ return self._parse_results(results)
393
+
394
+ def _parse_results(self, results: dict) -> List[List[SearchResult]]:
395
+ """
396
+ 解析 ChromaDB 返回结果
397
+
398
+ Args:
399
+ results: ChromaDB query 返回的结果字典
400
+
401
+ Returns:
402
+ SearchResult 列表的列表,每个查询对应一个结果列表
403
+ """
404
+ if not results or not results.get("ids"):
405
+ return []
406
+
407
+ all_results = []
408
+ num_queries = len(results["ids"])
409
+
410
+ for query_idx in range(num_queries):
411
+ ids = results["ids"][query_idx]
412
+ if not ids:
413
+ all_results.append([])
414
+ continue
415
+
416
+ documents = results.get("documents", [[]] * num_queries)[query_idx]
417
+ metadatas = results.get("metadatas", [[]] * num_queries)[query_idx]
418
+ distances = results.get("distances", [[]] * num_queries)[query_idx]
419
+
420
+ query_results = []
421
+ for i, doc_id in enumerate(ids):
422
+ metadata = dict(metadatas[i]) if metadatas and i < len(metadatas) else {}
423
+ modality = metadata.pop("_modality", "text")
424
+
425
+ # 距离转相似度 (cosine: 1 - distance)
426
+ distance = distances[i] if distances and i < len(distances) else 0
427
+ if self.distance_metric == "cosine":
428
+ score = 1 - distance
429
+ else:
430
+ score = -distance # l2/ip: 距离越小越好
431
+
432
+ query_results.append(SearchResult(
433
+ id=doc_id,
434
+ content=documents[i] if documents and i < len(documents) else "",
435
+ score=score,
436
+ modality=modality,
437
+ metadata=metadata,
438
+ ))
439
+
440
+ all_results.append(query_results)
441
+
442
+ return all_results
443
+
444
+ # ========== 管理操作 ==========
445
+
446
+ def get(
447
+ self,
448
+ ids: Optional[Union[str, List[str]]] = None,
449
+ where: Optional[dict] = None,
450
+ limit: Optional[int] = None,
451
+ ) -> List[Document]:
452
+ """
453
+ 获取文档
454
+
455
+ Args:
456
+ ids: 文档 ID 或 ID 列表
457
+ where: 元数据过滤条件
458
+ limit: 返回数量限制
459
+
460
+ Returns:
461
+ Document 列表
462
+ """
463
+ if isinstance(ids, str):
464
+ ids = [ids]
465
+
466
+ results = self.collection.get(
467
+ ids=ids,
468
+ where=where,
469
+ limit=limit,
470
+ )
471
+
472
+ documents = []
473
+ if results and results.get("ids"):
474
+ for i, doc_id in enumerate(results["ids"]):
475
+ metadata = results["metadatas"][i] if results.get("metadatas") else {}
476
+ modality = metadata.pop("_modality", "text")
477
+ content = results["documents"][i] if results.get("documents") else ""
478
+
479
+ documents.append(Document(
480
+ id=doc_id,
481
+ content=content,
482
+ modality=modality,
483
+ metadata=metadata,
484
+ ))
485
+
486
+ return documents
487
+
488
+ def count(self) -> int:
489
+ """返回文档数量"""
490
+ return self.collection.count()
491
+
492
+ def clear(self) -> None:
493
+ """清空集合"""
494
+ logger.info(f"Clearing collection: {self.collection_name}")
495
+ self.client.delete_collection(self.collection_name)
496
+ self.collection = self.client.get_or_create_collection(
497
+ name=self.collection_name,
498
+ metadata={"hnsw:space": self.distance_metric},
499
+ )
500
+ logger.info(f"Collection '{self.collection_name}' cleared and recreated")
501
+
502
+ # ========== 便利方法 ==========
503
+
504
+ def upsert_batch(
505
+ self,
506
+ documents: List[Document],
507
+ batch_size: int = 32,
508
+ skip_existing: bool = False,
509
+ show_progress: bool = True,
510
+ ) -> int:
511
+ """
512
+ 批量插入文档(带进度条和增量更新支持)
513
+
514
+ Args:
515
+ documents: 文档列表
516
+ batch_size: 批处理大小
517
+ skip_existing: 是否跳过已存在的文档
518
+ show_progress: 是否显示进度条
519
+
520
+ Returns:
521
+ 实际插入的文档数量
522
+
523
+ Example:
524
+ >>> retriever = ChromaRetriever(embedding, persist_dir, collection_name)
525
+ >>> docs = [Document.text(content=text, id=f"doc_{i}") for i, text in enumerate(texts)]
526
+ >>> count = retriever.upsert_batch(docs, batch_size=32, skip_existing=True)
527
+ >>> print(f"插入 {count} 个文档")
528
+ """
529
+ if not documents:
530
+ return 0
531
+
532
+ total_docs = len(documents)
533
+ logger.info(f"Starting batch upsert: {total_docs} documents, batch_size={batch_size}")
534
+
535
+ # 过滤已存在的文档
536
+ skipped = 0
537
+ if skip_existing:
538
+ existing_ids = self._get_existing_ids([doc.id for doc in documents])
539
+ skipped = len([doc for doc in documents if doc.id in existing_ids])
540
+ documents = [doc for doc in documents if doc.id not in existing_ids]
541
+ if skipped > 0:
542
+ logger.info(f"Skipped {skipped} existing documents")
543
+ if not documents:
544
+ return 0
545
+
546
+ # 批量插入
547
+ inserted = 0
548
+ total_batches = (len(documents) + batch_size - 1) // batch_size
549
+ iterator = range(0, len(documents), batch_size)
550
+
551
+ if show_progress:
552
+ try:
553
+ from tqdm import tqdm
554
+ iterator = tqdm(
555
+ iterator,
556
+ desc="Upserting",
557
+ total=total_batches,
558
+ unit="batch",
559
+ )
560
+ except ImportError:
561
+ logger.debug("tqdm not installed, progress bar disabled")
562
+
563
+ for i in iterator:
564
+ batch = documents[i:i + batch_size]
565
+ self.upsert(batch)
566
+ inserted += len(batch)
567
+
568
+ logger.info(f"Batch upsert completed: {inserted} inserted, {skipped} skipped")
569
+ return inserted
570
+
571
+ def _get_existing_ids(self, candidate_ids: List[str]) -> set:
572
+ """获取已存在的文档 ID 集合"""
573
+ existing_ids = set()
574
+ batch_size = 10000
575
+
576
+ for i in range(0, len(candidate_ids), batch_size):
577
+ batch_ids = candidate_ids[i:i + batch_size]
578
+ try:
579
+ results = self.collection.get(ids=batch_ids)
580
+ if results and results.get("ids"):
581
+ existing_ids.update(results["ids"])
582
+ except Exception:
583
+ pass # ID 不存在时忽略
584
+
585
+ return existing_ids
586
+
587
+ def get_all_ids(self) -> List[str]:
588
+ """获取所有文档 ID"""
589
+ results = self.collection.get(include=[])
590
+ return results.get("ids", [])
591
+
592
+ def migrate_to(
593
+ self,
594
+ target,
595
+ batch_size: int = 100,
596
+ skip_existing: bool = True,
597
+ show_progress: bool = True,
598
+ ) -> int:
599
+ """
600
+ 将当前 collection 的所有数据迁移到目标 retriever
601
+
602
+ Args:
603
+ target: 目标 retriever(ChromaRetriever 或 MilvusRetriever)
604
+ batch_size: 批处理大小
605
+ skip_existing: 是否跳过已存在的文档
606
+ show_progress: 是否显示进度条
607
+
608
+ Returns:
609
+ 迁移的文档数量
610
+ """
611
+ all_ids = self.get_all_ids()
612
+ if not all_ids:
613
+ logger.info("No documents to migrate")
614
+ return 0
615
+
616
+ total = len(all_ids)
617
+ logger.info(f"Starting migration: {total} documents")
618
+
619
+ migrated = 0
620
+ iterator = range(0, total, batch_size)
621
+
622
+ if show_progress:
623
+ try:
624
+ from tqdm import tqdm
625
+ iterator = tqdm(
626
+ iterator,
627
+ desc="Migrating",
628
+ total=(total + batch_size - 1) // batch_size,
629
+ unit="batch",
630
+ )
631
+ except ImportError:
632
+ pass
633
+
634
+ for i in iterator:
635
+ batch_ids = all_ids[i:i + batch_size]
636
+ documents = self.get(ids=batch_ids)
637
+ if documents:
638
+ migrated += target.upsert_batch(
639
+ documents,
640
+ batch_size=batch_size,
641
+ skip_existing=skip_existing,
642
+ show_progress=False,
643
+ )
644
+
645
+ logger.info(f"Migration completed: {migrated} documents migrated")
646
+ return migrated
647
+
648
+ def __repr__(self) -> str:
649
+ return (
650
+ f"ChromaRetriever("
651
+ f"collection={self.collection_name!r}, "
652
+ f"count={self.count()}, "
653
+ f"embedding={self.embedding.__class__.__name__})"
654
+ )