xparse-client 0.2.19__py3-none-any.whl → 0.2.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xparse_client/__init__.py +2 -1
- xparse_client/pipeline/config.py +66 -32
- xparse_client/pipeline/pipeline.py +470 -300
- {xparse_client-0.2.19.dist-info → xparse_client-0.2.20.dist-info}/METADATA +1 -1
- xparse_client-0.2.20.dist-info/RECORD +11 -0
- {xparse_client-0.2.19.dist-info → xparse_client-0.2.20.dist-info}/WHEEL +1 -1
- xparse_client-0.2.19.dist-info/RECORD +0 -11
- {xparse_client-0.2.19.dist-info → xparse_client-0.2.20.dist-info}/licenses/LICENSE +0 -0
- {xparse_client-0.2.19.dist-info → xparse_client-0.2.20.dist-info}/top_level.txt +0 -0
|
@@ -7,14 +7,27 @@ import re
|
|
|
7
7
|
import time
|
|
8
8
|
from datetime import datetime, timezone
|
|
9
9
|
from pathlib import Path
|
|
10
|
-
from typing import Dict,
|
|
10
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
11
11
|
|
|
12
12
|
import requests
|
|
13
13
|
|
|
14
|
-
from .config import
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
14
|
+
from .config import (
|
|
15
|
+
ChunkConfig,
|
|
16
|
+
EmbedConfig,
|
|
17
|
+
ExtractConfig,
|
|
18
|
+
ParseConfig,
|
|
19
|
+
PipelineConfig,
|
|
20
|
+
PipelineStats,
|
|
21
|
+
Stage,
|
|
22
|
+
)
|
|
23
|
+
from .destinations import (
|
|
24
|
+
Destination,
|
|
25
|
+
LocalDestination,
|
|
26
|
+
MilvusDestination,
|
|
27
|
+
QdrantDestination,
|
|
28
|
+
S3Destination,
|
|
29
|
+
)
|
|
30
|
+
from .sources import FtpSource, LocalSource, S3Source, SmbSource, Source
|
|
18
31
|
|
|
19
32
|
logger = logging.getLogger(__name__)
|
|
20
33
|
|
|
@@ -26,15 +39,15 @@ class Pipeline:
|
|
|
26
39
|
self,
|
|
27
40
|
source: Source,
|
|
28
41
|
destination: Destination,
|
|
29
|
-
api_base_url: str =
|
|
42
|
+
api_base_url: str = "http://localhost:8000/api/xparse",
|
|
30
43
|
api_headers: Optional[Dict[str, str]] = None,
|
|
31
44
|
stages: Optional[List[Stage]] = None,
|
|
32
45
|
pipeline_config: Optional[PipelineConfig] = None,
|
|
33
|
-
intermediate_results_destination: Optional[Destination] = None
|
|
46
|
+
intermediate_results_destination: Optional[Destination] = None,
|
|
34
47
|
):
|
|
35
48
|
self.source = source
|
|
36
49
|
self.destination = destination
|
|
37
|
-
self.api_base_url = api_base_url.rstrip(
|
|
50
|
+
self.api_base_url = api_base_url.rstrip("/")
|
|
38
51
|
self.api_headers = api_headers or {}
|
|
39
52
|
self.pipeline_config = pipeline_config or PipelineConfig()
|
|
40
53
|
|
|
@@ -42,34 +55,62 @@ class Pipeline:
|
|
|
42
55
|
# 如果直接传入了 intermediate_results_destination,优先使用它并自动启用中间结果保存
|
|
43
56
|
if intermediate_results_destination is not None:
|
|
44
57
|
self.pipeline_config.include_intermediate_results = True
|
|
45
|
-
self.pipeline_config.intermediate_results_destination =
|
|
58
|
+
self.pipeline_config.intermediate_results_destination = (
|
|
59
|
+
intermediate_results_destination
|
|
60
|
+
)
|
|
46
61
|
# 如果 pipeline_config 中已设置,使用 pipeline_config 中的值
|
|
47
62
|
elif self.pipeline_config.include_intermediate_results:
|
|
48
63
|
if not self.pipeline_config.intermediate_results_destination:
|
|
49
|
-
raise ValueError(
|
|
64
|
+
raise ValueError(
|
|
65
|
+
"当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination"
|
|
66
|
+
)
|
|
50
67
|
|
|
51
68
|
# 处理 stages 配置
|
|
52
69
|
if stages is None:
|
|
53
70
|
raise ValueError("必须提供 stages 参数")
|
|
54
|
-
|
|
71
|
+
|
|
55
72
|
self.stages = stages
|
|
56
73
|
|
|
57
74
|
# 验证 stages
|
|
58
|
-
if not self.stages or self.stages[0].type !=
|
|
75
|
+
if not self.stages or self.stages[0].type != "parse":
|
|
59
76
|
raise ValueError("stages 必须包含且第一个必须是 'parse' 类型")
|
|
60
|
-
|
|
77
|
+
|
|
78
|
+
# 检查是否包含 extract 阶段
|
|
79
|
+
self.has_extract = any(stage.type == "extract" for stage in self.stages)
|
|
80
|
+
|
|
81
|
+
# 验证 extract 阶段的约束
|
|
82
|
+
if self.has_extract:
|
|
83
|
+
# extract 只能跟在 parse 后面,且必须是最后一个
|
|
84
|
+
if len(self.stages) != 2 or self.stages[1].type != "extract":
|
|
85
|
+
raise ValueError(
|
|
86
|
+
"extract 阶段只能跟在 parse 后面,且必须是最后一个阶段(即只能是 [parse, extract] 组合)"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# destination 必须是文件存储类型
|
|
90
|
+
if not isinstance(destination, (LocalDestination, S3Destination)):
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"当使用 extract 阶段时,destination 必须是文件存储类型(LocalDestination 或 S3Destination),不能是向量数据库"
|
|
93
|
+
)
|
|
94
|
+
|
|
61
95
|
# 验证 embed config(如果存在)
|
|
62
96
|
for stage in self.stages:
|
|
63
|
-
if stage.type ==
|
|
97
|
+
if stage.type == "embed" and isinstance(stage.config, EmbedConfig):
|
|
64
98
|
stage.config.validate()
|
|
65
|
-
|
|
99
|
+
|
|
66
100
|
# 验证 intermediate_results_destination
|
|
67
101
|
if self.pipeline_config.include_intermediate_results:
|
|
68
102
|
# 验证是否为支持的 Destination 类型
|
|
69
103
|
from .destinations import Destination
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
104
|
+
|
|
105
|
+
if not isinstance(
|
|
106
|
+
self.pipeline_config.intermediate_results_destination, Destination
|
|
107
|
+
):
|
|
108
|
+
raise ValueError(
|
|
109
|
+
f"intermediate_results_destination 必须是 Destination 类型"
|
|
110
|
+
)
|
|
111
|
+
self.intermediate_results_destination = (
|
|
112
|
+
self.pipeline_config.intermediate_results_destination
|
|
113
|
+
)
|
|
73
114
|
|
|
74
115
|
print("=" * 60)
|
|
75
116
|
print("Pipeline 初始化完成")
|
|
@@ -83,223 +124,264 @@ class Pipeline:
|
|
|
83
124
|
def get_config(self) -> Dict[str, Any]:
|
|
84
125
|
"""获取 Pipeline 的完整配置信息,返回字典格式(与 create_pipeline_from_config 的入参格式一致)"""
|
|
85
126
|
config = {}
|
|
86
|
-
|
|
127
|
+
|
|
87
128
|
# Source 配置
|
|
88
|
-
source_type = type(self.source).__name__.replace(
|
|
89
|
-
config[
|
|
90
|
-
|
|
129
|
+
source_type = type(self.source).__name__.replace("Source", "").lower()
|
|
130
|
+
config["source"] = {"type": source_type}
|
|
131
|
+
|
|
91
132
|
if isinstance(self.source, S3Source):
|
|
92
|
-
config[
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
133
|
+
config["source"].update(
|
|
134
|
+
{
|
|
135
|
+
"endpoint": self.source.endpoint,
|
|
136
|
+
"bucket": self.source.bucket,
|
|
137
|
+
"prefix": self.source.prefix,
|
|
138
|
+
"pattern": self.source.pattern,
|
|
139
|
+
"recursive": self.source.recursive,
|
|
140
|
+
}
|
|
141
|
+
)
|
|
99
142
|
# access_key 和 secret_key 不在对象中保存,无法恢复
|
|
100
143
|
# region 也不在对象中保存,使用默认值
|
|
101
|
-
config[
|
|
144
|
+
config["source"]["region"] = "us-east-1" # 默认值
|
|
102
145
|
elif isinstance(self.source, LocalSource):
|
|
103
|
-
config[
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
146
|
+
config["source"].update(
|
|
147
|
+
{
|
|
148
|
+
"directory": str(self.source.directory),
|
|
149
|
+
"pattern": self.source.pattern,
|
|
150
|
+
"recursive": self.source.recursive,
|
|
151
|
+
}
|
|
152
|
+
)
|
|
108
153
|
elif isinstance(self.source, FtpSource):
|
|
109
|
-
config[
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
154
|
+
config["source"].update(
|
|
155
|
+
{
|
|
156
|
+
"host": self.source.host,
|
|
157
|
+
"port": self.source.port,
|
|
158
|
+
"username": self.source.username,
|
|
159
|
+
"pattern": self.source.pattern,
|
|
160
|
+
"recursive": self.source.recursive,
|
|
161
|
+
}
|
|
162
|
+
)
|
|
116
163
|
# password 不在对象中保存,无法恢复
|
|
117
164
|
elif isinstance(self.source, SmbSource):
|
|
118
|
-
config[
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
165
|
+
config["source"].update(
|
|
166
|
+
{
|
|
167
|
+
"host": self.source.host,
|
|
168
|
+
"share_name": self.source.share_name,
|
|
169
|
+
"username": self.source.username,
|
|
170
|
+
"domain": self.source.domain,
|
|
171
|
+
"port": self.source.port,
|
|
172
|
+
"path": self.source.path,
|
|
173
|
+
"pattern": self.source.pattern,
|
|
174
|
+
"recursive": self.source.recursive,
|
|
175
|
+
}
|
|
176
|
+
)
|
|
128
177
|
# password 不在对象中保存,无法恢复
|
|
129
|
-
|
|
178
|
+
|
|
130
179
|
# Destination 配置
|
|
131
|
-
dest_type = type(self.destination).__name__.replace(
|
|
180
|
+
dest_type = type(self.destination).__name__.replace("Destination", "").lower()
|
|
132
181
|
# MilvusDestination 和 Zilliz 都使用 'milvus' 或 'zilliz' 类型
|
|
133
|
-
if dest_type ==
|
|
182
|
+
if dest_type == "milvus":
|
|
134
183
|
# 判断是本地 Milvus 还是 Zilliz(通过 db_path 判断)
|
|
135
|
-
if self.destination.db_path.startswith(
|
|
136
|
-
dest_type =
|
|
184
|
+
if self.destination.db_path.startswith("http"):
|
|
185
|
+
dest_type = "zilliz"
|
|
137
186
|
else:
|
|
138
|
-
dest_type =
|
|
139
|
-
|
|
140
|
-
config[
|
|
141
|
-
|
|
187
|
+
dest_type = "milvus"
|
|
188
|
+
|
|
189
|
+
config["destination"] = {"type": dest_type}
|
|
190
|
+
|
|
142
191
|
if isinstance(self.destination, MilvusDestination):
|
|
143
|
-
config[
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
192
|
+
config["destination"].update(
|
|
193
|
+
{
|
|
194
|
+
"db_path": self.destination.db_path,
|
|
195
|
+
"collection_name": self.destination.collection_name,
|
|
196
|
+
"dimension": self.destination.dimension,
|
|
197
|
+
}
|
|
198
|
+
)
|
|
148
199
|
# api_key 和 token 不在对象中保存,无法恢复
|
|
149
200
|
elif isinstance(self.destination, QdrantDestination):
|
|
150
|
-
config[
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
201
|
+
config["destination"].update(
|
|
202
|
+
{
|
|
203
|
+
"url": self.destination.url,
|
|
204
|
+
"collection_name": self.destination.collection_name,
|
|
205
|
+
"dimension": self.destination.dimension,
|
|
206
|
+
"prefer_grpc": getattr(self.destination, "prefer_grpc", False),
|
|
207
|
+
}
|
|
208
|
+
)
|
|
156
209
|
# api_key 不在对象中保存,无法恢复
|
|
157
210
|
elif isinstance(self.destination, LocalDestination):
|
|
158
|
-
config[
|
|
159
|
-
|
|
160
|
-
|
|
211
|
+
config["destination"].update(
|
|
212
|
+
{"output_dir": str(self.destination.output_dir)}
|
|
213
|
+
)
|
|
161
214
|
elif isinstance(self.destination, S3Destination):
|
|
162
|
-
config[
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
215
|
+
config["destination"].update(
|
|
216
|
+
{
|
|
217
|
+
"endpoint": self.destination.endpoint,
|
|
218
|
+
"bucket": self.destination.bucket,
|
|
219
|
+
"prefix": self.destination.prefix,
|
|
220
|
+
}
|
|
221
|
+
)
|
|
167
222
|
# access_key, secret_key, region 不在对象中保存,无法恢复
|
|
168
|
-
config[
|
|
169
|
-
|
|
223
|
+
config["destination"]["region"] = "us-east-1" # 默认值
|
|
224
|
+
|
|
170
225
|
# API 配置
|
|
171
|
-
config[
|
|
172
|
-
config[
|
|
226
|
+
config["api_base_url"] = self.api_base_url
|
|
227
|
+
config["api_headers"] = {}
|
|
173
228
|
for key, value in self.api_headers.items():
|
|
174
|
-
config[
|
|
175
|
-
|
|
229
|
+
config["api_headers"][key] = value
|
|
230
|
+
|
|
176
231
|
# Stages 配置
|
|
177
|
-
config[
|
|
232
|
+
config["stages"] = []
|
|
178
233
|
for stage in self.stages:
|
|
179
|
-
stage_dict = {
|
|
180
|
-
|
|
181
|
-
'config': {}
|
|
182
|
-
}
|
|
183
|
-
|
|
234
|
+
stage_dict = {"type": stage.type, "config": {}}
|
|
235
|
+
|
|
184
236
|
if isinstance(stage.config, ParseConfig):
|
|
185
|
-
stage_dict[
|
|
237
|
+
stage_dict["config"] = stage.config.to_dict()
|
|
186
238
|
elif isinstance(stage.config, ChunkConfig):
|
|
187
|
-
stage_dict[
|
|
239
|
+
stage_dict["config"] = stage.config.to_dict()
|
|
188
240
|
elif isinstance(stage.config, EmbedConfig):
|
|
189
|
-
stage_dict[
|
|
241
|
+
stage_dict["config"] = stage.config.to_dict()
|
|
242
|
+
elif isinstance(stage.config, ExtractConfig):
|
|
243
|
+
stage_dict["config"] = stage.config.to_dict()
|
|
190
244
|
else:
|
|
191
245
|
# 如果 config 是字典或其他类型,尝试转换
|
|
192
246
|
if isinstance(stage.config, dict):
|
|
193
|
-
stage_dict[
|
|
247
|
+
stage_dict["config"] = stage.config
|
|
194
248
|
else:
|
|
195
|
-
stage_dict[
|
|
196
|
-
|
|
197
|
-
config[
|
|
198
|
-
|
|
249
|
+
stage_dict["config"] = str(stage.config)
|
|
250
|
+
|
|
251
|
+
config["stages"].append(stage_dict)
|
|
252
|
+
|
|
199
253
|
# Pipeline Config
|
|
200
254
|
if self.pipeline_config.include_intermediate_results:
|
|
201
|
-
config[
|
|
202
|
-
|
|
203
|
-
|
|
255
|
+
config["pipeline_config"] = {
|
|
256
|
+
"include_intermediate_results": True,
|
|
257
|
+
"intermediate_results_destination": {},
|
|
204
258
|
}
|
|
205
|
-
|
|
259
|
+
|
|
206
260
|
inter_dest = self.pipeline_config.intermediate_results_destination
|
|
207
261
|
if inter_dest:
|
|
208
|
-
inter_dest_type =
|
|
209
|
-
|
|
210
|
-
|
|
262
|
+
inter_dest_type = (
|
|
263
|
+
type(inter_dest).__name__.replace("Destination", "").lower()
|
|
264
|
+
)
|
|
265
|
+
config["pipeline_config"]["intermediate_results_destination"][
|
|
266
|
+
"type"
|
|
267
|
+
] = inter_dest_type
|
|
268
|
+
|
|
211
269
|
if isinstance(inter_dest, LocalDestination):
|
|
212
|
-
config[
|
|
270
|
+
config["pipeline_config"]["intermediate_results_destination"][
|
|
271
|
+
"output_dir"
|
|
272
|
+
] = str(inter_dest.output_dir)
|
|
213
273
|
elif isinstance(inter_dest, S3Destination):
|
|
214
|
-
config[
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
274
|
+
config["pipeline_config"][
|
|
275
|
+
"intermediate_results_destination"
|
|
276
|
+
].update(
|
|
277
|
+
{
|
|
278
|
+
"endpoint": inter_dest.endpoint,
|
|
279
|
+
"bucket": inter_dest.bucket,
|
|
280
|
+
"prefix": inter_dest.prefix,
|
|
281
|
+
}
|
|
282
|
+
)
|
|
219
283
|
# access_key, secret_key, region 不在对象中保存,无法恢复
|
|
220
|
-
config[
|
|
221
|
-
|
|
284
|
+
config["pipeline_config"]["intermediate_results_destination"][
|
|
285
|
+
"region"
|
|
286
|
+
] = "us-east-1" # 默认值
|
|
287
|
+
|
|
222
288
|
return config
|
|
223
289
|
|
|
224
290
|
def _extract_error_message(self, response: requests.Response) -> Tuple[str, str]:
|
|
225
291
|
"""
|
|
226
292
|
从响应中提取规范化的错误信息
|
|
227
|
-
|
|
293
|
+
|
|
228
294
|
Returns:
|
|
229
295
|
Tuple[str, str]: (error_msg, x_request_id)
|
|
230
296
|
"""
|
|
231
297
|
# 首先尝试从响应头中提取 x-request-id(requests的headers大小写不敏感)
|
|
232
|
-
x_request_id = response.headers.get(
|
|
233
|
-
error_msg =
|
|
234
|
-
|
|
298
|
+
x_request_id = response.headers.get("x-request-id", "")
|
|
299
|
+
error_msg = ""
|
|
300
|
+
|
|
235
301
|
# 获取Content-Type
|
|
236
|
-
content_type = response.headers.get(
|
|
237
|
-
|
|
302
|
+
content_type = response.headers.get("Content-Type", "").lower()
|
|
303
|
+
|
|
238
304
|
# 尝试解析JSON响应
|
|
239
|
-
if
|
|
305
|
+
if "application/json" in content_type:
|
|
240
306
|
try:
|
|
241
307
|
result = response.json()
|
|
242
308
|
# 如果响应头中没有x-request-id,尝试从响应体中获取
|
|
243
309
|
if not x_request_id:
|
|
244
|
-
x_request_id = result.get(
|
|
245
|
-
error_msg = result.get(
|
|
310
|
+
x_request_id = result.get("x_request_id", "")
|
|
311
|
+
error_msg = result.get(
|
|
312
|
+
"message", result.get("msg", f"HTTP {response.status_code}")
|
|
313
|
+
)
|
|
246
314
|
return error_msg, x_request_id
|
|
247
315
|
except:
|
|
248
316
|
pass
|
|
249
|
-
|
|
317
|
+
|
|
250
318
|
# 处理HTML响应
|
|
251
|
-
if
|
|
319
|
+
if "text/html" in content_type or response.text.strip().startswith("<"):
|
|
252
320
|
try:
|
|
253
321
|
# 从HTML中提取标题(通常包含状态码和状态文本)
|
|
254
|
-
title_match = re.search(
|
|
322
|
+
title_match = re.search(
|
|
323
|
+
r"<title>(.*?)</title>", response.text, re.IGNORECASE
|
|
324
|
+
)
|
|
255
325
|
if title_match:
|
|
256
326
|
error_msg = title_match.group(1).strip()
|
|
257
327
|
else:
|
|
258
328
|
# 如果没有title,尝试提取h1标签
|
|
259
|
-
h1_match = re.search(
|
|
329
|
+
h1_match = re.search(
|
|
330
|
+
r"<h1>(.*?)</h1>", response.text, re.IGNORECASE
|
|
331
|
+
)
|
|
260
332
|
if h1_match:
|
|
261
333
|
error_msg = h1_match.group(1).strip()
|
|
262
334
|
else:
|
|
263
|
-
error_msg = f
|
|
335
|
+
error_msg = f"HTTP {response.status_code}"
|
|
264
336
|
except:
|
|
265
|
-
error_msg = f
|
|
266
|
-
|
|
337
|
+
error_msg = f"HTTP {response.status_code}"
|
|
338
|
+
|
|
267
339
|
# 处理纯文本响应
|
|
268
|
-
elif
|
|
269
|
-
error_msg =
|
|
270
|
-
|
|
340
|
+
elif "text/plain" in content_type:
|
|
341
|
+
error_msg = (
|
|
342
|
+
response.text[:200].strip()
|
|
343
|
+
if response.text
|
|
344
|
+
else f"HTTP {response.status_code}"
|
|
345
|
+
)
|
|
346
|
+
|
|
271
347
|
# 其他情况
|
|
272
348
|
else:
|
|
273
349
|
if response.text:
|
|
274
350
|
# 尝试截取前200字符,但去除换行和多余空格
|
|
275
351
|
text = response.text[:200].strip()
|
|
276
352
|
# 如果包含多行,只取第一行
|
|
277
|
-
if
|
|
278
|
-
text = text.split(
|
|
279
|
-
error_msg = text if text else f
|
|
353
|
+
if "\n" in text:
|
|
354
|
+
text = text.split("\n")[0].strip()
|
|
355
|
+
error_msg = text if text else f"HTTP {response.status_code}"
|
|
280
356
|
else:
|
|
281
|
-
error_msg = f
|
|
282
|
-
|
|
357
|
+
error_msg = f"HTTP {response.status_code}"
|
|
358
|
+
|
|
283
359
|
return error_msg, x_request_id
|
|
284
360
|
|
|
285
|
-
def _call_pipeline_api(
|
|
361
|
+
def _call_pipeline_api(
|
|
362
|
+
self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]
|
|
363
|
+
) -> Optional[Dict[str, Any]]:
|
|
286
364
|
url = f"{self.api_base_url}/pipeline"
|
|
287
365
|
max_retries = 3
|
|
288
366
|
|
|
289
367
|
for try_count in range(max_retries):
|
|
290
368
|
try:
|
|
291
|
-
files = {
|
|
369
|
+
files = {"file": (filename or "file", file_bytes)}
|
|
292
370
|
form_data = {}
|
|
293
371
|
|
|
294
372
|
# 将 stages 转换为 API 格式
|
|
295
373
|
stages_data = [stage.to_dict() for stage in self.stages]
|
|
296
374
|
try:
|
|
297
|
-
form_data[
|
|
298
|
-
form_data[
|
|
299
|
-
|
|
375
|
+
form_data["stages"] = json.dumps(stages_data, ensure_ascii=False)
|
|
376
|
+
form_data["data_source"] = json.dumps(
|
|
377
|
+
data_source, ensure_ascii=False
|
|
378
|
+
)
|
|
379
|
+
|
|
300
380
|
# 如果启用了中间结果保存,在请求中添加参数
|
|
301
381
|
if self.pipeline_config:
|
|
302
|
-
form_data[
|
|
382
|
+
form_data["config"] = json.dumps(
|
|
383
|
+
self.pipeline_config.to_dict(), ensure_ascii=False
|
|
384
|
+
)
|
|
303
385
|
except Exception as e:
|
|
304
386
|
print(f" ✗ 入参处理失败,请检查配置: {e}")
|
|
305
387
|
logger.error(f"入参处理失败,请检查配置: {e}")
|
|
@@ -310,69 +392,136 @@ class Pipeline:
|
|
|
310
392
|
files=files,
|
|
311
393
|
data=form_data,
|
|
312
394
|
headers=self.api_headers,
|
|
313
|
-
timeout=630
|
|
395
|
+
timeout=630,
|
|
314
396
|
)
|
|
315
397
|
|
|
316
398
|
if response.status_code == 200:
|
|
317
399
|
result = response.json()
|
|
318
|
-
x_request_id = result.get(
|
|
400
|
+
x_request_id = result.get("x_request_id", "")
|
|
319
401
|
print(f" ✓ Pipeline 接口返回 x_request_id: {x_request_id}")
|
|
320
|
-
if result.get(
|
|
321
|
-
return result.get(
|
|
402
|
+
if result.get("code") == 200 and "data" in result:
|
|
403
|
+
return result.get("data")
|
|
322
404
|
# 如果 code 不是 200,打印错误信息
|
|
323
|
-
error_msg = result.get(
|
|
324
|
-
print(
|
|
325
|
-
|
|
405
|
+
error_msg = result.get("message", result.get("msg", "未知错误"))
|
|
406
|
+
print(
|
|
407
|
+
f" ✗ Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}"
|
|
408
|
+
)
|
|
409
|
+
logger.error(
|
|
410
|
+
f"Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}"
|
|
411
|
+
)
|
|
326
412
|
return None
|
|
327
413
|
else:
|
|
328
414
|
# 使用规范化函数提取错误信息
|
|
329
415
|
error_msg, x_request_id = self._extract_error_message(response)
|
|
330
|
-
|
|
331
|
-
print(
|
|
332
|
-
|
|
416
|
+
|
|
417
|
+
print(
|
|
418
|
+
f" ✗ API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}"
|
|
419
|
+
)
|
|
420
|
+
logger.warning(
|
|
421
|
+
f"API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}"
|
|
422
|
+
)
|
|
333
423
|
|
|
334
424
|
except Exception as e:
|
|
335
425
|
# 如果是 requests 异常,尝试从响应中获取 x_request_id
|
|
336
|
-
x_request_id =
|
|
426
|
+
x_request_id = ""
|
|
337
427
|
error_msg = str(e)
|
|
338
428
|
try:
|
|
339
|
-
if hasattr(e,
|
|
429
|
+
if hasattr(e, "response") and e.response is not None:
|
|
340
430
|
try:
|
|
341
431
|
result = e.response.json()
|
|
342
|
-
x_request_id = result.get(
|
|
343
|
-
error_msg = result.get(
|
|
432
|
+
x_request_id = result.get("x_request_id", "")
|
|
433
|
+
error_msg = result.get(
|
|
434
|
+
"message", result.get("msg", error_msg)
|
|
435
|
+
)
|
|
344
436
|
except:
|
|
345
437
|
pass
|
|
346
438
|
except:
|
|
347
439
|
pass
|
|
348
|
-
|
|
349
|
-
print(
|
|
350
|
-
|
|
440
|
+
|
|
441
|
+
print(
|
|
442
|
+
f" ✗ 请求异常: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}"
|
|
443
|
+
)
|
|
444
|
+
logger.error(
|
|
445
|
+
f"API 请求异常 pipeline: {error_msg}, x_request_id={x_request_id}"
|
|
446
|
+
)
|
|
351
447
|
|
|
352
448
|
if try_count < max_retries - 1:
|
|
353
449
|
time.sleep(2)
|
|
354
450
|
|
|
355
451
|
return None
|
|
356
452
|
|
|
357
|
-
def process_with_pipeline(
|
|
453
|
+
def process_with_pipeline(
|
|
454
|
+
self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]
|
|
455
|
+
) -> Optional[Tuple[Any, PipelineStats]]:
|
|
358
456
|
print(f" → 调用 Pipeline 接口: {filename}")
|
|
359
457
|
result = self._call_pipeline_api(file_bytes, filename, data_source)
|
|
360
458
|
|
|
361
|
-
if
|
|
362
|
-
|
|
363
|
-
|
|
459
|
+
if not result or "stats" not in result:
|
|
460
|
+
print(f" ✗ Pipeline 失败")
|
|
461
|
+
logger.error(f"Pipeline 失败: {filename}")
|
|
462
|
+
return None
|
|
463
|
+
|
|
464
|
+
# 处理 extract 类型的响应
|
|
465
|
+
if self.has_extract:
|
|
466
|
+
# extract 返回 extract_result 而不是 elements
|
|
467
|
+
if "extract_result" not in result:
|
|
468
|
+
print(f" ✗ Pipeline 失败: extract 响应中缺少 extract_result")
|
|
469
|
+
logger.error(f"Pipeline 失败: extract 响应中缺少 extract_result")
|
|
470
|
+
return None
|
|
471
|
+
|
|
472
|
+
data = result["extract_result"] # 结构化数据
|
|
473
|
+
stats_data = result["stats"]
|
|
474
|
+
|
|
475
|
+
stats = PipelineStats(
|
|
476
|
+
original_elements=stats_data.get("original_elements", 0),
|
|
477
|
+
chunked_elements=0, # extract 不涉及分块
|
|
478
|
+
embedded_elements=0, # extract 不涉及向量化
|
|
479
|
+
stages=self.stages,
|
|
480
|
+
record_id=stats_data.get("record_id"),
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# 如果启用了中间结果保存,处理中间结果
|
|
484
|
+
if (
|
|
485
|
+
self.pipeline_config.include_intermediate_results
|
|
486
|
+
and "intermediate_results" in result
|
|
487
|
+
):
|
|
488
|
+
self._save_intermediate_results(
|
|
489
|
+
result["intermediate_results"], filename, data_source
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
print(f" ✓ Extract 完成:")
|
|
493
|
+
print(f" - 原始元素: {stats.original_elements}")
|
|
494
|
+
print(f" - 提取结果类型: {type(data).__name__}")
|
|
495
|
+
logger.info(f"Extract 完成: {filename}")
|
|
496
|
+
|
|
497
|
+
return data, stats
|
|
498
|
+
|
|
499
|
+
else:
|
|
500
|
+
# 原有的 parse/chunk/embed 逻辑
|
|
501
|
+
if "elements" not in result:
|
|
502
|
+
print(f" ✗ Pipeline 失败: 响应中缺少 elements")
|
|
503
|
+
logger.error(f"Pipeline 失败: 响应中缺少 elements")
|
|
504
|
+
return None
|
|
505
|
+
|
|
506
|
+
elements = result["elements"]
|
|
507
|
+
stats_data = result["stats"]
|
|
364
508
|
|
|
365
509
|
stats = PipelineStats(
|
|
366
|
-
original_elements=stats_data.get(
|
|
367
|
-
chunked_elements=stats_data.get(
|
|
368
|
-
embedded_elements=stats_data.get(
|
|
510
|
+
original_elements=stats_data.get("original_elements", 0),
|
|
511
|
+
chunked_elements=stats_data.get("chunked_elements", 0),
|
|
512
|
+
embedded_elements=stats_data.get("embedded_elements", 0),
|
|
369
513
|
stages=self.stages, # 使用实际执行的 stages
|
|
370
|
-
record_id=stats_data.get(
|
|
514
|
+
record_id=stats_data.get("record_id"), # 从 API 响应中获取 record_id
|
|
371
515
|
)
|
|
372
516
|
|
|
373
517
|
# 如果启用了中间结果保存,处理中间结果
|
|
374
|
-
if
|
|
375
|
-
self.
|
|
518
|
+
if (
|
|
519
|
+
self.pipeline_config.include_intermediate_results
|
|
520
|
+
and "intermediate_results" in result
|
|
521
|
+
):
|
|
522
|
+
self._save_intermediate_results(
|
|
523
|
+
result["intermediate_results"], filename, data_source
|
|
524
|
+
)
|
|
376
525
|
|
|
377
526
|
print(f" ✓ Pipeline 完成:")
|
|
378
527
|
print(f" - 原始元素: {stats.original_elements}")
|
|
@@ -381,14 +530,15 @@ class Pipeline:
|
|
|
381
530
|
logger.info(f"Pipeline 完成: {filename}, {stats.embedded_elements} 个向量")
|
|
382
531
|
|
|
383
532
|
return elements, stats
|
|
384
|
-
else:
|
|
385
|
-
print(f" ✗ Pipeline 失败")
|
|
386
|
-
logger.error(f"Pipeline 失败: {filename}")
|
|
387
|
-
return None
|
|
388
533
|
|
|
389
|
-
def _save_intermediate_results(
|
|
534
|
+
def _save_intermediate_results(
|
|
535
|
+
self,
|
|
536
|
+
intermediate_results: List[Dict[str, Any]],
|
|
537
|
+
filename: str,
|
|
538
|
+
data_source: Dict[str, Any],
|
|
539
|
+
) -> None:
|
|
390
540
|
"""保存中间结果
|
|
391
|
-
|
|
541
|
+
|
|
392
542
|
Args:
|
|
393
543
|
intermediate_results: 中间结果数组,每个元素包含 stage 和 elements 字段
|
|
394
544
|
filename: 文件名
|
|
@@ -397,22 +547,26 @@ class Pipeline:
|
|
|
397
547
|
try:
|
|
398
548
|
# intermediate_results 是一个数组,每个元素是 {stage: str, elements: List}
|
|
399
549
|
for result_item in intermediate_results:
|
|
400
|
-
if
|
|
401
|
-
logger.warning(
|
|
550
|
+
if "stage" not in result_item or "elements" not in result_item:
|
|
551
|
+
logger.warning(
|
|
552
|
+
f"中间结果项缺少 stage 或 elements 字段: {result_item}"
|
|
553
|
+
)
|
|
402
554
|
continue
|
|
403
|
-
|
|
404
|
-
stage = result_item[
|
|
405
|
-
elements = result_item[
|
|
406
|
-
|
|
555
|
+
|
|
556
|
+
stage = result_item["stage"]
|
|
557
|
+
elements = result_item["elements"]
|
|
558
|
+
|
|
407
559
|
metadata = {
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
560
|
+
"filename": filename,
|
|
561
|
+
"stage": stage,
|
|
562
|
+
"total_elements": len(elements),
|
|
563
|
+
"processed_at": datetime.now().isoformat(),
|
|
564
|
+
"data_source": data_source,
|
|
413
565
|
}
|
|
414
|
-
|
|
415
|
-
self.pipeline_config.intermediate_results_destination.write(
|
|
566
|
+
|
|
567
|
+
self.pipeline_config.intermediate_results_destination.write(
|
|
568
|
+
elements, metadata
|
|
569
|
+
)
|
|
416
570
|
print(f" ✓ 保存 {stage.upper()} 中间结果: {len(elements)} 个元素")
|
|
417
571
|
logger.info(f"保存 {stage.upper()} 中间结果成功: {filename}")
|
|
418
572
|
|
|
@@ -429,17 +583,17 @@ class Pipeline:
|
|
|
429
583
|
print(f" → 读取文件...")
|
|
430
584
|
file_bytes, data_source = self.source.read_file(file_path)
|
|
431
585
|
data_source = data_source or {}
|
|
432
|
-
|
|
586
|
+
|
|
433
587
|
# 检查文件大小,超过 100MB 则报错
|
|
434
588
|
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
|
|
435
589
|
file_size = len(file_bytes)
|
|
436
590
|
if file_size > MAX_FILE_SIZE:
|
|
437
591
|
file_size_mb = file_size / (1024 * 1024)
|
|
438
592
|
raise ValueError(f"文件大小过大: {file_size_mb:.2f}MB,超过100MB限制")
|
|
439
|
-
|
|
593
|
+
|
|
440
594
|
# 转换为毫秒时间戳字符串
|
|
441
595
|
timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000)
|
|
442
|
-
data_source[
|
|
596
|
+
data_source["date_processed"] = str(timestamp_ms)
|
|
443
597
|
print(f" ✓ 文件读取完成: {len(file_bytes)} bytes")
|
|
444
598
|
|
|
445
599
|
result = self.process_with_pipeline(file_bytes, file_path, data_source)
|
|
@@ -450,13 +604,13 @@ class Pipeline:
|
|
|
450
604
|
|
|
451
605
|
print(f" → 写入目的地...")
|
|
452
606
|
metadata = {
|
|
453
|
-
|
|
454
|
-
|
|
607
|
+
"filename": file_path,
|
|
608
|
+
"processed_at": str(timestamp_ms),
|
|
455
609
|
}
|
|
456
|
-
|
|
610
|
+
|
|
457
611
|
# 如果 stats 中有 record_id,添加到 metadata 中
|
|
458
612
|
if stats.record_id:
|
|
459
|
-
metadata[
|
|
613
|
+
metadata["record_id"] = stats.record_id
|
|
460
614
|
|
|
461
615
|
success = self.destination.write(embedded_data, metadata)
|
|
462
616
|
|
|
@@ -523,168 +677,184 @@ class Pipeline:
|
|
|
523
677
|
print("=" * 60)
|
|
524
678
|
|
|
525
679
|
logger.info("=" * 60)
|
|
526
|
-
logger.info(
|
|
680
|
+
logger.info(
|
|
681
|
+
f"Pipeline 完成 - 总数:{total}, 成功:{success_count}, 失败:{fail_count}, 耗时:{elapsed:.2f}秒"
|
|
682
|
+
)
|
|
527
683
|
logger.info("=" * 60)
|
|
528
684
|
|
|
529
685
|
|
|
530
686
|
def create_pipeline_from_config(config: Dict[str, Any]) -> Pipeline:
|
|
531
|
-
source_config = config[
|
|
532
|
-
if source_config[
|
|
687
|
+
source_config = config["source"]
|
|
688
|
+
if source_config["type"] == "s3":
|
|
533
689
|
source = S3Source(
|
|
534
|
-
endpoint=source_config[
|
|
535
|
-
access_key=source_config[
|
|
536
|
-
secret_key=source_config[
|
|
537
|
-
bucket=source_config[
|
|
538
|
-
prefix=source_config.get(
|
|
539
|
-
region=source_config.get(
|
|
540
|
-
pattern=source_config.get(
|
|
541
|
-
recursive=source_config.get(
|
|
690
|
+
endpoint=source_config["endpoint"],
|
|
691
|
+
access_key=source_config["access_key"],
|
|
692
|
+
secret_key=source_config["secret_key"],
|
|
693
|
+
bucket=source_config["bucket"],
|
|
694
|
+
prefix=source_config.get("prefix", ""),
|
|
695
|
+
region=source_config.get("region", "us-east-1"),
|
|
696
|
+
pattern=source_config.get("pattern", None),
|
|
697
|
+
recursive=source_config.get("recursive", False),
|
|
542
698
|
)
|
|
543
|
-
elif source_config[
|
|
699
|
+
elif source_config["type"] == "local":
|
|
544
700
|
source = LocalSource(
|
|
545
|
-
directory=source_config[
|
|
546
|
-
pattern=source_config.get(
|
|
547
|
-
recursive=source_config.get(
|
|
701
|
+
directory=source_config["directory"],
|
|
702
|
+
pattern=source_config.get("pattern", None),
|
|
703
|
+
recursive=source_config.get("recursive", False),
|
|
548
704
|
)
|
|
549
|
-
elif source_config[
|
|
705
|
+
elif source_config["type"] == "ftp":
|
|
550
706
|
source = FtpSource(
|
|
551
|
-
host=source_config[
|
|
552
|
-
port=source_config[
|
|
553
|
-
username=source_config[
|
|
554
|
-
password=source_config[
|
|
555
|
-
pattern=source_config.get(
|
|
556
|
-
recursive=source_config.get(
|
|
707
|
+
host=source_config["host"],
|
|
708
|
+
port=source_config["port"],
|
|
709
|
+
username=source_config["username"],
|
|
710
|
+
password=source_config["password"],
|
|
711
|
+
pattern=source_config.get("pattern", None),
|
|
712
|
+
recursive=source_config.get("recursive", False),
|
|
557
713
|
)
|
|
558
|
-
elif source_config[
|
|
714
|
+
elif source_config["type"] == "smb":
|
|
559
715
|
source = SmbSource(
|
|
560
|
-
host=source_config[
|
|
561
|
-
share_name=source_config[
|
|
562
|
-
username=source_config[
|
|
563
|
-
password=source_config[
|
|
564
|
-
domain=source_config.get(
|
|
565
|
-
port=source_config.get(
|
|
566
|
-
path=source_config.get(
|
|
567
|
-
pattern=source_config.get(
|
|
568
|
-
recursive=source_config.get(
|
|
716
|
+
host=source_config["host"],
|
|
717
|
+
share_name=source_config["share_name"],
|
|
718
|
+
username=source_config["username"],
|
|
719
|
+
password=source_config["password"],
|
|
720
|
+
domain=source_config.get("domain", ""),
|
|
721
|
+
port=source_config.get("port", 445),
|
|
722
|
+
path=source_config.get("path", ""),
|
|
723
|
+
pattern=source_config.get("pattern", None),
|
|
724
|
+
recursive=source_config.get("recursive", False),
|
|
569
725
|
)
|
|
570
726
|
else:
|
|
571
727
|
raise ValueError(f"未知的 source 类型: {source_config['type']}")
|
|
572
728
|
|
|
573
|
-
dest_config = config[
|
|
574
|
-
if dest_config[
|
|
729
|
+
dest_config = config["destination"]
|
|
730
|
+
if dest_config["type"] in ["milvus", "zilliz"]:
|
|
575
731
|
destination = MilvusDestination(
|
|
576
|
-
db_path=dest_config[
|
|
577
|
-
collection_name=dest_config[
|
|
578
|
-
dimension=dest_config[
|
|
579
|
-
api_key=dest_config.get(
|
|
580
|
-
token=dest_config.get(
|
|
732
|
+
db_path=dest_config["db_path"],
|
|
733
|
+
collection_name=dest_config["collection_name"],
|
|
734
|
+
dimension=dest_config["dimension"],
|
|
735
|
+
api_key=dest_config.get("api_key"),
|
|
736
|
+
token=dest_config.get("token"),
|
|
581
737
|
)
|
|
582
|
-
elif dest_config[
|
|
738
|
+
elif dest_config["type"] == "qdrant":
|
|
583
739
|
destination = QdrantDestination(
|
|
584
|
-
url=dest_config[
|
|
585
|
-
collection_name=dest_config[
|
|
586
|
-
dimension=dest_config[
|
|
587
|
-
api_key=dest_config.get(
|
|
588
|
-
prefer_grpc=dest_config.get(
|
|
589
|
-
)
|
|
590
|
-
elif dest_config['type'] == 'local':
|
|
591
|
-
destination = LocalDestination(
|
|
592
|
-
output_dir=dest_config['output_dir']
|
|
740
|
+
url=dest_config["url"],
|
|
741
|
+
collection_name=dest_config["collection_name"],
|
|
742
|
+
dimension=dest_config["dimension"],
|
|
743
|
+
api_key=dest_config.get("api_key"),
|
|
744
|
+
prefer_grpc=dest_config.get("prefer_grpc", False),
|
|
593
745
|
)
|
|
594
|
-
elif dest_config[
|
|
746
|
+
elif dest_config["type"] == "local":
|
|
747
|
+
destination = LocalDestination(output_dir=dest_config["output_dir"])
|
|
748
|
+
elif dest_config["type"] == "s3":
|
|
595
749
|
destination = S3Destination(
|
|
596
|
-
endpoint=dest_config[
|
|
597
|
-
access_key=dest_config[
|
|
598
|
-
secret_key=dest_config[
|
|
599
|
-
bucket=dest_config[
|
|
600
|
-
prefix=dest_config.get(
|
|
601
|
-
region=dest_config.get(
|
|
750
|
+
endpoint=dest_config["endpoint"],
|
|
751
|
+
access_key=dest_config["access_key"],
|
|
752
|
+
secret_key=dest_config["secret_key"],
|
|
753
|
+
bucket=dest_config["bucket"],
|
|
754
|
+
prefix=dest_config.get("prefix", ""),
|
|
755
|
+
region=dest_config.get("region", "us-east-1"),
|
|
602
756
|
)
|
|
603
757
|
else:
|
|
604
758
|
raise ValueError(f"未知的 destination 类型: {dest_config['type']}")
|
|
605
759
|
|
|
606
760
|
# 处理 stages 配置
|
|
607
|
-
if
|
|
761
|
+
if "stages" not in config or not config["stages"]:
|
|
608
762
|
raise ValueError("配置中必须包含 'stages' 字段")
|
|
609
|
-
|
|
763
|
+
|
|
610
764
|
stages = []
|
|
611
|
-
for stage_cfg in config[
|
|
612
|
-
stage_type = stage_cfg.get(
|
|
613
|
-
stage_config_dict = stage_cfg.get(
|
|
614
|
-
|
|
615
|
-
if stage_type ==
|
|
765
|
+
for stage_cfg in config["stages"]:
|
|
766
|
+
stage_type = stage_cfg.get("type")
|
|
767
|
+
stage_config_dict = stage_cfg.get("config", {})
|
|
768
|
+
|
|
769
|
+
if stage_type == "parse":
|
|
616
770
|
parse_cfg_copy = dict(stage_config_dict)
|
|
617
|
-
provider = parse_cfg_copy.pop(
|
|
771
|
+
provider = parse_cfg_copy.pop("provider", "textin")
|
|
618
772
|
stage_config = ParseConfig(provider=provider, **parse_cfg_copy)
|
|
619
|
-
elif stage_type ==
|
|
773
|
+
elif stage_type == "chunk":
|
|
620
774
|
stage_config = ChunkConfig(
|
|
621
|
-
strategy=stage_config_dict.get(
|
|
622
|
-
include_orig_elements=stage_config_dict.get(
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
775
|
+
strategy=stage_config_dict.get("strategy", "basic"),
|
|
776
|
+
include_orig_elements=stage_config_dict.get(
|
|
777
|
+
"include_orig_elements", False
|
|
778
|
+
),
|
|
779
|
+
new_after_n_chars=stage_config_dict.get("new_after_n_chars", 512),
|
|
780
|
+
max_characters=stage_config_dict.get("max_characters", 1024),
|
|
781
|
+
overlap=stage_config_dict.get("overlap", 0),
|
|
782
|
+
overlap_all=stage_config_dict.get("overlap_all", False),
|
|
627
783
|
)
|
|
628
|
-
elif stage_type ==
|
|
784
|
+
elif stage_type == "embed":
|
|
629
785
|
stage_config = EmbedConfig(
|
|
630
|
-
provider=stage_config_dict.get(
|
|
631
|
-
model_name=stage_config_dict.get(
|
|
786
|
+
provider=stage_config_dict.get("provider", "qwen"),
|
|
787
|
+
model_name=stage_config_dict.get("model_name", "text-embedding-v3"),
|
|
788
|
+
)
|
|
789
|
+
elif stage_type == "extract":
|
|
790
|
+
schema = stage_config_dict.get("schema")
|
|
791
|
+
if not schema:
|
|
792
|
+
raise ValueError("extract stage 的 config 中必须包含 'schema' 字段")
|
|
793
|
+
stage_config = ExtractConfig(
|
|
794
|
+
schema=schema,
|
|
795
|
+
generate_citations=stage_config_dict.get("generate_citations", False),
|
|
796
|
+
stamp=stage_config_dict.get("stamp", False),
|
|
632
797
|
)
|
|
633
798
|
else:
|
|
634
799
|
raise ValueError(f"未知的 stage 类型: {stage_type}")
|
|
635
|
-
|
|
800
|
+
|
|
636
801
|
stages.append(Stage(type=stage_type, config=stage_config))
|
|
637
802
|
|
|
638
803
|
# 创建 Pipeline 配置
|
|
639
804
|
pipeline_config = None
|
|
640
|
-
if
|
|
641
|
-
pipeline_cfg = config[
|
|
642
|
-
include_intermediate_results = pipeline_cfg.get(
|
|
805
|
+
if "pipeline_config" in config and config["pipeline_config"]:
|
|
806
|
+
pipeline_cfg = config["pipeline_config"]
|
|
807
|
+
include_intermediate_results = pipeline_cfg.get(
|
|
808
|
+
"include_intermediate_results", False
|
|
809
|
+
)
|
|
643
810
|
intermediate_results_destination = None
|
|
644
|
-
|
|
811
|
+
|
|
645
812
|
if include_intermediate_results:
|
|
646
|
-
if
|
|
647
|
-
dest_cfg = pipeline_cfg[
|
|
648
|
-
dest_type = dest_cfg.get(
|
|
649
|
-
|
|
650
|
-
if dest_type ==
|
|
813
|
+
if "intermediate_results_destination" in pipeline_cfg:
|
|
814
|
+
dest_cfg = pipeline_cfg["intermediate_results_destination"]
|
|
815
|
+
dest_type = dest_cfg.get("type")
|
|
816
|
+
|
|
817
|
+
if dest_type == "local":
|
|
651
818
|
intermediate_results_destination = LocalDestination(
|
|
652
|
-
output_dir=dest_cfg[
|
|
819
|
+
output_dir=dest_cfg["output_dir"]
|
|
653
820
|
)
|
|
654
|
-
elif dest_type ==
|
|
821
|
+
elif dest_type == "s3":
|
|
655
822
|
intermediate_results_destination = S3Destination(
|
|
656
|
-
endpoint=dest_cfg[
|
|
657
|
-
access_key=dest_cfg[
|
|
658
|
-
secret_key=dest_cfg[
|
|
659
|
-
bucket=dest_cfg[
|
|
660
|
-
prefix=dest_cfg.get(
|
|
661
|
-
region=dest_cfg.get(
|
|
823
|
+
endpoint=dest_cfg["endpoint"],
|
|
824
|
+
access_key=dest_cfg["access_key"],
|
|
825
|
+
secret_key=dest_cfg["secret_key"],
|
|
826
|
+
bucket=dest_cfg["bucket"],
|
|
827
|
+
prefix=dest_cfg.get("prefix", ""),
|
|
828
|
+
region=dest_cfg.get("region", "us-east-1"),
|
|
662
829
|
)
|
|
663
830
|
else:
|
|
664
|
-
raise ValueError(
|
|
831
|
+
raise ValueError(
|
|
832
|
+
f"不支持的 intermediate_results_destination 类型: '{dest_type}',支持的类型: 'local', 's3'"
|
|
833
|
+
)
|
|
665
834
|
else:
|
|
666
|
-
raise ValueError(
|
|
667
|
-
|
|
835
|
+
raise ValueError(
|
|
836
|
+
"当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination"
|
|
837
|
+
)
|
|
838
|
+
|
|
668
839
|
pipeline_config = PipelineConfig(
|
|
669
840
|
include_intermediate_results=include_intermediate_results,
|
|
670
|
-
intermediate_results_destination=intermediate_results_destination
|
|
841
|
+
intermediate_results_destination=intermediate_results_destination,
|
|
671
842
|
)
|
|
672
843
|
|
|
673
844
|
# 创建 Pipeline
|
|
674
845
|
pipeline = Pipeline(
|
|
675
846
|
source=source,
|
|
676
847
|
destination=destination,
|
|
677
|
-
api_base_url=config.get(
|
|
678
|
-
api_headers=config.get(
|
|
848
|
+
api_base_url=config.get("api_base_url", "http://localhost:8000/api/xparse"),
|
|
849
|
+
api_headers=config.get("api_headers", {}),
|
|
679
850
|
stages=stages,
|
|
680
|
-
pipeline_config=pipeline_config
|
|
851
|
+
pipeline_config=pipeline_config,
|
|
681
852
|
)
|
|
682
853
|
|
|
683
854
|
return pipeline
|
|
684
855
|
|
|
685
856
|
|
|
686
857
|
__all__ = [
|
|
687
|
-
|
|
688
|
-
|
|
858
|
+
"Pipeline",
|
|
859
|
+
"create_pipeline_from_config",
|
|
689
860
|
]
|
|
690
|
-
|