xparse-client 0.2.20__py3-none-any.whl → 0.3.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. example/1_basic_api_usage.py +198 -0
  2. example/2_async_job.py +210 -0
  3. example/3_local_workflow.py +300 -0
  4. example/4_advanced_workflow.py +327 -0
  5. example/README.md +128 -0
  6. example/config_example.json +95 -0
  7. tests/conftest.py +310 -0
  8. tests/unit/__init__.py +1 -0
  9. tests/unit/api/__init__.py +1 -0
  10. tests/unit/api/test_extract.py +232 -0
  11. tests/unit/api/test_local.py +231 -0
  12. tests/unit/api/test_parse.py +374 -0
  13. tests/unit/api/test_pipeline.py +369 -0
  14. tests/unit/api/test_workflows.py +108 -0
  15. tests/unit/connectors/test_ftp.py +525 -0
  16. tests/unit/connectors/test_local_connectors.py +324 -0
  17. tests/unit/connectors/test_milvus.py +368 -0
  18. tests/unit/connectors/test_qdrant.py +399 -0
  19. tests/unit/connectors/test_s3.py +598 -0
  20. tests/unit/connectors/test_smb.py +442 -0
  21. tests/unit/connectors/test_utils.py +335 -0
  22. tests/unit/models/test_local.py +54 -0
  23. tests/unit/models/test_pipeline_stages.py +144 -0
  24. tests/unit/models/test_workflows.py +55 -0
  25. tests/unit/test_base.py +437 -0
  26. tests/unit/test_client.py +110 -0
  27. tests/unit/test_config.py +160 -0
  28. tests/unit/test_exceptions.py +182 -0
  29. tests/unit/test_http.py +562 -0
  30. xparse_client/__init__.py +110 -20
  31. xparse_client/_base.py +179 -0
  32. xparse_client/_client.py +218 -0
  33. xparse_client/_config.py +221 -0
  34. xparse_client/_http.py +350 -0
  35. xparse_client/api/__init__.py +14 -0
  36. xparse_client/api/extract.py +109 -0
  37. xparse_client/api/local.py +185 -0
  38. xparse_client/api/parse.py +209 -0
  39. xparse_client/api/pipeline.py +132 -0
  40. xparse_client/api/workflows.py +204 -0
  41. xparse_client/connectors/__init__.py +45 -0
  42. xparse_client/connectors/_utils.py +138 -0
  43. xparse_client/connectors/destinations/__init__.py +45 -0
  44. xparse_client/connectors/destinations/base.py +116 -0
  45. xparse_client/connectors/destinations/local.py +91 -0
  46. xparse_client/connectors/destinations/milvus.py +229 -0
  47. xparse_client/connectors/destinations/qdrant.py +238 -0
  48. xparse_client/connectors/destinations/s3.py +163 -0
  49. xparse_client/connectors/sources/__init__.py +45 -0
  50. xparse_client/connectors/sources/base.py +74 -0
  51. xparse_client/connectors/sources/ftp.py +278 -0
  52. xparse_client/connectors/sources/local.py +176 -0
  53. xparse_client/connectors/sources/s3.py +232 -0
  54. xparse_client/connectors/sources/smb.py +259 -0
  55. xparse_client/exceptions.py +398 -0
  56. xparse_client/models/__init__.py +60 -0
  57. xparse_client/models/chunk.py +39 -0
  58. xparse_client/models/embed.py +62 -0
  59. xparse_client/models/extract.py +41 -0
  60. xparse_client/models/local.py +38 -0
  61. xparse_client/models/parse.py +136 -0
  62. xparse_client/models/pipeline.py +132 -0
  63. xparse_client/models/workflows.py +74 -0
  64. xparse_client-0.3.0b1.dist-info/METADATA +1075 -0
  65. xparse_client-0.3.0b1.dist-info/RECORD +68 -0
  66. {xparse_client-0.2.20.dist-info → xparse_client-0.3.0b1.dist-info}/WHEEL +1 -1
  67. {xparse_client-0.2.20.dist-info → xparse_client-0.3.0b1.dist-info}/licenses/LICENSE +1 -1
  68. {xparse_client-0.2.20.dist-info → xparse_client-0.3.0b1.dist-info}/top_level.txt +2 -0
  69. xparse_client/pipeline/__init__.py +0 -3
  70. xparse_client/pipeline/config.py +0 -163
  71. xparse_client/pipeline/destinations.py +0 -489
  72. xparse_client/pipeline/pipeline.py +0 -860
  73. xparse_client/pipeline/sources.py +0 -583
  74. xparse_client-0.2.20.dist-info/METADATA +0 -1050
  75. xparse_client-0.2.20.dist-info/RECORD +0 -11
@@ -0,0 +1,238 @@
1
+ """Qdrant 向量数据库目的地(懒加载 qdrant-client)"""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import uuid
7
+ from typing import Any
8
+
9
+ from ...exceptions import DestinationError
10
+ from .base import Destination, VectorDestinationMixin
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def _get_qdrant():
16
+ """懒加载 qdrant-client"""
17
+ try:
18
+ from qdrant_client import QdrantClient
19
+ from qdrant_client.models import (
20
+ Distance,
21
+ PayloadSchemaType,
22
+ PointStruct,
23
+ VectorParams,
24
+ )
25
+
26
+ return QdrantClient, Distance, PayloadSchemaType, PointStruct, VectorParams
27
+ except ImportError as e:
28
+ raise ImportError(
29
+ "使用 QdrantDestination 需要安装 qdrant-client: pip install xparse-client[qdrant]"
30
+ ) from e
31
+
32
+
33
+ class QdrantDestination(Destination, VectorDestinationMixin):
34
+ """Qdrant 向量数据库目的地
35
+
36
+ 支持 Qdrant 本地部署和 Qdrant Cloud。
37
+
38
+ Attributes:
39
+ url: Qdrant 服务地址
40
+ collection_name: Collection 名称
41
+ dimension: 向量维度
42
+
43
+ Example:
44
+ >>> # 本地 Qdrant
45
+ >>> dest = QdrantDestination(
46
+ ... url="http://localhost:6333",
47
+ ... collection_name="documents",
48
+ ... dimension=1024,
49
+ ... )
50
+ >>>
51
+ >>> # Qdrant Cloud
52
+ >>> dest = QdrantDestination(
53
+ ... url="https://xxx.qdrant.io",
54
+ ... collection_name="documents",
55
+ ... dimension=1024,
56
+ ... api_key="your-api-key",
57
+ ... )
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ url: str,
63
+ collection_name: str,
64
+ dimension: int,
65
+ api_key: str | None = None,
66
+ prefer_grpc: bool = False,
67
+ ) -> None:
68
+ """初始化 Qdrant 目的地
69
+
70
+ Args:
71
+ url: Qdrant 服务地址
72
+ collection_name: Collection 名称
73
+ dimension: 向量维度
74
+ api_key: API Key(用于 Qdrant Cloud)
75
+ prefer_grpc: 是否优先使用 gRPC,默认 False
76
+
77
+ Raises:
78
+ DestinationError: 连接失败
79
+ """
80
+ (
81
+ QdrantClient,
82
+ Distance,
83
+ PayloadSchemaType,
84
+ PointStruct,
85
+ VectorParams,
86
+ ) = _get_qdrant()
87
+
88
+ self.url = url
89
+ self.collection_name = collection_name
90
+ self.dimension = dimension
91
+ self._PointStruct = PointStruct
92
+
93
+ client_kwargs = {"url": url}
94
+ if api_key:
95
+ client_kwargs["api_key"] = api_key
96
+ if prefer_grpc:
97
+ client_kwargs["prefer_grpc"] = True
98
+
99
+ try:
100
+ self.client = QdrantClient(**client_kwargs)
101
+
102
+ # 检查或创建 Collection
103
+ collections = self.client.get_collections()
104
+ collection_exists = any(
105
+ col.name == collection_name for col in collections.collections
106
+ )
107
+
108
+ if not collection_exists:
109
+ self.client.create_collection(
110
+ collection_name=collection_name,
111
+ vectors_config=VectorParams(size=dimension, distance=Distance.COSINE),
112
+ )
113
+ # 为 record_id 创建索引
114
+ try:
115
+ self.client.create_payload_index(
116
+ collection_name=collection_name,
117
+ field_name="record_id",
118
+ field_schema=PayloadSchemaType.KEYWORD,
119
+ )
120
+ except Exception as e:
121
+ logger.warning(f"创建 record_id 索引失败: {e}")
122
+
123
+ logger.info(f"Qdrant Collection 创建: {collection_name}")
124
+ else:
125
+ logger.info(f"Qdrant Collection 已存在: {collection_name}")
126
+ # 确保索引存在
127
+ try:
128
+ self.client.create_payload_index(
129
+ collection_name=collection_name,
130
+ field_name="record_id",
131
+ field_schema=PayloadSchemaType.KEYWORD,
132
+ )
133
+ except Exception:
134
+ pass # 索引可能已存在
135
+
136
+ except ImportError:
137
+ raise
138
+ except Exception as e:
139
+ raise DestinationError(
140
+ f"Qdrant 连接失败: {e}",
141
+ connector_type="qdrant",
142
+ operation="connect",
143
+ details={"url": url, "collection_name": collection_name},
144
+ ) from e
145
+
146
+ def write(self, data: list[dict[str, Any]], metadata: dict[str, Any]) -> bool:
147
+ """写入向量数据到 Qdrant
148
+
149
+ Args:
150
+ data: 包含 embeddings 的元素列表
151
+ metadata: 元数据,必须包含 record_id
152
+
153
+ Returns:
154
+ 是否写入成功
155
+
156
+ Raises:
157
+ DestinationError: 写入失败
158
+ """
159
+ record_id = metadata.get("record_id")
160
+ if not record_id:
161
+ logger.warning("没有 record_id,跳过写入")
162
+ return False
163
+
164
+ try:
165
+ # 删除旧记录
166
+ try:
167
+ scroll_result = self.client.scroll(
168
+ collection_name=self.collection_name,
169
+ scroll_filter={
170
+ "must": [{"key": "record_id", "match": {"value": record_id}}]
171
+ },
172
+ limit=10000,
173
+ )
174
+
175
+ if scroll_result[0]:
176
+ point_ids = [point.id for point in scroll_result[0]]
177
+ self.client.delete(
178
+ collection_name=self.collection_name,
179
+ points_selector=point_ids,
180
+ )
181
+ logger.info(
182
+ f"删除 Qdrant 旧记录: record_id={record_id}, 数量={len(point_ids)}"
183
+ )
184
+ except Exception as e:
185
+ logger.warning(f"删除旧记录失败: {e}")
186
+
187
+ # 准备插入数据
188
+ fixed_fields = {"embeddings", "text", "element_id", "record_id", "metadata"}
189
+ points = []
190
+
191
+ for item in data:
192
+ if "embeddings" not in item or not item["embeddings"]:
193
+ continue
194
+
195
+ element_id = (
196
+ item.get("element_id") or item.get("id") or str(uuid.uuid4())
197
+ )
198
+
199
+ # Qdrant 的 point id 需要是 UUID 或整数
200
+ try:
201
+ point_id = str(uuid.UUID(element_id))
202
+ except (ValueError, TypeError):
203
+ point_id = str(uuid.uuid5(uuid.NAMESPACE_URL, str(element_id)))
204
+
205
+ # 准备 payload
206
+ payload = self.prepare_payload(item, metadata, record_id, fixed_fields)
207
+
208
+ point = self._PointStruct(
209
+ id=point_id,
210
+ vector=item["embeddings"],
211
+ payload=payload,
212
+ )
213
+ points.append(point)
214
+
215
+ if not points:
216
+ logger.warning("没有有效的向量数据")
217
+ return False
218
+
219
+ # 批量插入
220
+ self.client.upsert(
221
+ collection_name=self.collection_name,
222
+ points=points,
223
+ )
224
+ logger.info(f"写入 Qdrant: {len(points)} 条")
225
+ return True
226
+
227
+ except Exception as e:
228
+ raise DestinationError(
229
+ f"写入 Qdrant 失败: {e}",
230
+ connector_type="qdrant",
231
+ operation="write",
232
+ ) from e
233
+
234
+ def __repr__(self) -> str:
235
+ return f"<QdrantDestination url={self.url} collection={self.collection_name}>"
236
+
237
+
238
+ __all__ = ["QdrantDestination"]
@@ -0,0 +1,163 @@
1
+ """S3/MinIO 目的地(懒加载 boto3)"""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from ...exceptions import DestinationError
11
+ from .base import Destination
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def _get_boto3():
17
+ """懒加载 boto3"""
18
+ try:
19
+ import boto3
20
+ from botocore.config import Config
21
+
22
+ return boto3, Config
23
+ except ImportError as e:
24
+ raise ImportError(
25
+ "使用 S3Destination 需要安装 boto3: pip install xparse-client[s3]"
26
+ ) from e
27
+
28
+
29
+ class S3Destination(Destination):
30
+ """S3/MinIO 目的地
31
+
32
+ 将处理结果写入 S3 或兼容存储。
33
+
34
+ Attributes:
35
+ endpoint: S3 端点
36
+ bucket: 存储桶
37
+ prefix: 对象前缀
38
+
39
+ Example:
40
+ >>> dest = S3Destination(
41
+ ... endpoint="https://s3.amazonaws.com",
42
+ ... access_key="...",
43
+ ... secret_key="...",
44
+ ... bucket="my-bucket",
45
+ ... prefix="output/",
46
+ ... )
47
+ >>> dest.write(elements, {"filename": "doc.pdf"})
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ endpoint: str,
53
+ access_key: str,
54
+ secret_key: str,
55
+ bucket: str,
56
+ prefix: str = "",
57
+ region: str = "us-east-1",
58
+ ) -> None:
59
+ """初始化 S3 目的地
60
+
61
+ Args:
62
+ endpoint: S3 端点 URL
63
+ access_key: 访问密钥
64
+ secret_key: 秘密密钥
65
+ bucket: 存储桶名称
66
+ prefix: 对象前缀
67
+ region: 区域
68
+
69
+ Raises:
70
+ DestinationError: 连接失败
71
+ """
72
+ boto3, Config = _get_boto3()
73
+
74
+ self.endpoint = endpoint
75
+ self.bucket = bucket
76
+ self.prefix = prefix.strip("/") if prefix else ""
77
+
78
+ # 配置签名版本
79
+ if endpoint == "https://textin-minio-api.ai.intsig.net":
80
+ config = Config(signature_version="s3v4")
81
+ elif endpoint.endswith("aliyuncs.com"):
82
+ config = Config(signature_version="s3", s3={"addressing_style": "virtual"})
83
+ elif endpoint.endswith("myhuaweicloud.com"):
84
+ config = Config(signature_version="s3", s3={"addressing_style": "virtual"})
85
+ else:
86
+ config = Config(signature_version="s3v4", s3={"addressing_style": "virtual"})
87
+
88
+ self.client = boto3.client(
89
+ "s3",
90
+ endpoint_url=endpoint,
91
+ aws_access_key_id=access_key,
92
+ aws_secret_access_key=secret_key,
93
+ region_name=region,
94
+ config=config,
95
+ )
96
+
97
+ # 验证连接和写入权限
98
+ try:
99
+ self.client.head_bucket(Bucket=bucket)
100
+
101
+ # 测试写入权限
102
+ test_key = f"{self.prefix}/empty.tmp" if self.prefix else "empty.tmp"
103
+ self.client.put_object(Bucket=bucket, Key=test_key, Body=b"")
104
+ try:
105
+ self.client.delete_object(Bucket=bucket, Key=test_key)
106
+ except Exception:
107
+ pass
108
+
109
+ logger.info(f"S3 目的地连接成功: {endpoint}/{bucket}")
110
+
111
+ except Exception as e:
112
+ raise DestinationError(
113
+ f"S3 连接或写入测试失败: {e}",
114
+ connector_type="s3",
115
+ operation="connect",
116
+ details={"endpoint": endpoint, "bucket": bucket},
117
+ ) from e
118
+
119
+ def write(self, data: list[dict[str, Any]], metadata: dict[str, Any]) -> bool:
120
+ """写入数据到 S3
121
+
122
+ Args:
123
+ data: 要写入的数据列表
124
+ metadata: 元数据,包含 filename
125
+
126
+ Returns:
127
+ 是否写入成功
128
+
129
+ Raises:
130
+ DestinationError: 写入失败
131
+ """
132
+ try:
133
+ filename = metadata.get("filename", "output")
134
+ base_name = Path(filename).stem
135
+ object_key = (
136
+ f"{self.prefix}/{base_name}.json" if self.prefix else f"{base_name}.json"
137
+ )
138
+
139
+ json_data = json.dumps(data, ensure_ascii=False, indent=2)
140
+ json_bytes = json_data.encode("utf-8")
141
+
142
+ self.client.put_object(
143
+ Bucket=self.bucket,
144
+ Key=object_key,
145
+ Body=json_bytes,
146
+ ContentType="application/json",
147
+ )
148
+
149
+ logger.info(f"写入 S3: {self.endpoint}/{self.bucket}/{object_key}")
150
+ return True
151
+
152
+ except Exception as e:
153
+ raise DestinationError(
154
+ f"写入 S3 失败: {e}",
155
+ connector_type="s3",
156
+ operation="write",
157
+ ) from e
158
+
159
+ def __repr__(self) -> str:
160
+ return f"<S3Destination endpoint={self.endpoint} bucket={self.bucket}>"
161
+
162
+
163
+ __all__ = ["S3Destination"]
@@ -0,0 +1,45 @@
1
+ """数据源连接器
2
+
3
+ 提供从各种数据源读取文件的能力。
4
+
5
+ 可用的数据源:
6
+ - LocalSource: 本地文件系统
7
+ - S3Source: S3/MinIO 对象存储(需要 pip install xparse-client[s3])
8
+ - FtpSource: FTP 服务器
9
+ - SmbSource: SMB/CIFS 共享(需要 pip install xparse-client[smb])
10
+
11
+ Example:
12
+ >>> from xparse_client.connectors.sources import LocalSource
13
+ >>> source = LocalSource(directory="./docs", pattern=["*.pdf"])
14
+ >>> files = source.list_files()
15
+ >>> content, metadata = source.read_file(files[0])
16
+ """
17
+
18
+ from .base import Source
19
+ from .local import LocalSource
20
+
21
+ # 懒加载其他 Source,避免强依赖
22
+ __all__ = [
23
+ "Source",
24
+ "LocalSource",
25
+ "S3Source",
26
+ "FtpSource",
27
+ "SmbSource",
28
+ ]
29
+
30
+
31
+ def __getattr__(name: str):
32
+ """懒加载 Source 类"""
33
+ if name == "S3Source":
34
+ from .s3 import S3Source
35
+
36
+ return S3Source
37
+ elif name == "FtpSource":
38
+ from .ftp import FtpSource
39
+
40
+ return FtpSource
41
+ elif name == "SmbSource":
42
+ from .smb import SmbSource
43
+
44
+ return SmbSource
45
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -0,0 +1,74 @@
1
+ """数据源抽象基类"""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from abc import ABC, abstractmethod
7
+ from typing import Any
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class Source(ABC):
13
+ """数据源抽象基类
14
+
15
+ 所有数据源必须实现 list_files() 和 read_file() 方法。
16
+
17
+ Example:
18
+ >>> class MySource(Source):
19
+ ... def list_files(self) -> List[str]:
20
+ ... return ["file1.pdf", "file2.pdf"]
21
+ ...
22
+ ... def read_file(self, file_path: str) -> Tuple[bytes, Dict[str, Any]]:
23
+ ... return b"content", {"url": file_path}
24
+ """
25
+
26
+ @abstractmethod
27
+ def list_files(self) -> list[str]:
28
+ """列出所有匹配的文件
29
+
30
+ Returns:
31
+ 文件路径列表
32
+
33
+ Raises:
34
+ SourceError: 列出文件失败
35
+ """
36
+ raise NotImplementedError
37
+
38
+ @abstractmethod
39
+ def read_file(self, file_path: str) -> tuple[bytes, dict[str, Any]]:
40
+ """读取文件内容
41
+
42
+ Args:
43
+ file_path: 文件路径(相对路径)
44
+
45
+ Returns:
46
+ (file_bytes, data_source) 元组
47
+ - file_bytes: 文件二进制内容
48
+ - data_source: 数据来源元信息,包含 url, version, date_created 等
49
+
50
+ Raises:
51
+ SourceError: 读取文件失败
52
+ """
53
+ raise NotImplementedError
54
+
55
+ def close(self) -> None: # noqa: B027
56
+ """关闭连接
57
+
58
+ 子类可以重写此方法来释放资源(如网络连接)。
59
+ """
60
+ pass
61
+
62
+ def __enter__(self) -> Source:
63
+ """上下文管理器入口"""
64
+ return self
65
+
66
+ def __exit__(self, *args: Any) -> None:
67
+ """上下文管理器退出,自动关闭连接"""
68
+ self.close()
69
+
70
+ def __repr__(self) -> str:
71
+ return f"<{self.__class__.__name__}>"
72
+
73
+
74
+ __all__ = ["Source"]