xparse-client 0.2.18__py3-none-any.whl → 0.2.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,17 +3,31 @@
3
3
 
4
4
  import json
5
5
  import logging
6
+ import re
6
7
  import time
7
8
  from datetime import datetime, timezone
8
9
  from pathlib import Path
9
- from typing import Dict, Any, Optional, Tuple, List, Union
10
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
11
 
11
12
  import requests
12
13
 
13
- from .config import ParseConfig, ChunkConfig, EmbedConfig, Stage, PipelineStats, PipelineConfig
14
- from .sources import Source, S3Source, LocalSource, FtpSource, SmbSource
15
- from .destinations import Destination, MilvusDestination, QdrantDestination, LocalDestination, S3Destination
16
-
14
+ from .config import (
15
+ ChunkConfig,
16
+ EmbedConfig,
17
+ ExtractConfig,
18
+ ParseConfig,
19
+ PipelineConfig,
20
+ PipelineStats,
21
+ Stage,
22
+ )
23
+ from .destinations import (
24
+ Destination,
25
+ LocalDestination,
26
+ MilvusDestination,
27
+ QdrantDestination,
28
+ S3Destination,
29
+ )
30
+ from .sources import FtpSource, LocalSource, S3Source, SmbSource, Source
17
31
 
18
32
  logger = logging.getLogger(__name__)
19
33
 
@@ -25,15 +39,15 @@ class Pipeline:
25
39
  self,
26
40
  source: Source,
27
41
  destination: Destination,
28
- api_base_url: str = 'http://localhost:8000/api/xparse',
42
+ api_base_url: str = "http://localhost:8000/api/xparse",
29
43
  api_headers: Optional[Dict[str, str]] = None,
30
44
  stages: Optional[List[Stage]] = None,
31
45
  pipeline_config: Optional[PipelineConfig] = None,
32
- intermediate_results_destination: Optional[Destination] = None
46
+ intermediate_results_destination: Optional[Destination] = None,
33
47
  ):
34
48
  self.source = source
35
49
  self.destination = destination
36
- self.api_base_url = api_base_url.rstrip('/')
50
+ self.api_base_url = api_base_url.rstrip("/")
37
51
  self.api_headers = api_headers or {}
38
52
  self.pipeline_config = pipeline_config or PipelineConfig()
39
53
 
@@ -41,34 +55,62 @@ class Pipeline:
41
55
  # 如果直接传入了 intermediate_results_destination,优先使用它并自动启用中间结果保存
42
56
  if intermediate_results_destination is not None:
43
57
  self.pipeline_config.include_intermediate_results = True
44
- self.pipeline_config.intermediate_results_destination = intermediate_results_destination
58
+ self.pipeline_config.intermediate_results_destination = (
59
+ intermediate_results_destination
60
+ )
45
61
  # 如果 pipeline_config 中已设置,使用 pipeline_config 中的值
46
62
  elif self.pipeline_config.include_intermediate_results:
47
63
  if not self.pipeline_config.intermediate_results_destination:
48
- raise ValueError("当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination")
64
+ raise ValueError(
65
+ "当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination"
66
+ )
49
67
 
50
68
  # 处理 stages 配置
51
69
  if stages is None:
52
70
  raise ValueError("必须提供 stages 参数")
53
-
71
+
54
72
  self.stages = stages
55
73
 
56
74
  # 验证 stages
57
- if not self.stages or self.stages[0].type != 'parse':
75
+ if not self.stages or self.stages[0].type != "parse":
58
76
  raise ValueError("stages 必须包含且第一个必须是 'parse' 类型")
59
-
77
+
78
+ # 检查是否包含 extract 阶段
79
+ self.has_extract = any(stage.type == "extract" for stage in self.stages)
80
+
81
+ # 验证 extract 阶段的约束
82
+ if self.has_extract:
83
+ # extract 只能跟在 parse 后面,且必须是最后一个
84
+ if len(self.stages) != 2 or self.stages[1].type != "extract":
85
+ raise ValueError(
86
+ "extract 阶段只能跟在 parse 后面,且必须是最后一个阶段(即只能是 [parse, extract] 组合)"
87
+ )
88
+
89
+ # destination 必须是文件存储类型
90
+ if not isinstance(destination, (LocalDestination, S3Destination)):
91
+ raise ValueError(
92
+ "当使用 extract 阶段时,destination 必须是文件存储类型(LocalDestination 或 S3Destination),不能是向量数据库"
93
+ )
94
+
60
95
  # 验证 embed config(如果存在)
61
96
  for stage in self.stages:
62
- if stage.type == 'embed' and isinstance(stage.config, EmbedConfig):
97
+ if stage.type == "embed" and isinstance(stage.config, EmbedConfig):
63
98
  stage.config.validate()
64
-
99
+
65
100
  # 验证 intermediate_results_destination
66
101
  if self.pipeline_config.include_intermediate_results:
67
102
  # 验证是否为支持的 Destination 类型
68
103
  from .destinations import Destination
69
- if not isinstance(self.pipeline_config.intermediate_results_destination, Destination):
70
- raise ValueError(f"intermediate_results_destination 必须是 Destination 类型")
71
- self.intermediate_results_destination = self.pipeline_config.intermediate_results_destination
104
+
105
+ if not isinstance(
106
+ self.pipeline_config.intermediate_results_destination, Destination
107
+ ):
108
+ raise ValueError(
109
+ f"intermediate_results_destination 必须是 Destination 类型"
110
+ )
111
+ self.intermediate_results_destination = (
112
+ self.pipeline_config.intermediate_results_destination
113
+ )
72
114
 
73
115
  print("=" * 60)
74
116
  print("Pipeline 初始化完成")
@@ -82,162 +124,264 @@ class Pipeline:
82
124
  def get_config(self) -> Dict[str, Any]:
83
125
  """获取 Pipeline 的完整配置信息,返回字典格式(与 create_pipeline_from_config 的入参格式一致)"""
84
126
  config = {}
85
-
127
+
86
128
  # Source 配置
87
- source_type = type(self.source).__name__.replace('Source', '').lower()
88
- config['source'] = {'type': source_type}
89
-
129
+ source_type = type(self.source).__name__.replace("Source", "").lower()
130
+ config["source"] = {"type": source_type}
131
+
90
132
  if isinstance(self.source, S3Source):
91
- config['source'].update({
92
- 'endpoint': self.source.endpoint,
93
- 'bucket': self.source.bucket,
94
- 'prefix': self.source.prefix,
95
- 'pattern': self.source.pattern,
96
- 'recursive': self.source.recursive
97
- })
133
+ config["source"].update(
134
+ {
135
+ "endpoint": self.source.endpoint,
136
+ "bucket": self.source.bucket,
137
+ "prefix": self.source.prefix,
138
+ "pattern": self.source.pattern,
139
+ "recursive": self.source.recursive,
140
+ }
141
+ )
98
142
  # access_key 和 secret_key 不在对象中保存,无法恢复
99
143
  # region 也不在对象中保存,使用默认值
100
- config['source']['region'] = 'us-east-1' # 默认值
144
+ config["source"]["region"] = "us-east-1" # 默认值
101
145
  elif isinstance(self.source, LocalSource):
102
- config['source'].update({
103
- 'directory': str(self.source.directory),
104
- 'pattern': self.source.pattern,
105
- 'recursive': self.source.recursive
106
- })
146
+ config["source"].update(
147
+ {
148
+ "directory": str(self.source.directory),
149
+ "pattern": self.source.pattern,
150
+ "recursive": self.source.recursive,
151
+ }
152
+ )
107
153
  elif isinstance(self.source, FtpSource):
108
- config['source'].update({
109
- 'host': self.source.host,
110
- 'port': self.source.port,
111
- 'username': self.source.username,
112
- 'pattern': self.source.pattern,
113
- 'recursive': self.source.recursive
114
- })
154
+ config["source"].update(
155
+ {
156
+ "host": self.source.host,
157
+ "port": self.source.port,
158
+ "username": self.source.username,
159
+ "pattern": self.source.pattern,
160
+ "recursive": self.source.recursive,
161
+ }
162
+ )
115
163
  # password 不在对象中保存,无法恢复
116
164
  elif isinstance(self.source, SmbSource):
117
- config['source'].update({
118
- 'host': self.source.host,
119
- 'share_name': self.source.share_name,
120
- 'username': self.source.username,
121
- 'domain': self.source.domain,
122
- 'port': self.source.port,
123
- 'path': self.source.path,
124
- 'pattern': self.source.pattern,
125
- 'recursive': self.source.recursive
126
- })
165
+ config["source"].update(
166
+ {
167
+ "host": self.source.host,
168
+ "share_name": self.source.share_name,
169
+ "username": self.source.username,
170
+ "domain": self.source.domain,
171
+ "port": self.source.port,
172
+ "path": self.source.path,
173
+ "pattern": self.source.pattern,
174
+ "recursive": self.source.recursive,
175
+ }
176
+ )
127
177
  # password 不在对象中保存,无法恢复
128
-
178
+
129
179
  # Destination 配置
130
- dest_type = type(self.destination).__name__.replace('Destination', '').lower()
180
+ dest_type = type(self.destination).__name__.replace("Destination", "").lower()
131
181
  # MilvusDestination 和 Zilliz 都使用 'milvus' 或 'zilliz' 类型
132
- if dest_type == 'milvus':
182
+ if dest_type == "milvus":
133
183
  # 判断是本地 Milvus 还是 Zilliz(通过 db_path 判断)
134
- if self.destination.db_path.startswith('http'):
135
- dest_type = 'zilliz'
184
+ if self.destination.db_path.startswith("http"):
185
+ dest_type = "zilliz"
136
186
  else:
137
- dest_type = 'milvus'
138
-
139
- config['destination'] = {'type': dest_type}
140
-
187
+ dest_type = "milvus"
188
+
189
+ config["destination"] = {"type": dest_type}
190
+
141
191
  if isinstance(self.destination, MilvusDestination):
142
- config['destination'].update({
143
- 'db_path': self.destination.db_path,
144
- 'collection_name': self.destination.collection_name,
145
- 'dimension': self.destination.dimension
146
- })
192
+ config["destination"].update(
193
+ {
194
+ "db_path": self.destination.db_path,
195
+ "collection_name": self.destination.collection_name,
196
+ "dimension": self.destination.dimension,
197
+ }
198
+ )
147
199
  # api_key 和 token 不在对象中保存,无法恢复
148
200
  elif isinstance(self.destination, QdrantDestination):
149
- config['destination'].update({
150
- 'url': self.destination.url,
151
- 'collection_name': self.destination.collection_name,
152
- 'dimension': self.destination.dimension,
153
- 'prefer_grpc': getattr(self.destination, 'prefer_grpc', False)
154
- })
201
+ config["destination"].update(
202
+ {
203
+ "url": self.destination.url,
204
+ "collection_name": self.destination.collection_name,
205
+ "dimension": self.destination.dimension,
206
+ "prefer_grpc": getattr(self.destination, "prefer_grpc", False),
207
+ }
208
+ )
155
209
  # api_key 不在对象中保存,无法恢复
156
210
  elif isinstance(self.destination, LocalDestination):
157
- config['destination'].update({
158
- 'output_dir': str(self.destination.output_dir)
159
- })
211
+ config["destination"].update(
212
+ {"output_dir": str(self.destination.output_dir)}
213
+ )
160
214
  elif isinstance(self.destination, S3Destination):
161
- config['destination'].update({
162
- 'endpoint': self.destination.endpoint,
163
- 'bucket': self.destination.bucket,
164
- 'prefix': self.destination.prefix
165
- })
215
+ config["destination"].update(
216
+ {
217
+ "endpoint": self.destination.endpoint,
218
+ "bucket": self.destination.bucket,
219
+ "prefix": self.destination.prefix,
220
+ }
221
+ )
166
222
  # access_key, secret_key, region 不在对象中保存,无法恢复
167
- config['destination']['region'] = 'us-east-1' # 默认值
168
-
223
+ config["destination"]["region"] = "us-east-1" # 默认值
224
+
169
225
  # API 配置
170
- config['api_base_url'] = self.api_base_url
171
- config['api_headers'] = {}
226
+ config["api_base_url"] = self.api_base_url
227
+ config["api_headers"] = {}
172
228
  for key, value in self.api_headers.items():
173
- config['api_headers'][key] = value
174
-
229
+ config["api_headers"][key] = value
230
+
175
231
  # Stages 配置
176
- config['stages'] = []
232
+ config["stages"] = []
177
233
  for stage in self.stages:
178
- stage_dict = {
179
- 'type': stage.type,
180
- 'config': {}
181
- }
182
-
234
+ stage_dict = {"type": stage.type, "config": {}}
235
+
183
236
  if isinstance(stage.config, ParseConfig):
184
- stage_dict['config'] = stage.config.to_dict()
237
+ stage_dict["config"] = stage.config.to_dict()
185
238
  elif isinstance(stage.config, ChunkConfig):
186
- stage_dict['config'] = stage.config.to_dict()
239
+ stage_dict["config"] = stage.config.to_dict()
187
240
  elif isinstance(stage.config, EmbedConfig):
188
- stage_dict['config'] = stage.config.to_dict()
241
+ stage_dict["config"] = stage.config.to_dict()
242
+ elif isinstance(stage.config, ExtractConfig):
243
+ stage_dict["config"] = stage.config.to_dict()
189
244
  else:
190
245
  # 如果 config 是字典或其他类型,尝试转换
191
246
  if isinstance(stage.config, dict):
192
- stage_dict['config'] = stage.config
247
+ stage_dict["config"] = stage.config
193
248
  else:
194
- stage_dict['config'] = str(stage.config)
195
-
196
- config['stages'].append(stage_dict)
197
-
249
+ stage_dict["config"] = str(stage.config)
250
+
251
+ config["stages"].append(stage_dict)
252
+
198
253
  # Pipeline Config
199
254
  if self.pipeline_config.include_intermediate_results:
200
- config['pipeline_config'] = {
201
- 'include_intermediate_results': True,
202
- 'intermediate_results_destination': {}
255
+ config["pipeline_config"] = {
256
+ "include_intermediate_results": True,
257
+ "intermediate_results_destination": {},
203
258
  }
204
-
259
+
205
260
  inter_dest = self.pipeline_config.intermediate_results_destination
206
261
  if inter_dest:
207
- inter_dest_type = type(inter_dest).__name__.replace('Destination', '').lower()
208
- config['pipeline_config']['intermediate_results_destination']['type'] = inter_dest_type
209
-
262
+ inter_dest_type = (
263
+ type(inter_dest).__name__.replace("Destination", "").lower()
264
+ )
265
+ config["pipeline_config"]["intermediate_results_destination"][
266
+ "type"
267
+ ] = inter_dest_type
268
+
210
269
  if isinstance(inter_dest, LocalDestination):
211
- config['pipeline_config']['intermediate_results_destination']['output_dir'] = str(inter_dest.output_dir)
270
+ config["pipeline_config"]["intermediate_results_destination"][
271
+ "output_dir"
272
+ ] = str(inter_dest.output_dir)
212
273
  elif isinstance(inter_dest, S3Destination):
213
- config['pipeline_config']['intermediate_results_destination'].update({
214
- 'endpoint': inter_dest.endpoint,
215
- 'bucket': inter_dest.bucket,
216
- 'prefix': inter_dest.prefix
217
- })
274
+ config["pipeline_config"][
275
+ "intermediate_results_destination"
276
+ ].update(
277
+ {
278
+ "endpoint": inter_dest.endpoint,
279
+ "bucket": inter_dest.bucket,
280
+ "prefix": inter_dest.prefix,
281
+ }
282
+ )
218
283
  # access_key, secret_key, region 不在对象中保存,无法恢复
219
- config['pipeline_config']['intermediate_results_destination']['region'] = 'us-east-1' # 默认值
220
-
284
+ config["pipeline_config"]["intermediate_results_destination"][
285
+ "region"
286
+ ] = "us-east-1" # 默认值
287
+
221
288
  return config
222
289
 
223
- def _call_pipeline_api(self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]) -> Optional[Dict[str, Any]]:
290
+ def _extract_error_message(self, response: requests.Response) -> Tuple[str, str]:
291
+ """
292
+ 从响应中提取规范化的错误信息
293
+
294
+ Returns:
295
+ Tuple[str, str]: (error_msg, x_request_id)
296
+ """
297
+ # 首先尝试从响应头中提取 x-request-id(requests的headers大小写不敏感)
298
+ x_request_id = response.headers.get("x-request-id", "")
299
+ error_msg = ""
300
+
301
+ # 获取Content-Type
302
+ content_type = response.headers.get("Content-Type", "").lower()
303
+
304
+ # 尝试解析JSON响应
305
+ if "application/json" in content_type:
306
+ try:
307
+ result = response.json()
308
+ # 如果响应头中没有x-request-id,尝试从响应体中获取
309
+ if not x_request_id:
310
+ x_request_id = result.get("x_request_id", "")
311
+ error_msg = result.get(
312
+ "message", result.get("msg", f"HTTP {response.status_code}")
313
+ )
314
+ return error_msg, x_request_id
315
+ except:
316
+ pass
317
+
318
+ # 处理HTML响应
319
+ if "text/html" in content_type or response.text.strip().startswith("<"):
320
+ try:
321
+ # 从HTML中提取标题(通常包含状态码和状态文本)
322
+ title_match = re.search(
323
+ r"<title>(.*?)</title>", response.text, re.IGNORECASE
324
+ )
325
+ if title_match:
326
+ error_msg = title_match.group(1).strip()
327
+ else:
328
+ # 如果没有title,尝试提取h1标签
329
+ h1_match = re.search(
330
+ r"<h1>(.*?)</h1>", response.text, re.IGNORECASE
331
+ )
332
+ if h1_match:
333
+ error_msg = h1_match.group(1).strip()
334
+ else:
335
+ error_msg = f"HTTP {response.status_code}"
336
+ except:
337
+ error_msg = f"HTTP {response.status_code}"
338
+
339
+ # 处理纯文本响应
340
+ elif "text/plain" in content_type:
341
+ error_msg = (
342
+ response.text[:200].strip()
343
+ if response.text
344
+ else f"HTTP {response.status_code}"
345
+ )
346
+
347
+ # 其他情况
348
+ else:
349
+ if response.text:
350
+ # 尝试截取前200字符,但去除换行和多余空格
351
+ text = response.text[:200].strip()
352
+ # 如果包含多行,只取第一行
353
+ if "\n" in text:
354
+ text = text.split("\n")[0].strip()
355
+ error_msg = text if text else f"HTTP {response.status_code}"
356
+ else:
357
+ error_msg = f"HTTP {response.status_code}"
358
+
359
+ return error_msg, x_request_id
360
+
361
+ def _call_pipeline_api(
362
+ self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]
363
+ ) -> Optional[Dict[str, Any]]:
224
364
  url = f"{self.api_base_url}/pipeline"
225
365
  max_retries = 3
226
366
 
227
367
  for try_count in range(max_retries):
228
368
  try:
229
- files = {'file': (filename or 'file', file_bytes)}
369
+ files = {"file": (filename or "file", file_bytes)}
230
370
  form_data = {}
231
371
 
232
372
  # 将 stages 转换为 API 格式
233
373
  stages_data = [stage.to_dict() for stage in self.stages]
234
374
  try:
235
- form_data['stages'] = json.dumps(stages_data)
236
- form_data['data_source'] = json.dumps(data_source, ensure_ascii=False)
237
-
375
+ form_data["stages"] = json.dumps(stages_data, ensure_ascii=False)
376
+ form_data["data_source"] = json.dumps(
377
+ data_source, ensure_ascii=False
378
+ )
379
+
238
380
  # 如果启用了中间结果保存,在请求中添加参数
239
381
  if self.pipeline_config:
240
- form_data['config'] = json.dumps(self.pipeline_config.to_dict(), ensure_ascii=False)
382
+ form_data["config"] = json.dumps(
383
+ self.pipeline_config.to_dict(), ensure_ascii=False
384
+ )
241
385
  except Exception as e:
242
386
  print(f" ✗ 入参处理失败,请检查配置: {e}")
243
387
  logger.error(f"入参处理失败,请检查配置: {e}")
@@ -248,76 +392,136 @@ class Pipeline:
248
392
  files=files,
249
393
  data=form_data,
250
394
  headers=self.api_headers,
251
- timeout=630
395
+ timeout=630,
252
396
  )
253
397
 
254
398
  if response.status_code == 200:
255
399
  result = response.json()
256
- x_request_id = result.get('x_request_id', '')
400
+ x_request_id = result.get("x_request_id", "")
257
401
  print(f" ✓ Pipeline 接口返回 x_request_id: {x_request_id}")
258
- if result.get('code') == 200 and 'data' in result:
259
- return result.get('data')
402
+ if result.get("code") == 200 and "data" in result:
403
+ return result.get("data")
260
404
  # 如果 code 不是 200,打印错误信息
261
- error_msg = result.get('message', result.get('msg', '未知错误'))
262
- print(f" ✗ Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}")
263
- logger.error(f"Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}")
405
+ error_msg = result.get("message", result.get("msg", "未知错误"))
406
+ print(
407
+ f"Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}"
408
+ )
409
+ logger.error(
410
+ f"Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}"
411
+ )
264
412
  return None
265
413
  else:
266
- # 尝试解析响应获取 x_request_id 和错误信息
267
- x_request_id = ''
268
- error_msg = ''
269
- try:
270
- result = response.json()
271
- x_request_id = result.get('x_request_id', '')
272
- error_msg = result.get('message', result.get('msg', response.text[:200]))
273
- except:
274
- error_msg = response.text[:200] if response.text else f'HTTP {response.status_code}'
275
-
276
- print(f" ✗ API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}")
277
- logger.warning(f"API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}")
414
+ # 使用规范化函数提取错误信息
415
+ error_msg, x_request_id = self._extract_error_message(response)
416
+
417
+ print(
418
+ f" ✗ API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}"
419
+ )
420
+ logger.warning(
421
+ f"API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}"
422
+ )
278
423
 
279
424
  except Exception as e:
280
425
  # 如果是 requests 异常,尝试从响应中获取 x_request_id
281
- x_request_id = ''
426
+ x_request_id = ""
282
427
  error_msg = str(e)
283
428
  try:
284
- if hasattr(e, 'response') and e.response is not None:
429
+ if hasattr(e, "response") and e.response is not None:
285
430
  try:
286
431
  result = e.response.json()
287
- x_request_id = result.get('x_request_id', '')
288
- error_msg = result.get('message', result.get('msg', error_msg))
432
+ x_request_id = result.get("x_request_id", "")
433
+ error_msg = result.get(
434
+ "message", result.get("msg", error_msg)
435
+ )
289
436
  except:
290
437
  pass
291
438
  except:
292
439
  pass
293
-
294
- print(f" ✗ 请求异常: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}")
295
- logger.error(f"API 请求异常 pipeline: {error_msg}, x_request_id={x_request_id}")
440
+
441
+ print(
442
+ f" 请求异常: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}"
443
+ )
444
+ logger.error(
445
+ f"API 请求异常 pipeline: {error_msg}, x_request_id={x_request_id}"
446
+ )
296
447
 
297
448
  if try_count < max_retries - 1:
298
449
  time.sleep(2)
299
450
 
300
451
  return None
301
452
 
302
- def process_with_pipeline(self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]) -> Optional[Tuple[List[Dict[str, Any]], PipelineStats]]:
453
+ def process_with_pipeline(
454
+ self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]
455
+ ) -> Optional[Tuple[Any, PipelineStats]]:
303
456
  print(f" → 调用 Pipeline 接口: {filename}")
304
457
  result = self._call_pipeline_api(file_bytes, filename, data_source)
305
458
 
306
- if result and 'elements' in result and 'stats' in result:
307
- elements = result['elements']
308
- stats_data = result['stats']
459
+ if not result or "stats" not in result:
460
+ print(f" ✗ Pipeline 失败")
461
+ logger.error(f"Pipeline 失败: {filename}")
462
+ return None
463
+
464
+ # 处理 extract 类型的响应
465
+ if self.has_extract:
466
+ # extract 返回 extract_result 而不是 elements
467
+ if "extract_result" not in result:
468
+ print(f" ✗ Pipeline 失败: extract 响应中缺少 extract_result")
469
+ logger.error(f"Pipeline 失败: extract 响应中缺少 extract_result")
470
+ return None
471
+
472
+ data = result["extract_result"] # 结构化数据
473
+ stats_data = result["stats"]
309
474
 
310
475
  stats = PipelineStats(
311
- original_elements=stats_data.get('original_elements', 0),
312
- chunked_elements=stats_data.get('chunked_elements', 0),
313
- embedded_elements=stats_data.get('embedded_elements', 0),
476
+ original_elements=stats_data.get("original_elements", 0),
477
+ chunked_elements=0, # extract 不涉及分块
478
+ embedded_elements=0, # extract 不涉及向量化
479
+ stages=self.stages,
480
+ record_id=stats_data.get("record_id"),
481
+ )
482
+
483
+ # 如果启用了中间结果保存,处理中间结果
484
+ if (
485
+ self.pipeline_config.include_intermediate_results
486
+ and "intermediate_results" in result
487
+ ):
488
+ self._save_intermediate_results(
489
+ result["intermediate_results"], filename, data_source
490
+ )
491
+
492
+ print(f" ✓ Extract 完成:")
493
+ print(f" - 原始元素: {stats.original_elements}")
494
+ print(f" - 提取结果类型: {type(data).__name__}")
495
+ logger.info(f"Extract 完成: {filename}")
496
+
497
+ return data, stats
498
+
499
+ else:
500
+ # 原有的 parse/chunk/embed 逻辑
501
+ if "elements" not in result:
502
+ print(f" ✗ Pipeline 失败: 响应中缺少 elements")
503
+ logger.error(f"Pipeline 失败: 响应中缺少 elements")
504
+ return None
505
+
506
+ elements = result["elements"]
507
+ stats_data = result["stats"]
508
+
509
+ stats = PipelineStats(
510
+ original_elements=stats_data.get("original_elements", 0),
511
+ chunked_elements=stats_data.get("chunked_elements", 0),
512
+ embedded_elements=stats_data.get("embedded_elements", 0),
314
513
  stages=self.stages, # 使用实际执行的 stages
315
- record_id=stats_data.get('record_id') # 从 API 响应中获取 record_id
514
+ record_id=stats_data.get("record_id"), # 从 API 响应中获取 record_id
316
515
  )
317
516
 
318
517
  # 如果启用了中间结果保存,处理中间结果
319
- if self.pipeline_config.include_intermediate_results and 'intermediate_results' in result:
320
- self._save_intermediate_results(result['intermediate_results'], filename, data_source)
518
+ if (
519
+ self.pipeline_config.include_intermediate_results
520
+ and "intermediate_results" in result
521
+ ):
522
+ self._save_intermediate_results(
523
+ result["intermediate_results"], filename, data_source
524
+ )
321
525
 
322
526
  print(f" ✓ Pipeline 完成:")
323
527
  print(f" - 原始元素: {stats.original_elements}")
@@ -326,14 +530,15 @@ class Pipeline:
326
530
  logger.info(f"Pipeline 完成: {filename}, {stats.embedded_elements} 个向量")
327
531
 
328
532
  return elements, stats
329
- else:
330
- print(f" ✗ Pipeline 失败")
331
- logger.error(f"Pipeline 失败: {filename}")
332
- return None
333
533
 
334
- def _save_intermediate_results(self, intermediate_results: List[Dict[str, Any]], filename: str, data_source: Dict[str, Any]) -> None:
534
+ def _save_intermediate_results(
535
+ self,
536
+ intermediate_results: List[Dict[str, Any]],
537
+ filename: str,
538
+ data_source: Dict[str, Any],
539
+ ) -> None:
335
540
  """保存中间结果
336
-
541
+
337
542
  Args:
338
543
  intermediate_results: 中间结果数组,每个元素包含 stage 和 elements 字段
339
544
  filename: 文件名
@@ -342,22 +547,26 @@ class Pipeline:
342
547
  try:
343
548
  # intermediate_results 是一个数组,每个元素是 {stage: str, elements: List}
344
549
  for result_item in intermediate_results:
345
- if 'stage' not in result_item or 'elements' not in result_item:
346
- logger.warning(f"中间结果项缺少 stage 或 elements 字段: {result_item}")
550
+ if "stage" not in result_item or "elements" not in result_item:
551
+ logger.warning(
552
+ f"中间结果项缺少 stage 或 elements 字段: {result_item}"
553
+ )
347
554
  continue
348
-
349
- stage = result_item['stage']
350
- elements = result_item['elements']
351
-
555
+
556
+ stage = result_item["stage"]
557
+ elements = result_item["elements"]
558
+
352
559
  metadata = {
353
- 'filename': filename,
354
- 'stage': stage,
355
- 'total_elements': len(elements),
356
- 'processed_at': datetime.now().isoformat(),
357
- 'data_source': data_source
560
+ "filename": filename,
561
+ "stage": stage,
562
+ "total_elements": len(elements),
563
+ "processed_at": datetime.now().isoformat(),
564
+ "data_source": data_source,
358
565
  }
359
-
360
- self.pipeline_config.intermediate_results_destination.write(elements, metadata)
566
+
567
+ self.pipeline_config.intermediate_results_destination.write(
568
+ elements, metadata
569
+ )
361
570
  print(f" ✓ 保存 {stage.upper()} 中间结果: {len(elements)} 个元素")
362
571
  logger.info(f"保存 {stage.upper()} 中间结果成功: {filename}")
363
572
 
@@ -374,17 +583,17 @@ class Pipeline:
374
583
  print(f" → 读取文件...")
375
584
  file_bytes, data_source = self.source.read_file(file_path)
376
585
  data_source = data_source or {}
377
-
586
+
378
587
  # 检查文件大小,超过 100MB 则报错
379
588
  MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
380
589
  file_size = len(file_bytes)
381
590
  if file_size > MAX_FILE_SIZE:
382
591
  file_size_mb = file_size / (1024 * 1024)
383
592
  raise ValueError(f"文件大小过大: {file_size_mb:.2f}MB,超过100MB限制")
384
-
593
+
385
594
  # 转换为毫秒时间戳字符串
386
595
  timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000)
387
- data_source['date_processed'] = str(timestamp_ms)
596
+ data_source["date_processed"] = str(timestamp_ms)
388
597
  print(f" ✓ 文件读取完成: {len(file_bytes)} bytes")
389
598
 
390
599
  result = self.process_with_pipeline(file_bytes, file_path, data_source)
@@ -395,13 +604,13 @@ class Pipeline:
395
604
 
396
605
  print(f" → 写入目的地...")
397
606
  metadata = {
398
- 'filename': file_path,
399
- 'processed_at': str(timestamp_ms),
607
+ "filename": file_path,
608
+ "processed_at": str(timestamp_ms),
400
609
  }
401
-
610
+
402
611
  # 如果 stats 中有 record_id,添加到 metadata 中
403
612
  if stats.record_id:
404
- metadata['record_id'] = stats.record_id
613
+ metadata["record_id"] = stats.record_id
405
614
 
406
615
  success = self.destination.write(embedded_data, metadata)
407
616
 
@@ -468,168 +677,184 @@ class Pipeline:
468
677
  print("=" * 60)
469
678
 
470
679
  logger.info("=" * 60)
471
- logger.info(f"Pipeline 完成 - 总数:{total}, 成功:{success_count}, 失败:{fail_count}, 耗时:{elapsed:.2f}秒")
680
+ logger.info(
681
+ f"Pipeline 完成 - 总数:{total}, 成功:{success_count}, 失败:{fail_count}, 耗时:{elapsed:.2f}秒"
682
+ )
472
683
  logger.info("=" * 60)
473
684
 
474
685
 
475
686
  def create_pipeline_from_config(config: Dict[str, Any]) -> Pipeline:
476
- source_config = config['source']
477
- if source_config['type'] == 's3':
687
+ source_config = config["source"]
688
+ if source_config["type"] == "s3":
478
689
  source = S3Source(
479
- endpoint=source_config['endpoint'],
480
- access_key=source_config['access_key'],
481
- secret_key=source_config['secret_key'],
482
- bucket=source_config['bucket'],
483
- prefix=source_config.get('prefix', ''),
484
- region=source_config.get('region', 'us-east-1'),
485
- pattern=source_config.get('pattern', None),
486
- recursive=source_config.get('recursive', False)
690
+ endpoint=source_config["endpoint"],
691
+ access_key=source_config["access_key"],
692
+ secret_key=source_config["secret_key"],
693
+ bucket=source_config["bucket"],
694
+ prefix=source_config.get("prefix", ""),
695
+ region=source_config.get("region", "us-east-1"),
696
+ pattern=source_config.get("pattern", None),
697
+ recursive=source_config.get("recursive", False),
487
698
  )
488
- elif source_config['type'] == 'local':
699
+ elif source_config["type"] == "local":
489
700
  source = LocalSource(
490
- directory=source_config['directory'],
491
- pattern=source_config.get('pattern', None),
492
- recursive=source_config.get('recursive', False)
701
+ directory=source_config["directory"],
702
+ pattern=source_config.get("pattern", None),
703
+ recursive=source_config.get("recursive", False),
493
704
  )
494
- elif source_config['type'] == 'ftp':
705
+ elif source_config["type"] == "ftp":
495
706
  source = FtpSource(
496
- host=source_config['host'],
497
- port=source_config['port'],
498
- username=source_config['username'],
499
- password=source_config['password'],
500
- pattern=source_config.get('pattern', None),
501
- recursive=source_config.get('recursive', False)
707
+ host=source_config["host"],
708
+ port=source_config["port"],
709
+ username=source_config["username"],
710
+ password=source_config["password"],
711
+ pattern=source_config.get("pattern", None),
712
+ recursive=source_config.get("recursive", False),
502
713
  )
503
- elif source_config['type'] == 'smb':
714
+ elif source_config["type"] == "smb":
504
715
  source = SmbSource(
505
- host=source_config['host'],
506
- share_name=source_config['share_name'],
507
- username=source_config['username'],
508
- password=source_config['password'],
509
- domain=source_config.get('domain', ''),
510
- port=source_config.get('port', 445),
511
- path=source_config.get('path', ''),
512
- pattern=source_config.get('pattern', None),
513
- recursive=source_config.get('recursive', False)
716
+ host=source_config["host"],
717
+ share_name=source_config["share_name"],
718
+ username=source_config["username"],
719
+ password=source_config["password"],
720
+ domain=source_config.get("domain", ""),
721
+ port=source_config.get("port", 445),
722
+ path=source_config.get("path", ""),
723
+ pattern=source_config.get("pattern", None),
724
+ recursive=source_config.get("recursive", False),
514
725
  )
515
726
  else:
516
727
  raise ValueError(f"未知的 source 类型: {source_config['type']}")
517
728
 
518
- dest_config = config['destination']
519
- if dest_config['type'] in ['milvus', 'zilliz']:
729
+ dest_config = config["destination"]
730
+ if dest_config["type"] in ["milvus", "zilliz"]:
520
731
  destination = MilvusDestination(
521
- db_path=dest_config['db_path'],
522
- collection_name=dest_config['collection_name'],
523
- dimension=dest_config['dimension'],
524
- api_key=dest_config.get('api_key'),
525
- token=dest_config.get('token')
732
+ db_path=dest_config["db_path"],
733
+ collection_name=dest_config["collection_name"],
734
+ dimension=dest_config["dimension"],
735
+ api_key=dest_config.get("api_key"),
736
+ token=dest_config.get("token"),
526
737
  )
527
- elif dest_config['type'] == 'qdrant':
738
+ elif dest_config["type"] == "qdrant":
528
739
  destination = QdrantDestination(
529
- url=dest_config['url'],
530
- collection_name=dest_config['collection_name'],
531
- dimension=dest_config['dimension'],
532
- api_key=dest_config.get('api_key'),
533
- prefer_grpc=dest_config.get('prefer_grpc', False)
740
+ url=dest_config["url"],
741
+ collection_name=dest_config["collection_name"],
742
+ dimension=dest_config["dimension"],
743
+ api_key=dest_config.get("api_key"),
744
+ prefer_grpc=dest_config.get("prefer_grpc", False),
534
745
  )
535
- elif dest_config['type'] == 'local':
536
- destination = LocalDestination(
537
- output_dir=dest_config['output_dir']
538
- )
539
- elif dest_config['type'] == 's3':
746
+ elif dest_config["type"] == "local":
747
+ destination = LocalDestination(output_dir=dest_config["output_dir"])
748
+ elif dest_config["type"] == "s3":
540
749
  destination = S3Destination(
541
- endpoint=dest_config['endpoint'],
542
- access_key=dest_config['access_key'],
543
- secret_key=dest_config['secret_key'],
544
- bucket=dest_config['bucket'],
545
- prefix=dest_config.get('prefix', ''),
546
- region=dest_config.get('region', 'us-east-1')
750
+ endpoint=dest_config["endpoint"],
751
+ access_key=dest_config["access_key"],
752
+ secret_key=dest_config["secret_key"],
753
+ bucket=dest_config["bucket"],
754
+ prefix=dest_config.get("prefix", ""),
755
+ region=dest_config.get("region", "us-east-1"),
547
756
  )
548
757
  else:
549
758
  raise ValueError(f"未知的 destination 类型: {dest_config['type']}")
550
759
 
551
760
  # 处理 stages 配置
552
- if 'stages' not in config or not config['stages']:
761
+ if "stages" not in config or not config["stages"]:
553
762
  raise ValueError("配置中必须包含 'stages' 字段")
554
-
763
+
555
764
  stages = []
556
- for stage_cfg in config['stages']:
557
- stage_type = stage_cfg.get('type')
558
- stage_config_dict = stage_cfg.get('config', {})
559
-
560
- if stage_type == 'parse':
765
+ for stage_cfg in config["stages"]:
766
+ stage_type = stage_cfg.get("type")
767
+ stage_config_dict = stage_cfg.get("config", {})
768
+
769
+ if stage_type == "parse":
561
770
  parse_cfg_copy = dict(stage_config_dict)
562
- provider = parse_cfg_copy.pop('provider', 'textin')
771
+ provider = parse_cfg_copy.pop("provider", "textin")
563
772
  stage_config = ParseConfig(provider=provider, **parse_cfg_copy)
564
- elif stage_type == 'chunk':
773
+ elif stage_type == "chunk":
565
774
  stage_config = ChunkConfig(
566
- strategy=stage_config_dict.get('strategy', 'basic'),
567
- include_orig_elements=stage_config_dict.get('include_orig_elements', False),
568
- new_after_n_chars=stage_config_dict.get('new_after_n_chars', 512),
569
- max_characters=stage_config_dict.get('max_characters', 1024),
570
- overlap=stage_config_dict.get('overlap', 0),
571
- overlap_all=stage_config_dict.get('overlap_all', False)
775
+ strategy=stage_config_dict.get("strategy", "basic"),
776
+ include_orig_elements=stage_config_dict.get(
777
+ "include_orig_elements", False
778
+ ),
779
+ new_after_n_chars=stage_config_dict.get("new_after_n_chars", 512),
780
+ max_characters=stage_config_dict.get("max_characters", 1024),
781
+ overlap=stage_config_dict.get("overlap", 0),
782
+ overlap_all=stage_config_dict.get("overlap_all", False),
572
783
  )
573
- elif stage_type == 'embed':
784
+ elif stage_type == "embed":
574
785
  stage_config = EmbedConfig(
575
- provider=stage_config_dict.get('provider', 'qwen'),
576
- model_name=stage_config_dict.get('model_name', 'text-embedding-v3')
786
+ provider=stage_config_dict.get("provider", "qwen"),
787
+ model_name=stage_config_dict.get("model_name", "text-embedding-v3"),
788
+ )
789
+ elif stage_type == "extract":
790
+ schema = stage_config_dict.get("schema")
791
+ if not schema:
792
+ raise ValueError("extract stage 的 config 中必须包含 'schema' 字段")
793
+ stage_config = ExtractConfig(
794
+ schema=schema,
795
+ generate_citations=stage_config_dict.get("generate_citations", False),
796
+ stamp=stage_config_dict.get("stamp", False),
577
797
  )
578
798
  else:
579
799
  raise ValueError(f"未知的 stage 类型: {stage_type}")
580
-
800
+
581
801
  stages.append(Stage(type=stage_type, config=stage_config))
582
802
 
583
803
  # 创建 Pipeline 配置
584
804
  pipeline_config = None
585
- if 'pipeline_config' in config and config['pipeline_config']:
586
- pipeline_cfg = config['pipeline_config']
587
- include_intermediate_results = pipeline_cfg.get('include_intermediate_results', False)
805
+ if "pipeline_config" in config and config["pipeline_config"]:
806
+ pipeline_cfg = config["pipeline_config"]
807
+ include_intermediate_results = pipeline_cfg.get(
808
+ "include_intermediate_results", False
809
+ )
588
810
  intermediate_results_destination = None
589
-
811
+
590
812
  if include_intermediate_results:
591
- if 'intermediate_results_destination' in pipeline_cfg:
592
- dest_cfg = pipeline_cfg['intermediate_results_destination']
593
- dest_type = dest_cfg.get('type')
594
-
595
- if dest_type == 'local':
813
+ if "intermediate_results_destination" in pipeline_cfg:
814
+ dest_cfg = pipeline_cfg["intermediate_results_destination"]
815
+ dest_type = dest_cfg.get("type")
816
+
817
+ if dest_type == "local":
596
818
  intermediate_results_destination = LocalDestination(
597
- output_dir=dest_cfg['output_dir']
819
+ output_dir=dest_cfg["output_dir"]
598
820
  )
599
- elif dest_type == 's3':
821
+ elif dest_type == "s3":
600
822
  intermediate_results_destination = S3Destination(
601
- endpoint=dest_cfg['endpoint'],
602
- access_key=dest_cfg['access_key'],
603
- secret_key=dest_cfg['secret_key'],
604
- bucket=dest_cfg['bucket'],
605
- prefix=dest_cfg.get('prefix', ''),
606
- region=dest_cfg.get('region', 'us-east-1')
823
+ endpoint=dest_cfg["endpoint"],
824
+ access_key=dest_cfg["access_key"],
825
+ secret_key=dest_cfg["secret_key"],
826
+ bucket=dest_cfg["bucket"],
827
+ prefix=dest_cfg.get("prefix", ""),
828
+ region=dest_cfg.get("region", "us-east-1"),
607
829
  )
608
830
  else:
609
- raise ValueError(f"不支持的 intermediate_results_destination 类型: '{dest_type}',支持的类型: 'local', 's3'")
831
+ raise ValueError(
832
+ f"不支持的 intermediate_results_destination 类型: '{dest_type}',支持的类型: 'local', 's3'"
833
+ )
610
834
  else:
611
- raise ValueError("当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination")
612
-
835
+ raise ValueError(
836
+ "当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination"
837
+ )
838
+
613
839
  pipeline_config = PipelineConfig(
614
840
  include_intermediate_results=include_intermediate_results,
615
- intermediate_results_destination=intermediate_results_destination
841
+ intermediate_results_destination=intermediate_results_destination,
616
842
  )
617
843
 
618
844
  # 创建 Pipeline
619
845
  pipeline = Pipeline(
620
846
  source=source,
621
847
  destination=destination,
622
- api_base_url=config.get('api_base_url', 'http://localhost:8000/api/xparse'),
623
- api_headers=config.get('api_headers', {}),
848
+ api_base_url=config.get("api_base_url", "http://localhost:8000/api/xparse"),
849
+ api_headers=config.get("api_headers", {}),
624
850
  stages=stages,
625
- pipeline_config=pipeline_config
851
+ pipeline_config=pipeline_config,
626
852
  )
627
853
 
628
854
  return pipeline
629
855
 
630
856
 
631
857
  __all__ = [
632
- 'Pipeline',
633
- 'create_pipeline_from_config',
858
+ "Pipeline",
859
+ "create_pipeline_from_config",
634
860
  ]
635
-