sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +5 -1
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +17 -2
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +65 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +5 -9
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +148 -72
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +105 -66
- sglang/srt/function_call/function_call_parser.py +6 -4
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +46 -25
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +88 -34
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +33 -14
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +188 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +62 -13
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +27 -11
- sglang/srt/managers/scheduler.py +48 -26
- sglang/srt/managers/tokenizer_manager.py +62 -28
- sglang/srt/managers/tp_worker.py +5 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +35 -18
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +109 -37
- sglang/srt/models/deepseek_v2.py +63 -30
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +2 -6
- sglang/srt/models/qwen3_moe.py +6 -8
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +48 -5
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +132 -60
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +113 -69
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_activation.py +50 -1
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -20,10 +20,10 @@ import logging
|
|
20
20
|
import os
|
21
21
|
import random
|
22
22
|
import tempfile
|
23
|
-
from token import OP
|
24
23
|
from typing import List, Literal, Optional, Union
|
25
24
|
|
26
25
|
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
26
|
+
from sglang.srt.lora.lora_registry import LoRARef
|
27
27
|
from sglang.srt.reasoning_parser import ReasoningParser
|
28
28
|
from sglang.srt.utils import (
|
29
29
|
LORA_TARGET_ALL_MODULES,
|
@@ -80,7 +80,7 @@ class ServerArgs:
|
|
80
80
|
schedule_policy: str = "fcfs"
|
81
81
|
schedule_conservativeness: float = 1.0
|
82
82
|
cpu_offload_gb: int = 0
|
83
|
-
page_size: int =
|
83
|
+
page_size: Optional[int] = None
|
84
84
|
hybrid_kvcache_ratio: Optional[float] = None
|
85
85
|
swa_full_tokens_ratio: float = 0.8
|
86
86
|
disable_hybrid_swa_memory: bool = False
|
@@ -145,12 +145,14 @@ class ServerArgs:
|
|
145
145
|
enable_lora: Optional[bool] = None
|
146
146
|
max_lora_rank: Optional[int] = None
|
147
147
|
lora_target_modules: Optional[Union[set[str], List[str]]] = None
|
148
|
-
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
|
148
|
+
lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
|
149
149
|
max_loras_per_batch: int = 8
|
150
150
|
lora_backend: str = "triton"
|
151
151
|
|
152
152
|
# Kernel backend
|
153
153
|
attention_backend: Optional[str] = None
|
154
|
+
decode_attention_backend: Optional[str] = None
|
155
|
+
prefill_attention_backend: Optional[str] = None
|
154
156
|
sampling_backend: Optional[str] = None
|
155
157
|
grammar_backend: Optional[str] = None
|
156
158
|
mm_attention_backend: Optional[str] = None
|
@@ -169,7 +171,8 @@ class ServerArgs:
|
|
169
171
|
ep_size: int = 1
|
170
172
|
enable_ep_moe: bool = False
|
171
173
|
enable_deepep_moe: bool = False
|
172
|
-
|
174
|
+
enable_flashinfer_cutlass_moe: bool = False
|
175
|
+
enable_flashinfer_trtllm_moe: bool = False
|
173
176
|
enable_flashinfer_allreduce_fusion: bool = False
|
174
177
|
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
175
178
|
ep_num_redundant_experts: int = 0
|
@@ -266,31 +269,20 @@ class ServerArgs:
|
|
266
269
|
|
267
270
|
def __post_init__(self):
|
268
271
|
# Expert parallelism
|
272
|
+
# We put it here first due to some internal ckpt conversation issues.
|
269
273
|
if self.enable_ep_moe:
|
270
274
|
self.ep_size = self.tp_size
|
271
275
|
logger.warning(
|
272
276
|
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
273
277
|
)
|
274
|
-
if self.enable_flashinfer_moe:
|
275
|
-
assert (
|
276
|
-
self.quantization == "modelopt_fp4"
|
277
|
-
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
278
|
-
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
279
|
-
self.disable_shared_experts_fusion = True
|
280
|
-
logger.warning(
|
281
|
-
f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
|
282
|
-
)
|
283
278
|
|
284
279
|
# Set missing default values
|
285
280
|
if self.tokenizer_path is None:
|
286
281
|
self.tokenizer_path = self.model_path
|
287
|
-
|
288
|
-
if self.device is None:
|
289
|
-
self.device = get_device()
|
290
|
-
|
291
282
|
if self.served_model_name is None:
|
292
283
|
self.served_model_name = self.model_path
|
293
|
-
|
284
|
+
if self.device is None:
|
285
|
+
self.device = get_device()
|
294
286
|
if self.random_seed is None:
|
295
287
|
self.random_seed = random.randint(0, 1 << 30)
|
296
288
|
|
@@ -359,7 +351,6 @@ class ServerArgs:
|
|
359
351
|
self.chunked_prefill_size = 16384
|
360
352
|
else:
|
361
353
|
self.chunked_prefill_size = 4096
|
362
|
-
assert self.chunked_prefill_size % self.page_size == 0
|
363
354
|
|
364
355
|
# Set cuda graph max batch size
|
365
356
|
if self.cuda_graph_max_bs is None:
|
@@ -398,18 +389,32 @@ class ServerArgs:
|
|
398
389
|
)
|
399
390
|
self.page_size = 128
|
400
391
|
|
401
|
-
if
|
392
|
+
if (
|
393
|
+
self.attention_backend == "flashmla"
|
394
|
+
or self.decode_attention_backend == "flashmla"
|
395
|
+
):
|
402
396
|
logger.warning(
|
403
397
|
"FlashMLA only supports a page_size of 64, change page_size to 64."
|
404
398
|
)
|
405
399
|
self.page_size = 64
|
406
400
|
|
407
|
-
if
|
401
|
+
if (
|
402
|
+
self.attention_backend == "cutlass_mla"
|
403
|
+
or self.decode_attention_backend == "cutlass_mla"
|
404
|
+
):
|
408
405
|
logger.warning(
|
409
406
|
"Cutlass MLA only supports a page_size of 128, change page_size to 128."
|
410
407
|
)
|
411
408
|
self.page_size = 128
|
412
409
|
|
410
|
+
# Set page size
|
411
|
+
if self.page_size is None:
|
412
|
+
self.page_size = 1
|
413
|
+
|
414
|
+
# AMD-specific Triton attention KV splits default number
|
415
|
+
if is_hip():
|
416
|
+
self.triton_attention_num_kv_splits = 16
|
417
|
+
|
413
418
|
# Choose grammar backend
|
414
419
|
if self.grammar_backend is None:
|
415
420
|
self.grammar_backend = "xgrammar"
|
@@ -431,6 +436,17 @@ class ServerArgs:
|
|
431
436
|
self.enable_dp_attention
|
432
437
|
), "Please enable dp attention when setting enable_dp_lm_head. "
|
433
438
|
|
439
|
+
# MoE kernel
|
440
|
+
if self.enable_flashinfer_cutlass_moe:
|
441
|
+
assert (
|
442
|
+
self.quantization == "modelopt_fp4"
|
443
|
+
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
444
|
+
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
445
|
+
|
446
|
+
if self.enable_flashinfer_trtllm_moe:
|
447
|
+
assert self.enable_ep_moe, "EP MoE is required for Flashinfer TRTLLM MOE"
|
448
|
+
logger.warning(f"Flashinfer TRTLLM MoE is enabled.")
|
449
|
+
|
434
450
|
# DeepEP MoE
|
435
451
|
if self.enable_deepep_moe:
|
436
452
|
if self.deepep_mode == "normal":
|
@@ -455,6 +471,9 @@ class ServerArgs:
|
|
455
471
|
"EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
|
456
472
|
)
|
457
473
|
|
474
|
+
if self.enable_eplb:
|
475
|
+
assert self.enable_ep_moe or self.enable_deepep_moe
|
476
|
+
|
458
477
|
if self.enable_expert_distribution_metrics and (
|
459
478
|
self.expert_distribution_recorder_mode is None
|
460
479
|
):
|
@@ -494,7 +513,7 @@ class ServerArgs:
|
|
494
513
|
)
|
495
514
|
|
496
515
|
model_arch = self.get_hf_config().architectures[0]
|
497
|
-
if model_arch
|
516
|
+
if model_arch in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"]:
|
498
517
|
# Auto set draft_model_path DeepSeek-V3/R1
|
499
518
|
if self.speculative_draft_model_path is None:
|
500
519
|
self.speculative_draft_model_path = self.model_path
|
@@ -502,14 +521,6 @@ class ServerArgs:
|
|
502
521
|
logger.warning(
|
503
522
|
"DeepSeek MTP does not require setting speculative_draft_model_path."
|
504
523
|
)
|
505
|
-
elif "Llama4" in model_arch:
|
506
|
-
# TODO: remove this after Llama4 supports in other backends
|
507
|
-
if self.attention_backend != "fa3":
|
508
|
-
self.attention_backend = "fa3"
|
509
|
-
logger.warning(
|
510
|
-
"Llama4 requires using fa3 attention backend. "
|
511
|
-
"Attention backend is automatically set to fa3."
|
512
|
-
)
|
513
524
|
|
514
525
|
# Auto choose parameters
|
515
526
|
if self.speculative_num_steps is None:
|
@@ -542,12 +553,11 @@ class ServerArgs:
|
|
542
553
|
) and check_gguf_file(self.model_path):
|
543
554
|
self.quantization = self.load_format = "gguf"
|
544
555
|
|
556
|
+
# Model loading
|
545
557
|
if is_remote_url(self.model_path):
|
546
558
|
self.load_format = "remote"
|
547
|
-
|
548
|
-
|
549
|
-
if is_hip():
|
550
|
-
self.triton_attention_num_kv_splits = 16
|
559
|
+
if self.custom_weight_loader is None:
|
560
|
+
self.custom_weight_loader = []
|
551
561
|
|
552
562
|
# PD disaggregation
|
553
563
|
if self.disaggregation_mode == "decode":
|
@@ -572,6 +582,7 @@ class ServerArgs:
|
|
572
582
|
self.disable_cuda_graph = True
|
573
583
|
logger.warning("Cuda graph is disabled for prefill server")
|
574
584
|
|
585
|
+
# Propagate env vars
|
575
586
|
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
576
587
|
"1" if self.enable_torch_compile else "0"
|
577
588
|
)
|
@@ -580,9 +591,6 @@ class ServerArgs:
|
|
580
591
|
"1" if self.disable_outlines_disk_cache else "0"
|
581
592
|
)
|
582
593
|
|
583
|
-
if self.custom_weight_loader is None:
|
584
|
-
self.custom_weight_loader = []
|
585
|
-
|
586
594
|
@staticmethod
|
587
595
|
def add_cli_args(parser: argparse.ArgumentParser):
|
588
596
|
# Model and tokenizer
|
@@ -1099,10 +1107,11 @@ class ServerArgs:
|
|
1099
1107
|
"deepseekv3",
|
1100
1108
|
"pythonic",
|
1101
1109
|
"kimi_k2",
|
1102
|
-
"
|
1110
|
+
"qwen3_coder",
|
1111
|
+
"glm45",
|
1103
1112
|
],
|
1104
1113
|
default=ServerArgs.tool_call_parser,
|
1105
|
-
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and '
|
1114
|
+
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.",
|
1106
1115
|
)
|
1107
1116
|
|
1108
1117
|
# Data parallelism
|
@@ -1213,6 +1222,35 @@ class ServerArgs:
|
|
1213
1222
|
default=ServerArgs.attention_backend,
|
1214
1223
|
help="Choose the kernels for attention layers.",
|
1215
1224
|
)
|
1225
|
+
parser.add_argument(
|
1226
|
+
"--decode-attention-backend",
|
1227
|
+
type=str,
|
1228
|
+
choices=[
|
1229
|
+
"flashinfer",
|
1230
|
+
"triton",
|
1231
|
+
"torch_native",
|
1232
|
+
"fa3",
|
1233
|
+
"flashmla",
|
1234
|
+
"cutlass_mla",
|
1235
|
+
],
|
1236
|
+
default=ServerArgs.decode_attention_backend,
|
1237
|
+
help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
|
1238
|
+
)
|
1239
|
+
|
1240
|
+
parser.add_argument(
|
1241
|
+
"--prefill-attention-backend",
|
1242
|
+
type=str,
|
1243
|
+
choices=[
|
1244
|
+
"flashinfer",
|
1245
|
+
"triton",
|
1246
|
+
"torch_native",
|
1247
|
+
"fa3",
|
1248
|
+
"flashmla",
|
1249
|
+
"cutlass_mla",
|
1250
|
+
],
|
1251
|
+
default=ServerArgs.prefill_attention_backend,
|
1252
|
+
help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
|
1253
|
+
)
|
1216
1254
|
parser.add_argument(
|
1217
1255
|
"--sampling-backend",
|
1218
1256
|
type=str,
|
@@ -1227,6 +1265,13 @@ class ServerArgs:
|
|
1227
1265
|
default=ServerArgs.grammar_backend,
|
1228
1266
|
help="Choose the backend for grammar-guided decoding.",
|
1229
1267
|
)
|
1268
|
+
parser.add_argument(
|
1269
|
+
"--mm-attention-backend",
|
1270
|
+
type=str,
|
1271
|
+
choices=["sdpa", "fa3", "triton_attn"],
|
1272
|
+
default=ServerArgs.mm_attention_backend,
|
1273
|
+
help="Set multimodal attention backend.",
|
1274
|
+
)
|
1230
1275
|
|
1231
1276
|
# Speculative decoding
|
1232
1277
|
parser.add_argument(
|
@@ -1276,13 +1321,6 @@ class ServerArgs:
|
|
1276
1321
|
help="The path of the draft model's small vocab table.",
|
1277
1322
|
default=ServerArgs.speculative_token_map,
|
1278
1323
|
)
|
1279
|
-
parser.add_argument(
|
1280
|
-
"--mm-attention-backend",
|
1281
|
-
type=str,
|
1282
|
-
choices=["sdpa", "fa3", "triton_attn"],
|
1283
|
-
default=ServerArgs.mm_attention_backend,
|
1284
|
-
help="Set multimodal attention backend.",
|
1285
|
-
)
|
1286
1324
|
|
1287
1325
|
# Expert parallelism
|
1288
1326
|
parser.add_argument(
|
@@ -1298,10 +1336,15 @@ class ServerArgs:
|
|
1298
1336
|
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
|
1299
1337
|
)
|
1300
1338
|
parser.add_argument(
|
1301
|
-
"--enable-flashinfer-moe",
|
1339
|
+
"--enable-flashinfer-cutlass-moe",
|
1302
1340
|
action="store_true",
|
1303
1341
|
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
|
1304
1342
|
)
|
1343
|
+
parser.add_argument(
|
1344
|
+
"--enable-flashinfer-trtllm-moe",
|
1345
|
+
action="store_true",
|
1346
|
+
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP with --enable-ep-moe",
|
1347
|
+
)
|
1305
1348
|
parser.add_argument(
|
1306
1349
|
"--enable-flashinfer-allreduce-fusion",
|
1307
1350
|
action="store_true",
|
@@ -1530,11 +1573,6 @@ class ServerArgs:
|
|
1530
1573
|
action="store_true",
|
1531
1574
|
help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
|
1532
1575
|
)
|
1533
|
-
parser.add_argument(
|
1534
|
-
"--disable-overlap-cg-plan",
|
1535
|
-
action="store_true",
|
1536
|
-
help="Disable the overlap optimization for cudagraph preparation in eagle verify.",
|
1537
|
-
)
|
1538
1576
|
parser.add_argument(
|
1539
1577
|
"--enable-mixed-chunk",
|
1540
1578
|
action="store_true",
|
@@ -1792,11 +1830,11 @@ class ServerArgs:
|
|
1792
1830
|
return hf_config
|
1793
1831
|
|
1794
1832
|
def check_server_args(self):
|
1833
|
+
# Check parallel size constraints
|
1795
1834
|
assert (
|
1796
1835
|
self.tp_size * self.pp_size
|
1797
1836
|
) % self.nnodes == 0, "tp_size must be divisible by number of nodes"
|
1798
1837
|
|
1799
|
-
# FIXME pp constraints
|
1800
1838
|
if self.pp_size > 1:
|
1801
1839
|
assert (
|
1802
1840
|
self.disable_overlap_schedule
|
@@ -1807,11 +1845,7 @@ class ServerArgs:
|
|
1807
1845
|
assert not (
|
1808
1846
|
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
|
1809
1847
|
), "multi-node data parallel is not supported unless dp attention!"
|
1810
|
-
|
1811
|
-
self.max_loras_per_batch > 0
|
1812
|
-
# FIXME
|
1813
|
-
and (self.lora_paths is None or self.disable_radix_cache)
|
1814
|
-
), "compatibility of lora and radix attention is in progress"
|
1848
|
+
|
1815
1849
|
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
|
1816
1850
|
assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
|
1817
1851
|
|
@@ -1820,9 +1854,32 @@ class ServerArgs:
|
|
1820
1854
|
None,
|
1821
1855
|
}, "moe_dense_tp_size only support 1 and None currently"
|
1822
1856
|
|
1857
|
+
# Check model architecture
|
1858
|
+
model_arch = self.get_hf_config().architectures[0]
|
1859
|
+
if "Llama4" in model_arch:
|
1860
|
+
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
|
1861
|
+
|
1862
|
+
# Check LoRA
|
1823
1863
|
self.check_lora_server_args()
|
1824
1864
|
|
1865
|
+
# Check speculative decoding
|
1866
|
+
if self.speculative_algorithm is not None:
|
1867
|
+
assert (
|
1868
|
+
not self.enable_mixed_chunk
|
1869
|
+
), "enable_mixed_chunk is required for speculative decoding"
|
1870
|
+
|
1871
|
+
# Check chunked prefill
|
1872
|
+
assert (
|
1873
|
+
self.chunked_prefill_size % self.page_size == 0
|
1874
|
+
), "chunked_prefill_size must be divisible by page_size"
|
1875
|
+
|
1825
1876
|
def check_lora_server_args(self):
|
1877
|
+
assert (
|
1878
|
+
self.max_loras_per_batch > 0
|
1879
|
+
# FIXME
|
1880
|
+
and (self.lora_paths is None or self.disable_radix_cache)
|
1881
|
+
), "compatibility of lora and radix attention is in progress"
|
1882
|
+
|
1826
1883
|
# Enable LoRA if any LoRA paths are provided for backward compatibility.
|
1827
1884
|
if self.lora_paths:
|
1828
1885
|
if self.enable_lora is None:
|
@@ -1843,9 +1900,24 @@ class ServerArgs:
|
|
1843
1900
|
for lora_path in lora_paths:
|
1844
1901
|
if "=" in lora_path:
|
1845
1902
|
name, path = lora_path.split("=", 1)
|
1846
|
-
self.lora_paths[name] = path
|
1903
|
+
self.lora_paths[name] = LoRARef(lora_name=name, lora_path=path)
|
1847
1904
|
else:
|
1848
|
-
self.lora_paths[lora_path] =
|
1905
|
+
self.lora_paths[lora_path] = LoRARef(
|
1906
|
+
lora_name=lora_path,
|
1907
|
+
lora_path=lora_path,
|
1908
|
+
)
|
1909
|
+
elif isinstance(self.lora_paths, dict):
|
1910
|
+
self.lora_paths = {
|
1911
|
+
k: LoRARef(lora_name=k, lora_path=v)
|
1912
|
+
for k, v in self.lora_paths.items()
|
1913
|
+
}
|
1914
|
+
elif self.lora_paths is None:
|
1915
|
+
self.lora_paths = {}
|
1916
|
+
else:
|
1917
|
+
raise ValueError(
|
1918
|
+
f"Invalid type for --lora-paths: {type(self.lora_paths)}. "
|
1919
|
+
"Expected a list or a dictionary."
|
1920
|
+
)
|
1849
1921
|
|
1850
1922
|
# Expand target modules
|
1851
1923
|
if self.lora_target_modules:
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
|
+
from sglang.srt.layers.dp_attention import DPPaddingMode
|
8
9
|
from sglang.srt.model_executor.cuda_graph_runner import (
|
9
10
|
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
10
11
|
CudaGraphRunner,
|
@@ -97,13 +98,6 @@ class EAGLEDraftCudaGraphRunner:
|
|
97
98
|
)
|
98
99
|
|
99
100
|
if self.require_gathered_buffer:
|
100
|
-
self.gathered_buffer = torch.zeros(
|
101
|
-
(
|
102
|
-
self.max_num_token,
|
103
|
-
self.model_runner.model_config.hidden_size,
|
104
|
-
),
|
105
|
-
dtype=self.model_runner.dtype,
|
106
|
-
)
|
107
101
|
if self.require_mlp_tp_gather:
|
108
102
|
self.global_num_tokens_gpu = torch.zeros(
|
109
103
|
(self.dp_size,), dtype=torch.int32
|
@@ -111,12 +105,30 @@ class EAGLEDraftCudaGraphRunner:
|
|
111
105
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
112
106
|
(self.dp_size,), dtype=torch.int32
|
113
107
|
)
|
108
|
+
self.gathered_buffer = torch.zeros(
|
109
|
+
(
|
110
|
+
self.max_num_token * self.dp_size,
|
111
|
+
self.model_runner.model_config.hidden_size,
|
112
|
+
),
|
113
|
+
dtype=self.model_runner.dtype,
|
114
|
+
)
|
114
115
|
else:
|
115
116
|
assert self.require_attn_tp_gather
|
116
117
|
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
117
118
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
118
119
|
(1,), dtype=torch.int32
|
119
120
|
)
|
121
|
+
self.gathered_buffer = torch.zeros(
|
122
|
+
(
|
123
|
+
self.max_num_token,
|
124
|
+
self.model_runner.model_config.hidden_size,
|
125
|
+
),
|
126
|
+
dtype=self.model_runner.dtype,
|
127
|
+
)
|
128
|
+
else:
|
129
|
+
self.global_num_tokens_gpu = None
|
130
|
+
self.global_num_tokens_for_logprob_gpu = None
|
131
|
+
self.gathered_buffer = None
|
120
132
|
|
121
133
|
# Capture
|
122
134
|
try:
|
@@ -130,9 +142,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
130
142
|
def can_run(self, forward_batch: ForwardBatch):
|
131
143
|
if self.require_mlp_tp_gather:
|
132
144
|
cuda_graph_bs = (
|
133
|
-
|
145
|
+
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
134
146
|
if self.model_runner.spec_algorithm.is_eagle()
|
135
|
-
else
|
147
|
+
else max(forward_batch.global_num_tokens_cpu)
|
136
148
|
)
|
137
149
|
else:
|
138
150
|
cuda_graph_bs = forward_batch.batch_size
|
@@ -168,26 +180,20 @@ class EAGLEDraftCudaGraphRunner:
|
|
168
180
|
if self.require_mlp_tp_gather:
|
169
181
|
self.global_num_tokens_gpu.copy_(
|
170
182
|
torch.tensor(
|
171
|
-
[
|
172
|
-
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
173
|
-
for i in range(self.dp_size)
|
174
|
-
],
|
183
|
+
[num_tokens] * self.dp_size,
|
175
184
|
dtype=torch.int32,
|
176
185
|
device=self.input_ids.device,
|
177
186
|
)
|
178
187
|
)
|
179
188
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
180
189
|
torch.tensor(
|
181
|
-
[
|
182
|
-
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
183
|
-
for i in range(self.dp_size)
|
184
|
-
],
|
190
|
+
[num_tokens] * self.dp_size,
|
185
191
|
dtype=torch.int32,
|
186
192
|
device=self.input_ids.device,
|
187
193
|
)
|
188
194
|
)
|
189
195
|
global_num_tokens = self.global_num_tokens_gpu
|
190
|
-
gathered_buffer = self.gathered_buffer[:num_tokens]
|
196
|
+
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
|
191
197
|
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
192
198
|
elif self.require_attn_tp_gather:
|
193
199
|
self.global_num_tokens_gpu.copy_(
|
@@ -233,6 +239,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
233
239
|
return_logprob=False,
|
234
240
|
positions=positions,
|
235
241
|
global_num_tokens_gpu=global_num_tokens,
|
242
|
+
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
|
236
243
|
gathered_buffer=gathered_buffer,
|
237
244
|
spec_algorithm=self.model_runner.spec_algorithm,
|
238
245
|
spec_info=spec_info,
|
@@ -290,12 +297,13 @@ class EAGLEDraftCudaGraphRunner:
|
|
290
297
|
|
291
298
|
# Pad
|
292
299
|
if self.require_mlp_tp_gather:
|
293
|
-
|
294
|
-
|
300
|
+
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
|
301
|
+
max_batch_size = (
|
302
|
+
max_num_tokens // self.num_tokens_per_bs
|
295
303
|
if self.model_runner.spec_algorithm.is_eagle()
|
296
|
-
else
|
304
|
+
else max_num_tokens
|
297
305
|
)
|
298
|
-
index = bisect.bisect_left(self.capture_bs,
|
306
|
+
index = bisect.bisect_left(self.capture_bs, max_batch_size)
|
299
307
|
else:
|
300
308
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
301
309
|
bs = self.capture_bs[index]
|
@@ -316,12 +324,10 @@ class EAGLEDraftCudaGraphRunner:
|
|
316
324
|
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
|
317
325
|
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
318
326
|
|
327
|
+
# TODO(ch-wan): support num_token_non_padded
|
319
328
|
if self.require_gathered_buffer:
|
320
|
-
self.global_num_tokens_gpu.
|
321
|
-
self.global_num_tokens_for_logprob_gpu.
|
322
|
-
forward_batch.global_num_tokens_for_logprob_gpu
|
323
|
-
)
|
324
|
-
forward_batch.gathered_buffer = self.gathered_buffer
|
329
|
+
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
330
|
+
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
325
331
|
|
326
332
|
# Attention backend
|
327
333
|
if bs != raw_bs:
|
@@ -330,7 +336,6 @@ class EAGLEDraftCudaGraphRunner:
|
|
330
336
|
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
|
331
337
|
forward_batch.positions = self.positions[:num_tokens]
|
332
338
|
|
333
|
-
# Special handle for seq_len_cpu used when flashinfer mla is used
|
334
339
|
if forward_batch.seq_lens_cpu is not None:
|
335
340
|
if bs != raw_bs:
|
336
341
|
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|