sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -46,7 +46,6 @@ class ServerArgs:
|
|
46
46
|
tokenizer_path: Optional[str] = None
|
47
47
|
tokenizer_mode: str = "auto"
|
48
48
|
skip_tokenizer_init: bool = False
|
49
|
-
enable_tokenizer_batch_encode: bool = False
|
50
49
|
load_format: str = "auto"
|
51
50
|
trust_remote_code: bool = False
|
52
51
|
dtype: str = "auto"
|
@@ -59,6 +58,7 @@ class ServerArgs:
|
|
59
58
|
chat_template: Optional[str] = None
|
60
59
|
completion_template: Optional[str] = None
|
61
60
|
is_embedding: bool = False
|
61
|
+
enable_multimodal: Optional[bool] = None
|
62
62
|
revision: Optional[str] = None
|
63
63
|
|
64
64
|
# Port for the HTTP server
|
@@ -97,8 +97,13 @@ class ServerArgs:
|
|
97
97
|
log_requests_level: int = 0
|
98
98
|
show_time_cost: bool = False
|
99
99
|
enable_metrics: bool = False
|
100
|
+
bucket_time_to_first_token: Optional[List[float]] = None
|
101
|
+
bucket_e2e_request_latency: Optional[List[float]] = None
|
102
|
+
bucket_inter_token_latency: Optional[List[float]] = None
|
103
|
+
collect_tokens_histogram: bool = False
|
100
104
|
decode_log_interval: int = 40
|
101
105
|
enable_request_time_stats_logging: bool = False
|
106
|
+
kv_events_config: Optional[str] = None
|
102
107
|
|
103
108
|
# API related
|
104
109
|
api_key: Optional[str] = None
|
@@ -120,6 +125,7 @@ class ServerArgs:
|
|
120
125
|
|
121
126
|
# Model override args in JSON
|
122
127
|
json_model_override_args: str = "{}"
|
128
|
+
preferred_sampling_params: Optional[str] = None
|
123
129
|
|
124
130
|
# LoRA
|
125
131
|
lora_paths: Optional[List[str]] = None
|
@@ -154,9 +160,9 @@ class ServerArgs:
|
|
154
160
|
disable_cuda_graph: bool = False
|
155
161
|
disable_cuda_graph_padding: bool = False
|
156
162
|
enable_nccl_nvls: bool = False
|
163
|
+
enable_tokenizer_batch_encode: bool = False
|
157
164
|
disable_outlines_disk_cache: bool = False
|
158
165
|
disable_custom_all_reduce: bool = False
|
159
|
-
enable_multimodal: Optional[bool] = None
|
160
166
|
disable_overlap_schedule: bool = False
|
161
167
|
enable_mixed_chunk: bool = False
|
162
168
|
enable_dp_attention: bool = False
|
@@ -164,6 +170,17 @@ class ServerArgs:
|
|
164
170
|
enable_ep_moe: bool = False
|
165
171
|
enable_deepep_moe: bool = False
|
166
172
|
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
173
|
+
ep_num_redundant_experts: int = 0
|
174
|
+
ep_dispatch_algorithm: Optional[Literal["static", "dynamic"]] = None
|
175
|
+
init_expert_location: str = "trivial"
|
176
|
+
enable_eplb: bool = False
|
177
|
+
eplb_rebalance_num_iterations: int = 1000
|
178
|
+
expert_distribution_recorder_mode: Optional[
|
179
|
+
Literal["stat", "per_pass", "per_token"]
|
180
|
+
] = None
|
181
|
+
expert_distribution_recorder_buffer_size: Optional[int] = None
|
182
|
+
enable_expert_distribution_metrics: bool = False
|
183
|
+
deepep_config: Optional[str] = None
|
167
184
|
enable_torch_compile: bool = False
|
168
185
|
torch_compile_max_bs: int = 32
|
169
186
|
cuda_graph_max_bs: Optional[int] = None
|
@@ -229,7 +246,7 @@ class ServerArgs:
|
|
229
246
|
# Set mem fraction static, which depends on the tensor parallelism size
|
230
247
|
if self.mem_fraction_static is None:
|
231
248
|
parallel_size = self.tp_size * self.pp_size
|
232
|
-
if gpu_mem <= 81920:
|
249
|
+
if gpu_mem is not None and gpu_mem <= 81920:
|
233
250
|
if parallel_size >= 16:
|
234
251
|
self.mem_fraction_static = 0.79
|
235
252
|
elif parallel_size >= 8:
|
@@ -242,7 +259,7 @@ class ServerArgs:
|
|
242
259
|
self.mem_fraction_static = 0.88
|
243
260
|
else:
|
244
261
|
self.mem_fraction_static = 0.88
|
245
|
-
if gpu_mem > 96 * 1024:
|
262
|
+
if gpu_mem is not None and gpu_mem > 96 * 1024:
|
246
263
|
mem_fraction = self.mem_fraction_static
|
247
264
|
self.mem_fraction_static = min(
|
248
265
|
mem_fraction + 48 * 1024 * (1 - mem_fraction) / gpu_mem,
|
@@ -307,12 +324,6 @@ class ServerArgs:
|
|
307
324
|
if self.grammar_backend is None:
|
308
325
|
self.grammar_backend = "xgrammar"
|
309
326
|
|
310
|
-
if self.pp_size > 1:
|
311
|
-
self.disable_overlap_schedule = True
|
312
|
-
logger.warning(
|
313
|
-
"Overlap scheduler is disabled because of using pipeline parallelism."
|
314
|
-
)
|
315
|
-
|
316
327
|
# Data parallelism attention
|
317
328
|
if self.enable_dp_attention:
|
318
329
|
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
@@ -354,6 +365,15 @@ class ServerArgs:
|
|
354
365
|
"Pipeline parallelism is incompatible with overlap schedule."
|
355
366
|
)
|
356
367
|
|
368
|
+
if self.expert_distribution_recorder_buffer_size is None:
|
369
|
+
# TODO pr-chain: enable this later
|
370
|
+
# if (x := self.eplb_rebalance_num_iterations) is not None:
|
371
|
+
# self.expert_distribution_recorder_buffer_size = x
|
372
|
+
if False:
|
373
|
+
pass
|
374
|
+
elif self.expert_distribution_recorder_mode is not None:
|
375
|
+
self.expert_distribution_recorder_buffer_size = 1000
|
376
|
+
|
357
377
|
# Speculative Decoding
|
358
378
|
if self.speculative_algorithm == "NEXTN":
|
359
379
|
# NEXTN shares the same implementation of EAGLE
|
@@ -474,11 +494,6 @@ class ServerArgs:
|
|
474
494
|
action="store_true",
|
475
495
|
help="If set, skip init tokenizer and pass input_ids in generate request.",
|
476
496
|
)
|
477
|
-
parser.add_argument(
|
478
|
-
"--enable-tokenizer-batch-encode",
|
479
|
-
action="store_true",
|
480
|
-
help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
|
481
|
-
)
|
482
497
|
parser.add_argument(
|
483
498
|
"--load-format",
|
484
499
|
type=str,
|
@@ -556,6 +571,7 @@ class ServerArgs:
|
|
556
571
|
"w8a8_int8",
|
557
572
|
"w8a8_fp8",
|
558
573
|
"moe_wna16",
|
574
|
+
"qoq",
|
559
575
|
],
|
560
576
|
help="The quantization method.",
|
561
577
|
)
|
@@ -603,6 +619,12 @@ class ServerArgs:
|
|
603
619
|
action="store_true",
|
604
620
|
help="Whether to use a CausalLM as an embedding model.",
|
605
621
|
)
|
622
|
+
parser.add_argument(
|
623
|
+
"--enable-multimodal",
|
624
|
+
default=ServerArgs.enable_multimodal,
|
625
|
+
action="store_true",
|
626
|
+
help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
|
627
|
+
)
|
606
628
|
parser.add_argument(
|
607
629
|
"--revision",
|
608
630
|
type=str,
|
@@ -780,6 +802,39 @@ class ServerArgs:
|
|
780
802
|
action="store_true",
|
781
803
|
help="Enable log prometheus metrics.",
|
782
804
|
)
|
805
|
+
parser.add_argument(
|
806
|
+
"--bucket-time-to-first-token",
|
807
|
+
type=float,
|
808
|
+
nargs="+",
|
809
|
+
default=ServerArgs.bucket_time_to_first_token,
|
810
|
+
help="The buckets of time to first token, specified as a list of floats.",
|
811
|
+
)
|
812
|
+
parser.add_argument(
|
813
|
+
"--bucket-inter-token-latency",
|
814
|
+
type=float,
|
815
|
+
nargs="+",
|
816
|
+
default=ServerArgs.bucket_inter_token_latency,
|
817
|
+
help="The buckets of inter-token latency, specified as a list of floats.",
|
818
|
+
)
|
819
|
+
parser.add_argument(
|
820
|
+
"--bucket-e2e-request-latency",
|
821
|
+
type=float,
|
822
|
+
nargs="+",
|
823
|
+
default=ServerArgs.bucket_e2e_request_latency,
|
824
|
+
help="The buckets of end-to-end request latency, specified as a list of floats.",
|
825
|
+
)
|
826
|
+
parser.add_argument(
|
827
|
+
"--collect-tokens-histogram",
|
828
|
+
action="store_true",
|
829
|
+
default=ServerArgs.collect_tokens_histogram,
|
830
|
+
help="Collect prompt/generation tokens histogram.",
|
831
|
+
)
|
832
|
+
parser.add_argument(
|
833
|
+
"--kv-events-config",
|
834
|
+
type=str,
|
835
|
+
default=None,
|
836
|
+
help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
|
837
|
+
)
|
783
838
|
parser.add_argument(
|
784
839
|
"--decode-log-interval",
|
785
840
|
type=int,
|
@@ -868,6 +923,11 @@ class ServerArgs:
|
|
868
923
|
help="A dictionary in JSON string format used to override default model configurations.",
|
869
924
|
default=ServerArgs.json_model_override_args,
|
870
925
|
)
|
926
|
+
parser.add_argument(
|
927
|
+
"--preferred-sampling-params",
|
928
|
+
type=str,
|
929
|
+
help="json-formatted sampling settings that will be returned in /get_model_info",
|
930
|
+
)
|
871
931
|
|
872
932
|
# LoRA
|
873
933
|
parser.add_argument(
|
@@ -896,6 +956,7 @@ class ServerArgs:
|
|
896
956
|
"--attention-backend",
|
897
957
|
type=str,
|
898
958
|
choices=[
|
959
|
+
"aiter",
|
899
960
|
"flashinfer",
|
900
961
|
"triton",
|
901
962
|
"torch_native",
|
@@ -1043,6 +1104,11 @@ class ServerArgs:
|
|
1043
1104
|
action="store_true",
|
1044
1105
|
help="Enable NCCL NVLS for prefill heavy requests when available.",
|
1045
1106
|
)
|
1107
|
+
parser.add_argument(
|
1108
|
+
"--enable-tokenizer-batch-encode",
|
1109
|
+
action="store_true",
|
1110
|
+
help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
|
1111
|
+
)
|
1046
1112
|
parser.add_argument(
|
1047
1113
|
"--disable-outlines-disk-cache",
|
1048
1114
|
action="store_true",
|
@@ -1053,12 +1119,6 @@ class ServerArgs:
|
|
1053
1119
|
action="store_true",
|
1054
1120
|
help="Disable the custom all-reduce kernel and fall back to NCCL.",
|
1055
1121
|
)
|
1056
|
-
parser.add_argument(
|
1057
|
-
"--enable-multimodal",
|
1058
|
-
default=ServerArgs.enable_multimodal,
|
1059
|
-
action="store_true",
|
1060
|
-
help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
|
1061
|
-
)
|
1062
1122
|
parser.add_argument(
|
1063
1123
|
"--disable-overlap-schedule",
|
1064
1124
|
action="store_true",
|
@@ -1072,7 +1132,7 @@ class ServerArgs:
|
|
1072
1132
|
parser.add_argument(
|
1073
1133
|
"--enable-dp-attention",
|
1074
1134
|
action="store_true",
|
1075
|
-
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently
|
1135
|
+
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported.",
|
1076
1136
|
)
|
1077
1137
|
parser.add_argument(
|
1078
1138
|
"--enable-dp-lm-head",
|
@@ -1212,6 +1272,58 @@ class ServerArgs:
|
|
1212
1272
|
default="auto",
|
1213
1273
|
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
|
1214
1274
|
)
|
1275
|
+
parser.add_argument(
|
1276
|
+
"--ep-num-redundant-experts",
|
1277
|
+
type=int,
|
1278
|
+
default=ServerArgs.ep_num_redundant_experts,
|
1279
|
+
help="Allocate this number of redundant experts in expert parallel.",
|
1280
|
+
)
|
1281
|
+
parser.add_argument(
|
1282
|
+
"--ep-dispatch-algorithm",
|
1283
|
+
type=str,
|
1284
|
+
default=ServerArgs.ep_dispatch_algorithm,
|
1285
|
+
help="The algorithm to choose ranks for redundant experts in expert parallel.",
|
1286
|
+
)
|
1287
|
+
parser.add_argument(
|
1288
|
+
"--init-expert-location",
|
1289
|
+
type=str,
|
1290
|
+
default=ServerArgs.init_expert_location,
|
1291
|
+
help="Initial location of EP experts.",
|
1292
|
+
)
|
1293
|
+
parser.add_argument(
|
1294
|
+
"--enable-eplb",
|
1295
|
+
action="store_true",
|
1296
|
+
help="Enable EPLB algorithm",
|
1297
|
+
)
|
1298
|
+
parser.add_argument(
|
1299
|
+
"--eplb-rebalance-num-iterations",
|
1300
|
+
type=int,
|
1301
|
+
default=ServerArgs.eplb_rebalance_num_iterations,
|
1302
|
+
help="Number of iterations to automatically trigger a EPLB re-balance.",
|
1303
|
+
)
|
1304
|
+
parser.add_argument(
|
1305
|
+
"--expert-distribution-recorder-mode",
|
1306
|
+
type=str,
|
1307
|
+
default=ServerArgs.expert_distribution_recorder_mode,
|
1308
|
+
help="Mode of expert distribution recorder.",
|
1309
|
+
)
|
1310
|
+
parser.add_argument(
|
1311
|
+
"--expert-distribution-recorder-buffer-size",
|
1312
|
+
type=int,
|
1313
|
+
default=ServerArgs.expert_distribution_recorder_buffer_size,
|
1314
|
+
help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
|
1315
|
+
)
|
1316
|
+
parser.add_argument(
|
1317
|
+
"--enable-expert-distribution-metrics",
|
1318
|
+
action="store_true",
|
1319
|
+
help="Enable logging metrics for expert balancedness",
|
1320
|
+
)
|
1321
|
+
parser.add_argument(
|
1322
|
+
"--deepep-config",
|
1323
|
+
type=str,
|
1324
|
+
default=ServerArgs.deepep_config,
|
1325
|
+
help="Tuned DeepEP config suitable for your own cluster.",
|
1326
|
+
)
|
1215
1327
|
|
1216
1328
|
parser.add_argument(
|
1217
1329
|
"--n-share-experts-fusion",
|
@@ -1326,8 +1438,6 @@ class ServerArgs:
|
|
1326
1438
|
|
1327
1439
|
# FIXME pp constraints
|
1328
1440
|
if self.pp_size > 1:
|
1329
|
-
logger.warning(f"Turn off overlap scheule for pipeline parallelism.")
|
1330
|
-
self.disable_overlap_schedule = True
|
1331
1441
|
assert (
|
1332
1442
|
self.disable_overlap_schedule
|
1333
1443
|
and self.speculative_algorithm is None
|
@@ -9,15 +9,18 @@ import torch.nn.functional as F
|
|
9
9
|
import triton
|
10
10
|
import triton.language as tl
|
11
11
|
|
12
|
+
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
12
13
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
13
14
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
14
15
|
from sglang.srt.managers.schedule_batch import (
|
16
|
+
Req,
|
15
17
|
ScheduleBatch,
|
16
18
|
get_last_loc,
|
17
19
|
global_server_args_dict,
|
18
20
|
)
|
19
21
|
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
20
22
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
23
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
21
24
|
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
22
25
|
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
|
23
26
|
|
@@ -187,6 +190,7 @@ class EagleVerifyInput:
|
|
187
190
|
draft_token_num: int
|
188
191
|
spec_steps: int
|
189
192
|
capture_hidden_mode: CaptureHiddenMode
|
193
|
+
grammar: BaseGrammarObject = None
|
190
194
|
|
191
195
|
@classmethod
|
192
196
|
def create(
|
@@ -307,6 +311,7 @@ class EagleVerifyInput:
|
|
307
311
|
logits_output: torch.Tensor,
|
308
312
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
309
313
|
page_size: int,
|
314
|
+
vocab_mask: Optional[torch.Tensor] = None,
|
310
315
|
) -> torch.Tensor:
|
311
316
|
"""
|
312
317
|
Verify and find accepted tokens based on logits output and batch
|
@@ -343,6 +348,13 @@ class EagleVerifyInput:
|
|
343
348
|
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
|
344
349
|
)
|
345
350
|
|
351
|
+
# Apply grammar mask
|
352
|
+
if vocab_mask is not None:
|
353
|
+
assert self.grammar is not None
|
354
|
+
self.grammar.apply_vocab_mask(
|
355
|
+
logits=logits_output.next_token_logits, vocab_mask=vocab_mask
|
356
|
+
)
|
357
|
+
|
346
358
|
# Sample tokens
|
347
359
|
if batch.sampling_info.is_all_greedy:
|
348
360
|
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
@@ -440,6 +452,15 @@ class EagleVerifyInput:
|
|
440
452
|
break
|
441
453
|
else:
|
442
454
|
new_accept_index_.append(idx)
|
455
|
+
# update grammar state
|
456
|
+
if req.grammar is not None:
|
457
|
+
try:
|
458
|
+
req.grammar.accept_token(id)
|
459
|
+
except ValueError as e:
|
460
|
+
logger.info(
|
461
|
+
f"{i=}, {req=}\n" f"{accept_index=}\n" f"{predict=}\n"
|
462
|
+
)
|
463
|
+
raise e
|
443
464
|
if not req.finished():
|
444
465
|
new_accept_index.extend(new_accept_index_)
|
445
466
|
unfinished_index.append(i)
|
@@ -801,3 +822,113 @@ def _generate_simulated_accept_index(
|
|
801
822
|
accept_length.fill_(simulate_acc_len - 1)
|
802
823
|
predict.fill_(100) # some legit token id
|
803
824
|
return sim_accept_index
|
825
|
+
|
826
|
+
|
827
|
+
def traverse_tree(
|
828
|
+
retrieve_next_token: torch.Tensor,
|
829
|
+
retrieve_next_sibling: torch.Tensor,
|
830
|
+
draft_tokens: torch.Tensor,
|
831
|
+
grammar: BaseGrammarObject,
|
832
|
+
allocate_token_bitmask: torch.Tensor,
|
833
|
+
):
|
834
|
+
"""
|
835
|
+
Traverse the tree constructed by the draft model to generate the logits mask.
|
836
|
+
"""
|
837
|
+
assert (
|
838
|
+
retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
|
839
|
+
)
|
840
|
+
|
841
|
+
allocate_token_bitmask.fill_(0)
|
842
|
+
|
843
|
+
def dfs(
|
844
|
+
curr: int,
|
845
|
+
retrieve_next_token: torch.Tensor,
|
846
|
+
retrieve_next_sibling: torch.Tensor,
|
847
|
+
parent_pos: int,
|
848
|
+
):
|
849
|
+
if curr == 0:
|
850
|
+
# the first token generated by the target model, and thus it is always
|
851
|
+
# accepted from the previous iteration
|
852
|
+
accepted = True
|
853
|
+
else:
|
854
|
+
parent_bitmask = allocate_token_bitmask[parent_pos]
|
855
|
+
curr_token_id = draft_tokens[curr]
|
856
|
+
# 32 boolean bitmask values are packed into 32-bit integers
|
857
|
+
accepted = (
|
858
|
+
parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
|
859
|
+
) != 0
|
860
|
+
|
861
|
+
if accepted:
|
862
|
+
if curr != 0:
|
863
|
+
# Accept the current token
|
864
|
+
grammar.accept_token(draft_tokens[curr])
|
865
|
+
if not grammar.is_terminated():
|
866
|
+
# Generate the bitmask for the current token
|
867
|
+
grammar.fill_vocab_mask(allocate_token_bitmask, curr)
|
868
|
+
if retrieve_next_token[curr] != -1:
|
869
|
+
# Visit the child node
|
870
|
+
dfs(
|
871
|
+
retrieve_next_token[curr],
|
872
|
+
retrieve_next_token,
|
873
|
+
retrieve_next_sibling,
|
874
|
+
curr,
|
875
|
+
)
|
876
|
+
|
877
|
+
if curr != 0:
|
878
|
+
# Rollback the current token
|
879
|
+
grammar.rollback(1)
|
880
|
+
|
881
|
+
if retrieve_next_sibling[curr] != -1:
|
882
|
+
# Visit the sibling node
|
883
|
+
dfs(
|
884
|
+
retrieve_next_sibling[curr],
|
885
|
+
retrieve_next_token,
|
886
|
+
retrieve_next_sibling,
|
887
|
+
parent_pos,
|
888
|
+
)
|
889
|
+
|
890
|
+
dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
|
891
|
+
|
892
|
+
|
893
|
+
def generate_token_bitmask(
|
894
|
+
reqs: List[Req],
|
895
|
+
verify_input: EagleVerifyInput,
|
896
|
+
retrieve_next_token_cpu: torch.Tensor,
|
897
|
+
retrieve_next_sibling_cpu: torch.Tensor,
|
898
|
+
draft_tokens_cpu: torch.Tensor,
|
899
|
+
vocab_size: int,
|
900
|
+
):
|
901
|
+
"""
|
902
|
+
Generate the logit mask for structured output.
|
903
|
+
Draft model's token can be either valid or invalid with respect to the grammar.
|
904
|
+
We need to perform DFS to figure out:
|
905
|
+
1. which tokens are accepted by the grammar
|
906
|
+
2. what is the corresponding logit mask.
|
907
|
+
"""
|
908
|
+
|
909
|
+
num_draft_tokens = draft_tokens_cpu.shape[-1]
|
910
|
+
|
911
|
+
allocate_token_bitmask = None
|
912
|
+
assert len(reqs) == retrieve_next_token_cpu.shape[0]
|
913
|
+
grammar = None
|
914
|
+
for i, req in enumerate(reqs):
|
915
|
+
if req.grammar is not None:
|
916
|
+
if allocate_token_bitmask is None:
|
917
|
+
allocate_token_bitmask = req.grammar.allocate_vocab_mask(
|
918
|
+
vocab_size=vocab_size,
|
919
|
+
batch_size=draft_tokens_cpu.numel(),
|
920
|
+
device="cpu",
|
921
|
+
)
|
922
|
+
grammar = req.grammar
|
923
|
+
traverse_tree(
|
924
|
+
retrieve_next_token_cpu[i],
|
925
|
+
retrieve_next_sibling_cpu[i],
|
926
|
+
draft_tokens_cpu[i],
|
927
|
+
req.grammar,
|
928
|
+
allocate_token_bitmask[
|
929
|
+
i * num_draft_tokens : (i + 1) * num_draft_tokens
|
930
|
+
],
|
931
|
+
)
|
932
|
+
|
933
|
+
verify_input.grammar = grammar
|
934
|
+
return allocate_token_bitmask
|
@@ -31,6 +31,7 @@ from sglang.srt.speculative.eagle_utils import (
|
|
31
31
|
EagleVerifyInput,
|
32
32
|
EagleVerifyOutput,
|
33
33
|
assign_draft_cache_locs,
|
34
|
+
generate_token_bitmask,
|
34
35
|
select_top_k_tokens,
|
35
36
|
)
|
36
37
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
@@ -199,6 +200,19 @@ class EAGLEWorker(TpModelWorker):
|
|
199
200
|
self.draft_extend_attn_backend = None
|
200
201
|
self.padded_static_len = self.speculative_num_steps + 1
|
201
202
|
self.has_prefill_wrapper_verify = False
|
203
|
+
elif self.server_args.attention_backend == "flashmla":
|
204
|
+
from sglang.srt.layers.attention.flashmla_backend import (
|
205
|
+
FlashMLAMultiStepDraftBackend,
|
206
|
+
)
|
207
|
+
|
208
|
+
self.draft_attn_backend = FlashMLAMultiStepDraftBackend(
|
209
|
+
self.draft_model_runner,
|
210
|
+
self.topk,
|
211
|
+
self.speculative_num_steps,
|
212
|
+
)
|
213
|
+
self.draft_extend_attn_backend = None
|
214
|
+
self.padded_static_len = self.speculative_num_steps + 1
|
215
|
+
self.has_prefill_wrapper_verify = False
|
202
216
|
else:
|
203
217
|
raise ValueError(
|
204
218
|
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
|
@@ -215,7 +229,7 @@ class EAGLEWorker(TpModelWorker):
|
|
215
229
|
return
|
216
230
|
|
217
231
|
# Capture draft
|
218
|
-
tic = time.
|
232
|
+
tic = time.perf_counter()
|
219
233
|
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
220
234
|
logger.info(
|
221
235
|
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
@@ -223,7 +237,7 @@ class EAGLEWorker(TpModelWorker):
|
|
223
237
|
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
224
238
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
225
239
|
logger.info(
|
226
|
-
f"Capture draft cuda graph end. Time elapsed: {time.
|
240
|
+
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
227
241
|
)
|
228
242
|
|
229
243
|
# Capture extend
|
@@ -479,11 +493,41 @@ class EAGLEWorker(TpModelWorker):
|
|
479
493
|
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
480
494
|
batch.spec_info = spec_info
|
481
495
|
model_worker_batch = batch.get_model_worker_batch()
|
496
|
+
|
497
|
+
if batch.has_grammar:
|
498
|
+
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
|
499
|
+
retrieve_next_sibling_cpu = spec_info.retrive_next_sibling.cpu()
|
500
|
+
draft_tokens_cpu = spec_info.draft_token.view(
|
501
|
+
spec_info.retrive_next_token.shape
|
502
|
+
).cpu()
|
503
|
+
|
504
|
+
# Forward
|
482
505
|
logits_output, _, can_run_cuda_graph = (
|
483
506
|
self.target_worker.forward_batch_generation(
|
484
507
|
model_worker_batch, skip_sample=True
|
485
508
|
)
|
486
509
|
)
|
510
|
+
|
511
|
+
vocab_mask = None
|
512
|
+
if batch.has_grammar:
|
513
|
+
# Generate the logit mask for structured output.
|
514
|
+
# Overlap the CPU operations for bitmask generation with the forward pass.
|
515
|
+
vocab_mask = generate_token_bitmask(
|
516
|
+
batch.reqs,
|
517
|
+
spec_info,
|
518
|
+
retrieve_next_token_cpu,
|
519
|
+
retrieve_next_sibling_cpu,
|
520
|
+
draft_tokens_cpu,
|
521
|
+
batch.sampling_info.vocab_size,
|
522
|
+
)
|
523
|
+
|
524
|
+
if vocab_mask is not None:
|
525
|
+
assert spec_info.grammar is not None
|
526
|
+
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
|
527
|
+
# otherwise, this vocab mask will be the one from the previous extend stage
|
528
|
+
# and will be applied to produce wrong results
|
529
|
+
batch.sampling_info.vocab_mask = None
|
530
|
+
|
487
531
|
self._detect_nan_if_needed(logits_output)
|
488
532
|
spec_info.hidden_states = logits_output.hidden_states
|
489
533
|
res: EagleVerifyOutput = spec_info.verify(
|
@@ -491,6 +535,7 @@ class EAGLEWorker(TpModelWorker):
|
|
491
535
|
logits_output,
|
492
536
|
self.token_to_kv_pool_allocator,
|
493
537
|
self.page_size,
|
538
|
+
vocab_mask,
|
494
539
|
)
|
495
540
|
|
496
541
|
# Post process based on verified outputs.
|