xparse-client 0.2.19__py3-none-any.whl → 0.3.0b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. example/1_basic_api_usage.py +198 -0
  2. example/2_async_job.py +210 -0
  3. example/3_local_workflow.py +300 -0
  4. example/4_advanced_workflow.py +327 -0
  5. example/README.md +128 -0
  6. example/config_example.json +95 -0
  7. tests/conftest.py +310 -0
  8. tests/unit/__init__.py +1 -0
  9. tests/unit/api/__init__.py +1 -0
  10. tests/unit/api/test_extract.py +232 -0
  11. tests/unit/api/test_local.py +231 -0
  12. tests/unit/api/test_parse.py +374 -0
  13. tests/unit/api/test_pipeline.py +369 -0
  14. tests/unit/api/test_workflows.py +108 -0
  15. tests/unit/connectors/test_ftp.py +525 -0
  16. tests/unit/connectors/test_local_connectors.py +324 -0
  17. tests/unit/connectors/test_milvus.py +368 -0
  18. tests/unit/connectors/test_qdrant.py +399 -0
  19. tests/unit/connectors/test_s3.py +598 -0
  20. tests/unit/connectors/test_smb.py +442 -0
  21. tests/unit/connectors/test_utils.py +335 -0
  22. tests/unit/models/test_local.py +54 -0
  23. tests/unit/models/test_pipeline_stages.py +144 -0
  24. tests/unit/models/test_workflows.py +55 -0
  25. tests/unit/test_base.py +437 -0
  26. tests/unit/test_client.py +110 -0
  27. tests/unit/test_config.py +160 -0
  28. tests/unit/test_exceptions.py +182 -0
  29. tests/unit/test_http.py +562 -0
  30. xparse_client/__init__.py +111 -20
  31. xparse_client/_base.py +179 -0
  32. xparse_client/_client.py +218 -0
  33. xparse_client/_config.py +221 -0
  34. xparse_client/_http.py +350 -0
  35. xparse_client/api/__init__.py +14 -0
  36. xparse_client/api/extract.py +109 -0
  37. xparse_client/api/local.py +215 -0
  38. xparse_client/api/parse.py +209 -0
  39. xparse_client/api/pipeline.py +134 -0
  40. xparse_client/api/workflows.py +204 -0
  41. xparse_client/connectors/__init__.py +45 -0
  42. xparse_client/connectors/_utils.py +138 -0
  43. xparse_client/connectors/destinations/__init__.py +45 -0
  44. xparse_client/connectors/destinations/base.py +116 -0
  45. xparse_client/connectors/destinations/local.py +91 -0
  46. xparse_client/connectors/destinations/milvus.py +229 -0
  47. xparse_client/connectors/destinations/qdrant.py +238 -0
  48. xparse_client/connectors/destinations/s3.py +163 -0
  49. xparse_client/connectors/sources/__init__.py +45 -0
  50. xparse_client/connectors/sources/base.py +74 -0
  51. xparse_client/connectors/sources/ftp.py +278 -0
  52. xparse_client/connectors/sources/local.py +176 -0
  53. xparse_client/connectors/sources/s3.py +232 -0
  54. xparse_client/connectors/sources/smb.py +259 -0
  55. xparse_client/exceptions.py +398 -0
  56. xparse_client/models/__init__.py +60 -0
  57. xparse_client/models/chunk.py +39 -0
  58. xparse_client/models/embed.py +62 -0
  59. xparse_client/models/extract.py +41 -0
  60. xparse_client/models/local.py +38 -0
  61. xparse_client/models/parse.py +136 -0
  62. xparse_client/models/pipeline.py +134 -0
  63. xparse_client/models/workflows.py +74 -0
  64. xparse_client-0.3.0b3.dist-info/METADATA +1075 -0
  65. xparse_client-0.3.0b3.dist-info/RECORD +68 -0
  66. {xparse_client-0.2.19.dist-info → xparse_client-0.3.0b3.dist-info}/WHEEL +1 -1
  67. {xparse_client-0.2.19.dist-info → xparse_client-0.3.0b3.dist-info}/licenses/LICENSE +1 -1
  68. {xparse_client-0.2.19.dist-info → xparse_client-0.3.0b3.dist-info}/top_level.txt +2 -0
  69. xparse_client/pipeline/__init__.py +0 -3
  70. xparse_client/pipeline/config.py +0 -129
  71. xparse_client/pipeline/destinations.py +0 -489
  72. xparse_client/pipeline/pipeline.py +0 -690
  73. xparse_client/pipeline/sources.py +0 -583
  74. xparse_client-0.2.19.dist-info/METADATA +0 -1050
  75. xparse_client-0.2.19.dist-info/RECORD +0 -11
@@ -1,690 +0,0 @@
1
- #!/usr/bin/env python
2
- # -*- encoding: utf-8 -*-
3
-
4
- import json
5
- import logging
6
- import re
7
- import time
8
- from datetime import datetime, timezone
9
- from pathlib import Path
10
- from typing import Dict, Any, Optional, Tuple, List, Union
11
-
12
- import requests
13
-
14
- from .config import ParseConfig, ChunkConfig, EmbedConfig, Stage, PipelineStats, PipelineConfig
15
- from .sources import Source, S3Source, LocalSource, FtpSource, SmbSource
16
- from .destinations import Destination, MilvusDestination, QdrantDestination, LocalDestination, S3Destination
17
-
18
-
19
- logger = logging.getLogger(__name__)
20
-
21
-
22
- class Pipeline:
23
- """数据处理 Pipeline"""
24
-
25
- def __init__(
26
- self,
27
- source: Source,
28
- destination: Destination,
29
- api_base_url: str = 'http://localhost:8000/api/xparse',
30
- api_headers: Optional[Dict[str, str]] = None,
31
- stages: Optional[List[Stage]] = None,
32
- pipeline_config: Optional[PipelineConfig] = None,
33
- intermediate_results_destination: Optional[Destination] = None
34
- ):
35
- self.source = source
36
- self.destination = destination
37
- self.api_base_url = api_base_url.rstrip('/')
38
- self.api_headers = api_headers or {}
39
- self.pipeline_config = pipeline_config or PipelineConfig()
40
-
41
- # 处理 intermediate_results_destination 参数
42
- # 如果直接传入了 intermediate_results_destination,优先使用它并自动启用中间结果保存
43
- if intermediate_results_destination is not None:
44
- self.pipeline_config.include_intermediate_results = True
45
- self.pipeline_config.intermediate_results_destination = intermediate_results_destination
46
- # 如果 pipeline_config 中已设置,使用 pipeline_config 中的值
47
- elif self.pipeline_config.include_intermediate_results:
48
- if not self.pipeline_config.intermediate_results_destination:
49
- raise ValueError("当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination")
50
-
51
- # 处理 stages 配置
52
- if stages is None:
53
- raise ValueError("必须提供 stages 参数")
54
-
55
- self.stages = stages
56
-
57
- # 验证 stages
58
- if not self.stages or self.stages[0].type != 'parse':
59
- raise ValueError("stages 必须包含且第一个必须是 'parse' 类型")
60
-
61
- # 验证 embed config(如果存在)
62
- for stage in self.stages:
63
- if stage.type == 'embed' and isinstance(stage.config, EmbedConfig):
64
- stage.config.validate()
65
-
66
- # 验证 intermediate_results_destination
67
- if self.pipeline_config.include_intermediate_results:
68
- # 验证是否为支持的 Destination 类型
69
- from .destinations import Destination
70
- if not isinstance(self.pipeline_config.intermediate_results_destination, Destination):
71
- raise ValueError(f"intermediate_results_destination 必须是 Destination 类型")
72
- self.intermediate_results_destination = self.pipeline_config.intermediate_results_destination
73
-
74
- print("=" * 60)
75
- print("Pipeline 初始化完成")
76
- print(f" Stages: {[s.type for s in self.stages]}")
77
- for stage in self.stages:
78
- print(f" - {stage.type}: {stage.config}")
79
- if self.pipeline_config.include_intermediate_results:
80
- print(f" Pipeline Config: 中间结果保存已启用")
81
- print("=" * 60)
82
-
83
- def get_config(self) -> Dict[str, Any]:
84
- """获取 Pipeline 的完整配置信息,返回字典格式(与 create_pipeline_from_config 的入参格式一致)"""
85
- config = {}
86
-
87
- # Source 配置
88
- source_type = type(self.source).__name__.replace('Source', '').lower()
89
- config['source'] = {'type': source_type}
90
-
91
- if isinstance(self.source, S3Source):
92
- config['source'].update({
93
- 'endpoint': self.source.endpoint,
94
- 'bucket': self.source.bucket,
95
- 'prefix': self.source.prefix,
96
- 'pattern': self.source.pattern,
97
- 'recursive': self.source.recursive
98
- })
99
- # access_key 和 secret_key 不在对象中保存,无法恢复
100
- # region 也不在对象中保存,使用默认值
101
- config['source']['region'] = 'us-east-1' # 默认值
102
- elif isinstance(self.source, LocalSource):
103
- config['source'].update({
104
- 'directory': str(self.source.directory),
105
- 'pattern': self.source.pattern,
106
- 'recursive': self.source.recursive
107
- })
108
- elif isinstance(self.source, FtpSource):
109
- config['source'].update({
110
- 'host': self.source.host,
111
- 'port': self.source.port,
112
- 'username': self.source.username,
113
- 'pattern': self.source.pattern,
114
- 'recursive': self.source.recursive
115
- })
116
- # password 不在对象中保存,无法恢复
117
- elif isinstance(self.source, SmbSource):
118
- config['source'].update({
119
- 'host': self.source.host,
120
- 'share_name': self.source.share_name,
121
- 'username': self.source.username,
122
- 'domain': self.source.domain,
123
- 'port': self.source.port,
124
- 'path': self.source.path,
125
- 'pattern': self.source.pattern,
126
- 'recursive': self.source.recursive
127
- })
128
- # password 不在对象中保存,无法恢复
129
-
130
- # Destination 配置
131
- dest_type = type(self.destination).__name__.replace('Destination', '').lower()
132
- # MilvusDestination 和 Zilliz 都使用 'milvus' 或 'zilliz' 类型
133
- if dest_type == 'milvus':
134
- # 判断是本地 Milvus 还是 Zilliz(通过 db_path 判断)
135
- if self.destination.db_path.startswith('http'):
136
- dest_type = 'zilliz'
137
- else:
138
- dest_type = 'milvus'
139
-
140
- config['destination'] = {'type': dest_type}
141
-
142
- if isinstance(self.destination, MilvusDestination):
143
- config['destination'].update({
144
- 'db_path': self.destination.db_path,
145
- 'collection_name': self.destination.collection_name,
146
- 'dimension': self.destination.dimension
147
- })
148
- # api_key 和 token 不在对象中保存,无法恢复
149
- elif isinstance(self.destination, QdrantDestination):
150
- config['destination'].update({
151
- 'url': self.destination.url,
152
- 'collection_name': self.destination.collection_name,
153
- 'dimension': self.destination.dimension,
154
- 'prefer_grpc': getattr(self.destination, 'prefer_grpc', False)
155
- })
156
- # api_key 不在对象中保存,无法恢复
157
- elif isinstance(self.destination, LocalDestination):
158
- config['destination'].update({
159
- 'output_dir': str(self.destination.output_dir)
160
- })
161
- elif isinstance(self.destination, S3Destination):
162
- config['destination'].update({
163
- 'endpoint': self.destination.endpoint,
164
- 'bucket': self.destination.bucket,
165
- 'prefix': self.destination.prefix
166
- })
167
- # access_key, secret_key, region 不在对象中保存,无法恢复
168
- config['destination']['region'] = 'us-east-1' # 默认值
169
-
170
- # API 配置
171
- config['api_base_url'] = self.api_base_url
172
- config['api_headers'] = {}
173
- for key, value in self.api_headers.items():
174
- config['api_headers'][key] = value
175
-
176
- # Stages 配置
177
- config['stages'] = []
178
- for stage in self.stages:
179
- stage_dict = {
180
- 'type': stage.type,
181
- 'config': {}
182
- }
183
-
184
- if isinstance(stage.config, ParseConfig):
185
- stage_dict['config'] = stage.config.to_dict()
186
- elif isinstance(stage.config, ChunkConfig):
187
- stage_dict['config'] = stage.config.to_dict()
188
- elif isinstance(stage.config, EmbedConfig):
189
- stage_dict['config'] = stage.config.to_dict()
190
- else:
191
- # 如果 config 是字典或其他类型,尝试转换
192
- if isinstance(stage.config, dict):
193
- stage_dict['config'] = stage.config
194
- else:
195
- stage_dict['config'] = str(stage.config)
196
-
197
- config['stages'].append(stage_dict)
198
-
199
- # Pipeline Config
200
- if self.pipeline_config.include_intermediate_results:
201
- config['pipeline_config'] = {
202
- 'include_intermediate_results': True,
203
- 'intermediate_results_destination': {}
204
- }
205
-
206
- inter_dest = self.pipeline_config.intermediate_results_destination
207
- if inter_dest:
208
- inter_dest_type = type(inter_dest).__name__.replace('Destination', '').lower()
209
- config['pipeline_config']['intermediate_results_destination']['type'] = inter_dest_type
210
-
211
- if isinstance(inter_dest, LocalDestination):
212
- config['pipeline_config']['intermediate_results_destination']['output_dir'] = str(inter_dest.output_dir)
213
- elif isinstance(inter_dest, S3Destination):
214
- config['pipeline_config']['intermediate_results_destination'].update({
215
- 'endpoint': inter_dest.endpoint,
216
- 'bucket': inter_dest.bucket,
217
- 'prefix': inter_dest.prefix
218
- })
219
- # access_key, secret_key, region 不在对象中保存,无法恢复
220
- config['pipeline_config']['intermediate_results_destination']['region'] = 'us-east-1' # 默认值
221
-
222
- return config
223
-
224
- def _extract_error_message(self, response: requests.Response) -> Tuple[str, str]:
225
- """
226
- 从响应中提取规范化的错误信息
227
-
228
- Returns:
229
- Tuple[str, str]: (error_msg, x_request_id)
230
- """
231
- # 首先尝试从响应头中提取 x-request-id(requests的headers大小写不敏感)
232
- x_request_id = response.headers.get('x-request-id', '')
233
- error_msg = ''
234
-
235
- # 获取Content-Type
236
- content_type = response.headers.get('Content-Type', '').lower()
237
-
238
- # 尝试解析JSON响应
239
- if 'application/json' in content_type:
240
- try:
241
- result = response.json()
242
- # 如果响应头中没有x-request-id,尝试从响应体中获取
243
- if not x_request_id:
244
- x_request_id = result.get('x_request_id', '')
245
- error_msg = result.get('message', result.get('msg', f'HTTP {response.status_code}'))
246
- return error_msg, x_request_id
247
- except:
248
- pass
249
-
250
- # 处理HTML响应
251
- if 'text/html' in content_type or response.text.strip().startswith('<'):
252
- try:
253
- # 从HTML中提取标题(通常包含状态码和状态文本)
254
- title_match = re.search(r'<title>(.*?)</title>', response.text, re.IGNORECASE)
255
- if title_match:
256
- error_msg = title_match.group(1).strip()
257
- else:
258
- # 如果没有title,尝试提取h1标签
259
- h1_match = re.search(r'<h1>(.*?)</h1>', response.text, re.IGNORECASE)
260
- if h1_match:
261
- error_msg = h1_match.group(1).strip()
262
- else:
263
- error_msg = f'HTTP {response.status_code}'
264
- except:
265
- error_msg = f'HTTP {response.status_code}'
266
-
267
- # 处理纯文本响应
268
- elif 'text/plain' in content_type:
269
- error_msg = response.text[:200].strip() if response.text else f'HTTP {response.status_code}'
270
-
271
- # 其他情况
272
- else:
273
- if response.text:
274
- # 尝试截取前200字符,但去除换行和多余空格
275
- text = response.text[:200].strip()
276
- # 如果包含多行,只取第一行
277
- if '\n' in text:
278
- text = text.split('\n')[0].strip()
279
- error_msg = text if text else f'HTTP {response.status_code}'
280
- else:
281
- error_msg = f'HTTP {response.status_code}'
282
-
283
- return error_msg, x_request_id
284
-
285
- def _call_pipeline_api(self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]) -> Optional[Dict[str, Any]]:
286
- url = f"{self.api_base_url}/pipeline"
287
- max_retries = 3
288
-
289
- for try_count in range(max_retries):
290
- try:
291
- files = {'file': (filename or 'file', file_bytes)}
292
- form_data = {}
293
-
294
- # 将 stages 转换为 API 格式
295
- stages_data = [stage.to_dict() for stage in self.stages]
296
- try:
297
- form_data['stages'] = json.dumps(stages_data)
298
- form_data['data_source'] = json.dumps(data_source, ensure_ascii=False)
299
-
300
- # 如果启用了中间结果保存,在请求中添加参数
301
- if self.pipeline_config:
302
- form_data['config'] = json.dumps(self.pipeline_config.to_dict(), ensure_ascii=False)
303
- except Exception as e:
304
- print(f" ✗ 入参处理失败,请检查配置: {e}")
305
- logger.error(f"入参处理失败,请检查配置: {e}")
306
- return None
307
-
308
- response = requests.post(
309
- url,
310
- files=files,
311
- data=form_data,
312
- headers=self.api_headers,
313
- timeout=630
314
- )
315
-
316
- if response.status_code == 200:
317
- result = response.json()
318
- x_request_id = result.get('x_request_id', '')
319
- print(f" ✓ Pipeline 接口返回 x_request_id: {x_request_id}")
320
- if result.get('code') == 200 and 'data' in result:
321
- return result.get('data')
322
- # 如果 code 不是 200,打印错误信息
323
- error_msg = result.get('message', result.get('msg', '未知错误'))
324
- print(f" ✗ Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}")
325
- logger.error(f"Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}")
326
- return None
327
- else:
328
- # 使用规范化函数提取错误信息
329
- error_msg, x_request_id = self._extract_error_message(response)
330
-
331
- print(f" ✗ API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}")
332
- logger.warning(f"API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}")
333
-
334
- except Exception as e:
335
- # 如果是 requests 异常,尝试从响应中获取 x_request_id
336
- x_request_id = ''
337
- error_msg = str(e)
338
- try:
339
- if hasattr(e, 'response') and e.response is not None:
340
- try:
341
- result = e.response.json()
342
- x_request_id = result.get('x_request_id', '')
343
- error_msg = result.get('message', result.get('msg', error_msg))
344
- except:
345
- pass
346
- except:
347
- pass
348
-
349
- print(f" ✗ 请求异常: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}")
350
- logger.error(f"API 请求异常 pipeline: {error_msg}, x_request_id={x_request_id}")
351
-
352
- if try_count < max_retries - 1:
353
- time.sleep(2)
354
-
355
- return None
356
-
357
- def process_with_pipeline(self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]) -> Optional[Tuple[List[Dict[str, Any]], PipelineStats]]:
358
- print(f" → 调用 Pipeline 接口: {filename}")
359
- result = self._call_pipeline_api(file_bytes, filename, data_source)
360
-
361
- if result and 'elements' in result and 'stats' in result:
362
- elements = result['elements']
363
- stats_data = result['stats']
364
-
365
- stats = PipelineStats(
366
- original_elements=stats_data.get('original_elements', 0),
367
- chunked_elements=stats_data.get('chunked_elements', 0),
368
- embedded_elements=stats_data.get('embedded_elements', 0),
369
- stages=self.stages, # 使用实际执行的 stages
370
- record_id=stats_data.get('record_id') # 从 API 响应中获取 record_id
371
- )
372
-
373
- # 如果启用了中间结果保存,处理中间结果
374
- if self.pipeline_config.include_intermediate_results and 'intermediate_results' in result:
375
- self._save_intermediate_results(result['intermediate_results'], filename, data_source)
376
-
377
- print(f" ✓ Pipeline 完成:")
378
- print(f" - 原始元素: {stats.original_elements}")
379
- print(f" - 分块后: {stats.chunked_elements}")
380
- print(f" - 向量化: {stats.embedded_elements}")
381
- logger.info(f"Pipeline 完成: {filename}, {stats.embedded_elements} 个向量")
382
-
383
- return elements, stats
384
- else:
385
- print(f" ✗ Pipeline 失败")
386
- logger.error(f"Pipeline 失败: {filename}")
387
- return None
388
-
389
- def _save_intermediate_results(self, intermediate_results: List[Dict[str, Any]], filename: str, data_source: Dict[str, Any]) -> None:
390
- """保存中间结果
391
-
392
- Args:
393
- intermediate_results: 中间结果数组,每个元素包含 stage 和 elements 字段
394
- filename: 文件名
395
- data_source: 数据源信息
396
- """
397
- try:
398
- # intermediate_results 是一个数组,每个元素是 {stage: str, elements: List}
399
- for result_item in intermediate_results:
400
- if 'stage' not in result_item or 'elements' not in result_item:
401
- logger.warning(f"中间结果项缺少 stage 或 elements 字段: {result_item}")
402
- continue
403
-
404
- stage = result_item['stage']
405
- elements = result_item['elements']
406
-
407
- metadata = {
408
- 'filename': filename,
409
- 'stage': stage,
410
- 'total_elements': len(elements),
411
- 'processed_at': datetime.now().isoformat(),
412
- 'data_source': data_source
413
- }
414
-
415
- self.pipeline_config.intermediate_results_destination.write(elements, metadata)
416
- print(f" ✓ 保存 {stage.upper()} 中间结果: {len(elements)} 个元素")
417
- logger.info(f"保存 {stage.upper()} 中间结果成功: {filename}")
418
-
419
- except Exception as e:
420
- print(f" ✗ 保存中间结果失败: {str(e)}")
421
- logger.error(f"保存中间结果失败: {filename}, {str(e)}")
422
-
423
- def process_file(self, file_path: str) -> bool:
424
- print(f"\n{'=' * 60}")
425
- print(f"处理文件: {file_path}")
426
- logger.info(f"开始处理文件: {file_path}")
427
-
428
- try:
429
- print(f" → 读取文件...")
430
- file_bytes, data_source = self.source.read_file(file_path)
431
- data_source = data_source or {}
432
-
433
- # 检查文件大小,超过 100MB 则报错
434
- MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
435
- file_size = len(file_bytes)
436
- if file_size > MAX_FILE_SIZE:
437
- file_size_mb = file_size / (1024 * 1024)
438
- raise ValueError(f"文件大小过大: {file_size_mb:.2f}MB,超过100MB限制")
439
-
440
- # 转换为毫秒时间戳字符串
441
- timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000)
442
- data_source['date_processed'] = str(timestamp_ms)
443
- print(f" ✓ 文件读取完成: {len(file_bytes)} bytes")
444
-
445
- result = self.process_with_pipeline(file_bytes, file_path, data_source)
446
- if not result:
447
- return False
448
-
449
- embedded_data, stats = result
450
-
451
- print(f" → 写入目的地...")
452
- metadata = {
453
- 'filename': file_path,
454
- 'processed_at': str(timestamp_ms),
455
- }
456
-
457
- # 如果 stats 中有 record_id,添加到 metadata 中
458
- if stats.record_id:
459
- metadata['record_id'] = stats.record_id
460
-
461
- success = self.destination.write(embedded_data, metadata)
462
-
463
- if success:
464
- print(f"\n✓✓✓ 文件处理成功: {file_path}")
465
- logger.info(f"文件处理成功: {file_path}")
466
- else:
467
- print(f"\n✗✗✗ 文件处理失败: {file_path}")
468
- logger.error(f"文件处理失败: {file_path}")
469
-
470
- return success
471
-
472
- except Exception as e:
473
- print(f"\n✗✗✗ 处理异常: {str(e)}")
474
- logger.error(f"处理文件异常 {file_path}: {str(e)}")
475
- return False
476
-
477
- def run(self):
478
- start_time = time.time()
479
-
480
- print("\n" + "=" * 60)
481
- print("开始执行 Pipeline")
482
- print("=" * 60)
483
- logger.info("=" * 60)
484
- logger.info("开始执行 Pipeline")
485
-
486
- print("\n→ 列出文件...")
487
- files = self.source.list_files()
488
-
489
- if not files:
490
- print("\n✗ 没有找到文件")
491
- logger.info("没有找到文件")
492
- return
493
-
494
- total = len(files)
495
- success_count = 0
496
- fail_count = 0
497
-
498
- for idx, file_path in enumerate(files, 1):
499
- print(f"\n进度: [{idx}/{total}]")
500
- logger.info(f"进度: [{idx}/{total}] - {file_path}")
501
-
502
- try:
503
- if self.process_file(file_path):
504
- success_count += 1
505
- else:
506
- fail_count += 1
507
- except Exception as e:
508
- print(f"\n✗✗✗ 文件处理异常: {str(e)}")
509
- logger.error(f"文件处理异常 {file_path}: {str(e)}")
510
- fail_count += 1
511
-
512
- if idx < total:
513
- time.sleep(1)
514
-
515
- elapsed = time.time() - start_time
516
- print("\n" + "=" * 60)
517
- print("Pipeline 执行完成!")
518
- print("=" * 60)
519
- print(f"总文件数: {total}")
520
- print(f"成功: {success_count}")
521
- print(f"失败: {fail_count}")
522
- print(f"总耗时: {elapsed:.2f} 秒")
523
- print("=" * 60)
524
-
525
- logger.info("=" * 60)
526
- logger.info(f"Pipeline 完成 - 总数:{total}, 成功:{success_count}, 失败:{fail_count}, 耗时:{elapsed:.2f}秒")
527
- logger.info("=" * 60)
528
-
529
-
530
- def create_pipeline_from_config(config: Dict[str, Any]) -> Pipeline:
531
- source_config = config['source']
532
- if source_config['type'] == 's3':
533
- source = S3Source(
534
- endpoint=source_config['endpoint'],
535
- access_key=source_config['access_key'],
536
- secret_key=source_config['secret_key'],
537
- bucket=source_config['bucket'],
538
- prefix=source_config.get('prefix', ''),
539
- region=source_config.get('region', 'us-east-1'),
540
- pattern=source_config.get('pattern', None),
541
- recursive=source_config.get('recursive', False)
542
- )
543
- elif source_config['type'] == 'local':
544
- source = LocalSource(
545
- directory=source_config['directory'],
546
- pattern=source_config.get('pattern', None),
547
- recursive=source_config.get('recursive', False)
548
- )
549
- elif source_config['type'] == 'ftp':
550
- source = FtpSource(
551
- host=source_config['host'],
552
- port=source_config['port'],
553
- username=source_config['username'],
554
- password=source_config['password'],
555
- pattern=source_config.get('pattern', None),
556
- recursive=source_config.get('recursive', False)
557
- )
558
- elif source_config['type'] == 'smb':
559
- source = SmbSource(
560
- host=source_config['host'],
561
- share_name=source_config['share_name'],
562
- username=source_config['username'],
563
- password=source_config['password'],
564
- domain=source_config.get('domain', ''),
565
- port=source_config.get('port', 445),
566
- path=source_config.get('path', ''),
567
- pattern=source_config.get('pattern', None),
568
- recursive=source_config.get('recursive', False)
569
- )
570
- else:
571
- raise ValueError(f"未知的 source 类型: {source_config['type']}")
572
-
573
- dest_config = config['destination']
574
- if dest_config['type'] in ['milvus', 'zilliz']:
575
- destination = MilvusDestination(
576
- db_path=dest_config['db_path'],
577
- collection_name=dest_config['collection_name'],
578
- dimension=dest_config['dimension'],
579
- api_key=dest_config.get('api_key'),
580
- token=dest_config.get('token')
581
- )
582
- elif dest_config['type'] == 'qdrant':
583
- destination = QdrantDestination(
584
- url=dest_config['url'],
585
- collection_name=dest_config['collection_name'],
586
- dimension=dest_config['dimension'],
587
- api_key=dest_config.get('api_key'),
588
- prefer_grpc=dest_config.get('prefer_grpc', False)
589
- )
590
- elif dest_config['type'] == 'local':
591
- destination = LocalDestination(
592
- output_dir=dest_config['output_dir']
593
- )
594
- elif dest_config['type'] == 's3':
595
- destination = S3Destination(
596
- endpoint=dest_config['endpoint'],
597
- access_key=dest_config['access_key'],
598
- secret_key=dest_config['secret_key'],
599
- bucket=dest_config['bucket'],
600
- prefix=dest_config.get('prefix', ''),
601
- region=dest_config.get('region', 'us-east-1')
602
- )
603
- else:
604
- raise ValueError(f"未知的 destination 类型: {dest_config['type']}")
605
-
606
- # 处理 stages 配置
607
- if 'stages' not in config or not config['stages']:
608
- raise ValueError("配置中必须包含 'stages' 字段")
609
-
610
- stages = []
611
- for stage_cfg in config['stages']:
612
- stage_type = stage_cfg.get('type')
613
- stage_config_dict = stage_cfg.get('config', {})
614
-
615
- if stage_type == 'parse':
616
- parse_cfg_copy = dict(stage_config_dict)
617
- provider = parse_cfg_copy.pop('provider', 'textin')
618
- stage_config = ParseConfig(provider=provider, **parse_cfg_copy)
619
- elif stage_type == 'chunk':
620
- stage_config = ChunkConfig(
621
- strategy=stage_config_dict.get('strategy', 'basic'),
622
- include_orig_elements=stage_config_dict.get('include_orig_elements', False),
623
- new_after_n_chars=stage_config_dict.get('new_after_n_chars', 512),
624
- max_characters=stage_config_dict.get('max_characters', 1024),
625
- overlap=stage_config_dict.get('overlap', 0),
626
- overlap_all=stage_config_dict.get('overlap_all', False)
627
- )
628
- elif stage_type == 'embed':
629
- stage_config = EmbedConfig(
630
- provider=stage_config_dict.get('provider', 'qwen'),
631
- model_name=stage_config_dict.get('model_name', 'text-embedding-v3')
632
- )
633
- else:
634
- raise ValueError(f"未知的 stage 类型: {stage_type}")
635
-
636
- stages.append(Stage(type=stage_type, config=stage_config))
637
-
638
- # 创建 Pipeline 配置
639
- pipeline_config = None
640
- if 'pipeline_config' in config and config['pipeline_config']:
641
- pipeline_cfg = config['pipeline_config']
642
- include_intermediate_results = pipeline_cfg.get('include_intermediate_results', False)
643
- intermediate_results_destination = None
644
-
645
- if include_intermediate_results:
646
- if 'intermediate_results_destination' in pipeline_cfg:
647
- dest_cfg = pipeline_cfg['intermediate_results_destination']
648
- dest_type = dest_cfg.get('type')
649
-
650
- if dest_type == 'local':
651
- intermediate_results_destination = LocalDestination(
652
- output_dir=dest_cfg['output_dir']
653
- )
654
- elif dest_type == 's3':
655
- intermediate_results_destination = S3Destination(
656
- endpoint=dest_cfg['endpoint'],
657
- access_key=dest_cfg['access_key'],
658
- secret_key=dest_cfg['secret_key'],
659
- bucket=dest_cfg['bucket'],
660
- prefix=dest_cfg.get('prefix', ''),
661
- region=dest_cfg.get('region', 'us-east-1')
662
- )
663
- else:
664
- raise ValueError(f"不支持的 intermediate_results_destination 类型: '{dest_type}',支持的类型: 'local', 's3'")
665
- else:
666
- raise ValueError("当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination")
667
-
668
- pipeline_config = PipelineConfig(
669
- include_intermediate_results=include_intermediate_results,
670
- intermediate_results_destination=intermediate_results_destination
671
- )
672
-
673
- # 创建 Pipeline
674
- pipeline = Pipeline(
675
- source=source,
676
- destination=destination,
677
- api_base_url=config.get('api_base_url', 'http://localhost:8000/api/xparse'),
678
- api_headers=config.get('api_headers', {}),
679
- stages=stages,
680
- pipeline_config=pipeline_config
681
- )
682
-
683
- return pipeline
684
-
685
-
686
- __all__ = [
687
- 'Pipeline',
688
- 'create_pipeline_from_config',
689
- ]
690
-