xparse-client 0.2.19__py3-none-any.whl → 0.2.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,14 +7,27 @@ import re
7
7
  import time
8
8
  from datetime import datetime, timezone
9
9
  from pathlib import Path
10
- from typing import Dict, Any, Optional, Tuple, List, Union
10
+ from typing import Any, Dict, List, Optional, Tuple, Union
11
11
 
12
12
  import requests
13
13
 
14
- from .config import ParseConfig, ChunkConfig, EmbedConfig, Stage, PipelineStats, PipelineConfig
15
- from .sources import Source, S3Source, LocalSource, FtpSource, SmbSource
16
- from .destinations import Destination, MilvusDestination, QdrantDestination, LocalDestination, S3Destination
17
-
14
+ from .config import (
15
+ ChunkConfig,
16
+ EmbedConfig,
17
+ ExtractConfig,
18
+ ParseConfig,
19
+ PipelineConfig,
20
+ PipelineStats,
21
+ Stage,
22
+ )
23
+ from .destinations import (
24
+ Destination,
25
+ LocalDestination,
26
+ MilvusDestination,
27
+ QdrantDestination,
28
+ S3Destination,
29
+ )
30
+ from .sources import FtpSource, LocalSource, S3Source, SmbSource, Source
18
31
 
19
32
  logger = logging.getLogger(__name__)
20
33
 
@@ -26,15 +39,15 @@ class Pipeline:
26
39
  self,
27
40
  source: Source,
28
41
  destination: Destination,
29
- api_base_url: str = 'http://localhost:8000/api/xparse',
42
+ api_base_url: str = "http://localhost:8000/api/xparse",
30
43
  api_headers: Optional[Dict[str, str]] = None,
31
44
  stages: Optional[List[Stage]] = None,
32
45
  pipeline_config: Optional[PipelineConfig] = None,
33
- intermediate_results_destination: Optional[Destination] = None
46
+ intermediate_results_destination: Optional[Destination] = None,
34
47
  ):
35
48
  self.source = source
36
49
  self.destination = destination
37
- self.api_base_url = api_base_url.rstrip('/')
50
+ self.api_base_url = api_base_url.rstrip("/")
38
51
  self.api_headers = api_headers or {}
39
52
  self.pipeline_config = pipeline_config or PipelineConfig()
40
53
 
@@ -42,34 +55,62 @@ class Pipeline:
42
55
  # 如果直接传入了 intermediate_results_destination,优先使用它并自动启用中间结果保存
43
56
  if intermediate_results_destination is not None:
44
57
  self.pipeline_config.include_intermediate_results = True
45
- self.pipeline_config.intermediate_results_destination = intermediate_results_destination
58
+ self.pipeline_config.intermediate_results_destination = (
59
+ intermediate_results_destination
60
+ )
46
61
  # 如果 pipeline_config 中已设置,使用 pipeline_config 中的值
47
62
  elif self.pipeline_config.include_intermediate_results:
48
63
  if not self.pipeline_config.intermediate_results_destination:
49
- raise ValueError("当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination")
64
+ raise ValueError(
65
+ "当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination"
66
+ )
50
67
 
51
68
  # 处理 stages 配置
52
69
  if stages is None:
53
70
  raise ValueError("必须提供 stages 参数")
54
-
71
+
55
72
  self.stages = stages
56
73
 
57
74
  # 验证 stages
58
- if not self.stages or self.stages[0].type != 'parse':
75
+ if not self.stages or self.stages[0].type != "parse":
59
76
  raise ValueError("stages 必须包含且第一个必须是 'parse' 类型")
60
-
77
+
78
+ # 检查是否包含 extract 阶段
79
+ self.has_extract = any(stage.type == "extract" for stage in self.stages)
80
+
81
+ # 验证 extract 阶段的约束
82
+ if self.has_extract:
83
+ # extract 只能跟在 parse 后面,且必须是最后一个
84
+ if len(self.stages) != 2 or self.stages[1].type != "extract":
85
+ raise ValueError(
86
+ "extract 阶段只能跟在 parse 后面,且必须是最后一个阶段(即只能是 [parse, extract] 组合)"
87
+ )
88
+
89
+ # destination 必须是文件存储类型
90
+ if not isinstance(destination, (LocalDestination, S3Destination)):
91
+ raise ValueError(
92
+ "当使用 extract 阶段时,destination 必须是文件存储类型(LocalDestination 或 S3Destination),不能是向量数据库"
93
+ )
94
+
61
95
  # 验证 embed config(如果存在)
62
96
  for stage in self.stages:
63
- if stage.type == 'embed' and isinstance(stage.config, EmbedConfig):
97
+ if stage.type == "embed" and isinstance(stage.config, EmbedConfig):
64
98
  stage.config.validate()
65
-
99
+
66
100
  # 验证 intermediate_results_destination
67
101
  if self.pipeline_config.include_intermediate_results:
68
102
  # 验证是否为支持的 Destination 类型
69
103
  from .destinations import Destination
70
- if not isinstance(self.pipeline_config.intermediate_results_destination, Destination):
71
- raise ValueError(f"intermediate_results_destination 必须是 Destination 类型")
72
- self.intermediate_results_destination = self.pipeline_config.intermediate_results_destination
104
+
105
+ if not isinstance(
106
+ self.pipeline_config.intermediate_results_destination, Destination
107
+ ):
108
+ raise ValueError(
109
+ f"intermediate_results_destination 必须是 Destination 类型"
110
+ )
111
+ self.intermediate_results_destination = (
112
+ self.pipeline_config.intermediate_results_destination
113
+ )
73
114
 
74
115
  print("=" * 60)
75
116
  print("Pipeline 初始化完成")
@@ -83,223 +124,264 @@ class Pipeline:
83
124
  def get_config(self) -> Dict[str, Any]:
84
125
  """获取 Pipeline 的完整配置信息,返回字典格式(与 create_pipeline_from_config 的入参格式一致)"""
85
126
  config = {}
86
-
127
+
87
128
  # Source 配置
88
- source_type = type(self.source).__name__.replace('Source', '').lower()
89
- config['source'] = {'type': source_type}
90
-
129
+ source_type = type(self.source).__name__.replace("Source", "").lower()
130
+ config["source"] = {"type": source_type}
131
+
91
132
  if isinstance(self.source, S3Source):
92
- config['source'].update({
93
- 'endpoint': self.source.endpoint,
94
- 'bucket': self.source.bucket,
95
- 'prefix': self.source.prefix,
96
- 'pattern': self.source.pattern,
97
- 'recursive': self.source.recursive
98
- })
133
+ config["source"].update(
134
+ {
135
+ "endpoint": self.source.endpoint,
136
+ "bucket": self.source.bucket,
137
+ "prefix": self.source.prefix,
138
+ "pattern": self.source.pattern,
139
+ "recursive": self.source.recursive,
140
+ }
141
+ )
99
142
  # access_key 和 secret_key 不在对象中保存,无法恢复
100
143
  # region 也不在对象中保存,使用默认值
101
- config['source']['region'] = 'us-east-1' # 默认值
144
+ config["source"]["region"] = "us-east-1" # 默认值
102
145
  elif isinstance(self.source, LocalSource):
103
- config['source'].update({
104
- 'directory': str(self.source.directory),
105
- 'pattern': self.source.pattern,
106
- 'recursive': self.source.recursive
107
- })
146
+ config["source"].update(
147
+ {
148
+ "directory": str(self.source.directory),
149
+ "pattern": self.source.pattern,
150
+ "recursive": self.source.recursive,
151
+ }
152
+ )
108
153
  elif isinstance(self.source, FtpSource):
109
- config['source'].update({
110
- 'host': self.source.host,
111
- 'port': self.source.port,
112
- 'username': self.source.username,
113
- 'pattern': self.source.pattern,
114
- 'recursive': self.source.recursive
115
- })
154
+ config["source"].update(
155
+ {
156
+ "host": self.source.host,
157
+ "port": self.source.port,
158
+ "username": self.source.username,
159
+ "pattern": self.source.pattern,
160
+ "recursive": self.source.recursive,
161
+ }
162
+ )
116
163
  # password 不在对象中保存,无法恢复
117
164
  elif isinstance(self.source, SmbSource):
118
- config['source'].update({
119
- 'host': self.source.host,
120
- 'share_name': self.source.share_name,
121
- 'username': self.source.username,
122
- 'domain': self.source.domain,
123
- 'port': self.source.port,
124
- 'path': self.source.path,
125
- 'pattern': self.source.pattern,
126
- 'recursive': self.source.recursive
127
- })
165
+ config["source"].update(
166
+ {
167
+ "host": self.source.host,
168
+ "share_name": self.source.share_name,
169
+ "username": self.source.username,
170
+ "domain": self.source.domain,
171
+ "port": self.source.port,
172
+ "path": self.source.path,
173
+ "pattern": self.source.pattern,
174
+ "recursive": self.source.recursive,
175
+ }
176
+ )
128
177
  # password 不在对象中保存,无法恢复
129
-
178
+
130
179
  # Destination 配置
131
- dest_type = type(self.destination).__name__.replace('Destination', '').lower()
180
+ dest_type = type(self.destination).__name__.replace("Destination", "").lower()
132
181
  # MilvusDestination 和 Zilliz 都使用 'milvus' 或 'zilliz' 类型
133
- if dest_type == 'milvus':
182
+ if dest_type == "milvus":
134
183
  # 判断是本地 Milvus 还是 Zilliz(通过 db_path 判断)
135
- if self.destination.db_path.startswith('http'):
136
- dest_type = 'zilliz'
184
+ if self.destination.db_path.startswith("http"):
185
+ dest_type = "zilliz"
137
186
  else:
138
- dest_type = 'milvus'
139
-
140
- config['destination'] = {'type': dest_type}
141
-
187
+ dest_type = "milvus"
188
+
189
+ config["destination"] = {"type": dest_type}
190
+
142
191
  if isinstance(self.destination, MilvusDestination):
143
- config['destination'].update({
144
- 'db_path': self.destination.db_path,
145
- 'collection_name': self.destination.collection_name,
146
- 'dimension': self.destination.dimension
147
- })
192
+ config["destination"].update(
193
+ {
194
+ "db_path": self.destination.db_path,
195
+ "collection_name": self.destination.collection_name,
196
+ "dimension": self.destination.dimension,
197
+ }
198
+ )
148
199
  # api_key 和 token 不在对象中保存,无法恢复
149
200
  elif isinstance(self.destination, QdrantDestination):
150
- config['destination'].update({
151
- 'url': self.destination.url,
152
- 'collection_name': self.destination.collection_name,
153
- 'dimension': self.destination.dimension,
154
- 'prefer_grpc': getattr(self.destination, 'prefer_grpc', False)
155
- })
201
+ config["destination"].update(
202
+ {
203
+ "url": self.destination.url,
204
+ "collection_name": self.destination.collection_name,
205
+ "dimension": self.destination.dimension,
206
+ "prefer_grpc": getattr(self.destination, "prefer_grpc", False),
207
+ }
208
+ )
156
209
  # api_key 不在对象中保存,无法恢复
157
210
  elif isinstance(self.destination, LocalDestination):
158
- config['destination'].update({
159
- 'output_dir': str(self.destination.output_dir)
160
- })
211
+ config["destination"].update(
212
+ {"output_dir": str(self.destination.output_dir)}
213
+ )
161
214
  elif isinstance(self.destination, S3Destination):
162
- config['destination'].update({
163
- 'endpoint': self.destination.endpoint,
164
- 'bucket': self.destination.bucket,
165
- 'prefix': self.destination.prefix
166
- })
215
+ config["destination"].update(
216
+ {
217
+ "endpoint": self.destination.endpoint,
218
+ "bucket": self.destination.bucket,
219
+ "prefix": self.destination.prefix,
220
+ }
221
+ )
167
222
  # access_key, secret_key, region 不在对象中保存,无法恢复
168
- config['destination']['region'] = 'us-east-1' # 默认值
169
-
223
+ config["destination"]["region"] = "us-east-1" # 默认值
224
+
170
225
  # API 配置
171
- config['api_base_url'] = self.api_base_url
172
- config['api_headers'] = {}
226
+ config["api_base_url"] = self.api_base_url
227
+ config["api_headers"] = {}
173
228
  for key, value in self.api_headers.items():
174
- config['api_headers'][key] = value
175
-
229
+ config["api_headers"][key] = value
230
+
176
231
  # Stages 配置
177
- config['stages'] = []
232
+ config["stages"] = []
178
233
  for stage in self.stages:
179
- stage_dict = {
180
- 'type': stage.type,
181
- 'config': {}
182
- }
183
-
234
+ stage_dict = {"type": stage.type, "config": {}}
235
+
184
236
  if isinstance(stage.config, ParseConfig):
185
- stage_dict['config'] = stage.config.to_dict()
237
+ stage_dict["config"] = stage.config.to_dict()
186
238
  elif isinstance(stage.config, ChunkConfig):
187
- stage_dict['config'] = stage.config.to_dict()
239
+ stage_dict["config"] = stage.config.to_dict()
188
240
  elif isinstance(stage.config, EmbedConfig):
189
- stage_dict['config'] = stage.config.to_dict()
241
+ stage_dict["config"] = stage.config.to_dict()
242
+ elif isinstance(stage.config, ExtractConfig):
243
+ stage_dict["config"] = stage.config.to_dict()
190
244
  else:
191
245
  # 如果 config 是字典或其他类型,尝试转换
192
246
  if isinstance(stage.config, dict):
193
- stage_dict['config'] = stage.config
247
+ stage_dict["config"] = stage.config
194
248
  else:
195
- stage_dict['config'] = str(stage.config)
196
-
197
- config['stages'].append(stage_dict)
198
-
249
+ stage_dict["config"] = str(stage.config)
250
+
251
+ config["stages"].append(stage_dict)
252
+
199
253
  # Pipeline Config
200
254
  if self.pipeline_config.include_intermediate_results:
201
- config['pipeline_config'] = {
202
- 'include_intermediate_results': True,
203
- 'intermediate_results_destination': {}
255
+ config["pipeline_config"] = {
256
+ "include_intermediate_results": True,
257
+ "intermediate_results_destination": {},
204
258
  }
205
-
259
+
206
260
  inter_dest = self.pipeline_config.intermediate_results_destination
207
261
  if inter_dest:
208
- inter_dest_type = type(inter_dest).__name__.replace('Destination', '').lower()
209
- config['pipeline_config']['intermediate_results_destination']['type'] = inter_dest_type
210
-
262
+ inter_dest_type = (
263
+ type(inter_dest).__name__.replace("Destination", "").lower()
264
+ )
265
+ config["pipeline_config"]["intermediate_results_destination"][
266
+ "type"
267
+ ] = inter_dest_type
268
+
211
269
  if isinstance(inter_dest, LocalDestination):
212
- config['pipeline_config']['intermediate_results_destination']['output_dir'] = str(inter_dest.output_dir)
270
+ config["pipeline_config"]["intermediate_results_destination"][
271
+ "output_dir"
272
+ ] = str(inter_dest.output_dir)
213
273
  elif isinstance(inter_dest, S3Destination):
214
- config['pipeline_config']['intermediate_results_destination'].update({
215
- 'endpoint': inter_dest.endpoint,
216
- 'bucket': inter_dest.bucket,
217
- 'prefix': inter_dest.prefix
218
- })
274
+ config["pipeline_config"][
275
+ "intermediate_results_destination"
276
+ ].update(
277
+ {
278
+ "endpoint": inter_dest.endpoint,
279
+ "bucket": inter_dest.bucket,
280
+ "prefix": inter_dest.prefix,
281
+ }
282
+ )
219
283
  # access_key, secret_key, region 不在对象中保存,无法恢复
220
- config['pipeline_config']['intermediate_results_destination']['region'] = 'us-east-1' # 默认值
221
-
284
+ config["pipeline_config"]["intermediate_results_destination"][
285
+ "region"
286
+ ] = "us-east-1" # 默认值
287
+
222
288
  return config
223
289
 
224
290
  def _extract_error_message(self, response: requests.Response) -> Tuple[str, str]:
225
291
  """
226
292
  从响应中提取规范化的错误信息
227
-
293
+
228
294
  Returns:
229
295
  Tuple[str, str]: (error_msg, x_request_id)
230
296
  """
231
297
  # 首先尝试从响应头中提取 x-request-id(requests的headers大小写不敏感)
232
- x_request_id = response.headers.get('x-request-id', '')
233
- error_msg = ''
234
-
298
+ x_request_id = response.headers.get("x-request-id", "")
299
+ error_msg = ""
300
+
235
301
  # 获取Content-Type
236
- content_type = response.headers.get('Content-Type', '').lower()
237
-
302
+ content_type = response.headers.get("Content-Type", "").lower()
303
+
238
304
  # 尝试解析JSON响应
239
- if 'application/json' in content_type:
305
+ if "application/json" in content_type:
240
306
  try:
241
307
  result = response.json()
242
308
  # 如果响应头中没有x-request-id,尝试从响应体中获取
243
309
  if not x_request_id:
244
- x_request_id = result.get('x_request_id', '')
245
- error_msg = result.get('message', result.get('msg', f'HTTP {response.status_code}'))
310
+ x_request_id = result.get("x_request_id", "")
311
+ error_msg = result.get(
312
+ "message", result.get("msg", f"HTTP {response.status_code}")
313
+ )
246
314
  return error_msg, x_request_id
247
315
  except:
248
316
  pass
249
-
317
+
250
318
  # 处理HTML响应
251
- if 'text/html' in content_type or response.text.strip().startswith('<'):
319
+ if "text/html" in content_type or response.text.strip().startswith("<"):
252
320
  try:
253
321
  # 从HTML中提取标题(通常包含状态码和状态文本)
254
- title_match = re.search(r'<title>(.*?)</title>', response.text, re.IGNORECASE)
322
+ title_match = re.search(
323
+ r"<title>(.*?)</title>", response.text, re.IGNORECASE
324
+ )
255
325
  if title_match:
256
326
  error_msg = title_match.group(1).strip()
257
327
  else:
258
328
  # 如果没有title,尝试提取h1标签
259
- h1_match = re.search(r'<h1>(.*?)</h1>', response.text, re.IGNORECASE)
329
+ h1_match = re.search(
330
+ r"<h1>(.*?)</h1>", response.text, re.IGNORECASE
331
+ )
260
332
  if h1_match:
261
333
  error_msg = h1_match.group(1).strip()
262
334
  else:
263
- error_msg = f'HTTP {response.status_code}'
335
+ error_msg = f"HTTP {response.status_code}"
264
336
  except:
265
- error_msg = f'HTTP {response.status_code}'
266
-
337
+ error_msg = f"HTTP {response.status_code}"
338
+
267
339
  # 处理纯文本响应
268
- elif 'text/plain' in content_type:
269
- error_msg = response.text[:200].strip() if response.text else f'HTTP {response.status_code}'
270
-
340
+ elif "text/plain" in content_type:
341
+ error_msg = (
342
+ response.text[:200].strip()
343
+ if response.text
344
+ else f"HTTP {response.status_code}"
345
+ )
346
+
271
347
  # 其他情况
272
348
  else:
273
349
  if response.text:
274
350
  # 尝试截取前200字符,但去除换行和多余空格
275
351
  text = response.text[:200].strip()
276
352
  # 如果包含多行,只取第一行
277
- if '\n' in text:
278
- text = text.split('\n')[0].strip()
279
- error_msg = text if text else f'HTTP {response.status_code}'
353
+ if "\n" in text:
354
+ text = text.split("\n")[0].strip()
355
+ error_msg = text if text else f"HTTP {response.status_code}"
280
356
  else:
281
- error_msg = f'HTTP {response.status_code}'
282
-
357
+ error_msg = f"HTTP {response.status_code}"
358
+
283
359
  return error_msg, x_request_id
284
360
 
285
- def _call_pipeline_api(self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]) -> Optional[Dict[str, Any]]:
361
+ def _call_pipeline_api(
362
+ self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]
363
+ ) -> Optional[Dict[str, Any]]:
286
364
  url = f"{self.api_base_url}/pipeline"
287
365
  max_retries = 3
288
366
 
289
367
  for try_count in range(max_retries):
290
368
  try:
291
- files = {'file': (filename or 'file', file_bytes)}
369
+ files = {"file": (filename or "file", file_bytes)}
292
370
  form_data = {}
293
371
 
294
372
  # 将 stages 转换为 API 格式
295
373
  stages_data = [stage.to_dict() for stage in self.stages]
296
374
  try:
297
- form_data['stages'] = json.dumps(stages_data)
298
- form_data['data_source'] = json.dumps(data_source, ensure_ascii=False)
299
-
375
+ form_data["stages"] = json.dumps(stages_data, ensure_ascii=False)
376
+ form_data["data_source"] = json.dumps(
377
+ data_source, ensure_ascii=False
378
+ )
379
+
300
380
  # 如果启用了中间结果保存,在请求中添加参数
301
381
  if self.pipeline_config:
302
- form_data['config'] = json.dumps(self.pipeline_config.to_dict(), ensure_ascii=False)
382
+ form_data["config"] = json.dumps(
383
+ self.pipeline_config.to_dict(), ensure_ascii=False
384
+ )
303
385
  except Exception as e:
304
386
  print(f" ✗ 入参处理失败,请检查配置: {e}")
305
387
  logger.error(f"入参处理失败,请检查配置: {e}")
@@ -310,69 +392,136 @@ class Pipeline:
310
392
  files=files,
311
393
  data=form_data,
312
394
  headers=self.api_headers,
313
- timeout=630
395
+ timeout=630,
314
396
  )
315
397
 
316
398
  if response.status_code == 200:
317
399
  result = response.json()
318
- x_request_id = result.get('x_request_id', '')
400
+ x_request_id = result.get("x_request_id", "")
319
401
  print(f" ✓ Pipeline 接口返回 x_request_id: {x_request_id}")
320
- if result.get('code') == 200 and 'data' in result:
321
- return result.get('data')
402
+ if result.get("code") == 200 and "data" in result:
403
+ return result.get("data")
322
404
  # 如果 code 不是 200,打印错误信息
323
- error_msg = result.get('message', result.get('msg', '未知错误'))
324
- print(f" ✗ Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}")
325
- logger.error(f"Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}")
405
+ error_msg = result.get("message", result.get("msg", "未知错误"))
406
+ print(
407
+ f"Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}"
408
+ )
409
+ logger.error(
410
+ f"Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}"
411
+ )
326
412
  return None
327
413
  else:
328
414
  # 使用规范化函数提取错误信息
329
415
  error_msg, x_request_id = self._extract_error_message(response)
330
-
331
- print(f" ✗ API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}")
332
- logger.warning(f"API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}")
416
+
417
+ print(
418
+ f"API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}"
419
+ )
420
+ logger.warning(
421
+ f"API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}"
422
+ )
333
423
 
334
424
  except Exception as e:
335
425
  # 如果是 requests 异常,尝试从响应中获取 x_request_id
336
- x_request_id = ''
426
+ x_request_id = ""
337
427
  error_msg = str(e)
338
428
  try:
339
- if hasattr(e, 'response') and e.response is not None:
429
+ if hasattr(e, "response") and e.response is not None:
340
430
  try:
341
431
  result = e.response.json()
342
- x_request_id = result.get('x_request_id', '')
343
- error_msg = result.get('message', result.get('msg', error_msg))
432
+ x_request_id = result.get("x_request_id", "")
433
+ error_msg = result.get(
434
+ "message", result.get("msg", error_msg)
435
+ )
344
436
  except:
345
437
  pass
346
438
  except:
347
439
  pass
348
-
349
- print(f" ✗ 请求异常: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}")
350
- logger.error(f"API 请求异常 pipeline: {error_msg}, x_request_id={x_request_id}")
440
+
441
+ print(
442
+ f" 请求异常: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}"
443
+ )
444
+ logger.error(
445
+ f"API 请求异常 pipeline: {error_msg}, x_request_id={x_request_id}"
446
+ )
351
447
 
352
448
  if try_count < max_retries - 1:
353
449
  time.sleep(2)
354
450
 
355
451
  return None
356
452
 
357
- def process_with_pipeline(self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]) -> Optional[Tuple[List[Dict[str, Any]], PipelineStats]]:
453
+ def process_with_pipeline(
454
+ self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]
455
+ ) -> Optional[Tuple[Any, PipelineStats]]:
358
456
  print(f" → 调用 Pipeline 接口: {filename}")
359
457
  result = self._call_pipeline_api(file_bytes, filename, data_source)
360
458
 
361
- if result and 'elements' in result and 'stats' in result:
362
- elements = result['elements']
363
- stats_data = result['stats']
459
+ if not result or "stats" not in result:
460
+ print(f" ✗ Pipeline 失败")
461
+ logger.error(f"Pipeline 失败: {filename}")
462
+ return None
463
+
464
+ # 处理 extract 类型的响应
465
+ if self.has_extract:
466
+ # extract 返回 extract_result 而不是 elements
467
+ if "extract_result" not in result:
468
+ print(f" ✗ Pipeline 失败: extract 响应中缺少 extract_result")
469
+ logger.error(f"Pipeline 失败: extract 响应中缺少 extract_result")
470
+ return None
471
+
472
+ data = result["extract_result"] # 结构化数据
473
+ stats_data = result["stats"]
474
+
475
+ stats = PipelineStats(
476
+ original_elements=stats_data.get("original_elements", 0),
477
+ chunked_elements=0, # extract 不涉及分块
478
+ embedded_elements=0, # extract 不涉及向量化
479
+ stages=self.stages,
480
+ record_id=stats_data.get("record_id"),
481
+ )
482
+
483
+ # 如果启用了中间结果保存,处理中间结果
484
+ if (
485
+ self.pipeline_config.include_intermediate_results
486
+ and "intermediate_results" in result
487
+ ):
488
+ self._save_intermediate_results(
489
+ result["intermediate_results"], filename, data_source
490
+ )
491
+
492
+ print(f" ✓ Extract 完成:")
493
+ print(f" - 原始元素: {stats.original_elements}")
494
+ print(f" - 提取结果类型: {type(data).__name__}")
495
+ logger.info(f"Extract 完成: {filename}")
496
+
497
+ return data, stats
498
+
499
+ else:
500
+ # 原有的 parse/chunk/embed 逻辑
501
+ if "elements" not in result:
502
+ print(f" ✗ Pipeline 失败: 响应中缺少 elements")
503
+ logger.error(f"Pipeline 失败: 响应中缺少 elements")
504
+ return None
505
+
506
+ elements = result["elements"]
507
+ stats_data = result["stats"]
364
508
 
365
509
  stats = PipelineStats(
366
- original_elements=stats_data.get('original_elements', 0),
367
- chunked_elements=stats_data.get('chunked_elements', 0),
368
- embedded_elements=stats_data.get('embedded_elements', 0),
510
+ original_elements=stats_data.get("original_elements", 0),
511
+ chunked_elements=stats_data.get("chunked_elements", 0),
512
+ embedded_elements=stats_data.get("embedded_elements", 0),
369
513
  stages=self.stages, # 使用实际执行的 stages
370
- record_id=stats_data.get('record_id') # 从 API 响应中获取 record_id
514
+ record_id=stats_data.get("record_id"), # 从 API 响应中获取 record_id
371
515
  )
372
516
 
373
517
  # 如果启用了中间结果保存,处理中间结果
374
- if self.pipeline_config.include_intermediate_results and 'intermediate_results' in result:
375
- self._save_intermediate_results(result['intermediate_results'], filename, data_source)
518
+ if (
519
+ self.pipeline_config.include_intermediate_results
520
+ and "intermediate_results" in result
521
+ ):
522
+ self._save_intermediate_results(
523
+ result["intermediate_results"], filename, data_source
524
+ )
376
525
 
377
526
  print(f" ✓ Pipeline 完成:")
378
527
  print(f" - 原始元素: {stats.original_elements}")
@@ -381,14 +530,15 @@ class Pipeline:
381
530
  logger.info(f"Pipeline 完成: {filename}, {stats.embedded_elements} 个向量")
382
531
 
383
532
  return elements, stats
384
- else:
385
- print(f" ✗ Pipeline 失败")
386
- logger.error(f"Pipeline 失败: {filename}")
387
- return None
388
533
 
389
- def _save_intermediate_results(self, intermediate_results: List[Dict[str, Any]], filename: str, data_source: Dict[str, Any]) -> None:
534
+ def _save_intermediate_results(
535
+ self,
536
+ intermediate_results: List[Dict[str, Any]],
537
+ filename: str,
538
+ data_source: Dict[str, Any],
539
+ ) -> None:
390
540
  """保存中间结果
391
-
541
+
392
542
  Args:
393
543
  intermediate_results: 中间结果数组,每个元素包含 stage 和 elements 字段
394
544
  filename: 文件名
@@ -397,22 +547,26 @@ class Pipeline:
397
547
  try:
398
548
  # intermediate_results 是一个数组,每个元素是 {stage: str, elements: List}
399
549
  for result_item in intermediate_results:
400
- if 'stage' not in result_item or 'elements' not in result_item:
401
- logger.warning(f"中间结果项缺少 stage 或 elements 字段: {result_item}")
550
+ if "stage" not in result_item or "elements" not in result_item:
551
+ logger.warning(
552
+ f"中间结果项缺少 stage 或 elements 字段: {result_item}"
553
+ )
402
554
  continue
403
-
404
- stage = result_item['stage']
405
- elements = result_item['elements']
406
-
555
+
556
+ stage = result_item["stage"]
557
+ elements = result_item["elements"]
558
+
407
559
  metadata = {
408
- 'filename': filename,
409
- 'stage': stage,
410
- 'total_elements': len(elements),
411
- 'processed_at': datetime.now().isoformat(),
412
- 'data_source': data_source
560
+ "filename": filename,
561
+ "stage": stage,
562
+ "total_elements": len(elements),
563
+ "processed_at": datetime.now().isoformat(),
564
+ "data_source": data_source,
413
565
  }
414
-
415
- self.pipeline_config.intermediate_results_destination.write(elements, metadata)
566
+
567
+ self.pipeline_config.intermediate_results_destination.write(
568
+ elements, metadata
569
+ )
416
570
  print(f" ✓ 保存 {stage.upper()} 中间结果: {len(elements)} 个元素")
417
571
  logger.info(f"保存 {stage.upper()} 中间结果成功: {filename}")
418
572
 
@@ -429,17 +583,17 @@ class Pipeline:
429
583
  print(f" → 读取文件...")
430
584
  file_bytes, data_source = self.source.read_file(file_path)
431
585
  data_source = data_source or {}
432
-
586
+
433
587
  # 检查文件大小,超过 100MB 则报错
434
588
  MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
435
589
  file_size = len(file_bytes)
436
590
  if file_size > MAX_FILE_SIZE:
437
591
  file_size_mb = file_size / (1024 * 1024)
438
592
  raise ValueError(f"文件大小过大: {file_size_mb:.2f}MB,超过100MB限制")
439
-
593
+
440
594
  # 转换为毫秒时间戳字符串
441
595
  timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000)
442
- data_source['date_processed'] = str(timestamp_ms)
596
+ data_source["date_processed"] = str(timestamp_ms)
443
597
  print(f" ✓ 文件读取完成: {len(file_bytes)} bytes")
444
598
 
445
599
  result = self.process_with_pipeline(file_bytes, file_path, data_source)
@@ -450,13 +604,13 @@ class Pipeline:
450
604
 
451
605
  print(f" → 写入目的地...")
452
606
  metadata = {
453
- 'filename': file_path,
454
- 'processed_at': str(timestamp_ms),
607
+ "filename": file_path,
608
+ "processed_at": str(timestamp_ms),
455
609
  }
456
-
610
+
457
611
  # 如果 stats 中有 record_id,添加到 metadata 中
458
612
  if stats.record_id:
459
- metadata['record_id'] = stats.record_id
613
+ metadata["record_id"] = stats.record_id
460
614
 
461
615
  success = self.destination.write(embedded_data, metadata)
462
616
 
@@ -523,168 +677,184 @@ class Pipeline:
523
677
  print("=" * 60)
524
678
 
525
679
  logger.info("=" * 60)
526
- logger.info(f"Pipeline 完成 - 总数:{total}, 成功:{success_count}, 失败:{fail_count}, 耗时:{elapsed:.2f}秒")
680
+ logger.info(
681
+ f"Pipeline 完成 - 总数:{total}, 成功:{success_count}, 失败:{fail_count}, 耗时:{elapsed:.2f}秒"
682
+ )
527
683
  logger.info("=" * 60)
528
684
 
529
685
 
530
686
  def create_pipeline_from_config(config: Dict[str, Any]) -> Pipeline:
531
- source_config = config['source']
532
- if source_config['type'] == 's3':
687
+ source_config = config["source"]
688
+ if source_config["type"] == "s3":
533
689
  source = S3Source(
534
- endpoint=source_config['endpoint'],
535
- access_key=source_config['access_key'],
536
- secret_key=source_config['secret_key'],
537
- bucket=source_config['bucket'],
538
- prefix=source_config.get('prefix', ''),
539
- region=source_config.get('region', 'us-east-1'),
540
- pattern=source_config.get('pattern', None),
541
- recursive=source_config.get('recursive', False)
690
+ endpoint=source_config["endpoint"],
691
+ access_key=source_config["access_key"],
692
+ secret_key=source_config["secret_key"],
693
+ bucket=source_config["bucket"],
694
+ prefix=source_config.get("prefix", ""),
695
+ region=source_config.get("region", "us-east-1"),
696
+ pattern=source_config.get("pattern", None),
697
+ recursive=source_config.get("recursive", False),
542
698
  )
543
- elif source_config['type'] == 'local':
699
+ elif source_config["type"] == "local":
544
700
  source = LocalSource(
545
- directory=source_config['directory'],
546
- pattern=source_config.get('pattern', None),
547
- recursive=source_config.get('recursive', False)
701
+ directory=source_config["directory"],
702
+ pattern=source_config.get("pattern", None),
703
+ recursive=source_config.get("recursive", False),
548
704
  )
549
- elif source_config['type'] == 'ftp':
705
+ elif source_config["type"] == "ftp":
550
706
  source = FtpSource(
551
- host=source_config['host'],
552
- port=source_config['port'],
553
- username=source_config['username'],
554
- password=source_config['password'],
555
- pattern=source_config.get('pattern', None),
556
- recursive=source_config.get('recursive', False)
707
+ host=source_config["host"],
708
+ port=source_config["port"],
709
+ username=source_config["username"],
710
+ password=source_config["password"],
711
+ pattern=source_config.get("pattern", None),
712
+ recursive=source_config.get("recursive", False),
557
713
  )
558
- elif source_config['type'] == 'smb':
714
+ elif source_config["type"] == "smb":
559
715
  source = SmbSource(
560
- host=source_config['host'],
561
- share_name=source_config['share_name'],
562
- username=source_config['username'],
563
- password=source_config['password'],
564
- domain=source_config.get('domain', ''),
565
- port=source_config.get('port', 445),
566
- path=source_config.get('path', ''),
567
- pattern=source_config.get('pattern', None),
568
- recursive=source_config.get('recursive', False)
716
+ host=source_config["host"],
717
+ share_name=source_config["share_name"],
718
+ username=source_config["username"],
719
+ password=source_config["password"],
720
+ domain=source_config.get("domain", ""),
721
+ port=source_config.get("port", 445),
722
+ path=source_config.get("path", ""),
723
+ pattern=source_config.get("pattern", None),
724
+ recursive=source_config.get("recursive", False),
569
725
  )
570
726
  else:
571
727
  raise ValueError(f"未知的 source 类型: {source_config['type']}")
572
728
 
573
- dest_config = config['destination']
574
- if dest_config['type'] in ['milvus', 'zilliz']:
729
+ dest_config = config["destination"]
730
+ if dest_config["type"] in ["milvus", "zilliz"]:
575
731
  destination = MilvusDestination(
576
- db_path=dest_config['db_path'],
577
- collection_name=dest_config['collection_name'],
578
- dimension=dest_config['dimension'],
579
- api_key=dest_config.get('api_key'),
580
- token=dest_config.get('token')
732
+ db_path=dest_config["db_path"],
733
+ collection_name=dest_config["collection_name"],
734
+ dimension=dest_config["dimension"],
735
+ api_key=dest_config.get("api_key"),
736
+ token=dest_config.get("token"),
581
737
  )
582
- elif dest_config['type'] == 'qdrant':
738
+ elif dest_config["type"] == "qdrant":
583
739
  destination = QdrantDestination(
584
- url=dest_config['url'],
585
- collection_name=dest_config['collection_name'],
586
- dimension=dest_config['dimension'],
587
- api_key=dest_config.get('api_key'),
588
- prefer_grpc=dest_config.get('prefer_grpc', False)
589
- )
590
- elif dest_config['type'] == 'local':
591
- destination = LocalDestination(
592
- output_dir=dest_config['output_dir']
740
+ url=dest_config["url"],
741
+ collection_name=dest_config["collection_name"],
742
+ dimension=dest_config["dimension"],
743
+ api_key=dest_config.get("api_key"),
744
+ prefer_grpc=dest_config.get("prefer_grpc", False),
593
745
  )
594
- elif dest_config['type'] == 's3':
746
+ elif dest_config["type"] == "local":
747
+ destination = LocalDestination(output_dir=dest_config["output_dir"])
748
+ elif dest_config["type"] == "s3":
595
749
  destination = S3Destination(
596
- endpoint=dest_config['endpoint'],
597
- access_key=dest_config['access_key'],
598
- secret_key=dest_config['secret_key'],
599
- bucket=dest_config['bucket'],
600
- prefix=dest_config.get('prefix', ''),
601
- region=dest_config.get('region', 'us-east-1')
750
+ endpoint=dest_config["endpoint"],
751
+ access_key=dest_config["access_key"],
752
+ secret_key=dest_config["secret_key"],
753
+ bucket=dest_config["bucket"],
754
+ prefix=dest_config.get("prefix", ""),
755
+ region=dest_config.get("region", "us-east-1"),
602
756
  )
603
757
  else:
604
758
  raise ValueError(f"未知的 destination 类型: {dest_config['type']}")
605
759
 
606
760
  # 处理 stages 配置
607
- if 'stages' not in config or not config['stages']:
761
+ if "stages" not in config or not config["stages"]:
608
762
  raise ValueError("配置中必须包含 'stages' 字段")
609
-
763
+
610
764
  stages = []
611
- for stage_cfg in config['stages']:
612
- stage_type = stage_cfg.get('type')
613
- stage_config_dict = stage_cfg.get('config', {})
614
-
615
- if stage_type == 'parse':
765
+ for stage_cfg in config["stages"]:
766
+ stage_type = stage_cfg.get("type")
767
+ stage_config_dict = stage_cfg.get("config", {})
768
+
769
+ if stage_type == "parse":
616
770
  parse_cfg_copy = dict(stage_config_dict)
617
- provider = parse_cfg_copy.pop('provider', 'textin')
771
+ provider = parse_cfg_copy.pop("provider", "textin")
618
772
  stage_config = ParseConfig(provider=provider, **parse_cfg_copy)
619
- elif stage_type == 'chunk':
773
+ elif stage_type == "chunk":
620
774
  stage_config = ChunkConfig(
621
- strategy=stage_config_dict.get('strategy', 'basic'),
622
- include_orig_elements=stage_config_dict.get('include_orig_elements', False),
623
- new_after_n_chars=stage_config_dict.get('new_after_n_chars', 512),
624
- max_characters=stage_config_dict.get('max_characters', 1024),
625
- overlap=stage_config_dict.get('overlap', 0),
626
- overlap_all=stage_config_dict.get('overlap_all', False)
775
+ strategy=stage_config_dict.get("strategy", "basic"),
776
+ include_orig_elements=stage_config_dict.get(
777
+ "include_orig_elements", False
778
+ ),
779
+ new_after_n_chars=stage_config_dict.get("new_after_n_chars", 512),
780
+ max_characters=stage_config_dict.get("max_characters", 1024),
781
+ overlap=stage_config_dict.get("overlap", 0),
782
+ overlap_all=stage_config_dict.get("overlap_all", False),
627
783
  )
628
- elif stage_type == 'embed':
784
+ elif stage_type == "embed":
629
785
  stage_config = EmbedConfig(
630
- provider=stage_config_dict.get('provider', 'qwen'),
631
- model_name=stage_config_dict.get('model_name', 'text-embedding-v3')
786
+ provider=stage_config_dict.get("provider", "qwen"),
787
+ model_name=stage_config_dict.get("model_name", "text-embedding-v3"),
788
+ )
789
+ elif stage_type == "extract":
790
+ schema = stage_config_dict.get("schema")
791
+ if not schema:
792
+ raise ValueError("extract stage 的 config 中必须包含 'schema' 字段")
793
+ stage_config = ExtractConfig(
794
+ schema=schema,
795
+ generate_citations=stage_config_dict.get("generate_citations", False),
796
+ stamp=stage_config_dict.get("stamp", False),
632
797
  )
633
798
  else:
634
799
  raise ValueError(f"未知的 stage 类型: {stage_type}")
635
-
800
+
636
801
  stages.append(Stage(type=stage_type, config=stage_config))
637
802
 
638
803
  # 创建 Pipeline 配置
639
804
  pipeline_config = None
640
- if 'pipeline_config' in config and config['pipeline_config']:
641
- pipeline_cfg = config['pipeline_config']
642
- include_intermediate_results = pipeline_cfg.get('include_intermediate_results', False)
805
+ if "pipeline_config" in config and config["pipeline_config"]:
806
+ pipeline_cfg = config["pipeline_config"]
807
+ include_intermediate_results = pipeline_cfg.get(
808
+ "include_intermediate_results", False
809
+ )
643
810
  intermediate_results_destination = None
644
-
811
+
645
812
  if include_intermediate_results:
646
- if 'intermediate_results_destination' in pipeline_cfg:
647
- dest_cfg = pipeline_cfg['intermediate_results_destination']
648
- dest_type = dest_cfg.get('type')
649
-
650
- if dest_type == 'local':
813
+ if "intermediate_results_destination" in pipeline_cfg:
814
+ dest_cfg = pipeline_cfg["intermediate_results_destination"]
815
+ dest_type = dest_cfg.get("type")
816
+
817
+ if dest_type == "local":
651
818
  intermediate_results_destination = LocalDestination(
652
- output_dir=dest_cfg['output_dir']
819
+ output_dir=dest_cfg["output_dir"]
653
820
  )
654
- elif dest_type == 's3':
821
+ elif dest_type == "s3":
655
822
  intermediate_results_destination = S3Destination(
656
- endpoint=dest_cfg['endpoint'],
657
- access_key=dest_cfg['access_key'],
658
- secret_key=dest_cfg['secret_key'],
659
- bucket=dest_cfg['bucket'],
660
- prefix=dest_cfg.get('prefix', ''),
661
- region=dest_cfg.get('region', 'us-east-1')
823
+ endpoint=dest_cfg["endpoint"],
824
+ access_key=dest_cfg["access_key"],
825
+ secret_key=dest_cfg["secret_key"],
826
+ bucket=dest_cfg["bucket"],
827
+ prefix=dest_cfg.get("prefix", ""),
828
+ region=dest_cfg.get("region", "us-east-1"),
662
829
  )
663
830
  else:
664
- raise ValueError(f"不支持的 intermediate_results_destination 类型: '{dest_type}',支持的类型: 'local', 's3'")
831
+ raise ValueError(
832
+ f"不支持的 intermediate_results_destination 类型: '{dest_type}',支持的类型: 'local', 's3'"
833
+ )
665
834
  else:
666
- raise ValueError("当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination")
667
-
835
+ raise ValueError(
836
+ "当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination"
837
+ )
838
+
668
839
  pipeline_config = PipelineConfig(
669
840
  include_intermediate_results=include_intermediate_results,
670
- intermediate_results_destination=intermediate_results_destination
841
+ intermediate_results_destination=intermediate_results_destination,
671
842
  )
672
843
 
673
844
  # 创建 Pipeline
674
845
  pipeline = Pipeline(
675
846
  source=source,
676
847
  destination=destination,
677
- api_base_url=config.get('api_base_url', 'http://localhost:8000/api/xparse'),
678
- api_headers=config.get('api_headers', {}),
848
+ api_base_url=config.get("api_base_url", "http://localhost:8000/api/xparse"),
849
+ api_headers=config.get("api_headers", {}),
679
850
  stages=stages,
680
- pipeline_config=pipeline_config
851
+ pipeline_config=pipeline_config,
681
852
  )
682
853
 
683
854
  return pipeline
684
855
 
685
856
 
686
857
  __all__ = [
687
- 'Pipeline',
688
- 'create_pipeline_from_config',
858
+ "Pipeline",
859
+ "create_pipeline_from_config",
689
860
  ]
690
-