sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
|
|
1
|
+
import os
|
1
2
|
from typing import List, Optional, Tuple
|
2
3
|
|
3
4
|
import torch
|
4
5
|
|
5
6
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
7
|
+
_enable_jit_deepgemm,
|
6
8
|
per_token_group_quant_fp8,
|
7
9
|
static_quant_fp8,
|
8
10
|
w8a8_block_fp8_matmul,
|
@@ -15,6 +17,14 @@ from sglang.srt.utils import (
|
|
15
17
|
is_hip,
|
16
18
|
)
|
17
19
|
|
20
|
+
try:
|
21
|
+
import vllm
|
22
|
+
from vllm import _custom_ops as ops
|
23
|
+
|
24
|
+
VLLM_AVAILABLE = True
|
25
|
+
except ImportError:
|
26
|
+
VLLM_AVAILABLE = False
|
27
|
+
|
18
28
|
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
19
29
|
|
20
30
|
_is_hip = is_hip()
|
@@ -23,19 +33,29 @@ if _is_hip and get_bool_env_var("CK_MOE"):
|
|
23
33
|
|
24
34
|
_is_cuda = is_cuda()
|
25
35
|
if _is_cuda:
|
26
|
-
from sgl_kernel import fp8_blockwise_scaled_mm
|
36
|
+
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
|
27
37
|
|
38
|
+
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
28
39
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
|
29
40
|
|
30
|
-
if use_vllm_cutlass_w8a8_fp8_kernel:
|
31
|
-
from vllm import _custom_ops as ops
|
32
|
-
else:
|
33
|
-
from sgl_kernel import fp8_scaled_mm
|
34
|
-
|
35
41
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
36
42
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
37
43
|
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
38
44
|
|
45
|
+
_TORCH_VERSION = torch.__version__.split("+")[0]
|
46
|
+
try:
|
47
|
+
_TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
|
48
|
+
except ValueError:
|
49
|
+
_TORCH_VERSION_TUPLE = (0, 0, 0)
|
50
|
+
|
51
|
+
# The condition to determine if it is on a platform that supports
|
52
|
+
# torch._scaled_mm rowwise feature.
|
53
|
+
# The condition is determined once as the operations
|
54
|
+
# are time consuming.
|
55
|
+
USE_ROWWISE_TORCH_SCALED_MM = (
|
56
|
+
_is_hip and get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
|
57
|
+
)
|
58
|
+
|
39
59
|
|
40
60
|
def cutlass_fp8_supported():
|
41
61
|
if not _is_cuda:
|
@@ -74,7 +94,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
|
74
94
|
|
75
95
|
|
76
96
|
def cutlass_block_fp8_supported() -> bool:
|
77
|
-
if get_bool_env_var("SUPPORT_CUTLASS_BLOCK_FP8"):
|
97
|
+
if not get_bool_env_var("SUPPORT_CUTLASS_BLOCK_FP8"):
|
78
98
|
return False
|
79
99
|
if _is_cuda:
|
80
100
|
major, minor = torch.cuda.get_device_capability()
|
@@ -122,9 +142,17 @@ def apply_w8a8_block_fp8_linear(
|
|
122
142
|
)
|
123
143
|
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
124
144
|
else:
|
125
|
-
|
126
|
-
|
127
|
-
|
145
|
+
if _enable_jit_deepgemm:
|
146
|
+
q_input, x_scale = per_token_group_quant_fp8(
|
147
|
+
input_2d,
|
148
|
+
block_size[1],
|
149
|
+
column_major_scales=True,
|
150
|
+
scale_tma_aligned=True,
|
151
|
+
)
|
152
|
+
else:
|
153
|
+
q_input, x_scale = per_token_group_quant_fp8(
|
154
|
+
input_2d, block_size[1], column_major_scales=False
|
155
|
+
)
|
128
156
|
output = w8a8_block_fp8_matmul(
|
129
157
|
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
|
130
158
|
)
|
@@ -219,24 +247,32 @@ def apply_fp8_linear(
|
|
219
247
|
)
|
220
248
|
|
221
249
|
if cutlass_fp8_supported:
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
250
|
+
try:
|
251
|
+
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
252
|
+
# Fall back to vllm cutlass w8a8 fp8 kernel
|
253
|
+
output = ops.cutlass_scaled_mm(
|
254
|
+
qinput,
|
255
|
+
weight,
|
256
|
+
out_dtype=input.dtype,
|
257
|
+
scale_a=x_scale,
|
258
|
+
scale_b=weight_scale,
|
259
|
+
bias=bias,
|
260
|
+
)
|
261
|
+
else:
|
262
|
+
assert (
|
263
|
+
weight_scale.numel() == weight.shape[1]
|
264
|
+
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
265
|
+
output = fp8_scaled_mm(
|
266
|
+
qinput,
|
267
|
+
weight,
|
268
|
+
x_scale,
|
269
|
+
weight_scale,
|
270
|
+
out_dtype=input.dtype,
|
271
|
+
bias=bias,
|
272
|
+
)
|
273
|
+
return output.view(*output_shape)
|
274
|
+
except (ImportError, NameError, AttributeError):
|
275
|
+
pass
|
240
276
|
|
241
277
|
# torch.scaled_mm supports per tensor weights + activations only
|
242
278
|
# so fallback to naive if per channel or per token
|
@@ -306,3 +342,223 @@ def apply_fp8_linear(
|
|
306
342
|
if bias is not None:
|
307
343
|
output = output + bias
|
308
344
|
return output.to(dtype=input.dtype).view(*output_shape)
|
345
|
+
|
346
|
+
|
347
|
+
def maybe_create_device_identity():
|
348
|
+
# Allocate dummy ones tensor for torch._scaled_mm
|
349
|
+
global TORCH_DEVICE_IDENTITY
|
350
|
+
if TORCH_DEVICE_IDENTITY is None:
|
351
|
+
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
352
|
+
|
353
|
+
|
354
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
|
355
|
+
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
|
356
|
+
# https://github.com/vllm-project/vllm/issues/14397
|
357
|
+
class Fp8LinearOp:
|
358
|
+
"""
|
359
|
+
This class executes a FP8 linear layer using cutlass if supported and
|
360
|
+
torch.scaled_mm otherwise.
|
361
|
+
It needs to be a class instead of a method so that config can be read
|
362
|
+
in the __init__ method, as reading config is not allowed inside forward.
|
363
|
+
"""
|
364
|
+
|
365
|
+
def __init__(
|
366
|
+
self,
|
367
|
+
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
|
368
|
+
use_per_token_if_dynamic: bool = False,
|
369
|
+
pad_output: Optional[bool] = None,
|
370
|
+
):
|
371
|
+
self.cutlass_fp8_supported = cutlass_fp8_supported
|
372
|
+
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
373
|
+
|
374
|
+
# Note: we pad the input because torch._scaled_mm is more performant
|
375
|
+
# for matrices with batch dimension > 16.
|
376
|
+
# This could change in the future.
|
377
|
+
# We also don't pad when using torch.compile,
|
378
|
+
# as it breaks with dynamic shapes.
|
379
|
+
if pad_output is None:
|
380
|
+
enable_torch_compile = os.environ.get(
|
381
|
+
"SGLANG_ENABLE_TORCH_COMPILE", "0"
|
382
|
+
).lower() in ("1", "true", "yes")
|
383
|
+
pad_output = not enable_torch_compile
|
384
|
+
self.output_padding = 17 if pad_output else None
|
385
|
+
|
386
|
+
def apply(
|
387
|
+
self,
|
388
|
+
input: torch.Tensor,
|
389
|
+
weight: torch.Tensor,
|
390
|
+
weight_scale: torch.Tensor,
|
391
|
+
input_scale: Optional[torch.Tensor] = None,
|
392
|
+
input_scale_ub: Optional[torch.Tensor] = None,
|
393
|
+
bias: Optional[torch.Tensor] = None,
|
394
|
+
# TODO(luka) remove this parameter in favor of __init__
|
395
|
+
use_per_token_if_dynamic: Optional[bool] = None,
|
396
|
+
) -> torch.Tensor:
|
397
|
+
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
398
|
+
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
399
|
+
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
400
|
+
|
401
|
+
# View input as 2D matrix for fp8 methods
|
402
|
+
input_2d = input.view(-1, input.shape[-1])
|
403
|
+
output_shape = [*input.shape[:-1], weight.shape[1]]
|
404
|
+
|
405
|
+
# TODO(luka) this is here because currently MLA only decides this
|
406
|
+
# during the forward method instead of in __init__.
|
407
|
+
if use_per_token_if_dynamic is None:
|
408
|
+
use_per_token_if_dynamic = self.use_per_token_if_dynamic
|
409
|
+
|
410
|
+
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
411
|
+
# for sgl-kernel fp8_scaled_mm, it support per channel W now
|
412
|
+
if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
|
413
|
+
if _is_cuda:
|
414
|
+
qinput, x_scale = sgl_scaled_fp8_quant(
|
415
|
+
input_2d,
|
416
|
+
input_scale,
|
417
|
+
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
418
|
+
)
|
419
|
+
else:
|
420
|
+
qinput, x_scale = ops.scaled_fp8_quant(
|
421
|
+
input_2d,
|
422
|
+
input_scale,
|
423
|
+
scale_ub=input_scale_ub,
|
424
|
+
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
425
|
+
)
|
426
|
+
|
427
|
+
# Fused GEMM_DQ
|
428
|
+
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
429
|
+
# Fall back to vllm cutlass w8a8 fp8 kernel
|
430
|
+
output = ops.cutlass_scaled_mm(
|
431
|
+
qinput,
|
432
|
+
weight,
|
433
|
+
out_dtype=input.dtype,
|
434
|
+
scale_a=x_scale,
|
435
|
+
scale_b=weight_scale,
|
436
|
+
bias=bias,
|
437
|
+
)
|
438
|
+
else:
|
439
|
+
assert (
|
440
|
+
weight_scale.numel() == weight.shape[1]
|
441
|
+
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
442
|
+
output = fp8_scaled_mm(
|
443
|
+
qinput,
|
444
|
+
weight,
|
445
|
+
x_scale,
|
446
|
+
weight_scale,
|
447
|
+
out_dtype=input.dtype,
|
448
|
+
bias=bias,
|
449
|
+
)
|
450
|
+
return output.view(*output_shape)
|
451
|
+
|
452
|
+
# torch.scaled_mm supports per tensor weights + activations only
|
453
|
+
# so fallback to naive if per channel or per token
|
454
|
+
else:
|
455
|
+
# Maybe apply padding to output, see comment in __init__
|
456
|
+
if _is_cuda:
|
457
|
+
qinput, x_scale = sgl_scaled_fp8_quant(
|
458
|
+
input_2d,
|
459
|
+
input_scale,
|
460
|
+
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
461
|
+
)
|
462
|
+
if self.output_padding:
|
463
|
+
pad_size = max(self.output_padding - qinput.shape[0], 0)
|
464
|
+
if pad_size > 0:
|
465
|
+
qinput = torch.nn.functional.pad(qinput, (0, 0, 0, pad_size))
|
466
|
+
else:
|
467
|
+
qinput, x_scale = ops.scaled_fp8_quant(
|
468
|
+
input_2d,
|
469
|
+
input_scale,
|
470
|
+
num_token_padding=self.output_padding,
|
471
|
+
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
472
|
+
)
|
473
|
+
|
474
|
+
per_tensor_weights = weight_scale.numel() == 1
|
475
|
+
per_tensor_activations = x_scale.numel() == 1
|
476
|
+
|
477
|
+
if per_tensor_weights and per_tensor_activations:
|
478
|
+
# Fused GEMM_DQ
|
479
|
+
output = torch._scaled_mm(
|
480
|
+
qinput,
|
481
|
+
weight,
|
482
|
+
out_dtype=input.dtype,
|
483
|
+
scale_a=x_scale,
|
484
|
+
scale_b=weight_scale,
|
485
|
+
bias=bias,
|
486
|
+
)
|
487
|
+
# A fix for discrepancy in scaled_mm which returns tuple
|
488
|
+
# for torch < 2.5 and a single value in torch >= 2.5
|
489
|
+
if type(output) is tuple and len(output) == 2:
|
490
|
+
output = output[0]
|
491
|
+
|
492
|
+
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
493
|
+
|
494
|
+
elif (
|
495
|
+
use_per_token_if_dynamic
|
496
|
+
and not per_tensor_weights
|
497
|
+
and not per_tensor_activations
|
498
|
+
and USE_ROWWISE_TORCH_SCALED_MM
|
499
|
+
):
|
500
|
+
# For now validated on ROCm platform
|
501
|
+
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
502
|
+
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
|
503
|
+
# and ROCm 6.3, which only exists in torch 2.7 and above.
|
504
|
+
# For CUDA platform please validate if the
|
505
|
+
# torch._scaled_mm support rowwise scaled GEMM
|
506
|
+
# Fused GEMM_DQ Rowwise GEMM
|
507
|
+
output = torch._scaled_mm(
|
508
|
+
qinput,
|
509
|
+
weight,
|
510
|
+
out_dtype=input.dtype,
|
511
|
+
scale_a=x_scale,
|
512
|
+
scale_b=weight_scale.t(),
|
513
|
+
bias=bias,
|
514
|
+
)
|
515
|
+
|
516
|
+
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
517
|
+
output = output.view(*output_shape)
|
518
|
+
return output
|
519
|
+
|
520
|
+
else:
|
521
|
+
# Fallback for channelwise case, where we use unfused DQ
|
522
|
+
# due to limitations with scaled_mm
|
523
|
+
|
524
|
+
# Symmetric quantized GEMM by definition computes the following:
|
525
|
+
# C = (s_x * X) (s_w * W) + bias
|
526
|
+
# This is equivalent to dequantizing the weights and activations
|
527
|
+
# before applying a GEMM.
|
528
|
+
#
|
529
|
+
# In order to compute quantized operands, a quantized kernel
|
530
|
+
# will rewrite the above like so:
|
531
|
+
# C = s_w * s_x * (X * W) + bias
|
532
|
+
#
|
533
|
+
# For the scaled_mm fallback case, we break this down, since it
|
534
|
+
# does not support s_w being a vector.
|
535
|
+
|
536
|
+
# GEMM
|
537
|
+
# This computes C = (X * W).
|
538
|
+
# Output in fp32 to allow subsequent ops to happen in-place
|
539
|
+
|
540
|
+
global TORCH_DEVICE_IDENTITY
|
541
|
+
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
542
|
+
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
543
|
+
|
544
|
+
output = torch._scaled_mm(
|
545
|
+
qinput,
|
546
|
+
weight,
|
547
|
+
scale_a=TORCH_DEVICE_IDENTITY,
|
548
|
+
scale_b=TORCH_DEVICE_IDENTITY,
|
549
|
+
out_dtype=torch.float32,
|
550
|
+
)
|
551
|
+
# A fix for discrepancy in scaled_mm which returns tuple
|
552
|
+
# for torch < 2.5 and a single value in torch >= 2.5
|
553
|
+
if type(output) is tuple and len(output) == 2:
|
554
|
+
output = output[0]
|
555
|
+
# Unpad (undo num_token_padding)
|
556
|
+
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
557
|
+
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
558
|
+
|
559
|
+
# DQ
|
560
|
+
# C = sw * sx * (X * W) + bias
|
561
|
+
output = output * x_scale * weight_scale.t()
|
562
|
+
if bias is not None:
|
563
|
+
output = output + bias
|
564
|
+
return output.to(dtype=input.dtype).view(*output_shape)
|
@@ -3,11 +3,19 @@ from fractions import Fraction
|
|
3
3
|
from typing import Any, Dict, List, Optional, Union
|
4
4
|
|
5
5
|
import torch
|
6
|
-
from vllm.scalar_type import scalar_types
|
7
6
|
|
8
7
|
from sglang.srt.layers.linear import LinearBase
|
9
8
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
10
|
-
from sglang.srt.
|
9
|
+
from sglang.srt.utils import is_cuda
|
10
|
+
|
11
|
+
_is_cuda = is_cuda()
|
12
|
+
|
13
|
+
try:
|
14
|
+
import vllm
|
15
|
+
|
16
|
+
VLLM_AVAILABLE = True
|
17
|
+
except ImportError:
|
18
|
+
VLLM_AVAILABLE = False
|
11
19
|
|
12
20
|
logger = logging.getLogger(__name__)
|
13
21
|
|
@@ -110,6 +118,9 @@ class GPTQConfig(QuantizationConfig):
|
|
110
118
|
def get_quant_method(
|
111
119
|
self, layer: torch.nn.Module, prefix: str
|
112
120
|
) -> Optional["GPTQLinearMethod"]:
|
121
|
+
if not VLLM_AVAILABLE:
|
122
|
+
raise ImportError("vllm is not installed")
|
123
|
+
|
113
124
|
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
114
125
|
|
115
126
|
from sglang.srt.layers.quantization import get_linear_quant_method
|
@@ -120,11 +131,16 @@ class GPTQConfig(QuantizationConfig):
|
|
120
131
|
class GPTQMarlinConfig(QuantizationConfig):
|
121
132
|
"""Config class for GPTQ Marlin"""
|
122
133
|
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
(
|
127
|
-
|
134
|
+
if VLLM_AVAILABLE:
|
135
|
+
from vllm.scalar_type import scalar_types
|
136
|
+
|
137
|
+
# (num_bits, is_sym) -> quant_type
|
138
|
+
TYPE_MAP = {
|
139
|
+
(4, True): scalar_types.uint4b8,
|
140
|
+
(8, True): scalar_types.uint8b128,
|
141
|
+
}
|
142
|
+
else:
|
143
|
+
raise ImportError("vllm is not installed")
|
128
144
|
|
129
145
|
def __init__(
|
130
146
|
self,
|
@@ -263,6 +279,9 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
263
279
|
def get_quant_method(
|
264
280
|
self, layer: torch.nn.Module, prefix: str
|
265
281
|
) -> Optional["QuantizeMethodBase"]:
|
282
|
+
if not VLLM_AVAILABLE:
|
283
|
+
raise ImportError("vllm is not installed")
|
284
|
+
|
266
285
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
267
286
|
GPTQMarlinLinearMethod,
|
268
287
|
GPTQMarlinMoEMethod,
|
@@ -285,6 +304,9 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
285
304
|
|
286
305
|
@classmethod
|
287
306
|
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
307
|
+
if not VLLM_AVAILABLE:
|
308
|
+
return False
|
309
|
+
|
288
310
|
quant_method = quant_config.get("quant_method", "").lower()
|
289
311
|
num_bits = quant_config.get("bits")
|
290
312
|
group_size = quant_config.get("group_size")
|
@@ -294,9 +316,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
294
316
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
295
317
|
check_marlin_supported,
|
296
318
|
)
|
297
|
-
from vllm.platforms import current_platform
|
298
319
|
|
299
|
-
if not
|
320
|
+
if not _is_cuda:
|
300
321
|
return False
|
301
322
|
|
302
323
|
if quant_method != "gptq":
|
@@ -407,8 +428,14 @@ class MarlinConfig(QuantizationConfig):
|
|
407
428
|
def get_quant_method(
|
408
429
|
self, layer: torch.nn.Module, prefix: str
|
409
430
|
) -> Optional["MarlinLinearMethod"]:
|
431
|
+
if not VLLM_AVAILABLE:
|
432
|
+
raise ImportError("vllm is not installed")
|
433
|
+
|
410
434
|
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
411
435
|
|
436
|
+
# Delay import to avoid circular dependency
|
437
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
438
|
+
|
412
439
|
if isinstance(layer, LinearBase) or (
|
413
440
|
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
414
441
|
):
|
@@ -0,0 +1,98 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/kv_cache.py
|
2
|
+
|
3
|
+
import logging
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from sglang.srt.layers.quantization.base_config import (
|
8
|
+
QuantizationConfig,
|
9
|
+
QuantizeMethodBase,
|
10
|
+
)
|
11
|
+
from sglang.srt.utils import is_hip
|
12
|
+
|
13
|
+
_is_hip = is_hip()
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class BaseKVCacheMethod(QuantizeMethodBase):
|
19
|
+
"""
|
20
|
+
Quant method that adds `_k_scale` and `_v_scale` attributes to the
|
21
|
+
Attention layer to support loading those scaling factors from checkpoints.
|
22
|
+
The k/v_scale will be used to:
|
23
|
+
- quantize k/v_cache entries before saving them to the cache
|
24
|
+
- dequantize k/v_cache entries before fetching them from the cache
|
25
|
+
|
26
|
+
:param quant_config: the appropriate QuantizationConfig
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self, quant_config: QuantizationConfig):
|
30
|
+
self.quant_config = quant_config
|
31
|
+
|
32
|
+
def create_weights(self, layer: torch.nn.Module):
|
33
|
+
"""
|
34
|
+
Create "weight" (aka k_scale and v_scale) for an attention layer.
|
35
|
+
"""
|
36
|
+
# Initialize the KV cache scales to -1.0, which is an invalid value.
|
37
|
+
# If the k/v_scale appears in the checkpoint, it will be
|
38
|
+
# overwritten when loading weights.
|
39
|
+
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
40
|
+
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
41
|
+
|
42
|
+
@classmethod
|
43
|
+
def is_fp8_fnuz(cls) -> bool:
|
44
|
+
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
45
|
+
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
46
|
+
|
47
|
+
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
48
|
+
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
49
|
+
|
50
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
51
|
+
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
|
52
|
+
# regardless whether the kv-scale is available in the checkpoint.
|
53
|
+
# No need to process kv scales after loading if we are going to
|
54
|
+
# calculate them on the fly.
|
55
|
+
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
|
56
|
+
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
57
|
+
# We prefer to use separate k_scale and v_scale if present
|
58
|
+
k_scale = layer.k_scale.to("cpu").tolist()
|
59
|
+
v_scale = layer.v_scale.to("cpu").tolist()
|
60
|
+
if _is_hip and self.is_fp8_fnuz():
|
61
|
+
k_scale *= 2
|
62
|
+
v_scale *= 2
|
63
|
+
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
64
|
+
# If no scales were loaded (both scales are invalid negative
|
65
|
+
# values), use the default value of 1.0
|
66
|
+
k_scale = 1.0
|
67
|
+
v_scale = 1.0
|
68
|
+
else:
|
69
|
+
# If we find a single kv_scale in the checkpoint, we remap
|
70
|
+
# kv_scale to k_scale during weight loading, and duplicate
|
71
|
+
# k_scale to v_scale here
|
72
|
+
assert layer.k_scale > 0.0
|
73
|
+
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
74
|
+
k_scale = scale_to_duplicate.to("cpu").tolist()
|
75
|
+
v_scale = scale_to_duplicate.to("cpu").tolist()
|
76
|
+
if _is_hip and self.is_fp8_fnuz():
|
77
|
+
k_scale *= 2
|
78
|
+
v_scale *= 2
|
79
|
+
|
80
|
+
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
|
81
|
+
raise ValueError(
|
82
|
+
"Only support per-tensor scaling factor " "for fp8 KV cache"
|
83
|
+
)
|
84
|
+
|
85
|
+
# These are used in the final Attention.forward()
|
86
|
+
layer._k_scale.copy_(k_scale)
|
87
|
+
layer._v_scale.copy_(v_scale)
|
88
|
+
layer._k_scale_float = k_scale
|
89
|
+
layer._v_scale_float = v_scale
|
90
|
+
if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
|
91
|
+
logger.warning(
|
92
|
+
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
93
|
+
"may cause accuracy issues. Please make sure k/v_scale "
|
94
|
+
"scaling factors are available in the fp8 checkpoint."
|
95
|
+
)
|
96
|
+
|
97
|
+
del layer.k_scale
|
98
|
+
del layer.v_scale
|
@@ -5,12 +5,6 @@ from typing import Any, Dict, List, Optional
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch.nn.parameter import Parameter
|
8
|
-
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
9
|
-
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
10
|
-
convert_to_channelwise,
|
11
|
-
cutlass_fp8_supported,
|
12
|
-
requantize_with_max_scale,
|
13
|
-
)
|
14
8
|
|
15
9
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
16
10
|
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
@@ -19,7 +13,15 @@ from sglang.srt.layers.quantization.base_config import (
|
|
19
13
|
QuantizationConfig,
|
20
14
|
QuantizeMethodBase,
|
21
15
|
)
|
22
|
-
from sglang.srt.layers.quantization.fp8_utils import
|
16
|
+
from sglang.srt.layers.quantization.fp8_utils import (
|
17
|
+
apply_fp8_linear,
|
18
|
+
cutlass_fp8_supported,
|
19
|
+
)
|
20
|
+
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
21
|
+
from sglang.srt.layers.quantization.utils import (
|
22
|
+
convert_to_channelwise,
|
23
|
+
requantize_with_max_scale,
|
24
|
+
)
|
23
25
|
|
24
26
|
# Initialize logger for the module
|
25
27
|
logger = logging.getLogger(__name__)
|