TransferQueue 0.1.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. recipe/simple_use_case/async_demo.py +331 -0
  2. recipe/simple_use_case/sync_demo.py +220 -0
  3. tests/test_async_simple_storage_manager.py +339 -0
  4. tests/test_client.py +423 -0
  5. tests/test_controller.py +274 -0
  6. tests/test_controller_data_partitions.py +513 -0
  7. tests/test_kv_storage_manager.py +92 -0
  8. tests/test_put.py +327 -0
  9. tests/test_samplers.py +492 -0
  10. tests/test_serial_utils_on_cpu.py +202 -0
  11. tests/test_simple_storage_unit.py +443 -0
  12. tests/test_storage_client_factory.py +45 -0
  13. transfer_queue/__init__.py +48 -0
  14. transfer_queue/client.py +611 -0
  15. transfer_queue/controller.py +1187 -0
  16. transfer_queue/metadata.py +460 -0
  17. transfer_queue/sampler/__init__.py +19 -0
  18. transfer_queue/sampler/base.py +74 -0
  19. transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
  20. transfer_queue/sampler/sequential_sampler.py +75 -0
  21. transfer_queue/storage/__init__.py +25 -0
  22. transfer_queue/storage/clients/__init__.py +24 -0
  23. transfer_queue/storage/clients/base.py +22 -0
  24. transfer_queue/storage/clients/factory.py +55 -0
  25. transfer_queue/storage/clients/yuanrong_client.py +118 -0
  26. transfer_queue/storage/managers/__init__.py +23 -0
  27. transfer_queue/storage/managers/base.py +460 -0
  28. transfer_queue/storage/managers/factory.py +43 -0
  29. transfer_queue/storage/managers/simple_backend_manager.py +611 -0
  30. transfer_queue/storage/managers/yuanrong_manager.py +18 -0
  31. transfer_queue/storage/simple_backend.py +451 -0
  32. transfer_queue/utils/__init__.py +13 -0
  33. transfer_queue/utils/serial_utils.py +240 -0
  34. transfer_queue/utils/utils.py +132 -0
  35. transfer_queue/utils/zmq_utils.py +170 -0
  36. transfer_queue/version/version +1 -0
  37. transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
  38. transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
  39. transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
  40. transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
  41. transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
@@ -0,0 +1,240 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ # Copyright 2025 The vLLM project
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # This implementation is inspired by https://github.com/vllm-project/vllm/blob/main/vllm/v1/serial_utils.py
17
+
18
+ import os
19
+ import pickle
20
+ from collections.abc import Sequence
21
+ from inspect import isclass
22
+ from types import FunctionType
23
+ from typing import Any, Optional, TypeAlias
24
+
25
+ import cloudpickle
26
+ import torch
27
+ import zmq
28
+ from msgspec import msgpack
29
+ from tensordict import NonTensorData, TensorDict
30
+
31
+ TQ_MSGPACK_ZERO_COPY_THRESHOLD = int(os.environ.get("TQ_MSGPACK_ZERO_COPY_THRESHOLD", 256))
32
+ CUSTOM_TYPE_PICKLE = 1
33
+ CUSTOM_TYPE_CLOUDPICKLE = 2
34
+ CUSTOM_TYPE_RAW_VIEW = 3
35
+
36
+ bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame
37
+ tensorenc = tuple[str, tuple[int, ...], int | memoryview]
38
+
39
+
40
+ class MsgpackEncoder:
41
+ """Encoder with custom torch tensor and numpy array serialization.
42
+
43
+ Note that unlike vanilla `msgspec` Encoders, this interface is generally
44
+ not thread-safe when encoding tensors / numpy arrays.
45
+
46
+ By default, arrays below 256B are serialized inline Larger will get sent
47
+ via dedicated messages. Note that this is a per-tensor limit.
48
+ """
49
+
50
+ def __init__(self, size_threshold: Optional[int] = None):
51
+ if size_threshold is None:
52
+ size_threshold = TQ_MSGPACK_ZERO_COPY_THRESHOLD
53
+ self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
54
+ # This is used as a local stash of buffers that we can then access from
55
+ # our custom `msgspec` hook, `enc_hook`. We don't have a way to
56
+ # pass custom data to the hook otherwise.
57
+ self.aux_buffers: Optional[list[bytestr]] = None
58
+ self.size_threshold = size_threshold
59
+
60
+ def encode(self, obj: Any) -> Sequence[bytestr]:
61
+ try:
62
+ self.aux_buffers = bufs = [b""]
63
+ bufs[0] = self.encoder.encode(obj)
64
+ # This `bufs` list allows us to collect direct pointers to backing
65
+ # buffers of tensors and np arrays, and return them along with the
66
+ # top-level encoded buffer instead of copying their data into the
67
+ # new buffer.
68
+ return bufs
69
+ finally:
70
+ self.aux_buffers = None
71
+
72
+ def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
73
+ try:
74
+ self.aux_buffers = [buf]
75
+ bufs = self.aux_buffers
76
+ self.encoder.encode_into(obj, buf)
77
+ return bufs
78
+ finally:
79
+ self.aux_buffers = None
80
+
81
+ def enc_hook(self, obj: Any) -> Any:
82
+ if isinstance(obj, TensorDict):
83
+ return self._encode_tensordict(obj)
84
+
85
+ if isinstance(obj, torch.Tensor):
86
+ return self._encode_tensor(obj)
87
+
88
+ if isinstance(obj, FunctionType):
89
+ # `pickle` is generally faster than cloudpickle, but can have
90
+ # problems serializing methods.
91
+ return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
92
+
93
+ return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
94
+
95
+ def _encode_tensordict(self, obj: TensorDict) -> tuple[tuple[int, ...], Optional[str], dict[str, tuple[str, Any]]]:
96
+ assert self.aux_buffers is not None
97
+ encoded_items: dict[str, tuple[str, Any]] = {}
98
+ for k, v in obj.items():
99
+ if isinstance(v, torch.Tensor):
100
+ encoded_items[k] = ("tensor", self._encode_tensor(v))
101
+ # elif isinstance(v, NonTensorStack):
102
+ # encoded_items[k] = ("non_tensor_stack", self._encode_non_tensor_stack(v))
103
+ elif isinstance(v, NonTensorData):
104
+ encoded_items[k] = ("non_tensor_data", self._encode_non_tensor_data(v))
105
+ else:
106
+ data = len(self.aux_buffers)
107
+ self.aux_buffers.append(pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL))
108
+ encoded_items[k] = ("other", data)
109
+ batch_size = tuple(obj.batch_size)
110
+ device = str(obj.device) if obj.device is not None else None
111
+ return batch_size, device, encoded_items
112
+
113
+ def _encode_tensor(self, obj: torch.Tensor) -> tuple[str, list[tensorenc]] | tensorenc:
114
+ if not obj.is_nested:
115
+ return self._encode_single_tensor(obj)
116
+ else:
117
+ layout = str(obj.layout).removeprefix("torch.")
118
+ data = [self._encode_single_tensor(tensor) for tensor in obj.unbind()]
119
+ return layout, data
120
+
121
+ def _encode_single_tensor(self, obj: torch.Tensor) -> tensorenc:
122
+ assert self.aux_buffers is not None
123
+ # view the tensor as a contiguous 1D array of bytes
124
+ arr = obj.flatten().contiguous().view(torch.uint8).numpy()
125
+ if obj.nbytes < self.size_threshold:
126
+ # Smaller tensors are encoded inline, just like ndarrays.
127
+ data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
128
+ else:
129
+ # Otherwise encode index of backing buffer to avoid copy.
130
+ data = len(self.aux_buffers)
131
+ self.aux_buffers.append(arr.data)
132
+ dtype = str(obj.dtype).removeprefix("torch.")
133
+ return dtype, obj.shape, data
134
+
135
+ def _encode_non_tensor_data(self, obj: NonTensorData) -> tuple[tuple[int, ...], Optional[str], int]:
136
+ assert self.aux_buffers is not None
137
+ batch_size = tuple(obj.batch_size)
138
+ device = str(obj.device) if obj.device is not None else None
139
+ data = len(self.aux_buffers)
140
+ self.aux_buffers.append(pickle.dumps(obj.data, protocol=pickle.HIGHEST_PROTOCOL))
141
+ return batch_size, device, data
142
+
143
+
144
+ class MsgpackDecoder:
145
+ """Decoder with custom torch tensor and numpy array serialization.
146
+
147
+ Note that unlike vanilla `msgspec` Decoders, this interface is generally
148
+ not thread-safe when encoding tensors / numpy arrays.
149
+ """
150
+
151
+ def __init__(self, t: Optional[Any] = None):
152
+ args = () if t is None else (t,)
153
+ self.decoder = msgpack.Decoder(*args, ext_hook=self.ext_hook, dec_hook=self.dec_hook)
154
+ self.aux_buffers: Sequence[bytestr] = ()
155
+
156
+ def decode(self, bufs: bytestr | Sequence[bytestr]) -> Any:
157
+ if isinstance(bufs, bytestr):
158
+ return self.decoder.decode(bufs)
159
+
160
+ self.aux_buffers = bufs
161
+ try:
162
+ return self.decoder.decode(bufs[0]) # type: ignore[index]
163
+ finally:
164
+ self.aux_buffers = ()
165
+
166
+ def dec_hook(self, t: type, obj: Any) -> Any:
167
+ # Given native types in `obj`, convert to type `t`.
168
+ if isclass(t):
169
+ if issubclass(t, TensorDict):
170
+ return self._decode_tensordict(obj)
171
+ if issubclass(t, torch.Tensor):
172
+ return self._decode_tensor(obj)
173
+ return obj
174
+
175
+ def _decode_tensordict(self, arr: Any) -> TensorDict:
176
+ batch_size, device, encoded_items = arr
177
+ decoded_items: dict[str, Any] = {}
178
+
179
+ for k, (v_type, v) in encoded_items.items():
180
+ if v_type == "tensor":
181
+ decoded_items[k] = self._decode_tensor(v)
182
+ # elif v_type == "non_tensor_stack":
183
+ # decoded_items[k] = self._decode_non_tensor_stack(v)
184
+ elif v_type == "non_tensor_data":
185
+ decoded_items[k] = self._decode_non_tensor_data(v)
186
+ elif v_type == "other":
187
+ decoded_items[k] = pickle.loads(self.aux_buffers[v])
188
+
189
+ batch_size = torch.Size(batch_size)
190
+ torch_device = torch.device(device) if device is not None else None
191
+
192
+ return TensorDict(source=decoded_items, batch_size=batch_size, device=torch_device)
193
+
194
+ def _decode_tensor(self, arr: Any) -> torch.Tensor:
195
+ if len(arr) == 3:
196
+ # decode single tensor
197
+ return self._decode_single_tensor(arr)
198
+ elif len(arr) == 2:
199
+ # decode nested tensor
200
+ layout, data = arr
201
+ torch_layout = getattr(torch, layout)
202
+ return torch.nested.as_nested_tensor(
203
+ [self._decode_single_tensor(tensor) for tensor in data], layout=torch_layout
204
+ )
205
+ else:
206
+ raise ValueError(f"Invalid tensor encoding format, expected length 2 or 3, got {len(arr)}")
207
+
208
+ def _decode_single_tensor(self, arr: Any) -> torch.Tensor:
209
+ dtype, shape, data = arr
210
+ # Copy from inline representation, to decouple the memory storage
211
+ # of the message from the original buffer. And also make Torch
212
+ # not complain about a readonly memoryview.
213
+ buffer = self.aux_buffers[data] if isinstance(data, int) else bytearray(data)
214
+ torch_dtype = getattr(torch, dtype)
215
+ assert isinstance(torch_dtype, torch.dtype)
216
+ if not buffer: # torch.frombuffer doesn't like empty buffers
217
+ assert 0 in shape
218
+ return torch.empty(shape, dtype=torch_dtype)
219
+ # Create uint8 array
220
+ arr = torch.frombuffer(buffer, dtype=torch.uint8)
221
+ # Convert back to proper shape & type
222
+ return arr.view(torch_dtype).view(shape)
223
+
224
+ def _decode_non_tensor_data(self, arr: Any) -> NonTensorData:
225
+ batch_size, device, data = arr
226
+ buffer = self.aux_buffers[data]
227
+ batch_size = torch.Size(batch_size)
228
+ torch_device = torch.device(device) if device is not None else None
229
+ non_tensor_data = pickle.loads(buffer)
230
+ return NonTensorData(data=non_tensor_data, batch_size=batch_size, device=torch_device)
231
+
232
+ def ext_hook(self, code: int, data: memoryview) -> Any:
233
+ if code == CUSTOM_TYPE_RAW_VIEW:
234
+ return data
235
+ if code == CUSTOM_TYPE_PICKLE:
236
+ return pickle.loads(data)
237
+ if code == CUSTOM_TYPE_CLOUDPICKLE:
238
+ return cloudpickle.loads(data)
239
+
240
+ raise NotImplementedError(f"Extension type code {code} is not supported")
@@ -0,0 +1,132 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from contextlib import contextmanager
16
+ from enum import Enum
17
+ from typing import Optional
18
+
19
+ import psutil
20
+ import ray
21
+ import torch
22
+
23
+
24
+ class ExplicitEnum(str, Enum):
25
+ """
26
+ Enum with more explicit error message for missing values.
27
+ """
28
+
29
+ @classmethod
30
+ def _missing_(cls, value):
31
+ raise ValueError(
32
+ f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
33
+ )
34
+
35
+
36
+ class TransferQueueRole(ExplicitEnum):
37
+ CONTROLLER = "TransferQueueController"
38
+ STORAGE = "TransferQueueStorage"
39
+ CLIENT = "TransferQueueClient"
40
+
41
+
42
+ # production_status enum: 0: not produced, 1: ready for consume, 2: consumed
43
+ class ProductionStatus(ExplicitEnum):
44
+ NOT_PRODUCED = 0
45
+ READY_FOR_CONSUME = 1
46
+ CONSUMED = 2
47
+
48
+
49
+ def get_placement_group(num_ray_actors: int, num_cpus_per_actor: int = 1):
50
+ """
51
+ Create a placement group with SPREAD strategy for Ray actors.
52
+
53
+ Args:
54
+ num_ray_actors (int): Number of Ray actors to create.
55
+ num_cpus_per_actor (int): Number of CPUs to allocate per actor.
56
+
57
+ Returns:
58
+ placement_group: The created placement group.
59
+ """
60
+ bundle = {"CPU": num_cpus_per_actor}
61
+ placement_group = ray.util.placement_group([bundle for _ in range(num_ray_actors)], strategy="SPREAD")
62
+ ray.get(placement_group.ready())
63
+ return placement_group
64
+
65
+
66
+ def sequential_sampler(
67
+ ready_for_consume_idx: list[int],
68
+ batch_size: int,
69
+ get_n_samples: bool,
70
+ n_samples_per_prompt: int,
71
+ ) -> list[int]:
72
+ """
73
+ Sequentially samples a batch of indices from global indexes ready_for_consume_idx.
74
+
75
+ Args:
76
+ ready_for_consume_idx: A sorted list of available indices for sampling.
77
+ - When get_n_samples=True:
78
+ Expected to be grouped by prompts, e.g.,
79
+ [0,1,2,3, 8,9,10,11, 12,13,14,15] (3 groups of 4 samples each)
80
+ - When get_n_samples=False:
81
+ Can be any ordered list, e.g., [0,3,5,6,7,8]
82
+ batch_size: Total number of samples to return
83
+ get_n_samples: Flag indicating the sampling mode
84
+ n_samples_per_prompt: Number of samples per prompt (used when get_n_samples=True)
85
+
86
+ Returns:
87
+ list[int]: Sequentially sampled indices of length batch_size
88
+ """
89
+ if get_n_samples:
90
+ assert len(ready_for_consume_idx) % n_samples_per_prompt == 0
91
+ assert batch_size % n_samples_per_prompt == 0
92
+ batch_size_n_samples = batch_size // n_samples_per_prompt
93
+
94
+ group_ready_for_consume_idx = torch.tensor(ready_for_consume_idx, dtype=torch.int).view(
95
+ -1, n_samples_per_prompt
96
+ )
97
+
98
+ sampled_indexes = group_ready_for_consume_idx[list(range(batch_size_n_samples))].flatten().tolist()
99
+ else:
100
+ sampled_indexes = [int(ready_for_consume_idx[i]) for i in range(batch_size)]
101
+ return sampled_indexes
102
+
103
+
104
+ @contextmanager
105
+ def limit_pytorch_auto_parallel_threads(target_num_threads: Optional[int] = None):
106
+ """Prevent PyTorch from overdoing the automatic parallelism during torch.stack() operation"""
107
+ pytorch_current_num_threads = torch.get_num_threads()
108
+ logical_cores = psutil.cpu_count(logical=True)
109
+ physical_cores = psutil.cpu_count(logical=False)
110
+
111
+ if target_num_threads is None:
112
+ # auto determine target_num_threads
113
+ if physical_cores >= 16:
114
+ target_num_threads = 16
115
+ else:
116
+ target_num_threads = physical_cores
117
+
118
+ if target_num_threads > logical_cores:
119
+ raise RuntimeError(
120
+ f"target_num_threads {target_num_threads} should not exceed total logical CPU cores {logical_cores}"
121
+ )
122
+
123
+ if pytorch_current_num_threads <= target_num_threads:
124
+ # No need to change settings
125
+ yield
126
+ else:
127
+ torch.set_num_threads(target_num_threads)
128
+ try:
129
+ yield
130
+ finally:
131
+ # Restore the original number of threads
132
+ torch.set_num_threads(pytorch_current_num_threads)
@@ -0,0 +1,170 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pickle
16
+ import socket
17
+ import time
18
+ from dataclasses import dataclass
19
+ from typing import Any, Optional
20
+ from uuid import uuid4
21
+
22
+ import psutil
23
+ import zmq
24
+
25
+ from transfer_queue.utils.utils import (
26
+ ExplicitEnum,
27
+ TransferQueueRole,
28
+ )
29
+
30
+
31
+ class ZMQRequestType(ExplicitEnum):
32
+ # HANDSHAKE
33
+ HANDSHAKE = "HANDSHAKE" # TransferQueueStorageUnit -> TransferQueueController
34
+ HANDSHAKE_ACK = "HANDSHAKE_ACK" # TransferQueueController -> TransferQueueStorageUnit
35
+
36
+ # DATA_OPERATION
37
+ GET_DATA = "GET"
38
+ PUT_DATA = "PUT"
39
+ GET_DATA_RESPONSE = "GET_DATA_RESPONSE"
40
+ PUT_DATA_RESPONSE = "PUT_DATA_RESPONSE"
41
+ CLEAR_DATA = "CLEAR_DATA"
42
+ CLEAR_DATA_RESPONSE = "CLEAR_DATA_RESPONSE"
43
+
44
+ PUT_GET_OPERATION_ERROR = "PUT_GET_OPERATION_ERROR"
45
+ PUT_GET_ERROR = "PUT_GET_ERROR"
46
+ PUT_ERROR = "PUT_ERROR"
47
+ GET_ERROR = "GET_ERROR"
48
+ CLEAR_DATA_ERROR = "CLEAR_DATA_ERROR"
49
+
50
+ # META_OPERATION
51
+ GET_META = "GET_META"
52
+ GET_META_RESPONSE = "GET_META_RESPONSE"
53
+ GET_CLEAR_META = "GET_CLEAR_META"
54
+ GET_CLEAR_META_RESPONSE = "GET_CLEAR_META_RESPONSE"
55
+ CLEAR_META = "CLEAR_META"
56
+ CLEAR_META_RESPONSE = "CLEAR_META_RESPONSE"
57
+
58
+ # CHECK_CONSUMPTION
59
+ CHECK_CONSUMPTION = "CHECK_CONSUMPTION"
60
+ CONSUMPTION_RESPONSE = "CONSUMPTION_RESPONSE"
61
+
62
+ # NOTIFY_DATA_UPDATE
63
+ NOTIFY_DATA_UPDATE = "NOTIFY_DATA_UPDATE"
64
+ NOTIFY_DATA_UPDATE_ACK = "NOTIFY_DATA_UPDATE_ACK"
65
+ NOTIFY_DATA_UPDATE_ERROR = "NOTIFY_DATA_UPDATE_ERROR"
66
+
67
+
68
+ class ZMQServerInfo:
69
+ def __init__(self, role: TransferQueueRole, id: str, ip: str, ports: dict[str, str]):
70
+ self.role = role
71
+ self.id = id
72
+ self.ip = ip
73
+ self.ports = ports
74
+
75
+ def to_addr(self, port_name: str) -> str:
76
+ return f"tcp://{self.ip}:{self.ports[port_name]}"
77
+
78
+ def to_dict(self):
79
+ return {
80
+ "role": self.role,
81
+ "id": self.id,
82
+ "ip": self.ip,
83
+ "ports": self.ports,
84
+ }
85
+
86
+ def __str__(self) -> str:
87
+ return f"ZMQSocketInfo(role={self.role}, id={self.id}, ip={self.ip}, ports={self.ports})"
88
+
89
+
90
+ @dataclass
91
+ class ZMQMessage:
92
+ request_type: ZMQRequestType
93
+ sender_id: str
94
+ receiver_id: str | None
95
+ body: dict[str, Any]
96
+ request_id: str
97
+ timestamp: float
98
+
99
+ @classmethod
100
+ def create(
101
+ cls,
102
+ request_type: ZMQRequestType,
103
+ sender_id: str,
104
+ body: dict[str, Any],
105
+ receiver_id: Optional[str] = None,
106
+ ) -> "ZMQMessage":
107
+ return cls(
108
+ request_type=request_type,
109
+ sender_id=sender_id,
110
+ receiver_id=receiver_id,
111
+ body=body,
112
+ request_id=str(uuid4().hex[:8]),
113
+ timestamp=time.time(),
114
+ )
115
+
116
+ def serialize(self) -> bytes:
117
+ """Using pickle to serialize ZMQMessage objects"""
118
+ return pickle.dumps(self)
119
+
120
+ @classmethod
121
+ def deserialize(cls, data: bytes | list[bytes]):
122
+ """Using pickle to deserialize ZMQMessage objects"""
123
+ if isinstance(data, list):
124
+ # Process multiple byte streams by deserializing each in sequence
125
+ result = []
126
+ for d in data:
127
+ result.append(pickle.loads(d))
128
+ return result
129
+ else:
130
+ # Single byte stream case
131
+ return pickle.loads(data)
132
+
133
+
134
+ def get_free_port() -> str:
135
+ with socket.socket() as sock:
136
+ sock.bind(("", 0))
137
+ return sock.getsockname()[1]
138
+
139
+
140
+ def create_zmq_socket(
141
+ ctx: zmq.Context,
142
+ socket_type: Any,
143
+ identity: Optional[bytes] = None,
144
+ ) -> zmq.Socket:
145
+ mem = psutil.virtual_memory()
146
+ socket = ctx.socket(socket_type)
147
+
148
+ # Calculate buffer size based on system memory
149
+ total_mem = mem.total / 1024**3
150
+ available_mem = mem.available / 1024**3
151
+ # For systems with substantial memory (>32GB total, >16GB available):
152
+ # - Set a large 0.5GB buffer to improve throughput
153
+ # For systems with less memory:
154
+ # - Use system default (-1) to avoid excessive memory consumption
155
+ if total_mem > 32 and available_mem > 16:
156
+ buf_size = int(0.5 * 1024**3) # 0.5GB in bytes
157
+ else:
158
+ buf_size = -1 # Use system default buffer size
159
+
160
+ if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
161
+ socket.setsockopt(zmq.RCVHWM, 0)
162
+ socket.setsockopt(zmq.RCVBUF, buf_size)
163
+
164
+ if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
165
+ socket.setsockopt(zmq.SNDHWM, 0)
166
+ socket.setsockopt(zmq.SNDBUF, buf_size)
167
+
168
+ if identity is not None:
169
+ socket.setsockopt(zmq.IDENTITY, identity)
170
+ return socket
@@ -0,0 +1 @@
1
+ 0.1.1.dev