sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,11 @@ from typing import List, Optional, Set, Union
|
|
22
22
|
import torch
|
23
23
|
from transformers import PretrainedConfig
|
24
24
|
|
25
|
-
from sglang.srt.hf_transformers_utils import
|
25
|
+
from sglang.srt.hf_transformers_utils import (
|
26
|
+
get_config,
|
27
|
+
get_context_length,
|
28
|
+
get_hf_text_config,
|
29
|
+
)
|
26
30
|
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
27
31
|
from sglang.srt.server_args import ServerArgs
|
28
32
|
from sglang.srt.utils import get_bool_env_var, is_hip
|
@@ -69,6 +73,7 @@ class ModelConfig:
|
|
69
73
|
model_override_args=self.model_override_args,
|
70
74
|
**kwargs,
|
71
75
|
)
|
76
|
+
|
72
77
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
73
78
|
self.attention_chunk_size = getattr(
|
74
79
|
self.hf_text_config, "attention_chunk_size", None
|
@@ -93,6 +98,8 @@ class ModelConfig:
|
|
93
98
|
):
|
94
99
|
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
|
95
100
|
|
101
|
+
if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
|
102
|
+
self.hf_config.architectures[0] = "MiMoMTP"
|
96
103
|
# Check model type
|
97
104
|
self.is_generation = is_generation_model(
|
98
105
|
self.hf_config.architectures, is_embedding
|
@@ -109,6 +116,10 @@ class ModelConfig:
|
|
109
116
|
self.is_audio_model = enable_multimodal and is_audio_model(
|
110
117
|
self.hf_config.architectures
|
111
118
|
)
|
119
|
+
self.is_multimodal_chunked_prefill_supported = (
|
120
|
+
enable_multimodal
|
121
|
+
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
|
122
|
+
)
|
112
123
|
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
113
124
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
114
125
|
|
@@ -209,7 +220,13 @@ class ModelConfig:
|
|
209
220
|
|
210
221
|
# Cache attributes
|
211
222
|
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
212
|
-
|
223
|
+
|
224
|
+
config = self.hf_config
|
225
|
+
|
226
|
+
# multimodal
|
227
|
+
self.image_token_id = getattr(config, "image_token_id", None) or getattr(
|
228
|
+
config, "image_token_index", None
|
229
|
+
)
|
213
230
|
|
214
231
|
@staticmethod
|
215
232
|
def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
|
@@ -332,6 +349,7 @@ class ModelConfig:
|
|
332
349
|
"w8a8_int8",
|
333
350
|
"w8a8_fp8",
|
334
351
|
"moe_wna16",
|
352
|
+
"qoq",
|
335
353
|
]
|
336
354
|
compatible_quantization_methods = {
|
337
355
|
"modelopt_fp4": ["modelopt"],
|
@@ -423,31 +441,6 @@ class ModelConfig:
|
|
423
441
|
self.model_path = client.get_local_dir()
|
424
442
|
|
425
443
|
|
426
|
-
def get_hf_text_config(config: PretrainedConfig):
|
427
|
-
"""Get the "sub" config relevant to llm for multi modal models.
|
428
|
-
No op for pure text models.
|
429
|
-
"""
|
430
|
-
class_name = config.architectures[0]
|
431
|
-
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
|
432
|
-
# We support non-hf version of llava models, so we do not want to
|
433
|
-
# read the wrong values from the unused default text_config.
|
434
|
-
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
|
435
|
-
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
|
436
|
-
setattr(config, "torch_dtype", torch.float16)
|
437
|
-
return config
|
438
|
-
|
439
|
-
if hasattr(config, "text_config"):
|
440
|
-
# The code operates under the assumption that text_config should have
|
441
|
-
# `num_attention_heads` (among others). Assert here to fail early
|
442
|
-
# if transformers config doesn't align with this assumption.
|
443
|
-
assert hasattr(config.text_config, "num_attention_heads")
|
444
|
-
return config.text_config
|
445
|
-
if hasattr(config, "language_config"):
|
446
|
-
return config.language_config
|
447
|
-
else:
|
448
|
-
return config
|
449
|
-
|
450
|
-
|
451
444
|
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
452
445
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
453
446
|
"half": torch.float16,
|
@@ -466,6 +459,8 @@ def _get_and_verify_dtype(
|
|
466
459
|
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
467
460
|
# because config.torch_dtype can be None.
|
468
461
|
config_dtype = getattr(config, "torch_dtype", None)
|
462
|
+
if isinstance(config_dtype, str):
|
463
|
+
config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
|
469
464
|
if config_dtype is None:
|
470
465
|
config_dtype = torch.float32
|
471
466
|
|
@@ -537,6 +532,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
|
537
532
|
|
538
533
|
|
539
534
|
multimodal_model_archs = [
|
535
|
+
"CLIPModel",
|
540
536
|
"DeepseekVL2ForCausalLM",
|
541
537
|
"Gemma3ForConditionalGeneration",
|
542
538
|
"Grok1VForCausalLM",
|
@@ -549,11 +545,11 @@ multimodal_model_archs = [
|
|
549
545
|
"LlavaVidForCausalLM",
|
550
546
|
"MiniCPMO",
|
551
547
|
"MiniCPMV",
|
548
|
+
"Mistral3ForConditionalGeneration",
|
552
549
|
"MultiModalityCausalLM",
|
553
550
|
"MllamaForConditionalGeneration",
|
554
551
|
"Qwen2VLForConditionalGeneration",
|
555
552
|
"Qwen2_5_VLForConditionalGeneration",
|
556
|
-
"CLIPModel",
|
557
553
|
"KimiVLForConditionalGeneration",
|
558
554
|
"InternVLChatModel",
|
559
555
|
]
|
@@ -585,6 +581,21 @@ def is_encoder_decoder_model(model_architectures: List[str]):
|
|
585
581
|
return "MllamaForConditionalGeneration" in model_architectures
|
586
582
|
|
587
583
|
|
584
|
+
def is_multimodal_chunked_prefill_supported(model_architectures: List[str]):
|
585
|
+
"""Check if chunked prefill is supported for a MultiModal model."""
|
586
|
+
unsupported = [
|
587
|
+
"Grok1VForCausalLM",
|
588
|
+
"Grok1AForCausalLM",
|
589
|
+
"LlavaLlamaForCausalLM",
|
590
|
+
"MllamaForConditionalGeneration",
|
591
|
+
"CLIPModel",
|
592
|
+
]
|
593
|
+
if any(multi_model_arch in unsupported for multi_model_arch in model_architectures):
|
594
|
+
return False
|
595
|
+
else:
|
596
|
+
return True
|
597
|
+
|
598
|
+
|
588
599
|
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
589
600
|
if scale <= 1:
|
590
601
|
return 1.0
|
sglang/srt/conversation.py
CHANGED
@@ -781,7 +781,7 @@ register_conv_template(
|
|
781
781
|
Conversation(
|
782
782
|
name="gemma-it",
|
783
783
|
system_message="You are a helpful assistant.",
|
784
|
-
system_template="<start_of_turn>user{system_message}\n\n",
|
784
|
+
system_template="<start_of_turn>user\n{system_message}\n\n",
|
785
785
|
roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
|
786
786
|
sep="<end_of_turn>\n",
|
787
787
|
sep_style=SeparatorStyle.GEMMA3,
|
@@ -24,6 +24,7 @@ import logging
|
|
24
24
|
import os
|
25
25
|
from collections import deque
|
26
26
|
from dataclasses import dataclass
|
27
|
+
from http import HTTPStatus
|
27
28
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
28
29
|
|
29
30
|
import numpy as np
|
@@ -35,25 +36,25 @@ from sglang.srt.disaggregation.utils import (
|
|
35
36
|
DisaggregationMode,
|
36
37
|
FakeBootstrapHost,
|
37
38
|
KVClassType,
|
39
|
+
MetadataBuffers,
|
38
40
|
ReqToMetadataIdxAllocator,
|
39
41
|
TransferBackend,
|
40
42
|
get_kv_class,
|
41
43
|
is_mla_backend,
|
42
44
|
kv_to_page_indices,
|
43
45
|
poll_and_all_reduce,
|
46
|
+
prepare_abort,
|
44
47
|
)
|
48
|
+
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
45
49
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
46
50
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
47
51
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
48
|
-
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
49
52
|
|
50
53
|
logger = logging.getLogger(__name__)
|
51
54
|
|
52
55
|
if TYPE_CHECKING:
|
53
|
-
from sglang.srt.
|
54
|
-
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
56
|
+
from sglang.srt.managers.schedule_batch import Req
|
55
57
|
from sglang.srt.managers.scheduler import Scheduler
|
56
|
-
from sglang.srt.server_args import ServerArgs
|
57
58
|
|
58
59
|
|
59
60
|
@dataclass
|
@@ -73,9 +74,9 @@ class DecodePreallocQueue:
|
|
73
74
|
self,
|
74
75
|
req_to_token_pool: ReqToTokenPool,
|
75
76
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
77
|
+
draft_token_to_kv_pool: Optional[KVCache],
|
76
78
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
77
|
-
metadata_buffers:
|
78
|
-
aux_dtype: torch.dtype,
|
79
|
+
metadata_buffers: MetadataBuffers,
|
79
80
|
scheduler: Scheduler,
|
80
81
|
transfer_queue: DecodeTransferQueue,
|
81
82
|
tree_cache: BasePrefixCache,
|
@@ -88,8 +89,8 @@ class DecodePreallocQueue:
|
|
88
89
|
self.req_to_token_pool = req_to_token_pool
|
89
90
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
90
91
|
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
|
92
|
+
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
91
93
|
self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)
|
92
|
-
self.aux_dtype = aux_dtype
|
93
94
|
self.metadata_buffers = metadata_buffers
|
94
95
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
95
96
|
self.scheduler = scheduler
|
@@ -116,19 +117,21 @@ class DecodePreallocQueue:
|
|
116
117
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
117
118
|
)
|
118
119
|
|
120
|
+
if self.draft_token_to_kv_pool is not None:
|
121
|
+
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
|
122
|
+
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
|
123
|
+
)
|
124
|
+
kv_data_ptrs += draft_kv_data_ptrs
|
125
|
+
kv_data_lens += draft_kv_data_lens
|
126
|
+
kv_item_lens += draft_kv_item_lens
|
127
|
+
|
119
128
|
kv_args.kv_data_ptrs = kv_data_ptrs
|
120
129
|
kv_args.kv_data_lens = kv_data_lens
|
121
130
|
kv_args.kv_item_lens = kv_item_lens
|
122
131
|
|
123
|
-
kv_args.aux_data_ptrs =
|
124
|
-
|
125
|
-
|
126
|
-
kv_args.aux_data_lens = [
|
127
|
-
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
|
128
|
-
]
|
129
|
-
kv_args.aux_item_lens = [
|
130
|
-
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
131
|
-
]
|
132
|
+
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
133
|
+
self.metadata_buffers.get_buf_infos()
|
134
|
+
)
|
132
135
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
133
136
|
kv_args.gpu_id = self.scheduler.gpu_id
|
134
137
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
@@ -178,7 +181,17 @@ class DecodePreallocQueue:
|
|
178
181
|
elif poll == KVPoll.WaitingForInput:
|
179
182
|
decode_req.waiting_for_input = True
|
180
183
|
elif poll == KVPoll.Failed:
|
181
|
-
|
184
|
+
error_message = f"Decode handshake failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
185
|
+
try:
|
186
|
+
decode_req.kv_receiver.failure_exception()
|
187
|
+
except Exception as e:
|
188
|
+
error_message += f" with exception {e}"
|
189
|
+
logger.error(error_message)
|
190
|
+
prepare_abort(
|
191
|
+
decode_req.req,
|
192
|
+
error_message,
|
193
|
+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
194
|
+
)
|
182
195
|
|
183
196
|
def pop_preallocated(self) -> List[DecodeRequest]:
|
184
197
|
"""Pop the preallocated requests from the pending queue (FIFO)."""
|
@@ -188,7 +201,18 @@ class DecodePreallocQueue:
|
|
188
201
|
indices_to_remove = set()
|
189
202
|
allocatable_tokens = self._allocatable_tokens()
|
190
203
|
|
204
|
+
# First, remove all failed requests from the queue
|
191
205
|
for i, decode_req in enumerate(self.queue):
|
206
|
+
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
|
207
|
+
self.scheduler.stream_output(
|
208
|
+
[decode_req.req], decode_req.req.return_logprob
|
209
|
+
)
|
210
|
+
indices_to_remove.add(i)
|
211
|
+
|
212
|
+
for i, decode_req in enumerate(self.queue):
|
213
|
+
if i in indices_to_remove:
|
214
|
+
continue
|
215
|
+
|
192
216
|
if not decode_req.waiting_for_input:
|
193
217
|
continue
|
194
218
|
|
@@ -308,18 +332,22 @@ class DecodeTransferQueue:
|
|
308
332
|
self,
|
309
333
|
gloo_group: ProcessGroup,
|
310
334
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
311
|
-
metadata_buffers:
|
335
|
+
metadata_buffers: MetadataBuffers,
|
336
|
+
scheduler: Scheduler,
|
337
|
+
tree_cache: BasePrefixCache,
|
312
338
|
):
|
313
339
|
self.queue: List[DecodeRequest] = []
|
314
340
|
self.gloo_group = gloo_group
|
315
341
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
316
342
|
self.metadata_buffers = metadata_buffers
|
343
|
+
self.scheduler = scheduler
|
344
|
+
self.tree_cache = tree_cache
|
317
345
|
|
318
|
-
def add(self,
|
319
|
-
self.queue.append(
|
346
|
+
def add(self, decode_req: DecodeRequest) -> None:
|
347
|
+
self.queue.append(decode_req)
|
320
348
|
|
321
|
-
def extend(self,
|
322
|
-
self.queue.extend(
|
349
|
+
def extend(self, decode_reqs: List[DecodeRequest]) -> None:
|
350
|
+
self.queue.extend(decode_reqs)
|
323
351
|
|
324
352
|
def pop_transferred(self) -> List[DecodeRequest]:
|
325
353
|
if not self.queue:
|
@@ -333,18 +361,56 @@ class DecodeTransferQueue:
|
|
333
361
|
indices_to_remove = set()
|
334
362
|
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
335
363
|
if poll == KVPoll.Failed:
|
336
|
-
|
364
|
+
error_message = f"Decode transfer failed for request {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
365
|
+
try:
|
366
|
+
decode_req.kv_receiver.failure_exception()
|
367
|
+
except Exception as e:
|
368
|
+
error_message += f" with exception {e}"
|
369
|
+
logger.error(error_message)
|
370
|
+
prepare_abort(
|
371
|
+
decode_req.req,
|
372
|
+
error_message,
|
373
|
+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
374
|
+
)
|
375
|
+
self.scheduler.stream_output(
|
376
|
+
[decode_req.req], decode_req.req.return_logprob
|
377
|
+
)
|
378
|
+
# unlock the kv cache or it will have memory leak
|
379
|
+
self.tree_cache.cache_finished_req(decode_req.req)
|
380
|
+
indices_to_remove.add(i)
|
381
|
+
continue
|
337
382
|
elif poll == KVPoll.Success:
|
338
|
-
|
383
|
+
|
339
384
|
idx = decode_req.metadata_buffer_index
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
385
|
+
(
|
386
|
+
output_id,
|
387
|
+
output_token_logprobs_val,
|
388
|
+
output_token_logprobs_idx,
|
389
|
+
output_top_logprobs_val,
|
390
|
+
output_top_logprobs_idx,
|
391
|
+
) = self.metadata_buffers.get_buf(idx)
|
392
|
+
|
393
|
+
decode_req.req.output_ids.append(output_id[0].item())
|
394
|
+
|
395
|
+
if decode_req.req.return_logprob:
|
396
|
+
decode_req.req.output_token_logprobs_val.append(
|
397
|
+
output_token_logprobs_val[0].item()
|
398
|
+
)
|
399
|
+
decode_req.req.output_token_logprobs_idx.append(
|
400
|
+
output_token_logprobs_idx[0].item()
|
401
|
+
)
|
402
|
+
decode_req.req.output_top_logprobs_val.append(
|
403
|
+
output_top_logprobs_val[
|
404
|
+
: decode_req.req.top_logprobs_num
|
405
|
+
].tolist()
|
406
|
+
)
|
407
|
+
decode_req.req.output_top_logprobs_idx.append(
|
408
|
+
output_top_logprobs_idx[
|
409
|
+
: decode_req.req.top_logprobs_num
|
410
|
+
].tolist()
|
411
|
+
)
|
412
|
+
|
413
|
+
transferred_reqs.append(decode_req.req)
|
348
414
|
indices_to_remove.add(i)
|
349
415
|
elif poll in [
|
350
416
|
KVPoll.Bootstrapping,
|
@@ -367,95 +433,6 @@ class DecodeTransferQueue:
|
|
367
433
|
return transferred_reqs
|
368
434
|
|
369
435
|
|
370
|
-
class ScheduleBatchDisaggregationDecodeMixin:
|
371
|
-
|
372
|
-
def prepare_for_prebuilt_extend(self: ScheduleBatch):
|
373
|
-
"""
|
374
|
-
Prepare a prebuilt extend by populate metadata
|
375
|
-
Adapted from .prepare_for_extend().
|
376
|
-
"""
|
377
|
-
|
378
|
-
self.forward_mode = ForwardMode.EXTEND
|
379
|
-
reqs = self.reqs
|
380
|
-
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
381
|
-
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
382
|
-
seq_lens = []
|
383
|
-
pre_lens = []
|
384
|
-
req_pool_indices = []
|
385
|
-
|
386
|
-
# Pre-calculate total size
|
387
|
-
total_size = sum(req.extend_input_len for req in reqs)
|
388
|
-
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
|
389
|
-
|
390
|
-
# Fill the tensor in one pass
|
391
|
-
offset = 0
|
392
|
-
for i, req in enumerate(reqs):
|
393
|
-
req_pool_indices.append(req.req_pool_idx)
|
394
|
-
|
395
|
-
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
396
|
-
: req.extend_input_len
|
397
|
-
]
|
398
|
-
assert (
|
399
|
-
offset + req.extend_input_len <= total_size
|
400
|
-
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
|
401
|
-
out_cache_loc[offset : offset + req.extend_input_len] = chunk
|
402
|
-
offset += req.extend_input_len
|
403
|
-
|
404
|
-
pre_len = len(req.prefix_indices)
|
405
|
-
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
|
406
|
-
seq_lens.append(seq_len)
|
407
|
-
if len(req.output_ids) == 0:
|
408
|
-
assert (
|
409
|
-
seq_len - pre_len == req.extend_input_len
|
410
|
-
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
|
411
|
-
|
412
|
-
req.cached_tokens += pre_len - req.already_computed
|
413
|
-
req.already_computed = seq_len
|
414
|
-
req.is_retracted = False
|
415
|
-
pre_lens.append(pre_len)
|
416
|
-
req.extend_logprob_start_len = 0
|
417
|
-
|
418
|
-
extend_input_logprob_token_ids = None
|
419
|
-
|
420
|
-
# Set fields
|
421
|
-
self.input_ids = torch.tensor(
|
422
|
-
sum(input_ids, []), dtype=torch.int32, device=self.device
|
423
|
-
)
|
424
|
-
self.req_pool_indices = torch.tensor(
|
425
|
-
req_pool_indices, dtype=torch.int64, device=self.device
|
426
|
-
)
|
427
|
-
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
428
|
-
self.out_cache_loc = out_cache_loc
|
429
|
-
self.seq_lens_sum = sum(seq_lens)
|
430
|
-
self.extend_num_tokens = extend_num_tokens
|
431
|
-
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
432
|
-
self.extend_lens = [r.extend_input_len for r in reqs]
|
433
|
-
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
434
|
-
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
435
|
-
|
436
|
-
# Build sampling info
|
437
|
-
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
438
|
-
self,
|
439
|
-
self.model_config.vocab_size,
|
440
|
-
)
|
441
|
-
|
442
|
-
def process_prebuilt_extend(
|
443
|
-
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
|
444
|
-
):
|
445
|
-
"""Assign the buffered last input id to schedule batch"""
|
446
|
-
self.output_ids = []
|
447
|
-
for req in self.reqs:
|
448
|
-
if req.output_ids and len(req.output_ids) > 0:
|
449
|
-
# resumed retracted req
|
450
|
-
self.output_ids.append(req.output_ids[-1])
|
451
|
-
else:
|
452
|
-
assert req.transferred_output_id is not None
|
453
|
-
req.output_ids.append(req.transferred_output_id)
|
454
|
-
self.output_ids.append(req.transferred_output_id)
|
455
|
-
self.tree_cache.cache_unfinished_req(req)
|
456
|
-
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
457
|
-
|
458
|
-
|
459
436
|
class SchedulerDisaggregationDecodeMixin:
|
460
437
|
|
461
438
|
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
@@ -488,7 +465,9 @@ class SchedulerDisaggregationDecodeMixin:
|
|
488
465
|
# Generate fake extend output.
|
489
466
|
if batch.forward_mode.is_extend():
|
490
467
|
# Note: Logprobs should be handled on the prefill engine.
|
491
|
-
self.stream_output(
|
468
|
+
self.stream_output(
|
469
|
+
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
470
|
+
)
|
492
471
|
if prepare_dp_attn_flag:
|
493
472
|
self._prepare_idle_batch_and_run(None)
|
494
473
|
else:
|
@@ -534,7 +513,9 @@ class SchedulerDisaggregationDecodeMixin:
|
|
534
513
|
# Generate fake extend output.
|
535
514
|
if batch.forward_mode.is_extend():
|
536
515
|
# Note: Logprobs should be handled on the prefill engine.
|
537
|
-
self.stream_output(
|
516
|
+
self.stream_output(
|
517
|
+
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
518
|
+
)
|
538
519
|
if prepare_dp_attn_flag:
|
539
520
|
batch_, result = self._prepare_idle_batch_and_run(
|
540
521
|
None, delay_process=True
|
@@ -547,7 +528,18 @@ class SchedulerDisaggregationDecodeMixin:
|
|
547
528
|
self.prepare_dp_attn_batch(batch)
|
548
529
|
result = self.run_batch(batch)
|
549
530
|
result_queue.append((batch.copy(), result))
|
531
|
+
|
532
|
+
if (self.last_batch is None) or (not self.last_batch_in_queue):
|
533
|
+
# Create a dummy first batch to start the pipeline for overlap schedule.
|
534
|
+
# It is now used for triggering the sampling_info_done event.
|
535
|
+
tmp_batch = ScheduleBatch(
|
536
|
+
reqs=None,
|
537
|
+
forward_mode=ForwardMode.DUMMY_FIRST,
|
538
|
+
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
539
|
+
)
|
540
|
+
self.set_next_batch_sampling_info_done(tmp_batch)
|
550
541
|
last_batch_in_queue = True
|
542
|
+
|
551
543
|
elif prepare_dp_attn_flag:
|
552
544
|
batch, result = self._prepare_idle_batch_and_run(
|
553
545
|
None, delay_process=True
|
@@ -559,6 +551,9 @@ class SchedulerDisaggregationDecodeMixin:
|
|
559
551
|
# Process the results of the previous batch but skip if the last batch is extend
|
560
552
|
if self.last_batch and self.last_batch_in_queue:
|
561
553
|
tmp_batch, tmp_result = result_queue.popleft()
|
554
|
+
tmp_batch.next_batch_sampling_info = (
|
555
|
+
self.tp_worker.cur_sampling_info if batch else None
|
556
|
+
)
|
562
557
|
self.process_batch_result(tmp_batch, tmp_result)
|
563
558
|
|
564
559
|
if batch is None and (
|
@@ -607,6 +602,9 @@ class SchedulerDisaggregationDecodeMixin:
|
|
607
602
|
|
608
603
|
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
|
609
604
|
"""Create a schedulebatch for fake completed prefill"""
|
605
|
+
if self.grammar_queue:
|
606
|
+
self.move_ready_grammar_requests()
|
607
|
+
|
610
608
|
if len(self.waiting_queue) == 0:
|
611
609
|
return None
|
612
610
|
|
@@ -632,8 +630,6 @@ class SchedulerDisaggregationDecodeMixin:
|
|
632
630
|
self.waiting_queue = waiting_queue
|
633
631
|
if len(can_run_list) == 0:
|
634
632
|
return None
|
635
|
-
# local import to avoid circular import
|
636
|
-
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
637
633
|
|
638
634
|
# construct a schedule batch with those requests and mark as decode
|
639
635
|
new_batch = ScheduleBatch.init_new(
|
@@ -655,15 +651,8 @@ class SchedulerDisaggregationDecodeMixin:
|
|
655
651
|
|
656
652
|
def process_decode_queue(self: Scheduler):
|
657
653
|
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
658
|
-
|
659
|
-
def _num_pre_alloc(req):
|
660
|
-
return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0)
|
661
|
-
|
662
|
-
self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns)
|
663
654
|
self.disagg_decode_transfer_queue.extend(req_conns)
|
664
655
|
alloc_reqs = (
|
665
656
|
self.disagg_decode_transfer_queue.pop_transferred()
|
666
657
|
) # the requests which kv has arrived
|
667
|
-
self.
|
668
|
-
|
669
|
-
self.waiting_queue.extend([req.req for req in alloc_reqs])
|
658
|
+
self.waiting_queue.extend(alloc_reqs)
|