dtflow 0.4.3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/__init__.py +34 -1
- dtflow/__main__.py +22 -0
- dtflow/cli/commands.py +5 -0
- dtflow/cli/common.py +13 -9
- dtflow/cli/stats.py +114 -36
- dtflow/cli/validate.py +152 -0
- dtflow/core.py +220 -10
- dtflow/framework.py +610 -0
- dtflow/lineage.py +17 -0
- dtflow/schema.py +508 -0
- dtflow/streaming.py +93 -35
- dtflow/tokenizers.py +84 -29
- dtflow/utils/field_path.py +6 -2
- {dtflow-0.4.3.dist-info → dtflow-0.5.2.dist-info}/METADATA +117 -2
- {dtflow-0.4.3.dist-info → dtflow-0.5.2.dist-info}/RECORD +17 -14
- {dtflow-0.4.3.dist-info → dtflow-0.5.2.dist-info}/WHEEL +0 -0
- {dtflow-0.4.3.dist-info → dtflow-0.5.2.dist-info}/entry_points.txt +0 -0
dtflow/core.py
CHANGED
|
@@ -386,6 +386,88 @@ class DataTransformer:
|
|
|
386
386
|
|
|
387
387
|
return errors
|
|
388
388
|
|
|
389
|
+
def validate_schema(
|
|
390
|
+
self,
|
|
391
|
+
schema: "Schema",
|
|
392
|
+
on_error: Literal["skip", "raise", "filter"] = "skip",
|
|
393
|
+
max_errors: int = 100,
|
|
394
|
+
) -> Union["DataTransformer", List[tuple]]:
|
|
395
|
+
"""
|
|
396
|
+
使用 Schema 验证数据结构。
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
schema: Schema 对象,定义数据结构验证规则
|
|
400
|
+
on_error: 错误处理方式
|
|
401
|
+
- "skip": 打印警告,返回验证失败的记录列表
|
|
402
|
+
- "raise": 第一个错误时抛出异常
|
|
403
|
+
- "filter": 过滤掉验证失败的记录,返回新的 DataTransformer
|
|
404
|
+
max_errors: 最大错误数量(on_error="skip" 时生效)
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
- on_error="skip": 返回 [(index, ValidationResult), ...] 失败记录列表
|
|
408
|
+
- on_error="raise": 无返回(成功)或抛出 ValueError
|
|
409
|
+
- on_error="filter": 返回过滤后的新 DataTransformer
|
|
410
|
+
|
|
411
|
+
Examples:
|
|
412
|
+
>>> from dtflow import Schema, Field
|
|
413
|
+
>>> schema = Schema({
|
|
414
|
+
... "messages": Field(type="list", required=True, min_length=1),
|
|
415
|
+
... "messages[*].role": Field(type="str", choices=["user", "assistant"]),
|
|
416
|
+
... })
|
|
417
|
+
|
|
418
|
+
>>> # 获取验证失败的记录
|
|
419
|
+
>>> errors = dt.validate_schema(schema)
|
|
420
|
+
>>> for idx, result in errors:
|
|
421
|
+
... print(f"第 {idx} 行验证失败: {result.errors}")
|
|
422
|
+
|
|
423
|
+
>>> # 过滤掉无效记录
|
|
424
|
+
>>> valid_dt = dt.validate_schema(schema, on_error="filter")
|
|
425
|
+
|
|
426
|
+
>>> # 遇到错误立即停止
|
|
427
|
+
>>> dt.validate_schema(schema, on_error="raise")
|
|
428
|
+
"""
|
|
429
|
+
from .schema import Schema, ValidationResult
|
|
430
|
+
|
|
431
|
+
failed: List[tuple] = []
|
|
432
|
+
valid_data: List[dict] = []
|
|
433
|
+
error_count = 0
|
|
434
|
+
|
|
435
|
+
for i, item in enumerate(self._data):
|
|
436
|
+
result = schema.validate(item)
|
|
437
|
+
if result.valid:
|
|
438
|
+
valid_data.append(item)
|
|
439
|
+
else:
|
|
440
|
+
failed.append((i, result))
|
|
441
|
+
error_count += len(result.errors)
|
|
442
|
+
|
|
443
|
+
if on_error == "raise":
|
|
444
|
+
error_msgs = [str(e) for e in result.errors[:3]]
|
|
445
|
+
raise ValueError(
|
|
446
|
+
f"第 {i} 行验证失败:\n " + "\n ".join(error_msgs)
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
if on_error == "skip" and error_count >= max_errors:
|
|
450
|
+
print(f"⚠️ 已达到最大错误数 {max_errors},停止验证")
|
|
451
|
+
break
|
|
452
|
+
|
|
453
|
+
if on_error == "skip":
|
|
454
|
+
if failed:
|
|
455
|
+
print(f"⚠️ 验证失败 {len(failed)} 条记录(共 {error_count} 个错误)")
|
|
456
|
+
return failed
|
|
457
|
+
|
|
458
|
+
if on_error == "filter":
|
|
459
|
+
tracker = self._lineage_tracker
|
|
460
|
+
if tracker:
|
|
461
|
+
tracker.record(
|
|
462
|
+
"validate_schema",
|
|
463
|
+
{"schema": repr(schema), "on_error": on_error},
|
|
464
|
+
len(self._data),
|
|
465
|
+
len(valid_data),
|
|
466
|
+
)
|
|
467
|
+
return DataTransformer(valid_data, _lineage_tracker=tracker)
|
|
468
|
+
|
|
469
|
+
return failed
|
|
470
|
+
|
|
389
471
|
def dedupe(
|
|
390
472
|
self,
|
|
391
473
|
key: Union[None, str, List[str], Callable[[Any], Any]] = None,
|
|
@@ -711,19 +793,29 @@ class DataTransformer:
|
|
|
711
793
|
seed: 随机种子
|
|
712
794
|
|
|
713
795
|
Returns:
|
|
714
|
-
(train, test) 两个 DataTransformer
|
|
796
|
+
(train, test) 两个 DataTransformer,各自拥有独立的血缘追踪器
|
|
715
797
|
"""
|
|
716
798
|
data = self.shuffle(seed).data
|
|
717
799
|
split_idx = int(len(data) * ratio)
|
|
718
800
|
|
|
719
|
-
#
|
|
801
|
+
# 分割后血缘追踪器各自独立(使用深拷贝避免相互影响)
|
|
720
802
|
tracker = self._lineage_tracker
|
|
803
|
+
train_tracker = None
|
|
804
|
+
test_tracker = None
|
|
805
|
+
|
|
721
806
|
if tracker:
|
|
722
807
|
tracker.record("split", {"ratio": ratio, "seed": seed}, len(self._data), len(data))
|
|
808
|
+
# 为每个子数据集创建独立的追踪器副本
|
|
809
|
+
train_tracker = tracker.copy()
|
|
810
|
+
train_tracker.record("split_part", {"part": "train", "ratio": ratio}, len(data), split_idx)
|
|
811
|
+
test_tracker = tracker.copy()
|
|
812
|
+
test_tracker.record(
|
|
813
|
+
"split_part", {"part": "test", "ratio": 1 - ratio}, len(data), len(data) - split_idx
|
|
814
|
+
)
|
|
723
815
|
|
|
724
816
|
return (
|
|
725
|
-
DataTransformer(data[:split_idx], _lineage_tracker=
|
|
726
|
-
DataTransformer(data[split_idx:], _lineage_tracker=
|
|
817
|
+
DataTransformer(data[:split_idx], _lineage_tracker=train_tracker),
|
|
818
|
+
DataTransformer(data[split_idx:], _lineage_tracker=test_tracker),
|
|
727
819
|
)
|
|
728
820
|
|
|
729
821
|
# ============ 并行处理 ============
|
|
@@ -733,6 +825,7 @@ class DataTransformer:
|
|
|
733
825
|
func: Callable[[Dict], Any],
|
|
734
826
|
workers: Optional[int] = None,
|
|
735
827
|
chunksize: int = 1000,
|
|
828
|
+
timeout: Optional[float] = None,
|
|
736
829
|
) -> List[Any]:
|
|
737
830
|
"""
|
|
738
831
|
并行执行转换函数(使用多进程)。
|
|
@@ -743,24 +836,46 @@ class DataTransformer:
|
|
|
743
836
|
func: 转换函数,接收原始 dict,返回转换结果
|
|
744
837
|
workers: 进程数,默认为 CPU 核心数
|
|
745
838
|
chunksize: 每个进程处理的数据块大小
|
|
839
|
+
timeout: 超时时间(秒),None 表示无超时
|
|
746
840
|
|
|
747
841
|
Returns:
|
|
748
842
|
转换后的结果列表
|
|
749
843
|
|
|
844
|
+
Raises:
|
|
845
|
+
TypeError: 如果 func 无法被 pickle(如 lambda 函数)
|
|
846
|
+
RuntimeError: 如果子进程执行出错或超时
|
|
847
|
+
|
|
750
848
|
Examples:
|
|
751
849
|
>>> def transform(item):
|
|
752
850
|
... return {"id": item["id"], "text": item["text"].upper()}
|
|
753
851
|
>>> results = dt.map_parallel(transform)
|
|
754
852
|
"""
|
|
755
|
-
from multiprocessing import Pool, cpu_count
|
|
853
|
+
from multiprocessing import Pool, TimeoutError, cpu_count
|
|
854
|
+
import pickle
|
|
756
855
|
|
|
757
856
|
if not self._data:
|
|
758
857
|
return []
|
|
759
858
|
|
|
859
|
+
# 检查函数是否可 pickle
|
|
860
|
+
try:
|
|
861
|
+
pickle.dumps(func)
|
|
862
|
+
except (pickle.PicklingError, AttributeError, TypeError) as e:
|
|
863
|
+
func_name = getattr(func, "__name__", str(func))
|
|
864
|
+
raise TypeError(
|
|
865
|
+
f"函数 '{func_name}' 无法被 pickle,不能用于并行处理。"
|
|
866
|
+
f"请使用模块级函数而非 lambda 或闭包。错误: {e}"
|
|
867
|
+
) from e
|
|
868
|
+
|
|
760
869
|
workers = workers or cpu_count()
|
|
761
870
|
|
|
762
|
-
|
|
763
|
-
|
|
871
|
+
try:
|
|
872
|
+
with Pool(workers) as pool:
|
|
873
|
+
async_result = pool.map_async(func, self._data, chunksize=chunksize)
|
|
874
|
+
results = async_result.get(timeout=timeout)
|
|
875
|
+
except TimeoutError:
|
|
876
|
+
raise RuntimeError(f"并行处理超时({timeout}秒)")
|
|
877
|
+
except Exception as e:
|
|
878
|
+
raise RuntimeError(f"并行处理失败: {type(e).__name__}: {e}") from e
|
|
764
879
|
|
|
765
880
|
return results
|
|
766
881
|
|
|
@@ -769,6 +884,7 @@ class DataTransformer:
|
|
|
769
884
|
func: Callable[[Dict], bool],
|
|
770
885
|
workers: Optional[int] = None,
|
|
771
886
|
chunksize: int = 1000,
|
|
887
|
+
timeout: Optional[float] = None,
|
|
772
888
|
) -> "DataTransformer":
|
|
773
889
|
"""
|
|
774
890
|
并行执行过滤函数(使用多进程)。
|
|
@@ -779,28 +895,122 @@ class DataTransformer:
|
|
|
779
895
|
func: 过滤函数,接收原始 dict,返回 True 保留
|
|
780
896
|
workers: 进程数,默认为 CPU 核心数
|
|
781
897
|
chunksize: 每个进程处理的数据块大小
|
|
898
|
+
timeout: 超时时间(秒),None 表示无超时
|
|
782
899
|
|
|
783
900
|
Returns:
|
|
784
901
|
过滤后的新 DataTransformer
|
|
785
902
|
|
|
903
|
+
Raises:
|
|
904
|
+
TypeError: 如果 func 无法被 pickle(如 lambda 函数)
|
|
905
|
+
RuntimeError: 如果子进程执行出错或超时
|
|
906
|
+
|
|
786
907
|
Examples:
|
|
787
908
|
>>> def is_valid(item):
|
|
788
909
|
... return len(item["text"]) > 10
|
|
789
910
|
>>> filtered = dt.filter_parallel(is_valid)
|
|
790
911
|
"""
|
|
791
|
-
from multiprocessing import Pool, cpu_count
|
|
912
|
+
from multiprocessing import Pool, TimeoutError, cpu_count
|
|
913
|
+
import pickle
|
|
792
914
|
|
|
793
915
|
if not self._data:
|
|
794
916
|
return DataTransformer([])
|
|
795
917
|
|
|
918
|
+
# 检查函数是否可 pickle
|
|
919
|
+
try:
|
|
920
|
+
pickle.dumps(func)
|
|
921
|
+
except (pickle.PicklingError, AttributeError, TypeError) as e:
|
|
922
|
+
func_name = getattr(func, "__name__", str(func))
|
|
923
|
+
raise TypeError(
|
|
924
|
+
f"函数 '{func_name}' 无法被 pickle,不能用于并行处理。"
|
|
925
|
+
f"请使用模块级函数而非 lambda 或闭包。错误: {e}"
|
|
926
|
+
) from e
|
|
927
|
+
|
|
796
928
|
workers = workers or cpu_count()
|
|
797
929
|
|
|
798
|
-
|
|
799
|
-
|
|
930
|
+
try:
|
|
931
|
+
with Pool(workers) as pool:
|
|
932
|
+
async_result = pool.map_async(func, self._data, chunksize=chunksize)
|
|
933
|
+
mask = async_result.get(timeout=timeout)
|
|
934
|
+
except TimeoutError:
|
|
935
|
+
raise RuntimeError(f"并行处理超时({timeout}秒)")
|
|
936
|
+
except Exception as e:
|
|
937
|
+
raise RuntimeError(f"并行处理失败: {type(e).__name__}: {e}") from e
|
|
800
938
|
|
|
801
939
|
filtered = [item for item, keep in zip(self._data, mask) if keep]
|
|
802
940
|
return DataTransformer(filtered)
|
|
803
941
|
|
|
942
|
+
# ============ 训练框架集成 ============
|
|
943
|
+
|
|
944
|
+
def check_compatibility(
|
|
945
|
+
self,
|
|
946
|
+
framework: Literal["llama-factory", "swift", "axolotl"],
|
|
947
|
+
) -> "CompatibilityResult":
|
|
948
|
+
"""
|
|
949
|
+
检查数据与目标训练框架的兼容性。
|
|
950
|
+
|
|
951
|
+
Args:
|
|
952
|
+
framework: 目标框架名称
|
|
953
|
+
- "llama-factory": LLaMA-Factory
|
|
954
|
+
- "swift": ms-swift (ModelScope)
|
|
955
|
+
- "axolotl": Axolotl
|
|
956
|
+
|
|
957
|
+
Returns:
|
|
958
|
+
CompatibilityResult 对象,包含 valid, errors, warnings, suggestions
|
|
959
|
+
|
|
960
|
+
Examples:
|
|
961
|
+
>>> result = dt.check_compatibility("llama-factory")
|
|
962
|
+
>>> if result.valid:
|
|
963
|
+
... print("兼容!")
|
|
964
|
+
>>> else:
|
|
965
|
+
... print(result.errors)
|
|
966
|
+
"""
|
|
967
|
+
from .framework import check_compatibility
|
|
968
|
+
|
|
969
|
+
return check_compatibility(self._data, framework)
|
|
970
|
+
|
|
971
|
+
def export_for(
|
|
972
|
+
self,
|
|
973
|
+
framework: Literal["llama-factory", "swift", "axolotl"],
|
|
974
|
+
output_dir: str,
|
|
975
|
+
dataset_name: str = "custom_dataset",
|
|
976
|
+
**kwargs,
|
|
977
|
+
) -> Dict[str, str]:
|
|
978
|
+
"""
|
|
979
|
+
一键导出数据和配置文件到目标训练框架。
|
|
980
|
+
|
|
981
|
+
Args:
|
|
982
|
+
framework: 目标框架名称
|
|
983
|
+
output_dir: 输出目录
|
|
984
|
+
dataset_name: 数据集名称
|
|
985
|
+
**kwargs: 框架特定参数(如 model_name)
|
|
986
|
+
|
|
987
|
+
Returns:
|
|
988
|
+
生成的文件路径字典 {"data": "...", "config": "...", ...}
|
|
989
|
+
|
|
990
|
+
Examples:
|
|
991
|
+
>>> # 导出到 LLaMA-Factory
|
|
992
|
+
>>> dt.export_for("llama-factory", "./llama_ready")
|
|
993
|
+
# 生成:
|
|
994
|
+
# - ./llama_ready/custom_dataset.json
|
|
995
|
+
# - ./llama_ready/dataset_info.json
|
|
996
|
+
# - ./llama_ready/train_args.yaml
|
|
997
|
+
|
|
998
|
+
>>> # 导出到 ms-swift
|
|
999
|
+
>>> dt.export_for("swift", "./swift_ready", dataset_name="my_data")
|
|
1000
|
+
|
|
1001
|
+
>>> # 导出到 Axolotl
|
|
1002
|
+
>>> dt.export_for("axolotl", "./axolotl_ready")
|
|
1003
|
+
"""
|
|
1004
|
+
from .framework import export_for
|
|
1005
|
+
|
|
1006
|
+
return export_for(
|
|
1007
|
+
self._data,
|
|
1008
|
+
framework,
|
|
1009
|
+
output_dir,
|
|
1010
|
+
dataset_name=dataset_name,
|
|
1011
|
+
**kwargs,
|
|
1012
|
+
)
|
|
1013
|
+
|
|
804
1014
|
|
|
805
1015
|
def _sanitize_key(name: str) -> str:
|
|
806
1016
|
"""将字段名规范化为合法的 Python 标识符"""
|