sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -36,30 +36,27 @@ from sglang.srt.layers.quantization.marlin_utils import (
|
|
36
36
|
marlin_zero_points,
|
37
37
|
verify_marlin_supported,
|
38
38
|
)
|
39
|
-
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
40
39
|
from sglang.srt.layers.quantization.utils import (
|
41
40
|
get_linear_quant_method,
|
41
|
+
get_scalar_types,
|
42
42
|
replace_parameter,
|
43
43
|
unpack_cols,
|
44
44
|
)
|
45
45
|
|
46
46
|
if TYPE_CHECKING:
|
47
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
47
48
|
from sglang.srt.layers.moe.topk import TopKOutput
|
48
49
|
|
49
|
-
try:
|
50
|
-
from vllm import _custom_ops as ops
|
51
|
-
except ImportError:
|
52
|
-
ops = None
|
53
|
-
|
54
50
|
from sglang.srt.utils import is_cuda
|
55
51
|
|
56
52
|
_is_cuda = is_cuda()
|
57
53
|
|
58
54
|
if _is_cuda:
|
59
|
-
from sgl_kernel import fused_marlin_moe
|
55
|
+
from sgl_kernel import fused_marlin_moe, gptq_gemm, gptq_marlin_repack, gptq_shuffle
|
60
56
|
|
61
57
|
|
62
58
|
logger = logging.getLogger(__name__)
|
59
|
+
ScalarType, scalar_types = get_scalar_types()
|
63
60
|
|
64
61
|
|
65
62
|
def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
|
@@ -85,9 +82,7 @@ def gptq_marlin_moe_repack(
|
|
85
82
|
dtype=b_q_weight.dtype,
|
86
83
|
)
|
87
84
|
for e in range(num_experts):
|
88
|
-
output[e] =
|
89
|
-
b_q_weight[e], perm[e], size_k, size_n, num_bits
|
90
|
-
)
|
85
|
+
output[e] = gptq_marlin_repack(b_q_weight[e], perm[e], size_k, size_n, num_bits)
|
91
86
|
return output
|
92
87
|
|
93
88
|
|
@@ -204,11 +199,12 @@ class GPTQConfig(QuantizationConfig):
|
|
204
199
|
from sglang.srt.layers.linear import LinearBase
|
205
200
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
206
201
|
|
207
|
-
if isinstance(layer,
|
208
|
-
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
209
|
-
elif isinstance(layer, FusedMoE):
|
202
|
+
if isinstance(layer, FusedMoE):
|
210
203
|
raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin")
|
211
|
-
|
204
|
+
else:
|
205
|
+
return get_linear_quant_method(
|
206
|
+
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
207
|
+
)
|
212
208
|
|
213
209
|
|
214
210
|
class GPTQMarlinConfig(QuantizationConfig):
|
@@ -530,7 +526,7 @@ class GPTQLinearMethod(LinearMethodBase):
|
|
530
526
|
layer.g_idx.data = torch.empty(
|
531
527
|
(0,), dtype=torch.int, device=layer.g_idx.device
|
532
528
|
)
|
533
|
-
|
529
|
+
gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
|
534
530
|
|
535
531
|
def apply(
|
536
532
|
self,
|
@@ -541,7 +537,7 @@ class GPTQLinearMethod(LinearMethodBase):
|
|
541
537
|
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
|
542
538
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
543
539
|
|
544
|
-
output =
|
540
|
+
output = gptq_gemm(
|
545
541
|
reshaped_x,
|
546
542
|
layer.qweight,
|
547
543
|
layer.qzeros,
|
@@ -726,7 +722,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|
726
722
|
def transform_w_q(x):
|
727
723
|
assert isinstance(x, BasevLLMParameter)
|
728
724
|
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
729
|
-
x.data =
|
725
|
+
x.data = gptq_marlin_repack(
|
730
726
|
x.data.contiguous(),
|
731
727
|
perm=layer.g_idx_sort_indices,
|
732
728
|
size_k=c.partition_weight_shape[0],
|
@@ -1061,13 +1057,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
1061
1057
|
layer: torch.nn.Module,
|
1062
1058
|
x: torch.Tensor,
|
1063
1059
|
topk_output: TopKOutput,
|
1064
|
-
|
1065
|
-
activation: str = "silu",
|
1066
|
-
**kwargs,
|
1060
|
+
moe_runner_config: MoeRunnerConfig,
|
1067
1061
|
) -> torch.Tensor:
|
1068
1062
|
# Delay the import to avoid circular dependency
|
1069
1063
|
|
1070
|
-
assert
|
1064
|
+
assert (
|
1065
|
+
moe_runner_config.activation == "silu"
|
1066
|
+
), "Only SiLU activation is supported."
|
1071
1067
|
|
1072
1068
|
# The input must currently be float16
|
1073
1069
|
orig_dtype = x.dtype
|
@@ -19,20 +19,31 @@ from sglang.srt.layers.quantization.base_config import (
|
|
19
19
|
LinearMethodBase,
|
20
20
|
QuantizationConfig,
|
21
21
|
)
|
22
|
-
from sglang.srt.layers.quantization.
|
23
|
-
|
24
|
-
|
22
|
+
from sglang.srt.layers.quantization.utils import (
|
23
|
+
get_scalar_types,
|
24
|
+
pack_cols,
|
25
|
+
unpack_cols,
|
26
|
+
)
|
27
|
+
from sglang.srt.utils import get_device_capability, is_cuda
|
25
28
|
|
26
29
|
if TYPE_CHECKING:
|
27
30
|
from sglang.srt.layers.linear import LinearBase
|
31
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
28
32
|
|
29
33
|
try:
|
30
34
|
from vllm import _custom_ops as ops
|
31
35
|
except ImportError:
|
32
36
|
ops = None
|
33
37
|
|
38
|
+
_is_cuda = is_cuda()
|
39
|
+
|
40
|
+
if _is_cuda:
|
41
|
+
from sgl_kernel import gptq_marlin_gemm
|
42
|
+
|
34
43
|
logger = logging.getLogger(__name__)
|
35
44
|
|
45
|
+
ScalarType, scalar_types = get_scalar_types()
|
46
|
+
|
36
47
|
GPTQ_MARLIN_TILE = 16
|
37
48
|
GPTQ_MARLIN_MIN_THREAD_N = 64
|
38
49
|
GPTQ_MARLIN_MIN_THREAD_K = 128
|
@@ -206,13 +217,13 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
|
|
206
217
|
)[0]
|
207
218
|
|
208
219
|
|
209
|
-
def check_moe_marlin_supports_layer(layer:
|
220
|
+
def check_moe_marlin_supports_layer(layer: FusedMoE, group_size: int) -> bool:
|
210
221
|
hidden_size = layer.hidden_size
|
211
222
|
intermediate_size_per_partition = layer.intermediate_size_per_partition
|
212
223
|
# apply_router_weight_on_input is not supported for moe marlin
|
213
|
-
supports_router_weight = not layer.apply_router_weight_on_input
|
224
|
+
supports_router_weight = not layer.moe_runner_config.apply_router_weight_on_input
|
214
225
|
# moe marlin requires the activation to be silu
|
215
|
-
supports_activation = layer.activation == "silu"
|
226
|
+
supports_activation = layer.moe_runner_config.activation == "silu"
|
216
227
|
|
217
228
|
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
|
218
229
|
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
|
@@ -295,6 +306,13 @@ def marlin_permute_scales(
|
|
295
306
|
return s
|
296
307
|
|
297
308
|
|
309
|
+
def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
|
310
|
+
origin_shape = s.shape
|
311
|
+
_, scale_perm_single = get_scale_perms()
|
312
|
+
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
313
|
+
return s.reshape(*origin_shape).contiguous()
|
314
|
+
|
315
|
+
|
298
316
|
def marlin_moe_permute_scales(
|
299
317
|
s: torch.Tensor,
|
300
318
|
size_k: int,
|
@@ -453,7 +471,7 @@ def apply_gptq_marlin_linear(
|
|
453
471
|
dtype=input.dtype,
|
454
472
|
)
|
455
473
|
|
456
|
-
output =
|
474
|
+
output = gptq_marlin_gemm(
|
457
475
|
reshaped_x,
|
458
476
|
None,
|
459
477
|
weight,
|
@@ -504,7 +522,7 @@ def apply_awq_marlin_linear(
|
|
504
522
|
dtype=input.dtype,
|
505
523
|
)
|
506
524
|
|
507
|
-
output =
|
525
|
+
output = gptq_marlin_gemm(
|
508
526
|
reshaped_x,
|
509
527
|
None,
|
510
528
|
weight,
|
@@ -0,0 +1,352 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.layers.quantization.marlin_utils import (
|
9
|
+
USE_FP32_REDUCE_DEFAULT,
|
10
|
+
marlin_make_workspace,
|
11
|
+
marlin_permute_bias,
|
12
|
+
marlin_permute_scales,
|
13
|
+
should_use_atomic_add_reduce,
|
14
|
+
)
|
15
|
+
from sglang.srt.layers.quantization.utils import get_scalar_types
|
16
|
+
from sglang.srt.utils import is_cuda
|
17
|
+
|
18
|
+
_is_cuda = is_cuda()
|
19
|
+
if _is_cuda:
|
20
|
+
from sgl_kernel import gptq_marlin_gemm, gptq_marlin_repack
|
21
|
+
|
22
|
+
ScalarType, scalar_types = get_scalar_types()
|
23
|
+
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
def fp8_fused_exponent_bias_into_scales(scales):
|
28
|
+
fp8_exponent = 4
|
29
|
+
if scales.dtype == torch.half:
|
30
|
+
target_exponent = 5
|
31
|
+
elif scales.dtype == torch.bfloat16:
|
32
|
+
target_exponent = 8
|
33
|
+
# exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
|
34
|
+
# exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
|
35
|
+
exponent_bias = 2 ** (target_exponent - 1) - 2 ** (fp8_exponent - 1)
|
36
|
+
s = torch.ones_like(scales) * 2
|
37
|
+
s = s**exponent_bias
|
38
|
+
return scales * s
|
39
|
+
|
40
|
+
|
41
|
+
def apply_fp8_marlin_linear(
|
42
|
+
input: torch.Tensor,
|
43
|
+
weight: torch.Tensor,
|
44
|
+
weight_scale: torch.Tensor,
|
45
|
+
workspace: torch.Tensor,
|
46
|
+
size_n: int,
|
47
|
+
size_k: int,
|
48
|
+
bias: Optional[torch.Tensor],
|
49
|
+
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
50
|
+
) -> torch.Tensor:
|
51
|
+
# For GPUs that lack FP8 hardware support, we can leverage the
|
52
|
+
# Marlin kernel for fast weight-only FP8 quantization
|
53
|
+
|
54
|
+
reshaped_x = input.reshape(-1, input.shape[-1])
|
55
|
+
out_shape = input.shape[:-1] + (size_n,)
|
56
|
+
|
57
|
+
use_atomic_add = should_use_atomic_add_reduce(
|
58
|
+
m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype
|
59
|
+
)
|
60
|
+
|
61
|
+
output = gptq_marlin_gemm(
|
62
|
+
a=reshaped_x,
|
63
|
+
c=None,
|
64
|
+
b_q_weight=weight,
|
65
|
+
b_bias=bias,
|
66
|
+
b_scales=weight_scale,
|
67
|
+
global_scale=None,
|
68
|
+
b_zeros=None,
|
69
|
+
g_idx=None,
|
70
|
+
perm=None,
|
71
|
+
workspace=workspace,
|
72
|
+
b_q_type=scalar_types.float8_e4m3fn,
|
73
|
+
size_m=reshaped_x.size(0),
|
74
|
+
size_n=size_n,
|
75
|
+
size_k=size_k,
|
76
|
+
use_atomic_add=use_atomic_add,
|
77
|
+
use_fp32_reduce=use_fp32_reduce,
|
78
|
+
)
|
79
|
+
|
80
|
+
return output.reshape(out_shape)
|
81
|
+
|
82
|
+
|
83
|
+
def prepare_fp8_layer_for_marlin(
|
84
|
+
layer: torch.nn.Module, size_k_first: bool = True
|
85
|
+
) -> None:
|
86
|
+
logger.warning_once(
|
87
|
+
"Your GPU does not have native support for FP8 computation but "
|
88
|
+
"FP8 quantization is being used. Weight-only FP8 compression will "
|
89
|
+
"be used leveraging the Marlin kernel. This may degrade "
|
90
|
+
"performance for compute-heavy workloads."
|
91
|
+
)
|
92
|
+
|
93
|
+
part_size_n = layer.output_size_per_partition
|
94
|
+
part_size_k = layer.input_size_per_partition
|
95
|
+
weight_block_size = getattr(layer, "weight_block_size", None)
|
96
|
+
|
97
|
+
if size_k_first:
|
98
|
+
assert layer.weight.shape == (part_size_k, part_size_n)
|
99
|
+
else:
|
100
|
+
assert layer.weight.shape == (part_size_n, part_size_k)
|
101
|
+
|
102
|
+
device = layer.weight.device
|
103
|
+
|
104
|
+
# WORKSPACE
|
105
|
+
layer.workspace = marlin_make_workspace(device)
|
106
|
+
|
107
|
+
# WEIGHT
|
108
|
+
# Repack weights to marlin format
|
109
|
+
perm = torch.empty(0, dtype=torch.int, device=device)
|
110
|
+
qweight = pack_fp8_to_int32(layer.weight, size_k_first)
|
111
|
+
if not size_k_first:
|
112
|
+
qweight = qweight.T.contiguous()
|
113
|
+
|
114
|
+
marlin_qweight = gptq_marlin_repack(
|
115
|
+
b_q_weight=qweight,
|
116
|
+
perm=perm,
|
117
|
+
size_k=part_size_k,
|
118
|
+
size_n=part_size_n,
|
119
|
+
num_bits=8,
|
120
|
+
)
|
121
|
+
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
122
|
+
|
123
|
+
# WEIGHT SCALES
|
124
|
+
# Permute scales
|
125
|
+
if "weight_scale" in dir(layer):
|
126
|
+
scales = layer.weight_scale.to(layer.orig_dtype)
|
127
|
+
elif "weight_scale_inv" in dir(layer):
|
128
|
+
scales = layer.weight_scale_inv.to(layer.orig_dtype)
|
129
|
+
del layer.weight_scale_inv
|
130
|
+
|
131
|
+
group_size = -1 if weight_block_size is None else weight_block_size[1]
|
132
|
+
|
133
|
+
# marlin kernel only support channel-wise and group-wise quantization
|
134
|
+
# we need to convert the scales
|
135
|
+
if weight_block_size is None:
|
136
|
+
if scales.nelement() == 1:
|
137
|
+
# tensor-wise quantization -> channel-wise quantization
|
138
|
+
# (1, 1) =>(repeat)=> (1, size_n)
|
139
|
+
scales = scales.view(1, 1).repeat_interleave(part_size_n, 1)
|
140
|
+
elif scales.nelement() > 1 and scales.nelement() != part_size_n:
|
141
|
+
assert part_size_n % scales.nelement() == 0
|
142
|
+
s_size = scales.nelement()
|
143
|
+
# tensor-wise quantization (for gate-up proj)
|
144
|
+
# -> channel-wise quantization
|
145
|
+
# (1, s_size) =>(repeat)=> (1, size_n)
|
146
|
+
scales = scales.view(1, s_size)
|
147
|
+
scales = scales.repeat_interleave(part_size_n // s_size, 1)
|
148
|
+
else:
|
149
|
+
# channel-wise quantization
|
150
|
+
# (1, size_n)
|
151
|
+
scales = scales.view(1, part_size_n)
|
152
|
+
else:
|
153
|
+
# block-wise quantization -> group-wise quantization
|
154
|
+
# (size_k // block_size[1], ceil(size_n / block_size[0]))
|
155
|
+
# =>(repeat)=> (size_k // block_size[1], size_n)
|
156
|
+
if not size_k_first:
|
157
|
+
scales = scales.T.contiguous()
|
158
|
+
block_n = weight_block_size[0]
|
159
|
+
scales = scales.repeat_interleave(block_n, 1)
|
160
|
+
# size_n may not divisible by block_size[0]
|
161
|
+
scales = scales[:, :part_size_n]
|
162
|
+
|
163
|
+
marlin_scales = marlin_permute_scales(
|
164
|
+
s=scales, size_k=part_size_k, size_n=part_size_n, group_size=group_size
|
165
|
+
)
|
166
|
+
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
|
167
|
+
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
168
|
+
|
169
|
+
if hasattr(layer, "bias") and layer.bias is not None:
|
170
|
+
assert layer.bias.shape == (part_size_n,)
|
171
|
+
bias = marlin_permute_bias(layer.bias)
|
172
|
+
layer.bias = torch.nn.Parameter(bias, requires_grad=False)
|
173
|
+
|
174
|
+
|
175
|
+
def prepare_moe_fp8_layer_for_marlin(
|
176
|
+
layer: torch.nn.Module, size_k_first: bool = True
|
177
|
+
) -> None:
|
178
|
+
logger.warning_once(
|
179
|
+
"Your GPU does not have native support for FP8 computation but "
|
180
|
+
"FP8 quantization is being used. Weight-only FP8 compression will "
|
181
|
+
"be used leveraging the Marlin kernel. This may degrade "
|
182
|
+
"performance for compute-heavy workloads."
|
183
|
+
)
|
184
|
+
|
185
|
+
e = layer.num_experts
|
186
|
+
k = layer.hidden_size
|
187
|
+
n = layer.intermediate_size_per_partition
|
188
|
+
weight_block_size = getattr(layer, "weight_block_size", None)
|
189
|
+
|
190
|
+
# WORKSPACE
|
191
|
+
device = layer.w13_weight.device
|
192
|
+
layer.workspace = marlin_make_workspace(device, 4)
|
193
|
+
perm = torch.empty(0, dtype=torch.int, device=device)
|
194
|
+
|
195
|
+
# WEIGHT
|
196
|
+
# Repack weights to marlin format
|
197
|
+
for name in ["w13_weight", "w2_weight"]:
|
198
|
+
weight = getattr(layer, name)
|
199
|
+
tensor_list = []
|
200
|
+
if "w13" in name:
|
201
|
+
size_n, size_k = n * 2, k
|
202
|
+
else:
|
203
|
+
size_n, size_k = k, n
|
204
|
+
|
205
|
+
if size_k_first:
|
206
|
+
assert weight.shape == (e, size_k, size_n)
|
207
|
+
else:
|
208
|
+
assert weight.shape == (e, size_n, size_k)
|
209
|
+
|
210
|
+
for i in range(e):
|
211
|
+
qweight = pack_fp8_to_int32(weight[i], size_k_first)
|
212
|
+
if not size_k_first:
|
213
|
+
qweight = qweight.T.contiguous()
|
214
|
+
|
215
|
+
marlin_qweight = gptq_marlin_repack(
|
216
|
+
b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=8
|
217
|
+
)
|
218
|
+
tensor_list.append(marlin_qweight)
|
219
|
+
|
220
|
+
weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
221
|
+
weight = torch.nn.Parameter(weight, requires_grad=False)
|
222
|
+
|
223
|
+
setattr(layer, name, weight)
|
224
|
+
|
225
|
+
# WEIGHT SCALES
|
226
|
+
# Permute scales
|
227
|
+
group_size = -1 if weight_block_size is None else weight_block_size[1]
|
228
|
+
|
229
|
+
for name in ["w13", "w2"]:
|
230
|
+
if name + "_weight_scale" in dir(layer):
|
231
|
+
new_name = name + "_weight_scale"
|
232
|
+
scales = getattr(layer, new_name).to(layer.orig_dtype)
|
233
|
+
delattr(layer, new_name)
|
234
|
+
elif name + "_weight_scale_inv" in dir(layer):
|
235
|
+
new_name = name + "_weight_scale_inv"
|
236
|
+
scales = getattr(layer, new_name).to(layer.orig_dtype)
|
237
|
+
delattr(layer, new_name)
|
238
|
+
|
239
|
+
tensor_list = []
|
240
|
+
if "w13" in name:
|
241
|
+
size_n, size_k = n * 2, k
|
242
|
+
else:
|
243
|
+
size_n, size_k = k, n
|
244
|
+
|
245
|
+
# marlin kernel only support channel-wise and group-wise quantization
|
246
|
+
# we need to convert the scales
|
247
|
+
if weight_block_size is None:
|
248
|
+
if scales.nelement() == e:
|
249
|
+
# tensor-wise quantization -> channel-wise quantization
|
250
|
+
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
|
251
|
+
scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2)
|
252
|
+
elif scales.nelement() > e and scales.nelement() != e * size_n:
|
253
|
+
assert (e * size_n) % scales.nelement() == 0
|
254
|
+
s_size = scales.nelement() // e
|
255
|
+
# tensor-wise quantization (for gate-up proj)
|
256
|
+
# -> channel-wise quantization
|
257
|
+
# (e, 1, s_size) =>(repeat)=> (e, 1, size_n)
|
258
|
+
scales = scales.view(e, 1, s_size)
|
259
|
+
scales = scales.repeat_interleave(size_n // s_size, 2)
|
260
|
+
else:
|
261
|
+
# channel-wise quantization
|
262
|
+
# (e, 1, size_n)
|
263
|
+
scales = scales.view(e, 1, size_n)
|
264
|
+
else:
|
265
|
+
# block-wise quantization -> group-wise quantization
|
266
|
+
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
|
267
|
+
# =>(repeat)=> (e, size_k // block_size[1], size_n)
|
268
|
+
if not size_k_first:
|
269
|
+
scales = scales.permute(0, 2, 1)
|
270
|
+
block_n = weight_block_size[0]
|
271
|
+
scales = scales.repeat_interleave(block_n, 2)
|
272
|
+
# size_n may not divisible by block_size[0]
|
273
|
+
scales = scales[..., :size_n].contiguous()
|
274
|
+
|
275
|
+
for i in range(e):
|
276
|
+
marlin_scales = marlin_permute_scales(
|
277
|
+
s=scales[i], size_k=size_k, size_n=size_n, group_size=group_size
|
278
|
+
)
|
279
|
+
tensor_list.append(marlin_scales)
|
280
|
+
|
281
|
+
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
282
|
+
scales = fp8_fused_exponent_bias_into_scales(scales)
|
283
|
+
scales = torch.nn.Parameter(scales, requires_grad=False)
|
284
|
+
|
285
|
+
setattr(layer, name + "_weight_scale", scales)
|
286
|
+
|
287
|
+
# BIAS
|
288
|
+
# Permute bias
|
289
|
+
for name in ["w13_bias", "w2_bias"]:
|
290
|
+
if not hasattr(layer, name):
|
291
|
+
continue
|
292
|
+
bias = getattr(layer, name).to(layer.orig_dtype)
|
293
|
+
|
294
|
+
tensor_list = []
|
295
|
+
for i in range(e):
|
296
|
+
expert_bias = bias[i]
|
297
|
+
|
298
|
+
tensor_list.append(marlin_permute_bias(expert_bias))
|
299
|
+
|
300
|
+
bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
301
|
+
bias = torch.nn.Parameter(bias, requires_grad=False)
|
302
|
+
setattr(layer, name, bias)
|
303
|
+
|
304
|
+
|
305
|
+
def pack_fp8_to_int32(
|
306
|
+
fp8_tensor: torch.Tensor, size_k_first: bool = True
|
307
|
+
) -> torch.Tensor:
|
308
|
+
"""
|
309
|
+
Repack FP8 weights to gptq format (packed int32 elements)
|
310
|
+
"""
|
311
|
+
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
312
|
+
assert fp8_tensor.ndim == 2
|
313
|
+
|
314
|
+
fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
|
315
|
+
fp8_tensor = fp8_tensor.contiguous()
|
316
|
+
# fp8_tensor is contiguous and have shape (N, K) now
|
317
|
+
# with `.view(torch.int32)`, it become (N, K // 4)
|
318
|
+
int32_tensor = fp8_tensor.view(torch.int32)
|
319
|
+
return int32_tensor.T.contiguous() if size_k_first else int32_tensor
|
320
|
+
|
321
|
+
|
322
|
+
def marlin_quant_fp8_torch(weight, group_size):
|
323
|
+
size_n, size_k = weight.shape
|
324
|
+
device = weight.device
|
325
|
+
|
326
|
+
if group_size != -1:
|
327
|
+
scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448
|
328
|
+
repeated_scales = scales.repeat_interleave(group_size, 1)
|
329
|
+
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
|
330
|
+
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
|
331
|
+
else:
|
332
|
+
scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448
|
333
|
+
repeated_scales = scales.repeat_interleave(size_k, 1)
|
334
|
+
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
|
335
|
+
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
|
336
|
+
|
337
|
+
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
|
338
|
+
marlin_qweight = gptq_marlin_repack(
|
339
|
+
b_q_weight=packed_weight,
|
340
|
+
perm=torch.empty(0, dtype=torch.int, device=device),
|
341
|
+
size_k=size_k,
|
342
|
+
size_n=size_n,
|
343
|
+
num_bits=8,
|
344
|
+
)
|
345
|
+
|
346
|
+
marlin_scales = marlin_permute_scales(
|
347
|
+
s=scales.T, size_k=size_k, size_n=size_n, group_size=group_size
|
348
|
+
)
|
349
|
+
|
350
|
+
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
|
351
|
+
|
352
|
+
return weight_ref.T, marlin_qweight, marlin_scales
|