distflow 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,102 @@
1
+ Metadata-Version: 2.3
2
+ Name: distflow
3
+ Version: 1.0.0
4
+ Summary: Distance Computation Package for Data Preparation Bench
5
+ Requires-Dist: addict>=2.4.0
6
+ Requires-Dist: aiohttp>=3.11.0
7
+ Requires-Dist: datasets>=2.14.4
8
+ Requires-Dist: modelscope>=1.34.0
9
+ Requires-Dist: pandas>=2.3.3
10
+ Requires-Dist: pydantic>=2.12.5
11
+ Requires-Dist: pyyaml>=6.0
12
+ Requires-Dist: redis>=7.3.0
13
+ Requires-Dist: scikit-learn>=1.8.0
14
+ Requires-Dist: sentence-transformers>=5.3.0
15
+ Requires-Dist: torch>=2.6.0
16
+ Requires-Dist: transformers>=4.53.0
17
+ Requires-Dist: pre-commit>=4.2.0 ; extra == 'dev'
18
+ Requires-Dist: pyright>=1.1.408 ; extra == 'dev'
19
+ Requires-Dist: pytest>=8.4.1 ; extra == 'dev'
20
+ Requires-Dist: vllm>=0.8.5.post1 ; extra == 'vllm'
21
+ Requires-Python: >=3.12, <3.13
22
+ Provides-Extra: dev
23
+ Provides-Extra: vllm
24
+ Description-Content-Type: text/markdown
25
+
26
+ # Data-Preparation-Bench
27
+
28
+ A benchmark for evaluating the data preparation capabilities of large language models (LLMs). The benchmark is organized into two modules:
29
+
30
+ ## Modules
31
+
32
+ ### 1. Data Synthesis & Augmentation
33
+
34
+ Given raw metadata, the model is tasked with synthesizing or augmenting datasets to improve downstream model training.
35
+
36
+ ### 2. Data Quality Assessment
37
+
38
+ Given raw metadata, the model is tasked with predicting the training data's impact on downstream task performance.
39
+
40
+ ## Quick Start
41
+
42
+ ### Usage
43
+
44
+ This project uses [uv](https://docs.astral.sh/uv/) for dependency management. To get started:
45
+
46
+ ```bash
47
+ git clone https://github.com/haolpku/Data-Preparation-Bench.git
48
+ cd Data-Preparation-Bench
49
+ uv sync
50
+ ```
51
+
52
+ To use your own datasets, modify the configuration dictionaries and formatters in [compute_mmd.py](./examples/compute_mmd.py):
53
+
54
+ ```python
55
+ DS1_CONFIG = {
56
+ "name": "oda-math",
57
+ "data_path": "OpenDataArena/ODA-Math-460k",
58
+ "data_size": 5000,
59
+ "split": "train",
60
+ "shuffle_seed": 42,
61
+ }
62
+ formatter1 = AlpacaFormatter(
63
+ user_key="question",
64
+ assistant_key="response",
65
+ )
66
+
67
+ DS2_CONFIG = {
68
+ "name": "infinity-instruct",
69
+ "data_path": "BAAI/Infinity-Instruct",
70
+ "data_size": 5000,
71
+ "split": "train",
72
+ "shuffle_seed": 42,
73
+ }
74
+ formatter2 = ShareGptFormatter(
75
+ conversations_key="conversations",
76
+ )
77
+ ```
78
+
79
+ Typically, you only need to update `data_path` with your dataset and define a formatter that converts raw items to the required format. After making these changes, run the MMD computation with:
80
+
81
+ ```bash
82
+ uv run examples/compute_mmd.py
83
+ ```
84
+
85
+ ### Development
86
+
87
+ To set up the development environment locally:
88
+
89
+ ```bash
90
+ uv sync --extra dev
91
+ uv run pre-commit install
92
+ ```
93
+
94
+ Before committing, format and lint the code:
95
+
96
+ ```bash
97
+ uv run pre-commit run --all-files
98
+ ```
99
+
100
+ ## Experiment Settings
101
+
102
+ Please refer to [Experiment.md](./Experiment.md) for detailed experiment configurations.
@@ -0,0 +1,77 @@
1
+ # Data-Preparation-Bench
2
+
3
+ A benchmark for evaluating the data preparation capabilities of large language models (LLMs). The benchmark is organized into two modules:
4
+
5
+ ## Modules
6
+
7
+ ### 1. Data Synthesis & Augmentation
8
+
9
+ Given raw metadata, the model is tasked with synthesizing or augmenting datasets to improve downstream model training.
10
+
11
+ ### 2. Data Quality Assessment
12
+
13
+ Given raw metadata, the model is tasked with predicting the training data's impact on downstream task performance.
14
+
15
+ ## Quick Start
16
+
17
+ ### Usage
18
+
19
+ This project uses [uv](https://docs.astral.sh/uv/) for dependency management. To get started:
20
+
21
+ ```bash
22
+ git clone https://github.com/haolpku/Data-Preparation-Bench.git
23
+ cd Data-Preparation-Bench
24
+ uv sync
25
+ ```
26
+
27
+ To use your own datasets, modify the configuration dictionaries and formatters in [compute_mmd.py](./examples/compute_mmd.py):
28
+
29
+ ```python
30
+ DS1_CONFIG = {
31
+ "name": "oda-math",
32
+ "data_path": "OpenDataArena/ODA-Math-460k",
33
+ "data_size": 5000,
34
+ "split": "train",
35
+ "shuffle_seed": 42,
36
+ }
37
+ formatter1 = AlpacaFormatter(
38
+ user_key="question",
39
+ assistant_key="response",
40
+ )
41
+
42
+ DS2_CONFIG = {
43
+ "name": "infinity-instruct",
44
+ "data_path": "BAAI/Infinity-Instruct",
45
+ "data_size": 5000,
46
+ "split": "train",
47
+ "shuffle_seed": 42,
48
+ }
49
+ formatter2 = ShareGptFormatter(
50
+ conversations_key="conversations",
51
+ )
52
+ ```
53
+
54
+ Typically, you only need to update `data_path` with your dataset and define a formatter that converts raw items to the required format. After making these changes, run the MMD computation with:
55
+
56
+ ```bash
57
+ uv run examples/compute_mmd.py
58
+ ```
59
+
60
+ ### Development
61
+
62
+ To set up the development environment locally:
63
+
64
+ ```bash
65
+ uv sync --extra dev
66
+ uv run pre-commit install
67
+ ```
68
+
69
+ Before committing, format and lint the code:
70
+
71
+ ```bash
72
+ uv run pre-commit run --all-files
73
+ ```
74
+
75
+ ## Experiment Settings
76
+
77
+ Please refer to [Experiment.md](./Experiment.md) for detailed experiment configurations.
@@ -0,0 +1,45 @@
1
+ [project]
2
+ name = "distflow"
3
+ version = "1.0.0"
4
+ description = "Distance Computation Package for Data Preparation Bench"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12,<3.13"
7
+ dependencies = [
8
+ "addict>=2.4.0",
9
+ "aiohttp>=3.11.0",
10
+ "datasets>=2.14.4",
11
+ "modelscope>=1.34.0",
12
+ "pandas>=2.3.3",
13
+ "pydantic>=2.12.5",
14
+ "pyyaml>=6.0",
15
+ "redis>=7.3.0",
16
+ "scikit-learn>=1.8.0",
17
+ "sentence-transformers>=5.3.0",
18
+ "torch>=2.6.0",
19
+ "transformers>=4.53.0",
20
+ ]
21
+
22
+ [project.optional-dependencies]
23
+ vllm = ["vllm>=0.8.5.post1"]
24
+ dev = [
25
+ "pre-commit>=4.2.0",
26
+ "pyright>=1.1.408",
27
+ "pytest>=8.4.1",
28
+ ]
29
+
30
+ [tool.black]
31
+ line-length = 88
32
+ target-version = ['py312']
33
+ include = '\.pyi?$'
34
+
35
+ [tool.isort]
36
+ profile = "black"
37
+ line_length = 88
38
+ src_paths = ["src", "tests"]
39
+
40
+ [build-system]
41
+ requires = ["uv_build>=0.9.5,<0.12"]
42
+ build-backend = "uv_build"
43
+
44
+ [tool.uv]
45
+ index-url = "https://mirrors.aliyun.com/pypi/simple"
File without changes
File without changes
@@ -0,0 +1,7 @@
1
+ from typing import Any, Protocol
2
+
3
+
4
+ class CacheProtocol(Protocol):
5
+ async def load_cache(self, cache_key: str) -> dict[str, Any] | None: ...
6
+
7
+ async def save_cache(self, cache_key: str, cache_value: dict[str, Any]) -> bool: ...
@@ -0,0 +1,122 @@
1
+ import asyncio
2
+ import json
3
+ from typing import Any
4
+
5
+ from redis.asyncio import Redis
6
+
7
+ from distflow.utils import logger
8
+
9
+
10
+ class RedisCache:
11
+ """使用 Redis 作为缓存后端的实现.
12
+
13
+ 通过 Redis 客户端直接与 Redis 服务通信,实现分布式缓存。
14
+ 使用 semaphore 限制并发请求数量。
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ redis_url: str = "redis://127.0.0.1:6379",
20
+ max_concurrent_requests: int = 50,
21
+ redis_db: int = 0,
22
+ ) -> None:
23
+ """初始化Redis缓存.
24
+
25
+ Args:
26
+ redis_url: Redis 连接 URL,例如 "redis://127.0.0.1:6379"
27
+ max_concurrent_requests: 最大并发请求数
28
+ redis_db: Redis 数据库编号,默认为 0
29
+ """
30
+ self._semaphore = asyncio.Semaphore(max_concurrent_requests)
31
+
32
+ # 初始化 Redis 客户端
33
+ self._redis: Redis | None = None
34
+ self._redis_url = redis_url
35
+ self._redis_db = redis_db
36
+
37
+ def _get_redis(self) -> Redis:
38
+ """获取或创建 Redis 客户端."""
39
+ if self._redis is None:
40
+ self._redis = Redis.from_url(
41
+ self._redis_url,
42
+ db=self._redis_db,
43
+ decode_responses=True,
44
+ )
45
+ try:
46
+ # 测试连接
47
+ self._redis.ping()
48
+ logger.info(
49
+ f"成功连接到 Redis: {self._redis_url}, DB: {self._redis_db}"
50
+ )
51
+ except Exception as e:
52
+ logger.error(
53
+ f"无法连接到 Redis: {self._redis_url}, DB: {self._redis_db}, 错误: {e}"
54
+ )
55
+ raise ConnectionError(
56
+ f"无法连接到 Redis: {self._redis_url}, DB: {self._redis_db}"
57
+ ) from e
58
+ return self._redis
59
+
60
+ async def load_cache(self, cache_key: str) -> dict[str, Any] | None:
61
+ """从 Redis 获取单个缓存值(受 semaphore 限制并发).
62
+
63
+ Args:
64
+ cache_key: 缓存键
65
+
66
+ Returns:
67
+ 缓存值字典,如果不存在则返回 None
68
+ """
69
+ for attempt in range(3):
70
+ async with self._semaphore:
71
+ try:
72
+ redis = self._get_redis()
73
+ cached_data = await redis.get(cache_key)
74
+ if cached_data:
75
+ return json.loads(cached_data)
76
+ return None
77
+ except Exception as e:
78
+ logger.warning(
79
+ f"Redis 缓存查询失败 {attempt + 1} / 3: {type(e).__name__}: {e}"
80
+ )
81
+ await asyncio.sleep(0.1 * (attempt + 1)) # 简单的指数退避
82
+ self._redis = None # 重置 Redis 客户端以尝试重新连接
83
+ return None
84
+
85
+ async def save_cache(self, cache_key: str, cache_value: dict[str, Any]) -> bool:
86
+ """设置单个缓存值到 Redis(受 semaphore 限制并发).
87
+
88
+ Args:
89
+ cache_key: 缓存键
90
+ cache_value: 缓存值
91
+
92
+ Returns:
93
+ 是否成功
94
+ """
95
+ for attempt in range(3):
96
+ async with self._semaphore:
97
+ try:
98
+ redis = self._get_redis()
99
+ serialized = json.dumps(cache_value)
100
+ await redis.set(cache_key, serialized)
101
+ return True
102
+ except Exception as e:
103
+ logger.warning(
104
+ f"Redis 缓存写入失败 {attempt + 1} / 3: {type(e).__name__}: {e}"
105
+ )
106
+ await asyncio.sleep(0.1 * (attempt + 1)) # 简单的指数退避
107
+ self._redis = None # 重置 Redis 客户端以尝试重新连接
108
+ return False
109
+
110
+ async def close(self) -> None:
111
+ """关闭 Redis 连接."""
112
+ if self._redis:
113
+ await self._redis.close()
114
+ logger.info("Redis 连接已关闭")
115
+
116
+ async def __aenter__(self) -> "RedisCache":
117
+ """异步上下文管理器入口."""
118
+ return self
119
+
120
+ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
121
+ """异步上下文管理器退出."""
122
+ await self.close()
File without changes
@@ -0,0 +1,73 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Protocol, cast, runtime_checkable
4
+
5
+ from distflow.data.types import DatasetProcessOutputItem, MessageData
6
+
7
+
8
+ @runtime_checkable
9
+ class FormatterProtocol(Protocol):
10
+ def format(self, raw_item: dict[str, Any]) -> DatasetProcessOutputItem: ...
11
+
12
+
13
+ class AlpacaFormatter:
14
+ def __init__(self, *, user_key: str, assistant_key: str) -> None:
15
+ self.user_key = user_key
16
+ self.assistant_key = assistant_key
17
+
18
+ def format(self, raw_item: dict[str, Any]) -> DatasetProcessOutputItem:
19
+ assert (
20
+ self.user_key in raw_item
21
+ ), f"User key '{self.user_key}' not found in raw item"
22
+ assert (
23
+ self.assistant_key in raw_item
24
+ ), f"Assistant key '{self.assistant_key}' not found in raw item"
25
+ user_content = raw_item[self.user_key]
26
+ assert isinstance(
27
+ user_content, str
28
+ ), f"User content must be a string, got {type(user_content).__name__}: {user_content}"
29
+ assistant_content = raw_item[self.assistant_key]
30
+ assert isinstance(
31
+ assistant_content, str
32
+ ), f"Assistant content must be a string, got {type(assistant_content).__name__}: {assistant_content}"
33
+
34
+ return DatasetProcessOutputItem(
35
+ messages=[
36
+ cast(MessageData, {"role": "user", "content": user_content}),
37
+ cast(MessageData, {"role": "assistant", "content": assistant_content}),
38
+ ],
39
+ meta={
40
+ "user_key": self.user_key,
41
+ "assistant_key": self.assistant_key,
42
+ "raw_item": raw_item,
43
+ },
44
+ )
45
+
46
+
47
+ class ShareGptFormatter:
48
+ def __init__(self, *, conversations_key: str) -> None:
49
+ self.conversations_key = conversations_key
50
+
51
+ def format(self, raw_item: dict[str, Any]) -> DatasetProcessOutputItem:
52
+ assert (
53
+ self.conversations_key in raw_item
54
+ ), f"Conversations key '{self.conversations_key}' not found in raw item"
55
+ conversations = raw_item[self.conversations_key]
56
+ assert isinstance(
57
+ conversations, list
58
+ ), f"Conversations must be a list, got {type(conversations).__name__}: {conversations}"
59
+
60
+ messages: list[MessageData] = []
61
+ for conv in conversations:
62
+ if isinstance(conv, dict):
63
+ role = conv.get("role")
64
+ content = conv.get("content")
65
+ if role is not None and content is not None:
66
+ messages.append(
67
+ cast(MessageData, {"role": role, "content": content})
68
+ )
69
+
70
+ return DatasetProcessOutputItem(
71
+ messages=messages,
72
+ meta={"conversations_key": self.conversations_key, "raw_item": raw_item},
73
+ )
@@ -0,0 +1,64 @@
1
+ import builtins
2
+ import random
3
+ from typing import Any, Literal, cast
4
+
5
+ from distflow.data.data_formatter import FormatterProtocol
6
+ from distflow.data.types import DatasetProcessOutputItem
7
+ from distflow.utils import logger
8
+
9
+
10
+ def load_dataset(
11
+ dataset_name: str,
12
+ data_path: str,
13
+ load_type: Literal["datasets", "modelscope", "pandas"],
14
+ formatter: FormatterProtocol,
15
+ data_size: int = -1,
16
+ split: str = "train",
17
+ sep: str = "\t",
18
+ dtype: str = "str",
19
+ shuffle_seed: int = 42,
20
+ use_json: bool = False,
21
+ ) -> tuple[str, list[DatasetProcessOutputItem]]:
22
+ logger.info(f"开始加载数据集: {dataset_name}, 路径: {data_path}, 类型: {load_type}")
23
+
24
+ # 数据大小
25
+ logger.debug(f"数据大小限制: {data_size if data_size > 0 else '全部'}")
26
+
27
+ match load_type:
28
+ case "datasets":
29
+ from datasets import load_dataset
30
+
31
+ logger.debug(f"使用 datasets 加载, split={split}, use_json={use_json}")
32
+ if use_json:
33
+ dataset = load_dataset("json", data_files=data_path, split=split)
34
+ else:
35
+ dataset = load_dataset(path=data_path, split=split)
36
+ case "modelscope":
37
+ from modelscope.msdatasets import MsDataset
38
+
39
+ logger.debug(f"使用 modelscope 加载, split={split}")
40
+ dataset = MsDataset.load(data_path, split=split)
41
+ case "pandas":
42
+ from datasets import Dataset, load_dataset
43
+ from pandas import read_csv
44
+
45
+ logger.debug("使用 pandas 加载")
46
+ dtype_actual = getattr(builtins, dtype)
47
+ df = read_csv(data_path, sep=sep, dtype=dtype_actual)
48
+ dataset = Dataset.from_pandas(df)
49
+
50
+ logger.info(f"数据集加载完成,总样本数: {len(dataset)}")
51
+
52
+ random.seed(shuffle_seed)
53
+ logger.debug(f"使用随机种子: {shuffle_seed}")
54
+ random_indices = list(range(len(dataset)))
55
+ if data_size > 0 and data_size < len(dataset):
56
+ logger.info(f"随机采样 {data_size} 条数据")
57
+ random_indices = random.sample(random_indices, data_size)
58
+ else:
59
+ logger.info("使用全部数据")
60
+ random.shuffle(random_indices)
61
+ sampled_data = cast(list[dict[str, Any]], [dataset[i] for i in random_indices])
62
+ logger.debug(f"采样完成,开始格式化数据")
63
+ formatted_data = [formatter.format(data_item) for data_item in sampled_data]
64
+ return dataset_name, formatted_data
@@ -0,0 +1,13 @@
1
+ from typing import Any
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class MessageData(BaseModel): # type: ignore[misc]
7
+ role: str
8
+ content: str | dict[str, Any]
9
+
10
+
11
+ class DatasetProcessOutputItem(BaseModel): # type: ignore[misc]
12
+ messages: list[MessageData]
13
+ meta: dict[str, Any]
File without changes
@@ -0,0 +1,19 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ from distflow.embed.types import EmbeddingInputItem, EmbeddingResult
4
+
5
+
6
+ class BaseEmbed(ABC):
7
+ def __init__(self, model_name: str) -> None:
8
+ self.model_name = model_name
9
+
10
+ @abstractmethod
11
+ def embed(self, dataset: list[EmbeddingInputItem]) -> list[EmbeddingResult]:
12
+ """异步嵌入计算.
13
+
14
+ Args:
15
+ dataset: 待嵌入的数据项列表
16
+
17
+ Returns:
18
+ 嵌入结果列表
19
+ """
@@ -0,0 +1,154 @@
1
+ import asyncio
2
+ import hashlib
3
+ import json
4
+ from collections.abc import Coroutine
5
+ from typing import Any
6
+
7
+ from distflow.cache.protocol import CacheProtocol
8
+ from distflow.embed.base import BaseEmbed
9
+ from distflow.embed.types import EmbeddingInputItem, EmbeddingResult
10
+ from distflow.utils import logger
11
+
12
+
13
+ def dict_to_hash(d: dict[Any, Any]) -> str:
14
+ """生成字典的SHA256哈希摘要"""
15
+ s = json.dumps(d, sort_keys=True).encode()
16
+ return hashlib.sha256(s).hexdigest()
17
+
18
+
19
+ class CachedEmbed(BaseEmbed):
20
+ """使用 Redis 作为缓存后端的嵌入包装器.
21
+
22
+ 通过 RedisCache 类与 Redis 服务通信,实现分布式缓存。
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ embedder: BaseEmbed,
28
+ cache: CacheProtocol,
29
+ cache_model_id: str | None = None,
30
+ legacy_key: bool = False,
31
+ ) -> None:
32
+ """初始化缓存嵌入器.
33
+
34
+ Args:
35
+ embedder: 底层嵌入器,用于计算未缓存的数据
36
+ cache: 符合 CacheProtocol 的缓存实现
37
+ cache_model_id: 用于缓存键的模型标识符,默认为模型路径。
38
+ 可用于在移动模型后仍使用旧缓存。
39
+ legacy_key: 是否使用旧版缓存键格式(包含完整 data_item),
40
+ 默认为 False(使用新版:仅 model_id + messages)
41
+ """
42
+ self.embedder = embedder
43
+ self._cache = cache
44
+ self.model_path = (
45
+ getattr(embedder, "model_name", None)
46
+ or getattr(embedder, "model_path", None)
47
+ or "unknown"
48
+ )
49
+ # 用于缓存键的模型标识符
50
+ self.cache_model_id = cache_model_id if cache_model_id else self.model_path
51
+ self.legacy_key = legacy_key
52
+
53
+ super().__init__(self.model_path)
54
+
55
+ def _build_cache_key(self, item: EmbeddingInputItem) -> str:
56
+ """构建缓存键.
57
+
58
+ Args:
59
+ item: 输入数据项
60
+
61
+ Returns:
62
+ SHA256 哈希键
63
+ """
64
+ if self.legacy_key:
65
+ # 旧版格式:包含完整 data_item(包含 messages 和 meta)
66
+ key_payload = {
67
+ "model_path": self.model_path,
68
+ "data_item": item.model_dump(),
69
+ }
70
+ else:
71
+ # 新版格式:仅使用 cache_model_id 和 messages(不含 meta)
72
+ key_payload = {
73
+ "model_id": self.cache_model_id,
74
+ "messages": [msg.model_dump() for msg in item.messages],
75
+ }
76
+ return dict_to_hash(key_payload)
77
+
78
+ def embed(self, dataset: list[EmbeddingInputItem]) -> list[EmbeddingResult]:
79
+ """异步执行嵌入计算,使用 Redis 缓存.
80
+
81
+ Args:
82
+ dataset: 待嵌入的数据项列表
83
+
84
+ Returns:
85
+ 嵌入结果列表
86
+ """
87
+ logger.info(f"开始缓存嵌入计算,数据量: {len(dataset)}")
88
+
89
+ # 并发查询所有缓存
90
+ cache_keys = [self._build_cache_key(item) for item in dataset]
91
+ cache_tasks = [self._cache.load_cache(key) for key in cache_keys]
92
+
93
+ async def _run_all_get_cache() -> list[dict[str, Any] | None | BaseException]:
94
+ return await asyncio.gather(*cache_tasks, return_exceptions=True)
95
+
96
+ cached_values = asyncio.run(_run_all_get_cache())
97
+
98
+ # 分离缓存命中和未命中的项
99
+ results: list[EmbeddingResult | None] = [None] * len(dataset)
100
+ missing_items: list[EmbeddingInputItem] = []
101
+ missing_indices: list[int] = []
102
+ missing_keys: list[str] = []
103
+
104
+ for idx, (item, key, cached_result) in enumerate(
105
+ zip(dataset, cache_keys, cached_values)
106
+ ):
107
+ # 处理异常结果
108
+ if isinstance(cached_result, BaseException):
109
+ logger.debug(f"缓存查询异常,将重新计算: {cached_result}")
110
+ missing_items.append(item)
111
+ missing_indices.append(idx)
112
+ missing_keys.append(key)
113
+ elif cached_result is None:
114
+ missing_items.append(item)
115
+ missing_indices.append(idx)
116
+ missing_keys.append(key)
117
+ else:
118
+ results[idx] = EmbeddingResult(
119
+ embedding=cached_result["embedding"],
120
+ data_item=item,
121
+ meta=cached_result.get("meta", item.meta),
122
+ )
123
+ logger.debug(f"缓存命中: {key[:16]}...")
124
+
125
+ logger.info(f"缓存命中: {len(dataset) - len(missing_items)}/{len(dataset)}")
126
+
127
+ # 计算未缓存的嵌入
128
+ if missing_items:
129
+ new_results = self.embedder.embed(missing_items)
130
+
131
+ # 并发写入缓存
132
+ write_tasks: list[Coroutine[Any, Any, bool]] = []
133
+ for key, idx, result in zip(missing_keys, missing_indices, new_results):
134
+ cache_value = {
135
+ "embedding": result.embedding,
136
+ "meta": result.meta,
137
+ }
138
+ write_tasks.append(self._cache.save_cache(key, cache_value))
139
+ results[idx] = EmbeddingResult(
140
+ embedding=result.embedding,
141
+ data_item=dataset[idx],
142
+ meta=result.meta,
143
+ )
144
+
145
+ # 等待所有写入完成
146
+ async def _run_all_save_cache() -> list[bool | BaseException]:
147
+ return await asyncio.gather(*write_tasks, return_exceptions=False)
148
+
149
+ write_results = asyncio.run(_run_all_save_cache())
150
+ success_count = sum(1 for r in write_results if r is True)
151
+ logger.info(f"缓存写入完成: {success_count}/{len(write_tasks)} 成功")
152
+
153
+ logger.info(f"嵌入计算完成,共 {len(results)} 条结果")
154
+ return [result for result in results if result is not None]