TransferQueue 0.1.1.dev1__py3-none-any.whl → 0.1.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,15 +14,22 @@
14
14
 
15
15
  import dataclasses
16
16
  import itertools
17
- from collections import ChainMap
17
+ import logging
18
+ import os
19
+ from collections import defaultdict
18
20
  from dataclasses import dataclass
19
21
  from typing import Any, Optional
20
22
 
21
23
  import numpy as np
24
+ import torch
22
25
  from tensordict import TensorDict
26
+ from tensordict.tensorclass import NonTensorData, NonTensorStack
23
27
 
24
28
  from transfer_queue.utils.utils import ProductionStatus
25
29
 
30
+ logger = logging.getLogger(__name__)
31
+ logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
32
+
26
33
 
27
34
  @dataclass
28
35
  class FieldMeta:
@@ -296,8 +303,33 @@ class BatchMeta:
296
303
 
297
304
  # Combine all samples
298
305
  all_samples = list(itertools.chain.from_iterable(chunk.samples for chunk in data))
306
+
299
307
  # Merge all extra_info dictionaries from the chunks
300
- merged_extra_info = dict(ChainMap(*(chunk.extra_info for chunk in data)))
308
+ merged_extra_info = dict()
309
+
310
+ values_by_key = defaultdict(list)
311
+ for chunk in data:
312
+ for key, value in chunk.extra_info.items():
313
+ values_by_key[key].append(value)
314
+ for key, values in values_by_key.items():
315
+ if all(isinstance(v, torch.Tensor) for v in values):
316
+ try:
317
+ if all(v.dim() == 0 for v in values):
318
+ merged_extra_info[key] = torch.cat([v.unsqueeze(0) for v in values], dim=0)
319
+ else:
320
+ merged_extra_info[key] = torch.cat(values, dim=0)
321
+ except RuntimeError as e:
322
+ logger.warning(
323
+ f"BatchMeta.concat try to use torch.cat(dim=0) to merge extra_info key '{key}'"
324
+ f" fails, with RuntimeError {e}. Falling back to use list."
325
+ )
326
+ merged_extra_info[key] = values
327
+ elif all(isinstance(v, NonTensorStack | NonTensorData) for v in values):
328
+ merged_extra_info[key] = torch.stack(values)
329
+ elif all(isinstance(v, list) for v in values):
330
+ merged_extra_info[key] = list(itertools.chain.from_iterable(values))
331
+ else:
332
+ merged_extra_info[key] = values[-1]
301
333
 
302
334
  return BatchMeta(samples=all_samples, extra_info=merged_extra_info)
303
335
 
@@ -34,6 +34,9 @@ from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServer
34
34
  logger = logging.getLogger(__name__)
35
35
  logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
36
36
 
37
+ TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT", 200)) # seconds
38
+ TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT", 200)) # seconds
39
+
37
40
 
38
41
  @TransferQueueStorageManagerFactory.register("AsyncSimpleStorageManager")
39
42
  class AsyncSimpleStorageManager(TransferQueueStorageManager):
@@ -132,8 +135,8 @@ class AsyncSimpleStorageManager(TransferQueueStorageManager):
132
135
  try:
133
136
  sock.connect(address)
134
137
  # Timeouts to avoid indefinite await on recv/send
135
- sock.setsockopt(zmq.RCVTIMEO, 10_000) # 10s
136
- sock.setsockopt(zmq.SNDTIMEO, 10_000) # 10s
138
+ sock.setsockopt(zmq.RCVTIMEO, TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT * 1000)
139
+ sock.setsockopt(zmq.SNDTIMEO, TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT * 1000)
137
140
  logger.info(
138
141
  f"[{self.storage_manager_id}]: Connected to StorageUnit {server_info.id} at {address} "
139
142
  f"with identity {identity.decode()}"
@@ -1 +1 @@
1
- 0.1.1.dev1
1
+ 0.1.2.dev0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: TransferQueue
3
- Version: 0.1.1.dev1
3
+ Version: 0.1.2.dev0
4
4
  Summary: TransferQueue: An Asynchronous Streaming Data Management Module
5
5
  Author-email: The TransferQueue Team <hanzy19@tsinghua.org.cn>
6
6
  License: Apache-2.0
@@ -40,13 +40,13 @@ Dynamic: license-file
40
40
  TransferQueue is a high-performance data storage and transfer module with panoramic data visibility and streaming scheduling capabilities, optimized for efficient dataflow in post-training workflows.
41
41
 
42
42
  <p align="center">
43
- <img src="https://cdn.nlark.com/yuque/0/2025/png/23208217/1761356010763-b05751d3-f975-4890-ba59-c8d753cf95f2.png" width="70%">
43
+ <img src="https://github.com/TransferQueue/community_doc/blob/main/docs/tq_arch.png?raw=true" width="70%">
44
44
  </p>
45
45
 
46
46
  TransferQueue offers **fine-grained, sample-level** data management and **load-balancing** (on the way) capabilities, serving as a data gateway that decouples explicit data dependencies across computational tasks. This enables a divide-and-conquer approach, significantly simplifies the algorithm controller design.
47
47
 
48
48
  <p align="center">
49
- <img src="https://cdn.nlark.com/yuque/0/2025/png/23208217/1758696791245-fa7baf96-46af-4c19-8606-28ffadc4556c.png" width="70%">
49
+ <img src="https://github.com/TransferQueue/community_doc/blob/main/docs/main_func.png?raw=true" width="70%">
50
50
  </p>
51
51
 
52
52
  <h2 id="updates">🔄 Updates</h2>
@@ -69,7 +69,7 @@ In the control plane, `TransferQueueController` tracks the **production status**
69
69
  For consumption status, we record the consumption records for each computational task (e.g., `generate_sequences`, `compute_log_prob`, etc.). Therefore, even when different computation tasks require the same data field, they can consume the data independently without interfering with each other.
70
70
 
71
71
  <p align="center">
72
- <img src="https://cdn.nlark.com/yuque/0/2025/png/23208217/1758696820173-456c1784-42ba-40c8-a292-2ff1401f49c5.png" width="70%">
72
+ <img src="https://github.com/TransferQueue/community_doc/blob/main/docs/control_plane.png?raw=true" width="70%">
73
73
  </p>
74
74
 
75
75
  To make the data retrieval process more customizable, we provide a `Sampler` class that allows users to define their own data retrieval and consumption logic. Refer to the [Customize](#customize) section for details.
@@ -91,9 +91,9 @@ This class encapsulates the core interaction logic within the TransferQueue syst
91
91
  Currently, we support the following storage backends:
92
92
 
93
93
  - SimpleStorageUnit: A basic CPU memory storage with minimal data format constraints and easy usability.
94
- - [MoonCakeStore](https://github.com/kvcache-ai/Mooncake): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM.
95
94
  - [Yuanrong](https://gitee.com/openeuler/yuanrong-datasystem): An Ascend native data system that provides hierarchical storage interfaces including HBM/DRAM/SSD.
96
- - [Ray Direct Transport](https://docs.ray.io/en/master/ray-core/direct-transport.html): Ray's new feature that allows Ray to store and pass objects directly between Ray actors.
95
+ - [MoonCakeStore](https://github.com/kvcache-ai/Mooncake) (WIP): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM.
96
+ - [Ray Direct Transport](https://docs.ray.io/en/master/ray-core/direct-transport.html) ([WIP](https://github.com/TransferQueue/TransferQueue/pull/108)): Ray's new feature that allows Ray to store and pass objects directly between Ray actors.
97
97
 
98
98
  Among them, `SimpleStorageUnit` serves as our default storage backend, coordinated by the `AsyncSimpleStorageManager` class. Each storage unit can be deployed on a separate node, allowing for distributed data management.
99
99
 
@@ -105,7 +105,7 @@ Among them, `SimpleStorageUnit` serves as our default storage backend, coordinat
105
105
  This data structure design is motivated by the computational characteristics of the post-training process, where each training sample is generated in a relayed manner across task pipelines. It provides an accurate addressing capability, which allows fine-grained, concurrent data read/write operations in a streaming manner.
106
106
 
107
107
  <p align="center">
108
- <img src="https://cdn.nlark.com/yuque/0/2025/png/23208217/1758696805154-3817011f-84e6-40d0-a80c-58b7e3e5f6a7.png" width="70%">
108
+ <img src="https://github.com/TransferQueue/community_doc/blob/main/docs/data_plane.png?raw=true" width="70%">
109
109
  </p>
110
110
 
111
111
  ### User Interface: Asynchronous & Synchronous Client
@@ -140,7 +140,7 @@ We will soon release a detailed tutorial and API documentation.
140
140
  #### verl
141
141
  The primary motivation for integrating TransferQueue to verl now is to **alleviate the data transfer bottleneck of the single controller `RayPPOTrainer`**. Currently, all `DataProto` objects must be routed through `RayPPOTrainer`, resulting in a single point bottleneck of the whole post-training system.
142
142
 
143
- ![verl_dataflow_DataProto](https://cdn.nlark.com/yuque/0/2025/jpeg/23208217/1758704289414-bcc54228-716b-4d4a-ad3b-f9ace6d10fcf.jpeg)
143
+ ![verl_dataflow_DataProto](https://github.com/TransferQueue/community_doc/blob/main/docs/verl_workflow.jpeg?raw=true)
144
144
 
145
145
  Leveraging TransferQueue, we separate experience data transfer from metadata dispatch by
146
146
 
@@ -148,7 +148,7 @@ Leveraging TransferQueue, we separate experience data transfer from metadata dis
148
148
  - Preserving verl's original Dispatch/Collect logic via BatchMeta (maintaining single-controller debuggability)
149
149
  - Accelerating data transfer by TransferQueue's distributed storage units
150
150
 
151
- ![verl_dataflow_TransferQueue](https://cdn.nlark.com/yuque/0/2025/jpeg/23208217/1758704301666-0807dc06-766c-4a2d-9cde-889a6bb56b34.jpeg)
151
+ ![verl_dataflow_TransferQueue](https://github.com/TransferQueue/community_doc/blob/main/docs/verl_workflow_with_tq.jpeg?raw=true)
152
152
 
153
153
  You may refer to the [recipe](https://github.com/TransferQueue/TransferQueue/tree/dev/recipe/simple_use_case), where we mimic the verl usage in both async & sync scenarios. Official integration to verl is also available now at [verl/pulls/3649](https://github.com/volcengine/verl/pull/3649) (with subsequent PRs to further optimize the integration).
154
154
 
@@ -157,7 +157,7 @@ You may refer to the [recipe](https://github.com/TransferQueue/TransferQueue/tre
157
157
  Work in progress :)
158
158
 
159
159
  <p align="center">
160
- <img src="https://cdn.nlark.com/yuque/0/2025/png/23208217/1758696840817-14ba4c3b-b96e-4390-ac7c-4ecf7b8c0ac3.png" width="70%">
160
+ <img src="https://github.com/TransferQueue/community_doc/blob/main/docs/tq_streaming_dataloader.png?raw=true" width="70%">
161
161
  </p>
162
162
 
163
163
  <h2 id="quick-start">🚀 Quick Start</h2>
@@ -188,12 +188,12 @@ Follow these steps to build and install:
188
188
  <h2 id="performance">📊 Performance</h2>
189
189
 
190
190
  <p align="center">
191
- <img src="https://cdn.nlark.com/yuque/0/2025/png/23208217/1761294403612-76ca20a7-9108-42fc-b3f5-60f84d70f39b.png" width="100%">
191
+ <img src="https://github.com/TransferQueue/community_doc/blob/main/docs/performance_0.1.1.dev2.png?raw=true" width="100%">
192
192
  </p>
193
193
 
194
194
  > Note: The above benchmark for TransferQueue is based on our naive `SimpleStorageUnit` backend. By introducing high-performance storage backends and optimizing serialization/deserialization, we expect to achieve even better performance. Warmly welcome contributions from the community!
195
195
 
196
- For detailed performance benchmarks, please refer to [this blog](https://www.yuque.com/haomingzi-lfse7/hlx5g0/obi4ovmy9wf08zz3?singleDoc#).
196
+ For detailed performance benchmarks, please refer to [this blog](https://www.yuque.com/haomingzi-lfse7/hlx5g0/tml8ke0zkgn6roey?singleDoc#).
197
197
 
198
198
  <h2 id="customize"> 🛠️ Customize TransferQueue</h2>
199
199
 
@@ -5,7 +5,6 @@ tests/test_client.py,sha256=74Pm1D4SI_GCg0Kxwm5Wqa4ppSfc57mpHJGuIdtNUrs,15325
5
5
  tests/test_controller.py,sha256=ZcvFCC3jSnNN_fEerjA37RQv0SSO0Xh8vjcL2mvF03o,11084
6
6
  tests/test_controller_data_partitions.py,sha256=qZxMHerMwKIwyRmT8FZke8TEd80Z9vAkBzU_k5Jz1bY,19185
7
7
  tests/test_kv_storage_manager.py,sha256=Eh6xykdhLBMpxikXfRHxn1crhLsQjn9QGa0O7TLO-5o,3582
8
- tests/test_put.py,sha256=WnRKCGPXmRITAvbD8KWlZor4jpvv6sdg2gg3NCw9gyQ,10453
9
8
  tests/test_samplers.py,sha256=CvYqfmbHEWWa1RyymztCAn0GcitAPOBbfJ4ud1VvO2o,19168
10
9
  tests/test_serial_utils_on_cpu.py,sha256=Hju5yAV52JP-cPw-PT_jsut-y6J_7lUX_SbG5EO1lNQ,7379
11
10
  tests/test_simple_storage_unit.py,sha256=Sczhw2bdCfTfa7RwnpW-aKCV1of8mO6z3l3q2PDzZCA,16535
@@ -13,7 +12,7 @@ tests/test_storage_client_factory.py,sha256=U0gS_l4bc_bP7K_uPhy8UlJ186HyTrA10TSJ
13
12
  transfer_queue/__init__.py,sha256=68c0sBfqHPqTa7OdzO4sAZB52XvwtjpwLqP9BWAh4fA,1535
14
13
  transfer_queue/client.py,sha256=zDlH1beWwRbjz0a7S8QH9IOJ1wp5yEQ36XYwFwJTXOY,24949
15
14
  transfer_queue/controller.py,sha256=uc_MQAlG_QmJ8szxc5yPLaZvCeP2CRULnAADtcf41-8,49702
16
- transfer_queue/metadata.py,sha256=LHQhM7vixw5QcA05nZcqWbxSh7mhJ2ETJOo_1gFb3eI,17645
15
+ transfer_queue/metadata.py,sha256=mH3nev_lK1HiIqZ6przEgCd0BiHkO7jI6ji3x6VSBgY,19076
17
16
  transfer_queue/sampler/__init__.py,sha256=1oauDy2Dwb5GXhKi7tl5DWAHv8i4t2MQK1S4U36Sy4g,788
18
17
  transfer_queue/sampler/base.py,sha256=wFti4dNJb3YArYpGzxA_YDfyUTdTG8wVz6HclPDyZPw,3299
19
18
  transfer_queue/sampler/grpo_group_n_sampler.py,sha256=Kq3hGAz8mboBNvw4Dj0P8lP6Qs8TDojx81fxSh57w28,6566
@@ -27,15 +26,15 @@ transfer_queue/storage/clients/yuanrong_client.py,sha256=rWOPQPLgHLav7cFdGCr8CeV
27
26
  transfer_queue/storage/managers/__init__.py,sha256=y5x5OzZwK_YormdHdzc-smnJNew3niuqU2g_kkaiXIk,876
28
27
  transfer_queue/storage/managers/base.py,sha256=ntlo6sLWIbTiAOUxgJTUyjLl5m4HBuAUUNCYZxq9BFM,20352
29
28
  transfer_queue/storage/managers/factory.py,sha256=58kp2mCKz1K8Ea7RWMsWxdDhN3y4ZhgE-G647AKq7-I,1752
30
- transfer_queue/storage/managers/simple_backend_manager.py,sha256=fCj-0BnS3RbzPn03KEcxmUWbxZVWK11cfxpj8tSf-yY,27381
29
+ transfer_queue/storage/managers/simple_backend_manager.py,sha256=zn0RrRng2zo0kLHl3eKQ9Ksf9O1W9DCEzpuZes8uSjM,27684
31
30
  transfer_queue/storage/managers/yuanrong_manager.py,sha256=NjHC3LBW0fQwm30Oq_qEKoCQEq8oWO0D-AobcpQNPNg,777
32
31
  transfer_queue/utils/__init__.py,sha256=vki-5RVaRBKxVc6Q7XPQox3VNPio2DvJYvRz0SZtu-w,586
33
32
  transfer_queue/utils/serial_utils.py,sha256=9ZgsytTp-441YKtIRFqyH5NhNifSqRKa2h0FI44ltcc,10200
34
33
  transfer_queue/utils/utils.py,sha256=Pno4h3WjX_eT7q4xiVV6Jkhquc39Fp-Ycsg1cv0qNKQ,4544
35
34
  transfer_queue/utils/zmq_utils.py,sha256=jCg2pQfvy_IYdGyZq4nvL4CAwjJc7Li0Trp3T3GDBMg,5118
36
- transfer_queue/version/version,sha256=IBGGJN0Ii7hsCVlxnB5t6XS0CWaPTLXLjPL0jyY3vKE,11
37
- transferqueue-0.1.1.dev1.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
38
- transferqueue-0.1.1.dev1.dist-info/METADATA,sha256=nYLoCMpxUeP1Go3Pnai9W7lVRc4wRAUocqrzyP4K5ds,19531
39
- transferqueue-0.1.1.dev1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
- transferqueue-0.1.1.dev1.dist-info/top_level.txt,sha256=BiBclu7jWJ0AZ35vUr3hN9_cg8JL9EiH_hjFxquMxtw,33
41
- transferqueue-0.1.1.dev1.dist-info/RECORD,,
35
+ transfer_queue/version/version,sha256=Q11DRFEP_n2U67FVBUV-pmge1VJtg5UP4Tj8Ski4SkU,11
36
+ transferqueue-0.1.2.dev0.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
37
+ transferqueue-0.1.2.dev0.dist-info/METADATA,sha256=JnPLADNXjifw9xZxuj6vvNahLyT8k0bDc1gDuJxosOk,19502
38
+ transferqueue-0.1.2.dev0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
+ transferqueue-0.1.2.dev0.dist-info/top_level.txt,sha256=BiBclu7jWJ0AZ35vUr3hN9_cg8JL9EiH_hjFxquMxtw,33
40
+ transferqueue-0.1.2.dev0.dist-info/RECORD,,
tests/test_put.py DELETED
@@ -1,327 +0,0 @@
1
- # Copyright 2025 The TransferQueue Team
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import asyncio
16
- import logging
17
- import sys
18
- from pathlib import Path
19
-
20
- import pytest
21
- import ray
22
- import torch
23
- from tensordict import TensorDict
24
-
25
- parent_dir = Path(__file__).resolve().parent.parent
26
- sys.path.append(str(parent_dir))
27
-
28
- from transfer_queue import ( # noqa: E402
29
- AsyncTransferQueueClient,
30
- SimpleStorageUnit,
31
- TransferQueueController,
32
- process_zmq_server_info,
33
- )
34
- from transfer_queue.utils.utils import get_placement_group # noqa: E402
35
-
36
- # Set up logging
37
- logging.basicConfig(level=logging.INFO)
38
- logger = logging.getLogger(__name__)
39
-
40
-
41
- @pytest.fixture(scope="module")
42
- def ray_setup():
43
- """Initialize Ray for testing."""
44
- ray.init(ignore_reinit_error=True)
45
- yield
46
- ray.shutdown()
47
-
48
-
49
- @pytest.fixture(scope="module")
50
- def data_system_setup(ray_setup):
51
- """Set up data system for testing."""
52
- # Initialize storage units
53
- num_storage_units = 2
54
- storage_size = 10000
55
- storage_units = {}
56
-
57
- storage_placement_group = get_placement_group(num_storage_units, num_cpus_per_actor=1)
58
- for storage_unit_rank in range(num_storage_units):
59
- storage_node = SimpleStorageUnit.options(
60
- placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
61
- ).remote(storage_unit_size=storage_size)
62
- storage_units[storage_unit_rank] = storage_node
63
- logger.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.")
64
-
65
- # Initialize controller
66
- controller = TransferQueueController.remote()
67
- logger.info("TransferQueueController has been created.")
68
-
69
- # Prepare connection info
70
- controller_info = process_zmq_server_info(controller)
71
- storage_unit_infos = process_zmq_server_info(storage_units)
72
-
73
- # Create config
74
- config = {
75
- "controller_info": controller_info,
76
- "storage_unit_infos": storage_unit_infos,
77
- }
78
-
79
- yield controller, storage_units, config
80
-
81
- # Cleanup
82
- ray.kill(controller)
83
- for storage_unit in storage_units.values():
84
- ray.kill(storage_unit)
85
-
86
-
87
- @pytest.fixture
88
- async def client_setup(data_system_setup):
89
- """Set up client for testing."""
90
- controller, storage_units, config = data_system_setup
91
-
92
- client = AsyncTransferQueueClient(
93
- client_id="TestClient",
94
- controller_info=config["controller_info"],
95
- )
96
-
97
- client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
98
-
99
- # Wait a bit for connections to establish
100
- await asyncio.sleep(1)
101
-
102
- return client
103
-
104
-
105
- class TestMultipleAsyncPut:
106
- """Test class for multiple async_put operations."""
107
-
108
- def __init__(self):
109
- self.controller = None
110
- self.storage_units = None
111
- self.config = None
112
-
113
- async def setup(self):
114
- """Setup for the test class."""
115
- if not ray.is_initialized():
116
- ray.init(ignore_reinit_error=True)
117
-
118
- # Initialize data system
119
- num_storage_units = 2
120
- self.storage_units = {}
121
- storage_placement_group = get_placement_group(num_storage_units, num_cpus_per_actor=1)
122
-
123
- for i in range(num_storage_units):
124
- self.storage_units[i] = SimpleStorageUnit.options(
125
- placement_group=storage_placement_group, placement_group_bundle_index=i
126
- ).remote(storage_unit_size=10000)
127
-
128
- self.controller = TransferQueueController.remote()
129
-
130
- # Wait for initialization
131
- await asyncio.sleep(2)
132
-
133
- controller_info = process_zmq_server_info(self.controller)
134
- storage_unit_infos = process_zmq_server_info(self.storage_units)
135
-
136
- self.config = {
137
- "controller_info": controller_info,
138
- "storage_unit_infos": storage_unit_infos,
139
- }
140
- logger.info("TestMultipleAsyncPut setup completed")
141
-
142
- async def teardown(self):
143
- """Teardown for the test class."""
144
- if self.controller:
145
- ray.kill(self.controller)
146
- if self.storage_units:
147
- for storage in self.storage_units.values():
148
- ray.kill(storage)
149
- if ray.is_initialized():
150
- ray.shutdown()
151
-
152
- async def create_client(self, client_id="TestClient"):
153
- """Create a new client instance."""
154
- if self.config is None:
155
- await self.setup()
156
-
157
- client = AsyncTransferQueueClient(
158
- client_id=client_id,
159
- controller_info=self.config["controller_info"],
160
- )
161
- client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=self.config)
162
- await asyncio.sleep(1) # Wait for connections
163
- return client
164
-
165
-
166
- @pytest.mark.asyncio
167
- async def test_concurrent_async_put():
168
- """Test concurrent async_put operations."""
169
- test_instance = TestMultipleAsyncPut()
170
- await test_instance.setup()
171
-
172
- try:
173
- client = await test_instance.create_client("ConcurrentTestClient")
174
-
175
- async def put_operation(partition_id, data):
176
- await client.async_put(data=data, partition_id=partition_id)
177
- return partition_id
178
-
179
- # Create multiple put operations
180
- tasks = []
181
- for i in range(5):
182
- data = TensorDict({f"concurrent_data_{i}": torch.randn(4, 8)}, batch_size=[4])
183
- task = put_operation(f"concurrent_partition_{i}", data)
184
- tasks.append(task)
185
-
186
- # Execute concurrently
187
- results = await asyncio.gather(*tasks)
188
-
189
- assert len(results) == 5
190
- logger.info(f"Completed {len(results)} concurrent put operations")
191
-
192
- # Cleanup
193
- for partition in results:
194
- try:
195
- await client.async_clear(partition)
196
- except Exception as e:
197
- logger.warning(f"Failed to clear partition {partition}: {e}")
198
-
199
- client.close()
200
- finally:
201
- await test_instance.teardown()
202
-
203
-
204
- @pytest.mark.asyncio
205
- async def test_sequential_async_put_with_verification():
206
- """Test sequential async_put operations with data verification."""
207
- test_instance = TestMultipleAsyncPut()
208
- await test_instance.setup()
209
-
210
- try:
211
- client = await test_instance.create_client("SequentialTestClient")
212
-
213
- # Test data
214
- test_cases = [
215
- ("sequential_1", torch.randn(4, 5)),
216
- ("sequential_2", torch.randn(4, 10)),
217
- ("sequential_3", torch.randn(4, 15)),
218
- ]
219
-
220
- for partition_id, tensor_data in test_cases:
221
- data = TensorDict({"sequential_data": tensor_data}, batch_size=[4])
222
-
223
- # Put data
224
- await client.async_put(data=data, partition_id=partition_id)
225
- logger.info(f"Put data to {partition_id}")
226
-
227
- # Verify by reading back
228
- metadata = await client.async_get_meta(
229
- data_fields=["sequential_data"],
230
- batch_size=4,
231
- partition_id=partition_id,
232
- mode="fetch",
233
- task_name="verification_task",
234
- )
235
-
236
- retrieved_data = await client.async_get_data(metadata)
237
-
238
- # Verify shape and content
239
- assert retrieved_data["sequential_data"].shape == tensor_data.shape
240
- logger.info(f"Verified data in {partition_id}")
241
-
242
- # Cleanup
243
- await client.async_clear(partition_id)
244
-
245
- client.close()
246
- finally:
247
- await test_instance.teardown()
248
-
249
-
250
- @pytest.mark.asyncio
251
- async def test_multiple_puts_same_partition():
252
- """Test multiple puts to the same partition."""
253
- test_instance = TestMultipleAsyncPut()
254
- await test_instance.setup()
255
-
256
- try:
257
- client = await test_instance.create_client("SamePartitionTestClient")
258
-
259
- partition_id = "same_partition_test"
260
- batch_size = 4
261
-
262
- # First put
263
- data1 = TensorDict({"first_batch": torch.randn(batch_size, 8)}, batch_size=[batch_size])
264
-
265
- await client.async_put(data=data1, partition_id=partition_id)
266
- logger.info("First put completed")
267
-
268
- # Second put to same partition
269
- data2 = TensorDict({"second_batch": torch.randn(batch_size, 8) * 2}, batch_size=[batch_size])
270
-
271
- await client.async_put(data=data2, partition_id=partition_id)
272
- logger.info("Second put completed")
273
-
274
- # Third put to same partition
275
- data3 = TensorDict({"third_batch": torch.randn(batch_size, 8) * 3}, batch_size=[batch_size])
276
-
277
- await client.async_put(data=data3, partition_id=partition_id)
278
- logger.info("Third put completed")
279
-
280
- # Verify the last data is what we get
281
- metadata = await client.async_get_meta(
282
- data_fields=["third_batch"],
283
- batch_size=batch_size,
284
- partition_id=partition_id,
285
- mode="fetch",
286
- task_name="verification_task",
287
- )
288
-
289
- retrieved_data = await client.async_get_data(metadata)
290
-
291
- assert "third_batch" in retrieved_data
292
- assert retrieved_data["third_batch"].shape == (batch_size, 8)
293
- logger.info("Verified multiple puts to same partition")
294
-
295
- # Cleanup
296
- await client.async_clear(partition_id)
297
-
298
- client.close()
299
- finally:
300
- await test_instance.teardown()
301
-
302
-
303
- @pytest.mark.asyncio
304
- async def test_simple_multiple_async_put():
305
- """Simple test using the existing client_setup fixture."""
306
- test_instance = TestMultipleAsyncPut()
307
- await test_instance.setup()
308
-
309
- try:
310
- client = await test_instance.create_client("SimpleTestClient")
311
-
312
- # Test basic multiple puts
313
- for i in range(3):
314
- data = TensorDict({f"simple_data_{i}": torch.randn(2, 4)}, batch_size=[2])
315
-
316
- partition_id = f"simple_partition_{i}"
317
- await client.async_put(data=data, partition_id=partition_id)
318
- logger.info(f"Put data to {partition_id}")
319
-
320
- # Cleanup each partition
321
- await client.async_clear(partition_id)
322
-
323
- logger.info("Simple multiple async_put test completed")
324
-
325
- client.close()
326
- finally:
327
- await test_instance.teardown()