sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,7 @@ from sglang.srt.utils import (
33
33
  cpu_has_amx_support,
34
34
  is_cpu,
35
35
  is_cuda,
36
+ is_hip,
36
37
  is_npu,
37
38
  set_weight_attrs,
38
39
  )
@@ -42,9 +43,12 @@ _is_cuda = is_cuda()
42
43
  _is_npu = is_npu()
43
44
  _is_cpu_amx_available = cpu_has_amx_support()
44
45
  _is_cpu = is_cpu()
46
+ _is_hip = is_hip()
45
47
 
46
48
  if _is_cuda:
47
49
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
50
+ elif _is_hip:
51
+ from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
48
52
 
49
53
  if is_npu():
50
54
  import torch_npu
@@ -110,14 +114,29 @@ class NewGELU(CustomOp):
110
114
  return self.forward_native(x)
111
115
 
112
116
 
117
+ class ReLU2(nn.Module):
118
+ """
119
+ Applies the squared Rectified Linear Unit function.
120
+ y = max(0, x)^2
121
+ """
122
+
123
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
124
+ x = F.relu(x)
125
+ return x * x
126
+
127
+
113
128
  class QuickGELU(CustomOp):
114
129
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
115
130
  return x * torch.sigmoid(1.702 * x)
116
131
 
117
132
  def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
118
- # TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel
119
133
  return self.forward_native(x)
120
134
 
135
+ def forward_hip(self, x: torch.Tensor) -> torch.Tensor:
136
+ out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
137
+ gelu_quick(x, out)
138
+ return out
139
+
121
140
 
122
141
  class ScaledActivation(nn.Module):
123
142
  """An activation function with post-scale parameters.
@@ -164,6 +183,8 @@ class ScaledActivation(nn.Module):
164
183
  _ACTIVATION_REGISTRY = {
165
184
  "gelu": nn.GELU(),
166
185
  "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
186
+ "gelu_new": NewGELU(),
187
+ "relu2": ReLU2(),
167
188
  }
168
189
 
169
190
 
@@ -209,8 +230,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
209
230
  return nn.Identity()
210
231
 
211
232
 
212
- if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
233
+ if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
213
234
  logger.info(
214
- "sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
235
+ "sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
215
236
  )
216
237
  from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
@@ -65,7 +65,9 @@ class AttentionBackend(ABC):
65
65
  **kwargs,
66
66
  ):
67
67
  """Run forward on an attention layer."""
68
- if forward_batch.forward_mode.is_decode():
68
+ if forward_batch.forward_mode.is_idle():
69
+ return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
70
+ elif forward_batch.forward_mode.is_decode():
69
71
  return self.forward_decode(
70
72
  q,
71
73
  k,
@@ -1617,7 +1617,7 @@ class FlashAttentionBackend(AttentionBackend):
1617
1617
  metadata.max_seq_len_k + self.page_size - 1
1618
1618
  ) // self.page_size
1619
1619
 
1620
- normal_decode_set_medadata(
1620
+ normal_decode_set_metadata(
1621
1621
  metadata.cache_seqlens_int32,
1622
1622
  metadata.cu_seqlens_k,
1623
1623
  metadata.page_table,
@@ -1666,7 +1666,7 @@ class FlashAttentionBackend(AttentionBackend):
1666
1666
  max_seq_pages = (max_len + self.page_size - 1) // self.page_size
1667
1667
  metadata.max_seq_len_k = max_len
1668
1668
 
1669
- normal_decode_set_medadata(
1669
+ normal_decode_set_metadata(
1670
1670
  metadata.cache_seqlens_int32,
1671
1671
  metadata.cu_seqlens_k,
1672
1672
  metadata.page_table,
@@ -2089,7 +2089,7 @@ class FlashAttentionMultiStepBackend:
2089
2089
  # @torch.compile(dynamic=True, backend=get_compiler_backend())
2090
2090
  # TODO: fuse these kernels
2091
2091
  # NOTE: torch.compile makes it slower in speculative decoding
2092
- def normal_decode_set_medadata(
2092
+ def normal_decode_set_metadata(
2093
2093
  cache_seqlens_int32: torch.Tensor,
2094
2094
  cu_seqlens_k: torch.Tensor,
2095
2095
  page_table: torch.Tensor,
@@ -25,7 +25,9 @@ from sglang.global_config import global_config
25
25
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
26
26
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
27
27
  from sglang.srt.layers.dp_attention import get_attention_tp_size
28
+ from sglang.srt.layers.radix_attention import AttentionType
28
29
  from sglang.srt.layers.utils import is_sm100_supported
30
+ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
29
31
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
30
32
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
31
33
  from sglang.srt.utils import is_flashinfer_available, next_power_of_2
@@ -485,12 +487,20 @@ class FlashInferAttnBackend(AttentionBackend):
485
487
  v_scale=layer.v_scale,
486
488
  )
487
489
  else:
490
+ causal = True
491
+ if layer.attn_type == AttentionType.ENCODER_ONLY:
492
+ save_kv_cache = False
493
+ causal = False
494
+
488
495
  if self.forward_metadata.extend_no_prefix:
496
+ # NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
497
+ # The FlashInfer head_dim limitation itself is tracked here:
498
+ # https://github.com/flashinfer-ai/flashinfer/issues/1048
489
499
  o = self.prefill_wrapper_ragged.forward(
490
500
  q.view(-1, layer.tp_q_head_num, layer.head_dim),
491
501
  k.view(-1, layer.tp_k_head_num, layer.head_dim),
492
502
  v.view(-1, layer.tp_v_head_num, layer.head_dim),
493
- causal=True,
503
+ causal=causal,
494
504
  sm_scale=layer.scaling,
495
505
  logits_soft_cap=logits_soft_cap,
496
506
  )
@@ -589,6 +599,7 @@ class FlashInferIndicesUpdaterDecode:
589
599
  self.kv_indptr = attn_backend.kv_indptr
590
600
  self.kv_last_page_len = attn_backend.kv_last_page_len
591
601
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
602
+ self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
592
603
 
593
604
  # Dispatch the update function
594
605
  if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
@@ -655,6 +666,10 @@ class FlashInferIndicesUpdaterDecode:
655
666
  paged_kernel_lens_sum_tmp = seq_lens_sum
656
667
  kv_start_idx_tmp = None
657
668
 
669
+ use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
670
+ self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
671
+ )
672
+
658
673
  self.call_begin_forward(
659
674
  decode_wrappers[wrapper_id],
660
675
  req_pool_indices,
@@ -663,6 +678,7 @@ class FlashInferIndicesUpdaterDecode:
663
678
  self.kv_indptr[wrapper_id],
664
679
  kv_start_idx_tmp,
665
680
  spec_info,
681
+ use_sliding_window_kv_pool=use_sliding_window_kv_pool,
666
682
  )
667
683
 
668
684
  def update_cross_attention(
@@ -704,6 +720,7 @@ class FlashInferIndicesUpdaterDecode:
704
720
  kv_indptr: torch.Tensor,
705
721
  kv_start_idx: torch.Tensor,
706
722
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
723
+ use_sliding_window_kv_pool: bool = False,
707
724
  ):
708
725
  if spec_info is None:
709
726
  bs = len(req_pool_indices)
@@ -731,6 +748,14 @@ class FlashInferIndicesUpdaterDecode:
731
748
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
732
749
  bs = kv_indptr.shape[0] - 1
733
750
 
751
+ if use_sliding_window_kv_pool:
752
+ kv_last_index = kv_indptr[-1]
753
+ kv_indices[:kv_last_index] = (
754
+ self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
755
+ kv_indices[:kv_last_index]
756
+ )
757
+ )
758
+
734
759
  wrapper.begin_forward(
735
760
  kv_indptr,
736
761
  kv_indices,
@@ -765,6 +790,7 @@ class FlashInferIndicesUpdaterPrefill:
765
790
  self.kv_last_page_len = attn_backend.kv_last_page_len
766
791
  self.qo_indptr = attn_backend.qo_indptr
767
792
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
793
+ self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
768
794
  self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
769
795
 
770
796
  # Dispatch the update function
@@ -848,6 +874,9 @@ class FlashInferIndicesUpdaterPrefill:
848
874
  paged_kernel_lens_sum = seq_lens_sum
849
875
 
850
876
  kv_start_idx = seq_lens - paged_kernel_lens
877
+ use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
878
+ self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
879
+ )
851
880
 
852
881
  self.call_begin_forward(
853
882
  self.prefill_wrapper_ragged,
@@ -862,6 +891,7 @@ class FlashInferIndicesUpdaterPrefill:
862
891
  self.qo_indptr[wrapper_id],
863
892
  use_ragged,
864
893
  spec_info,
894
+ use_sliding_window_kv_pool=use_sliding_window_kv_pool,
865
895
  )
866
896
 
867
897
  def update_cross_attention(
@@ -916,6 +946,7 @@ class FlashInferIndicesUpdaterPrefill:
916
946
  qo_indptr: torch.Tensor,
917
947
  use_ragged: bool,
918
948
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
949
+ use_sliding_window_kv_pool: bool = False,
919
950
  ):
920
951
  bs = len(seq_lens)
921
952
  if spec_info is None:
@@ -964,6 +995,14 @@ class FlashInferIndicesUpdaterPrefill:
964
995
  q_data_type=self.q_data_type,
965
996
  )
966
997
 
998
+ if use_sliding_window_kv_pool:
999
+ kv_last_index = kv_indptr[-1]
1000
+ kv_indices[:kv_last_index] = (
1001
+ self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
1002
+ kv_indices[:kv_last_index]
1003
+ )
1004
+ )
1005
+
967
1006
  # cached part
968
1007
  wrapper_paged.begin_forward(
969
1008
  qo_indptr,
@@ -24,8 +24,8 @@ from sglang.srt.distributed import (
24
24
  tensor_model_parallel_all_reduce,
25
25
  )
26
26
  from sglang.srt.layers.dp_attention import (
27
- attn_tp_all_gather,
28
- attn_tp_reduce_scatter,
27
+ attn_tp_all_gather_into_tensor,
28
+ attn_tp_reduce_scatter_tensor,
29
29
  dp_gather_partial,
30
30
  dp_scatter,
31
31
  get_attention_dp_size,
@@ -309,8 +309,8 @@ class CommunicateSimpleFn:
309
309
  forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
310
310
  hidden_states,
311
311
  )
312
- attn_tp_all_gather(
313
- list(hidden_states.tensor_split(context.attn_tp_size)),
312
+ attn_tp_all_gather_into_tensor(
313
+ hidden_states,
314
314
  local_hidden_states,
315
315
  )
316
316
  return hidden_states
@@ -400,9 +400,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
400
400
  ].clone(),
401
401
  residual,
402
402
  )
403
- attn_tp_all_gather(
404
- list(residual.tensor_split(context.attn_tp_size)), local_residual
405
- )
403
+ attn_tp_all_gather_into_tensor(residual, local_residual)
406
404
  if context.attn_dp_size != 1:
407
405
  if context.attn_tp_rank == 0:
408
406
  hidden_states += residual
@@ -442,9 +440,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
442
440
  *,
443
441
  residual_input_mode,
444
442
  ):
445
- tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
446
- hidden_states = tensor_list[context.attn_tp_rank]
447
- attn_tp_reduce_scatter(hidden_states, tensor_list)
443
+ input_hidden_states = hidden_states
444
+ hidden_states = hidden_states.tensor_split(context.attn_tp_size)[
445
+ context.attn_tp_rank
446
+ ]
447
+ attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states)
448
448
  if residual_input_mode == ScatterMode.TP_ATTN_FULL:
449
449
  residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]
450
450
  if hidden_states.shape[0] != 0:
@@ -547,8 +547,8 @@ class CommunicateSummableTensorPairFn:
547
547
  forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
548
548
  hidden_states,
549
549
  )
550
- attn_tp_all_gather(
551
- list(hidden_states.tensor_split(context.attn_tp_size)),
550
+ attn_tp_all_gather_into_tensor(
551
+ hidden_states,
552
552
  local_hidden_states,
553
553
  )
554
554
  return hidden_states, residual
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  import functools
4
4
  import logging
5
5
  from contextlib import contextmanager
6
- from typing import TYPE_CHECKING, List
6
+ from enum import IntEnum, auto
7
+ from typing import TYPE_CHECKING, List, Tuple
7
8
 
8
9
  import torch
9
10
  import triton
@@ -30,6 +31,34 @@ _LOCAL_ATTN_DP_SIZE = None
30
31
  _LOCAL_ATTN_DP_RANK = None
31
32
 
32
33
 
34
+ class DPPaddingMode(IntEnum):
35
+
36
+ # Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
37
+ MAX_LEN = auto()
38
+ # Padding tokens to sum length and then gather tokens using `all_reduce`
39
+ SUM_LEN = auto()
40
+
41
+ def is_max_len(self):
42
+ return self == DPPaddingMode.MAX_LEN
43
+
44
+ def is_sum_len(self):
45
+ return self == DPPaddingMode.SUM_LEN
46
+
47
+ @classmethod
48
+ def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
49
+ # we choose the mode that minimizes the communication cost
50
+ max_len = max(global_num_tokens)
51
+ sum_len = sum(global_num_tokens)
52
+ if sum_len * 2 > max_len * get_attention_dp_size():
53
+ return cls.MAX_LEN
54
+ else:
55
+ return cls.SUM_LEN
56
+
57
+ @classmethod
58
+ def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
59
+ return cls.MAX_LEN
60
+
61
+
33
62
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
34
63
  if not enable_dp_attention:
35
64
  return tp_rank, tp_size, 0
@@ -162,7 +191,7 @@ def disable_dp_size():
162
191
  _ATTN_DP_SIZE = old_dp_size
163
192
 
164
193
 
165
- def get_dp_local_info(forward_batch: ForwardBatch):
194
+ def get_dp_local_info(forward_batch: ForwardBatch) -> Tuple[torch.Tensor, torch.Tensor]:
166
195
  # `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
167
196
  dp_rank = get_attention_dp_rank()
168
197
 
@@ -221,7 +250,7 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
221
250
  memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
222
251
 
223
252
 
224
- def _dp_gather(
253
+ def _dp_gather_via_all_reduce(
225
254
  global_tokens: torch.Tensor,
226
255
  local_tokens: torch.Tensor,
227
256
  forward_batch: ForwardBatch,
@@ -238,13 +267,6 @@ def _dp_gather(
238
267
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
239
268
  ), "aliasing between global_tokens and local_tokens not allowed"
240
269
 
241
- # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
242
- # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
243
- # actual size of the accepted tokens.
244
- if forward_batch.forward_mode.is_draft_extend():
245
- shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
246
- local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
247
-
248
270
  memcpy_triton(
249
271
  global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
250
272
  )
@@ -263,6 +285,38 @@ def _dp_gather(
263
285
  global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
264
286
 
265
287
 
288
+ def _dp_gather_via_all_gather(
289
+ global_tokens: torch.Tensor,
290
+ local_tokens: torch.Tensor,
291
+ forward_batch: ForwardBatch,
292
+ is_partial: bool,
293
+ ):
294
+ if not is_partial:
295
+ if get_attention_tp_rank() != 0:
296
+ local_tokens.fill_(0)
297
+ scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[
298
+ get_attention_tp_rank()
299
+ ]
300
+ get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens)
301
+ get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens)
302
+
303
+
304
+ def _dp_gather(
305
+ global_tokens: torch.Tensor,
306
+ local_tokens: torch.Tensor,
307
+ forward_batch: ForwardBatch,
308
+ is_partial: bool,
309
+ ):
310
+ if forward_batch.dp_padding_mode.is_max_len():
311
+ _dp_gather_via_all_gather(
312
+ global_tokens, local_tokens, forward_batch, is_partial
313
+ )
314
+ else:
315
+ _dp_gather_via_all_reduce(
316
+ global_tokens, local_tokens, forward_batch, is_partial
317
+ )
318
+
319
+
266
320
  def dp_gather_partial(
267
321
  global_tokens: torch.Tensor,
268
322
  local_tokens: torch.Tensor,
@@ -296,24 +350,18 @@ def dp_scatter(
296
350
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
297
351
  ), "aliasing between local_tokens and global_tokens not allowed"
298
352
 
299
- # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
300
- # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
301
- # actual size of the accepted tokens.
302
- if forward_batch.forward_mode.is_draft_extend():
303
- shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
304
- local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
305
-
306
353
  memcpy_triton(
307
354
  local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
308
355
  )
309
356
 
310
357
 
311
- def attn_tp_reduce_scatter(
312
- output: torch.Tensor,
313
- input_list: List[torch.Tensor],
314
- ):
315
- return get_attention_tp_group().reduce_scatter(output, input_list)
358
+ def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
359
+ return get_attention_tp_group().reduce_scatter_tensor(output, input)
360
+
361
+
362
+ def attn_tp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor):
363
+ return get_attention_tp_group().all_gather_into_tensor(output, input)
316
364
 
317
365
 
318
- def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
319
- return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
366
+ def attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor):
367
+ return get_attention_tp_group().all_gather(input, output_tensor_list=output_list)
@@ -1,12 +1,12 @@
1
1
  """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import itertools
4
6
  import logging
5
- from abc import abstractmethod
6
- from typing import Dict, List, Optional, Tuple
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
7
8
 
8
9
  import torch
9
- import torch.nn.functional as F
10
10
  from torch.nn.parameter import Parameter, UninitializedParameter
11
11
 
12
12
  from sglang.srt.distributed import (
@@ -17,7 +17,6 @@ from sglang.srt.distributed import (
17
17
  tensor_model_parallel_all_gather,
18
18
  tensor_model_parallel_all_reduce,
19
19
  )
20
- from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
21
20
  from sglang.srt.layers.parameter import (
22
21
  BasevLLMParameter,
23
22
  BlockQuantScaleParameter,
@@ -27,17 +26,14 @@ from sglang.srt.layers.parameter import (
27
26
  RowvLLMParameter,
28
27
  _ColumnvLLMParameter,
29
28
  )
30
- from sglang.srt.layers.quantization.base_config import (
31
- QuantizationConfig,
32
- QuantizeMethodBase,
33
- )
34
- from sglang.srt.utils import (
35
- cpu_has_amx_support,
36
- is_cpu,
37
- is_npu,
38
- set_weight_attrs,
39
- use_intel_amx_backend,
40
- )
29
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
30
+ from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
31
+
32
+ if TYPE_CHECKING:
33
+ from sglang.srt.layers.quantization.base_config import (
34
+ QuantizationConfig,
35
+ QuantizeMethodBase,
36
+ )
41
37
 
42
38
  logger = logging.getLogger(__name__)
43
39
 
@@ -57,9 +53,9 @@ WEIGHT_LOADER_V2_SUPPORTED = [
57
53
  "ModelOptFp8LinearMethod",
58
54
  "ModelOptFp4LinearMethod",
59
55
  "IPEXAWQLinearMethod",
56
+ "PetitNvFp4LinearMethod",
60
57
  ]
61
58
 
62
- _is_cpu_amx_available = cpu_has_amx_support()
63
59
  _is_cpu = is_cpu()
64
60
  _is_npu = is_npu()
65
61
 
@@ -110,91 +106,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
110
106
  return param[shard_id], loaded_weight
111
107
 
112
108
 
113
- class LinearMethodBase(QuantizeMethodBase):
114
- """Base class for different (maybe quantized) linear methods."""
115
-
116
- @abstractmethod
117
- def create_weights(
118
- self,
119
- layer: torch.nn.Module,
120
- input_size_per_partition: int,
121
- output_partition_sizes: List[int],
122
- input_size: int,
123
- output_size: int,
124
- params_dtype: torch.dtype,
125
- **extra_weight_attrs,
126
- ):
127
- """Create weights for a linear layer.
128
- The weights will be set as attributes of the layer.
129
-
130
- Args:
131
- layer: The layer that is using the LinearMethodBase factory.
132
- input_size_per_partition: Size of the weight input dim on rank X.
133
- output_partition_sizes: Sizes of the output dim of each logical
134
- weight on rank X. E.g., output_partition_sizes for QKVLinear
135
- is a list contains the width of Wq, Wk, Wv on rank X.
136
- input_size: Size of the input dim of the weight across all ranks.
137
- output_size: Size of the output dim of the weight across all ranks.
138
- params_dtype: Datatype of the parameters.
139
- """
140
- raise NotImplementedError
141
-
142
- @abstractmethod
143
- def apply(
144
- self,
145
- layer: torch.nn.Module,
146
- x: torch.Tensor,
147
- bias: Optional[torch.Tensor] = None,
148
- ) -> torch.Tensor:
149
- """Apply the weights in layer to the input tensor.
150
- Expects create_weights to have been called before on the layer."""
151
- raise NotImplementedError
152
-
153
-
154
- class UnquantizedLinearMethod(LinearMethodBase):
155
- """Linear method without quantization."""
156
-
157
- def create_weights(
158
- self,
159
- layer: torch.nn.Module,
160
- input_size_per_partition: int,
161
- output_partition_sizes: List[int],
162
- input_size: int,
163
- output_size: int,
164
- params_dtype: torch.dtype,
165
- **extra_weight_attrs,
166
- ):
167
- weight = Parameter(
168
- torch.empty(
169
- sum(output_partition_sizes),
170
- input_size_per_partition,
171
- dtype=params_dtype,
172
- ),
173
- requires_grad=False,
174
- )
175
- set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
176
- layer.register_parameter("weight", weight)
177
- set_weight_attrs(weight, extra_weight_attrs)
178
-
179
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
180
- if _is_cpu and _is_cpu_amx_available:
181
- _amx_process_weight_after_loading(layer, ["weight"])
182
-
183
- def apply(
184
- self,
185
- layer: torch.nn.Module,
186
- x: torch.Tensor,
187
- bias: Optional[torch.Tensor] = None,
188
- ) -> torch.Tensor:
189
-
190
- if use_intel_amx_backend(layer):
191
- return torch.ops.sgl_kernel.weight_packed_linear(
192
- x, layer.weight, bias, True # is_vnni
193
- )
194
-
195
- return F.linear(x, layer.weight, bias)
196
-
197
-
198
109
  class LinearBase(torch.nn.Module):
199
110
  """Base linear layer.
200
111
 
@@ -310,7 +221,7 @@ class ReplicatedLinear(LinearBase):
310
221
  assert param.size() == loaded_weight.size()
311
222
  param.data.copy_(loaded_weight)
312
223
 
313
- def forward(self, x: torch.Tensor) -> torch.Tensor:
224
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
314
225
  bias = self.bias if not self.skip_bias_add else None
315
226
  assert self.quant_method is not None
316
227
  output = self.quant_method.apply(self, x, bias)