TransferQueue 0.1.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. recipe/simple_use_case/async_demo.py +331 -0
  2. recipe/simple_use_case/sync_demo.py +220 -0
  3. tests/test_async_simple_storage_manager.py +339 -0
  4. tests/test_client.py +423 -0
  5. tests/test_controller.py +274 -0
  6. tests/test_controller_data_partitions.py +513 -0
  7. tests/test_kv_storage_manager.py +92 -0
  8. tests/test_put.py +327 -0
  9. tests/test_samplers.py +492 -0
  10. tests/test_serial_utils_on_cpu.py +202 -0
  11. tests/test_simple_storage_unit.py +443 -0
  12. tests/test_storage_client_factory.py +45 -0
  13. transfer_queue/__init__.py +48 -0
  14. transfer_queue/client.py +611 -0
  15. transfer_queue/controller.py +1187 -0
  16. transfer_queue/metadata.py +460 -0
  17. transfer_queue/sampler/__init__.py +19 -0
  18. transfer_queue/sampler/base.py +74 -0
  19. transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
  20. transfer_queue/sampler/sequential_sampler.py +75 -0
  21. transfer_queue/storage/__init__.py +25 -0
  22. transfer_queue/storage/clients/__init__.py +24 -0
  23. transfer_queue/storage/clients/base.py +22 -0
  24. transfer_queue/storage/clients/factory.py +55 -0
  25. transfer_queue/storage/clients/yuanrong_client.py +118 -0
  26. transfer_queue/storage/managers/__init__.py +23 -0
  27. transfer_queue/storage/managers/base.py +460 -0
  28. transfer_queue/storage/managers/factory.py +43 -0
  29. transfer_queue/storage/managers/simple_backend_manager.py +611 -0
  30. transfer_queue/storage/managers/yuanrong_manager.py +18 -0
  31. transfer_queue/storage/simple_backend.py +451 -0
  32. transfer_queue/utils/__init__.py +13 -0
  33. transfer_queue/utils/serial_utils.py +240 -0
  34. transfer_queue/utils/utils.py +132 -0
  35. transfer_queue/utils/zmq_utils.py +170 -0
  36. transfer_queue/version/version +1 -0
  37. transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
  38. transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
  39. transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
  40. transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
  41. transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
@@ -0,0 +1,451 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+ import logging
17
+ import os
18
+ from dataclasses import dataclass
19
+ from operator import itemgetter
20
+ from threading import Thread
21
+ from typing import Any, Optional
22
+ from uuid import uuid4
23
+
24
+ import ray
25
+ import torch
26
+ import zmq
27
+ from ray.util import get_node_ip_address
28
+ from tensordict import NonTensorStack, TensorDict
29
+
30
+ from transfer_queue.metadata import SampleMeta
31
+ from transfer_queue.utils.utils import TransferQueueRole
32
+ from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket, get_free_port
33
+
34
+ logger = logging.getLogger(__name__)
35
+ logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
36
+
37
+ # ZMQ timeouts (in seconds) and retry configurations
38
+ TQ_STORAGE_POLLER_TIMEOUT = int(os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 5))
39
+
40
+
41
+ class StorageUnitData:
42
+ """Storage unit for managing 2D data structure (samples × fields).
43
+
44
+ This class provides efficient storage and retrieval of data in a 2D matrix format
45
+ where rows represent samples (indexed by local_index) and columns represent fields.
46
+ Each field contains a list of data items indexed by their local position.
47
+
48
+ Data Structure Example:
49
+ ┌─────────────┬─────────────┬─────────────┬─────────┐
50
+ │ local_index │ field_name1 │ field_name2 │ ... │
51
+ ├─────────────┼─────────────┼─────────────┼─────────┤
52
+ │ 0 │ item1 │ item2 │ ... │
53
+ │ 1 │ item3 │ item4 │ ... │
54
+ │ 2 │ item5 │ item6 │ ... │
55
+ └─────────────┴─────────────┴─────────────┴─────────┘
56
+ """
57
+
58
+ def __init__(self, storage_size: int):
59
+ # Dict containing field names and corresponding data in the field
60
+ # Format: {"field_name": [data_at_index_0, data_at_index_1, ...]}
61
+ self.field_data: dict[str, list] = {}
62
+
63
+ # Maximum number of elements stored in storage unit
64
+ self.storage_size = storage_size
65
+
66
+ def get_data(self, fields: list[str], local_indexes: list[int]) -> TensorDict[str, list]:
67
+ """
68
+ Get data from storage unit according to given fields and local_indexes.
69
+
70
+ Args:
71
+ fields: Field names used for getting data.
72
+ local_indexes: Local indexes used for getting data.
73
+
74
+ Returns:
75
+ TensorDict with field names as keys, corresponding data list as values.
76
+ """
77
+ result: dict[str, list] = {}
78
+
79
+ for field in fields:
80
+ # Validate field name
81
+ if field not in self.field_data:
82
+ raise ValueError(
83
+ f"StorageUnitData get_data operation receive invalid field: {field} beyond {self.field_data.keys()}"
84
+ )
85
+
86
+ if len(local_indexes) == 1:
87
+ # The unsqueeze op make the shape from n to (1, n)
88
+ gathered_item = self.field_data[field][local_indexes[0]]
89
+ if not isinstance(gathered_item, torch.Tensor):
90
+ result[field] = NonTensorStack(gathered_item)
91
+ else:
92
+ result[field] = gathered_item.unsqueeze(0)
93
+ else:
94
+ gathered_items = list(itemgetter(*local_indexes)(self.field_data[field]))
95
+
96
+ if gathered_items:
97
+ all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items)
98
+ if all_tensors:
99
+ result[field] = torch.nested.as_nested_tensor(gathered_items)
100
+ else:
101
+ result[field] = NonTensorStack(*gathered_items)
102
+
103
+ # Explicit batch size for stability
104
+ batch_size = 0 if not fields or not local_indexes else len(local_indexes)
105
+ return TensorDict(result, batch_size=batch_size)
106
+
107
+ def put_data(self, field_data: TensorDict[str, Any], local_indexes: list[int]) -> None:
108
+ """
109
+ Put or update data into storage unit according to given field_data and local_indexes.
110
+
111
+ Args:
112
+ field_data: Dict with field names as keys, corresponding data in the field as values.
113
+ local_indexes: Local indexes used for putting data.
114
+ """
115
+ extracted_data = field_data.to_dict()
116
+
117
+ for f, values in extracted_data.items():
118
+ if f not in self.field_data:
119
+ self.field_data[f] = [None] * self.storage_size
120
+
121
+ for i, idx in enumerate(local_indexes):
122
+ if idx < 0 or idx >= self.storage_size:
123
+ raise ValueError(
124
+ f"StorageUnitData put_data operation receive invalid local_index: {idx} beyond "
125
+ f"storage_size: {self.storage_size}"
126
+ )
127
+
128
+ self.field_data[f][idx] = values[i]
129
+
130
+ def clear(self, local_indexes: list[int]) -> None:
131
+ """
132
+ Clear data at specified local_indexes by setting all related fields to None.
133
+
134
+ Args:
135
+ local_indexes: local_indexes to clear.
136
+ """
137
+ # Validate local_indexes
138
+ for idx in local_indexes:
139
+ if idx < 0 or idx >= self.storage_size:
140
+ raise ValueError(
141
+ f"StorageUnitData clear operation receive invalid local_index: {idx} beyond "
142
+ f"storage_size: {self.storage_size}"
143
+ )
144
+
145
+ # Clear data at specified local_indexes
146
+ for f in self.field_data:
147
+ for idx in local_indexes:
148
+ self.field_data[f][idx] = None
149
+
150
+
151
+ @ray.remote(num_cpus=1)
152
+ class SimpleStorageUnit:
153
+ """A storage unit that provides distributed data storage functionality.
154
+
155
+ This class represents a storage unit that can store data in a 2D structure
156
+ (samples × data fields) and provides ZMQ-based communication for put/get/clear operations.
157
+
158
+ Note: We use Ray decorator (@ray.remote) only for initialization purposes.
159
+ We do NOT use Ray's .remote() call capabilities - the storage unit runs
160
+ as a standalone process with its own ZMQ server socket.
161
+
162
+ Attributes:
163
+ storage_unit_id: Unique identifier for this storage unit.
164
+ storage_unit_size: Maximum number of elements that can be stored.
165
+ storage_data: Internal StorageUnitData instance for data management.
166
+ zmq_server_info: ZMQ connection information for clients.
167
+ """
168
+
169
+ def __init__(self, storage_unit_size: int):
170
+ """Initialize a SimpleStorageUnit with the specified size.
171
+
172
+ Args:
173
+ storage_unit_size: Maximum number of elements that can be stored in this storage unit.
174
+ """
175
+ self.storage_unit_id = f"TQ_STORAGE_UNIT_{uuid4().hex[:8]}"
176
+ self.storage_unit_size = storage_unit_size
177
+
178
+ self.storage_data = StorageUnitData(self.storage_unit_size)
179
+
180
+ self.zmq_server_info = ZMQServerInfo(
181
+ role=TransferQueueRole.STORAGE,
182
+ id=str(self.storage_unit_id),
183
+ ip=get_node_ip_address(),
184
+ ports={"put_get_socket": get_free_port()},
185
+ )
186
+ self._init_zmq_socket()
187
+ self._start_process_put_get()
188
+
189
+ def _init_zmq_socket(self) -> None:
190
+ """
191
+ Initialize ZMQ socket connections between storage unit and controller/clients:
192
+ - put_get_socket:
193
+ Handle put/get requests from clients.
194
+ """
195
+ self.zmq_context = zmq.Context()
196
+
197
+ self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER)
198
+ self.put_get_socket.bind(self.zmq_server_info.to_addr("put_get_socket"))
199
+
200
+ def _start_process_put_get(self) -> None:
201
+ """Create a daemon thread and start put/get process."""
202
+ self.process_put_get_thread = Thread(
203
+ target=self._process_put_get, name=f"StorageUnitProcessPutGetThread-{self.zmq_server_info.id}", daemon=True
204
+ )
205
+ self.process_put_get_thread.start()
206
+
207
+ def _process_put_get(self) -> None:
208
+ """Process put_get_socket request."""
209
+ poller = zmq.Poller()
210
+ poller.register(self.put_get_socket, zmq.POLLIN)
211
+
212
+ while True:
213
+ socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000))
214
+
215
+ if self.put_get_socket in socks:
216
+ identity, serialized_msg = self.put_get_socket.recv_multipart()
217
+
218
+ try:
219
+ request_msg = ZMQMessage.deserialize(serialized_msg)
220
+ operation = request_msg.request_type
221
+ logger.debug(f"[{self.zmq_server_info.id}]: receive operation: {operation}, message: {request_msg}")
222
+
223
+ if operation == ZMQRequestType.PUT_DATA:
224
+ response_msg = self._handle_put(request_msg)
225
+ elif operation == ZMQRequestType.GET_DATA:
226
+ response_msg = self._handle_get(request_msg)
227
+ elif operation == ZMQRequestType.CLEAR_DATA:
228
+ response_msg = self._handle_clear(request_msg)
229
+ else:
230
+ response_msg = ZMQMessage.create(
231
+ request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR,
232
+ sender_id=self.zmq_server_info.id,
233
+ body={
234
+ "message": f"Storage unit id #{self.zmq_server_info.id} "
235
+ f"receive invalid operation: {operation}."
236
+ },
237
+ )
238
+ except Exception as e:
239
+ response_msg = ZMQMessage.create(
240
+ request_type=ZMQRequestType.PUT_GET_ERROR,
241
+ sender_id=self.zmq_server_info.id,
242
+ body={
243
+ "message": f"Storage unit id #{self.zmq_server_info.id} occur error in processing "
244
+ f"put/get/clear request, detail error message: {str(e)}."
245
+ },
246
+ )
247
+
248
+ self.put_get_socket.send_multipart([identity, response_msg.serialize()])
249
+
250
+ def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage:
251
+ """
252
+ Handle put request, add or update data into storage unit.
253
+
254
+ Args:
255
+ data_parts: ZMQMessage from client.
256
+
257
+ Returns:
258
+ Put data success response ZMQMessage.
259
+ """
260
+ try:
261
+ local_indexes = data_parts.body["local_indexes"]
262
+ field_data = data_parts.body["data"] # field_data should be a TensorDict.
263
+
264
+ self.storage_data.put_data(field_data, local_indexes)
265
+
266
+ # After put operation finish, send a message to the client
267
+ response_msg = ZMQMessage.create(
268
+ request_type=ZMQRequestType.PUT_DATA_RESPONSE, sender_id=self.zmq_server_info.id, body={}
269
+ )
270
+
271
+ return response_msg
272
+ except Exception as e:
273
+ return ZMQMessage.create(
274
+ request_type=ZMQRequestType.PUT_ERROR,
275
+ sender_id=self.zmq_server_info.id,
276
+ body={
277
+ "message": f"Failed to put data into storage unit id "
278
+ f"#{self.zmq_server_info.id}, detail error message: {str(e)}"
279
+ },
280
+ )
281
+
282
+ def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage:
283
+ """
284
+ Handle get request, return data from storage unit.
285
+
286
+ Args:
287
+ data_parts: ZMQMessage from client.
288
+
289
+ Returns:
290
+ Get data success response ZMQMessage, containing target data.
291
+ """
292
+ try:
293
+ fields = data_parts.body["fields"]
294
+ local_indexes = data_parts.body["local_indexes"]
295
+
296
+ result_data = self.storage_data.get_data(fields, local_indexes)
297
+
298
+ response_msg = ZMQMessage.create(
299
+ request_type=ZMQRequestType.GET_DATA_RESPONSE,
300
+ sender_id=self.zmq_server_info.id,
301
+ body={
302
+ "data": result_data,
303
+ },
304
+ )
305
+ except Exception as e:
306
+ response_msg = ZMQMessage.create(
307
+ request_type=ZMQRequestType.GET_ERROR,
308
+ sender_id=self.zmq_server_info.id,
309
+ body={
310
+ "message": f"Failed to get data from storage unit id #{self.zmq_server_info.id}, "
311
+ f"detail error message: {str(e)}"
312
+ },
313
+ )
314
+ return response_msg
315
+
316
+ def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage:
317
+ """
318
+ Handle clear request, clear data in storage unit according to given local_indexes.
319
+
320
+ Args:
321
+ data_parts: ZMQMessage from client, including target local_indexes.
322
+
323
+ Returns:
324
+ Clear data success response ZMQMessage.
325
+ """
326
+ try:
327
+ local_indexes = data_parts.body["local_indexes"]
328
+
329
+ self.storage_data.clear(local_indexes)
330
+
331
+ response_msg = ZMQMessage.create(
332
+ request_type=ZMQRequestType.CLEAR_DATA_RESPONSE,
333
+ sender_id=self.zmq_server_info.id,
334
+ body={"message": f"Clear data in storage unit id #{self.zmq_server_info.id} successfully."},
335
+ )
336
+ except Exception as e:
337
+ response_msg = ZMQMessage.create(
338
+ request_type=ZMQRequestType.CLEAR_DATA_ERROR,
339
+ sender_id=self.zmq_server_info.id,
340
+ body={
341
+ "message": f"Failed to clear data in storage unit id #{self.zmq_server_info.id}, "
342
+ f"detail error message: {str(e)}"
343
+ },
344
+ )
345
+ return response_msg
346
+
347
+ def get_zmq_server_info(self) -> ZMQServerInfo:
348
+ """Get the ZMQ server information for this storage unit.
349
+
350
+ Returns:
351
+ ZMQServerInfo containing connection details for this storage unit.
352
+ """
353
+ return self.zmq_server_info
354
+
355
+
356
+ @dataclass
357
+ class StorageMetaGroup:
358
+ """
359
+ Represents a group of samples stored in the same storage unit.
360
+ Used to organize samples by their storage_id for efficient client operations.
361
+ """
362
+
363
+ storage_id: str
364
+ sample_metas: list[SampleMeta] = dataclasses.field(default_factory=list)
365
+ local_indexes: list[int] = dataclasses.field(default_factory=list)
366
+
367
+ def add_sample_meta(self, sample_meta: SampleMeta, local_index: int) -> None:
368
+ """Add a SampleMeta object to this storage group"""
369
+ self.sample_metas.append(sample_meta)
370
+ self.local_indexes.append(local_index)
371
+
372
+ def get_batch_indexes(self) -> list[int]:
373
+ """Get all internal indexes from stored SampleMeta objects"""
374
+ return [meta.batch_index for meta in self.sample_metas]
375
+
376
+ def get_global_indexes(self) -> list[int]:
377
+ """Get all global indexes from stored SampleMeta objects"""
378
+ return [meta.global_index for meta in self.sample_metas]
379
+
380
+ def get_local_indexes(self) -> list[int]:
381
+ """Get all local indexes from stored SampleMeta objects"""
382
+ return self.local_indexes
383
+
384
+ def get_field_names(self) -> list[str]:
385
+ """Get all unique field names from stored SampleMeta objects"""
386
+ all_fields: set[str] = set()
387
+ for meta in self.sample_metas:
388
+ all_fields.update(meta.fields.keys())
389
+ return list(all_fields)
390
+
391
+ def get_transfer_data(self, field_names: Optional[list[str]] = None) -> dict[str, list | dict]:
392
+ """Convert metadata to transfer dictionary format.
393
+
394
+ Creates a transfer_dict structure containing indexing and field information
395
+ but without the actual field data. The field_data placeholder will be
396
+ populated by the _add_field_data() function.
397
+
398
+ Args:
399
+ field_names: Optional list of field names to include. If None, includes all fields.
400
+
401
+ Returns:
402
+ Transfer dictionary with metadata structure:
403
+ {
404
+ "batch_indexes": [batch_idx1, batch_idx2, ...],
405
+ "global_indexes": [global_idx1, global_idx2, ...],
406
+ "local_indexes": [local_idx1, local_idx2, ...],
407
+ "fields": ["field1", "field2", ...],
408
+ "field_data": {} # Placeholder - actual data added by _add_field_data()
409
+ }
410
+
411
+ Example:
412
+ >>> group = StorageMetaGroup("storage1")
413
+ >>> # Add multiple samples with different batch/global indexes and storage locations
414
+ >>> group.add_sample_meta(SampleMeta(batch_index=0, global_index=10, fields={"img": ...}), 4)
415
+ >>> group.add_sample_meta(SampleMeta(batch_index=1, global_index=11, fields={"img": ...}), 5)
416
+ >>> group.add_sample_meta(SampleMeta(batch_index=2, global_index=12, fields={"img": ...}), 6)
417
+ >>> transfer_dict = group.get_transfer_data(["img"])
418
+ >>> transfer_dict["local_indexes"] # [4, 5, 6] - storage locations
419
+ >>> transfer_dict["batch_indexes"] # [0, 1, 2] - original data locations
420
+ >>> transfer_dict["global_indexes"] # [10, 11, 12] - global identifiers
421
+ """
422
+ if field_names is None:
423
+ field_names = self.get_field_names()
424
+ return {
425
+ "batch_indexes": self.get_batch_indexes(),
426
+ "global_indexes": self.get_global_indexes(),
427
+ "local_indexes": self.get_local_indexes(),
428
+ "fields": field_names,
429
+ "field_data": {}, # Placeholder for field data to be filled later
430
+ }
431
+
432
+ @property
433
+ def size(self) -> int:
434
+ """Number of samples in this storage meta group"""
435
+ return len(self.sample_metas)
436
+
437
+ @property
438
+ def is_empty(self) -> bool:
439
+ """Check if this storage meta group is empty"""
440
+ return len(self.sample_metas) == 0
441
+
442
+ def __len__(self) -> int:
443
+ """Number of samples in this storage meta group"""
444
+ return self.size
445
+
446
+ def __bool__(self) -> bool:
447
+ """Truthiness based on whether group has samples"""
448
+ return not self.is_empty
449
+
450
+ def __str__(self) -> str:
451
+ return f"StorageMetaGroup(storage_id='{self.storage_id}', size={self.size})"
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.