dtflow 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/pipeline.py ADDED
@@ -0,0 +1,450 @@
1
+ """
2
+ Pipeline 配置模块
3
+
4
+ 支持将数据处理流程导出为 YAML 配置,实现可复现的数据处理。
5
+ """
6
+ import random
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Dict, List, Optional, Union
9
+
10
+ from .core import DataTransformer
11
+ from .presets import get_preset, PRESETS
12
+ from .storage.io import load_data, save_data
13
+
14
+
15
+ # ============ Pipeline 配置格式 ============
16
+
17
+ PIPELINE_VERSION = "1.0"
18
+
19
+
20
+ def _load_yaml(filepath: str) -> Dict[str, Any]:
21
+ """加载 YAML 配置文件"""
22
+ try:
23
+ import yaml
24
+ except ImportError:
25
+ raise ImportError("需要安装 PyYAML: pip install pyyaml")
26
+
27
+ with open(filepath, "r", encoding="utf-8") as f:
28
+ return yaml.safe_load(f)
29
+
30
+
31
+ def _save_yaml(data: Dict[str, Any], filepath: str) -> None:
32
+ """保存 YAML 配置文件"""
33
+ try:
34
+ import yaml
35
+ except ImportError:
36
+ raise ImportError("需要安装 PyYAML: pip install pyyaml")
37
+
38
+ Path(filepath).parent.mkdir(parents=True, exist_ok=True)
39
+ with open(filepath, "w", encoding="utf-8") as f:
40
+ yaml.dump(data, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
41
+
42
+
43
+ # ============ 步骤执行器 ============
44
+
45
+
46
+ def _execute_filter(dt: DataTransformer, step: Dict[str, Any]) -> DataTransformer:
47
+ """
48
+ 执行 filter 步骤。
49
+
50
+ 支持的条件格式:
51
+ - 简单比较:field > value, field == value, field != value
52
+ - 长度过滤:len(field) > value
53
+ - 非空过滤:field is not None, field is not empty
54
+ """
55
+ condition = step.get("condition", "")
56
+ field = step.get("field")
57
+
58
+ if not condition and not field:
59
+ raise ValueError("filter 步骤需要指定 condition 或 field")
60
+
61
+ # 简单字段非空过滤
62
+ if field and not condition:
63
+ return dt.filter(lambda x, f=field: bool(x.get(f)), raw=True)
64
+
65
+ # 解析条件表达式
66
+ filter_func = _parse_condition(condition)
67
+ return dt.filter(filter_func, raw=True)
68
+
69
+
70
+ def _parse_condition(condition: str) -> Callable:
71
+ """
72
+ 解析条件表达式为过滤函数。
73
+
74
+ 支持的格式:
75
+ - "score > 0.5"
76
+ - "len(text) > 10"
77
+ - "category == 'A'"
78
+ - "field is not empty"
79
+ """
80
+ import re
81
+
82
+ condition = condition.strip()
83
+
84
+ # 长度比较:len(field) op value
85
+ len_match = re.match(r"len\((\w+)\)\s*(>|<|>=|<=|==|!=)\s*(\d+)", condition)
86
+ if len_match:
87
+ field, op, value = len_match.groups()
88
+ value = int(value)
89
+ ops = {
90
+ ">": lambda a, b: a > b,
91
+ "<": lambda a, b: a < b,
92
+ ">=": lambda a, b: a >= b,
93
+ "<=": lambda a, b: a <= b,
94
+ "==": lambda a, b: a == b,
95
+ "!=": lambda a, b: a != b,
96
+ }
97
+ return lambda x, f=field, o=ops[op], v=value: o(len(str(x.get(f, ""))), v)
98
+
99
+ # 非空判断:field is not empty / field is not None
100
+ nonempty_match = re.match(r"(\w+)\s+is\s+not\s+(empty|None)", condition)
101
+ if nonempty_match:
102
+ field = nonempty_match.group(1)
103
+ return lambda x, f=field: bool(x.get(f))
104
+
105
+ # 数值比较:field op value
106
+ num_match = re.match(r"(\w+)\s*(>|<|>=|<=|==|!=)\s*([\d.]+)", condition)
107
+ if num_match:
108
+ field, op, value = num_match.groups()
109
+ value = float(value)
110
+ ops = {
111
+ ">": lambda a, b: a > b,
112
+ "<": lambda a, b: a < b,
113
+ ">=": lambda a, b: a >= b,
114
+ "<=": lambda a, b: a <= b,
115
+ "==": lambda a, b: a == b,
116
+ "!=": lambda a, b: a != b,
117
+ }
118
+ return lambda x, f=field, o=ops[op], v=value: o(float(x.get(f, 0)), v)
119
+
120
+ # 字符串比较:field == 'value' 或 field != 'value'
121
+ str_match = re.match(r"(\w+)\s*(==|!=)\s*['\"](.+)['\"]", condition)
122
+ if str_match:
123
+ field, op, value = str_match.groups()
124
+ if op == "==":
125
+ return lambda x, f=field, v=value: x.get(f) == v
126
+ else:
127
+ return lambda x, f=field, v=value: x.get(f) != v
128
+
129
+ raise ValueError(f"无法解析条件表达式: {condition}")
130
+
131
+
132
+ def _execute_transform(dt: DataTransformer, step: Dict[str, Any]) -> DataTransformer:
133
+ """执行 transform 步骤"""
134
+ preset = step.get("preset")
135
+ params = step.get("params", {})
136
+
137
+ if not preset:
138
+ raise ValueError("transform 步骤需要指定 preset")
139
+
140
+ if preset not in PRESETS:
141
+ available = ", ".join(PRESETS.keys())
142
+ raise ValueError(f"未知预设: {preset}。可用预设: {available}")
143
+
144
+ transform_func = get_preset(preset, **params)
145
+ return dt.transform(transform_func)
146
+
147
+
148
+ def _execute_dedupe(dt: DataTransformer, step: Dict[str, Any]) -> DataTransformer:
149
+ """执行 dedupe 步骤"""
150
+ key = step.get("key")
151
+ similar = step.get("similar")
152
+
153
+ if similar is not None:
154
+ if not key:
155
+ raise ValueError("相似度去重需要指定 key")
156
+ return dt.dedupe_similar(key, threshold=similar)
157
+
158
+ # 精确去重
159
+ if key:
160
+ # 支持逗号分隔的多字段
161
+ if isinstance(key, str) and "," in key:
162
+ key = [k.strip() for k in key.split(",")]
163
+ return dt.dedupe(key)
164
+
165
+
166
+ def _execute_sample(dt: DataTransformer, step: Dict[str, Any]) -> DataTransformer:
167
+ """执行 sample 步骤"""
168
+ num = step.get("num", 10)
169
+ seed = step.get("seed")
170
+ return dt.sample(num, seed=seed)
171
+
172
+
173
+ def _execute_head(dt: DataTransformer, step: Dict[str, Any]) -> DataTransformer:
174
+ """执行 head 步骤"""
175
+ num = step.get("num", 10)
176
+ return dt.head(num)
177
+
178
+
179
+ def _execute_tail(dt: DataTransformer, step: Dict[str, Any]) -> DataTransformer:
180
+ """执行 tail 步骤"""
181
+ num = step.get("num", 10)
182
+ return dt.tail(num)
183
+
184
+
185
+ def _execute_shuffle(dt: DataTransformer, step: Dict[str, Any]) -> DataTransformer:
186
+ """执行 shuffle 步骤"""
187
+ seed = step.get("seed")
188
+ return dt.shuffle(seed=seed)
189
+
190
+
191
+ def _execute_split(dt: DataTransformer, step: Dict[str, Any]) -> DataTransformer:
192
+ """
193
+ 执行 split 步骤。
194
+
195
+ 注意:split 会产生两个输出,这里只返回第一个(train),
196
+ 第二个(test)会在 run_pipeline 中特殊处理。
197
+ """
198
+ ratio = step.get("ratio", 0.8)
199
+ seed = step.get("seed")
200
+ train, _ = dt.split(ratio=ratio, seed=seed)
201
+ return train
202
+
203
+
204
+ # 步骤执行器映射
205
+ STEP_EXECUTORS = {
206
+ "filter": _execute_filter,
207
+ "transform": _execute_transform,
208
+ "dedupe": _execute_dedupe,
209
+ "sample": _execute_sample,
210
+ "head": _execute_head,
211
+ "tail": _execute_tail,
212
+ "shuffle": _execute_shuffle,
213
+ "split": _execute_split,
214
+ }
215
+
216
+
217
+ # ============ Pipeline 执行器 ============
218
+
219
+
220
+ def run_pipeline(
221
+ config_path: str,
222
+ input_file: Optional[str] = None,
223
+ output_file: Optional[str] = None,
224
+ verbose: bool = True,
225
+ ) -> DataTransformer:
226
+ """
227
+ 执行 Pipeline 配置文件。
228
+
229
+ Args:
230
+ config_path: YAML 配置文件路径
231
+ input_file: 输入文件路径(覆盖配置中的 input)
232
+ output_file: 输出文件路径(覆盖配置中的 output)
233
+ verbose: 是否打印执行过程
234
+
235
+ Returns:
236
+ 处理后的 DataTransformer
237
+
238
+ Examples:
239
+ >>> run_pipeline("pipeline.yaml")
240
+ >>> run_pipeline("pipeline.yaml", input_file="new_data.jsonl")
241
+ """
242
+ # 加载配置
243
+ config = _load_yaml(config_path)
244
+
245
+ # 验证版本
246
+ version = config.get("version", "1.0")
247
+ if version != PIPELINE_VERSION:
248
+ if verbose:
249
+ print(f"⚠ 配置版本 {version} 与当前版本 {PIPELINE_VERSION} 不一致")
250
+
251
+ # 设置全局随机种子
252
+ seed = config.get("seed")
253
+ if seed is not None:
254
+ random.seed(seed)
255
+ if verbose:
256
+ print(f"🎲 设置随机种子: {seed}")
257
+
258
+ # 确定输入文件
259
+ input_path = input_file or config.get("input")
260
+ if not input_path:
261
+ raise ValueError("未指定输入文件,请在配置中设置 input 或使用 --input 参数")
262
+
263
+ # 加载数据
264
+ if verbose:
265
+ print(f"📂 加载数据: {input_path}")
266
+ dt = DataTransformer.load(input_path)
267
+ if verbose:
268
+ print(f" 共 {len(dt)} 条数据")
269
+
270
+ # 执行步骤
271
+ steps = config.get("steps", [])
272
+ for i, step in enumerate(steps, 1):
273
+ step_type = step.get("type")
274
+ if not step_type:
275
+ raise ValueError(f"步骤 {i} 未指定 type")
276
+
277
+ if step_type not in STEP_EXECUTORS:
278
+ available = ", ".join(STEP_EXECUTORS.keys())
279
+ raise ValueError(f"未知步骤类型: {step_type}。可用类型: {available}")
280
+
281
+ if verbose:
282
+ step_desc = _format_step_description(step)
283
+ print(f"🔄 步骤 {i}: {step_desc}")
284
+
285
+ before_count = len(dt)
286
+ dt = STEP_EXECUTORS[step_type](dt, step)
287
+ after_count = len(dt)
288
+
289
+ if verbose and before_count != after_count:
290
+ print(f" {before_count} → {after_count} 条")
291
+
292
+ # 保存结果
293
+ output_path = output_file or config.get("output")
294
+ if output_path:
295
+ if verbose:
296
+ print(f"💾 保存结果: {output_path}")
297
+ dt.save(output_path)
298
+ if verbose:
299
+ print(f"\n✅ 完成! 共 {len(dt)} 条数据")
300
+
301
+ return dt
302
+
303
+
304
+ def _format_step_description(step: Dict[str, Any]) -> str:
305
+ """格式化步骤描述"""
306
+ step_type = step.get("type", "")
307
+
308
+ if step_type == "filter":
309
+ cond = step.get("condition") or step.get("field")
310
+ return f"filter ({cond})"
311
+ elif step_type == "transform":
312
+ preset = step.get("preset", "")
313
+ return f"transform ({preset})"
314
+ elif step_type == "dedupe":
315
+ key = step.get("key", "全量")
316
+ similar = step.get("similar")
317
+ if similar:
318
+ return f"dedupe ({key}, 相似度={similar})"
319
+ return f"dedupe ({key})"
320
+ elif step_type == "sample":
321
+ num = step.get("num", 10)
322
+ return f"sample ({num})"
323
+ elif step_type in ("head", "tail"):
324
+ num = step.get("num", 10)
325
+ return f"{step_type} ({num})"
326
+ elif step_type == "shuffle":
327
+ return "shuffle"
328
+ elif step_type == "split":
329
+ ratio = step.get("ratio", 0.8)
330
+ return f"split (ratio={ratio})"
331
+ else:
332
+ return step_type
333
+
334
+
335
+ # ============ Pipeline 模板生成 ============
336
+
337
+
338
+ def generate_pipeline_template(
339
+ input_file: str,
340
+ output_file: str = "pipeline.yaml",
341
+ preset: Optional[str] = None,
342
+ ) -> str:
343
+ """
344
+ 生成 Pipeline 配置模板。
345
+
346
+ Args:
347
+ input_file: 输入文件路径
348
+ output_file: 配置文件输出路径
349
+
350
+ Returns:
351
+ 生成的配置文件路径
352
+ """
353
+ # 分析输入数据
354
+ data = load_data(input_file)
355
+ if not data:
356
+ raise ValueError("输入文件为空")
357
+
358
+ sample = data[0]
359
+ fields = list(sample.keys())
360
+
361
+ # 构建配置
362
+ config = {
363
+ "version": PIPELINE_VERSION,
364
+ "seed": 42,
365
+ "input": input_file,
366
+ "output": Path(input_file).stem + "_output.jsonl",
367
+ "steps": [],
368
+ }
369
+
370
+ # 添加示例步骤
371
+ if preset:
372
+ config["steps"].append({
373
+ "type": "transform",
374
+ "preset": preset,
375
+ })
376
+ else:
377
+ # 根据字段推断可能的步骤
378
+ config["steps"].append({
379
+ "type": "filter",
380
+ "condition": f"len({fields[0]}) > 0",
381
+ })
382
+
383
+ # 如果有 messages 或 q/a 字段,添加 transform 步骤
384
+ if "messages" in fields:
385
+ pass # 已经是 messages 格式
386
+ elif "q" in fields and "a" in fields:
387
+ config["steps"].append({
388
+ "type": "transform",
389
+ "preset": "openai_chat",
390
+ "params": {"user_field": "q", "assistant_field": "a"},
391
+ })
392
+ elif "instruction" in fields and "output" in fields:
393
+ config["steps"].append({
394
+ "type": "transform",
395
+ "preset": "alpaca",
396
+ })
397
+
398
+ # 添加去重步骤
399
+ config["steps"].append({
400
+ "type": "dedupe",
401
+ "key": fields[0] if fields else None,
402
+ })
403
+
404
+ # 保存配置
405
+ _save_yaml(config, output_file)
406
+
407
+ return output_file
408
+
409
+
410
+ def validate_pipeline(config_path: str) -> List[str]:
411
+ """
412
+ 验证 Pipeline 配置文件。
413
+
414
+ Args:
415
+ config_path: 配置文件路径
416
+
417
+ Returns:
418
+ 错误列表,空列表表示验证通过
419
+ """
420
+ errors = []
421
+
422
+ try:
423
+ config = _load_yaml(config_path)
424
+ except Exception as e:
425
+ return [f"无法解析配置文件: {e}"]
426
+
427
+ # 检查必需字段
428
+ if "steps" not in config:
429
+ errors.append("缺少 steps 字段")
430
+
431
+ # 检查步骤
432
+ steps = config.get("steps", [])
433
+ for i, step in enumerate(steps, 1):
434
+ if "type" not in step:
435
+ errors.append(f"步骤 {i} 缺少 type 字段")
436
+ continue
437
+
438
+ step_type = step["type"]
439
+ if step_type not in STEP_EXECUTORS:
440
+ available = ", ".join(STEP_EXECUTORS.keys())
441
+ errors.append(f"步骤 {i}: 未知类型 '{step_type}',可用: {available}")
442
+
443
+ # 特定步骤的验证
444
+ if step_type == "transform" and "preset" not in step:
445
+ errors.append(f"步骤 {i}: transform 需要指定 preset")
446
+
447
+ if step_type == "filter" and not step.get("condition") and not step.get("field"):
448
+ errors.append(f"步骤 {i}: filter 需要指定 condition 或 field")
449
+
450
+ return errors