dtflow 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/__init__.py +34 -1
- dtflow/__main__.py +22 -0
- dtflow/cli/commands.py +5 -0
- dtflow/cli/validate.py +152 -0
- dtflow/core.py +154 -0
- dtflow/framework.py +610 -0
- dtflow/schema.py +508 -0
- {dtflow-0.4.3.dist-info → dtflow-0.5.0.dist-info}/METADATA +107 -2
- {dtflow-0.4.3.dist-info → dtflow-0.5.0.dist-info}/RECORD +11 -8
- {dtflow-0.4.3.dist-info → dtflow-0.5.0.dist-info}/WHEEL +0 -0
- {dtflow-0.4.3.dist-info → dtflow-0.5.0.dist-info}/entry_points.txt +0 -0
dtflow/__init__.py
CHANGED
|
@@ -4,6 +4,7 @@ DataTransformer: 简洁的数据格式转换工具
|
|
|
4
4
|
核心功能:
|
|
5
5
|
- DataTransformer: 数据加载、转换、保存
|
|
6
6
|
- presets: 预设转换模板 (openai_chat, alpaca, sharegpt, dpo_pair, simple_qa)
|
|
7
|
+
- schema: 数据结构验证 (Schema, Field)
|
|
7
8
|
- tokenizers: Token 统计和过滤
|
|
8
9
|
- converters: HuggingFace/OpenAI 等格式转换
|
|
9
10
|
"""
|
|
@@ -26,6 +27,23 @@ from .converters import ( # LLaMA-Factory 扩展; ms-swift
|
|
|
26
27
|
)
|
|
27
28
|
from .core import DataTransformer, DictWrapper, TransformError, TransformErrors
|
|
28
29
|
from .presets import get_preset, list_presets
|
|
30
|
+
from .schema import (
|
|
31
|
+
Field,
|
|
32
|
+
Schema,
|
|
33
|
+
ValidationError,
|
|
34
|
+
ValidationResult,
|
|
35
|
+
alpaca_schema,
|
|
36
|
+
dpo_schema,
|
|
37
|
+
openai_chat_schema,
|
|
38
|
+
sharegpt_schema,
|
|
39
|
+
validate_data,
|
|
40
|
+
)
|
|
41
|
+
from .framework import (
|
|
42
|
+
CompatibilityResult,
|
|
43
|
+
check_compatibility,
|
|
44
|
+
detect_format,
|
|
45
|
+
export_for,
|
|
46
|
+
)
|
|
29
47
|
from .storage import load_data, sample_file, save_data
|
|
30
48
|
from .streaming import StreamingTransformer, load_sharded, load_stream, process_shards
|
|
31
49
|
from .tokenizers import (
|
|
@@ -42,7 +60,7 @@ from .tokenizers import (
|
|
|
42
60
|
token_stats,
|
|
43
61
|
)
|
|
44
62
|
|
|
45
|
-
__version__ = "0.
|
|
63
|
+
__version__ = "0.5.0"
|
|
46
64
|
|
|
47
65
|
__all__ = [
|
|
48
66
|
# core
|
|
@@ -53,6 +71,21 @@ __all__ = [
|
|
|
53
71
|
# presets
|
|
54
72
|
"get_preset",
|
|
55
73
|
"list_presets",
|
|
74
|
+
# schema
|
|
75
|
+
"Schema",
|
|
76
|
+
"Field",
|
|
77
|
+
"ValidationResult",
|
|
78
|
+
"ValidationError",
|
|
79
|
+
"validate_data",
|
|
80
|
+
"openai_chat_schema",
|
|
81
|
+
"alpaca_schema",
|
|
82
|
+
"dpo_schema",
|
|
83
|
+
"sharegpt_schema",
|
|
84
|
+
# framework
|
|
85
|
+
"CompatibilityResult",
|
|
86
|
+
"check_compatibility",
|
|
87
|
+
"detect_format",
|
|
88
|
+
"export_for",
|
|
56
89
|
# storage
|
|
57
90
|
"save_data",
|
|
58
91
|
"load_data",
|
dtflow/__main__.py
CHANGED
|
@@ -18,6 +18,7 @@ Commands:
|
|
|
18
18
|
clean 数据清洗
|
|
19
19
|
run 执行 Pipeline 配置文件
|
|
20
20
|
history 显示数据血缘历史
|
|
21
|
+
validate 使用 Schema 验证数据格式
|
|
21
22
|
mcp MCP 服务管理(install/uninstall/status)
|
|
22
23
|
logs 日志查看工具使用说明
|
|
23
24
|
"""
|
|
@@ -40,6 +41,7 @@ from .cli.commands import stats as _stats
|
|
|
40
41
|
from .cli.commands import tail as _tail
|
|
41
42
|
from .cli.commands import token_stats as _token_stats
|
|
42
43
|
from .cli.commands import transform as _transform
|
|
44
|
+
from .cli.commands import validate as _validate
|
|
43
45
|
|
|
44
46
|
# 创建主应用
|
|
45
47
|
app = typer.Typer(
|
|
@@ -211,6 +213,26 @@ def history(
|
|
|
211
213
|
_history(filename, json)
|
|
212
214
|
|
|
213
215
|
|
|
216
|
+
# ============ 验证命令 ============
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@app.command()
|
|
220
|
+
def validate(
|
|
221
|
+
filename: str = typer.Argument(..., help="输入文件路径"),
|
|
222
|
+
preset: Optional[str] = typer.Option(
|
|
223
|
+
None, "--preset", "-p", help="预设 Schema: openai_chat, alpaca, dpo, sharegpt"
|
|
224
|
+
),
|
|
225
|
+
output: Optional[str] = typer.Option(None, "--output", "-o", help="输出有效数据的文件路径"),
|
|
226
|
+
filter: bool = typer.Option(
|
|
227
|
+
False, "--filter", "-f", help="过滤无效数据并保存"
|
|
228
|
+
),
|
|
229
|
+
max_errors: int = typer.Option(20, "--max-errors", help="最多显示的错误数量"),
|
|
230
|
+
verbose: bool = typer.Option(False, "--verbose", "-v", help="显示详细信息"),
|
|
231
|
+
):
|
|
232
|
+
"""使用预设 Schema 验证数据格式"""
|
|
233
|
+
_validate(filename, preset, output, filter, max_errors, verbose)
|
|
234
|
+
|
|
235
|
+
|
|
214
236
|
# ============ 工具命令 ============
|
|
215
237
|
|
|
216
238
|
|
dtflow/cli/commands.py
CHANGED
|
@@ -33,6 +33,9 @@ from .pipeline import run
|
|
|
33
33
|
# 血缘追踪命令
|
|
34
34
|
from .lineage import history
|
|
35
35
|
|
|
36
|
+
# 验证命令
|
|
37
|
+
from .validate import validate
|
|
38
|
+
|
|
36
39
|
__all__ = [
|
|
37
40
|
# 采样
|
|
38
41
|
"sample",
|
|
@@ -53,4 +56,6 @@ __all__ = [
|
|
|
53
56
|
"run",
|
|
54
57
|
# 血缘
|
|
55
58
|
"history",
|
|
59
|
+
# 验证
|
|
60
|
+
"validate",
|
|
56
61
|
]
|
dtflow/cli/validate.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CLI Schema 验证命令
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from ..schema import (
|
|
9
|
+
Schema,
|
|
10
|
+
Field,
|
|
11
|
+
alpaca_schema,
|
|
12
|
+
dpo_schema,
|
|
13
|
+
openai_chat_schema,
|
|
14
|
+
sharegpt_schema,
|
|
15
|
+
)
|
|
16
|
+
from ..storage.io import load_data, save_data
|
|
17
|
+
from .common import _check_file_format
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# 预设 Schema 映射
|
|
21
|
+
PRESET_SCHEMAS = {
|
|
22
|
+
"openai_chat": openai_chat_schema,
|
|
23
|
+
"openai-chat": openai_chat_schema,
|
|
24
|
+
"chat": openai_chat_schema,
|
|
25
|
+
"alpaca": alpaca_schema,
|
|
26
|
+
"dpo": dpo_schema,
|
|
27
|
+
"dpo_pair": dpo_schema,
|
|
28
|
+
"sharegpt": sharegpt_schema,
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def validate(
|
|
33
|
+
filename: str,
|
|
34
|
+
preset: Optional[str] = None,
|
|
35
|
+
output: Optional[str] = None,
|
|
36
|
+
filter_invalid: bool = False,
|
|
37
|
+
max_errors: int = 20,
|
|
38
|
+
verbose: bool = False,
|
|
39
|
+
) -> None:
|
|
40
|
+
"""
|
|
41
|
+
使用 Schema 验证数据文件。
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
filename: 输入文件路径
|
|
45
|
+
preset: 预设 Schema 名称 (openai_chat, alpaca, dpo, sharegpt)
|
|
46
|
+
output: 输出文件路径(保存有效数据)
|
|
47
|
+
filter_invalid: 过滤无效数据并保存
|
|
48
|
+
max_errors: 最多显示的错误数量
|
|
49
|
+
verbose: 显示详细信息
|
|
50
|
+
|
|
51
|
+
Examples:
|
|
52
|
+
dt validate data.jsonl --preset=openai_chat
|
|
53
|
+
dt validate data.jsonl --preset=alpaca -o valid.jsonl
|
|
54
|
+
dt validate data.jsonl --preset=chat --filter
|
|
55
|
+
"""
|
|
56
|
+
filepath = Path(filename)
|
|
57
|
+
|
|
58
|
+
if not filepath.exists():
|
|
59
|
+
print(f"错误: 文件不存在 - {filename}")
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
if not _check_file_format(filepath):
|
|
63
|
+
return
|
|
64
|
+
|
|
65
|
+
# 确定 Schema
|
|
66
|
+
if preset is None:
|
|
67
|
+
# 列出可用的预设
|
|
68
|
+
print("请指定预设 Schema (--preset):")
|
|
69
|
+
print()
|
|
70
|
+
for name in ["openai_chat", "alpaca", "dpo", "sharegpt"]:
|
|
71
|
+
print(f" --preset={name}")
|
|
72
|
+
print()
|
|
73
|
+
print("示例:")
|
|
74
|
+
print(f" dt validate {filename} --preset=openai_chat")
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
preset_lower = preset.lower().replace("-", "_")
|
|
78
|
+
if preset_lower not in PRESET_SCHEMAS:
|
|
79
|
+
print(f"错误: 未知的预设 Schema '{preset}'")
|
|
80
|
+
print(f"可用预设: {', '.join(['openai_chat', 'alpaca', 'dpo', 'sharegpt'])}")
|
|
81
|
+
return
|
|
82
|
+
|
|
83
|
+
schema = PRESET_SCHEMAS[preset_lower]()
|
|
84
|
+
|
|
85
|
+
# 加载数据
|
|
86
|
+
try:
|
|
87
|
+
data = load_data(str(filepath))
|
|
88
|
+
except Exception as e:
|
|
89
|
+
print(f"错误: 无法读取文件 - {e}")
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
if not data:
|
|
93
|
+
print("文件为空")
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
total = len(data)
|
|
97
|
+
print(f"验证文件: {filepath.name}")
|
|
98
|
+
print(f"预设 Schema: {preset}")
|
|
99
|
+
print(f"总记录数: {total}")
|
|
100
|
+
print()
|
|
101
|
+
|
|
102
|
+
# 验证
|
|
103
|
+
valid_data = []
|
|
104
|
+
invalid_count = 0
|
|
105
|
+
error_samples = []
|
|
106
|
+
|
|
107
|
+
for i, item in enumerate(data):
|
|
108
|
+
result = schema.validate(item)
|
|
109
|
+
if result.valid:
|
|
110
|
+
valid_data.append(item)
|
|
111
|
+
else:
|
|
112
|
+
invalid_count += 1
|
|
113
|
+
if len(error_samples) < max_errors:
|
|
114
|
+
error_samples.append((i, result))
|
|
115
|
+
|
|
116
|
+
valid_count = len(valid_data)
|
|
117
|
+
valid_ratio = valid_count / total * 100 if total > 0 else 0
|
|
118
|
+
|
|
119
|
+
# 输出结果
|
|
120
|
+
if invalid_count == 0:
|
|
121
|
+
print(f"✅ 全部通过! {valid_count}/{total} 条记录有效 (100%)")
|
|
122
|
+
else:
|
|
123
|
+
print(f"⚠️ 验证结果: {valid_count}/{total} 条有效 ({valid_ratio:.1f}%)")
|
|
124
|
+
print(f" 无效记录: {invalid_count} 条")
|
|
125
|
+
print()
|
|
126
|
+
|
|
127
|
+
# 显示错误示例
|
|
128
|
+
print(f"错误示例 (最多显示 {max_errors} 条):")
|
|
129
|
+
print("-" * 60)
|
|
130
|
+
|
|
131
|
+
for idx, result in error_samples:
|
|
132
|
+
print(f"[第 {idx} 行]")
|
|
133
|
+
for err in result.errors[:3]: # 每条记录最多显示 3 个错误
|
|
134
|
+
print(f" - {err}")
|
|
135
|
+
if len(result.errors) > 3:
|
|
136
|
+
print(f" ... 还有 {len(result.errors) - 3} 个错误")
|
|
137
|
+
print()
|
|
138
|
+
|
|
139
|
+
# 保存有效数据
|
|
140
|
+
if output or filter_invalid:
|
|
141
|
+
output_path = output or str(filepath).replace(
|
|
142
|
+
filepath.suffix, f"_valid{filepath.suffix}"
|
|
143
|
+
)
|
|
144
|
+
save_data(valid_data, output_path)
|
|
145
|
+
print(f"✅ 有效数据已保存: {output_path} ({valid_count} 条)")
|
|
146
|
+
|
|
147
|
+
# 详细模式:显示 Schema 定义
|
|
148
|
+
if verbose:
|
|
149
|
+
print()
|
|
150
|
+
print("Schema 定义:")
|
|
151
|
+
print("-" * 40)
|
|
152
|
+
print(schema)
|
dtflow/core.py
CHANGED
|
@@ -386,6 +386,88 @@ class DataTransformer:
|
|
|
386
386
|
|
|
387
387
|
return errors
|
|
388
388
|
|
|
389
|
+
def validate_schema(
|
|
390
|
+
self,
|
|
391
|
+
schema: "Schema",
|
|
392
|
+
on_error: Literal["skip", "raise", "filter"] = "skip",
|
|
393
|
+
max_errors: int = 100,
|
|
394
|
+
) -> Union["DataTransformer", List[tuple]]:
|
|
395
|
+
"""
|
|
396
|
+
使用 Schema 验证数据结构。
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
schema: Schema 对象,定义数据结构验证规则
|
|
400
|
+
on_error: 错误处理方式
|
|
401
|
+
- "skip": 打印警告,返回验证失败的记录列表
|
|
402
|
+
- "raise": 第一个错误时抛出异常
|
|
403
|
+
- "filter": 过滤掉验证失败的记录,返回新的 DataTransformer
|
|
404
|
+
max_errors: 最大错误数量(on_error="skip" 时生效)
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
- on_error="skip": 返回 [(index, ValidationResult), ...] 失败记录列表
|
|
408
|
+
- on_error="raise": 无返回(成功)或抛出 ValueError
|
|
409
|
+
- on_error="filter": 返回过滤后的新 DataTransformer
|
|
410
|
+
|
|
411
|
+
Examples:
|
|
412
|
+
>>> from dtflow import Schema, Field
|
|
413
|
+
>>> schema = Schema({
|
|
414
|
+
... "messages": Field(type="list", required=True, min_length=1),
|
|
415
|
+
... "messages[*].role": Field(type="str", choices=["user", "assistant"]),
|
|
416
|
+
... })
|
|
417
|
+
|
|
418
|
+
>>> # 获取验证失败的记录
|
|
419
|
+
>>> errors = dt.validate_schema(schema)
|
|
420
|
+
>>> for idx, result in errors:
|
|
421
|
+
... print(f"第 {idx} 行验证失败: {result.errors}")
|
|
422
|
+
|
|
423
|
+
>>> # 过滤掉无效记录
|
|
424
|
+
>>> valid_dt = dt.validate_schema(schema, on_error="filter")
|
|
425
|
+
|
|
426
|
+
>>> # 遇到错误立即停止
|
|
427
|
+
>>> dt.validate_schema(schema, on_error="raise")
|
|
428
|
+
"""
|
|
429
|
+
from .schema import Schema, ValidationResult
|
|
430
|
+
|
|
431
|
+
failed: List[tuple] = []
|
|
432
|
+
valid_data: List[dict] = []
|
|
433
|
+
error_count = 0
|
|
434
|
+
|
|
435
|
+
for i, item in enumerate(self._data):
|
|
436
|
+
result = schema.validate(item)
|
|
437
|
+
if result.valid:
|
|
438
|
+
valid_data.append(item)
|
|
439
|
+
else:
|
|
440
|
+
failed.append((i, result))
|
|
441
|
+
error_count += len(result.errors)
|
|
442
|
+
|
|
443
|
+
if on_error == "raise":
|
|
444
|
+
error_msgs = [str(e) for e in result.errors[:3]]
|
|
445
|
+
raise ValueError(
|
|
446
|
+
f"第 {i} 行验证失败:\n " + "\n ".join(error_msgs)
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
if on_error == "skip" and error_count >= max_errors:
|
|
450
|
+
print(f"⚠️ 已达到最大错误数 {max_errors},停止验证")
|
|
451
|
+
break
|
|
452
|
+
|
|
453
|
+
if on_error == "skip":
|
|
454
|
+
if failed:
|
|
455
|
+
print(f"⚠️ 验证失败 {len(failed)} 条记录(共 {error_count} 个错误)")
|
|
456
|
+
return failed
|
|
457
|
+
|
|
458
|
+
if on_error == "filter":
|
|
459
|
+
tracker = self._lineage_tracker
|
|
460
|
+
if tracker:
|
|
461
|
+
tracker.record(
|
|
462
|
+
"validate_schema",
|
|
463
|
+
{"schema": repr(schema), "on_error": on_error},
|
|
464
|
+
len(self._data),
|
|
465
|
+
len(valid_data),
|
|
466
|
+
)
|
|
467
|
+
return DataTransformer(valid_data, _lineage_tracker=tracker)
|
|
468
|
+
|
|
469
|
+
return failed
|
|
470
|
+
|
|
389
471
|
def dedupe(
|
|
390
472
|
self,
|
|
391
473
|
key: Union[None, str, List[str], Callable[[Any], Any]] = None,
|
|
@@ -801,6 +883,78 @@ class DataTransformer:
|
|
|
801
883
|
filtered = [item for item, keep in zip(self._data, mask) if keep]
|
|
802
884
|
return DataTransformer(filtered)
|
|
803
885
|
|
|
886
|
+
# ============ 训练框架集成 ============
|
|
887
|
+
|
|
888
|
+
def check_compatibility(
|
|
889
|
+
self,
|
|
890
|
+
framework: Literal["llama-factory", "swift", "axolotl"],
|
|
891
|
+
) -> "CompatibilityResult":
|
|
892
|
+
"""
|
|
893
|
+
检查数据与目标训练框架的兼容性。
|
|
894
|
+
|
|
895
|
+
Args:
|
|
896
|
+
framework: 目标框架名称
|
|
897
|
+
- "llama-factory": LLaMA-Factory
|
|
898
|
+
- "swift": ms-swift (ModelScope)
|
|
899
|
+
- "axolotl": Axolotl
|
|
900
|
+
|
|
901
|
+
Returns:
|
|
902
|
+
CompatibilityResult 对象,包含 valid, errors, warnings, suggestions
|
|
903
|
+
|
|
904
|
+
Examples:
|
|
905
|
+
>>> result = dt.check_compatibility("llama-factory")
|
|
906
|
+
>>> if result.valid:
|
|
907
|
+
... print("兼容!")
|
|
908
|
+
>>> else:
|
|
909
|
+
... print(result.errors)
|
|
910
|
+
"""
|
|
911
|
+
from .framework import check_compatibility
|
|
912
|
+
|
|
913
|
+
return check_compatibility(self._data, framework)
|
|
914
|
+
|
|
915
|
+
def export_for(
|
|
916
|
+
self,
|
|
917
|
+
framework: Literal["llama-factory", "swift", "axolotl"],
|
|
918
|
+
output_dir: str,
|
|
919
|
+
dataset_name: str = "custom_dataset",
|
|
920
|
+
**kwargs,
|
|
921
|
+
) -> Dict[str, str]:
|
|
922
|
+
"""
|
|
923
|
+
一键导出数据和配置文件到目标训练框架。
|
|
924
|
+
|
|
925
|
+
Args:
|
|
926
|
+
framework: 目标框架名称
|
|
927
|
+
output_dir: 输出目录
|
|
928
|
+
dataset_name: 数据集名称
|
|
929
|
+
**kwargs: 框架特定参数(如 model_name)
|
|
930
|
+
|
|
931
|
+
Returns:
|
|
932
|
+
生成的文件路径字典 {"data": "...", "config": "...", ...}
|
|
933
|
+
|
|
934
|
+
Examples:
|
|
935
|
+
>>> # 导出到 LLaMA-Factory
|
|
936
|
+
>>> dt.export_for("llama-factory", "./llama_ready")
|
|
937
|
+
# 生成:
|
|
938
|
+
# - ./llama_ready/custom_dataset.json
|
|
939
|
+
# - ./llama_ready/dataset_info.json
|
|
940
|
+
# - ./llama_ready/train_args.yaml
|
|
941
|
+
|
|
942
|
+
>>> # 导出到 ms-swift
|
|
943
|
+
>>> dt.export_for("swift", "./swift_ready", dataset_name="my_data")
|
|
944
|
+
|
|
945
|
+
>>> # 导出到 Axolotl
|
|
946
|
+
>>> dt.export_for("axolotl", "./axolotl_ready")
|
|
947
|
+
"""
|
|
948
|
+
from .framework import export_for
|
|
949
|
+
|
|
950
|
+
return export_for(
|
|
951
|
+
self._data,
|
|
952
|
+
framework,
|
|
953
|
+
output_dir,
|
|
954
|
+
dataset_name=dataset_name,
|
|
955
|
+
**kwargs,
|
|
956
|
+
)
|
|
957
|
+
|
|
804
958
|
|
|
805
959
|
def _sanitize_key(name: str) -> str:
|
|
806
960
|
"""将字段名规范化为合法的 Python 标识符"""
|