maque 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maque/__init__.py +30 -0
- maque/__main__.py +926 -0
- maque/ai_platform/__init__.py +0 -0
- maque/ai_platform/crawl.py +45 -0
- maque/ai_platform/metrics.py +258 -0
- maque/ai_platform/nlp_preprocess.py +67 -0
- maque/ai_platform/webpage_screen_shot.py +195 -0
- maque/algorithms/__init__.py +78 -0
- maque/algorithms/bezier.py +15 -0
- maque/algorithms/bktree.py +117 -0
- maque/algorithms/core.py +104 -0
- maque/algorithms/hilbert.py +16 -0
- maque/algorithms/rate_function.py +92 -0
- maque/algorithms/transform.py +27 -0
- maque/algorithms/trie.py +272 -0
- maque/algorithms/utils.py +63 -0
- maque/algorithms/video.py +587 -0
- maque/api/__init__.py +1 -0
- maque/api/common.py +110 -0
- maque/api/fetch.py +26 -0
- maque/api/static/icon.png +0 -0
- maque/api/static/redoc.standalone.js +1782 -0
- maque/api/static/swagger-ui-bundle.js +3 -0
- maque/api/static/swagger-ui.css +3 -0
- maque/cli/__init__.py +1 -0
- maque/cli/clean_invisible_chars.py +324 -0
- maque/cli/core.py +34 -0
- maque/cli/groups/__init__.py +26 -0
- maque/cli/groups/config.py +205 -0
- maque/cli/groups/data.py +615 -0
- maque/cli/groups/doctor.py +259 -0
- maque/cli/groups/embedding.py +222 -0
- maque/cli/groups/git.py +29 -0
- maque/cli/groups/help.py +410 -0
- maque/cli/groups/llm.py +223 -0
- maque/cli/groups/mcp.py +241 -0
- maque/cli/groups/mllm.py +1795 -0
- maque/cli/groups/mllm_simple.py +60 -0
- maque/cli/groups/quant.py +210 -0
- maque/cli/groups/service.py +490 -0
- maque/cli/groups/system.py +570 -0
- maque/cli/mllm_run.py +1451 -0
- maque/cli/script.py +52 -0
- maque/cli/tree.py +49 -0
- maque/clustering/__init__.py +52 -0
- maque/clustering/analyzer.py +347 -0
- maque/clustering/clusterers.py +464 -0
- maque/clustering/sampler.py +134 -0
- maque/clustering/visualizer.py +205 -0
- maque/constant.py +13 -0
- maque/core.py +133 -0
- maque/cv/__init__.py +1 -0
- maque/cv/image.py +219 -0
- maque/cv/utils.py +68 -0
- maque/cv/video/__init__.py +3 -0
- maque/cv/video/keyframe_extractor.py +368 -0
- maque/embedding/__init__.py +43 -0
- maque/embedding/base.py +56 -0
- maque/embedding/multimodal.py +308 -0
- maque/embedding/server.py +523 -0
- maque/embedding/text.py +311 -0
- maque/git/__init__.py +24 -0
- maque/git/pure_git.py +912 -0
- maque/io/__init__.py +29 -0
- maque/io/core.py +38 -0
- maque/io/ops.py +194 -0
- maque/llm/__init__.py +111 -0
- maque/llm/backend.py +416 -0
- maque/llm/base.py +411 -0
- maque/llm/server.py +366 -0
- maque/mcp_server.py +1096 -0
- maque/mllm_data_processor_pipeline/__init__.py +17 -0
- maque/mllm_data_processor_pipeline/core.py +341 -0
- maque/mllm_data_processor_pipeline/example.py +291 -0
- maque/mllm_data_processor_pipeline/steps/__init__.py +56 -0
- maque/mllm_data_processor_pipeline/steps/data_alignment.py +267 -0
- maque/mllm_data_processor_pipeline/steps/data_loader.py +172 -0
- maque/mllm_data_processor_pipeline/steps/data_validation.py +304 -0
- maque/mllm_data_processor_pipeline/steps/format_conversion.py +411 -0
- maque/mllm_data_processor_pipeline/steps/mllm_annotation.py +331 -0
- maque/mllm_data_processor_pipeline/steps/mllm_refinement.py +446 -0
- maque/mllm_data_processor_pipeline/steps/result_validation.py +501 -0
- maque/mllm_data_processor_pipeline/web_app.py +317 -0
- maque/nlp/__init__.py +14 -0
- maque/nlp/ngram.py +9 -0
- maque/nlp/parser.py +63 -0
- maque/nlp/risk_matcher.py +543 -0
- maque/nlp/sentence_splitter.py +202 -0
- maque/nlp/simple_tradition_cvt.py +31 -0
- maque/performance/__init__.py +21 -0
- maque/performance/_measure_time.py +70 -0
- maque/performance/_profiler.py +367 -0
- maque/performance/_stat_memory.py +51 -0
- maque/pipelines/__init__.py +15 -0
- maque/pipelines/clustering.py +252 -0
- maque/quantization/__init__.py +42 -0
- maque/quantization/auto_round.py +120 -0
- maque/quantization/base.py +145 -0
- maque/quantization/bitsandbytes.py +127 -0
- maque/quantization/llm_compressor.py +102 -0
- maque/retriever/__init__.py +35 -0
- maque/retriever/chroma.py +654 -0
- maque/retriever/document.py +140 -0
- maque/retriever/milvus.py +1140 -0
- maque/table_ops/__init__.py +1 -0
- maque/table_ops/core.py +133 -0
- maque/table_viewer/__init__.py +4 -0
- maque/table_viewer/download_assets.py +57 -0
- maque/table_viewer/server.py +698 -0
- maque/table_viewer/static/element-plus-icons.js +5791 -0
- maque/table_viewer/static/element-plus.css +1 -0
- maque/table_viewer/static/element-plus.js +65236 -0
- maque/table_viewer/static/main.css +268 -0
- maque/table_viewer/static/main.js +669 -0
- maque/table_viewer/static/vue.global.js +18227 -0
- maque/table_viewer/templates/index.html +401 -0
- maque/utils/__init__.py +56 -0
- maque/utils/color.py +68 -0
- maque/utils/color_string.py +45 -0
- maque/utils/compress.py +66 -0
- maque/utils/constant.py +183 -0
- maque/utils/core.py +261 -0
- maque/utils/cursor.py +143 -0
- maque/utils/distance.py +58 -0
- maque/utils/docker.py +96 -0
- maque/utils/downloads.py +51 -0
- maque/utils/excel_helper.py +542 -0
- maque/utils/helper_metrics.py +121 -0
- maque/utils/helper_parser.py +168 -0
- maque/utils/net.py +64 -0
- maque/utils/nvidia_stat.py +140 -0
- maque/utils/ops.py +53 -0
- maque/utils/packages.py +31 -0
- maque/utils/path.py +57 -0
- maque/utils/tar.py +260 -0
- maque/utils/untar.py +129 -0
- maque/web/__init__.py +0 -0
- maque/web/image_downloader.py +1410 -0
- maque-0.2.1.dist-info/METADATA +450 -0
- maque-0.2.1.dist-info/RECORD +143 -0
- maque-0.2.1.dist-info/WHEEL +4 -0
- maque-0.2.1.dist-info/entry_points.txt +3 -0
- maque-0.2.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1140 @@
|
|
|
1
|
+
#! /usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Milvus 向量数据库检索器实现
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import List, Optional, Union, Literal, TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
from loguru import logger
|
|
11
|
+
|
|
12
|
+
from ..embedding.base import BaseEmbedding
|
|
13
|
+
from .document import Document, SearchResult, Modality, _content_hash
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from pymilvus import Collection
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
DistanceMetric = Literal["COSINE", "L2", "IP"]
|
|
20
|
+
ScalarType = Literal["VARCHAR", "INT64", "INT32", "INT16", "INT8", "FLOAT", "DOUBLE", "BOOL", "JSON", "ARRAY"]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MilvusRetriever:
|
|
24
|
+
"""
|
|
25
|
+
基于 Milvus 的检索器
|
|
26
|
+
支持文本和图片的向量检索
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
# 预设索引配置
|
|
30
|
+
INDEX_PRESETS = {
|
|
31
|
+
"AUTOINDEX": {
|
|
32
|
+
"index_type": "AUTOINDEX",
|
|
33
|
+
"index_params": {},
|
|
34
|
+
"search_params": {},
|
|
35
|
+
},
|
|
36
|
+
"HNSW": {
|
|
37
|
+
"index_type": "HNSW",
|
|
38
|
+
"index_params": {"M": 16, "efConstruction": 256},
|
|
39
|
+
"search_params": {"ef": 128},
|
|
40
|
+
},
|
|
41
|
+
"IVF_FLAT": {
|
|
42
|
+
"index_type": "IVF_FLAT",
|
|
43
|
+
"index_params": {"nlist": 1024},
|
|
44
|
+
"search_params": {"nprobe": 16},
|
|
45
|
+
},
|
|
46
|
+
"FLAT": {
|
|
47
|
+
"index_type": "FLAT",
|
|
48
|
+
"index_params": {},
|
|
49
|
+
"search_params": {},
|
|
50
|
+
},
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
# Scalar 类型映射
|
|
54
|
+
SCALAR_TYPE_MAP = {
|
|
55
|
+
"VARCHAR": ("VARCHAR", {"max_length": 256}),
|
|
56
|
+
"INT64": ("INT64", {}),
|
|
57
|
+
"INT32": ("INT32", {}),
|
|
58
|
+
"INT16": ("INT16", {}),
|
|
59
|
+
"INT8": ("INT8", {}),
|
|
60
|
+
"FLOAT": ("FLOAT", {}),
|
|
61
|
+
"DOUBLE": ("DOUBLE", {}),
|
|
62
|
+
"BOOL": ("BOOL", {}),
|
|
63
|
+
"JSON": ("JSON", {}),
|
|
64
|
+
"ARRAY": ("ARRAY", {"element_type": "VARCHAR", "max_capacity": 256, "max_length": 256}),
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
embedding: BaseEmbedding,
|
|
70
|
+
host: str = "localhost",
|
|
71
|
+
port: int = 19530,
|
|
72
|
+
db_name: str = "default",
|
|
73
|
+
collection_name: str = "default",
|
|
74
|
+
distance_metric: DistanceMetric = "COSINE",
|
|
75
|
+
auto_create: bool = True,
|
|
76
|
+
index_config: Optional[dict] = None,
|
|
77
|
+
scalar_fields: Optional[List[dict]] = None,
|
|
78
|
+
primary_key: str = "id",
|
|
79
|
+
field_mapping: Optional[dict] = None,
|
|
80
|
+
):
|
|
81
|
+
"""
|
|
82
|
+
初始化检索器
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
embedding: Embedding 实例
|
|
86
|
+
host: Milvus 服务地址
|
|
87
|
+
port: Milvus 服务端口
|
|
88
|
+
db_name: 数据库名称
|
|
89
|
+
collection_name: 集合名称
|
|
90
|
+
distance_metric: 距离度量方式 (COSINE/L2/IP)
|
|
91
|
+
auto_create: 是否自动创建集合
|
|
92
|
+
index_config: 索引配置,可选项:
|
|
93
|
+
- None: 使用默认 HNSW 配置
|
|
94
|
+
- "AUTOINDEX" / "HNSW" / "IVF_FLAT" / "FLAT": 使用预设配置
|
|
95
|
+
- dict: 自定义配置,如:
|
|
96
|
+
{
|
|
97
|
+
"index_type": "HNSW",
|
|
98
|
+
"index_params": {"M": 16, "efConstruction": 256},
|
|
99
|
+
"search_params": {"ef": 128},
|
|
100
|
+
"id_max_length": 256,
|
|
101
|
+
"content_max_length": 65535,
|
|
102
|
+
}
|
|
103
|
+
scalar_fields: 【创建模式】额外的 scalar 字段定义,用于高效过滤,如:
|
|
104
|
+
[
|
|
105
|
+
{"name": "category", "dtype": "VARCHAR", "max_length": 64},
|
|
106
|
+
{"name": "timestamp", "dtype": "INT64"},
|
|
107
|
+
{"name": "score", "dtype": "FLOAT"},
|
|
108
|
+
{"name": "tags", "dtype": "ARRAY", "element_type": "VARCHAR", "max_capacity": 64, "max_length": 32},
|
|
109
|
+
]
|
|
110
|
+
支持的类型: VARCHAR, INT64, INT32, INT16, INT8, FLOAT, DOUBLE, BOOL, JSON, ARRAY
|
|
111
|
+
ARRAY 类型需要额外参数: element_type (元素类型), max_capacity (最大容量)
|
|
112
|
+
字段值从 Document.metadata 中自动提取
|
|
113
|
+
注意:读取已存在的 collection 时不需要此参数,字段类型会从 schema 自动提取
|
|
114
|
+
primary_key: 主键字段名称,默认为 "id",可自定义如 "user_id"、"content_id" 等
|
|
115
|
+
field_mapping: 【读取模式】字段映射,用于读取已存在的 collection,如:
|
|
116
|
+
{
|
|
117
|
+
"primary_key": "word_id", # 主键字段名
|
|
118
|
+
"content": "word", # 内容字段名
|
|
119
|
+
"embedding": "vector", # 向量字段名
|
|
120
|
+
"modality": None, # 模态字段名(可选,None 表示不存在)
|
|
121
|
+
"metadata": None, # metadata 字段名(可选)
|
|
122
|
+
}
|
|
123
|
+
其他字段会自动从 schema 提取,类型也会自动识别
|
|
124
|
+
|
|
125
|
+
使用模式:
|
|
126
|
+
- 创建模式:使用 scalar_fields 定义额外字段
|
|
127
|
+
- 读取模式:使用 field_mapping 映射核心字段,其他字段和类型自动从 schema 提取
|
|
128
|
+
"""
|
|
129
|
+
try:
|
|
130
|
+
from pymilvus import (
|
|
131
|
+
connections,
|
|
132
|
+
Collection,
|
|
133
|
+
FieldSchema,
|
|
134
|
+
CollectionSchema,
|
|
135
|
+
DataType,
|
|
136
|
+
utility,
|
|
137
|
+
db,
|
|
138
|
+
)
|
|
139
|
+
except ImportError:
|
|
140
|
+
raise ImportError(
|
|
141
|
+
"pymilvus is required for MilvusRetriever. "
|
|
142
|
+
"Install it with: pip install pymilvus"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
self.embedding = embedding
|
|
146
|
+
self.host = host
|
|
147
|
+
self.port = port
|
|
148
|
+
self.db_name = db_name
|
|
149
|
+
self.collection_name = collection_name
|
|
150
|
+
self.distance_metric = distance_metric
|
|
151
|
+
self._dimension = embedding.dimension
|
|
152
|
+
|
|
153
|
+
# 解析索引配置
|
|
154
|
+
config = self._parse_index_config(index_config)
|
|
155
|
+
self._index_type = config["index_type"]
|
|
156
|
+
self._index_params = config["index_params"]
|
|
157
|
+
self._search_params = config["search_params"]
|
|
158
|
+
self._id_max_length = config.get("id_max_length", 256)
|
|
159
|
+
self._content_max_length = config.get("content_max_length", 65535)
|
|
160
|
+
|
|
161
|
+
# 解析字段映射
|
|
162
|
+
self._field_mapping = self._parse_field_mapping(field_mapping, primary_key)
|
|
163
|
+
self._primary_key = self._field_mapping["primary_key"]
|
|
164
|
+
self._use_field_mapping = field_mapping is not None
|
|
165
|
+
|
|
166
|
+
# 解析 scalar 字段配置(仅在创建模式下使用)
|
|
167
|
+
self._scalar_fields = self._parse_scalar_fields(scalar_fields or [])
|
|
168
|
+
|
|
169
|
+
# 连接 Milvus
|
|
170
|
+
self._connection_alias = f"milvus_{db_name}_{collection_name}"
|
|
171
|
+
logger.debug(f"Connecting to Milvus at {host}:{port}, db={db_name}")
|
|
172
|
+
connections.connect(
|
|
173
|
+
alias=self._connection_alias,
|
|
174
|
+
host=host,
|
|
175
|
+
port=port,
|
|
176
|
+
db_name=db_name,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# 获取或创建集合
|
|
180
|
+
if utility.has_collection(collection_name, using=self._connection_alias):
|
|
181
|
+
logger.debug(f"Loading existing collection: {collection_name}")
|
|
182
|
+
self.collection = Collection(
|
|
183
|
+
name=collection_name,
|
|
184
|
+
using=self._connection_alias,
|
|
185
|
+
)
|
|
186
|
+
# 从 schema 提取额外字段和类型
|
|
187
|
+
self._extra_fields, self._field_types = self._extract_extra_fields_from_schema()
|
|
188
|
+
# 补充 scalar_fields 中的类型信息(用户可能提供更精确的类型)
|
|
189
|
+
for sf in self._scalar_fields:
|
|
190
|
+
self._field_types[sf["name"]] = sf["dtype"]
|
|
191
|
+
# 加载集合到内存
|
|
192
|
+
self.collection.load()
|
|
193
|
+
logger.info(f"Collection '{collection_name}' loaded, {self.count()} documents")
|
|
194
|
+
elif auto_create:
|
|
195
|
+
# 创建模式:使用 scalar_fields
|
|
196
|
+
logger.debug(f"Creating new collection: {collection_name}")
|
|
197
|
+
self._extra_fields = [sf["name"] for sf in self._scalar_fields]
|
|
198
|
+
self._field_types = {sf["name"]: sf["dtype"] for sf in self._scalar_fields}
|
|
199
|
+
self.collection = self._create_collection()
|
|
200
|
+
logger.info(f"Collection '{collection_name}' created")
|
|
201
|
+
else:
|
|
202
|
+
raise ValueError(f"Collection '{collection_name}' does not exist in database '{db_name}'")
|
|
203
|
+
|
|
204
|
+
def _parse_field_mapping(self, field_mapping: Optional[dict], primary_key: str) -> dict:
|
|
205
|
+
"""
|
|
206
|
+
解析字段映射配置
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
field_mapping: 用户提供的字段映射
|
|
210
|
+
primary_key: 默认主键名
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
标准化的字段映射
|
|
214
|
+
"""
|
|
215
|
+
# 默认映射
|
|
216
|
+
default_mapping = {
|
|
217
|
+
"primary_key": primary_key,
|
|
218
|
+
"content": "content",
|
|
219
|
+
"embedding": "embedding",
|
|
220
|
+
"modality": "modality",
|
|
221
|
+
"metadata": "metadata",
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
if not field_mapping:
|
|
225
|
+
return default_mapping
|
|
226
|
+
|
|
227
|
+
# 合并用户提供的映射
|
|
228
|
+
result = default_mapping.copy()
|
|
229
|
+
result.update(field_mapping)
|
|
230
|
+
return result
|
|
231
|
+
|
|
232
|
+
def _extract_extra_fields_from_schema(self) -> tuple[List[str], dict]:
|
|
233
|
+
"""
|
|
234
|
+
从已存在的 collection schema 中提取额外字段名和类型
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
(额外字段名列表, 字段类型映射)
|
|
238
|
+
"""
|
|
239
|
+
from pymilvus import DataType
|
|
240
|
+
|
|
241
|
+
# DataType 到字符串的映射
|
|
242
|
+
dtype_to_str = {
|
|
243
|
+
DataType.VARCHAR: "VARCHAR",
|
|
244
|
+
DataType.INT64: "INT64",
|
|
245
|
+
DataType.INT32: "INT32",
|
|
246
|
+
DataType.INT16: "INT16",
|
|
247
|
+
DataType.INT8: "INT8",
|
|
248
|
+
DataType.FLOAT: "FLOAT",
|
|
249
|
+
DataType.DOUBLE: "DOUBLE",
|
|
250
|
+
DataType.BOOL: "BOOL",
|
|
251
|
+
DataType.JSON: "JSON",
|
|
252
|
+
DataType.ARRAY: "ARRAY",
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
# 已映射的字段名
|
|
256
|
+
mapped_fields = {
|
|
257
|
+
self._field_mapping["primary_key"],
|
|
258
|
+
self._field_mapping["content"],
|
|
259
|
+
self._field_mapping["embedding"],
|
|
260
|
+
}
|
|
261
|
+
# 可选的映射字段
|
|
262
|
+
if self._field_mapping.get("modality"):
|
|
263
|
+
mapped_fields.add(self._field_mapping["modality"])
|
|
264
|
+
if self._field_mapping.get("metadata"):
|
|
265
|
+
mapped_fields.add(self._field_mapping["metadata"])
|
|
266
|
+
|
|
267
|
+
extra_fields = []
|
|
268
|
+
field_types = {}
|
|
269
|
+
for field in self.collection.schema.fields:
|
|
270
|
+
# 跳过已映射的字段和向量字段
|
|
271
|
+
if field.name in mapped_fields:
|
|
272
|
+
continue
|
|
273
|
+
if field.dtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR):
|
|
274
|
+
continue
|
|
275
|
+
extra_fields.append(field.name)
|
|
276
|
+
field_types[field.name] = dtype_to_str.get(field.dtype, "JSON")
|
|
277
|
+
|
|
278
|
+
return extra_fields, field_types
|
|
279
|
+
|
|
280
|
+
def _parse_index_config(self, index_config) -> dict:
|
|
281
|
+
"""解析索引配置"""
|
|
282
|
+
# 默认使用 AUTOINDEX
|
|
283
|
+
if index_config is None:
|
|
284
|
+
return self.INDEX_PRESETS["AUTOINDEX"].copy()
|
|
285
|
+
|
|
286
|
+
# 字符串预设
|
|
287
|
+
if isinstance(index_config, str):
|
|
288
|
+
if index_config not in self.INDEX_PRESETS:
|
|
289
|
+
raise ValueError(f"Unknown preset: {index_config}, available: {list(self.INDEX_PRESETS.keys())}")
|
|
290
|
+
return self.INDEX_PRESETS[index_config].copy()
|
|
291
|
+
|
|
292
|
+
# 自定义 dict
|
|
293
|
+
base = self.INDEX_PRESETS.get(index_config.get("index_type", "AUTOINDEX"), {}).copy()
|
|
294
|
+
base.update(index_config)
|
|
295
|
+
return base
|
|
296
|
+
|
|
297
|
+
def _parse_scalar_fields(self, scalar_fields: List[dict]) -> List[dict]:
|
|
298
|
+
"""
|
|
299
|
+
解析 scalar 字段配置
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
scalar_fields: 字段定义列表,如:
|
|
303
|
+
[
|
|
304
|
+
{"name": "category", "dtype": "VARCHAR", "max_length": 64},
|
|
305
|
+
{"name": "timestamp", "dtype": "INT64"},
|
|
306
|
+
{"name": "tags", "dtype": "ARRAY", "element_type": "VARCHAR", "max_capacity": 64, "max_length": 32},
|
|
307
|
+
]
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
标准化的字段配置列表
|
|
311
|
+
"""
|
|
312
|
+
reserved_names = {self._primary_key, "content", "modality", "metadata", "embedding"}
|
|
313
|
+
parsed = []
|
|
314
|
+
|
|
315
|
+
for field in scalar_fields:
|
|
316
|
+
name = field.get("name")
|
|
317
|
+
dtype = field.get("dtype", "VARCHAR").upper()
|
|
318
|
+
|
|
319
|
+
if not name:
|
|
320
|
+
raise ValueError("Scalar field must have 'name'")
|
|
321
|
+
if name in reserved_names:
|
|
322
|
+
raise ValueError(f"Field name '{name}' is reserved")
|
|
323
|
+
if dtype not in self.SCALAR_TYPE_MAP:
|
|
324
|
+
raise ValueError(f"Unknown dtype '{dtype}', available: {list(self.SCALAR_TYPE_MAP.keys())}")
|
|
325
|
+
|
|
326
|
+
# 构建标准化配置
|
|
327
|
+
parsed_field = {"name": name, "dtype": dtype}
|
|
328
|
+
|
|
329
|
+
# VARCHAR 需要 max_length
|
|
330
|
+
if dtype == "VARCHAR":
|
|
331
|
+
parsed_field["max_length"] = field.get("max_length", 256)
|
|
332
|
+
# ARRAY 需要 element_type, max_capacity, 以及可能的 max_length
|
|
333
|
+
elif dtype == "ARRAY":
|
|
334
|
+
element_type = field.get("element_type", "VARCHAR").upper()
|
|
335
|
+
parsed_field["element_type"] = element_type
|
|
336
|
+
parsed_field["max_capacity"] = field.get("max_capacity", 256)
|
|
337
|
+
# 如果元素类型是 VARCHAR,需要 max_length
|
|
338
|
+
if element_type == "VARCHAR":
|
|
339
|
+
parsed_field["max_length"] = field.get("max_length", 256)
|
|
340
|
+
|
|
341
|
+
parsed.append(parsed_field)
|
|
342
|
+
|
|
343
|
+
return parsed
|
|
344
|
+
|
|
345
|
+
def _create_collection(self) -> "Collection":
|
|
346
|
+
"""创建集合"""
|
|
347
|
+
from pymilvus import (
|
|
348
|
+
Collection,
|
|
349
|
+
FieldSchema,
|
|
350
|
+
CollectionSchema,
|
|
351
|
+
DataType,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# 基础字段(使用字段映射)
|
|
355
|
+
fm = self._field_mapping
|
|
356
|
+
fields = [
|
|
357
|
+
FieldSchema(name=fm["primary_key"], dtype=DataType.VARCHAR, max_length=self._id_max_length, is_primary=True),
|
|
358
|
+
FieldSchema(name=fm["content"], dtype=DataType.VARCHAR, max_length=self._content_max_length),
|
|
359
|
+
FieldSchema(name=fm["embedding"], dtype=DataType.FLOAT_VECTOR, dim=self._dimension),
|
|
360
|
+
]
|
|
361
|
+
|
|
362
|
+
# 可选字段
|
|
363
|
+
if fm.get("modality"):
|
|
364
|
+
fields.append(FieldSchema(name=fm["modality"], dtype=DataType.VARCHAR, max_length=32))
|
|
365
|
+
if fm.get("metadata"):
|
|
366
|
+
fields.append(FieldSchema(name=fm["metadata"], dtype=DataType.JSON))
|
|
367
|
+
|
|
368
|
+
# 添加额外的 scalar 字段
|
|
369
|
+
for sf in self._scalar_fields:
|
|
370
|
+
dtype = getattr(DataType, sf["dtype"])
|
|
371
|
+
if sf["dtype"] == "VARCHAR":
|
|
372
|
+
fields.append(FieldSchema(name=sf["name"], dtype=dtype, max_length=sf["max_length"]))
|
|
373
|
+
elif sf["dtype"] == "ARRAY":
|
|
374
|
+
element_type = getattr(DataType, sf["element_type"])
|
|
375
|
+
if sf["element_type"] == "VARCHAR":
|
|
376
|
+
fields.append(FieldSchema(
|
|
377
|
+
name=sf["name"],
|
|
378
|
+
dtype=dtype,
|
|
379
|
+
element_type=element_type,
|
|
380
|
+
max_capacity=sf["max_capacity"],
|
|
381
|
+
max_length=sf["max_length"],
|
|
382
|
+
))
|
|
383
|
+
else:
|
|
384
|
+
fields.append(FieldSchema(
|
|
385
|
+
name=sf["name"],
|
|
386
|
+
dtype=dtype,
|
|
387
|
+
element_type=element_type,
|
|
388
|
+
max_capacity=sf["max_capacity"],
|
|
389
|
+
))
|
|
390
|
+
else:
|
|
391
|
+
fields.append(FieldSchema(name=sf["name"], dtype=dtype))
|
|
392
|
+
|
|
393
|
+
schema = CollectionSchema(
|
|
394
|
+
fields=fields,
|
|
395
|
+
description=f"Collection for {self.collection_name}",
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
collection = Collection(
|
|
399
|
+
name=self.collection_name,
|
|
400
|
+
schema=schema,
|
|
401
|
+
using=self._connection_alias,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
# 创建索引
|
|
405
|
+
index_params = {
|
|
406
|
+
"metric_type": self.distance_metric,
|
|
407
|
+
"index_type": self._index_type,
|
|
408
|
+
"params": self._index_params,
|
|
409
|
+
}
|
|
410
|
+
collection.create_index(field_name=fm["embedding"], index_params=index_params)
|
|
411
|
+
|
|
412
|
+
# 加载到内存
|
|
413
|
+
collection.load()
|
|
414
|
+
|
|
415
|
+
return collection
|
|
416
|
+
|
|
417
|
+
def _get_input_type(self, modality: Modality) -> str:
|
|
418
|
+
"""获取 embedding 的 input_type 参数"""
|
|
419
|
+
return "image" if modality == "image" else "text"
|
|
420
|
+
|
|
421
|
+
def _embed_documents(self, documents: List[Document]) -> List[List[float]]:
|
|
422
|
+
"""对文档进行向量化"""
|
|
423
|
+
if not documents:
|
|
424
|
+
return []
|
|
425
|
+
|
|
426
|
+
has_image = any(doc.is_image for doc in documents)
|
|
427
|
+
if has_image and not self.embedding.supports_image:
|
|
428
|
+
raise ValueError(
|
|
429
|
+
f"Embedding 不支持图片,但文档中包含图片。"
|
|
430
|
+
f"请使用 MultiModalEmbedding。"
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
if has_image:
|
|
434
|
+
embeddings = []
|
|
435
|
+
for doc in documents:
|
|
436
|
+
input_type = self._get_input_type(doc.modality)
|
|
437
|
+
vec = self.embedding.embed([doc.content], input_type=input_type)[0]
|
|
438
|
+
embeddings.append(vec)
|
|
439
|
+
return embeddings
|
|
440
|
+
else:
|
|
441
|
+
contents = [doc.content for doc in documents]
|
|
442
|
+
return self.embedding.embed(contents)
|
|
443
|
+
|
|
444
|
+
def _embed_query(
|
|
445
|
+
self,
|
|
446
|
+
query: str,
|
|
447
|
+
query_type: Modality = "text",
|
|
448
|
+
) -> List[float]:
|
|
449
|
+
"""对查询进行向量化"""
|
|
450
|
+
if query_type == "image" and not self.embedding.supports_image:
|
|
451
|
+
raise ValueError("Embedding 不支持图片查询")
|
|
452
|
+
|
|
453
|
+
if self.embedding.supports_image:
|
|
454
|
+
input_type = self._get_input_type(query_type)
|
|
455
|
+
return self.embedding.embed([query], input_type=input_type)[0]
|
|
456
|
+
else:
|
|
457
|
+
return self.embedding.embed([query])[0]
|
|
458
|
+
|
|
459
|
+
def _embed_queries(
|
|
460
|
+
self,
|
|
461
|
+
queries: List[str],
|
|
462
|
+
query_type: Modality = "text",
|
|
463
|
+
) -> List[List[float]]:
|
|
464
|
+
"""对多个查询进行批量向量化"""
|
|
465
|
+
if not queries:
|
|
466
|
+
return []
|
|
467
|
+
|
|
468
|
+
if query_type == "image" and not self.embedding.supports_image:
|
|
469
|
+
raise ValueError("Embedding 不支持图片查询")
|
|
470
|
+
|
|
471
|
+
if self.embedding.supports_image:
|
|
472
|
+
input_type = self._get_input_type(query_type)
|
|
473
|
+
return self.embedding.embed(queries, input_type=input_type)
|
|
474
|
+
else:
|
|
475
|
+
return self.embedding.embed(queries)
|
|
476
|
+
|
|
477
|
+
def _prepare_insert_data(self, documents: List[Document], embeddings: List[List[float]]) -> List[dict]:
|
|
478
|
+
"""
|
|
479
|
+
准备插入数据,包含基础字段和额外字段
|
|
480
|
+
|
|
481
|
+
Args:
|
|
482
|
+
documents: 文档列表
|
|
483
|
+
embeddings: 向量列表
|
|
484
|
+
|
|
485
|
+
Returns:
|
|
486
|
+
行格式的数据列表(每个元素是一个字典,代表一行)
|
|
487
|
+
"""
|
|
488
|
+
fm = self._field_mapping
|
|
489
|
+
|
|
490
|
+
rows = []
|
|
491
|
+
for i, doc in enumerate(documents):
|
|
492
|
+
row = {
|
|
493
|
+
fm["primary_key"]: doc.id,
|
|
494
|
+
fm["content"]: doc.content,
|
|
495
|
+
fm["embedding"]: embeddings[i],
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
# 可选字段
|
|
499
|
+
if fm.get("modality"):
|
|
500
|
+
row[fm["modality"]] = doc.modality
|
|
501
|
+
if fm.get("metadata"):
|
|
502
|
+
row[fm["metadata"]] = doc.metadata
|
|
503
|
+
|
|
504
|
+
# 添加所有额外字段值(从 metadata 中提取)
|
|
505
|
+
for field_name in self._extra_fields:
|
|
506
|
+
value = doc.metadata.get(field_name)
|
|
507
|
+
dtype = self._field_types.get(field_name)
|
|
508
|
+
|
|
509
|
+
# 如果值为 None,根据类型设置默认值
|
|
510
|
+
if value is None:
|
|
511
|
+
if dtype == "VARCHAR":
|
|
512
|
+
value = ""
|
|
513
|
+
elif dtype in ("INT64", "INT32", "INT16", "INT8"):
|
|
514
|
+
value = 0
|
|
515
|
+
elif dtype in ("FLOAT", "DOUBLE"):
|
|
516
|
+
value = 0.0
|
|
517
|
+
elif dtype == "BOOL":
|
|
518
|
+
value = False
|
|
519
|
+
else:
|
|
520
|
+
# JSON / ARRAY 等类型默认空列表
|
|
521
|
+
value = []
|
|
522
|
+
|
|
523
|
+
row[field_name] = value
|
|
524
|
+
|
|
525
|
+
rows.append(row)
|
|
526
|
+
|
|
527
|
+
return rows
|
|
528
|
+
|
|
529
|
+
# ========== 索引操作 ==========
|
|
530
|
+
|
|
531
|
+
def add(
|
|
532
|
+
self,
|
|
533
|
+
documents: Union[Document, List[Document]],
|
|
534
|
+
skip_existing: bool = False,
|
|
535
|
+
) -> List[str]:
|
|
536
|
+
"""
|
|
537
|
+
添加文档
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
documents: 单个文档或文档列表
|
|
541
|
+
skip_existing: 是否跳过已存在的文档
|
|
542
|
+
|
|
543
|
+
Returns:
|
|
544
|
+
添加的文档 ID 列表
|
|
545
|
+
"""
|
|
546
|
+
if isinstance(documents, Document):
|
|
547
|
+
documents = [documents]
|
|
548
|
+
|
|
549
|
+
if not documents:
|
|
550
|
+
return []
|
|
551
|
+
|
|
552
|
+
# 过滤已存在的文档
|
|
553
|
+
if skip_existing:
|
|
554
|
+
existing_ids = self._get_existing_ids([doc.id for doc in documents])
|
|
555
|
+
skipped = len([doc for doc in documents if doc.id in existing_ids])
|
|
556
|
+
documents = [doc for doc in documents if doc.id not in existing_ids]
|
|
557
|
+
if skipped > 0:
|
|
558
|
+
logger.debug(f"Skipped {skipped} existing documents")
|
|
559
|
+
if not documents:
|
|
560
|
+
return []
|
|
561
|
+
|
|
562
|
+
# 向量化
|
|
563
|
+
embeddings = self._embed_documents(documents)
|
|
564
|
+
|
|
565
|
+
# 准备数据(包含 scalar 字段)
|
|
566
|
+
data = self._prepare_insert_data(documents, embeddings)
|
|
567
|
+
|
|
568
|
+
# 插入数据
|
|
569
|
+
self.collection.insert(data)
|
|
570
|
+
self.collection.flush()
|
|
571
|
+
logger.debug(f"Added {len(documents)} documents")
|
|
572
|
+
|
|
573
|
+
return [doc.id for doc in documents]
|
|
574
|
+
|
|
575
|
+
def upsert(
|
|
576
|
+
self,
|
|
577
|
+
documents: Union[Document, List[Document]],
|
|
578
|
+
skip_existing: bool = False,
|
|
579
|
+
) -> List[str]:
|
|
580
|
+
"""
|
|
581
|
+
添加或更新文档
|
|
582
|
+
|
|
583
|
+
Args:
|
|
584
|
+
documents: 单个文档或文档列表
|
|
585
|
+
skip_existing: 是否跳过已存在的文档(为 True 时行为与 add 相同)
|
|
586
|
+
|
|
587
|
+
Returns:
|
|
588
|
+
upsert 的文档 ID 列表
|
|
589
|
+
"""
|
|
590
|
+
if isinstance(documents, Document):
|
|
591
|
+
documents = [documents]
|
|
592
|
+
|
|
593
|
+
if not documents:
|
|
594
|
+
return []
|
|
595
|
+
|
|
596
|
+
# 过滤已存在的文档
|
|
597
|
+
if skip_existing:
|
|
598
|
+
existing_ids = self._get_existing_ids([doc.id for doc in documents])
|
|
599
|
+
documents = [doc for doc in documents if doc.id not in existing_ids]
|
|
600
|
+
if not documents:
|
|
601
|
+
return []
|
|
602
|
+
|
|
603
|
+
# 向量化
|
|
604
|
+
embeddings = self._embed_documents(documents)
|
|
605
|
+
|
|
606
|
+
# 准备数据(包含 scalar 字段)
|
|
607
|
+
data = self._prepare_insert_data(documents, embeddings)
|
|
608
|
+
|
|
609
|
+
# Milvus upsert
|
|
610
|
+
self.collection.upsert(data)
|
|
611
|
+
self.collection.flush()
|
|
612
|
+
logger.debug(f"Upserted {len(documents)} documents")
|
|
613
|
+
|
|
614
|
+
return [doc.id for doc in documents]
|
|
615
|
+
|
|
616
|
+
def delete(self, ids: Union[str, List[str]]) -> None:
|
|
617
|
+
"""
|
|
618
|
+
删除文档
|
|
619
|
+
|
|
620
|
+
Args:
|
|
621
|
+
ids: 单个 ID 或 ID 列表
|
|
622
|
+
"""
|
|
623
|
+
if isinstance(ids, str):
|
|
624
|
+
ids = [ids]
|
|
625
|
+
|
|
626
|
+
# 构建删除表达式
|
|
627
|
+
ids_str = ", ".join([f'"{id}"' for id in ids])
|
|
628
|
+
expr = f"{self._primary_key} in [{ids_str}]"
|
|
629
|
+
self.collection.delete(expr)
|
|
630
|
+
self.collection.flush()
|
|
631
|
+
logger.debug(f"Deleted {len(ids)} documents")
|
|
632
|
+
|
|
633
|
+
def delete_by_content(self, contents: Union[str, List[str]]) -> None:
|
|
634
|
+
"""
|
|
635
|
+
根据内容删除文档
|
|
636
|
+
|
|
637
|
+
Args:
|
|
638
|
+
contents: 单个内容或内容列表
|
|
639
|
+
"""
|
|
640
|
+
if isinstance(contents, str):
|
|
641
|
+
contents = [contents]
|
|
642
|
+
|
|
643
|
+
ids = [_content_hash(content) for content in contents]
|
|
644
|
+
self.delete(ids)
|
|
645
|
+
|
|
646
|
+
# ========== 检索操作 ==========
|
|
647
|
+
|
|
648
|
+
def _get_output_fields(self) -> List[str]:
|
|
649
|
+
"""获取查询时需要返回的所有字段"""
|
|
650
|
+
fm = self._field_mapping
|
|
651
|
+
fields = [fm["primary_key"], fm["content"]]
|
|
652
|
+
|
|
653
|
+
# 可选字段(可能为 None)
|
|
654
|
+
if fm.get("modality"):
|
|
655
|
+
fields.append(fm["modality"])
|
|
656
|
+
if fm.get("metadata"):
|
|
657
|
+
fields.append(fm["metadata"])
|
|
658
|
+
|
|
659
|
+
# 额外字段(统一由 _extra_fields 管理)
|
|
660
|
+
fields.extend(self._extra_fields)
|
|
661
|
+
|
|
662
|
+
return fields
|
|
663
|
+
|
|
664
|
+
def search(
|
|
665
|
+
self,
|
|
666
|
+
query: str,
|
|
667
|
+
top_k: int = 5,
|
|
668
|
+
query_type: Modality = "text",
|
|
669
|
+
expr: Optional[str] = None,
|
|
670
|
+
) -> List[SearchResult]:
|
|
671
|
+
"""
|
|
672
|
+
检索相似文档
|
|
673
|
+
|
|
674
|
+
Args:
|
|
675
|
+
query: 查询内容(文本或图片路径/URL)
|
|
676
|
+
top_k: 返回数量
|
|
677
|
+
query_type: 查询类型 "text" / "image"
|
|
678
|
+
expr: Milvus 过滤表达式 (例如: 'metadata["category"] == "tech"')
|
|
679
|
+
|
|
680
|
+
Returns:
|
|
681
|
+
SearchResult 列表
|
|
682
|
+
"""
|
|
683
|
+
query_embedding = self._embed_query(query, query_type)
|
|
684
|
+
|
|
685
|
+
search_params = {
|
|
686
|
+
"metric_type": self.distance_metric,
|
|
687
|
+
"params": self._search_params,
|
|
688
|
+
}
|
|
689
|
+
|
|
690
|
+
results = self.collection.search(
|
|
691
|
+
data=[query_embedding],
|
|
692
|
+
anns_field=self._field_mapping["embedding"],
|
|
693
|
+
param=search_params,
|
|
694
|
+
limit=top_k,
|
|
695
|
+
expr=expr,
|
|
696
|
+
output_fields=self._get_output_fields(),
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
parsed = self._parse_results(results)
|
|
700
|
+
return parsed[0] if parsed else []
|
|
701
|
+
|
|
702
|
+
def search_by_vector(
|
|
703
|
+
self,
|
|
704
|
+
vector: List[float],
|
|
705
|
+
top_k: int = 5,
|
|
706
|
+
expr: Optional[str] = None,
|
|
707
|
+
) -> List[SearchResult]:
|
|
708
|
+
"""
|
|
709
|
+
直接使用向量检索
|
|
710
|
+
|
|
711
|
+
Args:
|
|
712
|
+
vector: 查询向量
|
|
713
|
+
top_k: 返回数量
|
|
714
|
+
expr: Milvus 过滤表达式
|
|
715
|
+
|
|
716
|
+
Returns:
|
|
717
|
+
SearchResult 列表
|
|
718
|
+
"""
|
|
719
|
+
search_params = {
|
|
720
|
+
"metric_type": self.distance_metric,
|
|
721
|
+
"params": self._search_params,
|
|
722
|
+
}
|
|
723
|
+
|
|
724
|
+
results = self.collection.search(
|
|
725
|
+
data=[vector],
|
|
726
|
+
anns_field=self._field_mapping["embedding"],
|
|
727
|
+
param=search_params,
|
|
728
|
+
limit=top_k,
|
|
729
|
+
expr=expr,
|
|
730
|
+
output_fields=self._get_output_fields(),
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
parsed = self._parse_results(results)
|
|
734
|
+
return parsed[0] if parsed else []
|
|
735
|
+
|
|
736
|
+
def search_batch(
|
|
737
|
+
self,
|
|
738
|
+
queries: List[str],
|
|
739
|
+
top_k: int = 5,
|
|
740
|
+
query_type: Modality = "text",
|
|
741
|
+
expr: Optional[str] = None,
|
|
742
|
+
) -> List[List[SearchResult]]:
|
|
743
|
+
"""
|
|
744
|
+
批量检索相似文档
|
|
745
|
+
|
|
746
|
+
Args:
|
|
747
|
+
queries: 查询内容列表(文本或图片路径/URL)
|
|
748
|
+
top_k: 每个查询返回的数量
|
|
749
|
+
query_type: 查询类型 "text" / "image"
|
|
750
|
+
expr: Milvus 过滤表达式
|
|
751
|
+
|
|
752
|
+
Returns:
|
|
753
|
+
SearchResult 列表的列表,每个查询对应一个结果列表
|
|
754
|
+
|
|
755
|
+
Example:
|
|
756
|
+
>>> results = retriever.search_batch(["query1", "query2"], top_k=5)
|
|
757
|
+
>>> for i, query_results in enumerate(results):
|
|
758
|
+
... print(f"Query {i}: {len(query_results)} results")
|
|
759
|
+
"""
|
|
760
|
+
if not queries:
|
|
761
|
+
return []
|
|
762
|
+
|
|
763
|
+
# 批量向量化查询
|
|
764
|
+
query_embeddings = self._embed_queries(queries, query_type)
|
|
765
|
+
|
|
766
|
+
search_params = {
|
|
767
|
+
"metric_type": self.distance_metric,
|
|
768
|
+
"params": self._search_params,
|
|
769
|
+
}
|
|
770
|
+
|
|
771
|
+
results = self.collection.search(
|
|
772
|
+
data=query_embeddings,
|
|
773
|
+
anns_field=self._field_mapping["embedding"],
|
|
774
|
+
param=search_params,
|
|
775
|
+
limit=top_k,
|
|
776
|
+
expr=expr,
|
|
777
|
+
output_fields=self._get_output_fields(),
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
return self._parse_results(results)
|
|
781
|
+
|
|
782
|
+
def search_by_vectors(
|
|
783
|
+
self,
|
|
784
|
+
vectors: List[List[float]],
|
|
785
|
+
top_k: int = 5,
|
|
786
|
+
expr: Optional[str] = None,
|
|
787
|
+
) -> List[List[SearchResult]]:
|
|
788
|
+
"""
|
|
789
|
+
批量使用向量检索
|
|
790
|
+
|
|
791
|
+
Args:
|
|
792
|
+
vectors: 查询向量列表
|
|
793
|
+
top_k: 每个查询返回的数量
|
|
794
|
+
expr: Milvus 过滤表达式
|
|
795
|
+
|
|
796
|
+
Returns:
|
|
797
|
+
SearchResult 列表的列表,每个向量对应一个结果列表
|
|
798
|
+
|
|
799
|
+
Example:
|
|
800
|
+
>>> vectors = [[0.1, 0.2, ...], [0.3, 0.4, ...]]
|
|
801
|
+
>>> results = retriever.search_by_vectors(vectors, top_k=5)
|
|
802
|
+
"""
|
|
803
|
+
if not vectors:
|
|
804
|
+
return []
|
|
805
|
+
|
|
806
|
+
search_params = {
|
|
807
|
+
"metric_type": self.distance_metric,
|
|
808
|
+
"params": self._search_params,
|
|
809
|
+
}
|
|
810
|
+
|
|
811
|
+
results = self.collection.search(
|
|
812
|
+
data=vectors,
|
|
813
|
+
anns_field=self._field_mapping["embedding"],
|
|
814
|
+
param=search_params,
|
|
815
|
+
limit=top_k,
|
|
816
|
+
expr=expr,
|
|
817
|
+
output_fields=self._get_output_fields(),
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
return self._parse_results(results)
|
|
821
|
+
|
|
822
|
+
def _parse_results(self, results) -> List[List[SearchResult]]:
|
|
823
|
+
"""
|
|
824
|
+
解析 Milvus 返回结果
|
|
825
|
+
|
|
826
|
+
Args:
|
|
827
|
+
results: Milvus search 返回的结果
|
|
828
|
+
|
|
829
|
+
Returns:
|
|
830
|
+
SearchResult 列表的列表,每个查询对应一个结果列表
|
|
831
|
+
"""
|
|
832
|
+
if not results or len(results) == 0:
|
|
833
|
+
return []
|
|
834
|
+
|
|
835
|
+
fm = self._field_mapping
|
|
836
|
+
all_results = []
|
|
837
|
+
|
|
838
|
+
for hits in results:
|
|
839
|
+
query_results = []
|
|
840
|
+
for hit in hits:
|
|
841
|
+
entity = hit.entity
|
|
842
|
+
|
|
843
|
+
# 距离转相似度
|
|
844
|
+
distance = hit.distance
|
|
845
|
+
if self.distance_metric == "COSINE":
|
|
846
|
+
score = distance # Milvus COSINE 返回的是相似度
|
|
847
|
+
elif self.distance_metric == "IP":
|
|
848
|
+
score = distance # IP 内积越大越相似
|
|
849
|
+
else:
|
|
850
|
+
score = -distance # L2 距离越小越好
|
|
851
|
+
|
|
852
|
+
# 合并 metadata 和额外字段
|
|
853
|
+
# 注意:pymilvus Hit 对象的 `in` 操作符不可靠,需要直接用 get
|
|
854
|
+
metadata_field = fm.get("metadata")
|
|
855
|
+
metadata = dict(entity.get(metadata_field) or {}) if metadata_field else {}
|
|
856
|
+
|
|
857
|
+
for name in self._extra_fields:
|
|
858
|
+
value = entity.get(name)
|
|
859
|
+
if value is not None:
|
|
860
|
+
metadata[name] = value
|
|
861
|
+
|
|
862
|
+
# 获取 modality(可能不存在)
|
|
863
|
+
modality_field = fm.get("modality")
|
|
864
|
+
modality = entity.get(modality_field, "text") if modality_field else "text"
|
|
865
|
+
|
|
866
|
+
query_results.append(SearchResult(
|
|
867
|
+
id=entity.get(fm["primary_key"], ""),
|
|
868
|
+
content=entity.get(fm["content"], ""),
|
|
869
|
+
score=score,
|
|
870
|
+
modality=modality,
|
|
871
|
+
metadata=metadata,
|
|
872
|
+
))
|
|
873
|
+
|
|
874
|
+
all_results.append(query_results)
|
|
875
|
+
|
|
876
|
+
return all_results
|
|
877
|
+
|
|
878
|
+
# ========== 管理操作 ==========
|
|
879
|
+
|
|
880
|
+
def get(
|
|
881
|
+
self,
|
|
882
|
+
ids: Optional[Union[str, List[str]]] = None,
|
|
883
|
+
expr: Optional[str] = None,
|
|
884
|
+
limit: Optional[int] = None,
|
|
885
|
+
) -> List[Document]:
|
|
886
|
+
"""
|
|
887
|
+
获取文档
|
|
888
|
+
|
|
889
|
+
Args:
|
|
890
|
+
ids: 文档 ID 或 ID 列表
|
|
891
|
+
expr: Milvus 过滤表达式
|
|
892
|
+
limit: 返回数量限制
|
|
893
|
+
|
|
894
|
+
Returns:
|
|
895
|
+
Document 列表
|
|
896
|
+
"""
|
|
897
|
+
if isinstance(ids, str):
|
|
898
|
+
ids = [ids]
|
|
899
|
+
|
|
900
|
+
# 构建查询表达式
|
|
901
|
+
pk = self._primary_key
|
|
902
|
+
if ids:
|
|
903
|
+
ids_str = ", ".join([f'"{id}"' for id in ids])
|
|
904
|
+
query_expr = f"{pk} in [{ids_str}]"
|
|
905
|
+
if expr:
|
|
906
|
+
query_expr = f"({query_expr}) and ({expr})"
|
|
907
|
+
else:
|
|
908
|
+
query_expr = expr or ""
|
|
909
|
+
|
|
910
|
+
results = self.collection.query(
|
|
911
|
+
expr=query_expr if query_expr else f"{pk} != ''",
|
|
912
|
+
output_fields=self._get_output_fields(),
|
|
913
|
+
limit=limit or 16384,
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
fm = self._field_mapping
|
|
917
|
+
documents = []
|
|
918
|
+
|
|
919
|
+
for item in results:
|
|
920
|
+
# 合并 metadata 和额外字段
|
|
921
|
+
metadata_field = fm.get("metadata")
|
|
922
|
+
metadata = dict(item.get(metadata_field) or {}) if metadata_field else {}
|
|
923
|
+
|
|
924
|
+
for name in self._extra_fields:
|
|
925
|
+
value = item.get(name)
|
|
926
|
+
if value is not None:
|
|
927
|
+
metadata[name] = value
|
|
928
|
+
|
|
929
|
+
# 获取 modality(可能不存在)
|
|
930
|
+
modality_field = fm.get("modality")
|
|
931
|
+
modality = item.get(modality_field, "text") if modality_field else "text"
|
|
932
|
+
|
|
933
|
+
documents.append(Document(
|
|
934
|
+
id=item.get(pk, ""),
|
|
935
|
+
content=item.get(fm["content"], ""),
|
|
936
|
+
modality=modality,
|
|
937
|
+
metadata=metadata,
|
|
938
|
+
))
|
|
939
|
+
|
|
940
|
+
return documents
|
|
941
|
+
|
|
942
|
+
def count(self) -> int:
|
|
943
|
+
"""返回文档数量"""
|
|
944
|
+
return self.collection.num_entities
|
|
945
|
+
|
|
946
|
+
def clear(self) -> None:
|
|
947
|
+
"""清空集合(删除并重建)"""
|
|
948
|
+
from pymilvus import utility
|
|
949
|
+
|
|
950
|
+
logger.info(f"Clearing collection: {self.collection_name}")
|
|
951
|
+
utility.drop_collection(self.collection_name, using=self._connection_alias)
|
|
952
|
+
self.collection = self._create_collection()
|
|
953
|
+
logger.info(f"Collection '{self.collection_name}' cleared and recreated")
|
|
954
|
+
|
|
955
|
+
def drop(self) -> None:
|
|
956
|
+
"""
|
|
957
|
+
彻底删除集合
|
|
958
|
+
|
|
959
|
+
警告:此操作不可逆,会永久删除 collection 及其所有数据。
|
|
960
|
+
删除后 retriever 实例将不可用,需要重新创建。
|
|
961
|
+
"""
|
|
962
|
+
from pymilvus import utility
|
|
963
|
+
|
|
964
|
+
logger.warning(f"Dropping collection: {self.collection_name}")
|
|
965
|
+
utility.drop_collection(self.collection_name, using=self._connection_alias)
|
|
966
|
+
self.collection = None
|
|
967
|
+
logger.info(f"Collection '{self.collection_name}' dropped")
|
|
968
|
+
|
|
969
|
+
# ========== 便利方法 ==========
|
|
970
|
+
|
|
971
|
+
def upsert_batch(
|
|
972
|
+
self,
|
|
973
|
+
documents: List[Document],
|
|
974
|
+
batch_size: int = 100,
|
|
975
|
+
skip_existing: bool = False,
|
|
976
|
+
show_progress: bool = True,
|
|
977
|
+
) -> int:
|
|
978
|
+
"""
|
|
979
|
+
批量插入文档(带进度条和增量更新支持)
|
|
980
|
+
|
|
981
|
+
Args:
|
|
982
|
+
documents: 文档列表
|
|
983
|
+
batch_size: 批处理大小
|
|
984
|
+
skip_existing: 是否跳过已存在的文档
|
|
985
|
+
show_progress: 是否显示进度条
|
|
986
|
+
|
|
987
|
+
Returns:
|
|
988
|
+
实际插入的文档数量
|
|
989
|
+
"""
|
|
990
|
+
if not documents:
|
|
991
|
+
return 0
|
|
992
|
+
|
|
993
|
+
total_docs = len(documents)
|
|
994
|
+
logger.info(f"Starting batch upsert: {total_docs} documents, batch_size={batch_size}")
|
|
995
|
+
|
|
996
|
+
# 过滤已存在的文档
|
|
997
|
+
skipped = 0
|
|
998
|
+
if skip_existing:
|
|
999
|
+
existing_ids = self._get_existing_ids([doc.id for doc in documents])
|
|
1000
|
+
skipped = len([doc for doc in documents if doc.id in existing_ids])
|
|
1001
|
+
documents = [doc for doc in documents if doc.id not in existing_ids]
|
|
1002
|
+
if skipped > 0:
|
|
1003
|
+
logger.info(f"Skipped {skipped} existing documents")
|
|
1004
|
+
if not documents:
|
|
1005
|
+
return 0
|
|
1006
|
+
|
|
1007
|
+
# 批量插入
|
|
1008
|
+
inserted = 0
|
|
1009
|
+
total_batches = (len(documents) + batch_size - 1) // batch_size
|
|
1010
|
+
iterator = range(0, len(documents), batch_size)
|
|
1011
|
+
|
|
1012
|
+
if show_progress:
|
|
1013
|
+
try:
|
|
1014
|
+
from tqdm import tqdm
|
|
1015
|
+
iterator = tqdm(
|
|
1016
|
+
iterator,
|
|
1017
|
+
desc="Upserting",
|
|
1018
|
+
total=total_batches,
|
|
1019
|
+
unit="batch",
|
|
1020
|
+
)
|
|
1021
|
+
except ImportError:
|
|
1022
|
+
logger.debug("tqdm not installed, progress bar disabled")
|
|
1023
|
+
|
|
1024
|
+
for i in iterator:
|
|
1025
|
+
batch = documents[i:i + batch_size]
|
|
1026
|
+
self.upsert(batch)
|
|
1027
|
+
inserted += len(batch)
|
|
1028
|
+
|
|
1029
|
+
logger.info(f"Batch upsert completed: {inserted} inserted, {skipped} skipped")
|
|
1030
|
+
return inserted
|
|
1031
|
+
|
|
1032
|
+
def _get_existing_ids(self, candidate_ids: List[str]) -> set:
|
|
1033
|
+
"""获取已存在的文档 ID 集合"""
|
|
1034
|
+
existing_ids = set()
|
|
1035
|
+
batch_size = 1000
|
|
1036
|
+
pk = self._primary_key
|
|
1037
|
+
|
|
1038
|
+
for i in range(0, len(candidate_ids), batch_size):
|
|
1039
|
+
batch_ids = candidate_ids[i:i + batch_size]
|
|
1040
|
+
ids_str = ", ".join([f'"{id}"' for id in batch_ids])
|
|
1041
|
+
try:
|
|
1042
|
+
results = self.collection.query(
|
|
1043
|
+
expr=f"{pk} in [{ids_str}]",
|
|
1044
|
+
output_fields=[pk],
|
|
1045
|
+
)
|
|
1046
|
+
for item in results:
|
|
1047
|
+
existing_ids.add(item.get(pk))
|
|
1048
|
+
except Exception:
|
|
1049
|
+
pass
|
|
1050
|
+
|
|
1051
|
+
return existing_ids
|
|
1052
|
+
|
|
1053
|
+
def get_all_ids(self) -> List[str]:
|
|
1054
|
+
"""获取所有文档 ID"""
|
|
1055
|
+
pk = self._primary_key
|
|
1056
|
+
results = self.collection.query(
|
|
1057
|
+
expr=f"{pk} != ''",
|
|
1058
|
+
output_fields=[pk],
|
|
1059
|
+
limit=16384,
|
|
1060
|
+
)
|
|
1061
|
+
return [item.get(pk, "") for item in results]
|
|
1062
|
+
|
|
1063
|
+
def migrate_to(
|
|
1064
|
+
self,
|
|
1065
|
+
target,
|
|
1066
|
+
batch_size: int = 100,
|
|
1067
|
+
skip_existing: bool = True,
|
|
1068
|
+
show_progress: bool = True,
|
|
1069
|
+
) -> int:
|
|
1070
|
+
"""
|
|
1071
|
+
将当前 collection 的所有数据迁移到目标 retriever
|
|
1072
|
+
|
|
1073
|
+
Args:
|
|
1074
|
+
target: 目标 retriever(ChromaRetriever 或 MilvusRetriever)
|
|
1075
|
+
batch_size: 批处理大小
|
|
1076
|
+
skip_existing: 是否跳过已存在的文档
|
|
1077
|
+
show_progress: 是否显示进度条
|
|
1078
|
+
|
|
1079
|
+
Returns:
|
|
1080
|
+
迁移的文档数量
|
|
1081
|
+
"""
|
|
1082
|
+
all_ids = self.get_all_ids()
|
|
1083
|
+
if not all_ids:
|
|
1084
|
+
logger.info("No documents to migrate")
|
|
1085
|
+
return 0
|
|
1086
|
+
|
|
1087
|
+
total = len(all_ids)
|
|
1088
|
+
logger.info(f"Starting migration: {total} documents")
|
|
1089
|
+
|
|
1090
|
+
migrated = 0
|
|
1091
|
+
iterator = range(0, total, batch_size)
|
|
1092
|
+
|
|
1093
|
+
if show_progress:
|
|
1094
|
+
try:
|
|
1095
|
+
from tqdm import tqdm
|
|
1096
|
+
iterator = tqdm(
|
|
1097
|
+
iterator,
|
|
1098
|
+
desc="Migrating",
|
|
1099
|
+
total=(total + batch_size - 1) // batch_size,
|
|
1100
|
+
unit="batch",
|
|
1101
|
+
)
|
|
1102
|
+
except ImportError:
|
|
1103
|
+
pass
|
|
1104
|
+
|
|
1105
|
+
for i in iterator:
|
|
1106
|
+
batch_ids = all_ids[i:i + batch_size]
|
|
1107
|
+
documents = self.get(ids=batch_ids)
|
|
1108
|
+
if documents:
|
|
1109
|
+
migrated += target.upsert_batch(
|
|
1110
|
+
documents,
|
|
1111
|
+
batch_size=batch_size,
|
|
1112
|
+
skip_existing=skip_existing,
|
|
1113
|
+
show_progress=False,
|
|
1114
|
+
)
|
|
1115
|
+
|
|
1116
|
+
logger.info(f"Migration completed: {migrated} documents migrated")
|
|
1117
|
+
return migrated
|
|
1118
|
+
|
|
1119
|
+
def close(self) -> None:
|
|
1120
|
+
"""关闭连接"""
|
|
1121
|
+
from pymilvus import connections
|
|
1122
|
+
connections.disconnect(self._connection_alias)
|
|
1123
|
+
|
|
1124
|
+
def __repr__(self) -> str:
|
|
1125
|
+
return (
|
|
1126
|
+
f"MilvusRetriever("
|
|
1127
|
+
f"host={self.host!r}, "
|
|
1128
|
+
f"port={self.port}, "
|
|
1129
|
+
f"db={self.db_name!r}, "
|
|
1130
|
+
f"collection={self.collection_name!r}, "
|
|
1131
|
+
f"count={self.count()}, "
|
|
1132
|
+
f"embedding={self.embedding.__class__.__name__})"
|
|
1133
|
+
)
|
|
1134
|
+
|
|
1135
|
+
def __enter__(self):
|
|
1136
|
+
return self
|
|
1137
|
+
|
|
1138
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
1139
|
+
self.close()
|
|
1140
|
+
return False
|