TransferQueue 0.1.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +331 -0
- recipe/simple_use_case/sync_demo.py +220 -0
- tests/test_async_simple_storage_manager.py +339 -0
- tests/test_client.py +423 -0
- tests/test_controller.py +274 -0
- tests/test_controller_data_partitions.py +513 -0
- tests/test_kv_storage_manager.py +92 -0
- tests/test_put.py +327 -0
- tests/test_samplers.py +492 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +443 -0
- tests/test_storage_client_factory.py +45 -0
- transfer_queue/__init__.py +48 -0
- transfer_queue/client.py +611 -0
- transfer_queue/controller.py +1187 -0
- transfer_queue/metadata.py +460 -0
- transfer_queue/sampler/__init__.py +19 -0
- transfer_queue/sampler/base.py +74 -0
- transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
- transfer_queue/sampler/sequential_sampler.py +75 -0
- transfer_queue/storage/__init__.py +25 -0
- transfer_queue/storage/clients/__init__.py +24 -0
- transfer_queue/storage/clients/base.py +22 -0
- transfer_queue/storage/clients/factory.py +55 -0
- transfer_queue/storage/clients/yuanrong_client.py +118 -0
- transfer_queue/storage/managers/__init__.py +23 -0
- transfer_queue/storage/managers/base.py +460 -0
- transfer_queue/storage/managers/factory.py +43 -0
- transfer_queue/storage/managers/simple_backend_manager.py +611 -0
- transfer_queue/storage/managers/yuanrong_manager.py +18 -0
- transfer_queue/storage/simple_backend.py +451 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +132 -0
- transfer_queue/utils/zmq_utils.py +170 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
- transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
- transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import logging
|
|
17
|
+
import math
|
|
18
|
+
import os
|
|
19
|
+
import sys
|
|
20
|
+
import time
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
import ray
|
|
24
|
+
import torch
|
|
25
|
+
from omegaconf import OmegaConf
|
|
26
|
+
from tensordict import NonTensorData, TensorDict
|
|
27
|
+
|
|
28
|
+
parent_dir = Path(__file__).resolve().parent.parent.parent
|
|
29
|
+
sys.path.append(str(parent_dir))
|
|
30
|
+
|
|
31
|
+
from transfer_queue import ( # noqa: E402
|
|
32
|
+
AsyncTransferQueueClient,
|
|
33
|
+
BatchMeta,
|
|
34
|
+
SimpleStorageUnit,
|
|
35
|
+
TransferQueueController,
|
|
36
|
+
process_zmq_server_info,
|
|
37
|
+
)
|
|
38
|
+
from transfer_queue.utils.utils import get_placement_group # noqa: E402
|
|
39
|
+
|
|
40
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
41
|
+
logger = logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
os.environ["RAY_DEDUP_LOGS"] = "0"
|
|
44
|
+
os.environ["RAY_DEBUG"] = "1"
|
|
45
|
+
ray.init()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def compute_old_log_prob(data1, data2):
|
|
49
|
+
time.sleep(3)
|
|
50
|
+
return data1
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def generate_sequences(data):
|
|
54
|
+
time.sleep(3)
|
|
55
|
+
return data
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ActorRolloutRefWorker:
|
|
59
|
+
def actor_rollout_wg_generate_sequences(self, data_meta, data_system_client):
|
|
60
|
+
# 1. Pull real data from the storage plane through client based on data_meta
|
|
61
|
+
data = asyncio.run(data_system_client.async_get_data(data_meta))
|
|
62
|
+
logger.info(f"demo get data->generate_sequences {data}")
|
|
63
|
+
|
|
64
|
+
output = generate_sequences(data["input_ids"])
|
|
65
|
+
|
|
66
|
+
output = TensorDict(
|
|
67
|
+
{
|
|
68
|
+
"generate_sequences_ids": output,
|
|
69
|
+
"non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(output.size(0))]),
|
|
70
|
+
"nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(output.size(0))]),
|
|
71
|
+
},
|
|
72
|
+
batch_size=output.size(0),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# 2. Write results back to the storage plane based on data_meta
|
|
76
|
+
asyncio.run(data_system_client.async_put(data=output, metadata=data_meta))
|
|
77
|
+
data_meta.add_fields(output)
|
|
78
|
+
logger.info("demo put data to storages done")
|
|
79
|
+
|
|
80
|
+
return data_meta
|
|
81
|
+
|
|
82
|
+
def actor_rollout_wg_compute_old_log_prob(self, data_meta, data_system_client):
|
|
83
|
+
# 1. Pull real data from the storage plane through client based on data_meta
|
|
84
|
+
data = asyncio.run(data_system_client.async_get_data(data_meta))
|
|
85
|
+
logger.info(f"demo get data->old_log_prob {data}")
|
|
86
|
+
|
|
87
|
+
output = compute_old_log_prob(data["input_ids"], data["generate_sequences_ids"])
|
|
88
|
+
|
|
89
|
+
output = TensorDict({"old_log_prob": output}, batch_size=output.size(0))
|
|
90
|
+
|
|
91
|
+
# 2. Write results back to the storage plane based on data_meta
|
|
92
|
+
asyncio.run(data_system_client.async_put(data=output, metadata=data_meta))
|
|
93
|
+
data_meta.add_fields(output)
|
|
94
|
+
logger.info("demo put data to storages done")
|
|
95
|
+
|
|
96
|
+
return data_meta
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@ray.remote
|
|
100
|
+
class AsyncvLLMServer:
|
|
101
|
+
def __init__(self, config, data_system_controller_info):
|
|
102
|
+
self.config = config
|
|
103
|
+
self.data_system_client = AsyncTransferQueueClient(
|
|
104
|
+
client_id="AsyncvLLMServer",
|
|
105
|
+
controller_info=data_system_controller_info,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
self.data_system_client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=self.config)
|
|
109
|
+
|
|
110
|
+
async def generate(self, data_meta):
|
|
111
|
+
data = await self.data_system_client.async_get_data(data_meta)
|
|
112
|
+
logger.info(f"demo get data->generate_sequences {data}")
|
|
113
|
+
|
|
114
|
+
data = data["input_ids"]
|
|
115
|
+
data += 1
|
|
116
|
+
await asyncio.sleep(3)
|
|
117
|
+
|
|
118
|
+
output = TensorDict(
|
|
119
|
+
{
|
|
120
|
+
"generate_sequences_ids": data,
|
|
121
|
+
"non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(data.size(0))]),
|
|
122
|
+
"nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(data.size(0))]),
|
|
123
|
+
},
|
|
124
|
+
batch_size=data.size(0),
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
await self.data_system_client.async_put(data=output, metadata=data_meta)
|
|
128
|
+
logger.info("demo Async Server put data to storages done")
|
|
129
|
+
|
|
130
|
+
return data_meta
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@ray.remote(num_cpus=1)
|
|
134
|
+
class AsyncRolloutWorker:
|
|
135
|
+
def __init__(
|
|
136
|
+
self,
|
|
137
|
+
config,
|
|
138
|
+
data_system_controller_info,
|
|
139
|
+
):
|
|
140
|
+
self.async_vllm_server = AsyncvLLMServer.remote(
|
|
141
|
+
config,
|
|
142
|
+
data_system_controller_info,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
async def generate_sequences(self, data_meta_chunk):
|
|
146
|
+
tasks = []
|
|
147
|
+
for i in range(data_meta_chunk.size):
|
|
148
|
+
# asyncio.create_task cannot directly call Ray Actor methods,
|
|
149
|
+
# otherwise an error will be reported:a coroutine was expected, got ObjectRef(xxx)
|
|
150
|
+
tasks.append(asyncio.create_task(self.generate(data_meta_chunk[i])))
|
|
151
|
+
data_metas = await asyncio.gather(*tasks)
|
|
152
|
+
return BatchMeta.concat(data_metas)
|
|
153
|
+
|
|
154
|
+
async def generate(self, data_meta):
|
|
155
|
+
data_meta_new = await self.async_vllm_server.generate.remote(data_meta)
|
|
156
|
+
return data_meta_new
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class RolloutManager:
|
|
160
|
+
def __init__(self, config, data_system_storage_unit_infos, data_system_controller_info):
|
|
161
|
+
self.config = config
|
|
162
|
+
|
|
163
|
+
self.data_system_client = AsyncTransferQueueClient(
|
|
164
|
+
client_id="RolloutManager",
|
|
165
|
+
controller_info=data_system_controller_info,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
self.data_system_client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=self.config)
|
|
169
|
+
|
|
170
|
+
self.async_rollout_workers = []
|
|
171
|
+
num_workers = self.config.rollout_agent_num_workers
|
|
172
|
+
for i in range(num_workers):
|
|
173
|
+
self.async_rollout_workers.append(AsyncRolloutWorker.remote(config, data_system_controller_info))
|
|
174
|
+
|
|
175
|
+
def generate_sequences(self, data_meta):
|
|
176
|
+
data_meta_chunkes = data_meta.chunk(len(self.async_rollout_workers))
|
|
177
|
+
data_metas = ray.get(
|
|
178
|
+
[
|
|
179
|
+
worker.generate_sequences.remote(data_meta_chunk)
|
|
180
|
+
for worker, data_meta_chunk in zip(self.async_rollout_workers, data_meta_chunkes, strict=True)
|
|
181
|
+
]
|
|
182
|
+
)
|
|
183
|
+
batch_meta = BatchMeta.concat(data_metas)
|
|
184
|
+
logger.info(f"batch_meta: {batch_meta}")
|
|
185
|
+
|
|
186
|
+
return batch_meta
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class Trainer:
|
|
190
|
+
def __init__(self, config):
|
|
191
|
+
self.config = config
|
|
192
|
+
self.data_system_client = self._initialize_data_system()
|
|
193
|
+
self.actor_rollout_wg = ActorRolloutRefWorker()
|
|
194
|
+
self.async_rollout_manager = RolloutManager(
|
|
195
|
+
self.config,
|
|
196
|
+
self.data_system_storage_unit_infos,
|
|
197
|
+
self.data_system_controller_info,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
def _initialize_data_system(self):
|
|
201
|
+
# 1. Initialize TransferQueueStorage
|
|
202
|
+
total_storage_size = self.config.global_batch_size * self.config.num_global_batch * self.config.num_n_samples
|
|
203
|
+
self.data_system_storage_units = {}
|
|
204
|
+
storage_placement_group = get_placement_group(self.config.num_data_storage_units, num_cpus_per_actor=1)
|
|
205
|
+
for storage_unit_rank in range(self.config.num_data_storage_units):
|
|
206
|
+
storage_node = SimpleStorageUnit.options(
|
|
207
|
+
placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
|
|
208
|
+
).remote(storage_unit_size=math.ceil(total_storage_size / self.config.num_data_storage_units))
|
|
209
|
+
self.data_system_storage_units[storage_unit_rank] = storage_node
|
|
210
|
+
logger.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.")
|
|
211
|
+
|
|
212
|
+
# 2. Initialize TransferQueueController (single controller only)
|
|
213
|
+
|
|
214
|
+
# Sampler usage instructions:
|
|
215
|
+
# For GRPO grouped sampling, you can initialize the controller with GRPOGroupNSampler:
|
|
216
|
+
# Option 1: Pass sampler class (will be instantiated automatically)
|
|
217
|
+
# self.data_system_controller = TransferQueueController.remote(sampler=GRPOGroupNSampler)
|
|
218
|
+
|
|
219
|
+
# Option 2: Pass sampler instance (if you need custom configuration)
|
|
220
|
+
# grpo_sampler = GRPOGroupNSampler()
|
|
221
|
+
# self.data_system_controller = TransferQueueController.remote(sampler=grpo_sampler)
|
|
222
|
+
|
|
223
|
+
# Then use sampling_config in get_meta calls:
|
|
224
|
+
# sampling_config={"n_samples_per_prompt": 4}
|
|
225
|
+
self.data_system_controller = TransferQueueController.remote()
|
|
226
|
+
logger.info("TransferQueueController has been created.")
|
|
227
|
+
|
|
228
|
+
# 3. Prepare necessary information
|
|
229
|
+
self.data_system_controller_info = process_zmq_server_info(self.data_system_controller)
|
|
230
|
+
self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units)
|
|
231
|
+
|
|
232
|
+
tq_config = OmegaConf.create({}, flags={"allow_objects": True}) # Note: Need to generate a new DictConfig
|
|
233
|
+
# with allow_objects=True to maintain ZMQServerInfo instance. Otherwise it will be flattened to dict
|
|
234
|
+
tq_config.controller_info = self.data_system_controller_info
|
|
235
|
+
tq_config.storage_unit_infos = self.data_system_storage_unit_infos
|
|
236
|
+
self.config = OmegaConf.merge(tq_config, self.config)
|
|
237
|
+
|
|
238
|
+
# 4. Create client
|
|
239
|
+
self.data_system_client = AsyncTransferQueueClient(
|
|
240
|
+
client_id="Trainer",
|
|
241
|
+
controller_info=self.data_system_controller_info,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
self.data_system_client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=self.config)
|
|
245
|
+
# Note: The client contains ZMQ objects. Currently, we cannot transmit the same client instance
|
|
246
|
+
# to multiple places, as this will cause serialization errors in Ray.
|
|
247
|
+
# Workaround: If you need to use a client in multiple Ray actors or processes, create a separate
|
|
248
|
+
# AsyncTransferQueueClient instance for each actor/process instead of sharing or transmitting the same instance.
|
|
249
|
+
return self.data_system_client
|
|
250
|
+
|
|
251
|
+
def fit(self):
|
|
252
|
+
for epoch in range(1):
|
|
253
|
+
train_dataloader = 1
|
|
254
|
+
for step in range(train_dataloader):
|
|
255
|
+
input_ids = (
|
|
256
|
+
torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111], [200, 222], [300, 333]])
|
|
257
|
+
) * (step + 1)
|
|
258
|
+
input_ids_repeated = torch.repeat_interleave(input_ids, self.config.num_n_samples, dim=0)
|
|
259
|
+
prompt_batch = TensorDict(
|
|
260
|
+
{"input_ids": input_ids_repeated, "attention_mask": input_ids_repeated},
|
|
261
|
+
batch_size=input_ids_repeated.size(0),
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
asyncio.run(self.data_system_client.async_put(data=prompt_batch, partition_id=f"train_{step}"))
|
|
265
|
+
|
|
266
|
+
logger.info("demo put prompts ok! ")
|
|
267
|
+
time.sleep(5)
|
|
268
|
+
|
|
269
|
+
batch_meta = asyncio.run(
|
|
270
|
+
self.data_system_client.async_get_meta(
|
|
271
|
+
data_fields=["input_ids", "attention_mask"],
|
|
272
|
+
batch_size=self.config.global_batch_size * self.config.num_n_samples,
|
|
273
|
+
partition_id=f"train_{step}",
|
|
274
|
+
task_name="generate_sequences",
|
|
275
|
+
)
|
|
276
|
+
)
|
|
277
|
+
logger.info(f"demo get meta {batch_meta}")
|
|
278
|
+
|
|
279
|
+
# Simulate calling the generate sequences task of the worker group
|
|
280
|
+
if not self.config.async_rollout_mode:
|
|
281
|
+
batch_meta = self.actor_rollout_wg.actor_rollout_wg_generate_sequences(
|
|
282
|
+
batch_meta, self.data_system_client
|
|
283
|
+
)
|
|
284
|
+
else:
|
|
285
|
+
batch_meta = self.async_rollout_manager.generate_sequences(batch_meta)
|
|
286
|
+
log_prob_meta = asyncio.run(
|
|
287
|
+
self.data_system_client.async_get_meta(
|
|
288
|
+
data_fields=["input_ids", "attention_mask", "generate_sequences_ids"],
|
|
289
|
+
batch_size=self.config.global_batch_size * self.config.num_n_samples,
|
|
290
|
+
partition_id=f"train_{step}",
|
|
291
|
+
task_name="compute_old_log_prob",
|
|
292
|
+
)
|
|
293
|
+
)
|
|
294
|
+
logger.info(f"demo get log prob meta: {log_prob_meta}")
|
|
295
|
+
|
|
296
|
+
# Simulate calling the compute old log prob task of the worker group
|
|
297
|
+
old_log_prob_meta = self.actor_rollout_wg.actor_rollout_wg_compute_old_log_prob(
|
|
298
|
+
log_prob_meta, self.data_system_client
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
batch_meta = batch_meta.union(old_log_prob_meta)
|
|
302
|
+
|
|
303
|
+
# Client notifies controller to clear data status, controller returns metadata;
|
|
304
|
+
# Client then notifies the storage plane to clear based on metadata
|
|
305
|
+
asyncio.run(self.data_system_client.async_clear(partition_id=f"train_{step}"))
|
|
306
|
+
logger.info("clear ok! ")
|
|
307
|
+
logger.info("demo done!")
|
|
308
|
+
|
|
309
|
+
# Cleanup resources
|
|
310
|
+
self.data_system_client.close()
|
|
311
|
+
return batch_meta
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
if __name__ == "__main__":
|
|
315
|
+
# NOTE: you may choose to set async_rollout_mode=True to test the async rollout mode that mimics
|
|
316
|
+
# AgentLoopManager in verl
|
|
317
|
+
config_str = """
|
|
318
|
+
global_batch_size: 8
|
|
319
|
+
num_global_batch: 1
|
|
320
|
+
num_data_storage_units: 2
|
|
321
|
+
async_rollout_mode: True
|
|
322
|
+
rollout_agent_num_workers: 2
|
|
323
|
+
num_n_samples: 2
|
|
324
|
+
|
|
325
|
+
"""
|
|
326
|
+
dict_conf = OmegaConf.create(config_str)
|
|
327
|
+
|
|
328
|
+
trainer = Trainer(dict_conf)
|
|
329
|
+
trainer.fit()
|
|
330
|
+
|
|
331
|
+
ray.shutdown()
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import math
|
|
17
|
+
import os
|
|
18
|
+
import sys
|
|
19
|
+
import time
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
import ray
|
|
23
|
+
import torch
|
|
24
|
+
from omegaconf import OmegaConf
|
|
25
|
+
from tensordict import NonTensorData, TensorDict
|
|
26
|
+
|
|
27
|
+
parent_dir = Path(__file__).resolve().parent.parent.parent
|
|
28
|
+
sys.path.append(str(parent_dir))
|
|
29
|
+
|
|
30
|
+
from transfer_queue import ( # noqa: E402
|
|
31
|
+
SimpleStorageUnit,
|
|
32
|
+
TransferQueueClient,
|
|
33
|
+
TransferQueueController,
|
|
34
|
+
process_zmq_server_info,
|
|
35
|
+
)
|
|
36
|
+
from transfer_queue.utils.utils import get_placement_group # noqa: E402
|
|
37
|
+
|
|
38
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
39
|
+
logger = logging.getLogger(__name__)
|
|
40
|
+
|
|
41
|
+
os.environ["RAY_DEDUP_LOGS"] = "0"
|
|
42
|
+
os.environ["RAY_DEBUG"] = "1"
|
|
43
|
+
ray.init()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def initialize_data_system(config):
|
|
47
|
+
# 1. Initialize TransferQueueStorage
|
|
48
|
+
total_storage_size = config.global_batch_size * config.num_global_batch * config.num_n_samples
|
|
49
|
+
data_system_storage_units = {}
|
|
50
|
+
storage_placement_group = get_placement_group(config.num_data_storage_units, num_cpus_per_actor=1)
|
|
51
|
+
for storage_unit_rank in range(config.num_data_storage_units):
|
|
52
|
+
storage_node = SimpleStorageUnit.options(
|
|
53
|
+
placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
|
|
54
|
+
).remote(storage_unit_size=math.ceil(total_storage_size / config.num_data_storage_units))
|
|
55
|
+
data_system_storage_units[storage_unit_rank] = storage_node
|
|
56
|
+
logger.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.")
|
|
57
|
+
|
|
58
|
+
# 2. Initialize TransferQueueController (single controller only)
|
|
59
|
+
|
|
60
|
+
# Sampler usage instructions:
|
|
61
|
+
# For GRPO grouped sampling, you can initialize the controller with GRPOGroupNSampler:
|
|
62
|
+
# Option 1: Pass sampler class (will be instantiated automatically)
|
|
63
|
+
# data_system_controller = TransferQueueController.remote(sampler=GRPOGroupNSampler)
|
|
64
|
+
|
|
65
|
+
# Option 2: Pass sampler instance (if you need custom configuration)
|
|
66
|
+
# grpo_sampler = GRPOGroupNSampler()
|
|
67
|
+
# data_system_controller = TransferQueueController.remote(sampler=grpo_sampler)
|
|
68
|
+
|
|
69
|
+
# Then use sampling_config in get_meta calls:
|
|
70
|
+
# sampling_config={"n_samples_per_prompt": 4}
|
|
71
|
+
data_system_controller = TransferQueueController.remote()
|
|
72
|
+
logger.info("TransferQueueController has been created.")
|
|
73
|
+
|
|
74
|
+
# 3. Prepare necessary information
|
|
75
|
+
data_system_controller_info = process_zmq_server_info(data_system_controller)
|
|
76
|
+
data_system_storage_unit_infos = process_zmq_server_info(data_system_storage_units)
|
|
77
|
+
|
|
78
|
+
tq_config = OmegaConf.create({}, flags={"allow_objects": True}) # Note: Need to generate a new DictConfig
|
|
79
|
+
# with allow_objects=True to maintain ZMQServerInfo instance. Otherwise it will be flattened to dict
|
|
80
|
+
tq_config.controller_info = data_system_controller_info
|
|
81
|
+
tq_config.storage_unit_infos = data_system_storage_unit_infos
|
|
82
|
+
config = OmegaConf.merge(tq_config, config)
|
|
83
|
+
|
|
84
|
+
# 4. Create client
|
|
85
|
+
data_system_client = TransferQueueClient(
|
|
86
|
+
client_id="Trainer",
|
|
87
|
+
controller_info=data_system_controller_info,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
data_system_client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
|
|
91
|
+
|
|
92
|
+
return data_system_controller, data_system_storage_units, data_system_client
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def generate_sequences(data):
|
|
96
|
+
time.sleep(3)
|
|
97
|
+
return data
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def compute_old_log_prob(data1, _data2):
|
|
101
|
+
time.sleep(3)
|
|
102
|
+
return data1
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def actor_rollout_wg_generate_sequences(data_meta, data_system_client):
|
|
106
|
+
# 1. Pull real data from the storage plane through client based on data_meta
|
|
107
|
+
data = data_system_client.get_data(data_meta)
|
|
108
|
+
logger.info(f"demo get data {data}")
|
|
109
|
+
|
|
110
|
+
output = generate_sequences(data["input_ids"])
|
|
111
|
+
|
|
112
|
+
output = TensorDict(
|
|
113
|
+
{
|
|
114
|
+
"generate_sequences_ids": output,
|
|
115
|
+
"non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(output.size(0))]),
|
|
116
|
+
"nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(output.size(0))]),
|
|
117
|
+
},
|
|
118
|
+
batch_size=output.size(0),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# 2. Write results back to the storage plane based on data_meta
|
|
122
|
+
data_system_client.put(data=output, metadata=data_meta)
|
|
123
|
+
data_meta.add_fields(output)
|
|
124
|
+
logger.info("demo put data to storages done")
|
|
125
|
+
|
|
126
|
+
return data_meta
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def actor_rollout_wg_compute_old_log_prob(data_meta, data_system_client):
|
|
130
|
+
# 1. Pull real data from the storage plane through client based on data_meta
|
|
131
|
+
data = data_system_client.get_data(data_meta)
|
|
132
|
+
logger.info(f"demo get data {data}")
|
|
133
|
+
|
|
134
|
+
output = compute_old_log_prob(data["input_ids"], data["generate_sequences_ids"])
|
|
135
|
+
|
|
136
|
+
output = TensorDict({"old_log_prob": output}, batch_size=output.size(0))
|
|
137
|
+
|
|
138
|
+
# 2. Write results back to the storage plane based on data_meta
|
|
139
|
+
data_system_client.put(data=output, metadata=data_meta)
|
|
140
|
+
data_meta.add_fields(output)
|
|
141
|
+
logger.info("demo put data to storages done")
|
|
142
|
+
|
|
143
|
+
return data_meta
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# Simulate the fit function of the trainer
|
|
147
|
+
def fit(config, data_system_client):
|
|
148
|
+
for _epoch in range(1):
|
|
149
|
+
train_dataloader = 1
|
|
150
|
+
for step in range(train_dataloader):
|
|
151
|
+
input_ids = (torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111]])) * (step + 1)
|
|
152
|
+
input_ids_repeated = torch.repeat_interleave(input_ids, config.num_n_samples, dim=0)
|
|
153
|
+
prompt_batch = TensorDict(
|
|
154
|
+
{"input_ids": input_ids_repeated, "attention_mask": input_ids_repeated},
|
|
155
|
+
batch_size=input_ids_repeated.size(0),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
data_system_client.put(data=prompt_batch, partition_id=f"train_{step}")
|
|
159
|
+
logger.info("demo put prompts ok! ")
|
|
160
|
+
time.sleep(5)
|
|
161
|
+
|
|
162
|
+
batch_meta = data_system_client.get_meta(
|
|
163
|
+
data_fields=["input_ids", "attention_mask"],
|
|
164
|
+
batch_size=config.global_batch_size,
|
|
165
|
+
partition_id=f"train_{step}",
|
|
166
|
+
task_name="generate_sequences",
|
|
167
|
+
)
|
|
168
|
+
# Set output fields for RL training - in this case, we want to generate sequences from input_ids
|
|
169
|
+
logger.info(f"demo get meta {batch_meta}")
|
|
170
|
+
|
|
171
|
+
# Simulate calling the generate sequences task of the worker group
|
|
172
|
+
batch_meta = actor_rollout_wg_generate_sequences(batch_meta, data_system_client)
|
|
173
|
+
log_prob_meta = data_system_client.get_meta(
|
|
174
|
+
data_fields=["input_ids", "attention_mask", "generate_sequences_ids"],
|
|
175
|
+
batch_size=config.global_batch_size,
|
|
176
|
+
partition_id=f"train_{step}",
|
|
177
|
+
task_name="compute_old_log_prob",
|
|
178
|
+
)
|
|
179
|
+
# Set output fields for RL training - we want to compute log probs for the generated sequences
|
|
180
|
+
logger.info(f"demo get log prob meta: {log_prob_meta}")
|
|
181
|
+
|
|
182
|
+
# Simulate calling the compute old log prob task of the worker group
|
|
183
|
+
old_log_prob_meta = actor_rollout_wg_compute_old_log_prob(log_prob_meta, data_system_client)
|
|
184
|
+
|
|
185
|
+
batch_meta = batch_meta.union(old_log_prob_meta)
|
|
186
|
+
|
|
187
|
+
# For the master client, notify all controllers to clear data status, master returns metadata;
|
|
188
|
+
# Client then notifies the storage plane to clear based on metadata
|
|
189
|
+
# Client selects one master controller to get metadata,
|
|
190
|
+
# other controllers directly clear without returning metadata
|
|
191
|
+
data_system_client.clear(partition_id=f"train_{step}")
|
|
192
|
+
logger.info("clear ok! ")
|
|
193
|
+
logger.info("demo done!")
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def main(config):
|
|
197
|
+
# Initialize Data System: Launching the Controller and Storage based on Ray
|
|
198
|
+
_data_system_controller, _data_system_storage_units, data_system_client = initialize_data_system(config)
|
|
199
|
+
import time
|
|
200
|
+
|
|
201
|
+
time.sleep(5)
|
|
202
|
+
|
|
203
|
+
fit(config, data_system_client)
|
|
204
|
+
|
|
205
|
+
# Cleanup resources
|
|
206
|
+
data_system_client.close()
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
if __name__ == "__main__":
|
|
210
|
+
config_str = """
|
|
211
|
+
global_batch_size: 6
|
|
212
|
+
num_global_batch: 1
|
|
213
|
+
num_data_storage_units: 2
|
|
214
|
+
num_n_samples: 2
|
|
215
|
+
"""
|
|
216
|
+
dict_conf = OmegaConf.create(config_str)
|
|
217
|
+
|
|
218
|
+
main(dict_conf)
|
|
219
|
+
|
|
220
|
+
ray.shutdown()
|