sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ from sglang.srt.layers.moe.moe_runner.base import (
9
9
  MoeRunnerConfig,
10
10
  PermuteMethodPool,
11
11
  )
12
+ from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore
12
13
  from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
13
14
  from sglang.srt.layers.moe.utils import get_moe_a2a_backend
14
15
 
@@ -30,6 +31,8 @@ class MoeRunner:
30
31
 
31
32
  if runner_backend.is_triton():
32
33
  self.runner_core = TritonRunnerCore(config)
34
+ elif runner_backend.is_deep_gemm():
35
+ self.runner_core = DeepGemmRunnerCore(config)
33
36
  else:
34
37
  raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
35
38
 
@@ -44,6 +44,7 @@ class MoeA2ABackend(Enum):
44
44
  class MoeRunnerBackend(Enum):
45
45
 
46
46
  AUTO = "auto"
47
+ DEEP_GEMM = "deep_gemm"
47
48
  TRITON = "triton"
48
49
  TRITON_KERNEL = "triton_kernel"
49
50
  FLASHINFER_TRTLLM = "flashinfer_trtllm"
@@ -54,6 +55,9 @@ class MoeRunnerBackend(Enum):
54
55
  def is_auto(self):
55
56
  return self == MoeRunnerBackend.AUTO
56
57
 
58
+ def is_deep_gemm(self):
59
+ return self == MoeRunnerBackend.DEEP_GEMM
60
+
57
61
  def is_triton(self):
58
62
  return self == MoeRunnerBackend.TRITON
59
63
 
@@ -147,7 +151,9 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
147
151
  def get_moe_runner_backend() -> MoeRunnerBackend:
148
152
  global MOE_RUNNER_BACKEND
149
153
  if MOE_RUNNER_BACKEND is None:
150
- logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
154
+ logger.warning(
155
+ "MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected"
156
+ )
151
157
  MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
152
158
  return MOE_RUNNER_BACKEND
153
159
 
@@ -72,7 +72,7 @@ if TYPE_CHECKING:
72
72
  BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
73
73
  "fp8": Fp8Config,
74
74
  "blockwise_int8": BlockInt8Config,
75
- "modelopt": ModelOptFp8Config,
75
+ "modelopt_fp8": ModelOptFp8Config,
76
76
  "modelopt_fp4": ModelOptFp4Config,
77
77
  "w8a8_int8": W8A8Int8Config,
78
78
  "w8a8_fp8": W8A8Fp8Config,
@@ -31,8 +31,8 @@ except ImportError:
31
31
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
32
32
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
33
33
  from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
34
+ from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
34
35
  from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
35
- from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
36
36
  from sglang.srt.layers.parameter import (
37
37
  BlockQuantScaleParameter,
38
38
  ModelWeightParameter,
@@ -1006,8 +1006,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1006
1006
  def create_moe_runner(
1007
1007
  self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1008
1008
  ):
1009
+
1010
+ from sglang.srt.layers.moe.utils import (
1011
+ get_moe_a2a_backend,
1012
+ get_moe_runner_backend,
1013
+ )
1014
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
1015
+
1009
1016
  self.moe_runner_config = moe_runner_config
1010
- self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
1017
+ moe_runner_backend = get_moe_runner_backend()
1018
+
1019
+ if moe_runner_backend.is_auto():
1020
+ if (
1021
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1022
+ and get_moe_a2a_backend().is_deepep()
1023
+ ):
1024
+ moe_runner_backend = MoeRunnerBackend.DEEP_GEMM
1025
+ else:
1026
+ moe_runner_backend = MoeRunnerBackend.TRITON
1027
+ if moe_runner_backend.is_deep_gemm() or moe_runner_backend.is_triton():
1028
+ self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
1029
+ else:
1030
+ # TODO(cwan): refactor other backends
1031
+ pass
1011
1032
 
1012
1033
  def apply(
1013
1034
  self,
@@ -1087,22 +1108,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1087
1108
  )
1088
1109
  return StandardCombineInput(hidden_states=output)
1089
1110
 
1090
- quant_info = TritonMoeQuantInfo(
1091
- w13_weight=layer.w13_weight,
1092
- w2_weight=layer.w2_weight,
1093
- use_fp8_w8a8=True,
1094
- w13_scale=(
1095
- layer.w13_weight_scale_inv
1096
- if self.block_quant
1097
- else layer.w13_weight_scale
1098
- ),
1099
- w2_scale=(
1100
- layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
1101
- ),
1102
- a13_scale=layer.w13_input_scale,
1103
- a2_scale=layer.w2_input_scale,
1104
- block_shape=self.quant_config.weight_block_size,
1105
- )
1111
+ if self.runner.runner_backend.is_deep_gemm():
1112
+
1113
+ w13_weight = layer.w13_weight
1114
+ w2_weight = layer.w2_weight
1115
+
1116
+ if self.block_quant:
1117
+ block_shape = self.quant_config.weight_block_size
1118
+ w13_scale = layer.w13_weight_scale_inv
1119
+ w2_scale = layer.w2_weight_scale_inv
1120
+ else:
1121
+ # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
1122
+ scale_block_size = 128
1123
+ block_shape = [scale_block_size, scale_block_size]
1124
+ w13_scale_n = (w13_weight.shape[1] - 1) // scale_block_size + 1
1125
+ w13_scale_k = (w13_weight.shape[2] - 1) // scale_block_size + 1
1126
+ w13_scale = (
1127
+ layer.w13_weight_scale.unsqueeze(1)
1128
+ .repeat_interleave(w13_scale_n, dim=1)
1129
+ .unsqueeze(2)
1130
+ .repeat_interleave(w13_scale_k, dim=2)
1131
+ )
1132
+ w2_scale_n = (w2_weight.shape[1] - 1) // scale_block_size + 1
1133
+ w2_scale_k = (w2_weight.shape[2] - 1) // scale_block_size + 1
1134
+ w2_scale = (
1135
+ layer.w2_weight_scale.unsqueeze(1)
1136
+ .repeat_interleave(w2_scale_n, dim=1)
1137
+ .unsqueeze(2)
1138
+ .repeat_interleave(w2_scale_k, dim=2)
1139
+ )
1140
+ quant_info = DeepGemmMoeQuantInfo(
1141
+ w13_weight=w13_weight,
1142
+ w2_weight=w2_weight,
1143
+ use_fp8=True,
1144
+ w13_scale=w13_scale,
1145
+ w2_scale=w2_scale,
1146
+ block_shape=block_shape,
1147
+ )
1148
+ elif self.runner.runner_backend.is_triton():
1149
+ quant_info = TritonMoeQuantInfo(
1150
+ w13_weight=layer.w13_weight,
1151
+ w2_weight=layer.w2_weight,
1152
+ use_fp8_w8a8=True,
1153
+ w13_scale=(
1154
+ layer.w13_weight_scale_inv
1155
+ if self.block_quant
1156
+ else layer.w13_weight_scale
1157
+ ),
1158
+ w2_scale=(
1159
+ layer.w2_weight_scale_inv
1160
+ if self.block_quant
1161
+ else layer.w2_weight_scale
1162
+ ),
1163
+ a13_scale=layer.w13_input_scale,
1164
+ a2_scale=layer.w2_input_scale,
1165
+ block_shape=self.quant_config.weight_block_size,
1166
+ )
1167
+ else:
1168
+ raise NotImplementedError(
1169
+ "Unsupported runner backend: %s" % self.runner.runner_backend
1170
+ )
1171
+
1106
1172
  return self.runner.run(dispatch_output, quant_info)
1107
1173
 
1108
1174
  def apply_with_router_logits(
@@ -113,7 +113,7 @@ class ModelOptFp8Config(QuantizationConfig):
113
113
 
114
114
  @classmethod
115
115
  def get_name(cls) -> str:
116
- return "modelopt"
116
+ return "modelopt_fp8"
117
117
 
118
118
  @classmethod
119
119
  def get_supported_act_dtypes(cls) -> List[torch.dtype]:
@@ -65,7 +65,9 @@ class QuarkConfig(QuantizationConfig):
65
65
  if should_ignore_layer(
66
66
  prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
67
67
  ):
68
- return UnquantizedLinearMethod()
68
+ if isinstance(layer, LinearBase):
69
+ return UnquantizedLinearMethod()
70
+ return None
69
71
 
70
72
  if isinstance(layer, LinearBase):
71
73
  scheme = self.get_scheme(layer=layer, layer_name=prefix)
@@ -21,7 +21,6 @@ from sglang.srt.utils import is_npu, set_weight_attrs
21
21
 
22
22
  if TYPE_CHECKING:
23
23
  from sglang.srt.layers.moe import MoeRunnerConfig
24
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
25
24
  from sglang.srt.layers.moe.token_dispatcher import (
26
25
  CombineInput,
27
26
  StandardDispatchOutput,
@@ -94,9 +93,7 @@ class W4AFp8Config(QuantizationConfig):
94
93
  self, layer: torch.nn.Module, prefix: str
95
94
  ) -> Optional[QuantizeMethodBase]:
96
95
  from sglang.srt.layers.linear import LinearBase
97
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
98
96
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
99
- from sglang.srt.managers.schedule_batch import global_server_args_dict
100
97
 
101
98
  if isinstance(layer, LinearBase):
102
99
  if is_layer_skipped(prefix, self.ignored_layers):
@@ -133,7 +130,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
133
130
 
134
131
  def create_weights(
135
132
  self,
136
- layer: EPMoE,
133
+ layer: Module,
137
134
  num_experts: int,
138
135
  hidden_size: int,
139
136
  intermediate_size_per_partition: int,
@@ -292,7 +289,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
292
289
 
293
290
  def apply(
294
291
  self,
295
- layer: EPMoE,
292
+ layer: Module,
296
293
  dispatch_output: StandardDispatchOutput,
297
294
  ) -> CombineInput:
298
295
 
@@ -303,18 +300,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
303
300
  topk_output = dispatch_output.topk_output
304
301
 
305
302
  topk_weights, topk_ids, _ = topk_output
306
- local_topk_ids = topk_ids
307
- if get_moe_expert_parallel_world_size() > 1:
308
- local_topk_ids = torch.where(
309
- topk_ids == -1,
310
- layer.num_experts,
311
- topk_ids,
312
- )
313
303
 
314
304
  output = cutlass_w4a8_moe(
315
- layer.start_expert_id,
316
- layer.end_expert_id,
317
- layer.num_experts,
318
305
  x,
319
306
  layer.w13_weight,
320
307
  layer.w2_weight,
@@ -322,7 +309,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
322
309
  layer.w2_weight_scale_inv,
323
310
  topk_weights,
324
311
  topk_ids,
325
- local_topk_ids,
326
312
  self.a_strides1,
327
313
  self.b_strides1,
328
314
  self.c_strides1,
@@ -418,10 +418,6 @@ class LoRAManager:
418
418
  replace_submodule(self.base_model, module_name, lora_module)
419
419
  return lora_module
420
420
 
421
- def should_skip_lora_for_vision_model(self, module_name):
422
- # TODO: support different vision models
423
- return module_name.find("vision_model.model") != -1
424
-
425
421
  def init_lora_modules(self):
426
422
  # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
427
423
  self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
@@ -439,10 +435,6 @@ class LoRAManager:
439
435
  ) and not self.base_model.should_apply_lora(module_name):
440
436
  continue
441
437
 
442
- # Skip vision model
443
- if self.should_skip_lora_for_vision_model(module_name):
444
- continue
445
-
446
438
  # The module should be converted if it is included in target_names
447
439
  if module_name.split(".")[-1] in self.target_modules:
448
440
  layer_id = get_layer_id(module_name)
@@ -1,3 +1,6 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
1
4
  import torch
2
5
 
3
6
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
@@ -13,6 +16,12 @@ def _resolve_future_token_ids(input_ids, future_token_ids_map):
13
16
  )
14
17
 
15
18
 
19
+ @dataclass
20
+ class FutureIndices:
21
+ indices: torch.Tensor
22
+ interval: Optional[slice] = None
23
+
24
+
16
25
  class FutureMap:
17
26
  def __init__(
18
27
  self,
@@ -30,24 +39,17 @@ class FutureMap:
30
39
  (self.future_buffer_len,), dtype=torch.int64, device=self.device
31
40
  )
32
41
 
33
- def update_ct(self, bs: int) -> int:
34
- """Update the circular buffer pointer and return the current pointer."""
42
+ def alloc_future_indices(self, bs: int) -> FutureIndices:
43
+ """Update the circular buffer pointer and allocate future indices."""
35
44
  cur_future_ct = self.future_ct
36
45
  self.future_ct = (cur_future_ct + bs) % self.future_limit
37
- return cur_future_ct
46
+ start = cur_future_ct + 1
47
+ end = cur_future_ct + 1 + bs
48
+ indices = torch.arange(start, end, dtype=torch.int64, device=self.device)
49
+ return FutureIndices(indices=indices, interval=slice(start, end))
38
50
 
39
51
  def resolve_future(self, model_worker_batch: ModelWorkerBatch):
40
- input_ids = model_worker_batch.input_ids
41
- _resolve_future_token_ids(input_ids, self.token_ids_buf)
42
-
43
- def update_next_future(self, future_ct: int, bs: int):
44
- return torch.arange(
45
- -(future_ct + 1),
46
- -(future_ct + 1 + bs),
47
- -1,
48
- dtype=torch.int64,
49
- device=self.device,
50
- )
52
+ _resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf)
51
53
 
52
- def store_to_map(self, future_ct: int, bs: int, next_token_ids: torch.Tensor):
53
- self.token_ids_buf[future_ct + 1 : future_ct + bs + 1] = next_token_ids
54
+ def store_to_map(self, future_indices: FutureIndices, next_token_ids: torch.Tensor):
55
+ self.token_ids_buf[future_indices.interval] = next_token_ids
@@ -97,7 +97,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
97
97
  "ep_num_redundant_experts",
98
98
  "enable_nan_detection",
99
99
  "flashinfer_mla_disable_ragged",
100
- "max_micro_batch_size",
100
+ "pp_max_micro_batch_size",
101
101
  "disable_shared_experts_fusion",
102
102
  "sampling_backend",
103
103
  "speculative_accept_threshold_single",
@@ -114,6 +114,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
114
114
  "enable_deterministic_inference",
115
115
  "nsa_prefill",
116
116
  "nsa_decode",
117
+ "multi_item_scoring_delimiter",
117
118
  ]
118
119
 
119
120
  # Put some global args for easy access
@@ -539,7 +540,7 @@ class Req:
539
540
 
540
541
  # Prefix info
541
542
  # The indices to kv cache for the shared prefix.
542
- self.prefix_indices: torch.Tensor = []
543
+ self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64)
543
544
  # Number of tokens to run prefill.
544
545
  self.extend_input_len = 0
545
546
  # The relative logprob_start_len in an extend batch
@@ -666,9 +667,11 @@ class Req:
666
667
  def is_prefill_only(self) -> bool:
667
668
  """Check if this request is prefill-only (no token generation needed)."""
668
669
  # NOTE: when spec is enabled, prefill_only optimizations are disabled
669
- return (
670
- self.sampling_params.max_new_tokens == 0
671
- and global_server_args_dict["speculative_algorithm"] is None
670
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
671
+
672
+ spec_alg = global_server_args_dict["speculative_algorithm"]
673
+ return self.sampling_params.max_new_tokens == 0 and (
674
+ spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
672
675
  )
673
676
 
674
677
  def add_latency(self, stage: RequestStage):
@@ -691,11 +694,16 @@ class Req:
691
694
  # Whether request reached finished condition
692
695
  return self.finished_reason is not None
693
696
 
694
- def init_next_round_input(
695
- self,
696
- tree_cache: Optional[BasePrefixCache] = None,
697
- ):
697
+ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
698
698
  self.fill_ids = self.origin_input_ids + self.output_ids
699
+ input_len = len(self.fill_ids)
700
+ # NOTE: the matched length is at most 1 less than the input length to enable logprob computation
701
+ max_prefix_len = input_len - 1
702
+ if self.return_logprob:
703
+ max_prefix_len = min(max_prefix_len, self.logprob_start_len)
704
+ max_prefix_len = max(max_prefix_len, 0)
705
+ token_ids = self.fill_ids[:max_prefix_len]
706
+
699
707
  if tree_cache is not None:
700
708
  (
701
709
  self.prefix_indices,
@@ -703,31 +711,11 @@ class Req:
703
711
  self.last_host_node,
704
712
  self.host_hit_length,
705
713
  ) = tree_cache.match_prefix(
706
- key=RadixKey(
707
- token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
708
- ),
714
+ key=RadixKey(token_ids=token_ids, extra_key=self.extra_key)
709
715
  )
710
716
  self.last_matched_prefix_len = len(self.prefix_indices)
711
717
  self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
712
718
 
713
- def adjust_max_prefix_ids(self):
714
- self.fill_ids = self.origin_input_ids + self.output_ids
715
- input_len = len(self.fill_ids)
716
-
717
- # FIXME: To work around some bugs in logprob computation, we need to ensure each
718
- # request has at least one token. Later, we can relax this requirement and use `input_len`.
719
- max_prefix_len = input_len - 1
720
-
721
- if self.sampling_params.max_new_tokens > 0:
722
- # Need at least one token to compute logits
723
- max_prefix_len = min(max_prefix_len, input_len - 1)
724
-
725
- if self.return_logprob:
726
- max_prefix_len = min(max_prefix_len, self.logprob_start_len)
727
-
728
- max_prefix_len = max(max_prefix_len, 0)
729
- return self.fill_ids[:max_prefix_len]
730
-
731
719
  # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
732
720
  def init_incremental_detokenize(self):
733
721
  first_iter = self.surr_offset is None or self.read_offset is None
@@ -808,7 +796,7 @@ class Req:
808
796
  return
809
797
 
810
798
  def reset_for_retract(self):
811
- self.prefix_indices = []
799
+ self.prefix_indices = torch.empty((0,), dtype=torch.int64)
812
800
  self.last_node = None
813
801
  self.swa_uuid_for_lock = None
814
802
  self.extend_input_len = 0
@@ -886,15 +874,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
886
874
  # This is an optimization to reduce the overhead of the prefill check.
887
875
  batch_is_full: bool = False
888
876
 
889
- # Events
890
- launch_done: Optional[threading.Event] = None
891
-
892
877
  # For chunked prefill in PP
893
878
  chunked_req: Optional[Req] = None
894
879
 
895
880
  # Sampling info
896
881
  sampling_info: SamplingBatchInfo = None
897
- next_batch_sampling_info: SamplingBatchInfo = None
898
882
 
899
883
  # Batched arguments to model runner
900
884
  input_ids: torch.Tensor = None # shape: [b], int64
@@ -1128,6 +1112,47 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1128
1112
  else:
1129
1113
  return out_cache_loc
1130
1114
 
1115
+ def write_cache_indices(
1116
+ self,
1117
+ req_pool_indices: List[int],
1118
+ prefix_lens: List[int],
1119
+ seq_lens: List[int],
1120
+ extend_lens: List[int],
1121
+ out_cache_loc: torch.Tensor,
1122
+ req_pool_indices_tensor: torch.Tensor,
1123
+ prefix_lens_tensor: torch.Tensor,
1124
+ seq_lens_tensor: torch.Tensor,
1125
+ extend_lens_tensor: torch.Tensor,
1126
+ prefix_tensors: list[torch.Tensor],
1127
+ ):
1128
+ if support_triton(global_server_args_dict.get("attention_backend")):
1129
+ prefix_pointers = torch.tensor(
1130
+ [t.data_ptr() for t in prefix_tensors], device=self.device
1131
+ )
1132
+ # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
1133
+ write_req_to_token_pool_triton[(len(req_pool_indices),)](
1134
+ self.req_to_token_pool.req_to_token,
1135
+ req_pool_indices_tensor,
1136
+ prefix_pointers,
1137
+ prefix_lens_tensor,
1138
+ seq_lens_tensor,
1139
+ extend_lens_tensor,
1140
+ out_cache_loc,
1141
+ self.req_to_token_pool.req_to_token.shape[1],
1142
+ )
1143
+ else:
1144
+ pt = 0
1145
+ for i in range(len(req_pool_indices)):
1146
+ self.req_to_token_pool.write(
1147
+ (req_pool_indices[i], slice(0, prefix_lens[i])),
1148
+ prefix_tensors[i],
1149
+ )
1150
+ self.req_to_token_pool.write(
1151
+ (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
1152
+ out_cache_loc[pt : pt + extend_lens[i]],
1153
+ )
1154
+ pt += extend_lens[i]
1155
+
1131
1156
  def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
1132
1157
  self.encoder_lens_cpu = []
1133
1158
  self.encoder_cached = []
@@ -1205,10 +1230,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1205
1230
  def prepare_for_extend(self):
1206
1231
  self.forward_mode = ForwardMode.EXTEND
1207
1232
 
1208
- # Allocate req slots
1209
- bs = len(self.reqs)
1210
- req_pool_indices = self.alloc_req_slots(bs, self.reqs)
1211
-
1212
1233
  # Init tensors
1213
1234
  reqs = self.reqs
1214
1235
  input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
@@ -1222,9 +1243,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1222
1243
  r.token_type_ids for r in reqs if r.token_type_ids is not None
1223
1244
  ]
1224
1245
 
1225
- req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
1226
- self.device, non_blocking=True
1227
- )
1228
1246
  input_ids_tensor = torch.tensor(
1229
1247
  list(chain.from_iterable(input_ids)), dtype=torch.int64
1230
1248
  ).to(self.device, non_blocking=True)
@@ -1248,7 +1266,49 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1248
1266
 
1249
1267
  extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1250
1268
 
1251
- # Copy prefix and do some basic check
1269
+ # Allocate req slots
1270
+ bs = len(self.reqs)
1271
+ req_pool_indices = self.alloc_req_slots(bs, self.reqs)
1272
+ req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
1273
+ self.device, non_blocking=True
1274
+ )
1275
+
1276
+ # Allocate memory
1277
+ if self.token_to_kv_pool_allocator.page_size == 1:
1278
+ out_cache_loc = self.alloc_token_slots(extend_num_tokens)
1279
+ else:
1280
+ last_loc = [
1281
+ (
1282
+ r.prefix_indices[-1:]
1283
+ if len(r.prefix_indices) > 0
1284
+ else torch.tensor([-1], device=self.device)
1285
+ )
1286
+ for r in self.reqs
1287
+ ]
1288
+ out_cache_loc = self.alloc_paged_token_slots_extend(
1289
+ prefix_lens_tensor,
1290
+ prefix_lens_cpu_tensor,
1291
+ seq_lens_tensor,
1292
+ seq_lens_cpu,
1293
+ torch.cat(last_loc),
1294
+ extend_num_tokens,
1295
+ )
1296
+
1297
+ # Write allocated tokens to req_to_token_pool
1298
+ self.write_cache_indices(
1299
+ req_pool_indices,
1300
+ prefix_lens,
1301
+ seq_lens,
1302
+ extend_lens,
1303
+ out_cache_loc,
1304
+ req_pool_indices_tensor,
1305
+ prefix_lens_tensor,
1306
+ seq_lens_tensor,
1307
+ extend_lens_tensor,
1308
+ [r.prefix_indices for r in reqs],
1309
+ )
1310
+
1311
+ # Set fields
1252
1312
  input_embeds = []
1253
1313
  extend_input_logprob_token_ids = []
1254
1314
  multimodal_inputs = []
@@ -1258,9 +1318,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1258
1318
  assert seq_len - pre_len == req.extend_input_len
1259
1319
 
1260
1320
  if pre_len > 0:
1261
- self.req_to_token_pool.write(
1262
- (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1263
- )
1264
1321
  if isinstance(self.tree_cache, SWAChunkCache):
1265
1322
  self.tree_cache.evict_swa(
1266
1323
  req, pre_len, self.model_config.attention_chunk_size
@@ -1355,25 +1412,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1355
1412
  else:
1356
1413
  extend_input_logprob_token_ids = None
1357
1414
 
1358
- # Allocate memory
1359
- if self.token_to_kv_pool_allocator.page_size == 1:
1360
- out_cache_loc = self.alloc_token_slots(extend_num_tokens)
1361
- else:
1362
- last_loc = get_last_loc(
1363
- self.req_to_token_pool.req_to_token,
1364
- req_pool_indices_tensor,
1365
- prefix_lens_tensor,
1366
- )
1367
- out_cache_loc = self.alloc_paged_token_slots_extend(
1368
- prefix_lens_tensor,
1369
- prefix_lens_cpu_tensor,
1370
- seq_lens_tensor,
1371
- seq_lens_cpu,
1372
- last_loc,
1373
- extend_num_tokens,
1374
- )
1375
-
1376
- # Set fields
1377
1415
  self.input_ids = input_ids_tensor
1378
1416
  self.req_pool_indices = req_pool_indices_tensor
1379
1417
  self.seq_lens = seq_lens_tensor
@@ -1406,28 +1444,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1406
1444
  self.extend_lens = extend_lens
1407
1445
  self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
1408
1446
 
1409
- # Write to req_to_token_pool
1410
- if support_triton(global_server_args_dict.get("attention_backend")):
1411
- # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
1412
-
1413
- write_req_to_token_pool_triton[(bs,)](
1414
- self.req_to_token_pool.req_to_token,
1415
- req_pool_indices_tensor,
1416
- prefix_lens_tensor,
1417
- seq_lens_tensor,
1418
- extend_lens_tensor,
1419
- out_cache_loc,
1420
- self.req_to_token_pool.req_to_token.shape[1],
1421
- )
1422
- else:
1423
- pt = 0
1424
- for i in range(bs):
1425
- self.req_to_token_pool.write(
1426
- (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
1427
- out_cache_loc[pt : pt + extend_lens[i]],
1428
- )
1429
- pt += extend_lens[i]
1430
-
1431
1447
  if self.model_config.is_encoder_decoder:
1432
1448
  self.prepare_encoder_info_extend(input_ids, seq_lens)
1433
1449
 
@@ -1877,7 +1893,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1877
1893
  )
1878
1894
  ),
1879
1895
  extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1880
- launch_done=self.launch_done,
1881
1896
  is_prefill_only=self.is_prefill_only,
1882
1897
  )
1883
1898
 
@@ -2018,8 +2033,8 @@ class ModelWorkerBatch:
2018
2033
  capture_hidden_mode: CaptureHiddenMode = None
2019
2034
  hicache_consumer_index: int = -1
2020
2035
 
2021
- # Overlap event
2022
- launch_done: Optional[threading.Event] = None
2036
+ # Overlap scheduler related
2037
+ delay_sample_launch: bool = False
2023
2038
 
2024
2039
  # Whether this batch is prefill-only (no token generation needed)
2025
2040
  is_prefill_only: bool = False
@@ -2029,6 +2044,7 @@ class ModelWorkerBatch:
2029
2044
  def write_req_to_token_pool_triton(
2030
2045
  req_to_token_ptr, # [max_batch, max_context_len]
2031
2046
  req_pool_indices,
2047
+ prefix_tensors,
2032
2048
  pre_lens,
2033
2049
  seq_lens,
2034
2050
  extend_lens,
@@ -2041,6 +2057,19 @@ def write_req_to_token_pool_triton(
2041
2057
  req_pool_index = tl.load(req_pool_indices + pid)
2042
2058
  pre_len = tl.load(pre_lens + pid)
2043
2059
  seq_len = tl.load(seq_lens + pid)
2060
+ prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64))
2061
+
2062
+ # write prefix
2063
+ num_loop = tl.cdiv(pre_len, BLOCK_SIZE)
2064
+ for i in range(num_loop):
2065
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
2066
+ mask = offset < pre_len
2067
+ value = tl.load(prefix_tensor + offset, mask=mask)
2068
+ tl.store(
2069
+ req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset,
2070
+ value,
2071
+ mask=mask,
2072
+ )
2044
2073
 
2045
2074
  # NOTE: This can be slow for large bs
2046
2075
  cumsum_start = tl.cast(0, tl.int64)