sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +22 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
- sglang/srt/layers/attention/flashinfer_backend.py +215 -83
- sglang/srt/layers/attention/torch_native_backend.py +1 -38
- sglang/srt/layers/attention/triton_backend.py +20 -11
- sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
- sglang/srt/layers/linear.py +159 -55
- sglang/srt/layers/logits_processor.py +170 -215
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
- sglang/srt/layers/parameter.py +431 -0
- sglang/srt/layers/quantization/__init__.py +3 -2
- sglang/srt/layers/quantization/fp8.py +3 -3
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/sampler.py +57 -21
- sglang/srt/layers/torchao_utils.py +17 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -1
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +1 -2
- sglang/srt/managers/schedule_batch.py +33 -3
- sglang/srt/managers/schedule_policy.py +159 -90
- sglang/srt/managers/scheduler.py +68 -28
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +27 -21
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/memory_pool.py +206 -1
- sglang/srt/metrics/collector.py +22 -30
- sglang/srt/model_executor/cuda_graph_runner.py +129 -77
- sglang/srt/model_executor/forward_batch_info.py +51 -21
- sglang/srt/model_executor/model_runner.py +72 -64
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek_v2.py +34 -7
- sglang/srt/models/grok.py +109 -29
- sglang/srt/models/llama.py +9 -2
- sglang/srt/openai_api/adapter.py +0 -17
- sglang/srt/openai_api/protocol.py +3 -3
- sglang/srt/sampling/sampling_batch_info.py +22 -0
- sglang/srt/sampling/sampling_params.py +9 -1
- sglang/srt/server.py +20 -13
- sglang/srt/server_args.py +120 -58
- sglang/srt/speculative/build_eagle_tree.py +347 -0
- sglang/srt/speculative/eagle_utils.py +626 -0
- sglang/srt/speculative/eagle_worker.py +184 -0
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/utils.py +47 -7
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -17,15 +17,21 @@ from sglang.srt.layers.moe.topk import select_experts
|
|
17
17
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
18
18
|
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
|
19
19
|
|
20
|
-
|
20
|
+
is_hip_flag = False
|
21
21
|
if not is_hip():
|
22
22
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
23
23
|
|
24
|
-
|
24
|
+
is_hip_flag = False
|
25
|
+
else:
|
26
|
+
is_hip_flag = True
|
25
27
|
|
26
28
|
logger = logging.getLogger(__name__)
|
27
29
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
28
30
|
|
31
|
+
enable_moe_align_block_size_triton = bool(
|
32
|
+
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
33
|
+
)
|
34
|
+
|
29
35
|
|
30
36
|
@triton.jit
|
31
37
|
def fused_moe_kernel(
|
@@ -222,6 +228,139 @@ def fused_moe_kernel(
|
|
222
228
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
223
229
|
|
224
230
|
|
231
|
+
def ceil_div(a, b):
|
232
|
+
return (a + b - 1) // b
|
233
|
+
|
234
|
+
|
235
|
+
@triton.jit
|
236
|
+
def moe_align_block_size_stage1(
|
237
|
+
topk_ids_ptr,
|
238
|
+
tokens_cnts_ptr,
|
239
|
+
num_experts: tl.constexpr,
|
240
|
+
numel: tl.constexpr,
|
241
|
+
tokens_per_thread: tl.constexpr,
|
242
|
+
):
|
243
|
+
pid = tl.program_id(0)
|
244
|
+
|
245
|
+
start_idx = pid * tokens_per_thread
|
246
|
+
|
247
|
+
off_c = (pid + 1) * num_experts
|
248
|
+
|
249
|
+
for i in range(tokens_per_thread):
|
250
|
+
if start_idx + i < numel:
|
251
|
+
idx = tl.load(topk_ids_ptr + start_idx + i)
|
252
|
+
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
253
|
+
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
254
|
+
|
255
|
+
|
256
|
+
@triton.jit
|
257
|
+
def moe_align_block_size_stage2(
|
258
|
+
tokens_cnts_ptr,
|
259
|
+
num_experts: tl.constexpr,
|
260
|
+
):
|
261
|
+
pid = tl.program_id(0)
|
262
|
+
|
263
|
+
last_cnt = 0
|
264
|
+
for i in range(1, num_experts + 1):
|
265
|
+
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
266
|
+
last_cnt = last_cnt + token_cnt
|
267
|
+
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
268
|
+
|
269
|
+
|
270
|
+
@triton.jit
|
271
|
+
def moe_align_block_size_stage3(
|
272
|
+
total_tokens_post_pad_ptr,
|
273
|
+
tokens_cnts_ptr,
|
274
|
+
cumsum_ptr,
|
275
|
+
num_experts: tl.constexpr,
|
276
|
+
block_size: tl.constexpr,
|
277
|
+
):
|
278
|
+
last_cumsum = 0
|
279
|
+
off_cnt = num_experts * num_experts
|
280
|
+
for i in range(1, num_experts + 1):
|
281
|
+
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
282
|
+
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
283
|
+
tl.store(cumsum_ptr + i, last_cumsum)
|
284
|
+
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
285
|
+
|
286
|
+
|
287
|
+
@triton.jit
|
288
|
+
def moe_align_block_size_stage4(
|
289
|
+
topk_ids_ptr,
|
290
|
+
sorted_token_ids_ptr,
|
291
|
+
expert_ids_ptr,
|
292
|
+
tokens_cnts_ptr,
|
293
|
+
cumsum_ptr,
|
294
|
+
num_experts: tl.constexpr,
|
295
|
+
block_size: tl.constexpr,
|
296
|
+
numel: tl.constexpr,
|
297
|
+
tokens_per_thread: tl.constexpr,
|
298
|
+
):
|
299
|
+
pid = tl.program_id(0)
|
300
|
+
start_idx = tl.load(cumsum_ptr + pid)
|
301
|
+
end_idx = tl.load(cumsum_ptr + pid + 1)
|
302
|
+
|
303
|
+
for i in range(start_idx, end_idx, block_size):
|
304
|
+
tl.store(expert_ids_ptr + i // block_size, pid)
|
305
|
+
|
306
|
+
start_idx = pid * tokens_per_thread
|
307
|
+
off_t = pid * num_experts
|
308
|
+
|
309
|
+
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
|
310
|
+
expert_id = tl.load(topk_ids_ptr + i)
|
311
|
+
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
312
|
+
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
313
|
+
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
314
|
+
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
315
|
+
|
316
|
+
|
317
|
+
def moe_align_block_size_triton(
|
318
|
+
topk_ids: torch.Tensor,
|
319
|
+
num_experts: int,
|
320
|
+
block_size: int,
|
321
|
+
sorted_token_ids: torch.Tensor,
|
322
|
+
expert_ids: torch.Tensor,
|
323
|
+
num_tokens_post_pad: torch.Tensor,
|
324
|
+
) -> None:
|
325
|
+
numel = topk_ids.numel()
|
326
|
+
grid = (num_experts,)
|
327
|
+
tokens_cnts = torch.zeros(
|
328
|
+
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
|
329
|
+
)
|
330
|
+
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
|
331
|
+
tokens_per_thread = ceil_div(numel, num_experts)
|
332
|
+
|
333
|
+
moe_align_block_size_stage1[grid](
|
334
|
+
topk_ids,
|
335
|
+
tokens_cnts,
|
336
|
+
num_experts,
|
337
|
+
numel,
|
338
|
+
tokens_per_thread,
|
339
|
+
)
|
340
|
+
moe_align_block_size_stage2[grid](
|
341
|
+
tokens_cnts,
|
342
|
+
num_experts,
|
343
|
+
)
|
344
|
+
moe_align_block_size_stage3[(1,)](
|
345
|
+
num_tokens_post_pad,
|
346
|
+
tokens_cnts,
|
347
|
+
cumsum,
|
348
|
+
num_experts,
|
349
|
+
block_size,
|
350
|
+
)
|
351
|
+
moe_align_block_size_stage4[grid](
|
352
|
+
topk_ids,
|
353
|
+
sorted_token_ids,
|
354
|
+
expert_ids,
|
355
|
+
tokens_cnts,
|
356
|
+
cumsum,
|
357
|
+
num_experts,
|
358
|
+
block_size,
|
359
|
+
numel,
|
360
|
+
tokens_per_thread,
|
361
|
+
)
|
362
|
+
|
363
|
+
|
225
364
|
def moe_align_block_size(
|
226
365
|
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
227
366
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
@@ -272,24 +411,36 @@ def moe_align_block_size(
|
|
272
411
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
273
412
|
)
|
274
413
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
275
|
-
if
|
276
|
-
|
277
|
-
(
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
414
|
+
if num_experts >= 224:
|
415
|
+
if enable_moe_align_block_size_triton or is_hip_flag:
|
416
|
+
moe_align_block_size_triton(
|
417
|
+
topk_ids,
|
418
|
+
num_experts,
|
419
|
+
block_size,
|
420
|
+
sorted_ids,
|
421
|
+
expert_ids,
|
422
|
+
num_tokens_post_pad,
|
423
|
+
)
|
424
|
+
else:
|
425
|
+
token_cnts_buffer = torch.empty(
|
426
|
+
(num_experts + 1) * num_experts,
|
427
|
+
dtype=torch.int32,
|
428
|
+
device=topk_ids.device,
|
429
|
+
)
|
430
|
+
cumsum_buffer = torch.empty(
|
431
|
+
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
432
|
+
)
|
282
433
|
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
434
|
+
sgl_moe_align_block_size(
|
435
|
+
topk_ids,
|
436
|
+
num_experts,
|
437
|
+
block_size,
|
438
|
+
sorted_ids,
|
439
|
+
expert_ids,
|
440
|
+
num_tokens_post_pad,
|
441
|
+
token_cnts_buffer,
|
442
|
+
cumsum_buffer,
|
443
|
+
)
|
293
444
|
else:
|
294
445
|
ops.moe_align_block_size(
|
295
446
|
topk_ids,
|
@@ -326,9 +477,9 @@ def invoke_fused_moe_kernel(
|
|
326
477
|
|
327
478
|
padded_size = 0
|
328
479
|
if use_fp8_w8a8:
|
329
|
-
padded_size = padding_size
|
330
480
|
assert B_scale is not None
|
331
481
|
if block_shape is None:
|
482
|
+
padded_size = padding_size
|
332
483
|
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
333
484
|
else:
|
334
485
|
assert len(block_shape) == 2
|
@@ -463,7 +614,7 @@ def get_default_config(
|
|
463
614
|
"BLOCK_SIZE_K": 128,
|
464
615
|
"GROUP_SIZE_M": 32,
|
465
616
|
"num_warps": 8,
|
466
|
-
"num_stages": 4,
|
617
|
+
"num_stages": 2 if is_hip_flag else 4,
|
467
618
|
}
|
468
619
|
if M <= E:
|
469
620
|
config = {
|
@@ -472,7 +623,7 @@ def get_default_config(
|
|
472
623
|
"BLOCK_SIZE_K": 128,
|
473
624
|
"GROUP_SIZE_M": 1,
|
474
625
|
"num_warps": 4,
|
475
|
-
"num_stages": 4,
|
626
|
+
"num_stages": 2 if is_hip_flag else 4,
|
476
627
|
}
|
477
628
|
else:
|
478
629
|
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
|
@@ -482,7 +633,7 @@ def get_default_config(
|
|
482
633
|
"BLOCK_SIZE_K": block_shape[1],
|
483
634
|
"GROUP_SIZE_M": 32,
|
484
635
|
"num_warps": 4,
|
485
|
-
"num_stages": 3,
|
636
|
+
"num_stages": 2 if is_hip_flag else 3,
|
486
637
|
}
|
487
638
|
else:
|
488
639
|
config = {
|
@@ -727,7 +878,7 @@ def fused_experts_impl(
|
|
727
878
|
block_shape: Optional[List[int]] = None,
|
728
879
|
):
|
729
880
|
padded_size = padding_size
|
730
|
-
if not use_fp8_w8a8:
|
881
|
+
if not use_fp8_w8a8 or block_shape is not None:
|
731
882
|
padded_size = 0
|
732
883
|
|
733
884
|
# Check constraints.
|
@@ -854,11 +1005,29 @@ def fused_experts_impl(
|
|
854
1005
|
block_shape=block_shape,
|
855
1006
|
)
|
856
1007
|
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
1008
|
+
if is_hip_flag:
|
1009
|
+
ops.moe_sum(
|
1010
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
1011
|
+
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1012
|
+
)
|
1013
|
+
else:
|
1014
|
+
if topk_ids.shape[1] == 1:
|
1015
|
+
out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_(
|
1016
|
+
intermediate_cache3[:, 0]
|
1017
|
+
)
|
1018
|
+
elif topk_ids.shape[1] == 2:
|
1019
|
+
torch.add(
|
1020
|
+
intermediate_cache3[:, 0],
|
1021
|
+
intermediate_cache3[:, 1],
|
1022
|
+
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1023
|
+
).squeeze(dim=1)
|
1024
|
+
elif topk_ids.shape[1] > 2:
|
1025
|
+
torch.sum(
|
1026
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
1027
|
+
dim=1,
|
1028
|
+
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1029
|
+
)
|
1030
|
+
|
862
1031
|
return out_hidden_states
|
863
1032
|
|
864
1033
|
|
@@ -204,6 +204,7 @@ class FusedMoE(torch.nn.Module):
|
|
204
204
|
prefix: str = "",
|
205
205
|
custom_routing_function: Optional[Callable] = None,
|
206
206
|
correction_bias: Optional[torch.Tensor] = None,
|
207
|
+
use_presharded_weights: bool = False,
|
207
208
|
):
|
208
209
|
super().__init__()
|
209
210
|
|
@@ -243,6 +244,7 @@ class FusedMoE(torch.nn.Module):
|
|
243
244
|
params_dtype=params_dtype,
|
244
245
|
weight_loader=self.weight_loader,
|
245
246
|
)
|
247
|
+
self.use_presharded_weights = use_presharded_weights
|
246
248
|
|
247
249
|
def _load_per_tensor_weight_scale(
|
248
250
|
self,
|
@@ -321,9 +323,12 @@ class FusedMoE(torch.nn.Module):
|
|
321
323
|
# Index the loaded weight for tp sharding.
|
322
324
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
323
325
|
shard_size = expert_data.shape[shard_dim] // 2
|
324
|
-
|
325
|
-
|
326
|
-
|
326
|
+
|
327
|
+
if not self.use_presharded_weights:
|
328
|
+
loaded_weight = loaded_weight.narrow(
|
329
|
+
shard_dim, shard_size * tp_rank, shard_size
|
330
|
+
)
|
331
|
+
|
327
332
|
# Narrow parameter and load.
|
328
333
|
# w1, gate_proj: Load into first logical weight of w13.
|
329
334
|
if shard_id == "w1":
|
@@ -347,9 +352,12 @@ class FusedMoE(torch.nn.Module):
|
|
347
352
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
348
353
|
# Narrow parameter and load.
|
349
354
|
shard_size = expert_data.shape[shard_dim]
|
350
|
-
|
351
|
-
|
352
|
-
|
355
|
+
|
356
|
+
if not self.use_presharded_weights:
|
357
|
+
loaded_weight = loaded_weight.narrow(
|
358
|
+
shard_dim, shard_size * tp_rank, shard_size
|
359
|
+
)
|
360
|
+
|
353
361
|
# w2, down_proj: Load into only logical weight of w2.
|
354
362
|
expert_data.copy_(loaded_weight)
|
355
363
|
|
@@ -390,7 +398,6 @@ class FusedMoE(torch.nn.Module):
|
|
390
398
|
shard_id: str,
|
391
399
|
expert_id: int,
|
392
400
|
) -> None:
|
393
|
-
|
394
401
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
395
402
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
396
403
|
# against known CompressionFormat enum values that have this quality
|