sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -9,18 +9,89 @@ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
9
9
 
10
10
  @triton.jit
11
11
  def create_flashinfer_kv_indices_triton(
12
- req_to_token_ptr, # [max_batch, max_context_len]
12
+ req_to_token_ptr,
13
13
  req_pool_indices_ptr,
14
14
  page_kernel_lens_ptr,
15
15
  kv_indptr,
16
16
  kv_start_idx,
17
17
  kv_indices_ptr,
18
18
  req_to_token_ptr_stride: tl.constexpr,
19
+ PAGE_SIZE: tl.constexpr = 1,
19
20
  ):
21
+ """
22
+ Create KV indices for FlashInfer attention backend.
23
+
24
+ This Triton kernel builds a lookup table that maps from logical request/token
25
+ coordinates to physical token locations in the global KV cache pool. It's used
26
+ by FlashInfer attention backends to efficiently access scattered KV cache data.
27
+
28
+ The kernel processes each request in parallel and converts the req_to_token
29
+ lookup table into a flat list of token indices that can be used by attention kernels.
30
+
31
+ general idea:
32
+ blocktables/kv_indices_ptr = [batch_size * max_pages(for graph mode with
33
+ fixed number of pages)]
34
+ max_pages = max_context_len / PAGED_SIZE
35
+ kv_indices_ptr will store the flat list of the pages used by each request
36
+ Args:
37
+ Inputs Arguments (non mutable):
38
+
39
+ req_to_token_ptr: Request to token location look up table
40
+ Shape: [max_batch, max_context_len]
41
+ req_pool_indices_ptr: Request to pool index look up table. Each request uses
42
+ one pool.
43
+ Shape: [batch_size]
44
+ page_kernel_lens_ptr: sequence lengths per request
45
+ Shape: [batch_size]
46
+ kv_indptr: Should be computed based on number of pages used by each request.
47
+ It is used by flashinfer attention kernels to index into the kv_indices_ptr.
48
+ per request.
49
+ Shape: [batch_size + 1]
50
+ kv_indptr[i] = start index in kv_indices for request i
51
+ kv_start_idx: Pointer to array containing start offsets for each request in SGL.
52
+ Can be None. If provided, adds offset to token positions.
53
+
54
+ req_to_token_ptr_stride: Stride for the second dimension of req_to_token.
55
+ Equal to max_context_len.
56
+
57
+ PAGED_SIZE: Number of tokens per page. Default is 1 for FlashInfer.
58
+
59
+ Outputs:
60
+ kv_indices_ptr: Pointer to output array where KV indices will be stored.
61
+ Shape:[total-num-pages],
62
+ where total_num_pages = sum(seq_lens // PAGED_SIZE)
63
+
64
+ Example:
65
+ If we have:
66
+ - req_pool_indices = [0, 1] (request 0 uses pool 0, request 1 uses pool 1)
67
+ - page_kernel_lens = [3, 2] (request 0 has 3 tokens, request 1 has 2 tokens)
68
+ - req_to_token = [[10, 11, 12, -1], [20, 21, -1, -1]] (tokens are the elements
69
+ in radix tree, use them as a pointer to the token location in the kv_indices_ptr)
70
+
71
+ The kernel will output:
72
+ If PAGE_SIZE = 1:
73
+ packed
74
+ - kv_indptr (passed in as input arg): [0,3,5]
75
+ - kv_indices = [10, 11, 12, 20, 21]
76
+ padded - max_pages is 10 tokens per req
77
+ - kv_indptr (passed in as input arg): [0,10, 20]
78
+ - kv_indices = [10, 11, 12, -1, -1, -1, -1, -1, -1, -1,
79
+ 20, 21, -1, -1, -1, -1, -1, -1, -1, -1]
80
+
81
+ If PAGE_SIZE = 2
82
+ packed:
83
+ - kv_indptr (passed in as input arg): [0,3,4]
84
+ - kv_indices = [5,6,10]
85
+ padded: max_pages is 4
86
+ - kv_indptr (passed in as input arg): [0,4,8,..] (note that 4 is the max_pages)
87
+ - kv_indices = [5, 6, -1, -1,
88
+ 10, -1, -1, -1]
89
+ This allows attention kernels to directly access the correct KV cache
90
+ entries for each request's tokens.
91
+ """
20
92
  BLOCK_SIZE: tl.constexpr = 512
93
+ NUM_PAGES_PER_BLOCK: tl.constexpr = BLOCK_SIZE // PAGE_SIZE
21
94
  pid = tl.program_id(axis=0)
22
-
23
- # find the req pool idx, this is for batch to token
24
95
  req_pool_index = tl.load(req_pool_indices_ptr + pid)
25
96
  kv_indices_offset = tl.load(kv_indptr + pid)
26
97
 
@@ -31,19 +102,27 @@ def create_flashinfer_kv_indices_triton(
31
102
  kv_end = kv_start
32
103
  kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
33
104
 
34
- num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
35
- for i in range(num_loop):
36
- # index into req_to_token_ptr needs to be int64
37
- offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
38
- mask = offset < kv_end - kv_start
39
- data = tl.load(
40
- req_to_token_ptr
41
- + req_pool_index * req_to_token_ptr_stride
42
- + kv_start
43
- + offset,
44
- mask=mask,
105
+ kv_range = kv_end - kv_start
106
+ num_pages = tl.cdiv(kv_range, PAGE_SIZE)
107
+ num_loops = tl.cdiv(kv_range, BLOCK_SIZE)
108
+ req_to_token_block_start = (
109
+ req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + kv_start
110
+ )
111
+ for i in range(num_loops):
112
+ token_offsets_in_block = (
113
+ tl.arange(0, NUM_PAGES_PER_BLOCK).to(tl.int64) + i * NUM_PAGES_PER_BLOCK
114
+ ) * PAGE_SIZE
115
+ page_offsets_in_block = token_offsets_in_block // PAGE_SIZE
116
+ valid_tokens = token_offsets_in_block < kv_range
117
+ valid_pages = page_offsets_in_block < num_pages
118
+ token_numbers = tl.load(
119
+ req_to_token_block_start + token_offsets_in_block, mask=valid_tokens
120
+ )
121
+ tl.store(
122
+ kv_indices_ptr + kv_indices_offset + page_offsets_in_block,
123
+ token_numbers // PAGE_SIZE, # write the page numbers to kv_indices_ptr
124
+ mask=valid_pages,
45
125
  )
46
- tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
47
126
 
48
127
 
49
128
  @triton.jit
@@ -12,7 +12,12 @@ import torch.nn.functional as F
12
12
  from einops import rearrange
13
13
 
14
14
  from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
15
- from sglang.srt.utils import is_cuda, print_info_once
15
+ from sglang.srt.utils import (
16
+ get_device_capability,
17
+ is_blackwell,
18
+ is_cuda,
19
+ print_info_once,
20
+ )
16
21
 
17
22
  _is_cuda = is_cuda()
18
23
 
@@ -20,7 +25,6 @@ if _is_cuda:
20
25
  from sgl_kernel.flash_attn import flash_attn_varlen_func
21
26
 
22
27
  from sglang.srt.distributed import (
23
- parallel_state,
24
28
  split_tensor_along_last_dim,
25
29
  tensor_model_parallel_all_gather,
26
30
  )
@@ -402,18 +406,14 @@ class VisionAttention(nn.Module):
402
406
  self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
403
407
  )
404
408
 
405
- # priority: server_args > passed qkv_backend > sdpa
406
- if global_server_args_dict["mm_attention_backend"] is None:
407
- if qkv_backend is None:
408
- if is_cuda():
409
- # Double prefill throughput by setting attn backend to Triton on CUDA
410
- qkv_backend = "triton_attn"
411
- else:
412
- qkv_backend = "sdpa"
409
+ # Select attention backend via a unified method
410
+ _passed_backend = qkv_backend
411
+ qkv_backend = self._determine_attention_backend(_passed_backend)
412
+ if (
413
+ global_server_args_dict["mm_attention_backend"] is None
414
+ and _passed_backend is None
415
+ ):
413
416
  print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
414
- else:
415
- qkv_backend = global_server_args_dict["mm_attention_backend"]
416
-
417
417
  print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
418
418
 
419
419
  self.customized_position_embedding_applier = (
@@ -461,6 +461,33 @@ class VisionAttention(nn.Module):
461
461
  prefix=add_prefix("proj", prefix),
462
462
  )
463
463
 
464
+ def _determine_attention_backend(self, passed_backend: Optional[str]) -> str:
465
+ """Decide the multimodal attention backend string.
466
+
467
+ Priority: server args override > constructor arg > platform default.
468
+
469
+ Platform defaults:
470
+ - CUDA: "triton_attn"
471
+ - Non-CUDA: "sdpa"
472
+ """
473
+ override_backend = global_server_args_dict["mm_attention_backend"]
474
+ if override_backend is not None:
475
+ backend = override_backend
476
+ elif passed_backend is not None:
477
+ backend = passed_backend
478
+ elif is_cuda():
479
+ major, minor = get_device_capability()
480
+ if major == 9:
481
+ backend = "fa3"
482
+ else:
483
+ backend = "triton_attn"
484
+ else:
485
+ backend = "sdpa"
486
+ if backend == "fa3" and is_blackwell():
487
+ raise ValueError("The 'fa3' backend is not supported on Blackwell GPUs")
488
+
489
+ return backend
490
+
464
491
  def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
465
492
  """apply qk norm for internvl vit attn"""
466
493
  q = q.flatten(1, 2)
@@ -0,0 +1,65 @@
1
+ """Utility functions for vision attention layers."""
2
+
3
+ import torch
4
+
5
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
6
+
7
+
8
+ def update_vit_attn_dummy_heads_config(config):
9
+ """Update HF config to ensure vision attention num_attention_heads is divisible by tp_size"""
10
+ tp_size = get_attention_tp_size()
11
+ num_heads = getattr(
12
+ config.vision_config,
13
+ "num_heads",
14
+ getattr(config.vision_config, "num_attention_heads", None),
15
+ )
16
+ head_dim = config.vision_config.hidden_size // num_heads
17
+ num_dummy_heads = 0
18
+
19
+ if num_heads % tp_size != 0:
20
+ num_dummy_heads = ((num_heads + tp_size - 1) // tp_size) * tp_size - num_heads
21
+
22
+ setattr(config.vision_config, "head_dim", head_dim)
23
+ setattr(config.vision_config, "num_dummy_heads", num_dummy_heads)
24
+
25
+
26
+ def pad_vit_attn_dummy_heads(config, name: str, loaded_weight: torch.Tensor):
27
+ """Pad attention qkv weights for dummy heads"""
28
+ num_dummy_heads = config.vision_config.num_dummy_heads
29
+ if num_dummy_heads == 0:
30
+ return loaded_weight
31
+ head_dim = config.vision_config.head_dim
32
+
33
+ if "attn.qkv_proj" in name:
34
+ wq, wk, wv = loaded_weight.chunk(3, dim=0)
35
+ if name.endswith(".weight"):
36
+ dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
37
+ elif name.endswith(".bias"):
38
+ dummy_shape = [num_dummy_heads, head_dim]
39
+ else:
40
+ raise RuntimeError(f"Unsupported weight with name={name}")
41
+ pad_func = lambda x: torch.cat(
42
+ [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
43
+ ).flatten(0, 1)
44
+ wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
45
+ loaded_weight = torch.cat([wq, wk, wv], dim=0)
46
+ elif any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]):
47
+ if name.endswith(".weight"):
48
+ dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]]
49
+ elif name.endswith(".bias"):
50
+ dummy_shape = [num_dummy_heads, head_dim]
51
+ else:
52
+ raise RuntimeError(f"Unsupported weight with name={name}")
53
+ padded_weight = loaded_weight.new_zeros(dummy_shape)
54
+ loaded_weight = torch.cat(
55
+ [loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0
56
+ ).flatten(0, 1)
57
+ elif "attn.proj.weight" in name:
58
+ padded_weight = loaded_weight.new_zeros(
59
+ loaded_weight.shape[0], head_dim * num_dummy_heads
60
+ )
61
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
62
+ elif "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
63
+ padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
64
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
65
+ return loaded_weight
@@ -17,7 +17,7 @@ from enum import Enum, auto
17
17
  from functools import partial
18
18
  from typing import Dict, Optional
19
19
 
20
- import torch.distributed
20
+ import torch
21
21
 
22
22
  from sglang.srt.distributed import (
23
23
  get_tensor_model_parallel_world_size,
@@ -32,6 +32,13 @@ from sglang.srt.layers.dp_attention import (
32
32
  get_attention_dp_size,
33
33
  get_attention_tp_rank,
34
34
  get_attention_tp_size,
35
+ get_global_dp_buffer,
36
+ get_local_dp_buffer,
37
+ is_dp_attention_enabled,
38
+ )
39
+ from sglang.srt.layers.moe import (
40
+ get_moe_a2a_backend,
41
+ should_use_flashinfer_cutlass_moe_fp4_allgather,
35
42
  )
36
43
  from sglang.srt.layers.utils import is_sm100_supported
37
44
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -41,6 +48,8 @@ from sglang.srt.utils import is_cuda, is_flashinfer_available
41
48
  _is_flashinfer_available = is_flashinfer_available()
42
49
  _is_sm100_supported = is_cuda() and is_sm100_supported()
43
50
 
51
+ FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
52
+
44
53
 
45
54
  class ScatterMode(Enum):
46
55
  """
@@ -109,7 +118,11 @@ class LayerScatterModes:
109
118
  if context.is_layer_sparse:
110
119
  return (
111
120
  ScatterMode.SCATTERED
112
- if not global_server_args_dict["moe_a2a_backend"].is_standard()
121
+ if (
122
+ # Token dispatch/combine will be handled outside of LayerCommunicator for these modes.
123
+ not get_moe_a2a_backend().is_none()
124
+ or should_use_flashinfer_cutlass_moe_fp4_allgather()
125
+ )
113
126
  else ScatterMode.FULL
114
127
  )
115
128
  else:
@@ -152,11 +165,13 @@ class LayerCommunicator:
152
165
  post_attention_layernorm: torch.nn.Module,
153
166
  # Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
154
167
  allow_reduce_scatter: bool = False,
168
+ is_last_layer: bool = False,
155
169
  ):
156
170
  self.layer_scatter_modes = layer_scatter_modes
157
171
  self.input_layernorm = input_layernorm
158
172
  self.post_attention_layernorm = post_attention_layernorm
159
173
  self.allow_reduce_scatter = allow_reduce_scatter
174
+ self.is_last_layer = is_last_layer
160
175
 
161
176
  self._context = CommunicateContext.init_new()
162
177
  self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
@@ -254,6 +269,41 @@ class LayerCommunicator:
254
269
  and forward_batch.dp_padding_mode.is_max_len()
255
270
  )
256
271
 
272
+ def should_fuse_mlp_allreduce_with_next_layer(
273
+ self, forward_batch: ForwardBatch
274
+ ) -> bool:
275
+ speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
276
+ if (
277
+ is_dp_attention_enabled()
278
+ and speculative_algo is not None
279
+ and speculative_algo.is_eagle()
280
+ ):
281
+ return False
282
+
283
+ batch_size = (
284
+ forward_batch.input_ids.shape[0]
285
+ if hasattr(forward_batch, "input_ids")
286
+ else 0
287
+ )
288
+ if batch_size > FUSE_ALLREDUCE_MAX_BATCH_SIZE:
289
+ return False
290
+
291
+ static_conditions_met = (
292
+ (not self.is_last_layer)
293
+ and (self._context.tp_size > 1)
294
+ and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
295
+ and _is_flashinfer_available
296
+ )
297
+
298
+ if not static_conditions_met:
299
+ return False
300
+
301
+ return (
302
+ batch_size > 0
303
+ and batch_size <= FUSE_ALLREDUCE_MAX_BATCH_SIZE
304
+ and (not self.is_last_layer)
305
+ )
306
+
257
307
 
258
308
  @dataclass
259
309
  class CommunicateContext:
@@ -319,7 +369,7 @@ class CommunicateSimpleFn:
319
369
  context: CommunicateContext,
320
370
  ) -> torch.Tensor:
321
371
  hidden_states, local_hidden_states = (
322
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
372
+ get_local_dp_buffer(),
323
373
  hidden_states,
324
374
  )
325
375
  attn_tp_all_gather_into_tensor(
@@ -380,7 +430,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
380
430
  )
381
431
 
382
432
  raise NotImplementedError(
383
- f"{hidden_states_input_mode=} {residual_input_mode=} {residual_output_mode=} {residual_output_mode=}"
433
+ f"{hidden_states_input_mode=} {residual_input_mode=} {hidden_states_output_mode=} {residual_output_mode=}"
384
434
  )
385
435
 
386
436
  @staticmethod
@@ -408,9 +458,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
408
458
  ):
409
459
  if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
410
460
  residual, local_residual = (
411
- torch.empty_like(
412
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]]
413
- ),
461
+ get_local_dp_buffer(),
414
462
  residual,
415
463
  )
416
464
  attn_tp_all_gather_into_tensor(residual, local_residual)
@@ -424,7 +472,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
424
472
  residual = hidden_states
425
473
  hidden_states = layernorm(hidden_states)
426
474
  hidden_states, local_hidden_states = (
427
- torch.empty_like(forward_batch.gathered_buffer),
475
+ get_global_dp_buffer(),
428
476
  hidden_states,
429
477
  )
430
478
  dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
@@ -548,7 +596,7 @@ class CommunicateSummableTensorPairFn:
548
596
  allow_reduce_scatter: bool = False,
549
597
  ):
550
598
  hidden_states, global_hidden_states = (
551
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
599
+ get_local_dp_buffer(),
552
600
  hidden_states,
553
601
  )
554
602
  if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
@@ -569,7 +617,7 @@ class CommunicateSummableTensorPairFn:
569
617
  hidden_states += residual
570
618
  residual = None
571
619
  hidden_states, local_hidden_states = (
572
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
620
+ get_local_dp_buffer(),
573
621
  hidden_states,
574
622
  )
575
623
  attn_tp_all_gather_into_tensor(
@@ -4,7 +4,7 @@ import functools
4
4
  import logging
5
5
  from contextlib import contextmanager
6
6
  from enum import IntEnum, auto
7
- from typing import TYPE_CHECKING, List, Tuple
7
+ from typing import TYPE_CHECKING, List, Optional, Tuple
8
8
 
9
9
  import torch
10
10
  import triton
@@ -18,21 +18,26 @@ from sglang.srt.distributed import (
18
18
  tensor_model_parallel_all_reduce,
19
19
  )
20
20
 
21
+ if TYPE_CHECKING:
22
+ from sglang.srt.configs.model_config import ModelConfig
23
+ from sglang.srt.server_args import ServerArgs
24
+
21
25
  logger = logging.getLogger(__name__)
22
26
 
23
27
  if TYPE_CHECKING:
24
28
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
25
29
 
26
- _ATTN_TP_GROUP = None
27
- _ATTN_TP_RANK = None
28
- _ATTN_TP_SIZE = None
29
- _ATTN_DP_RANK = None
30
- _ATTN_DP_SIZE = None
31
- _LOCAL_ATTN_DP_SIZE = None
32
- _LOCAL_ATTN_DP_RANK = None
30
+ _ATTN_TP_GROUP: Optional[GroupCoordinator] = None
31
+ _ATTN_TP_RANK: Optional[int] = None
32
+ _ATTN_TP_SIZE: Optional[int] = None
33
+ _ATTN_DP_RANK: Optional[int] = None
34
+ _ATTN_DP_SIZE: Optional[int] = None
35
+ _LOCAL_ATTN_DP_SIZE: Optional[int] = None
36
+ _LOCAL_ATTN_DP_RANK: Optional[int] = None
37
+ _ENABLE_DP_ATTENTION_FLAG: bool = False
33
38
 
34
39
 
35
- class DPPaddingMode(IntEnum):
40
+ class DpPaddingMode(IntEnum):
36
41
 
37
42
  # Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
38
43
  MAX_LEN = auto()
@@ -40,13 +45,13 @@ class DPPaddingMode(IntEnum):
40
45
  SUM_LEN = auto()
41
46
 
42
47
  def is_max_len(self):
43
- return self == DPPaddingMode.MAX_LEN
48
+ return self == DpPaddingMode.MAX_LEN
44
49
 
45
50
  def is_sum_len(self):
46
- return self == DPPaddingMode.SUM_LEN
51
+ return self == DpPaddingMode.SUM_LEN
47
52
 
48
53
  @classmethod
49
- def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
54
+ def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
50
55
  # we choose the mode that minimizes the communication cost
51
56
  max_len = max(global_num_tokens)
52
57
  sum_len = sum(global_num_tokens)
@@ -56,10 +61,95 @@ class DPPaddingMode(IntEnum):
56
61
  return cls.SUM_LEN
57
62
 
58
63
  @classmethod
59
- def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
64
+ def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
60
65
  return cls.MAX_LEN
61
66
 
62
67
 
68
+ class _DpGatheredBufferWrapper:
69
+
70
+ _hidden_size: int
71
+ _dtype: torch.dtype
72
+ _device: torch.device
73
+ _global_dp_buffer_len: int
74
+ _local_dp_buffer_len: int
75
+ _global_num_tokens: Optional[List[int]]
76
+
77
+ @classmethod
78
+ def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
79
+ cls._hidden_size = hidden_size
80
+ cls._dtype = dtype
81
+ cls._device = device
82
+
83
+ @classmethod
84
+ def set_dp_buffer_len(
85
+ cls,
86
+ global_dp_buffer_len: int,
87
+ local_dp_buffer_len: int,
88
+ global_num_tokens: Optional[List[int]] = None,
89
+ ):
90
+ cls._global_dp_buffer_len = global_dp_buffer_len
91
+ cls._local_dp_buffer_len = local_dp_buffer_len
92
+ cls._global_num_tokens = global_num_tokens
93
+
94
+ @classmethod
95
+ def get_global_dp_buffer(cls) -> torch.Tensor:
96
+ return torch.empty(
97
+ (cls._global_dp_buffer_len, cls._hidden_size),
98
+ dtype=cls._dtype,
99
+ device=cls._device,
100
+ )
101
+
102
+ @classmethod
103
+ def get_local_dp_buffer(cls) -> torch.Tensor:
104
+ return torch.empty(
105
+ (cls._local_dp_buffer_len, cls._hidden_size),
106
+ dtype=cls._dtype,
107
+ device=cls._device,
108
+ )
109
+
110
+ @classmethod
111
+ def get_global_dp_buffer_len(cls) -> int:
112
+ return cls._global_dp_buffer_len
113
+
114
+ @classmethod
115
+ def get_local_dp_buffer_len(cls) -> int:
116
+ return cls._local_dp_buffer_len
117
+
118
+ @classmethod
119
+ def get_dp_global_num_tokens(cls) -> List[int]:
120
+ return cls._global_num_tokens
121
+
122
+
123
+ def set_dp_buffer_len(
124
+ global_dp_buffer_len: int,
125
+ local_dp_buffer_len: int,
126
+ global_num_tokens: Optional[List[int]] = None,
127
+ ):
128
+ _DpGatheredBufferWrapper.set_dp_buffer_len(
129
+ global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
130
+ )
131
+
132
+
133
+ def get_global_dp_buffer() -> torch.Tensor:
134
+ return _DpGatheredBufferWrapper.get_global_dp_buffer()
135
+
136
+
137
+ def get_local_dp_buffer() -> torch.Tensor:
138
+ return _DpGatheredBufferWrapper.get_local_dp_buffer()
139
+
140
+
141
+ def get_global_dp_buffer_len() -> int:
142
+ return _DpGatheredBufferWrapper.get_global_dp_buffer_len()
143
+
144
+
145
+ def get_local_dp_buffer_len() -> int:
146
+ return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
147
+
148
+
149
+ def get_dp_global_num_tokens() -> List[int]:
150
+ return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
151
+
152
+
63
153
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
64
154
  if not enable_dp_attention:
65
155
  return tp_rank, tp_size, 0
@@ -89,18 +179,24 @@ def compute_dp_attention_local_info(
89
179
 
90
180
 
91
181
  def initialize_dp_attention(
92
- enable_dp_attention: bool,
93
- tp_rank: int,
94
- tp_size: int,
95
- dp_size: int,
96
- moe_dense_tp_size: int,
97
- pp_size: int,
182
+ server_args: ServerArgs,
183
+ model_config: ModelConfig,
98
184
  ):
99
185
  global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
100
- global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
186
+ global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK, _ENABLE_DP_ATTENTION_FLAG
101
187
 
102
188
  from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
103
189
 
190
+ enable_dp_attention = server_args.enable_dp_attention
191
+ tp_size = server_args.tp_size
192
+ dp_size = server_args.dp_size
193
+ moe_dense_tp_size = server_args.moe_dense_tp_size
194
+ pp_size = server_args.pp_size
195
+
196
+ tp_rank = get_tensor_model_parallel_rank()
197
+
198
+ _ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
199
+
104
200
  _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
105
201
  enable_dp_attention, tp_rank, tp_size, dp_size
106
202
  )
@@ -135,38 +231,48 @@ def initialize_dp_attention(
135
231
  group_name="attention_tp",
136
232
  )
137
233
 
234
+ _DpGatheredBufferWrapper.set_metadata(
235
+ hidden_size=model_config.hidden_size,
236
+ dtype=model_config.dtype,
237
+ device=torch.device(server_args.device),
238
+ )
239
+
240
+
241
+ def is_dp_attention_enabled() -> bool:
242
+ return _ENABLE_DP_ATTENTION_FLAG
138
243
 
139
- def get_attention_tp_group():
244
+
245
+ def get_attention_tp_group() -> GroupCoordinator:
140
246
  assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
141
247
  return _ATTN_TP_GROUP
142
248
 
143
249
 
144
- def get_attention_tp_rank():
250
+ def get_attention_tp_rank() -> int:
145
251
  assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
146
252
  return _ATTN_TP_RANK
147
253
 
148
254
 
149
- def get_attention_tp_size():
255
+ def get_attention_tp_size() -> int:
150
256
  assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
151
257
  return _ATTN_TP_SIZE
152
258
 
153
259
 
154
- def get_attention_dp_rank():
260
+ def get_attention_dp_rank() -> int:
155
261
  assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
156
262
  return _ATTN_DP_RANK
157
263
 
158
264
 
159
- def get_attention_dp_size():
265
+ def get_attention_dp_size() -> int:
160
266
  assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
161
267
  return _ATTN_DP_SIZE
162
268
 
163
269
 
164
- def get_local_attention_dp_rank():
270
+ def get_local_attention_dp_rank() -> int:
165
271
  assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
166
272
  return _LOCAL_ATTN_DP_RANK
167
273
 
168
274
 
169
- def get_local_attention_dp_size():
275
+ def get_local_attention_dp_size() -> int:
170
276
  assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
171
277
  return _LOCAL_ATTN_DP_SIZE
172
278
 
@@ -292,6 +398,10 @@ def _dp_gather_via_all_gather(
292
398
  forward_batch: ForwardBatch,
293
399
  is_partial: bool,
294
400
  ):
401
+ if get_attention_tp_size() == 1:
402
+ get_tp_group().all_gather_into_tensor(global_tokens, local_tokens)
403
+ return
404
+
295
405
  if not is_partial:
296
406
  if get_attention_tp_rank() != 0:
297
407
  local_tokens.fill_(0)