sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -9,18 +9,89 @@ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
9
9
 
10
10
  @triton.jit
11
11
  def create_flashinfer_kv_indices_triton(
12
- req_to_token_ptr, # [max_batch, max_context_len]
12
+ req_to_token_ptr,
13
13
  req_pool_indices_ptr,
14
14
  page_kernel_lens_ptr,
15
15
  kv_indptr,
16
16
  kv_start_idx,
17
17
  kv_indices_ptr,
18
18
  req_to_token_ptr_stride: tl.constexpr,
19
+ PAGE_SIZE: tl.constexpr = 1,
19
20
  ):
21
+ """
22
+ Create KV indices for FlashInfer attention backend.
23
+
24
+ This Triton kernel builds a lookup table that maps from logical request/token
25
+ coordinates to physical token locations in the global KV cache pool. It's used
26
+ by FlashInfer attention backends to efficiently access scattered KV cache data.
27
+
28
+ The kernel processes each request in parallel and converts the req_to_token
29
+ lookup table into a flat list of token indices that can be used by attention kernels.
30
+
31
+ general idea:
32
+ blocktables/kv_indices_ptr = [batch_size * max_pages(for graph mode with
33
+ fixed number of pages)]
34
+ max_pages = max_context_len / PAGED_SIZE
35
+ kv_indices_ptr will store the flat list of the pages used by each request
36
+ Args:
37
+ Inputs Arguments (non mutable):
38
+
39
+ req_to_token_ptr: Request to token location look up table
40
+ Shape: [max_batch, max_context_len]
41
+ req_pool_indices_ptr: Request to pool index look up table. Each request uses
42
+ one pool.
43
+ Shape: [batch_size]
44
+ page_kernel_lens_ptr: sequence lengths per request
45
+ Shape: [batch_size]
46
+ kv_indptr: Should be computed based on number of pages used by each request.
47
+ It is used by flashinfer attention kernels to index into the kv_indices_ptr.
48
+ per request.
49
+ Shape: [batch_size + 1]
50
+ kv_indptr[i] = start index in kv_indices for request i
51
+ kv_start_idx: Pointer to array containing start offsets for each request in SGL.
52
+ Can be None. If provided, adds offset to token positions.
53
+
54
+ req_to_token_ptr_stride: Stride for the second dimension of req_to_token.
55
+ Equal to max_context_len.
56
+
57
+ PAGED_SIZE: Number of tokens per page. Default is 1 for FlashInfer.
58
+
59
+ Outputs:
60
+ kv_indices_ptr: Pointer to output array where KV indices will be stored.
61
+ Shape:[total-num-pages],
62
+ where total_num_pages = sum(seq_lens // PAGED_SIZE)
63
+
64
+ Example:
65
+ If we have:
66
+ - req_pool_indices = [0, 1] (request 0 uses pool 0, request 1 uses pool 1)
67
+ - page_kernel_lens = [3, 2] (request 0 has 3 tokens, request 1 has 2 tokens)
68
+ - req_to_token = [[10, 11, 12, -1], [20, 21, -1, -1]] (tokens are the elements
69
+ in radix tree, use them as a pointer to the token location in the kv_indices_ptr)
70
+
71
+ The kernel will output:
72
+ If PAGE_SIZE = 1:
73
+ packed
74
+ - kv_indptr (passed in as input arg): [0,3,5]
75
+ - kv_indices = [10, 11, 12, 20, 21]
76
+ padded - max_pages is 10 tokens per req
77
+ - kv_indptr (passed in as input arg): [0,10, 20]
78
+ - kv_indices = [10, 11, 12, -1, -1, -1, -1, -1, -1, -1,
79
+ 20, 21, -1, -1, -1, -1, -1, -1, -1, -1]
80
+
81
+ If PAGE_SIZE = 2
82
+ packed:
83
+ - kv_indptr (passed in as input arg): [0,3,4]
84
+ - kv_indices = [5,6,10]
85
+ padded: max_pages is 4
86
+ - kv_indptr (passed in as input arg): [0,4,8,..] (note that 4 is the max_pages)
87
+ - kv_indices = [5, 6, -1, -1,
88
+ 10, -1, -1, -1]
89
+ This allows attention kernels to directly access the correct KV cache
90
+ entries for each request's tokens.
91
+ """
20
92
  BLOCK_SIZE: tl.constexpr = 512
93
+ NUM_PAGES_PER_BLOCK: tl.constexpr = BLOCK_SIZE // PAGE_SIZE
21
94
  pid = tl.program_id(axis=0)
22
-
23
- # find the req pool idx, this is for batch to token
24
95
  req_pool_index = tl.load(req_pool_indices_ptr + pid)
25
96
  kv_indices_offset = tl.load(kv_indptr + pid)
26
97
 
@@ -31,19 +102,27 @@ def create_flashinfer_kv_indices_triton(
31
102
  kv_end = kv_start
32
103
  kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
33
104
 
34
- num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
35
- for i in range(num_loop):
36
- # index into req_to_token_ptr needs to be int64
37
- offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
38
- mask = offset < kv_end - kv_start
39
- data = tl.load(
40
- req_to_token_ptr
41
- + req_pool_index * req_to_token_ptr_stride
42
- + kv_start
43
- + offset,
44
- mask=mask,
105
+ kv_range = kv_end - kv_start
106
+ num_pages = tl.cdiv(kv_range, PAGE_SIZE)
107
+ num_loops = tl.cdiv(kv_range, BLOCK_SIZE)
108
+ req_to_token_block_start = (
109
+ req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + kv_start
110
+ )
111
+ for i in range(num_loops):
112
+ token_offsets_in_block = (
113
+ tl.arange(0, NUM_PAGES_PER_BLOCK).to(tl.int64) + i * NUM_PAGES_PER_BLOCK
114
+ ) * PAGE_SIZE
115
+ page_offsets_in_block = token_offsets_in_block // PAGE_SIZE
116
+ valid_tokens = token_offsets_in_block < kv_range
117
+ valid_pages = page_offsets_in_block < num_pages
118
+ token_numbers = tl.load(
119
+ req_to_token_block_start + token_offsets_in_block, mask=valid_tokens
120
+ )
121
+ tl.store(
122
+ kv_indices_ptr + kv_indices_offset + page_offsets_in_block,
123
+ token_numbers // PAGE_SIZE, # write the page numbers to kv_indices_ptr
124
+ mask=valid_pages,
45
125
  )
46
- tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
47
126
 
48
127
 
49
128
  @triton.jit
@@ -12,7 +12,12 @@ import torch.nn.functional as F
12
12
  from einops import rearrange
13
13
 
14
14
  from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
15
- from sglang.srt.utils import is_cuda, print_info_once
15
+ from sglang.srt.utils import (
16
+ get_device_capability,
17
+ is_blackwell,
18
+ is_cuda,
19
+ print_info_once,
20
+ )
16
21
 
17
22
  _is_cuda = is_cuda()
18
23
 
@@ -20,7 +25,6 @@ if _is_cuda:
20
25
  from sgl_kernel.flash_attn import flash_attn_varlen_func
21
26
 
22
27
  from sglang.srt.distributed import (
23
- parallel_state,
24
28
  split_tensor_along_last_dim,
25
29
  tensor_model_parallel_all_gather,
26
30
  )
@@ -402,18 +406,14 @@ class VisionAttention(nn.Module):
402
406
  self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
403
407
  )
404
408
 
405
- # priority: server_args > passed qkv_backend > sdpa
406
- if global_server_args_dict["mm_attention_backend"] is None:
407
- if qkv_backend is None:
408
- if is_cuda():
409
- # Double prefill throughput by setting attn backend to Triton on CUDA
410
- qkv_backend = "triton_attn"
411
- else:
412
- qkv_backend = "sdpa"
409
+ # Select attention backend via a unified method
410
+ _passed_backend = qkv_backend
411
+ qkv_backend = self._determine_attention_backend(_passed_backend)
412
+ if (
413
+ global_server_args_dict["mm_attention_backend"] is None
414
+ and _passed_backend is None
415
+ ):
413
416
  print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
414
- else:
415
- qkv_backend = global_server_args_dict["mm_attention_backend"]
416
-
417
417
  print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
418
418
 
419
419
  self.customized_position_embedding_applier = (
@@ -461,6 +461,33 @@ class VisionAttention(nn.Module):
461
461
  prefix=add_prefix("proj", prefix),
462
462
  )
463
463
 
464
+ def _determine_attention_backend(self, passed_backend: Optional[str]) -> str:
465
+ """Decide the multimodal attention backend string.
466
+
467
+ Priority: server args override > constructor arg > platform default.
468
+
469
+ Platform defaults:
470
+ - CUDA: "triton_attn"
471
+ - Non-CUDA: "sdpa"
472
+ """
473
+ override_backend = global_server_args_dict["mm_attention_backend"]
474
+ if override_backend is not None:
475
+ backend = override_backend
476
+ elif passed_backend is not None:
477
+ backend = passed_backend
478
+ elif is_cuda():
479
+ major, minor = get_device_capability()
480
+ if major == 9:
481
+ backend = "fa3"
482
+ else:
483
+ backend = "triton_attn"
484
+ else:
485
+ backend = "sdpa"
486
+ if backend == "fa3" and is_blackwell():
487
+ raise ValueError("The 'fa3' backend is not supported on Blackwell GPUs")
488
+
489
+ return backend
490
+
464
491
  def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
465
492
  """apply qk norm for internvl vit attn"""
466
493
  q = q.flatten(1, 2)
@@ -0,0 +1,65 @@
1
+ """Utility functions for vision attention layers."""
2
+
3
+ import torch
4
+
5
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
6
+
7
+
8
+ def update_vit_attn_dummy_heads_config(config):
9
+ """Update HF config to ensure vision attention num_attention_heads is divisible by tp_size"""
10
+ tp_size = get_attention_tp_size()
11
+ num_heads = getattr(
12
+ config.vision_config,
13
+ "num_heads",
14
+ getattr(config.vision_config, "num_attention_heads", None),
15
+ )
16
+ head_dim = config.vision_config.hidden_size // num_heads
17
+ num_dummy_heads = 0
18
+
19
+ if num_heads % tp_size != 0:
20
+ num_dummy_heads = ((num_heads + tp_size - 1) // tp_size) * tp_size - num_heads
21
+
22
+ setattr(config.vision_config, "head_dim", head_dim)
23
+ setattr(config.vision_config, "num_dummy_heads", num_dummy_heads)
24
+
25
+
26
+ def pad_vit_attn_dummy_heads(config, name: str, loaded_weight: torch.Tensor):
27
+ """Pad attention qkv weights for dummy heads"""
28
+ num_dummy_heads = config.vision_config.num_dummy_heads
29
+ if num_dummy_heads == 0:
30
+ return loaded_weight
31
+ head_dim = config.vision_config.head_dim
32
+
33
+ if "attn.qkv_proj" in name:
34
+ wq, wk, wv = loaded_weight.chunk(3, dim=0)
35
+ if name.endswith(".weight"):
36
+ dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
37
+ elif name.endswith(".bias"):
38
+ dummy_shape = [num_dummy_heads, head_dim]
39
+ else:
40
+ raise RuntimeError(f"Unsupported weight with name={name}")
41
+ pad_func = lambda x: torch.cat(
42
+ [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
43
+ ).flatten(0, 1)
44
+ wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
45
+ loaded_weight = torch.cat([wq, wk, wv], dim=0)
46
+ elif any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]):
47
+ if name.endswith(".weight"):
48
+ dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]]
49
+ elif name.endswith(".bias"):
50
+ dummy_shape = [num_dummy_heads, head_dim]
51
+ else:
52
+ raise RuntimeError(f"Unsupported weight with name={name}")
53
+ padded_weight = loaded_weight.new_zeros(dummy_shape)
54
+ loaded_weight = torch.cat(
55
+ [loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0
56
+ ).flatten(0, 1)
57
+ elif "attn.proj.weight" in name:
58
+ padded_weight = loaded_weight.new_zeros(
59
+ loaded_weight.shape[0], head_dim * num_dummy_heads
60
+ )
61
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
62
+ elif "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
63
+ padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
64
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
65
+ return loaded_weight
@@ -17,7 +17,7 @@ from enum import Enum, auto
17
17
  from functools import partial
18
18
  from typing import Dict, Optional
19
19
 
20
- import torch.distributed
20
+ import torch
21
21
 
22
22
  from sglang.srt.distributed import (
23
23
  get_tensor_model_parallel_world_size,
@@ -34,6 +34,11 @@ from sglang.srt.layers.dp_attention import (
34
34
  get_attention_tp_size,
35
35
  get_global_dp_buffer,
36
36
  get_local_dp_buffer,
37
+ is_dp_attention_enabled,
38
+ )
39
+ from sglang.srt.layers.moe import (
40
+ get_moe_a2a_backend,
41
+ should_use_flashinfer_cutlass_moe_fp4_allgather,
37
42
  )
38
43
  from sglang.srt.layers.utils import is_sm100_supported
39
44
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -43,6 +48,8 @@ from sglang.srt.utils import is_cuda, is_flashinfer_available
43
48
  _is_flashinfer_available = is_flashinfer_available()
44
49
  _is_sm100_supported = is_cuda() and is_sm100_supported()
45
50
 
51
+ FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
52
+
46
53
 
47
54
  class ScatterMode(Enum):
48
55
  """
@@ -111,7 +118,11 @@ class LayerScatterModes:
111
118
  if context.is_layer_sparse:
112
119
  return (
113
120
  ScatterMode.SCATTERED
114
- if not global_server_args_dict["moe_a2a_backend"].is_standard()
121
+ if (
122
+ # Token dispatch/combine will be handled outside of LayerCommunicator for these modes.
123
+ not get_moe_a2a_backend().is_none()
124
+ or should_use_flashinfer_cutlass_moe_fp4_allgather()
125
+ )
115
126
  else ScatterMode.FULL
116
127
  )
117
128
  else:
@@ -154,11 +165,13 @@ class LayerCommunicator:
154
165
  post_attention_layernorm: torch.nn.Module,
155
166
  # Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
156
167
  allow_reduce_scatter: bool = False,
168
+ is_last_layer: bool = False,
157
169
  ):
158
170
  self.layer_scatter_modes = layer_scatter_modes
159
171
  self.input_layernorm = input_layernorm
160
172
  self.post_attention_layernorm = post_attention_layernorm
161
173
  self.allow_reduce_scatter = allow_reduce_scatter
174
+ self.is_last_layer = is_last_layer
162
175
 
163
176
  self._context = CommunicateContext.init_new()
164
177
  self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
@@ -256,6 +269,41 @@ class LayerCommunicator:
256
269
  and forward_batch.dp_padding_mode.is_max_len()
257
270
  )
258
271
 
272
+ def should_fuse_mlp_allreduce_with_next_layer(
273
+ self, forward_batch: ForwardBatch
274
+ ) -> bool:
275
+ speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
276
+ if (
277
+ is_dp_attention_enabled()
278
+ and speculative_algo is not None
279
+ and speculative_algo.is_eagle()
280
+ ):
281
+ return False
282
+
283
+ batch_size = (
284
+ forward_batch.input_ids.shape[0]
285
+ if hasattr(forward_batch, "input_ids")
286
+ else 0
287
+ )
288
+ if batch_size > FUSE_ALLREDUCE_MAX_BATCH_SIZE:
289
+ return False
290
+
291
+ static_conditions_met = (
292
+ (not self.is_last_layer)
293
+ and (self._context.tp_size > 1)
294
+ and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
295
+ and _is_flashinfer_available
296
+ )
297
+
298
+ if not static_conditions_met:
299
+ return False
300
+
301
+ return (
302
+ batch_size > 0
303
+ and batch_size <= FUSE_ALLREDUCE_MAX_BATCH_SIZE
304
+ and (not self.is_last_layer)
305
+ )
306
+
259
307
 
260
308
  @dataclass
261
309
  class CommunicateContext:
@@ -382,7 +430,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
382
430
  )
383
431
 
384
432
  raise NotImplementedError(
385
- f"{hidden_states_input_mode=} {residual_input_mode=} {residual_output_mode=} {residual_output_mode=}"
433
+ f"{hidden_states_input_mode=} {residual_input_mode=} {hidden_states_output_mode=} {residual_output_mode=}"
386
434
  )
387
435
 
388
436
  @staticmethod
@@ -72,6 +72,7 @@ class _DpGatheredBufferWrapper:
72
72
  _device: torch.device
73
73
  _global_dp_buffer_len: int
74
74
  _local_dp_buffer_len: int
75
+ _global_num_tokens: Optional[List[int]]
75
76
 
76
77
  @classmethod
77
78
  def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
@@ -80,9 +81,15 @@ class _DpGatheredBufferWrapper:
80
81
  cls._device = device
81
82
 
82
83
  @classmethod
83
- def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int):
84
+ def set_dp_buffer_len(
85
+ cls,
86
+ global_dp_buffer_len: int,
87
+ local_dp_buffer_len: int,
88
+ global_num_tokens: Optional[List[int]] = None,
89
+ ):
84
90
  cls._global_dp_buffer_len = global_dp_buffer_len
85
91
  cls._local_dp_buffer_len = local_dp_buffer_len
92
+ cls._global_num_tokens = global_num_tokens
86
93
 
87
94
  @classmethod
88
95
  def get_global_dp_buffer(cls) -> torch.Tensor:
@@ -108,10 +115,18 @@ class _DpGatheredBufferWrapper:
108
115
  def get_local_dp_buffer_len(cls) -> int:
109
116
  return cls._local_dp_buffer_len
110
117
 
118
+ @classmethod
119
+ def get_dp_global_num_tokens(cls) -> List[int]:
120
+ return cls._global_num_tokens
121
+
111
122
 
112
- def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int):
123
+ def set_dp_buffer_len(
124
+ global_dp_buffer_len: int,
125
+ local_dp_buffer_len: int,
126
+ global_num_tokens: Optional[List[int]] = None,
127
+ ):
113
128
  _DpGatheredBufferWrapper.set_dp_buffer_len(
114
- global_dp_buffer_len, local_dp_buffer_len
129
+ global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
115
130
  )
116
131
 
117
132
 
@@ -131,6 +146,10 @@ def get_local_dp_buffer_len() -> int:
131
146
  return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
132
147
 
133
148
 
149
+ def get_dp_global_num_tokens() -> List[int]:
150
+ return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
151
+
152
+
134
153
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
135
154
  if not enable_dp_attention:
136
155
  return tp_rank, tp_size, 0
@@ -215,7 +234,7 @@ def initialize_dp_attention(
215
234
  _DpGatheredBufferWrapper.set_metadata(
216
235
  hidden_size=model_config.hidden_size,
217
236
  dtype=model_config.dtype,
218
- device=torch.device("cuda"),
237
+ device=torch.device(server_args.device),
219
238
  )
220
239
 
221
240
 
@@ -486,3 +486,97 @@ def gelu_and_mul_triton(
486
486
  return out_hidden_states, out_scales
487
487
  else:
488
488
  return out_hidden_states, None
489
+
490
+
491
+ # silu on first half of vector
492
+ @triton.jit
493
+ def silu_and_mul_kernel(
494
+ out_hidden_states_ptr, # (bs, hidden_dim)
495
+ out_scales_ptr, # (bs,)
496
+ hidden_states_ptr, # (bs, hidden_dim * 2)
497
+ quant_max: tl.constexpr,
498
+ static_scale: tl.constexpr,
499
+ hidden_dim: tl.constexpr, # the output hidden_dim
500
+ BLOCK_SIZE: tl.constexpr,
501
+ ):
502
+ pid = tl.program_id(axis=0)
503
+
504
+ input_start = pid * hidden_dim * 2
505
+ output_start = pid * hidden_dim
506
+
507
+ input1_offs = tl.arange(0, BLOCK_SIZE)
508
+ mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
509
+ input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
510
+ output_offs = tl.arange(0, BLOCK_SIZE)
511
+
512
+ x1 = tl.load(
513
+ hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
514
+ ).to(tl.float32)
515
+ x3 = tl.load(
516
+ hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
517
+ ).to(tl.float32)
518
+
519
+ # silu
520
+ # cast down before mul to better match training?
521
+ silu_x1 = x1 * tl.sigmoid(x1)
522
+ out = x3 * silu_x1.to(hidden_states_ptr.dtype.element_ty)
523
+
524
+ if quant_max is not None:
525
+ raise NotImplementedError()
526
+
527
+ tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
528
+
529
+
530
+ def silu_and_mul_triton(
531
+ hidden_states,
532
+ scales=None,
533
+ quantize=None, # dtype to quantize to
534
+ out=None,
535
+ ):
536
+ bs, in_hidden_dim = hidden_states.shape
537
+ hidden_dim = in_hidden_dim // 2
538
+
539
+ if out is None:
540
+ out_hidden_states = torch.empty(
541
+ (bs, hidden_dim),
542
+ dtype=quantize or hidden_states.dtype,
543
+ device=hidden_states.device,
544
+ )
545
+ else:
546
+ assert out.shape == (bs, hidden_dim)
547
+ assert out.dtype == (quantize or hidden_states.dtype)
548
+ out_hidden_states = out
549
+ out_scales = None
550
+ static_scale = False
551
+ if quantize is not None:
552
+ if scales is None:
553
+ out_scales = torch.empty(
554
+ (bs,), dtype=torch.float32, device=hidden_states.device
555
+ )
556
+ else:
557
+ out_scales = scales
558
+ static_scale = True
559
+
560
+ max_warps = 16 if _is_hip else 32
561
+ config = {
562
+ # 8 ele per thread (not tuned)
563
+ "num_warps": max(
564
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
565
+ ),
566
+ }
567
+
568
+ silu_and_mul_kernel[(bs,)](
569
+ out_hidden_states,
570
+ out_scales,
571
+ hidden_states,
572
+ quant_max=torch.finfo(quantize).max if quantize is not None else None,
573
+ static_scale=static_scale,
574
+ hidden_dim=hidden_dim,
575
+ BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
576
+ **config,
577
+ )
578
+
579
+ if quantize is not None:
580
+ return out_hidden_states, out_scales
581
+ else:
582
+ return out_hidden_states, None
@@ -5,7 +5,11 @@ import torch
5
5
  import torch.distributed as dist
6
6
 
7
7
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
8
- from sglang.srt.utils import is_flashinfer_available
8
+ from sglang.srt.utils import (
9
+ direct_register_custom_op,
10
+ is_flashinfer_available,
11
+ supports_custom_op,
12
+ )
9
13
 
10
14
  logger = logging.getLogger(__name__)
11
15
 
@@ -196,6 +200,30 @@ def flashinfer_allreduce_residual_rmsnorm(
196
200
  return norm_out, residual_out
197
201
 
198
202
 
203
+ def fake_flashinfer_allreduce_residual_rmsnorm(
204
+ input_tensor: torch.Tensor,
205
+ residual: torch.Tensor,
206
+ weight: torch.Tensor,
207
+ eps: float = 1e-6,
208
+ max_token_num: int = 2048,
209
+ use_oneshot: Optional[bool] = None,
210
+ trigger_completion_at_end: bool = False,
211
+ fp32_acc: bool = False,
212
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
213
+ residual_out = torch.empty_like(residual)
214
+ norm_out = torch.empty_like(input_tensor)
215
+ return norm_out, residual_out
216
+
217
+
218
+ if supports_custom_op():
219
+ direct_register_custom_op(
220
+ "flashinfer_allreduce_residual_rmsnorm",
221
+ flashinfer_allreduce_residual_rmsnorm,
222
+ mutates_args=["input_tensor", "residual", "weight"],
223
+ fake_impl=fake_flashinfer_allreduce_residual_rmsnorm,
224
+ )
225
+
226
+
199
227
  def cleanup_flashinfer_workspace():
200
228
  global _workspace_manager
201
229
  if _workspace_manager is not None:
@@ -27,6 +27,7 @@ from sglang.srt.utils import (
27
27
  is_cuda,
28
28
  is_hip,
29
29
  is_npu,
30
+ supports_custom_op,
30
31
  )
31
32
 
32
33
  _is_cuda = is_cuda()
@@ -202,8 +203,14 @@ class RMSNorm(CustomOp):
202
203
  flashinfer_allreduce_residual_rmsnorm,
203
204
  )
204
205
 
206
+ fused_op = (
207
+ torch.ops.sglang.flashinfer_allreduce_residual_rmsnorm
208
+ if supports_custom_op()
209
+ else flashinfer_allreduce_residual_rmsnorm
210
+ )
211
+
205
212
  if get_tensor_model_parallel_world_size() > 1:
206
- fused_result = flashinfer_allreduce_residual_rmsnorm(
213
+ fused_result = fused_op(
207
214
  input_tensor=x,
208
215
  residual=residual,
209
216
  weight=self.weight,
@@ -110,6 +110,20 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
110
110
  return param[shard_id], loaded_weight
111
111
 
112
112
 
113
+ def adjust_shard_offsets(shard_offsets, loaded_weight, dim):
114
+ actual_weight_size = loaded_weight.size(dim)
115
+ target_weight_size = shard_offsets[-1][-1] + shard_offsets[-1][-2]
116
+ if actual_weight_size != target_weight_size:
117
+ new_shard_offsets = []
118
+ new_offset = 0
119
+ for shard_id, shard_offset, shard_size in shard_offsets:
120
+ actual_shard_size = actual_weight_size * shard_size // target_weight_size
121
+ new_shard_offsets.append((shard_id, new_offset, actual_shard_size))
122
+ new_offset += actual_shard_size
123
+ return new_shard_offsets
124
+ return shard_offsets
125
+
126
+
113
127
  class LinearBase(torch.nn.Module):
114
128
  """Base linear layer.
115
129
 
@@ -535,6 +549,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
535
549
  packed_dim = getattr(param, "packed_dim", None)
536
550
 
537
551
  use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
552
+ if _is_cpu:
553
+ shard_offsets = adjust_shard_offsets(
554
+ shard_offsets, loaded_weight, output_dim
555
+ )
556
+
538
557
  for shard_id, shard_offset, shard_size in shard_offsets:
539
558
  # Special case for Quantization.
540
559
  # If quantized, we need to adjust the offset and size to account
@@ -977,6 +996,11 @@ class QKVParallelLinear(ColumnParallelLinear):
977
996
  use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
978
997
 
979
998
  packed_dim = getattr(param, "packed_dim", None)
999
+ if _is_cpu:
1000
+ shard_offsets = adjust_shard_offsets(
1001
+ shard_offsets, loaded_weight, output_dim
1002
+ )
1003
+
980
1004
  for shard_id, shard_offset, shard_size in shard_offsets:
981
1005
  # Special case for Quantized Weights.
982
1006
  # If quantized, we need to adjust the offset and size to account
@@ -191,7 +191,11 @@ class LogitsMetadata:
191
191
  else:
192
192
  self.global_dp_buffer_len = self.global_dp_buffer_len
193
193
 
194
- set_dp_buffer_len(self.global_dp_buffer_len, self.dp_local_num_tokens)
194
+ set_dp_buffer_len(
195
+ self.global_dp_buffer_len,
196
+ self.dp_local_num_tokens,
197
+ self.global_num_tokens_for_logprob_cpu,
198
+ )
195
199
 
196
200
 
197
201
  class LogitsProcessor(nn.Module):