sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,219 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import torch_npu
|
8
|
+
from torch.nn.functional import scaled_dot_product_attention
|
9
|
+
|
10
|
+
from sglang.srt.configs.model_config import AttentionArch
|
11
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
12
|
+
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
13
|
+
from sglang.srt.layers.radix_attention import AttentionType
|
14
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
18
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class ForwardMetadata:
|
23
|
+
|
24
|
+
# calculated map for kv positions [bs * maxseqlen]
|
25
|
+
block_tables: Optional[torch.Tensor] = None
|
26
|
+
|
27
|
+
# seq len inputs
|
28
|
+
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
|
29
|
+
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
30
|
+
|
31
|
+
|
32
|
+
class AscendAttnBackend(AttentionBackend):
|
33
|
+
|
34
|
+
def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16):
|
35
|
+
mask_flag = torch.tril(
|
36
|
+
torch.ones((max_seq_len, max_seq_len), dtype=torch.bool)
|
37
|
+
).view(max_seq_len, max_seq_len)
|
38
|
+
mask_flag = ~mask_flag
|
39
|
+
if dtype == torch.float16:
|
40
|
+
mask_value = torch.finfo(torch.float32).min
|
41
|
+
else:
|
42
|
+
mask_value = 1
|
43
|
+
self.mask = (
|
44
|
+
torch.masked_fill(
|
45
|
+
torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value
|
46
|
+
)
|
47
|
+
.to(dtype)
|
48
|
+
.to(self.device)
|
49
|
+
)
|
50
|
+
self.mask_len = max_seq_len
|
51
|
+
|
52
|
+
def __init__(self, model_runner: ModelRunner):
|
53
|
+
super().__init__()
|
54
|
+
self.forward_metadata = ForwardMetadata()
|
55
|
+
self.device = model_runner.device
|
56
|
+
self.gen_attention_mask(128, model_runner.dtype)
|
57
|
+
self.page_size = model_runner.page_size
|
58
|
+
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
59
|
+
if self.use_mla:
|
60
|
+
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
61
|
+
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
62
|
+
self.native_attn = TorchNativeAttnBackend(model_runner)
|
63
|
+
|
64
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
65
|
+
"""Init the metadata for a forward pass."""
|
66
|
+
self.forward_metadata.block_tables = (
|
67
|
+
forward_batch.req_to_token_pool.req_to_token[
|
68
|
+
forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
|
69
|
+
][:, :: self.page_size]
|
70
|
+
// self.page_size
|
71
|
+
)
|
72
|
+
if forward_batch.extend_seq_lens is not None:
|
73
|
+
self.forward_metadata.extend_seq_lens_cpu_int = (
|
74
|
+
forward_batch.extend_seq_lens.cpu().int()
|
75
|
+
)
|
76
|
+
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
77
|
+
|
78
|
+
def forward_extend(
|
79
|
+
self,
|
80
|
+
q,
|
81
|
+
k,
|
82
|
+
v,
|
83
|
+
layer: RadixAttention,
|
84
|
+
forward_batch: ForwardBatch,
|
85
|
+
save_kv_cache=True,
|
86
|
+
):
|
87
|
+
if save_kv_cache:
|
88
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
89
|
+
layer, forward_batch.out_cache_loc, k, v
|
90
|
+
)
|
91
|
+
|
92
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
93
|
+
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
94
|
+
|
95
|
+
if not self.use_mla:
|
96
|
+
query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
97
|
+
output = torch.empty(
|
98
|
+
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
|
99
|
+
dtype=query.dtype,
|
100
|
+
device=query.device,
|
101
|
+
)
|
102
|
+
|
103
|
+
torch_npu._npu_flash_attention_qlens(
|
104
|
+
query=query,
|
105
|
+
key_cache=k_cache,
|
106
|
+
value_cache=v_cache,
|
107
|
+
mask=self.mask,
|
108
|
+
block_table=self.forward_metadata.block_tables,
|
109
|
+
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
|
110
|
+
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
111
|
+
scale_value=layer.scaling,
|
112
|
+
num_heads=layer.tp_q_head_num,
|
113
|
+
num_kv_heads=layer.tp_k_head_num,
|
114
|
+
out=output,
|
115
|
+
)
|
116
|
+
return output
|
117
|
+
else:
|
118
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
119
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
120
|
+
else:
|
121
|
+
o = torch.empty_like(q)
|
122
|
+
|
123
|
+
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
124
|
+
|
125
|
+
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
126
|
+
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
127
|
+
|
128
|
+
causal = True
|
129
|
+
if (
|
130
|
+
layer.is_cross_attention
|
131
|
+
or layer.attn_type == AttentionType.ENCODER_ONLY
|
132
|
+
):
|
133
|
+
causal = False
|
134
|
+
|
135
|
+
self.native_attn._run_sdpa_forward_extend(
|
136
|
+
q_,
|
137
|
+
o_,
|
138
|
+
k_cache.view(
|
139
|
+
-1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim)
|
140
|
+
),
|
141
|
+
v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank),
|
142
|
+
forward_batch.req_to_token_pool.req_to_token,
|
143
|
+
forward_batch.req_pool_indices,
|
144
|
+
forward_batch.seq_lens,
|
145
|
+
forward_batch.extend_prefix_lens,
|
146
|
+
forward_batch.extend_seq_lens,
|
147
|
+
scaling=layer.scaling,
|
148
|
+
enable_gqa=use_gqa,
|
149
|
+
causal=causal,
|
150
|
+
)
|
151
|
+
return o
|
152
|
+
|
153
|
+
def forward_decode(
|
154
|
+
self,
|
155
|
+
q: torch.Tensor,
|
156
|
+
k: torch.Tensor,
|
157
|
+
v: torch.Tensor,
|
158
|
+
layer: RadixAttention,
|
159
|
+
forward_batch: ForwardBatch,
|
160
|
+
save_kv_cache=True,
|
161
|
+
):
|
162
|
+
if save_kv_cache:
|
163
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
164
|
+
layer, forward_batch.out_cache_loc, k, v
|
165
|
+
)
|
166
|
+
if not self.use_mla:
|
167
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
168
|
+
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
169
|
+
|
170
|
+
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
171
|
+
num_tokens = query.shape[0]
|
172
|
+
output = torch.empty(
|
173
|
+
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
174
|
+
dtype=query.dtype,
|
175
|
+
device=query.device,
|
176
|
+
)
|
177
|
+
|
178
|
+
torch_npu._npu_paged_attention(
|
179
|
+
query=query,
|
180
|
+
key_cache=k_cache,
|
181
|
+
value_cache=v_cache,
|
182
|
+
num_heads=layer.tp_q_head_num,
|
183
|
+
num_kv_heads=layer.tp_k_head_num,
|
184
|
+
scale_value=layer.scaling,
|
185
|
+
block_table=self.forward_metadata.block_tables,
|
186
|
+
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
187
|
+
out=output,
|
188
|
+
)
|
189
|
+
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
190
|
+
else:
|
191
|
+
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
192
|
+
num_tokens = query.shape[0]
|
193
|
+
kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
194
|
+
layer.layer_id
|
195
|
+
)
|
196
|
+
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
197
|
+
-1,
|
198
|
+
self.page_size,
|
199
|
+
layer.tp_k_head_num,
|
200
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
201
|
+
)
|
202
|
+
|
203
|
+
attn_output = torch.empty(
|
204
|
+
[num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
|
205
|
+
dtype=q.dtype,
|
206
|
+
device=q.device,
|
207
|
+
)
|
208
|
+
torch_npu._npu_paged_attention_mla(
|
209
|
+
query=query,
|
210
|
+
key_cache=kv_c_and_k_pe_cache,
|
211
|
+
num_kv_heads=layer.tp_k_head_num,
|
212
|
+
num_heads=layer.tp_q_head_num,
|
213
|
+
scale_value=layer.scaling,
|
214
|
+
block_table=self.forward_metadata.block_tables,
|
215
|
+
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
216
|
+
mla_vheadsize=self.kv_lora_rank,
|
217
|
+
out=attn_output,
|
218
|
+
)
|
219
|
+
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
|
@@ -9,6 +9,7 @@ import torch
|
|
9
9
|
from sglang.srt.configs.model_config import AttentionArch
|
10
10
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
11
11
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
12
|
+
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
12
13
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
13
14
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
14
15
|
|
@@ -320,6 +321,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
320
321
|
self.page_size = model_runner.page_size
|
321
322
|
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
322
323
|
self.skip_prefill = skip_prefill
|
324
|
+
self.is_hybrid = model_runner.is_hybrid
|
325
|
+
if self.is_hybrid:
|
326
|
+
self.full_to_swa_index_mapping = (
|
327
|
+
model_runner.token_to_kv_pool.full_to_swa_index_mapping
|
328
|
+
)
|
323
329
|
self.topk = model_runner.server_args.speculative_eagle_topk or 0
|
324
330
|
self.speculative_num_steps = speculative_num_steps
|
325
331
|
self.speculative_num_draft_tokens = (
|
@@ -428,7 +434,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
428
434
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
429
435
|
]
|
430
436
|
# TODO: we need to test this part for llama 4 eagle case
|
431
|
-
self._init_local_attn_metadata(metadata, device)
|
437
|
+
self._init_local_attn_metadata(forward_batch, metadata, device)
|
432
438
|
elif forward_batch.forward_mode.is_target_verify():
|
433
439
|
if self.topk <= 1:
|
434
440
|
metadata.cache_seqlens_int32 = (
|
@@ -456,7 +462,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
456
462
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
457
463
|
]
|
458
464
|
|
459
|
-
self._init_local_attn_metadata(metadata, device)
|
465
|
+
self._init_local_attn_metadata(forward_batch, metadata, device)
|
460
466
|
else:
|
461
467
|
metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
|
462
468
|
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
@@ -575,7 +581,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
575
581
|
|
576
582
|
# Setup local attention if enabled
|
577
583
|
if forward_batch.forward_mode == ForwardMode.EXTEND:
|
578
|
-
self._init_local_attn_metadata(metadata, device)
|
584
|
+
self._init_local_attn_metadata(forward_batch, metadata, device)
|
579
585
|
|
580
586
|
# Encoder metadata for cross attention
|
581
587
|
if forward_batch.encoder_lens is not None:
|
@@ -1588,7 +1594,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1588
1594
|
forward_mode: ForwardMode,
|
1589
1595
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
1590
1596
|
seq_lens_cpu: Optional[torch.Tensor],
|
1591
|
-
out_cache_loc: torch.Tensor = None,
|
1597
|
+
out_cache_loc: Optional[torch.Tensor] = None,
|
1592
1598
|
):
|
1593
1599
|
"""Initialize forward metadata for replaying CUDA graph."""
|
1594
1600
|
seq_lens = seq_lens[:bs]
|
@@ -1673,7 +1679,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1673
1679
|
self.page_size,
|
1674
1680
|
)
|
1675
1681
|
|
1676
|
-
self._update_local_attn_metadata_for_replay(
|
1682
|
+
self._update_local_attn_metadata_for_replay(
|
1683
|
+
metadata,
|
1684
|
+
bs,
|
1685
|
+
)
|
1677
1686
|
elif forward_mode.is_target_verify():
|
1678
1687
|
if self.topk <= 1:
|
1679
1688
|
metadata = self.target_verify_metadata[bs]
|
@@ -1829,7 +1838,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1829
1838
|
"""Get the fill value for sequence length in CUDA graph."""
|
1830
1839
|
return 1
|
1831
1840
|
|
1832
|
-
def _init_local_attn_metadata(
|
1841
|
+
def _init_local_attn_metadata(
|
1842
|
+
self, forwardbatch: ForwardBatch, metadata: FlashAttentionMetadata, device
|
1843
|
+
):
|
1833
1844
|
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
|
1834
1845
|
if self.attention_chunk_size is None:
|
1835
1846
|
metadata.local_attn_metadata = None
|
@@ -1837,7 +1848,12 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1837
1848
|
|
1838
1849
|
cu_seqlens_q = metadata.cu_seqlens_q
|
1839
1850
|
cache_seqlens_int32 = metadata.cache_seqlens_int32
|
1840
|
-
|
1851
|
+
if self.is_hybrid:
|
1852
|
+
page_table = self.full_to_swa_index_mapping[metadata.page_table].to(
|
1853
|
+
torch.int32
|
1854
|
+
)
|
1855
|
+
else:
|
1856
|
+
page_table = metadata.page_table
|
1841
1857
|
if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
|
1842
1858
|
metadata.local_attn_metadata = None
|
1843
1859
|
return
|
@@ -1923,7 +1939,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1923
1939
|
)
|
1924
1940
|
|
1925
1941
|
def _update_local_attn_metadata_for_replay(
|
1926
|
-
self,
|
1942
|
+
self,
|
1943
|
+
metadata: FlashAttentionMetadata,
|
1944
|
+
bs: int,
|
1927
1945
|
):
|
1928
1946
|
"""Update preallocated local attention metadata in-place before CUDA graph replay."""
|
1929
1947
|
if self.attention_chunk_size is None:
|
@@ -1954,7 +1972,12 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1954
1972
|
# Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
|
1955
1973
|
# beyond the actual sequence length, leading to incorrect attention calculations
|
1956
1974
|
max_seq_len = int(seqlens.max().item())
|
1957
|
-
|
1975
|
+
if self.is_hybrid:
|
1976
|
+
sliced_page_table = self.full_to_swa_index_mapping[
|
1977
|
+
metadata.page_table[:bs, :max_seq_len]
|
1978
|
+
].to(torch.int32)
|
1979
|
+
else:
|
1980
|
+
sliced_page_table = metadata.page_table[:bs, :max_seq_len]
|
1958
1981
|
|
1959
1982
|
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
|
1960
1983
|
seqlens_np = seqlens.cpu().numpy()
|
@@ -119,21 +119,27 @@ class TboAttnBackend(AttentionBackend):
|
|
119
119
|
replay_seq_lens_sum: int = None,
|
120
120
|
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
|
121
121
|
):
|
122
|
+
token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
|
123
|
+
forward_mode=forward_mode, spec_info=spec_info
|
124
|
+
)
|
122
125
|
if fn_name == "init_forward_metadata_capture_cuda_graph":
|
123
|
-
assert
|
124
|
-
|
126
|
+
assert (
|
127
|
+
capture_num_tokens == bs * token_num_per_seq
|
128
|
+
), "For target-verify or decode mode, num_tokens should be equal to token_num_per_seq * bs"
|
129
|
+
num_tokens = bs * token_num_per_seq
|
125
130
|
|
126
131
|
tbo_split_seq_index, tbo_split_token_index = (
|
127
132
|
two_batch_overlap.compute_split_indices_for_cuda_graph_replay(
|
128
133
|
forward_mode=forward_mode,
|
129
134
|
cuda_graph_num_tokens=num_tokens,
|
135
|
+
spec_info=spec_info,
|
130
136
|
)
|
131
137
|
)
|
132
138
|
|
133
139
|
num_tokens_child_left = tbo_split_token_index
|
134
140
|
num_tokens_child_right = num_tokens - tbo_split_token_index
|
135
|
-
bs_child_left =
|
136
|
-
bs_child_right =
|
141
|
+
bs_child_left = tbo_split_seq_index
|
142
|
+
bs_child_right = bs - bs_child_left
|
137
143
|
|
138
144
|
assert (
|
139
145
|
num_tokens_child_left > 0 and num_tokens_child_right > 0
|
@@ -190,16 +196,36 @@ def _init_forward_metadata_cuda_graph_split(
|
|
190
196
|
seq_lens: torch.Tensor,
|
191
197
|
encoder_lens: Optional[torch.Tensor],
|
192
198
|
forward_mode: "ForwardMode",
|
193
|
-
spec_info: Optional[
|
199
|
+
spec_info: Optional[EagleVerifyInput],
|
194
200
|
# capture args
|
195
201
|
capture_num_tokens: int = None,
|
196
202
|
# replay args
|
197
203
|
replay_seq_lens_sum: int = None,
|
198
204
|
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
|
199
205
|
):
|
206
|
+
token_num_per_seq = two_batch_overlap.get_token_num_per_seq(
|
207
|
+
forward_mode=forward_mode, spec_info=spec_info
|
208
|
+
)
|
200
209
|
assert encoder_lens is None, "encoder_lens is not supported yet"
|
201
|
-
|
210
|
+
if spec_info is not None:
|
211
|
+
output_spec_info = two_batch_overlap.split_spec_info(
|
212
|
+
spec_info=spec_info,
|
213
|
+
start_seq_index=seq_slice.start if seq_slice.start is not None else 0,
|
214
|
+
end_seq_index=seq_slice.stop if seq_slice.stop is not None else bs,
|
215
|
+
start_token_index=(
|
216
|
+
seq_slice.start * token_num_per_seq
|
217
|
+
if seq_slice.start is not None
|
218
|
+
else 0
|
219
|
+
),
|
220
|
+
end_token_index=(
|
221
|
+
seq_slice.stop * token_num_per_seq
|
222
|
+
if seq_slice.stop is not None
|
223
|
+
else bs * token_num_per_seq
|
224
|
+
),
|
225
|
+
)
|
202
226
|
|
227
|
+
else:
|
228
|
+
output_spec_info = None
|
203
229
|
ans = dict(
|
204
230
|
bs=output_bs,
|
205
231
|
req_pool_indices=req_pool_indices[seq_slice],
|
@@ -208,14 +234,16 @@ def _init_forward_metadata_cuda_graph_split(
|
|
208
234
|
forward_mode=forward_mode,
|
209
235
|
# ignore
|
210
236
|
encoder_lens=None,
|
211
|
-
spec_info=
|
237
|
+
spec_info=output_spec_info,
|
212
238
|
)
|
213
239
|
|
214
240
|
if fn_name == "init_forward_metadata_capture_cuda_graph":
|
215
|
-
assert
|
241
|
+
assert (
|
242
|
+
capture_num_tokens == bs * token_num_per_seq
|
243
|
+
), "Only support num_tokens==bs * token_num_per_seq for target-verify or decode mode"
|
216
244
|
ans.update(
|
217
245
|
dict(
|
218
|
-
num_tokens=output_bs,
|
246
|
+
num_tokens=output_bs * token_num_per_seq,
|
219
247
|
)
|
220
248
|
)
|
221
249
|
elif fn_name == "init_forward_metadata_replay_cuda_graph":
|
@@ -32,8 +32,13 @@ from sglang.srt.layers.dp_attention import (
|
|
32
32
|
get_attention_tp_rank,
|
33
33
|
get_attention_tp_size,
|
34
34
|
)
|
35
|
+
from sglang.srt.layers.utils import is_sm100_supported
|
35
36
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
36
37
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
38
|
+
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
39
|
+
|
40
|
+
_is_flashinfer_available = is_flashinfer_available()
|
41
|
+
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
37
42
|
|
38
43
|
|
39
44
|
class ScatterMode(Enum):
|
@@ -397,8 +402,21 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
397
402
|
if hidden_states.shape[0] != 0:
|
398
403
|
hidden_states = layernorm(hidden_states)
|
399
404
|
else:
|
400
|
-
|
401
|
-
|
405
|
+
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
|
406
|
+
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
|
407
|
+
if (
|
408
|
+
_is_sm100_supported
|
409
|
+
and _is_flashinfer_available
|
410
|
+
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
411
|
+
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
412
|
+
and hidden_states.shape[0] <= 128
|
413
|
+
):
|
414
|
+
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
415
|
+
hidden_states, residual
|
416
|
+
)
|
417
|
+
else:
|
418
|
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
419
|
+
hidden_states, residual = layernorm(hidden_states, residual)
|
402
420
|
return hidden_states, residual
|
403
421
|
|
404
422
|
@staticmethod
|
@@ -79,14 +79,12 @@ def initialize_dp_attention(
|
|
79
79
|
)
|
80
80
|
|
81
81
|
if enable_dp_attention:
|
82
|
-
local_rank = tp_rank % (tp_size // dp_size)
|
83
82
|
_ATTN_DP_SIZE = dp_size
|
84
83
|
if moe_dense_tp_size is None:
|
85
84
|
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
|
86
85
|
else:
|
87
86
|
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
|
88
87
|
else:
|
89
|
-
local_rank = tp_rank
|
90
88
|
_ATTN_DP_SIZE = 1
|
91
89
|
_LOCAL_ATTN_DP_SIZE = 1
|
92
90
|
|
@@ -96,7 +94,7 @@ def initialize_dp_attention(
|
|
96
94
|
list(range(head, head + _ATTN_TP_SIZE))
|
97
95
|
for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
|
98
96
|
],
|
99
|
-
local_rank,
|
97
|
+
tp_group.local_rank,
|
100
98
|
torch.distributed.get_backend(tp_group.device_group),
|
101
99
|
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
|
102
100
|
use_pymscclpp=False,
|
@@ -239,6 +237,10 @@ def _dp_gather(
|
|
239
237
|
assert (
|
240
238
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
241
239
|
), "aliasing between global_tokens and local_tokens not allowed"
|
240
|
+
|
241
|
+
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
|
242
|
+
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
|
243
|
+
# actual size of the accepted tokens.
|
242
244
|
if forward_batch.forward_mode.is_draft_extend():
|
243
245
|
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
244
246
|
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
@@ -293,6 +295,10 @@ def dp_scatter(
|
|
293
295
|
assert (
|
294
296
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
295
297
|
), "aliasing between local_tokens and global_tokens not allowed"
|
298
|
+
|
299
|
+
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
|
300
|
+
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
|
301
|
+
# actual size of the accepted tokens.
|
296
302
|
if forward_batch.forward_mode.is_draft_extend():
|
297
303
|
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
298
304
|
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
sglang/srt/layers/elementwise.py
CHANGED
@@ -8,6 +8,7 @@ from sglang.srt.utils import is_hip
|
|
8
8
|
|
9
9
|
_is_hip = is_hip()
|
10
10
|
|
11
|
+
|
11
12
|
fused_softcap_autotune = triton.autotune(
|
12
13
|
configs=[
|
13
14
|
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
|
@@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
|
|
189
190
|
assert x.shape == residual.shape and x.dtype == residual.dtype
|
190
191
|
output, mid = torch.empty_like(x), torch.empty_like(x)
|
191
192
|
bs, hidden_dim = x.shape
|
192
|
-
|
193
|
-
min_num_warps = 16 if _is_hip else 32
|
194
|
-
|
195
193
|
if autotune:
|
196
194
|
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
|
197
195
|
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
|
198
196
|
)
|
199
197
|
else:
|
198
|
+
max_warps = 16 if _is_hip else 32
|
200
199
|
config = {
|
201
200
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
202
201
|
"num_warps": max(
|
203
|
-
min(
|
204
|
-
triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
|
205
|
-
),
|
206
|
-
4,
|
202
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
|
207
203
|
),
|
208
204
|
}
|
209
205
|
|
@@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
|
|
260
256
|
else:
|
261
257
|
output = torch.empty_like(x)
|
262
258
|
bs, hidden_dim = x.shape
|
263
|
-
|
264
|
-
min_num_warps = 16 if _is_hip else 32
|
265
|
-
|
259
|
+
max_warps = 16 if _is_hip else 32
|
266
260
|
config = {
|
267
261
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
268
262
|
"num_warps": max(
|
269
|
-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)),
|
263
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
|
270
264
|
),
|
271
265
|
}
|
272
266
|
|
@@ -331,6 +325,75 @@ class FusedDualResidualRMSNorm:
|
|
331
325
|
return self.rmsnorm2.forward_native(residual), residual
|
332
326
|
|
333
327
|
|
328
|
+
@triton.jit
|
329
|
+
def experts_combine_kernel(
|
330
|
+
out_hidden_states,
|
331
|
+
moe_hidden_states,
|
332
|
+
mlp_hidden_states,
|
333
|
+
combine_k: tl.constexpr,
|
334
|
+
hidden_dim: tl.constexpr,
|
335
|
+
BLOCK_SIZE: tl.constexpr,
|
336
|
+
):
|
337
|
+
pid = tl.program_id(0)
|
338
|
+
start_index_mlp = pid * hidden_dim
|
339
|
+
start_index_rmoe = pid * hidden_dim * combine_k
|
340
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
341
|
+
mask = offsets < hidden_dim
|
342
|
+
combine_k_offsets = tl.arange(0, combine_k)
|
343
|
+
|
344
|
+
moe_x = tl.load(
|
345
|
+
moe_hidden_states
|
346
|
+
+ start_index_rmoe
|
347
|
+
+ combine_k_offsets[:, None] * hidden_dim
|
348
|
+
+ offsets[None, :],
|
349
|
+
mask=mask[None, :],
|
350
|
+
other=0.0,
|
351
|
+
)
|
352
|
+
moe_x = tl.sum(moe_x, axis=0)
|
353
|
+
mlp_x = tl.load(mlp_hidden_states + start_index_mlp + offsets, mask=mask, other=0.0)
|
354
|
+
combined_x = (moe_x + mlp_x) / 1.4142135623730951
|
355
|
+
|
356
|
+
tl.store(out_hidden_states + start_index_mlp + offsets, combined_x, mask=mask)
|
357
|
+
|
358
|
+
|
359
|
+
def experts_combine_triton(moe_hidden_states, mlp_hidden_states, output_buffer=None):
|
360
|
+
assert moe_hidden_states.is_contiguous()
|
361
|
+
assert mlp_hidden_states.is_contiguous()
|
362
|
+
|
363
|
+
if len(moe_hidden_states.shape) == 2:
|
364
|
+
combine_k = 1 # pre-combined
|
365
|
+
else:
|
366
|
+
combine_k = moe_hidden_states.shape[1]
|
367
|
+
|
368
|
+
if output_buffer is None:
|
369
|
+
out_hidden_states = torch.empty_like(mlp_hidden_states)
|
370
|
+
else:
|
371
|
+
flat_output_buffer = output_buffer.view(mlp_hidden_states.dtype).reshape(-1)
|
372
|
+
assert flat_output_buffer.numel() >= mlp_hidden_states.numel()
|
373
|
+
out_hidden_states = flat_output_buffer[: mlp_hidden_states.numel()].reshape(
|
374
|
+
mlp_hidden_states.shape
|
375
|
+
)
|
376
|
+
|
377
|
+
bs, hidden_dim = mlp_hidden_states.shape
|
378
|
+
|
379
|
+
config = {
|
380
|
+
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
381
|
+
"num_warps": max(
|
382
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 1024)), 8), 4
|
383
|
+
),
|
384
|
+
}
|
385
|
+
|
386
|
+
experts_combine_kernel[(bs,)](
|
387
|
+
out_hidden_states,
|
388
|
+
moe_hidden_states,
|
389
|
+
mlp_hidden_states,
|
390
|
+
combine_k,
|
391
|
+
hidden_dim,
|
392
|
+
**config,
|
393
|
+
)
|
394
|
+
return out_hidden_states
|
395
|
+
|
396
|
+
|
334
397
|
# gelu on first half of vector
|
335
398
|
@triton.jit
|
336
399
|
def gelu_and_mul_kernel(
|
@@ -400,10 +463,11 @@ def gelu_and_mul_triton(
|
|
400
463
|
out_scales = scales
|
401
464
|
static_scale = True
|
402
465
|
|
466
|
+
max_warps = 16 if _is_hip else 32
|
403
467
|
config = {
|
404
468
|
# 8 ele per thread (not tuned)
|
405
469
|
"num_warps": max(
|
406
|
-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)),
|
470
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
|
407
471
|
),
|
408
472
|
}
|
409
473
|
|