dtflow 0.5.7__py3-none-any.whl → 0.5.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/cli/stats.py CHANGED
@@ -3,7 +3,7 @@ CLI 数据统计相关命令
3
3
  """
4
4
 
5
5
  from pathlib import Path
6
- from typing import Any, Dict, List
6
+ from typing import Any, Dict, List, Optional
7
7
 
8
8
  import orjson
9
9
 
@@ -22,6 +22,8 @@ def stats(
22
22
  filename: str,
23
23
  top: int = 10,
24
24
  full: bool = False,
25
+ fields: Optional[List[str]] = None,
26
+ expand_fields: Optional[List[str]] = None,
25
27
  ) -> None:
26
28
  """
27
29
  显示数据文件的统计信息。
@@ -33,11 +35,15 @@ def stats(
33
35
  filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
34
36
  top: 显示频率最高的前 N 个值,默认 10(仅完整模式)
35
37
  full: 完整模式,统计值分布、唯一值等详细信息
38
+ fields: 指定统计的字段列表(支持嵌套路径)
39
+ expand_fields: 展开 list 字段统计的字段列表
36
40
 
37
41
  Examples:
38
42
  dt stats data.jsonl # 快速模式(默认)
39
43
  dt stats data.jsonl --full # 完整模式
40
44
  dt stats data.csv -f --top=5 # 完整模式,显示 Top 5
45
+ dt stats data.jsonl --full --field=category # 指定字段
46
+ dt stats data.jsonl --full --expand=tags # 展开 list 字段
41
47
  """
42
48
  filepath = Path(filename)
43
49
 
@@ -48,7 +54,10 @@ def stats(
48
54
  if not _check_file_format(filepath):
49
55
  return
50
56
 
57
+ # 快速模式:忽略 --field 和 --expand 参数
51
58
  if not full:
59
+ if fields or expand_fields:
60
+ print("⚠️ 警告: --field 和 --expand 参数仅在完整模式 (--full) 下生效")
52
61
  _quick_stats(filepath)
53
62
  return
54
63
 
@@ -65,7 +74,7 @@ def stats(
65
74
 
66
75
  # 计算统计信息
67
76
  total = len(data)
68
- field_stats = _compute_field_stats(data, top)
77
+ field_stats = _compute_field_stats(data, top, fields, expand_fields)
69
78
 
70
79
  # 输出统计信息
71
80
  _print_stats(filepath.name, total, field_stats)
@@ -205,11 +214,99 @@ def _quick_stats(filepath: Path) -> None:
205
214
  print(f" {i}. {f['field']} ({f['type']})")
206
215
 
207
216
 
208
- def _compute_field_stats(data: List[Dict], top: int) -> List[Dict[str, Any]]:
217
+ def _extract_with_wildcard(item: dict, field_spec: str) -> List[Any]:
218
+ """处理包含 [*] 的字段路径,返回所有值"""
219
+ if "[*]" not in field_spec:
220
+ # 无 [*],直接返回单个值的列表
221
+ value = get_field_with_spec(item, field_spec)
222
+ return [value] if value is not None else []
223
+
224
+ # 分割路径:messages[*].role -> ("messages", ".role")
225
+ before, after = field_spec.split("[*]", 1)
226
+ after = after.lstrip(".") # 移除开头的点
227
+
228
+ # 获取数组
229
+ array = get_field_with_spec(item, before) if before else item
230
+ if not isinstance(array, list):
231
+ return []
232
+
233
+ # 提取每个元素的后续路径
234
+ results = []
235
+ for elem in array:
236
+ if after:
237
+ val = get_field_with_spec(elem, after)
238
+ else:
239
+ val = elem
240
+ if val is not None:
241
+ results.append(val)
242
+
243
+ return results
244
+
245
+
246
+ def _extract_field_values(
247
+ data: List[Dict],
248
+ field_spec: str,
249
+ expand: bool = False,
250
+ ) -> List[Any]:
251
+ """
252
+ 从数据中提取字段值。
253
+
254
+ Args:
255
+ data: 数据列表
256
+ field_spec: 字段路径规格(如 "messages[*].role")
257
+ expand: 是否展开 list
258
+
259
+ Returns:
260
+ 值列表(展开或不展开)
261
+ """
262
+ all_values = []
263
+
264
+ for item in data:
265
+ if "[*]" in field_spec or expand:
266
+ # 使用通配符提取所有值
267
+ values = _extract_with_wildcard(item, field_spec)
268
+
269
+ if expand and len(values) == 1 and isinstance(values[0], list):
270
+ # 展开模式:如果返回单个列表,展开其元素
271
+ all_values.extend(values[0])
272
+ elif expand and values and isinstance(values[0], list):
273
+ # 多个列表,全部展开
274
+ for v in values:
275
+ if isinstance(v, list):
276
+ all_values.extend(v)
277
+ else:
278
+ all_values.append(v)
279
+ else:
280
+ # 不展开或非列表值
281
+ all_values.extend(values)
282
+ else:
283
+ # 普通字段路径
284
+ value = get_field_with_spec(item, field_spec)
285
+ if expand and isinstance(value, list):
286
+ # 展开 list
287
+ all_values.extend(value)
288
+ else:
289
+ all_values.append(value)
290
+
291
+ return all_values
292
+
293
+
294
+ def _compute_field_stats(
295
+ data: List[Dict],
296
+ top: int,
297
+ fields: Optional[List[str]] = None,
298
+ expand_fields: Optional[List[str]] = None,
299
+ ) -> List[Dict[str, Any]]:
209
300
  """
210
301
  单次遍历计算每个字段的统计信息。
211
302
 
212
303
  优化:将多次遍历合并为单次遍历,在遍历过程中同时收集所有统计数据。
304
+
305
+ Args:
306
+ data: 数据列表
307
+ top: Top N 值数量
308
+ fields: 指定统计的字段列表
309
+ expand_fields: 展开 list 字段统计的字段列表
213
310
  """
214
311
  from collections import Counter, defaultdict
215
312
 
@@ -218,38 +315,115 @@ def _compute_field_stats(data: List[Dict], top: int) -> List[Dict[str, Any]]:
218
315
 
219
316
  total = len(data)
220
317
 
221
- # 单次遍历收集所有字段的值和统计信息
222
- field_values = defaultdict(list) # 存储每个字段的所有值
223
- field_counters = defaultdict(Counter) # 存储每个字段的值频率(用于 top N)
318
+ # 如果没有指定字段,统计所有顶层字段(保持向后兼容)
319
+ if not fields and not expand_fields:
320
+ # 单次遍历收集所有字段的值和统计信息
321
+ field_values = defaultdict(list) # 存储每个字段的所有值
322
+ field_counters = defaultdict(Counter) # 存储每个字段的值频率(用于 top N)
323
+
324
+ for item in data:
325
+ for k, v in item.items():
326
+ field_values[k].append(v)
327
+ # 对值进行截断后计数(用于 top N 显示)
328
+ displayable = _truncate(v if v is not None else "", 30)
329
+ field_counters[k][displayable] += 1
330
+
331
+ # 根据收集的数据计算统计信息
332
+ stats_list = []
333
+ for field in sorted(field_values.keys()):
334
+ values = field_values[field]
335
+ non_null = [v for v in values if v is not None and v != ""]
336
+ non_null_count = len(non_null)
337
+
338
+ # 推断类型(从第一个非空值)
339
+ field_type = _infer_type(non_null)
340
+
341
+ # 基础统计
342
+ stat = {
343
+ "field": field,
344
+ "non_null": non_null_count,
345
+ "null_rate": f"{non_null_count / total * 100:.1f}%",
346
+ "type": field_type,
347
+ }
348
+
349
+ # 类型特定统计
350
+ if non_null:
351
+ # 唯一值计数(对复杂类型使用 hash 节省内存)
352
+ stat["unique"] = _count_unique(non_null, field_type)
353
+
354
+ # 字符串类型:计算长度统计
355
+ if field_type == "str":
356
+ lengths = [len(str(v)) for v in non_null]
357
+ stat["len_min"] = min(lengths)
358
+ stat["len_max"] = max(lengths)
359
+ stat["len_avg"] = sum(lengths) / len(lengths)
360
+
361
+ # 数值类型:计算数值统计
362
+ elif field_type in ("int", "float"):
363
+ nums = [float(v) for v in non_null if _is_numeric(v)]
364
+ if nums:
365
+ stat["min"] = min(nums)
366
+ stat["max"] = max(nums)
367
+ stat["avg"] = sum(nums) / len(nums)
368
+
369
+ # 列表类型:计算长度统计
370
+ elif field_type == "list":
371
+ lengths = [len(v) if isinstance(v, list) else 0 for v in non_null]
372
+ stat["len_min"] = min(lengths)
373
+ stat["len_max"] = max(lengths)
374
+ stat["len_avg"] = sum(lengths) / len(lengths)
375
+
376
+ # Top N 值(已在遍历时收集)
377
+ stat["top_values"] = field_counters[field].most_common(top)
378
+
379
+ stats_list.append(stat)
380
+
381
+ return stats_list
382
+
383
+ # 指定了字段:收集指定字段的统计
384
+ stats_list = []
385
+ expand_set = set(expand_fields) if expand_fields else set()
224
386
 
225
- for item in data:
226
- for k, v in item.items():
227
- field_values[k].append(v)
228
- # 对值进行截断后计数(用于 top N 显示)
229
- displayable = _truncate(v if v is not None else "", 30)
230
- field_counters[k][displayable] += 1
387
+ # 合并字段列表
388
+ all_fields = set(fields) if fields else set()
389
+ all_fields.update(expand_set)
231
390
 
232
- # 根据收集的数据计算统计信息
233
- stats_list = []
234
- for field in sorted(field_values.keys()):
235
- values = field_values[field]
391
+ for field_spec in sorted(all_fields):
392
+ is_expanded = field_spec in expand_set
393
+
394
+ # 提取字段值
395
+ values = _extract_field_values(data, field_spec, expand=is_expanded)
396
+
397
+ # 过滤 None 和空值
236
398
  non_null = [v for v in values if v is not None and v != ""]
237
399
  non_null_count = len(non_null)
238
400
 
239
- # 推断类型(从第一个非空值)
401
+ # 推断类型
240
402
  field_type = _infer_type(non_null)
241
403
 
242
404
  # 基础统计
243
- stat = {
244
- "field": field,
245
- "non_null": non_null_count,
246
- "null_rate": f"{(total - non_null_count) / total * 100:.1f}%",
247
- "type": field_type,
248
- }
405
+ if is_expanded:
406
+ # 展开模式:显示元素总数和平均数,而非非空率
407
+ stat = {
408
+ "field": field_spec,
409
+ "non_null": non_null_count,
410
+ "null_rate": f"总元素: {len(values)}",
411
+ "type": field_type,
412
+ "is_expanded": is_expanded,
413
+ }
414
+ else:
415
+ # 普通模式:显示非空率
416
+ stat = {
417
+ "field": field_spec,
418
+ "non_null": non_null_count,
419
+ "null_rate": f"{non_null_count / total * 100:.1f}%",
420
+ "type": field_type,
421
+ "is_expanded": is_expanded,
422
+ }
249
423
 
250
424
  # 类型特定统计
251
425
  if non_null:
252
- # 唯一值计数(对复杂类型使用 hash 节省内存)
426
+ # 唯一值计数
253
427
  stat["unique"] = _count_unique(non_null, field_type)
254
428
 
255
429
  # 字符串类型:计算长度统计
@@ -274,8 +448,12 @@ def _compute_field_stats(data: List[Dict], top: int) -> List[Dict[str, Any]]:
274
448
  stat["len_max"] = max(lengths)
275
449
  stat["len_avg"] = sum(lengths) / len(lengths)
276
450
 
277
- # Top N 值(已在遍历时收集)
278
- stat["top_values"] = field_counters[field].most_common(top)
451
+ # Top N 值(需要重新计数)
452
+ counter = Counter()
453
+ for v in non_null:
454
+ displayable = _truncate(v if v is not None else "", 30)
455
+ counter[displayable] += 1
456
+ stat["top_values"] = counter.most_common(top)
279
457
 
280
458
  stats_list.append(stat)
281
459
 
@@ -343,9 +521,18 @@ def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -
343
521
  table.add_column("统计", style="dim")
344
522
 
345
523
  for stat in field_stats:
346
- non_null_rate = f"{stat['non_null'] / total * 100:.0f}%"
524
+ # 使用 stat 中的 null_rate(支持展开模式的特殊显示)
525
+ if "null_rate" in stat:
526
+ non_null_rate = stat["null_rate"]
527
+ else:
528
+ non_null_rate = f"{stat['non_null'] / total * 100:.0f}%"
347
529
  unique = str(stat.get("unique", "-"))
348
530
 
531
+ # 字段名(添加展开标记)
532
+ field_name = stat["field"]
533
+ if stat.get("is_expanded"):
534
+ field_name += " (展开)"
535
+
349
536
  # 构建统计信息字符串
350
537
  extra = []
351
538
  if "len_avg" in stat:
@@ -363,7 +550,7 @@ def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -
363
550
  )
364
551
 
365
552
  table.add_row(
366
- stat["field"],
553
+ field_name,
367
554
  stat["type"],
368
555
  non_null_rate,
369
556
  unique,
@@ -387,12 +574,19 @@ def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -
387
574
  if unique_ratio > 0.9 and stat.get("unique", 0) > 100:
388
575
  continue
389
576
 
577
+ # 字段名(添加展开标记)
578
+ field_display = stat["field"]
579
+ if stat.get("is_expanded"):
580
+ field_display += " (展开)"
581
+
390
582
  console.print(
391
- f"\n[bold cyan]{stat['field']}[/bold cyan] 值分布 (Top {len(top_values)}):"
583
+ f"\n[bold cyan]{field_display}[/bold cyan] 值分布 (Top {len(top_values)}):"
392
584
  )
393
585
  max_count = max(c for _, c in top_values) if top_values else 1
586
+ # 展开模式下使用 non_null(元素总数),否则使用 total(数据条数)
587
+ base_count = stat["non_null"] if stat.get("is_expanded") else total
394
588
  for value, count in top_values:
395
- pct = count / total * 100
589
+ pct = count / base_count * 100 if base_count > 0 else 0
396
590
  bar_len = int(count / max_count * 20) # 按相对比例,最长 20 字符
397
591
  bar = "█" * bar_len
398
592
  display_value = value if value else "[空]"
dtflow/eval.py ADDED
@@ -0,0 +1,276 @@
1
+ """
2
+ 评估指标计算模块
3
+
4
+ 提供分类任务的指标计算和评估报告导出:
5
+ - MetricsCalculator: 计算 accuracy/precision/recall/F1/混淆矩阵
6
+ - export_eval_report: 生成 metrics.md + result.jsonl + bad_case.jsonl
7
+
8
+ 依赖: scikit-learn, pandas
9
+ 安装: pip install dtflow[eval]
10
+ """
11
+
12
+ import os
13
+ from datetime import datetime
14
+ from pathlib import Path
15
+ from typing import TYPE_CHECKING, Optional
16
+
17
+ if TYPE_CHECKING:
18
+ from pandas import DataFrame
19
+
20
+
21
+ def _check_eval_deps():
22
+ """检查 eval 依赖是否已安装"""
23
+ try:
24
+ import pandas # noqa: F401
25
+ import sklearn # noqa: F401
26
+ except ImportError as e:
27
+ missing = str(e).split("'")[1] if "'" in str(e) else str(e)
28
+ raise ImportError(
29
+ f"eval 功能需要额外依赖: {missing}\n" f"请运行: pip install dtflow[eval]"
30
+ ) from e
31
+
32
+
33
+ class MetricsCalculator:
34
+ """分类指标计算器
35
+
36
+ 基于 sklearn 计算 accuracy/precision/recall/F1/混淆矩阵/分类报告。
37
+
38
+ Args:
39
+ df: 包含预测列和标签列的 DataFrame
40
+ pred_col: 预测值列名
41
+ label_col: 标签值列名
42
+ include_macro_micro_avg: 是否在报告中包含 macro/micro 平均
43
+ remove_matrix_zero_row: 是否移除混淆矩阵中 support=0 的行
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ df: "DataFrame",
49
+ pred_col: str = "predict",
50
+ label_col: str = "label",
51
+ include_macro_micro_avg: bool = False,
52
+ remove_matrix_zero_row: bool = False,
53
+ ):
54
+ _check_eval_deps()
55
+ self.df = df
56
+ self.y_pred = df[pred_col]
57
+ self.y_true = df[label_col]
58
+ self.all_labels = sorted(set(self.y_true.unique()).union(set(self.y_pred.unique())))
59
+ self.needed_labels = None
60
+ self.remove_matrix_zero_row = remove_matrix_zero_row
61
+ self.include_macro_micro_avg = include_macro_micro_avg
62
+ self.metrics = self._calculate_metrics()
63
+
64
+ def _calculate_metrics(self):
65
+ from sklearn.metrics import (
66
+ accuracy_score,
67
+ classification_report,
68
+ confusion_matrix,
69
+ precision_score,
70
+ recall_score,
71
+ )
72
+
73
+ accuracy = accuracy_score(self.y_true, self.y_pred)
74
+ precision = precision_score(
75
+ self.y_true, self.y_pred, labels=self.all_labels, average="weighted", zero_division=0
76
+ )
77
+ recall = recall_score(
78
+ self.y_true, self.y_pred, labels=self.all_labels, average="weighted", zero_division=0
79
+ )
80
+ conf_matrix = confusion_matrix(self.y_true, self.y_pred, labels=self.all_labels)
81
+ report = classification_report(
82
+ self.y_true, self.y_pred, labels=self.all_labels, output_dict=True, zero_division=0
83
+ )
84
+
85
+ # 默认只保留加权平均
86
+ if not self.include_macro_micro_avg:
87
+ report = {
88
+ label: metrics
89
+ for label, metrics in report.items()
90
+ if label in self.all_labels or label == "weighted avg"
91
+ }
92
+
93
+ # 去除 support=0 的类别(注意 accuracy 是 float 不是 dict)
94
+ report = {
95
+ label: metrics
96
+ for label, metrics in report.items()
97
+ if isinstance(metrics, dict) and metrics.get("support", 0) > 0
98
+ }
99
+
100
+ self.needed_labels = [label for label in report.keys() if label in self.all_labels]
101
+
102
+ # 可选移除混淆矩阵中不需要的行
103
+ needed_idx_list = [self.all_labels.index(label) for label in self.needed_labels]
104
+ if self.remove_matrix_zero_row:
105
+ conf_matrix = conf_matrix[needed_idx_list]
106
+
107
+ return {
108
+ "accuracy": accuracy,
109
+ "precision": precision,
110
+ "recall": recall,
111
+ "confusion_matrix": conf_matrix,
112
+ "classification_report": report,
113
+ }
114
+
115
+ def get_metrics(self):
116
+ return self.metrics
117
+
118
+ def format_classification_report_as_markdown(self):
119
+ """将分类报告格式化为 Markdown 表格"""
120
+ report = self.metrics["classification_report"]
121
+ header = "| Label | Precision | Recall | F1-score | Support |\n"
122
+ separator = "|-------|-----------|--------|----------|---------|\n"
123
+ rows = []
124
+ for label, metrics in report.items():
125
+ if isinstance(metrics, dict):
126
+ rows.append(
127
+ f"| {label} | {metrics['precision']:.2f} | {metrics['recall']:.2f} "
128
+ f"| {metrics['f1-score']:.2f} | {metrics['support']:.0f} |"
129
+ )
130
+ return header + separator + "\n".join(rows)
131
+
132
+ def _clean_label_for_markdown(self, label, max_length=20):
133
+ """清理标签文本,使其适合 Markdown 表格显示"""
134
+ label = str(label).replace("\n", " ")
135
+ label = label.replace("|", "\\|")
136
+ label = label.replace("-", "\\-")
137
+ label = label.replace("<", "&lt;")
138
+ label = label.replace(">", "&gt;")
139
+ if len(label) > max_length:
140
+ label = label[:max_length] + "..."
141
+ label = label.strip()
142
+ if not label:
143
+ label = "(empty)"
144
+ return label
145
+
146
+ def format_confusion_matrix_as_markdown(self, max_label_length=20):
147
+ """将混淆矩阵格式化为 Markdown 表格"""
148
+ matrix = self.metrics["confusion_matrix"]
149
+
150
+ if self.remove_matrix_zero_row:
151
+ labels = self.needed_labels
152
+ else:
153
+ labels = self.all_labels
154
+
155
+ processed_labels = [self._clean_label_for_markdown(lb, max_label_length) for lb in labels]
156
+
157
+ header = "| 真实值/预测值 | " + " | ".join(processed_labels) + " |\n"
158
+ separator_parts = [":---:"] * (len(processed_labels) + 1)
159
+ separator = "| " + " | ".join(separator_parts) + " |\n"
160
+
161
+ rows = []
162
+ for i, row in enumerate(matrix):
163
+ row_label = self._clean_label_for_markdown(labels[i], max_label_length)
164
+ formatted_row = [f"{num:,}" for num in row]
165
+ rows.append(f"| {row_label} | " + " | ".join(formatted_row) + " |")
166
+
167
+ return header + separator + "\n".join(rows)
168
+
169
+
170
+ def export_eval_report(
171
+ df: "DataFrame",
172
+ pred_col: str,
173
+ label_col: str,
174
+ record_folder: str = "record",
175
+ input_name: Optional[str] = None,
176
+ ):
177
+ """生成评估报告并保存到指定目录
178
+
179
+ 输出文件:
180
+ - metrics.md: 指标概览 + 分类报告 + 混淆矩阵
181
+ - result.jsonl: 完整预测结果
182
+ - bad_case.jsonl: 预测错误样本
183
+
184
+ Args:
185
+ df: 包含预测和标签的 DataFrame
186
+ pred_col: 预测值列名
187
+ label_col: 标签值列名
188
+ record_folder: 输出根目录
189
+ input_name: 输入文件名(用于子目录命名)
190
+ """
191
+ from rich.console import Console
192
+ from rich.markdown import Markdown
193
+
194
+ calculator = MetricsCalculator(df, pred_col=pred_col, label_col=label_col)
195
+ metrics = calculator.get_metrics()
196
+
197
+ # 用 Rich Table 构建指标概览(替代 tabulate)
198
+ from rich.table import Table
199
+
200
+ overview_table = Table(title="指标概览", show_header=True)
201
+ overview_table.add_column("Accuracy", justify="center")
202
+ overview_table.add_column("Precision", justify="center")
203
+ overview_table.add_column("Recall", justify="center")
204
+ overview_table.add_row(
205
+ f"{metrics['accuracy']:.4f}",
206
+ f"{metrics['precision']:.4f}",
207
+ f"{metrics['recall']:.4f}",
208
+ )
209
+
210
+ # 构建 Markdown 报告内容
211
+ md = (
212
+ f"\n\n### 指标概览\n\n"
213
+ f"| Accuracy | Precision | Recall |\n"
214
+ f"|----------|-----------|--------|\n"
215
+ f"| {metrics['accuracy']:.4f} | {metrics['precision']:.4f} | {metrics['recall']:.4f} |"
216
+ )
217
+ metrics_md = calculator.format_classification_report_as_markdown()
218
+ confusion_md = calculator.format_confusion_matrix_as_markdown()
219
+ md += f"\n\n### Classification Report\n{metrics_md}\n" f"\n### Confusion Matrix\n{confusion_md}"
220
+
221
+ # 创建输出目录(带序号和时间戳)
222
+ now = datetime.now().strftime("%Y%m%d-%H-%M-%S")
223
+ record_path = Path(record_folder)
224
+ if input_name:
225
+ record_path = record_path / input_name
226
+
227
+ if record_path.exists():
228
+ existing = [d.name for d in record_path.iterdir() if d.is_dir()]
229
+ max_idx = 0
230
+ for name in existing:
231
+ parts = name.split("-", 1)
232
+ if parts[0].isdigit():
233
+ max_idx = max(max_idx, int(parts[0]))
234
+ idx = max_idx + 1
235
+ else:
236
+ idx = 1
237
+
238
+ record_path = record_path / f"{idx}-{now}"
239
+ record_path.mkdir(parents=True, exist_ok=True)
240
+
241
+ # 终端输出
242
+ console = Console()
243
+ console.print(overview_table)
244
+ console.print(Markdown(md))
245
+
246
+ # 保存文件
247
+ with open(os.path.join(record_path, "metrics.md"), "w", encoding="utf-8") as f:
248
+ f.write(md)
249
+
250
+ bad_case_df = df[df[pred_col] != df[label_col]]
251
+
252
+ # 保存 JSONL
253
+ df.to_json(
254
+ os.path.join(record_path, "result.jsonl"),
255
+ orient="records",
256
+ lines=True,
257
+ force_ascii=False,
258
+ )
259
+ bad_case_df.to_json(
260
+ os.path.join(record_path, "bad_case.jsonl"),
261
+ orient="records",
262
+ lines=True,
263
+ force_ascii=False,
264
+ )
265
+
266
+ # 尝试保存 CSV
267
+ try:
268
+ df.to_csv(os.path.join(record_path, "result.csv"), index=False)
269
+ bad_case_df.to_csv(os.path.join(record_path, "bad_case.csv"), index=False)
270
+ except Exception:
271
+ pass
272
+
273
+ console.print(f"\n[green]报告已保存到: {record_path}[/green]")
274
+ console.print(f"[dim] - metrics.md ({len(df)} 条数据, {len(bad_case_df)} 条错误)[/dim]")
275
+
276
+ return record_path