sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +1 -11
- sglang/bench_serving.py +149 -1
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +17 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +30 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +14 -2
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +5 -0
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/lora/lora_manager.py +10 -13
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/schedule_batch.py +19 -1
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +28 -13
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +9 -12
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/model_executor/model_runner.py +44 -33
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +55 -20
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +1 -1
- sglang/srt/models/llama4.py +53 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +24 -40
- sglang/srt/openai_api/protocol.py +28 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +30 -6
- sglang/srt/utils.py +35 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -18,8 +18,9 @@ import torch
|
|
18
18
|
import triton
|
19
19
|
|
20
20
|
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
21
|
-
import
|
21
|
+
import logging
|
22
22
|
|
23
|
+
torch._logging.set_logs(dynamo=logging.ERROR)
|
23
24
|
torch._dynamo.config.suppress_errors = True
|
24
25
|
|
25
26
|
from sglang.global_config import global_config
|
@@ -338,23 +339,39 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
338
339
|
layer: RadixAttention,
|
339
340
|
forward_batch: ForwardBatch,
|
340
341
|
save_kv_cache: bool = True,
|
342
|
+
q_rope: Optional[torch.Tensor] = None,
|
343
|
+
k_rope: Optional[torch.Tensor] = None,
|
341
344
|
):
|
342
345
|
|
343
346
|
cache_loc = forward_batch.out_cache_loc
|
344
347
|
logits_soft_cap = layer.logit_cap
|
345
348
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
346
|
-
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
347
349
|
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
348
350
|
|
349
351
|
# Save kv cache
|
350
352
|
if save_kv_cache and k is not None:
|
351
353
|
assert v is not None
|
352
354
|
if save_kv_cache:
|
353
|
-
|
355
|
+
if k_rope is not None:
|
356
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
357
|
+
layer, cache_loc, k, k_rope
|
358
|
+
)
|
359
|
+
else:
|
360
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
361
|
+
if q_rope is not None:
|
362
|
+
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
363
|
+
q_rope = q_rope.view(
|
364
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
365
|
+
)
|
354
366
|
|
355
367
|
if self.forward_metadata.use_ragged:
|
356
368
|
# ragged prefill
|
357
|
-
|
369
|
+
if q_rope is not None:
|
370
|
+
q = torch.cat([q, q_rope], dim=-1)
|
371
|
+
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
372
|
+
if k_rope is not None:
|
373
|
+
k = torch.cat([k, k_rope], dim=-1)
|
374
|
+
o = self.prefill_wrapper_ragged.forward(
|
358
375
|
qall,
|
359
376
|
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
360
377
|
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
@@ -364,11 +381,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
364
381
|
)
|
365
382
|
else:
|
366
383
|
# mla paged prefill
|
384
|
+
if q_rope is None:
|
385
|
+
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
386
|
+
q, q_rope = (
|
387
|
+
qall[:, :, : layer.v_head_dim],
|
388
|
+
qall[:, :, layer.v_head_dim :],
|
389
|
+
)
|
390
|
+
o = q.new_empty(q.shape)
|
367
391
|
o = prefill_wrapper_paged.run(
|
368
|
-
|
369
|
-
|
392
|
+
q,
|
393
|
+
q_rope,
|
370
394
|
k_buf[:, :, : layer.v_head_dim],
|
371
395
|
k_buf[:, :, layer.v_head_dim :],
|
396
|
+
out=o,
|
372
397
|
)
|
373
398
|
|
374
399
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
@@ -381,6 +406,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
381
406
|
layer: RadixAttention,
|
382
407
|
forward_batch: ForwardBatch,
|
383
408
|
save_kv_cache: bool = True,
|
409
|
+
# For multi-head latent attention
|
410
|
+
q_rope: Optional[torch.Tensor] = None,
|
411
|
+
k_rope: Optional[torch.Tensor] = None,
|
384
412
|
):
|
385
413
|
decode_wrapper = self.forward_metadata.decode_wrapper
|
386
414
|
cache_loc = forward_batch.out_cache_loc
|
@@ -388,23 +416,42 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
388
416
|
if k is not None:
|
389
417
|
assert v is not None
|
390
418
|
if save_kv_cache:
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
419
|
+
if k_rope is not None:
|
420
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
421
|
+
layer,
|
422
|
+
cache_loc,
|
423
|
+
k,
|
424
|
+
k_rope,
|
425
|
+
)
|
426
|
+
else:
|
427
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
428
|
+
layer,
|
429
|
+
cache_loc,
|
430
|
+
k,
|
431
|
+
v,
|
432
|
+
)
|
397
433
|
|
398
434
|
# Reshape inputs
|
399
|
-
|
435
|
+
if q_rope is not None:
|
436
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
437
|
+
q_rope = q_rope.view(
|
438
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
439
|
+
)
|
440
|
+
else:
|
441
|
+
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
442
|
+
q_nope = reshaped_q[:, :, : layer.v_head_dim]
|
443
|
+
q_rope = reshaped_q[:, :, layer.v_head_dim :]
|
444
|
+
|
400
445
|
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
401
446
|
|
447
|
+
o = q_nope.new_empty(q_nope.shape)
|
402
448
|
# Direct call to run without the wrapper
|
403
449
|
o = decode_wrapper.run(
|
404
|
-
|
405
|
-
|
450
|
+
q_nope,
|
451
|
+
q_rope,
|
406
452
|
k_buffer[:, :, : layer.v_head_dim],
|
407
453
|
k_buffer[:, :, layer.v_head_dim :],
|
454
|
+
out=o,
|
408
455
|
)
|
409
456
|
|
410
457
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
@@ -0,0 +1,46 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from sgl_kernel import merge_state_v2
|
5
|
+
|
6
|
+
from sglang.srt.layers.attention.triton_ops.merge_state import merge_state_triton
|
7
|
+
from sglang.srt.utils import is_cuda
|
8
|
+
|
9
|
+
_is_cuda = is_cuda()
|
10
|
+
|
11
|
+
|
12
|
+
# Automatically fallback to the Triton kernel in some cases
|
13
|
+
# (e.g., for AMD GPUs, when the head dimension is not a multiple
|
14
|
+
# of 4 or 8, and in FP8 precision)
|
15
|
+
def _supported_dtypes(o: torch.Tensor) -> bool:
|
16
|
+
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
|
17
|
+
|
18
|
+
|
19
|
+
def _supported_headdim(o: torch.Tensor) -> bool:
|
20
|
+
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
21
|
+
if o.dtype == torch.float32:
|
22
|
+
return headdim % 4 == 0
|
23
|
+
return headdim % 8 == 0
|
24
|
+
|
25
|
+
|
26
|
+
def merge_state(
|
27
|
+
prefix_output: torch.Tensor,
|
28
|
+
prefix_lse: torch.Tensor,
|
29
|
+
suffix_output: torch.Tensor,
|
30
|
+
suffix_lse: torch.Tensor,
|
31
|
+
output: Optional[torch.Tensor] = None,
|
32
|
+
output_lse: Optional[torch.Tensor] = None,
|
33
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
34
|
+
if (
|
35
|
+
_is_cuda
|
36
|
+
and _supported_dtypes(prefix_output)
|
37
|
+
and _supported_headdim(prefix_output)
|
38
|
+
):
|
39
|
+
return merge_state_v2(
|
40
|
+
prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
|
41
|
+
)
|
42
|
+
else:
|
43
|
+
# Fallback to Triton kernel
|
44
|
+
return merge_state_triton(
|
45
|
+
prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
|
46
|
+
)
|
@@ -0,0 +1,96 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
|
8
|
+
@triton.jit
|
9
|
+
def merge_state_kernel(
|
10
|
+
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged
|
11
|
+
output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged
|
12
|
+
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a
|
13
|
+
prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a
|
14
|
+
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b
|
15
|
+
suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b
|
16
|
+
HEAD_SIZE: tl.constexpr,
|
17
|
+
PADDED_HEAD_SIZE: tl.constexpr,
|
18
|
+
OUTPUT_LSE: tl.constexpr,
|
19
|
+
):
|
20
|
+
token_idx = tl.program_id(0)
|
21
|
+
num_tokens = tl.num_programs(0)
|
22
|
+
head_idx = tl.program_id(1)
|
23
|
+
num_heads = tl.num_programs(1)
|
24
|
+
|
25
|
+
p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx)
|
26
|
+
s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx)
|
27
|
+
p_lse = float("-inf") if p_lse == float("inf") else p_lse
|
28
|
+
s_lse = float("-inf") if s_lse == float("inf") else s_lse
|
29
|
+
|
30
|
+
max_lse = tl.maximum(p_lse, s_lse)
|
31
|
+
p_lse = p_lse - max_lse
|
32
|
+
s_lse = s_lse - max_lse
|
33
|
+
out_se = tl.exp(p_lse) + tl.exp(s_lse)
|
34
|
+
|
35
|
+
if OUTPUT_LSE:
|
36
|
+
out_lse = tl.log(out_se) + max_lse
|
37
|
+
tl.store(output_lse + token_idx * num_heads + head_idx, out_lse)
|
38
|
+
|
39
|
+
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
40
|
+
head_mask = head_arange < HEAD_SIZE
|
41
|
+
p_out = tl.load(
|
42
|
+
prefix_output
|
43
|
+
+ token_idx * num_heads * HEAD_SIZE
|
44
|
+
+ head_idx * HEAD_SIZE
|
45
|
+
+ head_arange,
|
46
|
+
mask=head_mask,
|
47
|
+
)
|
48
|
+
s_out = tl.load(
|
49
|
+
suffix_output
|
50
|
+
+ token_idx * num_heads * HEAD_SIZE
|
51
|
+
+ head_idx * HEAD_SIZE
|
52
|
+
+ head_arange,
|
53
|
+
mask=head_mask,
|
54
|
+
)
|
55
|
+
|
56
|
+
p_scale = tl.exp(p_lse) / out_se
|
57
|
+
s_scale = tl.exp(s_lse) / out_se
|
58
|
+
out = p_out * p_scale + s_out * s_scale
|
59
|
+
tl.store(
|
60
|
+
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
|
61
|
+
out,
|
62
|
+
mask=head_mask,
|
63
|
+
)
|
64
|
+
|
65
|
+
|
66
|
+
def merge_state_triton(
|
67
|
+
prefix_output: torch.Tensor,
|
68
|
+
prefix_lse: torch.Tensor,
|
69
|
+
suffix_output: torch.Tensor,
|
70
|
+
suffix_lse: torch.Tensor,
|
71
|
+
output: Optional[torch.Tensor] = None,
|
72
|
+
output_lse: Optional[torch.Tensor] = None,
|
73
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
74
|
+
# Avoid creating new tensors if they are already provided
|
75
|
+
if output is None:
|
76
|
+
output = torch.empty_like(prefix_output)
|
77
|
+
if output_lse is None:
|
78
|
+
output_lse = torch.empty_like(prefix_lse)
|
79
|
+
|
80
|
+
num_tokens = output.shape[0]
|
81
|
+
num_query_heads = output.shape[1]
|
82
|
+
head_size = output.shape[2]
|
83
|
+
padded_head_size = triton.next_power_of_2(head_size)
|
84
|
+
|
85
|
+
merge_state_kernel[(num_tokens, num_query_heads)](
|
86
|
+
output,
|
87
|
+
output_lse,
|
88
|
+
prefix_output,
|
89
|
+
prefix_lse,
|
90
|
+
suffix_output,
|
91
|
+
suffix_lse,
|
92
|
+
head_size,
|
93
|
+
padded_head_size,
|
94
|
+
output_lse is not None,
|
95
|
+
)
|
96
|
+
return output, output_lse
|