trinity-rft 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- trinity/__init__.py +4 -0
- trinity/buffer/__init__.py +7 -0
- trinity/buffer/buffer.py +65 -0
- trinity/buffer/buffer_reader.py +13 -0
- trinity/buffer/buffer_writer.py +15 -0
- trinity/buffer/queue.py +54 -0
- trinity/buffer/reader/__init__.py +0 -0
- trinity/buffer/reader/file_reader.py +231 -0
- trinity/buffer/reader/queue_reader.py +34 -0
- trinity/buffer/reader/sql_reader.py +79 -0
- trinity/buffer/schema/__init__.py +3 -0
- trinity/buffer/schema/sql_schema.py +149 -0
- trinity/buffer/utils.py +33 -0
- trinity/buffer/writer/__init__.py +0 -0
- trinity/buffer/writer/queue_writer.py +30 -0
- trinity/buffer/writer/sql_writer.py +46 -0
- trinity/cli/client.py +44 -0
- trinity/cli/launcher.py +238 -0
- trinity/cli/server.py +32 -0
- trinity/common/__init__.py +0 -0
- trinity/common/config.py +578 -0
- trinity/common/constants.py +119 -0
- trinity/common/experience.py +278 -0
- trinity/common/models/__init__.py +139 -0
- trinity/common/models/model.py +130 -0
- trinity/common/models/openai_api.py +79 -0
- trinity/common/models/utils.py +265 -0
- trinity/common/models/vllm_async_model.py +353 -0
- trinity/common/models/vllm_model.py +287 -0
- trinity/common/models/vllm_worker.py +74 -0
- trinity/common/rewards/__init__.py +11 -0
- trinity/common/rewards/accuracy_reward.py +33 -0
- trinity/common/rewards/agents_reward.py +1 -0
- trinity/common/rewards/base.py +24 -0
- trinity/common/rewards/composite_reward.py +24 -0
- trinity/common/rewards/format_reward.py +29 -0
- trinity/common/rewards/human_reward.py +1 -0
- trinity/common/rewards/reward_fn.py +197 -0
- trinity/common/rewards/tool_reward.py +1 -0
- trinity/common/schema.py +148 -0
- trinity/common/verl_config.py +346 -0
- trinity/common/workflows/__init__.py +16 -0
- trinity/common/workflows/envs/alfworld/alfworld_workflow.py +179 -0
- trinity/common/workflows/envs/sciworld/sciworld_workflow.py +144 -0
- trinity/common/workflows/envs/webshop/webshop_workflow.py +273 -0
- trinity/common/workflows/workflow.py +245 -0
- trinity/data/controllers/active_iterator.py +290 -0
- trinity/data/controllers/default_ops.py +77 -0
- trinity/data/controllers/task_parser.py +303 -0
- trinity/data/core/comparator.py +84 -0
- trinity/data/core/dataset.py +136 -0
- trinity/data/core/dataset_db.py +84 -0
- trinity/data/core/formatter.py +151 -0
- trinity/data/processors/base.py +143 -0
- trinity/data/processors/cleaner.py +229 -0
- trinity/data/processors/human_annotator.py +47 -0
- trinity/data/processors/synthesizer.py +107 -0
- trinity/data/server.py +27 -0
- trinity/explorer/__init__.py +4 -0
- trinity/explorer/explorer.py +299 -0
- trinity/explorer/runner_pool.py +269 -0
- trinity/explorer/workflow_runner.py +109 -0
- trinity/manager/__init__.py +7 -0
- trinity/manager/config_manager.py +1786 -0
- trinity/manager/manager.py +69 -0
- trinity/trainer/__init__.py +3 -0
- trinity/trainer/trainer.py +175 -0
- trinity/trainer/verl/__init__.py +0 -0
- trinity/trainer/verl/core_algos.py +717 -0
- trinity/trainer/verl/dp_actor.py +538 -0
- trinity/trainer/verl/fsdp_workers.py +1522 -0
- trinity/trainer/verl/ray_trainer.py +1160 -0
- trinity/trainer/verl_trainer.py +552 -0
- trinity/utils/__init__.py +0 -0
- trinity/utils/distributed.py +82 -0
- trinity/utils/dlc_utils.py +86 -0
- trinity/utils/eval_utils.py +78 -0
- trinity/utils/log.py +65 -0
- trinity/utils/monitor.py +109 -0
- trinity/utils/registry.py +127 -0
- trinity_rft-0.1.0.dist-info/METADATA +396 -0
- trinity_rft-0.1.0.dist-info/RECORD +86 -0
- trinity_rft-0.1.0.dist-info/WHEEL +5 -0
- trinity_rft-0.1.0.dist-info/entry_points.txt +2 -0
- trinity_rft-0.1.0.dist-info/licenses/LICENSE +201 -0
- trinity_rft-0.1.0.dist-info/top_level.txt +1 -0
trinity/__init__.py
ADDED
trinity/buffer/buffer.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""The buffer module"""
|
|
3
|
+
import ray
|
|
4
|
+
|
|
5
|
+
from trinity.buffer.buffer_reader import BufferReader
|
|
6
|
+
from trinity.buffer.buffer_writer import BufferWriter
|
|
7
|
+
from trinity.common.config import BufferConfig, Config, StorageConfig
|
|
8
|
+
from trinity.common.constants import StorageType
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@ray.remote(name="buffer")
|
|
12
|
+
class Buffer:
|
|
13
|
+
"""Responsible for storing experiences."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, config: Config):
|
|
16
|
+
self.buffer_mapping: dict[str, StorageConfig] = {}
|
|
17
|
+
self._register_from_config(config)
|
|
18
|
+
|
|
19
|
+
def get_dataset_info(self, dataset_name: str) -> StorageConfig:
|
|
20
|
+
storage_config = self.buffer_mapping.get(dataset_name, None)
|
|
21
|
+
if storage_config is None:
|
|
22
|
+
raise ValueError(f"{dataset_name} not found.")
|
|
23
|
+
return storage_config
|
|
24
|
+
|
|
25
|
+
def register_dataset(self, storage_config: StorageConfig) -> None:
|
|
26
|
+
if storage_config.name in self.buffer_mapping:
|
|
27
|
+
raise ValueError(f"{storage_config.name} already exists.")
|
|
28
|
+
self.buffer_mapping[storage_config.name] = storage_config
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig) -> BufferReader:
|
|
32
|
+
"""Get a buffer reader for the given dataset name."""
|
|
33
|
+
if storage_config.storage_type == StorageType.SQL:
|
|
34
|
+
from trinity.buffer.reader.sql_reader import SQLReader
|
|
35
|
+
|
|
36
|
+
return SQLReader(storage_config, buffer_config)
|
|
37
|
+
elif storage_config.storage_type == StorageType.QUEUE:
|
|
38
|
+
from trinity.buffer.reader.queue_reader import QueueReader
|
|
39
|
+
|
|
40
|
+
return QueueReader(storage_config, buffer_config)
|
|
41
|
+
elif storage_config.storage_type == StorageType.FILE:
|
|
42
|
+
from trinity.buffer.reader.file_reader import FILE_READERS
|
|
43
|
+
|
|
44
|
+
file_read_type = storage_config.algorithm_type
|
|
45
|
+
if file_read_type is not None:
|
|
46
|
+
file_read_type = file_read_type.value
|
|
47
|
+
else:
|
|
48
|
+
file_read_type = "rollout"
|
|
49
|
+
return FILE_READERS.get(file_read_type)(storage_config, buffer_config)
|
|
50
|
+
else:
|
|
51
|
+
raise ValueError(f"{storage_config.storage_type} not supported.")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_buffer_writer(storage_config: StorageConfig, buffer_config: BufferConfig) -> BufferWriter:
|
|
55
|
+
"""Get a buffer writer for the given dataset name."""
|
|
56
|
+
if storage_config.storage_type == StorageType.SQL:
|
|
57
|
+
from trinity.buffer.writer.sql_writer import SQLWriter
|
|
58
|
+
|
|
59
|
+
return SQLWriter(storage_config, buffer_config)
|
|
60
|
+
elif storage_config.storage_type == StorageType.QUEUE:
|
|
61
|
+
from trinity.buffer.writer.queue_writer import QueueWriter
|
|
62
|
+
|
|
63
|
+
return QueueWriter(storage_config, buffer_config)
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError(f"{storage_config.storage_type} not supported.")
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Reader of the buffer."""
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
from trinity.common.constants import ReadStrategy
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BufferReader(ABC):
|
|
9
|
+
"""Interface of the buffer reader."""
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
|
|
13
|
+
"""Read from buffer."""
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Writer of the buffer."""
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BufferWriter(ABC):
|
|
7
|
+
"""Interface of the buffer writer."""
|
|
8
|
+
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def write(self, data: List) -> None:
|
|
11
|
+
"""Write to buffer."""
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def finish(self) -> None:
|
|
15
|
+
"""Finish writing."""
|
trinity/buffer/queue.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""A queue implemented by Ray Actor."""
|
|
2
|
+
import asyncio
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
import ray
|
|
7
|
+
|
|
8
|
+
from trinity.buffer.writer.sql_writer import SQLWriter
|
|
9
|
+
from trinity.common.config import BufferConfig, StorageConfig
|
|
10
|
+
from trinity.common.constants import StorageType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@ray.remote
|
|
14
|
+
class QueueActor:
|
|
15
|
+
"""An asyncio.Queue based queue actor."""
|
|
16
|
+
|
|
17
|
+
FINISH_MESSAGE = "$FINISH$"
|
|
18
|
+
|
|
19
|
+
def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
|
|
20
|
+
self.config = config
|
|
21
|
+
self.capacity = getattr(config, "capacity", 10000)
|
|
22
|
+
self.queue = asyncio.Queue(self.capacity)
|
|
23
|
+
if storage_config.path is not None and len(storage_config.path) > 0:
|
|
24
|
+
sql_config = deepcopy(storage_config)
|
|
25
|
+
sql_config.storage_type = StorageType.SQL
|
|
26
|
+
self.sql_writer = SQLWriter(sql_config, self.config)
|
|
27
|
+
else:
|
|
28
|
+
self.sql_writer = None
|
|
29
|
+
|
|
30
|
+
def length(self) -> int:
|
|
31
|
+
"""The length of the queue."""
|
|
32
|
+
return self.queue.qsize()
|
|
33
|
+
|
|
34
|
+
async def put_batch(self, exp_list: List) -> None:
|
|
35
|
+
"""Put batch of experience."""
|
|
36
|
+
await self.queue.put(exp_list)
|
|
37
|
+
if self.sql_writer is not None:
|
|
38
|
+
self.sql_writer.write(exp_list)
|
|
39
|
+
|
|
40
|
+
async def finish(self) -> None:
|
|
41
|
+
"""Stop the queue."""
|
|
42
|
+
await self.queue.put(self.FINISH_MESSAGE)
|
|
43
|
+
|
|
44
|
+
async def get_batch(self, batch_size: int) -> List:
|
|
45
|
+
"""Get batch of experience."""
|
|
46
|
+
batch = []
|
|
47
|
+
while True:
|
|
48
|
+
exp_list = await self.queue.get()
|
|
49
|
+
if exp_list == self.FINISH_MESSAGE:
|
|
50
|
+
raise StopAsyncIteration()
|
|
51
|
+
batch.extend(exp_list)
|
|
52
|
+
if len(batch) >= batch_size:
|
|
53
|
+
break
|
|
54
|
+
return batch
|
|
File without changes
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""Filed based buffer reader."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
import datasets
|
|
6
|
+
import transformers
|
|
7
|
+
from datasets import load_dataset
|
|
8
|
+
|
|
9
|
+
from trinity.buffer.buffer_reader import BufferReader
|
|
10
|
+
from trinity.common.config import BufferConfig, StorageConfig
|
|
11
|
+
from trinity.common.constants import AlgorithmType, PromptType, ReadStrategy, TaskType
|
|
12
|
+
from trinity.common.experience import Experience
|
|
13
|
+
from trinity.common.rewards import REWARD_FUNCTIONS
|
|
14
|
+
from trinity.common.workflows import WORKFLOWS, Task
|
|
15
|
+
from trinity.utils.registry import Registry
|
|
16
|
+
|
|
17
|
+
FILE_READERS = Registry("file_readers")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@FILE_READERS.register_module(AlgorithmType.SFT.value)
|
|
21
|
+
class SFTDataReader(BufferReader):
|
|
22
|
+
"""Reader for SFT file data."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, meta: StorageConfig, config: BufferConfig):
|
|
25
|
+
self.split = meta.split
|
|
26
|
+
subset_name = meta.subset_name
|
|
27
|
+
self.prompt_type = meta.format.prompt_type
|
|
28
|
+
self.messages_key = meta.format.messages_key
|
|
29
|
+
self.prompt_key = meta.format.prompt_key
|
|
30
|
+
self.response_key = meta.format.response_key
|
|
31
|
+
self.read_batch_size = config.read_batch_size
|
|
32
|
+
self.dataset = load_dataset(
|
|
33
|
+
meta.path, name=subset_name, split=self.split
|
|
34
|
+
) # TODO: support resume
|
|
35
|
+
self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True)
|
|
36
|
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
|
|
37
|
+
|
|
38
|
+
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
|
|
39
|
+
try:
|
|
40
|
+
batch_data = next(self.data_iter)
|
|
41
|
+
except StopIteration:
|
|
42
|
+
self.dataset = self.dataset.shuffle()
|
|
43
|
+
self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True)
|
|
44
|
+
batch_data = next(self.data_iter)
|
|
45
|
+
exp_list = []
|
|
46
|
+
if self.prompt_type == PromptType.MESSAGES:
|
|
47
|
+
for messages in batch_data[self.messages_key]:
|
|
48
|
+
tokens = self.tokenizer.apply_chat_template(
|
|
49
|
+
messages, add_generation_prompt=False, return_tensors="pt"
|
|
50
|
+
)[0]
|
|
51
|
+
prompt_tokens = self.tokenizer.apply_chat_template(
|
|
52
|
+
messages[:-1], add_generation_prompt=True, return_tensors="pt"
|
|
53
|
+
)[0]
|
|
54
|
+
experience = Experience(
|
|
55
|
+
tokens=tokens,
|
|
56
|
+
prompt_length=len(prompt_tokens),
|
|
57
|
+
)
|
|
58
|
+
exp_list.append(experience)
|
|
59
|
+
|
|
60
|
+
elif self.prompt_type == PromptType.CHATPAIR:
|
|
61
|
+
for prompt_messages, response_messages in zip(
|
|
62
|
+
batch_data[self.prompt_key], batch_data[self.response_key]
|
|
63
|
+
):
|
|
64
|
+
if not isinstance(prompt_messages, list):
|
|
65
|
+
prompt_messages = [prompt_messages]
|
|
66
|
+
if not isinstance(response_messages, list):
|
|
67
|
+
response_messages = [response_messages]
|
|
68
|
+
full_messages = prompt_messages + response_messages
|
|
69
|
+
|
|
70
|
+
tokens = self.tokenizer.apply_chat_template(
|
|
71
|
+
full_messages, add_generation_prompt=False, return_tensors="pt"
|
|
72
|
+
)[0]
|
|
73
|
+
|
|
74
|
+
prompt_tokens = self.tokenizer.apply_chat_template(
|
|
75
|
+
prompt_messages, add_generation_prompt=True, return_tensors="pt"
|
|
76
|
+
)[0]
|
|
77
|
+
|
|
78
|
+
experience = Experience(
|
|
79
|
+
tokens=tokens,
|
|
80
|
+
prompt_length=len(prompt_tokens),
|
|
81
|
+
)
|
|
82
|
+
exp_list.append(experience)
|
|
83
|
+
|
|
84
|
+
elif self.prompt_type == PromptType.PLAINTEXT:
|
|
85
|
+
# TODO: support HF format without chat template
|
|
86
|
+
for prompt, response in zip(batch_data[self.prompt_key], batch_data[self.response_key]):
|
|
87
|
+
tokens = self.tokenizer(prompt + response, return_tensors="pt")["input_ids"][0]
|
|
88
|
+
prompt_tokens = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0]
|
|
89
|
+
experience = Experience(
|
|
90
|
+
tokens=tokens,
|
|
91
|
+
prompt_length=len(prompt_tokens),
|
|
92
|
+
)
|
|
93
|
+
exp_list.append(experience)
|
|
94
|
+
else:
|
|
95
|
+
raise ValueError(f"Unknown data format: {self.prompt_type}")
|
|
96
|
+
return exp_list
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@FILE_READERS.register_module(AlgorithmType.DPO.value)
|
|
100
|
+
class DPODataReader(BufferReader):
|
|
101
|
+
def __init__(self, meta: StorageConfig, config: BufferConfig):
|
|
102
|
+
self.split = meta.split
|
|
103
|
+
subset_name = meta.subset_name
|
|
104
|
+
self.prompt_type = meta.format.prompt_type
|
|
105
|
+
self.prompt_key = meta.format.prompt_key
|
|
106
|
+
self.chosen_key = meta.format.chosen_key
|
|
107
|
+
self.rejected_key = meta.format.rejected_key
|
|
108
|
+
self.read_batch_size = config.read_batch_size
|
|
109
|
+
self.dataset = load_dataset(
|
|
110
|
+
meta.path, name=subset_name, split=self.split
|
|
111
|
+
) # TODO: support resume
|
|
112
|
+
self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True)
|
|
113
|
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
|
|
114
|
+
|
|
115
|
+
def _get_assistant_message(self, item) -> dict:
|
|
116
|
+
if isinstance(item, List):
|
|
117
|
+
item = item[0]
|
|
118
|
+
if isinstance(item, str):
|
|
119
|
+
return {"role": "assistant", "content": item}
|
|
120
|
+
else:
|
|
121
|
+
return item
|
|
122
|
+
|
|
123
|
+
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
|
|
124
|
+
try:
|
|
125
|
+
batch_data = next(self.data_iter)
|
|
126
|
+
except StopIteration:
|
|
127
|
+
self.dataset = self.dataset.shuffle()
|
|
128
|
+
self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True)
|
|
129
|
+
batch_data = next(self.data_iter)
|
|
130
|
+
exp_list = []
|
|
131
|
+
for prompt, chosen, rejected in zip(
|
|
132
|
+
batch_data[self.prompt_key], batch_data[self.chosen_key], batch_data[self.rejected_key]
|
|
133
|
+
):
|
|
134
|
+
if self.prompt_type == PromptType.MESSAGES:
|
|
135
|
+
prompt_messages = prompt
|
|
136
|
+
|
|
137
|
+
elif self.prompt_type == PromptType.PLAINTEXT:
|
|
138
|
+
prompt_messages = [
|
|
139
|
+
{
|
|
140
|
+
"role": "user",
|
|
141
|
+
"content": prompt,
|
|
142
|
+
}
|
|
143
|
+
]
|
|
144
|
+
else:
|
|
145
|
+
raise ValueError(f"Unknown prompt type: {self.prompt_type}")
|
|
146
|
+
prompt_tokens = self.tokenizer.apply_chat_template(
|
|
147
|
+
prompt_messages, add_generation_prompt=True, return_tensors="pt"
|
|
148
|
+
)[0]
|
|
149
|
+
prompt_length = len(prompt_tokens)
|
|
150
|
+
messages_with_chosen = prompt_messages + [self._get_assistant_message(chosen)]
|
|
151
|
+
chosen_tokens = self.tokenizer.apply_chat_template(
|
|
152
|
+
messages_with_chosen,
|
|
153
|
+
add_generation_prompt=False,
|
|
154
|
+
return_tensors="pt",
|
|
155
|
+
)[0][prompt_length:]
|
|
156
|
+
messages_with_rejected = prompt_messages + [self._get_assistant_message(rejected)]
|
|
157
|
+
rejected_tokens = self.tokenizer.apply_chat_template(
|
|
158
|
+
messages_with_rejected,
|
|
159
|
+
add_generation_prompt=False,
|
|
160
|
+
return_tensors="pt",
|
|
161
|
+
)[0][prompt_length:]
|
|
162
|
+
experience = Experience(
|
|
163
|
+
tokens=prompt_tokens,
|
|
164
|
+
prompt_length=len(prompt_tokens),
|
|
165
|
+
chosen=chosen_tokens,
|
|
166
|
+
rejected=rejected_tokens,
|
|
167
|
+
)
|
|
168
|
+
exp_list.append(experience)
|
|
169
|
+
return exp_list
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@FILE_READERS.register_module("rollout")
|
|
173
|
+
class RolloutDataReader(BufferReader):
|
|
174
|
+
def __init__(self, meta: StorageConfig, config: BufferConfig):
|
|
175
|
+
self.meta = meta
|
|
176
|
+
self.name = meta.name
|
|
177
|
+
self.split = meta.split
|
|
178
|
+
subset_name = meta.subset_name
|
|
179
|
+
# disable datasets caching to avoid reuse old-version dataset
|
|
180
|
+
datasets.disable_caching()
|
|
181
|
+
self.dataset = load_dataset(
|
|
182
|
+
meta.path, name=subset_name, split=self.split
|
|
183
|
+
) # TODO: may from db_url
|
|
184
|
+
# if task_type != TaskType.EVAL and config.db_url != "":
|
|
185
|
+
# logger.info(f"Loading dataset from database with url: {config.db_url}")
|
|
186
|
+
# db_type = config.db_url.split(":")[0]
|
|
187
|
+
# db_name = config.db_url.split("/")[-1]
|
|
188
|
+
# dataset = Dataset.from_sql(RftDatasetModel.__tablename__, f"{db_type}:///{db_name}")
|
|
189
|
+
datasets.enable_caching()
|
|
190
|
+
self.index = meta.index # TODO: apply shuffle
|
|
191
|
+
|
|
192
|
+
self.prompt_key = meta.format.prompt_key
|
|
193
|
+
self.response_key = meta.format.response_key
|
|
194
|
+
self.workflow_key = meta.format.workflow_key
|
|
195
|
+
self.reward_fn_key = meta.format.reward_fn_key
|
|
196
|
+
|
|
197
|
+
self.task_type = meta.task_type
|
|
198
|
+
self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type)
|
|
199
|
+
self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type)
|
|
200
|
+
self.total_epochs = meta.total_epochs if self.task_type == TaskType.EXPLORE else 1
|
|
201
|
+
|
|
202
|
+
def __len__(self):
|
|
203
|
+
return len(self.dataset)
|
|
204
|
+
|
|
205
|
+
def read(self, strategy: Optional[ReadStrategy] = None):
|
|
206
|
+
if self.index >= len(self.dataset) * self.total_epochs:
|
|
207
|
+
raise StopIteration
|
|
208
|
+
sample = self.dataset[self.index % len(self.dataset)]
|
|
209
|
+
workflow_class = (
|
|
210
|
+
WORKFLOWS.get(sample[self.workflow_key])
|
|
211
|
+
if self.workflow_key in sample
|
|
212
|
+
else self.default_workflow_cls
|
|
213
|
+
)
|
|
214
|
+
reward_fn = (
|
|
215
|
+
REWARD_FUNCTIONS.get(sample[self.reward_fn_key])
|
|
216
|
+
if self.reward_fn_key in sample
|
|
217
|
+
else self.default_reward_fn_cls
|
|
218
|
+
)
|
|
219
|
+
assert workflow_class is not None, "`default_reward_fn_type` or `workflow_key` is required"
|
|
220
|
+
task = Task(
|
|
221
|
+
workflow=workflow_class,
|
|
222
|
+
format_args=self.meta.format,
|
|
223
|
+
rollout_args=self.meta.rollout_args,
|
|
224
|
+
is_eval=self.meta.task_type == TaskType.EVAL,
|
|
225
|
+
reward_fn=reward_fn,
|
|
226
|
+
raw_task=sample,
|
|
227
|
+
)
|
|
228
|
+
self.index += 1
|
|
229
|
+
if self.task_type == TaskType.EVAL and self.index == len(self.dataset):
|
|
230
|
+
self.index = 0
|
|
231
|
+
return task
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Reader of the Queue buffer."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
import ray
|
|
6
|
+
|
|
7
|
+
from trinity.buffer.buffer_reader import BufferReader
|
|
8
|
+
from trinity.buffer.queue import QueueActor
|
|
9
|
+
from trinity.common.config import BufferConfig, StorageConfig
|
|
10
|
+
from trinity.common.constants import ReadStrategy, StorageType
|
|
11
|
+
from trinity.utils.log import get_logger
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class QueueReader(BufferReader):
|
|
17
|
+
"""Reader of the Queue buffer."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, meta: StorageConfig, config: BufferConfig):
|
|
20
|
+
assert meta.storage_type == StorageType.QUEUE
|
|
21
|
+
self.config = config
|
|
22
|
+
self.queue = QueueActor.options(
|
|
23
|
+
name=f"queue-{meta.name}",
|
|
24
|
+
get_if_exists=True,
|
|
25
|
+
).remote(meta, config)
|
|
26
|
+
|
|
27
|
+
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
|
|
28
|
+
if strategy is not None and strategy != ReadStrategy.FIFO:
|
|
29
|
+
raise NotImplementedError(f"Read strategy {strategy} not supported for Queue Reader.")
|
|
30
|
+
try:
|
|
31
|
+
exps = ray.get(self.queue.get_batch.remote(self.config.read_batch_size))
|
|
32
|
+
except StopAsyncIteration:
|
|
33
|
+
raise StopIteration()
|
|
34
|
+
return exps
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Reader of the SQL buffer."""
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
from sqlalchemy import asc, create_engine, desc
|
|
7
|
+
from sqlalchemy.exc import OperationalError
|
|
8
|
+
from sqlalchemy.orm import sessionmaker
|
|
9
|
+
from sqlalchemy.pool import NullPool
|
|
10
|
+
|
|
11
|
+
from trinity.buffer.buffer_reader import BufferReader
|
|
12
|
+
from trinity.buffer.schema import Base, create_dynamic_table
|
|
13
|
+
from trinity.buffer.utils import retry_session
|
|
14
|
+
from trinity.common.config import BufferConfig, StorageConfig
|
|
15
|
+
from trinity.common.constants import ReadStrategy, StorageType
|
|
16
|
+
from trinity.utils.log import get_logger
|
|
17
|
+
|
|
18
|
+
logger = get_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SQLReader(BufferReader):
|
|
22
|
+
"""Reader of the SQL buffer."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, meta: StorageConfig, config: BufferConfig) -> None:
|
|
25
|
+
assert meta.storage_type == StorageType.SQL
|
|
26
|
+
self.engine = create_engine(meta.path, poolclass=NullPool)
|
|
27
|
+
|
|
28
|
+
self.table_model_cls = create_dynamic_table(meta.algorithm_type, meta.name)
|
|
29
|
+
try:
|
|
30
|
+
Base.metadata.create_all(self.engine, checkfirst=True)
|
|
31
|
+
except OperationalError:
|
|
32
|
+
logger.warning("Failed to create database, assuming it already exists.")
|
|
33
|
+
self.session = sessionmaker(bind=self.engine)
|
|
34
|
+
self.batch_size = config.read_batch_size
|
|
35
|
+
self.max_retry_times = config.max_retry_times
|
|
36
|
+
self.max_retry_interval = config.max_retry_interval
|
|
37
|
+
|
|
38
|
+
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
|
|
39
|
+
if strategy is None:
|
|
40
|
+
strategy = ReadStrategy.LFU
|
|
41
|
+
|
|
42
|
+
if strategy == ReadStrategy.LFU:
|
|
43
|
+
sortOrder = (asc(self.table_model_cls.consumed), desc(self.table_model_cls.id))
|
|
44
|
+
|
|
45
|
+
elif strategy == ReadStrategy.LRU:
|
|
46
|
+
sortOrder = (desc(self.table_model_cls.id),)
|
|
47
|
+
|
|
48
|
+
elif strategy == ReadStrategy.PRIORITY:
|
|
49
|
+
sortOrder = (desc(self.table_model_cls.priority), desc(self.table_model_cls.id))
|
|
50
|
+
|
|
51
|
+
else:
|
|
52
|
+
raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage")
|
|
53
|
+
|
|
54
|
+
exp_list = []
|
|
55
|
+
while len(exp_list) < self.batch_size:
|
|
56
|
+
if len(exp_list):
|
|
57
|
+
logger.info("waiting for experiences...")
|
|
58
|
+
time.sleep(1)
|
|
59
|
+
with retry_session(
|
|
60
|
+
self.session, self.max_retry_times, self.max_retry_interval
|
|
61
|
+
) as session:
|
|
62
|
+
# get a batch of experiences from the database
|
|
63
|
+
experiences = (
|
|
64
|
+
session.query(self.table_model_cls)
|
|
65
|
+
.filter(self.table_model_cls.reward.isnot(None))
|
|
66
|
+
.order_by(*sortOrder) # TODO: very slow
|
|
67
|
+
.limit(self.batch_size - len(exp_list))
|
|
68
|
+
.with_for_update()
|
|
69
|
+
.all()
|
|
70
|
+
)
|
|
71
|
+
# update the consumed field
|
|
72
|
+
for exp in experiences:
|
|
73
|
+
exp.consumed += 1
|
|
74
|
+
exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences])
|
|
75
|
+
logger.info(f"get {len(exp_list)} experiences:")
|
|
76
|
+
logger.info(f"reward = {[exp.reward for exp in exp_list]}")
|
|
77
|
+
logger.info(f"first prompt_text = {exp_list[0].prompt_text}")
|
|
78
|
+
logger.info(f"first response_text = {exp_list[0].response_text}")
|
|
79
|
+
return exp_list
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""Schema for SQLAlchemy models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional, Union
|
|
4
|
+
|
|
5
|
+
from sqlalchemy import Column, Float, Integer, LargeBinary, String
|
|
6
|
+
from sqlalchemy.ext.declarative import declarative_base
|
|
7
|
+
|
|
8
|
+
from trinity.common.constants import AlgorithmType
|
|
9
|
+
from trinity.common.experience import Experience
|
|
10
|
+
from trinity.common.models.utils import tokenize_and_mask_messages_hf
|
|
11
|
+
|
|
12
|
+
Base = declarative_base()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TaskModel(Base): # type: ignore
|
|
16
|
+
"""Model for storing tasks in SQLAlchemy."""
|
|
17
|
+
|
|
18
|
+
__abstract__ = True
|
|
19
|
+
|
|
20
|
+
__table_args__ = {
|
|
21
|
+
"keep_existing": True,
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
25
|
+
task_desc = Column(String, nullable=True)
|
|
26
|
+
workflow_type = Column(String, nullable=True)
|
|
27
|
+
reward_type = Column(String, nullable=True)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ExperienceModel(Base): # type: ignore
|
|
31
|
+
"""SQLAlchemy model for Experience."""
|
|
32
|
+
|
|
33
|
+
__abstract__ = True
|
|
34
|
+
|
|
35
|
+
__table_args__ = {
|
|
36
|
+
"keep_existing": True,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
40
|
+
serialized_exp = Column(LargeBinary, nullable=True)
|
|
41
|
+
prompt = Column(String, nullable=True)
|
|
42
|
+
response = Column(String, nullable=True)
|
|
43
|
+
reward = Column(Float, nullable=True)
|
|
44
|
+
consumed = Column(Integer, default=0)
|
|
45
|
+
priority = Column(Float, default=0.0)
|
|
46
|
+
|
|
47
|
+
def to_experience(self) -> Experience:
|
|
48
|
+
"""Load the experience from the database."""
|
|
49
|
+
return Experience.deserialize(self.serialized_exp)
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def from_experience(cls, experience: Experience):
|
|
53
|
+
"""Save the experience to database."""
|
|
54
|
+
return cls(
|
|
55
|
+
serialized_exp=experience.serialize(),
|
|
56
|
+
reward=experience.reward,
|
|
57
|
+
prompt=experience.prompt_text,
|
|
58
|
+
response=experience.response_text,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class SFTDataModel(Base): # type: ignore
|
|
63
|
+
"""SQLAlchemy model for SFT data."""
|
|
64
|
+
|
|
65
|
+
__abstract__ = True
|
|
66
|
+
|
|
67
|
+
__table_args__ = {
|
|
68
|
+
"keep_existing": True,
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
72
|
+
serialized_exp = Column(LargeBinary, nullable=True)
|
|
73
|
+
messages = Column(String, nullable=True)
|
|
74
|
+
consumed = Column(Integer, default=0)
|
|
75
|
+
|
|
76
|
+
def to_experience(self) -> Experience:
|
|
77
|
+
"""Load the experience from the database."""
|
|
78
|
+
return Experience.deserialize(self.serialized_exp)
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def from_messages(
|
|
82
|
+
cls,
|
|
83
|
+
messages: list[dict],
|
|
84
|
+
tokenizer: Any,
|
|
85
|
+
chat_template: Optional[str] = None,
|
|
86
|
+
) -> "SFTDataModel":
|
|
87
|
+
"""Convert a list of messages into a single instance of SFT data."""
|
|
88
|
+
token_ids, action_mask = tokenize_and_mask_messages_hf(
|
|
89
|
+
tokenizer=tokenizer,
|
|
90
|
+
messages=messages,
|
|
91
|
+
chat_template=chat_template,
|
|
92
|
+
)
|
|
93
|
+
exp = Experience(
|
|
94
|
+
tokens=token_ids,
|
|
95
|
+
prompt_length=0,
|
|
96
|
+
action_mask=action_mask,
|
|
97
|
+
info={"response_num": sum([1 if m["role"] == "assistant" else 0 for m in messages])},
|
|
98
|
+
)
|
|
99
|
+
return cls(
|
|
100
|
+
serialized_exp=exp.serialize(),
|
|
101
|
+
messages=messages,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class DPODataModel(Base): # type: ignore
|
|
106
|
+
"""SQLAlchemy model for DPO data."""
|
|
107
|
+
|
|
108
|
+
__abstract__ = True
|
|
109
|
+
|
|
110
|
+
__table_args__ = {
|
|
111
|
+
"keep_existing": True,
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
115
|
+
serialized_exp = Column(LargeBinary, nullable=True)
|
|
116
|
+
chosen = Column(LargeBinary, nullable=True)
|
|
117
|
+
rejected = Column(LargeBinary, nullable=True)
|
|
118
|
+
consumed = Column(Integer, default=0)
|
|
119
|
+
|
|
120
|
+
def to_experience(self) -> Experience:
|
|
121
|
+
"""Load the experience from the database."""
|
|
122
|
+
exp = Experience.deserialize(self.serialized_exp)
|
|
123
|
+
exp.chosen = Experience.deserialize(self.chosen)
|
|
124
|
+
exp.rejected = Experience.deserialize(self.rejected)
|
|
125
|
+
return exp
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
SCHEMA_MAPPING = {
|
|
129
|
+
None: TaskModel,
|
|
130
|
+
AlgorithmType.SFT: SFTDataModel,
|
|
131
|
+
AlgorithmType.PPO: ExperienceModel,
|
|
132
|
+
AlgorithmType.GRPO: ExperienceModel,
|
|
133
|
+
AlgorithmType.OPMD: ExperienceModel,
|
|
134
|
+
AlgorithmType.DPO: DPODataModel,
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def create_dynamic_table(algorithm_type: Union[AlgorithmType | None], table_name: str) -> Any:
|
|
139
|
+
"""Create a dynamic table based on the provided algorithm type and table name."""
|
|
140
|
+
if algorithm_type not in SCHEMA_MAPPING:
|
|
141
|
+
raise ValueError(f"Unknown schema: {algorithm_type}")
|
|
142
|
+
|
|
143
|
+
base_class = SCHEMA_MAPPING[algorithm_type]
|
|
144
|
+
|
|
145
|
+
table_attrs = {
|
|
146
|
+
"__tablename__": table_name,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
return type(table_name, (base_class,), table_attrs)
|