xparse-client 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,220 @@
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import json
5
+ import logging
6
+ import uuid
7
+ import boto3
8
+
9
+ from abc import ABC, abstractmethod
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import List, Dict, Any
13
+
14
+ from botocore.config import Config
15
+ from pymilvus import MilvusClient
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class Destination(ABC):
22
+ """数据目的地抽象基类"""
23
+
24
+ @abstractmethod
25
+ def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
26
+ """写入数据"""
27
+ raise NotImplementedError
28
+
29
+
30
+ class MilvusDestination(Destination):
31
+ """Milvus/Zilliz 向量数据库目的地"""
32
+
33
+ def __init__(self, db_path: str, collection_name: str, dimension: int, api_key: str = None, token: str = None):
34
+ from pymilvus import DataType
35
+
36
+ self.db_path = db_path
37
+ self.collection_name = collection_name
38
+ self.dimension = dimension
39
+
40
+ client_kwargs = {'uri': db_path}
41
+ if api_key:
42
+ client_kwargs['token'] = api_key
43
+ elif token:
44
+ client_kwargs['token'] = token
45
+
46
+ self.client = MilvusClient(**client_kwargs)
47
+
48
+ if not self.client.has_collection(collection_name):
49
+ schema = self.client.create_schema(
50
+ auto_id=False,
51
+ enable_dynamic_field=True
52
+ )
53
+
54
+ schema.add_field(field_name="element_id", datatype=DataType.VARCHAR, max_length=128, is_primary=True)
55
+ schema.add_field(field_name="embeddings", datatype=DataType.FLOAT_VECTOR, dim=dimension)
56
+ schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
57
+ schema.add_field(field_name="record_id", datatype=DataType.VARCHAR, max_length=128)
58
+ schema.add_field(field_name="metadata", datatype=DataType.JSON)
59
+
60
+ index_params = self.client.prepare_index_params()
61
+ index_params.add_index(
62
+ field_name="embeddings",
63
+ index_type="AUTOINDEX",
64
+ metric_type="COSINE"
65
+ )
66
+
67
+ self.client.create_collection(
68
+ collection_name=collection_name,
69
+ schema=schema,
70
+ index_params=index_params
71
+ )
72
+ print(f"✓ Milvus/Zilliz 集合创建: {collection_name} (自定义 Schema)")
73
+ else:
74
+ print(f"✓ Milvus/Zilliz 集合存在: {collection_name}")
75
+
76
+ logger.info(f"Milvus/Zilliz 连接成功: {db_path}")
77
+
78
+ def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
79
+ try:
80
+ insert_data = []
81
+ for item in data:
82
+ metadata = item.get('metadata', {})
83
+ if 'embeddings' in item and item['embeddings']:
84
+ element_id = item.get('element_id') or item.get('id') or str(uuid.uuid4())
85
+ insert_data.append({
86
+ 'embeddings': item['embeddings'],
87
+ 'text': item.get('text', ''),
88
+ 'element_id': element_id,
89
+ 'record_id': metadata.get('record_id', ''),
90
+ 'metadata': metadata,
91
+ 'created_at': datetime.now().isoformat()
92
+ })
93
+
94
+ if not insert_data:
95
+ print(f" ! 警告: 没有有效的向量数据")
96
+ return False
97
+
98
+ self.client.insert(
99
+ collection_name=self.collection_name,
100
+ data=insert_data
101
+ )
102
+ print(f" ✓ 写入 Milvus: {len(insert_data)} 条")
103
+ logger.info(f"写入 Milvus 成功: {len(insert_data)} 条")
104
+ return True
105
+ except Exception as e:
106
+ print(f" ✗ 写入 Milvus 失败: {str(e)}")
107
+ logger.error(f"写入 Milvus 失败: {str(e)}")
108
+ return False
109
+
110
+
111
+ class LocalDestination(Destination):
112
+ """本地文件系统目的地"""
113
+
114
+ def __init__(self, output_dir: str):
115
+ self.output_dir = Path(output_dir)
116
+ self.output_dir.mkdir(parents=True, exist_ok=True)
117
+ print(f"✓ 输出目录: {self.output_dir}")
118
+ logger.info(f"输出目录: {self.output_dir}")
119
+
120
+ def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
121
+ try:
122
+ file_name = metadata.get('file_name', 'output')
123
+ base_name = Path(file_name).stem
124
+ stage = metadata.get('stage') # 用于区分中间结果的阶段
125
+
126
+ # 如果是中间结果,在文件名中添加阶段标识
127
+ if stage:
128
+ output_file = self.output_dir / f"{base_name}_{stage}.json"
129
+ else:
130
+ output_file = self.output_dir / f"{base_name}.json"
131
+
132
+ with open(output_file, 'w', encoding='utf-8') as f:
133
+ json.dump(data, f, ensure_ascii=False, indent=2)
134
+
135
+ print(f" ✓ 写入本地: {output_file}")
136
+ logger.info(f"写入本地成功: {output_file}")
137
+ return True
138
+ except Exception as e:
139
+ print(f" ✗ 写入本地失败: {str(e)}")
140
+ logger.error(f"写入本地失败: {str(e)}")
141
+ return False
142
+
143
+
144
+ class S3Destination(Destination):
145
+ """S3/MinIO 数据目的地"""
146
+
147
+ def __init__(self, endpoint: str, access_key: str, secret_key: str,
148
+ bucket: str, prefix: str = '', region: str = 'us-east-1'):
149
+ self.endpoint = endpoint
150
+ self.bucket = bucket
151
+ self.prefix = prefix.strip('/') if prefix else ''
152
+
153
+ if self.endpoint == 'https://textin-minio-api.ai.intsig.net':
154
+ config = Config(signature_version='s3v4')
155
+ elif self.endpoint.endswith('aliyuncs.com'):
156
+ config = Config(signature_version='s3', s3={'addressing_style': 'virtual'})
157
+ else:
158
+ config = Config(signature_version='s3v4', s3={'addressing_style': 'virtual'})
159
+
160
+ self.client = boto3.client(
161
+ 's3',
162
+ endpoint_url=endpoint,
163
+ aws_access_key_id=access_key,
164
+ aws_secret_access_key=secret_key,
165
+ region_name=region,
166
+ config=config
167
+ )
168
+
169
+ try:
170
+ self.client.head_bucket(Bucket=bucket)
171
+ test_key = f"{self.prefix}/empty.tmp" if self.prefix else f"empty.tmp"
172
+ self.client.put_object(
173
+ Bucket=bucket,
174
+ Key=test_key,
175
+ Body=b''
176
+ )
177
+ try:
178
+ self.client.delete_object(Bucket=bucket, Key=test_key)
179
+ except Exception:
180
+ pass
181
+
182
+ print(f"✓ S3 连接成功且可写: {endpoint}/{bucket}")
183
+ logger.info(f"S3 连接成功且可写: {endpoint}/{bucket}")
184
+ except Exception as e:
185
+ print(f"✗ S3 连接或写入测试失败: {str(e)}")
186
+ logger.error(f"S3 连接或写入测试失败: {str(e)}")
187
+ raise
188
+
189
+ def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
190
+ try:
191
+ file_name = metadata.get('file_name', 'output')
192
+ base_name = Path(file_name).stem
193
+ object_key = f"{self.prefix}/{base_name}.json" if self.prefix else f"{base_name}.json"
194
+
195
+ json_data = json.dumps(data, ensure_ascii=False, indent=2)
196
+ json_bytes = json_data.encode('utf-8')
197
+
198
+ self.client.put_object(
199
+ Bucket=self.bucket,
200
+ Key=object_key,
201
+ Body=json_bytes,
202
+ ContentType='application/json'
203
+ )
204
+
205
+ print(f" ✓ 写入 S3: {self.endpoint}/{self.bucket}/{object_key}")
206
+ logger.info(f"写入 S3 成功: {self.endpoint}/{self.bucket}/{object_key}")
207
+ return True
208
+ except Exception as e:
209
+ print(f" ✗ 写入 S3 失败: {str(e)}")
210
+ logger.error(f"写入 S3 失败: {str(e)}")
211
+ return False
212
+
213
+
214
+ __all__ = [
215
+ 'Destination',
216
+ 'MilvusDestination',
217
+ 'LocalDestination',
218
+ 'S3Destination',
219
+ ]
220
+
@@ -0,0 +1,440 @@
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import json
5
+ import logging
6
+ import time
7
+ from datetime import datetime, timezone
8
+ from pathlib import Path
9
+ from typing import Dict, Any, Optional, Tuple, List, Union
10
+
11
+ import requests
12
+
13
+ from .config import ParseConfig, ChunkConfig, EmbedConfig, Stage, PipelineStats, PipelineConfig
14
+ from .sources import Source, S3Source, LocalSource, FtpSource, SmbSource
15
+ from .destinations import Destination, MilvusDestination, LocalDestination, S3Destination
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class Pipeline:
22
+ """数据处理 Pipeline"""
23
+
24
+ def __init__(
25
+ self,
26
+ source: Source,
27
+ destination: Destination,
28
+ api_base_url: str = 'http://localhost:8000/api/xparse',
29
+ api_headers: Optional[Dict[str, str]] = None,
30
+ stages: Optional[List[Stage]] = None,
31
+ pipeline_config: Optional[PipelineConfig] = None,
32
+ intermediate_results_destination: Optional[Destination] = None
33
+ ):
34
+ self.source = source
35
+ self.destination = destination
36
+ self.api_base_url = api_base_url.rstrip('/')
37
+ self.api_headers = api_headers or {}
38
+ self.pipeline_config = pipeline_config or PipelineConfig()
39
+
40
+ # 处理 intermediate_results_destination 参数
41
+ # 如果直接传入了 intermediate_results_destination,优先使用它并自动启用中间结果保存
42
+ if intermediate_results_destination is not None:
43
+ self.pipeline_config.include_intermediate_results = True
44
+ self.pipeline_config.intermediate_results_destination = intermediate_results_destination
45
+ # 如果 pipeline_config 中已设置,使用 pipeline_config 中的值
46
+ elif self.pipeline_config.include_intermediate_results:
47
+ if not self.pipeline_config.intermediate_results_destination:
48
+ raise ValueError("当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination")
49
+
50
+ # 处理 stages 配置
51
+ if stages is None:
52
+ raise ValueError("必须提供 stages 参数")
53
+
54
+ self.stages = stages
55
+
56
+ # 验证 stages
57
+ if not self.stages or self.stages[0].type != 'parse':
58
+ raise ValueError("stages 必须包含且第一个必须是 'parse' 类型")
59
+
60
+ # 验证 embed config(如果存在)
61
+ for stage in self.stages:
62
+ if stage.type == 'embed' and isinstance(stage.config, EmbedConfig):
63
+ stage.config.validate()
64
+
65
+ # 验证 intermediate_results_destination
66
+ if self.pipeline_config.include_intermediate_results:
67
+ # 验证是否为支持的 Destination 类型
68
+ from .destinations import Destination
69
+ if not isinstance(self.pipeline_config.intermediate_results_destination, Destination):
70
+ raise ValueError(f"intermediate_results_destination 必须是 Destination 类型")
71
+ self.intermediate_results_destination = self.pipeline_config.intermediate_results_destination
72
+
73
+ print("=" * 60)
74
+ print("Pipeline 初始化完成")
75
+ print(f" Stages: {[s.type for s in self.stages]}")
76
+ for stage in self.stages:
77
+ print(f" - {stage.type}: {stage.config}")
78
+ if self.pipeline_config.include_intermediate_results:
79
+ print(f" Pipeline Config: 中间结果保存已启用")
80
+ print("=" * 60)
81
+
82
+ def _call_pipeline_api(self, file_bytes: bytes, file_name: str, data_source: Dict[str, Any]) -> Optional[Dict[str, Any]]:
83
+ url = f"{self.api_base_url}/pipeline"
84
+ max_retries = 3
85
+
86
+ for try_count in range(max_retries):
87
+ try:
88
+ files = {'file': (file_name or 'file', file_bytes)}
89
+ form_data = {}
90
+
91
+ # 将 stages 转换为 API 格式
92
+ stages_data = [stage.to_dict() for stage in self.stages]
93
+ form_data['stages'] = json.dumps(stages_data)
94
+ form_data['data_source'] = json.dumps(data_source, ensure_ascii=False)
95
+
96
+ # 如果启用了中间结果保存,在请求中添加参数
97
+ if self.pipeline_config:
98
+ form_data['config'] = json.dumps(self.pipeline_config.to_dict(), ensure_ascii=False)
99
+
100
+ response = requests.post(
101
+ url,
102
+ files=files,
103
+ data=form_data,
104
+ headers=self.api_headers,
105
+ timeout=120
106
+ )
107
+
108
+ if response.status_code == 200:
109
+ result = response.json()
110
+ print(f" ✓ Pipeline 接口返回 x_request_id: {result.get('x_request_id')}")
111
+ if result.get('code') == 200 and 'data' in result:
112
+ return result.get('data')
113
+ return None
114
+ else:
115
+ print(f" ! API 错误 {response.status_code}, 重试 {try_count + 1}/{max_retries}")
116
+ logger.warning(f"API 错误 {response.status_code}: pipeline")
117
+
118
+ except Exception as e:
119
+ print(f" ! 请求异常: {str(e)}, 重试 {try_count + 1}/{max_retries}")
120
+ logger.error(f"API 请求异常 pipeline: {str(e)}")
121
+
122
+ if try_count < max_retries - 1:
123
+ time.sleep(2)
124
+
125
+ return None
126
+
127
+ def process_with_pipeline(self, file_bytes: bytes, file_name: str, data_source: Dict[str, Any]) -> Optional[Tuple[List[Dict[str, Any]], PipelineStats]]:
128
+ print(f" → 调用 Pipeline 接口: {file_name}")
129
+ result = self._call_pipeline_api(file_bytes, file_name, data_source)
130
+
131
+ if result and 'elements' in result and 'stats' in result:
132
+ elements = result['elements']
133
+ stats_data = result['stats']
134
+
135
+ stats = PipelineStats(
136
+ original_elements=stats_data.get('original_elements', 0),
137
+ chunked_elements=stats_data.get('chunked_elements', 0),
138
+ embedded_elements=stats_data.get('embedded_elements', 0),
139
+ stages=self.stages # 使用实际执行的 stages
140
+ )
141
+
142
+ # 如果启用了中间结果保存,处理中间结果
143
+ if self.pipeline_config.include_intermediate_results and 'intermediate_results' in result:
144
+ self._save_intermediate_results(result['intermediate_results'], file_name, data_source)
145
+
146
+ print(f" ✓ Pipeline 完成:")
147
+ print(f" - 原始元素: {stats.original_elements}")
148
+ print(f" - 分块后: {stats.chunked_elements}")
149
+ print(f" - 向量化: {stats.embedded_elements}")
150
+ logger.info(f"Pipeline 完成: {file_name}, {stats.embedded_elements} 个向量")
151
+
152
+ return elements, stats
153
+ else:
154
+ print(f" ✗ Pipeline 失败")
155
+ logger.error(f"Pipeline 失败: {file_name}")
156
+ return None
157
+
158
+ def _save_intermediate_results(self, intermediate_results: List[Dict[str, Any]], file_name: str, data_source: Dict[str, Any]) -> None:
159
+ """保存中间结果
160
+
161
+ Args:
162
+ intermediate_results: 中间结果数组,每个元素包含 stage 和 elements 字段
163
+ file_name: 文件名
164
+ data_source: 数据源信息
165
+ """
166
+ try:
167
+ # intermediate_results 是一个数组,每个元素是 {stage: str, elements: List}
168
+ for result_item in intermediate_results:
169
+ if 'stage' not in result_item or 'elements' not in result_item:
170
+ logger.warning(f"中间结果项缺少 stage 或 elements 字段: {result_item}")
171
+ continue
172
+
173
+ stage = result_item['stage']
174
+ elements = result_item['elements']
175
+
176
+ metadata = {
177
+ 'file_name': file_name,
178
+ 'stage': stage,
179
+ 'total_elements': len(elements),
180
+ 'processed_at': datetime.now().isoformat(),
181
+ 'data_source': data_source
182
+ }
183
+
184
+ self.pipeline_config.intermediate_results_destination.write(elements, metadata)
185
+ print(f" ✓ 保存 {stage.upper()} 中间结果: {len(elements)} 个元素")
186
+ logger.info(f"保存 {stage.upper()} 中间结果成功: {file_name}")
187
+
188
+ except Exception as e:
189
+ print(f" ✗ 保存中间结果失败: {str(e)}")
190
+ logger.error(f"保存中间结果失败: {file_name}, {str(e)}")
191
+
192
+ def process_file(self, file_path: str) -> bool:
193
+ print(f"\n{'=' * 60}")
194
+ print(f"处理文件: {file_path}")
195
+ logger.info(f"开始处理文件: {file_path}")
196
+
197
+ try:
198
+ print(f" → 读取文件...")
199
+ file_bytes, data_source = self.source.read_file(file_path)
200
+ data_source = data_source or {}
201
+ data_source['date_processed'] = datetime.now(timezone.utc).timestamp()
202
+ print(f" ✓ 文件读取完成: {len(file_bytes)} bytes")
203
+
204
+ result = self.process_with_pipeline(file_bytes, file_path, data_source)
205
+ if not result:
206
+ return False
207
+
208
+ embedded_data, stats = result
209
+
210
+ print(f" → 写入目的地...")
211
+ metadata = {
212
+ 'file_name': file_path,
213
+ 'total_elements': len(embedded_data),
214
+ 'processed_at': datetime.now().isoformat(),
215
+ 'data_source': data_source,
216
+ 'stats': {
217
+ 'original_elements': stats.original_elements,
218
+ 'chunked_elements': stats.chunked_elements,
219
+ 'embedded_elements': stats.embedded_elements
220
+ }
221
+ }
222
+
223
+ success = self.destination.write(embedded_data, metadata)
224
+
225
+ if success:
226
+ print(f"\n✓✓✓ 文件处理成功: {file_path}")
227
+ logger.info(f"文件处理成功: {file_path}")
228
+ else:
229
+ print(f"\n✗✗✗ 文件处理失败: {file_path}")
230
+ logger.error(f"文件处理失败: {file_path}")
231
+
232
+ return success
233
+
234
+ except Exception as e:
235
+ print(f"\n✗✗✗ 处理异常: {str(e)}")
236
+ logger.error(f"处理文件异常 {file_path}: {str(e)}")
237
+ return False
238
+
239
+ def run(self):
240
+ start_time = time.time()
241
+
242
+ print("\n" + "=" * 60)
243
+ print("开始执行 Pipeline")
244
+ print("=" * 60)
245
+ logger.info("=" * 60)
246
+ logger.info("开始执行 Pipeline")
247
+
248
+ print("\n→ 列出文件...")
249
+ files = self.source.list_files()
250
+
251
+ if not files:
252
+ print("\n✗ 没有找到文件")
253
+ logger.info("没有找到文件")
254
+ return
255
+
256
+ total = len(files)
257
+ success_count = 0
258
+ fail_count = 0
259
+
260
+ for idx, file_path in enumerate(files, 1):
261
+ print(f"\n进度: [{idx}/{total}]")
262
+ logger.info(f"进度: [{idx}/{total}] - {file_path}")
263
+
264
+ try:
265
+ if self.process_file(file_path):
266
+ success_count += 1
267
+ else:
268
+ fail_count += 1
269
+ except Exception as e:
270
+ print(f"\n✗✗✗ 文件处理异常: {str(e)}")
271
+ logger.error(f"文件处理异常 {file_path}: {str(e)}")
272
+ fail_count += 1
273
+
274
+ if idx < total:
275
+ time.sleep(1)
276
+
277
+ elapsed = time.time() - start_time
278
+ print("\n" + "=" * 60)
279
+ print("Pipeline 执行完成!")
280
+ print("=" * 60)
281
+ print(f"总文件数: {total}")
282
+ print(f"成功: {success_count}")
283
+ print(f"失败: {fail_count}")
284
+ print(f"总耗时: {elapsed:.2f} 秒")
285
+ print("=" * 60)
286
+
287
+ logger.info("=" * 60)
288
+ logger.info(f"Pipeline 完成 - 总数:{total}, 成功:{success_count}, 失败:{fail_count}, 耗时:{elapsed:.2f}秒")
289
+ logger.info("=" * 60)
290
+
291
+
292
+ def create_pipeline_from_config(config: Dict[str, Any]) -> Pipeline:
293
+ source_config = config['source']
294
+ if source_config['type'] == 's3':
295
+ source = S3Source(
296
+ endpoint=source_config['endpoint'],
297
+ access_key=source_config['access_key'],
298
+ secret_key=source_config['secret_key'],
299
+ bucket=source_config['bucket'],
300
+ prefix=source_config.get('prefix', ''),
301
+ region=source_config.get('region', 'us-east-1'),
302
+ pattern=source_config.get('pattern', '*')
303
+ )
304
+ elif source_config['type'] == 'local':
305
+ source = LocalSource(
306
+ directory=source_config['directory'],
307
+ pattern=source_config.get('pattern', '*')
308
+ )
309
+ elif source_config['type'] == 'ftp':
310
+ source = FtpSource(
311
+ host=source_config['host'],
312
+ port=source_config['port'],
313
+ username=source_config['username'],
314
+ password=source_config['password'],
315
+ pattern=source_config.get('pattern', '*')
316
+ )
317
+ elif source_config['type'] == 'smb':
318
+ source = SmbSource(
319
+ host=source_config['host'],
320
+ share_name=source_config['share_name'],
321
+ username=source_config['username'],
322
+ password=source_config['password'],
323
+ domain=source_config.get('domain', ''),
324
+ port=source_config.get('port', 445),
325
+ path=source_config.get('path', ''),
326
+ pattern=source_config.get('pattern', '*')
327
+ )
328
+ else:
329
+ raise ValueError(f"未知的 source 类型: {source_config['type']}")
330
+
331
+ dest_config = config['destination']
332
+ if dest_config['type'] in ['milvus', 'zilliz']:
333
+ destination = MilvusDestination(
334
+ db_path=dest_config['db_path'],
335
+ collection_name=dest_config['collection_name'],
336
+ dimension=dest_config['dimension'],
337
+ api_key=dest_config.get('api_key'),
338
+ token=dest_config.get('token')
339
+ )
340
+ elif dest_config['type'] == 'local':
341
+ destination = LocalDestination(
342
+ output_dir=dest_config['output_dir']
343
+ )
344
+ elif dest_config['type'] == 's3':
345
+ destination = S3Destination(
346
+ endpoint=dest_config['endpoint'],
347
+ access_key=dest_config['access_key'],
348
+ secret_key=dest_config['secret_key'],
349
+ bucket=dest_config['bucket'],
350
+ prefix=dest_config.get('prefix', ''),
351
+ region=dest_config.get('region', 'us-east-1')
352
+ )
353
+ else:
354
+ raise ValueError(f"未知的 destination 类型: {dest_config['type']}")
355
+
356
+ # 处理 stages 配置
357
+ if 'stages' not in config or not config['stages']:
358
+ raise ValueError("配置中必须包含 'stages' 字段")
359
+
360
+ stages = []
361
+ for stage_cfg in config['stages']:
362
+ stage_type = stage_cfg.get('type')
363
+ stage_config_dict = stage_cfg.get('config', {})
364
+
365
+ if stage_type == 'parse':
366
+ parse_cfg_copy = dict(stage_config_dict)
367
+ provider = parse_cfg_copy.pop('provider', 'textin')
368
+ stage_config = ParseConfig(provider=provider, **parse_cfg_copy)
369
+ elif stage_type == 'chunk':
370
+ stage_config = ChunkConfig(
371
+ strategy=stage_config_dict.get('strategy', 'basic'),
372
+ include_orig_elements=stage_config_dict.get('include_orig_elements', False),
373
+ new_after_n_chars=stage_config_dict.get('new_after_n_chars', 512),
374
+ max_characters=stage_config_dict.get('max_characters', 1024),
375
+ overlap=stage_config_dict.get('overlap', 0),
376
+ overlap_all=stage_config_dict.get('overlap_all', False)
377
+ )
378
+ elif stage_type == 'embed':
379
+ stage_config = EmbedConfig(
380
+ provider=stage_config_dict.get('provider', 'qwen'),
381
+ model_name=stage_config_dict.get('model_name', 'text-embedding-v3')
382
+ )
383
+ else:
384
+ raise ValueError(f"未知的 stage 类型: {stage_type}")
385
+
386
+ stages.append(Stage(type=stage_type, config=stage_config))
387
+
388
+ # 创建 Pipeline 配置
389
+ pipeline_config = None
390
+ if 'pipeline_config' in config and config['pipeline_config']:
391
+ pipeline_cfg = config['pipeline_config']
392
+ include_intermediate_results = pipeline_cfg.get('include_intermediate_results', False)
393
+ intermediate_results_destination = None
394
+
395
+ if include_intermediate_results:
396
+ if 'intermediate_results_destination' in pipeline_cfg:
397
+ dest_cfg = pipeline_cfg['intermediate_results_destination']
398
+ dest_type = dest_cfg.get('type')
399
+
400
+ if dest_type == 'local':
401
+ intermediate_results_destination = LocalDestination(
402
+ output_dir=dest_cfg['output_dir']
403
+ )
404
+ elif dest_type == 's3':
405
+ intermediate_results_destination = S3Destination(
406
+ endpoint=dest_cfg['endpoint'],
407
+ access_key=dest_cfg['access_key'],
408
+ secret_key=dest_cfg['secret_key'],
409
+ bucket=dest_cfg['bucket'],
410
+ prefix=dest_cfg.get('prefix', ''),
411
+ region=dest_cfg.get('region', 'us-east-1')
412
+ )
413
+ else:
414
+ raise ValueError(f"不支持的 intermediate_results_destination 类型: '{dest_type}',支持的类型: 'local', 's3'")
415
+ else:
416
+ raise ValueError("当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination")
417
+
418
+ pipeline_config = PipelineConfig(
419
+ include_intermediate_results=include_intermediate_results,
420
+ intermediate_results_destination=intermediate_results_destination
421
+ )
422
+
423
+ # 创建 Pipeline
424
+ pipeline = Pipeline(
425
+ source=source,
426
+ destination=destination,
427
+ api_base_url=config.get('api_base_url', 'http://localhost:8000/api/xparse'),
428
+ api_headers=config.get('api_headers', {}),
429
+ stages=stages,
430
+ pipeline_config=pipeline_config
431
+ )
432
+
433
+ return pipeline
434
+
435
+
436
+ __all__ = [
437
+ 'Pipeline',
438
+ 'create_pipeline_from_config',
439
+ ]
440
+