dtflow 0.4.3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/__init__.py +34 -1
- dtflow/__main__.py +22 -0
- dtflow/cli/commands.py +5 -0
- dtflow/cli/common.py +13 -9
- dtflow/cli/stats.py +114 -36
- dtflow/cli/validate.py +152 -0
- dtflow/core.py +220 -10
- dtflow/framework.py +610 -0
- dtflow/lineage.py +17 -0
- dtflow/schema.py +508 -0
- dtflow/streaming.py +93 -35
- dtflow/tokenizers.py +84 -29
- dtflow/utils/field_path.py +6 -2
- {dtflow-0.4.3.dist-info → dtflow-0.5.2.dist-info}/METADATA +117 -2
- {dtflow-0.4.3.dist-info → dtflow-0.5.2.dist-info}/RECORD +17 -14
- {dtflow-0.4.3.dist-info → dtflow-0.5.2.dist-info}/WHEEL +0 -0
- {dtflow-0.4.3.dist-info → dtflow-0.5.2.dist-info}/entry_points.txt +0 -0
dtflow/framework.py
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
1
|
+
"""
|
|
2
|
+
训练框架集成模块
|
|
3
|
+
|
|
4
|
+
支持一键导出数据和配置文件到主流训练框架:
|
|
5
|
+
- LLaMA-Factory
|
|
6
|
+
- ms-swift
|
|
7
|
+
- Axolotl
|
|
8
|
+
|
|
9
|
+
用法:
|
|
10
|
+
from dtflow import DataTransformer
|
|
11
|
+
|
|
12
|
+
# 检查兼容性
|
|
13
|
+
result = dt.check_compatibility("llama-factory")
|
|
14
|
+
|
|
15
|
+
# 一键导出
|
|
16
|
+
dt.export_for("llama-factory", output_dir="./output")
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import json
|
|
20
|
+
import os
|
|
21
|
+
from dataclasses import dataclass, field
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
|
24
|
+
|
|
25
|
+
# 支持的框架类型
|
|
26
|
+
FrameworkType = Literal["llama-factory", "swift", "axolotl"]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class CompatibilityResult:
|
|
31
|
+
"""兼容性检查结果"""
|
|
32
|
+
|
|
33
|
+
valid: bool
|
|
34
|
+
framework: str
|
|
35
|
+
format: str # 识别的格式类型
|
|
36
|
+
warnings: List[str] = field(default_factory=list)
|
|
37
|
+
errors: List[str] = field(default_factory=list)
|
|
38
|
+
suggestions: List[str] = field(default_factory=list)
|
|
39
|
+
|
|
40
|
+
def __bool__(self) -> bool:
|
|
41
|
+
return self.valid
|
|
42
|
+
|
|
43
|
+
def __str__(self) -> str:
|
|
44
|
+
status = "✅ 兼容" if self.valid else "❌ 不兼容"
|
|
45
|
+
lines = [f"{status} - {self.framework} ({self.format})"]
|
|
46
|
+
|
|
47
|
+
if self.errors:
|
|
48
|
+
lines.append("\n错误:")
|
|
49
|
+
for err in self.errors:
|
|
50
|
+
lines.append(f" - {err}")
|
|
51
|
+
|
|
52
|
+
if self.warnings:
|
|
53
|
+
lines.append("\n警告:")
|
|
54
|
+
for warn in self.warnings:
|
|
55
|
+
lines.append(f" - {warn}")
|
|
56
|
+
|
|
57
|
+
if self.suggestions:
|
|
58
|
+
lines.append("\n建议:")
|
|
59
|
+
for sug in self.suggestions:
|
|
60
|
+
lines.append(f" - {sug}")
|
|
61
|
+
|
|
62
|
+
return "\n".join(lines)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ============================================================================
|
|
66
|
+
# 格式检测
|
|
67
|
+
# ============================================================================
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def detect_format(data: List[dict]) -> str:
|
|
71
|
+
"""
|
|
72
|
+
自动检测数据格式
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
格式名称: alpaca, sharegpt, openai_chat, dpo, unknown
|
|
76
|
+
"""
|
|
77
|
+
if not data:
|
|
78
|
+
return "unknown"
|
|
79
|
+
|
|
80
|
+
sample = data[0]
|
|
81
|
+
|
|
82
|
+
# OpenAI Chat 格式
|
|
83
|
+
if "messages" in sample:
|
|
84
|
+
messages = sample["messages"]
|
|
85
|
+
if isinstance(messages, list) and len(messages) > 0:
|
|
86
|
+
first_msg = messages[0]
|
|
87
|
+
if isinstance(first_msg, dict) and "role" in first_msg and "content" in first_msg:
|
|
88
|
+
return "openai_chat"
|
|
89
|
+
|
|
90
|
+
# ShareGPT 格式
|
|
91
|
+
if "conversations" in sample:
|
|
92
|
+
convs = sample["conversations"]
|
|
93
|
+
if isinstance(convs, list) and len(convs) > 0:
|
|
94
|
+
first_conv = convs[0]
|
|
95
|
+
if isinstance(first_conv, dict) and "from" in first_conv and "value" in first_conv:
|
|
96
|
+
return "sharegpt"
|
|
97
|
+
|
|
98
|
+
# Alpaca 格式
|
|
99
|
+
if "instruction" in sample and "output" in sample:
|
|
100
|
+
return "alpaca"
|
|
101
|
+
|
|
102
|
+
# DPO 格式
|
|
103
|
+
if "prompt" in sample and "chosen" in sample and "rejected" in sample:
|
|
104
|
+
return "dpo"
|
|
105
|
+
|
|
106
|
+
# 简单 QA 格式
|
|
107
|
+
if ("question" in sample and "answer" in sample) or ("q" in sample and "a" in sample):
|
|
108
|
+
return "simple_qa"
|
|
109
|
+
|
|
110
|
+
return "unknown"
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# ============================================================================
|
|
114
|
+
# 兼容性检查
|
|
115
|
+
# ============================================================================
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def check_compatibility(
|
|
119
|
+
data: List[dict],
|
|
120
|
+
framework: FrameworkType,
|
|
121
|
+
) -> CompatibilityResult:
|
|
122
|
+
"""
|
|
123
|
+
检查数据与目标框架的兼容性
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
data: 数据列表
|
|
127
|
+
framework: 目标框架名称
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
CompatibilityResult 对象
|
|
131
|
+
"""
|
|
132
|
+
framework = framework.lower().replace("_", "-")
|
|
133
|
+
|
|
134
|
+
if framework in ("llama-factory", "llamafactory", "lf"):
|
|
135
|
+
return _check_llama_factory_compatibility(data)
|
|
136
|
+
elif framework in ("swift", "ms-swift", "modelscope-swift"):
|
|
137
|
+
return _check_swift_compatibility(data)
|
|
138
|
+
elif framework == "axolotl":
|
|
139
|
+
return _check_axolotl_compatibility(data)
|
|
140
|
+
else:
|
|
141
|
+
return CompatibilityResult(
|
|
142
|
+
valid=False,
|
|
143
|
+
framework=framework,
|
|
144
|
+
format="unknown",
|
|
145
|
+
errors=[f"不支持的框架: {framework}"],
|
|
146
|
+
suggestions=["支持的框架: llama-factory, swift, axolotl"],
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _check_llama_factory_compatibility(data: List[dict]) -> CompatibilityResult:
|
|
151
|
+
"""检查 LLaMA-Factory 兼容性"""
|
|
152
|
+
format_type = detect_format(data)
|
|
153
|
+
errors = []
|
|
154
|
+
warnings = []
|
|
155
|
+
suggestions = []
|
|
156
|
+
|
|
157
|
+
# 检查格式兼容性
|
|
158
|
+
if format_type == "unknown":
|
|
159
|
+
errors.append("无法识别数据格式")
|
|
160
|
+
suggestions.append("LLaMA-Factory 支持: alpaca, sharegpt, openai_chat")
|
|
161
|
+
return CompatibilityResult(
|
|
162
|
+
valid=False,
|
|
163
|
+
framework="LLaMA-Factory",
|
|
164
|
+
format=format_type,
|
|
165
|
+
errors=errors,
|
|
166
|
+
suggestions=suggestions,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# 格式特定检查
|
|
170
|
+
sample = data[0] if data else {}
|
|
171
|
+
|
|
172
|
+
if format_type == "openai_chat":
|
|
173
|
+
# 需要转换为 sharegpt 格式
|
|
174
|
+
suggestions.append("建议使用 to_llama_factory_sharegpt() 转换")
|
|
175
|
+
|
|
176
|
+
elif format_type == "alpaca":
|
|
177
|
+
# 直接兼容
|
|
178
|
+
if "input" not in sample:
|
|
179
|
+
warnings.append("缺少 'input' 字段,将使用空字符串")
|
|
180
|
+
|
|
181
|
+
elif format_type == "sharegpt":
|
|
182
|
+
# 检查角色名
|
|
183
|
+
if data:
|
|
184
|
+
roles = set()
|
|
185
|
+
for item in data[:10]: # 只检查前 10 条
|
|
186
|
+
for conv in item.get("conversations", []):
|
|
187
|
+
roles.add(conv.get("from", ""))
|
|
188
|
+
valid_roles = {"human", "gpt", "user", "assistant", "system"}
|
|
189
|
+
invalid_roles = roles - valid_roles
|
|
190
|
+
if invalid_roles:
|
|
191
|
+
warnings.append(f"非标准角色名: {invalid_roles}")
|
|
192
|
+
suggestions.append("标准角色: human/gpt 或 user/assistant")
|
|
193
|
+
|
|
194
|
+
elif format_type == "dpo":
|
|
195
|
+
# LLaMA-Factory 支持 DPO
|
|
196
|
+
pass
|
|
197
|
+
|
|
198
|
+
elif format_type == "simple_qa":
|
|
199
|
+
suggestions.append("建议使用 to_llama_factory() 转换为 alpaca 格式")
|
|
200
|
+
|
|
201
|
+
return CompatibilityResult(
|
|
202
|
+
valid=len(errors) == 0,
|
|
203
|
+
framework="LLaMA-Factory",
|
|
204
|
+
format=format_type,
|
|
205
|
+
errors=errors,
|
|
206
|
+
warnings=warnings,
|
|
207
|
+
suggestions=suggestions,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _check_swift_compatibility(data: List[dict]) -> CompatibilityResult:
|
|
212
|
+
"""检查 ms-swift 兼容性"""
|
|
213
|
+
format_type = detect_format(data)
|
|
214
|
+
errors = []
|
|
215
|
+
warnings = []
|
|
216
|
+
suggestions = []
|
|
217
|
+
|
|
218
|
+
if format_type == "unknown":
|
|
219
|
+
errors.append("无法识别数据格式")
|
|
220
|
+
suggestions.append("ms-swift 支持: messages, query-response, sharegpt")
|
|
221
|
+
return CompatibilityResult(
|
|
222
|
+
valid=False,
|
|
223
|
+
framework="ms-swift",
|
|
224
|
+
format=format_type,
|
|
225
|
+
errors=errors,
|
|
226
|
+
suggestions=suggestions,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# ms-swift 支持多种格式
|
|
230
|
+
if format_type == "openai_chat":
|
|
231
|
+
# messages 格式直接支持
|
|
232
|
+
pass
|
|
233
|
+
elif format_type == "alpaca":
|
|
234
|
+
suggestions.append("建议使用 to_swift_query_response() 转换")
|
|
235
|
+
elif format_type == "sharegpt":
|
|
236
|
+
# 需要转换角色
|
|
237
|
+
pass
|
|
238
|
+
|
|
239
|
+
return CompatibilityResult(
|
|
240
|
+
valid=len(errors) == 0,
|
|
241
|
+
framework="ms-swift",
|
|
242
|
+
format=format_type,
|
|
243
|
+
errors=errors,
|
|
244
|
+
warnings=warnings,
|
|
245
|
+
suggestions=suggestions,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def _check_axolotl_compatibility(data: List[dict]) -> CompatibilityResult:
|
|
250
|
+
"""检查 Axolotl 兼容性"""
|
|
251
|
+
format_type = detect_format(data)
|
|
252
|
+
errors = []
|
|
253
|
+
warnings = []
|
|
254
|
+
suggestions = []
|
|
255
|
+
|
|
256
|
+
if format_type == "unknown":
|
|
257
|
+
errors.append("无法识别数据格式")
|
|
258
|
+
suggestions.append("Axolotl 支持: alpaca, sharegpt, openai_chat")
|
|
259
|
+
return CompatibilityResult(
|
|
260
|
+
valid=False,
|
|
261
|
+
framework="Axolotl",
|
|
262
|
+
format=format_type,
|
|
263
|
+
errors=errors,
|
|
264
|
+
suggestions=suggestions,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
if format_type == "openai_chat":
|
|
268
|
+
# Axolotl 直接支持 messages 格式
|
|
269
|
+
pass
|
|
270
|
+
elif format_type == "alpaca":
|
|
271
|
+
pass
|
|
272
|
+
elif format_type == "sharegpt":
|
|
273
|
+
pass
|
|
274
|
+
|
|
275
|
+
return CompatibilityResult(
|
|
276
|
+
valid=len(errors) == 0,
|
|
277
|
+
framework="Axolotl",
|
|
278
|
+
format=format_type,
|
|
279
|
+
errors=errors,
|
|
280
|
+
warnings=warnings,
|
|
281
|
+
suggestions=suggestions,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
# ============================================================================
|
|
286
|
+
# 导出功能
|
|
287
|
+
# ============================================================================
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def export_for(
|
|
291
|
+
data: List[dict],
|
|
292
|
+
framework: FrameworkType,
|
|
293
|
+
output_dir: str,
|
|
294
|
+
dataset_name: str = "custom_dataset",
|
|
295
|
+
format_type: Optional[str] = None,
|
|
296
|
+
**kwargs,
|
|
297
|
+
) -> Dict[str, str]:
|
|
298
|
+
"""
|
|
299
|
+
一键导出数据和配置文件到目标框架
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
data: 数据列表
|
|
303
|
+
framework: 目标框架
|
|
304
|
+
output_dir: 输出目录
|
|
305
|
+
dataset_name: 数据集名称
|
|
306
|
+
format_type: 强制指定格式类型(默认自动检测)
|
|
307
|
+
**kwargs: 框架特定参数
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
生成的文件路径字典 {"data": "...", "config": "...", ...}
|
|
311
|
+
"""
|
|
312
|
+
framework = framework.lower().replace("_", "-")
|
|
313
|
+
output_path = Path(output_dir)
|
|
314
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
315
|
+
|
|
316
|
+
# 自动检测格式
|
|
317
|
+
if format_type is None:
|
|
318
|
+
format_type = detect_format(data)
|
|
319
|
+
|
|
320
|
+
if framework in ("llama-factory", "llamafactory", "lf"):
|
|
321
|
+
return _export_llama_factory(data, output_path, dataset_name, format_type, **kwargs)
|
|
322
|
+
elif framework in ("swift", "ms-swift", "modelscope-swift"):
|
|
323
|
+
return _export_swift(data, output_path, dataset_name, format_type, **kwargs)
|
|
324
|
+
elif framework == "axolotl":
|
|
325
|
+
return _export_axolotl(data, output_path, dataset_name, format_type, **kwargs)
|
|
326
|
+
else:
|
|
327
|
+
raise ValueError(f"不支持的框架: {framework}")
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def _export_llama_factory(
|
|
331
|
+
data: List[dict],
|
|
332
|
+
output_path: Path,
|
|
333
|
+
dataset_name: str,
|
|
334
|
+
format_type: str,
|
|
335
|
+
**kwargs,
|
|
336
|
+
) -> Dict[str, str]:
|
|
337
|
+
"""导出为 LLaMA-Factory 格式"""
|
|
338
|
+
files = {}
|
|
339
|
+
|
|
340
|
+
# 1. 保存数据文件
|
|
341
|
+
data_file = output_path / f"{dataset_name}.json"
|
|
342
|
+
with open(data_file, "w", encoding="utf-8") as f:
|
|
343
|
+
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
344
|
+
files["data"] = str(data_file)
|
|
345
|
+
|
|
346
|
+
# 2. 生成 dataset_info.json
|
|
347
|
+
dataset_info = _generate_llama_factory_dataset_info(dataset_name, format_type)
|
|
348
|
+
info_file = output_path / "dataset_info.json"
|
|
349
|
+
with open(info_file, "w", encoding="utf-8") as f:
|
|
350
|
+
json.dump(dataset_info, f, ensure_ascii=False, indent=2)
|
|
351
|
+
files["dataset_info"] = str(info_file)
|
|
352
|
+
|
|
353
|
+
# 3. 生成训练参数模板
|
|
354
|
+
train_args = _generate_llama_factory_train_args(dataset_name, **kwargs)
|
|
355
|
+
args_file = output_path / "train_args.yaml"
|
|
356
|
+
with open(args_file, "w", encoding="utf-8") as f:
|
|
357
|
+
f.write(train_args)
|
|
358
|
+
files["train_args"] = str(args_file)
|
|
359
|
+
|
|
360
|
+
print(f"✅ LLaMA-Factory 导出完成:")
|
|
361
|
+
print(f" 数据文件: {data_file}")
|
|
362
|
+
print(f" 配置文件: {info_file}")
|
|
363
|
+
print(f" 训练参数: {args_file}")
|
|
364
|
+
|
|
365
|
+
return files
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def _generate_llama_factory_dataset_info(dataset_name: str, format_type: str) -> dict:
|
|
369
|
+
"""生成 LLaMA-Factory dataset_info.json"""
|
|
370
|
+
if format_type in ("openai_chat", "sharegpt"):
|
|
371
|
+
# ShareGPT/对话格式
|
|
372
|
+
return {
|
|
373
|
+
dataset_name: {
|
|
374
|
+
"file_name": f"{dataset_name}.json",
|
|
375
|
+
"formatting": "sharegpt",
|
|
376
|
+
"columns": {
|
|
377
|
+
"messages": "conversations",
|
|
378
|
+
},
|
|
379
|
+
"tags": {
|
|
380
|
+
"role_tag": "from",
|
|
381
|
+
"content_tag": "value",
|
|
382
|
+
"user_tag": "human",
|
|
383
|
+
"assistant_tag": "gpt",
|
|
384
|
+
},
|
|
385
|
+
}
|
|
386
|
+
}
|
|
387
|
+
elif format_type == "alpaca":
|
|
388
|
+
return {
|
|
389
|
+
dataset_name: {
|
|
390
|
+
"file_name": f"{dataset_name}.json",
|
|
391
|
+
"columns": {
|
|
392
|
+
"prompt": "instruction",
|
|
393
|
+
"query": "input",
|
|
394
|
+
"response": "output",
|
|
395
|
+
},
|
|
396
|
+
}
|
|
397
|
+
}
|
|
398
|
+
elif format_type == "dpo":
|
|
399
|
+
return {
|
|
400
|
+
dataset_name: {
|
|
401
|
+
"file_name": f"{dataset_name}.json",
|
|
402
|
+
"ranking": True,
|
|
403
|
+
"columns": {
|
|
404
|
+
"prompt": "prompt",
|
|
405
|
+
"chosen": "chosen",
|
|
406
|
+
"rejected": "rejected",
|
|
407
|
+
},
|
|
408
|
+
}
|
|
409
|
+
}
|
|
410
|
+
else:
|
|
411
|
+
# 默认 alpaca 格式
|
|
412
|
+
return {
|
|
413
|
+
dataset_name: {
|
|
414
|
+
"file_name": f"{dataset_name}.json",
|
|
415
|
+
}
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def _generate_llama_factory_train_args(
|
|
420
|
+
dataset_name: str,
|
|
421
|
+
model_name: str = "Qwen/Qwen2.5-7B-Instruct",
|
|
422
|
+
**kwargs,
|
|
423
|
+
) -> str:
|
|
424
|
+
"""生成 LLaMA-Factory 训练参数模板"""
|
|
425
|
+
return f"""### LLaMA-Factory 训练参数模板
|
|
426
|
+
### 使用: llamafactory-cli train train_args.yaml
|
|
427
|
+
|
|
428
|
+
### 模型
|
|
429
|
+
model_name_or_path: {model_name}
|
|
430
|
+
|
|
431
|
+
### 方法
|
|
432
|
+
stage: sft
|
|
433
|
+
do_train: true
|
|
434
|
+
finetuning_type: lora
|
|
435
|
+
|
|
436
|
+
### 数据集
|
|
437
|
+
dataset: {dataset_name}
|
|
438
|
+
dataset_dir: .
|
|
439
|
+
template: qwen
|
|
440
|
+
cutoff_len: 2048
|
|
441
|
+
|
|
442
|
+
### 输出
|
|
443
|
+
output_dir: ./output
|
|
444
|
+
logging_steps: 10
|
|
445
|
+
save_steps: 500
|
|
446
|
+
plot_loss: true
|
|
447
|
+
|
|
448
|
+
### 训练参数
|
|
449
|
+
per_device_train_batch_size: 2
|
|
450
|
+
gradient_accumulation_steps: 4
|
|
451
|
+
learning_rate: 1.0e-4
|
|
452
|
+
num_train_epochs: 3.0
|
|
453
|
+
lr_scheduler_type: cosine
|
|
454
|
+
|
|
455
|
+
### LoRA 参数
|
|
456
|
+
lora_rank: 8
|
|
457
|
+
lora_alpha: 16
|
|
458
|
+
lora_dropout: 0.1
|
|
459
|
+
lora_target: all
|
|
460
|
+
"""
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def _export_swift(
|
|
464
|
+
data: List[dict],
|
|
465
|
+
output_path: Path,
|
|
466
|
+
dataset_name: str,
|
|
467
|
+
format_type: str,
|
|
468
|
+
**kwargs,
|
|
469
|
+
) -> Dict[str, str]:
|
|
470
|
+
"""导出为 ms-swift 格式"""
|
|
471
|
+
files = {}
|
|
472
|
+
|
|
473
|
+
# 1. 保存数据文件 (JSONL 格式)
|
|
474
|
+
data_file = output_path / f"{dataset_name}.jsonl"
|
|
475
|
+
with open(data_file, "w", encoding="utf-8") as f:
|
|
476
|
+
for item in data:
|
|
477
|
+
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
|
478
|
+
files["data"] = str(data_file)
|
|
479
|
+
|
|
480
|
+
# 2. 生成训练脚本
|
|
481
|
+
train_script = _generate_swift_train_script(dataset_name, format_type, **kwargs)
|
|
482
|
+
script_file = output_path / "train.sh"
|
|
483
|
+
with open(script_file, "w", encoding="utf-8") as f:
|
|
484
|
+
f.write(train_script)
|
|
485
|
+
files["train_script"] = str(script_file)
|
|
486
|
+
|
|
487
|
+
print(f"✅ ms-swift 导出完成:")
|
|
488
|
+
print(f" 数据文件: {data_file}")
|
|
489
|
+
print(f" 训练脚本: {script_file}")
|
|
490
|
+
|
|
491
|
+
return files
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def _generate_swift_train_script(
|
|
495
|
+
dataset_name: str,
|
|
496
|
+
format_type: str,
|
|
497
|
+
model_name: str = "qwen2_5-7b-instruct",
|
|
498
|
+
**kwargs,
|
|
499
|
+
) -> str:
|
|
500
|
+
"""生成 ms-swift 训练脚本"""
|
|
501
|
+
# 确定数据集格式
|
|
502
|
+
if format_type in ("openai_chat", "sharegpt"):
|
|
503
|
+
dataset_format = "messages"
|
|
504
|
+
else:
|
|
505
|
+
dataset_format = "query-response"
|
|
506
|
+
|
|
507
|
+
return f"""#!/bin/bash
|
|
508
|
+
# ms-swift 训练脚本
|
|
509
|
+
# 使用: bash train.sh
|
|
510
|
+
|
|
511
|
+
swift sft \\
|
|
512
|
+
--model_type {model_name} \\
|
|
513
|
+
--dataset {dataset_name}.jsonl \\
|
|
514
|
+
--output_dir ./output \\
|
|
515
|
+
--max_length 2048 \\
|
|
516
|
+
--learning_rate 1e-4 \\
|
|
517
|
+
--num_train_epochs 3 \\
|
|
518
|
+
--per_device_train_batch_size 2 \\
|
|
519
|
+
--gradient_accumulation_steps 4 \\
|
|
520
|
+
--save_steps 500 \\
|
|
521
|
+
--logging_steps 10 \\
|
|
522
|
+
--lora_rank 8 \\
|
|
523
|
+
--lora_alpha 32
|
|
524
|
+
"""
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def _export_axolotl(
|
|
528
|
+
data: List[dict],
|
|
529
|
+
output_path: Path,
|
|
530
|
+
dataset_name: str,
|
|
531
|
+
format_type: str,
|
|
532
|
+
**kwargs,
|
|
533
|
+
) -> Dict[str, str]:
|
|
534
|
+
"""导出为 Axolotl 格式"""
|
|
535
|
+
files = {}
|
|
536
|
+
|
|
537
|
+
# 1. 保存数据文件
|
|
538
|
+
data_file = output_path / f"{dataset_name}.jsonl"
|
|
539
|
+
with open(data_file, "w", encoding="utf-8") as f:
|
|
540
|
+
for item in data:
|
|
541
|
+
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
|
542
|
+
files["data"] = str(data_file)
|
|
543
|
+
|
|
544
|
+
# 2. 生成配置文件
|
|
545
|
+
config = _generate_axolotl_config(dataset_name, format_type, **kwargs)
|
|
546
|
+
config_file = output_path / "config.yaml"
|
|
547
|
+
with open(config_file, "w", encoding="utf-8") as f:
|
|
548
|
+
f.write(config)
|
|
549
|
+
files["config"] = str(config_file)
|
|
550
|
+
|
|
551
|
+
print(f"✅ Axolotl 导出完成:")
|
|
552
|
+
print(f" 数据文件: {data_file}")
|
|
553
|
+
print(f" 配置文件: {config_file}")
|
|
554
|
+
|
|
555
|
+
return files
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def _generate_axolotl_config(
|
|
559
|
+
dataset_name: str,
|
|
560
|
+
format_type: str,
|
|
561
|
+
model_name: str = "Qwen/Qwen2.5-7B-Instruct",
|
|
562
|
+
**kwargs,
|
|
563
|
+
) -> str:
|
|
564
|
+
"""生成 Axolotl 配置文件"""
|
|
565
|
+
# 确定数据集格式类型
|
|
566
|
+
if format_type in ("openai_chat",):
|
|
567
|
+
ds_type = "chat_template"
|
|
568
|
+
elif format_type == "sharegpt":
|
|
569
|
+
ds_type = "sharegpt"
|
|
570
|
+
elif format_type == "alpaca":
|
|
571
|
+
ds_type = "alpaca"
|
|
572
|
+
else:
|
|
573
|
+
ds_type = "completion"
|
|
574
|
+
|
|
575
|
+
return f"""# Axolotl 配置文件
|
|
576
|
+
# 使用: accelerate launch -m axolotl.cli.train config.yaml
|
|
577
|
+
|
|
578
|
+
base_model: {model_name}
|
|
579
|
+
model_type: AutoModelForCausalLM
|
|
580
|
+
|
|
581
|
+
datasets:
|
|
582
|
+
- path: {dataset_name}.jsonl
|
|
583
|
+
type: {ds_type}
|
|
584
|
+
|
|
585
|
+
sequence_len: 2048
|
|
586
|
+
sample_packing: true
|
|
587
|
+
pad_to_sequence_len: true
|
|
588
|
+
|
|
589
|
+
adapter: lora
|
|
590
|
+
lora_r: 8
|
|
591
|
+
lora_alpha: 16
|
|
592
|
+
lora_dropout: 0.05
|
|
593
|
+
lora_target_linear: true
|
|
594
|
+
|
|
595
|
+
learning_rate: 1e-4
|
|
596
|
+
num_epochs: 3
|
|
597
|
+
micro_batch_size: 2
|
|
598
|
+
gradient_accumulation_steps: 4
|
|
599
|
+
|
|
600
|
+
output_dir: ./output
|
|
601
|
+
logging_steps: 10
|
|
602
|
+
save_steps: 500
|
|
603
|
+
|
|
604
|
+
bf16: auto
|
|
605
|
+
tf32: true
|
|
606
|
+
gradient_checkpointing: true
|
|
607
|
+
|
|
608
|
+
warmup_ratio: 0.1
|
|
609
|
+
lr_scheduler: cosine
|
|
610
|
+
"""
|
dtflow/lineage.py
CHANGED
|
@@ -237,6 +237,23 @@ class LineageTracker:
|
|
|
237
237
|
|
|
238
238
|
return lineage_path
|
|
239
239
|
|
|
240
|
+
def copy(self) -> "LineageTracker":
|
|
241
|
+
"""
|
|
242
|
+
创建追踪器的深拷贝。
|
|
243
|
+
|
|
244
|
+
用于 split() 等场景,确保子数据集有独立的血缘追踪。
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
新的 LineageTracker 实例
|
|
248
|
+
"""
|
|
249
|
+
import copy as copy_module
|
|
250
|
+
|
|
251
|
+
new_tracker = LineageTracker.__new__(LineageTracker)
|
|
252
|
+
new_tracker.source_path = self.source_path
|
|
253
|
+
new_tracker.source_lineage = self.source_lineage # LineageRecord 是不可变的,可共享
|
|
254
|
+
new_tracker.operations = copy_module.deepcopy(self.operations)
|
|
255
|
+
return new_tracker
|
|
256
|
+
|
|
240
257
|
|
|
241
258
|
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
|
|
242
259
|
"""
|