TransferQueue 0.0.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +307 -0
- recipe/simple_use_case/sync_demo.py +223 -0
- tests/test_client.py +390 -0
- tests/test_controller.py +268 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +479 -0
- transfer_queue/__init__.py +42 -0
- transfer_queue/client.py +663 -0
- transfer_queue/controller.py +772 -0
- transfer_queue/metadata.py +603 -0
- transfer_queue/storage.py +515 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +98 -0
- transfer_queue/utils/zmq_utils.py +175 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.0.1.dev0.dist-info/METADATA +15 -0
- transferqueue-0.0.1.dev0.dist-info/RECORD +21 -0
- transferqueue-0.0.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.0.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.0.1.dev0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import logging
|
|
17
|
+
import math
|
|
18
|
+
import os
|
|
19
|
+
import sys
|
|
20
|
+
import time
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
import ray
|
|
24
|
+
import torch
|
|
25
|
+
from omegaconf import OmegaConf
|
|
26
|
+
from tensordict import NonTensorData, TensorDict
|
|
27
|
+
|
|
28
|
+
parent_dir = Path(__file__).resolve().parent.parent.parent
|
|
29
|
+
sys.path.append(str(parent_dir))
|
|
30
|
+
|
|
31
|
+
from transfer_queue import ( # noqa: E402
|
|
32
|
+
AsyncTransferQueueClient,
|
|
33
|
+
BatchMeta,
|
|
34
|
+
TransferQueueController,
|
|
35
|
+
TransferQueueStorageSimpleUnit,
|
|
36
|
+
process_zmq_server_info,
|
|
37
|
+
)
|
|
38
|
+
from transfer_queue.utils.utils import get_placement_group # noqa: E402
|
|
39
|
+
|
|
40
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
41
|
+
logger = logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
os.environ["RAY_DEDUP_LOGS"] = "0"
|
|
44
|
+
os.environ["RAY_DEBUG"] = "1"
|
|
45
|
+
ray.init()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def compute_old_log_prob(data1, data2):
|
|
49
|
+
time.sleep(3)
|
|
50
|
+
return data1
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def generate_sequences(data):
|
|
54
|
+
time.sleep(3)
|
|
55
|
+
return data
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ActorRolloutRefWorker:
|
|
59
|
+
def actor_rollout_wg_generate_sequences(self, data_meta, data_system_client):
|
|
60
|
+
# 1. 根据data_meta通过client从storage unit中拉取真实data
|
|
61
|
+
data = asyncio.run(data_system_client.async_get_data(data_meta))
|
|
62
|
+
logger.info(f"demo get data->generate_sequences {data}")
|
|
63
|
+
|
|
64
|
+
output = generate_sequences(data["input_ids"])
|
|
65
|
+
|
|
66
|
+
output = TensorDict(
|
|
67
|
+
{
|
|
68
|
+
"generate_sequences_ids": output,
|
|
69
|
+
"non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(output.size(0))]),
|
|
70
|
+
"nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(output.size(0))]),
|
|
71
|
+
},
|
|
72
|
+
batch_size=output.size(0),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# 2. 根据data_meta将结果写回storage unit
|
|
76
|
+
asyncio.run(data_system_client.async_put(data=output, metadata=data_meta))
|
|
77
|
+
data_meta.add_fields(output)
|
|
78
|
+
logger.info("demo put data to storages done")
|
|
79
|
+
|
|
80
|
+
return data_meta
|
|
81
|
+
|
|
82
|
+
def actor_rollout_wg_compute_old_log_prob(self, data_meta, data_system_client):
|
|
83
|
+
# 1. 根据data_meta通过client从storage unit中拉取真实data
|
|
84
|
+
data = asyncio.run(data_system_client.async_get_data(data_meta))
|
|
85
|
+
logger.info(f"demo get data->old_log_prob {data}")
|
|
86
|
+
|
|
87
|
+
output = compute_old_log_prob(data["input_ids"], data["generate_sequences_ids"])
|
|
88
|
+
|
|
89
|
+
output = TensorDict({"old_log_prob": output}, batch_size=output.size(0))
|
|
90
|
+
|
|
91
|
+
# 2. 根据data_meta将结果写回storage unit
|
|
92
|
+
asyncio.run(data_system_client.async_put(data=output, metadata=data_meta))
|
|
93
|
+
data_meta.add_fields(output)
|
|
94
|
+
logger.info("demo put data to storages done")
|
|
95
|
+
|
|
96
|
+
return data_meta
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@ray.remote
|
|
100
|
+
class AsyncvLLMServer:
|
|
101
|
+
def __init__(self, data_system_client):
|
|
102
|
+
self.data_system_client = data_system_client
|
|
103
|
+
|
|
104
|
+
async def generate(self, data_meta):
|
|
105
|
+
data = await self.data_system_client.async_get_data(data_meta)
|
|
106
|
+
logger.info(f"demo get data->generate_sequences {data}")
|
|
107
|
+
|
|
108
|
+
data = data["input_ids"]
|
|
109
|
+
data += 1
|
|
110
|
+
await asyncio.sleep(3)
|
|
111
|
+
|
|
112
|
+
output = TensorDict(
|
|
113
|
+
{
|
|
114
|
+
"generate_sequences_ids": data,
|
|
115
|
+
"non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(data.size(0))]),
|
|
116
|
+
"nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(data.size(0))]),
|
|
117
|
+
},
|
|
118
|
+
batch_size=data.size(0),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
await self.data_system_client.async_put(data=output, metadata=data_meta)
|
|
122
|
+
logger.info("demo Async Server put data to storages done")
|
|
123
|
+
|
|
124
|
+
return data_meta
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@ray.remote(num_cpus=1)
|
|
128
|
+
class AsyncRolloutWorker:
|
|
129
|
+
def __init__(self, data_system_client):
|
|
130
|
+
self.async_vllm_server = AsyncvLLMServer.remote(data_system_client)
|
|
131
|
+
|
|
132
|
+
async def generate_sequences(self, data_meta_chunk):
|
|
133
|
+
tasks = []
|
|
134
|
+
for i in range(data_meta_chunk.size):
|
|
135
|
+
# asyncio.create_task cannot directly call Ray Actor methods,
|
|
136
|
+
# otherwise an error will be reported:a coroutine was expected, got ObjectRef(xxx)
|
|
137
|
+
tasks.append(asyncio.create_task(self.generate(data_meta_chunk[i])))
|
|
138
|
+
data_metas = await asyncio.gather(*tasks)
|
|
139
|
+
return BatchMeta.concat(data_metas)
|
|
140
|
+
|
|
141
|
+
async def generate(self, data_meta):
|
|
142
|
+
data_meta_new = await self.async_vllm_server.generate.remote(data_meta)
|
|
143
|
+
return data_meta_new
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class RolloutManager:
|
|
147
|
+
def __init__(self, config, data_system_client):
|
|
148
|
+
self.config = config
|
|
149
|
+
self.data_system_client = data_system_client
|
|
150
|
+
self.async_rollout_workers = []
|
|
151
|
+
num_workers = self.config.rollout_agent_num_workers
|
|
152
|
+
for i in range(num_workers):
|
|
153
|
+
self.async_rollout_workers.append(AsyncRolloutWorker.remote(self.data_system_client))
|
|
154
|
+
|
|
155
|
+
def generate_sequences(self, data_meta):
|
|
156
|
+
data_meta_chunkes = data_meta.chunk(len(self.async_rollout_workers))
|
|
157
|
+
data_metas = ray.get(
|
|
158
|
+
[
|
|
159
|
+
worker.generate_sequences.remote(data_meta_chunk)
|
|
160
|
+
for worker, data_meta_chunk in zip(self.async_rollout_workers, data_meta_chunkes, strict=True)
|
|
161
|
+
]
|
|
162
|
+
)
|
|
163
|
+
batch_meta = BatchMeta.concat(data_metas)
|
|
164
|
+
logger.info(f"batch_meta: {batch_meta}")
|
|
165
|
+
|
|
166
|
+
return batch_meta
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class Trainer:
|
|
170
|
+
def __init__(self, config):
|
|
171
|
+
self.config = config
|
|
172
|
+
self.data_system_client = self._initialize_data_system()
|
|
173
|
+
self.actor_rollout_wg = ActorRolloutRefWorker()
|
|
174
|
+
self.async_rollout_manager = RolloutManager(self.config, self.data_system_client)
|
|
175
|
+
|
|
176
|
+
def _initialize_data_system(self):
|
|
177
|
+
# 1. 初始化TransferQueueStorage
|
|
178
|
+
total_storage_size = self.config.global_batch_size * self.config.num_global_batch * self.config.num_n_samples
|
|
179
|
+
self.data_system_storage_units = {}
|
|
180
|
+
storage_placement_group = get_placement_group(self.config.num_data_storage_units, num_cpus_per_actor=1)
|
|
181
|
+
for storage_unit_rank in range(self.config.num_data_storage_units):
|
|
182
|
+
# TransferQueueStorage通过Ray拉起,是一个ray.remote修饰的类
|
|
183
|
+
storage_node = TransferQueueStorageSimpleUnit.options(
|
|
184
|
+
placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
|
|
185
|
+
).remote(storage_size=math.ceil(total_storage_size / self.config.num_data_storage_units))
|
|
186
|
+
self.data_system_storage_units[storage_unit_rank] = storage_node
|
|
187
|
+
logger.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")
|
|
188
|
+
|
|
189
|
+
# 2. 初始化TransferQueueController
|
|
190
|
+
# 这里支持多controller实例以实现负载均衡,支持大规模扩展。不同controller可分配至不同RL计算任务
|
|
191
|
+
self.data_system_controllers = {}
|
|
192
|
+
controller_placement_group = get_placement_group(self.config.num_data_controllers, num_cpus_per_actor=1)
|
|
193
|
+
for controller_rank in range(self.config.num_data_controllers):
|
|
194
|
+
self.data_system_controllers[controller_rank] = TransferQueueController.options(
|
|
195
|
+
placement_group=controller_placement_group, placement_group_bundle_index=controller_rank
|
|
196
|
+
).remote(
|
|
197
|
+
num_storage_units=self.config.num_data_storage_units,
|
|
198
|
+
global_batch_size=self.config.global_batch_size,
|
|
199
|
+
num_global_batch=self.config.num_global_batch,
|
|
200
|
+
num_n_samples=self.config.num_n_samples,
|
|
201
|
+
)
|
|
202
|
+
logger.info(f"TransferQueueController #{controller_rank} has been created.")
|
|
203
|
+
|
|
204
|
+
# 3. 将Controller注册至各个Storage
|
|
205
|
+
# 每个Storage Unit拿到所有Controller的handler,通过Ray拿到对应的IP+端口,之后建立ZMQ Socket进行消息传输
|
|
206
|
+
self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers)
|
|
207
|
+
self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units)
|
|
208
|
+
|
|
209
|
+
ray.get(
|
|
210
|
+
[
|
|
211
|
+
storage_unit.register_controller_info.remote(self.data_system_controller_infos)
|
|
212
|
+
for storage_unit in self.data_system_storage_units.values()
|
|
213
|
+
]
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# 4. 创建Client
|
|
217
|
+
self.data_system_client = AsyncTransferQueueClient(
|
|
218
|
+
client_id="Trainer",
|
|
219
|
+
controller_infos=self.data_system_controller_infos[0],
|
|
220
|
+
storage_infos=self.data_system_storage_unit_infos,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return self.data_system_client
|
|
224
|
+
|
|
225
|
+
def fit(self):
|
|
226
|
+
for epoch in range(1):
|
|
227
|
+
train_dataloader = 1
|
|
228
|
+
for step in range(train_dataloader):
|
|
229
|
+
input_ids = (
|
|
230
|
+
torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111], [200, 222], [300, 333]])
|
|
231
|
+
) * (step + 1)
|
|
232
|
+
input_ids_repeated = torch.repeat_interleave(input_ids, self.config.num_n_samples, dim=0)
|
|
233
|
+
prompt_batch = TensorDict(
|
|
234
|
+
{"input_ids": input_ids_repeated, "attention_mask": input_ids_repeated},
|
|
235
|
+
batch_size=input_ids_repeated.size(0),
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
asyncio.run(self.data_system_client.async_put(data=prompt_batch, global_step=step))
|
|
239
|
+
|
|
240
|
+
logger.info("demo put prompts ok! ")
|
|
241
|
+
time.sleep(5)
|
|
242
|
+
|
|
243
|
+
batch_meta = asyncio.run(
|
|
244
|
+
self.data_system_client.async_get_meta(
|
|
245
|
+
data_fields=["input_ids", "attention_mask"],
|
|
246
|
+
batch_size=self.config.global_batch_size * self.config.num_n_samples,
|
|
247
|
+
global_step=step,
|
|
248
|
+
get_n_samples=False,
|
|
249
|
+
task_name="generate_sequences",
|
|
250
|
+
)
|
|
251
|
+
)
|
|
252
|
+
logger.info(f"demo get meta {batch_meta}")
|
|
253
|
+
|
|
254
|
+
# Simulate calling the generate sequences task of the worker group
|
|
255
|
+
if not self.config.async_rollout_mode:
|
|
256
|
+
batch_meta = self.actor_rollout_wg.actor_rollout_wg_generate_sequences(
|
|
257
|
+
batch_meta, self.data_system_client
|
|
258
|
+
)
|
|
259
|
+
else:
|
|
260
|
+
batch_meta = self.async_rollout_manager.generate_sequences(batch_meta)
|
|
261
|
+
log_prob_meta = asyncio.run(
|
|
262
|
+
self.data_system_client.async_get_meta(
|
|
263
|
+
data_fields=["input_ids", "attention_mask", "generate_sequences_ids"],
|
|
264
|
+
batch_size=self.config.global_batch_size * self.config.num_n_samples,
|
|
265
|
+
global_step=step,
|
|
266
|
+
get_n_samples=False,
|
|
267
|
+
task_name="compute_old_log_prob",
|
|
268
|
+
)
|
|
269
|
+
)
|
|
270
|
+
logger.info(f"demo get log prob meta: {log_prob_meta}")
|
|
271
|
+
|
|
272
|
+
# Simulate calling the compute old log prob task of the worker group
|
|
273
|
+
old_log_prob_meta = self.actor_rollout_wg.actor_rollout_wg_compute_old_log_prob(
|
|
274
|
+
log_prob_meta, self.data_system_client
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
batch_meta = batch_meta.union(old_log_prob_meta)
|
|
278
|
+
|
|
279
|
+
# 对于主控的client,通知所有controller进行数据状态清空,主控返回metadata;
|
|
280
|
+
# client再根据metadata通知所有storage unit清空
|
|
281
|
+
# client选择一个主controller拿到metadata,其他的controller直接清空不用返回metadata即可
|
|
282
|
+
asyncio.run(self.data_system_client.async_clear(global_step=step))
|
|
283
|
+
logger.info("clear ok! ")
|
|
284
|
+
logger.info("demo done!")
|
|
285
|
+
return batch_meta
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
if __name__ == "__main__":
|
|
289
|
+
# NOTE: you may choose to set async_rollout_mode=True to test the async rollout mode that mimics
|
|
290
|
+
# AgentLoopManager in verl
|
|
291
|
+
|
|
292
|
+
config_str = """
|
|
293
|
+
global_batch_size: 8
|
|
294
|
+
num_global_batch: 1
|
|
295
|
+
num_data_storage_units: 2
|
|
296
|
+
num_data_controllers: 1
|
|
297
|
+
async_rollout_mode: True
|
|
298
|
+
rollout_agent_num_workers: 2
|
|
299
|
+
num_n_samples: 2
|
|
300
|
+
|
|
301
|
+
"""
|
|
302
|
+
dict_conf = OmegaConf.create(config_str)
|
|
303
|
+
|
|
304
|
+
trainer = Trainer(dict_conf)
|
|
305
|
+
trainer.fit()
|
|
306
|
+
|
|
307
|
+
ray.shutdown()
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import math
|
|
17
|
+
import os
|
|
18
|
+
import sys
|
|
19
|
+
import time
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
import ray
|
|
23
|
+
import torch
|
|
24
|
+
from omegaconf import OmegaConf
|
|
25
|
+
from tensordict import NonTensorData, TensorDict
|
|
26
|
+
|
|
27
|
+
parent_dir = Path(__file__).resolve().parent.parent.parent
|
|
28
|
+
sys.path.append(str(parent_dir))
|
|
29
|
+
|
|
30
|
+
from transfer_queue import ( # noqa: E402
|
|
31
|
+
TransferQueueController,
|
|
32
|
+
TransferQueueStorageSimpleUnit,
|
|
33
|
+
process_zmq_server_info,
|
|
34
|
+
)
|
|
35
|
+
from transfer_queue.utils.utils import get_placement_group # noqa: E402
|
|
36
|
+
|
|
37
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
os.environ["RAY_DEDUP_LOGS"] = "0"
|
|
41
|
+
os.environ["RAY_DEBUG"] = "1"
|
|
42
|
+
ray.init()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def initialize_data_system(config):
|
|
46
|
+
# 1. 初始化TransferQueueStorage
|
|
47
|
+
total_storage_size = config.global_batch_size * config.num_global_batch * config.num_n_samples
|
|
48
|
+
data_system_storage_units = {}
|
|
49
|
+
storage_placement_group = get_placement_group(config.num_data_storage_units, num_cpus_per_actor=1)
|
|
50
|
+
for storage_unit_rank in range(config.num_data_storage_units):
|
|
51
|
+
# TransferQueueStorage通过Ray拉起,是一个ray.remote修饰的类
|
|
52
|
+
storage_node = TransferQueueStorageSimpleUnit.options(
|
|
53
|
+
placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
|
|
54
|
+
).remote(storage_size=math.ceil(total_storage_size / config.num_data_storage_units))
|
|
55
|
+
data_system_storage_units[storage_unit_rank] = storage_node
|
|
56
|
+
logger.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")
|
|
57
|
+
|
|
58
|
+
# 2. 初始化TransferQueueController
|
|
59
|
+
# 这里支持多controller实例以实现负载均衡,支持大规模扩展。不同controller可分配至不同RL计算任务
|
|
60
|
+
controller_placement_group = get_placement_group(config.num_data_controllers, num_cpus_per_actor=1)
|
|
61
|
+
data_system_controllers = {}
|
|
62
|
+
for controller_rank in range(config.num_data_controllers):
|
|
63
|
+
data_system_controllers[controller_rank] = TransferQueueController.options(
|
|
64
|
+
placement_group=controller_placement_group, placement_group_bundle_index=controller_rank
|
|
65
|
+
).remote(
|
|
66
|
+
num_storage_units=config.num_data_storage_units,
|
|
67
|
+
global_batch_size=config.global_batch_size,
|
|
68
|
+
num_global_batch=config.num_global_batch,
|
|
69
|
+
num_n_samples=config.num_n_samples,
|
|
70
|
+
)
|
|
71
|
+
logger.info(f"TransferQueueController #{controller_rank} has been created.")
|
|
72
|
+
|
|
73
|
+
# 3. 将Controller注册至各个Storage
|
|
74
|
+
# 每个Storage Unit拿到所有Controller的handler,通过Ray拿到对应的IP+端口,之后建立ZMQ Socket进行消息传输
|
|
75
|
+
data_system_controller_infos = process_zmq_server_info(data_system_controllers)
|
|
76
|
+
data_system_storage_unit_infos = process_zmq_server_info(data_system_storage_units)
|
|
77
|
+
|
|
78
|
+
ray.get(
|
|
79
|
+
[
|
|
80
|
+
storage_unit.register_controller_info.remote(data_system_controller_infos)
|
|
81
|
+
for storage_unit in data_system_storage_units.values()
|
|
82
|
+
]
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# 4. 创建Client
|
|
86
|
+
from transfer_queue import TransferQueueClient
|
|
87
|
+
|
|
88
|
+
data_system_client = TransferQueueClient(
|
|
89
|
+
client_id="Trainer",
|
|
90
|
+
controller_infos=data_system_controller_infos[
|
|
91
|
+
0
|
|
92
|
+
], # TODO: 主控Client感知所有controller,WorkerGroup和Worker的Client感知一个controller
|
|
93
|
+
storage_infos=data_system_storage_unit_infos,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return data_system_controllers, data_system_storage_units, data_system_client
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def generate_sequences(data):
|
|
100
|
+
time.sleep(3)
|
|
101
|
+
return data
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def compute_old_log_prob(data1, data2):
|
|
105
|
+
time.sleep(3)
|
|
106
|
+
return data1
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def actor_rollout_wg_generate_sequences(data_meta, data_system_client):
|
|
110
|
+
# 1. 根据data_meta通过client从storage unit中拉取真实data
|
|
111
|
+
data = data_system_client.get_data(data_meta)
|
|
112
|
+
logger.info(f"demo get data {data}")
|
|
113
|
+
|
|
114
|
+
output = generate_sequences(data["input_ids"])
|
|
115
|
+
|
|
116
|
+
output = TensorDict(
|
|
117
|
+
{
|
|
118
|
+
"generate_sequences_ids": output,
|
|
119
|
+
"non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(output.size(0))]),
|
|
120
|
+
"nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(output.size(0))]),
|
|
121
|
+
},
|
|
122
|
+
batch_size=output.size(0),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# 2. 根据data_meta将结果写回storage unit
|
|
126
|
+
data_system_client.put(data=output, metadata=data_meta)
|
|
127
|
+
data_meta.add_fields(output)
|
|
128
|
+
logger.info("demo put data to storages done")
|
|
129
|
+
|
|
130
|
+
return data_meta
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def actor_rollout_wg_compute_old_log_prob(data_meta, data_system_client):
|
|
134
|
+
# 1. 根据data_meta通过client从storage unit中拉取真实data
|
|
135
|
+
data = data_system_client.get_data(data_meta)
|
|
136
|
+
logger.info(f"demo get data {data}")
|
|
137
|
+
|
|
138
|
+
output = compute_old_log_prob(data["input_ids"], data["generate_sequences_ids"])
|
|
139
|
+
|
|
140
|
+
output = TensorDict({"old_log_prob": output}, batch_size=output.size(0))
|
|
141
|
+
|
|
142
|
+
# 2. 根据data_meta将结果写回storage unit
|
|
143
|
+
data_system_client.put(data=output, metadata=data_meta)
|
|
144
|
+
data_meta.add_fields(output)
|
|
145
|
+
logger.info("demo put data to storages done")
|
|
146
|
+
|
|
147
|
+
return data_meta
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# Simulate the fit function of the trainer
|
|
151
|
+
def fit(config, data_system_client):
|
|
152
|
+
for epoch in range(1):
|
|
153
|
+
train_dataloader = 1
|
|
154
|
+
for step in range(train_dataloader):
|
|
155
|
+
input_ids = (torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111]])) * (step + 1)
|
|
156
|
+
input_ids_repeated = torch.repeat_interleave(input_ids, config.num_n_samples, dim=0)
|
|
157
|
+
prompt_batch = TensorDict(
|
|
158
|
+
{"input_ids": input_ids_repeated, "attention_mask": input_ids_repeated},
|
|
159
|
+
batch_size=input_ids_repeated.size(0),
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
data_system_client.put(data=prompt_batch, global_step=step)
|
|
163
|
+
logger.info("demo put prompts ok! ")
|
|
164
|
+
time.sleep(5)
|
|
165
|
+
|
|
166
|
+
batch_meta = data_system_client.get_meta(
|
|
167
|
+
data_fields=["input_ids", "attention_mask"],
|
|
168
|
+
batch_size=config.global_batch_size,
|
|
169
|
+
global_step=step,
|
|
170
|
+
get_n_samples=False,
|
|
171
|
+
task_name="generate_sequences",
|
|
172
|
+
)
|
|
173
|
+
# Set output fields for RL training - in this case, we want to generate sequences from input_ids
|
|
174
|
+
logger.info(f"demo get meta {batch_meta}")
|
|
175
|
+
|
|
176
|
+
# Simulate calling the generate sequences task of the worker group
|
|
177
|
+
batch_meta = actor_rollout_wg_generate_sequences(batch_meta, data_system_client)
|
|
178
|
+
log_prob_meta = data_system_client.get_meta(
|
|
179
|
+
data_fields=["input_ids", "attention_mask", "generate_sequences_ids"],
|
|
180
|
+
batch_size=config.global_batch_size,
|
|
181
|
+
global_step=0,
|
|
182
|
+
get_n_samples=False,
|
|
183
|
+
task_name="compute_old_log_prob",
|
|
184
|
+
)
|
|
185
|
+
# Set output fields for RL training - we want to compute log probs for the generated sequences
|
|
186
|
+
logger.info(f"demo get log prob meta: {log_prob_meta}")
|
|
187
|
+
|
|
188
|
+
# Simulate calling the compute old log prob task of the worker group
|
|
189
|
+
old_log_prob_meta = actor_rollout_wg_compute_old_log_prob(log_prob_meta, data_system_client)
|
|
190
|
+
|
|
191
|
+
batch_meta = batch_meta.union(old_log_prob_meta)
|
|
192
|
+
|
|
193
|
+
# 对于主控的client,通知所有controller进行数据状态清空,主控返回metadata;
|
|
194
|
+
# client再根据metadata通知所有storage unit清空
|
|
195
|
+
# client选择一个主controller拿到metadata,其他的controller直接清空不用返回metadata即可
|
|
196
|
+
data_system_client.clear(global_step=step)
|
|
197
|
+
logger.info("clear ok! ")
|
|
198
|
+
logger.info("demo done!")
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def main(config):
|
|
202
|
+
# Initialize Data System:基于Ray拉起Controller以及Storage
|
|
203
|
+
data_system_controllers, data_system_storage_units, data_system_client = initialize_data_system(config)
|
|
204
|
+
import time
|
|
205
|
+
|
|
206
|
+
time.sleep(5)
|
|
207
|
+
|
|
208
|
+
fit(config, data_system_client)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
if __name__ == "__main__":
|
|
212
|
+
config_str = """
|
|
213
|
+
global_batch_size: 6
|
|
214
|
+
num_global_batch: 1
|
|
215
|
+
num_data_storage_units: 2
|
|
216
|
+
num_data_controllers: 1
|
|
217
|
+
num_n_samples: 2
|
|
218
|
+
"""
|
|
219
|
+
dict_conf = OmegaConf.create(config_str)
|
|
220
|
+
|
|
221
|
+
main(dict_conf)
|
|
222
|
+
|
|
223
|
+
ray.shutdown()
|