dtflow 0.5.0__tar.gz → 0.5.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. dtflow-0.5.3/CHANGELOG.md +19 -0
  2. {dtflow-0.5.0 → dtflow-0.5.3}/PKG-INFO +11 -1
  3. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/__init__.py +7 -7
  4. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/common.py +13 -9
  5. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/stats.py +114 -36
  6. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/converters.py +17 -13
  7. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/core.py +66 -10
  8. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/lineage.py +17 -0
  9. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/presets.py +14 -15
  10. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/streaming.py +93 -35
  11. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/tokenizers.py +84 -29
  12. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/utils/__init__.py +3 -0
  13. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/utils/field_path.py +6 -2
  14. dtflow-0.5.3/dtflow/utils/helpers.py +30 -0
  15. {dtflow-0.5.0 → dtflow-0.5.3}/pyproject.toml +17 -4
  16. dtflow-0.5.3/tests/README.md +88 -0
  17. dtflow-0.5.3/tests/benchmark_sharegpt.py +392 -0
  18. dtflow-0.5.3/tests/test_cli_benchmark.py +565 -0
  19. dtflow-0.5.3/tests/test_cli_clean.py +314 -0
  20. dtflow-0.5.3/tests/test_cli_sample.py +242 -0
  21. dtflow-0.5.3/tests/test_cli_stats.py +213 -0
  22. dtflow-0.5.3/tests/test_cli_transform.py +304 -0
  23. {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_streaming.py +80 -0
  24. {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_transformer.py +77 -0
  25. {dtflow-0.5.0 → dtflow-0.5.3}/.gitignore +0 -0
  26. {dtflow-0.5.0 → dtflow-0.5.3}/README.md +0 -0
  27. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/__main__.py +0 -0
  28. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/__init__.py +0 -0
  29. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/clean.py +0 -0
  30. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/commands.py +0 -0
  31. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/io_ops.py +0 -0
  32. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/lineage.py +0 -0
  33. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/pipeline.py +0 -0
  34. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/sample.py +0 -0
  35. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/transform.py +0 -0
  36. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/cli/validate.py +0 -0
  37. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/framework.py +0 -0
  38. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/mcp/__init__.py +0 -0
  39. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/mcp/__main__.py +0 -0
  40. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/mcp/cli.py +0 -0
  41. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/mcp/docs.py +0 -0
  42. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/mcp/server.py +0 -0
  43. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/pipeline.py +0 -0
  44. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/schema.py +0 -0
  45. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/storage/__init__.py +0 -0
  46. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/storage/io.py +0 -0
  47. {dtflow-0.5.0 → dtflow-0.5.3}/dtflow/utils/display.py +0 -0
  48. {dtflow-0.5.0 → dtflow-0.5.3}/tests/benchmark_io.py +0 -0
  49. {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_converters.py +0 -0
  50. {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_field_path.py +0 -0
  51. {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_framework.py +0 -0
  52. {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_io.py +0 -0
  53. {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_lineage.py +0 -0
  54. {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_pipeline.py +0 -0
  55. {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_schema.py +0 -0
  56. {dtflow-0.5.0 → dtflow-0.5.3}/tests/test_tokenizers.py +0 -0
@@ -0,0 +1,19 @@
1
+ # Changelog
2
+
3
+ ## [0.5.2] - 2026-01-18
4
+
5
+ ### Miscellaneous
6
+
7
+ - Bump version to 0.5.2
8
+ - 添加 pre-commit 配置和发版脚本
9
+
10
+ ## [0.5.1] - 2026-01-18
11
+
12
+ ### Features
13
+
14
+ - 优化 sample 命令文本预览显示
15
+
16
+ ### Testing
17
+
18
+ - 添加测试运行说明
19
+ - 补充 tail/token-stats/validate 性能测试
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dtflow
3
- Version: 0.5.0
3
+ Version: 0.5.3
4
4
  Summary: A flexible data transformation tool for ML training formats (SFT, RLHF, Pretrain)
5
5
  Project-URL: Homepage, https://github.com/yourusername/DataTransformer
6
6
  Project-URL: Documentation, https://github.com/yourusername/DataTransformer#readme
@@ -32,16 +32,26 @@ Requires-Dist: orjson>=3.9.0
32
32
  Requires-Dist: polars>=0.20.0
33
33
  Requires-Dist: pyyaml>=5.4.0
34
34
  Requires-Dist: rich>=10.0.0
35
+ Requires-Dist: tiktoken>=0.5.0
35
36
  Requires-Dist: typer>=0.9.0
36
37
  Provides-Extra: converters
37
38
  Requires-Dist: datasets>=2.0.0; extra == 'converters'
38
39
  Provides-Extra: dev
39
40
  Requires-Dist: black>=21.0; extra == 'dev'
41
+ Requires-Dist: datasets>=2.0.0; extra == 'dev'
42
+ Requires-Dist: datasketch>=1.5.0; extra == 'dev'
40
43
  Requires-Dist: flake8>=3.9.0; extra == 'dev'
44
+ Requires-Dist: huggingface-hub>=0.20.0; extra == 'dev'
41
45
  Requires-Dist: isort>=5.9.0; extra == 'dev'
42
46
  Requires-Dist: mypy>=0.910; extra == 'dev'
47
+ Requires-Dist: pyarrow; extra == 'dev'
43
48
  Requires-Dist: pytest-cov>=2.12.0; extra == 'dev'
44
49
  Requires-Dist: pytest>=6.0.0; extra == 'dev'
50
+ Requires-Dist: rich>=10.0.0; extra == 'dev'
51
+ Requires-Dist: scikit-learn>=0.24.0; extra == 'dev'
52
+ Requires-Dist: tiktoken>=0.5.0; extra == 'dev'
53
+ Requires-Dist: tokenizers>=0.15.0; extra == 'dev'
54
+ Requires-Dist: toolong>=1.5.0; extra == 'dev'
45
55
  Provides-Extra: display
46
56
  Provides-Extra: docs
47
57
  Requires-Dist: myst-parser>=0.15.0; extra == 'docs'
@@ -26,6 +26,12 @@ from .converters import ( # LLaMA-Factory 扩展; ms-swift
26
26
  to_swift_vlm,
27
27
  )
28
28
  from .core import DataTransformer, DictWrapper, TransformError, TransformErrors
29
+ from .framework import (
30
+ CompatibilityResult,
31
+ check_compatibility,
32
+ detect_format,
33
+ export_for,
34
+ )
29
35
  from .presets import get_preset, list_presets
30
36
  from .schema import (
31
37
  Field,
@@ -38,12 +44,6 @@ from .schema import (
38
44
  sharegpt_schema,
39
45
  validate_data,
40
46
  )
41
- from .framework import (
42
- CompatibilityResult,
43
- check_compatibility,
44
- detect_format,
45
- export_for,
46
- )
47
47
  from .storage import load_data, sample_file, save_data
48
48
  from .streaming import StreamingTransformer, load_sharded, load_stream, process_shards
49
49
  from .tokenizers import (
@@ -60,7 +60,7 @@ from .tokenizers import (
60
60
  token_stats,
61
61
  )
62
62
 
63
- __version__ = "0.5.0"
63
+ __version__ = "0.5.3"
64
64
 
65
65
  __all__ = [
66
66
  # core
@@ -57,7 +57,7 @@ def _get_file_row_count(filepath: Path) -> Optional[int]:
57
57
  return None
58
58
 
59
59
 
60
- def _format_value(value: Any, max_len: int = 80) -> str:
60
+ def _format_value(value: Any, max_len: int = 120) -> str:
61
61
  """格式化单个值,长文本截断。"""
62
62
  if value is None:
63
63
  return "[dim]null[/dim]"
@@ -66,18 +66,22 @@ def _format_value(value: Any, max_len: int = 80) -> str:
66
66
  if isinstance(value, (int, float)):
67
67
  return f"[cyan]{value}[/cyan]"
68
68
  if isinstance(value, str):
69
+ half_len = max_len // 2
69
70
  # 处理多行文本
70
71
  if "\n" in value:
71
72
  lines = value.split("\n")
72
- if len(lines) > 3:
73
- preview = lines[0][:max_len] + f"... [dim]({len(lines)} 行)[/dim]"
74
- else:
75
- preview = value.replace("\n", "\\n")
76
- if len(preview) > max_len:
77
- preview = preview[:max_len] + "..."
73
+ preview = value.replace("\n", "\\n")
74
+ if len(preview) > max_len:
75
+ # 前半 + 省略标记 + 后半
76
+ head = preview[:half_len]
77
+ tail = preview[-half_len:]
78
+ return f'"{head} [yellow]<<<{len(lines)}行>>>[/yellow] {tail}"'
78
79
  return f'"{preview}"'
79
80
  if len(value) > max_len:
80
- return f'"{value[:max_len]}..." [dim]({len(value)} 字符)[/dim]'
81
+ # 前半 + 省略标记 + 后半
82
+ head = value[:half_len]
83
+ tail = value[-half_len:]
84
+ return f'"{head} [yellow]<<<{len(value)}字符>>>[/yellow] {tail}"'
81
85
  return f'"{value}"'
82
86
  return str(value)
83
87
 
@@ -86,7 +90,7 @@ def _format_nested(
86
90
  value: Any,
87
91
  indent: str = "",
88
92
  is_last: bool = True,
89
- max_len: int = 80,
93
+ max_len: int = 120,
90
94
  ) -> List[str]:
91
95
  """
92
96
  递归格式化嵌套结构,返回行列表。
@@ -465,34 +465,65 @@ def token_stats(
465
465
  return
466
466
 
467
467
  total = len(data)
468
- print(f" 共 {total} 条数据")
469
- print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
468
+ print(f" 共 {total:,} 条数据")
470
469
 
471
470
  # 检查字段类型并选择合适的统计方法(支持嵌套路径)
472
471
  sample = data[0]
473
472
  field_value = get_field_with_spec(sample, field)
474
473
 
474
+ # 尝试使用 rich 进度条
475
475
  try:
476
- if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
477
- # messages 格式
478
- from ..tokenizers import messages_token_stats
479
-
480
- stats_result = messages_token_stats(data, messages_field=field, model=model)
481
- _print_messages_token_stats(stats_result, detailed)
482
- else:
483
- # 普通文本字段
484
- from ..tokenizers import token_stats as compute_token_stats
485
-
486
- stats_result = compute_token_stats(data, fields=field, model=model)
487
- _print_text_token_stats(stats_result, detailed)
488
- except ImportError as e:
489
- print(f"错误: {e}")
490
- return
491
- except Exception as e:
492
- print(f"错误: 统计失败 - {e}")
493
- import traceback
476
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
477
+
478
+ with Progress(
479
+ SpinnerColumn(),
480
+ TextColumn("[bold blue]统计 Token"),
481
+ BarColumn(),
482
+ TaskProgressColumn(),
483
+ TextColumn(f"(模型: {model})"),
484
+ ) as progress:
485
+ task = progress.add_task("", total=total)
486
+
487
+ def update_progress(current: int, total_count: int):
488
+ progress.update(task, completed=current)
489
+
490
+ if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
491
+ from ..tokenizers import messages_token_stats
492
+
493
+ stats_result = messages_token_stats(
494
+ data, messages_field=field, model=model, progress_callback=update_progress
495
+ )
496
+ _print_messages_token_stats(stats_result, detailed)
497
+ else:
498
+ from ..tokenizers import token_stats as compute_token_stats
499
+
500
+ stats_result = compute_token_stats(
501
+ data, fields=field, model=model, progress_callback=update_progress
502
+ )
503
+ _print_text_token_stats(stats_result, detailed)
504
+
505
+ except ImportError:
506
+ # 没有 rich,显示简单进度
507
+ print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
508
+ try:
509
+ if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
510
+ from ..tokenizers import messages_token_stats
511
+
512
+ stats_result = messages_token_stats(data, messages_field=field, model=model)
513
+ _print_messages_token_stats(stats_result, detailed)
514
+ else:
515
+ from ..tokenizers import token_stats as compute_token_stats
494
516
 
495
- traceback.print_exc()
517
+ stats_result = compute_token_stats(data, fields=field, model=model)
518
+ _print_text_token_stats(stats_result, detailed)
519
+ except ImportError as e:
520
+ print(f"错误: {e}")
521
+ return
522
+ except Exception as e:
523
+ print(f"错误: 统计失败 - {e}")
524
+ import traceback
525
+
526
+ traceback.print_exc()
496
527
 
497
528
 
498
529
  def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
@@ -505,21 +536,39 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
505
536
  console = Console()
506
537
 
507
538
  # 概览
539
+ std = stats.get("std_tokens", 0)
508
540
  overview = (
509
541
  f"[bold]总样本数:[/bold] {stats['count']:,}\n"
510
542
  f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
511
- f"[bold]平均 Token:[/bold] {stats['avg_tokens']:,}\n"
512
- f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
543
+ f"[bold]平均 Token:[/bold] {stats['avg_tokens']:,} (std: {std:.1f})\n"
513
544
  f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
514
545
  )
515
546
  console.print(Panel(overview, title="📊 Token 统计概览", expand=False))
516
547
 
548
+ # 百分位数表格
549
+ table = Table(title="📈 分布统计")
550
+ table.add_column("百分位", style="cyan", justify="center")
551
+ table.add_column("Token 数", justify="right")
552
+ percentiles = [
553
+ ("Min", stats["min_tokens"]),
554
+ ("P25", stats.get("p25", "-")),
555
+ ("P50 (中位数)", stats.get("median_tokens", "-")),
556
+ ("P75", stats.get("p75", "-")),
557
+ ("P90", stats.get("p90", "-")),
558
+ ("P95", stats.get("p95", "-")),
559
+ ("P99", stats.get("p99", "-")),
560
+ ("Max", stats["max_tokens"]),
561
+ ]
562
+ for name, val in percentiles:
563
+ table.add_row(name, f"{val:,}" if isinstance(val, int) else str(val))
564
+ console.print(table)
565
+
517
566
  if detailed:
518
- # 详细统计
519
- table = Table(title="📋 分角色统计")
520
- table.add_column("角色", style="cyan")
521
- table.add_column("Token 数", justify="right")
522
- table.add_column("占比", justify="right")
567
+ # 分角色统计
568
+ role_table = Table(title="📋 分角色统计")
569
+ role_table.add_column("角色", style="cyan")
570
+ role_table.add_column("Token 数", justify="right")
571
+ role_table.add_column("占比", justify="right")
523
572
 
524
573
  total = stats["total_tokens"]
525
574
  for role, key in [
@@ -529,22 +578,27 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
529
578
  ]:
530
579
  tokens = stats.get(key, 0)
531
580
  pct = tokens / total * 100 if total > 0 else 0
532
- table.add_row(role, f"{tokens:,}", f"{pct:.1f}%")
581
+ role_table.add_row(role, f"{tokens:,}", f"{pct:.1f}%")
533
582
 
534
- console.print(table)
583
+ console.print(role_table)
535
584
  console.print(f"\n平均对话轮数: {stats.get('avg_turns', 0)}")
536
585
 
537
586
  except ImportError:
538
587
  # 没有 rich,使用普通打印
588
+ std = stats.get("std_tokens", 0)
539
589
  print(f"\n{'=' * 40}")
540
590
  print("📊 Token 统计概览")
541
591
  print(f"{'=' * 40}")
542
592
  print(f"总样本数: {stats['count']:,}")
543
593
  print(f"总 Token: {stats['total_tokens']:,}")
544
- print(f"平均 Token: {stats['avg_tokens']:,}")
545
- print(f"中位数: {stats['median_tokens']:,}")
594
+ print(f"平均 Token: {stats['avg_tokens']:,} (std: {std:.1f})")
546
595
  print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
547
596
 
597
+ print(f"\n📈 百分位分布:")
598
+ print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
599
+ print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
600
+ print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
601
+
548
602
  if detailed:
549
603
  print(f"\n{'=' * 40}")
550
604
  print("📋 分角色统计")
@@ -566,24 +620,48 @@ def _print_text_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
566
620
  try:
567
621
  from rich.console import Console
568
622
  from rich.panel import Panel
623
+ from rich.table import Table
569
624
 
570
625
  console = Console()
571
626
 
627
+ std = stats.get("std_tokens", 0)
572
628
  overview = (
573
629
  f"[bold]总样本数:[/bold] {stats['count']:,}\n"
574
630
  f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
575
- f"[bold]平均 Token:[/bold] {stats['avg_tokens']:.1f}\n"
576
- f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
631
+ f"[bold]平均 Token:[/bold] {stats['avg_tokens']:.1f} (std: {std:.1f})\n"
577
632
  f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
578
633
  )
579
634
  console.print(Panel(overview, title="📊 Token 统计", expand=False))
580
635
 
636
+ # 百分位数表格
637
+ table = Table(title="📈 分布统计")
638
+ table.add_column("百分位", style="cyan", justify="center")
639
+ table.add_column("Token 数", justify="right")
640
+ percentiles = [
641
+ ("Min", stats["min_tokens"]),
642
+ ("P25", stats.get("p25", "-")),
643
+ ("P50 (中位数)", stats.get("median_tokens", "-")),
644
+ ("P75", stats.get("p75", "-")),
645
+ ("P90", stats.get("p90", "-")),
646
+ ("P95", stats.get("p95", "-")),
647
+ ("P99", stats.get("p99", "-")),
648
+ ("Max", stats["max_tokens"]),
649
+ ]
650
+ for name, val in percentiles:
651
+ table.add_row(name, f"{val:,}" if isinstance(val, int) else str(val))
652
+ console.print(table)
653
+
581
654
  except ImportError:
655
+ std = stats.get("std_tokens", 0)
582
656
  print(f"\n{'=' * 40}")
583
657
  print("📊 Token 统计")
584
658
  print(f"{'=' * 40}")
585
659
  print(f"总样本数: {stats['count']:,}")
586
660
  print(f"总 Token: {stats['total_tokens']:,}")
587
- print(f"平均 Token: {stats['avg_tokens']:.1f}")
588
- print(f"中位数: {stats['median_tokens']:,}")
661
+ print(f"平均 Token: {stats['avg_tokens']:.1f} (std: {std:.1f})")
589
662
  print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
663
+
664
+ print(f"\n📈 百分位分布:")
665
+ print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
666
+ print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
667
+ print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
@@ -4,7 +4,7 @@
4
4
  提供与 HuggingFace datasets 等常用格式的互转功能。
5
5
  """
6
6
 
7
- from typing import Any, Callable, Dict, List, Optional, Union
7
+ from typing import Any, Callable, Dict, List, Optional
8
8
 
9
9
 
10
10
  def to_hf_dataset(data: List[Dict[str, Any]]):
@@ -143,14 +143,16 @@ def to_openai_batch(
143
143
  >>> batch_input = dt.to(to_openai_batch(model="gpt-4o"))
144
144
  """
145
145
 
146
- def transform(item, idx=[0]) -> dict:
146
+ counter = {"idx": 0}
147
+
148
+ def transform(item) -> dict:
147
149
  messages = item.get(messages_field, []) if hasattr(item, "get") else item[messages_field]
148
150
 
149
151
  if custom_id_field:
150
152
  custom_id = item.get(custom_id_field) if hasattr(item, "get") else item[custom_id_field]
151
153
  else:
152
- custom_id = f"request-{idx[0]}"
153
- idx[0] += 1
154
+ custom_id = f"request-{counter['idx']}"
155
+ counter["idx"] += 1
154
156
 
155
157
  return {
156
158
  "custom_id": str(custom_id),
@@ -196,7 +198,7 @@ def to_llama_factory(
196
198
  """
197
199
 
198
200
  def transform(item) -> dict:
199
- get = lambda f: (item.get(f, "") if hasattr(item, "get") else item.get(f, ""))
201
+ get = lambda f: item.get(f, "") if hasattr(item, "get") else getattr(item, f, "")
200
202
 
201
203
  result = {
202
204
  "instruction": get(instruction_field),
@@ -248,7 +250,7 @@ def to_axolotl(
248
250
  conversations = (
249
251
  item.get(conversations_field, [])
250
252
  if hasattr(item, "get")
251
- else item.get(conversations_field, [])
253
+ else getattr(item, conversations_field, [])
252
254
  )
253
255
 
254
256
  # 如果已经是正确格式,直接返回
@@ -257,7 +259,9 @@ def to_axolotl(
257
259
  return {"conversations": conversations}
258
260
 
259
261
  # 尝试从 messages 格式转换
260
- messages = item.get("messages", []) if hasattr(item, "get") else item.get("messages", [])
262
+ messages = (
263
+ item.get("messages", []) if hasattr(item, "get") else getattr(item, "messages", [])
264
+ )
261
265
  if messages:
262
266
  role_map = {"user": "human", "assistant": "gpt", "system": "system"}
263
267
  conversations = [
@@ -312,7 +316,7 @@ def to_llama_factory_sharegpt(
312
316
  }
313
317
 
314
318
  def transform(item) -> dict:
315
- get = lambda f: (item.get(f, "") if hasattr(item, "get") else item.get(f, ""))
319
+ get = lambda f: item.get(f, "") if hasattr(item, "get") else getattr(item, f, "")
316
320
  messages = get(messages_field) or []
317
321
 
318
322
  conversations = []
@@ -385,7 +389,7 @@ def to_llama_factory_vlm(
385
389
  """
386
390
 
387
391
  def transform(item) -> dict:
388
- get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
392
+ get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
389
393
  messages = get(messages_field) or []
390
394
 
391
395
  instruction = ""
@@ -467,7 +471,7 @@ def to_llama_factory_vlm_sharegpt(
467
471
  role_map = {"user": "human", "assistant": "gpt", "system": "system"}
468
472
 
469
473
  def transform(item) -> dict:
470
- get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
474
+ get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
471
475
  messages = get(messages_field) or []
472
476
 
473
477
  conversations = []
@@ -541,7 +545,7 @@ def to_swift_messages(
541
545
  """
542
546
 
543
547
  def transform(item) -> dict:
544
- get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
548
+ get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
545
549
  messages = get(messages_field) or []
546
550
 
547
551
  # 复制 messages,避免修改原数据
@@ -600,7 +604,7 @@ def to_swift_query_response(
600
604
  """
601
605
 
602
606
  def transform(item) -> dict:
603
- get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
607
+ get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
604
608
 
605
609
  query = get(query_field)
606
610
  response = get(response_field)
@@ -693,7 +697,7 @@ def to_swift_vlm(
693
697
  """
694
698
 
695
699
  def transform(item) -> dict:
696
- get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
700
+ get = lambda f: item.get(f) if hasattr(item, "get") else getattr(item, f, None)
697
701
  messages = get(messages_field) or []
698
702
 
699
703
  result_messages = []
@@ -793,19 +793,29 @@ class DataTransformer:
793
793
  seed: 随机种子
794
794
 
795
795
  Returns:
796
- (train, test) 两个 DataTransformer
796
+ (train, test) 两个 DataTransformer,各自拥有独立的血缘追踪器
797
797
  """
798
798
  data = self.shuffle(seed).data
799
799
  split_idx = int(len(data) * ratio)
800
800
 
801
- # 分割后血缘追踪器各自独立
801
+ # 分割后血缘追踪器各自独立(使用深拷贝避免相互影响)
802
802
  tracker = self._lineage_tracker
803
+ train_tracker = None
804
+ test_tracker = None
805
+
803
806
  if tracker:
804
807
  tracker.record("split", {"ratio": ratio, "seed": seed}, len(self._data), len(data))
808
+ # 为每个子数据集创建独立的追踪器副本
809
+ train_tracker = tracker.copy()
810
+ train_tracker.record("split_part", {"part": "train", "ratio": ratio}, len(data), split_idx)
811
+ test_tracker = tracker.copy()
812
+ test_tracker.record(
813
+ "split_part", {"part": "test", "ratio": 1 - ratio}, len(data), len(data) - split_idx
814
+ )
805
815
 
806
816
  return (
807
- DataTransformer(data[:split_idx], _lineage_tracker=tracker),
808
- DataTransformer(data[split_idx:], _lineage_tracker=tracker),
817
+ DataTransformer(data[:split_idx], _lineage_tracker=train_tracker),
818
+ DataTransformer(data[split_idx:], _lineage_tracker=test_tracker),
809
819
  )
810
820
 
811
821
  # ============ 并行处理 ============
@@ -815,6 +825,7 @@ class DataTransformer:
815
825
  func: Callable[[Dict], Any],
816
826
  workers: Optional[int] = None,
817
827
  chunksize: int = 1000,
828
+ timeout: Optional[float] = None,
818
829
  ) -> List[Any]:
819
830
  """
820
831
  并行执行转换函数(使用多进程)。
@@ -825,24 +836,46 @@ class DataTransformer:
825
836
  func: 转换函数,接收原始 dict,返回转换结果
826
837
  workers: 进程数,默认为 CPU 核心数
827
838
  chunksize: 每个进程处理的数据块大小
839
+ timeout: 超时时间(秒),None 表示无超时
828
840
 
829
841
  Returns:
830
842
  转换后的结果列表
831
843
 
844
+ Raises:
845
+ TypeError: 如果 func 无法被 pickle(如 lambda 函数)
846
+ RuntimeError: 如果子进程执行出错或超时
847
+
832
848
  Examples:
833
849
  >>> def transform(item):
834
850
  ... return {"id": item["id"], "text": item["text"].upper()}
835
851
  >>> results = dt.map_parallel(transform)
836
852
  """
837
- from multiprocessing import Pool, cpu_count
853
+ from multiprocessing import Pool, TimeoutError, cpu_count
854
+ import pickle
838
855
 
839
856
  if not self._data:
840
857
  return []
841
858
 
859
+ # 检查函数是否可 pickle
860
+ try:
861
+ pickle.dumps(func)
862
+ except (pickle.PicklingError, AttributeError, TypeError) as e:
863
+ func_name = getattr(func, "__name__", str(func))
864
+ raise TypeError(
865
+ f"函数 '{func_name}' 无法被 pickle,不能用于并行处理。"
866
+ f"请使用模块级函数而非 lambda 或闭包。错误: {e}"
867
+ ) from e
868
+
842
869
  workers = workers or cpu_count()
843
870
 
844
- with Pool(workers) as pool:
845
- results = pool.map(func, self._data, chunksize=chunksize)
871
+ try:
872
+ with Pool(workers) as pool:
873
+ async_result = pool.map_async(func, self._data, chunksize=chunksize)
874
+ results = async_result.get(timeout=timeout)
875
+ except TimeoutError:
876
+ raise RuntimeError(f"并行处理超时({timeout}秒)")
877
+ except Exception as e:
878
+ raise RuntimeError(f"并行处理失败: {type(e).__name__}: {e}") from e
846
879
 
847
880
  return results
848
881
 
@@ -851,6 +884,7 @@ class DataTransformer:
851
884
  func: Callable[[Dict], bool],
852
885
  workers: Optional[int] = None,
853
886
  chunksize: int = 1000,
887
+ timeout: Optional[float] = None,
854
888
  ) -> "DataTransformer":
855
889
  """
856
890
  并行执行过滤函数(使用多进程)。
@@ -861,24 +895,46 @@ class DataTransformer:
861
895
  func: 过滤函数,接收原始 dict,返回 True 保留
862
896
  workers: 进程数,默认为 CPU 核心数
863
897
  chunksize: 每个进程处理的数据块大小
898
+ timeout: 超时时间(秒),None 表示无超时
864
899
 
865
900
  Returns:
866
901
  过滤后的新 DataTransformer
867
902
 
903
+ Raises:
904
+ TypeError: 如果 func 无法被 pickle(如 lambda 函数)
905
+ RuntimeError: 如果子进程执行出错或超时
906
+
868
907
  Examples:
869
908
  >>> def is_valid(item):
870
909
  ... return len(item["text"]) > 10
871
910
  >>> filtered = dt.filter_parallel(is_valid)
872
911
  """
873
- from multiprocessing import Pool, cpu_count
912
+ from multiprocessing import Pool, TimeoutError, cpu_count
913
+ import pickle
874
914
 
875
915
  if not self._data:
876
916
  return DataTransformer([])
877
917
 
918
+ # 检查函数是否可 pickle
919
+ try:
920
+ pickle.dumps(func)
921
+ except (pickle.PicklingError, AttributeError, TypeError) as e:
922
+ func_name = getattr(func, "__name__", str(func))
923
+ raise TypeError(
924
+ f"函数 '{func_name}' 无法被 pickle,不能用于并行处理。"
925
+ f"请使用模块级函数而非 lambda 或闭包。错误: {e}"
926
+ ) from e
927
+
878
928
  workers = workers or cpu_count()
879
929
 
880
- with Pool(workers) as pool:
881
- mask = pool.map(func, self._data, chunksize=chunksize)
930
+ try:
931
+ with Pool(workers) as pool:
932
+ async_result = pool.map_async(func, self._data, chunksize=chunksize)
933
+ mask = async_result.get(timeout=timeout)
934
+ except TimeoutError:
935
+ raise RuntimeError(f"并行处理超时({timeout}秒)")
936
+ except Exception as e:
937
+ raise RuntimeError(f"并行处理失败: {type(e).__name__}: {e}") from e
882
938
 
883
939
  filtered = [item for item, keep in zip(self._data, mask) if keep]
884
940
  return DataTransformer(filtered)
@@ -237,6 +237,23 @@ class LineageTracker:
237
237
 
238
238
  return lineage_path
239
239
 
240
+ def copy(self) -> "LineageTracker":
241
+ """
242
+ 创建追踪器的深拷贝。
243
+
244
+ 用于 split() 等场景,确保子数据集有独立的血缘追踪。
245
+
246
+ Returns:
247
+ 新的 LineageTracker 实例
248
+ """
249
+ import copy as copy_module
250
+
251
+ new_tracker = LineageTracker.__new__(LineageTracker)
252
+ new_tracker.source_path = self.source_path
253
+ new_tracker.source_lineage = self.source_lineage # LineageRecord 是不可变的,可共享
254
+ new_tracker.operations = copy_module.deepcopy(self.operations)
255
+ return new_tracker
256
+
240
257
 
241
258
  def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
242
259
  """