sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,21 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import re
|
6
|
+
from copy import deepcopy
|
3
7
|
from types import MappingProxyType
|
4
|
-
from typing import List, Mapping, Tuple, Union
|
8
|
+
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
|
5
9
|
|
10
|
+
import numpy
|
6
11
|
import torch
|
7
12
|
|
8
13
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
9
|
-
from sglang.srt.
|
10
|
-
|
11
|
-
_is_cuda = is_cuda()
|
12
|
-
_is_npu = is_npu()
|
13
|
-
_is_cpu_amx_available = cpu_has_amx_support()
|
14
|
-
_is_cpu = is_cpu()
|
14
|
+
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
15
|
+
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
|
15
16
|
|
16
|
-
if
|
17
|
-
from
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
18
19
|
|
19
20
|
|
20
21
|
def is_layer_skipped(
|
@@ -143,3 +144,333 @@ def replace_parameter(
|
|
143
144
|
if not isinstance(new, torch.nn.Parameter):
|
144
145
|
new = torch.nn.Parameter(new, requires_grad=False)
|
145
146
|
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|
147
|
+
|
148
|
+
|
149
|
+
# Match dynamic rules with module name (prefix) and override quantize
|
150
|
+
# config if module (prefix) matches a rule
|
151
|
+
def override_config(config: QuantizationConfig, prefix: str):
|
152
|
+
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
|
153
|
+
if isinstance(weight_bits, int):
|
154
|
+
config.weight_bits = weight_bits
|
155
|
+
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
|
156
|
+
if isinstance(group_size, int):
|
157
|
+
config.group_size = group_size
|
158
|
+
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
|
159
|
+
if isinstance(desc_act, bool):
|
160
|
+
config.desc_act = desc_act
|
161
|
+
|
162
|
+
config.pack_factor = 32 // config.weight_bits # packed into int32
|
163
|
+
if config.get_name() == "gptq_marlin":
|
164
|
+
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
165
|
+
if isinstance(is_sym, bool):
|
166
|
+
config.is_sym = is_sym
|
167
|
+
|
168
|
+
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
169
|
+
raise ValueError(
|
170
|
+
"Unsupported quantization config: "
|
171
|
+
f"bits={config.weight_bits}, sym={config.is_sym}"
|
172
|
+
)
|
173
|
+
|
174
|
+
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
|
175
|
+
elif config.get_name() == "gptq":
|
176
|
+
if config.weight_bits not in [2, 3, 4, 8]:
|
177
|
+
raise ValueError(
|
178
|
+
"Currently, only 2/3/4/8-bit weight quantization is "
|
179
|
+
f"supported for GPTQ, but got {config.weight_bits} bits."
|
180
|
+
)
|
181
|
+
|
182
|
+
|
183
|
+
def get_dynamic_override(
|
184
|
+
config: QuantizationConfig,
|
185
|
+
layer_name: str,
|
186
|
+
key: Optional[str] = None,
|
187
|
+
default_value: Union[int, bool, None] = None,
|
188
|
+
) -> Union[Dict, int, bool, None]:
|
189
|
+
for pattern, pattern_dict in config.dynamic.items():
|
190
|
+
# Negative match: matched modules are excluded from quantized init
|
191
|
+
if pattern.startswith("-:"):
|
192
|
+
if re.match(pattern.removeprefix("-:"), layer_name):
|
193
|
+
return False
|
194
|
+
# Positive match: matched modules have quant properties overrides
|
195
|
+
# base quant config
|
196
|
+
elif re.match(pattern.removeprefix("+:"), layer_name):
|
197
|
+
if key is None:
|
198
|
+
return pattern_dict
|
199
|
+
else:
|
200
|
+
return pattern_dict.get(key, default_value)
|
201
|
+
return default_value
|
202
|
+
|
203
|
+
|
204
|
+
def get_linear_quant_method(
|
205
|
+
config: QuantizationConfig,
|
206
|
+
layer: torch.nn.Module,
|
207
|
+
prefix: str,
|
208
|
+
linear_method_cls: type,
|
209
|
+
):
|
210
|
+
from sglang.srt.layers.linear import LinearBase
|
211
|
+
from sglang.srt.layers.quantization.unquant import (
|
212
|
+
UnquantizedEmbeddingMethod,
|
213
|
+
UnquantizedLinearMethod,
|
214
|
+
)
|
215
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
216
|
+
|
217
|
+
cloned_config = deepcopy(config)
|
218
|
+
parallel_lm_head_quantized = (
|
219
|
+
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
220
|
+
)
|
221
|
+
|
222
|
+
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
223
|
+
# False = skip module, None = no override, else = Positive match
|
224
|
+
if get_dynamic_override(cloned_config, layer_name=prefix) is False:
|
225
|
+
if parallel_lm_head_quantized:
|
226
|
+
return UnquantizedEmbeddingMethod()
|
227
|
+
return UnquantizedLinearMethod()
|
228
|
+
|
229
|
+
if prefix:
|
230
|
+
# Dynamic per module/layer rules may override base config
|
231
|
+
override_config(cloned_config, prefix=prefix)
|
232
|
+
|
233
|
+
return linear_method_cls(cloned_config)
|
234
|
+
return None
|
235
|
+
|
236
|
+
|
237
|
+
def get_pack_factor(num_bits):
|
238
|
+
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
239
|
+
return 32 // num_bits
|
240
|
+
|
241
|
+
|
242
|
+
def permute_rows(
|
243
|
+
q_w: torch.Tensor,
|
244
|
+
w_ref: torch.Tensor,
|
245
|
+
group_size: int,
|
246
|
+
test_perm: Optional[torch.Tensor] = None,
|
247
|
+
):
|
248
|
+
assert q_w.shape == w_ref.shape
|
249
|
+
|
250
|
+
orig_device = q_w.device
|
251
|
+
k_size, _ = q_w.shape
|
252
|
+
|
253
|
+
g_idx = torch.zeros((k_size,), dtype=torch.int32)
|
254
|
+
for i in range(k_size):
|
255
|
+
g_idx[i] = i // group_size
|
256
|
+
|
257
|
+
# Simulate act_order by doing a random permutation on K
|
258
|
+
rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
|
259
|
+
|
260
|
+
g_idx = g_idx[rand_perm].contiguous()
|
261
|
+
q_w = q_w[rand_perm, :].contiguous()
|
262
|
+
w_ref = w_ref[rand_perm, :].contiguous()
|
263
|
+
|
264
|
+
return (
|
265
|
+
w_ref.to(device=orig_device),
|
266
|
+
q_w.to(device=orig_device),
|
267
|
+
g_idx.to(device=orig_device),
|
268
|
+
rand_perm.to(device=orig_device),
|
269
|
+
)
|
270
|
+
|
271
|
+
|
272
|
+
def pack_cols(
|
273
|
+
q_w: torch.Tensor,
|
274
|
+
num_bits: int,
|
275
|
+
size_k: int,
|
276
|
+
size_n: int,
|
277
|
+
):
|
278
|
+
assert q_w.shape == (size_k, size_n)
|
279
|
+
|
280
|
+
pack_factor = get_pack_factor(num_bits)
|
281
|
+
assert size_n % pack_factor == 0
|
282
|
+
|
283
|
+
orig_device = q_w.device
|
284
|
+
|
285
|
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
286
|
+
|
287
|
+
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
288
|
+
|
289
|
+
for i in range(pack_factor):
|
290
|
+
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
291
|
+
|
292
|
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
293
|
+
q_res = q_res.contiguous()
|
294
|
+
|
295
|
+
return q_res
|
296
|
+
|
297
|
+
|
298
|
+
def unpack_cols(
|
299
|
+
packed_q_w: torch.Tensor,
|
300
|
+
num_bits: int,
|
301
|
+
size_k: int,
|
302
|
+
size_n: int,
|
303
|
+
):
|
304
|
+
pack_factor = get_pack_factor(num_bits)
|
305
|
+
assert size_n % pack_factor == 0
|
306
|
+
assert packed_q_w.shape == (
|
307
|
+
size_k,
|
308
|
+
size_n // pack_factor,
|
309
|
+
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
310
|
+
packed_q_w.shape, size_k, size_n, pack_factor
|
311
|
+
)
|
312
|
+
|
313
|
+
orig_device = packed_q_w.device
|
314
|
+
|
315
|
+
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
316
|
+
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
317
|
+
|
318
|
+
mask = (1 << num_bits) - 1
|
319
|
+
for i in range(pack_factor):
|
320
|
+
vals = packed_q_w_cpu & mask
|
321
|
+
packed_q_w_cpu >>= num_bits
|
322
|
+
q_res[:, i::pack_factor] = vals
|
323
|
+
|
324
|
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
325
|
+
q_res = q_res.contiguous()
|
326
|
+
|
327
|
+
return q_res
|
328
|
+
|
329
|
+
|
330
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
331
|
+
def quantize_weights(
|
332
|
+
w: torch.Tensor,
|
333
|
+
quant_type: ScalarType,
|
334
|
+
group_size: Optional[int],
|
335
|
+
zero_points: bool = False,
|
336
|
+
ref_zero_points_after_scales: bool = False,
|
337
|
+
):
|
338
|
+
assert (
|
339
|
+
quant_type.is_integer()
|
340
|
+
), "Floating point quantization may work but has not been tested"
|
341
|
+
assert not zero_points or group_size is not None, (
|
342
|
+
"to have group zero points, group_size must be provided "
|
343
|
+
"(-1 group_size is channelwise)"
|
344
|
+
)
|
345
|
+
|
346
|
+
orig_device = w.device
|
347
|
+
orig_type = w.dtype
|
348
|
+
size_k, size_n = w.shape
|
349
|
+
|
350
|
+
assert w.is_floating_point(), "w must be float"
|
351
|
+
|
352
|
+
if group_size == -1:
|
353
|
+
group_size = size_k
|
354
|
+
|
355
|
+
# Reshape to [groupsize, -1]
|
356
|
+
if group_size is not None and group_size < size_k:
|
357
|
+
w = w.reshape((-1, group_size, size_n))
|
358
|
+
w = w.permute(1, 0, 2)
|
359
|
+
w = w.reshape((group_size, -1))
|
360
|
+
|
361
|
+
# Compute scale for each group
|
362
|
+
max_val = torch.max(w, 0, keepdim=True).values
|
363
|
+
min_val = torch.min(w, 0, keepdim=True).values
|
364
|
+
|
365
|
+
max_q_val = quant_type.max()
|
366
|
+
min_q_val = quant_type.min()
|
367
|
+
|
368
|
+
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
369
|
+
maybe_w_zp = None
|
370
|
+
if group_size is not None:
|
371
|
+
if zero_points:
|
372
|
+
assert not quant_type.is_signed() and quant_type.max() > 0
|
373
|
+
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
374
|
+
maybe_w_zp = (
|
375
|
+
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
|
376
|
+
)
|
377
|
+
else:
|
378
|
+
# If the bias is such that there are no possible negative/positive
|
379
|
+
# values, set the max value to inf to avoid divide by 0
|
380
|
+
w_s = torch.max(
|
381
|
+
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
382
|
+
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
|
383
|
+
)
|
384
|
+
|
385
|
+
# Quantize
|
386
|
+
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
387
|
+
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
388
|
+
|
389
|
+
# Compute ref (dequantized)
|
390
|
+
# For some kernels (namely Machete) the zero-points are applied after the
|
391
|
+
# scales are applied, for this case computing the reference in similar way
|
392
|
+
# allows us to use tighter error tolerances in our unit tests.
|
393
|
+
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
394
|
+
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
395
|
+
else:
|
396
|
+
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
397
|
+
|
398
|
+
if quant_type.has_bias():
|
399
|
+
w_q += quant_type.bias
|
400
|
+
|
401
|
+
# Restore original shapes
|
402
|
+
if group_size is not None and group_size < size_k:
|
403
|
+
|
404
|
+
def reshape_w(w):
|
405
|
+
w = w.reshape((group_size, -1, size_n))
|
406
|
+
w = w.permute(1, 0, 2)
|
407
|
+
w = w.reshape((size_k, size_n)).contiguous()
|
408
|
+
return w
|
409
|
+
|
410
|
+
w_q = reshape_w(w_q)
|
411
|
+
w_ref = reshape_w(w_ref)
|
412
|
+
w_s = w_s.reshape((-1, size_n)).contiguous()
|
413
|
+
|
414
|
+
if maybe_w_zp is not None:
|
415
|
+
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
416
|
+
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
417
|
+
|
418
|
+
return (
|
419
|
+
w_ref.to(device=orig_device),
|
420
|
+
w_q.to(device=orig_device),
|
421
|
+
w_s if group_size is not None else None,
|
422
|
+
maybe_w_zp,
|
423
|
+
)
|
424
|
+
|
425
|
+
|
426
|
+
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
427
|
+
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
428
|
+
|
429
|
+
|
430
|
+
def gptq_quantize_weights(
|
431
|
+
w: torch.Tensor,
|
432
|
+
quant_type: ScalarType,
|
433
|
+
group_size: int,
|
434
|
+
act_order: bool,
|
435
|
+
test_perm: Optional[torch.Tensor] = None,
|
436
|
+
):
|
437
|
+
size_k, _ = w.shape
|
438
|
+
|
439
|
+
assert w.is_floating_point(), "w must be float"
|
440
|
+
assert (
|
441
|
+
quant_type in SUPPORTED_GPTQ_QUANT_TYPES
|
442
|
+
), f"Unsupported gptq type = {quant_type}"
|
443
|
+
assert group_size in SUPPORTED_GROUP_SIZES + [
|
444
|
+
size_k
|
445
|
+
], f"Unsupported groupsize = {group_size}"
|
446
|
+
|
447
|
+
w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
|
448
|
+
|
449
|
+
# Apply act_order
|
450
|
+
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
|
451
|
+
rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
|
452
|
+
if act_order:
|
453
|
+
assert (
|
454
|
+
group_size < size_k
|
455
|
+
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
456
|
+
group_size, size_k
|
457
|
+
)
|
458
|
+
|
459
|
+
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
|
460
|
+
|
461
|
+
return w_ref, w_q, w_s, g_idx, rand_perm
|
462
|
+
|
463
|
+
|
464
|
+
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
465
|
+
orig_device = q_w.device
|
466
|
+
|
467
|
+
sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
|
468
|
+
|
469
|
+
g_idx = g_idx[sort_indices].contiguous()
|
470
|
+
q_w = q_w[sort_indices, :].contiguous()
|
471
|
+
|
472
|
+
return (
|
473
|
+
q_w.to(device=orig_device),
|
474
|
+
g_idx.to(device=orig_device),
|
475
|
+
sort_indices.to(device=orig_device),
|
476
|
+
)
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import logging
|
2
4
|
from typing import Any, Dict, List, Optional
|
3
5
|
|
@@ -5,12 +7,13 @@ import torch
|
|
5
7
|
from torch.nn import Module
|
6
8
|
from torch.nn.parameter import Parameter
|
7
9
|
|
8
|
-
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
9
10
|
from sglang.srt.layers.quantization.base_config import (
|
11
|
+
FusedMoEMethodBase,
|
10
12
|
QuantizationConfig,
|
11
13
|
QuantizeMethodBase,
|
12
14
|
)
|
13
15
|
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
16
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
14
17
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
15
18
|
from sglang.srt.utils import set_weight_attrs
|
16
19
|
|
@@ -62,7 +65,7 @@ class W4AFp8Config(QuantizationConfig):
|
|
62
65
|
return []
|
63
66
|
|
64
67
|
@classmethod
|
65
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
68
|
+
def from_config(cls, config: Dict[str, Any]) -> W4AFp8Config:
|
66
69
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
67
70
|
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
68
71
|
is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
|
@@ -79,7 +82,8 @@ class W4AFp8Config(QuantizationConfig):
|
|
79
82
|
|
80
83
|
def get_quant_method(
|
81
84
|
self, layer: torch.nn.Module, prefix: str
|
82
|
-
) -> Optional[
|
85
|
+
) -> Optional[QuantizeMethodBase]:
|
86
|
+
from sglang.srt.layers.linear import LinearBase
|
83
87
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
84
88
|
|
85
89
|
if isinstance(layer, LinearBase):
|
@@ -94,7 +98,7 @@ class W4AFp8Config(QuantizationConfig):
|
|
94
98
|
return []
|
95
99
|
|
96
100
|
|
97
|
-
class W4AFp8MoEMethod:
|
101
|
+
class W4AFp8MoEMethod(FusedMoEMethodBase):
|
98
102
|
|
99
103
|
def __init__(self, quant_config: W4AFp8Config):
|
100
104
|
self.quant_config = quant_config
|
@@ -1,11 +1,14 @@
|
|
1
|
-
from
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
2
4
|
|
3
5
|
import torch
|
4
6
|
from torch.nn.parameter import Parameter
|
5
7
|
|
6
|
-
from sglang.srt.layers.linear import LinearMethodBase
|
7
8
|
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
8
9
|
from sglang.srt.layers.quantization.base_config import (
|
10
|
+
FusedMoEMethodBase,
|
11
|
+
LinearMethodBase,
|
9
12
|
QuantizationConfig,
|
10
13
|
QuantizeMethodBase,
|
11
14
|
)
|
@@ -22,6 +25,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
22
25
|
)
|
23
26
|
from sglang.srt.utils import set_weight_attrs
|
24
27
|
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
30
|
+
|
25
31
|
_is_fp8_fnuz = is_fp8_fnuz()
|
26
32
|
|
27
33
|
|
@@ -64,7 +70,7 @@ class W8A8Fp8Config(QuantizationConfig):
|
|
64
70
|
return []
|
65
71
|
|
66
72
|
@classmethod
|
67
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
73
|
+
def from_config(cls, config: Dict[str, Any]) -> W8A8Fp8Config:
|
68
74
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
69
75
|
is_checkpoint_fp8_serialized = (
|
70
76
|
"compressed-tensors" in quant_method or "w8a8_fp8" in quant_method
|
@@ -75,7 +81,7 @@ class W8A8Fp8Config(QuantizationConfig):
|
|
75
81
|
self,
|
76
82
|
layer: torch.nn.Module,
|
77
83
|
prefix: str,
|
78
|
-
) -> Optional[
|
84
|
+
) -> Optional[QuantizeMethodBase]:
|
79
85
|
from sglang.srt.layers.linear import LinearBase
|
80
86
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
81
87
|
|
@@ -183,7 +189,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
183
189
|
)
|
184
190
|
|
185
191
|
|
186
|
-
class W8A8FP8MoEMethod:
|
192
|
+
class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
187
193
|
"""MoE method for FP8.
|
188
194
|
Supports loading FP8 checkpoints with static weight scale and
|
189
195
|
dynamic/static activation scale.
|
@@ -194,25 +200,7 @@ class W8A8FP8MoEMethod:
|
|
194
200
|
quant_config: The quantization config.
|
195
201
|
"""
|
196
202
|
|
197
|
-
def
|
198
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
199
|
-
|
200
|
-
if not hasattr(cls, "_initialized"):
|
201
|
-
original_init = cls.__init__
|
202
|
-
new_cls = type(
|
203
|
-
cls.__name__,
|
204
|
-
(FusedMoEMethodBase,),
|
205
|
-
{
|
206
|
-
"__init__": original_init,
|
207
|
-
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
208
|
-
},
|
209
|
-
)
|
210
|
-
obj = super(new_cls, new_cls).__new__(new_cls)
|
211
|
-
obj.__init__(*args, **kwargs)
|
212
|
-
return obj
|
213
|
-
return super().__new__(cls)
|
214
|
-
|
215
|
-
def __init__(self, quant_config):
|
203
|
+
def __init__(self, quant_config: W8A8Fp8Config):
|
216
204
|
self.quant_config = quant_config
|
217
205
|
|
218
206
|
def create_weights(
|
@@ -281,45 +269,23 @@ class W8A8FP8MoEMethod:
|
|
281
269
|
self,
|
282
270
|
layer: torch.nn.Module,
|
283
271
|
x: torch.Tensor,
|
284
|
-
|
285
|
-
|
286
|
-
renormalize: bool,
|
287
|
-
use_grouped_topk: bool,
|
288
|
-
topk_group: Optional[int] = None,
|
289
|
-
num_expert_group: Optional[int] = None,
|
290
|
-
num_fused_shared_experts: int = 0,
|
291
|
-
custom_routing_function: Optional[Callable] = None,
|
292
|
-
correction_bias: Optional[torch.Tensor] = None,
|
272
|
+
topk_output: TopKOutput,
|
273
|
+
*,
|
293
274
|
activation: str = "silu",
|
275
|
+
apply_router_weight_on_input: bool = False,
|
294
276
|
inplace: bool = True,
|
295
277
|
no_combine: bool = False,
|
296
278
|
routed_scaling_factor: Optional[float] = None,
|
297
279
|
) -> torch.Tensor:
|
298
280
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
299
|
-
from sglang.srt.layers.moe.topk import select_experts
|
300
|
-
|
301
|
-
# Expert selection
|
302
|
-
topk_weights, topk_ids = select_experts(
|
303
|
-
hidden_states=x,
|
304
|
-
router_logits=router_logits,
|
305
|
-
use_grouped_topk=use_grouped_topk,
|
306
|
-
top_k=top_k,
|
307
|
-
renormalize=renormalize,
|
308
|
-
topk_group=topk_group,
|
309
|
-
num_expert_group=num_expert_group,
|
310
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
311
|
-
custom_routing_function=custom_routing_function,
|
312
|
-
correction_bias=correction_bias,
|
313
|
-
routed_scaling_factor=routed_scaling_factor,
|
314
|
-
)
|
315
281
|
|
316
282
|
return fused_experts(
|
317
283
|
x,
|
318
284
|
layer.w13_weight,
|
319
285
|
layer.w2_weight,
|
320
|
-
|
321
|
-
topk_ids=topk_ids,
|
286
|
+
topk_output=topk_output,
|
322
287
|
inplace=inplace,
|
288
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
323
289
|
activation=activation,
|
324
290
|
use_fp8_w8a8=True,
|
325
291
|
per_channel_quant=True,
|