dtflow 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/converters.py CHANGED
@@ -259,6 +259,462 @@ def to_axolotl(
259
259
  return transform
260
260
 
261
261
 
262
+ def to_llama_factory_sharegpt(
263
+ messages_field: str = "messages",
264
+ system_field: Optional[str] = None,
265
+ tools_field: Optional[str] = None,
266
+ ) -> Callable:
267
+ """
268
+ 转换为 LLaMA-Factory ShareGPT 格式(多轮对话)。
269
+
270
+ 输出格式:
271
+ {
272
+ "conversations": [
273
+ {"from": "human", "value": "..."},
274
+ {"from": "gpt", "value": "..."}
275
+ ],
276
+ "system": "...", # 可选
277
+ "tools": "..." # 可选
278
+ }
279
+
280
+ Args:
281
+ messages_field: 输入的 messages 字段名
282
+ system_field: 系统提示字段(如果为 None,从 messages 中提取)
283
+ tools_field: 工具描述字段
284
+
285
+ Returns:
286
+ 转换函数
287
+
288
+ Examples:
289
+ >>> dt.transform(to_llama_factory_sharegpt())
290
+ >>> dt.transform(to_llama_factory_sharegpt(system_field="system_prompt"))
291
+ """
292
+ role_map = {
293
+ "user": "human",
294
+ "assistant": "gpt",
295
+ "system": "system",
296
+ "tool": "observation",
297
+ "function_call": "function_call",
298
+ }
299
+
300
+ def transform(item) -> dict:
301
+ get = lambda f: (item.get(f, "") if hasattr(item, "get") else item.get(f, ""))
302
+ messages = get(messages_field) or []
303
+
304
+ conversations = []
305
+ system_prompt = None
306
+
307
+ for msg in messages:
308
+ role = msg.get("role", "")
309
+ content = msg.get("content", "")
310
+
311
+ # 提取 system 消息
312
+ if role == "system":
313
+ system_prompt = content
314
+ continue
315
+
316
+ mapped_role = role_map.get(role, role)
317
+ conversations.append({"from": mapped_role, "value": content})
318
+
319
+ result = {"conversations": conversations}
320
+
321
+ # 系统提示:优先使用指定字段,否则用从 messages 提取的
322
+ if system_field:
323
+ system = get(system_field)
324
+ if system:
325
+ result["system"] = system
326
+ elif system_prompt:
327
+ result["system"] = system_prompt
328
+
329
+ # 工具描述
330
+ if tools_field:
331
+ tools = get(tools_field)
332
+ if tools:
333
+ result["tools"] = tools
334
+
335
+ return result
336
+
337
+ return transform
338
+
339
+
340
+ def to_llama_factory_vlm(
341
+ messages_field: str = "messages",
342
+ images_field: str = "images",
343
+ videos_field: Optional[str] = None,
344
+ system_field: Optional[str] = None,
345
+ ) -> Callable:
346
+ """
347
+ 转换为 LLaMA-Factory VLM(视觉语言模型)格式。
348
+
349
+ 输出格式 (Alpaca 风格):
350
+ {
351
+ "instruction": "...",
352
+ "input": "",
353
+ "output": "...",
354
+ "images": ["path1.jpg", "path2.jpg"], # 图片路径列表
355
+ "videos": ["path.mp4"], # 可选,视频路径列表
356
+ "system": "..." # 可选
357
+ }
358
+
359
+ Args:
360
+ messages_field: 输入的 messages 字段名
361
+ images_field: 图片路径字段名
362
+ videos_field: 视频路径字段名
363
+ system_field: 系统提示字段
364
+
365
+ Returns:
366
+ 转换函数
367
+
368
+ Examples:
369
+ >>> dt.transform(to_llama_factory_vlm())
370
+ >>> dt.transform(to_llama_factory_vlm(images_field="image_paths"))
371
+ """
372
+
373
+ def transform(item) -> dict:
374
+ get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
375
+ messages = get(messages_field) or []
376
+
377
+ instruction = ""
378
+ output = ""
379
+ system_prompt = None
380
+
381
+ for msg in messages:
382
+ role = msg.get("role", "")
383
+ content = msg.get("content", "")
384
+
385
+ if role == "system":
386
+ system_prompt = content
387
+ elif role == "user":
388
+ instruction = content
389
+ elif role == "assistant":
390
+ output = content
391
+
392
+ result = {
393
+ "instruction": instruction,
394
+ "input": "",
395
+ "output": output,
396
+ }
397
+
398
+ # 图片
399
+ images = get(images_field)
400
+ if images:
401
+ result["images"] = images if isinstance(images, list) else [images]
402
+
403
+ # 视频
404
+ if videos_field:
405
+ videos = get(videos_field)
406
+ if videos:
407
+ result["videos"] = videos if isinstance(videos, list) else [videos]
408
+
409
+ # 系统提示
410
+ if system_field:
411
+ system = get(system_field)
412
+ if system:
413
+ result["system"] = system
414
+ elif system_prompt:
415
+ result["system"] = system_prompt
416
+
417
+ return result
418
+
419
+ return transform
420
+
421
+
422
+ def to_llama_factory_vlm_sharegpt(
423
+ messages_field: str = "messages",
424
+ images_field: str = "images",
425
+ videos_field: Optional[str] = None,
426
+ system_field: Optional[str] = None,
427
+ ) -> Callable:
428
+ """
429
+ 转换为 LLaMA-Factory VLM ShareGPT 格式(多轮多模态对话)。
430
+
431
+ 输出格式:
432
+ {
433
+ "conversations": [
434
+ {"from": "human", "value": "<image>描述这张图片"},
435
+ {"from": "gpt", "value": "这是一张..."}
436
+ ],
437
+ "images": ["path1.jpg"],
438
+ "system": "..."
439
+ }
440
+
441
+ Args:
442
+ messages_field: 输入的 messages 字段名
443
+ images_field: 图片路径字段名
444
+ videos_field: 视频路径字段名
445
+ system_field: 系统提示字段
446
+
447
+ Returns:
448
+ 转换函数
449
+
450
+ Examples:
451
+ >>> dt.transform(to_llama_factory_vlm_sharegpt())
452
+ """
453
+ role_map = {"user": "human", "assistant": "gpt", "system": "system"}
454
+
455
+ def transform(item) -> dict:
456
+ get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
457
+ messages = get(messages_field) or []
458
+
459
+ conversations = []
460
+ system_prompt = None
461
+
462
+ for msg in messages:
463
+ role = msg.get("role", "")
464
+ content = msg.get("content", "")
465
+
466
+ if role == "system":
467
+ system_prompt = content
468
+ continue
469
+
470
+ mapped_role = role_map.get(role, role)
471
+ conversations.append({"from": mapped_role, "value": content})
472
+
473
+ result = {"conversations": conversations}
474
+
475
+ # 图片
476
+ images = get(images_field)
477
+ if images:
478
+ result["images"] = images if isinstance(images, list) else [images]
479
+
480
+ # 视频
481
+ if videos_field:
482
+ videos = get(videos_field)
483
+ if videos:
484
+ result["videos"] = videos if isinstance(videos, list) else [videos]
485
+
486
+ # 系统提示
487
+ if system_field:
488
+ system = get(system_field)
489
+ if system:
490
+ result["system"] = system
491
+ elif system_prompt:
492
+ result["system"] = system_prompt
493
+
494
+ return result
495
+
496
+ return transform
497
+
498
+
499
+ # ============== ms-swift 格式转换器 ==============
500
+
501
+
502
+ def to_swift_messages(
503
+ messages_field: str = "messages",
504
+ system_field: Optional[str] = None,
505
+ ) -> Callable:
506
+ """
507
+ 转换为 ms-swift messages 格式(标准格式)。
508
+
509
+ 输出格式:
510
+ {
511
+ "messages": [
512
+ {"role": "system", "content": "..."},
513
+ {"role": "user", "content": "..."},
514
+ {"role": "assistant", "content": "..."}
515
+ ]
516
+ }
517
+
518
+ Args:
519
+ messages_field: 输入的 messages 字段名
520
+ system_field: 系统提示字段(如果需要额外添加)
521
+
522
+ Returns:
523
+ 转换函数
524
+
525
+ Examples:
526
+ >>> dt.transform(to_swift_messages())
527
+ """
528
+
529
+ def transform(item) -> dict:
530
+ get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
531
+ messages = get(messages_field) or []
532
+
533
+ # 复制 messages,避免修改原数据
534
+ result_messages = []
535
+
536
+ # 如果指定了 system_field,添加系统消息
537
+ if system_field:
538
+ system = get(system_field)
539
+ if system:
540
+ result_messages.append({"role": "system", "content": system})
541
+
542
+ for msg in messages:
543
+ # 标准化格式
544
+ result_messages.append({
545
+ "role": msg.get("role", "user"),
546
+ "content": msg.get("content", ""),
547
+ })
548
+
549
+ return {"messages": result_messages}
550
+
551
+ return transform
552
+
553
+
554
+ def to_swift_query_response(
555
+ query_field: str = "query",
556
+ response_field: str = "response",
557
+ system_field: Optional[str] = None,
558
+ history_field: Optional[str] = None,
559
+ ) -> Callable:
560
+ """
561
+ 转换为 ms-swift query-response 格式。
562
+
563
+ 输出格式:
564
+ {
565
+ "query": "用户问题",
566
+ "response": "模型回答",
567
+ "system": "系统提示", # 可选
568
+ "history": [["q1", "r1"]] # 可选
569
+ }
570
+
571
+ Args:
572
+ query_field: 用户问题字段
573
+ response_field: 模型回答字段
574
+ system_field: 系统提示字段
575
+ history_field: 历史对话字段
576
+
577
+ Returns:
578
+ 转换函数
579
+
580
+ Examples:
581
+ >>> dt.transform(to_swift_query_response())
582
+ >>> # 从 messages 格式转换
583
+ >>> dt.transform(to_swift_query_response(query_field="messages"))
584
+ """
585
+
586
+ def transform(item) -> dict:
587
+ get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
588
+
589
+ query = get(query_field)
590
+ response = get(response_field)
591
+
592
+ # 如果 query_field 是 messages,提取最后一轮对话
593
+ if isinstance(query, list):
594
+ messages = query
595
+ system_prompt = None
596
+ history = []
597
+ current_query = ""
598
+ current_response = ""
599
+
600
+ for i, msg in enumerate(messages):
601
+ role = msg.get("role", "")
602
+ content = msg.get("content", "")
603
+
604
+ if role == "system":
605
+ system_prompt = content
606
+ elif role == "user":
607
+ if current_query and current_response:
608
+ history.append([current_query, current_response])
609
+ current_query = content
610
+ current_response = ""
611
+ elif role == "assistant":
612
+ current_response = content
613
+
614
+ result = {
615
+ "query": current_query,
616
+ "response": current_response,
617
+ }
618
+
619
+ if system_prompt:
620
+ result["system"] = system_prompt
621
+ if history:
622
+ result["history"] = history
623
+
624
+ return result
625
+
626
+ # 直接使用字段
627
+ result = {
628
+ "query": query or "",
629
+ "response": response or "",
630
+ }
631
+
632
+ if system_field:
633
+ system = get(system_field)
634
+ if system:
635
+ result["system"] = system
636
+
637
+ if history_field:
638
+ history = get(history_field)
639
+ if history:
640
+ result["history"] = history
641
+
642
+ return result
643
+
644
+ return transform
645
+
646
+
647
+ def to_swift_vlm(
648
+ messages_field: str = "messages",
649
+ images_field: str = "images",
650
+ videos_field: Optional[str] = None,
651
+ system_field: Optional[str] = None,
652
+ ) -> Callable:
653
+ """
654
+ 转换为 ms-swift VLM(视觉语言模型)格式。
655
+
656
+ 输出格式:
657
+ {
658
+ "messages": [
659
+ {"role": "user", "content": "<image>描述图片"},
660
+ {"role": "assistant", "content": "这是..."}
661
+ ],
662
+ "images": ["/path/to/image.jpg"]
663
+ }
664
+
665
+ Args:
666
+ messages_field: 输入的 messages 字段名
667
+ images_field: 图片路径字段名
668
+ videos_field: 视频路径字段名
669
+ system_field: 系统提示字段
670
+
671
+ Returns:
672
+ 转换函数
673
+
674
+ Examples:
675
+ >>> dt.transform(to_swift_vlm())
676
+ >>> dt.transform(to_swift_vlm(images_field="image_paths"))
677
+ """
678
+
679
+ def transform(item) -> dict:
680
+ get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
681
+ messages = get(messages_field) or []
682
+
683
+ result_messages = []
684
+
685
+ # 添加系统提示
686
+ if system_field:
687
+ system = get(system_field)
688
+ if system:
689
+ result_messages.append({"role": "system", "content": system})
690
+
691
+ for msg in messages:
692
+ role = msg.get("role", "")
693
+ content = msg.get("content", "")
694
+
695
+ if role == "system" and not system_field:
696
+ result_messages.append({"role": "system", "content": content})
697
+ elif role in ("user", "assistant"):
698
+ result_messages.append({"role": role, "content": content})
699
+
700
+ result = {"messages": result_messages}
701
+
702
+ # 图片
703
+ images = get(images_field)
704
+ if images:
705
+ result["images"] = images if isinstance(images, list) else [images]
706
+
707
+ # 视频
708
+ if videos_field:
709
+ videos = get(videos_field)
710
+ if videos:
711
+ result["videos"] = videos if isinstance(videos, list) else [videos]
712
+
713
+ return result
714
+
715
+ return transform
716
+
717
+
262
718
  def messages_to_text(
263
719
  messages_field: str = "messages",
264
720
  output_field: str = "text",
dtflow/core.py CHANGED
@@ -6,29 +6,16 @@ DataTransformer 核心模块
6
6
  from typing import List, Dict, Any, Optional, Callable, Union, Tuple, Literal
7
7
  from copy import deepcopy
8
8
  from dataclasses import dataclass
9
- import json
10
9
 
11
- from .storage.io import save_data, load_data
10
+ import orjson
12
11
 
13
- # 尝试使用 orjson(更快的 JSON 序列化库)
14
- try:
15
- import orjson
16
- _HAS_ORJSON = True
17
- except ImportError:
18
- _HAS_ORJSON = False
12
+ from .storage.io import save_data, load_data
13
+ from .lineage import LineageTracker
19
14
 
20
15
 
21
16
  def _fast_json_dumps(obj: Any) -> str:
22
- """
23
- 快速 JSON 序列化,优先使用 orjson
24
-
25
- orjson 比标准 json 快约 10 倍,特别适合大量数据的序列化场景。
26
- """
27
- if _HAS_ORJSON:
28
- # orjson.dumps 返回 bytes,需要 decode
29
- return orjson.dumps(obj, option=orjson.OPT_SORT_KEYS).decode('utf-8')
30
- else:
31
- return json.dumps(obj, sort_keys=True, ensure_ascii=False)
17
+ """快速 JSON 序列化(使用 orjson,比标准 json 快约 10 倍)"""
18
+ return orjson.dumps(obj, option=orjson.OPT_SORT_KEYS).decode('utf-8')
32
19
 
33
20
 
34
21
  # ============ 错误处理 ============
@@ -102,8 +89,15 @@ class DataTransformer:
102
89
  - fields/stats: 数据信息
103
90
  """
104
91
 
105
- def __init__(self, data: Optional[List[Dict[str, Any]]] = None):
92
+ def __init__(
93
+ self,
94
+ data: Optional[List[Dict[str, Any]]] = None,
95
+ _source_path: Optional[str] = None,
96
+ _lineage_tracker: Optional[LineageTracker] = None,
97
+ ):
106
98
  self._data = data if data is not None else []
99
+ self._source_path = _source_path
100
+ self._lineage_tracker = _lineage_tracker
107
101
 
108
102
  @property
109
103
  def data(self) -> List[Dict[str, Any]]:
@@ -122,23 +116,38 @@ class DataTransformer:
122
116
  # ============ 加载/保存 ============
123
117
 
124
118
  @classmethod
125
- def load(cls, filepath: str) -> 'DataTransformer':
119
+ def load(cls, filepath: str, track_lineage: bool = False) -> 'DataTransformer':
126
120
  """
127
121
  从文件加载数据。
128
122
 
129
123
  支持格式: jsonl, json, csv, parquet(自动检测)
124
+
125
+ Args:
126
+ filepath: 文件路径
127
+ track_lineage: 是否追踪血缘(默认 False)
130
128
  """
131
129
  data = load_data(filepath)
132
- return cls(data)
130
+ tracker = LineageTracker(filepath) if track_lineage else None
131
+ return cls(data, _source_path=filepath, _lineage_tracker=tracker)
133
132
 
134
- def save(self, filepath: str) -> None:
133
+ def save(self, filepath: str, lineage: bool = False) -> None:
135
134
  """
136
135
  保存数据到文件。
137
136
 
138
137
  支持格式: jsonl, json, csv, parquet(根据扩展名)
138
+
139
+ Args:
140
+ filepath: 文件路径
141
+ lineage: 是否保存血缘元数据(默认 False)
139
142
  """
140
143
  save_data(self._data, filepath)
141
144
 
145
+ # 保存血缘记录
146
+ if lineage and self._lineage_tracker:
147
+ lineage_path = self._lineage_tracker.save(filepath, len(self._data))
148
+ import sys
149
+ print(f"📜 血缘记录已保存: {lineage_path}", file=sys.stderr)
150
+
142
151
  # ============ 核心转换 ============
143
152
 
144
153
  def to(
@@ -230,7 +239,16 @@ class DataTransformer:
230
239
  >>> # 原始模式(大数据集推荐)
231
240
  >>> dt.transform(lambda x: {"q": x["q"]}, raw=True).save("output.jsonl")
232
241
  """
233
- return DataTransformer(self.to(func, on_error=on_error, raw=raw))
242
+ input_count = len(self._data)
243
+ result = self.to(func, on_error=on_error, raw=raw)
244
+ output_count = len(result)
245
+
246
+ # 传递血缘追踪器并记录操作
247
+ tracker = self._lineage_tracker
248
+ if tracker:
249
+ tracker.record("transform", {"func": func}, input_count, output_count)
250
+
251
+ return DataTransformer(result, _lineage_tracker=tracker)
234
252
 
235
253
  # ============ 数据筛选 ============
236
254
 
@@ -281,7 +299,12 @@ class DataTransformer:
281
299
  if errors:
282
300
  _print_error_summary(errors, len(self._data))
283
301
 
284
- return DataTransformer(filtered)
302
+ # 传递血缘追踪器并记录操作
303
+ tracker = self._lineage_tracker
304
+ if tracker:
305
+ tracker.record("filter", {"func": func}, len(self._data), len(filtered))
306
+
307
+ return DataTransformer(filtered, _lineage_tracker=tracker)
285
308
 
286
309
  def sample(self, n: int, seed: Optional[int] = None) -> 'DataTransformer':
287
310
  """
@@ -295,16 +318,30 @@ class DataTransformer:
295
318
  if seed is not None:
296
319
  random.seed(seed)
297
320
 
321
+ input_count = len(self._data)
298
322
  data = self._data[:] if n >= len(self._data) else random.sample(self._data, n)
299
- return DataTransformer(data)
323
+
324
+ tracker = self._lineage_tracker
325
+ if tracker:
326
+ tracker.record("sample", {"n": n, "seed": seed}, input_count, len(data))
327
+
328
+ return DataTransformer(data, _lineage_tracker=tracker)
300
329
 
301
330
  def head(self, n: int = 10) -> 'DataTransformer':
302
331
  """取前 n 条"""
303
- return DataTransformer(self._data[:n])
332
+ data = self._data[:n]
333
+ tracker = self._lineage_tracker
334
+ if tracker:
335
+ tracker.record("head", {"n": n}, len(self._data), len(data))
336
+ return DataTransformer(data, _lineage_tracker=tracker)
304
337
 
305
338
  def tail(self, n: int = 10) -> 'DataTransformer':
306
339
  """取后 n 条"""
307
- return DataTransformer(self._data[-n:])
340
+ data = self._data[-n:]
341
+ tracker = self._lineage_tracker
342
+ if tracker:
343
+ tracker.record("tail", {"n": n}, len(self._data), len(data))
344
+ return DataTransformer(data, _lineage_tracker=tracker)
308
345
 
309
346
  def dedupe(
310
347
  self,
@@ -338,7 +375,11 @@ class DataTransformer:
338
375
  seen.add(k)
339
376
  result.append(item)
340
377
 
341
- return DataTransformer(result)
378
+ tracker = self._lineage_tracker
379
+ if tracker:
380
+ tracker.record("dedupe", {"key": key}, len(self._data), len(result))
381
+
382
+ return DataTransformer(result, _lineage_tracker=tracker)
342
383
 
343
384
  def _get_dedupe_key(
344
385
  self,
@@ -442,7 +483,17 @@ class DataTransformer:
442
483
 
443
484
  # 按原顺序保留数据
444
485
  result = [self._data[i] for i in sorted(keep_indices)]
445
- return DataTransformer(result)
486
+
487
+ tracker = self._lineage_tracker
488
+ if tracker:
489
+ tracker.record(
490
+ "dedupe_similar",
491
+ {"key": key, "threshold": threshold, "num_perm": num_perm, "ngram": ngram},
492
+ len(self._data),
493
+ len(result),
494
+ )
495
+
496
+ return DataTransformer(result, _lineage_tracker=tracker)
446
497
 
447
498
  def _get_text_for_similarity(
448
499
  self,
@@ -581,7 +632,12 @@ class DataTransformer:
581
632
  if seed is not None:
582
633
  random.seed(seed)
583
634
  random.shuffle(data)
584
- return DataTransformer(data)
635
+
636
+ tracker = self._lineage_tracker
637
+ if tracker:
638
+ tracker.record("shuffle", {"seed": seed}, len(self._data), len(data))
639
+
640
+ return DataTransformer(data, _lineage_tracker=tracker)
585
641
 
586
642
  def split(self, ratio: float = 0.8, seed: Optional[int] = None) -> tuple:
587
643
  """
@@ -596,7 +652,16 @@ class DataTransformer:
596
652
  """
597
653
  data = self.shuffle(seed).data
598
654
  split_idx = int(len(data) * ratio)
599
- return DataTransformer(data[:split_idx]), DataTransformer(data[split_idx:])
655
+
656
+ # 分割后血缘追踪器各自独立
657
+ tracker = self._lineage_tracker
658
+ if tracker:
659
+ tracker.record("split", {"ratio": ratio, "seed": seed}, len(self._data), len(data))
660
+
661
+ return (
662
+ DataTransformer(data[:split_idx], _lineage_tracker=tracker),
663
+ DataTransformer(data[split_idx:], _lineage_tracker=tracker),
664
+ )
600
665
 
601
666
  # ============ 并行处理 ============
602
667