xparse-client 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- example/run_pipeline.py +479 -0
- example/run_pipeline_test.py +458 -0
- xparse_client/__init__.py +21 -1
- xparse_client/pipeline/__init__.py +3 -0
- xparse_client/pipeline/config.py +128 -0
- xparse_client/pipeline/destinations.py +250 -0
- xparse_client/pipeline/pipeline.py +440 -0
- xparse_client/pipeline/sources.py +342 -0
- {xparse_client-0.2.2.dist-info → xparse_client-0.2.4.dist-info}/METADATA +1 -1
- xparse_client-0.2.4.dist-info/RECORD +13 -0
- {xparse_client-0.2.2.dist-info → xparse_client-0.2.4.dist-info}/top_level.txt +1 -0
- xparse_client-0.2.2.dist-info/RECORD +0 -6
- {xparse_client-0.2.2.dist-info → xparse_client-0.2.4.dist-info}/WHEEL +0 -0
- {xparse_client-0.2.2.dist-info → xparse_client-0.2.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
# -*- encoding: utf-8 -*-
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import uuid
|
|
7
|
+
import boto3
|
|
8
|
+
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import List, Dict, Any
|
|
13
|
+
|
|
14
|
+
from botocore.config import Config
|
|
15
|
+
from pymilvus import MilvusClient
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Destination(ABC):
|
|
22
|
+
"""数据目的地抽象基类"""
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
|
|
26
|
+
"""写入数据"""
|
|
27
|
+
raise NotImplementedError
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MilvusDestination(Destination):
|
|
31
|
+
"""Milvus/Zilliz 向量数据库目的地"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, db_path: str, collection_name: str, dimension: int, api_key: str = None, token: str = None):
|
|
34
|
+
from pymilvus import DataType
|
|
35
|
+
|
|
36
|
+
self.db_path = db_path
|
|
37
|
+
self.collection_name = collection_name
|
|
38
|
+
self.dimension = dimension
|
|
39
|
+
|
|
40
|
+
client_kwargs = {'uri': db_path}
|
|
41
|
+
if api_key:
|
|
42
|
+
client_kwargs['token'] = api_key
|
|
43
|
+
elif token:
|
|
44
|
+
client_kwargs['token'] = token
|
|
45
|
+
|
|
46
|
+
self.client = MilvusClient(**client_kwargs)
|
|
47
|
+
|
|
48
|
+
if not self.client.has_collection(collection_name):
|
|
49
|
+
schema = self.client.create_schema(
|
|
50
|
+
auto_id=False,
|
|
51
|
+
enable_dynamic_field=True
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
schema.add_field(field_name="element_id", datatype=DataType.VARCHAR, max_length=128, is_primary=True)
|
|
55
|
+
schema.add_field(field_name="embeddings", datatype=DataType.FLOAT_VECTOR, dim=dimension)
|
|
56
|
+
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
|
|
57
|
+
schema.add_field(field_name="record_id", datatype=DataType.VARCHAR, max_length=128)
|
|
58
|
+
schema.add_field(field_name="metadata", datatype=DataType.JSON)
|
|
59
|
+
|
|
60
|
+
index_params = self.client.prepare_index_params()
|
|
61
|
+
index_params.add_index(
|
|
62
|
+
field_name="embeddings",
|
|
63
|
+
index_type="AUTOINDEX",
|
|
64
|
+
metric_type="COSINE"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
self.client.create_collection(
|
|
68
|
+
collection_name=collection_name,
|
|
69
|
+
schema=schema,
|
|
70
|
+
index_params=index_params
|
|
71
|
+
)
|
|
72
|
+
print(f"✓ Milvus/Zilliz 集合创建: {collection_name} (自定义 Schema)")
|
|
73
|
+
else:
|
|
74
|
+
print(f"✓ Milvus/Zilliz 集合存在: {collection_name}")
|
|
75
|
+
|
|
76
|
+
logger.info(f"Milvus/Zilliz 连接成功: {db_path}")
|
|
77
|
+
|
|
78
|
+
def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
|
|
79
|
+
try:
|
|
80
|
+
insert_data = []
|
|
81
|
+
for item in data:
|
|
82
|
+
# 获取元素级别的 metadata
|
|
83
|
+
element_metadata = item.get('metadata', {})
|
|
84
|
+
|
|
85
|
+
if 'embeddings' in item and item['embeddings']:
|
|
86
|
+
element_id = item.get('element_id') or item.get('id') or str(uuid.uuid4())
|
|
87
|
+
|
|
88
|
+
# 构建基础数据
|
|
89
|
+
insert_item = {
|
|
90
|
+
'embeddings': item['embeddings'],
|
|
91
|
+
'text': item.get('text', ''),
|
|
92
|
+
'element_id': element_id,
|
|
93
|
+
'record_id': element_metadata.get('record_id', ''),
|
|
94
|
+
'created_at': datetime.now().isoformat()
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
# 合并文件级别的 metadata 和元素级别的 metadata
|
|
98
|
+
# 文件级别的 metadata 优先级更高
|
|
99
|
+
merged_metadata = {**element_metadata, **metadata}
|
|
100
|
+
|
|
101
|
+
# 将 metadata 中的字段展平到顶层作为动态字段
|
|
102
|
+
# 排除已存在的固定字段,避免冲突
|
|
103
|
+
fixed_fields = {'embeddings', 'text', 'element_id', 'record_id', 'created_at', 'metadata'}
|
|
104
|
+
for key, value in merged_metadata.items():
|
|
105
|
+
if key not in fixed_fields:
|
|
106
|
+
# 特殊处理 data_source 字段:如果是字典则展平
|
|
107
|
+
if key == 'data_source' and isinstance(value, dict):
|
|
108
|
+
# 将 data_source 字典展平为 data_source_* 格式
|
|
109
|
+
for sub_key, sub_value in value.items():
|
|
110
|
+
flat_key = f'data_source_{sub_key}'
|
|
111
|
+
if flat_key not in fixed_fields:
|
|
112
|
+
# 如果子值也是字典或列表,转换为 JSON 字符串
|
|
113
|
+
if isinstance(sub_value, (dict, list)):
|
|
114
|
+
insert_item[flat_key] = json.dumps(sub_value, ensure_ascii=False)
|
|
115
|
+
else:
|
|
116
|
+
insert_item[flat_key] = sub_value
|
|
117
|
+
elif isinstance(value, (dict, list)):
|
|
118
|
+
continue
|
|
119
|
+
else:
|
|
120
|
+
insert_item[key] = value
|
|
121
|
+
|
|
122
|
+
insert_data.append(insert_item)
|
|
123
|
+
|
|
124
|
+
if not insert_data:
|
|
125
|
+
print(f" ! 警告: 没有有效的向量数据")
|
|
126
|
+
return False
|
|
127
|
+
|
|
128
|
+
self.client.insert(
|
|
129
|
+
collection_name=self.collection_name,
|
|
130
|
+
data=insert_data
|
|
131
|
+
)
|
|
132
|
+
print(f" ✓ 写入 Milvus: {len(insert_data)} 条")
|
|
133
|
+
logger.info(f"写入 Milvus 成功: {len(insert_data)} 条")
|
|
134
|
+
return True
|
|
135
|
+
except Exception as e:
|
|
136
|
+
print(f" ✗ 写入 Milvus 失败: {str(e)}")
|
|
137
|
+
logger.error(f"写入 Milvus 失败: {str(e)}")
|
|
138
|
+
return False
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class LocalDestination(Destination):
|
|
142
|
+
"""本地文件系统目的地"""
|
|
143
|
+
|
|
144
|
+
def __init__(self, output_dir: str):
|
|
145
|
+
self.output_dir = Path(output_dir)
|
|
146
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
147
|
+
print(f"✓ 输出目录: {self.output_dir}")
|
|
148
|
+
logger.info(f"输出目录: {self.output_dir}")
|
|
149
|
+
|
|
150
|
+
def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
|
|
151
|
+
try:
|
|
152
|
+
file_name = metadata.get('file_name', 'output')
|
|
153
|
+
base_name = Path(file_name).stem
|
|
154
|
+
stage = metadata.get('stage') # 用于区分中间结果的阶段
|
|
155
|
+
|
|
156
|
+
# 如果是中间结果,在文件名中添加阶段标识
|
|
157
|
+
if stage:
|
|
158
|
+
output_file = self.output_dir / f"{base_name}_{stage}.json"
|
|
159
|
+
else:
|
|
160
|
+
output_file = self.output_dir / f"{base_name}.json"
|
|
161
|
+
|
|
162
|
+
with open(output_file, 'w', encoding='utf-8') as f:
|
|
163
|
+
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
164
|
+
|
|
165
|
+
print(f" ✓ 写入本地: {output_file}")
|
|
166
|
+
logger.info(f"写入本地成功: {output_file}")
|
|
167
|
+
return True
|
|
168
|
+
except Exception as e:
|
|
169
|
+
print(f" ✗ 写入本地失败: {str(e)}")
|
|
170
|
+
logger.error(f"写入本地失败: {str(e)}")
|
|
171
|
+
return False
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class S3Destination(Destination):
|
|
175
|
+
"""S3/MinIO 数据目的地"""
|
|
176
|
+
|
|
177
|
+
def __init__(self, endpoint: str, access_key: str, secret_key: str,
|
|
178
|
+
bucket: str, prefix: str = '', region: str = 'us-east-1'):
|
|
179
|
+
self.endpoint = endpoint
|
|
180
|
+
self.bucket = bucket
|
|
181
|
+
self.prefix = prefix.strip('/') if prefix else ''
|
|
182
|
+
|
|
183
|
+
if self.endpoint == 'https://textin-minio-api.ai.intsig.net':
|
|
184
|
+
config = Config(signature_version='s3v4')
|
|
185
|
+
elif self.endpoint.endswith('aliyuncs.com'):
|
|
186
|
+
config = Config(signature_version='s3', s3={'addressing_style': 'virtual'})
|
|
187
|
+
else:
|
|
188
|
+
config = Config(signature_version='s3v4', s3={'addressing_style': 'virtual'})
|
|
189
|
+
|
|
190
|
+
self.client = boto3.client(
|
|
191
|
+
's3',
|
|
192
|
+
endpoint_url=endpoint,
|
|
193
|
+
aws_access_key_id=access_key,
|
|
194
|
+
aws_secret_access_key=secret_key,
|
|
195
|
+
region_name=region,
|
|
196
|
+
config=config
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
try:
|
|
200
|
+
self.client.head_bucket(Bucket=bucket)
|
|
201
|
+
test_key = f"{self.prefix}/empty.tmp" if self.prefix else f"empty.tmp"
|
|
202
|
+
self.client.put_object(
|
|
203
|
+
Bucket=bucket,
|
|
204
|
+
Key=test_key,
|
|
205
|
+
Body=b''
|
|
206
|
+
)
|
|
207
|
+
try:
|
|
208
|
+
self.client.delete_object(Bucket=bucket, Key=test_key)
|
|
209
|
+
except Exception:
|
|
210
|
+
pass
|
|
211
|
+
|
|
212
|
+
print(f"✓ S3 连接成功且可写: {endpoint}/{bucket}")
|
|
213
|
+
logger.info(f"S3 连接成功且可写: {endpoint}/{bucket}")
|
|
214
|
+
except Exception as e:
|
|
215
|
+
print(f"✗ S3 连接或写入测试失败: {str(e)}")
|
|
216
|
+
logger.error(f"S3 连接或写入测试失败: {str(e)}")
|
|
217
|
+
raise
|
|
218
|
+
|
|
219
|
+
def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
|
|
220
|
+
try:
|
|
221
|
+
file_name = metadata.get('file_name', 'output')
|
|
222
|
+
base_name = Path(file_name).stem
|
|
223
|
+
object_key = f"{self.prefix}/{base_name}.json" if self.prefix else f"{base_name}.json"
|
|
224
|
+
|
|
225
|
+
json_data = json.dumps(data, ensure_ascii=False, indent=2)
|
|
226
|
+
json_bytes = json_data.encode('utf-8')
|
|
227
|
+
|
|
228
|
+
self.client.put_object(
|
|
229
|
+
Bucket=self.bucket,
|
|
230
|
+
Key=object_key,
|
|
231
|
+
Body=json_bytes,
|
|
232
|
+
ContentType='application/json'
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
print(f" ✓ 写入 S3: {self.endpoint}/{self.bucket}/{object_key}")
|
|
236
|
+
logger.info(f"写入 S3 成功: {self.endpoint}/{self.bucket}/{object_key}")
|
|
237
|
+
return True
|
|
238
|
+
except Exception as e:
|
|
239
|
+
print(f" ✗ 写入 S3 失败: {str(e)}")
|
|
240
|
+
logger.error(f"写入 S3 失败: {str(e)}")
|
|
241
|
+
return False
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
__all__ = [
|
|
245
|
+
'Destination',
|
|
246
|
+
'MilvusDestination',
|
|
247
|
+
'LocalDestination',
|
|
248
|
+
'S3Destination',
|
|
249
|
+
]
|
|
250
|
+
|
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
# -*- encoding: utf-8 -*-
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Dict, Any, Optional, Tuple, List, Union
|
|
10
|
+
|
|
11
|
+
import requests
|
|
12
|
+
|
|
13
|
+
from .config import ParseConfig, ChunkConfig, EmbedConfig, Stage, PipelineStats, PipelineConfig
|
|
14
|
+
from .sources import Source, S3Source, LocalSource, FtpSource, SmbSource
|
|
15
|
+
from .destinations import Destination, MilvusDestination, LocalDestination, S3Destination
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Pipeline:
|
|
22
|
+
"""数据处理 Pipeline"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
source: Source,
|
|
27
|
+
destination: Destination,
|
|
28
|
+
api_base_url: str = 'http://localhost:8000/api/xparse',
|
|
29
|
+
api_headers: Optional[Dict[str, str]] = None,
|
|
30
|
+
stages: Optional[List[Stage]] = None,
|
|
31
|
+
pipeline_config: Optional[PipelineConfig] = None,
|
|
32
|
+
intermediate_results_destination: Optional[Destination] = None
|
|
33
|
+
):
|
|
34
|
+
self.source = source
|
|
35
|
+
self.destination = destination
|
|
36
|
+
self.api_base_url = api_base_url.rstrip('/')
|
|
37
|
+
self.api_headers = api_headers or {}
|
|
38
|
+
self.pipeline_config = pipeline_config or PipelineConfig()
|
|
39
|
+
|
|
40
|
+
# 处理 intermediate_results_destination 参数
|
|
41
|
+
# 如果直接传入了 intermediate_results_destination,优先使用它并自动启用中间结果保存
|
|
42
|
+
if intermediate_results_destination is not None:
|
|
43
|
+
self.pipeline_config.include_intermediate_results = True
|
|
44
|
+
self.pipeline_config.intermediate_results_destination = intermediate_results_destination
|
|
45
|
+
# 如果 pipeline_config 中已设置,使用 pipeline_config 中的值
|
|
46
|
+
elif self.pipeline_config.include_intermediate_results:
|
|
47
|
+
if not self.pipeline_config.intermediate_results_destination:
|
|
48
|
+
raise ValueError("当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination")
|
|
49
|
+
|
|
50
|
+
# 处理 stages 配置
|
|
51
|
+
if stages is None:
|
|
52
|
+
raise ValueError("必须提供 stages 参数")
|
|
53
|
+
|
|
54
|
+
self.stages = stages
|
|
55
|
+
|
|
56
|
+
# 验证 stages
|
|
57
|
+
if not self.stages or self.stages[0].type != 'parse':
|
|
58
|
+
raise ValueError("stages 必须包含且第一个必须是 'parse' 类型")
|
|
59
|
+
|
|
60
|
+
# 验证 embed config(如果存在)
|
|
61
|
+
for stage in self.stages:
|
|
62
|
+
if stage.type == 'embed' and isinstance(stage.config, EmbedConfig):
|
|
63
|
+
stage.config.validate()
|
|
64
|
+
|
|
65
|
+
# 验证 intermediate_results_destination
|
|
66
|
+
if self.pipeline_config.include_intermediate_results:
|
|
67
|
+
# 验证是否为支持的 Destination 类型
|
|
68
|
+
from .destinations import Destination
|
|
69
|
+
if not isinstance(self.pipeline_config.intermediate_results_destination, Destination):
|
|
70
|
+
raise ValueError(f"intermediate_results_destination 必须是 Destination 类型")
|
|
71
|
+
self.intermediate_results_destination = self.pipeline_config.intermediate_results_destination
|
|
72
|
+
|
|
73
|
+
print("=" * 60)
|
|
74
|
+
print("Pipeline 初始化完成")
|
|
75
|
+
print(f" Stages: {[s.type for s in self.stages]}")
|
|
76
|
+
for stage in self.stages:
|
|
77
|
+
print(f" - {stage.type}: {stage.config}")
|
|
78
|
+
if self.pipeline_config.include_intermediate_results:
|
|
79
|
+
print(f" Pipeline Config: 中间结果保存已启用")
|
|
80
|
+
print("=" * 60)
|
|
81
|
+
|
|
82
|
+
def _call_pipeline_api(self, file_bytes: bytes, file_name: str, data_source: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
83
|
+
url = f"{self.api_base_url}/pipeline"
|
|
84
|
+
max_retries = 3
|
|
85
|
+
|
|
86
|
+
for try_count in range(max_retries):
|
|
87
|
+
try:
|
|
88
|
+
files = {'file': (file_name or 'file', file_bytes)}
|
|
89
|
+
form_data = {}
|
|
90
|
+
|
|
91
|
+
# 将 stages 转换为 API 格式
|
|
92
|
+
stages_data = [stage.to_dict() for stage in self.stages]
|
|
93
|
+
form_data['stages'] = json.dumps(stages_data)
|
|
94
|
+
form_data['data_source'] = json.dumps(data_source, ensure_ascii=False)
|
|
95
|
+
|
|
96
|
+
# 如果启用了中间结果保存,在请求中添加参数
|
|
97
|
+
if self.pipeline_config:
|
|
98
|
+
form_data['config'] = json.dumps(self.pipeline_config.to_dict(), ensure_ascii=False)
|
|
99
|
+
|
|
100
|
+
response = requests.post(
|
|
101
|
+
url,
|
|
102
|
+
files=files,
|
|
103
|
+
data=form_data,
|
|
104
|
+
headers=self.api_headers,
|
|
105
|
+
timeout=120
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if response.status_code == 200:
|
|
109
|
+
result = response.json()
|
|
110
|
+
print(f" ✓ Pipeline 接口返回 x_request_id: {result.get('x_request_id')}")
|
|
111
|
+
if result.get('code') == 200 and 'data' in result:
|
|
112
|
+
return result.get('data')
|
|
113
|
+
return None
|
|
114
|
+
else:
|
|
115
|
+
print(f" ! API 错误 {response.status_code}, 重试 {try_count + 1}/{max_retries}")
|
|
116
|
+
logger.warning(f"API 错误 {response.status_code}: pipeline")
|
|
117
|
+
|
|
118
|
+
except Exception as e:
|
|
119
|
+
print(f" ! 请求异常: {str(e)}, 重试 {try_count + 1}/{max_retries}")
|
|
120
|
+
logger.error(f"API 请求异常 pipeline: {str(e)}")
|
|
121
|
+
|
|
122
|
+
if try_count < max_retries - 1:
|
|
123
|
+
time.sleep(2)
|
|
124
|
+
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
def process_with_pipeline(self, file_bytes: bytes, file_name: str, data_source: Dict[str, Any]) -> Optional[Tuple[List[Dict[str, Any]], PipelineStats]]:
|
|
128
|
+
print(f" → 调用 Pipeline 接口: {file_name}")
|
|
129
|
+
result = self._call_pipeline_api(file_bytes, file_name, data_source)
|
|
130
|
+
|
|
131
|
+
if result and 'elements' in result and 'stats' in result:
|
|
132
|
+
elements = result['elements']
|
|
133
|
+
stats_data = result['stats']
|
|
134
|
+
|
|
135
|
+
stats = PipelineStats(
|
|
136
|
+
original_elements=stats_data.get('original_elements', 0),
|
|
137
|
+
chunked_elements=stats_data.get('chunked_elements', 0),
|
|
138
|
+
embedded_elements=stats_data.get('embedded_elements', 0),
|
|
139
|
+
stages=self.stages # 使用实际执行的 stages
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# 如果启用了中间结果保存,处理中间结果
|
|
143
|
+
if self.pipeline_config.include_intermediate_results and 'intermediate_results' in result:
|
|
144
|
+
self._save_intermediate_results(result['intermediate_results'], file_name, data_source)
|
|
145
|
+
|
|
146
|
+
print(f" ✓ Pipeline 完成:")
|
|
147
|
+
print(f" - 原始元素: {stats.original_elements}")
|
|
148
|
+
print(f" - 分块后: {stats.chunked_elements}")
|
|
149
|
+
print(f" - 向量化: {stats.embedded_elements}")
|
|
150
|
+
logger.info(f"Pipeline 完成: {file_name}, {stats.embedded_elements} 个向量")
|
|
151
|
+
|
|
152
|
+
return elements, stats
|
|
153
|
+
else:
|
|
154
|
+
print(f" ✗ Pipeline 失败")
|
|
155
|
+
logger.error(f"Pipeline 失败: {file_name}")
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
def _save_intermediate_results(self, intermediate_results: List[Dict[str, Any]], file_name: str, data_source: Dict[str, Any]) -> None:
|
|
159
|
+
"""保存中间结果
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
intermediate_results: 中间结果数组,每个元素包含 stage 和 elements 字段
|
|
163
|
+
file_name: 文件名
|
|
164
|
+
data_source: 数据源信息
|
|
165
|
+
"""
|
|
166
|
+
try:
|
|
167
|
+
# intermediate_results 是一个数组,每个元素是 {stage: str, elements: List}
|
|
168
|
+
for result_item in intermediate_results:
|
|
169
|
+
if 'stage' not in result_item or 'elements' not in result_item:
|
|
170
|
+
logger.warning(f"中间结果项缺少 stage 或 elements 字段: {result_item}")
|
|
171
|
+
continue
|
|
172
|
+
|
|
173
|
+
stage = result_item['stage']
|
|
174
|
+
elements = result_item['elements']
|
|
175
|
+
|
|
176
|
+
metadata = {
|
|
177
|
+
'file_name': file_name,
|
|
178
|
+
'stage': stage,
|
|
179
|
+
'total_elements': len(elements),
|
|
180
|
+
'processed_at': datetime.now().isoformat(),
|
|
181
|
+
'data_source': data_source
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
self.pipeline_config.intermediate_results_destination.write(elements, metadata)
|
|
185
|
+
print(f" ✓ 保存 {stage.upper()} 中间结果: {len(elements)} 个元素")
|
|
186
|
+
logger.info(f"保存 {stage.upper()} 中间结果成功: {file_name}")
|
|
187
|
+
|
|
188
|
+
except Exception as e:
|
|
189
|
+
print(f" ✗ 保存中间结果失败: {str(e)}")
|
|
190
|
+
logger.error(f"保存中间结果失败: {file_name}, {str(e)}")
|
|
191
|
+
|
|
192
|
+
def process_file(self, file_path: str) -> bool:
|
|
193
|
+
print(f"\n{'=' * 60}")
|
|
194
|
+
print(f"处理文件: {file_path}")
|
|
195
|
+
logger.info(f"开始处理文件: {file_path}")
|
|
196
|
+
|
|
197
|
+
try:
|
|
198
|
+
print(f" → 读取文件...")
|
|
199
|
+
file_bytes, data_source = self.source.read_file(file_path)
|
|
200
|
+
data_source = data_source or {}
|
|
201
|
+
data_source['date_processed'] = datetime.now(timezone.utc).timestamp()
|
|
202
|
+
print(f" ✓ 文件读取完成: {len(file_bytes)} bytes")
|
|
203
|
+
|
|
204
|
+
result = self.process_with_pipeline(file_bytes, file_path, data_source)
|
|
205
|
+
if not result:
|
|
206
|
+
return False
|
|
207
|
+
|
|
208
|
+
embedded_data, stats = result
|
|
209
|
+
|
|
210
|
+
print(f" → 写入目的地...")
|
|
211
|
+
metadata = {
|
|
212
|
+
'file_name': file_path,
|
|
213
|
+
'total_elements': len(embedded_data),
|
|
214
|
+
'processed_at': datetime.now().isoformat(),
|
|
215
|
+
'data_source': data_source,
|
|
216
|
+
'stats': {
|
|
217
|
+
'original_elements': stats.original_elements,
|
|
218
|
+
'chunked_elements': stats.chunked_elements,
|
|
219
|
+
'embedded_elements': stats.embedded_elements
|
|
220
|
+
}
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
success = self.destination.write(embedded_data, metadata)
|
|
224
|
+
|
|
225
|
+
if success:
|
|
226
|
+
print(f"\n✓✓✓ 文件处理成功: {file_path}")
|
|
227
|
+
logger.info(f"文件处理成功: {file_path}")
|
|
228
|
+
else:
|
|
229
|
+
print(f"\n✗✗✗ 文件处理失败: {file_path}")
|
|
230
|
+
logger.error(f"文件处理失败: {file_path}")
|
|
231
|
+
|
|
232
|
+
return success
|
|
233
|
+
|
|
234
|
+
except Exception as e:
|
|
235
|
+
print(f"\n✗✗✗ 处理异常: {str(e)}")
|
|
236
|
+
logger.error(f"处理文件异常 {file_path}: {str(e)}")
|
|
237
|
+
return False
|
|
238
|
+
|
|
239
|
+
def run(self):
|
|
240
|
+
start_time = time.time()
|
|
241
|
+
|
|
242
|
+
print("\n" + "=" * 60)
|
|
243
|
+
print("开始执行 Pipeline")
|
|
244
|
+
print("=" * 60)
|
|
245
|
+
logger.info("=" * 60)
|
|
246
|
+
logger.info("开始执行 Pipeline")
|
|
247
|
+
|
|
248
|
+
print("\n→ 列出文件...")
|
|
249
|
+
files = self.source.list_files()
|
|
250
|
+
|
|
251
|
+
if not files:
|
|
252
|
+
print("\n✗ 没有找到文件")
|
|
253
|
+
logger.info("没有找到文件")
|
|
254
|
+
return
|
|
255
|
+
|
|
256
|
+
total = len(files)
|
|
257
|
+
success_count = 0
|
|
258
|
+
fail_count = 0
|
|
259
|
+
|
|
260
|
+
for idx, file_path in enumerate(files, 1):
|
|
261
|
+
print(f"\n进度: [{idx}/{total}]")
|
|
262
|
+
logger.info(f"进度: [{idx}/{total}] - {file_path}")
|
|
263
|
+
|
|
264
|
+
try:
|
|
265
|
+
if self.process_file(file_path):
|
|
266
|
+
success_count += 1
|
|
267
|
+
else:
|
|
268
|
+
fail_count += 1
|
|
269
|
+
except Exception as e:
|
|
270
|
+
print(f"\n✗✗✗ 文件处理异常: {str(e)}")
|
|
271
|
+
logger.error(f"文件处理异常 {file_path}: {str(e)}")
|
|
272
|
+
fail_count += 1
|
|
273
|
+
|
|
274
|
+
if idx < total:
|
|
275
|
+
time.sleep(1)
|
|
276
|
+
|
|
277
|
+
elapsed = time.time() - start_time
|
|
278
|
+
print("\n" + "=" * 60)
|
|
279
|
+
print("Pipeline 执行完成!")
|
|
280
|
+
print("=" * 60)
|
|
281
|
+
print(f"总文件数: {total}")
|
|
282
|
+
print(f"成功: {success_count}")
|
|
283
|
+
print(f"失败: {fail_count}")
|
|
284
|
+
print(f"总耗时: {elapsed:.2f} 秒")
|
|
285
|
+
print("=" * 60)
|
|
286
|
+
|
|
287
|
+
logger.info("=" * 60)
|
|
288
|
+
logger.info(f"Pipeline 完成 - 总数:{total}, 成功:{success_count}, 失败:{fail_count}, 耗时:{elapsed:.2f}秒")
|
|
289
|
+
logger.info("=" * 60)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def create_pipeline_from_config(config: Dict[str, Any]) -> Pipeline:
|
|
293
|
+
source_config = config['source']
|
|
294
|
+
if source_config['type'] == 's3':
|
|
295
|
+
source = S3Source(
|
|
296
|
+
endpoint=source_config['endpoint'],
|
|
297
|
+
access_key=source_config['access_key'],
|
|
298
|
+
secret_key=source_config['secret_key'],
|
|
299
|
+
bucket=source_config['bucket'],
|
|
300
|
+
prefix=source_config.get('prefix', ''),
|
|
301
|
+
region=source_config.get('region', 'us-east-1'),
|
|
302
|
+
pattern=source_config.get('pattern', '*')
|
|
303
|
+
)
|
|
304
|
+
elif source_config['type'] == 'local':
|
|
305
|
+
source = LocalSource(
|
|
306
|
+
directory=source_config['directory'],
|
|
307
|
+
pattern=source_config.get('pattern', '*')
|
|
308
|
+
)
|
|
309
|
+
elif source_config['type'] == 'ftp':
|
|
310
|
+
source = FtpSource(
|
|
311
|
+
host=source_config['host'],
|
|
312
|
+
port=source_config['port'],
|
|
313
|
+
username=source_config['username'],
|
|
314
|
+
password=source_config['password'],
|
|
315
|
+
pattern=source_config.get('pattern', '*')
|
|
316
|
+
)
|
|
317
|
+
elif source_config['type'] == 'smb':
|
|
318
|
+
source = SmbSource(
|
|
319
|
+
host=source_config['host'],
|
|
320
|
+
share_name=source_config['share_name'],
|
|
321
|
+
username=source_config['username'],
|
|
322
|
+
password=source_config['password'],
|
|
323
|
+
domain=source_config.get('domain', ''),
|
|
324
|
+
port=source_config.get('port', 445),
|
|
325
|
+
path=source_config.get('path', ''),
|
|
326
|
+
pattern=source_config.get('pattern', '*')
|
|
327
|
+
)
|
|
328
|
+
else:
|
|
329
|
+
raise ValueError(f"未知的 source 类型: {source_config['type']}")
|
|
330
|
+
|
|
331
|
+
dest_config = config['destination']
|
|
332
|
+
if dest_config['type'] in ['milvus', 'zilliz']:
|
|
333
|
+
destination = MilvusDestination(
|
|
334
|
+
db_path=dest_config['db_path'],
|
|
335
|
+
collection_name=dest_config['collection_name'],
|
|
336
|
+
dimension=dest_config['dimension'],
|
|
337
|
+
api_key=dest_config.get('api_key'),
|
|
338
|
+
token=dest_config.get('token')
|
|
339
|
+
)
|
|
340
|
+
elif dest_config['type'] == 'local':
|
|
341
|
+
destination = LocalDestination(
|
|
342
|
+
output_dir=dest_config['output_dir']
|
|
343
|
+
)
|
|
344
|
+
elif dest_config['type'] == 's3':
|
|
345
|
+
destination = S3Destination(
|
|
346
|
+
endpoint=dest_config['endpoint'],
|
|
347
|
+
access_key=dest_config['access_key'],
|
|
348
|
+
secret_key=dest_config['secret_key'],
|
|
349
|
+
bucket=dest_config['bucket'],
|
|
350
|
+
prefix=dest_config.get('prefix', ''),
|
|
351
|
+
region=dest_config.get('region', 'us-east-1')
|
|
352
|
+
)
|
|
353
|
+
else:
|
|
354
|
+
raise ValueError(f"未知的 destination 类型: {dest_config['type']}")
|
|
355
|
+
|
|
356
|
+
# 处理 stages 配置
|
|
357
|
+
if 'stages' not in config or not config['stages']:
|
|
358
|
+
raise ValueError("配置中必须包含 'stages' 字段")
|
|
359
|
+
|
|
360
|
+
stages = []
|
|
361
|
+
for stage_cfg in config['stages']:
|
|
362
|
+
stage_type = stage_cfg.get('type')
|
|
363
|
+
stage_config_dict = stage_cfg.get('config', {})
|
|
364
|
+
|
|
365
|
+
if stage_type == 'parse':
|
|
366
|
+
parse_cfg_copy = dict(stage_config_dict)
|
|
367
|
+
provider = parse_cfg_copy.pop('provider', 'textin')
|
|
368
|
+
stage_config = ParseConfig(provider=provider, **parse_cfg_copy)
|
|
369
|
+
elif stage_type == 'chunk':
|
|
370
|
+
stage_config = ChunkConfig(
|
|
371
|
+
strategy=stage_config_dict.get('strategy', 'basic'),
|
|
372
|
+
include_orig_elements=stage_config_dict.get('include_orig_elements', False),
|
|
373
|
+
new_after_n_chars=stage_config_dict.get('new_after_n_chars', 512),
|
|
374
|
+
max_characters=stage_config_dict.get('max_characters', 1024),
|
|
375
|
+
overlap=stage_config_dict.get('overlap', 0),
|
|
376
|
+
overlap_all=stage_config_dict.get('overlap_all', False)
|
|
377
|
+
)
|
|
378
|
+
elif stage_type == 'embed':
|
|
379
|
+
stage_config = EmbedConfig(
|
|
380
|
+
provider=stage_config_dict.get('provider', 'qwen'),
|
|
381
|
+
model_name=stage_config_dict.get('model_name', 'text-embedding-v3')
|
|
382
|
+
)
|
|
383
|
+
else:
|
|
384
|
+
raise ValueError(f"未知的 stage 类型: {stage_type}")
|
|
385
|
+
|
|
386
|
+
stages.append(Stage(type=stage_type, config=stage_config))
|
|
387
|
+
|
|
388
|
+
# 创建 Pipeline 配置
|
|
389
|
+
pipeline_config = None
|
|
390
|
+
if 'pipeline_config' in config and config['pipeline_config']:
|
|
391
|
+
pipeline_cfg = config['pipeline_config']
|
|
392
|
+
include_intermediate_results = pipeline_cfg.get('include_intermediate_results', False)
|
|
393
|
+
intermediate_results_destination = None
|
|
394
|
+
|
|
395
|
+
if include_intermediate_results:
|
|
396
|
+
if 'intermediate_results_destination' in pipeline_cfg:
|
|
397
|
+
dest_cfg = pipeline_cfg['intermediate_results_destination']
|
|
398
|
+
dest_type = dest_cfg.get('type')
|
|
399
|
+
|
|
400
|
+
if dest_type == 'local':
|
|
401
|
+
intermediate_results_destination = LocalDestination(
|
|
402
|
+
output_dir=dest_cfg['output_dir']
|
|
403
|
+
)
|
|
404
|
+
elif dest_type == 's3':
|
|
405
|
+
intermediate_results_destination = S3Destination(
|
|
406
|
+
endpoint=dest_cfg['endpoint'],
|
|
407
|
+
access_key=dest_cfg['access_key'],
|
|
408
|
+
secret_key=dest_cfg['secret_key'],
|
|
409
|
+
bucket=dest_cfg['bucket'],
|
|
410
|
+
prefix=dest_cfg.get('prefix', ''),
|
|
411
|
+
region=dest_cfg.get('region', 'us-east-1')
|
|
412
|
+
)
|
|
413
|
+
else:
|
|
414
|
+
raise ValueError(f"不支持的 intermediate_results_destination 类型: '{dest_type}',支持的类型: 'local', 's3'")
|
|
415
|
+
else:
|
|
416
|
+
raise ValueError("当 include_intermediate_results 为 True 时,必须设置 intermediate_results_destination")
|
|
417
|
+
|
|
418
|
+
pipeline_config = PipelineConfig(
|
|
419
|
+
include_intermediate_results=include_intermediate_results,
|
|
420
|
+
intermediate_results_destination=intermediate_results_destination
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
# 创建 Pipeline
|
|
424
|
+
pipeline = Pipeline(
|
|
425
|
+
source=source,
|
|
426
|
+
destination=destination,
|
|
427
|
+
api_base_url=config.get('api_base_url', 'http://localhost:8000/api/xparse'),
|
|
428
|
+
api_headers=config.get('api_headers', {}),
|
|
429
|
+
stages=stages,
|
|
430
|
+
pipeline_config=pipeline_config
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
return pipeline
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
__all__ = [
|
|
437
|
+
'Pipeline',
|
|
438
|
+
'create_pipeline_from_config',
|
|
439
|
+
]
|
|
440
|
+
|