sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/_custom_ops.py +29 -1
  3. sglang/srt/configs/internvl.py +3 -0
  4. sglang/srt/configs/model_config.py +5 -1
  5. sglang/srt/constrained/base_grammar_backend.py +10 -2
  6. sglang/srt/constrained/xgrammar_backend.py +7 -5
  7. sglang/srt/conversation.py +17 -2
  8. sglang/srt/debug_utils/__init__.py +0 -0
  9. sglang/srt/debug_utils/dump_comparator.py +131 -0
  10. sglang/srt/debug_utils/dumper.py +108 -0
  11. sglang/srt/debug_utils/text_comparator.py +172 -0
  12. sglang/srt/disaggregation/common/conn.py +34 -6
  13. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  14. sglang/srt/disaggregation/mini_lb.py +3 -2
  15. sglang/srt/disaggregation/mooncake/conn.py +65 -20
  16. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  17. sglang/srt/disaggregation/nixl/conn.py +17 -13
  18. sglang/srt/disaggregation/prefill.py +13 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  21. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  23. sglang/srt/distributed/parallel_state.py +70 -15
  24. sglang/srt/entrypoints/engine.py +5 -9
  25. sglang/srt/entrypoints/http_server.py +20 -32
  26. sglang/srt/entrypoints/openai/protocol.py +3 -3
  27. sglang/srt/entrypoints/openai/serving_chat.py +148 -72
  28. sglang/srt/function_call/base_format_detector.py +74 -12
  29. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  30. sglang/srt/function_call/ebnf_composer.py +105 -66
  31. sglang/srt/function_call/function_call_parser.py +6 -4
  32. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  33. sglang/srt/function_call/kimik2_detector.py +41 -16
  34. sglang/srt/function_call/llama32_detector.py +6 -3
  35. sglang/srt/function_call/mistral_detector.py +11 -3
  36. sglang/srt/function_call/pythonic_detector.py +16 -14
  37. sglang/srt/function_call/qwen25_detector.py +12 -3
  38. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
  39. sglang/srt/layers/activation.py +11 -3
  40. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  41. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  42. sglang/srt/layers/attention/vision.py +56 -8
  43. sglang/srt/layers/communicator.py +12 -12
  44. sglang/srt/layers/dp_attention.py +72 -24
  45. sglang/srt/layers/layernorm.py +26 -1
  46. sglang/srt/layers/logits_processor.py +46 -25
  47. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  48. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  51. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  52. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  53. sglang/srt/layers/moe/topk.py +88 -34
  54. sglang/srt/layers/multimodal.py +11 -8
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  56. sglang/srt/layers/quantization/fp8.py +25 -247
  57. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  58. sglang/srt/layers/quantization/modelopt_quant.py +33 -14
  59. sglang/srt/layers/quantization/unquant.py +24 -76
  60. sglang/srt/layers/quantization/utils.py +0 -9
  61. sglang/srt/layers/quantization/w4afp8.py +68 -17
  62. sglang/srt/layers/radix_attention.py +5 -3
  63. sglang/srt/lora/lora_manager.py +133 -169
  64. sglang/srt/lora/lora_registry.py +188 -0
  65. sglang/srt/lora/mem_pool.py +2 -2
  66. sglang/srt/managers/cache_controller.py +62 -13
  67. sglang/srt/managers/io_struct.py +19 -1
  68. sglang/srt/managers/mm_utils.py +154 -35
  69. sglang/srt/managers/multimodal_processor.py +3 -14
  70. sglang/srt/managers/schedule_batch.py +27 -11
  71. sglang/srt/managers/scheduler.py +48 -26
  72. sglang/srt/managers/tokenizer_manager.py +62 -28
  73. sglang/srt/managers/tp_worker.py +5 -4
  74. sglang/srt/mem_cache/allocator.py +67 -7
  75. sglang/srt/mem_cache/hicache_storage.py +17 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +35 -18
  77. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  78. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  79. sglang/srt/model_executor/forward_batch_info.py +201 -29
  80. sglang/srt/model_executor/model_runner.py +109 -37
  81. sglang/srt/models/deepseek_v2.py +63 -30
  82. sglang/srt/models/glm4_moe.py +1035 -0
  83. sglang/srt/models/glm4_moe_nextn.py +167 -0
  84. sglang/srt/models/interns1.py +328 -0
  85. sglang/srt/models/internvl.py +143 -47
  86. sglang/srt/models/llava.py +9 -5
  87. sglang/srt/models/minicpmo.py +4 -1
  88. sglang/srt/models/mllama4.py +10 -3
  89. sglang/srt/models/qwen2_moe.py +2 -6
  90. sglang/srt/models/qwen3_moe.py +6 -8
  91. sglang/srt/multimodal/processors/base_processor.py +20 -6
  92. sglang/srt/multimodal/processors/clip.py +2 -2
  93. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  94. sglang/srt/multimodal/processors/gemma3.py +2 -2
  95. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  96. sglang/srt/multimodal/processors/internvl.py +21 -8
  97. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  98. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  99. sglang/srt/multimodal/processors/llava.py +4 -4
  100. sglang/srt/multimodal/processors/minicpm.py +2 -3
  101. sglang/srt/multimodal/processors/mlama.py +2 -2
  102. sglang/srt/multimodal/processors/mllama4.py +18 -111
  103. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  104. sglang/srt/multimodal/processors/pixtral.py +2 -2
  105. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  106. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  107. sglang/srt/multimodal/processors/vila.py +3 -1
  108. sglang/srt/reasoning_parser.py +48 -5
  109. sglang/srt/sampling/sampling_batch_info.py +6 -5
  110. sglang/srt/server_args.py +132 -60
  111. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  112. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  113. sglang/srt/speculative/eagle_utils.py +51 -23
  114. sglang/srt/speculative/eagle_worker.py +59 -44
  115. sglang/srt/two_batch_overlap.py +9 -5
  116. sglang/srt/utils.py +113 -69
  117. sglang/srt/weight_sync/utils.py +119 -0
  118. sglang/test/runners.py +4 -0
  119. sglang/test/test_activation.py +50 -1
  120. sglang/test/test_utils.py +65 -5
  121. sglang/utils.py +19 -0
  122. sglang/version.py +1 -1
  123. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
  124. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
  125. sglang/srt/debug_utils.py +0 -74
  126. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  127. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  128. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -20,10 +20,10 @@ import logging
20
20
  import os
21
21
  import random
22
22
  import tempfile
23
- from token import OP
24
23
  from typing import List, Literal, Optional, Union
25
24
 
26
25
  from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
26
+ from sglang.srt.lora.lora_registry import LoRARef
27
27
  from sglang.srt.reasoning_parser import ReasoningParser
28
28
  from sglang.srt.utils import (
29
29
  LORA_TARGET_ALL_MODULES,
@@ -80,7 +80,7 @@ class ServerArgs:
80
80
  schedule_policy: str = "fcfs"
81
81
  schedule_conservativeness: float = 1.0
82
82
  cpu_offload_gb: int = 0
83
- page_size: int = 1
83
+ page_size: Optional[int] = None
84
84
  hybrid_kvcache_ratio: Optional[float] = None
85
85
  swa_full_tokens_ratio: float = 0.8
86
86
  disable_hybrid_swa_memory: bool = False
@@ -145,12 +145,14 @@ class ServerArgs:
145
145
  enable_lora: Optional[bool] = None
146
146
  max_lora_rank: Optional[int] = None
147
147
  lora_target_modules: Optional[Union[set[str], List[str]]] = None
148
- lora_paths: Optional[Union[dict[str, str], List[str]]] = None
148
+ lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
149
149
  max_loras_per_batch: int = 8
150
150
  lora_backend: str = "triton"
151
151
 
152
152
  # Kernel backend
153
153
  attention_backend: Optional[str] = None
154
+ decode_attention_backend: Optional[str] = None
155
+ prefill_attention_backend: Optional[str] = None
154
156
  sampling_backend: Optional[str] = None
155
157
  grammar_backend: Optional[str] = None
156
158
  mm_attention_backend: Optional[str] = None
@@ -169,7 +171,8 @@ class ServerArgs:
169
171
  ep_size: int = 1
170
172
  enable_ep_moe: bool = False
171
173
  enable_deepep_moe: bool = False
172
- enable_flashinfer_moe: bool = False
174
+ enable_flashinfer_cutlass_moe: bool = False
175
+ enable_flashinfer_trtllm_moe: bool = False
173
176
  enable_flashinfer_allreduce_fusion: bool = False
174
177
  deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
175
178
  ep_num_redundant_experts: int = 0
@@ -266,31 +269,20 @@ class ServerArgs:
266
269
 
267
270
  def __post_init__(self):
268
271
  # Expert parallelism
272
+ # We put it here first due to some internal ckpt conversation issues.
269
273
  if self.enable_ep_moe:
270
274
  self.ep_size = self.tp_size
271
275
  logger.warning(
272
276
  f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
273
277
  )
274
- if self.enable_flashinfer_moe:
275
- assert (
276
- self.quantization == "modelopt_fp4"
277
- ), "modelopt_fp4 quantization is required for Flashinfer MOE"
278
- os.environ["TRTLLM_ENABLE_PDL"] = "1"
279
- self.disable_shared_experts_fusion = True
280
- logger.warning(
281
- f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
282
- )
283
278
 
284
279
  # Set missing default values
285
280
  if self.tokenizer_path is None:
286
281
  self.tokenizer_path = self.model_path
287
-
288
- if self.device is None:
289
- self.device = get_device()
290
-
291
282
  if self.served_model_name is None:
292
283
  self.served_model_name = self.model_path
293
-
284
+ if self.device is None:
285
+ self.device = get_device()
294
286
  if self.random_seed is None:
295
287
  self.random_seed = random.randint(0, 1 << 30)
296
288
 
@@ -359,7 +351,6 @@ class ServerArgs:
359
351
  self.chunked_prefill_size = 16384
360
352
  else:
361
353
  self.chunked_prefill_size = 4096
362
- assert self.chunked_prefill_size % self.page_size == 0
363
354
 
364
355
  # Set cuda graph max batch size
365
356
  if self.cuda_graph_max_bs is None:
@@ -398,18 +389,32 @@ class ServerArgs:
398
389
  )
399
390
  self.page_size = 128
400
391
 
401
- if self.attention_backend == "flashmla":
392
+ if (
393
+ self.attention_backend == "flashmla"
394
+ or self.decode_attention_backend == "flashmla"
395
+ ):
402
396
  logger.warning(
403
397
  "FlashMLA only supports a page_size of 64, change page_size to 64."
404
398
  )
405
399
  self.page_size = 64
406
400
 
407
- if self.attention_backend == "cutlass_mla":
401
+ if (
402
+ self.attention_backend == "cutlass_mla"
403
+ or self.decode_attention_backend == "cutlass_mla"
404
+ ):
408
405
  logger.warning(
409
406
  "Cutlass MLA only supports a page_size of 128, change page_size to 128."
410
407
  )
411
408
  self.page_size = 128
412
409
 
410
+ # Set page size
411
+ if self.page_size is None:
412
+ self.page_size = 1
413
+
414
+ # AMD-specific Triton attention KV splits default number
415
+ if is_hip():
416
+ self.triton_attention_num_kv_splits = 16
417
+
413
418
  # Choose grammar backend
414
419
  if self.grammar_backend is None:
415
420
  self.grammar_backend = "xgrammar"
@@ -431,6 +436,17 @@ class ServerArgs:
431
436
  self.enable_dp_attention
432
437
  ), "Please enable dp attention when setting enable_dp_lm_head. "
433
438
 
439
+ # MoE kernel
440
+ if self.enable_flashinfer_cutlass_moe:
441
+ assert (
442
+ self.quantization == "modelopt_fp4"
443
+ ), "modelopt_fp4 quantization is required for Flashinfer MOE"
444
+ os.environ["TRTLLM_ENABLE_PDL"] = "1"
445
+
446
+ if self.enable_flashinfer_trtllm_moe:
447
+ assert self.enable_ep_moe, "EP MoE is required for Flashinfer TRTLLM MOE"
448
+ logger.warning(f"Flashinfer TRTLLM MoE is enabled.")
449
+
434
450
  # DeepEP MoE
435
451
  if self.enable_deepep_moe:
436
452
  if self.deepep_mode == "normal":
@@ -455,6 +471,9 @@ class ServerArgs:
455
471
  "EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
456
472
  )
457
473
 
474
+ if self.enable_eplb:
475
+ assert self.enable_ep_moe or self.enable_deepep_moe
476
+
458
477
  if self.enable_expert_distribution_metrics and (
459
478
  self.expert_distribution_recorder_mode is None
460
479
  ):
@@ -494,7 +513,7 @@ class ServerArgs:
494
513
  )
495
514
 
496
515
  model_arch = self.get_hf_config().architectures[0]
497
- if model_arch == "DeepseekV3ForCausalLM":
516
+ if model_arch in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"]:
498
517
  # Auto set draft_model_path DeepSeek-V3/R1
499
518
  if self.speculative_draft_model_path is None:
500
519
  self.speculative_draft_model_path = self.model_path
@@ -502,14 +521,6 @@ class ServerArgs:
502
521
  logger.warning(
503
522
  "DeepSeek MTP does not require setting speculative_draft_model_path."
504
523
  )
505
- elif "Llama4" in model_arch:
506
- # TODO: remove this after Llama4 supports in other backends
507
- if self.attention_backend != "fa3":
508
- self.attention_backend = "fa3"
509
- logger.warning(
510
- "Llama4 requires using fa3 attention backend. "
511
- "Attention backend is automatically set to fa3."
512
- )
513
524
 
514
525
  # Auto choose parameters
515
526
  if self.speculative_num_steps is None:
@@ -542,12 +553,11 @@ class ServerArgs:
542
553
  ) and check_gguf_file(self.model_path):
543
554
  self.quantization = self.load_format = "gguf"
544
555
 
556
+ # Model loading
545
557
  if is_remote_url(self.model_path):
546
558
  self.load_format = "remote"
547
-
548
- # AMD-specific Triton attention KV splits default number
549
- if is_hip():
550
- self.triton_attention_num_kv_splits = 16
559
+ if self.custom_weight_loader is None:
560
+ self.custom_weight_loader = []
551
561
 
552
562
  # PD disaggregation
553
563
  if self.disaggregation_mode == "decode":
@@ -572,6 +582,7 @@ class ServerArgs:
572
582
  self.disable_cuda_graph = True
573
583
  logger.warning("Cuda graph is disabled for prefill server")
574
584
 
585
+ # Propagate env vars
575
586
  os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
576
587
  "1" if self.enable_torch_compile else "0"
577
588
  )
@@ -580,9 +591,6 @@ class ServerArgs:
580
591
  "1" if self.disable_outlines_disk_cache else "0"
581
592
  )
582
593
 
583
- if self.custom_weight_loader is None:
584
- self.custom_weight_loader = []
585
-
586
594
  @staticmethod
587
595
  def add_cli_args(parser: argparse.ArgumentParser):
588
596
  # Model and tokenizer
@@ -1099,10 +1107,11 @@ class ServerArgs:
1099
1107
  "deepseekv3",
1100
1108
  "pythonic",
1101
1109
  "kimi_k2",
1102
- "qwen3",
1110
+ "qwen3_coder",
1111
+ "glm45",
1103
1112
  ],
1104
1113
  default=ServerArgs.tool_call_parser,
1105
- help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.",
1114
+ help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.",
1106
1115
  )
1107
1116
 
1108
1117
  # Data parallelism
@@ -1213,6 +1222,35 @@ class ServerArgs:
1213
1222
  default=ServerArgs.attention_backend,
1214
1223
  help="Choose the kernels for attention layers.",
1215
1224
  )
1225
+ parser.add_argument(
1226
+ "--decode-attention-backend",
1227
+ type=str,
1228
+ choices=[
1229
+ "flashinfer",
1230
+ "triton",
1231
+ "torch_native",
1232
+ "fa3",
1233
+ "flashmla",
1234
+ "cutlass_mla",
1235
+ ],
1236
+ default=ServerArgs.decode_attention_backend,
1237
+ help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
1238
+ )
1239
+
1240
+ parser.add_argument(
1241
+ "--prefill-attention-backend",
1242
+ type=str,
1243
+ choices=[
1244
+ "flashinfer",
1245
+ "triton",
1246
+ "torch_native",
1247
+ "fa3",
1248
+ "flashmla",
1249
+ "cutlass_mla",
1250
+ ],
1251
+ default=ServerArgs.prefill_attention_backend,
1252
+ help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
1253
+ )
1216
1254
  parser.add_argument(
1217
1255
  "--sampling-backend",
1218
1256
  type=str,
@@ -1227,6 +1265,13 @@ class ServerArgs:
1227
1265
  default=ServerArgs.grammar_backend,
1228
1266
  help="Choose the backend for grammar-guided decoding.",
1229
1267
  )
1268
+ parser.add_argument(
1269
+ "--mm-attention-backend",
1270
+ type=str,
1271
+ choices=["sdpa", "fa3", "triton_attn"],
1272
+ default=ServerArgs.mm_attention_backend,
1273
+ help="Set multimodal attention backend.",
1274
+ )
1230
1275
 
1231
1276
  # Speculative decoding
1232
1277
  parser.add_argument(
@@ -1276,13 +1321,6 @@ class ServerArgs:
1276
1321
  help="The path of the draft model's small vocab table.",
1277
1322
  default=ServerArgs.speculative_token_map,
1278
1323
  )
1279
- parser.add_argument(
1280
- "--mm-attention-backend",
1281
- type=str,
1282
- choices=["sdpa", "fa3", "triton_attn"],
1283
- default=ServerArgs.mm_attention_backend,
1284
- help="Set multimodal attention backend.",
1285
- )
1286
1324
 
1287
1325
  # Expert parallelism
1288
1326
  parser.add_argument(
@@ -1298,10 +1336,15 @@ class ServerArgs:
1298
1336
  help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
1299
1337
  )
1300
1338
  parser.add_argument(
1301
- "--enable-flashinfer-moe",
1339
+ "--enable-flashinfer-cutlass-moe",
1302
1340
  action="store_true",
1303
1341
  help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
1304
1342
  )
1343
+ parser.add_argument(
1344
+ "--enable-flashinfer-trtllm-moe",
1345
+ action="store_true",
1346
+ help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP with --enable-ep-moe",
1347
+ )
1305
1348
  parser.add_argument(
1306
1349
  "--enable-flashinfer-allreduce-fusion",
1307
1350
  action="store_true",
@@ -1530,11 +1573,6 @@ class ServerArgs:
1530
1573
  action="store_true",
1531
1574
  help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
1532
1575
  )
1533
- parser.add_argument(
1534
- "--disable-overlap-cg-plan",
1535
- action="store_true",
1536
- help="Disable the overlap optimization for cudagraph preparation in eagle verify.",
1537
- )
1538
1576
  parser.add_argument(
1539
1577
  "--enable-mixed-chunk",
1540
1578
  action="store_true",
@@ -1792,11 +1830,11 @@ class ServerArgs:
1792
1830
  return hf_config
1793
1831
 
1794
1832
  def check_server_args(self):
1833
+ # Check parallel size constraints
1795
1834
  assert (
1796
1835
  self.tp_size * self.pp_size
1797
1836
  ) % self.nnodes == 0, "tp_size must be divisible by number of nodes"
1798
1837
 
1799
- # FIXME pp constraints
1800
1838
  if self.pp_size > 1:
1801
1839
  assert (
1802
1840
  self.disable_overlap_schedule
@@ -1807,11 +1845,7 @@ class ServerArgs:
1807
1845
  assert not (
1808
1846
  self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
1809
1847
  ), "multi-node data parallel is not supported unless dp attention!"
1810
- assert (
1811
- self.max_loras_per_batch > 0
1812
- # FIXME
1813
- and (self.lora_paths is None or self.disable_radix_cache)
1814
- ), "compatibility of lora and radix attention is in progress"
1848
+
1815
1849
  assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
1816
1850
  assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
1817
1851
 
@@ -1820,9 +1854,32 @@ class ServerArgs:
1820
1854
  None,
1821
1855
  }, "moe_dense_tp_size only support 1 and None currently"
1822
1856
 
1857
+ # Check model architecture
1858
+ model_arch = self.get_hf_config().architectures[0]
1859
+ if "Llama4" in model_arch:
1860
+ assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
1861
+
1862
+ # Check LoRA
1823
1863
  self.check_lora_server_args()
1824
1864
 
1865
+ # Check speculative decoding
1866
+ if self.speculative_algorithm is not None:
1867
+ assert (
1868
+ not self.enable_mixed_chunk
1869
+ ), "enable_mixed_chunk is required for speculative decoding"
1870
+
1871
+ # Check chunked prefill
1872
+ assert (
1873
+ self.chunked_prefill_size % self.page_size == 0
1874
+ ), "chunked_prefill_size must be divisible by page_size"
1875
+
1825
1876
  def check_lora_server_args(self):
1877
+ assert (
1878
+ self.max_loras_per_batch > 0
1879
+ # FIXME
1880
+ and (self.lora_paths is None or self.disable_radix_cache)
1881
+ ), "compatibility of lora and radix attention is in progress"
1882
+
1826
1883
  # Enable LoRA if any LoRA paths are provided for backward compatibility.
1827
1884
  if self.lora_paths:
1828
1885
  if self.enable_lora is None:
@@ -1843,9 +1900,24 @@ class ServerArgs:
1843
1900
  for lora_path in lora_paths:
1844
1901
  if "=" in lora_path:
1845
1902
  name, path = lora_path.split("=", 1)
1846
- self.lora_paths[name] = path
1903
+ self.lora_paths[name] = LoRARef(lora_name=name, lora_path=path)
1847
1904
  else:
1848
- self.lora_paths[lora_path] = lora_path
1905
+ self.lora_paths[lora_path] = LoRARef(
1906
+ lora_name=lora_path,
1907
+ lora_path=lora_path,
1908
+ )
1909
+ elif isinstance(self.lora_paths, dict):
1910
+ self.lora_paths = {
1911
+ k: LoRARef(lora_name=k, lora_path=v)
1912
+ for k, v in self.lora_paths.items()
1913
+ }
1914
+ elif self.lora_paths is None:
1915
+ self.lora_paths = {}
1916
+ else:
1917
+ raise ValueError(
1918
+ f"Invalid type for --lora-paths: {type(self.lora_paths)}. "
1919
+ "Expected a list or a dictionary."
1920
+ )
1849
1921
 
1850
1922
  # Expand target modules
1851
1923
  if self.lora_target_modules:
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
5
5
 
6
6
  import torch
7
7
 
8
+ from sglang.srt.layers.dp_attention import DPPaddingMode
8
9
  from sglang.srt.model_executor.cuda_graph_runner import (
9
10
  CUDA_GRAPH_CAPTURE_FAILED_MSG,
10
11
  CudaGraphRunner,
@@ -97,13 +98,6 @@ class EAGLEDraftCudaGraphRunner:
97
98
  )
98
99
 
99
100
  if self.require_gathered_buffer:
100
- self.gathered_buffer = torch.zeros(
101
- (
102
- self.max_num_token,
103
- self.model_runner.model_config.hidden_size,
104
- ),
105
- dtype=self.model_runner.dtype,
106
- )
107
101
  if self.require_mlp_tp_gather:
108
102
  self.global_num_tokens_gpu = torch.zeros(
109
103
  (self.dp_size,), dtype=torch.int32
@@ -111,12 +105,30 @@ class EAGLEDraftCudaGraphRunner:
111
105
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
112
106
  (self.dp_size,), dtype=torch.int32
113
107
  )
108
+ self.gathered_buffer = torch.zeros(
109
+ (
110
+ self.max_num_token * self.dp_size,
111
+ self.model_runner.model_config.hidden_size,
112
+ ),
113
+ dtype=self.model_runner.dtype,
114
+ )
114
115
  else:
115
116
  assert self.require_attn_tp_gather
116
117
  self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
117
118
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
118
119
  (1,), dtype=torch.int32
119
120
  )
121
+ self.gathered_buffer = torch.zeros(
122
+ (
123
+ self.max_num_token,
124
+ self.model_runner.model_config.hidden_size,
125
+ ),
126
+ dtype=self.model_runner.dtype,
127
+ )
128
+ else:
129
+ self.global_num_tokens_gpu = None
130
+ self.global_num_tokens_for_logprob_gpu = None
131
+ self.gathered_buffer = None
120
132
 
121
133
  # Capture
122
134
  try:
@@ -130,9 +142,9 @@ class EAGLEDraftCudaGraphRunner:
130
142
  def can_run(self, forward_batch: ForwardBatch):
131
143
  if self.require_mlp_tp_gather:
132
144
  cuda_graph_bs = (
133
- sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
145
+ max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
134
146
  if self.model_runner.spec_algorithm.is_eagle()
135
- else sum(forward_batch.global_num_tokens_cpu)
147
+ else max(forward_batch.global_num_tokens_cpu)
136
148
  )
137
149
  else:
138
150
  cuda_graph_bs = forward_batch.batch_size
@@ -168,26 +180,20 @@ class EAGLEDraftCudaGraphRunner:
168
180
  if self.require_mlp_tp_gather:
169
181
  self.global_num_tokens_gpu.copy_(
170
182
  torch.tensor(
171
- [
172
- num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
173
- for i in range(self.dp_size)
174
- ],
183
+ [num_tokens] * self.dp_size,
175
184
  dtype=torch.int32,
176
185
  device=self.input_ids.device,
177
186
  )
178
187
  )
179
188
  self.global_num_tokens_for_logprob_gpu.copy_(
180
189
  torch.tensor(
181
- [
182
- num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
183
- for i in range(self.dp_size)
184
- ],
190
+ [num_tokens] * self.dp_size,
185
191
  dtype=torch.int32,
186
192
  device=self.input_ids.device,
187
193
  )
188
194
  )
189
195
  global_num_tokens = self.global_num_tokens_gpu
190
- gathered_buffer = self.gathered_buffer[:num_tokens]
196
+ gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
191
197
  global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
192
198
  elif self.require_attn_tp_gather:
193
199
  self.global_num_tokens_gpu.copy_(
@@ -233,6 +239,7 @@ class EAGLEDraftCudaGraphRunner:
233
239
  return_logprob=False,
234
240
  positions=positions,
235
241
  global_num_tokens_gpu=global_num_tokens,
242
+ dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
236
243
  gathered_buffer=gathered_buffer,
237
244
  spec_algorithm=self.model_runner.spec_algorithm,
238
245
  spec_info=spec_info,
@@ -290,12 +297,13 @@ class EAGLEDraftCudaGraphRunner:
290
297
 
291
298
  # Pad
292
299
  if self.require_mlp_tp_gather:
293
- total_batch_size = (
294
- sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
300
+ max_num_tokens = max(forward_batch.global_num_tokens_cpu)
301
+ max_batch_size = (
302
+ max_num_tokens // self.num_tokens_per_bs
295
303
  if self.model_runner.spec_algorithm.is_eagle()
296
- else sum(forward_batch.global_num_tokens_cpu)
304
+ else max_num_tokens
297
305
  )
298
- index = bisect.bisect_left(self.capture_bs, total_batch_size)
306
+ index = bisect.bisect_left(self.capture_bs, max_batch_size)
299
307
  else:
300
308
  index = bisect.bisect_left(self.capture_bs, raw_bs)
301
309
  bs = self.capture_bs[index]
@@ -316,12 +324,10 @@ class EAGLEDraftCudaGraphRunner:
316
324
  self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
317
325
  self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
318
326
 
327
+ # TODO(ch-wan): support num_token_non_padded
319
328
  if self.require_gathered_buffer:
320
- self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
321
- self.global_num_tokens_for_logprob_gpu.copy_(
322
- forward_batch.global_num_tokens_for_logprob_gpu
323
- )
324
- forward_batch.gathered_buffer = self.gathered_buffer
329
+ self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
330
+ self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
325
331
 
326
332
  # Attention backend
327
333
  if bs != raw_bs:
@@ -330,7 +336,6 @@ class EAGLEDraftCudaGraphRunner:
330
336
  forward_batch.req_pool_indices = self.req_pool_indices[:bs]
331
337
  forward_batch.positions = self.positions[:num_tokens]
332
338
 
333
- # Special handle for seq_len_cpu used when flashinfer mla is used
334
339
  if forward_batch.seq_lens_cpu is not None:
335
340
  if bs != raw_bs:
336
341
  self.seq_lens_cpu.fill_(self.seq_len_fill_value)