dtflow 0.5.9__py3-none-any.whl → 0.5.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/SKILL.md +2 -0
- dtflow/__init__.py +1 -1
- dtflow/__main__.py +8 -2
- dtflow/cli/stats.py +23 -10
- dtflow/cli/validate.py +52 -19
- dtflow/parallel.py +115 -0
- dtflow/schema.py +99 -13
- dtflow/tokenizers.py +104 -21
- {dtflow-0.5.9.dist-info → dtflow-0.5.10.dist-info}/METADATA +3 -1
- {dtflow-0.5.9.dist-info → dtflow-0.5.10.dist-info}/RECORD +12 -11
- {dtflow-0.5.9.dist-info → dtflow-0.5.10.dist-info}/WHEEL +0 -0
- {dtflow-0.5.9.dist-info → dtflow-0.5.10.dist-info}/entry_points.txt +0 -0
dtflow/SKILL.md
CHANGED
|
@@ -154,6 +154,7 @@ dt token-stats data.jsonl # 默认统计 messages 字段
|
|
|
154
154
|
dt token-stats data.jsonl -f text # 指定统计字段
|
|
155
155
|
dt token-stats data.jsonl -m qwen2.5 # 指定分词器 (cl100k_base/qwen2.5/llama3)
|
|
156
156
|
dt token-stats data.jsonl --detailed # 显示详细统计
|
|
157
|
+
dt token-stats data.jsonl -w 4 # 多进程加速(数据量>=1000时自动启用)
|
|
157
158
|
|
|
158
159
|
# 采样(支持字段路径语法)
|
|
159
160
|
dt sample data.jsonl 100 # 随机采样 100 条
|
|
@@ -204,6 +205,7 @@ dt validate data.jsonl --preset=openai_chat # 预设: openai_chat/alpaca/d
|
|
|
204
205
|
dt validate data.jsonl -p alpaca -f -o valid.jsonl # 过滤无效数据并保存
|
|
205
206
|
dt validate data.jsonl -p openai_chat -v # 显示详细信息
|
|
206
207
|
dt validate data.jsonl -p openai_chat --max-errors=50 # 最多显示 50 条错误
|
|
208
|
+
dt validate data.jsonl -p openai_chat -w 4 # 多进程加速
|
|
207
209
|
|
|
208
210
|
# 转换
|
|
209
211
|
dt transform data.jsonl --preset=openai_chat
|
dtflow/__init__.py
CHANGED
dtflow/__main__.py
CHANGED
|
@@ -256,9 +256,12 @@ def token_stats(
|
|
|
256
256
|
"cl100k_base", "--model", "-m", help="分词器: cl100k_base (默认), qwen2.5, llama3, gpt-4 等"
|
|
257
257
|
),
|
|
258
258
|
detailed: bool = typer.Option(False, "--detailed", "-d", help="显示详细统计"),
|
|
259
|
+
workers: Optional[int] = typer.Option(
|
|
260
|
+
None, "--workers", "-w", help="并行进程数 (默认自动, 1 禁用并行)"
|
|
261
|
+
),
|
|
259
262
|
):
|
|
260
263
|
"""统计数据集的 Token 信息"""
|
|
261
|
-
_token_stats(filename, field, model, detailed)
|
|
264
|
+
_token_stats(filename, field, model, detailed, workers)
|
|
262
265
|
|
|
263
266
|
|
|
264
267
|
@app.command()
|
|
@@ -359,9 +362,12 @@ def validate(
|
|
|
359
362
|
filter: bool = typer.Option(False, "--filter", "-f", help="过滤无效数据并保存"),
|
|
360
363
|
max_errors: int = typer.Option(20, "--max-errors", help="最多显示的错误数量"),
|
|
361
364
|
verbose: bool = typer.Option(False, "--verbose", "-v", help="显示详细信息"),
|
|
365
|
+
workers: Optional[int] = typer.Option(
|
|
366
|
+
None, "--workers", "-w", help="并行进程数 (默认自动, 1 禁用并行)"
|
|
367
|
+
),
|
|
362
368
|
):
|
|
363
369
|
"""使用预设 Schema 验证数据格式"""
|
|
364
|
-
_validate(filename, preset, output, filter, max_errors, verbose)
|
|
370
|
+
_validate(filename, preset, output, filter, max_errors, verbose, workers)
|
|
365
371
|
|
|
366
372
|
|
|
367
373
|
# ============ 工具命令 ============
|
dtflow/cli/stats.py
CHANGED
|
@@ -209,7 +209,7 @@ def _quick_stats(filepath: Path) -> None:
|
|
|
209
209
|
print(f"字段: {len(fields)} 个")
|
|
210
210
|
|
|
211
211
|
if fields:
|
|
212
|
-
print(
|
|
212
|
+
print("\n📋 字段结构:")
|
|
213
213
|
for i, f in enumerate(fields, 1):
|
|
214
214
|
print(f" {i}. {f['field']} ({f['type']})")
|
|
215
215
|
|
|
@@ -597,14 +597,14 @@ def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -
|
|
|
597
597
|
except ImportError:
|
|
598
598
|
# 没有 rich,使用普通打印
|
|
599
599
|
print(f"\n{'=' * 50}")
|
|
600
|
-
print(
|
|
600
|
+
print("📊 数据概览")
|
|
601
601
|
print(f"{'=' * 50}")
|
|
602
602
|
print(f"文件: {filename}")
|
|
603
603
|
print(f"总数: {total:,} 条")
|
|
604
604
|
print(f"字段: {len(field_stats)} 个")
|
|
605
605
|
|
|
606
606
|
print(f"\n{'=' * 50}")
|
|
607
|
-
print(
|
|
607
|
+
print("📋 字段统计")
|
|
608
608
|
print(f"{'=' * 50}")
|
|
609
609
|
print(f"{'字段':<20} {'类型':<8} {'非空率':<8} {'唯一值':<8}")
|
|
610
610
|
print("-" * 50)
|
|
@@ -620,6 +620,7 @@ def token_stats(
|
|
|
620
620
|
field: str = "messages",
|
|
621
621
|
model: str = "cl100k_base",
|
|
622
622
|
detailed: bool = False,
|
|
623
|
+
workers: Optional[int] = None,
|
|
623
624
|
) -> None:
|
|
624
625
|
"""
|
|
625
626
|
统计数据集的 Token 信息。
|
|
@@ -629,6 +630,7 @@ def token_stats(
|
|
|
629
630
|
field: 要统计的字段(默认 messages),支持嵌套路径语法
|
|
630
631
|
model: 分词器: cl100k_base (默认), qwen2.5, llama3, gpt-4 等
|
|
631
632
|
detailed: 是否显示详细统计
|
|
633
|
+
workers: 并行进程数,None 自动检测,1 禁用并行
|
|
632
634
|
|
|
633
635
|
Examples:
|
|
634
636
|
dt token-stats data.jsonl
|
|
@@ -636,6 +638,7 @@ def token_stats(
|
|
|
636
638
|
dt token-stats data.jsonl --field=conversation.messages
|
|
637
639
|
dt token-stats data.jsonl --field=messages[-1].content # 统计最后一条消息
|
|
638
640
|
dt token-stats data.jsonl --detailed
|
|
641
|
+
dt token-stats data.jsonl --workers=4 # 使用 4 进程
|
|
639
642
|
"""
|
|
640
643
|
filepath = Path(filename)
|
|
641
644
|
|
|
@@ -667,7 +670,7 @@ def token_stats(
|
|
|
667
670
|
|
|
668
671
|
# 尝试使用 rich 进度条
|
|
669
672
|
try:
|
|
670
|
-
from rich.progress import Progress, SpinnerColumn,
|
|
673
|
+
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
|
671
674
|
|
|
672
675
|
with Progress(
|
|
673
676
|
SpinnerColumn(),
|
|
@@ -685,14 +688,22 @@ def token_stats(
|
|
|
685
688
|
from ..tokenizers import messages_token_stats
|
|
686
689
|
|
|
687
690
|
stats_result = messages_token_stats(
|
|
688
|
-
data,
|
|
691
|
+
data,
|
|
692
|
+
messages_field=field,
|
|
693
|
+
model=model,
|
|
694
|
+
progress_callback=update_progress,
|
|
695
|
+
workers=workers,
|
|
689
696
|
)
|
|
690
697
|
_print_messages_token_stats(stats_result, detailed)
|
|
691
698
|
else:
|
|
692
699
|
from ..tokenizers import token_stats as compute_token_stats
|
|
693
700
|
|
|
694
701
|
stats_result = compute_token_stats(
|
|
695
|
-
data,
|
|
702
|
+
data,
|
|
703
|
+
fields=field,
|
|
704
|
+
model=model,
|
|
705
|
+
progress_callback=update_progress,
|
|
706
|
+
workers=workers,
|
|
696
707
|
)
|
|
697
708
|
_print_text_token_stats(stats_result, detailed)
|
|
698
709
|
|
|
@@ -703,12 +714,14 @@ def token_stats(
|
|
|
703
714
|
if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
|
|
704
715
|
from ..tokenizers import messages_token_stats
|
|
705
716
|
|
|
706
|
-
stats_result = messages_token_stats(
|
|
717
|
+
stats_result = messages_token_stats(
|
|
718
|
+
data, messages_field=field, model=model, workers=workers
|
|
719
|
+
)
|
|
707
720
|
_print_messages_token_stats(stats_result, detailed)
|
|
708
721
|
else:
|
|
709
722
|
from ..tokenizers import token_stats as compute_token_stats
|
|
710
723
|
|
|
711
|
-
stats_result = compute_token_stats(data, fields=field, model=model)
|
|
724
|
+
stats_result = compute_token_stats(data, fields=field, model=model, workers=workers)
|
|
712
725
|
_print_text_token_stats(stats_result, detailed)
|
|
713
726
|
except ImportError as e:
|
|
714
727
|
print(f"错误: {e}")
|
|
@@ -788,7 +801,7 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
|
788
801
|
print(f"平均 Token: {stats['avg_tokens']:,} (std: {std:.1f})")
|
|
789
802
|
print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
|
|
790
803
|
|
|
791
|
-
print(
|
|
804
|
+
print("\n📈 百分位分布:")
|
|
792
805
|
print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
|
|
793
806
|
print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
|
|
794
807
|
print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
|
|
@@ -855,7 +868,7 @@ def _print_text_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
|
855
868
|
print(f"平均 Token: {stats['avg_tokens']:.1f} (std: {std:.1f})")
|
|
856
869
|
print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
|
|
857
870
|
|
|
858
|
-
print(
|
|
871
|
+
print("\n📈 百分位分布:")
|
|
859
872
|
print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
|
|
860
873
|
print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
|
|
861
874
|
print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
|
dtflow/cli/validate.py
CHANGED
|
@@ -6,8 +6,6 @@ from pathlib import Path
|
|
|
6
6
|
from typing import Optional
|
|
7
7
|
|
|
8
8
|
from ..schema import (
|
|
9
|
-
Schema,
|
|
10
|
-
Field,
|
|
11
9
|
alpaca_schema,
|
|
12
10
|
dpo_schema,
|
|
13
11
|
openai_chat_schema,
|
|
@@ -16,7 +14,6 @@ from ..schema import (
|
|
|
16
14
|
from ..storage.io import load_data, save_data
|
|
17
15
|
from .common import _check_file_format
|
|
18
16
|
|
|
19
|
-
|
|
20
17
|
# 预设 Schema 映射
|
|
21
18
|
PRESET_SCHEMAS = {
|
|
22
19
|
"openai_chat": openai_chat_schema,
|
|
@@ -36,6 +33,7 @@ def validate(
|
|
|
36
33
|
filter_invalid: bool = False,
|
|
37
34
|
max_errors: int = 20,
|
|
38
35
|
verbose: bool = False,
|
|
36
|
+
workers: Optional[int] = None,
|
|
39
37
|
) -> None:
|
|
40
38
|
"""
|
|
41
39
|
使用 Schema 验证数据文件。
|
|
@@ -47,11 +45,13 @@ def validate(
|
|
|
47
45
|
filter_invalid: 过滤无效数据并保存
|
|
48
46
|
max_errors: 最多显示的错误数量
|
|
49
47
|
verbose: 显示详细信息
|
|
48
|
+
workers: 并行进程数,None 自动检测,1 禁用并行
|
|
50
49
|
|
|
51
50
|
Examples:
|
|
52
51
|
dt validate data.jsonl --preset=openai_chat
|
|
53
52
|
dt validate data.jsonl --preset=alpaca -o valid.jsonl
|
|
54
53
|
dt validate data.jsonl --preset=chat --filter
|
|
54
|
+
dt validate data.jsonl --preset=chat --workers=4
|
|
55
55
|
"""
|
|
56
56
|
filepath = Path(filename)
|
|
57
57
|
|
|
@@ -99,19 +99,54 @@ def validate(
|
|
|
99
99
|
print(f"总记录数: {total}")
|
|
100
100
|
print()
|
|
101
101
|
|
|
102
|
-
#
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
102
|
+
# 验证(使用并行或串行)
|
|
103
|
+
use_parallel = workers != 1 and total >= 1000
|
|
104
|
+
|
|
105
|
+
if use_parallel:
|
|
106
|
+
# 使用进度条(如果有 rich)
|
|
107
|
+
try:
|
|
108
|
+
from rich.progress import (
|
|
109
|
+
BarColumn,
|
|
110
|
+
Progress,
|
|
111
|
+
SpinnerColumn,
|
|
112
|
+
TaskProgressColumn,
|
|
113
|
+
TextColumn,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
with Progress(
|
|
117
|
+
SpinnerColumn(),
|
|
118
|
+
TextColumn("[bold blue]验证数据"),
|
|
119
|
+
BarColumn(),
|
|
120
|
+
TaskProgressColumn(),
|
|
121
|
+
) as progress:
|
|
122
|
+
task = progress.add_task("", total=total)
|
|
123
|
+
|
|
124
|
+
def update_progress(current: int, total_count: int):
|
|
125
|
+
progress.update(task, completed=current)
|
|
126
|
+
|
|
127
|
+
valid_data, invalid_results = schema.validate_parallel(
|
|
128
|
+
data, workers=workers, progress_callback=update_progress
|
|
129
|
+
)
|
|
130
|
+
except ImportError:
|
|
131
|
+
print("🔍 验证数据...")
|
|
132
|
+
valid_data, invalid_results = schema.validate_parallel(data, workers=workers)
|
|
133
|
+
|
|
134
|
+
invalid_count = len(invalid_results)
|
|
135
|
+
error_samples = invalid_results[:max_errors]
|
|
136
|
+
else:
|
|
137
|
+
# 串行验证
|
|
138
|
+
valid_data = []
|
|
139
|
+
invalid_count = 0
|
|
140
|
+
error_samples = []
|
|
141
|
+
|
|
142
|
+
for i, item in enumerate(data):
|
|
143
|
+
result = schema.validate(item)
|
|
144
|
+
if result.valid:
|
|
145
|
+
valid_data.append(item)
|
|
146
|
+
else:
|
|
147
|
+
invalid_count += 1
|
|
148
|
+
if len(error_samples) < max_errors:
|
|
149
|
+
error_samples.append((i, result))
|
|
115
150
|
|
|
116
151
|
valid_count = len(valid_data)
|
|
117
152
|
valid_ratio = valid_count / total * 100 if total > 0 else 0
|
|
@@ -138,9 +173,7 @@ def validate(
|
|
|
138
173
|
|
|
139
174
|
# 保存有效数据
|
|
140
175
|
if output or filter_invalid:
|
|
141
|
-
output_path = output or str(filepath).replace(
|
|
142
|
-
filepath.suffix, f"_valid{filepath.suffix}"
|
|
143
|
-
)
|
|
176
|
+
output_path = output or str(filepath).replace(filepath.suffix, f"_valid{filepath.suffix}")
|
|
144
177
|
save_data(valid_data, output_path)
|
|
145
178
|
print(f"✅ 有效数据已保存: {output_path} ({valid_count} 条)")
|
|
146
179
|
|
dtflow/parallel.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""
|
|
2
|
+
并行处理模块
|
|
3
|
+
|
|
4
|
+
提供多进程并行处理工具,用于加速大数据集的 token 统计和 schema 验证。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from multiprocessing import Pool, cpu_count
|
|
8
|
+
from typing import Callable, List, Optional, TypeVar
|
|
9
|
+
|
|
10
|
+
T = TypeVar("T")
|
|
11
|
+
R = TypeVar("R")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def parallel_map(
|
|
15
|
+
func: Callable[[T], R],
|
|
16
|
+
data: List[T],
|
|
17
|
+
workers: Optional[int] = None,
|
|
18
|
+
threshold: int = 1000,
|
|
19
|
+
chunksize: Optional[int] = None,
|
|
20
|
+
) -> List[R]:
|
|
21
|
+
"""
|
|
22
|
+
并行 map 操作。
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
func: 处理函数(必须可 pickle,不能是 lambda 或闭包)
|
|
26
|
+
data: 数据列表
|
|
27
|
+
workers: 进程数,None 则使用 CPU 核数
|
|
28
|
+
threshold: 数据量阈值,低于此值使用串行
|
|
29
|
+
chunksize: 每个进程的任务块大小,None 则自动计算
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
处理结果列表(保持顺序)
|
|
33
|
+
"""
|
|
34
|
+
n = len(data)
|
|
35
|
+
|
|
36
|
+
# 数据量小或指定单进程,使用串行
|
|
37
|
+
if n < threshold or workers == 1:
|
|
38
|
+
return [func(item) for item in data]
|
|
39
|
+
|
|
40
|
+
workers = workers or cpu_count()
|
|
41
|
+
workers = min(workers, n) # 进程数不超过数据量
|
|
42
|
+
|
|
43
|
+
# 自动计算 chunksize
|
|
44
|
+
if chunksize is None:
|
|
45
|
+
chunksize = max(1, n // (workers * 4))
|
|
46
|
+
|
|
47
|
+
with Pool(processes=workers) as pool:
|
|
48
|
+
return pool.map(func, data, chunksize=chunksize)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def parallel_imap(
|
|
52
|
+
func: Callable[[T], R],
|
|
53
|
+
data: List[T],
|
|
54
|
+
workers: Optional[int] = None,
|
|
55
|
+
threshold: int = 1000,
|
|
56
|
+
chunksize: Optional[int] = None,
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
并行 imap 操作(惰性迭代器版本,支持进度回调)。
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
func: 处理函数(必须可 pickle)
|
|
63
|
+
data: 数据列表
|
|
64
|
+
workers: 进程数,None 则使用 CPU 核数
|
|
65
|
+
threshold: 数据量阈值,低于此值使用串行
|
|
66
|
+
chunksize: 每个进程的任务块大小
|
|
67
|
+
|
|
68
|
+
Yields:
|
|
69
|
+
处理结果(按顺序)
|
|
70
|
+
"""
|
|
71
|
+
n = len(data)
|
|
72
|
+
|
|
73
|
+
# 数据量小或指定单进程,使用串行
|
|
74
|
+
if n < threshold or workers == 1:
|
|
75
|
+
for item in data:
|
|
76
|
+
yield func(item)
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
workers = workers or cpu_count()
|
|
80
|
+
workers = min(workers, n)
|
|
81
|
+
|
|
82
|
+
if chunksize is None:
|
|
83
|
+
chunksize = max(1, n // (workers * 4))
|
|
84
|
+
|
|
85
|
+
with Pool(processes=workers) as pool:
|
|
86
|
+
for result in pool.imap(func, data, chunksize=chunksize):
|
|
87
|
+
yield result
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_optimal_workers(data_size: int, default: Optional[int] = None) -> int:
|
|
91
|
+
"""
|
|
92
|
+
根据数据量计算最优进程数。
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
data_size: 数据量
|
|
96
|
+
default: 用户指定的进程数,None 则自动计算
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
最优进程数
|
|
100
|
+
"""
|
|
101
|
+
if default is not None:
|
|
102
|
+
return default
|
|
103
|
+
|
|
104
|
+
cpu_cores = cpu_count()
|
|
105
|
+
|
|
106
|
+
# 数据量小于阈值,单进程
|
|
107
|
+
if data_size < 1000:
|
|
108
|
+
return 1
|
|
109
|
+
|
|
110
|
+
# 数据量适中,使用一半 CPU
|
|
111
|
+
if data_size < 10000:
|
|
112
|
+
return max(1, cpu_cores // 2)
|
|
113
|
+
|
|
114
|
+
# 大数据量,使用全部 CPU
|
|
115
|
+
return cpu_cores
|
dtflow/schema.py
CHANGED
|
@@ -26,10 +26,35 @@ Schema 验证模块
|
|
|
26
26
|
results = dt.validate_schema(schema)
|
|
27
27
|
"""
|
|
28
28
|
|
|
29
|
-
from dataclasses import dataclass
|
|
30
|
-
from
|
|
29
|
+
from dataclasses import dataclass
|
|
30
|
+
from dataclasses import field as dataclass_field
|
|
31
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
|
32
|
+
|
|
33
|
+
from .utils.field_path import _parse_path, get_field
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _validate_item_wrapper(args: tuple) -> Tuple[int, bool, list]:
|
|
37
|
+
"""
|
|
38
|
+
验证单条数据(用于多进程)。
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
args: (index, item, schema_fields) 元组
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
(index, is_valid, errors_as_dicts) - 返回字典列表而非对象(pickle 兼容)
|
|
45
|
+
"""
|
|
46
|
+
idx, item, fields = args
|
|
47
|
+
# 在子进程中重建 Schema
|
|
48
|
+
schema = Schema(fields)
|
|
49
|
+
result = schema.validate(item)
|
|
50
|
+
|
|
51
|
+
if result.valid:
|
|
52
|
+
return (idx, True, [])
|
|
53
|
+
else:
|
|
54
|
+
# 将错误转换为字典(pickle 兼容)
|
|
55
|
+
errors = [{"path": e.path, "message": e.message, "value": e.value} for e in result.errors]
|
|
56
|
+
return (idx, False, errors)
|
|
31
57
|
|
|
32
|
-
from .utils.field_path import get_field, _parse_path, _get_value_by_segments
|
|
33
58
|
|
|
34
59
|
# 支持的类型
|
|
35
60
|
FieldType = Literal["str", "int", "float", "bool", "list", "dict", "any"]
|
|
@@ -162,9 +187,7 @@ class Field:
|
|
|
162
187
|
|
|
163
188
|
# 选项检查
|
|
164
189
|
if self.choices is not None and value not in self.choices:
|
|
165
|
-
errors.append(
|
|
166
|
-
ValidationError(path, f"值必须是 {self.choices} 之一", value)
|
|
167
|
-
)
|
|
190
|
+
errors.append(ValidationError(path, f"值必须是 {self.choices} 之一", value))
|
|
168
191
|
|
|
169
192
|
# 正则表达式检查
|
|
170
193
|
if self.pattern is not None and isinstance(value, str):
|
|
@@ -324,9 +347,7 @@ class Schema:
|
|
|
324
347
|
|
|
325
348
|
return errors
|
|
326
349
|
|
|
327
|
-
def validate_batch(
|
|
328
|
-
self, data: List[dict], max_errors: int = 100
|
|
329
|
-
) -> List[tuple]:
|
|
350
|
+
def validate_batch(self, data: List[dict], max_errors: int = 100) -> List[tuple]:
|
|
330
351
|
"""
|
|
331
352
|
批量验证数据
|
|
332
353
|
|
|
@@ -350,9 +371,76 @@ class Schema:
|
|
|
350
371
|
|
|
351
372
|
return failed
|
|
352
373
|
|
|
374
|
+
def validate_parallel(
|
|
375
|
+
self,
|
|
376
|
+
data: List[dict],
|
|
377
|
+
workers: Optional[int] = None,
|
|
378
|
+
progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
379
|
+
) -> tuple:
|
|
380
|
+
"""
|
|
381
|
+
并行验证数据列表。
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
data: 数据列表
|
|
385
|
+
workers: 进程数,None 自动检测,1 禁用并行
|
|
386
|
+
progress_callback: 进度回调函数
|
|
387
|
+
|
|
388
|
+
Returns:
|
|
389
|
+
(valid_data, invalid_indices_results) 元组
|
|
390
|
+
- valid_data: 有效数据列表
|
|
391
|
+
- invalid_indices_results: [(index, ValidationResult), ...] 无效数据
|
|
392
|
+
"""
|
|
393
|
+
if not data:
|
|
394
|
+
return [], []
|
|
395
|
+
|
|
396
|
+
total = len(data)
|
|
397
|
+
use_parallel = workers != 1 and total >= 1000
|
|
398
|
+
|
|
399
|
+
valid_data = []
|
|
400
|
+
invalid_results = []
|
|
401
|
+
|
|
402
|
+
if use_parallel:
|
|
403
|
+
from .parallel import get_optimal_workers, parallel_imap
|
|
404
|
+
|
|
405
|
+
actual_workers = get_optimal_workers(total, workers)
|
|
406
|
+
# 准备参数:(index, item, schema_fields)
|
|
407
|
+
args_list = [(i, item, self._fields) for i, item in enumerate(data)]
|
|
408
|
+
|
|
409
|
+
for i, (idx, is_valid, result_data) in enumerate(
|
|
410
|
+
parallel_imap(
|
|
411
|
+
_validate_item_wrapper,
|
|
412
|
+
args_list,
|
|
413
|
+
workers=actual_workers,
|
|
414
|
+
threshold=1000,
|
|
415
|
+
)
|
|
416
|
+
):
|
|
417
|
+
if is_valid:
|
|
418
|
+
valid_data.append(data[idx])
|
|
419
|
+
else:
|
|
420
|
+
# 重建 ValidationResult(因为不能直接 pickle)
|
|
421
|
+
errors = [
|
|
422
|
+
ValidationError(path=e["path"], message=e["message"], value=e.get("value"))
|
|
423
|
+
for e in result_data
|
|
424
|
+
]
|
|
425
|
+
invalid_results.append((idx, ValidationResult(valid=False, errors=errors)))
|
|
426
|
+
if progress_callback:
|
|
427
|
+
progress_callback(i + 1, total)
|
|
428
|
+
else:
|
|
429
|
+
# 串行处理
|
|
430
|
+
for i, item in enumerate(data):
|
|
431
|
+
result = self.validate(item)
|
|
432
|
+
if result.valid:
|
|
433
|
+
valid_data.append(item)
|
|
434
|
+
else:
|
|
435
|
+
invalid_results.append((i, result))
|
|
436
|
+
if progress_callback:
|
|
437
|
+
progress_callback(i + 1, total)
|
|
438
|
+
|
|
439
|
+
return valid_data, invalid_results
|
|
440
|
+
|
|
353
441
|
def __repr__(self) -> str:
|
|
354
442
|
field_strs = [f" {path}: {field_def}" for path, field_def in self._fields.items()]
|
|
355
|
-
return
|
|
443
|
+
return "Schema({\n" + ",\n".join(field_strs) + "\n}})"
|
|
356
444
|
|
|
357
445
|
|
|
358
446
|
# ============================================================================
|
|
@@ -461,9 +549,7 @@ def sharegpt_schema(
|
|
|
461
549
|
"""
|
|
462
550
|
return Schema(
|
|
463
551
|
{
|
|
464
|
-
"conversations": Field(
|
|
465
|
-
type="list", required=True, min_length=min_conversations
|
|
466
|
-
),
|
|
552
|
+
"conversations": Field(type="list", required=True, min_length=min_conversations),
|
|
467
553
|
"conversations[*].from": Field(
|
|
468
554
|
type="str", required=True, choices=[human_role, gpt_role]
|
|
469
555
|
),
|
dtflow/tokenizers.py
CHANGED
|
@@ -122,8 +122,8 @@ def _get_tiktoken_encoder(model: str):
|
|
|
122
122
|
_tokenizer_cache[model] = tiktoken.get_encoding(model)
|
|
123
123
|
else:
|
|
124
124
|
_tokenizer_cache[model] = tiktoken.encoding_for_model(model)
|
|
125
|
-
except ImportError:
|
|
126
|
-
raise ImportError("需要安装 tiktoken: pip install tiktoken")
|
|
125
|
+
except ImportError as e:
|
|
126
|
+
raise ImportError("需要安装 tiktoken: pip install tiktoken") from e
|
|
127
127
|
return _tokenizer_cache[model]
|
|
128
128
|
|
|
129
129
|
|
|
@@ -149,12 +149,12 @@ def _get_hf_tokenizer(model: str):
|
|
|
149
149
|
|
|
150
150
|
tokenizer = AutoTokenizer.from_pretrained(resolved, trust_remote_code=True)
|
|
151
151
|
_tokenizer_cache[resolved] = ("transformers", tokenizer)
|
|
152
|
-
except ImportError:
|
|
152
|
+
except ImportError as e:
|
|
153
153
|
raise ImportError(
|
|
154
154
|
"需要安装 tokenizers 或 transformers:\n"
|
|
155
155
|
" pip install tokenizers huggingface_hub (推荐,更轻量)\n"
|
|
156
156
|
" pip install transformers"
|
|
157
|
-
)
|
|
157
|
+
) from e
|
|
158
158
|
return _tokenizer_cache[resolved]
|
|
159
159
|
|
|
160
160
|
|
|
@@ -309,12 +309,29 @@ def _std(counts: List[int], avg: float) -> float:
|
|
|
309
309
|
return variance**0.5
|
|
310
310
|
|
|
311
311
|
|
|
312
|
+
def _count_item_tokens(args: tuple) -> int:
|
|
313
|
+
"""
|
|
314
|
+
计算单条数据的 token 数(用于多进程)。
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
args: (item, fields, model, backend) 元组
|
|
318
|
+
"""
|
|
319
|
+
item, fields, model, backend = args
|
|
320
|
+
total = 0
|
|
321
|
+
for field in fields:
|
|
322
|
+
value = get_field_with_spec(item, field, default="")
|
|
323
|
+
if value:
|
|
324
|
+
total += count_tokens(str(value), model=model, backend=backend)
|
|
325
|
+
return total
|
|
326
|
+
|
|
327
|
+
|
|
312
328
|
def token_stats(
|
|
313
329
|
data: List[Dict[str, Any]],
|
|
314
330
|
fields: Union[str, List[str]],
|
|
315
331
|
model: str = DEFAULT_MODEL,
|
|
316
332
|
backend: Optional[str] = None,
|
|
317
333
|
progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
334
|
+
workers: Optional[int] = None,
|
|
318
335
|
) -> Dict[str, Any]:
|
|
319
336
|
"""
|
|
320
337
|
统计数据集的 token 信息。
|
|
@@ -325,6 +342,7 @@ def token_stats(
|
|
|
325
342
|
model: 模型名称或别名,如 "qwen2.5", "gpt-4" 等
|
|
326
343
|
backend: 后端选择,None 则自动检测
|
|
327
344
|
progress_callback: 进度回调函数,接收 (current, total) 两个参数
|
|
345
|
+
workers: 进程数,None 自动检测,1 表示禁用并行
|
|
328
346
|
|
|
329
347
|
Returns:
|
|
330
348
|
统计信息字典,包含:
|
|
@@ -342,17 +360,42 @@ def token_stats(
|
|
|
342
360
|
if not data:
|
|
343
361
|
return {"total_tokens": 0, "count": 0}
|
|
344
362
|
|
|
345
|
-
counts = []
|
|
346
363
|
total_items = len(data)
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
364
|
+
_backend = backend or _auto_backend(model)
|
|
365
|
+
|
|
366
|
+
# 判断是否使用多进程
|
|
367
|
+
use_parallel = workers != 1 and total_items >= 1000
|
|
368
|
+
|
|
369
|
+
if use_parallel:
|
|
370
|
+
from .parallel import get_optimal_workers, parallel_imap
|
|
371
|
+
|
|
372
|
+
actual_workers = get_optimal_workers(total_items, workers)
|
|
373
|
+
# 准备参数
|
|
374
|
+
args_list = [(item, fields, model, _backend) for item in data]
|
|
375
|
+
counts = []
|
|
376
|
+
for i, result in enumerate(
|
|
377
|
+
parallel_imap(
|
|
378
|
+
_count_item_tokens,
|
|
379
|
+
args_list,
|
|
380
|
+
workers=actual_workers,
|
|
381
|
+
threshold=1000,
|
|
382
|
+
)
|
|
383
|
+
):
|
|
384
|
+
counts.append(result)
|
|
385
|
+
if progress_callback:
|
|
386
|
+
progress_callback(i + 1, total_items)
|
|
387
|
+
else:
|
|
388
|
+
# 串行处理
|
|
389
|
+
counts = []
|
|
390
|
+
for i, item in enumerate(data):
|
|
391
|
+
total = 0
|
|
392
|
+
for field in fields:
|
|
393
|
+
value = get_field_with_spec(item, field, default="")
|
|
394
|
+
if value:
|
|
395
|
+
total += count_tokens(str(value), model=model, backend=_backend)
|
|
396
|
+
counts.append(total)
|
|
397
|
+
if progress_callback:
|
|
398
|
+
progress_callback(i + 1, total_items)
|
|
356
399
|
|
|
357
400
|
sorted_counts = sorted(counts)
|
|
358
401
|
avg = sum(counts) / len(counts)
|
|
@@ -548,12 +591,27 @@ def messages_token_filter(
|
|
|
548
591
|
return filter_func
|
|
549
592
|
|
|
550
593
|
|
|
594
|
+
def _count_messages_tokens_wrapper(args: tuple) -> Optional[Dict[str, int]]:
|
|
595
|
+
"""
|
|
596
|
+
计算单条 messages 的 token 数(用于多进程)。
|
|
597
|
+
|
|
598
|
+
Args:
|
|
599
|
+
args: (item, messages_field, model, backend) 元组
|
|
600
|
+
"""
|
|
601
|
+
item, messages_field, model, backend = args
|
|
602
|
+
messages = get_field_with_spec(item, messages_field, default=[])
|
|
603
|
+
if messages:
|
|
604
|
+
return _count_messages_tokens(messages, model=model, backend=backend)
|
|
605
|
+
return None
|
|
606
|
+
|
|
607
|
+
|
|
551
608
|
def messages_token_stats(
|
|
552
609
|
data: List[Dict[str, Any]],
|
|
553
610
|
messages_field: str = "messages",
|
|
554
611
|
model: str = DEFAULT_MODEL,
|
|
555
612
|
backend: Optional[str] = None,
|
|
556
613
|
progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
614
|
+
workers: Optional[int] = None,
|
|
557
615
|
) -> Dict[str, Any]:
|
|
558
616
|
"""
|
|
559
617
|
统计数据集中 messages 的 token 信息。
|
|
@@ -564,6 +622,7 @@ def messages_token_stats(
|
|
|
564
622
|
model: 模型名称或别名
|
|
565
623
|
backend: 后端,None 则自动检测
|
|
566
624
|
progress_callback: 进度回调函数,接收 (current, total) 两个参数
|
|
625
|
+
workers: 进程数,None 自动检测,1 表示禁用并行
|
|
567
626
|
|
|
568
627
|
Returns:
|
|
569
628
|
统计信息字典,包含:
|
|
@@ -581,14 +640,38 @@ def messages_token_stats(
|
|
|
581
640
|
if not data:
|
|
582
641
|
return {"count": 0, "total_tokens": 0}
|
|
583
642
|
|
|
584
|
-
all_stats = []
|
|
585
643
|
total_items = len(data)
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
644
|
+
|
|
645
|
+
# 判断是否使用多进程
|
|
646
|
+
use_parallel = workers != 1 and total_items >= 1000
|
|
647
|
+
|
|
648
|
+
all_stats = []
|
|
649
|
+
if use_parallel:
|
|
650
|
+
from .parallel import get_optimal_workers, parallel_imap
|
|
651
|
+
|
|
652
|
+
actual_workers = get_optimal_workers(total_items, workers)
|
|
653
|
+
args_list = [(item, messages_field, model, _backend) for item in data]
|
|
654
|
+
|
|
655
|
+
for i, result in enumerate(
|
|
656
|
+
parallel_imap(
|
|
657
|
+
_count_messages_tokens_wrapper,
|
|
658
|
+
args_list,
|
|
659
|
+
workers=actual_workers,
|
|
660
|
+
threshold=1000,
|
|
661
|
+
)
|
|
662
|
+
):
|
|
663
|
+
if result is not None:
|
|
664
|
+
all_stats.append(result)
|
|
665
|
+
if progress_callback:
|
|
666
|
+
progress_callback(i + 1, total_items)
|
|
667
|
+
else:
|
|
668
|
+
# 串行处理
|
|
669
|
+
for i, item in enumerate(data):
|
|
670
|
+
messages = get_field_with_spec(item, messages_field, default=[])
|
|
671
|
+
if messages:
|
|
672
|
+
all_stats.append(_count_messages_tokens(messages, model=model, backend=_backend))
|
|
673
|
+
if progress_callback:
|
|
674
|
+
progress_callback(i + 1, total_items)
|
|
592
675
|
|
|
593
676
|
if not all_stats:
|
|
594
677
|
return {"count": 0, "total_tokens": 0}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dtflow
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.10
|
|
4
4
|
Summary: A flexible data transformation tool for ML training formats (SFT, RLHF, Pretrain)
|
|
5
5
|
Project-URL: Homepage, https://github.com/yourusername/DataTransformer
|
|
6
6
|
Project-URL: Documentation, https://github.com/yourusername/DataTransformer#readme
|
|
@@ -464,6 +464,7 @@ dt run pipeline.yaml --input=new_data.jsonl --output=result.jsonl
|
|
|
464
464
|
dt token-stats data.jsonl --field=messages --model=gpt-4
|
|
465
465
|
dt token-stats data.jsonl --field=messages[-1].content # 统计最后一条消息
|
|
466
466
|
dt token-stats data.jsonl --field=text --detailed
|
|
467
|
+
dt token-stats data.jsonl --workers=4 # 多进程加速(数据量大时自动启用)
|
|
467
468
|
|
|
468
469
|
# 数据对比
|
|
469
470
|
dt diff v1/train.jsonl v2/train.jsonl
|
|
@@ -521,6 +522,7 @@ dt validate data.jsonl --preset=openai_chat # 使用预设 schema 验
|
|
|
521
522
|
dt validate data.jsonl --preset=alpaca --verbose # 详细输出
|
|
522
523
|
dt validate data.jsonl --preset=sharegpt --filter-invalid -o valid.jsonl # 过滤出有效数据
|
|
523
524
|
dt validate data.jsonl --preset=dpo --max-errors=100 # 限制错误输出数量
|
|
525
|
+
dt validate data.jsonl --preset=openai_chat --workers=4 # 多进程加速
|
|
524
526
|
```
|
|
525
527
|
|
|
526
528
|
### 字段路径语法
|
|
@@ -1,16 +1,17 @@
|
|
|
1
|
-
dtflow/SKILL.md,sha256=
|
|
2
|
-
dtflow/__init__.py,sha256=
|
|
3
|
-
dtflow/__main__.py,sha256=
|
|
1
|
+
dtflow/SKILL.md,sha256=Oq8Kb5JghZMJ1WoP8OWhX3qAWaUY9Sip_iWAv8S2eMg,10567
|
|
2
|
+
dtflow/__init__.py,sha256=2A-P6k9VBIWZXRgXwYPFOwHMCmgkfKZVYuGuBziqqhc,3032
|
|
3
|
+
dtflow/__main__.py,sha256=_wrpYfOog6G83I17yuBe-hryBsaCrIwbXSEnzT-r28g,18008
|
|
4
4
|
dtflow/converters.py,sha256=X3qeFD7FCOMnfiP3MicL5MXimOm4XUYBs5pczIkudU0,22331
|
|
5
5
|
dtflow/core.py,sha256=qMo6B3LK--TWRK7ZBKObGcs3pKFnd0NPoaM0T8JC7Jw,38135
|
|
6
6
|
dtflow/eval.py,sha256=_c-XP2zsOBznYltSyKEScOqvmPVX2orqepg5cNhXXB0,9836
|
|
7
7
|
dtflow/framework.py,sha256=jyICi_RWHjX7WfsXdSbWmP1SL7y1OWSPyd5G5Y-lvg4,17578
|
|
8
8
|
dtflow/lineage.py,sha256=jie3OL1qK90-_cOOqqLbhSJ1oGUktDM1x5HRpQ5Qiyc,12800
|
|
9
|
+
dtflow/parallel.py,sha256=EnIdGEGMrZUNT2-CBIV93UFfpqr_jU_heqqvdGXcP-Y,3046
|
|
9
10
|
dtflow/pipeline.py,sha256=zZaC4fg5vsp_30Fhbg75vu0yggsdvf28bWBiVDWzZ6Y,13901
|
|
10
11
|
dtflow/presets.py,sha256=qa8WQJhbNMuGxqqgA9BFadEBwDB9s0zWNxxhzF3q1K8,4701
|
|
11
|
-
dtflow/schema.py,sha256=
|
|
12
|
+
dtflow/schema.py,sha256=zCZNEAqTMT1BS_p2t0CYczR5S9rqyDREa7ZsYI5pFGA,19885
|
|
12
13
|
dtflow/streaming.py,sha256=dxpNd1-Wz_PTLTdvM5qn06_2TJr5NRlIIuw0LOSS2Iw,24755
|
|
13
|
-
dtflow/tokenizers.py,sha256=
|
|
14
|
+
dtflow/tokenizers.py,sha256=GFQsuLSLn2GHn2kaXhJkP8G85lgsdLzYtJNbppQhYPE,23408
|
|
14
15
|
dtflow/cli/__init__.py,sha256=QhZ-thgx9IBTFII7T_hdoWFUl0CCsdGQHN5ZEZw2XB0,423
|
|
15
16
|
dtflow/cli/clean.py,sha256=BEQQlH2q6luCbx51M3oxxOwcnwlOA8vo9WX3Fp7I6AY,29498
|
|
16
17
|
dtflow/cli/commands.py,sha256=LvyDQ_nWUM7UlPDEFQadRdw5O2ZKDLgF41_xAJRhYxI,1583
|
|
@@ -23,9 +24,9 @@ dtflow/cli/pipeline.py,sha256=QNEo-BJlaC1CVnVeRZr7TwfuZYloJ4TebIzJ5ALzry0,1426
|
|
|
23
24
|
dtflow/cli/sample.py,sha256=etbro5I0pyNgn0Qfhp1M6Bh-95JN-AntDa5AwVe_oKY,18269
|
|
24
25
|
dtflow/cli/skill.py,sha256=opiTEBejA7JHKrEMftMOPDQlOgZ4n59rwaHXGU1Nukk,2022
|
|
25
26
|
dtflow/cli/split.py,sha256=96bhWnxHnjIqifoliLgciApkLbwQU8bWHovK8bcMk9g,3667
|
|
26
|
-
dtflow/cli/stats.py,sha256=
|
|
27
|
+
dtflow/cli/stats.py,sha256=HkTZD80h4tzYXTtMnfpjLUMP6kl_es6ifcmExxzGdMU,31813
|
|
27
28
|
dtflow/cli/transform.py,sha256=w6xqMOxPxQvL2u_BPCfpDHuPSC9gmcqMPVN8s-B6bbY,15052
|
|
28
|
-
dtflow/cli/validate.py,sha256=
|
|
29
|
+
dtflow/cli/validate.py,sha256=Frs-jKcDHmYozpmIYZueDSX5o2i1Xn-WW81FGUyUrng,5796
|
|
29
30
|
dtflow/storage/__init__.py,sha256=C0jpWNQU808Ezz7lWneddABal3wILy8ijFUNiSKbHV4,362
|
|
30
31
|
dtflow/storage/io.py,sha256=ZH2aSE-S89gpy3z4oTqhcqWf4u10OdkDoyul7o_YBDI,23374
|
|
31
32
|
dtflow/utils/__init__.py,sha256=Pn-ltwV04fBQmeZG7FxInDQmzH29LYOi90LgeLMEuQk,506
|
|
@@ -33,7 +34,7 @@ dtflow/utils/display.py,sha256=OeOdTh6mbDwSkDWlmkjfpTjy2QG8ZUaYU0NpHUWkpEQ,5881
|
|
|
33
34
|
dtflow/utils/field_path.py,sha256=K8nU196RxTSJ1OoieTWGcYOWl9KjGq2iSxCAkfjECuM,7621
|
|
34
35
|
dtflow/utils/helpers.py,sha256=JXN176_B2pm53GLVyZ1wj3wrmBJG52Tkw6AMQSdj7M8,791
|
|
35
36
|
dtflow/utils/text_parser.py,sha256=0t2TMOSha4dTiDu9H4ygdb67cI20zhtBH1XavDspL_g,3727
|
|
36
|
-
dtflow-0.5.
|
|
37
|
-
dtflow-0.5.
|
|
38
|
-
dtflow-0.5.
|
|
39
|
-
dtflow-0.5.
|
|
37
|
+
dtflow-0.5.10.dist-info/METADATA,sha256=OGefMoe17by5IbxdxZgqoJ1Y6OWPt_iGEFM4KgltRZw,26023
|
|
38
|
+
dtflow-0.5.10.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
39
|
+
dtflow-0.5.10.dist-info/entry_points.txt,sha256=dadIDOK7Iu9pMxnMPBfpb4aAPe4hQbBOshpQYjVYpGc,44
|
|
40
|
+
dtflow-0.5.10.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|