TransferQueue 0.1.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +331 -0
- recipe/simple_use_case/sync_demo.py +220 -0
- tests/test_async_simple_storage_manager.py +339 -0
- tests/test_client.py +423 -0
- tests/test_controller.py +274 -0
- tests/test_controller_data_partitions.py +513 -0
- tests/test_kv_storage_manager.py +92 -0
- tests/test_put.py +327 -0
- tests/test_samplers.py +492 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +443 -0
- tests/test_storage_client_factory.py +45 -0
- transfer_queue/__init__.py +48 -0
- transfer_queue/client.py +611 -0
- transfer_queue/controller.py +1187 -0
- transfer_queue/metadata.py +460 -0
- transfer_queue/sampler/__init__.py +19 -0
- transfer_queue/sampler/base.py +74 -0
- transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
- transfer_queue/sampler/sequential_sampler.py +75 -0
- transfer_queue/storage/__init__.py +25 -0
- transfer_queue/storage/clients/__init__.py +24 -0
- transfer_queue/storage/clients/base.py +22 -0
- transfer_queue/storage/clients/factory.py +55 -0
- transfer_queue/storage/clients/yuanrong_client.py +118 -0
- transfer_queue/storage/managers/__init__.py +23 -0
- transfer_queue/storage/managers/base.py +460 -0
- transfer_queue/storage/managers/factory.py +43 -0
- transfer_queue/storage/managers/simple_backend_manager.py +611 -0
- transfer_queue/storage/managers/yuanrong_manager.py +18 -0
- transfer_queue/storage/simple_backend.py +451 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +132 -0
- transfer_queue/utils/zmq_utils.py +170 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
- transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
- transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
# Copyright 2025 The vLLM project
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# This implementation is inspired by https://github.com/vllm-project/vllm/blob/main/vllm/v1/serial_utils.py
|
|
17
|
+
|
|
18
|
+
import os
|
|
19
|
+
import pickle
|
|
20
|
+
from collections.abc import Sequence
|
|
21
|
+
from inspect import isclass
|
|
22
|
+
from types import FunctionType
|
|
23
|
+
from typing import Any, Optional, TypeAlias
|
|
24
|
+
|
|
25
|
+
import cloudpickle
|
|
26
|
+
import torch
|
|
27
|
+
import zmq
|
|
28
|
+
from msgspec import msgpack
|
|
29
|
+
from tensordict import NonTensorData, TensorDict
|
|
30
|
+
|
|
31
|
+
TQ_MSGPACK_ZERO_COPY_THRESHOLD = int(os.environ.get("TQ_MSGPACK_ZERO_COPY_THRESHOLD", 256))
|
|
32
|
+
CUSTOM_TYPE_PICKLE = 1
|
|
33
|
+
CUSTOM_TYPE_CLOUDPICKLE = 2
|
|
34
|
+
CUSTOM_TYPE_RAW_VIEW = 3
|
|
35
|
+
|
|
36
|
+
bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame
|
|
37
|
+
tensorenc = tuple[str, tuple[int, ...], int | memoryview]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class MsgpackEncoder:
|
|
41
|
+
"""Encoder with custom torch tensor and numpy array serialization.
|
|
42
|
+
|
|
43
|
+
Note that unlike vanilla `msgspec` Encoders, this interface is generally
|
|
44
|
+
not thread-safe when encoding tensors / numpy arrays.
|
|
45
|
+
|
|
46
|
+
By default, arrays below 256B are serialized inline Larger will get sent
|
|
47
|
+
via dedicated messages. Note that this is a per-tensor limit.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, size_threshold: Optional[int] = None):
|
|
51
|
+
if size_threshold is None:
|
|
52
|
+
size_threshold = TQ_MSGPACK_ZERO_COPY_THRESHOLD
|
|
53
|
+
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
|
|
54
|
+
# This is used as a local stash of buffers that we can then access from
|
|
55
|
+
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
|
|
56
|
+
# pass custom data to the hook otherwise.
|
|
57
|
+
self.aux_buffers: Optional[list[bytestr]] = None
|
|
58
|
+
self.size_threshold = size_threshold
|
|
59
|
+
|
|
60
|
+
def encode(self, obj: Any) -> Sequence[bytestr]:
|
|
61
|
+
try:
|
|
62
|
+
self.aux_buffers = bufs = [b""]
|
|
63
|
+
bufs[0] = self.encoder.encode(obj)
|
|
64
|
+
# This `bufs` list allows us to collect direct pointers to backing
|
|
65
|
+
# buffers of tensors and np arrays, and return them along with the
|
|
66
|
+
# top-level encoded buffer instead of copying their data into the
|
|
67
|
+
# new buffer.
|
|
68
|
+
return bufs
|
|
69
|
+
finally:
|
|
70
|
+
self.aux_buffers = None
|
|
71
|
+
|
|
72
|
+
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
|
|
73
|
+
try:
|
|
74
|
+
self.aux_buffers = [buf]
|
|
75
|
+
bufs = self.aux_buffers
|
|
76
|
+
self.encoder.encode_into(obj, buf)
|
|
77
|
+
return bufs
|
|
78
|
+
finally:
|
|
79
|
+
self.aux_buffers = None
|
|
80
|
+
|
|
81
|
+
def enc_hook(self, obj: Any) -> Any:
|
|
82
|
+
if isinstance(obj, TensorDict):
|
|
83
|
+
return self._encode_tensordict(obj)
|
|
84
|
+
|
|
85
|
+
if isinstance(obj, torch.Tensor):
|
|
86
|
+
return self._encode_tensor(obj)
|
|
87
|
+
|
|
88
|
+
if isinstance(obj, FunctionType):
|
|
89
|
+
# `pickle` is generally faster than cloudpickle, but can have
|
|
90
|
+
# problems serializing methods.
|
|
91
|
+
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
|
|
92
|
+
|
|
93
|
+
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
|
|
94
|
+
|
|
95
|
+
def _encode_tensordict(self, obj: TensorDict) -> tuple[tuple[int, ...], Optional[str], dict[str, tuple[str, Any]]]:
|
|
96
|
+
assert self.aux_buffers is not None
|
|
97
|
+
encoded_items: dict[str, tuple[str, Any]] = {}
|
|
98
|
+
for k, v in obj.items():
|
|
99
|
+
if isinstance(v, torch.Tensor):
|
|
100
|
+
encoded_items[k] = ("tensor", self._encode_tensor(v))
|
|
101
|
+
# elif isinstance(v, NonTensorStack):
|
|
102
|
+
# encoded_items[k] = ("non_tensor_stack", self._encode_non_tensor_stack(v))
|
|
103
|
+
elif isinstance(v, NonTensorData):
|
|
104
|
+
encoded_items[k] = ("non_tensor_data", self._encode_non_tensor_data(v))
|
|
105
|
+
else:
|
|
106
|
+
data = len(self.aux_buffers)
|
|
107
|
+
self.aux_buffers.append(pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL))
|
|
108
|
+
encoded_items[k] = ("other", data)
|
|
109
|
+
batch_size = tuple(obj.batch_size)
|
|
110
|
+
device = str(obj.device) if obj.device is not None else None
|
|
111
|
+
return batch_size, device, encoded_items
|
|
112
|
+
|
|
113
|
+
def _encode_tensor(self, obj: torch.Tensor) -> tuple[str, list[tensorenc]] | tensorenc:
|
|
114
|
+
if not obj.is_nested:
|
|
115
|
+
return self._encode_single_tensor(obj)
|
|
116
|
+
else:
|
|
117
|
+
layout = str(obj.layout).removeprefix("torch.")
|
|
118
|
+
data = [self._encode_single_tensor(tensor) for tensor in obj.unbind()]
|
|
119
|
+
return layout, data
|
|
120
|
+
|
|
121
|
+
def _encode_single_tensor(self, obj: torch.Tensor) -> tensorenc:
|
|
122
|
+
assert self.aux_buffers is not None
|
|
123
|
+
# view the tensor as a contiguous 1D array of bytes
|
|
124
|
+
arr = obj.flatten().contiguous().view(torch.uint8).numpy()
|
|
125
|
+
if obj.nbytes < self.size_threshold:
|
|
126
|
+
# Smaller tensors are encoded inline, just like ndarrays.
|
|
127
|
+
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
|
|
128
|
+
else:
|
|
129
|
+
# Otherwise encode index of backing buffer to avoid copy.
|
|
130
|
+
data = len(self.aux_buffers)
|
|
131
|
+
self.aux_buffers.append(arr.data)
|
|
132
|
+
dtype = str(obj.dtype).removeprefix("torch.")
|
|
133
|
+
return dtype, obj.shape, data
|
|
134
|
+
|
|
135
|
+
def _encode_non_tensor_data(self, obj: NonTensorData) -> tuple[tuple[int, ...], Optional[str], int]:
|
|
136
|
+
assert self.aux_buffers is not None
|
|
137
|
+
batch_size = tuple(obj.batch_size)
|
|
138
|
+
device = str(obj.device) if obj.device is not None else None
|
|
139
|
+
data = len(self.aux_buffers)
|
|
140
|
+
self.aux_buffers.append(pickle.dumps(obj.data, protocol=pickle.HIGHEST_PROTOCOL))
|
|
141
|
+
return batch_size, device, data
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class MsgpackDecoder:
|
|
145
|
+
"""Decoder with custom torch tensor and numpy array serialization.
|
|
146
|
+
|
|
147
|
+
Note that unlike vanilla `msgspec` Decoders, this interface is generally
|
|
148
|
+
not thread-safe when encoding tensors / numpy arrays.
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
def __init__(self, t: Optional[Any] = None):
|
|
152
|
+
args = () if t is None else (t,)
|
|
153
|
+
self.decoder = msgpack.Decoder(*args, ext_hook=self.ext_hook, dec_hook=self.dec_hook)
|
|
154
|
+
self.aux_buffers: Sequence[bytestr] = ()
|
|
155
|
+
|
|
156
|
+
def decode(self, bufs: bytestr | Sequence[bytestr]) -> Any:
|
|
157
|
+
if isinstance(bufs, bytestr):
|
|
158
|
+
return self.decoder.decode(bufs)
|
|
159
|
+
|
|
160
|
+
self.aux_buffers = bufs
|
|
161
|
+
try:
|
|
162
|
+
return self.decoder.decode(bufs[0]) # type: ignore[index]
|
|
163
|
+
finally:
|
|
164
|
+
self.aux_buffers = ()
|
|
165
|
+
|
|
166
|
+
def dec_hook(self, t: type, obj: Any) -> Any:
|
|
167
|
+
# Given native types in `obj`, convert to type `t`.
|
|
168
|
+
if isclass(t):
|
|
169
|
+
if issubclass(t, TensorDict):
|
|
170
|
+
return self._decode_tensordict(obj)
|
|
171
|
+
if issubclass(t, torch.Tensor):
|
|
172
|
+
return self._decode_tensor(obj)
|
|
173
|
+
return obj
|
|
174
|
+
|
|
175
|
+
def _decode_tensordict(self, arr: Any) -> TensorDict:
|
|
176
|
+
batch_size, device, encoded_items = arr
|
|
177
|
+
decoded_items: dict[str, Any] = {}
|
|
178
|
+
|
|
179
|
+
for k, (v_type, v) in encoded_items.items():
|
|
180
|
+
if v_type == "tensor":
|
|
181
|
+
decoded_items[k] = self._decode_tensor(v)
|
|
182
|
+
# elif v_type == "non_tensor_stack":
|
|
183
|
+
# decoded_items[k] = self._decode_non_tensor_stack(v)
|
|
184
|
+
elif v_type == "non_tensor_data":
|
|
185
|
+
decoded_items[k] = self._decode_non_tensor_data(v)
|
|
186
|
+
elif v_type == "other":
|
|
187
|
+
decoded_items[k] = pickle.loads(self.aux_buffers[v])
|
|
188
|
+
|
|
189
|
+
batch_size = torch.Size(batch_size)
|
|
190
|
+
torch_device = torch.device(device) if device is not None else None
|
|
191
|
+
|
|
192
|
+
return TensorDict(source=decoded_items, batch_size=batch_size, device=torch_device)
|
|
193
|
+
|
|
194
|
+
def _decode_tensor(self, arr: Any) -> torch.Tensor:
|
|
195
|
+
if len(arr) == 3:
|
|
196
|
+
# decode single tensor
|
|
197
|
+
return self._decode_single_tensor(arr)
|
|
198
|
+
elif len(arr) == 2:
|
|
199
|
+
# decode nested tensor
|
|
200
|
+
layout, data = arr
|
|
201
|
+
torch_layout = getattr(torch, layout)
|
|
202
|
+
return torch.nested.as_nested_tensor(
|
|
203
|
+
[self._decode_single_tensor(tensor) for tensor in data], layout=torch_layout
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
raise ValueError(f"Invalid tensor encoding format, expected length 2 or 3, got {len(arr)}")
|
|
207
|
+
|
|
208
|
+
def _decode_single_tensor(self, arr: Any) -> torch.Tensor:
|
|
209
|
+
dtype, shape, data = arr
|
|
210
|
+
# Copy from inline representation, to decouple the memory storage
|
|
211
|
+
# of the message from the original buffer. And also make Torch
|
|
212
|
+
# not complain about a readonly memoryview.
|
|
213
|
+
buffer = self.aux_buffers[data] if isinstance(data, int) else bytearray(data)
|
|
214
|
+
torch_dtype = getattr(torch, dtype)
|
|
215
|
+
assert isinstance(torch_dtype, torch.dtype)
|
|
216
|
+
if not buffer: # torch.frombuffer doesn't like empty buffers
|
|
217
|
+
assert 0 in shape
|
|
218
|
+
return torch.empty(shape, dtype=torch_dtype)
|
|
219
|
+
# Create uint8 array
|
|
220
|
+
arr = torch.frombuffer(buffer, dtype=torch.uint8)
|
|
221
|
+
# Convert back to proper shape & type
|
|
222
|
+
return arr.view(torch_dtype).view(shape)
|
|
223
|
+
|
|
224
|
+
def _decode_non_tensor_data(self, arr: Any) -> NonTensorData:
|
|
225
|
+
batch_size, device, data = arr
|
|
226
|
+
buffer = self.aux_buffers[data]
|
|
227
|
+
batch_size = torch.Size(batch_size)
|
|
228
|
+
torch_device = torch.device(device) if device is not None else None
|
|
229
|
+
non_tensor_data = pickle.loads(buffer)
|
|
230
|
+
return NonTensorData(data=non_tensor_data, batch_size=batch_size, device=torch_device)
|
|
231
|
+
|
|
232
|
+
def ext_hook(self, code: int, data: memoryview) -> Any:
|
|
233
|
+
if code == CUSTOM_TYPE_RAW_VIEW:
|
|
234
|
+
return data
|
|
235
|
+
if code == CUSTOM_TYPE_PICKLE:
|
|
236
|
+
return pickle.loads(data)
|
|
237
|
+
if code == CUSTOM_TYPE_CLOUDPICKLE:
|
|
238
|
+
return cloudpickle.loads(data)
|
|
239
|
+
|
|
240
|
+
raise NotImplementedError(f"Extension type code {code} is not supported")
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from contextlib import contextmanager
|
|
16
|
+
from enum import Enum
|
|
17
|
+
from typing import Optional
|
|
18
|
+
|
|
19
|
+
import psutil
|
|
20
|
+
import ray
|
|
21
|
+
import torch
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ExplicitEnum(str, Enum):
|
|
25
|
+
"""
|
|
26
|
+
Enum with more explicit error message for missing values.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def _missing_(cls, value):
|
|
31
|
+
raise ValueError(
|
|
32
|
+
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class TransferQueueRole(ExplicitEnum):
|
|
37
|
+
CONTROLLER = "TransferQueueController"
|
|
38
|
+
STORAGE = "TransferQueueStorage"
|
|
39
|
+
CLIENT = "TransferQueueClient"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# production_status enum: 0: not produced, 1: ready for consume, 2: consumed
|
|
43
|
+
class ProductionStatus(ExplicitEnum):
|
|
44
|
+
NOT_PRODUCED = 0
|
|
45
|
+
READY_FOR_CONSUME = 1
|
|
46
|
+
CONSUMED = 2
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_placement_group(num_ray_actors: int, num_cpus_per_actor: int = 1):
|
|
50
|
+
"""
|
|
51
|
+
Create a placement group with SPREAD strategy for Ray actors.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
num_ray_actors (int): Number of Ray actors to create.
|
|
55
|
+
num_cpus_per_actor (int): Number of CPUs to allocate per actor.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
placement_group: The created placement group.
|
|
59
|
+
"""
|
|
60
|
+
bundle = {"CPU": num_cpus_per_actor}
|
|
61
|
+
placement_group = ray.util.placement_group([bundle for _ in range(num_ray_actors)], strategy="SPREAD")
|
|
62
|
+
ray.get(placement_group.ready())
|
|
63
|
+
return placement_group
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def sequential_sampler(
|
|
67
|
+
ready_for_consume_idx: list[int],
|
|
68
|
+
batch_size: int,
|
|
69
|
+
get_n_samples: bool,
|
|
70
|
+
n_samples_per_prompt: int,
|
|
71
|
+
) -> list[int]:
|
|
72
|
+
"""
|
|
73
|
+
Sequentially samples a batch of indices from global indexes ready_for_consume_idx.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
ready_for_consume_idx: A sorted list of available indices for sampling.
|
|
77
|
+
- When get_n_samples=True:
|
|
78
|
+
Expected to be grouped by prompts, e.g.,
|
|
79
|
+
[0,1,2,3, 8,9,10,11, 12,13,14,15] (3 groups of 4 samples each)
|
|
80
|
+
- When get_n_samples=False:
|
|
81
|
+
Can be any ordered list, e.g., [0,3,5,6,7,8]
|
|
82
|
+
batch_size: Total number of samples to return
|
|
83
|
+
get_n_samples: Flag indicating the sampling mode
|
|
84
|
+
n_samples_per_prompt: Number of samples per prompt (used when get_n_samples=True)
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
list[int]: Sequentially sampled indices of length batch_size
|
|
88
|
+
"""
|
|
89
|
+
if get_n_samples:
|
|
90
|
+
assert len(ready_for_consume_idx) % n_samples_per_prompt == 0
|
|
91
|
+
assert batch_size % n_samples_per_prompt == 0
|
|
92
|
+
batch_size_n_samples = batch_size // n_samples_per_prompt
|
|
93
|
+
|
|
94
|
+
group_ready_for_consume_idx = torch.tensor(ready_for_consume_idx, dtype=torch.int).view(
|
|
95
|
+
-1, n_samples_per_prompt
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
sampled_indexes = group_ready_for_consume_idx[list(range(batch_size_n_samples))].flatten().tolist()
|
|
99
|
+
else:
|
|
100
|
+
sampled_indexes = [int(ready_for_consume_idx[i]) for i in range(batch_size)]
|
|
101
|
+
return sampled_indexes
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@contextmanager
|
|
105
|
+
def limit_pytorch_auto_parallel_threads(target_num_threads: Optional[int] = None):
|
|
106
|
+
"""Prevent PyTorch from overdoing the automatic parallelism during torch.stack() operation"""
|
|
107
|
+
pytorch_current_num_threads = torch.get_num_threads()
|
|
108
|
+
logical_cores = psutil.cpu_count(logical=True)
|
|
109
|
+
physical_cores = psutil.cpu_count(logical=False)
|
|
110
|
+
|
|
111
|
+
if target_num_threads is None:
|
|
112
|
+
# auto determine target_num_threads
|
|
113
|
+
if physical_cores >= 16:
|
|
114
|
+
target_num_threads = 16
|
|
115
|
+
else:
|
|
116
|
+
target_num_threads = physical_cores
|
|
117
|
+
|
|
118
|
+
if target_num_threads > logical_cores:
|
|
119
|
+
raise RuntimeError(
|
|
120
|
+
f"target_num_threads {target_num_threads} should not exceed total logical CPU cores {logical_cores}"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
if pytorch_current_num_threads <= target_num_threads:
|
|
124
|
+
# No need to change settings
|
|
125
|
+
yield
|
|
126
|
+
else:
|
|
127
|
+
torch.set_num_threads(target_num_threads)
|
|
128
|
+
try:
|
|
129
|
+
yield
|
|
130
|
+
finally:
|
|
131
|
+
# Restore the original number of threads
|
|
132
|
+
torch.set_num_threads(pytorch_current_num_threads)
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import pickle
|
|
16
|
+
import socket
|
|
17
|
+
import time
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from typing import Any, Optional
|
|
20
|
+
from uuid import uuid4
|
|
21
|
+
|
|
22
|
+
import psutil
|
|
23
|
+
import zmq
|
|
24
|
+
|
|
25
|
+
from transfer_queue.utils.utils import (
|
|
26
|
+
ExplicitEnum,
|
|
27
|
+
TransferQueueRole,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ZMQRequestType(ExplicitEnum):
|
|
32
|
+
# HANDSHAKE
|
|
33
|
+
HANDSHAKE = "HANDSHAKE" # TransferQueueStorageUnit -> TransferQueueController
|
|
34
|
+
HANDSHAKE_ACK = "HANDSHAKE_ACK" # TransferQueueController -> TransferQueueStorageUnit
|
|
35
|
+
|
|
36
|
+
# DATA_OPERATION
|
|
37
|
+
GET_DATA = "GET"
|
|
38
|
+
PUT_DATA = "PUT"
|
|
39
|
+
GET_DATA_RESPONSE = "GET_DATA_RESPONSE"
|
|
40
|
+
PUT_DATA_RESPONSE = "PUT_DATA_RESPONSE"
|
|
41
|
+
CLEAR_DATA = "CLEAR_DATA"
|
|
42
|
+
CLEAR_DATA_RESPONSE = "CLEAR_DATA_RESPONSE"
|
|
43
|
+
|
|
44
|
+
PUT_GET_OPERATION_ERROR = "PUT_GET_OPERATION_ERROR"
|
|
45
|
+
PUT_GET_ERROR = "PUT_GET_ERROR"
|
|
46
|
+
PUT_ERROR = "PUT_ERROR"
|
|
47
|
+
GET_ERROR = "GET_ERROR"
|
|
48
|
+
CLEAR_DATA_ERROR = "CLEAR_DATA_ERROR"
|
|
49
|
+
|
|
50
|
+
# META_OPERATION
|
|
51
|
+
GET_META = "GET_META"
|
|
52
|
+
GET_META_RESPONSE = "GET_META_RESPONSE"
|
|
53
|
+
GET_CLEAR_META = "GET_CLEAR_META"
|
|
54
|
+
GET_CLEAR_META_RESPONSE = "GET_CLEAR_META_RESPONSE"
|
|
55
|
+
CLEAR_META = "CLEAR_META"
|
|
56
|
+
CLEAR_META_RESPONSE = "CLEAR_META_RESPONSE"
|
|
57
|
+
|
|
58
|
+
# CHECK_CONSUMPTION
|
|
59
|
+
CHECK_CONSUMPTION = "CHECK_CONSUMPTION"
|
|
60
|
+
CONSUMPTION_RESPONSE = "CONSUMPTION_RESPONSE"
|
|
61
|
+
|
|
62
|
+
# NOTIFY_DATA_UPDATE
|
|
63
|
+
NOTIFY_DATA_UPDATE = "NOTIFY_DATA_UPDATE"
|
|
64
|
+
NOTIFY_DATA_UPDATE_ACK = "NOTIFY_DATA_UPDATE_ACK"
|
|
65
|
+
NOTIFY_DATA_UPDATE_ERROR = "NOTIFY_DATA_UPDATE_ERROR"
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ZMQServerInfo:
|
|
69
|
+
def __init__(self, role: TransferQueueRole, id: str, ip: str, ports: dict[str, str]):
|
|
70
|
+
self.role = role
|
|
71
|
+
self.id = id
|
|
72
|
+
self.ip = ip
|
|
73
|
+
self.ports = ports
|
|
74
|
+
|
|
75
|
+
def to_addr(self, port_name: str) -> str:
|
|
76
|
+
return f"tcp://{self.ip}:{self.ports[port_name]}"
|
|
77
|
+
|
|
78
|
+
def to_dict(self):
|
|
79
|
+
return {
|
|
80
|
+
"role": self.role,
|
|
81
|
+
"id": self.id,
|
|
82
|
+
"ip": self.ip,
|
|
83
|
+
"ports": self.ports,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
def __str__(self) -> str:
|
|
87
|
+
return f"ZMQSocketInfo(role={self.role}, id={self.id}, ip={self.ip}, ports={self.ports})"
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class ZMQMessage:
|
|
92
|
+
request_type: ZMQRequestType
|
|
93
|
+
sender_id: str
|
|
94
|
+
receiver_id: str | None
|
|
95
|
+
body: dict[str, Any]
|
|
96
|
+
request_id: str
|
|
97
|
+
timestamp: float
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def create(
|
|
101
|
+
cls,
|
|
102
|
+
request_type: ZMQRequestType,
|
|
103
|
+
sender_id: str,
|
|
104
|
+
body: dict[str, Any],
|
|
105
|
+
receiver_id: Optional[str] = None,
|
|
106
|
+
) -> "ZMQMessage":
|
|
107
|
+
return cls(
|
|
108
|
+
request_type=request_type,
|
|
109
|
+
sender_id=sender_id,
|
|
110
|
+
receiver_id=receiver_id,
|
|
111
|
+
body=body,
|
|
112
|
+
request_id=str(uuid4().hex[:8]),
|
|
113
|
+
timestamp=time.time(),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def serialize(self) -> bytes:
|
|
117
|
+
"""Using pickle to serialize ZMQMessage objects"""
|
|
118
|
+
return pickle.dumps(self)
|
|
119
|
+
|
|
120
|
+
@classmethod
|
|
121
|
+
def deserialize(cls, data: bytes | list[bytes]):
|
|
122
|
+
"""Using pickle to deserialize ZMQMessage objects"""
|
|
123
|
+
if isinstance(data, list):
|
|
124
|
+
# Process multiple byte streams by deserializing each in sequence
|
|
125
|
+
result = []
|
|
126
|
+
for d in data:
|
|
127
|
+
result.append(pickle.loads(d))
|
|
128
|
+
return result
|
|
129
|
+
else:
|
|
130
|
+
# Single byte stream case
|
|
131
|
+
return pickle.loads(data)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def get_free_port() -> str:
|
|
135
|
+
with socket.socket() as sock:
|
|
136
|
+
sock.bind(("", 0))
|
|
137
|
+
return sock.getsockname()[1]
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def create_zmq_socket(
|
|
141
|
+
ctx: zmq.Context,
|
|
142
|
+
socket_type: Any,
|
|
143
|
+
identity: Optional[bytes] = None,
|
|
144
|
+
) -> zmq.Socket:
|
|
145
|
+
mem = psutil.virtual_memory()
|
|
146
|
+
socket = ctx.socket(socket_type)
|
|
147
|
+
|
|
148
|
+
# Calculate buffer size based on system memory
|
|
149
|
+
total_mem = mem.total / 1024**3
|
|
150
|
+
available_mem = mem.available / 1024**3
|
|
151
|
+
# For systems with substantial memory (>32GB total, >16GB available):
|
|
152
|
+
# - Set a large 0.5GB buffer to improve throughput
|
|
153
|
+
# For systems with less memory:
|
|
154
|
+
# - Use system default (-1) to avoid excessive memory consumption
|
|
155
|
+
if total_mem > 32 and available_mem > 16:
|
|
156
|
+
buf_size = int(0.5 * 1024**3) # 0.5GB in bytes
|
|
157
|
+
else:
|
|
158
|
+
buf_size = -1 # Use system default buffer size
|
|
159
|
+
|
|
160
|
+
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
|
|
161
|
+
socket.setsockopt(zmq.RCVHWM, 0)
|
|
162
|
+
socket.setsockopt(zmq.RCVBUF, buf_size)
|
|
163
|
+
|
|
164
|
+
if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
|
|
165
|
+
socket.setsockopt(zmq.SNDHWM, 0)
|
|
166
|
+
socket.setsockopt(zmq.SNDBUF, buf_size)
|
|
167
|
+
|
|
168
|
+
if identity is not None:
|
|
169
|
+
socket.setsockopt(zmq.IDENTITY, identity)
|
|
170
|
+
return socket
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
0.1.1.dev
|