dtflow 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/__init__.py +1 -1
- dtflow/__main__.py +6 -3
- dtflow/cli/clean.py +486 -0
- dtflow/cli/commands.py +53 -2637
- dtflow/cli/common.py +384 -0
- dtflow/cli/io_ops.py +385 -0
- dtflow/cli/lineage.py +49 -0
- dtflow/cli/pipeline.py +54 -0
- dtflow/cli/sample.py +294 -0
- dtflow/cli/stats.py +589 -0
- dtflow/cli/transform.py +486 -0
- dtflow/core.py +35 -0
- dtflow/storage/io.py +49 -6
- dtflow/streaming.py +25 -4
- {dtflow-0.4.2.dist-info → dtflow-0.4.3.dist-info}/METADATA +12 -1
- dtflow-0.4.3.dist-info/RECORD +33 -0
- dtflow-0.4.2.dist-info/RECORD +0 -25
- {dtflow-0.4.2.dist-info → dtflow-0.4.3.dist-info}/WHEEL +0 -0
- {dtflow-0.4.2.dist-info → dtflow-0.4.3.dist-info}/entry_points.txt +0 -0
dtflow/cli/stats.py
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CLI 数据统计相关命令
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, List
|
|
7
|
+
|
|
8
|
+
import orjson
|
|
9
|
+
|
|
10
|
+
from ..storage.io import load_data
|
|
11
|
+
from ..utils.field_path import get_field_with_spec
|
|
12
|
+
from .common import (
|
|
13
|
+
_check_file_format,
|
|
14
|
+
_infer_type,
|
|
15
|
+
_is_numeric,
|
|
16
|
+
_pad_to_width,
|
|
17
|
+
_truncate,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def stats(
|
|
22
|
+
filename: str,
|
|
23
|
+
top: int = 10,
|
|
24
|
+
full: bool = False,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""
|
|
27
|
+
显示数据文件的统计信息。
|
|
28
|
+
|
|
29
|
+
默认快速模式:只统计行数和字段结构。
|
|
30
|
+
完整模式(--full):统计值分布、唯一值、长度等详细信息。
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
|
|
34
|
+
top: 显示频率最高的前 N 个值,默认 10(仅完整模式)
|
|
35
|
+
full: 完整模式,统计值分布、唯一值等详细信息
|
|
36
|
+
|
|
37
|
+
Examples:
|
|
38
|
+
dt stats data.jsonl # 快速模式(默认)
|
|
39
|
+
dt stats data.jsonl --full # 完整模式
|
|
40
|
+
dt stats data.csv -f --top=5 # 完整模式,显示 Top 5
|
|
41
|
+
"""
|
|
42
|
+
filepath = Path(filename)
|
|
43
|
+
|
|
44
|
+
if not filepath.exists():
|
|
45
|
+
print(f"错误: 文件不存在 - {filename}")
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
if not _check_file_format(filepath):
|
|
49
|
+
return
|
|
50
|
+
|
|
51
|
+
if not full:
|
|
52
|
+
_quick_stats(filepath)
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
# 加载数据
|
|
56
|
+
try:
|
|
57
|
+
data = load_data(str(filepath))
|
|
58
|
+
except Exception as e:
|
|
59
|
+
print(f"错误: 无法读取文件 - {e}")
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
if not data:
|
|
63
|
+
print("文件为空")
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
# 计算统计信息
|
|
67
|
+
total = len(data)
|
|
68
|
+
field_stats = _compute_field_stats(data, top)
|
|
69
|
+
|
|
70
|
+
# 输出统计信息
|
|
71
|
+
_print_stats(filepath.name, total, field_stats)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _quick_stats(filepath: Path) -> None:
|
|
75
|
+
"""
|
|
76
|
+
快速统计模式:只统计行数和字段结构,不遍历全部数据。
|
|
77
|
+
|
|
78
|
+
特点:
|
|
79
|
+
- 使用流式计数,不加载全部数据到内存
|
|
80
|
+
- 只读取前几条数据来推断字段结构
|
|
81
|
+
- 不计算值分布、唯一值等耗时统计
|
|
82
|
+
"""
|
|
83
|
+
from ..streaming import _count_rows_fast
|
|
84
|
+
|
|
85
|
+
ext = filepath.suffix.lower()
|
|
86
|
+
file_size = filepath.stat().st_size
|
|
87
|
+
|
|
88
|
+
# 格式化文件大小
|
|
89
|
+
def format_size(size: int) -> str:
|
|
90
|
+
for unit in ["B", "KB", "MB", "GB"]:
|
|
91
|
+
if size < 1024:
|
|
92
|
+
return f"{size:.1f} {unit}"
|
|
93
|
+
size /= 1024
|
|
94
|
+
return f"{size:.1f} TB"
|
|
95
|
+
|
|
96
|
+
# 快速统计行数
|
|
97
|
+
total = _count_rows_fast(str(filepath))
|
|
98
|
+
if total is None:
|
|
99
|
+
# 回退:手动计数
|
|
100
|
+
total = 0
|
|
101
|
+
try:
|
|
102
|
+
with open(filepath, "rb") as f:
|
|
103
|
+
for line in f:
|
|
104
|
+
if line.strip():
|
|
105
|
+
total += 1
|
|
106
|
+
except Exception:
|
|
107
|
+
total = -1
|
|
108
|
+
|
|
109
|
+
# 读取前几条数据推断字段结构
|
|
110
|
+
sample_data = []
|
|
111
|
+
sample_size = 5
|
|
112
|
+
try:
|
|
113
|
+
if ext == ".jsonl":
|
|
114
|
+
with open(filepath, "rb") as f:
|
|
115
|
+
for i, line in enumerate(f):
|
|
116
|
+
if i >= sample_size:
|
|
117
|
+
break
|
|
118
|
+
line = line.strip()
|
|
119
|
+
if line:
|
|
120
|
+
sample_data.append(orjson.loads(line))
|
|
121
|
+
elif ext == ".csv":
|
|
122
|
+
import polars as pl
|
|
123
|
+
|
|
124
|
+
df = pl.scan_csv(str(filepath)).head(sample_size).collect()
|
|
125
|
+
sample_data = df.to_dicts()
|
|
126
|
+
elif ext == ".parquet":
|
|
127
|
+
import polars as pl
|
|
128
|
+
|
|
129
|
+
df = pl.scan_parquet(str(filepath)).head(sample_size).collect()
|
|
130
|
+
sample_data = df.to_dicts()
|
|
131
|
+
elif ext in (".arrow", ".feather"):
|
|
132
|
+
import polars as pl
|
|
133
|
+
|
|
134
|
+
df = pl.scan_ipc(str(filepath)).head(sample_size).collect()
|
|
135
|
+
sample_data = df.to_dicts()
|
|
136
|
+
elif ext == ".json":
|
|
137
|
+
with open(filepath, "rb") as f:
|
|
138
|
+
data = orjson.loads(f.read())
|
|
139
|
+
if isinstance(data, list):
|
|
140
|
+
sample_data = data[:sample_size]
|
|
141
|
+
except Exception:
|
|
142
|
+
pass
|
|
143
|
+
|
|
144
|
+
# 分析字段结构
|
|
145
|
+
fields = []
|
|
146
|
+
if sample_data:
|
|
147
|
+
all_keys = set()
|
|
148
|
+
for item in sample_data:
|
|
149
|
+
all_keys.update(item.keys())
|
|
150
|
+
|
|
151
|
+
for key in sorted(all_keys):
|
|
152
|
+
# 从采样数据中推断类型
|
|
153
|
+
sample_values = [item.get(key) for item in sample_data if key in item]
|
|
154
|
+
non_null = [v for v in sample_values if v is not None]
|
|
155
|
+
if non_null:
|
|
156
|
+
field_type = _infer_type(non_null)
|
|
157
|
+
else:
|
|
158
|
+
field_type = "unknown"
|
|
159
|
+
fields.append({"field": key, "type": field_type})
|
|
160
|
+
|
|
161
|
+
# 输出
|
|
162
|
+
try:
|
|
163
|
+
from rich.console import Console
|
|
164
|
+
from rich.panel import Panel
|
|
165
|
+
from rich.table import Table
|
|
166
|
+
|
|
167
|
+
console = Console()
|
|
168
|
+
|
|
169
|
+
# 概览
|
|
170
|
+
console.print(
|
|
171
|
+
Panel(
|
|
172
|
+
f"[bold]文件:[/bold] {filepath.name}\n"
|
|
173
|
+
f"[bold]大小:[/bold] {format_size(file_size)}\n"
|
|
174
|
+
f"[bold]总数:[/bold] {total:,} 条\n"
|
|
175
|
+
f"[bold]字段:[/bold] {len(fields)} 个",
|
|
176
|
+
title="📊 快速统计",
|
|
177
|
+
expand=False,
|
|
178
|
+
)
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
if fields:
|
|
182
|
+
table = Table(title="📋 字段结构", show_header=True, header_style="bold cyan")
|
|
183
|
+
table.add_column("#", style="dim", justify="right")
|
|
184
|
+
table.add_column("字段", style="green")
|
|
185
|
+
table.add_column("类型", style="yellow")
|
|
186
|
+
|
|
187
|
+
for i, f in enumerate(fields, 1):
|
|
188
|
+
table.add_row(str(i), f["field"], f["type"])
|
|
189
|
+
|
|
190
|
+
console.print(table)
|
|
191
|
+
|
|
192
|
+
except ImportError:
|
|
193
|
+
# 没有 rich,使用普通打印
|
|
194
|
+
print(f"\n{'=' * 40}")
|
|
195
|
+
print("📊 快速统计")
|
|
196
|
+
print(f"{'=' * 40}")
|
|
197
|
+
print(f"文件: {filepath.name}")
|
|
198
|
+
print(f"大小: {format_size(file_size)}")
|
|
199
|
+
print(f"总数: {total:,} 条")
|
|
200
|
+
print(f"字段: {len(fields)} 个")
|
|
201
|
+
|
|
202
|
+
if fields:
|
|
203
|
+
print(f"\n📋 字段结构:")
|
|
204
|
+
for i, f in enumerate(fields, 1):
|
|
205
|
+
print(f" {i}. {f['field']} ({f['type']})")
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _compute_field_stats(data: List[Dict], top: int) -> List[Dict[str, Any]]:
|
|
209
|
+
"""
|
|
210
|
+
单次遍历计算每个字段的统计信息。
|
|
211
|
+
|
|
212
|
+
优化:将多次遍历合并为单次遍历,在遍历过程中同时收集所有统计数据。
|
|
213
|
+
"""
|
|
214
|
+
from collections import Counter, defaultdict
|
|
215
|
+
|
|
216
|
+
if not data:
|
|
217
|
+
return []
|
|
218
|
+
|
|
219
|
+
total = len(data)
|
|
220
|
+
|
|
221
|
+
# 单次遍历收集所有字段的值和统计信息
|
|
222
|
+
field_values = defaultdict(list) # 存储每个字段的所有值
|
|
223
|
+
field_counters = defaultdict(Counter) # 存储每个字段的值频率(用于 top N)
|
|
224
|
+
|
|
225
|
+
for item in data:
|
|
226
|
+
for k, v in item.items():
|
|
227
|
+
field_values[k].append(v)
|
|
228
|
+
# 对值进行截断后计数(用于 top N 显示)
|
|
229
|
+
displayable = _truncate(v if v is not None else "", 30)
|
|
230
|
+
field_counters[k][displayable] += 1
|
|
231
|
+
|
|
232
|
+
# 根据收集的数据计算统计信息
|
|
233
|
+
stats_list = []
|
|
234
|
+
for field in sorted(field_values.keys()):
|
|
235
|
+
values = field_values[field]
|
|
236
|
+
non_null = [v for v in values if v is not None and v != ""]
|
|
237
|
+
non_null_count = len(non_null)
|
|
238
|
+
|
|
239
|
+
# 推断类型(从第一个非空值)
|
|
240
|
+
field_type = _infer_type(non_null)
|
|
241
|
+
|
|
242
|
+
# 基础统计
|
|
243
|
+
stat = {
|
|
244
|
+
"field": field,
|
|
245
|
+
"non_null": non_null_count,
|
|
246
|
+
"null_rate": f"{(total - non_null_count) / total * 100:.1f}%",
|
|
247
|
+
"type": field_type,
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
# 类型特定统计
|
|
251
|
+
if non_null:
|
|
252
|
+
# 唯一值计数(对复杂类型使用 hash 节省内存)
|
|
253
|
+
stat["unique"] = _count_unique(non_null, field_type)
|
|
254
|
+
|
|
255
|
+
# 字符串类型:计算长度统计
|
|
256
|
+
if field_type == "str":
|
|
257
|
+
lengths = [len(str(v)) for v in non_null]
|
|
258
|
+
stat["len_min"] = min(lengths)
|
|
259
|
+
stat["len_max"] = max(lengths)
|
|
260
|
+
stat["len_avg"] = sum(lengths) / len(lengths)
|
|
261
|
+
|
|
262
|
+
# 数值类型:计算数值统计
|
|
263
|
+
elif field_type in ("int", "float"):
|
|
264
|
+
nums = [float(v) for v in non_null if _is_numeric(v)]
|
|
265
|
+
if nums:
|
|
266
|
+
stat["min"] = min(nums)
|
|
267
|
+
stat["max"] = max(nums)
|
|
268
|
+
stat["avg"] = sum(nums) / len(nums)
|
|
269
|
+
|
|
270
|
+
# 列表类型:计算长度统计
|
|
271
|
+
elif field_type == "list":
|
|
272
|
+
lengths = [len(v) if isinstance(v, list) else 0 for v in non_null]
|
|
273
|
+
stat["len_min"] = min(lengths)
|
|
274
|
+
stat["len_max"] = max(lengths)
|
|
275
|
+
stat["len_avg"] = sum(lengths) / len(lengths)
|
|
276
|
+
|
|
277
|
+
# Top N 值(已在遍历时收集)
|
|
278
|
+
stat["top_values"] = field_counters[field].most_common(top)
|
|
279
|
+
|
|
280
|
+
stats_list.append(stat)
|
|
281
|
+
|
|
282
|
+
return stats_list
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def _count_unique(values: List[Any], field_type: str) -> int:
|
|
286
|
+
"""
|
|
287
|
+
计算唯一值数量。
|
|
288
|
+
|
|
289
|
+
对于简单类型直接比较,对于 list/dict 或混合类型使用 hash。
|
|
290
|
+
"""
|
|
291
|
+
if field_type in ("list", "dict"):
|
|
292
|
+
return _count_unique_by_hash(values)
|
|
293
|
+
else:
|
|
294
|
+
# 简单类型:尝试直接比较,失败则回退到 hash 方式
|
|
295
|
+
try:
|
|
296
|
+
return len(set(values))
|
|
297
|
+
except TypeError:
|
|
298
|
+
# 混合类型(如字段中既有 str 又有 dict),回退到 hash
|
|
299
|
+
return _count_unique_by_hash(values)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def _count_unique_by_hash(values: List[Any]) -> int:
|
|
303
|
+
"""使用 orjson 序列化后计算 hash 来统计唯一值"""
|
|
304
|
+
import hashlib
|
|
305
|
+
|
|
306
|
+
seen = set()
|
|
307
|
+
for v in values:
|
|
308
|
+
try:
|
|
309
|
+
h = hashlib.md5(orjson.dumps(v, option=orjson.OPT_SORT_KEYS)).digest()
|
|
310
|
+
seen.add(h)
|
|
311
|
+
except TypeError:
|
|
312
|
+
# 无法序列化的值,用 repr 兜底
|
|
313
|
+
seen.add(repr(v))
|
|
314
|
+
return len(seen)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -> None:
|
|
318
|
+
"""打印统计信息"""
|
|
319
|
+
try:
|
|
320
|
+
from rich.console import Console
|
|
321
|
+
from rich.panel import Panel
|
|
322
|
+
from rich.table import Table
|
|
323
|
+
|
|
324
|
+
console = Console()
|
|
325
|
+
|
|
326
|
+
# 概览
|
|
327
|
+
console.print(
|
|
328
|
+
Panel(
|
|
329
|
+
f"[bold]文件:[/bold] {filename}\n"
|
|
330
|
+
f"[bold]总数:[/bold] {total:,} 条\n"
|
|
331
|
+
f"[bold]字段:[/bold] {len(field_stats)} 个",
|
|
332
|
+
title="📊 数据概览",
|
|
333
|
+
expand=False,
|
|
334
|
+
)
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# 字段统计表
|
|
338
|
+
table = Table(title="📋 字段统计", show_header=True, header_style="bold cyan")
|
|
339
|
+
table.add_column("字段", style="green")
|
|
340
|
+
table.add_column("类型", style="yellow")
|
|
341
|
+
table.add_column("非空率", justify="right")
|
|
342
|
+
table.add_column("唯一值", justify="right")
|
|
343
|
+
table.add_column("统计", style="dim")
|
|
344
|
+
|
|
345
|
+
for stat in field_stats:
|
|
346
|
+
non_null_rate = f"{stat['non_null'] / total * 100:.0f}%"
|
|
347
|
+
unique = str(stat.get("unique", "-"))
|
|
348
|
+
|
|
349
|
+
# 构建统计信息字符串
|
|
350
|
+
extra = []
|
|
351
|
+
if "len_avg" in stat:
|
|
352
|
+
extra.append(
|
|
353
|
+
f"长度: {stat['len_min']}-{stat['len_max']} (avg {stat['len_avg']:.0f})"
|
|
354
|
+
)
|
|
355
|
+
if "avg" in stat:
|
|
356
|
+
if stat["type"] == "int":
|
|
357
|
+
extra.append(
|
|
358
|
+
f"范围: {int(stat['min'])}-{int(stat['max'])} (avg {stat['avg']:.1f})"
|
|
359
|
+
)
|
|
360
|
+
else:
|
|
361
|
+
extra.append(
|
|
362
|
+
f"范围: {stat['min']:.2f}-{stat['max']:.2f} (avg {stat['avg']:.2f})"
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
table.add_row(
|
|
366
|
+
stat["field"],
|
|
367
|
+
stat["type"],
|
|
368
|
+
non_null_rate,
|
|
369
|
+
unique,
|
|
370
|
+
"; ".join(extra) if extra else "-",
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
console.print(table)
|
|
374
|
+
|
|
375
|
+
# Top 值统计(仅显示有意义的字段)
|
|
376
|
+
for stat in field_stats:
|
|
377
|
+
top_values = stat.get("top_values", [])
|
|
378
|
+
if not top_values:
|
|
379
|
+
continue
|
|
380
|
+
|
|
381
|
+
# 跳过数值类型(min/max/avg 已足够)
|
|
382
|
+
if stat["type"] in ("int", "float"):
|
|
383
|
+
continue
|
|
384
|
+
|
|
385
|
+
# 跳过唯一值过多的字段(基本都是唯一的)
|
|
386
|
+
unique_ratio = stat.get("unique", 0) / total if total > 0 else 0
|
|
387
|
+
if unique_ratio > 0.9 and stat.get("unique", 0) > 100:
|
|
388
|
+
continue
|
|
389
|
+
|
|
390
|
+
console.print(
|
|
391
|
+
f"\n[bold cyan]{stat['field']}[/bold cyan] 值分布 (Top {len(top_values)}):"
|
|
392
|
+
)
|
|
393
|
+
max_count = max(c for _, c in top_values) if top_values else 1
|
|
394
|
+
for value, count in top_values:
|
|
395
|
+
pct = count / total * 100
|
|
396
|
+
bar_len = int(count / max_count * 20) # 按相对比例,最长 20 字符
|
|
397
|
+
bar = "█" * bar_len
|
|
398
|
+
display_value = value if value else "[空]"
|
|
399
|
+
# 使用显示宽度对齐(处理中文字符)
|
|
400
|
+
padded_value = _pad_to_width(display_value, 32)
|
|
401
|
+
console.print(f" {padded_value} {count:>6} ({pct:>5.1f}%) {bar}")
|
|
402
|
+
|
|
403
|
+
except ImportError:
|
|
404
|
+
# 没有 rich,使用普通打印
|
|
405
|
+
print(f"\n{'=' * 50}")
|
|
406
|
+
print(f"📊 数据概览")
|
|
407
|
+
print(f"{'=' * 50}")
|
|
408
|
+
print(f"文件: {filename}")
|
|
409
|
+
print(f"总数: {total:,} 条")
|
|
410
|
+
print(f"字段: {len(field_stats)} 个")
|
|
411
|
+
|
|
412
|
+
print(f"\n{'=' * 50}")
|
|
413
|
+
print(f"📋 字段统计")
|
|
414
|
+
print(f"{'=' * 50}")
|
|
415
|
+
print(f"{'字段':<20} {'类型':<8} {'非空率':<8} {'唯一值':<8}")
|
|
416
|
+
print("-" * 50)
|
|
417
|
+
|
|
418
|
+
for stat in field_stats:
|
|
419
|
+
non_null_rate = f"{stat['non_null'] / total * 100:.0f}%"
|
|
420
|
+
unique = str(stat.get("unique", "-"))
|
|
421
|
+
print(f"{stat['field']:<20} {stat['type']:<8} {non_null_rate:<8} {unique:<8}")
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def token_stats(
|
|
425
|
+
filename: str,
|
|
426
|
+
field: str = "messages",
|
|
427
|
+
model: str = "cl100k_base",
|
|
428
|
+
detailed: bool = False,
|
|
429
|
+
) -> None:
|
|
430
|
+
"""
|
|
431
|
+
统计数据集的 Token 信息。
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
filename: 输入文件路径
|
|
435
|
+
field: 要统计的字段(默认 messages),支持嵌套路径语法
|
|
436
|
+
model: 分词器: cl100k_base (默认), qwen2.5, llama3, gpt-4 等
|
|
437
|
+
detailed: 是否显示详细统计
|
|
438
|
+
|
|
439
|
+
Examples:
|
|
440
|
+
dt token-stats data.jsonl
|
|
441
|
+
dt token-stats data.jsonl --field=text --model=qwen2.5
|
|
442
|
+
dt token-stats data.jsonl --field=conversation.messages
|
|
443
|
+
dt token-stats data.jsonl --field=messages[-1].content # 统计最后一条消息
|
|
444
|
+
dt token-stats data.jsonl --detailed
|
|
445
|
+
"""
|
|
446
|
+
filepath = Path(filename)
|
|
447
|
+
|
|
448
|
+
if not filepath.exists():
|
|
449
|
+
print(f"错误: 文件不存在 - {filename}")
|
|
450
|
+
return
|
|
451
|
+
|
|
452
|
+
if not _check_file_format(filepath):
|
|
453
|
+
return
|
|
454
|
+
|
|
455
|
+
# 加载数据
|
|
456
|
+
print(f"📊 加载数据: {filepath}")
|
|
457
|
+
try:
|
|
458
|
+
data = load_data(str(filepath))
|
|
459
|
+
except Exception as e:
|
|
460
|
+
print(f"错误: 无法读取文件 - {e}")
|
|
461
|
+
return
|
|
462
|
+
|
|
463
|
+
if not data:
|
|
464
|
+
print("文件为空")
|
|
465
|
+
return
|
|
466
|
+
|
|
467
|
+
total = len(data)
|
|
468
|
+
print(f" 共 {total} 条数据")
|
|
469
|
+
print(f"🔢 统计 Token (模型: {model}, 字段: {field})...")
|
|
470
|
+
|
|
471
|
+
# 检查字段类型并选择合适的统计方法(支持嵌套路径)
|
|
472
|
+
sample = data[0]
|
|
473
|
+
field_value = get_field_with_spec(sample, field)
|
|
474
|
+
|
|
475
|
+
try:
|
|
476
|
+
if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
|
|
477
|
+
# messages 格式
|
|
478
|
+
from ..tokenizers import messages_token_stats
|
|
479
|
+
|
|
480
|
+
stats_result = messages_token_stats(data, messages_field=field, model=model)
|
|
481
|
+
_print_messages_token_stats(stats_result, detailed)
|
|
482
|
+
else:
|
|
483
|
+
# 普通文本字段
|
|
484
|
+
from ..tokenizers import token_stats as compute_token_stats
|
|
485
|
+
|
|
486
|
+
stats_result = compute_token_stats(data, fields=field, model=model)
|
|
487
|
+
_print_text_token_stats(stats_result, detailed)
|
|
488
|
+
except ImportError as e:
|
|
489
|
+
print(f"错误: {e}")
|
|
490
|
+
return
|
|
491
|
+
except Exception as e:
|
|
492
|
+
print(f"错误: 统计失败 - {e}")
|
|
493
|
+
import traceback
|
|
494
|
+
|
|
495
|
+
traceback.print_exc()
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
499
|
+
"""打印 messages 格式的 token 统计"""
|
|
500
|
+
try:
|
|
501
|
+
from rich.console import Console
|
|
502
|
+
from rich.panel import Panel
|
|
503
|
+
from rich.table import Table
|
|
504
|
+
|
|
505
|
+
console = Console()
|
|
506
|
+
|
|
507
|
+
# 概览
|
|
508
|
+
overview = (
|
|
509
|
+
f"[bold]总样本数:[/bold] {stats['count']:,}\n"
|
|
510
|
+
f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
|
|
511
|
+
f"[bold]平均 Token:[/bold] {stats['avg_tokens']:,}\n"
|
|
512
|
+
f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
|
|
513
|
+
f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
|
|
514
|
+
)
|
|
515
|
+
console.print(Panel(overview, title="📊 Token 统计概览", expand=False))
|
|
516
|
+
|
|
517
|
+
if detailed:
|
|
518
|
+
# 详细统计
|
|
519
|
+
table = Table(title="📋 分角色统计")
|
|
520
|
+
table.add_column("角色", style="cyan")
|
|
521
|
+
table.add_column("Token 数", justify="right")
|
|
522
|
+
table.add_column("占比", justify="right")
|
|
523
|
+
|
|
524
|
+
total = stats["total_tokens"]
|
|
525
|
+
for role, key in [
|
|
526
|
+
("User", "user_tokens"),
|
|
527
|
+
("Assistant", "assistant_tokens"),
|
|
528
|
+
("System", "system_tokens"),
|
|
529
|
+
]:
|
|
530
|
+
tokens = stats.get(key, 0)
|
|
531
|
+
pct = tokens / total * 100 if total > 0 else 0
|
|
532
|
+
table.add_row(role, f"{tokens:,}", f"{pct:.1f}%")
|
|
533
|
+
|
|
534
|
+
console.print(table)
|
|
535
|
+
console.print(f"\n平均对话轮数: {stats.get('avg_turns', 0)}")
|
|
536
|
+
|
|
537
|
+
except ImportError:
|
|
538
|
+
# 没有 rich,使用普通打印
|
|
539
|
+
print(f"\n{'=' * 40}")
|
|
540
|
+
print("📊 Token 统计概览")
|
|
541
|
+
print(f"{'=' * 40}")
|
|
542
|
+
print(f"总样本数: {stats['count']:,}")
|
|
543
|
+
print(f"总 Token: {stats['total_tokens']:,}")
|
|
544
|
+
print(f"平均 Token: {stats['avg_tokens']:,}")
|
|
545
|
+
print(f"中位数: {stats['median_tokens']:,}")
|
|
546
|
+
print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
|
|
547
|
+
|
|
548
|
+
if detailed:
|
|
549
|
+
print(f"\n{'=' * 40}")
|
|
550
|
+
print("📋 分角色统计")
|
|
551
|
+
print(f"{'=' * 40}")
|
|
552
|
+
total = stats["total_tokens"]
|
|
553
|
+
for role, key in [
|
|
554
|
+
("User", "user_tokens"),
|
|
555
|
+
("Assistant", "assistant_tokens"),
|
|
556
|
+
("System", "system_tokens"),
|
|
557
|
+
]:
|
|
558
|
+
tokens = stats.get(key, 0)
|
|
559
|
+
pct = tokens / total * 100 if total > 0 else 0
|
|
560
|
+
print(f"{role}: {tokens:,} ({pct:.1f}%)")
|
|
561
|
+
print(f"\n平均对话轮数: {stats.get('avg_turns', 0)}")
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
def _print_text_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
565
|
+
"""打印普通文本的 token 统计"""
|
|
566
|
+
try:
|
|
567
|
+
from rich.console import Console
|
|
568
|
+
from rich.panel import Panel
|
|
569
|
+
|
|
570
|
+
console = Console()
|
|
571
|
+
|
|
572
|
+
overview = (
|
|
573
|
+
f"[bold]总样本数:[/bold] {stats['count']:,}\n"
|
|
574
|
+
f"[bold]总 Token:[/bold] {stats['total_tokens']:,}\n"
|
|
575
|
+
f"[bold]平均 Token:[/bold] {stats['avg_tokens']:.1f}\n"
|
|
576
|
+
f"[bold]中位数:[/bold] {stats['median_tokens']:,}\n"
|
|
577
|
+
f"[bold]范围:[/bold] {stats['min_tokens']:,} - {stats['max_tokens']:,}"
|
|
578
|
+
)
|
|
579
|
+
console.print(Panel(overview, title="📊 Token 统计", expand=False))
|
|
580
|
+
|
|
581
|
+
except ImportError:
|
|
582
|
+
print(f"\n{'=' * 40}")
|
|
583
|
+
print("📊 Token 统计")
|
|
584
|
+
print(f"{'=' * 40}")
|
|
585
|
+
print(f"总样本数: {stats['count']:,}")
|
|
586
|
+
print(f"总 Token: {stats['total_tokens']:,}")
|
|
587
|
+
print(f"平均 Token: {stats['avg_tokens']:.1f}")
|
|
588
|
+
print(f"中位数: {stats['median_tokens']:,}")
|
|
589
|
+
print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
|