dtflow 0.5.7__py3-none-any.whl → 0.5.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/SKILL.md +39 -5
- dtflow/__init__.py +1 -1
- dtflow/__main__.py +137 -8
- dtflow/cli/clean.py +294 -9
- dtflow/cli/commands.py +17 -1
- dtflow/cli/eval.py +288 -0
- dtflow/cli/export.py +81 -0
- dtflow/cli/sample.py +90 -3
- dtflow/cli/split.py +138 -0
- dtflow/cli/stats.py +224 -30
- dtflow/eval.py +276 -0
- dtflow/utils/text_parser.py +124 -0
- {dtflow-0.5.7.dist-info → dtflow-0.5.9.dist-info}/METADATA +34 -2
- {dtflow-0.5.7.dist-info → dtflow-0.5.9.dist-info}/RECORD +16 -11
- {dtflow-0.5.7.dist-info → dtflow-0.5.9.dist-info}/WHEEL +0 -0
- {dtflow-0.5.7.dist-info → dtflow-0.5.9.dist-info}/entry_points.txt +0 -0
dtflow/cli/stats.py
CHANGED
|
@@ -3,7 +3,7 @@ CLI 数据统计相关命令
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Dict, List
|
|
6
|
+
from typing import Any, Dict, List, Optional
|
|
7
7
|
|
|
8
8
|
import orjson
|
|
9
9
|
|
|
@@ -22,6 +22,8 @@ def stats(
|
|
|
22
22
|
filename: str,
|
|
23
23
|
top: int = 10,
|
|
24
24
|
full: bool = False,
|
|
25
|
+
fields: Optional[List[str]] = None,
|
|
26
|
+
expand_fields: Optional[List[str]] = None,
|
|
25
27
|
) -> None:
|
|
26
28
|
"""
|
|
27
29
|
显示数据文件的统计信息。
|
|
@@ -33,11 +35,15 @@ def stats(
|
|
|
33
35
|
filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
|
|
34
36
|
top: 显示频率最高的前 N 个值,默认 10(仅完整模式)
|
|
35
37
|
full: 完整模式,统计值分布、唯一值等详细信息
|
|
38
|
+
fields: 指定统计的字段列表(支持嵌套路径)
|
|
39
|
+
expand_fields: 展开 list 字段统计的字段列表
|
|
36
40
|
|
|
37
41
|
Examples:
|
|
38
42
|
dt stats data.jsonl # 快速模式(默认)
|
|
39
43
|
dt stats data.jsonl --full # 完整模式
|
|
40
44
|
dt stats data.csv -f --top=5 # 完整模式,显示 Top 5
|
|
45
|
+
dt stats data.jsonl --full --field=category # 指定字段
|
|
46
|
+
dt stats data.jsonl --full --expand=tags # 展开 list 字段
|
|
41
47
|
"""
|
|
42
48
|
filepath = Path(filename)
|
|
43
49
|
|
|
@@ -48,7 +54,10 @@ def stats(
|
|
|
48
54
|
if not _check_file_format(filepath):
|
|
49
55
|
return
|
|
50
56
|
|
|
57
|
+
# 快速模式:忽略 --field 和 --expand 参数
|
|
51
58
|
if not full:
|
|
59
|
+
if fields or expand_fields:
|
|
60
|
+
print("⚠️ 警告: --field 和 --expand 参数仅在完整模式 (--full) 下生效")
|
|
52
61
|
_quick_stats(filepath)
|
|
53
62
|
return
|
|
54
63
|
|
|
@@ -65,7 +74,7 @@ def stats(
|
|
|
65
74
|
|
|
66
75
|
# 计算统计信息
|
|
67
76
|
total = len(data)
|
|
68
|
-
field_stats = _compute_field_stats(data, top)
|
|
77
|
+
field_stats = _compute_field_stats(data, top, fields, expand_fields)
|
|
69
78
|
|
|
70
79
|
# 输出统计信息
|
|
71
80
|
_print_stats(filepath.name, total, field_stats)
|
|
@@ -205,11 +214,99 @@ def _quick_stats(filepath: Path) -> None:
|
|
|
205
214
|
print(f" {i}. {f['field']} ({f['type']})")
|
|
206
215
|
|
|
207
216
|
|
|
208
|
-
def
|
|
217
|
+
def _extract_with_wildcard(item: dict, field_spec: str) -> List[Any]:
|
|
218
|
+
"""处理包含 [*] 的字段路径,返回所有值"""
|
|
219
|
+
if "[*]" not in field_spec:
|
|
220
|
+
# 无 [*],直接返回单个值的列表
|
|
221
|
+
value = get_field_with_spec(item, field_spec)
|
|
222
|
+
return [value] if value is not None else []
|
|
223
|
+
|
|
224
|
+
# 分割路径:messages[*].role -> ("messages", ".role")
|
|
225
|
+
before, after = field_spec.split("[*]", 1)
|
|
226
|
+
after = after.lstrip(".") # 移除开头的点
|
|
227
|
+
|
|
228
|
+
# 获取数组
|
|
229
|
+
array = get_field_with_spec(item, before) if before else item
|
|
230
|
+
if not isinstance(array, list):
|
|
231
|
+
return []
|
|
232
|
+
|
|
233
|
+
# 提取每个元素的后续路径
|
|
234
|
+
results = []
|
|
235
|
+
for elem in array:
|
|
236
|
+
if after:
|
|
237
|
+
val = get_field_with_spec(elem, after)
|
|
238
|
+
else:
|
|
239
|
+
val = elem
|
|
240
|
+
if val is not None:
|
|
241
|
+
results.append(val)
|
|
242
|
+
|
|
243
|
+
return results
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _extract_field_values(
|
|
247
|
+
data: List[Dict],
|
|
248
|
+
field_spec: str,
|
|
249
|
+
expand: bool = False,
|
|
250
|
+
) -> List[Any]:
|
|
251
|
+
"""
|
|
252
|
+
从数据中提取字段值。
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
data: 数据列表
|
|
256
|
+
field_spec: 字段路径规格(如 "messages[*].role")
|
|
257
|
+
expand: 是否展开 list
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
值列表(展开或不展开)
|
|
261
|
+
"""
|
|
262
|
+
all_values = []
|
|
263
|
+
|
|
264
|
+
for item in data:
|
|
265
|
+
if "[*]" in field_spec or expand:
|
|
266
|
+
# 使用通配符提取所有值
|
|
267
|
+
values = _extract_with_wildcard(item, field_spec)
|
|
268
|
+
|
|
269
|
+
if expand and len(values) == 1 and isinstance(values[0], list):
|
|
270
|
+
# 展开模式:如果返回单个列表,展开其元素
|
|
271
|
+
all_values.extend(values[0])
|
|
272
|
+
elif expand and values and isinstance(values[0], list):
|
|
273
|
+
# 多个列表,全部展开
|
|
274
|
+
for v in values:
|
|
275
|
+
if isinstance(v, list):
|
|
276
|
+
all_values.extend(v)
|
|
277
|
+
else:
|
|
278
|
+
all_values.append(v)
|
|
279
|
+
else:
|
|
280
|
+
# 不展开或非列表值
|
|
281
|
+
all_values.extend(values)
|
|
282
|
+
else:
|
|
283
|
+
# 普通字段路径
|
|
284
|
+
value = get_field_with_spec(item, field_spec)
|
|
285
|
+
if expand and isinstance(value, list):
|
|
286
|
+
# 展开 list
|
|
287
|
+
all_values.extend(value)
|
|
288
|
+
else:
|
|
289
|
+
all_values.append(value)
|
|
290
|
+
|
|
291
|
+
return all_values
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def _compute_field_stats(
|
|
295
|
+
data: List[Dict],
|
|
296
|
+
top: int,
|
|
297
|
+
fields: Optional[List[str]] = None,
|
|
298
|
+
expand_fields: Optional[List[str]] = None,
|
|
299
|
+
) -> List[Dict[str, Any]]:
|
|
209
300
|
"""
|
|
210
301
|
单次遍历计算每个字段的统计信息。
|
|
211
302
|
|
|
212
303
|
优化:将多次遍历合并为单次遍历,在遍历过程中同时收集所有统计数据。
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
data: 数据列表
|
|
307
|
+
top: Top N 值数量
|
|
308
|
+
fields: 指定统计的字段列表
|
|
309
|
+
expand_fields: 展开 list 字段统计的字段列表
|
|
213
310
|
"""
|
|
214
311
|
from collections import Counter, defaultdict
|
|
215
312
|
|
|
@@ -218,38 +315,115 @@ def _compute_field_stats(data: List[Dict], top: int) -> List[Dict[str, Any]]:
|
|
|
218
315
|
|
|
219
316
|
total = len(data)
|
|
220
317
|
|
|
221
|
-
#
|
|
222
|
-
|
|
223
|
-
|
|
318
|
+
# 如果没有指定字段,统计所有顶层字段(保持向后兼容)
|
|
319
|
+
if not fields and not expand_fields:
|
|
320
|
+
# 单次遍历收集所有字段的值和统计信息
|
|
321
|
+
field_values = defaultdict(list) # 存储每个字段的所有值
|
|
322
|
+
field_counters = defaultdict(Counter) # 存储每个字段的值频率(用于 top N)
|
|
323
|
+
|
|
324
|
+
for item in data:
|
|
325
|
+
for k, v in item.items():
|
|
326
|
+
field_values[k].append(v)
|
|
327
|
+
# 对值进行截断后计数(用于 top N 显示)
|
|
328
|
+
displayable = _truncate(v if v is not None else "", 30)
|
|
329
|
+
field_counters[k][displayable] += 1
|
|
330
|
+
|
|
331
|
+
# 根据收集的数据计算统计信息
|
|
332
|
+
stats_list = []
|
|
333
|
+
for field in sorted(field_values.keys()):
|
|
334
|
+
values = field_values[field]
|
|
335
|
+
non_null = [v for v in values if v is not None and v != ""]
|
|
336
|
+
non_null_count = len(non_null)
|
|
337
|
+
|
|
338
|
+
# 推断类型(从第一个非空值)
|
|
339
|
+
field_type = _infer_type(non_null)
|
|
340
|
+
|
|
341
|
+
# 基础统计
|
|
342
|
+
stat = {
|
|
343
|
+
"field": field,
|
|
344
|
+
"non_null": non_null_count,
|
|
345
|
+
"null_rate": f"{non_null_count / total * 100:.1f}%",
|
|
346
|
+
"type": field_type,
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
# 类型特定统计
|
|
350
|
+
if non_null:
|
|
351
|
+
# 唯一值计数(对复杂类型使用 hash 节省内存)
|
|
352
|
+
stat["unique"] = _count_unique(non_null, field_type)
|
|
353
|
+
|
|
354
|
+
# 字符串类型:计算长度统计
|
|
355
|
+
if field_type == "str":
|
|
356
|
+
lengths = [len(str(v)) for v in non_null]
|
|
357
|
+
stat["len_min"] = min(lengths)
|
|
358
|
+
stat["len_max"] = max(lengths)
|
|
359
|
+
stat["len_avg"] = sum(lengths) / len(lengths)
|
|
360
|
+
|
|
361
|
+
# 数值类型:计算数值统计
|
|
362
|
+
elif field_type in ("int", "float"):
|
|
363
|
+
nums = [float(v) for v in non_null if _is_numeric(v)]
|
|
364
|
+
if nums:
|
|
365
|
+
stat["min"] = min(nums)
|
|
366
|
+
stat["max"] = max(nums)
|
|
367
|
+
stat["avg"] = sum(nums) / len(nums)
|
|
368
|
+
|
|
369
|
+
# 列表类型:计算长度统计
|
|
370
|
+
elif field_type == "list":
|
|
371
|
+
lengths = [len(v) if isinstance(v, list) else 0 for v in non_null]
|
|
372
|
+
stat["len_min"] = min(lengths)
|
|
373
|
+
stat["len_max"] = max(lengths)
|
|
374
|
+
stat["len_avg"] = sum(lengths) / len(lengths)
|
|
375
|
+
|
|
376
|
+
# Top N 值(已在遍历时收集)
|
|
377
|
+
stat["top_values"] = field_counters[field].most_common(top)
|
|
378
|
+
|
|
379
|
+
stats_list.append(stat)
|
|
380
|
+
|
|
381
|
+
return stats_list
|
|
382
|
+
|
|
383
|
+
# 指定了字段:收集指定字段的统计
|
|
384
|
+
stats_list = []
|
|
385
|
+
expand_set = set(expand_fields) if expand_fields else set()
|
|
224
386
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
# 对值进行截断后计数(用于 top N 显示)
|
|
229
|
-
displayable = _truncate(v if v is not None else "", 30)
|
|
230
|
-
field_counters[k][displayable] += 1
|
|
387
|
+
# 合并字段列表
|
|
388
|
+
all_fields = set(fields) if fields else set()
|
|
389
|
+
all_fields.update(expand_set)
|
|
231
390
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
391
|
+
for field_spec in sorted(all_fields):
|
|
392
|
+
is_expanded = field_spec in expand_set
|
|
393
|
+
|
|
394
|
+
# 提取字段值
|
|
395
|
+
values = _extract_field_values(data, field_spec, expand=is_expanded)
|
|
396
|
+
|
|
397
|
+
# 过滤 None 和空值
|
|
236
398
|
non_null = [v for v in values if v is not None and v != ""]
|
|
237
399
|
non_null_count = len(non_null)
|
|
238
400
|
|
|
239
|
-
#
|
|
401
|
+
# 推断类型
|
|
240
402
|
field_type = _infer_type(non_null)
|
|
241
403
|
|
|
242
404
|
# 基础统计
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
405
|
+
if is_expanded:
|
|
406
|
+
# 展开模式:显示元素总数和平均数,而非非空率
|
|
407
|
+
stat = {
|
|
408
|
+
"field": field_spec,
|
|
409
|
+
"non_null": non_null_count,
|
|
410
|
+
"null_rate": f"总元素: {len(values)}",
|
|
411
|
+
"type": field_type,
|
|
412
|
+
"is_expanded": is_expanded,
|
|
413
|
+
}
|
|
414
|
+
else:
|
|
415
|
+
# 普通模式:显示非空率
|
|
416
|
+
stat = {
|
|
417
|
+
"field": field_spec,
|
|
418
|
+
"non_null": non_null_count,
|
|
419
|
+
"null_rate": f"{non_null_count / total * 100:.1f}%",
|
|
420
|
+
"type": field_type,
|
|
421
|
+
"is_expanded": is_expanded,
|
|
422
|
+
}
|
|
249
423
|
|
|
250
424
|
# 类型特定统计
|
|
251
425
|
if non_null:
|
|
252
|
-
#
|
|
426
|
+
# 唯一值计数
|
|
253
427
|
stat["unique"] = _count_unique(non_null, field_type)
|
|
254
428
|
|
|
255
429
|
# 字符串类型:计算长度统计
|
|
@@ -274,8 +448,12 @@ def _compute_field_stats(data: List[Dict], top: int) -> List[Dict[str, Any]]:
|
|
|
274
448
|
stat["len_max"] = max(lengths)
|
|
275
449
|
stat["len_avg"] = sum(lengths) / len(lengths)
|
|
276
450
|
|
|
277
|
-
# Top N
|
|
278
|
-
|
|
451
|
+
# Top N 值(需要重新计数)
|
|
452
|
+
counter = Counter()
|
|
453
|
+
for v in non_null:
|
|
454
|
+
displayable = _truncate(v if v is not None else "", 30)
|
|
455
|
+
counter[displayable] += 1
|
|
456
|
+
stat["top_values"] = counter.most_common(top)
|
|
279
457
|
|
|
280
458
|
stats_list.append(stat)
|
|
281
459
|
|
|
@@ -343,9 +521,18 @@ def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -
|
|
|
343
521
|
table.add_column("统计", style="dim")
|
|
344
522
|
|
|
345
523
|
for stat in field_stats:
|
|
346
|
-
|
|
524
|
+
# 使用 stat 中的 null_rate(支持展开模式的特殊显示)
|
|
525
|
+
if "null_rate" in stat:
|
|
526
|
+
non_null_rate = stat["null_rate"]
|
|
527
|
+
else:
|
|
528
|
+
non_null_rate = f"{stat['non_null'] / total * 100:.0f}%"
|
|
347
529
|
unique = str(stat.get("unique", "-"))
|
|
348
530
|
|
|
531
|
+
# 字段名(添加展开标记)
|
|
532
|
+
field_name = stat["field"]
|
|
533
|
+
if stat.get("is_expanded"):
|
|
534
|
+
field_name += " (展开)"
|
|
535
|
+
|
|
349
536
|
# 构建统计信息字符串
|
|
350
537
|
extra = []
|
|
351
538
|
if "len_avg" in stat:
|
|
@@ -363,7 +550,7 @@ def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -
|
|
|
363
550
|
)
|
|
364
551
|
|
|
365
552
|
table.add_row(
|
|
366
|
-
|
|
553
|
+
field_name,
|
|
367
554
|
stat["type"],
|
|
368
555
|
non_null_rate,
|
|
369
556
|
unique,
|
|
@@ -387,12 +574,19 @@ def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -
|
|
|
387
574
|
if unique_ratio > 0.9 and stat.get("unique", 0) > 100:
|
|
388
575
|
continue
|
|
389
576
|
|
|
577
|
+
# 字段名(添加展开标记)
|
|
578
|
+
field_display = stat["field"]
|
|
579
|
+
if stat.get("is_expanded"):
|
|
580
|
+
field_display += " (展开)"
|
|
581
|
+
|
|
390
582
|
console.print(
|
|
391
|
-
f"\n[bold cyan]{
|
|
583
|
+
f"\n[bold cyan]{field_display}[/bold cyan] 值分布 (Top {len(top_values)}):"
|
|
392
584
|
)
|
|
393
585
|
max_count = max(c for _, c in top_values) if top_values else 1
|
|
586
|
+
# 展开模式下使用 non_null(元素总数),否则使用 total(数据条数)
|
|
587
|
+
base_count = stat["non_null"] if stat.get("is_expanded") else total
|
|
394
588
|
for value, count in top_values:
|
|
395
|
-
pct = count /
|
|
589
|
+
pct = count / base_count * 100 if base_count > 0 else 0
|
|
396
590
|
bar_len = int(count / max_count * 20) # 按相对比例,最长 20 字符
|
|
397
591
|
bar = "█" * bar_len
|
|
398
592
|
display_value = value if value else "[空]"
|
dtflow/eval.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
"""
|
|
2
|
+
评估指标计算模块
|
|
3
|
+
|
|
4
|
+
提供分类任务的指标计算和评估报告导出:
|
|
5
|
+
- MetricsCalculator: 计算 accuracy/precision/recall/F1/混淆矩阵
|
|
6
|
+
- export_eval_report: 生成 metrics.md + result.jsonl + bad_case.jsonl
|
|
7
|
+
|
|
8
|
+
依赖: scikit-learn, pandas
|
|
9
|
+
安装: pip install dtflow[eval]
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import TYPE_CHECKING, Optional
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from pandas import DataFrame
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _check_eval_deps():
|
|
22
|
+
"""检查 eval 依赖是否已安装"""
|
|
23
|
+
try:
|
|
24
|
+
import pandas # noqa: F401
|
|
25
|
+
import sklearn # noqa: F401
|
|
26
|
+
except ImportError as e:
|
|
27
|
+
missing = str(e).split("'")[1] if "'" in str(e) else str(e)
|
|
28
|
+
raise ImportError(
|
|
29
|
+
f"eval 功能需要额外依赖: {missing}\n" f"请运行: pip install dtflow[eval]"
|
|
30
|
+
) from e
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class MetricsCalculator:
|
|
34
|
+
"""分类指标计算器
|
|
35
|
+
|
|
36
|
+
基于 sklearn 计算 accuracy/precision/recall/F1/混淆矩阵/分类报告。
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
df: 包含预测列和标签列的 DataFrame
|
|
40
|
+
pred_col: 预测值列名
|
|
41
|
+
label_col: 标签值列名
|
|
42
|
+
include_macro_micro_avg: 是否在报告中包含 macro/micro 平均
|
|
43
|
+
remove_matrix_zero_row: 是否移除混淆矩阵中 support=0 的行
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
df: "DataFrame",
|
|
49
|
+
pred_col: str = "predict",
|
|
50
|
+
label_col: str = "label",
|
|
51
|
+
include_macro_micro_avg: bool = False,
|
|
52
|
+
remove_matrix_zero_row: bool = False,
|
|
53
|
+
):
|
|
54
|
+
_check_eval_deps()
|
|
55
|
+
self.df = df
|
|
56
|
+
self.y_pred = df[pred_col]
|
|
57
|
+
self.y_true = df[label_col]
|
|
58
|
+
self.all_labels = sorted(set(self.y_true.unique()).union(set(self.y_pred.unique())))
|
|
59
|
+
self.needed_labels = None
|
|
60
|
+
self.remove_matrix_zero_row = remove_matrix_zero_row
|
|
61
|
+
self.include_macro_micro_avg = include_macro_micro_avg
|
|
62
|
+
self.metrics = self._calculate_metrics()
|
|
63
|
+
|
|
64
|
+
def _calculate_metrics(self):
|
|
65
|
+
from sklearn.metrics import (
|
|
66
|
+
accuracy_score,
|
|
67
|
+
classification_report,
|
|
68
|
+
confusion_matrix,
|
|
69
|
+
precision_score,
|
|
70
|
+
recall_score,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
accuracy = accuracy_score(self.y_true, self.y_pred)
|
|
74
|
+
precision = precision_score(
|
|
75
|
+
self.y_true, self.y_pred, labels=self.all_labels, average="weighted", zero_division=0
|
|
76
|
+
)
|
|
77
|
+
recall = recall_score(
|
|
78
|
+
self.y_true, self.y_pred, labels=self.all_labels, average="weighted", zero_division=0
|
|
79
|
+
)
|
|
80
|
+
conf_matrix = confusion_matrix(self.y_true, self.y_pred, labels=self.all_labels)
|
|
81
|
+
report = classification_report(
|
|
82
|
+
self.y_true, self.y_pred, labels=self.all_labels, output_dict=True, zero_division=0
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# 默认只保留加权平均
|
|
86
|
+
if not self.include_macro_micro_avg:
|
|
87
|
+
report = {
|
|
88
|
+
label: metrics
|
|
89
|
+
for label, metrics in report.items()
|
|
90
|
+
if label in self.all_labels or label == "weighted avg"
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
# 去除 support=0 的类别(注意 accuracy 是 float 不是 dict)
|
|
94
|
+
report = {
|
|
95
|
+
label: metrics
|
|
96
|
+
for label, metrics in report.items()
|
|
97
|
+
if isinstance(metrics, dict) and metrics.get("support", 0) > 0
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
self.needed_labels = [label for label in report.keys() if label in self.all_labels]
|
|
101
|
+
|
|
102
|
+
# 可选移除混淆矩阵中不需要的行
|
|
103
|
+
needed_idx_list = [self.all_labels.index(label) for label in self.needed_labels]
|
|
104
|
+
if self.remove_matrix_zero_row:
|
|
105
|
+
conf_matrix = conf_matrix[needed_idx_list]
|
|
106
|
+
|
|
107
|
+
return {
|
|
108
|
+
"accuracy": accuracy,
|
|
109
|
+
"precision": precision,
|
|
110
|
+
"recall": recall,
|
|
111
|
+
"confusion_matrix": conf_matrix,
|
|
112
|
+
"classification_report": report,
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
def get_metrics(self):
|
|
116
|
+
return self.metrics
|
|
117
|
+
|
|
118
|
+
def format_classification_report_as_markdown(self):
|
|
119
|
+
"""将分类报告格式化为 Markdown 表格"""
|
|
120
|
+
report = self.metrics["classification_report"]
|
|
121
|
+
header = "| Label | Precision | Recall | F1-score | Support |\n"
|
|
122
|
+
separator = "|-------|-----------|--------|----------|---------|\n"
|
|
123
|
+
rows = []
|
|
124
|
+
for label, metrics in report.items():
|
|
125
|
+
if isinstance(metrics, dict):
|
|
126
|
+
rows.append(
|
|
127
|
+
f"| {label} | {metrics['precision']:.2f} | {metrics['recall']:.2f} "
|
|
128
|
+
f"| {metrics['f1-score']:.2f} | {metrics['support']:.0f} |"
|
|
129
|
+
)
|
|
130
|
+
return header + separator + "\n".join(rows)
|
|
131
|
+
|
|
132
|
+
def _clean_label_for_markdown(self, label, max_length=20):
|
|
133
|
+
"""清理标签文本,使其适合 Markdown 表格显示"""
|
|
134
|
+
label = str(label).replace("\n", " ")
|
|
135
|
+
label = label.replace("|", "\\|")
|
|
136
|
+
label = label.replace("-", "\\-")
|
|
137
|
+
label = label.replace("<", "<")
|
|
138
|
+
label = label.replace(">", ">")
|
|
139
|
+
if len(label) > max_length:
|
|
140
|
+
label = label[:max_length] + "..."
|
|
141
|
+
label = label.strip()
|
|
142
|
+
if not label:
|
|
143
|
+
label = "(empty)"
|
|
144
|
+
return label
|
|
145
|
+
|
|
146
|
+
def format_confusion_matrix_as_markdown(self, max_label_length=20):
|
|
147
|
+
"""将混淆矩阵格式化为 Markdown 表格"""
|
|
148
|
+
matrix = self.metrics["confusion_matrix"]
|
|
149
|
+
|
|
150
|
+
if self.remove_matrix_zero_row:
|
|
151
|
+
labels = self.needed_labels
|
|
152
|
+
else:
|
|
153
|
+
labels = self.all_labels
|
|
154
|
+
|
|
155
|
+
processed_labels = [self._clean_label_for_markdown(lb, max_label_length) for lb in labels]
|
|
156
|
+
|
|
157
|
+
header = "| 真实值/预测值 | " + " | ".join(processed_labels) + " |\n"
|
|
158
|
+
separator_parts = [":---:"] * (len(processed_labels) + 1)
|
|
159
|
+
separator = "| " + " | ".join(separator_parts) + " |\n"
|
|
160
|
+
|
|
161
|
+
rows = []
|
|
162
|
+
for i, row in enumerate(matrix):
|
|
163
|
+
row_label = self._clean_label_for_markdown(labels[i], max_label_length)
|
|
164
|
+
formatted_row = [f"{num:,}" for num in row]
|
|
165
|
+
rows.append(f"| {row_label} | " + " | ".join(formatted_row) + " |")
|
|
166
|
+
|
|
167
|
+
return header + separator + "\n".join(rows)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def export_eval_report(
|
|
171
|
+
df: "DataFrame",
|
|
172
|
+
pred_col: str,
|
|
173
|
+
label_col: str,
|
|
174
|
+
record_folder: str = "record",
|
|
175
|
+
input_name: Optional[str] = None,
|
|
176
|
+
):
|
|
177
|
+
"""生成评估报告并保存到指定目录
|
|
178
|
+
|
|
179
|
+
输出文件:
|
|
180
|
+
- metrics.md: 指标概览 + 分类报告 + 混淆矩阵
|
|
181
|
+
- result.jsonl: 完整预测结果
|
|
182
|
+
- bad_case.jsonl: 预测错误样本
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
df: 包含预测和标签的 DataFrame
|
|
186
|
+
pred_col: 预测值列名
|
|
187
|
+
label_col: 标签值列名
|
|
188
|
+
record_folder: 输出根目录
|
|
189
|
+
input_name: 输入文件名(用于子目录命名)
|
|
190
|
+
"""
|
|
191
|
+
from rich.console import Console
|
|
192
|
+
from rich.markdown import Markdown
|
|
193
|
+
|
|
194
|
+
calculator = MetricsCalculator(df, pred_col=pred_col, label_col=label_col)
|
|
195
|
+
metrics = calculator.get_metrics()
|
|
196
|
+
|
|
197
|
+
# 用 Rich Table 构建指标概览(替代 tabulate)
|
|
198
|
+
from rich.table import Table
|
|
199
|
+
|
|
200
|
+
overview_table = Table(title="指标概览", show_header=True)
|
|
201
|
+
overview_table.add_column("Accuracy", justify="center")
|
|
202
|
+
overview_table.add_column("Precision", justify="center")
|
|
203
|
+
overview_table.add_column("Recall", justify="center")
|
|
204
|
+
overview_table.add_row(
|
|
205
|
+
f"{metrics['accuracy']:.4f}",
|
|
206
|
+
f"{metrics['precision']:.4f}",
|
|
207
|
+
f"{metrics['recall']:.4f}",
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# 构建 Markdown 报告内容
|
|
211
|
+
md = (
|
|
212
|
+
f"\n\n### 指标概览\n\n"
|
|
213
|
+
f"| Accuracy | Precision | Recall |\n"
|
|
214
|
+
f"|----------|-----------|--------|\n"
|
|
215
|
+
f"| {metrics['accuracy']:.4f} | {metrics['precision']:.4f} | {metrics['recall']:.4f} |"
|
|
216
|
+
)
|
|
217
|
+
metrics_md = calculator.format_classification_report_as_markdown()
|
|
218
|
+
confusion_md = calculator.format_confusion_matrix_as_markdown()
|
|
219
|
+
md += f"\n\n### Classification Report\n{metrics_md}\n" f"\n### Confusion Matrix\n{confusion_md}"
|
|
220
|
+
|
|
221
|
+
# 创建输出目录(带序号和时间戳)
|
|
222
|
+
now = datetime.now().strftime("%Y%m%d-%H-%M-%S")
|
|
223
|
+
record_path = Path(record_folder)
|
|
224
|
+
if input_name:
|
|
225
|
+
record_path = record_path / input_name
|
|
226
|
+
|
|
227
|
+
if record_path.exists():
|
|
228
|
+
existing = [d.name for d in record_path.iterdir() if d.is_dir()]
|
|
229
|
+
max_idx = 0
|
|
230
|
+
for name in existing:
|
|
231
|
+
parts = name.split("-", 1)
|
|
232
|
+
if parts[0].isdigit():
|
|
233
|
+
max_idx = max(max_idx, int(parts[0]))
|
|
234
|
+
idx = max_idx + 1
|
|
235
|
+
else:
|
|
236
|
+
idx = 1
|
|
237
|
+
|
|
238
|
+
record_path = record_path / f"{idx}-{now}"
|
|
239
|
+
record_path.mkdir(parents=True, exist_ok=True)
|
|
240
|
+
|
|
241
|
+
# 终端输出
|
|
242
|
+
console = Console()
|
|
243
|
+
console.print(overview_table)
|
|
244
|
+
console.print(Markdown(md))
|
|
245
|
+
|
|
246
|
+
# 保存文件
|
|
247
|
+
with open(os.path.join(record_path, "metrics.md"), "w", encoding="utf-8") as f:
|
|
248
|
+
f.write(md)
|
|
249
|
+
|
|
250
|
+
bad_case_df = df[df[pred_col] != df[label_col]]
|
|
251
|
+
|
|
252
|
+
# 保存 JSONL
|
|
253
|
+
df.to_json(
|
|
254
|
+
os.path.join(record_path, "result.jsonl"),
|
|
255
|
+
orient="records",
|
|
256
|
+
lines=True,
|
|
257
|
+
force_ascii=False,
|
|
258
|
+
)
|
|
259
|
+
bad_case_df.to_json(
|
|
260
|
+
os.path.join(record_path, "bad_case.jsonl"),
|
|
261
|
+
orient="records",
|
|
262
|
+
lines=True,
|
|
263
|
+
force_ascii=False,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# 尝试保存 CSV
|
|
267
|
+
try:
|
|
268
|
+
df.to_csv(os.path.join(record_path, "result.csv"), index=False)
|
|
269
|
+
bad_case_df.to_csv(os.path.join(record_path, "bad_case.csv"), index=False)
|
|
270
|
+
except Exception:
|
|
271
|
+
pass
|
|
272
|
+
|
|
273
|
+
console.print(f"\n[green]报告已保存到: {record_path}[/green]")
|
|
274
|
+
console.print(f"[dim] - metrics.md ({len(df)} 条数据, {len(bad_case_df)} 条错误)[/dim]")
|
|
275
|
+
|
|
276
|
+
return record_path
|