trinity-rft 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. trinity/__init__.py +4 -0
  2. trinity/buffer/__init__.py +7 -0
  3. trinity/buffer/buffer.py +65 -0
  4. trinity/buffer/buffer_reader.py +13 -0
  5. trinity/buffer/buffer_writer.py +15 -0
  6. trinity/buffer/queue.py +54 -0
  7. trinity/buffer/reader/__init__.py +0 -0
  8. trinity/buffer/reader/file_reader.py +231 -0
  9. trinity/buffer/reader/queue_reader.py +34 -0
  10. trinity/buffer/reader/sql_reader.py +79 -0
  11. trinity/buffer/schema/__init__.py +3 -0
  12. trinity/buffer/schema/sql_schema.py +149 -0
  13. trinity/buffer/utils.py +33 -0
  14. trinity/buffer/writer/__init__.py +0 -0
  15. trinity/buffer/writer/queue_writer.py +30 -0
  16. trinity/buffer/writer/sql_writer.py +46 -0
  17. trinity/cli/client.py +44 -0
  18. trinity/cli/launcher.py +238 -0
  19. trinity/cli/server.py +32 -0
  20. trinity/common/__init__.py +0 -0
  21. trinity/common/config.py +578 -0
  22. trinity/common/constants.py +119 -0
  23. trinity/common/experience.py +278 -0
  24. trinity/common/models/__init__.py +139 -0
  25. trinity/common/models/model.py +130 -0
  26. trinity/common/models/openai_api.py +79 -0
  27. trinity/common/models/utils.py +265 -0
  28. trinity/common/models/vllm_async_model.py +353 -0
  29. trinity/common/models/vllm_model.py +287 -0
  30. trinity/common/models/vllm_worker.py +74 -0
  31. trinity/common/rewards/__init__.py +11 -0
  32. trinity/common/rewards/accuracy_reward.py +33 -0
  33. trinity/common/rewards/agents_reward.py +1 -0
  34. trinity/common/rewards/base.py +24 -0
  35. trinity/common/rewards/composite_reward.py +24 -0
  36. trinity/common/rewards/format_reward.py +29 -0
  37. trinity/common/rewards/human_reward.py +1 -0
  38. trinity/common/rewards/reward_fn.py +197 -0
  39. trinity/common/rewards/tool_reward.py +1 -0
  40. trinity/common/schema.py +148 -0
  41. trinity/common/verl_config.py +346 -0
  42. trinity/common/workflows/__init__.py +16 -0
  43. trinity/common/workflows/envs/alfworld/alfworld_workflow.py +179 -0
  44. trinity/common/workflows/envs/sciworld/sciworld_workflow.py +144 -0
  45. trinity/common/workflows/envs/webshop/webshop_workflow.py +273 -0
  46. trinity/common/workflows/workflow.py +245 -0
  47. trinity/data/controllers/active_iterator.py +290 -0
  48. trinity/data/controllers/default_ops.py +77 -0
  49. trinity/data/controllers/task_parser.py +303 -0
  50. trinity/data/core/comparator.py +84 -0
  51. trinity/data/core/dataset.py +136 -0
  52. trinity/data/core/dataset_db.py +84 -0
  53. trinity/data/core/formatter.py +151 -0
  54. trinity/data/processors/base.py +143 -0
  55. trinity/data/processors/cleaner.py +229 -0
  56. trinity/data/processors/human_annotator.py +47 -0
  57. trinity/data/processors/synthesizer.py +107 -0
  58. trinity/data/server.py +27 -0
  59. trinity/explorer/__init__.py +4 -0
  60. trinity/explorer/explorer.py +299 -0
  61. trinity/explorer/runner_pool.py +269 -0
  62. trinity/explorer/workflow_runner.py +109 -0
  63. trinity/manager/__init__.py +7 -0
  64. trinity/manager/config_manager.py +1786 -0
  65. trinity/manager/manager.py +69 -0
  66. trinity/trainer/__init__.py +3 -0
  67. trinity/trainer/trainer.py +175 -0
  68. trinity/trainer/verl/__init__.py +0 -0
  69. trinity/trainer/verl/core_algos.py +717 -0
  70. trinity/trainer/verl/dp_actor.py +538 -0
  71. trinity/trainer/verl/fsdp_workers.py +1522 -0
  72. trinity/trainer/verl/ray_trainer.py +1160 -0
  73. trinity/trainer/verl_trainer.py +552 -0
  74. trinity/utils/__init__.py +0 -0
  75. trinity/utils/distributed.py +82 -0
  76. trinity/utils/dlc_utils.py +86 -0
  77. trinity/utils/eval_utils.py +78 -0
  78. trinity/utils/log.py +65 -0
  79. trinity/utils/monitor.py +109 -0
  80. trinity/utils/registry.py +127 -0
  81. trinity_rft-0.1.0.dist-info/METADATA +396 -0
  82. trinity_rft-0.1.0.dist-info/RECORD +86 -0
  83. trinity_rft-0.1.0.dist-info/WHEEL +5 -0
  84. trinity_rft-0.1.0.dist-info/entry_points.txt +2 -0
  85. trinity_rft-0.1.0.dist-info/licenses/LICENSE +201 -0
  86. trinity_rft-0.1.0.dist-info/top_level.txt +1 -0
trinity/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Trinity-RFT (Reinforcement Fine-Tuning)"""
3
+
4
+ __version__ = "0.1.0"
@@ -0,0 +1,7 @@
1
+ from trinity.buffer.buffer import Buffer, get_buffer_reader, get_buffer_writer
2
+
3
+ __all__ = [
4
+ "Buffer",
5
+ "get_buffer_reader",
6
+ "get_buffer_writer",
7
+ ]
@@ -0,0 +1,65 @@
1
+ # -*- coding: utf-8 -*-
2
+ """The buffer module"""
3
+ import ray
4
+
5
+ from trinity.buffer.buffer_reader import BufferReader
6
+ from trinity.buffer.buffer_writer import BufferWriter
7
+ from trinity.common.config import BufferConfig, Config, StorageConfig
8
+ from trinity.common.constants import StorageType
9
+
10
+
11
+ @ray.remote(name="buffer")
12
+ class Buffer:
13
+ """Responsible for storing experiences."""
14
+
15
+ def __init__(self, config: Config):
16
+ self.buffer_mapping: dict[str, StorageConfig] = {}
17
+ self._register_from_config(config)
18
+
19
+ def get_dataset_info(self, dataset_name: str) -> StorageConfig:
20
+ storage_config = self.buffer_mapping.get(dataset_name, None)
21
+ if storage_config is None:
22
+ raise ValueError(f"{dataset_name} not found.")
23
+ return storage_config
24
+
25
+ def register_dataset(self, storage_config: StorageConfig) -> None:
26
+ if storage_config.name in self.buffer_mapping:
27
+ raise ValueError(f"{storage_config.name} already exists.")
28
+ self.buffer_mapping[storage_config.name] = storage_config
29
+
30
+
31
+ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig) -> BufferReader:
32
+ """Get a buffer reader for the given dataset name."""
33
+ if storage_config.storage_type == StorageType.SQL:
34
+ from trinity.buffer.reader.sql_reader import SQLReader
35
+
36
+ return SQLReader(storage_config, buffer_config)
37
+ elif storage_config.storage_type == StorageType.QUEUE:
38
+ from trinity.buffer.reader.queue_reader import QueueReader
39
+
40
+ return QueueReader(storage_config, buffer_config)
41
+ elif storage_config.storage_type == StorageType.FILE:
42
+ from trinity.buffer.reader.file_reader import FILE_READERS
43
+
44
+ file_read_type = storage_config.algorithm_type
45
+ if file_read_type is not None:
46
+ file_read_type = file_read_type.value
47
+ else:
48
+ file_read_type = "rollout"
49
+ return FILE_READERS.get(file_read_type)(storage_config, buffer_config)
50
+ else:
51
+ raise ValueError(f"{storage_config.storage_type} not supported.")
52
+
53
+
54
+ def get_buffer_writer(storage_config: StorageConfig, buffer_config: BufferConfig) -> BufferWriter:
55
+ """Get a buffer writer for the given dataset name."""
56
+ if storage_config.storage_type == StorageType.SQL:
57
+ from trinity.buffer.writer.sql_writer import SQLWriter
58
+
59
+ return SQLWriter(storage_config, buffer_config)
60
+ elif storage_config.storage_type == StorageType.QUEUE:
61
+ from trinity.buffer.writer.queue_writer import QueueWriter
62
+
63
+ return QueueWriter(storage_config, buffer_config)
64
+ else:
65
+ raise ValueError(f"{storage_config.storage_type} not supported.")
@@ -0,0 +1,13 @@
1
+ """Reader of the buffer."""
2
+ from abc import ABC, abstractmethod
3
+ from typing import List, Optional
4
+
5
+ from trinity.common.constants import ReadStrategy
6
+
7
+
8
+ class BufferReader(ABC):
9
+ """Interface of the buffer reader."""
10
+
11
+ @abstractmethod
12
+ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
13
+ """Read from buffer."""
@@ -0,0 +1,15 @@
1
+ """Writer of the buffer."""
2
+ from abc import ABC, abstractmethod
3
+ from typing import List
4
+
5
+
6
+ class BufferWriter(ABC):
7
+ """Interface of the buffer writer."""
8
+
9
+ @abstractmethod
10
+ def write(self, data: List) -> None:
11
+ """Write to buffer."""
12
+
13
+ @abstractmethod
14
+ def finish(self) -> None:
15
+ """Finish writing."""
@@ -0,0 +1,54 @@
1
+ """A queue implemented by Ray Actor."""
2
+ import asyncio
3
+ from copy import deepcopy
4
+ from typing import List
5
+
6
+ import ray
7
+
8
+ from trinity.buffer.writer.sql_writer import SQLWriter
9
+ from trinity.common.config import BufferConfig, StorageConfig
10
+ from trinity.common.constants import StorageType
11
+
12
+
13
+ @ray.remote
14
+ class QueueActor:
15
+ """An asyncio.Queue based queue actor."""
16
+
17
+ FINISH_MESSAGE = "$FINISH$"
18
+
19
+ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
20
+ self.config = config
21
+ self.capacity = getattr(config, "capacity", 10000)
22
+ self.queue = asyncio.Queue(self.capacity)
23
+ if storage_config.path is not None and len(storage_config.path) > 0:
24
+ sql_config = deepcopy(storage_config)
25
+ sql_config.storage_type = StorageType.SQL
26
+ self.sql_writer = SQLWriter(sql_config, self.config)
27
+ else:
28
+ self.sql_writer = None
29
+
30
+ def length(self) -> int:
31
+ """The length of the queue."""
32
+ return self.queue.qsize()
33
+
34
+ async def put_batch(self, exp_list: List) -> None:
35
+ """Put batch of experience."""
36
+ await self.queue.put(exp_list)
37
+ if self.sql_writer is not None:
38
+ self.sql_writer.write(exp_list)
39
+
40
+ async def finish(self) -> None:
41
+ """Stop the queue."""
42
+ await self.queue.put(self.FINISH_MESSAGE)
43
+
44
+ async def get_batch(self, batch_size: int) -> List:
45
+ """Get batch of experience."""
46
+ batch = []
47
+ while True:
48
+ exp_list = await self.queue.get()
49
+ if exp_list == self.FINISH_MESSAGE:
50
+ raise StopAsyncIteration()
51
+ batch.extend(exp_list)
52
+ if len(batch) >= batch_size:
53
+ break
54
+ return batch
File without changes
@@ -0,0 +1,231 @@
1
+ """Filed based buffer reader."""
2
+
3
+ from typing import List, Optional
4
+
5
+ import datasets
6
+ import transformers
7
+ from datasets import load_dataset
8
+
9
+ from trinity.buffer.buffer_reader import BufferReader
10
+ from trinity.common.config import BufferConfig, StorageConfig
11
+ from trinity.common.constants import AlgorithmType, PromptType, ReadStrategy, TaskType
12
+ from trinity.common.experience import Experience
13
+ from trinity.common.rewards import REWARD_FUNCTIONS
14
+ from trinity.common.workflows import WORKFLOWS, Task
15
+ from trinity.utils.registry import Registry
16
+
17
+ FILE_READERS = Registry("file_readers")
18
+
19
+
20
+ @FILE_READERS.register_module(AlgorithmType.SFT.value)
21
+ class SFTDataReader(BufferReader):
22
+ """Reader for SFT file data."""
23
+
24
+ def __init__(self, meta: StorageConfig, config: BufferConfig):
25
+ self.split = meta.split
26
+ subset_name = meta.subset_name
27
+ self.prompt_type = meta.format.prompt_type
28
+ self.messages_key = meta.format.messages_key
29
+ self.prompt_key = meta.format.prompt_key
30
+ self.response_key = meta.format.response_key
31
+ self.read_batch_size = config.read_batch_size
32
+ self.dataset = load_dataset(
33
+ meta.path, name=subset_name, split=self.split
34
+ ) # TODO: support resume
35
+ self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True)
36
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
37
+
38
+ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
39
+ try:
40
+ batch_data = next(self.data_iter)
41
+ except StopIteration:
42
+ self.dataset = self.dataset.shuffle()
43
+ self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True)
44
+ batch_data = next(self.data_iter)
45
+ exp_list = []
46
+ if self.prompt_type == PromptType.MESSAGES:
47
+ for messages in batch_data[self.messages_key]:
48
+ tokens = self.tokenizer.apply_chat_template(
49
+ messages, add_generation_prompt=False, return_tensors="pt"
50
+ )[0]
51
+ prompt_tokens = self.tokenizer.apply_chat_template(
52
+ messages[:-1], add_generation_prompt=True, return_tensors="pt"
53
+ )[0]
54
+ experience = Experience(
55
+ tokens=tokens,
56
+ prompt_length=len(prompt_tokens),
57
+ )
58
+ exp_list.append(experience)
59
+
60
+ elif self.prompt_type == PromptType.CHATPAIR:
61
+ for prompt_messages, response_messages in zip(
62
+ batch_data[self.prompt_key], batch_data[self.response_key]
63
+ ):
64
+ if not isinstance(prompt_messages, list):
65
+ prompt_messages = [prompt_messages]
66
+ if not isinstance(response_messages, list):
67
+ response_messages = [response_messages]
68
+ full_messages = prompt_messages + response_messages
69
+
70
+ tokens = self.tokenizer.apply_chat_template(
71
+ full_messages, add_generation_prompt=False, return_tensors="pt"
72
+ )[0]
73
+
74
+ prompt_tokens = self.tokenizer.apply_chat_template(
75
+ prompt_messages, add_generation_prompt=True, return_tensors="pt"
76
+ )[0]
77
+
78
+ experience = Experience(
79
+ tokens=tokens,
80
+ prompt_length=len(prompt_tokens),
81
+ )
82
+ exp_list.append(experience)
83
+
84
+ elif self.prompt_type == PromptType.PLAINTEXT:
85
+ # TODO: support HF format without chat template
86
+ for prompt, response in zip(batch_data[self.prompt_key], batch_data[self.response_key]):
87
+ tokens = self.tokenizer(prompt + response, return_tensors="pt")["input_ids"][0]
88
+ prompt_tokens = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0]
89
+ experience = Experience(
90
+ tokens=tokens,
91
+ prompt_length=len(prompt_tokens),
92
+ )
93
+ exp_list.append(experience)
94
+ else:
95
+ raise ValueError(f"Unknown data format: {self.prompt_type}")
96
+ return exp_list
97
+
98
+
99
+ @FILE_READERS.register_module(AlgorithmType.DPO.value)
100
+ class DPODataReader(BufferReader):
101
+ def __init__(self, meta: StorageConfig, config: BufferConfig):
102
+ self.split = meta.split
103
+ subset_name = meta.subset_name
104
+ self.prompt_type = meta.format.prompt_type
105
+ self.prompt_key = meta.format.prompt_key
106
+ self.chosen_key = meta.format.chosen_key
107
+ self.rejected_key = meta.format.rejected_key
108
+ self.read_batch_size = config.read_batch_size
109
+ self.dataset = load_dataset(
110
+ meta.path, name=subset_name, split=self.split
111
+ ) # TODO: support resume
112
+ self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True)
113
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
114
+
115
+ def _get_assistant_message(self, item) -> dict:
116
+ if isinstance(item, List):
117
+ item = item[0]
118
+ if isinstance(item, str):
119
+ return {"role": "assistant", "content": item}
120
+ else:
121
+ return item
122
+
123
+ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
124
+ try:
125
+ batch_data = next(self.data_iter)
126
+ except StopIteration:
127
+ self.dataset = self.dataset.shuffle()
128
+ self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True)
129
+ batch_data = next(self.data_iter)
130
+ exp_list = []
131
+ for prompt, chosen, rejected in zip(
132
+ batch_data[self.prompt_key], batch_data[self.chosen_key], batch_data[self.rejected_key]
133
+ ):
134
+ if self.prompt_type == PromptType.MESSAGES:
135
+ prompt_messages = prompt
136
+
137
+ elif self.prompt_type == PromptType.PLAINTEXT:
138
+ prompt_messages = [
139
+ {
140
+ "role": "user",
141
+ "content": prompt,
142
+ }
143
+ ]
144
+ else:
145
+ raise ValueError(f"Unknown prompt type: {self.prompt_type}")
146
+ prompt_tokens = self.tokenizer.apply_chat_template(
147
+ prompt_messages, add_generation_prompt=True, return_tensors="pt"
148
+ )[0]
149
+ prompt_length = len(prompt_tokens)
150
+ messages_with_chosen = prompt_messages + [self._get_assistant_message(chosen)]
151
+ chosen_tokens = self.tokenizer.apply_chat_template(
152
+ messages_with_chosen,
153
+ add_generation_prompt=False,
154
+ return_tensors="pt",
155
+ )[0][prompt_length:]
156
+ messages_with_rejected = prompt_messages + [self._get_assistant_message(rejected)]
157
+ rejected_tokens = self.tokenizer.apply_chat_template(
158
+ messages_with_rejected,
159
+ add_generation_prompt=False,
160
+ return_tensors="pt",
161
+ )[0][prompt_length:]
162
+ experience = Experience(
163
+ tokens=prompt_tokens,
164
+ prompt_length=len(prompt_tokens),
165
+ chosen=chosen_tokens,
166
+ rejected=rejected_tokens,
167
+ )
168
+ exp_list.append(experience)
169
+ return exp_list
170
+
171
+
172
+ @FILE_READERS.register_module("rollout")
173
+ class RolloutDataReader(BufferReader):
174
+ def __init__(self, meta: StorageConfig, config: BufferConfig):
175
+ self.meta = meta
176
+ self.name = meta.name
177
+ self.split = meta.split
178
+ subset_name = meta.subset_name
179
+ # disable datasets caching to avoid reuse old-version dataset
180
+ datasets.disable_caching()
181
+ self.dataset = load_dataset(
182
+ meta.path, name=subset_name, split=self.split
183
+ ) # TODO: may from db_url
184
+ # if task_type != TaskType.EVAL and config.db_url != "":
185
+ # logger.info(f"Loading dataset from database with url: {config.db_url}")
186
+ # db_type = config.db_url.split(":")[0]
187
+ # db_name = config.db_url.split("/")[-1]
188
+ # dataset = Dataset.from_sql(RftDatasetModel.__tablename__, f"{db_type}:///{db_name}")
189
+ datasets.enable_caching()
190
+ self.index = meta.index # TODO: apply shuffle
191
+
192
+ self.prompt_key = meta.format.prompt_key
193
+ self.response_key = meta.format.response_key
194
+ self.workflow_key = meta.format.workflow_key
195
+ self.reward_fn_key = meta.format.reward_fn_key
196
+
197
+ self.task_type = meta.task_type
198
+ self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type)
199
+ self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type)
200
+ self.total_epochs = meta.total_epochs if self.task_type == TaskType.EXPLORE else 1
201
+
202
+ def __len__(self):
203
+ return len(self.dataset)
204
+
205
+ def read(self, strategy: Optional[ReadStrategy] = None):
206
+ if self.index >= len(self.dataset) * self.total_epochs:
207
+ raise StopIteration
208
+ sample = self.dataset[self.index % len(self.dataset)]
209
+ workflow_class = (
210
+ WORKFLOWS.get(sample[self.workflow_key])
211
+ if self.workflow_key in sample
212
+ else self.default_workflow_cls
213
+ )
214
+ reward_fn = (
215
+ REWARD_FUNCTIONS.get(sample[self.reward_fn_key])
216
+ if self.reward_fn_key in sample
217
+ else self.default_reward_fn_cls
218
+ )
219
+ assert workflow_class is not None, "`default_reward_fn_type` or `workflow_key` is required"
220
+ task = Task(
221
+ workflow=workflow_class,
222
+ format_args=self.meta.format,
223
+ rollout_args=self.meta.rollout_args,
224
+ is_eval=self.meta.task_type == TaskType.EVAL,
225
+ reward_fn=reward_fn,
226
+ raw_task=sample,
227
+ )
228
+ self.index += 1
229
+ if self.task_type == TaskType.EVAL and self.index == len(self.dataset):
230
+ self.index = 0
231
+ return task
@@ -0,0 +1,34 @@
1
+ """Reader of the Queue buffer."""
2
+
3
+ from typing import List, Optional
4
+
5
+ import ray
6
+
7
+ from trinity.buffer.buffer_reader import BufferReader
8
+ from trinity.buffer.queue import QueueActor
9
+ from trinity.common.config import BufferConfig, StorageConfig
10
+ from trinity.common.constants import ReadStrategy, StorageType
11
+ from trinity.utils.log import get_logger
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ class QueueReader(BufferReader):
17
+ """Reader of the Queue buffer."""
18
+
19
+ def __init__(self, meta: StorageConfig, config: BufferConfig):
20
+ assert meta.storage_type == StorageType.QUEUE
21
+ self.config = config
22
+ self.queue = QueueActor.options(
23
+ name=f"queue-{meta.name}",
24
+ get_if_exists=True,
25
+ ).remote(meta, config)
26
+
27
+ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
28
+ if strategy is not None and strategy != ReadStrategy.FIFO:
29
+ raise NotImplementedError(f"Read strategy {strategy} not supported for Queue Reader.")
30
+ try:
31
+ exps = ray.get(self.queue.get_batch.remote(self.config.read_batch_size))
32
+ except StopAsyncIteration:
33
+ raise StopIteration()
34
+ return exps
@@ -0,0 +1,79 @@
1
+ """Reader of the SQL buffer."""
2
+
3
+ import time
4
+ from typing import List, Optional
5
+
6
+ from sqlalchemy import asc, create_engine, desc
7
+ from sqlalchemy.exc import OperationalError
8
+ from sqlalchemy.orm import sessionmaker
9
+ from sqlalchemy.pool import NullPool
10
+
11
+ from trinity.buffer.buffer_reader import BufferReader
12
+ from trinity.buffer.schema import Base, create_dynamic_table
13
+ from trinity.buffer.utils import retry_session
14
+ from trinity.common.config import BufferConfig, StorageConfig
15
+ from trinity.common.constants import ReadStrategy, StorageType
16
+ from trinity.utils.log import get_logger
17
+
18
+ logger = get_logger(__name__)
19
+
20
+
21
+ class SQLReader(BufferReader):
22
+ """Reader of the SQL buffer."""
23
+
24
+ def __init__(self, meta: StorageConfig, config: BufferConfig) -> None:
25
+ assert meta.storage_type == StorageType.SQL
26
+ self.engine = create_engine(meta.path, poolclass=NullPool)
27
+
28
+ self.table_model_cls = create_dynamic_table(meta.algorithm_type, meta.name)
29
+ try:
30
+ Base.metadata.create_all(self.engine, checkfirst=True)
31
+ except OperationalError:
32
+ logger.warning("Failed to create database, assuming it already exists.")
33
+ self.session = sessionmaker(bind=self.engine)
34
+ self.batch_size = config.read_batch_size
35
+ self.max_retry_times = config.max_retry_times
36
+ self.max_retry_interval = config.max_retry_interval
37
+
38
+ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
39
+ if strategy is None:
40
+ strategy = ReadStrategy.LFU
41
+
42
+ if strategy == ReadStrategy.LFU:
43
+ sortOrder = (asc(self.table_model_cls.consumed), desc(self.table_model_cls.id))
44
+
45
+ elif strategy == ReadStrategy.LRU:
46
+ sortOrder = (desc(self.table_model_cls.id),)
47
+
48
+ elif strategy == ReadStrategy.PRIORITY:
49
+ sortOrder = (desc(self.table_model_cls.priority), desc(self.table_model_cls.id))
50
+
51
+ else:
52
+ raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage")
53
+
54
+ exp_list = []
55
+ while len(exp_list) < self.batch_size:
56
+ if len(exp_list):
57
+ logger.info("waiting for experiences...")
58
+ time.sleep(1)
59
+ with retry_session(
60
+ self.session, self.max_retry_times, self.max_retry_interval
61
+ ) as session:
62
+ # get a batch of experiences from the database
63
+ experiences = (
64
+ session.query(self.table_model_cls)
65
+ .filter(self.table_model_cls.reward.isnot(None))
66
+ .order_by(*sortOrder) # TODO: very slow
67
+ .limit(self.batch_size - len(exp_list))
68
+ .with_for_update()
69
+ .all()
70
+ )
71
+ # update the consumed field
72
+ for exp in experiences:
73
+ exp.consumed += 1
74
+ exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences])
75
+ logger.info(f"get {len(exp_list)} experiences:")
76
+ logger.info(f"reward = {[exp.reward for exp in exp_list]}")
77
+ logger.info(f"first prompt_text = {exp_list[0].prompt_text}")
78
+ logger.info(f"first response_text = {exp_list[0].response_text}")
79
+ return exp_list
@@ -0,0 +1,3 @@
1
+ from .sql_schema import Base, create_dynamic_table
2
+
3
+ __all__ = ["create_dynamic_table", "Base"]
@@ -0,0 +1,149 @@
1
+ """Schema for SQLAlchemy models."""
2
+
3
+ from typing import Any, Optional, Union
4
+
5
+ from sqlalchemy import Column, Float, Integer, LargeBinary, String
6
+ from sqlalchemy.ext.declarative import declarative_base
7
+
8
+ from trinity.common.constants import AlgorithmType
9
+ from trinity.common.experience import Experience
10
+ from trinity.common.models.utils import tokenize_and_mask_messages_hf
11
+
12
+ Base = declarative_base()
13
+
14
+
15
+ class TaskModel(Base): # type: ignore
16
+ """Model for storing tasks in SQLAlchemy."""
17
+
18
+ __abstract__ = True
19
+
20
+ __table_args__ = {
21
+ "keep_existing": True,
22
+ }
23
+
24
+ id = Column(Integer, primary_key=True, autoincrement=True)
25
+ task_desc = Column(String, nullable=True)
26
+ workflow_type = Column(String, nullable=True)
27
+ reward_type = Column(String, nullable=True)
28
+
29
+
30
+ class ExperienceModel(Base): # type: ignore
31
+ """SQLAlchemy model for Experience."""
32
+
33
+ __abstract__ = True
34
+
35
+ __table_args__ = {
36
+ "keep_existing": True,
37
+ }
38
+
39
+ id = Column(Integer, primary_key=True, autoincrement=True)
40
+ serialized_exp = Column(LargeBinary, nullable=True)
41
+ prompt = Column(String, nullable=True)
42
+ response = Column(String, nullable=True)
43
+ reward = Column(Float, nullable=True)
44
+ consumed = Column(Integer, default=0)
45
+ priority = Column(Float, default=0.0)
46
+
47
+ def to_experience(self) -> Experience:
48
+ """Load the experience from the database."""
49
+ return Experience.deserialize(self.serialized_exp)
50
+
51
+ @classmethod
52
+ def from_experience(cls, experience: Experience):
53
+ """Save the experience to database."""
54
+ return cls(
55
+ serialized_exp=experience.serialize(),
56
+ reward=experience.reward,
57
+ prompt=experience.prompt_text,
58
+ response=experience.response_text,
59
+ )
60
+
61
+
62
+ class SFTDataModel(Base): # type: ignore
63
+ """SQLAlchemy model for SFT data."""
64
+
65
+ __abstract__ = True
66
+
67
+ __table_args__ = {
68
+ "keep_existing": True,
69
+ }
70
+
71
+ id = Column(Integer, primary_key=True, autoincrement=True)
72
+ serialized_exp = Column(LargeBinary, nullable=True)
73
+ messages = Column(String, nullable=True)
74
+ consumed = Column(Integer, default=0)
75
+
76
+ def to_experience(self) -> Experience:
77
+ """Load the experience from the database."""
78
+ return Experience.deserialize(self.serialized_exp)
79
+
80
+ @classmethod
81
+ def from_messages(
82
+ cls,
83
+ messages: list[dict],
84
+ tokenizer: Any,
85
+ chat_template: Optional[str] = None,
86
+ ) -> "SFTDataModel":
87
+ """Convert a list of messages into a single instance of SFT data."""
88
+ token_ids, action_mask = tokenize_and_mask_messages_hf(
89
+ tokenizer=tokenizer,
90
+ messages=messages,
91
+ chat_template=chat_template,
92
+ )
93
+ exp = Experience(
94
+ tokens=token_ids,
95
+ prompt_length=0,
96
+ action_mask=action_mask,
97
+ info={"response_num": sum([1 if m["role"] == "assistant" else 0 for m in messages])},
98
+ )
99
+ return cls(
100
+ serialized_exp=exp.serialize(),
101
+ messages=messages,
102
+ )
103
+
104
+
105
+ class DPODataModel(Base): # type: ignore
106
+ """SQLAlchemy model for DPO data."""
107
+
108
+ __abstract__ = True
109
+
110
+ __table_args__ = {
111
+ "keep_existing": True,
112
+ }
113
+
114
+ id = Column(Integer, primary_key=True, autoincrement=True)
115
+ serialized_exp = Column(LargeBinary, nullable=True)
116
+ chosen = Column(LargeBinary, nullable=True)
117
+ rejected = Column(LargeBinary, nullable=True)
118
+ consumed = Column(Integer, default=0)
119
+
120
+ def to_experience(self) -> Experience:
121
+ """Load the experience from the database."""
122
+ exp = Experience.deserialize(self.serialized_exp)
123
+ exp.chosen = Experience.deserialize(self.chosen)
124
+ exp.rejected = Experience.deserialize(self.rejected)
125
+ return exp
126
+
127
+
128
+ SCHEMA_MAPPING = {
129
+ None: TaskModel,
130
+ AlgorithmType.SFT: SFTDataModel,
131
+ AlgorithmType.PPO: ExperienceModel,
132
+ AlgorithmType.GRPO: ExperienceModel,
133
+ AlgorithmType.OPMD: ExperienceModel,
134
+ AlgorithmType.DPO: DPODataModel,
135
+ }
136
+
137
+
138
+ def create_dynamic_table(algorithm_type: Union[AlgorithmType | None], table_name: str) -> Any:
139
+ """Create a dynamic table based on the provided algorithm type and table name."""
140
+ if algorithm_type not in SCHEMA_MAPPING:
141
+ raise ValueError(f"Unknown schema: {algorithm_type}")
142
+
143
+ base_class = SCHEMA_MAPPING[algorithm_type]
144
+
145
+ table_attrs = {
146
+ "__tablename__": table_name,
147
+ }
148
+
149
+ return type(table_name, (base_class,), table_attrs)