sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
sglang/srt/layers/activation.py
CHANGED
@@ -33,6 +33,7 @@ from sglang.srt.utils import (
|
|
33
33
|
cpu_has_amx_support,
|
34
34
|
is_cpu,
|
35
35
|
is_cuda,
|
36
|
+
is_hip,
|
36
37
|
is_npu,
|
37
38
|
set_weight_attrs,
|
38
39
|
)
|
@@ -42,9 +43,12 @@ _is_cuda = is_cuda()
|
|
42
43
|
_is_npu = is_npu()
|
43
44
|
_is_cpu_amx_available = cpu_has_amx_support()
|
44
45
|
_is_cpu = is_cpu()
|
46
|
+
_is_hip = is_hip()
|
45
47
|
|
46
48
|
if _is_cuda:
|
47
49
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
50
|
+
elif _is_hip:
|
51
|
+
from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
|
48
52
|
|
49
53
|
if is_npu():
|
50
54
|
import torch_npu
|
@@ -110,14 +114,29 @@ class NewGELU(CustomOp):
|
|
110
114
|
return self.forward_native(x)
|
111
115
|
|
112
116
|
|
117
|
+
class ReLU2(nn.Module):
|
118
|
+
"""
|
119
|
+
Applies the squared Rectified Linear Unit function.
|
120
|
+
y = max(0, x)^2
|
121
|
+
"""
|
122
|
+
|
123
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
124
|
+
x = F.relu(x)
|
125
|
+
return x * x
|
126
|
+
|
127
|
+
|
113
128
|
class QuickGELU(CustomOp):
|
114
129
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
115
130
|
return x * torch.sigmoid(1.702 * x)
|
116
131
|
|
117
132
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
118
|
-
# TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel
|
119
133
|
return self.forward_native(x)
|
120
134
|
|
135
|
+
def forward_hip(self, x: torch.Tensor) -> torch.Tensor:
|
136
|
+
out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
|
137
|
+
gelu_quick(x, out)
|
138
|
+
return out
|
139
|
+
|
121
140
|
|
122
141
|
class ScaledActivation(nn.Module):
|
123
142
|
"""An activation function with post-scale parameters.
|
@@ -164,6 +183,8 @@ class ScaledActivation(nn.Module):
|
|
164
183
|
_ACTIVATION_REGISTRY = {
|
165
184
|
"gelu": nn.GELU(),
|
166
185
|
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
186
|
+
"gelu_new": NewGELU(),
|
187
|
+
"relu2": ReLU2(),
|
167
188
|
}
|
168
189
|
|
169
190
|
|
@@ -209,8 +230,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
|
209
230
|
return nn.Identity()
|
210
231
|
|
211
232
|
|
212
|
-
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
233
|
+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
213
234
|
logger.info(
|
214
|
-
"sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
235
|
+
"sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
215
236
|
)
|
216
237
|
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
@@ -65,7 +65,9 @@ class AttentionBackend(ABC):
|
|
65
65
|
**kwargs,
|
66
66
|
):
|
67
67
|
"""Run forward on an attention layer."""
|
68
|
-
if forward_batch.forward_mode.
|
68
|
+
if forward_batch.forward_mode.is_idle():
|
69
|
+
return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
|
70
|
+
elif forward_batch.forward_mode.is_decode():
|
69
71
|
return self.forward_decode(
|
70
72
|
q,
|
71
73
|
k,
|
@@ -1617,7 +1617,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1617
1617
|
metadata.max_seq_len_k + self.page_size - 1
|
1618
1618
|
) // self.page_size
|
1619
1619
|
|
1620
|
-
|
1620
|
+
normal_decode_set_metadata(
|
1621
1621
|
metadata.cache_seqlens_int32,
|
1622
1622
|
metadata.cu_seqlens_k,
|
1623
1623
|
metadata.page_table,
|
@@ -1666,7 +1666,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1666
1666
|
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
1667
1667
|
metadata.max_seq_len_k = max_len
|
1668
1668
|
|
1669
|
-
|
1669
|
+
normal_decode_set_metadata(
|
1670
1670
|
metadata.cache_seqlens_int32,
|
1671
1671
|
metadata.cu_seqlens_k,
|
1672
1672
|
metadata.page_table,
|
@@ -2089,7 +2089,7 @@ class FlashAttentionMultiStepBackend:
|
|
2089
2089
|
# @torch.compile(dynamic=True, backend=get_compiler_backend())
|
2090
2090
|
# TODO: fuse these kernels
|
2091
2091
|
# NOTE: torch.compile makes it slower in speculative decoding
|
2092
|
-
def
|
2092
|
+
def normal_decode_set_metadata(
|
2093
2093
|
cache_seqlens_int32: torch.Tensor,
|
2094
2094
|
cu_seqlens_k: torch.Tensor,
|
2095
2095
|
page_table: torch.Tensor,
|
@@ -25,7 +25,9 @@ from sglang.global_config import global_config
|
|
25
25
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
26
26
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
27
27
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
28
|
+
from sglang.srt.layers.radix_attention import AttentionType
|
28
29
|
from sglang.srt.layers.utils import is_sm100_supported
|
30
|
+
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
29
31
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
30
32
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
31
33
|
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
|
@@ -485,12 +487,20 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
485
487
|
v_scale=layer.v_scale,
|
486
488
|
)
|
487
489
|
else:
|
490
|
+
causal = True
|
491
|
+
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
492
|
+
save_kv_cache = False
|
493
|
+
causal = False
|
494
|
+
|
488
495
|
if self.forward_metadata.extend_no_prefix:
|
496
|
+
# NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
|
497
|
+
# The FlashInfer head_dim limitation itself is tracked here:
|
498
|
+
# https://github.com/flashinfer-ai/flashinfer/issues/1048
|
489
499
|
o = self.prefill_wrapper_ragged.forward(
|
490
500
|
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
491
501
|
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
492
502
|
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
493
|
-
causal=
|
503
|
+
causal=causal,
|
494
504
|
sm_scale=layer.scaling,
|
495
505
|
logits_soft_cap=logits_soft_cap,
|
496
506
|
)
|
@@ -589,6 +599,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
589
599
|
self.kv_indptr = attn_backend.kv_indptr
|
590
600
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
591
601
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
602
|
+
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
592
603
|
|
593
604
|
# Dispatch the update function
|
594
605
|
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
@@ -655,6 +666,10 @@ class FlashInferIndicesUpdaterDecode:
|
|
655
666
|
paged_kernel_lens_sum_tmp = seq_lens_sum
|
656
667
|
kv_start_idx_tmp = None
|
657
668
|
|
669
|
+
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
|
670
|
+
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
|
671
|
+
)
|
672
|
+
|
658
673
|
self.call_begin_forward(
|
659
674
|
decode_wrappers[wrapper_id],
|
660
675
|
req_pool_indices,
|
@@ -663,6 +678,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
663
678
|
self.kv_indptr[wrapper_id],
|
664
679
|
kv_start_idx_tmp,
|
665
680
|
spec_info,
|
681
|
+
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
666
682
|
)
|
667
683
|
|
668
684
|
def update_cross_attention(
|
@@ -704,6 +720,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
704
720
|
kv_indptr: torch.Tensor,
|
705
721
|
kv_start_idx: torch.Tensor,
|
706
722
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
723
|
+
use_sliding_window_kv_pool: bool = False,
|
707
724
|
):
|
708
725
|
if spec_info is None:
|
709
726
|
bs = len(req_pool_indices)
|
@@ -731,6 +748,14 @@ class FlashInferIndicesUpdaterDecode:
|
|
731
748
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
732
749
|
bs = kv_indptr.shape[0] - 1
|
733
750
|
|
751
|
+
if use_sliding_window_kv_pool:
|
752
|
+
kv_last_index = kv_indptr[-1]
|
753
|
+
kv_indices[:kv_last_index] = (
|
754
|
+
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
755
|
+
kv_indices[:kv_last_index]
|
756
|
+
)
|
757
|
+
)
|
758
|
+
|
734
759
|
wrapper.begin_forward(
|
735
760
|
kv_indptr,
|
736
761
|
kv_indices,
|
@@ -765,6 +790,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
765
790
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
766
791
|
self.qo_indptr = attn_backend.qo_indptr
|
767
792
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
793
|
+
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
768
794
|
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
769
795
|
|
770
796
|
# Dispatch the update function
|
@@ -848,6 +874,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|
848
874
|
paged_kernel_lens_sum = seq_lens_sum
|
849
875
|
|
850
876
|
kv_start_idx = seq_lens - paged_kernel_lens
|
877
|
+
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
|
878
|
+
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
|
879
|
+
)
|
851
880
|
|
852
881
|
self.call_begin_forward(
|
853
882
|
self.prefill_wrapper_ragged,
|
@@ -862,6 +891,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
862
891
|
self.qo_indptr[wrapper_id],
|
863
892
|
use_ragged,
|
864
893
|
spec_info,
|
894
|
+
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
865
895
|
)
|
866
896
|
|
867
897
|
def update_cross_attention(
|
@@ -916,6 +946,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
916
946
|
qo_indptr: torch.Tensor,
|
917
947
|
use_ragged: bool,
|
918
948
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
949
|
+
use_sliding_window_kv_pool: bool = False,
|
919
950
|
):
|
920
951
|
bs = len(seq_lens)
|
921
952
|
if spec_info is None:
|
@@ -964,6 +995,14 @@ class FlashInferIndicesUpdaterPrefill:
|
|
964
995
|
q_data_type=self.q_data_type,
|
965
996
|
)
|
966
997
|
|
998
|
+
if use_sliding_window_kv_pool:
|
999
|
+
kv_last_index = kv_indptr[-1]
|
1000
|
+
kv_indices[:kv_last_index] = (
|
1001
|
+
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
1002
|
+
kv_indices[:kv_last_index]
|
1003
|
+
)
|
1004
|
+
)
|
1005
|
+
|
967
1006
|
# cached part
|
968
1007
|
wrapper_paged.begin_forward(
|
969
1008
|
qo_indptr,
|
@@ -24,8 +24,8 @@ from sglang.srt.distributed import (
|
|
24
24
|
tensor_model_parallel_all_reduce,
|
25
25
|
)
|
26
26
|
from sglang.srt.layers.dp_attention import (
|
27
|
-
|
28
|
-
|
27
|
+
attn_tp_all_gather_into_tensor,
|
28
|
+
attn_tp_reduce_scatter_tensor,
|
29
29
|
dp_gather_partial,
|
30
30
|
dp_scatter,
|
31
31
|
get_attention_dp_size,
|
@@ -309,8 +309,8 @@ class CommunicateSimpleFn:
|
|
309
309
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
310
310
|
hidden_states,
|
311
311
|
)
|
312
|
-
|
313
|
-
|
312
|
+
attn_tp_all_gather_into_tensor(
|
313
|
+
hidden_states,
|
314
314
|
local_hidden_states,
|
315
315
|
)
|
316
316
|
return hidden_states
|
@@ -400,9 +400,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
400
400
|
].clone(),
|
401
401
|
residual,
|
402
402
|
)
|
403
|
-
|
404
|
-
list(residual.tensor_split(context.attn_tp_size)), local_residual
|
405
|
-
)
|
403
|
+
attn_tp_all_gather_into_tensor(residual, local_residual)
|
406
404
|
if context.attn_dp_size != 1:
|
407
405
|
if context.attn_tp_rank == 0:
|
408
406
|
hidden_states += residual
|
@@ -442,9 +440,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
442
440
|
*,
|
443
441
|
residual_input_mode,
|
444
442
|
):
|
445
|
-
|
446
|
-
hidden_states =
|
447
|
-
|
443
|
+
input_hidden_states = hidden_states
|
444
|
+
hidden_states = hidden_states.tensor_split(context.attn_tp_size)[
|
445
|
+
context.attn_tp_rank
|
446
|
+
]
|
447
|
+
attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states)
|
448
448
|
if residual_input_mode == ScatterMode.TP_ATTN_FULL:
|
449
449
|
residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]
|
450
450
|
if hidden_states.shape[0] != 0:
|
@@ -547,8 +547,8 @@ class CommunicateSummableTensorPairFn:
|
|
547
547
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
548
548
|
hidden_states,
|
549
549
|
)
|
550
|
-
|
551
|
-
|
550
|
+
attn_tp_all_gather_into_tensor(
|
551
|
+
hidden_states,
|
552
552
|
local_hidden_states,
|
553
553
|
)
|
554
554
|
return hidden_states, residual
|
@@ -3,7 +3,8 @@ from __future__ import annotations
|
|
3
3
|
import functools
|
4
4
|
import logging
|
5
5
|
from contextlib import contextmanager
|
6
|
-
from
|
6
|
+
from enum import IntEnum, auto
|
7
|
+
from typing import TYPE_CHECKING, List, Tuple
|
7
8
|
|
8
9
|
import torch
|
9
10
|
import triton
|
@@ -30,6 +31,34 @@ _LOCAL_ATTN_DP_SIZE = None
|
|
30
31
|
_LOCAL_ATTN_DP_RANK = None
|
31
32
|
|
32
33
|
|
34
|
+
class DPPaddingMode(IntEnum):
|
35
|
+
|
36
|
+
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
|
37
|
+
MAX_LEN = auto()
|
38
|
+
# Padding tokens to sum length and then gather tokens using `all_reduce`
|
39
|
+
SUM_LEN = auto()
|
40
|
+
|
41
|
+
def is_max_len(self):
|
42
|
+
return self == DPPaddingMode.MAX_LEN
|
43
|
+
|
44
|
+
def is_sum_len(self):
|
45
|
+
return self == DPPaddingMode.SUM_LEN
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
|
49
|
+
# we choose the mode that minimizes the communication cost
|
50
|
+
max_len = max(global_num_tokens)
|
51
|
+
sum_len = sum(global_num_tokens)
|
52
|
+
if sum_len * 2 > max_len * get_attention_dp_size():
|
53
|
+
return cls.MAX_LEN
|
54
|
+
else:
|
55
|
+
return cls.SUM_LEN
|
56
|
+
|
57
|
+
@classmethod
|
58
|
+
def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
|
59
|
+
return cls.MAX_LEN
|
60
|
+
|
61
|
+
|
33
62
|
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
34
63
|
if not enable_dp_attention:
|
35
64
|
return tp_rank, tp_size, 0
|
@@ -162,7 +191,7 @@ def disable_dp_size():
|
|
162
191
|
_ATTN_DP_SIZE = old_dp_size
|
163
192
|
|
164
193
|
|
165
|
-
def get_dp_local_info(forward_batch: ForwardBatch):
|
194
|
+
def get_dp_local_info(forward_batch: ForwardBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
166
195
|
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
|
167
196
|
dp_rank = get_attention_dp_rank()
|
168
197
|
|
@@ -221,7 +250,7 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
|
|
221
250
|
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
|
222
251
|
|
223
252
|
|
224
|
-
def
|
253
|
+
def _dp_gather_via_all_reduce(
|
225
254
|
global_tokens: torch.Tensor,
|
226
255
|
local_tokens: torch.Tensor,
|
227
256
|
forward_batch: ForwardBatch,
|
@@ -238,13 +267,6 @@ def _dp_gather(
|
|
238
267
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
239
268
|
), "aliasing between global_tokens and local_tokens not allowed"
|
240
269
|
|
241
|
-
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
|
242
|
-
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
|
243
|
-
# actual size of the accepted tokens.
|
244
|
-
if forward_batch.forward_mode.is_draft_extend():
|
245
|
-
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
246
|
-
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
247
|
-
|
248
270
|
memcpy_triton(
|
249
271
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
250
272
|
)
|
@@ -263,6 +285,38 @@ def _dp_gather(
|
|
263
285
|
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
|
264
286
|
|
265
287
|
|
288
|
+
def _dp_gather_via_all_gather(
|
289
|
+
global_tokens: torch.Tensor,
|
290
|
+
local_tokens: torch.Tensor,
|
291
|
+
forward_batch: ForwardBatch,
|
292
|
+
is_partial: bool,
|
293
|
+
):
|
294
|
+
if not is_partial:
|
295
|
+
if get_attention_tp_rank() != 0:
|
296
|
+
local_tokens.fill_(0)
|
297
|
+
scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[
|
298
|
+
get_attention_tp_rank()
|
299
|
+
]
|
300
|
+
get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens)
|
301
|
+
get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens)
|
302
|
+
|
303
|
+
|
304
|
+
def _dp_gather(
|
305
|
+
global_tokens: torch.Tensor,
|
306
|
+
local_tokens: torch.Tensor,
|
307
|
+
forward_batch: ForwardBatch,
|
308
|
+
is_partial: bool,
|
309
|
+
):
|
310
|
+
if forward_batch.dp_padding_mode.is_max_len():
|
311
|
+
_dp_gather_via_all_gather(
|
312
|
+
global_tokens, local_tokens, forward_batch, is_partial
|
313
|
+
)
|
314
|
+
else:
|
315
|
+
_dp_gather_via_all_reduce(
|
316
|
+
global_tokens, local_tokens, forward_batch, is_partial
|
317
|
+
)
|
318
|
+
|
319
|
+
|
266
320
|
def dp_gather_partial(
|
267
321
|
global_tokens: torch.Tensor,
|
268
322
|
local_tokens: torch.Tensor,
|
@@ -296,24 +350,18 @@ def dp_scatter(
|
|
296
350
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
297
351
|
), "aliasing between local_tokens and global_tokens not allowed"
|
298
352
|
|
299
|
-
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
|
300
|
-
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
|
301
|
-
# actual size of the accepted tokens.
|
302
|
-
if forward_batch.forward_mode.is_draft_extend():
|
303
|
-
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
304
|
-
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
305
|
-
|
306
353
|
memcpy_triton(
|
307
354
|
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
308
355
|
)
|
309
356
|
|
310
357
|
|
311
|
-
def
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
358
|
+
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
359
|
+
return get_attention_tp_group().reduce_scatter_tensor(output, input)
|
360
|
+
|
361
|
+
|
362
|
+
def attn_tp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor):
|
363
|
+
return get_attention_tp_group().all_gather_into_tensor(output, input)
|
316
364
|
|
317
365
|
|
318
|
-
def attn_tp_all_gather(output_list: List[torch.Tensor],
|
319
|
-
return get_attention_tp_group().all_gather(
|
366
|
+
def attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor):
|
367
|
+
return get_attention_tp_group().all_gather(input, output_tensor_list=output_list)
|
sglang/srt/layers/linear.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import itertools
|
4
6
|
import logging
|
5
|
-
from
|
6
|
-
from typing import Dict, List, Optional, Tuple
|
7
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
7
8
|
|
8
9
|
import torch
|
9
|
-
import torch.nn.functional as F
|
10
10
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
11
11
|
|
12
12
|
from sglang.srt.distributed import (
|
@@ -17,7 +17,6 @@ from sglang.srt.distributed import (
|
|
17
17
|
tensor_model_parallel_all_gather,
|
18
18
|
tensor_model_parallel_all_reduce,
|
19
19
|
)
|
20
|
-
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
21
20
|
from sglang.srt.layers.parameter import (
|
22
21
|
BasevLLMParameter,
|
23
22
|
BlockQuantScaleParameter,
|
@@ -27,17 +26,14 @@ from sglang.srt.layers.parameter import (
|
|
27
26
|
RowvLLMParameter,
|
28
27
|
_ColumnvLLMParameter,
|
29
28
|
)
|
30
|
-
from sglang.srt.layers.quantization.
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
from sglang.srt.
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
set_weight_attrs,
|
39
|
-
use_intel_amx_backend,
|
40
|
-
)
|
29
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
30
|
+
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
|
31
|
+
|
32
|
+
if TYPE_CHECKING:
|
33
|
+
from sglang.srt.layers.quantization.base_config import (
|
34
|
+
QuantizationConfig,
|
35
|
+
QuantizeMethodBase,
|
36
|
+
)
|
41
37
|
|
42
38
|
logger = logging.getLogger(__name__)
|
43
39
|
|
@@ -57,9 +53,9 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|
57
53
|
"ModelOptFp8LinearMethod",
|
58
54
|
"ModelOptFp4LinearMethod",
|
59
55
|
"IPEXAWQLinearMethod",
|
56
|
+
"PetitNvFp4LinearMethod",
|
60
57
|
]
|
61
58
|
|
62
|
-
_is_cpu_amx_available = cpu_has_amx_support()
|
63
59
|
_is_cpu = is_cpu()
|
64
60
|
_is_npu = is_npu()
|
65
61
|
|
@@ -110,91 +106,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
|
110
106
|
return param[shard_id], loaded_weight
|
111
107
|
|
112
108
|
|
113
|
-
class LinearMethodBase(QuantizeMethodBase):
|
114
|
-
"""Base class for different (maybe quantized) linear methods."""
|
115
|
-
|
116
|
-
@abstractmethod
|
117
|
-
def create_weights(
|
118
|
-
self,
|
119
|
-
layer: torch.nn.Module,
|
120
|
-
input_size_per_partition: int,
|
121
|
-
output_partition_sizes: List[int],
|
122
|
-
input_size: int,
|
123
|
-
output_size: int,
|
124
|
-
params_dtype: torch.dtype,
|
125
|
-
**extra_weight_attrs,
|
126
|
-
):
|
127
|
-
"""Create weights for a linear layer.
|
128
|
-
The weights will be set as attributes of the layer.
|
129
|
-
|
130
|
-
Args:
|
131
|
-
layer: The layer that is using the LinearMethodBase factory.
|
132
|
-
input_size_per_partition: Size of the weight input dim on rank X.
|
133
|
-
output_partition_sizes: Sizes of the output dim of each logical
|
134
|
-
weight on rank X. E.g., output_partition_sizes for QKVLinear
|
135
|
-
is a list contains the width of Wq, Wk, Wv on rank X.
|
136
|
-
input_size: Size of the input dim of the weight across all ranks.
|
137
|
-
output_size: Size of the output dim of the weight across all ranks.
|
138
|
-
params_dtype: Datatype of the parameters.
|
139
|
-
"""
|
140
|
-
raise NotImplementedError
|
141
|
-
|
142
|
-
@abstractmethod
|
143
|
-
def apply(
|
144
|
-
self,
|
145
|
-
layer: torch.nn.Module,
|
146
|
-
x: torch.Tensor,
|
147
|
-
bias: Optional[torch.Tensor] = None,
|
148
|
-
) -> torch.Tensor:
|
149
|
-
"""Apply the weights in layer to the input tensor.
|
150
|
-
Expects create_weights to have been called before on the layer."""
|
151
|
-
raise NotImplementedError
|
152
|
-
|
153
|
-
|
154
|
-
class UnquantizedLinearMethod(LinearMethodBase):
|
155
|
-
"""Linear method without quantization."""
|
156
|
-
|
157
|
-
def create_weights(
|
158
|
-
self,
|
159
|
-
layer: torch.nn.Module,
|
160
|
-
input_size_per_partition: int,
|
161
|
-
output_partition_sizes: List[int],
|
162
|
-
input_size: int,
|
163
|
-
output_size: int,
|
164
|
-
params_dtype: torch.dtype,
|
165
|
-
**extra_weight_attrs,
|
166
|
-
):
|
167
|
-
weight = Parameter(
|
168
|
-
torch.empty(
|
169
|
-
sum(output_partition_sizes),
|
170
|
-
input_size_per_partition,
|
171
|
-
dtype=params_dtype,
|
172
|
-
),
|
173
|
-
requires_grad=False,
|
174
|
-
)
|
175
|
-
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
176
|
-
layer.register_parameter("weight", weight)
|
177
|
-
set_weight_attrs(weight, extra_weight_attrs)
|
178
|
-
|
179
|
-
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
180
|
-
if _is_cpu and _is_cpu_amx_available:
|
181
|
-
_amx_process_weight_after_loading(layer, ["weight"])
|
182
|
-
|
183
|
-
def apply(
|
184
|
-
self,
|
185
|
-
layer: torch.nn.Module,
|
186
|
-
x: torch.Tensor,
|
187
|
-
bias: Optional[torch.Tensor] = None,
|
188
|
-
) -> torch.Tensor:
|
189
|
-
|
190
|
-
if use_intel_amx_backend(layer):
|
191
|
-
return torch.ops.sgl_kernel.weight_packed_linear(
|
192
|
-
x, layer.weight, bias, True # is_vnni
|
193
|
-
)
|
194
|
-
|
195
|
-
return F.linear(x, layer.weight, bias)
|
196
|
-
|
197
|
-
|
198
109
|
class LinearBase(torch.nn.Module):
|
199
110
|
"""Base linear layer.
|
200
111
|
|
@@ -310,7 +221,7 @@ class ReplicatedLinear(LinearBase):
|
|
310
221
|
assert param.size() == loaded_weight.size()
|
311
222
|
param.data.copy_(loaded_weight)
|
312
223
|
|
313
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
224
|
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
314
225
|
bias = self.bias if not self.skip_bias_add else None
|
315
226
|
assert self.quant_method is not None
|
316
227
|
output = self.quant_method.apply(self, x, bias)
|