dtflow 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/__init__.py +36 -2
- dtflow/__main__.py +292 -239
- dtflow/cli/__init__.py +8 -2
- dtflow/cli/commands.py +1030 -92
- dtflow/converters.py +456 -0
- dtflow/core.py +96 -31
- dtflow/lineage.py +407 -0
- dtflow/mcp/cli.py +14 -14
- dtflow/pipeline.py +450 -0
- dtflow/storage/io.py +376 -370
- dtflow/streaming.py +661 -0
- dtflow/tokenizers.py +387 -31
- dtflow/utils/display.py +5 -4
- {dtflow-0.2.0.dist-info → dtflow-0.3.1.dist-info}/METADATA +234 -15
- dtflow-0.3.1.dist-info/RECORD +24 -0
- dtflow-0.2.0.dist-info/RECORD +0 -21
- {dtflow-0.2.0.dist-info → dtflow-0.3.1.dist-info}/WHEEL +0 -0
- {dtflow-0.2.0.dist-info → dtflow-0.3.1.dist-info}/entry_points.txt +0 -0
dtflow/converters.py
CHANGED
|
@@ -259,6 +259,462 @@ def to_axolotl(
|
|
|
259
259
|
return transform
|
|
260
260
|
|
|
261
261
|
|
|
262
|
+
def to_llama_factory_sharegpt(
|
|
263
|
+
messages_field: str = "messages",
|
|
264
|
+
system_field: Optional[str] = None,
|
|
265
|
+
tools_field: Optional[str] = None,
|
|
266
|
+
) -> Callable:
|
|
267
|
+
"""
|
|
268
|
+
转换为 LLaMA-Factory ShareGPT 格式(多轮对话)。
|
|
269
|
+
|
|
270
|
+
输出格式:
|
|
271
|
+
{
|
|
272
|
+
"conversations": [
|
|
273
|
+
{"from": "human", "value": "..."},
|
|
274
|
+
{"from": "gpt", "value": "..."}
|
|
275
|
+
],
|
|
276
|
+
"system": "...", # 可选
|
|
277
|
+
"tools": "..." # 可选
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
messages_field: 输入的 messages 字段名
|
|
282
|
+
system_field: 系统提示字段(如果为 None,从 messages 中提取)
|
|
283
|
+
tools_field: 工具描述字段
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
转换函数
|
|
287
|
+
|
|
288
|
+
Examples:
|
|
289
|
+
>>> dt.transform(to_llama_factory_sharegpt())
|
|
290
|
+
>>> dt.transform(to_llama_factory_sharegpt(system_field="system_prompt"))
|
|
291
|
+
"""
|
|
292
|
+
role_map = {
|
|
293
|
+
"user": "human",
|
|
294
|
+
"assistant": "gpt",
|
|
295
|
+
"system": "system",
|
|
296
|
+
"tool": "observation",
|
|
297
|
+
"function_call": "function_call",
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
def transform(item) -> dict:
|
|
301
|
+
get = lambda f: (item.get(f, "") if hasattr(item, "get") else item.get(f, ""))
|
|
302
|
+
messages = get(messages_field) or []
|
|
303
|
+
|
|
304
|
+
conversations = []
|
|
305
|
+
system_prompt = None
|
|
306
|
+
|
|
307
|
+
for msg in messages:
|
|
308
|
+
role = msg.get("role", "")
|
|
309
|
+
content = msg.get("content", "")
|
|
310
|
+
|
|
311
|
+
# 提取 system 消息
|
|
312
|
+
if role == "system":
|
|
313
|
+
system_prompt = content
|
|
314
|
+
continue
|
|
315
|
+
|
|
316
|
+
mapped_role = role_map.get(role, role)
|
|
317
|
+
conversations.append({"from": mapped_role, "value": content})
|
|
318
|
+
|
|
319
|
+
result = {"conversations": conversations}
|
|
320
|
+
|
|
321
|
+
# 系统提示:优先使用指定字段,否则用从 messages 提取的
|
|
322
|
+
if system_field:
|
|
323
|
+
system = get(system_field)
|
|
324
|
+
if system:
|
|
325
|
+
result["system"] = system
|
|
326
|
+
elif system_prompt:
|
|
327
|
+
result["system"] = system_prompt
|
|
328
|
+
|
|
329
|
+
# 工具描述
|
|
330
|
+
if tools_field:
|
|
331
|
+
tools = get(tools_field)
|
|
332
|
+
if tools:
|
|
333
|
+
result["tools"] = tools
|
|
334
|
+
|
|
335
|
+
return result
|
|
336
|
+
|
|
337
|
+
return transform
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def to_llama_factory_vlm(
|
|
341
|
+
messages_field: str = "messages",
|
|
342
|
+
images_field: str = "images",
|
|
343
|
+
videos_field: Optional[str] = None,
|
|
344
|
+
system_field: Optional[str] = None,
|
|
345
|
+
) -> Callable:
|
|
346
|
+
"""
|
|
347
|
+
转换为 LLaMA-Factory VLM(视觉语言模型)格式。
|
|
348
|
+
|
|
349
|
+
输出格式 (Alpaca 风格):
|
|
350
|
+
{
|
|
351
|
+
"instruction": "...",
|
|
352
|
+
"input": "",
|
|
353
|
+
"output": "...",
|
|
354
|
+
"images": ["path1.jpg", "path2.jpg"], # 图片路径列表
|
|
355
|
+
"videos": ["path.mp4"], # 可选,视频路径列表
|
|
356
|
+
"system": "..." # 可选
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
messages_field: 输入的 messages 字段名
|
|
361
|
+
images_field: 图片路径字段名
|
|
362
|
+
videos_field: 视频路径字段名
|
|
363
|
+
system_field: 系统提示字段
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
转换函数
|
|
367
|
+
|
|
368
|
+
Examples:
|
|
369
|
+
>>> dt.transform(to_llama_factory_vlm())
|
|
370
|
+
>>> dt.transform(to_llama_factory_vlm(images_field="image_paths"))
|
|
371
|
+
"""
|
|
372
|
+
|
|
373
|
+
def transform(item) -> dict:
|
|
374
|
+
get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
|
|
375
|
+
messages = get(messages_field) or []
|
|
376
|
+
|
|
377
|
+
instruction = ""
|
|
378
|
+
output = ""
|
|
379
|
+
system_prompt = None
|
|
380
|
+
|
|
381
|
+
for msg in messages:
|
|
382
|
+
role = msg.get("role", "")
|
|
383
|
+
content = msg.get("content", "")
|
|
384
|
+
|
|
385
|
+
if role == "system":
|
|
386
|
+
system_prompt = content
|
|
387
|
+
elif role == "user":
|
|
388
|
+
instruction = content
|
|
389
|
+
elif role == "assistant":
|
|
390
|
+
output = content
|
|
391
|
+
|
|
392
|
+
result = {
|
|
393
|
+
"instruction": instruction,
|
|
394
|
+
"input": "",
|
|
395
|
+
"output": output,
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
# 图片
|
|
399
|
+
images = get(images_field)
|
|
400
|
+
if images:
|
|
401
|
+
result["images"] = images if isinstance(images, list) else [images]
|
|
402
|
+
|
|
403
|
+
# 视频
|
|
404
|
+
if videos_field:
|
|
405
|
+
videos = get(videos_field)
|
|
406
|
+
if videos:
|
|
407
|
+
result["videos"] = videos if isinstance(videos, list) else [videos]
|
|
408
|
+
|
|
409
|
+
# 系统提示
|
|
410
|
+
if system_field:
|
|
411
|
+
system = get(system_field)
|
|
412
|
+
if system:
|
|
413
|
+
result["system"] = system
|
|
414
|
+
elif system_prompt:
|
|
415
|
+
result["system"] = system_prompt
|
|
416
|
+
|
|
417
|
+
return result
|
|
418
|
+
|
|
419
|
+
return transform
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def to_llama_factory_vlm_sharegpt(
|
|
423
|
+
messages_field: str = "messages",
|
|
424
|
+
images_field: str = "images",
|
|
425
|
+
videos_field: Optional[str] = None,
|
|
426
|
+
system_field: Optional[str] = None,
|
|
427
|
+
) -> Callable:
|
|
428
|
+
"""
|
|
429
|
+
转换为 LLaMA-Factory VLM ShareGPT 格式(多轮多模态对话)。
|
|
430
|
+
|
|
431
|
+
输出格式:
|
|
432
|
+
{
|
|
433
|
+
"conversations": [
|
|
434
|
+
{"from": "human", "value": "<image>描述这张图片"},
|
|
435
|
+
{"from": "gpt", "value": "这是一张..."}
|
|
436
|
+
],
|
|
437
|
+
"images": ["path1.jpg"],
|
|
438
|
+
"system": "..."
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
Args:
|
|
442
|
+
messages_field: 输入的 messages 字段名
|
|
443
|
+
images_field: 图片路径字段名
|
|
444
|
+
videos_field: 视频路径字段名
|
|
445
|
+
system_field: 系统提示字段
|
|
446
|
+
|
|
447
|
+
Returns:
|
|
448
|
+
转换函数
|
|
449
|
+
|
|
450
|
+
Examples:
|
|
451
|
+
>>> dt.transform(to_llama_factory_vlm_sharegpt())
|
|
452
|
+
"""
|
|
453
|
+
role_map = {"user": "human", "assistant": "gpt", "system": "system"}
|
|
454
|
+
|
|
455
|
+
def transform(item) -> dict:
|
|
456
|
+
get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
|
|
457
|
+
messages = get(messages_field) or []
|
|
458
|
+
|
|
459
|
+
conversations = []
|
|
460
|
+
system_prompt = None
|
|
461
|
+
|
|
462
|
+
for msg in messages:
|
|
463
|
+
role = msg.get("role", "")
|
|
464
|
+
content = msg.get("content", "")
|
|
465
|
+
|
|
466
|
+
if role == "system":
|
|
467
|
+
system_prompt = content
|
|
468
|
+
continue
|
|
469
|
+
|
|
470
|
+
mapped_role = role_map.get(role, role)
|
|
471
|
+
conversations.append({"from": mapped_role, "value": content})
|
|
472
|
+
|
|
473
|
+
result = {"conversations": conversations}
|
|
474
|
+
|
|
475
|
+
# 图片
|
|
476
|
+
images = get(images_field)
|
|
477
|
+
if images:
|
|
478
|
+
result["images"] = images if isinstance(images, list) else [images]
|
|
479
|
+
|
|
480
|
+
# 视频
|
|
481
|
+
if videos_field:
|
|
482
|
+
videos = get(videos_field)
|
|
483
|
+
if videos:
|
|
484
|
+
result["videos"] = videos if isinstance(videos, list) else [videos]
|
|
485
|
+
|
|
486
|
+
# 系统提示
|
|
487
|
+
if system_field:
|
|
488
|
+
system = get(system_field)
|
|
489
|
+
if system:
|
|
490
|
+
result["system"] = system
|
|
491
|
+
elif system_prompt:
|
|
492
|
+
result["system"] = system_prompt
|
|
493
|
+
|
|
494
|
+
return result
|
|
495
|
+
|
|
496
|
+
return transform
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
# ============== ms-swift 格式转换器 ==============
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def to_swift_messages(
|
|
503
|
+
messages_field: str = "messages",
|
|
504
|
+
system_field: Optional[str] = None,
|
|
505
|
+
) -> Callable:
|
|
506
|
+
"""
|
|
507
|
+
转换为 ms-swift messages 格式(标准格式)。
|
|
508
|
+
|
|
509
|
+
输出格式:
|
|
510
|
+
{
|
|
511
|
+
"messages": [
|
|
512
|
+
{"role": "system", "content": "..."},
|
|
513
|
+
{"role": "user", "content": "..."},
|
|
514
|
+
{"role": "assistant", "content": "..."}
|
|
515
|
+
]
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
messages_field: 输入的 messages 字段名
|
|
520
|
+
system_field: 系统提示字段(如果需要额外添加)
|
|
521
|
+
|
|
522
|
+
Returns:
|
|
523
|
+
转换函数
|
|
524
|
+
|
|
525
|
+
Examples:
|
|
526
|
+
>>> dt.transform(to_swift_messages())
|
|
527
|
+
"""
|
|
528
|
+
|
|
529
|
+
def transform(item) -> dict:
|
|
530
|
+
get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
|
|
531
|
+
messages = get(messages_field) or []
|
|
532
|
+
|
|
533
|
+
# 复制 messages,避免修改原数据
|
|
534
|
+
result_messages = []
|
|
535
|
+
|
|
536
|
+
# 如果指定了 system_field,添加系统消息
|
|
537
|
+
if system_field:
|
|
538
|
+
system = get(system_field)
|
|
539
|
+
if system:
|
|
540
|
+
result_messages.append({"role": "system", "content": system})
|
|
541
|
+
|
|
542
|
+
for msg in messages:
|
|
543
|
+
# 标准化格式
|
|
544
|
+
result_messages.append({
|
|
545
|
+
"role": msg.get("role", "user"),
|
|
546
|
+
"content": msg.get("content", ""),
|
|
547
|
+
})
|
|
548
|
+
|
|
549
|
+
return {"messages": result_messages}
|
|
550
|
+
|
|
551
|
+
return transform
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def to_swift_query_response(
|
|
555
|
+
query_field: str = "query",
|
|
556
|
+
response_field: str = "response",
|
|
557
|
+
system_field: Optional[str] = None,
|
|
558
|
+
history_field: Optional[str] = None,
|
|
559
|
+
) -> Callable:
|
|
560
|
+
"""
|
|
561
|
+
转换为 ms-swift query-response 格式。
|
|
562
|
+
|
|
563
|
+
输出格式:
|
|
564
|
+
{
|
|
565
|
+
"query": "用户问题",
|
|
566
|
+
"response": "模型回答",
|
|
567
|
+
"system": "系统提示", # 可选
|
|
568
|
+
"history": [["q1", "r1"]] # 可选
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
Args:
|
|
572
|
+
query_field: 用户问题字段
|
|
573
|
+
response_field: 模型回答字段
|
|
574
|
+
system_field: 系统提示字段
|
|
575
|
+
history_field: 历史对话字段
|
|
576
|
+
|
|
577
|
+
Returns:
|
|
578
|
+
转换函数
|
|
579
|
+
|
|
580
|
+
Examples:
|
|
581
|
+
>>> dt.transform(to_swift_query_response())
|
|
582
|
+
>>> # 从 messages 格式转换
|
|
583
|
+
>>> dt.transform(to_swift_query_response(query_field="messages"))
|
|
584
|
+
"""
|
|
585
|
+
|
|
586
|
+
def transform(item) -> dict:
|
|
587
|
+
get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
|
|
588
|
+
|
|
589
|
+
query = get(query_field)
|
|
590
|
+
response = get(response_field)
|
|
591
|
+
|
|
592
|
+
# 如果 query_field 是 messages,提取最后一轮对话
|
|
593
|
+
if isinstance(query, list):
|
|
594
|
+
messages = query
|
|
595
|
+
system_prompt = None
|
|
596
|
+
history = []
|
|
597
|
+
current_query = ""
|
|
598
|
+
current_response = ""
|
|
599
|
+
|
|
600
|
+
for i, msg in enumerate(messages):
|
|
601
|
+
role = msg.get("role", "")
|
|
602
|
+
content = msg.get("content", "")
|
|
603
|
+
|
|
604
|
+
if role == "system":
|
|
605
|
+
system_prompt = content
|
|
606
|
+
elif role == "user":
|
|
607
|
+
if current_query and current_response:
|
|
608
|
+
history.append([current_query, current_response])
|
|
609
|
+
current_query = content
|
|
610
|
+
current_response = ""
|
|
611
|
+
elif role == "assistant":
|
|
612
|
+
current_response = content
|
|
613
|
+
|
|
614
|
+
result = {
|
|
615
|
+
"query": current_query,
|
|
616
|
+
"response": current_response,
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
if system_prompt:
|
|
620
|
+
result["system"] = system_prompt
|
|
621
|
+
if history:
|
|
622
|
+
result["history"] = history
|
|
623
|
+
|
|
624
|
+
return result
|
|
625
|
+
|
|
626
|
+
# 直接使用字段
|
|
627
|
+
result = {
|
|
628
|
+
"query": query or "",
|
|
629
|
+
"response": response or "",
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
if system_field:
|
|
633
|
+
system = get(system_field)
|
|
634
|
+
if system:
|
|
635
|
+
result["system"] = system
|
|
636
|
+
|
|
637
|
+
if history_field:
|
|
638
|
+
history = get(history_field)
|
|
639
|
+
if history:
|
|
640
|
+
result["history"] = history
|
|
641
|
+
|
|
642
|
+
return result
|
|
643
|
+
|
|
644
|
+
return transform
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
def to_swift_vlm(
|
|
648
|
+
messages_field: str = "messages",
|
|
649
|
+
images_field: str = "images",
|
|
650
|
+
videos_field: Optional[str] = None,
|
|
651
|
+
system_field: Optional[str] = None,
|
|
652
|
+
) -> Callable:
|
|
653
|
+
"""
|
|
654
|
+
转换为 ms-swift VLM(视觉语言模型)格式。
|
|
655
|
+
|
|
656
|
+
输出格式:
|
|
657
|
+
{
|
|
658
|
+
"messages": [
|
|
659
|
+
{"role": "user", "content": "<image>描述图片"},
|
|
660
|
+
{"role": "assistant", "content": "这是..."}
|
|
661
|
+
],
|
|
662
|
+
"images": ["/path/to/image.jpg"]
|
|
663
|
+
}
|
|
664
|
+
|
|
665
|
+
Args:
|
|
666
|
+
messages_field: 输入的 messages 字段名
|
|
667
|
+
images_field: 图片路径字段名
|
|
668
|
+
videos_field: 视频路径字段名
|
|
669
|
+
system_field: 系统提示字段
|
|
670
|
+
|
|
671
|
+
Returns:
|
|
672
|
+
转换函数
|
|
673
|
+
|
|
674
|
+
Examples:
|
|
675
|
+
>>> dt.transform(to_swift_vlm())
|
|
676
|
+
>>> dt.transform(to_swift_vlm(images_field="image_paths"))
|
|
677
|
+
"""
|
|
678
|
+
|
|
679
|
+
def transform(item) -> dict:
|
|
680
|
+
get = lambda f: item.get(f) if hasattr(item, "get") else item.get(f)
|
|
681
|
+
messages = get(messages_field) or []
|
|
682
|
+
|
|
683
|
+
result_messages = []
|
|
684
|
+
|
|
685
|
+
# 添加系统提示
|
|
686
|
+
if system_field:
|
|
687
|
+
system = get(system_field)
|
|
688
|
+
if system:
|
|
689
|
+
result_messages.append({"role": "system", "content": system})
|
|
690
|
+
|
|
691
|
+
for msg in messages:
|
|
692
|
+
role = msg.get("role", "")
|
|
693
|
+
content = msg.get("content", "")
|
|
694
|
+
|
|
695
|
+
if role == "system" and not system_field:
|
|
696
|
+
result_messages.append({"role": "system", "content": content})
|
|
697
|
+
elif role in ("user", "assistant"):
|
|
698
|
+
result_messages.append({"role": role, "content": content})
|
|
699
|
+
|
|
700
|
+
result = {"messages": result_messages}
|
|
701
|
+
|
|
702
|
+
# 图片
|
|
703
|
+
images = get(images_field)
|
|
704
|
+
if images:
|
|
705
|
+
result["images"] = images if isinstance(images, list) else [images]
|
|
706
|
+
|
|
707
|
+
# 视频
|
|
708
|
+
if videos_field:
|
|
709
|
+
videos = get(videos_field)
|
|
710
|
+
if videos:
|
|
711
|
+
result["videos"] = videos if isinstance(videos, list) else [videos]
|
|
712
|
+
|
|
713
|
+
return result
|
|
714
|
+
|
|
715
|
+
return transform
|
|
716
|
+
|
|
717
|
+
|
|
262
718
|
def messages_to_text(
|
|
263
719
|
messages_field: str = "messages",
|
|
264
720
|
output_field: str = "text",
|
dtflow/core.py
CHANGED
|
@@ -6,29 +6,16 @@ DataTransformer 核心模块
|
|
|
6
6
|
from typing import List, Dict, Any, Optional, Callable, Union, Tuple, Literal
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from dataclasses import dataclass
|
|
9
|
-
import json
|
|
10
9
|
|
|
11
|
-
|
|
10
|
+
import orjson
|
|
12
11
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
import orjson
|
|
16
|
-
_HAS_ORJSON = True
|
|
17
|
-
except ImportError:
|
|
18
|
-
_HAS_ORJSON = False
|
|
12
|
+
from .storage.io import save_data, load_data
|
|
13
|
+
from .lineage import LineageTracker
|
|
19
14
|
|
|
20
15
|
|
|
21
16
|
def _fast_json_dumps(obj: Any) -> str:
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
orjson 比标准 json 快约 10 倍,特别适合大量数据的序列化场景。
|
|
26
|
-
"""
|
|
27
|
-
if _HAS_ORJSON:
|
|
28
|
-
# orjson.dumps 返回 bytes,需要 decode
|
|
29
|
-
return orjson.dumps(obj, option=orjson.OPT_SORT_KEYS).decode('utf-8')
|
|
30
|
-
else:
|
|
31
|
-
return json.dumps(obj, sort_keys=True, ensure_ascii=False)
|
|
17
|
+
"""快速 JSON 序列化(使用 orjson,比标准 json 快约 10 倍)"""
|
|
18
|
+
return orjson.dumps(obj, option=orjson.OPT_SORT_KEYS).decode('utf-8')
|
|
32
19
|
|
|
33
20
|
|
|
34
21
|
# ============ 错误处理 ============
|
|
@@ -102,8 +89,15 @@ class DataTransformer:
|
|
|
102
89
|
- fields/stats: 数据信息
|
|
103
90
|
"""
|
|
104
91
|
|
|
105
|
-
def __init__(
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
data: Optional[List[Dict[str, Any]]] = None,
|
|
95
|
+
_source_path: Optional[str] = None,
|
|
96
|
+
_lineage_tracker: Optional[LineageTracker] = None,
|
|
97
|
+
):
|
|
106
98
|
self._data = data if data is not None else []
|
|
99
|
+
self._source_path = _source_path
|
|
100
|
+
self._lineage_tracker = _lineage_tracker
|
|
107
101
|
|
|
108
102
|
@property
|
|
109
103
|
def data(self) -> List[Dict[str, Any]]:
|
|
@@ -122,23 +116,38 @@ class DataTransformer:
|
|
|
122
116
|
# ============ 加载/保存 ============
|
|
123
117
|
|
|
124
118
|
@classmethod
|
|
125
|
-
def load(cls, filepath: str) -> 'DataTransformer':
|
|
119
|
+
def load(cls, filepath: str, track_lineage: bool = False) -> 'DataTransformer':
|
|
126
120
|
"""
|
|
127
121
|
从文件加载数据。
|
|
128
122
|
|
|
129
123
|
支持格式: jsonl, json, csv, parquet(自动检测)
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
filepath: 文件路径
|
|
127
|
+
track_lineage: 是否追踪血缘(默认 False)
|
|
130
128
|
"""
|
|
131
129
|
data = load_data(filepath)
|
|
132
|
-
|
|
130
|
+
tracker = LineageTracker(filepath) if track_lineage else None
|
|
131
|
+
return cls(data, _source_path=filepath, _lineage_tracker=tracker)
|
|
133
132
|
|
|
134
|
-
def save(self, filepath: str) -> None:
|
|
133
|
+
def save(self, filepath: str, lineage: bool = False) -> None:
|
|
135
134
|
"""
|
|
136
135
|
保存数据到文件。
|
|
137
136
|
|
|
138
137
|
支持格式: jsonl, json, csv, parquet(根据扩展名)
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
filepath: 文件路径
|
|
141
|
+
lineage: 是否保存血缘元数据(默认 False)
|
|
139
142
|
"""
|
|
140
143
|
save_data(self._data, filepath)
|
|
141
144
|
|
|
145
|
+
# 保存血缘记录
|
|
146
|
+
if lineage and self._lineage_tracker:
|
|
147
|
+
lineage_path = self._lineage_tracker.save(filepath, len(self._data))
|
|
148
|
+
import sys
|
|
149
|
+
print(f"📜 血缘记录已保存: {lineage_path}", file=sys.stderr)
|
|
150
|
+
|
|
142
151
|
# ============ 核心转换 ============
|
|
143
152
|
|
|
144
153
|
def to(
|
|
@@ -230,7 +239,16 @@ class DataTransformer:
|
|
|
230
239
|
>>> # 原始模式(大数据集推荐)
|
|
231
240
|
>>> dt.transform(lambda x: {"q": x["q"]}, raw=True).save("output.jsonl")
|
|
232
241
|
"""
|
|
233
|
-
|
|
242
|
+
input_count = len(self._data)
|
|
243
|
+
result = self.to(func, on_error=on_error, raw=raw)
|
|
244
|
+
output_count = len(result)
|
|
245
|
+
|
|
246
|
+
# 传递血缘追踪器并记录操作
|
|
247
|
+
tracker = self._lineage_tracker
|
|
248
|
+
if tracker:
|
|
249
|
+
tracker.record("transform", {"func": func}, input_count, output_count)
|
|
250
|
+
|
|
251
|
+
return DataTransformer(result, _lineage_tracker=tracker)
|
|
234
252
|
|
|
235
253
|
# ============ 数据筛选 ============
|
|
236
254
|
|
|
@@ -281,7 +299,12 @@ class DataTransformer:
|
|
|
281
299
|
if errors:
|
|
282
300
|
_print_error_summary(errors, len(self._data))
|
|
283
301
|
|
|
284
|
-
|
|
302
|
+
# 传递血缘追踪器并记录操作
|
|
303
|
+
tracker = self._lineage_tracker
|
|
304
|
+
if tracker:
|
|
305
|
+
tracker.record("filter", {"func": func}, len(self._data), len(filtered))
|
|
306
|
+
|
|
307
|
+
return DataTransformer(filtered, _lineage_tracker=tracker)
|
|
285
308
|
|
|
286
309
|
def sample(self, n: int, seed: Optional[int] = None) -> 'DataTransformer':
|
|
287
310
|
"""
|
|
@@ -295,16 +318,30 @@ class DataTransformer:
|
|
|
295
318
|
if seed is not None:
|
|
296
319
|
random.seed(seed)
|
|
297
320
|
|
|
321
|
+
input_count = len(self._data)
|
|
298
322
|
data = self._data[:] if n >= len(self._data) else random.sample(self._data, n)
|
|
299
|
-
|
|
323
|
+
|
|
324
|
+
tracker = self._lineage_tracker
|
|
325
|
+
if tracker:
|
|
326
|
+
tracker.record("sample", {"n": n, "seed": seed}, input_count, len(data))
|
|
327
|
+
|
|
328
|
+
return DataTransformer(data, _lineage_tracker=tracker)
|
|
300
329
|
|
|
301
330
|
def head(self, n: int = 10) -> 'DataTransformer':
|
|
302
331
|
"""取前 n 条"""
|
|
303
|
-
|
|
332
|
+
data = self._data[:n]
|
|
333
|
+
tracker = self._lineage_tracker
|
|
334
|
+
if tracker:
|
|
335
|
+
tracker.record("head", {"n": n}, len(self._data), len(data))
|
|
336
|
+
return DataTransformer(data, _lineage_tracker=tracker)
|
|
304
337
|
|
|
305
338
|
def tail(self, n: int = 10) -> 'DataTransformer':
|
|
306
339
|
"""取后 n 条"""
|
|
307
|
-
|
|
340
|
+
data = self._data[-n:]
|
|
341
|
+
tracker = self._lineage_tracker
|
|
342
|
+
if tracker:
|
|
343
|
+
tracker.record("tail", {"n": n}, len(self._data), len(data))
|
|
344
|
+
return DataTransformer(data, _lineage_tracker=tracker)
|
|
308
345
|
|
|
309
346
|
def dedupe(
|
|
310
347
|
self,
|
|
@@ -338,7 +375,11 @@ class DataTransformer:
|
|
|
338
375
|
seen.add(k)
|
|
339
376
|
result.append(item)
|
|
340
377
|
|
|
341
|
-
|
|
378
|
+
tracker = self._lineage_tracker
|
|
379
|
+
if tracker:
|
|
380
|
+
tracker.record("dedupe", {"key": key}, len(self._data), len(result))
|
|
381
|
+
|
|
382
|
+
return DataTransformer(result, _lineage_tracker=tracker)
|
|
342
383
|
|
|
343
384
|
def _get_dedupe_key(
|
|
344
385
|
self,
|
|
@@ -442,7 +483,17 @@ class DataTransformer:
|
|
|
442
483
|
|
|
443
484
|
# 按原顺序保留数据
|
|
444
485
|
result = [self._data[i] for i in sorted(keep_indices)]
|
|
445
|
-
|
|
486
|
+
|
|
487
|
+
tracker = self._lineage_tracker
|
|
488
|
+
if tracker:
|
|
489
|
+
tracker.record(
|
|
490
|
+
"dedupe_similar",
|
|
491
|
+
{"key": key, "threshold": threshold, "num_perm": num_perm, "ngram": ngram},
|
|
492
|
+
len(self._data),
|
|
493
|
+
len(result),
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
return DataTransformer(result, _lineage_tracker=tracker)
|
|
446
497
|
|
|
447
498
|
def _get_text_for_similarity(
|
|
448
499
|
self,
|
|
@@ -581,7 +632,12 @@ class DataTransformer:
|
|
|
581
632
|
if seed is not None:
|
|
582
633
|
random.seed(seed)
|
|
583
634
|
random.shuffle(data)
|
|
584
|
-
|
|
635
|
+
|
|
636
|
+
tracker = self._lineage_tracker
|
|
637
|
+
if tracker:
|
|
638
|
+
tracker.record("shuffle", {"seed": seed}, len(self._data), len(data))
|
|
639
|
+
|
|
640
|
+
return DataTransformer(data, _lineage_tracker=tracker)
|
|
585
641
|
|
|
586
642
|
def split(self, ratio: float = 0.8, seed: Optional[int] = None) -> tuple:
|
|
587
643
|
"""
|
|
@@ -596,7 +652,16 @@ class DataTransformer:
|
|
|
596
652
|
"""
|
|
597
653
|
data = self.shuffle(seed).data
|
|
598
654
|
split_idx = int(len(data) * ratio)
|
|
599
|
-
|
|
655
|
+
|
|
656
|
+
# 分割后血缘追踪器各自独立
|
|
657
|
+
tracker = self._lineage_tracker
|
|
658
|
+
if tracker:
|
|
659
|
+
tracker.record("split", {"ratio": ratio, "seed": seed}, len(self._data), len(data))
|
|
660
|
+
|
|
661
|
+
return (
|
|
662
|
+
DataTransformer(data[:split_idx], _lineage_tracker=tracker),
|
|
663
|
+
DataTransformer(data[split_idx:], _lineage_tracker=tracker),
|
|
664
|
+
)
|
|
600
665
|
|
|
601
666
|
# ============ 并行处理 ============
|
|
602
667
|
|