sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -24,7 +24,7 @@ import tempfile
24
24
  from typing import List, Literal, Optional, Union
25
25
 
26
26
  from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
27
- from sglang.srt.layers.utils import is_sm100_supported
27
+ from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
28
28
  from sglang.srt.lora.lora_registry import LoRARef
29
29
  from sglang.srt.reasoning_parser import ReasoningParser
30
30
  from sglang.srt.utils import (
@@ -37,7 +37,6 @@ from sglang.srt.utils import (
37
37
  is_hip,
38
38
  is_port_available,
39
39
  is_remote_url,
40
- is_triton_kernels_available,
41
40
  is_valid_ipv6_address,
42
41
  nullable_str,
43
42
  )
@@ -109,7 +108,7 @@ class ServerArgs:
109
108
  log_level: str = "info"
110
109
  log_level_http: Optional[str] = None
111
110
  log_requests: bool = False
112
- log_requests_level: int = 0
111
+ log_requests_level: int = 2
113
112
  crash_dump_folder: Optional[str] = None
114
113
  show_time_cost: bool = False
115
114
  enable_metrics: bool = False
@@ -125,12 +124,14 @@ class ServerArgs:
125
124
  # API related
126
125
  api_key: Optional[str] = None
127
126
  served_model_name: Optional[str] = None
127
+ weight_version: str = "default"
128
128
  chat_template: Optional[str] = None
129
129
  completion_template: Optional[str] = None
130
130
  file_storage_path: str = "sglang_storage"
131
131
  enable_cache_report: bool = False
132
132
  reasoning_parser: Optional[str] = None
133
133
  tool_call_parser: Optional[str] = None
134
+ tool_server: Optional[str] = None
134
135
 
135
136
  # Data parallelism
136
137
  dp_size: int = 1
@@ -278,15 +279,11 @@ class ServerArgs:
278
279
  enable_pdmux: bool = False
279
280
  sm_group_num: int = 3
280
281
 
281
- # For tool server
282
- tool_server: Optional[str] = None
283
-
284
282
  # Deprecated arguments
285
283
  enable_ep_moe: bool = False
286
284
  enable_deepep_moe: bool = False
287
285
 
288
286
  def __post_init__(self):
289
-
290
287
  # Check deprecated arguments
291
288
  def print_deprecated_warning(message: str):
292
289
  logger.warning(f"\033[33m{message}\033[0m")
@@ -392,6 +389,9 @@ class ServerArgs:
392
389
  self.attention_backend = "torch_native"
393
390
  self.sampling_backend = "pytorch"
394
391
 
392
+ # Model-specific adjustments
393
+ self.model_specific_adjustments()
394
+
395
395
  # Set kernel backends
396
396
  if self.device == "cpu":
397
397
  if self.attention_backend is None:
@@ -433,7 +433,10 @@ class ServerArgs:
433
433
  )
434
434
  self.page_size = 128
435
435
 
436
- if self.attention_backend == "trtllm_mla":
436
+ if (
437
+ self.attention_backend == "trtllm_mla"
438
+ or self.decode_attention_backend == "trtllm_mla"
439
+ ):
437
440
  if not is_sm100_supported():
438
441
  raise ValueError(
439
442
  "TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
@@ -444,11 +447,17 @@ class ServerArgs:
444
447
  f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64."
445
448
  )
446
449
  self.page_size = 64
450
+
447
451
  if self.speculative_algorithm is not None:
448
452
  raise ValueError(
449
453
  "trtllm_mla backend does not support speculative decoding yet."
450
454
  )
451
455
 
456
+ if self.kv_cache_dtype not in ["fp8_e4m3", "auto"]:
457
+ raise ValueError(
458
+ "TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto."
459
+ )
460
+
452
461
  if (
453
462
  self.attention_backend == "trtllm_mha"
454
463
  or self.decode_attention_backend == "trtllm_mha"
@@ -470,55 +479,9 @@ class ServerArgs:
470
479
  "trtllm_mha backend does not support speculative decoding yet."
471
480
  )
472
481
 
473
- model_arch = self.get_hf_config().architectures[0]
474
- if model_arch in ["GptOssForCausalLM"]:
475
- if self.attention_backend is None:
476
- # default is triton, but we could have trtllm_mha as an option
477
- self.attention_backend = "triton"
478
- assert (
479
- self.attention_backend == "trtllm_mha"
480
- or self.attention_backend == "triton"
481
- )
482
- quantization_config = getattr(
483
- self.get_hf_config(), "quantization_config", None
484
- )
485
- is_mxfp4_quant_format = (
486
- quantization_config is not None
487
- and quantization_config.get("quant_method") == "mxfp4"
488
- )
489
-
490
- if is_sm100_supported() and is_mxfp4_quant_format:
491
- self.enable_flashinfer_mxfp4_moe = True
492
- self.enable_triton_kernel_moe = False
493
- logger.info(
494
- "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
495
- )
496
- else:
497
- if self.enable_triton_kernel_moe:
498
- assert (
499
- self.ep_size == 1
500
- ), "Triton kernel MoE is only supported when ep_size == 1"
501
- if not self.enable_triton_kernel_moe and self.ep_size == 1:
502
- self.enable_triton_kernel_moe = True
503
- logger.info(
504
- "Detected GPT-OSS model, enabling triton_kernels MOE kernel."
505
- )
506
-
507
- self.disable_hybrid_swa_memory = True
508
-
509
- if is_mxfp4_quant_format:
510
- # use bf16 for mxfp4 triton kernels
511
- self.dtype = "bfloat16"
512
-
513
482
  if self.attention_backend == "dual_chunk_flash_attn":
514
483
  logger.warning(
515
- "Mixed chunk is disabled because of using dual chunk flash attention backend"
516
- )
517
- logger.warning(
518
- "Radix cache is disabled because of using dual chunk flash attention backend"
519
- )
520
- logger.warning(
521
- "Cuda graph is disabled because of using dual chunk flash attention backend"
484
+ "Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend"
522
485
  )
523
486
  self.enable_mixed_chunk = False
524
487
  self.disable_cuda_graph = True
@@ -583,7 +546,7 @@ class ServerArgs:
583
546
 
584
547
  if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
585
548
  self.expert_distribution_recorder_mode = "stat"
586
- logger.info(
549
+ logger.warning(
587
550
  "EPLB is enabled. The expert_distribution_recorder_mode is automatically set."
588
551
  )
589
552
 
@@ -591,9 +554,6 @@ class ServerArgs:
591
554
  self.ep_dispatch_algorithm is None
592
555
  ):
593
556
  self.ep_dispatch_algorithm = "static"
594
- logger.info(
595
- "EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
596
- )
597
557
 
598
558
  if self.enable_eplb:
599
559
  assert self.ep_size > 1 or self.moe_a2a_backend is not None
@@ -616,6 +576,12 @@ class ServerArgs:
616
576
  "Pipeline parallelism is incompatible with overlap schedule."
617
577
  )
618
578
 
579
+ # Hicache
580
+ if self.hicache_storage_backend == "mooncake":
581
+ # to use mooncake storage backend, the following conditions must be met:
582
+ self.hicache_io_backend = "kernel"
583
+ self.hicache_mem_layout = "page_first"
584
+
619
585
  # Speculative Decoding
620
586
  if self.speculative_algorithm == "NEXTN":
621
587
  # NEXTN shares the same implementation of EAGLE
@@ -1112,7 +1078,7 @@ class ServerArgs:
1112
1078
  parser.add_argument(
1113
1079
  "--log-requests-level",
1114
1080
  type=int,
1115
- default=0,
1081
+ default=ServerArgs.log_requests_level,
1116
1082
  help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
1117
1083
  choices=[0, 1, 2, 3],
1118
1084
  )
@@ -1198,6 +1164,12 @@ class ServerArgs:
1198
1164
  default=ServerArgs.served_model_name,
1199
1165
  help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
1200
1166
  )
1167
+ parser.add_argument(
1168
+ "--weight-version",
1169
+ type=str,
1170
+ default=ServerArgs.weight_version,
1171
+ help="Version identifier for the model weights. Defaults to 'default' if not specified.",
1172
+ )
1201
1173
  parser.add_argument(
1202
1174
  "--chat-template",
1203
1175
  type=str,
@@ -1231,7 +1203,7 @@ class ServerArgs:
1231
1203
  parser.add_argument(
1232
1204
  "--tool-call-parser",
1233
1205
  type=str,
1234
- choices=[
1206
+ choices=[ # TODO: use FunctionCallParser.DetectorMap.keys()
1235
1207
  "qwen25",
1236
1208
  "mistral",
1237
1209
  "llama3",
@@ -1241,10 +1213,17 @@ class ServerArgs:
1241
1213
  "qwen3_coder",
1242
1214
  "glm45",
1243
1215
  "step3",
1216
+ "gpt-oss",
1244
1217
  ],
1245
1218
  default=ServerArgs.tool_call_parser,
1246
1219
  help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'.",
1247
1220
  )
1221
+ parser.add_argument(
1222
+ "--tool-server",
1223
+ type=str,
1224
+ default=None,
1225
+ help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
1226
+ )
1248
1227
 
1249
1228
  # Data parallelism
1250
1229
  parser.add_argument(
@@ -1344,55 +1323,46 @@ class ServerArgs:
1344
1323
  )
1345
1324
 
1346
1325
  # Kernel backend
1326
+ ATTN_BACKENDS = [
1327
+ # Common
1328
+ "triton",
1329
+ "torch_native",
1330
+ # NVIDIA specific
1331
+ "cutlass_mla",
1332
+ "fa3",
1333
+ "flashinfer",
1334
+ "flashmla",
1335
+ "trtllm_mla",
1336
+ "trtllm_mha",
1337
+ "dual_chunk_flash_attn",
1338
+ # AMD specific
1339
+ "aiter",
1340
+ "wave",
1341
+ # Other platforms
1342
+ "intel_amx",
1343
+ "ascend",
1344
+ ]
1347
1345
  parser.add_argument(
1348
1346
  "--attention-backend",
1349
1347
  type=str,
1350
- choices=[
1351
- "aiter",
1352
- "cutlass_mla",
1353
- "fa3",
1354
- "flashinfer",
1355
- "flashmla",
1356
- "intel_amx",
1357
- "torch_native",
1358
- "ascend",
1359
- "triton",
1360
- "trtllm_mla",
1361
- "trtllm_mha",
1362
- "dual_chunk_flash_attn",
1363
- ],
1348
+ choices=ATTN_BACKENDS,
1364
1349
  default=ServerArgs.attention_backend,
1365
1350
  help="Choose the kernels for attention layers.",
1366
1351
  )
1367
- parser.add_argument(
1368
- "--decode-attention-backend",
1369
- type=str,
1370
- choices=[
1371
- "flashinfer",
1372
- "triton",
1373
- "torch_native",
1374
- "fa3",
1375
- "flashmla",
1376
- "cutlass_mla",
1377
- ],
1378
- default=ServerArgs.decode_attention_backend,
1379
- help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
1380
- )
1381
-
1382
1352
  parser.add_argument(
1383
1353
  "--prefill-attention-backend",
1384
1354
  type=str,
1385
- choices=[
1386
- "flashinfer",
1387
- "triton",
1388
- "torch_native",
1389
- "fa3",
1390
- "flashmla",
1391
- "cutlass_mla",
1392
- ],
1355
+ choices=ATTN_BACKENDS,
1393
1356
  default=ServerArgs.prefill_attention_backend,
1394
1357
  help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
1395
1358
  )
1359
+ parser.add_argument(
1360
+ "--decode-attention-backend",
1361
+ type=str,
1362
+ choices=ATTN_BACKENDS,
1363
+ default=ServerArgs.decode_attention_backend,
1364
+ help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
1365
+ )
1396
1366
  parser.add_argument(
1397
1367
  "--sampling-backend",
1398
1368
  type=str,
@@ -1493,7 +1463,7 @@ class ServerArgs:
1493
1463
  parser.add_argument(
1494
1464
  "--enable-flashinfer-allreduce-fusion",
1495
1465
  action="store_true",
1496
- help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
1466
+ help="Enable FlashInfer allreduce fusion with Residual RMSNorm.",
1497
1467
  )
1498
1468
  parser.add_argument(
1499
1469
  "--deepep-mode",
@@ -1612,7 +1582,6 @@ class ServerArgs:
1612
1582
  default=ServerArgs.hicache_mem_layout,
1613
1583
  help="The layout of host memory pool for hierarchical cache.",
1614
1584
  )
1615
-
1616
1585
  parser.add_argument(
1617
1586
  "--hicache-storage-backend",
1618
1587
  type=str,
@@ -1985,14 +1954,6 @@ class ServerArgs:
1985
1954
  help="Disable mmap while loading weight using safetensors.",
1986
1955
  )
1987
1956
 
1988
- # For tool server
1989
- parser.add_argument(
1990
- "--tool-server",
1991
- type=str,
1992
- default=None,
1993
- help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
1994
- )
1995
-
1996
1957
  # Deprecated arguments
1997
1958
  parser.add_argument(
1998
1959
  "--enable-ep-moe",
@@ -2056,25 +2017,6 @@ class ServerArgs:
2056
2017
  None,
2057
2018
  }, "moe_dense_tp_size only support 1 and None currently"
2058
2019
 
2059
- # Check model architecture
2060
- model_arch = self.get_hf_config().architectures[0]
2061
- if "Llama4" in model_arch:
2062
- assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
2063
-
2064
- if model_arch in [
2065
- "Gemma2ForCausalLM",
2066
- "Gemma3ForCausalLM",
2067
- "Gemma3ForConditionalGeneration",
2068
- "Gemma3nForCausalLM",
2069
- "Gemma3nForConditionalGeneration",
2070
- ]:
2071
- # FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
2072
- # It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
2073
- logger.warning(
2074
- f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
2075
- )
2076
- self.disable_hybrid_swa_memory = True
2077
-
2078
2020
  # Check LoRA
2079
2021
  self.check_lora_server_args()
2080
2022
 
@@ -2085,22 +2027,20 @@ class ServerArgs:
2085
2027
  ), "enable_mixed_chunk is required for speculative decoding"
2086
2028
 
2087
2029
  # Check chunked prefill
2088
- assert (
2089
- self.chunked_prefill_size % self.page_size == 0
2090
- ), "chunked_prefill_size must be divisible by page_size"
2030
+ # Skip validation if chunked prefill is disabled (i.e., size <= 0).
2031
+ if self.chunked_prefill_size > 0:
2032
+ assert (
2033
+ self.chunked_prefill_size % self.page_size == 0
2034
+ ), "chunked_prefill_size must be divisible by page_size"
2091
2035
 
2092
2036
  def check_lora_server_args(self):
2093
- assert (
2094
- self.max_loras_per_batch > 0
2095
- # FIXME
2096
- and (self.lora_paths is None or self.disable_radix_cache)
2097
- ), "compatibility of lora and radix attention is in progress"
2037
+ assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
2098
2038
 
2099
2039
  # Enable LoRA if any LoRA paths are provided for backward compatibility.
2100
2040
  if self.lora_paths:
2101
2041
  if self.enable_lora is None:
2102
2042
  self.enable_lora = True
2103
- logger.info(
2043
+ logger.warning(
2104
2044
  "--enable-lora is set to True because --lora-paths is provided."
2105
2045
  )
2106
2046
  elif self.enable_lora is False:
@@ -2172,6 +2112,72 @@ class ServerArgs:
2172
2112
  f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
2173
2113
  )
2174
2114
 
2115
+ def model_specific_adjustments(self):
2116
+ hf_config = self.get_hf_config()
2117
+ model_arch = hf_config.architectures[0]
2118
+ if model_arch in ["GptOssForCausalLM"]:
2119
+ if self.attention_backend is None:
2120
+ if is_sm100_supported():
2121
+ self.attention_backend = "trtllm_mha"
2122
+ elif is_sm90_supported():
2123
+ self.attention_backend = "fa3"
2124
+ else:
2125
+ self.attention_backend = "triton"
2126
+ supported_backends = ["triton", "trtllm_mha", "fa3"]
2127
+ logger.info(
2128
+ f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
2129
+ )
2130
+ assert (
2131
+ self.attention_backend in supported_backends
2132
+ ), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
2133
+
2134
+ if is_sm100_supported():
2135
+ self.enable_flashinfer_allreduce_fusion = True
2136
+ logger.info(
2137
+ "Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
2138
+ )
2139
+ quantization_config = getattr(hf_config, "quantization_config", None)
2140
+ is_mxfp4_quant_format = (
2141
+ quantization_config is not None
2142
+ and quantization_config.get("quant_method") == "mxfp4"
2143
+ )
2144
+
2145
+ if is_sm100_supported() and is_mxfp4_quant_format:
2146
+ self.enable_flashinfer_mxfp4_moe = True
2147
+ self.enable_triton_kernel_moe = False
2148
+ logger.warning(
2149
+ "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
2150
+ )
2151
+ else:
2152
+ if self.enable_triton_kernel_moe:
2153
+ assert (
2154
+ self.ep_size == 1
2155
+ ), "Triton kernel MoE is only supported when ep_size == 1"
2156
+ if not self.enable_triton_kernel_moe and self.ep_size == 1:
2157
+ self.enable_triton_kernel_moe = True
2158
+ logger.warning(
2159
+ "Detected GPT-OSS model, enabling triton_kernels MOE kernel."
2160
+ )
2161
+ self.disable_hybrid_swa_memory = True
2162
+ if is_mxfp4_quant_format:
2163
+ # use bf16 for mxfp4 triton kernels
2164
+ self.dtype = "bfloat16"
2165
+ elif "Llama4" in model_arch:
2166
+ assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
2167
+ elif model_arch in [
2168
+ "Gemma2ForCausalLM",
2169
+ "Gemma3ForCausalLM",
2170
+ "Gemma3ForConditionalGeneration",
2171
+ "Gemma3nForCausalLM",
2172
+ "Gemma3nForConditionalGeneration",
2173
+ ]:
2174
+ # FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
2175
+ # It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
2176
+ logger.warning(
2177
+ f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
2178
+ )
2179
+ self.disable_hybrid_swa_memory = True
2180
+
2175
2181
  def adjust_mem_fraction_for_vlm(self, model_config):
2176
2182
  vision_config = getattr(model_config.hf_config, "vision_config", None)
2177
2183
  if vision_config is None:
@@ -2209,10 +2215,6 @@ class ServerArgs:
2209
2215
  self.mem_fraction_static = (
2210
2216
  original_server_arg_mem_fraction * final_overall_factor
2211
2217
  )
2212
- logger.warning(
2213
- f"Multimodal model: Dynamically adjusted --mem-fraction-static "
2214
- f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}."
2215
- )
2216
2218
 
2217
2219
 
2218
2220
  def prepare_server_args(argv: List[str]) -> ServerArgs:
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
5
5
 
6
6
  import torch
7
7
 
8
- from sglang.srt.layers.dp_attention import DPPaddingMode
8
+ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
9
9
  from sglang.srt.model_executor.cuda_graph_runner import (
10
10
  CUDA_GRAPH_CAPTURE_FAILED_MSG,
11
11
  CudaGraphRunner,
@@ -105,30 +105,15 @@ class EAGLEDraftCudaGraphRunner:
105
105
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
106
106
  (self.dp_size,), dtype=torch.int32
107
107
  )
108
- self.gathered_buffer = torch.zeros(
109
- (
110
- self.max_num_token * self.dp_size,
111
- self.model_runner.model_config.hidden_size,
112
- ),
113
- dtype=self.model_runner.dtype,
114
- )
115
108
  else:
116
109
  assert self.require_attn_tp_gather
117
110
  self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
118
111
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
119
112
  (1,), dtype=torch.int32
120
113
  )
121
- self.gathered_buffer = torch.zeros(
122
- (
123
- self.max_num_token,
124
- self.model_runner.model_config.hidden_size,
125
- ),
126
- dtype=self.model_runner.dtype,
127
- )
128
114
  else:
129
115
  self.global_num_tokens_gpu = None
130
116
  self.global_num_tokens_for_logprob_gpu = None
131
- self.gathered_buffer = None
132
117
 
133
118
  # Capture
134
119
  try:
@@ -193,7 +178,7 @@ class EAGLEDraftCudaGraphRunner:
193
178
  )
194
179
  )
195
180
  global_num_tokens = self.global_num_tokens_gpu
196
- gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
181
+ global_dp_buffer_len = num_tokens * self.dp_size
197
182
  global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
198
183
  elif self.require_attn_tp_gather:
199
184
  self.global_num_tokens_gpu.copy_(
@@ -211,11 +196,11 @@ class EAGLEDraftCudaGraphRunner:
211
196
  )
212
197
  )
213
198
  global_num_tokens = self.global_num_tokens_gpu
214
- gathered_buffer = self.gathered_buffer[:num_tokens]
199
+ global_dp_buffer_len = num_tokens
215
200
  global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
216
201
  else:
217
202
  global_num_tokens = None
218
- gathered_buffer = None
203
+ global_dp_buffer_len = None
219
204
  global_num_tokens_for_logprob = None
220
205
 
221
206
  spec_info = EagleDraftInput(
@@ -239,8 +224,8 @@ class EAGLEDraftCudaGraphRunner:
239
224
  return_logprob=False,
240
225
  positions=positions,
241
226
  global_num_tokens_gpu=global_num_tokens,
242
- dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
243
- gathered_buffer=gathered_buffer,
227
+ dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
228
+ global_dp_buffer_len=global_dp_buffer_len,
244
229
  spec_algorithm=self.model_runner.spec_algorithm,
245
230
  spec_info=spec_info,
246
231
  capture_hidden_mode=(
@@ -258,6 +243,7 @@ class EAGLEDraftCudaGraphRunner:
258
243
  def run_once():
259
244
  # Clean intermediate result cache for DP attention
260
245
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
246
+ set_dp_buffer_len(global_dp_buffer_len, num_tokens)
261
247
 
262
248
  # Backup two fields, which will be modified in-place in `draft_forward`.
263
249
  output_cache_loc_backup = forward_batch.out_cache_loc
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
5
5
 
6
6
  import torch
7
7
 
8
- from sglang.srt.layers.dp_attention import DPPaddingMode
8
+ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
9
9
  from sglang.srt.model_executor.cuda_graph_runner import (
10
10
  CUDA_GRAPH_CAPTURE_FAILED_MSG,
11
11
  CudaGraphRunner,
@@ -117,30 +117,15 @@ class EAGLEDraftExtendCudaGraphRunner:
117
117
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
118
118
  (self.dp_size,), dtype=torch.int32
119
119
  )
120
- self.gathered_buffer = torch.zeros(
121
- (
122
- self.max_num_token * self.dp_size,
123
- self.model_runner.model_config.hidden_size,
124
- ),
125
- dtype=self.model_runner.dtype,
126
- )
127
120
  else:
128
121
  assert self.require_attn_tp_gather
129
122
  self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
130
123
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
131
124
  (1,), dtype=torch.int32
132
125
  )
133
- self.gathered_buffer = torch.zeros(
134
- (
135
- self.max_num_token,
136
- self.model_runner.model_config.hidden_size,
137
- ),
138
- dtype=self.model_runner.dtype,
139
- )
140
126
  else:
141
127
  self.global_num_tokens_gpu = None
142
128
  self.global_num_tokens_for_logprob_gpu = None
143
- self.gathered_buffer = None
144
129
 
145
130
  if hasattr(
146
131
  self.model_runner.model_config.hf_config, "draft_vocab_size"
@@ -222,7 +207,7 @@ class EAGLEDraftExtendCudaGraphRunner:
222
207
  device=self.input_ids.device,
223
208
  )
224
209
  )
225
- gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
210
+ global_dp_buffer_len = num_tokens * self.dp_size
226
211
  elif self.require_attn_tp_gather:
227
212
  self.global_num_tokens_gpu.copy_(
228
213
  torch.tensor(
@@ -238,9 +223,9 @@ class EAGLEDraftExtendCudaGraphRunner:
238
223
  device=self.input_ids.device,
239
224
  )
240
225
  )
241
- gathered_buffer = self.gathered_buffer[:num_tokens]
226
+ global_dp_buffer_len = num_tokens
242
227
  else:
243
- gathered_buffer = None
228
+ global_dp_buffer_len = None
244
229
 
245
230
  spec_info = EagleDraftInput(
246
231
  hidden_states=hidden_states,
@@ -264,8 +249,8 @@ class EAGLEDraftExtendCudaGraphRunner:
264
249
  positions=positions,
265
250
  global_num_tokens_gpu=self.global_num_tokens_gpu,
266
251
  global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
267
- dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
268
- gathered_buffer=gathered_buffer,
252
+ dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
253
+ global_dp_buffer_len=global_dp_buffer_len,
269
254
  spec_algorithm=self.model_runner.spec_algorithm,
270
255
  spec_info=spec_info,
271
256
  capture_hidden_mode=CaptureHiddenMode.LAST,
@@ -288,6 +273,7 @@ class EAGLEDraftExtendCudaGraphRunner:
288
273
  def run_once():
289
274
  # Clean intermediate result cache for DP attention
290
275
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
276
+ set_dp_buffer_len(global_dp_buffer_len, num_tokens)
291
277
 
292
278
  # Backup two fields, which will be modified in-place in `draft_forward`.
293
279
  output_cache_loc_backup = forward_batch.out_cache_loc
@@ -226,6 +226,22 @@ class EAGLEWorker(TpModelWorker):
226
226
  self.draft_model_runner,
227
227
  skip_prefill=False,
228
228
  )
229
+ elif self.server_args.attention_backend == "aiter":
230
+ from sglang.srt.layers.attention.aiter_backend import (
231
+ AiterAttnBackend,
232
+ AiterMultiStepDraftBackend,
233
+ )
234
+
235
+ self.draft_attn_backend = AiterMultiStepDraftBackend(
236
+ self.draft_model_runner,
237
+ self.topk,
238
+ self.speculative_num_steps,
239
+ )
240
+ self.draft_extend_attn_backend = AiterAttnBackend(
241
+ self.draft_model_runner,
242
+ skip_prefill=False,
243
+ )
244
+ self.has_prefill_wrapper_verify = False
229
245
  elif self.server_args.attention_backend == "fa3":
230
246
  from sglang.srt.layers.attention.flashattention_backend import (
231
247
  FlashAttentionBackend,