sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +72 -10
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/deepseekvl2.py +10 -1
- sglang/srt/configs/model_config.py +6 -16
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/parallel_state.py +32 -5
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/entrypoints/http_server.py +7 -1
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/attention/flashattention_backend.py +582 -125
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/dp_attention.py +12 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +79 -6
- sglang/srt/layers/quantization/__init__.py +137 -165
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_kernel.py +2 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/gptq.py +30 -40
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +19 -33
- sglang/srt/lora/lora_manager.py +20 -7
- sglang/srt/lora/mem_pool.py +12 -6
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +6 -0
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/io_struct.py +4 -2
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +44 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -127
- sglang/srt/managers/scheduler.py +29 -23
- sglang/srt/managers/tokenizer_manager.py +1 -2
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +16 -13
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +64 -59
- sglang/srt/model_loader/loader.py +19 -1
- sglang/srt/model_loader/weight_utils.py +6 -3
- sglang/srt/models/clip.py +568 -0
- sglang/srt/models/deepseek_janus_pro.py +12 -17
- sglang/srt/models/deepseek_v2.py +339 -123
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_causal.py +12 -2
- sglang/srt/models/gemma3_mm.py +20 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +106 -93
- sglang/srt/openai_api/protocol.py +10 -5
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +120 -25
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +94 -25
- sglang/srt/utils.py +137 -51
- sglang/test/runners.py +27 -2
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +14 -27
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,200 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
import logging
|
3
|
+
from typing import Any, Dict, List, Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from sgl_kernel import awq_dequantize
|
7
|
+
|
8
|
+
from sglang.srt.layers.linear import (
|
9
|
+
LinearBase,
|
10
|
+
LinearMethodBase,
|
11
|
+
UnquantizedLinearMethod,
|
12
|
+
)
|
13
|
+
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
14
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
|
20
|
+
return any(module_name in prefix for module_name in modules_to_not_convert)
|
21
|
+
|
22
|
+
|
23
|
+
class AWQConfig(QuantizationConfig):
|
24
|
+
"""Config class for AWQ.
|
25
|
+
|
26
|
+
Reference: https://arxiv.org/abs/2306.00978
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
weight_bits: int,
|
32
|
+
group_size: int,
|
33
|
+
zero_point: bool,
|
34
|
+
modules_to_not_convert: Optional[List[str]] = None,
|
35
|
+
) -> None:
|
36
|
+
super().__init__()
|
37
|
+
self.weight_bits = weight_bits
|
38
|
+
self.group_size = group_size
|
39
|
+
self.zero_point = zero_point
|
40
|
+
self.modules_to_not_convert = modules_to_not_convert or []
|
41
|
+
|
42
|
+
if self.weight_bits != 4:
|
43
|
+
raise ValueError(
|
44
|
+
"Currently, only 4-bit weight quantization is supported for "
|
45
|
+
f"AWQ, but got {self.weight_bits} bits."
|
46
|
+
)
|
47
|
+
self.pack_factor = 32 // self.weight_bits
|
48
|
+
|
49
|
+
def __repr__(self) -> str:
|
50
|
+
return (
|
51
|
+
f"AWQConfig(weight_bits={self.weight_bits}, "
|
52
|
+
f"group_size={self.group_size}, "
|
53
|
+
f"zero_point={self.zero_point}, "
|
54
|
+
f"modules_to_not_convert={self.modules_to_not_convert})"
|
55
|
+
)
|
56
|
+
|
57
|
+
def get_scaled_act_names(self) -> List[str]:
|
58
|
+
return []
|
59
|
+
|
60
|
+
def get_name(self) -> str:
|
61
|
+
return "awq"
|
62
|
+
|
63
|
+
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
64
|
+
return [torch.half]
|
65
|
+
|
66
|
+
@classmethod
|
67
|
+
def get_min_capability(cls) -> int:
|
68
|
+
# The AWQ kernel only supports Turing or newer GPUs.
|
69
|
+
return 75
|
70
|
+
|
71
|
+
@staticmethod
|
72
|
+
def get_config_filenames() -> List[str]:
|
73
|
+
return [
|
74
|
+
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
75
|
+
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
|
76
|
+
"quantize_config.json",
|
77
|
+
]
|
78
|
+
|
79
|
+
@classmethod
|
80
|
+
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
81
|
+
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
82
|
+
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
83
|
+
zero_point = cls.get_from_keys(config, ["zero_point"])
|
84
|
+
modules_to_not_convert = cls.get_from_keys_or(
|
85
|
+
config, ["modules_to_not_convert"], None
|
86
|
+
)
|
87
|
+
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
|
88
|
+
|
89
|
+
def get_quant_method(
|
90
|
+
self, layer: torch.nn.Module, prefix: str
|
91
|
+
) -> Optional["LinearMethodBase"]:
|
92
|
+
|
93
|
+
if isinstance(layer, LinearBase):
|
94
|
+
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
95
|
+
return UnquantizedLinearMethod()
|
96
|
+
return AWQLinearMethod(self)
|
97
|
+
return None
|
98
|
+
|
99
|
+
|
100
|
+
class AWQLinearMethod(LinearMethodBase):
|
101
|
+
"""Linear method for AWQ.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
quant_config: The AWQ quantization config.
|
105
|
+
"""
|
106
|
+
|
107
|
+
def __init__(self, quant_config: AWQConfig):
|
108
|
+
self.quant_config = quant_config
|
109
|
+
|
110
|
+
def create_weights(
|
111
|
+
self,
|
112
|
+
layer: torch.nn.Module,
|
113
|
+
input_size_per_partition: int,
|
114
|
+
output_partition_sizes: List[int],
|
115
|
+
input_size: int,
|
116
|
+
output_size: int,
|
117
|
+
params_dtype: torch.dtype,
|
118
|
+
**extra_weight_attrs,
|
119
|
+
):
|
120
|
+
if input_size_per_partition % self.quant_config.group_size != 0:
|
121
|
+
raise ValueError(
|
122
|
+
"The input size is not aligned with the quantized "
|
123
|
+
"weight shape. This can be caused by too large "
|
124
|
+
"tensor parallel size."
|
125
|
+
)
|
126
|
+
|
127
|
+
output_size_per_partition = sum(output_partition_sizes)
|
128
|
+
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
129
|
+
raise ValueError(
|
130
|
+
"The output size is not aligned with the quantized "
|
131
|
+
"weight shape. This can be caused by too large "
|
132
|
+
"tensor parallel size."
|
133
|
+
)
|
134
|
+
|
135
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
136
|
+
qweight = PackedvLLMParameter(
|
137
|
+
data=torch.empty(
|
138
|
+
input_size_per_partition,
|
139
|
+
output_size_per_partition // self.quant_config.pack_factor,
|
140
|
+
dtype=torch.int32,
|
141
|
+
),
|
142
|
+
input_dim=0,
|
143
|
+
output_dim=1,
|
144
|
+
packed_dim=1,
|
145
|
+
packed_factor=self.quant_config.pack_factor,
|
146
|
+
weight_loader=weight_loader,
|
147
|
+
)
|
148
|
+
|
149
|
+
qzeros = PackedvLLMParameter(
|
150
|
+
data=torch.empty(
|
151
|
+
input_size_per_partition // self.quant_config.group_size,
|
152
|
+
output_size_per_partition // self.quant_config.pack_factor,
|
153
|
+
dtype=torch.int32,
|
154
|
+
),
|
155
|
+
input_dim=0,
|
156
|
+
output_dim=1,
|
157
|
+
packed_dim=1,
|
158
|
+
packed_factor=self.quant_config.pack_factor,
|
159
|
+
weight_loader=weight_loader,
|
160
|
+
)
|
161
|
+
|
162
|
+
scales = GroupQuantScaleParameter(
|
163
|
+
data=torch.empty(
|
164
|
+
input_size_per_partition // self.quant_config.group_size,
|
165
|
+
output_size_per_partition,
|
166
|
+
dtype=params_dtype,
|
167
|
+
),
|
168
|
+
input_dim=0,
|
169
|
+
output_dim=1,
|
170
|
+
weight_loader=weight_loader,
|
171
|
+
)
|
172
|
+
|
173
|
+
layer.register_parameter("qweight", qweight)
|
174
|
+
layer.register_parameter("qzeros", qzeros)
|
175
|
+
layer.register_parameter("scales", scales)
|
176
|
+
|
177
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
178
|
+
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
|
179
|
+
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
|
180
|
+
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
181
|
+
|
182
|
+
def apply(
|
183
|
+
self,
|
184
|
+
layer: torch.nn.Module,
|
185
|
+
x: torch.Tensor,
|
186
|
+
bias: Optional[torch.Tensor] = None,
|
187
|
+
) -> torch.Tensor:
|
188
|
+
qweight = layer.qweight
|
189
|
+
scales = layer.scales
|
190
|
+
qzeros = layer.qzeros
|
191
|
+
pack_factor = self.quant_config.pack_factor
|
192
|
+
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
|
193
|
+
reshaped_x = x.reshape(-1, x.shape[-1])
|
194
|
+
|
195
|
+
out = awq_dequantize(qweight, scales, qzeros)
|
196
|
+
out = torch.matmul(reshaped_x, out)
|
197
|
+
|
198
|
+
if bias is not None:
|
199
|
+
out.add_(bias)
|
200
|
+
return out.reshape(out_shape)
|
@@ -23,7 +23,6 @@ from sglang.srt.layers.linear import (
|
|
23
23
|
LinearMethodBase,
|
24
24
|
UnquantizedLinearMethod,
|
25
25
|
)
|
26
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
27
26
|
from sglang.srt.layers.quantization.base_config import (
|
28
27
|
QuantizationConfig,
|
29
28
|
QuantizeMethodBase,
|
@@ -123,6 +122,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
123
122
|
return UnquantizedLinearMethod()
|
124
123
|
layer.scheme = scheme
|
125
124
|
return CompressedTensorsLinearMethod(self)
|
125
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
126
|
+
|
126
127
|
if isinstance(layer, FusedMoE):
|
127
128
|
return CompressedTensorsMoEMethod.get_moe_method(self)
|
128
129
|
return None
|
@@ -4,18 +4,19 @@
|
|
4
4
|
import enum
|
5
5
|
import logging
|
6
6
|
from enum import Enum
|
7
|
-
from typing import Callable, List, Optional
|
7
|
+
from typing import TYPE_CHECKING, Callable, List, Optional
|
8
8
|
|
9
9
|
import torch
|
10
10
|
from compressed_tensors import CompressionFormat
|
11
11
|
from compressed_tensors.quantization import QuantizationStrategy
|
12
12
|
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
from sglang.srt.layers.moe.fused_moe_triton import (
|
15
|
+
FusedMoE,
|
16
|
+
FusedMoEMethodBase,
|
17
|
+
FusedMoeWeightScaleSupported,
|
18
|
+
)
|
19
|
+
|
19
20
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
20
21
|
from sglang.srt.layers.quantization.utils import (
|
21
22
|
all_close_1d,
|
@@ -55,7 +56,13 @@ __all__ = [
|
|
55
56
|
]
|
56
57
|
|
57
58
|
|
58
|
-
class CompressedTensorsMoEMethod
|
59
|
+
class CompressedTensorsMoEMethod:
|
60
|
+
def __new__(cls, *args, **kwargs):
|
61
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
62
|
+
|
63
|
+
if cls is CompressedTensorsMoEMethod:
|
64
|
+
return super().__new__(cls)
|
65
|
+
return super().__new__(cls)
|
59
66
|
|
60
67
|
@staticmethod
|
61
68
|
def get_moe_method(
|
@@ -85,6 +92,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
85
92
|
def __init__(
|
86
93
|
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
87
94
|
):
|
95
|
+
from sglang.srt.layers.moe.fused_moe_triton import (
|
96
|
+
FusedMoEMethodBase,
|
97
|
+
FusedMoeWeightScaleSupported,
|
98
|
+
)
|
99
|
+
|
88
100
|
self.quant_config = quant_config
|
89
101
|
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
90
102
|
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
@@ -112,6 +124,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
112
124
|
params_dtype: torch.dtype,
|
113
125
|
**extra_weight_attrs,
|
114
126
|
):
|
127
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
115
128
|
|
116
129
|
params_dtype = torch.float8_e4m3fn
|
117
130
|
|
@@ -270,8 +283,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
270
283
|
scoring_func: str = "softmax",
|
271
284
|
correction_bias: Optional[torch.Tensor] = None,
|
272
285
|
activation: str = "silu",
|
286
|
+
inplace: bool = True,
|
287
|
+
no_combine: bool = False,
|
273
288
|
) -> torch.Tensor:
|
274
|
-
from sglang.srt.layers.moe.fused_moe_triton
|
289
|
+
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
290
|
+
from sglang.srt.layers.moe.topk import select_experts
|
275
291
|
|
276
292
|
topk_weights, topk_ids = select_experts(
|
277
293
|
hidden_states=x,
|
@@ -291,7 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
291
307
|
layer.w2_weight,
|
292
308
|
topk_weights=topk_weights,
|
293
309
|
topk_ids=topk_ids,
|
294
|
-
inplace=
|
310
|
+
inplace=inplace,
|
295
311
|
activation=activation,
|
296
312
|
use_fp8_w8a8=True,
|
297
313
|
w1_scale=layer.w13_weight_scale,
|
@@ -306,6 +322,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
306
322
|
def __init__(
|
307
323
|
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
308
324
|
):
|
325
|
+
from sglang.srt.layers.moe.fused_moe_triton import (
|
326
|
+
FusedMoEMethodBase,
|
327
|
+
FusedMoeWeightScaleSupported,
|
328
|
+
)
|
329
|
+
|
309
330
|
self.quant_config = quant_config
|
310
331
|
# TODO: @dsikka: refactor this to use schemes as other kernels
|
311
332
|
# are supported + check if the layer is being ignored.
|
@@ -617,6 +638,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
617
638
|
correction_bias: Optional[torch.Tensor] = None,
|
618
639
|
activation: str = "silu",
|
619
640
|
) -> torch.Tensor:
|
641
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
642
|
+
from sglang.srt.layers.moe.topk import select_experts
|
643
|
+
|
620
644
|
assert activation == "silu", "Only SiLU activation is supported."
|
621
645
|
if not VLLM_AVAILABLE:
|
622
646
|
raise ImportError(
|
@@ -24,6 +24,7 @@ import triton.language as tl
|
|
24
24
|
|
25
25
|
from sglang.srt.utils import (
|
26
26
|
direct_register_custom_op,
|
27
|
+
get_bool_env_var,
|
27
28
|
get_device_core_count,
|
28
29
|
get_device_name,
|
29
30
|
get_device_sm,
|
@@ -43,7 +44,7 @@ if _is_cuda:
|
|
43
44
|
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
44
45
|
|
45
46
|
sm_version = get_device_sm()
|
46
|
-
if sm_version >= 90 and
|
47
|
+
if sm_version >= 90 and get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
47
48
|
_enable_jit_deepgemm = True
|
48
49
|
|
49
50
|
|
@@ -457,12 +457,9 @@ class Fp8LinearOp:
|
|
457
457
|
qinput, x_scale = sgl_scaled_fp8_quant(
|
458
458
|
input_2d,
|
459
459
|
input_scale,
|
460
|
+
num_token_padding=self.output_padding,
|
460
461
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
461
462
|
)
|
462
|
-
if self.output_padding:
|
463
|
-
pad_size = max(self.output_padding - qinput.shape[0], 0)
|
464
|
-
if pad_size > 0:
|
465
|
-
qinput = torch.nn.functional.pad(qinput, (0, 0, 0, pad_size))
|
466
463
|
else:
|
467
464
|
qinput, x_scale = ops.scaled_fp8_quant(
|
468
465
|
input_2d,
|
@@ -11,12 +11,29 @@ from sglang.srt.utils import is_cuda
|
|
11
11
|
_is_cuda = is_cuda()
|
12
12
|
|
13
13
|
try:
|
14
|
-
import
|
14
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
15
|
+
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
16
|
+
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
17
|
+
GPTQMarlinLinearMethod,
|
18
|
+
GPTQMarlinMoEMethod,
|
19
|
+
)
|
20
|
+
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
21
|
+
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
22
|
+
check_marlin_supported,
|
23
|
+
)
|
24
|
+
from vllm.scalar_type import scalar_types
|
15
25
|
|
16
26
|
VLLM_AVAILABLE = True
|
17
27
|
except ImportError:
|
18
28
|
VLLM_AVAILABLE = False
|
19
29
|
|
30
|
+
GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any
|
31
|
+
|
32
|
+
class scalar_types:
|
33
|
+
uint4b8 = "uint4b8"
|
34
|
+
uint8b128 = "uint8b128"
|
35
|
+
|
36
|
+
|
20
37
|
logger = logging.getLogger(__name__)
|
21
38
|
|
22
39
|
|
@@ -117,12 +134,8 @@ class GPTQConfig(QuantizationConfig):
|
|
117
134
|
|
118
135
|
def get_quant_method(
|
119
136
|
self, layer: torch.nn.Module, prefix: str
|
120
|
-
) -> Optional[
|
121
|
-
|
122
|
-
raise ImportError("vllm is not installed")
|
123
|
-
|
124
|
-
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
125
|
-
|
137
|
+
) -> Optional[GPTQLinearMethod]:
|
138
|
+
# Delay the import to avoid circular dependency
|
126
139
|
from sglang.srt.layers.quantization import get_linear_quant_method
|
127
140
|
|
128
141
|
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
@@ -131,16 +144,11 @@ class GPTQConfig(QuantizationConfig):
|
|
131
144
|
class GPTQMarlinConfig(QuantizationConfig):
|
132
145
|
"""Config class for GPTQ Marlin"""
|
133
146
|
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
(4, True): scalar_types.uint4b8,
|
140
|
-
(8, True): scalar_types.uint8b128,
|
141
|
-
}
|
142
|
-
else:
|
143
|
-
raise ImportError("vllm is not installed")
|
147
|
+
# (num_bits, is_sym) -> quant_type
|
148
|
+
TYPE_MAP = {
|
149
|
+
(4, True): scalar_types.uint4b8,
|
150
|
+
(8, True): scalar_types.uint8b128,
|
151
|
+
}
|
144
152
|
|
145
153
|
def __init__(
|
146
154
|
self,
|
@@ -197,6 +205,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
197
205
|
"Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
|
198
206
|
)
|
199
207
|
|
208
|
+
# (num_bits, is_sym) -> quant_type
|
200
209
|
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
201
210
|
|
202
211
|
def __repr__(self) -> str:
|
@@ -278,15 +287,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
278
287
|
|
279
288
|
def get_quant_method(
|
280
289
|
self, layer: torch.nn.Module, prefix: str
|
281
|
-
) -> Optional[
|
282
|
-
|
283
|
-
raise ImportError("vllm is not installed")
|
284
|
-
|
285
|
-
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
286
|
-
GPTQMarlinLinearMethod,
|
287
|
-
GPTQMarlinMoEMethod,
|
288
|
-
)
|
289
|
-
|
290
|
+
) -> Optional[QuantizeMethodBase]:
|
291
|
+
# Delay the import to avoid circular dependency
|
290
292
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
291
293
|
from sglang.srt.layers.quantization import get_linear_quant_method
|
292
294
|
|
@@ -304,19 +306,12 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
304
306
|
|
305
307
|
@classmethod
|
306
308
|
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
307
|
-
if not VLLM_AVAILABLE:
|
308
|
-
return False
|
309
|
-
|
310
309
|
quant_method = quant_config.get("quant_method", "").lower()
|
311
310
|
num_bits = quant_config.get("bits")
|
312
311
|
group_size = quant_config.get("group_size")
|
313
312
|
sym = quant_config.get("sym")
|
314
313
|
desc_act = quant_config.get("desc_act")
|
315
314
|
|
316
|
-
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
317
|
-
check_marlin_supported,
|
318
|
-
)
|
319
|
-
|
320
315
|
if not _is_cuda:
|
321
316
|
return False
|
322
317
|
|
@@ -427,13 +422,8 @@ class MarlinConfig(QuantizationConfig):
|
|
427
422
|
|
428
423
|
def get_quant_method(
|
429
424
|
self, layer: torch.nn.Module, prefix: str
|
430
|
-
) -> Optional[
|
431
|
-
|
432
|
-
raise ImportError("vllm is not installed")
|
433
|
-
|
434
|
-
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
435
|
-
|
436
|
-
# Delay import to avoid circular dependency
|
425
|
+
) -> Optional[MarlinLinearMethod]:
|
426
|
+
# Delay the import to avoid circular dependency
|
437
427
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
438
428
|
|
439
429
|
if isinstance(layer, LinearBase) or (
|