xparse-client 0.2.20__py3-none-any.whl → 0.3.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- example/1_basic_api_usage.py +198 -0
- example/2_async_job.py +210 -0
- example/3_local_workflow.py +300 -0
- example/4_advanced_workflow.py +327 -0
- example/README.md +128 -0
- example/config_example.json +95 -0
- tests/conftest.py +310 -0
- tests/unit/__init__.py +1 -0
- tests/unit/api/__init__.py +1 -0
- tests/unit/api/test_extract.py +232 -0
- tests/unit/api/test_local.py +231 -0
- tests/unit/api/test_parse.py +374 -0
- tests/unit/api/test_pipeline.py +369 -0
- tests/unit/api/test_workflows.py +108 -0
- tests/unit/connectors/test_ftp.py +525 -0
- tests/unit/connectors/test_local_connectors.py +324 -0
- tests/unit/connectors/test_milvus.py +368 -0
- tests/unit/connectors/test_qdrant.py +399 -0
- tests/unit/connectors/test_s3.py +598 -0
- tests/unit/connectors/test_smb.py +442 -0
- tests/unit/connectors/test_utils.py +335 -0
- tests/unit/models/test_local.py +54 -0
- tests/unit/models/test_pipeline_stages.py +144 -0
- tests/unit/models/test_workflows.py +55 -0
- tests/unit/test_base.py +437 -0
- tests/unit/test_client.py +110 -0
- tests/unit/test_config.py +160 -0
- tests/unit/test_exceptions.py +182 -0
- tests/unit/test_http.py +562 -0
- xparse_client/__init__.py +110 -20
- xparse_client/_base.py +179 -0
- xparse_client/_client.py +218 -0
- xparse_client/_config.py +221 -0
- xparse_client/_http.py +350 -0
- xparse_client/api/__init__.py +14 -0
- xparse_client/api/extract.py +109 -0
- xparse_client/api/local.py +185 -0
- xparse_client/api/parse.py +209 -0
- xparse_client/api/pipeline.py +132 -0
- xparse_client/api/workflows.py +204 -0
- xparse_client/connectors/__init__.py +45 -0
- xparse_client/connectors/_utils.py +138 -0
- xparse_client/connectors/destinations/__init__.py +45 -0
- xparse_client/connectors/destinations/base.py +116 -0
- xparse_client/connectors/destinations/local.py +91 -0
- xparse_client/connectors/destinations/milvus.py +229 -0
- xparse_client/connectors/destinations/qdrant.py +238 -0
- xparse_client/connectors/destinations/s3.py +163 -0
- xparse_client/connectors/sources/__init__.py +45 -0
- xparse_client/connectors/sources/base.py +74 -0
- xparse_client/connectors/sources/ftp.py +278 -0
- xparse_client/connectors/sources/local.py +176 -0
- xparse_client/connectors/sources/s3.py +232 -0
- xparse_client/connectors/sources/smb.py +259 -0
- xparse_client/exceptions.py +398 -0
- xparse_client/models/__init__.py +60 -0
- xparse_client/models/chunk.py +39 -0
- xparse_client/models/embed.py +62 -0
- xparse_client/models/extract.py +41 -0
- xparse_client/models/local.py +38 -0
- xparse_client/models/parse.py +136 -0
- xparse_client/models/pipeline.py +132 -0
- xparse_client/models/workflows.py +74 -0
- xparse_client-0.3.0b1.dist-info/METADATA +1075 -0
- xparse_client-0.3.0b1.dist-info/RECORD +68 -0
- {xparse_client-0.2.20.dist-info → xparse_client-0.3.0b1.dist-info}/WHEEL +1 -1
- {xparse_client-0.2.20.dist-info → xparse_client-0.3.0b1.dist-info}/licenses/LICENSE +1 -1
- {xparse_client-0.2.20.dist-info → xparse_client-0.3.0b1.dist-info}/top_level.txt +2 -0
- xparse_client/pipeline/__init__.py +0 -3
- xparse_client/pipeline/config.py +0 -163
- xparse_client/pipeline/destinations.py +0 -489
- xparse_client/pipeline/pipeline.py +0 -860
- xparse_client/pipeline/sources.py +0 -583
- xparse_client-0.2.20.dist-info/METADATA +0 -1050
- xparse_client-0.2.20.dist-info/RECORD +0 -11
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
"""Qdrant 向量数据库目的地(懒加载 qdrant-client)"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import uuid
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from ...exceptions import DestinationError
|
|
10
|
+
from .base import Destination, VectorDestinationMixin
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _get_qdrant():
|
|
16
|
+
"""懒加载 qdrant-client"""
|
|
17
|
+
try:
|
|
18
|
+
from qdrant_client import QdrantClient
|
|
19
|
+
from qdrant_client.models import (
|
|
20
|
+
Distance,
|
|
21
|
+
PayloadSchemaType,
|
|
22
|
+
PointStruct,
|
|
23
|
+
VectorParams,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
return QdrantClient, Distance, PayloadSchemaType, PointStruct, VectorParams
|
|
27
|
+
except ImportError as e:
|
|
28
|
+
raise ImportError(
|
|
29
|
+
"使用 QdrantDestination 需要安装 qdrant-client: pip install xparse-client[qdrant]"
|
|
30
|
+
) from e
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class QdrantDestination(Destination, VectorDestinationMixin):
|
|
34
|
+
"""Qdrant 向量数据库目的地
|
|
35
|
+
|
|
36
|
+
支持 Qdrant 本地部署和 Qdrant Cloud。
|
|
37
|
+
|
|
38
|
+
Attributes:
|
|
39
|
+
url: Qdrant 服务地址
|
|
40
|
+
collection_name: Collection 名称
|
|
41
|
+
dimension: 向量维度
|
|
42
|
+
|
|
43
|
+
Example:
|
|
44
|
+
>>> # 本地 Qdrant
|
|
45
|
+
>>> dest = QdrantDestination(
|
|
46
|
+
... url="http://localhost:6333",
|
|
47
|
+
... collection_name="documents",
|
|
48
|
+
... dimension=1024,
|
|
49
|
+
... )
|
|
50
|
+
>>>
|
|
51
|
+
>>> # Qdrant Cloud
|
|
52
|
+
>>> dest = QdrantDestination(
|
|
53
|
+
... url="https://xxx.qdrant.io",
|
|
54
|
+
... collection_name="documents",
|
|
55
|
+
... dimension=1024,
|
|
56
|
+
... api_key="your-api-key",
|
|
57
|
+
... )
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
url: str,
|
|
63
|
+
collection_name: str,
|
|
64
|
+
dimension: int,
|
|
65
|
+
api_key: str | None = None,
|
|
66
|
+
prefer_grpc: bool = False,
|
|
67
|
+
) -> None:
|
|
68
|
+
"""初始化 Qdrant 目的地
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
url: Qdrant 服务地址
|
|
72
|
+
collection_name: Collection 名称
|
|
73
|
+
dimension: 向量维度
|
|
74
|
+
api_key: API Key(用于 Qdrant Cloud)
|
|
75
|
+
prefer_grpc: 是否优先使用 gRPC,默认 False
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
DestinationError: 连接失败
|
|
79
|
+
"""
|
|
80
|
+
(
|
|
81
|
+
QdrantClient,
|
|
82
|
+
Distance,
|
|
83
|
+
PayloadSchemaType,
|
|
84
|
+
PointStruct,
|
|
85
|
+
VectorParams,
|
|
86
|
+
) = _get_qdrant()
|
|
87
|
+
|
|
88
|
+
self.url = url
|
|
89
|
+
self.collection_name = collection_name
|
|
90
|
+
self.dimension = dimension
|
|
91
|
+
self._PointStruct = PointStruct
|
|
92
|
+
|
|
93
|
+
client_kwargs = {"url": url}
|
|
94
|
+
if api_key:
|
|
95
|
+
client_kwargs["api_key"] = api_key
|
|
96
|
+
if prefer_grpc:
|
|
97
|
+
client_kwargs["prefer_grpc"] = True
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
self.client = QdrantClient(**client_kwargs)
|
|
101
|
+
|
|
102
|
+
# 检查或创建 Collection
|
|
103
|
+
collections = self.client.get_collections()
|
|
104
|
+
collection_exists = any(
|
|
105
|
+
col.name == collection_name for col in collections.collections
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if not collection_exists:
|
|
109
|
+
self.client.create_collection(
|
|
110
|
+
collection_name=collection_name,
|
|
111
|
+
vectors_config=VectorParams(size=dimension, distance=Distance.COSINE),
|
|
112
|
+
)
|
|
113
|
+
# 为 record_id 创建索引
|
|
114
|
+
try:
|
|
115
|
+
self.client.create_payload_index(
|
|
116
|
+
collection_name=collection_name,
|
|
117
|
+
field_name="record_id",
|
|
118
|
+
field_schema=PayloadSchemaType.KEYWORD,
|
|
119
|
+
)
|
|
120
|
+
except Exception as e:
|
|
121
|
+
logger.warning(f"创建 record_id 索引失败: {e}")
|
|
122
|
+
|
|
123
|
+
logger.info(f"Qdrant Collection 创建: {collection_name}")
|
|
124
|
+
else:
|
|
125
|
+
logger.info(f"Qdrant Collection 已存在: {collection_name}")
|
|
126
|
+
# 确保索引存在
|
|
127
|
+
try:
|
|
128
|
+
self.client.create_payload_index(
|
|
129
|
+
collection_name=collection_name,
|
|
130
|
+
field_name="record_id",
|
|
131
|
+
field_schema=PayloadSchemaType.KEYWORD,
|
|
132
|
+
)
|
|
133
|
+
except Exception:
|
|
134
|
+
pass # 索引可能已存在
|
|
135
|
+
|
|
136
|
+
except ImportError:
|
|
137
|
+
raise
|
|
138
|
+
except Exception as e:
|
|
139
|
+
raise DestinationError(
|
|
140
|
+
f"Qdrant 连接失败: {e}",
|
|
141
|
+
connector_type="qdrant",
|
|
142
|
+
operation="connect",
|
|
143
|
+
details={"url": url, "collection_name": collection_name},
|
|
144
|
+
) from e
|
|
145
|
+
|
|
146
|
+
def write(self, data: list[dict[str, Any]], metadata: dict[str, Any]) -> bool:
|
|
147
|
+
"""写入向量数据到 Qdrant
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
data: 包含 embeddings 的元素列表
|
|
151
|
+
metadata: 元数据,必须包含 record_id
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
是否写入成功
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
DestinationError: 写入失败
|
|
158
|
+
"""
|
|
159
|
+
record_id = metadata.get("record_id")
|
|
160
|
+
if not record_id:
|
|
161
|
+
logger.warning("没有 record_id,跳过写入")
|
|
162
|
+
return False
|
|
163
|
+
|
|
164
|
+
try:
|
|
165
|
+
# 删除旧记录
|
|
166
|
+
try:
|
|
167
|
+
scroll_result = self.client.scroll(
|
|
168
|
+
collection_name=self.collection_name,
|
|
169
|
+
scroll_filter={
|
|
170
|
+
"must": [{"key": "record_id", "match": {"value": record_id}}]
|
|
171
|
+
},
|
|
172
|
+
limit=10000,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if scroll_result[0]:
|
|
176
|
+
point_ids = [point.id for point in scroll_result[0]]
|
|
177
|
+
self.client.delete(
|
|
178
|
+
collection_name=self.collection_name,
|
|
179
|
+
points_selector=point_ids,
|
|
180
|
+
)
|
|
181
|
+
logger.info(
|
|
182
|
+
f"删除 Qdrant 旧记录: record_id={record_id}, 数量={len(point_ids)}"
|
|
183
|
+
)
|
|
184
|
+
except Exception as e:
|
|
185
|
+
logger.warning(f"删除旧记录失败: {e}")
|
|
186
|
+
|
|
187
|
+
# 准备插入数据
|
|
188
|
+
fixed_fields = {"embeddings", "text", "element_id", "record_id", "metadata"}
|
|
189
|
+
points = []
|
|
190
|
+
|
|
191
|
+
for item in data:
|
|
192
|
+
if "embeddings" not in item or not item["embeddings"]:
|
|
193
|
+
continue
|
|
194
|
+
|
|
195
|
+
element_id = (
|
|
196
|
+
item.get("element_id") or item.get("id") or str(uuid.uuid4())
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Qdrant 的 point id 需要是 UUID 或整数
|
|
200
|
+
try:
|
|
201
|
+
point_id = str(uuid.UUID(element_id))
|
|
202
|
+
except (ValueError, TypeError):
|
|
203
|
+
point_id = str(uuid.uuid5(uuid.NAMESPACE_URL, str(element_id)))
|
|
204
|
+
|
|
205
|
+
# 准备 payload
|
|
206
|
+
payload = self.prepare_payload(item, metadata, record_id, fixed_fields)
|
|
207
|
+
|
|
208
|
+
point = self._PointStruct(
|
|
209
|
+
id=point_id,
|
|
210
|
+
vector=item["embeddings"],
|
|
211
|
+
payload=payload,
|
|
212
|
+
)
|
|
213
|
+
points.append(point)
|
|
214
|
+
|
|
215
|
+
if not points:
|
|
216
|
+
logger.warning("没有有效的向量数据")
|
|
217
|
+
return False
|
|
218
|
+
|
|
219
|
+
# 批量插入
|
|
220
|
+
self.client.upsert(
|
|
221
|
+
collection_name=self.collection_name,
|
|
222
|
+
points=points,
|
|
223
|
+
)
|
|
224
|
+
logger.info(f"写入 Qdrant: {len(points)} 条")
|
|
225
|
+
return True
|
|
226
|
+
|
|
227
|
+
except Exception as e:
|
|
228
|
+
raise DestinationError(
|
|
229
|
+
f"写入 Qdrant 失败: {e}",
|
|
230
|
+
connector_type="qdrant",
|
|
231
|
+
operation="write",
|
|
232
|
+
) from e
|
|
233
|
+
|
|
234
|
+
def __repr__(self) -> str:
|
|
235
|
+
return f"<QdrantDestination url={self.url} collection={self.collection_name}>"
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
__all__ = ["QdrantDestination"]
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""S3/MinIO 目的地(懒加载 boto3)"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from ...exceptions import DestinationError
|
|
11
|
+
from .base import Destination
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _get_boto3():
|
|
17
|
+
"""懒加载 boto3"""
|
|
18
|
+
try:
|
|
19
|
+
import boto3
|
|
20
|
+
from botocore.config import Config
|
|
21
|
+
|
|
22
|
+
return boto3, Config
|
|
23
|
+
except ImportError as e:
|
|
24
|
+
raise ImportError(
|
|
25
|
+
"使用 S3Destination 需要安装 boto3: pip install xparse-client[s3]"
|
|
26
|
+
) from e
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class S3Destination(Destination):
|
|
30
|
+
"""S3/MinIO 目的地
|
|
31
|
+
|
|
32
|
+
将处理结果写入 S3 或兼容存储。
|
|
33
|
+
|
|
34
|
+
Attributes:
|
|
35
|
+
endpoint: S3 端点
|
|
36
|
+
bucket: 存储桶
|
|
37
|
+
prefix: 对象前缀
|
|
38
|
+
|
|
39
|
+
Example:
|
|
40
|
+
>>> dest = S3Destination(
|
|
41
|
+
... endpoint="https://s3.amazonaws.com",
|
|
42
|
+
... access_key="...",
|
|
43
|
+
... secret_key="...",
|
|
44
|
+
... bucket="my-bucket",
|
|
45
|
+
... prefix="output/",
|
|
46
|
+
... )
|
|
47
|
+
>>> dest.write(elements, {"filename": "doc.pdf"})
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
endpoint: str,
|
|
53
|
+
access_key: str,
|
|
54
|
+
secret_key: str,
|
|
55
|
+
bucket: str,
|
|
56
|
+
prefix: str = "",
|
|
57
|
+
region: str = "us-east-1",
|
|
58
|
+
) -> None:
|
|
59
|
+
"""初始化 S3 目的地
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
endpoint: S3 端点 URL
|
|
63
|
+
access_key: 访问密钥
|
|
64
|
+
secret_key: 秘密密钥
|
|
65
|
+
bucket: 存储桶名称
|
|
66
|
+
prefix: 对象前缀
|
|
67
|
+
region: 区域
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
DestinationError: 连接失败
|
|
71
|
+
"""
|
|
72
|
+
boto3, Config = _get_boto3()
|
|
73
|
+
|
|
74
|
+
self.endpoint = endpoint
|
|
75
|
+
self.bucket = bucket
|
|
76
|
+
self.prefix = prefix.strip("/") if prefix else ""
|
|
77
|
+
|
|
78
|
+
# 配置签名版本
|
|
79
|
+
if endpoint == "https://textin-minio-api.ai.intsig.net":
|
|
80
|
+
config = Config(signature_version="s3v4")
|
|
81
|
+
elif endpoint.endswith("aliyuncs.com"):
|
|
82
|
+
config = Config(signature_version="s3", s3={"addressing_style": "virtual"})
|
|
83
|
+
elif endpoint.endswith("myhuaweicloud.com"):
|
|
84
|
+
config = Config(signature_version="s3", s3={"addressing_style": "virtual"})
|
|
85
|
+
else:
|
|
86
|
+
config = Config(signature_version="s3v4", s3={"addressing_style": "virtual"})
|
|
87
|
+
|
|
88
|
+
self.client = boto3.client(
|
|
89
|
+
"s3",
|
|
90
|
+
endpoint_url=endpoint,
|
|
91
|
+
aws_access_key_id=access_key,
|
|
92
|
+
aws_secret_access_key=secret_key,
|
|
93
|
+
region_name=region,
|
|
94
|
+
config=config,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# 验证连接和写入权限
|
|
98
|
+
try:
|
|
99
|
+
self.client.head_bucket(Bucket=bucket)
|
|
100
|
+
|
|
101
|
+
# 测试写入权限
|
|
102
|
+
test_key = f"{self.prefix}/empty.tmp" if self.prefix else "empty.tmp"
|
|
103
|
+
self.client.put_object(Bucket=bucket, Key=test_key, Body=b"")
|
|
104
|
+
try:
|
|
105
|
+
self.client.delete_object(Bucket=bucket, Key=test_key)
|
|
106
|
+
except Exception:
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
logger.info(f"S3 目的地连接成功: {endpoint}/{bucket}")
|
|
110
|
+
|
|
111
|
+
except Exception as e:
|
|
112
|
+
raise DestinationError(
|
|
113
|
+
f"S3 连接或写入测试失败: {e}",
|
|
114
|
+
connector_type="s3",
|
|
115
|
+
operation="connect",
|
|
116
|
+
details={"endpoint": endpoint, "bucket": bucket},
|
|
117
|
+
) from e
|
|
118
|
+
|
|
119
|
+
def write(self, data: list[dict[str, Any]], metadata: dict[str, Any]) -> bool:
|
|
120
|
+
"""写入数据到 S3
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
data: 要写入的数据列表
|
|
124
|
+
metadata: 元数据,包含 filename
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
是否写入成功
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
DestinationError: 写入失败
|
|
131
|
+
"""
|
|
132
|
+
try:
|
|
133
|
+
filename = metadata.get("filename", "output")
|
|
134
|
+
base_name = Path(filename).stem
|
|
135
|
+
object_key = (
|
|
136
|
+
f"{self.prefix}/{base_name}.json" if self.prefix else f"{base_name}.json"
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
json_data = json.dumps(data, ensure_ascii=False, indent=2)
|
|
140
|
+
json_bytes = json_data.encode("utf-8")
|
|
141
|
+
|
|
142
|
+
self.client.put_object(
|
|
143
|
+
Bucket=self.bucket,
|
|
144
|
+
Key=object_key,
|
|
145
|
+
Body=json_bytes,
|
|
146
|
+
ContentType="application/json",
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
logger.info(f"写入 S3: {self.endpoint}/{self.bucket}/{object_key}")
|
|
150
|
+
return True
|
|
151
|
+
|
|
152
|
+
except Exception as e:
|
|
153
|
+
raise DestinationError(
|
|
154
|
+
f"写入 S3 失败: {e}",
|
|
155
|
+
connector_type="s3",
|
|
156
|
+
operation="write",
|
|
157
|
+
) from e
|
|
158
|
+
|
|
159
|
+
def __repr__(self) -> str:
|
|
160
|
+
return f"<S3Destination endpoint={self.endpoint} bucket={self.bucket}>"
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
__all__ = ["S3Destination"]
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""数据源连接器
|
|
2
|
+
|
|
3
|
+
提供从各种数据源读取文件的能力。
|
|
4
|
+
|
|
5
|
+
可用的数据源:
|
|
6
|
+
- LocalSource: 本地文件系统
|
|
7
|
+
- S3Source: S3/MinIO 对象存储(需要 pip install xparse-client[s3])
|
|
8
|
+
- FtpSource: FTP 服务器
|
|
9
|
+
- SmbSource: SMB/CIFS 共享(需要 pip install xparse-client[smb])
|
|
10
|
+
|
|
11
|
+
Example:
|
|
12
|
+
>>> from xparse_client.connectors.sources import LocalSource
|
|
13
|
+
>>> source = LocalSource(directory="./docs", pattern=["*.pdf"])
|
|
14
|
+
>>> files = source.list_files()
|
|
15
|
+
>>> content, metadata = source.read_file(files[0])
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from .base import Source
|
|
19
|
+
from .local import LocalSource
|
|
20
|
+
|
|
21
|
+
# 懒加载其他 Source,避免强依赖
|
|
22
|
+
__all__ = [
|
|
23
|
+
"Source",
|
|
24
|
+
"LocalSource",
|
|
25
|
+
"S3Source",
|
|
26
|
+
"FtpSource",
|
|
27
|
+
"SmbSource",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def __getattr__(name: str):
|
|
32
|
+
"""懒加载 Source 类"""
|
|
33
|
+
if name == "S3Source":
|
|
34
|
+
from .s3 import S3Source
|
|
35
|
+
|
|
36
|
+
return S3Source
|
|
37
|
+
elif name == "FtpSource":
|
|
38
|
+
from .ftp import FtpSource
|
|
39
|
+
|
|
40
|
+
return FtpSource
|
|
41
|
+
elif name == "SmbSource":
|
|
42
|
+
from .smb import SmbSource
|
|
43
|
+
|
|
44
|
+
return SmbSource
|
|
45
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""数据源抽象基类"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Source(ABC):
|
|
13
|
+
"""数据源抽象基类
|
|
14
|
+
|
|
15
|
+
所有数据源必须实现 list_files() 和 read_file() 方法。
|
|
16
|
+
|
|
17
|
+
Example:
|
|
18
|
+
>>> class MySource(Source):
|
|
19
|
+
... def list_files(self) -> List[str]:
|
|
20
|
+
... return ["file1.pdf", "file2.pdf"]
|
|
21
|
+
...
|
|
22
|
+
... def read_file(self, file_path: str) -> Tuple[bytes, Dict[str, Any]]:
|
|
23
|
+
... return b"content", {"url": file_path}
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def list_files(self) -> list[str]:
|
|
28
|
+
"""列出所有匹配的文件
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
文件路径列表
|
|
32
|
+
|
|
33
|
+
Raises:
|
|
34
|
+
SourceError: 列出文件失败
|
|
35
|
+
"""
|
|
36
|
+
raise NotImplementedError
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def read_file(self, file_path: str) -> tuple[bytes, dict[str, Any]]:
|
|
40
|
+
"""读取文件内容
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
file_path: 文件路径(相对路径)
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
(file_bytes, data_source) 元组
|
|
47
|
+
- file_bytes: 文件二进制内容
|
|
48
|
+
- data_source: 数据来源元信息,包含 url, version, date_created 等
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
SourceError: 读取文件失败
|
|
52
|
+
"""
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
def close(self) -> None: # noqa: B027
|
|
56
|
+
"""关闭连接
|
|
57
|
+
|
|
58
|
+
子类可以重写此方法来释放资源(如网络连接)。
|
|
59
|
+
"""
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
def __enter__(self) -> Source:
|
|
63
|
+
"""上下文管理器入口"""
|
|
64
|
+
return self
|
|
65
|
+
|
|
66
|
+
def __exit__(self, *args: Any) -> None:
|
|
67
|
+
"""上下文管理器退出,自动关闭连接"""
|
|
68
|
+
self.close()
|
|
69
|
+
|
|
70
|
+
def __repr__(self) -> str:
|
|
71
|
+
return f"<{self.__class__.__name__}>"
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
__all__ = ["Source"]
|