maque 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. maque/__init__.py +30 -0
  2. maque/__main__.py +926 -0
  3. maque/ai_platform/__init__.py +0 -0
  4. maque/ai_platform/crawl.py +45 -0
  5. maque/ai_platform/metrics.py +258 -0
  6. maque/ai_platform/nlp_preprocess.py +67 -0
  7. maque/ai_platform/webpage_screen_shot.py +195 -0
  8. maque/algorithms/__init__.py +78 -0
  9. maque/algorithms/bezier.py +15 -0
  10. maque/algorithms/bktree.py +117 -0
  11. maque/algorithms/core.py +104 -0
  12. maque/algorithms/hilbert.py +16 -0
  13. maque/algorithms/rate_function.py +92 -0
  14. maque/algorithms/transform.py +27 -0
  15. maque/algorithms/trie.py +272 -0
  16. maque/algorithms/utils.py +63 -0
  17. maque/algorithms/video.py +587 -0
  18. maque/api/__init__.py +1 -0
  19. maque/api/common.py +110 -0
  20. maque/api/fetch.py +26 -0
  21. maque/api/static/icon.png +0 -0
  22. maque/api/static/redoc.standalone.js +1782 -0
  23. maque/api/static/swagger-ui-bundle.js +3 -0
  24. maque/api/static/swagger-ui.css +3 -0
  25. maque/cli/__init__.py +1 -0
  26. maque/cli/clean_invisible_chars.py +324 -0
  27. maque/cli/core.py +34 -0
  28. maque/cli/groups/__init__.py +26 -0
  29. maque/cli/groups/config.py +205 -0
  30. maque/cli/groups/data.py +615 -0
  31. maque/cli/groups/doctor.py +259 -0
  32. maque/cli/groups/embedding.py +222 -0
  33. maque/cli/groups/git.py +29 -0
  34. maque/cli/groups/help.py +410 -0
  35. maque/cli/groups/llm.py +223 -0
  36. maque/cli/groups/mcp.py +241 -0
  37. maque/cli/groups/mllm.py +1795 -0
  38. maque/cli/groups/mllm_simple.py +60 -0
  39. maque/cli/groups/quant.py +210 -0
  40. maque/cli/groups/service.py +490 -0
  41. maque/cli/groups/system.py +570 -0
  42. maque/cli/mllm_run.py +1451 -0
  43. maque/cli/script.py +52 -0
  44. maque/cli/tree.py +49 -0
  45. maque/clustering/__init__.py +52 -0
  46. maque/clustering/analyzer.py +347 -0
  47. maque/clustering/clusterers.py +464 -0
  48. maque/clustering/sampler.py +134 -0
  49. maque/clustering/visualizer.py +205 -0
  50. maque/constant.py +13 -0
  51. maque/core.py +133 -0
  52. maque/cv/__init__.py +1 -0
  53. maque/cv/image.py +219 -0
  54. maque/cv/utils.py +68 -0
  55. maque/cv/video/__init__.py +3 -0
  56. maque/cv/video/keyframe_extractor.py +368 -0
  57. maque/embedding/__init__.py +43 -0
  58. maque/embedding/base.py +56 -0
  59. maque/embedding/multimodal.py +308 -0
  60. maque/embedding/server.py +523 -0
  61. maque/embedding/text.py +311 -0
  62. maque/git/__init__.py +24 -0
  63. maque/git/pure_git.py +912 -0
  64. maque/io/__init__.py +29 -0
  65. maque/io/core.py +38 -0
  66. maque/io/ops.py +194 -0
  67. maque/llm/__init__.py +111 -0
  68. maque/llm/backend.py +416 -0
  69. maque/llm/base.py +411 -0
  70. maque/llm/server.py +366 -0
  71. maque/mcp_server.py +1096 -0
  72. maque/mllm_data_processor_pipeline/__init__.py +17 -0
  73. maque/mllm_data_processor_pipeline/core.py +341 -0
  74. maque/mllm_data_processor_pipeline/example.py +291 -0
  75. maque/mllm_data_processor_pipeline/steps/__init__.py +56 -0
  76. maque/mllm_data_processor_pipeline/steps/data_alignment.py +267 -0
  77. maque/mllm_data_processor_pipeline/steps/data_loader.py +172 -0
  78. maque/mllm_data_processor_pipeline/steps/data_validation.py +304 -0
  79. maque/mllm_data_processor_pipeline/steps/format_conversion.py +411 -0
  80. maque/mllm_data_processor_pipeline/steps/mllm_annotation.py +331 -0
  81. maque/mllm_data_processor_pipeline/steps/mllm_refinement.py +446 -0
  82. maque/mllm_data_processor_pipeline/steps/result_validation.py +501 -0
  83. maque/mllm_data_processor_pipeline/web_app.py +317 -0
  84. maque/nlp/__init__.py +14 -0
  85. maque/nlp/ngram.py +9 -0
  86. maque/nlp/parser.py +63 -0
  87. maque/nlp/risk_matcher.py +543 -0
  88. maque/nlp/sentence_splitter.py +202 -0
  89. maque/nlp/simple_tradition_cvt.py +31 -0
  90. maque/performance/__init__.py +21 -0
  91. maque/performance/_measure_time.py +70 -0
  92. maque/performance/_profiler.py +367 -0
  93. maque/performance/_stat_memory.py +51 -0
  94. maque/pipelines/__init__.py +15 -0
  95. maque/pipelines/clustering.py +252 -0
  96. maque/quantization/__init__.py +42 -0
  97. maque/quantization/auto_round.py +120 -0
  98. maque/quantization/base.py +145 -0
  99. maque/quantization/bitsandbytes.py +127 -0
  100. maque/quantization/llm_compressor.py +102 -0
  101. maque/retriever/__init__.py +35 -0
  102. maque/retriever/chroma.py +654 -0
  103. maque/retriever/document.py +140 -0
  104. maque/retriever/milvus.py +1140 -0
  105. maque/table_ops/__init__.py +1 -0
  106. maque/table_ops/core.py +133 -0
  107. maque/table_viewer/__init__.py +4 -0
  108. maque/table_viewer/download_assets.py +57 -0
  109. maque/table_viewer/server.py +698 -0
  110. maque/table_viewer/static/element-plus-icons.js +5791 -0
  111. maque/table_viewer/static/element-plus.css +1 -0
  112. maque/table_viewer/static/element-plus.js +65236 -0
  113. maque/table_viewer/static/main.css +268 -0
  114. maque/table_viewer/static/main.js +669 -0
  115. maque/table_viewer/static/vue.global.js +18227 -0
  116. maque/table_viewer/templates/index.html +401 -0
  117. maque/utils/__init__.py +56 -0
  118. maque/utils/color.py +68 -0
  119. maque/utils/color_string.py +45 -0
  120. maque/utils/compress.py +66 -0
  121. maque/utils/constant.py +183 -0
  122. maque/utils/core.py +261 -0
  123. maque/utils/cursor.py +143 -0
  124. maque/utils/distance.py +58 -0
  125. maque/utils/docker.py +96 -0
  126. maque/utils/downloads.py +51 -0
  127. maque/utils/excel_helper.py +542 -0
  128. maque/utils/helper_metrics.py +121 -0
  129. maque/utils/helper_parser.py +168 -0
  130. maque/utils/net.py +64 -0
  131. maque/utils/nvidia_stat.py +140 -0
  132. maque/utils/ops.py +53 -0
  133. maque/utils/packages.py +31 -0
  134. maque/utils/path.py +57 -0
  135. maque/utils/tar.py +260 -0
  136. maque/utils/untar.py +129 -0
  137. maque/web/__init__.py +0 -0
  138. maque/web/image_downloader.py +1410 -0
  139. maque-0.2.1.dist-info/METADATA +450 -0
  140. maque-0.2.1.dist-info/RECORD +143 -0
  141. maque-0.2.1.dist-info/WHEEL +4 -0
  142. maque-0.2.1.dist-info/entry_points.txt +3 -0
  143. maque-0.2.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1140 @@
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Milvus 向量数据库检索器实现
6
+ """
7
+
8
+ from typing import List, Optional, Union, Literal, TYPE_CHECKING
9
+
10
+ from loguru import logger
11
+
12
+ from ..embedding.base import BaseEmbedding
13
+ from .document import Document, SearchResult, Modality, _content_hash
14
+
15
+ if TYPE_CHECKING:
16
+ from pymilvus import Collection
17
+
18
+
19
+ DistanceMetric = Literal["COSINE", "L2", "IP"]
20
+ ScalarType = Literal["VARCHAR", "INT64", "INT32", "INT16", "INT8", "FLOAT", "DOUBLE", "BOOL", "JSON", "ARRAY"]
21
+
22
+
23
+ class MilvusRetriever:
24
+ """
25
+ 基于 Milvus 的检索器
26
+ 支持文本和图片的向量检索
27
+ """
28
+
29
+ # 预设索引配置
30
+ INDEX_PRESETS = {
31
+ "AUTOINDEX": {
32
+ "index_type": "AUTOINDEX",
33
+ "index_params": {},
34
+ "search_params": {},
35
+ },
36
+ "HNSW": {
37
+ "index_type": "HNSW",
38
+ "index_params": {"M": 16, "efConstruction": 256},
39
+ "search_params": {"ef": 128},
40
+ },
41
+ "IVF_FLAT": {
42
+ "index_type": "IVF_FLAT",
43
+ "index_params": {"nlist": 1024},
44
+ "search_params": {"nprobe": 16},
45
+ },
46
+ "FLAT": {
47
+ "index_type": "FLAT",
48
+ "index_params": {},
49
+ "search_params": {},
50
+ },
51
+ }
52
+
53
+ # Scalar 类型映射
54
+ SCALAR_TYPE_MAP = {
55
+ "VARCHAR": ("VARCHAR", {"max_length": 256}),
56
+ "INT64": ("INT64", {}),
57
+ "INT32": ("INT32", {}),
58
+ "INT16": ("INT16", {}),
59
+ "INT8": ("INT8", {}),
60
+ "FLOAT": ("FLOAT", {}),
61
+ "DOUBLE": ("DOUBLE", {}),
62
+ "BOOL": ("BOOL", {}),
63
+ "JSON": ("JSON", {}),
64
+ "ARRAY": ("ARRAY", {"element_type": "VARCHAR", "max_capacity": 256, "max_length": 256}),
65
+ }
66
+
67
+ def __init__(
68
+ self,
69
+ embedding: BaseEmbedding,
70
+ host: str = "localhost",
71
+ port: int = 19530,
72
+ db_name: str = "default",
73
+ collection_name: str = "default",
74
+ distance_metric: DistanceMetric = "COSINE",
75
+ auto_create: bool = True,
76
+ index_config: Optional[dict] = None,
77
+ scalar_fields: Optional[List[dict]] = None,
78
+ primary_key: str = "id",
79
+ field_mapping: Optional[dict] = None,
80
+ ):
81
+ """
82
+ 初始化检索器
83
+
84
+ Args:
85
+ embedding: Embedding 实例
86
+ host: Milvus 服务地址
87
+ port: Milvus 服务端口
88
+ db_name: 数据库名称
89
+ collection_name: 集合名称
90
+ distance_metric: 距离度量方式 (COSINE/L2/IP)
91
+ auto_create: 是否自动创建集合
92
+ index_config: 索引配置,可选项:
93
+ - None: 使用默认 HNSW 配置
94
+ - "AUTOINDEX" / "HNSW" / "IVF_FLAT" / "FLAT": 使用预设配置
95
+ - dict: 自定义配置,如:
96
+ {
97
+ "index_type": "HNSW",
98
+ "index_params": {"M": 16, "efConstruction": 256},
99
+ "search_params": {"ef": 128},
100
+ "id_max_length": 256,
101
+ "content_max_length": 65535,
102
+ }
103
+ scalar_fields: 【创建模式】额外的 scalar 字段定义,用于高效过滤,如:
104
+ [
105
+ {"name": "category", "dtype": "VARCHAR", "max_length": 64},
106
+ {"name": "timestamp", "dtype": "INT64"},
107
+ {"name": "score", "dtype": "FLOAT"},
108
+ {"name": "tags", "dtype": "ARRAY", "element_type": "VARCHAR", "max_capacity": 64, "max_length": 32},
109
+ ]
110
+ 支持的类型: VARCHAR, INT64, INT32, INT16, INT8, FLOAT, DOUBLE, BOOL, JSON, ARRAY
111
+ ARRAY 类型需要额外参数: element_type (元素类型), max_capacity (最大容量)
112
+ 字段值从 Document.metadata 中自动提取
113
+ 注意:读取已存在的 collection 时不需要此参数,字段类型会从 schema 自动提取
114
+ primary_key: 主键字段名称,默认为 "id",可自定义如 "user_id"、"content_id" 等
115
+ field_mapping: 【读取模式】字段映射,用于读取已存在的 collection,如:
116
+ {
117
+ "primary_key": "word_id", # 主键字段名
118
+ "content": "word", # 内容字段名
119
+ "embedding": "vector", # 向量字段名
120
+ "modality": None, # 模态字段名(可选,None 表示不存在)
121
+ "metadata": None, # metadata 字段名(可选)
122
+ }
123
+ 其他字段会自动从 schema 提取,类型也会自动识别
124
+
125
+ 使用模式:
126
+ - 创建模式:使用 scalar_fields 定义额外字段
127
+ - 读取模式:使用 field_mapping 映射核心字段,其他字段和类型自动从 schema 提取
128
+ """
129
+ try:
130
+ from pymilvus import (
131
+ connections,
132
+ Collection,
133
+ FieldSchema,
134
+ CollectionSchema,
135
+ DataType,
136
+ utility,
137
+ db,
138
+ )
139
+ except ImportError:
140
+ raise ImportError(
141
+ "pymilvus is required for MilvusRetriever. "
142
+ "Install it with: pip install pymilvus"
143
+ )
144
+
145
+ self.embedding = embedding
146
+ self.host = host
147
+ self.port = port
148
+ self.db_name = db_name
149
+ self.collection_name = collection_name
150
+ self.distance_metric = distance_metric
151
+ self._dimension = embedding.dimension
152
+
153
+ # 解析索引配置
154
+ config = self._parse_index_config(index_config)
155
+ self._index_type = config["index_type"]
156
+ self._index_params = config["index_params"]
157
+ self._search_params = config["search_params"]
158
+ self._id_max_length = config.get("id_max_length", 256)
159
+ self._content_max_length = config.get("content_max_length", 65535)
160
+
161
+ # 解析字段映射
162
+ self._field_mapping = self._parse_field_mapping(field_mapping, primary_key)
163
+ self._primary_key = self._field_mapping["primary_key"]
164
+ self._use_field_mapping = field_mapping is not None
165
+
166
+ # 解析 scalar 字段配置(仅在创建模式下使用)
167
+ self._scalar_fields = self._parse_scalar_fields(scalar_fields or [])
168
+
169
+ # 连接 Milvus
170
+ self._connection_alias = f"milvus_{db_name}_{collection_name}"
171
+ logger.debug(f"Connecting to Milvus at {host}:{port}, db={db_name}")
172
+ connections.connect(
173
+ alias=self._connection_alias,
174
+ host=host,
175
+ port=port,
176
+ db_name=db_name,
177
+ )
178
+
179
+ # 获取或创建集合
180
+ if utility.has_collection(collection_name, using=self._connection_alias):
181
+ logger.debug(f"Loading existing collection: {collection_name}")
182
+ self.collection = Collection(
183
+ name=collection_name,
184
+ using=self._connection_alias,
185
+ )
186
+ # 从 schema 提取额外字段和类型
187
+ self._extra_fields, self._field_types = self._extract_extra_fields_from_schema()
188
+ # 补充 scalar_fields 中的类型信息(用户可能提供更精确的类型)
189
+ for sf in self._scalar_fields:
190
+ self._field_types[sf["name"]] = sf["dtype"]
191
+ # 加载集合到内存
192
+ self.collection.load()
193
+ logger.info(f"Collection '{collection_name}' loaded, {self.count()} documents")
194
+ elif auto_create:
195
+ # 创建模式:使用 scalar_fields
196
+ logger.debug(f"Creating new collection: {collection_name}")
197
+ self._extra_fields = [sf["name"] for sf in self._scalar_fields]
198
+ self._field_types = {sf["name"]: sf["dtype"] for sf in self._scalar_fields}
199
+ self.collection = self._create_collection()
200
+ logger.info(f"Collection '{collection_name}' created")
201
+ else:
202
+ raise ValueError(f"Collection '{collection_name}' does not exist in database '{db_name}'")
203
+
204
+ def _parse_field_mapping(self, field_mapping: Optional[dict], primary_key: str) -> dict:
205
+ """
206
+ 解析字段映射配置
207
+
208
+ Args:
209
+ field_mapping: 用户提供的字段映射
210
+ primary_key: 默认主键名
211
+
212
+ Returns:
213
+ 标准化的字段映射
214
+ """
215
+ # 默认映射
216
+ default_mapping = {
217
+ "primary_key": primary_key,
218
+ "content": "content",
219
+ "embedding": "embedding",
220
+ "modality": "modality",
221
+ "metadata": "metadata",
222
+ }
223
+
224
+ if not field_mapping:
225
+ return default_mapping
226
+
227
+ # 合并用户提供的映射
228
+ result = default_mapping.copy()
229
+ result.update(field_mapping)
230
+ return result
231
+
232
+ def _extract_extra_fields_from_schema(self) -> tuple[List[str], dict]:
233
+ """
234
+ 从已存在的 collection schema 中提取额外字段名和类型
235
+
236
+ Returns:
237
+ (额外字段名列表, 字段类型映射)
238
+ """
239
+ from pymilvus import DataType
240
+
241
+ # DataType 到字符串的映射
242
+ dtype_to_str = {
243
+ DataType.VARCHAR: "VARCHAR",
244
+ DataType.INT64: "INT64",
245
+ DataType.INT32: "INT32",
246
+ DataType.INT16: "INT16",
247
+ DataType.INT8: "INT8",
248
+ DataType.FLOAT: "FLOAT",
249
+ DataType.DOUBLE: "DOUBLE",
250
+ DataType.BOOL: "BOOL",
251
+ DataType.JSON: "JSON",
252
+ DataType.ARRAY: "ARRAY",
253
+ }
254
+
255
+ # 已映射的字段名
256
+ mapped_fields = {
257
+ self._field_mapping["primary_key"],
258
+ self._field_mapping["content"],
259
+ self._field_mapping["embedding"],
260
+ }
261
+ # 可选的映射字段
262
+ if self._field_mapping.get("modality"):
263
+ mapped_fields.add(self._field_mapping["modality"])
264
+ if self._field_mapping.get("metadata"):
265
+ mapped_fields.add(self._field_mapping["metadata"])
266
+
267
+ extra_fields = []
268
+ field_types = {}
269
+ for field in self.collection.schema.fields:
270
+ # 跳过已映射的字段和向量字段
271
+ if field.name in mapped_fields:
272
+ continue
273
+ if field.dtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR):
274
+ continue
275
+ extra_fields.append(field.name)
276
+ field_types[field.name] = dtype_to_str.get(field.dtype, "JSON")
277
+
278
+ return extra_fields, field_types
279
+
280
+ def _parse_index_config(self, index_config) -> dict:
281
+ """解析索引配置"""
282
+ # 默认使用 AUTOINDEX
283
+ if index_config is None:
284
+ return self.INDEX_PRESETS["AUTOINDEX"].copy()
285
+
286
+ # 字符串预设
287
+ if isinstance(index_config, str):
288
+ if index_config not in self.INDEX_PRESETS:
289
+ raise ValueError(f"Unknown preset: {index_config}, available: {list(self.INDEX_PRESETS.keys())}")
290
+ return self.INDEX_PRESETS[index_config].copy()
291
+
292
+ # 自定义 dict
293
+ base = self.INDEX_PRESETS.get(index_config.get("index_type", "AUTOINDEX"), {}).copy()
294
+ base.update(index_config)
295
+ return base
296
+
297
+ def _parse_scalar_fields(self, scalar_fields: List[dict]) -> List[dict]:
298
+ """
299
+ 解析 scalar 字段配置
300
+
301
+ Args:
302
+ scalar_fields: 字段定义列表,如:
303
+ [
304
+ {"name": "category", "dtype": "VARCHAR", "max_length": 64},
305
+ {"name": "timestamp", "dtype": "INT64"},
306
+ {"name": "tags", "dtype": "ARRAY", "element_type": "VARCHAR", "max_capacity": 64, "max_length": 32},
307
+ ]
308
+
309
+ Returns:
310
+ 标准化的字段配置列表
311
+ """
312
+ reserved_names = {self._primary_key, "content", "modality", "metadata", "embedding"}
313
+ parsed = []
314
+
315
+ for field in scalar_fields:
316
+ name = field.get("name")
317
+ dtype = field.get("dtype", "VARCHAR").upper()
318
+
319
+ if not name:
320
+ raise ValueError("Scalar field must have 'name'")
321
+ if name in reserved_names:
322
+ raise ValueError(f"Field name '{name}' is reserved")
323
+ if dtype not in self.SCALAR_TYPE_MAP:
324
+ raise ValueError(f"Unknown dtype '{dtype}', available: {list(self.SCALAR_TYPE_MAP.keys())}")
325
+
326
+ # 构建标准化配置
327
+ parsed_field = {"name": name, "dtype": dtype}
328
+
329
+ # VARCHAR 需要 max_length
330
+ if dtype == "VARCHAR":
331
+ parsed_field["max_length"] = field.get("max_length", 256)
332
+ # ARRAY 需要 element_type, max_capacity, 以及可能的 max_length
333
+ elif dtype == "ARRAY":
334
+ element_type = field.get("element_type", "VARCHAR").upper()
335
+ parsed_field["element_type"] = element_type
336
+ parsed_field["max_capacity"] = field.get("max_capacity", 256)
337
+ # 如果元素类型是 VARCHAR,需要 max_length
338
+ if element_type == "VARCHAR":
339
+ parsed_field["max_length"] = field.get("max_length", 256)
340
+
341
+ parsed.append(parsed_field)
342
+
343
+ return parsed
344
+
345
+ def _create_collection(self) -> "Collection":
346
+ """创建集合"""
347
+ from pymilvus import (
348
+ Collection,
349
+ FieldSchema,
350
+ CollectionSchema,
351
+ DataType,
352
+ )
353
+
354
+ # 基础字段(使用字段映射)
355
+ fm = self._field_mapping
356
+ fields = [
357
+ FieldSchema(name=fm["primary_key"], dtype=DataType.VARCHAR, max_length=self._id_max_length, is_primary=True),
358
+ FieldSchema(name=fm["content"], dtype=DataType.VARCHAR, max_length=self._content_max_length),
359
+ FieldSchema(name=fm["embedding"], dtype=DataType.FLOAT_VECTOR, dim=self._dimension),
360
+ ]
361
+
362
+ # 可选字段
363
+ if fm.get("modality"):
364
+ fields.append(FieldSchema(name=fm["modality"], dtype=DataType.VARCHAR, max_length=32))
365
+ if fm.get("metadata"):
366
+ fields.append(FieldSchema(name=fm["metadata"], dtype=DataType.JSON))
367
+
368
+ # 添加额外的 scalar 字段
369
+ for sf in self._scalar_fields:
370
+ dtype = getattr(DataType, sf["dtype"])
371
+ if sf["dtype"] == "VARCHAR":
372
+ fields.append(FieldSchema(name=sf["name"], dtype=dtype, max_length=sf["max_length"]))
373
+ elif sf["dtype"] == "ARRAY":
374
+ element_type = getattr(DataType, sf["element_type"])
375
+ if sf["element_type"] == "VARCHAR":
376
+ fields.append(FieldSchema(
377
+ name=sf["name"],
378
+ dtype=dtype,
379
+ element_type=element_type,
380
+ max_capacity=sf["max_capacity"],
381
+ max_length=sf["max_length"],
382
+ ))
383
+ else:
384
+ fields.append(FieldSchema(
385
+ name=sf["name"],
386
+ dtype=dtype,
387
+ element_type=element_type,
388
+ max_capacity=sf["max_capacity"],
389
+ ))
390
+ else:
391
+ fields.append(FieldSchema(name=sf["name"], dtype=dtype))
392
+
393
+ schema = CollectionSchema(
394
+ fields=fields,
395
+ description=f"Collection for {self.collection_name}",
396
+ )
397
+
398
+ collection = Collection(
399
+ name=self.collection_name,
400
+ schema=schema,
401
+ using=self._connection_alias,
402
+ )
403
+
404
+ # 创建索引
405
+ index_params = {
406
+ "metric_type": self.distance_metric,
407
+ "index_type": self._index_type,
408
+ "params": self._index_params,
409
+ }
410
+ collection.create_index(field_name=fm["embedding"], index_params=index_params)
411
+
412
+ # 加载到内存
413
+ collection.load()
414
+
415
+ return collection
416
+
417
+ def _get_input_type(self, modality: Modality) -> str:
418
+ """获取 embedding 的 input_type 参数"""
419
+ return "image" if modality == "image" else "text"
420
+
421
+ def _embed_documents(self, documents: List[Document]) -> List[List[float]]:
422
+ """对文档进行向量化"""
423
+ if not documents:
424
+ return []
425
+
426
+ has_image = any(doc.is_image for doc in documents)
427
+ if has_image and not self.embedding.supports_image:
428
+ raise ValueError(
429
+ f"Embedding 不支持图片,但文档中包含图片。"
430
+ f"请使用 MultiModalEmbedding。"
431
+ )
432
+
433
+ if has_image:
434
+ embeddings = []
435
+ for doc in documents:
436
+ input_type = self._get_input_type(doc.modality)
437
+ vec = self.embedding.embed([doc.content], input_type=input_type)[0]
438
+ embeddings.append(vec)
439
+ return embeddings
440
+ else:
441
+ contents = [doc.content for doc in documents]
442
+ return self.embedding.embed(contents)
443
+
444
+ def _embed_query(
445
+ self,
446
+ query: str,
447
+ query_type: Modality = "text",
448
+ ) -> List[float]:
449
+ """对查询进行向量化"""
450
+ if query_type == "image" and not self.embedding.supports_image:
451
+ raise ValueError("Embedding 不支持图片查询")
452
+
453
+ if self.embedding.supports_image:
454
+ input_type = self._get_input_type(query_type)
455
+ return self.embedding.embed([query], input_type=input_type)[0]
456
+ else:
457
+ return self.embedding.embed([query])[0]
458
+
459
+ def _embed_queries(
460
+ self,
461
+ queries: List[str],
462
+ query_type: Modality = "text",
463
+ ) -> List[List[float]]:
464
+ """对多个查询进行批量向量化"""
465
+ if not queries:
466
+ return []
467
+
468
+ if query_type == "image" and not self.embedding.supports_image:
469
+ raise ValueError("Embedding 不支持图片查询")
470
+
471
+ if self.embedding.supports_image:
472
+ input_type = self._get_input_type(query_type)
473
+ return self.embedding.embed(queries, input_type=input_type)
474
+ else:
475
+ return self.embedding.embed(queries)
476
+
477
+ def _prepare_insert_data(self, documents: List[Document], embeddings: List[List[float]]) -> List[dict]:
478
+ """
479
+ 准备插入数据,包含基础字段和额外字段
480
+
481
+ Args:
482
+ documents: 文档列表
483
+ embeddings: 向量列表
484
+
485
+ Returns:
486
+ 行格式的数据列表(每个元素是一个字典,代表一行)
487
+ """
488
+ fm = self._field_mapping
489
+
490
+ rows = []
491
+ for i, doc in enumerate(documents):
492
+ row = {
493
+ fm["primary_key"]: doc.id,
494
+ fm["content"]: doc.content,
495
+ fm["embedding"]: embeddings[i],
496
+ }
497
+
498
+ # 可选字段
499
+ if fm.get("modality"):
500
+ row[fm["modality"]] = doc.modality
501
+ if fm.get("metadata"):
502
+ row[fm["metadata"]] = doc.metadata
503
+
504
+ # 添加所有额外字段值(从 metadata 中提取)
505
+ for field_name in self._extra_fields:
506
+ value = doc.metadata.get(field_name)
507
+ dtype = self._field_types.get(field_name)
508
+
509
+ # 如果值为 None,根据类型设置默认值
510
+ if value is None:
511
+ if dtype == "VARCHAR":
512
+ value = ""
513
+ elif dtype in ("INT64", "INT32", "INT16", "INT8"):
514
+ value = 0
515
+ elif dtype in ("FLOAT", "DOUBLE"):
516
+ value = 0.0
517
+ elif dtype == "BOOL":
518
+ value = False
519
+ else:
520
+ # JSON / ARRAY 等类型默认空列表
521
+ value = []
522
+
523
+ row[field_name] = value
524
+
525
+ rows.append(row)
526
+
527
+ return rows
528
+
529
+ # ========== 索引操作 ==========
530
+
531
+ def add(
532
+ self,
533
+ documents: Union[Document, List[Document]],
534
+ skip_existing: bool = False,
535
+ ) -> List[str]:
536
+ """
537
+ 添加文档
538
+
539
+ Args:
540
+ documents: 单个文档或文档列表
541
+ skip_existing: 是否跳过已存在的文档
542
+
543
+ Returns:
544
+ 添加的文档 ID 列表
545
+ """
546
+ if isinstance(documents, Document):
547
+ documents = [documents]
548
+
549
+ if not documents:
550
+ return []
551
+
552
+ # 过滤已存在的文档
553
+ if skip_existing:
554
+ existing_ids = self._get_existing_ids([doc.id for doc in documents])
555
+ skipped = len([doc for doc in documents if doc.id in existing_ids])
556
+ documents = [doc for doc in documents if doc.id not in existing_ids]
557
+ if skipped > 0:
558
+ logger.debug(f"Skipped {skipped} existing documents")
559
+ if not documents:
560
+ return []
561
+
562
+ # 向量化
563
+ embeddings = self._embed_documents(documents)
564
+
565
+ # 准备数据(包含 scalar 字段)
566
+ data = self._prepare_insert_data(documents, embeddings)
567
+
568
+ # 插入数据
569
+ self.collection.insert(data)
570
+ self.collection.flush()
571
+ logger.debug(f"Added {len(documents)} documents")
572
+
573
+ return [doc.id for doc in documents]
574
+
575
+ def upsert(
576
+ self,
577
+ documents: Union[Document, List[Document]],
578
+ skip_existing: bool = False,
579
+ ) -> List[str]:
580
+ """
581
+ 添加或更新文档
582
+
583
+ Args:
584
+ documents: 单个文档或文档列表
585
+ skip_existing: 是否跳过已存在的文档(为 True 时行为与 add 相同)
586
+
587
+ Returns:
588
+ upsert 的文档 ID 列表
589
+ """
590
+ if isinstance(documents, Document):
591
+ documents = [documents]
592
+
593
+ if not documents:
594
+ return []
595
+
596
+ # 过滤已存在的文档
597
+ if skip_existing:
598
+ existing_ids = self._get_existing_ids([doc.id for doc in documents])
599
+ documents = [doc for doc in documents if doc.id not in existing_ids]
600
+ if not documents:
601
+ return []
602
+
603
+ # 向量化
604
+ embeddings = self._embed_documents(documents)
605
+
606
+ # 准备数据(包含 scalar 字段)
607
+ data = self._prepare_insert_data(documents, embeddings)
608
+
609
+ # Milvus upsert
610
+ self.collection.upsert(data)
611
+ self.collection.flush()
612
+ logger.debug(f"Upserted {len(documents)} documents")
613
+
614
+ return [doc.id for doc in documents]
615
+
616
+ def delete(self, ids: Union[str, List[str]]) -> None:
617
+ """
618
+ 删除文档
619
+
620
+ Args:
621
+ ids: 单个 ID 或 ID 列表
622
+ """
623
+ if isinstance(ids, str):
624
+ ids = [ids]
625
+
626
+ # 构建删除表达式
627
+ ids_str = ", ".join([f'"{id}"' for id in ids])
628
+ expr = f"{self._primary_key} in [{ids_str}]"
629
+ self.collection.delete(expr)
630
+ self.collection.flush()
631
+ logger.debug(f"Deleted {len(ids)} documents")
632
+
633
+ def delete_by_content(self, contents: Union[str, List[str]]) -> None:
634
+ """
635
+ 根据内容删除文档
636
+
637
+ Args:
638
+ contents: 单个内容或内容列表
639
+ """
640
+ if isinstance(contents, str):
641
+ contents = [contents]
642
+
643
+ ids = [_content_hash(content) for content in contents]
644
+ self.delete(ids)
645
+
646
+ # ========== 检索操作 ==========
647
+
648
+ def _get_output_fields(self) -> List[str]:
649
+ """获取查询时需要返回的所有字段"""
650
+ fm = self._field_mapping
651
+ fields = [fm["primary_key"], fm["content"]]
652
+
653
+ # 可选字段(可能为 None)
654
+ if fm.get("modality"):
655
+ fields.append(fm["modality"])
656
+ if fm.get("metadata"):
657
+ fields.append(fm["metadata"])
658
+
659
+ # 额外字段(统一由 _extra_fields 管理)
660
+ fields.extend(self._extra_fields)
661
+
662
+ return fields
663
+
664
+ def search(
665
+ self,
666
+ query: str,
667
+ top_k: int = 5,
668
+ query_type: Modality = "text",
669
+ expr: Optional[str] = None,
670
+ ) -> List[SearchResult]:
671
+ """
672
+ 检索相似文档
673
+
674
+ Args:
675
+ query: 查询内容(文本或图片路径/URL)
676
+ top_k: 返回数量
677
+ query_type: 查询类型 "text" / "image"
678
+ expr: Milvus 过滤表达式 (例如: 'metadata["category"] == "tech"')
679
+
680
+ Returns:
681
+ SearchResult 列表
682
+ """
683
+ query_embedding = self._embed_query(query, query_type)
684
+
685
+ search_params = {
686
+ "metric_type": self.distance_metric,
687
+ "params": self._search_params,
688
+ }
689
+
690
+ results = self.collection.search(
691
+ data=[query_embedding],
692
+ anns_field=self._field_mapping["embedding"],
693
+ param=search_params,
694
+ limit=top_k,
695
+ expr=expr,
696
+ output_fields=self._get_output_fields(),
697
+ )
698
+
699
+ parsed = self._parse_results(results)
700
+ return parsed[0] if parsed else []
701
+
702
+ def search_by_vector(
703
+ self,
704
+ vector: List[float],
705
+ top_k: int = 5,
706
+ expr: Optional[str] = None,
707
+ ) -> List[SearchResult]:
708
+ """
709
+ 直接使用向量检索
710
+
711
+ Args:
712
+ vector: 查询向量
713
+ top_k: 返回数量
714
+ expr: Milvus 过滤表达式
715
+
716
+ Returns:
717
+ SearchResult 列表
718
+ """
719
+ search_params = {
720
+ "metric_type": self.distance_metric,
721
+ "params": self._search_params,
722
+ }
723
+
724
+ results = self.collection.search(
725
+ data=[vector],
726
+ anns_field=self._field_mapping["embedding"],
727
+ param=search_params,
728
+ limit=top_k,
729
+ expr=expr,
730
+ output_fields=self._get_output_fields(),
731
+ )
732
+
733
+ parsed = self._parse_results(results)
734
+ return parsed[0] if parsed else []
735
+
736
+ def search_batch(
737
+ self,
738
+ queries: List[str],
739
+ top_k: int = 5,
740
+ query_type: Modality = "text",
741
+ expr: Optional[str] = None,
742
+ ) -> List[List[SearchResult]]:
743
+ """
744
+ 批量检索相似文档
745
+
746
+ Args:
747
+ queries: 查询内容列表(文本或图片路径/URL)
748
+ top_k: 每个查询返回的数量
749
+ query_type: 查询类型 "text" / "image"
750
+ expr: Milvus 过滤表达式
751
+
752
+ Returns:
753
+ SearchResult 列表的列表,每个查询对应一个结果列表
754
+
755
+ Example:
756
+ >>> results = retriever.search_batch(["query1", "query2"], top_k=5)
757
+ >>> for i, query_results in enumerate(results):
758
+ ... print(f"Query {i}: {len(query_results)} results")
759
+ """
760
+ if not queries:
761
+ return []
762
+
763
+ # 批量向量化查询
764
+ query_embeddings = self._embed_queries(queries, query_type)
765
+
766
+ search_params = {
767
+ "metric_type": self.distance_metric,
768
+ "params": self._search_params,
769
+ }
770
+
771
+ results = self.collection.search(
772
+ data=query_embeddings,
773
+ anns_field=self._field_mapping["embedding"],
774
+ param=search_params,
775
+ limit=top_k,
776
+ expr=expr,
777
+ output_fields=self._get_output_fields(),
778
+ )
779
+
780
+ return self._parse_results(results)
781
+
782
+ def search_by_vectors(
783
+ self,
784
+ vectors: List[List[float]],
785
+ top_k: int = 5,
786
+ expr: Optional[str] = None,
787
+ ) -> List[List[SearchResult]]:
788
+ """
789
+ 批量使用向量检索
790
+
791
+ Args:
792
+ vectors: 查询向量列表
793
+ top_k: 每个查询返回的数量
794
+ expr: Milvus 过滤表达式
795
+
796
+ Returns:
797
+ SearchResult 列表的列表,每个向量对应一个结果列表
798
+
799
+ Example:
800
+ >>> vectors = [[0.1, 0.2, ...], [0.3, 0.4, ...]]
801
+ >>> results = retriever.search_by_vectors(vectors, top_k=5)
802
+ """
803
+ if not vectors:
804
+ return []
805
+
806
+ search_params = {
807
+ "metric_type": self.distance_metric,
808
+ "params": self._search_params,
809
+ }
810
+
811
+ results = self.collection.search(
812
+ data=vectors,
813
+ anns_field=self._field_mapping["embedding"],
814
+ param=search_params,
815
+ limit=top_k,
816
+ expr=expr,
817
+ output_fields=self._get_output_fields(),
818
+ )
819
+
820
+ return self._parse_results(results)
821
+
822
+ def _parse_results(self, results) -> List[List[SearchResult]]:
823
+ """
824
+ 解析 Milvus 返回结果
825
+
826
+ Args:
827
+ results: Milvus search 返回的结果
828
+
829
+ Returns:
830
+ SearchResult 列表的列表,每个查询对应一个结果列表
831
+ """
832
+ if not results or len(results) == 0:
833
+ return []
834
+
835
+ fm = self._field_mapping
836
+ all_results = []
837
+
838
+ for hits in results:
839
+ query_results = []
840
+ for hit in hits:
841
+ entity = hit.entity
842
+
843
+ # 距离转相似度
844
+ distance = hit.distance
845
+ if self.distance_metric == "COSINE":
846
+ score = distance # Milvus COSINE 返回的是相似度
847
+ elif self.distance_metric == "IP":
848
+ score = distance # IP 内积越大越相似
849
+ else:
850
+ score = -distance # L2 距离越小越好
851
+
852
+ # 合并 metadata 和额外字段
853
+ # 注意:pymilvus Hit 对象的 `in` 操作符不可靠,需要直接用 get
854
+ metadata_field = fm.get("metadata")
855
+ metadata = dict(entity.get(metadata_field) or {}) if metadata_field else {}
856
+
857
+ for name in self._extra_fields:
858
+ value = entity.get(name)
859
+ if value is not None:
860
+ metadata[name] = value
861
+
862
+ # 获取 modality(可能不存在)
863
+ modality_field = fm.get("modality")
864
+ modality = entity.get(modality_field, "text") if modality_field else "text"
865
+
866
+ query_results.append(SearchResult(
867
+ id=entity.get(fm["primary_key"], ""),
868
+ content=entity.get(fm["content"], ""),
869
+ score=score,
870
+ modality=modality,
871
+ metadata=metadata,
872
+ ))
873
+
874
+ all_results.append(query_results)
875
+
876
+ return all_results
877
+
878
+ # ========== 管理操作 ==========
879
+
880
+ def get(
881
+ self,
882
+ ids: Optional[Union[str, List[str]]] = None,
883
+ expr: Optional[str] = None,
884
+ limit: Optional[int] = None,
885
+ ) -> List[Document]:
886
+ """
887
+ 获取文档
888
+
889
+ Args:
890
+ ids: 文档 ID 或 ID 列表
891
+ expr: Milvus 过滤表达式
892
+ limit: 返回数量限制
893
+
894
+ Returns:
895
+ Document 列表
896
+ """
897
+ if isinstance(ids, str):
898
+ ids = [ids]
899
+
900
+ # 构建查询表达式
901
+ pk = self._primary_key
902
+ if ids:
903
+ ids_str = ", ".join([f'"{id}"' for id in ids])
904
+ query_expr = f"{pk} in [{ids_str}]"
905
+ if expr:
906
+ query_expr = f"({query_expr}) and ({expr})"
907
+ else:
908
+ query_expr = expr or ""
909
+
910
+ results = self.collection.query(
911
+ expr=query_expr if query_expr else f"{pk} != ''",
912
+ output_fields=self._get_output_fields(),
913
+ limit=limit or 16384,
914
+ )
915
+
916
+ fm = self._field_mapping
917
+ documents = []
918
+
919
+ for item in results:
920
+ # 合并 metadata 和额外字段
921
+ metadata_field = fm.get("metadata")
922
+ metadata = dict(item.get(metadata_field) or {}) if metadata_field else {}
923
+
924
+ for name in self._extra_fields:
925
+ value = item.get(name)
926
+ if value is not None:
927
+ metadata[name] = value
928
+
929
+ # 获取 modality(可能不存在)
930
+ modality_field = fm.get("modality")
931
+ modality = item.get(modality_field, "text") if modality_field else "text"
932
+
933
+ documents.append(Document(
934
+ id=item.get(pk, ""),
935
+ content=item.get(fm["content"], ""),
936
+ modality=modality,
937
+ metadata=metadata,
938
+ ))
939
+
940
+ return documents
941
+
942
+ def count(self) -> int:
943
+ """返回文档数量"""
944
+ return self.collection.num_entities
945
+
946
+ def clear(self) -> None:
947
+ """清空集合(删除并重建)"""
948
+ from pymilvus import utility
949
+
950
+ logger.info(f"Clearing collection: {self.collection_name}")
951
+ utility.drop_collection(self.collection_name, using=self._connection_alias)
952
+ self.collection = self._create_collection()
953
+ logger.info(f"Collection '{self.collection_name}' cleared and recreated")
954
+
955
+ def drop(self) -> None:
956
+ """
957
+ 彻底删除集合
958
+
959
+ 警告:此操作不可逆,会永久删除 collection 及其所有数据。
960
+ 删除后 retriever 实例将不可用,需要重新创建。
961
+ """
962
+ from pymilvus import utility
963
+
964
+ logger.warning(f"Dropping collection: {self.collection_name}")
965
+ utility.drop_collection(self.collection_name, using=self._connection_alias)
966
+ self.collection = None
967
+ logger.info(f"Collection '{self.collection_name}' dropped")
968
+
969
+ # ========== 便利方法 ==========
970
+
971
+ def upsert_batch(
972
+ self,
973
+ documents: List[Document],
974
+ batch_size: int = 100,
975
+ skip_existing: bool = False,
976
+ show_progress: bool = True,
977
+ ) -> int:
978
+ """
979
+ 批量插入文档(带进度条和增量更新支持)
980
+
981
+ Args:
982
+ documents: 文档列表
983
+ batch_size: 批处理大小
984
+ skip_existing: 是否跳过已存在的文档
985
+ show_progress: 是否显示进度条
986
+
987
+ Returns:
988
+ 实际插入的文档数量
989
+ """
990
+ if not documents:
991
+ return 0
992
+
993
+ total_docs = len(documents)
994
+ logger.info(f"Starting batch upsert: {total_docs} documents, batch_size={batch_size}")
995
+
996
+ # 过滤已存在的文档
997
+ skipped = 0
998
+ if skip_existing:
999
+ existing_ids = self._get_existing_ids([doc.id for doc in documents])
1000
+ skipped = len([doc for doc in documents if doc.id in existing_ids])
1001
+ documents = [doc for doc in documents if doc.id not in existing_ids]
1002
+ if skipped > 0:
1003
+ logger.info(f"Skipped {skipped} existing documents")
1004
+ if not documents:
1005
+ return 0
1006
+
1007
+ # 批量插入
1008
+ inserted = 0
1009
+ total_batches = (len(documents) + batch_size - 1) // batch_size
1010
+ iterator = range(0, len(documents), batch_size)
1011
+
1012
+ if show_progress:
1013
+ try:
1014
+ from tqdm import tqdm
1015
+ iterator = tqdm(
1016
+ iterator,
1017
+ desc="Upserting",
1018
+ total=total_batches,
1019
+ unit="batch",
1020
+ )
1021
+ except ImportError:
1022
+ logger.debug("tqdm not installed, progress bar disabled")
1023
+
1024
+ for i in iterator:
1025
+ batch = documents[i:i + batch_size]
1026
+ self.upsert(batch)
1027
+ inserted += len(batch)
1028
+
1029
+ logger.info(f"Batch upsert completed: {inserted} inserted, {skipped} skipped")
1030
+ return inserted
1031
+
1032
+ def _get_existing_ids(self, candidate_ids: List[str]) -> set:
1033
+ """获取已存在的文档 ID 集合"""
1034
+ existing_ids = set()
1035
+ batch_size = 1000
1036
+ pk = self._primary_key
1037
+
1038
+ for i in range(0, len(candidate_ids), batch_size):
1039
+ batch_ids = candidate_ids[i:i + batch_size]
1040
+ ids_str = ", ".join([f'"{id}"' for id in batch_ids])
1041
+ try:
1042
+ results = self.collection.query(
1043
+ expr=f"{pk} in [{ids_str}]",
1044
+ output_fields=[pk],
1045
+ )
1046
+ for item in results:
1047
+ existing_ids.add(item.get(pk))
1048
+ except Exception:
1049
+ pass
1050
+
1051
+ return existing_ids
1052
+
1053
+ def get_all_ids(self) -> List[str]:
1054
+ """获取所有文档 ID"""
1055
+ pk = self._primary_key
1056
+ results = self.collection.query(
1057
+ expr=f"{pk} != ''",
1058
+ output_fields=[pk],
1059
+ limit=16384,
1060
+ )
1061
+ return [item.get(pk, "") for item in results]
1062
+
1063
+ def migrate_to(
1064
+ self,
1065
+ target,
1066
+ batch_size: int = 100,
1067
+ skip_existing: bool = True,
1068
+ show_progress: bool = True,
1069
+ ) -> int:
1070
+ """
1071
+ 将当前 collection 的所有数据迁移到目标 retriever
1072
+
1073
+ Args:
1074
+ target: 目标 retriever(ChromaRetriever 或 MilvusRetriever)
1075
+ batch_size: 批处理大小
1076
+ skip_existing: 是否跳过已存在的文档
1077
+ show_progress: 是否显示进度条
1078
+
1079
+ Returns:
1080
+ 迁移的文档数量
1081
+ """
1082
+ all_ids = self.get_all_ids()
1083
+ if not all_ids:
1084
+ logger.info("No documents to migrate")
1085
+ return 0
1086
+
1087
+ total = len(all_ids)
1088
+ logger.info(f"Starting migration: {total} documents")
1089
+
1090
+ migrated = 0
1091
+ iterator = range(0, total, batch_size)
1092
+
1093
+ if show_progress:
1094
+ try:
1095
+ from tqdm import tqdm
1096
+ iterator = tqdm(
1097
+ iterator,
1098
+ desc="Migrating",
1099
+ total=(total + batch_size - 1) // batch_size,
1100
+ unit="batch",
1101
+ )
1102
+ except ImportError:
1103
+ pass
1104
+
1105
+ for i in iterator:
1106
+ batch_ids = all_ids[i:i + batch_size]
1107
+ documents = self.get(ids=batch_ids)
1108
+ if documents:
1109
+ migrated += target.upsert_batch(
1110
+ documents,
1111
+ batch_size=batch_size,
1112
+ skip_existing=skip_existing,
1113
+ show_progress=False,
1114
+ )
1115
+
1116
+ logger.info(f"Migration completed: {migrated} documents migrated")
1117
+ return migrated
1118
+
1119
+ def close(self) -> None:
1120
+ """关闭连接"""
1121
+ from pymilvus import connections
1122
+ connections.disconnect(self._connection_alias)
1123
+
1124
+ def __repr__(self) -> str:
1125
+ return (
1126
+ f"MilvusRetriever("
1127
+ f"host={self.host!r}, "
1128
+ f"port={self.port}, "
1129
+ f"db={self.db_name!r}, "
1130
+ f"collection={self.collection_name!r}, "
1131
+ f"count={self.count()}, "
1132
+ f"embedding={self.embedding.__class__.__name__})"
1133
+ )
1134
+
1135
+ def __enter__(self):
1136
+ return self
1137
+
1138
+ def __exit__(self, exc_type, exc_val, exc_tb):
1139
+ self.close()
1140
+ return False