TransferQueue 0.1.4.dev2__py3-none-any.whl → 0.1.5.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +1 -1
- recipe/simple_use_case/sync_demo.py +1 -1
- tests/test_client.py +168 -12
- tests/test_controller.py +116 -19
- tests/test_controller_data_partitions.py +12 -146
- tests/test_serial_utils_on_cpu.py +12 -12
- transfer_queue/client.py +286 -53
- transfer_queue/controller.py +354 -170
- transfer_queue/metadata.py +37 -4
- transfer_queue/storage/managers/simple_backend_manager.py +67 -163
- transfer_queue/storage/simple_backend.py +15 -47
- transfer_queue/utils/serial_utils.py +127 -1
- transfer_queue/utils/utils.py +31 -17
- transfer_queue/utils/zmq_utils.py +15 -112
- transfer_queue/version/version +1 -1
- {transferqueue-0.1.4.dev2.dist-info → transferqueue-0.1.5.dev3.dist-info}/METADATA +8 -3
- {transferqueue-0.1.4.dev2.dist-info → transferqueue-0.1.5.dev3.dist-info}/RECORD +24 -21
- {transferqueue-0.1.4.dev2.dist-info → transferqueue-0.1.5.dev3.dist-info}/top_level.txt +1 -1
- tutorial/01_core_components.py +268 -0
- tutorial/02_metadata_concepts.py +444 -0
- tutorial/03_understanding_controller.py +312 -0
- tutorial/04_custom_sampler.py +413 -0
- performance_test.py +0 -383
- {transferqueue-0.1.4.dev2.dist-info → transferqueue-0.1.5.dev3.dist-info}/WHEEL +0 -0
- {transferqueue-0.1.4.dev2.dist-info → transferqueue-0.1.5.dev3.dist-info}/licenses/LICENSE +0 -0
|
@@ -306,7 +306,7 @@ class Trainer:
|
|
|
306
306
|
|
|
307
307
|
# Client notifies controller to clear data status, controller returns metadata;
|
|
308
308
|
# Client then notifies the storage plane to clear based on metadata
|
|
309
|
-
asyncio.run(self.data_system_client.
|
|
309
|
+
asyncio.run(self.data_system_client.async_clear_partition(partition_id=f"train_{step}"))
|
|
310
310
|
logger.info("clear ok! ")
|
|
311
311
|
logger.info("demo done!")
|
|
312
312
|
|
|
@@ -190,7 +190,7 @@ def fit(config, data_system_client):
|
|
|
190
190
|
# Client then notifies the storage plane to clear based on metadata
|
|
191
191
|
# Client selects one master controller to get metadata,
|
|
192
192
|
# other controllers directly clear without returning metadata
|
|
193
|
-
data_system_client.
|
|
193
|
+
data_system_client.clear_partition(partition_id=f"train_{step}")
|
|
194
194
|
logger.info("clear ok! ")
|
|
195
195
|
logger.info("demo done!")
|
|
196
196
|
|
tests/test_client.py
CHANGED
|
@@ -101,11 +101,38 @@ class MockController:
|
|
|
101
101
|
if request_msg.request_type == ZMQRequestType.GET_META:
|
|
102
102
|
response_body = self._mock_batch_meta(request_msg.body)
|
|
103
103
|
response_type = ZMQRequestType.GET_META_RESPONSE
|
|
104
|
-
elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META:
|
|
105
|
-
response_body = self._mock_batch_meta(request_msg.body)
|
|
106
|
-
response_type = ZMQRequestType.GET_CLEAR_META_RESPONSE
|
|
107
104
|
elif request_msg.request_type == ZMQRequestType.CLEAR_META:
|
|
108
|
-
response_body = {"message": "clear ok"}
|
|
105
|
+
response_body = {"message": "clear meta ok"}
|
|
106
|
+
response_type = ZMQRequestType.CLEAR_META_RESPONSE
|
|
107
|
+
elif request_msg.request_type == ZMQRequestType.CLEAR_PARTITION:
|
|
108
|
+
response_body = {"message": "clear partition ok"}
|
|
109
|
+
response_type = ZMQRequestType.CLEAR_PARTITION_RESPONSE
|
|
110
|
+
elif request_msg.request_type == ZMQRequestType.GET_PARTITION_META:
|
|
111
|
+
# Mock partition metadata response
|
|
112
|
+
response_body = {"metadata": self._mock_batch_meta(request_msg.body)}
|
|
113
|
+
response_type = ZMQRequestType.GET_PARTITION_META_RESPONSE
|
|
114
|
+
elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION:
|
|
115
|
+
# Mock consumption status check - all consumed
|
|
116
|
+
response_body = {
|
|
117
|
+
"partition_id": request_msg.body.get("partition_id"),
|
|
118
|
+
"consumed": True,
|
|
119
|
+
}
|
|
120
|
+
response_type = ZMQRequestType.CONSUMPTION_RESPONSE
|
|
121
|
+
elif request_msg.request_type == ZMQRequestType.CHECK_PRODUCTION:
|
|
122
|
+
# Mock production status check - all produced
|
|
123
|
+
response_body = {
|
|
124
|
+
"partition_id": request_msg.body.get("partition_id"),
|
|
125
|
+
"produced": True,
|
|
126
|
+
}
|
|
127
|
+
response_type = ZMQRequestType.PRODUCTION_RESPONSE
|
|
128
|
+
elif request_msg.request_type == ZMQRequestType.GET_LIST_PARTITIONS:
|
|
129
|
+
# Mock partition list
|
|
130
|
+
response_body = {
|
|
131
|
+
"partition_ids": ["partition_0", "partition_1", "test_partition"],
|
|
132
|
+
}
|
|
133
|
+
response_type = ZMQRequestType.LIST_PARTITIONS_RESPONSE
|
|
134
|
+
else:
|
|
135
|
+
response_body = {"error": f"Unknown request type: {request_msg.request_type}"}
|
|
109
136
|
response_type = ZMQRequestType.CLEAR_META_RESPONSE
|
|
110
137
|
|
|
111
138
|
# Send response
|
|
@@ -352,14 +379,6 @@ def test_get_meta(client_setup):
|
|
|
352
379
|
assert len(metadata.global_indexes) == 10
|
|
353
380
|
|
|
354
381
|
|
|
355
|
-
def test_clear_operation(client_setup):
|
|
356
|
-
"""Test clear operation"""
|
|
357
|
-
client, _, _ = client_setup
|
|
358
|
-
|
|
359
|
-
# Test clear operation
|
|
360
|
-
client.clear(partition_id="0")
|
|
361
|
-
|
|
362
|
-
|
|
363
382
|
# Test with single controller and multiple storage units
|
|
364
383
|
def test_single_controller_multiple_storages():
|
|
365
384
|
"""Test client with single controller and multiple storage units"""
|
|
@@ -426,3 +445,140 @@ def test_put_without_required_params(client_setup):
|
|
|
426
445
|
# Test put without partition id (should fail)
|
|
427
446
|
with pytest.raises(ValueError):
|
|
428
447
|
client.put(data=test_data)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
# Test new status checking methods
|
|
451
|
+
def test_check_consumption_status(client_setup):
|
|
452
|
+
"""Test consumption status checking"""
|
|
453
|
+
client, _, _ = client_setup
|
|
454
|
+
|
|
455
|
+
# Test synchronous check_consumption_status
|
|
456
|
+
is_consumed = client.check_consumption_status(task_name="generate_sequences", partition_id="train_0")
|
|
457
|
+
assert is_consumed is True
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def test_check_production_status(client_setup):
|
|
461
|
+
"""Test production status checking"""
|
|
462
|
+
client, _, _ = client_setup
|
|
463
|
+
|
|
464
|
+
# Test synchronous check_production_status
|
|
465
|
+
is_produced = client.check_production_status(data_fields=["prompt_ids", "attention_mask"], partition_id="train_0")
|
|
466
|
+
assert is_produced is True
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def test_get_partition_list(client_setup):
|
|
470
|
+
"""Test partition list retrieval"""
|
|
471
|
+
client, _, _ = client_setup
|
|
472
|
+
|
|
473
|
+
# Test synchronous get_partition_list
|
|
474
|
+
partition_list = client.get_partition_list()
|
|
475
|
+
assert isinstance(partition_list, list)
|
|
476
|
+
assert len(partition_list) > 0
|
|
477
|
+
assert "partition_0" in partition_list
|
|
478
|
+
assert "partition_1" in partition_list
|
|
479
|
+
assert "test_partition" in partition_list
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
@pytest.mark.asyncio
|
|
483
|
+
async def test_async_check_consumption_status(client_setup):
|
|
484
|
+
"""Test async consumption status checking"""
|
|
485
|
+
client, _, _ = client_setup
|
|
486
|
+
|
|
487
|
+
# Test async_check_consumption_status
|
|
488
|
+
is_consumed = await client.async_check_consumption_status(task_name="generate_sequences", partition_id="train_0")
|
|
489
|
+
assert is_consumed is True
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
@pytest.mark.asyncio
|
|
493
|
+
async def test_async_check_production_status(client_setup):
|
|
494
|
+
"""Test async production status checking"""
|
|
495
|
+
client, _, _ = client_setup
|
|
496
|
+
|
|
497
|
+
# Test async_check_production_status
|
|
498
|
+
is_produced = await client.async_check_production_status(
|
|
499
|
+
data_fields=["prompt_ids", "attention_mask"], partition_id="train_0"
|
|
500
|
+
)
|
|
501
|
+
assert is_produced is True
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
@pytest.mark.asyncio
|
|
505
|
+
async def test_async_get_partition_list(client_setup):
|
|
506
|
+
"""Test async partition list retrieval"""
|
|
507
|
+
client, _, _ = client_setup
|
|
508
|
+
|
|
509
|
+
# Test async_get_partition_list
|
|
510
|
+
partition_list = await client.async_get_partition_list()
|
|
511
|
+
assert isinstance(partition_list, list)
|
|
512
|
+
assert len(partition_list) > 0
|
|
513
|
+
assert "partition_0" in partition_list
|
|
514
|
+
assert "partition_1" in partition_list
|
|
515
|
+
assert "test_partition" in partition_list
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
# Test clear methods
|
|
519
|
+
@pytest.mark.asyncio
|
|
520
|
+
async def test_async_clear_partition(client_setup):
|
|
521
|
+
"""Test async clear partition operation"""
|
|
522
|
+
client, _, _ = client_setup
|
|
523
|
+
|
|
524
|
+
# Test async_clear_partition
|
|
525
|
+
await client.async_clear_partition(partition_id="test_partition")
|
|
526
|
+
|
|
527
|
+
# If no exception is raised, the test passes
|
|
528
|
+
assert True
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
@pytest.mark.asyncio
|
|
532
|
+
async def test_async_clear_samples(client_setup):
|
|
533
|
+
"""Test async clear samples operation"""
|
|
534
|
+
client, _, _ = client_setup
|
|
535
|
+
|
|
536
|
+
# First get metadata to create a BatchMeta object
|
|
537
|
+
metadata = await client.async_get_meta(data_fields=["tokens", "labels"], batch_size=2, partition_id="0")
|
|
538
|
+
|
|
539
|
+
# Test async_clear_samples
|
|
540
|
+
await client.async_clear_samples(metadata=metadata)
|
|
541
|
+
|
|
542
|
+
# If no exception is raised, the test passes
|
|
543
|
+
assert True
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
def test_clear_partition(client_setup):
|
|
547
|
+
"""Test synchronous clear partition operation"""
|
|
548
|
+
client, _, _ = client_setup
|
|
549
|
+
|
|
550
|
+
# Test synchronous clear_partition
|
|
551
|
+
client.clear_partition(partition_id="test_partition")
|
|
552
|
+
|
|
553
|
+
# If no exception is raised, the test passes
|
|
554
|
+
assert True
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
def test_clear_samples(client_setup):
|
|
558
|
+
"""Test synchronous clear samples operation"""
|
|
559
|
+
client, _, _ = client_setup
|
|
560
|
+
|
|
561
|
+
# First get metadata to create a BatchMeta object
|
|
562
|
+
metadata = client.get_meta(data_fields=["tokens", "labels"], batch_size=2, partition_id="0")
|
|
563
|
+
|
|
564
|
+
# Test synchronous clear_samples
|
|
565
|
+
client.clear_samples(metadata=metadata)
|
|
566
|
+
|
|
567
|
+
# If no exception is raised, the test passes
|
|
568
|
+
assert True
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
@pytest.mark.asyncio
|
|
572
|
+
async def test_async_clear_samples_with_empty_metadata(client_setup):
|
|
573
|
+
"""Test async_clear_samples with empty BatchMeta"""
|
|
574
|
+
client, _, _ = client_setup
|
|
575
|
+
|
|
576
|
+
# Create empty BatchMeta
|
|
577
|
+
metadata = BatchMeta(samples=[])
|
|
578
|
+
|
|
579
|
+
# The clear operation should complete without raising an exception
|
|
580
|
+
# because the mock storage manager is configured to handle this
|
|
581
|
+
await client.async_clear_samples(metadata=metadata)
|
|
582
|
+
|
|
583
|
+
# If no exception is raised, the test passes
|
|
584
|
+
assert True
|
tests/test_controller.py
CHANGED
|
@@ -28,7 +28,6 @@ logging.basicConfig(level=logging.INFO)
|
|
|
28
28
|
logger = logging.getLogger(__name__)
|
|
29
29
|
|
|
30
30
|
from transfer_queue import TransferQueueController # noqa: E402
|
|
31
|
-
from transfer_queue.controller import TQ_INIT_FIELD_NUM # noqa: E402
|
|
32
31
|
from transfer_queue.utils.utils import ProductionStatus # noqa: E402
|
|
33
32
|
|
|
34
33
|
|
|
@@ -92,21 +91,49 @@ class TestTransferQueueController:
|
|
|
92
91
|
)
|
|
93
92
|
)
|
|
94
93
|
assert success
|
|
95
|
-
partition = ray.get(tq_controller.
|
|
94
|
+
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
|
|
96
95
|
assert partition.production_status is not None
|
|
97
96
|
assert partition.production_status.size(0) == gbs * num_n_samples
|
|
98
|
-
|
|
97
|
+
|
|
98
|
+
# Test for get production status
|
|
99
|
+
production_status = ray.get(
|
|
100
|
+
tq_controller.get_production_status.remote(
|
|
101
|
+
partition_id=partition_id,
|
|
102
|
+
data_fields=data_fields,
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
assert production_status
|
|
106
|
+
|
|
107
|
+
# Total fields should match the number of fields we added
|
|
108
|
+
assert partition.total_fields_num == len(data_fields)
|
|
109
|
+
|
|
110
|
+
# Allocated fields should be at least the number of actual fields
|
|
111
|
+
assert partition.allocated_fields_num >= partition.total_fields_num
|
|
112
|
+
|
|
113
|
+
# Check production status for the fields we added
|
|
99
114
|
assert torch.equal(
|
|
100
115
|
sum(partition.production_status[:, : len(data_fields)]),
|
|
101
116
|
torch.Tensor([gbs * num_n_samples, gbs * num_n_samples]),
|
|
102
117
|
)
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
118
|
+
|
|
119
|
+
# Any additional allocated fields should be zero (unused)
|
|
120
|
+
if partition.allocated_fields_num > len(data_fields):
|
|
121
|
+
assert torch.equal(
|
|
122
|
+
sum(partition.production_status[:, len(data_fields) :]),
|
|
123
|
+
torch.zeros(1 * (partition.allocated_fields_num - len(data_fields))),
|
|
124
|
+
)
|
|
107
125
|
|
|
108
126
|
print(f"✓ Updated production status for partition {partition_id}")
|
|
109
127
|
|
|
128
|
+
# Test for get consumption status
|
|
129
|
+
consumption_status = ray.get(
|
|
130
|
+
tq_controller.get_consumption_status.remote(
|
|
131
|
+
partition_id=partition_id,
|
|
132
|
+
task_name="generate_sequences",
|
|
133
|
+
)
|
|
134
|
+
)
|
|
135
|
+
assert torch.equal(consumption_status, torch.zeros(gbs * num_n_samples))
|
|
136
|
+
|
|
110
137
|
# Test get metadate in fetch mode
|
|
111
138
|
gen_meta = ray.get(
|
|
112
139
|
tq_controller.get_metadata.remote(
|
|
@@ -117,13 +144,23 @@ class TestTransferQueueController:
|
|
|
117
144
|
task_name="generate_sequences",
|
|
118
145
|
)
|
|
119
146
|
)
|
|
147
|
+
|
|
120
148
|
assert gen_meta.global_indexes == list(range(gbs * num_n_samples))
|
|
121
149
|
assert gen_meta.samples[0].partition_id == "train_0"
|
|
122
150
|
assert gen_meta.field_names == ["prompt_ids"]
|
|
123
|
-
partition = ray.get(tq_controller.
|
|
151
|
+
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
|
|
124
152
|
assert torch.equal(partition.consumption_status["generate_sequences"], torch.ones(gbs * num_n_samples))
|
|
125
153
|
print("✓ Get metadata in fetch mode correct")
|
|
126
154
|
|
|
155
|
+
# Test for get consumption status
|
|
156
|
+
consumption_status = ray.get(
|
|
157
|
+
tq_controller.get_consumption_status.remote(
|
|
158
|
+
partition_id=partition_id,
|
|
159
|
+
task_name="generate_sequences",
|
|
160
|
+
)
|
|
161
|
+
)
|
|
162
|
+
assert torch.equal(consumption_status, torch.ones(gbs * num_n_samples))
|
|
163
|
+
|
|
127
164
|
# Test get clear meta
|
|
128
165
|
clear_meta = ray.get(
|
|
129
166
|
tq_controller.get_metadata.remote(
|
|
@@ -136,14 +173,13 @@ class TestTransferQueueController:
|
|
|
136
173
|
assert [sample.fields for sample in clear_meta.samples] == [{}] * (gbs * num_n_samples)
|
|
137
174
|
print("✓ Clear metadata correct")
|
|
138
175
|
|
|
139
|
-
# Test
|
|
140
|
-
ray.get(tq_controller.
|
|
141
|
-
partition = ray.get(tq_controller.
|
|
176
|
+
# Test clear_partition
|
|
177
|
+
ray.get(tq_controller.clear_partition.remote(partition_id))
|
|
178
|
+
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
|
|
142
179
|
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id))
|
|
143
180
|
assert partition_index_range == set()
|
|
144
|
-
assert
|
|
145
|
-
|
|
146
|
-
print("✓ Clear correct")
|
|
181
|
+
assert partition is None
|
|
182
|
+
print("✓ Clear partition correct")
|
|
147
183
|
|
|
148
184
|
def test_controller_with_multi_partitions(self, ray_setup):
|
|
149
185
|
gbs_1 = 8
|
|
@@ -248,15 +284,15 @@ class TestTransferQueueController:
|
|
|
248
284
|
# Clear partition 1
|
|
249
285
|
partition_index_range_1 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_1))
|
|
250
286
|
assert partition_index_range_1
|
|
251
|
-
ray.get(tq_controller.
|
|
252
|
-
partition_1_after_clear = ray.get(tq_controller.
|
|
287
|
+
ray.get(tq_controller.clear_partition.remote(partition_id_1))
|
|
288
|
+
partition_1_after_clear = ray.get(tq_controller.get_partition_snapshot.remote(partition_id_1))
|
|
253
289
|
partition_index_range_1_after_clear = ray.get(tq_controller.get_partition_index_range.remote(partition_id_1))
|
|
254
290
|
|
|
255
291
|
assert not partition_index_range_1_after_clear
|
|
256
|
-
assert
|
|
257
|
-
assert
|
|
292
|
+
assert partition_1_after_clear is None
|
|
293
|
+
assert partition_index_range_1_after_clear == set()
|
|
258
294
|
|
|
259
|
-
partition_2 = ray.get(tq_controller.
|
|
295
|
+
partition_2 = ray.get(tq_controller.get_partition_snapshot.remote(partition_id_2))
|
|
260
296
|
partition_index_range_2 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2))
|
|
261
297
|
assert partition_index_range_2 == set([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
|
|
262
298
|
assert torch.all(
|
|
@@ -284,3 +320,64 @@ class TestTransferQueueController:
|
|
|
284
320
|
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_3))
|
|
285
321
|
assert partition_index_range == set(list(range(32)) + list(range(48, 80)))
|
|
286
322
|
print("✓ Correctly assign partition_3")
|
|
323
|
+
|
|
324
|
+
def test_controller_clear_meta(self, ray_setup):
|
|
325
|
+
"""Test clear_meta functionality for individual samples"""
|
|
326
|
+
gbs = 4
|
|
327
|
+
num_n_samples = 2
|
|
328
|
+
partition_id = "test_clear_meta"
|
|
329
|
+
|
|
330
|
+
tq_controller = TransferQueueController.remote()
|
|
331
|
+
|
|
332
|
+
# Create metadata in insert mode
|
|
333
|
+
data_fields = ["prompt_ids", "attention_mask"]
|
|
334
|
+
metadata = ray.get(
|
|
335
|
+
tq_controller.get_metadata.remote(
|
|
336
|
+
data_fields=data_fields,
|
|
337
|
+
batch_size=gbs * num_n_samples,
|
|
338
|
+
partition_id=partition_id,
|
|
339
|
+
mode="insert",
|
|
340
|
+
)
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
assert metadata.global_indexes == list(range(gbs * num_n_samples))
|
|
344
|
+
|
|
345
|
+
# Update production status
|
|
346
|
+
dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in metadata.global_indexes}
|
|
347
|
+
shapes = {k: {"prompt_ids": (32,), "attention_mask": (32,)} for k in metadata.global_indexes}
|
|
348
|
+
success = ray.get(
|
|
349
|
+
tq_controller.update_production_status.remote(
|
|
350
|
+
partition_id=partition_id,
|
|
351
|
+
global_indexes=metadata.global_indexes,
|
|
352
|
+
field_names=metadata.field_names,
|
|
353
|
+
dtypes=dtypes,
|
|
354
|
+
shapes=shapes,
|
|
355
|
+
)
|
|
356
|
+
)
|
|
357
|
+
assert success
|
|
358
|
+
|
|
359
|
+
# Get partition snapshot before clear
|
|
360
|
+
partition_before = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
|
|
361
|
+
assert partition_before is not None
|
|
362
|
+
assert len(partition_before.global_indexes) == gbs * num_n_samples
|
|
363
|
+
assert set(partition_before.global_indexes) == set(range(gbs * num_n_samples))
|
|
364
|
+
|
|
365
|
+
# Test clear_meta - clear first 4 samples (indexes 0-3)
|
|
366
|
+
global_indexes_to_clear = [0, 1, 2, 3, 6]
|
|
367
|
+
partition_ids_to_clear = [partition_id] * len(global_indexes_to_clear)
|
|
368
|
+
|
|
369
|
+
ray.get(
|
|
370
|
+
tq_controller.clear_meta.remote(
|
|
371
|
+
global_indexes=global_indexes_to_clear,
|
|
372
|
+
partition_ids=partition_ids_to_clear,
|
|
373
|
+
)
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# Check that only the cleared samples are affected
|
|
377
|
+
partition_after = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
|
|
378
|
+
assert partition_after is not None
|
|
379
|
+
|
|
380
|
+
# Verify production status is cleared for the specified indexes
|
|
381
|
+
assert set(partition_after.global_indexes) == set([4, 5, 7])
|
|
382
|
+
|
|
383
|
+
print("✓ Clear meta correct")
|
|
@@ -26,8 +26,8 @@ sys.path.append(str(parent_dir))
|
|
|
26
26
|
logging.basicConfig(level=logging.INFO)
|
|
27
27
|
logger = logging.getLogger(__name__)
|
|
28
28
|
|
|
29
|
-
TQ_INIT_SAMPLE_NUM = int(os.environ.get("TQ_INIT_SAMPLE_NUM",
|
|
30
|
-
TQ_INIT_FIELD_NUM = int(os.environ.get("TQ_INIT_FIELD_NUM",
|
|
29
|
+
TQ_INIT_SAMPLE_NUM = int(os.environ.get("TQ_INIT_SAMPLE_NUM", 1)) # Initial number of samples
|
|
30
|
+
TQ_INIT_FIELD_NUM = int(os.environ.get("TQ_INIT_FIELD_NUM", 1))
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
def test_data_partition_status():
|
|
@@ -40,7 +40,8 @@ def test_data_partition_status():
|
|
|
40
40
|
partition = DataPartitionStatus(partition_id="test@partition_1")
|
|
41
41
|
|
|
42
42
|
# Test initial state
|
|
43
|
-
assert partition.total_samples_num ==
|
|
43
|
+
assert partition.total_samples_num == 0
|
|
44
|
+
assert partition.allocated_samples_num == TQ_INIT_SAMPLE_NUM
|
|
44
45
|
assert partition.total_fields_num == 0
|
|
45
46
|
assert partition.allocated_fields_num == TQ_INIT_FIELD_NUM
|
|
46
47
|
assert partition.production_status is not None
|
|
@@ -127,7 +128,7 @@ def test_partition_interface():
|
|
|
127
128
|
|
|
128
129
|
# Test that the class can be imported and has expected methods
|
|
129
130
|
assert hasattr(TransferQueueController, "create_partition")
|
|
130
|
-
assert hasattr(TransferQueueController, "
|
|
131
|
+
assert hasattr(TransferQueueController, "get_partition_snapshot")
|
|
131
132
|
assert hasattr(TransferQueueController, "update_production_status")
|
|
132
133
|
assert hasattr(TransferQueueController, "scan_data_status")
|
|
133
134
|
assert hasattr(TransferQueueController, "generate_batch_meta")
|
|
@@ -171,8 +172,8 @@ def test_dynamic_expansion_scenarios():
|
|
|
171
172
|
10: {"field_1": (32,)},
|
|
172
173
|
},
|
|
173
174
|
)
|
|
174
|
-
assert partition.total_samples_num
|
|
175
|
-
|
|
175
|
+
assert partition.total_samples_num == 3
|
|
176
|
+
assert partition.allocated_samples_num >= 11 # Should accommodate index 10
|
|
176
177
|
print("✓ Large index gaps handled correctly")
|
|
177
178
|
|
|
178
179
|
# Scenario 2: Adding many fields dynamically
|
|
@@ -212,7 +213,8 @@ def test_data_partition_status_advanced():
|
|
|
212
213
|
partition = DataPartitionStatus(partition_id="advanced_test")
|
|
213
214
|
|
|
214
215
|
# Initially empty
|
|
215
|
-
assert partition.total_samples_num ==
|
|
216
|
+
assert partition.total_samples_num == 0
|
|
217
|
+
assert partition.allocated_samples_num == TQ_INIT_SAMPLE_NUM
|
|
216
218
|
assert partition.total_fields_num == 0
|
|
217
219
|
assert partition.allocated_fields_num == TQ_INIT_FIELD_NUM
|
|
218
220
|
|
|
@@ -289,6 +291,7 @@ def test_data_partition_status_advanced():
|
|
|
289
291
|
"created_at",
|
|
290
292
|
"total_samples_num",
|
|
291
293
|
"total_fields_num",
|
|
294
|
+
"allocated_samples_num",
|
|
292
295
|
"allocated_fields_num",
|
|
293
296
|
"registered_tasks",
|
|
294
297
|
"produced_samples",
|
|
@@ -311,8 +314,7 @@ def test_data_partition_status_advanced():
|
|
|
311
314
|
initial_consumption_sum = sum(t.sum().item() for t in partition.consumption_status.values())
|
|
312
315
|
|
|
313
316
|
# Clear only production data
|
|
314
|
-
|
|
315
|
-
assert success
|
|
317
|
+
partition.clear_data(list(range(4)), clear_consumption=False)
|
|
316
318
|
assert partition.production_status[:4, :].sum().item() == 0
|
|
317
319
|
|
|
318
320
|
# Consumption data should remain
|
|
@@ -353,7 +355,7 @@ def test_edge_cases_and_error_handling():
|
|
|
353
355
|
task_name = "early_task"
|
|
354
356
|
consumption_tensor = partition.get_consumption_status(task_name)
|
|
355
357
|
assert consumption_tensor is not None
|
|
356
|
-
assert consumption_tensor.shape[0] == partition.
|
|
358
|
+
assert consumption_tensor.shape[0] == partition.allocated_samples_num
|
|
357
359
|
|
|
358
360
|
# Test 4: Production status update error conditions
|
|
359
361
|
# Test with empty lists
|
|
@@ -371,80 +373,6 @@ def test_edge_cases_and_error_handling():
|
|
|
371
373
|
print("Edge cases and error handling tests passed!\n")
|
|
372
374
|
|
|
373
375
|
|
|
374
|
-
def test_backward_compatibility():
|
|
375
|
-
"""Test backward compatibility with existing interfaces."""
|
|
376
|
-
print("Testing backward compatibility...")
|
|
377
|
-
|
|
378
|
-
from transfer_queue.controller import DataPartitionStatus
|
|
379
|
-
|
|
380
|
-
partition = DataPartitionStatus(partition_id="compat_test")
|
|
381
|
-
|
|
382
|
-
# Test 1: Basic workflow should work as before
|
|
383
|
-
sample_indices = [0, 1, 2, 3, 4]
|
|
384
|
-
field_names = ["input_ids", "attention_mask", "labels"]
|
|
385
|
-
dtypes = {
|
|
386
|
-
k: {"input_ids": "torch.int64", "attention_mask": "torch.bool", "labels": "torch.int64"} for k in sample_indices
|
|
387
|
-
}
|
|
388
|
-
shapes = {k: {"input_ids": (32,), "attention_mask": (32,), "labels": (32,)} for k in sample_indices}
|
|
389
|
-
success = partition.update_production_status(
|
|
390
|
-
sample_indices,
|
|
391
|
-
field_names,
|
|
392
|
-
dtypes=dtypes,
|
|
393
|
-
shapes=shapes,
|
|
394
|
-
)
|
|
395
|
-
assert success
|
|
396
|
-
|
|
397
|
-
# Traditional consumption tracking
|
|
398
|
-
task_name = "training_task"
|
|
399
|
-
ready_samples = partition.scan_data_status(field_names, task_name)
|
|
400
|
-
assert len(ready_samples) == 5
|
|
401
|
-
|
|
402
|
-
# Mark as consumed
|
|
403
|
-
partition.mark_consumed(task_name, ready_samples[:3])
|
|
404
|
-
|
|
405
|
-
# Should now return only unconsumed samples
|
|
406
|
-
remaining_ready = partition.scan_data_status(field_names, task_name)
|
|
407
|
-
assert len(remaining_ready) == 2
|
|
408
|
-
|
|
409
|
-
print("✓ Basic workflow maintains compatibility")
|
|
410
|
-
|
|
411
|
-
# Test 2: Field mapping should be consistent
|
|
412
|
-
for field in field_names:
|
|
413
|
-
assert field in partition.field_name_mapping
|
|
414
|
-
field_idx = partition.field_name_mapping[field]
|
|
415
|
-
assert field_idx >= 0
|
|
416
|
-
assert field_idx < partition.allocated_fields_num
|
|
417
|
-
|
|
418
|
-
print("✓ Field mapping consistency maintained")
|
|
419
|
-
|
|
420
|
-
# Test 3: Metadata access patterns
|
|
421
|
-
for sample_idx in sample_indices:
|
|
422
|
-
for field in field_names:
|
|
423
|
-
# These should return reasonable values or None
|
|
424
|
-
dtype = partition.get_field_dtype(sample_idx, field)
|
|
425
|
-
shape = partition.get_field_shape(sample_idx, field)
|
|
426
|
-
assert dtype is not None
|
|
427
|
-
assert shape is not None
|
|
428
|
-
# Should not crash even if metadata wasn't provided
|
|
429
|
-
|
|
430
|
-
print("✓ Metadata access patterns preserved")
|
|
431
|
-
|
|
432
|
-
# Test 4: Statistics format should be familiar
|
|
433
|
-
stats = partition.get_statistics()
|
|
434
|
-
familiar_keys = ["partition_id", "total_samples_num", "total_fields_num"]
|
|
435
|
-
for key in familiar_keys:
|
|
436
|
-
assert key in stats
|
|
437
|
-
|
|
438
|
-
assert isinstance(stats["total_samples_num"], int)
|
|
439
|
-
assert isinstance(stats["total_fields_num"], int)
|
|
440
|
-
assert stats["total_samples_num"] > 0
|
|
441
|
-
assert stats["total_fields_num"] == len(field_names)
|
|
442
|
-
|
|
443
|
-
print("✓ Statistics format maintains familiarity")
|
|
444
|
-
|
|
445
|
-
print("Backward compatibility tests passed!\n")
|
|
446
|
-
|
|
447
|
-
|
|
448
376
|
def test_performance_characteristics():
|
|
449
377
|
"""Test performance characteristics of the refactored implementation."""
|
|
450
378
|
print("Testing performance characteristics...")
|
|
@@ -512,65 +440,3 @@ def test_performance_characteristics():
|
|
|
512
440
|
print("✓ Memory usage patterns reasonable")
|
|
513
441
|
|
|
514
442
|
print("Performance characteristics tests passed!\n")
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
def main():
|
|
518
|
-
"""Run all tests."""
|
|
519
|
-
print("=== Comprehensive Testing of TransferQueue Controller ===\n")
|
|
520
|
-
|
|
521
|
-
test_functions = [
|
|
522
|
-
test_data_partition_status,
|
|
523
|
-
test_partition_interface,
|
|
524
|
-
test_dynamic_expansion_scenarios,
|
|
525
|
-
test_data_partition_status_advanced,
|
|
526
|
-
test_edge_cases_and_error_handling,
|
|
527
|
-
test_backward_compatibility,
|
|
528
|
-
test_performance_characteristics,
|
|
529
|
-
]
|
|
530
|
-
|
|
531
|
-
passed_tests = 0
|
|
532
|
-
total_tests = len(test_functions)
|
|
533
|
-
|
|
534
|
-
try:
|
|
535
|
-
for test_func in test_functions:
|
|
536
|
-
try:
|
|
537
|
-
test_func()
|
|
538
|
-
passed_tests += 1
|
|
539
|
-
except Exception as e:
|
|
540
|
-
print(f"❌ {test_func.__name__} failed: {e}")
|
|
541
|
-
import traceback
|
|
542
|
-
|
|
543
|
-
traceback.print_exc()
|
|
544
|
-
print()
|
|
545
|
-
|
|
546
|
-
print("=" * 60)
|
|
547
|
-
print(f"TEST SUMMARY: {passed_tests}/{total_tests} test suites passed")
|
|
548
|
-
|
|
549
|
-
if passed_tests == total_tests:
|
|
550
|
-
print("🎉 ALL TESTS PASSED!")
|
|
551
|
-
print("\nThe refactored DataPartitionStatus demonstrates:")
|
|
552
|
-
print("1. ✅ Dynamic row and column expansion without pre-allocation")
|
|
553
|
-
print("2. ✅ Robust partition-controller interface design")
|
|
554
|
-
print("3. ✅ Self-contained state management in DataPartitionStatus")
|
|
555
|
-
print("4. ✅ Flexible consumption tracking per task")
|
|
556
|
-
print("5. ✅ Comprehensive scanning and query capabilities")
|
|
557
|
-
print("6. ✅ Advanced error handling and edge case management")
|
|
558
|
-
print("7. ✅ Backward compatibility with existing interfaces")
|
|
559
|
-
print("8. ✅ Good performance characteristics for large datasets")
|
|
560
|
-
print("\n🚀 DataPartitionStatus refactoring is ready for production!")
|
|
561
|
-
else:
|
|
562
|
-
print(f"⚠️ {total_tests - passed_tests} test suites failed.")
|
|
563
|
-
print("Please review the failures before deploying to production.")
|
|
564
|
-
|
|
565
|
-
print("=" * 60)
|
|
566
|
-
|
|
567
|
-
except Exception as e:
|
|
568
|
-
print(f"❌ Critical test failure: {e}")
|
|
569
|
-
import traceback
|
|
570
|
-
|
|
571
|
-
traceback.print_exc()
|
|
572
|
-
sys.exit(1)
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
if __name__ == "__main__":
|
|
576
|
-
main()
|