xparse-client 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
example/run_pipeline.py CHANGED
@@ -96,8 +96,8 @@ def run_with_manual_setup():
96
96
  # )
97
97
  # source = S3Source(
98
98
  # endpoint='https://s3.oss-cn-shanghai.aliyuncs.com',
99
- # access_key='',
100
- # secret_key='',
99
+ # access_key='LTAI5tBgsaVfkbh9rbPyuB17',
100
+ # secret_key='JFIIaTGiXelv7DgBYNIBSStofF0S98',
101
101
  # bucket='textin',
102
102
  # prefix='',
103
103
  # region='cn-shanghai',
@@ -113,8 +113,8 @@ def run_with_manual_setup():
113
113
  # )
114
114
  # source = S3Source(
115
115
  # endpoint='https://tos-s3-cn-shanghai.volces.com',
116
- # access_key='',
117
- # secret_key='',
116
+ # access_key='AKLTMzNkZjk1OGM3MzBjNGQ1ZjhkMGQ4MThlNjBjYjZjYzA',
117
+ # secret_key='TnpWaE0yRTVaamRqTmpSbU5EY3pObUZrTTJVNE5qUm1NR0ppWkRrMFlqVQ==',
118
118
  # bucket='textin',
119
119
  # prefix='',
120
120
  # region='cn-shanghai'
@@ -127,14 +127,14 @@ def run_with_manual_setup():
127
127
  # prefix='',
128
128
  # region='cn-east-3'
129
129
  # )
130
- # source = S3Source(
131
- # endpoint='https://s3.us-east-1.amazonaws.com',
132
- # access_key='',
133
- # secret_key='',
134
- # bucket='textin-xparse',
135
- # prefix='',
136
- # region='us-east-1'
137
- # )
130
+ source = S3Source(
131
+ endpoint='https://s3.us-east-1.amazonaws.com',
132
+ access_key='AKIA6QUE3TVZADUWA4PO',
133
+ secret_key='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',
134
+ bucket='textin-xparse',
135
+ prefix='',
136
+ region='us-east-1'
137
+ )
138
138
  # source = S3Source(
139
139
  # endpoint='http://127.0.0.1:9000',
140
140
  # access_key='',
@@ -153,13 +153,14 @@ def run_with_manual_setup():
153
153
  # source = FtpSource(
154
154
  # host='127.0.0.1',
155
155
  # port=21,
156
+ # # recursive=True,
156
157
  # username='', # 用户名,按照实际填写
157
158
  # password='' # 密码,按照实际填写
158
159
  # )
159
- source = LocalSource(
160
- directory='/Users/ke_wang/Documents/doc',
161
- pattern='*.pdf' # 支持通配符: *.pdf, *.docx, **/*.txt
162
- )
160
+ # source = LocalSource(
161
+ # directory='/Users/ke_wang/Documents/doc',
162
+ # pattern='*.pdf' # 支持通配符: *.pdf, *.docx, **/*.txt
163
+ # )
163
164
 
164
165
  # 创建 Milvus 目的地
165
166
  # destination = MilvusDestination(
@@ -174,7 +175,7 @@ def run_with_manual_setup():
174
175
 
175
176
  destination = MilvusDestination(
176
177
  db_path='https://in03-5388093d0db1707.serverless.ali-cn-hangzhou.cloud.zilliz.com.cn', # zilliz连接地址
177
- collection_name='textin_test_2', # 数据库collection名称
178
+ collection_name='textin_test_3', # 数据库collection名称
178
179
  dimension=1024, # 向量维度,需与 embed API 返回一致
179
180
  api_key='872c3f5b3f3995c80dcda5c3d34f1f608815aef7671b6ee391ab37e40e79c892ce56d9c8c6565a03a3fd66da7e11b67f384c5c46' # Zilliz Cloud API Key
180
181
  )
@@ -98,6 +98,7 @@ class PipelineStats:
98
98
  chunked_elements: int = 0
99
99
  embedded_elements: int = 0
100
100
  stages: Optional[List[Stage]] = None # 存储实际执行的 stages
101
+ record_id: Optional[str] = None # 记录 ID,用于标识需要写入 Milvus 的记录
101
102
 
102
103
 
103
104
  @dataclass
@@ -18,6 +18,41 @@ from pymilvus import MilvusClient
18
18
  logger = logging.getLogger(__name__)
19
19
 
20
20
 
21
+ def _flatten_dict(data: Dict[str, Any], prefix: str = '', fixed_fields: set = None) -> Dict[str, Any]:
22
+ """递归展平嵌套字典
23
+
24
+ Args:
25
+ data: 要展平的字典
26
+ prefix: 键的前缀
27
+ fixed_fields: 需要排除的字段集合
28
+
29
+ Returns:
30
+ 展平后的字典
31
+ """
32
+ if fixed_fields is None:
33
+ fixed_fields = set()
34
+
35
+ result = {}
36
+ for key, value in data.items():
37
+ flat_key = f'{prefix}_{key}' if prefix else key
38
+
39
+ if flat_key in fixed_fields:
40
+ continue
41
+
42
+ if isinstance(value, dict):
43
+ # 递归展平嵌套字典
44
+ nested = _flatten_dict(value, flat_key, fixed_fields)
45
+ result.update(nested)
46
+ elif isinstance(value, list):
47
+ # 列表转换为 JSON 字符串
48
+ result[flat_key] = json.dumps(value, ensure_ascii=False)
49
+ else:
50
+ # 其他类型直接使用
51
+ result[flat_key] = value
52
+
53
+ return result
54
+
55
+
21
56
  class Destination(ABC):
22
57
  """数据目的地抽象基类"""
23
58
 
@@ -54,8 +89,7 @@ class MilvusDestination(Destination):
54
89
  schema.add_field(field_name="element_id", datatype=DataType.VARCHAR, max_length=128, is_primary=True)
55
90
  schema.add_field(field_name="embeddings", datatype=DataType.FLOAT_VECTOR, dim=dimension)
56
91
  schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
57
- schema.add_field(field_name="record_id", datatype=DataType.VARCHAR, max_length=128)
58
- schema.add_field(field_name="metadata", datatype=DataType.JSON)
92
+ schema.add_field(field_name="record_id", datatype=DataType.VARCHAR, max_length=200)
59
93
 
60
94
  index_params = self.client.prepare_index_params()
61
95
  index_params.add_index(
@@ -77,6 +111,32 @@ class MilvusDestination(Destination):
77
111
 
78
112
  def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
79
113
  try:
114
+ # 如果 metadata 中有 record_id,先删除相同 record_id 的现有记录
115
+ record_id = metadata.get('record_id')
116
+ if record_id:
117
+ try:
118
+ # 删除相同 record_id 的所有记录
119
+ # MilvusClient.delete 返回删除的记录数(可能是 int 或 dict)
120
+ result = self.client.delete(
121
+ collection_name=self.collection_name,
122
+ filter=f'record_id == "{record_id}"'
123
+ )
124
+ # 处理返回值:可能是数字或字典
125
+ deleted_count = result if isinstance(result, int) else result.get('delete_count', 0) if isinstance(result, dict) else 0
126
+ if deleted_count > 0:
127
+ print(f" ✓ 删除现有记录: record_id={record_id}, 删除 {deleted_count} 条")
128
+ logger.info(f"删除 Milvus 现有记录: record_id={record_id}, 删除 {deleted_count} 条")
129
+ else:
130
+ print(f" → 未找到现有记录: record_id={record_id}")
131
+ except Exception as e:
132
+ print(f" ! 删除现有记录失败: {str(e)}")
133
+ logger.warning(f"删除 Milvus 现有记录失败: record_id={record_id}, {str(e)}")
134
+ # 继续执行写入,不因为删除失败而中断
135
+ else:
136
+ print(f" → 没有 record_id")
137
+ logger.warning(f"没有 record_id")
138
+ return
139
+
80
140
  insert_data = []
81
141
  for item in data:
82
142
  # 获取元素级别的 metadata
@@ -90,8 +150,7 @@ class MilvusDestination(Destination):
90
150
  'embeddings': item['embeddings'],
91
151
  'text': item.get('text', ''),
92
152
  'element_id': element_id,
93
- 'record_id': element_metadata.get('record_id', ''),
94
- 'created_at': datetime.now().isoformat()
153
+ 'record_id': record_id
95
154
  }
96
155
 
97
156
  # 合并文件级别的 metadata 和元素级别的 metadata
@@ -103,17 +162,13 @@ class MilvusDestination(Destination):
103
162
  fixed_fields = {'embeddings', 'text', 'element_id', 'record_id', 'created_at', 'metadata'}
104
163
  for key, value in merged_metadata.items():
105
164
  if key not in fixed_fields:
106
- # 特殊处理 data_source 字段:如果是字典则展平
165
+ # 特殊处理 data_source 字段:如果是字典则递归展平
107
166
  if key == 'data_source' and isinstance(value, dict):
108
- # data_source 字典展平为 data_source_* 格式
109
- for sub_key, sub_value in value.items():
110
- flat_key = f'data_source_{sub_key}'
111
- if flat_key not in fixed_fields:
112
- # 如果子值也是字典或列表,转换为 JSON 字符串
113
- if isinstance(sub_value, (dict, list)):
114
- insert_item[flat_key] = json.dumps(sub_value, ensure_ascii=False)
115
- else:
116
- insert_item[flat_key] = sub_value
167
+ # 递归展平 data_source 字典,包括嵌套的字典
168
+ flattened = _flatten_dict(value, 'data_source', fixed_fields)
169
+ insert_item.update(flattened)
170
+ elif key == 'coordinates' and isinstance(value, list):
171
+ insert_item[key] = value
117
172
  elif isinstance(value, (dict, list)):
118
173
  continue
119
174
  else:
@@ -149,8 +204,8 @@ class LocalDestination(Destination):
149
204
 
150
205
  def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
151
206
  try:
152
- file_name = metadata.get('file_name', 'output')
153
- base_name = Path(file_name).stem
207
+ filename = metadata.get('filename', 'output')
208
+ base_name = Path(filename).stem
154
209
  stage = metadata.get('stage') # 用于区分中间结果的阶段
155
210
 
156
211
  # 如果是中间结果,在文件名中添加阶段标识
@@ -218,8 +273,8 @@ class S3Destination(Destination):
218
273
 
219
274
  def write(self, data: List[Dict[str, Any]], metadata: Dict[str, Any]) -> bool:
220
275
  try:
221
- file_name = metadata.get('file_name', 'output')
222
- base_name = Path(file_name).stem
276
+ filename = metadata.get('filename', 'output')
277
+ base_name = Path(filename).stem
223
278
  object_key = f"{self.prefix}/{base_name}.json" if self.prefix else f"{base_name}.json"
224
279
 
225
280
  json_data = json.dumps(data, ensure_ascii=False, indent=2)
@@ -79,13 +79,13 @@ class Pipeline:
79
79
  print(f" Pipeline Config: 中间结果保存已启用")
80
80
  print("=" * 60)
81
81
 
82
- def _call_pipeline_api(self, file_bytes: bytes, file_name: str, data_source: Dict[str, Any]) -> Optional[Dict[str, Any]]:
82
+ def _call_pipeline_api(self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]) -> Optional[Dict[str, Any]]:
83
83
  url = f"{self.api_base_url}/pipeline"
84
84
  max_retries = 3
85
85
 
86
86
  for try_count in range(max_retries):
87
87
  try:
88
- files = {'file': (file_name or 'file', file_bytes)}
88
+ files = {'file': (filename or 'file', file_bytes)}
89
89
  form_data = {}
90
90
 
91
91
  # 将 stages 转换为 API 格式
@@ -107,26 +107,55 @@ class Pipeline:
107
107
 
108
108
  if response.status_code == 200:
109
109
  result = response.json()
110
- print(f" ✓ Pipeline 接口返回 x_request_id: {result.get('x_request_id')}")
110
+ x_request_id = result.get('x_request_id', '')
111
+ print(f" ✓ Pipeline 接口返回 x_request_id: {x_request_id}")
111
112
  if result.get('code') == 200 and 'data' in result:
112
113
  return result.get('data')
114
+ # 如果 code 不是 200,打印错误信息
115
+ error_msg = result.get('message', result.get('msg', '未知错误'))
116
+ print(f" ✗ Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}")
117
+ logger.error(f"Pipeline 接口返回错误: code={result.get('code')}, message={error_msg}, x_request_id={x_request_id}")
113
118
  return None
114
119
  else:
115
- print(f" ! API 错误 {response.status_code}, 重试 {try_count + 1}/{max_retries}")
116
- logger.warning(f"API 错误 {response.status_code}: pipeline")
120
+ # 尝试解析响应获取 x_request_id 和错误信息
121
+ x_request_id = ''
122
+ error_msg = ''
123
+ try:
124
+ result = response.json()
125
+ x_request_id = result.get('x_request_id', '')
126
+ error_msg = result.get('message', result.get('msg', response.text[:200]))
127
+ except:
128
+ error_msg = response.text[:200] if response.text else f'HTTP {response.status_code}'
129
+
130
+ print(f" ✗ API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}")
131
+ logger.warning(f"API 错误 {response.status_code}: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}")
117
132
 
118
133
  except Exception as e:
119
- print(f" ! 请求异常: {str(e)}, 重试 {try_count + 1}/{max_retries}")
120
- logger.error(f"API 请求异常 pipeline: {str(e)}")
134
+ # 如果是 requests 异常,尝试从响应中获取 x_request_id
135
+ x_request_id = ''
136
+ error_msg = str(e)
137
+ try:
138
+ if hasattr(e, 'response') and e.response is not None:
139
+ try:
140
+ result = e.response.json()
141
+ x_request_id = result.get('x_request_id', '')
142
+ error_msg = result.get('message', result.get('msg', error_msg))
143
+ except:
144
+ pass
145
+ except:
146
+ pass
147
+
148
+ print(f" ✗ 请求异常: {error_msg}, x_request_id={x_request_id}, 重试 {try_count + 1}/{max_retries}")
149
+ logger.error(f"API 请求异常 pipeline: {error_msg}, x_request_id={x_request_id}")
121
150
 
122
151
  if try_count < max_retries - 1:
123
152
  time.sleep(2)
124
153
 
125
154
  return None
126
155
 
127
- def process_with_pipeline(self, file_bytes: bytes, file_name: str, data_source: Dict[str, Any]) -> Optional[Tuple[List[Dict[str, Any]], PipelineStats]]:
128
- print(f" → 调用 Pipeline 接口: {file_name}")
129
- result = self._call_pipeline_api(file_bytes, file_name, data_source)
156
+ def process_with_pipeline(self, file_bytes: bytes, filename: str, data_source: Dict[str, Any]) -> Optional[Tuple[List[Dict[str, Any]], PipelineStats]]:
157
+ print(f" → 调用 Pipeline 接口: {filename}")
158
+ result = self._call_pipeline_api(file_bytes, filename, data_source)
130
159
 
131
160
  if result and 'elements' in result and 'stats' in result:
132
161
  elements = result['elements']
@@ -136,31 +165,32 @@ class Pipeline:
136
165
  original_elements=stats_data.get('original_elements', 0),
137
166
  chunked_elements=stats_data.get('chunked_elements', 0),
138
167
  embedded_elements=stats_data.get('embedded_elements', 0),
139
- stages=self.stages # 使用实际执行的 stages
168
+ stages=self.stages, # 使用实际执行的 stages
169
+ record_id=stats_data.get('record_id') # 从 API 响应中获取 record_id
140
170
  )
141
171
 
142
172
  # 如果启用了中间结果保存,处理中间结果
143
173
  if self.pipeline_config.include_intermediate_results and 'intermediate_results' in result:
144
- self._save_intermediate_results(result['intermediate_results'], file_name, data_source)
174
+ self._save_intermediate_results(result['intermediate_results'], filename, data_source)
145
175
 
146
176
  print(f" ✓ Pipeline 完成:")
147
177
  print(f" - 原始元素: {stats.original_elements}")
148
178
  print(f" - 分块后: {stats.chunked_elements}")
149
179
  print(f" - 向量化: {stats.embedded_elements}")
150
- logger.info(f"Pipeline 完成: {file_name}, {stats.embedded_elements} 个向量")
180
+ logger.info(f"Pipeline 完成: {filename}, {stats.embedded_elements} 个向量")
151
181
 
152
182
  return elements, stats
153
183
  else:
154
184
  print(f" ✗ Pipeline 失败")
155
- logger.error(f"Pipeline 失败: {file_name}")
185
+ logger.error(f"Pipeline 失败: {filename}")
156
186
  return None
157
187
 
158
- def _save_intermediate_results(self, intermediate_results: List[Dict[str, Any]], file_name: str, data_source: Dict[str, Any]) -> None:
188
+ def _save_intermediate_results(self, intermediate_results: List[Dict[str, Any]], filename: str, data_source: Dict[str, Any]) -> None:
159
189
  """保存中间结果
160
190
 
161
191
  Args:
162
192
  intermediate_results: 中间结果数组,每个元素包含 stage 和 elements 字段
163
- file_name: 文件名
193
+ filename: 文件名
164
194
  data_source: 数据源信息
165
195
  """
166
196
  try:
@@ -174,7 +204,7 @@ class Pipeline:
174
204
  elements = result_item['elements']
175
205
 
176
206
  metadata = {
177
- 'file_name': file_name,
207
+ 'filename': filename,
178
208
  'stage': stage,
179
209
  'total_elements': len(elements),
180
210
  'processed_at': datetime.now().isoformat(),
@@ -183,11 +213,11 @@ class Pipeline:
183
213
 
184
214
  self.pipeline_config.intermediate_results_destination.write(elements, metadata)
185
215
  print(f" ✓ 保存 {stage.upper()} 中间结果: {len(elements)} 个元素")
186
- logger.info(f"保存 {stage.upper()} 中间结果成功: {file_name}")
216
+ logger.info(f"保存 {stage.upper()} 中间结果成功: {filename}")
187
217
 
188
218
  except Exception as e:
189
219
  print(f" ✗ 保存中间结果失败: {str(e)}")
190
- logger.error(f"保存中间结果失败: {file_name}, {str(e)}")
220
+ logger.error(f"保存中间结果失败: {filename}, {str(e)}")
191
221
 
192
222
  def process_file(self, file_path: str) -> bool:
193
223
  print(f"\n{'=' * 60}")
@@ -198,7 +228,10 @@ class Pipeline:
198
228
  print(f" → 读取文件...")
199
229
  file_bytes, data_source = self.source.read_file(file_path)
200
230
  data_source = data_source or {}
201
- data_source['date_processed'] = datetime.now(timezone.utc).timestamp()
231
+ print(f" data_source: {data_source}")
232
+ # 转换为毫秒时间戳字符串
233
+ timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000)
234
+ data_source['date_processed'] = str(timestamp_ms)
202
235
  print(f" ✓ 文件读取完成: {len(file_bytes)} bytes")
203
236
 
204
237
  result = self.process_with_pipeline(file_bytes, file_path, data_source)
@@ -209,16 +242,13 @@ class Pipeline:
209
242
 
210
243
  print(f" → 写入目的地...")
211
244
  metadata = {
212
- 'file_name': file_path,
213
- 'total_elements': len(embedded_data),
214
- 'processed_at': datetime.now().isoformat(),
215
- 'data_source': data_source,
216
- 'stats': {
217
- 'original_elements': stats.original_elements,
218
- 'chunked_elements': stats.chunked_elements,
219
- 'embedded_elements': stats.embedded_elements
220
- }
245
+ 'filename': file_path,
246
+ 'processed_at': str(timestamp_ms),
221
247
  }
248
+
249
+ # 如果 stats 中有 record_id,添加到 metadata 中
250
+ if stats.record_id:
251
+ metadata['record_id'] = stats.record_id
222
252
 
223
253
  success = self.destination.write(embedded_data, metadata)
224
254
 
@@ -299,12 +329,14 @@ def create_pipeline_from_config(config: Dict[str, Any]) -> Pipeline:
299
329
  bucket=source_config['bucket'],
300
330
  prefix=source_config.get('prefix', ''),
301
331
  region=source_config.get('region', 'us-east-1'),
302
- pattern=source_config.get('pattern', '*')
332
+ pattern=source_config.get('pattern', '*'),
333
+ recursive=source_config.get('recursive', False)
303
334
  )
304
335
  elif source_config['type'] == 'local':
305
336
  source = LocalSource(
306
337
  directory=source_config['directory'],
307
- pattern=source_config.get('pattern', '*')
338
+ pattern=source_config.get('pattern', '*'),
339
+ recursive=source_config.get('recursive', False)
308
340
  )
309
341
  elif source_config['type'] == 'ftp':
310
342
  source = FtpSource(
@@ -312,7 +344,8 @@ def create_pipeline_from_config(config: Dict[str, Any]) -> Pipeline:
312
344
  port=source_config['port'],
313
345
  username=source_config['username'],
314
346
  password=source_config['password'],
315
- pattern=source_config.get('pattern', '*')
347
+ pattern=source_config.get('pattern', '*'),
348
+ recursive=source_config.get('recursive', False)
316
349
  )
317
350
  elif source_config['type'] == 'smb':
318
351
  source = SmbSource(
@@ -323,7 +356,8 @@ def create_pipeline_from_config(config: Dict[str, Any]) -> Pipeline:
323
356
  domain=source_config.get('domain', ''),
324
357
  port=source_config.get('port', 445),
325
358
  path=source_config.get('path', ''),
326
- pattern=source_config.get('pattern', '*')
359
+ pattern=source_config.get('pattern', '*'),
360
+ recursive=source_config.get('recursive', False)
327
361
  )
328
362
  else:
329
363
  raise ValueError(f"未知的 source 类型: {source_config['type']}")
@@ -20,6 +20,30 @@ from botocore.config import Config
20
20
  logger = logging.getLogger(__name__)
21
21
 
22
22
 
23
+ def _to_millis_timestamp_string(timestamp):
24
+ """将时间戳转换为毫秒时间戳字符串
25
+
26
+ Args:
27
+ timestamp: 时间戳(秒或毫秒),可以是 int、float 或 None
28
+
29
+ Returns:
30
+ str: 毫秒时间戳字符串,如果输入为 None 则返回空字符串
31
+ """
32
+ if timestamp is None:
33
+ return ""
34
+
35
+ # 如果已经是毫秒时间戳(大于 1e12),直接转换
36
+ if isinstance(timestamp, (int, float)):
37
+ if timestamp > 1e12:
38
+ # 已经是毫秒时间戳
39
+ return str(int(timestamp))
40
+ else:
41
+ # 秒级时间戳,转换为毫秒
42
+ return str(int(timestamp * 1000))
43
+
44
+ return str(timestamp)
45
+
46
+
23
47
  class Source(ABC):
24
48
  """数据源抽象基类"""
25
49
 
@@ -38,11 +62,12 @@ class S3Source(Source):
38
62
  """S3/MinIO 数据源"""
39
63
 
40
64
  def __init__(self, endpoint: str, access_key: str, secret_key: str,
41
- bucket: str, prefix: str = '', region: str = 'us-east-1', pattern: str = '*'):
65
+ bucket: str, prefix: str = '', region: str = 'us-east-1', pattern: str = '*', recursive: bool = False):
42
66
  self.endpoint = endpoint
43
67
  self.bucket = bucket
44
68
  self.prefix = prefix
45
69
  self.pattern = pattern or '*'
70
+ self.recursive = recursive
46
71
 
47
72
  if self.endpoint == 'https://textin-minio-api.ai.intsig.net':
48
73
  config = Config(signature_version='s3v4')
@@ -73,8 +98,12 @@ class S3Source(Source):
73
98
  params = {'Bucket': self.bucket}
74
99
  if self.prefix:
75
100
  params['Prefix'] = self.prefix
101
+ if not self.recursive:
102
+ # 非递归模式:使用 Delimiter 只列出当前目录下的文件
103
+ params['Delimiter'] = '/'
76
104
 
77
105
  for page in paginator.paginate(**params):
106
+ print(page)
78
107
  if 'Contents' in page:
79
108
  for obj in page['Contents']:
80
109
  key = obj['Key']
@@ -82,6 +111,11 @@ class S3Source(Source):
82
111
  continue
83
112
  if fnmatch(key, self.pattern):
84
113
  files.append(key)
114
+
115
+ # 非递归模式下,CommonPrefixes 包含子目录,我们忽略它们
116
+ if not self.recursive and 'CommonPrefixes' in page:
117
+ # 这些是子目录,在非递归模式下忽略
118
+ pass
85
119
 
86
120
  print(f"✓ S3 找到 {len(files)} 个文件")
87
121
  return files
@@ -91,7 +125,9 @@ class S3Source(Source):
91
125
  file_bytes = response['Body'].read()
92
126
 
93
127
  headers = response.get('ResponseMetadata', {}).get('HTTPHeaders', {})
94
- version = response.get('VersionId') or headers.get('x-amz-version-id')
128
+ version = headers.get('etag') or ""
129
+ if version.startswith('"') and version.endswith('"'):
130
+ version = version[1:-1]
95
131
  last_modified = headers.get('last-modified')
96
132
  server = headers.get('server') or "unknown"
97
133
  date_modified = None
@@ -106,8 +142,8 @@ class S3Source(Source):
106
142
  data_source = {
107
143
  'url': f"s3://{self.bucket}/{normalized_key}",
108
144
  'version': version,
109
- 'date_created': date_modified,
110
- 'date_modified': date_modified,
145
+ 'date_created': _to_millis_timestamp_string(date_modified),
146
+ 'date_modified': _to_millis_timestamp_string(date_modified),
111
147
  'record_locator': {
112
148
  'server': server,
113
149
  'protocol': 's3',
@@ -121,9 +157,10 @@ class S3Source(Source):
121
157
  class LocalSource(Source):
122
158
  """本地文件系统数据源"""
123
159
 
124
- def __init__(self, directory: str, pattern: str = '*'):
160
+ def __init__(self, directory: str, pattern: str = '*', recursive: bool = False):
125
161
  self.directory = Path(directory)
126
162
  self.pattern = pattern or '*'
163
+ self.recursive = recursive
127
164
 
128
165
  if not self.directory.exists():
129
166
  raise ValueError(f"目录不存在: {directory}")
@@ -132,11 +169,20 @@ class LocalSource(Source):
132
169
  logger.info(f"本地目录: {self.directory}")
133
170
 
134
171
  def list_files(self) -> List[str]:
135
- files = [
136
- str(f.relative_to(self.directory))
137
- for f in self.directory.rglob(self.pattern)
138
- if f.is_file()
139
- ]
172
+ if self.recursive:
173
+ # 递归模式:使用 rglob
174
+ files = [
175
+ str(f.relative_to(self.directory))
176
+ for f in self.directory.rglob(self.pattern)
177
+ if f.is_file()
178
+ ]
179
+ else:
180
+ # 非递归模式:只列出根目录下的文件,使用 glob
181
+ files = [
182
+ str(f.relative_to(self.directory))
183
+ for f in self.directory.glob(self.pattern)
184
+ if f.is_file()
185
+ ]
140
186
  print(f"✓ 本地找到 {len(files)} 个文件")
141
187
  return files
142
188
 
@@ -159,8 +205,8 @@ class LocalSource(Source):
159
205
  data_source = {
160
206
  'url': full_path.as_uri(),
161
207
  'version': version,
162
- 'date_created': date_created,
163
- 'date_modified': date_modified,
208
+ 'date_created': _to_millis_timestamp_string(date_created),
209
+ 'date_modified': _to_millis_timestamp_string(date_modified),
164
210
  'record_locator': {
165
211
  'protocol': 'file',
166
212
  'remote_file_path': str(full_path)
@@ -172,12 +218,13 @@ class LocalSource(Source):
172
218
  class FtpSource(Source):
173
219
  """FTP 数据源"""
174
220
 
175
- def __init__(self, host: str, port: int, username: str, password: str, pattern: str = '*'):
221
+ def __init__(self, host: str, port: int, username: str, password: str, pattern: str = '*', recursive: bool = False):
176
222
  self.host = host
177
223
  self.port = port
178
224
  self.username = username
179
225
  self.password = password
180
226
  self.pattern = pattern or '*'
227
+ self.recursive = recursive
181
228
 
182
229
  self.client = ftplib.FTP()
183
230
  self.client.connect(self.host, self.port)
@@ -187,8 +234,139 @@ class FtpSource(Source):
187
234
  logger.info(f"FTP 连接成功: {self.host}:{self.port}")
188
235
 
189
236
  def list_files(self) -> List[str]:
190
- raw_files = self.client.nlst()
191
- files = [f for f in raw_files if fnmatch(f, self.pattern)]
237
+ if self.recursive:
238
+ # 递归模式:递归列出所有文件
239
+ files = []
240
+ current_dir = self.client.pwd()
241
+
242
+ def _list_recursive(path=''):
243
+ try:
244
+ # 保存当前目录
245
+ original_dir = self.client.pwd()
246
+ if path:
247
+ try:
248
+ self.client.cwd(path)
249
+ except:
250
+ return
251
+
252
+ items = []
253
+ try:
254
+ # 尝试使用 MLSD 命令(更可靠)
255
+ items = []
256
+ for item in self.client.mlsd():
257
+ items.append(item)
258
+ except:
259
+ # 如果不支持 MLSD,使用 LIST 命令
260
+ try:
261
+ lines = []
262
+ self.client.retrlines('LIST', lines.append)
263
+ for line in lines:
264
+ parts = line.split()
265
+ if len(parts) >= 9:
266
+ # 解析 LIST 输出,第一个字符表示文件类型
267
+ item_name = ' '.join(parts[8:])
268
+ is_dir = parts[0].startswith('d')
269
+ items.append((item_name, {'type': 'dir' if is_dir else 'file'}))
270
+ except:
271
+ # 最后回退到 nlst,但无法区分文件和目录
272
+ for item_name in self.client.nlst():
273
+ items.append((item_name, {'type': 'unknown'}))
274
+
275
+ for item_name, item_info in items:
276
+ if item_name in ['.', '..']:
277
+ continue
278
+
279
+ item_type = item_info.get('type', 'unknown')
280
+ full_path = f"{path}/{item_name}" if path else item_name
281
+
282
+ if item_type == 'dir' or item_type == 'unknown':
283
+ # 尝试切换目录来判断是否为目录
284
+ try:
285
+ self.client.cwd(item_name)
286
+ self.client.cwd('..')
287
+ # 是目录,递归处理
288
+ _list_recursive(full_path)
289
+ except:
290
+ # 不是目录,是文件
291
+ relative_path = full_path.lstrip('/')
292
+ if fnmatch(relative_path, self.pattern):
293
+ files.append(relative_path)
294
+ else:
295
+ # 是文件
296
+ relative_path = full_path.lstrip('/')
297
+ if fnmatch(relative_path, self.pattern):
298
+ files.append(relative_path)
299
+
300
+ # 恢复原始目录
301
+ self.client.cwd(original_dir)
302
+ except Exception as e:
303
+ logger.warning(f"FTP 列出路径失败 {path}: {str(e)}")
304
+ try:
305
+ self.client.cwd(current_dir)
306
+ except:
307
+ pass
308
+
309
+ _list_recursive()
310
+ # 确保回到原始目录
311
+ try:
312
+ self.client.cwd(current_dir)
313
+ except:
314
+ pass
315
+ else:
316
+ # 非递归模式:只列出当前目录下的文件(排除目录)
317
+ files = []
318
+ current_dir = self.client.pwd()
319
+
320
+ try:
321
+ # 尝试使用 MLSD 命令(更可靠)
322
+ items = []
323
+ for item_name, item_info in self.client.mlsd():
324
+ if item_name in ['.', '..']:
325
+ continue
326
+ item_type = item_info.get('type', 'unknown')
327
+ # 只添加文件,排除目录
328
+ if item_type == 'file' or (item_type == 'unknown' and not item_info.get('type', '').startswith('dir')):
329
+ if fnmatch(item_name, self.pattern):
330
+ files.append(item_name)
331
+ except:
332
+ # 如果不支持 MLSD,使用 LIST 命令
333
+ try:
334
+ lines = []
335
+ self.client.retrlines('LIST', lines.append)
336
+ for line in lines:
337
+ parts = line.split()
338
+ if len(parts) >= 9:
339
+ # 解析 LIST 输出,第一个字符表示文件类型
340
+ item_name = ' '.join(parts[8:])
341
+ if item_name in ['.', '..']:
342
+ continue
343
+ is_dir = parts[0].startswith('d')
344
+ # 只添加文件,排除目录
345
+ if not is_dir and fnmatch(item_name, self.pattern):
346
+ files.append(item_name)
347
+ except:
348
+ # 最后回退到 nlst,通过尝试切换目录来判断是否为目录
349
+ raw_items = self.client.nlst()
350
+ for item_name in raw_items:
351
+ if item_name in ['.', '..']:
352
+ continue
353
+ # 尝试切换目录来判断是否为目录
354
+ try:
355
+ self.client.cwd(item_name)
356
+ self.client.cwd('..')
357
+ # 能切换成功,说明是目录,跳过
358
+ continue
359
+ except:
360
+ # 不能切换,说明是文件
361
+ if fnmatch(item_name, self.pattern):
362
+ files.append(item_name)
363
+
364
+ # 确保回到原始目录
365
+ try:
366
+ self.client.cwd(current_dir)
367
+ except:
368
+ pass
369
+
192
370
  print(f"✓ FTP 找到 {len(files)} 个文件 (匹配 pattern)")
193
371
  return files
194
372
 
@@ -208,11 +386,12 @@ class FtpSource(Source):
208
386
  logger.debug(f"FTP 获取文件时间失败 {file_path}: {exc}")
209
387
 
210
388
  normalized_path = file_path.lstrip('/')
389
+ version = _to_millis_timestamp_string(date_modified)
211
390
  data_source = {
212
391
  'url': f"ftp://{self.host}:{self.port}/{normalized_path}",
213
- 'version': None,
214
- 'date_created': None,
215
- 'date_modified': date_modified,
392
+ 'version': version,
393
+ 'date_created': version,
394
+ 'date_modified': version,
216
395
  'record_locator': {
217
396
  'server': f"{self.host}:{self.port}",
218
397
  'protocol': 'ftp',
@@ -227,7 +406,7 @@ class SmbSource(Source):
227
406
  """SMB/CIFS 数据源"""
228
407
 
229
408
  def __init__(self, host: str, share_name: str, username: str, password: str,
230
- domain: str = '', port: int = 445, path: str = '', pattern: str = '*'):
409
+ domain: str = '', port: int = 445, path: str = '', pattern: str = '*', recursive: bool = False):
231
410
  self.host = host
232
411
  self.share_name = share_name
233
412
  self.username = username
@@ -236,6 +415,7 @@ class SmbSource(Source):
236
415
  self.port = port
237
416
  self.path = path.strip('/').strip('\\') if path else ''
238
417
  self.pattern = pattern or '*'
418
+ self.recursive = recursive
239
419
 
240
420
  self.conn = SMBConnection(
241
421
  username,
@@ -267,7 +447,10 @@ class SmbSource(Source):
267
447
  item_path = f"{current_path.rstrip('/')}/{item.filename}" if current_path != '/' else f"/{item.filename}"
268
448
  relative_path = item_path[len(base_path):].lstrip('/')
269
449
  if item.isDirectory:
270
- _list_recursive(conn, share, item_path)
450
+ if self.recursive:
451
+ # 递归模式:继续递归子目录
452
+ _list_recursive(conn, share, item_path)
453
+ # 非递归模式:忽略子目录
271
454
  else:
272
455
  if fnmatch(relative_path, self.pattern):
273
456
  files.append(relative_path)
@@ -310,9 +493,9 @@ class SmbSource(Source):
310
493
  smb_url = f"smb://{self.host}/{self.share_name}{full_path}"
311
494
  data_source = {
312
495
  'url': smb_url,
313
- 'version': None,
314
- 'date_created': date_created,
315
- 'date_modified': date_modified,
496
+ 'version': _to_millis_timestamp_string(date_modified),
497
+ 'date_created': _to_millis_timestamp_string(date_created),
498
+ 'date_modified': _to_millis_timestamp_string(date_modified),
316
499
  'record_locator': {
317
500
  'server': self.host,
318
501
  'share': self.share_name,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xparse-client
3
- Version: 0.2.4
3
+ Version: 0.2.6
4
4
  Summary: 面向Agent和RAG的新一代文档处理 AI Infra
5
5
  License-Expression: MIT
6
6
  Project-URL: Homepage, https://gitlab.intsig.net/xparse1/xparse-pipeline
@@ -323,28 +323,34 @@ source = SmbSource(
323
323
  )
324
324
  ```
325
325
 
326
- > 提示:所有 Source 均支持 `pattern` 参数,使用 Shell 通配符(`*.pdf`、`**/*.txt` 等)来过滤需要处理的文件;默认为 `*`,即处理全部文件。
326
+ > 1:所有 Source 均支持 `pattern` 参数,使用 Shell 通配符(`*.pdf`、`**/*.txt` 等)来过滤需要处理的文件;默认为 `*`,即处理全部文件。
327
+
328
+ > 注 2:所有 Source 均支持 `recursive` 参数,表示是否递归遍历,默认为 `False`。
327
329
 
328
330
  ### Destination 配置
329
331
 
330
332
  #### 本地 Milvus 向量存储
331
333
 
334
+ collection 中至少需要包含 `element_id`,`text`,`embeddings`,`record_id` 四个字段。
335
+
332
336
  ```python
333
337
  destination = MilvusDestination(
334
- db_path: './milvus_pipeline.db', # 本地数据库文件
335
- collection_name: 'my_collection', # 数据库collection名称
336
- dimension: 1024 # 向量维度,需与 embed API 返回一致
338
+ db_path='./milvus_pipeline.db', # 本地数据库文件
339
+ collection_name='my_collection', # 数据库collection名称
340
+ dimension=1024 # 向量维度,需与 embed API 返回一致
337
341
  )
338
342
  ```
339
343
 
340
344
  #### Zilliz 向量存储
341
345
 
346
+ collection 中至少需要包含 `element_id`,`text`,`embeddings`,`record_id` 四个字段。
347
+
342
348
  ```python
343
349
  destination = MilvusDestination(
344
- db_path: 'https://xxxxxxx.serverless.xxxxxxx.cloud.zilliz.com.cn', # zilliz连接地址
345
- collection_name: 'my_collection', # 数据库collection名称
346
- dimension: 1024, # 向量维度,需与 embed API 返回一致
347
- api_key: 'your-api-key' # Zilliz Cloud API Key
350
+ db_path='https://xxxxxxx.serverless.xxxxxxx.cloud.zilliz.com.cn', # zilliz连接地址
351
+ collection_name='my_collection', # 数据库collection名称
352
+ dimension=1024, # 向量维度,需与 embed API 返回一致
353
+ api_key='your-api-key' # Zilliz Cloud API Key
348
354
  )
349
355
  ```
350
356
 
@@ -354,7 +360,7 @@ destination = MilvusDestination(
354
360
 
355
361
  ```python
356
362
  destination = LocalDestination(
357
- output_dir: './output'
363
+ output_dir='./output'
358
364
  )
359
365
  ```
360
366
 
@@ -0,0 +1,13 @@
1
+ example/run_pipeline.py,sha256=ijws5q_vMmV0-bMHuFtOUMrEnxnL1LvOBCtcCD2c8zc,15366
2
+ example/run_pipeline_test.py,sha256=uIU09FTv_VnTQS1Lc94ydc3kaD86eHkaHQbVXpsGEcA,14861
3
+ xparse_client/__init__.py,sha256=je1ena3HwLL4CRtLU4r6EAzoOIJthlPjTwshxZnzQDM,1677
4
+ xparse_client/pipeline/__init__.py,sha256=TVlb2AGCNKP0jrv3p4ZLZCPKp68hTVMFi00DTdi6QAo,49
5
+ xparse_client/pipeline/config.py,sha256=FFYq2a0dBWBEj70s2aInXOiQ5MwwHimd6SI2_tkp52w,4138
6
+ xparse_client/pipeline/destinations.py,sha256=F0z1AgVIBOn0m32i4l7LCMkJE0IbBdlpykO_at_wLaE,11931
7
+ xparse_client/pipeline/pipeline.py,sha256=BEC1kf5HKn4qIiNMe5QPNOs9PJcjaE_4qRK0I9yaLSQ,20430
8
+ xparse_client/pipeline/sources.py,sha256=0m2GFYCQTORSkbn_fWIWrOOSFes4Aso6oyVJDHfDeqc,19964
9
+ xparse_client-0.2.6.dist-info/licenses/LICENSE,sha256=ckIP-MbocsP9nqYnta5KgfAicYF196B5TNdHIR6kOO0,1075
10
+ xparse_client-0.2.6.dist-info/METADATA,sha256=-VNzfb68RsczdETZ17vRhGzClj2X5mrwiwT5-HzK1XE,26805
11
+ xparse_client-0.2.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
12
+ xparse_client-0.2.6.dist-info/top_level.txt,sha256=bfX8BWo1sEEQVsI4Ql4Uu80vrfEh5zfajU9YqFTzxMo,22
13
+ xparse_client-0.2.6.dist-info/RECORD,,
@@ -1,13 +0,0 @@
1
- example/run_pipeline.py,sha256=6gavTizAIqD62g4n9Pjq2-yW57ItZMJOOw8GEKm0Byk,15125
2
- example/run_pipeline_test.py,sha256=uIU09FTv_VnTQS1Lc94ydc3kaD86eHkaHQbVXpsGEcA,14861
3
- xparse_client/__init__.py,sha256=je1ena3HwLL4CRtLU4r6EAzoOIJthlPjTwshxZnzQDM,1677
4
- xparse_client/pipeline/__init__.py,sha256=TVlb2AGCNKP0jrv3p4ZLZCPKp68hTVMFi00DTdi6QAo,49
5
- xparse_client/pipeline/config.py,sha256=gkhAF-55PNvPPyfTZ0HkP95XB_K0HKCyYl6R4PTQLhI,4045
6
- xparse_client/pipeline/destinations.py,sha256=rqcxmsn1YGClVxGQxSVmyr-uumOVilOv_vX82fUBj-I,9859
7
- xparse_client/pipeline/pipeline.py,sha256=oz_BKWLbslkuRsxG0zEfh9url7saLWgtoTH1mrK6gCc,18282
8
- xparse_client/pipeline/sources.py,sha256=-0Eutg9t8xni12cfv2bdQVdImlkCQ7gWlOXIFBt6tpE,11568
9
- xparse_client-0.2.4.dist-info/licenses/LICENSE,sha256=ckIP-MbocsP9nqYnta5KgfAicYF196B5TNdHIR6kOO0,1075
10
- xparse_client-0.2.4.dist-info/METADATA,sha256=IMgXO9a7wnN0Ygzauk7eOkyrRFE3A2rq73eofvq3wBs,26508
11
- xparse_client-0.2.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
12
- xparse_client-0.2.4.dist-info/top_level.txt,sha256=bfX8BWo1sEEQVsI4Ql4Uu80vrfEh5zfajU9YqFTzxMo,22
13
- xparse_client-0.2.4.dist-info/RECORD,,