xparse-client 0.2.18__py3-none-any.whl → 0.2.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xparse_client/__init__.py +2 -1
- xparse_client/pipeline/config.py +66 -32
- xparse_client/pipeline/pipeline.py +510 -285
- {xparse_client-0.2.18.dist-info → xparse_client-0.2.20.dist-info}/METADATA +1 -1
- xparse_client-0.2.20.dist-info/RECORD +11 -0
- {xparse_client-0.2.18.dist-info → xparse_client-0.2.20.dist-info}/WHEEL +1 -1
- xparse_client-0.2.18.dist-info/RECORD +0 -11
- {xparse_client-0.2.18.dist-info → xparse_client-0.2.20.dist-info}/licenses/LICENSE +0 -0
- {xparse_client-0.2.18.dist-info → xparse_client-0.2.20.dist-info}/top_level.txt +0 -0
|
@@ -3,17 +3,31 @@
|
|
|
3
3
|
|
|
4
4
|
import json
|
|
5
5
|
import logging
|
|
6
|
+
import re
|
|
6
7
|
import time
|
|
7
8
|
from datetime import datetime, timezone
|
|
8
9
|
from pathlib import Path
|
|
9
|
-
from typing import Dict,
|
|
10
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
10
11
|
|
|
11
12
|
import requests
|
|
12
13
|
|
|
13
|
-
from .config import
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
14
|
+
from .config import (
|
|
15
|
+
ChunkConfig,
|
|
16
|
+
EmbedConfig,
|
|
17
|
+
ExtractConfig,
|
|
18
|
+
ParseConfig,
|
|
19
|
+
PipelineConfig,
|
|
20
|
+
PipelineStats,
|
|
21
|
+
Stage,
|
|
22
|
+
)
|
|
23
|
+
from .destinations import (
|
|
24
|
+
Destination,
|
|
25
|
+
LocalDestination,
|
|
26
|
+
MilvusDestination,
|
|
27
|
+
QdrantDestination,
|
|
28
|
+
S3Destination,
|
|
29
|
+
)
|
|
30
|
+
from .sources import FtpSource, LocalSource, S3Source, SmbSource, Source
|
|
17
31
|
|
|
18
32
|
logger = logging.getLogger(__name__)
|
|
19
33
|
|
|
@@ -25,15 +39,15 @@ class Pipeline:
|
|
|
25
39
|
self,
|
|
26
40
|
source: Source,
|
|
27
41
|
destination: Destination,
|
|
28
|
-
api_base_url: str =
|
|
42
|
+
api_base_url: str = "http://localhost:8000/api/xparse",
|
|
29
43
|
api_headers: Optional[Dict[str, str]] = None,
|
|
30
44
|
stages: Optional[List[Stage]] = None,
|
|
31
45
|
pipeline_config: Optional[PipelineConfig] = None,
|
|
32
|
-
intermediate_results_destination: Optional[Destination] = None
|
|
46
|
+
intermediate_results_destination: Optional[Destination] = None,
|
|
33
47
|
):
|
|
34
48
|
self.source = source
|
|
35
49
|
self.destination = destination
|
|
36
|
-
self.api_base_url = api_base_url.rstrip(
|
|
50
|
+
self.api_base_url = api_base_url.rstrip("/")
|
|
37
51
|
self.api_headers = api_headers or {}
|
|
38
52
|
self.pipeline_config = pipeline_config or PipelineConfig()
|
|
39
53
|
|
|
@@ -41,34 +55,62 @@ class Pipeline:
|
|
|
41
55
|
# 如果直接传入了 intermediate_results_destination,优先使用它并自动启用中间结果保存
|
|
42
56
|
if intermediate_results_destination is not None:
|
|
43
57
|
self.pipeline_config.include_intermediate_results = True
|
|
44
|
-
self.pipeline_config.intermediate_results_destination =
|
|
58
|
+
self.pipeline_config.intermediate_results_destination = (
|
|
59
|
+
intermediate_results_destination
|
|
60
|
+
)
|
|
45
61
|
# 如果 pipeline_config 中已设置,使用 pipeline_config 中的值
|
|
46
62
|
elif self.pipeline_config.include_intermediate_results:
|
|
47
63
|
if not self.pipeline_config.intermediate_results_destination:
|
|
48
|
-
raise ValueError(
|
|
64
|
+
raise ValueError(
|
|
65
|
+
"当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination"
|
|
66
|
+
)
|
|
49
67
|
|
|
50
68
|
# 处理 stages 配置
|
|
51
69
|
if stages is None:
|
|
52
70
|
raise ValueError("必须提供 stages 参数")
|
|
53
|
-
|
|
71
|
+
|
|
54
72
|
self.stages = stages
|
|
55
73
|
|
|
56
74
|
# 验证 stages
|
|
57
|
-
if not self.stages or self.stages[0].type !=
|
|
75
|
+
if not self.stages or self.stages[0].type != "parse":
|
|
58
76
|
raise ValueError("stages 必须包含且第一个必须是 'parse' 类型")
|
|
59
|
-
|
|
77
|
+
|
|
78
|
+
# 检查是否包含 extract 阶段
|
|
79
|
+
self.has_extract = any(stage.type == "extract" for stage in self.stages)
|
|
80
|
+
|
|
81
|
+
# 验证 extract 阶段的约束
|
|
82
|
+
if self.has_extract:
|
|
83
|
+
# extract 只能跟在 parse 后面,且必须是最后一个
|
|
84
|
+
if len(self.stages) != 2 or self.stages[1].type != "extract":
|
|
85
|
+
raise ValueError(
|
|
86
|
+
"extract 阶段只能跟在 parse 后面,且必须是最后一个阶段(即只能是 [parse, extract] 组合)"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# destination 必须是文件存储类型
|
|
90
|
+
if not isinstance(destination, (LocalDestination, S3Destination)):
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"当使用 extract 阶段时,destination 必须是文件存储类型(LocalDestination 或 S3Destination),不能是向量数据库"
|
|
93
|
+
)
|
|
94
|
+
|
|
60
95
|
# 验证 embed config(如果存在)
|
|
61
96
|
for stage in self.stages:
|
|
62
|
-
if stage.type ==
|
|
97
|
+
if stage.type == "embed" and isinstance(stage.config, EmbedConfig):
|
|
63
98
|
stage.config.validate()
|
|
64
|
-
|
|
99
|
+
|
|
65
100
|
# 验证 intermediate_results_destination
|
|
66
101
|
if self.pipeline_config.include_intermediate_results:
|
|
67
102
|
# 验证是否为支持的 Destination 类型
|
|
68
103
|
from .destinations import Destination
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
104
|
+
|
|
105
|
+
if not isinstance(
|
|
106
|
+
self.pipeline_config.intermediate_results_destination, Destination
|
|
107
|
+
):
|
|
108
|
+
raise ValueError(
|
|
109
|
+
f"intermediate_results_destination 必须是 Destination 类型"
|
|
110
|
+
)
|
|
111
|
+
self.intermediate_results_destination = (
|
|
112
|
+
self.pipeline_config.intermediate_results_destination
|
|
113
|
+
)
|
|
72
114
|
|
|
73
115
|
print("=" * 60)
|
|
74
116
|
print("Pipeline 初始化完成")
|
|
@@ -82,162 +124,264 @@ class Pipeline:
|
|
|
82
124
|
def get_config(self) -> Dict[str, Any]:
|
|
83
125
|
"""获取 Pipeline 的完整配置信息,返回字典格式(与 create_pipeline_from_config 的入参格式一致)"""
|
|
84
126
|
config = {}
|
|
85
|
-
|
|
127
|
+
|
|
86
128
|
# Source 配置
|
|
87
|
-
source_type = type(self.source).__name__.replace(
|
|
88
|
-
config[
|
|
89
|
-
|
|
129
|
+
source_type = type(self.source).__name__.replace("Source", "").lower()
|
|
130
|
+
config["source"] = {"type": source_type}
|
|
131
|
+
|
|
90
132
|
if isinstance(self.source, S3Source):
|
|
91
|
-
config[
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
133
|
+
config["source"].update(
|
|
134
|
+
{
|
|
135
|
+
"endpoint": self.source.endpoint,
|
|
136
|
+
"bucket": self.source.bucket,
|
|
137
|
+
"prefix": self.source.prefix,
|
|
138
|
+
"pattern": self.source.pattern,
|
|
139
|
+
"recursive": self.source.recursive,
|
|
140
|
+
}
|
|
141
|
+
)
|
|
98
142
|
# access_key 和 secret_key 不在对象中保存,无法恢复
|
|
99
143
|
# region 也不在对象中保存,使用默认值
|
|
100
|
-
config[
|
|
144
|
+
config["source"]["region"] = "us-east-1" # 默认值
|
|
101
145
|
elif isinstance(self.source, LocalSource):
|
|
102
|
-
config[
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
146
|
+
config["source"].update(
|
|
147
|
+
{
|
|
148
|
+
"directory": str(self.source.directory),
|
|
149
|
+
"pattern": self.source.pattern,
|
|
150
|
+
"recursive": self.source.recursive,
|
|
151
|
+
}
|
|
152
|
+
)
|
|
107
153
|
elif isinstance(self.source, FtpSource):
|
|
108
|
-
config[
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
154
|
+
config["source"].update(
|
|
155
|
+
{
|
|
156
|
+
"host": self.source.host,
|
|
157
|
+
"port": self.source.port,
|
|
158
|
+
"username": self.source.username,
|
|
159
|
+
"pattern": self.source.pattern,
|
|
160
|
+
"recursive": self.source.recursive,
|
|
161
|
+
}
|
|
162
|
+
)
|
|
115
163
|
# password 不在对象中保存,无法恢复
|
|
116
164
|
elif isinstance(self.source, SmbSource):
|
|
117
|
-
config[
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
165
|
+
config["source"].update(
|
|
166
|
+
{
|
|
167
|
+
"host": self.source.host,
|
|
168
|
+
"share_name": self.source.share_name,
|
|
169
|
+
"username": self.source.username,
|
|
170
|
+
"domain": self.source.domain,
|
|
171
|
+
"port": self.source.port,
|
|
172
|
+
"path": self.source.path,
|
|
173
|
+
"pattern": self.source.pattern,
|
|
174
|
+
"recursive": self.source.recursive,
|
|
175
|
+
}
|
|
176
|
+
)
|
|
127
177
|
# password 不在对象中保存,无法恢复
|
|
128
|
-
|
|
178
|
+
|
|
129
179
|
# Destination 配置
|
|
130
|
-
dest_type = type(self.destination).__name__.replace(
|
|
180
|
+
dest_type = type(self.destination).__name__.replace("Destination", "").lower()
|
|
131
181
|
# MilvusDestination 和 Zilliz 都使用 'milvus' 或 'zilliz' 类型
|
|
132
|
-
if dest_type ==
|
|
182
|
+
if dest_type == "milvus":
|
|
133
183
|
# 判断是本地 Milvus 还是 Zilliz(通过 db_path 判断)
|
|
134
|
-
if self.destination.db_path.startswith(
|
|
135
|
-
dest_type =
|
|
184
|
+
if self.destination.db_path.startswith("http"):
|
|
185
|
+
dest_type = "zilliz"
|
|
136
186
|
else:
|
|
137
|
-
dest_type =
|
|
138
|
-
|
|
139
|
-
config[
|
|
140
|
-
|
|
187
|
+
dest_type = "milvus"
|
|
188
|
+
|
|
189
|
+
config["destination"] = {"type": dest_type}
|
|
190
|
+
|
|
141
191
|
if isinstance(self.destination, MilvusDestination):
|
|
142
|
-
config[
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
192
|
+
config["destination"].update(
|
|
193
|
+
{
|
|
194
|
+
"db_path": self.destination.db_path,
|
|
195
|
+
"collection_name": self.destination.collection_name,
|
|
196
|
+
"dimension": self.destination.dimension,
|
|
197
|
+
}
|
|
198
|
+
)
|
|
147
199
|
# api_key 和 token 不在对象中保存,无法恢复
|
|
148
200
|
elif isinstance(self.destination, QdrantDestination):
|
|
149
|
-
config[
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
201
|
+
config["destination"].update(
|
|
202
|
+
{
|
|
203
|
+
"url": self.destination.url,
|
|
204
|
+
"collection_name": self.destination.collection_name,
|
|
205
|
+
"dimension": self.destination.dimension,
|
|
206
|
+
"prefer_grpc": getattr(self.destination, "prefer_grpc", False),
|
|
207
|
+
}
|
|
208
|
+
)
|
|
155
209
|
# api_key 不在对象中保存,无法恢复
|
|
156
210
|
elif isinstance(self.destination, LocalDestination):
|
|
157
|
-
config[
|
|
158
|
-
|
|
159
|
-
|
|
211
|
+
config["destination"].update(
|
|
212
|
+
{"output_dir": str(self.destination.output_dir)}
|
|
213
|
+
)
|
|
160
214
|
elif isinstance(self.destination, S3Destination):
|
|
161
|
-
config[
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
215
|
+
config["destination"].update(
|
|
216
|
+
{
|
|
217
|
+
"endpoint": self.destination.endpoint,
|
|
218
|
+
"bucket": self.destination.bucket,
|
|
219
|
+
"prefix": self.destination.prefix,
|
|
220
|
+
}
|
|
221
|
+
)
|
|
166
222
|
# access_key, secret_key, region 不在对象中保存,无法恢复
|
|
167
|
-
config[
|
|
168
|
-
|
|
223
|
+
config["destination"]["region"] = "us-east-1" # 默认值
|
|
224
|
+
|
|
169
225
|
# API 配置
|
|
170
|
-
config[
|
|
171
|
-
config[
|
|
226
|
+
config["api_base_url"] = self.api_base_url
|
|
227
|
+
config["api_headers"] = {}
|
|
172
228
|
for key, value in self.api_headers.items():
|
|
173
|
-
config[
|
|
174
|
-
|
|
229
|
+
config["api_headers"][key] = value
|
|
230
|
+
|
|
175
231
|
# Stages 配置
|
|
176
|
-
config[
|
|
232
|
+
config["stages"] = []
|
|
177
233
|
for stage in self.stages:
|
|
178
|
-
stage_dict = {
|
|
179
|
-
|
|
180
|
-
'config': {}
|
|
181
|
-
}
|
|
182
|
-
|
|
234
|
+
stage_dict = {"type": stage.type, "config": {}}
|
|
235
|
+
|
|
183
236
|
if isinstance(stage.config, ParseConfig):
|
|
184
|
-
stage_dict[
|
|
237
|
+
stage_dict["config"] = stage.config.to_dict()
|
|
185
238
|
elif isinstance(stage.config, ChunkConfig):
|
|
186
|
-
stage_dict[
|
|
239
|
+
stage_dict["config"] = stage.config.to_dict()
|
|
187
240
|
elif isinstance(stage.config, EmbedConfig):
|
|
188
|
-
stage_dict[
|
|
241
|
+
stage_dict["config"] = stage.config.to_dict()
|
|
242
|
+
elif isinstance(stage.config, ExtractConfig):
|
|
243
|
+
stage_dict["config"] = stage.config.to_dict()
|
|
189
244
|
else:
|
|
190
245
|
# 如果 config 是字典或其他类型,尝试转换
|
|
191
246
|
if isinstance(stage.config, dict):
|
|
192
|
-
stage_dict[
|
|
247
|
+
stage_dict["config"] = stage.config
|
|
193
248
|
else:
|
|
194
|
-
stage_dict[
|
|
195
|
-
|
|
196
|
-
config[
|
|
197
|
-
|
|
249
|
+
stage_dict["config"] = str(stage.config)
|
|
250
|
+
|
|
251
|
+
config["stages"].append(stage_dict)
|
|
252
|
+
|
|
198
253
|
# Pipeline Config
|
|
199
254
|
if self.pipeline_config.include_intermediate_results:
|
|
200
|
-
config[
|
|
201
|
-
|
|
202
|
-
|
|
255
|
+
config["pipeline_config"] = {
|
|
256
|
+
"include_intermediate_results": True,
|
|
257
|
+
"intermediate_results_destination": {},
|
|
203
258
|
}
|
|
204
|
-
|
|
259
|
+
|
|
205
260
|
inter_dest = self.pipeline_config.intermediate_results_destination
|
|
206
261
|
if inter_dest:
|
|
207
|
-
inter_dest_type =
|
|
208
|
-
|
|
209
|
-
|
|
262
|
+
inter_dest_type = (
|
|
263
|
+
type(inter_dest).__name__.replace("Destination", "").lower()
|
|
264
|
+
)
|
|
265
|
+
config["pipeline_config"]["intermediate_results_destination"][
|
|
266
|
+
"type"
|
|
267
|
+
] = inter_dest_type
|
|
268
|
+
|
|
210
269
|
if isinstance(inter_dest, LocalDestination):
|
|
211
|
-
config[
|
|
270
|
+
config["pipeline_config"]["intermediate_results_destination"][
|
|
271
|
+
"output_dir"
|
|
272
|
+
] = str(inter_dest.output_dir)
|
|
212
273
|
elif isinstance(inter_dest, S3Destination):
|
|
213
|
-
config[
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
274
|
+
config["pipeline_config"][
|
|
275
|
+
"intermediate_results_destination"
|
|
276
|
+
].update(
|
|
277
|
+
{
|
|
278
|
+
"endpoint": inter_dest.endpoint,
|
|
279
|
+
"bucket": inter_dest.bucket,
|
|
280
|
+
"prefix": inter_dest.prefix,
|
|
281
|
+
}
|
|
282
|
+
)
|
|
218
283
|
# access_key, secret_key, region 不在对象中保存,无法恢复
|
|
219
|
-
config[
|
|
220
|
-
|
|
284
|
+
config["pipeline_config"]["intermediate_results_destination"][
|
|
285
|
+
"region"
|
|
286
|
+
] = "us-east-1" # 默认值
|
|
287
|
+
|
|
221
288
|
return config
|
|
222
289
|
|
|
223
|
-
def
|
|
290
|
+
def _extract_error_message(self, response: requests.Response) -> Tuple[str, str]:
|
|
291
|
+
"""
|
|
292
|
+
从响应中提取规范化的错误信息
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
Tuple[str, str]: (error_msg, x_request_id)
|
|
296
|
+
"""
|
|
297
|
+
# 首先尝试从响应头中提取 x-request-id(requests的headers大小写不敏感)
|
|
298
|
+
x_request_id = response.headers.get("x-request-id", "")
|
|
299
|
+
error_msg = ""
|
|
300
|
+
|
|
301
|
+
# 获取Content-Type
|
|
302
|
+
content_type = response.headers.get("Content-Type", "").lower()
|
|
303
|
+
|
|
304
|
+
# 尝试解析JSON响应
|
|
305
|
+
if "application/json" in content_type:
|
|
306
|
+
try:
|
|
307
|
+
result = response.json()
|
|
308
|
+
# 如果响应头中没有x-request-id,尝试从响应体中获取
|
|
309
|
+
if not x_request_id:
|
|
310
|
+
x_request_id = result.get("x_request_id", "")
|
|
311
|
+
error_msg = result.get(
|
|
312
|
+
"message", result.get("msg", f"HTTP {response.status_code}")
|
|
313
|
+
)
|
|
314
|
+
return error_msg, x_request_id
|
|
315
|
+
except:
|
|
316
|
+
pass
|
|
317
|
+
|
|
318
|
+
# 处理HTML响应
|
|
319
|
+
if "text/html" in content_type or response.text.strip().startswith("<"):
|
|
320
|
+
try:
|
|
321
|
+
# 从HTML中提取标题(通常包含状态码和状态文本)
|
|
322
|
+
title_match = re.search(
|
|
323
|
+
r"<title>(.*?)</title>", response.text, re.IGNORECASE
|
|
324
|
+
)
|
|
325
|
+
if title_match:
|
|
326
|
+
error_msg = title_match.group(1).strip()
|
|
327
|
+
else:
|
|
328
|
+
# 如果没有title,尝试提取h1标签
|
|
329
|
+
h1_match = re.search(
|
|
330
|
+
r"<h1>(.*?)</h1>", response.text, re.IGNORECASE
|
|
331
|
+
)
|
|
332
|
+
if h1_match:
|
|
333
|
+
error_msg = h1_match.group(1).strip()
|
|
334
|
+
else:
|
|
335
|
+
error_msg = f"HTTP {response.status_code}"
|
|
336
|
+
except:
|
|
337
|
+
error_msg = f"HTTP {response.status_code}"
|
|
338
|
+
|
|
339
|
+
# 处理纯文本响应
|
|
340
|
+
elif "text/plain" in content_type:
|
|
341
|
+
error_msg = (
|
|
342
|
+
response.text[:200].strip()
|
|
343
|
+
if response.text
|
|
344
|
+
else f"HTTP {response.status_code}"
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# 其他情况
|
|
348
|
+
else:
|
|
349
|
+
if response.text:
|
|
350
|
+
# 尝试截取前200字符,但去除换行和多余空格
|
|
351
|
+
text = response.text[:200].strip()
|
|
352
|
+
# 如果包含多行,只取第一行
|
|
353
|
+
if "\n" in text:
|
|
354
|
+
text = text.split("\n")[0].strip()
|
|
355
|
+
error_msg = text if text else f"HTTP {response.status_code}"
|
|
356
|
+
else:
|
|
357
|
+
error_msg = f"HTTP {response.status_code}"
|
|
358
|
+
|
|
359
|
+
return error_msg, x_request_id
|
|
360
|
+
|
|
361
|
+
def _call_pipeline_api(
|
|
362
|
+
self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]
|
|
363
|
+
) -> Optional[Dict[str, Any]]:
|
|
224
364
|
url = f"{self.api_base_url}/pipeline"
|
|
225
365
|
max_retries = 3
|
|
226
366
|
|
|
227
367
|
for try_count in range(max_retries):
|
|
228
368
|
try:
|
|
229
|
-
files = {
|
|
369
|
+
files = {"file": (filename or "file", file_bytes)}
|
|
230
370
|
form_data = {}
|
|
231
371
|
|
|
232
372
|
# 将 stages 转换为 API 格式
|
|
233
373
|
stages_data = [stage.to_dict() for stage in self.stages]
|
|
234
374
|
try:
|
|
235
|
-
form_data[
|
|
236
|
-
form_data[
|
|
237
|
-
|
|
375
|
+
form_data["stages"] = json.dumps(stages_data, ensure_ascii=False)
|
|
376
|
+
form_data["data_source"] = json.dumps(
|
|
377
|
+
data_source, ensure_ascii=False
|
|
378
|
+
)
|
|
379
|
+
|
|
238
380
|
# 如果启用了中间结果保存,在请求中添加参数
|
|
239
381
|
if self.pipeline_config:
|
|
240
|
-
form_data[
|
|
382
|
+
form_data["config"] = json.dumps(
|
|
383
|
+
self.pipeline_config.to_dict(), ensure_ascii=False
|
|
384
|
+
)
|
|
241
385
|
except Exception as e:
|
|
242
386
|
print(f" ✗ 入参处理失败,请检查配置: {e}")
|
|
243
387
|
logger.error(f"入参处理失败,请检查配置: {e}")
|
|
@@ -248,76 +392,136 @@ class Pipeline:
|
|
|
248
392
|
files=files,
|
|
249
393
|
data=form_data,
|
|
250
394
|
headers=self.api_headers,
|
|
251
|
-
timeout=630
|
|
395
|
+
timeout=630,
|
|
252
396
|
)
|
|
253
397
|
|
|
254
398
|
if response.status_code == 200:
|
|
255
399
|
result = response.json()
|
|
256
|
-
x_request_id = result.get(
|
|
400
|
+
x_request_id = result.get("x_request_id", "")
|
|
257
401
|
print(f" ✓ Pipeline 接口返回 x_request_id: {x_request_id}")
|
|
258
|
-
if result.get(
|
|
259
|
-
return result.get(
|
|
402
|
+
if result.get("code") == 200 and "data" in result:
|
|
403
|
+
return result.get("data")
|
|
260
404
|
# 如果 code 不是 200,打印错误信息
|
|
261
|
-
error_msg = result.get(
|
|
262
|
-
print(
|
|
263
|
-
|
|
405
|
+
error_msg = result.get("message", result.get("msg", "未知错误"))
|
|
406
|
+
print(
|
|
407
|
+
f" ✗ Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}"
|
|
408
|
+
)
|
|
409
|
+
logger.error(
|
|
410
|
+
f"Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}"
|
|
411
|
+
)
|
|
264
412
|
return None
|
|
265
413
|
else:
|
|
266
|
-
#
|
|
267
|
-
x_request_id =
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
print(f" ✗ API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}")
|
|
277
|
-
logger.warning(f"API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}")
|
|
414
|
+
# 使用规范化函数提取错误信息
|
|
415
|
+
error_msg, x_request_id = self._extract_error_message(response)
|
|
416
|
+
|
|
417
|
+
print(
|
|
418
|
+
f" ✗ API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}"
|
|
419
|
+
)
|
|
420
|
+
logger.warning(
|
|
421
|
+
f"API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}"
|
|
422
|
+
)
|
|
278
423
|
|
|
279
424
|
except Exception as e:
|
|
280
425
|
# 如果是 requests 异常,尝试从响应中获取 x_request_id
|
|
281
|
-
x_request_id =
|
|
426
|
+
x_request_id = ""
|
|
282
427
|
error_msg = str(e)
|
|
283
428
|
try:
|
|
284
|
-
if hasattr(e,
|
|
429
|
+
if hasattr(e, "response") and e.response is not None:
|
|
285
430
|
try:
|
|
286
431
|
result = e.response.json()
|
|
287
|
-
x_request_id = result.get(
|
|
288
|
-
error_msg = result.get(
|
|
432
|
+
x_request_id = result.get("x_request_id", "")
|
|
433
|
+
error_msg = result.get(
|
|
434
|
+
"message", result.get("msg", error_msg)
|
|
435
|
+
)
|
|
289
436
|
except:
|
|
290
437
|
pass
|
|
291
438
|
except:
|
|
292
439
|
pass
|
|
293
|
-
|
|
294
|
-
print(
|
|
295
|
-
|
|
440
|
+
|
|
441
|
+
print(
|
|
442
|
+
f" ✗ 请求异常: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}"
|
|
443
|
+
)
|
|
444
|
+
logger.error(
|
|
445
|
+
f"API 请求异常 pipeline: {error_msg}, x_request_id={x_request_id}"
|
|
446
|
+
)
|
|
296
447
|
|
|
297
448
|
if try_count < max_retries - 1:
|
|
298
449
|
time.sleep(2)
|
|
299
450
|
|
|
300
451
|
return None
|
|
301
452
|
|
|
302
|
-
def process_with_pipeline(
|
|
453
|
+
def process_with_pipeline(
|
|
454
|
+
self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]
|
|
455
|
+
) -> Optional[Tuple[Any, PipelineStats]]:
|
|
303
456
|
print(f" → 调用 Pipeline 接口: {filename}")
|
|
304
457
|
result = self._call_pipeline_api(file_bytes, filename, data_source)
|
|
305
458
|
|
|
306
|
-
if
|
|
307
|
-
|
|
308
|
-
|
|
459
|
+
if not result or "stats" not in result:
|
|
460
|
+
print(f" ✗ Pipeline 失败")
|
|
461
|
+
logger.error(f"Pipeline 失败: {filename}")
|
|
462
|
+
return None
|
|
463
|
+
|
|
464
|
+
# 处理 extract 类型的响应
|
|
465
|
+
if self.has_extract:
|
|
466
|
+
# extract 返回 extract_result 而不是 elements
|
|
467
|
+
if "extract_result" not in result:
|
|
468
|
+
print(f" ✗ Pipeline 失败: extract 响应中缺少 extract_result")
|
|
469
|
+
logger.error(f"Pipeline 失败: extract 响应中缺少 extract_result")
|
|
470
|
+
return None
|
|
471
|
+
|
|
472
|
+
data = result["extract_result"] # 结构化数据
|
|
473
|
+
stats_data = result["stats"]
|
|
309
474
|
|
|
310
475
|
stats = PipelineStats(
|
|
311
|
-
original_elements=stats_data.get(
|
|
312
|
-
chunked_elements=
|
|
313
|
-
embedded_elements=
|
|
476
|
+
original_elements=stats_data.get("original_elements", 0),
|
|
477
|
+
chunked_elements=0, # extract 不涉及分块
|
|
478
|
+
embedded_elements=0, # extract 不涉及向量化
|
|
479
|
+
stages=self.stages,
|
|
480
|
+
record_id=stats_data.get("record_id"),
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# 如果启用了中间结果保存,处理中间结果
|
|
484
|
+
if (
|
|
485
|
+
self.pipeline_config.include_intermediate_results
|
|
486
|
+
and "intermediate_results" in result
|
|
487
|
+
):
|
|
488
|
+
self._save_intermediate_results(
|
|
489
|
+
result["intermediate_results"], filename, data_source
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
print(f" ✓ Extract 完成:")
|
|
493
|
+
print(f" - 原始元素: {stats.original_elements}")
|
|
494
|
+
print(f" - 提取结果类型: {type(data).__name__}")
|
|
495
|
+
logger.info(f"Extract 完成: {filename}")
|
|
496
|
+
|
|
497
|
+
return data, stats
|
|
498
|
+
|
|
499
|
+
else:
|
|
500
|
+
# 原有的 parse/chunk/embed 逻辑
|
|
501
|
+
if "elements" not in result:
|
|
502
|
+
print(f" ✗ Pipeline 失败: 响应中缺少 elements")
|
|
503
|
+
logger.error(f"Pipeline 失败: 响应中缺少 elements")
|
|
504
|
+
return None
|
|
505
|
+
|
|
506
|
+
elements = result["elements"]
|
|
507
|
+
stats_data = result["stats"]
|
|
508
|
+
|
|
509
|
+
stats = PipelineStats(
|
|
510
|
+
original_elements=stats_data.get("original_elements", 0),
|
|
511
|
+
chunked_elements=stats_data.get("chunked_elements", 0),
|
|
512
|
+
embedded_elements=stats_data.get("embedded_elements", 0),
|
|
314
513
|
stages=self.stages, # 使用实际执行的 stages
|
|
315
|
-
record_id=stats_data.get(
|
|
514
|
+
record_id=stats_data.get("record_id"), # 从 API 响应中获取 record_id
|
|
316
515
|
)
|
|
317
516
|
|
|
318
517
|
# 如果启用了中间结果保存,处理中间结果
|
|
319
|
-
if
|
|
320
|
-
self.
|
|
518
|
+
if (
|
|
519
|
+
self.pipeline_config.include_intermediate_results
|
|
520
|
+
and "intermediate_results" in result
|
|
521
|
+
):
|
|
522
|
+
self._save_intermediate_results(
|
|
523
|
+
result["intermediate_results"], filename, data_source
|
|
524
|
+
)
|
|
321
525
|
|
|
322
526
|
print(f" ✓ Pipeline 完成:")
|
|
323
527
|
print(f" - 原始元素: {stats.original_elements}")
|
|
@@ -326,14 +530,15 @@ class Pipeline:
|
|
|
326
530
|
logger.info(f"Pipeline 完成: {filename}, {stats.embedded_elements} 个向量")
|
|
327
531
|
|
|
328
532
|
return elements, stats
|
|
329
|
-
else:
|
|
330
|
-
print(f" ✗ Pipeline 失败")
|
|
331
|
-
logger.error(f"Pipeline 失败: {filename}")
|
|
332
|
-
return None
|
|
333
533
|
|
|
334
|
-
def _save_intermediate_results(
|
|
534
|
+
def _save_intermediate_results(
|
|
535
|
+
self,
|
|
536
|
+
intermediate_results: List[Dict[str, Any]],
|
|
537
|
+
filename: str,
|
|
538
|
+
data_source: Dict[str, Any],
|
|
539
|
+
) -> None:
|
|
335
540
|
"""保存中间结果
|
|
336
|
-
|
|
541
|
+
|
|
337
542
|
Args:
|
|
338
543
|
intermediate_results: 中间结果数组,每个元素包含 stage 和 elements 字段
|
|
339
544
|
filename: 文件名
|
|
@@ -342,22 +547,26 @@ class Pipeline:
|
|
|
342
547
|
try:
|
|
343
548
|
# intermediate_results 是一个数组,每个元素是 {stage: str, elements: List}
|
|
344
549
|
for result_item in intermediate_results:
|
|
345
|
-
if
|
|
346
|
-
logger.warning(
|
|
550
|
+
if "stage" not in result_item or "elements" not in result_item:
|
|
551
|
+
logger.warning(
|
|
552
|
+
f"中间结果项缺少 stage 或 elements 字段: {result_item}"
|
|
553
|
+
)
|
|
347
554
|
continue
|
|
348
|
-
|
|
349
|
-
stage = result_item[
|
|
350
|
-
elements = result_item[
|
|
351
|
-
|
|
555
|
+
|
|
556
|
+
stage = result_item["stage"]
|
|
557
|
+
elements = result_item["elements"]
|
|
558
|
+
|
|
352
559
|
metadata = {
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
560
|
+
"filename": filename,
|
|
561
|
+
"stage": stage,
|
|
562
|
+
"total_elements": len(elements),
|
|
563
|
+
"processed_at": datetime.now().isoformat(),
|
|
564
|
+
"data_source": data_source,
|
|
358
565
|
}
|
|
359
|
-
|
|
360
|
-
self.pipeline_config.intermediate_results_destination.write(
|
|
566
|
+
|
|
567
|
+
self.pipeline_config.intermediate_results_destination.write(
|
|
568
|
+
elements, metadata
|
|
569
|
+
)
|
|
361
570
|
print(f" ✓ 保存 {stage.upper()} 中间结果: {len(elements)} 个元素")
|
|
362
571
|
logger.info(f"保存 {stage.upper()} 中间结果成功: {filename}")
|
|
363
572
|
|
|
@@ -374,17 +583,17 @@ class Pipeline:
|
|
|
374
583
|
print(f" → 读取文件...")
|
|
375
584
|
file_bytes, data_source = self.source.read_file(file_path)
|
|
376
585
|
data_source = data_source or {}
|
|
377
|
-
|
|
586
|
+
|
|
378
587
|
# 检查文件大小,超过 100MB 则报错
|
|
379
588
|
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
|
|
380
589
|
file_size = len(file_bytes)
|
|
381
590
|
if file_size > MAX_FILE_SIZE:
|
|
382
591
|
file_size_mb = file_size / (1024 * 1024)
|
|
383
592
|
raise ValueError(f"文件大小过大: {file_size_mb:.2f}MB,超过100MB限制")
|
|
384
|
-
|
|
593
|
+
|
|
385
594
|
# 转换为毫秒时间戳字符串
|
|
386
595
|
timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000)
|
|
387
|
-
data_source[
|
|
596
|
+
data_source["date_processed"] = str(timestamp_ms)
|
|
388
597
|
print(f" ✓ 文件读取完成: {len(file_bytes)} bytes")
|
|
389
598
|
|
|
390
599
|
result = self.process_with_pipeline(file_bytes, file_path, data_source)
|
|
@@ -395,13 +604,13 @@ class Pipeline:
|
|
|
395
604
|
|
|
396
605
|
print(f" → 写入目的地...")
|
|
397
606
|
metadata = {
|
|
398
|
-
|
|
399
|
-
|
|
607
|
+
"filename": file_path,
|
|
608
|
+
"processed_at": str(timestamp_ms),
|
|
400
609
|
}
|
|
401
|
-
|
|
610
|
+
|
|
402
611
|
# 如果 stats 中有 record_id,添加到 metadata 中
|
|
403
612
|
if stats.record_id:
|
|
404
|
-
metadata[
|
|
613
|
+
metadata["record_id"] = stats.record_id
|
|
405
614
|
|
|
406
615
|
success = self.destination.write(embedded_data, metadata)
|
|
407
616
|
|
|
@@ -468,168 +677,184 @@ class Pipeline:
|
|
|
468
677
|
print("=" * 60)
|
|
469
678
|
|
|
470
679
|
logger.info("=" * 60)
|
|
471
|
-
logger.info(
|
|
680
|
+
logger.info(
|
|
681
|
+
f"Pipeline 完成 - 总数:{total}, 成功:{success_count}, 失败:{fail_count}, 耗时:{elapsed:.2f}秒"
|
|
682
|
+
)
|
|
472
683
|
logger.info("=" * 60)
|
|
473
684
|
|
|
474
685
|
|
|
475
686
|
def create_pipeline_from_config(config: Dict[str, Any]) -> Pipeline:
|
|
476
|
-
source_config = config[
|
|
477
|
-
if source_config[
|
|
687
|
+
source_config = config["source"]
|
|
688
|
+
if source_config["type"] == "s3":
|
|
478
689
|
source = S3Source(
|
|
479
|
-
endpoint=source_config[
|
|
480
|
-
access_key=source_config[
|
|
481
|
-
secret_key=source_config[
|
|
482
|
-
bucket=source_config[
|
|
483
|
-
prefix=source_config.get(
|
|
484
|
-
region=source_config.get(
|
|
485
|
-
pattern=source_config.get(
|
|
486
|
-
recursive=source_config.get(
|
|
690
|
+
endpoint=source_config["endpoint"],
|
|
691
|
+
access_key=source_config["access_key"],
|
|
692
|
+
secret_key=source_config["secret_key"],
|
|
693
|
+
bucket=source_config["bucket"],
|
|
694
|
+
prefix=source_config.get("prefix", ""),
|
|
695
|
+
region=source_config.get("region", "us-east-1"),
|
|
696
|
+
pattern=source_config.get("pattern", None),
|
|
697
|
+
recursive=source_config.get("recursive", False),
|
|
487
698
|
)
|
|
488
|
-
elif source_config[
|
|
699
|
+
elif source_config["type"] == "local":
|
|
489
700
|
source = LocalSource(
|
|
490
|
-
directory=source_config[
|
|
491
|
-
pattern=source_config.get(
|
|
492
|
-
recursive=source_config.get(
|
|
701
|
+
directory=source_config["directory"],
|
|
702
|
+
pattern=source_config.get("pattern", None),
|
|
703
|
+
recursive=source_config.get("recursive", False),
|
|
493
704
|
)
|
|
494
|
-
elif source_config[
|
|
705
|
+
elif source_config["type"] == "ftp":
|
|
495
706
|
source = FtpSource(
|
|
496
|
-
host=source_config[
|
|
497
|
-
port=source_config[
|
|
498
|
-
username=source_config[
|
|
499
|
-
password=source_config[
|
|
500
|
-
pattern=source_config.get(
|
|
501
|
-
recursive=source_config.get(
|
|
707
|
+
host=source_config["host"],
|
|
708
|
+
port=source_config["port"],
|
|
709
|
+
username=source_config["username"],
|
|
710
|
+
password=source_config["password"],
|
|
711
|
+
pattern=source_config.get("pattern", None),
|
|
712
|
+
recursive=source_config.get("recursive", False),
|
|
502
713
|
)
|
|
503
|
-
elif source_config[
|
|
714
|
+
elif source_config["type"] == "smb":
|
|
504
715
|
source = SmbSource(
|
|
505
|
-
host=source_config[
|
|
506
|
-
share_name=source_config[
|
|
507
|
-
username=source_config[
|
|
508
|
-
password=source_config[
|
|
509
|
-
domain=source_config.get(
|
|
510
|
-
port=source_config.get(
|
|
511
|
-
path=source_config.get(
|
|
512
|
-
pattern=source_config.get(
|
|
513
|
-
recursive=source_config.get(
|
|
716
|
+
host=source_config["host"],
|
|
717
|
+
share_name=source_config["share_name"],
|
|
718
|
+
username=source_config["username"],
|
|
719
|
+
password=source_config["password"],
|
|
720
|
+
domain=source_config.get("domain", ""),
|
|
721
|
+
port=source_config.get("port", 445),
|
|
722
|
+
path=source_config.get("path", ""),
|
|
723
|
+
pattern=source_config.get("pattern", None),
|
|
724
|
+
recursive=source_config.get("recursive", False),
|
|
514
725
|
)
|
|
515
726
|
else:
|
|
516
727
|
raise ValueError(f"未知的 source 类型: {source_config['type']}")
|
|
517
728
|
|
|
518
|
-
dest_config = config[
|
|
519
|
-
if dest_config[
|
|
729
|
+
dest_config = config["destination"]
|
|
730
|
+
if dest_config["type"] in ["milvus", "zilliz"]:
|
|
520
731
|
destination = MilvusDestination(
|
|
521
|
-
db_path=dest_config[
|
|
522
|
-
collection_name=dest_config[
|
|
523
|
-
dimension=dest_config[
|
|
524
|
-
api_key=dest_config.get(
|
|
525
|
-
token=dest_config.get(
|
|
732
|
+
db_path=dest_config["db_path"],
|
|
733
|
+
collection_name=dest_config["collection_name"],
|
|
734
|
+
dimension=dest_config["dimension"],
|
|
735
|
+
api_key=dest_config.get("api_key"),
|
|
736
|
+
token=dest_config.get("token"),
|
|
526
737
|
)
|
|
527
|
-
elif dest_config[
|
|
738
|
+
elif dest_config["type"] == "qdrant":
|
|
528
739
|
destination = QdrantDestination(
|
|
529
|
-
url=dest_config[
|
|
530
|
-
collection_name=dest_config[
|
|
531
|
-
dimension=dest_config[
|
|
532
|
-
api_key=dest_config.get(
|
|
533
|
-
prefer_grpc=dest_config.get(
|
|
740
|
+
url=dest_config["url"],
|
|
741
|
+
collection_name=dest_config["collection_name"],
|
|
742
|
+
dimension=dest_config["dimension"],
|
|
743
|
+
api_key=dest_config.get("api_key"),
|
|
744
|
+
prefer_grpc=dest_config.get("prefer_grpc", False),
|
|
534
745
|
)
|
|
535
|
-
elif dest_config[
|
|
536
|
-
destination = LocalDestination(
|
|
537
|
-
|
|
538
|
-
)
|
|
539
|
-
elif dest_config['type'] == 's3':
|
|
746
|
+
elif dest_config["type"] == "local":
|
|
747
|
+
destination = LocalDestination(output_dir=dest_config["output_dir"])
|
|
748
|
+
elif dest_config["type"] == "s3":
|
|
540
749
|
destination = S3Destination(
|
|
541
|
-
endpoint=dest_config[
|
|
542
|
-
access_key=dest_config[
|
|
543
|
-
secret_key=dest_config[
|
|
544
|
-
bucket=dest_config[
|
|
545
|
-
prefix=dest_config.get(
|
|
546
|
-
region=dest_config.get(
|
|
750
|
+
endpoint=dest_config["endpoint"],
|
|
751
|
+
access_key=dest_config["access_key"],
|
|
752
|
+
secret_key=dest_config["secret_key"],
|
|
753
|
+
bucket=dest_config["bucket"],
|
|
754
|
+
prefix=dest_config.get("prefix", ""),
|
|
755
|
+
region=dest_config.get("region", "us-east-1"),
|
|
547
756
|
)
|
|
548
757
|
else:
|
|
549
758
|
raise ValueError(f"未知的 destination 类型: {dest_config['type']}")
|
|
550
759
|
|
|
551
760
|
# 处理 stages 配置
|
|
552
|
-
if
|
|
761
|
+
if "stages" not in config or not config["stages"]:
|
|
553
762
|
raise ValueError("配置中必须包含 'stages' 字段")
|
|
554
|
-
|
|
763
|
+
|
|
555
764
|
stages = []
|
|
556
|
-
for stage_cfg in config[
|
|
557
|
-
stage_type = stage_cfg.get(
|
|
558
|
-
stage_config_dict = stage_cfg.get(
|
|
559
|
-
|
|
560
|
-
if stage_type ==
|
|
765
|
+
for stage_cfg in config["stages"]:
|
|
766
|
+
stage_type = stage_cfg.get("type")
|
|
767
|
+
stage_config_dict = stage_cfg.get("config", {})
|
|
768
|
+
|
|
769
|
+
if stage_type == "parse":
|
|
561
770
|
parse_cfg_copy = dict(stage_config_dict)
|
|
562
|
-
provider = parse_cfg_copy.pop(
|
|
771
|
+
provider = parse_cfg_copy.pop("provider", "textin")
|
|
563
772
|
stage_config = ParseConfig(provider=provider, **parse_cfg_copy)
|
|
564
|
-
elif stage_type ==
|
|
773
|
+
elif stage_type == "chunk":
|
|
565
774
|
stage_config = ChunkConfig(
|
|
566
|
-
strategy=stage_config_dict.get(
|
|
567
|
-
include_orig_elements=stage_config_dict.get(
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
775
|
+
strategy=stage_config_dict.get("strategy", "basic"),
|
|
776
|
+
include_orig_elements=stage_config_dict.get(
|
|
777
|
+
"include_orig_elements", False
|
|
778
|
+
),
|
|
779
|
+
new_after_n_chars=stage_config_dict.get("new_after_n_chars", 512),
|
|
780
|
+
max_characters=stage_config_dict.get("max_characters", 1024),
|
|
781
|
+
overlap=stage_config_dict.get("overlap", 0),
|
|
782
|
+
overlap_all=stage_config_dict.get("overlap_all", False),
|
|
572
783
|
)
|
|
573
|
-
elif stage_type ==
|
|
784
|
+
elif stage_type == "embed":
|
|
574
785
|
stage_config = EmbedConfig(
|
|
575
|
-
provider=stage_config_dict.get(
|
|
576
|
-
model_name=stage_config_dict.get(
|
|
786
|
+
provider=stage_config_dict.get("provider", "qwen"),
|
|
787
|
+
model_name=stage_config_dict.get("model_name", "text-embedding-v3"),
|
|
788
|
+
)
|
|
789
|
+
elif stage_type == "extract":
|
|
790
|
+
schema = stage_config_dict.get("schema")
|
|
791
|
+
if not schema:
|
|
792
|
+
raise ValueError("extract stage 的 config 中必须包含 'schema' 字段")
|
|
793
|
+
stage_config = ExtractConfig(
|
|
794
|
+
schema=schema,
|
|
795
|
+
generate_citations=stage_config_dict.get("generate_citations", False),
|
|
796
|
+
stamp=stage_config_dict.get("stamp", False),
|
|
577
797
|
)
|
|
578
798
|
else:
|
|
579
799
|
raise ValueError(f"未知的 stage 类型: {stage_type}")
|
|
580
|
-
|
|
800
|
+
|
|
581
801
|
stages.append(Stage(type=stage_type, config=stage_config))
|
|
582
802
|
|
|
583
803
|
# 创建 Pipeline 配置
|
|
584
804
|
pipeline_config = None
|
|
585
|
-
if
|
|
586
|
-
pipeline_cfg = config[
|
|
587
|
-
include_intermediate_results = pipeline_cfg.get(
|
|
805
|
+
if "pipeline_config" in config and config["pipeline_config"]:
|
|
806
|
+
pipeline_cfg = config["pipeline_config"]
|
|
807
|
+
include_intermediate_results = pipeline_cfg.get(
|
|
808
|
+
"include_intermediate_results", False
|
|
809
|
+
)
|
|
588
810
|
intermediate_results_destination = None
|
|
589
|
-
|
|
811
|
+
|
|
590
812
|
if include_intermediate_results:
|
|
591
|
-
if
|
|
592
|
-
dest_cfg = pipeline_cfg[
|
|
593
|
-
dest_type = dest_cfg.get(
|
|
594
|
-
|
|
595
|
-
if dest_type ==
|
|
813
|
+
if "intermediate_results_destination" in pipeline_cfg:
|
|
814
|
+
dest_cfg = pipeline_cfg["intermediate_results_destination"]
|
|
815
|
+
dest_type = dest_cfg.get("type")
|
|
816
|
+
|
|
817
|
+
if dest_type == "local":
|
|
596
818
|
intermediate_results_destination = LocalDestination(
|
|
597
|
-
output_dir=dest_cfg[
|
|
819
|
+
output_dir=dest_cfg["output_dir"]
|
|
598
820
|
)
|
|
599
|
-
elif dest_type ==
|
|
821
|
+
elif dest_type == "s3":
|
|
600
822
|
intermediate_results_destination = S3Destination(
|
|
601
|
-
endpoint=dest_cfg[
|
|
602
|
-
access_key=dest_cfg[
|
|
603
|
-
secret_key=dest_cfg[
|
|
604
|
-
bucket=dest_cfg[
|
|
605
|
-
prefix=dest_cfg.get(
|
|
606
|
-
region=dest_cfg.get(
|
|
823
|
+
endpoint=dest_cfg["endpoint"],
|
|
824
|
+
access_key=dest_cfg["access_key"],
|
|
825
|
+
secret_key=dest_cfg["secret_key"],
|
|
826
|
+
bucket=dest_cfg["bucket"],
|
|
827
|
+
prefix=dest_cfg.get("prefix", ""),
|
|
828
|
+
region=dest_cfg.get("region", "us-east-1"),
|
|
607
829
|
)
|
|
608
830
|
else:
|
|
609
|
-
raise ValueError(
|
|
831
|
+
raise ValueError(
|
|
832
|
+
f"不支持的 intermediate_results_destination 类型: '{dest_type}',支持的类型: 'local', 's3'"
|
|
833
|
+
)
|
|
610
834
|
else:
|
|
611
|
-
raise ValueError(
|
|
612
|
-
|
|
835
|
+
raise ValueError(
|
|
836
|
+
"当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination"
|
|
837
|
+
)
|
|
838
|
+
|
|
613
839
|
pipeline_config = PipelineConfig(
|
|
614
840
|
include_intermediate_results=include_intermediate_results,
|
|
615
|
-
intermediate_results_destination=intermediate_results_destination
|
|
841
|
+
intermediate_results_destination=intermediate_results_destination,
|
|
616
842
|
)
|
|
617
843
|
|
|
618
844
|
# 创建 Pipeline
|
|
619
845
|
pipeline = Pipeline(
|
|
620
846
|
source=source,
|
|
621
847
|
destination=destination,
|
|
622
|
-
api_base_url=config.get(
|
|
623
|
-
api_headers=config.get(
|
|
848
|
+
api_base_url=config.get("api_base_url", "http://localhost:8000/api/xparse"),
|
|
849
|
+
api_headers=config.get("api_headers", {}),
|
|
624
850
|
stages=stages,
|
|
625
|
-
pipeline_config=pipeline_config
|
|
851
|
+
pipeline_config=pipeline_config,
|
|
626
852
|
)
|
|
627
853
|
|
|
628
854
|
return pipeline
|
|
629
855
|
|
|
630
856
|
|
|
631
857
|
__all__ = [
|
|
632
|
-
|
|
633
|
-
|
|
858
|
+
"Pipeline",
|
|
859
|
+
"create_pipeline_from_config",
|
|
634
860
|
]
|
|
635
|
-
|