sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,17 @@
1
+ import functools
2
+ import json
3
+ import logging
4
+ import os
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
1
7
  import torch
2
8
  import triton
3
9
  import triton.language as tl
4
10
 
11
+ from sglang.srt.utils import get_device_name
12
+
13
+ logger = logging.getLogger(__name__)
14
+
5
15
 
6
16
  @triton.jit
7
17
  def _per_token_quant_int8(
@@ -52,3 +62,320 @@ def per_token_quant_int8(x):
52
62
  )
53
63
 
54
64
  return x_q, scales
65
+
66
+
67
+ @triton.jit
68
+ def _per_token_group_quant_int8(
69
+ # Pointers to inputs and output
70
+ y_ptr,
71
+ y_q_ptr,
72
+ y_s_ptr,
73
+ # Stride of input
74
+ y_stride,
75
+ # Collums of input
76
+ N,
77
+ # Avoid to divide zero
78
+ eps,
79
+ # Information for int8
80
+ int8_min,
81
+ int8_max,
82
+ # Meta-parameters
83
+ BLOCK: tl.constexpr,
84
+ ):
85
+ """A Triton-accelerated function to perform per-token-group quantization on a
86
+ tensor.
87
+
88
+ This function converts the tensor values into int8 values.
89
+ """
90
+ # Map the program id to the row of X and Y it should compute.
91
+ g_id = tl.program_id(0)
92
+ y_ptr += g_id * y_stride
93
+ y_q_ptr += g_id * y_stride
94
+ y_s_ptr += g_id
95
+
96
+ cols = tl.arange(0, BLOCK) # N <= BLOCK
97
+ mask = cols < N
98
+
99
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
100
+ # Quant
101
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
102
+ y_s = _absmax / int8_max
103
+ y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty)
104
+
105
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
106
+ tl.store(y_s_ptr, y_s)
107
+
108
+
109
+ def per_token_group_quant_int8(
110
+ x: torch.Tensor,
111
+ group_size: int,
112
+ eps: float = 1e-10,
113
+ dtype: torch.dtype = torch.int8,
114
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
115
+ """Function to perform per-token-group quantization on an input tensor `x`.
116
+
117
+ It converts the tensor values into signed int8 values and returns the
118
+ quantized tensor along with the scaling factor used for quantization.
119
+
120
+ Args:
121
+ x: The input tenosr with ndim >= 2.
122
+ group_size: The group size used for quantization.
123
+ eps: The minimum to avoid dividing zero.
124
+ dtype: The dype of output tensor. Note that only `torch.int8` is supported for now.
125
+
126
+ Returns:
127
+ Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
128
+ """
129
+ assert (
130
+ x.shape[-1] % group_size == 0
131
+ ), "the last dimension of `x` cannot be divisible by `group_size`"
132
+ assert x.is_contiguous(), "`x` is not contiguous"
133
+
134
+ iinfo = torch.iinfo(dtype)
135
+ int8_max = iinfo.max
136
+ int8_min = iinfo.min
137
+
138
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
139
+ M = x.numel() // group_size
140
+ N = group_size
141
+ x_s = torch.empty(
142
+ x.shape[:-1] + (x.shape[-1] // group_size,),
143
+ device=x.device,
144
+ dtype=torch.float32,
145
+ )
146
+
147
+ BLOCK = triton.next_power_of_2(N)
148
+ # heuristics for number of warps
149
+ num_warps = min(max(BLOCK // 256, 1), 8)
150
+ num_stages = 1
151
+ _per_token_group_quant_int8[(M,)](
152
+ x,
153
+ x_q,
154
+ x_s,
155
+ group_size,
156
+ N,
157
+ eps,
158
+ int8_min=int8_min,
159
+ int8_max=int8_max,
160
+ BLOCK=BLOCK,
161
+ num_warps=num_warps,
162
+ num_stages=num_stages,
163
+ )
164
+
165
+ return x_q, x_s
166
+
167
+
168
+ @triton.jit
169
+ def _w8a8_block_int8_matmul(
170
+ # Pointers to inputs and output
171
+ A,
172
+ B,
173
+ C,
174
+ As,
175
+ Bs,
176
+ # Shape for matmul
177
+ M,
178
+ N,
179
+ K,
180
+ # Block size for block-wise quantization
181
+ group_n,
182
+ group_k,
183
+ # Stride for inputs and output
184
+ stride_am,
185
+ stride_ak,
186
+ stride_bk,
187
+ stride_bn,
188
+ stride_cm,
189
+ stride_cn,
190
+ stride_As_m,
191
+ stride_As_k,
192
+ stride_Bs_k,
193
+ stride_Bs_n,
194
+ # Meta-parameters
195
+ BLOCK_SIZE_M: tl.constexpr,
196
+ BLOCK_SIZE_N: tl.constexpr,
197
+ BLOCK_SIZE_K: tl.constexpr,
198
+ GROUP_SIZE_M: tl.constexpr,
199
+ ):
200
+ """Triton-accelerated function used to perform linear operations (dot
201
+ product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
202
+ tensor `C`.
203
+ """
204
+
205
+ pid = tl.program_id(axis=0)
206
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
207
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
208
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
209
+ group_id = pid // num_pid_in_group
210
+ first_pid_m = group_id * GROUP_SIZE_M
211
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
212
+ pid_m = first_pid_m + (pid % group_size_m)
213
+ pid_n = (pid % num_pid_in_group) // group_size_m
214
+
215
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
216
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
217
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
218
+ a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
219
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
220
+
221
+ As_ptrs = As + offs_am * stride_As_m
222
+ offs_bsn = offs_bn // group_n
223
+ Bs_ptrs = Bs + offs_bsn * stride_Bs_n
224
+
225
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
226
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
227
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
228
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
229
+
230
+ k_start = k * BLOCK_SIZE_K
231
+ offs_ks = k_start // group_k
232
+ a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
233
+ b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
234
+
235
+ accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :]
236
+ a_ptrs += BLOCK_SIZE_K * stride_ak
237
+ b_ptrs += BLOCK_SIZE_K * stride_bk
238
+
239
+ if C.dtype.element_ty == tl.bfloat16:
240
+ c = accumulator.to(tl.bfloat16)
241
+ elif C.dtype.element_ty == tl.float16:
242
+ c = accumulator.to(tl.float16)
243
+ else:
244
+ c = accumulator.to(tl.float32)
245
+
246
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
247
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
248
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
249
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
250
+ tl.store(c_ptrs, c, mask=c_mask)
251
+
252
+
253
+ @functools.lru_cache
254
+ def get_w8a8_block_int8_configs(
255
+ N: int, K: int, block_n: int, block_k: int
256
+ ) -> Optional[Dict[int, Any]]:
257
+ """
258
+ Return optimized configurations for the w8a8 block fp8 kernel.
259
+
260
+ The return value will be a dictionary that maps an irregular grid of
261
+ batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
262
+ kernel on a given batch size bs, the closest batch size in the grid should
263
+ be picked and the associated configuration chosen to invoke the kernel.
264
+ """
265
+
266
+ # First look up if an optimized configuration is available in the configs
267
+ # directory
268
+ device_name = get_device_name().replace(" ", "_")
269
+ json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json"
270
+
271
+ config_file_path = os.path.join(
272
+ os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
273
+ )
274
+ if os.path.exists(config_file_path):
275
+ with open(config_file_path) as f:
276
+ logger.info(
277
+ "Using configuration from %s for W8A8 Block INT8 kernel.",
278
+ config_file_path,
279
+ )
280
+ # If a configuration has been found, return it
281
+ return {int(key): val for key, val in json.load(f).items()}
282
+
283
+ # If no optimized configuration is available, we will use the default
284
+ # configuration
285
+ logger.warning(
286
+ (
287
+ "Using default W8A8 Block INT8 kernel config. Performance might be sub-optimal! "
288
+ "Config file not found at %s"
289
+ ),
290
+ config_file_path,
291
+ )
292
+ return None
293
+
294
+
295
+ def w8a8_block_int8_matmul(
296
+ A: torch.Tensor,
297
+ B: torch.Tensor,
298
+ As: torch.Tensor,
299
+ Bs: torch.Tensor,
300
+ block_size: List[int],
301
+ output_dtype: torch.dtype = torch.float16,
302
+ ) -> torch.Tensor:
303
+ """This function performs matrix multiplication with block-wise quantization.
304
+
305
+ It takes two input tensors `A` and `B` with scales `As` and `Bs`.
306
+ The output is returned in the specified `output_dtype`.
307
+
308
+ Args:
309
+ A: The input tensor, e.g., activation.
310
+ B: The input tensor, e.g., weight.
311
+ As: The per-token-group quantization scale for `A`.
312
+ Bs: The per-block quantization scale for `B`.
313
+ block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
314
+ output_dytpe: The dtype of the returned tensor.
315
+
316
+ Returns:
317
+ torch.Tensor: The result of matmul.
318
+ """
319
+ assert len(block_size) == 2
320
+ block_n, block_k = block_size[0], block_size[1]
321
+
322
+ assert A.shape[-1] == B.shape[-1]
323
+ assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
324
+ assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
325
+ M = A.numel() // A.shape[-1]
326
+
327
+ assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
328
+ N, K = B.shape
329
+ assert triton.cdiv(N, block_n) == Bs.shape[0]
330
+ assert triton.cdiv(K, block_k) == Bs.shape[1]
331
+
332
+ C_shape = A.shape[:-1] + (N,)
333
+ C = A.new_empty(C_shape, dtype=output_dtype)
334
+
335
+ configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
336
+ if configs:
337
+ # If an optimal configuration map has been found, look up the
338
+ # optimal config
339
+ config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
340
+ else:
341
+ # Default config
342
+ # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
343
+ config = {
344
+ "BLOCK_SIZE_M": 64,
345
+ "BLOCK_SIZE_N": block_size[0],
346
+ "BLOCK_SIZE_K": block_size[1],
347
+ "GROUP_SIZE_M": 32,
348
+ "num_warps": 4,
349
+ "num_stages": 3,
350
+ }
351
+
352
+ def grid(META):
353
+ return (
354
+ triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
355
+ )
356
+
357
+ _w8a8_block_int8_matmul[grid](
358
+ A,
359
+ B,
360
+ C,
361
+ As,
362
+ Bs,
363
+ M,
364
+ N,
365
+ K,
366
+ block_n,
367
+ block_k,
368
+ A.stride(-2),
369
+ A.stride(-1),
370
+ B.stride(1),
371
+ B.stride(0),
372
+ C.stride(-2),
373
+ C.stride(-1),
374
+ As.stride(-2),
375
+ As.stride(-1),
376
+ Bs.stride(1),
377
+ Bs.stride(0),
378
+ **config,
379
+ )
380
+
381
+ return C
@@ -0,0 +1,73 @@
1
+ from typing import List, Optional, Tuple
2
+
3
+ import torch
4
+
5
+ from sglang.srt.layers.quantization.int8_kernel import (
6
+ per_token_group_quant_int8,
7
+ w8a8_block_int8_matmul,
8
+ )
9
+
10
+
11
+ def apply_w8a8_block_int8_linear(
12
+ input: torch.Tensor,
13
+ weight: torch.Tensor,
14
+ block_size: List[int],
15
+ weight_scale: torch.Tensor,
16
+ input_scale: Optional[torch.Tensor] = None,
17
+ bias: Optional[torch.Tensor] = None,
18
+ ) -> torch.Tensor:
19
+ assert input_scale is None
20
+ # View input as 2D matrix for fp8 methods
21
+ input_2d = input.view(-1, input.shape[-1])
22
+ output_shape = [*input.shape[:-1], weight.shape[0]]
23
+
24
+ q_input, x_scale = per_token_group_quant_int8(input_2d, block_size[1])
25
+ output = w8a8_block_int8_matmul(
26
+ q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
27
+ )
28
+
29
+ if bias is not None:
30
+ output = output + bias
31
+ return output.to(dtype=input.dtype).view(*output_shape)
32
+
33
+
34
+ def input_to_int8(
35
+ x: torch.Tensor, dtype: torch.dtype = torch.int8
36
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
37
+ """This function quantizes input values to int8 values with tensor-wise quantization."""
38
+ iinfo = torch.iinfo(dtype)
39
+ min_val, max_val = x.aminmax()
40
+ amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
41
+ int8_min, int8_max = iinfo.min, iinfo.max
42
+ scale = int8_max / amax
43
+ x_scl_sat = (x * scale).clamp(min=int8_min, max=int8_max)
44
+ return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
45
+
46
+
47
+ def block_dequant(
48
+ x_q_block: torch.Tensor,
49
+ x_s: torch.Tensor,
50
+ block_size: List[int],
51
+ ) -> torch.Tensor:
52
+ """This function conducts block-wise dequantization.
53
+ The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
54
+ and the block size.
55
+ The outputs are dequantized tensor.
56
+ """
57
+ block_n, block_k = block_size[0], block_size[1]
58
+ n, k = x_q_block.shape
59
+ n_tiles = (n + block_n - 1) // block_n
60
+ k_tiles = (k + block_k - 1) // block_k
61
+ assert n_tiles == x_s.shape[0]
62
+ assert k_tiles == x_s.shape[1]
63
+
64
+ x_dq_block = x_q_block.to(torch.float32)
65
+
66
+ for i in range(k_tiles):
67
+ for j in range(n_tiles):
68
+ x_dq_block[
69
+ j * block_n : min((j + 1) * block_n, n),
70
+ i * block_k : min((i + 1) * block_k, k),
71
+ ] *= x_s[j][i]
72
+
73
+ return x_dq_block
@@ -5,12 +5,14 @@ from typing import Any, Dict, List, Optional
5
5
 
6
6
  import torch
7
7
  from torch.nn.parameter import Parameter
8
+ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
8
9
  from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
9
10
  apply_fp8_linear,
10
11
  cutlass_fp8_supported,
11
12
  requantize_with_max_scale,
12
13
  )
13
14
 
15
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
14
16
  from sglang.srt.layers.linear import LinearBase, LinearMethodBase
15
17
  from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
16
18
  from sglang.srt.layers.quantization.base_config import (
@@ -70,7 +72,13 @@ class ModelOptFp8Config(QuantizationConfig):
70
72
  def get_quant_method(
71
73
  self, layer: torch.nn.Module, prefix: str
72
74
  ) -> Optional["QuantizeMethodBase"]:
73
- return ModelOptFp8LinearMethod(self) if isinstance(layer, LinearBase) else None
75
+
76
+ if isinstance(layer, LinearBase):
77
+ return ModelOptFp8LinearMethod(self)
78
+ if isinstance(layer, AttentionBackend):
79
+ return ModelOptFp8KVCacheMethod(self)
80
+
81
+ return None
74
82
 
75
83
  def get_scaled_act_names(self) -> List[str]:
76
84
  return []
@@ -171,3 +179,12 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
171
179
  bias=bias,
172
180
  cutlass_fp8_supported=self.cutlass_fp8_supported,
173
181
  )
182
+
183
+
184
+ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
185
+ """
186
+ Handles loading FP8 kv-cache scaling factors from modelopt quantized checkpoints.
187
+ """
188
+
189
+ def __init__(self, quant_config: ModelOptFp8Config):
190
+ super().__init__(quant_config)
@@ -34,6 +34,7 @@ class RadixAttention(nn.Module):
34
34
  v_head_dim: int = -1,
35
35
  sliding_window_size: int = -1,
36
36
  is_cross_attention: bool = False,
37
+ prefix: str = "",
37
38
  ):
38
39
  super().__init__()
39
40
  self.tp_q_head_num = num_heads
@@ -707,7 +707,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
707
707
  cos = freqs.cos() * self.mscale
708
708
  sin = freqs.sin() * self.mscale
709
709
  cache = torch.cat((cos, sin), dim=-1)
710
- print("Cache shape", cache.shape)
711
710
  return cache
712
711
 
713
712
  def forward(
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import List
2
+ from typing import List, Optional
3
3
 
4
4
  import torch
5
5
  import torch.distributed as dist
@@ -29,7 +29,7 @@ SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
29
29
  class Sampler(nn.Module):
30
30
  def __init__(self):
31
31
  super().__init__()
32
- self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
32
+ self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
33
33
  self.tp_sync_group = get_tensor_model_parallel_group().device_group
34
34
 
35
35
  if global_server_args_dict["enable_dp_attention"]:
@@ -41,14 +41,28 @@ class Sampler(nn.Module):
41
41
  sampling_info: SamplingBatchInfo,
42
42
  return_logprob: bool,
43
43
  top_logprobs_nums: List[int],
44
+ token_ids_logprobs: List[List[int]],
45
+ batch_next_token_ids: Optional[torch.Tensor] = None,
44
46
  ):
47
+ """Run a sampler & compute logprobs and update logits_output accordingly.
48
+
49
+ Args:
50
+ logits_output: The logits from the model forward
51
+ sampling_info: Metadata for sampling
52
+ return_logprob: If set, store the output logprob information to
53
+ logits_output
54
+ top_logprobs_nums: Number of top lobprobs per sequence in a batch
55
+ batch_next_token_ids: next token IDs. If set, skip sampling and only
56
+ compute output logprobs It is used for speculative decoding which
57
+ performs sampling in draft workers.
58
+ """
45
59
  logits = logits_output.next_token_logits
46
60
 
47
61
  # Apply the custom logit processors if registered in the sampling info.
48
62
  if sampling_info.has_custom_logit_processor:
49
63
  self._apply_custom_logit_processor(logits, sampling_info)
50
64
 
51
- if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
65
+ if self.use_nan_detection and torch.any(torch.isnan(logits)):
52
66
  logger.warning("Detected errors during sampling! NaN in the logits.")
53
67
  logits = torch.where(
54
68
  torch.isnan(logits), torch.full_like(logits, -1e5), logits
@@ -58,13 +72,15 @@ class Sampler(nn.Module):
58
72
 
59
73
  if sampling_info.is_all_greedy:
60
74
  # Use torch.argmax if all requests use greedy sampling
61
- batch_next_token_ids = torch.argmax(logits, -1)
75
+ if batch_next_token_ids is None:
76
+ batch_next_token_ids = torch.argmax(logits, -1)
62
77
  if return_logprob:
63
78
  logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
64
79
  else:
65
80
  # Post process logits
66
81
  logits.div_(sampling_info.temperatures)
67
- probs = torch.softmax(logits, dim=-1)
82
+ logits[:] = torch.softmax(logits, dim=-1)
83
+ probs = logits
68
84
  del logits
69
85
 
70
86
  if global_server_args_dict["sampling_backend"] == "flashinfer":
@@ -78,38 +94,43 @@ class Sampler(nn.Module):
78
94
  top_p_normalize_probs_torch(probs, sampling_info.top_ps)
79
95
  ).clamp(min=torch.finfo(probs.dtype).min)
80
96
 
81
- max_top_k_round, batch_size = 32, probs.shape[0]
82
- uniform_samples = torch.rand(
83
- (max_top_k_round, batch_size), device=probs.device
84
- )
85
- if sampling_info.need_min_p_sampling:
86
- probs = top_k_renorm_prob(probs, sampling_info.top_ks)
87
- probs = top_p_renorm_prob(probs, sampling_info.top_ps)
88
- batch_next_token_ids = min_p_sampling_from_probs(
89
- probs, uniform_samples, sampling_info.min_ps
97
+ if batch_next_token_ids is None:
98
+ max_top_k_round, batch_size = 32, probs.shape[0]
99
+ uniform_samples = torch.rand(
100
+ (max_top_k_round, batch_size), device=probs.device
90
101
  )
91
- else:
92
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
102
+ if sampling_info.need_min_p_sampling:
103
+ probs = top_k_renorm_prob(probs, sampling_info.top_ks)
104
+ probs = top_p_renorm_prob(probs, sampling_info.top_ps)
105
+ batch_next_token_ids = min_p_sampling_from_probs(
106
+ probs, uniform_samples, sampling_info.min_ps
107
+ )
108
+ else:
109
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
110
+ probs,
111
+ uniform_samples,
112
+ sampling_info.top_ks,
113
+ sampling_info.top_ps,
114
+ filter_apply_order="joint",
115
+ )
116
+
117
+ if self.use_nan_detection and not torch.all(success):
118
+ logger.warning("Detected errors during sampling!")
119
+ batch_next_token_ids = torch.zeros_like(
120
+ batch_next_token_ids
121
+ )
122
+
123
+ elif global_server_args_dict["sampling_backend"] == "pytorch":
124
+ if batch_next_token_ids is None:
125
+ # A slower fallback implementation with torch native operations.
126
+ batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
93
127
  probs,
94
- uniform_samples,
95
128
  sampling_info.top_ks,
96
129
  sampling_info.top_ps,
97
- filter_apply_order="joint",
130
+ sampling_info.min_ps,
131
+ sampling_info.need_min_p_sampling,
98
132
  )
99
133
 
100
- if self.use_nan_detectioin and not torch.all(success):
101
- logger.warning("Detected errors during sampling!")
102
- batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
103
-
104
- elif global_server_args_dict["sampling_backend"] == "pytorch":
105
- # A slower fallback implementation with torch native operations.
106
- batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
107
- probs,
108
- sampling_info.top_ks,
109
- sampling_info.top_ps,
110
- sampling_info.min_ps,
111
- sampling_info.need_min_p_sampling,
112
- )
113
134
  if return_logprob:
114
135
  # clamp to avoid -inf
115
136
  logprobs = torch.log(
@@ -128,6 +149,12 @@ class Sampler(nn.Module):
128
149
  logits_output.next_token_top_logprobs_idx,
129
150
  ) = get_top_logprobs(logprobs, top_logprobs_nums)
130
151
 
152
+ if any(x is not None for x in token_ids_logprobs):
153
+ (
154
+ logits_output.next_token_token_ids_logprobs_val,
155
+ logits_output.next_token_token_ids_logprobs_idx,
156
+ ) = get_token_ids_logprobs(logprobs, token_ids_logprobs)
157
+
131
158
  logits_output.next_token_logprobs = logprobs[
132
159
  torch.arange(len(batch_next_token_ids), device=sampling_info.device),
133
160
  batch_next_token_ids,
@@ -223,6 +250,10 @@ def top_p_normalize_probs_torch(
223
250
 
224
251
 
225
252
  def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
253
+ assert len(top_logprobs_nums) == logprobs.shape[0], (
254
+ len(top_logprobs_nums),
255
+ logprobs.shape[0],
256
+ )
226
257
  max_k = max(top_logprobs_nums)
227
258
  ret = logprobs.topk(max_k, dim=1)
228
259
  values = ret.values.tolist()
@@ -234,3 +265,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
234
265
  output_top_logprobs_val.append(values[i][:k])
235
266
  output_top_logprobs_idx.append(indices[i][:k])
236
267
  return output_top_logprobs_val, output_top_logprobs_idx
268
+
269
+
270
+ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
271
+ output_token_ids_logprobs_val = []
272
+ output_token_ids_logprobs_idx = []
273
+ for i, token_ids in enumerate(token_ids_logprobs):
274
+ if token_ids is not None:
275
+ output_token_ids_logprobs_val.append(logprobs[i, token_ids].tolist())
276
+ output_token_ids_logprobs_idx.append(token_ids)
277
+ else:
278
+ output_token_ids_logprobs_val.append([])
279
+ output_token_ids_logprobs_idx.append([])
280
+
281
+ return output_token_ids_logprobs_val, output_token_ids_logprobs_idx