TransferQueue 0.1.4.dev0__py3-none-any.whl → 0.1.4.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
performance_test.py CHANGED
@@ -137,8 +137,8 @@ class RayBandwidthTester:
137
137
  RemoteDataStore = RemoteDataStoreRemote
138
138
 
139
139
  self.remote_store = RemoteDataStore.options(
140
- num_cpus=0.01,
141
- resources={f"node:{WORKER_NODE_IP}": 0.001}
140
+ num_cpus=1,
141
+ resources={f"node:{WORKER_NODE_IP}": 1}
142
142
  ).remote()
143
143
 
144
144
  logger.info(f"Remote data store created on worker node {WORKER_NODE_IP}")
@@ -232,15 +232,15 @@ class TQBandwidthTester:
232
232
  # 限制在远程节点
233
233
  for storage_unit_rank in range(self.config.num_data_storage_units):
234
234
  storage_node = SimpleStorageUnit.options(
235
- num_cpus=0.01,
236
- resources={f"node:{WORKER_NODE_IP}": 0.001},
235
+ num_cpus=1,
236
+ resources={f"node:{WORKER_NODE_IP}": 1},
237
237
  runtime_env={"env_vars": {"OMP_NUM_THREADS": "2"}},
238
238
  ).remote(
239
239
  storage_unit_size=math.ceil(total_storage_size / self.config.num_data_storage_units)
240
240
  )
241
241
  self.data_system_storage_units[storage_unit_rank] = storage_node
242
242
  else:
243
- storage_placement_group = get_placement_group(self.config.num_data_storage_units, num_cpus_per_actor=0.01)
243
+ storage_placement_group = get_placement_group(self.config.num_data_storage_units, num_cpus_per_actor=1)
244
244
  for storage_unit_rank in range(self.config.num_data_storage_units):
245
245
  storage_node = SimpleStorageUnit.options(
246
246
  placement_group=storage_placement_group,
tests/test_metadata.py CHANGED
@@ -535,6 +535,153 @@ class TestBatchMeta:
535
535
  assert selected_field.production_status == ProductionStatus.READY_FOR_CONSUME
536
536
  assert selected_field.name == "field1"
537
537
 
538
+ def test_batch_meta_select_samples(self):
539
+ """Example: Select specific samples from a batch."""
540
+ fields = {
541
+ "field1": FieldMeta(name="field1", dtype=torch.float32, shape=(2,)),
542
+ "field2": FieldMeta(name="field2", dtype=torch.int64, shape=(3,)),
543
+ }
544
+ samples = [
545
+ SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
546
+ SampleMeta(partition_id="partition_0", global_index=1, fields=fields),
547
+ SampleMeta(partition_id="partition_0", global_index=2, fields=fields),
548
+ SampleMeta(partition_id="partition_0", global_index=3, fields=fields),
549
+ ]
550
+ batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"})
551
+
552
+ # Select samples at indices [0, 2]
553
+ selected_batch = batch.select_samples([0, 2])
554
+
555
+ # Check number of samples
556
+ assert len(selected_batch) == 2
557
+ # Check global indexes
558
+ assert selected_batch.global_indexes == [0, 2]
559
+ # Check fields are preserved
560
+ for sample in selected_batch.samples:
561
+ assert "field1" in sample.fields
562
+ assert "field2" in sample.fields
563
+ # Original batch is unchanged
564
+ assert len(batch) == 4
565
+ # Extra info is preserved
566
+ assert selected_batch.extra_info["test_key"] == "test_value"
567
+
568
+ def test_batch_meta_select_samples_all_indices(self):
569
+ """Example: Select all samples using complete index list."""
570
+ fields = {
571
+ "test_field": FieldMeta(
572
+ name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME
573
+ )
574
+ }
575
+ samples = [
576
+ SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
577
+ SampleMeta(partition_id="partition_0", global_index=1, fields=fields),
578
+ SampleMeta(partition_id="partition_0", global_index=2, fields=fields),
579
+ ]
580
+ batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"})
581
+
582
+ # Select all samples
583
+ selected_batch = batch.select_samples([0, 1, 2])
584
+
585
+ # All samples are selected
586
+ assert len(selected_batch) == 3
587
+ assert selected_batch.global_indexes == [0, 1, 2]
588
+ # Extra info is preserved
589
+ assert selected_batch.extra_info["test_key"] == "test_value"
590
+
591
+ def test_batch_meta_select_samples_single_sample(self):
592
+ """Example: Select a single sample from batch."""
593
+ fields = {
594
+ "test_field": FieldMeta(
595
+ name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME
596
+ )
597
+ }
598
+ samples = [
599
+ SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
600
+ SampleMeta(partition_id="partition_0", global_index=1, fields=fields),
601
+ SampleMeta(partition_id="partition_0", global_index=2, fields=fields),
602
+ ]
603
+ batch = BatchMeta(samples=samples)
604
+
605
+ # Select only the middle sample
606
+ selected_batch = batch.select_samples([1])
607
+
608
+ assert len(selected_batch) == 1
609
+ assert selected_batch.global_indexes == [1]
610
+ assert selected_batch.samples[0].batch_index == 0 # New batch index
611
+
612
+ def test_batch_meta_select_samples_empty_list(self):
613
+ """Example: Select with empty list returns empty batch."""
614
+ fields = {
615
+ "test_field": FieldMeta(
616
+ name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME
617
+ )
618
+ }
619
+ samples = [
620
+ SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
621
+ SampleMeta(partition_id="partition_0", global_index=1, fields=fields),
622
+ ]
623
+ batch = BatchMeta(samples=samples, extra_info={"test_key": "test_value"})
624
+
625
+ # Select with empty list
626
+ selected_batch = batch.select_samples([])
627
+
628
+ assert len(selected_batch) == 0
629
+ assert selected_batch.global_indexes == []
630
+ # Extra info is still preserved
631
+ assert selected_batch.extra_info["test_key"] == "test_value"
632
+
633
+ def test_batch_meta_select_samples_reverse_order(self):
634
+ """Example: Select samples in reverse order."""
635
+ fields = {
636
+ "test_field": FieldMeta(
637
+ name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME
638
+ )
639
+ }
640
+ samples = [
641
+ SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
642
+ SampleMeta(partition_id="partition_0", global_index=1, fields=fields),
643
+ SampleMeta(partition_id="partition_0", global_index=2, fields=fields),
644
+ ]
645
+ batch = BatchMeta(samples=samples)
646
+
647
+ # Select samples in reverse order
648
+ selected_batch = batch.select_samples([2, 1, 0])
649
+
650
+ assert len(selected_batch) == 3
651
+ assert selected_batch.global_indexes == [2, 1, 0]
652
+ # Batch indexes are re-assigned
653
+ assert selected_batch.samples[0].global_index == 2
654
+ assert selected_batch.samples[1].global_index == 1
655
+ assert selected_batch.samples[2].global_index == 0
656
+
657
+ def test_batch_meta_select_samples_with_extra_info(self):
658
+ """Example: Select samples preserves all extra info types."""
659
+ fields = {
660
+ "test_field": FieldMeta(
661
+ name="test_field", dtype=torch.float32, shape=(2,), production_status=ProductionStatus.READY_FOR_CONSUME
662
+ )
663
+ }
664
+ samples = [
665
+ SampleMeta(partition_id="partition_0", global_index=0, fields=fields),
666
+ SampleMeta(partition_id="partition_0", global_index=1, fields=fields),
667
+ ]
668
+ batch = BatchMeta(samples=samples)
669
+
670
+ # Add various extra info types
671
+ batch.extra_info["tensor"] = torch.randn(3, 4)
672
+ batch.extra_info["string"] = "test_string"
673
+ batch.extra_info["number"] = 42
674
+ batch.extra_info["list"] = [1, 2, 3]
675
+
676
+ # Select one sample
677
+ selected_batch = batch.select_samples([0])
678
+
679
+ # All extra info is preserved
680
+ assert "tensor" in selected_batch.extra_info
681
+ assert selected_batch.extra_info["string"] == "test_string"
682
+ assert selected_batch.extra_info["number"] == 42
683
+ assert selected_batch.extra_info["list"] == [1, 2, 3]
684
+
538
685
  def test_batch_meta_extra_info_operations(self):
539
686
  """Example: Extra info management operations."""
540
687
  fields = {
@@ -541,3 +541,56 @@ def test_nested_jagged_tensor_serialization(enable_zero_copy):
541
541
  # Verify individual components
542
542
  for i in range(len(outer_td["nested_jagged1"].unbind())):
543
543
  assert torch.allclose(decoded_msg.body["data"]["nested_jagged1"][i], outer_td["nested_jagged1"][i])
544
+
545
+
546
+ @pytest.mark.parametrize("enable_zero_copy", [True, False])
547
+ def test_single_nested_tensor_serialization(enable_zero_copy):
548
+ """Test serialization of nested tensor with only one element (edge case for zero-copy)."""
549
+ with patch("transfer_queue.utils.zmq_utils.TQ_ZERO_COPY_SERIALIZATION", enable_zero_copy):
550
+ from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType
551
+
552
+ # Create nested tensor with only one element
553
+ # This is the critical edge case where a nested tensor with 1 element
554
+ # must be distinguished from a regular tensor during deserialization
555
+ single_nested = torch.nested.as_nested_tensor([torch.randn(4, 3)], layout=torch.strided)
556
+ # For normal tensor, expand to batch_size=1 to match the nested tensor's batch dimension
557
+ normal_tensor = torch.randn(1, 4, 3)
558
+
559
+ # Create TensorDict with both types
560
+ td = TensorDict(
561
+ {
562
+ "single_nested_tensor": single_nested,
563
+ "normal_tensor": normal_tensor,
564
+ },
565
+ batch_size=1,
566
+ )
567
+
568
+ msg = ZMQMessage(
569
+ request_type=ZMQRequestType.PUT_DATA,
570
+ sender_id="test",
571
+ receiver_id="test",
572
+ request_id="test",
573
+ timestamp=0.0,
574
+ body={"data": td},
575
+ )
576
+
577
+ encoded_msg = msg.serialize()
578
+ decoded_msg = ZMQMessage.deserialize(encoded_msg)
579
+
580
+ # Verify batch sizes
581
+ assert decoded_msg.body["data"].batch_size == td.batch_size
582
+
583
+ # Verify normal tensor
584
+ assert torch.allclose(decoded_msg.body["data"]["normal_tensor"], td["normal_tensor"])
585
+ assert decoded_msg.body["data"]["normal_tensor"].shape == td["normal_tensor"].shape
586
+
587
+ # Verify single nested tensor is properly reconstructed as nested
588
+ assert decoded_msg.body["data"]["single_nested_tensor"].is_nested
589
+ assert decoded_msg.body["data"]["single_nested_tensor"].layout == torch.strided
590
+ assert len(decoded_msg.body["data"]["single_nested_tensor"].unbind()) == 1
591
+ assert torch.allclose(decoded_msg.body["data"]["single_nested_tensor"][0], td["single_nested_tensor"][0])
592
+
593
+ # Ensure the nested tensor with single element is correctly distinguished from regular tensor
594
+ # Both should have the same data but different types
595
+ assert not decoded_msg.body["data"]["normal_tensor"].is_nested
596
+ assert decoded_msg.body["data"]["single_nested_tensor"].is_nested
@@ -32,6 +32,7 @@ from transfer_queue.metadata import (
32
32
  SampleMeta,
33
33
  )
34
34
  from transfer_queue.sampler import BaseSampler, SequentialSampler
35
+ from transfer_queue.utils.perf_utils import IntervalPerfMonitor
35
36
  from transfer_queue.utils.utils import (
36
37
  ProductionStatus,
37
38
  TransferQueueRole,
@@ -584,7 +585,7 @@ class TransferQueueController:
584
585
 
585
586
  self.partitions[partition_id] = DataPartitionStatus(partition_id=partition_id)
586
587
 
587
- logger.info(f"Created partition {partition_id} with dynamic capacity")
588
+ logger.info(f"Created partition {partition_id}")
588
589
  return True
589
590
 
590
591
  def get_partition(self, partition_id: str) -> Optional[DataPartitionStatus]:
@@ -1008,7 +1009,7 @@ class TransferQueueController:
1008
1009
  poller = zmq.Poller()
1009
1010
  poller.register(self.handshake_socket, zmq.POLLIN)
1010
1011
 
1011
- logger.info(f"Dynamic Controller {self.controller_id} started waiting for storage connections...")
1012
+ logger.info(f"Controller {self.controller_id} started waiting for storage connections...")
1012
1013
 
1013
1014
  while True:
1014
1015
  socks = dict(poller.poll(TQ_CONTROLLER_CONNECTION_CHECK_INTERVAL * 1000))
@@ -1036,23 +1037,23 @@ class TransferQueueController:
1036
1037
  self._connected_storage_managers.add(storage_manager_id)
1037
1038
  storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown")
1038
1039
  logger.info(
1039
- f"Dynamic Controller {self.controller_id} received handshake from "
1040
+ f"Controller {self.controller_id} received handshake from "
1040
1041
  f"storage manager {storage_manager_id} (type: {storage_manager_type}). "
1041
1042
  f"Total connected: {len(self._connected_storage_managers)}"
1042
1043
  )
1043
1044
  else:
1044
1045
  logger.debug(
1045
- f"Dynamic Controller {self.controller_id} received duplicate handshake from "
1046
+ f"Controller {self.controller_id} received duplicate handshake from "
1046
1047
  f"storage manager {storage_manager_id}. Resending ACK."
1047
1048
  )
1048
1049
 
1049
1050
  except Exception as e:
1050
- logger.error(f"Dynamic Controller {self.controller_id} error processing handshake: {e}")
1051
+ logger.error(f"Controller {self.controller_id} error processing handshake: {e}")
1051
1052
 
1052
1053
  def _start_process_handshake(self):
1053
1054
  """Start the handshake process thread."""
1054
1055
  self.wait_connection_thread = Thread(
1055
- target=self._wait_connection, name="DynamicTransferQueueControllerWaitConnectionThread", daemon=True
1056
+ target=self._wait_connection, name="TransferQueueControllerWaitConnectionThread", daemon=True
1056
1057
  )
1057
1058
  self.wait_connection_thread.start()
1058
1059
 
@@ -1060,7 +1061,7 @@ class TransferQueueController:
1060
1061
  """Start the data status update processing thread."""
1061
1062
  self.process_update_data_status_thread = Thread(
1062
1063
  target=self._update_data_status,
1063
- name="DynamicTransferQueueControllerProcessUpdateDataStatusThread",
1064
+ name="TransferQueueControllerProcessUpdateDataStatusThread",
1064
1065
  daemon=True,
1065
1066
  )
1066
1067
  self.process_update_data_status_thread.start()
@@ -1068,12 +1069,17 @@ class TransferQueueController:
1068
1069
  def _start_process_request(self):
1069
1070
  """Start the request processing thread."""
1070
1071
  self.process_request_thread = Thread(
1071
- target=self._process_request, name="DynamicTransferQueueControllerProcessRequestThread", daemon=True
1072
+ target=self._process_request, name="TransferQueueControllerProcessRequestThread", daemon=True
1072
1073
  )
1073
1074
  self.process_request_thread.start()
1074
1075
 
1075
1076
  def _process_request(self):
1076
1077
  """Main request processing loop - adapted for partition-based operations."""
1078
+
1079
+ logger.info(f"[{self.controller_id}]: start processing requests...")
1080
+
1081
+ perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id)
1082
+
1077
1083
  while True:
1078
1084
  messages = self.request_handle_socket.recv_multipart()
1079
1085
  identity = messages.pop(0)
@@ -1081,88 +1087,96 @@ class TransferQueueController:
1081
1087
  request_msg = ZMQMessage.deserialize(serialized_msg)
1082
1088
 
1083
1089
  if request_msg.request_type == ZMQRequestType.GET_META:
1084
- params = request_msg.body
1085
-
1086
- metadata = self.get_metadata(
1087
- data_fields=params["data_fields"],
1088
- batch_size=params["batch_size"],
1089
- partition_id=params["partition_id"],
1090
- mode=params.get("mode", "fetch"),
1091
- task_name=params.get("task_name"),
1092
- sampling_config=params.get("sampling_config"),
1093
- )
1094
-
1095
- response_msg = ZMQMessage.create(
1096
- request_type=ZMQRequestType.GET_META_RESPONSE,
1097
- sender_id=self.controller_id,
1098
- receiver_id=request_msg.sender_id,
1099
- body={"metadata": metadata},
1100
- )
1101
-
1102
- elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META:
1103
- params = request_msg.body
1104
- partition_id = params["partition_id"]
1105
-
1106
- metadata = self.get_metadata(
1107
- data_fields=[],
1108
- partition_id=partition_id,
1109
- mode="insert",
1110
- )
1111
- response_msg = ZMQMessage.create(
1112
- request_type=ZMQRequestType.GET_CLEAR_META_RESPONSE,
1113
- sender_id=self.controller_id,
1114
- receiver_id=request_msg.sender_id,
1115
- body={"metadata": metadata},
1116
- )
1117
- elif request_msg.request_type == ZMQRequestType.CLEAR_META:
1118
- params = request_msg.body
1119
- partition_id = params["partition_id"]
1090
+ with perf_monitor.measure(op_type="GET_META"):
1091
+ params = request_msg.body
1092
+
1093
+ metadata = self.get_metadata(
1094
+ data_fields=params["data_fields"],
1095
+ batch_size=params["batch_size"],
1096
+ partition_id=params["partition_id"],
1097
+ mode=params.get("mode", "fetch"),
1098
+ task_name=params.get("task_name"),
1099
+ sampling_config=params.get("sampling_config"),
1100
+ )
1120
1101
 
1121
- clear_success = self.clear(partition_id)
1122
- if clear_success:
1123
1102
  response_msg = ZMQMessage.create(
1124
- request_type=ZMQRequestType.CLEAR_META_RESPONSE,
1103
+ request_type=ZMQRequestType.GET_META_RESPONSE,
1125
1104
  sender_id=self.controller_id,
1126
1105
  receiver_id=request_msg.sender_id,
1127
- body={"message": f"Clear operation completed by controller {self.controller_id}"},
1106
+ body={"metadata": metadata},
1107
+ )
1108
+
1109
+ elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META:
1110
+ with perf_monitor.measure(op_type="GET_CLEAR_META"):
1111
+ params = request_msg.body
1112
+ partition_id = params["partition_id"]
1113
+
1114
+ metadata = self.get_metadata(
1115
+ data_fields=[],
1116
+ partition_id=partition_id,
1117
+ mode="insert",
1128
1118
  )
1129
- else:
1130
1119
  response_msg = ZMQMessage.create(
1131
- request_type=ZMQRequestType.CLEAR_META_RESPONSE,
1120
+ request_type=ZMQRequestType.GET_CLEAR_META_RESPONSE,
1132
1121
  sender_id=self.controller_id,
1133
1122
  receiver_id=request_msg.sender_id,
1134
- body={"error": f"Clear operation failed for partition {partition_id}"},
1123
+ body={"metadata": metadata},
1135
1124
  )
1125
+ elif request_msg.request_type == ZMQRequestType.CLEAR_META:
1126
+ with perf_monitor.measure(op_type="CLEAR_META"):
1127
+ params = request_msg.body
1128
+ partition_id = params["partition_id"]
1129
+
1130
+ clear_success = self.clear(partition_id)
1131
+ if clear_success:
1132
+ response_msg = ZMQMessage.create(
1133
+ request_type=ZMQRequestType.CLEAR_META_RESPONSE,
1134
+ sender_id=self.controller_id,
1135
+ receiver_id=request_msg.sender_id,
1136
+ body={"message": f"Clear operation completed by controller {self.controller_id}"},
1137
+ )
1138
+ else:
1139
+ response_msg = ZMQMessage.create(
1140
+ request_type=ZMQRequestType.CLEAR_META_RESPONSE,
1141
+ sender_id=self.controller_id,
1142
+ receiver_id=request_msg.sender_id,
1143
+ body={"error": f"Clear operation failed for partition {partition_id}"},
1144
+ )
1136
1145
 
1137
1146
  elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION:
1138
- # Handle consumption status checks
1139
- params = request_msg.body
1140
-
1141
- consumption_status = self.get_consumption_status(params["partition_id"], params["task_name"])
1142
- sample_filter = params.get("sample_filter")
1143
-
1144
- if consumption_status is not None and sample_filter:
1145
- batch_status = consumption_status[sample_filter]
1146
- consumed = torch.all(batch_status == 1).item()
1147
- elif consumption_status is not None:
1148
- batch_status = consumption_status
1149
- consumed = torch.all(batch_status == 1).item()
1150
- else:
1151
- consumed = False
1152
-
1153
- response_msg = ZMQMessage.create(
1154
- request_type=ZMQRequestType.CONSUMPTION_RESPONSE,
1155
- sender_id=self.controller_id,
1156
- receiver_id=request_msg.sender_id,
1157
- body={
1158
- "partition_id": params["partition_id"],
1159
- "consumed": consumed,
1160
- },
1161
- )
1147
+ with perf_monitor.measure(op_type="CHECK_CONSUMPTION"):
1148
+ # Handle consumption status checks
1149
+ params = request_msg.body
1150
+
1151
+ consumption_status = self.get_consumption_status(params["partition_id"], params["task_name"])
1152
+ sample_filter = params.get("sample_filter")
1153
+
1154
+ if consumption_status is not None and sample_filter:
1155
+ batch_status = consumption_status[sample_filter]
1156
+ consumed = torch.all(batch_status == 1).item()
1157
+ elif consumption_status is not None:
1158
+ batch_status = consumption_status
1159
+ consumed = torch.all(batch_status == 1).item()
1160
+ else:
1161
+ consumed = False
1162
+
1163
+ response_msg = ZMQMessage.create(
1164
+ request_type=ZMQRequestType.CONSUMPTION_RESPONSE,
1165
+ sender_id=self.controller_id,
1166
+ receiver_id=request_msg.sender_id,
1167
+ body={
1168
+ "partition_id": params["partition_id"],
1169
+ "consumed": consumed,
1170
+ },
1171
+ )
1162
1172
  self.request_handle_socket.send_multipart([identity, *response_msg.serialize()])
1163
1173
 
1164
1174
  def _update_data_status(self):
1165
1175
  """Process data status update messages from storage units - adapted for partitions."""
1176
+ logger.info(f"[{self.controller_id}]: start receiving update_data_status requests...")
1177
+
1178
+ perf_monitor = IntervalPerfMonitor(caller_name=self.controller_id)
1179
+
1166
1180
  while True:
1167
1181
  messages = self.data_status_update_socket.recv_multipart()
1168
1182
  identity = messages.pop(0)
@@ -1170,32 +1184,33 @@ class TransferQueueController:
1170
1184
  request_msg = ZMQMessage.deserialize(serialized_msg)
1171
1185
 
1172
1186
  if request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE:
1173
- message_data = request_msg.body
1174
- partition_id = message_data.get("partition_id")
1175
-
1176
- # Update production status
1177
- success = self.update_production_status(
1178
- partition_id=partition_id,
1179
- global_indexes=message_data.get("global_indexes", []),
1180
- field_names=message_data.get("fields", []),
1181
- dtypes=message_data.get("dtypes", {}),
1182
- shapes=message_data.get("shapes", {}),
1183
- )
1187
+ with perf_monitor.measure(op_type="NOTIFY_DATA_UPDATE"):
1188
+ message_data = request_msg.body
1189
+ partition_id = message_data.get("partition_id")
1190
+
1191
+ # Update production status
1192
+ success = self.update_production_status(
1193
+ partition_id=partition_id,
1194
+ global_indexes=message_data.get("global_indexes", []),
1195
+ field_names=message_data.get("fields", []),
1196
+ dtypes=message_data.get("dtypes", {}),
1197
+ shapes=message_data.get("shapes", {}),
1198
+ )
1184
1199
 
1185
- if success:
1186
- logger.info(f"Updated production status for partition {partition_id}")
1187
-
1188
- # Send acknowledgment
1189
- response_msg = ZMQMessage.create(
1190
- request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK,
1191
- sender_id=self.controller_id,
1192
- body={
1193
- "controller_id": self.controller_id,
1194
- "partition_id": partition_id,
1195
- "success": success,
1196
- },
1197
- )
1198
- self.data_status_update_socket.send_multipart([identity, *response_msg.serialize()])
1200
+ if success:
1201
+ logger.info(f"Updated production status for partition {partition_id}")
1202
+
1203
+ # Send acknowledgment
1204
+ response_msg = ZMQMessage.create(
1205
+ request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK,
1206
+ sender_id=self.controller_id,
1207
+ body={
1208
+ "controller_id": self.controller_id,
1209
+ "partition_id": partition_id,
1210
+ "success": success,
1211
+ },
1212
+ )
1213
+ self.data_status_update_socket.send_multipart([identity, *response_msg.serialize()])
1199
1214
 
1200
1215
  def get_zmq_server_info(self) -> ZMQServerInfo:
1201
1216
  """Get ZMQ server connection information."""
@@ -261,6 +261,28 @@ class BatchMeta:
261
261
  object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples))
262
262
  return self
263
263
 
264
+ def select_samples(self, sample_indices: list[int]) -> "BatchMeta":
265
+ """
266
+ Select specific samples from this batch.
267
+ This will construct a new BatchMeta instance containing only the specified samples.
268
+
269
+ Args:
270
+ sample_indices (list[int]): List of sample indices to retain.
271
+
272
+ Returns:
273
+ BatchMeta: A new BatchMeta instance containing only the specified samples.
274
+ """
275
+
276
+ if any(i < 0 or i >= len(self.samples) for i in sample_indices):
277
+ raise ValueError(f"Sample indices must be in range [0, {len(self.samples)})")
278
+
279
+ selected_samples = [self.samples[i] for i in sample_indices]
280
+
281
+ # construct new BatchMeta instance
282
+ selected_batch_meta = BatchMeta(samples=selected_samples, extra_info=self.extra_info.copy())
283
+
284
+ return selected_batch_meta
285
+
264
286
  def select_fields(self, field_names: list[str]) -> "BatchMeta":
265
287
  """
266
288
  Select specific fields from all samples in this batch.
@@ -287,7 +309,7 @@ class BatchMeta:
287
309
  def __getitem__(self, item):
288
310
  if isinstance(item, int | np.integer):
289
311
  sample_meta = self.samples[item] if self.samples else []
290
- return BatchMeta(samples=[sample_meta], extra_info=self.extra_info)
312
+ return BatchMeta(samples=[sample_meta], extra_info=self.extra_info.copy())
291
313
  else:
292
314
  raise TypeError(f"Indexing with {type(item)} is not supported now!")
293
315
 
@@ -508,6 +530,13 @@ class BatchMeta:
508
530
  extra_info = {}
509
531
  return cls(samples=[], extra_info=extra_info)
510
532
 
533
+ def __str__(self):
534
+ sample_strs = ", ".join(str(sample) for sample in self.samples)
535
+ return (
536
+ f"BatchMeta(size={self.size}, field_names={self.field_names}, is_ready={self.is_ready}, "
537
+ f"samples=[{sample_strs}], extra_info={self.extra_info})"
538
+ )
539
+
511
540
 
512
541
  def _union_fields(fields1: dict[str, FieldMeta], fields2: dict[str, FieldMeta]) -> dict[str, FieldMeta]:
513
542
  """Union two sample's fields. If fields overlap, the fields in fields1 will be replaced by fields2."""
@@ -173,6 +173,8 @@ class AsyncSimpleStorageManager(TransferQueueStorageManager):
173
173
  metadata: BatchMeta containing storage location information.
174
174
  """
175
175
 
176
+ logger.info(f"{__class__.__name__}: receive put_data request, putting {metadata.size} samples.")
177
+
176
178
  # group samples by storage unit
177
179
  storage_meta_groups = build_storage_meta_groups(
178
180
  metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping
@@ -228,7 +230,8 @@ class AsyncSimpleStorageManager(TransferQueueStorageManager):
228
230
  else NonTensorStack(*transfer_data["field_data"][field])
229
231
  )
230
232
  for field in transfer_data["field_data"]
231
- }
233
+ },
234
+ batch_size=len(local_indexes),
232
235
  )
233
236
 
234
237
  request_msg = ZMQMessage.create(
@@ -263,6 +266,8 @@ class AsyncSimpleStorageManager(TransferQueueStorageManager):
263
266
  TensorDict containing the retrieved data.
264
267
  """
265
268
 
269
+ logger.info(f"{__class__.__name__}: receive get_data request, getting {metadata.size} samples.")
270
+
266
271
  # group samples by storage unit
267
272
  storage_meta_groups = build_storage_meta_groups(
268
273
  metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping
@@ -28,14 +28,14 @@ from ray.util import get_node_ip_address
28
28
  from tensordict import NonTensorStack, TensorDict
29
29
 
30
30
  from transfer_queue.metadata import SampleMeta
31
+ from transfer_queue.utils.perf_utils import IntervalPerfMonitor
31
32
  from transfer_queue.utils.utils import TransferQueueRole
32
33
  from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket, get_free_port
33
34
 
34
35
  logger = logging.getLogger(__name__)
35
36
  logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
36
37
 
37
- # ZMQ timeouts (in seconds) and retry configurations
38
- TQ_STORAGE_POLLER_TIMEOUT = int(os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 5))
38
+ TQ_STORAGE_POLLER_TIMEOUT = int(os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 5)) # in seconds
39
39
 
40
40
 
41
41
  class StorageUnitData:
@@ -200,7 +200,7 @@ class SimpleStorageUnit:
200
200
  def _start_process_put_get(self) -> None:
201
201
  """Create a daemon thread and start put/get process."""
202
202
  self.process_put_get_thread = Thread(
203
- target=self._process_put_get, name=f"StorageUnitProcessPutGetThread-{self.zmq_server_info.id}", daemon=True
203
+ target=self._process_put_get, name=f"StorageUnitProcessPutGetThread-{self.storage_unit_id}", daemon=True
204
204
  )
205
205
  self.process_put_get_thread.start()
206
206
 
@@ -209,6 +209,10 @@ class SimpleStorageUnit:
209
209
  poller = zmq.Poller()
210
210
  poller.register(self.put_get_socket, zmq.POLLIN)
211
211
 
212
+ logger.info(f"[{self.storage_unit_id}]: start processing put/get requests...")
213
+
214
+ perf_monitor = IntervalPerfMonitor(caller_name=self.storage_unit_id)
215
+
212
216
  while True:
213
217
  socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000))
214
218
 
@@ -219,29 +223,32 @@ class SimpleStorageUnit:
219
223
  request_msg = ZMQMessage.deserialize(serialized_msg)
220
224
  operation = request_msg.request_type
221
225
  try:
222
- logger.debug(f"[{self.zmq_server_info.id}]: receive operation: {operation}, message: {request_msg}")
226
+ logger.debug(f"[{self.storage_unit_id}]: receive operation: {operation}, message: {request_msg}")
223
227
 
224
228
  if operation == ZMQRequestType.PUT_DATA:
225
- response_msg = self._handle_put(request_msg)
229
+ with perf_monitor.measure(op_type="PUT_DATA"):
230
+ response_msg = self._handle_put(request_msg)
226
231
  elif operation == ZMQRequestType.GET_DATA:
227
- response_msg = self._handle_get(request_msg)
232
+ with perf_monitor.measure(op_type="GET_DATA"):
233
+ response_msg = self._handle_get(request_msg)
228
234
  elif operation == ZMQRequestType.CLEAR_DATA:
229
- response_msg = self._handle_clear(request_msg)
235
+ with perf_monitor.measure(op_type="CLEAR_DATA"):
236
+ response_msg = self._handle_clear(request_msg)
230
237
  else:
231
238
  response_msg = ZMQMessage.create(
232
239
  request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR,
233
- sender_id=self.zmq_server_info.id,
240
+ sender_id=self.storage_unit_id,
234
241
  body={
235
- "message": f"Storage unit id #{self.zmq_server_info.id} "
242
+ "message": f"Storage unit id #{self.storage_unit_id} "
236
243
  f"receive invalid operation: {operation}."
237
244
  },
238
245
  )
239
246
  except Exception as e:
240
247
  response_msg = ZMQMessage.create(
241
248
  request_type=ZMQRequestType.PUT_GET_ERROR,
242
- sender_id=self.zmq_server_info.id,
249
+ sender_id=self.storage_unit_id,
243
250
  body={
244
- "message": f"Storage unit id #{self.zmq_server_info.id} occur error in processing "
251
+ "message": f"Storage unit id #{self.storage_unit_id} occur error in processing "
245
252
  f"put/get/clear request, detail error message: {str(e)}."
246
253
  },
247
254
  )
@@ -268,17 +275,17 @@ class SimpleStorageUnit:
268
275
 
269
276
  # After put operation finish, send a message to the client
270
277
  response_msg = ZMQMessage.create(
271
- request_type=ZMQRequestType.PUT_DATA_RESPONSE, sender_id=self.zmq_server_info.id, body={}
278
+ request_type=ZMQRequestType.PUT_DATA_RESPONSE, sender_id=self.storage_unit_id, body={}
272
279
  )
273
280
 
274
281
  return response_msg
275
282
  except Exception as e:
276
283
  return ZMQMessage.create(
277
284
  request_type=ZMQRequestType.PUT_ERROR,
278
- sender_id=self.zmq_server_info.id,
285
+ sender_id=self.storage_unit_id,
279
286
  body={
280
287
  "message": f"Failed to put data into storage unit id "
281
- f"#{self.zmq_server_info.id}, detail error message: {str(e)}"
288
+ f"#{self.storage_unit_id}, detail error message: {str(e)}"
282
289
  },
283
290
  )
284
291
 
@@ -300,7 +307,7 @@ class SimpleStorageUnit:
300
307
 
301
308
  response_msg = ZMQMessage.create(
302
309
  request_type=ZMQRequestType.GET_DATA_RESPONSE,
303
- sender_id=self.zmq_server_info.id,
310
+ sender_id=self.storage_unit_id,
304
311
  body={
305
312
  "data": result_data,
306
313
  },
@@ -308,9 +315,9 @@ class SimpleStorageUnit:
308
315
  except Exception as e:
309
316
  response_msg = ZMQMessage.create(
310
317
  request_type=ZMQRequestType.GET_ERROR,
311
- sender_id=self.zmq_server_info.id,
318
+ sender_id=self.storage_unit_id,
312
319
  body={
313
- "message": f"Failed to get data from storage unit id #{self.zmq_server_info.id}, "
320
+ "message": f"Failed to get data from storage unit id #{self.storage_unit_id}, "
314
321
  f"detail error message: {str(e)}"
315
322
  },
316
323
  )
@@ -333,15 +340,15 @@ class SimpleStorageUnit:
333
340
 
334
341
  response_msg = ZMQMessage.create(
335
342
  request_type=ZMQRequestType.CLEAR_DATA_RESPONSE,
336
- sender_id=self.zmq_server_info.id,
337
- body={"message": f"Clear data in storage unit id #{self.zmq_server_info.id} successfully."},
343
+ sender_id=self.storage_unit_id,
344
+ body={"message": f"Clear data in storage unit id #{self.storage_unit_id} successfully."},
338
345
  )
339
346
  except Exception as e:
340
347
  response_msg = ZMQMessage.create(
341
348
  request_type=ZMQRequestType.CLEAR_DATA_ERROR,
342
- sender_id=self.zmq_server_info.id,
349
+ sender_id=self.storage_unit_id,
343
350
  body={
344
- "message": f"Failed to clear data in storage unit id #{self.zmq_server_info.id}, "
351
+ "message": f"Failed to clear data in storage unit id #{self.storage_unit_id}, "
345
352
  f"detail error message: {str(e)}"
346
353
  },
347
354
  )
@@ -0,0 +1,104 @@
1
+ import logging
2
+ import os
3
+ import time
4
+ from collections import defaultdict
5
+ from contextlib import contextmanager
6
+
7
+ logger = logging.getLogger(__name__)
8
+ logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.INFO))
9
+
10
+ TQ_PERF_LOG_FLUSH_INTERVAL = float(os.environ.get("TQ_PERF_LOG_FLUSH_INTERVAL", 10)) # in seconds
11
+
12
+
13
+ class IntervalPerfMonitor:
14
+ """
15
+ Monitors and logs performance statistics for operations over configurable time intervals.
16
+
17
+ This class is designed to be used in contexts where you want to track the number of successful
18
+ operations and their processing times, and periodically log summary statistics such as request
19
+ counts, rates, and timing metrics (average, max, min) per operation type.
20
+
21
+ Usage:
22
+ monitor = IntervalPerfMonitor("Your Class")
23
+ with monitor.measure("method_name"):
24
+ # perform upload operation
25
+
26
+ At each interval (controlled by TQ_PERF_LOG_FLUSH_INTERVAL), the monitor logs aggregated
27
+ statistics and resets its counters.
28
+
29
+ Args:
30
+ caller_name (str): Name of the component or caller using the monitor, included in logs.
31
+ """
32
+
33
+ def __init__(self, caller_name: str):
34
+ self.caller_name = caller_name
35
+ self.last_flush_time = time.perf_counter()
36
+
37
+ self.success_counts: dict[str, int] = defaultdict(int)
38
+ self.process_time: dict[str, list[float]] = defaultdict(list)
39
+
40
+ def _flush_logs(self):
41
+ """
42
+ Internal method to conditionally flush (log) aggregated performance statistics.
43
+
44
+ If the configured time interval (TQ_PERF_LOG_FLUSH_INTERVAL) has passed since the last flush,
45
+ this method logs:
46
+ - Total number of successful requests and requests per minute.
47
+ - Average processing time across all operations.
48
+ - For each operation type: request count, requests per minute, average, max, and min processing times.
49
+ After logging, all statistics are reset and the flush timer is updated.
50
+ """
51
+ now = time.perf_counter()
52
+
53
+ # only flush if the interval has passed
54
+ if (now - self.last_flush_time) >= TQ_PERF_LOG_FLUSH_INTERVAL:
55
+ minutes = (now - self.last_flush_time) / 60
56
+
57
+ total_requests = sum(self.success_counts.values())
58
+ total_process_time = sum(sum(time_list) for time_list in self.process_time.values())
59
+ total_avg_process_time = total_process_time / total_requests if total_requests > 0 else 0.0
60
+
61
+ # max/min/avg time for each operation type
62
+ op_detail_stats = []
63
+ for op_type, count in self.success_counts.items():
64
+ times = self.process_time[op_type]
65
+ if not times:
66
+ op_avg = op_max = op_min = 0.0
67
+ else:
68
+ op_avg = sum(times) / len(times)
69
+ op_max = max(times)
70
+ op_min = min(times)
71
+
72
+ op_detail_stats.append(
73
+ f"{op_type}: req_count={count}, req/min={count / minutes:.2f}, "
74
+ f"avg_time={op_avg:.6f}s, max_time={op_max:.6f}s, min_time={op_min:.6f}s"
75
+ )
76
+
77
+ log_msg = (
78
+ f"{self.caller_name}: [Performance] "
79
+ f"Total success requests: {total_requests}, "
80
+ f"Total req/min: {total_requests / minutes:.2f}, "
81
+ f"Total avg process time: {total_avg_process_time:.4f}s; \n"
82
+ f"Time range: last {minutes:.2f} minutes; \n"
83
+ f"Per-operation statistics: {'; '.join(op_detail_stats)}"
84
+ )
85
+
86
+ logger.info(log_msg)
87
+
88
+ # reset counts
89
+ self.success_counts.clear()
90
+ self.process_time.clear()
91
+ self.last_flush_time = now
92
+
93
+ @contextmanager
94
+ def measure(self, op_type: str):
95
+ start_time = time.perf_counter()
96
+ try:
97
+ yield
98
+ finally:
99
+ cost = time.perf_counter() - start_time
100
+ self.success_counts[op_type] += 1
101
+ self.process_time[op_type].append(cost)
102
+
103
+ # try flush logs
104
+ self._flush_logs()
@@ -22,6 +22,7 @@ from dataclasses import dataclass
22
22
  from typing import Any, Optional, TypeAlias
23
23
  from uuid import uuid4
24
24
 
25
+ import numpy as np
25
26
  import psutil
26
27
  import torch
27
28
  import zmq
@@ -162,15 +163,15 @@ class ZMQMessage:
162
163
  tensor_list = tensor.unbind()
163
164
  tensor_count = len(tensor_list)
164
165
  serialized_tensors = [_encoder.encode(inner_tensor) for inner_tensor in tensor_list]
165
- return tensor_count, serialized_tensors
166
+ return tensor_count, serialized_tensors # tensor_count may equal to 1 for single nested tensor
166
167
  else:
167
- return 1, [_encoder.encode(tensor)]
168
+ return -1, [_encoder.encode(tensor)] # use -1 to indicate regular single tensor
168
169
 
169
170
  # Use map to process all tensors in parallel-like fashion
170
171
  nested_tensor_info_and_serialized_tensors = list(map(process_tensor, tensors))
171
172
 
172
173
  # Extract nested_tensor_info and flatten serialized tensors using itertools
173
- nested_tensor_info = [info for info, _ in nested_tensor_info_and_serialized_tensors]
174
+ nested_tensor_info = np.array([info for info, _ in nested_tensor_info_and_serialized_tensors])
174
175
  double_layer_serialized_tensors: list[list[bytestr]] = list(
175
176
  itertools.chain.from_iterable(serialized for _, serialized in nested_tensor_info_and_serialized_tensors)
176
177
  )
@@ -209,14 +210,14 @@ class ZMQMessage:
209
210
  f"When TQ_ZERO_COPY_SERIALIZATION is enabled, input data should be a list, but got {type(data)}."
210
211
  )
211
212
 
212
- tensor_nums = sum(nested_tensor_info)
213
+ tensor_nums = np.abs(nested_tensor_info).sum()
213
214
  if tensor_nums != len(single_tensors):
214
215
  raise ValueError(f"Expecting {tensor_nums} tensors, but got {len(single_tensors)}.")
215
216
 
216
217
  tensors = [None] * len(nested_tensor_info)
217
218
  current_idx = 0
218
219
  for i, tensor_num in enumerate(nested_tensor_info):
219
- if tensor_num == 1:
220
+ if tensor_num == -1:
220
221
  tensors[i] = single_tensors[current_idx]
221
222
  current_idx += 1
222
223
  else:
@@ -1 +1 @@
1
- 0.1.4.dev0
1
+ 0.1.4.dev1
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: TransferQueue
3
- Version: 0.1.4.dev0
3
+ Version: 0.1.4.dev1
4
4
  Summary: TransferQueue: An Asynchronous Streaming Data Management Module
5
5
  Author-email: The TransferQueue Team <hanzy19@tsinghua.org.cn>
6
6
  License: Apache-2.0
@@ -1,6 +1,7 @@
1
- performance_test.py,sha256=tKM1m-IDlcOv003kaVgvrO37r04OMfCSQ7fbrefqV5A,14451
1
+ performance_test.py,sha256=Yl4k-ln5iHAJ5HIfK--Lx-tlEXQcI71Juzve-lBesUI,14434
2
2
  serial_profiling_demo.py,sha256=b4GEoF8bSIawQlWOIK2eg9Tgn_-Q8n_KMsfRHLdK1mc,3961
3
3
  serial_profiling_demo_nested_non_continues_test.py,sha256=ahXMGDsRc_bwE3lbVE6L-8_guezv6rpG24pKKIxVKF8,3849
4
+ verify_fix.py,sha256=IcZ52jbYoxFb4MRwZMx0f52E8ATFn7GbUp8fmVY31RU,3522
4
5
  recipe/simple_use_case/async_demo.py,sha256=wGQCnXAzElE1D6-POwJVhR4UbT3kEJdCnnF-ZL5jPKg,13709
5
6
  recipe/simple_use_case/sync_demo.py,sha256=Rfatf205t-gHaxShk2n8LavMVoMLWZpguHRKpKgMDNo,8680
6
7
  tests/test_async_simple_storage_manager.py,sha256=qYmSJV6LYSXAzjnlzWZ5HboOhr4B1gMgNwOc5_esVaU,12366
@@ -8,21 +9,21 @@ tests/test_client.py,sha256=Wj1Eswt9qVfL0-192Mwx5_ICaQO7iso0WMJxKpGEvFg,15546
8
9
  tests/test_controller.py,sha256=ZcvFCC3jSnNN_fEerjA37RQv0SSO0Xh8vjcL2mvF03o,11084
9
10
  tests/test_controller_data_partitions.py,sha256=RQExMFcuXblpUZE_LaFaoncbwX4-YlPcBi69siRfvzY,19363
10
11
  tests/test_kv_storage_manager.py,sha256=j45VZ14H8MY8-c4CJI1MRgljoHijaK-tkQAD_-9lmzw,4012
11
- tests/test_metadata.py,sha256=iPbYZIyDavtG2QMFPqbKXzRDTgNkcssjnnQkKPisim0,26044
12
+ tests/test_metadata.py,sha256=q0X8UuxTmx-JjZlXR8nqr7cm2YbGFHWNbO9kZHgFFYU,32415
12
13
  tests/test_samplers.py,sha256=CvYqfmbHEWWa1RyymztCAn0GcitAPOBbfJ4ud1VvO2o,19168
13
- tests/test_serial_utils_on_cpu.py,sha256=_-I88rjj_uzWWE_tiT_PcfH5_j15j-Q7iB_Gq-dU6B4,21122
14
+ tests/test_serial_utils_on_cpu.py,sha256=iZII_-oVBu3KQ8Afpf7roqET9mIMegrsQ7cwg1XPXQo,23595
14
15
  tests/test_simple_storage_unit.py,sha256=29mrQwIkS63D6-b1lNZRhUlZ2nkmjpXtQVGHPvYq_ug,16595
15
16
  tests/test_storage_client_factory.py,sha256=lZr7SRY4rpzQB-ZgG7gbjPF2Pcde55nwweumSJT7Yd0,2363
16
17
  transfer_queue/__init__.py,sha256=68c0sBfqHPqTa7OdzO4sAZB52XvwtjpwLqP9BWAh4fA,1535
17
18
  transfer_queue/client.py,sha256=vH9stFyDCXtLYujdNVYjz815NjXbSUJOmxcFOZLIU1s,25831
18
- transfer_queue/controller.py,sha256=pq5OwXbdFILOZD3IKR39KF-Y9B8yVLmfsztzoQ2EvMk,49839
19
- transfer_queue/metadata.py,sha256=zqxAOrFj2uRCC8jPWxkLjOR7hJuiY3AxF2FKYe3oxNU,21515
19
+ transfer_queue/controller.py,sha256=uIhgSEdHQDfmQicqdPXnyXQuJ6Cn5AH3Wyue0SlYdrY,50853
20
+ transfer_queue/metadata.py,sha256=W71tN_-AVfizFDFM-2mIWt1RYvIkN9J1pwgUwBeDPpo,22664
20
21
  transfer_queue/sampler/__init__.py,sha256=1oauDy2Dwb5GXhKi7tl5DWAHv8i4t2MQK1S4U36Sy4g,788
21
22
  transfer_queue/sampler/base.py,sha256=wFti4dNJb3YArYpGzxA_YDfyUTdTG8wVz6HclPDyZPw,3299
22
23
  transfer_queue/sampler/grpo_group_n_sampler.py,sha256=Kq3hGAz8mboBNvw4Dj0P8lP6Qs8TDojx81fxSh57w28,6566
23
24
  transfer_queue/sampler/sequential_sampler.py,sha256=TY0eB-uFLUskwoNMgu3AvuF4G2KDkgjOkrlXZHy4Pls,2780
24
25
  transfer_queue/storage/__init__.py,sha256=559q9ZOMLLhHXil5-iY3aLPnACoJLnZnKf-E0lvpQdk,978
25
- transfer_queue/storage/simple_backend.py,sha256=51M5QOXood6D5Sojh135L3jPHvXgzl_TcGLKR02mIiw,18988
26
+ transfer_queue/storage/simple_backend.py,sha256=3bGqalWk0FX7SYjSxEWQeH1j-XAl_6heFWJvRbTgaII,19358
26
27
  transfer_queue/storage/clients/__init__.py,sha256=WCa6pcijAixpopvelkZZ9ZRTwF_P3fYMmSEBb04CQZ4,915
27
28
  transfer_queue/storage/clients/base.py,sha256=xXd9JBeTmW8tN4wsPocHhW-ERUEzx2YyYHZrtuQQIdI,690
28
29
  transfer_queue/storage/clients/factory.py,sha256=lPOG8oMAgaTbrzkogcOULPJnGywa0F-m4vskkOQZhnU,2137
@@ -30,15 +31,16 @@ transfer_queue/storage/clients/yuanrong_client.py,sha256=MskYioa0BHHGRqzYbAg9DCn
30
31
  transfer_queue/storage/managers/__init__.py,sha256=bkgaIN4Xa3IF26JJt4BK4bqct0SESdwA32wX5SLnuY0,959
31
32
  transfer_queue/storage/managers/base.py,sha256=iSBethCS-pq0tHcRsolXtGD0V_0PtBmInNF4Gi-flfw,21628
32
33
  transfer_queue/storage/managers/factory.py,sha256=58kp2mCKz1K8Ea7RWMsWxdDhN3y4ZhgE-G647AKq7-I,1752
33
- transfer_queue/storage/managers/simple_backend_manager.py,sha256=vTOfc3Y163km-r-lv4fQ_KlUzTTPD3LBUT1E7dCcPcs,27759
34
+ transfer_queue/storage/managers/simple_backend_manager.py,sha256=8F7fW7Z3SDeR22KCZqnFWPyC4qil5w_sp7QXHfvJsfY,28015
34
35
  transfer_queue/storage/managers/yuanrong_manager.py,sha256=RsCmVVDNcTaMU9J9vwN1gu0-srdBCWG-W3Q_Si91uio,1250
35
36
  transfer_queue/utils/__init__.py,sha256=vki-5RVaRBKxVc6Q7XPQox3VNPio2DvJYvRz0SZtu-w,586
37
+ transfer_queue/utils/perf_utils.py,sha256=WUl8AW9eHS5P9G3zq8g52MgMsyZqTqFXuBkXyBLmBLc,4100
36
38
  transfer_queue/utils/serial_utils.py,sha256=J94wrNKVEJtZg22o7GByMs9e_UuwOgRqt1faC5Sy7DY,6048
37
39
  transfer_queue/utils/utils.py,sha256=EE5S8YtyLNduohj1egKLHQlG4K2nrN-yAa8klBx9Nro,4846
38
- transfer_queue/utils/zmq_utils.py,sha256=gsu9xbfxrEXWihxsyW-3hhPcmUIQL0G0eZgJuzjF8gY,10590
39
- transfer_queue/version/version,sha256=_j45IFlkEXFXEmud89gRme3qSSoOPxu3Gk5uOxHL9eo,11
40
- transferqueue-0.1.4.dev0.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
41
- transferqueue-0.1.4.dev0.dist-info/METADATA,sha256=hRPOl4vOhcBv86t4fjMicXkZR5UnrmQwyw4ngvETTmw,19502
42
- transferqueue-0.1.4.dev0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
43
- transferqueue-0.1.4.dev0.dist-info/top_level.txt,sha256=6qfRszcN5Zyq8fWzDWI_wDo9N3Dg8k-8CsXeMLkwuXo,120
44
- transferqueue-0.1.4.dev0.dist-info/RECORD,,
40
+ transfer_queue/utils/zmq_utils.py,sha256=ecJO1GV_AEAAKnnts-0t7jl19j9jUpBT07N6ZSR8op0,10730
41
+ transfer_queue/version/version,sha256=BEnC3jt-HrAwaHIIQhet48H4zzl05lM_-XlEH_IAuRc,11
42
+ transferqueue-0.1.4.dev1.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
43
+ transferqueue-0.1.4.dev1.dist-info/METADATA,sha256=pUiuMKBnGXQUsrtba_G1VAllyhJr64aCOuxyt3UMyQo,19502
44
+ transferqueue-0.1.4.dev1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
+ transferqueue-0.1.4.dev1.dist-info/top_level.txt,sha256=4MQO9VzdR-IUYG4xAidtwDNiWECIQZ_zx0G5KflYJkE,131
46
+ transferqueue-0.1.4.dev1.dist-info/RECORD,,
@@ -5,3 +5,4 @@ serial_profiling_demo
5
5
  serial_profiling_demo_nested_non_continues_test
6
6
  tests
7
7
  transfer_queue
8
+ verify_fix
verify_fix.py ADDED
@@ -0,0 +1,109 @@
1
+ #!/usr/bin/env python
2
+ """
3
+ 验证脚本:测试单元素nested tensor的序列化/反序列化修复
4
+
5
+ 此脚本验证了在TQ_ZERO_COPY_SERIALIZATION=True时,
6
+ 序列化只有1个tensor的nested tensor能够正确区分于普通tensor。
7
+ """
8
+
9
+ import os
10
+ import torch
11
+ from tensordict import TensorDict
12
+
13
+ # 启用零拷贝序列化
14
+ os.environ["TQ_ZERO_COPY_SERIALIZATION"] = "True"
15
+
16
+ from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType
17
+
18
+ def test_single_nested_tensor_fix():
19
+ """验证单元素nested tensor修复"""
20
+ print("=" * 80)
21
+ print("测试:单元素nested tensor序列化/反序列化修复")
22
+ print("=" * 80)
23
+
24
+ # 创建单元素nested tensor和普通tensor
25
+ single_nested = torch.nested.as_nested_tensor([torch.randn(4, 3)], layout=torch.strided)
26
+ normal_tensor = torch.randn(1, 4, 3)
27
+
28
+ print("\n1. 创建测试数据:")
29
+ print(f" - 单元素nested tensor: {single_nested.shape}, is_nested={single_nested.is_nested}")
30
+ print(f" - 普通tensor: {normal_tensor.shape}, is_nested={normal_tensor.is_nested}")
31
+
32
+ # 创建TensorDict
33
+ td = TensorDict(
34
+ {
35
+ "single_nested_tensor": single_nested,
36
+ "normal_tensor": normal_tensor,
37
+ },
38
+ batch_size=1,
39
+ )
40
+
41
+ print("\n2. 创建ZMQMessage并序列化:")
42
+ msg = ZMQMessage(
43
+ request_type=ZMQRequestType.PUT_DATA,
44
+ sender_id="test_sender",
45
+ receiver_id="test_receiver",
46
+ body={"data": td},
47
+ )
48
+
49
+ # 序列化
50
+ serialized_data = msg.serialize()
51
+ print(f" - 序列化完成,数据列表长度: {len(serialized_data)}")
52
+
53
+ # 反序列化
54
+ print("\n3. 反序列化数据:")
55
+ decoded_msg = ZMQMessage.deserialize(serialized_data)
56
+
57
+ print(f" - 反序列化完成")
58
+ print(f" - decoded_msg.body['data']['single_nested_tensor'].is_nested = {decoded_msg.body['data']['single_nested_tensor'].is_nested}")
59
+ print(f" - decoded_msg.body['data']['normal_tensor'].is_nested = {decoded_msg.body['data']['normal_tensor'].is_nested}")
60
+
61
+ # 验证结果
62
+ print("\n4. 验证结果:")
63
+ success = True
64
+
65
+ # 检查单元素nested tensor
66
+ if decoded_msg.body["data"]["single_nested_tensor"].is_nested:
67
+ print(" ✓ 单元素nested tensor正确保持为nested类型")
68
+ else:
69
+ print(" ✗ 单元素nested tensor错误地变成了普通tensor类型")
70
+ success = False
71
+
72
+ # 检查普通tensor
73
+ if not decoded_msg.body["data"]["normal_tensor"].is_nested:
74
+ print(" ✓ 普通tensor正确保持为普通tensor类型")
75
+ else:
76
+ print(" ✗ 普通tensor错误地变成了nested类型")
77
+ success = False
78
+
79
+ # 检查数据内容
80
+ import torch
81
+ if torch.allclose(
82
+ decoded_msg.body["data"]["single_nested_tensor"][0],
83
+ single_nested[0]
84
+ ):
85
+ print(" ✓ 单元素nested tensor数据内容正确")
86
+ else:
87
+ print(" ✗ 单元素nested tensor数据内容不正确")
88
+ success = False
89
+
90
+ if torch.allclose(
91
+ decoded_msg.body["data"]["normal_tensor"],
92
+ normal_tensor
93
+ ):
94
+ print(" ✓ 普通tensor数据内容正确")
95
+ else:
96
+ print(" ✗ 普通tensor数据内容不正确")
97
+ success = False
98
+
99
+ print("\n" + "=" * 80)
100
+ if success:
101
+ print("✓ 所有测试通过!修复有效。")
102
+ else:
103
+ print("✗ 测试失败!修复可能存在问题。")
104
+ print("=" * 80)
105
+
106
+ return success
107
+
108
+ if __name__ == "__main__":
109
+ test_single_nested_tensor_fix()