sglang 0.4.6__py3-none-any.whl → 0.4.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/disaggregation/decode.py +8 -2
- sglang/srt/disaggregation/fake/__init__.py +1 -0
- sglang/srt/disaggregation/fake/conn.py +88 -0
- sglang/srt/disaggregation/prefill.py +12 -3
- sglang/srt/disaggregation/utils.py +16 -2
- sglang/srt/entrypoints/engine.py +9 -0
- sglang/srt/entrypoints/http_server.py +27 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
- sglang/srt/layers/attention/utils.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
- sglang/srt/layers/quantization/fp8.py +20 -22
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/managers/schedule_batch.py +9 -0
- sglang/srt/managers/scheduler.py +10 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +25 -9
- sglang/srt/managers/tp_worker.py +3 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -4
- sglang/srt/model_executor/model_runner.py +8 -1
- sglang/srt/openai_api/adapter.py +32 -3
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/reasoning_parser.py +25 -1
- sglang/srt/server_args.py +16 -2
- sglang/srt/utils.py +3 -0
- sglang/test/send_one.py +84 -28
- sglang/test/test_utils.py +38 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/METADATA +2 -2
- {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/RECORD +44 -29
- {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.dist-info → sglang-0.4.6.post1.dist-info}/top_level.txt +0 -0
@@ -72,8 +72,8 @@ _is_hip = is_hip()
|
|
72
72
|
_is_cuda = is_cuda()
|
73
73
|
|
74
74
|
if _is_hip:
|
75
|
-
from aiter import ActivationType
|
76
|
-
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
75
|
+
from aiter import ActivationType, QuantType
|
76
|
+
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
77
77
|
from aiter.ops.shuffle import shuffle_weight
|
78
78
|
|
79
79
|
if not _is_cuda:
|
@@ -484,7 +484,7 @@ class Fp8MoEMethod:
|
|
484
484
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
485
485
|
params_dtype = (
|
486
486
|
torch.uint32
|
487
|
-
if get_bool_env_var("
|
487
|
+
if get_bool_env_var("SGLANG_INT4_WEIGHT")
|
488
488
|
else torch.float8_e4m3fn
|
489
489
|
)
|
490
490
|
tp_size = get_tensor_model_parallel_world_size()
|
@@ -511,7 +511,7 @@ class Fp8MoEMethod:
|
|
511
511
|
)
|
512
512
|
|
513
513
|
# WEIGHTS
|
514
|
-
if _is_hip and get_bool_env_var("
|
514
|
+
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
515
515
|
# INT4 MoE weight - INT32 packed
|
516
516
|
w13_weight = torch.nn.Parameter(
|
517
517
|
torch.empty(
|
@@ -585,7 +585,7 @@ class Fp8MoEMethod:
|
|
585
585
|
|
586
586
|
if (
|
587
587
|
_is_hip
|
588
|
-
): # and get_bool_env_var("
|
588
|
+
): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
|
589
589
|
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
590
590
|
w13_weight_scale1 = torch.nn.Parameter(
|
591
591
|
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
@@ -612,7 +612,7 @@ class Fp8MoEMethod:
|
|
612
612
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
613
613
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
614
614
|
|
615
|
-
if _is_hip and get_bool_env_var("
|
615
|
+
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
616
616
|
extra_weight_attrs.update(
|
617
617
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
618
618
|
)
|
@@ -644,7 +644,7 @@ class Fp8MoEMethod:
|
|
644
644
|
layer.w2_input_scale = None
|
645
645
|
|
646
646
|
def process_weights_after_loading(self, layer: Module) -> None:
|
647
|
-
if _is_hip and get_bool_env_var("
|
647
|
+
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
648
648
|
self.process_weights_hip_int4(layer)
|
649
649
|
return
|
650
650
|
|
@@ -675,7 +675,7 @@ class Fp8MoEMethod:
|
|
675
675
|
)
|
676
676
|
layer.w2_input_scale = None
|
677
677
|
|
678
|
-
if get_bool_env_var("
|
678
|
+
if get_bool_env_var("SGLANG_AITER_MOE"):
|
679
679
|
# Pre-shuffle weights
|
680
680
|
layer.w13_weight.data = shuffle_weight(
|
681
681
|
layer.w13_weight.contiguous(), (16, 16)
|
@@ -798,17 +798,15 @@ class Fp8MoEMethod:
|
|
798
798
|
return
|
799
799
|
|
800
800
|
def process_weights_hip_int4(self, layer: Module):
|
801
|
-
# TODO: and get_bool_env_var("
|
801
|
+
# TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added
|
802
802
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
803
803
|
# Weight Permutation
|
804
804
|
layer.w13_weight = torch.nn.Parameter(
|
805
|
-
# permute_weight(layer.w13_weight.data),
|
806
805
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
807
806
|
requires_grad=False,
|
808
807
|
)
|
809
808
|
torch.cuda.empty_cache()
|
810
809
|
layer.w2_weight = torch.nn.Parameter(
|
811
|
-
# permute_weight(layer.w2_weight.data),
|
812
810
|
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
813
811
|
requires_grad=False,
|
814
812
|
)
|
@@ -847,23 +845,21 @@ class Fp8MoEMethod:
|
|
847
845
|
padding_size, # Avoid circular import
|
848
846
|
)
|
849
847
|
|
850
|
-
if get_bool_env_var("
|
848
|
+
if get_bool_env_var("SGLANG_AITER_MOE"):
|
851
849
|
layer.w13_weight = torch.nn.Parameter(
|
852
|
-
# permute_weight(layer.w13_weight.data),
|
853
850
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
854
851
|
requires_grad=False,
|
855
852
|
)
|
856
853
|
torch.cuda.empty_cache()
|
857
854
|
layer.w2_weight = torch.nn.Parameter(
|
858
|
-
# permute_weight(layer.w2_weight.data),
|
859
855
|
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
860
856
|
requires_grad=False,
|
861
857
|
)
|
862
858
|
torch.cuda.empty_cache()
|
863
|
-
# ROCm (
|
859
|
+
# ROCm (SGLANG_AITER_MOE): using column-wise scaling
|
864
860
|
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
865
861
|
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
866
|
-
elif get_bool_env_var("
|
862
|
+
elif get_bool_env_var("SGLANG_MOE_PADDING"):
|
867
863
|
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
868
864
|
layer.w13_weight = torch.nn.Parameter(
|
869
865
|
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
@@ -912,15 +908,16 @@ class Fp8MoEMethod:
|
|
912
908
|
)
|
913
909
|
|
914
910
|
if _is_hip:
|
915
|
-
if get_bool_env_var("
|
916
|
-
# TODO: add triton kernel and add check get_bool_env_var("
|
911
|
+
if get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
912
|
+
# TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
|
917
913
|
assert not no_combine, f"{no_combine=} is not supported."
|
918
|
-
return
|
914
|
+
return ck_moe_2stages(
|
919
915
|
x,
|
920
916
|
layer.w13_weight,
|
921
917
|
layer.w2_weight,
|
922
918
|
topk_weights,
|
923
919
|
topk_ids,
|
920
|
+
QuantType.per_Token,
|
924
921
|
layer.w13_weight_scale1,
|
925
922
|
layer.w2_weight_scale1,
|
926
923
|
activation=(
|
@@ -930,13 +927,13 @@ class Fp8MoEMethod:
|
|
930
927
|
),
|
931
928
|
)
|
932
929
|
|
933
|
-
if get_bool_env_var("
|
930
|
+
if get_bool_env_var("SGLANG_AITER_MOE"):
|
934
931
|
assert not no_combine, f"{no_combine=} is not supported."
|
935
932
|
if self.block_quant:
|
936
|
-
# TODO(
|
933
|
+
# TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
937
934
|
assert (
|
938
935
|
activation == "silu"
|
939
|
-
), f"
|
936
|
+
), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
|
940
937
|
return asm_moe(
|
941
938
|
x,
|
942
939
|
layer.w13_weight,
|
@@ -955,6 +952,7 @@ class Fp8MoEMethod:
|
|
955
952
|
layer.w2_weight,
|
956
953
|
topk_weights,
|
957
954
|
topk_ids,
|
955
|
+
QuantType.per_Token,
|
958
956
|
layer.w13_weight_scale1,
|
959
957
|
layer.w2_weight_scale1,
|
960
958
|
activation=(
|
@@ -31,7 +31,7 @@ from sglang.srt.utils import (
|
|
31
31
|
_is_hip = is_hip()
|
32
32
|
_is_cuda = is_cuda()
|
33
33
|
|
34
|
-
if _is_hip and get_bool_env_var("
|
34
|
+
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
35
35
|
from aiter import gemm_a8w8_blockscale
|
36
36
|
|
37
37
|
if _is_cuda:
|
@@ -132,7 +132,7 @@ def apply_w8a8_block_fp8_linear(
|
|
132
132
|
output = fp8_blockwise_scaled_mm(
|
133
133
|
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
134
134
|
)
|
135
|
-
elif _is_hip and get_bool_env_var("
|
135
|
+
elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
136
136
|
q_input, x_scale = per_token_group_quant_fp8(
|
137
137
|
input_2d, block_size[1], column_major_scales=False
|
138
138
|
)
|
@@ -35,6 +35,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|
35
35
|
import copy
|
36
36
|
import dataclasses
|
37
37
|
import logging
|
38
|
+
import threading
|
38
39
|
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
39
40
|
|
40
41
|
import numpy as np
|
@@ -724,6 +725,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
724
725
|
# This is an optimization to reduce the overhead of the prefill check.
|
725
726
|
batch_is_full: bool = False
|
726
727
|
|
728
|
+
# Events
|
729
|
+
launch_done: Optional[threading.Event] = None
|
730
|
+
|
727
731
|
# Sampling info
|
728
732
|
sampling_info: SamplingBatchInfo = None
|
729
733
|
next_batch_sampling_info: SamplingBatchInfo = None
|
@@ -1511,6 +1515,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1511
1515
|
)
|
1512
1516
|
or global_server_args_dict["attention_backend"] == "flashmla"
|
1513
1517
|
or global_server_args_dict["attention_backend"] == "fa3"
|
1518
|
+
or global_server_args_dict["attention_backend"] == "cutlass_mla"
|
1514
1519
|
):
|
1515
1520
|
seq_lens_cpu = self.seq_lens.cpu()
|
1516
1521
|
else:
|
@@ -1565,6 +1570,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1565
1570
|
)
|
1566
1571
|
),
|
1567
1572
|
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
1573
|
+
launch_done=self.launch_done,
|
1568
1574
|
)
|
1569
1575
|
|
1570
1576
|
def copy(self):
|
@@ -1647,6 +1653,9 @@ class ModelWorkerBatch:
|
|
1647
1653
|
# If set, the output of the batch contains the hidden states of the run.
|
1648
1654
|
capture_hidden_mode: CaptureHiddenMode = None
|
1649
1655
|
|
1656
|
+
# Overlap event
|
1657
|
+
launch_done: Optional[threading.Event] = None
|
1658
|
+
|
1650
1659
|
|
1651
1660
|
@triton.jit
|
1652
1661
|
def write_req_to_token_pool_triton(
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -248,9 +248,6 @@ class Scheduler(
|
|
248
248
|
if not self.is_generation:
|
249
249
|
self.enable_overlap = False
|
250
250
|
logger.info("Overlap scheduler is disabled for embedding models.")
|
251
|
-
if self.model_config.is_multimodal:
|
252
|
-
self.enable_overlap = False
|
253
|
-
logger.info("Overlap scheduler is disabled for multimodal models.")
|
254
251
|
|
255
252
|
# Launch a tensor parallel worker
|
256
253
|
if self.enable_overlap:
|
@@ -645,6 +642,7 @@ class Scheduler(
|
|
645
642
|
self.cur_batch = batch
|
646
643
|
|
647
644
|
if batch:
|
645
|
+
batch.launch_done = threading.Event()
|
648
646
|
result = self.run_batch(batch)
|
649
647
|
self.result_queue.append((batch.copy(), result))
|
650
648
|
|
@@ -656,7 +654,7 @@ class Scheduler(
|
|
656
654
|
forward_mode=ForwardMode.DUMMY_FIRST,
|
657
655
|
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
658
656
|
)
|
659
|
-
self.process_batch_result(tmp_batch, None)
|
657
|
+
self.process_batch_result(tmp_batch, None, batch.launch_done)
|
660
658
|
|
661
659
|
if self.last_batch:
|
662
660
|
# Process the results of the last batch
|
@@ -664,7 +662,10 @@ class Scheduler(
|
|
664
662
|
tmp_batch.next_batch_sampling_info = (
|
665
663
|
self.tp_worker.cur_sampling_info if batch else None
|
666
664
|
)
|
667
|
-
|
665
|
+
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
|
666
|
+
self.process_batch_result(
|
667
|
+
tmp_batch, tmp_result, batch.launch_done if batch else None
|
668
|
+
)
|
668
669
|
elif batch is None:
|
669
670
|
# When the server is idle, do self-check and re-init some states
|
670
671
|
self.check_memory()
|
@@ -1417,14 +1418,15 @@ class Scheduler(
|
|
1417
1418
|
self,
|
1418
1419
|
batch: ScheduleBatch,
|
1419
1420
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
1421
|
+
launch_done: Optional[threading.Event] = None,
|
1420
1422
|
):
|
1421
1423
|
if batch.forward_mode.is_decode():
|
1422
|
-
self.process_batch_result_decode(batch, result)
|
1424
|
+
self.process_batch_result_decode(batch, result, launch_done)
|
1423
1425
|
elif batch.forward_mode.is_extend():
|
1424
|
-
self.process_batch_result_prefill(batch, result)
|
1426
|
+
self.process_batch_result_prefill(batch, result, launch_done)
|
1425
1427
|
elif batch.forward_mode.is_idle():
|
1426
1428
|
if self.enable_overlap:
|
1427
|
-
self.tp_worker.
|
1429
|
+
self.tp_worker.resolve_last_batch_result(launch_done)
|
1428
1430
|
if batch.next_batch_sampling_info:
|
1429
1431
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1430
1432
|
self.current_stream.synchronize()
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import threading
|
3
4
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
4
5
|
|
5
6
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
@@ -11,6 +12,7 @@ if TYPE_CHECKING:
|
|
11
12
|
EmbeddingBatchResult,
|
12
13
|
GenerationBatchResult,
|
13
14
|
ScheduleBatch,
|
15
|
+
Scheduler,
|
14
16
|
)
|
15
17
|
|
16
18
|
|
@@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin:
|
|
21
23
|
"""
|
22
24
|
|
23
25
|
def process_batch_result_prefill(
|
24
|
-
self,
|
26
|
+
self: Scheduler,
|
25
27
|
batch: ScheduleBatch,
|
26
28
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
29
|
+
launch_done: Optional[threading.Event] = None,
|
27
30
|
):
|
28
31
|
skip_stream_req = None
|
29
32
|
|
@@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin:
|
|
43
46
|
)
|
44
47
|
|
45
48
|
if self.enable_overlap:
|
46
|
-
logits_output, next_token_ids =
|
49
|
+
logits_output, next_token_ids = (
|
50
|
+
self.tp_worker.resolve_last_batch_result(
|
51
|
+
launch_done,
|
52
|
+
)
|
53
|
+
)
|
47
54
|
else:
|
48
55
|
# Move next_token_ids and logprobs to cpu
|
49
56
|
next_token_ids = next_token_ids.tolist()
|
@@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin:
|
|
175
182
|
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
176
183
|
|
177
184
|
def process_batch_result_decode(
|
178
|
-
self,
|
185
|
+
self: Scheduler,
|
179
186
|
batch: ScheduleBatch,
|
180
187
|
result: GenerationBatchResult,
|
188
|
+
launch_done: Optional[threading.Event] = None,
|
181
189
|
):
|
182
190
|
logits_output, next_token_ids, bid = (
|
183
191
|
result.logits_output,
|
@@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin:
|
|
187
195
|
self.num_generated_tokens += len(batch.reqs)
|
188
196
|
|
189
197
|
if self.enable_overlap:
|
190
|
-
logits_output, next_token_ids = self.tp_worker.
|
198
|
+
logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result(
|
199
|
+
launch_done
|
200
|
+
)
|
191
201
|
next_token_logprobs = logits_output.next_token_logprobs
|
192
202
|
elif batch.spec_algorithm.is_none():
|
193
203
|
# spec decoding handles output logprobs inside verify process.
|
@@ -271,7 +281,7 @@ class SchedulerOutputProcessorMixin:
|
|
271
281
|
self.log_decode_stats()
|
272
282
|
|
273
283
|
def add_input_logprob_return_values(
|
274
|
-
self,
|
284
|
+
self: Scheduler,
|
275
285
|
i: int,
|
276
286
|
req: Req,
|
277
287
|
output: LogitsProcessorOutput,
|
@@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin:
|
|
405
415
|
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
|
406
416
|
|
407
417
|
def add_logprob_return_values(
|
408
|
-
self,
|
418
|
+
self: Scheduler,
|
409
419
|
i: int,
|
410
420
|
req: Req,
|
411
421
|
pt: int,
|
@@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin:
|
|
436
446
|
return num_input_logprobs
|
437
447
|
|
438
448
|
def stream_output(
|
439
|
-
self
|
449
|
+
self: Scheduler,
|
450
|
+
reqs: List[Req],
|
451
|
+
return_logprob: bool,
|
452
|
+
skip_req: Optional[Req] = None,
|
440
453
|
):
|
441
454
|
"""Stream the output to detokenizer."""
|
442
455
|
if self.is_generation:
|
@@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin:
|
|
445
458
|
self.stream_output_embedding(reqs)
|
446
459
|
|
447
460
|
def stream_output_generation(
|
448
|
-
self
|
461
|
+
self: Scheduler,
|
462
|
+
reqs: List[Req],
|
463
|
+
return_logprob: bool,
|
464
|
+
skip_req: Optional[Req] = None,
|
449
465
|
):
|
450
466
|
rids = []
|
451
467
|
finished_reasons: List[BaseFinishReason] = []
|
@@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin:
|
|
593
609
|
)
|
594
610
|
)
|
595
611
|
|
596
|
-
def stream_output_embedding(self, reqs: List[Req]):
|
612
|
+
def stream_output_embedding(self: Scheduler, reqs: List[Req]):
|
597
613
|
rids = []
|
598
614
|
finished_reasons: List[BaseFinishReason] = []
|
599
615
|
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -170,13 +170,13 @@ class TpModelWorker:
|
|
170
170
|
def forward_batch_generation(
|
171
171
|
self,
|
172
172
|
model_worker_batch: ModelWorkerBatch,
|
173
|
-
launch_done: Optional[threading.Event] = None,
|
174
173
|
skip_sample: bool = False,
|
175
174
|
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
|
176
175
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
177
176
|
logits_output = self.model_runner.forward(forward_batch)
|
178
|
-
|
179
|
-
|
177
|
+
|
178
|
+
if model_worker_batch.launch_done is not None:
|
179
|
+
model_worker_batch.launch_done.set()
|
180
180
|
|
181
181
|
if skip_sample:
|
182
182
|
next_token_ids = None
|
@@ -132,7 +132,6 @@ class TpModelWorkerClient:
|
|
132
132
|
batch_pt += 1
|
133
133
|
|
134
134
|
# Create event
|
135
|
-
self.launch_done = threading.Event()
|
136
135
|
copy_done = torch.get_device_module(self.device).Event()
|
137
136
|
|
138
137
|
# Resolve future tokens in the input
|
@@ -141,7 +140,7 @@ class TpModelWorkerClient:
|
|
141
140
|
|
142
141
|
# Run forward
|
143
142
|
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
144
|
-
model_worker_batch
|
143
|
+
model_worker_batch
|
145
144
|
)
|
146
145
|
|
147
146
|
# Update the future token ids map
|
@@ -168,10 +167,16 @@ class TpModelWorkerClient:
|
|
168
167
|
|
169
168
|
self.output_queue.put((copy_done, logits_output, next_token_ids))
|
170
169
|
|
171
|
-
def
|
170
|
+
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
|
171
|
+
"""
|
172
|
+
This function is called to resolve the last batch result and
|
173
|
+
wait for the current batch to be launched. Used in overlap mode.
|
174
|
+
"""
|
172
175
|
copy_done, logits_output, next_token_ids = self.output_queue.get()
|
176
|
+
|
177
|
+
if launch_done is not None:
|
178
|
+
launch_done.wait()
|
173
179
|
copy_done.synchronize()
|
174
|
-
self.launch_done.wait()
|
175
180
|
|
176
181
|
if logits_output.next_token_logprobs is not None:
|
177
182
|
logits_output.next_token_logprobs = (
|
@@ -271,6 +271,7 @@ class ModelRunner:
|
|
271
271
|
"fa3",
|
272
272
|
"triton",
|
273
273
|
"flashmla",
|
274
|
+
"cutlass_mla",
|
274
275
|
]:
|
275
276
|
logger.info(
|
276
277
|
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
@@ -926,6 +927,12 @@ class ModelRunner:
|
|
926
927
|
)
|
927
928
|
|
928
929
|
self.attn_backend = FlashAttentionBackend(self)
|
930
|
+
elif self.server_args.attention_backend == "cutlass_mla":
|
931
|
+
from sglang.srt.layers.attention.cutlass_mla_backend import (
|
932
|
+
CutlassMLABackend,
|
933
|
+
)
|
934
|
+
|
935
|
+
self.attn_backend = CutlassMLABackend(self)
|
929
936
|
else:
|
930
937
|
raise ValueError(
|
931
938
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
@@ -968,7 +975,7 @@ class ModelRunner:
|
|
968
975
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
969
976
|
logger.info(
|
970
977
|
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
|
971
|
-
f"
|
978
|
+
f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
972
979
|
)
|
973
980
|
|
974
981
|
def apply_torch_tp(self):
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -971,6 +971,8 @@ def v1_chat_generate_request(
|
|
971
971
|
)
|
972
972
|
|
973
973
|
for message in request.messages:
|
974
|
+
if message.content is None:
|
975
|
+
message.content = ""
|
974
976
|
if isinstance(message.content, str):
|
975
977
|
openai_compatible_messages.append(
|
976
978
|
{"role": message.role, "content": message.content}
|
@@ -1001,6 +1003,11 @@ def v1_chat_generate_request(
|
|
1001
1003
|
tokenize=True,
|
1002
1004
|
add_generation_prompt=True,
|
1003
1005
|
tools=tools,
|
1006
|
+
**(
|
1007
|
+
request.chat_template_kwargs
|
1008
|
+
if request.chat_template_kwargs
|
1009
|
+
else {}
|
1010
|
+
),
|
1004
1011
|
)
|
1005
1012
|
except:
|
1006
1013
|
# This except branch will be triggered when the chosen model
|
@@ -1012,6 +1019,11 @@ def v1_chat_generate_request(
|
|
1012
1019
|
tokenize=True,
|
1013
1020
|
add_generation_prompt=True,
|
1014
1021
|
tools=tools,
|
1022
|
+
**(
|
1023
|
+
request.chat_template_kwargs
|
1024
|
+
if request.chat_template_kwargs
|
1025
|
+
else {}
|
1026
|
+
),
|
1015
1027
|
)
|
1016
1028
|
|
1017
1029
|
if assistant_prefix:
|
@@ -1179,6 +1191,7 @@ def v1_chat_generate_request(
|
|
1179
1191
|
modalities=modalities_list,
|
1180
1192
|
lora_path=lora_paths,
|
1181
1193
|
bootstrap_host=all_requests[0].bootstrap_host,
|
1194
|
+
bootstrap_port=all_requests[0].bootstrap_port,
|
1182
1195
|
bootstrap_room=all_requests[0].bootstrap_room,
|
1183
1196
|
)
|
1184
1197
|
|
@@ -1245,16 +1258,34 @@ def v1_chat_generate_response(
|
|
1245
1258
|
tool_calls = None
|
1246
1259
|
text = ret_item["text"]
|
1247
1260
|
|
1261
|
+
enable_thinking = True
|
1248
1262
|
if isinstance(request, list):
|
1249
1263
|
tool_choice = request[idx].tool_choice
|
1250
1264
|
tools = request[idx].tools
|
1251
1265
|
separate_reasoning = request[idx].separate_reasoning
|
1266
|
+
|
1267
|
+
if (
|
1268
|
+
request[idx].chat_template_kwargs
|
1269
|
+
and request[idx].chat_template_kwargs.get("enable_thinking") is not None
|
1270
|
+
):
|
1271
|
+
enable_thinking = request[idx].chat_template_kwargs.get(
|
1272
|
+
"enable_thinking", True
|
1273
|
+
)
|
1252
1274
|
else:
|
1253
1275
|
tool_choice = request.tool_choice
|
1254
1276
|
tools = request.tools
|
1255
1277
|
separate_reasoning = request.separate_reasoning
|
1256
1278
|
|
1257
|
-
|
1279
|
+
if (
|
1280
|
+
request.chat_template_kwargs
|
1281
|
+
and request.chat_template_kwargs.get("enable_thinking") is not None
|
1282
|
+
):
|
1283
|
+
enable_thinking = request.chat_template_kwargs.get(
|
1284
|
+
"enable_thinking", True
|
1285
|
+
)
|
1286
|
+
|
1287
|
+
reasoning_text = None
|
1288
|
+
if reasoning_parser and separate_reasoning and enable_thinking:
|
1258
1289
|
try:
|
1259
1290
|
parser = ReasoningParser(
|
1260
1291
|
model_type=reasoning_parser, stream_reasoning=False
|
@@ -1266,8 +1297,6 @@ def v1_chat_generate_response(
|
|
1266
1297
|
HTTPStatus.BAD_REQUEST,
|
1267
1298
|
"Failed to parse reasoning related info to json format!",
|
1268
1299
|
)
|
1269
|
-
else:
|
1270
|
-
reasoning_text = None
|
1271
1300
|
|
1272
1301
|
if tool_choice != "none" and tools:
|
1273
1302
|
parser = FunctionCallParser(tools, tool_call_parser)
|
@@ -361,9 +361,11 @@ class ChatCompletionRequest(BaseModel):
|
|
361
361
|
session_params: Optional[Dict] = None
|
362
362
|
separate_reasoning: bool = True
|
363
363
|
stream_reasoning: bool = True
|
364
|
+
chat_template_kwargs: Optional[Dict] = None
|
364
365
|
|
365
366
|
# For PD disaggregation
|
366
367
|
bootstrap_host: Optional[str] = None
|
368
|
+
bootstrap_port: Optional[int] = None
|
367
369
|
bootstrap_room: Optional[int] = None
|
368
370
|
|
369
371
|
|
sglang/srt/reasoning_parser.py
CHANGED
@@ -117,6 +117,29 @@ class DeepSeekR1Detector(BaseReasoningFormatDetector):
|
|
117
117
|
# https://github.com/sgl-project/sglang/pull/3202#discussion_r1950153599
|
118
118
|
|
119
119
|
|
120
|
+
class Qwen3Detector(BaseReasoningFormatDetector):
|
121
|
+
"""
|
122
|
+
Detector for Qwen3 model.
|
123
|
+
Assumes reasoning format:
|
124
|
+
(<think>)*(.*)</think>
|
125
|
+
Returns all the text before the </think> tag as `reasoning_text`
|
126
|
+
and the rest of the text as `normal_text`.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
stream_reasoning (bool): If False, accumulates reasoning content until the end tag.
|
130
|
+
If True, streams reasoning content as it arrives.
|
131
|
+
"""
|
132
|
+
|
133
|
+
def __init__(self, stream_reasoning: bool = True):
|
134
|
+
# Qwen3 is assumed to be reasoning until `</think>` token
|
135
|
+
super().__init__(
|
136
|
+
"<think>",
|
137
|
+
"</think>",
|
138
|
+
force_reasoning=True,
|
139
|
+
stream_reasoning=stream_reasoning,
|
140
|
+
)
|
141
|
+
|
142
|
+
|
120
143
|
class ReasoningParser:
|
121
144
|
"""
|
122
145
|
Parser that handles both streaming and non-streaming scenarios for extracting
|
@@ -129,7 +152,8 @@ class ReasoningParser:
|
|
129
152
|
"""
|
130
153
|
|
131
154
|
DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
|
132
|
-
"deepseek-r1": DeepSeekR1Detector
|
155
|
+
"deepseek-r1": DeepSeekR1Detector,
|
156
|
+
"qwen3": Qwen3Detector,
|
133
157
|
}
|
134
158
|
|
135
159
|
def __init__(self, model_type: str = None, stream_reasoning: bool = True):
|
sglang/srt/server_args.py
CHANGED
@@ -256,6 +256,12 @@ class ServerArgs:
|
|
256
256
|
)
|
257
257
|
self.page_size = 64
|
258
258
|
|
259
|
+
if self.attention_backend == "cutlass_mla":
|
260
|
+
logger.warning(
|
261
|
+
"Cutlass MLA only supports a page_size of 128, change page_size to 128."
|
262
|
+
)
|
263
|
+
self.page_size = 128
|
264
|
+
|
259
265
|
# Set cuda graph max batch size
|
260
266
|
if self.cuda_graph_max_bs is None:
|
261
267
|
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
|
@@ -420,7 +426,7 @@ class ServerArgs:
|
|
420
426
|
parser.add_argument(
|
421
427
|
"--skip-tokenizer-init",
|
422
428
|
action="store_true",
|
423
|
-
help="If set, skip init tokenizer and pass input_ids in generate request",
|
429
|
+
help="If set, skip init tokenizer and pass input_ids in generate request.",
|
424
430
|
)
|
425
431
|
parser.add_argument(
|
426
432
|
"--enable-tokenizer-batch-encode",
|
@@ -559,6 +565,7 @@ class ServerArgs:
|
|
559
565
|
"name, a tag name, or a commit id. If unspecified, will use "
|
560
566
|
"the default version.",
|
561
567
|
)
|
568
|
+
|
562
569
|
# Memory and scheduling
|
563
570
|
parser.add_argument(
|
564
571
|
"--mem-fraction-static",
|
@@ -823,7 +830,14 @@ class ServerArgs:
|
|
823
830
|
parser.add_argument(
|
824
831
|
"--attention-backend",
|
825
832
|
type=str,
|
826
|
-
choices=[
|
833
|
+
choices=[
|
834
|
+
"flashinfer",
|
835
|
+
"triton",
|
836
|
+
"torch_native",
|
837
|
+
"fa3",
|
838
|
+
"flashmla",
|
839
|
+
"cutlass_mla",
|
840
|
+
],
|
827
841
|
default=ServerArgs.attention_backend,
|
828
842
|
help="Choose the kernels for attention layers.",
|
829
843
|
)
|
sglang/srt/utils.py
CHANGED
@@ -1970,8 +1970,11 @@ def is_fa3_default_architecture(hf_config):
|
|
1970
1970
|
"Llama4ForConditionalGeneration",
|
1971
1971
|
"LlamaForCausalLM",
|
1972
1972
|
"MistralForCausalLM",
|
1973
|
+
"MixtralForCausalLM",
|
1973
1974
|
"Gemma2ForCausalLM",
|
1974
1975
|
"Gemma3ForConditionalGeneration",
|
1976
|
+
"Qwen3ForCausalLM",
|
1977
|
+
"Qwen3MoeForCausalLM",
|
1975
1978
|
}
|
1976
1979
|
return architectures[0] in default_archs
|
1977
1980
|
|