xparse-client 0.2.20__py3-none-any.whl → 0.3.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. example/1_basic_api_usage.py +198 -0
  2. example/2_async_job.py +210 -0
  3. example/3_local_workflow.py +300 -0
  4. example/4_advanced_workflow.py +327 -0
  5. example/README.md +128 -0
  6. example/config_example.json +95 -0
  7. tests/conftest.py +310 -0
  8. tests/unit/__init__.py +1 -0
  9. tests/unit/api/__init__.py +1 -0
  10. tests/unit/api/test_extract.py +232 -0
  11. tests/unit/api/test_local.py +231 -0
  12. tests/unit/api/test_parse.py +374 -0
  13. tests/unit/api/test_pipeline.py +369 -0
  14. tests/unit/api/test_workflows.py +108 -0
  15. tests/unit/connectors/test_ftp.py +525 -0
  16. tests/unit/connectors/test_local_connectors.py +324 -0
  17. tests/unit/connectors/test_milvus.py +368 -0
  18. tests/unit/connectors/test_qdrant.py +399 -0
  19. tests/unit/connectors/test_s3.py +598 -0
  20. tests/unit/connectors/test_smb.py +442 -0
  21. tests/unit/connectors/test_utils.py +335 -0
  22. tests/unit/models/test_local.py +54 -0
  23. tests/unit/models/test_pipeline_stages.py +144 -0
  24. tests/unit/models/test_workflows.py +55 -0
  25. tests/unit/test_base.py +437 -0
  26. tests/unit/test_client.py +110 -0
  27. tests/unit/test_config.py +160 -0
  28. tests/unit/test_exceptions.py +182 -0
  29. tests/unit/test_http.py +562 -0
  30. xparse_client/__init__.py +110 -20
  31. xparse_client/_base.py +179 -0
  32. xparse_client/_client.py +218 -0
  33. xparse_client/_config.py +221 -0
  34. xparse_client/_http.py +350 -0
  35. xparse_client/api/__init__.py +14 -0
  36. xparse_client/api/extract.py +109 -0
  37. xparse_client/api/local.py +188 -0
  38. xparse_client/api/parse.py +209 -0
  39. xparse_client/api/pipeline.py +132 -0
  40. xparse_client/api/workflows.py +204 -0
  41. xparse_client/connectors/__init__.py +45 -0
  42. xparse_client/connectors/_utils.py +138 -0
  43. xparse_client/connectors/destinations/__init__.py +45 -0
  44. xparse_client/connectors/destinations/base.py +116 -0
  45. xparse_client/connectors/destinations/local.py +91 -0
  46. xparse_client/connectors/destinations/milvus.py +229 -0
  47. xparse_client/connectors/destinations/qdrant.py +238 -0
  48. xparse_client/connectors/destinations/s3.py +163 -0
  49. xparse_client/connectors/sources/__init__.py +45 -0
  50. xparse_client/connectors/sources/base.py +74 -0
  51. xparse_client/connectors/sources/ftp.py +278 -0
  52. xparse_client/connectors/sources/local.py +176 -0
  53. xparse_client/connectors/sources/s3.py +232 -0
  54. xparse_client/connectors/sources/smb.py +259 -0
  55. xparse_client/exceptions.py +398 -0
  56. xparse_client/models/__init__.py +60 -0
  57. xparse_client/models/chunk.py +39 -0
  58. xparse_client/models/embed.py +62 -0
  59. xparse_client/models/extract.py +41 -0
  60. xparse_client/models/local.py +38 -0
  61. xparse_client/models/parse.py +136 -0
  62. xparse_client/models/pipeline.py +132 -0
  63. xparse_client/models/workflows.py +74 -0
  64. xparse_client-0.3.0b2.dist-info/METADATA +1075 -0
  65. xparse_client-0.3.0b2.dist-info/RECORD +68 -0
  66. {xparse_client-0.2.20.dist-info → xparse_client-0.3.0b2.dist-info}/WHEEL +1 -1
  67. {xparse_client-0.2.20.dist-info → xparse_client-0.3.0b2.dist-info}/licenses/LICENSE +1 -1
  68. {xparse_client-0.2.20.dist-info → xparse_client-0.3.0b2.dist-info}/top_level.txt +2 -0
  69. xparse_client/pipeline/__init__.py +0 -3
  70. xparse_client/pipeline/config.py +0 -163
  71. xparse_client/pipeline/destinations.py +0 -489
  72. xparse_client/pipeline/pipeline.py +0 -860
  73. xparse_client/pipeline/sources.py +0 -583
  74. xparse_client-0.2.20.dist-info/METADATA +0 -1050
  75. xparse_client-0.2.20.dist-info/RECORD +0 -11
@@ -1,860 +0,0 @@
1
- #!/usr/bin/env python
2
- # -*- encoding: utf-8 -*-
3
-
4
- import json
5
- import logging
6
- import re
7
- import time
8
- from datetime import datetime, timezone
9
- from pathlib import Path
10
- from typing import Any, Dict, List, Optional, Tuple, Union
11
-
12
- import requests
13
-
14
- from .config import (
15
- ChunkConfig,
16
- EmbedConfig,
17
- ExtractConfig,
18
- ParseConfig,
19
- PipelineConfig,
20
- PipelineStats,
21
- Stage,
22
- )
23
- from .destinations import (
24
- Destination,
25
- LocalDestination,
26
- MilvusDestination,
27
- QdrantDestination,
28
- S3Destination,
29
- )
30
- from .sources import FtpSource, LocalSource, S3Source, SmbSource, Source
31
-
32
- logger = logging.getLogger(__name__)
33
-
34
-
35
- class Pipeline:
36
- """数据处理 Pipeline"""
37
-
38
- def __init__(
39
- self,
40
- source: Source,
41
- destination: Destination,
42
- api_base_url: str = "http://localhost:8000/api/xparse",
43
- api_headers: Optional[Dict[str, str]] = None,
44
- stages: Optional[List[Stage]] = None,
45
- pipeline_config: Optional[PipelineConfig] = None,
46
- intermediate_results_destination: Optional[Destination] = None,
47
- ):
48
- self.source = source
49
- self.destination = destination
50
- self.api_base_url = api_base_url.rstrip("/")
51
- self.api_headers = api_headers or {}
52
- self.pipeline_config = pipeline_config or PipelineConfig()
53
-
54
- # 处理 intermediate_results_destination 参数
55
- # 如果直接传入了 intermediate_results_destination,优先使用它并自动启用中间结果保存
56
- if intermediate_results_destination is not None:
57
- self.pipeline_config.include_intermediate_results = True
58
- self.pipeline_config.intermediate_results_destination = (
59
- intermediate_results_destination
60
- )
61
- # 如果 pipeline_config 中已设置,使用 pipeline_config 中的值
62
- elif self.pipeline_config.include_intermediate_results:
63
- if not self.pipeline_config.intermediate_results_destination:
64
- raise ValueError(
65
- "当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination"
66
- )
67
-
68
- # 处理 stages 配置
69
- if stages is None:
70
- raise ValueError("必须提供 stages 参数")
71
-
72
- self.stages = stages
73
-
74
- # 验证 stages
75
- if not self.stages or self.stages[0].type != "parse":
76
- raise ValueError("stages 必须包含且第一个必须是 'parse' 类型")
77
-
78
- # 检查是否包含 extract 阶段
79
- self.has_extract = any(stage.type == "extract" for stage in self.stages)
80
-
81
- # 验证 extract 阶段的约束
82
- if self.has_extract:
83
- # extract 只能跟在 parse 后面,且必须是最后一个
84
- if len(self.stages) != 2 or self.stages[1].type != "extract":
85
- raise ValueError(
86
- "extract 阶段只能跟在 parse 后面,且必须是最后一个阶段(即只能是 [parse, extract] 组合)"
87
- )
88
-
89
- # destination 必须是文件存储类型
90
- if not isinstance(destination, (LocalDestination, S3Destination)):
91
- raise ValueError(
92
- "当使用 extract 阶段时,destination 必须是文件存储类型(LocalDestination 或 S3Destination),不能是向量数据库"
93
- )
94
-
95
- # 验证 embed config(如果存在)
96
- for stage in self.stages:
97
- if stage.type == "embed" and isinstance(stage.config, EmbedConfig):
98
- stage.config.validate()
99
-
100
- # 验证 intermediate_results_destination
101
- if self.pipeline_config.include_intermediate_results:
102
- # 验证是否为支持的 Destination 类型
103
- from .destinations import Destination
104
-
105
- if not isinstance(
106
- self.pipeline_config.intermediate_results_destination, Destination
107
- ):
108
- raise ValueError(
109
- f"intermediate_results_destination 必须是 Destination 类型"
110
- )
111
- self.intermediate_results_destination = (
112
- self.pipeline_config.intermediate_results_destination
113
- )
114
-
115
- print("=" * 60)
116
- print("Pipeline 初始化完成")
117
- print(f" Stages: {[s.type for s in self.stages]}")
118
- for stage in self.stages:
119
- print(f" - {stage.type}: {stage.config}")
120
- if self.pipeline_config.include_intermediate_results:
121
- print(f" Pipeline Config: 中间结果保存已启用")
122
- print("=" * 60)
123
-
124
- def get_config(self) -> Dict[str, Any]:
125
- """获取 Pipeline 的完整配置信息,返回字典格式(与 create_pipeline_from_config 的入参格式一致)"""
126
- config = {}
127
-
128
- # Source 配置
129
- source_type = type(self.source).__name__.replace("Source", "").lower()
130
- config["source"] = {"type": source_type}
131
-
132
- if isinstance(self.source, S3Source):
133
- config["source"].update(
134
- {
135
- "endpoint": self.source.endpoint,
136
- "bucket": self.source.bucket,
137
- "prefix": self.source.prefix,
138
- "pattern": self.source.pattern,
139
- "recursive": self.source.recursive,
140
- }
141
- )
142
- # access_key 和 secret_key 不在对象中保存,无法恢复
143
- # region 也不在对象中保存,使用默认值
144
- config["source"]["region"] = "us-east-1" # 默认值
145
- elif isinstance(self.source, LocalSource):
146
- config["source"].update(
147
- {
148
- "directory": str(self.source.directory),
149
- "pattern": self.source.pattern,
150
- "recursive": self.source.recursive,
151
- }
152
- )
153
- elif isinstance(self.source, FtpSource):
154
- config["source"].update(
155
- {
156
- "host": self.source.host,
157
- "port": self.source.port,
158
- "username": self.source.username,
159
- "pattern": self.source.pattern,
160
- "recursive": self.source.recursive,
161
- }
162
- )
163
- # password 不在对象中保存,无法恢复
164
- elif isinstance(self.source, SmbSource):
165
- config["source"].update(
166
- {
167
- "host": self.source.host,
168
- "share_name": self.source.share_name,
169
- "username": self.source.username,
170
- "domain": self.source.domain,
171
- "port": self.source.port,
172
- "path": self.source.path,
173
- "pattern": self.source.pattern,
174
- "recursive": self.source.recursive,
175
- }
176
- )
177
- # password 不在对象中保存,无法恢复
178
-
179
- # Destination 配置
180
- dest_type = type(self.destination).__name__.replace("Destination", "").lower()
181
- # MilvusDestination 和 Zilliz 都使用 'milvus' 或 'zilliz' 类型
182
- if dest_type == "milvus":
183
- # 判断是本地 Milvus 还是 Zilliz(通过 db_path 判断)
184
- if self.destination.db_path.startswith("http"):
185
- dest_type = "zilliz"
186
- else:
187
- dest_type = "milvus"
188
-
189
- config["destination"] = {"type": dest_type}
190
-
191
- if isinstance(self.destination, MilvusDestination):
192
- config["destination"].update(
193
- {
194
- "db_path": self.destination.db_path,
195
- "collection_name": self.destination.collection_name,
196
- "dimension": self.destination.dimension,
197
- }
198
- )
199
- # api_key 和 token 不在对象中保存,无法恢复
200
- elif isinstance(self.destination, QdrantDestination):
201
- config["destination"].update(
202
- {
203
- "url": self.destination.url,
204
- "collection_name": self.destination.collection_name,
205
- "dimension": self.destination.dimension,
206
- "prefer_grpc": getattr(self.destination, "prefer_grpc", False),
207
- }
208
- )
209
- # api_key 不在对象中保存,无法恢复
210
- elif isinstance(self.destination, LocalDestination):
211
- config["destination"].update(
212
- {"output_dir": str(self.destination.output_dir)}
213
- )
214
- elif isinstance(self.destination, S3Destination):
215
- config["destination"].update(
216
- {
217
- "endpoint": self.destination.endpoint,
218
- "bucket": self.destination.bucket,
219
- "prefix": self.destination.prefix,
220
- }
221
- )
222
- # access_key, secret_key, region 不在对象中保存,无法恢复
223
- config["destination"]["region"] = "us-east-1" # 默认值
224
-
225
- # API 配置
226
- config["api_base_url"] = self.api_base_url
227
- config["api_headers"] = {}
228
- for key, value in self.api_headers.items():
229
- config["api_headers"][key] = value
230
-
231
- # Stages 配置
232
- config["stages"] = []
233
- for stage in self.stages:
234
- stage_dict = {"type": stage.type, "config": {}}
235
-
236
- if isinstance(stage.config, ParseConfig):
237
- stage_dict["config"] = stage.config.to_dict()
238
- elif isinstance(stage.config, ChunkConfig):
239
- stage_dict["config"] = stage.config.to_dict()
240
- elif isinstance(stage.config, EmbedConfig):
241
- stage_dict["config"] = stage.config.to_dict()
242
- elif isinstance(stage.config, ExtractConfig):
243
- stage_dict["config"] = stage.config.to_dict()
244
- else:
245
- # 如果 config 是字典或其他类型,尝试转换
246
- if isinstance(stage.config, dict):
247
- stage_dict["config"] = stage.config
248
- else:
249
- stage_dict["config"] = str(stage.config)
250
-
251
- config["stages"].append(stage_dict)
252
-
253
- # Pipeline Config
254
- if self.pipeline_config.include_intermediate_results:
255
- config["pipeline_config"] = {
256
- "include_intermediate_results": True,
257
- "intermediate_results_destination": {},
258
- }
259
-
260
- inter_dest = self.pipeline_config.intermediate_results_destination
261
- if inter_dest:
262
- inter_dest_type = (
263
- type(inter_dest).__name__.replace("Destination", "").lower()
264
- )
265
- config["pipeline_config"]["intermediate_results_destination"][
266
- "type"
267
- ] = inter_dest_type
268
-
269
- if isinstance(inter_dest, LocalDestination):
270
- config["pipeline_config"]["intermediate_results_destination"][
271
- "output_dir"
272
- ] = str(inter_dest.output_dir)
273
- elif isinstance(inter_dest, S3Destination):
274
- config["pipeline_config"][
275
- "intermediate_results_destination"
276
- ].update(
277
- {
278
- "endpoint": inter_dest.endpoint,
279
- "bucket": inter_dest.bucket,
280
- "prefix": inter_dest.prefix,
281
- }
282
- )
283
- # access_key, secret_key, region 不在对象中保存,无法恢复
284
- config["pipeline_config"]["intermediate_results_destination"][
285
- "region"
286
- ] = "us-east-1" # 默认值
287
-
288
- return config
289
-
290
- def _extract_error_message(self, response: requests.Response) -> Tuple[str, str]:
291
- """
292
- 从响应中提取规范化的错误信息
293
-
294
- Returns:
295
- Tuple[str, str]: (error_msg, x_request_id)
296
- """
297
- # 首先尝试从响应头中提取 x-request-id(requests的headers大小写不敏感)
298
- x_request_id = response.headers.get("x-request-id", "")
299
- error_msg = ""
300
-
301
- # 获取Content-Type
302
- content_type = response.headers.get("Content-Type", "").lower()
303
-
304
- # 尝试解析JSON响应
305
- if "application/json" in content_type:
306
- try:
307
- result = response.json()
308
- # 如果响应头中没有x-request-id,尝试从响应体中获取
309
- if not x_request_id:
310
- x_request_id = result.get("x_request_id", "")
311
- error_msg = result.get(
312
- "message", result.get("msg", f"HTTP {response.status_code}")
313
- )
314
- return error_msg, x_request_id
315
- except:
316
- pass
317
-
318
- # 处理HTML响应
319
- if "text/html" in content_type or response.text.strip().startswith("<"):
320
- try:
321
- # 从HTML中提取标题(通常包含状态码和状态文本)
322
- title_match = re.search(
323
- r"<title>(.*?)</title>", response.text, re.IGNORECASE
324
- )
325
- if title_match:
326
- error_msg = title_match.group(1).strip()
327
- else:
328
- # 如果没有title,尝试提取h1标签
329
- h1_match = re.search(
330
- r"<h1>(.*?)</h1>", response.text, re.IGNORECASE
331
- )
332
- if h1_match:
333
- error_msg = h1_match.group(1).strip()
334
- else:
335
- error_msg = f"HTTP {response.status_code}"
336
- except:
337
- error_msg = f"HTTP {response.status_code}"
338
-
339
- # 处理纯文本响应
340
- elif "text/plain" in content_type:
341
- error_msg = (
342
- response.text[:200].strip()
343
- if response.text
344
- else f"HTTP {response.status_code}"
345
- )
346
-
347
- # 其他情况
348
- else:
349
- if response.text:
350
- # 尝试截取前200字符,但去除换行和多余空格
351
- text = response.text[:200].strip()
352
- # 如果包含多行,只取第一行
353
- if "\n" in text:
354
- text = text.split("\n")[0].strip()
355
- error_msg = text if text else f"HTTP {response.status_code}"
356
- else:
357
- error_msg = f"HTTP {response.status_code}"
358
-
359
- return error_msg, x_request_id
360
-
361
- def _call_pipeline_api(
362
- self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]
363
- ) -> Optional[Dict[str, Any]]:
364
- url = f"{self.api_base_url}/pipeline"
365
- max_retries = 3
366
-
367
- for try_count in range(max_retries):
368
- try:
369
- files = {"file": (filename or "file", file_bytes)}
370
- form_data = {}
371
-
372
- # 将 stages 转换为 API 格式
373
- stages_data = [stage.to_dict() for stage in self.stages]
374
- try:
375
- form_data["stages"] = json.dumps(stages_data, ensure_ascii=False)
376
- form_data["data_source"] = json.dumps(
377
- data_source, ensure_ascii=False
378
- )
379
-
380
- # 如果启用了中间结果保存,在请求中添加参数
381
- if self.pipeline_config:
382
- form_data["config"] = json.dumps(
383
- self.pipeline_config.to_dict(), ensure_ascii=False
384
- )
385
- except Exception as e:
386
- print(f" ✗ 入参处理失败,请检查配置: {e}")
387
- logger.error(f"入参处理失败,请检查配置: {e}")
388
- return None
389
-
390
- response = requests.post(
391
- url,
392
- files=files,
393
- data=form_data,
394
- headers=self.api_headers,
395
- timeout=630,
396
- )
397
-
398
- if response.status_code == 200:
399
- result = response.json()
400
- x_request_id = result.get("x_request_id", "")
401
- print(f" ✓ Pipeline 接口返回 x_request_id: {x_request_id}")
402
- if result.get("code") == 200 and "data" in result:
403
- return result.get("data")
404
- # 如果 code 不是 200,打印错误信息
405
- error_msg = result.get("message", result.get("msg", "未知错误"))
406
- print(
407
- f" ✗ Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}"
408
- )
409
- logger.error(
410
- f"Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}"
411
- )
412
- return None
413
- else:
414
- # 使用规范化函数提取错误信息
415
- error_msg, x_request_id = self._extract_error_message(response)
416
-
417
- print(
418
- f" ✗ API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}"
419
- )
420
- logger.warning(
421
- f"API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}"
422
- )
423
-
424
- except Exception as e:
425
- # 如果是 requests 异常,尝试从响应中获取 x_request_id
426
- x_request_id = ""
427
- error_msg = str(e)
428
- try:
429
- if hasattr(e, "response") and e.response is not None:
430
- try:
431
- result = e.response.json()
432
- x_request_id = result.get("x_request_id", "")
433
- error_msg = result.get(
434
- "message", result.get("msg", error_msg)
435
- )
436
- except:
437
- pass
438
- except:
439
- pass
440
-
441
- print(
442
- f" ✗ 请求异常: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}"
443
- )
444
- logger.error(
445
- f"API 请求异常 pipeline: {error_msg}, x_request_id={x_request_id}"
446
- )
447
-
448
- if try_count < max_retries - 1:
449
- time.sleep(2)
450
-
451
- return None
452
-
453
- def process_with_pipeline(
454
- self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]
455
- ) -> Optional[Tuple[Any, PipelineStats]]:
456
- print(f" → 调用 Pipeline 接口: {filename}")
457
- result = self._call_pipeline_api(file_bytes, filename, data_source)
458
-
459
- if not result or "stats" not in result:
460
- print(f" ✗ Pipeline 失败")
461
- logger.error(f"Pipeline 失败: {filename}")
462
- return None
463
-
464
- # 处理 extract 类型的响应
465
- if self.has_extract:
466
- # extract 返回 extract_result 而不是 elements
467
- if "extract_result" not in result:
468
- print(f" ✗ Pipeline 失败: extract 响应中缺少 extract_result")
469
- logger.error(f"Pipeline 失败: extract 响应中缺少 extract_result")
470
- return None
471
-
472
- data = result["extract_result"] # 结构化数据
473
- stats_data = result["stats"]
474
-
475
- stats = PipelineStats(
476
- original_elements=stats_data.get("original_elements", 0),
477
- chunked_elements=0, # extract 不涉及分块
478
- embedded_elements=0, # extract 不涉及向量化
479
- stages=self.stages,
480
- record_id=stats_data.get("record_id"),
481
- )
482
-
483
- # 如果启用了中间结果保存,处理中间结果
484
- if (
485
- self.pipeline_config.include_intermediate_results
486
- and "intermediate_results" in result
487
- ):
488
- self._save_intermediate_results(
489
- result["intermediate_results"], filename, data_source
490
- )
491
-
492
- print(f" ✓ Extract 完成:")
493
- print(f" - 原始元素: {stats.original_elements}")
494
- print(f" - 提取结果类型: {type(data).__name__}")
495
- logger.info(f"Extract 完成: {filename}")
496
-
497
- return data, stats
498
-
499
- else:
500
- # 原有的 parse/chunk/embed 逻辑
501
- if "elements" not in result:
502
- print(f" ✗ Pipeline 失败: 响应中缺少 elements")
503
- logger.error(f"Pipeline 失败: 响应中缺少 elements")
504
- return None
505
-
506
- elements = result["elements"]
507
- stats_data = result["stats"]
508
-
509
- stats = PipelineStats(
510
- original_elements=stats_data.get("original_elements", 0),
511
- chunked_elements=stats_data.get("chunked_elements", 0),
512
- embedded_elements=stats_data.get("embedded_elements", 0),
513
- stages=self.stages, # 使用实际执行的 stages
514
- record_id=stats_data.get("record_id"), # 从 API 响应中获取 record_id
515
- )
516
-
517
- # 如果启用了中间结果保存,处理中间结果
518
- if (
519
- self.pipeline_config.include_intermediate_results
520
- and "intermediate_results" in result
521
- ):
522
- self._save_intermediate_results(
523
- result["intermediate_results"], filename, data_source
524
- )
525
-
526
- print(f" ✓ Pipeline 完成:")
527
- print(f" - 原始元素: {stats.original_elements}")
528
- print(f" - 分块后: {stats.chunked_elements}")
529
- print(f" - 向量化: {stats.embedded_elements}")
530
- logger.info(f"Pipeline 完成: {filename}, {stats.embedded_elements} 个向量")
531
-
532
- return elements, stats
533
-
534
- def _save_intermediate_results(
535
- self,
536
- intermediate_results: List[Dict[str, Any]],
537
- filename: str,
538
- data_source: Dict[str, Any],
539
- ) -> None:
540
- """保存中间结果
541
-
542
- Args:
543
- intermediate_results: 中间结果数组,每个元素包含 stage 和 elements 字段
544
- filename: 文件名
545
- data_source: 数据源信息
546
- """
547
- try:
548
- # intermediate_results 是一个数组,每个元素是 {stage: str, elements: List}
549
- for result_item in intermediate_results:
550
- if "stage" not in result_item or "elements" not in result_item:
551
- logger.warning(
552
- f"中间结果项缺少 stage 或 elements 字段: {result_item}"
553
- )
554
- continue
555
-
556
- stage = result_item["stage"]
557
- elements = result_item["elements"]
558
-
559
- metadata = {
560
- "filename": filename,
561
- "stage": stage,
562
- "total_elements": len(elements),
563
- "processed_at": datetime.now().isoformat(),
564
- "data_source": data_source,
565
- }
566
-
567
- self.pipeline_config.intermediate_results_destination.write(
568
- elements, metadata
569
- )
570
- print(f" ✓ 保存 {stage.upper()} 中间结果: {len(elements)} 个元素")
571
- logger.info(f"保存 {stage.upper()} 中间结果成功: {filename}")
572
-
573
- except Exception as e:
574
- print(f" ✗ 保存中间结果失败: {str(e)}")
575
- logger.error(f"保存中间结果失败: {filename}, {str(e)}")
576
-
577
- def process_file(self, file_path: str) -> bool:
578
- print(f"\n{'=' * 60}")
579
- print(f"处理文件: {file_path}")
580
- logger.info(f"开始处理文件: {file_path}")
581
-
582
- try:
583
- print(f" → 读取文件...")
584
- file_bytes, data_source = self.source.read_file(file_path)
585
- data_source = data_source or {}
586
-
587
- # 检查文件大小,超过 100MB 则报错
588
- MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
589
- file_size = len(file_bytes)
590
- if file_size > MAX_FILE_SIZE:
591
- file_size_mb = file_size / (1024 * 1024)
592
- raise ValueError(f"文件大小过大: {file_size_mb:.2f}MB,超过100MB限制")
593
-
594
- # 转换为毫秒时间戳字符串
595
- timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000)
596
- data_source["date_processed"] = str(timestamp_ms)
597
- print(f" ✓ 文件读取完成: {len(file_bytes)} bytes")
598
-
599
- result = self.process_with_pipeline(file_bytes, file_path, data_source)
600
- if not result:
601
- return False
602
-
603
- embedded_data, stats = result
604
-
605
- print(f" → 写入目的地...")
606
- metadata = {
607
- "filename": file_path,
608
- "processed_at": str(timestamp_ms),
609
- }
610
-
611
- # 如果 stats 中有 record_id,添加到 metadata 中
612
- if stats.record_id:
613
- metadata["record_id"] = stats.record_id
614
-
615
- success = self.destination.write(embedded_data, metadata)
616
-
617
- if success:
618
- print(f"\n✓✓✓ 文件处理成功: {file_path}")
619
- logger.info(f"文件处理成功: {file_path}")
620
- else:
621
- print(f"\n✗✗✗ 文件处理失败: {file_path}")
622
- logger.error(f"文件处理失败: {file_path}")
623
-
624
- return success
625
-
626
- except Exception as e:
627
- print(f"\n✗✗✗ 处理异常: {str(e)}")
628
- logger.error(f"处理文件异常 {file_path}: {str(e)}")
629
- return False
630
-
631
- def run(self):
632
- start_time = time.time()
633
-
634
- print("\n" + "=" * 60)
635
- print("开始执行 Pipeline")
636
- print("=" * 60)
637
- logger.info("=" * 60)
638
- logger.info("开始执行 Pipeline")
639
-
640
- print("\n→ 列出文件...")
641
- files = self.source.list_files()
642
-
643
- if not files:
644
- print("\n✗ 没有找到文件")
645
- logger.info("没有找到文件")
646
- return
647
-
648
- total = len(files)
649
- success_count = 0
650
- fail_count = 0
651
-
652
- for idx, file_path in enumerate(files, 1):
653
- print(f"\n进度: [{idx}/{total}]")
654
- logger.info(f"进度: [{idx}/{total}] - {file_path}")
655
-
656
- try:
657
- if self.process_file(file_path):
658
- success_count += 1
659
- else:
660
- fail_count += 1
661
- except Exception as e:
662
- print(f"\n✗✗✗ 文件处理异常: {str(e)}")
663
- logger.error(f"文件处理异常 {file_path}: {str(e)}")
664
- fail_count += 1
665
-
666
- if idx < total:
667
- time.sleep(1)
668
-
669
- elapsed = time.time() - start_time
670
- print("\n" + "=" * 60)
671
- print("Pipeline 执行完成!")
672
- print("=" * 60)
673
- print(f"总文件数: {total}")
674
- print(f"成功: {success_count}")
675
- print(f"失败: {fail_count}")
676
- print(f"总耗时: {elapsed:.2f} 秒")
677
- print("=" * 60)
678
-
679
- logger.info("=" * 60)
680
- logger.info(
681
- f"Pipeline 完成 - 总数:{total}, 成功:{success_count}, 失败:{fail_count}, 耗时:{elapsed:.2f}秒"
682
- )
683
- logger.info("=" * 60)
684
-
685
-
686
- def create_pipeline_from_config(config: Dict[str, Any]) -> Pipeline:
687
- source_config = config["source"]
688
- if source_config["type"] == "s3":
689
- source = S3Source(
690
- endpoint=source_config["endpoint"],
691
- access_key=source_config["access_key"],
692
- secret_key=source_config["secret_key"],
693
- bucket=source_config["bucket"],
694
- prefix=source_config.get("prefix", ""),
695
- region=source_config.get("region", "us-east-1"),
696
- pattern=source_config.get("pattern", None),
697
- recursive=source_config.get("recursive", False),
698
- )
699
- elif source_config["type"] == "local":
700
- source = LocalSource(
701
- directory=source_config["directory"],
702
- pattern=source_config.get("pattern", None),
703
- recursive=source_config.get("recursive", False),
704
- )
705
- elif source_config["type"] == "ftp":
706
- source = FtpSource(
707
- host=source_config["host"],
708
- port=source_config["port"],
709
- username=source_config["username"],
710
- password=source_config["password"],
711
- pattern=source_config.get("pattern", None),
712
- recursive=source_config.get("recursive", False),
713
- )
714
- elif source_config["type"] == "smb":
715
- source = SmbSource(
716
- host=source_config["host"],
717
- share_name=source_config["share_name"],
718
- username=source_config["username"],
719
- password=source_config["password"],
720
- domain=source_config.get("domain", ""),
721
- port=source_config.get("port", 445),
722
- path=source_config.get("path", ""),
723
- pattern=source_config.get("pattern", None),
724
- recursive=source_config.get("recursive", False),
725
- )
726
- else:
727
- raise ValueError(f"未知的 source 类型: {source_config['type']}")
728
-
729
- dest_config = config["destination"]
730
- if dest_config["type"] in ["milvus", "zilliz"]:
731
- destination = MilvusDestination(
732
- db_path=dest_config["db_path"],
733
- collection_name=dest_config["collection_name"],
734
- dimension=dest_config["dimension"],
735
- api_key=dest_config.get("api_key"),
736
- token=dest_config.get("token"),
737
- )
738
- elif dest_config["type"] == "qdrant":
739
- destination = QdrantDestination(
740
- url=dest_config["url"],
741
- collection_name=dest_config["collection_name"],
742
- dimension=dest_config["dimension"],
743
- api_key=dest_config.get("api_key"),
744
- prefer_grpc=dest_config.get("prefer_grpc", False),
745
- )
746
- elif dest_config["type"] == "local":
747
- destination = LocalDestination(output_dir=dest_config["output_dir"])
748
- elif dest_config["type"] == "s3":
749
- destination = S3Destination(
750
- endpoint=dest_config["endpoint"],
751
- access_key=dest_config["access_key"],
752
- secret_key=dest_config["secret_key"],
753
- bucket=dest_config["bucket"],
754
- prefix=dest_config.get("prefix", ""),
755
- region=dest_config.get("region", "us-east-1"),
756
- )
757
- else:
758
- raise ValueError(f"未知的 destination 类型: {dest_config['type']}")
759
-
760
- # 处理 stages 配置
761
- if "stages" not in config or not config["stages"]:
762
- raise ValueError("配置中必须包含 'stages' 字段")
763
-
764
- stages = []
765
- for stage_cfg in config["stages"]:
766
- stage_type = stage_cfg.get("type")
767
- stage_config_dict = stage_cfg.get("config", {})
768
-
769
- if stage_type == "parse":
770
- parse_cfg_copy = dict(stage_config_dict)
771
- provider = parse_cfg_copy.pop("provider", "textin")
772
- stage_config = ParseConfig(provider=provider, **parse_cfg_copy)
773
- elif stage_type == "chunk":
774
- stage_config = ChunkConfig(
775
- strategy=stage_config_dict.get("strategy", "basic"),
776
- include_orig_elements=stage_config_dict.get(
777
- "include_orig_elements", False
778
- ),
779
- new_after_n_chars=stage_config_dict.get("new_after_n_chars", 512),
780
- max_characters=stage_config_dict.get("max_characters", 1024),
781
- overlap=stage_config_dict.get("overlap", 0),
782
- overlap_all=stage_config_dict.get("overlap_all", False),
783
- )
784
- elif stage_type == "embed":
785
- stage_config = EmbedConfig(
786
- provider=stage_config_dict.get("provider", "qwen"),
787
- model_name=stage_config_dict.get("model_name", "text-embedding-v3"),
788
- )
789
- elif stage_type == "extract":
790
- schema = stage_config_dict.get("schema")
791
- if not schema:
792
- raise ValueError("extract stage 的 config 中必须包含 'schema' 字段")
793
- stage_config = ExtractConfig(
794
- schema=schema,
795
- generate_citations=stage_config_dict.get("generate_citations", False),
796
- stamp=stage_config_dict.get("stamp", False),
797
- )
798
- else:
799
- raise ValueError(f"未知的 stage 类型: {stage_type}")
800
-
801
- stages.append(Stage(type=stage_type, config=stage_config))
802
-
803
- # 创建 Pipeline 配置
804
- pipeline_config = None
805
- if "pipeline_config" in config and config["pipeline_config"]:
806
- pipeline_cfg = config["pipeline_config"]
807
- include_intermediate_results = pipeline_cfg.get(
808
- "include_intermediate_results", False
809
- )
810
- intermediate_results_destination = None
811
-
812
- if include_intermediate_results:
813
- if "intermediate_results_destination" in pipeline_cfg:
814
- dest_cfg = pipeline_cfg["intermediate_results_destination"]
815
- dest_type = dest_cfg.get("type")
816
-
817
- if dest_type == "local":
818
- intermediate_results_destination = LocalDestination(
819
- output_dir=dest_cfg["output_dir"]
820
- )
821
- elif dest_type == "s3":
822
- intermediate_results_destination = S3Destination(
823
- endpoint=dest_cfg["endpoint"],
824
- access_key=dest_cfg["access_key"],
825
- secret_key=dest_cfg["secret_key"],
826
- bucket=dest_cfg["bucket"],
827
- prefix=dest_cfg.get("prefix", ""),
828
- region=dest_cfg.get("region", "us-east-1"),
829
- )
830
- else:
831
- raise ValueError(
832
- f"不支持的 intermediate_results_destination 类型: '{dest_type}',支持的类型: 'local', 's3'"
833
- )
834
- else:
835
- raise ValueError(
836
- "当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination"
837
- )
838
-
839
- pipeline_config = PipelineConfig(
840
- include_intermediate_results=include_intermediate_results,
841
- intermediate_results_destination=intermediate_results_destination,
842
- )
843
-
844
- # 创建 Pipeline
845
- pipeline = Pipeline(
846
- source=source,
847
- destination=destination,
848
- api_base_url=config.get("api_base_url", "http://localhost:8000/api/xparse"),
849
- api_headers=config.get("api_headers", {}),
850
- stages=stages,
851
- pipeline_config=pipeline_config,
852
- )
853
-
854
- return pipeline
855
-
856
-
857
- __all__ = [
858
- "Pipeline",
859
- "create_pipeline_from_config",
860
- ]