sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +0 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +62 -6
- sglang/srt/disaggregation/mini_lb.py +5 -1
- sglang/srt/disaggregation/mooncake/conn.py +32 -62
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/prefill.py +40 -4
- sglang/srt/disaggregation/utils.py +15 -0
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +114 -71
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -57
- sglang/srt/layers/quantization/fp8_utils.py +187 -262
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +3 -2
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +1 -0
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +2 -4
- sglang/srt/managers/scheduler.py +12 -71
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +7 -2
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/model_runner.py +20 -27
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +289 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +29 -201
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +34 -32
- sglang/srt/speculative/eagle_worker.py +4 -7
- sglang/srt/utils.py +16 -1
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,20 @@ from typing import List, Optional, Tuple
|
|
3
3
|
|
4
4
|
import torch
|
5
5
|
|
6
|
+
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
7
|
+
|
8
|
+
try:
|
9
|
+
from vllm import _custom_ops as ops
|
10
|
+
|
11
|
+
VLLM_AVAILABLE = True
|
12
|
+
except ImportError:
|
13
|
+
VLLM_AVAILABLE = False
|
14
|
+
|
6
15
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
7
16
|
_enable_jit_deepgemm,
|
8
17
|
per_token_group_quant_fp8,
|
18
|
+
scaled_fp8_quant,
|
19
|
+
sglang_per_token_quant_fp8,
|
9
20
|
static_quant_fp8,
|
10
21
|
w8a8_block_fp8_matmul,
|
11
22
|
)
|
@@ -17,30 +28,20 @@ from sglang.srt.utils import (
|
|
17
28
|
is_hip,
|
18
29
|
)
|
19
30
|
|
20
|
-
try:
|
21
|
-
import vllm
|
22
|
-
from vllm import _custom_ops as ops
|
23
|
-
|
24
|
-
VLLM_AVAILABLE = True
|
25
|
-
except ImportError:
|
26
|
-
VLLM_AVAILABLE = False
|
27
|
-
|
28
|
-
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
29
|
-
|
30
31
|
_is_hip = is_hip()
|
32
|
+
_is_cuda = is_cuda()
|
33
|
+
|
31
34
|
if _is_hip and get_bool_env_var("CK_MOE"):
|
32
35
|
from aiter import gemm_a8w8_blockscale
|
33
36
|
|
34
|
-
_is_cuda = is_cuda()
|
35
37
|
if _is_cuda:
|
36
38
|
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
|
37
39
|
|
38
|
-
|
39
|
-
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
|
40
|
+
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
40
41
|
|
41
42
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
42
43
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
43
|
-
TORCH_DEVICE_IDENTITY =
|
44
|
+
TORCH_DEVICE_IDENTITY = None
|
44
45
|
|
45
46
|
_TORCH_VERSION = torch.__version__.split("+")[0]
|
46
47
|
try:
|
@@ -143,7 +144,7 @@ def apply_w8a8_block_fp8_linear(
|
|
143
144
|
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
144
145
|
else:
|
145
146
|
if _enable_jit_deepgemm:
|
146
|
-
q_input, x_scale =
|
147
|
+
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
147
148
|
input_2d,
|
148
149
|
block_size[1],
|
149
150
|
column_major_scales=True,
|
@@ -214,7 +215,7 @@ def block_quant_to_tensor_quant(
|
|
214
215
|
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
215
216
|
|
216
217
|
x_q_tensor, scale = (
|
217
|
-
|
218
|
+
scaled_fp8_quant(x_dq_block)
|
218
219
|
if _is_cuda
|
219
220
|
else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
220
221
|
)
|
@@ -227,13 +228,50 @@ def channel_quant_to_tensor_quant(
|
|
227
228
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
228
229
|
x_dq_channel = x_q_channel.to(torch.float32) * x_s
|
229
230
|
x_q_tensor, scale = (
|
230
|
-
|
231
|
+
scaled_fp8_quant(x_dq_channel)
|
231
232
|
if _is_cuda
|
232
233
|
else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
|
233
234
|
)
|
234
235
|
return x_q_tensor, scale
|
235
236
|
|
236
237
|
|
238
|
+
def _process_scaled_mm_output(output, input_2d_shape, output_shape):
|
239
|
+
if type(output) is tuple and len(output) == 2:
|
240
|
+
output = output[0]
|
241
|
+
return torch.narrow(output, 0, 0, input_2d_shape[0]).view(*output_shape)
|
242
|
+
|
243
|
+
|
244
|
+
def _apply_fallback_scaled_mm(
|
245
|
+
qinput,
|
246
|
+
weight,
|
247
|
+
x_scale,
|
248
|
+
weight_scale,
|
249
|
+
input_2d_shape,
|
250
|
+
output_shape,
|
251
|
+
bias,
|
252
|
+
input_dtype,
|
253
|
+
):
|
254
|
+
global TORCH_DEVICE_IDENTITY
|
255
|
+
if TORCH_DEVICE_IDENTITY is None:
|
256
|
+
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32, device=weight.device)
|
257
|
+
|
258
|
+
output = torch._scaled_mm(
|
259
|
+
qinput,
|
260
|
+
weight,
|
261
|
+
scale_a=TORCH_DEVICE_IDENTITY,
|
262
|
+
scale_b=TORCH_DEVICE_IDENTITY,
|
263
|
+
out_dtype=torch.float32,
|
264
|
+
)
|
265
|
+
|
266
|
+
output = _process_scaled_mm_output(output, input_2d_shape, output_shape)
|
267
|
+
x_scale = torch.narrow(x_scale, 0, 0, input_2d_shape[0])
|
268
|
+
|
269
|
+
output = output * x_scale * weight_scale.t()
|
270
|
+
if bias is not None:
|
271
|
+
output = output + bias
|
272
|
+
return output.to(dtype=input_dtype)
|
273
|
+
|
274
|
+
|
237
275
|
def apply_fp8_linear(
|
238
276
|
input: torch.Tensor,
|
239
277
|
weight: torch.Tensor,
|
@@ -241,216 +279,33 @@ def apply_fp8_linear(
|
|
241
279
|
input_scale: Optional[torch.Tensor] = None,
|
242
280
|
input_scale_ub: Optional[torch.Tensor] = None,
|
243
281
|
bias: Optional[torch.Tensor] = None,
|
244
|
-
cutlass_fp8_supported: bool =
|
282
|
+
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
|
245
283
|
use_per_token_if_dynamic: bool = False,
|
284
|
+
pad_output: Optional[bool] = None,
|
285
|
+
compressed_tensor_quant: bool = False,
|
246
286
|
) -> torch.Tensor:
|
287
|
+
# Note: we pad the input because torch._scaled_mm is more performant
|
288
|
+
# for matrices with batch dimension > 16.
|
289
|
+
# This could change in the future.
|
290
|
+
# We also don't pad when using torch.compile,
|
291
|
+
# as it breaks with dynamic shapes.
|
292
|
+
if pad_output is None:
|
293
|
+
pad_output = not get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
|
294
|
+
output_padding = 17 if pad_output else None
|
295
|
+
|
247
296
|
# View input as 2D matrix for fp8 methods
|
248
297
|
input_2d = input.view(-1, input.shape[-1])
|
249
298
|
output_shape = [*input.shape[:-1], weight.shape[1]]
|
250
299
|
|
251
|
-
|
252
|
-
if input_scale is not None:
|
253
|
-
assert input_scale.numel() == 1
|
254
|
-
# broadcast per-tensor scale to per-token scale when supporting cutlass
|
255
|
-
qinput, x_scale = static_quant_fp8(
|
256
|
-
input_2d, input_scale, repeat_scale=cutlass_fp8_supported
|
257
|
-
)
|
258
|
-
else:
|
259
|
-
# default use per-token quantization if dynamic
|
260
|
-
if _is_cuda:
|
261
|
-
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
|
262
|
-
else:
|
263
|
-
# TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
|
264
|
-
# final solution should be: 1. add support to per-tensor activation scaling.
|
265
|
-
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
|
266
|
-
if _is_hip and weight_scale.numel() == 1:
|
267
|
-
qinput, x_scale = ops.scaled_fp8_quant(
|
268
|
-
input_2d,
|
269
|
-
input_scale,
|
270
|
-
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
271
|
-
)
|
272
|
-
else:
|
273
|
-
qinput, x_scale = per_token_group_quant_fp8(
|
274
|
-
input_2d, group_size=input_2d.shape[1]
|
275
|
-
)
|
276
|
-
|
277
|
-
if cutlass_fp8_supported:
|
278
|
-
try:
|
279
|
-
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
280
|
-
# Fall back to vllm cutlass w8a8 fp8 kernel
|
281
|
-
output = ops.cutlass_scaled_mm(
|
282
|
-
qinput,
|
283
|
-
weight,
|
284
|
-
out_dtype=input.dtype,
|
285
|
-
scale_a=x_scale,
|
286
|
-
scale_b=weight_scale,
|
287
|
-
bias=bias,
|
288
|
-
)
|
289
|
-
else:
|
290
|
-
assert (
|
291
|
-
weight_scale.numel() == weight.shape[1]
|
292
|
-
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
293
|
-
output = fp8_scaled_mm(
|
294
|
-
qinput,
|
295
|
-
weight,
|
296
|
-
x_scale,
|
297
|
-
weight_scale,
|
298
|
-
out_dtype=input.dtype,
|
299
|
-
bias=bias,
|
300
|
-
)
|
301
|
-
return output.view(*output_shape)
|
302
|
-
except (ImportError, NameError, AttributeError):
|
303
|
-
pass
|
304
|
-
|
305
|
-
# torch.scaled_mm supports per tensor weights + activations only
|
306
|
-
# so fallback to naive if per channel or per token
|
307
|
-
else:
|
308
|
-
per_tensor_weights = weight_scale.numel() == 1
|
309
|
-
per_tensor_activations = x_scale.numel() == 1
|
310
|
-
|
311
|
-
if per_tensor_weights and per_tensor_activations:
|
312
|
-
# Fused GEMM_DQ
|
313
|
-
output = torch._scaled_mm(
|
314
|
-
qinput,
|
315
|
-
weight,
|
316
|
-
out_dtype=input.dtype,
|
317
|
-
scale_a=x_scale,
|
318
|
-
scale_b=weight_scale,
|
319
|
-
bias=bias,
|
320
|
-
)
|
321
|
-
# A fix for discrepancy in scaled_mm which returns tuple
|
322
|
-
# for torch < 2.5 and a single value in torch >= 2.5
|
323
|
-
if type(output) is tuple and len(output) == 2:
|
324
|
-
output = output[0]
|
325
|
-
|
326
|
-
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
327
|
-
|
328
|
-
else:
|
329
|
-
# Fallback for channelwise case, where we use unfused DQ
|
330
|
-
# due to limitations with scaled_mm
|
331
|
-
|
332
|
-
# Symmetric quantized GEMM by definition computes the following:
|
333
|
-
# C = (s_x * X) (s_w * W) + bias
|
334
|
-
# This is equivalent to dequantizing the weights and activations
|
335
|
-
# before applying a GEMM.
|
336
|
-
#
|
337
|
-
# In order to compute quantized operands, a quantized kernel
|
338
|
-
# will rewrite the above like so:
|
339
|
-
# C = s_w * s_x * (X * W) + bias
|
340
|
-
#
|
341
|
-
# For the scaled_mm fallback case, we break this down, since it
|
342
|
-
# does not support s_w being a vector.
|
343
|
-
|
344
|
-
# Making sure the dummy tensor is on the same device as the weight
|
345
|
-
global TORCH_DEVICE_IDENTITY
|
346
|
-
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
347
|
-
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
348
|
-
|
349
|
-
# GEMM
|
350
|
-
# This computes C = (X * W).
|
351
|
-
# Output in fp32 to allow subsequent ops to happen in-place
|
352
|
-
output = torch._scaled_mm(
|
353
|
-
qinput,
|
354
|
-
weight,
|
355
|
-
scale_a=TORCH_DEVICE_IDENTITY,
|
356
|
-
scale_b=TORCH_DEVICE_IDENTITY,
|
357
|
-
out_dtype=torch.float32,
|
358
|
-
)
|
359
|
-
# A fix for discrepancy in scaled_mm which returns tuple
|
360
|
-
# for torch < 2.5 and a single value in torch >= 2.5
|
361
|
-
if type(output) is tuple and len(output) == 2:
|
362
|
-
output = output[0]
|
363
|
-
# Unpad (undo num_token_padding)
|
364
|
-
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
365
|
-
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
366
|
-
|
367
|
-
# DQ
|
368
|
-
# C = sw * sx * (X * W) + bias
|
369
|
-
output = output * x_scale * weight_scale.t()
|
370
|
-
if bias is not None:
|
371
|
-
output = output + bias
|
372
|
-
return output.to(dtype=input.dtype).view(*output_shape)
|
373
|
-
|
374
|
-
|
375
|
-
def maybe_create_device_identity():
|
376
|
-
# Allocate dummy ones tensor for torch._scaled_mm
|
377
|
-
global TORCH_DEVICE_IDENTITY
|
378
|
-
if TORCH_DEVICE_IDENTITY is None:
|
379
|
-
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
380
|
-
|
381
|
-
|
382
|
-
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
|
383
|
-
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
|
384
|
-
# https://github.com/vllm-project/vllm/issues/14397
|
385
|
-
class Fp8LinearOp:
|
386
|
-
"""
|
387
|
-
This class executes a FP8 linear layer using cutlass if supported and
|
388
|
-
torch.scaled_mm otherwise.
|
389
|
-
It needs to be a class instead of a method so that config can be read
|
390
|
-
in the __init__ method, as reading config is not allowed inside forward.
|
391
|
-
"""
|
392
|
-
|
393
|
-
def __init__(
|
394
|
-
self,
|
395
|
-
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
|
396
|
-
use_per_token_if_dynamic: bool = False,
|
397
|
-
pad_output: Optional[bool] = None,
|
398
|
-
):
|
399
|
-
self.cutlass_fp8_supported = cutlass_fp8_supported
|
400
|
-
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
401
|
-
|
402
|
-
# Note: we pad the input because torch._scaled_mm is more performant
|
403
|
-
# for matrices with batch dimension > 16.
|
404
|
-
# This could change in the future.
|
405
|
-
# We also don't pad when using torch.compile,
|
406
|
-
# as it breaks with dynamic shapes.
|
407
|
-
if pad_output is None:
|
408
|
-
enable_torch_compile = os.environ.get(
|
409
|
-
"SGLANG_ENABLE_TORCH_COMPILE", "0"
|
410
|
-
).lower() in ("1", "true", "yes")
|
411
|
-
pad_output = not enable_torch_compile
|
412
|
-
self.output_padding = 17 if pad_output else None
|
413
|
-
|
414
|
-
def apply(
|
415
|
-
self,
|
416
|
-
input: torch.Tensor,
|
417
|
-
weight: torch.Tensor,
|
418
|
-
weight_scale: torch.Tensor,
|
419
|
-
input_scale: Optional[torch.Tensor] = None,
|
420
|
-
input_scale_ub: Optional[torch.Tensor] = None,
|
421
|
-
bias: Optional[torch.Tensor] = None,
|
422
|
-
# TODO(luka) remove this parameter in favor of __init__
|
423
|
-
use_per_token_if_dynamic: Optional[bool] = None,
|
424
|
-
) -> torch.Tensor:
|
425
|
-
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
426
|
-
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
427
|
-
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
428
|
-
|
429
|
-
# View input as 2D matrix for fp8 methods
|
430
|
-
input_2d = input.view(-1, input.shape[-1])
|
431
|
-
output_shape = [*input.shape[:-1], weight.shape[1]]
|
432
|
-
|
433
|
-
# TODO(luka) this is here because currently MLA only decides this
|
434
|
-
# during the forward method instead of in __init__.
|
435
|
-
if use_per_token_if_dynamic is None:
|
436
|
-
use_per_token_if_dynamic = self.use_per_token_if_dynamic
|
437
|
-
|
300
|
+
if compressed_tensor_quant:
|
438
301
|
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
439
302
|
# for sgl-kernel fp8_scaled_mm, it support per channel W now
|
440
|
-
if
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
)
|
447
|
-
else:
|
448
|
-
qinput, x_scale = ops.scaled_fp8_quant(
|
449
|
-
input_2d,
|
450
|
-
input_scale,
|
451
|
-
scale_ub=input_scale_ub,
|
452
|
-
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
453
|
-
)
|
303
|
+
if cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
|
304
|
+
qinput, x_scale = scaled_fp8_quant(
|
305
|
+
input_2d,
|
306
|
+
input_scale,
|
307
|
+
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
308
|
+
)
|
454
309
|
|
455
310
|
# Fused GEMM_DQ
|
456
311
|
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
@@ -481,20 +336,21 @@ class Fp8LinearOp:
|
|
481
336
|
# so fallback to naive if per channel or per token
|
482
337
|
else:
|
483
338
|
# Maybe apply padding to output, see comment in __init__
|
484
|
-
|
485
|
-
|
339
|
+
qinput, x_scale = (
|
340
|
+
scaled_fp8_quant(
|
486
341
|
input_2d,
|
487
342
|
input_scale,
|
488
|
-
num_token_padding=
|
343
|
+
num_token_padding=output_padding,
|
489
344
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
490
345
|
)
|
491
|
-
|
492
|
-
|
346
|
+
if _is_cuda
|
347
|
+
else ops.scaled_fp8_quant(
|
493
348
|
input_2d,
|
494
349
|
input_scale,
|
495
|
-
num_token_padding=
|
350
|
+
num_token_padding=output_padding,
|
496
351
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
497
352
|
)
|
353
|
+
)
|
498
354
|
|
499
355
|
per_tensor_weights = weight_scale.numel() == 1
|
500
356
|
per_tensor_activations = x_scale.numel() == 1
|
@@ -509,12 +365,7 @@ class Fp8LinearOp:
|
|
509
365
|
scale_b=weight_scale,
|
510
366
|
bias=bias,
|
511
367
|
)
|
512
|
-
|
513
|
-
# for torch < 2.5 and a single value in torch >= 2.5
|
514
|
-
if type(output) is tuple and len(output) == 2:
|
515
|
-
output = output[0]
|
516
|
-
|
517
|
-
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
368
|
+
return _process_scaled_mm_output(output, input_2d.shape, output_shape)
|
518
369
|
|
519
370
|
elif (
|
520
371
|
use_per_token_if_dynamic
|
@@ -537,10 +388,7 @@ class Fp8LinearOp:
|
|
537
388
|
scale_b=weight_scale.t(),
|
538
389
|
bias=bias,
|
539
390
|
)
|
540
|
-
|
541
|
-
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
542
|
-
output = output.view(*output_shape)
|
543
|
-
return output
|
391
|
+
return _process_scaled_mm_output(output, input_2d.shape, output_shape)
|
544
392
|
|
545
393
|
else:
|
546
394
|
# Fallback for channelwise case, where we use unfused DQ
|
@@ -557,33 +405,110 @@ class Fp8LinearOp:
|
|
557
405
|
#
|
558
406
|
# For the scaled_mm fallback case, we break this down, since it
|
559
407
|
# does not support s_w being a vector.
|
560
|
-
|
561
|
-
# GEMM
|
562
|
-
# This computes C = (X * W).
|
563
|
-
# Output in fp32 to allow subsequent ops to happen in-place
|
564
|
-
|
565
|
-
global TORCH_DEVICE_IDENTITY
|
566
|
-
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
567
|
-
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
568
|
-
|
569
|
-
output = torch._scaled_mm(
|
408
|
+
return _apply_fallback_scaled_mm(
|
570
409
|
qinput,
|
571
410
|
weight,
|
572
|
-
|
573
|
-
|
574
|
-
|
411
|
+
x_scale,
|
412
|
+
weight_scale,
|
413
|
+
input_2d.shape,
|
414
|
+
output_shape,
|
415
|
+
bias,
|
416
|
+
input.dtype,
|
575
417
|
)
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
418
|
+
else:
|
419
|
+
# cutlass w8a8 fp8 sgl-kernel only supports per-token scale
|
420
|
+
if input_scale is not None:
|
421
|
+
assert input_scale.numel() == 1
|
422
|
+
# broadcast per-tensor scale to per-token scale when supporting cutlass
|
423
|
+
qinput, x_scale = static_quant_fp8(
|
424
|
+
input_2d, input_scale, repeat_scale=cutlass_fp8_supported
|
425
|
+
)
|
426
|
+
else:
|
427
|
+
# default use per-token quantization if dynamic
|
428
|
+
if _is_cuda:
|
429
|
+
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
|
430
|
+
else:
|
431
|
+
# TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
|
432
|
+
# final solution should be: 1. add support to per-tensor activation scaling.
|
433
|
+
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
|
434
|
+
if _is_hip and weight_scale.numel() == 1:
|
435
|
+
qinput, x_scale = ops.scaled_fp8_quant(
|
436
|
+
input_2d,
|
437
|
+
input_scale,
|
438
|
+
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
439
|
+
)
|
440
|
+
else:
|
441
|
+
qinput, x_scale = per_token_group_quant_fp8(
|
442
|
+
input_2d, group_size=input_2d.shape[1]
|
443
|
+
)
|
444
|
+
|
445
|
+
if cutlass_fp8_supported:
|
446
|
+
try:
|
447
|
+
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
448
|
+
# Fall back to vllm cutlass w8a8 fp8 kernel
|
449
|
+
output = ops.cutlass_scaled_mm(
|
450
|
+
qinput,
|
451
|
+
weight,
|
452
|
+
out_dtype=input.dtype,
|
453
|
+
scale_a=x_scale,
|
454
|
+
scale_b=weight_scale,
|
455
|
+
bias=bias,
|
456
|
+
)
|
457
|
+
else:
|
458
|
+
assert (
|
459
|
+
weight_scale.numel() == weight.shape[1]
|
460
|
+
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
461
|
+
output = fp8_scaled_mm(
|
462
|
+
qinput,
|
463
|
+
weight,
|
464
|
+
x_scale,
|
465
|
+
weight_scale,
|
466
|
+
out_dtype=input.dtype,
|
467
|
+
bias=bias,
|
468
|
+
)
|
469
|
+
return output.view(*output_shape)
|
470
|
+
except (ImportError, NameError, AttributeError):
|
471
|
+
pass
|
472
|
+
|
473
|
+
# torch.scaled_mm supports per tensor weights + activations only
|
474
|
+
# so fallback to naive if per channel or per token
|
475
|
+
per_tensor_weights = weight_scale.numel() == 1
|
476
|
+
per_tensor_activations = x_scale.numel() == 1
|
477
|
+
|
478
|
+
if per_tensor_weights and per_tensor_activations:
|
479
|
+
# Fused GEMM_DQ
|
480
|
+
output = torch._scaled_mm(
|
481
|
+
qinput,
|
482
|
+
weight,
|
483
|
+
out_dtype=input.dtype,
|
484
|
+
scale_a=x_scale,
|
485
|
+
scale_b=weight_scale,
|
486
|
+
bias=bias,
|
487
|
+
)
|
488
|
+
return _process_scaled_mm_output(output, input_2d.shape, output_shape)
|
489
|
+
|
490
|
+
else:
|
491
|
+
# Fallback for channelwise case, where we use unfused DQ
|
492
|
+
# due to limitations with scaled_mm
|
493
|
+
|
494
|
+
# Symmetric quantized GEMM by definition computes the following:
|
495
|
+
# C = (s_x * X) (s_w * W) + bias
|
496
|
+
# This is equivalent to dequantizing the weights and activations
|
497
|
+
# before applying a GEMM.
|
498
|
+
#
|
499
|
+
# In order to compute quantized operands, a quantized kernel
|
500
|
+
# will rewrite the above like so:
|
501
|
+
# C = s_w * s_x * (X * W) + bias
|
502
|
+
#
|
503
|
+
# For the scaled_mm fallback case, we break this down, since it
|
504
|
+
# does not support s_w being a vector.
|
505
|
+
return _apply_fallback_scaled_mm(
|
506
|
+
qinput,
|
507
|
+
weight,
|
508
|
+
x_scale,
|
509
|
+
weight_scale,
|
510
|
+
input_2d.shape,
|
511
|
+
output_shape,
|
512
|
+
bias,
|
513
|
+
input.dtype,
|
514
|
+
)
|
@@ -347,6 +347,7 @@ class MoeWNA16Method:
|
|
347
347
|
apply_router_weight_on_input: bool = False,
|
348
348
|
inplace: bool = True,
|
349
349
|
no_combine: bool = False,
|
350
|
+
routed_scaling_factor: Optional[float] = None,
|
350
351
|
) -> torch.Tensor:
|
351
352
|
# avoid circular import
|
352
353
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
@@ -363,6 +364,7 @@ class MoeWNA16Method:
|
|
363
364
|
num_expert_group=num_expert_group,
|
364
365
|
custom_routing_function=custom_routing_function,
|
365
366
|
correction_bias=correction_bias,
|
367
|
+
routed_scaling_factor=routed_scaling_factor,
|
366
368
|
)
|
367
369
|
|
368
370
|
weight_bits = self.quant_config.weight_bits
|
@@ -1,18 +1,17 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
2
2
|
|
3
3
|
from types import MappingProxyType
|
4
|
-
from typing import List, Mapping,
|
4
|
+
from typing import List, Mapping, Tuple, Union
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
|
+
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
8
9
|
from sglang.srt.utils import is_cuda
|
9
10
|
|
10
11
|
_is_cuda = is_cuda()
|
11
12
|
|
12
|
-
if _is_cuda:
|
13
|
-
from
|
14
|
-
else:
|
15
|
-
from vllm import _custom_ops as vllm_ops
|
13
|
+
if not _is_cuda:
|
14
|
+
from vllm._custom_ops import scaled_fp8_quant
|
16
15
|
|
17
16
|
|
18
17
|
def is_fp8_fnuz() -> bool:
|
@@ -116,12 +115,7 @@ def requantize_with_max_scale(
|
|
116
115
|
for idx, logical_width in enumerate(logical_widths):
|
117
116
|
end = start + logical_width
|
118
117
|
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
|
119
|
-
|
120
|
-
weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
|
121
|
-
else:
|
122
|
-
weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
|
123
|
-
weight_dq, max_w_scale
|
124
|
-
)
|
118
|
+
weight[start:end, :], _ = scaled_fp8_quant(weight_dq, max_w_scale)
|
125
119
|
start = end
|
126
120
|
|
127
121
|
return max_w_scale, weight
|
@@ -294,6 +294,7 @@ class W8A8FP8MoEMethod:
|
|
294
294
|
activation: str = "silu",
|
295
295
|
inplace: bool = True,
|
296
296
|
no_combine: bool = False,
|
297
|
+
routed_scaling_factor: Optional[float] = None,
|
297
298
|
) -> torch.Tensor:
|
298
299
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
299
300
|
from sglang.srt.layers.moe.topk import select_experts
|
@@ -309,6 +310,7 @@ class W8A8FP8MoEMethod:
|
|
309
310
|
num_expert_group=num_expert_group,
|
310
311
|
custom_routing_function=custom_routing_function,
|
311
312
|
correction_bias=correction_bias,
|
313
|
+
routed_scaling_factor=routed_scaling_factor,
|
312
314
|
)
|
313
315
|
|
314
316
|
return fused_experts(
|
@@ -1,13 +1,6 @@
|
|
1
1
|
from typing import Any, Callable, Dict, List, Optional
|
2
2
|
|
3
3
|
import torch
|
4
|
-
|
5
|
-
from sglang.srt.utils import is_cuda_available, set_weight_attrs
|
6
|
-
|
7
|
-
is_cuda = is_cuda_available()
|
8
|
-
if is_cuda:
|
9
|
-
from sgl_kernel import int8_scaled_mm
|
10
|
-
|
11
4
|
from torch.nn.parameter import Parameter
|
12
5
|
|
13
6
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
@@ -18,6 +11,11 @@ from sglang.srt.layers.quantization.base_config import (
|
|
18
11
|
QuantizeMethodBase,
|
19
12
|
)
|
20
13
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
14
|
+
from sglang.srt.utils import is_cuda_available, set_weight_attrs
|
15
|
+
|
16
|
+
is_cuda = is_cuda_available()
|
17
|
+
if is_cuda:
|
18
|
+
from sgl_kernel import int8_scaled_mm
|
21
19
|
|
22
20
|
|
23
21
|
class W8A8Int8Config(QuantizationConfig):
|
@@ -233,6 +231,7 @@ class W8A8Int8MoEMethod:
|
|
233
231
|
apply_router_weight_on_input: bool = False,
|
234
232
|
inplace: bool = True,
|
235
233
|
no_combine: bool = False,
|
234
|
+
routed_scaling_factor: Optional[float] = None,
|
236
235
|
) -> torch.Tensor:
|
237
236
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
238
237
|
from sglang.srt.layers.moe.topk import select_experts
|
@@ -248,6 +247,7 @@ class W8A8Int8MoEMethod:
|
|
248
247
|
num_expert_group=num_expert_group,
|
249
248
|
custom_routing_function=custom_routing_function,
|
250
249
|
correction_bias=correction_bias,
|
250
|
+
routed_scaling_factor=routed_scaling_factor,
|
251
251
|
)
|
252
252
|
|
253
253
|
return fused_experts(
|