sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,219 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Optional
5
+
6
+ import torch
7
+ import torch_npu
8
+ from torch.nn.functional import scaled_dot_product_attention
9
+
10
+ from sglang.srt.configs.model_config import AttentionArch
11
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
12
+ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
13
+ from sglang.srt.layers.radix_attention import AttentionType
14
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
15
+
16
+ if TYPE_CHECKING:
17
+ from sglang.srt.layers.radix_attention import RadixAttention
18
+ from sglang.srt.model_executor.model_runner import ModelRunner
19
+
20
+
21
+ @dataclass
22
+ class ForwardMetadata:
23
+
24
+ # calculated map for kv positions [bs * maxseqlen]
25
+ block_tables: Optional[torch.Tensor] = None
26
+
27
+ # seq len inputs
28
+ extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
29
+ seq_lens_cpu_int: Optional[torch.Tensor] = None
30
+
31
+
32
+ class AscendAttnBackend(AttentionBackend):
33
+
34
+ def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16):
35
+ mask_flag = torch.tril(
36
+ torch.ones((max_seq_len, max_seq_len), dtype=torch.bool)
37
+ ).view(max_seq_len, max_seq_len)
38
+ mask_flag = ~mask_flag
39
+ if dtype == torch.float16:
40
+ mask_value = torch.finfo(torch.float32).min
41
+ else:
42
+ mask_value = 1
43
+ self.mask = (
44
+ torch.masked_fill(
45
+ torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value
46
+ )
47
+ .to(dtype)
48
+ .to(self.device)
49
+ )
50
+ self.mask_len = max_seq_len
51
+
52
+ def __init__(self, model_runner: ModelRunner):
53
+ super().__init__()
54
+ self.forward_metadata = ForwardMetadata()
55
+ self.device = model_runner.device
56
+ self.gen_attention_mask(128, model_runner.dtype)
57
+ self.page_size = model_runner.page_size
58
+ self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
59
+ if self.use_mla:
60
+ self.kv_lora_rank = model_runner.model_config.kv_lora_rank
61
+ self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
62
+ self.native_attn = TorchNativeAttnBackend(model_runner)
63
+
64
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
65
+ """Init the metadata for a forward pass."""
66
+ self.forward_metadata.block_tables = (
67
+ forward_batch.req_to_token_pool.req_to_token[
68
+ forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
69
+ ][:, :: self.page_size]
70
+ // self.page_size
71
+ )
72
+ if forward_batch.extend_seq_lens is not None:
73
+ self.forward_metadata.extend_seq_lens_cpu_int = (
74
+ forward_batch.extend_seq_lens.cpu().int()
75
+ )
76
+ self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
77
+
78
+ def forward_extend(
79
+ self,
80
+ q,
81
+ k,
82
+ v,
83
+ layer: RadixAttention,
84
+ forward_batch: ForwardBatch,
85
+ save_kv_cache=True,
86
+ ):
87
+ if save_kv_cache:
88
+ forward_batch.token_to_kv_pool.set_kv_buffer(
89
+ layer, forward_batch.out_cache_loc, k, v
90
+ )
91
+
92
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
93
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
94
+
95
+ if not self.use_mla:
96
+ query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
97
+ output = torch.empty(
98
+ (query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
99
+ dtype=query.dtype,
100
+ device=query.device,
101
+ )
102
+
103
+ torch_npu._npu_flash_attention_qlens(
104
+ query=query,
105
+ key_cache=k_cache,
106
+ value_cache=v_cache,
107
+ mask=self.mask,
108
+ block_table=self.forward_metadata.block_tables,
109
+ seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
110
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
111
+ scale_value=layer.scaling,
112
+ num_heads=layer.tp_q_head_num,
113
+ num_kv_heads=layer.tp_k_head_num,
114
+ out=output,
115
+ )
116
+ return output
117
+ else:
118
+ if layer.qk_head_dim != layer.v_head_dim:
119
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
120
+ else:
121
+ o = torch.empty_like(q)
122
+
123
+ use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
124
+
125
+ q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
126
+ o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
127
+
128
+ causal = True
129
+ if (
130
+ layer.is_cross_attention
131
+ or layer.attn_type == AttentionType.ENCODER_ONLY
132
+ ):
133
+ causal = False
134
+
135
+ self.native_attn._run_sdpa_forward_extend(
136
+ q_,
137
+ o_,
138
+ k_cache.view(
139
+ -1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim)
140
+ ),
141
+ v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank),
142
+ forward_batch.req_to_token_pool.req_to_token,
143
+ forward_batch.req_pool_indices,
144
+ forward_batch.seq_lens,
145
+ forward_batch.extend_prefix_lens,
146
+ forward_batch.extend_seq_lens,
147
+ scaling=layer.scaling,
148
+ enable_gqa=use_gqa,
149
+ causal=causal,
150
+ )
151
+ return o
152
+
153
+ def forward_decode(
154
+ self,
155
+ q: torch.Tensor,
156
+ k: torch.Tensor,
157
+ v: torch.Tensor,
158
+ layer: RadixAttention,
159
+ forward_batch: ForwardBatch,
160
+ save_kv_cache=True,
161
+ ):
162
+ if save_kv_cache:
163
+ forward_batch.token_to_kv_pool.set_kv_buffer(
164
+ layer, forward_batch.out_cache_loc, k, v
165
+ )
166
+ if not self.use_mla:
167
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
168
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
169
+
170
+ query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
171
+ num_tokens = query.shape[0]
172
+ output = torch.empty(
173
+ (num_tokens, layer.tp_q_head_num, layer.v_head_dim),
174
+ dtype=query.dtype,
175
+ device=query.device,
176
+ )
177
+
178
+ torch_npu._npu_paged_attention(
179
+ query=query,
180
+ key_cache=k_cache,
181
+ value_cache=v_cache,
182
+ num_heads=layer.tp_q_head_num,
183
+ num_kv_heads=layer.tp_k_head_num,
184
+ scale_value=layer.scaling,
185
+ block_table=self.forward_metadata.block_tables,
186
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
187
+ out=output,
188
+ )
189
+ return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
190
+ else:
191
+ query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
192
+ num_tokens = query.shape[0]
193
+ kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer(
194
+ layer.layer_id
195
+ )
196
+ kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
197
+ -1,
198
+ self.page_size,
199
+ layer.tp_k_head_num,
200
+ self.kv_lora_rank + self.qk_rope_head_dim,
201
+ )
202
+
203
+ attn_output = torch.empty(
204
+ [num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
205
+ dtype=q.dtype,
206
+ device=q.device,
207
+ )
208
+ torch_npu._npu_paged_attention_mla(
209
+ query=query,
210
+ key_cache=kv_c_and_k_pe_cache,
211
+ num_kv_heads=layer.tp_k_head_num,
212
+ num_heads=layer.tp_q_head_num,
213
+ scale_value=layer.scaling,
214
+ block_table=self.forward_metadata.block_tables,
215
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
216
+ mla_vheadsize=self.kv_lora_rank,
217
+ out=attn_output,
218
+ )
219
+ return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
@@ -9,6 +9,7 @@ import torch
9
9
  from sglang.srt.configs.model_config import AttentionArch
10
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
11
  from sglang.srt.managers.schedule_batch import global_server_args_dict
12
+ from sglang.srt.mem_cache.memory_pool import SWAKVPool
12
13
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
13
14
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
14
15
 
@@ -320,6 +321,11 @@ class FlashAttentionBackend(AttentionBackend):
320
321
  self.page_size = model_runner.page_size
321
322
  self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
322
323
  self.skip_prefill = skip_prefill
324
+ self.is_hybrid = model_runner.is_hybrid
325
+ if self.is_hybrid:
326
+ self.full_to_swa_index_mapping = (
327
+ model_runner.token_to_kv_pool.full_to_swa_index_mapping
328
+ )
323
329
  self.topk = model_runner.server_args.speculative_eagle_topk or 0
324
330
  self.speculative_num_steps = speculative_num_steps
325
331
  self.speculative_num_draft_tokens = (
@@ -428,7 +434,7 @@ class FlashAttentionBackend(AttentionBackend):
428
434
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
429
435
  ]
430
436
  # TODO: we need to test this part for llama 4 eagle case
431
- self._init_local_attn_metadata(metadata, device)
437
+ self._init_local_attn_metadata(forward_batch, metadata, device)
432
438
  elif forward_batch.forward_mode.is_target_verify():
433
439
  if self.topk <= 1:
434
440
  metadata.cache_seqlens_int32 = (
@@ -456,7 +462,7 @@ class FlashAttentionBackend(AttentionBackend):
456
462
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
457
463
  ]
458
464
 
459
- self._init_local_attn_metadata(metadata, device)
465
+ self._init_local_attn_metadata(forward_batch, metadata, device)
460
466
  else:
461
467
  metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
462
468
  metadata.max_seq_len_q = self.speculative_num_draft_tokens
@@ -575,7 +581,7 @@ class FlashAttentionBackend(AttentionBackend):
575
581
 
576
582
  # Setup local attention if enabled
577
583
  if forward_batch.forward_mode == ForwardMode.EXTEND:
578
- self._init_local_attn_metadata(metadata, device)
584
+ self._init_local_attn_metadata(forward_batch, metadata, device)
579
585
 
580
586
  # Encoder metadata for cross attention
581
587
  if forward_batch.encoder_lens is not None:
@@ -1588,7 +1594,7 @@ class FlashAttentionBackend(AttentionBackend):
1588
1594
  forward_mode: ForwardMode,
1589
1595
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1590
1596
  seq_lens_cpu: Optional[torch.Tensor],
1591
- out_cache_loc: torch.Tensor = None,
1597
+ out_cache_loc: Optional[torch.Tensor] = None,
1592
1598
  ):
1593
1599
  """Initialize forward metadata for replaying CUDA graph."""
1594
1600
  seq_lens = seq_lens[:bs]
@@ -1673,7 +1679,10 @@ class FlashAttentionBackend(AttentionBackend):
1673
1679
  self.page_size,
1674
1680
  )
1675
1681
 
1676
- self._update_local_attn_metadata_for_replay(metadata, bs)
1682
+ self._update_local_attn_metadata_for_replay(
1683
+ metadata,
1684
+ bs,
1685
+ )
1677
1686
  elif forward_mode.is_target_verify():
1678
1687
  if self.topk <= 1:
1679
1688
  metadata = self.target_verify_metadata[bs]
@@ -1829,7 +1838,9 @@ class FlashAttentionBackend(AttentionBackend):
1829
1838
  """Get the fill value for sequence length in CUDA graph."""
1830
1839
  return 1
1831
1840
 
1832
- def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
1841
+ def _init_local_attn_metadata(
1842
+ self, forwardbatch: ForwardBatch, metadata: FlashAttentionMetadata, device
1843
+ ):
1833
1844
  """Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
1834
1845
  if self.attention_chunk_size is None:
1835
1846
  metadata.local_attn_metadata = None
@@ -1837,7 +1848,12 @@ class FlashAttentionBackend(AttentionBackend):
1837
1848
 
1838
1849
  cu_seqlens_q = metadata.cu_seqlens_q
1839
1850
  cache_seqlens_int32 = metadata.cache_seqlens_int32
1840
- page_table = metadata.page_table
1851
+ if self.is_hybrid:
1852
+ page_table = self.full_to_swa_index_mapping[metadata.page_table].to(
1853
+ torch.int32
1854
+ )
1855
+ else:
1856
+ page_table = metadata.page_table
1841
1857
  if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
1842
1858
  metadata.local_attn_metadata = None
1843
1859
  return
@@ -1923,7 +1939,9 @@ class FlashAttentionBackend(AttentionBackend):
1923
1939
  )
1924
1940
 
1925
1941
  def _update_local_attn_metadata_for_replay(
1926
- self, metadata: FlashAttentionMetadata, bs: int
1942
+ self,
1943
+ metadata: FlashAttentionMetadata,
1944
+ bs: int,
1927
1945
  ):
1928
1946
  """Update preallocated local attention metadata in-place before CUDA graph replay."""
1929
1947
  if self.attention_chunk_size is None:
@@ -1954,7 +1972,12 @@ class FlashAttentionBackend(AttentionBackend):
1954
1972
  # Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
1955
1973
  # beyond the actual sequence length, leading to incorrect attention calculations
1956
1974
  max_seq_len = int(seqlens.max().item())
1957
- sliced_page_table = metadata.page_table[:bs, :max_seq_len]
1975
+ if self.is_hybrid:
1976
+ sliced_page_table = self.full_to_swa_index_mapping[
1977
+ metadata.page_table[:bs, :max_seq_len]
1978
+ ].to(torch.int32)
1979
+ else:
1980
+ sliced_page_table = metadata.page_table[:bs, :max_seq_len]
1958
1981
 
1959
1982
  cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
1960
1983
  seqlens_np = seqlens.cpu().numpy()
@@ -119,21 +119,27 @@ class TboAttnBackend(AttentionBackend):
119
119
  replay_seq_lens_sum: int = None,
120
120
  replay_seq_lens_cpu: Optional[torch.Tensor] = None,
121
121
  ):
122
+ token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
123
+ forward_mode=forward_mode, spec_info=spec_info
124
+ )
122
125
  if fn_name == "init_forward_metadata_capture_cuda_graph":
123
- assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
124
- num_tokens = bs
126
+ assert (
127
+ capture_num_tokens == bs * token_num_per_seq
128
+ ), "For target-verify or decode mode, num_tokens should be equal to token_num_per_seq * bs"
129
+ num_tokens = bs * token_num_per_seq
125
130
 
126
131
  tbo_split_seq_index, tbo_split_token_index = (
127
132
  two_batch_overlap.compute_split_indices_for_cuda_graph_replay(
128
133
  forward_mode=forward_mode,
129
134
  cuda_graph_num_tokens=num_tokens,
135
+ spec_info=spec_info,
130
136
  )
131
137
  )
132
138
 
133
139
  num_tokens_child_left = tbo_split_token_index
134
140
  num_tokens_child_right = num_tokens - tbo_split_token_index
135
- bs_child_left = num_tokens_child_left
136
- bs_child_right = num_tokens_child_right
141
+ bs_child_left = tbo_split_seq_index
142
+ bs_child_right = bs - bs_child_left
137
143
 
138
144
  assert (
139
145
  num_tokens_child_left > 0 and num_tokens_child_right > 0
@@ -190,16 +196,36 @@ def _init_forward_metadata_cuda_graph_split(
190
196
  seq_lens: torch.Tensor,
191
197
  encoder_lens: Optional[torch.Tensor],
192
198
  forward_mode: "ForwardMode",
193
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
199
+ spec_info: Optional[EagleVerifyInput],
194
200
  # capture args
195
201
  capture_num_tokens: int = None,
196
202
  # replay args
197
203
  replay_seq_lens_sum: int = None,
198
204
  replay_seq_lens_cpu: Optional[torch.Tensor] = None,
199
205
  ):
206
+ token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
207
+ forward_mode=forward_mode, spec_info=spec_info
208
+ )
200
209
  assert encoder_lens is None, "encoder_lens is not supported yet"
201
- assert spec_info is None, "spec_info is not supported yet"
210
+ if spec_info is not None:
211
+ output_spec_info = two_batch_overlap.split_spec_info(
212
+ spec_info=spec_info,
213
+ start_seq_index=seq_slice.start if seq_slice.start is not None else 0,
214
+ end_seq_index=seq_slice.stop if seq_slice.stop is not None else bs,
215
+ start_token_index=(
216
+ seq_slice.start * token_num_per_seq
217
+ if seq_slice.start is not None
218
+ else 0
219
+ ),
220
+ end_token_index=(
221
+ seq_slice.stop * token_num_per_seq
222
+ if seq_slice.stop is not None
223
+ else bs * token_num_per_seq
224
+ ),
225
+ )
202
226
 
227
+ else:
228
+ output_spec_info = None
203
229
  ans = dict(
204
230
  bs=output_bs,
205
231
  req_pool_indices=req_pool_indices[seq_slice],
@@ -208,14 +234,16 @@ def _init_forward_metadata_cuda_graph_split(
208
234
  forward_mode=forward_mode,
209
235
  # ignore
210
236
  encoder_lens=None,
211
- spec_info=None,
237
+ spec_info=output_spec_info,
212
238
  )
213
239
 
214
240
  if fn_name == "init_forward_metadata_capture_cuda_graph":
215
- assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
241
+ assert (
242
+ capture_num_tokens == bs * token_num_per_seq
243
+ ), "Only support num_tokens==bs * token_num_per_seq for target-verify or decode mode"
216
244
  ans.update(
217
245
  dict(
218
- num_tokens=output_bs,
246
+ num_tokens=output_bs * token_num_per_seq,
219
247
  )
220
248
  )
221
249
  elif fn_name == "init_forward_metadata_replay_cuda_graph":
@@ -32,8 +32,13 @@ from sglang.srt.layers.dp_attention import (
32
32
  get_attention_tp_rank,
33
33
  get_attention_tp_size,
34
34
  )
35
+ from sglang.srt.layers.utils import is_sm100_supported
35
36
  from sglang.srt.managers.schedule_batch import global_server_args_dict
36
37
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
38
+ from sglang.srt.utils import is_cuda, is_flashinfer_available
39
+
40
+ _is_flashinfer_available = is_flashinfer_available()
41
+ _is_sm100_supported = is_cuda() and is_sm100_supported()
37
42
 
38
43
 
39
44
  class ScatterMode(Enum):
@@ -397,8 +402,21 @@ class CommunicateWithAllReduceAndLayerNormFn:
397
402
  if hidden_states.shape[0] != 0:
398
403
  hidden_states = layernorm(hidden_states)
399
404
  else:
400
- hidden_states = tensor_model_parallel_all_reduce(hidden_states)
401
- hidden_states, residual = layernorm(hidden_states, residual)
405
+ # According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
406
+ # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
407
+ if (
408
+ _is_sm100_supported
409
+ and _is_flashinfer_available
410
+ and hasattr(layernorm, "forward_with_allreduce_fusion")
411
+ and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
412
+ and hidden_states.shape[0] <= 128
413
+ ):
414
+ hidden_states, residual = layernorm.forward_with_allreduce_fusion(
415
+ hidden_states, residual
416
+ )
417
+ else:
418
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
419
+ hidden_states, residual = layernorm(hidden_states, residual)
402
420
  return hidden_states, residual
403
421
 
404
422
  @staticmethod
@@ -79,14 +79,12 @@ def initialize_dp_attention(
79
79
  )
80
80
 
81
81
  if enable_dp_attention:
82
- local_rank = tp_rank % (tp_size // dp_size)
83
82
  _ATTN_DP_SIZE = dp_size
84
83
  if moe_dense_tp_size is None:
85
84
  _LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
86
85
  else:
87
86
  _LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
88
87
  else:
89
- local_rank = tp_rank
90
88
  _ATTN_DP_SIZE = 1
91
89
  _LOCAL_ATTN_DP_SIZE = 1
92
90
 
@@ -96,7 +94,7 @@ def initialize_dp_attention(
96
94
  list(range(head, head + _ATTN_TP_SIZE))
97
95
  for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
98
96
  ],
99
- local_rank,
97
+ tp_group.local_rank,
100
98
  torch.distributed.get_backend(tp_group.device_group),
101
99
  use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
102
100
  use_pymscclpp=False,
@@ -239,6 +237,10 @@ def _dp_gather(
239
237
  assert (
240
238
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
241
239
  ), "aliasing between global_tokens and local_tokens not allowed"
240
+
241
+ # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
242
+ # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
243
+ # actual size of the accepted tokens.
242
244
  if forward_batch.forward_mode.is_draft_extend():
243
245
  shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
244
246
  local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
@@ -293,6 +295,10 @@ def dp_scatter(
293
295
  assert (
294
296
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
295
297
  ), "aliasing between local_tokens and global_tokens not allowed"
298
+
299
+ # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
300
+ # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
301
+ # actual size of the accepted tokens.
296
302
  if forward_batch.forward_mode.is_draft_extend():
297
303
  shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
298
304
  local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
@@ -8,6 +8,7 @@ from sglang.srt.utils import is_hip
8
8
 
9
9
  _is_hip = is_hip()
10
10
 
11
+
11
12
  fused_softcap_autotune = triton.autotune(
12
13
  configs=[
13
14
  triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
@@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
189
190
  assert x.shape == residual.shape and x.dtype == residual.dtype
190
191
  output, mid = torch.empty_like(x), torch.empty_like(x)
191
192
  bs, hidden_dim = x.shape
192
-
193
- min_num_warps = 16 if _is_hip else 32
194
-
195
193
  if autotune:
196
194
  fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
197
195
  output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
198
196
  )
199
197
  else:
198
+ max_warps = 16 if _is_hip else 32
200
199
  config = {
201
200
  "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
202
201
  "num_warps": max(
203
- min(
204
- triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
205
- ),
206
- 4,
202
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
207
203
  ),
208
204
  }
209
205
 
@@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
260
256
  else:
261
257
  output = torch.empty_like(x)
262
258
  bs, hidden_dim = x.shape
263
-
264
- min_num_warps = 16 if _is_hip else 32
265
-
259
+ max_warps = 16 if _is_hip else 32
266
260
  config = {
267
261
  "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
268
262
  "num_warps": max(
269
- min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
263
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
270
264
  ),
271
265
  }
272
266
 
@@ -331,6 +325,75 @@ class FusedDualResidualRMSNorm:
331
325
  return self.rmsnorm2.forward_native(residual), residual
332
326
 
333
327
 
328
+ @triton.jit
329
+ def experts_combine_kernel(
330
+ out_hidden_states,
331
+ moe_hidden_states,
332
+ mlp_hidden_states,
333
+ combine_k: tl.constexpr,
334
+ hidden_dim: tl.constexpr,
335
+ BLOCK_SIZE: tl.constexpr,
336
+ ):
337
+ pid = tl.program_id(0)
338
+ start_index_mlp = pid * hidden_dim
339
+ start_index_rmoe = pid * hidden_dim * combine_k
340
+ offsets = tl.arange(0, BLOCK_SIZE)
341
+ mask = offsets < hidden_dim
342
+ combine_k_offsets = tl.arange(0, combine_k)
343
+
344
+ moe_x = tl.load(
345
+ moe_hidden_states
346
+ + start_index_rmoe
347
+ + combine_k_offsets[:, None] * hidden_dim
348
+ + offsets[None, :],
349
+ mask=mask[None, :],
350
+ other=0.0,
351
+ )
352
+ moe_x = tl.sum(moe_x, axis=0)
353
+ mlp_x = tl.load(mlp_hidden_states + start_index_mlp + offsets, mask=mask, other=0.0)
354
+ combined_x = (moe_x + mlp_x) / 1.4142135623730951
355
+
356
+ tl.store(out_hidden_states + start_index_mlp + offsets, combined_x, mask=mask)
357
+
358
+
359
+ def experts_combine_triton(moe_hidden_states, mlp_hidden_states, output_buffer=None):
360
+ assert moe_hidden_states.is_contiguous()
361
+ assert mlp_hidden_states.is_contiguous()
362
+
363
+ if len(moe_hidden_states.shape) == 2:
364
+ combine_k = 1 # pre-combined
365
+ else:
366
+ combine_k = moe_hidden_states.shape[1]
367
+
368
+ if output_buffer is None:
369
+ out_hidden_states = torch.empty_like(mlp_hidden_states)
370
+ else:
371
+ flat_output_buffer = output_buffer.view(mlp_hidden_states.dtype).reshape(-1)
372
+ assert flat_output_buffer.numel() >= mlp_hidden_states.numel()
373
+ out_hidden_states = flat_output_buffer[: mlp_hidden_states.numel()].reshape(
374
+ mlp_hidden_states.shape
375
+ )
376
+
377
+ bs, hidden_dim = mlp_hidden_states.shape
378
+
379
+ config = {
380
+ "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
381
+ "num_warps": max(
382
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 1024)), 8), 4
383
+ ),
384
+ }
385
+
386
+ experts_combine_kernel[(bs,)](
387
+ out_hidden_states,
388
+ moe_hidden_states,
389
+ mlp_hidden_states,
390
+ combine_k,
391
+ hidden_dim,
392
+ **config,
393
+ )
394
+ return out_hidden_states
395
+
396
+
334
397
  # gelu on first half of vector
335
398
  @triton.jit
336
399
  def gelu_and_mul_kernel(
@@ -400,10 +463,11 @@ def gelu_and_mul_triton(
400
463
  out_scales = scales
401
464
  static_scale = True
402
465
 
466
+ max_warps = 16 if _is_hip else 32
403
467
  config = {
404
468
  # 8 ele per thread (not tuned)
405
469
  "num_warps": max(
406
- min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
470
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
407
471
  ),
408
472
  }
409
473