dtflow 0.5.8__py3-none-any.whl → 0.5.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/SKILL.md +22 -8
- dtflow/__init__.py +1 -1
- dtflow/__main__.py +108 -14
- dtflow/cli/clean.py +90 -1
- dtflow/cli/commands.py +17 -1
- dtflow/cli/eval.py +288 -0
- dtflow/cli/export.py +81 -0
- dtflow/cli/sample.py +90 -3
- dtflow/cli/split.py +138 -0
- dtflow/cli/stats.py +10 -23
- dtflow/cli/validate.py +19 -52
- dtflow/eval.py +276 -0
- dtflow/schema.py +13 -99
- dtflow/tokenizers.py +21 -104
- dtflow/utils/text_parser.py +124 -0
- {dtflow-0.5.8.dist-info → dtflow-0.5.9.dist-info}/METADATA +29 -3
- {dtflow-0.5.8.dist-info → dtflow-0.5.9.dist-info}/RECORD +19 -15
- dtflow/parallel.py +0 -115
- {dtflow-0.5.8.dist-info → dtflow-0.5.9.dist-info}/WHEEL +0 -0
- {dtflow-0.5.8.dist-info → dtflow-0.5.9.dist-info}/entry_points.txt +0 -0
dtflow/cli/stats.py
CHANGED
|
@@ -209,7 +209,7 @@ def _quick_stats(filepath: Path) -> None:
|
|
|
209
209
|
print(f"字段: {len(fields)} 个")
|
|
210
210
|
|
|
211
211
|
if fields:
|
|
212
|
-
print("\n📋 字段结构:")
|
|
212
|
+
print(f"\n📋 字段结构:")
|
|
213
213
|
for i, f in enumerate(fields, 1):
|
|
214
214
|
print(f" {i}. {f['field']} ({f['type']})")
|
|
215
215
|
|
|
@@ -597,14 +597,14 @@ def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -
|
|
|
597
597
|
except ImportError:
|
|
598
598
|
# 没有 rich,使用普通打印
|
|
599
599
|
print(f"\n{'=' * 50}")
|
|
600
|
-
print("📊 数据概览")
|
|
600
|
+
print(f"📊 数据概览")
|
|
601
601
|
print(f"{'=' * 50}")
|
|
602
602
|
print(f"文件: {filename}")
|
|
603
603
|
print(f"总数: {total:,} 条")
|
|
604
604
|
print(f"字段: {len(field_stats)} 个")
|
|
605
605
|
|
|
606
606
|
print(f"\n{'=' * 50}")
|
|
607
|
-
print("📋 字段统计")
|
|
607
|
+
print(f"📋 字段统计")
|
|
608
608
|
print(f"{'=' * 50}")
|
|
609
609
|
print(f"{'字段':<20} {'类型':<8} {'非空率':<8} {'唯一值':<8}")
|
|
610
610
|
print("-" * 50)
|
|
@@ -620,7 +620,6 @@ def token_stats(
|
|
|
620
620
|
field: str = "messages",
|
|
621
621
|
model: str = "cl100k_base",
|
|
622
622
|
detailed: bool = False,
|
|
623
|
-
workers: Optional[int] = None,
|
|
624
623
|
) -> None:
|
|
625
624
|
"""
|
|
626
625
|
统计数据集的 Token 信息。
|
|
@@ -630,7 +629,6 @@ def token_stats(
|
|
|
630
629
|
field: 要统计的字段(默认 messages),支持嵌套路径语法
|
|
631
630
|
model: 分词器: cl100k_base (默认), qwen2.5, llama3, gpt-4 等
|
|
632
631
|
detailed: 是否显示详细统计
|
|
633
|
-
workers: 并行进程数,None 自动检测,1 禁用并行
|
|
634
632
|
|
|
635
633
|
Examples:
|
|
636
634
|
dt token-stats data.jsonl
|
|
@@ -638,7 +636,6 @@ def token_stats(
|
|
|
638
636
|
dt token-stats data.jsonl --field=conversation.messages
|
|
639
637
|
dt token-stats data.jsonl --field=messages[-1].content # 统计最后一条消息
|
|
640
638
|
dt token-stats data.jsonl --detailed
|
|
641
|
-
dt token-stats data.jsonl --workers=4 # 使用 4 进程
|
|
642
639
|
"""
|
|
643
640
|
filepath = Path(filename)
|
|
644
641
|
|
|
@@ -670,7 +667,7 @@ def token_stats(
|
|
|
670
667
|
|
|
671
668
|
# 尝试使用 rich 进度条
|
|
672
669
|
try:
|
|
673
|
-
from rich.progress import
|
|
670
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
|
|
674
671
|
|
|
675
672
|
with Progress(
|
|
676
673
|
SpinnerColumn(),
|
|
@@ -688,22 +685,14 @@ def token_stats(
|
|
|
688
685
|
from ..tokenizers import messages_token_stats
|
|
689
686
|
|
|
690
687
|
stats_result = messages_token_stats(
|
|
691
|
-
data,
|
|
692
|
-
messages_field=field,
|
|
693
|
-
model=model,
|
|
694
|
-
progress_callback=update_progress,
|
|
695
|
-
workers=workers,
|
|
688
|
+
data, messages_field=field, model=model, progress_callback=update_progress
|
|
696
689
|
)
|
|
697
690
|
_print_messages_token_stats(stats_result, detailed)
|
|
698
691
|
else:
|
|
699
692
|
from ..tokenizers import token_stats as compute_token_stats
|
|
700
693
|
|
|
701
694
|
stats_result = compute_token_stats(
|
|
702
|
-
data,
|
|
703
|
-
fields=field,
|
|
704
|
-
model=model,
|
|
705
|
-
progress_callback=update_progress,
|
|
706
|
-
workers=workers,
|
|
695
|
+
data, fields=field, model=model, progress_callback=update_progress
|
|
707
696
|
)
|
|
708
697
|
_print_text_token_stats(stats_result, detailed)
|
|
709
698
|
|
|
@@ -714,14 +703,12 @@ def token_stats(
|
|
|
714
703
|
if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
|
|
715
704
|
from ..tokenizers import messages_token_stats
|
|
716
705
|
|
|
717
|
-
stats_result = messages_token_stats(
|
|
718
|
-
data, messages_field=field, model=model, workers=workers
|
|
719
|
-
)
|
|
706
|
+
stats_result = messages_token_stats(data, messages_field=field, model=model)
|
|
720
707
|
_print_messages_token_stats(stats_result, detailed)
|
|
721
708
|
else:
|
|
722
709
|
from ..tokenizers import token_stats as compute_token_stats
|
|
723
710
|
|
|
724
|
-
stats_result = compute_token_stats(data, fields=field, model=model
|
|
711
|
+
stats_result = compute_token_stats(data, fields=field, model=model)
|
|
725
712
|
_print_text_token_stats(stats_result, detailed)
|
|
726
713
|
except ImportError as e:
|
|
727
714
|
print(f"错误: {e}")
|
|
@@ -801,7 +788,7 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
|
801
788
|
print(f"平均 Token: {stats['avg_tokens']:,} (std: {std:.1f})")
|
|
802
789
|
print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
|
|
803
790
|
|
|
804
|
-
print("\n📈 百分位分布:")
|
|
791
|
+
print(f"\n📈 百分位分布:")
|
|
805
792
|
print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
|
|
806
793
|
print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
|
|
807
794
|
print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
|
|
@@ -868,7 +855,7 @@ def _print_text_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
|
|
|
868
855
|
print(f"平均 Token: {stats['avg_tokens']:.1f} (std: {std:.1f})")
|
|
869
856
|
print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
|
|
870
857
|
|
|
871
|
-
print("\n📈 百分位分布:")
|
|
858
|
+
print(f"\n📈 百分位分布:")
|
|
872
859
|
print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
|
|
873
860
|
print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
|
|
874
861
|
print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
|
dtflow/cli/validate.py
CHANGED
|
@@ -6,6 +6,8 @@ from pathlib import Path
|
|
|
6
6
|
from typing import Optional
|
|
7
7
|
|
|
8
8
|
from ..schema import (
|
|
9
|
+
Schema,
|
|
10
|
+
Field,
|
|
9
11
|
alpaca_schema,
|
|
10
12
|
dpo_schema,
|
|
11
13
|
openai_chat_schema,
|
|
@@ -14,6 +16,7 @@ from ..schema import (
|
|
|
14
16
|
from ..storage.io import load_data, save_data
|
|
15
17
|
from .common import _check_file_format
|
|
16
18
|
|
|
19
|
+
|
|
17
20
|
# 预设 Schema 映射
|
|
18
21
|
PRESET_SCHEMAS = {
|
|
19
22
|
"openai_chat": openai_chat_schema,
|
|
@@ -33,7 +36,6 @@ def validate(
|
|
|
33
36
|
filter_invalid: bool = False,
|
|
34
37
|
max_errors: int = 20,
|
|
35
38
|
verbose: bool = False,
|
|
36
|
-
workers: Optional[int] = None,
|
|
37
39
|
) -> None:
|
|
38
40
|
"""
|
|
39
41
|
使用 Schema 验证数据文件。
|
|
@@ -45,13 +47,11 @@ def validate(
|
|
|
45
47
|
filter_invalid: 过滤无效数据并保存
|
|
46
48
|
max_errors: 最多显示的错误数量
|
|
47
49
|
verbose: 显示详细信息
|
|
48
|
-
workers: 并行进程数,None 自动检测,1 禁用并行
|
|
49
50
|
|
|
50
51
|
Examples:
|
|
51
52
|
dt validate data.jsonl --preset=openai_chat
|
|
52
53
|
dt validate data.jsonl --preset=alpaca -o valid.jsonl
|
|
53
54
|
dt validate data.jsonl --preset=chat --filter
|
|
54
|
-
dt validate data.jsonl --preset=chat --workers=4
|
|
55
55
|
"""
|
|
56
56
|
filepath = Path(filename)
|
|
57
57
|
|
|
@@ -99,54 +99,19 @@ def validate(
|
|
|
99
99
|
print(f"总记录数: {total}")
|
|
100
100
|
print()
|
|
101
101
|
|
|
102
|
-
#
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
with Progress(
|
|
117
|
-
SpinnerColumn(),
|
|
118
|
-
TextColumn("[bold blue]验证数据"),
|
|
119
|
-
BarColumn(),
|
|
120
|
-
TaskProgressColumn(),
|
|
121
|
-
) as progress:
|
|
122
|
-
task = progress.add_task("", total=total)
|
|
123
|
-
|
|
124
|
-
def update_progress(current: int, total_count: int):
|
|
125
|
-
progress.update(task, completed=current)
|
|
126
|
-
|
|
127
|
-
valid_data, invalid_results = schema.validate_parallel(
|
|
128
|
-
data, workers=workers, progress_callback=update_progress
|
|
129
|
-
)
|
|
130
|
-
except ImportError:
|
|
131
|
-
print("🔍 验证数据...")
|
|
132
|
-
valid_data, invalid_results = schema.validate_parallel(data, workers=workers)
|
|
133
|
-
|
|
134
|
-
invalid_count = len(invalid_results)
|
|
135
|
-
error_samples = invalid_results[:max_errors]
|
|
136
|
-
else:
|
|
137
|
-
# 串行验证
|
|
138
|
-
valid_data = []
|
|
139
|
-
invalid_count = 0
|
|
140
|
-
error_samples = []
|
|
141
|
-
|
|
142
|
-
for i, item in enumerate(data):
|
|
143
|
-
result = schema.validate(item)
|
|
144
|
-
if result.valid:
|
|
145
|
-
valid_data.append(item)
|
|
146
|
-
else:
|
|
147
|
-
invalid_count += 1
|
|
148
|
-
if len(error_samples) < max_errors:
|
|
149
|
-
error_samples.append((i, result))
|
|
102
|
+
# 验证
|
|
103
|
+
valid_data = []
|
|
104
|
+
invalid_count = 0
|
|
105
|
+
error_samples = []
|
|
106
|
+
|
|
107
|
+
for i, item in enumerate(data):
|
|
108
|
+
result = schema.validate(item)
|
|
109
|
+
if result.valid:
|
|
110
|
+
valid_data.append(item)
|
|
111
|
+
else:
|
|
112
|
+
invalid_count += 1
|
|
113
|
+
if len(error_samples) < max_errors:
|
|
114
|
+
error_samples.append((i, result))
|
|
150
115
|
|
|
151
116
|
valid_count = len(valid_data)
|
|
152
117
|
valid_ratio = valid_count / total * 100 if total > 0 else 0
|
|
@@ -173,7 +138,9 @@ def validate(
|
|
|
173
138
|
|
|
174
139
|
# 保存有效数据
|
|
175
140
|
if output or filter_invalid:
|
|
176
|
-
output_path = output or str(filepath).replace(
|
|
141
|
+
output_path = output or str(filepath).replace(
|
|
142
|
+
filepath.suffix, f"_valid{filepath.suffix}"
|
|
143
|
+
)
|
|
177
144
|
save_data(valid_data, output_path)
|
|
178
145
|
print(f"✅ 有效数据已保存: {output_path} ({valid_count} 条)")
|
|
179
146
|
|
dtflow/eval.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
"""
|
|
2
|
+
评估指标计算模块
|
|
3
|
+
|
|
4
|
+
提供分类任务的指标计算和评估报告导出:
|
|
5
|
+
- MetricsCalculator: 计算 accuracy/precision/recall/F1/混淆矩阵
|
|
6
|
+
- export_eval_report: 生成 metrics.md + result.jsonl + bad_case.jsonl
|
|
7
|
+
|
|
8
|
+
依赖: scikit-learn, pandas
|
|
9
|
+
安装: pip install dtflow[eval]
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import TYPE_CHECKING, Optional
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from pandas import DataFrame
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _check_eval_deps():
|
|
22
|
+
"""检查 eval 依赖是否已安装"""
|
|
23
|
+
try:
|
|
24
|
+
import pandas # noqa: F401
|
|
25
|
+
import sklearn # noqa: F401
|
|
26
|
+
except ImportError as e:
|
|
27
|
+
missing = str(e).split("'")[1] if "'" in str(e) else str(e)
|
|
28
|
+
raise ImportError(
|
|
29
|
+
f"eval 功能需要额外依赖: {missing}\n" f"请运行: pip install dtflow[eval]"
|
|
30
|
+
) from e
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class MetricsCalculator:
|
|
34
|
+
"""分类指标计算器
|
|
35
|
+
|
|
36
|
+
基于 sklearn 计算 accuracy/precision/recall/F1/混淆矩阵/分类报告。
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
df: 包含预测列和标签列的 DataFrame
|
|
40
|
+
pred_col: 预测值列名
|
|
41
|
+
label_col: 标签值列名
|
|
42
|
+
include_macro_micro_avg: 是否在报告中包含 macro/micro 平均
|
|
43
|
+
remove_matrix_zero_row: 是否移除混淆矩阵中 support=0 的行
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
df: "DataFrame",
|
|
49
|
+
pred_col: str = "predict",
|
|
50
|
+
label_col: str = "label",
|
|
51
|
+
include_macro_micro_avg: bool = False,
|
|
52
|
+
remove_matrix_zero_row: bool = False,
|
|
53
|
+
):
|
|
54
|
+
_check_eval_deps()
|
|
55
|
+
self.df = df
|
|
56
|
+
self.y_pred = df[pred_col]
|
|
57
|
+
self.y_true = df[label_col]
|
|
58
|
+
self.all_labels = sorted(set(self.y_true.unique()).union(set(self.y_pred.unique())))
|
|
59
|
+
self.needed_labels = None
|
|
60
|
+
self.remove_matrix_zero_row = remove_matrix_zero_row
|
|
61
|
+
self.include_macro_micro_avg = include_macro_micro_avg
|
|
62
|
+
self.metrics = self._calculate_metrics()
|
|
63
|
+
|
|
64
|
+
def _calculate_metrics(self):
|
|
65
|
+
from sklearn.metrics import (
|
|
66
|
+
accuracy_score,
|
|
67
|
+
classification_report,
|
|
68
|
+
confusion_matrix,
|
|
69
|
+
precision_score,
|
|
70
|
+
recall_score,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
accuracy = accuracy_score(self.y_true, self.y_pred)
|
|
74
|
+
precision = precision_score(
|
|
75
|
+
self.y_true, self.y_pred, labels=self.all_labels, average="weighted", zero_division=0
|
|
76
|
+
)
|
|
77
|
+
recall = recall_score(
|
|
78
|
+
self.y_true, self.y_pred, labels=self.all_labels, average="weighted", zero_division=0
|
|
79
|
+
)
|
|
80
|
+
conf_matrix = confusion_matrix(self.y_true, self.y_pred, labels=self.all_labels)
|
|
81
|
+
report = classification_report(
|
|
82
|
+
self.y_true, self.y_pred, labels=self.all_labels, output_dict=True, zero_division=0
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# 默认只保留加权平均
|
|
86
|
+
if not self.include_macro_micro_avg:
|
|
87
|
+
report = {
|
|
88
|
+
label: metrics
|
|
89
|
+
for label, metrics in report.items()
|
|
90
|
+
if label in self.all_labels or label == "weighted avg"
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
# 去除 support=0 的类别(注意 accuracy 是 float 不是 dict)
|
|
94
|
+
report = {
|
|
95
|
+
label: metrics
|
|
96
|
+
for label, metrics in report.items()
|
|
97
|
+
if isinstance(metrics, dict) and metrics.get("support", 0) > 0
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
self.needed_labels = [label for label in report.keys() if label in self.all_labels]
|
|
101
|
+
|
|
102
|
+
# 可选移除混淆矩阵中不需要的行
|
|
103
|
+
needed_idx_list = [self.all_labels.index(label) for label in self.needed_labels]
|
|
104
|
+
if self.remove_matrix_zero_row:
|
|
105
|
+
conf_matrix = conf_matrix[needed_idx_list]
|
|
106
|
+
|
|
107
|
+
return {
|
|
108
|
+
"accuracy": accuracy,
|
|
109
|
+
"precision": precision,
|
|
110
|
+
"recall": recall,
|
|
111
|
+
"confusion_matrix": conf_matrix,
|
|
112
|
+
"classification_report": report,
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
def get_metrics(self):
|
|
116
|
+
return self.metrics
|
|
117
|
+
|
|
118
|
+
def format_classification_report_as_markdown(self):
|
|
119
|
+
"""将分类报告格式化为 Markdown 表格"""
|
|
120
|
+
report = self.metrics["classification_report"]
|
|
121
|
+
header = "| Label | Precision | Recall | F1-score | Support |\n"
|
|
122
|
+
separator = "|-------|-----------|--------|----------|---------|\n"
|
|
123
|
+
rows = []
|
|
124
|
+
for label, metrics in report.items():
|
|
125
|
+
if isinstance(metrics, dict):
|
|
126
|
+
rows.append(
|
|
127
|
+
f"| {label} | {metrics['precision']:.2f} | {metrics['recall']:.2f} "
|
|
128
|
+
f"| {metrics['f1-score']:.2f} | {metrics['support']:.0f} |"
|
|
129
|
+
)
|
|
130
|
+
return header + separator + "\n".join(rows)
|
|
131
|
+
|
|
132
|
+
def _clean_label_for_markdown(self, label, max_length=20):
|
|
133
|
+
"""清理标签文本,使其适合 Markdown 表格显示"""
|
|
134
|
+
label = str(label).replace("\n", " ")
|
|
135
|
+
label = label.replace("|", "\\|")
|
|
136
|
+
label = label.replace("-", "\\-")
|
|
137
|
+
label = label.replace("<", "<")
|
|
138
|
+
label = label.replace(">", ">")
|
|
139
|
+
if len(label) > max_length:
|
|
140
|
+
label = label[:max_length] + "..."
|
|
141
|
+
label = label.strip()
|
|
142
|
+
if not label:
|
|
143
|
+
label = "(empty)"
|
|
144
|
+
return label
|
|
145
|
+
|
|
146
|
+
def format_confusion_matrix_as_markdown(self, max_label_length=20):
|
|
147
|
+
"""将混淆矩阵格式化为 Markdown 表格"""
|
|
148
|
+
matrix = self.metrics["confusion_matrix"]
|
|
149
|
+
|
|
150
|
+
if self.remove_matrix_zero_row:
|
|
151
|
+
labels = self.needed_labels
|
|
152
|
+
else:
|
|
153
|
+
labels = self.all_labels
|
|
154
|
+
|
|
155
|
+
processed_labels = [self._clean_label_for_markdown(lb, max_label_length) for lb in labels]
|
|
156
|
+
|
|
157
|
+
header = "| 真实值/预测值 | " + " | ".join(processed_labels) + " |\n"
|
|
158
|
+
separator_parts = [":---:"] * (len(processed_labels) + 1)
|
|
159
|
+
separator = "| " + " | ".join(separator_parts) + " |\n"
|
|
160
|
+
|
|
161
|
+
rows = []
|
|
162
|
+
for i, row in enumerate(matrix):
|
|
163
|
+
row_label = self._clean_label_for_markdown(labels[i], max_label_length)
|
|
164
|
+
formatted_row = [f"{num:,}" for num in row]
|
|
165
|
+
rows.append(f"| {row_label} | " + " | ".join(formatted_row) + " |")
|
|
166
|
+
|
|
167
|
+
return header + separator + "\n".join(rows)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def export_eval_report(
|
|
171
|
+
df: "DataFrame",
|
|
172
|
+
pred_col: str,
|
|
173
|
+
label_col: str,
|
|
174
|
+
record_folder: str = "record",
|
|
175
|
+
input_name: Optional[str] = None,
|
|
176
|
+
):
|
|
177
|
+
"""生成评估报告并保存到指定目录
|
|
178
|
+
|
|
179
|
+
输出文件:
|
|
180
|
+
- metrics.md: 指标概览 + 分类报告 + 混淆矩阵
|
|
181
|
+
- result.jsonl: 完整预测结果
|
|
182
|
+
- bad_case.jsonl: 预测错误样本
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
df: 包含预测和标签的 DataFrame
|
|
186
|
+
pred_col: 预测值列名
|
|
187
|
+
label_col: 标签值列名
|
|
188
|
+
record_folder: 输出根目录
|
|
189
|
+
input_name: 输入文件名(用于子目录命名)
|
|
190
|
+
"""
|
|
191
|
+
from rich.console import Console
|
|
192
|
+
from rich.markdown import Markdown
|
|
193
|
+
|
|
194
|
+
calculator = MetricsCalculator(df, pred_col=pred_col, label_col=label_col)
|
|
195
|
+
metrics = calculator.get_metrics()
|
|
196
|
+
|
|
197
|
+
# 用 Rich Table 构建指标概览(替代 tabulate)
|
|
198
|
+
from rich.table import Table
|
|
199
|
+
|
|
200
|
+
overview_table = Table(title="指标概览", show_header=True)
|
|
201
|
+
overview_table.add_column("Accuracy", justify="center")
|
|
202
|
+
overview_table.add_column("Precision", justify="center")
|
|
203
|
+
overview_table.add_column("Recall", justify="center")
|
|
204
|
+
overview_table.add_row(
|
|
205
|
+
f"{metrics['accuracy']:.4f}",
|
|
206
|
+
f"{metrics['precision']:.4f}",
|
|
207
|
+
f"{metrics['recall']:.4f}",
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# 构建 Markdown 报告内容
|
|
211
|
+
md = (
|
|
212
|
+
f"\n\n### 指标概览\n\n"
|
|
213
|
+
f"| Accuracy | Precision | Recall |\n"
|
|
214
|
+
f"|----------|-----------|--------|\n"
|
|
215
|
+
f"| {metrics['accuracy']:.4f} | {metrics['precision']:.4f} | {metrics['recall']:.4f} |"
|
|
216
|
+
)
|
|
217
|
+
metrics_md = calculator.format_classification_report_as_markdown()
|
|
218
|
+
confusion_md = calculator.format_confusion_matrix_as_markdown()
|
|
219
|
+
md += f"\n\n### Classification Report\n{metrics_md}\n" f"\n### Confusion Matrix\n{confusion_md}"
|
|
220
|
+
|
|
221
|
+
# 创建输出目录(带序号和时间戳)
|
|
222
|
+
now = datetime.now().strftime("%Y%m%d-%H-%M-%S")
|
|
223
|
+
record_path = Path(record_folder)
|
|
224
|
+
if input_name:
|
|
225
|
+
record_path = record_path / input_name
|
|
226
|
+
|
|
227
|
+
if record_path.exists():
|
|
228
|
+
existing = [d.name for d in record_path.iterdir() if d.is_dir()]
|
|
229
|
+
max_idx = 0
|
|
230
|
+
for name in existing:
|
|
231
|
+
parts = name.split("-", 1)
|
|
232
|
+
if parts[0].isdigit():
|
|
233
|
+
max_idx = max(max_idx, int(parts[0]))
|
|
234
|
+
idx = max_idx + 1
|
|
235
|
+
else:
|
|
236
|
+
idx = 1
|
|
237
|
+
|
|
238
|
+
record_path = record_path / f"{idx}-{now}"
|
|
239
|
+
record_path.mkdir(parents=True, exist_ok=True)
|
|
240
|
+
|
|
241
|
+
# 终端输出
|
|
242
|
+
console = Console()
|
|
243
|
+
console.print(overview_table)
|
|
244
|
+
console.print(Markdown(md))
|
|
245
|
+
|
|
246
|
+
# 保存文件
|
|
247
|
+
with open(os.path.join(record_path, "metrics.md"), "w", encoding="utf-8") as f:
|
|
248
|
+
f.write(md)
|
|
249
|
+
|
|
250
|
+
bad_case_df = df[df[pred_col] != df[label_col]]
|
|
251
|
+
|
|
252
|
+
# 保存 JSONL
|
|
253
|
+
df.to_json(
|
|
254
|
+
os.path.join(record_path, "result.jsonl"),
|
|
255
|
+
orient="records",
|
|
256
|
+
lines=True,
|
|
257
|
+
force_ascii=False,
|
|
258
|
+
)
|
|
259
|
+
bad_case_df.to_json(
|
|
260
|
+
os.path.join(record_path, "bad_case.jsonl"),
|
|
261
|
+
orient="records",
|
|
262
|
+
lines=True,
|
|
263
|
+
force_ascii=False,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# 尝试保存 CSV
|
|
267
|
+
try:
|
|
268
|
+
df.to_csv(os.path.join(record_path, "result.csv"), index=False)
|
|
269
|
+
bad_case_df.to_csv(os.path.join(record_path, "bad_case.csv"), index=False)
|
|
270
|
+
except Exception:
|
|
271
|
+
pass
|
|
272
|
+
|
|
273
|
+
console.print(f"\n[green]报告已保存到: {record_path}[/green]")
|
|
274
|
+
console.print(f"[dim] - metrics.md ({len(df)} 条数据, {len(bad_case_df)} 条错误)[/dim]")
|
|
275
|
+
|
|
276
|
+
return record_path
|
dtflow/schema.py
CHANGED
|
@@ -26,35 +26,10 @@ Schema 验证模块
|
|
|
26
26
|
results = dt.validate_schema(schema)
|
|
27
27
|
"""
|
|
28
28
|
|
|
29
|
-
from dataclasses import dataclass
|
|
30
|
-
from
|
|
31
|
-
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
|
32
|
-
|
|
33
|
-
from .utils.field_path import _parse_path, get_field
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def _validate_item_wrapper(args: tuple) -> Tuple[int, bool, list]:
|
|
37
|
-
"""
|
|
38
|
-
验证单条数据(用于多进程)。
|
|
39
|
-
|
|
40
|
-
Args:
|
|
41
|
-
args: (index, item, schema_fields) 元组
|
|
42
|
-
|
|
43
|
-
Returns:
|
|
44
|
-
(index, is_valid, errors_as_dicts) - 返回字典列表而非对象(pickle 兼容)
|
|
45
|
-
"""
|
|
46
|
-
idx, item, fields = args
|
|
47
|
-
# 在子进程中重建 Schema
|
|
48
|
-
schema = Schema(fields)
|
|
49
|
-
result = schema.validate(item)
|
|
50
|
-
|
|
51
|
-
if result.valid:
|
|
52
|
-
return (idx, True, [])
|
|
53
|
-
else:
|
|
54
|
-
# 将错误转换为字典(pickle 兼容)
|
|
55
|
-
errors = [{"path": e.path, "message": e.message, "value": e.value} for e in result.errors]
|
|
56
|
-
return (idx, False, errors)
|
|
29
|
+
from dataclasses import dataclass, field as dataclass_field
|
|
30
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Union
|
|
57
31
|
|
|
32
|
+
from .utils.field_path import get_field, _parse_path, _get_value_by_segments
|
|
58
33
|
|
|
59
34
|
# 支持的类型
|
|
60
35
|
FieldType = Literal["str", "int", "float", "bool", "list", "dict", "any"]
|
|
@@ -187,7 +162,9 @@ class Field:
|
|
|
187
162
|
|
|
188
163
|
# 选项检查
|
|
189
164
|
if self.choices is not None and value not in self.choices:
|
|
190
|
-
errors.append(
|
|
165
|
+
errors.append(
|
|
166
|
+
ValidationError(path, f"值必须是 {self.choices} 之一", value)
|
|
167
|
+
)
|
|
191
168
|
|
|
192
169
|
# 正则表达式检查
|
|
193
170
|
if self.pattern is not None and isinstance(value, str):
|
|
@@ -347,7 +324,9 @@ class Schema:
|
|
|
347
324
|
|
|
348
325
|
return errors
|
|
349
326
|
|
|
350
|
-
def validate_batch(
|
|
327
|
+
def validate_batch(
|
|
328
|
+
self, data: List[dict], max_errors: int = 100
|
|
329
|
+
) -> List[tuple]:
|
|
351
330
|
"""
|
|
352
331
|
批量验证数据
|
|
353
332
|
|
|
@@ -371,76 +350,9 @@ class Schema:
|
|
|
371
350
|
|
|
372
351
|
return failed
|
|
373
352
|
|
|
374
|
-
def validate_parallel(
|
|
375
|
-
self,
|
|
376
|
-
data: List[dict],
|
|
377
|
-
workers: Optional[int] = None,
|
|
378
|
-
progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
379
|
-
) -> tuple:
|
|
380
|
-
"""
|
|
381
|
-
并行验证数据列表。
|
|
382
|
-
|
|
383
|
-
Args:
|
|
384
|
-
data: 数据列表
|
|
385
|
-
workers: 进程数,None 自动检测,1 禁用并行
|
|
386
|
-
progress_callback: 进度回调函数
|
|
387
|
-
|
|
388
|
-
Returns:
|
|
389
|
-
(valid_data, invalid_indices_results) 元组
|
|
390
|
-
- valid_data: 有效数据列表
|
|
391
|
-
- invalid_indices_results: [(index, ValidationResult), ...] 无效数据
|
|
392
|
-
"""
|
|
393
|
-
if not data:
|
|
394
|
-
return [], []
|
|
395
|
-
|
|
396
|
-
total = len(data)
|
|
397
|
-
use_parallel = workers != 1 and total >= 1000
|
|
398
|
-
|
|
399
|
-
valid_data = []
|
|
400
|
-
invalid_results = []
|
|
401
|
-
|
|
402
|
-
if use_parallel:
|
|
403
|
-
from .parallel import get_optimal_workers, parallel_imap
|
|
404
|
-
|
|
405
|
-
actual_workers = get_optimal_workers(total, workers)
|
|
406
|
-
# 准备参数:(index, item, schema_fields)
|
|
407
|
-
args_list = [(i, item, self._fields) for i, item in enumerate(data)]
|
|
408
|
-
|
|
409
|
-
for i, (idx, is_valid, result_data) in enumerate(
|
|
410
|
-
parallel_imap(
|
|
411
|
-
_validate_item_wrapper,
|
|
412
|
-
args_list,
|
|
413
|
-
workers=actual_workers,
|
|
414
|
-
threshold=1000,
|
|
415
|
-
)
|
|
416
|
-
):
|
|
417
|
-
if is_valid:
|
|
418
|
-
valid_data.append(data[idx])
|
|
419
|
-
else:
|
|
420
|
-
# 重建 ValidationResult(因为不能直接 pickle)
|
|
421
|
-
errors = [
|
|
422
|
-
ValidationError(path=e["path"], message=e["message"], value=e.get("value"))
|
|
423
|
-
for e in result_data
|
|
424
|
-
]
|
|
425
|
-
invalid_results.append((idx, ValidationResult(valid=False, errors=errors)))
|
|
426
|
-
if progress_callback:
|
|
427
|
-
progress_callback(i + 1, total)
|
|
428
|
-
else:
|
|
429
|
-
# 串行处理
|
|
430
|
-
for i, item in enumerate(data):
|
|
431
|
-
result = self.validate(item)
|
|
432
|
-
if result.valid:
|
|
433
|
-
valid_data.append(item)
|
|
434
|
-
else:
|
|
435
|
-
invalid_results.append((i, result))
|
|
436
|
-
if progress_callback:
|
|
437
|
-
progress_callback(i + 1, total)
|
|
438
|
-
|
|
439
|
-
return valid_data, invalid_results
|
|
440
|
-
|
|
441
353
|
def __repr__(self) -> str:
|
|
442
354
|
field_strs = [f" {path}: {field_def}" for path, field_def in self._fields.items()]
|
|
443
|
-
return "Schema({\n" + ",\n".join(field_strs) + "\n}})"
|
|
355
|
+
return f"Schema({{\n" + ",\n".join(field_strs) + "\n}})"
|
|
444
356
|
|
|
445
357
|
|
|
446
358
|
# ============================================================================
|
|
@@ -549,7 +461,9 @@ def sharegpt_schema(
|
|
|
549
461
|
"""
|
|
550
462
|
return Schema(
|
|
551
463
|
{
|
|
552
|
-
"conversations": Field(
|
|
464
|
+
"conversations": Field(
|
|
465
|
+
type="list", required=True, min_length=min_conversations
|
|
466
|
+
),
|
|
553
467
|
"conversations[*].from": Field(
|
|
554
468
|
type="str", required=True, choices=[human_role, gpt_role]
|
|
555
469
|
),
|