sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -8
  3. sglang/compile_deep_gemm.py +177 -0
  4. sglang/lang/backend/openai.py +5 -1
  5. sglang/lang/backend/runtime_endpoint.py +5 -1
  6. sglang/srt/code_completion_parser.py +1 -1
  7. sglang/srt/configs/deepseekvl2.py +1 -1
  8. sglang/srt/configs/model_config.py +11 -2
  9. sglang/srt/constrained/llguidance_backend.py +78 -61
  10. sglang/srt/constrained/xgrammar_backend.py +1 -0
  11. sglang/srt/conversation.py +34 -1
  12. sglang/srt/disaggregation/decode.py +96 -5
  13. sglang/srt/disaggregation/mini_lb.py +113 -15
  14. sglang/srt/disaggregation/mooncake/conn.py +199 -32
  15. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  16. sglang/srt/disaggregation/nixl/conn.py +622 -0
  17. sglang/srt/disaggregation/prefill.py +119 -20
  18. sglang/srt/disaggregation/utils.py +17 -0
  19. sglang/srt/entrypoints/engine.py +4 -0
  20. sglang/srt/entrypoints/http_server.py +11 -9
  21. sglang/srt/function_call_parser.py +132 -0
  22. sglang/srt/layers/activation.py +2 -2
  23. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +809 -160
  25. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  26. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
  28. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  29. sglang/srt/layers/attention/vision.py +2 -0
  30. sglang/srt/layers/dp_attention.py +1 -1
  31. sglang/srt/layers/layernorm.py +42 -5
  32. sglang/srt/layers/logits_processor.py +2 -2
  33. sglang/srt/layers/moe/ep_moe/layer.py +2 -0
  34. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
  38. sglang/srt/layers/pooler.py +6 -0
  39. sglang/srt/layers/quantization/awq.py +5 -1
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  41. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  42. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  43. sglang/srt/layers/quantization/deep_gemm.py +385 -0
  44. sglang/srt/layers/quantization/fp8_kernel.py +7 -38
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/quantization/gptq.py +13 -7
  47. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  48. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  49. sglang/srt/layers/quantization/w8a8_int8.py +3 -3
  50. sglang/srt/layers/radix_attention.py +13 -3
  51. sglang/srt/layers/rotary_embedding.py +176 -132
  52. sglang/srt/layers/sampler.py +2 -2
  53. sglang/srt/managers/data_parallel_controller.py +17 -4
  54. sglang/srt/managers/io_struct.py +21 -3
  55. sglang/srt/managers/mm_utils.py +85 -28
  56. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  57. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  58. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  59. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  60. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  61. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  62. sglang/srt/managers/schedule_batch.py +42 -12
  63. sglang/srt/managers/scheduler.py +47 -26
  64. sglang/srt/managers/tokenizer_manager.py +120 -30
  65. sglang/srt/managers/tp_worker.py +1 -0
  66. sglang/srt/mem_cache/hiradix_cache.py +40 -32
  67. sglang/srt/mem_cache/memory_pool.py +118 -13
  68. sglang/srt/model_executor/cuda_graph_runner.py +16 -10
  69. sglang/srt/model_executor/forward_batch_info.py +51 -95
  70. sglang/srt/model_executor/model_runner.py +29 -27
  71. sglang/srt/models/deepseek.py +12 -2
  72. sglang/srt/models/deepseek_nextn.py +101 -6
  73. sglang/srt/models/deepseek_v2.py +153 -76
  74. sglang/srt/models/deepseek_vl2.py +9 -4
  75. sglang/srt/models/gemma3_causal.py +1 -1
  76. sglang/srt/models/llama4.py +0 -1
  77. sglang/srt/models/minicpm3.py +2 -2
  78. sglang/srt/models/minicpmo.py +22 -7
  79. sglang/srt/models/mllama4.py +2 -2
  80. sglang/srt/models/qwen2_5_vl.py +3 -6
  81. sglang/srt/models/qwen2_vl.py +3 -7
  82. sglang/srt/models/roberta.py +178 -0
  83. sglang/srt/openai_api/adapter.py +87 -10
  84. sglang/srt/openai_api/protocol.py +6 -1
  85. sglang/srt/server_args.py +65 -60
  86. sglang/srt/speculative/build_eagle_tree.py +2 -2
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +2 -2
  89. sglang/srt/speculative/eagle_worker.py +2 -7
  90. sglang/srt/torch_memory_saver_adapter.py +10 -1
  91. sglang/srt/utils.py +48 -6
  92. sglang/test/runners.py +6 -13
  93. sglang/test/test_utils.py +39 -19
  94. sglang/version.py +1 -1
  95. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
  96. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
  97. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
  98. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,622 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import dataclasses
5
+ import logging
6
+ import queue
7
+ import socket
8
+ import struct
9
+ import threading
10
+ import uuid
11
+ from collections import defaultdict
12
+ from functools import cache
13
+ from typing import Dict, List, Optional, Tuple, Union
14
+
15
+ import numpy as np
16
+ import numpy.typing as npt
17
+ import requests
18
+ import zmq
19
+ from aiohttp import web
20
+
21
+ from sglang.srt.disaggregation.base.conn import (
22
+ BaseKVBootstrapServer,
23
+ BaseKVManager,
24
+ BaseKVReceiver,
25
+ BaseKVSender,
26
+ KVArgs,
27
+ KVPoll,
28
+ )
29
+ from sglang.srt.disaggregation.utils import DisaggregationMode
30
+ from sglang.srt.server_args import ServerArgs
31
+ from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclasses.dataclass
37
+ class TransferInfo:
38
+ room: int
39
+ endpoint: str
40
+ dst_port: int
41
+ agent_metadata: bytes
42
+ dst_kv_ptrs: list[int]
43
+ dst_kv_indices: npt.NDArray[np.int64]
44
+ dst_aux_ptrs: list[int]
45
+ dst_aux_index: int
46
+ dst_gpu_id: int
47
+
48
+ @classmethod
49
+ def from_zmq(cls, msg: List[bytes]):
50
+ return cls(
51
+ room=int(msg[0].decode("ascii")),
52
+ endpoint=msg[1].decode("ascii"),
53
+ dst_port=int(msg[2].decode("ascii")),
54
+ agent_metadata=msg[3],
55
+ dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
56
+ dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
57
+ dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
58
+ dst_aux_index=int(msg[7].decode("ascii")),
59
+ dst_gpu_id=int(msg[8].decode("ascii")),
60
+ )
61
+
62
+
63
+ @dataclasses.dataclass
64
+ class TransferStatus:
65
+ """Used by KV Receiver to know when a transfer is done."""
66
+
67
+ # KV chunk IDs that have been received.
68
+ received_kvs: Set[int] = dataclasses.field(default_factory=set)
69
+ # Number of kv chunks to expect, will know this after last chunk is received.
70
+ num_kvs_expected: Optional[int] = None
71
+ # Whether aux data has been received.
72
+ received_aux: bool = False
73
+
74
+ def is_done(self):
75
+ if self.num_kvs_expected is None:
76
+ return False
77
+ return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
78
+
79
+
80
+ class NixlKVManager(BaseKVManager):
81
+ def __init__(
82
+ self,
83
+ args: KVArgs,
84
+ disaggregation_mode: DisaggregationMode,
85
+ server_args: ServerArgs,
86
+ ):
87
+ try:
88
+ from nixl._api import nixl_agent
89
+ except ImportError as e:
90
+ raise ImportError(
91
+ "Please install NIXL by following the instructions at "
92
+ "https://github.com/ai-dynamo/nixl/blob/main/README.md "
93
+ "to run SGLang with NixlTransferEngine."
94
+ ) from e
95
+ self.agent = nixl_agent(str(uuid.uuid4()))
96
+ self.kv_args = args
97
+ self.disaggregation_mode = disaggregation_mode
98
+ # for p/d multi node infer
99
+ self.bootstrap_port = server_args.disaggregation_bootstrap_port
100
+ self.dist_init_addr = server_args.dist_init_addr
101
+ self.rank_port = None
102
+ self.server_socket = zmq.Context().socket(zmq.PULL)
103
+ self.register_buffer_to_engine()
104
+
105
+ self.rank_port = get_free_port()
106
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
107
+ self.transfer_infos: Dict[int, TransferInfo] = {}
108
+ self.condition = threading.Condition()
109
+ self.peer_names: Dict[int, str] = {}
110
+ self._start_bootstrap_thread()
111
+ self._register_to_bootstrap()
112
+ elif self.disaggregation_mode == DisaggregationMode.DECODE:
113
+ self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
114
+ self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
115
+ TransferStatus
116
+ )
117
+ else:
118
+ raise ValueError(
119
+ f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
120
+ )
121
+
122
+ def register_buffer_to_engine(self):
123
+ kv_addrs = []
124
+ for kv_data_ptr, kv_data_len in zip(
125
+ self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
126
+ ):
127
+ kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
128
+ self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True)
129
+ if not self.kv_descs:
130
+ raise Exception("NIXL memory registration failed for kv tensors")
131
+ aux_addrs = []
132
+ for aux_data_ptr, aux_data_len in zip(
133
+ self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
134
+ ):
135
+ aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
136
+ self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True)
137
+ if not self.aux_descs:
138
+ raise Exception("NIXL memory registration failed for aux tensors")
139
+
140
+ @cache
141
+ def _connect(self, endpoint: str):
142
+ socket = zmq.Context().socket(zmq.PUSH)
143
+ socket.connect(endpoint)
144
+ return socket
145
+
146
+ def _add_remote(self, room: int, agent_metadata: bytes):
147
+ if room not in self.peer_names:
148
+ self.peer_names[room] = self.agent.add_remote_agent(agent_metadata)
149
+ return self.peer_names[room]
150
+
151
+ def send_kvcache(
152
+ self,
153
+ peer_name: str,
154
+ prefill_kv_indices: npt.NDArray[np.int64],
155
+ dst_kv_ptrs: list[int],
156
+ dst_kv_indices: npt.NDArray[np.int64],
157
+ dst_gpu_id: int,
158
+ notif: str,
159
+ ):
160
+ # Make descs
161
+ num_layers = len(self.kv_args.kv_data_ptrs)
162
+ src_addrs = []
163
+ dst_addrs = []
164
+ for layer_id in range(num_layers):
165
+ src_ptr = self.kv_args.kv_data_ptrs[layer_id]
166
+ dst_ptr = dst_kv_ptrs[layer_id]
167
+ item_len = self.kv_args.kv_item_lens[layer_id]
168
+
169
+ for prefill_index, decode_index in zip(prefill_kv_indices, dst_kv_indices):
170
+ src_addr = src_ptr + int(prefill_index) * item_len
171
+ dst_addr = dst_ptr + int(decode_index) * item_len
172
+ length = item_len
173
+ src_addrs.append((src_addr, length, self.kv_args.gpu_id))
174
+ dst_addrs.append((dst_addr, length, dst_gpu_id))
175
+ src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True)
176
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True)
177
+ # Transfer data
178
+ xfer_handle = self.agent.initialize_xfer(
179
+ "WRITE",
180
+ src_descs,
181
+ dst_descs,
182
+ peer_name,
183
+ notif.encode("ascii"),
184
+ )
185
+ if not xfer_handle:
186
+ raise Exception("KVSender failed to create transfer")
187
+ state = self.agent.transfer(xfer_handle)
188
+ if state == "ERR":
189
+ raise Exception("KVSender failed to post transfer")
190
+ return xfer_handle
191
+
192
+ def send_aux(
193
+ self,
194
+ peer_name: str,
195
+ prefill_aux_index: int,
196
+ dst_aux_ptrs: list[int],
197
+ dst_aux_index: int,
198
+ notif: str,
199
+ ):
200
+ # Make descs
201
+ aux_item_len = self.kv_args.aux_item_lens[0]
202
+ prefill_aux_addr = (
203
+ self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
204
+ )
205
+ decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
206
+ src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
207
+ dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
208
+ src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=True)
209
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=True)
210
+ # Transfer data
211
+ xfer_handle = self.agent.initialize_xfer(
212
+ "WRITE",
213
+ src_descs,
214
+ dst_descs,
215
+ peer_name,
216
+ notif.encode("ascii"),
217
+ )
218
+ if not xfer_handle:
219
+ raise Exception("KVSender failed to create transfer")
220
+ state = self.agent.transfer(xfer_handle)
221
+ if state == "ERR":
222
+ raise Exception("KVSender failed to post transfer")
223
+ return xfer_handle
224
+
225
+ def add_transfer_request(
226
+ self,
227
+ bootstrap_room: int,
228
+ kv_indices: npt.NDArray[np.int64],
229
+ index_slice: slice,
230
+ is_last: bool,
231
+ chunk_id: int,
232
+ aux_index: Optional[int] = None,
233
+ ):
234
+ assert self.disaggregation_mode == DisaggregationMode.PREFILL
235
+ assert not is_last or (is_last and aux_index is not None)
236
+
237
+ # Wait for transfer info to be populated by bootstrap thread.
238
+ with self.condition:
239
+ self.condition.wait_for(lambda: bootstrap_room in self.transfer_infos)
240
+ req = self.transfer_infos[bootstrap_room]
241
+ assert bootstrap_room == req.room
242
+
243
+ peer_name = self._add_remote(bootstrap_room, req.agent_metadata)
244
+ chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
245
+ assert len(chunked_dst_kv_indice) == len(kv_indices)
246
+
247
+ notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
248
+ kv_xfer_handle = self.send_kvcache(
249
+ peer_name,
250
+ kv_indices,
251
+ req.dst_kv_ptrs,
252
+ chunked_dst_kv_indice,
253
+ req.dst_gpu_id,
254
+ notif,
255
+ )
256
+ handles = [kv_xfer_handle]
257
+ # Only the last chunk we need to send the aux data.
258
+ if is_last:
259
+ aux_xfer_handle = self.send_aux(
260
+ peer_name,
261
+ aux_index,
262
+ req.dst_aux_ptrs,
263
+ req.dst_aux_index,
264
+ str(req.room) + "_aux",
265
+ )
266
+ handles.append(aux_xfer_handle)
267
+ return handles
268
+
269
+ def update_transfer_status(self):
270
+ # Process notifications from received transfers.
271
+ notif_map = self.agent.get_new_notifs()
272
+ for peer_name, messages in notif_map.items():
273
+ # We could also check that self.bootstrap_info['agent_name'] matches
274
+ # the message sender. But the bootstrap room alone should be
275
+ # sufficient to map the status.
276
+ for msg in messages:
277
+ components = msg.decode("ascii").split("_")
278
+ room = int(components[0])
279
+ if components[1] == "kv":
280
+ chunk_id = int(components[2])
281
+ is_last = bool(components[3])
282
+ self.transfer_statuses[room].received_kvs.add(chunk_id)
283
+ if is_last:
284
+ self.transfer_statuses[room].num_kvs_expected = chunk_id + 1
285
+ elif components[1] == "aux":
286
+ self.transfer_statuses[room].received_aux = True
287
+
288
+ def check_transfer_done(self, room: int):
289
+ if room not in self.transfer_statuses:
290
+ return False
291
+ return self.transfer_statuses[room].is_done()
292
+
293
+ def _register_to_bootstrap(self):
294
+ """Register KVSender to bootstrap server via HTTP POST."""
295
+ if self.dist_init_addr:
296
+ ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
297
+ else:
298
+ ip_address = get_ip()
299
+
300
+ bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
301
+ url = f"http://{bootstrap_server_url}/route"
302
+ payload = {
303
+ "role": "Prefill",
304
+ "rank_ip": get_local_ip_by_remote(),
305
+ "rank_port": self.rank_port,
306
+ "engine_rank": self.kv_args.engine_rank,
307
+ "agent_name": self.agent.name,
308
+ }
309
+
310
+ try:
311
+ response = requests.put(url, json=payload)
312
+ if response.status_code == 200:
313
+ logger.debug("Prefill successfully registered to bootstrap server.")
314
+ else:
315
+ logger.error(
316
+ f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
317
+ )
318
+ except Exception as e:
319
+ logger.error(f"Prefill Failed to register to bootstrap server: {e}")
320
+
321
+ def _start_bootstrap_thread(self):
322
+ self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
323
+
324
+ def bootstrap_thread():
325
+ """This thread recvs transfer info from the decode engine"""
326
+ while True:
327
+ waiting_req_bytes = self.server_socket.recv_multipart()
328
+ room = waiting_req_bytes[0].decode("ascii")
329
+ if room == "None":
330
+ continue
331
+ room = int(room)
332
+ with self.condition:
333
+ self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
334
+ self.condition.notify_all()
335
+
336
+ threading.Thread(target=bootstrap_thread).start()
337
+
338
+
339
+ class NixlKVSender(BaseKVSender):
340
+
341
+ def __init__(self, mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: int):
342
+ self.kv_mgr = mgr
343
+ self.bootstrap_room = bootstrap_room
344
+ self.aux_index = None
345
+ self.bootstrap_server_url = bootstrap_addr
346
+ self.xfer_handles = []
347
+ self.has_sent = False
348
+ self.chunk_id = 0
349
+
350
+ def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
351
+ self.num_kv_indices = num_kv_indices
352
+ self.aux_index = aux_index
353
+
354
+ def send(
355
+ self,
356
+ kv_indices: npt.NDArray[np.int64],
357
+ index_slice: slice,
358
+ is_last: bool,
359
+ ):
360
+ new_xfer_handles = self.kv_mgr.add_transfer_request(
361
+ self.bootstrap_room,
362
+ kv_indices,
363
+ index_slice,
364
+ is_last,
365
+ self.chunk_id,
366
+ self.aux_index,
367
+ )
368
+ self.xfer_handles.extend(new_xfer_handles)
369
+ self.chunk_id += 1
370
+ if is_last:
371
+ self.has_sent = True
372
+
373
+ def poll(self) -> KVPoll:
374
+ if not self.has_sent:
375
+ return KVPoll.WaitingForInput
376
+
377
+ states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
378
+ if all([x == "DONE" for x in states]):
379
+ return KVPoll.Success
380
+ if any([x == "ERR" for x in states]):
381
+ raise Exception("KVSender transfer encountered an error.")
382
+ return KVPoll.WaitingForInput
383
+
384
+ def failure_exception(self):
385
+ raise Exception("Fake KVSender Exception")
386
+
387
+
388
+ class NixlKVReceiver(BaseKVReceiver):
389
+
390
+ def __init__(
391
+ self,
392
+ mgr: NixlKVManager,
393
+ bootstrap_addr: str,
394
+ bootstrap_room: Optional[int] = None,
395
+ ):
396
+ self.bootstrap_room = bootstrap_room
397
+ self.bootstrap_addr = bootstrap_addr
398
+ self.kv_mgr = mgr
399
+ self.started_transfer = False
400
+
401
+ # NOTE: key distinguished by bootstrap_addr and engine_rank
402
+ bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
403
+
404
+ if bootstrap_key not in self.kv_mgr.connection_pool:
405
+ self.bootstrap_info = self._get_bootstrap_info_from_server(
406
+ self.kv_mgr.kv_args.engine_rank
407
+ )
408
+ if self.bootstrap_info is None:
409
+ logger.error(
410
+ f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
411
+ )
412
+ else:
413
+ self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
414
+ else:
415
+ self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
416
+
417
+ assert self.bootstrap_info is not None
418
+
419
+ def _get_bootstrap_info_from_server(self, engine_rank):
420
+ """Fetch the bootstrap info from the bootstrap server."""
421
+ try:
422
+ url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
423
+ response = requests.get(url)
424
+ if response.status_code == 200:
425
+ bootstrap_info = response.json()
426
+ return bootstrap_info
427
+ else:
428
+ logger.error(
429
+ f"Failed to get prefill server info: {response.status_code}, {response.text}"
430
+ )
431
+ return None
432
+ except Exception as e:
433
+ logger.error(f"Error fetching prefill info from bootstrap: {e}")
434
+ return None
435
+
436
+ @cache
437
+ def _connect(self, endpoint: str):
438
+ socket = zmq.Context().socket(zmq.PUSH)
439
+ socket.connect(endpoint)
440
+ return socket
441
+
442
+ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
443
+ self.prefill_server_url = (
444
+ f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
445
+ )
446
+ logger.debug(
447
+ f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
448
+ )
449
+
450
+ packed_kv_data_ptrs = b"".join(
451
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
452
+ )
453
+ packed_aux_data_ptrs = b"".join(
454
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
455
+ )
456
+ self._connect("tcp://" + self.prefill_server_url).send_multipart(
457
+ [
458
+ str(self.bootstrap_room).encode("ascii"),
459
+ get_local_ip_by_remote().encode("ascii"),
460
+ str(self.kv_mgr.rank_port).encode("ascii"),
461
+ self.kv_mgr.agent.get_agent_metadata(),
462
+ packed_kv_data_ptrs,
463
+ kv_indices.tobytes(),
464
+ packed_aux_data_ptrs,
465
+ str(aux_index).encode("ascii"),
466
+ str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
467
+ ]
468
+ )
469
+ self.started_transfer = True
470
+
471
+ def poll(self) -> KVPoll:
472
+ if not self.started_transfer:
473
+ return KVPoll.WaitingForInput
474
+
475
+ self.kv_mgr.update_transfer_status()
476
+
477
+ if self.kv_mgr.check_transfer_done(self.bootstrap_room):
478
+ return KVPoll.Success
479
+ return KVPoll.WaitingForInput
480
+
481
+ def failure_exception(self):
482
+ raise Exception("Fake KVReceiver Exception")
483
+
484
+
485
+ class NixlKVBootstrapServer(BaseKVBootstrapServer):
486
+ def __init__(self, port: int):
487
+ self.port = port
488
+ self.app = web.Application()
489
+ self.store = dict()
490
+ self.lock = asyncio.Lock()
491
+ self._setup_routes()
492
+ self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {}
493
+
494
+ # Start bootstrap server
495
+ self.thread = threading.Thread(target=self._run_server, daemon=True)
496
+ self.run()
497
+
498
+ def run(self):
499
+ self.thread.start()
500
+
501
+ def _setup_routes(self):
502
+ self.app.router.add_route("*", "/metadata", self._handle_metadata)
503
+ self.app.router.add_route("*", "/route", self._handle_route)
504
+
505
+ async def _handle_metadata(self, request: web.Request):
506
+ key = request.query.get("key", "")
507
+
508
+ if request.method == "GET":
509
+ return await self._handle_metadata_get(key)
510
+ elif request.method == "PUT":
511
+ return await self._handle_metadata_put(key, request)
512
+ elif request.method == "DELETE":
513
+ return await self._handle_metadata_delete(key)
514
+ return web.Response(
515
+ text="Method not allowed", status=405, content_type="application/json"
516
+ )
517
+
518
+ async def _handle_metadata_get(self, key):
519
+ async with self.lock:
520
+ value = self.store.get(key)
521
+ if value is None:
522
+ return web.Response(
523
+ text="metadata not found", status=404, content_type="application/json"
524
+ )
525
+ return web.Response(body=value, status=200, content_type="application/json")
526
+
527
+ async def _handle_metadata_put(self, key, request):
528
+ data = await request.read()
529
+ async with self.lock:
530
+ self.store[key] = data
531
+ return web.Response(
532
+ text="metadata updated", status=200, content_type="application/json"
533
+ )
534
+
535
+ async def _handle_metadata_delete(self, key):
536
+ async with self.lock:
537
+ if key not in self.store:
538
+ return web.Response(
539
+ text="metadata not found",
540
+ status=404,
541
+ content_type="application/json",
542
+ )
543
+ del self.store[key]
544
+ return web.Response(
545
+ text="metadata deleted", status=200, content_type="application/json"
546
+ )
547
+
548
+ async def _handle_route(self, request: web.Request):
549
+ method = request.method
550
+ if method == "PUT":
551
+ return await self._handle_route_put(request)
552
+ elif method == "GET":
553
+ return await self._handle_route_get(request)
554
+ else:
555
+ return web.Response(
556
+ text="Method not allowed", status=405, content_type="application/json"
557
+ )
558
+
559
+ async def _handle_route_put(self, request: web.Request):
560
+ data = await request.json()
561
+ role = data["role"]
562
+ rank_ip = data["rank_ip"]
563
+ rank_port = int(data["rank_port"])
564
+ engine_rank = int(data["engine_rank"])
565
+ agent_name = data["agent_name"]
566
+
567
+ # Add lock to make sure thread-safe
568
+ if role == "Prefill":
569
+ self.prefill_port_table[engine_rank] = {
570
+ "rank_ip": rank_ip,
571
+ "rank_port": rank_port,
572
+ "agent_name": agent_name,
573
+ }
574
+ logger.info(
575
+ f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port} and name: {agent_name}"
576
+ )
577
+
578
+ return web.Response(text="OK", status=200)
579
+
580
+ async def _handle_route_get(self, request: web.Request):
581
+ engine_rank = request.query.get("engine_rank")
582
+ if not engine_rank:
583
+ return web.Response(text="Missing rank", status=400)
584
+
585
+ # Find corresponding prefill info
586
+ async with self.lock:
587
+ bootstrap_info = self.prefill_port_table.get(int(engine_rank))
588
+ if bootstrap_info is not None:
589
+ return web.json_response(bootstrap_info, status=200)
590
+ else:
591
+ return web.Response(text="Not Found", status=404)
592
+
593
+ def _run_server(self):
594
+ try:
595
+ # Event Loop
596
+ self._loop = asyncio.new_event_loop()
597
+ asyncio.set_event_loop(self._loop)
598
+
599
+ self._runner = web.AppRunner(self.app)
600
+ self._loop.run_until_complete(self._runner.setup())
601
+
602
+ site = web.TCPSite(self._runner, port=self.port)
603
+ self._loop.run_until_complete(site.start())
604
+ self._loop.run_forever()
605
+ except Exception as e:
606
+ logger.error(f"Server error: {str(e)}")
607
+ finally:
608
+ # Cleanup
609
+ self._loop.run_until_complete(self._runner.cleanup())
610
+ self._loop.close()
611
+
612
+ def close(self):
613
+ """Shutdown"""
614
+ if self._loop is not None and self._loop.is_running():
615
+ self._loop.call_soon_threadsafe(self._loop.stop)
616
+ logger.info("Stopping server loop...")
617
+
618
+ if self.thread.is_alive():
619
+ self.thread.join(timeout=2)
620
+ logger.info("Server thread stopped")
621
+
622
+ def poll(self) -> KVPoll: ...