TransferQueue 0.1.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +331 -0
- recipe/simple_use_case/sync_demo.py +220 -0
- tests/test_async_simple_storage_manager.py +339 -0
- tests/test_client.py +423 -0
- tests/test_controller.py +274 -0
- tests/test_controller_data_partitions.py +513 -0
- tests/test_kv_storage_manager.py +92 -0
- tests/test_put.py +327 -0
- tests/test_samplers.py +492 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +443 -0
- tests/test_storage_client_factory.py +45 -0
- transfer_queue/__init__.py +48 -0
- transfer_queue/client.py +611 -0
- transfer_queue/controller.py +1187 -0
- transfer_queue/metadata.py +460 -0
- transfer_queue/sampler/__init__.py +19 -0
- transfer_queue/sampler/base.py +74 -0
- transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
- transfer_queue/sampler/sequential_sampler.py +75 -0
- transfer_queue/storage/__init__.py +25 -0
- transfer_queue/storage/clients/__init__.py +24 -0
- transfer_queue/storage/clients/base.py +22 -0
- transfer_queue/storage/clients/factory.py +55 -0
- transfer_queue/storage/clients/yuanrong_client.py +118 -0
- transfer_queue/storage/managers/__init__.py +23 -0
- transfer_queue/storage/managers/base.py +460 -0
- transfer_queue/storage/managers/factory.py +43 -0
- transfer_queue/storage/managers/simple_backend_manager.py +611 -0
- transfer_queue/storage/managers/yuanrong_manager.py +18 -0
- transfer_queue/storage/simple_backend.py +451 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +132 -0
- transfer_queue/utils/zmq_utils.py +170 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
- transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
- transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
tests/test_samplers.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Unit tests for TransferQueue samplers."""
|
|
16
|
+
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
import pytest
|
|
20
|
+
|
|
21
|
+
from transfer_queue.sampler import BaseSampler
|
|
22
|
+
from transfer_queue.sampler.grpo_group_n_sampler import GRPOGroupNSampler
|
|
23
|
+
from transfer_queue.sampler.sequential_sampler import SequentialSampler
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TestBaseSampler:
|
|
27
|
+
"""Test cases for BaseSampler abstract class."""
|
|
28
|
+
|
|
29
|
+
def test_base_sampler_is_abstract(self):
|
|
30
|
+
"""Test that BaseSampler cannot be instantiated directly."""
|
|
31
|
+
with pytest.raises(TypeError) as exc_info:
|
|
32
|
+
BaseSampler()
|
|
33
|
+
|
|
34
|
+
assert "Can't instantiate abstract class" in str(exc_info.value)
|
|
35
|
+
assert "sample" in str(exc_info.value)
|
|
36
|
+
|
|
37
|
+
def test_base_sampler_has_abstract_methods(self):
|
|
38
|
+
"""Test that BaseSampler defines abstract methods."""
|
|
39
|
+
assert hasattr(BaseSampler, "sample")
|
|
40
|
+
assert getattr(BaseSampler.sample, "__isabstractmethod__", False)
|
|
41
|
+
|
|
42
|
+
def test_base_sampler_has_call_method(self):
|
|
43
|
+
"""Test that BaseSampler has __call__ method."""
|
|
44
|
+
assert callable(BaseSampler)
|
|
45
|
+
|
|
46
|
+
def test_base_sampler_initialization_states(self):
|
|
47
|
+
"""Test BaseSampler initialization sets _states correctly."""
|
|
48
|
+
|
|
49
|
+
# Create a concrete implementation for testing
|
|
50
|
+
class TestSampler(BaseSampler):
|
|
51
|
+
def sample(self, ready_indexes: list[int], batch_size: int, **kwargs: Any) -> tuple[list[int], list[int]]:
|
|
52
|
+
return ready_indexes[:batch_size], ready_indexes[:batch_size]
|
|
53
|
+
|
|
54
|
+
sampler = TestSampler()
|
|
55
|
+
assert hasattr(sampler, "_states")
|
|
56
|
+
assert sampler._states == {}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class TestSequentialSampler:
|
|
60
|
+
"""Test cases for SequentialSampler."""
|
|
61
|
+
|
|
62
|
+
def test_sequential_sampler_initialization(self):
|
|
63
|
+
"""Test SequentialSampler initialization."""
|
|
64
|
+
sampler = SequentialSampler()
|
|
65
|
+
assert isinstance(sampler, BaseSampler)
|
|
66
|
+
assert hasattr(sampler, "_states")
|
|
67
|
+
assert sampler._states == {}
|
|
68
|
+
|
|
69
|
+
def test_sequential_sampler_basic_functionality(self):
|
|
70
|
+
"""Test basic sampling functionality."""
|
|
71
|
+
sampler = SequentialSampler()
|
|
72
|
+
ready_indexes = [0, 1, 2, 3, 4, 5]
|
|
73
|
+
batch_size = 3
|
|
74
|
+
|
|
75
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size)
|
|
76
|
+
|
|
77
|
+
assert sampled == [0, 1, 2]
|
|
78
|
+
assert consumed == [0, 1, 2]
|
|
79
|
+
assert len(sampled) == batch_size
|
|
80
|
+
assert len(consumed) == batch_size
|
|
81
|
+
|
|
82
|
+
def test_sequential_sampler_empty_ready_indexes(self):
|
|
83
|
+
"""Test behavior with empty ready indexes."""
|
|
84
|
+
sampler = SequentialSampler()
|
|
85
|
+
ready_indexes = []
|
|
86
|
+
batch_size = 3
|
|
87
|
+
|
|
88
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size)
|
|
89
|
+
|
|
90
|
+
assert sampled == []
|
|
91
|
+
assert consumed == []
|
|
92
|
+
|
|
93
|
+
def test_sequential_sampler_batch_size_larger_than_ready(self):
|
|
94
|
+
"""Test behavior when batch_size > len(ready_indexes)."""
|
|
95
|
+
sampler = SequentialSampler()
|
|
96
|
+
ready_indexes = [0, 1]
|
|
97
|
+
batch_size = 5
|
|
98
|
+
|
|
99
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size)
|
|
100
|
+
|
|
101
|
+
assert sampled == [0, 1]
|
|
102
|
+
assert consumed == [0, 1]
|
|
103
|
+
assert len(sampled) == len(ready_indexes)
|
|
104
|
+
|
|
105
|
+
def test_sequential_sampler_zero_batch_size(self):
|
|
106
|
+
"""Test behavior with zero batch size."""
|
|
107
|
+
sampler = SequentialSampler()
|
|
108
|
+
ready_indexes = [0, 1, 2, 3]
|
|
109
|
+
batch_size = 0
|
|
110
|
+
|
|
111
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size)
|
|
112
|
+
|
|
113
|
+
assert sampled == []
|
|
114
|
+
assert consumed == []
|
|
115
|
+
|
|
116
|
+
def test_sequential_sampler_negative_batch_size(self):
|
|
117
|
+
"""Test behavior with negative batch size."""
|
|
118
|
+
sampler = SequentialSampler()
|
|
119
|
+
ready_indexes = [0, 1, 2, 3]
|
|
120
|
+
batch_size = -1
|
|
121
|
+
|
|
122
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size)
|
|
123
|
+
|
|
124
|
+
# Python slicing with negative numbers should work as expected
|
|
125
|
+
expected = ready_indexes[:batch_size] # This gives [0, 1, 2] for -1
|
|
126
|
+
assert sampled == expected
|
|
127
|
+
assert consumed == expected
|
|
128
|
+
|
|
129
|
+
def test_sequential_sampler_non_sequential_indexes(self):
|
|
130
|
+
"""Test behavior with non-sequential ready indexes."""
|
|
131
|
+
sampler = SequentialSampler()
|
|
132
|
+
ready_indexes = [10, 5, 15, 20, 8]
|
|
133
|
+
batch_size = 3
|
|
134
|
+
|
|
135
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size)
|
|
136
|
+
|
|
137
|
+
assert sampled == [10, 5, 15]
|
|
138
|
+
assert consumed == [10, 5, 15]
|
|
139
|
+
|
|
140
|
+
def test_sequential_sampler_duplicate_indexes(self):
|
|
141
|
+
"""Test behavior with duplicate indexes."""
|
|
142
|
+
sampler = SequentialSampler()
|
|
143
|
+
ready_indexes = [0, 1, 0, 2, 1, 3]
|
|
144
|
+
batch_size = 4
|
|
145
|
+
|
|
146
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size)
|
|
147
|
+
|
|
148
|
+
assert sampled == [0, 1, 0, 2]
|
|
149
|
+
assert consumed == [0, 1, 0, 2]
|
|
150
|
+
|
|
151
|
+
def test_sequential_sampler_call_method(self):
|
|
152
|
+
"""Test that __call__ method works correctly."""
|
|
153
|
+
sampler = SequentialSampler()
|
|
154
|
+
ready_indexes = [0, 1, 2, 3]
|
|
155
|
+
batch_size = 2
|
|
156
|
+
|
|
157
|
+
sampled, consumed = sampler(ready_indexes, batch_size)
|
|
158
|
+
|
|
159
|
+
assert sampled == [0, 1]
|
|
160
|
+
assert consumed == [0, 1]
|
|
161
|
+
|
|
162
|
+
def test_sequential_sampler_with_extra_kwargs(self):
|
|
163
|
+
"""Test that SequentialSampler accepts extra kwargs but ignores them."""
|
|
164
|
+
sampler = SequentialSampler()
|
|
165
|
+
ready_indexes = [0, 1, 2, 3]
|
|
166
|
+
batch_size = 2
|
|
167
|
+
|
|
168
|
+
# SequentialSampler should accept extra kwargs but ignore them
|
|
169
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size, extra_param="ignored")
|
|
170
|
+
|
|
171
|
+
assert sampled == [0, 1]
|
|
172
|
+
assert consumed == [0, 1]
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class TestGRPOGroupNSampler:
|
|
176
|
+
"""Test cases for GRPOGroupNSampler."""
|
|
177
|
+
|
|
178
|
+
def test_grpo_sampler_initialization(self):
|
|
179
|
+
"""Test GRPOGroupNSampler initialization."""
|
|
180
|
+
sampler = GRPOGroupNSampler()
|
|
181
|
+
assert isinstance(sampler, BaseSampler)
|
|
182
|
+
assert hasattr(sampler, "_states")
|
|
183
|
+
assert sampler._states == {}
|
|
184
|
+
|
|
185
|
+
def test_grpo_sampler_basic_functionality(self):
|
|
186
|
+
"""Test basic grouped sampling functionality."""
|
|
187
|
+
sampler = GRPOGroupNSampler()
|
|
188
|
+
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] # 8 indexes
|
|
189
|
+
batch_size = 8
|
|
190
|
+
n_samples_per_prompt = 4 # 2 groups of 4
|
|
191
|
+
|
|
192
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt)
|
|
193
|
+
|
|
194
|
+
assert sampled == [0, 1, 2, 3, 4, 5, 6, 7]
|
|
195
|
+
assert consumed == [0, 1, 2, 3, 4, 5, 6, 7]
|
|
196
|
+
assert len(sampled) == batch_size
|
|
197
|
+
assert len(consumed) == batch_size
|
|
198
|
+
|
|
199
|
+
def test_grpo_sampler_partial_batch(self):
|
|
200
|
+
"""Test partial batch sampling."""
|
|
201
|
+
sampler = GRPOGroupNSampler()
|
|
202
|
+
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] # 12 indexes
|
|
203
|
+
batch_size = 8 # Want 8 samples total
|
|
204
|
+
n_samples_per_prompt = 4 # 2 groups of 4
|
|
205
|
+
|
|
206
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt)
|
|
207
|
+
|
|
208
|
+
assert sampled == [0, 1, 2, 3, 4, 5, 6, 7]
|
|
209
|
+
assert consumed == [0, 1, 2, 3, 4, 5, 6, 7]
|
|
210
|
+
assert len(sampled) == batch_size
|
|
211
|
+
assert len(consumed) == batch_size
|
|
212
|
+
|
|
213
|
+
def test_grpo_sampler_different_group_sizes(self):
|
|
214
|
+
"""Test different n_samples_per_prompt values."""
|
|
215
|
+
sampler = GRPOGroupNSampler()
|
|
216
|
+
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
|
|
217
|
+
|
|
218
|
+
# Test with 2 samples per prompt (8 groups)
|
|
219
|
+
sampled, consumed = sampler.sample(ready_indexes, 8, n_samples_per_prompt=2)
|
|
220
|
+
assert sampled == [0, 1, 2, 3, 4, 5, 6, 7]
|
|
221
|
+
assert consumed == [0, 1, 2, 3, 4, 5, 6, 7]
|
|
222
|
+
|
|
223
|
+
# Test with 8 samples per prompt (2 groups)
|
|
224
|
+
sampled, consumed = sampler.sample(ready_indexes, 8, n_samples_per_prompt=8)
|
|
225
|
+
assert sampled == [0, 1, 2, 3, 4, 5, 6, 7]
|
|
226
|
+
assert consumed == [0, 1, 2, 3, 4, 5, 6, 7]
|
|
227
|
+
|
|
228
|
+
def test_grpo_sampler_batch_size_divisibility(self):
|
|
229
|
+
"""Test that batch_size must be divisible by n_samples_per_prompt."""
|
|
230
|
+
sampler = GRPOGroupNSampler()
|
|
231
|
+
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] # 8 indexes, sufficient for batch_size=7
|
|
232
|
+
batch_size = 7
|
|
233
|
+
n_samples_per_prompt = 4
|
|
234
|
+
|
|
235
|
+
with pytest.raises(ValueError) as exc_info:
|
|
236
|
+
sampler.sample(ready_indexes, batch_size, n_samples_per_prompt)
|
|
237
|
+
|
|
238
|
+
assert "must be a multiple of n_samples_per_prompt" in str(exc_info.value)
|
|
239
|
+
|
|
240
|
+
def test_grpo_sampler_insufficient_ready_indexes(self):
|
|
241
|
+
"""Test behavior when not enough ready indexes are available."""
|
|
242
|
+
sampler = GRPOGroupNSampler()
|
|
243
|
+
ready_indexes = [0, 1, 2, 3] # Only 4 indexes, but need 8 for 2 groups of 4
|
|
244
|
+
batch_size = 8
|
|
245
|
+
n_samples_per_prompt = 4
|
|
246
|
+
|
|
247
|
+
# Should return empty lists when insufficient complete groups
|
|
248
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt)
|
|
249
|
+
assert sampled == []
|
|
250
|
+
assert consumed == []
|
|
251
|
+
|
|
252
|
+
def test_grpo_sampler_exact_multiple_available(self):
|
|
253
|
+
"""Test when ready_indexes length is exactly a multiple of n_samples_per_prompt."""
|
|
254
|
+
sampler = GRPOGroupNSampler()
|
|
255
|
+
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] # 8 indexes
|
|
256
|
+
batch_size = 8
|
|
257
|
+
n_samples_per_prompt = 4
|
|
258
|
+
|
|
259
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt)
|
|
260
|
+
|
|
261
|
+
assert sampled == [0, 1, 2, 3, 4, 5, 6, 7]
|
|
262
|
+
assert consumed == [0, 1, 2, 3, 4, 5, 6, 7]
|
|
263
|
+
|
|
264
|
+
def test_grpo_sampler_zero_batch_size(self):
|
|
265
|
+
"""Test behavior with zero batch size."""
|
|
266
|
+
sampler = GRPOGroupNSampler()
|
|
267
|
+
ready_indexes = [0, 1, 2, 3]
|
|
268
|
+
batch_size = 0
|
|
269
|
+
n_samples_per_prompt = 2
|
|
270
|
+
|
|
271
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt)
|
|
272
|
+
|
|
273
|
+
assert sampled == []
|
|
274
|
+
assert consumed == []
|
|
275
|
+
|
|
276
|
+
def test_grpo_sampler_single_sample_per_prompt(self):
|
|
277
|
+
"""Test with n_samples_per_prompt = 1."""
|
|
278
|
+
sampler = GRPOGroupNSampler()
|
|
279
|
+
ready_indexes = [0, 1, 2, 3, 4, 5]
|
|
280
|
+
batch_size = 3
|
|
281
|
+
n_samples_per_prompt = 1
|
|
282
|
+
|
|
283
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt)
|
|
284
|
+
|
|
285
|
+
assert sampled == [0, 1, 2]
|
|
286
|
+
assert consumed == [0, 1, 2]
|
|
287
|
+
|
|
288
|
+
def test_grpo_sampler_large_group_size(self):
|
|
289
|
+
"""Test with large n_samples_per_prompt."""
|
|
290
|
+
sampler = GRPOGroupNSampler()
|
|
291
|
+
ready_indexes = list(range(20)) # 20 indexes
|
|
292
|
+
batch_size = 20
|
|
293
|
+
n_samples_per_prompt = 10
|
|
294
|
+
|
|
295
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt)
|
|
296
|
+
|
|
297
|
+
assert sampled == list(range(20))
|
|
298
|
+
assert consumed == list(range(20))
|
|
299
|
+
|
|
300
|
+
def test_grpo_sampler_call_method(self):
|
|
301
|
+
"""Test that __call__ method works correctly."""
|
|
302
|
+
sampler = GRPOGroupNSampler()
|
|
303
|
+
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7]
|
|
304
|
+
batch_size = 4
|
|
305
|
+
n_samples_per_prompt = 2
|
|
306
|
+
|
|
307
|
+
sampled, consumed = sampler(ready_indexes, batch_size, n_samples_per_prompt=n_samples_per_prompt)
|
|
308
|
+
|
|
309
|
+
assert sampled == [0, 1, 2, 3]
|
|
310
|
+
assert consumed == [0, 1, 2, 3]
|
|
311
|
+
|
|
312
|
+
def test_grpo_sampler_parameter_order_independence(self):
|
|
313
|
+
"""Test that parameter order doesn't matter when using kwargs."""
|
|
314
|
+
sampler = GRPOGroupNSampler()
|
|
315
|
+
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7]
|
|
316
|
+
|
|
317
|
+
# Try different parameter orders
|
|
318
|
+
sampled1, consumed1 = sampler.sample(n_samples_per_prompt=4, batch_size=8, ready_indexes=ready_indexes)
|
|
319
|
+
|
|
320
|
+
sampled2, consumed2 = sampler.sample(batch_size=8, ready_indexes=ready_indexes, n_samples_per_prompt=4)
|
|
321
|
+
|
|
322
|
+
assert sampled1 == sampled2
|
|
323
|
+
assert consumed1 == consumed2
|
|
324
|
+
|
|
325
|
+
def test_grpo_sampler_with_extra_kwargs(self):
|
|
326
|
+
"""Test that GRPOGroupNSampler accepts extra kwargs but ignores them."""
|
|
327
|
+
sampler = GRPOGroupNSampler()
|
|
328
|
+
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7]
|
|
329
|
+
batch_size = 8
|
|
330
|
+
n_samples_per_prompt = 4
|
|
331
|
+
|
|
332
|
+
# GRPOGroupNSampler should accept extra kwargs but ignore them
|
|
333
|
+
sampled, consumed = sampler.sample(
|
|
334
|
+
ready_indexes, batch_size, n_samples_per_prompt, extra_param="ignored", another_param=42
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
assert sampled == [0, 1, 2, 3, 4, 5, 6, 7]
|
|
338
|
+
assert consumed == [0, 1, 2, 3, 4, 5, 6, 7]
|
|
339
|
+
|
|
340
|
+
def test_grpo_sampler_non_sequential_indexes(self):
|
|
341
|
+
"""Test with non-sequential ready indexes that get sorted."""
|
|
342
|
+
sampler = GRPOGroupNSampler()
|
|
343
|
+
ready_indexes = [3, 4, 5, 6, 9, 10, 11, 12] # Non-sequential order but has consecutive groups after sorting
|
|
344
|
+
batch_size = 8
|
|
345
|
+
n_samples_per_prompt = 4
|
|
346
|
+
|
|
347
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt)
|
|
348
|
+
|
|
349
|
+
# Should find consecutive groups after sorting: [3,4,5,6] and [9,10,11,12]
|
|
350
|
+
expected = [3, 4, 5, 6, 9, 10, 11, 12]
|
|
351
|
+
assert sampled == expected
|
|
352
|
+
assert consumed == expected
|
|
353
|
+
|
|
354
|
+
def test_grpo_sampler_invalid_n_samples_per_prompt(self):
|
|
355
|
+
"""Test behavior with invalid n_samples_per_prompt values."""
|
|
356
|
+
sampler = GRPOGroupNSampler()
|
|
357
|
+
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7]
|
|
358
|
+
batch_size = 8
|
|
359
|
+
|
|
360
|
+
# Test zero n_samples_per_prompt
|
|
361
|
+
with pytest.raises(ValueError) as exc_info:
|
|
362
|
+
sampler.sample(ready_indexes, batch_size, n_samples_per_prompt=0)
|
|
363
|
+
assert "must be positive" in str(exc_info.value)
|
|
364
|
+
|
|
365
|
+
# Test negative n_samples_per_prompt
|
|
366
|
+
with pytest.raises(ValueError) as exc_info:
|
|
367
|
+
sampler.sample(ready_indexes, batch_size, n_samples_per_prompt=-2)
|
|
368
|
+
assert "must be positive" in str(exc_info.value)
|
|
369
|
+
|
|
370
|
+
def test_grpo_sampler_no_complete_groups(self):
|
|
371
|
+
"""Test behavior when no complete groups are available."""
|
|
372
|
+
sampler = GRPOGroupNSampler()
|
|
373
|
+
ready_indexes = [0, 1, 3, 4, 6, 7] # No consecutive groups of size 3
|
|
374
|
+
batch_size = 6
|
|
375
|
+
n_samples_per_prompt = 3
|
|
376
|
+
|
|
377
|
+
# Should return empty lists when no complete groups found
|
|
378
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt)
|
|
379
|
+
assert sampled == []
|
|
380
|
+
assert consumed == []
|
|
381
|
+
|
|
382
|
+
def test_grpo_sampler_mixed_groups(self):
|
|
383
|
+
"""Test behavior with mixed complete and incomplete groups."""
|
|
384
|
+
sampler = GRPOGroupNSampler()
|
|
385
|
+
ready_indexes = [0, 1, 3, 4, 5, 6, 7, 9, 10, 11] # Mixed groups
|
|
386
|
+
batch_size = 6
|
|
387
|
+
n_samples_per_prompt = 3
|
|
388
|
+
|
|
389
|
+
# Should find the complete groups [3,4,5] and [9,10,11]
|
|
390
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt)
|
|
391
|
+
assert sampled == [3, 4, 5, 9, 10, 11]
|
|
392
|
+
assert consumed == [3, 4, 5, 9, 10, 11]
|
|
393
|
+
|
|
394
|
+
def test_grpo_sampler_sorting_functionality(self):
|
|
395
|
+
"""Test that ready_indexes are properly sorted before group detection."""
|
|
396
|
+
sampler = GRPOGroupNSampler()
|
|
397
|
+
ready_indexes = [10, 11, 12, 5, 6, 7, 8, 9] # Out of order but contains consecutive groups
|
|
398
|
+
batch_size = 8
|
|
399
|
+
n_samples_per_prompt = 4
|
|
400
|
+
|
|
401
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt)
|
|
402
|
+
|
|
403
|
+
# After sorting: [5,6,7,8,9,10,11,12], should find [5,6,7,8] and [9,10,11,12]
|
|
404
|
+
expected = [5, 6, 7, 8, 9, 10, 11, 12]
|
|
405
|
+
assert sampled == expected
|
|
406
|
+
assert consumed == expected
|
|
407
|
+
|
|
408
|
+
def test_grpo_sampler_insufficient_groups(self):
|
|
409
|
+
"""Test behavior when requesting more groups than available."""
|
|
410
|
+
sampler = GRPOGroupNSampler()
|
|
411
|
+
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] # 4 groups of 4
|
|
412
|
+
batch_size = 12 # Requesting 3 groups of 4 - this should work
|
|
413
|
+
n_samples_per_prompt = 4
|
|
414
|
+
|
|
415
|
+
# This should actually work fine since we have 4 groups and request 3
|
|
416
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt)
|
|
417
|
+
assert len(sampled) == 12
|
|
418
|
+
assert len(consumed) == 12
|
|
419
|
+
|
|
420
|
+
# Now test requesting more than available
|
|
421
|
+
batch_size = 20 # Requesting 5 groups of 4, but only have 4
|
|
422
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size, n_samples_per_prompt)
|
|
423
|
+
|
|
424
|
+
# Should return empty lists when requesting more complete groups than available
|
|
425
|
+
assert sampled == []
|
|
426
|
+
assert consumed == []
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
class TestSamplerIntegration:
|
|
430
|
+
"""Integration tests for samplers."""
|
|
431
|
+
|
|
432
|
+
def test_samplers_implement_base_interface(self):
|
|
433
|
+
"""Test that all samplers properly implement BaseSampler interface."""
|
|
434
|
+
samplers = [SequentialSampler(), GRPOGroupNSampler()]
|
|
435
|
+
|
|
436
|
+
for sampler in samplers:
|
|
437
|
+
# Test that they are instances of BaseSampler
|
|
438
|
+
assert isinstance(sampler, BaseSampler)
|
|
439
|
+
|
|
440
|
+
# Test that they have the required methods
|
|
441
|
+
assert hasattr(sampler, "sample")
|
|
442
|
+
assert callable(sampler.sample)
|
|
443
|
+
assert callable(sampler)
|
|
444
|
+
assert callable(sampler.__call__)
|
|
445
|
+
|
|
446
|
+
def test_samplers_return_consistent_types(self):
|
|
447
|
+
"""Test that all samplers return consistent tuple types."""
|
|
448
|
+
samplers = [(SequentialSampler(), {}), (GRPOGroupNSampler(), {"n_samples_per_prompt": 2})]
|
|
449
|
+
|
|
450
|
+
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7]
|
|
451
|
+
batch_size = 4
|
|
452
|
+
|
|
453
|
+
for sampler, kwargs in samplers:
|
|
454
|
+
sampled, consumed = sampler.sample(ready_indexes, batch_size, **kwargs)
|
|
455
|
+
|
|
456
|
+
# Check return types
|
|
457
|
+
assert isinstance(sampled, list)
|
|
458
|
+
assert isinstance(consumed, list)
|
|
459
|
+
assert isinstance(sampled[0], int) if sampled else True
|
|
460
|
+
assert isinstance(consumed[0], int) if consumed else True
|
|
461
|
+
|
|
462
|
+
# Check return value consistency
|
|
463
|
+
assert len(sampled) <= batch_size
|
|
464
|
+
assert len(sampled) == len(consumed)
|
|
465
|
+
|
|
466
|
+
def test_samplers_handle_edge_cases_consistently(self):
|
|
467
|
+
"""Test that samplers handle edge cases consistently."""
|
|
468
|
+
samplers = [(SequentialSampler(), {}), (GRPOGroupNSampler(), {"n_samples_per_prompt": 2})]
|
|
469
|
+
|
|
470
|
+
# Test empty ready indexes
|
|
471
|
+
for sampler, kwargs in samplers:
|
|
472
|
+
try:
|
|
473
|
+
sampled, consumed = sampler.sample([], 0, **kwargs)
|
|
474
|
+
assert sampled == []
|
|
475
|
+
assert consumed == []
|
|
476
|
+
except Exception:
|
|
477
|
+
# GRPO sampler might fail with empty list, that's expected
|
|
478
|
+
pass
|
|
479
|
+
|
|
480
|
+
# Test zero batch size
|
|
481
|
+
for sampler, kwargs in samplers:
|
|
482
|
+
try:
|
|
483
|
+
sampled, consumed = sampler.sample([0, 1, 2, 3], 0, **kwargs)
|
|
484
|
+
assert sampled == []
|
|
485
|
+
assert consumed == []
|
|
486
|
+
except Exception:
|
|
487
|
+
# Some samplers might not handle zero batch size
|
|
488
|
+
pass
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
if __name__ == "__main__":
|
|
492
|
+
pytest.main([__file__])
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import sys
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import pytest
|
|
20
|
+
import tensordict
|
|
21
|
+
import torch
|
|
22
|
+
from tensordict import NonTensorData, NonTensorStack, TensorDict
|
|
23
|
+
|
|
24
|
+
# Import your classes here
|
|
25
|
+
parent_dir = Path(__file__).resolve().parent.parent
|
|
26
|
+
sys.path.append(str(parent_dir))
|
|
27
|
+
|
|
28
|
+
from transfer_queue.utils.serial_utils import MsgpackDecoder, MsgpackEncoder # noqa: E402
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) -> TensorDict:
|
|
32
|
+
if non_tensor_dict is None:
|
|
33
|
+
non_tensor_dict = {}
|
|
34
|
+
|
|
35
|
+
batch_size = None
|
|
36
|
+
|
|
37
|
+
for key, val in tensor_dict.items():
|
|
38
|
+
if isinstance(val, list):
|
|
39
|
+
for v in val:
|
|
40
|
+
assert not isinstance(v, torch.Tensor), (
|
|
41
|
+
"Passing a list makes the data NonTensorStack, "
|
|
42
|
+
"which doesn't support torch.Tensor. Please convert to numpy first"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
assert isinstance(val, torch.Tensor | list)
|
|
46
|
+
|
|
47
|
+
if batch_size is None:
|
|
48
|
+
batch_size = len(val)
|
|
49
|
+
else:
|
|
50
|
+
assert len(val) == batch_size
|
|
51
|
+
|
|
52
|
+
if batch_size is None:
|
|
53
|
+
batch_size = []
|
|
54
|
+
else:
|
|
55
|
+
batch_size = [batch_size]
|
|
56
|
+
|
|
57
|
+
for key, val in non_tensor_dict.items():
|
|
58
|
+
assert key not in tensor_dict
|
|
59
|
+
tensor_dict[key] = NonTensorData(val)
|
|
60
|
+
|
|
61
|
+
return TensorDict(source=tensor_dict, batch_size=batch_size)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@pytest.mark.parametrize(
|
|
65
|
+
"dtype",
|
|
66
|
+
[
|
|
67
|
+
torch.float16,
|
|
68
|
+
torch.bfloat16,
|
|
69
|
+
torch.float32,
|
|
70
|
+
],
|
|
71
|
+
)
|
|
72
|
+
def test_tensor_serialization(dtype):
|
|
73
|
+
encoder = MsgpackEncoder()
|
|
74
|
+
decoder = MsgpackDecoder(torch.Tensor)
|
|
75
|
+
|
|
76
|
+
tensor = torch.randn(100, 10, dtype=dtype)
|
|
77
|
+
serialized = encoder.encode(tensor)
|
|
78
|
+
deserialized = decoder.decode(serialized)
|
|
79
|
+
assert torch.allclose(tensor, deserialized)
|
|
80
|
+
|
|
81
|
+
vocab_size = 128
|
|
82
|
+
a = torch.randint(low=0, high=vocab_size, size=(11,))
|
|
83
|
+
b = torch.randint(low=0, high=vocab_size, size=(13,))
|
|
84
|
+
input_ids = [a, b]
|
|
85
|
+
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged, dtype=dtype)
|
|
86
|
+
|
|
87
|
+
input_ids_serialized = encoder.encode(input_ids)
|
|
88
|
+
input_ids_deserialized = decoder.decode(input_ids_serialized)
|
|
89
|
+
for i in range(len(input_ids.unbind())):
|
|
90
|
+
assert torch.allclose(input_ids[0], input_ids_deserialized[0])
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def test_tensordict_serialization_with_nontensor():
|
|
94
|
+
encoder = MsgpackEncoder()
|
|
95
|
+
decoder = MsgpackDecoder(TensorDict)
|
|
96
|
+
|
|
97
|
+
obs = torch.randn(100, 10)
|
|
98
|
+
data1 = {"obs": obs, "act": torch.randn(100, 3), "data_sources": ["gsm8k"] * 100}
|
|
99
|
+
data1 = get_tensordict(tensor_dict=data1)
|
|
100
|
+
|
|
101
|
+
serialized = encoder.encode(data1)
|
|
102
|
+
deserialized = decoder.decode(serialized)
|
|
103
|
+
|
|
104
|
+
assert deserialized.keys() == data1.keys()
|
|
105
|
+
assert deserialized.batch_size[0] == 100
|
|
106
|
+
assert isinstance(deserialized.get("data_sources"), NonTensorStack)
|
|
107
|
+
for k, v in data1.items():
|
|
108
|
+
if isinstance(v, torch.Tensor):
|
|
109
|
+
assert torch.allclose(deserialized[k], v)
|
|
110
|
+
elif isinstance(v, NonTensorStack):
|
|
111
|
+
assert deserialized[k] == data1[k]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def test_tensordict_serialization_with_images():
|
|
115
|
+
# each sample contains a sequence with multiple images of different sizes
|
|
116
|
+
vocab_size = 128
|
|
117
|
+
a = torch.randint(low=0, high=vocab_size, size=(11,))
|
|
118
|
+
b = torch.randint(low=0, high=vocab_size, size=(13,))
|
|
119
|
+
input_ids = [a, b]
|
|
120
|
+
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
|
|
121
|
+
|
|
122
|
+
a_images = [
|
|
123
|
+
torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),
|
|
124
|
+
torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),
|
|
125
|
+
]
|
|
126
|
+
b_images = [
|
|
127
|
+
torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),
|
|
128
|
+
torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),
|
|
129
|
+
torch.randint(low=0, high=255, size=(3, 64, 64), dtype=torch.uint8).numpy(),
|
|
130
|
+
]
|
|
131
|
+
|
|
132
|
+
images = [a_images, b_images]
|
|
133
|
+
|
|
134
|
+
data = get_tensordict({"input_ids": input_ids, "images": images})
|
|
135
|
+
|
|
136
|
+
encoder = MsgpackEncoder()
|
|
137
|
+
decoder = MsgpackDecoder(TensorDict)
|
|
138
|
+
|
|
139
|
+
serialized = encoder.encode(data)
|
|
140
|
+
deserialized = decoder.decode(serialized)
|
|
141
|
+
|
|
142
|
+
assert np.all(np.equal(deserialized[0]["images"][0], a_images[0]))
|
|
143
|
+
assert torch.all(torch.eq(deserialized[0]["input_ids"], a))
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# Copied from https://github.com/volcengine/verl/blob/33edd95e13c72b9494585765b5fedc679fd73923/tests/test_protocol_v2_on_cpu.py#L119
|
|
147
|
+
def test_tensordict_with_packing():
|
|
148
|
+
vocab_size = 128
|
|
149
|
+
a = torch.randint(low=0, high=vocab_size, size=(11,))
|
|
150
|
+
b = torch.randint(low=0, high=vocab_size, size=(13,))
|
|
151
|
+
input_ids = [a, b]
|
|
152
|
+
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
|
|
153
|
+
|
|
154
|
+
data = get_tensordict({"input_ids": input_ids})
|
|
155
|
+
encoder = MsgpackEncoder()
|
|
156
|
+
decoder = MsgpackDecoder(TensorDict)
|
|
157
|
+
deserialized_data = decoder.decode(encoder.encode(data))
|
|
158
|
+
|
|
159
|
+
# test cu_seqlens
|
|
160
|
+
cu_seqlens = torch.tensor([0, 11, 24])
|
|
161
|
+
assert torch.all(torch.eq(cu_seqlens, deserialized_data["input_ids"].offsets()))
|
|
162
|
+
|
|
163
|
+
# test index
|
|
164
|
+
assert torch.all(torch.eq(deserialized_data["input_ids"][0], a))
|
|
165
|
+
assert torch.all(torch.eq(deserialized_data["input_ids"][1], b))
|
|
166
|
+
|
|
167
|
+
assert torch.all(torch.eq(deserialized_data[0]["input_ids"], a))
|
|
168
|
+
assert torch.all(torch.eq(deserialized_data[1]["input_ids"], b))
|
|
169
|
+
|
|
170
|
+
data_lst = deserialized_data.chunk(2)
|
|
171
|
+
|
|
172
|
+
assert torch.all(torch.eq(data_lst[0]["input_ids"][0], a))
|
|
173
|
+
assert torch.all(torch.eq(data_lst[1]["input_ids"][0], b))
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def test_nested_tensordict_serialization():
|
|
177
|
+
td1 = tensordict.TensorDict({"a": torch.randn(2, 3), "b": torch.randn(2, 4)}, batch_size=[2])
|
|
178
|
+
|
|
179
|
+
td2 = tensordict.TensorDict({"c": torch.randn(2, 5), "d": torch.randn(2, 6)}, batch_size=[2])
|
|
180
|
+
|
|
181
|
+
td = tensordict.TensorDict({"part1": td1, "part2": td2, "e": torch.randn(2, 7)}, batch_size=[2])
|
|
182
|
+
|
|
183
|
+
encoder = MsgpackEncoder()
|
|
184
|
+
decoder = MsgpackDecoder(TensorDict)
|
|
185
|
+
deserialized_td = decoder.decode(encoder.encode(td))
|
|
186
|
+
|
|
187
|
+
assert isinstance(deserialized_td, tensordict.TensorDict)
|
|
188
|
+
assert set(deserialized_td.keys()) == set(td.keys())
|
|
189
|
+
assert isinstance(deserialized_td["part1"], tensordict.TensorDict)
|
|
190
|
+
assert isinstance(deserialized_td["part2"], tensordict.TensorDict)
|
|
191
|
+
|
|
192
|
+
assert set(deserialized_td["part1"].keys()) == set(td1.keys())
|
|
193
|
+
assert set(deserialized_td["part2"].keys()) == set(td2.keys())
|
|
194
|
+
|
|
195
|
+
for key in td.keys():
|
|
196
|
+
if isinstance(td[key], tensordict.TensorDict):
|
|
197
|
+
for inner_key in td[key].keys():
|
|
198
|
+
assert torch.allclose(deserialized_td[key][inner_key], td[key][inner_key]), (
|
|
199
|
+
f"Values for key '{key}.{inner_key}' do not match"
|
|
200
|
+
)
|
|
201
|
+
else:
|
|
202
|
+
assert torch.allclose(deserialized_td[key], td[key]), f"Values for key '{key}' do not match"
|