dtflow 0.5.0__tar.gz → 0.5.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow-0.5.3/CHANGELOG.md +19 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/PKG-INFO +11 -1
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/__init__.py +7 -7
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/common.py +13 -9
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/stats.py +114 -36
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/converters.py +17 -13
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/core.py +66 -10
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/lineage.py +17 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/presets.py +14 -15
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/streaming.py +93 -35
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/tokenizers.py +84 -29
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/utils/__init__.py +3 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/utils/field_path.py +6 -2
- dtflow-0.5.3/dtflow/utils/helpers.py +30 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/pyproject.toml +17 -4
- dtflow-0.5.3/tests/README.md +88 -0
- dtflow-0.5.3/tests/benchmark_sharegpt.py +392 -0
- dtflow-0.5.3/tests/test_cli_benchmark.py +565 -0
- dtflow-0.5.3/tests/test_cli_clean.py +314 -0
- dtflow-0.5.3/tests/test_cli_sample.py +242 -0
- dtflow-0.5.3/tests/test_cli_stats.py +213 -0
- dtflow-0.5.3/tests/test_cli_transform.py +304 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_streaming.py +80 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_transformer.py +77 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/.gitignore +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/README.md +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/__main__.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/__init__.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/clean.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/commands.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/io_ops.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/lineage.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/pipeline.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/sample.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/transform.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/validate.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/framework.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/mcp/__init__.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/mcp/__main__.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/mcp/cli.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/mcp/docs.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/mcp/server.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/pipeline.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/schema.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/storage/__init__.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/storage/io.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/utils/display.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/tests/benchmark_io.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_converters.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_field_path.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_framework.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_io.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_lineage.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_pipeline.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_schema.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_tokenizers.py +0 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
## [0.5.2] - 2026-01-18
|
|
4
|
+
|
|
5
|
+
### Miscellaneous
|
|
6
|
+
|
|
7
|
+
- Bump version to 0.5.2
|
|
8
|
+
- 添加 pre-commit 配置和发版脚本
|
|
9
|
+
|
|
10
|
+
## [0.5.1] - 2026-01-18
|
|
11
|
+
|
|
12
|
+
### Features
|
|
13
|
+
|
|
14
|
+
- 优化 sample 命令文本预览显示
|
|
15
|
+
|
|
16
|
+
### Testing
|
|
17
|
+
|
|
18
|
+
- 添加测试运行说明
|
|
19
|
+
- 补充 tail/token-stats/validate 性能测试
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dtflow
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.3
|
|
4
4
|
Summary: A flexible data transformation tool for ML training formats (SFT, RLHF, Pretrain)
|
|
5
5
|
Project-URL: Homepage, https://github.com/yourusername/DataTransformer
|
|
6
6
|
Project-URL: Documentation, https://github.com/yourusername/DataTransformer#readme
|
|
@@ -32,16 +32,26 @@ Requires-Dist: orjson>=3.9.0
|
|
|
32
32
|
Requires-Dist: polars>=0.20.0
|
|
33
33
|
Requires-Dist: pyyaml>=5.4.0
|
|
34
34
|
Requires-Dist: rich>=10.0.0
|
|
35
|
+
Requires-Dist: tiktoken>=0.5.0
|
|
35
36
|
Requires-Dist: typer>=0.9.0
|
|
36
37
|
Provides-Extra: converters
|
|
37
38
|
Requires-Dist: datasets>=2.0.0; extra == 'converters'
|
|
38
39
|
Provides-Extra: dev
|
|
39
40
|
Requires-Dist: black>=21.0; extra == 'dev'
|
|
41
|
+
Requires-Dist: datasets>=2.0.0; extra == 'dev'
|
|
42
|
+
Requires-Dist: datasketch>=1.5.0; extra == 'dev'
|
|
40
43
|
Requires-Dist: flake8>=3.9.0; extra == 'dev'
|
|
44
|
+
Requires-Dist: huggingface-hub>=0.20.0; extra == 'dev'
|
|
41
45
|
Requires-Dist: isort>=5.9.0; extra == 'dev'
|
|
42
46
|
Requires-Dist: mypy>=0.910; extra == 'dev'
|
|
47
|
+
Requires-Dist: pyarrow; extra == 'dev'
|
|
43
48
|
Requires-Dist: pytest-cov>=2.12.0; extra == 'dev'
|
|
44
49
|
Requires-Dist: pytest>=6.0.0; extra == 'dev'
|
|
50
|
+
Requires-Dist: rich>=10.0.0; extra == 'dev'
|
|
51
|
+
Requires-Dist: scikit-learn>=0.24.0; extra == 'dev'
|
|
52
|
+
Requires-Dist: tiktoken>=0.5.0; extra == 'dev'
|
|
53
|
+
Requires-Dist: tokenizers>=0.15.0; extra == 'dev'
|
|
54
|
+
Requires-Dist: toolong>=1.5.0; extra == 'dev'
|
|
45
55
|
Provides-Extra: display
|
|
46
56
|
Provides-Extra: docs
|
|
47
57
|
Requires-Dist: myst-parser>=0.15.0; extra == 'docs'
|
|
@@ -26,6 +26,12 @@ from .converters import ( # LLaMA-Factory 扩展; ms-swift
|
|
|
26
26
|
to_swift_vlm,
|
|
27
27
|
)
|
|
28
28
|
from .core import DataTransformer, DictWrapper, TransformError, TransformErrors
|
|
29
|
+
from .framework import (
|
|
30
|
+
CompatibilityResult,
|
|
31
|
+
check_compatibility,
|
|
32
|
+
detect_format,
|
|
33
|
+
export_for,
|
|
34
|
+
)
|
|
29
35
|
from .presets import get_preset, list_presets
|
|
30
36
|
from .schema import (
|
|
31
37
|
Field,
|
|
@@ -38,12 +44,6 @@ from .schema import (
|
|
|
38
44
|
sharegpt_schema,
|
|
39
45
|
validate_data,
|
|
40
46
|
)
|
|
41
|
-
from .framework import (
|
|
42
|
-
CompatibilityResult,
|
|
43
|
-
check_compatibility,
|
|
44
|
-
detect_format,
|
|
45
|
-
export_for,
|
|
46
|
-
)
|
|
47
47
|
from .storage import load_data, sample_file, save_data
|
|
48
48
|
from .streaming import StreamingTransformer, load_sharded, load_stream, process_shards
|
|
49
49
|
from .tokenizers import (
|
|
@@ -60,7 +60,7 @@ from .tokenizers import (
|
|
|
60
60
|
token_stats,
|
|
61
61
|
)
|
|
62
62
|
|
|
63
|
-
__version__ = "0.5.
|
|
63
|
+
__version__ = "0.5.3"
|
|
64
64
|
|
|
65
65
|
__all__ = [
|
|
66
66
|
# core
|
|
@@ -57,7 +57,7 @@ def _get_file_row_count(filepath: Path) -> Optional[int]:
|
|
|
57
57
|
return None
|
|
58
58
|
|
|
59
59
|
|
|
60
|
-
def _format_value(value: Any, max_len: int =
|
|
60
|
+
def _format_value(value: Any, max_len: int = 120) -> str:
|
|
61
61
|
"""格式化单个值,长文本截断。"""
|
|
62
62
|
if value is None:
|
|
63
63
|
return "[dim]null[/dim]"
|
|
@@ -66,18 +66,22 @@ def _format_value(value: Any, max_len: int = 80) -> str:
|
|
|
66
66
|
if isinstance(value, (int, float)):
|
|
67
67
|
return f"[cyan]{value}[/cyan]"
|
|
68
68
|
if isinstance(value, str):
|
|
69
|
+
half_len = max_len // 2
|
|
69
70
|
# 处理多行文本
|
|
70
71
|
if "\n" in value:
|
|
71
72
|
lines = value.split("\n")
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
73
|
+
preview = value.replace("\n", "\\n")
|
|
74
|
+
if len(preview) > max_len:
|
|
75
|
+
# 前半 + 省略标记 + 后半
|
|
76
|
+
head = preview[:half_len]
|
|
77
|
+
tail = preview[-half_len:]
|
|
78
|
+
return f'"{head} [yellow]<<<{len(lines)}行>>>[/yellow] {tail}"'
|
|
78
79
|
return f'"{preview}"'
|
|
79
80
|
if len(value) > max_len:
|
|
80
|
-
|
|
81
|
+
# 前半 + 省略标记 + 后半
|
|
82
|
+
head = value[:half_len]
|
|
83
|
+
tail = value[-half_len:]
|
|
84
|
+
return f'"{head} [yellow]<<<{len(value)}字符>>>[/yellow] {tail}"'
|
|
81
85
|
return f'"{value}"'
|
|
82
86
|
return str(value)
|
|
83
87
|
|
|
@@ -86,7 +90,7 @@ def _format_nested(
|
|
|
86
90
|
value: Any,
|
|
87
91
|
indent: str = "",
|
|
88
92
|
is_last: bool = True,
|
|
89
|
-
max_len: int =
|
|
93
|
+
max_len: int = 120,
|
|
90
94
|
) -> List[str]:
|
|
91
95
|
"""
|
|
92
96
|
递归格式化嵌套结构,返回行列表。
|
|
@@ -465,34 +465,65 @@ def token_stats(
|
|
|
465
465
|
return
|
|
466
466
|
|
|
467
467
|
total = len(data)
|
|
468
|
-
print(f" 共 {total} 条数据")
|
|
469
|
-
print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
|
|
468
|
+
print(f" 共 {total:,} 条数据")
|
|
470
469
|
|
|
471
470
|
# 检查字段类型并选择合适的统计方法(支持嵌套路径)
|
|
472
471
|
sample = data[0]
|
|
473
472
|
field_value = get_field_with_spec(sample, field)
|
|
474
473
|
|
|
474
|
+
# 尝试使用 rich 进度条
|
|
475
475
|
try:
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
476
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
|
|
477
|
+
|
|
478
|
+
with Progress(
|
|
479
|
+
SpinnerColumn(),
|
|
480
|
+
TextColumn("[bold blue]统计 Token"),
|
|
481
|
+
BarColumn(),
|
|
482
|
+
TaskProgressColumn(),
|
|
483
|
+
TextColumn(f"(模型: {model})"),
|
|
484
|
+
) as progress:
|
|
485
|
+
task = progress.add_task("", total=total)
|
|
486
|
+
|
|
487
|
+
def update_progress(current: int, total_count: int):
|
|
488
|
+
progress.update(task, completed=current)
|
|
489
|
+
|
|
490
|
+
if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
|
|
491
|
+
from ..tokenizers import messages_token_stats
|
|
492
|
+
|
|
493
|
+
stats_result = messages_token_stats(
|
|
494
|
+
data, messages_field=field, model=model, progress_callback=update_progress
|
|
495
|
+
)
|
|
496
|
+
_print_messages_token_stats(stats_result, detailed)
|
|
497
|
+
else:
|
|
498
|
+
from ..tokenizers import token_stats as compute_token_stats
|
|
499
|
+
|
|
500
|
+
stats_result = compute_token_stats(
|
|
501
|
+
data, fields=field, model=model, progress_callback=update_progress
|
|
502
|
+
)
|
|
503
|
+
_print_text_token_stats(stats_result, detailed)
|
|
504
|
+
|
|
505
|
+
except ImportError:
|
|
506
|
+
# 没有 rich,显示简单进度
|
|
507
|
+
print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
|
|
508
|
+
try:
|
|
509
|
+
if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
|
|
510
|
+
from ..tokenizers import messages_token_stats
|
|
511
|
+
|
|
512
|
+
stats_result = messages_token_stats(data, messages_field=field, model=model)
|
|
513
|
+
_print_messages_token_stats(stats_result, detailed)
|
|
514
|
+
else:
|
|
515
|
+
from ..tokenizers import token_stats as compute_token_stats
|
|
494
516
|
|
|
495
|
-
|
|
517
|
+
stats_result = compute_token_stats(data, fields=field, model=model)
|
|
518
|
+
_print_text_token_stats(stats_result, detailed)
|
|
519
|
+
except ImportError as e:
|
|
520
|
+
print(f"错误: {e}")
|
|
521
|
+
return
|
|
522
|
+
except Exception as e:
|
|
523
|
+
print(f"错误: 统计失败 - {e}")
|
|
524
|
+
import traceback
|
|
525
|
+
|
|
526
|
+
traceback.print_exc()
|
|
496
527
|
|
|
497
528
|
|
|
498
529
|
def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
@@ -505,21 +536,39 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
|
505
536
|
console = Console()
|
|
506
537
|
|
|
507
538
|
# 概览
|
|
539
|
+
std = stats.get("std_tokens", 0)
|
|
508
540
|
overview = (
|
|
509
541
|
f"[bold]总样本数:[/bold] {stats['count']:,}\n"
|
|
510
542
|
f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
|
|
511
|
-
f"[bold]平均 Token:[/bold] {stats['avg_tokens']:,}\n"
|
|
512
|
-
f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
|
|
543
|
+
f"[bold]平均 Token:[/bold] {stats['avg_tokens']:,} (std: {std:.1f})\n"
|
|
513
544
|
f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
|
|
514
545
|
)
|
|
515
546
|
console.print(Panel(overview, title="📊 Token 统计概览", expand=False))
|
|
516
547
|
|
|
548
|
+
# 百分位数表格
|
|
549
|
+
table = Table(title="📈 分布统计")
|
|
550
|
+
table.add_column("百分位", style="cyan", justify="center")
|
|
551
|
+
table.add_column("Token 数", justify="right")
|
|
552
|
+
percentiles = [
|
|
553
|
+
("Min", stats["min_tokens"]),
|
|
554
|
+
("P25", stats.get("p25", "-")),
|
|
555
|
+
("P50 (中位数)", stats.get("median_tokens", "-")),
|
|
556
|
+
("P75", stats.get("p75", "-")),
|
|
557
|
+
("P90", stats.get("p90", "-")),
|
|
558
|
+
("P95", stats.get("p95", "-")),
|
|
559
|
+
("P99", stats.get("p99", "-")),
|
|
560
|
+
("Max", stats["max_tokens"]),
|
|
561
|
+
]
|
|
562
|
+
for name, val in percentiles:
|
|
563
|
+
table.add_row(name, f"{val:,}" if isinstance(val, int) else str(val))
|
|
564
|
+
console.print(table)
|
|
565
|
+
|
|
517
566
|
if detailed:
|
|
518
|
-
#
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
567
|
+
# 分角色统计
|
|
568
|
+
role_table = Table(title="📋 分角色统计")
|
|
569
|
+
role_table.add_column("角色", style="cyan")
|
|
570
|
+
role_table.add_column("Token 数", justify="right")
|
|
571
|
+
role_table.add_column("占比", justify="right")
|
|
523
572
|
|
|
524
573
|
total = stats["total_tokens"]
|
|
525
574
|
for role, key in [
|
|
@@ -529,22 +578,27 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
|
529
578
|
]:
|
|
530
579
|
tokens = stats.get(key, 0)
|
|
531
580
|
pct = tokens / total * 100 if total > 0 else 0
|
|
532
|
-
|
|
581
|
+
role_table.add_row(role, f"{tokens:,}", f"{pct:.1f}%")
|
|
533
582
|
|
|
534
|
-
console.print(
|
|
583
|
+
console.print(role_table)
|
|
535
584
|
console.print(f"\n平均对话轮数: {stats.get('avg_turns', 0)}")
|
|
536
585
|
|
|
537
586
|
except ImportError:
|
|
538
587
|
# 没有 rich,使用普通打印
|
|
588
|
+
std = stats.get("std_tokens", 0)
|
|
539
589
|
print(f"\n{'=' * 40}")
|
|
540
590
|
print("📊 Token 统计概览")
|
|
541
591
|
print(f"{'=' * 40}")
|
|
542
592
|
print(f"总样本数: {stats['count']:,}")
|
|
543
593
|
print(f"总 Token: {stats['total_tokens']:,}")
|
|
544
|
-
print(f"平均 Token: {stats['avg_tokens']:,}")
|
|
545
|
-
print(f"中位数: {stats['median_tokens']:,}")
|
|
594
|
+
print(f"平均 Token: {stats['avg_tokens']:,} (std: {std:.1f})")
|
|
546
595
|
print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
|
|
547
596
|
|
|
597
|
+
print(f"\n📈 百分位分布:")
|
|
598
|
+
print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
|
|
599
|
+
print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
|
|
600
|
+
print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
|
|
601
|
+
|
|
548
602
|
if detailed:
|
|
549
603
|
print(f"\n{'=' * 40}")
|
|
550
604
|
print("📋 分角色统计")
|
|
@@ -566,24 +620,48 @@ def _print_text_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
|
566
620
|
try:
|
|
567
621
|
from rich.console import Console
|
|
568
622
|
from rich.panel import Panel
|
|
623
|
+
from rich.table import Table
|
|
569
624
|
|
|
570
625
|
console = Console()
|
|
571
626
|
|
|
627
|
+
std = stats.get("std_tokens", 0)
|
|
572
628
|
overview = (
|
|
573
629
|
f"[bold]总样本数:[/bold] {stats['count']:,}\n"
|
|
574
630
|
f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
|
|
575
|
-
f"[bold]平均 Token:[/bold] {stats['avg_tokens']:.1f}\n"
|
|
576
|
-
f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
|
|
631
|
+
f"[bold]平均 Token:[/bold] {stats['avg_tokens']:.1f} (std: {std:.1f})\n"
|
|
577
632
|
f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
|
|
578
633
|
)
|
|
579
634
|
console.print(Panel(overview, title="📊 Token 统计", expand=False))
|
|
580
635
|
|
|
636
|
+
# 百分位数表格
|
|
637
|
+
table = Table(title="📈 分布统计")
|
|
638
|
+
table.add_column("百分位", style="cyan", justify="center")
|
|
639
|
+
table.add_column("Token 数", justify="right")
|
|
640
|
+
percentiles = [
|
|
641
|
+
("Min", stats["min_tokens"]),
|
|
642
|
+
("P25", stats.get("p25", "-")),
|
|
643
|
+
("P50 (中位数)", stats.get("median_tokens", "-")),
|
|
644
|
+
("P75", stats.get("p75", "-")),
|
|
645
|
+
("P90", stats.get("p90", "-")),
|
|
646
|
+
("P95", stats.get("p95", "-")),
|
|
647
|
+
("P99", stats.get("p99", "-")),
|
|
648
|
+
("Max", stats["max_tokens"]),
|
|
649
|
+
]
|
|
650
|
+
for name, val in percentiles:
|
|
651
|
+
table.add_row(name, f"{val:,}" if isinstance(val, int) else str(val))
|
|
652
|
+
console.print(table)
|
|
653
|
+
|
|
581
654
|
except ImportError:
|
|
655
|
+
std = stats.get("std_tokens", 0)
|
|
582
656
|
print(f"\n{'=' * 40}")
|
|
583
657
|
print("📊 Token 统计")
|
|
584
658
|
print(f"{'=' * 40}")
|
|
585
659
|
print(f"总样本数: {stats['count']:,}")
|
|
586
660
|
print(f"总 Token: {stats['total_tokens']:,}")
|
|
587
|
-
print(f"平均 Token: {stats['avg_tokens']:.1f}")
|
|
588
|
-
print(f"中位数: {stats['median_tokens']:,}")
|
|
661
|
+
print(f"平均 Token: {stats['avg_tokens']:.1f} (std: {std:.1f})")
|
|
589
662
|
print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
|
|
663
|
+
|
|
664
|
+
print(f"\n📈 百分位分布:")
|
|
665
|
+
print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
|
|
666
|
+
print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
|
|
667
|
+
print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
提供与 HuggingFace datasets 等常用格式的互转功能。
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from typing import Any, Callable, Dict, List, Optional
|
|
7
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def to_hf_dataset(data: List[Dict[str, Any]]):
|
|
@@ -143,14 +143,16 @@ def to_openai_batch(
|
|
|
143
143
|
>>> batch_input = dt.to(to_openai_batch(model="gpt-4o"))
|
|
144
144
|
"""
|
|
145
145
|
|
|
146
|
-
|
|
146
|
+
counter = {"idx": 0}
|
|
147
|
+
|
|
148
|
+
def transform(item) -> dict:
|
|
147
149
|
messages = item.get(messages_field, []) if hasattr(item, "get") else item[messages_field]
|
|
148
150
|
|
|
149
151
|
if custom_id_field:
|
|
150
152
|
custom_id = item.get(custom_id_field) if hasattr(item, "get") else item[custom_id_field]
|
|
151
153
|
else:
|
|
152
|
-
custom_id = f"request-{idx
|
|
153
|
-
idx
|
|
154
|
+
custom_id = f"request-{counter['idx']}"
|
|
155
|
+
counter["idx"] += 1
|
|
154
156
|
|
|
155
157
|
return {
|
|
156
158
|
"custom_id": str(custom_id),
|
|
@@ -196,7 +198,7 @@ def to_llama_factory(
|
|
|
196
198
|
"""
|
|
197
199
|
|
|
198
200
|
def transform(item) -> dict:
|
|
199
|
-
get = lambda f:
|
|
201
|
+
get = lambda f: item.get(f, "") if hasattr(item, "get") else getattr(item, f, "")
|
|
200
202
|
|
|
201
203
|
result = {
|
|
202
204
|
"instruction": get(instruction_field),
|
|
@@ -248,7 +250,7 @@ def to_axolotl(
|
|
|
248
250
|
conversations = (
|
|
249
251
|
item.get(conversations_field, [])
|
|
250
252
|
if hasattr(item, "get")
|
|
251
|
-
else item
|
|
253
|
+
else getattr(item, conversations_field, [])
|
|
252
254
|
)
|
|
253
255
|
|
|
254
256
|
# 如果已经是正确格式,直接返回
|
|
@@ -257,7 +259,9 @@ def to_axolotl(
|
|
|
257
259
|
return {"conversations": conversations}
|
|
258
260
|
|
|
259
261
|
# 尝试从 messages 格式转换
|
|
260
|
-
messages =
|
|
262
|
+
messages = (
|
|
263
|
+
item.get("messages", []) if hasattr(item, "get") else getattr(item, "messages", [])
|
|
264
|
+
)
|
|
261
265
|
if messages:
|
|
262
266
|
role_map = {"user": "human", "assistant": "gpt", "system": "system"}
|
|
263
267
|
conversations = [
|
|
@@ -312,7 +316,7 @@ def to_llama_factory_sharegpt(
|
|
|
312
316
|
}
|
|
313
317
|
|
|
314
318
|
def transform(item) -> dict:
|
|
315
|
-
get = lambda f:
|
|
319
|
+
get = lambda f: item.get(f, "") if hasattr(item, "get") else getattr(item, f, "")
|
|
316
320
|
messages = get(messages_field) or []
|
|
317
321
|
|
|
318
322
|
conversations = []
|
|
@@ -385,7 +389,7 @@ def to_llama_factory_vlm(
|
|
|
385
389
|
"""
|
|
386
390
|
|
|
387
391
|
def transform(item) -> dict:
|
|
388
|
-
get = lambda f: item.get(f) if hasattr(item, "get") else item
|
|
392
|
+
get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
|
|
389
393
|
messages = get(messages_field) or []
|
|
390
394
|
|
|
391
395
|
instruction = ""
|
|
@@ -467,7 +471,7 @@ def to_llama_factory_vlm_sharegpt(
|
|
|
467
471
|
role_map = {"user": "human", "assistant": "gpt", "system": "system"}
|
|
468
472
|
|
|
469
473
|
def transform(item) -> dict:
|
|
470
|
-
get = lambda f: item.get(f) if hasattr(item, "get") else item
|
|
474
|
+
get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
|
|
471
475
|
messages = get(messages_field) or []
|
|
472
476
|
|
|
473
477
|
conversations = []
|
|
@@ -541,7 +545,7 @@ def to_swift_messages(
|
|
|
541
545
|
"""
|
|
542
546
|
|
|
543
547
|
def transform(item) -> dict:
|
|
544
|
-
get = lambda f: item.get(f) if hasattr(item, "get") else item
|
|
548
|
+
get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
|
|
545
549
|
messages = get(messages_field) or []
|
|
546
550
|
|
|
547
551
|
# 复制 messages,避免修改原数据
|
|
@@ -600,7 +604,7 @@ def to_swift_query_response(
|
|
|
600
604
|
"""
|
|
601
605
|
|
|
602
606
|
def transform(item) -> dict:
|
|
603
|
-
get = lambda f: item.get(f) if hasattr(item, "get") else item
|
|
607
|
+
get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
|
|
604
608
|
|
|
605
609
|
query = get(query_field)
|
|
606
610
|
response = get(response_field)
|
|
@@ -693,7 +697,7 @@ def to_swift_vlm(
|
|
|
693
697
|
"""
|
|
694
698
|
|
|
695
699
|
def transform(item) -> dict:
|
|
696
|
-
get = lambda f: item.get(f) if hasattr(item, "get") else item
|
|
700
|
+
get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
|
|
697
701
|
messages = get(messages_field) or []
|
|
698
702
|
|
|
699
703
|
result_messages = []
|
|
@@ -793,19 +793,29 @@ class DataTransformer:
|
|
|
793
793
|
seed: 随机种子
|
|
794
794
|
|
|
795
795
|
Returns:
|
|
796
|
-
(train, test) 两个 DataTransformer
|
|
796
|
+
(train, test) 两个 DataTransformer,各自拥有独立的血缘追踪器
|
|
797
797
|
"""
|
|
798
798
|
data = self.shuffle(seed).data
|
|
799
799
|
split_idx = int(len(data) * ratio)
|
|
800
800
|
|
|
801
|
-
#
|
|
801
|
+
# 分割后血缘追踪器各自独立(使用深拷贝避免相互影响)
|
|
802
802
|
tracker = self._lineage_tracker
|
|
803
|
+
train_tracker = None
|
|
804
|
+
test_tracker = None
|
|
805
|
+
|
|
803
806
|
if tracker:
|
|
804
807
|
tracker.record("split", {"ratio": ratio, "seed": seed}, len(self._data), len(data))
|
|
808
|
+
# 为每个子数据集创建独立的追踪器副本
|
|
809
|
+
train_tracker = tracker.copy()
|
|
810
|
+
train_tracker.record("split_part", {"part": "train", "ratio": ratio}, len(data), split_idx)
|
|
811
|
+
test_tracker = tracker.copy()
|
|
812
|
+
test_tracker.record(
|
|
813
|
+
"split_part", {"part": "test", "ratio": 1 - ratio}, len(data), len(data) - split_idx
|
|
814
|
+
)
|
|
805
815
|
|
|
806
816
|
return (
|
|
807
|
-
DataTransformer(data[:split_idx], _lineage_tracker=
|
|
808
|
-
DataTransformer(data[split_idx:], _lineage_tracker=
|
|
817
|
+
DataTransformer(data[:split_idx], _lineage_tracker=train_tracker),
|
|
818
|
+
DataTransformer(data[split_idx:], _lineage_tracker=test_tracker),
|
|
809
819
|
)
|
|
810
820
|
|
|
811
821
|
# ============ 并行处理 ============
|
|
@@ -815,6 +825,7 @@ class DataTransformer:
|
|
|
815
825
|
func: Callable[[Dict], Any],
|
|
816
826
|
workers: Optional[int] = None,
|
|
817
827
|
chunksize: int = 1000,
|
|
828
|
+
timeout: Optional[float] = None,
|
|
818
829
|
) -> List[Any]:
|
|
819
830
|
"""
|
|
820
831
|
并行执行转换函数(使用多进程)。
|
|
@@ -825,24 +836,46 @@ class DataTransformer:
|
|
|
825
836
|
func: 转换函数,接收原始 dict,返回转换结果
|
|
826
837
|
workers: 进程数,默认为 CPU 核心数
|
|
827
838
|
chunksize: 每个进程处理的数据块大小
|
|
839
|
+
timeout: 超时时间(秒),None 表示无超时
|
|
828
840
|
|
|
829
841
|
Returns:
|
|
830
842
|
转换后的结果列表
|
|
831
843
|
|
|
844
|
+
Raises:
|
|
845
|
+
TypeError: 如果 func 无法被 pickle(如 lambda 函数)
|
|
846
|
+
RuntimeError: 如果子进程执行出错或超时
|
|
847
|
+
|
|
832
848
|
Examples:
|
|
833
849
|
>>> def transform(item):
|
|
834
850
|
... return {"id": item["id"], "text": item["text"].upper()}
|
|
835
851
|
>>> results = dt.map_parallel(transform)
|
|
836
852
|
"""
|
|
837
|
-
from multiprocessing import Pool, cpu_count
|
|
853
|
+
from multiprocessing import Pool, TimeoutError, cpu_count
|
|
854
|
+
import pickle
|
|
838
855
|
|
|
839
856
|
if not self._data:
|
|
840
857
|
return []
|
|
841
858
|
|
|
859
|
+
# 检查函数是否可 pickle
|
|
860
|
+
try:
|
|
861
|
+
pickle.dumps(func)
|
|
862
|
+
except (pickle.PicklingError, AttributeError, TypeError) as e:
|
|
863
|
+
func_name = getattr(func, "__name__", str(func))
|
|
864
|
+
raise TypeError(
|
|
865
|
+
f"函数 '{func_name}' 无法被 pickle,不能用于并行处理。"
|
|
866
|
+
f"请使用模块级函数而非 lambda 或闭包。错误: {e}"
|
|
867
|
+
) from e
|
|
868
|
+
|
|
842
869
|
workers = workers or cpu_count()
|
|
843
870
|
|
|
844
|
-
|
|
845
|
-
|
|
871
|
+
try:
|
|
872
|
+
with Pool(workers) as pool:
|
|
873
|
+
async_result = pool.map_async(func, self._data, chunksize=chunksize)
|
|
874
|
+
results = async_result.get(timeout=timeout)
|
|
875
|
+
except TimeoutError:
|
|
876
|
+
raise RuntimeError(f"并行处理超时({timeout}秒)")
|
|
877
|
+
except Exception as e:
|
|
878
|
+
raise RuntimeError(f"并行处理失败: {type(e).__name__}: {e}") from e
|
|
846
879
|
|
|
847
880
|
return results
|
|
848
881
|
|
|
@@ -851,6 +884,7 @@ class DataTransformer:
|
|
|
851
884
|
func: Callable[[Dict], bool],
|
|
852
885
|
workers: Optional[int] = None,
|
|
853
886
|
chunksize: int = 1000,
|
|
887
|
+
timeout: Optional[float] = None,
|
|
854
888
|
) -> "DataTransformer":
|
|
855
889
|
"""
|
|
856
890
|
并行执行过滤函数(使用多进程)。
|
|
@@ -861,24 +895,46 @@ class DataTransformer:
|
|
|
861
895
|
func: 过滤函数,接收原始 dict,返回 True 保留
|
|
862
896
|
workers: 进程数,默认为 CPU 核心数
|
|
863
897
|
chunksize: 每个进程处理的数据块大小
|
|
898
|
+
timeout: 超时时间(秒),None 表示无超时
|
|
864
899
|
|
|
865
900
|
Returns:
|
|
866
901
|
过滤后的新 DataTransformer
|
|
867
902
|
|
|
903
|
+
Raises:
|
|
904
|
+
TypeError: 如果 func 无法被 pickle(如 lambda 函数)
|
|
905
|
+
RuntimeError: 如果子进程执行出错或超时
|
|
906
|
+
|
|
868
907
|
Examples:
|
|
869
908
|
>>> def is_valid(item):
|
|
870
909
|
... return len(item["text"]) > 10
|
|
871
910
|
>>> filtered = dt.filter_parallel(is_valid)
|
|
872
911
|
"""
|
|
873
|
-
from multiprocessing import Pool, cpu_count
|
|
912
|
+
from multiprocessing import Pool, TimeoutError, cpu_count
|
|
913
|
+
import pickle
|
|
874
914
|
|
|
875
915
|
if not self._data:
|
|
876
916
|
return DataTransformer([])
|
|
877
917
|
|
|
918
|
+
# 检查函数是否可 pickle
|
|
919
|
+
try:
|
|
920
|
+
pickle.dumps(func)
|
|
921
|
+
except (pickle.PicklingError, AttributeError, TypeError) as e:
|
|
922
|
+
func_name = getattr(func, "__name__", str(func))
|
|
923
|
+
raise TypeError(
|
|
924
|
+
f"函数 '{func_name}' 无法被 pickle,不能用于并行处理。"
|
|
925
|
+
f"请使用模块级函数而非 lambda 或闭包。错误: {e}"
|
|
926
|
+
) from e
|
|
927
|
+
|
|
878
928
|
workers = workers or cpu_count()
|
|
879
929
|
|
|
880
|
-
|
|
881
|
-
|
|
930
|
+
try:
|
|
931
|
+
with Pool(workers) as pool:
|
|
932
|
+
async_result = pool.map_async(func, self._data, chunksize=chunksize)
|
|
933
|
+
mask = async_result.get(timeout=timeout)
|
|
934
|
+
except TimeoutError:
|
|
935
|
+
raise RuntimeError(f"并行处理超时({timeout}秒)")
|
|
936
|
+
except Exception as e:
|
|
937
|
+
raise RuntimeError(f"并行处理失败: {type(e).__name__}: {e}") from e
|
|
882
938
|
|
|
883
939
|
filtered = [item for item, keep in zip(self._data, mask) if keep]
|
|
884
940
|
return DataTransformer(filtered)
|
|
@@ -237,6 +237,23 @@ class LineageTracker:
|
|
|
237
237
|
|
|
238
238
|
return lineage_path
|
|
239
239
|
|
|
240
|
+
def copy(self) -> "LineageTracker":
|
|
241
|
+
"""
|
|
242
|
+
创建追踪器的深拷贝。
|
|
243
|
+
|
|
244
|
+
用于 split() 等场景,确保子数据集有独立的血缘追踪。
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
新的 LineageTracker 实例
|
|
248
|
+
"""
|
|
249
|
+
import copy as copy_module
|
|
250
|
+
|
|
251
|
+
new_tracker = LineageTracker.__new__(LineageTracker)
|
|
252
|
+
new_tracker.source_path = self.source_path
|
|
253
|
+
new_tracker.source_lineage = self.source_lineage # LineageRecord 是不可变的,可共享
|
|
254
|
+
new_tracker.operations = copy_module.deepcopy(self.operations)
|
|
255
|
+
return new_tracker
|
|
256
|
+
|
|
240
257
|
|
|
241
258
|
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
|
|
242
259
|
"""
|