sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ from sglang.srt.layers.moe.moe_runner.base import (
|
|
9
9
|
MoeRunnerConfig,
|
10
10
|
PermuteMethodPool,
|
11
11
|
)
|
12
|
+
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore
|
12
13
|
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
|
13
14
|
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
|
14
15
|
|
@@ -30,6 +31,8 @@ class MoeRunner:
|
|
30
31
|
|
31
32
|
if runner_backend.is_triton():
|
32
33
|
self.runner_core = TritonRunnerCore(config)
|
34
|
+
elif runner_backend.is_deep_gemm():
|
35
|
+
self.runner_core = DeepGemmRunnerCore(config)
|
33
36
|
else:
|
34
37
|
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
|
35
38
|
|
sglang/srt/layers/moe/utils.py
CHANGED
@@ -44,6 +44,7 @@ class MoeA2ABackend(Enum):
|
|
44
44
|
class MoeRunnerBackend(Enum):
|
45
45
|
|
46
46
|
AUTO = "auto"
|
47
|
+
DEEP_GEMM = "deep_gemm"
|
47
48
|
TRITON = "triton"
|
48
49
|
TRITON_KERNEL = "triton_kernel"
|
49
50
|
FLASHINFER_TRTLLM = "flashinfer_trtllm"
|
@@ -54,6 +55,9 @@ class MoeRunnerBackend(Enum):
|
|
54
55
|
def is_auto(self):
|
55
56
|
return self == MoeRunnerBackend.AUTO
|
56
57
|
|
58
|
+
def is_deep_gemm(self):
|
59
|
+
return self == MoeRunnerBackend.DEEP_GEMM
|
60
|
+
|
57
61
|
def is_triton(self):
|
58
62
|
return self == MoeRunnerBackend.TRITON
|
59
63
|
|
@@ -147,7 +151,9 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
|
|
147
151
|
def get_moe_runner_backend() -> MoeRunnerBackend:
|
148
152
|
global MOE_RUNNER_BACKEND
|
149
153
|
if MOE_RUNNER_BACKEND is None:
|
150
|
-
logger.warning(
|
154
|
+
logger.warning(
|
155
|
+
"MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected"
|
156
|
+
)
|
151
157
|
MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
|
152
158
|
return MOE_RUNNER_BACKEND
|
153
159
|
|
@@ -72,7 +72,7 @@ if TYPE_CHECKING:
|
|
72
72
|
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
73
73
|
"fp8": Fp8Config,
|
74
74
|
"blockwise_int8": BlockInt8Config,
|
75
|
-
"
|
75
|
+
"modelopt_fp8": ModelOptFp8Config,
|
76
76
|
"modelopt_fp4": ModelOptFp4Config,
|
77
77
|
"w8a8_int8": W8A8Int8Config,
|
78
78
|
"w8a8_fp8": W8A8Fp8Config,
|
@@ -31,8 +31,8 @@ except ImportError:
|
|
31
31
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
32
32
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
33
33
|
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
34
|
+
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
|
34
35
|
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
35
|
-
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
|
36
36
|
from sglang.srt.layers.parameter import (
|
37
37
|
BlockQuantScaleParameter,
|
38
38
|
ModelWeightParameter,
|
@@ -1006,8 +1006,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1006
1006
|
def create_moe_runner(
|
1007
1007
|
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1008
1008
|
):
|
1009
|
+
|
1010
|
+
from sglang.srt.layers.moe.utils import (
|
1011
|
+
get_moe_a2a_backend,
|
1012
|
+
get_moe_runner_backend,
|
1013
|
+
)
|
1014
|
+
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
1015
|
+
|
1009
1016
|
self.moe_runner_config = moe_runner_config
|
1010
|
-
|
1017
|
+
moe_runner_backend = get_moe_runner_backend()
|
1018
|
+
|
1019
|
+
if moe_runner_backend.is_auto():
|
1020
|
+
if (
|
1021
|
+
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
1022
|
+
and get_moe_a2a_backend().is_deepep()
|
1023
|
+
):
|
1024
|
+
moe_runner_backend = MoeRunnerBackend.DEEP_GEMM
|
1025
|
+
else:
|
1026
|
+
moe_runner_backend = MoeRunnerBackend.TRITON
|
1027
|
+
if moe_runner_backend.is_deep_gemm() or moe_runner_backend.is_triton():
|
1028
|
+
self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
|
1029
|
+
else:
|
1030
|
+
# TODO(cwan): refactor other backends
|
1031
|
+
pass
|
1011
1032
|
|
1012
1033
|
def apply(
|
1013
1034
|
self,
|
@@ -1087,22 +1108,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1087
1108
|
)
|
1088
1109
|
return StandardCombineInput(hidden_states=output)
|
1089
1110
|
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1111
|
+
if self.runner.runner_backend.is_deep_gemm():
|
1112
|
+
|
1113
|
+
w13_weight = layer.w13_weight
|
1114
|
+
w2_weight = layer.w2_weight
|
1115
|
+
|
1116
|
+
if self.block_quant:
|
1117
|
+
block_shape = self.quant_config.weight_block_size
|
1118
|
+
w13_scale = layer.w13_weight_scale_inv
|
1119
|
+
w2_scale = layer.w2_weight_scale_inv
|
1120
|
+
else:
|
1121
|
+
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
1122
|
+
scale_block_size = 128
|
1123
|
+
block_shape = [scale_block_size, scale_block_size]
|
1124
|
+
w13_scale_n = (w13_weight.shape[1] - 1) // scale_block_size + 1
|
1125
|
+
w13_scale_k = (w13_weight.shape[2] - 1) // scale_block_size + 1
|
1126
|
+
w13_scale = (
|
1127
|
+
layer.w13_weight_scale.unsqueeze(1)
|
1128
|
+
.repeat_interleave(w13_scale_n, dim=1)
|
1129
|
+
.unsqueeze(2)
|
1130
|
+
.repeat_interleave(w13_scale_k, dim=2)
|
1131
|
+
)
|
1132
|
+
w2_scale_n = (w2_weight.shape[1] - 1) // scale_block_size + 1
|
1133
|
+
w2_scale_k = (w2_weight.shape[2] - 1) // scale_block_size + 1
|
1134
|
+
w2_scale = (
|
1135
|
+
layer.w2_weight_scale.unsqueeze(1)
|
1136
|
+
.repeat_interleave(w2_scale_n, dim=1)
|
1137
|
+
.unsqueeze(2)
|
1138
|
+
.repeat_interleave(w2_scale_k, dim=2)
|
1139
|
+
)
|
1140
|
+
quant_info = DeepGemmMoeQuantInfo(
|
1141
|
+
w13_weight=w13_weight,
|
1142
|
+
w2_weight=w2_weight,
|
1143
|
+
use_fp8=True,
|
1144
|
+
w13_scale=w13_scale,
|
1145
|
+
w2_scale=w2_scale,
|
1146
|
+
block_shape=block_shape,
|
1147
|
+
)
|
1148
|
+
elif self.runner.runner_backend.is_triton():
|
1149
|
+
quant_info = TritonMoeQuantInfo(
|
1150
|
+
w13_weight=layer.w13_weight,
|
1151
|
+
w2_weight=layer.w2_weight,
|
1152
|
+
use_fp8_w8a8=True,
|
1153
|
+
w13_scale=(
|
1154
|
+
layer.w13_weight_scale_inv
|
1155
|
+
if self.block_quant
|
1156
|
+
else layer.w13_weight_scale
|
1157
|
+
),
|
1158
|
+
w2_scale=(
|
1159
|
+
layer.w2_weight_scale_inv
|
1160
|
+
if self.block_quant
|
1161
|
+
else layer.w2_weight_scale
|
1162
|
+
),
|
1163
|
+
a13_scale=layer.w13_input_scale,
|
1164
|
+
a2_scale=layer.w2_input_scale,
|
1165
|
+
block_shape=self.quant_config.weight_block_size,
|
1166
|
+
)
|
1167
|
+
else:
|
1168
|
+
raise NotImplementedError(
|
1169
|
+
"Unsupported runner backend: %s" % self.runner.runner_backend
|
1170
|
+
)
|
1171
|
+
|
1106
1172
|
return self.runner.run(dispatch_output, quant_info)
|
1107
1173
|
|
1108
1174
|
def apply_with_router_logits(
|
@@ -65,7 +65,9 @@ class QuarkConfig(QuantizationConfig):
|
|
65
65
|
if should_ignore_layer(
|
66
66
|
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
|
67
67
|
):
|
68
|
-
|
68
|
+
if isinstance(layer, LinearBase):
|
69
|
+
return UnquantizedLinearMethod()
|
70
|
+
return None
|
69
71
|
|
70
72
|
if isinstance(layer, LinearBase):
|
71
73
|
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
@@ -21,7 +21,6 @@ from sglang.srt.utils import is_npu, set_weight_attrs
|
|
21
21
|
|
22
22
|
if TYPE_CHECKING:
|
23
23
|
from sglang.srt.layers.moe import MoeRunnerConfig
|
24
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
25
24
|
from sglang.srt.layers.moe.token_dispatcher import (
|
26
25
|
CombineInput,
|
27
26
|
StandardDispatchOutput,
|
@@ -94,9 +93,7 @@ class W4AFp8Config(QuantizationConfig):
|
|
94
93
|
self, layer: torch.nn.Module, prefix: str
|
95
94
|
) -> Optional[QuantizeMethodBase]:
|
96
95
|
from sglang.srt.layers.linear import LinearBase
|
97
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
98
96
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
99
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
100
97
|
|
101
98
|
if isinstance(layer, LinearBase):
|
102
99
|
if is_layer_skipped(prefix, self.ignored_layers):
|
@@ -133,7 +130,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
133
130
|
|
134
131
|
def create_weights(
|
135
132
|
self,
|
136
|
-
layer:
|
133
|
+
layer: Module,
|
137
134
|
num_experts: int,
|
138
135
|
hidden_size: int,
|
139
136
|
intermediate_size_per_partition: int,
|
@@ -292,7 +289,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
292
289
|
|
293
290
|
def apply(
|
294
291
|
self,
|
295
|
-
layer:
|
292
|
+
layer: Module,
|
296
293
|
dispatch_output: StandardDispatchOutput,
|
297
294
|
) -> CombineInput:
|
298
295
|
|
@@ -303,18 +300,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
303
300
|
topk_output = dispatch_output.topk_output
|
304
301
|
|
305
302
|
topk_weights, topk_ids, _ = topk_output
|
306
|
-
local_topk_ids = topk_ids
|
307
|
-
if get_moe_expert_parallel_world_size() > 1:
|
308
|
-
local_topk_ids = torch.where(
|
309
|
-
topk_ids == -1,
|
310
|
-
layer.num_experts,
|
311
|
-
topk_ids,
|
312
|
-
)
|
313
303
|
|
314
304
|
output = cutlass_w4a8_moe(
|
315
|
-
layer.start_expert_id,
|
316
|
-
layer.end_expert_id,
|
317
|
-
layer.num_experts,
|
318
305
|
x,
|
319
306
|
layer.w13_weight,
|
320
307
|
layer.w2_weight,
|
@@ -322,7 +309,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
322
309
|
layer.w2_weight_scale_inv,
|
323
310
|
topk_weights,
|
324
311
|
topk_ids,
|
325
|
-
local_topk_ids,
|
326
312
|
self.a_strides1,
|
327
313
|
self.b_strides1,
|
328
314
|
self.c_strides1,
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -418,10 +418,6 @@ class LoRAManager:
|
|
418
418
|
replace_submodule(self.base_model, module_name, lora_module)
|
419
419
|
return lora_module
|
420
420
|
|
421
|
-
def should_skip_lora_for_vision_model(self, module_name):
|
422
|
-
# TODO: support different vision models
|
423
|
-
return module_name.find("vision_model.model") != -1
|
424
|
-
|
425
421
|
def init_lora_modules(self):
|
426
422
|
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
427
423
|
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
|
@@ -439,10 +435,6 @@ class LoRAManager:
|
|
439
435
|
) and not self.base_model.should_apply_lora(module_name):
|
440
436
|
continue
|
441
437
|
|
442
|
-
# Skip vision model
|
443
|
-
if self.should_skip_lora_for_vision_model(module_name):
|
444
|
-
continue
|
445
|
-
|
446
438
|
# The module should be converted if it is included in target_names
|
447
439
|
if module_name.split(".")[-1] in self.target_modules:
|
448
440
|
layer_id = get_layer_id(module_name)
|
@@ -1,3 +1,6 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Optional
|
3
|
+
|
1
4
|
import torch
|
2
5
|
|
3
6
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
@@ -13,6 +16,12 @@ def _resolve_future_token_ids(input_ids, future_token_ids_map):
|
|
13
16
|
)
|
14
17
|
|
15
18
|
|
19
|
+
@dataclass
|
20
|
+
class FutureIndices:
|
21
|
+
indices: torch.Tensor
|
22
|
+
interval: Optional[slice] = None
|
23
|
+
|
24
|
+
|
16
25
|
class FutureMap:
|
17
26
|
def __init__(
|
18
27
|
self,
|
@@ -30,24 +39,17 @@ class FutureMap:
|
|
30
39
|
(self.future_buffer_len,), dtype=torch.int64, device=self.device
|
31
40
|
)
|
32
41
|
|
33
|
-
def
|
34
|
-
"""Update the circular buffer pointer and
|
42
|
+
def alloc_future_indices(self, bs: int) -> FutureIndices:
|
43
|
+
"""Update the circular buffer pointer and allocate future indices."""
|
35
44
|
cur_future_ct = self.future_ct
|
36
45
|
self.future_ct = (cur_future_ct + bs) % self.future_limit
|
37
|
-
|
46
|
+
start = cur_future_ct + 1
|
47
|
+
end = cur_future_ct + 1 + bs
|
48
|
+
indices = torch.arange(start, end, dtype=torch.int64, device=self.device)
|
49
|
+
return FutureIndices(indices=indices, interval=slice(start, end))
|
38
50
|
|
39
51
|
def resolve_future(self, model_worker_batch: ModelWorkerBatch):
|
40
|
-
input_ids
|
41
|
-
_resolve_future_token_ids(input_ids, self.token_ids_buf)
|
42
|
-
|
43
|
-
def update_next_future(self, future_ct: int, bs: int):
|
44
|
-
return torch.arange(
|
45
|
-
-(future_ct + 1),
|
46
|
-
-(future_ct + 1 + bs),
|
47
|
-
-1,
|
48
|
-
dtype=torch.int64,
|
49
|
-
device=self.device,
|
50
|
-
)
|
52
|
+
_resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf)
|
51
53
|
|
52
|
-
def store_to_map(self,
|
53
|
-
self.token_ids_buf[
|
54
|
+
def store_to_map(self, future_indices: FutureIndices, next_token_ids: torch.Tensor):
|
55
|
+
self.token_ids_buf[future_indices.interval] = next_token_ids
|
@@ -97,7 +97,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
97
97
|
"ep_num_redundant_experts",
|
98
98
|
"enable_nan_detection",
|
99
99
|
"flashinfer_mla_disable_ragged",
|
100
|
-
"
|
100
|
+
"pp_max_micro_batch_size",
|
101
101
|
"disable_shared_experts_fusion",
|
102
102
|
"sampling_backend",
|
103
103
|
"speculative_accept_threshold_single",
|
@@ -114,6 +114,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
114
114
|
"enable_deterministic_inference",
|
115
115
|
"nsa_prefill",
|
116
116
|
"nsa_decode",
|
117
|
+
"multi_item_scoring_delimiter",
|
117
118
|
]
|
118
119
|
|
119
120
|
# Put some global args for easy access
|
@@ -539,7 +540,7 @@ class Req:
|
|
539
540
|
|
540
541
|
# Prefix info
|
541
542
|
# The indices to kv cache for the shared prefix.
|
542
|
-
self.prefix_indices: torch.Tensor =
|
543
|
+
self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64)
|
543
544
|
# Number of tokens to run prefill.
|
544
545
|
self.extend_input_len = 0
|
545
546
|
# The relative logprob_start_len in an extend batch
|
@@ -666,9 +667,11 @@ class Req:
|
|
666
667
|
def is_prefill_only(self) -> bool:
|
667
668
|
"""Check if this request is prefill-only (no token generation needed)."""
|
668
669
|
# NOTE: when spec is enabled, prefill_only optimizations are disabled
|
669
|
-
|
670
|
-
|
671
|
-
|
670
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
671
|
+
|
672
|
+
spec_alg = global_server_args_dict["speculative_algorithm"]
|
673
|
+
return self.sampling_params.max_new_tokens == 0 and (
|
674
|
+
spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
|
672
675
|
)
|
673
676
|
|
674
677
|
def add_latency(self, stage: RequestStage):
|
@@ -691,11 +694,16 @@ class Req:
|
|
691
694
|
# Whether request reached finished condition
|
692
695
|
return self.finished_reason is not None
|
693
696
|
|
694
|
-
def init_next_round_input(
|
695
|
-
self,
|
696
|
-
tree_cache: Optional[BasePrefixCache] = None,
|
697
|
-
):
|
697
|
+
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
|
698
698
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
699
|
+
input_len = len(self.fill_ids)
|
700
|
+
# NOTE: the matched length is at most 1 less than the input length to enable logprob computation
|
701
|
+
max_prefix_len = input_len - 1
|
702
|
+
if self.return_logprob:
|
703
|
+
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
704
|
+
max_prefix_len = max(max_prefix_len, 0)
|
705
|
+
token_ids = self.fill_ids[:max_prefix_len]
|
706
|
+
|
699
707
|
if tree_cache is not None:
|
700
708
|
(
|
701
709
|
self.prefix_indices,
|
@@ -703,31 +711,11 @@ class Req:
|
|
703
711
|
self.last_host_node,
|
704
712
|
self.host_hit_length,
|
705
713
|
) = tree_cache.match_prefix(
|
706
|
-
key=RadixKey(
|
707
|
-
token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
|
708
|
-
),
|
714
|
+
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key)
|
709
715
|
)
|
710
716
|
self.last_matched_prefix_len = len(self.prefix_indices)
|
711
717
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
712
718
|
|
713
|
-
def adjust_max_prefix_ids(self):
|
714
|
-
self.fill_ids = self.origin_input_ids + self.output_ids
|
715
|
-
input_len = len(self.fill_ids)
|
716
|
-
|
717
|
-
# FIXME: To work around some bugs in logprob computation, we need to ensure each
|
718
|
-
# request has at least one token. Later, we can relax this requirement and use `input_len`.
|
719
|
-
max_prefix_len = input_len - 1
|
720
|
-
|
721
|
-
if self.sampling_params.max_new_tokens > 0:
|
722
|
-
# Need at least one token to compute logits
|
723
|
-
max_prefix_len = min(max_prefix_len, input_len - 1)
|
724
|
-
|
725
|
-
if self.return_logprob:
|
726
|
-
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
727
|
-
|
728
|
-
max_prefix_len = max(max_prefix_len, 0)
|
729
|
-
return self.fill_ids[:max_prefix_len]
|
730
|
-
|
731
719
|
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
732
720
|
def init_incremental_detokenize(self):
|
733
721
|
first_iter = self.surr_offset is None or self.read_offset is None
|
@@ -808,7 +796,7 @@ class Req:
|
|
808
796
|
return
|
809
797
|
|
810
798
|
def reset_for_retract(self):
|
811
|
-
self.prefix_indices =
|
799
|
+
self.prefix_indices = torch.empty((0,), dtype=torch.int64)
|
812
800
|
self.last_node = None
|
813
801
|
self.swa_uuid_for_lock = None
|
814
802
|
self.extend_input_len = 0
|
@@ -886,15 +874,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
886
874
|
# This is an optimization to reduce the overhead of the prefill check.
|
887
875
|
batch_is_full: bool = False
|
888
876
|
|
889
|
-
# Events
|
890
|
-
launch_done: Optional[threading.Event] = None
|
891
|
-
|
892
877
|
# For chunked prefill in PP
|
893
878
|
chunked_req: Optional[Req] = None
|
894
879
|
|
895
880
|
# Sampling info
|
896
881
|
sampling_info: SamplingBatchInfo = None
|
897
|
-
next_batch_sampling_info: SamplingBatchInfo = None
|
898
882
|
|
899
883
|
# Batched arguments to model runner
|
900
884
|
input_ids: torch.Tensor = None # shape: [b], int64
|
@@ -1128,6 +1112,47 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1128
1112
|
else:
|
1129
1113
|
return out_cache_loc
|
1130
1114
|
|
1115
|
+
def write_cache_indices(
|
1116
|
+
self,
|
1117
|
+
req_pool_indices: List[int],
|
1118
|
+
prefix_lens: List[int],
|
1119
|
+
seq_lens: List[int],
|
1120
|
+
extend_lens: List[int],
|
1121
|
+
out_cache_loc: torch.Tensor,
|
1122
|
+
req_pool_indices_tensor: torch.Tensor,
|
1123
|
+
prefix_lens_tensor: torch.Tensor,
|
1124
|
+
seq_lens_tensor: torch.Tensor,
|
1125
|
+
extend_lens_tensor: torch.Tensor,
|
1126
|
+
prefix_tensors: list[torch.Tensor],
|
1127
|
+
):
|
1128
|
+
if support_triton(global_server_args_dict.get("attention_backend")):
|
1129
|
+
prefix_pointers = torch.tensor(
|
1130
|
+
[t.data_ptr() for t in prefix_tensors], device=self.device
|
1131
|
+
)
|
1132
|
+
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
1133
|
+
write_req_to_token_pool_triton[(len(req_pool_indices),)](
|
1134
|
+
self.req_to_token_pool.req_to_token,
|
1135
|
+
req_pool_indices_tensor,
|
1136
|
+
prefix_pointers,
|
1137
|
+
prefix_lens_tensor,
|
1138
|
+
seq_lens_tensor,
|
1139
|
+
extend_lens_tensor,
|
1140
|
+
out_cache_loc,
|
1141
|
+
self.req_to_token_pool.req_to_token.shape[1],
|
1142
|
+
)
|
1143
|
+
else:
|
1144
|
+
pt = 0
|
1145
|
+
for i in range(len(req_pool_indices)):
|
1146
|
+
self.req_to_token_pool.write(
|
1147
|
+
(req_pool_indices[i], slice(0, prefix_lens[i])),
|
1148
|
+
prefix_tensors[i],
|
1149
|
+
)
|
1150
|
+
self.req_to_token_pool.write(
|
1151
|
+
(req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
|
1152
|
+
out_cache_loc[pt : pt + extend_lens[i]],
|
1153
|
+
)
|
1154
|
+
pt += extend_lens[i]
|
1155
|
+
|
1131
1156
|
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
1132
1157
|
self.encoder_lens_cpu = []
|
1133
1158
|
self.encoder_cached = []
|
@@ -1205,10 +1230,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1205
1230
|
def prepare_for_extend(self):
|
1206
1231
|
self.forward_mode = ForwardMode.EXTEND
|
1207
1232
|
|
1208
|
-
# Allocate req slots
|
1209
|
-
bs = len(self.reqs)
|
1210
|
-
req_pool_indices = self.alloc_req_slots(bs, self.reqs)
|
1211
|
-
|
1212
1233
|
# Init tensors
|
1213
1234
|
reqs = self.reqs
|
1214
1235
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
@@ -1222,9 +1243,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1222
1243
|
r.token_type_ids for r in reqs if r.token_type_ids is not None
|
1223
1244
|
]
|
1224
1245
|
|
1225
|
-
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
1226
|
-
self.device, non_blocking=True
|
1227
|
-
)
|
1228
1246
|
input_ids_tensor = torch.tensor(
|
1229
1247
|
list(chain.from_iterable(input_ids)), dtype=torch.int64
|
1230
1248
|
).to(self.device, non_blocking=True)
|
@@ -1248,7 +1266,49 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1248
1266
|
|
1249
1267
|
extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
|
1250
1268
|
|
1251
|
-
#
|
1269
|
+
# Allocate req slots
|
1270
|
+
bs = len(self.reqs)
|
1271
|
+
req_pool_indices = self.alloc_req_slots(bs, self.reqs)
|
1272
|
+
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
1273
|
+
self.device, non_blocking=True
|
1274
|
+
)
|
1275
|
+
|
1276
|
+
# Allocate memory
|
1277
|
+
if self.token_to_kv_pool_allocator.page_size == 1:
|
1278
|
+
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
1279
|
+
else:
|
1280
|
+
last_loc = [
|
1281
|
+
(
|
1282
|
+
r.prefix_indices[-1:]
|
1283
|
+
if len(r.prefix_indices) > 0
|
1284
|
+
else torch.tensor([-1], device=self.device)
|
1285
|
+
)
|
1286
|
+
for r in self.reqs
|
1287
|
+
]
|
1288
|
+
out_cache_loc = self.alloc_paged_token_slots_extend(
|
1289
|
+
prefix_lens_tensor,
|
1290
|
+
prefix_lens_cpu_tensor,
|
1291
|
+
seq_lens_tensor,
|
1292
|
+
seq_lens_cpu,
|
1293
|
+
torch.cat(last_loc),
|
1294
|
+
extend_num_tokens,
|
1295
|
+
)
|
1296
|
+
|
1297
|
+
# Write allocated tokens to req_to_token_pool
|
1298
|
+
self.write_cache_indices(
|
1299
|
+
req_pool_indices,
|
1300
|
+
prefix_lens,
|
1301
|
+
seq_lens,
|
1302
|
+
extend_lens,
|
1303
|
+
out_cache_loc,
|
1304
|
+
req_pool_indices_tensor,
|
1305
|
+
prefix_lens_tensor,
|
1306
|
+
seq_lens_tensor,
|
1307
|
+
extend_lens_tensor,
|
1308
|
+
[r.prefix_indices for r in reqs],
|
1309
|
+
)
|
1310
|
+
|
1311
|
+
# Set fields
|
1252
1312
|
input_embeds = []
|
1253
1313
|
extend_input_logprob_token_ids = []
|
1254
1314
|
multimodal_inputs = []
|
@@ -1258,9 +1318,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1258
1318
|
assert seq_len - pre_len == req.extend_input_len
|
1259
1319
|
|
1260
1320
|
if pre_len > 0:
|
1261
|
-
self.req_to_token_pool.write(
|
1262
|
-
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
1263
|
-
)
|
1264
1321
|
if isinstance(self.tree_cache, SWAChunkCache):
|
1265
1322
|
self.tree_cache.evict_swa(
|
1266
1323
|
req, pre_len, self.model_config.attention_chunk_size
|
@@ -1355,25 +1412,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1355
1412
|
else:
|
1356
1413
|
extend_input_logprob_token_ids = None
|
1357
1414
|
|
1358
|
-
# Allocate memory
|
1359
|
-
if self.token_to_kv_pool_allocator.page_size == 1:
|
1360
|
-
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
1361
|
-
else:
|
1362
|
-
last_loc = get_last_loc(
|
1363
|
-
self.req_to_token_pool.req_to_token,
|
1364
|
-
req_pool_indices_tensor,
|
1365
|
-
prefix_lens_tensor,
|
1366
|
-
)
|
1367
|
-
out_cache_loc = self.alloc_paged_token_slots_extend(
|
1368
|
-
prefix_lens_tensor,
|
1369
|
-
prefix_lens_cpu_tensor,
|
1370
|
-
seq_lens_tensor,
|
1371
|
-
seq_lens_cpu,
|
1372
|
-
last_loc,
|
1373
|
-
extend_num_tokens,
|
1374
|
-
)
|
1375
|
-
|
1376
|
-
# Set fields
|
1377
1415
|
self.input_ids = input_ids_tensor
|
1378
1416
|
self.req_pool_indices = req_pool_indices_tensor
|
1379
1417
|
self.seq_lens = seq_lens_tensor
|
@@ -1406,28 +1444,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1406
1444
|
self.extend_lens = extend_lens
|
1407
1445
|
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
1408
1446
|
|
1409
|
-
# Write to req_to_token_pool
|
1410
|
-
if support_triton(global_server_args_dict.get("attention_backend")):
|
1411
|
-
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
1412
|
-
|
1413
|
-
write_req_to_token_pool_triton[(bs,)](
|
1414
|
-
self.req_to_token_pool.req_to_token,
|
1415
|
-
req_pool_indices_tensor,
|
1416
|
-
prefix_lens_tensor,
|
1417
|
-
seq_lens_tensor,
|
1418
|
-
extend_lens_tensor,
|
1419
|
-
out_cache_loc,
|
1420
|
-
self.req_to_token_pool.req_to_token.shape[1],
|
1421
|
-
)
|
1422
|
-
else:
|
1423
|
-
pt = 0
|
1424
|
-
for i in range(bs):
|
1425
|
-
self.req_to_token_pool.write(
|
1426
|
-
(req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
|
1427
|
-
out_cache_loc[pt : pt + extend_lens[i]],
|
1428
|
-
)
|
1429
|
-
pt += extend_lens[i]
|
1430
|
-
|
1431
1447
|
if self.model_config.is_encoder_decoder:
|
1432
1448
|
self.prepare_encoder_info_extend(input_ids, seq_lens)
|
1433
1449
|
|
@@ -1877,7 +1893,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1877
1893
|
)
|
1878
1894
|
),
|
1879
1895
|
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
1880
|
-
launch_done=self.launch_done,
|
1881
1896
|
is_prefill_only=self.is_prefill_only,
|
1882
1897
|
)
|
1883
1898
|
|
@@ -2018,8 +2033,8 @@ class ModelWorkerBatch:
|
|
2018
2033
|
capture_hidden_mode: CaptureHiddenMode = None
|
2019
2034
|
hicache_consumer_index: int = -1
|
2020
2035
|
|
2021
|
-
# Overlap
|
2022
|
-
|
2036
|
+
# Overlap scheduler related
|
2037
|
+
delay_sample_launch: bool = False
|
2023
2038
|
|
2024
2039
|
# Whether this batch is prefill-only (no token generation needed)
|
2025
2040
|
is_prefill_only: bool = False
|
@@ -2029,6 +2044,7 @@ class ModelWorkerBatch:
|
|
2029
2044
|
def write_req_to_token_pool_triton(
|
2030
2045
|
req_to_token_ptr, # [max_batch, max_context_len]
|
2031
2046
|
req_pool_indices,
|
2047
|
+
prefix_tensors,
|
2032
2048
|
pre_lens,
|
2033
2049
|
seq_lens,
|
2034
2050
|
extend_lens,
|
@@ -2041,6 +2057,19 @@ def write_req_to_token_pool_triton(
|
|
2041
2057
|
req_pool_index = tl.load(req_pool_indices + pid)
|
2042
2058
|
pre_len = tl.load(pre_lens + pid)
|
2043
2059
|
seq_len = tl.load(seq_lens + pid)
|
2060
|
+
prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64))
|
2061
|
+
|
2062
|
+
# write prefix
|
2063
|
+
num_loop = tl.cdiv(pre_len, BLOCK_SIZE)
|
2064
|
+
for i in range(num_loop):
|
2065
|
+
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
2066
|
+
mask = offset < pre_len
|
2067
|
+
value = tl.load(prefix_tensor + offset, mask=mask)
|
2068
|
+
tl.store(
|
2069
|
+
req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset,
|
2070
|
+
value,
|
2071
|
+
mask=mask,
|
2072
|
+
)
|
2044
2073
|
|
2045
2074
|
# NOTE: This can be slow for large bs
|
2046
2075
|
cumsum_start = tl.cast(0, tl.int64)
|