TransferQueue 0.1.1.dev1__py3-none-any.whl → 0.1.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transfer_queue/metadata.py +34 -2
- transfer_queue/storage/managers/simple_backend_manager.py +5 -2
- transfer_queue/version/version +1 -1
- {transferqueue-0.1.1.dev1.dist-info → transferqueue-0.1.2.dev0.dist-info}/METADATA +12 -12
- {transferqueue-0.1.1.dev1.dist-info → transferqueue-0.1.2.dev0.dist-info}/RECORD +8 -9
- tests/test_put.py +0 -327
- {transferqueue-0.1.1.dev1.dist-info → transferqueue-0.1.2.dev0.dist-info}/WHEEL +0 -0
- {transferqueue-0.1.1.dev1.dist-info → transferqueue-0.1.2.dev0.dist-info}/licenses/LICENSE +0 -0
- {transferqueue-0.1.1.dev1.dist-info → transferqueue-0.1.2.dev0.dist-info}/top_level.txt +0 -0
transfer_queue/metadata.py
CHANGED
|
@@ -14,15 +14,22 @@
|
|
|
14
14
|
|
|
15
15
|
import dataclasses
|
|
16
16
|
import itertools
|
|
17
|
-
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
from collections import defaultdict
|
|
18
20
|
from dataclasses import dataclass
|
|
19
21
|
from typing import Any, Optional
|
|
20
22
|
|
|
21
23
|
import numpy as np
|
|
24
|
+
import torch
|
|
22
25
|
from tensordict import TensorDict
|
|
26
|
+
from tensordict.tensorclass import NonTensorData, NonTensorStack
|
|
23
27
|
|
|
24
28
|
from transfer_queue.utils.utils import ProductionStatus
|
|
25
29
|
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
|
|
32
|
+
|
|
26
33
|
|
|
27
34
|
@dataclass
|
|
28
35
|
class FieldMeta:
|
|
@@ -296,8 +303,33 @@ class BatchMeta:
|
|
|
296
303
|
|
|
297
304
|
# Combine all samples
|
|
298
305
|
all_samples = list(itertools.chain.from_iterable(chunk.samples for chunk in data))
|
|
306
|
+
|
|
299
307
|
# Merge all extra_info dictionaries from the chunks
|
|
300
|
-
merged_extra_info = dict(
|
|
308
|
+
merged_extra_info = dict()
|
|
309
|
+
|
|
310
|
+
values_by_key = defaultdict(list)
|
|
311
|
+
for chunk in data:
|
|
312
|
+
for key, value in chunk.extra_info.items():
|
|
313
|
+
values_by_key[key].append(value)
|
|
314
|
+
for key, values in values_by_key.items():
|
|
315
|
+
if all(isinstance(v, torch.Tensor) for v in values):
|
|
316
|
+
try:
|
|
317
|
+
if all(v.dim() == 0 for v in values):
|
|
318
|
+
merged_extra_info[key] = torch.cat([v.unsqueeze(0) for v in values], dim=0)
|
|
319
|
+
else:
|
|
320
|
+
merged_extra_info[key] = torch.cat(values, dim=0)
|
|
321
|
+
except RuntimeError as e:
|
|
322
|
+
logger.warning(
|
|
323
|
+
f"BatchMeta.concat try to use torch.cat(dim=0) to merge extra_info key '{key}'"
|
|
324
|
+
f" fails, with RuntimeError {e}. Falling back to use list."
|
|
325
|
+
)
|
|
326
|
+
merged_extra_info[key] = values
|
|
327
|
+
elif all(isinstance(v, NonTensorStack | NonTensorData) for v in values):
|
|
328
|
+
merged_extra_info[key] = torch.stack(values)
|
|
329
|
+
elif all(isinstance(v, list) for v in values):
|
|
330
|
+
merged_extra_info[key] = list(itertools.chain.from_iterable(values))
|
|
331
|
+
else:
|
|
332
|
+
merged_extra_info[key] = values[-1]
|
|
301
333
|
|
|
302
334
|
return BatchMeta(samples=all_samples, extra_info=merged_extra_info)
|
|
303
335
|
|
|
@@ -34,6 +34,9 @@ from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServer
|
|
|
34
34
|
logger = logging.getLogger(__name__)
|
|
35
35
|
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
|
|
36
36
|
|
|
37
|
+
TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT", 200)) # seconds
|
|
38
|
+
TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT", 200)) # seconds
|
|
39
|
+
|
|
37
40
|
|
|
38
41
|
@TransferQueueStorageManagerFactory.register("AsyncSimpleStorageManager")
|
|
39
42
|
class AsyncSimpleStorageManager(TransferQueueStorageManager):
|
|
@@ -132,8 +135,8 @@ class AsyncSimpleStorageManager(TransferQueueStorageManager):
|
|
|
132
135
|
try:
|
|
133
136
|
sock.connect(address)
|
|
134
137
|
# Timeouts to avoid indefinite await on recv/send
|
|
135
|
-
sock.setsockopt(zmq.RCVTIMEO,
|
|
136
|
-
sock.setsockopt(zmq.SNDTIMEO,
|
|
138
|
+
sock.setsockopt(zmq.RCVTIMEO, TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT * 1000)
|
|
139
|
+
sock.setsockopt(zmq.SNDTIMEO, TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT * 1000)
|
|
137
140
|
logger.info(
|
|
138
141
|
f"[{self.storage_manager_id}]: Connected to StorageUnit {server_info.id} at {address} "
|
|
139
142
|
f"with identity {identity.decode()}"
|
transfer_queue/version/version
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.1.
|
|
1
|
+
0.1.2.dev0
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: TransferQueue
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2.dev0
|
|
4
4
|
Summary: TransferQueue: An Asynchronous Streaming Data Management Module
|
|
5
5
|
Author-email: The TransferQueue Team <hanzy19@tsinghua.org.cn>
|
|
6
6
|
License: Apache-2.0
|
|
@@ -40,13 +40,13 @@ Dynamic: license-file
|
|
|
40
40
|
TransferQueue is a high-performance data storage and transfer module with panoramic data visibility and streaming scheduling capabilities, optimized for efficient dataflow in post-training workflows.
|
|
41
41
|
|
|
42
42
|
<p align="center">
|
|
43
|
-
<img src="https://
|
|
43
|
+
<img src="https://github.com/TransferQueue/community_doc/blob/main/docs/tq_arch.png?raw=true" width="70%">
|
|
44
44
|
</p>
|
|
45
45
|
|
|
46
46
|
TransferQueue offers **fine-grained, sample-level** data management and **load-balancing** (on the way) capabilities, serving as a data gateway that decouples explicit data dependencies across computational tasks. This enables a divide-and-conquer approach, significantly simplifies the algorithm controller design.
|
|
47
47
|
|
|
48
48
|
<p align="center">
|
|
49
|
-
<img src="https://
|
|
49
|
+
<img src="https://github.com/TransferQueue/community_doc/blob/main/docs/main_func.png?raw=true" width="70%">
|
|
50
50
|
</p>
|
|
51
51
|
|
|
52
52
|
<h2 id="updates">🔄 Updates</h2>
|
|
@@ -69,7 +69,7 @@ In the control plane, `TransferQueueController` tracks the **production status**
|
|
|
69
69
|
For consumption status, we record the consumption records for each computational task (e.g., `generate_sequences`, `compute_log_prob`, etc.). Therefore, even when different computation tasks require the same data field, they can consume the data independently without interfering with each other.
|
|
70
70
|
|
|
71
71
|
<p align="center">
|
|
72
|
-
<img src="https://
|
|
72
|
+
<img src="https://github.com/TransferQueue/community_doc/blob/main/docs/control_plane.png?raw=true" width="70%">
|
|
73
73
|
</p>
|
|
74
74
|
|
|
75
75
|
To make the data retrieval process more customizable, we provide a `Sampler` class that allows users to define their own data retrieval and consumption logic. Refer to the [Customize](#customize) section for details.
|
|
@@ -91,9 +91,9 @@ This class encapsulates the core interaction logic within the TransferQueue syst
|
|
|
91
91
|
Currently, we support the following storage backends:
|
|
92
92
|
|
|
93
93
|
- SimpleStorageUnit: A basic CPU memory storage with minimal data format constraints and easy usability.
|
|
94
|
-
- [MoonCakeStore](https://github.com/kvcache-ai/Mooncake): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM.
|
|
95
94
|
- [Yuanrong](https://gitee.com/openeuler/yuanrong-datasystem): An Ascend native data system that provides hierarchical storage interfaces including HBM/DRAM/SSD.
|
|
96
|
-
- [
|
|
95
|
+
- [MoonCakeStore](https://github.com/kvcache-ai/Mooncake) (WIP): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM.
|
|
96
|
+
- [Ray Direct Transport](https://docs.ray.io/en/master/ray-core/direct-transport.html) ([WIP](https://github.com/TransferQueue/TransferQueue/pull/108)): Ray's new feature that allows Ray to store and pass objects directly between Ray actors.
|
|
97
97
|
|
|
98
98
|
Among them, `SimpleStorageUnit` serves as our default storage backend, coordinated by the `AsyncSimpleStorageManager` class. Each storage unit can be deployed on a separate node, allowing for distributed data management.
|
|
99
99
|
|
|
@@ -105,7 +105,7 @@ Among them, `SimpleStorageUnit` serves as our default storage backend, coordinat
|
|
|
105
105
|
This data structure design is motivated by the computational characteristics of the post-training process, where each training sample is generated in a relayed manner across task pipelines. It provides an accurate addressing capability, which allows fine-grained, concurrent data read/write operations in a streaming manner.
|
|
106
106
|
|
|
107
107
|
<p align="center">
|
|
108
|
-
<img src="https://
|
|
108
|
+
<img src="https://github.com/TransferQueue/community_doc/blob/main/docs/data_plane.png?raw=true" width="70%">
|
|
109
109
|
</p>
|
|
110
110
|
|
|
111
111
|
### User Interface: Asynchronous & Synchronous Client
|
|
@@ -140,7 +140,7 @@ We will soon release a detailed tutorial and API documentation.
|
|
|
140
140
|
#### verl
|
|
141
141
|
The primary motivation for integrating TransferQueue to verl now is to **alleviate the data transfer bottleneck of the single controller `RayPPOTrainer`**. Currently, all `DataProto` objects must be routed through `RayPPOTrainer`, resulting in a single point bottleneck of the whole post-training system.
|
|
142
142
|
|
|
143
|
-

|
|
144
144
|
|
|
145
145
|
Leveraging TransferQueue, we separate experience data transfer from metadata dispatch by
|
|
146
146
|
|
|
@@ -148,7 +148,7 @@ Leveraging TransferQueue, we separate experience data transfer from metadata dis
|
|
|
148
148
|
- Preserving verl's original Dispatch/Collect logic via BatchMeta (maintaining single-controller debuggability)
|
|
149
149
|
- Accelerating data transfer by TransferQueue's distributed storage units
|
|
150
150
|
|
|
151
|
-

|
|
152
152
|
|
|
153
153
|
You may refer to the [recipe](https://github.com/TransferQueue/TransferQueue/tree/dev/recipe/simple_use_case), where we mimic the verl usage in both async & sync scenarios. Official integration to verl is also available now at [verl/pulls/3649](https://github.com/volcengine/verl/pull/3649) (with subsequent PRs to further optimize the integration).
|
|
154
154
|
|
|
@@ -157,7 +157,7 @@ You may refer to the [recipe](https://github.com/TransferQueue/TransferQueue/tre
|
|
|
157
157
|
Work in progress :)
|
|
158
158
|
|
|
159
159
|
<p align="center">
|
|
160
|
-
<img src="https://
|
|
160
|
+
<img src="https://github.com/TransferQueue/community_doc/blob/main/docs/tq_streaming_dataloader.png?raw=true" width="70%">
|
|
161
161
|
</p>
|
|
162
162
|
|
|
163
163
|
<h2 id="quick-start">🚀 Quick Start</h2>
|
|
@@ -188,12 +188,12 @@ Follow these steps to build and install:
|
|
|
188
188
|
<h2 id="performance">📊 Performance</h2>
|
|
189
189
|
|
|
190
190
|
<p align="center">
|
|
191
|
-
<img src="https://
|
|
191
|
+
<img src="https://github.com/TransferQueue/community_doc/blob/main/docs/performance_0.1.1.dev2.png?raw=true" width="100%">
|
|
192
192
|
</p>
|
|
193
193
|
|
|
194
194
|
> Note: The above benchmark for TransferQueue is based on our naive `SimpleStorageUnit` backend. By introducing high-performance storage backends and optimizing serialization/deserialization, we expect to achieve even better performance. Warmly welcome contributions from the community!
|
|
195
195
|
|
|
196
|
-
For detailed performance benchmarks, please refer to [this blog](https://www.yuque.com/haomingzi-lfse7/hlx5g0/
|
|
196
|
+
For detailed performance benchmarks, please refer to [this blog](https://www.yuque.com/haomingzi-lfse7/hlx5g0/tml8ke0zkgn6roey?singleDoc#).
|
|
197
197
|
|
|
198
198
|
<h2 id="customize"> 🛠️ Customize TransferQueue</h2>
|
|
199
199
|
|
|
@@ -5,7 +5,6 @@ tests/test_client.py,sha256=74Pm1D4SI_GCg0Kxwm5Wqa4ppSfc57mpHJGuIdtNUrs,15325
|
|
|
5
5
|
tests/test_controller.py,sha256=ZcvFCC3jSnNN_fEerjA37RQv0SSO0Xh8vjcL2mvF03o,11084
|
|
6
6
|
tests/test_controller_data_partitions.py,sha256=qZxMHerMwKIwyRmT8FZke8TEd80Z9vAkBzU_k5Jz1bY,19185
|
|
7
7
|
tests/test_kv_storage_manager.py,sha256=Eh6xykdhLBMpxikXfRHxn1crhLsQjn9QGa0O7TLO-5o,3582
|
|
8
|
-
tests/test_put.py,sha256=WnRKCGPXmRITAvbD8KWlZor4jpvv6sdg2gg3NCw9gyQ,10453
|
|
9
8
|
tests/test_samplers.py,sha256=CvYqfmbHEWWa1RyymztCAn0GcitAPOBbfJ4ud1VvO2o,19168
|
|
10
9
|
tests/test_serial_utils_on_cpu.py,sha256=Hju5yAV52JP-cPw-PT_jsut-y6J_7lUX_SbG5EO1lNQ,7379
|
|
11
10
|
tests/test_simple_storage_unit.py,sha256=Sczhw2bdCfTfa7RwnpW-aKCV1of8mO6z3l3q2PDzZCA,16535
|
|
@@ -13,7 +12,7 @@ tests/test_storage_client_factory.py,sha256=U0gS_l4bc_bP7K_uPhy8UlJ186HyTrA10TSJ
|
|
|
13
12
|
transfer_queue/__init__.py,sha256=68c0sBfqHPqTa7OdzO4sAZB52XvwtjpwLqP9BWAh4fA,1535
|
|
14
13
|
transfer_queue/client.py,sha256=zDlH1beWwRbjz0a7S8QH9IOJ1wp5yEQ36XYwFwJTXOY,24949
|
|
15
14
|
transfer_queue/controller.py,sha256=uc_MQAlG_QmJ8szxc5yPLaZvCeP2CRULnAADtcf41-8,49702
|
|
16
|
-
transfer_queue/metadata.py,sha256=
|
|
15
|
+
transfer_queue/metadata.py,sha256=mH3nev_lK1HiIqZ6przEgCd0BiHkO7jI6ji3x6VSBgY,19076
|
|
17
16
|
transfer_queue/sampler/__init__.py,sha256=1oauDy2Dwb5GXhKi7tl5DWAHv8i4t2MQK1S4U36Sy4g,788
|
|
18
17
|
transfer_queue/sampler/base.py,sha256=wFti4dNJb3YArYpGzxA_YDfyUTdTG8wVz6HclPDyZPw,3299
|
|
19
18
|
transfer_queue/sampler/grpo_group_n_sampler.py,sha256=Kq3hGAz8mboBNvw4Dj0P8lP6Qs8TDojx81fxSh57w28,6566
|
|
@@ -27,15 +26,15 @@ transfer_queue/storage/clients/yuanrong_client.py,sha256=rWOPQPLgHLav7cFdGCr8CeV
|
|
|
27
26
|
transfer_queue/storage/managers/__init__.py,sha256=y5x5OzZwK_YormdHdzc-smnJNew3niuqU2g_kkaiXIk,876
|
|
28
27
|
transfer_queue/storage/managers/base.py,sha256=ntlo6sLWIbTiAOUxgJTUyjLl5m4HBuAUUNCYZxq9BFM,20352
|
|
29
28
|
transfer_queue/storage/managers/factory.py,sha256=58kp2mCKz1K8Ea7RWMsWxdDhN3y4ZhgE-G647AKq7-I,1752
|
|
30
|
-
transfer_queue/storage/managers/simple_backend_manager.py,sha256=
|
|
29
|
+
transfer_queue/storage/managers/simple_backend_manager.py,sha256=zn0RrRng2zo0kLHl3eKQ9Ksf9O1W9DCEzpuZes8uSjM,27684
|
|
31
30
|
transfer_queue/storage/managers/yuanrong_manager.py,sha256=NjHC3LBW0fQwm30Oq_qEKoCQEq8oWO0D-AobcpQNPNg,777
|
|
32
31
|
transfer_queue/utils/__init__.py,sha256=vki-5RVaRBKxVc6Q7XPQox3VNPio2DvJYvRz0SZtu-w,586
|
|
33
32
|
transfer_queue/utils/serial_utils.py,sha256=9ZgsytTp-441YKtIRFqyH5NhNifSqRKa2h0FI44ltcc,10200
|
|
34
33
|
transfer_queue/utils/utils.py,sha256=Pno4h3WjX_eT7q4xiVV6Jkhquc39Fp-Ycsg1cv0qNKQ,4544
|
|
35
34
|
transfer_queue/utils/zmq_utils.py,sha256=jCg2pQfvy_IYdGyZq4nvL4CAwjJc7Li0Trp3T3GDBMg,5118
|
|
36
|
-
transfer_queue/version/version,sha256=
|
|
37
|
-
transferqueue-0.1.
|
|
38
|
-
transferqueue-0.1.
|
|
39
|
-
transferqueue-0.1.
|
|
40
|
-
transferqueue-0.1.
|
|
41
|
-
transferqueue-0.1.
|
|
35
|
+
transfer_queue/version/version,sha256=Q11DRFEP_n2U67FVBUV-pmge1VJtg5UP4Tj8Ski4SkU,11
|
|
36
|
+
transferqueue-0.1.2.dev0.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
|
37
|
+
transferqueue-0.1.2.dev0.dist-info/METADATA,sha256=JnPLADNXjifw9xZxuj6vvNahLyT8k0bDc1gDuJxosOk,19502
|
|
38
|
+
transferqueue-0.1.2.dev0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
39
|
+
transferqueue-0.1.2.dev0.dist-info/top_level.txt,sha256=BiBclu7jWJ0AZ35vUr3hN9_cg8JL9EiH_hjFxquMxtw,33
|
|
40
|
+
transferqueue-0.1.2.dev0.dist-info/RECORD,,
|
tests/test_put.py
DELETED
|
@@ -1,327 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 The TransferQueue Team
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
import asyncio
|
|
16
|
-
import logging
|
|
17
|
-
import sys
|
|
18
|
-
from pathlib import Path
|
|
19
|
-
|
|
20
|
-
import pytest
|
|
21
|
-
import ray
|
|
22
|
-
import torch
|
|
23
|
-
from tensordict import TensorDict
|
|
24
|
-
|
|
25
|
-
parent_dir = Path(__file__).resolve().parent.parent
|
|
26
|
-
sys.path.append(str(parent_dir))
|
|
27
|
-
|
|
28
|
-
from transfer_queue import ( # noqa: E402
|
|
29
|
-
AsyncTransferQueueClient,
|
|
30
|
-
SimpleStorageUnit,
|
|
31
|
-
TransferQueueController,
|
|
32
|
-
process_zmq_server_info,
|
|
33
|
-
)
|
|
34
|
-
from transfer_queue.utils.utils import get_placement_group # noqa: E402
|
|
35
|
-
|
|
36
|
-
# Set up logging
|
|
37
|
-
logging.basicConfig(level=logging.INFO)
|
|
38
|
-
logger = logging.getLogger(__name__)
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
@pytest.fixture(scope="module")
|
|
42
|
-
def ray_setup():
|
|
43
|
-
"""Initialize Ray for testing."""
|
|
44
|
-
ray.init(ignore_reinit_error=True)
|
|
45
|
-
yield
|
|
46
|
-
ray.shutdown()
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
@pytest.fixture(scope="module")
|
|
50
|
-
def data_system_setup(ray_setup):
|
|
51
|
-
"""Set up data system for testing."""
|
|
52
|
-
# Initialize storage units
|
|
53
|
-
num_storage_units = 2
|
|
54
|
-
storage_size = 10000
|
|
55
|
-
storage_units = {}
|
|
56
|
-
|
|
57
|
-
storage_placement_group = get_placement_group(num_storage_units, num_cpus_per_actor=1)
|
|
58
|
-
for storage_unit_rank in range(num_storage_units):
|
|
59
|
-
storage_node = SimpleStorageUnit.options(
|
|
60
|
-
placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
|
|
61
|
-
).remote(storage_unit_size=storage_size)
|
|
62
|
-
storage_units[storage_unit_rank] = storage_node
|
|
63
|
-
logger.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.")
|
|
64
|
-
|
|
65
|
-
# Initialize controller
|
|
66
|
-
controller = TransferQueueController.remote()
|
|
67
|
-
logger.info("TransferQueueController has been created.")
|
|
68
|
-
|
|
69
|
-
# Prepare connection info
|
|
70
|
-
controller_info = process_zmq_server_info(controller)
|
|
71
|
-
storage_unit_infos = process_zmq_server_info(storage_units)
|
|
72
|
-
|
|
73
|
-
# Create config
|
|
74
|
-
config = {
|
|
75
|
-
"controller_info": controller_info,
|
|
76
|
-
"storage_unit_infos": storage_unit_infos,
|
|
77
|
-
}
|
|
78
|
-
|
|
79
|
-
yield controller, storage_units, config
|
|
80
|
-
|
|
81
|
-
# Cleanup
|
|
82
|
-
ray.kill(controller)
|
|
83
|
-
for storage_unit in storage_units.values():
|
|
84
|
-
ray.kill(storage_unit)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
@pytest.fixture
|
|
88
|
-
async def client_setup(data_system_setup):
|
|
89
|
-
"""Set up client for testing."""
|
|
90
|
-
controller, storage_units, config = data_system_setup
|
|
91
|
-
|
|
92
|
-
client = AsyncTransferQueueClient(
|
|
93
|
-
client_id="TestClient",
|
|
94
|
-
controller_info=config["controller_info"],
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
|
|
98
|
-
|
|
99
|
-
# Wait a bit for connections to establish
|
|
100
|
-
await asyncio.sleep(1)
|
|
101
|
-
|
|
102
|
-
return client
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
class TestMultipleAsyncPut:
|
|
106
|
-
"""Test class for multiple async_put operations."""
|
|
107
|
-
|
|
108
|
-
def __init__(self):
|
|
109
|
-
self.controller = None
|
|
110
|
-
self.storage_units = None
|
|
111
|
-
self.config = None
|
|
112
|
-
|
|
113
|
-
async def setup(self):
|
|
114
|
-
"""Setup for the test class."""
|
|
115
|
-
if not ray.is_initialized():
|
|
116
|
-
ray.init(ignore_reinit_error=True)
|
|
117
|
-
|
|
118
|
-
# Initialize data system
|
|
119
|
-
num_storage_units = 2
|
|
120
|
-
self.storage_units = {}
|
|
121
|
-
storage_placement_group = get_placement_group(num_storage_units, num_cpus_per_actor=1)
|
|
122
|
-
|
|
123
|
-
for i in range(num_storage_units):
|
|
124
|
-
self.storage_units[i] = SimpleStorageUnit.options(
|
|
125
|
-
placement_group=storage_placement_group, placement_group_bundle_index=i
|
|
126
|
-
).remote(storage_unit_size=10000)
|
|
127
|
-
|
|
128
|
-
self.controller = TransferQueueController.remote()
|
|
129
|
-
|
|
130
|
-
# Wait for initialization
|
|
131
|
-
await asyncio.sleep(2)
|
|
132
|
-
|
|
133
|
-
controller_info = process_zmq_server_info(self.controller)
|
|
134
|
-
storage_unit_infos = process_zmq_server_info(self.storage_units)
|
|
135
|
-
|
|
136
|
-
self.config = {
|
|
137
|
-
"controller_info": controller_info,
|
|
138
|
-
"storage_unit_infos": storage_unit_infos,
|
|
139
|
-
}
|
|
140
|
-
logger.info("TestMultipleAsyncPut setup completed")
|
|
141
|
-
|
|
142
|
-
async def teardown(self):
|
|
143
|
-
"""Teardown for the test class."""
|
|
144
|
-
if self.controller:
|
|
145
|
-
ray.kill(self.controller)
|
|
146
|
-
if self.storage_units:
|
|
147
|
-
for storage in self.storage_units.values():
|
|
148
|
-
ray.kill(storage)
|
|
149
|
-
if ray.is_initialized():
|
|
150
|
-
ray.shutdown()
|
|
151
|
-
|
|
152
|
-
async def create_client(self, client_id="TestClient"):
|
|
153
|
-
"""Create a new client instance."""
|
|
154
|
-
if self.config is None:
|
|
155
|
-
await self.setup()
|
|
156
|
-
|
|
157
|
-
client = AsyncTransferQueueClient(
|
|
158
|
-
client_id=client_id,
|
|
159
|
-
controller_info=self.config["controller_info"],
|
|
160
|
-
)
|
|
161
|
-
client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=self.config)
|
|
162
|
-
await asyncio.sleep(1) # Wait for connections
|
|
163
|
-
return client
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
@pytest.mark.asyncio
|
|
167
|
-
async def test_concurrent_async_put():
|
|
168
|
-
"""Test concurrent async_put operations."""
|
|
169
|
-
test_instance = TestMultipleAsyncPut()
|
|
170
|
-
await test_instance.setup()
|
|
171
|
-
|
|
172
|
-
try:
|
|
173
|
-
client = await test_instance.create_client("ConcurrentTestClient")
|
|
174
|
-
|
|
175
|
-
async def put_operation(partition_id, data):
|
|
176
|
-
await client.async_put(data=data, partition_id=partition_id)
|
|
177
|
-
return partition_id
|
|
178
|
-
|
|
179
|
-
# Create multiple put operations
|
|
180
|
-
tasks = []
|
|
181
|
-
for i in range(5):
|
|
182
|
-
data = TensorDict({f"concurrent_data_{i}": torch.randn(4, 8)}, batch_size=[4])
|
|
183
|
-
task = put_operation(f"concurrent_partition_{i}", data)
|
|
184
|
-
tasks.append(task)
|
|
185
|
-
|
|
186
|
-
# Execute concurrently
|
|
187
|
-
results = await asyncio.gather(*tasks)
|
|
188
|
-
|
|
189
|
-
assert len(results) == 5
|
|
190
|
-
logger.info(f"Completed {len(results)} concurrent put operations")
|
|
191
|
-
|
|
192
|
-
# Cleanup
|
|
193
|
-
for partition in results:
|
|
194
|
-
try:
|
|
195
|
-
await client.async_clear(partition)
|
|
196
|
-
except Exception as e:
|
|
197
|
-
logger.warning(f"Failed to clear partition {partition}: {e}")
|
|
198
|
-
|
|
199
|
-
client.close()
|
|
200
|
-
finally:
|
|
201
|
-
await test_instance.teardown()
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
@pytest.mark.asyncio
|
|
205
|
-
async def test_sequential_async_put_with_verification():
|
|
206
|
-
"""Test sequential async_put operations with data verification."""
|
|
207
|
-
test_instance = TestMultipleAsyncPut()
|
|
208
|
-
await test_instance.setup()
|
|
209
|
-
|
|
210
|
-
try:
|
|
211
|
-
client = await test_instance.create_client("SequentialTestClient")
|
|
212
|
-
|
|
213
|
-
# Test data
|
|
214
|
-
test_cases = [
|
|
215
|
-
("sequential_1", torch.randn(4, 5)),
|
|
216
|
-
("sequential_2", torch.randn(4, 10)),
|
|
217
|
-
("sequential_3", torch.randn(4, 15)),
|
|
218
|
-
]
|
|
219
|
-
|
|
220
|
-
for partition_id, tensor_data in test_cases:
|
|
221
|
-
data = TensorDict({"sequential_data": tensor_data}, batch_size=[4])
|
|
222
|
-
|
|
223
|
-
# Put data
|
|
224
|
-
await client.async_put(data=data, partition_id=partition_id)
|
|
225
|
-
logger.info(f"Put data to {partition_id}")
|
|
226
|
-
|
|
227
|
-
# Verify by reading back
|
|
228
|
-
metadata = await client.async_get_meta(
|
|
229
|
-
data_fields=["sequential_data"],
|
|
230
|
-
batch_size=4,
|
|
231
|
-
partition_id=partition_id,
|
|
232
|
-
mode="fetch",
|
|
233
|
-
task_name="verification_task",
|
|
234
|
-
)
|
|
235
|
-
|
|
236
|
-
retrieved_data = await client.async_get_data(metadata)
|
|
237
|
-
|
|
238
|
-
# Verify shape and content
|
|
239
|
-
assert retrieved_data["sequential_data"].shape == tensor_data.shape
|
|
240
|
-
logger.info(f"Verified data in {partition_id}")
|
|
241
|
-
|
|
242
|
-
# Cleanup
|
|
243
|
-
await client.async_clear(partition_id)
|
|
244
|
-
|
|
245
|
-
client.close()
|
|
246
|
-
finally:
|
|
247
|
-
await test_instance.teardown()
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
@pytest.mark.asyncio
|
|
251
|
-
async def test_multiple_puts_same_partition():
|
|
252
|
-
"""Test multiple puts to the same partition."""
|
|
253
|
-
test_instance = TestMultipleAsyncPut()
|
|
254
|
-
await test_instance.setup()
|
|
255
|
-
|
|
256
|
-
try:
|
|
257
|
-
client = await test_instance.create_client("SamePartitionTestClient")
|
|
258
|
-
|
|
259
|
-
partition_id = "same_partition_test"
|
|
260
|
-
batch_size = 4
|
|
261
|
-
|
|
262
|
-
# First put
|
|
263
|
-
data1 = TensorDict({"first_batch": torch.randn(batch_size, 8)}, batch_size=[batch_size])
|
|
264
|
-
|
|
265
|
-
await client.async_put(data=data1, partition_id=partition_id)
|
|
266
|
-
logger.info("First put completed")
|
|
267
|
-
|
|
268
|
-
# Second put to same partition
|
|
269
|
-
data2 = TensorDict({"second_batch": torch.randn(batch_size, 8) * 2}, batch_size=[batch_size])
|
|
270
|
-
|
|
271
|
-
await client.async_put(data=data2, partition_id=partition_id)
|
|
272
|
-
logger.info("Second put completed")
|
|
273
|
-
|
|
274
|
-
# Third put to same partition
|
|
275
|
-
data3 = TensorDict({"third_batch": torch.randn(batch_size, 8) * 3}, batch_size=[batch_size])
|
|
276
|
-
|
|
277
|
-
await client.async_put(data=data3, partition_id=partition_id)
|
|
278
|
-
logger.info("Third put completed")
|
|
279
|
-
|
|
280
|
-
# Verify the last data is what we get
|
|
281
|
-
metadata = await client.async_get_meta(
|
|
282
|
-
data_fields=["third_batch"],
|
|
283
|
-
batch_size=batch_size,
|
|
284
|
-
partition_id=partition_id,
|
|
285
|
-
mode="fetch",
|
|
286
|
-
task_name="verification_task",
|
|
287
|
-
)
|
|
288
|
-
|
|
289
|
-
retrieved_data = await client.async_get_data(metadata)
|
|
290
|
-
|
|
291
|
-
assert "third_batch" in retrieved_data
|
|
292
|
-
assert retrieved_data["third_batch"].shape == (batch_size, 8)
|
|
293
|
-
logger.info("Verified multiple puts to same partition")
|
|
294
|
-
|
|
295
|
-
# Cleanup
|
|
296
|
-
await client.async_clear(partition_id)
|
|
297
|
-
|
|
298
|
-
client.close()
|
|
299
|
-
finally:
|
|
300
|
-
await test_instance.teardown()
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
@pytest.mark.asyncio
|
|
304
|
-
async def test_simple_multiple_async_put():
|
|
305
|
-
"""Simple test using the existing client_setup fixture."""
|
|
306
|
-
test_instance = TestMultipleAsyncPut()
|
|
307
|
-
await test_instance.setup()
|
|
308
|
-
|
|
309
|
-
try:
|
|
310
|
-
client = await test_instance.create_client("SimpleTestClient")
|
|
311
|
-
|
|
312
|
-
# Test basic multiple puts
|
|
313
|
-
for i in range(3):
|
|
314
|
-
data = TensorDict({f"simple_data_{i}": torch.randn(2, 4)}, batch_size=[2])
|
|
315
|
-
|
|
316
|
-
partition_id = f"simple_partition_{i}"
|
|
317
|
-
await client.async_put(data=data, partition_id=partition_id)
|
|
318
|
-
logger.info(f"Put data to {partition_id}")
|
|
319
|
-
|
|
320
|
-
# Cleanup each partition
|
|
321
|
-
await client.async_clear(partition_id)
|
|
322
|
-
|
|
323
|
-
logger.info("Simple multiple async_put test completed")
|
|
324
|
-
|
|
325
|
-
client.close()
|
|
326
|
-
finally:
|
|
327
|
-
await test_instance.teardown()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|