sglang 0.4.1.post7__py3-none-any.whl → 0.4.2.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +17 -11
- sglang/bench_one_batch.py +14 -6
- sglang/bench_serving.py +47 -44
- sglang/lang/chat_template.py +31 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
- sglang/srt/entrypoints/engine.py +5 -2
- sglang/srt/entrypoints/http_server.py +24 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
- sglang/srt/layers/attention/vision.py +243 -40
- sglang/srt/layers/dp_attention.py +3 -1
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +24 -9
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +20 -12
- sglang/srt/layers/moe/fused_moe_native.py +17 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
- sglang/srt/layers/parameter.py +16 -7
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +11 -1
- sglang/srt/layers/rotary_embedding.py +34 -13
- sglang/srt/layers/sampler.py +33 -10
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/image_processor.py +77 -38
- sglang/srt/managers/io_struct.py +36 -5
- sglang/srt/managers/schedule_batch.py +31 -25
- sglang/srt/managers/scheduler.py +78 -38
- sglang/srt/managers/tokenizer_manager.py +4 -0
- sglang/srt/mem_cache/base_prefix_cache.py +4 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +30 -1
- sglang/srt/model_executor/cuda_graph_runner.py +23 -25
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +7 -4
- sglang/srt/model_loader/loader.py +75 -0
- sglang/srt/model_loader/weight_utils.py +91 -5
- sglang/srt/models/commandr.py +14 -2
- sglang/srt/models/dbrx.py +9 -1
- sglang/srt/models/deepseek_v2.py +3 -3
- sglang/srt/models/gemma2.py +9 -1
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/minicpm3.py +3 -3
- sglang/srt/models/minicpmv.py +129 -76
- sglang/srt/models/mllama.py +16 -56
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_vl.py +18 -8
- sglang/srt/models/torch_native_llama.py +17 -4
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +5 -4
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +4 -14
- sglang/srt/server.py +2 -2
- sglang/srt/server_args.py +26 -1
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +62 -67
- sglang/test/test_programs.py +1 -0
- sglang/test/test_utils.py +81 -22
- sglang/utils.py +42 -0
- sglang/version.py +1 -1
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/METADATA +8 -8
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/RECORD +78 -67
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ from typing import Optional
|
|
4
4
|
|
5
5
|
import torch
|
6
6
|
import torch.nn as nn
|
7
|
+
import torch.nn.functional as F
|
7
8
|
from einops import rearrange, repeat
|
8
9
|
|
9
10
|
from sglang.srt.distributed import parallel_state
|
@@ -63,7 +64,20 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.T
|
|
63
64
|
|
64
65
|
|
65
66
|
class VisionAttention(nn.Module):
|
66
|
-
"""
|
67
|
+
r"""
|
68
|
+
Multi-headed attention without any cache, mostly used for ViT.
|
69
|
+
|
70
|
+
|
71
|
+
Args:
|
72
|
+
use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
|
73
|
+
use_context_forward (bool, default to True):
|
74
|
+
if ``True``, a flash_attn style attention will be applied
|
75
|
+
Otherwise, a full-sequence attention will be applied.
|
76
|
+
use_full_precision_softmax (bool, default to False):
|
77
|
+
if ``True``, the softmax will be performed in full-precision
|
78
|
+
Otherwise, it will be performed in half-precision
|
79
|
+
|
80
|
+
"""
|
67
81
|
|
68
82
|
def __init__(
|
69
83
|
self,
|
@@ -72,25 +86,39 @@ class VisionAttention(nn.Module):
|
|
72
86
|
projection_size: int,
|
73
87
|
use_qkv_parallel: bool,
|
74
88
|
quant_config: Optional[QuantizationConfig] = None,
|
89
|
+
dropout: float = 0.0,
|
90
|
+
use_context_forward: bool = True,
|
91
|
+
use_full_precision_softmax: bool = False,
|
92
|
+
flatten_batch: bool = False,
|
75
93
|
prefix: str = "",
|
76
94
|
):
|
77
95
|
super().__init__()
|
96
|
+
self.use_context_forward = use_context_forward
|
78
97
|
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
79
|
-
|
98
|
+
self.dropout = dropout
|
99
|
+
self.head_size = embed_dim // num_heads
|
80
100
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
81
101
|
projection_size, num_heads
|
82
102
|
)
|
83
103
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
84
104
|
num_heads, world_size
|
85
105
|
)
|
86
|
-
|
87
|
-
|
106
|
+
|
107
|
+
if self.use_context_forward:
|
108
|
+
self.qkv_backend = VisionTritonAttention()
|
109
|
+
else:
|
110
|
+
self.qkv_backend = VisionSdpaAttention(
|
111
|
+
head_size=self.head_size,
|
112
|
+
dropout=dropout,
|
113
|
+
flatten_batch=flatten_batch,
|
114
|
+
use_full_precision_softmax=use_full_precision_softmax,
|
115
|
+
)
|
116
|
+
|
88
117
|
self.use_qkv_parallel = use_qkv_parallel
|
89
118
|
if use_qkv_parallel:
|
90
|
-
self.head_dim = embed_dim // num_heads
|
91
119
|
self.qkv_proj = QKVParallelLinear(
|
92
120
|
hidden_size=embed_dim,
|
93
|
-
head_size=self.
|
121
|
+
head_size=self.head_size,
|
94
122
|
total_num_heads=num_heads,
|
95
123
|
quant_config=quant_config,
|
96
124
|
prefix=f"{prefix}.qkv_proj",
|
@@ -114,12 +142,15 @@ class VisionAttention(nn.Module):
|
|
114
142
|
x: torch.Tensor,
|
115
143
|
cu_seqlens: Optional[torch.Tensor] = None,
|
116
144
|
rotary_pos_emb: torch.Tensor = None,
|
145
|
+
attention_mask: Optional[torch.Tensor] = None,
|
117
146
|
) -> torch.Tensor:
|
147
|
+
r"""
|
148
|
+
Args:
|
149
|
+
x: [b, s, embed_dim]
|
150
|
+
cu_seqlens: [b]
|
151
|
+
Returns:
|
152
|
+
[s, b, num_heads * head]
|
118
153
|
"""
|
119
|
-
Input shape: [b, s, embed_dim]
|
120
|
-
Output shape: [s, b, num_heads * head_size]
|
121
|
-
"""
|
122
|
-
|
123
154
|
bsz, s, _ = x.shape
|
124
155
|
if self.use_qkv_parallel:
|
125
156
|
# [b, s, embed_dim] --> [b, s, embed_dim]
|
@@ -136,19 +167,19 @@ class VisionAttention(nn.Module):
|
|
136
167
|
else:
|
137
168
|
# [b, s, embed_dim] --> [s, b, embed_dim]
|
138
169
|
x = rearrange(x, "b s ... -> s b ...")
|
139
|
-
# [s, b, embed_dim] --> [s, b, head * 3 *
|
170
|
+
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
|
140
171
|
qkv, _ = self.qkv_proj(x)
|
141
|
-
# [s, b, head * 3 *
|
172
|
+
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
|
142
173
|
new_x_shape = qkv.size()[:-1] + (
|
143
174
|
self.num_attention_heads_per_partition,
|
144
175
|
3 * self.hidden_size_per_attention_head,
|
145
176
|
)
|
146
177
|
qkv = qkv.view(*new_x_shape)
|
147
178
|
|
148
|
-
# [s, b, head, 3 *
|
179
|
+
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
|
149
180
|
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
150
181
|
|
151
|
-
# [s, b, head,
|
182
|
+
# [s, b, head, head_size] --> [b, s, head, head_size]
|
152
183
|
q, k, v = [
|
153
184
|
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
154
185
|
]
|
@@ -160,45 +191,217 @@ class VisionAttention(nn.Module):
|
|
160
191
|
if self.use_qkv_parallel:
|
161
192
|
pass
|
162
193
|
else:
|
163
|
-
# [b, s, head,
|
194
|
+
# [b, s, head, head_size] --> [b * s, head, head_size]
|
164
195
|
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
165
196
|
|
166
|
-
|
167
|
-
output = torch.empty_like(q)
|
168
|
-
|
169
|
-
seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
|
170
|
-
max_seqlen = seq_lens.max().item()
|
171
|
-
|
172
|
-
context_attention_fwd(
|
173
|
-
q,
|
174
|
-
k,
|
175
|
-
v,
|
176
|
-
output,
|
177
|
-
cu_seqlens.cuda(),
|
178
|
-
seq_lens,
|
179
|
-
max_seqlen,
|
180
|
-
is_causal=False,
|
181
|
-
)
|
197
|
+
output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask)
|
182
198
|
|
183
199
|
if self.use_qkv_parallel:
|
184
|
-
|
185
|
-
# [b * s, head, head_dim] --> [b, s, head * head_dim]
|
200
|
+
# [b * s, h, head_size] --> [b, s, h * head_size]
|
186
201
|
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
187
202
|
|
188
|
-
# [b, s,
|
203
|
+
# [b, s, h * head_size] --> [b, s, h * head_size]
|
189
204
|
output, _ = self.proj(output)
|
190
205
|
else:
|
191
|
-
# [b * s,
|
192
|
-
context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
193
|
-
|
194
|
-
# [s, b, num_heads * head_size]
|
206
|
+
# [b * s, h, head_size] --> [s, b, h * head_size]
|
195
207
|
context_layer = rearrange(
|
196
|
-
|
208
|
+
output, "(b s) h d -> s b (h d)", b=bsz, s=s
|
197
209
|
).contiguous()
|
198
210
|
|
199
|
-
# [s, b,
|
211
|
+
# [s, b, h * head_size] --> [s, b, h * head_size]
|
200
212
|
output, _ = self.proj(context_layer)
|
201
213
|
|
214
|
+
# [s, b, h * head_size] --> [b, s, h * head_size]
|
202
215
|
output = output.view(bsz, s, -1)
|
203
216
|
|
204
217
|
return output
|
218
|
+
|
219
|
+
|
220
|
+
class VisionSdpaAttention(nn.Module):
|
221
|
+
r"""
|
222
|
+
Scaled Dot Product Attention inner product
|
223
|
+
|
224
|
+
"""
|
225
|
+
|
226
|
+
# TODO: Should it be released after used?
|
227
|
+
_mask_cache = {}
|
228
|
+
|
229
|
+
def __init__(
|
230
|
+
self,
|
231
|
+
head_size: int,
|
232
|
+
dropout: float = 0.0,
|
233
|
+
flatten_batch: bool = False,
|
234
|
+
use_full_precision_softmax: bool = False,
|
235
|
+
):
|
236
|
+
super().__init__()
|
237
|
+
self.head_size = head_size
|
238
|
+
self.flatten_batch = flatten_batch
|
239
|
+
self.use_full_precision_softmax = use_full_precision_softmax
|
240
|
+
self.dropout = dropout
|
241
|
+
|
242
|
+
def generate_patch_attention_mask(
|
243
|
+
self,
|
244
|
+
s: int,
|
245
|
+
bsz: int,
|
246
|
+
device,
|
247
|
+
cu_seqlens: Optional[torch.Tensor],
|
248
|
+
flatten_batch: bool = False,
|
249
|
+
dtype=torch.bfloat16,
|
250
|
+
) -> torch.Tensor:
|
251
|
+
r"""
|
252
|
+
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
253
|
+
|
254
|
+
When `flatten_batch` is True:
|
255
|
+
- All sequences in the batch are flattened into a single dimension
|
256
|
+
- `s` represents the total number of tokens across all sequences in the batch
|
257
|
+
- Returns a unified mask of shape `(1, 1, s, s)`
|
258
|
+
|
259
|
+
When `flatten_batch` is False:
|
260
|
+
- Each sequence has its own attention mask
|
261
|
+
- `s` represents the maximum sequence length in the batch
|
262
|
+
- Returns separate masks of shape `(b, 1, s, s)`
|
263
|
+
|
264
|
+
Args:
|
265
|
+
flatten_batch: (bool):
|
266
|
+
If True, treats all sequences in the batch as a single flattened sequence
|
267
|
+
If False, generates separate masks for each sequence
|
268
|
+
|
269
|
+
Returns:
|
270
|
+
Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
271
|
+
"""
|
272
|
+
|
273
|
+
cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))
|
274
|
+
|
275
|
+
if cache_key in VisionSdpaAttention._mask_cache:
|
276
|
+
cached_mask = VisionSdpaAttention._mask_cache[cache_key]
|
277
|
+
# print(f"cache hit for key: {cache_key}")
|
278
|
+
return cached_mask.to(device=device, dtype=dtype)
|
279
|
+
|
280
|
+
if cu_seqlens is None:
|
281
|
+
raise ValueError("Internal Error: cu_seqlens cannot be None")
|
282
|
+
|
283
|
+
if flatten_batch:
|
284
|
+
mask = torch.zeros([1, s, s], device=device, dtype=torch.bool)
|
285
|
+
for i in range(1, len(cu_seqlens)):
|
286
|
+
start = cu_seqlens[i - 1]
|
287
|
+
end = cu_seqlens[i]
|
288
|
+
mask[
|
289
|
+
...,
|
290
|
+
start:end,
|
291
|
+
start:end,
|
292
|
+
] = True
|
293
|
+
else:
|
294
|
+
# [1, 1, 1, s]
|
295
|
+
row_indices = torch.arange(s, device=device).view(1, 1, 1, s)
|
296
|
+
# [1, 1, s, 1]
|
297
|
+
col_indices = torch.arange(s, device=device).view(1, 1, s, 1)
|
298
|
+
# [b, 1, 1, 1]
|
299
|
+
seq_lens = (
|
300
|
+
(cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1)
|
301
|
+
)
|
302
|
+
|
303
|
+
mask = (row_indices < seq_lens) & (col_indices < seq_lens)
|
304
|
+
|
305
|
+
# Convert to attention mask format (False -> 0, True -> -inf)
|
306
|
+
mask = (~mask).to(dtype) * torch.finfo(dtype).min
|
307
|
+
|
308
|
+
VisionSdpaAttention._mask_cache[cache_key] = mask
|
309
|
+
|
310
|
+
return mask
|
311
|
+
|
312
|
+
def forward(
|
313
|
+
self,
|
314
|
+
q: torch.Tensor,
|
315
|
+
k: torch.Tensor,
|
316
|
+
v: torch.Tensor,
|
317
|
+
bsz: int,
|
318
|
+
cu_seqlens: Optional[torch.Tensor] = None,
|
319
|
+
attention_mask: Optional[torch.Tensor] = None,
|
320
|
+
) -> torch.Tensor:
|
321
|
+
r"""
|
322
|
+
Args:
|
323
|
+
cu_seqlens: [b]
|
324
|
+
Returns:
|
325
|
+
[b * s, h, head_size]
|
326
|
+
"""
|
327
|
+
|
328
|
+
s = q.shape[0] // bsz
|
329
|
+
|
330
|
+
# [b, 1, s, s]
|
331
|
+
if attention_mask is None:
|
332
|
+
attention_mask = self.generate_patch_attention_mask(
|
333
|
+
s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype
|
334
|
+
)
|
335
|
+
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
|
336
|
+
# [b, 1, s]
|
337
|
+
if self.use_full_precision_softmax:
|
338
|
+
scale = self.head_size**-0.5
|
339
|
+
k_transposed = rearrange(k, "b h s d -> b h d s")
|
340
|
+
attn_weights = torch.matmul(q, k_transposed) * scale
|
341
|
+
del k, k_transposed
|
342
|
+
attn_weights = attn_weights + attention_mask
|
343
|
+
del attention_mask
|
344
|
+
# full-precision
|
345
|
+
attn_weights = nn.functional.softmax(
|
346
|
+
attn_weights, dim=-1, dtype=torch.float32
|
347
|
+
).to(q.dtype)
|
348
|
+
attn_weights = nn.functional.dropout(
|
349
|
+
attn_weights, p=self.dropout, training=False
|
350
|
+
)
|
351
|
+
output = torch.matmul(attn_weights, v)
|
352
|
+
del attn_weights, v
|
353
|
+
else:
|
354
|
+
# SDPA
|
355
|
+
# [b, h, s, head_size]
|
356
|
+
output = F.scaled_dot_product_attention(
|
357
|
+
q, k, v, attention_mask, dropout_p=self.dropout
|
358
|
+
)
|
359
|
+
|
360
|
+
# [b, h, s, head_size] --> [b * s, h, head_size]
|
361
|
+
output = rearrange(output, "b h s d -> (b s) h d")
|
362
|
+
|
363
|
+
return output
|
364
|
+
|
365
|
+
|
366
|
+
class VisionTritonAttention(nn.Module):
|
367
|
+
"""
|
368
|
+
Triton-implemented attention without a causal mask
|
369
|
+
"""
|
370
|
+
|
371
|
+
def __init__(
|
372
|
+
self,
|
373
|
+
):
|
374
|
+
super().__init__()
|
375
|
+
|
376
|
+
def forward(
|
377
|
+
self,
|
378
|
+
q: torch.Tensor,
|
379
|
+
k: torch.Tensor,
|
380
|
+
v: torch.Tensor,
|
381
|
+
_bsz: int,
|
382
|
+
cu_seqlens: Optional[torch.Tensor],
|
383
|
+
**kwargs,
|
384
|
+
) -> torch.Tensor:
|
385
|
+
r"""
|
386
|
+
Args:
|
387
|
+
cu_seqlens: [b]
|
388
|
+
Returns:
|
389
|
+
[b * s, h, head_size]
|
390
|
+
"""
|
391
|
+
|
392
|
+
# [b * s, head, head_size]
|
393
|
+
output = torch.empty_like(q)
|
394
|
+
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
395
|
+
max_seqlen = seq_lens.max().item()
|
396
|
+
context_attention_fwd(
|
397
|
+
q,
|
398
|
+
k,
|
399
|
+
v,
|
400
|
+
output,
|
401
|
+
cu_seqlens.cuda(),
|
402
|
+
seq_lens.cuda(),
|
403
|
+
max_seqlen,
|
404
|
+
is_causal=False,
|
405
|
+
)
|
406
|
+
|
407
|
+
return output
|
@@ -22,6 +22,8 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
|
|
22
22
|
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
23
23
|
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
24
24
|
|
25
|
+
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
26
|
+
|
25
27
|
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
|
26
28
|
enable_dp_attention, tp_rank, tp_size, dp_size
|
27
29
|
)
|
@@ -35,7 +37,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
|
35
37
|
],
|
36
38
|
tp_rank,
|
37
39
|
torch.distributed.get_backend(tp_group.device_group),
|
38
|
-
|
40
|
+
SYNC_TOKEN_IDS_ACROSS_TP,
|
39
41
|
False,
|
40
42
|
False,
|
41
43
|
False,
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -19,10 +19,10 @@ from typing import Optional, Tuple, Union
|
|
19
19
|
import torch
|
20
20
|
import torch.nn as nn
|
21
21
|
|
22
|
-
from sglang.srt.utils import
|
22
|
+
from sglang.srt.utils import is_cuda_available
|
23
23
|
|
24
|
-
if
|
25
|
-
from
|
24
|
+
if is_cuda_available():
|
25
|
+
from sgl_kernel import (
|
26
26
|
fused_add_rmsnorm,
|
27
27
|
gemma_fused_add_rmsnorm,
|
28
28
|
gemma_rmsnorm,
|
@@ -121,8 +121,8 @@ class GemmaRMSNorm(CustomOp):
|
|
121
121
|
return out
|
122
122
|
|
123
123
|
|
124
|
-
if not
|
124
|
+
if not is_cuda_available():
|
125
125
|
logger.info(
|
126
|
-
"
|
126
|
+
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
127
127
|
)
|
128
128
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
sglang/srt/layers/linear.py
CHANGED
@@ -329,12 +329,14 @@ class ColumnParallelLinear(LinearBase):
|
|
329
329
|
prefix: str = "",
|
330
330
|
tp_rank: Optional[int] = None,
|
331
331
|
tp_size: Optional[int] = None,
|
332
|
+
use_presharded_weights: bool = False,
|
332
333
|
):
|
333
334
|
super().__init__(
|
334
335
|
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
335
336
|
)
|
336
337
|
|
337
338
|
self.gather_output = gather_output
|
339
|
+
self.use_presharded_weights = use_presharded_weights
|
338
340
|
|
339
341
|
# Divide the weight matrix along the last dimension.
|
340
342
|
if tp_rank is None:
|
@@ -402,7 +404,8 @@ class ColumnParallelLinear(LinearBase):
|
|
402
404
|
if output_dim is not None and not use_bitsandbytes_4bit:
|
403
405
|
shard_size = param_data.shape[output_dim]
|
404
406
|
start_idx = self.tp_rank * shard_size
|
405
|
-
|
407
|
+
if not self.use_presharded_weights:
|
408
|
+
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
406
409
|
|
407
410
|
# Special case for loading scales off disk, which often do not
|
408
411
|
# have a shape (such as in the case of AutoFP8).
|
@@ -418,7 +421,11 @@ class ColumnParallelLinear(LinearBase):
|
|
418
421
|
if len(loaded_weight.shape) == 0:
|
419
422
|
assert loaded_weight.numel() == 1
|
420
423
|
loaded_weight = loaded_weight.reshape(1)
|
421
|
-
param.load_column_parallel_weight(
|
424
|
+
param.load_column_parallel_weight(
|
425
|
+
loaded_weight,
|
426
|
+
tp_rank=self.tp_rank,
|
427
|
+
use_presharded_weights=self.use_presharded_weights,
|
428
|
+
)
|
422
429
|
|
423
430
|
def forward(self, input_):
|
424
431
|
bias = self.bias if not self.skip_bias_add else None
|
@@ -499,7 +506,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
499
506
|
prefix=prefix,
|
500
507
|
tp_rank=tp_rank,
|
501
508
|
tp_size=tp_size,
|
509
|
+
use_presharded_weights=use_presharded_weights,
|
502
510
|
)
|
511
|
+
self.prefix = prefix
|
503
512
|
|
504
513
|
def weight_loader(
|
505
514
|
self,
|
@@ -743,6 +752,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
743
752
|
prefix: str = "",
|
744
753
|
tp_rank: Optional[int] = None,
|
745
754
|
tp_size: Optional[int] = None,
|
755
|
+
load_presharded_attn: bool = False,
|
746
756
|
):
|
747
757
|
self.hidden_size = hidden_size
|
748
758
|
self.head_size = head_size
|
@@ -772,6 +782,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
772
782
|
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
773
783
|
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
774
784
|
]
|
785
|
+
self.use_presharded_weights = load_presharded_attn
|
775
786
|
|
776
787
|
super().__init__(
|
777
788
|
input_size=input_size,
|
@@ -784,6 +795,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
784
795
|
prefix=prefix,
|
785
796
|
tp_rank=tp_rank,
|
786
797
|
tp_size=tp_size,
|
798
|
+
use_presharded_weights=self.use_presharded_weights,
|
787
799
|
)
|
788
800
|
|
789
801
|
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
@@ -842,9 +854,10 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
842
854
|
shard_size=shard_size, shard_offset=shard_offset
|
843
855
|
)
|
844
856
|
|
845
|
-
|
846
|
-
|
847
|
-
|
857
|
+
if not self.use_presharded_weights:
|
858
|
+
loaded_weight_shard = loaded_weight.narrow(
|
859
|
+
param.output_dim, shard_offset, shard_size
|
860
|
+
)
|
848
861
|
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
|
849
862
|
|
850
863
|
def weight_loader_v2(
|
@@ -882,6 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
882
895
|
shard_offset=shard_offset,
|
883
896
|
shard_size=shard_size,
|
884
897
|
tp_rank=self.tp_rank,
|
898
|
+
use_presharded_weights=self.use_presharded_weights,
|
885
899
|
)
|
886
900
|
|
887
901
|
def weight_loader(
|
@@ -987,9 +1001,10 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
987
1001
|
param, orig_qkv_offsets, shard_id
|
988
1002
|
)
|
989
1003
|
|
990
|
-
|
991
|
-
|
992
|
-
|
1004
|
+
if not self.use_presharded_weights:
|
1005
|
+
loaded_weight_shard = loaded_weight.narrow(
|
1006
|
+
output_dim, shard_offset, shard_size
|
1007
|
+
)
|
993
1008
|
self.weight_loader(param, loaded_weight_shard, shard_id)
|
994
1009
|
return
|
995
1010
|
|
@@ -1049,7 +1064,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
1049
1064
|
|
1050
1065
|
# bitsandbytes loads the weights of the specific portion
|
1051
1066
|
# no need to narrow here
|
1052
|
-
if not use_bitsandbytes_4bit:
|
1067
|
+
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
1053
1068
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
1054
1069
|
|
1055
1070
|
# Special case for for AQLM codebooks.
|
@@ -296,7 +296,7 @@ def fused_softcap_kernel(
|
|
296
296
|
n_elements,
|
297
297
|
BLOCK_SIZE: tl.constexpr,
|
298
298
|
):
|
299
|
-
pid = tl.program_id(0)
|
299
|
+
pid = tl.program_id(0).to(tl.int64)
|
300
300
|
block_start = pid * BLOCK_SIZE
|
301
301
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
302
302
|
mask = offsets < n_elements
|
@@ -114,6 +114,8 @@ class EPMoE(torch.nn.Module):
|
|
114
114
|
tp_size: Optional[int] = None,
|
115
115
|
prefix: str = "",
|
116
116
|
correction_bias: Optional[torch.Tensor] = None,
|
117
|
+
custom_routing_function: Optional[Callable] = None,
|
118
|
+
activation: str = "silu",
|
117
119
|
):
|
118
120
|
super().__init__()
|
119
121
|
|
@@ -140,6 +142,8 @@ class EPMoE(torch.nn.Module):
|
|
140
142
|
self.num_expert_group = num_expert_group
|
141
143
|
self.topk_group = topk_group
|
142
144
|
self.correction_bias = correction_bias
|
145
|
+
self.custom_routing_function = custom_routing_function
|
146
|
+
self.activation = activation
|
143
147
|
|
144
148
|
if quant_config is None:
|
145
149
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
@@ -166,6 +170,7 @@ class EPMoE(torch.nn.Module):
|
|
166
170
|
|
167
171
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
168
172
|
assert self.quant_method is not None
|
173
|
+
assert self.activation == "silu"
|
169
174
|
|
170
175
|
if self.grouped_gemm_runner is None:
|
171
176
|
self.grouped_gemm_runner = GroupedGemmRunner(
|
@@ -181,6 +186,7 @@ class EPMoE(torch.nn.Module):
|
|
181
186
|
topk_group=self.topk_group,
|
182
187
|
num_expert_group=self.num_expert_group,
|
183
188
|
correction_bias=self.correction_bias,
|
189
|
+
custom_routing_function=self.custom_routing_function,
|
184
190
|
)
|
185
191
|
|
186
192
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
@@ -254,16 +260,20 @@ class EPMoE(torch.nn.Module):
|
|
254
260
|
dtype=torch.float32,
|
255
261
|
device=hidden_states.device,
|
256
262
|
)
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
263
|
+
|
264
|
+
if self.activation == "silu":
|
265
|
+
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
266
|
+
gateup_output,
|
267
|
+
down_input,
|
268
|
+
gateup_output.shape[1],
|
269
|
+
reorder_topk_ids,
|
270
|
+
self.w2_input_scale,
|
271
|
+
self.start_expert_id,
|
272
|
+
self.end_expert_id,
|
273
|
+
BLOCK_SIZE=512,
|
274
|
+
)
|
275
|
+
else:
|
276
|
+
raise ValueError(f"Unsupported activation: {self.activation=}")
|
267
277
|
|
268
278
|
# GroupGemm-1
|
269
279
|
down_output = torch.empty(
|
@@ -309,7 +319,6 @@ class EPMoE(torch.nn.Module):
|
|
309
319
|
ckpt_up_proj_name: str,
|
310
320
|
num_experts: int,
|
311
321
|
) -> List[Tuple[str, str, int, str]]:
|
312
|
-
|
313
322
|
return [
|
314
323
|
# (param_name, weight_name, expert_id, shard_id)
|
315
324
|
(
|
@@ -354,7 +363,6 @@ class EPMoE(torch.nn.Module):
|
|
354
363
|
)
|
355
364
|
return
|
356
365
|
|
357
|
-
expert_data = param.data[expert_id]
|
358
366
|
if shard_id == "w2":
|
359
367
|
param.data[expert_id] = loaded_weight
|
360
368
|
elif shard_id == "w1":
|
@@ -8,7 +8,7 @@ from typing import Callable, Optional
|
|
8
8
|
import torch
|
9
9
|
from torch.nn import functional as F
|
10
10
|
|
11
|
-
from sglang.srt.layers.activation import SiluAndMul
|
11
|
+
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
12
12
|
from sglang.srt.layers.moe.topk import select_experts
|
13
13
|
|
14
14
|
|
@@ -23,6 +23,7 @@ def fused_moe_forward_native(
|
|
23
23
|
num_expert_group: Optional[int] = None,
|
24
24
|
custom_routing_function: Optional[Callable] = None,
|
25
25
|
correction_bias: Optional[torch.Tensor] = None,
|
26
|
+
activation: str = "silu",
|
26
27
|
) -> torch.Tensor:
|
27
28
|
topk_weights, topk_ids = select_experts(
|
28
29
|
hidden_states=x,
|
@@ -41,7 +42,12 @@ def fused_moe_forward_native(
|
|
41
42
|
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
42
43
|
w2_weights = layer.w2_weight[topk_ids]
|
43
44
|
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
44
|
-
|
45
|
+
if activation == "silu":
|
46
|
+
x1 = F.silu(x1)
|
47
|
+
elif activation == "gelu":
|
48
|
+
x1 = F.gelu(x1)
|
49
|
+
else:
|
50
|
+
raise ValueError(f"Unsupported activation: {activation=}")
|
45
51
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
46
52
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
47
53
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
@@ -58,6 +64,7 @@ def moe_forward_native(
|
|
58
64
|
num_expert_group: Optional[int] = None,
|
59
65
|
custom_routing_function: Optional[Callable] = None,
|
60
66
|
correction_bias: Optional[torch.Tensor] = None,
|
67
|
+
activation: str = "silu",
|
61
68
|
) -> torch.Tensor:
|
62
69
|
|
63
70
|
topk_weights, topk_ids = select_experts(
|
@@ -84,6 +91,13 @@ def moe_forward_native(
|
|
84
91
|
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
85
92
|
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
86
93
|
|
94
|
+
if activation == "silu":
|
95
|
+
act = SiluAndMul()
|
96
|
+
elif activation == "gelu":
|
97
|
+
act = GeluAndMul()
|
98
|
+
else:
|
99
|
+
raise ValueError(f"Unsupported activation: {activation=}")
|
100
|
+
|
87
101
|
outputs = []
|
88
102
|
start_idx = 0
|
89
103
|
for i, num_tokens in enumerate(tokens_per_expert):
|
@@ -96,7 +110,7 @@ def moe_forward_native(
|
|
96
110
|
layer_w2_weight = layer.w2_weight[i]
|
97
111
|
|
98
112
|
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
|
99
|
-
gate_up =
|
113
|
+
gate_up = act(gate_up)
|
100
114
|
expert_out = F.linear(gate_up, layer_w2_weight)
|
101
115
|
outputs.append(expert_out)
|
102
116
|
start_idx = end_idx
|