sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +72 -10
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/deepseekvl2.py +10 -1
- sglang/srt/configs/model_config.py +6 -16
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/parallel_state.py +32 -5
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/entrypoints/http_server.py +7 -1
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/attention/flashattention_backend.py +582 -125
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/dp_attention.py +12 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +79 -6
- sglang/srt/layers/quantization/__init__.py +137 -165
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_kernel.py +2 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/gptq.py +30 -40
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +19 -33
- sglang/srt/lora/lora_manager.py +20 -7
- sglang/srt/lora/mem_pool.py +12 -6
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +6 -0
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/io_struct.py +4 -2
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +44 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -127
- sglang/srt/managers/scheduler.py +29 -23
- sglang/srt/managers/tokenizer_manager.py +1 -2
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +16 -13
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +64 -59
- sglang/srt/model_loader/loader.py +19 -1
- sglang/srt/model_loader/weight_utils.py +6 -3
- sglang/srt/models/clip.py +568 -0
- sglang/srt/models/deepseek_janus_pro.py +12 -17
- sglang/srt/models/deepseek_v2.py +339 -123
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_causal.py +12 -2
- sglang/srt/models/gemma3_mm.py +20 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +106 -93
- sglang/srt/openai_api/protocol.py +10 -5
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +120 -25
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +94 -25
- sglang/srt/utils.py +137 -51
- sglang/test/runners.py +27 -2
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +14 -27
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,501 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/moe_wna16.py
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Any, Callable, Dict, List, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
9
|
+
from sglang.srt.distributed.parallel_state import get_tp_group
|
10
|
+
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
11
|
+
from sglang.srt.layers.quantization.awq import AWQConfig
|
12
|
+
from sglang.srt.layers.quantization.base_config import (
|
13
|
+
QuantizationConfig,
|
14
|
+
QuantizeMethodBase,
|
15
|
+
)
|
16
|
+
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
17
|
+
from sglang.srt.utils import get_device_capability, set_weight_attrs
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
class MoeWNA16Config(QuantizationConfig):
|
23
|
+
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
linear_quant_method: str,
|
28
|
+
weight_bits: int,
|
29
|
+
group_size: int,
|
30
|
+
has_zp: bool,
|
31
|
+
lm_head_quantized: bool,
|
32
|
+
modules_to_not_convert: Optional[List[str]],
|
33
|
+
full_config: Dict[str, Any],
|
34
|
+
) -> None:
|
35
|
+
super().__init__()
|
36
|
+
self.weight_bits = weight_bits
|
37
|
+
self.group_size = group_size
|
38
|
+
self.has_zp = has_zp
|
39
|
+
self.bit8_pack_factor = 8 // self.weight_bits
|
40
|
+
self.lm_head_quantized = lm_head_quantized
|
41
|
+
self.linear_quant_method = linear_quant_method
|
42
|
+
self.full_config = full_config
|
43
|
+
self.use_marlin = False
|
44
|
+
# Avoid circular import
|
45
|
+
|
46
|
+
if self.linear_quant_method == "gptq":
|
47
|
+
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config)
|
48
|
+
elif self.linear_quant_method == "awq":
|
49
|
+
capability_tuple = get_device_capability()
|
50
|
+
device_capability = (
|
51
|
+
-1
|
52
|
+
if capability_tuple is None
|
53
|
+
else capability_tuple[0] * 10 + capability_tuple[1]
|
54
|
+
)
|
55
|
+
awq_min_capability = AWQConfig.get_min_capability()
|
56
|
+
if device_capability < awq_min_capability:
|
57
|
+
raise ValueError(
|
58
|
+
"The quantization method moe_wna16 + awq is not supported "
|
59
|
+
"for the current GPU. "
|
60
|
+
f"Minimum capability: {awq_min_capability}. "
|
61
|
+
f"Current capability: {device_capability}."
|
62
|
+
)
|
63
|
+
else:
|
64
|
+
raise ValueError("moe_wna16 only support gptq and awq.")
|
65
|
+
|
66
|
+
if modules_to_not_convert is None:
|
67
|
+
self.modules_to_not_convert = []
|
68
|
+
else:
|
69
|
+
self.modules_to_not_convert = modules_to_not_convert
|
70
|
+
|
71
|
+
@classmethod
|
72
|
+
def get_name(cls) -> str:
|
73
|
+
return "moe_wna16"
|
74
|
+
|
75
|
+
@classmethod
|
76
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
77
|
+
return [torch.bfloat16, torch.half]
|
78
|
+
|
79
|
+
@classmethod
|
80
|
+
def get_min_capability(cls) -> int:
|
81
|
+
return 70
|
82
|
+
|
83
|
+
@classmethod
|
84
|
+
def get_config_filenames(cls) -> List[str]:
|
85
|
+
return ["quantize_config.json"]
|
86
|
+
|
87
|
+
def get_scaled_act_names(self) -> List[str]:
|
88
|
+
raise NotImplementedError
|
89
|
+
|
90
|
+
@classmethod
|
91
|
+
def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config":
|
92
|
+
quant_method = cls.get_from_keys(config, ["quant_method"])
|
93
|
+
weight_bits = cls.get_from_keys(config, ["bits"])
|
94
|
+
group_size = cls.get_from_keys(config, ["group_size"])
|
95
|
+
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
96
|
+
if quant_method == "gptq":
|
97
|
+
has_zp = not cls.get_from_keys(config, ["sym"])
|
98
|
+
modules_to_not_convert = []
|
99
|
+
elif quant_method == "awq":
|
100
|
+
has_zp = cls.get_from_keys(config, ["zero_point"])
|
101
|
+
modules_to_not_convert = cls.get_from_keys_or(
|
102
|
+
config, ["modules_to_not_convert"], None
|
103
|
+
)
|
104
|
+
else:
|
105
|
+
raise ValueError("moe_wna16 only support gptq and awq.")
|
106
|
+
|
107
|
+
return cls(
|
108
|
+
quant_method,
|
109
|
+
weight_bits,
|
110
|
+
group_size,
|
111
|
+
has_zp,
|
112
|
+
lm_head_quantized,
|
113
|
+
modules_to_not_convert,
|
114
|
+
config,
|
115
|
+
)
|
116
|
+
|
117
|
+
@classmethod
|
118
|
+
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
119
|
+
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
|
120
|
+
if can_convert and user_quant == "moe_wna16":
|
121
|
+
return cls.get_name()
|
122
|
+
return None
|
123
|
+
|
124
|
+
@classmethod
|
125
|
+
def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]):
|
126
|
+
# Extract data from quant config.
|
127
|
+
quant_method = quant_config.get("quant_method", "").lower()
|
128
|
+
num_bits = quant_config.get("bits")
|
129
|
+
desc_act = quant_config.get("desc_act")
|
130
|
+
|
131
|
+
capability_tuple = get_device_capability()
|
132
|
+
device_capability = (
|
133
|
+
-1
|
134
|
+
if capability_tuple is None
|
135
|
+
else capability_tuple[0] * 10 + capability_tuple[1]
|
136
|
+
)
|
137
|
+
# Avoid circular import
|
138
|
+
awq_min_capability = AWQConfig.get_min_capability()
|
139
|
+
|
140
|
+
gptq_compatible = quant_method == "gptq" and not desc_act and num_bits in [4, 8]
|
141
|
+
awq_compatible = (
|
142
|
+
quant_method == "awq"
|
143
|
+
and num_bits == 4
|
144
|
+
and device_capability >= awq_min_capability
|
145
|
+
)
|
146
|
+
|
147
|
+
return gptq_compatible or awq_compatible
|
148
|
+
|
149
|
+
def get_quant_method(
|
150
|
+
self, layer: torch.nn.Module, prefix: str
|
151
|
+
) -> Optional["QuantizeMethodBase"]:
|
152
|
+
# avoid circular import
|
153
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
154
|
+
|
155
|
+
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
|
156
|
+
return UnquantizedLinearMethod()
|
157
|
+
elif isinstance(layer, LinearBase):
|
158
|
+
|
159
|
+
if self.linear_quant_method == "gptq":
|
160
|
+
if self.use_marlin:
|
161
|
+
return GPTQMarlinConfig.from_config(
|
162
|
+
self.full_config
|
163
|
+
).get_quant_method(layer, prefix)
|
164
|
+
else:
|
165
|
+
return GPTQConfig.from_config(self.full_config).get_quant_method(
|
166
|
+
layer, prefix
|
167
|
+
)
|
168
|
+
elif self.linear_quant_method == "awq":
|
169
|
+
return AWQConfig.from_config(self.full_config).get_quant_method(
|
170
|
+
layer, prefix
|
171
|
+
)
|
172
|
+
else:
|
173
|
+
raise ValueError("moe_wna16 only support gptq and awq.")
|
174
|
+
elif isinstance(layer, FusedMoE):
|
175
|
+
return MoeWNA16Method(self)
|
176
|
+
return None
|
177
|
+
|
178
|
+
|
179
|
+
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):
|
180
|
+
return any(module_name in prefix for module_name in modules_to_not_convert)
|
181
|
+
|
182
|
+
|
183
|
+
class MoeWNA16Method:
|
184
|
+
"""Linear method for MOE WNA16 (W8A16/W4A16) quantization.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
|
188
|
+
"""
|
189
|
+
|
190
|
+
def __new__(cls, *args, **kwargs):
|
191
|
+
# avoid circular import
|
192
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
193
|
+
|
194
|
+
if not hasattr(cls, "_initialized"):
|
195
|
+
original_init = cls.__init__
|
196
|
+
new_cls = type(
|
197
|
+
cls.__name__,
|
198
|
+
(FusedMoEMethodBase,),
|
199
|
+
{
|
200
|
+
"__init__": original_init,
|
201
|
+
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
202
|
+
},
|
203
|
+
)
|
204
|
+
obj = super(new_cls, new_cls).__new__(new_cls)
|
205
|
+
obj.__init__(*args, **kwargs)
|
206
|
+
return obj
|
207
|
+
return super().__new__(cls)
|
208
|
+
|
209
|
+
def __init__(self, quant_config: MoeWNA16Config):
|
210
|
+
self.quant_config = quant_config
|
211
|
+
|
212
|
+
def create_weights(
|
213
|
+
self,
|
214
|
+
layer: torch.nn.Module,
|
215
|
+
num_experts: int,
|
216
|
+
hidden_size: int,
|
217
|
+
intermediate_size_per_partition: int,
|
218
|
+
params_dtype: torch.dtype,
|
219
|
+
**extra_weight_attrs,
|
220
|
+
):
|
221
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
222
|
+
|
223
|
+
layer.quant_config = self.quant_config
|
224
|
+
bit8_pack_factor = self.quant_config.bit8_pack_factor
|
225
|
+
group_size = self.quant_config.group_size
|
226
|
+
group_size_div_factor = 1
|
227
|
+
|
228
|
+
# make intermediate_size and hidden_size diviable by group_size
|
229
|
+
# we reduce the group size to ensure that
|
230
|
+
# and we would repeat the loaded_weight later
|
231
|
+
while intermediate_size_per_partition % group_size or hidden_size % group_size:
|
232
|
+
group_size = group_size // 2
|
233
|
+
group_size_div_factor *= 2
|
234
|
+
assert group_size >= 32
|
235
|
+
layer.group_size = group_size
|
236
|
+
layer.group_size_div_factor = group_size_div_factor
|
237
|
+
|
238
|
+
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
239
|
+
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": False})
|
240
|
+
|
241
|
+
assert "weight_loader" in extra_weight_attrs
|
242
|
+
weight_loader = extra_weight_attrs["weight_loader"]
|
243
|
+
wrapped_weight_loader = MoeWNA16Method.get_weight_loader(layer, weight_loader)
|
244
|
+
extra_weight_attrs["weight_loader"] = wrapped_weight_loader
|
245
|
+
|
246
|
+
# Fused gate_up_proj (column parallel)
|
247
|
+
w13_qweight = torch.nn.Parameter(
|
248
|
+
torch.empty(
|
249
|
+
num_experts,
|
250
|
+
2 * intermediate_size_per_partition,
|
251
|
+
hidden_size // bit8_pack_factor,
|
252
|
+
dtype=torch.uint8,
|
253
|
+
),
|
254
|
+
requires_grad=False,
|
255
|
+
)
|
256
|
+
layer.register_parameter("w13_qweight", w13_qweight)
|
257
|
+
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
258
|
+
|
259
|
+
# down_proj (row parallel)
|
260
|
+
w2_qweight = torch.nn.Parameter(
|
261
|
+
torch.empty(
|
262
|
+
num_experts,
|
263
|
+
hidden_size,
|
264
|
+
intermediate_size_per_partition // bit8_pack_factor,
|
265
|
+
dtype=torch.uint8,
|
266
|
+
),
|
267
|
+
requires_grad=False,
|
268
|
+
)
|
269
|
+
layer.register_parameter("w2_qweight", w2_qweight)
|
270
|
+
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
271
|
+
|
272
|
+
w13_scales = torch.nn.Parameter(
|
273
|
+
torch.zeros(
|
274
|
+
num_experts,
|
275
|
+
2 * intermediate_size_per_partition,
|
276
|
+
hidden_size // group_size,
|
277
|
+
dtype=params_dtype,
|
278
|
+
),
|
279
|
+
requires_grad=False,
|
280
|
+
)
|
281
|
+
layer.register_parameter("w13_scales", w13_scales)
|
282
|
+
set_weight_attrs(w13_scales, extra_weight_attrs)
|
283
|
+
|
284
|
+
w2_scales = torch.nn.Parameter(
|
285
|
+
torch.zeros(
|
286
|
+
num_experts,
|
287
|
+
hidden_size,
|
288
|
+
intermediate_size_per_partition // group_size,
|
289
|
+
dtype=params_dtype,
|
290
|
+
),
|
291
|
+
requires_grad=False,
|
292
|
+
)
|
293
|
+
layer.register_parameter("w2_scales", w2_scales)
|
294
|
+
set_weight_attrs(w2_scales, extra_weight_attrs)
|
295
|
+
|
296
|
+
if self.quant_config.has_zp:
|
297
|
+
w13_qzeros = torch.nn.Parameter(
|
298
|
+
torch.zeros(
|
299
|
+
num_experts,
|
300
|
+
2 * intermediate_size_per_partition // bit8_pack_factor,
|
301
|
+
hidden_size // group_size,
|
302
|
+
dtype=torch.uint8,
|
303
|
+
),
|
304
|
+
requires_grad=False,
|
305
|
+
)
|
306
|
+
layer.register_parameter("w13_qzeros", w13_qzeros)
|
307
|
+
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
308
|
+
|
309
|
+
w2_qzeros = torch.nn.Parameter(
|
310
|
+
torch.zeros(
|
311
|
+
num_experts,
|
312
|
+
hidden_size // bit8_pack_factor,
|
313
|
+
intermediate_size_per_partition // group_size,
|
314
|
+
dtype=torch.uint8,
|
315
|
+
),
|
316
|
+
requires_grad=False,
|
317
|
+
)
|
318
|
+
layer.register_parameter("w2_qzeros", w2_qzeros)
|
319
|
+
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
320
|
+
|
321
|
+
if self.quant_config.linear_quant_method == "gptq":
|
322
|
+
# some param are unused, but we need to init them in order to
|
323
|
+
# load weights
|
324
|
+
invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
|
325
|
+
if not self.quant_config.has_zp:
|
326
|
+
invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
|
327
|
+
for key in invalid_param_keys:
|
328
|
+
param = torch.nn.Parameter(
|
329
|
+
torch.empty((0,), dtype=torch.int32), requires_grad=False
|
330
|
+
)
|
331
|
+
layer.register_parameter(key, param)
|
332
|
+
set_weight_attrs(param, extra_weight_attrs)
|
333
|
+
|
334
|
+
def apply(
|
335
|
+
self,
|
336
|
+
layer: torch.nn.Module,
|
337
|
+
x: torch.Tensor,
|
338
|
+
router_logits: torch.Tensor,
|
339
|
+
top_k: int,
|
340
|
+
renormalize: bool,
|
341
|
+
use_grouped_topk: bool = False,
|
342
|
+
topk_group: Optional[int] = None,
|
343
|
+
num_expert_group: Optional[int] = None,
|
344
|
+
custom_routing_function: Optional[Callable] = None,
|
345
|
+
correction_bias: Optional[torch.Tensor] = None,
|
346
|
+
activation: str = "silu",
|
347
|
+
inplace: bool = True,
|
348
|
+
no_combine: bool = False,
|
349
|
+
) -> torch.Tensor:
|
350
|
+
# avoid circular import
|
351
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
352
|
+
from sglang.srt.layers.moe.topk import select_experts
|
353
|
+
|
354
|
+
assert activation == "silu", "Only SiLU activation is supported."
|
355
|
+
topk_weights, topk_ids = select_experts(
|
356
|
+
hidden_states=x,
|
357
|
+
router_logits=router_logits,
|
358
|
+
top_k=top_k,
|
359
|
+
use_grouped_topk=use_grouped_topk,
|
360
|
+
renormalize=renormalize,
|
361
|
+
topk_group=topk_group,
|
362
|
+
num_expert_group=num_expert_group,
|
363
|
+
custom_routing_function=custom_routing_function,
|
364
|
+
correction_bias=correction_bias,
|
365
|
+
)
|
366
|
+
|
367
|
+
weight_bits = self.quant_config.weight_bits
|
368
|
+
has_zp = self.quant_config.has_zp
|
369
|
+
|
370
|
+
return fused_experts(
|
371
|
+
x,
|
372
|
+
layer.w13_qweight,
|
373
|
+
layer.w2_qweight,
|
374
|
+
topk_weights=topk_weights,
|
375
|
+
topk_ids=topk_ids,
|
376
|
+
inplace=inplace,
|
377
|
+
use_int4_w4a16=weight_bits == 4,
|
378
|
+
use_int8_w8a16=weight_bits == 8,
|
379
|
+
w1_scale=layer.w13_scales,
|
380
|
+
w2_scale=layer.w2_scales,
|
381
|
+
w1_zp=layer.w13_qzeros if has_zp else None,
|
382
|
+
w2_zp=layer.w2_qzeros if has_zp else None,
|
383
|
+
block_shape=[0, layer.group_size],
|
384
|
+
no_combine=no_combine,
|
385
|
+
)
|
386
|
+
|
387
|
+
@staticmethod
|
388
|
+
def get_weight_loader(layer, weight_loader):
|
389
|
+
|
390
|
+
def convert_awq_tensor(tensor, tensor_type):
|
391
|
+
# convert awq qweight/qzeros to a standard format (assume int4)
|
392
|
+
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
|
393
|
+
# qzeros: (k // group_size, n // pack_factor_bit32) ->
|
394
|
+
# (n // pack_factor_bit8, k // group_size)
|
395
|
+
# pack_factor_bit32 = 32 // weight_bits
|
396
|
+
# pack_factor_bit8 = 8 // weight_bits
|
397
|
+
|
398
|
+
# 0. suppose origin shape (a, b), dtype int32
|
399
|
+
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
|
400
|
+
size0 = tensor.size(0)
|
401
|
+
tensor = tensor.view(torch.uint8)
|
402
|
+
|
403
|
+
# 2. unpack to uint4 (only when weight_bits == 4)
|
404
|
+
# shape (a, 4 * b) -> (a, 4 * b, 2)
|
405
|
+
shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device)
|
406
|
+
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
407
|
+
|
408
|
+
# 3. change order, see
|
409
|
+
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
|
410
|
+
# shape -> (a, 4 * b * pack_factor_bit8)
|
411
|
+
reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
|
412
|
+
tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
|
413
|
+
tensor = tensor.view(size0, -1)
|
414
|
+
|
415
|
+
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
|
416
|
+
tensor = tensor.T.contiguous()
|
417
|
+
|
418
|
+
# 5. repack (only when weight_bits == 4)
|
419
|
+
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
|
420
|
+
# qzeros shape -> (4 * b, a)
|
421
|
+
|
422
|
+
if tensor_type == "qweight":
|
423
|
+
tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
|
424
|
+
elif tensor_type == "qzeros":
|
425
|
+
tensor = tensor[1::2, :] * 16 + tensor[::2, :]
|
426
|
+
return tensor
|
427
|
+
|
428
|
+
def convert_gptq_int4_qzeros(tensor):
|
429
|
+
tensor = tensor.view(torch.uint8)
|
430
|
+
shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device)
|
431
|
+
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
432
|
+
tensor = tensor + 1
|
433
|
+
tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
|
434
|
+
return tensor
|
435
|
+
|
436
|
+
def moe_wna16_weight_loader(
|
437
|
+
param: torch.nn.Parameter,
|
438
|
+
loaded_weight: torch.Tensor,
|
439
|
+
weight_name: str,
|
440
|
+
shard_id: str,
|
441
|
+
expert_id: int,
|
442
|
+
):
|
443
|
+
if "g_idx" in weight_name:
|
444
|
+
return
|
445
|
+
if not layer.quant_config.has_zp and "qzeros" in weight_name:
|
446
|
+
return
|
447
|
+
|
448
|
+
device = get_tp_group().device
|
449
|
+
tp_rank = get_tensor_model_parallel_rank()
|
450
|
+
loaded_weight = loaded_weight.to(device)
|
451
|
+
shard_size = layer.intermediate_size_per_partition
|
452
|
+
|
453
|
+
# convert gptq and awq weight to a standard format
|
454
|
+
if layer.quant_config.linear_quant_method == "awq":
|
455
|
+
assert layer.quant_config.weight_bits == 4
|
456
|
+
if "weight" in weight_name:
|
457
|
+
loaded_weight = convert_awq_tensor(loaded_weight, "qweight")
|
458
|
+
elif "zeros" in weight_name:
|
459
|
+
loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
|
460
|
+
else:
|
461
|
+
loaded_weight = loaded_weight.T
|
462
|
+
elif layer.quant_config.linear_quant_method == "gptq":
|
463
|
+
assert layer.quant_config.weight_bits in [4, 8]
|
464
|
+
if "weight" in weight_name:
|
465
|
+
loaded_weight = loaded_weight.T.contiguous().view(torch.uint8)
|
466
|
+
elif "zeros" in weight_name:
|
467
|
+
# add 1 to gptq qzeros to align with awq
|
468
|
+
loaded_weight = loaded_weight.view(torch.uint8)
|
469
|
+
if layer.quant_config.weight_bits == 4:
|
470
|
+
loaded_weight = convert_gptq_int4_qzeros(loaded_weight).T
|
471
|
+
else:
|
472
|
+
loaded_weight = loaded_weight.T + 1
|
473
|
+
else:
|
474
|
+
loaded_weight = loaded_weight.T
|
475
|
+
|
476
|
+
# repeat the qzeros/scales to fit new group size
|
477
|
+
if (
|
478
|
+
layer.group_size_div_factor > 1
|
479
|
+
and "qzeros" in weight_name
|
480
|
+
or "scales" in weight_name
|
481
|
+
):
|
482
|
+
loaded_weight = loaded_weight.repeat_interleave(
|
483
|
+
layer.group_size_div_factor, 1
|
484
|
+
)
|
485
|
+
|
486
|
+
if "w13_qzeros" in weight_name:
|
487
|
+
tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[
|
488
|
+
tp_rank
|
489
|
+
]
|
490
|
+
if shard_id == "w1":
|
491
|
+
param.data[expert_id, : shard_size // 2] = tensor
|
492
|
+
else:
|
493
|
+
param.data[expert_id, shard_size // 2 :] = tensor
|
494
|
+
elif "w2_qzeros" in weight_name:
|
495
|
+
param.data[expert_id] = loaded_weight.view(
|
496
|
+
loaded_weight.size(0), layer.tp_size, -1
|
497
|
+
)[:, tp_rank]
|
498
|
+
else:
|
499
|
+
weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)
|
500
|
+
|
501
|
+
return moe_wna16_weight_loader
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
2
2
|
|
3
3
|
from types import MappingProxyType
|
4
|
-
from typing import List, Mapping, Tuple, Union
|
4
|
+
from typing import List, Mapping, Optional, Tuple, Union
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
@@ -37,7 +37,7 @@ class W8A8Fp8Config(QuantizationConfig):
|
|
37
37
|
Note:
|
38
38
|
- For models without offline quantization, weights will be quantized during model loading
|
39
39
|
- If CUTLASS is supported: Per-channel weight quantization is used
|
40
|
-
- If CUTLASS is not supported: Falls back to per-
|
40
|
+
- If CUTLASS is not supported: Falls back to per-tensor weight quantization
|
41
41
|
"""
|
42
42
|
|
43
43
|
def __init__(self, is_checkpoint_fp8_serialized: bool = False):
|
@@ -5,7 +5,7 @@ import torch
|
|
5
5
|
from sglang.srt.lora.utils import LoRABatchInfo
|
6
6
|
|
7
7
|
|
8
|
-
def
|
8
|
+
def get_fuse_output_add_from_name(name: str) -> bool:
|
9
9
|
mapping = {
|
10
10
|
"triton": True,
|
11
11
|
"flashinfer": False,
|
@@ -28,14 +28,14 @@ class BaseLoRABackend:
|
|
28
28
|
Args:
|
29
29
|
name: name of backend
|
30
30
|
batch_info: information of current batch for use
|
31
|
-
|
32
|
-
and the operation of
|
31
|
+
fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
|
32
|
+
and the operation of adding will be fused into kernel
|
33
33
|
"""
|
34
34
|
|
35
35
|
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
36
36
|
self.name = name
|
37
37
|
self.batch_info = batch_info
|
38
|
-
self.
|
38
|
+
self.fuse_output_add = get_fuse_output_add_from_name(name)
|
39
39
|
self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
|
40
40
|
|
41
41
|
def run_lora_a_sgemm(
|
@@ -37,13 +37,16 @@ class FlashInferLoRABackend(BaseLoRABackend):
|
|
37
37
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
38
38
|
) -> torch.Tensor:
|
39
39
|
|
40
|
-
return
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
40
|
+
return (
|
41
|
+
self.segment_gemm.run(
|
42
|
+
x=x,
|
43
|
+
weights=weights,
|
44
|
+
batch_size=self.batch_info.bs,
|
45
|
+
weight_column_major=True,
|
46
|
+
seg_indptr=self.batch_info.seg_indptr,
|
47
|
+
weight_indices=self.batch_info.weight_indices,
|
48
|
+
)
|
49
|
+
* self.batch_info.scalings[0]
|
47
50
|
)
|
48
51
|
|
49
52
|
def run_qkv_lora(
|
@@ -90,7 +93,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
|
|
90
93
|
weights=kv_lora_b[1],
|
91
94
|
)
|
92
95
|
|
93
|
-
return lora_output
|
96
|
+
return lora_output * self.batch_info.scalings[0]
|
94
97
|
|
95
98
|
def run_gate_up_lora(
|
96
99
|
self,
|
@@ -125,4 +128,4 @@ class FlashInferLoRABackend(BaseLoRABackend):
|
|
125
128
|
weights=gate_up_lora_b[1],
|
126
129
|
)
|
127
130
|
|
128
|
-
return lora_output
|
131
|
+
return lora_output * self.batch_info.scalings[0]
|
@@ -25,11 +25,10 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
25
25
|
x: torch.Tensor,
|
26
26
|
weights: torch.Tensor,
|
27
27
|
base_output: torch.Tensor = None,
|
28
|
-
scaling: float = 1.0,
|
29
28
|
*args,
|
30
29
|
**kwargs
|
31
30
|
) -> torch.Tensor:
|
32
|
-
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output
|
31
|
+
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output)
|
33
32
|
|
34
33
|
def run_qkv_lora(
|
35
34
|
self,
|
@@ -39,7 +38,6 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
39
38
|
output_offset: torch.Tensor,
|
40
39
|
max_qkv_out_dim: int,
|
41
40
|
base_output: torch.Tensor = None,
|
42
|
-
scaling: float = 1.0,
|
43
41
|
*args,
|
44
42
|
**kwargs
|
45
43
|
) -> torch.Tensor:
|
@@ -49,7 +47,7 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
49
47
|
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
50
48
|
assert isinstance(qkv_lora_b, torch.Tensor)
|
51
49
|
|
52
|
-
lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info)
|
50
|
+
lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info, stack_num=3)
|
53
51
|
lora_output = qkv_lora_b_fwd(
|
54
52
|
lora_a_output,
|
55
53
|
qkv_lora_b,
|
@@ -57,7 +55,6 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
57
55
|
output_offset,
|
58
56
|
max_qkv_out_dim,
|
59
57
|
base_output,
|
60
|
-
scaling,
|
61
58
|
)
|
62
59
|
return lora_output
|
63
60
|
|
@@ -67,7 +64,6 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
67
64
|
gate_up_lora_a: torch.Tensor,
|
68
65
|
gate_up_lora_b: torch.Tensor,
|
69
66
|
base_output: torch.Tensor = None,
|
70
|
-
scaling: float = 1.0,
|
71
67
|
*args,
|
72
68
|
**kwargs
|
73
69
|
) -> torch.Tensor:
|
@@ -79,13 +75,14 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
79
75
|
output_dim = gate_up_lora_b.shape[-2] // 2
|
80
76
|
|
81
77
|
# lora_a_output: (s, 2 * r)
|
82
|
-
lora_a_output = sgemm_lora_a_fwd(
|
78
|
+
lora_a_output = sgemm_lora_a_fwd(
|
79
|
+
x, gate_up_lora_a, self.batch_info, stack_num=2
|
80
|
+
)
|
83
81
|
lora_output = gate_up_lora_b_fwd(
|
84
82
|
lora_a_output,
|
85
83
|
gate_up_lora_b,
|
86
84
|
self.batch_info,
|
87
85
|
output_dim,
|
88
86
|
base_output,
|
89
|
-
scaling,
|
90
87
|
)
|
91
88
|
return lora_output
|