dtflow 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/__init__.py +69 -58
- dtflow/__main__.py +29 -20
- dtflow/cli/__init__.py +25 -4
- dtflow/cli/commands.py +184 -93
- dtflow/converters.py +39 -23
- dtflow/core.py +79 -51
- dtflow/lineage.py +6 -3
- dtflow/mcp/__init__.py +1 -0
- dtflow/mcp/__main__.py +2 -0
- dtflow/mcp/cli.py +22 -4
- dtflow/mcp/docs.py +0 -5
- dtflow/pipeline.py +33 -23
- dtflow/presets.py +24 -22
- dtflow/storage/__init__.py +11 -10
- dtflow/storage/io.py +19 -10
- dtflow/streaming.py +13 -18
- dtflow/tokenizers.py +32 -12
- dtflow/utils/__init__.py +20 -1
- dtflow/utils/display.py +23 -23
- dtflow/utils/field_path.py +274 -0
- {dtflow-0.3.1.dist-info → dtflow-0.4.0.dist-info}/METADATA +48 -3
- dtflow-0.4.0.dist-info/RECORD +25 -0
- dtflow-0.3.1.dist-info/RECORD +0 -24
- {dtflow-0.3.1.dist-info → dtflow-0.4.0.dist-info}/WHEEL +0 -0
- {dtflow-0.3.1.dist-info → dtflow-0.4.0.dist-info}/entry_points.txt +0 -0
dtflow/cli/commands.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
CLI 命令实现
|
|
3
3
|
"""
|
|
4
|
-
|
|
4
|
+
|
|
5
5
|
import os
|
|
6
6
|
import shutil
|
|
7
7
|
import tempfile
|
|
@@ -9,13 +9,15 @@ from datetime import datetime
|
|
|
9
9
|
from pathlib import Path
|
|
10
10
|
from typing import Any, Dict, List, Literal, Optional
|
|
11
11
|
|
|
12
|
+
import orjson
|
|
13
|
+
|
|
12
14
|
from ..core import DataTransformer, DictWrapper
|
|
13
|
-
from ..
|
|
14
|
-
from ..storage.io import load_data, save_data, sample_file
|
|
15
|
+
from ..lineage import format_lineage_report, get_lineage_chain, has_lineage, load_lineage
|
|
15
16
|
from ..pipeline import run_pipeline, validate_pipeline
|
|
16
|
-
from ..
|
|
17
|
+
from ..presets import get_preset, list_presets
|
|
18
|
+
from ..storage.io import load_data, sample_file, save_data
|
|
17
19
|
from ..streaming import load_stream
|
|
18
|
-
|
|
20
|
+
from ..utils.field_path import get_field_with_spec
|
|
19
21
|
|
|
20
22
|
# 支持的文件格式
|
|
21
23
|
SUPPORTED_FORMATS = {".csv", ".jsonl", ".json", ".xlsx", ".xls", ".parquet", ".arrow", ".feather"}
|
|
@@ -92,9 +94,7 @@ def sample(
|
|
|
92
94
|
# 分层采样模式
|
|
93
95
|
if by:
|
|
94
96
|
try:
|
|
95
|
-
sampled = _stratified_sample(
|
|
96
|
-
filepath, num, by, uniform, seed, type
|
|
97
|
-
)
|
|
97
|
+
sampled = _stratified_sample(filepath, num, by, uniform, seed, type)
|
|
98
98
|
except Exception as e:
|
|
99
99
|
print(f"错误: {e}")
|
|
100
100
|
return
|
|
@@ -138,7 +138,12 @@ def _stratified_sample(
|
|
|
138
138
|
Args:
|
|
139
139
|
filepath: 文件路径
|
|
140
140
|
num: 目标采样总数
|
|
141
|
-
stratify_field:
|
|
141
|
+
stratify_field: 分层字段,支持嵌套路径语法:
|
|
142
|
+
- meta.source 嵌套字段
|
|
143
|
+
- messages[0].role 数组索引
|
|
144
|
+
- messages[-1].role 负索引
|
|
145
|
+
- messages.# 数组长度
|
|
146
|
+
- messages[*].role 展开所有元素(可加 :join/:unique 模式)
|
|
142
147
|
uniform: 是否均匀采样(各组相同数量)
|
|
143
148
|
seed: 随机种子
|
|
144
149
|
sample_type: 采样方式(用于组内采样)
|
|
@@ -159,10 +164,13 @@ def _stratified_sample(
|
|
|
159
164
|
if num <= 0 or num > total:
|
|
160
165
|
num = total
|
|
161
166
|
|
|
162
|
-
#
|
|
167
|
+
# 按字段分组(支持嵌套路径语法)
|
|
163
168
|
groups: Dict[Any, List[Dict]] = defaultdict(list)
|
|
164
169
|
for item in data:
|
|
165
|
-
key = item
|
|
170
|
+
key = get_field_with_spec(item, stratify_field, default="__null__")
|
|
171
|
+
# 确保 key 可哈希
|
|
172
|
+
if isinstance(key, list):
|
|
173
|
+
key = tuple(key)
|
|
166
174
|
groups[key].append(item)
|
|
167
175
|
|
|
168
176
|
group_keys = list(groups.keys())
|
|
@@ -360,7 +368,7 @@ def _format_nested(
|
|
|
360
368
|
if isinstance(value, dict):
|
|
361
369
|
items = list(value.items())
|
|
362
370
|
for i, (k, v) in enumerate(items):
|
|
363
|
-
is_last_item =
|
|
371
|
+
is_last_item = i == len(items) - 1
|
|
364
372
|
b = "└─ " if is_last_item else "├─ "
|
|
365
373
|
c = " " if is_last_item else "│ "
|
|
366
374
|
|
|
@@ -369,11 +377,12 @@ def _format_nested(
|
|
|
369
377
|
if isinstance(v, list):
|
|
370
378
|
# 检测是否为 messages 格式
|
|
371
379
|
is_messages = (
|
|
372
|
-
v and isinstance(v[0], dict)
|
|
373
|
-
and "role" in v[0] and "content" in v[0]
|
|
380
|
+
v and isinstance(v[0], dict) and "role" in v[0] and "content" in v[0]
|
|
374
381
|
)
|
|
375
382
|
if is_messages:
|
|
376
|
-
lines.append(
|
|
383
|
+
lines.append(
|
|
384
|
+
f"{indent}{b}[green]{k}[/green]: ({len(v)} items) [dim]→ \\[role]: content[/dim]"
|
|
385
|
+
)
|
|
377
386
|
else:
|
|
378
387
|
lines.append(f"{indent}{b}[green]{k}[/green]: ({len(v)} items)")
|
|
379
388
|
else:
|
|
@@ -385,7 +394,7 @@ def _format_nested(
|
|
|
385
394
|
|
|
386
395
|
elif isinstance(value, list):
|
|
387
396
|
for i, item in enumerate(value):
|
|
388
|
-
is_last_item =
|
|
397
|
+
is_last_item = i == len(value) - 1
|
|
389
398
|
b = "└─ " if is_last_item else "├─ "
|
|
390
399
|
c = " " if is_last_item else "│ "
|
|
391
400
|
|
|
@@ -457,8 +466,8 @@ def _print_samples(
|
|
|
457
466
|
|
|
458
467
|
try:
|
|
459
468
|
from rich.console import Console
|
|
460
|
-
from rich.table import Table
|
|
461
469
|
from rich.panel import Panel
|
|
470
|
+
from rich.table import Table
|
|
462
471
|
|
|
463
472
|
console = Console()
|
|
464
473
|
|
|
@@ -475,12 +484,14 @@ def _print_samples(
|
|
|
475
484
|
else:
|
|
476
485
|
info = f"采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
|
|
477
486
|
|
|
478
|
-
console.print(
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
487
|
+
console.print(
|
|
488
|
+
Panel(
|
|
489
|
+
f"[dim]{info}[/dim]\n[dim]字段: {field_names}[/dim]",
|
|
490
|
+
title=f"[bold]📊 {filename}[/bold]",
|
|
491
|
+
expand=False,
|
|
492
|
+
border_style="dim",
|
|
493
|
+
)
|
|
494
|
+
)
|
|
484
495
|
console.print()
|
|
485
496
|
|
|
486
497
|
# 简单数据用表格展示
|
|
@@ -514,7 +525,9 @@ def _print_samples(
|
|
|
514
525
|
|
|
515
526
|
print(f"\n📊 {filename}")
|
|
516
527
|
if total_count is not None:
|
|
517
|
-
print(
|
|
528
|
+
print(
|
|
529
|
+
f" 总行数: {total_count:,} | 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
|
|
530
|
+
)
|
|
518
531
|
else:
|
|
519
532
|
print(f" 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个")
|
|
520
533
|
print(f" 字段: {', '.join(sorted(all_fields))}")
|
|
@@ -780,7 +793,7 @@ def _generate_default_transform(field_names: List[str]) -> str:
|
|
|
780
793
|
for name in field_names[:5]: # 最多显示 5 个字段
|
|
781
794
|
safe_name, _ = _sanitize_field_name(name)
|
|
782
795
|
lines.append(f' "{name}": item.{safe_name},')
|
|
783
|
-
return "\n".join(lines) if lines else
|
|
796
|
+
return "\n".join(lines) if lines else " # 在这里定义输出字段"
|
|
784
797
|
|
|
785
798
|
|
|
786
799
|
def _execute_transform(
|
|
@@ -827,6 +840,7 @@ def _execute_transform(
|
|
|
827
840
|
except Exception as e:
|
|
828
841
|
print(f"错误: 转换失败 - {e}")
|
|
829
842
|
import traceback
|
|
843
|
+
|
|
830
844
|
traceback.print_exc()
|
|
831
845
|
return
|
|
832
846
|
|
|
@@ -852,6 +866,7 @@ def _execute_transform(
|
|
|
852
866
|
except Exception as e:
|
|
853
867
|
print(f"错误: 转换失败 - {e}")
|
|
854
868
|
import traceback
|
|
869
|
+
|
|
855
870
|
traceback.print_exc()
|
|
856
871
|
return
|
|
857
872
|
|
|
@@ -930,6 +945,7 @@ def _execute_preset_transform(
|
|
|
930
945
|
os.unlink(temp_path)
|
|
931
946
|
print(f"错误: 转换失败 - {e}")
|
|
932
947
|
import traceback
|
|
948
|
+
|
|
933
949
|
traceback.print_exc()
|
|
934
950
|
return
|
|
935
951
|
|
|
@@ -955,6 +971,7 @@ def _execute_preset_transform(
|
|
|
955
971
|
except Exception as e:
|
|
956
972
|
print(f"错误: 转换失败 - {e}")
|
|
957
973
|
import traceback
|
|
974
|
+
|
|
958
975
|
traceback.print_exc()
|
|
959
976
|
return
|
|
960
977
|
|
|
@@ -998,7 +1015,13 @@ def dedupe(
|
|
|
998
1015
|
|
|
999
1016
|
Args:
|
|
1000
1017
|
filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
|
|
1001
|
-
key:
|
|
1018
|
+
key: 去重依据字段,支持嵌套路径语法:
|
|
1019
|
+
- meta.source 嵌套字段
|
|
1020
|
+
- messages[0].role 数组索引
|
|
1021
|
+
- messages[-1].content 负索引
|
|
1022
|
+
- messages.# 数组长度
|
|
1023
|
+
- messages[*].role:join 展开所有元素
|
|
1024
|
+
多个字段用逗号分隔。不指定则全量去重
|
|
1002
1025
|
similar: 相似度阈值(0-1),指定后启用相似度去重模式,需要指定 --key
|
|
1003
1026
|
output: 输出文件路径,不指定则覆盖原文件
|
|
1004
1027
|
|
|
@@ -1006,8 +1029,9 @@ def dedupe(
|
|
|
1006
1029
|
dt dedupe data.jsonl # 全量精确去重
|
|
1007
1030
|
dt dedupe data.jsonl --key=text # 按 text 字段精确去重
|
|
1008
1031
|
dt dedupe data.jsonl --key=user,timestamp # 按多字段组合精确去重
|
|
1009
|
-
dt dedupe data.jsonl --key=
|
|
1010
|
-
dt dedupe data.jsonl --
|
|
1032
|
+
dt dedupe data.jsonl --key=meta.id # 按嵌套字段去重
|
|
1033
|
+
dt dedupe data.jsonl --key=messages[0].content # 按第一条消息内容去重
|
|
1034
|
+
dt dedupe data.jsonl --key=text --similar=0.8 # 相似度去重
|
|
1011
1035
|
"""
|
|
1012
1036
|
filepath = Path(filename)
|
|
1013
1037
|
|
|
@@ -1132,8 +1156,13 @@ def concat(
|
|
|
1132
1156
|
|
|
1133
1157
|
for filepath in file_paths:
|
|
1134
1158
|
try:
|
|
1135
|
-
#
|
|
1136
|
-
|
|
1159
|
+
# 只读取第一行来获取字段(根据格式选择加载方式)
|
|
1160
|
+
if _is_streaming_supported(filepath):
|
|
1161
|
+
first_row = load_stream(str(filepath)).head(1).collect()
|
|
1162
|
+
else:
|
|
1163
|
+
# 非流式格式(如 .json, .xlsx)使用全量加载
|
|
1164
|
+
data = load_data(str(filepath))
|
|
1165
|
+
first_row = data[:1] if data else []
|
|
1137
1166
|
if not first_row:
|
|
1138
1167
|
print(f"警告: 文件为空 - {filepath}")
|
|
1139
1168
|
fields = set()
|
|
@@ -1207,7 +1236,13 @@ def concat(
|
|
|
1207
1236
|
|
|
1208
1237
|
def _concat_streaming(file_paths: List[Path], output: str) -> int:
|
|
1209
1238
|
"""流式拼接多个文件"""
|
|
1210
|
-
from ..streaming import
|
|
1239
|
+
from ..streaming import (
|
|
1240
|
+
StreamingTransformer,
|
|
1241
|
+
_stream_arrow,
|
|
1242
|
+
_stream_csv,
|
|
1243
|
+
_stream_jsonl,
|
|
1244
|
+
_stream_parquet,
|
|
1245
|
+
)
|
|
1211
1246
|
|
|
1212
1247
|
def generator():
|
|
1213
1248
|
for filepath in file_paths:
|
|
@@ -1413,12 +1448,16 @@ def _truncate(v: Any, max_width: int) -> str:
|
|
|
1413
1448
|
result = []
|
|
1414
1449
|
for char in s:
|
|
1415
1450
|
# CJK 字符范围
|
|
1416
|
-
if
|
|
1451
|
+
if (
|
|
1452
|
+
"\u4e00" <= char <= "\u9fff"
|
|
1453
|
+
or "\u3000" <= char <= "\u303f"
|
|
1454
|
+
or "\uff00" <= char <= "\uffef"
|
|
1455
|
+
):
|
|
1417
1456
|
char_width = 2
|
|
1418
1457
|
else:
|
|
1419
1458
|
char_width = 1
|
|
1420
1459
|
if width + char_width > max_width - 3: # 预留 ... 的宽度
|
|
1421
|
-
return
|
|
1460
|
+
return "".join(result) + "..."
|
|
1422
1461
|
result.append(char)
|
|
1423
1462
|
width += char_width
|
|
1424
1463
|
return s
|
|
@@ -1429,7 +1468,11 @@ def _display_width(s: str) -> int:
|
|
|
1429
1468
|
width = 0
|
|
1430
1469
|
for char in s:
|
|
1431
1470
|
# CJK 字符范围
|
|
1432
|
-
if
|
|
1471
|
+
if (
|
|
1472
|
+
"\u4e00" <= char <= "\u9fff"
|
|
1473
|
+
or "\u3000" <= char <= "\u303f"
|
|
1474
|
+
or "\uff00" <= char <= "\uffef"
|
|
1475
|
+
):
|
|
1433
1476
|
width += 2
|
|
1434
1477
|
else:
|
|
1435
1478
|
width += 1
|
|
@@ -1441,26 +1484,28 @@ def _pad_to_width(s: str, target_width: int) -> str:
|
|
|
1441
1484
|
current_width = _display_width(s)
|
|
1442
1485
|
if current_width >= target_width:
|
|
1443
1486
|
return s
|
|
1444
|
-
return s +
|
|
1487
|
+
return s + " " * (target_width - current_width)
|
|
1445
1488
|
|
|
1446
1489
|
|
|
1447
1490
|
def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -> None:
|
|
1448
1491
|
"""打印统计信息"""
|
|
1449
1492
|
try:
|
|
1450
1493
|
from rich.console import Console
|
|
1451
|
-
from rich.table import Table
|
|
1452
1494
|
from rich.panel import Panel
|
|
1495
|
+
from rich.table import Table
|
|
1453
1496
|
|
|
1454
1497
|
console = Console()
|
|
1455
1498
|
|
|
1456
1499
|
# 概览
|
|
1457
|
-
console.print(
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1500
|
+
console.print(
|
|
1501
|
+
Panel(
|
|
1502
|
+
f"[bold]文件:[/bold] {filename}\n"
|
|
1503
|
+
f"[bold]总数:[/bold] {total:,} 条\n"
|
|
1504
|
+
f"[bold]字段:[/bold] {len(field_stats)} 个",
|
|
1505
|
+
title="📊 数据概览",
|
|
1506
|
+
expand=False,
|
|
1507
|
+
)
|
|
1508
|
+
)
|
|
1464
1509
|
|
|
1465
1510
|
# 字段统计表
|
|
1466
1511
|
table = Table(title="📋 字段统计", show_header=True, header_style="bold cyan")
|
|
@@ -1477,12 +1522,18 @@ def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -
|
|
|
1477
1522
|
# 构建统计信息字符串
|
|
1478
1523
|
extra = []
|
|
1479
1524
|
if "len_avg" in stat:
|
|
1480
|
-
extra.append(
|
|
1525
|
+
extra.append(
|
|
1526
|
+
f"长度: {stat['len_min']}-{stat['len_max']} (avg {stat['len_avg']:.0f})"
|
|
1527
|
+
)
|
|
1481
1528
|
if "avg" in stat:
|
|
1482
1529
|
if stat["type"] == "int":
|
|
1483
|
-
extra.append(
|
|
1530
|
+
extra.append(
|
|
1531
|
+
f"范围: {int(stat['min'])}-{int(stat['max'])} (avg {stat['avg']:.1f})"
|
|
1532
|
+
)
|
|
1484
1533
|
else:
|
|
1485
|
-
extra.append(
|
|
1534
|
+
extra.append(
|
|
1535
|
+
f"范围: {stat['min']:.2f}-{stat['max']:.2f} (avg {stat['avg']:.2f})"
|
|
1536
|
+
)
|
|
1486
1537
|
|
|
1487
1538
|
table.add_row(
|
|
1488
1539
|
stat["field"],
|
|
@@ -1509,7 +1560,9 @@ def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -
|
|
|
1509
1560
|
if unique_ratio > 0.9 and stat.get("unique", 0) > 100:
|
|
1510
1561
|
continue
|
|
1511
1562
|
|
|
1512
|
-
console.print(
|
|
1563
|
+
console.print(
|
|
1564
|
+
f"\n[bold cyan]{stat['field']}[/bold cyan] 值分布 (Top {len(top_values)}):"
|
|
1565
|
+
)
|
|
1513
1566
|
max_count = max(c for _, c in top_values) if top_values else 1
|
|
1514
1567
|
for value, count in top_values:
|
|
1515
1568
|
pct = count / total * 100
|
|
@@ -1559,25 +1612,26 @@ def clean(
|
|
|
1559
1612
|
|
|
1560
1613
|
Args:
|
|
1561
1614
|
filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
|
|
1562
|
-
drop_empty:
|
|
1615
|
+
drop_empty: 删除空值记录,支持嵌套路径语法
|
|
1563
1616
|
- 不带值:删除任意字段为空的记录
|
|
1564
1617
|
- 指定字段:删除指定字段为空的记录(逗号分隔)
|
|
1565
|
-
min_len: 最小长度过滤,格式 "字段:长度"
|
|
1566
|
-
max_len: 最大长度过滤,格式 "字段:长度"
|
|
1567
|
-
keep:
|
|
1568
|
-
drop:
|
|
1618
|
+
min_len: 最小长度过滤,格式 "字段:长度",字段支持嵌套路径
|
|
1619
|
+
max_len: 最大长度过滤,格式 "字段:长度",字段支持嵌套路径
|
|
1620
|
+
keep: 只保留指定字段(逗号分隔,仅支持顶层字段)
|
|
1621
|
+
drop: 删除指定字段(逗号分隔,仅支持顶层字段)
|
|
1569
1622
|
strip: 去除所有字符串字段的首尾空白
|
|
1570
1623
|
output: 输出文件路径,不指定则覆盖原文件
|
|
1571
1624
|
|
|
1572
1625
|
Examples:
|
|
1573
1626
|
dt clean data.jsonl --drop-empty # 删除任意空值记录
|
|
1574
1627
|
dt clean data.jsonl --drop-empty=text,answer # 删除指定字段为空的记录
|
|
1628
|
+
dt clean data.jsonl --drop-empty=meta.source # 删除嵌套字段为空的记录
|
|
1575
1629
|
dt clean data.jsonl --min-len=text:10 # text 字段最少 10 字符
|
|
1576
|
-
dt clean data.jsonl --
|
|
1630
|
+
dt clean data.jsonl --min-len=messages.#:2 # 至少 2 条消息
|
|
1631
|
+
dt clean data.jsonl --max-len=messages[-1].content:500 # 最后一条消息最多 500 字符
|
|
1577
1632
|
dt clean data.jsonl --keep=question,answer # 只保留这些字段
|
|
1578
1633
|
dt clean data.jsonl --drop=metadata,timestamp # 删除这些字段
|
|
1579
1634
|
dt clean data.jsonl --strip # 去除字符串首尾空白
|
|
1580
|
-
dt clean data.jsonl --drop-empty --strip -o out.jsonl
|
|
1581
1635
|
"""
|
|
1582
1636
|
filepath = Path(filename)
|
|
1583
1637
|
|
|
@@ -1666,6 +1720,7 @@ def clean(
|
|
|
1666
1720
|
os.unlink(temp_path)
|
|
1667
1721
|
print(f"错误: 清洗失败 - {e}")
|
|
1668
1722
|
import traceback
|
|
1723
|
+
|
|
1669
1724
|
traceback.print_exc()
|
|
1670
1725
|
return
|
|
1671
1726
|
|
|
@@ -1746,9 +1801,18 @@ def _is_empty_value(v: Any) -> bool:
|
|
|
1746
1801
|
|
|
1747
1802
|
|
|
1748
1803
|
def _get_value_len(value: Any) -> int:
|
|
1749
|
-
"""
|
|
1804
|
+
"""
|
|
1805
|
+
获取值的长度。
|
|
1806
|
+
|
|
1807
|
+
- str/list/dict: 返回 len()
|
|
1808
|
+
- int/float: 直接返回该数值(用于 messages.# 这种返回数量的场景)
|
|
1809
|
+
- None: 返回 0
|
|
1810
|
+
- 其他: 转为字符串后返回长度
|
|
1811
|
+
"""
|
|
1750
1812
|
if value is None:
|
|
1751
1813
|
return 0
|
|
1814
|
+
if isinstance(value, (int, float)):
|
|
1815
|
+
return int(value)
|
|
1752
1816
|
if isinstance(value, (str, list, dict)):
|
|
1753
1817
|
return len(value)
|
|
1754
1818
|
return len(str(value))
|
|
@@ -1771,13 +1835,13 @@ def _clean_data_single_pass(
|
|
|
1771
1835
|
Args:
|
|
1772
1836
|
data: 原始数据列表
|
|
1773
1837
|
strip: 是否去除字符串首尾空白
|
|
1774
|
-
empty_fields:
|
|
1775
|
-
min_len_field:
|
|
1838
|
+
empty_fields: 检查空值的字段列表(支持嵌套路径),空列表表示检查所有字段,None 表示不检查
|
|
1839
|
+
min_len_field: 最小长度检查的字段(支持嵌套路径)
|
|
1776
1840
|
min_len_value: 最小长度值
|
|
1777
|
-
max_len_field:
|
|
1841
|
+
max_len_field: 最大长度检查的字段(支持嵌套路径)
|
|
1778
1842
|
max_len_value: 最大长度值
|
|
1779
|
-
keep_fields:
|
|
1780
|
-
drop_fields:
|
|
1843
|
+
keep_fields: 只保留的字段列表(仅支持顶层字段)
|
|
1844
|
+
drop_fields: 要删除的字段集合(仅支持顶层字段)
|
|
1781
1845
|
|
|
1782
1846
|
Returns:
|
|
1783
1847
|
(清洗后的数据, 统计信息列表)
|
|
@@ -1805,20 +1869,20 @@ def _clean_data_single_pass(
|
|
|
1805
1869
|
stats["drop_empty"] += 1
|
|
1806
1870
|
continue
|
|
1807
1871
|
else:
|
|
1808
|
-
#
|
|
1809
|
-
if any(_is_empty_value(item
|
|
1872
|
+
# 检查指定字段(支持嵌套路径)
|
|
1873
|
+
if any(_is_empty_value(get_field_with_spec(item, f)) for f in empty_fields):
|
|
1810
1874
|
stats["drop_empty"] += 1
|
|
1811
1875
|
continue
|
|
1812
1876
|
|
|
1813
|
-
# 3.
|
|
1877
|
+
# 3. 最小长度过滤(支持嵌套路径)
|
|
1814
1878
|
if min_len_field is not None:
|
|
1815
|
-
if _get_value_len(item
|
|
1879
|
+
if _get_value_len(get_field_with_spec(item, min_len_field, default="")) < min_len_value:
|
|
1816
1880
|
stats["min_len"] += 1
|
|
1817
1881
|
continue
|
|
1818
1882
|
|
|
1819
|
-
# 4.
|
|
1883
|
+
# 4. 最大长度过滤(支持嵌套路径)
|
|
1820
1884
|
if max_len_field is not None:
|
|
1821
|
-
if _get_value_len(item
|
|
1885
|
+
if _get_value_len(get_field_with_spec(item, max_len_field, default="")) > max_len_value:
|
|
1822
1886
|
stats["max_len"] += 1
|
|
1823
1887
|
continue
|
|
1824
1888
|
|
|
@@ -1866,25 +1930,27 @@ def _clean_streaming(
|
|
|
1866
1930
|
Returns:
|
|
1867
1931
|
处理后的数据条数
|
|
1868
1932
|
"""
|
|
1933
|
+
|
|
1869
1934
|
def clean_filter(item: Dict) -> bool:
|
|
1870
|
-
"""过滤函数:返回 True 保留,False
|
|
1935
|
+
"""过滤函数:返回 True 保留,False 过滤(支持嵌套路径)"""
|
|
1871
1936
|
# 空值过滤
|
|
1872
1937
|
if empty_fields is not None:
|
|
1873
1938
|
if len(empty_fields) == 0:
|
|
1874
1939
|
if any(_is_empty_value(v) for v in item.values()):
|
|
1875
1940
|
return False
|
|
1876
1941
|
else:
|
|
1877
|
-
|
|
1942
|
+
# 支持嵌套路径
|
|
1943
|
+
if any(_is_empty_value(get_field_with_spec(item, f)) for f in empty_fields):
|
|
1878
1944
|
return False
|
|
1879
1945
|
|
|
1880
|
-
#
|
|
1946
|
+
# 最小长度过滤(支持嵌套路径)
|
|
1881
1947
|
if min_len_field is not None:
|
|
1882
|
-
if _get_value_len(item
|
|
1948
|
+
if _get_value_len(get_field_with_spec(item, min_len_field, default="")) < min_len_value:
|
|
1883
1949
|
return False
|
|
1884
1950
|
|
|
1885
|
-
#
|
|
1951
|
+
# 最大长度过滤(支持嵌套路径)
|
|
1886
1952
|
if max_len_field is not None:
|
|
1887
|
-
if _get_value_len(item
|
|
1953
|
+
if _get_value_len(get_field_with_spec(item, max_len_field, default="")) > max_len_value:
|
|
1888
1954
|
return False
|
|
1889
1955
|
|
|
1890
1956
|
return True
|
|
@@ -1908,7 +1974,9 @@ def _clean_streaming(
|
|
|
1908
1974
|
|
|
1909
1975
|
# 如果需要 strip,先执行 strip 转换(在过滤之前,这样空值检测更准确)
|
|
1910
1976
|
if strip:
|
|
1911
|
-
st = st.transform(
|
|
1977
|
+
st = st.transform(
|
|
1978
|
+
lambda x: {k: v.strip() if isinstance(v, str) else v for k, v in x.items()}
|
|
1979
|
+
)
|
|
1912
1980
|
|
|
1913
1981
|
# 执行过滤
|
|
1914
1982
|
if empty_fields is not None or min_len_field is not None or max_len_field is not None:
|
|
@@ -1916,12 +1984,14 @@ def _clean_streaming(
|
|
|
1916
1984
|
|
|
1917
1985
|
# 执行字段管理(如果没有 strip,也需要在这里处理)
|
|
1918
1986
|
if keep_set is not None or drop_fields_set is not None:
|
|
1987
|
+
|
|
1919
1988
|
def field_transform(item):
|
|
1920
1989
|
if keep_set is not None:
|
|
1921
1990
|
return {k: v for k, v in item.items() if k in keep_set}
|
|
1922
1991
|
elif drop_fields_set is not None:
|
|
1923
1992
|
return {k: v for k, v in item.items() if k not in drop_fields_set}
|
|
1924
1993
|
return item
|
|
1994
|
+
|
|
1925
1995
|
st = st.transform(field_transform)
|
|
1926
1996
|
|
|
1927
1997
|
return st.save(output_path)
|
|
@@ -1972,6 +2042,7 @@ def run(
|
|
|
1972
2042
|
except Exception as e:
|
|
1973
2043
|
print(f"错误: {e}")
|
|
1974
2044
|
import traceback
|
|
2045
|
+
|
|
1975
2046
|
traceback.print_exc()
|
|
1976
2047
|
|
|
1977
2048
|
|
|
@@ -1989,13 +2060,15 @@ def token_stats(
|
|
|
1989
2060
|
|
|
1990
2061
|
Args:
|
|
1991
2062
|
filename: 输入文件路径
|
|
1992
|
-
field: 要统计的字段(默认 messages
|
|
2063
|
+
field: 要统计的字段(默认 messages),支持嵌套路径语法
|
|
1993
2064
|
model: 分词器: cl100k_base (默认), qwen2.5, llama3, gpt-4 等
|
|
1994
2065
|
detailed: 是否显示详细统计
|
|
1995
2066
|
|
|
1996
2067
|
Examples:
|
|
1997
2068
|
dt token-stats data.jsonl
|
|
1998
2069
|
dt token-stats data.jsonl --field=text --model=qwen2.5
|
|
2070
|
+
dt token-stats data.jsonl --field=conversation.messages
|
|
2071
|
+
dt token-stats data.jsonl --field=messages[-1].content # 统计最后一条消息
|
|
1999
2072
|
dt token-stats data.jsonl --detailed
|
|
2000
2073
|
"""
|
|
2001
2074
|
filepath = Path(filename)
|
|
@@ -2023,19 +2096,21 @@ def token_stats(
|
|
|
2023
2096
|
print(f" 共 {total} 条数据")
|
|
2024
2097
|
print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
|
|
2025
2098
|
|
|
2026
|
-
#
|
|
2099
|
+
# 检查字段类型并选择合适的统计方法(支持嵌套路径)
|
|
2027
2100
|
sample = data[0]
|
|
2028
|
-
field_value = sample
|
|
2101
|
+
field_value = get_field_with_spec(sample, field)
|
|
2029
2102
|
|
|
2030
2103
|
try:
|
|
2031
2104
|
if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
|
|
2032
2105
|
# messages 格式
|
|
2033
2106
|
from ..tokenizers import messages_token_stats
|
|
2107
|
+
|
|
2034
2108
|
stats = messages_token_stats(data, messages_field=field, model=model)
|
|
2035
2109
|
_print_messages_token_stats(stats, detailed)
|
|
2036
2110
|
else:
|
|
2037
2111
|
# 普通文本字段
|
|
2038
2112
|
from ..tokenizers import token_stats as compute_token_stats
|
|
2113
|
+
|
|
2039
2114
|
stats = compute_token_stats(data, fields=field, model=model)
|
|
2040
2115
|
_print_text_token_stats(stats, detailed)
|
|
2041
2116
|
except ImportError as e:
|
|
@@ -2044,6 +2119,7 @@ def token_stats(
|
|
|
2044
2119
|
except Exception as e:
|
|
2045
2120
|
print(f"错误: 统计失败 - {e}")
|
|
2046
2121
|
import traceback
|
|
2122
|
+
|
|
2047
2123
|
traceback.print_exc()
|
|
2048
2124
|
|
|
2049
2125
|
|
|
@@ -2051,8 +2127,8 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
|
2051
2127
|
"""打印 messages 格式的 token 统计"""
|
|
2052
2128
|
try:
|
|
2053
2129
|
from rich.console import Console
|
|
2054
|
-
from rich.table import Table
|
|
2055
2130
|
from rich.panel import Panel
|
|
2131
|
+
from rich.table import Table
|
|
2056
2132
|
|
|
2057
2133
|
console = Console()
|
|
2058
2134
|
|
|
@@ -2073,8 +2149,12 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
|
2073
2149
|
table.add_column("Token 数", justify="right")
|
|
2074
2150
|
table.add_column("占比", justify="right")
|
|
2075
2151
|
|
|
2076
|
-
total = stats[
|
|
2077
|
-
for role, key in [
|
|
2152
|
+
total = stats["total_tokens"]
|
|
2153
|
+
for role, key in [
|
|
2154
|
+
("User", "user_tokens"),
|
|
2155
|
+
("Assistant", "assistant_tokens"),
|
|
2156
|
+
("System", "system_tokens"),
|
|
2157
|
+
]:
|
|
2078
2158
|
tokens = stats.get(key, 0)
|
|
2079
2159
|
pct = tokens / total * 100 if total > 0 else 0
|
|
2080
2160
|
table.add_row(role, f"{tokens:,}", f"{pct:.1f}%")
|
|
@@ -2097,8 +2177,12 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
|
2097
2177
|
print(f"\n{'=' * 40}")
|
|
2098
2178
|
print("📋 分角色统计")
|
|
2099
2179
|
print(f"{'=' * 40}")
|
|
2100
|
-
total = stats[
|
|
2101
|
-
for role, key in [
|
|
2180
|
+
total = stats["total_tokens"]
|
|
2181
|
+
for role, key in [
|
|
2182
|
+
("User", "user_tokens"),
|
|
2183
|
+
("Assistant", "assistant_tokens"),
|
|
2184
|
+
("System", "system_tokens"),
|
|
2185
|
+
]:
|
|
2102
2186
|
tokens = stats.get(key, 0)
|
|
2103
2187
|
pct = tokens / total * 100 if total > 0 else 0
|
|
2104
2188
|
print(f"{role}: {tokens:,} ({pct:.1f}%)")
|
|
@@ -2148,12 +2232,13 @@ def diff(
|
|
|
2148
2232
|
Args:
|
|
2149
2233
|
file1: 第一个文件路径
|
|
2150
2234
|
file2: 第二个文件路径
|
|
2151
|
-
key:
|
|
2235
|
+
key: 用于匹配的键字段,支持嵌套路径语法(可选)
|
|
2152
2236
|
output: 差异报告输出路径(可选)
|
|
2153
2237
|
|
|
2154
2238
|
Examples:
|
|
2155
2239
|
dt diff v1/train.jsonl v2/train.jsonl
|
|
2156
2240
|
dt diff a.jsonl b.jsonl --key=id
|
|
2241
|
+
dt diff a.jsonl b.jsonl --key=meta.uuid # 按嵌套字段匹配
|
|
2157
2242
|
dt diff a.jsonl b.jsonl --output=diff_report.json
|
|
2158
2243
|
"""
|
|
2159
2244
|
path1 = Path(file1)
|
|
@@ -2216,9 +2301,9 @@ def _compute_diff(
|
|
|
2216
2301
|
}
|
|
2217
2302
|
|
|
2218
2303
|
if key:
|
|
2219
|
-
# 基于 key
|
|
2220
|
-
dict1 = {item
|
|
2221
|
-
dict2 = {item
|
|
2304
|
+
# 基于 key 的精确匹配(支持嵌套路径)
|
|
2305
|
+
dict1 = {get_field_with_spec(item, key): item for item in data1 if get_field_with_spec(item, key) is not None}
|
|
2306
|
+
dict2 = {get_field_with_spec(item, key): item for item in data2 if get_field_with_spec(item, key) is not None}
|
|
2222
2307
|
|
|
2223
2308
|
keys1 = set(dict1.keys())
|
|
2224
2309
|
keys2 = set(dict2.keys())
|
|
@@ -2241,11 +2326,13 @@ def _compute_diff(
|
|
|
2241
2326
|
else:
|
|
2242
2327
|
result["summary"]["modified"] += 1
|
|
2243
2328
|
if len(result["details"]["modified"]) < 10:
|
|
2244
|
-
result["details"]["modified"].append(
|
|
2245
|
-
|
|
2246
|
-
|
|
2247
|
-
|
|
2248
|
-
|
|
2329
|
+
result["details"]["modified"].append(
|
|
2330
|
+
{
|
|
2331
|
+
"key": k,
|
|
2332
|
+
"before": dict1[k],
|
|
2333
|
+
"after": dict2[k],
|
|
2334
|
+
}
|
|
2335
|
+
)
|
|
2249
2336
|
else:
|
|
2250
2337
|
# 基于哈希的比较
|
|
2251
2338
|
def _hash_item(item):
|
|
@@ -2290,8 +2377,8 @@ def _print_diff_report(diff_result: Dict[str, Any], name1: str, name2: str) -> N
|
|
|
2290
2377
|
|
|
2291
2378
|
try:
|
|
2292
2379
|
from rich.console import Console
|
|
2293
|
-
from rich.table import Table
|
|
2294
2380
|
from rich.panel import Panel
|
|
2381
|
+
from rich.table import Table
|
|
2295
2382
|
|
|
2296
2383
|
console = Console()
|
|
2297
2384
|
|
|
@@ -2311,9 +2398,13 @@ def _print_diff_report(diff_result: Dict[str, Any], name1: str, name2: str) -> N
|
|
|
2311
2398
|
if field_changes["added_fields"] or field_changes["removed_fields"]:
|
|
2312
2399
|
console.print("\n[bold]📋 字段变化:[/bold]")
|
|
2313
2400
|
if field_changes["added_fields"]:
|
|
2314
|
-
console.print(
|
|
2401
|
+
console.print(
|
|
2402
|
+
f" [green]+ 新增字段:[/green] {', '.join(field_changes['added_fields'])}"
|
|
2403
|
+
)
|
|
2315
2404
|
if field_changes["removed_fields"]:
|
|
2316
|
-
console.print(
|
|
2405
|
+
console.print(
|
|
2406
|
+
f" [red]- 删除字段:[/red] {', '.join(field_changes['removed_fields'])}"
|
|
2407
|
+
)
|
|
2317
2408
|
|
|
2318
2409
|
except ImportError:
|
|
2319
2410
|
print(f"\n{'=' * 50}")
|