TransferQueue 0.1.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. recipe/simple_use_case/async_demo.py +331 -0
  2. recipe/simple_use_case/sync_demo.py +220 -0
  3. tests/test_async_simple_storage_manager.py +339 -0
  4. tests/test_client.py +423 -0
  5. tests/test_controller.py +274 -0
  6. tests/test_controller_data_partitions.py +513 -0
  7. tests/test_kv_storage_manager.py +92 -0
  8. tests/test_put.py +327 -0
  9. tests/test_samplers.py +492 -0
  10. tests/test_serial_utils_on_cpu.py +202 -0
  11. tests/test_simple_storage_unit.py +443 -0
  12. tests/test_storage_client_factory.py +45 -0
  13. transfer_queue/__init__.py +48 -0
  14. transfer_queue/client.py +611 -0
  15. transfer_queue/controller.py +1187 -0
  16. transfer_queue/metadata.py +460 -0
  17. transfer_queue/sampler/__init__.py +19 -0
  18. transfer_queue/sampler/base.py +74 -0
  19. transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
  20. transfer_queue/sampler/sequential_sampler.py +75 -0
  21. transfer_queue/storage/__init__.py +25 -0
  22. transfer_queue/storage/clients/__init__.py +24 -0
  23. transfer_queue/storage/clients/base.py +22 -0
  24. transfer_queue/storage/clients/factory.py +55 -0
  25. transfer_queue/storage/clients/yuanrong_client.py +118 -0
  26. transfer_queue/storage/managers/__init__.py +23 -0
  27. transfer_queue/storage/managers/base.py +460 -0
  28. transfer_queue/storage/managers/factory.py +43 -0
  29. transfer_queue/storage/managers/simple_backend_manager.py +611 -0
  30. transfer_queue/storage/managers/yuanrong_manager.py +18 -0
  31. transfer_queue/storage/simple_backend.py +451 -0
  32. transfer_queue/utils/__init__.py +13 -0
  33. transfer_queue/utils/serial_utils.py +240 -0
  34. transfer_queue/utils/utils.py +132 -0
  35. transfer_queue/utils/zmq_utils.py +170 -0
  36. transfer_queue/version/version +1 -0
  37. transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
  38. transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
  39. transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
  40. transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
  41. transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
@@ -0,0 +1,331 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import logging
17
+ import math
18
+ import os
19
+ import sys
20
+ import time
21
+ from pathlib import Path
22
+
23
+ import ray
24
+ import torch
25
+ from omegaconf import OmegaConf
26
+ from tensordict import NonTensorData, TensorDict
27
+
28
+ parent_dir = Path(__file__).resolve().parent.parent.parent
29
+ sys.path.append(str(parent_dir))
30
+
31
+ from transfer_queue import ( # noqa: E402
32
+ AsyncTransferQueueClient,
33
+ BatchMeta,
34
+ SimpleStorageUnit,
35
+ TransferQueueController,
36
+ process_zmq_server_info,
37
+ )
38
+ from transfer_queue.utils.utils import get_placement_group # noqa: E402
39
+
40
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
41
+ logger = logging.getLogger(__name__)
42
+
43
+ os.environ["RAY_DEDUP_LOGS"] = "0"
44
+ os.environ["RAY_DEBUG"] = "1"
45
+ ray.init()
46
+
47
+
48
+ def compute_old_log_prob(data1, data2):
49
+ time.sleep(3)
50
+ return data1
51
+
52
+
53
+ def generate_sequences(data):
54
+ time.sleep(3)
55
+ return data
56
+
57
+
58
+ class ActorRolloutRefWorker:
59
+ def actor_rollout_wg_generate_sequences(self, data_meta, data_system_client):
60
+ # 1. Pull real data from the storage plane through client based on data_meta
61
+ data = asyncio.run(data_system_client.async_get_data(data_meta))
62
+ logger.info(f"demo get data->generate_sequences {data}")
63
+
64
+ output = generate_sequences(data["input_ids"])
65
+
66
+ output = TensorDict(
67
+ {
68
+ "generate_sequences_ids": output,
69
+ "non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(output.size(0))]),
70
+ "nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(output.size(0))]),
71
+ },
72
+ batch_size=output.size(0),
73
+ )
74
+
75
+ # 2. Write results back to the storage plane based on data_meta
76
+ asyncio.run(data_system_client.async_put(data=output, metadata=data_meta))
77
+ data_meta.add_fields(output)
78
+ logger.info("demo put data to storages done")
79
+
80
+ return data_meta
81
+
82
+ def actor_rollout_wg_compute_old_log_prob(self, data_meta, data_system_client):
83
+ # 1. Pull real data from the storage plane through client based on data_meta
84
+ data = asyncio.run(data_system_client.async_get_data(data_meta))
85
+ logger.info(f"demo get data->old_log_prob {data}")
86
+
87
+ output = compute_old_log_prob(data["input_ids"], data["generate_sequences_ids"])
88
+
89
+ output = TensorDict({"old_log_prob": output}, batch_size=output.size(0))
90
+
91
+ # 2. Write results back to the storage plane based on data_meta
92
+ asyncio.run(data_system_client.async_put(data=output, metadata=data_meta))
93
+ data_meta.add_fields(output)
94
+ logger.info("demo put data to storages done")
95
+
96
+ return data_meta
97
+
98
+
99
+ @ray.remote
100
+ class AsyncvLLMServer:
101
+ def __init__(self, config, data_system_controller_info):
102
+ self.config = config
103
+ self.data_system_client = AsyncTransferQueueClient(
104
+ client_id="AsyncvLLMServer",
105
+ controller_info=data_system_controller_info,
106
+ )
107
+
108
+ self.data_system_client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=self.config)
109
+
110
+ async def generate(self, data_meta):
111
+ data = await self.data_system_client.async_get_data(data_meta)
112
+ logger.info(f"demo get data->generate_sequences {data}")
113
+
114
+ data = data["input_ids"]
115
+ data += 1
116
+ await asyncio.sleep(3)
117
+
118
+ output = TensorDict(
119
+ {
120
+ "generate_sequences_ids": data,
121
+ "non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(data.size(0))]),
122
+ "nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(data.size(0))]),
123
+ },
124
+ batch_size=data.size(0),
125
+ )
126
+
127
+ await self.data_system_client.async_put(data=output, metadata=data_meta)
128
+ logger.info("demo Async Server put data to storages done")
129
+
130
+ return data_meta
131
+
132
+
133
+ @ray.remote(num_cpus=1)
134
+ class AsyncRolloutWorker:
135
+ def __init__(
136
+ self,
137
+ config,
138
+ data_system_controller_info,
139
+ ):
140
+ self.async_vllm_server = AsyncvLLMServer.remote(
141
+ config,
142
+ data_system_controller_info,
143
+ )
144
+
145
+ async def generate_sequences(self, data_meta_chunk):
146
+ tasks = []
147
+ for i in range(data_meta_chunk.size):
148
+ # asyncio.create_task cannot directly call Ray Actor methods,
149
+ # otherwise an error will be reported:a coroutine was expected, got ObjectRef(xxx)
150
+ tasks.append(asyncio.create_task(self.generate(data_meta_chunk[i])))
151
+ data_metas = await asyncio.gather(*tasks)
152
+ return BatchMeta.concat(data_metas)
153
+
154
+ async def generate(self, data_meta):
155
+ data_meta_new = await self.async_vllm_server.generate.remote(data_meta)
156
+ return data_meta_new
157
+
158
+
159
+ class RolloutManager:
160
+ def __init__(self, config, data_system_storage_unit_infos, data_system_controller_info):
161
+ self.config = config
162
+
163
+ self.data_system_client = AsyncTransferQueueClient(
164
+ client_id="RolloutManager",
165
+ controller_info=data_system_controller_info,
166
+ )
167
+
168
+ self.data_system_client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=self.config)
169
+
170
+ self.async_rollout_workers = []
171
+ num_workers = self.config.rollout_agent_num_workers
172
+ for i in range(num_workers):
173
+ self.async_rollout_workers.append(AsyncRolloutWorker.remote(config, data_system_controller_info))
174
+
175
+ def generate_sequences(self, data_meta):
176
+ data_meta_chunkes = data_meta.chunk(len(self.async_rollout_workers))
177
+ data_metas = ray.get(
178
+ [
179
+ worker.generate_sequences.remote(data_meta_chunk)
180
+ for worker, data_meta_chunk in zip(self.async_rollout_workers, data_meta_chunkes, strict=True)
181
+ ]
182
+ )
183
+ batch_meta = BatchMeta.concat(data_metas)
184
+ logger.info(f"batch_meta: {batch_meta}")
185
+
186
+ return batch_meta
187
+
188
+
189
+ class Trainer:
190
+ def __init__(self, config):
191
+ self.config = config
192
+ self.data_system_client = self._initialize_data_system()
193
+ self.actor_rollout_wg = ActorRolloutRefWorker()
194
+ self.async_rollout_manager = RolloutManager(
195
+ self.config,
196
+ self.data_system_storage_unit_infos,
197
+ self.data_system_controller_info,
198
+ )
199
+
200
+ def _initialize_data_system(self):
201
+ # 1. Initialize TransferQueueStorage
202
+ total_storage_size = self.config.global_batch_size * self.config.num_global_batch * self.config.num_n_samples
203
+ self.data_system_storage_units = {}
204
+ storage_placement_group = get_placement_group(self.config.num_data_storage_units, num_cpus_per_actor=1)
205
+ for storage_unit_rank in range(self.config.num_data_storage_units):
206
+ storage_node = SimpleStorageUnit.options(
207
+ placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
208
+ ).remote(storage_unit_size=math.ceil(total_storage_size / self.config.num_data_storage_units))
209
+ self.data_system_storage_units[storage_unit_rank] = storage_node
210
+ logger.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.")
211
+
212
+ # 2. Initialize TransferQueueController (single controller only)
213
+
214
+ # Sampler usage instructions:
215
+ # For GRPO grouped sampling, you can initialize the controller with GRPOGroupNSampler:
216
+ # Option 1: Pass sampler class (will be instantiated automatically)
217
+ # self.data_system_controller = TransferQueueController.remote(sampler=GRPOGroupNSampler)
218
+
219
+ # Option 2: Pass sampler instance (if you need custom configuration)
220
+ # grpo_sampler = GRPOGroupNSampler()
221
+ # self.data_system_controller = TransferQueueController.remote(sampler=grpo_sampler)
222
+
223
+ # Then use sampling_config in get_meta calls:
224
+ # sampling_config={"n_samples_per_prompt": 4}
225
+ self.data_system_controller = TransferQueueController.remote()
226
+ logger.info("TransferQueueController has been created.")
227
+
228
+ # 3. Prepare necessary information
229
+ self.data_system_controller_info = process_zmq_server_info(self.data_system_controller)
230
+ self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units)
231
+
232
+ tq_config = OmegaConf.create({}, flags={"allow_objects": True}) # Note: Need to generate a new DictConfig
233
+ # with allow_objects=True to maintain ZMQServerInfo instance. Otherwise it will be flattened to dict
234
+ tq_config.controller_info = self.data_system_controller_info
235
+ tq_config.storage_unit_infos = self.data_system_storage_unit_infos
236
+ self.config = OmegaConf.merge(tq_config, self.config)
237
+
238
+ # 4. Create client
239
+ self.data_system_client = AsyncTransferQueueClient(
240
+ client_id="Trainer",
241
+ controller_info=self.data_system_controller_info,
242
+ )
243
+
244
+ self.data_system_client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=self.config)
245
+ # Note: The client contains ZMQ objects. Currently, we cannot transmit the same client instance
246
+ # to multiple places, as this will cause serialization errors in Ray.
247
+ # Workaround: If you need to use a client in multiple Ray actors or processes, create a separate
248
+ # AsyncTransferQueueClient instance for each actor/process instead of sharing or transmitting the same instance.
249
+ return self.data_system_client
250
+
251
+ def fit(self):
252
+ for epoch in range(1):
253
+ train_dataloader = 1
254
+ for step in range(train_dataloader):
255
+ input_ids = (
256
+ torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111], [200, 222], [300, 333]])
257
+ ) * (step + 1)
258
+ input_ids_repeated = torch.repeat_interleave(input_ids, self.config.num_n_samples, dim=0)
259
+ prompt_batch = TensorDict(
260
+ {"input_ids": input_ids_repeated, "attention_mask": input_ids_repeated},
261
+ batch_size=input_ids_repeated.size(0),
262
+ )
263
+
264
+ asyncio.run(self.data_system_client.async_put(data=prompt_batch, partition_id=f"train_{step}"))
265
+
266
+ logger.info("demo put prompts ok! ")
267
+ time.sleep(5)
268
+
269
+ batch_meta = asyncio.run(
270
+ self.data_system_client.async_get_meta(
271
+ data_fields=["input_ids", "attention_mask"],
272
+ batch_size=self.config.global_batch_size * self.config.num_n_samples,
273
+ partition_id=f"train_{step}",
274
+ task_name="generate_sequences",
275
+ )
276
+ )
277
+ logger.info(f"demo get meta {batch_meta}")
278
+
279
+ # Simulate calling the generate sequences task of the worker group
280
+ if not self.config.async_rollout_mode:
281
+ batch_meta = self.actor_rollout_wg.actor_rollout_wg_generate_sequences(
282
+ batch_meta, self.data_system_client
283
+ )
284
+ else:
285
+ batch_meta = self.async_rollout_manager.generate_sequences(batch_meta)
286
+ log_prob_meta = asyncio.run(
287
+ self.data_system_client.async_get_meta(
288
+ data_fields=["input_ids", "attention_mask", "generate_sequences_ids"],
289
+ batch_size=self.config.global_batch_size * self.config.num_n_samples,
290
+ partition_id=f"train_{step}",
291
+ task_name="compute_old_log_prob",
292
+ )
293
+ )
294
+ logger.info(f"demo get log prob meta: {log_prob_meta}")
295
+
296
+ # Simulate calling the compute old log prob task of the worker group
297
+ old_log_prob_meta = self.actor_rollout_wg.actor_rollout_wg_compute_old_log_prob(
298
+ log_prob_meta, self.data_system_client
299
+ )
300
+
301
+ batch_meta = batch_meta.union(old_log_prob_meta)
302
+
303
+ # Client notifies controller to clear data status, controller returns metadata;
304
+ # Client then notifies the storage plane to clear based on metadata
305
+ asyncio.run(self.data_system_client.async_clear(partition_id=f"train_{step}"))
306
+ logger.info("clear ok! ")
307
+ logger.info("demo done!")
308
+
309
+ # Cleanup resources
310
+ self.data_system_client.close()
311
+ return batch_meta
312
+
313
+
314
+ if __name__ == "__main__":
315
+ # NOTE: you may choose to set async_rollout_mode=True to test the async rollout mode that mimics
316
+ # AgentLoopManager in verl
317
+ config_str = """
318
+ global_batch_size: 8
319
+ num_global_batch: 1
320
+ num_data_storage_units: 2
321
+ async_rollout_mode: True
322
+ rollout_agent_num_workers: 2
323
+ num_n_samples: 2
324
+
325
+ """
326
+ dict_conf = OmegaConf.create(config_str)
327
+
328
+ trainer = Trainer(dict_conf)
329
+ trainer.fit()
330
+
331
+ ray.shutdown()
@@ -0,0 +1,220 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import math
17
+ import os
18
+ import sys
19
+ import time
20
+ from pathlib import Path
21
+
22
+ import ray
23
+ import torch
24
+ from omegaconf import OmegaConf
25
+ from tensordict import NonTensorData, TensorDict
26
+
27
+ parent_dir = Path(__file__).resolve().parent.parent.parent
28
+ sys.path.append(str(parent_dir))
29
+
30
+ from transfer_queue import ( # noqa: E402
31
+ SimpleStorageUnit,
32
+ TransferQueueClient,
33
+ TransferQueueController,
34
+ process_zmq_server_info,
35
+ )
36
+ from transfer_queue.utils.utils import get_placement_group # noqa: E402
37
+
38
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
39
+ logger = logging.getLogger(__name__)
40
+
41
+ os.environ["RAY_DEDUP_LOGS"] = "0"
42
+ os.environ["RAY_DEBUG"] = "1"
43
+ ray.init()
44
+
45
+
46
+ def initialize_data_system(config):
47
+ # 1. Initialize TransferQueueStorage
48
+ total_storage_size = config.global_batch_size * config.num_global_batch * config.num_n_samples
49
+ data_system_storage_units = {}
50
+ storage_placement_group = get_placement_group(config.num_data_storage_units, num_cpus_per_actor=1)
51
+ for storage_unit_rank in range(config.num_data_storage_units):
52
+ storage_node = SimpleStorageUnit.options(
53
+ placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
54
+ ).remote(storage_unit_size=math.ceil(total_storage_size / config.num_data_storage_units))
55
+ data_system_storage_units[storage_unit_rank] = storage_node
56
+ logger.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.")
57
+
58
+ # 2. Initialize TransferQueueController (single controller only)
59
+
60
+ # Sampler usage instructions:
61
+ # For GRPO grouped sampling, you can initialize the controller with GRPOGroupNSampler:
62
+ # Option 1: Pass sampler class (will be instantiated automatically)
63
+ # data_system_controller = TransferQueueController.remote(sampler=GRPOGroupNSampler)
64
+
65
+ # Option 2: Pass sampler instance (if you need custom configuration)
66
+ # grpo_sampler = GRPOGroupNSampler()
67
+ # data_system_controller = TransferQueueController.remote(sampler=grpo_sampler)
68
+
69
+ # Then use sampling_config in get_meta calls:
70
+ # sampling_config={"n_samples_per_prompt": 4}
71
+ data_system_controller = TransferQueueController.remote()
72
+ logger.info("TransferQueueController has been created.")
73
+
74
+ # 3. Prepare necessary information
75
+ data_system_controller_info = process_zmq_server_info(data_system_controller)
76
+ data_system_storage_unit_infos = process_zmq_server_info(data_system_storage_units)
77
+
78
+ tq_config = OmegaConf.create({}, flags={"allow_objects": True}) # Note: Need to generate a new DictConfig
79
+ # with allow_objects=True to maintain ZMQServerInfo instance. Otherwise it will be flattened to dict
80
+ tq_config.controller_info = data_system_controller_info
81
+ tq_config.storage_unit_infos = data_system_storage_unit_infos
82
+ config = OmegaConf.merge(tq_config, config)
83
+
84
+ # 4. Create client
85
+ data_system_client = TransferQueueClient(
86
+ client_id="Trainer",
87
+ controller_info=data_system_controller_info,
88
+ )
89
+
90
+ data_system_client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
91
+
92
+ return data_system_controller, data_system_storage_units, data_system_client
93
+
94
+
95
+ def generate_sequences(data):
96
+ time.sleep(3)
97
+ return data
98
+
99
+
100
+ def compute_old_log_prob(data1, _data2):
101
+ time.sleep(3)
102
+ return data1
103
+
104
+
105
+ def actor_rollout_wg_generate_sequences(data_meta, data_system_client):
106
+ # 1. Pull real data from the storage plane through client based on data_meta
107
+ data = data_system_client.get_data(data_meta)
108
+ logger.info(f"demo get data {data}")
109
+
110
+ output = generate_sequences(data["input_ids"])
111
+
112
+ output = TensorDict(
113
+ {
114
+ "generate_sequences_ids": output,
115
+ "non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(output.size(0))]),
116
+ "nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(output.size(0))]),
117
+ },
118
+ batch_size=output.size(0),
119
+ )
120
+
121
+ # 2. Write results back to the storage plane based on data_meta
122
+ data_system_client.put(data=output, metadata=data_meta)
123
+ data_meta.add_fields(output)
124
+ logger.info("demo put data to storages done")
125
+
126
+ return data_meta
127
+
128
+
129
+ def actor_rollout_wg_compute_old_log_prob(data_meta, data_system_client):
130
+ # 1. Pull real data from the storage plane through client based on data_meta
131
+ data = data_system_client.get_data(data_meta)
132
+ logger.info(f"demo get data {data}")
133
+
134
+ output = compute_old_log_prob(data["input_ids"], data["generate_sequences_ids"])
135
+
136
+ output = TensorDict({"old_log_prob": output}, batch_size=output.size(0))
137
+
138
+ # 2. Write results back to the storage plane based on data_meta
139
+ data_system_client.put(data=output, metadata=data_meta)
140
+ data_meta.add_fields(output)
141
+ logger.info("demo put data to storages done")
142
+
143
+ return data_meta
144
+
145
+
146
+ # Simulate the fit function of the trainer
147
+ def fit(config, data_system_client):
148
+ for _epoch in range(1):
149
+ train_dataloader = 1
150
+ for step in range(train_dataloader):
151
+ input_ids = (torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111]])) * (step + 1)
152
+ input_ids_repeated = torch.repeat_interleave(input_ids, config.num_n_samples, dim=0)
153
+ prompt_batch = TensorDict(
154
+ {"input_ids": input_ids_repeated, "attention_mask": input_ids_repeated},
155
+ batch_size=input_ids_repeated.size(0),
156
+ )
157
+
158
+ data_system_client.put(data=prompt_batch, partition_id=f"train_{step}")
159
+ logger.info("demo put prompts ok! ")
160
+ time.sleep(5)
161
+
162
+ batch_meta = data_system_client.get_meta(
163
+ data_fields=["input_ids", "attention_mask"],
164
+ batch_size=config.global_batch_size,
165
+ partition_id=f"train_{step}",
166
+ task_name="generate_sequences",
167
+ )
168
+ # Set output fields for RL training - in this case, we want to generate sequences from input_ids
169
+ logger.info(f"demo get meta {batch_meta}")
170
+
171
+ # Simulate calling the generate sequences task of the worker group
172
+ batch_meta = actor_rollout_wg_generate_sequences(batch_meta, data_system_client)
173
+ log_prob_meta = data_system_client.get_meta(
174
+ data_fields=["input_ids", "attention_mask", "generate_sequences_ids"],
175
+ batch_size=config.global_batch_size,
176
+ partition_id=f"train_{step}",
177
+ task_name="compute_old_log_prob",
178
+ )
179
+ # Set output fields for RL training - we want to compute log probs for the generated sequences
180
+ logger.info(f"demo get log prob meta: {log_prob_meta}")
181
+
182
+ # Simulate calling the compute old log prob task of the worker group
183
+ old_log_prob_meta = actor_rollout_wg_compute_old_log_prob(log_prob_meta, data_system_client)
184
+
185
+ batch_meta = batch_meta.union(old_log_prob_meta)
186
+
187
+ # For the master client, notify all controllers to clear data status, master returns metadata;
188
+ # Client then notifies the storage plane to clear based on metadata
189
+ # Client selects one master controller to get metadata,
190
+ # other controllers directly clear without returning metadata
191
+ data_system_client.clear(partition_id=f"train_{step}")
192
+ logger.info("clear ok! ")
193
+ logger.info("demo done!")
194
+
195
+
196
+ def main(config):
197
+ # Initialize Data System: Launching the Controller and Storage based on Ray
198
+ _data_system_controller, _data_system_storage_units, data_system_client = initialize_data_system(config)
199
+ import time
200
+
201
+ time.sleep(5)
202
+
203
+ fit(config, data_system_client)
204
+
205
+ # Cleanup resources
206
+ data_system_client.close()
207
+
208
+
209
+ if __name__ == "__main__":
210
+ config_str = """
211
+ global_batch_size: 6
212
+ num_global_batch: 1
213
+ num_data_storage_units: 2
214
+ num_n_samples: 2
215
+ """
216
+ dict_conf = OmegaConf.create(config_str)
217
+
218
+ main(dict_conf)
219
+
220
+ ray.shutdown()