sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +49 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +35 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -6
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +100 -52
- sglang/srt/disaggregation/prefill.py +5 -4
- sglang/srt/disaggregation/utils.py +13 -12
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +45 -9
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +51 -6
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +56 -23
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +41 -0
- sglang/srt/layers/linear.py +99 -12
- sglang/srt/layers/logits_processor.py +15 -6
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +115 -25
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +36 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +6 -6
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +105 -13
- sglang/srt/layers/vocab_parallel_embedding.py +19 -2
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +60 -15
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +80 -79
- sglang/srt/managers/scheduler.py +153 -63
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +302 -58
- sglang/srt/model_loader/loader.py +86 -10
- sglang/srt/model_loader/weight_utils.py +160 -3
- sglang/srt/models/deepseek_nextn.py +5 -4
- sglang/srt/models/deepseek_v2.py +305 -26
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1010 -0
- sglang/srt/models/gemma3n_mm.py +495 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/multimodal/processors/gemma3n.py +82 -0
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +85 -24
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +204 -28
- sglang/srt/utils.py +369 -138
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,219 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import torch_npu
|
8
|
+
from torch.nn.functional import scaled_dot_product_attention
|
9
|
+
|
10
|
+
from sglang.srt.configs.model_config import AttentionArch
|
11
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
12
|
+
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
13
|
+
from sglang.srt.layers.radix_attention import AttentionType
|
14
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
18
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class ForwardMetadata:
|
23
|
+
|
24
|
+
# calculated map for kv positions [bs * maxseqlen]
|
25
|
+
block_tables: Optional[torch.Tensor] = None
|
26
|
+
|
27
|
+
# seq len inputs
|
28
|
+
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
|
29
|
+
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
30
|
+
|
31
|
+
|
32
|
+
class AscendAttnBackend(AttentionBackend):
|
33
|
+
|
34
|
+
def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16):
|
35
|
+
mask_flag = torch.tril(
|
36
|
+
torch.ones((max_seq_len, max_seq_len), dtype=torch.bool)
|
37
|
+
).view(max_seq_len, max_seq_len)
|
38
|
+
mask_flag = ~mask_flag
|
39
|
+
if dtype == torch.float16:
|
40
|
+
mask_value = torch.finfo(torch.float32).min
|
41
|
+
else:
|
42
|
+
mask_value = 1
|
43
|
+
self.mask = (
|
44
|
+
torch.masked_fill(
|
45
|
+
torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value
|
46
|
+
)
|
47
|
+
.to(dtype)
|
48
|
+
.to(self.device)
|
49
|
+
)
|
50
|
+
self.mask_len = max_seq_len
|
51
|
+
|
52
|
+
def __init__(self, model_runner: ModelRunner):
|
53
|
+
super().__init__()
|
54
|
+
self.forward_metadata = ForwardMetadata()
|
55
|
+
self.device = model_runner.device
|
56
|
+
self.gen_attention_mask(128, model_runner.dtype)
|
57
|
+
self.page_size = model_runner.page_size
|
58
|
+
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
59
|
+
if self.use_mla:
|
60
|
+
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
61
|
+
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
62
|
+
self.native_attn = TorchNativeAttnBackend(model_runner)
|
63
|
+
|
64
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
65
|
+
"""Init the metadata for a forward pass."""
|
66
|
+
self.forward_metadata.block_tables = (
|
67
|
+
forward_batch.req_to_token_pool.req_to_token[
|
68
|
+
forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
|
69
|
+
][:, :: self.page_size]
|
70
|
+
// self.page_size
|
71
|
+
)
|
72
|
+
if forward_batch.extend_seq_lens is not None:
|
73
|
+
self.forward_metadata.extend_seq_lens_cpu_int = (
|
74
|
+
forward_batch.extend_seq_lens.cpu().int()
|
75
|
+
)
|
76
|
+
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
77
|
+
|
78
|
+
def forward_extend(
|
79
|
+
self,
|
80
|
+
q,
|
81
|
+
k,
|
82
|
+
v,
|
83
|
+
layer: RadixAttention,
|
84
|
+
forward_batch: ForwardBatch,
|
85
|
+
save_kv_cache=True,
|
86
|
+
):
|
87
|
+
if save_kv_cache:
|
88
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
89
|
+
layer, forward_batch.out_cache_loc, k, v
|
90
|
+
)
|
91
|
+
|
92
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
93
|
+
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
94
|
+
|
95
|
+
if not self.use_mla:
|
96
|
+
query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
97
|
+
output = torch.empty(
|
98
|
+
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
|
99
|
+
dtype=query.dtype,
|
100
|
+
device=query.device,
|
101
|
+
)
|
102
|
+
|
103
|
+
torch_npu._npu_flash_attention_qlens(
|
104
|
+
query=query,
|
105
|
+
key_cache=k_cache,
|
106
|
+
value_cache=v_cache,
|
107
|
+
mask=self.mask,
|
108
|
+
block_table=self.forward_metadata.block_tables,
|
109
|
+
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
|
110
|
+
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
111
|
+
scale_value=layer.scaling,
|
112
|
+
num_heads=layer.tp_q_head_num,
|
113
|
+
num_kv_heads=layer.tp_k_head_num,
|
114
|
+
out=output,
|
115
|
+
)
|
116
|
+
return output
|
117
|
+
else:
|
118
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
119
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
120
|
+
else:
|
121
|
+
o = torch.empty_like(q)
|
122
|
+
|
123
|
+
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
124
|
+
|
125
|
+
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
126
|
+
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
127
|
+
|
128
|
+
causal = True
|
129
|
+
if (
|
130
|
+
layer.is_cross_attention
|
131
|
+
or layer.attn_type == AttentionType.ENCODER_ONLY
|
132
|
+
):
|
133
|
+
causal = False
|
134
|
+
|
135
|
+
self.native_attn._run_sdpa_forward_extend(
|
136
|
+
q_,
|
137
|
+
o_,
|
138
|
+
k_cache.view(
|
139
|
+
-1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim)
|
140
|
+
),
|
141
|
+
v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank),
|
142
|
+
forward_batch.req_to_token_pool.req_to_token,
|
143
|
+
forward_batch.req_pool_indices,
|
144
|
+
forward_batch.seq_lens,
|
145
|
+
forward_batch.extend_prefix_lens,
|
146
|
+
forward_batch.extend_seq_lens,
|
147
|
+
scaling=layer.scaling,
|
148
|
+
enable_gqa=use_gqa,
|
149
|
+
causal=causal,
|
150
|
+
)
|
151
|
+
return o
|
152
|
+
|
153
|
+
def forward_decode(
|
154
|
+
self,
|
155
|
+
q: torch.Tensor,
|
156
|
+
k: torch.Tensor,
|
157
|
+
v: torch.Tensor,
|
158
|
+
layer: RadixAttention,
|
159
|
+
forward_batch: ForwardBatch,
|
160
|
+
save_kv_cache=True,
|
161
|
+
):
|
162
|
+
if save_kv_cache:
|
163
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
164
|
+
layer, forward_batch.out_cache_loc, k, v
|
165
|
+
)
|
166
|
+
if not self.use_mla:
|
167
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
168
|
+
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
169
|
+
|
170
|
+
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
171
|
+
num_tokens = query.shape[0]
|
172
|
+
output = torch.empty(
|
173
|
+
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
174
|
+
dtype=query.dtype,
|
175
|
+
device=query.device,
|
176
|
+
)
|
177
|
+
|
178
|
+
torch_npu._npu_paged_attention(
|
179
|
+
query=query,
|
180
|
+
key_cache=k_cache,
|
181
|
+
value_cache=v_cache,
|
182
|
+
num_heads=layer.tp_q_head_num,
|
183
|
+
num_kv_heads=layer.tp_k_head_num,
|
184
|
+
scale_value=layer.scaling,
|
185
|
+
block_table=self.forward_metadata.block_tables,
|
186
|
+
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
187
|
+
out=output,
|
188
|
+
)
|
189
|
+
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
190
|
+
else:
|
191
|
+
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
192
|
+
num_tokens = query.shape[0]
|
193
|
+
kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
194
|
+
layer.layer_id
|
195
|
+
)
|
196
|
+
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
197
|
+
-1,
|
198
|
+
self.page_size,
|
199
|
+
layer.tp_k_head_num,
|
200
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
201
|
+
)
|
202
|
+
|
203
|
+
attn_output = torch.empty(
|
204
|
+
[num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
|
205
|
+
dtype=q.dtype,
|
206
|
+
device=q.device,
|
207
|
+
)
|
208
|
+
torch_npu._npu_paged_attention_mla(
|
209
|
+
query=query,
|
210
|
+
key_cache=kv_c_and_k_pe_cache,
|
211
|
+
num_kv_heads=layer.tp_k_head_num,
|
212
|
+
num_heads=layer.tp_q_head_num,
|
213
|
+
scale_value=layer.scaling,
|
214
|
+
block_table=self.forward_metadata.block_tables,
|
215
|
+
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
216
|
+
mla_vheadsize=self.kv_lora_rank,
|
217
|
+
out=attn_output,
|
218
|
+
)
|
219
|
+
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
|
@@ -9,6 +9,7 @@ import torch
|
|
9
9
|
from sglang.srt.configs.model_config import AttentionArch
|
10
10
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
11
11
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
12
|
+
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
12
13
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
13
14
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
14
15
|
|
@@ -320,6 +321,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
320
321
|
self.page_size = model_runner.page_size
|
321
322
|
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
322
323
|
self.skip_prefill = skip_prefill
|
324
|
+
self.is_hybrid = model_runner.is_hybrid
|
325
|
+
if self.is_hybrid:
|
326
|
+
self.full_to_swa_index_mapping = (
|
327
|
+
model_runner.token_to_kv_pool.full_to_swa_index_mapping
|
328
|
+
)
|
323
329
|
self.topk = model_runner.server_args.speculative_eagle_topk or 0
|
324
330
|
self.speculative_num_steps = speculative_num_steps
|
325
331
|
self.speculative_num_draft_tokens = (
|
@@ -428,7 +434,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
428
434
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
429
435
|
]
|
430
436
|
# TODO: we need to test this part for llama 4 eagle case
|
431
|
-
self._init_local_attn_metadata(metadata, device)
|
437
|
+
self._init_local_attn_metadata(forward_batch, metadata, device)
|
432
438
|
elif forward_batch.forward_mode.is_target_verify():
|
433
439
|
if self.topk <= 1:
|
434
440
|
metadata.cache_seqlens_int32 = (
|
@@ -456,7 +462,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
456
462
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
457
463
|
]
|
458
464
|
|
459
|
-
self._init_local_attn_metadata(metadata, device)
|
465
|
+
self._init_local_attn_metadata(forward_batch, metadata, device)
|
460
466
|
else:
|
461
467
|
metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
|
462
468
|
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
@@ -575,7 +581,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
575
581
|
|
576
582
|
# Setup local attention if enabled
|
577
583
|
if forward_batch.forward_mode == ForwardMode.EXTEND:
|
578
|
-
self._init_local_attn_metadata(metadata, device)
|
584
|
+
self._init_local_attn_metadata(forward_batch, metadata, device)
|
579
585
|
|
580
586
|
# Encoder metadata for cross attention
|
581
587
|
if forward_batch.encoder_lens is not None:
|
@@ -657,12 +663,16 @@ class FlashAttentionBackend(AttentionBackend):
|
|
657
663
|
)
|
658
664
|
k_descale, v_descale = None, None
|
659
665
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
660
|
-
# has corresponding quantization method so that layer.k_scale is not None
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
666
|
+
# has corresponding quantization method so that layer.k_scale is not None,
|
667
|
+
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
|
668
|
+
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
|
669
|
+
if layer.k_scale is not None:
|
670
|
+
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
671
|
+
k_descale = layer.k_scale.expand(descale_shape)
|
672
|
+
v_descale = layer.v_scale.expand(descale_shape)
|
665
673
|
q = q.to(self.kv_cache_dtype)
|
674
|
+
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
|
675
|
+
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
|
666
676
|
causal = not layer.is_cross_attention
|
667
677
|
|
668
678
|
# Check if we should use local attention
|
@@ -776,8 +786,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
776
786
|
|
777
787
|
output, lse, *rest = flash_attn_varlen_func(
|
778
788
|
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
779
|
-
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
780
|
-
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
789
|
+
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
790
|
+
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
781
791
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
782
792
|
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
783
793
|
max_seqlen_q=metadata.max_seq_len_q,
|
@@ -790,8 +800,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
790
800
|
# MHA for extend part of sequence without attending prefix kv cache
|
791
801
|
output, lse, *rest = flash_attn_varlen_func(
|
792
802
|
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
793
|
-
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
794
|
-
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
803
|
+
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
804
|
+
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
795
805
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
796
806
|
cu_seqlens_k=metadata.cu_seqlens_q,
|
797
807
|
max_seqlen_q=metadata.max_seq_len_q,
|
@@ -803,7 +813,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
803
813
|
return output, lse
|
804
814
|
else:
|
805
815
|
# Do absorbed multi-latent attention
|
806
|
-
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
816
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
817
|
+
layer.layer_id
|
818
|
+
).to(q.dtype)
|
807
819
|
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
808
820
|
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
809
821
|
k_rope_cache = k_rope.view(
|
@@ -933,14 +945,16 @@ class FlashAttentionBackend(AttentionBackend):
|
|
933
945
|
|
934
946
|
k_descale, v_descale = None, None
|
935
947
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
936
|
-
# has corresponding quantization method so that layer.k_scale is not None
|
937
|
-
|
948
|
+
# has corresponding quantization method so that layer.k_scale is not None,
|
949
|
+
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
|
950
|
+
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
|
938
951
|
if layer.k_scale is not None:
|
939
952
|
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
940
953
|
k_descale = layer.k_scale.expand(descale_shape)
|
941
954
|
v_descale = layer.v_scale.expand(descale_shape)
|
942
955
|
q = q.to(self.kv_cache_dtype)
|
943
|
-
|
956
|
+
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
|
957
|
+
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
|
944
958
|
if not self.use_mla:
|
945
959
|
# Do multi-head attention
|
946
960
|
|
@@ -1048,7 +1062,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1048
1062
|
o = result
|
1049
1063
|
else:
|
1050
1064
|
# Do absorbed multi-latent attention
|
1051
|
-
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
1065
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
1066
|
+
q.dtype
|
1067
|
+
)
|
1052
1068
|
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
1053
1069
|
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
1054
1070
|
k_rope_cache = k_rope.view(
|
@@ -1578,7 +1594,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1578
1594
|
forward_mode: ForwardMode,
|
1579
1595
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
1580
1596
|
seq_lens_cpu: Optional[torch.Tensor],
|
1581
|
-
out_cache_loc: torch.Tensor = None,
|
1597
|
+
out_cache_loc: Optional[torch.Tensor] = None,
|
1582
1598
|
):
|
1583
1599
|
"""Initialize forward metadata for replaying CUDA graph."""
|
1584
1600
|
seq_lens = seq_lens[:bs]
|
@@ -1663,7 +1679,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1663
1679
|
self.page_size,
|
1664
1680
|
)
|
1665
1681
|
|
1666
|
-
self._update_local_attn_metadata_for_replay(
|
1682
|
+
self._update_local_attn_metadata_for_replay(
|
1683
|
+
metadata,
|
1684
|
+
bs,
|
1685
|
+
)
|
1667
1686
|
elif forward_mode.is_target_verify():
|
1668
1687
|
if self.topk <= 1:
|
1669
1688
|
metadata = self.target_verify_metadata[bs]
|
@@ -1819,7 +1838,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1819
1838
|
"""Get the fill value for sequence length in CUDA graph."""
|
1820
1839
|
return 1
|
1821
1840
|
|
1822
|
-
def _init_local_attn_metadata(
|
1841
|
+
def _init_local_attn_metadata(
|
1842
|
+
self, forwardbatch: ForwardBatch, metadata: FlashAttentionMetadata, device
|
1843
|
+
):
|
1823
1844
|
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
|
1824
1845
|
if self.attention_chunk_size is None:
|
1825
1846
|
metadata.local_attn_metadata = None
|
@@ -1827,7 +1848,12 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1827
1848
|
|
1828
1849
|
cu_seqlens_q = metadata.cu_seqlens_q
|
1829
1850
|
cache_seqlens_int32 = metadata.cache_seqlens_int32
|
1830
|
-
|
1851
|
+
if self.is_hybrid:
|
1852
|
+
page_table = self.full_to_swa_index_mapping[metadata.page_table].to(
|
1853
|
+
torch.int32
|
1854
|
+
)
|
1855
|
+
else:
|
1856
|
+
page_table = metadata.page_table
|
1831
1857
|
if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
|
1832
1858
|
metadata.local_attn_metadata = None
|
1833
1859
|
return
|
@@ -1913,7 +1939,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1913
1939
|
)
|
1914
1940
|
|
1915
1941
|
def _update_local_attn_metadata_for_replay(
|
1916
|
-
self,
|
1942
|
+
self,
|
1943
|
+
metadata: FlashAttentionMetadata,
|
1944
|
+
bs: int,
|
1917
1945
|
):
|
1918
1946
|
"""Update preallocated local attention metadata in-place before CUDA graph replay."""
|
1919
1947
|
if self.attention_chunk_size is None:
|
@@ -1944,7 +1972,12 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1944
1972
|
# Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
|
1945
1973
|
# beyond the actual sequence length, leading to incorrect attention calculations
|
1946
1974
|
max_seq_len = int(seqlens.max().item())
|
1947
|
-
|
1975
|
+
if self.is_hybrid:
|
1976
|
+
sliced_page_table = self.full_to_swa_index_mapping[
|
1977
|
+
metadata.page_table[:bs, :max_seq_len]
|
1978
|
+
].to(torch.int32)
|
1979
|
+
else:
|
1980
|
+
sliced_page_table = metadata.page_table[:bs, :max_seq_len]
|
1948
1981
|
|
1949
1982
|
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
|
1950
1983
|
seqlens_np = seqlens.cpu().numpy()
|
@@ -119,21 +119,27 @@ class TboAttnBackend(AttentionBackend):
|
|
119
119
|
replay_seq_lens_sum: int = None,
|
120
120
|
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
|
121
121
|
):
|
122
|
+
token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
|
123
|
+
forward_mode=forward_mode, spec_info=spec_info
|
124
|
+
)
|
122
125
|
if fn_name == "init_forward_metadata_capture_cuda_graph":
|
123
|
-
assert
|
124
|
-
|
126
|
+
assert (
|
127
|
+
capture_num_tokens == bs * token_num_per_seq
|
128
|
+
), "For target-verify or decode mode, num_tokens should be equal to token_num_per_seq * bs"
|
129
|
+
num_tokens = bs * token_num_per_seq
|
125
130
|
|
126
131
|
tbo_split_seq_index, tbo_split_token_index = (
|
127
132
|
two_batch_overlap.compute_split_indices_for_cuda_graph_replay(
|
128
133
|
forward_mode=forward_mode,
|
129
134
|
cuda_graph_num_tokens=num_tokens,
|
135
|
+
spec_info=spec_info,
|
130
136
|
)
|
131
137
|
)
|
132
138
|
|
133
139
|
num_tokens_child_left = tbo_split_token_index
|
134
140
|
num_tokens_child_right = num_tokens - tbo_split_token_index
|
135
|
-
bs_child_left =
|
136
|
-
bs_child_right =
|
141
|
+
bs_child_left = tbo_split_seq_index
|
142
|
+
bs_child_right = bs - bs_child_left
|
137
143
|
|
138
144
|
assert (
|
139
145
|
num_tokens_child_left > 0 and num_tokens_child_right > 0
|
@@ -190,16 +196,36 @@ def _init_forward_metadata_cuda_graph_split(
|
|
190
196
|
seq_lens: torch.Tensor,
|
191
197
|
encoder_lens: Optional[torch.Tensor],
|
192
198
|
forward_mode: "ForwardMode",
|
193
|
-
spec_info: Optional[
|
199
|
+
spec_info: Optional[EagleVerifyInput],
|
194
200
|
# capture args
|
195
201
|
capture_num_tokens: int = None,
|
196
202
|
# replay args
|
197
203
|
replay_seq_lens_sum: int = None,
|
198
204
|
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
|
199
205
|
):
|
206
|
+
token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
|
207
|
+
forward_mode=forward_mode, spec_info=spec_info
|
208
|
+
)
|
200
209
|
assert encoder_lens is None, "encoder_lens is not supported yet"
|
201
|
-
|
210
|
+
if spec_info is not None:
|
211
|
+
output_spec_info = two_batch_overlap.split_spec_info(
|
212
|
+
spec_info=spec_info,
|
213
|
+
start_seq_index=seq_slice.start if seq_slice.start is not None else 0,
|
214
|
+
end_seq_index=seq_slice.stop if seq_slice.stop is not None else bs,
|
215
|
+
start_token_index=(
|
216
|
+
seq_slice.start * token_num_per_seq
|
217
|
+
if seq_slice.start is not None
|
218
|
+
else 0
|
219
|
+
),
|
220
|
+
end_token_index=(
|
221
|
+
seq_slice.stop * token_num_per_seq
|
222
|
+
if seq_slice.stop is not None
|
223
|
+
else bs * token_num_per_seq
|
224
|
+
),
|
225
|
+
)
|
202
226
|
|
227
|
+
else:
|
228
|
+
output_spec_info = None
|
203
229
|
ans = dict(
|
204
230
|
bs=output_bs,
|
205
231
|
req_pool_indices=req_pool_indices[seq_slice],
|
@@ -208,14 +234,16 @@ def _init_forward_metadata_cuda_graph_split(
|
|
208
234
|
forward_mode=forward_mode,
|
209
235
|
# ignore
|
210
236
|
encoder_lens=None,
|
211
|
-
spec_info=
|
237
|
+
spec_info=output_spec_info,
|
212
238
|
)
|
213
239
|
|
214
240
|
if fn_name == "init_forward_metadata_capture_cuda_graph":
|
215
|
-
assert
|
241
|
+
assert (
|
242
|
+
capture_num_tokens == bs * token_num_per_seq
|
243
|
+
), "Only support num_tokens==bs * token_num_per_seq for target-verify or decode mode"
|
216
244
|
ans.update(
|
217
245
|
dict(
|
218
|
-
num_tokens=output_bs,
|
246
|
+
num_tokens=output_bs * token_num_per_seq,
|
219
247
|
)
|
220
248
|
)
|
221
249
|
elif fn_name == "init_forward_metadata_replay_cuda_graph":
|
@@ -32,8 +32,13 @@ from sglang.srt.layers.dp_attention import (
|
|
32
32
|
get_attention_tp_rank,
|
33
33
|
get_attention_tp_size,
|
34
34
|
)
|
35
|
+
from sglang.srt.layers.utils import is_sm100_supported
|
35
36
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
36
37
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
38
|
+
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
39
|
+
|
40
|
+
_is_flashinfer_available = is_flashinfer_available()
|
41
|
+
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
37
42
|
|
38
43
|
|
39
44
|
class ScatterMode(Enum):
|
@@ -397,8 +402,19 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
397
402
|
if hidden_states.shape[0] != 0:
|
398
403
|
hidden_states = layernorm(hidden_states)
|
399
404
|
else:
|
400
|
-
|
401
|
-
|
405
|
+
if (
|
406
|
+
_is_sm100_supported
|
407
|
+
and _is_flashinfer_available
|
408
|
+
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
409
|
+
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
410
|
+
and hidden_states.shape[0] <= 1024
|
411
|
+
):
|
412
|
+
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
413
|
+
hidden_states, residual
|
414
|
+
)
|
415
|
+
else:
|
416
|
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
417
|
+
hidden_states, residual = layernorm(hidden_states, residual)
|
402
418
|
return hidden_states, residual
|
403
419
|
|
404
420
|
@staticmethod
|
@@ -79,14 +79,12 @@ def initialize_dp_attention(
|
|
79
79
|
)
|
80
80
|
|
81
81
|
if enable_dp_attention:
|
82
|
-
local_rank = tp_rank % (tp_size // dp_size)
|
83
82
|
_ATTN_DP_SIZE = dp_size
|
84
83
|
if moe_dense_tp_size is None:
|
85
84
|
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
|
86
85
|
else:
|
87
86
|
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
|
88
87
|
else:
|
89
|
-
local_rank = tp_rank
|
90
88
|
_ATTN_DP_SIZE = 1
|
91
89
|
_LOCAL_ATTN_DP_SIZE = 1
|
92
90
|
|
@@ -96,7 +94,7 @@ def initialize_dp_attention(
|
|
96
94
|
list(range(head, head + _ATTN_TP_SIZE))
|
97
95
|
for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
|
98
96
|
],
|
99
|
-
local_rank,
|
97
|
+
tp_group.local_rank,
|
100
98
|
torch.distributed.get_backend(tp_group.device_group),
|
101
99
|
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
|
102
100
|
use_pymscclpp=False,
|
@@ -239,6 +237,10 @@ def _dp_gather(
|
|
239
237
|
assert (
|
240
238
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
241
239
|
), "aliasing between global_tokens and local_tokens not allowed"
|
240
|
+
|
241
|
+
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
|
242
|
+
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
|
243
|
+
# actual size of the accepted tokens.
|
242
244
|
if forward_batch.forward_mode.is_draft_extend():
|
243
245
|
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
244
246
|
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
@@ -293,6 +295,10 @@ def dp_scatter(
|
|
293
295
|
assert (
|
294
296
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
295
297
|
), "aliasing between local_tokens and global_tokens not allowed"
|
298
|
+
|
299
|
+
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
|
300
|
+
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
|
301
|
+
# actual size of the accepted tokens.
|
296
302
|
if forward_batch.forward_mode.is_draft_extend():
|
297
303
|
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
298
304
|
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|