dtflow 0.4.3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/__init__.py +34 -1
- dtflow/__main__.py +22 -0
- dtflow/cli/commands.py +5 -0
- dtflow/cli/common.py +13 -9
- dtflow/cli/stats.py +114 -36
- dtflow/cli/validate.py +152 -0
- dtflow/core.py +220 -10
- dtflow/framework.py +610 -0
- dtflow/lineage.py +17 -0
- dtflow/schema.py +508 -0
- dtflow/streaming.py +93 -35
- dtflow/tokenizers.py +84 -29
- dtflow/utils/field_path.py +6 -2
- {dtflow-0.4.3.dist-info → dtflow-0.5.2.dist-info}/METADATA +117 -2
- {dtflow-0.4.3.dist-info → dtflow-0.5.2.dist-info}/RECORD +17 -14
- {dtflow-0.4.3.dist-info → dtflow-0.5.2.dist-info}/WHEEL +0 -0
- {dtflow-0.4.3.dist-info → dtflow-0.5.2.dist-info}/entry_points.txt +0 -0
dtflow/__init__.py
CHANGED
|
@@ -4,6 +4,7 @@ DataTransformer: 简洁的数据格式转换工具
|
|
|
4
4
|
核心功能:
|
|
5
5
|
- DataTransformer: 数据加载、转换、保存
|
|
6
6
|
- presets: 预设转换模板 (openai_chat, alpaca, sharegpt, dpo_pair, simple_qa)
|
|
7
|
+
- schema: 数据结构验证 (Schema, Field)
|
|
7
8
|
- tokenizers: Token 统计和过滤
|
|
8
9
|
- converters: HuggingFace/OpenAI 等格式转换
|
|
9
10
|
"""
|
|
@@ -26,6 +27,23 @@ from .converters import ( # LLaMA-Factory 扩展; ms-swift
|
|
|
26
27
|
)
|
|
27
28
|
from .core import DataTransformer, DictWrapper, TransformError, TransformErrors
|
|
28
29
|
from .presets import get_preset, list_presets
|
|
30
|
+
from .schema import (
|
|
31
|
+
Field,
|
|
32
|
+
Schema,
|
|
33
|
+
ValidationError,
|
|
34
|
+
ValidationResult,
|
|
35
|
+
alpaca_schema,
|
|
36
|
+
dpo_schema,
|
|
37
|
+
openai_chat_schema,
|
|
38
|
+
sharegpt_schema,
|
|
39
|
+
validate_data,
|
|
40
|
+
)
|
|
41
|
+
from .framework import (
|
|
42
|
+
CompatibilityResult,
|
|
43
|
+
check_compatibility,
|
|
44
|
+
detect_format,
|
|
45
|
+
export_for,
|
|
46
|
+
)
|
|
29
47
|
from .storage import load_data, sample_file, save_data
|
|
30
48
|
from .streaming import StreamingTransformer, load_sharded, load_stream, process_shards
|
|
31
49
|
from .tokenizers import (
|
|
@@ -42,7 +60,7 @@ from .tokenizers import (
|
|
|
42
60
|
token_stats,
|
|
43
61
|
)
|
|
44
62
|
|
|
45
|
-
__version__ = "0.
|
|
63
|
+
__version__ = "0.5.2"
|
|
46
64
|
|
|
47
65
|
__all__ = [
|
|
48
66
|
# core
|
|
@@ -53,6 +71,21 @@ __all__ = [
|
|
|
53
71
|
# presets
|
|
54
72
|
"get_preset",
|
|
55
73
|
"list_presets",
|
|
74
|
+
# schema
|
|
75
|
+
"Schema",
|
|
76
|
+
"Field",
|
|
77
|
+
"ValidationResult",
|
|
78
|
+
"ValidationError",
|
|
79
|
+
"validate_data",
|
|
80
|
+
"openai_chat_schema",
|
|
81
|
+
"alpaca_schema",
|
|
82
|
+
"dpo_schema",
|
|
83
|
+
"sharegpt_schema",
|
|
84
|
+
# framework
|
|
85
|
+
"CompatibilityResult",
|
|
86
|
+
"check_compatibility",
|
|
87
|
+
"detect_format",
|
|
88
|
+
"export_for",
|
|
56
89
|
# storage
|
|
57
90
|
"save_data",
|
|
58
91
|
"load_data",
|
dtflow/__main__.py
CHANGED
|
@@ -18,6 +18,7 @@ Commands:
|
|
|
18
18
|
clean 数据清洗
|
|
19
19
|
run 执行 Pipeline 配置文件
|
|
20
20
|
history 显示数据血缘历史
|
|
21
|
+
validate 使用 Schema 验证数据格式
|
|
21
22
|
mcp MCP 服务管理(install/uninstall/status)
|
|
22
23
|
logs 日志查看工具使用说明
|
|
23
24
|
"""
|
|
@@ -40,6 +41,7 @@ from .cli.commands import stats as _stats
|
|
|
40
41
|
from .cli.commands import tail as _tail
|
|
41
42
|
from .cli.commands import token_stats as _token_stats
|
|
42
43
|
from .cli.commands import transform as _transform
|
|
44
|
+
from .cli.commands import validate as _validate
|
|
43
45
|
|
|
44
46
|
# 创建主应用
|
|
45
47
|
app = typer.Typer(
|
|
@@ -211,6 +213,26 @@ def history(
|
|
|
211
213
|
_history(filename, json)
|
|
212
214
|
|
|
213
215
|
|
|
216
|
+
# ============ 验证命令 ============
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@app.command()
|
|
220
|
+
def validate(
|
|
221
|
+
filename: str = typer.Argument(..., help="输入文件路径"),
|
|
222
|
+
preset: Optional[str] = typer.Option(
|
|
223
|
+
None, "--preset", "-p", help="预设 Schema: openai_chat, alpaca, dpo, sharegpt"
|
|
224
|
+
),
|
|
225
|
+
output: Optional[str] = typer.Option(None, "--output", "-o", help="输出有效数据的文件路径"),
|
|
226
|
+
filter: bool = typer.Option(
|
|
227
|
+
False, "--filter", "-f", help="过滤无效数据并保存"
|
|
228
|
+
),
|
|
229
|
+
max_errors: int = typer.Option(20, "--max-errors", help="最多显示的错误数量"),
|
|
230
|
+
verbose: bool = typer.Option(False, "--verbose", "-v", help="显示详细信息"),
|
|
231
|
+
):
|
|
232
|
+
"""使用预设 Schema 验证数据格式"""
|
|
233
|
+
_validate(filename, preset, output, filter, max_errors, verbose)
|
|
234
|
+
|
|
235
|
+
|
|
214
236
|
# ============ 工具命令 ============
|
|
215
237
|
|
|
216
238
|
|
dtflow/cli/commands.py
CHANGED
|
@@ -33,6 +33,9 @@ from .pipeline import run
|
|
|
33
33
|
# 血缘追踪命令
|
|
34
34
|
from .lineage import history
|
|
35
35
|
|
|
36
|
+
# 验证命令
|
|
37
|
+
from .validate import validate
|
|
38
|
+
|
|
36
39
|
__all__ = [
|
|
37
40
|
# 采样
|
|
38
41
|
"sample",
|
|
@@ -53,4 +56,6 @@ __all__ = [
|
|
|
53
56
|
"run",
|
|
54
57
|
# 血缘
|
|
55
58
|
"history",
|
|
59
|
+
# 验证
|
|
60
|
+
"validate",
|
|
56
61
|
]
|
dtflow/cli/common.py
CHANGED
|
@@ -57,7 +57,7 @@ def _get_file_row_count(filepath: Path) -> Optional[int]:
|
|
|
57
57
|
return None
|
|
58
58
|
|
|
59
59
|
|
|
60
|
-
def _format_value(value: Any, max_len: int =
|
|
60
|
+
def _format_value(value: Any, max_len: int = 120) -> str:
|
|
61
61
|
"""格式化单个值,长文本截断。"""
|
|
62
62
|
if value is None:
|
|
63
63
|
return "[dim]null[/dim]"
|
|
@@ -66,18 +66,22 @@ def _format_value(value: Any, max_len: int = 80) -> str:
|
|
|
66
66
|
if isinstance(value, (int, float)):
|
|
67
67
|
return f"[cyan]{value}[/cyan]"
|
|
68
68
|
if isinstance(value, str):
|
|
69
|
+
half_len = max_len // 2
|
|
69
70
|
# 处理多行文本
|
|
70
71
|
if "\n" in value:
|
|
71
72
|
lines = value.split("\n")
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
73
|
+
preview = value.replace("\n", "\\n")
|
|
74
|
+
if len(preview) > max_len:
|
|
75
|
+
# 前半 + 省略标记 + 后半
|
|
76
|
+
head = preview[:half_len]
|
|
77
|
+
tail = preview[-half_len:]
|
|
78
|
+
return f'"{head} [yellow]<<<{len(lines)}行>>>[/yellow] {tail}"'
|
|
78
79
|
return f'"{preview}"'
|
|
79
80
|
if len(value) > max_len:
|
|
80
|
-
|
|
81
|
+
# 前半 + 省略标记 + 后半
|
|
82
|
+
head = value[:half_len]
|
|
83
|
+
tail = value[-half_len:]
|
|
84
|
+
return f'"{head} [yellow]<<<{len(value)}字符>>>[/yellow] {tail}"'
|
|
81
85
|
return f'"{value}"'
|
|
82
86
|
return str(value)
|
|
83
87
|
|
|
@@ -86,7 +90,7 @@ def _format_nested(
|
|
|
86
90
|
value: Any,
|
|
87
91
|
indent: str = "",
|
|
88
92
|
is_last: bool = True,
|
|
89
|
-
max_len: int =
|
|
93
|
+
max_len: int = 120,
|
|
90
94
|
) -> List[str]:
|
|
91
95
|
"""
|
|
92
96
|
递归格式化嵌套结构,返回行列表。
|
dtflow/cli/stats.py
CHANGED
|
@@ -465,34 +465,65 @@ def token_stats(
|
|
|
465
465
|
return
|
|
466
466
|
|
|
467
467
|
total = len(data)
|
|
468
|
-
print(f" 共 {total} 条数据")
|
|
469
|
-
print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
|
|
468
|
+
print(f" 共 {total:,} 条数据")
|
|
470
469
|
|
|
471
470
|
# 检查字段类型并选择合适的统计方法(支持嵌套路径)
|
|
472
471
|
sample = data[0]
|
|
473
472
|
field_value = get_field_with_spec(sample, field)
|
|
474
473
|
|
|
474
|
+
# 尝试使用 rich 进度条
|
|
475
475
|
try:
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
476
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
|
|
477
|
+
|
|
478
|
+
with Progress(
|
|
479
|
+
SpinnerColumn(),
|
|
480
|
+
TextColumn("[bold blue]统计 Token"),
|
|
481
|
+
BarColumn(),
|
|
482
|
+
TaskProgressColumn(),
|
|
483
|
+
TextColumn(f"(模型: {model})"),
|
|
484
|
+
) as progress:
|
|
485
|
+
task = progress.add_task("", total=total)
|
|
486
|
+
|
|
487
|
+
def update_progress(current: int, total_count: int):
|
|
488
|
+
progress.update(task, completed=current)
|
|
489
|
+
|
|
490
|
+
if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
|
|
491
|
+
from ..tokenizers import messages_token_stats
|
|
492
|
+
|
|
493
|
+
stats_result = messages_token_stats(
|
|
494
|
+
data, messages_field=field, model=model, progress_callback=update_progress
|
|
495
|
+
)
|
|
496
|
+
_print_messages_token_stats(stats_result, detailed)
|
|
497
|
+
else:
|
|
498
|
+
from ..tokenizers import token_stats as compute_token_stats
|
|
499
|
+
|
|
500
|
+
stats_result = compute_token_stats(
|
|
501
|
+
data, fields=field, model=model, progress_callback=update_progress
|
|
502
|
+
)
|
|
503
|
+
_print_text_token_stats(stats_result, detailed)
|
|
504
|
+
|
|
505
|
+
except ImportError:
|
|
506
|
+
# 没有 rich,显示简单进度
|
|
507
|
+
print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
|
|
508
|
+
try:
|
|
509
|
+
if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
|
|
510
|
+
from ..tokenizers import messages_token_stats
|
|
511
|
+
|
|
512
|
+
stats_result = messages_token_stats(data, messages_field=field, model=model)
|
|
513
|
+
_print_messages_token_stats(stats_result, detailed)
|
|
514
|
+
else:
|
|
515
|
+
from ..tokenizers import token_stats as compute_token_stats
|
|
494
516
|
|
|
495
|
-
|
|
517
|
+
stats_result = compute_token_stats(data, fields=field, model=model)
|
|
518
|
+
_print_text_token_stats(stats_result, detailed)
|
|
519
|
+
except ImportError as e:
|
|
520
|
+
print(f"错误: {e}")
|
|
521
|
+
return
|
|
522
|
+
except Exception as e:
|
|
523
|
+
print(f"错误: 统计失败 - {e}")
|
|
524
|
+
import traceback
|
|
525
|
+
|
|
526
|
+
traceback.print_exc()
|
|
496
527
|
|
|
497
528
|
|
|
498
529
|
def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
@@ -505,21 +536,39 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
|
505
536
|
console = Console()
|
|
506
537
|
|
|
507
538
|
# 概览
|
|
539
|
+
std = stats.get("std_tokens", 0)
|
|
508
540
|
overview = (
|
|
509
541
|
f"[bold]总样本数:[/bold] {stats['count']:,}\n"
|
|
510
542
|
f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
|
|
511
|
-
f"[bold]平均 Token:[/bold] {stats['avg_tokens']:,}\n"
|
|
512
|
-
f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
|
|
543
|
+
f"[bold]平均 Token:[/bold] {stats['avg_tokens']:,} (std: {std:.1f})\n"
|
|
513
544
|
f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
|
|
514
545
|
)
|
|
515
546
|
console.print(Panel(overview, title="📊 Token 统计概览", expand=False))
|
|
516
547
|
|
|
548
|
+
# 百分位数表格
|
|
549
|
+
table = Table(title="📈 分布统计")
|
|
550
|
+
table.add_column("百分位", style="cyan", justify="center")
|
|
551
|
+
table.add_column("Token 数", justify="right")
|
|
552
|
+
percentiles = [
|
|
553
|
+
("Min", stats["min_tokens"]),
|
|
554
|
+
("P25", stats.get("p25", "-")),
|
|
555
|
+
("P50 (中位数)", stats.get("median_tokens", "-")),
|
|
556
|
+
("P75", stats.get("p75", "-")),
|
|
557
|
+
("P90", stats.get("p90", "-")),
|
|
558
|
+
("P95", stats.get("p95", "-")),
|
|
559
|
+
("P99", stats.get("p99", "-")),
|
|
560
|
+
("Max", stats["max_tokens"]),
|
|
561
|
+
]
|
|
562
|
+
for name, val in percentiles:
|
|
563
|
+
table.add_row(name, f"{val:,}" if isinstance(val, int) else str(val))
|
|
564
|
+
console.print(table)
|
|
565
|
+
|
|
517
566
|
if detailed:
|
|
518
|
-
#
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
567
|
+
# 分角色统计
|
|
568
|
+
role_table = Table(title="📋 分角色统计")
|
|
569
|
+
role_table.add_column("角色", style="cyan")
|
|
570
|
+
role_table.add_column("Token 数", justify="right")
|
|
571
|
+
role_table.add_column("占比", justify="right")
|
|
523
572
|
|
|
524
573
|
total = stats["total_tokens"]
|
|
525
574
|
for role, key in [
|
|
@@ -529,22 +578,27 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
|
529
578
|
]:
|
|
530
579
|
tokens = stats.get(key, 0)
|
|
531
580
|
pct = tokens / total * 100 if total > 0 else 0
|
|
532
|
-
|
|
581
|
+
role_table.add_row(role, f"{tokens:,}", f"{pct:.1f}%")
|
|
533
582
|
|
|
534
|
-
console.print(
|
|
583
|
+
console.print(role_table)
|
|
535
584
|
console.print(f"\n平均对话轮数: {stats.get('avg_turns', 0)}")
|
|
536
585
|
|
|
537
586
|
except ImportError:
|
|
538
587
|
# 没有 rich,使用普通打印
|
|
588
|
+
std = stats.get("std_tokens", 0)
|
|
539
589
|
print(f"\n{'=' * 40}")
|
|
540
590
|
print("📊 Token 统计概览")
|
|
541
591
|
print(f"{'=' * 40}")
|
|
542
592
|
print(f"总样本数: {stats['count']:,}")
|
|
543
593
|
print(f"总 Token: {stats['total_tokens']:,}")
|
|
544
|
-
print(f"平均 Token: {stats['avg_tokens']:,}")
|
|
545
|
-
print(f"中位数: {stats['median_tokens']:,}")
|
|
594
|
+
print(f"平均 Token: {stats['avg_tokens']:,} (std: {std:.1f})")
|
|
546
595
|
print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
|
|
547
596
|
|
|
597
|
+
print(f"\n📈 百分位分布:")
|
|
598
|
+
print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
|
|
599
|
+
print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
|
|
600
|
+
print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
|
|
601
|
+
|
|
548
602
|
if detailed:
|
|
549
603
|
print(f"\n{'=' * 40}")
|
|
550
604
|
print("📋 分角色统计")
|
|
@@ -566,24 +620,48 @@ def _print_text_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
|
566
620
|
try:
|
|
567
621
|
from rich.console import Console
|
|
568
622
|
from rich.panel import Panel
|
|
623
|
+
from rich.table import Table
|
|
569
624
|
|
|
570
625
|
console = Console()
|
|
571
626
|
|
|
627
|
+
std = stats.get("std_tokens", 0)
|
|
572
628
|
overview = (
|
|
573
629
|
f"[bold]总样本数:[/bold] {stats['count']:,}\n"
|
|
574
630
|
f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
|
|
575
|
-
f"[bold]平均 Token:[/bold] {stats['avg_tokens']:.1f}\n"
|
|
576
|
-
f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
|
|
631
|
+
f"[bold]平均 Token:[/bold] {stats['avg_tokens']:.1f} (std: {std:.1f})\n"
|
|
577
632
|
f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
|
|
578
633
|
)
|
|
579
634
|
console.print(Panel(overview, title="📊 Token 统计", expand=False))
|
|
580
635
|
|
|
636
|
+
# 百分位数表格
|
|
637
|
+
table = Table(title="📈 分布统计")
|
|
638
|
+
table.add_column("百分位", style="cyan", justify="center")
|
|
639
|
+
table.add_column("Token 数", justify="right")
|
|
640
|
+
percentiles = [
|
|
641
|
+
("Min", stats["min_tokens"]),
|
|
642
|
+
("P25", stats.get("p25", "-")),
|
|
643
|
+
("P50 (中位数)", stats.get("median_tokens", "-")),
|
|
644
|
+
("P75", stats.get("p75", "-")),
|
|
645
|
+
("P90", stats.get("p90", "-")),
|
|
646
|
+
("P95", stats.get("p95", "-")),
|
|
647
|
+
("P99", stats.get("p99", "-")),
|
|
648
|
+
("Max", stats["max_tokens"]),
|
|
649
|
+
]
|
|
650
|
+
for name, val in percentiles:
|
|
651
|
+
table.add_row(name, f"{val:,}" if isinstance(val, int) else str(val))
|
|
652
|
+
console.print(table)
|
|
653
|
+
|
|
581
654
|
except ImportError:
|
|
655
|
+
std = stats.get("std_tokens", 0)
|
|
582
656
|
print(f"\n{'=' * 40}")
|
|
583
657
|
print("📊 Token 统计")
|
|
584
658
|
print(f"{'=' * 40}")
|
|
585
659
|
print(f"总样本数: {stats['count']:,}")
|
|
586
660
|
print(f"总 Token: {stats['total_tokens']:,}")
|
|
587
|
-
print(f"平均 Token: {stats['avg_tokens']:.1f}")
|
|
588
|
-
print(f"中位数: {stats['median_tokens']:,}")
|
|
661
|
+
print(f"平均 Token: {stats['avg_tokens']:.1f} (std: {std:.1f})")
|
|
589
662
|
print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
|
|
663
|
+
|
|
664
|
+
print(f"\n📈 百分位分布:")
|
|
665
|
+
print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
|
|
666
|
+
print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
|
|
667
|
+
print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
|
dtflow/cli/validate.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CLI Schema 验证命令
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from ..schema import (
|
|
9
|
+
Schema,
|
|
10
|
+
Field,
|
|
11
|
+
alpaca_schema,
|
|
12
|
+
dpo_schema,
|
|
13
|
+
openai_chat_schema,
|
|
14
|
+
sharegpt_schema,
|
|
15
|
+
)
|
|
16
|
+
from ..storage.io import load_data, save_data
|
|
17
|
+
from .common import _check_file_format
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# 预设 Schema 映射
|
|
21
|
+
PRESET_SCHEMAS = {
|
|
22
|
+
"openai_chat": openai_chat_schema,
|
|
23
|
+
"openai-chat": openai_chat_schema,
|
|
24
|
+
"chat": openai_chat_schema,
|
|
25
|
+
"alpaca": alpaca_schema,
|
|
26
|
+
"dpo": dpo_schema,
|
|
27
|
+
"dpo_pair": dpo_schema,
|
|
28
|
+
"sharegpt": sharegpt_schema,
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def validate(
|
|
33
|
+
filename: str,
|
|
34
|
+
preset: Optional[str] = None,
|
|
35
|
+
output: Optional[str] = None,
|
|
36
|
+
filter_invalid: bool = False,
|
|
37
|
+
max_errors: int = 20,
|
|
38
|
+
verbose: bool = False,
|
|
39
|
+
) -> None:
|
|
40
|
+
"""
|
|
41
|
+
使用 Schema 验证数据文件。
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
filename: 输入文件路径
|
|
45
|
+
preset: 预设 Schema 名称 (openai_chat, alpaca, dpo, sharegpt)
|
|
46
|
+
output: 输出文件路径(保存有效数据)
|
|
47
|
+
filter_invalid: 过滤无效数据并保存
|
|
48
|
+
max_errors: 最多显示的错误数量
|
|
49
|
+
verbose: 显示详细信息
|
|
50
|
+
|
|
51
|
+
Examples:
|
|
52
|
+
dt validate data.jsonl --preset=openai_chat
|
|
53
|
+
dt validate data.jsonl --preset=alpaca -o valid.jsonl
|
|
54
|
+
dt validate data.jsonl --preset=chat --filter
|
|
55
|
+
"""
|
|
56
|
+
filepath = Path(filename)
|
|
57
|
+
|
|
58
|
+
if not filepath.exists():
|
|
59
|
+
print(f"错误: 文件不存在 - {filename}")
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
if not _check_file_format(filepath):
|
|
63
|
+
return
|
|
64
|
+
|
|
65
|
+
# 确定 Schema
|
|
66
|
+
if preset is None:
|
|
67
|
+
# 列出可用的预设
|
|
68
|
+
print("请指定预设 Schema (--preset):")
|
|
69
|
+
print()
|
|
70
|
+
for name in ["openai_chat", "alpaca", "dpo", "sharegpt"]:
|
|
71
|
+
print(f" --preset={name}")
|
|
72
|
+
print()
|
|
73
|
+
print("示例:")
|
|
74
|
+
print(f" dt validate {filename} --preset=openai_chat")
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
preset_lower = preset.lower().replace("-", "_")
|
|
78
|
+
if preset_lower not in PRESET_SCHEMAS:
|
|
79
|
+
print(f"错误: 未知的预设 Schema '{preset}'")
|
|
80
|
+
print(f"可用预设: {', '.join(['openai_chat', 'alpaca', 'dpo', 'sharegpt'])}")
|
|
81
|
+
return
|
|
82
|
+
|
|
83
|
+
schema = PRESET_SCHEMAS[preset_lower]()
|
|
84
|
+
|
|
85
|
+
# 加载数据
|
|
86
|
+
try:
|
|
87
|
+
data = load_data(str(filepath))
|
|
88
|
+
except Exception as e:
|
|
89
|
+
print(f"错误: 无法读取文件 - {e}")
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
if not data:
|
|
93
|
+
print("文件为空")
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
total = len(data)
|
|
97
|
+
print(f"验证文件: {filepath.name}")
|
|
98
|
+
print(f"预设 Schema: {preset}")
|
|
99
|
+
print(f"总记录数: {total}")
|
|
100
|
+
print()
|
|
101
|
+
|
|
102
|
+
# 验证
|
|
103
|
+
valid_data = []
|
|
104
|
+
invalid_count = 0
|
|
105
|
+
error_samples = []
|
|
106
|
+
|
|
107
|
+
for i, item in enumerate(data):
|
|
108
|
+
result = schema.validate(item)
|
|
109
|
+
if result.valid:
|
|
110
|
+
valid_data.append(item)
|
|
111
|
+
else:
|
|
112
|
+
invalid_count += 1
|
|
113
|
+
if len(error_samples) < max_errors:
|
|
114
|
+
error_samples.append((i, result))
|
|
115
|
+
|
|
116
|
+
valid_count = len(valid_data)
|
|
117
|
+
valid_ratio = valid_count / total * 100 if total > 0 else 0
|
|
118
|
+
|
|
119
|
+
# 输出结果
|
|
120
|
+
if invalid_count == 0:
|
|
121
|
+
print(f"✅ 全部通过! {valid_count}/{total} 条记录有效 (100%)")
|
|
122
|
+
else:
|
|
123
|
+
print(f"⚠️ 验证结果: {valid_count}/{total} 条有效 ({valid_ratio:.1f}%)")
|
|
124
|
+
print(f" 无效记录: {invalid_count} 条")
|
|
125
|
+
print()
|
|
126
|
+
|
|
127
|
+
# 显示错误示例
|
|
128
|
+
print(f"错误示例 (最多显示 {max_errors} 条):")
|
|
129
|
+
print("-" * 60)
|
|
130
|
+
|
|
131
|
+
for idx, result in error_samples:
|
|
132
|
+
print(f"[第 {idx} 行]")
|
|
133
|
+
for err in result.errors[:3]: # 每条记录最多显示 3 个错误
|
|
134
|
+
print(f" - {err}")
|
|
135
|
+
if len(result.errors) > 3:
|
|
136
|
+
print(f" ... 还有 {len(result.errors) - 3} 个错误")
|
|
137
|
+
print()
|
|
138
|
+
|
|
139
|
+
# 保存有效数据
|
|
140
|
+
if output or filter_invalid:
|
|
141
|
+
output_path = output or str(filepath).replace(
|
|
142
|
+
filepath.suffix, f"_valid{filepath.suffix}"
|
|
143
|
+
)
|
|
144
|
+
save_data(valid_data, output_path)
|
|
145
|
+
print(f"✅ 有效数据已保存: {output_path} ({valid_count} 条)")
|
|
146
|
+
|
|
147
|
+
# 详细模式:显示 Schema 定义
|
|
148
|
+
if verbose:
|
|
149
|
+
print()
|
|
150
|
+
print("Schema 定义:")
|
|
151
|
+
print("-" * 40)
|
|
152
|
+
print(schema)
|