distflow 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- distflow/__init__.py +0 -0
- distflow/cache/__init__.py +0 -0
- distflow/cache/protocol.py +7 -0
- distflow/cache/redis_cache.py +122 -0
- distflow/data/__init__.py +0 -0
- distflow/data/data_formatter.py +73 -0
- distflow/data/data_loader.py +64 -0
- distflow/data/types.py +13 -0
- distflow/embed/__init__.py +0 -0
- distflow/embed/base.py +19 -0
- distflow/embed/cache_wrapper.py +154 -0
- distflow/embed/sentence_transformers.py +152 -0
- distflow/embed/types.py +13 -0
- distflow/embed/vllm.py +133 -0
- distflow/mmd.py +175 -0
- distflow/utils/__init__.py +0 -0
- distflow/utils/logger.py +126 -0
- distflow/utils/stats.py +111 -0
- distflow/utils/timing.py +106 -0
- distflow-1.0.0.dist-info/METADATA +102 -0
- distflow-1.0.0.dist-info/RECORD +22 -0
- distflow-1.0.0.dist-info/WHEEL +4 -0
distflow/__init__.py
ADDED
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from redis.asyncio import Redis
|
|
6
|
+
|
|
7
|
+
from distflow.utils import logger
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RedisCache:
|
|
11
|
+
"""使用 Redis 作为缓存后端的实现.
|
|
12
|
+
|
|
13
|
+
通过 Redis 客户端直接与 Redis 服务通信,实现分布式缓存。
|
|
14
|
+
使用 semaphore 限制并发请求数量。
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
redis_url: str = "redis://127.0.0.1:6379",
|
|
20
|
+
max_concurrent_requests: int = 50,
|
|
21
|
+
redis_db: int = 0,
|
|
22
|
+
) -> None:
|
|
23
|
+
"""初始化Redis缓存.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
redis_url: Redis 连接 URL,例如 "redis://127.0.0.1:6379"
|
|
27
|
+
max_concurrent_requests: 最大并发请求数
|
|
28
|
+
redis_db: Redis 数据库编号,默认为 0
|
|
29
|
+
"""
|
|
30
|
+
self._semaphore = asyncio.Semaphore(max_concurrent_requests)
|
|
31
|
+
|
|
32
|
+
# 初始化 Redis 客户端
|
|
33
|
+
self._redis: Redis | None = None
|
|
34
|
+
self._redis_url = redis_url
|
|
35
|
+
self._redis_db = redis_db
|
|
36
|
+
|
|
37
|
+
def _get_redis(self) -> Redis:
|
|
38
|
+
"""获取或创建 Redis 客户端."""
|
|
39
|
+
if self._redis is None:
|
|
40
|
+
self._redis = Redis.from_url(
|
|
41
|
+
self._redis_url,
|
|
42
|
+
db=self._redis_db,
|
|
43
|
+
decode_responses=True,
|
|
44
|
+
)
|
|
45
|
+
try:
|
|
46
|
+
# 测试连接
|
|
47
|
+
self._redis.ping()
|
|
48
|
+
logger.info(
|
|
49
|
+
f"成功连接到 Redis: {self._redis_url}, DB: {self._redis_db}"
|
|
50
|
+
)
|
|
51
|
+
except Exception as e:
|
|
52
|
+
logger.error(
|
|
53
|
+
f"无法连接到 Redis: {self._redis_url}, DB: {self._redis_db}, 错误: {e}"
|
|
54
|
+
)
|
|
55
|
+
raise ConnectionError(
|
|
56
|
+
f"无法连接到 Redis: {self._redis_url}, DB: {self._redis_db}"
|
|
57
|
+
) from e
|
|
58
|
+
return self._redis
|
|
59
|
+
|
|
60
|
+
async def load_cache(self, cache_key: str) -> dict[str, Any] | None:
|
|
61
|
+
"""从 Redis 获取单个缓存值(受 semaphore 限制并发).
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
cache_key: 缓存键
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
缓存值字典,如果不存在则返回 None
|
|
68
|
+
"""
|
|
69
|
+
for attempt in range(3):
|
|
70
|
+
async with self._semaphore:
|
|
71
|
+
try:
|
|
72
|
+
redis = self._get_redis()
|
|
73
|
+
cached_data = await redis.get(cache_key)
|
|
74
|
+
if cached_data:
|
|
75
|
+
return json.loads(cached_data)
|
|
76
|
+
return None
|
|
77
|
+
except Exception as e:
|
|
78
|
+
logger.warning(
|
|
79
|
+
f"Redis 缓存查询失败 {attempt + 1} / 3: {type(e).__name__}: {e}"
|
|
80
|
+
)
|
|
81
|
+
await asyncio.sleep(0.1 * (attempt + 1)) # 简单的指数退避
|
|
82
|
+
self._redis = None # 重置 Redis 客户端以尝试重新连接
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
async def save_cache(self, cache_key: str, cache_value: dict[str, Any]) -> bool:
|
|
86
|
+
"""设置单个缓存值到 Redis(受 semaphore 限制并发).
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
cache_key: 缓存键
|
|
90
|
+
cache_value: 缓存值
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
是否成功
|
|
94
|
+
"""
|
|
95
|
+
for attempt in range(3):
|
|
96
|
+
async with self._semaphore:
|
|
97
|
+
try:
|
|
98
|
+
redis = self._get_redis()
|
|
99
|
+
serialized = json.dumps(cache_value)
|
|
100
|
+
await redis.set(cache_key, serialized)
|
|
101
|
+
return True
|
|
102
|
+
except Exception as e:
|
|
103
|
+
logger.warning(
|
|
104
|
+
f"Redis 缓存写入失败 {attempt + 1} / 3: {type(e).__name__}: {e}"
|
|
105
|
+
)
|
|
106
|
+
await asyncio.sleep(0.1 * (attempt + 1)) # 简单的指数退避
|
|
107
|
+
self._redis = None # 重置 Redis 客户端以尝试重新连接
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
async def close(self) -> None:
|
|
111
|
+
"""关闭 Redis 连接."""
|
|
112
|
+
if self._redis:
|
|
113
|
+
await self._redis.close()
|
|
114
|
+
logger.info("Redis 连接已关闭")
|
|
115
|
+
|
|
116
|
+
async def __aenter__(self) -> "RedisCache":
|
|
117
|
+
"""异步上下文管理器入口."""
|
|
118
|
+
return self
|
|
119
|
+
|
|
120
|
+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
121
|
+
"""异步上下文管理器退出."""
|
|
122
|
+
await self.close()
|
|
File without changes
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Protocol, cast, runtime_checkable
|
|
4
|
+
|
|
5
|
+
from distflow.data.types import DatasetProcessOutputItem, MessageData
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@runtime_checkable
|
|
9
|
+
class FormatterProtocol(Protocol):
|
|
10
|
+
def format(self, raw_item: dict[str, Any]) -> DatasetProcessOutputItem: ...
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AlpacaFormatter:
|
|
14
|
+
def __init__(self, *, user_key: str, assistant_key: str) -> None:
|
|
15
|
+
self.user_key = user_key
|
|
16
|
+
self.assistant_key = assistant_key
|
|
17
|
+
|
|
18
|
+
def format(self, raw_item: dict[str, Any]) -> DatasetProcessOutputItem:
|
|
19
|
+
assert (
|
|
20
|
+
self.user_key in raw_item
|
|
21
|
+
), f"User key '{self.user_key}' not found in raw item"
|
|
22
|
+
assert (
|
|
23
|
+
self.assistant_key in raw_item
|
|
24
|
+
), f"Assistant key '{self.assistant_key}' not found in raw item"
|
|
25
|
+
user_content = raw_item[self.user_key]
|
|
26
|
+
assert isinstance(
|
|
27
|
+
user_content, str
|
|
28
|
+
), f"User content must be a string, got {type(user_content).__name__}: {user_content}"
|
|
29
|
+
assistant_content = raw_item[self.assistant_key]
|
|
30
|
+
assert isinstance(
|
|
31
|
+
assistant_content, str
|
|
32
|
+
), f"Assistant content must be a string, got {type(assistant_content).__name__}: {assistant_content}"
|
|
33
|
+
|
|
34
|
+
return DatasetProcessOutputItem(
|
|
35
|
+
messages=[
|
|
36
|
+
cast(MessageData, {"role": "user", "content": user_content}),
|
|
37
|
+
cast(MessageData, {"role": "assistant", "content": assistant_content}),
|
|
38
|
+
],
|
|
39
|
+
meta={
|
|
40
|
+
"user_key": self.user_key,
|
|
41
|
+
"assistant_key": self.assistant_key,
|
|
42
|
+
"raw_item": raw_item,
|
|
43
|
+
},
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ShareGptFormatter:
|
|
48
|
+
def __init__(self, *, conversations_key: str) -> None:
|
|
49
|
+
self.conversations_key = conversations_key
|
|
50
|
+
|
|
51
|
+
def format(self, raw_item: dict[str, Any]) -> DatasetProcessOutputItem:
|
|
52
|
+
assert (
|
|
53
|
+
self.conversations_key in raw_item
|
|
54
|
+
), f"Conversations key '{self.conversations_key}' not found in raw item"
|
|
55
|
+
conversations = raw_item[self.conversations_key]
|
|
56
|
+
assert isinstance(
|
|
57
|
+
conversations, list
|
|
58
|
+
), f"Conversations must be a list, got {type(conversations).__name__}: {conversations}"
|
|
59
|
+
|
|
60
|
+
messages: list[MessageData] = []
|
|
61
|
+
for conv in conversations:
|
|
62
|
+
if isinstance(conv, dict):
|
|
63
|
+
role = conv.get("role")
|
|
64
|
+
content = conv.get("content")
|
|
65
|
+
if role is not None and content is not None:
|
|
66
|
+
messages.append(
|
|
67
|
+
cast(MessageData, {"role": role, "content": content})
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
return DatasetProcessOutputItem(
|
|
71
|
+
messages=messages,
|
|
72
|
+
meta={"conversations_key": self.conversations_key, "raw_item": raw_item},
|
|
73
|
+
)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import builtins
|
|
2
|
+
import random
|
|
3
|
+
from typing import Any, Literal, cast
|
|
4
|
+
|
|
5
|
+
from distflow.data.data_formatter import FormatterProtocol
|
|
6
|
+
from distflow.data.types import DatasetProcessOutputItem
|
|
7
|
+
from distflow.utils import logger
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def load_dataset(
|
|
11
|
+
dataset_name: str,
|
|
12
|
+
data_path: str,
|
|
13
|
+
load_type: Literal["datasets", "modelscope", "pandas"],
|
|
14
|
+
formatter: FormatterProtocol,
|
|
15
|
+
data_size: int = -1,
|
|
16
|
+
split: str = "train",
|
|
17
|
+
sep: str = "\t",
|
|
18
|
+
dtype: str = "str",
|
|
19
|
+
shuffle_seed: int = 42,
|
|
20
|
+
use_json: bool = False,
|
|
21
|
+
) -> tuple[str, list[DatasetProcessOutputItem]]:
|
|
22
|
+
logger.info(f"开始加载数据集: {dataset_name}, 路径: {data_path}, 类型: {load_type}")
|
|
23
|
+
|
|
24
|
+
# 数据大小
|
|
25
|
+
logger.debug(f"数据大小限制: {data_size if data_size > 0 else '全部'}")
|
|
26
|
+
|
|
27
|
+
match load_type:
|
|
28
|
+
case "datasets":
|
|
29
|
+
from datasets import load_dataset
|
|
30
|
+
|
|
31
|
+
logger.debug(f"使用 datasets 加载, split={split}, use_json={use_json}")
|
|
32
|
+
if use_json:
|
|
33
|
+
dataset = load_dataset("json", data_files=data_path, split=split)
|
|
34
|
+
else:
|
|
35
|
+
dataset = load_dataset(path=data_path, split=split)
|
|
36
|
+
case "modelscope":
|
|
37
|
+
from modelscope.msdatasets import MsDataset
|
|
38
|
+
|
|
39
|
+
logger.debug(f"使用 modelscope 加载, split={split}")
|
|
40
|
+
dataset = MsDataset.load(data_path, split=split)
|
|
41
|
+
case "pandas":
|
|
42
|
+
from datasets import Dataset, load_dataset
|
|
43
|
+
from pandas import read_csv
|
|
44
|
+
|
|
45
|
+
logger.debug("使用 pandas 加载")
|
|
46
|
+
dtype_actual = getattr(builtins, dtype)
|
|
47
|
+
df = read_csv(data_path, sep=sep, dtype=dtype_actual)
|
|
48
|
+
dataset = Dataset.from_pandas(df)
|
|
49
|
+
|
|
50
|
+
logger.info(f"数据集加载完成,总样本数: {len(dataset)}")
|
|
51
|
+
|
|
52
|
+
random.seed(shuffle_seed)
|
|
53
|
+
logger.debug(f"使用随机种子: {shuffle_seed}")
|
|
54
|
+
random_indices = list(range(len(dataset)))
|
|
55
|
+
if data_size > 0 and data_size < len(dataset):
|
|
56
|
+
logger.info(f"随机采样 {data_size} 条数据")
|
|
57
|
+
random_indices = random.sample(random_indices, data_size)
|
|
58
|
+
else:
|
|
59
|
+
logger.info("使用全部数据")
|
|
60
|
+
random.shuffle(random_indices)
|
|
61
|
+
sampled_data = cast(list[dict[str, Any]], [dataset[i] for i in random_indices])
|
|
62
|
+
logger.debug(f"采样完成,开始格式化数据")
|
|
63
|
+
formatted_data = [formatter.format(data_item) for data_item in sampled_data]
|
|
64
|
+
return dataset_name, formatted_data
|
distflow/data/types.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MessageData(BaseModel): # type: ignore[misc]
|
|
7
|
+
role: str
|
|
8
|
+
content: str | dict[str, Any]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DatasetProcessOutputItem(BaseModel): # type: ignore[misc]
|
|
12
|
+
messages: list[MessageData]
|
|
13
|
+
meta: dict[str, Any]
|
|
File without changes
|
distflow/embed/base.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
from distflow.embed.types import EmbeddingInputItem, EmbeddingResult
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseEmbed(ABC):
|
|
7
|
+
def __init__(self, model_name: str) -> None:
|
|
8
|
+
self.model_name = model_name
|
|
9
|
+
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def embed(self, dataset: list[EmbeddingInputItem]) -> list[EmbeddingResult]:
|
|
12
|
+
"""异步嵌入计算.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
dataset: 待嵌入的数据项列表
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
嵌入结果列表
|
|
19
|
+
"""
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import hashlib
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import Coroutine
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from distflow.cache.protocol import CacheProtocol
|
|
8
|
+
from distflow.embed.base import BaseEmbed
|
|
9
|
+
from distflow.embed.types import EmbeddingInputItem, EmbeddingResult
|
|
10
|
+
from distflow.utils import logger
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def dict_to_hash(d: dict[Any, Any]) -> str:
|
|
14
|
+
"""生成字典的SHA256哈希摘要"""
|
|
15
|
+
s = json.dumps(d, sort_keys=True).encode()
|
|
16
|
+
return hashlib.sha256(s).hexdigest()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CachedEmbed(BaseEmbed):
|
|
20
|
+
"""使用 Redis 作为缓存后端的嵌入包装器.
|
|
21
|
+
|
|
22
|
+
通过 RedisCache 类与 Redis 服务通信,实现分布式缓存。
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
embedder: BaseEmbed,
|
|
28
|
+
cache: CacheProtocol,
|
|
29
|
+
cache_model_id: str | None = None,
|
|
30
|
+
legacy_key: bool = False,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""初始化缓存嵌入器.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
embedder: 底层嵌入器,用于计算未缓存的数据
|
|
36
|
+
cache: 符合 CacheProtocol 的缓存实现
|
|
37
|
+
cache_model_id: 用于缓存键的模型标识符,默认为模型路径。
|
|
38
|
+
可用于在移动模型后仍使用旧缓存。
|
|
39
|
+
legacy_key: 是否使用旧版缓存键格式(包含完整 data_item),
|
|
40
|
+
默认为 False(使用新版:仅 model_id + messages)
|
|
41
|
+
"""
|
|
42
|
+
self.embedder = embedder
|
|
43
|
+
self._cache = cache
|
|
44
|
+
self.model_path = (
|
|
45
|
+
getattr(embedder, "model_name", None)
|
|
46
|
+
or getattr(embedder, "model_path", None)
|
|
47
|
+
or "unknown"
|
|
48
|
+
)
|
|
49
|
+
# 用于缓存键的模型标识符
|
|
50
|
+
self.cache_model_id = cache_model_id if cache_model_id else self.model_path
|
|
51
|
+
self.legacy_key = legacy_key
|
|
52
|
+
|
|
53
|
+
super().__init__(self.model_path)
|
|
54
|
+
|
|
55
|
+
def _build_cache_key(self, item: EmbeddingInputItem) -> str:
|
|
56
|
+
"""构建缓存键.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
item: 输入数据项
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
SHA256 哈希键
|
|
63
|
+
"""
|
|
64
|
+
if self.legacy_key:
|
|
65
|
+
# 旧版格式:包含完整 data_item(包含 messages 和 meta)
|
|
66
|
+
key_payload = {
|
|
67
|
+
"model_path": self.model_path,
|
|
68
|
+
"data_item": item.model_dump(),
|
|
69
|
+
}
|
|
70
|
+
else:
|
|
71
|
+
# 新版格式:仅使用 cache_model_id 和 messages(不含 meta)
|
|
72
|
+
key_payload = {
|
|
73
|
+
"model_id": self.cache_model_id,
|
|
74
|
+
"messages": [msg.model_dump() for msg in item.messages],
|
|
75
|
+
}
|
|
76
|
+
return dict_to_hash(key_payload)
|
|
77
|
+
|
|
78
|
+
def embed(self, dataset: list[EmbeddingInputItem]) -> list[EmbeddingResult]:
|
|
79
|
+
"""异步执行嵌入计算,使用 Redis 缓存.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
dataset: 待嵌入的数据项列表
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
嵌入结果列表
|
|
86
|
+
"""
|
|
87
|
+
logger.info(f"开始缓存嵌入计算,数据量: {len(dataset)}")
|
|
88
|
+
|
|
89
|
+
# 并发查询所有缓存
|
|
90
|
+
cache_keys = [self._build_cache_key(item) for item in dataset]
|
|
91
|
+
cache_tasks = [self._cache.load_cache(key) for key in cache_keys]
|
|
92
|
+
|
|
93
|
+
async def _run_all_get_cache() -> list[dict[str, Any] | None | BaseException]:
|
|
94
|
+
return await asyncio.gather(*cache_tasks, return_exceptions=True)
|
|
95
|
+
|
|
96
|
+
cached_values = asyncio.run(_run_all_get_cache())
|
|
97
|
+
|
|
98
|
+
# 分离缓存命中和未命中的项
|
|
99
|
+
results: list[EmbeddingResult | None] = [None] * len(dataset)
|
|
100
|
+
missing_items: list[EmbeddingInputItem] = []
|
|
101
|
+
missing_indices: list[int] = []
|
|
102
|
+
missing_keys: list[str] = []
|
|
103
|
+
|
|
104
|
+
for idx, (item, key, cached_result) in enumerate(
|
|
105
|
+
zip(dataset, cache_keys, cached_values)
|
|
106
|
+
):
|
|
107
|
+
# 处理异常结果
|
|
108
|
+
if isinstance(cached_result, BaseException):
|
|
109
|
+
logger.debug(f"缓存查询异常,将重新计算: {cached_result}")
|
|
110
|
+
missing_items.append(item)
|
|
111
|
+
missing_indices.append(idx)
|
|
112
|
+
missing_keys.append(key)
|
|
113
|
+
elif cached_result is None:
|
|
114
|
+
missing_items.append(item)
|
|
115
|
+
missing_indices.append(idx)
|
|
116
|
+
missing_keys.append(key)
|
|
117
|
+
else:
|
|
118
|
+
results[idx] = EmbeddingResult(
|
|
119
|
+
embedding=cached_result["embedding"],
|
|
120
|
+
data_item=item,
|
|
121
|
+
meta=cached_result.get("meta", item.meta),
|
|
122
|
+
)
|
|
123
|
+
logger.debug(f"缓存命中: {key[:16]}...")
|
|
124
|
+
|
|
125
|
+
logger.info(f"缓存命中: {len(dataset) - len(missing_items)}/{len(dataset)}")
|
|
126
|
+
|
|
127
|
+
# 计算未缓存的嵌入
|
|
128
|
+
if missing_items:
|
|
129
|
+
new_results = self.embedder.embed(missing_items)
|
|
130
|
+
|
|
131
|
+
# 并发写入缓存
|
|
132
|
+
write_tasks: list[Coroutine[Any, Any, bool]] = []
|
|
133
|
+
for key, idx, result in zip(missing_keys, missing_indices, new_results):
|
|
134
|
+
cache_value = {
|
|
135
|
+
"embedding": result.embedding,
|
|
136
|
+
"meta": result.meta,
|
|
137
|
+
}
|
|
138
|
+
write_tasks.append(self._cache.save_cache(key, cache_value))
|
|
139
|
+
results[idx] = EmbeddingResult(
|
|
140
|
+
embedding=result.embedding,
|
|
141
|
+
data_item=dataset[idx],
|
|
142
|
+
meta=result.meta,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# 等待所有写入完成
|
|
146
|
+
async def _run_all_save_cache() -> list[bool | BaseException]:
|
|
147
|
+
return await asyncio.gather(*write_tasks, return_exceptions=False)
|
|
148
|
+
|
|
149
|
+
write_results = asyncio.run(_run_all_save_cache())
|
|
150
|
+
success_count = sum(1 for r in write_results if r is True)
|
|
151
|
+
logger.info(f"缓存写入完成: {success_count}/{len(write_tasks)} 成功")
|
|
152
|
+
|
|
153
|
+
logger.info(f"嵌入计算完成,共 {len(results)} 条结果")
|
|
154
|
+
return [result for result in results if result is not None]
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, override
|
|
4
|
+
|
|
5
|
+
from distflow.data.types import MessageData
|
|
6
|
+
from distflow.embed.base import BaseEmbed
|
|
7
|
+
from distflow.embed.types import EmbeddingInputItem, EmbeddingResult
|
|
8
|
+
from distflow.utils import logger
|
|
9
|
+
from distflow.utils.timing import timing_context
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from sentence_transformers import SentenceTransformer
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SentenceTransformersEmbed(BaseEmbed):
|
|
16
|
+
"""基于 sentence-transformers 的嵌入器实现.
|
|
17
|
+
|
|
18
|
+
使用 sentence-transformers 库进行文本嵌入计算,支持批量处理和归一化。
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
model_name: str,
|
|
24
|
+
device: str = "cuda",
|
|
25
|
+
batch_size: int = 32,
|
|
26
|
+
normalize_embeddings: bool = True,
|
|
27
|
+
trust_remote_code: bool = False,
|
|
28
|
+
prompt: str | None = None,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""初始化 SentenceTransformersEmbed.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
model_name: 模型名称或路径
|
|
34
|
+
device: 运行设备,默认为 "cuda"
|
|
35
|
+
batch_size: 批处理大小,默认为 32
|
|
36
|
+
normalize_embeddings: 是否对嵌入向量进行归一化,默认为 True
|
|
37
|
+
trust_remote_code: 是否信任远程代码,默认为 False
|
|
38
|
+
prompt: 可选的前缀提示文本,会添加到每个输入文本前面
|
|
39
|
+
"""
|
|
40
|
+
logger.info(f"创建 SentenceTransformersEmbed,模型: {model_name}")
|
|
41
|
+
# 存储配置用于延迟初始化
|
|
42
|
+
self._model_name = model_name
|
|
43
|
+
self._device = device
|
|
44
|
+
self._batch_size = batch_size
|
|
45
|
+
self._normalize_embeddings = normalize_embeddings
|
|
46
|
+
self._trust_remote_code = trust_remote_code
|
|
47
|
+
self._prompt = prompt
|
|
48
|
+
self._model: SentenceTransformer | None = None
|
|
49
|
+
super().__init__(model_name)
|
|
50
|
+
|
|
51
|
+
def _ensure_initialized(self) -> None:
|
|
52
|
+
"""延迟初始化模型 - 仅在需要嵌入计算时才调用."""
|
|
53
|
+
if self._model is not None:
|
|
54
|
+
return
|
|
55
|
+
|
|
56
|
+
logger.info(f"开始加载 Sentence Transformers 模型: {self._model_name}")
|
|
57
|
+
logger.debug(
|
|
58
|
+
f"配置参数: device={self._device}, batch_size={self._batch_size}, "
|
|
59
|
+
f"normalize_embeddings={self._normalize_embeddings}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
with timing_context("模型加载"):
|
|
63
|
+
from sentence_transformers import SentenceTransformer
|
|
64
|
+
|
|
65
|
+
self._model = SentenceTransformer(
|
|
66
|
+
self._model_name,
|
|
67
|
+
device=self._device,
|
|
68
|
+
trust_remote_code=self._trust_remote_code,
|
|
69
|
+
)
|
|
70
|
+
logger.info(f"Sentence Transformers 模型加载完成: {self._model_name}")
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def model(self) -> SentenceTransformer:
|
|
74
|
+
"""获取模型实例(确保已初始化)."""
|
|
75
|
+
self._ensure_initialized()
|
|
76
|
+
assert self._model, "模型初始化后仍为 None"
|
|
77
|
+
return self._model
|
|
78
|
+
|
|
79
|
+
def _format_messages(self, messages: list[MessageData]) -> str:
|
|
80
|
+
"""将消息列表格式化为单个文本字符串.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
messages: 消息列表,每个消息包含 role 和 content
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
格式化后的文本字符串
|
|
87
|
+
"""
|
|
88
|
+
parts: list[str] = []
|
|
89
|
+
for msg in messages:
|
|
90
|
+
content = msg.content
|
|
91
|
+
if isinstance(content, dict):
|
|
92
|
+
# 如果 content 是字典,尝试提取 text 字段,否则转为字符串
|
|
93
|
+
content = content.get("text", str(content))
|
|
94
|
+
parts.append(f"{msg.role}: {content}")
|
|
95
|
+
return "\n".join(parts)
|
|
96
|
+
|
|
97
|
+
def _prepare_texts(self, dataset: list[EmbeddingInputItem]) -> list[str]:
|
|
98
|
+
"""准备输入文本列表.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
dataset: 待嵌入的数据项列表
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
格式化后的文本列表
|
|
105
|
+
"""
|
|
106
|
+
texts: list[str] = []
|
|
107
|
+
for item in dataset:
|
|
108
|
+
text = self._format_messages(item.messages)
|
|
109
|
+
if self._prompt:
|
|
110
|
+
text = self._prompt + text
|
|
111
|
+
texts.append(text)
|
|
112
|
+
return texts
|
|
113
|
+
|
|
114
|
+
@override
|
|
115
|
+
def embed(self, dataset: list[EmbeddingInputItem]) -> list[EmbeddingResult]:
|
|
116
|
+
"""执行嵌入计算.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
dataset: 待嵌入的数据项列表
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
嵌入结果列表
|
|
123
|
+
"""
|
|
124
|
+
logger.info(f"开始嵌入计算,数据量: {len(dataset)}")
|
|
125
|
+
|
|
126
|
+
# 确保模型已初始化
|
|
127
|
+
self._ensure_initialized()
|
|
128
|
+
|
|
129
|
+
# 准备输入文本
|
|
130
|
+
logger.debug("准备输入文本...")
|
|
131
|
+
texts = self._prepare_texts(dataset)
|
|
132
|
+
|
|
133
|
+
# 执行嵌入计算
|
|
134
|
+
logger.info("开始模型推理...")
|
|
135
|
+
with timing_context("模型推理"):
|
|
136
|
+
embeddings = self.model.encode(
|
|
137
|
+
texts,
|
|
138
|
+
batch_size=self._batch_size,
|
|
139
|
+
normalize_embeddings=self._normalize_embeddings,
|
|
140
|
+
show_progress_bar=True,
|
|
141
|
+
)
|
|
142
|
+
logger.info(f"嵌入计算完成,输出 {len(embeddings)} 条结果")
|
|
143
|
+
|
|
144
|
+
# 构建结果列表
|
|
145
|
+
return [
|
|
146
|
+
EmbeddingResult(
|
|
147
|
+
embedding=embedding.tolist(),
|
|
148
|
+
data_item=item,
|
|
149
|
+
meta=item.meta,
|
|
150
|
+
)
|
|
151
|
+
for embedding, item in zip(embeddings, dataset)
|
|
152
|
+
]
|
distflow/embed/types.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from typing import Any, TypeAlias
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
from distflow.data.types import DatasetProcessOutputItem
|
|
6
|
+
|
|
7
|
+
EmbeddingInputItem: TypeAlias = DatasetProcessOutputItem
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EmbeddingResult(BaseModel): # type: ignore[misc]
|
|
11
|
+
embedding: list[float]
|
|
12
|
+
data_item: DatasetProcessOutputItem
|
|
13
|
+
meta: dict[str, Any]
|