sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +49 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +35 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -6
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +100 -52
- sglang/srt/disaggregation/prefill.py +5 -4
- sglang/srt/disaggregation/utils.py +13 -12
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +45 -9
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +51 -6
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +56 -23
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +41 -0
- sglang/srt/layers/linear.py +99 -12
- sglang/srt/layers/logits_processor.py +15 -6
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +115 -25
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +36 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +6 -6
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +105 -13
- sglang/srt/layers/vocab_parallel_embedding.py +19 -2
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +60 -15
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +80 -79
- sglang/srt/managers/scheduler.py +153 -63
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +302 -58
- sglang/srt/model_loader/loader.py +86 -10
- sglang/srt/model_loader/weight_utils.py +160 -3
- sglang/srt/models/deepseek_nextn.py +5 -4
- sglang/srt/models/deepseek_v2.py +305 -26
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1010 -0
- sglang/srt/models/gemma3n_mm.py +495 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/multimodal/processors/gemma3n.py +82 -0
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +85 -24
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +204 -28
- sglang/srt/utils.py +369 -138
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -31,23 +31,19 @@ from sglang.srt.utils import get_local_ip_by_remote
|
|
31
31
|
|
32
32
|
logger = logging.getLogger(__name__)
|
33
33
|
|
34
|
-
NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
|
35
|
-
|
36
34
|
GUARD = "NixlMsgGuard".encode("ascii")
|
37
35
|
|
38
36
|
|
39
37
|
@dataclasses.dataclass
|
40
38
|
class TransferInfo:
|
39
|
+
"""Contains indices for a transfer, sent by KVReceiver. Received by prefill bootstrap thread."""
|
40
|
+
|
41
41
|
room: int
|
42
42
|
endpoint: str
|
43
43
|
dst_port: int
|
44
|
-
agent_metadata: bytes
|
45
44
|
agent_name: str
|
46
|
-
dst_kv_ptrs: list[int]
|
47
45
|
dst_kv_indices: npt.NDArray[np.int32]
|
48
|
-
dst_aux_ptrs: list[int]
|
49
46
|
dst_aux_index: int
|
50
|
-
dst_gpu_id: int
|
51
47
|
required_dst_info_num: int
|
52
48
|
|
53
49
|
def is_dummy(self):
|
@@ -59,14 +55,37 @@ class TransferInfo:
|
|
59
55
|
room=int(msg[0].decode("ascii")),
|
60
56
|
endpoint=msg[1].decode("ascii"),
|
61
57
|
dst_port=int(msg[2].decode("ascii")),
|
62
|
-
|
63
|
-
|
58
|
+
agent_name=msg[3].decode("ascii"),
|
59
|
+
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int32),
|
60
|
+
dst_aux_index=int(msg[5].decode("ascii")),
|
61
|
+
required_dst_info_num=int(msg[6].decode("ascii")),
|
62
|
+
)
|
63
|
+
|
64
|
+
|
65
|
+
@dataclasses.dataclass
|
66
|
+
class KVArgsRegisterInfo:
|
67
|
+
"""Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread."""
|
68
|
+
|
69
|
+
room: str
|
70
|
+
endpoint: str
|
71
|
+
dst_port: int
|
72
|
+
agent_name: str
|
73
|
+
agent_metadata: bytes
|
74
|
+
dst_kv_ptrs: list[int]
|
75
|
+
dst_aux_ptrs: list[int]
|
76
|
+
gpu_id: int
|
77
|
+
|
78
|
+
@classmethod
|
79
|
+
def from_zmq(cls, msg: List[bytes]):
|
80
|
+
return cls(
|
81
|
+
room=str(msg[0].decode("ascii")),
|
82
|
+
endpoint=msg[1].decode("ascii"),
|
83
|
+
dst_port=int(msg[2].decode("ascii")),
|
84
|
+
agent_name=msg[3].decode("ascii"),
|
85
|
+
agent_metadata=msg[4],
|
64
86
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
65
|
-
|
66
|
-
|
67
|
-
dst_aux_index=int(msg[8].decode("ascii")),
|
68
|
-
dst_gpu_id=int(msg[9].decode("ascii")),
|
69
|
-
required_dst_info_num=int(msg[10].decode("ascii")),
|
87
|
+
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
88
|
+
gpu_id=int(msg[7].decode("ascii")),
|
70
89
|
)
|
71
90
|
|
72
91
|
|
@@ -109,9 +128,9 @@ class NixlKVManager(CommonKVManager):
|
|
109
128
|
self.register_buffer_to_engine()
|
110
129
|
|
111
130
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
112
|
-
self.request_status = {}
|
113
|
-
self.transfer_infos: Dict[int, TransferInfo] = {}
|
114
|
-
self.
|
131
|
+
self.request_status: Dict[int, KVPoll] = {}
|
132
|
+
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
133
|
+
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
115
134
|
self._start_bootstrap_thread()
|
116
135
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
117
136
|
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
@@ -140,7 +159,7 @@ class NixlKVManager(CommonKVManager):
|
|
140
159
|
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
141
160
|
):
|
142
161
|
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
|
143
|
-
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=
|
162
|
+
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=False)
|
144
163
|
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
|
145
164
|
if not self.kv_descs:
|
146
165
|
raise Exception("NIXL memory registration failed for kv tensors")
|
@@ -149,15 +168,18 @@ class NixlKVManager(CommonKVManager):
|
|
149
168
|
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
150
169
|
):
|
151
170
|
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
|
152
|
-
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=
|
171
|
+
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=False)
|
153
172
|
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
|
154
173
|
if not self.aux_descs:
|
155
174
|
raise Exception("NIXL memory registration failed for aux tensors")
|
156
175
|
|
157
|
-
def
|
158
|
-
|
159
|
-
|
160
|
-
|
176
|
+
def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo):
|
177
|
+
agent_name = decode_kv_args.agent_name
|
178
|
+
if agent_name in self.decode_kv_args_table:
|
179
|
+
logger.info(f"Peer {agent_name} was already registered, ignoring.")
|
180
|
+
return
|
181
|
+
self.decode_kv_args_table[agent_name] = decode_kv_args
|
182
|
+
self.agent.add_remote_agent(decode_kv_args.agent_metadata)
|
161
183
|
|
162
184
|
def send_kvcache(
|
163
185
|
self,
|
@@ -193,8 +215,8 @@ class NixlKVManager(CommonKVManager):
|
|
193
215
|
logger.debug(
|
194
216
|
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
|
195
217
|
)
|
196
|
-
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=
|
197
|
-
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=
|
218
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=False)
|
219
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=False)
|
198
220
|
# Transfer data
|
199
221
|
xfer_handle = self.agent.initialize_xfer(
|
200
222
|
"WRITE",
|
@@ -226,8 +248,8 @@ class NixlKVManager(CommonKVManager):
|
|
226
248
|
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
227
249
|
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
|
228
250
|
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
|
229
|
-
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=
|
230
|
-
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=
|
251
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=False)
|
252
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=False)
|
231
253
|
# Transfer data
|
232
254
|
xfer_handle = self.agent.initialize_xfer(
|
233
255
|
"WRITE",
|
@@ -262,17 +284,17 @@ class NixlKVManager(CommonKVManager):
|
|
262
284
|
if req.is_dummy():
|
263
285
|
continue
|
264
286
|
|
265
|
-
peer_name = self._add_remote(req.agent_name, req.agent_metadata)
|
266
287
|
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
|
267
288
|
assert len(chunked_dst_kv_indice) == len(kv_indices)
|
289
|
+
assert req.agent_name in self.decode_kv_args_table
|
268
290
|
|
269
291
|
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
270
292
|
kv_xfer_handle = self.send_kvcache(
|
271
|
-
|
293
|
+
req.agent_name,
|
272
294
|
kv_indices,
|
273
|
-
req.dst_kv_ptrs,
|
295
|
+
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
274
296
|
chunked_dst_kv_indice,
|
275
|
-
req.
|
297
|
+
self.decode_kv_args_table[req.agent_name].gpu_id,
|
276
298
|
notif,
|
277
299
|
)
|
278
300
|
handles.append(kv_xfer_handle)
|
@@ -280,13 +302,15 @@ class NixlKVManager(CommonKVManager):
|
|
280
302
|
if is_last:
|
281
303
|
assert aux_index is not None
|
282
304
|
aux_xfer_handle = self.send_aux(
|
283
|
-
|
305
|
+
req.agent_name,
|
284
306
|
aux_index,
|
285
|
-
req.dst_aux_ptrs,
|
307
|
+
self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
|
286
308
|
req.dst_aux_index,
|
287
309
|
str(req.room) + "_aux",
|
288
310
|
)
|
289
311
|
handles.append(aux_xfer_handle)
|
312
|
+
if is_last:
|
313
|
+
del self.transfer_infos[bootstrap_room]
|
290
314
|
return handles
|
291
315
|
|
292
316
|
def update_transfer_status(self):
|
@@ -328,16 +352,23 @@ class NixlKVManager(CommonKVManager):
|
|
328
352
|
), f"First message should be {GUARD}. Foreign traffic?"
|
329
353
|
waiting_req_bytes = waiting_req_bytes[1:]
|
330
354
|
room = waiting_req_bytes[0].decode("ascii")
|
331
|
-
|
332
|
-
|
355
|
+
agent_name = waiting_req_bytes[3].decode("ascii")
|
356
|
+
if room == "None":
|
357
|
+
# Register new peer and save KV base pointers.
|
358
|
+
self._add_remote_peer(
|
359
|
+
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
|
360
|
+
)
|
361
|
+
logger.debug(f"Register KVArgs from {agent_name} successfully")
|
362
|
+
continue
|
333
363
|
room = int(room)
|
334
|
-
agent_name = waiting_req_bytes[4].decode("ascii")
|
335
364
|
if room not in self.transfer_infos:
|
336
365
|
self.transfer_infos[room] = {}
|
337
366
|
self.transfer_infos[room][agent_name] = TransferInfo.from_zmq(
|
338
367
|
waiting_req_bytes
|
339
368
|
)
|
340
|
-
|
369
|
+
required_dst_info_num = self.transfer_infos[room][
|
370
|
+
agent_name
|
371
|
+
].required_dst_info_num
|
341
372
|
logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}")
|
342
373
|
if len(self.transfer_infos[room]) == required_dst_info_num:
|
343
374
|
logger.debug(f"{room=} is bootstrapped")
|
@@ -391,6 +422,7 @@ class NixlKVSender(BaseKVSender):
|
|
391
422
|
self.chunk_id += 1
|
392
423
|
if is_last:
|
393
424
|
self.has_sent = True
|
425
|
+
del self.kv_mgr.request_status[self.bootstrap_room]
|
394
426
|
|
395
427
|
def poll(self) -> KVPoll:
|
396
428
|
if not self.has_sent:
|
@@ -415,6 +447,7 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
415
447
|
data_parallel_rank: Optional[int] = None,
|
416
448
|
):
|
417
449
|
self.started_transfer = False
|
450
|
+
self.conclude_state = None
|
418
451
|
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
|
419
452
|
|
420
453
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
@@ -426,17 +459,8 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
426
459
|
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
427
460
|
)
|
428
461
|
is_dummy = bootstrap_info["is_dummy"]
|
429
|
-
|
430
|
-
# TODO: send_kv_args earlier
|
431
|
-
packed_kv_data_ptrs = b"".join(
|
432
|
-
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
433
|
-
)
|
434
|
-
packed_aux_data_ptrs = b"".join(
|
435
|
-
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
436
|
-
)
|
437
|
-
|
438
462
|
logger.debug(
|
439
|
-
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
|
463
|
+
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}"
|
440
464
|
)
|
441
465
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
442
466
|
with lock:
|
@@ -446,13 +470,9 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
446
470
|
str(self.bootstrap_room).encode("ascii"),
|
447
471
|
get_local_ip_by_remote().encode("ascii"),
|
448
472
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
449
|
-
self.kv_mgr.agent.get_agent_metadata(),
|
450
473
|
self.kv_mgr.agent.name.encode("ascii"),
|
451
|
-
packed_kv_data_ptrs,
|
452
474
|
kv_indices.tobytes() if not is_dummy else b"",
|
453
|
-
packed_aux_data_ptrs,
|
454
475
|
str(aux_index).encode("ascii"),
|
455
|
-
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
456
476
|
str(self.required_dst_info_num).encode("ascii"),
|
457
477
|
]
|
458
478
|
)
|
@@ -460,17 +480,45 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
460
480
|
self.started_transfer = True
|
461
481
|
|
462
482
|
def poll(self) -> KVPoll:
|
483
|
+
if self.conclude_state is not None:
|
484
|
+
return self.conclude_state
|
463
485
|
if not self.started_transfer:
|
464
486
|
return KVPoll.WaitingForInput # type: ignore
|
465
487
|
|
466
488
|
self.kv_mgr.update_transfer_status()
|
467
|
-
|
468
489
|
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
|
490
|
+
self.conclude_state = KVPoll.Success
|
491
|
+
del self.kv_mgr.transfer_statuses[self.bootstrap_room]
|
469
492
|
return KVPoll.Success # type: ignore
|
470
493
|
return KVPoll.WaitingForInput # type: ignore
|
471
494
|
|
472
495
|
def _register_kv_args(self):
|
473
|
-
|
496
|
+
for bootstrap_info in self.bootstrap_infos:
|
497
|
+
self.prefill_server_url = (
|
498
|
+
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
499
|
+
)
|
500
|
+
packed_kv_data_ptrs = b"".join(
|
501
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
502
|
+
)
|
503
|
+
packed_aux_data_ptrs = b"".join(
|
504
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
505
|
+
)
|
506
|
+
|
507
|
+
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
508
|
+
with lock:
|
509
|
+
sock.send_multipart(
|
510
|
+
[
|
511
|
+
GUARD,
|
512
|
+
"None".encode("ascii"),
|
513
|
+
get_local_ip_by_remote().encode("ascii"),
|
514
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
515
|
+
self.kv_mgr.agent.name.encode("ascii"),
|
516
|
+
self.kv_mgr.agent.get_agent_metadata(),
|
517
|
+
packed_kv_data_ptrs,
|
518
|
+
packed_aux_data_ptrs,
|
519
|
+
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
520
|
+
]
|
521
|
+
)
|
474
522
|
|
475
523
|
def failure_exception(self):
|
476
524
|
raise Exception("Fake KVReceiver Exception")
|
@@ -93,8 +93,6 @@ class PrefillBootstrapQueue:
|
|
93
93
|
self.gpu_id = gpu_id
|
94
94
|
self.bootstrap_port = bootstrap_port
|
95
95
|
self.queue: List[Req] = []
|
96
|
-
self.pp_rank = pp_rank
|
97
|
-
self.pp_size = pp_size
|
98
96
|
self.gloo_group = gloo_group
|
99
97
|
self.max_total_num_tokens = max_total_num_tokens
|
100
98
|
self.scheduler = scheduler
|
@@ -124,6 +122,9 @@ class PrefillBootstrapQueue:
|
|
124
122
|
kv_args.kv_data_ptrs = kv_data_ptrs
|
125
123
|
kv_args.kv_data_lens = kv_data_lens
|
126
124
|
kv_args.kv_item_lens = kv_item_lens
|
125
|
+
if not self.is_mla_backend:
|
126
|
+
kv_args.kv_head_num = self.token_to_kv_pool.head_num
|
127
|
+
kv_args.page_size = self.token_to_kv_pool.page_size
|
127
128
|
|
128
129
|
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
129
130
|
self.metadata_buffers.get_buf_infos()
|
@@ -275,7 +276,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
275
276
|
batch = self.get_new_batch_prefill()
|
276
277
|
|
277
278
|
if require_mlp_sync(self.server_args):
|
278
|
-
batch
|
279
|
+
batch = self.prepare_mlp_sync_batch(batch)
|
279
280
|
self.cur_batch = batch
|
280
281
|
|
281
282
|
if batch:
|
@@ -309,7 +310,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
309
310
|
batch = self.get_new_batch_prefill()
|
310
311
|
|
311
312
|
if require_mlp_sync(self.server_args):
|
312
|
-
batch
|
313
|
+
batch = self.prepare_mlp_sync_batch(batch)
|
313
314
|
self.cur_batch = batch
|
314
315
|
if batch:
|
315
316
|
result = self.run_batch(batch)
|
@@ -74,7 +74,7 @@ class ReqToMetadataIdxAllocator:
|
|
74
74
|
def available_size(self):
|
75
75
|
return len(self.free_slots)
|
76
76
|
|
77
|
-
def alloc(self) ->
|
77
|
+
def alloc(self) -> Optional[int]:
|
78
78
|
if len(self.free_slots) == 0:
|
79
79
|
return None
|
80
80
|
|
@@ -107,9 +107,6 @@ class MetadataBuffers:
|
|
107
107
|
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
108
108
|
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
109
109
|
|
110
|
-
self.output_hidden_states = torch.zeros(
|
111
|
-
(size, hidden_size), dtype=dtype, device=device
|
112
|
-
)
|
113
110
|
self.output_token_logprobs_val = torch.zeros(
|
114
111
|
(size, 16), dtype=torch.float32, device=device
|
115
112
|
)
|
@@ -122,51 +119,50 @@ class MetadataBuffers:
|
|
122
119
|
self.output_top_logprobs_idx = torch.zeros(
|
123
120
|
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
124
121
|
)
|
122
|
+
self.output_hidden_states = torch.zeros(
|
123
|
+
(size, hidden_size), dtype=dtype, device=device
|
124
|
+
)
|
125
125
|
|
126
126
|
def get_buf_infos(self):
|
127
127
|
ptrs = [
|
128
128
|
self.output_ids.data_ptr(),
|
129
|
-
self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None
|
130
129
|
self.output_token_logprobs_val.data_ptr(),
|
131
130
|
self.output_token_logprobs_idx.data_ptr(),
|
132
131
|
self.output_top_logprobs_val.data_ptr(),
|
133
132
|
self.output_top_logprobs_idx.data_ptr(),
|
133
|
+
self.output_hidden_states.data_ptr(),
|
134
134
|
]
|
135
135
|
data_lens = [
|
136
136
|
self.output_ids.nbytes,
|
137
|
-
self.output_hidden_states.nbytes,
|
138
137
|
self.output_token_logprobs_val.nbytes,
|
139
138
|
self.output_token_logprobs_idx.nbytes,
|
140
139
|
self.output_top_logprobs_val.nbytes,
|
141
140
|
self.output_top_logprobs_idx.nbytes,
|
141
|
+
self.output_hidden_states.nbytes,
|
142
142
|
]
|
143
143
|
item_lens = [
|
144
144
|
self.output_ids[0].nbytes,
|
145
|
-
self.output_hidden_states[0].nbytes,
|
146
145
|
self.output_token_logprobs_val[0].nbytes,
|
147
146
|
self.output_token_logprobs_idx[0].nbytes,
|
148
147
|
self.output_top_logprobs_val[0].nbytes,
|
149
148
|
self.output_top_logprobs_idx[0].nbytes,
|
149
|
+
self.output_hidden_states[0].nbytes,
|
150
150
|
]
|
151
151
|
return ptrs, data_lens, item_lens
|
152
152
|
|
153
153
|
def get_buf(self, idx: int):
|
154
154
|
return (
|
155
155
|
self.output_ids[idx],
|
156
|
-
self.output_hidden_states[idx],
|
157
156
|
self.output_token_logprobs_val[idx],
|
158
157
|
self.output_token_logprobs_idx[idx],
|
159
158
|
self.output_top_logprobs_val[idx],
|
160
159
|
self.output_top_logprobs_idx[idx],
|
160
|
+
self.output_hidden_states[idx],
|
161
161
|
)
|
162
162
|
|
163
163
|
def set_buf(self, req: Req):
|
164
164
|
|
165
165
|
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
166
|
-
if req.hidden_states_tensor is not None:
|
167
|
-
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
168
|
-
req.hidden_states_tensor
|
169
|
-
)
|
170
166
|
if req.return_logprob:
|
171
167
|
if req.output_token_logprobs_val: # not none or empty list
|
172
168
|
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
@@ -189,6 +185,11 @@ class MetadataBuffers:
|
|
189
185
|
] = torch.tensor(
|
190
186
|
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
191
187
|
)
|
188
|
+
# for PD + spec decode
|
189
|
+
if req.hidden_states_tensor is not None:
|
190
|
+
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
191
|
+
req.hidden_states_tensor
|
192
|
+
)
|
192
193
|
|
193
194
|
|
194
195
|
#########################
|
@@ -42,8 +42,10 @@ from torch.distributed import Backend, ProcessGroup
|
|
42
42
|
from sglang.srt.utils import (
|
43
43
|
direct_register_custom_op,
|
44
44
|
get_bool_env_var,
|
45
|
+
get_int_env_var,
|
45
46
|
is_cuda_alike,
|
46
47
|
is_npu,
|
48
|
+
is_shm_available,
|
47
49
|
supports_custom_op,
|
48
50
|
)
|
49
51
|
|
@@ -222,6 +224,7 @@ class GroupCoordinator:
|
|
222
224
|
self.local_rank = local_rank
|
223
225
|
self.device_group = None
|
224
226
|
self.cpu_group = None
|
227
|
+
self.local_size = get_int_env_var("LOCAL_SIZE", 0)
|
225
228
|
|
226
229
|
for ranks in group_ranks:
|
227
230
|
device_group = torch.distributed.new_group(
|
@@ -440,9 +443,12 @@ class GroupCoordinator:
|
|
440
443
|
return input_
|
441
444
|
|
442
445
|
if input_.is_cpu:
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
+
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
447
|
+
torch.ops.sgl_kernel.shm_allreduce(
|
448
|
+
input_, torch.distributed.ReduceOp.SUM
|
449
|
+
)
|
450
|
+
else:
|
451
|
+
torch.distributed.all_reduce(input_, group=self.device_group)
|
446
452
|
return input_
|
447
453
|
|
448
454
|
if not supports_custom_op():
|
@@ -570,6 +576,16 @@ class GroupCoordinator:
|
|
570
576
|
output_tensor = torch.empty(
|
571
577
|
output_size, dtype=input_.dtype, device=input_.device
|
572
578
|
)
|
579
|
+
|
580
|
+
if input_.is_cpu:
|
581
|
+
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
582
|
+
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
583
|
+
else:
|
584
|
+
torch.distributed.all_gather_into_tensor(
|
585
|
+
output_tensor, input_, group=self.device_group
|
586
|
+
)
|
587
|
+
return output_tensor
|
588
|
+
|
573
589
|
# All-gather.
|
574
590
|
self.all_gather_into_tensor(output_tensor, input_)
|
575
591
|
# Reshape
|
@@ -683,18 +699,25 @@ class GroupCoordinator:
|
|
683
699
|
)
|
684
700
|
|
685
701
|
# Serialize object to tensor and get the size as well
|
686
|
-
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
|
702
|
+
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
|
703
|
+
device=torch.cuda.current_device()
|
704
|
+
)
|
687
705
|
|
688
706
|
size_tensor = torch.tensor(
|
689
|
-
[object_tensor.numel()],
|
707
|
+
[object_tensor.numel()],
|
708
|
+
dtype=torch.long,
|
709
|
+
device=torch.cuda.current_device(),
|
690
710
|
)
|
691
711
|
|
692
712
|
# Send object size
|
693
|
-
|
694
|
-
|
713
|
+
torch.distributed.send(
|
714
|
+
size_tensor, dst=self.ranks[dst], group=self.device_group
|
715
|
+
)
|
695
716
|
|
696
717
|
# Send object
|
697
|
-
torch.distributed.send(
|
718
|
+
torch.distributed.send(
|
719
|
+
object_tensor, dst=self.ranks[dst], group=self.device_group
|
720
|
+
)
|
698
721
|
|
699
722
|
return None
|
700
723
|
|
@@ -708,29 +731,31 @@ class GroupCoordinator:
|
|
708
731
|
src != self.rank_in_group
|
709
732
|
), "Invalid source rank. Source rank is the same as the current rank."
|
710
733
|
|
711
|
-
size_tensor = torch.empty(
|
734
|
+
size_tensor = torch.empty(
|
735
|
+
1, dtype=torch.long, device=torch.cuda.current_device()
|
736
|
+
)
|
712
737
|
|
713
738
|
# Receive object size
|
714
739
|
rank_size = torch.distributed.recv(
|
715
|
-
size_tensor, src=self.ranks[src], group=self.
|
740
|
+
size_tensor, src=self.ranks[src], group=self.device_group
|
716
741
|
)
|
717
742
|
|
718
743
|
# Tensor to receive serialized objects into.
|
719
744
|
object_tensor = torch.empty( # type: ignore[call-overload]
|
720
745
|
size_tensor.item(), # type: ignore[arg-type]
|
721
746
|
dtype=torch.uint8,
|
722
|
-
device=
|
747
|
+
device=torch.cuda.current_device(),
|
723
748
|
)
|
724
749
|
|
725
750
|
rank_object = torch.distributed.recv(
|
726
|
-
object_tensor, src=self.ranks[src], group=self.
|
751
|
+
object_tensor, src=self.ranks[src], group=self.device_group
|
727
752
|
)
|
728
753
|
|
729
754
|
assert (
|
730
755
|
rank_object == rank_size
|
731
756
|
), "Received object sender rank does not match the size sender rank."
|
732
757
|
|
733
|
-
obj = pickle.loads(object_tensor.numpy().tobytes())
|
758
|
+
obj = pickle.loads(object_tensor.cpu().numpy().tobytes())
|
734
759
|
|
735
760
|
return obj
|
736
761
|
|
@@ -841,14 +866,16 @@ class GroupCoordinator:
|
|
841
866
|
dst = (self.rank_in_group + 1) % self.world_size
|
842
867
|
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
843
868
|
|
844
|
-
metadata_list: List[Tuple[Any, Any]] = []
|
845
869
|
assert isinstance(
|
846
870
|
tensor_dict, dict
|
847
871
|
), f"Expecting a dictionary, got {type(tensor_dict)}"
|
848
872
|
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
849
|
-
#
|
850
|
-
#
|
851
|
-
#
|
873
|
+
# Note: While switching to Device-to-Device (D2D) would introduce an extra
|
874
|
+
# Device-to-Host (D2H) memory copy overhead for serialization, our benchmarks
|
875
|
+
# show better overall transmission performance with D2D due to:
|
876
|
+
# 1. Superior D2D transfer bandwidth
|
877
|
+
# 2. Ability to overlap send and recv operations
|
878
|
+
# Thus the net performance gain justifies this approach.
|
852
879
|
self.send_object(metadata_list, dst=dst)
|
853
880
|
for tensor in tensor_list:
|
854
881
|
if tensor.numel() == 0:
|
@@ -48,6 +48,14 @@ class EngineBase(ABC):
|
|
48
48
|
"""Update model weights with in-memory tensor data."""
|
49
49
|
pass
|
50
50
|
|
51
|
+
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
52
|
+
"""Load a new LoRA adapter without re-launching the engine."""
|
53
|
+
pass
|
54
|
+
|
55
|
+
def unload_lora_adapter(self, lora_name: str):
|
56
|
+
"""Unload a LoRA adapter without re-launching the engine."""
|
57
|
+
pass
|
58
|
+
|
51
59
|
@abstractmethod
|
52
60
|
def release_memory_occupation(self):
|
53
61
|
"""Release GPU memory occupation temporarily."""
|