sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +14 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +301 -64
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +20 -15
- sglang/srt/disaggregation/utils.py +47 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +27 -31
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +897 -0
- sglang/srt/entrypoints/openai/serving_completions.py +425 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +28 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +43 -23
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +44 -2
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +14 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +286 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
- sglang/srt/layers/moe/topk.py +117 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +144 -12
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +19 -14
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +49 -32
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +189 -68
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +77 -46
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +27 -8
- sglang/srt/model_loader/loader.py +50 -8
- sglang/srt/model_loader/weight_utils.py +100 -2
- sglang/srt/models/deepseek_nextn.py +35 -30
- sglang/srt/models/deepseek_v2.py +255 -30
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +51 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +248 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -31,23 +31,19 @@ from sglang.srt.utils import get_local_ip_by_remote
|
|
31
31
|
|
32
32
|
logger = logging.getLogger(__name__)
|
33
33
|
|
34
|
-
NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
|
35
|
-
|
36
34
|
GUARD = "NixlMsgGuard".encode("ascii")
|
37
35
|
|
38
36
|
|
39
37
|
@dataclasses.dataclass
|
40
38
|
class TransferInfo:
|
39
|
+
"""Contains indices for a transfer, sent by KVReceiver. Received by prefill bootstrap thread."""
|
40
|
+
|
41
41
|
room: int
|
42
42
|
endpoint: str
|
43
43
|
dst_port: int
|
44
|
-
agent_metadata: bytes
|
45
44
|
agent_name: str
|
46
|
-
dst_kv_ptrs: list[int]
|
47
45
|
dst_kv_indices: npt.NDArray[np.int32]
|
48
|
-
dst_aux_ptrs: list[int]
|
49
46
|
dst_aux_index: int
|
50
|
-
dst_gpu_id: int
|
51
47
|
required_dst_info_num: int
|
52
48
|
|
53
49
|
def is_dummy(self):
|
@@ -59,14 +55,37 @@ class TransferInfo:
|
|
59
55
|
room=int(msg[0].decode("ascii")),
|
60
56
|
endpoint=msg[1].decode("ascii"),
|
61
57
|
dst_port=int(msg[2].decode("ascii")),
|
62
|
-
|
63
|
-
|
58
|
+
agent_name=msg[3].decode("ascii"),
|
59
|
+
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int32),
|
60
|
+
dst_aux_index=int(msg[5].decode("ascii")),
|
61
|
+
required_dst_info_num=int(msg[6].decode("ascii")),
|
62
|
+
)
|
63
|
+
|
64
|
+
|
65
|
+
@dataclasses.dataclass
|
66
|
+
class KVArgsRegisterInfo:
|
67
|
+
"""Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread."""
|
68
|
+
|
69
|
+
room: str
|
70
|
+
endpoint: str
|
71
|
+
dst_port: int
|
72
|
+
agent_name: str
|
73
|
+
agent_metadata: bytes
|
74
|
+
dst_kv_ptrs: list[int]
|
75
|
+
dst_aux_ptrs: list[int]
|
76
|
+
gpu_id: int
|
77
|
+
|
78
|
+
@classmethod
|
79
|
+
def from_zmq(cls, msg: List[bytes]):
|
80
|
+
return cls(
|
81
|
+
room=str(msg[0].decode("ascii")),
|
82
|
+
endpoint=msg[1].decode("ascii"),
|
83
|
+
dst_port=int(msg[2].decode("ascii")),
|
84
|
+
agent_name=msg[3].decode("ascii"),
|
85
|
+
agent_metadata=msg[4],
|
64
86
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
65
|
-
|
66
|
-
|
67
|
-
dst_aux_index=int(msg[8].decode("ascii")),
|
68
|
-
dst_gpu_id=int(msg[9].decode("ascii")),
|
69
|
-
required_dst_info_num=int(msg[10].decode("ascii")),
|
87
|
+
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
88
|
+
gpu_id=int(msg[7].decode("ascii")),
|
70
89
|
)
|
71
90
|
|
72
91
|
|
@@ -109,9 +128,9 @@ class NixlKVManager(CommonKVManager):
|
|
109
128
|
self.register_buffer_to_engine()
|
110
129
|
|
111
130
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
112
|
-
self.request_status = {}
|
113
|
-
self.transfer_infos: Dict[int, TransferInfo] = {}
|
114
|
-
self.
|
131
|
+
self.request_status: Dict[int, KVPoll] = {}
|
132
|
+
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
133
|
+
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
115
134
|
self._start_bootstrap_thread()
|
116
135
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
117
136
|
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
@@ -154,10 +173,13 @@ class NixlKVManager(CommonKVManager):
|
|
154
173
|
if not self.aux_descs:
|
155
174
|
raise Exception("NIXL memory registration failed for aux tensors")
|
156
175
|
|
157
|
-
def
|
158
|
-
|
159
|
-
|
160
|
-
|
176
|
+
def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo):
|
177
|
+
agent_name = decode_kv_args.agent_name
|
178
|
+
if agent_name in self.decode_kv_args_table:
|
179
|
+
logger.info(f"Peer {agent_name} was already registered, ignoring.")
|
180
|
+
return
|
181
|
+
self.decode_kv_args_table[agent_name] = decode_kv_args
|
182
|
+
self.agent.add_remote_agent(decode_kv_args.agent_metadata)
|
161
183
|
|
162
184
|
def send_kvcache(
|
163
185
|
self,
|
@@ -262,17 +284,17 @@ class NixlKVManager(CommonKVManager):
|
|
262
284
|
if req.is_dummy():
|
263
285
|
continue
|
264
286
|
|
265
|
-
peer_name = self._add_remote(req.agent_name, req.agent_metadata)
|
266
287
|
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
|
267
288
|
assert len(chunked_dst_kv_indice) == len(kv_indices)
|
289
|
+
assert req.agent_name in self.decode_kv_args_table
|
268
290
|
|
269
291
|
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
270
292
|
kv_xfer_handle = self.send_kvcache(
|
271
|
-
|
293
|
+
req.agent_name,
|
272
294
|
kv_indices,
|
273
|
-
req.dst_kv_ptrs,
|
295
|
+
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
274
296
|
chunked_dst_kv_indice,
|
275
|
-
req.
|
297
|
+
self.decode_kv_args_table[req.agent_name].gpu_id,
|
276
298
|
notif,
|
277
299
|
)
|
278
300
|
handles.append(kv_xfer_handle)
|
@@ -280,13 +302,15 @@ class NixlKVManager(CommonKVManager):
|
|
280
302
|
if is_last:
|
281
303
|
assert aux_index is not None
|
282
304
|
aux_xfer_handle = self.send_aux(
|
283
|
-
|
305
|
+
req.agent_name,
|
284
306
|
aux_index,
|
285
|
-
req.dst_aux_ptrs,
|
307
|
+
self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
|
286
308
|
req.dst_aux_index,
|
287
309
|
str(req.room) + "_aux",
|
288
310
|
)
|
289
311
|
handles.append(aux_xfer_handle)
|
312
|
+
if is_last:
|
313
|
+
del self.transfer_infos[bootstrap_room]
|
290
314
|
return handles
|
291
315
|
|
292
316
|
def update_transfer_status(self):
|
@@ -328,16 +352,23 @@ class NixlKVManager(CommonKVManager):
|
|
328
352
|
), f"First message should be {GUARD}. Foreign traffic?"
|
329
353
|
waiting_req_bytes = waiting_req_bytes[1:]
|
330
354
|
room = waiting_req_bytes[0].decode("ascii")
|
331
|
-
|
332
|
-
|
355
|
+
agent_name = waiting_req_bytes[3].decode("ascii")
|
356
|
+
if room == "None":
|
357
|
+
# Register new peer and save KV base pointers.
|
358
|
+
self._add_remote_peer(
|
359
|
+
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
|
360
|
+
)
|
361
|
+
logger.debug(f"Register KVArgs from {agent_name} successfully")
|
362
|
+
continue
|
333
363
|
room = int(room)
|
334
|
-
agent_name = waiting_req_bytes[4].decode("ascii")
|
335
364
|
if room not in self.transfer_infos:
|
336
365
|
self.transfer_infos[room] = {}
|
337
366
|
self.transfer_infos[room][agent_name] = TransferInfo.from_zmq(
|
338
367
|
waiting_req_bytes
|
339
368
|
)
|
340
|
-
|
369
|
+
required_dst_info_num = self.transfer_infos[room][
|
370
|
+
agent_name
|
371
|
+
].required_dst_info_num
|
341
372
|
logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}")
|
342
373
|
if len(self.transfer_infos[room]) == required_dst_info_num:
|
343
374
|
logger.debug(f"{room=} is bootstrapped")
|
@@ -391,6 +422,7 @@ class NixlKVSender(BaseKVSender):
|
|
391
422
|
self.chunk_id += 1
|
392
423
|
if is_last:
|
393
424
|
self.has_sent = True
|
425
|
+
del self.kv_mgr.request_status[self.bootstrap_room]
|
394
426
|
|
395
427
|
def poll(self) -> KVPoll:
|
396
428
|
if not self.has_sent:
|
@@ -415,6 +447,7 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
415
447
|
data_parallel_rank: Optional[int] = None,
|
416
448
|
):
|
417
449
|
self.started_transfer = False
|
450
|
+
self.conclude_state = None
|
418
451
|
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
|
419
452
|
|
420
453
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
@@ -426,17 +459,8 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
426
459
|
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
427
460
|
)
|
428
461
|
is_dummy = bootstrap_info["is_dummy"]
|
429
|
-
|
430
|
-
# TODO: send_kv_args earlier
|
431
|
-
packed_kv_data_ptrs = b"".join(
|
432
|
-
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
433
|
-
)
|
434
|
-
packed_aux_data_ptrs = b"".join(
|
435
|
-
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
436
|
-
)
|
437
|
-
|
438
462
|
logger.debug(
|
439
|
-
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
|
463
|
+
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}"
|
440
464
|
)
|
441
465
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
442
466
|
with lock:
|
@@ -446,13 +470,9 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
446
470
|
str(self.bootstrap_room).encode("ascii"),
|
447
471
|
get_local_ip_by_remote().encode("ascii"),
|
448
472
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
449
|
-
self.kv_mgr.agent.get_agent_metadata(),
|
450
473
|
self.kv_mgr.agent.name.encode("ascii"),
|
451
|
-
packed_kv_data_ptrs,
|
452
474
|
kv_indices.tobytes() if not is_dummy else b"",
|
453
|
-
packed_aux_data_ptrs,
|
454
475
|
str(aux_index).encode("ascii"),
|
455
|
-
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
456
476
|
str(self.required_dst_info_num).encode("ascii"),
|
457
477
|
]
|
458
478
|
)
|
@@ -460,17 +480,45 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
460
480
|
self.started_transfer = True
|
461
481
|
|
462
482
|
def poll(self) -> KVPoll:
|
483
|
+
if self.conclude_state is not None:
|
484
|
+
return self.conclude_state
|
463
485
|
if not self.started_transfer:
|
464
486
|
return KVPoll.WaitingForInput # type: ignore
|
465
487
|
|
466
488
|
self.kv_mgr.update_transfer_status()
|
467
|
-
|
468
489
|
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
|
490
|
+
self.conclude_state = KVPoll.Success
|
491
|
+
del self.kv_mgr.transfer_statuses[self.bootstrap_room]
|
469
492
|
return KVPoll.Success # type: ignore
|
470
493
|
return KVPoll.WaitingForInput # type: ignore
|
471
494
|
|
472
495
|
def _register_kv_args(self):
|
473
|
-
|
496
|
+
for bootstrap_info in self.bootstrap_infos:
|
497
|
+
self.prefill_server_url = (
|
498
|
+
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
499
|
+
)
|
500
|
+
packed_kv_data_ptrs = b"".join(
|
501
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
502
|
+
)
|
503
|
+
packed_aux_data_ptrs = b"".join(
|
504
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
505
|
+
)
|
506
|
+
|
507
|
+
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
508
|
+
with lock:
|
509
|
+
sock.send_multipart(
|
510
|
+
[
|
511
|
+
GUARD,
|
512
|
+
"None".encode("ascii"),
|
513
|
+
get_local_ip_by_remote().encode("ascii"),
|
514
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
515
|
+
self.kv_mgr.agent.name.encode("ascii"),
|
516
|
+
self.kv_mgr.agent.get_agent_metadata(),
|
517
|
+
packed_kv_data_ptrs,
|
518
|
+
packed_aux_data_ptrs,
|
519
|
+
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
520
|
+
]
|
521
|
+
)
|
474
522
|
|
475
523
|
def failure_exception(self):
|
476
524
|
raise Exception("Fake KVReceiver Exception")
|
@@ -25,7 +25,6 @@ from collections import deque
|
|
25
25
|
from http import HTTPStatus
|
26
26
|
from typing import TYPE_CHECKING, List, Optional
|
27
27
|
|
28
|
-
import numpy as np
|
29
28
|
import torch
|
30
29
|
|
31
30
|
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
|
@@ -45,6 +44,7 @@ from sglang.srt.disaggregation.utils import (
|
|
45
44
|
)
|
46
45
|
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
47
46
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
47
|
+
from sglang.srt.utils import require_mlp_sync
|
48
48
|
|
49
49
|
if TYPE_CHECKING:
|
50
50
|
from torch.distributed import ProcessGroup
|
@@ -93,8 +93,6 @@ class PrefillBootstrapQueue:
|
|
93
93
|
self.gpu_id = gpu_id
|
94
94
|
self.bootstrap_port = bootstrap_port
|
95
95
|
self.queue: List[Req] = []
|
96
|
-
self.pp_rank = pp_rank
|
97
|
-
self.pp_size = pp_size
|
98
96
|
self.gloo_group = gloo_group
|
99
97
|
self.max_total_num_tokens = max_total_num_tokens
|
100
98
|
self.scheduler = scheduler
|
@@ -124,6 +122,9 @@ class PrefillBootstrapQueue:
|
|
124
122
|
kv_args.kv_data_ptrs = kv_data_ptrs
|
125
123
|
kv_args.kv_data_lens = kv_data_lens
|
126
124
|
kv_args.kv_item_lens = kv_item_lens
|
125
|
+
if not self.is_mla_backend:
|
126
|
+
kv_args.kv_head_num = self.token_to_kv_pool.head_num
|
127
|
+
kv_args.page_size = self.token_to_kv_pool.page_size
|
127
128
|
|
128
129
|
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
129
130
|
self.metadata_buffers.get_buf_infos()
|
@@ -274,12 +275,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
274
275
|
self.process_prefill_chunk()
|
275
276
|
batch = self.get_new_batch_prefill()
|
276
277
|
|
277
|
-
|
278
|
-
|
279
|
-
self.server_args.enable_dp_attention
|
280
|
-
or self.server_args.enable_sp_layernorm
|
281
|
-
):
|
282
|
-
batch, _ = self.prepare_dp_attn_batch(batch)
|
278
|
+
if require_mlp_sync(self.server_args):
|
279
|
+
batch, _ = self.prepare_mlp_sync_batch(batch)
|
283
280
|
self.cur_batch = batch
|
284
281
|
|
285
282
|
if batch:
|
@@ -312,12 +309,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
312
309
|
self.process_prefill_chunk()
|
313
310
|
batch = self.get_new_batch_prefill()
|
314
311
|
|
315
|
-
|
316
|
-
|
317
|
-
self.server_args.enable_dp_attention
|
318
|
-
or self.server_args.enable_sp_layernorm
|
319
|
-
):
|
320
|
-
batch, _ = self.prepare_dp_attn_batch(batch)
|
312
|
+
if require_mlp_sync(self.server_args):
|
313
|
+
batch, _ = self.prepare_mlp_sync_batch(batch)
|
321
314
|
self.cur_batch = batch
|
322
315
|
if batch:
|
323
316
|
result = self.run_batch(batch)
|
@@ -393,6 +386,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
393
386
|
logits_output.input_token_logprobs = tuple(
|
394
387
|
logits_output.input_token_logprobs.tolist()
|
395
388
|
)
|
389
|
+
|
390
|
+
hidden_state_offset = 0
|
396
391
|
for i, (req, next_token_id) in enumerate(
|
397
392
|
zip(batch.reqs, next_token_ids, strict=True)
|
398
393
|
):
|
@@ -402,6 +397,16 @@ class SchedulerDisaggregationPrefillMixin:
|
|
402
397
|
req.output_ids.append(next_token_id)
|
403
398
|
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
404
399
|
self.disagg_prefill_inflight_queue.append(req)
|
400
|
+
if logits_output.hidden_states is not None:
|
401
|
+
last_hidden_index = (
|
402
|
+
hidden_state_offset + extend_input_len_per_req[i] - 1
|
403
|
+
)
|
404
|
+
req.hidden_states_tensor = (
|
405
|
+
logits_output.hidden_states[last_hidden_index].cpu().clone()
|
406
|
+
)
|
407
|
+
hidden_state_offset += extend_input_len_per_req[i]
|
408
|
+
else:
|
409
|
+
req.hidden_states_tensor = None
|
405
410
|
if req.return_logprob:
|
406
411
|
assert extend_logprob_start_len_per_req is not None
|
407
412
|
assert extend_input_len_per_req is not None
|
@@ -6,6 +6,7 @@ import random
|
|
6
6
|
import threading
|
7
7
|
import warnings
|
8
8
|
from collections import deque
|
9
|
+
from contextlib import nullcontext
|
9
10
|
from enum import Enum
|
10
11
|
from typing import TYPE_CHECKING, List, Optional
|
11
12
|
|
@@ -84,24 +85,43 @@ class ReqToMetadataIdxAllocator:
|
|
84
85
|
|
85
86
|
|
86
87
|
class MetadataBuffers:
|
87
|
-
def __init__(
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
self.
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
size: int,
|
91
|
+
hidden_size: int,
|
92
|
+
dtype: torch.dtype,
|
93
|
+
max_top_logprobs_num: int = 128,
|
94
|
+
custom_mem_pool: torch.cuda.MemPool = None,
|
95
|
+
):
|
96
|
+
self.custom_mem_pool = custom_mem_pool
|
97
|
+
device = "cuda" if self.custom_mem_pool else "cpu"
|
98
|
+
|
99
|
+
with (
|
100
|
+
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
101
|
+
if self.custom_mem_pool
|
102
|
+
else nullcontext()
|
103
|
+
):
|
104
|
+
# TODO: abort top_logprobs_num > 128 in PD
|
105
|
+
|
106
|
+
# We transfer the metadata of first output token to decode
|
107
|
+
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
108
|
+
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
109
|
+
|
110
|
+
self.output_token_logprobs_val = torch.zeros(
|
111
|
+
(size, 16), dtype=torch.float32, device=device
|
112
|
+
)
|
113
|
+
self.output_token_logprobs_idx = torch.zeros(
|
114
|
+
(size, 16), dtype=torch.int32, device=device
|
115
|
+
)
|
116
|
+
self.output_top_logprobs_val = torch.zeros(
|
117
|
+
(size, max_top_logprobs_num), dtype=torch.float32, device=device
|
118
|
+
)
|
119
|
+
self.output_top_logprobs_idx = torch.zeros(
|
120
|
+
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
121
|
+
)
|
122
|
+
self.output_hidden_states = torch.zeros(
|
123
|
+
(size, hidden_size), dtype=dtype, device=device
|
124
|
+
)
|
105
125
|
|
106
126
|
def get_buf_infos(self):
|
107
127
|
ptrs = [
|
@@ -110,6 +130,7 @@ class MetadataBuffers:
|
|
110
130
|
self.output_token_logprobs_idx.data_ptr(),
|
111
131
|
self.output_top_logprobs_val.data_ptr(),
|
112
132
|
self.output_top_logprobs_idx.data_ptr(),
|
133
|
+
self.output_hidden_states.data_ptr(),
|
113
134
|
]
|
114
135
|
data_lens = [
|
115
136
|
self.output_ids.nbytes,
|
@@ -117,6 +138,7 @@ class MetadataBuffers:
|
|
117
138
|
self.output_token_logprobs_idx.nbytes,
|
118
139
|
self.output_top_logprobs_val.nbytes,
|
119
140
|
self.output_top_logprobs_idx.nbytes,
|
141
|
+
self.output_hidden_states.nbytes,
|
120
142
|
]
|
121
143
|
item_lens = [
|
122
144
|
self.output_ids[0].nbytes,
|
@@ -124,6 +146,7 @@ class MetadataBuffers:
|
|
124
146
|
self.output_token_logprobs_idx[0].nbytes,
|
125
147
|
self.output_top_logprobs_val[0].nbytes,
|
126
148
|
self.output_top_logprobs_idx[0].nbytes,
|
149
|
+
self.output_hidden_states[0].nbytes,
|
127
150
|
]
|
128
151
|
return ptrs, data_lens, item_lens
|
129
152
|
|
@@ -134,6 +157,7 @@ class MetadataBuffers:
|
|
134
157
|
self.output_token_logprobs_idx[idx],
|
135
158
|
self.output_top_logprobs_val[idx],
|
136
159
|
self.output_top_logprobs_idx[idx],
|
160
|
+
self.output_hidden_states[idx],
|
137
161
|
)
|
138
162
|
|
139
163
|
def set_buf(self, req: Req):
|
@@ -161,6 +185,11 @@ class MetadataBuffers:
|
|
161
185
|
] = torch.tensor(
|
162
186
|
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
163
187
|
)
|
188
|
+
# for PD + spec decode
|
189
|
+
if req.hidden_states_tensor is not None:
|
190
|
+
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
191
|
+
req.hidden_states_tensor
|
192
|
+
)
|
164
193
|
|
165
194
|
|
166
195
|
#########################
|
@@ -523,17 +523,25 @@ class GroupCoordinator:
|
|
523
523
|
self,
|
524
524
|
input_: torch.Tensor,
|
525
525
|
dim: int = -1,
|
526
|
-
|
526
|
+
output_tensor_list: Optional[List[torch.Tensor]] = None,
|
527
527
|
) -> torch.Tensor:
|
528
528
|
world_size = self.world_size
|
529
529
|
# Bypass the function if we are using only 1 GPU.
|
530
530
|
if world_size == 1:
|
531
|
-
|
531
|
+
if output_tensor_list is not None:
|
532
|
+
logger.warning(
|
533
|
+
"Performing in-place all-gather with a group size of 1. "
|
534
|
+
"This may be unnecessary; consider bypassing it for better efficiency."
|
535
|
+
)
|
536
|
+
output_tensor_list[0].copy_(input_)
|
537
|
+
return None
|
538
|
+
else:
|
539
|
+
return input_
|
532
540
|
|
533
|
-
if
|
541
|
+
if output_tensor_list is not None:
|
534
542
|
# TODO(ch-wan): support other backends
|
535
543
|
return torch.distributed.all_gather(
|
536
|
-
|
544
|
+
output_tensor_list, input_, group=self.device_group
|
537
545
|
)
|
538
546
|
|
539
547
|
assert (
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -37,7 +37,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
|
37
37
|
import torch
|
38
38
|
import uvloop
|
39
39
|
|
40
|
-
from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
|
41
40
|
from sglang.srt.entrypoints.EngineBase import EngineBase
|
42
41
|
from sglang.srt.managers.data_parallel_controller import (
|
43
42
|
run_data_parallel_controller_process,
|
@@ -58,11 +57,8 @@ from sglang.srt.managers.io_struct import (
|
|
58
57
|
UpdateWeightsFromTensorReqInput,
|
59
58
|
)
|
60
59
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
60
|
+
from sglang.srt.managers.template_manager import TemplateManager
|
61
61
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
62
|
-
from sglang.srt.openai_api.adapter import (
|
63
|
-
guess_chat_template_name_from_model_path,
|
64
|
-
load_chat_template_for_openai_api,
|
65
|
-
)
|
66
62
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
67
63
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
68
64
|
from sglang.srt.utils import (
|
@@ -119,21 +115,22 @@ class Engine(EngineBase):
|
|
119
115
|
atexit.register(self.shutdown)
|
120
116
|
|
121
117
|
# Allocate ports for inter-process communications
|
122
|
-
port_args = PortArgs.init_new(server_args)
|
118
|
+
self.port_args = PortArgs.init_new(server_args)
|
123
119
|
logger.info(f"{server_args=}")
|
124
120
|
|
125
121
|
# Launch subprocesses
|
126
|
-
tokenizer_manager, scheduler_info = _launch_subprocesses(
|
122
|
+
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
127
123
|
server_args=server_args,
|
128
|
-
port_args=port_args,
|
124
|
+
port_args=self.port_args,
|
129
125
|
)
|
130
126
|
self.server_args = server_args
|
131
127
|
self.tokenizer_manager = tokenizer_manager
|
128
|
+
self.template_manager = template_manager
|
132
129
|
self.scheduler_info = scheduler_info
|
133
130
|
|
134
131
|
context = zmq.Context(2)
|
135
132
|
self.send_to_rpc = get_zmq_socket(
|
136
|
-
context, zmq.DEALER, port_args.rpc_ipc_name, True
|
133
|
+
context, zmq.DEALER, self.port_args.rpc_ipc_name, True
|
137
134
|
)
|
138
135
|
|
139
136
|
def generate(
|
@@ -175,7 +172,7 @@ class Engine(EngineBase):
|
|
175
172
|
"""
|
176
173
|
if self.server_args.enable_dp_attention:
|
177
174
|
if data_parallel_rank is None:
|
178
|
-
logger.
|
175
|
+
logger.debug("data_parallel_rank not provided, using default dispatch")
|
179
176
|
elif data_parallel_rank < 0:
|
180
177
|
raise ValueError("data_parallel_rank must be non-negative")
|
181
178
|
elif data_parallel_rank >= self.server_args.dp_size:
|
@@ -245,6 +242,7 @@ class Engine(EngineBase):
|
|
245
242
|
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
|
246
243
|
lora_path: Optional[List[Optional[str]]] = None,
|
247
244
|
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
245
|
+
return_hidden_states: bool = False,
|
248
246
|
stream: bool = False,
|
249
247
|
bootstrap_host: Optional[Union[List[str], str]] = None,
|
250
248
|
bootstrap_port: Optional[Union[List[int], int]] = None,
|
@@ -258,7 +256,7 @@ class Engine(EngineBase):
|
|
258
256
|
|
259
257
|
if self.server_args.enable_dp_attention:
|
260
258
|
if data_parallel_rank is None:
|
261
|
-
logger.
|
259
|
+
logger.debug("data_parallel_rank not provided, using default dispatch")
|
262
260
|
elif data_parallel_rank < 0:
|
263
261
|
raise ValueError("data_parallel_rank must be non-negative")
|
264
262
|
elif data_parallel_rank >= self.server_args.dp_size:
|
@@ -277,6 +275,7 @@ class Engine(EngineBase):
|
|
277
275
|
top_logprobs_num=top_logprobs_num,
|
278
276
|
token_ids_logprob=token_ids_logprob,
|
279
277
|
lora_path=lora_path,
|
278
|
+
return_hidden_states=return_hidden_states,
|
280
279
|
stream=stream,
|
281
280
|
custom_logit_processor=custom_logit_processor,
|
282
281
|
bootstrap_host=bootstrap_host,
|
@@ -479,17 +478,15 @@ class Engine(EngineBase):
|
|
479
478
|
self.tokenizer_manager.get_weights_by_name(obj, None)
|
480
479
|
)
|
481
480
|
|
482
|
-
def release_memory_occupation(self):
|
483
|
-
|
484
|
-
obj = ReleaseMemoryOccupationReqInput()
|
481
|
+
def release_memory_occupation(self, tags: Optional[List[str]] = None):
|
482
|
+
obj = ReleaseMemoryOccupationReqInput(tags=tags)
|
485
483
|
loop = asyncio.get_event_loop()
|
486
484
|
return loop.run_until_complete(
|
487
485
|
self.tokenizer_manager.release_memory_occupation(obj, None)
|
488
486
|
)
|
489
487
|
|
490
|
-
def resume_memory_occupation(self):
|
491
|
-
|
492
|
-
obj = ResumeMemoryOccupationReqInput()
|
488
|
+
def resume_memory_occupation(self, tags: Optional[List[str]] = None):
|
489
|
+
obj = ResumeMemoryOccupationReqInput(tags=tags)
|
493
490
|
loop = asyncio.get_event_loop()
|
494
491
|
return loop.run_until_complete(
|
495
492
|
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
@@ -649,7 +646,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
649
646
|
|
650
647
|
def _launch_subprocesses(
|
651
648
|
server_args: ServerArgs, port_args: Optional[PortArgs] = None
|
652
|
-
) -> Tuple[TokenizerManager, Dict]:
|
649
|
+
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
|
653
650
|
"""
|
654
651
|
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
655
652
|
"""
|
@@ -670,11 +667,9 @@ def _launch_subprocesses(
|
|
670
667
|
|
671
668
|
scheduler_procs = []
|
672
669
|
if server_args.dp_size == 1:
|
673
|
-
# Launch tensor parallel scheduler processes
|
674
670
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
675
671
|
enable=server_args.enable_memory_saver
|
676
672
|
)
|
677
|
-
|
678
673
|
scheduler_pipe_readers = []
|
679
674
|
|
680
675
|
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
@@ -710,6 +705,7 @@ def _launch_subprocesses(
|
|
710
705
|
writer,
|
711
706
|
),
|
712
707
|
)
|
708
|
+
|
713
709
|
with memory_saver_adapter.configure_subprocess():
|
714
710
|
proc.start()
|
715
711
|
scheduler_procs.append(proc)
|
@@ -735,7 +731,7 @@ def _launch_subprocesses(
|
|
735
731
|
|
736
732
|
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
|
737
733
|
# When using `Engine` as a Python API, we don't want to block here.
|
738
|
-
return None, None
|
734
|
+
return None, None, None
|
739
735
|
|
740
736
|
launch_dummy_health_check_server(server_args.host, server_args.port)
|
741
737
|
|
@@ -744,7 +740,7 @@ def _launch_subprocesses(
|
|
744
740
|
logger.error(
|
745
741
|
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
746
742
|
)
|
747
|
-
return None, None
|
743
|
+
return None, None, None
|
748
744
|
|
749
745
|
# Launch detokenizer process
|
750
746
|
detoken_proc = mp.Process(
|
@@ -758,15 +754,15 @@ def _launch_subprocesses(
|
|
758
754
|
|
759
755
|
# Launch tokenizer process
|
760
756
|
tokenizer_manager = TokenizerManager(server_args, port_args)
|
761
|
-
if server_args.chat_template:
|
762
|
-
load_chat_template_for_openai_api(
|
763
|
-
tokenizer_manager, server_args.chat_template, server_args.model_path
|
764
|
-
)
|
765
|
-
else:
|
766
|
-
guess_chat_template_name_from_model_path(server_args.model_path)
|
767
757
|
|
768
|
-
|
769
|
-
|
758
|
+
# Initialize templates
|
759
|
+
template_manager = TemplateManager()
|
760
|
+
template_manager.initialize_templates(
|
761
|
+
tokenizer_manager=tokenizer_manager,
|
762
|
+
model_path=server_args.model_path,
|
763
|
+
chat_template=server_args.chat_template,
|
764
|
+
completion_template=server_args.completion_template,
|
765
|
+
)
|
770
766
|
|
771
767
|
# Wait for the model to finish loading
|
772
768
|
scheduler_infos = []
|
@@ -790,4 +786,4 @@ def _launch_subprocesses(
|
|
790
786
|
# Assume all schedulers have the same scheduler_info
|
791
787
|
scheduler_info = scheduler_infos[0]
|
792
788
|
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
793
|
-
return tokenizer_manager, scheduler_info
|
789
|
+
return tokenizer_manager, template_manager, scheduler_info
|