sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +11 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +4 -3
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +71 -0
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/vision.py +13 -5
- sglang/srt/layers/communicator.py +21 -4
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +2 -7
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +77 -73
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +3 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +55 -30
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +28 -7
- sglang/srt/managers/scheduler.py +26 -12
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +24 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +7 -6
- sglang/srt/model_executor/forward_batch_info.py +35 -14
- sglang/srt/model_executor/model_runner.py +19 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +72 -33
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +24 -12
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +142 -7
- sglang/srt/two_batch_overlap.py +157 -5
- sglang/srt/utils.py +38 -2
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,332 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
"""
|
4
|
+
Support attention backend for TRTLLM MHA kernels from flashinfer.
|
5
|
+
The kernel supports sm100 only, with sliding window and attention sink features.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from dataclasses import dataclass
|
9
|
+
from typing import TYPE_CHECKING, Optional
|
10
|
+
|
11
|
+
import torch
|
12
|
+
|
13
|
+
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
14
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
15
|
+
from sglang.srt.utils import is_flashinfer_available
|
16
|
+
|
17
|
+
if is_flashinfer_available():
|
18
|
+
import flashinfer
|
19
|
+
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
22
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
23
|
+
from sglang.srt.speculative.spec_info import SpecInfo
|
24
|
+
|
25
|
+
# Constants
|
26
|
+
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
27
|
+
|
28
|
+
# Reuse this workspace buffer across all TRTLLM MHA wrappers
|
29
|
+
global_workspace_buffer = None
|
30
|
+
|
31
|
+
|
32
|
+
@dataclass
|
33
|
+
class TRTLLMMHAMetadata:
|
34
|
+
# Sequence lengths for the forward batch
|
35
|
+
cache_seqlens_int32: torch.Tensor = None
|
36
|
+
# Maximum sequence length for query
|
37
|
+
max_seq_len_q: int = 1
|
38
|
+
# Maximum sequence length for key
|
39
|
+
max_seq_len_k: int = 0
|
40
|
+
# Cumulative sequence lengths for `query
|
41
|
+
cu_seqlens_q: torch.Tensor = None
|
42
|
+
# Cumulative sequence lengths for key
|
43
|
+
cu_seqlens_k: torch.Tensor = None
|
44
|
+
# Page table, the index of KV Cache Tables/Blocks
|
45
|
+
page_table: torch.Tensor = None
|
46
|
+
|
47
|
+
|
48
|
+
class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
49
|
+
"""TRTLLM MHA attention kernel from flashinfer."""
|
50
|
+
|
51
|
+
def __init__(
|
52
|
+
self,
|
53
|
+
model_runner: ModelRunner,
|
54
|
+
skip_prefill: bool = False,
|
55
|
+
kv_indptr_buf: Optional[torch.Tensor] = None,
|
56
|
+
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
57
|
+
):
|
58
|
+
super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
|
59
|
+
|
60
|
+
config = model_runner.model_config
|
61
|
+
|
62
|
+
# MHA-specific dimensions
|
63
|
+
self.max_context_len = model_runner.model_config.context_len
|
64
|
+
self.hidden_size = config.hidden_size
|
65
|
+
|
66
|
+
# Runtime parameters
|
67
|
+
self.data_type = model_runner.kv_cache_dtype
|
68
|
+
self.q_data_type = model_runner.dtype
|
69
|
+
self.page_size = model_runner.page_size
|
70
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
71
|
+
self.device = model_runner.device
|
72
|
+
|
73
|
+
# Workspace allocation
|
74
|
+
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
75
|
+
# Allocate buffers
|
76
|
+
global global_workspace_buffer
|
77
|
+
if global_workspace_buffer is None:
|
78
|
+
global_workspace_buffer = torch.empty(
|
79
|
+
self.workspace_size,
|
80
|
+
dtype=torch.uint8,
|
81
|
+
device=model_runner.device,
|
82
|
+
)
|
83
|
+
self.workspace_buffer = global_workspace_buffer
|
84
|
+
|
85
|
+
# CUDA graph state
|
86
|
+
self.decode_cuda_graph_metadata = {}
|
87
|
+
|
88
|
+
# Forward metadata
|
89
|
+
self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
|
90
|
+
|
91
|
+
def init_cuda_graph_state(
|
92
|
+
self,
|
93
|
+
max_bs: int,
|
94
|
+
max_num_tokens: int,
|
95
|
+
kv_indices_buf: Optional[torch.Tensor] = None,
|
96
|
+
):
|
97
|
+
"""Initialize CUDA graph state for TRTLLM MHA."""
|
98
|
+
self.decode_cuda_graph_metadata = {
|
99
|
+
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
100
|
+
"page_table": torch.zeros(
|
101
|
+
max_bs,
|
102
|
+
(self.max_context_len + self.page_size - 1) // self.page_size,
|
103
|
+
dtype=torch.int32,
|
104
|
+
device=self.device,
|
105
|
+
),
|
106
|
+
"strided_indices": torch.arange(
|
107
|
+
0, self.max_context_len, self.page_size, device=self.device
|
108
|
+
),
|
109
|
+
}
|
110
|
+
|
111
|
+
def init_forward_metadata_capture_cuda_graph(
|
112
|
+
self,
|
113
|
+
bs: int,
|
114
|
+
num_tokens: int,
|
115
|
+
req_pool_indices: torch.Tensor,
|
116
|
+
seq_lens: torch.Tensor,
|
117
|
+
encoder_lens: Optional[torch.Tensor],
|
118
|
+
forward_mode: ForwardMode,
|
119
|
+
spec_info: Optional[SpecInfo],
|
120
|
+
):
|
121
|
+
"""Initialize metadata for CUDA graph capture."""
|
122
|
+
metadata = TRTLLMMHAMetadata()
|
123
|
+
|
124
|
+
# Get sequence information
|
125
|
+
metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
|
126
|
+
|
127
|
+
# Precompute maximum sequence length
|
128
|
+
metadata.max_seq_len_k = self.max_context_len
|
129
|
+
|
130
|
+
# Precompute page table
|
131
|
+
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
|
132
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
133
|
+
self.forward_metadata = metadata
|
134
|
+
|
135
|
+
def init_forward_metadata_replay_cuda_graph(
|
136
|
+
self,
|
137
|
+
bs: int,
|
138
|
+
req_pool_indices: torch.Tensor,
|
139
|
+
seq_lens: torch.Tensor,
|
140
|
+
seq_lens_sum: int,
|
141
|
+
encoder_lens: Optional[torch.Tensor],
|
142
|
+
forward_mode: ForwardMode,
|
143
|
+
spec_info: Optional[SpecInfo],
|
144
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
145
|
+
):
|
146
|
+
"""Replay CUDA graph with new inputs."""
|
147
|
+
seq_lens = seq_lens[:bs]
|
148
|
+
seq_lens_cpu = seq_lens_cpu[:bs]
|
149
|
+
req_pool_indices = req_pool_indices[:bs]
|
150
|
+
device = seq_lens.device
|
151
|
+
metadata = None
|
152
|
+
|
153
|
+
# Normal Decode
|
154
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
155
|
+
max_len = seq_lens_cpu.max().item()
|
156
|
+
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
157
|
+
metadata.max_seq_len_k = self.max_context_len
|
158
|
+
|
159
|
+
metadata.cache_seqlens_int32.copy_(seq_lens)
|
160
|
+
page_indices = self.req_to_token[
|
161
|
+
req_pool_indices[:, None],
|
162
|
+
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][None, :],
|
163
|
+
]
|
164
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
|
165
|
+
self.forward_metadata = metadata
|
166
|
+
|
167
|
+
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
168
|
+
"""Get the fill value for sequence lengths in CUDA graph."""
|
169
|
+
return 1
|
170
|
+
|
171
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
172
|
+
"""Initialize the metadata for a forward pass."""
|
173
|
+
|
174
|
+
metadata = TRTLLMMHAMetadata()
|
175
|
+
seqlens_in_batch = forward_batch.seq_lens
|
176
|
+
batch_size = forward_batch.batch_size
|
177
|
+
device = seqlens_in_batch.device
|
178
|
+
|
179
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
180
|
+
# Normal Decode
|
181
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
182
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
183
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
184
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
185
|
+
]
|
186
|
+
else:
|
187
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
188
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
189
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
190
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
191
|
+
)
|
192
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
193
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
194
|
+
]
|
195
|
+
|
196
|
+
if any(forward_batch.extend_prefix_lens_cpu):
|
197
|
+
extend_seq_lens = forward_batch.extend_seq_lens
|
198
|
+
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
199
|
+
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
200
|
+
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
201
|
+
)
|
202
|
+
else:
|
203
|
+
metadata.max_seq_len_q = metadata.max_seq_len_k
|
204
|
+
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
205
|
+
|
206
|
+
# Convert the page table to a strided format
|
207
|
+
if self.page_size > 1:
|
208
|
+
self.strided_indices = torch.arange(
|
209
|
+
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
210
|
+
)
|
211
|
+
metadata.page_table = (
|
212
|
+
metadata.page_table[:, self.strided_indices] // self.page_size
|
213
|
+
)
|
214
|
+
|
215
|
+
self.forward_metadata = metadata
|
216
|
+
|
217
|
+
def forward_decode(
|
218
|
+
self,
|
219
|
+
q: torch.Tensor,
|
220
|
+
k: torch.Tensor,
|
221
|
+
v: torch.Tensor,
|
222
|
+
layer: RadixAttention,
|
223
|
+
forward_batch: ForwardBatch,
|
224
|
+
save_kv_cache: bool = True,
|
225
|
+
**kwargs,
|
226
|
+
) -> torch.Tensor:
|
227
|
+
"""Run forward for decode using TRTLLM MHA kernel."""
|
228
|
+
cache_loc = forward_batch.out_cache_loc
|
229
|
+
if save_kv_cache and k is not None:
|
230
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
231
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
232
|
+
)
|
233
|
+
|
234
|
+
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
235
|
+
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
236
|
+
# shape conversion:
|
237
|
+
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
|
238
|
+
k_cache = k_cache.view(
|
239
|
+
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
240
|
+
).permute(0, 2, 1, 3)
|
241
|
+
v_cache = v_cache.view(
|
242
|
+
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
243
|
+
).permute(0, 2, 1, 3)
|
244
|
+
kv_cache = (k_cache, v_cache)
|
245
|
+
|
246
|
+
# TODO: add support for quantization
|
247
|
+
q_scale = 1.0
|
248
|
+
k_scale = (
|
249
|
+
layer.k_scale_float
|
250
|
+
if getattr(layer, "k_scale_float", None) is not None
|
251
|
+
else 1.0
|
252
|
+
)
|
253
|
+
bmm1_scale = q_scale * k_scale * layer.scaling
|
254
|
+
bmm2_scale = 1.0
|
255
|
+
# sink: additional value per head in the denominator of the softmax.
|
256
|
+
attention_sink = kwargs.get("sinks", None)
|
257
|
+
|
258
|
+
# Call TRT-LLM kernel
|
259
|
+
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
|
260
|
+
o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
261
|
+
query=q,
|
262
|
+
kv_cache=kv_cache,
|
263
|
+
workspace_buffer=self.workspace_buffer,
|
264
|
+
block_tables=self.forward_metadata.page_table,
|
265
|
+
seq_lens=self.forward_metadata.cache_seqlens_int32,
|
266
|
+
max_seq_len=self.forward_metadata.max_seq_len_k,
|
267
|
+
bmm1_scale=bmm1_scale,
|
268
|
+
bmm2_scale=bmm2_scale,
|
269
|
+
window_left=layer.sliding_window_size,
|
270
|
+
# TODO: add attention_sink operation or nvfp4 scale factor if needed
|
271
|
+
sinks=attention_sink,
|
272
|
+
)
|
273
|
+
|
274
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
275
|
+
|
276
|
+
def forward_extend(
|
277
|
+
self,
|
278
|
+
q: torch.Tensor,
|
279
|
+
k: torch.Tensor,
|
280
|
+
v: torch.Tensor,
|
281
|
+
layer: RadixAttention,
|
282
|
+
forward_batch: ForwardBatch,
|
283
|
+
save_kv_cache=True,
|
284
|
+
**kwargs,
|
285
|
+
):
|
286
|
+
cache_loc = forward_batch.out_cache_loc
|
287
|
+
if save_kv_cache and k is not None:
|
288
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
289
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
290
|
+
)
|
291
|
+
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
292
|
+
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
|
293
|
+
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
294
|
+
k_cache = k_cache.view(
|
295
|
+
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
296
|
+
).permute(0, 2, 1, 3)
|
297
|
+
v_cache = v_cache.view(
|
298
|
+
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
299
|
+
).permute(0, 2, 1, 3)
|
300
|
+
kv_cache = (k_cache, v_cache)
|
301
|
+
|
302
|
+
# sink: additional value per head in the denominator of the softmax.
|
303
|
+
attention_sink = kwargs.get("sinks", None)
|
304
|
+
# TODO: add support for quantization
|
305
|
+
q_scale = 1.0
|
306
|
+
k_scale = (
|
307
|
+
layer.k_scale_float
|
308
|
+
if getattr(layer, "k_scale_float", None) is not None
|
309
|
+
else 1.0
|
310
|
+
)
|
311
|
+
bmm1_scale = q_scale * k_scale * layer.scaling
|
312
|
+
bmm2_scale = 1.0
|
313
|
+
|
314
|
+
o = flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
315
|
+
query=q,
|
316
|
+
kv_cache=kv_cache,
|
317
|
+
workspace_buffer=self.workspace_buffer,
|
318
|
+
block_tables=self.forward_metadata.page_table,
|
319
|
+
seq_lens=self.forward_metadata.cache_seqlens_int32,
|
320
|
+
max_q_len=self.forward_metadata.max_seq_len_q,
|
321
|
+
max_kv_len=self.forward_metadata.max_seq_len_k,
|
322
|
+
bmm1_scale=bmm1_scale,
|
323
|
+
bmm2_scale=bmm2_scale,
|
324
|
+
batch_size=forward_batch.batch_size,
|
325
|
+
cum_seq_lens_q=self.forward_metadata.cu_seqlens_q,
|
326
|
+
cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,
|
327
|
+
window_left=layer.sliding_window_size,
|
328
|
+
# TODO: add attention_sink operation or nvfp4 scale factor if needed
|
329
|
+
sinks=attention_sink,
|
330
|
+
)
|
331
|
+
|
332
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
@@ -11,6 +11,7 @@ import torch.nn as nn
|
|
11
11
|
import torch.nn.functional as F
|
12
12
|
from einops import rearrange
|
13
13
|
|
14
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
14
15
|
from sglang.srt.utils import is_cuda, print_info_once
|
15
16
|
|
16
17
|
_is_cuda = is_cuda()
|
@@ -365,19 +366,20 @@ class VisionAttention(nn.Module):
|
|
365
366
|
**kwargs,
|
366
367
|
):
|
367
368
|
super().__init__()
|
368
|
-
|
369
|
-
|
370
|
-
self.
|
369
|
+
attn_tp_rank = get_attention_tp_rank()
|
370
|
+
attn_tp_size = get_attention_tp_size()
|
371
|
+
self.tp_size = attn_tp_size
|
372
|
+
self.tp_rank = attn_tp_rank
|
371
373
|
self.dropout = dropout
|
372
374
|
self.head_size = embed_dim // num_heads
|
373
375
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
374
376
|
projection_size, num_heads
|
375
377
|
)
|
376
378
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
377
|
-
num_dummy_heads + num_heads,
|
379
|
+
num_dummy_heads + num_heads, self.tp_size
|
378
380
|
)
|
379
381
|
self.num_attention_kv_heads_per_partition = dist_utils.divide(
|
380
|
-
num_dummy_heads + num_heads,
|
382
|
+
num_dummy_heads + num_heads, self.tp_size
|
381
383
|
)
|
382
384
|
|
383
385
|
self.q_size = self.num_attention_heads_per_partition * self.head_size
|
@@ -427,6 +429,8 @@ class VisionAttention(nn.Module):
|
|
427
429
|
total_num_kv_heads=num_dummy_heads + num_heads,
|
428
430
|
bias=qkv_bias,
|
429
431
|
quant_config=quant_config,
|
432
|
+
tp_rank=self.tp_rank,
|
433
|
+
tp_size=self.tp_size,
|
430
434
|
prefix=add_prefix("qkv_proj", prefix),
|
431
435
|
)
|
432
436
|
else:
|
@@ -435,6 +439,8 @@ class VisionAttention(nn.Module):
|
|
435
439
|
output_size=3 * self.dummy_dim,
|
436
440
|
bias=qkv_bias,
|
437
441
|
quant_config=quant_config,
|
442
|
+
tp_rank=self.tp_rank,
|
443
|
+
tp_size=self.tp_size,
|
438
444
|
prefix=add_prefix("qkv_proj", prefix),
|
439
445
|
)
|
440
446
|
self.proj = RowParallelLinear(
|
@@ -442,6 +448,8 @@ class VisionAttention(nn.Module):
|
|
442
448
|
output_size=embed_dim,
|
443
449
|
bias=proj_bias,
|
444
450
|
quant_config=quant_config,
|
451
|
+
tp_rank=self.tp_rank,
|
452
|
+
tp_size=self.tp_size,
|
445
453
|
prefix=add_prefix("proj", prefix),
|
446
454
|
)
|
447
455
|
|
@@ -27,6 +27,7 @@ from sglang.srt.layers.dp_attention import (
|
|
27
27
|
attn_tp_all_gather_into_tensor,
|
28
28
|
attn_tp_reduce_scatter_tensor,
|
29
29
|
dp_gather_partial,
|
30
|
+
dp_reduce_scatter_tensor,
|
30
31
|
dp_scatter,
|
31
32
|
get_attention_dp_size,
|
32
33
|
get_attention_tp_rank,
|
@@ -149,10 +150,13 @@ class LayerCommunicator:
|
|
149
150
|
layer_scatter_modes: LayerScatterModes,
|
150
151
|
input_layernorm: torch.nn.Module,
|
151
152
|
post_attention_layernorm: torch.nn.Module,
|
153
|
+
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
|
154
|
+
allow_reduce_scatter: bool = False,
|
152
155
|
):
|
153
156
|
self.layer_scatter_modes = layer_scatter_modes
|
154
157
|
self.input_layernorm = input_layernorm
|
155
158
|
self.post_attention_layernorm = post_attention_layernorm
|
159
|
+
self.allow_reduce_scatter = allow_reduce_scatter
|
156
160
|
|
157
161
|
self._context = CommunicateContext.init_new()
|
158
162
|
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
|
@@ -239,6 +243,15 @@ class LayerCommunicator:
|
|
239
243
|
residual=residual,
|
240
244
|
forward_batch=forward_batch,
|
241
245
|
context=self._context,
|
246
|
+
allow_reduce_scatter=self.allow_reduce_scatter,
|
247
|
+
)
|
248
|
+
|
249
|
+
def should_use_reduce_scatter(self, forward_batch: ForwardBatch):
|
250
|
+
return (
|
251
|
+
self.allow_reduce_scatter
|
252
|
+
and self._communicate_summable_tensor_pair_fn
|
253
|
+
is CommunicateSummableTensorPairFn._scatter_hidden_states
|
254
|
+
and forward_batch.dp_padding_mode.is_max_len()
|
242
255
|
)
|
243
256
|
|
244
257
|
|
@@ -524,6 +537,7 @@ class CommunicateSummableTensorPairFn:
|
|
524
537
|
residual: torch.Tensor,
|
525
538
|
forward_batch: ForwardBatch,
|
526
539
|
context: CommunicateContext,
|
540
|
+
**kwargs,
|
527
541
|
):
|
528
542
|
return hidden_states, residual
|
529
543
|
|
@@ -533,15 +547,17 @@ class CommunicateSummableTensorPairFn:
|
|
533
547
|
residual: torch.Tensor,
|
534
548
|
forward_batch: ForwardBatch,
|
535
549
|
context: CommunicateContext,
|
550
|
+
allow_reduce_scatter: bool = False,
|
536
551
|
):
|
537
|
-
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
538
|
-
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
539
|
-
# be careful about this!
|
540
552
|
hidden_states, global_hidden_states = (
|
541
553
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
542
554
|
hidden_states,
|
543
555
|
)
|
544
|
-
|
556
|
+
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
|
557
|
+
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
|
558
|
+
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
|
559
|
+
else:
|
560
|
+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
545
561
|
return hidden_states, residual
|
546
562
|
|
547
563
|
@staticmethod
|
@@ -550,6 +566,7 @@ class CommunicateSummableTensorPairFn:
|
|
550
566
|
residual: torch.Tensor,
|
551
567
|
forward_batch: ForwardBatch,
|
552
568
|
context: CommunicateContext,
|
569
|
+
**kwargs,
|
553
570
|
):
|
554
571
|
hidden_states += residual
|
555
572
|
residual = None
|
@@ -12,6 +12,7 @@ import triton.language as tl
|
|
12
12
|
|
13
13
|
from sglang.srt.distributed import (
|
14
14
|
GroupCoordinator,
|
15
|
+
get_tensor_model_parallel_rank,
|
15
16
|
get_tensor_model_parallel_world_size,
|
16
17
|
get_tp_group,
|
17
18
|
tensor_model_parallel_all_reduce,
|
@@ -355,6 +356,17 @@ def dp_scatter(
|
|
355
356
|
)
|
356
357
|
|
357
358
|
|
359
|
+
def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
360
|
+
if get_tensor_model_parallel_world_size() == get_attention_dp_size():
|
361
|
+
get_tp_group().reduce_scatter_tensor(output, input)
|
362
|
+
else:
|
363
|
+
scattered_local_tokens = input.tensor_split(
|
364
|
+
get_tensor_model_parallel_world_size()
|
365
|
+
)[get_tensor_model_parallel_rank()]
|
366
|
+
get_tp_group().reduce_scatter_tensor(scattered_local_tokens, input)
|
367
|
+
get_attention_tp_group().all_gather_into_tensor(output, scattered_local_tokens)
|
368
|
+
|
369
|
+
|
358
370
|
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
359
371
|
return get_attention_tp_group().reduce_scatter_tensor(output, input)
|
360
372
|
|
sglang/srt/layers/linear.py
CHANGED
@@ -1191,11 +1191,6 @@ class RowParallelLinear(LinearBase):
|
|
1191
1191
|
else self.weight_loader
|
1192
1192
|
),
|
1193
1193
|
)
|
1194
|
-
if not reduce_results and (bias and not skip_bias_add):
|
1195
|
-
raise ValueError(
|
1196
|
-
"When not reduce the results, adding bias to the "
|
1197
|
-
"results can lead to incorrect results"
|
1198
|
-
)
|
1199
1194
|
|
1200
1195
|
if bias:
|
1201
1196
|
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
|
@@ -1282,7 +1277,7 @@ class RowParallelLinear(LinearBase):
|
|
1282
1277
|
# It does not support additional parameters.
|
1283
1278
|
param.load_row_parallel_weight(loaded_weight)
|
1284
1279
|
|
1285
|
-
def forward(self, input_,
|
1280
|
+
def forward(self, input_, skip_all_reduce=False):
|
1286
1281
|
if self.input_is_parallel:
|
1287
1282
|
input_parallel = input_
|
1288
1283
|
else:
|
@@ -1299,7 +1294,7 @@ class RowParallelLinear(LinearBase):
|
|
1299
1294
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
1300
1295
|
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
1301
1296
|
sm.tag(output_parallel)
|
1302
|
-
if self.reduce_results and self.tp_size > 1 and not
|
1297
|
+
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
|
1303
1298
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
1304
1299
|
else:
|
1305
1300
|
output = output_parallel
|
@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
|
12
|
+
from sglang.srt.layers.utils import is_sm100_supported
|
12
13
|
from sglang.srt.utils import is_cuda
|
13
14
|
|
14
15
|
_is_cuda = is_cuda()
|
@@ -123,6 +124,7 @@ def cutlass_fused_experts_fp8(
|
|
123
124
|
|
124
125
|
if is_cuda:
|
125
126
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
127
|
+
per_token_group_quant_fp8_hopper_moe_mn_major,
|
126
128
|
sglang_per_token_group_quant_fp8,
|
127
129
|
)
|
128
130
|
|
@@ -133,9 +135,7 @@ def cutlass_fused_experts_fp8(
|
|
133
135
|
n = w2_q.size(1)
|
134
136
|
|
135
137
|
topk = topk_ids.size(1)
|
136
|
-
|
137
|
-
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
|
138
|
-
device = a_q.device
|
138
|
+
device = a.device
|
139
139
|
|
140
140
|
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
141
141
|
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
@@ -152,8 +152,16 @@ def cutlass_fused_experts_fp8(
|
|
152
152
|
k,
|
153
153
|
)
|
154
154
|
|
155
|
-
|
156
|
-
|
155
|
+
if is_sm100_supported():
|
156
|
+
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
|
157
|
+
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
158
|
+
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
159
|
+
else:
|
160
|
+
rep_a = shuffle_rows(a, a_map, (m * topk, k))
|
161
|
+
rep_a_q, rep_a1_scales = per_token_group_quant_fp8_hopper_moe_mn_major(
|
162
|
+
rep_a, expert_offsets, problem_sizes1, 128
|
163
|
+
)
|
164
|
+
w1_scale = w1_scale.contiguous()
|
157
165
|
|
158
166
|
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
159
167
|
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
@@ -185,7 +193,13 @@ def cutlass_fused_experts_fp8(
|
|
185
193
|
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
186
194
|
silu_and_mul(c1, intermediate)
|
187
195
|
|
188
|
-
|
196
|
+
if is_sm100_supported():
|
197
|
+
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
198
|
+
else:
|
199
|
+
intemediate_q, a2_scale = per_token_group_quant_fp8_hopper_moe_mn_major(
|
200
|
+
intermediate, expert_offsets, problem_sizes2, 128
|
201
|
+
)
|
202
|
+
w2_scale = w2_scale.contiguous()
|
189
203
|
|
190
204
|
fp8_blockwise_scaled_grouped_mm(
|
191
205
|
c2,
|