dtflow 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/cli/commands.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """
2
2
  CLI 命令实现
3
3
  """
4
- import orjson
4
+
5
5
  import os
6
6
  import shutil
7
7
  import tempfile
@@ -9,13 +9,15 @@ from datetime import datetime
9
9
  from pathlib import Path
10
10
  from typing import Any, Dict, List, Literal, Optional
11
11
 
12
+ import orjson
13
+
12
14
  from ..core import DataTransformer, DictWrapper
13
- from ..presets import get_preset, list_presets
14
- from ..storage.io import load_data, save_data, sample_file
15
+ from ..lineage import format_lineage_report, get_lineage_chain, has_lineage, load_lineage
15
16
  from ..pipeline import run_pipeline, validate_pipeline
16
- from ..lineage import load_lineage, format_lineage_report, has_lineage, get_lineage_chain
17
+ from ..presets import get_preset, list_presets
18
+ from ..storage.io import load_data, sample_file, save_data
17
19
  from ..streaming import load_stream
18
-
20
+ from ..utils.field_path import get_field_with_spec
19
21
 
20
22
  # 支持的文件格式
21
23
  SUPPORTED_FORMATS = {".csv", ".jsonl", ".json", ".xlsx", ".xls", ".parquet", ".arrow", ".feather"}
@@ -92,9 +94,7 @@ def sample(
92
94
  # 分层采样模式
93
95
  if by:
94
96
  try:
95
- sampled = _stratified_sample(
96
- filepath, num, by, uniform, seed, type
97
- )
97
+ sampled = _stratified_sample(filepath, num, by, uniform, seed, type)
98
98
  except Exception as e:
99
99
  print(f"错误: {e}")
100
100
  return
@@ -138,7 +138,12 @@ def _stratified_sample(
138
138
  Args:
139
139
  filepath: 文件路径
140
140
  num: 目标采样总数
141
- stratify_field: 分层字段
141
+ stratify_field: 分层字段,支持嵌套路径语法:
142
+ - meta.source 嵌套字段
143
+ - messages[0].role 数组索引
144
+ - messages[-1].role 负索引
145
+ - messages.# 数组长度
146
+ - messages[*].role 展开所有元素(可加 :join/:unique 模式)
142
147
  uniform: 是否均匀采样(各组相同数量)
143
148
  seed: 随机种子
144
149
  sample_type: 采样方式(用于组内采样)
@@ -159,10 +164,13 @@ def _stratified_sample(
159
164
  if num <= 0 or num > total:
160
165
  num = total
161
166
 
162
- # 按字段分组
167
+ # 按字段分组(支持嵌套路径语法)
163
168
  groups: Dict[Any, List[Dict]] = defaultdict(list)
164
169
  for item in data:
165
- key = item.get(stratify_field, "__null__")
170
+ key = get_field_with_spec(item, stratify_field, default="__null__")
171
+ # 确保 key 可哈希
172
+ if isinstance(key, list):
173
+ key = tuple(key)
166
174
  groups[key].append(item)
167
175
 
168
176
  group_keys = list(groups.keys())
@@ -360,7 +368,7 @@ def _format_nested(
360
368
  if isinstance(value, dict):
361
369
  items = list(value.items())
362
370
  for i, (k, v) in enumerate(items):
363
- is_last_item = (i == len(items) - 1)
371
+ is_last_item = i == len(items) - 1
364
372
  b = "└─ " if is_last_item else "├─ "
365
373
  c = " " if is_last_item else "│ "
366
374
 
@@ -369,11 +377,12 @@ def _format_nested(
369
377
  if isinstance(v, list):
370
378
  # 检测是否为 messages 格式
371
379
  is_messages = (
372
- v and isinstance(v[0], dict)
373
- and "role" in v[0] and "content" in v[0]
380
+ v and isinstance(v[0], dict) and "role" in v[0] and "content" in v[0]
374
381
  )
375
382
  if is_messages:
376
- lines.append(f"{indent}{b}[green]{k}[/green]: ({len(v)} items) [dim]→ \\[role]: content[/dim]")
383
+ lines.append(
384
+ f"{indent}{b}[green]{k}[/green]: ({len(v)} items) [dim]→ \\[role]: content[/dim]"
385
+ )
377
386
  else:
378
387
  lines.append(f"{indent}{b}[green]{k}[/green]: ({len(v)} items)")
379
388
  else:
@@ -385,7 +394,7 @@ def _format_nested(
385
394
 
386
395
  elif isinstance(value, list):
387
396
  for i, item in enumerate(value):
388
- is_last_item = (i == len(value) - 1)
397
+ is_last_item = i == len(value) - 1
389
398
  b = "└─ " if is_last_item else "├─ "
390
399
  c = " " if is_last_item else "│ "
391
400
 
@@ -457,8 +466,8 @@ def _print_samples(
457
466
 
458
467
  try:
459
468
  from rich.console import Console
460
- from rich.table import Table
461
469
  from rich.panel import Panel
470
+ from rich.table import Table
462
471
 
463
472
  console = Console()
464
473
 
@@ -475,12 +484,14 @@ def _print_samples(
475
484
  else:
476
485
  info = f"采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
477
486
 
478
- console.print(Panel(
479
- f"[dim]{info}[/dim]\n[dim]字段: {field_names}[/dim]",
480
- title=f"[bold]📊 {filename}[/bold]",
481
- expand=False,
482
- border_style="dim",
483
- ))
487
+ console.print(
488
+ Panel(
489
+ f"[dim]{info}[/dim]\n[dim]字段: {field_names}[/dim]",
490
+ title=f"[bold]📊 {filename}[/bold]",
491
+ expand=False,
492
+ border_style="dim",
493
+ )
494
+ )
484
495
  console.print()
485
496
 
486
497
  # 简单数据用表格展示
@@ -514,7 +525,9 @@ def _print_samples(
514
525
 
515
526
  print(f"\n📊 {filename}")
516
527
  if total_count is not None:
517
- print(f" 总行数: {total_count:,} | 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个")
528
+ print(
529
+ f" 总行数: {total_count:,} | 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
530
+ )
518
531
  else:
519
532
  print(f" 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个")
520
533
  print(f" 字段: {', '.join(sorted(all_fields))}")
@@ -780,7 +793,7 @@ def _generate_default_transform(field_names: List[str]) -> str:
780
793
  for name in field_names[:5]: # 最多显示 5 个字段
781
794
  safe_name, _ = _sanitize_field_name(name)
782
795
  lines.append(f' "{name}": item.{safe_name},')
783
- return "\n".join(lines) if lines else ' # 在这里定义输出字段'
796
+ return "\n".join(lines) if lines else " # 在这里定义输出字段"
784
797
 
785
798
 
786
799
  def _execute_transform(
@@ -827,6 +840,7 @@ def _execute_transform(
827
840
  except Exception as e:
828
841
  print(f"错误: 转换失败 - {e}")
829
842
  import traceback
843
+
830
844
  traceback.print_exc()
831
845
  return
832
846
 
@@ -852,6 +866,7 @@ def _execute_transform(
852
866
  except Exception as e:
853
867
  print(f"错误: 转换失败 - {e}")
854
868
  import traceback
869
+
855
870
  traceback.print_exc()
856
871
  return
857
872
 
@@ -930,6 +945,7 @@ def _execute_preset_transform(
930
945
  os.unlink(temp_path)
931
946
  print(f"错误: 转换失败 - {e}")
932
947
  import traceback
948
+
933
949
  traceback.print_exc()
934
950
  return
935
951
 
@@ -955,6 +971,7 @@ def _execute_preset_transform(
955
971
  except Exception as e:
956
972
  print(f"错误: 转换失败 - {e}")
957
973
  import traceback
974
+
958
975
  traceback.print_exc()
959
976
  return
960
977
 
@@ -998,7 +1015,13 @@ def dedupe(
998
1015
 
999
1016
  Args:
1000
1017
  filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
1001
- key: 去重依据字段,多个字段用逗号分隔。不指定则全量去重
1018
+ key: 去重依据字段,支持嵌套路径语法:
1019
+ - meta.source 嵌套字段
1020
+ - messages[0].role 数组索引
1021
+ - messages[-1].content 负索引
1022
+ - messages.# 数组长度
1023
+ - messages[*].role:join 展开所有元素
1024
+ 多个字段用逗号分隔。不指定则全量去重
1002
1025
  similar: 相似度阈值(0-1),指定后启用相似度去重模式,需要指定 --key
1003
1026
  output: 输出文件路径,不指定则覆盖原文件
1004
1027
 
@@ -1006,8 +1029,9 @@ def dedupe(
1006
1029
  dt dedupe data.jsonl # 全量精确去重
1007
1030
  dt dedupe data.jsonl --key=text # 按 text 字段精确去重
1008
1031
  dt dedupe data.jsonl --key=user,timestamp # 按多字段组合精确去重
1009
- dt dedupe data.jsonl --key=text --similar=0.8 # 相似度去重
1010
- dt dedupe data.jsonl --output=clean.jsonl # 指定输出文件
1032
+ dt dedupe data.jsonl --key=meta.id # 按嵌套字段去重
1033
+ dt dedupe data.jsonl --key=messages[0].content # 按第一条消息内容去重
1034
+ dt dedupe data.jsonl --key=text --similar=0.8 # 相似度去重
1011
1035
  """
1012
1036
  filepath = Path(filename)
1013
1037
 
@@ -1132,8 +1156,13 @@ def concat(
1132
1156
 
1133
1157
  for filepath in file_paths:
1134
1158
  try:
1135
- # 只读取第一行来获取字段
1136
- first_row = load_stream(str(filepath)).head(1).collect()
1159
+ # 只读取第一行来获取字段(根据格式选择加载方式)
1160
+ if _is_streaming_supported(filepath):
1161
+ first_row = load_stream(str(filepath)).head(1).collect()
1162
+ else:
1163
+ # 非流式格式(如 .json, .xlsx)使用全量加载
1164
+ data = load_data(str(filepath))
1165
+ first_row = data[:1] if data else []
1137
1166
  if not first_row:
1138
1167
  print(f"警告: 文件为空 - {filepath}")
1139
1168
  fields = set()
@@ -1207,7 +1236,13 @@ def concat(
1207
1236
 
1208
1237
  def _concat_streaming(file_paths: List[Path], output: str) -> int:
1209
1238
  """流式拼接多个文件"""
1210
- from ..streaming import StreamingTransformer, _stream_jsonl, _stream_csv, _stream_parquet, _stream_arrow
1239
+ from ..streaming import (
1240
+ StreamingTransformer,
1241
+ _stream_arrow,
1242
+ _stream_csv,
1243
+ _stream_jsonl,
1244
+ _stream_parquet,
1245
+ )
1211
1246
 
1212
1247
  def generator():
1213
1248
  for filepath in file_paths:
@@ -1413,12 +1448,16 @@ def _truncate(v: Any, max_width: int) -> str:
1413
1448
  result = []
1414
1449
  for char in s:
1415
1450
  # CJK 字符范围
1416
- if '\u4e00' <= char <= '\u9fff' or '\u3000' <= char <= '\u303f' or '\uff00' <= char <= '\uffef':
1451
+ if (
1452
+ "\u4e00" <= char <= "\u9fff"
1453
+ or "\u3000" <= char <= "\u303f"
1454
+ or "\uff00" <= char <= "\uffef"
1455
+ ):
1417
1456
  char_width = 2
1418
1457
  else:
1419
1458
  char_width = 1
1420
1459
  if width + char_width > max_width - 3: # 预留 ... 的宽度
1421
- return ''.join(result) + "..."
1460
+ return "".join(result) + "..."
1422
1461
  result.append(char)
1423
1462
  width += char_width
1424
1463
  return s
@@ -1429,7 +1468,11 @@ def _display_width(s: str) -> int:
1429
1468
  width = 0
1430
1469
  for char in s:
1431
1470
  # CJK 字符范围
1432
- if '\u4e00' <= char <= '\u9fff' or '\u3000' <= char <= '\u303f' or '\uff00' <= char <= '\uffef':
1471
+ if (
1472
+ "\u4e00" <= char <= "\u9fff"
1473
+ or "\u3000" <= char <= "\u303f"
1474
+ or "\uff00" <= char <= "\uffef"
1475
+ ):
1433
1476
  width += 2
1434
1477
  else:
1435
1478
  width += 1
@@ -1441,26 +1484,28 @@ def _pad_to_width(s: str, target_width: int) -> str:
1441
1484
  current_width = _display_width(s)
1442
1485
  if current_width >= target_width:
1443
1486
  return s
1444
- return s + ' ' * (target_width - current_width)
1487
+ return s + " " * (target_width - current_width)
1445
1488
 
1446
1489
 
1447
1490
  def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -> None:
1448
1491
  """打印统计信息"""
1449
1492
  try:
1450
1493
  from rich.console import Console
1451
- from rich.table import Table
1452
1494
  from rich.panel import Panel
1495
+ from rich.table import Table
1453
1496
 
1454
1497
  console = Console()
1455
1498
 
1456
1499
  # 概览
1457
- console.print(Panel(
1458
- f"[bold]文件:[/bold] {filename}\n"
1459
- f"[bold]总数:[/bold] {total:,} 条\n"
1460
- f"[bold]字段:[/bold] {len(field_stats)} ",
1461
- title="📊 数据概览",
1462
- expand=False,
1463
- ))
1500
+ console.print(
1501
+ Panel(
1502
+ f"[bold]文件:[/bold] {filename}\n"
1503
+ f"[bold]总数:[/bold] {total:,} 条\n"
1504
+ f"[bold]字段:[/bold] {len(field_stats)} 个",
1505
+ title="📊 数据概览",
1506
+ expand=False,
1507
+ )
1508
+ )
1464
1509
 
1465
1510
  # 字段统计表
1466
1511
  table = Table(title="📋 字段统计", show_header=True, header_style="bold cyan")
@@ -1477,12 +1522,18 @@ def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -
1477
1522
  # 构建统计信息字符串
1478
1523
  extra = []
1479
1524
  if "len_avg" in stat:
1480
- extra.append(f"长度: {stat['len_min']}-{stat['len_max']} (avg {stat['len_avg']:.0f})")
1525
+ extra.append(
1526
+ f"长度: {stat['len_min']}-{stat['len_max']} (avg {stat['len_avg']:.0f})"
1527
+ )
1481
1528
  if "avg" in stat:
1482
1529
  if stat["type"] == "int":
1483
- extra.append(f"范围: {int(stat['min'])}-{int(stat['max'])} (avg {stat['avg']:.1f})")
1530
+ extra.append(
1531
+ f"范围: {int(stat['min'])}-{int(stat['max'])} (avg {stat['avg']:.1f})"
1532
+ )
1484
1533
  else:
1485
- extra.append(f"范围: {stat['min']:.2f}-{stat['max']:.2f} (avg {stat['avg']:.2f})")
1534
+ extra.append(
1535
+ f"范围: {stat['min']:.2f}-{stat['max']:.2f} (avg {stat['avg']:.2f})"
1536
+ )
1486
1537
 
1487
1538
  table.add_row(
1488
1539
  stat["field"],
@@ -1509,7 +1560,9 @@ def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -
1509
1560
  if unique_ratio > 0.9 and stat.get("unique", 0) > 100:
1510
1561
  continue
1511
1562
 
1512
- console.print(f"\n[bold cyan]{stat['field']}[/bold cyan] 值分布 (Top {len(top_values)}):")
1563
+ console.print(
1564
+ f"\n[bold cyan]{stat['field']}[/bold cyan] 值分布 (Top {len(top_values)}):"
1565
+ )
1513
1566
  max_count = max(c for _, c in top_values) if top_values else 1
1514
1567
  for value, count in top_values:
1515
1568
  pct = count / total * 100
@@ -1559,25 +1612,26 @@ def clean(
1559
1612
 
1560
1613
  Args:
1561
1614
  filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
1562
- drop_empty: 删除空值记录
1615
+ drop_empty: 删除空值记录,支持嵌套路径语法
1563
1616
  - 不带值:删除任意字段为空的记录
1564
1617
  - 指定字段:删除指定字段为空的记录(逗号分隔)
1565
- min_len: 最小长度过滤,格式 "字段:长度"(如 text:10)
1566
- max_len: 最大长度过滤,格式 "字段:长度"(如 text:1000)
1567
- keep: 只保留指定字段(逗号分隔)
1568
- drop: 删除指定字段(逗号分隔)
1618
+ min_len: 最小长度过滤,格式 "字段:长度",字段支持嵌套路径
1619
+ max_len: 最大长度过滤,格式 "字段:长度",字段支持嵌套路径
1620
+ keep: 只保留指定字段(逗号分隔,仅支持顶层字段)
1621
+ drop: 删除指定字段(逗号分隔,仅支持顶层字段)
1569
1622
  strip: 去除所有字符串字段的首尾空白
1570
1623
  output: 输出文件路径,不指定则覆盖原文件
1571
1624
 
1572
1625
  Examples:
1573
1626
  dt clean data.jsonl --drop-empty # 删除任意空值记录
1574
1627
  dt clean data.jsonl --drop-empty=text,answer # 删除指定字段为空的记录
1628
+ dt clean data.jsonl --drop-empty=meta.source # 删除嵌套字段为空的记录
1575
1629
  dt clean data.jsonl --min-len=text:10 # text 字段最少 10 字符
1576
- dt clean data.jsonl --max-len=text:1000 # text 字段最多 1000 字符
1630
+ dt clean data.jsonl --min-len=messages.#:2 # 至少 2 条消息
1631
+ dt clean data.jsonl --max-len=messages[-1].content:500 # 最后一条消息最多 500 字符
1577
1632
  dt clean data.jsonl --keep=question,answer # 只保留这些字段
1578
1633
  dt clean data.jsonl --drop=metadata,timestamp # 删除这些字段
1579
1634
  dt clean data.jsonl --strip # 去除字符串首尾空白
1580
- dt clean data.jsonl --drop-empty --strip -o out.jsonl
1581
1635
  """
1582
1636
  filepath = Path(filename)
1583
1637
 
@@ -1666,6 +1720,7 @@ def clean(
1666
1720
  os.unlink(temp_path)
1667
1721
  print(f"错误: 清洗失败 - {e}")
1668
1722
  import traceback
1723
+
1669
1724
  traceback.print_exc()
1670
1725
  return
1671
1726
 
@@ -1746,9 +1801,18 @@ def _is_empty_value(v: Any) -> bool:
1746
1801
 
1747
1802
 
1748
1803
  def _get_value_len(value: Any) -> int:
1749
- """获取值的长度"""
1804
+ """
1805
+ 获取值的长度。
1806
+
1807
+ - str/list/dict: 返回 len()
1808
+ - int/float: 直接返回该数值(用于 messages.# 这种返回数量的场景)
1809
+ - None: 返回 0
1810
+ - 其他: 转为字符串后返回长度
1811
+ """
1750
1812
  if value is None:
1751
1813
  return 0
1814
+ if isinstance(value, (int, float)):
1815
+ return int(value)
1752
1816
  if isinstance(value, (str, list, dict)):
1753
1817
  return len(value)
1754
1818
  return len(str(value))
@@ -1771,13 +1835,13 @@ def _clean_data_single_pass(
1771
1835
  Args:
1772
1836
  data: 原始数据列表
1773
1837
  strip: 是否去除字符串首尾空白
1774
- empty_fields: 检查空值的字段列表,空列表表示检查所有字段,None 表示不检查
1775
- min_len_field: 最小长度检查的字段
1838
+ empty_fields: 检查空值的字段列表(支持嵌套路径),空列表表示检查所有字段,None 表示不检查
1839
+ min_len_field: 最小长度检查的字段(支持嵌套路径)
1776
1840
  min_len_value: 最小长度值
1777
- max_len_field: 最大长度检查的字段
1841
+ max_len_field: 最大长度检查的字段(支持嵌套路径)
1778
1842
  max_len_value: 最大长度值
1779
- keep_fields: 只保留的字段列表
1780
- drop_fields: 要删除的字段集合
1843
+ keep_fields: 只保留的字段列表(仅支持顶层字段)
1844
+ drop_fields: 要删除的字段集合(仅支持顶层字段)
1781
1845
 
1782
1846
  Returns:
1783
1847
  (清洗后的数据, 统计信息列表)
@@ -1805,20 +1869,20 @@ def _clean_data_single_pass(
1805
1869
  stats["drop_empty"] += 1
1806
1870
  continue
1807
1871
  else:
1808
- # 检查指定字段
1809
- if any(_is_empty_value(item.get(f)) for f in empty_fields):
1872
+ # 检查指定字段(支持嵌套路径)
1873
+ if any(_is_empty_value(get_field_with_spec(item, f)) for f in empty_fields):
1810
1874
  stats["drop_empty"] += 1
1811
1875
  continue
1812
1876
 
1813
- # 3. 最小长度过滤
1877
+ # 3. 最小长度过滤(支持嵌套路径)
1814
1878
  if min_len_field is not None:
1815
- if _get_value_len(item.get(min_len_field, "")) < min_len_value:
1879
+ if _get_value_len(get_field_with_spec(item, min_len_field, default="")) < min_len_value:
1816
1880
  stats["min_len"] += 1
1817
1881
  continue
1818
1882
 
1819
- # 4. 最大长度过滤
1883
+ # 4. 最大长度过滤(支持嵌套路径)
1820
1884
  if max_len_field is not None:
1821
- if _get_value_len(item.get(max_len_field, "")) > max_len_value:
1885
+ if _get_value_len(get_field_with_spec(item, max_len_field, default="")) > max_len_value:
1822
1886
  stats["max_len"] += 1
1823
1887
  continue
1824
1888
 
@@ -1866,25 +1930,27 @@ def _clean_streaming(
1866
1930
  Returns:
1867
1931
  处理后的数据条数
1868
1932
  """
1933
+
1869
1934
  def clean_filter(item: Dict) -> bool:
1870
- """过滤函数:返回 True 保留,False 过滤"""
1935
+ """过滤函数:返回 True 保留,False 过滤(支持嵌套路径)"""
1871
1936
  # 空值过滤
1872
1937
  if empty_fields is not None:
1873
1938
  if len(empty_fields) == 0:
1874
1939
  if any(_is_empty_value(v) for v in item.values()):
1875
1940
  return False
1876
1941
  else:
1877
- if any(_is_empty_value(item.get(f)) for f in empty_fields):
1942
+ # 支持嵌套路径
1943
+ if any(_is_empty_value(get_field_with_spec(item, f)) for f in empty_fields):
1878
1944
  return False
1879
1945
 
1880
- # 最小长度过滤
1946
+ # 最小长度过滤(支持嵌套路径)
1881
1947
  if min_len_field is not None:
1882
- if _get_value_len(item.get(min_len_field, "")) < min_len_value:
1948
+ if _get_value_len(get_field_with_spec(item, min_len_field, default="")) < min_len_value:
1883
1949
  return False
1884
1950
 
1885
- # 最大长度过滤
1951
+ # 最大长度过滤(支持嵌套路径)
1886
1952
  if max_len_field is not None:
1887
- if _get_value_len(item.get(max_len_field, "")) > max_len_value:
1953
+ if _get_value_len(get_field_with_spec(item, max_len_field, default="")) > max_len_value:
1888
1954
  return False
1889
1955
 
1890
1956
  return True
@@ -1908,7 +1974,9 @@ def _clean_streaming(
1908
1974
 
1909
1975
  # 如果需要 strip,先执行 strip 转换(在过滤之前,这样空值检测更准确)
1910
1976
  if strip:
1911
- st = st.transform(lambda x: {k: v.strip() if isinstance(v, str) else v for k, v in x.items()})
1977
+ st = st.transform(
1978
+ lambda x: {k: v.strip() if isinstance(v, str) else v for k, v in x.items()}
1979
+ )
1912
1980
 
1913
1981
  # 执行过滤
1914
1982
  if empty_fields is not None or min_len_field is not None or max_len_field is not None:
@@ -1916,12 +1984,14 @@ def _clean_streaming(
1916
1984
 
1917
1985
  # 执行字段管理(如果没有 strip,也需要在这里处理)
1918
1986
  if keep_set is not None or drop_fields_set is not None:
1987
+
1919
1988
  def field_transform(item):
1920
1989
  if keep_set is not None:
1921
1990
  return {k: v for k, v in item.items() if k in keep_set}
1922
1991
  elif drop_fields_set is not None:
1923
1992
  return {k: v for k, v in item.items() if k not in drop_fields_set}
1924
1993
  return item
1994
+
1925
1995
  st = st.transform(field_transform)
1926
1996
 
1927
1997
  return st.save(output_path)
@@ -1972,6 +2042,7 @@ def run(
1972
2042
  except Exception as e:
1973
2043
  print(f"错误: {e}")
1974
2044
  import traceback
2045
+
1975
2046
  traceback.print_exc()
1976
2047
 
1977
2048
 
@@ -1989,13 +2060,15 @@ def token_stats(
1989
2060
 
1990
2061
  Args:
1991
2062
  filename: 输入文件路径
1992
- field: 要统计的字段(默认 messages
2063
+ field: 要统计的字段(默认 messages),支持嵌套路径语法
1993
2064
  model: 分词器: cl100k_base (默认), qwen2.5, llama3, gpt-4 等
1994
2065
  detailed: 是否显示详细统计
1995
2066
 
1996
2067
  Examples:
1997
2068
  dt token-stats data.jsonl
1998
2069
  dt token-stats data.jsonl --field=text --model=qwen2.5
2070
+ dt token-stats data.jsonl --field=conversation.messages
2071
+ dt token-stats data.jsonl --field=messages[-1].content # 统计最后一条消息
1999
2072
  dt token-stats data.jsonl --detailed
2000
2073
  """
2001
2074
  filepath = Path(filename)
@@ -2023,19 +2096,21 @@ def token_stats(
2023
2096
  print(f" 共 {total} 条数据")
2024
2097
  print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
2025
2098
 
2026
- # 检查字段类型并选择合适的统计方法
2099
+ # 检查字段类型并选择合适的统计方法(支持嵌套路径)
2027
2100
  sample = data[0]
2028
- field_value = sample.get(field)
2101
+ field_value = get_field_with_spec(sample, field)
2029
2102
 
2030
2103
  try:
2031
2104
  if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
2032
2105
  # messages 格式
2033
2106
  from ..tokenizers import messages_token_stats
2107
+
2034
2108
  stats = messages_token_stats(data, messages_field=field, model=model)
2035
2109
  _print_messages_token_stats(stats, detailed)
2036
2110
  else:
2037
2111
  # 普通文本字段
2038
2112
  from ..tokenizers import token_stats as compute_token_stats
2113
+
2039
2114
  stats = compute_token_stats(data, fields=field, model=model)
2040
2115
  _print_text_token_stats(stats, detailed)
2041
2116
  except ImportError as e:
@@ -2044,6 +2119,7 @@ def token_stats(
2044
2119
  except Exception as e:
2045
2120
  print(f"错误: 统计失败 - {e}")
2046
2121
  import traceback
2122
+
2047
2123
  traceback.print_exc()
2048
2124
 
2049
2125
 
@@ -2051,8 +2127,8 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
2051
2127
  """打印 messages 格式的 token 统计"""
2052
2128
  try:
2053
2129
  from rich.console import Console
2054
- from rich.table import Table
2055
2130
  from rich.panel import Panel
2131
+ from rich.table import Table
2056
2132
 
2057
2133
  console = Console()
2058
2134
 
@@ -2073,8 +2149,12 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
2073
2149
  table.add_column("Token 数", justify="right")
2074
2150
  table.add_column("占比", justify="right")
2075
2151
 
2076
- total = stats['total_tokens']
2077
- for role, key in [("User", "user_tokens"), ("Assistant", "assistant_tokens"), ("System", "system_tokens")]:
2152
+ total = stats["total_tokens"]
2153
+ for role, key in [
2154
+ ("User", "user_tokens"),
2155
+ ("Assistant", "assistant_tokens"),
2156
+ ("System", "system_tokens"),
2157
+ ]:
2078
2158
  tokens = stats.get(key, 0)
2079
2159
  pct = tokens / total * 100 if total > 0 else 0
2080
2160
  table.add_row(role, f"{tokens:,}", f"{pct:.1f}%")
@@ -2097,8 +2177,12 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
2097
2177
  print(f"\n{'=' * 40}")
2098
2178
  print("📋 分角色统计")
2099
2179
  print(f"{'=' * 40}")
2100
- total = stats['total_tokens']
2101
- for role, key in [("User", "user_tokens"), ("Assistant", "assistant_tokens"), ("System", "system_tokens")]:
2180
+ total = stats["total_tokens"]
2181
+ for role, key in [
2182
+ ("User", "user_tokens"),
2183
+ ("Assistant", "assistant_tokens"),
2184
+ ("System", "system_tokens"),
2185
+ ]:
2102
2186
  tokens = stats.get(key, 0)
2103
2187
  pct = tokens / total * 100 if total > 0 else 0
2104
2188
  print(f"{role}: {tokens:,} ({pct:.1f}%)")
@@ -2148,12 +2232,13 @@ def diff(
2148
2232
  Args:
2149
2233
  file1: 第一个文件路径
2150
2234
  file2: 第二个文件路径
2151
- key: 用于匹配的键字段(可选)
2235
+ key: 用于匹配的键字段,支持嵌套路径语法(可选)
2152
2236
  output: 差异报告输出路径(可选)
2153
2237
 
2154
2238
  Examples:
2155
2239
  dt diff v1/train.jsonl v2/train.jsonl
2156
2240
  dt diff a.jsonl b.jsonl --key=id
2241
+ dt diff a.jsonl b.jsonl --key=meta.uuid # 按嵌套字段匹配
2157
2242
  dt diff a.jsonl b.jsonl --output=diff_report.json
2158
2243
  """
2159
2244
  path1 = Path(file1)
@@ -2216,9 +2301,9 @@ def _compute_diff(
2216
2301
  }
2217
2302
 
2218
2303
  if key:
2219
- # 基于 key 的精确匹配
2220
- dict1 = {item.get(key): item for item in data1 if item.get(key) is not None}
2221
- dict2 = {item.get(key): item for item in data2 if item.get(key) is not None}
2304
+ # 基于 key 的精确匹配(支持嵌套路径)
2305
+ dict1 = {get_field_with_spec(item, key): item for item in data1 if get_field_with_spec(item, key) is not None}
2306
+ dict2 = {get_field_with_spec(item, key): item for item in data2 if get_field_with_spec(item, key) is not None}
2222
2307
 
2223
2308
  keys1 = set(dict1.keys())
2224
2309
  keys2 = set(dict2.keys())
@@ -2241,11 +2326,13 @@ def _compute_diff(
2241
2326
  else:
2242
2327
  result["summary"]["modified"] += 1
2243
2328
  if len(result["details"]["modified"]) < 10:
2244
- result["details"]["modified"].append({
2245
- "key": k,
2246
- "before": dict1[k],
2247
- "after": dict2[k],
2248
- })
2329
+ result["details"]["modified"].append(
2330
+ {
2331
+ "key": k,
2332
+ "before": dict1[k],
2333
+ "after": dict2[k],
2334
+ }
2335
+ )
2249
2336
  else:
2250
2337
  # 基于哈希的比较
2251
2338
  def _hash_item(item):
@@ -2290,8 +2377,8 @@ def _print_diff_report(diff_result: Dict[str, Any], name1: str, name2: str) -> N
2290
2377
 
2291
2378
  try:
2292
2379
  from rich.console import Console
2293
- from rich.table import Table
2294
2380
  from rich.panel import Panel
2381
+ from rich.table import Table
2295
2382
 
2296
2383
  console = Console()
2297
2384
 
@@ -2311,9 +2398,13 @@ def _print_diff_report(diff_result: Dict[str, Any], name1: str, name2: str) -> N
2311
2398
  if field_changes["added_fields"] or field_changes["removed_fields"]:
2312
2399
  console.print("\n[bold]📋 字段变化:[/bold]")
2313
2400
  if field_changes["added_fields"]:
2314
- console.print(f" [green]+ 新增字段:[/green] {', '.join(field_changes['added_fields'])}")
2401
+ console.print(
2402
+ f" [green]+ 新增字段:[/green] {', '.join(field_changes['added_fields'])}"
2403
+ )
2315
2404
  if field_changes["removed_fields"]:
2316
- console.print(f" [red]- 删除字段:[/red] {', '.join(field_changes['removed_fields'])}")
2405
+ console.print(
2406
+ f" [red]- 删除字段:[/red] {', '.join(field_changes['removed_fields'])}"
2407
+ )
2317
2408
 
2318
2409
  except ImportError:
2319
2410
  print(f"\n{'=' * 50}")