sglang 0.4.5.post3__py3-none-any.whl → 0.4.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -9
- sglang/compile_deep_gemm.py +45 -4
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +9 -3
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +67 -13
- sglang/srt/disaggregation/fake/__init__.py +1 -0
- sglang/srt/disaggregation/fake/conn.py +88 -0
- sglang/srt/disaggregation/mini_lb.py +45 -8
- sglang/srt/disaggregation/mooncake/conn.py +198 -31
- sglang/srt/disaggregation/prefill.py +36 -12
- sglang/srt/disaggregation/utils.py +16 -2
- sglang/srt/entrypoints/engine.py +9 -0
- sglang/srt/entrypoints/http_server.py +35 -4
- sglang/srt/function_call_parser.py +77 -5
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
- sglang/srt/layers/attention/flashattention_backend.py +28 -10
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/utils.py +1 -1
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/layernorm.py +38 -16
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -17
- sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/deep_gemm.py +17 -10
- sglang/srt/layers/quantization/fp8.py +20 -22
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +170 -126
- sglang/srt/managers/data_parallel_controller.py +10 -3
- sglang/srt/managers/io_struct.py +7 -0
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +38 -12
- sglang/srt/managers/scheduler.py +41 -28
- sglang/srt/managers/scheduler_output_processor_mixin.py +25 -9
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/managers/tp_worker.py +3 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -4
- sglang/srt/mem_cache/memory_pool.py +87 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -3
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +19 -25
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +144 -70
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpmo.py +5 -1
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +50 -11
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/reasoning_parser.py +25 -1
- sglang/srt/server_args.py +31 -24
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +5 -1
- sglang/test/runners.py +6 -13
- sglang/test/send_one.py +84 -28
- sglang/test/test_utils.py +74 -18
- sglang/version.py +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/METADATA +5 -6
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/RECORD +97 -80
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/top_level.txt +0 -0
@@ -35,6 +35,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|
35
35
|
import copy
|
36
36
|
import dataclasses
|
37
37
|
import logging
|
38
|
+
import threading
|
38
39
|
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
39
40
|
|
40
41
|
import numpy as np
|
@@ -285,6 +286,7 @@ class MultimodalInputs:
|
|
285
286
|
num_image_tokens: Optional[int] = None
|
286
287
|
|
287
288
|
# QWen2-VL related
|
289
|
+
mrope_positions: Optional[torch.Tensor] = None
|
288
290
|
mrope_position_delta: Optional[torch.Tensor] = None
|
289
291
|
|
290
292
|
# image
|
@@ -310,16 +312,12 @@ class MultimodalInputs:
|
|
310
312
|
assert isinstance(ret.mm_items, list)
|
311
313
|
ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
|
312
314
|
|
313
|
-
assert len(ret.mm_items) != 0
|
314
|
-
|
315
|
-
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
316
|
-
# Please note that if the `input_ids` is later used in the model forward,
|
317
|
-
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
318
|
-
# errors in cuda kernels. See also llava.py for example.
|
319
315
|
for item in ret.mm_items:
|
320
316
|
item.set_pad_value()
|
321
317
|
|
322
318
|
optional_args = [
|
319
|
+
"mrope_positions",
|
320
|
+
"mrope_position_delta",
|
323
321
|
"im_token_id",
|
324
322
|
"im_start_id",
|
325
323
|
"im_end_id",
|
@@ -350,11 +348,6 @@ class MultimodalInputs:
|
|
350
348
|
merge image inputs when requests are being merged
|
351
349
|
"""
|
352
350
|
|
353
|
-
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
354
|
-
# Please note that if the `input_ids` is later used in the model forward,
|
355
|
-
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
356
|
-
# errors in cuda kernels. See also llava.py for example.
|
357
|
-
|
358
351
|
# args needed to be merged
|
359
352
|
optional_args = [
|
360
353
|
"mm_items",
|
@@ -364,6 +357,30 @@ class MultimodalInputs:
|
|
364
357
|
self_arg = getattr(self, arg, None)
|
365
358
|
if self_arg is not None:
|
366
359
|
setattr(self, arg, self_arg + getattr(other, arg))
|
360
|
+
|
361
|
+
mrope_positions = self.mrope_positions
|
362
|
+
if mrope_positions is not None:
|
363
|
+
if other.mrope_positions is None:
|
364
|
+
self.mrope_positions = mrope_positions
|
365
|
+
else:
|
366
|
+
self.mrope_positions = torch.cat(
|
367
|
+
[self.mrope_positions, other.mrope_positions], dim=1
|
368
|
+
)
|
369
|
+
|
370
|
+
mrope_position_delta = self.mrope_position_delta
|
371
|
+
if mrope_position_delta is not None:
|
372
|
+
if other.mrope_position_delta is None:
|
373
|
+
self.mrope_position_delta = mrope_position_delta
|
374
|
+
else:
|
375
|
+
self.mrope_position_delta = torch.cat(
|
376
|
+
[self.mrope_position_delta, other.mrope_position_delta], dim=0
|
377
|
+
)
|
378
|
+
|
379
|
+
for key, val in other.__dict__.items():
|
380
|
+
if "_id" in key:
|
381
|
+
# set token_ids
|
382
|
+
if getattr(self, key, None) is None:
|
383
|
+
setattr(self, key, getattr(other, key, None))
|
367
384
|
# other args would be kept intact
|
368
385
|
|
369
386
|
|
@@ -388,6 +405,7 @@ class Req:
|
|
388
405
|
return_hidden_states: bool = False,
|
389
406
|
eos_token_ids: Optional[Set[int]] = None,
|
390
407
|
bootstrap_host: Optional[str] = None,
|
408
|
+
bootstrap_port: Optional[int] = None,
|
391
409
|
bootstrap_room: Optional[int] = None,
|
392
410
|
):
|
393
411
|
# Input and output info
|
@@ -523,6 +541,7 @@ class Req:
|
|
523
541
|
|
524
542
|
# For disaggregation
|
525
543
|
self.bootstrap_host: str = bootstrap_host
|
544
|
+
self.bootstrap_port: Optional[int] = bootstrap_port
|
526
545
|
self.bootstrap_room: Optional[int] = bootstrap_room
|
527
546
|
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
528
547
|
|
@@ -706,6 +725,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
706
725
|
# This is an optimization to reduce the overhead of the prefill check.
|
707
726
|
batch_is_full: bool = False
|
708
727
|
|
728
|
+
# Events
|
729
|
+
launch_done: Optional[threading.Event] = None
|
730
|
+
|
709
731
|
# Sampling info
|
710
732
|
sampling_info: SamplingBatchInfo = None
|
711
733
|
next_batch_sampling_info: SamplingBatchInfo = None
|
@@ -1450,7 +1472,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1450
1472
|
if self.model_config.is_encoder_decoder:
|
1451
1473
|
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
|
1452
1474
|
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
|
1453
|
-
|
1454
1475
|
self.req_pool_indices = torch.cat(
|
1455
1476
|
[self.req_pool_indices, other.req_pool_indices]
|
1456
1477
|
)
|
@@ -1494,6 +1515,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1494
1515
|
)
|
1495
1516
|
or global_server_args_dict["attention_backend"] == "flashmla"
|
1496
1517
|
or global_server_args_dict["attention_backend"] == "fa3"
|
1518
|
+
or global_server_args_dict["attention_backend"] == "cutlass_mla"
|
1497
1519
|
):
|
1498
1520
|
seq_lens_cpu = self.seq_lens.cpu()
|
1499
1521
|
else:
|
@@ -1548,6 +1570,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1548
1570
|
)
|
1549
1571
|
),
|
1550
1572
|
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
1573
|
+
launch_done=self.launch_done,
|
1551
1574
|
)
|
1552
1575
|
|
1553
1576
|
def copy(self):
|
@@ -1630,6 +1653,9 @@ class ModelWorkerBatch:
|
|
1630
1653
|
# If set, the output of the batch contains the hidden states of the run.
|
1631
1654
|
capture_hidden_mode: CaptureHiddenMode = None
|
1632
1655
|
|
1656
|
+
# Overlap event
|
1657
|
+
launch_done: Optional[threading.Event] = None
|
1658
|
+
|
1633
1659
|
|
1634
1660
|
@triton.jit
|
1635
1661
|
def write_req_to_token_pool_triton(
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -248,9 +248,6 @@ class Scheduler(
|
|
248
248
|
if not self.is_generation:
|
249
249
|
self.enable_overlap = False
|
250
250
|
logger.info("Overlap scheduler is disabled for embedding models.")
|
251
|
-
if self.model_config.is_multimodal:
|
252
|
-
self.enable_overlap = False
|
253
|
-
logger.info("Overlap scheduler is disabled for multimodal models.")
|
254
251
|
|
255
252
|
# Launch a tensor parallel worker
|
256
253
|
if self.enable_overlap:
|
@@ -578,6 +575,10 @@ class Scheduler(
|
|
578
575
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
579
576
|
transfer_backend=self.transfer_backend,
|
580
577
|
)
|
578
|
+
|
579
|
+
# Metric for pre-allocation
|
580
|
+
self.num_tokens_pre_allocated = 0
|
581
|
+
|
581
582
|
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
582
583
|
# *2 for the headroom.
|
583
584
|
buffer_size = self.max_running_requests * 2
|
@@ -593,7 +594,7 @@ class Scheduler(
|
|
593
594
|
)
|
594
595
|
metadata_buffers = [output_id_buffer]
|
595
596
|
|
596
|
-
self.
|
597
|
+
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
|
597
598
|
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
598
599
|
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
599
600
|
metadata_buffers=metadata_buffers,
|
@@ -641,6 +642,7 @@ class Scheduler(
|
|
641
642
|
self.cur_batch = batch
|
642
643
|
|
643
644
|
if batch:
|
645
|
+
batch.launch_done = threading.Event()
|
644
646
|
result = self.run_batch(batch)
|
645
647
|
self.result_queue.append((batch.copy(), result))
|
646
648
|
|
@@ -652,7 +654,7 @@ class Scheduler(
|
|
652
654
|
forward_mode=ForwardMode.DUMMY_FIRST,
|
653
655
|
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
654
656
|
)
|
655
|
-
self.process_batch_result(tmp_batch, None)
|
657
|
+
self.process_batch_result(tmp_batch, None, batch.launch_done)
|
656
658
|
|
657
659
|
if self.last_batch:
|
658
660
|
# Process the results of the last batch
|
@@ -660,7 +662,10 @@ class Scheduler(
|
|
660
662
|
tmp_batch.next_batch_sampling_info = (
|
661
663
|
self.tp_worker.cur_sampling_info if batch else None
|
662
664
|
)
|
663
|
-
|
665
|
+
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
|
666
|
+
self.process_batch_result(
|
667
|
+
tmp_batch, tmp_result, batch.launch_done if batch else None
|
668
|
+
)
|
664
669
|
elif batch is None:
|
665
670
|
# When the server is idle, do self-check and re-init some states
|
666
671
|
self.check_memory()
|
@@ -787,6 +792,7 @@ class Scheduler(
|
|
787
792
|
return_hidden_states=recv_req.return_hidden_states,
|
788
793
|
eos_token_ids=self.model_config.hf_eos_token_id,
|
789
794
|
bootstrap_host=recv_req.bootstrap_host,
|
795
|
+
bootstrap_port=recv_req.bootstrap_port,
|
790
796
|
bootstrap_room=recv_req.bootstrap_room,
|
791
797
|
)
|
792
798
|
req.tokenizer = self.tokenizer
|
@@ -901,7 +907,7 @@ class Scheduler(
|
|
901
907
|
def _add_request_to_queue(self, req: Req):
|
902
908
|
req.queue_time_start = time.time()
|
903
909
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
904
|
-
self.
|
910
|
+
self.disagg_prefill_bootstrap_queue.add(req)
|
905
911
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
906
912
|
self.disagg_decode_prealloc_queue.add(req)
|
907
913
|
else:
|
@@ -991,8 +997,15 @@ class Scheduler(
|
|
991
997
|
f"#cached-token: {adder.log_hit_tokens}, "
|
992
998
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
993
999
|
f"#running-req: {running_bs}, "
|
994
|
-
f"#queue-req: {len(self.waiting_queue)}, "
|
995
1000
|
)
|
1001
|
+
|
1002
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1003
|
+
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
1004
|
+
f += f"#queue-req: {len(self.waiting_queue)}, "
|
1005
|
+
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)} "
|
1006
|
+
else:
|
1007
|
+
f += f"#queue-req: {len(self.waiting_queue)}"
|
1008
|
+
|
996
1009
|
logger.info(f)
|
997
1010
|
|
998
1011
|
if self.enable_metrics:
|
@@ -1028,15 +1041,14 @@ class Scheduler(
|
|
1028
1041
|
gap_latency / self.server_args.decode_log_interval
|
1029
1042
|
)
|
1030
1043
|
|
1044
|
+
msg = (
|
1045
|
+
f"Decode batch. "
|
1046
|
+
f"#running-req: {num_running_reqs}, "
|
1047
|
+
f"#token: {num_used}, "
|
1048
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
1049
|
+
)
|
1050
|
+
|
1031
1051
|
if self.spec_algorithm.is_none():
|
1032
|
-
msg = (
|
1033
|
-
f"Decode batch. "
|
1034
|
-
f"#running-req: {num_running_reqs}, "
|
1035
|
-
f"#token: {num_used}, "
|
1036
|
-
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
1037
|
-
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
1038
|
-
f"#queue-req: {len(self.waiting_queue)}, "
|
1039
|
-
)
|
1040
1052
|
spec_accept_length = 0
|
1041
1053
|
else:
|
1042
1054
|
spec_accept_length = (
|
@@ -1045,15 +1057,15 @@ class Scheduler(
|
|
1045
1057
|
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
|
1046
1058
|
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
1047
1059
|
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
1048
|
-
msg
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1060
|
+
msg += f"accept len: {spec_accept_length:.2f}, "
|
1061
|
+
|
1062
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
1063
|
+
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
1064
|
+
|
1065
|
+
msg += (
|
1066
|
+
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
1067
|
+
f"#queue-req: {len(self.waiting_queue)}"
|
1068
|
+
)
|
1057
1069
|
|
1058
1070
|
logger.info(msg)
|
1059
1071
|
if self.enable_metrics:
|
@@ -1406,14 +1418,15 @@ class Scheduler(
|
|
1406
1418
|
self,
|
1407
1419
|
batch: ScheduleBatch,
|
1408
1420
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
1421
|
+
launch_done: Optional[threading.Event] = None,
|
1409
1422
|
):
|
1410
1423
|
if batch.forward_mode.is_decode():
|
1411
|
-
self.process_batch_result_decode(batch, result)
|
1424
|
+
self.process_batch_result_decode(batch, result, launch_done)
|
1412
1425
|
elif batch.forward_mode.is_extend():
|
1413
|
-
self.process_batch_result_prefill(batch, result)
|
1426
|
+
self.process_batch_result_prefill(batch, result, launch_done)
|
1414
1427
|
elif batch.forward_mode.is_idle():
|
1415
1428
|
if self.enable_overlap:
|
1416
|
-
self.tp_worker.
|
1429
|
+
self.tp_worker.resolve_last_batch_result(launch_done)
|
1417
1430
|
if batch.next_batch_sampling_info:
|
1418
1431
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1419
1432
|
self.current_stream.synchronize()
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import threading
|
3
4
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
4
5
|
|
5
6
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
@@ -11,6 +12,7 @@ if TYPE_CHECKING:
|
|
11
12
|
EmbeddingBatchResult,
|
12
13
|
GenerationBatchResult,
|
13
14
|
ScheduleBatch,
|
15
|
+
Scheduler,
|
14
16
|
)
|
15
17
|
|
16
18
|
|
@@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin:
|
|
21
23
|
"""
|
22
24
|
|
23
25
|
def process_batch_result_prefill(
|
24
|
-
self,
|
26
|
+
self: Scheduler,
|
25
27
|
batch: ScheduleBatch,
|
26
28
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
29
|
+
launch_done: Optional[threading.Event] = None,
|
27
30
|
):
|
28
31
|
skip_stream_req = None
|
29
32
|
|
@@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin:
|
|
43
46
|
)
|
44
47
|
|
45
48
|
if self.enable_overlap:
|
46
|
-
logits_output, next_token_ids =
|
49
|
+
logits_output, next_token_ids = (
|
50
|
+
self.tp_worker.resolve_last_batch_result(
|
51
|
+
launch_done,
|
52
|
+
)
|
53
|
+
)
|
47
54
|
else:
|
48
55
|
# Move next_token_ids and logprobs to cpu
|
49
56
|
next_token_ids = next_token_ids.tolist()
|
@@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin:
|
|
175
182
|
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
176
183
|
|
177
184
|
def process_batch_result_decode(
|
178
|
-
self,
|
185
|
+
self: Scheduler,
|
179
186
|
batch: ScheduleBatch,
|
180
187
|
result: GenerationBatchResult,
|
188
|
+
launch_done: Optional[threading.Event] = None,
|
181
189
|
):
|
182
190
|
logits_output, next_token_ids, bid = (
|
183
191
|
result.logits_output,
|
@@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin:
|
|
187
195
|
self.num_generated_tokens += len(batch.reqs)
|
188
196
|
|
189
197
|
if self.enable_overlap:
|
190
|
-
logits_output, next_token_ids = self.tp_worker.
|
198
|
+
logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result(
|
199
|
+
launch_done
|
200
|
+
)
|
191
201
|
next_token_logprobs = logits_output.next_token_logprobs
|
192
202
|
elif batch.spec_algorithm.is_none():
|
193
203
|
# spec decoding handles output logprobs inside verify process.
|
@@ -271,7 +281,7 @@ class SchedulerOutputProcessorMixin:
|
|
271
281
|
self.log_decode_stats()
|
272
282
|
|
273
283
|
def add_input_logprob_return_values(
|
274
|
-
self,
|
284
|
+
self: Scheduler,
|
275
285
|
i: int,
|
276
286
|
req: Req,
|
277
287
|
output: LogitsProcessorOutput,
|
@@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin:
|
|
405
415
|
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
|
406
416
|
|
407
417
|
def add_logprob_return_values(
|
408
|
-
self,
|
418
|
+
self: Scheduler,
|
409
419
|
i: int,
|
410
420
|
req: Req,
|
411
421
|
pt: int,
|
@@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin:
|
|
436
446
|
return num_input_logprobs
|
437
447
|
|
438
448
|
def stream_output(
|
439
|
-
self
|
449
|
+
self: Scheduler,
|
450
|
+
reqs: List[Req],
|
451
|
+
return_logprob: bool,
|
452
|
+
skip_req: Optional[Req] = None,
|
440
453
|
):
|
441
454
|
"""Stream the output to detokenizer."""
|
442
455
|
if self.is_generation:
|
@@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin:
|
|
445
458
|
self.stream_output_embedding(reqs)
|
446
459
|
|
447
460
|
def stream_output_generation(
|
448
|
-
self
|
461
|
+
self: Scheduler,
|
462
|
+
reqs: List[Req],
|
463
|
+
return_logprob: bool,
|
464
|
+
skip_req: Optional[Req] = None,
|
449
465
|
):
|
450
466
|
rids = []
|
451
467
|
finished_reasons: List[BaseFinishReason] = []
|
@@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin:
|
|
593
609
|
)
|
594
610
|
)
|
595
611
|
|
596
|
-
def stream_output_embedding(self, reqs: List[Req]):
|
612
|
+
def stream_output_embedding(self: Scheduler, reqs: List[Req]):
|
597
613
|
rids = []
|
598
614
|
finished_reasons: List[BaseFinishReason] = []
|
599
615
|
|
@@ -419,7 +419,10 @@ class TokenizerManager:
|
|
419
419
|
input_ids = self.tokenizer.encode(input_text)
|
420
420
|
|
421
421
|
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
422
|
-
obj.image_data,
|
422
|
+
image_data=obj.image_data,
|
423
|
+
input_text=input_text or input_ids,
|
424
|
+
request_obj=obj,
|
425
|
+
max_req_input_len=self.max_req_input_len,
|
423
426
|
)
|
424
427
|
if image_inputs and "input_ids" in image_inputs:
|
425
428
|
input_ids = image_inputs["input_ids"]
|
@@ -495,6 +498,7 @@ class TokenizerManager:
|
|
495
498
|
token_ids_logprob,
|
496
499
|
obj.stream,
|
497
500
|
bootstrap_host=obj.bootstrap_host,
|
501
|
+
bootstrap_port=obj.bootstrap_port,
|
498
502
|
bootstrap_room=obj.bootstrap_room,
|
499
503
|
lora_path=obj.lora_path,
|
500
504
|
input_embeds=input_embeds,
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -170,13 +170,13 @@ class TpModelWorker:
|
|
170
170
|
def forward_batch_generation(
|
171
171
|
self,
|
172
172
|
model_worker_batch: ModelWorkerBatch,
|
173
|
-
launch_done: Optional[threading.Event] = None,
|
174
173
|
skip_sample: bool = False,
|
175
174
|
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
|
176
175
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
177
176
|
logits_output = self.model_runner.forward(forward_batch)
|
178
|
-
|
179
|
-
|
177
|
+
|
178
|
+
if model_worker_batch.launch_done is not None:
|
179
|
+
model_worker_batch.launch_done.set()
|
180
180
|
|
181
181
|
if skip_sample:
|
182
182
|
next_token_ids = None
|
@@ -132,7 +132,6 @@ class TpModelWorkerClient:
|
|
132
132
|
batch_pt += 1
|
133
133
|
|
134
134
|
# Create event
|
135
|
-
self.launch_done = threading.Event()
|
136
135
|
copy_done = torch.get_device_module(self.device).Event()
|
137
136
|
|
138
137
|
# Resolve future tokens in the input
|
@@ -141,7 +140,7 @@ class TpModelWorkerClient:
|
|
141
140
|
|
142
141
|
# Run forward
|
143
142
|
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
144
|
-
model_worker_batch
|
143
|
+
model_worker_batch
|
145
144
|
)
|
146
145
|
|
147
146
|
# Update the future token ids map
|
@@ -168,10 +167,16 @@ class TpModelWorkerClient:
|
|
168
167
|
|
169
168
|
self.output_queue.put((copy_done, logits_output, next_token_ids))
|
170
169
|
|
171
|
-
def
|
170
|
+
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
|
171
|
+
"""
|
172
|
+
This function is called to resolve the last batch result and
|
173
|
+
wait for the current batch to be launched. Used in overlap mode.
|
174
|
+
"""
|
172
175
|
copy_done, logits_output, next_token_ids = self.output_queue.get()
|
176
|
+
|
177
|
+
if launch_done is not None:
|
178
|
+
launch_done.wait()
|
173
179
|
copy_done.synchronize()
|
174
|
-
self.launch_done.wait()
|
175
180
|
|
176
181
|
if logits_output.next_token_logprobs is not None:
|
177
182
|
logits_output.next_token_logprobs = (
|
@@ -34,6 +34,8 @@ from typing import List, Optional, Tuple, Union
|
|
34
34
|
import numpy as np
|
35
35
|
import psutil
|
36
36
|
import torch
|
37
|
+
import triton
|
38
|
+
import triton.language as tl
|
37
39
|
|
38
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
41
|
from sglang.srt.utils import debug_timing, get_compiler_backend
|
@@ -405,6 +407,72 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
|
405
407
|
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
406
408
|
|
407
409
|
|
410
|
+
@triton.jit
|
411
|
+
def set_mla_kv_buffer_kernel(
|
412
|
+
kv_buffer_ptr,
|
413
|
+
cache_k_nope_ptr,
|
414
|
+
cache_k_rope_ptr,
|
415
|
+
loc_ptr,
|
416
|
+
buffer_stride: tl.constexpr,
|
417
|
+
nope_stride: tl.constexpr,
|
418
|
+
rope_stride: tl.constexpr,
|
419
|
+
nope_dim: tl.constexpr,
|
420
|
+
rope_dim: tl.constexpr,
|
421
|
+
BLOCK: tl.constexpr,
|
422
|
+
):
|
423
|
+
pid_loc = tl.program_id(0)
|
424
|
+
pid_blk = tl.program_id(1)
|
425
|
+
|
426
|
+
base = pid_blk * BLOCK
|
427
|
+
offs = base + tl.arange(0, BLOCK)
|
428
|
+
total_dim = nope_dim + rope_dim
|
429
|
+
mask = offs < total_dim
|
430
|
+
|
431
|
+
loc = tl.load(loc_ptr + pid_loc)
|
432
|
+
dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs
|
433
|
+
|
434
|
+
if base + BLOCK <= nope_dim:
|
435
|
+
src = tl.load(
|
436
|
+
cache_k_nope_ptr + pid_loc * nope_stride + offs,
|
437
|
+
mask=mask,
|
438
|
+
)
|
439
|
+
else:
|
440
|
+
offs_rope = offs - nope_dim
|
441
|
+
src = tl.load(
|
442
|
+
cache_k_rope_ptr + pid_loc * rope_stride + offs_rope,
|
443
|
+
mask=mask,
|
444
|
+
)
|
445
|
+
|
446
|
+
tl.store(dst_ptr, src, mask=mask)
|
447
|
+
|
448
|
+
|
449
|
+
def set_mla_kv_buffer_triton(
|
450
|
+
kv_buffer: torch.Tensor,
|
451
|
+
loc: torch.Tensor,
|
452
|
+
cache_k_nope: torch.Tensor,
|
453
|
+
cache_k_rope: torch.Tensor,
|
454
|
+
):
|
455
|
+
nope_dim = cache_k_nope.shape[-1]
|
456
|
+
rope_dim = cache_k_rope.shape[-1]
|
457
|
+
total_dim = nope_dim + rope_dim
|
458
|
+
BLOCK = 128
|
459
|
+
n_loc = loc.numel()
|
460
|
+
grid = (n_loc, triton.cdiv(total_dim, BLOCK))
|
461
|
+
|
462
|
+
set_mla_kv_buffer_kernel[grid](
|
463
|
+
kv_buffer,
|
464
|
+
cache_k_nope,
|
465
|
+
cache_k_rope,
|
466
|
+
loc,
|
467
|
+
kv_buffer.stride(0),
|
468
|
+
cache_k_nope.stride(0),
|
469
|
+
cache_k_rope.stride(0),
|
470
|
+
nope_dim,
|
471
|
+
rope_dim,
|
472
|
+
BLOCK=BLOCK,
|
473
|
+
)
|
474
|
+
|
475
|
+
|
408
476
|
class MLATokenToKVPool(KVCache):
|
409
477
|
def __init__(
|
410
478
|
self,
|
@@ -504,6 +572,25 @@ class MLATokenToKVPool(KVCache):
|
|
504
572
|
else:
|
505
573
|
self.kv_buffer[layer_id][loc] = cache_k
|
506
574
|
|
575
|
+
def set_mla_kv_buffer(
|
576
|
+
self,
|
577
|
+
layer: RadixAttention,
|
578
|
+
loc: torch.Tensor,
|
579
|
+
cache_k_nope: torch.Tensor,
|
580
|
+
cache_k_rope: torch.Tensor,
|
581
|
+
):
|
582
|
+
layer_id = layer.layer_id
|
583
|
+
if cache_k_nope.dtype != self.dtype:
|
584
|
+
cache_k_nope = cache_k_nope.to(self.dtype)
|
585
|
+
cache_k_rope = cache_k_rope.to(self.dtype)
|
586
|
+
if self.store_dtype != self.dtype:
|
587
|
+
cache_k_nope = cache_k_nope.view(self.store_dtype)
|
588
|
+
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
589
|
+
|
590
|
+
set_mla_kv_buffer_triton(
|
591
|
+
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
592
|
+
)
|
593
|
+
|
507
594
|
def get_flat_data(self, indices):
|
508
595
|
# prepare a large chunk of contiguous data for efficient transfer
|
509
596
|
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
|
@@ -134,7 +134,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
134
134
|
)
|
135
135
|
|
136
136
|
gpu_mem = get_device_memory_capacity()
|
137
|
-
|
137
|
+
# Batch size of each rank will not become so large when DP is on
|
138
|
+
if gpu_mem is not None and gpu_mem > 81920 and server_args.dp_size == 1:
|
138
139
|
capture_bs += list(range(160, 257, 8))
|
139
140
|
|
140
141
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
@@ -278,9 +279,9 @@ class CudaGraphRunner:
|
|
278
279
|
f"Capture cuda graph failed: {e}\n"
|
279
280
|
"Possible solutions:\n"
|
280
281
|
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
281
|
-
"2. set --cuda-graph-max-bs to a smaller value (e.g.,
|
282
|
+
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
|
282
283
|
"3. disable torch compile by not using --enable-torch-compile\n"
|
283
|
-
"4. disable cuda graph by --disable-cuda-graph\n"
|
284
|
+
"4. disable cuda graph by --disable-cuda-graph. (Not recommonded. Huge perf loss)\n"
|
284
285
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
285
286
|
)
|
286
287
|
|