sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
sglang/srt/metrics/collector.py
CHANGED
@@ -145,6 +145,7 @@ class SchedulerStats:
|
|
145
145
|
num_prefill_infight_queue_reqs: int = 0
|
146
146
|
num_decode_prealloc_queue_reqs: int = 0
|
147
147
|
num_decode_transfer_queue_reqs: int = 0
|
148
|
+
total_retracted_reqs: int = 0
|
148
149
|
|
149
150
|
|
150
151
|
class SchedulerMetricsCollector:
|
@@ -219,6 +220,13 @@ class SchedulerMetricsCollector:
|
|
219
220
|
multiprocess_mode="mostrecent",
|
220
221
|
)
|
221
222
|
|
223
|
+
self.total_retracted_reqs = Gauge(
|
224
|
+
name="sglang:total_retracted_reqs",
|
225
|
+
documentation="The total number of retracted requests due to kvcache full.",
|
226
|
+
labelnames=labels.keys(),
|
227
|
+
multiprocess_mode="mostrecent",
|
228
|
+
)
|
229
|
+
|
222
230
|
# Disaggregation queue metrics
|
223
231
|
self.num_prefill_prealloc_queue_reqs = Gauge(
|
224
232
|
name="sglang:num_prefill_prealloc_queue_reqs",
|
@@ -279,6 +287,7 @@ class SchedulerMetricsCollector:
|
|
279
287
|
self._log_gauge(self.num_grammar_queue_reqs, stats.num_grammar_queue_reqs)
|
280
288
|
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
|
281
289
|
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
290
|
+
self._log_gauge(self.total_retracted_reqs, stats.total_retracted_reqs)
|
282
291
|
|
283
292
|
# Disaggregation metrics
|
284
293
|
self._log_gauge(
|
@@ -29,9 +29,9 @@ from torch.profiler import ProfilerActivity, profile
|
|
29
29
|
from sglang.srt.custom_op import CustomOp
|
30
30
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
31
31
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
32
|
+
from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
|
32
33
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
33
34
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
34
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
35
35
|
from sglang.srt.model_executor.forward_batch_info import (
|
36
36
|
CaptureHiddenMode,
|
37
37
|
ForwardBatch,
|
@@ -167,8 +167,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
167
167
|
# is very small. We add more values here to make sure we capture the maximum bs.
|
168
168
|
capture_bs += [model_runner.req_to_token_pool.size]
|
169
169
|
|
170
|
+
mul_base = 1
|
171
|
+
|
170
172
|
if server_args.enable_two_batch_overlap:
|
171
|
-
|
173
|
+
mul_base *= 2
|
174
|
+
|
175
|
+
if require_gathered_buffer(server_args):
|
176
|
+
mul_base *= get_attention_tp_size()
|
177
|
+
|
178
|
+
capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
|
172
179
|
|
173
180
|
if server_args.cuda_graph_max_bs:
|
174
181
|
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
@@ -264,7 +271,7 @@ class CudaGraphRunner:
|
|
264
271
|
if self.enable_torch_compile:
|
265
272
|
set_torch_compile_config()
|
266
273
|
|
267
|
-
if self.model_runner.server_args.
|
274
|
+
if self.model_runner.server_args.enable_lora:
|
268
275
|
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
|
269
276
|
|
270
277
|
# Graph inputs
|
@@ -306,20 +313,37 @@ class CudaGraphRunner:
|
|
306
313
|
self.encoder_lens = None
|
307
314
|
|
308
315
|
if self.require_gathered_buffer:
|
309
|
-
self.gathered_buffer = torch.zeros(
|
310
|
-
(
|
311
|
-
self.max_num_token,
|
312
|
-
self.model_runner.model_config.hidden_size,
|
313
|
-
),
|
314
|
-
dtype=self.model_runner.dtype,
|
315
|
-
)
|
316
316
|
if self.require_mlp_tp_gather:
|
317
317
|
self.global_num_tokens_gpu = torch.zeros(
|
318
318
|
(self.dp_size,), dtype=torch.int32
|
319
319
|
)
|
320
|
+
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
321
|
+
(self.dp_size,), dtype=torch.int32
|
322
|
+
)
|
323
|
+
self.gathered_buffer = torch.zeros(
|
324
|
+
(
|
325
|
+
self.max_num_token * self.dp_size,
|
326
|
+
self.model_runner.model_config.hidden_size,
|
327
|
+
),
|
328
|
+
dtype=self.model_runner.dtype,
|
329
|
+
)
|
320
330
|
else:
|
321
331
|
assert self.require_attn_tp_gather
|
322
332
|
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
333
|
+
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
334
|
+
(1,), dtype=torch.int32
|
335
|
+
)
|
336
|
+
self.gathered_buffer = torch.zeros(
|
337
|
+
(
|
338
|
+
self.max_num_token,
|
339
|
+
self.model_runner.model_config.hidden_size,
|
340
|
+
),
|
341
|
+
dtype=self.model_runner.dtype,
|
342
|
+
)
|
343
|
+
else:
|
344
|
+
self.global_num_tokens_gpu = None
|
345
|
+
self.global_num_tokens_for_logprob_gpu = None
|
346
|
+
self.gathered_buffer = None
|
323
347
|
|
324
348
|
self.custom_mask = torch.ones(
|
325
349
|
(
|
@@ -342,9 +366,9 @@ class CudaGraphRunner:
|
|
342
366
|
def can_run(self, forward_batch: ForwardBatch):
|
343
367
|
if self.require_mlp_tp_gather:
|
344
368
|
cuda_graph_bs = (
|
345
|
-
|
369
|
+
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
346
370
|
if self.model_runner.spec_algorithm.is_eagle()
|
347
|
-
else
|
371
|
+
else max(forward_batch.global_num_tokens_cpu)
|
348
372
|
)
|
349
373
|
else:
|
350
374
|
cuda_graph_bs = forward_batch.batch_size
|
@@ -480,16 +504,19 @@ class CudaGraphRunner:
|
|
480
504
|
if self.require_mlp_tp_gather:
|
481
505
|
self.global_num_tokens_gpu.copy_(
|
482
506
|
torch.tensor(
|
483
|
-
[
|
484
|
-
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
485
|
-
for i in range(self.dp_size)
|
486
|
-
],
|
507
|
+
[num_tokens] * self.dp_size,
|
487
508
|
dtype=torch.int32,
|
488
509
|
device=input_ids.device,
|
489
510
|
)
|
490
511
|
)
|
491
|
-
|
492
|
-
|
512
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
513
|
+
torch.tensor(
|
514
|
+
[num_tokens] * self.dp_size,
|
515
|
+
dtype=torch.int32,
|
516
|
+
device=input_ids.device,
|
517
|
+
)
|
518
|
+
)
|
519
|
+
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
|
493
520
|
elif self.require_attn_tp_gather:
|
494
521
|
self.global_num_tokens_gpu.copy_(
|
495
522
|
torch.tensor(
|
@@ -498,10 +525,15 @@ class CudaGraphRunner:
|
|
498
525
|
device=input_ids.device,
|
499
526
|
)
|
500
527
|
)
|
501
|
-
|
528
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
529
|
+
torch.tensor(
|
530
|
+
[num_tokens],
|
531
|
+
dtype=torch.int32,
|
532
|
+
device=input_ids.device,
|
533
|
+
)
|
534
|
+
)
|
502
535
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
503
536
|
else:
|
504
|
-
global_num_tokens = None
|
505
537
|
gathered_buffer = None
|
506
538
|
|
507
539
|
spec_info = self.get_spec_info(num_tokens)
|
@@ -510,11 +542,10 @@ class CudaGraphRunner:
|
|
510
542
|
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
511
543
|
)
|
512
544
|
|
513
|
-
if self.model_runner.server_args.
|
514
|
-
#
|
515
|
-
#
|
516
|
-
|
517
|
-
lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs
|
545
|
+
if self.model_runner.server_args.enable_lora:
|
546
|
+
# It is safe to capture CUDA graph using empty LoRA path, as the LoRA kernels will always be launched whenever
|
547
|
+
# `--enable-lora` is set to True (and return immediately if the LoRA path is empty for perf optimization).
|
548
|
+
lora_paths = [None] * bs
|
518
549
|
else:
|
519
550
|
lora_paths = None
|
520
551
|
|
@@ -532,7 +563,9 @@ class CudaGraphRunner:
|
|
532
563
|
encoder_lens=encoder_lens,
|
533
564
|
return_logprob=False,
|
534
565
|
positions=positions,
|
535
|
-
global_num_tokens_gpu=
|
566
|
+
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
567
|
+
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
568
|
+
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
|
536
569
|
gathered_buffer=gathered_buffer,
|
537
570
|
mrope_positions=mrope_positions,
|
538
571
|
spec_algorithm=self.model_runner.spec_algorithm,
|
@@ -636,12 +669,13 @@ class CudaGraphRunner:
|
|
636
669
|
|
637
670
|
# Pad
|
638
671
|
if self.require_mlp_tp_gather:
|
639
|
-
|
640
|
-
|
672
|
+
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
|
673
|
+
max_batch_size = (
|
674
|
+
max_num_tokens / self.num_tokens_per_bs
|
641
675
|
if self.model_runner.spec_algorithm.is_eagle()
|
642
|
-
else
|
676
|
+
else max_num_tokens
|
643
677
|
)
|
644
|
-
index = bisect.bisect_left(self.capture_bs,
|
678
|
+
index = bisect.bisect_left(self.capture_bs, max_batch_size)
|
645
679
|
else:
|
646
680
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
647
681
|
bs = self.capture_bs[index]
|
@@ -671,7 +705,8 @@ class CudaGraphRunner:
|
|
671
705
|
if forward_batch.mrope_positions is not None:
|
672
706
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
673
707
|
if self.require_gathered_buffer:
|
674
|
-
self.global_num_tokens_gpu.
|
708
|
+
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
709
|
+
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
675
710
|
if enable_num_token_non_padded(self.model_runner.server_args):
|
676
711
|
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
|
677
712
|
if self.enable_two_batch_overlap:
|
@@ -38,6 +38,11 @@ import torch
|
|
38
38
|
import triton
|
39
39
|
import triton.language as tl
|
40
40
|
|
41
|
+
from sglang.srt.layers.dp_attention import (
|
42
|
+
DPPaddingMode,
|
43
|
+
get_attention_dp_rank,
|
44
|
+
get_attention_tp_size,
|
45
|
+
)
|
41
46
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
42
47
|
from sglang.srt.utils import (
|
43
48
|
flatten_nested_list,
|
@@ -48,6 +53,7 @@ from sglang.srt.utils import (
|
|
48
53
|
|
49
54
|
if TYPE_CHECKING:
|
50
55
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
56
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
51
57
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs
|
52
58
|
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
53
59
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -78,6 +84,9 @@ class ForwardMode(IntEnum):
|
|
78
84
|
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
79
85
|
DUMMY_FIRST = auto()
|
80
86
|
|
87
|
+
# Split Prefill for PD multiplexing
|
88
|
+
SPLIT_PREFILL = auto()
|
89
|
+
|
81
90
|
def is_prefill(self):
|
82
91
|
return self.is_extend()
|
83
92
|
|
@@ -98,6 +107,9 @@ class ForwardMode(IntEnum):
|
|
98
107
|
def is_idle(self):
|
99
108
|
return self == ForwardMode.IDLE
|
100
109
|
|
110
|
+
def is_decode_or_idle(self):
|
111
|
+
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
|
112
|
+
|
101
113
|
def is_target_verify(self):
|
102
114
|
return self == ForwardMode.TARGET_VERIFY
|
103
115
|
|
@@ -121,8 +133,8 @@ class ForwardMode(IntEnum):
|
|
121
133
|
def is_dummy_first(self):
|
122
134
|
return self == ForwardMode.DUMMY_FIRST
|
123
135
|
|
124
|
-
def
|
125
|
-
return self == ForwardMode.
|
136
|
+
def is_split_prefill(self):
|
137
|
+
return self == ForwardMode.SPLIT_PREFILL
|
126
138
|
|
127
139
|
|
128
140
|
@total_ordering
|
@@ -194,6 +206,14 @@ class ForwardBatch:
|
|
194
206
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
195
207
|
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
196
208
|
|
209
|
+
# For split prefill
|
210
|
+
# intermediate values for split prefill
|
211
|
+
hidden_states: torch.Tensor = None
|
212
|
+
residual: torch.Tensor = None
|
213
|
+
model_specific_states: Dict[str, any] = None
|
214
|
+
# current split index of layer
|
215
|
+
split_index: int = 0
|
216
|
+
|
197
217
|
# For MLA chunked prefix cache used in chunked prefill
|
198
218
|
# Tell attention backend whether the kv cache needs to be attended in current pass
|
199
219
|
attn_attend_prefix_cache: Optional[bool] = None
|
@@ -229,7 +249,7 @@ class ForwardBatch:
|
|
229
249
|
lora_paths: Optional[List[str]] = None
|
230
250
|
|
231
251
|
# For input embeddings
|
232
|
-
input_embeds: Optional[torch.
|
252
|
+
input_embeds: Optional[torch.Tensor] = None
|
233
253
|
|
234
254
|
# For cross-encoder model
|
235
255
|
token_type_ids: Optional[torch.Tensor] = None
|
@@ -248,6 +268,8 @@ class ForwardBatch:
|
|
248
268
|
# Has to be None when cuda graph is captured.
|
249
269
|
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
|
250
270
|
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
271
|
+
# The padding mode for DP attention
|
272
|
+
dp_padding_mode: Optional[DPPaddingMode] = None
|
251
273
|
# for extend, local start pos and num tokens is different in logits processor
|
252
274
|
# this will be computed in get_dp_local_info
|
253
275
|
# this will be recomputed in LogitsMetadata.from_forward_batch
|
@@ -273,7 +295,7 @@ class ForwardBatch:
|
|
273
295
|
# For two-batch overlap
|
274
296
|
tbo_split_seq_index: Optional[int] = None
|
275
297
|
tbo_parent_token_range: Optional[Tuple[int, int]] = None
|
276
|
-
tbo_children: Optional[List[
|
298
|
+
tbo_children: Optional[List[ForwardBatch]] = None
|
277
299
|
|
278
300
|
@classmethod
|
279
301
|
def init_new(
|
@@ -327,20 +349,38 @@ class ForwardBatch:
|
|
327
349
|
len(batch.input_ids), dtype=torch.int32
|
328
350
|
).to(device, non_blocking=True)
|
329
351
|
|
330
|
-
# For
|
352
|
+
# For MLP sync
|
331
353
|
if batch.global_num_tokens is not None:
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
if batch.spec_num_draft_tokens is not None
|
336
|
-
else 1
|
354
|
+
from sglang.srt.speculative.eagle_utils import (
|
355
|
+
EagleDraftInput,
|
356
|
+
EagleVerifyInput,
|
337
357
|
)
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
358
|
+
|
359
|
+
assert batch.global_num_tokens_for_logprob is not None
|
360
|
+
# process global_num_tokens and global_num_tokens_for_logprob
|
361
|
+
if batch.spec_info is not None:
|
362
|
+
if isinstance(batch.spec_info, EagleDraftInput):
|
363
|
+
global_num_tokens = [
|
364
|
+
x * batch.spec_info.num_tokens_per_batch
|
365
|
+
for x in batch.global_num_tokens
|
366
|
+
]
|
367
|
+
global_num_tokens_for_logprob = [
|
368
|
+
x * batch.spec_info.num_tokens_for_logprob_per_batch
|
369
|
+
for x in batch.global_num_tokens_for_logprob
|
370
|
+
]
|
371
|
+
else:
|
372
|
+
assert isinstance(batch.spec_info, EagleVerifyInput)
|
373
|
+
global_num_tokens = [
|
374
|
+
x * batch.spec_info.draft_token_num
|
375
|
+
for x in batch.global_num_tokens
|
376
|
+
]
|
377
|
+
global_num_tokens_for_logprob = [
|
378
|
+
x * batch.spec_info.draft_token_num
|
379
|
+
for x in batch.global_num_tokens_for_logprob
|
380
|
+
]
|
381
|
+
else:
|
382
|
+
global_num_tokens = batch.global_num_tokens
|
383
|
+
global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
|
344
384
|
|
345
385
|
ret.global_num_tokens_cpu = global_num_tokens
|
346
386
|
ret.global_num_tokens_gpu = torch.tensor(
|
@@ -352,15 +392,8 @@ class ForwardBatch:
|
|
352
392
|
global_num_tokens_for_logprob, dtype=torch.int64
|
353
393
|
).to(device, non_blocking=True)
|
354
394
|
|
355
|
-
sum_len = sum(global_num_tokens)
|
356
|
-
ret.gathered_buffer = torch.zeros(
|
357
|
-
(sum_len, model_runner.model_config.hidden_size),
|
358
|
-
dtype=model_runner.dtype,
|
359
|
-
device=device,
|
360
|
-
)
|
361
|
-
|
362
395
|
if ret.forward_mode.is_idle():
|
363
|
-
ret.positions = torch.empty((0,), device=device)
|
396
|
+
ret.positions = torch.empty((0,), dtype=torch.int64, device=device)
|
364
397
|
TboForwardBatchPreparer.prepare(
|
365
398
|
ret, is_draft_worker=model_runner.is_draft_worker
|
366
399
|
)
|
@@ -405,7 +438,7 @@ class ForwardBatch:
|
|
405
438
|
ret._compute_mrope_positions(model_runner, batch)
|
406
439
|
|
407
440
|
# Init lora information
|
408
|
-
if model_runner.server_args.
|
441
|
+
if model_runner.server_args.enable_lora:
|
409
442
|
model_runner.lora_manager.prepare_lora_batch(ret)
|
410
443
|
|
411
444
|
TboForwardBatchPreparer.prepare(
|
@@ -560,6 +593,158 @@ class ForwardBatch:
|
|
560
593
|
)
|
561
594
|
self.prefix_chunk_kv_indices.append(chunk_kv_indices)
|
562
595
|
|
596
|
+
def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0):
|
597
|
+
if value == 0:
|
598
|
+
return torch.cat(
|
599
|
+
[tensor, tensor.new_zeros(size - tensor.shape[0], *tensor.shape[1:])],
|
600
|
+
dim=0,
|
601
|
+
)
|
602
|
+
else:
|
603
|
+
return torch.cat(
|
604
|
+
[
|
605
|
+
tensor,
|
606
|
+
tensor.new_full((size - tensor.shape[0], *tensor.shape[1:]), value),
|
607
|
+
],
|
608
|
+
dim=0,
|
609
|
+
)
|
610
|
+
|
611
|
+
def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
|
612
|
+
|
613
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
614
|
+
|
615
|
+
assert self.global_num_tokens_cpu is not None
|
616
|
+
assert self.global_num_tokens_for_logprob_cpu is not None
|
617
|
+
|
618
|
+
global_num_tokens = self.global_num_tokens_cpu
|
619
|
+
sync_group_size = len(global_num_tokens)
|
620
|
+
attn_tp_size = get_attention_tp_size()
|
621
|
+
|
622
|
+
for i in range(sync_group_size):
|
623
|
+
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
|
624
|
+
# there is no reduce-scatter in LM logprob, so we do not need to adjust the padded length for logprob
|
625
|
+
global_num_tokens[i] = (
|
626
|
+
(global_num_tokens[i] - 1) // attn_tp_size + 1
|
627
|
+
) * attn_tp_size
|
628
|
+
|
629
|
+
dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens)
|
630
|
+
self.dp_padding_mode = dp_padding_mode
|
631
|
+
|
632
|
+
if dp_padding_mode.is_max_len():
|
633
|
+
# when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states,
|
634
|
+
# where transferred tokens should be padded to the same length.
|
635
|
+
max_num_tokens = max(global_num_tokens)
|
636
|
+
global_num_tokens = [max_num_tokens] * sync_group_size
|
637
|
+
buffer_len = max_num_tokens * sync_group_size
|
638
|
+
else:
|
639
|
+
buffer_len = sum(global_num_tokens)
|
640
|
+
|
641
|
+
self.gathered_buffer = torch.zeros(
|
642
|
+
(buffer_len, model_runner.model_config.hidden_size),
|
643
|
+
dtype=model_runner.dtype,
|
644
|
+
device=model_runner.device,
|
645
|
+
)
|
646
|
+
|
647
|
+
bs = self.batch_size
|
648
|
+
if len(global_num_tokens) > 1:
|
649
|
+
num_tokens = global_num_tokens[get_attention_dp_rank()]
|
650
|
+
else:
|
651
|
+
num_tokens = global_num_tokens[0]
|
652
|
+
|
653
|
+
# padding
|
654
|
+
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
|
655
|
+
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
|
656
|
+
|
657
|
+
seq_len_fill_value = (
|
658
|
+
model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
659
|
+
)
|
660
|
+
self.seq_lens = self._pad_tensor_to_size(
|
661
|
+
self.seq_lens, bs, value=seq_len_fill_value
|
662
|
+
)
|
663
|
+
if self.seq_lens_cpu is not None:
|
664
|
+
self.seq_lens_cpu = self._pad_tensor_to_size(
|
665
|
+
self.seq_lens_cpu, bs, value=seq_len_fill_value
|
666
|
+
)
|
667
|
+
|
668
|
+
self.out_cache_loc = self._pad_tensor_to_size(self.out_cache_loc, num_tokens)
|
669
|
+
if self.encoder_lens is not None:
|
670
|
+
self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs)
|
671
|
+
self.positions = self._pad_tensor_to_size(self.positions, num_tokens)
|
672
|
+
self.global_num_tokens_cpu = global_num_tokens
|
673
|
+
self.global_num_tokens_gpu = self.global_num_tokens_gpu.new_tensor(
|
674
|
+
global_num_tokens
|
675
|
+
)
|
676
|
+
|
677
|
+
if self.mrope_positions is not None:
|
678
|
+
self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
|
679
|
+
|
680
|
+
if self.extend_seq_lens is not None:
|
681
|
+
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
|
682
|
+
|
683
|
+
if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput):
|
684
|
+
spec_info = self.spec_info
|
685
|
+
self.output_cache_loc_backup = self.out_cache_loc
|
686
|
+
self.hidden_states_backup = spec_info.hidden_states
|
687
|
+
if spec_info.topk_p is not None:
|
688
|
+
spec_info.topk_p = self._pad_tensor_to_size(spec_info.topk_p, bs)
|
689
|
+
if spec_info.topk_index is not None:
|
690
|
+
spec_info.topk_index = self._pad_tensor_to_size(
|
691
|
+
spec_info.topk_index, bs
|
692
|
+
)
|
693
|
+
if spec_info.accept_length is not None:
|
694
|
+
spec_info.accept_length = self._pad_tensor_to_size(
|
695
|
+
spec_info.accept_length, bs
|
696
|
+
)
|
697
|
+
spec_info.hidden_states = self._pad_tensor_to_size(
|
698
|
+
spec_info.hidden_states, num_tokens
|
699
|
+
)
|
700
|
+
|
701
|
+
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
|
702
|
+
|
703
|
+
bs = self.batch_size
|
704
|
+
|
705
|
+
if self.spec_info is not None:
|
706
|
+
if self.forward_mode.is_decode(): # draft
|
707
|
+
num_tokens = self.hidden_states_backup.shape[0]
|
708
|
+
self.positions = self.positions[:num_tokens]
|
709
|
+
self.seq_lens = self.seq_lens[:bs]
|
710
|
+
self.req_pool_indices = self.req_pool_indices[:bs]
|
711
|
+
if self.seq_lens_cpu is not None:
|
712
|
+
self.seq_lens_cpu = self.seq_lens_cpu[:bs]
|
713
|
+
logits_output.next_token_logits = logits_output.next_token_logits[
|
714
|
+
:num_tokens
|
715
|
+
]
|
716
|
+
logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
|
717
|
+
elif self.forward_mode.is_target_verify(): # verify
|
718
|
+
num_tokens = bs * self.spec_info.draft_token_num
|
719
|
+
logits_output.next_token_logits = logits_output.next_token_logits[
|
720
|
+
:num_tokens
|
721
|
+
]
|
722
|
+
logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
|
723
|
+
elif self.forward_mode.is_draft_extend(): # draft extend
|
724
|
+
self.spec_info.accept_length = self.spec_info.accept_length[:bs]
|
725
|
+
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
|
726
|
+
logits_output.hidden_states = logits_output.hidden_states[:bs]
|
727
|
+
elif self.forward_mode.is_extend() or self.forward_mode.is_idle():
|
728
|
+
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
|
729
|
+
logits_output.hidden_states = logits_output.hidden_states[:bs]
|
730
|
+
|
731
|
+
if hasattr(self, "hidden_states_backup"):
|
732
|
+
self.spec_info.hidden_states = self.hidden_states_backup
|
733
|
+
if hasattr(self, "output_cache_loc_backup"):
|
734
|
+
self.out_cache_loc = self.output_cache_loc_backup
|
735
|
+
|
736
|
+
elif self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
737
|
+
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
|
738
|
+
if logits_output.hidden_states is not None:
|
739
|
+
logits_output.hidden_states = logits_output.hidden_states[:bs]
|
740
|
+
elif self.forward_mode.is_extend():
|
741
|
+
num_tokens = self.seq_lens_sum
|
742
|
+
logits_output.next_token_logits = logits_output.next_token_logits[
|
743
|
+
:num_tokens
|
744
|
+
]
|
745
|
+
if logits_output.hidden_states is not None:
|
746
|
+
logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
|
747
|
+
|
563
748
|
# Here we suppose the length of each chunk is equal
|
564
749
|
# For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
|
565
750
|
# num_prefix_chunks = cdiv(1024, 256) = 4
|