sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,150 @@
|
|
1
|
+
import ast
|
2
|
+
import html
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
import re
|
6
|
+
from typing import Any, Dict, List, Tuple
|
7
|
+
|
8
|
+
from sglang.srt.entrypoints.openai.protocol import Tool
|
9
|
+
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
10
|
+
from sglang.srt.function_call.core_types import (
|
11
|
+
StreamingParseResult,
|
12
|
+
StructureInfo,
|
13
|
+
ToolCallItem,
|
14
|
+
_GetInfoFunc,
|
15
|
+
)
|
16
|
+
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
17
|
+
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
def _safe_val(raw: str) -> Any:
|
22
|
+
raw = html.unescape(raw.strip())
|
23
|
+
try:
|
24
|
+
return json.loads(raw)
|
25
|
+
except Exception:
|
26
|
+
try:
|
27
|
+
return ast.literal_eval(raw)
|
28
|
+
except Exception:
|
29
|
+
return raw
|
30
|
+
|
31
|
+
|
32
|
+
class Qwen3XMLDetector(BaseFormatDetector):
|
33
|
+
"""
|
34
|
+
Detector for Qwen 3 models.
|
35
|
+
Assumes function call format:
|
36
|
+
<tool_call>
|
37
|
+
<function=execute_bash>
|
38
|
+
<parameter=command>
|
39
|
+
pwd && ls
|
40
|
+
</parameter>
|
41
|
+
</function>
|
42
|
+
</tool_call>
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(self):
|
46
|
+
super().__init__()
|
47
|
+
self.tool_call_start_token: str = "<tool_call>"
|
48
|
+
self.tool_call_end_token: str = "</tool_call>"
|
49
|
+
self.tool_call_prefix: str = "<function="
|
50
|
+
self.tool_call_regex = re.compile(
|
51
|
+
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL
|
52
|
+
)
|
53
|
+
self.tool_call_function_regex = re.compile(
|
54
|
+
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
|
55
|
+
)
|
56
|
+
self.tool_call_parameter_regex = re.compile(
|
57
|
+
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL
|
58
|
+
)
|
59
|
+
self._buf: str = ""
|
60
|
+
|
61
|
+
def has_tool_call(self, text: str) -> bool:
|
62
|
+
return self.tool_call_start_token in text
|
63
|
+
|
64
|
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
65
|
+
normal, calls = self._extract(text, tools)
|
66
|
+
return StreamingParseResult(normal_text=normal, calls=calls)
|
67
|
+
|
68
|
+
def parse_streaming_increment(
|
69
|
+
self, new_text: str, tools: List[Tool]
|
70
|
+
) -> StreamingParseResult:
|
71
|
+
self._buf += new_text
|
72
|
+
normal = ""
|
73
|
+
calls: List[ToolCallItem] = []
|
74
|
+
while True:
|
75
|
+
if self.tool_call_start_token not in self._buf:
|
76
|
+
normal += self._buf
|
77
|
+
self._buf = ""
|
78
|
+
break
|
79
|
+
s = self._buf.find(self.tool_call_start_token)
|
80
|
+
if s > 0:
|
81
|
+
normal += self._buf[:s]
|
82
|
+
self._buf = self._buf[s:]
|
83
|
+
e = self._buf.find(self.tool_call_end_token)
|
84
|
+
if e == -1:
|
85
|
+
break
|
86
|
+
block = self._buf[: e + len(self.tool_call_end_token)]
|
87
|
+
self._buf = self._buf[e + len(self.tool_call_end_token) :]
|
88
|
+
calls.extend(self._parse_block(block, tools))
|
89
|
+
return StreamingParseResult(normal_text=normal, calls=calls)
|
90
|
+
|
91
|
+
def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]:
|
92
|
+
normal_parts: List[str] = []
|
93
|
+
calls: List[ToolCallItem] = []
|
94
|
+
cursor = 0
|
95
|
+
while True:
|
96
|
+
s = text.find(self.tool_call_start_token, cursor)
|
97
|
+
if s == -1:
|
98
|
+
normal_parts.append(text[cursor:])
|
99
|
+
break
|
100
|
+
normal_parts.append(text[cursor:s])
|
101
|
+
e = text.find(self.tool_call_end_token, s)
|
102
|
+
if e == -1:
|
103
|
+
normal_parts.append(text[s:])
|
104
|
+
break
|
105
|
+
block = text[s : e + len(self.tool_call_end_token)]
|
106
|
+
cursor = e + len(self.tool_call_end_token)
|
107
|
+
calls.extend(self._parse_block(block, tools))
|
108
|
+
return "".join(normal_parts), calls
|
109
|
+
|
110
|
+
def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]:
|
111
|
+
res: List[ToolCallItem] = []
|
112
|
+
for m in self.tool_call_function_regex.findall(block):
|
113
|
+
txt = m[0] if m[0] else m[1]
|
114
|
+
if ">" not in txt:
|
115
|
+
continue
|
116
|
+
idx = txt.index(">")
|
117
|
+
fname = txt[:idx].strip()
|
118
|
+
body = txt[idx + 1 :]
|
119
|
+
params: Dict[str, Any] = {}
|
120
|
+
for pm in self.tool_call_parameter_regex.findall(body):
|
121
|
+
ptxt = pm[0] if pm[0] else pm[1]
|
122
|
+
if ">" not in ptxt:
|
123
|
+
continue
|
124
|
+
pidx = ptxt.index(">")
|
125
|
+
pname = ptxt[:pidx].strip()
|
126
|
+
pval = ptxt[pidx + 1 :].lstrip("\n").rstrip("\n")
|
127
|
+
params[pname] = _safe_val(pval)
|
128
|
+
raw = {"name": fname, "arguments": params}
|
129
|
+
try:
|
130
|
+
res.extend(self.parse_base_json(raw, tools))
|
131
|
+
except Exception:
|
132
|
+
logger.warning("invalid tool call for %s dropped", fname)
|
133
|
+
return res
|
134
|
+
|
135
|
+
def structure_info(self) -> _GetInfoFunc:
|
136
|
+
return lambda n: StructureInfo(
|
137
|
+
begin=f"{self.tool_call_start_token}\n<function={n}>",
|
138
|
+
end=f"</function>\n{self.tool_call_end_token}",
|
139
|
+
trigger=self.tool_call_start_token,
|
140
|
+
)
|
141
|
+
|
142
|
+
# TODO: fake ebnf for xml + outlines backend
|
143
|
+
def build_ebnf(self, tools: List[Tool]):
|
144
|
+
return EBNFComposer.build_ebnf(
|
145
|
+
tools,
|
146
|
+
individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
|
147
|
+
individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
|
148
|
+
tool_call_separator="\\n",
|
149
|
+
function_format="json",
|
150
|
+
)
|
sglang/srt/layers/activation.py
CHANGED
@@ -110,6 +110,17 @@ class NewGELU(CustomOp):
|
|
110
110
|
return self.forward_native(x)
|
111
111
|
|
112
112
|
|
113
|
+
class ReLU2(nn.Module):
|
114
|
+
"""
|
115
|
+
Applies the squared Rectified Linear Unit function.
|
116
|
+
y = max(0, x)^2
|
117
|
+
"""
|
118
|
+
|
119
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
120
|
+
x = F.relu(x)
|
121
|
+
return x * x
|
122
|
+
|
123
|
+
|
113
124
|
class QuickGELU(CustomOp):
|
114
125
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
115
126
|
return x * torch.sigmoid(1.702 * x)
|
@@ -164,6 +175,8 @@ class ScaledActivation(nn.Module):
|
|
164
175
|
_ACTIVATION_REGISTRY = {
|
165
176
|
"gelu": nn.GELU(),
|
166
177
|
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
178
|
+
"gelu_new": NewGELU(),
|
179
|
+
"relu2": ReLU2(),
|
167
180
|
}
|
168
181
|
|
169
182
|
|
@@ -1617,7 +1617,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1617
1617
|
metadata.max_seq_len_k + self.page_size - 1
|
1618
1618
|
) // self.page_size
|
1619
1619
|
|
1620
|
-
|
1620
|
+
normal_decode_set_metadata(
|
1621
1621
|
metadata.cache_seqlens_int32,
|
1622
1622
|
metadata.cu_seqlens_k,
|
1623
1623
|
metadata.page_table,
|
@@ -1666,7 +1666,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1666
1666
|
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
1667
1667
|
metadata.max_seq_len_k = max_len
|
1668
1668
|
|
1669
|
-
|
1669
|
+
normal_decode_set_metadata(
|
1670
1670
|
metadata.cache_seqlens_int32,
|
1671
1671
|
metadata.cu_seqlens_k,
|
1672
1672
|
metadata.page_table,
|
@@ -2089,7 +2089,7 @@ class FlashAttentionMultiStepBackend:
|
|
2089
2089
|
# @torch.compile(dynamic=True, backend=get_compiler_backend())
|
2090
2090
|
# TODO: fuse these kernels
|
2091
2091
|
# NOTE: torch.compile makes it slower in speculative decoding
|
2092
|
-
def
|
2092
|
+
def normal_decode_set_metadata(
|
2093
2093
|
cache_seqlens_int32: torch.Tensor,
|
2094
2094
|
cu_seqlens_k: torch.Tensor,
|
2095
2095
|
page_table: torch.Tensor,
|
@@ -25,7 +25,9 @@ from sglang.global_config import global_config
|
|
25
25
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
26
26
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
27
27
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
28
|
+
from sglang.srt.layers.radix_attention import AttentionType
|
28
29
|
from sglang.srt.layers.utils import is_sm100_supported
|
30
|
+
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
29
31
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
30
32
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
31
33
|
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
|
@@ -485,12 +487,20 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
485
487
|
v_scale=layer.v_scale,
|
486
488
|
)
|
487
489
|
else:
|
490
|
+
causal = True
|
491
|
+
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
492
|
+
save_kv_cache = False
|
493
|
+
causal = False
|
494
|
+
|
488
495
|
if self.forward_metadata.extend_no_prefix:
|
496
|
+
# NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
|
497
|
+
# The FlashInfer head_dim limitation itself is tracked here:
|
498
|
+
# https://github.com/flashinfer-ai/flashinfer/issues/1048
|
489
499
|
o = self.prefill_wrapper_ragged.forward(
|
490
500
|
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
491
501
|
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
492
502
|
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
493
|
-
causal=
|
503
|
+
causal=causal,
|
494
504
|
sm_scale=layer.scaling,
|
495
505
|
logits_soft_cap=logits_soft_cap,
|
496
506
|
)
|
@@ -589,6 +599,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
589
599
|
self.kv_indptr = attn_backend.kv_indptr
|
590
600
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
591
601
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
602
|
+
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
592
603
|
|
593
604
|
# Dispatch the update function
|
594
605
|
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
@@ -655,6 +666,10 @@ class FlashInferIndicesUpdaterDecode:
|
|
655
666
|
paged_kernel_lens_sum_tmp = seq_lens_sum
|
656
667
|
kv_start_idx_tmp = None
|
657
668
|
|
669
|
+
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
|
670
|
+
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
|
671
|
+
)
|
672
|
+
|
658
673
|
self.call_begin_forward(
|
659
674
|
decode_wrappers[wrapper_id],
|
660
675
|
req_pool_indices,
|
@@ -663,6 +678,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
663
678
|
self.kv_indptr[wrapper_id],
|
664
679
|
kv_start_idx_tmp,
|
665
680
|
spec_info,
|
681
|
+
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
666
682
|
)
|
667
683
|
|
668
684
|
def update_cross_attention(
|
@@ -704,6 +720,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
704
720
|
kv_indptr: torch.Tensor,
|
705
721
|
kv_start_idx: torch.Tensor,
|
706
722
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
723
|
+
use_sliding_window_kv_pool: bool = False,
|
707
724
|
):
|
708
725
|
if spec_info is None:
|
709
726
|
bs = len(req_pool_indices)
|
@@ -731,6 +748,14 @@ class FlashInferIndicesUpdaterDecode:
|
|
731
748
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
732
749
|
bs = kv_indptr.shape[0] - 1
|
733
750
|
|
751
|
+
if use_sliding_window_kv_pool:
|
752
|
+
kv_last_index = kv_indptr[-1]
|
753
|
+
kv_indices[:kv_last_index] = (
|
754
|
+
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
755
|
+
kv_indices[:kv_last_index]
|
756
|
+
)
|
757
|
+
)
|
758
|
+
|
734
759
|
wrapper.begin_forward(
|
735
760
|
kv_indptr,
|
736
761
|
kv_indices,
|
@@ -765,6 +790,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
765
790
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
766
791
|
self.qo_indptr = attn_backend.qo_indptr
|
767
792
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
793
|
+
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
768
794
|
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
769
795
|
|
770
796
|
# Dispatch the update function
|
@@ -848,6 +874,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|
848
874
|
paged_kernel_lens_sum = seq_lens_sum
|
849
875
|
|
850
876
|
kv_start_idx = seq_lens - paged_kernel_lens
|
877
|
+
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
|
878
|
+
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
|
879
|
+
)
|
851
880
|
|
852
881
|
self.call_begin_forward(
|
853
882
|
self.prefill_wrapper_ragged,
|
@@ -862,6 +891,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
862
891
|
self.qo_indptr[wrapper_id],
|
863
892
|
use_ragged,
|
864
893
|
spec_info,
|
894
|
+
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
865
895
|
)
|
866
896
|
|
867
897
|
def update_cross_attention(
|
@@ -916,6 +946,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
916
946
|
qo_indptr: torch.Tensor,
|
917
947
|
use_ragged: bool,
|
918
948
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
949
|
+
use_sliding_window_kv_pool: bool = False,
|
919
950
|
):
|
920
951
|
bs = len(seq_lens)
|
921
952
|
if spec_info is None:
|
@@ -964,6 +995,14 @@ class FlashInferIndicesUpdaterPrefill:
|
|
964
995
|
q_data_type=self.q_data_type,
|
965
996
|
)
|
966
997
|
|
998
|
+
if use_sliding_window_kv_pool:
|
999
|
+
kv_last_index = kv_indptr[-1]
|
1000
|
+
kv_indices[:kv_last_index] = (
|
1001
|
+
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
1002
|
+
kv_indices[:kv_last_index]
|
1003
|
+
)
|
1004
|
+
)
|
1005
|
+
|
967
1006
|
# cached part
|
968
1007
|
wrapper_paged.begin_forward(
|
969
1008
|
qo_indptr,
|
sglang/srt/layers/linear.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import itertools
|
4
6
|
import logging
|
5
|
-
from
|
6
|
-
from typing import Dict, List, Optional, Tuple
|
7
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
7
8
|
|
8
9
|
import torch
|
9
|
-
import torch.nn.functional as F
|
10
10
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
11
11
|
|
12
12
|
from sglang.srt.distributed import (
|
@@ -17,7 +17,6 @@ from sglang.srt.distributed import (
|
|
17
17
|
tensor_model_parallel_all_gather,
|
18
18
|
tensor_model_parallel_all_reduce,
|
19
19
|
)
|
20
|
-
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
21
20
|
from sglang.srt.layers.parameter import (
|
22
21
|
BasevLLMParameter,
|
23
22
|
BlockQuantScaleParameter,
|
@@ -27,17 +26,14 @@ from sglang.srt.layers.parameter import (
|
|
27
26
|
RowvLLMParameter,
|
28
27
|
_ColumnvLLMParameter,
|
29
28
|
)
|
30
|
-
from sglang.srt.layers.quantization.
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
from sglang.srt.
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
set_weight_attrs,
|
39
|
-
use_intel_amx_backend,
|
40
|
-
)
|
29
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
30
|
+
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
|
31
|
+
|
32
|
+
if TYPE_CHECKING:
|
33
|
+
from sglang.srt.layers.quantization.base_config import (
|
34
|
+
QuantizationConfig,
|
35
|
+
QuantizeMethodBase,
|
36
|
+
)
|
41
37
|
|
42
38
|
logger = logging.getLogger(__name__)
|
43
39
|
|
@@ -57,9 +53,9 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|
57
53
|
"ModelOptFp8LinearMethod",
|
58
54
|
"ModelOptFp4LinearMethod",
|
59
55
|
"IPEXAWQLinearMethod",
|
56
|
+
"PetitNvFp4LinearMethod",
|
60
57
|
]
|
61
58
|
|
62
|
-
_is_cpu_amx_available = cpu_has_amx_support()
|
63
59
|
_is_cpu = is_cpu()
|
64
60
|
_is_npu = is_npu()
|
65
61
|
|
@@ -110,91 +106,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
|
110
106
|
return param[shard_id], loaded_weight
|
111
107
|
|
112
108
|
|
113
|
-
class LinearMethodBase(QuantizeMethodBase):
|
114
|
-
"""Base class for different (maybe quantized) linear methods."""
|
115
|
-
|
116
|
-
@abstractmethod
|
117
|
-
def create_weights(
|
118
|
-
self,
|
119
|
-
layer: torch.nn.Module,
|
120
|
-
input_size_per_partition: int,
|
121
|
-
output_partition_sizes: List[int],
|
122
|
-
input_size: int,
|
123
|
-
output_size: int,
|
124
|
-
params_dtype: torch.dtype,
|
125
|
-
**extra_weight_attrs,
|
126
|
-
):
|
127
|
-
"""Create weights for a linear layer.
|
128
|
-
The weights will be set as attributes of the layer.
|
129
|
-
|
130
|
-
Args:
|
131
|
-
layer: The layer that is using the LinearMethodBase factory.
|
132
|
-
input_size_per_partition: Size of the weight input dim on rank X.
|
133
|
-
output_partition_sizes: Sizes of the output dim of each logical
|
134
|
-
weight on rank X. E.g., output_partition_sizes for QKVLinear
|
135
|
-
is a list contains the width of Wq, Wk, Wv on rank X.
|
136
|
-
input_size: Size of the input dim of the weight across all ranks.
|
137
|
-
output_size: Size of the output dim of the weight across all ranks.
|
138
|
-
params_dtype: Datatype of the parameters.
|
139
|
-
"""
|
140
|
-
raise NotImplementedError
|
141
|
-
|
142
|
-
@abstractmethod
|
143
|
-
def apply(
|
144
|
-
self,
|
145
|
-
layer: torch.nn.Module,
|
146
|
-
x: torch.Tensor,
|
147
|
-
bias: Optional[torch.Tensor] = None,
|
148
|
-
) -> torch.Tensor:
|
149
|
-
"""Apply the weights in layer to the input tensor.
|
150
|
-
Expects create_weights to have been called before on the layer."""
|
151
|
-
raise NotImplementedError
|
152
|
-
|
153
|
-
|
154
|
-
class UnquantizedLinearMethod(LinearMethodBase):
|
155
|
-
"""Linear method without quantization."""
|
156
|
-
|
157
|
-
def create_weights(
|
158
|
-
self,
|
159
|
-
layer: torch.nn.Module,
|
160
|
-
input_size_per_partition: int,
|
161
|
-
output_partition_sizes: List[int],
|
162
|
-
input_size: int,
|
163
|
-
output_size: int,
|
164
|
-
params_dtype: torch.dtype,
|
165
|
-
**extra_weight_attrs,
|
166
|
-
):
|
167
|
-
weight = Parameter(
|
168
|
-
torch.empty(
|
169
|
-
sum(output_partition_sizes),
|
170
|
-
input_size_per_partition,
|
171
|
-
dtype=params_dtype,
|
172
|
-
),
|
173
|
-
requires_grad=False,
|
174
|
-
)
|
175
|
-
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
176
|
-
layer.register_parameter("weight", weight)
|
177
|
-
set_weight_attrs(weight, extra_weight_attrs)
|
178
|
-
|
179
|
-
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
180
|
-
if _is_cpu and _is_cpu_amx_available:
|
181
|
-
_amx_process_weight_after_loading(layer, ["weight"])
|
182
|
-
|
183
|
-
def apply(
|
184
|
-
self,
|
185
|
-
layer: torch.nn.Module,
|
186
|
-
x: torch.Tensor,
|
187
|
-
bias: Optional[torch.Tensor] = None,
|
188
|
-
) -> torch.Tensor:
|
189
|
-
|
190
|
-
if use_intel_amx_backend(layer):
|
191
|
-
return torch.ops.sgl_kernel.weight_packed_linear(
|
192
|
-
x, layer.weight, bias, True # is_vnni
|
193
|
-
)
|
194
|
-
|
195
|
-
return F.linear(x, layer.weight, bias)
|
196
|
-
|
197
|
-
|
198
109
|
class LinearBase(torch.nn.Module):
|
199
110
|
"""Base linear layer.
|
200
111
|
|
@@ -310,7 +221,7 @@ class ReplicatedLinear(LinearBase):
|
|
310
221
|
assert param.size() == loaded_weight.size()
|
311
222
|
param.data.copy_(loaded_weight)
|
312
223
|
|
313
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
224
|
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
314
225
|
bias = self.bias if not self.skip_bias_add else None
|
315
226
|
assert self.quant_method is not None
|
316
227
|
output = self.quant_method.apply(self, x, bias)
|
@@ -236,7 +236,8 @@ def pre_reorder_triton_kernel(
|
|
236
236
|
):
|
237
237
|
OutDtype = gateup_input_ptr.dtype.element_ty
|
238
238
|
|
239
|
-
|
239
|
+
src_idx_int32 = tl.program_id(0)
|
240
|
+
src_idx = src_idx_int32.to(tl.int64)
|
240
241
|
src2dst_ptr = src2dst_ptr + src_idx * topk
|
241
242
|
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
242
243
|
src_ptr = input_ptr + src_idx * hidden_size
|
@@ -255,7 +256,8 @@ def pre_reorder_triton_kernel(
|
|
255
256
|
else:
|
256
257
|
scale = 1.0
|
257
258
|
|
258
|
-
|
259
|
+
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
260
|
+
dst_idx = dst_idx_int32.to(tl.int64)
|
259
261
|
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
260
262
|
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
261
263
|
offset = start_offset + vec
|