dtflow 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/__init__.py +34 -1
- dtflow/__main__.py +22 -0
- dtflow/cli/commands.py +5 -0
- dtflow/cli/validate.py +152 -0
- dtflow/core.py +154 -0
- dtflow/framework.py +610 -0
- dtflow/schema.py +508 -0
- {dtflow-0.4.3.dist-info → dtflow-0.5.0.dist-info}/METADATA +107 -2
- {dtflow-0.4.3.dist-info → dtflow-0.5.0.dist-info}/RECORD +11 -8
- {dtflow-0.4.3.dist-info → dtflow-0.5.0.dist-info}/WHEEL +0 -0
- {dtflow-0.4.3.dist-info → dtflow-0.5.0.dist-info}/entry_points.txt +0 -0
dtflow/schema.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Schema 验证模块
|
|
3
|
+
|
|
4
|
+
提供轻量级的数据结构验证,支持字段路径语法。
|
|
5
|
+
|
|
6
|
+
用法:
|
|
7
|
+
from dtflow import Schema, Field
|
|
8
|
+
|
|
9
|
+
# 定义 Schema
|
|
10
|
+
schema = Schema({
|
|
11
|
+
"messages": Field(type="list", required=True, min_length=1),
|
|
12
|
+
"messages[*].role": Field(type="str", choices=["user", "assistant", "system"]),
|
|
13
|
+
"messages[*].content": Field(type="str", min_length=1),
|
|
14
|
+
"score": Field(type="float", min=0, max=1),
|
|
15
|
+
})
|
|
16
|
+
|
|
17
|
+
# 验证单条数据
|
|
18
|
+
result = schema.validate(item)
|
|
19
|
+
if result.valid:
|
|
20
|
+
print("验证通过")
|
|
21
|
+
else:
|
|
22
|
+
print(f"验证失败: {result.errors}")
|
|
23
|
+
|
|
24
|
+
# 验证整个数据集
|
|
25
|
+
dt = DataTransformer.load("data.jsonl")
|
|
26
|
+
results = dt.validate_schema(schema)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from dataclasses import dataclass, field as dataclass_field
|
|
30
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Union
|
|
31
|
+
|
|
32
|
+
from .utils.field_path import get_field, _parse_path, _get_value_by_segments
|
|
33
|
+
|
|
34
|
+
# 支持的类型
|
|
35
|
+
FieldType = Literal["str", "int", "float", "bool", "list", "dict", "any"]
|
|
36
|
+
|
|
37
|
+
# 类型映射
|
|
38
|
+
_TYPE_MAP: Dict[str, type] = {
|
|
39
|
+
"str": str,
|
|
40
|
+
"int": int,
|
|
41
|
+
"float": (int, float), # type: ignore # float 也接受 int
|
|
42
|
+
"bool": bool,
|
|
43
|
+
"list": list,
|
|
44
|
+
"dict": dict,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class ValidationError:
|
|
50
|
+
"""单个验证错误"""
|
|
51
|
+
|
|
52
|
+
path: str # 字段路径
|
|
53
|
+
message: str # 错误信息
|
|
54
|
+
value: Any = None # 实际值(可选)
|
|
55
|
+
|
|
56
|
+
def __str__(self) -> str:
|
|
57
|
+
if self.value is not None:
|
|
58
|
+
return f"{self.path}: {self.message} (got: {self.value!r})"
|
|
59
|
+
return f"{self.path}: {self.message}"
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class ValidationResult:
|
|
64
|
+
"""验证结果"""
|
|
65
|
+
|
|
66
|
+
valid: bool
|
|
67
|
+
errors: List[ValidationError] = dataclass_field(default_factory=list)
|
|
68
|
+
|
|
69
|
+
def __bool__(self) -> bool:
|
|
70
|
+
return self.valid
|
|
71
|
+
|
|
72
|
+
def __str__(self) -> str:
|
|
73
|
+
if self.valid:
|
|
74
|
+
return "ValidationResult(valid=True)"
|
|
75
|
+
error_strs = [str(e) for e in self.errors[:5]]
|
|
76
|
+
if len(self.errors) > 5:
|
|
77
|
+
error_strs.append(f"... and {len(self.errors) - 5} more errors")
|
|
78
|
+
return f"ValidationResult(valid=False, errors=[{', '.join(error_strs)}])"
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class Field:
|
|
83
|
+
"""
|
|
84
|
+
字段定义
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
type: 期望的类型,可选 "str", "int", "float", "bool", "list", "dict", "any"
|
|
88
|
+
required: 是否必填(默认 False,明确标记才是必填)
|
|
89
|
+
nullable: 是否允许 None(默认 False)
|
|
90
|
+
min: 最小值(数值类型)
|
|
91
|
+
max: 最大值(数值类型)
|
|
92
|
+
min_length: 最小长度(字符串或列表)
|
|
93
|
+
max_length: 最大长度(字符串或列表)
|
|
94
|
+
choices: 允许的值列表
|
|
95
|
+
pattern: 正则表达式模式(字符串类型)
|
|
96
|
+
custom: 自定义验证函数,接收值返回 True/False 或错误信息字符串
|
|
97
|
+
|
|
98
|
+
Examples:
|
|
99
|
+
>>> Field(type="str", required=True, min_length=1)
|
|
100
|
+
>>> Field(type="int", min=0, max=100)
|
|
101
|
+
>>> Field(type="str", choices=["user", "assistant", "system"])
|
|
102
|
+
>>> Field(type="float", min=0, max=1, nullable=True)
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
type: FieldType = "any"
|
|
106
|
+
required: bool = False # 默认非必填,明确标记才是必填
|
|
107
|
+
nullable: bool = False
|
|
108
|
+
min: Optional[float] = None
|
|
109
|
+
max: Optional[float] = None
|
|
110
|
+
min_length: Optional[int] = None
|
|
111
|
+
max_length: Optional[int] = None
|
|
112
|
+
choices: Optional[List[Any]] = None
|
|
113
|
+
pattern: Optional[str] = None
|
|
114
|
+
custom: Optional[Callable[[Any], Union[bool, str]]] = None
|
|
115
|
+
|
|
116
|
+
def validate(self, value: Any, path: str = "") -> List[ValidationError]:
|
|
117
|
+
"""
|
|
118
|
+
验证单个值
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
value: 要验证的值
|
|
122
|
+
path: 字段路径(用于错误信息)
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
验证错误列表(空列表表示验证通过)
|
|
126
|
+
"""
|
|
127
|
+
errors: List[ValidationError] = []
|
|
128
|
+
|
|
129
|
+
# 检查 None
|
|
130
|
+
if value is None:
|
|
131
|
+
if self.nullable:
|
|
132
|
+
return [] # None 是允许的,跳过后续检查
|
|
133
|
+
if self.required:
|
|
134
|
+
errors.append(ValidationError(path, "字段不能为 None", value))
|
|
135
|
+
return errors
|
|
136
|
+
|
|
137
|
+
# 类型检查
|
|
138
|
+
if self.type != "any":
|
|
139
|
+
expected_type = _TYPE_MAP.get(self.type)
|
|
140
|
+
if expected_type and not isinstance(value, expected_type):
|
|
141
|
+
errors.append(
|
|
142
|
+
ValidationError(
|
|
143
|
+
path, f"类型错误,期望 {self.type},实际 {type(value).__name__}", value
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
return errors # 类型错误,跳过后续检查
|
|
147
|
+
|
|
148
|
+
# 数值范围检查
|
|
149
|
+
if isinstance(value, (int, float)) and not isinstance(value, bool):
|
|
150
|
+
if self.min is not None and value < self.min:
|
|
151
|
+
errors.append(ValidationError(path, f"值不能小于 {self.min}", value))
|
|
152
|
+
if self.max is not None and value > self.max:
|
|
153
|
+
errors.append(ValidationError(path, f"值不能大于 {self.max}", value))
|
|
154
|
+
|
|
155
|
+
# 长度检查
|
|
156
|
+
if isinstance(value, (str, list, tuple)):
|
|
157
|
+
length = len(value)
|
|
158
|
+
if self.min_length is not None and length < self.min_length:
|
|
159
|
+
errors.append(ValidationError(path, f"长度不能小于 {self.min_length}", length))
|
|
160
|
+
if self.max_length is not None and length > self.max_length:
|
|
161
|
+
errors.append(ValidationError(path, f"长度不能大于 {self.max_length}", length))
|
|
162
|
+
|
|
163
|
+
# 选项检查
|
|
164
|
+
if self.choices is not None and value not in self.choices:
|
|
165
|
+
errors.append(
|
|
166
|
+
ValidationError(path, f"值必须是 {self.choices} 之一", value)
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# 正则表达式检查
|
|
170
|
+
if self.pattern is not None and isinstance(value, str):
|
|
171
|
+
import re
|
|
172
|
+
|
|
173
|
+
if not re.match(self.pattern, value):
|
|
174
|
+
errors.append(ValidationError(path, f"不匹配模式 {self.pattern}", value))
|
|
175
|
+
|
|
176
|
+
# 自定义验证
|
|
177
|
+
if self.custom is not None:
|
|
178
|
+
try:
|
|
179
|
+
result = self.custom(value)
|
|
180
|
+
if result is False:
|
|
181
|
+
errors.append(ValidationError(path, "自定义验证失败", value))
|
|
182
|
+
elif isinstance(result, str):
|
|
183
|
+
errors.append(ValidationError(path, result, value))
|
|
184
|
+
except Exception as e:
|
|
185
|
+
errors.append(ValidationError(path, f"自定义验证异常: {e}", value))
|
|
186
|
+
|
|
187
|
+
return errors
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class Schema:
|
|
191
|
+
"""
|
|
192
|
+
数据结构验证 Schema
|
|
193
|
+
|
|
194
|
+
支持字段路径语法定义嵌套结构的验证规则。
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
fields: 字段定义字典,键为字段路径,值为 Field 对象
|
|
198
|
+
|
|
199
|
+
Examples:
|
|
200
|
+
>>> schema = Schema({
|
|
201
|
+
... "messages": Field(type="list", required=True, min_length=1),
|
|
202
|
+
... "messages[*].role": Field(type="str", choices=["user", "assistant", "system"]),
|
|
203
|
+
... "messages[*].content": Field(type="str", min_length=1),
|
|
204
|
+
... "score": Field(type="float", min=0, max=1, required=False),
|
|
205
|
+
... })
|
|
206
|
+
|
|
207
|
+
>>> result = schema.validate({"messages": [{"role": "user", "content": "hello"}]})
|
|
208
|
+
>>> result.valid
|
|
209
|
+
True
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
def __init__(self, fields: Dict[str, Field]):
|
|
213
|
+
self._fields = fields
|
|
214
|
+
# 分离普通字段和展开字段(包含 [*])
|
|
215
|
+
self._regular_fields: Dict[str, Field] = {}
|
|
216
|
+
self._expand_fields: Dict[str, Field] = {}
|
|
217
|
+
|
|
218
|
+
for path, field_def in fields.items():
|
|
219
|
+
if "[*]" in path:
|
|
220
|
+
self._expand_fields[path] = field_def
|
|
221
|
+
else:
|
|
222
|
+
self._regular_fields[path] = field_def
|
|
223
|
+
|
|
224
|
+
def validate(self, data: dict) -> ValidationResult:
|
|
225
|
+
"""
|
|
226
|
+
验证单条数据
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
data: 要验证的字典数据
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
ValidationResult 对象
|
|
233
|
+
"""
|
|
234
|
+
if not isinstance(data, dict):
|
|
235
|
+
return ValidationResult(
|
|
236
|
+
valid=False,
|
|
237
|
+
errors=[ValidationError("", "数据必须是字典类型", type(data).__name__)],
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
errors: List[ValidationError] = []
|
|
241
|
+
|
|
242
|
+
# 验证普通字段
|
|
243
|
+
for path, field_def in self._regular_fields.items():
|
|
244
|
+
value = get_field(data, path)
|
|
245
|
+
|
|
246
|
+
# 字段不存在
|
|
247
|
+
if value is None and field_def.required:
|
|
248
|
+
# 区分「字段不存在」和「字段值为 None」
|
|
249
|
+
if not self._field_exists(data, path):
|
|
250
|
+
errors.append(ValidationError(path, "必填字段缺失"))
|
|
251
|
+
continue
|
|
252
|
+
|
|
253
|
+
field_errors = field_def.validate(value, path)
|
|
254
|
+
errors.extend(field_errors)
|
|
255
|
+
|
|
256
|
+
# 验证展开字段(包含 [*])
|
|
257
|
+
for path, field_def in self._expand_fields.items():
|
|
258
|
+
expand_errors = self._validate_expand_field(data, path, field_def)
|
|
259
|
+
errors.extend(expand_errors)
|
|
260
|
+
|
|
261
|
+
return ValidationResult(valid=len(errors) == 0, errors=errors)
|
|
262
|
+
|
|
263
|
+
def _field_exists(self, data: dict, path: str) -> bool:
|
|
264
|
+
"""检查字段是否存在(区分不存在和值为 None)"""
|
|
265
|
+
segments = _parse_path(path)
|
|
266
|
+
current = data
|
|
267
|
+
|
|
268
|
+
for seg in segments:
|
|
269
|
+
if current is None:
|
|
270
|
+
return False
|
|
271
|
+
if isinstance(seg, str):
|
|
272
|
+
if isinstance(current, dict):
|
|
273
|
+
if seg not in current:
|
|
274
|
+
return False
|
|
275
|
+
current = current[seg]
|
|
276
|
+
else:
|
|
277
|
+
return False
|
|
278
|
+
elif isinstance(seg, int):
|
|
279
|
+
if isinstance(current, (list, tuple)):
|
|
280
|
+
try:
|
|
281
|
+
current = current[seg]
|
|
282
|
+
except IndexError:
|
|
283
|
+
return False
|
|
284
|
+
else:
|
|
285
|
+
return False
|
|
286
|
+
|
|
287
|
+
return True
|
|
288
|
+
|
|
289
|
+
def _validate_expand_field(
|
|
290
|
+
self, data: dict, path: str, field_def: Field
|
|
291
|
+
) -> List[ValidationError]:
|
|
292
|
+
"""验证包含 [*] 的展开字段"""
|
|
293
|
+
errors: List[ValidationError] = []
|
|
294
|
+
|
|
295
|
+
# 分割路径:[*] 之前的部分和之后的部分
|
|
296
|
+
parts = path.split("[*]", 1)
|
|
297
|
+
prefix = parts[0].rstrip(".")
|
|
298
|
+
suffix = parts[1].lstrip(".") if len(parts) > 1 else ""
|
|
299
|
+
|
|
300
|
+
# 获取数组
|
|
301
|
+
array = get_field(data, prefix) if prefix else data
|
|
302
|
+
|
|
303
|
+
if array is None:
|
|
304
|
+
# 如果前缀字段不存在,由普通字段验证处理
|
|
305
|
+
return errors
|
|
306
|
+
|
|
307
|
+
if not isinstance(array, (list, tuple)):
|
|
308
|
+
errors.append(ValidationError(prefix, "期望是数组类型", type(array).__name__))
|
|
309
|
+
return errors
|
|
310
|
+
|
|
311
|
+
# 对数组中的每个元素验证
|
|
312
|
+
for i, item in enumerate(array):
|
|
313
|
+
actual_path = f"{prefix}[{i}]" if prefix else f"[{i}]"
|
|
314
|
+
|
|
315
|
+
if suffix:
|
|
316
|
+
# 有后缀,获取嵌套值
|
|
317
|
+
value = get_field(item, suffix) if isinstance(item, dict) else None
|
|
318
|
+
actual_path = f"{actual_path}.{suffix}"
|
|
319
|
+
else:
|
|
320
|
+
value = item
|
|
321
|
+
|
|
322
|
+
field_errors = field_def.validate(value, actual_path)
|
|
323
|
+
errors.extend(field_errors)
|
|
324
|
+
|
|
325
|
+
return errors
|
|
326
|
+
|
|
327
|
+
def validate_batch(
|
|
328
|
+
self, data: List[dict], max_errors: int = 100
|
|
329
|
+
) -> List[tuple]:
|
|
330
|
+
"""
|
|
331
|
+
批量验证数据
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
data: 数据列表
|
|
335
|
+
max_errors: 最大错误数量(超过后停止)
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
[(index, ValidationResult), ...] 失败记录列表
|
|
339
|
+
"""
|
|
340
|
+
failed: List[tuple] = []
|
|
341
|
+
error_count = 0
|
|
342
|
+
|
|
343
|
+
for i, item in enumerate(data):
|
|
344
|
+
result = self.validate(item)
|
|
345
|
+
if not result.valid:
|
|
346
|
+
failed.append((i, result))
|
|
347
|
+
error_count += len(result.errors)
|
|
348
|
+
if error_count >= max_errors:
|
|
349
|
+
break
|
|
350
|
+
|
|
351
|
+
return failed
|
|
352
|
+
|
|
353
|
+
def __repr__(self) -> str:
|
|
354
|
+
field_strs = [f" {path}: {field_def}" for path, field_def in self._fields.items()]
|
|
355
|
+
return f"Schema({{\n" + ",\n".join(field_strs) + "\n}})"
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
# ============================================================================
|
|
359
|
+
# 预定义 Schema 模板
|
|
360
|
+
# ============================================================================
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def openai_chat_schema(
|
|
364
|
+
min_messages: int = 1,
|
|
365
|
+
max_messages: Optional[int] = None,
|
|
366
|
+
roles: Optional[List[str]] = None,
|
|
367
|
+
) -> Schema:
|
|
368
|
+
"""
|
|
369
|
+
OpenAI Chat 格式的 Schema
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
min_messages: 最少消息数(默认 1)
|
|
373
|
+
max_messages: 最多消息数(默认不限)
|
|
374
|
+
roles: 允许的角色列表(默认 ["system", "user", "assistant"])
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
Schema 对象
|
|
378
|
+
|
|
379
|
+
Examples:
|
|
380
|
+
>>> schema = openai_chat_schema()
|
|
381
|
+
>>> result = schema.validate({"messages": [{"role": "user", "content": "hi"}]})
|
|
382
|
+
"""
|
|
383
|
+
if roles is None:
|
|
384
|
+
roles = ["system", "user", "assistant"]
|
|
385
|
+
|
|
386
|
+
fields = {
|
|
387
|
+
"messages": Field(
|
|
388
|
+
type="list",
|
|
389
|
+
required=True,
|
|
390
|
+
min_length=min_messages,
|
|
391
|
+
max_length=max_messages,
|
|
392
|
+
),
|
|
393
|
+
"messages[*].role": Field(type="str", required=True, choices=roles),
|
|
394
|
+
"messages[*].content": Field(type="str", required=True, min_length=1),
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
return Schema(fields)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def alpaca_schema(
|
|
401
|
+
require_input: bool = False,
|
|
402
|
+
min_output_length: int = 1,
|
|
403
|
+
) -> Schema:
|
|
404
|
+
"""
|
|
405
|
+
Alpaca 格式的 Schema
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
require_input: input 字段是否必填(默认 False)
|
|
409
|
+
min_output_length: output 最小长度(默认 1)
|
|
410
|
+
|
|
411
|
+
Returns:
|
|
412
|
+
Schema 对象
|
|
413
|
+
"""
|
|
414
|
+
return Schema(
|
|
415
|
+
{
|
|
416
|
+
"instruction": Field(type="str", required=True, min_length=1),
|
|
417
|
+
"input": Field(type="str", required=require_input),
|
|
418
|
+
"output": Field(type="str", required=True, min_length=min_output_length),
|
|
419
|
+
}
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def dpo_schema(
|
|
424
|
+
min_chosen_length: int = 1,
|
|
425
|
+
min_rejected_length: int = 1,
|
|
426
|
+
) -> Schema:
|
|
427
|
+
"""
|
|
428
|
+
DPO 偏好对格式的 Schema
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
min_chosen_length: chosen 最小长度
|
|
432
|
+
min_rejected_length: rejected 最小长度
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
Schema 对象
|
|
436
|
+
"""
|
|
437
|
+
return Schema(
|
|
438
|
+
{
|
|
439
|
+
"prompt": Field(type="str", required=True, min_length=1),
|
|
440
|
+
"chosen": Field(type="str", required=True, min_length=min_chosen_length),
|
|
441
|
+
"rejected": Field(type="str", required=True, min_length=min_rejected_length),
|
|
442
|
+
}
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def sharegpt_schema(
|
|
447
|
+
min_conversations: int = 1,
|
|
448
|
+
human_role: str = "human",
|
|
449
|
+
gpt_role: str = "gpt",
|
|
450
|
+
) -> Schema:
|
|
451
|
+
"""
|
|
452
|
+
ShareGPT 多轮对话格式的 Schema
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
min_conversations: 最少对话轮数
|
|
456
|
+
human_role: 用户角色名
|
|
457
|
+
gpt_role: 助手角色名
|
|
458
|
+
|
|
459
|
+
Returns:
|
|
460
|
+
Schema 对象
|
|
461
|
+
"""
|
|
462
|
+
return Schema(
|
|
463
|
+
{
|
|
464
|
+
"conversations": Field(
|
|
465
|
+
type="list", required=True, min_length=min_conversations
|
|
466
|
+
),
|
|
467
|
+
"conversations[*].from": Field(
|
|
468
|
+
type="str", required=True, choices=[human_role, gpt_role]
|
|
469
|
+
),
|
|
470
|
+
"conversations[*].value": Field(type="str", required=True, min_length=1),
|
|
471
|
+
}
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
# ============================================================================
|
|
476
|
+
# 便捷函数
|
|
477
|
+
# ============================================================================
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def validate_data(data: Union[dict, List[dict]], schema: Schema) -> ValidationResult:
|
|
481
|
+
"""
|
|
482
|
+
便捷函数:验证单条或多条数据
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
data: 单条数据(dict)或数据列表(List[dict])
|
|
486
|
+
schema: Schema 对象
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
ValidationResult(如果是列表,返回汇总结果)
|
|
490
|
+
"""
|
|
491
|
+
if isinstance(data, dict):
|
|
492
|
+
return schema.validate(data)
|
|
493
|
+
|
|
494
|
+
# 批量验证
|
|
495
|
+
all_errors: List[ValidationError] = []
|
|
496
|
+
for i, item in enumerate(data):
|
|
497
|
+
result = schema.validate(item)
|
|
498
|
+
if not result.valid:
|
|
499
|
+
for err in result.errors:
|
|
500
|
+
all_errors.append(
|
|
501
|
+
ValidationError(
|
|
502
|
+
path=f"[{i}].{err.path}" if err.path else f"[{i}]",
|
|
503
|
+
message=err.message,
|
|
504
|
+
value=err.value,
|
|
505
|
+
)
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
return ValidationResult(valid=len(all_errors) == 0, errors=all_errors)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dtflow
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0
|
|
4
4
|
Summary: A flexible data transformation tool for ML training formats (SFT, RLHF, Pretrain)
|
|
5
5
|
Project-URL: Homepage, https://github.com/yourusername/DataTransformer
|
|
6
6
|
Project-URL: Documentation, https://github.com/yourusername/DataTransformer#readme
|
|
@@ -129,7 +129,7 @@ dt.filter(lambda x: x.language == "zh")
|
|
|
129
129
|
### 数据验证
|
|
130
130
|
|
|
131
131
|
```python
|
|
132
|
-
#
|
|
132
|
+
# 简单验证,返回不通过的记录列表
|
|
133
133
|
errors = dt.validate(lambda x: len(x.messages) >= 2)
|
|
134
134
|
|
|
135
135
|
if errors:
|
|
@@ -137,6 +137,53 @@ if errors:
|
|
|
137
137
|
print(f"第 {e.index} 行: {e.error}")
|
|
138
138
|
```
|
|
139
139
|
|
|
140
|
+
### Schema 验证
|
|
141
|
+
|
|
142
|
+
使用 Schema 进行结构化数据验证:
|
|
143
|
+
|
|
144
|
+
```python
|
|
145
|
+
from dtflow import Schema, Field, openai_chat_schema
|
|
146
|
+
|
|
147
|
+
# 使用预设 Schema
|
|
148
|
+
result = dt.validate_schema(openai_chat_schema)
|
|
149
|
+
print(result) # ValidationResult(valid=950, invalid=50, errors=[...])
|
|
150
|
+
|
|
151
|
+
# 自定义 Schema
|
|
152
|
+
schema = Schema({
|
|
153
|
+
"messages": Field(type="list", required=True, min_length=1),
|
|
154
|
+
"messages[*].role": Field(type="str", choices=["user", "assistant", "system"]),
|
|
155
|
+
"messages[*].content": Field(type="str", min_length=1),
|
|
156
|
+
"score": Field(type="float", min=0, max=1),
|
|
157
|
+
})
|
|
158
|
+
|
|
159
|
+
result = dt.validate_schema(schema)
|
|
160
|
+
|
|
161
|
+
# 过滤出有效数据
|
|
162
|
+
valid_dt = dt.validate_schema(schema, filter_invalid=True)
|
|
163
|
+
valid_dt.save("valid.jsonl")
|
|
164
|
+
```
|
|
165
|
+
|
|
166
|
+
**预设 Schema**:
|
|
167
|
+
|
|
168
|
+
| Schema 名称 | 用途 |
|
|
169
|
+
|------------|------|
|
|
170
|
+
| `openai_chat_schema` | OpenAI messages 格式验证 |
|
|
171
|
+
| `alpaca_schema` | Alpaca instruction/output 格式 |
|
|
172
|
+
| `sharegpt_schema` | ShareGPT conversations 格式 |
|
|
173
|
+
| `dpo_schema` | DPO prompt/chosen/rejected 格式 |
|
|
174
|
+
|
|
175
|
+
**Field 参数**:
|
|
176
|
+
|
|
177
|
+
| 参数 | 说明 | 示例 |
|
|
178
|
+
|------|------|------|
|
|
179
|
+
| `type` | 类型验证 | `"str"`, `"int"`, `"float"`, `"bool"`, `"list"`, `"dict"` |
|
|
180
|
+
| `required` | 是否必填 | `True` / `False` |
|
|
181
|
+
| `min` / `max` | 数值范围 | `min=0, max=1` |
|
|
182
|
+
| `min_length` / `max_length` | 长度范围 | `min_length=1` |
|
|
183
|
+
| `choices` | 枚举值 | `choices=["user", "assistant"]` |
|
|
184
|
+
| `pattern` | 正则匹配 | `pattern=r"^\d{4}-\d{2}-\d{2}$"` |
|
|
185
|
+
| `custom` | 自定义验证 | `custom=lambda x: x > 0` |
|
|
186
|
+
|
|
140
187
|
### 数据转换
|
|
141
188
|
|
|
142
189
|
```python
|
|
@@ -286,6 +333,58 @@ dt.transform(to_swift_vlm(images_field="images")).save("swift_vlm.jsonl")
|
|
|
286
333
|
# 输出: {"messages": [...], "images": ["/path/to/img.jpg"]}
|
|
287
334
|
```
|
|
288
335
|
|
|
336
|
+
### 训练框架一键导出
|
|
337
|
+
|
|
338
|
+
将数据导出为目标训练框架可直接使用的格式,自动生成配置文件:
|
|
339
|
+
|
|
340
|
+
```python
|
|
341
|
+
from dtflow import DataTransformer
|
|
342
|
+
|
|
343
|
+
dt = DataTransformer.load("data.jsonl")
|
|
344
|
+
|
|
345
|
+
# 1. 检查框架兼容性
|
|
346
|
+
result = dt.check_compatibility("llama-factory")
|
|
347
|
+
print(result)
|
|
348
|
+
# ✅ 兼容 - LLaMA-Factory (openai_chat)
|
|
349
|
+
# 或
|
|
350
|
+
# ❌ 不兼容 - 错误: xxx
|
|
351
|
+
|
|
352
|
+
# 2. 一键导出到 LLaMA-Factory
|
|
353
|
+
files = dt.export_for("llama-factory", "./llama_ready/")
|
|
354
|
+
# 生成文件:
|
|
355
|
+
# - ./llama_ready/custom_dataset.json # 数据文件
|
|
356
|
+
# - ./llama_ready/dataset_info.json # 数据集配置
|
|
357
|
+
# - ./llama_ready/train_args.yaml # 训练参数模板
|
|
358
|
+
|
|
359
|
+
# 3. 导出到 ms-swift
|
|
360
|
+
files = dt.export_for("swift", "./swift_ready/")
|
|
361
|
+
# 生成: data.jsonl + train_swift.sh
|
|
362
|
+
|
|
363
|
+
# 4. 导出到 Axolotl
|
|
364
|
+
files = dt.export_for("axolotl", "./axolotl_ready/")
|
|
365
|
+
# 生成: data.jsonl + config.yaml
|
|
366
|
+
|
|
367
|
+
# 指定数据集名称
|
|
368
|
+
dt.export_for("llama-factory", "./output/", dataset_name="my_sft_data")
|
|
369
|
+
```
|
|
370
|
+
|
|
371
|
+
**支持的框架**:
|
|
372
|
+
|
|
373
|
+
| 框架 | 导出内容 | 使用方式 |
|
|
374
|
+
|------|---------|---------|
|
|
375
|
+
| `llama-factory` | data.json + dataset_info.json + train_args.yaml | `llamafactory-cli train train_args.yaml` |
|
|
376
|
+
| `swift` | data.jsonl + train_swift.sh | `bash train_swift.sh` |
|
|
377
|
+
| `axolotl` | data.jsonl + config.yaml | `accelerate launch -m axolotl.cli.train config.yaml` |
|
|
378
|
+
|
|
379
|
+
**自动格式检测**:
|
|
380
|
+
|
|
381
|
+
| 检测到的格式 | 数据结构 |
|
|
382
|
+
|------------|---------|
|
|
383
|
+
| `openai_chat` | `{"messages": [{"role": "user", ...}]}` |
|
|
384
|
+
| `alpaca` | `{"instruction": ..., "output": ...}` |
|
|
385
|
+
| `sharegpt` | `{"conversations": [{"from": "human", ...}]}` |
|
|
386
|
+
| `dpo` | `{"prompt": ..., "chosen": ..., "rejected": ...}` |
|
|
387
|
+
|
|
289
388
|
### 其他操作
|
|
290
389
|
|
|
291
390
|
```python
|
|
@@ -361,6 +460,12 @@ dt concat a.jsonl b.jsonl -o merged.jsonl
|
|
|
361
460
|
|
|
362
461
|
# 数据统计
|
|
363
462
|
dt stats data.jsonl
|
|
463
|
+
|
|
464
|
+
# 数据验证
|
|
465
|
+
dt validate data.jsonl --preset=openai_chat # 使用预设 schema 验证
|
|
466
|
+
dt validate data.jsonl --preset=alpaca --verbose # 详细输出
|
|
467
|
+
dt validate data.jsonl --preset=sharegpt --filter-invalid -o valid.jsonl # 过滤出有效数据
|
|
468
|
+
dt validate data.jsonl --preset=dpo --max-errors=100 # 限制错误输出数量
|
|
364
469
|
```
|
|
365
470
|
|
|
366
471
|
### 字段路径语法
|
|
@@ -1,15 +1,17 @@
|
|
|
1
|
-
dtflow/__init__.py,sha256=
|
|
2
|
-
dtflow/__main__.py,sha256=
|
|
1
|
+
dtflow/__init__.py,sha256=fOkG8g8VXS1HFk2ztmaJpjHBXmArHGBW8WE8tHPHXts,3031
|
|
2
|
+
dtflow/__main__.py,sha256=ySpqvEn7k-vsrYFPx-8O6p-yx_24KccgnOSPd2XybhM,12572
|
|
3
3
|
dtflow/converters.py,sha256=gyy-K15zjzGBawFnZa8D9JX37JZ47rey2GhjKa2pxFo,22081
|
|
4
|
-
dtflow/core.py,sha256=
|
|
4
|
+
dtflow/core.py,sha256=HJAlxOaCtwvLOWF9JSC-2li3fsyRE2Q-H9unj9GQJ6M,35445
|
|
5
|
+
dtflow/framework.py,sha256=jyICi_RWHjX7WfsXdSbWmP1SL7y1OWSPyd5G5Y-lvg4,17578
|
|
5
6
|
dtflow/lineage.py,sha256=vQ06lxBHftu-Ma5HlISp3F2eiIvwagQSnUGaLeABDZY,12190
|
|
6
7
|
dtflow/pipeline.py,sha256=zZaC4fg5vsp_30Fhbg75vu0yggsdvf28bWBiVDWzZ6Y,13901
|
|
7
8
|
dtflow/presets.py,sha256=OP1nnM5NFk5Kli9FsXK0xAot48E5OQ6-VOIJT9ffXPg,5023
|
|
9
|
+
dtflow/schema.py,sha256=IFcij22_UFKcgKT1YWwRg2QJO0vcAvCb1arZmsGByts,16824
|
|
8
10
|
dtflow/streaming.py,sha256=jtWQjkhhZqfyzIaFskXNvooGAYDQBn1b6X8FHgaCZYk,22704
|
|
9
11
|
dtflow/tokenizers.py,sha256=zxE6XZGjZ_DOGCjRSClI9xaAbFVf8FS6jwwssGoi_9U,18111
|
|
10
12
|
dtflow/cli/__init__.py,sha256=QhZ-thgx9IBTFII7T_hdoWFUl0CCsdGQHN5ZEZw2XB0,423
|
|
11
13
|
dtflow/cli/clean.py,sha256=y9VCRibgK1j8WIY3h0XZX0m93EdELQC7TdnseMWwS-0,17799
|
|
12
|
-
dtflow/cli/commands.py,sha256=
|
|
14
|
+
dtflow/cli/commands.py,sha256=ST65Ox_MKu-CKAtPVaxECAPXYOJiF7BhL32A4nsZZl0,1175
|
|
13
15
|
dtflow/cli/common.py,sha256=FsDFVNcLj_874qSg2dGef4V7mqPU9THLchT8PxJpBt8,12955
|
|
14
16
|
dtflow/cli/io_ops.py,sha256=BMDisP6dxzzmSjYwmeFwaHmpHHPqirmXAWeNTD-9MQM,13254
|
|
15
17
|
dtflow/cli/lineage.py,sha256=_lNh35nF9AA0Zy6FyZ4g8IzrXH2ZQnp3inF-o2Hs1pw,1383
|
|
@@ -17,6 +19,7 @@ dtflow/cli/pipeline.py,sha256=QNEo-BJlaC1CVnVeRZr7TwfuZYloJ4TebIzJ5ALzry0,1426
|
|
|
17
19
|
dtflow/cli/sample.py,sha256=vPTQlF0OXEry4QjO8uaD9vOae4AQbX9zDwVYOxg59ZI,10339
|
|
18
20
|
dtflow/cli/stats.py,sha256=HByF0sFMqY1kM75dnjTcJbMKDdQNdOt4iDba4au_-pI,20495
|
|
19
21
|
dtflow/cli/transform.py,sha256=w6xqMOxPxQvL2u_BPCfpDHuPSC9gmcqMPVN8s-B6bbY,15052
|
|
22
|
+
dtflow/cli/validate.py,sha256=65aGVlMS_Rq0Ch0YQ-TclVJ03RQP4CnG137wthzb8Ao,4384
|
|
20
23
|
dtflow/mcp/__init__.py,sha256=huEJ3rXDbxDRjsLPEvjNT2u3tWs6Poiv6fokPIrByjw,897
|
|
21
24
|
dtflow/mcp/__main__.py,sha256=PoT2ZZmJq9xDZxDACJfqDW9Ld_ukHrGNK-0XUd7WGnY,448
|
|
22
25
|
dtflow/mcp/cli.py,sha256=ck0oOS_642cNktxULaMRE7BJfMxsBCwotmCj3PSPwVk,13110
|
|
@@ -27,7 +30,7 @@ dtflow/storage/io.py,sha256=ZH2aSE-S89gpy3z4oTqhcqWf4u10OdkDoyul7o_YBDI,23374
|
|
|
27
30
|
dtflow/utils/__init__.py,sha256=f8v9HJZMWRI5AL64Vjr76Pf2Na_whOF9nJBKgPbXXYg,429
|
|
28
31
|
dtflow/utils/display.py,sha256=OeOdTh6mbDwSkDWlmkjfpTjy2QG8ZUaYU0NpHUWkpEQ,5881
|
|
29
32
|
dtflow/utils/field_path.py,sha256=WcNA-LZh3H61a77FEzB_R7YAyyZl3M8ofdq05ytQGmI,7459
|
|
30
|
-
dtflow-0.
|
|
31
|
-
dtflow-0.
|
|
32
|
-
dtflow-0.
|
|
33
|
-
dtflow-0.
|
|
33
|
+
dtflow-0.5.0.dist-info/METADATA,sha256=chELFIevPb1h7ZydbWtH9rM7RiA2n3Ep-XWL1qbaHk0,22084
|
|
34
|
+
dtflow-0.5.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
35
|
+
dtflow-0.5.0.dist-info/entry_points.txt,sha256=dadIDOK7Iu9pMxnMPBfpb4aAPe4hQbBOshpQYjVYpGc,44
|
|
36
|
+
dtflow-0.5.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|