TransferQueue 0.1.4.dev2__py3-none-any.whl → 0.1.5.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -306,7 +306,7 @@ class Trainer:
306
306
 
307
307
  # Client notifies controller to clear data status, controller returns metadata;
308
308
  # Client then notifies the storage plane to clear based on metadata
309
- asyncio.run(self.data_system_client.async_clear(partition_id=f"train_{step}"))
309
+ asyncio.run(self.data_system_client.async_clear_partition(partition_id=f"train_{step}"))
310
310
  logger.info("clear ok! ")
311
311
  logger.info("demo done!")
312
312
 
@@ -190,7 +190,7 @@ def fit(config, data_system_client):
190
190
  # Client then notifies the storage plane to clear based on metadata
191
191
  # Client selects one master controller to get metadata,
192
192
  # other controllers directly clear without returning metadata
193
- data_system_client.clear(partition_id=f"train_{step}")
193
+ data_system_client.clear_partition(partition_id=f"train_{step}")
194
194
  logger.info("clear ok! ")
195
195
  logger.info("demo done!")
196
196
 
tests/test_client.py CHANGED
@@ -101,11 +101,38 @@ class MockController:
101
101
  if request_msg.request_type == ZMQRequestType.GET_META:
102
102
  response_body = self._mock_batch_meta(request_msg.body)
103
103
  response_type = ZMQRequestType.GET_META_RESPONSE
104
- elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META:
105
- response_body = self._mock_batch_meta(request_msg.body)
106
- response_type = ZMQRequestType.GET_CLEAR_META_RESPONSE
107
104
  elif request_msg.request_type == ZMQRequestType.CLEAR_META:
108
- response_body = {"message": "clear ok"}
105
+ response_body = {"message": "clear meta ok"}
106
+ response_type = ZMQRequestType.CLEAR_META_RESPONSE
107
+ elif request_msg.request_type == ZMQRequestType.CLEAR_PARTITION:
108
+ response_body = {"message": "clear partition ok"}
109
+ response_type = ZMQRequestType.CLEAR_PARTITION_RESPONSE
110
+ elif request_msg.request_type == ZMQRequestType.GET_PARTITION_META:
111
+ # Mock partition metadata response
112
+ response_body = {"metadata": self._mock_batch_meta(request_msg.body)}
113
+ response_type = ZMQRequestType.GET_PARTITION_META_RESPONSE
114
+ elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION:
115
+ # Mock consumption status check - all consumed
116
+ response_body = {
117
+ "partition_id": request_msg.body.get("partition_id"),
118
+ "consumed": True,
119
+ }
120
+ response_type = ZMQRequestType.CONSUMPTION_RESPONSE
121
+ elif request_msg.request_type == ZMQRequestType.CHECK_PRODUCTION:
122
+ # Mock production status check - all produced
123
+ response_body = {
124
+ "partition_id": request_msg.body.get("partition_id"),
125
+ "produced": True,
126
+ }
127
+ response_type = ZMQRequestType.PRODUCTION_RESPONSE
128
+ elif request_msg.request_type == ZMQRequestType.GET_LIST_PARTITIONS:
129
+ # Mock partition list
130
+ response_body = {
131
+ "partition_ids": ["partition_0", "partition_1", "test_partition"],
132
+ }
133
+ response_type = ZMQRequestType.LIST_PARTITIONS_RESPONSE
134
+ else:
135
+ response_body = {"error": f"Unknown request type: {request_msg.request_type}"}
109
136
  response_type = ZMQRequestType.CLEAR_META_RESPONSE
110
137
 
111
138
  # Send response
@@ -352,14 +379,6 @@ def test_get_meta(client_setup):
352
379
  assert len(metadata.global_indexes) == 10
353
380
 
354
381
 
355
- def test_clear_operation(client_setup):
356
- """Test clear operation"""
357
- client, _, _ = client_setup
358
-
359
- # Test clear operation
360
- client.clear(partition_id="0")
361
-
362
-
363
382
  # Test with single controller and multiple storage units
364
383
  def test_single_controller_multiple_storages():
365
384
  """Test client with single controller and multiple storage units"""
@@ -426,3 +445,140 @@ def test_put_without_required_params(client_setup):
426
445
  # Test put without partition id (should fail)
427
446
  with pytest.raises(ValueError):
428
447
  client.put(data=test_data)
448
+
449
+
450
+ # Test new status checking methods
451
+ def test_check_consumption_status(client_setup):
452
+ """Test consumption status checking"""
453
+ client, _, _ = client_setup
454
+
455
+ # Test synchronous check_consumption_status
456
+ is_consumed = client.check_consumption_status(task_name="generate_sequences", partition_id="train_0")
457
+ assert is_consumed is True
458
+
459
+
460
+ def test_check_production_status(client_setup):
461
+ """Test production status checking"""
462
+ client, _, _ = client_setup
463
+
464
+ # Test synchronous check_production_status
465
+ is_produced = client.check_production_status(data_fields=["prompt_ids", "attention_mask"], partition_id="train_0")
466
+ assert is_produced is True
467
+
468
+
469
+ def test_get_partition_list(client_setup):
470
+ """Test partition list retrieval"""
471
+ client, _, _ = client_setup
472
+
473
+ # Test synchronous get_partition_list
474
+ partition_list = client.get_partition_list()
475
+ assert isinstance(partition_list, list)
476
+ assert len(partition_list) > 0
477
+ assert "partition_0" in partition_list
478
+ assert "partition_1" in partition_list
479
+ assert "test_partition" in partition_list
480
+
481
+
482
+ @pytest.mark.asyncio
483
+ async def test_async_check_consumption_status(client_setup):
484
+ """Test async consumption status checking"""
485
+ client, _, _ = client_setup
486
+
487
+ # Test async_check_consumption_status
488
+ is_consumed = await client.async_check_consumption_status(task_name="generate_sequences", partition_id="train_0")
489
+ assert is_consumed is True
490
+
491
+
492
+ @pytest.mark.asyncio
493
+ async def test_async_check_production_status(client_setup):
494
+ """Test async production status checking"""
495
+ client, _, _ = client_setup
496
+
497
+ # Test async_check_production_status
498
+ is_produced = await client.async_check_production_status(
499
+ data_fields=["prompt_ids", "attention_mask"], partition_id="train_0"
500
+ )
501
+ assert is_produced is True
502
+
503
+
504
+ @pytest.mark.asyncio
505
+ async def test_async_get_partition_list(client_setup):
506
+ """Test async partition list retrieval"""
507
+ client, _, _ = client_setup
508
+
509
+ # Test async_get_partition_list
510
+ partition_list = await client.async_get_partition_list()
511
+ assert isinstance(partition_list, list)
512
+ assert len(partition_list) > 0
513
+ assert "partition_0" in partition_list
514
+ assert "partition_1" in partition_list
515
+ assert "test_partition" in partition_list
516
+
517
+
518
+ # Test clear methods
519
+ @pytest.mark.asyncio
520
+ async def test_async_clear_partition(client_setup):
521
+ """Test async clear partition operation"""
522
+ client, _, _ = client_setup
523
+
524
+ # Test async_clear_partition
525
+ await client.async_clear_partition(partition_id="test_partition")
526
+
527
+ # If no exception is raised, the test passes
528
+ assert True
529
+
530
+
531
+ @pytest.mark.asyncio
532
+ async def test_async_clear_samples(client_setup):
533
+ """Test async clear samples operation"""
534
+ client, _, _ = client_setup
535
+
536
+ # First get metadata to create a BatchMeta object
537
+ metadata = await client.async_get_meta(data_fields=["tokens", "labels"], batch_size=2, partition_id="0")
538
+
539
+ # Test async_clear_samples
540
+ await client.async_clear_samples(metadata=metadata)
541
+
542
+ # If no exception is raised, the test passes
543
+ assert True
544
+
545
+
546
+ def test_clear_partition(client_setup):
547
+ """Test synchronous clear partition operation"""
548
+ client, _, _ = client_setup
549
+
550
+ # Test synchronous clear_partition
551
+ client.clear_partition(partition_id="test_partition")
552
+
553
+ # If no exception is raised, the test passes
554
+ assert True
555
+
556
+
557
+ def test_clear_samples(client_setup):
558
+ """Test synchronous clear samples operation"""
559
+ client, _, _ = client_setup
560
+
561
+ # First get metadata to create a BatchMeta object
562
+ metadata = client.get_meta(data_fields=["tokens", "labels"], batch_size=2, partition_id="0")
563
+
564
+ # Test synchronous clear_samples
565
+ client.clear_samples(metadata=metadata)
566
+
567
+ # If no exception is raised, the test passes
568
+ assert True
569
+
570
+
571
+ @pytest.mark.asyncio
572
+ async def test_async_clear_samples_with_empty_metadata(client_setup):
573
+ """Test async_clear_samples with empty BatchMeta"""
574
+ client, _, _ = client_setup
575
+
576
+ # Create empty BatchMeta
577
+ metadata = BatchMeta(samples=[])
578
+
579
+ # The clear operation should complete without raising an exception
580
+ # because the mock storage manager is configured to handle this
581
+ await client.async_clear_samples(metadata=metadata)
582
+
583
+ # If no exception is raised, the test passes
584
+ assert True
tests/test_controller.py CHANGED
@@ -28,7 +28,6 @@ logging.basicConfig(level=logging.INFO)
28
28
  logger = logging.getLogger(__name__)
29
29
 
30
30
  from transfer_queue import TransferQueueController # noqa: E402
31
- from transfer_queue.controller import TQ_INIT_FIELD_NUM # noqa: E402
32
31
  from transfer_queue.utils.utils import ProductionStatus # noqa: E402
33
32
 
34
33
 
@@ -92,21 +91,49 @@ class TestTransferQueueController:
92
91
  )
93
92
  )
94
93
  assert success
95
- partition = ray.get(tq_controller.get_partition.remote(partition_id))
94
+ partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
96
95
  assert partition.production_status is not None
97
96
  assert partition.production_status.size(0) == gbs * num_n_samples
98
- assert partition.production_status.size(1) == TQ_INIT_FIELD_NUM
97
+
98
+ # Test for get production status
99
+ production_status = ray.get(
100
+ tq_controller.get_production_status.remote(
101
+ partition_id=partition_id,
102
+ data_fields=data_fields,
103
+ )
104
+ )
105
+ assert production_status
106
+
107
+ # Total fields should match the number of fields we added
108
+ assert partition.total_fields_num == len(data_fields)
109
+
110
+ # Allocated fields should be at least the number of actual fields
111
+ assert partition.allocated_fields_num >= partition.total_fields_num
112
+
113
+ # Check production status for the fields we added
99
114
  assert torch.equal(
100
115
  sum(partition.production_status[:, : len(data_fields)]),
101
116
  torch.Tensor([gbs * num_n_samples, gbs * num_n_samples]),
102
117
  )
103
- assert torch.equal(
104
- sum(partition.production_status[:, len(data_fields) :]),
105
- torch.zeros(1 * (TQ_INIT_FIELD_NUM - len(data_fields))),
106
- )
118
+
119
+ # Any additional allocated fields should be zero (unused)
120
+ if partition.allocated_fields_num > len(data_fields):
121
+ assert torch.equal(
122
+ sum(partition.production_status[:, len(data_fields) :]),
123
+ torch.zeros(1 * (partition.allocated_fields_num - len(data_fields))),
124
+ )
107
125
 
108
126
  print(f"✓ Updated production status for partition {partition_id}")
109
127
 
128
+ # Test for get consumption status
129
+ consumption_status = ray.get(
130
+ tq_controller.get_consumption_status.remote(
131
+ partition_id=partition_id,
132
+ task_name="generate_sequences",
133
+ )
134
+ )
135
+ assert torch.equal(consumption_status, torch.zeros(gbs * num_n_samples))
136
+
110
137
  # Test get metadate in fetch mode
111
138
  gen_meta = ray.get(
112
139
  tq_controller.get_metadata.remote(
@@ -117,13 +144,23 @@ class TestTransferQueueController:
117
144
  task_name="generate_sequences",
118
145
  )
119
146
  )
147
+
120
148
  assert gen_meta.global_indexes == list(range(gbs * num_n_samples))
121
149
  assert gen_meta.samples[0].partition_id == "train_0"
122
150
  assert gen_meta.field_names == ["prompt_ids"]
123
- partition = ray.get(tq_controller.get_partition.remote(partition_id))
151
+ partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
124
152
  assert torch.equal(partition.consumption_status["generate_sequences"], torch.ones(gbs * num_n_samples))
125
153
  print("✓ Get metadata in fetch mode correct")
126
154
 
155
+ # Test for get consumption status
156
+ consumption_status = ray.get(
157
+ tq_controller.get_consumption_status.remote(
158
+ partition_id=partition_id,
159
+ task_name="generate_sequences",
160
+ )
161
+ )
162
+ assert torch.equal(consumption_status, torch.ones(gbs * num_n_samples))
163
+
127
164
  # Test get clear meta
128
165
  clear_meta = ray.get(
129
166
  tq_controller.get_metadata.remote(
@@ -136,14 +173,13 @@ class TestTransferQueueController:
136
173
  assert [sample.fields for sample in clear_meta.samples] == [{}] * (gbs * num_n_samples)
137
174
  print("✓ Clear metadata correct")
138
175
 
139
- # Test clear
140
- ray.get(tq_controller.clear.remote(partition_id))
141
- partition = ray.get(tq_controller.get_partition.remote(partition_id))
176
+ # Test clear_partition
177
+ ray.get(tq_controller.clear_partition.remote(partition_id))
178
+ partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
142
179
  partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id))
143
180
  assert partition_index_range == set()
144
- assert torch.all(partition.production_status == 0)
145
- assert torch.all(partition.consumption_status["generate_sequences"] == 0)
146
- print("✓ Clear correct")
181
+ assert partition is None
182
+ print(" Clear partition correct")
147
183
 
148
184
  def test_controller_with_multi_partitions(self, ray_setup):
149
185
  gbs_1 = 8
@@ -248,15 +284,15 @@ class TestTransferQueueController:
248
284
  # Clear partition 1
249
285
  partition_index_range_1 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_1))
250
286
  assert partition_index_range_1
251
- ray.get(tq_controller.clear.remote(partition_id_1))
252
- partition_1_after_clear = ray.get(tq_controller.get_partition.remote(partition_id_1))
287
+ ray.get(tq_controller.clear_partition.remote(partition_id_1))
288
+ partition_1_after_clear = ray.get(tq_controller.get_partition_snapshot.remote(partition_id_1))
253
289
  partition_index_range_1_after_clear = ray.get(tq_controller.get_partition_index_range.remote(partition_id_1))
254
290
 
255
291
  assert not partition_index_range_1_after_clear
256
- assert torch.all(partition_1_after_clear.production_status[list(partition_index_range_1), :] == 0)
257
- assert torch.all(partition_1_after_clear.consumption_status["generate_sequences"] == 0)
292
+ assert partition_1_after_clear is None
293
+ assert partition_index_range_1_after_clear == set()
258
294
 
259
- partition_2 = ray.get(tq_controller.get_partition.remote(partition_id_2))
295
+ partition_2 = ray.get(tq_controller.get_partition_snapshot.remote(partition_id_2))
260
296
  partition_index_range_2 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2))
261
297
  assert partition_index_range_2 == set([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
262
298
  assert torch.all(
@@ -284,3 +320,64 @@ class TestTransferQueueController:
284
320
  partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_3))
285
321
  assert partition_index_range == set(list(range(32)) + list(range(48, 80)))
286
322
  print("✓ Correctly assign partition_3")
323
+
324
+ def test_controller_clear_meta(self, ray_setup):
325
+ """Test clear_meta functionality for individual samples"""
326
+ gbs = 4
327
+ num_n_samples = 2
328
+ partition_id = "test_clear_meta"
329
+
330
+ tq_controller = TransferQueueController.remote()
331
+
332
+ # Create metadata in insert mode
333
+ data_fields = ["prompt_ids", "attention_mask"]
334
+ metadata = ray.get(
335
+ tq_controller.get_metadata.remote(
336
+ data_fields=data_fields,
337
+ batch_size=gbs * num_n_samples,
338
+ partition_id=partition_id,
339
+ mode="insert",
340
+ )
341
+ )
342
+
343
+ assert metadata.global_indexes == list(range(gbs * num_n_samples))
344
+
345
+ # Update production status
346
+ dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in metadata.global_indexes}
347
+ shapes = {k: {"prompt_ids": (32,), "attention_mask": (32,)} for k in metadata.global_indexes}
348
+ success = ray.get(
349
+ tq_controller.update_production_status.remote(
350
+ partition_id=partition_id,
351
+ global_indexes=metadata.global_indexes,
352
+ field_names=metadata.field_names,
353
+ dtypes=dtypes,
354
+ shapes=shapes,
355
+ )
356
+ )
357
+ assert success
358
+
359
+ # Get partition snapshot before clear
360
+ partition_before = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
361
+ assert partition_before is not None
362
+ assert len(partition_before.global_indexes) == gbs * num_n_samples
363
+ assert set(partition_before.global_indexes) == set(range(gbs * num_n_samples))
364
+
365
+ # Test clear_meta - clear first 4 samples (indexes 0-3)
366
+ global_indexes_to_clear = [0, 1, 2, 3, 6]
367
+ partition_ids_to_clear = [partition_id] * len(global_indexes_to_clear)
368
+
369
+ ray.get(
370
+ tq_controller.clear_meta.remote(
371
+ global_indexes=global_indexes_to_clear,
372
+ partition_ids=partition_ids_to_clear,
373
+ )
374
+ )
375
+
376
+ # Check that only the cleared samples are affected
377
+ partition_after = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
378
+ assert partition_after is not None
379
+
380
+ # Verify production status is cleared for the specified indexes
381
+ assert set(partition_after.global_indexes) == set([4, 5, 7])
382
+
383
+ print("✓ Clear meta correct")
@@ -26,8 +26,8 @@ sys.path.append(str(parent_dir))
26
26
  logging.basicConfig(level=logging.INFO)
27
27
  logger = logging.getLogger(__name__)
28
28
 
29
- TQ_INIT_SAMPLE_NUM = int(os.environ.get("TQ_INIT_SAMPLE_NUM", 10)) # Initial number of samples
30
- TQ_INIT_FIELD_NUM = int(os.environ.get("TQ_INIT_FIELD_NUM", 10))
29
+ TQ_INIT_SAMPLE_NUM = int(os.environ.get("TQ_INIT_SAMPLE_NUM", 1)) # Initial number of samples
30
+ TQ_INIT_FIELD_NUM = int(os.environ.get("TQ_INIT_FIELD_NUM", 1))
31
31
 
32
32
 
33
33
  def test_data_partition_status():
@@ -40,7 +40,8 @@ def test_data_partition_status():
40
40
  partition = DataPartitionStatus(partition_id="test@partition_1")
41
41
 
42
42
  # Test initial state
43
- assert partition.total_samples_num == TQ_INIT_SAMPLE_NUM
43
+ assert partition.total_samples_num == 0
44
+ assert partition.allocated_samples_num == TQ_INIT_SAMPLE_NUM
44
45
  assert partition.total_fields_num == 0
45
46
  assert partition.allocated_fields_num == TQ_INIT_FIELD_NUM
46
47
  assert partition.production_status is not None
@@ -127,7 +128,7 @@ def test_partition_interface():
127
128
 
128
129
  # Test that the class can be imported and has expected methods
129
130
  assert hasattr(TransferQueueController, "create_partition")
130
- assert hasattr(TransferQueueController, "get_partition")
131
+ assert hasattr(TransferQueueController, "get_partition_snapshot")
131
132
  assert hasattr(TransferQueueController, "update_production_status")
132
133
  assert hasattr(TransferQueueController, "scan_data_status")
133
134
  assert hasattr(TransferQueueController, "generate_batch_meta")
@@ -171,8 +172,8 @@ def test_dynamic_expansion_scenarios():
171
172
  10: {"field_1": (32,)},
172
173
  },
173
174
  )
174
- assert partition.total_samples_num >= 11 # Should accommodate index 10
175
-
175
+ assert partition.total_samples_num == 3
176
+ assert partition.allocated_samples_num >= 11 # Should accommodate index 10
176
177
  print("✓ Large index gaps handled correctly")
177
178
 
178
179
  # Scenario 2: Adding many fields dynamically
@@ -212,7 +213,8 @@ def test_data_partition_status_advanced():
212
213
  partition = DataPartitionStatus(partition_id="advanced_test")
213
214
 
214
215
  # Initially empty
215
- assert partition.total_samples_num == TQ_INIT_SAMPLE_NUM
216
+ assert partition.total_samples_num == 0
217
+ assert partition.allocated_samples_num == TQ_INIT_SAMPLE_NUM
216
218
  assert partition.total_fields_num == 0
217
219
  assert partition.allocated_fields_num == TQ_INIT_FIELD_NUM
218
220
 
@@ -289,6 +291,7 @@ def test_data_partition_status_advanced():
289
291
  "created_at",
290
292
  "total_samples_num",
291
293
  "total_fields_num",
294
+ "allocated_samples_num",
292
295
  "allocated_fields_num",
293
296
  "registered_tasks",
294
297
  "produced_samples",
@@ -311,8 +314,7 @@ def test_data_partition_status_advanced():
311
314
  initial_consumption_sum = sum(t.sum().item() for t in partition.consumption_status.values())
312
315
 
313
316
  # Clear only production data
314
- success = partition.clear_data(list(range(4)), clear_consumption=False)
315
- assert success
317
+ partition.clear_data(list(range(4)), clear_consumption=False)
316
318
  assert partition.production_status[:4, :].sum().item() == 0
317
319
 
318
320
  # Consumption data should remain
@@ -353,7 +355,7 @@ def test_edge_cases_and_error_handling():
353
355
  task_name = "early_task"
354
356
  consumption_tensor = partition.get_consumption_status(task_name)
355
357
  assert consumption_tensor is not None
356
- assert consumption_tensor.shape[0] == partition.total_samples_num
358
+ assert consumption_tensor.shape[0] == partition.allocated_samples_num
357
359
 
358
360
  # Test 4: Production status update error conditions
359
361
  # Test with empty lists
@@ -371,80 +373,6 @@ def test_edge_cases_and_error_handling():
371
373
  print("Edge cases and error handling tests passed!\n")
372
374
 
373
375
 
374
- def test_backward_compatibility():
375
- """Test backward compatibility with existing interfaces."""
376
- print("Testing backward compatibility...")
377
-
378
- from transfer_queue.controller import DataPartitionStatus
379
-
380
- partition = DataPartitionStatus(partition_id="compat_test")
381
-
382
- # Test 1: Basic workflow should work as before
383
- sample_indices = [0, 1, 2, 3, 4]
384
- field_names = ["input_ids", "attention_mask", "labels"]
385
- dtypes = {
386
- k: {"input_ids": "torch.int64", "attention_mask": "torch.bool", "labels": "torch.int64"} for k in sample_indices
387
- }
388
- shapes = {k: {"input_ids": (32,), "attention_mask": (32,), "labels": (32,)} for k in sample_indices}
389
- success = partition.update_production_status(
390
- sample_indices,
391
- field_names,
392
- dtypes=dtypes,
393
- shapes=shapes,
394
- )
395
- assert success
396
-
397
- # Traditional consumption tracking
398
- task_name = "training_task"
399
- ready_samples = partition.scan_data_status(field_names, task_name)
400
- assert len(ready_samples) == 5
401
-
402
- # Mark as consumed
403
- partition.mark_consumed(task_name, ready_samples[:3])
404
-
405
- # Should now return only unconsumed samples
406
- remaining_ready = partition.scan_data_status(field_names, task_name)
407
- assert len(remaining_ready) == 2
408
-
409
- print("✓ Basic workflow maintains compatibility")
410
-
411
- # Test 2: Field mapping should be consistent
412
- for field in field_names:
413
- assert field in partition.field_name_mapping
414
- field_idx = partition.field_name_mapping[field]
415
- assert field_idx >= 0
416
- assert field_idx < partition.allocated_fields_num
417
-
418
- print("✓ Field mapping consistency maintained")
419
-
420
- # Test 3: Metadata access patterns
421
- for sample_idx in sample_indices:
422
- for field in field_names:
423
- # These should return reasonable values or None
424
- dtype = partition.get_field_dtype(sample_idx, field)
425
- shape = partition.get_field_shape(sample_idx, field)
426
- assert dtype is not None
427
- assert shape is not None
428
- # Should not crash even if metadata wasn't provided
429
-
430
- print("✓ Metadata access patterns preserved")
431
-
432
- # Test 4: Statistics format should be familiar
433
- stats = partition.get_statistics()
434
- familiar_keys = ["partition_id", "total_samples_num", "total_fields_num"]
435
- for key in familiar_keys:
436
- assert key in stats
437
-
438
- assert isinstance(stats["total_samples_num"], int)
439
- assert isinstance(stats["total_fields_num"], int)
440
- assert stats["total_samples_num"] > 0
441
- assert stats["total_fields_num"] == len(field_names)
442
-
443
- print("✓ Statistics format maintains familiarity")
444
-
445
- print("Backward compatibility tests passed!\n")
446
-
447
-
448
376
  def test_performance_characteristics():
449
377
  """Test performance characteristics of the refactored implementation."""
450
378
  print("Testing performance characteristics...")
@@ -512,65 +440,3 @@ def test_performance_characteristics():
512
440
  print("✓ Memory usage patterns reasonable")
513
441
 
514
442
  print("Performance characteristics tests passed!\n")
515
-
516
-
517
- def main():
518
- """Run all tests."""
519
- print("=== Comprehensive Testing of TransferQueue Controller ===\n")
520
-
521
- test_functions = [
522
- test_data_partition_status,
523
- test_partition_interface,
524
- test_dynamic_expansion_scenarios,
525
- test_data_partition_status_advanced,
526
- test_edge_cases_and_error_handling,
527
- test_backward_compatibility,
528
- test_performance_characteristics,
529
- ]
530
-
531
- passed_tests = 0
532
- total_tests = len(test_functions)
533
-
534
- try:
535
- for test_func in test_functions:
536
- try:
537
- test_func()
538
- passed_tests += 1
539
- except Exception as e:
540
- print(f"❌ {test_func.__name__} failed: {e}")
541
- import traceback
542
-
543
- traceback.print_exc()
544
- print()
545
-
546
- print("=" * 60)
547
- print(f"TEST SUMMARY: {passed_tests}/{total_tests} test suites passed")
548
-
549
- if passed_tests == total_tests:
550
- print("🎉 ALL TESTS PASSED!")
551
- print("\nThe refactored DataPartitionStatus demonstrates:")
552
- print("1. ✅ Dynamic row and column expansion without pre-allocation")
553
- print("2. ✅ Robust partition-controller interface design")
554
- print("3. ✅ Self-contained state management in DataPartitionStatus")
555
- print("4. ✅ Flexible consumption tracking per task")
556
- print("5. ✅ Comprehensive scanning and query capabilities")
557
- print("6. ✅ Advanced error handling and edge case management")
558
- print("7. ✅ Backward compatibility with existing interfaces")
559
- print("8. ✅ Good performance characteristics for large datasets")
560
- print("\n🚀 DataPartitionStatus refactoring is ready for production!")
561
- else:
562
- print(f"⚠️ {total_tests - passed_tests} test suites failed.")
563
- print("Please review the failures before deploying to production.")
564
-
565
- print("=" * 60)
566
-
567
- except Exception as e:
568
- print(f"❌ Critical test failure: {e}")
569
- import traceback
570
-
571
- traceback.print_exc()
572
- sys.exit(1)
573
-
574
-
575
- if __name__ == "__main__":
576
- main()