dtflow 0.5.2__tar.gz → 0.5.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow-0.5.4/CHANGELOG.md +19 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/PKG-INFO +1 -1
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/__init__.py +7 -7
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/cli/common.py +18 -3
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/cli/sample.py +9 -5
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/converters.py +35 -19
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/presets.py +14 -15
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/utils/__init__.py +3 -0
- dtflow-0.5.4/dtflow/utils/helpers.py +30 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/pyproject.toml +6 -4
- dtflow-0.5.4/tests/test_cli_clean.py +314 -0
- dtflow-0.5.4/tests/test_cli_sample.py +242 -0
- dtflow-0.5.4/tests/test_cli_stats.py +213 -0
- dtflow-0.5.4/tests/test_cli_transform.py +304 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/.gitignore +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/README.md +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/__main__.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/cli/__init__.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/cli/clean.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/cli/commands.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/cli/io_ops.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/cli/lineage.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/cli/pipeline.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/cli/stats.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/cli/transform.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/cli/validate.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/core.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/framework.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/lineage.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/mcp/__init__.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/mcp/__main__.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/mcp/cli.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/mcp/docs.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/mcp/server.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/pipeline.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/schema.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/storage/__init__.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/storage/io.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/streaming.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/tokenizers.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/utils/display.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/dtflow/utils/field_path.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/tests/README.md +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/tests/benchmark_io.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/tests/benchmark_sharegpt.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/tests/test_cli_benchmark.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/tests/test_converters.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/tests/test_field_path.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/tests/test_framework.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/tests/test_io.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/tests/test_lineage.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/tests/test_pipeline.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/tests/test_schema.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/tests/test_streaming.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/tests/test_tokenizers.py +0 -0
- {dtflow-0.5.2 → dtflow-0.5.4}/tests/test_transformer.py +0 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
## [0.5.2] - 2026-01-18
|
|
4
|
+
|
|
5
|
+
### Miscellaneous
|
|
6
|
+
|
|
7
|
+
- Bump version to 0.5.2
|
|
8
|
+
- 添加 pre-commit 配置和发版脚本
|
|
9
|
+
|
|
10
|
+
## [0.5.1] - 2026-01-18
|
|
11
|
+
|
|
12
|
+
### Features
|
|
13
|
+
|
|
14
|
+
- 优化 sample 命令文本预览显示
|
|
15
|
+
|
|
16
|
+
### Testing
|
|
17
|
+
|
|
18
|
+
- 添加测试运行说明
|
|
19
|
+
- 补充 tail/token-stats/validate 性能测试
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dtflow
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.4
|
|
4
4
|
Summary: A flexible data transformation tool for ML training formats (SFT, RLHF, Pretrain)
|
|
5
5
|
Project-URL: Homepage, https://github.com/yourusername/DataTransformer
|
|
6
6
|
Project-URL: Documentation, https://github.com/yourusername/DataTransformer#readme
|
|
@@ -26,6 +26,12 @@ from .converters import ( # LLaMA-Factory 扩展; ms-swift
|
|
|
26
26
|
to_swift_vlm,
|
|
27
27
|
)
|
|
28
28
|
from .core import DataTransformer, DictWrapper, TransformError, TransformErrors
|
|
29
|
+
from .framework import (
|
|
30
|
+
CompatibilityResult,
|
|
31
|
+
check_compatibility,
|
|
32
|
+
detect_format,
|
|
33
|
+
export_for,
|
|
34
|
+
)
|
|
29
35
|
from .presets import get_preset, list_presets
|
|
30
36
|
from .schema import (
|
|
31
37
|
Field,
|
|
@@ -38,12 +44,6 @@ from .schema import (
|
|
|
38
44
|
sharegpt_schema,
|
|
39
45
|
validate_data,
|
|
40
46
|
)
|
|
41
|
-
from .framework import (
|
|
42
|
-
CompatibilityResult,
|
|
43
|
-
check_compatibility,
|
|
44
|
-
detect_format,
|
|
45
|
-
export_for,
|
|
46
|
-
)
|
|
47
47
|
from .storage import load_data, sample_file, save_data
|
|
48
48
|
from .streaming import StreamingTransformer, load_sharded, load_stream, process_shards
|
|
49
49
|
from .tokenizers import (
|
|
@@ -60,7 +60,7 @@ from .tokenizers import (
|
|
|
60
60
|
token_stats,
|
|
61
61
|
)
|
|
62
62
|
|
|
63
|
-
__version__ = "0.5.
|
|
63
|
+
__version__ = "0.5.4"
|
|
64
64
|
|
|
65
65
|
__all__ = [
|
|
66
66
|
# core
|
|
@@ -100,8 +100,6 @@ def _format_nested(
|
|
|
100
100
|
└─ 最后一项
|
|
101
101
|
"""
|
|
102
102
|
lines = []
|
|
103
|
-
branch = "└─ " if is_last else "├─ "
|
|
104
|
-
cont = " " if is_last else "│ "
|
|
105
103
|
|
|
106
104
|
if isinstance(value, dict):
|
|
107
105
|
items = list(value.items())
|
|
@@ -183,6 +181,7 @@ def _print_samples(
|
|
|
183
181
|
filename: Optional[str] = None,
|
|
184
182
|
total_count: Optional[int] = None,
|
|
185
183
|
fields: Optional[List[str]] = None,
|
|
184
|
+
file_size: Optional[int] = None,
|
|
186
185
|
) -> None:
|
|
187
186
|
"""
|
|
188
187
|
打印采样结果。
|
|
@@ -190,8 +189,9 @@ def _print_samples(
|
|
|
190
189
|
Args:
|
|
191
190
|
samples: 采样数据列表
|
|
192
191
|
filename: 文件名(用于显示概览)
|
|
193
|
-
total_count:
|
|
192
|
+
total_count: 文件总行数(用于显示概览),大文件时可能为 None
|
|
194
193
|
fields: 只显示指定字段
|
|
194
|
+
file_size: 文件大小(字节),当 total_count 为 None 时显示
|
|
195
195
|
"""
|
|
196
196
|
if not samples:
|
|
197
197
|
print("没有数据")
|
|
@@ -219,6 +219,8 @@ def _print_samples(
|
|
|
219
219
|
|
|
220
220
|
if total_count is not None:
|
|
221
221
|
info = f"总行数: {total_count:,} | 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
|
|
222
|
+
elif file_size is not None:
|
|
223
|
+
info = f"文件大小: {_format_file_size(file_size)} | 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
|
|
222
224
|
else:
|
|
223
225
|
info = f"采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
|
|
224
226
|
|
|
@@ -266,6 +268,10 @@ def _print_samples(
|
|
|
266
268
|
print(
|
|
267
269
|
f" 总行数: {total_count:,} | 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
|
|
268
270
|
)
|
|
271
|
+
elif file_size is not None:
|
|
272
|
+
print(
|
|
273
|
+
f" 文件大小: {_format_file_size(file_size)} | 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个"
|
|
274
|
+
)
|
|
269
275
|
else:
|
|
270
276
|
print(f" 采样: {len(samples)} 条 | 字段: {len(all_fields)} 个")
|
|
271
277
|
print(f" 字段: {', '.join(sorted(all_fields))}")
|
|
@@ -287,6 +293,15 @@ def _parse_field_list(value: Any) -> List[str]:
|
|
|
287
293
|
return [str(value)]
|
|
288
294
|
|
|
289
295
|
|
|
296
|
+
def _format_file_size(size: int) -> str:
|
|
297
|
+
"""格式化文件大小"""
|
|
298
|
+
for unit in ["B", "KB", "MB", "GB"]:
|
|
299
|
+
if size < 1024:
|
|
300
|
+
return f"{size:.1f} {unit}"
|
|
301
|
+
size /= 1024
|
|
302
|
+
return f"{size:.1f} TB"
|
|
303
|
+
|
|
304
|
+
|
|
290
305
|
def _is_empty_value(v: Any) -> bool:
|
|
291
306
|
"""判断值是否为空"""
|
|
292
307
|
if v is None:
|
|
@@ -99,11 +99,15 @@ def sample(
|
|
|
99
99
|
for item in sampled:
|
|
100
100
|
print(orjson.dumps(item, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
|
101
101
|
else:
|
|
102
|
-
#
|
|
103
|
-
|
|
102
|
+
# 大文件跳过行数统计(50MB 阈值)
|
|
103
|
+
file_size = filepath.stat().st_size
|
|
104
|
+
if file_size < 50 * 1024 * 1024:
|
|
105
|
+
total_count = _get_file_row_count(filepath)
|
|
106
|
+
else:
|
|
107
|
+
total_count = None
|
|
104
108
|
# 解析 fields 参数
|
|
105
109
|
field_list = _parse_field_list(fields) if fields else None
|
|
106
|
-
_print_samples(sampled, filepath.name, total_count, field_list)
|
|
110
|
+
_print_samples(sampled, filepath.name, total_count, field_list, file_size)
|
|
107
111
|
|
|
108
112
|
|
|
109
113
|
def _stratified_sample(
|
|
@@ -196,7 +200,7 @@ def _stratified_sample(
|
|
|
196
200
|
|
|
197
201
|
# 执行各组采样
|
|
198
202
|
result = []
|
|
199
|
-
print(
|
|
203
|
+
print("🔄 执行采样...")
|
|
200
204
|
for key in group_keys:
|
|
201
205
|
group_data = groups[key]
|
|
202
206
|
target = min(sample_counts[key], len(group_data))
|
|
@@ -215,7 +219,7 @@ def _stratified_sample(
|
|
|
215
219
|
result.extend(sampled)
|
|
216
220
|
|
|
217
221
|
# 打印采样结果
|
|
218
|
-
print(
|
|
222
|
+
print("\n📋 采样结果:")
|
|
219
223
|
result_groups: Dict[Any, int] = defaultdict(int)
|
|
220
224
|
for item in result:
|
|
221
225
|
key = item.get(stratify_field, "__null__")
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
提供与 HuggingFace datasets 等常用格式的互转功能。
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from typing import Any, Callable, Dict, List, Optional
|
|
7
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def to_hf_dataset(data: List[Dict[str, Any]]):
|
|
@@ -23,8 +23,8 @@ def to_hf_dataset(data: List[Dict[str, Any]]):
|
|
|
23
23
|
"""
|
|
24
24
|
try:
|
|
25
25
|
from datasets import Dataset
|
|
26
|
-
except ImportError:
|
|
27
|
-
raise ImportError("需要安装 datasets: pip install datasets")
|
|
26
|
+
except ImportError as e:
|
|
27
|
+
raise ImportError("需要安装 datasets: pip install datasets") from e
|
|
28
28
|
|
|
29
29
|
return Dataset.from_list(data)
|
|
30
30
|
|
|
@@ -45,9 +45,9 @@ def from_hf_dataset(dataset, split: Optional[str] = None) -> List[Dict[str, Any]
|
|
|
45
45
|
>>> data = from_hf_dataset(my_dataset, split="train")
|
|
46
46
|
"""
|
|
47
47
|
try:
|
|
48
|
-
from datasets import
|
|
49
|
-
except ImportError:
|
|
50
|
-
raise ImportError("需要安装 datasets: pip install datasets")
|
|
48
|
+
from datasets import load_dataset
|
|
49
|
+
except ImportError as e:
|
|
50
|
+
raise ImportError("需要安装 datasets: pip install datasets") from e
|
|
51
51
|
|
|
52
52
|
# 如果是字符串,加载数据集
|
|
53
53
|
if isinstance(dataset, str):
|
|
@@ -143,14 +143,16 @@ def to_openai_batch(
|
|
|
143
143
|
>>> batch_input = dt.to(to_openai_batch(model="gpt-4o"))
|
|
144
144
|
"""
|
|
145
145
|
|
|
146
|
-
|
|
146
|
+
counter = {"idx": 0}
|
|
147
|
+
|
|
148
|
+
def transform(item) -> dict:
|
|
147
149
|
messages = item.get(messages_field, []) if hasattr(item, "get") else item[messages_field]
|
|
148
150
|
|
|
149
151
|
if custom_id_field:
|
|
150
152
|
custom_id = item.get(custom_id_field) if hasattr(item, "get") else item[custom_id_field]
|
|
151
153
|
else:
|
|
152
|
-
custom_id = f"request-{idx
|
|
153
|
-
idx
|
|
154
|
+
custom_id = f"request-{counter['idx']}"
|
|
155
|
+
counter["idx"] += 1
|
|
154
156
|
|
|
155
157
|
return {
|
|
156
158
|
"custom_id": str(custom_id),
|
|
@@ -196,7 +198,8 @@ def to_llama_factory(
|
|
|
196
198
|
"""
|
|
197
199
|
|
|
198
200
|
def transform(item) -> dict:
|
|
199
|
-
|
|
201
|
+
def get(f):
|
|
202
|
+
return item.get(f, "") if hasattr(item, "get") else getattr(item, f, "")
|
|
200
203
|
|
|
201
204
|
result = {
|
|
202
205
|
"instruction": get(instruction_field),
|
|
@@ -248,7 +251,7 @@ def to_axolotl(
|
|
|
248
251
|
conversations = (
|
|
249
252
|
item.get(conversations_field, [])
|
|
250
253
|
if hasattr(item, "get")
|
|
251
|
-
else item
|
|
254
|
+
else getattr(item, conversations_field, [])
|
|
252
255
|
)
|
|
253
256
|
|
|
254
257
|
# 如果已经是正确格式,直接返回
|
|
@@ -257,7 +260,9 @@ def to_axolotl(
|
|
|
257
260
|
return {"conversations": conversations}
|
|
258
261
|
|
|
259
262
|
# 尝试从 messages 格式转换
|
|
260
|
-
messages =
|
|
263
|
+
messages = (
|
|
264
|
+
item.get("messages", []) if hasattr(item, "get") else getattr(item, "messages", [])
|
|
265
|
+
)
|
|
261
266
|
if messages:
|
|
262
267
|
role_map = {"user": "human", "assistant": "gpt", "system": "system"}
|
|
263
268
|
conversations = [
|
|
@@ -312,7 +317,9 @@ def to_llama_factory_sharegpt(
|
|
|
312
317
|
}
|
|
313
318
|
|
|
314
319
|
def transform(item) -> dict:
|
|
315
|
-
|
|
320
|
+
def get(f):
|
|
321
|
+
return item.get(f, "") if hasattr(item, "get") else getattr(item, f, "")
|
|
322
|
+
|
|
316
323
|
messages = get(messages_field) or []
|
|
317
324
|
|
|
318
325
|
conversations = []
|
|
@@ -385,7 +392,9 @@ def to_llama_factory_vlm(
|
|
|
385
392
|
"""
|
|
386
393
|
|
|
387
394
|
def transform(item) -> dict:
|
|
388
|
-
|
|
395
|
+
def get(f):
|
|
396
|
+
return item.get(f) if hasattr(item, "get") else getattr(item, f, None)
|
|
397
|
+
|
|
389
398
|
messages = get(messages_field) or []
|
|
390
399
|
|
|
391
400
|
instruction = ""
|
|
@@ -467,7 +476,9 @@ def to_llama_factory_vlm_sharegpt(
|
|
|
467
476
|
role_map = {"user": "human", "assistant": "gpt", "system": "system"}
|
|
468
477
|
|
|
469
478
|
def transform(item) -> dict:
|
|
470
|
-
|
|
479
|
+
def get(f):
|
|
480
|
+
return item.get(f) if hasattr(item, "get") else getattr(item, f, None)
|
|
481
|
+
|
|
471
482
|
messages = get(messages_field) or []
|
|
472
483
|
|
|
473
484
|
conversations = []
|
|
@@ -541,7 +552,9 @@ def to_swift_messages(
|
|
|
541
552
|
"""
|
|
542
553
|
|
|
543
554
|
def transform(item) -> dict:
|
|
544
|
-
|
|
555
|
+
def get(f):
|
|
556
|
+
return item.get(f) if hasattr(item, "get") else getattr(item, f, None)
|
|
557
|
+
|
|
545
558
|
messages = get(messages_field) or []
|
|
546
559
|
|
|
547
560
|
# 复制 messages,避免修改原数据
|
|
@@ -600,7 +613,8 @@ def to_swift_query_response(
|
|
|
600
613
|
"""
|
|
601
614
|
|
|
602
615
|
def transform(item) -> dict:
|
|
603
|
-
|
|
616
|
+
def get(f):
|
|
617
|
+
return item.get(f) if hasattr(item, "get") else getattr(item, f, None)
|
|
604
618
|
|
|
605
619
|
query = get(query_field)
|
|
606
620
|
response = get(response_field)
|
|
@@ -613,7 +627,7 @@ def to_swift_query_response(
|
|
|
613
627
|
current_query = ""
|
|
614
628
|
current_response = ""
|
|
615
629
|
|
|
616
|
-
for
|
|
630
|
+
for _i, msg in enumerate(messages):
|
|
617
631
|
role = msg.get("role", "")
|
|
618
632
|
content = msg.get("content", "")
|
|
619
633
|
|
|
@@ -693,7 +707,9 @@ def to_swift_vlm(
|
|
|
693
707
|
"""
|
|
694
708
|
|
|
695
709
|
def transform(item) -> dict:
|
|
696
|
-
|
|
710
|
+
def get(f):
|
|
711
|
+
return item.get(f) if hasattr(item, "get") else getattr(item, f, None)
|
|
712
|
+
|
|
697
713
|
messages = get(messages_field) or []
|
|
698
714
|
|
|
699
715
|
result_messages = []
|
|
@@ -6,6 +6,8 @@
|
|
|
6
6
|
|
|
7
7
|
from typing import Any, Callable
|
|
8
8
|
|
|
9
|
+
from dtflow.utils.helpers import get_field_value
|
|
10
|
+
|
|
9
11
|
|
|
10
12
|
def openai_chat(
|
|
11
13
|
user_field: str = "q", assistant_field: str = "a", system_prompt: str = None
|
|
@@ -33,8 +35,8 @@ def openai_chat(
|
|
|
33
35
|
if system_prompt:
|
|
34
36
|
messages.append({"role": "system", "content": system_prompt})
|
|
35
37
|
|
|
36
|
-
user_content =
|
|
37
|
-
assistant_content =
|
|
38
|
+
user_content = get_field_value(item, user_field)
|
|
39
|
+
assistant_content = get_field_value(item, assistant_field)
|
|
38
40
|
|
|
39
41
|
messages.append({"role": "user", "content": user_content})
|
|
40
42
|
messages.append({"role": "assistant", "content": assistant_content})
|
|
@@ -60,10 +62,9 @@ def alpaca(
|
|
|
60
62
|
|
|
61
63
|
def transform(item: Any) -> dict:
|
|
62
64
|
return {
|
|
63
|
-
"instruction":
|
|
64
|
-
|
|
65
|
-
"
|
|
66
|
-
"output": getattr(item, output_field, None) or item.get(output_field, ""),
|
|
65
|
+
"instruction": get_field_value(item, instruction_field),
|
|
66
|
+
"input": get_field_value(item, input_field),
|
|
67
|
+
"output": get_field_value(item, output_field),
|
|
67
68
|
}
|
|
68
69
|
|
|
69
70
|
return transform
|
|
@@ -84,9 +85,7 @@ def sharegpt(conversations_field: str = "conversations", role_mapping: dict = No
|
|
|
84
85
|
role_mapping = role_mapping or {"user": "human", "assistant": "gpt"}
|
|
85
86
|
|
|
86
87
|
def transform(item: Any) -> dict:
|
|
87
|
-
conversations =
|
|
88
|
-
conversations_field, []
|
|
89
|
-
)
|
|
88
|
+
conversations = get_field_value(item, conversations_field, [])
|
|
90
89
|
|
|
91
90
|
# 如果已经是对话格式,直接返回
|
|
92
91
|
if conversations:
|
|
@@ -102,7 +101,7 @@ def sharegpt(conversations_field: str = "conversations", role_mapping: dict = No
|
|
|
102
101
|
("answer", "gpt"),
|
|
103
102
|
("output", "gpt"),
|
|
104
103
|
]:
|
|
105
|
-
value =
|
|
104
|
+
value = get_field_value(item, field, None)
|
|
106
105
|
if value:
|
|
107
106
|
result.append({"from": role, "value": value})
|
|
108
107
|
|
|
@@ -127,9 +126,9 @@ def dpo_pair(
|
|
|
127
126
|
|
|
128
127
|
def transform(item: Any) -> dict:
|
|
129
128
|
return {
|
|
130
|
-
"prompt":
|
|
131
|
-
"chosen":
|
|
132
|
-
"rejected":
|
|
129
|
+
"prompt": get_field_value(item, prompt_field),
|
|
130
|
+
"chosen": get_field_value(item, chosen_field),
|
|
131
|
+
"rejected": get_field_value(item, rejected_field),
|
|
133
132
|
}
|
|
134
133
|
|
|
135
134
|
return transform
|
|
@@ -148,8 +147,8 @@ def simple_qa(question_field: str = "q", answer_field: str = "a") -> Callable:
|
|
|
148
147
|
|
|
149
148
|
def transform(item: Any) -> dict:
|
|
150
149
|
return {
|
|
151
|
-
"question":
|
|
152
|
-
"answer":
|
|
150
|
+
"question": get_field_value(item, question_field),
|
|
151
|
+
"answer": get_field_value(item, answer_field),
|
|
153
152
|
}
|
|
154
153
|
|
|
155
154
|
return transform
|
|
@@ -9,6 +9,7 @@ from .field_path import (
|
|
|
9
9
|
get_field_with_spec,
|
|
10
10
|
parse_field_spec,
|
|
11
11
|
)
|
|
12
|
+
from .helpers import get_field_value
|
|
12
13
|
|
|
13
14
|
__all__ = [
|
|
14
15
|
"display_data",
|
|
@@ -20,4 +21,6 @@ __all__ = [
|
|
|
20
21
|
"extract",
|
|
21
22
|
"extract_with_spec",
|
|
22
23
|
"ExpandMode",
|
|
24
|
+
# helpers
|
|
25
|
+
"get_field_value",
|
|
23
26
|
]
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""公共辅助函数"""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_field_value(item: Any, field: str, default: Any = "") -> Any:
|
|
7
|
+
"""
|
|
8
|
+
获取字段值,支持 DictWrapper 和普通 dict。
|
|
9
|
+
|
|
10
|
+
优先尝试 dict.get(),如果没有 get 方法则使用 getattr()。
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
item: 数据对象(dict 或 DictWrapper)
|
|
14
|
+
field: 字段名
|
|
15
|
+
default: 默认值
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
字段值或默认值
|
|
19
|
+
|
|
20
|
+
Examples:
|
|
21
|
+
>>> get_field_value({"name": "test"}, "name")
|
|
22
|
+
'test'
|
|
23
|
+
>>> get_field_value({"name": ""}, "name", "default")
|
|
24
|
+
'default'
|
|
25
|
+
"""
|
|
26
|
+
if hasattr(item, "get"):
|
|
27
|
+
value = item.get(field, default)
|
|
28
|
+
else:
|
|
29
|
+
value = getattr(item, field, default)
|
|
30
|
+
return value if value else default
|
|
@@ -300,8 +300,10 @@ ignore_missing_imports = true
|
|
|
300
300
|
|
|
301
301
|
# Ruff configuration (optional alternative to flake8)
|
|
302
302
|
[tool.ruff]
|
|
303
|
-
target-version = "
|
|
303
|
+
target-version = "py38"
|
|
304
304
|
line-length = 100
|
|
305
|
+
|
|
306
|
+
[tool.ruff.lint]
|
|
305
307
|
select = [
|
|
306
308
|
"E", # pycodestyle errors
|
|
307
309
|
"W", # pycodestyle warnings
|
|
@@ -311,13 +313,13 @@ select = [
|
|
|
311
313
|
"B", # flake8-bugbear
|
|
312
314
|
]
|
|
313
315
|
ignore = [
|
|
314
|
-
"E501", # line too long, handled by
|
|
316
|
+
"E501", # line too long, handled by ruff-format
|
|
315
317
|
"B008", # do not perform function calls in argument defaults
|
|
316
318
|
"C901", # too complex
|
|
317
319
|
]
|
|
318
320
|
|
|
319
|
-
[tool.ruff.per-file-ignores]
|
|
321
|
+
[tool.ruff.lint.per-file-ignores]
|
|
320
322
|
"__init__.py" = ["F401"]
|
|
321
323
|
|
|
322
|
-
[tool.ruff.isort]
|
|
324
|
+
[tool.ruff.lint.isort]
|
|
323
325
|
known-first-party = ["dtflow"]
|