sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,20 @@ from typing import List, Optional, Tuple
|
|
3
3
|
|
4
4
|
import torch
|
5
5
|
|
6
|
+
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
7
|
+
|
8
|
+
try:
|
9
|
+
from vllm import _custom_ops as ops
|
10
|
+
|
11
|
+
VLLM_AVAILABLE = True
|
12
|
+
except ImportError:
|
13
|
+
VLLM_AVAILABLE = False
|
14
|
+
|
6
15
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
7
16
|
_enable_jit_deepgemm,
|
8
17
|
per_token_group_quant_fp8,
|
18
|
+
scaled_fp8_quant,
|
19
|
+
sglang_per_token_quant_fp8,
|
9
20
|
static_quant_fp8,
|
10
21
|
w8a8_block_fp8_matmul,
|
11
22
|
)
|
@@ -17,30 +28,20 @@ from sglang.srt.utils import (
|
|
17
28
|
is_hip,
|
18
29
|
)
|
19
30
|
|
20
|
-
try:
|
21
|
-
import vllm
|
22
|
-
from vllm import _custom_ops as ops
|
23
|
-
|
24
|
-
VLLM_AVAILABLE = True
|
25
|
-
except ImportError:
|
26
|
-
VLLM_AVAILABLE = False
|
27
|
-
|
28
|
-
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
29
|
-
|
30
31
|
_is_hip = is_hip()
|
32
|
+
_is_cuda = is_cuda()
|
33
|
+
|
31
34
|
if _is_hip and get_bool_env_var("CK_MOE"):
|
32
35
|
from aiter import gemm_a8w8_blockscale
|
33
36
|
|
34
|
-
_is_cuda = is_cuda()
|
35
37
|
if _is_cuda:
|
36
38
|
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
|
37
39
|
|
38
|
-
|
39
|
-
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
|
40
|
+
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
40
41
|
|
41
42
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
42
43
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
43
|
-
TORCH_DEVICE_IDENTITY =
|
44
|
+
TORCH_DEVICE_IDENTITY = None
|
44
45
|
|
45
46
|
_TORCH_VERSION = torch.__version__.split("+")[0]
|
46
47
|
try:
|
@@ -143,7 +144,7 @@ def apply_w8a8_block_fp8_linear(
|
|
143
144
|
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
144
145
|
else:
|
145
146
|
if _enable_jit_deepgemm:
|
146
|
-
q_input, x_scale =
|
147
|
+
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
147
148
|
input_2d,
|
148
149
|
block_size[1],
|
149
150
|
column_major_scales=True,
|
@@ -168,12 +169,13 @@ def input_to_float8(
|
|
168
169
|
"""This function quantizes input values to float8 values with tensor-wise quantization."""
|
169
170
|
finfo = torch.finfo(dtype)
|
170
171
|
min_val, max_val = x.aminmax()
|
171
|
-
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
172
|
+
amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
|
172
173
|
fp8_max = finfo.max
|
173
174
|
if _is_hip:
|
175
|
+
dtype = torch.float8_e4m3fnuz
|
174
176
|
fp8_max = 224.0
|
175
177
|
scale = fp8_max / amax
|
176
|
-
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
178
|
+
x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max)
|
177
179
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
178
180
|
|
179
181
|
|
@@ -212,10 +214,64 @@ def block_quant_to_tensor_quant(
|
|
212
214
|
for j in range(n_tiles):
|
213
215
|
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
214
216
|
|
215
|
-
x_q_tensor, scale =
|
217
|
+
x_q_tensor, scale = (
|
218
|
+
scaled_fp8_quant(x_dq_block)
|
219
|
+
if _is_cuda
|
220
|
+
else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
221
|
+
)
|
216
222
|
return x_q_tensor, scale
|
217
223
|
|
218
224
|
|
225
|
+
def channel_quant_to_tensor_quant(
|
226
|
+
x_q_channel: torch.Tensor,
|
227
|
+
x_s: torch.Tensor,
|
228
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
229
|
+
x_dq_channel = x_q_channel.to(torch.float32) * x_s
|
230
|
+
x_q_tensor, scale = (
|
231
|
+
scaled_fp8_quant(x_dq_channel)
|
232
|
+
if _is_cuda
|
233
|
+
else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
|
234
|
+
)
|
235
|
+
return x_q_tensor, scale
|
236
|
+
|
237
|
+
|
238
|
+
def _process_scaled_mm_output(output, input_2d_shape, output_shape):
|
239
|
+
if type(output) is tuple and len(output) == 2:
|
240
|
+
output = output[0]
|
241
|
+
return torch.narrow(output, 0, 0, input_2d_shape[0]).view(*output_shape)
|
242
|
+
|
243
|
+
|
244
|
+
def _apply_fallback_scaled_mm(
|
245
|
+
qinput,
|
246
|
+
weight,
|
247
|
+
x_scale,
|
248
|
+
weight_scale,
|
249
|
+
input_2d_shape,
|
250
|
+
output_shape,
|
251
|
+
bias,
|
252
|
+
input_dtype,
|
253
|
+
):
|
254
|
+
global TORCH_DEVICE_IDENTITY
|
255
|
+
if TORCH_DEVICE_IDENTITY is None:
|
256
|
+
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32, device=weight.device)
|
257
|
+
|
258
|
+
output = torch._scaled_mm(
|
259
|
+
qinput,
|
260
|
+
weight,
|
261
|
+
scale_a=TORCH_DEVICE_IDENTITY,
|
262
|
+
scale_b=TORCH_DEVICE_IDENTITY,
|
263
|
+
out_dtype=torch.float32,
|
264
|
+
)
|
265
|
+
|
266
|
+
output = _process_scaled_mm_output(output, input_2d_shape, output_shape)
|
267
|
+
x_scale = torch.narrow(x_scale, 0, 0, input_2d_shape[0])
|
268
|
+
|
269
|
+
output = output * x_scale * weight_scale.t()
|
270
|
+
if bias is not None:
|
271
|
+
output = output + bias
|
272
|
+
return output.to(dtype=input_dtype)
|
273
|
+
|
274
|
+
|
219
275
|
def apply_fp8_linear(
|
220
276
|
input: torch.Tensor,
|
221
277
|
weight: torch.Tensor,
|
@@ -223,206 +279,33 @@ def apply_fp8_linear(
|
|
223
279
|
input_scale: Optional[torch.Tensor] = None,
|
224
280
|
input_scale_ub: Optional[torch.Tensor] = None,
|
225
281
|
bias: Optional[torch.Tensor] = None,
|
226
|
-
cutlass_fp8_supported: bool =
|
282
|
+
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
|
227
283
|
use_per_token_if_dynamic: bool = False,
|
284
|
+
pad_output: Optional[bool] = None,
|
285
|
+
compressed_tensor_quant: bool = False,
|
228
286
|
) -> torch.Tensor:
|
287
|
+
# Note: we pad the input because torch._scaled_mm is more performant
|
288
|
+
# for matrices with batch dimension > 16.
|
289
|
+
# This could change in the future.
|
290
|
+
# We also don't pad when using torch.compile,
|
291
|
+
# as it breaks with dynamic shapes.
|
292
|
+
if pad_output is None:
|
293
|
+
pad_output = not get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
|
294
|
+
output_padding = 17 if pad_output else None
|
295
|
+
|
229
296
|
# View input as 2D matrix for fp8 methods
|
230
297
|
input_2d = input.view(-1, input.shape[-1])
|
231
298
|
output_shape = [*input.shape[:-1], weight.shape[1]]
|
232
299
|
|
233
|
-
|
234
|
-
if input_scale is not None:
|
235
|
-
assert input_scale.numel() == 1
|
236
|
-
# broadcast per-tensor scale to per-token scale when supporting cutlass
|
237
|
-
qinput, x_scale = static_quant_fp8(
|
238
|
-
input_2d, input_scale, repeat_scale=cutlass_fp8_supported
|
239
|
-
)
|
240
|
-
else:
|
241
|
-
# default use per-token quantization if dynamic
|
242
|
-
if _is_cuda:
|
243
|
-
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
|
244
|
-
else:
|
245
|
-
qinput, x_scale = per_token_group_quant_fp8(
|
246
|
-
input_2d, group_size=input_2d.shape[1]
|
247
|
-
)
|
248
|
-
|
249
|
-
if cutlass_fp8_supported:
|
250
|
-
try:
|
251
|
-
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
252
|
-
# Fall back to vllm cutlass w8a8 fp8 kernel
|
253
|
-
output = ops.cutlass_scaled_mm(
|
254
|
-
qinput,
|
255
|
-
weight,
|
256
|
-
out_dtype=input.dtype,
|
257
|
-
scale_a=x_scale,
|
258
|
-
scale_b=weight_scale,
|
259
|
-
bias=bias,
|
260
|
-
)
|
261
|
-
else:
|
262
|
-
assert (
|
263
|
-
weight_scale.numel() == weight.shape[1]
|
264
|
-
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
265
|
-
output = fp8_scaled_mm(
|
266
|
-
qinput,
|
267
|
-
weight,
|
268
|
-
x_scale,
|
269
|
-
weight_scale,
|
270
|
-
out_dtype=input.dtype,
|
271
|
-
bias=bias,
|
272
|
-
)
|
273
|
-
return output.view(*output_shape)
|
274
|
-
except (ImportError, NameError, AttributeError):
|
275
|
-
pass
|
276
|
-
|
277
|
-
# torch.scaled_mm supports per tensor weights + activations only
|
278
|
-
# so fallback to naive if per channel or per token
|
279
|
-
else:
|
280
|
-
per_tensor_weights = weight_scale.numel() == 1
|
281
|
-
per_tensor_activations = x_scale.numel() == 1
|
282
|
-
|
283
|
-
if per_tensor_weights and per_tensor_activations:
|
284
|
-
# Fused GEMM_DQ
|
285
|
-
output = torch._scaled_mm(
|
286
|
-
qinput,
|
287
|
-
weight,
|
288
|
-
out_dtype=input.dtype,
|
289
|
-
scale_a=x_scale,
|
290
|
-
scale_b=weight_scale,
|
291
|
-
bias=bias,
|
292
|
-
)
|
293
|
-
# A fix for discrepancy in scaled_mm which returns tuple
|
294
|
-
# for torch < 2.5 and a single value in torch >= 2.5
|
295
|
-
if type(output) is tuple and len(output) == 2:
|
296
|
-
output = output[0]
|
297
|
-
|
298
|
-
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
299
|
-
|
300
|
-
else:
|
301
|
-
# Fallback for channelwise case, where we use unfused DQ
|
302
|
-
# due to limitations with scaled_mm
|
303
|
-
|
304
|
-
# Symmetric quantized GEMM by definition computes the following:
|
305
|
-
# C = (s_x * X) (s_w * W) + bias
|
306
|
-
# This is equivalent to dequantizing the weights and activations
|
307
|
-
# before applying a GEMM.
|
308
|
-
#
|
309
|
-
# In order to compute quantized operands, a quantized kernel
|
310
|
-
# will rewrite the above like so:
|
311
|
-
# C = s_w * s_x * (X * W) + bias
|
312
|
-
#
|
313
|
-
# For the scaled_mm fallback case, we break this down, since it
|
314
|
-
# does not support s_w being a vector.
|
315
|
-
|
316
|
-
# Making sure the dummy tensor is on the same device as the weight
|
317
|
-
global TORCH_DEVICE_IDENTITY
|
318
|
-
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
319
|
-
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
320
|
-
|
321
|
-
# GEMM
|
322
|
-
# This computes C = (X * W).
|
323
|
-
# Output in fp32 to allow subsequent ops to happen in-place
|
324
|
-
output = torch._scaled_mm(
|
325
|
-
qinput,
|
326
|
-
weight,
|
327
|
-
scale_a=TORCH_DEVICE_IDENTITY,
|
328
|
-
scale_b=TORCH_DEVICE_IDENTITY,
|
329
|
-
out_dtype=torch.float32,
|
330
|
-
)
|
331
|
-
# A fix for discrepancy in scaled_mm which returns tuple
|
332
|
-
# for torch < 2.5 and a single value in torch >= 2.5
|
333
|
-
if type(output) is tuple and len(output) == 2:
|
334
|
-
output = output[0]
|
335
|
-
# Unpad (undo num_token_padding)
|
336
|
-
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
337
|
-
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
338
|
-
|
339
|
-
# DQ
|
340
|
-
# C = sw * sx * (X * W) + bias
|
341
|
-
output = output * x_scale * weight_scale.t()
|
342
|
-
if bias is not None:
|
343
|
-
output = output + bias
|
344
|
-
return output.to(dtype=input.dtype).view(*output_shape)
|
345
|
-
|
346
|
-
|
347
|
-
def maybe_create_device_identity():
|
348
|
-
# Allocate dummy ones tensor for torch._scaled_mm
|
349
|
-
global TORCH_DEVICE_IDENTITY
|
350
|
-
if TORCH_DEVICE_IDENTITY is None:
|
351
|
-
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
352
|
-
|
353
|
-
|
354
|
-
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
|
355
|
-
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
|
356
|
-
# https://github.com/vllm-project/vllm/issues/14397
|
357
|
-
class Fp8LinearOp:
|
358
|
-
"""
|
359
|
-
This class executes a FP8 linear layer using cutlass if supported and
|
360
|
-
torch.scaled_mm otherwise.
|
361
|
-
It needs to be a class instead of a method so that config can be read
|
362
|
-
in the __init__ method, as reading config is not allowed inside forward.
|
363
|
-
"""
|
364
|
-
|
365
|
-
def __init__(
|
366
|
-
self,
|
367
|
-
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
|
368
|
-
use_per_token_if_dynamic: bool = False,
|
369
|
-
pad_output: Optional[bool] = None,
|
370
|
-
):
|
371
|
-
self.cutlass_fp8_supported = cutlass_fp8_supported
|
372
|
-
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
373
|
-
|
374
|
-
# Note: we pad the input because torch._scaled_mm is more performant
|
375
|
-
# for matrices with batch dimension > 16.
|
376
|
-
# This could change in the future.
|
377
|
-
# We also don't pad when using torch.compile,
|
378
|
-
# as it breaks with dynamic shapes.
|
379
|
-
if pad_output is None:
|
380
|
-
enable_torch_compile = os.environ.get(
|
381
|
-
"SGLANG_ENABLE_TORCH_COMPILE", "0"
|
382
|
-
).lower() in ("1", "true", "yes")
|
383
|
-
pad_output = not enable_torch_compile
|
384
|
-
self.output_padding = 17 if pad_output else None
|
385
|
-
|
386
|
-
def apply(
|
387
|
-
self,
|
388
|
-
input: torch.Tensor,
|
389
|
-
weight: torch.Tensor,
|
390
|
-
weight_scale: torch.Tensor,
|
391
|
-
input_scale: Optional[torch.Tensor] = None,
|
392
|
-
input_scale_ub: Optional[torch.Tensor] = None,
|
393
|
-
bias: Optional[torch.Tensor] = None,
|
394
|
-
# TODO(luka) remove this parameter in favor of __init__
|
395
|
-
use_per_token_if_dynamic: Optional[bool] = None,
|
396
|
-
) -> torch.Tensor:
|
397
|
-
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
398
|
-
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
399
|
-
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
400
|
-
|
401
|
-
# View input as 2D matrix for fp8 methods
|
402
|
-
input_2d = input.view(-1, input.shape[-1])
|
403
|
-
output_shape = [*input.shape[:-1], weight.shape[1]]
|
404
|
-
|
405
|
-
# TODO(luka) this is here because currently MLA only decides this
|
406
|
-
# during the forward method instead of in __init__.
|
407
|
-
if use_per_token_if_dynamic is None:
|
408
|
-
use_per_token_if_dynamic = self.use_per_token_if_dynamic
|
409
|
-
|
300
|
+
if compressed_tensor_quant:
|
410
301
|
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
411
302
|
# for sgl-kernel fp8_scaled_mm, it support per channel W now
|
412
|
-
if
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
)
|
419
|
-
else:
|
420
|
-
qinput, x_scale = ops.scaled_fp8_quant(
|
421
|
-
input_2d,
|
422
|
-
input_scale,
|
423
|
-
scale_ub=input_scale_ub,
|
424
|
-
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
425
|
-
)
|
303
|
+
if cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
|
304
|
+
qinput, x_scale = scaled_fp8_quant(
|
305
|
+
input_2d,
|
306
|
+
input_scale,
|
307
|
+
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
308
|
+
)
|
426
309
|
|
427
310
|
# Fused GEMM_DQ
|
428
311
|
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
@@ -453,20 +336,21 @@ class Fp8LinearOp:
|
|
453
336
|
# so fallback to naive if per channel or per token
|
454
337
|
else:
|
455
338
|
# Maybe apply padding to output, see comment in __init__
|
456
|
-
|
457
|
-
|
339
|
+
qinput, x_scale = (
|
340
|
+
scaled_fp8_quant(
|
458
341
|
input_2d,
|
459
342
|
input_scale,
|
460
|
-
num_token_padding=
|
343
|
+
num_token_padding=output_padding,
|
461
344
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
462
345
|
)
|
463
|
-
|
464
|
-
|
346
|
+
if _is_cuda
|
347
|
+
else ops.scaled_fp8_quant(
|
465
348
|
input_2d,
|
466
349
|
input_scale,
|
467
|
-
num_token_padding=
|
350
|
+
num_token_padding=output_padding,
|
468
351
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
469
352
|
)
|
353
|
+
)
|
470
354
|
|
471
355
|
per_tensor_weights = weight_scale.numel() == 1
|
472
356
|
per_tensor_activations = x_scale.numel() == 1
|
@@ -481,12 +365,7 @@ class Fp8LinearOp:
|
|
481
365
|
scale_b=weight_scale,
|
482
366
|
bias=bias,
|
483
367
|
)
|
484
|
-
|
485
|
-
# for torch < 2.5 and a single value in torch >= 2.5
|
486
|
-
if type(output) is tuple and len(output) == 2:
|
487
|
-
output = output[0]
|
488
|
-
|
489
|
-
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
368
|
+
return _process_scaled_mm_output(output, input_2d.shape, output_shape)
|
490
369
|
|
491
370
|
elif (
|
492
371
|
use_per_token_if_dynamic
|
@@ -509,10 +388,7 @@ class Fp8LinearOp:
|
|
509
388
|
scale_b=weight_scale.t(),
|
510
389
|
bias=bias,
|
511
390
|
)
|
512
|
-
|
513
|
-
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
514
|
-
output = output.view(*output_shape)
|
515
|
-
return output
|
391
|
+
return _process_scaled_mm_output(output, input_2d.shape, output_shape)
|
516
392
|
|
517
393
|
else:
|
518
394
|
# Fallback for channelwise case, where we use unfused DQ
|
@@ -529,33 +405,110 @@ class Fp8LinearOp:
|
|
529
405
|
#
|
530
406
|
# For the scaled_mm fallback case, we break this down, since it
|
531
407
|
# does not support s_w being a vector.
|
532
|
-
|
533
|
-
# GEMM
|
534
|
-
# This computes C = (X * W).
|
535
|
-
# Output in fp32 to allow subsequent ops to happen in-place
|
536
|
-
|
537
|
-
global TORCH_DEVICE_IDENTITY
|
538
|
-
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
539
|
-
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
540
|
-
|
541
|
-
output = torch._scaled_mm(
|
408
|
+
return _apply_fallback_scaled_mm(
|
542
409
|
qinput,
|
543
410
|
weight,
|
544
|
-
|
545
|
-
|
546
|
-
|
411
|
+
x_scale,
|
412
|
+
weight_scale,
|
413
|
+
input_2d.shape,
|
414
|
+
output_shape,
|
415
|
+
bias,
|
416
|
+
input.dtype,
|
547
417
|
)
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
418
|
+
else:
|
419
|
+
# cutlass w8a8 fp8 sgl-kernel only supports per-token scale
|
420
|
+
if input_scale is not None:
|
421
|
+
assert input_scale.numel() == 1
|
422
|
+
# broadcast per-tensor scale to per-token scale when supporting cutlass
|
423
|
+
qinput, x_scale = static_quant_fp8(
|
424
|
+
input_2d, input_scale, repeat_scale=cutlass_fp8_supported
|
425
|
+
)
|
426
|
+
else:
|
427
|
+
# default use per-token quantization if dynamic
|
428
|
+
if _is_cuda:
|
429
|
+
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
|
430
|
+
else:
|
431
|
+
# TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
|
432
|
+
# final solution should be: 1. add support to per-tensor activation scaling.
|
433
|
+
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
|
434
|
+
if _is_hip and weight_scale.numel() == 1:
|
435
|
+
qinput, x_scale = ops.scaled_fp8_quant(
|
436
|
+
input_2d,
|
437
|
+
input_scale,
|
438
|
+
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
439
|
+
)
|
440
|
+
else:
|
441
|
+
qinput, x_scale = per_token_group_quant_fp8(
|
442
|
+
input_2d, group_size=input_2d.shape[1]
|
443
|
+
)
|
444
|
+
|
445
|
+
if cutlass_fp8_supported:
|
446
|
+
try:
|
447
|
+
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
448
|
+
# Fall back to vllm cutlass w8a8 fp8 kernel
|
449
|
+
output = ops.cutlass_scaled_mm(
|
450
|
+
qinput,
|
451
|
+
weight,
|
452
|
+
out_dtype=input.dtype,
|
453
|
+
scale_a=x_scale,
|
454
|
+
scale_b=weight_scale,
|
455
|
+
bias=bias,
|
456
|
+
)
|
457
|
+
else:
|
458
|
+
assert (
|
459
|
+
weight_scale.numel() == weight.shape[1]
|
460
|
+
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
461
|
+
output = fp8_scaled_mm(
|
462
|
+
qinput,
|
463
|
+
weight,
|
464
|
+
x_scale,
|
465
|
+
weight_scale,
|
466
|
+
out_dtype=input.dtype,
|
467
|
+
bias=bias,
|
468
|
+
)
|
469
|
+
return output.view(*output_shape)
|
470
|
+
except (ImportError, NameError, AttributeError):
|
471
|
+
pass
|
472
|
+
|
473
|
+
# torch.scaled_mm supports per tensor weights + activations only
|
474
|
+
# so fallback to naive if per channel or per token
|
475
|
+
per_tensor_weights = weight_scale.numel() == 1
|
476
|
+
per_tensor_activations = x_scale.numel() == 1
|
477
|
+
|
478
|
+
if per_tensor_weights and per_tensor_activations:
|
479
|
+
# Fused GEMM_DQ
|
480
|
+
output = torch._scaled_mm(
|
481
|
+
qinput,
|
482
|
+
weight,
|
483
|
+
out_dtype=input.dtype,
|
484
|
+
scale_a=x_scale,
|
485
|
+
scale_b=weight_scale,
|
486
|
+
bias=bias,
|
487
|
+
)
|
488
|
+
return _process_scaled_mm_output(output, input_2d.shape, output_shape)
|
489
|
+
|
490
|
+
else:
|
491
|
+
# Fallback for channelwise case, where we use unfused DQ
|
492
|
+
# due to limitations with scaled_mm
|
493
|
+
|
494
|
+
# Symmetric quantized GEMM by definition computes the following:
|
495
|
+
# C = (s_x * X) (s_w * W) + bias
|
496
|
+
# This is equivalent to dequantizing the weights and activations
|
497
|
+
# before applying a GEMM.
|
498
|
+
#
|
499
|
+
# In order to compute quantized operands, a quantized kernel
|
500
|
+
# will rewrite the above like so:
|
501
|
+
# C = s_w * s_x * (X * W) + bias
|
502
|
+
#
|
503
|
+
# For the scaled_mm fallback case, we break this down, since it
|
504
|
+
# does not support s_w being a vector.
|
505
|
+
return _apply_fallback_scaled_mm(
|
506
|
+
qinput,
|
507
|
+
weight,
|
508
|
+
x_scale,
|
509
|
+
weight_scale,
|
510
|
+
input_2d.shape,
|
511
|
+
output_shape,
|
512
|
+
bias,
|
513
|
+
input.dtype,
|
514
|
+
)
|
@@ -8,6 +8,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
8
8
|
QuantizationConfig,
|
9
9
|
QuantizeMethodBase,
|
10
10
|
)
|
11
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
11
12
|
from sglang.srt.utils import is_hip
|
12
13
|
|
13
14
|
_is_hip = is_hip()
|
@@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
|
|
17
18
|
|
18
19
|
class BaseKVCacheMethod(QuantizeMethodBase):
|
19
20
|
"""
|
20
|
-
Quant method that adds `
|
21
|
+
Quant method that adds `k_scale` and `v_scale` attributes to the
|
21
22
|
Attention layer to support loading those scaling factors from checkpoints.
|
22
23
|
The k/v_scale will be used to:
|
23
24
|
- quantize k/v_cache entries before saving them to the cache
|
@@ -36,8 +37,12 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|
36
37
|
# Initialize the KV cache scales to -1.0, which is an invalid value.
|
37
38
|
# If the k/v_scale appears in the checkpoint, it will be
|
38
39
|
# overwritten when loading weights.
|
39
|
-
layer.k_scale = torch.nn.Parameter(
|
40
|
-
|
40
|
+
layer.k_scale = torch.nn.Parameter(
|
41
|
+
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
42
|
+
)
|
43
|
+
layer.v_scale = torch.nn.Parameter(
|
44
|
+
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
45
|
+
)
|
41
46
|
|
42
47
|
@classmethod
|
43
48
|
def is_fp8_fnuz(cls) -> bool:
|
@@ -47,52 +52,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|
47
52
|
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
48
53
|
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
49
54
|
|
50
|
-
def process_weights_after_loading(self, layer:
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
# These are used in the final Attention.forward()
|
86
|
-
layer._k_scale.copy_(k_scale)
|
87
|
-
layer._v_scale.copy_(v_scale)
|
88
|
-
layer._k_scale_float = k_scale
|
89
|
-
layer._v_scale_float = v_scale
|
90
|
-
if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
|
91
|
-
logger.warning(
|
92
|
-
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
93
|
-
"may cause accuracy issues. Please make sure k/v_scale "
|
94
|
-
"scaling factors are available in the fp8 checkpoint."
|
95
|
-
)
|
96
|
-
|
97
|
-
del layer.k_scale
|
98
|
-
del layer.v_scale
|
55
|
+
def process_weights_after_loading(self, layer: RadixAttention) -> None:
|
56
|
+
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
57
|
+
# We prefer to use separate k_scale and v_scale if present
|
58
|
+
k_scale = layer.k_scale.to("cpu").tolist()
|
59
|
+
v_scale = layer.v_scale.to("cpu").tolist()
|
60
|
+
if _is_hip and self.is_fp8_fnuz():
|
61
|
+
k_scale *= 2
|
62
|
+
v_scale *= 2
|
63
|
+
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
64
|
+
# If no scales were loaded (both scales are invalid negative
|
65
|
+
# values), use the default value of 1.0
|
66
|
+
k_scale = 1.0
|
67
|
+
v_scale = 1.0
|
68
|
+
else:
|
69
|
+
# If we find a single kv_scale in the checkpoint, we remap
|
70
|
+
# kv_scale to k_scale during weight loading, and duplicate
|
71
|
+
# k_scale to v_scale here
|
72
|
+
assert layer.k_scale > 0.0
|
73
|
+
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
74
|
+
k_scale = scale_to_duplicate.to("cpu").tolist()
|
75
|
+
v_scale = scale_to_duplicate.to("cpu").tolist()
|
76
|
+
if _is_hip and self.is_fp8_fnuz():
|
77
|
+
k_scale *= 2
|
78
|
+
v_scale *= 2
|
79
|
+
|
80
|
+
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
|
81
|
+
raise ValueError(
|
82
|
+
"Only support per-tensor scaling factor " "for fp8 KV cache"
|
83
|
+
)
|
84
|
+
|
85
|
+
# These are used in the final Attention.forward()
|
86
|
+
layer.k_scale.copy_(k_scale)
|
87
|
+
layer.v_scale.copy_(v_scale)
|
88
|
+
layer.k_scale_float = k_scale
|
89
|
+
layer.v_scale_float = v_scale
|