xparse-client 0.2.19__py3-none-any.whl → 0.3.0b8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- example/1_basic_api_usage.py +198 -0
- example/2_async_job.py +210 -0
- example/3_local_workflow.py +300 -0
- example/4_advanced_workflow.py +327 -0
- example/README.md +128 -0
- example/config_example.json +95 -0
- tests/conftest.py +310 -0
- tests/unit/__init__.py +1 -0
- tests/unit/api/__init__.py +1 -0
- tests/unit/api/test_extract.py +232 -0
- tests/unit/api/test_local.py +231 -0
- tests/unit/api/test_parse.py +374 -0
- tests/unit/api/test_pipeline.py +369 -0
- tests/unit/api/test_workflows.py +108 -0
- tests/unit/connectors/test_ftp.py +525 -0
- tests/unit/connectors/test_local_connectors.py +324 -0
- tests/unit/connectors/test_milvus.py +368 -0
- tests/unit/connectors/test_qdrant.py +399 -0
- tests/unit/connectors/test_s3.py +598 -0
- tests/unit/connectors/test_smb.py +442 -0
- tests/unit/connectors/test_utils.py +335 -0
- tests/unit/models/test_local.py +54 -0
- tests/unit/models/test_pipeline_stages.py +144 -0
- tests/unit/models/test_workflows.py +55 -0
- tests/unit/test_base.py +437 -0
- tests/unit/test_client.py +110 -0
- tests/unit/test_config.py +160 -0
- tests/unit/test_exceptions.py +182 -0
- tests/unit/test_http.py +562 -0
- xparse_client/__init__.py +111 -20
- xparse_client/_base.py +188 -0
- xparse_client/_client.py +218 -0
- xparse_client/_config.py +221 -0
- xparse_client/_http.py +351 -0
- xparse_client/api/__init__.py +14 -0
- xparse_client/api/extract.py +109 -0
- xparse_client/api/local.py +225 -0
- xparse_client/api/parse.py +209 -0
- xparse_client/api/pipeline.py +134 -0
- xparse_client/api/workflows.py +204 -0
- xparse_client/connectors/__init__.py +45 -0
- xparse_client/connectors/_utils.py +138 -0
- xparse_client/connectors/destinations/__init__.py +45 -0
- xparse_client/connectors/destinations/base.py +116 -0
- xparse_client/connectors/destinations/local.py +91 -0
- xparse_client/connectors/destinations/milvus.py +229 -0
- xparse_client/connectors/destinations/qdrant.py +238 -0
- xparse_client/connectors/destinations/s3.py +163 -0
- xparse_client/connectors/sources/__init__.py +45 -0
- xparse_client/connectors/sources/base.py +74 -0
- xparse_client/connectors/sources/ftp.py +278 -0
- xparse_client/connectors/sources/local.py +176 -0
- xparse_client/connectors/sources/s3.py +232 -0
- xparse_client/connectors/sources/smb.py +259 -0
- xparse_client/exceptions.py +398 -0
- xparse_client/models/__init__.py +60 -0
- xparse_client/models/chunk.py +39 -0
- xparse_client/models/embed.py +62 -0
- xparse_client/models/extract.py +41 -0
- xparse_client/models/local.py +38 -0
- xparse_client/models/parse.py +132 -0
- xparse_client/models/pipeline.py +134 -0
- xparse_client/models/workflows.py +74 -0
- xparse_client-0.3.0b8.dist-info/METADATA +1075 -0
- xparse_client-0.3.0b8.dist-info/RECORD +68 -0
- {xparse_client-0.2.19.dist-info → xparse_client-0.3.0b8.dist-info}/WHEEL +1 -1
- {xparse_client-0.2.19.dist-info → xparse_client-0.3.0b8.dist-info}/licenses/LICENSE +1 -1
- {xparse_client-0.2.19.dist-info → xparse_client-0.3.0b8.dist-info}/top_level.txt +2 -0
- xparse_client/pipeline/__init__.py +0 -3
- xparse_client/pipeline/config.py +0 -129
- xparse_client/pipeline/destinations.py +0 -489
- xparse_client/pipeline/pipeline.py +0 -690
- xparse_client/pipeline/sources.py +0 -583
- xparse_client-0.2.19.dist-info/METADATA +0 -1050
- xparse_client-0.2.19.dist-info/RECORD +0 -11
|
@@ -1,489 +0,0 @@
|
|
|
1
|
-
#!/usr/bin/env python
|
|
2
|
-
# -*- encoding: utf-8 -*-
|
|
3
|
-
|
|
4
|
-
import json
|
|
5
|
-
import logging
|
|
6
|
-
import uuid
|
|
7
|
-
import boto3
|
|
8
|
-
|
|
9
|
-
from abc import ABC, abstractmethod
|
|
10
|
-
from datetime import datetime
|
|
11
|
-
from pathlib import Path
|
|
12
|
-
from typing import List, Dict, Any
|
|
13
|
-
|
|
14
|
-
from botocore.config import Config
|
|
15
|
-
from pymilvus import MilvusClient
|
|
16
|
-
from qdrant_client import QdrantClient
|
|
17
|
-
from qdrant_client.models import Distance, VectorParams, PointStruct, PayloadSchemaType
|
|
18
|
-
|
|
19
|
-
logger = logging.getLogger(__name__)
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def _flatten_dict(data: Dict[str, Any], prefix: str = '', fixed_fields: set = None) -> Dict[str, Any]:
|
|
23
|
-
"""递归展平嵌套字典
|
|
24
|
-
|
|
25
|
-
Args:
|
|
26
|
-
data: 要展平的字典
|
|
27
|
-
prefix: 键的前缀
|
|
28
|
-
fixed_fields: 需要排除的字段集合
|
|
29
|
-
|
|
30
|
-
Returns:
|
|
31
|
-
展平后的字典
|
|
32
|
-
"""
|
|
33
|
-
if fixed_fields is None:
|
|
34
|
-
fixed_fields = set()
|
|
35
|
-
|
|
36
|
-
result = {}
|
|
37
|
-
for key, value in data.items():
|
|
38
|
-
flat_key = f'{prefix}_{key}' if prefix else key
|
|
39
|
-
|
|
40
|
-
if flat_key in fixed_fields:
|
|
41
|
-
continue
|
|
42
|
-
|
|
43
|
-
if isinstance(value, dict):
|
|
44
|
-
# 递归展平嵌套字典
|
|
45
|
-
nested = _flatten_dict(value, flat_key, fixed_fields)
|
|
46
|
-
result.update(nested)
|
|
47
|
-
elif isinstance(value, list):
|
|
48
|
-
# 列表转换为 JSON 字符串
|
|
49
|
-
result[flat_key] = json.dumps(value, ensure_ascii=False)
|
|
50
|
-
else:
|
|
51
|
-
# 其他类型直接使用
|
|
52
|
-
result[flat_key] = value
|
|
53
|
-
|
|
54
|
-
return result
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
class Destination(ABC):
|
|
58
|
-
"""数据目的地抽象基类"""
|
|
59
|
-
|
|
60
|
-
@abstractmethod
|
|
61
|
-
def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
|
|
62
|
-
"""写入数据"""
|
|
63
|
-
raise NotImplementedError
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
class MilvusDestination(Destination):
|
|
67
|
-
"""Milvus/Zilliz 向量数据库目的地"""
|
|
68
|
-
|
|
69
|
-
def __init__(self, db_path: str, collection_name: str, dimension: int, api_key: str = None, token: str = None):
|
|
70
|
-
from pymilvus import DataType
|
|
71
|
-
|
|
72
|
-
self.db_path = db_path
|
|
73
|
-
self.collection_name = collection_name
|
|
74
|
-
self.dimension = dimension
|
|
75
|
-
|
|
76
|
-
client_kwargs = {'uri': db_path}
|
|
77
|
-
if api_key:
|
|
78
|
-
client_kwargs['token'] = api_key
|
|
79
|
-
elif token:
|
|
80
|
-
client_kwargs['token'] = token
|
|
81
|
-
|
|
82
|
-
self.client = MilvusClient(**client_kwargs)
|
|
83
|
-
|
|
84
|
-
if not self.client.has_collection(collection_name):
|
|
85
|
-
schema = self.client.create_schema(
|
|
86
|
-
auto_id=False,
|
|
87
|
-
enable_dynamic_field=True
|
|
88
|
-
)
|
|
89
|
-
|
|
90
|
-
schema.add_field(field_name="element_id", datatype=DataType.VARCHAR, max_length=128, is_primary=True)
|
|
91
|
-
schema.add_field(field_name="embeddings", datatype=DataType.FLOAT_VECTOR, dim=dimension)
|
|
92
|
-
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
|
|
93
|
-
schema.add_field(field_name="record_id", datatype=DataType.VARCHAR, max_length=200)
|
|
94
|
-
|
|
95
|
-
index_params = self.client.prepare_index_params()
|
|
96
|
-
index_params.add_index(
|
|
97
|
-
field_name="embeddings",
|
|
98
|
-
index_type="AUTOINDEX",
|
|
99
|
-
metric_type="COSINE"
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
self.client.create_collection(
|
|
103
|
-
collection_name=collection_name,
|
|
104
|
-
schema=schema,
|
|
105
|
-
index_params=index_params
|
|
106
|
-
)
|
|
107
|
-
print(f"✓ Milvus/Zilliz 集合创建: {collection_name} (自定义 Schema)")
|
|
108
|
-
else:
|
|
109
|
-
print(f"✓ Milvus/Zilliz 集合存在: {collection_name}")
|
|
110
|
-
|
|
111
|
-
logger.info(f"Milvus/Zilliz 连接成功: {db_path}")
|
|
112
|
-
|
|
113
|
-
def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
|
|
114
|
-
try:
|
|
115
|
-
# 如果 metadata 中有 record_id,先删除相同 record_id 的现有记录
|
|
116
|
-
record_id = metadata.get('record_id')
|
|
117
|
-
if record_id:
|
|
118
|
-
try:
|
|
119
|
-
# 删除相同 record_id 的所有记录
|
|
120
|
-
# MilvusClient.delete 返回删除的记录数(可能是 int 或 dict)
|
|
121
|
-
result = self.client.delete(
|
|
122
|
-
collection_name=self.collection_name,
|
|
123
|
-
filter=f'record_id == "{record_id}"'
|
|
124
|
-
)
|
|
125
|
-
# 处理返回值:可能是数字或字典
|
|
126
|
-
deleted_count = result if isinstance(result, int) else result.get('delete_count', 0) if isinstance(result, dict) else 0
|
|
127
|
-
if deleted_count > 0:
|
|
128
|
-
print(f" ✓ 删除现有记录: record_id={record_id}, 删除 {deleted_count} 条")
|
|
129
|
-
logger.info(f"删除 Milvus 现有记录: record_id={record_id}, 删除 {deleted_count} 条")
|
|
130
|
-
else:
|
|
131
|
-
print(f" → 准备写入记录: record_id={record_id}")
|
|
132
|
-
except Exception as e:
|
|
133
|
-
print(f" ! 删除现有记录失败: {str(e)}")
|
|
134
|
-
logger.warning(f"删除 Milvus 现有记录失败: record_id={record_id}, {str(e)}")
|
|
135
|
-
# 继续执行写入,不因为删除失败而中断
|
|
136
|
-
else:
|
|
137
|
-
print(f" → 没有 record_id")
|
|
138
|
-
logger.warning(f"没有 record_id")
|
|
139
|
-
return
|
|
140
|
-
|
|
141
|
-
insert_data = []
|
|
142
|
-
for item in data:
|
|
143
|
-
# 获取元素级别的 metadata
|
|
144
|
-
element_metadata = item.get('metadata', {})
|
|
145
|
-
|
|
146
|
-
if 'embeddings' in item and item['embeddings']:
|
|
147
|
-
element_id = item.get('element_id') or item.get('id') or str(uuid.uuid4())
|
|
148
|
-
|
|
149
|
-
# 构建基础数据
|
|
150
|
-
insert_item = {
|
|
151
|
-
'embeddings': item['embeddings'],
|
|
152
|
-
'text': item.get('text', ''),
|
|
153
|
-
'element_id': element_id,
|
|
154
|
-
'record_id': record_id
|
|
155
|
-
}
|
|
156
|
-
|
|
157
|
-
# 合并文件级别的 metadata 和元素级别的 metadata
|
|
158
|
-
# 文件级别的 metadata 优先级更高
|
|
159
|
-
merged_metadata = {**element_metadata, **metadata}
|
|
160
|
-
|
|
161
|
-
# 将 metadata 中的字段展平到顶层作为动态字段
|
|
162
|
-
# 排除已存在的固定字段,避免冲突
|
|
163
|
-
fixed_fields = {'embeddings', 'text', 'element_id', 'record_id', 'created_at', 'metadata'}
|
|
164
|
-
for key, value in merged_metadata.items():
|
|
165
|
-
if key not in fixed_fields:
|
|
166
|
-
# 特殊处理 data_source 字段:如果是字典则递归展平
|
|
167
|
-
if key == 'data_source' and isinstance(value, dict):
|
|
168
|
-
# 递归展平 data_source 字典,包括嵌套的字典
|
|
169
|
-
flattened = _flatten_dict(value, 'data_source', fixed_fields)
|
|
170
|
-
insert_item.update(flattened)
|
|
171
|
-
elif key == 'coordinates' and isinstance(value, list):
|
|
172
|
-
insert_item[key] = value
|
|
173
|
-
elif isinstance(value, (dict, list)):
|
|
174
|
-
continue
|
|
175
|
-
else:
|
|
176
|
-
insert_item[key] = value
|
|
177
|
-
|
|
178
|
-
insert_data.append(insert_item)
|
|
179
|
-
|
|
180
|
-
if not insert_data:
|
|
181
|
-
print(f" ! 警告: 没有有效的向量数据")
|
|
182
|
-
return False
|
|
183
|
-
|
|
184
|
-
self.client.insert(
|
|
185
|
-
collection_name=self.collection_name,
|
|
186
|
-
data=insert_data
|
|
187
|
-
)
|
|
188
|
-
print(f" ✓ 写入 Milvus: {len(insert_data)} 条")
|
|
189
|
-
logger.info(f"写入 Milvus 成功: {len(insert_data)} 条")
|
|
190
|
-
return True
|
|
191
|
-
except Exception as e:
|
|
192
|
-
print(f" ✗ 写入 Milvus 失败: {str(e)}")
|
|
193
|
-
logger.error(f"写入 Milvus 失败: {str(e)}")
|
|
194
|
-
return False
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
class LocalDestination(Destination):
|
|
198
|
-
"""本地文件系统目的地"""
|
|
199
|
-
|
|
200
|
-
def __init__(self, output_dir: str):
|
|
201
|
-
self.output_dir = Path(output_dir)
|
|
202
|
-
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
203
|
-
print(f"✓ 输出目录: {self.output_dir}")
|
|
204
|
-
logger.info(f"输出目录: {self.output_dir}")
|
|
205
|
-
|
|
206
|
-
def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
|
|
207
|
-
try:
|
|
208
|
-
filename = metadata.get('filename', 'output')
|
|
209
|
-
base_name = Path(filename).stem
|
|
210
|
-
stage = metadata.get('stage') # 用于区分中间结果的阶段
|
|
211
|
-
|
|
212
|
-
# 如果是中间结果,在文件名中添加阶段标识
|
|
213
|
-
if stage:
|
|
214
|
-
output_file = self.output_dir / f"{base_name}_{stage}.json"
|
|
215
|
-
else:
|
|
216
|
-
output_file = self.output_dir / f"{base_name}.json"
|
|
217
|
-
|
|
218
|
-
with open(output_file, 'w', encoding='utf-8') as f:
|
|
219
|
-
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
220
|
-
|
|
221
|
-
print(f" ✓ 写入本地: {output_file}")
|
|
222
|
-
logger.info(f"写入本地成功: {output_file}")
|
|
223
|
-
return True
|
|
224
|
-
except Exception as e:
|
|
225
|
-
print(f" ✗ 写入本地失败: {str(e)}")
|
|
226
|
-
logger.error(f"写入本地失败: {str(e)}")
|
|
227
|
-
return False
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
class S3Destination(Destination):
|
|
231
|
-
"""S3/MinIO 数据目的地"""
|
|
232
|
-
|
|
233
|
-
def __init__(self, endpoint: str, access_key: str, secret_key: str,
|
|
234
|
-
bucket: str, prefix: str = '', region: str = 'us-east-1'):
|
|
235
|
-
self.endpoint = endpoint
|
|
236
|
-
self.bucket = bucket
|
|
237
|
-
self.prefix = prefix.strip('/') if prefix else ''
|
|
238
|
-
|
|
239
|
-
if self.endpoint == 'https://textin-minio-api.ai.intsig.net':
|
|
240
|
-
config = Config(signature_version='s3v4')
|
|
241
|
-
elif self.endpoint.endswith('aliyuncs.com'):
|
|
242
|
-
config = Config(signature_version='s3', s3={'addressing_style': 'virtual'})
|
|
243
|
-
elif self.endpoint.endswith('myhuaweicloud.com'):
|
|
244
|
-
config = Config(signature_version='s3', s3={'addressing_style': 'virtual'})
|
|
245
|
-
else:
|
|
246
|
-
config = Config(signature_version='s3v4', s3={'addressing_style': 'virtual'})
|
|
247
|
-
|
|
248
|
-
self.client = boto3.client(
|
|
249
|
-
's3',
|
|
250
|
-
endpoint_url=endpoint,
|
|
251
|
-
aws_access_key_id=access_key,
|
|
252
|
-
aws_secret_access_key=secret_key,
|
|
253
|
-
region_name=region,
|
|
254
|
-
config=config
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
try:
|
|
258
|
-
self.client.head_bucket(Bucket=bucket)
|
|
259
|
-
test_key = f"{self.prefix}/empty.tmp" if self.prefix else f"empty.tmp"
|
|
260
|
-
self.client.put_object(
|
|
261
|
-
Bucket=bucket,
|
|
262
|
-
Key=test_key,
|
|
263
|
-
Body=b''
|
|
264
|
-
)
|
|
265
|
-
try:
|
|
266
|
-
self.client.delete_object(Bucket=bucket, Key=test_key)
|
|
267
|
-
except Exception:
|
|
268
|
-
pass
|
|
269
|
-
|
|
270
|
-
print(f"✓ S3 连接成功且可写: {endpoint}/{bucket}")
|
|
271
|
-
logger.info(f"S3 连接成功且可写: {endpoint}/{bucket}")
|
|
272
|
-
except Exception as e:
|
|
273
|
-
print(f"✗ S3 连接或写入测试失败: {str(e)}")
|
|
274
|
-
logger.error(f"S3 连接或写入测试失败: {str(e)}")
|
|
275
|
-
raise
|
|
276
|
-
|
|
277
|
-
def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
|
|
278
|
-
try:
|
|
279
|
-
filename = metadata.get('filename', 'output')
|
|
280
|
-
base_name = Path(filename).stem
|
|
281
|
-
object_key = f"{self.prefix}/{base_name}.json" if self.prefix else f"{base_name}.json"
|
|
282
|
-
|
|
283
|
-
json_data = json.dumps(data, ensure_ascii=False, indent=2)
|
|
284
|
-
json_bytes = json_data.encode('utf-8')
|
|
285
|
-
|
|
286
|
-
self.client.put_object(
|
|
287
|
-
Bucket=self.bucket,
|
|
288
|
-
Key=object_key,
|
|
289
|
-
Body=json_bytes,
|
|
290
|
-
ContentType='application/json'
|
|
291
|
-
)
|
|
292
|
-
|
|
293
|
-
print(f" ✓ 写入 S3: {self.endpoint}/{self.bucket}/{object_key}")
|
|
294
|
-
logger.info(f"写入 S3 成功: {self.endpoint}/{self.bucket}/{object_key}")
|
|
295
|
-
return True
|
|
296
|
-
except Exception as e:
|
|
297
|
-
print(f" ✗ 写入 S3 失败: {str(e)}")
|
|
298
|
-
logger.error(f"写入 S3 失败: {str(e)}")
|
|
299
|
-
return False
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
class QdrantDestination(Destination):
|
|
303
|
-
"""Qdrant 向量数据库目的地"""
|
|
304
|
-
|
|
305
|
-
def __init__(self, url: str, collection_name: str, dimension: int, api_key: str = None, prefer_grpc: bool = False):
|
|
306
|
-
"""初始化 Qdrant 目的地
|
|
307
|
-
|
|
308
|
-
Args:
|
|
309
|
-
url: Qdrant 服务地址(如 'http://localhost:6333' 或 'https://xxx.qdrant.io')
|
|
310
|
-
collection_name: Collection 名称
|
|
311
|
-
dimension: 向量维度
|
|
312
|
-
api_key: API Key(可选,用于 Qdrant Cloud)
|
|
313
|
-
prefer_grpc: 是否优先使用 gRPC(默认 False,使用 HTTP)
|
|
314
|
-
"""
|
|
315
|
-
|
|
316
|
-
self.url = url
|
|
317
|
-
self.collection_name = collection_name
|
|
318
|
-
self.dimension = dimension
|
|
319
|
-
|
|
320
|
-
client_kwargs = {'url': url}
|
|
321
|
-
if api_key:
|
|
322
|
-
client_kwargs['api_key'] = api_key
|
|
323
|
-
if prefer_grpc:
|
|
324
|
-
client_kwargs['prefer_grpc'] = True
|
|
325
|
-
|
|
326
|
-
self.client = QdrantClient(**client_kwargs)
|
|
327
|
-
|
|
328
|
-
# 检查或创建 collection
|
|
329
|
-
try:
|
|
330
|
-
collections = self.client.get_collections()
|
|
331
|
-
collection_exists = any(col.name == collection_name for col in collections.collections)
|
|
332
|
-
|
|
333
|
-
if not collection_exists:
|
|
334
|
-
self.client.create_collection(
|
|
335
|
-
collection_name=collection_name,
|
|
336
|
-
vectors_config=VectorParams(
|
|
337
|
-
size=dimension,
|
|
338
|
-
distance=Distance.COSINE
|
|
339
|
-
)
|
|
340
|
-
)
|
|
341
|
-
# 为 record_id 创建索引,用于过滤查询
|
|
342
|
-
try:
|
|
343
|
-
self.client.create_payload_index(
|
|
344
|
-
collection_name=collection_name,
|
|
345
|
-
field_name="record_id",
|
|
346
|
-
field_schema=PayloadSchemaType.KEYWORD
|
|
347
|
-
)
|
|
348
|
-
print(f"✓ Qdrant Collection 创建: {collection_name} (维度: {dimension})")
|
|
349
|
-
except Exception as e:
|
|
350
|
-
logger.warning(f"创建 record_id 索引失败(可能已存在): {str(e)}")
|
|
351
|
-
print(f"✓ Qdrant Collection 创建: {collection_name} (维度: {dimension})")
|
|
352
|
-
else:
|
|
353
|
-
print(f"✓ Qdrant Collection 存在: {collection_name}")
|
|
354
|
-
# 确保 record_id 索引存在(如果不存在则创建)
|
|
355
|
-
try:
|
|
356
|
-
self.client.create_payload_index(
|
|
357
|
-
collection_name=collection_name,
|
|
358
|
-
field_name="record_id",
|
|
359
|
-
field_schema=PayloadSchemaType.KEYWORD
|
|
360
|
-
)
|
|
361
|
-
except Exception as e:
|
|
362
|
-
# 索引可能已存在,忽略错误
|
|
363
|
-
logger.debug(f"record_id 索引可能已存在: {str(e)}")
|
|
364
|
-
|
|
365
|
-
logger.info(f"Qdrant 连接成功: {url}/{collection_name}")
|
|
366
|
-
except Exception as e:
|
|
367
|
-
print(f"✗ Qdrant 连接失败: {str(e)}")
|
|
368
|
-
logger.error(f"Qdrant 连接失败: {str(e)}")
|
|
369
|
-
raise
|
|
370
|
-
|
|
371
|
-
def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
|
|
372
|
-
try:
|
|
373
|
-
# 如果 metadata 中有 record_id,先删除相同 record_id 的现有记录
|
|
374
|
-
record_id = metadata.get('record_id')
|
|
375
|
-
if record_id:
|
|
376
|
-
try:
|
|
377
|
-
# 查询并删除相同 record_id 的所有记录
|
|
378
|
-
# 使用字典格式的 filter(兼容性更好)
|
|
379
|
-
scroll_result = self.client.scroll(
|
|
380
|
-
collection_name=self.collection_name,
|
|
381
|
-
scroll_filter={
|
|
382
|
-
"must": [
|
|
383
|
-
{
|
|
384
|
-
"key": "record_id",
|
|
385
|
-
"match": {"value": record_id}
|
|
386
|
-
}
|
|
387
|
-
]
|
|
388
|
-
},
|
|
389
|
-
limit=10000 # 假设单次最多删除 10000 条
|
|
390
|
-
)
|
|
391
|
-
|
|
392
|
-
if scroll_result[0]: # 有记录
|
|
393
|
-
point_ids = [point.id for point in scroll_result[0]]
|
|
394
|
-
self.client.delete(
|
|
395
|
-
collection_name=self.collection_name,
|
|
396
|
-
points_selector=point_ids
|
|
397
|
-
)
|
|
398
|
-
print(f" ✓ 删除现有记录: record_id={record_id}, 删除 {len(point_ids)} 条")
|
|
399
|
-
logger.info(f"删除 Qdrant 现有记录: record_id={record_id}, 删除 {len(point_ids)} 条")
|
|
400
|
-
else:
|
|
401
|
-
print(f" → 准备写入记录: record_id={record_id}")
|
|
402
|
-
except Exception as e:
|
|
403
|
-
print(f" ! 删除现有记录失败: {str(e)}")
|
|
404
|
-
logger.warning(f"删除 Qdrant 现有记录失败: record_id={record_id}, {str(e)}")
|
|
405
|
-
# 继续执行写入,不因为删除失败而中断
|
|
406
|
-
else:
|
|
407
|
-
print(f" → 没有 record_id")
|
|
408
|
-
logger.warning(f"没有 record_id")
|
|
409
|
-
return False
|
|
410
|
-
|
|
411
|
-
points = []
|
|
412
|
-
for item in data:
|
|
413
|
-
# 获取元素级别的 metadata
|
|
414
|
-
element_metadata = item.get('metadata', {})
|
|
415
|
-
|
|
416
|
-
if 'embeddings' in item and item['embeddings']:
|
|
417
|
-
element_id = item.get('element_id') or item.get('id') or str(uuid.uuid4())
|
|
418
|
-
|
|
419
|
-
# 构建 payload(元数据)
|
|
420
|
-
payload = {
|
|
421
|
-
'text': item.get('text', ''),
|
|
422
|
-
'record_id': record_id,
|
|
423
|
-
}
|
|
424
|
-
|
|
425
|
-
# 合并文件级别的 metadata 和元素级别的 metadata
|
|
426
|
-
# 文件级别的 metadata 优先级更高
|
|
427
|
-
merged_metadata = {**element_metadata, **metadata}
|
|
428
|
-
|
|
429
|
-
# 将 metadata 中的字段添加到 payload
|
|
430
|
-
# 排除已存在的固定字段,避免冲突
|
|
431
|
-
fixed_fields = {'embeddings', 'text', 'element_id', 'record_id', 'created_at', 'metadata'}
|
|
432
|
-
for key, value in merged_metadata.items():
|
|
433
|
-
if key not in fixed_fields:
|
|
434
|
-
# 特殊处理 data_source 字段:如果是字典则递归展平
|
|
435
|
-
if key == 'data_source' and isinstance(value, dict):
|
|
436
|
-
# 递归展平 data_source 字典,包括嵌套的字典
|
|
437
|
-
flattened = _flatten_dict(value, 'data_source', fixed_fields)
|
|
438
|
-
payload.update(flattened)
|
|
439
|
-
elif key == 'coordinates' and isinstance(value, list):
|
|
440
|
-
payload[key] = value
|
|
441
|
-
elif isinstance(value, (dict, list)):
|
|
442
|
-
# Qdrant 支持 JSON 格式的 payload
|
|
443
|
-
payload[key] = value
|
|
444
|
-
else:
|
|
445
|
-
payload[key] = value
|
|
446
|
-
|
|
447
|
-
# 创建 Point(id 是必需的)
|
|
448
|
-
# Qdrant 的 point id 可以是整数或 UUID 字符串
|
|
449
|
-
# 如果 element_id 是 UUID 格式,直接使用;否则转换为 UUID5(基于 element_id 生成稳定的 UUID)
|
|
450
|
-
try:
|
|
451
|
-
# 尝试将 element_id 解析为 UUID
|
|
452
|
-
point_id = str(uuid.UUID(element_id))
|
|
453
|
-
except (ValueError, TypeError):
|
|
454
|
-
# 如果不是有效的 UUID,使用 UUID5 基于 element_id 生成稳定的 UUID
|
|
455
|
-
point_id = str(uuid.uuid5(uuid.NAMESPACE_URL, str(element_id)))
|
|
456
|
-
|
|
457
|
-
point = PointStruct(
|
|
458
|
-
id=point_id,
|
|
459
|
-
vector=item['embeddings'],
|
|
460
|
-
payload=payload
|
|
461
|
-
)
|
|
462
|
-
points.append(point)
|
|
463
|
-
|
|
464
|
-
if not points:
|
|
465
|
-
print(f" ! 警告: 没有有效的向量数据")
|
|
466
|
-
return False
|
|
467
|
-
|
|
468
|
-
# 批量插入
|
|
469
|
-
self.client.upsert(
|
|
470
|
-
collection_name=self.collection_name,
|
|
471
|
-
points=points
|
|
472
|
-
)
|
|
473
|
-
print(f" ✓ 写入 Qdrant: {len(points)} 条")
|
|
474
|
-
logger.info(f"写入 Qdrant 成功: {len(points)} 条")
|
|
475
|
-
return True
|
|
476
|
-
except Exception as e:
|
|
477
|
-
print(f" ✗ 写入 Qdrant 失败: {str(e)}")
|
|
478
|
-
logger.error(f"写入 Qdrant 失败: {str(e)}")
|
|
479
|
-
return False
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
__all__ = [
|
|
483
|
-
'Destination',
|
|
484
|
-
'MilvusDestination',
|
|
485
|
-
'QdrantDestination',
|
|
486
|
-
'LocalDestination',
|
|
487
|
-
'S3Destination',
|
|
488
|
-
]
|
|
489
|
-
|