TransferQueue 0.1.4.dev0__py3-none-any.whl → 0.1.4.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- performance_test.py +5 -5
- tests/test_metadata.py +147 -0
- tests/test_serial_utils_on_cpu.py +53 -0
- transfer_queue/controller.py +115 -100
- transfer_queue/metadata.py +30 -1
- transfer_queue/storage/managers/simple_backend_manager.py +6 -1
- transfer_queue/storage/simple_backend.py +28 -21
- transfer_queue/utils/perf_utils.py +104 -0
- transfer_queue/utils/zmq_utils.py +6 -5
- transfer_queue/version/version +1 -1
- {transferqueue-0.1.4.dev0.dist-info → transferqueue-0.1.4.dev1.dist-info}/METADATA +1 -1
- {transferqueue-0.1.4.dev0.dist-info → transferqueue-0.1.4.dev1.dist-info}/RECORD +16 -14
- {transferqueue-0.1.4.dev0.dist-info → transferqueue-0.1.4.dev1.dist-info}/top_level.txt +1 -0
- verify_fix.py +109 -0
- {transferqueue-0.1.4.dev0.dist-info → transferqueue-0.1.4.dev1.dist-info}/WHEEL +0 -0
- {transferqueue-0.1.4.dev0.dist-info → transferqueue-0.1.4.dev1.dist-info}/licenses/LICENSE +0 -0
performance_test.py
CHANGED
|
@@ -137,8 +137,8 @@ class RayBandwidthTester:
|
|
|
137
137
|
RemoteDataStore = RemoteDataStoreRemote
|
|
138
138
|
|
|
139
139
|
self.remote_store = RemoteDataStore.options(
|
|
140
|
-
num_cpus=
|
|
141
|
-
resources={f"node:{WORKER_NODE_IP}":
|
|
140
|
+
num_cpus=1,
|
|
141
|
+
resources={f"node:{WORKER_NODE_IP}": 1}
|
|
142
142
|
).remote()
|
|
143
143
|
|
|
144
144
|
logger.info(f"Remote data store created on worker node {WORKER_NODE_IP}")
|
|
@@ -232,15 +232,15 @@ class TQBandwidthTester:
|
|
|
232
232
|
# 限制在远程节点
|
|
233
233
|
for storage_unit_rank in range(self.config.num_data_storage_units):
|
|
234
234
|
storage_node = SimpleStorageUnit.options(
|
|
235
|
-
num_cpus=
|
|
236
|
-
resources={f"node:{WORKER_NODE_IP}":
|
|
235
|
+
num_cpus=1,
|
|
236
|
+
resources={f"node:{WORKER_NODE_IP}": 1},
|
|
237
237
|
runtime_env={"env_vars": {"OMP_NUM_THREADS": "2"}},
|
|
238
238
|
).remote(
|
|
239
239
|
storage_unit_size=math.ceil(total_storage_size / self.config.num_data_storage_units)
|
|
240
240
|
)
|
|
241
241
|
self.data_system_storage_units[storage_unit_rank] = storage_node
|
|
242
242
|
else:
|
|
243
|
-
storage_placement_group = get_placement_group(self.config.num_data_storage_units, num_cpus_per_actor=
|
|
243
|
+
storage_placement_group = get_placement_group(self.config.num_data_storage_units, num_cpus_per_actor=1)
|
|
244
244
|
for storage_unit_rank in range(self.config.num_data_storage_units):
|
|
245
245
|
storage_node = SimpleStorageUnit.options(
|
|
246
246
|
placement_group=storage_placement_group,
|
tests/test_metadata.py
CHANGED
|
@@ -535,6 +535,153 @@ class TestBatchMeta:
|
|
|
535
535
|
assert selected_field.production_status == ProductionStatus.READY_FOR_CONSUME
|
|
536
536
|
assert selected_field.name == "field1"
|
|
537
537
|
|
|
538
|
+
def test_batch_meta_select_samples(self):
|
|
539
|
+
"""Example: Select specific samples from a batch."""
|
|
540
|
+
fields = {
|
|
541
|
+
"field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)),
|
|
542
|
+
"field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)),
|
|
543
|
+
}
|
|
544
|
+
samples = [
|
|
545
|
+
SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
|
|
546
|
+
SampleMeta(partition_id="partition_0", global_index=1, fields=fields),
|
|
547
|
+
SampleMeta(partition_id="partition_0", global_index=2, fields=fields),
|
|
548
|
+
SampleMeta(partition_id="partition_0", global_index=3, fields=fields),
|
|
549
|
+
]
|
|
550
|
+
batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"})
|
|
551
|
+
|
|
552
|
+
# Select samples at indices [0, 2]
|
|
553
|
+
selected_batch = batch.select_samples([0, 2])
|
|
554
|
+
|
|
555
|
+
# Check number of samples
|
|
556
|
+
assert len(selected_batch) == 2
|
|
557
|
+
# Check global indexes
|
|
558
|
+
assert selected_batch.global_indexes == [0, 2]
|
|
559
|
+
# Check fields are preserved
|
|
560
|
+
for sample in selected_batch.samples:
|
|
561
|
+
assert "field1" in sample.fields
|
|
562
|
+
assert "field2" in sample.fields
|
|
563
|
+
# Original batch is unchanged
|
|
564
|
+
assert len(batch) == 4
|
|
565
|
+
# Extra info is preserved
|
|
566
|
+
assert selected_batch.extra_info["test_key"] == "test_value"
|
|
567
|
+
|
|
568
|
+
def test_batch_meta_select_samples_all_indices(self):
|
|
569
|
+
"""Example: Select all samples using complete index list."""
|
|
570
|
+
fields = {
|
|
571
|
+
"test_field": FieldMeta(
|
|
572
|
+
name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME
|
|
573
|
+
)
|
|
574
|
+
}
|
|
575
|
+
samples = [
|
|
576
|
+
SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
|
|
577
|
+
SampleMeta(partition_id="partition_0", global_index=1, fields=fields),
|
|
578
|
+
SampleMeta(partition_id="partition_0", global_index=2, fields=fields),
|
|
579
|
+
]
|
|
580
|
+
batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"})
|
|
581
|
+
|
|
582
|
+
# Select all samples
|
|
583
|
+
selected_batch = batch.select_samples([0, 1, 2])
|
|
584
|
+
|
|
585
|
+
# All samples are selected
|
|
586
|
+
assert len(selected_batch) == 3
|
|
587
|
+
assert selected_batch.global_indexes == [0, 1, 2]
|
|
588
|
+
# Extra info is preserved
|
|
589
|
+
assert selected_batch.extra_info["test_key"] == "test_value"
|
|
590
|
+
|
|
591
|
+
def test_batch_meta_select_samples_single_sample(self):
|
|
592
|
+
"""Example: Select a single sample from batch."""
|
|
593
|
+
fields = {
|
|
594
|
+
"test_field": FieldMeta(
|
|
595
|
+
name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME
|
|
596
|
+
)
|
|
597
|
+
}
|
|
598
|
+
samples = [
|
|
599
|
+
SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
|
|
600
|
+
SampleMeta(partition_id="partition_0", global_index=1, fields=fields),
|
|
601
|
+
SampleMeta(partition_id="partition_0", global_index=2, fields=fields),
|
|
602
|
+
]
|
|
603
|
+
batch = BatchMeta(samples=samples)
|
|
604
|
+
|
|
605
|
+
# Select only the middle sample
|
|
606
|
+
selected_batch = batch.select_samples([1])
|
|
607
|
+
|
|
608
|
+
assert len(selected_batch) == 1
|
|
609
|
+
assert selected_batch.global_indexes == [1]
|
|
610
|
+
assert selected_batch.samples[0].batch_index == 0 # New batch index
|
|
611
|
+
|
|
612
|
+
def test_batch_meta_select_samples_empty_list(self):
|
|
613
|
+
"""Example: Select with empty list returns empty batch."""
|
|
614
|
+
fields = {
|
|
615
|
+
"test_field": FieldMeta(
|
|
616
|
+
name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME
|
|
617
|
+
)
|
|
618
|
+
}
|
|
619
|
+
samples = [
|
|
620
|
+
SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
|
|
621
|
+
SampleMeta(partition_id="partition_0", global_index=1, fields=fields),
|
|
622
|
+
]
|
|
623
|
+
batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"})
|
|
624
|
+
|
|
625
|
+
# Select with empty list
|
|
626
|
+
selected_batch = batch.select_samples([])
|
|
627
|
+
|
|
628
|
+
assert len(selected_batch) == 0
|
|
629
|
+
assert selected_batch.global_indexes == []
|
|
630
|
+
# Extra info is still preserved
|
|
631
|
+
assert selected_batch.extra_info["test_key"] == "test_value"
|
|
632
|
+
|
|
633
|
+
def test_batch_meta_select_samples_reverse_order(self):
|
|
634
|
+
"""Example: Select samples in reverse order."""
|
|
635
|
+
fields = {
|
|
636
|
+
"test_field": FieldMeta(
|
|
637
|
+
name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME
|
|
638
|
+
)
|
|
639
|
+
}
|
|
640
|
+
samples = [
|
|
641
|
+
SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
|
|
642
|
+
SampleMeta(partition_id="partition_0", global_index=1, fields=fields),
|
|
643
|
+
SampleMeta(partition_id="partition_0", global_index=2, fields=fields),
|
|
644
|
+
]
|
|
645
|
+
batch = BatchMeta(samples=samples)
|
|
646
|
+
|
|
647
|
+
# Select samples in reverse order
|
|
648
|
+
selected_batch = batch.select_samples([2, 1, 0])
|
|
649
|
+
|
|
650
|
+
assert len(selected_batch) == 3
|
|
651
|
+
assert selected_batch.global_indexes == [2, 1, 0]
|
|
652
|
+
# Batch indexes are re-assigned
|
|
653
|
+
assert selected_batch.samples[0].global_index == 2
|
|
654
|
+
assert selected_batch.samples[1].global_index == 1
|
|
655
|
+
assert selected_batch.samples[2].global_index == 0
|
|
656
|
+
|
|
657
|
+
def test_batch_meta_select_samples_with_extra_info(self):
|
|
658
|
+
"""Example: Select samples preserves all extra info types."""
|
|
659
|
+
fields = {
|
|
660
|
+
"test_field": FieldMeta(
|
|
661
|
+
name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME
|
|
662
|
+
)
|
|
663
|
+
}
|
|
664
|
+
samples = [
|
|
665
|
+
SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
|
|
666
|
+
SampleMeta(partition_id="partition_0", global_index=1, fields=fields),
|
|
667
|
+
]
|
|
668
|
+
batch = BatchMeta(samples=samples)
|
|
669
|
+
|
|
670
|
+
# Add various extra info types
|
|
671
|
+
batch.extra_info["tensor"] = torch.randn(3, 4)
|
|
672
|
+
batch.extra_info["string"] = "test_string"
|
|
673
|
+
batch.extra_info["number"] = 42
|
|
674
|
+
batch.extra_info["list"] = [1, 2, 3]
|
|
675
|
+
|
|
676
|
+
# Select one sample
|
|
677
|
+
selected_batch = batch.select_samples([0])
|
|
678
|
+
|
|
679
|
+
# All extra info is preserved
|
|
680
|
+
assert "tensor" in selected_batch.extra_info
|
|
681
|
+
assert selected_batch.extra_info["string"] == "test_string"
|
|
682
|
+
assert selected_batch.extra_info["number"] == 42
|
|
683
|
+
assert selected_batch.extra_info["list"] == [1, 2, 3]
|
|
684
|
+
|
|
538
685
|
def test_batch_meta_extra_info_operations(self):
|
|
539
686
|
"""Example: Extra info management operations."""
|
|
540
687
|
fields = {
|
|
@@ -541,3 +541,56 @@ def test_nested_jagged_tensor_serialization(enable_zero_copy):
|
|
|
541
541
|
# Verify individual components
|
|
542
542
|
for i in range(len(outer_td["nested_jagged1"].unbind())):
|
|
543
543
|
assert torch.allclose(decoded_msg.body["data"]["nested_jagged1"][i], outer_td["nested_jagged1"][i])
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
@pytest.mark.parametrize("enable_zero_copy", [True, False])
|
|
547
|
+
def test_single_nested_tensor_serialization(enable_zero_copy):
|
|
548
|
+
"""Test serialization of nested tensor with only one element (edge case for zero-copy)."""
|
|
549
|
+
with patch("transfer_queue.utils.zmq_utils.TQ_ZERO_COPY_SERIALIZATION", enable_zero_copy):
|
|
550
|
+
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType
|
|
551
|
+
|
|
552
|
+
# Create nested tensor with only one element
|
|
553
|
+
# This is the critical edge case where a nested tensor with 1 element
|
|
554
|
+
# must be distinguished from a regular tensor during deserialization
|
|
555
|
+
single_nested = torch.nested.as_nested_tensor([torch.randn(4, 3)], layout=torch.strided)
|
|
556
|
+
# For normal tensor, expand to batch_size=1 to match the nested tensor's batch dimension
|
|
557
|
+
normal_tensor = torch.randn(1, 4, 3)
|
|
558
|
+
|
|
559
|
+
# Create TensorDict with both types
|
|
560
|
+
td = TensorDict(
|
|
561
|
+
{
|
|
562
|
+
"single_nested_tensor": single_nested,
|
|
563
|
+
"normal_tensor": normal_tensor,
|
|
564
|
+
},
|
|
565
|
+
batch_size=1,
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
msg = ZMQMessage(
|
|
569
|
+
request_type=ZMQRequestType.PUT_DATA,
|
|
570
|
+
sender_id="test",
|
|
571
|
+
receiver_id="test",
|
|
572
|
+
request_id="test",
|
|
573
|
+
timestamp=0.0,
|
|
574
|
+
body={"data": td},
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
encoded_msg = msg.serialize()
|
|
578
|
+
decoded_msg = ZMQMessage.deserialize(encoded_msg)
|
|
579
|
+
|
|
580
|
+
# Verify batch sizes
|
|
581
|
+
assert decoded_msg.body["data"].batch_size == td.batch_size
|
|
582
|
+
|
|
583
|
+
# Verify normal tensor
|
|
584
|
+
assert torch.allclose(decoded_msg.body["data"]["normal_tensor"], td["normal_tensor"])
|
|
585
|
+
assert decoded_msg.body["data"]["normal_tensor"].shape == td["normal_tensor"].shape
|
|
586
|
+
|
|
587
|
+
# Verify single nested tensor is properly reconstructed as nested
|
|
588
|
+
assert decoded_msg.body["data"]["single_nested_tensor"].is_nested
|
|
589
|
+
assert decoded_msg.body["data"]["single_nested_tensor"].layout == torch.strided
|
|
590
|
+
assert len(decoded_msg.body["data"]["single_nested_tensor"].unbind()) == 1
|
|
591
|
+
assert torch.allclose(decoded_msg.body["data"]["single_nested_tensor"][0], td["single_nested_tensor"][0])
|
|
592
|
+
|
|
593
|
+
# Ensure the nested tensor with single element is correctly distinguished from regular tensor
|
|
594
|
+
# Both should have the same data but different types
|
|
595
|
+
assert not decoded_msg.body["data"]["normal_tensor"].is_nested
|
|
596
|
+
assert decoded_msg.body["data"]["single_nested_tensor"].is_nested
|
transfer_queue/controller.py
CHANGED
|
@@ -32,6 +32,7 @@ from transfer_queue.metadata import (
|
|
|
32
32
|
SampleMeta,
|
|
33
33
|
)
|
|
34
34
|
from transfer_queue.sampler import BaseSampler, SequentialSampler
|
|
35
|
+
from transfer_queue.utils.perf_utils import IntervalPerfMonitor
|
|
35
36
|
from transfer_queue.utils.utils import (
|
|
36
37
|
ProductionStatus,
|
|
37
38
|
TransferQueueRole,
|
|
@@ -584,7 +585,7 @@ class TransferQueueController:
|
|
|
584
585
|
|
|
585
586
|
self.partitions[partition_id] = DataPartitionStatus(partition_id=partition_id)
|
|
586
587
|
|
|
587
|
-
logger.info(f"Created partition {partition_id}
|
|
588
|
+
logger.info(f"Created partition {partition_id}")
|
|
588
589
|
return True
|
|
589
590
|
|
|
590
591
|
def get_partition(self, partition_id: str) -> Optional[DataPartitionStatus]:
|
|
@@ -1008,7 +1009,7 @@ class TransferQueueController:
|
|
|
1008
1009
|
poller = zmq.Poller()
|
|
1009
1010
|
poller.register(self.handshake_socket, zmq.POLLIN)
|
|
1010
1011
|
|
|
1011
|
-
logger.info(f"
|
|
1012
|
+
logger.info(f"Controller {self.controller_id} started waiting for storage connections...")
|
|
1012
1013
|
|
|
1013
1014
|
while True:
|
|
1014
1015
|
socks = dict(poller.poll(TQ_CONTROLLER_CONNECTION_CHECK_INTERVAL * 1000))
|
|
@@ -1036,23 +1037,23 @@ class TransferQueueController:
|
|
|
1036
1037
|
self._connected_storage_managers.add(storage_manager_id)
|
|
1037
1038
|
storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown")
|
|
1038
1039
|
logger.info(
|
|
1039
|
-
f"
|
|
1040
|
+
f"Controller {self.controller_id} received handshake from "
|
|
1040
1041
|
f"storage manager {storage_manager_id} (type: {storage_manager_type}). "
|
|
1041
1042
|
f"Total connected: {len(self._connected_storage_managers)}"
|
|
1042
1043
|
)
|
|
1043
1044
|
else:
|
|
1044
1045
|
logger.debug(
|
|
1045
|
-
f"
|
|
1046
|
+
f"Controller {self.controller_id} received duplicate handshake from "
|
|
1046
1047
|
f"storage manager {storage_manager_id}. Resending ACK."
|
|
1047
1048
|
)
|
|
1048
1049
|
|
|
1049
1050
|
except Exception as e:
|
|
1050
|
-
logger.error(f"
|
|
1051
|
+
logger.error(f"Controller {self.controller_id} error processing handshake: {e}")
|
|
1051
1052
|
|
|
1052
1053
|
def _start_process_handshake(self):
|
|
1053
1054
|
"""Start the handshake process thread."""
|
|
1054
1055
|
self.wait_connection_thread = Thread(
|
|
1055
|
-
target=self._wait_connection, name="
|
|
1056
|
+
target=self._wait_connection, name="TransferQueueControllerWaitConnectionThread", daemon=True
|
|
1056
1057
|
)
|
|
1057
1058
|
self.wait_connection_thread.start()
|
|
1058
1059
|
|
|
@@ -1060,7 +1061,7 @@ class TransferQueueController:
|
|
|
1060
1061
|
"""Start the data status update processing thread."""
|
|
1061
1062
|
self.process_update_data_status_thread = Thread(
|
|
1062
1063
|
target=self._update_data_status,
|
|
1063
|
-
name="
|
|
1064
|
+
name="TransferQueueControllerProcessUpdateDataStatusThread",
|
|
1064
1065
|
daemon=True,
|
|
1065
1066
|
)
|
|
1066
1067
|
self.process_update_data_status_thread.start()
|
|
@@ -1068,12 +1069,17 @@ class TransferQueueController:
|
|
|
1068
1069
|
def _start_process_request(self):
|
|
1069
1070
|
"""Start the request processing thread."""
|
|
1070
1071
|
self.process_request_thread = Thread(
|
|
1071
|
-
target=self._process_request, name="
|
|
1072
|
+
target=self._process_request, name="TransferQueueControllerProcessRequestThread", daemon=True
|
|
1072
1073
|
)
|
|
1073
1074
|
self.process_request_thread.start()
|
|
1074
1075
|
|
|
1075
1076
|
def _process_request(self):
|
|
1076
1077
|
"""Main request processing loop - adapted for partition-based operations."""
|
|
1078
|
+
|
|
1079
|
+
logger.info(f"[{self.controller_id}]: start processing requests...")
|
|
1080
|
+
|
|
1081
|
+
perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id)
|
|
1082
|
+
|
|
1077
1083
|
while True:
|
|
1078
1084
|
messages = self.request_handle_socket.recv_multipart()
|
|
1079
1085
|
identity = messages.pop(0)
|
|
@@ -1081,88 +1087,96 @@ class TransferQueueController:
|
|
|
1081
1087
|
request_msg = ZMQMessage.deserialize(serialized_msg)
|
|
1082
1088
|
|
|
1083
1089
|
if request_msg.request_type == ZMQRequestType.GET_META:
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
response_msg = ZMQMessage.create(
|
|
1096
|
-
request_type=ZMQRequestType.GET_META_RESPONSE,
|
|
1097
|
-
sender_id=self.controller_id,
|
|
1098
|
-
receiver_id=request_msg.sender_id,
|
|
1099
|
-
body={"metadata": metadata},
|
|
1100
|
-
)
|
|
1101
|
-
|
|
1102
|
-
elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META:
|
|
1103
|
-
params = request_msg.body
|
|
1104
|
-
partition_id = params["partition_id"]
|
|
1105
|
-
|
|
1106
|
-
metadata = self.get_metadata(
|
|
1107
|
-
data_fields=[],
|
|
1108
|
-
partition_id=partition_id,
|
|
1109
|
-
mode="insert",
|
|
1110
|
-
)
|
|
1111
|
-
response_msg = ZMQMessage.create(
|
|
1112
|
-
request_type=ZMQRequestType.GET_CLEAR_META_RESPONSE,
|
|
1113
|
-
sender_id=self.controller_id,
|
|
1114
|
-
receiver_id=request_msg.sender_id,
|
|
1115
|
-
body={"metadata": metadata},
|
|
1116
|
-
)
|
|
1117
|
-
elif request_msg.request_type == ZMQRequestType.CLEAR_META:
|
|
1118
|
-
params = request_msg.body
|
|
1119
|
-
partition_id = params["partition_id"]
|
|
1090
|
+
with perf_monitor.measure(op_type="GET_META"):
|
|
1091
|
+
params = request_msg.body
|
|
1092
|
+
|
|
1093
|
+
metadata = self.get_metadata(
|
|
1094
|
+
data_fields=params["data_fields"],
|
|
1095
|
+
batch_size=params["batch_size"],
|
|
1096
|
+
partition_id=params["partition_id"],
|
|
1097
|
+
mode=params.get("mode", "fetch"),
|
|
1098
|
+
task_name=params.get("task_name"),
|
|
1099
|
+
sampling_config=params.get("sampling_config"),
|
|
1100
|
+
)
|
|
1120
1101
|
|
|
1121
|
-
clear_success = self.clear(partition_id)
|
|
1122
|
-
if clear_success:
|
|
1123
1102
|
response_msg = ZMQMessage.create(
|
|
1124
|
-
request_type=ZMQRequestType.
|
|
1103
|
+
request_type=ZMQRequestType.GET_META_RESPONSE,
|
|
1125
1104
|
sender_id=self.controller_id,
|
|
1126
1105
|
receiver_id=request_msg.sender_id,
|
|
1127
|
-
body={"
|
|
1106
|
+
body={"metadata": metadata},
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META:
|
|
1110
|
+
with perf_monitor.measure(op_type="GET_CLEAR_META"):
|
|
1111
|
+
params = request_msg.body
|
|
1112
|
+
partition_id = params["partition_id"]
|
|
1113
|
+
|
|
1114
|
+
metadata = self.get_metadata(
|
|
1115
|
+
data_fields=[],
|
|
1116
|
+
partition_id=partition_id,
|
|
1117
|
+
mode="insert",
|
|
1128
1118
|
)
|
|
1129
|
-
else:
|
|
1130
1119
|
response_msg = ZMQMessage.create(
|
|
1131
|
-
request_type=ZMQRequestType.
|
|
1120
|
+
request_type=ZMQRequestType.GET_CLEAR_META_RESPONSE,
|
|
1132
1121
|
sender_id=self.controller_id,
|
|
1133
1122
|
receiver_id=request_msg.sender_id,
|
|
1134
|
-
body={"
|
|
1123
|
+
body={"metadata": metadata},
|
|
1135
1124
|
)
|
|
1125
|
+
elif request_msg.request_type == ZMQRequestType.CLEAR_META:
|
|
1126
|
+
with perf_monitor.measure(op_type="CLEAR_META"):
|
|
1127
|
+
params = request_msg.body
|
|
1128
|
+
partition_id = params["partition_id"]
|
|
1129
|
+
|
|
1130
|
+
clear_success = self.clear(partition_id)
|
|
1131
|
+
if clear_success:
|
|
1132
|
+
response_msg = ZMQMessage.create(
|
|
1133
|
+
request_type=ZMQRequestType.CLEAR_META_RESPONSE,
|
|
1134
|
+
sender_id=self.controller_id,
|
|
1135
|
+
receiver_id=request_msg.sender_id,
|
|
1136
|
+
body={"message": f"Clear operation completed by controller {self.controller_id}"},
|
|
1137
|
+
)
|
|
1138
|
+
else:
|
|
1139
|
+
response_msg = ZMQMessage.create(
|
|
1140
|
+
request_type=ZMQRequestType.CLEAR_META_RESPONSE,
|
|
1141
|
+
sender_id=self.controller_id,
|
|
1142
|
+
receiver_id=request_msg.sender_id,
|
|
1143
|
+
body={"error": f"Clear operation failed for partition {partition_id}"},
|
|
1144
|
+
)
|
|
1136
1145
|
|
|
1137
1146
|
elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION:
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1147
|
+
with perf_monitor.measure(op_type="CHECK_CONSUMPTION"):
|
|
1148
|
+
# Handle consumption status checks
|
|
1149
|
+
params = request_msg.body
|
|
1150
|
+
|
|
1151
|
+
consumption_status = self.get_consumption_status(params["partition_id"], params["task_name"])
|
|
1152
|
+
sample_filter = params.get("sample_filter")
|
|
1153
|
+
|
|
1154
|
+
if consumption_status is not None and sample_filter:
|
|
1155
|
+
batch_status = consumption_status[sample_filter]
|
|
1156
|
+
consumed = torch.all(batch_status == 1).item()
|
|
1157
|
+
elif consumption_status is not None:
|
|
1158
|
+
batch_status = consumption_status
|
|
1159
|
+
consumed = torch.all(batch_status == 1).item()
|
|
1160
|
+
else:
|
|
1161
|
+
consumed = False
|
|
1162
|
+
|
|
1163
|
+
response_msg = ZMQMessage.create(
|
|
1164
|
+
request_type=ZMQRequestType.CONSUMPTION_RESPONSE,
|
|
1165
|
+
sender_id=self.controller_id,
|
|
1166
|
+
receiver_id=request_msg.sender_id,
|
|
1167
|
+
body={
|
|
1168
|
+
"partition_id": params["partition_id"],
|
|
1169
|
+
"consumed": consumed,
|
|
1170
|
+
},
|
|
1171
|
+
)
|
|
1162
1172
|
self.request_handle_socket.send_multipart([identity, *response_msg.serialize()])
|
|
1163
1173
|
|
|
1164
1174
|
def _update_data_status(self):
|
|
1165
1175
|
"""Process data status update messages from storage units - adapted for partitions."""
|
|
1176
|
+
logger.info(f"[{self.controller_id}]: start receiving update_data_status requests...")
|
|
1177
|
+
|
|
1178
|
+
perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id)
|
|
1179
|
+
|
|
1166
1180
|
while True:
|
|
1167
1181
|
messages = self.data_status_update_socket.recv_multipart()
|
|
1168
1182
|
identity = messages.pop(0)
|
|
@@ -1170,32 +1184,33 @@ class TransferQueueController:
|
|
|
1170
1184
|
request_msg = ZMQMessage.deserialize(serialized_msg)
|
|
1171
1185
|
|
|
1172
1186
|
if request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE:
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1187
|
+
with perf_monitor.measure(op_type="NOTIFY_DATA_UPDATE"):
|
|
1188
|
+
message_data = request_msg.body
|
|
1189
|
+
partition_id = message_data.get("partition_id")
|
|
1190
|
+
|
|
1191
|
+
# Update production status
|
|
1192
|
+
success = self.update_production_status(
|
|
1193
|
+
partition_id=partition_id,
|
|
1194
|
+
global_indexes=message_data.get("global_indexes", []),
|
|
1195
|
+
field_names=message_data.get("fields", []),
|
|
1196
|
+
dtypes=message_data.get("dtypes", {}),
|
|
1197
|
+
shapes=message_data.get("shapes", {}),
|
|
1198
|
+
)
|
|
1184
1199
|
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
|
|
1200
|
+
if success:
|
|
1201
|
+
logger.info(f"Updated production status for partition {partition_id}")
|
|
1202
|
+
|
|
1203
|
+
# Send acknowledgment
|
|
1204
|
+
response_msg = ZMQMessage.create(
|
|
1205
|
+
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK,
|
|
1206
|
+
sender_id=self.controller_id,
|
|
1207
|
+
body={
|
|
1208
|
+
"controller_id": self.controller_id,
|
|
1209
|
+
"partition_id": partition_id,
|
|
1210
|
+
"success": success,
|
|
1211
|
+
},
|
|
1212
|
+
)
|
|
1213
|
+
self.data_status_update_socket.send_multipart([identity, *response_msg.serialize()])
|
|
1199
1214
|
|
|
1200
1215
|
def get_zmq_server_info(self) -> ZMQServerInfo:
|
|
1201
1216
|
"""Get ZMQ server connection information."""
|
transfer_queue/metadata.py
CHANGED
|
@@ -261,6 +261,28 @@ class BatchMeta:
|
|
|
261
261
|
object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples))
|
|
262
262
|
return self
|
|
263
263
|
|
|
264
|
+
def select_samples(self, sample_indices: list[int]) -> "BatchMeta":
|
|
265
|
+
"""
|
|
266
|
+
Select specific samples from this batch.
|
|
267
|
+
This will construct a new BatchMeta instance containing only the specified samples.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
sample_indices (list[int]): List of sample indices to retain.
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
BatchMeta: A new BatchMeta instance containing only the specified samples.
|
|
274
|
+
"""
|
|
275
|
+
|
|
276
|
+
if any(i < 0 or i >= len(self.samples) for i in sample_indices):
|
|
277
|
+
raise ValueError(f"Sample indices must be in range [0, {len(self.samples)})")
|
|
278
|
+
|
|
279
|
+
selected_samples = [self.samples[i] for i in sample_indices]
|
|
280
|
+
|
|
281
|
+
# construct new BatchMeta instance
|
|
282
|
+
selected_batch_meta = BatchMeta(samples=selected_samples, extra_info=self.extra_info.copy())
|
|
283
|
+
|
|
284
|
+
return selected_batch_meta
|
|
285
|
+
|
|
264
286
|
def select_fields(self, field_names: list[str]) -> "BatchMeta":
|
|
265
287
|
"""
|
|
266
288
|
Select specific fields from all samples in this batch.
|
|
@@ -287,7 +309,7 @@ class BatchMeta:
|
|
|
287
309
|
def __getitem__(self, item):
|
|
288
310
|
if isinstance(item, int | np.integer):
|
|
289
311
|
sample_meta = self.samples[item] if self.samples else []
|
|
290
|
-
return BatchMeta(samples=[sample_meta], extra_info=self.extra_info)
|
|
312
|
+
return BatchMeta(samples=[sample_meta], extra_info=self.extra_info.copy())
|
|
291
313
|
else:
|
|
292
314
|
raise TypeError(f"Indexing with {type(item)} is not supported now!")
|
|
293
315
|
|
|
@@ -508,6 +530,13 @@ class BatchMeta:
|
|
|
508
530
|
extra_info = {}
|
|
509
531
|
return cls(samples=[], extra_info=extra_info)
|
|
510
532
|
|
|
533
|
+
def __str__(self):
|
|
534
|
+
sample_strs = ", ".join(str(sample) for sample in self.samples)
|
|
535
|
+
return (
|
|
536
|
+
f"BatchMeta(size={self.size}, field_names={self.field_names}, is_ready={self.is_ready}, "
|
|
537
|
+
f"samples=[{sample_strs}], extra_info={self.extra_info})"
|
|
538
|
+
)
|
|
539
|
+
|
|
511
540
|
|
|
512
541
|
def _union_fields(fields1: dict[str, FieldMeta], fields2: dict[str, FieldMeta]) -> dict[str, FieldMeta]:
|
|
513
542
|
"""Union two sample's fields. If fields overlap, the fields in fields1 will be replaced by fields2."""
|
|
@@ -173,6 +173,8 @@ class AsyncSimpleStorageManager(TransferQueueStorageManager):
|
|
|
173
173
|
metadata: BatchMeta containing storage location information.
|
|
174
174
|
"""
|
|
175
175
|
|
|
176
|
+
logger.info(f"{__class__.__name__}: receive put_data request, putting {metadata.size} samples.")
|
|
177
|
+
|
|
176
178
|
# group samples by storage unit
|
|
177
179
|
storage_meta_groups = build_storage_meta_groups(
|
|
178
180
|
metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping
|
|
@@ -228,7 +230,8 @@ class AsyncSimpleStorageManager(TransferQueueStorageManager):
|
|
|
228
230
|
else NonTensorStack(*transfer_data["field_data"][field])
|
|
229
231
|
)
|
|
230
232
|
for field in transfer_data["field_data"]
|
|
231
|
-
}
|
|
233
|
+
},
|
|
234
|
+
batch_size=len(local_indexes),
|
|
232
235
|
)
|
|
233
236
|
|
|
234
237
|
request_msg = ZMQMessage.create(
|
|
@@ -263,6 +266,8 @@ class AsyncSimpleStorageManager(TransferQueueStorageManager):
|
|
|
263
266
|
TensorDict containing the retrieved data.
|
|
264
267
|
"""
|
|
265
268
|
|
|
269
|
+
logger.info(f"{__class__.__name__}: receive get_data request, getting {metadata.size} samples.")
|
|
270
|
+
|
|
266
271
|
# group samples by storage unit
|
|
267
272
|
storage_meta_groups = build_storage_meta_groups(
|
|
268
273
|
metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping
|
|
@@ -28,14 +28,14 @@ from ray.util import get_node_ip_address
|
|
|
28
28
|
from tensordict import NonTensorStack, TensorDict
|
|
29
29
|
|
|
30
30
|
from transfer_queue.metadata import SampleMeta
|
|
31
|
+
from transfer_queue.utils.perf_utils import IntervalPerfMonitor
|
|
31
32
|
from transfer_queue.utils.utils import TransferQueueRole
|
|
32
33
|
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket, get_free_port
|
|
33
34
|
|
|
34
35
|
logger = logging.getLogger(__name__)
|
|
35
36
|
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
|
|
36
37
|
|
|
37
|
-
|
|
38
|
-
TQ_STORAGE_POLLER_TIMEOUT = int(os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 5))
|
|
38
|
+
TQ_STORAGE_POLLER_TIMEOUT = int(os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 5)) # in seconds
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
class StorageUnitData:
|
|
@@ -200,7 +200,7 @@ class SimpleStorageUnit:
|
|
|
200
200
|
def _start_process_put_get(self) -> None:
|
|
201
201
|
"""Create a daemon thread and start put/get process."""
|
|
202
202
|
self.process_put_get_thread = Thread(
|
|
203
|
-
target=self._process_put_get, name=f"StorageUnitProcessPutGetThread-{self.
|
|
203
|
+
target=self._process_put_get, name=f"StorageUnitProcessPutGetThread-{self.storage_unit_id}", daemon=True
|
|
204
204
|
)
|
|
205
205
|
self.process_put_get_thread.start()
|
|
206
206
|
|
|
@@ -209,6 +209,10 @@ class SimpleStorageUnit:
|
|
|
209
209
|
poller = zmq.Poller()
|
|
210
210
|
poller.register(self.put_get_socket, zmq.POLLIN)
|
|
211
211
|
|
|
212
|
+
logger.info(f"[{self.storage_unit_id}]: start processing put/get requests...")
|
|
213
|
+
|
|
214
|
+
perf_monitor = IntervalPerfMonitor(caller_name=self.storage_unit_id)
|
|
215
|
+
|
|
212
216
|
while True:
|
|
213
217
|
socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000))
|
|
214
218
|
|
|
@@ -219,29 +223,32 @@ class SimpleStorageUnit:
|
|
|
219
223
|
request_msg = ZMQMessage.deserialize(serialized_msg)
|
|
220
224
|
operation = request_msg.request_type
|
|
221
225
|
try:
|
|
222
|
-
logger.debug(f"[{self.
|
|
226
|
+
logger.debug(f"[{self.storage_unit_id}]: receive operation: {operation}, message: {request_msg}")
|
|
223
227
|
|
|
224
228
|
if operation == ZMQRequestType.PUT_DATA:
|
|
225
|
-
|
|
229
|
+
with perf_monitor.measure(op_type="PUT_DATA"):
|
|
230
|
+
response_msg = self._handle_put(request_msg)
|
|
226
231
|
elif operation == ZMQRequestType.GET_DATA:
|
|
227
|
-
|
|
232
|
+
with perf_monitor.measure(op_type="GET_DATA"):
|
|
233
|
+
response_msg = self._handle_get(request_msg)
|
|
228
234
|
elif operation == ZMQRequestType.CLEAR_DATA:
|
|
229
|
-
|
|
235
|
+
with perf_monitor.measure(op_type="CLEAR_DATA"):
|
|
236
|
+
response_msg = self._handle_clear(request_msg)
|
|
230
237
|
else:
|
|
231
238
|
response_msg = ZMQMessage.create(
|
|
232
239
|
request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR,
|
|
233
|
-
sender_id=self.
|
|
240
|
+
sender_id=self.storage_unit_id,
|
|
234
241
|
body={
|
|
235
|
-
"message": f"Storage unit id #{self.
|
|
242
|
+
"message": f"Storage unit id #{self.storage_unit_id} "
|
|
236
243
|
f"receive invalid operation: {operation}."
|
|
237
244
|
},
|
|
238
245
|
)
|
|
239
246
|
except Exception as e:
|
|
240
247
|
response_msg = ZMQMessage.create(
|
|
241
248
|
request_type=ZMQRequestType.PUT_GET_ERROR,
|
|
242
|
-
sender_id=self.
|
|
249
|
+
sender_id=self.storage_unit_id,
|
|
243
250
|
body={
|
|
244
|
-
"message": f"Storage unit id #{self.
|
|
251
|
+
"message": f"Storage unit id #{self.storage_unit_id} occur error in processing "
|
|
245
252
|
f"put/get/clear request, detail error message: {str(e)}."
|
|
246
253
|
},
|
|
247
254
|
)
|
|
@@ -268,17 +275,17 @@ class SimpleStorageUnit:
|
|
|
268
275
|
|
|
269
276
|
# After put operation finish, send a message to the client
|
|
270
277
|
response_msg = ZMQMessage.create(
|
|
271
|
-
request_type=ZMQRequestType.PUT_DATA_RESPONSE, sender_id=self.
|
|
278
|
+
request_type=ZMQRequestType.PUT_DATA_RESPONSE, sender_id=self.storage_unit_id, body={}
|
|
272
279
|
)
|
|
273
280
|
|
|
274
281
|
return response_msg
|
|
275
282
|
except Exception as e:
|
|
276
283
|
return ZMQMessage.create(
|
|
277
284
|
request_type=ZMQRequestType.PUT_ERROR,
|
|
278
|
-
sender_id=self.
|
|
285
|
+
sender_id=self.storage_unit_id,
|
|
279
286
|
body={
|
|
280
287
|
"message": f"Failed to put data into storage unit id "
|
|
281
|
-
f"#{self.
|
|
288
|
+
f"#{self.storage_unit_id}, detail error message: {str(e)}"
|
|
282
289
|
},
|
|
283
290
|
)
|
|
284
291
|
|
|
@@ -300,7 +307,7 @@ class SimpleStorageUnit:
|
|
|
300
307
|
|
|
301
308
|
response_msg = ZMQMessage.create(
|
|
302
309
|
request_type=ZMQRequestType.GET_DATA_RESPONSE,
|
|
303
|
-
sender_id=self.
|
|
310
|
+
sender_id=self.storage_unit_id,
|
|
304
311
|
body={
|
|
305
312
|
"data": result_data,
|
|
306
313
|
},
|
|
@@ -308,9 +315,9 @@ class SimpleStorageUnit:
|
|
|
308
315
|
except Exception as e:
|
|
309
316
|
response_msg = ZMQMessage.create(
|
|
310
317
|
request_type=ZMQRequestType.GET_ERROR,
|
|
311
|
-
sender_id=self.
|
|
318
|
+
sender_id=self.storage_unit_id,
|
|
312
319
|
body={
|
|
313
|
-
"message": f"Failed to get data from storage unit id #{self.
|
|
320
|
+
"message": f"Failed to get data from storage unit id #{self.storage_unit_id}, "
|
|
314
321
|
f"detail error message: {str(e)}"
|
|
315
322
|
},
|
|
316
323
|
)
|
|
@@ -333,15 +340,15 @@ class SimpleStorageUnit:
|
|
|
333
340
|
|
|
334
341
|
response_msg = ZMQMessage.create(
|
|
335
342
|
request_type=ZMQRequestType.CLEAR_DATA_RESPONSE,
|
|
336
|
-
sender_id=self.
|
|
337
|
-
body={"message": f"Clear data in storage unit id #{self.
|
|
343
|
+
sender_id=self.storage_unit_id,
|
|
344
|
+
body={"message": f"Clear data in storage unit id #{self.storage_unit_id} successfully."},
|
|
338
345
|
)
|
|
339
346
|
except Exception as e:
|
|
340
347
|
response_msg = ZMQMessage.create(
|
|
341
348
|
request_type=ZMQRequestType.CLEAR_DATA_ERROR,
|
|
342
|
-
sender_id=self.
|
|
349
|
+
sender_id=self.storage_unit_id,
|
|
343
350
|
body={
|
|
344
|
-
"message": f"Failed to clear data in storage unit id #{self.
|
|
351
|
+
"message": f"Failed to clear data in storage unit id #{self.storage_unit_id}, "
|
|
345
352
|
f"detail error message: {str(e)}"
|
|
346
353
|
},
|
|
347
354
|
)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.INFO))
|
|
9
|
+
|
|
10
|
+
TQ_PERF_LOG_FLUSH_INTERVAL = float(os.environ.get("TQ_PERF_LOG_FLUSH_INTERVAL", 10)) # in seconds
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class IntervalPerfMonitor:
|
|
14
|
+
"""
|
|
15
|
+
Monitors and logs performance statistics for operations over configurable time intervals.
|
|
16
|
+
|
|
17
|
+
This class is designed to be used in contexts where you want to track the number of successful
|
|
18
|
+
operations and their processing times, and periodically log summary statistics such as request
|
|
19
|
+
counts, rates, and timing metrics (average, max, min) per operation type.
|
|
20
|
+
|
|
21
|
+
Usage:
|
|
22
|
+
monitor = IntervalPerfMonitor("Your Class")
|
|
23
|
+
with monitor.measure("method_name"):
|
|
24
|
+
# perform upload operation
|
|
25
|
+
|
|
26
|
+
At each interval (controlled by TQ_PERF_LOG_FLUSH_INTERVAL), the monitor logs aggregated
|
|
27
|
+
statistics and resets its counters.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
caller_name (str): Name of the component or caller using the monitor, included in logs.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, caller_name: str):
|
|
34
|
+
self.caller_name = caller_name
|
|
35
|
+
self.last_flush_time = time.perf_counter()
|
|
36
|
+
|
|
37
|
+
self.success_counts: dict[str, int] = defaultdict(int)
|
|
38
|
+
self.process_time: dict[str, list[float]] = defaultdict(list)
|
|
39
|
+
|
|
40
|
+
def _flush_logs(self):
|
|
41
|
+
"""
|
|
42
|
+
Internal method to conditionally flush (log) aggregated performance statistics.
|
|
43
|
+
|
|
44
|
+
If the configured time interval (TQ_PERF_LOG_FLUSH_INTERVAL) has passed since the last flush,
|
|
45
|
+
this method logs:
|
|
46
|
+
- Total number of successful requests and requests per minute.
|
|
47
|
+
- Average processing time across all operations.
|
|
48
|
+
- For each operation type: request count, requests per minute, average, max, and min processing times.
|
|
49
|
+
After logging, all statistics are reset and the flush timer is updated.
|
|
50
|
+
"""
|
|
51
|
+
now = time.perf_counter()
|
|
52
|
+
|
|
53
|
+
# only flush if the interval has passed
|
|
54
|
+
if (now - self.last_flush_time) >= TQ_PERF_LOG_FLUSH_INTERVAL:
|
|
55
|
+
minutes = (now - self.last_flush_time) / 60
|
|
56
|
+
|
|
57
|
+
total_requests = sum(self.success_counts.values())
|
|
58
|
+
total_process_time = sum(sum(time_list) for time_list in self.process_time.values())
|
|
59
|
+
total_avg_process_time = total_process_time / total_requests if total_requests > 0 else 0.0
|
|
60
|
+
|
|
61
|
+
# max/min/avg time for each operation type
|
|
62
|
+
op_detail_stats = []
|
|
63
|
+
for op_type, count in self.success_counts.items():
|
|
64
|
+
times = self.process_time[op_type]
|
|
65
|
+
if not times:
|
|
66
|
+
op_avg = op_max = op_min = 0.0
|
|
67
|
+
else:
|
|
68
|
+
op_avg = sum(times) / len(times)
|
|
69
|
+
op_max = max(times)
|
|
70
|
+
op_min = min(times)
|
|
71
|
+
|
|
72
|
+
op_detail_stats.append(
|
|
73
|
+
f"{op_type}: req_count={count}, req/min={count / minutes:.2f}, "
|
|
74
|
+
f"avg_time={op_avg:.6f}s, max_time={op_max:.6f}s, min_time={op_min:.6f}s"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
log_msg = (
|
|
78
|
+
f"{self.caller_name}: [Performance] "
|
|
79
|
+
f"Total success requests: {total_requests}, "
|
|
80
|
+
f"Total req/min: {total_requests / minutes:.2f}, "
|
|
81
|
+
f"Total avg process time: {total_avg_process_time:.4f}s; \n"
|
|
82
|
+
f"Time range: last {minutes:.2f} minutes; \n"
|
|
83
|
+
f"Per-operation statistics: {'; '.join(op_detail_stats)}"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
logger.info(log_msg)
|
|
87
|
+
|
|
88
|
+
# reset counts
|
|
89
|
+
self.success_counts.clear()
|
|
90
|
+
self.process_time.clear()
|
|
91
|
+
self.last_flush_time = now
|
|
92
|
+
|
|
93
|
+
@contextmanager
|
|
94
|
+
def measure(self, op_type: str):
|
|
95
|
+
start_time = time.perf_counter()
|
|
96
|
+
try:
|
|
97
|
+
yield
|
|
98
|
+
finally:
|
|
99
|
+
cost = time.perf_counter() - start_time
|
|
100
|
+
self.success_counts[op_type] += 1
|
|
101
|
+
self.process_time[op_type].append(cost)
|
|
102
|
+
|
|
103
|
+
# try flush logs
|
|
104
|
+
self._flush_logs()
|
|
@@ -22,6 +22,7 @@ from dataclasses import dataclass
|
|
|
22
22
|
from typing import Any, Optional, TypeAlias
|
|
23
23
|
from uuid import uuid4
|
|
24
24
|
|
|
25
|
+
import numpy as np
|
|
25
26
|
import psutil
|
|
26
27
|
import torch
|
|
27
28
|
import zmq
|
|
@@ -162,15 +163,15 @@ class ZMQMessage:
|
|
|
162
163
|
tensor_list = tensor.unbind()
|
|
163
164
|
tensor_count = len(tensor_list)
|
|
164
165
|
serialized_tensors = [_encoder.encode(inner_tensor) for inner_tensor in tensor_list]
|
|
165
|
-
return tensor_count, serialized_tensors
|
|
166
|
+
return tensor_count, serialized_tensors # tensor_count may equal to 1 for single nested tensor
|
|
166
167
|
else:
|
|
167
|
-
return 1, [_encoder.encode(tensor)]
|
|
168
|
+
return -1, [_encoder.encode(tensor)] # use -1 to indicate regular single tensor
|
|
168
169
|
|
|
169
170
|
# Use map to process all tensors in parallel-like fashion
|
|
170
171
|
nested_tensor_info_and_serialized_tensors = list(map(process_tensor, tensors))
|
|
171
172
|
|
|
172
173
|
# Extract nested_tensor_info and flatten serialized tensors using itertools
|
|
173
|
-
nested_tensor_info = [info for info, _ in nested_tensor_info_and_serialized_tensors]
|
|
174
|
+
nested_tensor_info = np.array([info for info, _ in nested_tensor_info_and_serialized_tensors])
|
|
174
175
|
double_layer_serialized_tensors: list[list[bytestr]] = list(
|
|
175
176
|
itertools.chain.from_iterable(serialized for _, serialized in nested_tensor_info_and_serialized_tensors)
|
|
176
177
|
)
|
|
@@ -209,14 +210,14 @@ class ZMQMessage:
|
|
|
209
210
|
f"When TQ_ZERO_COPY_SERIALIZATION is enabled, input data should be a list, but got {type(data)}."
|
|
210
211
|
)
|
|
211
212
|
|
|
212
|
-
tensor_nums =
|
|
213
|
+
tensor_nums = np.abs(nested_tensor_info).sum()
|
|
213
214
|
if tensor_nums != len(single_tensors):
|
|
214
215
|
raise ValueError(f"Expecting {tensor_nums} tensors, but got {len(single_tensors)}.")
|
|
215
216
|
|
|
216
217
|
tensors = [None] * len(nested_tensor_info)
|
|
217
218
|
current_idx = 0
|
|
218
219
|
for i, tensor_num in enumerate(nested_tensor_info):
|
|
219
|
-
if tensor_num == 1:
|
|
220
|
+
if tensor_num == -1:
|
|
220
221
|
tensors[i] = single_tensors[current_idx]
|
|
221
222
|
current_idx += 1
|
|
222
223
|
else:
|
transfer_queue/version/version
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.1.4.
|
|
1
|
+
0.1.4.dev1
|
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
performance_test.py,sha256=
|
|
1
|
+
performance_test.py,sha256=Yl4k-ln5iHAJ5HIfK--Lx-tlEXQcI71Juzve-lBesUI,14434
|
|
2
2
|
serial_profiling_demo.py,sha256=b4GEoF8bSIawQlWOIK2eg9Tgn_-Q8n_KMsfRHLdK1mc,3961
|
|
3
3
|
serial_profiling_demo_nested_non_continues_test.py,sha256=ahXMGDsRc_bwE3lbVE6L-8_guezv6rpG24pKKIxVKF8,3849
|
|
4
|
+
verify_fix.py,sha256=IcZ52jbYoxFb4MRwZMx0f52E8ATFn7GbUp8fmVY31RU,3522
|
|
4
5
|
recipe/simple_use_case/async_demo.py,sha256=wGQCnXAzElE1D6-POwJVhR4UbT3kEJdCnnF-ZL5jPKg,13709
|
|
5
6
|
recipe/simple_use_case/sync_demo.py,sha256=Rfatf205t-gHaxShk2n8LavMVoMLWZpguHRKpKgMDNo,8680
|
|
6
7
|
tests/test_async_simple_storage_manager.py,sha256=qYmSJV6LYSXAzjnlzWZ5HboOhr4B1gMgNwOc5_esVaU,12366
|
|
@@ -8,21 +9,21 @@ tests/test_client.py,sha256=Wj1Eswt9qVfL0-192Mwx5_ICaQO7iso0WMJxKpGEvFg,15546
|
|
|
8
9
|
tests/test_controller.py,sha256=ZcvFCC3jSnNN_fEerjA37RQv0SSO0Xh8vjcL2mvF03o,11084
|
|
9
10
|
tests/test_controller_data_partitions.py,sha256=RQExMFcuXblpUZE_LaFaoncbwX4-YlPcBi69siRfvzY,19363
|
|
10
11
|
tests/test_kv_storage_manager.py,sha256=j45VZ14H8MY8-c4CJI1MRgljoHijaK-tkQAD_-9lmzw,4012
|
|
11
|
-
tests/test_metadata.py,sha256=
|
|
12
|
+
tests/test_metadata.py,sha256=q0X8UuxTmx-JjZlXR8nqr7cm2YbGFHWNbO9kZHgFFYU,32415
|
|
12
13
|
tests/test_samplers.py,sha256=CvYqfmbHEWWa1RyymztCAn0GcitAPOBbfJ4ud1VvO2o,19168
|
|
13
|
-
tests/test_serial_utils_on_cpu.py,sha256=
|
|
14
|
+
tests/test_serial_utils_on_cpu.py,sha256=iZII_-oVBu3KQ8Afpf7roqET9mIMegrsQ7cwg1XPXQo,23595
|
|
14
15
|
tests/test_simple_storage_unit.py,sha256=29mrQwIkS63D6-b1lNZRhUlZ2nkmjpXtQVGHPvYq_ug,16595
|
|
15
16
|
tests/test_storage_client_factory.py,sha256=lZr7SRY4rpzQB-ZgG7gbjPF2Pcde55nwweumSJT7Yd0,2363
|
|
16
17
|
transfer_queue/__init__.py,sha256=68c0sBfqHPqTa7OdzO4sAZB52XvwtjpwLqP9BWAh4fA,1535
|
|
17
18
|
transfer_queue/client.py,sha256=vH9stFyDCXtLYujdNVYjz815NjXbSUJOmxcFOZLIU1s,25831
|
|
18
|
-
transfer_queue/controller.py,sha256=
|
|
19
|
-
transfer_queue/metadata.py,sha256=
|
|
19
|
+
transfer_queue/controller.py,sha256=uIhgSEdHQDfmQicqdPXnyXQuJ6Cn5AH3Wyue0SlYdrY,50853
|
|
20
|
+
transfer_queue/metadata.py,sha256=W71tN_-AVfizFDFM-2mIWt1RYvIkN9J1pwgUwBeDPpo,22664
|
|
20
21
|
transfer_queue/sampler/__init__.py,sha256=1oauDy2Dwb5GXhKi7tl5DWAHv8i4t2MQK1S4U36Sy4g,788
|
|
21
22
|
transfer_queue/sampler/base.py,sha256=wFti4dNJb3YArYpGzxA_YDfyUTdTG8wVz6HclPDyZPw,3299
|
|
22
23
|
transfer_queue/sampler/grpo_group_n_sampler.py,sha256=Kq3hGAz8mboBNvw4Dj0P8lP6Qs8TDojx81fxSh57w28,6566
|
|
23
24
|
transfer_queue/sampler/sequential_sampler.py,sha256=TY0eB-uFLUskwoNMgu3AvuF4G2KDkgjOkrlXZHy4Pls,2780
|
|
24
25
|
transfer_queue/storage/__init__.py,sha256=559q9ZOMLLhHXil5-iY3aLPnACoJLnZnKf-E0lvpQdk,978
|
|
25
|
-
transfer_queue/storage/simple_backend.py,sha256=
|
|
26
|
+
transfer_queue/storage/simple_backend.py,sha256=3bGqalWk0FX7SYjSxEWQeH1j-XAl_6heFWJvRbTgaII,19358
|
|
26
27
|
transfer_queue/storage/clients/__init__.py,sha256=WCa6pcijAixpopvelkZZ9ZRTwF_P3fYMmSEBb04CQZ4,915
|
|
27
28
|
transfer_queue/storage/clients/base.py,sha256=xXd9JBeTmW8tN4wsPocHhW-ERUEzx2YyYHZrtuQQIdI,690
|
|
28
29
|
transfer_queue/storage/clients/factory.py,sha256=lPOG8oMAgaTbrzkogcOULPJnGywa0F-m4vskkOQZhnU,2137
|
|
@@ -30,15 +31,16 @@ transfer_queue/storage/clients/yuanrong_client.py,sha256=MskYioa0BHHGRqzYbAg9DCn
|
|
|
30
31
|
transfer_queue/storage/managers/__init__.py,sha256=bkgaIN4Xa3IF26JJt4BK4bqct0SESdwA32wX5SLnuY0,959
|
|
31
32
|
transfer_queue/storage/managers/base.py,sha256=iSBethCS-pq0tHcRsolXtGD0V_0PtBmInNF4Gi-flfw,21628
|
|
32
33
|
transfer_queue/storage/managers/factory.py,sha256=58kp2mCKz1K8Ea7RWMsWxdDhN3y4ZhgE-G647AKq7-I,1752
|
|
33
|
-
transfer_queue/storage/managers/simple_backend_manager.py,sha256=
|
|
34
|
+
transfer_queue/storage/managers/simple_backend_manager.py,sha256=8F7fW7Z3SDeR22KCZqnFWPyC4qil5w_sp7QXHfvJsfY,28015
|
|
34
35
|
transfer_queue/storage/managers/yuanrong_manager.py,sha256=RsCmVVDNcTaMU9J9vwN1gu0-srdBCWG-W3Q_Si91uio,1250
|
|
35
36
|
transfer_queue/utils/__init__.py,sha256=vki-5RVaRBKxVc6Q7XPQox3VNPio2DvJYvRz0SZtu-w,586
|
|
37
|
+
transfer_queue/utils/perf_utils.py,sha256=WUl8AW9eHS5P9G3zq8g52MgMsyZqTqFXuBkXyBLmBLc,4100
|
|
36
38
|
transfer_queue/utils/serial_utils.py,sha256=J94wrNKVEJtZg22o7GByMs9e_UuwOgRqt1faC5Sy7DY,6048
|
|
37
39
|
transfer_queue/utils/utils.py,sha256=EE5S8YtyLNduohj1egKLHQlG4K2nrN-yAa8klBx9Nro,4846
|
|
38
|
-
transfer_queue/utils/zmq_utils.py,sha256=
|
|
39
|
-
transfer_queue/version/version,sha256=
|
|
40
|
-
transferqueue-0.1.4.
|
|
41
|
-
transferqueue-0.1.4.
|
|
42
|
-
transferqueue-0.1.4.
|
|
43
|
-
transferqueue-0.1.4.
|
|
44
|
-
transferqueue-0.1.4.
|
|
40
|
+
transfer_queue/utils/zmq_utils.py,sha256=ecJO1GV_AEAAKnnts-0t7jl19j9jUpBT07N6ZSR8op0,10730
|
|
41
|
+
transfer_queue/version/version,sha256=BEnC3jt-HrAwaHIIQhet48H4zzl05lM_-XlEH_IAuRc,11
|
|
42
|
+
transferqueue-0.1.4.dev1.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
|
43
|
+
transferqueue-0.1.4.dev1.dist-info/METADATA,sha256=pUiuMKBnGXQUsrtba_G1VAllyhJr64aCOuxyt3UMyQo,19502
|
|
44
|
+
transferqueue-0.1.4.dev1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
45
|
+
transferqueue-0.1.4.dev1.dist-info/top_level.txt,sha256=4MQO9VzdR-IUYG4xAidtwDNiWECIQZ_zx0G5KflYJkE,131
|
|
46
|
+
transferqueue-0.1.4.dev1.dist-info/RECORD,,
|
verify_fix.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
"""
|
|
3
|
+
验证脚本:测试单元素nested tensor的序列化/反序列化修复
|
|
4
|
+
|
|
5
|
+
此脚本验证了在TQ_ZERO_COPY_SERIALIZATION=True时,
|
|
6
|
+
序列化只有1个tensor的nested tensor能够正确区分于普通tensor。
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import torch
|
|
11
|
+
from tensordict import TensorDict
|
|
12
|
+
|
|
13
|
+
# 启用零拷贝序列化
|
|
14
|
+
os.environ["TQ_ZERO_COPY_SERIALIZATION"] = "True"
|
|
15
|
+
|
|
16
|
+
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType
|
|
17
|
+
|
|
18
|
+
def test_single_nested_tensor_fix():
|
|
19
|
+
"""验证单元素nested tensor修复"""
|
|
20
|
+
print("=" * 80)
|
|
21
|
+
print("测试:单元素nested tensor序列化/反序列化修复")
|
|
22
|
+
print("=" * 80)
|
|
23
|
+
|
|
24
|
+
# 创建单元素nested tensor和普通tensor
|
|
25
|
+
single_nested = torch.nested.as_nested_tensor([torch.randn(4, 3)], layout=torch.strided)
|
|
26
|
+
normal_tensor = torch.randn(1, 4, 3)
|
|
27
|
+
|
|
28
|
+
print("\n1. 创建测试数据:")
|
|
29
|
+
print(f" - 单元素nested tensor: {single_nested.shape}, is_nested={single_nested.is_nested}")
|
|
30
|
+
print(f" - 普通tensor: {normal_tensor.shape}, is_nested={normal_tensor.is_nested}")
|
|
31
|
+
|
|
32
|
+
# 创建TensorDict
|
|
33
|
+
td = TensorDict(
|
|
34
|
+
{
|
|
35
|
+
"single_nested_tensor": single_nested,
|
|
36
|
+
"normal_tensor": normal_tensor,
|
|
37
|
+
},
|
|
38
|
+
batch_size=1,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
print("\n2. 创建ZMQMessage并序列化:")
|
|
42
|
+
msg = ZMQMessage(
|
|
43
|
+
request_type=ZMQRequestType.PUT_DATA,
|
|
44
|
+
sender_id="test_sender",
|
|
45
|
+
receiver_id="test_receiver",
|
|
46
|
+
body={"data": td},
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# 序列化
|
|
50
|
+
serialized_data = msg.serialize()
|
|
51
|
+
print(f" - 序列化完成,数据列表长度: {len(serialized_data)}")
|
|
52
|
+
|
|
53
|
+
# 反序列化
|
|
54
|
+
print("\n3. 反序列化数据:")
|
|
55
|
+
decoded_msg = ZMQMessage.deserialize(serialized_data)
|
|
56
|
+
|
|
57
|
+
print(f" - 反序列化完成")
|
|
58
|
+
print(f" - decoded_msg.body['data']['single_nested_tensor'].is_nested = {decoded_msg.body['data']['single_nested_tensor'].is_nested}")
|
|
59
|
+
print(f" - decoded_msg.body['data']['normal_tensor'].is_nested = {decoded_msg.body['data']['normal_tensor'].is_nested}")
|
|
60
|
+
|
|
61
|
+
# 验证结果
|
|
62
|
+
print("\n4. 验证结果:")
|
|
63
|
+
success = True
|
|
64
|
+
|
|
65
|
+
# 检查单元素nested tensor
|
|
66
|
+
if decoded_msg.body["data"]["single_nested_tensor"].is_nested:
|
|
67
|
+
print(" ✓ 单元素nested tensor正确保持为nested类型")
|
|
68
|
+
else:
|
|
69
|
+
print(" ✗ 单元素nested tensor错误地变成了普通tensor类型")
|
|
70
|
+
success = False
|
|
71
|
+
|
|
72
|
+
# 检查普通tensor
|
|
73
|
+
if not decoded_msg.body["data"]["normal_tensor"].is_nested:
|
|
74
|
+
print(" ✓ 普通tensor正确保持为普通tensor类型")
|
|
75
|
+
else:
|
|
76
|
+
print(" ✗ 普通tensor错误地变成了nested类型")
|
|
77
|
+
success = False
|
|
78
|
+
|
|
79
|
+
# 检查数据内容
|
|
80
|
+
import torch
|
|
81
|
+
if torch.allclose(
|
|
82
|
+
decoded_msg.body["data"]["single_nested_tensor"][0],
|
|
83
|
+
single_nested[0]
|
|
84
|
+
):
|
|
85
|
+
print(" ✓ 单元素nested tensor数据内容正确")
|
|
86
|
+
else:
|
|
87
|
+
print(" ✗ 单元素nested tensor数据内容不正确")
|
|
88
|
+
success = False
|
|
89
|
+
|
|
90
|
+
if torch.allclose(
|
|
91
|
+
decoded_msg.body["data"]["normal_tensor"],
|
|
92
|
+
normal_tensor
|
|
93
|
+
):
|
|
94
|
+
print(" ✓ 普通tensor数据内容正确")
|
|
95
|
+
else:
|
|
96
|
+
print(" ✗ 普通tensor数据内容不正确")
|
|
97
|
+
success = False
|
|
98
|
+
|
|
99
|
+
print("\n" + "=" * 80)
|
|
100
|
+
if success:
|
|
101
|
+
print("✓ 所有测试通过!修复有效。")
|
|
102
|
+
else:
|
|
103
|
+
print("✗ 测试失败!修复可能存在问题。")
|
|
104
|
+
print("=" * 80)
|
|
105
|
+
|
|
106
|
+
return success
|
|
107
|
+
|
|
108
|
+
if __name__ == "__main__":
|
|
109
|
+
test_single_nested_tensor_fix()
|
|
File without changes
|
|
File without changes
|