dtflow 0.4.3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/core.py CHANGED
@@ -386,6 +386,88 @@ class DataTransformer:
386
386
 
387
387
  return errors
388
388
 
389
+ def validate_schema(
390
+ self,
391
+ schema: "Schema",
392
+ on_error: Literal["skip", "raise", "filter"] = "skip",
393
+ max_errors: int = 100,
394
+ ) -> Union["DataTransformer", List[tuple]]:
395
+ """
396
+ 使用 Schema 验证数据结构。
397
+
398
+ Args:
399
+ schema: Schema 对象,定义数据结构验证规则
400
+ on_error: 错误处理方式
401
+ - "skip": 打印警告,返回验证失败的记录列表
402
+ - "raise": 第一个错误时抛出异常
403
+ - "filter": 过滤掉验证失败的记录,返回新的 DataTransformer
404
+ max_errors: 最大错误数量(on_error="skip" 时生效)
405
+
406
+ Returns:
407
+ - on_error="skip": 返回 [(index, ValidationResult), ...] 失败记录列表
408
+ - on_error="raise": 无返回(成功)或抛出 ValueError
409
+ - on_error="filter": 返回过滤后的新 DataTransformer
410
+
411
+ Examples:
412
+ >>> from dtflow import Schema, Field
413
+ >>> schema = Schema({
414
+ ... "messages": Field(type="list", required=True, min_length=1),
415
+ ... "messages[*].role": Field(type="str", choices=["user", "assistant"]),
416
+ ... })
417
+
418
+ >>> # 获取验证失败的记录
419
+ >>> errors = dt.validate_schema(schema)
420
+ >>> for idx, result in errors:
421
+ ... print(f"第 {idx} 行验证失败: {result.errors}")
422
+
423
+ >>> # 过滤掉无效记录
424
+ >>> valid_dt = dt.validate_schema(schema, on_error="filter")
425
+
426
+ >>> # 遇到错误立即停止
427
+ >>> dt.validate_schema(schema, on_error="raise")
428
+ """
429
+ from .schema import Schema, ValidationResult
430
+
431
+ failed: List[tuple] = []
432
+ valid_data: List[dict] = []
433
+ error_count = 0
434
+
435
+ for i, item in enumerate(self._data):
436
+ result = schema.validate(item)
437
+ if result.valid:
438
+ valid_data.append(item)
439
+ else:
440
+ failed.append((i, result))
441
+ error_count += len(result.errors)
442
+
443
+ if on_error == "raise":
444
+ error_msgs = [str(e) for e in result.errors[:3]]
445
+ raise ValueError(
446
+ f"第 {i} 行验证失败:\n " + "\n ".join(error_msgs)
447
+ )
448
+
449
+ if on_error == "skip" and error_count >= max_errors:
450
+ print(f"⚠️ 已达到最大错误数 {max_errors},停止验证")
451
+ break
452
+
453
+ if on_error == "skip":
454
+ if failed:
455
+ print(f"⚠️ 验证失败 {len(failed)} 条记录(共 {error_count} 个错误)")
456
+ return failed
457
+
458
+ if on_error == "filter":
459
+ tracker = self._lineage_tracker
460
+ if tracker:
461
+ tracker.record(
462
+ "validate_schema",
463
+ {"schema": repr(schema), "on_error": on_error},
464
+ len(self._data),
465
+ len(valid_data),
466
+ )
467
+ return DataTransformer(valid_data, _lineage_tracker=tracker)
468
+
469
+ return failed
470
+
389
471
  def dedupe(
390
472
  self,
391
473
  key: Union[None, str, List[str], Callable[[Any], Any]] = None,
@@ -711,19 +793,29 @@ class DataTransformer:
711
793
  seed: 随机种子
712
794
 
713
795
  Returns:
714
- (train, test) 两个 DataTransformer
796
+ (train, test) 两个 DataTransformer,各自拥有独立的血缘追踪器
715
797
  """
716
798
  data = self.shuffle(seed).data
717
799
  split_idx = int(len(data) * ratio)
718
800
 
719
- # 分割后血缘追踪器各自独立
801
+ # 分割后血缘追踪器各自独立(使用深拷贝避免相互影响)
720
802
  tracker = self._lineage_tracker
803
+ train_tracker = None
804
+ test_tracker = None
805
+
721
806
  if tracker:
722
807
  tracker.record("split", {"ratio": ratio, "seed": seed}, len(self._data), len(data))
808
+ # 为每个子数据集创建独立的追踪器副本
809
+ train_tracker = tracker.copy()
810
+ train_tracker.record("split_part", {"part": "train", "ratio": ratio}, len(data), split_idx)
811
+ test_tracker = tracker.copy()
812
+ test_tracker.record(
813
+ "split_part", {"part": "test", "ratio": 1 - ratio}, len(data), len(data) - split_idx
814
+ )
723
815
 
724
816
  return (
725
- DataTransformer(data[:split_idx], _lineage_tracker=tracker),
726
- DataTransformer(data[split_idx:], _lineage_tracker=tracker),
817
+ DataTransformer(data[:split_idx], _lineage_tracker=train_tracker),
818
+ DataTransformer(data[split_idx:], _lineage_tracker=test_tracker),
727
819
  )
728
820
 
729
821
  # ============ 并行处理 ============
@@ -733,6 +825,7 @@ class DataTransformer:
733
825
  func: Callable[[Dict], Any],
734
826
  workers: Optional[int] = None,
735
827
  chunksize: int = 1000,
828
+ timeout: Optional[float] = None,
736
829
  ) -> List[Any]:
737
830
  """
738
831
  并行执行转换函数(使用多进程)。
@@ -743,24 +836,46 @@ class DataTransformer:
743
836
  func: 转换函数,接收原始 dict,返回转换结果
744
837
  workers: 进程数,默认为 CPU 核心数
745
838
  chunksize: 每个进程处理的数据块大小
839
+ timeout: 超时时间(秒),None 表示无超时
746
840
 
747
841
  Returns:
748
842
  转换后的结果列表
749
843
 
844
+ Raises:
845
+ TypeError: 如果 func 无法被 pickle(如 lambda 函数)
846
+ RuntimeError: 如果子进程执行出错或超时
847
+
750
848
  Examples:
751
849
  >>> def transform(item):
752
850
  ... return {"id": item["id"], "text": item["text"].upper()}
753
851
  >>> results = dt.map_parallel(transform)
754
852
  """
755
- from multiprocessing import Pool, cpu_count
853
+ from multiprocessing import Pool, TimeoutError, cpu_count
854
+ import pickle
756
855
 
757
856
  if not self._data:
758
857
  return []
759
858
 
859
+ # 检查函数是否可 pickle
860
+ try:
861
+ pickle.dumps(func)
862
+ except (pickle.PicklingError, AttributeError, TypeError) as e:
863
+ func_name = getattr(func, "__name__", str(func))
864
+ raise TypeError(
865
+ f"函数 '{func_name}' 无法被 pickle,不能用于并行处理。"
866
+ f"请使用模块级函数而非 lambda 或闭包。错误: {e}"
867
+ ) from e
868
+
760
869
  workers = workers or cpu_count()
761
870
 
762
- with Pool(workers) as pool:
763
- results = pool.map(func, self._data, chunksize=chunksize)
871
+ try:
872
+ with Pool(workers) as pool:
873
+ async_result = pool.map_async(func, self._data, chunksize=chunksize)
874
+ results = async_result.get(timeout=timeout)
875
+ except TimeoutError:
876
+ raise RuntimeError(f"并行处理超时({timeout}秒)")
877
+ except Exception as e:
878
+ raise RuntimeError(f"并行处理失败: {type(e).__name__}: {e}") from e
764
879
 
765
880
  return results
766
881
 
@@ -769,6 +884,7 @@ class DataTransformer:
769
884
  func: Callable[[Dict], bool],
770
885
  workers: Optional[int] = None,
771
886
  chunksize: int = 1000,
887
+ timeout: Optional[float] = None,
772
888
  ) -> "DataTransformer":
773
889
  """
774
890
  并行执行过滤函数(使用多进程)。
@@ -779,28 +895,122 @@ class DataTransformer:
779
895
  func: 过滤函数,接收原始 dict,返回 True 保留
780
896
  workers: 进程数,默认为 CPU 核心数
781
897
  chunksize: 每个进程处理的数据块大小
898
+ timeout: 超时时间(秒),None 表示无超时
782
899
 
783
900
  Returns:
784
901
  过滤后的新 DataTransformer
785
902
 
903
+ Raises:
904
+ TypeError: 如果 func 无法被 pickle(如 lambda 函数)
905
+ RuntimeError: 如果子进程执行出错或超时
906
+
786
907
  Examples:
787
908
  >>> def is_valid(item):
788
909
  ... return len(item["text"]) > 10
789
910
  >>> filtered = dt.filter_parallel(is_valid)
790
911
  """
791
- from multiprocessing import Pool, cpu_count
912
+ from multiprocessing import Pool, TimeoutError, cpu_count
913
+ import pickle
792
914
 
793
915
  if not self._data:
794
916
  return DataTransformer([])
795
917
 
918
+ # 检查函数是否可 pickle
919
+ try:
920
+ pickle.dumps(func)
921
+ except (pickle.PicklingError, AttributeError, TypeError) as e:
922
+ func_name = getattr(func, "__name__", str(func))
923
+ raise TypeError(
924
+ f"函数 '{func_name}' 无法被 pickle,不能用于并行处理。"
925
+ f"请使用模块级函数而非 lambda 或闭包。错误: {e}"
926
+ ) from e
927
+
796
928
  workers = workers or cpu_count()
797
929
 
798
- with Pool(workers) as pool:
799
- mask = pool.map(func, self._data, chunksize=chunksize)
930
+ try:
931
+ with Pool(workers) as pool:
932
+ async_result = pool.map_async(func, self._data, chunksize=chunksize)
933
+ mask = async_result.get(timeout=timeout)
934
+ except TimeoutError:
935
+ raise RuntimeError(f"并行处理超时({timeout}秒)")
936
+ except Exception as e:
937
+ raise RuntimeError(f"并行处理失败: {type(e).__name__}: {e}") from e
800
938
 
801
939
  filtered = [item for item, keep in zip(self._data, mask) if keep]
802
940
  return DataTransformer(filtered)
803
941
 
942
+ # ============ 训练框架集成 ============
943
+
944
+ def check_compatibility(
945
+ self,
946
+ framework: Literal["llama-factory", "swift", "axolotl"],
947
+ ) -> "CompatibilityResult":
948
+ """
949
+ 检查数据与目标训练框架的兼容性。
950
+
951
+ Args:
952
+ framework: 目标框架名称
953
+ - "llama-factory": LLaMA-Factory
954
+ - "swift": ms-swift (ModelScope)
955
+ - "axolotl": Axolotl
956
+
957
+ Returns:
958
+ CompatibilityResult 对象,包含 valid, errors, warnings, suggestions
959
+
960
+ Examples:
961
+ >>> result = dt.check_compatibility("llama-factory")
962
+ >>> if result.valid:
963
+ ... print("兼容!")
964
+ >>> else:
965
+ ... print(result.errors)
966
+ """
967
+ from .framework import check_compatibility
968
+
969
+ return check_compatibility(self._data, framework)
970
+
971
+ def export_for(
972
+ self,
973
+ framework: Literal["llama-factory", "swift", "axolotl"],
974
+ output_dir: str,
975
+ dataset_name: str = "custom_dataset",
976
+ **kwargs,
977
+ ) -> Dict[str, str]:
978
+ """
979
+ 一键导出数据和配置文件到目标训练框架。
980
+
981
+ Args:
982
+ framework: 目标框架名称
983
+ output_dir: 输出目录
984
+ dataset_name: 数据集名称
985
+ **kwargs: 框架特定参数(如 model_name)
986
+
987
+ Returns:
988
+ 生成的文件路径字典 {"data": "...", "config": "...", ...}
989
+
990
+ Examples:
991
+ >>> # 导出到 LLaMA-Factory
992
+ >>> dt.export_for("llama-factory", "./llama_ready")
993
+ # 生成:
994
+ # - ./llama_ready/custom_dataset.json
995
+ # - ./llama_ready/dataset_info.json
996
+ # - ./llama_ready/train_args.yaml
997
+
998
+ >>> # 导出到 ms-swift
999
+ >>> dt.export_for("swift", "./swift_ready", dataset_name="my_data")
1000
+
1001
+ >>> # 导出到 Axolotl
1002
+ >>> dt.export_for("axolotl", "./axolotl_ready")
1003
+ """
1004
+ from .framework import export_for
1005
+
1006
+ return export_for(
1007
+ self._data,
1008
+ framework,
1009
+ output_dir,
1010
+ dataset_name=dataset_name,
1011
+ **kwargs,
1012
+ )
1013
+
804
1014
 
805
1015
  def _sanitize_key(name: str) -> str:
806
1016
  """将字段名规范化为合法的 Python 标识符"""