sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,20 @@ from typing import List, Optional, Tuple
3
3
 
4
4
  import torch
5
5
 
6
+ from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
7
+
8
+ try:
9
+ from vllm import _custom_ops as ops
10
+
11
+ VLLM_AVAILABLE = True
12
+ except ImportError:
13
+ VLLM_AVAILABLE = False
14
+
6
15
  from sglang.srt.layers.quantization.fp8_kernel import (
7
16
  _enable_jit_deepgemm,
8
17
  per_token_group_quant_fp8,
18
+ scaled_fp8_quant,
19
+ sglang_per_token_quant_fp8,
9
20
  static_quant_fp8,
10
21
  w8a8_block_fp8_matmul,
11
22
  )
@@ -17,30 +28,20 @@ from sglang.srt.utils import (
17
28
  is_hip,
18
29
  )
19
30
 
20
- try:
21
- import vllm
22
- from vllm import _custom_ops as ops
23
-
24
- VLLM_AVAILABLE = True
25
- except ImportError:
26
- VLLM_AVAILABLE = False
27
-
28
- use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
29
-
30
31
  _is_hip = is_hip()
32
+ _is_cuda = is_cuda()
33
+
31
34
  if _is_hip and get_bool_env_var("CK_MOE"):
32
35
  from aiter import gemm_a8w8_blockscale
33
36
 
34
- _is_cuda = is_cuda()
35
37
  if _is_cuda:
36
38
  from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
37
39
 
38
- from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
39
- from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
40
+ use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
40
41
 
41
42
  # Input scaling factors are no longer optional in _scaled_mm starting
42
43
  # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
43
- TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
44
+ TORCH_DEVICE_IDENTITY = None
44
45
 
45
46
  _TORCH_VERSION = torch.__version__.split("+")[0]
46
47
  try:
@@ -143,7 +144,7 @@ def apply_w8a8_block_fp8_linear(
143
144
  gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
144
145
  else:
145
146
  if _enable_jit_deepgemm:
146
- q_input, x_scale = per_token_group_quant_fp8(
147
+ q_input, x_scale = sglang_per_token_group_quant_fp8(
147
148
  input_2d,
148
149
  block_size[1],
149
150
  column_major_scales=True,
@@ -168,12 +169,13 @@ def input_to_float8(
168
169
  """This function quantizes input values to float8 values with tensor-wise quantization."""
169
170
  finfo = torch.finfo(dtype)
170
171
  min_val, max_val = x.aminmax()
171
- amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
172
+ amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
172
173
  fp8_max = finfo.max
173
174
  if _is_hip:
175
+ dtype = torch.float8_e4m3fnuz
174
176
  fp8_max = 224.0
175
177
  scale = fp8_max / amax
176
- x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
178
+ x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max)
177
179
  return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
178
180
 
179
181
 
@@ -212,10 +214,64 @@ def block_quant_to_tensor_quant(
212
214
  for j in range(n_tiles):
213
215
  x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
214
216
 
215
- x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
217
+ x_q_tensor, scale = (
218
+ scaled_fp8_quant(x_dq_block)
219
+ if _is_cuda
220
+ else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
221
+ )
216
222
  return x_q_tensor, scale
217
223
 
218
224
 
225
+ def channel_quant_to_tensor_quant(
226
+ x_q_channel: torch.Tensor,
227
+ x_s: torch.Tensor,
228
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
229
+ x_dq_channel = x_q_channel.to(torch.float32) * x_s
230
+ x_q_tensor, scale = (
231
+ scaled_fp8_quant(x_dq_channel)
232
+ if _is_cuda
233
+ else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
234
+ )
235
+ return x_q_tensor, scale
236
+
237
+
238
+ def _process_scaled_mm_output(output, input_2d_shape, output_shape):
239
+ if type(output) is tuple and len(output) == 2:
240
+ output = output[0]
241
+ return torch.narrow(output, 0, 0, input_2d_shape[0]).view(*output_shape)
242
+
243
+
244
+ def _apply_fallback_scaled_mm(
245
+ qinput,
246
+ weight,
247
+ x_scale,
248
+ weight_scale,
249
+ input_2d_shape,
250
+ output_shape,
251
+ bias,
252
+ input_dtype,
253
+ ):
254
+ global TORCH_DEVICE_IDENTITY
255
+ if TORCH_DEVICE_IDENTITY is None:
256
+ TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32, device=weight.device)
257
+
258
+ output = torch._scaled_mm(
259
+ qinput,
260
+ weight,
261
+ scale_a=TORCH_DEVICE_IDENTITY,
262
+ scale_b=TORCH_DEVICE_IDENTITY,
263
+ out_dtype=torch.float32,
264
+ )
265
+
266
+ output = _process_scaled_mm_output(output, input_2d_shape, output_shape)
267
+ x_scale = torch.narrow(x_scale, 0, 0, input_2d_shape[0])
268
+
269
+ output = output * x_scale * weight_scale.t()
270
+ if bias is not None:
271
+ output = output + bias
272
+ return output.to(dtype=input_dtype)
273
+
274
+
219
275
  def apply_fp8_linear(
220
276
  input: torch.Tensor,
221
277
  weight: torch.Tensor,
@@ -223,206 +279,33 @@ def apply_fp8_linear(
223
279
  input_scale: Optional[torch.Tensor] = None,
224
280
  input_scale_ub: Optional[torch.Tensor] = None,
225
281
  bias: Optional[torch.Tensor] = None,
226
- cutlass_fp8_supported: bool = True,
282
+ cutlass_fp8_supported: bool = cutlass_fp8_supported(),
227
283
  use_per_token_if_dynamic: bool = False,
284
+ pad_output: Optional[bool] = None,
285
+ compressed_tensor_quant: bool = False,
228
286
  ) -> torch.Tensor:
287
+ # Note: we pad the input because torch._scaled_mm is more performant
288
+ # for matrices with batch dimension > 16.
289
+ # This could change in the future.
290
+ # We also don't pad when using torch.compile,
291
+ # as it breaks with dynamic shapes.
292
+ if pad_output is None:
293
+ pad_output = not get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
294
+ output_padding = 17 if pad_output else None
295
+
229
296
  # View input as 2D matrix for fp8 methods
230
297
  input_2d = input.view(-1, input.shape[-1])
231
298
  output_shape = [*input.shape[:-1], weight.shape[1]]
232
299
 
233
- # cutlass w8a8 fp8 sgl-kernel only supports per-token scale
234
- if input_scale is not None:
235
- assert input_scale.numel() == 1
236
- # broadcast per-tensor scale to per-token scale when supporting cutlass
237
- qinput, x_scale = static_quant_fp8(
238
- input_2d, input_scale, repeat_scale=cutlass_fp8_supported
239
- )
240
- else:
241
- # default use per-token quantization if dynamic
242
- if _is_cuda:
243
- qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
244
- else:
245
- qinput, x_scale = per_token_group_quant_fp8(
246
- input_2d, group_size=input_2d.shape[1]
247
- )
248
-
249
- if cutlass_fp8_supported:
250
- try:
251
- if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
252
- # Fall back to vllm cutlass w8a8 fp8 kernel
253
- output = ops.cutlass_scaled_mm(
254
- qinput,
255
- weight,
256
- out_dtype=input.dtype,
257
- scale_a=x_scale,
258
- scale_b=weight_scale,
259
- bias=bias,
260
- )
261
- else:
262
- assert (
263
- weight_scale.numel() == weight.shape[1]
264
- ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
265
- output = fp8_scaled_mm(
266
- qinput,
267
- weight,
268
- x_scale,
269
- weight_scale,
270
- out_dtype=input.dtype,
271
- bias=bias,
272
- )
273
- return output.view(*output_shape)
274
- except (ImportError, NameError, AttributeError):
275
- pass
276
-
277
- # torch.scaled_mm supports per tensor weights + activations only
278
- # so fallback to naive if per channel or per token
279
- else:
280
- per_tensor_weights = weight_scale.numel() == 1
281
- per_tensor_activations = x_scale.numel() == 1
282
-
283
- if per_tensor_weights and per_tensor_activations:
284
- # Fused GEMM_DQ
285
- output = torch._scaled_mm(
286
- qinput,
287
- weight,
288
- out_dtype=input.dtype,
289
- scale_a=x_scale,
290
- scale_b=weight_scale,
291
- bias=bias,
292
- )
293
- # A fix for discrepancy in scaled_mm which returns tuple
294
- # for torch < 2.5 and a single value in torch >= 2.5
295
- if type(output) is tuple and len(output) == 2:
296
- output = output[0]
297
-
298
- return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
299
-
300
- else:
301
- # Fallback for channelwise case, where we use unfused DQ
302
- # due to limitations with scaled_mm
303
-
304
- # Symmetric quantized GEMM by definition computes the following:
305
- # C = (s_x * X) (s_w * W) + bias
306
- # This is equivalent to dequantizing the weights and activations
307
- # before applying a GEMM.
308
- #
309
- # In order to compute quantized operands, a quantized kernel
310
- # will rewrite the above like so:
311
- # C = s_w * s_x * (X * W) + bias
312
- #
313
- # For the scaled_mm fallback case, we break this down, since it
314
- # does not support s_w being a vector.
315
-
316
- # Making sure the dummy tensor is on the same device as the weight
317
- global TORCH_DEVICE_IDENTITY
318
- if TORCH_DEVICE_IDENTITY.device != weight.device:
319
- TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
320
-
321
- # GEMM
322
- # This computes C = (X * W).
323
- # Output in fp32 to allow subsequent ops to happen in-place
324
- output = torch._scaled_mm(
325
- qinput,
326
- weight,
327
- scale_a=TORCH_DEVICE_IDENTITY,
328
- scale_b=TORCH_DEVICE_IDENTITY,
329
- out_dtype=torch.float32,
330
- )
331
- # A fix for discrepancy in scaled_mm which returns tuple
332
- # for torch < 2.5 and a single value in torch >= 2.5
333
- if type(output) is tuple and len(output) == 2:
334
- output = output[0]
335
- # Unpad (undo num_token_padding)
336
- output = torch.narrow(output, 0, 0, input_2d.shape[0])
337
- x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
338
-
339
- # DQ
340
- # C = sw * sx * (X * W) + bias
341
- output = output * x_scale * weight_scale.t()
342
- if bias is not None:
343
- output = output + bias
344
- return output.to(dtype=input.dtype).view(*output_shape)
345
-
346
-
347
- def maybe_create_device_identity():
348
- # Allocate dummy ones tensor for torch._scaled_mm
349
- global TORCH_DEVICE_IDENTITY
350
- if TORCH_DEVICE_IDENTITY is None:
351
- TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
352
-
353
-
354
- # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
355
- # TODO(luka): follow similar pattern for marlin and block-fp8-linear
356
- # https://github.com/vllm-project/vllm/issues/14397
357
- class Fp8LinearOp:
358
- """
359
- This class executes a FP8 linear layer using cutlass if supported and
360
- torch.scaled_mm otherwise.
361
- It needs to be a class instead of a method so that config can be read
362
- in the __init__ method, as reading config is not allowed inside forward.
363
- """
364
-
365
- def __init__(
366
- self,
367
- cutlass_fp8_supported: bool = cutlass_fp8_supported(),
368
- use_per_token_if_dynamic: bool = False,
369
- pad_output: Optional[bool] = None,
370
- ):
371
- self.cutlass_fp8_supported = cutlass_fp8_supported
372
- self.use_per_token_if_dynamic = use_per_token_if_dynamic
373
-
374
- # Note: we pad the input because torch._scaled_mm is more performant
375
- # for matrices with batch dimension > 16.
376
- # This could change in the future.
377
- # We also don't pad when using torch.compile,
378
- # as it breaks with dynamic shapes.
379
- if pad_output is None:
380
- enable_torch_compile = os.environ.get(
381
- "SGLANG_ENABLE_TORCH_COMPILE", "0"
382
- ).lower() in ("1", "true", "yes")
383
- pad_output = not enable_torch_compile
384
- self.output_padding = 17 if pad_output else None
385
-
386
- def apply(
387
- self,
388
- input: torch.Tensor,
389
- weight: torch.Tensor,
390
- weight_scale: torch.Tensor,
391
- input_scale: Optional[torch.Tensor] = None,
392
- input_scale_ub: Optional[torch.Tensor] = None,
393
- bias: Optional[torch.Tensor] = None,
394
- # TODO(luka) remove this parameter in favor of __init__
395
- use_per_token_if_dynamic: Optional[bool] = None,
396
- ) -> torch.Tensor:
397
- # ops.scaled_fp8_quant supports both dynamic and static quant.
398
- # If dynamic, layer.input_scale is None and x_scale computed from x.
399
- # If static, layer.input_scale is scalar and x_scale is input_scale.
400
-
401
- # View input as 2D matrix for fp8 methods
402
- input_2d = input.view(-1, input.shape[-1])
403
- output_shape = [*input.shape[:-1], weight.shape[1]]
404
-
405
- # TODO(luka) this is here because currently MLA only decides this
406
- # during the forward method instead of in __init__.
407
- if use_per_token_if_dynamic is None:
408
- use_per_token_if_dynamic = self.use_per_token_if_dynamic
409
-
300
+ if compressed_tensor_quant:
410
301
  # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
411
302
  # for sgl-kernel fp8_scaled_mm, it support per channel W now
412
- if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
413
- if _is_cuda:
414
- qinput, x_scale = sgl_scaled_fp8_quant(
415
- input_2d,
416
- input_scale,
417
- use_per_token_if_dynamic=use_per_token_if_dynamic,
418
- )
419
- else:
420
- qinput, x_scale = ops.scaled_fp8_quant(
421
- input_2d,
422
- input_scale,
423
- scale_ub=input_scale_ub,
424
- use_per_token_if_dynamic=use_per_token_if_dynamic,
425
- )
303
+ if cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
304
+ qinput, x_scale = scaled_fp8_quant(
305
+ input_2d,
306
+ input_scale,
307
+ use_per_token_if_dynamic=use_per_token_if_dynamic,
308
+ )
426
309
 
427
310
  # Fused GEMM_DQ
428
311
  if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
@@ -453,20 +336,21 @@ class Fp8LinearOp:
453
336
  # so fallback to naive if per channel or per token
454
337
  else:
455
338
  # Maybe apply padding to output, see comment in __init__
456
- if _is_cuda:
457
- qinput, x_scale = sgl_scaled_fp8_quant(
339
+ qinput, x_scale = (
340
+ scaled_fp8_quant(
458
341
  input_2d,
459
342
  input_scale,
460
- num_token_padding=self.output_padding,
343
+ num_token_padding=output_padding,
461
344
  use_per_token_if_dynamic=use_per_token_if_dynamic,
462
345
  )
463
- else:
464
- qinput, x_scale = ops.scaled_fp8_quant(
346
+ if _is_cuda
347
+ else ops.scaled_fp8_quant(
465
348
  input_2d,
466
349
  input_scale,
467
- num_token_padding=self.output_padding,
350
+ num_token_padding=output_padding,
468
351
  use_per_token_if_dynamic=use_per_token_if_dynamic,
469
352
  )
353
+ )
470
354
 
471
355
  per_tensor_weights = weight_scale.numel() == 1
472
356
  per_tensor_activations = x_scale.numel() == 1
@@ -481,12 +365,7 @@ class Fp8LinearOp:
481
365
  scale_b=weight_scale,
482
366
  bias=bias,
483
367
  )
484
- # A fix for discrepancy in scaled_mm which returns tuple
485
- # for torch < 2.5 and a single value in torch >= 2.5
486
- if type(output) is tuple and len(output) == 2:
487
- output = output[0]
488
-
489
- return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
368
+ return _process_scaled_mm_output(output, input_2d.shape, output_shape)
490
369
 
491
370
  elif (
492
371
  use_per_token_if_dynamic
@@ -509,10 +388,7 @@ class Fp8LinearOp:
509
388
  scale_b=weight_scale.t(),
510
389
  bias=bias,
511
390
  )
512
-
513
- output = torch.narrow(output, 0, 0, input_2d.shape[0])
514
- output = output.view(*output_shape)
515
- return output
391
+ return _process_scaled_mm_output(output, input_2d.shape, output_shape)
516
392
 
517
393
  else:
518
394
  # Fallback for channelwise case, where we use unfused DQ
@@ -529,33 +405,110 @@ class Fp8LinearOp:
529
405
  #
530
406
  # For the scaled_mm fallback case, we break this down, since it
531
407
  # does not support s_w being a vector.
532
-
533
- # GEMM
534
- # This computes C = (X * W).
535
- # Output in fp32 to allow subsequent ops to happen in-place
536
-
537
- global TORCH_DEVICE_IDENTITY
538
- if TORCH_DEVICE_IDENTITY.device != weight.device:
539
- TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
540
-
541
- output = torch._scaled_mm(
408
+ return _apply_fallback_scaled_mm(
542
409
  qinput,
543
410
  weight,
544
- scale_a=TORCH_DEVICE_IDENTITY,
545
- scale_b=TORCH_DEVICE_IDENTITY,
546
- out_dtype=torch.float32,
411
+ x_scale,
412
+ weight_scale,
413
+ input_2d.shape,
414
+ output_shape,
415
+ bias,
416
+ input.dtype,
547
417
  )
548
- # A fix for discrepancy in scaled_mm which returns tuple
549
- # for torch < 2.5 and a single value in torch >= 2.5
550
- if type(output) is tuple and len(output) == 2:
551
- output = output[0]
552
- # Unpad (undo num_token_padding)
553
- output = torch.narrow(output, 0, 0, input_2d.shape[0])
554
- x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
555
-
556
- # DQ
557
- # C = sw * sx * (X * W) + bias
558
- output = output * x_scale * weight_scale.t()
559
- if bias is not None:
560
- output = output + bias
561
- return output.to(dtype=input.dtype).view(*output_shape)
418
+ else:
419
+ # cutlass w8a8 fp8 sgl-kernel only supports per-token scale
420
+ if input_scale is not None:
421
+ assert input_scale.numel() == 1
422
+ # broadcast per-tensor scale to per-token scale when supporting cutlass
423
+ qinput, x_scale = static_quant_fp8(
424
+ input_2d, input_scale, repeat_scale=cutlass_fp8_supported
425
+ )
426
+ else:
427
+ # default use per-token quantization if dynamic
428
+ if _is_cuda:
429
+ qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
430
+ else:
431
+ # TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
432
+ # final solution should be: 1. add support to per-tensor activation scaling.
433
+ # 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
434
+ if _is_hip and weight_scale.numel() == 1:
435
+ qinput, x_scale = ops.scaled_fp8_quant(
436
+ input_2d,
437
+ input_scale,
438
+ use_per_token_if_dynamic=use_per_token_if_dynamic,
439
+ )
440
+ else:
441
+ qinput, x_scale = per_token_group_quant_fp8(
442
+ input_2d, group_size=input_2d.shape[1]
443
+ )
444
+
445
+ if cutlass_fp8_supported:
446
+ try:
447
+ if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
448
+ # Fall back to vllm cutlass w8a8 fp8 kernel
449
+ output = ops.cutlass_scaled_mm(
450
+ qinput,
451
+ weight,
452
+ out_dtype=input.dtype,
453
+ scale_a=x_scale,
454
+ scale_b=weight_scale,
455
+ bias=bias,
456
+ )
457
+ else:
458
+ assert (
459
+ weight_scale.numel() == weight.shape[1]
460
+ ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
461
+ output = fp8_scaled_mm(
462
+ qinput,
463
+ weight,
464
+ x_scale,
465
+ weight_scale,
466
+ out_dtype=input.dtype,
467
+ bias=bias,
468
+ )
469
+ return output.view(*output_shape)
470
+ except (ImportError, NameError, AttributeError):
471
+ pass
472
+
473
+ # torch.scaled_mm supports per tensor weights + activations only
474
+ # so fallback to naive if per channel or per token
475
+ per_tensor_weights = weight_scale.numel() == 1
476
+ per_tensor_activations = x_scale.numel() == 1
477
+
478
+ if per_tensor_weights and per_tensor_activations:
479
+ # Fused GEMM_DQ
480
+ output = torch._scaled_mm(
481
+ qinput,
482
+ weight,
483
+ out_dtype=input.dtype,
484
+ scale_a=x_scale,
485
+ scale_b=weight_scale,
486
+ bias=bias,
487
+ )
488
+ return _process_scaled_mm_output(output, input_2d.shape, output_shape)
489
+
490
+ else:
491
+ # Fallback for channelwise case, where we use unfused DQ
492
+ # due to limitations with scaled_mm
493
+
494
+ # Symmetric quantized GEMM by definition computes the following:
495
+ # C = (s_x * X) (s_w * W) + bias
496
+ # This is equivalent to dequantizing the weights and activations
497
+ # before applying a GEMM.
498
+ #
499
+ # In order to compute quantized operands, a quantized kernel
500
+ # will rewrite the above like so:
501
+ # C = s_w * s_x * (X * W) + bias
502
+ #
503
+ # For the scaled_mm fallback case, we break this down, since it
504
+ # does not support s_w being a vector.
505
+ return _apply_fallback_scaled_mm(
506
+ qinput,
507
+ weight,
508
+ x_scale,
509
+ weight_scale,
510
+ input_2d.shape,
511
+ output_shape,
512
+ bias,
513
+ input.dtype,
514
+ )
@@ -8,6 +8,7 @@ from sglang.srt.layers.quantization.base_config import (
8
8
  QuantizationConfig,
9
9
  QuantizeMethodBase,
10
10
  )
11
+ from sglang.srt.layers.radix_attention import RadixAttention
11
12
  from sglang.srt.utils import is_hip
12
13
 
13
14
  _is_hip = is_hip()
@@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
17
18
 
18
19
  class BaseKVCacheMethod(QuantizeMethodBase):
19
20
  """
20
- Quant method that adds `_k_scale` and `_v_scale` attributes to the
21
+ Quant method that adds `k_scale` and `v_scale` attributes to the
21
22
  Attention layer to support loading those scaling factors from checkpoints.
22
23
  The k/v_scale will be used to:
23
24
  - quantize k/v_cache entries before saving them to the cache
@@ -36,8 +37,12 @@ class BaseKVCacheMethod(QuantizeMethodBase):
36
37
  # Initialize the KV cache scales to -1.0, which is an invalid value.
37
38
  # If the k/v_scale appears in the checkpoint, it will be
38
39
  # overwritten when loading weights.
39
- layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
40
- layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
40
+ layer.k_scale = torch.nn.Parameter(
41
+ torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
42
+ )
43
+ layer.v_scale = torch.nn.Parameter(
44
+ torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
45
+ )
41
46
 
42
47
  @classmethod
43
48
  def is_fp8_fnuz(cls) -> bool:
@@ -47,52 +52,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
47
52
  def apply(self, layer: torch.nn.Module) -> torch.Tensor:
48
53
  raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
49
54
 
50
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
51
- # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
52
- # regardless whether the kv-scale is available in the checkpoint.
53
- # No need to process kv scales after loading if we are going to
54
- # calculate them on the fly.
55
- if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
56
- if layer.k_scale > 0.0 and layer.v_scale > 0.0:
57
- # We prefer to use separate k_scale and v_scale if present
58
- k_scale = layer.k_scale.to("cpu").tolist()
59
- v_scale = layer.v_scale.to("cpu").tolist()
60
- if _is_hip and self.is_fp8_fnuz():
61
- k_scale *= 2
62
- v_scale *= 2
63
- elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
64
- # If no scales were loaded (both scales are invalid negative
65
- # values), use the default value of 1.0
66
- k_scale = 1.0
67
- v_scale = 1.0
68
- else:
69
- # If we find a single kv_scale in the checkpoint, we remap
70
- # kv_scale to k_scale during weight loading, and duplicate
71
- # k_scale to v_scale here
72
- assert layer.k_scale > 0.0
73
- scale_to_duplicate = max(layer.k_scale, layer.v_scale)
74
- k_scale = scale_to_duplicate.to("cpu").tolist()
75
- v_scale = scale_to_duplicate.to("cpu").tolist()
76
- if _is_hip and self.is_fp8_fnuz():
77
- k_scale *= 2
78
- v_scale *= 2
79
-
80
- if not isinstance(k_scale, float) or not isinstance(v_scale, float):
81
- raise ValueError(
82
- "Only support per-tensor scaling factor " "for fp8 KV cache"
83
- )
84
-
85
- # These are used in the final Attention.forward()
86
- layer._k_scale.copy_(k_scale)
87
- layer._v_scale.copy_(v_scale)
88
- layer._k_scale_float = k_scale
89
- layer._v_scale_float = v_scale
90
- if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
91
- logger.warning(
92
- "Using KV cache scaling factor 1.0 for fp8_e4m3. This "
93
- "may cause accuracy issues. Please make sure k/v_scale "
94
- "scaling factors are available in the fp8 checkpoint."
95
- )
96
-
97
- del layer.k_scale
98
- del layer.v_scale
55
+ def process_weights_after_loading(self, layer: RadixAttention) -> None:
56
+ if layer.k_scale > 0.0 and layer.v_scale > 0.0:
57
+ # We prefer to use separate k_scale and v_scale if present
58
+ k_scale = layer.k_scale.to("cpu").tolist()
59
+ v_scale = layer.v_scale.to("cpu").tolist()
60
+ if _is_hip and self.is_fp8_fnuz():
61
+ k_scale *= 2
62
+ v_scale *= 2
63
+ elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
64
+ # If no scales were loaded (both scales are invalid negative
65
+ # values), use the default value of 1.0
66
+ k_scale = 1.0
67
+ v_scale = 1.0
68
+ else:
69
+ # If we find a single kv_scale in the checkpoint, we remap
70
+ # kv_scale to k_scale during weight loading, and duplicate
71
+ # k_scale to v_scale here
72
+ assert layer.k_scale > 0.0
73
+ scale_to_duplicate = max(layer.k_scale, layer.v_scale)
74
+ k_scale = scale_to_duplicate.to("cpu").tolist()
75
+ v_scale = scale_to_duplicate.to("cpu").tolist()
76
+ if _is_hip and self.is_fp8_fnuz():
77
+ k_scale *= 2
78
+ v_scale *= 2
79
+
80
+ if not isinstance(k_scale, float) or not isinstance(v_scale, float):
81
+ raise ValueError(
82
+ "Only support per-tensor scaling factor " "for fp8 KV cache"
83
+ )
84
+
85
+ # These are used in the final Attention.forward()
86
+ layer.k_scale.copy_(k_scale)
87
+ layer.v_scale.copy_(v_scale)
88
+ layer.k_scale_float = k_scale
89
+ layer.v_scale_float = v_scale