xparse-client 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,250 @@
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import json
5
+ import logging
6
+ import uuid
7
+ import boto3
8
+
9
+ from abc import ABC, abstractmethod
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import List, Dict, Any
13
+
14
+ from botocore.config import Config
15
+ from pymilvus import MilvusClient
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class Destination(ABC):
22
+ """数据目的地抽象基类"""
23
+
24
+ @abstractmethod
25
+ def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
26
+ """写入数据"""
27
+ raise NotImplementedError
28
+
29
+
30
+ class MilvusDestination(Destination):
31
+ """Milvus/Zilliz 向量数据库目的地"""
32
+
33
+ def __init__(self, db_path: str, collection_name: str, dimension: int, api_key: str = None, token: str = None):
34
+ from pymilvus import DataType
35
+
36
+ self.db_path = db_path
37
+ self.collection_name = collection_name
38
+ self.dimension = dimension
39
+
40
+ client_kwargs = {'uri': db_path}
41
+ if api_key:
42
+ client_kwargs['token'] = api_key
43
+ elif token:
44
+ client_kwargs['token'] = token
45
+
46
+ self.client = MilvusClient(**client_kwargs)
47
+
48
+ if not self.client.has_collection(collection_name):
49
+ schema = self.client.create_schema(
50
+ auto_id=False,
51
+ enable_dynamic_field=True
52
+ )
53
+
54
+ schema.add_field(field_name="element_id", datatype=DataType.VARCHAR, max_length=128, is_primary=True)
55
+ schema.add_field(field_name="embeddings", datatype=DataType.FLOAT_VECTOR, dim=dimension)
56
+ schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
57
+ schema.add_field(field_name="record_id", datatype=DataType.VARCHAR, max_length=128)
58
+ schema.add_field(field_name="metadata", datatype=DataType.JSON)
59
+
60
+ index_params = self.client.prepare_index_params()
61
+ index_params.add_index(
62
+ field_name="embeddings",
63
+ index_type="AUTOINDEX",
64
+ metric_type="COSINE"
65
+ )
66
+
67
+ self.client.create_collection(
68
+ collection_name=collection_name,
69
+ schema=schema,
70
+ index_params=index_params
71
+ )
72
+ print(f"✓ Milvus/Zilliz 集合创建: {collection_name} (自定义 Schema)")
73
+ else:
74
+ print(f"✓ Milvus/Zilliz 集合存在: {collection_name}")
75
+
76
+ logger.info(f"Milvus/Zilliz 连接成功: {db_path}")
77
+
78
+ def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
79
+ try:
80
+ insert_data = []
81
+ for item in data:
82
+ # 获取元素级别的 metadata
83
+ element_metadata = item.get('metadata', {})
84
+
85
+ if 'embeddings' in item and item['embeddings']:
86
+ element_id = item.get('element_id') or item.get('id') or str(uuid.uuid4())
87
+
88
+ # 构建基础数据
89
+ insert_item = {
90
+ 'embeddings': item['embeddings'],
91
+ 'text': item.get('text', ''),
92
+ 'element_id': element_id,
93
+ 'record_id': element_metadata.get('record_id', ''),
94
+ 'created_at': datetime.now().isoformat()
95
+ }
96
+
97
+ # 合并文件级别的 metadata 和元素级别的 metadata
98
+ # 文件级别的 metadata 优先级更高
99
+ merged_metadata = {**element_metadata, **metadata}
100
+
101
+ # 将 metadata 中的字段展平到顶层作为动态字段
102
+ # 排除已存在的固定字段,避免冲突
103
+ fixed_fields = {'embeddings', 'text', 'element_id', 'record_id', 'created_at', 'metadata'}
104
+ for key, value in merged_metadata.items():
105
+ if key not in fixed_fields:
106
+ # 特殊处理 data_source 字段:如果是字典则展平
107
+ if key == 'data_source' and isinstance(value, dict):
108
+ # 将 data_source 字典展平为 data_source_* 格式
109
+ for sub_key, sub_value in value.items():
110
+ flat_key = f'data_source_{sub_key}'
111
+ if flat_key not in fixed_fields:
112
+ # 如果子值也是字典或列表,转换为 JSON 字符串
113
+ if isinstance(sub_value, (dict, list)):
114
+ insert_item[flat_key] = json.dumps(sub_value, ensure_ascii=False)
115
+ else:
116
+ insert_item[flat_key] = sub_value
117
+ elif isinstance(value, (dict, list)):
118
+ continue
119
+ else:
120
+ insert_item[key] = value
121
+
122
+ insert_data.append(insert_item)
123
+
124
+ if not insert_data:
125
+ print(f" ! 警告: 没有有效的向量数据")
126
+ return False
127
+
128
+ self.client.insert(
129
+ collection_name=self.collection_name,
130
+ data=insert_data
131
+ )
132
+ print(f" ✓ 写入 Milvus: {len(insert_data)} 条")
133
+ logger.info(f"写入 Milvus 成功: {len(insert_data)} 条")
134
+ return True
135
+ except Exception as e:
136
+ print(f" ✗ 写入 Milvus 失败: {str(e)}")
137
+ logger.error(f"写入 Milvus 失败: {str(e)}")
138
+ return False
139
+
140
+
141
+ class LocalDestination(Destination):
142
+ """本地文件系统目的地"""
143
+
144
+ def __init__(self, output_dir: str):
145
+ self.output_dir = Path(output_dir)
146
+ self.output_dir.mkdir(parents=True, exist_ok=True)
147
+ print(f"✓ 输出目录: {self.output_dir}")
148
+ logger.info(f"输出目录: {self.output_dir}")
149
+
150
+ def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
151
+ try:
152
+ file_name = metadata.get('file_name', 'output')
153
+ base_name = Path(file_name).stem
154
+ stage = metadata.get('stage') # 用于区分中间结果的阶段
155
+
156
+ # 如果是中间结果,在文件名中添加阶段标识
157
+ if stage:
158
+ output_file = self.output_dir / f"{base_name}_{stage}.json"
159
+ else:
160
+ output_file = self.output_dir / f"{base_name}.json"
161
+
162
+ with open(output_file, 'w', encoding='utf-8') as f:
163
+ json.dump(data, f, ensure_ascii=False, indent=2)
164
+
165
+ print(f" ✓ 写入本地: {output_file}")
166
+ logger.info(f"写入本地成功: {output_file}")
167
+ return True
168
+ except Exception as e:
169
+ print(f" ✗ 写入本地失败: {str(e)}")
170
+ logger.error(f"写入本地失败: {str(e)}")
171
+ return False
172
+
173
+
174
+ class S3Destination(Destination):
175
+ """S3/MinIO 数据目的地"""
176
+
177
+ def __init__(self, endpoint: str, access_key: str, secret_key: str,
178
+ bucket: str, prefix: str = '', region: str = 'us-east-1'):
179
+ self.endpoint = endpoint
180
+ self.bucket = bucket
181
+ self.prefix = prefix.strip('/') if prefix else ''
182
+
183
+ if self.endpoint == 'https://textin-minio-api.ai.intsig.net':
184
+ config = Config(signature_version='s3v4')
185
+ elif self.endpoint.endswith('aliyuncs.com'):
186
+ config = Config(signature_version='s3', s3={'addressing_style': 'virtual'})
187
+ else:
188
+ config = Config(signature_version='s3v4', s3={'addressing_style': 'virtual'})
189
+
190
+ self.client = boto3.client(
191
+ 's3',
192
+ endpoint_url=endpoint,
193
+ aws_access_key_id=access_key,
194
+ aws_secret_access_key=secret_key,
195
+ region_name=region,
196
+ config=config
197
+ )
198
+
199
+ try:
200
+ self.client.head_bucket(Bucket=bucket)
201
+ test_key = f"{self.prefix}/empty.tmp" if self.prefix else f"empty.tmp"
202
+ self.client.put_object(
203
+ Bucket=bucket,
204
+ Key=test_key,
205
+ Body=b''
206
+ )
207
+ try:
208
+ self.client.delete_object(Bucket=bucket, Key=test_key)
209
+ except Exception:
210
+ pass
211
+
212
+ print(f"✓ S3 连接成功且可写: {endpoint}/{bucket}")
213
+ logger.info(f"S3 连接成功且可写: {endpoint}/{bucket}")
214
+ except Exception as e:
215
+ print(f"✗ S3 连接或写入测试失败: {str(e)}")
216
+ logger.error(f"S3 连接或写入测试失败: {str(e)}")
217
+ raise
218
+
219
+ def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
220
+ try:
221
+ file_name = metadata.get('file_name', 'output')
222
+ base_name = Path(file_name).stem
223
+ object_key = f"{self.prefix}/{base_name}.json" if self.prefix else f"{base_name}.json"
224
+
225
+ json_data = json.dumps(data, ensure_ascii=False, indent=2)
226
+ json_bytes = json_data.encode('utf-8')
227
+
228
+ self.client.put_object(
229
+ Bucket=self.bucket,
230
+ Key=object_key,
231
+ Body=json_bytes,
232
+ ContentType='application/json'
233
+ )
234
+
235
+ print(f" ✓ 写入 S3: {self.endpoint}/{self.bucket}/{object_key}")
236
+ logger.info(f"写入 S3 成功: {self.endpoint}/{self.bucket}/{object_key}")
237
+ return True
238
+ except Exception as e:
239
+ print(f" ✗ 写入 S3 失败: {str(e)}")
240
+ logger.error(f"写入 S3 失败: {str(e)}")
241
+ return False
242
+
243
+
244
+ __all__ = [
245
+ 'Destination',
246
+ 'MilvusDestination',
247
+ 'LocalDestination',
248
+ 'S3Destination',
249
+ ]
250
+
@@ -0,0 +1,440 @@
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import json
5
+ import logging
6
+ import time
7
+ from datetime import datetime, timezone
8
+ from pathlib import Path
9
+ from typing import Dict, Any, Optional, Tuple, List, Union
10
+
11
+ import requests
12
+
13
+ from .config import ParseConfig, ChunkConfig, EmbedConfig, Stage, PipelineStats, PipelineConfig
14
+ from .sources import Source, S3Source, LocalSource, FtpSource, SmbSource
15
+ from .destinations import Destination, MilvusDestination, LocalDestination, S3Destination
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class Pipeline:
22
+ """数据处理 Pipeline"""
23
+
24
+ def __init__(
25
+ self,
26
+ source: Source,
27
+ destination: Destination,
28
+ api_base_url: str = 'http://localhost:8000/api/xparse',
29
+ api_headers: Optional[Dict[str, str]] = None,
30
+ stages: Optional[List[Stage]] = None,
31
+ pipeline_config: Optional[PipelineConfig] = None,
32
+ intermediate_results_destination: Optional[Destination] = None
33
+ ):
34
+ self.source = source
35
+ self.destination = destination
36
+ self.api_base_url = api_base_url.rstrip('/')
37
+ self.api_headers = api_headers or {}
38
+ self.pipeline_config = pipeline_config or PipelineConfig()
39
+
40
+ # 处理 intermediate_results_destination 参数
41
+ # 如果直接传入了 intermediate_results_destination,优先使用它并自动启用中间结果保存
42
+ if intermediate_results_destination is not None:
43
+ self.pipeline_config.include_intermediate_results = True
44
+ self.pipeline_config.intermediate_results_destination = intermediate_results_destination
45
+ # 如果 pipeline_config 中已设置,使用 pipeline_config 中的值
46
+ elif self.pipeline_config.include_intermediate_results:
47
+ if not self.pipeline_config.intermediate_results_destination:
48
+ raise ValueError("当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination")
49
+
50
+ # 处理 stages 配置
51
+ if stages is None:
52
+ raise ValueError("必须提供 stages 参数")
53
+
54
+ self.stages = stages
55
+
56
+ # 验证 stages
57
+ if not self.stages or self.stages[0].type != 'parse':
58
+ raise ValueError("stages 必须包含且第一个必须是 'parse' 类型")
59
+
60
+ # 验证 embed config(如果存在)
61
+ for stage in self.stages:
62
+ if stage.type == 'embed' and isinstance(stage.config, EmbedConfig):
63
+ stage.config.validate()
64
+
65
+ # 验证 intermediate_results_destination
66
+ if self.pipeline_config.include_intermediate_results:
67
+ # 验证是否为支持的 Destination 类型
68
+ from .destinations import Destination
69
+ if not isinstance(self.pipeline_config.intermediate_results_destination, Destination):
70
+ raise ValueError(f"intermediate_results_destination 必须是 Destination 类型")
71
+ self.intermediate_results_destination = self.pipeline_config.intermediate_results_destination
72
+
73
+ print("=" * 60)
74
+ print("Pipeline 初始化完成")
75
+ print(f" Stages: {[s.type for s in self.stages]}")
76
+ for stage in self.stages:
77
+ print(f" - {stage.type}: {stage.config}")
78
+ if self.pipeline_config.include_intermediate_results:
79
+ print(f" Pipeline Config: 中间结果保存已启用")
80
+ print("=" * 60)
81
+
82
+ def _call_pipeline_api(self, file_bytes: bytes, file_name: str, data_source: Dict[str, Any]) -> Optional[Dict[str, Any]]:
83
+ url = f"{self.api_base_url}/pipeline"
84
+ max_retries = 3
85
+
86
+ for try_count in range(max_retries):
87
+ try:
88
+ files = {'file': (file_name or 'file', file_bytes)}
89
+ form_data = {}
90
+
91
+ # 将 stages 转换为 API 格式
92
+ stages_data = [stage.to_dict() for stage in self.stages]
93
+ form_data['stages'] = json.dumps(stages_data)
94
+ form_data['data_source'] = json.dumps(data_source, ensure_ascii=False)
95
+
96
+ # 如果启用了中间结果保存,在请求中添加参数
97
+ if self.pipeline_config:
98
+ form_data['config'] = json.dumps(self.pipeline_config.to_dict(), ensure_ascii=False)
99
+
100
+ response = requests.post(
101
+ url,
102
+ files=files,
103
+ data=form_data,
104
+ headers=self.api_headers,
105
+ timeout=120
106
+ )
107
+
108
+ if response.status_code == 200:
109
+ result = response.json()
110
+ print(f" ✓ Pipeline 接口返回 x_request_id: {result.get('x_request_id')}")
111
+ if result.get('code') == 200 and 'data' in result:
112
+ return result.get('data')
113
+ return None
114
+ else:
115
+ print(f" ! API 错误 {response.status_code}, 重试 {try_count + 1}/{max_retries}")
116
+ logger.warning(f"API 错误 {response.status_code}: pipeline")
117
+
118
+ except Exception as e:
119
+ print(f" ! 请求异常: {str(e)}, 重试 {try_count + 1}/{max_retries}")
120
+ logger.error(f"API 请求异常 pipeline: {str(e)}")
121
+
122
+ if try_count < max_retries - 1:
123
+ time.sleep(2)
124
+
125
+ return None
126
+
127
+ def process_with_pipeline(self, file_bytes: bytes, file_name: str, data_source: Dict[str, Any]) -> Optional[Tuple[List[Dict[str, Any]], PipelineStats]]:
128
+ print(f" → 调用 Pipeline 接口: {file_name}")
129
+ result = self._call_pipeline_api(file_bytes, file_name, data_source)
130
+
131
+ if result and 'elements' in result and 'stats' in result:
132
+ elements = result['elements']
133
+ stats_data = result['stats']
134
+
135
+ stats = PipelineStats(
136
+ original_elements=stats_data.get('original_elements', 0),
137
+ chunked_elements=stats_data.get('chunked_elements', 0),
138
+ embedded_elements=stats_data.get('embedded_elements', 0),
139
+ stages=self.stages # 使用实际执行的 stages
140
+ )
141
+
142
+ # 如果启用了中间结果保存,处理中间结果
143
+ if self.pipeline_config.include_intermediate_results and 'intermediate_results' in result:
144
+ self._save_intermediate_results(result['intermediate_results'], file_name, data_source)
145
+
146
+ print(f" ✓ Pipeline 完成:")
147
+ print(f" - 原始元素: {stats.original_elements}")
148
+ print(f" - 分块后: {stats.chunked_elements}")
149
+ print(f" - 向量化: {stats.embedded_elements}")
150
+ logger.info(f"Pipeline 完成: {file_name}, {stats.embedded_elements} 个向量")
151
+
152
+ return elements, stats
153
+ else:
154
+ print(f" ✗ Pipeline 失败")
155
+ logger.error(f"Pipeline 失败: {file_name}")
156
+ return None
157
+
158
+ def _save_intermediate_results(self, intermediate_results: List[Dict[str, Any]], file_name: str, data_source: Dict[str, Any]) -> None:
159
+ """保存中间结果
160
+
161
+ Args:
162
+ intermediate_results: 中间结果数组,每个元素包含 stage 和 elements 字段
163
+ file_name: 文件名
164
+ data_source: 数据源信息
165
+ """
166
+ try:
167
+ # intermediate_results 是一个数组,每个元素是 {stage: str, elements: List}
168
+ for result_item in intermediate_results:
169
+ if 'stage' not in result_item or 'elements' not in result_item:
170
+ logger.warning(f"中间结果项缺少 stage 或 elements 字段: {result_item}")
171
+ continue
172
+
173
+ stage = result_item['stage']
174
+ elements = result_item['elements']
175
+
176
+ metadata = {
177
+ 'file_name': file_name,
178
+ 'stage': stage,
179
+ 'total_elements': len(elements),
180
+ 'processed_at': datetime.now().isoformat(),
181
+ 'data_source': data_source
182
+ }
183
+
184
+ self.pipeline_config.intermediate_results_destination.write(elements, metadata)
185
+ print(f" ✓ 保存 {stage.upper()} 中间结果: {len(elements)} 个元素")
186
+ logger.info(f"保存 {stage.upper()} 中间结果成功: {file_name}")
187
+
188
+ except Exception as e:
189
+ print(f" ✗ 保存中间结果失败: {str(e)}")
190
+ logger.error(f"保存中间结果失败: {file_name}, {str(e)}")
191
+
192
+ def process_file(self, file_path: str) -> bool:
193
+ print(f"\n{'=' * 60}")
194
+ print(f"处理文件: {file_path}")
195
+ logger.info(f"开始处理文件: {file_path}")
196
+
197
+ try:
198
+ print(f" → 读取文件...")
199
+ file_bytes, data_source = self.source.read_file(file_path)
200
+ data_source = data_source or {}
201
+ data_source['date_processed'] = datetime.now(timezone.utc).timestamp()
202
+ print(f" ✓ 文件读取完成: {len(file_bytes)} bytes")
203
+
204
+ result = self.process_with_pipeline(file_bytes, file_path, data_source)
205
+ if not result:
206
+ return False
207
+
208
+ embedded_data, stats = result
209
+
210
+ print(f" → 写入目的地...")
211
+ metadata = {
212
+ 'file_name': file_path,
213
+ 'total_elements': len(embedded_data),
214
+ 'processed_at': datetime.now().isoformat(),
215
+ 'data_source': data_source,
216
+ 'stats': {
217
+ 'original_elements': stats.original_elements,
218
+ 'chunked_elements': stats.chunked_elements,
219
+ 'embedded_elements': stats.embedded_elements
220
+ }
221
+ }
222
+
223
+ success = self.destination.write(embedded_data, metadata)
224
+
225
+ if success:
226
+ print(f"\n✓✓✓ 文件处理成功: {file_path}")
227
+ logger.info(f"文件处理成功: {file_path}")
228
+ else:
229
+ print(f"\n✗✗✗ 文件处理失败: {file_path}")
230
+ logger.error(f"文件处理失败: {file_path}")
231
+
232
+ return success
233
+
234
+ except Exception as e:
235
+ print(f"\n✗✗✗ 处理异常: {str(e)}")
236
+ logger.error(f"处理文件异常 {file_path}: {str(e)}")
237
+ return False
238
+
239
+ def run(self):
240
+ start_time = time.time()
241
+
242
+ print("\n" + "=" * 60)
243
+ print("开始执行 Pipeline")
244
+ print("=" * 60)
245
+ logger.info("=" * 60)
246
+ logger.info("开始执行 Pipeline")
247
+
248
+ print("\n→ 列出文件...")
249
+ files = self.source.list_files()
250
+
251
+ if not files:
252
+ print("\n✗ 没有找到文件")
253
+ logger.info("没有找到文件")
254
+ return
255
+
256
+ total = len(files)
257
+ success_count = 0
258
+ fail_count = 0
259
+
260
+ for idx, file_path in enumerate(files, 1):
261
+ print(f"\n进度: [{idx}/{total}]")
262
+ logger.info(f"进度: [{idx}/{total}] - {file_path}")
263
+
264
+ try:
265
+ if self.process_file(file_path):
266
+ success_count += 1
267
+ else:
268
+ fail_count += 1
269
+ except Exception as e:
270
+ print(f"\n✗✗✗ 文件处理异常: {str(e)}")
271
+ logger.error(f"文件处理异常 {file_path}: {str(e)}")
272
+ fail_count += 1
273
+
274
+ if idx < total:
275
+ time.sleep(1)
276
+
277
+ elapsed = time.time() - start_time
278
+ print("\n" + "=" * 60)
279
+ print("Pipeline 执行完成!")
280
+ print("=" * 60)
281
+ print(f"总文件数: {total}")
282
+ print(f"成功: {success_count}")
283
+ print(f"失败: {fail_count}")
284
+ print(f"总耗时: {elapsed:.2f} 秒")
285
+ print("=" * 60)
286
+
287
+ logger.info("=" * 60)
288
+ logger.info(f"Pipeline 完成 - 总数:{total}, 成功:{success_count}, 失败:{fail_count}, 耗时:{elapsed:.2f}秒")
289
+ logger.info("=" * 60)
290
+
291
+
292
+ def create_pipeline_from_config(config: Dict[str, Any]) -> Pipeline:
293
+ source_config = config['source']
294
+ if source_config['type'] == 's3':
295
+ source = S3Source(
296
+ endpoint=source_config['endpoint'],
297
+ access_key=source_config['access_key'],
298
+ secret_key=source_config['secret_key'],
299
+ bucket=source_config['bucket'],
300
+ prefix=source_config.get('prefix', ''),
301
+ region=source_config.get('region', 'us-east-1'),
302
+ pattern=source_config.get('pattern', '*')
303
+ )
304
+ elif source_config['type'] == 'local':
305
+ source = LocalSource(
306
+ directory=source_config['directory'],
307
+ pattern=source_config.get('pattern', '*')
308
+ )
309
+ elif source_config['type'] == 'ftp':
310
+ source = FtpSource(
311
+ host=source_config['host'],
312
+ port=source_config['port'],
313
+ username=source_config['username'],
314
+ password=source_config['password'],
315
+ pattern=source_config.get('pattern', '*')
316
+ )
317
+ elif source_config['type'] == 'smb':
318
+ source = SmbSource(
319
+ host=source_config['host'],
320
+ share_name=source_config['share_name'],
321
+ username=source_config['username'],
322
+ password=source_config['password'],
323
+ domain=source_config.get('domain', ''),
324
+ port=source_config.get('port', 445),
325
+ path=source_config.get('path', ''),
326
+ pattern=source_config.get('pattern', '*')
327
+ )
328
+ else:
329
+ raise ValueError(f"未知的 source 类型: {source_config['type']}")
330
+
331
+ dest_config = config['destination']
332
+ if dest_config['type'] in ['milvus', 'zilliz']:
333
+ destination = MilvusDestination(
334
+ db_path=dest_config['db_path'],
335
+ collection_name=dest_config['collection_name'],
336
+ dimension=dest_config['dimension'],
337
+ api_key=dest_config.get('api_key'),
338
+ token=dest_config.get('token')
339
+ )
340
+ elif dest_config['type'] == 'local':
341
+ destination = LocalDestination(
342
+ output_dir=dest_config['output_dir']
343
+ )
344
+ elif dest_config['type'] == 's3':
345
+ destination = S3Destination(
346
+ endpoint=dest_config['endpoint'],
347
+ access_key=dest_config['access_key'],
348
+ secret_key=dest_config['secret_key'],
349
+ bucket=dest_config['bucket'],
350
+ prefix=dest_config.get('prefix', ''),
351
+ region=dest_config.get('region', 'us-east-1')
352
+ )
353
+ else:
354
+ raise ValueError(f"未知的 destination 类型: {dest_config['type']}")
355
+
356
+ # 处理 stages 配置
357
+ if 'stages' not in config or not config['stages']:
358
+ raise ValueError("配置中必须包含 'stages' 字段")
359
+
360
+ stages = []
361
+ for stage_cfg in config['stages']:
362
+ stage_type = stage_cfg.get('type')
363
+ stage_config_dict = stage_cfg.get('config', {})
364
+
365
+ if stage_type == 'parse':
366
+ parse_cfg_copy = dict(stage_config_dict)
367
+ provider = parse_cfg_copy.pop('provider', 'textin')
368
+ stage_config = ParseConfig(provider=provider, **parse_cfg_copy)
369
+ elif stage_type == 'chunk':
370
+ stage_config = ChunkConfig(
371
+ strategy=stage_config_dict.get('strategy', 'basic'),
372
+ include_orig_elements=stage_config_dict.get('include_orig_elements', False),
373
+ new_after_n_chars=stage_config_dict.get('new_after_n_chars', 512),
374
+ max_characters=stage_config_dict.get('max_characters', 1024),
375
+ overlap=stage_config_dict.get('overlap', 0),
376
+ overlap_all=stage_config_dict.get('overlap_all', False)
377
+ )
378
+ elif stage_type == 'embed':
379
+ stage_config = EmbedConfig(
380
+ provider=stage_config_dict.get('provider', 'qwen'),
381
+ model_name=stage_config_dict.get('model_name', 'text-embedding-v3')
382
+ )
383
+ else:
384
+ raise ValueError(f"未知的 stage 类型: {stage_type}")
385
+
386
+ stages.append(Stage(type=stage_type, config=stage_config))
387
+
388
+ # 创建 Pipeline 配置
389
+ pipeline_config = None
390
+ if 'pipeline_config' in config and config['pipeline_config']:
391
+ pipeline_cfg = config['pipeline_config']
392
+ include_intermediate_results = pipeline_cfg.get('include_intermediate_results', False)
393
+ intermediate_results_destination = None
394
+
395
+ if include_intermediate_results:
396
+ if 'intermediate_results_destination' in pipeline_cfg:
397
+ dest_cfg = pipeline_cfg['intermediate_results_destination']
398
+ dest_type = dest_cfg.get('type')
399
+
400
+ if dest_type == 'local':
401
+ intermediate_results_destination = LocalDestination(
402
+ output_dir=dest_cfg['output_dir']
403
+ )
404
+ elif dest_type == 's3':
405
+ intermediate_results_destination = S3Destination(
406
+ endpoint=dest_cfg['endpoint'],
407
+ access_key=dest_cfg['access_key'],
408
+ secret_key=dest_cfg['secret_key'],
409
+ bucket=dest_cfg['bucket'],
410
+ prefix=dest_cfg.get('prefix', ''),
411
+ region=dest_cfg.get('region', 'us-east-1')
412
+ )
413
+ else:
414
+ raise ValueError(f"不支持的 intermediate_results_destination 类型: '{dest_type}',支持的类型: 'local', 's3'")
415
+ else:
416
+ raise ValueError("当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination")
417
+
418
+ pipeline_config = PipelineConfig(
419
+ include_intermediate_results=include_intermediate_results,
420
+ intermediate_results_destination=intermediate_results_destination
421
+ )
422
+
423
+ # 创建 Pipeline
424
+ pipeline = Pipeline(
425
+ source=source,
426
+ destination=destination,
427
+ api_base_url=config.get('api_base_url', 'http://localhost:8000/api/xparse'),
428
+ api_headers=config.get('api_headers', {}),
429
+ stages=stages,
430
+ pipeline_config=pipeline_config
431
+ )
432
+
433
+ return pipeline
434
+
435
+
436
+ __all__ = [
437
+ 'Pipeline',
438
+ 'create_pipeline_from_config',
439
+ ]
440
+