TransferQueue 0.0.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,307 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import logging
17
+ import math
18
+ import os
19
+ import sys
20
+ import time
21
+ from pathlib import Path
22
+
23
+ import ray
24
+ import torch
25
+ from omegaconf import OmegaConf
26
+ from tensordict import NonTensorData, TensorDict
27
+
28
+ parent_dir = Path(__file__).resolve().parent.parent.parent
29
+ sys.path.append(str(parent_dir))
30
+
31
+ from transfer_queue import ( # noqa: E402
32
+ AsyncTransferQueueClient,
33
+ BatchMeta,
34
+ TransferQueueController,
35
+ TransferQueueStorageSimpleUnit,
36
+ process_zmq_server_info,
37
+ )
38
+ from transfer_queue.utils.utils import get_placement_group # noqa: E402
39
+
40
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
41
+ logger = logging.getLogger(__name__)
42
+
43
+ os.environ["RAY_DEDUP_LOGS"] = "0"
44
+ os.environ["RAY_DEBUG"] = "1"
45
+ ray.init()
46
+
47
+
48
+ def compute_old_log_prob(data1, data2):
49
+ time.sleep(3)
50
+ return data1
51
+
52
+
53
+ def generate_sequences(data):
54
+ time.sleep(3)
55
+ return data
56
+
57
+
58
+ class ActorRolloutRefWorker:
59
+ def actor_rollout_wg_generate_sequences(self, data_meta, data_system_client):
60
+ # 1. 根据data_meta通过client从storage unit中拉取真实data
61
+ data = asyncio.run(data_system_client.async_get_data(data_meta))
62
+ logger.info(f"demo get data->generate_sequences {data}")
63
+
64
+ output = generate_sequences(data["input_ids"])
65
+
66
+ output = TensorDict(
67
+ {
68
+ "generate_sequences_ids": output,
69
+ "non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(output.size(0))]),
70
+ "nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(output.size(0))]),
71
+ },
72
+ batch_size=output.size(0),
73
+ )
74
+
75
+ # 2. 根据data_meta将结果写回storage unit
76
+ asyncio.run(data_system_client.async_put(data=output, metadata=data_meta))
77
+ data_meta.add_fields(output)
78
+ logger.info("demo put data to storages done")
79
+
80
+ return data_meta
81
+
82
+ def actor_rollout_wg_compute_old_log_prob(self, data_meta, data_system_client):
83
+ # 1. 根据data_meta通过client从storage unit中拉取真实data
84
+ data = asyncio.run(data_system_client.async_get_data(data_meta))
85
+ logger.info(f"demo get data->old_log_prob {data}")
86
+
87
+ output = compute_old_log_prob(data["input_ids"], data["generate_sequences_ids"])
88
+
89
+ output = TensorDict({"old_log_prob": output}, batch_size=output.size(0))
90
+
91
+ # 2. 根据data_meta将结果写回storage unit
92
+ asyncio.run(data_system_client.async_put(data=output, metadata=data_meta))
93
+ data_meta.add_fields(output)
94
+ logger.info("demo put data to storages done")
95
+
96
+ return data_meta
97
+
98
+
99
+ @ray.remote
100
+ class AsyncvLLMServer:
101
+ def __init__(self, data_system_client):
102
+ self.data_system_client = data_system_client
103
+
104
+ async def generate(self, data_meta):
105
+ data = await self.data_system_client.async_get_data(data_meta)
106
+ logger.info(f"demo get data->generate_sequences {data}")
107
+
108
+ data = data["input_ids"]
109
+ data += 1
110
+ await asyncio.sleep(3)
111
+
112
+ output = TensorDict(
113
+ {
114
+ "generate_sequences_ids": data,
115
+ "non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(data.size(0))]),
116
+ "nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(data.size(0))]),
117
+ },
118
+ batch_size=data.size(0),
119
+ )
120
+
121
+ await self.data_system_client.async_put(data=output, metadata=data_meta)
122
+ logger.info("demo Async Server put data to storages done")
123
+
124
+ return data_meta
125
+
126
+
127
+ @ray.remote(num_cpus=1)
128
+ class AsyncRolloutWorker:
129
+ def __init__(self, data_system_client):
130
+ self.async_vllm_server = AsyncvLLMServer.remote(data_system_client)
131
+
132
+ async def generate_sequences(self, data_meta_chunk):
133
+ tasks = []
134
+ for i in range(data_meta_chunk.size):
135
+ # asyncio.create_task cannot directly call Ray Actor methods,
136
+ # otherwise an error will be reported:a coroutine was expected, got ObjectRef(xxx)
137
+ tasks.append(asyncio.create_task(self.generate(data_meta_chunk[i])))
138
+ data_metas = await asyncio.gather(*tasks)
139
+ return BatchMeta.concat(data_metas)
140
+
141
+ async def generate(self, data_meta):
142
+ data_meta_new = await self.async_vllm_server.generate.remote(data_meta)
143
+ return data_meta_new
144
+
145
+
146
+ class RolloutManager:
147
+ def __init__(self, config, data_system_client):
148
+ self.config = config
149
+ self.data_system_client = data_system_client
150
+ self.async_rollout_workers = []
151
+ num_workers = self.config.rollout_agent_num_workers
152
+ for i in range(num_workers):
153
+ self.async_rollout_workers.append(AsyncRolloutWorker.remote(self.data_system_client))
154
+
155
+ def generate_sequences(self, data_meta):
156
+ data_meta_chunkes = data_meta.chunk(len(self.async_rollout_workers))
157
+ data_metas = ray.get(
158
+ [
159
+ worker.generate_sequences.remote(data_meta_chunk)
160
+ for worker, data_meta_chunk in zip(self.async_rollout_workers, data_meta_chunkes, strict=True)
161
+ ]
162
+ )
163
+ batch_meta = BatchMeta.concat(data_metas)
164
+ logger.info(f"batch_meta: {batch_meta}")
165
+
166
+ return batch_meta
167
+
168
+
169
+ class Trainer:
170
+ def __init__(self, config):
171
+ self.config = config
172
+ self.data_system_client = self._initialize_data_system()
173
+ self.actor_rollout_wg = ActorRolloutRefWorker()
174
+ self.async_rollout_manager = RolloutManager(self.config, self.data_system_client)
175
+
176
+ def _initialize_data_system(self):
177
+ # 1. 初始化TransferQueueStorage
178
+ total_storage_size = self.config.global_batch_size * self.config.num_global_batch * self.config.num_n_samples
179
+ self.data_system_storage_units = {}
180
+ storage_placement_group = get_placement_group(self.config.num_data_storage_units, num_cpus_per_actor=1)
181
+ for storage_unit_rank in range(self.config.num_data_storage_units):
182
+ # TransferQueueStorage通过Ray拉起,是一个ray.remote修饰的类
183
+ storage_node = TransferQueueStorageSimpleUnit.options(
184
+ placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
185
+ ).remote(storage_size=math.ceil(total_storage_size / self.config.num_data_storage_units))
186
+ self.data_system_storage_units[storage_unit_rank] = storage_node
187
+ logger.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")
188
+
189
+ # 2. 初始化TransferQueueController
190
+ # 这里支持多controller实例以实现负载均衡,支持大规模扩展。不同controller可分配至不同RL计算任务
191
+ self.data_system_controllers = {}
192
+ controller_placement_group = get_placement_group(self.config.num_data_controllers, num_cpus_per_actor=1)
193
+ for controller_rank in range(self.config.num_data_controllers):
194
+ self.data_system_controllers[controller_rank] = TransferQueueController.options(
195
+ placement_group=controller_placement_group, placement_group_bundle_index=controller_rank
196
+ ).remote(
197
+ num_storage_units=self.config.num_data_storage_units,
198
+ global_batch_size=self.config.global_batch_size,
199
+ num_global_batch=self.config.num_global_batch,
200
+ num_n_samples=self.config.num_n_samples,
201
+ )
202
+ logger.info(f"TransferQueueController #{controller_rank} has been created.")
203
+
204
+ # 3. 将Controller注册至各个Storage
205
+ # 每个Storage Unit拿到所有Controller的handler,通过Ray拿到对应的IP+端口,之后建立ZMQ Socket进行消息传输
206
+ self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers)
207
+ self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units)
208
+
209
+ ray.get(
210
+ [
211
+ storage_unit.register_controller_info.remote(self.data_system_controller_infos)
212
+ for storage_unit in self.data_system_storage_units.values()
213
+ ]
214
+ )
215
+
216
+ # 4. 创建Client
217
+ self.data_system_client = AsyncTransferQueueClient(
218
+ client_id="Trainer",
219
+ controller_infos=self.data_system_controller_infos[0],
220
+ storage_infos=self.data_system_storage_unit_infos,
221
+ )
222
+
223
+ return self.data_system_client
224
+
225
+ def fit(self):
226
+ for epoch in range(1):
227
+ train_dataloader = 1
228
+ for step in range(train_dataloader):
229
+ input_ids = (
230
+ torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111], [200, 222], [300, 333]])
231
+ ) * (step + 1)
232
+ input_ids_repeated = torch.repeat_interleave(input_ids, self.config.num_n_samples, dim=0)
233
+ prompt_batch = TensorDict(
234
+ {"input_ids": input_ids_repeated, "attention_mask": input_ids_repeated},
235
+ batch_size=input_ids_repeated.size(0),
236
+ )
237
+
238
+ asyncio.run(self.data_system_client.async_put(data=prompt_batch, global_step=step))
239
+
240
+ logger.info("demo put prompts ok! ")
241
+ time.sleep(5)
242
+
243
+ batch_meta = asyncio.run(
244
+ self.data_system_client.async_get_meta(
245
+ data_fields=["input_ids", "attention_mask"],
246
+ batch_size=self.config.global_batch_size * self.config.num_n_samples,
247
+ global_step=step,
248
+ get_n_samples=False,
249
+ task_name="generate_sequences",
250
+ )
251
+ )
252
+ logger.info(f"demo get meta {batch_meta}")
253
+
254
+ # Simulate calling the generate sequences task of the worker group
255
+ if not self.config.async_rollout_mode:
256
+ batch_meta = self.actor_rollout_wg.actor_rollout_wg_generate_sequences(
257
+ batch_meta, self.data_system_client
258
+ )
259
+ else:
260
+ batch_meta = self.async_rollout_manager.generate_sequences(batch_meta)
261
+ log_prob_meta = asyncio.run(
262
+ self.data_system_client.async_get_meta(
263
+ data_fields=["input_ids", "attention_mask", "generate_sequences_ids"],
264
+ batch_size=self.config.global_batch_size * self.config.num_n_samples,
265
+ global_step=step,
266
+ get_n_samples=False,
267
+ task_name="compute_old_log_prob",
268
+ )
269
+ )
270
+ logger.info(f"demo get log prob meta: {log_prob_meta}")
271
+
272
+ # Simulate calling the compute old log prob task of the worker group
273
+ old_log_prob_meta = self.actor_rollout_wg.actor_rollout_wg_compute_old_log_prob(
274
+ log_prob_meta, self.data_system_client
275
+ )
276
+
277
+ batch_meta = batch_meta.union(old_log_prob_meta)
278
+
279
+ # 对于主控的client,通知所有controller进行数据状态清空,主控返回metadata;
280
+ # client再根据metadata通知所有storage unit清空
281
+ # client选择一个主controller拿到metadata,其他的controller直接清空不用返回metadata即可
282
+ asyncio.run(self.data_system_client.async_clear(global_step=step))
283
+ logger.info("clear ok! ")
284
+ logger.info("demo done!")
285
+ return batch_meta
286
+
287
+
288
+ if __name__ == "__main__":
289
+ # NOTE: you may choose to set async_rollout_mode=True to test the async rollout mode that mimics
290
+ # AgentLoopManager in verl
291
+
292
+ config_str = """
293
+ global_batch_size: 8
294
+ num_global_batch: 1
295
+ num_data_storage_units: 2
296
+ num_data_controllers: 1
297
+ async_rollout_mode: True
298
+ rollout_agent_num_workers: 2
299
+ num_n_samples: 2
300
+
301
+ """
302
+ dict_conf = OmegaConf.create(config_str)
303
+
304
+ trainer = Trainer(dict_conf)
305
+ trainer.fit()
306
+
307
+ ray.shutdown()
@@ -0,0 +1,223 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import math
17
+ import os
18
+ import sys
19
+ import time
20
+ from pathlib import Path
21
+
22
+ import ray
23
+ import torch
24
+ from omegaconf import OmegaConf
25
+ from tensordict import NonTensorData, TensorDict
26
+
27
+ parent_dir = Path(__file__).resolve().parent.parent.parent
28
+ sys.path.append(str(parent_dir))
29
+
30
+ from transfer_queue import ( # noqa: E402
31
+ TransferQueueController,
32
+ TransferQueueStorageSimpleUnit,
33
+ process_zmq_server_info,
34
+ )
35
+ from transfer_queue.utils.utils import get_placement_group # noqa: E402
36
+
37
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
38
+ logger = logging.getLogger(__name__)
39
+
40
+ os.environ["RAY_DEDUP_LOGS"] = "0"
41
+ os.environ["RAY_DEBUG"] = "1"
42
+ ray.init()
43
+
44
+
45
+ def initialize_data_system(config):
46
+ # 1. 初始化TransferQueueStorage
47
+ total_storage_size = config.global_batch_size * config.num_global_batch * config.num_n_samples
48
+ data_system_storage_units = {}
49
+ storage_placement_group = get_placement_group(config.num_data_storage_units, num_cpus_per_actor=1)
50
+ for storage_unit_rank in range(config.num_data_storage_units):
51
+ # TransferQueueStorage通过Ray拉起,是一个ray.remote修饰的类
52
+ storage_node = TransferQueueStorageSimpleUnit.options(
53
+ placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
54
+ ).remote(storage_size=math.ceil(total_storage_size / config.num_data_storage_units))
55
+ data_system_storage_units[storage_unit_rank] = storage_node
56
+ logger.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")
57
+
58
+ # 2. 初始化TransferQueueController
59
+ # 这里支持多controller实例以实现负载均衡,支持大规模扩展。不同controller可分配至不同RL计算任务
60
+ controller_placement_group = get_placement_group(config.num_data_controllers, num_cpus_per_actor=1)
61
+ data_system_controllers = {}
62
+ for controller_rank in range(config.num_data_controllers):
63
+ data_system_controllers[controller_rank] = TransferQueueController.options(
64
+ placement_group=controller_placement_group, placement_group_bundle_index=controller_rank
65
+ ).remote(
66
+ num_storage_units=config.num_data_storage_units,
67
+ global_batch_size=config.global_batch_size,
68
+ num_global_batch=config.num_global_batch,
69
+ num_n_samples=config.num_n_samples,
70
+ )
71
+ logger.info(f"TransferQueueController #{controller_rank} has been created.")
72
+
73
+ # 3. 将Controller注册至各个Storage
74
+ # 每个Storage Unit拿到所有Controller的handler,通过Ray拿到对应的IP+端口,之后建立ZMQ Socket进行消息传输
75
+ data_system_controller_infos = process_zmq_server_info(data_system_controllers)
76
+ data_system_storage_unit_infos = process_zmq_server_info(data_system_storage_units)
77
+
78
+ ray.get(
79
+ [
80
+ storage_unit.register_controller_info.remote(data_system_controller_infos)
81
+ for storage_unit in data_system_storage_units.values()
82
+ ]
83
+ )
84
+
85
+ # 4. 创建Client
86
+ from transfer_queue import TransferQueueClient
87
+
88
+ data_system_client = TransferQueueClient(
89
+ client_id="Trainer",
90
+ controller_infos=data_system_controller_infos[
91
+ 0
92
+ ], # TODO: 主控Client感知所有controller,WorkerGroup和Worker的Client感知一个controller
93
+ storage_infos=data_system_storage_unit_infos,
94
+ )
95
+
96
+ return data_system_controllers, data_system_storage_units, data_system_client
97
+
98
+
99
+ def generate_sequences(data):
100
+ time.sleep(3)
101
+ return data
102
+
103
+
104
+ def compute_old_log_prob(data1, data2):
105
+ time.sleep(3)
106
+ return data1
107
+
108
+
109
+ def actor_rollout_wg_generate_sequences(data_meta, data_system_client):
110
+ # 1. 根据data_meta通过client从storage unit中拉取真实data
111
+ data = data_system_client.get_data(data_meta)
112
+ logger.info(f"demo get data {data}")
113
+
114
+ output = generate_sequences(data["input_ids"])
115
+
116
+ output = TensorDict(
117
+ {
118
+ "generate_sequences_ids": output,
119
+ "non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(output.size(0))]),
120
+ "nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(output.size(0))]),
121
+ },
122
+ batch_size=output.size(0),
123
+ )
124
+
125
+ # 2. 根据data_meta将结果写回storage unit
126
+ data_system_client.put(data=output, metadata=data_meta)
127
+ data_meta.add_fields(output)
128
+ logger.info("demo put data to storages done")
129
+
130
+ return data_meta
131
+
132
+
133
+ def actor_rollout_wg_compute_old_log_prob(data_meta, data_system_client):
134
+ # 1. 根据data_meta通过client从storage unit中拉取真实data
135
+ data = data_system_client.get_data(data_meta)
136
+ logger.info(f"demo get data {data}")
137
+
138
+ output = compute_old_log_prob(data["input_ids"], data["generate_sequences_ids"])
139
+
140
+ output = TensorDict({"old_log_prob": output}, batch_size=output.size(0))
141
+
142
+ # 2. 根据data_meta将结果写回storage unit
143
+ data_system_client.put(data=output, metadata=data_meta)
144
+ data_meta.add_fields(output)
145
+ logger.info("demo put data to storages done")
146
+
147
+ return data_meta
148
+
149
+
150
+ # Simulate the fit function of the trainer
151
+ def fit(config, data_system_client):
152
+ for epoch in range(1):
153
+ train_dataloader = 1
154
+ for step in range(train_dataloader):
155
+ input_ids = (torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111]])) * (step + 1)
156
+ input_ids_repeated = torch.repeat_interleave(input_ids, config.num_n_samples, dim=0)
157
+ prompt_batch = TensorDict(
158
+ {"input_ids": input_ids_repeated, "attention_mask": input_ids_repeated},
159
+ batch_size=input_ids_repeated.size(0),
160
+ )
161
+
162
+ data_system_client.put(data=prompt_batch, global_step=step)
163
+ logger.info("demo put prompts ok! ")
164
+ time.sleep(5)
165
+
166
+ batch_meta = data_system_client.get_meta(
167
+ data_fields=["input_ids", "attention_mask"],
168
+ batch_size=config.global_batch_size,
169
+ global_step=step,
170
+ get_n_samples=False,
171
+ task_name="generate_sequences",
172
+ )
173
+ # Set output fields for RL training - in this case, we want to generate sequences from input_ids
174
+ logger.info(f"demo get meta {batch_meta}")
175
+
176
+ # Simulate calling the generate sequences task of the worker group
177
+ batch_meta = actor_rollout_wg_generate_sequences(batch_meta, data_system_client)
178
+ log_prob_meta = data_system_client.get_meta(
179
+ data_fields=["input_ids", "attention_mask", "generate_sequences_ids"],
180
+ batch_size=config.global_batch_size,
181
+ global_step=0,
182
+ get_n_samples=False,
183
+ task_name="compute_old_log_prob",
184
+ )
185
+ # Set output fields for RL training - we want to compute log probs for the generated sequences
186
+ logger.info(f"demo get log prob meta: {log_prob_meta}")
187
+
188
+ # Simulate calling the compute old log prob task of the worker group
189
+ old_log_prob_meta = actor_rollout_wg_compute_old_log_prob(log_prob_meta, data_system_client)
190
+
191
+ batch_meta = batch_meta.union(old_log_prob_meta)
192
+
193
+ # 对于主控的client,通知所有controller进行数据状态清空,主控返回metadata;
194
+ # client再根据metadata通知所有storage unit清空
195
+ # client选择一个主controller拿到metadata,其他的controller直接清空不用返回metadata即可
196
+ data_system_client.clear(global_step=step)
197
+ logger.info("clear ok! ")
198
+ logger.info("demo done!")
199
+
200
+
201
+ def main(config):
202
+ # Initialize Data System:基于Ray拉起Controller以及Storage
203
+ data_system_controllers, data_system_storage_units, data_system_client = initialize_data_system(config)
204
+ import time
205
+
206
+ time.sleep(5)
207
+
208
+ fit(config, data_system_client)
209
+
210
+
211
+ if __name__ == "__main__":
212
+ config_str = """
213
+ global_batch_size: 6
214
+ num_global_batch: 1
215
+ num_data_storage_units: 2
216
+ num_data_controllers: 1
217
+ num_n_samples: 2
218
+ """
219
+ dict_conf = OmegaConf.create(config_str)
220
+
221
+ main(dict_conf)
222
+
223
+ ray.shutdown()