xparse-client 0.2.20__py3-none-any.whl → 0.3.0b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- example/1_basic_api_usage.py +198 -0
- example/2_async_job.py +210 -0
- example/3_local_workflow.py +300 -0
- example/4_advanced_workflow.py +327 -0
- example/README.md +128 -0
- example/config_example.json +95 -0
- tests/conftest.py +310 -0
- tests/unit/__init__.py +1 -0
- tests/unit/api/__init__.py +1 -0
- tests/unit/api/test_extract.py +232 -0
- tests/unit/api/test_local.py +231 -0
- tests/unit/api/test_parse.py +374 -0
- tests/unit/api/test_pipeline.py +369 -0
- tests/unit/api/test_workflows.py +108 -0
- tests/unit/connectors/test_ftp.py +525 -0
- tests/unit/connectors/test_local_connectors.py +324 -0
- tests/unit/connectors/test_milvus.py +368 -0
- tests/unit/connectors/test_qdrant.py +399 -0
- tests/unit/connectors/test_s3.py +598 -0
- tests/unit/connectors/test_smb.py +442 -0
- tests/unit/connectors/test_utils.py +335 -0
- tests/unit/models/test_local.py +54 -0
- tests/unit/models/test_pipeline_stages.py +144 -0
- tests/unit/models/test_workflows.py +55 -0
- tests/unit/test_base.py +437 -0
- tests/unit/test_client.py +110 -0
- tests/unit/test_config.py +160 -0
- tests/unit/test_exceptions.py +182 -0
- tests/unit/test_http.py +562 -0
- xparse_client/__init__.py +110 -20
- xparse_client/_base.py +179 -0
- xparse_client/_client.py +218 -0
- xparse_client/_config.py +221 -0
- xparse_client/_http.py +350 -0
- xparse_client/api/__init__.py +14 -0
- xparse_client/api/extract.py +109 -0
- xparse_client/api/local.py +188 -0
- xparse_client/api/parse.py +209 -0
- xparse_client/api/pipeline.py +132 -0
- xparse_client/api/workflows.py +204 -0
- xparse_client/connectors/__init__.py +45 -0
- xparse_client/connectors/_utils.py +138 -0
- xparse_client/connectors/destinations/__init__.py +45 -0
- xparse_client/connectors/destinations/base.py +116 -0
- xparse_client/connectors/destinations/local.py +91 -0
- xparse_client/connectors/destinations/milvus.py +229 -0
- xparse_client/connectors/destinations/qdrant.py +238 -0
- xparse_client/connectors/destinations/s3.py +163 -0
- xparse_client/connectors/sources/__init__.py +45 -0
- xparse_client/connectors/sources/base.py +74 -0
- xparse_client/connectors/sources/ftp.py +278 -0
- xparse_client/connectors/sources/local.py +176 -0
- xparse_client/connectors/sources/s3.py +232 -0
- xparse_client/connectors/sources/smb.py +259 -0
- xparse_client/exceptions.py +398 -0
- xparse_client/models/__init__.py +60 -0
- xparse_client/models/chunk.py +39 -0
- xparse_client/models/embed.py +62 -0
- xparse_client/models/extract.py +41 -0
- xparse_client/models/local.py +38 -0
- xparse_client/models/parse.py +136 -0
- xparse_client/models/pipeline.py +132 -0
- xparse_client/models/workflows.py +74 -0
- xparse_client-0.3.0b2.dist-info/METADATA +1075 -0
- xparse_client-0.3.0b2.dist-info/RECORD +68 -0
- {xparse_client-0.2.20.dist-info → xparse_client-0.3.0b2.dist-info}/WHEEL +1 -1
- {xparse_client-0.2.20.dist-info → xparse_client-0.3.0b2.dist-info}/licenses/LICENSE +1 -1
- {xparse_client-0.2.20.dist-info → xparse_client-0.3.0b2.dist-info}/top_level.txt +2 -0
- xparse_client/pipeline/__init__.py +0 -3
- xparse_client/pipeline/config.py +0 -163
- xparse_client/pipeline/destinations.py +0 -489
- xparse_client/pipeline/pipeline.py +0 -860
- xparse_client/pipeline/sources.py +0 -583
- xparse_client-0.2.20.dist-info/METADATA +0 -1050
- xparse_client-0.2.20.dist-info/RECORD +0 -11
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""连接器公共工具函数"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
from fnmatch import fnmatch
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def normalize_wildcard_patterns(pattern: list[str] | None) -> list[str] | None:
|
|
15
|
+
"""规范化通配符模式列表
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
pattern: 通配符模式列表,如 ["*.pdf", "*.docx"]
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
规范化后的模式列表,None 表示匹配所有文件
|
|
22
|
+
|
|
23
|
+
Raises:
|
|
24
|
+
ValueError: pattern 类型错误
|
|
25
|
+
"""
|
|
26
|
+
if pattern is None:
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
if not isinstance(pattern, list):
|
|
30
|
+
raise ValueError(f"pattern 必须是列表类型,当前类型: {type(pattern).__name__}")
|
|
31
|
+
|
|
32
|
+
# 清理空字符串
|
|
33
|
+
normalized = [p.strip() for p in pattern if p and p.strip()]
|
|
34
|
+
|
|
35
|
+
# 空列表或包含 "*" 表示匹配所有
|
|
36
|
+
if not normalized or "*" in normalized:
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
return normalized
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def match_file_pattern(file_path: str, patterns: list[str] | None) -> bool:
|
|
43
|
+
"""检查文件路径是否匹配通配符模式
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
file_path: 文件路径
|
|
47
|
+
patterns: 通配符模式列表
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
是否匹配
|
|
51
|
+
"""
|
|
52
|
+
if patterns is None:
|
|
53
|
+
return True
|
|
54
|
+
|
|
55
|
+
# 获取文件名用于匹配
|
|
56
|
+
filename = file_path.rsplit("/", 1)[-1] if "/" in file_path else file_path
|
|
57
|
+
|
|
58
|
+
return any(fnmatch(filename, p) for p in patterns)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def to_millis_timestamp(timestamp: float | None) -> str:
|
|
62
|
+
"""将时间戳转换为毫秒字符串
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
timestamp: Unix 时间戳(秒或毫秒)
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
毫秒时间戳字符串,如果输入为 None 则返回空字符串
|
|
69
|
+
"""
|
|
70
|
+
if timestamp is None:
|
|
71
|
+
return ""
|
|
72
|
+
|
|
73
|
+
# 如果已经是毫秒(大于 1e12),直接返回
|
|
74
|
+
if timestamp > 1e12:
|
|
75
|
+
return str(int(timestamp))
|
|
76
|
+
|
|
77
|
+
# 秒转毫秒
|
|
78
|
+
return str(int(timestamp * 1000))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_current_millis_timestamp() -> str:
|
|
82
|
+
"""获取当前时间的毫秒时间戳字符串"""
|
|
83
|
+
return str(int(datetime.now(timezone.utc).timestamp() * 1000))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def flatten_dict(
|
|
87
|
+
data: dict[str, Any],
|
|
88
|
+
prefix: str = "",
|
|
89
|
+
exclude_fields: set | None = None,
|
|
90
|
+
) -> dict[str, Any]:
|
|
91
|
+
"""递归展平嵌套字典
|
|
92
|
+
|
|
93
|
+
用于将 metadata 中的嵌套结构展平为向量数据库的 payload。
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
data: 要展平的字典
|
|
97
|
+
prefix: 键的前缀
|
|
98
|
+
exclude_fields: 需要排除的字段集合
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
展平后的字典
|
|
102
|
+
|
|
103
|
+
Example:
|
|
104
|
+
>>> data = {"a": {"b": 1, "c": 2}, "d": 3}
|
|
105
|
+
>>> flatten_dict(data, "prefix")
|
|
106
|
+
{"prefix_a_b": 1, "prefix_a_c": 2, "prefix_d": 3}
|
|
107
|
+
"""
|
|
108
|
+
if exclude_fields is None:
|
|
109
|
+
exclude_fields = set()
|
|
110
|
+
|
|
111
|
+
result = {}
|
|
112
|
+
for key, value in data.items():
|
|
113
|
+
flat_key = f"{prefix}_{key}" if prefix else key
|
|
114
|
+
|
|
115
|
+
if flat_key in exclude_fields:
|
|
116
|
+
continue
|
|
117
|
+
|
|
118
|
+
if isinstance(value, dict):
|
|
119
|
+
# 递归展平嵌套字典
|
|
120
|
+
nested = flatten_dict(value, flat_key, exclude_fields)
|
|
121
|
+
result.update(nested)
|
|
122
|
+
elif isinstance(value, list):
|
|
123
|
+
# 列表转换为 JSON 字符串
|
|
124
|
+
result[flat_key] = json.dumps(value, ensure_ascii=False)
|
|
125
|
+
else:
|
|
126
|
+
# 其他类型直接使用
|
|
127
|
+
result[flat_key] = value
|
|
128
|
+
|
|
129
|
+
return result
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
__all__ = [
|
|
133
|
+
"normalize_wildcard_patterns",
|
|
134
|
+
"match_file_pattern",
|
|
135
|
+
"to_millis_timestamp",
|
|
136
|
+
"get_current_millis_timestamp",
|
|
137
|
+
"flatten_dict",
|
|
138
|
+
]
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""目的地连接器
|
|
2
|
+
|
|
3
|
+
提供将处理结果写入各种目的地的能力。
|
|
4
|
+
|
|
5
|
+
可用的目的地:
|
|
6
|
+
- LocalDestination: 本地文件系统
|
|
7
|
+
- S3Destination: S3/MinIO 对象存储(需要 pip install xparse-client[s3])
|
|
8
|
+
- MilvusDestination: Milvus/Zilliz 向量数据库(需要 pip install xparse-client[milvus])
|
|
9
|
+
- QdrantDestination: Qdrant 向量数据库(需要 pip install xparse-client[qdrant])
|
|
10
|
+
|
|
11
|
+
Example:
|
|
12
|
+
>>> from xparse_client.connectors.destinations import LocalDestination
|
|
13
|
+
>>> dest = LocalDestination(output_dir="./output")
|
|
14
|
+
>>> dest.write(elements, {"filename": "doc.pdf", "record_id": "xxx"})
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from .base import Destination, VectorDestinationMixin
|
|
18
|
+
from .local import LocalDestination
|
|
19
|
+
|
|
20
|
+
# 懒加载其他 Destination,避免强依赖
|
|
21
|
+
__all__ = [
|
|
22
|
+
"Destination",
|
|
23
|
+
"VectorDestinationMixin",
|
|
24
|
+
"LocalDestination",
|
|
25
|
+
"S3Destination",
|
|
26
|
+
"MilvusDestination",
|
|
27
|
+
"QdrantDestination",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def __getattr__(name: str):
|
|
32
|
+
"""懒加载 Destination 类"""
|
|
33
|
+
if name == "S3Destination":
|
|
34
|
+
from .s3 import S3Destination
|
|
35
|
+
|
|
36
|
+
return S3Destination
|
|
37
|
+
elif name == "MilvusDestination":
|
|
38
|
+
from .milvus import MilvusDestination
|
|
39
|
+
|
|
40
|
+
return MilvusDestination
|
|
41
|
+
elif name == "QdrantDestination":
|
|
42
|
+
from .qdrant import QdrantDestination
|
|
43
|
+
|
|
44
|
+
return QdrantDestination
|
|
45
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""目的地抽象基类"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Destination(ABC):
|
|
13
|
+
"""目的地抽象基类
|
|
14
|
+
|
|
15
|
+
所有目的地必须实现 write() 方法。
|
|
16
|
+
|
|
17
|
+
Example:
|
|
18
|
+
>>> class MyDestination(Destination):
|
|
19
|
+
... def write(self, data, metadata):
|
|
20
|
+
... # 写入逻辑
|
|
21
|
+
... return True
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def write(self, data: list[dict[str, Any]], metadata: dict[str, Any]) -> bool:
|
|
26
|
+
"""写入数据
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
data: 要写入的数据列表(元素列表或结构化数据)
|
|
30
|
+
metadata: 元数据,包含 filename, record_id, processed_at 等
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
是否写入成功
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
DestinationError: 写入失败
|
|
37
|
+
"""
|
|
38
|
+
raise NotImplementedError
|
|
39
|
+
|
|
40
|
+
def close(self) -> None: # noqa: B027
|
|
41
|
+
"""关闭连接
|
|
42
|
+
|
|
43
|
+
子类可以重写此方法来释放资源。
|
|
44
|
+
"""
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
def __enter__(self) -> Destination:
|
|
48
|
+
"""上下文管理器入口"""
|
|
49
|
+
return self
|
|
50
|
+
|
|
51
|
+
def __exit__(self, *args: Any) -> None:
|
|
52
|
+
"""上下文管理器退出,自动关闭连接"""
|
|
53
|
+
self.close()
|
|
54
|
+
|
|
55
|
+
def __repr__(self) -> str:
|
|
56
|
+
return f"<{self.__class__.__name__}>"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class VectorDestinationMixin:
|
|
60
|
+
"""向量数据库目的地的公共逻辑
|
|
61
|
+
|
|
62
|
+
提供 Milvus 和 Qdrant 共用的 payload 处理逻辑。
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def prepare_payload(
|
|
67
|
+
item: dict[str, Any],
|
|
68
|
+
metadata: dict[str, Any],
|
|
69
|
+
record_id: str,
|
|
70
|
+
fixed_fields: set | None = None,
|
|
71
|
+
) -> dict[str, Any]:
|
|
72
|
+
"""准备向量数据库的 payload
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
item: 单条数据(包含 embeddings, text, metadata 等)
|
|
76
|
+
metadata: 文件级别的元数据
|
|
77
|
+
record_id: 记录 ID
|
|
78
|
+
fixed_fields: 需要排除的固定字段
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
处理后的 payload 字典
|
|
82
|
+
"""
|
|
83
|
+
from .._utils import flatten_dict
|
|
84
|
+
|
|
85
|
+
if fixed_fields is None:
|
|
86
|
+
fixed_fields = {"embeddings", "text", "element_id", "record_id", "metadata"}
|
|
87
|
+
|
|
88
|
+
# 合并元素级和文件级 metadata
|
|
89
|
+
element_metadata = item.get("metadata", {})
|
|
90
|
+
merged_metadata = {**element_metadata, **metadata}
|
|
91
|
+
|
|
92
|
+
payload = {
|
|
93
|
+
"text": item.get("text", ""),
|
|
94
|
+
"record_id": record_id,
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
for key, value in merged_metadata.items():
|
|
98
|
+
if key in fixed_fields:
|
|
99
|
+
continue
|
|
100
|
+
|
|
101
|
+
if key == "data_source" and isinstance(value, dict):
|
|
102
|
+
# 展平 data_source 字典
|
|
103
|
+
flattened = flatten_dict(value, "data_source", fixed_fields)
|
|
104
|
+
payload.update(flattened)
|
|
105
|
+
elif key == "coordinates" and isinstance(value, list):
|
|
106
|
+
payload[key] = value
|
|
107
|
+
elif isinstance(value, (dict, list)):
|
|
108
|
+
# 跳过复杂类型(或在子类中特殊处理)
|
|
109
|
+
continue
|
|
110
|
+
else:
|
|
111
|
+
payload[key] = value
|
|
112
|
+
|
|
113
|
+
return payload
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
__all__ = ["Destination", "VectorDestinationMixin"]
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""本地文件系统目的地"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from ...exceptions import DestinationError
|
|
11
|
+
from .base import Destination
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LocalDestination(Destination):
|
|
17
|
+
"""本地文件系统目的地
|
|
18
|
+
|
|
19
|
+
将处理结果写入本地 JSON 文件。
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
output_dir: 输出目录
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
>>> dest = LocalDestination(output_dir="./output")
|
|
26
|
+
>>> dest.write(elements, {"filename": "doc.pdf"})
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, output_dir: str) -> None:
|
|
30
|
+
"""初始化本地目的地
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
output_dir: 输出目录路径
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
DestinationError: 创建目录失败
|
|
37
|
+
"""
|
|
38
|
+
self.output_dir = Path(output_dir)
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
42
|
+
logger.info(f"本地输出目录: {self.output_dir}")
|
|
43
|
+
except Exception as e:
|
|
44
|
+
raise DestinationError(
|
|
45
|
+
f"创建输出目录失败: {e}",
|
|
46
|
+
connector_type="local",
|
|
47
|
+
operation="init",
|
|
48
|
+
) from e
|
|
49
|
+
|
|
50
|
+
def write(self, data: list[dict[str, Any]], metadata: dict[str, Any]) -> bool:
|
|
51
|
+
"""写入数据到本地文件
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
data: 要写入的数据列表
|
|
55
|
+
metadata: 元数据,包含 filename(必须)和 stage(可选)
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
是否写入成功
|
|
59
|
+
|
|
60
|
+
Raises:
|
|
61
|
+
DestinationError: 写入失败
|
|
62
|
+
"""
|
|
63
|
+
try:
|
|
64
|
+
filename = metadata.get("filename", "output")
|
|
65
|
+
base_name = Path(filename).stem
|
|
66
|
+
stage = metadata.get("stage") # 用于区分中间结果
|
|
67
|
+
|
|
68
|
+
# 构建输出文件名
|
|
69
|
+
if stage:
|
|
70
|
+
output_file = self.output_dir / f"{base_name}_{stage}.json"
|
|
71
|
+
else:
|
|
72
|
+
output_file = self.output_dir / f"{base_name}.json"
|
|
73
|
+
|
|
74
|
+
with open(output_file, "w", encoding="utf-8") as f:
|
|
75
|
+
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
76
|
+
|
|
77
|
+
logger.info(f"写入本地文件: {output_file}")
|
|
78
|
+
return True
|
|
79
|
+
|
|
80
|
+
except Exception as e:
|
|
81
|
+
raise DestinationError(
|
|
82
|
+
f"写入本地文件失败: {e}",
|
|
83
|
+
connector_type="local",
|
|
84
|
+
operation="write",
|
|
85
|
+
) from e
|
|
86
|
+
|
|
87
|
+
def __repr__(self) -> str:
|
|
88
|
+
return f"<LocalDestination output_dir={self.output_dir}>"
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
__all__ = ["LocalDestination"]
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""Milvus/Zilliz 向量数据库目的地(懒加载 pymilvus)"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import uuid
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from ...exceptions import DestinationError
|
|
10
|
+
from .base import Destination, VectorDestinationMixin
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _get_pymilvus():
|
|
16
|
+
"""懒加载 pymilvus"""
|
|
17
|
+
try:
|
|
18
|
+
from pymilvus import DataType, MilvusClient
|
|
19
|
+
|
|
20
|
+
return MilvusClient, DataType
|
|
21
|
+
except ImportError as e:
|
|
22
|
+
raise ImportError(
|
|
23
|
+
"使用 MilvusDestination 需要安装 pymilvus: pip install xparse-client[milvus]"
|
|
24
|
+
) from e
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MilvusDestination(Destination, VectorDestinationMixin):
|
|
28
|
+
"""Milvus/Zilliz 向量数据库目的地
|
|
29
|
+
|
|
30
|
+
支持 Milvus 本地部署(Milvus Lite)和 Zilliz Cloud。
|
|
31
|
+
|
|
32
|
+
Attributes:
|
|
33
|
+
db_path: 数据库路径或 Zilliz URL
|
|
34
|
+
collection_name: Collection 名称
|
|
35
|
+
dimension: 向量维度
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
>>> # Milvus Lite(本地)
|
|
39
|
+
>>> dest = MilvusDestination(
|
|
40
|
+
... db_path="./milvus.db",
|
|
41
|
+
... collection_name="documents",
|
|
42
|
+
... dimension=1024,
|
|
43
|
+
... )
|
|
44
|
+
>>>
|
|
45
|
+
>>> # Zilliz Cloud
|
|
46
|
+
>>> dest = MilvusDestination(
|
|
47
|
+
... db_path="https://xxx.zillizcloud.com",
|
|
48
|
+
... collection_name="documents",
|
|
49
|
+
... dimension=1024,
|
|
50
|
+
... token="your-token",
|
|
51
|
+
... )
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
db_path: str,
|
|
57
|
+
collection_name: str,
|
|
58
|
+
dimension: int,
|
|
59
|
+
api_key: str | None = None,
|
|
60
|
+
token: str | None = None,
|
|
61
|
+
) -> None:
|
|
62
|
+
"""初始化 Milvus 目的地
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
db_path: 数据库路径(本地)或 Zilliz Cloud URL
|
|
66
|
+
collection_name: Collection 名称
|
|
67
|
+
dimension: 向量维度
|
|
68
|
+
api_key: API Key(与 token 功能相同)
|
|
69
|
+
token: Token(与 api_key 功能相同)
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
DestinationError: 连接失败
|
|
73
|
+
"""
|
|
74
|
+
MilvusClient, DataType = _get_pymilvus()
|
|
75
|
+
|
|
76
|
+
self.db_path = db_path
|
|
77
|
+
self.collection_name = collection_name
|
|
78
|
+
self.dimension = dimension
|
|
79
|
+
|
|
80
|
+
client_kwargs = {"uri": db_path}
|
|
81
|
+
if api_key:
|
|
82
|
+
client_kwargs["token"] = api_key
|
|
83
|
+
elif token:
|
|
84
|
+
client_kwargs["token"] = token
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
self.client = MilvusClient(**client_kwargs)
|
|
88
|
+
|
|
89
|
+
# 创建 Collection(如果不存在)
|
|
90
|
+
if not self.client.has_collection(collection_name):
|
|
91
|
+
schema = self.client.create_schema(
|
|
92
|
+
auto_id=False,
|
|
93
|
+
enable_dynamic_field=True,
|
|
94
|
+
)
|
|
95
|
+
schema.add_field(
|
|
96
|
+
field_name="element_id",
|
|
97
|
+
datatype=DataType.VARCHAR,
|
|
98
|
+
max_length=128,
|
|
99
|
+
is_primary=True,
|
|
100
|
+
)
|
|
101
|
+
schema.add_field(
|
|
102
|
+
field_name="embeddings",
|
|
103
|
+
datatype=DataType.FLOAT_VECTOR,
|
|
104
|
+
dim=dimension,
|
|
105
|
+
)
|
|
106
|
+
schema.add_field(
|
|
107
|
+
field_name="text",
|
|
108
|
+
datatype=DataType.VARCHAR,
|
|
109
|
+
max_length=65535,
|
|
110
|
+
)
|
|
111
|
+
schema.add_field(
|
|
112
|
+
field_name="record_id",
|
|
113
|
+
datatype=DataType.VARCHAR,
|
|
114
|
+
max_length=200,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
index_params = self.client.prepare_index_params()
|
|
118
|
+
index_params.add_index(
|
|
119
|
+
field_name="embeddings",
|
|
120
|
+
index_type="AUTOINDEX",
|
|
121
|
+
metric_type="COSINE",
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
self.client.create_collection(
|
|
125
|
+
collection_name=collection_name,
|
|
126
|
+
schema=schema,
|
|
127
|
+
index_params=index_params,
|
|
128
|
+
)
|
|
129
|
+
logger.info(f"Milvus Collection 创建: {collection_name}")
|
|
130
|
+
else:
|
|
131
|
+
logger.info(f"Milvus Collection 已存在: {collection_name}")
|
|
132
|
+
|
|
133
|
+
except ImportError:
|
|
134
|
+
raise
|
|
135
|
+
except Exception as e:
|
|
136
|
+
raise DestinationError(
|
|
137
|
+
f"Milvus 连接失败: {e}",
|
|
138
|
+
connector_type="milvus",
|
|
139
|
+
operation="connect",
|
|
140
|
+
details={"db_path": db_path, "collection_name": collection_name},
|
|
141
|
+
) from e
|
|
142
|
+
|
|
143
|
+
def write(self, data: list[dict[str, Any]], metadata: dict[str, Any]) -> bool:
|
|
144
|
+
"""写入向量数据到 Milvus
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
data: 包含 embeddings 的元素列表
|
|
148
|
+
metadata: 元数据,必须包含 record_id
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
是否写入成功
|
|
152
|
+
|
|
153
|
+
Raises:
|
|
154
|
+
DestinationError: 写入失败
|
|
155
|
+
"""
|
|
156
|
+
record_id = metadata.get("record_id")
|
|
157
|
+
if not record_id:
|
|
158
|
+
logger.warning("没有 record_id,跳过写入")
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
try:
|
|
162
|
+
# 删除旧记录
|
|
163
|
+
try:
|
|
164
|
+
result = self.client.delete(
|
|
165
|
+
collection_name=self.collection_name,
|
|
166
|
+
filter=f'record_id == "{record_id}"',
|
|
167
|
+
)
|
|
168
|
+
deleted = (
|
|
169
|
+
result
|
|
170
|
+
if isinstance(result, int)
|
|
171
|
+
else result.get("delete_count", 0)
|
|
172
|
+
if isinstance(result, dict)
|
|
173
|
+
else 0
|
|
174
|
+
)
|
|
175
|
+
if deleted > 0:
|
|
176
|
+
logger.info(f"删除 Milvus 旧记录: record_id={record_id}, 数量={deleted}")
|
|
177
|
+
except Exception as e:
|
|
178
|
+
logger.warning(f"删除旧记录失败: {e}")
|
|
179
|
+
|
|
180
|
+
# 准备插入数据
|
|
181
|
+
fixed_fields = {"embeddings", "text", "element_id", "record_id", "metadata"}
|
|
182
|
+
insert_data = []
|
|
183
|
+
|
|
184
|
+
for item in data:
|
|
185
|
+
if "embeddings" not in item or not item["embeddings"]:
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
element_id = (
|
|
189
|
+
item.get("element_id") or item.get("id") or str(uuid.uuid4())
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
insert_item = {
|
|
193
|
+
"element_id": element_id,
|
|
194
|
+
"embeddings": item["embeddings"],
|
|
195
|
+
"text": item.get("text", ""),
|
|
196
|
+
"record_id": record_id,
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
# 添加 payload 字段
|
|
200
|
+
payload = self.prepare_payload(item, metadata, record_id, fixed_fields)
|
|
201
|
+
for k, v in payload.items():
|
|
202
|
+
if k not in insert_item:
|
|
203
|
+
insert_item[k] = v
|
|
204
|
+
|
|
205
|
+
insert_data.append(insert_item)
|
|
206
|
+
|
|
207
|
+
if not insert_data:
|
|
208
|
+
logger.warning("没有有效的向量数据")
|
|
209
|
+
return False
|
|
210
|
+
|
|
211
|
+
self.client.insert(
|
|
212
|
+
collection_name=self.collection_name,
|
|
213
|
+
data=insert_data,
|
|
214
|
+
)
|
|
215
|
+
logger.info(f"写入 Milvus: {len(insert_data)} 条")
|
|
216
|
+
return True
|
|
217
|
+
|
|
218
|
+
except Exception as e:
|
|
219
|
+
raise DestinationError(
|
|
220
|
+
f"写入 Milvus 失败: {e}",
|
|
221
|
+
connector_type="milvus",
|
|
222
|
+
operation="write",
|
|
223
|
+
) from e
|
|
224
|
+
|
|
225
|
+
def __repr__(self) -> str:
|
|
226
|
+
return f"<MilvusDestination db_path={self.db_path} collection={self.collection_name}>"
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
__all__ = ["MilvusDestination"]
|