sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,150 @@
1
+ import ast
2
+ import html
3
+ import json
4
+ import logging
5
+ import re
6
+ from typing import Any, Dict, List, Tuple
7
+
8
+ from sglang.srt.entrypoints.openai.protocol import Tool
9
+ from sglang.srt.function_call.base_format_detector import BaseFormatDetector
10
+ from sglang.srt.function_call.core_types import (
11
+ StreamingParseResult,
12
+ StructureInfo,
13
+ ToolCallItem,
14
+ _GetInfoFunc,
15
+ )
16
+ from sglang.srt.function_call.ebnf_composer import EBNFComposer
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def _safe_val(raw: str) -> Any:
22
+ raw = html.unescape(raw.strip())
23
+ try:
24
+ return json.loads(raw)
25
+ except Exception:
26
+ try:
27
+ return ast.literal_eval(raw)
28
+ except Exception:
29
+ return raw
30
+
31
+
32
+ class Qwen3XMLDetector(BaseFormatDetector):
33
+ """
34
+ Detector for Qwen 3 models.
35
+ Assumes function call format:
36
+ <tool_call>
37
+ <function=execute_bash>
38
+ <parameter=command>
39
+ pwd && ls
40
+ </parameter>
41
+ </function>
42
+ </tool_call>
43
+ """
44
+
45
+ def __init__(self):
46
+ super().__init__()
47
+ self.tool_call_start_token: str = "<tool_call>"
48
+ self.tool_call_end_token: str = "</tool_call>"
49
+ self.tool_call_prefix: str = "<function="
50
+ self.tool_call_regex = re.compile(
51
+ r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL
52
+ )
53
+ self.tool_call_function_regex = re.compile(
54
+ r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
55
+ )
56
+ self.tool_call_parameter_regex = re.compile(
57
+ r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL
58
+ )
59
+ self._buf: str = ""
60
+
61
+ def has_tool_call(self, text: str) -> bool:
62
+ return self.tool_call_start_token in text
63
+
64
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
65
+ normal, calls = self._extract(text, tools)
66
+ return StreamingParseResult(normal_text=normal, calls=calls)
67
+
68
+ def parse_streaming_increment(
69
+ self, new_text: str, tools: List[Tool]
70
+ ) -> StreamingParseResult:
71
+ self._buf += new_text
72
+ normal = ""
73
+ calls: List[ToolCallItem] = []
74
+ while True:
75
+ if self.tool_call_start_token not in self._buf:
76
+ normal += self._buf
77
+ self._buf = ""
78
+ break
79
+ s = self._buf.find(self.tool_call_start_token)
80
+ if s > 0:
81
+ normal += self._buf[:s]
82
+ self._buf = self._buf[s:]
83
+ e = self._buf.find(self.tool_call_end_token)
84
+ if e == -1:
85
+ break
86
+ block = self._buf[: e + len(self.tool_call_end_token)]
87
+ self._buf = self._buf[e + len(self.tool_call_end_token) :]
88
+ calls.extend(self._parse_block(block, tools))
89
+ return StreamingParseResult(normal_text=normal, calls=calls)
90
+
91
+ def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]:
92
+ normal_parts: List[str] = []
93
+ calls: List[ToolCallItem] = []
94
+ cursor = 0
95
+ while True:
96
+ s = text.find(self.tool_call_start_token, cursor)
97
+ if s == -1:
98
+ normal_parts.append(text[cursor:])
99
+ break
100
+ normal_parts.append(text[cursor:s])
101
+ e = text.find(self.tool_call_end_token, s)
102
+ if e == -1:
103
+ normal_parts.append(text[s:])
104
+ break
105
+ block = text[s : e + len(self.tool_call_end_token)]
106
+ cursor = e + len(self.tool_call_end_token)
107
+ calls.extend(self._parse_block(block, tools))
108
+ return "".join(normal_parts), calls
109
+
110
+ def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]:
111
+ res: List[ToolCallItem] = []
112
+ for m in self.tool_call_function_regex.findall(block):
113
+ txt = m[0] if m[0] else m[1]
114
+ if ">" not in txt:
115
+ continue
116
+ idx = txt.index(">")
117
+ fname = txt[:idx].strip()
118
+ body = txt[idx + 1 :]
119
+ params: Dict[str, Any] = {}
120
+ for pm in self.tool_call_parameter_regex.findall(body):
121
+ ptxt = pm[0] if pm[0] else pm[1]
122
+ if ">" not in ptxt:
123
+ continue
124
+ pidx = ptxt.index(">")
125
+ pname = ptxt[:pidx].strip()
126
+ pval = ptxt[pidx + 1 :].lstrip("\n").rstrip("\n")
127
+ params[pname] = _safe_val(pval)
128
+ raw = {"name": fname, "arguments": params}
129
+ try:
130
+ res.extend(self.parse_base_json(raw, tools))
131
+ except Exception:
132
+ logger.warning("invalid tool call for %s dropped", fname)
133
+ return res
134
+
135
+ def structure_info(self) -> _GetInfoFunc:
136
+ return lambda n: StructureInfo(
137
+ begin=f"{self.tool_call_start_token}\n<function={n}>",
138
+ end=f"</function>\n{self.tool_call_end_token}",
139
+ trigger=self.tool_call_start_token,
140
+ )
141
+
142
+ # TODO: fake ebnf for xml + outlines backend
143
+ def build_ebnf(self, tools: List[Tool]):
144
+ return EBNFComposer.build_ebnf(
145
+ tools,
146
+ individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
147
+ individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
148
+ tool_call_separator="\\n",
149
+ function_format="json",
150
+ )
@@ -167,7 +167,6 @@ def get_generation_config(
167
167
  model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
168
168
  )
169
169
  except OSError as e:
170
- logging.info("model doesn't have generation_config.json")
171
170
  return None
172
171
 
173
172
 
@@ -110,6 +110,17 @@ class NewGELU(CustomOp):
110
110
  return self.forward_native(x)
111
111
 
112
112
 
113
+ class ReLU2(nn.Module):
114
+ """
115
+ Applies the squared Rectified Linear Unit function.
116
+ y = max(0, x)^2
117
+ """
118
+
119
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
120
+ x = F.relu(x)
121
+ return x * x
122
+
123
+
113
124
  class QuickGELU(CustomOp):
114
125
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
115
126
  return x * torch.sigmoid(1.702 * x)
@@ -164,6 +175,8 @@ class ScaledActivation(nn.Module):
164
175
  _ACTIVATION_REGISTRY = {
165
176
  "gelu": nn.GELU(),
166
177
  "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
178
+ "gelu_new": NewGELU(),
179
+ "relu2": ReLU2(),
167
180
  }
168
181
 
169
182
 
@@ -1617,7 +1617,7 @@ class FlashAttentionBackend(AttentionBackend):
1617
1617
  metadata.max_seq_len_k + self.page_size - 1
1618
1618
  ) // self.page_size
1619
1619
 
1620
- normal_decode_set_medadata(
1620
+ normal_decode_set_metadata(
1621
1621
  metadata.cache_seqlens_int32,
1622
1622
  metadata.cu_seqlens_k,
1623
1623
  metadata.page_table,
@@ -1666,7 +1666,7 @@ class FlashAttentionBackend(AttentionBackend):
1666
1666
  max_seq_pages = (max_len + self.page_size - 1) // self.page_size
1667
1667
  metadata.max_seq_len_k = max_len
1668
1668
 
1669
- normal_decode_set_medadata(
1669
+ normal_decode_set_metadata(
1670
1670
  metadata.cache_seqlens_int32,
1671
1671
  metadata.cu_seqlens_k,
1672
1672
  metadata.page_table,
@@ -2089,7 +2089,7 @@ class FlashAttentionMultiStepBackend:
2089
2089
  # @torch.compile(dynamic=True, backend=get_compiler_backend())
2090
2090
  # TODO: fuse these kernels
2091
2091
  # NOTE: torch.compile makes it slower in speculative decoding
2092
- def normal_decode_set_medadata(
2092
+ def normal_decode_set_metadata(
2093
2093
  cache_seqlens_int32: torch.Tensor,
2094
2094
  cu_seqlens_k: torch.Tensor,
2095
2095
  page_table: torch.Tensor,
@@ -25,7 +25,9 @@ from sglang.global_config import global_config
25
25
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
26
26
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
27
27
  from sglang.srt.layers.dp_attention import get_attention_tp_size
28
+ from sglang.srt.layers.radix_attention import AttentionType
28
29
  from sglang.srt.layers.utils import is_sm100_supported
30
+ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
29
31
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
30
32
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
31
33
  from sglang.srt.utils import is_flashinfer_available, next_power_of_2
@@ -485,12 +487,20 @@ class FlashInferAttnBackend(AttentionBackend):
485
487
  v_scale=layer.v_scale,
486
488
  )
487
489
  else:
490
+ causal = True
491
+ if layer.attn_type == AttentionType.ENCODER_ONLY:
492
+ save_kv_cache = False
493
+ causal = False
494
+
488
495
  if self.forward_metadata.extend_no_prefix:
496
+ # NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
497
+ # The FlashInfer head_dim limitation itself is tracked here:
498
+ # https://github.com/flashinfer-ai/flashinfer/issues/1048
489
499
  o = self.prefill_wrapper_ragged.forward(
490
500
  q.view(-1, layer.tp_q_head_num, layer.head_dim),
491
501
  k.view(-1, layer.tp_k_head_num, layer.head_dim),
492
502
  v.view(-1, layer.tp_v_head_num, layer.head_dim),
493
- causal=True,
503
+ causal=causal,
494
504
  sm_scale=layer.scaling,
495
505
  logits_soft_cap=logits_soft_cap,
496
506
  )
@@ -589,6 +599,7 @@ class FlashInferIndicesUpdaterDecode:
589
599
  self.kv_indptr = attn_backend.kv_indptr
590
600
  self.kv_last_page_len = attn_backend.kv_last_page_len
591
601
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
602
+ self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
592
603
 
593
604
  # Dispatch the update function
594
605
  if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
@@ -655,6 +666,10 @@ class FlashInferIndicesUpdaterDecode:
655
666
  paged_kernel_lens_sum_tmp = seq_lens_sum
656
667
  kv_start_idx_tmp = None
657
668
 
669
+ use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
670
+ self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
671
+ )
672
+
658
673
  self.call_begin_forward(
659
674
  decode_wrappers[wrapper_id],
660
675
  req_pool_indices,
@@ -663,6 +678,7 @@ class FlashInferIndicesUpdaterDecode:
663
678
  self.kv_indptr[wrapper_id],
664
679
  kv_start_idx_tmp,
665
680
  spec_info,
681
+ use_sliding_window_kv_pool=use_sliding_window_kv_pool,
666
682
  )
667
683
 
668
684
  def update_cross_attention(
@@ -704,6 +720,7 @@ class FlashInferIndicesUpdaterDecode:
704
720
  kv_indptr: torch.Tensor,
705
721
  kv_start_idx: torch.Tensor,
706
722
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
723
+ use_sliding_window_kv_pool: bool = False,
707
724
  ):
708
725
  if spec_info is None:
709
726
  bs = len(req_pool_indices)
@@ -731,6 +748,14 @@ class FlashInferIndicesUpdaterDecode:
731
748
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
732
749
  bs = kv_indptr.shape[0] - 1
733
750
 
751
+ if use_sliding_window_kv_pool:
752
+ kv_last_index = kv_indptr[-1]
753
+ kv_indices[:kv_last_index] = (
754
+ self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
755
+ kv_indices[:kv_last_index]
756
+ )
757
+ )
758
+
734
759
  wrapper.begin_forward(
735
760
  kv_indptr,
736
761
  kv_indices,
@@ -765,6 +790,7 @@ class FlashInferIndicesUpdaterPrefill:
765
790
  self.kv_last_page_len = attn_backend.kv_last_page_len
766
791
  self.qo_indptr = attn_backend.qo_indptr
767
792
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
793
+ self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
768
794
  self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
769
795
 
770
796
  # Dispatch the update function
@@ -848,6 +874,9 @@ class FlashInferIndicesUpdaterPrefill:
848
874
  paged_kernel_lens_sum = seq_lens_sum
849
875
 
850
876
  kv_start_idx = seq_lens - paged_kernel_lens
877
+ use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
878
+ self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
879
+ )
851
880
 
852
881
  self.call_begin_forward(
853
882
  self.prefill_wrapper_ragged,
@@ -862,6 +891,7 @@ class FlashInferIndicesUpdaterPrefill:
862
891
  self.qo_indptr[wrapper_id],
863
892
  use_ragged,
864
893
  spec_info,
894
+ use_sliding_window_kv_pool=use_sliding_window_kv_pool,
865
895
  )
866
896
 
867
897
  def update_cross_attention(
@@ -916,6 +946,7 @@ class FlashInferIndicesUpdaterPrefill:
916
946
  qo_indptr: torch.Tensor,
917
947
  use_ragged: bool,
918
948
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
949
+ use_sliding_window_kv_pool: bool = False,
919
950
  ):
920
951
  bs = len(seq_lens)
921
952
  if spec_info is None:
@@ -964,6 +995,14 @@ class FlashInferIndicesUpdaterPrefill:
964
995
  q_data_type=self.q_data_type,
965
996
  )
966
997
 
998
+ if use_sliding_window_kv_pool:
999
+ kv_last_index = kv_indptr[-1]
1000
+ kv_indices[:kv_last_index] = (
1001
+ self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
1002
+ kv_indices[:kv_last_index]
1003
+ )
1004
+ )
1005
+
967
1006
  # cached part
968
1007
  wrapper_paged.begin_forward(
969
1008
  qo_indptr,
@@ -1,12 +1,12 @@
1
1
  """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import itertools
4
6
  import logging
5
- from abc import abstractmethod
6
- from typing import Dict, List, Optional, Tuple
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
7
8
 
8
9
  import torch
9
- import torch.nn.functional as F
10
10
  from torch.nn.parameter import Parameter, UninitializedParameter
11
11
 
12
12
  from sglang.srt.distributed import (
@@ -17,7 +17,6 @@ from sglang.srt.distributed import (
17
17
  tensor_model_parallel_all_gather,
18
18
  tensor_model_parallel_all_reduce,
19
19
  )
20
- from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
21
20
  from sglang.srt.layers.parameter import (
22
21
  BasevLLMParameter,
23
22
  BlockQuantScaleParameter,
@@ -27,17 +26,14 @@ from sglang.srt.layers.parameter import (
27
26
  RowvLLMParameter,
28
27
  _ColumnvLLMParameter,
29
28
  )
30
- from sglang.srt.layers.quantization.base_config import (
31
- QuantizationConfig,
32
- QuantizeMethodBase,
33
- )
34
- from sglang.srt.utils import (
35
- cpu_has_amx_support,
36
- is_cpu,
37
- is_npu,
38
- set_weight_attrs,
39
- use_intel_amx_backend,
40
- )
29
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
30
+ from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
31
+
32
+ if TYPE_CHECKING:
33
+ from sglang.srt.layers.quantization.base_config import (
34
+ QuantizationConfig,
35
+ QuantizeMethodBase,
36
+ )
41
37
 
42
38
  logger = logging.getLogger(__name__)
43
39
 
@@ -57,9 +53,9 @@ WEIGHT_LOADER_V2_SUPPORTED = [
57
53
  "ModelOptFp8LinearMethod",
58
54
  "ModelOptFp4LinearMethod",
59
55
  "IPEXAWQLinearMethod",
56
+ "PetitNvFp4LinearMethod",
60
57
  ]
61
58
 
62
- _is_cpu_amx_available = cpu_has_amx_support()
63
59
  _is_cpu = is_cpu()
64
60
  _is_npu = is_npu()
65
61
 
@@ -110,91 +106,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
110
106
  return param[shard_id], loaded_weight
111
107
 
112
108
 
113
- class LinearMethodBase(QuantizeMethodBase):
114
- """Base class for different (maybe quantized) linear methods."""
115
-
116
- @abstractmethod
117
- def create_weights(
118
- self,
119
- layer: torch.nn.Module,
120
- input_size_per_partition: int,
121
- output_partition_sizes: List[int],
122
- input_size: int,
123
- output_size: int,
124
- params_dtype: torch.dtype,
125
- **extra_weight_attrs,
126
- ):
127
- """Create weights for a linear layer.
128
- The weights will be set as attributes of the layer.
129
-
130
- Args:
131
- layer: The layer that is using the LinearMethodBase factory.
132
- input_size_per_partition: Size of the weight input dim on rank X.
133
- output_partition_sizes: Sizes of the output dim of each logical
134
- weight on rank X. E.g., output_partition_sizes for QKVLinear
135
- is a list contains the width of Wq, Wk, Wv on rank X.
136
- input_size: Size of the input dim of the weight across all ranks.
137
- output_size: Size of the output dim of the weight across all ranks.
138
- params_dtype: Datatype of the parameters.
139
- """
140
- raise NotImplementedError
141
-
142
- @abstractmethod
143
- def apply(
144
- self,
145
- layer: torch.nn.Module,
146
- x: torch.Tensor,
147
- bias: Optional[torch.Tensor] = None,
148
- ) -> torch.Tensor:
149
- """Apply the weights in layer to the input tensor.
150
- Expects create_weights to have been called before on the layer."""
151
- raise NotImplementedError
152
-
153
-
154
- class UnquantizedLinearMethod(LinearMethodBase):
155
- """Linear method without quantization."""
156
-
157
- def create_weights(
158
- self,
159
- layer: torch.nn.Module,
160
- input_size_per_partition: int,
161
- output_partition_sizes: List[int],
162
- input_size: int,
163
- output_size: int,
164
- params_dtype: torch.dtype,
165
- **extra_weight_attrs,
166
- ):
167
- weight = Parameter(
168
- torch.empty(
169
- sum(output_partition_sizes),
170
- input_size_per_partition,
171
- dtype=params_dtype,
172
- ),
173
- requires_grad=False,
174
- )
175
- set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
176
- layer.register_parameter("weight", weight)
177
- set_weight_attrs(weight, extra_weight_attrs)
178
-
179
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
180
- if _is_cpu and _is_cpu_amx_available:
181
- _amx_process_weight_after_loading(layer, ["weight"])
182
-
183
- def apply(
184
- self,
185
- layer: torch.nn.Module,
186
- x: torch.Tensor,
187
- bias: Optional[torch.Tensor] = None,
188
- ) -> torch.Tensor:
189
-
190
- if use_intel_amx_backend(layer):
191
- return torch.ops.sgl_kernel.weight_packed_linear(
192
- x, layer.weight, bias, True # is_vnni
193
- )
194
-
195
- return F.linear(x, layer.weight, bias)
196
-
197
-
198
109
  class LinearBase(torch.nn.Module):
199
110
  """Base linear layer.
200
111
 
@@ -310,7 +221,7 @@ class ReplicatedLinear(LinearBase):
310
221
  assert param.size() == loaded_weight.size()
311
222
  param.data.copy_(loaded_weight)
312
223
 
313
- def forward(self, x: torch.Tensor) -> torch.Tensor:
224
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
314
225
  bias = self.bias if not self.skip_bias_add else None
315
226
  assert self.quant_method is not None
316
227
  output = self.quant_method.apply(self, x, bias)
@@ -236,7 +236,8 @@ def pre_reorder_triton_kernel(
236
236
  ):
237
237
  OutDtype = gateup_input_ptr.dtype.element_ty
238
238
 
239
- src_idx = tl.program_id(0)
239
+ src_idx_int32 = tl.program_id(0)
240
+ src_idx = src_idx_int32.to(tl.int64)
240
241
  src2dst_ptr = src2dst_ptr + src_idx * topk
241
242
  topk_ids_ptr = topk_ids_ptr + src_idx * topk
242
243
  src_ptr = input_ptr + src_idx * hidden_size
@@ -255,7 +256,8 @@ def pre_reorder_triton_kernel(
255
256
  else:
256
257
  scale = 1.0
257
258
 
258
- dst_idx = tl.load(src2dst_ptr + idx)
259
+ dst_idx_int32 = tl.load(src2dst_ptr + idx)
260
+ dst_idx = dst_idx_int32.to(tl.int64)
259
261
  dst_ptr = gateup_input_ptr + dst_idx * hidden_size
260
262
  for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
261
263
  offset = start_offset + vec