sglang 0.4.5.post2__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +3 -2
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/openai.py +5 -1
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/disaggregation/decode.py +43 -0
- sglang/srt/disaggregation/mini_lb.py +69 -8
- sglang/srt/disaggregation/mooncake/conn.py +1 -1
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +100 -16
- sglang/srt/disaggregation/utils.py +17 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +781 -150
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +19 -4
- sglang/srt/layers/moe/ep_moe/layer.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8_kernel.py +7 -38
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/w8a8_int8.py +3 -3
- sglang/srt/layers/rotary_embedding.py +6 -6
- sglang/srt/layers/sampler.py +2 -2
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/io_struct.py +14 -3
- sglang/srt/managers/schedule_batch.py +13 -0
- sglang/srt/managers/scheduler.py +16 -6
- sglang/srt/managers/tokenizer_manager.py +115 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +40 -32
- sglang/srt/mem_cache/memory_pool.py +31 -13
- sglang/srt/model_executor/cuda_graph_runner.py +13 -8
- sglang/srt/model_executor/model_runner.py +19 -4
- sglang/srt/models/deepseek_v2.py +9 -6
- sglang/srt/models/minicpm3.py +2 -2
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/server_args.py +52 -40
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +2 -7
- sglang/srt/utils.py +46 -5
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +3 -3
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +62 -57
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,622 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import dataclasses
|
5
|
+
import logging
|
6
|
+
import queue
|
7
|
+
import socket
|
8
|
+
import struct
|
9
|
+
import threading
|
10
|
+
import uuid
|
11
|
+
from collections import defaultdict
|
12
|
+
from functools import cache
|
13
|
+
from typing import Dict, List, Optional, Tuple, Union
|
14
|
+
|
15
|
+
import numpy as np
|
16
|
+
import numpy.typing as npt
|
17
|
+
import requests
|
18
|
+
import zmq
|
19
|
+
from aiohttp import web
|
20
|
+
|
21
|
+
from sglang.srt.disaggregation.base.conn import (
|
22
|
+
BaseKVBootstrapServer,
|
23
|
+
BaseKVManager,
|
24
|
+
BaseKVReceiver,
|
25
|
+
BaseKVSender,
|
26
|
+
KVArgs,
|
27
|
+
KVPoll,
|
28
|
+
)
|
29
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
30
|
+
from sglang.srt.server_args import ServerArgs
|
31
|
+
from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
|
32
|
+
|
33
|
+
logger = logging.getLogger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
@dataclasses.dataclass
|
37
|
+
class TransferInfo:
|
38
|
+
room: int
|
39
|
+
endpoint: str
|
40
|
+
dst_port: int
|
41
|
+
agent_metadata: bytes
|
42
|
+
dst_kv_ptrs: list[int]
|
43
|
+
dst_kv_indices: npt.NDArray[np.int64]
|
44
|
+
dst_aux_ptrs: list[int]
|
45
|
+
dst_aux_index: int
|
46
|
+
dst_gpu_id: int
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def from_zmq(cls, msg: List[bytes]):
|
50
|
+
return cls(
|
51
|
+
room=int(msg[0].decode("ascii")),
|
52
|
+
endpoint=msg[1].decode("ascii"),
|
53
|
+
dst_port=int(msg[2].decode("ascii")),
|
54
|
+
agent_metadata=msg[3],
|
55
|
+
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
56
|
+
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
|
57
|
+
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
58
|
+
dst_aux_index=int(msg[7].decode("ascii")),
|
59
|
+
dst_gpu_id=int(msg[8].decode("ascii")),
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
@dataclasses.dataclass
|
64
|
+
class TransferStatus:
|
65
|
+
"""Used by KV Receiver to know when a transfer is done."""
|
66
|
+
|
67
|
+
# KV chunk IDs that have been received.
|
68
|
+
received_kvs: Set[int] = dataclasses.field(default_factory=set)
|
69
|
+
# Number of kv chunks to expect, will know this after last chunk is received.
|
70
|
+
num_kvs_expected: Optional[int] = None
|
71
|
+
# Whether aux data has been received.
|
72
|
+
received_aux: bool = False
|
73
|
+
|
74
|
+
def is_done(self):
|
75
|
+
if self.num_kvs_expected is None:
|
76
|
+
return False
|
77
|
+
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
|
78
|
+
|
79
|
+
|
80
|
+
class NixlKVManager(BaseKVManager):
|
81
|
+
def __init__(
|
82
|
+
self,
|
83
|
+
args: KVArgs,
|
84
|
+
disaggregation_mode: DisaggregationMode,
|
85
|
+
server_args: ServerArgs,
|
86
|
+
):
|
87
|
+
try:
|
88
|
+
from nixl._api import nixl_agent
|
89
|
+
except ImportError as e:
|
90
|
+
raise ImportError(
|
91
|
+
"Please install NIXL by following the instructions at "
|
92
|
+
"https://github.com/ai-dynamo/nixl/blob/main/README.md "
|
93
|
+
"to run SGLang with NixlTransferEngine."
|
94
|
+
) from e
|
95
|
+
self.agent = nixl_agent(str(uuid.uuid4()))
|
96
|
+
self.kv_args = args
|
97
|
+
self.disaggregation_mode = disaggregation_mode
|
98
|
+
# for p/d multi node infer
|
99
|
+
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
100
|
+
self.dist_init_addr = server_args.dist_init_addr
|
101
|
+
self.rank_port = None
|
102
|
+
self.server_socket = zmq.Context().socket(zmq.PULL)
|
103
|
+
self.register_buffer_to_engine()
|
104
|
+
|
105
|
+
self.rank_port = get_free_port()
|
106
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
107
|
+
self.transfer_infos: Dict[int, TransferInfo] = {}
|
108
|
+
self.condition = threading.Condition()
|
109
|
+
self.peer_names: Dict[int, str] = {}
|
110
|
+
self._start_bootstrap_thread()
|
111
|
+
self._register_to_bootstrap()
|
112
|
+
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
113
|
+
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
114
|
+
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
115
|
+
TransferStatus
|
116
|
+
)
|
117
|
+
else:
|
118
|
+
raise ValueError(
|
119
|
+
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
120
|
+
)
|
121
|
+
|
122
|
+
def register_buffer_to_engine(self):
|
123
|
+
kv_addrs = []
|
124
|
+
for kv_data_ptr, kv_data_len in zip(
|
125
|
+
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
126
|
+
):
|
127
|
+
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
|
128
|
+
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True)
|
129
|
+
if not self.kv_descs:
|
130
|
+
raise Exception("NIXL memory registration failed for kv tensors")
|
131
|
+
aux_addrs = []
|
132
|
+
for aux_data_ptr, aux_data_len in zip(
|
133
|
+
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
134
|
+
):
|
135
|
+
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
|
136
|
+
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True)
|
137
|
+
if not self.aux_descs:
|
138
|
+
raise Exception("NIXL memory registration failed for aux tensors")
|
139
|
+
|
140
|
+
@cache
|
141
|
+
def _connect(self, endpoint: str):
|
142
|
+
socket = zmq.Context().socket(zmq.PUSH)
|
143
|
+
socket.connect(endpoint)
|
144
|
+
return socket
|
145
|
+
|
146
|
+
def _add_remote(self, room: int, agent_metadata: bytes):
|
147
|
+
if room not in self.peer_names:
|
148
|
+
self.peer_names[room] = self.agent.add_remote_agent(agent_metadata)
|
149
|
+
return self.peer_names[room]
|
150
|
+
|
151
|
+
def send_kvcache(
|
152
|
+
self,
|
153
|
+
peer_name: str,
|
154
|
+
prefill_kv_indices: npt.NDArray[np.int64],
|
155
|
+
dst_kv_ptrs: list[int],
|
156
|
+
dst_kv_indices: npt.NDArray[np.int64],
|
157
|
+
dst_gpu_id: int,
|
158
|
+
notif: str,
|
159
|
+
):
|
160
|
+
# Make descs
|
161
|
+
num_layers = len(self.kv_args.kv_data_ptrs)
|
162
|
+
src_addrs = []
|
163
|
+
dst_addrs = []
|
164
|
+
for layer_id in range(num_layers):
|
165
|
+
src_ptr = self.kv_args.kv_data_ptrs[layer_id]
|
166
|
+
dst_ptr = dst_kv_ptrs[layer_id]
|
167
|
+
item_len = self.kv_args.kv_item_lens[layer_id]
|
168
|
+
|
169
|
+
for prefill_index, decode_index in zip(prefill_kv_indices, dst_kv_indices):
|
170
|
+
src_addr = src_ptr + int(prefill_index) * item_len
|
171
|
+
dst_addr = dst_ptr + int(decode_index) * item_len
|
172
|
+
length = item_len
|
173
|
+
src_addrs.append((src_addr, length, self.kv_args.gpu_id))
|
174
|
+
dst_addrs.append((dst_addr, length, dst_gpu_id))
|
175
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True)
|
176
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True)
|
177
|
+
# Transfer data
|
178
|
+
xfer_handle = self.agent.initialize_xfer(
|
179
|
+
"WRITE",
|
180
|
+
src_descs,
|
181
|
+
dst_descs,
|
182
|
+
peer_name,
|
183
|
+
notif.encode("ascii"),
|
184
|
+
)
|
185
|
+
if not xfer_handle:
|
186
|
+
raise Exception("KVSender failed to create transfer")
|
187
|
+
state = self.agent.transfer(xfer_handle)
|
188
|
+
if state == "ERR":
|
189
|
+
raise Exception("KVSender failed to post transfer")
|
190
|
+
return xfer_handle
|
191
|
+
|
192
|
+
def send_aux(
|
193
|
+
self,
|
194
|
+
peer_name: str,
|
195
|
+
prefill_aux_index: int,
|
196
|
+
dst_aux_ptrs: list[int],
|
197
|
+
dst_aux_index: int,
|
198
|
+
notif: str,
|
199
|
+
):
|
200
|
+
# Make descs
|
201
|
+
aux_item_len = self.kv_args.aux_item_lens[0]
|
202
|
+
prefill_aux_addr = (
|
203
|
+
self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
|
204
|
+
)
|
205
|
+
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
206
|
+
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
|
207
|
+
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
|
208
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=True)
|
209
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=True)
|
210
|
+
# Transfer data
|
211
|
+
xfer_handle = self.agent.initialize_xfer(
|
212
|
+
"WRITE",
|
213
|
+
src_descs,
|
214
|
+
dst_descs,
|
215
|
+
peer_name,
|
216
|
+
notif.encode("ascii"),
|
217
|
+
)
|
218
|
+
if not xfer_handle:
|
219
|
+
raise Exception("KVSender failed to create transfer")
|
220
|
+
state = self.agent.transfer(xfer_handle)
|
221
|
+
if state == "ERR":
|
222
|
+
raise Exception("KVSender failed to post transfer")
|
223
|
+
return xfer_handle
|
224
|
+
|
225
|
+
def add_transfer_request(
|
226
|
+
self,
|
227
|
+
bootstrap_room: int,
|
228
|
+
kv_indices: npt.NDArray[np.int64],
|
229
|
+
index_slice: slice,
|
230
|
+
is_last: bool,
|
231
|
+
chunk_id: int,
|
232
|
+
aux_index: Optional[int] = None,
|
233
|
+
):
|
234
|
+
assert self.disaggregation_mode == DisaggregationMode.PREFILL
|
235
|
+
assert not is_last or (is_last and aux_index is not None)
|
236
|
+
|
237
|
+
# Wait for transfer info to be populated by bootstrap thread.
|
238
|
+
with self.condition:
|
239
|
+
self.condition.wait_for(lambda: bootstrap_room in self.transfer_infos)
|
240
|
+
req = self.transfer_infos[bootstrap_room]
|
241
|
+
assert bootstrap_room == req.room
|
242
|
+
|
243
|
+
peer_name = self._add_remote(bootstrap_room, req.agent_metadata)
|
244
|
+
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
|
245
|
+
assert len(chunked_dst_kv_indice) == len(kv_indices)
|
246
|
+
|
247
|
+
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
248
|
+
kv_xfer_handle = self.send_kvcache(
|
249
|
+
peer_name,
|
250
|
+
kv_indices,
|
251
|
+
req.dst_kv_ptrs,
|
252
|
+
chunked_dst_kv_indice,
|
253
|
+
req.dst_gpu_id,
|
254
|
+
notif,
|
255
|
+
)
|
256
|
+
handles = [kv_xfer_handle]
|
257
|
+
# Only the last chunk we need to send the aux data.
|
258
|
+
if is_last:
|
259
|
+
aux_xfer_handle = self.send_aux(
|
260
|
+
peer_name,
|
261
|
+
aux_index,
|
262
|
+
req.dst_aux_ptrs,
|
263
|
+
req.dst_aux_index,
|
264
|
+
str(req.room) + "_aux",
|
265
|
+
)
|
266
|
+
handles.append(aux_xfer_handle)
|
267
|
+
return handles
|
268
|
+
|
269
|
+
def update_transfer_status(self):
|
270
|
+
# Process notifications from received transfers.
|
271
|
+
notif_map = self.agent.get_new_notifs()
|
272
|
+
for peer_name, messages in notif_map.items():
|
273
|
+
# We could also check that self.bootstrap_info['agent_name'] matches
|
274
|
+
# the message sender. But the bootstrap room alone should be
|
275
|
+
# sufficient to map the status.
|
276
|
+
for msg in messages:
|
277
|
+
components = msg.decode("ascii").split("_")
|
278
|
+
room = int(components[0])
|
279
|
+
if components[1] == "kv":
|
280
|
+
chunk_id = int(components[2])
|
281
|
+
is_last = bool(components[3])
|
282
|
+
self.transfer_statuses[room].received_kvs.add(chunk_id)
|
283
|
+
if is_last:
|
284
|
+
self.transfer_statuses[room].num_kvs_expected = chunk_id + 1
|
285
|
+
elif components[1] == "aux":
|
286
|
+
self.transfer_statuses[room].received_aux = True
|
287
|
+
|
288
|
+
def check_transfer_done(self, room: int):
|
289
|
+
if room not in self.transfer_statuses:
|
290
|
+
return False
|
291
|
+
return self.transfer_statuses[room].is_done()
|
292
|
+
|
293
|
+
def _register_to_bootstrap(self):
|
294
|
+
"""Register KVSender to bootstrap server via HTTP POST."""
|
295
|
+
if self.dist_init_addr:
|
296
|
+
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
|
297
|
+
else:
|
298
|
+
ip_address = get_ip()
|
299
|
+
|
300
|
+
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
|
301
|
+
url = f"http://{bootstrap_server_url}/route"
|
302
|
+
payload = {
|
303
|
+
"role": "Prefill",
|
304
|
+
"rank_ip": get_local_ip_by_remote(),
|
305
|
+
"rank_port": self.rank_port,
|
306
|
+
"engine_rank": self.kv_args.engine_rank,
|
307
|
+
"agent_name": self.agent.name,
|
308
|
+
}
|
309
|
+
|
310
|
+
try:
|
311
|
+
response = requests.put(url, json=payload)
|
312
|
+
if response.status_code == 200:
|
313
|
+
logger.debug("Prefill successfully registered to bootstrap server.")
|
314
|
+
else:
|
315
|
+
logger.error(
|
316
|
+
f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
|
317
|
+
)
|
318
|
+
except Exception as e:
|
319
|
+
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
|
320
|
+
|
321
|
+
def _start_bootstrap_thread(self):
|
322
|
+
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
|
323
|
+
|
324
|
+
def bootstrap_thread():
|
325
|
+
"""This thread recvs transfer info from the decode engine"""
|
326
|
+
while True:
|
327
|
+
waiting_req_bytes = self.server_socket.recv_multipart()
|
328
|
+
room = waiting_req_bytes[0].decode("ascii")
|
329
|
+
if room == "None":
|
330
|
+
continue
|
331
|
+
room = int(room)
|
332
|
+
with self.condition:
|
333
|
+
self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
|
334
|
+
self.condition.notify_all()
|
335
|
+
|
336
|
+
threading.Thread(target=bootstrap_thread).start()
|
337
|
+
|
338
|
+
|
339
|
+
class NixlKVSender(BaseKVSender):
|
340
|
+
|
341
|
+
def __init__(self, mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: int):
|
342
|
+
self.kv_mgr = mgr
|
343
|
+
self.bootstrap_room = bootstrap_room
|
344
|
+
self.aux_index = None
|
345
|
+
self.bootstrap_server_url = bootstrap_addr
|
346
|
+
self.xfer_handles = []
|
347
|
+
self.has_sent = False
|
348
|
+
self.chunk_id = 0
|
349
|
+
|
350
|
+
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
351
|
+
self.num_kv_indices = num_kv_indices
|
352
|
+
self.aux_index = aux_index
|
353
|
+
|
354
|
+
def send(
|
355
|
+
self,
|
356
|
+
kv_indices: npt.NDArray[np.int64],
|
357
|
+
index_slice: slice,
|
358
|
+
is_last: bool,
|
359
|
+
):
|
360
|
+
new_xfer_handles = self.kv_mgr.add_transfer_request(
|
361
|
+
self.bootstrap_room,
|
362
|
+
kv_indices,
|
363
|
+
index_slice,
|
364
|
+
is_last,
|
365
|
+
self.chunk_id,
|
366
|
+
self.aux_index,
|
367
|
+
)
|
368
|
+
self.xfer_handles.extend(new_xfer_handles)
|
369
|
+
self.chunk_id += 1
|
370
|
+
if is_last:
|
371
|
+
self.has_sent = True
|
372
|
+
|
373
|
+
def poll(self) -> KVPoll:
|
374
|
+
if not self.has_sent:
|
375
|
+
return KVPoll.WaitingForInput
|
376
|
+
|
377
|
+
states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
|
378
|
+
if all([x == "DONE" for x in states]):
|
379
|
+
return KVPoll.Success
|
380
|
+
if any([x == "ERR" for x in states]):
|
381
|
+
raise Exception("KVSender transfer encountered an error.")
|
382
|
+
return KVPoll.WaitingForInput
|
383
|
+
|
384
|
+
def failure_exception(self):
|
385
|
+
raise Exception("Fake KVSender Exception")
|
386
|
+
|
387
|
+
|
388
|
+
class NixlKVReceiver(BaseKVReceiver):
|
389
|
+
|
390
|
+
def __init__(
|
391
|
+
self,
|
392
|
+
mgr: NixlKVManager,
|
393
|
+
bootstrap_addr: str,
|
394
|
+
bootstrap_room: Optional[int] = None,
|
395
|
+
):
|
396
|
+
self.bootstrap_room = bootstrap_room
|
397
|
+
self.bootstrap_addr = bootstrap_addr
|
398
|
+
self.kv_mgr = mgr
|
399
|
+
self.started_transfer = False
|
400
|
+
|
401
|
+
# NOTE: key distinguished by bootstrap_addr and engine_rank
|
402
|
+
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
|
403
|
+
|
404
|
+
if bootstrap_key not in self.kv_mgr.connection_pool:
|
405
|
+
self.bootstrap_info = self._get_bootstrap_info_from_server(
|
406
|
+
self.kv_mgr.kv_args.engine_rank
|
407
|
+
)
|
408
|
+
if self.bootstrap_info is None:
|
409
|
+
logger.error(
|
410
|
+
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
411
|
+
)
|
412
|
+
else:
|
413
|
+
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
|
414
|
+
else:
|
415
|
+
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
|
416
|
+
|
417
|
+
assert self.bootstrap_info is not None
|
418
|
+
|
419
|
+
def _get_bootstrap_info_from_server(self, engine_rank):
|
420
|
+
"""Fetch the bootstrap info from the bootstrap server."""
|
421
|
+
try:
|
422
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
|
423
|
+
response = requests.get(url)
|
424
|
+
if response.status_code == 200:
|
425
|
+
bootstrap_info = response.json()
|
426
|
+
return bootstrap_info
|
427
|
+
else:
|
428
|
+
logger.error(
|
429
|
+
f"Failed to get prefill server info: {response.status_code}, {response.text}"
|
430
|
+
)
|
431
|
+
return None
|
432
|
+
except Exception as e:
|
433
|
+
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
434
|
+
return None
|
435
|
+
|
436
|
+
@cache
|
437
|
+
def _connect(self, endpoint: str):
|
438
|
+
socket = zmq.Context().socket(zmq.PUSH)
|
439
|
+
socket.connect(endpoint)
|
440
|
+
return socket
|
441
|
+
|
442
|
+
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
443
|
+
self.prefill_server_url = (
|
444
|
+
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
|
445
|
+
)
|
446
|
+
logger.debug(
|
447
|
+
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
448
|
+
)
|
449
|
+
|
450
|
+
packed_kv_data_ptrs = b"".join(
|
451
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
452
|
+
)
|
453
|
+
packed_aux_data_ptrs = b"".join(
|
454
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
455
|
+
)
|
456
|
+
self._connect("tcp://" + self.prefill_server_url).send_multipart(
|
457
|
+
[
|
458
|
+
str(self.bootstrap_room).encode("ascii"),
|
459
|
+
get_local_ip_by_remote().encode("ascii"),
|
460
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
461
|
+
self.kv_mgr.agent.get_agent_metadata(),
|
462
|
+
packed_kv_data_ptrs,
|
463
|
+
kv_indices.tobytes(),
|
464
|
+
packed_aux_data_ptrs,
|
465
|
+
str(aux_index).encode("ascii"),
|
466
|
+
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
467
|
+
]
|
468
|
+
)
|
469
|
+
self.started_transfer = True
|
470
|
+
|
471
|
+
def poll(self) -> KVPoll:
|
472
|
+
if not self.started_transfer:
|
473
|
+
return KVPoll.WaitingForInput
|
474
|
+
|
475
|
+
self.kv_mgr.update_transfer_status()
|
476
|
+
|
477
|
+
if self.kv_mgr.check_transfer_done(self.bootstrap_room):
|
478
|
+
return KVPoll.Success
|
479
|
+
return KVPoll.WaitingForInput
|
480
|
+
|
481
|
+
def failure_exception(self):
|
482
|
+
raise Exception("Fake KVReceiver Exception")
|
483
|
+
|
484
|
+
|
485
|
+
class NixlKVBootstrapServer(BaseKVBootstrapServer):
|
486
|
+
def __init__(self, port: int):
|
487
|
+
self.port = port
|
488
|
+
self.app = web.Application()
|
489
|
+
self.store = dict()
|
490
|
+
self.lock = asyncio.Lock()
|
491
|
+
self._setup_routes()
|
492
|
+
self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {}
|
493
|
+
|
494
|
+
# Start bootstrap server
|
495
|
+
self.thread = threading.Thread(target=self._run_server, daemon=True)
|
496
|
+
self.run()
|
497
|
+
|
498
|
+
def run(self):
|
499
|
+
self.thread.start()
|
500
|
+
|
501
|
+
def _setup_routes(self):
|
502
|
+
self.app.router.add_route("*", "/metadata", self._handle_metadata)
|
503
|
+
self.app.router.add_route("*", "/route", self._handle_route)
|
504
|
+
|
505
|
+
async def _handle_metadata(self, request: web.Request):
|
506
|
+
key = request.query.get("key", "")
|
507
|
+
|
508
|
+
if request.method == "GET":
|
509
|
+
return await self._handle_metadata_get(key)
|
510
|
+
elif request.method == "PUT":
|
511
|
+
return await self._handle_metadata_put(key, request)
|
512
|
+
elif request.method == "DELETE":
|
513
|
+
return await self._handle_metadata_delete(key)
|
514
|
+
return web.Response(
|
515
|
+
text="Method not allowed", status=405, content_type="application/json"
|
516
|
+
)
|
517
|
+
|
518
|
+
async def _handle_metadata_get(self, key):
|
519
|
+
async with self.lock:
|
520
|
+
value = self.store.get(key)
|
521
|
+
if value is None:
|
522
|
+
return web.Response(
|
523
|
+
text="metadata not found", status=404, content_type="application/json"
|
524
|
+
)
|
525
|
+
return web.Response(body=value, status=200, content_type="application/json")
|
526
|
+
|
527
|
+
async def _handle_metadata_put(self, key, request):
|
528
|
+
data = await request.read()
|
529
|
+
async with self.lock:
|
530
|
+
self.store[key] = data
|
531
|
+
return web.Response(
|
532
|
+
text="metadata updated", status=200, content_type="application/json"
|
533
|
+
)
|
534
|
+
|
535
|
+
async def _handle_metadata_delete(self, key):
|
536
|
+
async with self.lock:
|
537
|
+
if key not in self.store:
|
538
|
+
return web.Response(
|
539
|
+
text="metadata not found",
|
540
|
+
status=404,
|
541
|
+
content_type="application/json",
|
542
|
+
)
|
543
|
+
del self.store[key]
|
544
|
+
return web.Response(
|
545
|
+
text="metadata deleted", status=200, content_type="application/json"
|
546
|
+
)
|
547
|
+
|
548
|
+
async def _handle_route(self, request: web.Request):
|
549
|
+
method = request.method
|
550
|
+
if method == "PUT":
|
551
|
+
return await self._handle_route_put(request)
|
552
|
+
elif method == "GET":
|
553
|
+
return await self._handle_route_get(request)
|
554
|
+
else:
|
555
|
+
return web.Response(
|
556
|
+
text="Method not allowed", status=405, content_type="application/json"
|
557
|
+
)
|
558
|
+
|
559
|
+
async def _handle_route_put(self, request: web.Request):
|
560
|
+
data = await request.json()
|
561
|
+
role = data["role"]
|
562
|
+
rank_ip = data["rank_ip"]
|
563
|
+
rank_port = int(data["rank_port"])
|
564
|
+
engine_rank = int(data["engine_rank"])
|
565
|
+
agent_name = data["agent_name"]
|
566
|
+
|
567
|
+
# Add lock to make sure thread-safe
|
568
|
+
if role == "Prefill":
|
569
|
+
self.prefill_port_table[engine_rank] = {
|
570
|
+
"rank_ip": rank_ip,
|
571
|
+
"rank_port": rank_port,
|
572
|
+
"agent_name": agent_name,
|
573
|
+
}
|
574
|
+
logger.info(
|
575
|
+
f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port} and name: {agent_name}"
|
576
|
+
)
|
577
|
+
|
578
|
+
return web.Response(text="OK", status=200)
|
579
|
+
|
580
|
+
async def _handle_route_get(self, request: web.Request):
|
581
|
+
engine_rank = request.query.get("engine_rank")
|
582
|
+
if not engine_rank:
|
583
|
+
return web.Response(text="Missing rank", status=400)
|
584
|
+
|
585
|
+
# Find corresponding prefill info
|
586
|
+
async with self.lock:
|
587
|
+
bootstrap_info = self.prefill_port_table.get(int(engine_rank))
|
588
|
+
if bootstrap_info is not None:
|
589
|
+
return web.json_response(bootstrap_info, status=200)
|
590
|
+
else:
|
591
|
+
return web.Response(text="Not Found", status=404)
|
592
|
+
|
593
|
+
def _run_server(self):
|
594
|
+
try:
|
595
|
+
# Event Loop
|
596
|
+
self._loop = asyncio.new_event_loop()
|
597
|
+
asyncio.set_event_loop(self._loop)
|
598
|
+
|
599
|
+
self._runner = web.AppRunner(self.app)
|
600
|
+
self._loop.run_until_complete(self._runner.setup())
|
601
|
+
|
602
|
+
site = web.TCPSite(self._runner, port=self.port)
|
603
|
+
self._loop.run_until_complete(site.start())
|
604
|
+
self._loop.run_forever()
|
605
|
+
except Exception as e:
|
606
|
+
logger.error(f"Server error: {str(e)}")
|
607
|
+
finally:
|
608
|
+
# Cleanup
|
609
|
+
self._loop.run_until_complete(self._runner.cleanup())
|
610
|
+
self._loop.close()
|
611
|
+
|
612
|
+
def close(self):
|
613
|
+
"""Shutdown"""
|
614
|
+
if self._loop is not None and self._loop.is_running():
|
615
|
+
self._loop.call_soon_threadsafe(self._loop.stop)
|
616
|
+
logger.info("Stopping server loop...")
|
617
|
+
|
618
|
+
if self.thread.is_alive():
|
619
|
+
self.thread.join(timeout=2)
|
620
|
+
logger.info("Server thread stopped")
|
621
|
+
|
622
|
+
def poll(self) -> KVPoll: ...
|