distflow 0.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. distflow-0.0.0/PKG-INFO +114 -0
  2. distflow-0.0.0/README.md +89 -0
  3. distflow-0.0.0/pyproject.toml +44 -0
  4. distflow-0.0.0/setup.cfg +4 -0
  5. distflow-0.0.0/src/distflow/__init__.py +0 -0
  6. distflow-0.0.0/src/distflow/cache/__init__.py +0 -0
  7. distflow-0.0.0/src/distflow/cache/protocol.py +7 -0
  8. distflow-0.0.0/src/distflow/cache/redis_cache.py +122 -0
  9. distflow-0.0.0/src/distflow/data/__init__.py +0 -0
  10. distflow-0.0.0/src/distflow/data/data_formatter.py +112 -0
  11. distflow-0.0.0/src/distflow/data/data_loader.py +64 -0
  12. distflow-0.0.0/src/distflow/data/types.py +13 -0
  13. distflow-0.0.0/src/distflow/embed/__init__.py +0 -0
  14. distflow-0.0.0/src/distflow/embed/base.py +19 -0
  15. distflow-0.0.0/src/distflow/embed/cache_wrapper.py +157 -0
  16. distflow-0.0.0/src/distflow/embed/openai_embed.py +244 -0
  17. distflow-0.0.0/src/distflow/embed/sentence_transformers.py +152 -0
  18. distflow-0.0.0/src/distflow/embed/types.py +13 -0
  19. distflow-0.0.0/src/distflow/embed/vllm.py +133 -0
  20. distflow-0.0.0/src/distflow/mmd.py +216 -0
  21. distflow-0.0.0/src/distflow/utils/__init__.py +0 -0
  22. distflow-0.0.0/src/distflow/utils/logger.py +126 -0
  23. distflow-0.0.0/src/distflow/utils/stats.py +111 -0
  24. distflow-0.0.0/src/distflow/utils/timing.py +106 -0
  25. distflow-0.0.0/src/distflow.egg-info/PKG-INFO +114 -0
  26. distflow-0.0.0/src/distflow.egg-info/SOURCES.txt +27 -0
  27. distflow-0.0.0/src/distflow.egg-info/dependency_links.txt +1 -0
  28. distflow-0.0.0/src/distflow.egg-info/requires.txt +20 -0
  29. distflow-0.0.0/src/distflow.egg-info/top_level.txt +1 -0
@@ -0,0 +1,114 @@
1
+ Metadata-Version: 2.4
2
+ Name: distflow
3
+ Version: 0.0.0
4
+ Summary: Distance Computation Package for Data Preparation Bench
5
+ Requires-Python: >=3.10
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: aiohttp
8
+ Requires-Dist: datasets
9
+ Requires-Dist: modelscope
10
+ Requires-Dist: openai
11
+ Requires-Dist: pandas
12
+ Requires-Dist: pydantic
13
+ Requires-Dist: pyyaml
14
+ Requires-Dist: redis
15
+ Requires-Dist: scikit-learn
16
+ Requires-Dist: sentence-transformers
17
+ Requires-Dist: torch
18
+ Requires-Dist: transformers
19
+ Provides-Extra: vllm
20
+ Requires-Dist: vllm; extra == "vllm"
21
+ Provides-Extra: dev
22
+ Requires-Dist: pre-commit; extra == "dev"
23
+ Requires-Dist: pyright; extra == "dev"
24
+ Requires-Dist: pytest; extra == "dev"
25
+
26
+ # Data-Preparation-Bench
27
+
28
+ A benchmark for evaluating the data preparation capabilities of large language models (LLMs). The benchmark is organized into two modules:
29
+
30
+ ## Modules
31
+
32
+ ### 1. Data Synthesis & Augmentation
33
+
34
+ Given raw metadata, the model is tasked with synthesizing or augmenting datasets to improve downstream model training.
35
+
36
+ ### 2. Data Quality Assessment
37
+
38
+ Given raw metadata, the model is tasked with predicting the training data's impact on downstream task performance.
39
+
40
+ ## Quick Start
41
+
42
+ ### Usage
43
+
44
+ The package is published on PyPI and can be installed via pip:
45
+
46
+ ```python
47
+ pip install distflow
48
+ ```
49
+
50
+ For vLLM embedding support, install the optional dependency:
51
+
52
+ ```python
53
+ pip install distflow[vllm]
54
+ ```
55
+
56
+ This project uses [uv](https://docs.astral.sh/uv/) for dependency management. To get started:
57
+
58
+ ```bash
59
+ git clone https://github.com/haolpku/Data-Preparation-Bench.git
60
+ cd Data-Preparation-Bench
61
+ uv sync
62
+ ```
63
+
64
+ To use your own datasets, modify the configuration dictionaries and formatters in [compute_mmd.py](./examples/compute_mmd.py):
65
+
66
+ ```python
67
+ DS1_CONFIG = {
68
+ "name": "oda-math",
69
+ "data_path": "OpenDataArena/ODA-Math-460k",
70
+ "data_size": 5000,
71
+ "split": "train",
72
+ "shuffle_seed": 42,
73
+ }
74
+ formatter1 = AlpacaFormatter(
75
+ user_key="question",
76
+ assistant_key="response",
77
+ )
78
+
79
+ DS2_CONFIG = {
80
+ "name": "infinity-instruct",
81
+ "data_path": "BAAI/Infinity-Instruct",
82
+ "data_size": 5000,
83
+ "split": "train",
84
+ "shuffle_seed": 42,
85
+ }
86
+ formatter2 = ShareGptFormatter(
87
+ conversations_key="conversations",
88
+ )
89
+ ```
90
+
91
+ Typically, you only need to update `data_path` with your dataset and define a formatter that converts raw items to the required format. After making these changes, run the MMD computation with:
92
+
93
+ ```bash
94
+ uv run examples/compute_mmd.py
95
+ ```
96
+
97
+ ### Development
98
+
99
+ To set up the development environment locally:
100
+
101
+ ```bash
102
+ uv sync --extra dev
103
+ uv run pre-commit install
104
+ ```
105
+
106
+ Before committing, format and lint the code:
107
+
108
+ ```bash
109
+ uv run pre-commit run --all-files
110
+ ```
111
+
112
+ ## Experiment Settings
113
+
114
+ Please refer to [Experiment.md](./Experiment.md) for detailed experiment configurations.
@@ -0,0 +1,89 @@
1
+ # Data-Preparation-Bench
2
+
3
+ A benchmark for evaluating the data preparation capabilities of large language models (LLMs). The benchmark is organized into two modules:
4
+
5
+ ## Modules
6
+
7
+ ### 1. Data Synthesis & Augmentation
8
+
9
+ Given raw metadata, the model is tasked with synthesizing or augmenting datasets to improve downstream model training.
10
+
11
+ ### 2. Data Quality Assessment
12
+
13
+ Given raw metadata, the model is tasked with predicting the training data's impact on downstream task performance.
14
+
15
+ ## Quick Start
16
+
17
+ ### Usage
18
+
19
+ The package is published on PyPI and can be installed via pip:
20
+
21
+ ```python
22
+ pip install distflow
23
+ ```
24
+
25
+ For vLLM embedding support, install the optional dependency:
26
+
27
+ ```python
28
+ pip install distflow[vllm]
29
+ ```
30
+
31
+ This project uses [uv](https://docs.astral.sh/uv/) for dependency management. To get started:
32
+
33
+ ```bash
34
+ git clone https://github.com/haolpku/Data-Preparation-Bench.git
35
+ cd Data-Preparation-Bench
36
+ uv sync
37
+ ```
38
+
39
+ To use your own datasets, modify the configuration dictionaries and formatters in [compute_mmd.py](./examples/compute_mmd.py):
40
+
41
+ ```python
42
+ DS1_CONFIG = {
43
+ "name": "oda-math",
44
+ "data_path": "OpenDataArena/ODA-Math-460k",
45
+ "data_size": 5000,
46
+ "split": "train",
47
+ "shuffle_seed": 42,
48
+ }
49
+ formatter1 = AlpacaFormatter(
50
+ user_key="question",
51
+ assistant_key="response",
52
+ )
53
+
54
+ DS2_CONFIG = {
55
+ "name": "infinity-instruct",
56
+ "data_path": "BAAI/Infinity-Instruct",
57
+ "data_size": 5000,
58
+ "split": "train",
59
+ "shuffle_seed": 42,
60
+ }
61
+ formatter2 = ShareGptFormatter(
62
+ conversations_key="conversations",
63
+ )
64
+ ```
65
+
66
+ Typically, you only need to update `data_path` with your dataset and define a formatter that converts raw items to the required format. After making these changes, run the MMD computation with:
67
+
68
+ ```bash
69
+ uv run examples/compute_mmd.py
70
+ ```
71
+
72
+ ### Development
73
+
74
+ To set up the development environment locally:
75
+
76
+ ```bash
77
+ uv sync --extra dev
78
+ uv run pre-commit install
79
+ ```
80
+
81
+ Before committing, format and lint the code:
82
+
83
+ ```bash
84
+ uv run pre-commit run --all-files
85
+ ```
86
+
87
+ ## Experiment Settings
88
+
89
+ Please refer to [Experiment.md](./Experiment.md) for detailed experiment configurations.
@@ -0,0 +1,44 @@
1
+ [project]
2
+ name = "distflow"
3
+ dynamic = ["version"]
4
+ description = "Distance Computation Package for Data Preparation Bench"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "aiohttp",
9
+ "datasets",
10
+ "modelscope",
11
+ "openai",
12
+ "pandas",
13
+ "pydantic",
14
+ "pyyaml",
15
+ "redis",
16
+ "scikit-learn",
17
+ "sentence-transformers",
18
+ "torch",
19
+ "transformers",
20
+ ]
21
+
22
+ [project.optional-dependencies]
23
+ vllm = ["vllm"]
24
+ dev = [
25
+ "pre-commit",
26
+ "pyright",
27
+ "pytest",
28
+ ]
29
+
30
+ [tool.black]
31
+ line-length = 88
32
+ target-version = ['py312']
33
+ include = '\.pyi?$'
34
+
35
+ [tool.isort]
36
+ profile = "black"
37
+ line_length = 88
38
+ src_paths = ["src", "tests"]
39
+
40
+ [tool.hatch.version]
41
+ source = "vcs"
42
+
43
+ [tool.hatch.build.targets.wheel]
44
+ packages = ["src/distflow"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
File without changes
File without changes
@@ -0,0 +1,7 @@
1
+ from typing import Any, Protocol
2
+
3
+
4
+ class CacheProtocol(Protocol):
5
+ async def load_cache(self, cache_key: str) -> dict[str, Any] | None: ...
6
+
7
+ async def save_cache(self, cache_key: str, cache_value: dict[str, Any]) -> bool: ...
@@ -0,0 +1,122 @@
1
+ import asyncio
2
+ import json
3
+ from typing import Any
4
+
5
+ from redis.asyncio import Redis
6
+
7
+ from distflow.utils import logger
8
+
9
+
10
+ class RedisCache:
11
+ """使用 Redis 作为缓存后端的实现.
12
+
13
+ 通过 Redis 客户端直接与 Redis 服务通信,实现分布式缓存。
14
+ 使用 semaphore 限制并发请求数量。
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ redis_url: str = "redis://127.0.0.1:6379",
20
+ max_concurrent_requests: int = 50,
21
+ redis_db: int = 0,
22
+ ) -> None:
23
+ """初始化Redis缓存.
24
+
25
+ Args:
26
+ redis_url: Redis 连接 URL,例如 "redis://127.0.0.1:6379"
27
+ max_concurrent_requests: 最大并发请求数
28
+ redis_db: Redis 数据库编号,默认为 0
29
+ """
30
+ self._semaphore = asyncio.Semaphore(max_concurrent_requests)
31
+
32
+ # 初始化 Redis 客户端
33
+ self._redis: Redis | None = None
34
+ self._redis_url = redis_url
35
+ self._redis_db = redis_db
36
+
37
+ def _get_redis(self) -> Redis:
38
+ """获取或创建 Redis 客户端."""
39
+ if self._redis is None:
40
+ self._redis = Redis.from_url(
41
+ self._redis_url,
42
+ db=self._redis_db,
43
+ decode_responses=True,
44
+ )
45
+ try:
46
+ # 测试连接
47
+ self._redis.ping()
48
+ logger.info(
49
+ f"成功连接到 Redis: {self._redis_url}, DB: {self._redis_db}"
50
+ )
51
+ except Exception as e:
52
+ logger.error(
53
+ f"无法连接到 Redis: {self._redis_url}, DB: {self._redis_db}, 错误: {e}"
54
+ )
55
+ raise ConnectionError(
56
+ f"无法连接到 Redis: {self._redis_url}, DB: {self._redis_db}"
57
+ ) from e
58
+ return self._redis
59
+
60
+ async def load_cache(self, cache_key: str) -> dict[str, Any] | None:
61
+ """从 Redis 获取单个缓存值(受 semaphore 限制并发).
62
+
63
+ Args:
64
+ cache_key: 缓存键
65
+
66
+ Returns:
67
+ 缓存值字典,如果不存在则返回 None
68
+ """
69
+ for attempt in range(3):
70
+ async with self._semaphore:
71
+ try:
72
+ redis = self._get_redis()
73
+ cached_data = await redis.get(cache_key)
74
+ if cached_data:
75
+ return json.loads(cached_data)
76
+ return None
77
+ except Exception as e:
78
+ logger.warning(
79
+ f"Redis 缓存查询失败 {attempt + 1} / 3: {type(e).__name__}: {e}"
80
+ )
81
+ await asyncio.sleep(0.1 * (attempt + 1)) # 简单的指数退避
82
+ self._redis = None # 重置 Redis 客户端以尝试重新连接
83
+ return None
84
+
85
+ async def save_cache(self, cache_key: str, cache_value: dict[str, Any]) -> bool:
86
+ """设置单个缓存值到 Redis(受 semaphore 限制并发).
87
+
88
+ Args:
89
+ cache_key: 缓存键
90
+ cache_value: 缓存值
91
+
92
+ Returns:
93
+ 是否成功
94
+ """
95
+ for attempt in range(3):
96
+ async with self._semaphore:
97
+ try:
98
+ redis = self._get_redis()
99
+ serialized = json.dumps(cache_value)
100
+ await redis.set(cache_key, serialized)
101
+ return True
102
+ except Exception as e:
103
+ logger.warning(
104
+ f"Redis 缓存写入失败 {attempt + 1} / 3: {type(e).__name__}: {e}"
105
+ )
106
+ await asyncio.sleep(0.1 * (attempt + 1)) # 简单的指数退避
107
+ self._redis = None # 重置 Redis 客户端以尝试重新连接
108
+ return False
109
+
110
+ async def close(self) -> None:
111
+ """关闭 Redis 连接."""
112
+ if self._redis:
113
+ await self._redis.close()
114
+ logger.info("Redis 连接已关闭")
115
+
116
+ async def __aenter__(self) -> "RedisCache":
117
+ """异步上下文管理器入口."""
118
+ return self
119
+
120
+ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
121
+ """异步上下文管理器退出."""
122
+ await self.close()
File without changes
@@ -0,0 +1,112 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Protocol, cast, runtime_checkable
4
+
5
+ from distflow.data.types import DatasetProcessOutputItem, MessageData
6
+
7
+
8
+ @runtime_checkable
9
+ class FormatterProtocol(Protocol):
10
+ def format(self, raw_item: dict[str, Any]) -> DatasetProcessOutputItem: ...
11
+
12
+
13
+ class AlpacaFormatter:
14
+ def __init__(self, *, user_key: str, assistant_key: str) -> None:
15
+ self.user_key = user_key
16
+ self.assistant_key = assistant_key
17
+
18
+ def format(self, raw_item: dict[str, Any]) -> DatasetProcessOutputItem:
19
+ assert (
20
+ self.user_key in raw_item
21
+ ), f"User key '{self.user_key}' not found in raw item"
22
+ assert (
23
+ self.assistant_key in raw_item
24
+ ), f"Assistant key '{self.assistant_key}' not found in raw item"
25
+ user_content = raw_item[self.user_key]
26
+ assert isinstance(
27
+ user_content, str
28
+ ), f"User content must be a string, got {type(user_content).__name__}: {user_content}"
29
+ assistant_content = raw_item[self.assistant_key]
30
+ assert isinstance(
31
+ assistant_content, str
32
+ ), f"Assistant content must be a string, got {type(assistant_content).__name__}: {assistant_content}"
33
+
34
+ return DatasetProcessOutputItem(
35
+ messages=[
36
+ cast(MessageData, {"role": "user", "content": user_content}),
37
+ cast(MessageData, {"role": "assistant", "content": assistant_content}),
38
+ ],
39
+ meta={
40
+ "user_key": self.user_key,
41
+ "assistant_key": self.assistant_key,
42
+ "raw_item": raw_item,
43
+ },
44
+ )
45
+
46
+
47
+ from typing import Any, cast
48
+
49
+
50
+ class ShareGptFormatter:
51
+ def __init__(self, *, conversations_key: str) -> None:
52
+ self.conversations_key = conversations_key
53
+
54
+ def format(self, raw_item: dict[str, Any]) -> DatasetProcessOutputItem:
55
+ assert (
56
+ self.conversations_key in raw_item
57
+ ), f"Conversations key '{self.conversations_key}' not found in raw item"
58
+
59
+ conversations = raw_item[self.conversations_key]
60
+ assert isinstance(
61
+ conversations, list
62
+ ), f"Conversations must be a list, got {type(conversations).__name__}: {conversations}"
63
+
64
+ messages: list[MessageData] = []
65
+
66
+ for conv in conversations:
67
+ if not isinstance(conv, dict):
68
+ continue
69
+
70
+ # 检测格式类型并提取字段
71
+ role = None
72
+ content = None
73
+
74
+ # 标准格式: {"role": "user", "content": "..."}
75
+ if "role" in conv and "content" in conv:
76
+ role = conv.get("role")
77
+ content = conv.get("content")
78
+
79
+ # ShareGPT 格式: {"from": "human", "value": "..."}
80
+ elif "from" in conv and "value" in conv:
81
+ from_field = conv.get("from")
82
+ content = conv.get("value")
83
+
84
+ assert isinstance(from_field, str) and isinstance(
85
+ content, str
86
+ ), "from和content必须都是str类型"
87
+
88
+ role_mapping = {
89
+ "human": "user",
90
+ "gpt": "assistant",
91
+ "system": "system",
92
+ "user": "user",
93
+ "assistant": "assistant",
94
+ }
95
+ role = role_mapping.get(from_field, from_field)
96
+
97
+ # 添加到 messages
98
+ if role is not None and content is not None:
99
+ messages.append(cast(MessageData, {"role": role, "content": content}))
100
+
101
+ return DatasetProcessOutputItem(
102
+ messages=messages,
103
+ meta={
104
+ "conversations_key": self.conversations_key,
105
+ "raw_item": raw_item,
106
+ "detected_format": (
107
+ "sharegpt"
108
+ if any(isinstance(c, dict) and "from" in c for c in conversations)
109
+ else "standard"
110
+ ),
111
+ },
112
+ )
@@ -0,0 +1,64 @@
1
+ import builtins
2
+ import random
3
+ from typing import Any, Literal, cast
4
+
5
+ from distflow.data.data_formatter import FormatterProtocol
6
+ from distflow.data.types import DatasetProcessOutputItem
7
+ from distflow.utils import logger
8
+
9
+
10
+ def load_dataset(
11
+ dataset_name: str,
12
+ data_path: str,
13
+ load_type: Literal["datasets", "modelscope", "pandas"],
14
+ formatter: FormatterProtocol,
15
+ data_size: int = -1,
16
+ split: str = "train",
17
+ sep: str = "\t",
18
+ dtype: str = "str",
19
+ shuffle_seed: int = 42,
20
+ use_json: bool = False,
21
+ ) -> tuple[str, list[DatasetProcessOutputItem]]:
22
+ logger.info(f"开始加载数据集: {dataset_name}, 路径: {data_path}, 类型: {load_type}")
23
+
24
+ # 数据大小
25
+ logger.debug(f"数据大小限制: {data_size if data_size > 0 else '全部'}")
26
+
27
+ match load_type:
28
+ case "datasets":
29
+ from datasets import load_dataset
30
+
31
+ logger.debug(f"使用 datasets 加载, split={split}, use_json={use_json}")
32
+ if use_json:
33
+ dataset = load_dataset("json", data_files=data_path, split=split)
34
+ else:
35
+ dataset = load_dataset(path=data_path, split=split)
36
+ case "modelscope":
37
+ from modelscope.msdatasets import MsDataset
38
+
39
+ logger.debug(f"使用 modelscope 加载, split={split}")
40
+ dataset = MsDataset.load(data_path, split=split)
41
+ case "pandas":
42
+ from datasets import Dataset, load_dataset
43
+ from pandas import read_csv
44
+
45
+ logger.debug("使用 pandas 加载")
46
+ dtype_actual = getattr(builtins, dtype)
47
+ df = read_csv(data_path, sep=sep, dtype=dtype_actual)
48
+ dataset = Dataset.from_pandas(df)
49
+
50
+ logger.info(f"数据集加载完成,总样本数: {len(dataset)}")
51
+
52
+ random.seed(shuffle_seed)
53
+ logger.debug(f"使用随机种子: {shuffle_seed}")
54
+ random_indices = list(range(len(dataset)))
55
+ if data_size > 0 and data_size < len(dataset):
56
+ logger.info(f"随机采样 {data_size} 条数据")
57
+ random_indices = random.sample(random_indices, data_size)
58
+ else:
59
+ logger.info("使用全部数据")
60
+ random.shuffle(random_indices)
61
+ sampled_data = cast(list[dict[str, Any]], [dataset[i] for i in random_indices])
62
+ logger.debug(f"采样完成,开始格式化数据")
63
+ formatted_data = [formatter.format(data_item) for data_item in sampled_data]
64
+ return dataset_name, formatted_data
@@ -0,0 +1,13 @@
1
+ from typing import Any
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class MessageData(BaseModel): # type: ignore[misc]
7
+ role: str
8
+ content: str | dict[str, Any]
9
+
10
+
11
+ class DatasetProcessOutputItem(BaseModel): # type: ignore[misc]
12
+ messages: list[MessageData]
13
+ meta: dict[str, Any]
File without changes
@@ -0,0 +1,19 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ from distflow.embed.types import EmbeddingInputItem, EmbeddingResult
4
+
5
+
6
+ class BaseEmbed(ABC):
7
+ def __init__(self, model_name: str) -> None:
8
+ self.model_name = model_name
9
+
10
+ @abstractmethod
11
+ def embed(self, dataset: list[EmbeddingInputItem]) -> list[EmbeddingResult | None]:
12
+ """异步嵌入计算.
13
+
14
+ Args:
15
+ dataset: 待嵌入的数据项列表
16
+
17
+ Returns:
18
+ 嵌入结果列表,失败项为 None
19
+ """