sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +14 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +301 -64
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +20 -15
- sglang/srt/disaggregation/utils.py +47 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +27 -31
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +897 -0
- sglang/srt/entrypoints/openai/serving_completions.py +425 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +28 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +43 -23
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +44 -2
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +14 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +286 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
- sglang/srt/layers/moe/topk.py +117 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +144 -12
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +19 -14
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +49 -32
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +189 -68
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +77 -46
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +27 -8
- sglang/srt/model_loader/loader.py +50 -8
- sglang/srt/model_loader/weight_utils.py +100 -2
- sglang/srt/models/deepseek_nextn.py +35 -30
- sglang/srt/models/deepseek_v2.py +255 -30
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +51 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +248 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/mimo_mtp.py
CHANGED
@@ -7,33 +7,17 @@ import torch
|
|
7
7
|
from torch import nn
|
8
8
|
from transformers import PretrainedConfig
|
9
9
|
|
10
|
-
from sglang.srt.distributed import
|
11
|
-
get_tensor_model_parallel_rank,
|
12
|
-
get_tensor_model_parallel_world_size,
|
13
|
-
split_tensor_along_last_dim,
|
14
|
-
tensor_model_parallel_all_gather,
|
15
|
-
)
|
10
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
16
11
|
from sglang.srt.layers.layernorm import RMSNorm
|
17
|
-
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
|
18
12
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
19
|
-
from sglang.srt.layers.pooler import Pooler, PoolingType
|
20
13
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
21
|
-
from sglang.srt.layers.radix_attention import RadixAttention
|
22
|
-
from sglang.srt.layers.rotary_embedding import get_rope
|
23
14
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
24
15
|
ParallelLMHead,
|
25
16
|
VocabParallelEmbedding,
|
26
17
|
)
|
27
18
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
28
19
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
29
|
-
from sglang.srt.models.
|
30
|
-
from sglang.srt.models.qwen2 import (
|
31
|
-
Qwen2Attention,
|
32
|
-
Qwen2DecoderLayer,
|
33
|
-
Qwen2MLP,
|
34
|
-
Qwen2Model,
|
35
|
-
)
|
36
|
-
from sglang.srt.utils import add_prefix
|
20
|
+
from sglang.srt.models.qwen2 import Qwen2DecoderLayer
|
37
21
|
|
38
22
|
|
39
23
|
class MiMoMultiTokenPredictorLayer(nn.Module):
|
sglang/srt/reasoning_parser.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Dict, Tuple
|
1
|
+
from typing import Dict, Optional, Tuple, Type
|
2
2
|
|
3
3
|
|
4
4
|
class StreamingParseResult:
|
@@ -32,17 +32,26 @@ class BaseReasoningFormatDetector:
|
|
32
32
|
One-time parsing: Detects and parses reasoning sections in the provided text.
|
33
33
|
Returns both reasoning content and normal text separately.
|
34
34
|
"""
|
35
|
-
|
36
|
-
|
35
|
+
in_reasoning = self._in_reasoning or text.startswith(self.think_start_token)
|
36
|
+
|
37
|
+
if not in_reasoning:
|
38
|
+
return StreamingParseResult(normal_text=text)
|
39
|
+
|
40
|
+
# The text is considered to be in a reasoning block.
|
41
|
+
processed_text = text.replace(self.think_start_token, "").strip()
|
42
|
+
|
43
|
+
if self.think_end_token not in processed_text:
|
37
44
|
# Assume reasoning was truncated before `</think>` token
|
38
|
-
return StreamingParseResult(reasoning_text=
|
45
|
+
return StreamingParseResult(reasoning_text=processed_text)
|
39
46
|
|
40
47
|
# Extract reasoning content
|
41
|
-
splits =
|
48
|
+
splits = processed_text.split(self.think_end_token, maxsplit=1)
|
42
49
|
reasoning_text = splits[0]
|
43
|
-
|
50
|
+
normal_text = splits[1].strip()
|
44
51
|
|
45
|
-
return StreamingParseResult(
|
52
|
+
return StreamingParseResult(
|
53
|
+
normal_text=normal_text, reasoning_text=reasoning_text
|
54
|
+
)
|
46
55
|
|
47
56
|
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
|
48
57
|
"""
|
@@ -61,6 +70,7 @@ class BaseReasoningFormatDetector:
|
|
61
70
|
if not self.stripped_think_start and self.think_start_token in current_text:
|
62
71
|
current_text = current_text.replace(self.think_start_token, "")
|
63
72
|
self.stripped_think_start = True
|
73
|
+
self._in_reasoning = True
|
64
74
|
|
65
75
|
# Handle end of reasoning block
|
66
76
|
if self._in_reasoning and self.think_end_token in current_text:
|
@@ -131,11 +141,11 @@ class Qwen3Detector(BaseReasoningFormatDetector):
|
|
131
141
|
"""
|
132
142
|
|
133
143
|
def __init__(self, stream_reasoning: bool = True):
|
134
|
-
# Qwen3
|
144
|
+
# Qwen3 won't be in reasoning mode when user passes `enable_thinking=False`
|
135
145
|
super().__init__(
|
136
146
|
"<think>",
|
137
147
|
"</think>",
|
138
|
-
force_reasoning=
|
148
|
+
force_reasoning=False,
|
139
149
|
stream_reasoning=stream_reasoning,
|
140
150
|
)
|
141
151
|
|
@@ -151,12 +161,12 @@ class ReasoningParser:
|
|
151
161
|
If True, streams reasoning content as it arrives.
|
152
162
|
"""
|
153
163
|
|
154
|
-
DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
|
164
|
+
DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
|
155
165
|
"deepseek-r1": DeepSeekR1Detector,
|
156
166
|
"qwen3": Qwen3Detector,
|
157
167
|
}
|
158
168
|
|
159
|
-
def __init__(self, model_type: str = None, stream_reasoning: bool = True):
|
169
|
+
def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
|
160
170
|
if not model_type:
|
161
171
|
raise ValueError("Model type must be specified")
|
162
172
|
|
sglang/srt/server_args.py
CHANGED
@@ -47,6 +47,7 @@ class ServerArgs:
|
|
47
47
|
tokenizer_mode: str = "auto"
|
48
48
|
skip_tokenizer_init: bool = False
|
49
49
|
load_format: str = "auto"
|
50
|
+
model_loader_extra_config: str = "{}"
|
50
51
|
trust_remote_code: bool = False
|
51
52
|
dtype: str = "auto"
|
52
53
|
kv_cache_dtype: str = "auto"
|
@@ -152,6 +153,7 @@ class ServerArgs:
|
|
152
153
|
ep_size: int = 1
|
153
154
|
enable_ep_moe: bool = False
|
154
155
|
enable_deepep_moe: bool = False
|
156
|
+
enable_flashinfer_moe: bool = False
|
155
157
|
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
156
158
|
ep_num_redundant_experts: int = 0
|
157
159
|
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
|
@@ -234,6 +236,10 @@ class ServerArgs:
|
|
234
236
|
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
235
237
|
pdlb_url: Optional[str] = None
|
236
238
|
|
239
|
+
# For model weight update
|
240
|
+
custom_weight_loader: Optional[List[str]] = None
|
241
|
+
weight_loader_disable_mmap: bool = False
|
242
|
+
|
237
243
|
def __post_init__(self):
|
238
244
|
# Expert parallelism
|
239
245
|
if self.enable_ep_moe:
|
@@ -241,7 +247,15 @@ class ServerArgs:
|
|
241
247
|
logger.warning(
|
242
248
|
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
243
249
|
)
|
244
|
-
|
250
|
+
if self.enable_flashinfer_moe:
|
251
|
+
assert (
|
252
|
+
self.quantization == "modelopt_fp4"
|
253
|
+
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
254
|
+
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
255
|
+
self.disable_shared_experts_fusion = True
|
256
|
+
logger.warning(
|
257
|
+
f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
|
258
|
+
)
|
245
259
|
# Set missing default values
|
246
260
|
if self.tokenizer_path is None:
|
247
261
|
self.tokenizer_path = self.model_path
|
@@ -384,7 +398,6 @@ class ServerArgs:
|
|
384
398
|
), "Please enable dp attention when setting enable_dp_attention. "
|
385
399
|
|
386
400
|
# DeepEP MoE
|
387
|
-
self.enable_sp_layernorm = False
|
388
401
|
if self.enable_deepep_moe:
|
389
402
|
if self.deepep_mode == "auto":
|
390
403
|
assert (
|
@@ -394,9 +407,6 @@ class ServerArgs:
|
|
394
407
|
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
|
395
408
|
self.disable_cuda_graph = True
|
396
409
|
self.ep_size = self.tp_size
|
397
|
-
self.enable_sp_layernorm = (
|
398
|
-
self.dp_size < self.tp_size if self.enable_dp_attention else True
|
399
|
-
)
|
400
410
|
logger.warning(
|
401
411
|
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
402
412
|
)
|
@@ -538,6 +548,9 @@ class ServerArgs:
|
|
538
548
|
"1" if self.disable_outlines_disk_cache else "0"
|
539
549
|
)
|
540
550
|
|
551
|
+
if self.custom_weight_loader is None:
|
552
|
+
self.custom_weight_loader = []
|
553
|
+
|
541
554
|
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
|
542
555
|
larger_tp = max(decode_tp, prefill_tp)
|
543
556
|
smaller_tp = min(decode_tp, prefill_tp)
|
@@ -551,6 +564,7 @@ class ServerArgs:
|
|
551
564
|
# Model and port args
|
552
565
|
parser.add_argument(
|
553
566
|
"--model-path",
|
567
|
+
"--model",
|
554
568
|
type=str,
|
555
569
|
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
|
556
570
|
required=True,
|
@@ -620,6 +634,13 @@ class ServerArgs:
|
|
620
634
|
"layer before loading another to make the peak memory envelope "
|
621
635
|
"smaller.",
|
622
636
|
)
|
637
|
+
parser.add_argument(
|
638
|
+
"--model-loader-extra-config",
|
639
|
+
type=str,
|
640
|
+
help="Extra config for model loader. "
|
641
|
+
"This will be passed to the model loader corresponding to the chosen load_format.",
|
642
|
+
default=ServerArgs.model_loader_extra_config,
|
643
|
+
)
|
623
644
|
parser.add_argument(
|
624
645
|
"--trust-remote-code",
|
625
646
|
action="store_true",
|
@@ -1160,6 +1181,11 @@ class ServerArgs:
|
|
1160
1181
|
action="store_true",
|
1161
1182
|
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
|
1162
1183
|
)
|
1184
|
+
parser.add_argument(
|
1185
|
+
"--enable-flashinfer-moe",
|
1186
|
+
action="store_true",
|
1187
|
+
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
|
1188
|
+
)
|
1163
1189
|
parser.add_argument(
|
1164
1190
|
"--enable-deepep-moe",
|
1165
1191
|
action="store_true",
|
@@ -1576,6 +1602,18 @@ class ServerArgs:
|
|
1576
1602
|
default=None,
|
1577
1603
|
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
|
1578
1604
|
)
|
1605
|
+
parser.add_argument(
|
1606
|
+
"--custom-weight-loader",
|
1607
|
+
type=str,
|
1608
|
+
nargs="*",
|
1609
|
+
default=None,
|
1610
|
+
help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
|
1611
|
+
)
|
1612
|
+
parser.add_argument(
|
1613
|
+
"--weight-loader-disable-mmap",
|
1614
|
+
action="store_true",
|
1615
|
+
help="Disable mmap while loading weight using safetensors.",
|
1616
|
+
)
|
1579
1617
|
|
1580
1618
|
@classmethod
|
1581
1619
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -1663,6 +1701,9 @@ class PortArgs:
|
|
1663
1701
|
# The ipc filename for rpc call between Engine and Scheduler
|
1664
1702
|
rpc_ipc_name: str
|
1665
1703
|
|
1704
|
+
# The ipc filename for Scheduler to send metrics
|
1705
|
+
metrics_ipc_name: str
|
1706
|
+
|
1666
1707
|
@staticmethod
|
1667
1708
|
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
1668
1709
|
port = server_args.port + random.randint(100, 1000)
|
@@ -1682,6 +1723,7 @@ class PortArgs:
|
|
1682
1723
|
detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
1683
1724
|
nccl_port=port,
|
1684
1725
|
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
1726
|
+
metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
1685
1727
|
)
|
1686
1728
|
else:
|
1687
1729
|
# DP attention. Use TCP + port to handle both single-node and multi-node.
|
@@ -1700,11 +1742,10 @@ class PortArgs:
|
|
1700
1742
|
dist_init_host, dist_init_port = dist_init_addr
|
1701
1743
|
port_base = int(dist_init_port) + 1
|
1702
1744
|
if dp_rank is None:
|
1703
|
-
|
1704
|
-
|
1705
|
-
) # TokenizerManager to DataParallelController
|
1745
|
+
# TokenizerManager to DataParallelController
|
1746
|
+
scheduler_input_port = port_base + 4
|
1706
1747
|
else:
|
1707
|
-
scheduler_input_port = port_base +
|
1748
|
+
scheduler_input_port = port_base + 4 + 1 + dp_rank
|
1708
1749
|
|
1709
1750
|
return PortArgs(
|
1710
1751
|
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
|
@@ -1712,6 +1753,7 @@ class PortArgs:
|
|
1712
1753
|
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
|
1713
1754
|
nccl_port=port,
|
1714
1755
|
rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
|
1756
|
+
metrics_ipc_name=f"tcp://{dist_init_host}:{port_base + 3}",
|
1715
1757
|
)
|
1716
1758
|
|
1717
1759
|
|
@@ -20,6 +20,12 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
20
20
|
ForwardMode,
|
21
21
|
)
|
22
22
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
23
|
+
from sglang.srt.utils import (
|
24
|
+
require_attn_tp_gather,
|
25
|
+
require_gathered_buffer,
|
26
|
+
require_mlp_sync,
|
27
|
+
require_mlp_tp_gather,
|
28
|
+
)
|
23
29
|
|
24
30
|
if TYPE_CHECKING:
|
25
31
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
@@ -38,6 +44,12 @@ class EAGLEDraftCudaGraphRunner:
|
|
38
44
|
self.output_buffers = {}
|
39
45
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
40
46
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
47
|
+
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
48
|
+
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
49
|
+
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
50
|
+
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
51
|
+
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
52
|
+
self.dp_size = self.model_runner.dp_size
|
41
53
|
self.tp_size = self.model_runner.tp_size
|
42
54
|
self.topk = model_runner.server_args.speculative_eagle_topk
|
43
55
|
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
@@ -53,7 +65,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
53
65
|
# Attention backend
|
54
66
|
self.max_bs = max(self.capture_bs)
|
55
67
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
56
|
-
self.model_runner.draft_attn_backend.init_cuda_graph_state(
|
68
|
+
self.model_runner.draft_attn_backend.init_cuda_graph_state(
|
69
|
+
self.max_bs, self.max_num_token
|
70
|
+
)
|
57
71
|
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
|
58
72
|
0
|
59
73
|
].get_cuda_graph_seq_len_fill_value()
|
@@ -78,10 +92,32 @@ class EAGLEDraftCudaGraphRunner:
|
|
78
92
|
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
|
79
93
|
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
|
80
94
|
self.hidden_states = torch.zeros(
|
81
|
-
(self.
|
95
|
+
(self.max_bs, self.model_runner.model_config.hidden_size),
|
82
96
|
dtype=self.model_runner.dtype,
|
83
97
|
)
|
84
98
|
|
99
|
+
if self.require_gathered_buffer:
|
100
|
+
self.gathered_buffer = torch.zeros(
|
101
|
+
(
|
102
|
+
self.max_num_token,
|
103
|
+
self.model_runner.model_config.hidden_size,
|
104
|
+
),
|
105
|
+
dtype=self.model_runner.dtype,
|
106
|
+
)
|
107
|
+
if self.require_mlp_tp_gather:
|
108
|
+
self.global_num_tokens_gpu = torch.zeros(
|
109
|
+
(self.dp_size,), dtype=torch.int32
|
110
|
+
)
|
111
|
+
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
112
|
+
(self.dp_size,), dtype=torch.int32
|
113
|
+
)
|
114
|
+
else:
|
115
|
+
assert self.require_attn_tp_gather
|
116
|
+
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
117
|
+
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
118
|
+
(1,), dtype=torch.int32
|
119
|
+
)
|
120
|
+
|
85
121
|
# Capture
|
86
122
|
try:
|
87
123
|
with model_capture_mode():
|
@@ -92,11 +128,24 @@ class EAGLEDraftCudaGraphRunner:
|
|
92
128
|
)
|
93
129
|
|
94
130
|
def can_run(self, forward_batch: ForwardBatch):
|
131
|
+
if self.require_mlp_tp_gather:
|
132
|
+
cuda_graph_bs = (
|
133
|
+
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
134
|
+
if self.model_runner.spec_algorithm.is_eagle()
|
135
|
+
else sum(forward_batch.global_num_tokens_cpu)
|
136
|
+
)
|
137
|
+
else:
|
138
|
+
cuda_graph_bs = forward_batch.batch_size
|
139
|
+
|
95
140
|
is_bs_supported = (
|
96
|
-
|
141
|
+
cuda_graph_bs in self.graphs
|
97
142
|
if self.disable_padding
|
98
|
-
else
|
143
|
+
else cuda_graph_bs <= self.max_bs
|
99
144
|
)
|
145
|
+
|
146
|
+
if self.require_mlp_sync:
|
147
|
+
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
148
|
+
|
100
149
|
return is_bs_supported
|
101
150
|
|
102
151
|
def capture(self):
|
@@ -116,8 +165,58 @@ class EAGLEDraftCudaGraphRunner:
|
|
116
165
|
topk_index = self.topk_index[:num_seqs]
|
117
166
|
hidden_states = self.hidden_states[:num_seqs]
|
118
167
|
|
168
|
+
if self.require_mlp_tp_gather:
|
169
|
+
self.global_num_tokens_gpu.copy_(
|
170
|
+
torch.tensor(
|
171
|
+
[
|
172
|
+
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
173
|
+
for i in range(self.dp_size)
|
174
|
+
],
|
175
|
+
dtype=torch.int32,
|
176
|
+
device=self.input_ids.device,
|
177
|
+
)
|
178
|
+
)
|
179
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
180
|
+
torch.tensor(
|
181
|
+
[
|
182
|
+
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
183
|
+
for i in range(self.dp_size)
|
184
|
+
],
|
185
|
+
dtype=torch.int32,
|
186
|
+
device=self.input_ids.device,
|
187
|
+
)
|
188
|
+
)
|
189
|
+
global_num_tokens = self.global_num_tokens_gpu
|
190
|
+
gathered_buffer = self.gathered_buffer[:num_tokens]
|
191
|
+
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
192
|
+
elif self.require_attn_tp_gather:
|
193
|
+
self.global_num_tokens_gpu.copy_(
|
194
|
+
torch.tensor(
|
195
|
+
[num_tokens],
|
196
|
+
dtype=torch.int32,
|
197
|
+
device=self.input_ids.device,
|
198
|
+
)
|
199
|
+
)
|
200
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
201
|
+
torch.tensor(
|
202
|
+
[num_tokens],
|
203
|
+
dtype=torch.int32,
|
204
|
+
device=self.input_ids.device,
|
205
|
+
)
|
206
|
+
)
|
207
|
+
global_num_tokens = self.global_num_tokens_gpu
|
208
|
+
gathered_buffer = self.gathered_buffer[:num_tokens]
|
209
|
+
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
210
|
+
else:
|
211
|
+
global_num_tokens = None
|
212
|
+
gathered_buffer = None
|
213
|
+
global_num_tokens_for_logprob = None
|
214
|
+
|
119
215
|
spec_info = EagleDraftInput(
|
120
|
-
topk_p=topk_p,
|
216
|
+
topk_p=topk_p,
|
217
|
+
topk_index=topk_index,
|
218
|
+
hidden_states=hidden_states,
|
219
|
+
capture_hidden_mode=CaptureHiddenMode.LAST,
|
121
220
|
)
|
122
221
|
|
123
222
|
# Forward batch
|
@@ -133,11 +232,14 @@ class EAGLEDraftCudaGraphRunner:
|
|
133
232
|
seq_lens_sum=seq_lens.sum().item(),
|
134
233
|
return_logprob=False,
|
135
234
|
positions=positions,
|
235
|
+
global_num_tokens_gpu=global_num_tokens,
|
236
|
+
gathered_buffer=gathered_buffer,
|
136
237
|
spec_algorithm=self.model_runner.spec_algorithm,
|
137
238
|
spec_info=spec_info,
|
138
239
|
capture_hidden_mode=(
|
139
240
|
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
140
241
|
),
|
242
|
+
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
|
141
243
|
)
|
142
244
|
|
143
245
|
# Attention backend
|
@@ -147,6 +249,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
147
249
|
|
148
250
|
# Run and capture
|
149
251
|
def run_once():
|
252
|
+
# Clean intermediate result cache for DP attention
|
253
|
+
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
254
|
+
|
150
255
|
# Backup two fields, which will be modified in-place in `draft_forward`.
|
151
256
|
output_cache_loc_backup = forward_batch.out_cache_loc
|
152
257
|
hidden_states_backup = forward_batch.spec_info.hidden_states
|
@@ -184,12 +289,19 @@ class EAGLEDraftCudaGraphRunner:
|
|
184
289
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
185
290
|
|
186
291
|
# Pad
|
187
|
-
|
292
|
+
if self.require_mlp_tp_gather:
|
293
|
+
total_batch_size = (
|
294
|
+
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
295
|
+
if self.model_runner.spec_algorithm.is_eagle()
|
296
|
+
else sum(forward_batch.global_num_tokens_cpu)
|
297
|
+
)
|
298
|
+
index = bisect.bisect_left(self.capture_bs, total_batch_size)
|
299
|
+
else:
|
300
|
+
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
188
301
|
bs = self.capture_bs[index]
|
189
302
|
if bs != raw_bs:
|
190
|
-
self.seq_lens.fill_(
|
303
|
+
self.seq_lens.fill_(self.seq_len_fill_value)
|
191
304
|
self.out_cache_loc.zero_()
|
192
|
-
self.positions.zero_()
|
193
305
|
|
194
306
|
num_tokens = bs * self.num_tokens_per_bs
|
195
307
|
|
@@ -204,6 +316,13 @@ class EAGLEDraftCudaGraphRunner:
|
|
204
316
|
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
|
205
317
|
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
206
318
|
|
319
|
+
if self.require_gathered_buffer:
|
320
|
+
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
321
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
322
|
+
forward_batch.global_num_tokens_for_logprob_gpu
|
323
|
+
)
|
324
|
+
forward_batch.gathered_buffer = self.gathered_buffer
|
325
|
+
|
207
326
|
# Attention backend
|
208
327
|
if bs != raw_bs:
|
209
328
|
forward_batch.batch_size = bs
|
@@ -212,14 +331,16 @@ class EAGLEDraftCudaGraphRunner:
|
|
212
331
|
forward_batch.positions = self.positions[:num_tokens]
|
213
332
|
|
214
333
|
# Special handle for seq_len_cpu used when flashinfer mla is used
|
215
|
-
if forward_batch.seq_lens_cpu is not None
|
216
|
-
|
334
|
+
if forward_batch.seq_lens_cpu is not None:
|
335
|
+
if bs != raw_bs:
|
336
|
+
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
217
337
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
218
338
|
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
|
219
339
|
|
220
340
|
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
221
341
|
forward_batch, bs
|
222
342
|
)
|
343
|
+
# TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph
|
223
344
|
|
224
345
|
# Replay
|
225
346
|
self.graphs[bs].replay()
|