dtflow 0.4.3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/__init__.py CHANGED
@@ -4,6 +4,7 @@ DataTransformer: 简洁的数据格式转换工具
4
4
  核心功能:
5
5
  - DataTransformer: 数据加载、转换、保存
6
6
  - presets: 预设转换模板 (openai_chat, alpaca, sharegpt, dpo_pair, simple_qa)
7
+ - schema: 数据结构验证 (Schema, Field)
7
8
  - tokenizers: Token 统计和过滤
8
9
  - converters: HuggingFace/OpenAI 等格式转换
9
10
  """
@@ -26,6 +27,23 @@ from .converters import ( # LLaMA-Factory 扩展; ms-swift
26
27
  )
27
28
  from .core import DataTransformer, DictWrapper, TransformError, TransformErrors
28
29
  from .presets import get_preset, list_presets
30
+ from .schema import (
31
+ Field,
32
+ Schema,
33
+ ValidationError,
34
+ ValidationResult,
35
+ alpaca_schema,
36
+ dpo_schema,
37
+ openai_chat_schema,
38
+ sharegpt_schema,
39
+ validate_data,
40
+ )
41
+ from .framework import (
42
+ CompatibilityResult,
43
+ check_compatibility,
44
+ detect_format,
45
+ export_for,
46
+ )
29
47
  from .storage import load_data, sample_file, save_data
30
48
  from .streaming import StreamingTransformer, load_sharded, load_stream, process_shards
31
49
  from .tokenizers import (
@@ -42,7 +60,7 @@ from .tokenizers import (
42
60
  token_stats,
43
61
  )
44
62
 
45
- __version__ = "0.4.3"
63
+ __version__ = "0.5.2"
46
64
 
47
65
  __all__ = [
48
66
  # core
@@ -53,6 +71,21 @@ __all__ = [
53
71
  # presets
54
72
  "get_preset",
55
73
  "list_presets",
74
+ # schema
75
+ "Schema",
76
+ "Field",
77
+ "ValidationResult",
78
+ "ValidationError",
79
+ "validate_data",
80
+ "openai_chat_schema",
81
+ "alpaca_schema",
82
+ "dpo_schema",
83
+ "sharegpt_schema",
84
+ # framework
85
+ "CompatibilityResult",
86
+ "check_compatibility",
87
+ "detect_format",
88
+ "export_for",
56
89
  # storage
57
90
  "save_data",
58
91
  "load_data",
dtflow/__main__.py CHANGED
@@ -18,6 +18,7 @@ Commands:
18
18
  clean 数据清洗
19
19
  run 执行 Pipeline 配置文件
20
20
  history 显示数据血缘历史
21
+ validate 使用 Schema 验证数据格式
21
22
  mcp MCP 服务管理(install/uninstall/status)
22
23
  logs 日志查看工具使用说明
23
24
  """
@@ -40,6 +41,7 @@ from .cli.commands import stats as _stats
40
41
  from .cli.commands import tail as _tail
41
42
  from .cli.commands import token_stats as _token_stats
42
43
  from .cli.commands import transform as _transform
44
+ from .cli.commands import validate as _validate
43
45
 
44
46
  # 创建主应用
45
47
  app = typer.Typer(
@@ -211,6 +213,26 @@ def history(
211
213
  _history(filename, json)
212
214
 
213
215
 
216
+ # ============ 验证命令 ============
217
+
218
+
219
+ @app.command()
220
+ def validate(
221
+ filename: str = typer.Argument(..., help="输入文件路径"),
222
+ preset: Optional[str] = typer.Option(
223
+ None, "--preset", "-p", help="预设 Schema: openai_chat, alpaca, dpo, sharegpt"
224
+ ),
225
+ output: Optional[str] = typer.Option(None, "--output", "-o", help="输出有效数据的文件路径"),
226
+ filter: bool = typer.Option(
227
+ False, "--filter", "-f", help="过滤无效数据并保存"
228
+ ),
229
+ max_errors: int = typer.Option(20, "--max-errors", help="最多显示的错误数量"),
230
+ verbose: bool = typer.Option(False, "--verbose", "-v", help="显示详细信息"),
231
+ ):
232
+ """使用预设 Schema 验证数据格式"""
233
+ _validate(filename, preset, output, filter, max_errors, verbose)
234
+
235
+
214
236
  # ============ 工具命令 ============
215
237
 
216
238
 
dtflow/cli/commands.py CHANGED
@@ -33,6 +33,9 @@ from .pipeline import run
33
33
  # 血缘追踪命令
34
34
  from .lineage import history
35
35
 
36
+ # 验证命令
37
+ from .validate import validate
38
+
36
39
  __all__ = [
37
40
  # 采样
38
41
  "sample",
@@ -53,4 +56,6 @@ __all__ = [
53
56
  "run",
54
57
  # 血缘
55
58
  "history",
59
+ # 验证
60
+ "validate",
56
61
  ]
dtflow/cli/common.py CHANGED
@@ -57,7 +57,7 @@ def _get_file_row_count(filepath: Path) -> Optional[int]:
57
57
  return None
58
58
 
59
59
 
60
- def _format_value(value: Any, max_len: int = 80) -> str:
60
+ def _format_value(value: Any, max_len: int = 120) -> str:
61
61
  """格式化单个值,长文本截断。"""
62
62
  if value is None:
63
63
  return "[dim]null[/dim]"
@@ -66,18 +66,22 @@ def _format_value(value: Any, max_len: int = 80) -> str:
66
66
  if isinstance(value, (int, float)):
67
67
  return f"[cyan]{value}[/cyan]"
68
68
  if isinstance(value, str):
69
+ half_len = max_len // 2
69
70
  # 处理多行文本
70
71
  if "\n" in value:
71
72
  lines = value.split("\n")
72
- if len(lines) > 3:
73
- preview = lines[0][:max_len] + f"... [dim]({len(lines)} 行)[/dim]"
74
- else:
75
- preview = value.replace("\n", "\\n")
76
- if len(preview) > max_len:
77
- preview = preview[:max_len] + "..."
73
+ preview = value.replace("\n", "\\n")
74
+ if len(preview) > max_len:
75
+ # 前半 + 省略标记 + 后半
76
+ head = preview[:half_len]
77
+ tail = preview[-half_len:]
78
+ return f'"{head} [yellow]<<<{len(lines)}行>>>[/yellow] {tail}"'
78
79
  return f'"{preview}"'
79
80
  if len(value) > max_len:
80
- return f'"{value[:max_len]}..." [dim]({len(value)} 字符)[/dim]'
81
+ # 前半 + 省略标记 + 后半
82
+ head = value[:half_len]
83
+ tail = value[-half_len:]
84
+ return f'"{head} [yellow]<<<{len(value)}字符>>>[/yellow] {tail}"'
81
85
  return f'"{value}"'
82
86
  return str(value)
83
87
 
@@ -86,7 +90,7 @@ def _format_nested(
86
90
  value: Any,
87
91
  indent: str = "",
88
92
  is_last: bool = True,
89
- max_len: int = 80,
93
+ max_len: int = 120,
90
94
  ) -> List[str]:
91
95
  """
92
96
  递归格式化嵌套结构,返回行列表。
dtflow/cli/stats.py CHANGED
@@ -465,34 +465,65 @@ def token_stats(
465
465
  return
466
466
 
467
467
  total = len(data)
468
- print(f" 共 {total} 条数据")
469
- print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
468
+ print(f" 共 {total:,} 条数据")
470
469
 
471
470
  # 检查字段类型并选择合适的统计方法(支持嵌套路径)
472
471
  sample = data[0]
473
472
  field_value = get_field_with_spec(sample, field)
474
473
 
474
+ # 尝试使用 rich 进度条
475
475
  try:
476
- if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
477
- # messages 格式
478
- from ..tokenizers import messages_token_stats
479
-
480
- stats_result = messages_token_stats(data, messages_field=field, model=model)
481
- _print_messages_token_stats(stats_result, detailed)
482
- else:
483
- # 普通文本字段
484
- from ..tokenizers import token_stats as compute_token_stats
485
-
486
- stats_result = compute_token_stats(data, fields=field, model=model)
487
- _print_text_token_stats(stats_result, detailed)
488
- except ImportError as e:
489
- print(f"错误: {e}")
490
- return
491
- except Exception as e:
492
- print(f"错误: 统计失败 - {e}")
493
- import traceback
476
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
477
+
478
+ with Progress(
479
+ SpinnerColumn(),
480
+ TextColumn("[bold blue]统计 Token"),
481
+ BarColumn(),
482
+ TaskProgressColumn(),
483
+ TextColumn(f"(模型: {model})"),
484
+ ) as progress:
485
+ task = progress.add_task("", total=total)
486
+
487
+ def update_progress(current: int, total_count: int):
488
+ progress.update(task, completed=current)
489
+
490
+ if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
491
+ from ..tokenizers import messages_token_stats
492
+
493
+ stats_result = messages_token_stats(
494
+ data, messages_field=field, model=model, progress_callback=update_progress
495
+ )
496
+ _print_messages_token_stats(stats_result, detailed)
497
+ else:
498
+ from ..tokenizers import token_stats as compute_token_stats
499
+
500
+ stats_result = compute_token_stats(
501
+ data, fields=field, model=model, progress_callback=update_progress
502
+ )
503
+ _print_text_token_stats(stats_result, detailed)
504
+
505
+ except ImportError:
506
+ # 没有 rich,显示简单进度
507
+ print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
508
+ try:
509
+ if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
510
+ from ..tokenizers import messages_token_stats
511
+
512
+ stats_result = messages_token_stats(data, messages_field=field, model=model)
513
+ _print_messages_token_stats(stats_result, detailed)
514
+ else:
515
+ from ..tokenizers import token_stats as compute_token_stats
494
516
 
495
- traceback.print_exc()
517
+ stats_result = compute_token_stats(data, fields=field, model=model)
518
+ _print_text_token_stats(stats_result, detailed)
519
+ except ImportError as e:
520
+ print(f"错误: {e}")
521
+ return
522
+ except Exception as e:
523
+ print(f"错误: 统计失败 - {e}")
524
+ import traceback
525
+
526
+ traceback.print_exc()
496
527
 
497
528
 
498
529
  def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
@@ -505,21 +536,39 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
505
536
  console = Console()
506
537
 
507
538
  # 概览
539
+ std = stats.get("std_tokens", 0)
508
540
  overview = (
509
541
  f"[bold]总样本数:[/bold] {stats['count']:,}\n"
510
542
  f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
511
- f"[bold]平均 Token:[/bold] {stats['avg_tokens']:,}\n"
512
- f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
543
+ f"[bold]平均 Token:[/bold] {stats['avg_tokens']:,} (std: {std:.1f})\n"
513
544
  f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
514
545
  )
515
546
  console.print(Panel(overview, title="📊 Token 统计概览", expand=False))
516
547
 
548
+ # 百分位数表格
549
+ table = Table(title="📈 分布统计")
550
+ table.add_column("百分位", style="cyan", justify="center")
551
+ table.add_column("Token 数", justify="right")
552
+ percentiles = [
553
+ ("Min", stats["min_tokens"]),
554
+ ("P25", stats.get("p25", "-")),
555
+ ("P50 (中位数)", stats.get("median_tokens", "-")),
556
+ ("P75", stats.get("p75", "-")),
557
+ ("P90", stats.get("p90", "-")),
558
+ ("P95", stats.get("p95", "-")),
559
+ ("P99", stats.get("p99", "-")),
560
+ ("Max", stats["max_tokens"]),
561
+ ]
562
+ for name, val in percentiles:
563
+ table.add_row(name, f"{val:,}" if isinstance(val, int) else str(val))
564
+ console.print(table)
565
+
517
566
  if detailed:
518
- # 详细统计
519
- table = Table(title="📋 分角色统计")
520
- table.add_column("角色", style="cyan")
521
- table.add_column("Token 数", justify="right")
522
- table.add_column("占比", justify="right")
567
+ # 分角色统计
568
+ role_table = Table(title="📋 分角色统计")
569
+ role_table.add_column("角色", style="cyan")
570
+ role_table.add_column("Token 数", justify="right")
571
+ role_table.add_column("占比", justify="right")
523
572
 
524
573
  total = stats["total_tokens"]
525
574
  for role, key in [
@@ -529,22 +578,27 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
529
578
  ]:
530
579
  tokens = stats.get(key, 0)
531
580
  pct = tokens / total * 100 if total > 0 else 0
532
- table.add_row(role, f"{tokens:,}", f"{pct:.1f}%")
581
+ role_table.add_row(role, f"{tokens:,}", f"{pct:.1f}%")
533
582
 
534
- console.print(table)
583
+ console.print(role_table)
535
584
  console.print(f"\n平均对话轮数: {stats.get('avg_turns', 0)}")
536
585
 
537
586
  except ImportError:
538
587
  # 没有 rich,使用普通打印
588
+ std = stats.get("std_tokens", 0)
539
589
  print(f"\n{'=' * 40}")
540
590
  print("📊 Token 统计概览")
541
591
  print(f"{'=' * 40}")
542
592
  print(f"总样本数: {stats['count']:,}")
543
593
  print(f"总 Token: {stats['total_tokens']:,}")
544
- print(f"平均 Token: {stats['avg_tokens']:,}")
545
- print(f"中位数: {stats['median_tokens']:,}")
594
+ print(f"平均 Token: {stats['avg_tokens']:,} (std: {std:.1f})")
546
595
  print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
547
596
 
597
+ print(f"\n📈 百分位分布:")
598
+ print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
599
+ print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
600
+ print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
601
+
548
602
  if detailed:
549
603
  print(f"\n{'=' * 40}")
550
604
  print("📋 分角色统计")
@@ -566,24 +620,48 @@ def _print_text_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
566
620
  try:
567
621
  from rich.console import Console
568
622
  from rich.panel import Panel
623
+ from rich.table import Table
569
624
 
570
625
  console = Console()
571
626
 
627
+ std = stats.get("std_tokens", 0)
572
628
  overview = (
573
629
  f"[bold]总样本数:[/bold] {stats['count']:,}\n"
574
630
  f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
575
- f"[bold]平均 Token:[/bold] {stats['avg_tokens']:.1f}\n"
576
- f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
631
+ f"[bold]平均 Token:[/bold] {stats['avg_tokens']:.1f} (std: {std:.1f})\n"
577
632
  f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
578
633
  )
579
634
  console.print(Panel(overview, title="📊 Token 统计", expand=False))
580
635
 
636
+ # 百分位数表格
637
+ table = Table(title="📈 分布统计")
638
+ table.add_column("百分位", style="cyan", justify="center")
639
+ table.add_column("Token 数", justify="right")
640
+ percentiles = [
641
+ ("Min", stats["min_tokens"]),
642
+ ("P25", stats.get("p25", "-")),
643
+ ("P50 (中位数)", stats.get("median_tokens", "-")),
644
+ ("P75", stats.get("p75", "-")),
645
+ ("P90", stats.get("p90", "-")),
646
+ ("P95", stats.get("p95", "-")),
647
+ ("P99", stats.get("p99", "-")),
648
+ ("Max", stats["max_tokens"]),
649
+ ]
650
+ for name, val in percentiles:
651
+ table.add_row(name, f"{val:,}" if isinstance(val, int) else str(val))
652
+ console.print(table)
653
+
581
654
  except ImportError:
655
+ std = stats.get("std_tokens", 0)
582
656
  print(f"\n{'=' * 40}")
583
657
  print("📊 Token 统计")
584
658
  print(f"{'=' * 40}")
585
659
  print(f"总样本数: {stats['count']:,}")
586
660
  print(f"总 Token: {stats['total_tokens']:,}")
587
- print(f"平均 Token: {stats['avg_tokens']:.1f}")
588
- print(f"中位数: {stats['median_tokens']:,}")
661
+ print(f"平均 Token: {stats['avg_tokens']:.1f} (std: {std:.1f})")
589
662
  print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
663
+
664
+ print(f"\n📈 百分位分布:")
665
+ print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
666
+ print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
667
+ print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
dtflow/cli/validate.py ADDED
@@ -0,0 +1,152 @@
1
+ """
2
+ CLI Schema 验证命令
3
+ """
4
+
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ from ..schema import (
9
+ Schema,
10
+ Field,
11
+ alpaca_schema,
12
+ dpo_schema,
13
+ openai_chat_schema,
14
+ sharegpt_schema,
15
+ )
16
+ from ..storage.io import load_data, save_data
17
+ from .common import _check_file_format
18
+
19
+
20
+ # 预设 Schema 映射
21
+ PRESET_SCHEMAS = {
22
+ "openai_chat": openai_chat_schema,
23
+ "openai-chat": openai_chat_schema,
24
+ "chat": openai_chat_schema,
25
+ "alpaca": alpaca_schema,
26
+ "dpo": dpo_schema,
27
+ "dpo_pair": dpo_schema,
28
+ "sharegpt": sharegpt_schema,
29
+ }
30
+
31
+
32
+ def validate(
33
+ filename: str,
34
+ preset: Optional[str] = None,
35
+ output: Optional[str] = None,
36
+ filter_invalid: bool = False,
37
+ max_errors: int = 20,
38
+ verbose: bool = False,
39
+ ) -> None:
40
+ """
41
+ 使用 Schema 验证数据文件。
42
+
43
+ Args:
44
+ filename: 输入文件路径
45
+ preset: 预设 Schema 名称 (openai_chat, alpaca, dpo, sharegpt)
46
+ output: 输出文件路径(保存有效数据)
47
+ filter_invalid: 过滤无效数据并保存
48
+ max_errors: 最多显示的错误数量
49
+ verbose: 显示详细信息
50
+
51
+ Examples:
52
+ dt validate data.jsonl --preset=openai_chat
53
+ dt validate data.jsonl --preset=alpaca -o valid.jsonl
54
+ dt validate data.jsonl --preset=chat --filter
55
+ """
56
+ filepath = Path(filename)
57
+
58
+ if not filepath.exists():
59
+ print(f"错误: 文件不存在 - {filename}")
60
+ return
61
+
62
+ if not _check_file_format(filepath):
63
+ return
64
+
65
+ # 确定 Schema
66
+ if preset is None:
67
+ # 列出可用的预设
68
+ print("请指定预设 Schema (--preset):")
69
+ print()
70
+ for name in ["openai_chat", "alpaca", "dpo", "sharegpt"]:
71
+ print(f" --preset={name}")
72
+ print()
73
+ print("示例:")
74
+ print(f" dt validate {filename} --preset=openai_chat")
75
+ return
76
+
77
+ preset_lower = preset.lower().replace("-", "_")
78
+ if preset_lower not in PRESET_SCHEMAS:
79
+ print(f"错误: 未知的预设 Schema '{preset}'")
80
+ print(f"可用预设: {', '.join(['openai_chat', 'alpaca', 'dpo', 'sharegpt'])}")
81
+ return
82
+
83
+ schema = PRESET_SCHEMAS[preset_lower]()
84
+
85
+ # 加载数据
86
+ try:
87
+ data = load_data(str(filepath))
88
+ except Exception as e:
89
+ print(f"错误: 无法读取文件 - {e}")
90
+ return
91
+
92
+ if not data:
93
+ print("文件为空")
94
+ return
95
+
96
+ total = len(data)
97
+ print(f"验证文件: {filepath.name}")
98
+ print(f"预设 Schema: {preset}")
99
+ print(f"总记录数: {total}")
100
+ print()
101
+
102
+ # 验证
103
+ valid_data = []
104
+ invalid_count = 0
105
+ error_samples = []
106
+
107
+ for i, item in enumerate(data):
108
+ result = schema.validate(item)
109
+ if result.valid:
110
+ valid_data.append(item)
111
+ else:
112
+ invalid_count += 1
113
+ if len(error_samples) < max_errors:
114
+ error_samples.append((i, result))
115
+
116
+ valid_count = len(valid_data)
117
+ valid_ratio = valid_count / total * 100 if total > 0 else 0
118
+
119
+ # 输出结果
120
+ if invalid_count == 0:
121
+ print(f"✅ 全部通过! {valid_count}/{total} 条记录有效 (100%)")
122
+ else:
123
+ print(f"⚠️ 验证结果: {valid_count}/{total} 条有效 ({valid_ratio:.1f}%)")
124
+ print(f" 无效记录: {invalid_count} 条")
125
+ print()
126
+
127
+ # 显示错误示例
128
+ print(f"错误示例 (最多显示 {max_errors} 条):")
129
+ print("-" * 60)
130
+
131
+ for idx, result in error_samples:
132
+ print(f"[第 {idx} 行]")
133
+ for err in result.errors[:3]: # 每条记录最多显示 3 个错误
134
+ print(f" - {err}")
135
+ if len(result.errors) > 3:
136
+ print(f" ... 还有 {len(result.errors) - 3} 个错误")
137
+ print()
138
+
139
+ # 保存有效数据
140
+ if output or filter_invalid:
141
+ output_path = output or str(filepath).replace(
142
+ filepath.suffix, f"_valid{filepath.suffix}"
143
+ )
144
+ save_data(valid_data, output_path)
145
+ print(f"✅ 有效数据已保存: {output_path} ({valid_count} 条)")
146
+
147
+ # 详细模式:显示 Schema 定义
148
+ if verbose:
149
+ print()
150
+ print("Schema 定义:")
151
+ print("-" * 40)
152
+ print(schema)