dtflow 0.5.0__tar.gz → 0.5.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dtflow-0.5.0 → dtflow-0.5.2}/PKG-INFO +11 -1
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/__init__.py +1 -1
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/common.py +13 -9
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/stats.py +114 -36
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/core.py +66 -10
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/lineage.py +17 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/streaming.py +93 -35
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/tokenizers.py +84 -29
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/utils/field_path.py +6 -2
- {dtflow-0.5.0 → dtflow-0.5.2}/pyproject.toml +11 -0
- dtflow-0.5.2/tests/README.md +88 -0
- dtflow-0.5.2/tests/benchmark_sharegpt.py +392 -0
- dtflow-0.5.2/tests/test_cli_benchmark.py +565 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_streaming.py +80 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_transformer.py +77 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/.gitignore +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/README.md +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/__main__.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/__init__.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/clean.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/commands.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/io_ops.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/lineage.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/pipeline.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/sample.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/transform.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/validate.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/converters.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/framework.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/mcp/__init__.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/mcp/__main__.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/mcp/cli.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/mcp/docs.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/mcp/server.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/pipeline.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/presets.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/schema.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/storage/__init__.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/storage/io.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/utils/__init__.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/utils/display.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/tests/benchmark_io.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_converters.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_field_path.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_framework.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_io.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_lineage.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_pipeline.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_schema.py +0 -0
- {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_tokenizers.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dtflow
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.2
|
|
4
4
|
Summary: A flexible data transformation tool for ML training formats (SFT, RLHF, Pretrain)
|
|
5
5
|
Project-URL: Homepage, https://github.com/yourusername/DataTransformer
|
|
6
6
|
Project-URL: Documentation, https://github.com/yourusername/DataTransformer#readme
|
|
@@ -32,16 +32,26 @@ Requires-Dist: orjson>=3.9.0
|
|
|
32
32
|
Requires-Dist: polars>=0.20.0
|
|
33
33
|
Requires-Dist: pyyaml>=5.4.0
|
|
34
34
|
Requires-Dist: rich>=10.0.0
|
|
35
|
+
Requires-Dist: tiktoken>=0.5.0
|
|
35
36
|
Requires-Dist: typer>=0.9.0
|
|
36
37
|
Provides-Extra: converters
|
|
37
38
|
Requires-Dist: datasets>=2.0.0; extra == 'converters'
|
|
38
39
|
Provides-Extra: dev
|
|
39
40
|
Requires-Dist: black>=21.0; extra == 'dev'
|
|
41
|
+
Requires-Dist: datasets>=2.0.0; extra == 'dev'
|
|
42
|
+
Requires-Dist: datasketch>=1.5.0; extra == 'dev'
|
|
40
43
|
Requires-Dist: flake8>=3.9.0; extra == 'dev'
|
|
44
|
+
Requires-Dist: huggingface-hub>=0.20.0; extra == 'dev'
|
|
41
45
|
Requires-Dist: isort>=5.9.0; extra == 'dev'
|
|
42
46
|
Requires-Dist: mypy>=0.910; extra == 'dev'
|
|
47
|
+
Requires-Dist: pyarrow; extra == 'dev'
|
|
43
48
|
Requires-Dist: pytest-cov>=2.12.0; extra == 'dev'
|
|
44
49
|
Requires-Dist: pytest>=6.0.0; extra == 'dev'
|
|
50
|
+
Requires-Dist: rich>=10.0.0; extra == 'dev'
|
|
51
|
+
Requires-Dist: scikit-learn>=0.24.0; extra == 'dev'
|
|
52
|
+
Requires-Dist: tiktoken>=0.5.0; extra == 'dev'
|
|
53
|
+
Requires-Dist: tokenizers>=0.15.0; extra == 'dev'
|
|
54
|
+
Requires-Dist: toolong>=1.5.0; extra == 'dev'
|
|
45
55
|
Provides-Extra: display
|
|
46
56
|
Provides-Extra: docs
|
|
47
57
|
Requires-Dist: myst-parser>=0.15.0; extra == 'docs'
|
|
@@ -57,7 +57,7 @@ def _get_file_row_count(filepath: Path) -> Optional[int]:
|
|
|
57
57
|
return None
|
|
58
58
|
|
|
59
59
|
|
|
60
|
-
def _format_value(value: Any, max_len: int =
|
|
60
|
+
def _format_value(value: Any, max_len: int = 120) -> str:
|
|
61
61
|
"""格式化单个值,长文本截断。"""
|
|
62
62
|
if value is None:
|
|
63
63
|
return "[dim]null[/dim]"
|
|
@@ -66,18 +66,22 @@ def _format_value(value: Any, max_len: int = 80) -> str:
|
|
|
66
66
|
if isinstance(value, (int, float)):
|
|
67
67
|
return f"[cyan]{value}[/cyan]"
|
|
68
68
|
if isinstance(value, str):
|
|
69
|
+
half_len = max_len // 2
|
|
69
70
|
# 处理多行文本
|
|
70
71
|
if "\n" in value:
|
|
71
72
|
lines = value.split("\n")
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
73
|
+
preview = value.replace("\n", "\\n")
|
|
74
|
+
if len(preview) > max_len:
|
|
75
|
+
# 前半 + 省略标记 + 后半
|
|
76
|
+
head = preview[:half_len]
|
|
77
|
+
tail = preview[-half_len:]
|
|
78
|
+
return f'"{head} [yellow]<<<{len(lines)}行>>>[/yellow] {tail}"'
|
|
78
79
|
return f'"{preview}"'
|
|
79
80
|
if len(value) > max_len:
|
|
80
|
-
|
|
81
|
+
# 前半 + 省略标记 + 后半
|
|
82
|
+
head = value[:half_len]
|
|
83
|
+
tail = value[-half_len:]
|
|
84
|
+
return f'"{head} [yellow]<<<{len(value)}字符>>>[/yellow] {tail}"'
|
|
81
85
|
return f'"{value}"'
|
|
82
86
|
return str(value)
|
|
83
87
|
|
|
@@ -86,7 +90,7 @@ def _format_nested(
|
|
|
86
90
|
value: Any,
|
|
87
91
|
indent: str = "",
|
|
88
92
|
is_last: bool = True,
|
|
89
|
-
max_len: int =
|
|
93
|
+
max_len: int = 120,
|
|
90
94
|
) -> List[str]:
|
|
91
95
|
"""
|
|
92
96
|
递归格式化嵌套结构,返回行列表。
|
|
@@ -465,34 +465,65 @@ def token_stats(
|
|
|
465
465
|
return
|
|
466
466
|
|
|
467
467
|
total = len(data)
|
|
468
|
-
print(f" 共 {total} 条数据")
|
|
469
|
-
print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
|
|
468
|
+
print(f" 共 {total:,} 条数据")
|
|
470
469
|
|
|
471
470
|
# 检查字段类型并选择合适的统计方法(支持嵌套路径)
|
|
472
471
|
sample = data[0]
|
|
473
472
|
field_value = get_field_with_spec(sample, field)
|
|
474
473
|
|
|
474
|
+
# 尝试使用 rich 进度条
|
|
475
475
|
try:
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
476
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
|
|
477
|
+
|
|
478
|
+
with Progress(
|
|
479
|
+
SpinnerColumn(),
|
|
480
|
+
TextColumn("[bold blue]统计 Token"),
|
|
481
|
+
BarColumn(),
|
|
482
|
+
TaskProgressColumn(),
|
|
483
|
+
TextColumn(f"(模型: {model})"),
|
|
484
|
+
) as progress:
|
|
485
|
+
task = progress.add_task("", total=total)
|
|
486
|
+
|
|
487
|
+
def update_progress(current: int, total_count: int):
|
|
488
|
+
progress.update(task, completed=current)
|
|
489
|
+
|
|
490
|
+
if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
|
|
491
|
+
from ..tokenizers import messages_token_stats
|
|
492
|
+
|
|
493
|
+
stats_result = messages_token_stats(
|
|
494
|
+
data, messages_field=field, model=model, progress_callback=update_progress
|
|
495
|
+
)
|
|
496
|
+
_print_messages_token_stats(stats_result, detailed)
|
|
497
|
+
else:
|
|
498
|
+
from ..tokenizers import token_stats as compute_token_stats
|
|
499
|
+
|
|
500
|
+
stats_result = compute_token_stats(
|
|
501
|
+
data, fields=field, model=model, progress_callback=update_progress
|
|
502
|
+
)
|
|
503
|
+
_print_text_token_stats(stats_result, detailed)
|
|
504
|
+
|
|
505
|
+
except ImportError:
|
|
506
|
+
# 没有 rich,显示简单进度
|
|
507
|
+
print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
|
|
508
|
+
try:
|
|
509
|
+
if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
|
|
510
|
+
from ..tokenizers import messages_token_stats
|
|
511
|
+
|
|
512
|
+
stats_result = messages_token_stats(data, messages_field=field, model=model)
|
|
513
|
+
_print_messages_token_stats(stats_result, detailed)
|
|
514
|
+
else:
|
|
515
|
+
from ..tokenizers import token_stats as compute_token_stats
|
|
494
516
|
|
|
495
|
-
|
|
517
|
+
stats_result = compute_token_stats(data, fields=field, model=model)
|
|
518
|
+
_print_text_token_stats(stats_result, detailed)
|
|
519
|
+
except ImportError as e:
|
|
520
|
+
print(f"错误: {e}")
|
|
521
|
+
return
|
|
522
|
+
except Exception as e:
|
|
523
|
+
print(f"错误: 统计失败 - {e}")
|
|
524
|
+
import traceback
|
|
525
|
+
|
|
526
|
+
traceback.print_exc()
|
|
496
527
|
|
|
497
528
|
|
|
498
529
|
def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
@@ -505,21 +536,39 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
|
505
536
|
console = Console()
|
|
506
537
|
|
|
507
538
|
# 概览
|
|
539
|
+
std = stats.get("std_tokens", 0)
|
|
508
540
|
overview = (
|
|
509
541
|
f"[bold]总样本数:[/bold] {stats['count']:,}\n"
|
|
510
542
|
f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
|
|
511
|
-
f"[bold]平均 Token:[/bold] {stats['avg_tokens']:,}\n"
|
|
512
|
-
f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
|
|
543
|
+
f"[bold]平均 Token:[/bold] {stats['avg_tokens']:,} (std: {std:.1f})\n"
|
|
513
544
|
f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
|
|
514
545
|
)
|
|
515
546
|
console.print(Panel(overview, title="📊 Token 统计概览", expand=False))
|
|
516
547
|
|
|
548
|
+
# 百分位数表格
|
|
549
|
+
table = Table(title="📈 分布统计")
|
|
550
|
+
table.add_column("百分位", style="cyan", justify="center")
|
|
551
|
+
table.add_column("Token 数", justify="right")
|
|
552
|
+
percentiles = [
|
|
553
|
+
("Min", stats["min_tokens"]),
|
|
554
|
+
("P25", stats.get("p25", "-")),
|
|
555
|
+
("P50 (中位数)", stats.get("median_tokens", "-")),
|
|
556
|
+
("P75", stats.get("p75", "-")),
|
|
557
|
+
("P90", stats.get("p90", "-")),
|
|
558
|
+
("P95", stats.get("p95", "-")),
|
|
559
|
+
("P99", stats.get("p99", "-")),
|
|
560
|
+
("Max", stats["max_tokens"]),
|
|
561
|
+
]
|
|
562
|
+
for name, val in percentiles:
|
|
563
|
+
table.add_row(name, f"{val:,}" if isinstance(val, int) else str(val))
|
|
564
|
+
console.print(table)
|
|
565
|
+
|
|
517
566
|
if detailed:
|
|
518
|
-
#
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
567
|
+
# 分角色统计
|
|
568
|
+
role_table = Table(title="📋 分角色统计")
|
|
569
|
+
role_table.add_column("角色", style="cyan")
|
|
570
|
+
role_table.add_column("Token 数", justify="right")
|
|
571
|
+
role_table.add_column("占比", justify="right")
|
|
523
572
|
|
|
524
573
|
total = stats["total_tokens"]
|
|
525
574
|
for role, key in [
|
|
@@ -529,22 +578,27 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
|
529
578
|
]:
|
|
530
579
|
tokens = stats.get(key, 0)
|
|
531
580
|
pct = tokens / total * 100 if total > 0 else 0
|
|
532
|
-
|
|
581
|
+
role_table.add_row(role, f"{tokens:,}", f"{pct:.1f}%")
|
|
533
582
|
|
|
534
|
-
console.print(
|
|
583
|
+
console.print(role_table)
|
|
535
584
|
console.print(f"\n平均对话轮数: {stats.get('avg_turns', 0)}")
|
|
536
585
|
|
|
537
586
|
except ImportError:
|
|
538
587
|
# 没有 rich,使用普通打印
|
|
588
|
+
std = stats.get("std_tokens", 0)
|
|
539
589
|
print(f"\n{'=' * 40}")
|
|
540
590
|
print("📊 Token 统计概览")
|
|
541
591
|
print(f"{'=' * 40}")
|
|
542
592
|
print(f"总样本数: {stats['count']:,}")
|
|
543
593
|
print(f"总 Token: {stats['total_tokens']:,}")
|
|
544
|
-
print(f"平均 Token: {stats['avg_tokens']:,}")
|
|
545
|
-
print(f"中位数: {stats['median_tokens']:,}")
|
|
594
|
+
print(f"平均 Token: {stats['avg_tokens']:,} (std: {std:.1f})")
|
|
546
595
|
print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
|
|
547
596
|
|
|
597
|
+
print(f"\n📈 百分位分布:")
|
|
598
|
+
print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
|
|
599
|
+
print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
|
|
600
|
+
print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
|
|
601
|
+
|
|
548
602
|
if detailed:
|
|
549
603
|
print(f"\n{'=' * 40}")
|
|
550
604
|
print("📋 分角色统计")
|
|
@@ -566,24 +620,48 @@ def _print_text_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
|
566
620
|
try:
|
|
567
621
|
from rich.console import Console
|
|
568
622
|
from rich.panel import Panel
|
|
623
|
+
from rich.table import Table
|
|
569
624
|
|
|
570
625
|
console = Console()
|
|
571
626
|
|
|
627
|
+
std = stats.get("std_tokens", 0)
|
|
572
628
|
overview = (
|
|
573
629
|
f"[bold]总样本数:[/bold] {stats['count']:,}\n"
|
|
574
630
|
f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
|
|
575
|
-
f"[bold]平均 Token:[/bold] {stats['avg_tokens']:.1f}\n"
|
|
576
|
-
f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
|
|
631
|
+
f"[bold]平均 Token:[/bold] {stats['avg_tokens']:.1f} (std: {std:.1f})\n"
|
|
577
632
|
f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
|
|
578
633
|
)
|
|
579
634
|
console.print(Panel(overview, title="📊 Token 统计", expand=False))
|
|
580
635
|
|
|
636
|
+
# 百分位数表格
|
|
637
|
+
table = Table(title="📈 分布统计")
|
|
638
|
+
table.add_column("百分位", style="cyan", justify="center")
|
|
639
|
+
table.add_column("Token 数", justify="right")
|
|
640
|
+
percentiles = [
|
|
641
|
+
("Min", stats["min_tokens"]),
|
|
642
|
+
("P25", stats.get("p25", "-")),
|
|
643
|
+
("P50 (中位数)", stats.get("median_tokens", "-")),
|
|
644
|
+
("P75", stats.get("p75", "-")),
|
|
645
|
+
("P90", stats.get("p90", "-")),
|
|
646
|
+
("P95", stats.get("p95", "-")),
|
|
647
|
+
("P99", stats.get("p99", "-")),
|
|
648
|
+
("Max", stats["max_tokens"]),
|
|
649
|
+
]
|
|
650
|
+
for name, val in percentiles:
|
|
651
|
+
table.add_row(name, f"{val:,}" if isinstance(val, int) else str(val))
|
|
652
|
+
console.print(table)
|
|
653
|
+
|
|
581
654
|
except ImportError:
|
|
655
|
+
std = stats.get("std_tokens", 0)
|
|
582
656
|
print(f"\n{'=' * 40}")
|
|
583
657
|
print("📊 Token 统计")
|
|
584
658
|
print(f"{'=' * 40}")
|
|
585
659
|
print(f"总样本数: {stats['count']:,}")
|
|
586
660
|
print(f"总 Token: {stats['total_tokens']:,}")
|
|
587
|
-
print(f"平均 Token: {stats['avg_tokens']:.1f}")
|
|
588
|
-
print(f"中位数: {stats['median_tokens']:,}")
|
|
661
|
+
print(f"平均 Token: {stats['avg_tokens']:.1f} (std: {std:.1f})")
|
|
589
662
|
print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
|
|
663
|
+
|
|
664
|
+
print(f"\n📈 百分位分布:")
|
|
665
|
+
print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
|
|
666
|
+
print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
|
|
667
|
+
print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
|
|
@@ -793,19 +793,29 @@ class DataTransformer:
|
|
|
793
793
|
seed: 随机种子
|
|
794
794
|
|
|
795
795
|
Returns:
|
|
796
|
-
(train, test) 两个 DataTransformer
|
|
796
|
+
(train, test) 两个 DataTransformer,各自拥有独立的血缘追踪器
|
|
797
797
|
"""
|
|
798
798
|
data = self.shuffle(seed).data
|
|
799
799
|
split_idx = int(len(data) * ratio)
|
|
800
800
|
|
|
801
|
-
#
|
|
801
|
+
# 分割后血缘追踪器各自独立(使用深拷贝避免相互影响)
|
|
802
802
|
tracker = self._lineage_tracker
|
|
803
|
+
train_tracker = None
|
|
804
|
+
test_tracker = None
|
|
805
|
+
|
|
803
806
|
if tracker:
|
|
804
807
|
tracker.record("split", {"ratio": ratio, "seed": seed}, len(self._data), len(data))
|
|
808
|
+
# 为每个子数据集创建独立的追踪器副本
|
|
809
|
+
train_tracker = tracker.copy()
|
|
810
|
+
train_tracker.record("split_part", {"part": "train", "ratio": ratio}, len(data), split_idx)
|
|
811
|
+
test_tracker = tracker.copy()
|
|
812
|
+
test_tracker.record(
|
|
813
|
+
"split_part", {"part": "test", "ratio": 1 - ratio}, len(data), len(data) - split_idx
|
|
814
|
+
)
|
|
805
815
|
|
|
806
816
|
return (
|
|
807
|
-
DataTransformer(data[:split_idx], _lineage_tracker=
|
|
808
|
-
DataTransformer(data[split_idx:], _lineage_tracker=
|
|
817
|
+
DataTransformer(data[:split_idx], _lineage_tracker=train_tracker),
|
|
818
|
+
DataTransformer(data[split_idx:], _lineage_tracker=test_tracker),
|
|
809
819
|
)
|
|
810
820
|
|
|
811
821
|
# ============ 并行处理 ============
|
|
@@ -815,6 +825,7 @@ class DataTransformer:
|
|
|
815
825
|
func: Callable[[Dict], Any],
|
|
816
826
|
workers: Optional[int] = None,
|
|
817
827
|
chunksize: int = 1000,
|
|
828
|
+
timeout: Optional[float] = None,
|
|
818
829
|
) -> List[Any]:
|
|
819
830
|
"""
|
|
820
831
|
并行执行转换函数(使用多进程)。
|
|
@@ -825,24 +836,46 @@ class DataTransformer:
|
|
|
825
836
|
func: 转换函数,接收原始 dict,返回转换结果
|
|
826
837
|
workers: 进程数,默认为 CPU 核心数
|
|
827
838
|
chunksize: 每个进程处理的数据块大小
|
|
839
|
+
timeout: 超时时间(秒),None 表示无超时
|
|
828
840
|
|
|
829
841
|
Returns:
|
|
830
842
|
转换后的结果列表
|
|
831
843
|
|
|
844
|
+
Raises:
|
|
845
|
+
TypeError: 如果 func 无法被 pickle(如 lambda 函数)
|
|
846
|
+
RuntimeError: 如果子进程执行出错或超时
|
|
847
|
+
|
|
832
848
|
Examples:
|
|
833
849
|
>>> def transform(item):
|
|
834
850
|
... return {"id": item["id"], "text": item["text"].upper()}
|
|
835
851
|
>>> results = dt.map_parallel(transform)
|
|
836
852
|
"""
|
|
837
|
-
from multiprocessing import Pool, cpu_count
|
|
853
|
+
from multiprocessing import Pool, TimeoutError, cpu_count
|
|
854
|
+
import pickle
|
|
838
855
|
|
|
839
856
|
if not self._data:
|
|
840
857
|
return []
|
|
841
858
|
|
|
859
|
+
# 检查函数是否可 pickle
|
|
860
|
+
try:
|
|
861
|
+
pickle.dumps(func)
|
|
862
|
+
except (pickle.PicklingError, AttributeError, TypeError) as e:
|
|
863
|
+
func_name = getattr(func, "__name__", str(func))
|
|
864
|
+
raise TypeError(
|
|
865
|
+
f"函数 '{func_name}' 无法被 pickle,不能用于并行处理。"
|
|
866
|
+
f"请使用模块级函数而非 lambda 或闭包。错误: {e}"
|
|
867
|
+
) from e
|
|
868
|
+
|
|
842
869
|
workers = workers or cpu_count()
|
|
843
870
|
|
|
844
|
-
|
|
845
|
-
|
|
871
|
+
try:
|
|
872
|
+
with Pool(workers) as pool:
|
|
873
|
+
async_result = pool.map_async(func, self._data, chunksize=chunksize)
|
|
874
|
+
results = async_result.get(timeout=timeout)
|
|
875
|
+
except TimeoutError:
|
|
876
|
+
raise RuntimeError(f"并行处理超时({timeout}秒)")
|
|
877
|
+
except Exception as e:
|
|
878
|
+
raise RuntimeError(f"并行处理失败: {type(e).__name__}: {e}") from e
|
|
846
879
|
|
|
847
880
|
return results
|
|
848
881
|
|
|
@@ -851,6 +884,7 @@ class DataTransformer:
|
|
|
851
884
|
func: Callable[[Dict], bool],
|
|
852
885
|
workers: Optional[int] = None,
|
|
853
886
|
chunksize: int = 1000,
|
|
887
|
+
timeout: Optional[float] = None,
|
|
854
888
|
) -> "DataTransformer":
|
|
855
889
|
"""
|
|
856
890
|
并行执行过滤函数(使用多进程)。
|
|
@@ -861,24 +895,46 @@ class DataTransformer:
|
|
|
861
895
|
func: 过滤函数,接收原始 dict,返回 True 保留
|
|
862
896
|
workers: 进程数,默认为 CPU 核心数
|
|
863
897
|
chunksize: 每个进程处理的数据块大小
|
|
898
|
+
timeout: 超时时间(秒),None 表示无超时
|
|
864
899
|
|
|
865
900
|
Returns:
|
|
866
901
|
过滤后的新 DataTransformer
|
|
867
902
|
|
|
903
|
+
Raises:
|
|
904
|
+
TypeError: 如果 func 无法被 pickle(如 lambda 函数)
|
|
905
|
+
RuntimeError: 如果子进程执行出错或超时
|
|
906
|
+
|
|
868
907
|
Examples:
|
|
869
908
|
>>> def is_valid(item):
|
|
870
909
|
... return len(item["text"]) > 10
|
|
871
910
|
>>> filtered = dt.filter_parallel(is_valid)
|
|
872
911
|
"""
|
|
873
|
-
from multiprocessing import Pool, cpu_count
|
|
912
|
+
from multiprocessing import Pool, TimeoutError, cpu_count
|
|
913
|
+
import pickle
|
|
874
914
|
|
|
875
915
|
if not self._data:
|
|
876
916
|
return DataTransformer([])
|
|
877
917
|
|
|
918
|
+
# 检查函数是否可 pickle
|
|
919
|
+
try:
|
|
920
|
+
pickle.dumps(func)
|
|
921
|
+
except (pickle.PicklingError, AttributeError, TypeError) as e:
|
|
922
|
+
func_name = getattr(func, "__name__", str(func))
|
|
923
|
+
raise TypeError(
|
|
924
|
+
f"函数 '{func_name}' 无法被 pickle,不能用于并行处理。"
|
|
925
|
+
f"请使用模块级函数而非 lambda 或闭包。错误: {e}"
|
|
926
|
+
) from e
|
|
927
|
+
|
|
878
928
|
workers = workers or cpu_count()
|
|
879
929
|
|
|
880
|
-
|
|
881
|
-
|
|
930
|
+
try:
|
|
931
|
+
with Pool(workers) as pool:
|
|
932
|
+
async_result = pool.map_async(func, self._data, chunksize=chunksize)
|
|
933
|
+
mask = async_result.get(timeout=timeout)
|
|
934
|
+
except TimeoutError:
|
|
935
|
+
raise RuntimeError(f"并行处理超时({timeout}秒)")
|
|
936
|
+
except Exception as e:
|
|
937
|
+
raise RuntimeError(f"并行处理失败: {type(e).__name__}: {e}") from e
|
|
882
938
|
|
|
883
939
|
filtered = [item for item, keep in zip(self._data, mask) if keep]
|
|
884
940
|
return DataTransformer(filtered)
|
|
@@ -237,6 +237,23 @@ class LineageTracker:
|
|
|
237
237
|
|
|
238
238
|
return lineage_path
|
|
239
239
|
|
|
240
|
+
def copy(self) -> "LineageTracker":
|
|
241
|
+
"""
|
|
242
|
+
创建追踪器的深拷贝。
|
|
243
|
+
|
|
244
|
+
用于 split() 等场景,确保子数据集有独立的血缘追踪。
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
新的 LineageTracker 实例
|
|
248
|
+
"""
|
|
249
|
+
import copy as copy_module
|
|
250
|
+
|
|
251
|
+
new_tracker = LineageTracker.__new__(LineageTracker)
|
|
252
|
+
new_tracker.source_path = self.source_path
|
|
253
|
+
new_tracker.source_lineage = self.source_lineage # LineageRecord 是不可变的,可共享
|
|
254
|
+
new_tracker.operations = copy_module.deepcopy(self.operations)
|
|
255
|
+
return new_tracker
|
|
256
|
+
|
|
240
257
|
|
|
241
258
|
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
|
|
242
259
|
"""
|
|
@@ -365,50 +365,108 @@ class StreamingTransformer:
|
|
|
365
365
|
"""
|
|
366
366
|
批量流式保存(CSV/Parquet/Arrow)。
|
|
367
367
|
|
|
368
|
-
|
|
368
|
+
真正的流式写入:分批处理,每批写入后释放内存。
|
|
369
|
+
内存占用 O(batch_size) 而非 O(n)。
|
|
369
370
|
"""
|
|
370
371
|
path = Path(filepath)
|
|
371
|
-
|
|
372
|
+
count = 0
|
|
373
|
+
batch = []
|
|
374
|
+
first_batch = True
|
|
372
375
|
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
if self._total is not None:
|
|
376
|
-
columns = [
|
|
377
|
-
SpinnerColumn(),
|
|
378
|
-
TextColumn("[progress.description]{task.description}"),
|
|
379
|
-
BarColumn(),
|
|
380
|
-
TaskProgressColumn(),
|
|
381
|
-
MofNCompleteColumn(),
|
|
382
|
-
TimeElapsedColumn(),
|
|
383
|
-
TimeRemainingColumn(),
|
|
384
|
-
]
|
|
385
|
-
else:
|
|
386
|
-
columns = [
|
|
387
|
-
SpinnerColumn(),
|
|
388
|
-
TextColumn("[progress.description]{task.description}"),
|
|
389
|
-
MofNCompleteColumn(),
|
|
390
|
-
TimeElapsedColumn(),
|
|
391
|
-
]
|
|
376
|
+
# 进度条配置
|
|
377
|
+
progress_columns = self._get_progress_columns()
|
|
392
378
|
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
for item in self._iterator:
|
|
400
|
-
all_items.append(item)
|
|
379
|
+
def write_batch(items: List[Dict], is_first: bool, writer_state: Dict):
|
|
380
|
+
"""写入一批数据"""
|
|
381
|
+
if not items:
|
|
382
|
+
return
|
|
383
|
+
|
|
384
|
+
df = pl.DataFrame(items)
|
|
401
385
|
|
|
402
|
-
if all_items:
|
|
403
|
-
df = pl.DataFrame(all_items)
|
|
404
386
|
if fmt == "csv":
|
|
405
|
-
|
|
387
|
+
if is_first:
|
|
388
|
+
df.write_csv(path)
|
|
389
|
+
else:
|
|
390
|
+
# CSV 追加模式:不写表头
|
|
391
|
+
with open(path, "ab") as f:
|
|
392
|
+
f.write(df.write_csv(include_header=False).encode("utf-8"))
|
|
393
|
+
|
|
406
394
|
elif fmt == "parquet":
|
|
407
|
-
|
|
395
|
+
import pyarrow as pa
|
|
396
|
+
import pyarrow.parquet as pq
|
|
397
|
+
|
|
398
|
+
table = df.to_arrow()
|
|
399
|
+
if is_first:
|
|
400
|
+
writer_state["writer"] = pq.ParquetWriter(str(path), table.schema)
|
|
401
|
+
writer_state["writer"].write_table(table)
|
|
402
|
+
|
|
408
403
|
elif fmt == "arrow":
|
|
409
|
-
|
|
404
|
+
import pyarrow as pa
|
|
405
|
+
|
|
406
|
+
table = df.to_arrow()
|
|
407
|
+
if is_first:
|
|
408
|
+
writer_state["writer"] = pa.ipc.new_file(str(path), table.schema)
|
|
409
|
+
for record_batch in table.to_batches():
|
|
410
|
+
writer_state["writer"].write_batch(record_batch)
|
|
411
|
+
|
|
412
|
+
writer_state: Dict[str, Any] = {}
|
|
413
|
+
|
|
414
|
+
try:
|
|
415
|
+
if show_progress:
|
|
416
|
+
with Progress(*progress_columns) as progress:
|
|
417
|
+
task = progress.add_task("处理中", total=self._total)
|
|
418
|
+
for item in self._iterator:
|
|
419
|
+
batch.append(item)
|
|
420
|
+
count += 1
|
|
421
|
+
progress.update(task, advance=1)
|
|
422
|
+
|
|
423
|
+
if len(batch) >= batch_size:
|
|
424
|
+
write_batch(batch, first_batch, writer_state)
|
|
425
|
+
first_batch = False
|
|
426
|
+
batch = [] # 释放内存
|
|
427
|
+
|
|
428
|
+
# 写入最后一批
|
|
429
|
+
if batch:
|
|
430
|
+
write_batch(batch, first_batch, writer_state)
|
|
431
|
+
else:
|
|
432
|
+
for item in self._iterator:
|
|
433
|
+
batch.append(item)
|
|
434
|
+
count += 1
|
|
435
|
+
|
|
436
|
+
if len(batch) >= batch_size:
|
|
437
|
+
write_batch(batch, first_batch, writer_state)
|
|
438
|
+
first_batch = False
|
|
439
|
+
batch = []
|
|
410
440
|
|
|
411
|
-
|
|
441
|
+
if batch:
|
|
442
|
+
write_batch(batch, first_batch, writer_state)
|
|
443
|
+
|
|
444
|
+
finally:
|
|
445
|
+
# 关闭 writer
|
|
446
|
+
if "writer" in writer_state:
|
|
447
|
+
writer_state["writer"].close()
|
|
448
|
+
|
|
449
|
+
return count
|
|
450
|
+
|
|
451
|
+
def _get_progress_columns(self):
|
|
452
|
+
"""获取进度条列配置"""
|
|
453
|
+
if self._total is not None:
|
|
454
|
+
return [
|
|
455
|
+
SpinnerColumn(),
|
|
456
|
+
TextColumn("[progress.description]{task.description}"),
|
|
457
|
+
BarColumn(),
|
|
458
|
+
TaskProgressColumn(),
|
|
459
|
+
MofNCompleteColumn(),
|
|
460
|
+
TimeElapsedColumn(),
|
|
461
|
+
TimeRemainingColumn(),
|
|
462
|
+
]
|
|
463
|
+
else:
|
|
464
|
+
return [
|
|
465
|
+
SpinnerColumn(),
|
|
466
|
+
TextColumn("[progress.description]{task.description}"),
|
|
467
|
+
MofNCompleteColumn(),
|
|
468
|
+
TimeElapsedColumn(),
|
|
469
|
+
]
|
|
412
470
|
|
|
413
471
|
def save_sharded(
|
|
414
472
|
self,
|