dtflow 0.5.8__py3-none-any.whl → 0.5.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/cli/stats.py CHANGED
@@ -209,7 +209,7 @@ def _quick_stats(filepath: Path) -> None:
209
209
  print(f"字段: {len(fields)} 个")
210
210
 
211
211
  if fields:
212
- print("\n📋 字段结构:")
212
+ print(f"\n📋 字段结构:")
213
213
  for i, f in enumerate(fields, 1):
214
214
  print(f" {i}. {f['field']} ({f['type']})")
215
215
 
@@ -597,14 +597,14 @@ def _print_stats(filename: str, total: int, field_stats: List[Dict[str, Any]]) -
597
597
  except ImportError:
598
598
  # 没有 rich,使用普通打印
599
599
  print(f"\n{'=' * 50}")
600
- print("📊 数据概览")
600
+ print(f"📊 数据概览")
601
601
  print(f"{'=' * 50}")
602
602
  print(f"文件: {filename}")
603
603
  print(f"总数: {total:,} 条")
604
604
  print(f"字段: {len(field_stats)} 个")
605
605
 
606
606
  print(f"\n{'=' * 50}")
607
- print("📋 字段统计")
607
+ print(f"📋 字段统计")
608
608
  print(f"{'=' * 50}")
609
609
  print(f"{'字段':<20} {'类型':<8} {'非空率':<8} {'唯一值':<8}")
610
610
  print("-" * 50)
@@ -620,7 +620,6 @@ def token_stats(
620
620
  field: str = "messages",
621
621
  model: str = "cl100k_base",
622
622
  detailed: bool = False,
623
- workers: Optional[int] = None,
624
623
  ) -> None:
625
624
  """
626
625
  统计数据集的 Token 信息。
@@ -630,7 +629,6 @@ def token_stats(
630
629
  field: 要统计的字段(默认 messages),支持嵌套路径语法
631
630
  model: 分词器: cl100k_base (默认), qwen2.5, llama3, gpt-4 等
632
631
  detailed: 是否显示详细统计
633
- workers: 并行进程数,None 自动检测,1 禁用并行
634
632
 
635
633
  Examples:
636
634
  dt token-stats data.jsonl
@@ -638,7 +636,6 @@ def token_stats(
638
636
  dt token-stats data.jsonl --field=conversation.messages
639
637
  dt token-stats data.jsonl --field=messages[-1].content # 统计最后一条消息
640
638
  dt token-stats data.jsonl --detailed
641
- dt token-stats data.jsonl --workers=4 # 使用 4 进程
642
639
  """
643
640
  filepath = Path(filename)
644
641
 
@@ -670,7 +667,7 @@ def token_stats(
670
667
 
671
668
  # 尝试使用 rich 进度条
672
669
  try:
673
- from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
670
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
674
671
 
675
672
  with Progress(
676
673
  SpinnerColumn(),
@@ -688,22 +685,14 @@ def token_stats(
688
685
  from ..tokenizers import messages_token_stats
689
686
 
690
687
  stats_result = messages_token_stats(
691
- data,
692
- messages_field=field,
693
- model=model,
694
- progress_callback=update_progress,
695
- workers=workers,
688
+ data, messages_field=field, model=model, progress_callback=update_progress
696
689
  )
697
690
  _print_messages_token_stats(stats_result, detailed)
698
691
  else:
699
692
  from ..tokenizers import token_stats as compute_token_stats
700
693
 
701
694
  stats_result = compute_token_stats(
702
- data,
703
- fields=field,
704
- model=model,
705
- progress_callback=update_progress,
706
- workers=workers,
695
+ data, fields=field, model=model, progress_callback=update_progress
707
696
  )
708
697
  _print_text_token_stats(stats_result, detailed)
709
698
 
@@ -714,14 +703,12 @@ def token_stats(
714
703
  if isinstance(field_value, list) and field_value and isinstance(field_value[0], dict):
715
704
  from ..tokenizers import messages_token_stats
716
705
 
717
- stats_result = messages_token_stats(
718
- data, messages_field=field, model=model, workers=workers
719
- )
706
+ stats_result = messages_token_stats(data, messages_field=field, model=model)
720
707
  _print_messages_token_stats(stats_result, detailed)
721
708
  else:
722
709
  from ..tokenizers import token_stats as compute_token_stats
723
710
 
724
- stats_result = compute_token_stats(data, fields=field, model=model, workers=workers)
711
+ stats_result = compute_token_stats(data, fields=field, model=model)
725
712
  _print_text_token_stats(stats_result, detailed)
726
713
  except ImportError as e:
727
714
  print(f"错误: {e}")
@@ -801,7 +788,7 @@ def _print_messages_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
801
788
  print(f"平均 Token: {stats['avg_tokens']:,} (std: {std:.1f})")
802
789
  print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
803
790
 
804
- print("\n📈 百分位分布:")
791
+ print(f"\n📈 百分位分布:")
805
792
  print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
806
793
  print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
807
794
  print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
@@ -868,7 +855,7 @@ def _print_text_token_stats(stats: Dict[str, Any], detailed: bool) -> None:
868
855
  print(f"平均 Token: {stats['avg_tokens']:.1f} (std: {std:.1f})")
869
856
  print(f"范围: {stats['min_tokens']:,} - {stats['max_tokens']:,}")
870
857
 
871
- print("\n📈 百分位分布:")
858
+ print(f"\n📈 百分位分布:")
872
859
  print(f" P25: {stats.get('p25', '-'):,} P50: {stats.get('median_tokens', '-'):,}")
873
860
  print(f" P75: {stats.get('p75', '-'):,} P90: {stats.get('p90', '-'):,}")
874
861
  print(f" P95: {stats.get('p95', '-'):,} P99: {stats.get('p99', '-'):,}")
dtflow/cli/validate.py CHANGED
@@ -6,6 +6,8 @@ from pathlib import Path
6
6
  from typing import Optional
7
7
 
8
8
  from ..schema import (
9
+ Schema,
10
+ Field,
9
11
  alpaca_schema,
10
12
  dpo_schema,
11
13
  openai_chat_schema,
@@ -14,6 +16,7 @@ from ..schema import (
14
16
  from ..storage.io import load_data, save_data
15
17
  from .common import _check_file_format
16
18
 
19
+
17
20
  # 预设 Schema 映射
18
21
  PRESET_SCHEMAS = {
19
22
  "openai_chat": openai_chat_schema,
@@ -33,7 +36,6 @@ def validate(
33
36
  filter_invalid: bool = False,
34
37
  max_errors: int = 20,
35
38
  verbose: bool = False,
36
- workers: Optional[int] = None,
37
39
  ) -> None:
38
40
  """
39
41
  使用 Schema 验证数据文件。
@@ -45,13 +47,11 @@ def validate(
45
47
  filter_invalid: 过滤无效数据并保存
46
48
  max_errors: 最多显示的错误数量
47
49
  verbose: 显示详细信息
48
- workers: 并行进程数,None 自动检测,1 禁用并行
49
50
 
50
51
  Examples:
51
52
  dt validate data.jsonl --preset=openai_chat
52
53
  dt validate data.jsonl --preset=alpaca -o valid.jsonl
53
54
  dt validate data.jsonl --preset=chat --filter
54
- dt validate data.jsonl --preset=chat --workers=4
55
55
  """
56
56
  filepath = Path(filename)
57
57
 
@@ -99,54 +99,19 @@ def validate(
99
99
  print(f"总记录数: {total}")
100
100
  print()
101
101
 
102
- # 验证(使用并行或串行)
103
- use_parallel = workers != 1 and total >= 1000
104
-
105
- if use_parallel:
106
- # 使用进度条(如果有 rich)
107
- try:
108
- from rich.progress import (
109
- BarColumn,
110
- Progress,
111
- SpinnerColumn,
112
- TaskProgressColumn,
113
- TextColumn,
114
- )
115
-
116
- with Progress(
117
- SpinnerColumn(),
118
- TextColumn("[bold blue]验证数据"),
119
- BarColumn(),
120
- TaskProgressColumn(),
121
- ) as progress:
122
- task = progress.add_task("", total=total)
123
-
124
- def update_progress(current: int, total_count: int):
125
- progress.update(task, completed=current)
126
-
127
- valid_data, invalid_results = schema.validate_parallel(
128
- data, workers=workers, progress_callback=update_progress
129
- )
130
- except ImportError:
131
- print("🔍 验证数据...")
132
- valid_data, invalid_results = schema.validate_parallel(data, workers=workers)
133
-
134
- invalid_count = len(invalid_results)
135
- error_samples = invalid_results[:max_errors]
136
- else:
137
- # 串行验证
138
- valid_data = []
139
- invalid_count = 0
140
- error_samples = []
141
-
142
- for i, item in enumerate(data):
143
- result = schema.validate(item)
144
- if result.valid:
145
- valid_data.append(item)
146
- else:
147
- invalid_count += 1
148
- if len(error_samples) < max_errors:
149
- error_samples.append((i, result))
102
+ # 验证
103
+ valid_data = []
104
+ invalid_count = 0
105
+ error_samples = []
106
+
107
+ for i, item in enumerate(data):
108
+ result = schema.validate(item)
109
+ if result.valid:
110
+ valid_data.append(item)
111
+ else:
112
+ invalid_count += 1
113
+ if len(error_samples) < max_errors:
114
+ error_samples.append((i, result))
150
115
 
151
116
  valid_count = len(valid_data)
152
117
  valid_ratio = valid_count / total * 100 if total > 0 else 0
@@ -173,7 +138,9 @@ def validate(
173
138
 
174
139
  # 保存有效数据
175
140
  if output or filter_invalid:
176
- output_path = output or str(filepath).replace(filepath.suffix, f"_valid{filepath.suffix}")
141
+ output_path = output or str(filepath).replace(
142
+ filepath.suffix, f"_valid{filepath.suffix}"
143
+ )
177
144
  save_data(valid_data, output_path)
178
145
  print(f"✅ 有效数据已保存: {output_path} ({valid_count} 条)")
179
146
 
dtflow/eval.py ADDED
@@ -0,0 +1,276 @@
1
+ """
2
+ 评估指标计算模块
3
+
4
+ 提供分类任务的指标计算和评估报告导出:
5
+ - MetricsCalculator: 计算 accuracy/precision/recall/F1/混淆矩阵
6
+ - export_eval_report: 生成 metrics.md + result.jsonl + bad_case.jsonl
7
+
8
+ 依赖: scikit-learn, pandas
9
+ 安装: pip install dtflow[eval]
10
+ """
11
+
12
+ import os
13
+ from datetime import datetime
14
+ from pathlib import Path
15
+ from typing import TYPE_CHECKING, Optional
16
+
17
+ if TYPE_CHECKING:
18
+ from pandas import DataFrame
19
+
20
+
21
+ def _check_eval_deps():
22
+ """检查 eval 依赖是否已安装"""
23
+ try:
24
+ import pandas # noqa: F401
25
+ import sklearn # noqa: F401
26
+ except ImportError as e:
27
+ missing = str(e).split("'")[1] if "'" in str(e) else str(e)
28
+ raise ImportError(
29
+ f"eval 功能需要额外依赖: {missing}\n" f"请运行: pip install dtflow[eval]"
30
+ ) from e
31
+
32
+
33
+ class MetricsCalculator:
34
+ """分类指标计算器
35
+
36
+ 基于 sklearn 计算 accuracy/precision/recall/F1/混淆矩阵/分类报告。
37
+
38
+ Args:
39
+ df: 包含预测列和标签列的 DataFrame
40
+ pred_col: 预测值列名
41
+ label_col: 标签值列名
42
+ include_macro_micro_avg: 是否在报告中包含 macro/micro 平均
43
+ remove_matrix_zero_row: 是否移除混淆矩阵中 support=0 的行
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ df: "DataFrame",
49
+ pred_col: str = "predict",
50
+ label_col: str = "label",
51
+ include_macro_micro_avg: bool = False,
52
+ remove_matrix_zero_row: bool = False,
53
+ ):
54
+ _check_eval_deps()
55
+ self.df = df
56
+ self.y_pred = df[pred_col]
57
+ self.y_true = df[label_col]
58
+ self.all_labels = sorted(set(self.y_true.unique()).union(set(self.y_pred.unique())))
59
+ self.needed_labels = None
60
+ self.remove_matrix_zero_row = remove_matrix_zero_row
61
+ self.include_macro_micro_avg = include_macro_micro_avg
62
+ self.metrics = self._calculate_metrics()
63
+
64
+ def _calculate_metrics(self):
65
+ from sklearn.metrics import (
66
+ accuracy_score,
67
+ classification_report,
68
+ confusion_matrix,
69
+ precision_score,
70
+ recall_score,
71
+ )
72
+
73
+ accuracy = accuracy_score(self.y_true, self.y_pred)
74
+ precision = precision_score(
75
+ self.y_true, self.y_pred, labels=self.all_labels, average="weighted", zero_division=0
76
+ )
77
+ recall = recall_score(
78
+ self.y_true, self.y_pred, labels=self.all_labels, average="weighted", zero_division=0
79
+ )
80
+ conf_matrix = confusion_matrix(self.y_true, self.y_pred, labels=self.all_labels)
81
+ report = classification_report(
82
+ self.y_true, self.y_pred, labels=self.all_labels, output_dict=True, zero_division=0
83
+ )
84
+
85
+ # 默认只保留加权平均
86
+ if not self.include_macro_micro_avg:
87
+ report = {
88
+ label: metrics
89
+ for label, metrics in report.items()
90
+ if label in self.all_labels or label == "weighted avg"
91
+ }
92
+
93
+ # 去除 support=0 的类别(注意 accuracy 是 float 不是 dict)
94
+ report = {
95
+ label: metrics
96
+ for label, metrics in report.items()
97
+ if isinstance(metrics, dict) and metrics.get("support", 0) > 0
98
+ }
99
+
100
+ self.needed_labels = [label for label in report.keys() if label in self.all_labels]
101
+
102
+ # 可选移除混淆矩阵中不需要的行
103
+ needed_idx_list = [self.all_labels.index(label) for label in self.needed_labels]
104
+ if self.remove_matrix_zero_row:
105
+ conf_matrix = conf_matrix[needed_idx_list]
106
+
107
+ return {
108
+ "accuracy": accuracy,
109
+ "precision": precision,
110
+ "recall": recall,
111
+ "confusion_matrix": conf_matrix,
112
+ "classification_report": report,
113
+ }
114
+
115
+ def get_metrics(self):
116
+ return self.metrics
117
+
118
+ def format_classification_report_as_markdown(self):
119
+ """将分类报告格式化为 Markdown 表格"""
120
+ report = self.metrics["classification_report"]
121
+ header = "| Label | Precision | Recall | F1-score | Support |\n"
122
+ separator = "|-------|-----------|--------|----------|---------|\n"
123
+ rows = []
124
+ for label, metrics in report.items():
125
+ if isinstance(metrics, dict):
126
+ rows.append(
127
+ f"| {label} | {metrics['precision']:.2f} | {metrics['recall']:.2f} "
128
+ f"| {metrics['f1-score']:.2f} | {metrics['support']:.0f} |"
129
+ )
130
+ return header + separator + "\n".join(rows)
131
+
132
+ def _clean_label_for_markdown(self, label, max_length=20):
133
+ """清理标签文本,使其适合 Markdown 表格显示"""
134
+ label = str(label).replace("\n", " ")
135
+ label = label.replace("|", "\\|")
136
+ label = label.replace("-", "\\-")
137
+ label = label.replace("<", "&lt;")
138
+ label = label.replace(">", "&gt;")
139
+ if len(label) > max_length:
140
+ label = label[:max_length] + "..."
141
+ label = label.strip()
142
+ if not label:
143
+ label = "(empty)"
144
+ return label
145
+
146
+ def format_confusion_matrix_as_markdown(self, max_label_length=20):
147
+ """将混淆矩阵格式化为 Markdown 表格"""
148
+ matrix = self.metrics["confusion_matrix"]
149
+
150
+ if self.remove_matrix_zero_row:
151
+ labels = self.needed_labels
152
+ else:
153
+ labels = self.all_labels
154
+
155
+ processed_labels = [self._clean_label_for_markdown(lb, max_label_length) for lb in labels]
156
+
157
+ header = "| 真实值/预测值 | " + " | ".join(processed_labels) + " |\n"
158
+ separator_parts = [":---:"] * (len(processed_labels) + 1)
159
+ separator = "| " + " | ".join(separator_parts) + " |\n"
160
+
161
+ rows = []
162
+ for i, row in enumerate(matrix):
163
+ row_label = self._clean_label_for_markdown(labels[i], max_label_length)
164
+ formatted_row = [f"{num:,}" for num in row]
165
+ rows.append(f"| {row_label} | " + " | ".join(formatted_row) + " |")
166
+
167
+ return header + separator + "\n".join(rows)
168
+
169
+
170
+ def export_eval_report(
171
+ df: "DataFrame",
172
+ pred_col: str,
173
+ label_col: str,
174
+ record_folder: str = "record",
175
+ input_name: Optional[str] = None,
176
+ ):
177
+ """生成评估报告并保存到指定目录
178
+
179
+ 输出文件:
180
+ - metrics.md: 指标概览 + 分类报告 + 混淆矩阵
181
+ - result.jsonl: 完整预测结果
182
+ - bad_case.jsonl: 预测错误样本
183
+
184
+ Args:
185
+ df: 包含预测和标签的 DataFrame
186
+ pred_col: 预测值列名
187
+ label_col: 标签值列名
188
+ record_folder: 输出根目录
189
+ input_name: 输入文件名(用于子目录命名)
190
+ """
191
+ from rich.console import Console
192
+ from rich.markdown import Markdown
193
+
194
+ calculator = MetricsCalculator(df, pred_col=pred_col, label_col=label_col)
195
+ metrics = calculator.get_metrics()
196
+
197
+ # 用 Rich Table 构建指标概览(替代 tabulate)
198
+ from rich.table import Table
199
+
200
+ overview_table = Table(title="指标概览", show_header=True)
201
+ overview_table.add_column("Accuracy", justify="center")
202
+ overview_table.add_column("Precision", justify="center")
203
+ overview_table.add_column("Recall", justify="center")
204
+ overview_table.add_row(
205
+ f"{metrics['accuracy']:.4f}",
206
+ f"{metrics['precision']:.4f}",
207
+ f"{metrics['recall']:.4f}",
208
+ )
209
+
210
+ # 构建 Markdown 报告内容
211
+ md = (
212
+ f"\n\n### 指标概览\n\n"
213
+ f"| Accuracy | Precision | Recall |\n"
214
+ f"|----------|-----------|--------|\n"
215
+ f"| {metrics['accuracy']:.4f} | {metrics['precision']:.4f} | {metrics['recall']:.4f} |"
216
+ )
217
+ metrics_md = calculator.format_classification_report_as_markdown()
218
+ confusion_md = calculator.format_confusion_matrix_as_markdown()
219
+ md += f"\n\n### Classification Report\n{metrics_md}\n" f"\n### Confusion Matrix\n{confusion_md}"
220
+
221
+ # 创建输出目录(带序号和时间戳)
222
+ now = datetime.now().strftime("%Y%m%d-%H-%M-%S")
223
+ record_path = Path(record_folder)
224
+ if input_name:
225
+ record_path = record_path / input_name
226
+
227
+ if record_path.exists():
228
+ existing = [d.name for d in record_path.iterdir() if d.is_dir()]
229
+ max_idx = 0
230
+ for name in existing:
231
+ parts = name.split("-", 1)
232
+ if parts[0].isdigit():
233
+ max_idx = max(max_idx, int(parts[0]))
234
+ idx = max_idx + 1
235
+ else:
236
+ idx = 1
237
+
238
+ record_path = record_path / f"{idx}-{now}"
239
+ record_path.mkdir(parents=True, exist_ok=True)
240
+
241
+ # 终端输出
242
+ console = Console()
243
+ console.print(overview_table)
244
+ console.print(Markdown(md))
245
+
246
+ # 保存文件
247
+ with open(os.path.join(record_path, "metrics.md"), "w", encoding="utf-8") as f:
248
+ f.write(md)
249
+
250
+ bad_case_df = df[df[pred_col] != df[label_col]]
251
+
252
+ # 保存 JSONL
253
+ df.to_json(
254
+ os.path.join(record_path, "result.jsonl"),
255
+ orient="records",
256
+ lines=True,
257
+ force_ascii=False,
258
+ )
259
+ bad_case_df.to_json(
260
+ os.path.join(record_path, "bad_case.jsonl"),
261
+ orient="records",
262
+ lines=True,
263
+ force_ascii=False,
264
+ )
265
+
266
+ # 尝试保存 CSV
267
+ try:
268
+ df.to_csv(os.path.join(record_path, "result.csv"), index=False)
269
+ bad_case_df.to_csv(os.path.join(record_path, "bad_case.csv"), index=False)
270
+ except Exception:
271
+ pass
272
+
273
+ console.print(f"\n[green]报告已保存到: {record_path}[/green]")
274
+ console.print(f"[dim] - metrics.md ({len(df)} 条数据, {len(bad_case_df)} 条错误)[/dim]")
275
+
276
+ return record_path
dtflow/schema.py CHANGED
@@ -26,35 +26,10 @@ Schema 验证模块
26
26
  results = dt.validate_schema(schema)
27
27
  """
28
28
 
29
- from dataclasses import dataclass
30
- from dataclasses import field as dataclass_field
31
- from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
32
-
33
- from .utils.field_path import _parse_path, get_field
34
-
35
-
36
- def _validate_item_wrapper(args: tuple) -> Tuple[int, bool, list]:
37
- """
38
- 验证单条数据(用于多进程)。
39
-
40
- Args:
41
- args: (index, item, schema_fields) 元组
42
-
43
- Returns:
44
- (index, is_valid, errors_as_dicts) - 返回字典列表而非对象(pickle 兼容)
45
- """
46
- idx, item, fields = args
47
- # 在子进程中重建 Schema
48
- schema = Schema(fields)
49
- result = schema.validate(item)
50
-
51
- if result.valid:
52
- return (idx, True, [])
53
- else:
54
- # 将错误转换为字典(pickle 兼容)
55
- errors = [{"path": e.path, "message": e.message, "value": e.value} for e in result.errors]
56
- return (idx, False, errors)
29
+ from dataclasses import dataclass, field as dataclass_field
30
+ from typing import Any, Callable, Dict, List, Literal, Optional, Set, Union
57
31
 
32
+ from .utils.field_path import get_field, _parse_path, _get_value_by_segments
58
33
 
59
34
  # 支持的类型
60
35
  FieldType = Literal["str", "int", "float", "bool", "list", "dict", "any"]
@@ -187,7 +162,9 @@ class Field:
187
162
 
188
163
  # 选项检查
189
164
  if self.choices is not None and value not in self.choices:
190
- errors.append(ValidationError(path, f"值必须是 {self.choices} 之一", value))
165
+ errors.append(
166
+ ValidationError(path, f"值必须是 {self.choices} 之一", value)
167
+ )
191
168
 
192
169
  # 正则表达式检查
193
170
  if self.pattern is not None and isinstance(value, str):
@@ -347,7 +324,9 @@ class Schema:
347
324
 
348
325
  return errors
349
326
 
350
- def validate_batch(self, data: List[dict], max_errors: int = 100) -> List[tuple]:
327
+ def validate_batch(
328
+ self, data: List[dict], max_errors: int = 100
329
+ ) -> List[tuple]:
351
330
  """
352
331
  批量验证数据
353
332
 
@@ -371,76 +350,9 @@ class Schema:
371
350
 
372
351
  return failed
373
352
 
374
- def validate_parallel(
375
- self,
376
- data: List[dict],
377
- workers: Optional[int] = None,
378
- progress_callback: Optional[Callable[[int, int], None]] = None,
379
- ) -> tuple:
380
- """
381
- 并行验证数据列表。
382
-
383
- Args:
384
- data: 数据列表
385
- workers: 进程数,None 自动检测,1 禁用并行
386
- progress_callback: 进度回调函数
387
-
388
- Returns:
389
- (valid_data, invalid_indices_results) 元组
390
- - valid_data: 有效数据列表
391
- - invalid_indices_results: [(index, ValidationResult), ...] 无效数据
392
- """
393
- if not data:
394
- return [], []
395
-
396
- total = len(data)
397
- use_parallel = workers != 1 and total >= 1000
398
-
399
- valid_data = []
400
- invalid_results = []
401
-
402
- if use_parallel:
403
- from .parallel import get_optimal_workers, parallel_imap
404
-
405
- actual_workers = get_optimal_workers(total, workers)
406
- # 准备参数:(index, item, schema_fields)
407
- args_list = [(i, item, self._fields) for i, item in enumerate(data)]
408
-
409
- for i, (idx, is_valid, result_data) in enumerate(
410
- parallel_imap(
411
- _validate_item_wrapper,
412
- args_list,
413
- workers=actual_workers,
414
- threshold=1000,
415
- )
416
- ):
417
- if is_valid:
418
- valid_data.append(data[idx])
419
- else:
420
- # 重建 ValidationResult(因为不能直接 pickle)
421
- errors = [
422
- ValidationError(path=e["path"], message=e["message"], value=e.get("value"))
423
- for e in result_data
424
- ]
425
- invalid_results.append((idx, ValidationResult(valid=False, errors=errors)))
426
- if progress_callback:
427
- progress_callback(i + 1, total)
428
- else:
429
- # 串行处理
430
- for i, item in enumerate(data):
431
- result = self.validate(item)
432
- if result.valid:
433
- valid_data.append(item)
434
- else:
435
- invalid_results.append((i, result))
436
- if progress_callback:
437
- progress_callback(i + 1, total)
438
-
439
- return valid_data, invalid_results
440
-
441
353
  def __repr__(self) -> str:
442
354
  field_strs = [f" {path}: {field_def}" for path, field_def in self._fields.items()]
443
- return "Schema({\n" + ",\n".join(field_strs) + "\n}})"
355
+ return f"Schema({{\n" + ",\n".join(field_strs) + "\n}})"
444
356
 
445
357
 
446
358
  # ============================================================================
@@ -549,7 +461,9 @@ def sharegpt_schema(
549
461
  """
550
462
  return Schema(
551
463
  {
552
- "conversations": Field(type="list", required=True, min_length=min_conversations),
464
+ "conversations": Field(
465
+ type="list", required=True, min_length=min_conversations
466
+ ),
553
467
  "conversations[*].from": Field(
554
468
  type="str", required=True, choices=[human_role, gpt_role]
555
469
  ),