sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +74 -8
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +213 -118
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +176 -683
- sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +71 -34
- sglang/srt/mem_cache/memory_pool.py +81 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +68 -20
- sglang/srt/model_executor/forward_batch_info.py +23 -10
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +200 -191
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +59 -35
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +24 -16
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -22,17 +22,54 @@ import torch
|
|
22
22
|
import triton
|
23
23
|
import triton.language as tl
|
24
24
|
|
25
|
-
from sglang.srt.utils import
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
25
|
+
from sglang.srt.utils import (
|
26
|
+
direct_register_custom_op,
|
27
|
+
get_device_core_count,
|
28
|
+
get_device_name,
|
29
|
+
is_cuda,
|
30
|
+
is_hip,
|
31
|
+
supports_custom_op,
|
32
|
+
)
|
33
|
+
|
34
|
+
_is_hip = is_hip()
|
35
|
+
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
36
|
+
|
37
|
+
_is_cuda = is_cuda()
|
31
38
|
if _is_cuda:
|
32
|
-
|
39
|
+
import deep_gemm # `pip install "sgl-kernel>=0.0.4.post3"`
|
40
|
+
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
33
41
|
|
34
42
|
logger = logging.getLogger(__name__)
|
35
43
|
|
44
|
+
_enable_jit_deepgemm = int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "0"))
|
45
|
+
|
46
|
+
if supports_custom_op():
|
47
|
+
|
48
|
+
def deep_gemm_fp8_fp8_bf16_nt(
|
49
|
+
A: torch.Tensor,
|
50
|
+
As: torch.Tensor,
|
51
|
+
B: torch.Tensor,
|
52
|
+
Bs: torch.Tensor,
|
53
|
+
C: torch.Tensor,
|
54
|
+
) -> None:
|
55
|
+
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
56
|
+
|
57
|
+
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
58
|
+
A: torch.Tensor,
|
59
|
+
As: torch.Tensor,
|
60
|
+
B: torch.Tensor,
|
61
|
+
Bs: torch.Tensor,
|
62
|
+
C: torch.Tensor,
|
63
|
+
) -> None:
|
64
|
+
return
|
65
|
+
|
66
|
+
direct_register_custom_op(
|
67
|
+
op_name="deep_gemm_fp8_fp8_bf16_nt",
|
68
|
+
op_func=deep_gemm_fp8_fp8_bf16_nt,
|
69
|
+
mutates_args=["C"],
|
70
|
+
fake_impl=deep_gemm_fp8_fp8_bf16_nt_fake,
|
71
|
+
)
|
72
|
+
|
36
73
|
|
37
74
|
@triton.jit
|
38
75
|
def _per_token_group_quant_fp8(
|
@@ -70,7 +107,8 @@ def _per_token_group_quant_fp8(
|
|
70
107
|
# Quant
|
71
108
|
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
72
109
|
y_s = _absmax / fp8_max
|
73
|
-
|
110
|
+
y_s_inv = 1.0 / y_s
|
111
|
+
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
74
112
|
|
75
113
|
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
76
114
|
tl.store(y_s_ptr, y_s)
|
@@ -140,7 +178,7 @@ def per_token_group_quant_fp8(
|
|
140
178
|
x: The input tenosr with ndim >= 2.
|
141
179
|
group_size: The group size used for quantization.
|
142
180
|
eps: The minimum to avoid dividing zero.
|
143
|
-
dtype: The dype of output tensor.
|
181
|
+
dtype: The dype of output tensor.
|
144
182
|
|
145
183
|
Returns:
|
146
184
|
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
@@ -153,7 +191,7 @@ def per_token_group_quant_fp8(
|
|
153
191
|
finfo = torch.finfo(dtype)
|
154
192
|
fp8_max = finfo.max
|
155
193
|
|
156
|
-
if
|
194
|
+
if _is_hip:
|
157
195
|
fp8_max = 224.0
|
158
196
|
|
159
197
|
fp8_min = -fp8_max
|
@@ -241,6 +279,132 @@ def sglang_per_token_group_quant_fp8(
|
|
241
279
|
return x_q, x_s
|
242
280
|
|
243
281
|
|
282
|
+
def sglang_per_token_quant_fp8(
|
283
|
+
x: torch.Tensor,
|
284
|
+
dtype: torch.dtype = fp8_type_,
|
285
|
+
):
|
286
|
+
assert x.is_contiguous(), "`x` is not contiguous"
|
287
|
+
|
288
|
+
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
289
|
+
x_s = torch.empty(
|
290
|
+
x.shape[0],
|
291
|
+
1,
|
292
|
+
device=x.device,
|
293
|
+
dtype=torch.float32,
|
294
|
+
)
|
295
|
+
|
296
|
+
sgl_per_token_quant_fp8(x, x_q, x_s)
|
297
|
+
|
298
|
+
return x_q, x_s
|
299
|
+
|
300
|
+
|
301
|
+
@triton.jit
|
302
|
+
def _static_quant_fp8(
|
303
|
+
# Pointers to inputs and output
|
304
|
+
y_ptr,
|
305
|
+
y_q_ptr,
|
306
|
+
y_s_ptr,
|
307
|
+
y_s_repeat_ptr,
|
308
|
+
# Stride of input
|
309
|
+
y_stride,
|
310
|
+
# Collums of input
|
311
|
+
N,
|
312
|
+
# Information for float8
|
313
|
+
fp8_min,
|
314
|
+
fp8_max,
|
315
|
+
# Meta-parameters
|
316
|
+
BLOCK: tl.constexpr,
|
317
|
+
REPEAT_SCALE: tl.constexpr,
|
318
|
+
):
|
319
|
+
"""A Triton-accelerated function to perform quantization using the given scale on a
|
320
|
+
tensor
|
321
|
+
|
322
|
+
This function converts the tensor values into float8 values.
|
323
|
+
"""
|
324
|
+
# Map the program id to the row of X and Y it should compute.
|
325
|
+
g_id = tl.program_id(0)
|
326
|
+
y_ptr += g_id * y_stride
|
327
|
+
y_q_ptr += g_id * y_stride
|
328
|
+
if REPEAT_SCALE:
|
329
|
+
y_s_repeat_ptr += g_id
|
330
|
+
|
331
|
+
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
332
|
+
mask = cols < N
|
333
|
+
|
334
|
+
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
335
|
+
y_s = tl.load(y_s_ptr).to(tl.float32)
|
336
|
+
y_s_inv = 1.0 / y_s
|
337
|
+
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
338
|
+
|
339
|
+
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
340
|
+
if REPEAT_SCALE:
|
341
|
+
tl.store(y_s_repeat_ptr, y_s)
|
342
|
+
|
343
|
+
|
344
|
+
def static_quant_fp8(
|
345
|
+
x: torch.Tensor,
|
346
|
+
x_s: torch.Tensor,
|
347
|
+
repeat_scale: bool = False,
|
348
|
+
dtype: torch.dtype = fp8_type_,
|
349
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
350
|
+
"""Function to perform static quantization using the given scale on an input tensor `x`.
|
351
|
+
|
352
|
+
It converts the tensor values into signed float8 values and returns the
|
353
|
+
quantized tensor along with the scaling factor used for quantization.
|
354
|
+
|
355
|
+
Args:
|
356
|
+
x: The input tenosr with ndim >= 2.
|
357
|
+
x_s: The quantization scale.
|
358
|
+
repeat_scale: Whether to broadcast per-tensor scale to per-channel scale.
|
359
|
+
dtype: The dype of output tensor.
|
360
|
+
|
361
|
+
Returns:
|
362
|
+
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
363
|
+
"""
|
364
|
+
assert x.is_contiguous(), "`x` is not contiguous"
|
365
|
+
assert x_s.numel() == 1, "only supports per-tensor scale"
|
366
|
+
finfo = torch.finfo(dtype)
|
367
|
+
fp8_max = finfo.max
|
368
|
+
|
369
|
+
if _is_hip:
|
370
|
+
fp8_max = 224.0
|
371
|
+
|
372
|
+
fp8_min = -fp8_max
|
373
|
+
|
374
|
+
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
375
|
+
M = x.numel() // x.shape[-1]
|
376
|
+
N = x.shape[-1]
|
377
|
+
if repeat_scale:
|
378
|
+
x_s_repeat = torch.empty(
|
379
|
+
(M, 1),
|
380
|
+
device=x.device,
|
381
|
+
dtype=torch.float32,
|
382
|
+
)
|
383
|
+
else:
|
384
|
+
x_s_repeat = None
|
385
|
+
|
386
|
+
BLOCK = triton.next_power_of_2(N)
|
387
|
+
# heuristics for number of warps
|
388
|
+
num_warps = min(max(BLOCK // 256, 1), 8)
|
389
|
+
num_stages = 1
|
390
|
+
_static_quant_fp8[(M,)](
|
391
|
+
x,
|
392
|
+
x_q,
|
393
|
+
x_s,
|
394
|
+
x_s_repeat,
|
395
|
+
N,
|
396
|
+
N,
|
397
|
+
fp8_min=fp8_min,
|
398
|
+
fp8_max=fp8_max,
|
399
|
+
BLOCK=BLOCK,
|
400
|
+
REPEAT_SCALE=repeat_scale,
|
401
|
+
num_warps=num_warps,
|
402
|
+
num_stages=num_stages,
|
403
|
+
)
|
404
|
+
x_s = x_s_repeat if repeat_scale else x_s
|
405
|
+
return x_q, x_s
|
406
|
+
|
407
|
+
|
244
408
|
@triton.jit
|
245
409
|
def _w8a8_block_fp8_matmul(
|
246
410
|
# Pointers to inputs and output
|
@@ -595,34 +759,42 @@ def w8a8_block_fp8_matmul(
|
|
595
759
|
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
596
760
|
N, config["BLOCK_SIZE_N"]
|
597
761
|
)
|
598
|
-
kernel = (
|
599
|
-
_w8a8_block_fp8_matmul_unrolledx4
|
600
|
-
if (is_hip_ == True and num_workgroups <= get_device_core_count())
|
601
|
-
else _w8a8_block_fp8_matmul
|
602
|
-
)
|
603
762
|
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
763
|
+
# deepgemm only support bf16
|
764
|
+
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
765
|
+
if supports_custom_op():
|
766
|
+
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
767
|
+
else:
|
768
|
+
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
769
|
+
else:
|
770
|
+
kernel = (
|
771
|
+
_w8a8_block_fp8_matmul_unrolledx4
|
772
|
+
if (_is_hip == True and num_workgroups <= get_device_core_count())
|
773
|
+
else _w8a8_block_fp8_matmul
|
774
|
+
)
|
775
|
+
|
776
|
+
kernel[grid](
|
777
|
+
A,
|
778
|
+
B,
|
779
|
+
C,
|
780
|
+
As,
|
781
|
+
Bs,
|
782
|
+
M,
|
783
|
+
N,
|
784
|
+
K,
|
785
|
+
block_n,
|
786
|
+
block_k,
|
787
|
+
A.stride(-2),
|
788
|
+
A.stride(-1),
|
789
|
+
B.stride(1),
|
790
|
+
B.stride(0),
|
791
|
+
C.stride(-2),
|
792
|
+
C.stride(-1),
|
793
|
+
As.stride(-2),
|
794
|
+
As.stride(-1),
|
795
|
+
Bs.stride(1),
|
796
|
+
Bs.stride(0),
|
797
|
+
**config,
|
798
|
+
)
|
627
799
|
|
628
800
|
return C
|
@@ -1,23 +1,53 @@
|
|
1
|
-
import os
|
2
1
|
from typing import List, Optional, Tuple
|
3
2
|
|
4
3
|
import torch
|
5
4
|
|
6
|
-
from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter
|
7
5
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
8
6
|
per_token_group_quant_fp8,
|
7
|
+
static_quant_fp8,
|
9
8
|
w8a8_block_fp8_matmul,
|
10
9
|
)
|
11
|
-
from sglang.srt.utils import
|
10
|
+
from sglang.srt.utils import (
|
11
|
+
get_bool_env_var,
|
12
|
+
get_cuda_version,
|
13
|
+
get_device_capability,
|
14
|
+
is_cuda,
|
15
|
+
is_hip,
|
16
|
+
)
|
17
|
+
|
18
|
+
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
12
19
|
|
13
|
-
|
14
|
-
if
|
20
|
+
_is_hip = is_hip()
|
21
|
+
if _is_hip and get_bool_env_var("CK_MOE"):
|
15
22
|
from aiter import gemm_a8w8_blockscale
|
16
23
|
|
17
|
-
_is_cuda =
|
24
|
+
_is_cuda = is_cuda()
|
18
25
|
if _is_cuda:
|
19
26
|
from sgl_kernel import fp8_blockwise_scaled_mm
|
20
27
|
|
28
|
+
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
|
29
|
+
|
30
|
+
if use_vllm_cutlass_w8a8_fp8_kernel:
|
31
|
+
from vllm import _custom_ops as ops
|
32
|
+
else:
|
33
|
+
from sgl_kernel import fp8_scaled_mm
|
34
|
+
|
35
|
+
# Input scaling factors are no longer optional in _scaled_mm starting
|
36
|
+
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
37
|
+
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
38
|
+
|
39
|
+
|
40
|
+
def cutlass_fp8_supported():
|
41
|
+
if not _is_cuda:
|
42
|
+
return False
|
43
|
+
major, minor = get_device_capability()
|
44
|
+
cuda_version = get_cuda_version()
|
45
|
+
if major >= 9:
|
46
|
+
return cuda_version >= (12, 0)
|
47
|
+
elif major == 8 and minor == 9:
|
48
|
+
return cuda_version >= (12, 4)
|
49
|
+
return False
|
50
|
+
|
21
51
|
|
22
52
|
def normalize_e4m3fn_to_e4m3fnuz(
|
23
53
|
weight: torch.Tensor,
|
@@ -44,7 +74,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
|
44
74
|
|
45
75
|
|
46
76
|
def cutlass_block_fp8_supported() -> bool:
|
47
|
-
if
|
77
|
+
if get_bool_env_var("SUPPORT_CUTLASS_BLOCK_FP8"):
|
48
78
|
return False
|
49
79
|
if _is_cuda:
|
50
80
|
major, minor = torch.cuda.get_device_capability()
|
@@ -81,7 +111,7 @@ def apply_w8a8_block_fp8_linear(
|
|
81
111
|
output = fp8_blockwise_scaled_mm(
|
82
112
|
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
83
113
|
)
|
84
|
-
elif
|
114
|
+
elif _is_hip and get_bool_env_var("CK_MOE"):
|
85
115
|
q_input, x_scale = per_token_group_quant_fp8(
|
86
116
|
input_2d, block_size[1], column_major_scales=False
|
87
117
|
)
|
@@ -112,7 +142,7 @@ def input_to_float8(
|
|
112
142
|
min_val, max_val = x.aminmax()
|
113
143
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
114
144
|
fp8_max = finfo.max
|
115
|
-
if
|
145
|
+
if _is_hip:
|
116
146
|
fp8_max = 224.0
|
117
147
|
scale = fp8_max / amax
|
118
148
|
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
@@ -158,10 +188,121 @@ def block_quant_to_tensor_quant(
|
|
158
188
|
return x_q_tensor, scale
|
159
189
|
|
160
190
|
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
191
|
+
def apply_fp8_linear(
|
192
|
+
input: torch.Tensor,
|
193
|
+
weight: torch.Tensor,
|
194
|
+
weight_scale: torch.Tensor,
|
195
|
+
input_scale: Optional[torch.Tensor] = None,
|
196
|
+
input_scale_ub: Optional[torch.Tensor] = None,
|
197
|
+
bias: Optional[torch.Tensor] = None,
|
198
|
+
cutlass_fp8_supported: bool = True,
|
199
|
+
use_per_token_if_dynamic: bool = False,
|
200
|
+
) -> torch.Tensor:
|
201
|
+
# View input as 2D matrix for fp8 methods
|
202
|
+
input_2d = input.view(-1, input.shape[-1])
|
203
|
+
output_shape = [*input.shape[:-1], weight.shape[1]]
|
204
|
+
|
205
|
+
# cutlass w8a8 fp8 sgl-kernel only supports per-token scale
|
206
|
+
if input_scale is not None:
|
207
|
+
assert input_scale.numel() == 1
|
208
|
+
# broadcast per-tensor scale to per-token scale when supporting cutlass
|
209
|
+
qinput, x_scale = static_quant_fp8(
|
210
|
+
input_2d, input_scale, repeat_scale=cutlass_fp8_supported
|
211
|
+
)
|
212
|
+
else:
|
213
|
+
# default use per-token quantization if dynamic
|
214
|
+
if _is_cuda:
|
215
|
+
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
|
216
|
+
else:
|
217
|
+
qinput, x_scale = per_token_group_quant_fp8(
|
218
|
+
input_2d, group_size=input_2d.shape[1]
|
219
|
+
)
|
220
|
+
|
221
|
+
if cutlass_fp8_supported:
|
222
|
+
if use_vllm_cutlass_w8a8_fp8_kernel:
|
223
|
+
# Fall back to vllm cutlass w8a8 fp8 kernel
|
224
|
+
output = ops.cutlass_scaled_mm(
|
225
|
+
qinput,
|
226
|
+
weight,
|
227
|
+
out_dtype=input.dtype,
|
228
|
+
scale_a=x_scale,
|
229
|
+
scale_b=weight_scale,
|
230
|
+
bias=bias,
|
231
|
+
)
|
232
|
+
else:
|
233
|
+
assert (
|
234
|
+
weight_scale.numel() == weight.shape[1]
|
235
|
+
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
236
|
+
output = fp8_scaled_mm(
|
237
|
+
qinput, weight, x_scale, weight_scale, out_dtype=input.dtype, bias=bias
|
238
|
+
)
|
239
|
+
return output.view(*output_shape)
|
240
|
+
|
241
|
+
# torch.scaled_mm supports per tensor weights + activations only
|
242
|
+
# so fallback to naive if per channel or per token
|
243
|
+
else:
|
244
|
+
per_tensor_weights = weight_scale.numel() == 1
|
245
|
+
per_tensor_activations = x_scale.numel() == 1
|
246
|
+
|
247
|
+
if per_tensor_weights and per_tensor_activations:
|
248
|
+
# Fused GEMM_DQ
|
249
|
+
output = torch._scaled_mm(
|
250
|
+
qinput,
|
251
|
+
weight,
|
252
|
+
out_dtype=input.dtype,
|
253
|
+
scale_a=x_scale,
|
254
|
+
scale_b=weight_scale,
|
255
|
+
bias=bias,
|
256
|
+
)
|
257
|
+
# A fix for discrepancy in scaled_mm which returns tuple
|
258
|
+
# for torch < 2.5 and a single value in torch >= 2.5
|
259
|
+
if type(output) is tuple and len(output) == 2:
|
260
|
+
output = output[0]
|
261
|
+
|
262
|
+
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
263
|
+
|
264
|
+
else:
|
265
|
+
# Fallback for channelwise case, where we use unfused DQ
|
266
|
+
# due to limitations with scaled_mm
|
267
|
+
|
268
|
+
# Symmetric quantized GEMM by definition computes the following:
|
269
|
+
# C = (s_x * X) (s_w * W) + bias
|
270
|
+
# This is equivalent to dequantizing the weights and activations
|
271
|
+
# before applying a GEMM.
|
272
|
+
#
|
273
|
+
# In order to compute quantized operands, a quantized kernel
|
274
|
+
# will rewrite the above like so:
|
275
|
+
# C = s_w * s_x * (X * W) + bias
|
276
|
+
#
|
277
|
+
# For the scaled_mm fallback case, we break this down, since it
|
278
|
+
# does not support s_w being a vector.
|
279
|
+
|
280
|
+
# Making sure the dummy tensor is on the same device as the weight
|
281
|
+
global TORCH_DEVICE_IDENTITY
|
282
|
+
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
283
|
+
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
284
|
+
|
285
|
+
# GEMM
|
286
|
+
# This computes C = (X * W).
|
287
|
+
# Output in fp32 to allow subsequent ops to happen in-place
|
288
|
+
output = torch._scaled_mm(
|
289
|
+
qinput,
|
290
|
+
weight,
|
291
|
+
scale_a=TORCH_DEVICE_IDENTITY,
|
292
|
+
scale_b=TORCH_DEVICE_IDENTITY,
|
293
|
+
out_dtype=torch.float32,
|
294
|
+
)
|
295
|
+
# A fix for discrepancy in scaled_mm which returns tuple
|
296
|
+
# for torch < 2.5 and a single value in torch >= 2.5
|
297
|
+
if type(output) is tuple and len(output) == 2:
|
298
|
+
output = output[0]
|
299
|
+
# Unpad (undo num_token_padding)
|
300
|
+
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
301
|
+
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
166
302
|
|
167
|
-
|
303
|
+
# DQ
|
304
|
+
# C = sw * sx * (X * W) + bias
|
305
|
+
output = output * x_scale * weight_scale.t()
|
306
|
+
if bias is not None:
|
307
|
+
output = output + bias
|
308
|
+
return output.to(dtype=input.dtype).view(*output_shape)
|
@@ -7,7 +7,7 @@ import torch
|
|
7
7
|
from torch.nn.parameter import Parameter
|
8
8
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
9
9
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
10
|
-
|
10
|
+
convert_to_channelwise,
|
11
11
|
cutlass_fp8_supported,
|
12
12
|
requantize_with_max_scale,
|
13
13
|
)
|
@@ -19,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
19
19
|
QuantizationConfig,
|
20
20
|
QuantizeMethodBase,
|
21
21
|
)
|
22
|
+
from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
|
22
23
|
|
23
24
|
# Initialize logger for the module
|
24
25
|
logger = logging.getLogger(__name__)
|
@@ -161,6 +162,9 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
|
161
162
|
layer.weight, layer.weight_scale, layer.logical_widths
|
162
163
|
)
|
163
164
|
layer.weight = Parameter(quantized_weight.t(), requires_grad=False)
|
165
|
+
# cutlass sgl-kernel only supports per-channel scale
|
166
|
+
if self.cutlass_fp8_supported:
|
167
|
+
max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths)
|
164
168
|
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
165
169
|
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
|
166
170
|
|
@@ -0,0 +1,128 @@
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch.nn.parameter import Parameter
|
5
|
+
|
6
|
+
from sglang.srt.layers.linear import LinearMethodBase
|
7
|
+
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
8
|
+
from sglang.srt.layers.quantization.base_config import (
|
9
|
+
QuantizationConfig,
|
10
|
+
QuantizeMethodBase,
|
11
|
+
)
|
12
|
+
from sglang.srt.layers.quantization.fp8_utils import (
|
13
|
+
apply_fp8_linear,
|
14
|
+
cutlass_fp8_supported,
|
15
|
+
normalize_e4m3fn_to_e4m3fnuz,
|
16
|
+
)
|
17
|
+
from sglang.srt.utils import is_hip
|
18
|
+
|
19
|
+
_is_hip = is_hip()
|
20
|
+
|
21
|
+
|
22
|
+
class W8A8Fp8Config(QuantizationConfig):
|
23
|
+
"""Config class for W8A8 FP8 Quantization.
|
24
|
+
|
25
|
+
- Weight: static, per-channel, symmetric
|
26
|
+
- Activation: dynamic, per-token, symmetric
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self):
|
30
|
+
pass
|
31
|
+
|
32
|
+
@classmethod
|
33
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
34
|
+
return [torch.float16, torch.bfloat16]
|
35
|
+
|
36
|
+
@classmethod
|
37
|
+
def get_min_capability(cls) -> int:
|
38
|
+
return 89
|
39
|
+
|
40
|
+
@classmethod
|
41
|
+
def get_name(self) -> str:
|
42
|
+
return "w8a8_fp8"
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def get_config_filenames(cls) -> List[str]:
|
46
|
+
return []
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
|
50
|
+
return cls()
|
51
|
+
|
52
|
+
def get_quant_method(
|
53
|
+
self,
|
54
|
+
layer: torch.nn.Module,
|
55
|
+
prefix: str,
|
56
|
+
) -> Optional["QuantizeMethodBase"]:
|
57
|
+
from sglang.srt.layers.linear import LinearBase
|
58
|
+
|
59
|
+
if isinstance(layer, LinearBase):
|
60
|
+
return W8A8Fp8LinearMethod(self)
|
61
|
+
return None
|
62
|
+
|
63
|
+
def get_scaled_act_names(self) -> List[str]:
|
64
|
+
return []
|
65
|
+
|
66
|
+
|
67
|
+
class W8A8Fp8LinearMethod(LinearMethodBase):
|
68
|
+
|
69
|
+
def __init__(self, quantization_config: W8A8Fp8Config):
|
70
|
+
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
71
|
+
self.quantization_config = quantization_config
|
72
|
+
|
73
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
74
|
+
weight = layer.weight
|
75
|
+
weight_scale = layer.weight_scale.detach()
|
76
|
+
if _is_hip:
|
77
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
78
|
+
weight=weight, weight_scale=weight_scale
|
79
|
+
)
|
80
|
+
layer.weight = Parameter(weight.t(), requires_grad=False)
|
81
|
+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
82
|
+
|
83
|
+
def create_weights(
|
84
|
+
self,
|
85
|
+
layer: torch.nn.Module,
|
86
|
+
input_size_per_partition: int,
|
87
|
+
output_partition_sizes: List[int],
|
88
|
+
input_size: int,
|
89
|
+
output_size: int,
|
90
|
+
params_dtype: torch.dtype,
|
91
|
+
**extra_weight_attrs
|
92
|
+
):
|
93
|
+
|
94
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
95
|
+
self.logical_widths = output_partition_sizes
|
96
|
+
|
97
|
+
weight = ModelWeightParameter(
|
98
|
+
data=torch.empty(
|
99
|
+
sum(output_partition_sizes),
|
100
|
+
input_size_per_partition,
|
101
|
+
dtype=torch.float8_e4m3fn,
|
102
|
+
),
|
103
|
+
input_dim=1,
|
104
|
+
output_dim=0,
|
105
|
+
weight_loader=weight_loader,
|
106
|
+
)
|
107
|
+
layer.register_parameter("weight", weight)
|
108
|
+
|
109
|
+
weight_scale = ChannelQuantScaleParameter(
|
110
|
+
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
111
|
+
output_dim=0,
|
112
|
+
weight_loader=weight_loader,
|
113
|
+
)
|
114
|
+
layer.register_parameter("weight_scale", weight_scale)
|
115
|
+
|
116
|
+
def apply(
|
117
|
+
self,
|
118
|
+
layer: torch.nn.Module,
|
119
|
+
x: torch.Tensor,
|
120
|
+
bias: Optional[torch.Tensor] = None,
|
121
|
+
):
|
122
|
+
return apply_fp8_linear(
|
123
|
+
x,
|
124
|
+
layer.weight,
|
125
|
+
layer.weight_scale,
|
126
|
+
bias=bias,
|
127
|
+
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
128
|
+
)
|