dtflow 0.5.0__tar.gz → 0.5.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {dtflow-0.5.0 → dtflow-0.5.2}/PKG-INFO +11 -1
  2. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/__init__.py +1 -1
  3. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/common.py +13 -9
  4. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/stats.py +114 -36
  5. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/core.py +66 -10
  6. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/lineage.py +17 -0
  7. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/streaming.py +93 -35
  8. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/tokenizers.py +84 -29
  9. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/utils/field_path.py +6 -2
  10. {dtflow-0.5.0 → dtflow-0.5.2}/pyproject.toml +11 -0
  11. dtflow-0.5.2/tests/README.md +88 -0
  12. dtflow-0.5.2/tests/benchmark_sharegpt.py +392 -0
  13. dtflow-0.5.2/tests/test_cli_benchmark.py +565 -0
  14. {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_streaming.py +80 -0
  15. {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_transformer.py +77 -0
  16. {dtflow-0.5.0 → dtflow-0.5.2}/.gitignore +0 -0
  17. {dtflow-0.5.0 → dtflow-0.5.2}/README.md +0 -0
  18. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/__main__.py +0 -0
  19. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/__init__.py +0 -0
  20. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/clean.py +0 -0
  21. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/commands.py +0 -0
  22. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/io_ops.py +0 -0
  23. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/lineage.py +0 -0
  24. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/pipeline.py +0 -0
  25. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/sample.py +0 -0
  26. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/transform.py +0 -0
  27. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/cli/validate.py +0 -0
  28. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/converters.py +0 -0
  29. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/framework.py +0 -0
  30. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/mcp/__init__.py +0 -0
  31. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/mcp/__main__.py +0 -0
  32. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/mcp/cli.py +0 -0
  33. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/mcp/docs.py +0 -0
  34. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/mcp/server.py +0 -0
  35. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/pipeline.py +0 -0
  36. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/presets.py +0 -0
  37. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/schema.py +0 -0
  38. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/storage/__init__.py +0 -0
  39. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/storage/io.py +0 -0
  40. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/utils/__init__.py +0 -0
  41. {dtflow-0.5.0 → dtflow-0.5.2}/dtflow/utils/display.py +0 -0
  42. {dtflow-0.5.0 → dtflow-0.5.2}/tests/benchmark_io.py +0 -0
  43. {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_converters.py +0 -0
  44. {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_field_path.py +0 -0
  45. {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_framework.py +0 -0
  46. {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_io.py +0 -0
  47. {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_lineage.py +0 -0
  48. {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_pipeline.py +0 -0
  49. {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_schema.py +0 -0
  50. {dtflow-0.5.0 → dtflow-0.5.2}/tests/test_tokenizers.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dtflow
3
- Version: 0.5.0
3
+ Version: 0.5.2
4
4
  Summary: A flexible data transformation tool for ML training formats (SFT, RLHF, Pretrain)
5
5
  Project-URL: Homepage, https://github.com/yourusername/DataTransformer
6
6
  Project-URL: Documentation, https://github.com/yourusername/DataTransformer#readme
@@ -32,16 +32,26 @@ Requires-Dist: orjson>=3.9.0
32
32
  Requires-Dist: polars>=0.20.0
33
33
  Requires-Dist: pyyaml>=5.4.0
34
34
  Requires-Dist: rich>=10.0.0
35
+ Requires-Dist: tiktoken>=0.5.0
35
36
  Requires-Dist: typer>=0.9.0
36
37
  Provides-Extra: converters
37
38
  Requires-Dist: datasets>=2.0.0; extra == 'converters'
38
39
  Provides-Extra: dev
39
40
  Requires-Dist: black>=21.0; extra == 'dev'
41
+ Requires-Dist: datasets>=2.0.0; extra == 'dev'
42
+ Requires-Dist: datasketch>=1.5.0; extra == 'dev'
40
43
  Requires-Dist: flake8>=3.9.0; extra == 'dev'
44
+ Requires-Dist: huggingface-hub>=0.20.0; extra == 'dev'
41
45
  Requires-Dist: isort>=5.9.0; extra == 'dev'
42
46
  Requires-Dist: mypy>=0.910; extra == 'dev'
47
+ Requires-Dist: pyarrow; extra == 'dev'
43
48
  Requires-Dist: pytest-cov>=2.12.0; extra == 'dev'
44
49
  Requires-Dist: pytest>=6.0.0; extra == 'dev'
50
+ Requires-Dist: rich>=10.0.0; extra == 'dev'
51
+ Requires-Dist: scikit-learn>=0.24.0; extra == 'dev'
52
+ Requires-Dist: tiktoken>=0.5.0; extra == 'dev'
53
+ Requires-Dist: tokenizers>=0.15.0; extra == 'dev'
54
+ Requires-Dist: toolong>=1.5.0; extra == 'dev'
45
55
  Provides-Extra: display
46
56
  Provides-Extra: docs
47
57
  Requires-Dist: myst-parser>=0.15.0; extra == 'docs'
@@ -60,7 +60,7 @@ from .tokenizers import (
60
60
  token_stats,
61
61
  )
62
62
 
63
- __version__ = "0.5.0"
63
+ __version__ = "0.5.2"
64
64
 
65
65
  __all__ = [
66
66
  # core
@@ -57,7 +57,7 @@ def _get_file_row_count(filepath: Path) -> Optional[int]:
57
57
  return None
58
58
 
59
59
 
60
- def _format_value(value: Any, max_len: int = 80) -> str:
60
+ def _format_value(value: Any, max_len: int = 120) -> str:
61
61
  """格式化单个值,长文本截断。"""
62
62
  if value is None:
63
63
  return "[dim]null[/dim]"
@@ -66,18 +66,22 @@ def _format_value(value: Any, max_len: int = 80) -> str:
66
66
  if isinstance(value, (int, float)):
67
67
  return f"[cyan]{value}[/cyan]"
68
68
  if isinstance(value, str):
69
+ half_len = max_len // 2
69
70
  # 处理多行文本
70
71
  if "\n" in value:
71
72
  lines = value.split("\n")
72
- if len(lines) > 3:
73
- preview = lines[0][:max_len] + f"... [dim]({len(lines)} 行)[/dim]"
74
- else:
75
- preview = value.replace("\n", "\\n")
76
- if len(preview) > max_len:
77
- preview = preview[:max_len] + "..."
73
+ preview = value.replace("\n", "\\n")
74
+ if len(preview) > max_len:
75
+ # 前半 + 省略标记 + 后半
76
+ head = preview[:half_len]
77
+ tail = preview[-half_len:]
78
+ return f'"{head} [yellow]<<<{len(lines)}行>>>[/yellow] {tail}"'
78
79
  return f'"{preview}"'
79
80
  if len(value) > max_len:
80
- return f'"{value[:max_len]}..." [dim]({len(value)} 字符)[/dim]'
81
+ # 前半 + 省略标记 + 后半
82
+ head = value[:half_len]
83
+ tail = value[-half_len:]
84
+ return f'"{head} [yellow]<<<{len(value)}字符>>>[/yellow] {tail}"'
81
85
  return f'"{value}"'
82
86
  return str(value)
83
87
 
@@ -86,7 +90,7 @@ def _format_nested(
86
90
  value: Any,
87
91
  indent: str = "",
88
92
  is_last: bool = True,
89
- max_len: int = 80,
93
+ max_len: int = 120,
90
94
  ) -> List[str]:
91
95
  """
92
96
  递归格式化嵌套结构,返回行列表。
@@ -465,34 +465,65 @@ def token_stats(
465
465
  return
466
466
 
467
467
  total = len(data)
468
- print(f" 共 {total} 条数据")
469
- print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
468
+ print(f" 共 {total:,} 条数据")
470
469
 
471
470
  # 检查字段类型并选择合适的统计方法(支持嵌套路径)
472
471
  sample = data[0]
473
472
  field_value = get_field_with_spec(sample, field)
474
473
 
474
+ # 尝试使用 rich 进度条
475
475
  try:
476
- if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
477
- # messages 格式
478
- from ..tokenizers import messages_token_stats
479
-
480
- stats_result = messages_token_stats(data, messages_field=field, model=model)
481
- _print_messages_token_stats(stats_result, detailed)
482
- else:
483
- # 普通文本字段
484
- from ..tokenizers import token_stats as compute_token_stats
485
-
486
- stats_result = compute_token_stats(data, fields=field, model=model)
487
- _print_text_token_stats(stats_result, detailed)
488
- except ImportError as e:
489
- print(f"错误: {e}")
490
- return
491
- except Exception as e:
492
- print(f"错误: 统计失败 - {e}")
493
- import traceback
476
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
477
+
478
+ with Progress(
479
+ SpinnerColumn(),
480
+ TextColumn("[bold blue]统计 Token"),
481
+ BarColumn(),
482
+ TaskProgressColumn(),
483
+ TextColumn(f"(模型: {model})"),
484
+ ) as progress:
485
+ task = progress.add_task("", total=total)
486
+
487
+ def update_progress(current: int, total_count: int):
488
+ progress.update(task, completed=current)
489
+
490
+ if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
491
+ from ..tokenizers import messages_token_stats
492
+
493
+ stats_result = messages_token_stats(
494
+ data, messages_field=field, model=model, progress_callback=update_progress
495
+ )
496
+ _print_messages_token_stats(stats_result, detailed)
497
+ else:
498
+ from ..tokenizers import token_stats as compute_token_stats
499
+
500
+ stats_result = compute_token_stats(
501
+ data, fields=field, model=model, progress_callback=update_progress
502
+ )
503
+ _print_text_token_stats(stats_result, detailed)
504
+
505
+ except ImportError:
506
+ # 没有 rich,显示简单进度
507
+ print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
508
+ try:
509
+ if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
510
+ from ..tokenizers import messages_token_stats
511
+
512
+ stats_result = messages_token_stats(data, messages_field=field, model=model)
513
+ _print_messages_token_stats(stats_result, detailed)
514
+ else:
515
+ from ..tokenizers import token_stats as compute_token_stats
494
516
 
495
- traceback.print_exc()
517
+ stats_result = compute_token_stats(data, fields=field, model=model)
518
+ _print_text_token_stats(stats_result, detailed)
519
+ except ImportError as e:
520
+ print(f"错误: {e}")
521
+ return
522
+ except Exception as e:
523
+ print(f"错误: 统计失败 - {e}")
524
+ import traceback
525
+
526
+ traceback.print_exc()
496
527
 
497
528
 
498
529
  def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
@@ -505,21 +536,39 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
505
536
  console = Console()
506
537
 
507
538
  # 概览
539
+ std = stats.get("std_tokens", 0)
508
540
  overview = (
509
541
  f"[bold]总样本数:[/bold] {stats['count']:,}\n"
510
542
  f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
511
- f"[bold]平均 Token:[/bold] {stats['avg_tokens']:,}\n"
512
- f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
543
+ f"[bold]平均 Token:[/bold] {stats['avg_tokens']:,} (std: {std:.1f})\n"
513
544
  f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
514
545
  )
515
546
  console.print(Panel(overview, title="📊 Token 统计概览", expand=False))
516
547
 
548
+ # 百分位数表格
549
+ table = Table(title="📈 分布统计")
550
+ table.add_column("百分位", style="cyan", justify="center")
551
+ table.add_column("Token 数", justify="right")
552
+ percentiles = [
553
+ ("Min", stats["min_tokens"]),
554
+ ("P25", stats.get("p25", "-")),
555
+ ("P50 (中位数)", stats.get("median_tokens", "-")),
556
+ ("P75", stats.get("p75", "-")),
557
+ ("P90", stats.get("p90", "-")),
558
+ ("P95", stats.get("p95", "-")),
559
+ ("P99", stats.get("p99", "-")),
560
+ ("Max", stats["max_tokens"]),
561
+ ]
562
+ for name, val in percentiles:
563
+ table.add_row(name, f"{val:,}" if isinstance(val, int) else str(val))
564
+ console.print(table)
565
+
517
566
  if detailed:
518
- # 详细统计
519
- table = Table(title="📋 分角色统计")
520
- table.add_column("角色", style="cyan")
521
- table.add_column("Token 数", justify="right")
522
- table.add_column("占比", justify="right")
567
+ # 分角色统计
568
+ role_table = Table(title="📋 分角色统计")
569
+ role_table.add_column("角色", style="cyan")
570
+ role_table.add_column("Token 数", justify="right")
571
+ role_table.add_column("占比", justify="right")
523
572
 
524
573
  total = stats["total_tokens"]
525
574
  for role, key in [
@@ -529,22 +578,27 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
529
578
  ]:
530
579
  tokens = stats.get(key, 0)
531
580
  pct = tokens / total * 100 if total > 0 else 0
532
- table.add_row(role, f"{tokens:,}", f"{pct:.1f}%")
581
+ role_table.add_row(role, f"{tokens:,}", f"{pct:.1f}%")
533
582
 
534
- console.print(table)
583
+ console.print(role_table)
535
584
  console.print(f"\n平均对话轮数: {stats.get('avg_turns', 0)}")
536
585
 
537
586
  except ImportError:
538
587
  # 没有 rich,使用普通打印
588
+ std = stats.get("std_tokens", 0)
539
589
  print(f"\n{'=' * 40}")
540
590
  print("📊 Token 统计概览")
541
591
  print(f"{'=' * 40}")
542
592
  print(f"总样本数: {stats['count']:,}")
543
593
  print(f"总 Token: {stats['total_tokens']:,}")
544
- print(f"平均 Token: {stats['avg_tokens']:,}")
545
- print(f"中位数: {stats['median_tokens']:,}")
594
+ print(f"平均 Token: {stats['avg_tokens']:,} (std: {std:.1f})")
546
595
  print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
547
596
 
597
+ print(f"\n📈 百分位分布:")
598
+ print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
599
+ print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
600
+ print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
601
+
548
602
  if detailed:
549
603
  print(f"\n{'=' * 40}")
550
604
  print("📋 分角色统计")
@@ -566,24 +620,48 @@ def _print_text_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
566
620
  try:
567
621
  from rich.console import Console
568
622
  from rich.panel import Panel
623
+ from rich.table import Table
569
624
 
570
625
  console = Console()
571
626
 
627
+ std = stats.get("std_tokens", 0)
572
628
  overview = (
573
629
  f"[bold]总样本数:[/bold] {stats['count']:,}\n"
574
630
  f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
575
- f"[bold]平均 Token:[/bold] {stats['avg_tokens']:.1f}\n"
576
- f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
631
+ f"[bold]平均 Token:[/bold] {stats['avg_tokens']:.1f} (std: {std:.1f})\n"
577
632
  f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
578
633
  )
579
634
  console.print(Panel(overview, title="📊 Token 统计", expand=False))
580
635
 
636
+ # 百分位数表格
637
+ table = Table(title="📈 分布统计")
638
+ table.add_column("百分位", style="cyan", justify="center")
639
+ table.add_column("Token 数", justify="right")
640
+ percentiles = [
641
+ ("Min", stats["min_tokens"]),
642
+ ("P25", stats.get("p25", "-")),
643
+ ("P50 (中位数)", stats.get("median_tokens", "-")),
644
+ ("P75", stats.get("p75", "-")),
645
+ ("P90", stats.get("p90", "-")),
646
+ ("P95", stats.get("p95", "-")),
647
+ ("P99", stats.get("p99", "-")),
648
+ ("Max", stats["max_tokens"]),
649
+ ]
650
+ for name, val in percentiles:
651
+ table.add_row(name, f"{val:,}" if isinstance(val, int) else str(val))
652
+ console.print(table)
653
+
581
654
  except ImportError:
655
+ std = stats.get("std_tokens", 0)
582
656
  print(f"\n{'=' * 40}")
583
657
  print("📊 Token 统计")
584
658
  print(f"{'=' * 40}")
585
659
  print(f"总样本数: {stats['count']:,}")
586
660
  print(f"总 Token: {stats['total_tokens']:,}")
587
- print(f"平均 Token: {stats['avg_tokens']:.1f}")
588
- print(f"中位数: {stats['median_tokens']:,}")
661
+ print(f"平均 Token: {stats['avg_tokens']:.1f} (std: {std:.1f})")
589
662
  print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
663
+
664
+ print(f"\n📈 百分位分布:")
665
+ print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
666
+ print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
667
+ print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
@@ -793,19 +793,29 @@ class DataTransformer:
793
793
  seed: 随机种子
794
794
 
795
795
  Returns:
796
- (train, test) 两个 DataTransformer
796
+ (train, test) 两个 DataTransformer,各自拥有独立的血缘追踪器
797
797
  """
798
798
  data = self.shuffle(seed).data
799
799
  split_idx = int(len(data) * ratio)
800
800
 
801
- # 分割后血缘追踪器各自独立
801
+ # 分割后血缘追踪器各自独立(使用深拷贝避免相互影响)
802
802
  tracker = self._lineage_tracker
803
+ train_tracker = None
804
+ test_tracker = None
805
+
803
806
  if tracker:
804
807
  tracker.record("split", {"ratio": ratio, "seed": seed}, len(self._data), len(data))
808
+ # 为每个子数据集创建独立的追踪器副本
809
+ train_tracker = tracker.copy()
810
+ train_tracker.record("split_part", {"part": "train", "ratio": ratio}, len(data), split_idx)
811
+ test_tracker = tracker.copy()
812
+ test_tracker.record(
813
+ "split_part", {"part": "test", "ratio": 1 - ratio}, len(data), len(data) - split_idx
814
+ )
805
815
 
806
816
  return (
807
- DataTransformer(data[:split_idx], _lineage_tracker=tracker),
808
- DataTransformer(data[split_idx:], _lineage_tracker=tracker),
817
+ DataTransformer(data[:split_idx], _lineage_tracker=train_tracker),
818
+ DataTransformer(data[split_idx:], _lineage_tracker=test_tracker),
809
819
  )
810
820
 
811
821
  # ============ 并行处理 ============
@@ -815,6 +825,7 @@ class DataTransformer:
815
825
  func: Callable[[Dict], Any],
816
826
  workers: Optional[int] = None,
817
827
  chunksize: int = 1000,
828
+ timeout: Optional[float] = None,
818
829
  ) -> List[Any]:
819
830
  """
820
831
  并行执行转换函数(使用多进程)。
@@ -825,24 +836,46 @@ class DataTransformer:
825
836
  func: 转换函数,接收原始 dict,返回转换结果
826
837
  workers: 进程数,默认为 CPU 核心数
827
838
  chunksize: 每个进程处理的数据块大小
839
+ timeout: 超时时间(秒),None 表示无超时
828
840
 
829
841
  Returns:
830
842
  转换后的结果列表
831
843
 
844
+ Raises:
845
+ TypeError: 如果 func 无法被 pickle(如 lambda 函数)
846
+ RuntimeError: 如果子进程执行出错或超时
847
+
832
848
  Examples:
833
849
  >>> def transform(item):
834
850
  ... return {"id": item["id"], "text": item["text"].upper()}
835
851
  >>> results = dt.map_parallel(transform)
836
852
  """
837
- from multiprocessing import Pool, cpu_count
853
+ from multiprocessing import Pool, TimeoutError, cpu_count
854
+ import pickle
838
855
 
839
856
  if not self._data:
840
857
  return []
841
858
 
859
+ # 检查函数是否可 pickle
860
+ try:
861
+ pickle.dumps(func)
862
+ except (pickle.PicklingError, AttributeError, TypeError) as e:
863
+ func_name = getattr(func, "__name__", str(func))
864
+ raise TypeError(
865
+ f"函数 '{func_name}' 无法被 pickle,不能用于并行处理。"
866
+ f"请使用模块级函数而非 lambda 或闭包。错误: {e}"
867
+ ) from e
868
+
842
869
  workers = workers or cpu_count()
843
870
 
844
- with Pool(workers) as pool:
845
- results = pool.map(func, self._data, chunksize=chunksize)
871
+ try:
872
+ with Pool(workers) as pool:
873
+ async_result = pool.map_async(func, self._data, chunksize=chunksize)
874
+ results = async_result.get(timeout=timeout)
875
+ except TimeoutError:
876
+ raise RuntimeError(f"并行处理超时({timeout}秒)")
877
+ except Exception as e:
878
+ raise RuntimeError(f"并行处理失败: {type(e).__name__}: {e}") from e
846
879
 
847
880
  return results
848
881
 
@@ -851,6 +884,7 @@ class DataTransformer:
851
884
  func: Callable[[Dict], bool],
852
885
  workers: Optional[int] = None,
853
886
  chunksize: int = 1000,
887
+ timeout: Optional[float] = None,
854
888
  ) -> "DataTransformer":
855
889
  """
856
890
  并行执行过滤函数(使用多进程)。
@@ -861,24 +895,46 @@ class DataTransformer:
861
895
  func: 过滤函数,接收原始 dict,返回 True 保留
862
896
  workers: 进程数,默认为 CPU 核心数
863
897
  chunksize: 每个进程处理的数据块大小
898
+ timeout: 超时时间(秒),None 表示无超时
864
899
 
865
900
  Returns:
866
901
  过滤后的新 DataTransformer
867
902
 
903
+ Raises:
904
+ TypeError: 如果 func 无法被 pickle(如 lambda 函数)
905
+ RuntimeError: 如果子进程执行出错或超时
906
+
868
907
  Examples:
869
908
  >>> def is_valid(item):
870
909
  ... return len(item["text"]) > 10
871
910
  >>> filtered = dt.filter_parallel(is_valid)
872
911
  """
873
- from multiprocessing import Pool, cpu_count
912
+ from multiprocessing import Pool, TimeoutError, cpu_count
913
+ import pickle
874
914
 
875
915
  if not self._data:
876
916
  return DataTransformer([])
877
917
 
918
+ # 检查函数是否可 pickle
919
+ try:
920
+ pickle.dumps(func)
921
+ except (pickle.PicklingError, AttributeError, TypeError) as e:
922
+ func_name = getattr(func, "__name__", str(func))
923
+ raise TypeError(
924
+ f"函数 '{func_name}' 无法被 pickle,不能用于并行处理。"
925
+ f"请使用模块级函数而非 lambda 或闭包。错误: {e}"
926
+ ) from e
927
+
878
928
  workers = workers or cpu_count()
879
929
 
880
- with Pool(workers) as pool:
881
- mask = pool.map(func, self._data, chunksize=chunksize)
930
+ try:
931
+ with Pool(workers) as pool:
932
+ async_result = pool.map_async(func, self._data, chunksize=chunksize)
933
+ mask = async_result.get(timeout=timeout)
934
+ except TimeoutError:
935
+ raise RuntimeError(f"并行处理超时({timeout}秒)")
936
+ except Exception as e:
937
+ raise RuntimeError(f"并行处理失败: {type(e).__name__}: {e}") from e
882
938
 
883
939
  filtered = [item for item, keep in zip(self._data, mask) if keep]
884
940
  return DataTransformer(filtered)
@@ -237,6 +237,23 @@ class LineageTracker:
237
237
 
238
238
  return lineage_path
239
239
 
240
+ def copy(self) -> "LineageTracker":
241
+ """
242
+ 创建追踪器的深拷贝。
243
+
244
+ 用于 split() 等场景,确保子数据集有独立的血缘追踪。
245
+
246
+ Returns:
247
+ 新的 LineageTracker 实例
248
+ """
249
+ import copy as copy_module
250
+
251
+ new_tracker = LineageTracker.__new__(LineageTracker)
252
+ new_tracker.source_path = self.source_path
253
+ new_tracker.source_lineage = self.source_lineage # LineageRecord 是不可变的,可共享
254
+ new_tracker.operations = copy_module.deepcopy(self.operations)
255
+ return new_tracker
256
+
240
257
 
241
258
  def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
242
259
  """
@@ -365,50 +365,108 @@ class StreamingTransformer:
365
365
  """
366
366
  批量流式保存(CSV/Parquet/Arrow)。
367
367
 
368
- 读取和处理是流式的,写入时收集后一次性写入。
368
+ 真正的流式写入:分批处理,每批写入后释放内存。
369
+ 内存占用 O(batch_size) 而非 O(n)。
369
370
  """
370
371
  path = Path(filepath)
371
- all_items = []
372
+ count = 0
373
+ batch = []
374
+ first_batch = True
372
375
 
373
- if show_progress:
374
- # 根据是否有总数选择进度条样式
375
- if self._total is not None:
376
- columns = [
377
- SpinnerColumn(),
378
- TextColumn("[progress.description]{task.description}"),
379
- BarColumn(),
380
- TaskProgressColumn(),
381
- MofNCompleteColumn(),
382
- TimeElapsedColumn(),
383
- TimeRemainingColumn(),
384
- ]
385
- else:
386
- columns = [
387
- SpinnerColumn(),
388
- TextColumn("[progress.description]{task.description}"),
389
- MofNCompleteColumn(),
390
- TimeElapsedColumn(),
391
- ]
376
+ # 进度条配置
377
+ progress_columns = self._get_progress_columns()
392
378
 
393
- with Progress(*columns) as progress:
394
- task = progress.add_task("处理中", total=self._total)
395
- for item in self._iterator:
396
- all_items.append(item)
397
- progress.update(task, advance=1)
398
- else:
399
- for item in self._iterator:
400
- all_items.append(item)
379
+ def write_batch(items: List[Dict], is_first: bool, writer_state: Dict):
380
+ """写入一批数据"""
381
+ if not items:
382
+ return
383
+
384
+ df = pl.DataFrame(items)
401
385
 
402
- if all_items:
403
- df = pl.DataFrame(all_items)
404
386
  if fmt == "csv":
405
- df.write_csv(path)
387
+ if is_first:
388
+ df.write_csv(path)
389
+ else:
390
+ # CSV 追加模式:不写表头
391
+ with open(path, "ab") as f:
392
+ f.write(df.write_csv(include_header=False).encode("utf-8"))
393
+
406
394
  elif fmt == "parquet":
407
- df.write_parquet(path)
395
+ import pyarrow as pa
396
+ import pyarrow.parquet as pq
397
+
398
+ table = df.to_arrow()
399
+ if is_first:
400
+ writer_state["writer"] = pq.ParquetWriter(str(path), table.schema)
401
+ writer_state["writer"].write_table(table)
402
+
408
403
  elif fmt == "arrow":
409
- df.write_ipc(path)
404
+ import pyarrow as pa
405
+
406
+ table = df.to_arrow()
407
+ if is_first:
408
+ writer_state["writer"] = pa.ipc.new_file(str(path), table.schema)
409
+ for record_batch in table.to_batches():
410
+ writer_state["writer"].write_batch(record_batch)
411
+
412
+ writer_state: Dict[str, Any] = {}
413
+
414
+ try:
415
+ if show_progress:
416
+ with Progress(*progress_columns) as progress:
417
+ task = progress.add_task("处理中", total=self._total)
418
+ for item in self._iterator:
419
+ batch.append(item)
420
+ count += 1
421
+ progress.update(task, advance=1)
422
+
423
+ if len(batch) >= batch_size:
424
+ write_batch(batch, first_batch, writer_state)
425
+ first_batch = False
426
+ batch = [] # 释放内存
427
+
428
+ # 写入最后一批
429
+ if batch:
430
+ write_batch(batch, first_batch, writer_state)
431
+ else:
432
+ for item in self._iterator:
433
+ batch.append(item)
434
+ count += 1
435
+
436
+ if len(batch) >= batch_size:
437
+ write_batch(batch, first_batch, writer_state)
438
+ first_batch = False
439
+ batch = []
410
440
 
411
- return len(all_items)
441
+ if batch:
442
+ write_batch(batch, first_batch, writer_state)
443
+
444
+ finally:
445
+ # 关闭 writer
446
+ if "writer" in writer_state:
447
+ writer_state["writer"].close()
448
+
449
+ return count
450
+
451
+ def _get_progress_columns(self):
452
+ """获取进度条列配置"""
453
+ if self._total is not None:
454
+ return [
455
+ SpinnerColumn(),
456
+ TextColumn("[progress.description]{task.description}"),
457
+ BarColumn(),
458
+ TaskProgressColumn(),
459
+ MofNCompleteColumn(),
460
+ TimeElapsedColumn(),
461
+ TimeRemainingColumn(),
462
+ ]
463
+ else:
464
+ return [
465
+ SpinnerColumn(),
466
+ TextColumn("[progress.description]{task.description}"),
467
+ MofNCompleteColumn(),
468
+ TimeElapsedColumn(),
469
+ ]
412
470
 
413
471
  def save_sharded(
414
472
  self,