dtflow 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/framework.py ADDED
@@ -0,0 +1,610 @@
1
+ """
2
+ 训练框架集成模块
3
+
4
+ 支持一键导出数据和配置文件到主流训练框架:
5
+ - LLaMA-Factory
6
+ - ms-swift
7
+ - Axolotl
8
+
9
+ 用法:
10
+ from dtflow import DataTransformer
11
+
12
+ # 检查兼容性
13
+ result = dt.check_compatibility("llama-factory")
14
+
15
+ # 一键导出
16
+ dt.export_for("llama-factory", output_dir="./output")
17
+ """
18
+
19
+ import json
20
+ import os
21
+ from dataclasses import dataclass, field
22
+ from pathlib import Path
23
+ from typing import Any, Callable, Dict, List, Literal, Optional, Union
24
+
25
+ # 支持的框架类型
26
+ FrameworkType = Literal["llama-factory", "swift", "axolotl"]
27
+
28
+
29
+ @dataclass
30
+ class CompatibilityResult:
31
+ """兼容性检查结果"""
32
+
33
+ valid: bool
34
+ framework: str
35
+ format: str # 识别的格式类型
36
+ warnings: List[str] = field(default_factory=list)
37
+ errors: List[str] = field(default_factory=list)
38
+ suggestions: List[str] = field(default_factory=list)
39
+
40
+ def __bool__(self) -> bool:
41
+ return self.valid
42
+
43
+ def __str__(self) -> str:
44
+ status = "✅ 兼容" if self.valid else "❌ 不兼容"
45
+ lines = [f"{status} - {self.framework} ({self.format})"]
46
+
47
+ if self.errors:
48
+ lines.append("\n错误:")
49
+ for err in self.errors:
50
+ lines.append(f" - {err}")
51
+
52
+ if self.warnings:
53
+ lines.append("\n警告:")
54
+ for warn in self.warnings:
55
+ lines.append(f" - {warn}")
56
+
57
+ if self.suggestions:
58
+ lines.append("\n建议:")
59
+ for sug in self.suggestions:
60
+ lines.append(f" - {sug}")
61
+
62
+ return "\n".join(lines)
63
+
64
+
65
+ # ============================================================================
66
+ # 格式检测
67
+ # ============================================================================
68
+
69
+
70
+ def detect_format(data: List[dict]) -> str:
71
+ """
72
+ 自动检测数据格式
73
+
74
+ Returns:
75
+ 格式名称: alpaca, sharegpt, openai_chat, dpo, unknown
76
+ """
77
+ if not data:
78
+ return "unknown"
79
+
80
+ sample = data[0]
81
+
82
+ # OpenAI Chat 格式
83
+ if "messages" in sample:
84
+ messages = sample["messages"]
85
+ if isinstance(messages, list) and len(messages) > 0:
86
+ first_msg = messages[0]
87
+ if isinstance(first_msg, dict) and "role" in first_msg and "content" in first_msg:
88
+ return "openai_chat"
89
+
90
+ # ShareGPT 格式
91
+ if "conversations" in sample:
92
+ convs = sample["conversations"]
93
+ if isinstance(convs, list) and len(convs) > 0:
94
+ first_conv = convs[0]
95
+ if isinstance(first_conv, dict) and "from" in first_conv and "value" in first_conv:
96
+ return "sharegpt"
97
+
98
+ # Alpaca 格式
99
+ if "instruction" in sample and "output" in sample:
100
+ return "alpaca"
101
+
102
+ # DPO 格式
103
+ if "prompt" in sample and "chosen" in sample and "rejected" in sample:
104
+ return "dpo"
105
+
106
+ # 简单 QA 格式
107
+ if ("question" in sample and "answer" in sample) or ("q" in sample and "a" in sample):
108
+ return "simple_qa"
109
+
110
+ return "unknown"
111
+
112
+
113
+ # ============================================================================
114
+ # 兼容性检查
115
+ # ============================================================================
116
+
117
+
118
+ def check_compatibility(
119
+ data: List[dict],
120
+ framework: FrameworkType,
121
+ ) -> CompatibilityResult:
122
+ """
123
+ 检查数据与目标框架的兼容性
124
+
125
+ Args:
126
+ data: 数据列表
127
+ framework: 目标框架名称
128
+
129
+ Returns:
130
+ CompatibilityResult 对象
131
+ """
132
+ framework = framework.lower().replace("_", "-")
133
+
134
+ if framework in ("llama-factory", "llamafactory", "lf"):
135
+ return _check_llama_factory_compatibility(data)
136
+ elif framework in ("swift", "ms-swift", "modelscope-swift"):
137
+ return _check_swift_compatibility(data)
138
+ elif framework == "axolotl":
139
+ return _check_axolotl_compatibility(data)
140
+ else:
141
+ return CompatibilityResult(
142
+ valid=False,
143
+ framework=framework,
144
+ format="unknown",
145
+ errors=[f"不支持的框架: {framework}"],
146
+ suggestions=["支持的框架: llama-factory, swift, axolotl"],
147
+ )
148
+
149
+
150
+ def _check_llama_factory_compatibility(data: List[dict]) -> CompatibilityResult:
151
+ """检查 LLaMA-Factory 兼容性"""
152
+ format_type = detect_format(data)
153
+ errors = []
154
+ warnings = []
155
+ suggestions = []
156
+
157
+ # 检查格式兼容性
158
+ if format_type == "unknown":
159
+ errors.append("无法识别数据格式")
160
+ suggestions.append("LLaMA-Factory 支持: alpaca, sharegpt, openai_chat")
161
+ return CompatibilityResult(
162
+ valid=False,
163
+ framework="LLaMA-Factory",
164
+ format=format_type,
165
+ errors=errors,
166
+ suggestions=suggestions,
167
+ )
168
+
169
+ # 格式特定检查
170
+ sample = data[0] if data else {}
171
+
172
+ if format_type == "openai_chat":
173
+ # 需要转换为 sharegpt 格式
174
+ suggestions.append("建议使用 to_llama_factory_sharegpt() 转换")
175
+
176
+ elif format_type == "alpaca":
177
+ # 直接兼容
178
+ if "input" not in sample:
179
+ warnings.append("缺少 'input' 字段,将使用空字符串")
180
+
181
+ elif format_type == "sharegpt":
182
+ # 检查角色名
183
+ if data:
184
+ roles = set()
185
+ for item in data[:10]: # 只检查前 10 条
186
+ for conv in item.get("conversations", []):
187
+ roles.add(conv.get("from", ""))
188
+ valid_roles = {"human", "gpt", "user", "assistant", "system"}
189
+ invalid_roles = roles - valid_roles
190
+ if invalid_roles:
191
+ warnings.append(f"非标准角色名: {invalid_roles}")
192
+ suggestions.append("标准角色: human/gpt 或 user/assistant")
193
+
194
+ elif format_type == "dpo":
195
+ # LLaMA-Factory 支持 DPO
196
+ pass
197
+
198
+ elif format_type == "simple_qa":
199
+ suggestions.append("建议使用 to_llama_factory() 转换为 alpaca 格式")
200
+
201
+ return CompatibilityResult(
202
+ valid=len(errors) == 0,
203
+ framework="LLaMA-Factory",
204
+ format=format_type,
205
+ errors=errors,
206
+ warnings=warnings,
207
+ suggestions=suggestions,
208
+ )
209
+
210
+
211
+ def _check_swift_compatibility(data: List[dict]) -> CompatibilityResult:
212
+ """检查 ms-swift 兼容性"""
213
+ format_type = detect_format(data)
214
+ errors = []
215
+ warnings = []
216
+ suggestions = []
217
+
218
+ if format_type == "unknown":
219
+ errors.append("无法识别数据格式")
220
+ suggestions.append("ms-swift 支持: messages, query-response, sharegpt")
221
+ return CompatibilityResult(
222
+ valid=False,
223
+ framework="ms-swift",
224
+ format=format_type,
225
+ errors=errors,
226
+ suggestions=suggestions,
227
+ )
228
+
229
+ # ms-swift 支持多种格式
230
+ if format_type == "openai_chat":
231
+ # messages 格式直接支持
232
+ pass
233
+ elif format_type == "alpaca":
234
+ suggestions.append("建议使用 to_swift_query_response() 转换")
235
+ elif format_type == "sharegpt":
236
+ # 需要转换角色
237
+ pass
238
+
239
+ return CompatibilityResult(
240
+ valid=len(errors) == 0,
241
+ framework="ms-swift",
242
+ format=format_type,
243
+ errors=errors,
244
+ warnings=warnings,
245
+ suggestions=suggestions,
246
+ )
247
+
248
+
249
+ def _check_axolotl_compatibility(data: List[dict]) -> CompatibilityResult:
250
+ """检查 Axolotl 兼容性"""
251
+ format_type = detect_format(data)
252
+ errors = []
253
+ warnings = []
254
+ suggestions = []
255
+
256
+ if format_type == "unknown":
257
+ errors.append("无法识别数据格式")
258
+ suggestions.append("Axolotl 支持: alpaca, sharegpt, openai_chat")
259
+ return CompatibilityResult(
260
+ valid=False,
261
+ framework="Axolotl",
262
+ format=format_type,
263
+ errors=errors,
264
+ suggestions=suggestions,
265
+ )
266
+
267
+ if format_type == "openai_chat":
268
+ # Axolotl 直接支持 messages 格式
269
+ pass
270
+ elif format_type == "alpaca":
271
+ pass
272
+ elif format_type == "sharegpt":
273
+ pass
274
+
275
+ return CompatibilityResult(
276
+ valid=len(errors) == 0,
277
+ framework="Axolotl",
278
+ format=format_type,
279
+ errors=errors,
280
+ warnings=warnings,
281
+ suggestions=suggestions,
282
+ )
283
+
284
+
285
+ # ============================================================================
286
+ # 导出功能
287
+ # ============================================================================
288
+
289
+
290
+ def export_for(
291
+ data: List[dict],
292
+ framework: FrameworkType,
293
+ output_dir: str,
294
+ dataset_name: str = "custom_dataset",
295
+ format_type: Optional[str] = None,
296
+ **kwargs,
297
+ ) -> Dict[str, str]:
298
+ """
299
+ 一键导出数据和配置文件到目标框架
300
+
301
+ Args:
302
+ data: 数据列表
303
+ framework: 目标框架
304
+ output_dir: 输出目录
305
+ dataset_name: 数据集名称
306
+ format_type: 强制指定格式类型(默认自动检测)
307
+ **kwargs: 框架特定参数
308
+
309
+ Returns:
310
+ 生成的文件路径字典 {"data": "...", "config": "...", ...}
311
+ """
312
+ framework = framework.lower().replace("_", "-")
313
+ output_path = Path(output_dir)
314
+ output_path.mkdir(parents=True, exist_ok=True)
315
+
316
+ # 自动检测格式
317
+ if format_type is None:
318
+ format_type = detect_format(data)
319
+
320
+ if framework in ("llama-factory", "llamafactory", "lf"):
321
+ return _export_llama_factory(data, output_path, dataset_name, format_type, **kwargs)
322
+ elif framework in ("swift", "ms-swift", "modelscope-swift"):
323
+ return _export_swift(data, output_path, dataset_name, format_type, **kwargs)
324
+ elif framework == "axolotl":
325
+ return _export_axolotl(data, output_path, dataset_name, format_type, **kwargs)
326
+ else:
327
+ raise ValueError(f"不支持的框架: {framework}")
328
+
329
+
330
+ def _export_llama_factory(
331
+ data: List[dict],
332
+ output_path: Path,
333
+ dataset_name: str,
334
+ format_type: str,
335
+ **kwargs,
336
+ ) -> Dict[str, str]:
337
+ """导出为 LLaMA-Factory 格式"""
338
+ files = {}
339
+
340
+ # 1. 保存数据文件
341
+ data_file = output_path / f"{dataset_name}.json"
342
+ with open(data_file, "w", encoding="utf-8") as f:
343
+ json.dump(data, f, ensure_ascii=False, indent=2)
344
+ files["data"] = str(data_file)
345
+
346
+ # 2. 生成 dataset_info.json
347
+ dataset_info = _generate_llama_factory_dataset_info(dataset_name, format_type)
348
+ info_file = output_path / "dataset_info.json"
349
+ with open(info_file, "w", encoding="utf-8") as f:
350
+ json.dump(dataset_info, f, ensure_ascii=False, indent=2)
351
+ files["dataset_info"] = str(info_file)
352
+
353
+ # 3. 生成训练参数模板
354
+ train_args = _generate_llama_factory_train_args(dataset_name, **kwargs)
355
+ args_file = output_path / "train_args.yaml"
356
+ with open(args_file, "w", encoding="utf-8") as f:
357
+ f.write(train_args)
358
+ files["train_args"] = str(args_file)
359
+
360
+ print(f"✅ LLaMA-Factory 导出完成:")
361
+ print(f" 数据文件: {data_file}")
362
+ print(f" 配置文件: {info_file}")
363
+ print(f" 训练参数: {args_file}")
364
+
365
+ return files
366
+
367
+
368
+ def _generate_llama_factory_dataset_info(dataset_name: str, format_type: str) -> dict:
369
+ """生成 LLaMA-Factory dataset_info.json"""
370
+ if format_type in ("openai_chat", "sharegpt"):
371
+ # ShareGPT/对话格式
372
+ return {
373
+ dataset_name: {
374
+ "file_name": f"{dataset_name}.json",
375
+ "formatting": "sharegpt",
376
+ "columns": {
377
+ "messages": "conversations",
378
+ },
379
+ "tags": {
380
+ "role_tag": "from",
381
+ "content_tag": "value",
382
+ "user_tag": "human",
383
+ "assistant_tag": "gpt",
384
+ },
385
+ }
386
+ }
387
+ elif format_type == "alpaca":
388
+ return {
389
+ dataset_name: {
390
+ "file_name": f"{dataset_name}.json",
391
+ "columns": {
392
+ "prompt": "instruction",
393
+ "query": "input",
394
+ "response": "output",
395
+ },
396
+ }
397
+ }
398
+ elif format_type == "dpo":
399
+ return {
400
+ dataset_name: {
401
+ "file_name": f"{dataset_name}.json",
402
+ "ranking": True,
403
+ "columns": {
404
+ "prompt": "prompt",
405
+ "chosen": "chosen",
406
+ "rejected": "rejected",
407
+ },
408
+ }
409
+ }
410
+ else:
411
+ # 默认 alpaca 格式
412
+ return {
413
+ dataset_name: {
414
+ "file_name": f"{dataset_name}.json",
415
+ }
416
+ }
417
+
418
+
419
+ def _generate_llama_factory_train_args(
420
+ dataset_name: str,
421
+ model_name: str = "Qwen/Qwen2.5-7B-Instruct",
422
+ **kwargs,
423
+ ) -> str:
424
+ """生成 LLaMA-Factory 训练参数模板"""
425
+ return f"""### LLaMA-Factory 训练参数模板
426
+ ### 使用: llamafactory-cli train train_args.yaml
427
+
428
+ ### 模型
429
+ model_name_or_path: {model_name}
430
+
431
+ ### 方法
432
+ stage: sft
433
+ do_train: true
434
+ finetuning_type: lora
435
+
436
+ ### 数据集
437
+ dataset: {dataset_name}
438
+ dataset_dir: .
439
+ template: qwen
440
+ cutoff_len: 2048
441
+
442
+ ### 输出
443
+ output_dir: ./output
444
+ logging_steps: 10
445
+ save_steps: 500
446
+ plot_loss: true
447
+
448
+ ### 训练参数
449
+ per_device_train_batch_size: 2
450
+ gradient_accumulation_steps: 4
451
+ learning_rate: 1.0e-4
452
+ num_train_epochs: 3.0
453
+ lr_scheduler_type: cosine
454
+
455
+ ### LoRA 参数
456
+ lora_rank: 8
457
+ lora_alpha: 16
458
+ lora_dropout: 0.1
459
+ lora_target: all
460
+ """
461
+
462
+
463
+ def _export_swift(
464
+ data: List[dict],
465
+ output_path: Path,
466
+ dataset_name: str,
467
+ format_type: str,
468
+ **kwargs,
469
+ ) -> Dict[str, str]:
470
+ """导出为 ms-swift 格式"""
471
+ files = {}
472
+
473
+ # 1. 保存数据文件 (JSONL 格式)
474
+ data_file = output_path / f"{dataset_name}.jsonl"
475
+ with open(data_file, "w", encoding="utf-8") as f:
476
+ for item in data:
477
+ f.write(json.dumps(item, ensure_ascii=False) + "\n")
478
+ files["data"] = str(data_file)
479
+
480
+ # 2. 生成训练脚本
481
+ train_script = _generate_swift_train_script(dataset_name, format_type, **kwargs)
482
+ script_file = output_path / "train.sh"
483
+ with open(script_file, "w", encoding="utf-8") as f:
484
+ f.write(train_script)
485
+ files["train_script"] = str(script_file)
486
+
487
+ print(f"✅ ms-swift 导出完成:")
488
+ print(f" 数据文件: {data_file}")
489
+ print(f" 训练脚本: {script_file}")
490
+
491
+ return files
492
+
493
+
494
+ def _generate_swift_train_script(
495
+ dataset_name: str,
496
+ format_type: str,
497
+ model_name: str = "qwen2_5-7b-instruct",
498
+ **kwargs,
499
+ ) -> str:
500
+ """生成 ms-swift 训练脚本"""
501
+ # 确定数据集格式
502
+ if format_type in ("openai_chat", "sharegpt"):
503
+ dataset_format = "messages"
504
+ else:
505
+ dataset_format = "query-response"
506
+
507
+ return f"""#!/bin/bash
508
+ # ms-swift 训练脚本
509
+ # 使用: bash train.sh
510
+
511
+ swift sft \\
512
+ --model_type {model_name} \\
513
+ --dataset {dataset_name}.jsonl \\
514
+ --output_dir ./output \\
515
+ --max_length 2048 \\
516
+ --learning_rate 1e-4 \\
517
+ --num_train_epochs 3 \\
518
+ --per_device_train_batch_size 2 \\
519
+ --gradient_accumulation_steps 4 \\
520
+ --save_steps 500 \\
521
+ --logging_steps 10 \\
522
+ --lora_rank 8 \\
523
+ --lora_alpha 32
524
+ """
525
+
526
+
527
+ def _export_axolotl(
528
+ data: List[dict],
529
+ output_path: Path,
530
+ dataset_name: str,
531
+ format_type: str,
532
+ **kwargs,
533
+ ) -> Dict[str, str]:
534
+ """导出为 Axolotl 格式"""
535
+ files = {}
536
+
537
+ # 1. 保存数据文件
538
+ data_file = output_path / f"{dataset_name}.jsonl"
539
+ with open(data_file, "w", encoding="utf-8") as f:
540
+ for item in data:
541
+ f.write(json.dumps(item, ensure_ascii=False) + "\n")
542
+ files["data"] = str(data_file)
543
+
544
+ # 2. 生成配置文件
545
+ config = _generate_axolotl_config(dataset_name, format_type, **kwargs)
546
+ config_file = output_path / "config.yaml"
547
+ with open(config_file, "w", encoding="utf-8") as f:
548
+ f.write(config)
549
+ files["config"] = str(config_file)
550
+
551
+ print(f"✅ Axolotl 导出完成:")
552
+ print(f" 数据文件: {data_file}")
553
+ print(f" 配置文件: {config_file}")
554
+
555
+ return files
556
+
557
+
558
+ def _generate_axolotl_config(
559
+ dataset_name: str,
560
+ format_type: str,
561
+ model_name: str = "Qwen/Qwen2.5-7B-Instruct",
562
+ **kwargs,
563
+ ) -> str:
564
+ """生成 Axolotl 配置文件"""
565
+ # 确定数据集格式类型
566
+ if format_type in ("openai_chat",):
567
+ ds_type = "chat_template"
568
+ elif format_type == "sharegpt":
569
+ ds_type = "sharegpt"
570
+ elif format_type == "alpaca":
571
+ ds_type = "alpaca"
572
+ else:
573
+ ds_type = "completion"
574
+
575
+ return f"""# Axolotl 配置文件
576
+ # 使用: accelerate launch -m axolotl.cli.train config.yaml
577
+
578
+ base_model: {model_name}
579
+ model_type: AutoModelForCausalLM
580
+
581
+ datasets:
582
+ - path: {dataset_name}.jsonl
583
+ type: {ds_type}
584
+
585
+ sequence_len: 2048
586
+ sample_packing: true
587
+ pad_to_sequence_len: true
588
+
589
+ adapter: lora
590
+ lora_r: 8
591
+ lora_alpha: 16
592
+ lora_dropout: 0.05
593
+ lora_target_linear: true
594
+
595
+ learning_rate: 1e-4
596
+ num_epochs: 3
597
+ micro_batch_size: 2
598
+ gradient_accumulation_steps: 4
599
+
600
+ output_dir: ./output
601
+ logging_steps: 10
602
+ save_steps: 500
603
+
604
+ bf16: auto
605
+ tf32: true
606
+ gradient_checkpointing: true
607
+
608
+ warmup_ratio: 0.1
609
+ lr_scheduler: cosine
610
+ """