sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,219 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Optional
5
+
6
+ import torch
7
+ import torch_npu
8
+ from torch.nn.functional import scaled_dot_product_attention
9
+
10
+ from sglang.srt.configs.model_config import AttentionArch
11
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
12
+ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
13
+ from sglang.srt.layers.radix_attention import AttentionType
14
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
15
+
16
+ if TYPE_CHECKING:
17
+ from sglang.srt.layers.radix_attention import RadixAttention
18
+ from sglang.srt.model_executor.model_runner import ModelRunner
19
+
20
+
21
+ @dataclass
22
+ class ForwardMetadata:
23
+
24
+ # calculated map for kv positions [bs * maxseqlen]
25
+ block_tables: Optional[torch.Tensor] = None
26
+
27
+ # seq len inputs
28
+ extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
29
+ seq_lens_cpu_int: Optional[torch.Tensor] = None
30
+
31
+
32
+ class AscendAttnBackend(AttentionBackend):
33
+
34
+ def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16):
35
+ mask_flag = torch.tril(
36
+ torch.ones((max_seq_len, max_seq_len), dtype=torch.bool)
37
+ ).view(max_seq_len, max_seq_len)
38
+ mask_flag = ~mask_flag
39
+ if dtype == torch.float16:
40
+ mask_value = torch.finfo(torch.float32).min
41
+ else:
42
+ mask_value = 1
43
+ self.mask = (
44
+ torch.masked_fill(
45
+ torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value
46
+ )
47
+ .to(dtype)
48
+ .to(self.device)
49
+ )
50
+ self.mask_len = max_seq_len
51
+
52
+ def __init__(self, model_runner: ModelRunner):
53
+ super().__init__()
54
+ self.forward_metadata = ForwardMetadata()
55
+ self.device = model_runner.device
56
+ self.gen_attention_mask(128, model_runner.dtype)
57
+ self.page_size = model_runner.page_size
58
+ self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
59
+ if self.use_mla:
60
+ self.kv_lora_rank = model_runner.model_config.kv_lora_rank
61
+ self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
62
+ self.native_attn = TorchNativeAttnBackend(model_runner)
63
+
64
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
65
+ """Init the metadata for a forward pass."""
66
+ self.forward_metadata.block_tables = (
67
+ forward_batch.req_to_token_pool.req_to_token[
68
+ forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
69
+ ][:, :: self.page_size]
70
+ // self.page_size
71
+ )
72
+ if forward_batch.extend_seq_lens is not None:
73
+ self.forward_metadata.extend_seq_lens_cpu_int = (
74
+ forward_batch.extend_seq_lens.cpu().int()
75
+ )
76
+ self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
77
+
78
+ def forward_extend(
79
+ self,
80
+ q,
81
+ k,
82
+ v,
83
+ layer: RadixAttention,
84
+ forward_batch: ForwardBatch,
85
+ save_kv_cache=True,
86
+ ):
87
+ if save_kv_cache:
88
+ forward_batch.token_to_kv_pool.set_kv_buffer(
89
+ layer, forward_batch.out_cache_loc, k, v
90
+ )
91
+
92
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
93
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
94
+
95
+ if not self.use_mla:
96
+ query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
97
+ output = torch.empty(
98
+ (query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
99
+ dtype=query.dtype,
100
+ device=query.device,
101
+ )
102
+
103
+ torch_npu._npu_flash_attention_qlens(
104
+ query=query,
105
+ key_cache=k_cache,
106
+ value_cache=v_cache,
107
+ mask=self.mask,
108
+ block_table=self.forward_metadata.block_tables,
109
+ seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
110
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
111
+ scale_value=layer.scaling,
112
+ num_heads=layer.tp_q_head_num,
113
+ num_kv_heads=layer.tp_k_head_num,
114
+ out=output,
115
+ )
116
+ return output
117
+ else:
118
+ if layer.qk_head_dim != layer.v_head_dim:
119
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
120
+ else:
121
+ o = torch.empty_like(q)
122
+
123
+ use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
124
+
125
+ q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
126
+ o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
127
+
128
+ causal = True
129
+ if (
130
+ layer.is_cross_attention
131
+ or layer.attn_type == AttentionType.ENCODER_ONLY
132
+ ):
133
+ causal = False
134
+
135
+ self.native_attn._run_sdpa_forward_extend(
136
+ q_,
137
+ o_,
138
+ k_cache.view(
139
+ -1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim)
140
+ ),
141
+ v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank),
142
+ forward_batch.req_to_token_pool.req_to_token,
143
+ forward_batch.req_pool_indices,
144
+ forward_batch.seq_lens,
145
+ forward_batch.extend_prefix_lens,
146
+ forward_batch.extend_seq_lens,
147
+ scaling=layer.scaling,
148
+ enable_gqa=use_gqa,
149
+ causal=causal,
150
+ )
151
+ return o
152
+
153
+ def forward_decode(
154
+ self,
155
+ q: torch.Tensor,
156
+ k: torch.Tensor,
157
+ v: torch.Tensor,
158
+ layer: RadixAttention,
159
+ forward_batch: ForwardBatch,
160
+ save_kv_cache=True,
161
+ ):
162
+ if save_kv_cache:
163
+ forward_batch.token_to_kv_pool.set_kv_buffer(
164
+ layer, forward_batch.out_cache_loc, k, v
165
+ )
166
+ if not self.use_mla:
167
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
168
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
169
+
170
+ query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
171
+ num_tokens = query.shape[0]
172
+ output = torch.empty(
173
+ (num_tokens, layer.tp_q_head_num, layer.v_head_dim),
174
+ dtype=query.dtype,
175
+ device=query.device,
176
+ )
177
+
178
+ torch_npu._npu_paged_attention(
179
+ query=query,
180
+ key_cache=k_cache,
181
+ value_cache=v_cache,
182
+ num_heads=layer.tp_q_head_num,
183
+ num_kv_heads=layer.tp_k_head_num,
184
+ scale_value=layer.scaling,
185
+ block_table=self.forward_metadata.block_tables,
186
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
187
+ out=output,
188
+ )
189
+ return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
190
+ else:
191
+ query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
192
+ num_tokens = query.shape[0]
193
+ kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer(
194
+ layer.layer_id
195
+ )
196
+ kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
197
+ -1,
198
+ self.page_size,
199
+ layer.tp_k_head_num,
200
+ self.kv_lora_rank + self.qk_rope_head_dim,
201
+ )
202
+
203
+ attn_output = torch.empty(
204
+ [num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
205
+ dtype=q.dtype,
206
+ device=q.device,
207
+ )
208
+ torch_npu._npu_paged_attention_mla(
209
+ query=query,
210
+ key_cache=kv_c_and_k_pe_cache,
211
+ num_kv_heads=layer.tp_k_head_num,
212
+ num_heads=layer.tp_q_head_num,
213
+ scale_value=layer.scaling,
214
+ block_table=self.forward_metadata.block_tables,
215
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
216
+ mla_vheadsize=self.kv_lora_rank,
217
+ out=attn_output,
218
+ )
219
+ return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
@@ -9,6 +9,7 @@ import torch
9
9
  from sglang.srt.configs.model_config import AttentionArch
10
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
11
  from sglang.srt.managers.schedule_batch import global_server_args_dict
12
+ from sglang.srt.mem_cache.memory_pool import SWAKVPool
12
13
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
13
14
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
14
15
 
@@ -320,6 +321,11 @@ class FlashAttentionBackend(AttentionBackend):
320
321
  self.page_size = model_runner.page_size
321
322
  self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
322
323
  self.skip_prefill = skip_prefill
324
+ self.is_hybrid = model_runner.is_hybrid
325
+ if self.is_hybrid:
326
+ self.full_to_swa_index_mapping = (
327
+ model_runner.token_to_kv_pool.full_to_swa_index_mapping
328
+ )
323
329
  self.topk = model_runner.server_args.speculative_eagle_topk or 0
324
330
  self.speculative_num_steps = speculative_num_steps
325
331
  self.speculative_num_draft_tokens = (
@@ -428,7 +434,7 @@ class FlashAttentionBackend(AttentionBackend):
428
434
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
429
435
  ]
430
436
  # TODO: we need to test this part for llama 4 eagle case
431
- self._init_local_attn_metadata(metadata, device)
437
+ self._init_local_attn_metadata(forward_batch, metadata, device)
432
438
  elif forward_batch.forward_mode.is_target_verify():
433
439
  if self.topk <= 1:
434
440
  metadata.cache_seqlens_int32 = (
@@ -456,7 +462,7 @@ class FlashAttentionBackend(AttentionBackend):
456
462
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
457
463
  ]
458
464
 
459
- self._init_local_attn_metadata(metadata, device)
465
+ self._init_local_attn_metadata(forward_batch, metadata, device)
460
466
  else:
461
467
  metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
462
468
  metadata.max_seq_len_q = self.speculative_num_draft_tokens
@@ -575,7 +581,7 @@ class FlashAttentionBackend(AttentionBackend):
575
581
 
576
582
  # Setup local attention if enabled
577
583
  if forward_batch.forward_mode == ForwardMode.EXTEND:
578
- self._init_local_attn_metadata(metadata, device)
584
+ self._init_local_attn_metadata(forward_batch, metadata, device)
579
585
 
580
586
  # Encoder metadata for cross attention
581
587
  if forward_batch.encoder_lens is not None:
@@ -657,12 +663,16 @@ class FlashAttentionBackend(AttentionBackend):
657
663
  )
658
664
  k_descale, v_descale = None, None
659
665
  # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
660
- # has corresponding quantization method so that layer.k_scale is not None
661
- if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None:
662
- descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
663
- k_descale = layer.k_scale.expand(descale_shape)
664
- v_descale = layer.v_scale.expand(descale_shape)
666
+ # has corresponding quantization method so that layer.k_scale is not None,
667
+ # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
668
+ if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
669
+ if layer.k_scale is not None:
670
+ descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
671
+ k_descale = layer.k_scale.expand(descale_shape)
672
+ v_descale = layer.v_scale.expand(descale_shape)
665
673
  q = q.to(self.kv_cache_dtype)
674
+ q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
675
+ k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
666
676
  causal = not layer.is_cross_attention
667
677
 
668
678
  # Check if we should use local attention
@@ -776,8 +786,8 @@ class FlashAttentionBackend(AttentionBackend):
776
786
 
777
787
  output, lse, *rest = flash_attn_varlen_func(
778
788
  q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
779
- k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
780
- v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
789
+ k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
790
+ v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
781
791
  cu_seqlens_q=metadata.cu_seqlens_q,
782
792
  cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
783
793
  max_seqlen_q=metadata.max_seq_len_q,
@@ -790,8 +800,8 @@ class FlashAttentionBackend(AttentionBackend):
790
800
  # MHA for extend part of sequence without attending prefix kv cache
791
801
  output, lse, *rest = flash_attn_varlen_func(
792
802
  q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
793
- k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
794
- v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
803
+ k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
804
+ v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
795
805
  cu_seqlens_q=metadata.cu_seqlens_q,
796
806
  cu_seqlens_k=metadata.cu_seqlens_q,
797
807
  max_seqlen_q=metadata.max_seq_len_q,
@@ -803,7 +813,9 @@ class FlashAttentionBackend(AttentionBackend):
803
813
  return output, lse
804
814
  else:
805
815
  # Do absorbed multi-latent attention
806
- kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
816
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
817
+ layer.layer_id
818
+ ).to(q.dtype)
807
819
  k_rope = kv_cache[:, :, layer.v_head_dim :]
808
820
  c_kv = kv_cache[:, :, : layer.v_head_dim]
809
821
  k_rope_cache = k_rope.view(
@@ -933,14 +945,16 @@ class FlashAttentionBackend(AttentionBackend):
933
945
 
934
946
  k_descale, v_descale = None, None
935
947
  # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
936
- # has corresponding quantization method so that layer.k_scale is not None
937
- if self.kv_cache_dtype_str != "auto":
948
+ # has corresponding quantization method so that layer.k_scale is not None,
949
+ # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
950
+ if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
938
951
  if layer.k_scale is not None:
939
952
  descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
940
953
  k_descale = layer.k_scale.expand(descale_shape)
941
954
  v_descale = layer.v_scale.expand(descale_shape)
942
955
  q = q.to(self.kv_cache_dtype)
943
-
956
+ q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
957
+ k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
944
958
  if not self.use_mla:
945
959
  # Do multi-head attention
946
960
 
@@ -1048,7 +1062,9 @@ class FlashAttentionBackend(AttentionBackend):
1048
1062
  o = result
1049
1063
  else:
1050
1064
  # Do absorbed multi-latent attention
1051
- kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
1065
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
1066
+ q.dtype
1067
+ )
1052
1068
  k_rope = kv_cache[:, :, layer.v_head_dim :]
1053
1069
  c_kv = kv_cache[:, :, : layer.v_head_dim]
1054
1070
  k_rope_cache = k_rope.view(
@@ -1578,7 +1594,7 @@ class FlashAttentionBackend(AttentionBackend):
1578
1594
  forward_mode: ForwardMode,
1579
1595
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
1580
1596
  seq_lens_cpu: Optional[torch.Tensor],
1581
- out_cache_loc: torch.Tensor = None,
1597
+ out_cache_loc: Optional[torch.Tensor] = None,
1582
1598
  ):
1583
1599
  """Initialize forward metadata for replaying CUDA graph."""
1584
1600
  seq_lens = seq_lens[:bs]
@@ -1663,7 +1679,10 @@ class FlashAttentionBackend(AttentionBackend):
1663
1679
  self.page_size,
1664
1680
  )
1665
1681
 
1666
- self._update_local_attn_metadata_for_replay(metadata, bs)
1682
+ self._update_local_attn_metadata_for_replay(
1683
+ metadata,
1684
+ bs,
1685
+ )
1667
1686
  elif forward_mode.is_target_verify():
1668
1687
  if self.topk <= 1:
1669
1688
  metadata = self.target_verify_metadata[bs]
@@ -1819,7 +1838,9 @@ class FlashAttentionBackend(AttentionBackend):
1819
1838
  """Get the fill value for sequence length in CUDA graph."""
1820
1839
  return 1
1821
1840
 
1822
- def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
1841
+ def _init_local_attn_metadata(
1842
+ self, forwardbatch: ForwardBatch, metadata: FlashAttentionMetadata, device
1843
+ ):
1823
1844
  """Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
1824
1845
  if self.attention_chunk_size is None:
1825
1846
  metadata.local_attn_metadata = None
@@ -1827,7 +1848,12 @@ class FlashAttentionBackend(AttentionBackend):
1827
1848
 
1828
1849
  cu_seqlens_q = metadata.cu_seqlens_q
1829
1850
  cache_seqlens_int32 = metadata.cache_seqlens_int32
1830
- page_table = metadata.page_table
1851
+ if self.is_hybrid:
1852
+ page_table = self.full_to_swa_index_mapping[metadata.page_table].to(
1853
+ torch.int32
1854
+ )
1855
+ else:
1856
+ page_table = metadata.page_table
1831
1857
  if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
1832
1858
  metadata.local_attn_metadata = None
1833
1859
  return
@@ -1913,7 +1939,9 @@ class FlashAttentionBackend(AttentionBackend):
1913
1939
  )
1914
1940
 
1915
1941
  def _update_local_attn_metadata_for_replay(
1916
- self, metadata: FlashAttentionMetadata, bs: int
1942
+ self,
1943
+ metadata: FlashAttentionMetadata,
1944
+ bs: int,
1917
1945
  ):
1918
1946
  """Update preallocated local attention metadata in-place before CUDA graph replay."""
1919
1947
  if self.attention_chunk_size is None:
@@ -1944,7 +1972,12 @@ class FlashAttentionBackend(AttentionBackend):
1944
1972
  # Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
1945
1973
  # beyond the actual sequence length, leading to incorrect attention calculations
1946
1974
  max_seq_len = int(seqlens.max().item())
1947
- sliced_page_table = metadata.page_table[:bs, :max_seq_len]
1975
+ if self.is_hybrid:
1976
+ sliced_page_table = self.full_to_swa_index_mapping[
1977
+ metadata.page_table[:bs, :max_seq_len]
1978
+ ].to(torch.int32)
1979
+ else:
1980
+ sliced_page_table = metadata.page_table[:bs, :max_seq_len]
1948
1981
 
1949
1982
  cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
1950
1983
  seqlens_np = seqlens.cpu().numpy()
@@ -119,21 +119,27 @@ class TboAttnBackend(AttentionBackend):
119
119
  replay_seq_lens_sum: int = None,
120
120
  replay_seq_lens_cpu: Optional[torch.Tensor] = None,
121
121
  ):
122
+ token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
123
+ forward_mode=forward_mode, spec_info=spec_info
124
+ )
122
125
  if fn_name == "init_forward_metadata_capture_cuda_graph":
123
- assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
124
- num_tokens = bs
126
+ assert (
127
+ capture_num_tokens == bs * token_num_per_seq
128
+ ), "For target-verify or decode mode, num_tokens should be equal to token_num_per_seq * bs"
129
+ num_tokens = bs * token_num_per_seq
125
130
 
126
131
  tbo_split_seq_index, tbo_split_token_index = (
127
132
  two_batch_overlap.compute_split_indices_for_cuda_graph_replay(
128
133
  forward_mode=forward_mode,
129
134
  cuda_graph_num_tokens=num_tokens,
135
+ spec_info=spec_info,
130
136
  )
131
137
  )
132
138
 
133
139
  num_tokens_child_left = tbo_split_token_index
134
140
  num_tokens_child_right = num_tokens - tbo_split_token_index
135
- bs_child_left = num_tokens_child_left
136
- bs_child_right = num_tokens_child_right
141
+ bs_child_left = tbo_split_seq_index
142
+ bs_child_right = bs - bs_child_left
137
143
 
138
144
  assert (
139
145
  num_tokens_child_left > 0 and num_tokens_child_right > 0
@@ -190,16 +196,36 @@ def _init_forward_metadata_cuda_graph_split(
190
196
  seq_lens: torch.Tensor,
191
197
  encoder_lens: Optional[torch.Tensor],
192
198
  forward_mode: "ForwardMode",
193
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
199
+ spec_info: Optional[EagleVerifyInput],
194
200
  # capture args
195
201
  capture_num_tokens: int = None,
196
202
  # replay args
197
203
  replay_seq_lens_sum: int = None,
198
204
  replay_seq_lens_cpu: Optional[torch.Tensor] = None,
199
205
  ):
206
+ token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
207
+ forward_mode=forward_mode, spec_info=spec_info
208
+ )
200
209
  assert encoder_lens is None, "encoder_lens is not supported yet"
201
- assert spec_info is None, "spec_info is not supported yet"
210
+ if spec_info is not None:
211
+ output_spec_info = two_batch_overlap.split_spec_info(
212
+ spec_info=spec_info,
213
+ start_seq_index=seq_slice.start if seq_slice.start is not None else 0,
214
+ end_seq_index=seq_slice.stop if seq_slice.stop is not None else bs,
215
+ start_token_index=(
216
+ seq_slice.start * token_num_per_seq
217
+ if seq_slice.start is not None
218
+ else 0
219
+ ),
220
+ end_token_index=(
221
+ seq_slice.stop * token_num_per_seq
222
+ if seq_slice.stop is not None
223
+ else bs * token_num_per_seq
224
+ ),
225
+ )
202
226
 
227
+ else:
228
+ output_spec_info = None
203
229
  ans = dict(
204
230
  bs=output_bs,
205
231
  req_pool_indices=req_pool_indices[seq_slice],
@@ -208,14 +234,16 @@ def _init_forward_metadata_cuda_graph_split(
208
234
  forward_mode=forward_mode,
209
235
  # ignore
210
236
  encoder_lens=None,
211
- spec_info=None,
237
+ spec_info=output_spec_info,
212
238
  )
213
239
 
214
240
  if fn_name == "init_forward_metadata_capture_cuda_graph":
215
- assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
241
+ assert (
242
+ capture_num_tokens == bs * token_num_per_seq
243
+ ), "Only support num_tokens==bs * token_num_per_seq for target-verify or decode mode"
216
244
  ans.update(
217
245
  dict(
218
- num_tokens=output_bs,
246
+ num_tokens=output_bs * token_num_per_seq,
219
247
  )
220
248
  )
221
249
  elif fn_name == "init_forward_metadata_replay_cuda_graph":
@@ -32,8 +32,13 @@ from sglang.srt.layers.dp_attention import (
32
32
  get_attention_tp_rank,
33
33
  get_attention_tp_size,
34
34
  )
35
+ from sglang.srt.layers.utils import is_sm100_supported
35
36
  from sglang.srt.managers.schedule_batch import global_server_args_dict
36
37
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
38
+ from sglang.srt.utils import is_cuda, is_flashinfer_available
39
+
40
+ _is_flashinfer_available = is_flashinfer_available()
41
+ _is_sm100_supported = is_cuda() and is_sm100_supported()
37
42
 
38
43
 
39
44
  class ScatterMode(Enum):
@@ -397,8 +402,19 @@ class CommunicateWithAllReduceAndLayerNormFn:
397
402
  if hidden_states.shape[0] != 0:
398
403
  hidden_states = layernorm(hidden_states)
399
404
  else:
400
- hidden_states = tensor_model_parallel_all_reduce(hidden_states)
401
- hidden_states, residual = layernorm(hidden_states, residual)
405
+ if (
406
+ _is_sm100_supported
407
+ and _is_flashinfer_available
408
+ and hasattr(layernorm, "forward_with_allreduce_fusion")
409
+ and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
410
+ and hidden_states.shape[0] <= 1024
411
+ ):
412
+ hidden_states, residual = layernorm.forward_with_allreduce_fusion(
413
+ hidden_states, residual
414
+ )
415
+ else:
416
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
417
+ hidden_states, residual = layernorm(hidden_states, residual)
402
418
  return hidden_states, residual
403
419
 
404
420
  @staticmethod
@@ -79,14 +79,12 @@ def initialize_dp_attention(
79
79
  )
80
80
 
81
81
  if enable_dp_attention:
82
- local_rank = tp_rank % (tp_size // dp_size)
83
82
  _ATTN_DP_SIZE = dp_size
84
83
  if moe_dense_tp_size is None:
85
84
  _LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
86
85
  else:
87
86
  _LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
88
87
  else:
89
- local_rank = tp_rank
90
88
  _ATTN_DP_SIZE = 1
91
89
  _LOCAL_ATTN_DP_SIZE = 1
92
90
 
@@ -96,7 +94,7 @@ def initialize_dp_attention(
96
94
  list(range(head, head + _ATTN_TP_SIZE))
97
95
  for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
98
96
  ],
99
- local_rank,
97
+ tp_group.local_rank,
100
98
  torch.distributed.get_backend(tp_group.device_group),
101
99
  use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
102
100
  use_pymscclpp=False,
@@ -239,6 +237,10 @@ def _dp_gather(
239
237
  assert (
240
238
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
241
239
  ), "aliasing between global_tokens and local_tokens not allowed"
240
+
241
+ # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
242
+ # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
243
+ # actual size of the accepted tokens.
242
244
  if forward_batch.forward_mode.is_draft_extend():
243
245
  shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
244
246
  local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
@@ -293,6 +295,10 @@ def dp_scatter(
293
295
  assert (
294
296
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
295
297
  ), "aliasing between local_tokens and global_tokens not allowed"
298
+
299
+ # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
300
+ # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
301
+ # actual size of the accepted tokens.
296
302
  if forward_batch.forward_mode.is_draft_extend():
297
303
  shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
298
304
  local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)