sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,10 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
|
51
51
|
|
52
52
|
is_hip_ = is_hip()
|
53
53
|
|
54
|
+
if is_hip_:
|
55
|
+
from aiter.fused_moe_bf16_asm import asm_moe
|
56
|
+
from aiter.ops.shuffle import shuffle_weight
|
57
|
+
|
54
58
|
logger = logging.getLogger(__name__)
|
55
59
|
|
56
60
|
|
@@ -533,6 +537,20 @@ class Fp8MoEMethod:
|
|
533
537
|
)
|
534
538
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
535
539
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
540
|
+
|
541
|
+
if is_hip_ and get_bool_env_var("CK_MOE"):
|
542
|
+
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
543
|
+
w13_weight_scale1 = torch.nn.Parameter(
|
544
|
+
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
545
|
+
requires_grad=False,
|
546
|
+
)
|
547
|
+
w2_weight_scale1 = torch.nn.Parameter(
|
548
|
+
torch.ones(num_experts, hidden_size, dtype=torch.float32),
|
549
|
+
requires_grad=False,
|
550
|
+
)
|
551
|
+
layer.register_parameter("w13_weight_scale1", w13_weight_scale1)
|
552
|
+
layer.register_parameter("w2_weight_scale1", w2_weight_scale1)
|
553
|
+
|
536
554
|
# Add the quantization method used (per tensor/grouped/channel)
|
537
555
|
# to ensure the weight scales are loaded in properly
|
538
556
|
extra_weight_attrs.update(
|
@@ -602,6 +620,15 @@ class Fp8MoEMethod:
|
|
602
620
|
w2_weight_scale, requires_grad=False
|
603
621
|
)
|
604
622
|
layer.w2_input_scale = None
|
623
|
+
|
624
|
+
if get_bool_env_var("CK_MOE"):
|
625
|
+
# Pre-shuffle weights
|
626
|
+
layer.w13_weight.data = shuffle_weight(
|
627
|
+
layer.w13_weight.contiguous(), (16, 16)
|
628
|
+
)
|
629
|
+
layer.w2_weight.data = shuffle_weight(
|
630
|
+
layer.w2_weight.contiguous(), (16, 16)
|
631
|
+
)
|
605
632
|
return
|
606
633
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
607
634
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
@@ -640,6 +667,9 @@ class Fp8MoEMethod:
|
|
640
667
|
requires_grad=False,
|
641
668
|
)
|
642
669
|
torch.cuda.empty_cache()
|
670
|
+
# ROCm (CK_MOE): using column-wise scaling
|
671
|
+
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
672
|
+
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
643
673
|
elif get_bool_env_var("MOE_PADDING"):
|
644
674
|
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
645
675
|
layer.w13_weight = torch.nn.Parameter(
|
@@ -744,6 +774,9 @@ class Fp8MoEMethod:
|
|
744
774
|
requires_grad=False,
|
745
775
|
)
|
746
776
|
torch.cuda.empty_cache()
|
777
|
+
# ROCm (CK_MOE): using column-wise scaling
|
778
|
+
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
779
|
+
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
747
780
|
elif get_bool_env_var("MOE_PADDING"):
|
748
781
|
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
749
782
|
layer.w13_weight = torch.nn.Parameter(
|
@@ -771,6 +804,8 @@ class Fp8MoEMethod:
|
|
771
804
|
custom_routing_function: Optional[Callable] = None,
|
772
805
|
correction_bias: Optional[torch.Tensor] = None,
|
773
806
|
activation: str = "silu",
|
807
|
+
inplace: bool = True,
|
808
|
+
no_combine: bool = False,
|
774
809
|
) -> torch.Tensor:
|
775
810
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
776
811
|
from sglang.srt.layers.moe.topk import select_experts
|
@@ -788,33 +823,38 @@ class Fp8MoEMethod:
|
|
788
823
|
correction_bias=correction_bias,
|
789
824
|
)
|
790
825
|
|
791
|
-
if is_hip_ and get_bool_env_var("CK_MOE"):
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
826
|
+
if is_hip_ and get_bool_env_var("CK_MOE") and activation == "silu":
|
827
|
+
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
|
828
|
+
assert not no_combine, f"{no_combine=} is not supported."
|
829
|
+
if self.block_quant:
|
830
|
+
return asm_moe(
|
831
|
+
x,
|
832
|
+
layer.w13_weight,
|
833
|
+
layer.w2_weight,
|
834
|
+
topk_weights,
|
835
|
+
topk_ids,
|
836
|
+
layer.w13_weight_scale_inv,
|
837
|
+
layer.w2_weight_scale_inv,
|
838
|
+
None,
|
839
|
+
None,
|
840
|
+
False,
|
841
|
+
None,
|
842
|
+
block_shape=tuple(self.quant_config.weight_block_size),
|
843
|
+
expert_mask=None,
|
844
|
+
)
|
845
|
+
else:
|
846
|
+
return asm_moe(
|
847
|
+
x,
|
848
|
+
layer.w13_weight,
|
849
|
+
layer.w2_weight,
|
850
|
+
topk_weights,
|
851
|
+
topk_ids,
|
852
|
+
layer.w13_weight_scale1,
|
853
|
+
layer.w2_weight_scale1,
|
854
|
+
None,
|
855
|
+
None,
|
856
|
+
False,
|
857
|
+
)
|
818
858
|
else:
|
819
859
|
# Expert fusion with FP8 quantization
|
820
860
|
return fused_experts(
|
@@ -823,7 +863,7 @@ class Fp8MoEMethod:
|
|
823
863
|
layer.w2_weight,
|
824
864
|
topk_weights=topk_weights,
|
825
865
|
topk_ids=topk_ids,
|
826
|
-
inplace=
|
866
|
+
inplace=inplace and not no_combine,
|
827
867
|
activation=activation,
|
828
868
|
use_fp8_w8a8=True,
|
829
869
|
w1_scale=(
|
@@ -839,6 +879,7 @@ class Fp8MoEMethod:
|
|
839
879
|
a1_scale=layer.w13_input_scale,
|
840
880
|
a2_scale=layer.w2_input_scale,
|
841
881
|
block_shape=self.quant_config.weight_block_size,
|
882
|
+
no_combine=no_combine,
|
842
883
|
)
|
843
884
|
|
844
885
|
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import os
|
1
2
|
from typing import List, Optional, Tuple
|
2
3
|
|
3
4
|
import torch
|
@@ -7,9 +8,12 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
7
8
|
per_token_group_quant_fp8,
|
8
9
|
w8a8_block_fp8_matmul,
|
9
10
|
)
|
10
|
-
from sglang.srt.utils import is_hip
|
11
|
+
from sglang.srt.utils import get_bool_env_var, is_hip
|
11
12
|
|
12
13
|
is_hip_ = is_hip()
|
14
|
+
if is_hip_ and get_bool_env_var("CK_MOE"):
|
15
|
+
from aiter import gemm_a8w8_blockscale
|
16
|
+
|
13
17
|
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
14
18
|
if _is_cuda:
|
15
19
|
from sgl_kernel import fp8_blockwise_scaled_mm
|
@@ -40,6 +44,8 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
|
40
44
|
|
41
45
|
|
42
46
|
def cutlass_block_fp8_supported() -> bool:
|
47
|
+
if os.environ.get("SUPPORT_CUTLASS_BLOCK_FP8") is None:
|
48
|
+
return False
|
43
49
|
if _is_cuda:
|
44
50
|
major, minor = torch.cuda.get_device_capability()
|
45
51
|
sm_version = major * 10 + minor
|
@@ -75,6 +81,16 @@ def apply_w8a8_block_fp8_linear(
|
|
75
81
|
output = fp8_blockwise_scaled_mm(
|
76
82
|
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
77
83
|
)
|
84
|
+
elif is_hip_ and get_bool_env_var("CK_MOE"):
|
85
|
+
q_input, x_scale = per_token_group_quant_fp8(
|
86
|
+
input_2d, block_size[1], column_major_scales=False
|
87
|
+
)
|
88
|
+
output = torch.zeros(
|
89
|
+
[q_input.shape[0], weight.shape[0]],
|
90
|
+
dtype=input.dtype,
|
91
|
+
device=q_input.device,
|
92
|
+
)
|
93
|
+
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
78
94
|
else:
|
79
95
|
q_input, x_scale = per_token_group_quant_fp8(
|
80
96
|
input_2d, block_size[1], column_major_scales=False
|
@@ -0,0 +1,416 @@
|
|
1
|
+
import logging
|
2
|
+
from fractions import Fraction
|
3
|
+
from typing import Any, Dict, List, Optional, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from vllm.scalar_type import scalar_types
|
7
|
+
|
8
|
+
from sglang.srt.layers.linear import LinearBase
|
9
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
10
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class GPTQConfig(QuantizationConfig):
|
16
|
+
"""Config class for GPTQ.
|
17
|
+
|
18
|
+
Reference: https://arxiv.org/abs/2210.17323
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
weight_bits: int,
|
24
|
+
group_size: int,
|
25
|
+
desc_act: bool,
|
26
|
+
lm_head_quantized: bool,
|
27
|
+
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
28
|
+
) -> None:
|
29
|
+
# GPTQModel use `dynamic` config property to allow per module
|
30
|
+
# quantization config so each module can be individually optimized.
|
31
|
+
# Format is Dict[str, Dict] where key is a regex string that can
|
32
|
+
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
33
|
+
# matching of a module.
|
34
|
+
# Default to positive match, override base quant config mode, if no
|
35
|
+
# prefix is used. Value is in dict format of field key and override
|
36
|
+
# value.
|
37
|
+
# Negative matching will skip quantization init for this module
|
38
|
+
# entirely:
|
39
|
+
# non-quantized inference. More details and quantization examples can be
|
40
|
+
# found at: https://github.com/ModelCloud/GPTQModel
|
41
|
+
# Example:
|
42
|
+
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
43
|
+
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
44
|
+
# dynamic = {
|
45
|
+
# #`.*\.` matches the layers_node prefix
|
46
|
+
# # positive match layer 10-15
|
47
|
+
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
48
|
+
# # positive match layer 16-21
|
49
|
+
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
50
|
+
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
51
|
+
# }
|
52
|
+
super().__init__()
|
53
|
+
self.dynamic = dynamic
|
54
|
+
|
55
|
+
self.weight_bits = weight_bits
|
56
|
+
self.group_size = group_size
|
57
|
+
self.desc_act = desc_act
|
58
|
+
self.lm_head_quantized = lm_head_quantized
|
59
|
+
self.pack_factor = Fraction(32, self.weight_bits)
|
60
|
+
if self.weight_bits not in [2, 3, 4, 8]:
|
61
|
+
raise ValueError(
|
62
|
+
"Currently, only 2/3/4/8-bit weight quantization is "
|
63
|
+
f"supported for GPTQ, but got {self.weight_bits} bits."
|
64
|
+
)
|
65
|
+
|
66
|
+
def __repr__(self) -> str:
|
67
|
+
return (
|
68
|
+
f"GPTQConfig(weight_bits={self.weight_bits}, "
|
69
|
+
f"group_size={self.group_size}, "
|
70
|
+
f"desc_act={self.desc_act}),"
|
71
|
+
f"lm_head_quantized={self.lm_head_quantized}), "
|
72
|
+
f"dynamic={self.dynamic}"
|
73
|
+
)
|
74
|
+
|
75
|
+
def get_scaled_act_names(self) -> List[str]:
|
76
|
+
"""Returns the activation function names that should be post-scaled.
|
77
|
+
|
78
|
+
For now, this is only used by AWQ.
|
79
|
+
"""
|
80
|
+
raise NotImplementedError
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def get_name(cls) -> str:
|
84
|
+
return "gptq"
|
85
|
+
|
86
|
+
@classmethod
|
87
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
88
|
+
return [torch.half]
|
89
|
+
|
90
|
+
@classmethod
|
91
|
+
# Need to figure it out
|
92
|
+
def get_min_capability(cls) -> int:
|
93
|
+
return 60
|
94
|
+
|
95
|
+
@classmethod
|
96
|
+
def get_config_filenames(cls) -> List[str]:
|
97
|
+
return ["quantize_config.json"]
|
98
|
+
|
99
|
+
@classmethod
|
100
|
+
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
|
101
|
+
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
102
|
+
dynamic = {} if dynamic is None else dynamic
|
103
|
+
|
104
|
+
weight_bits = cls.get_from_keys(config, ["bits"])
|
105
|
+
group_size = cls.get_from_keys(config, ["group_size"])
|
106
|
+
desc_act = cls.get_from_keys(config, ["desc_act"])
|
107
|
+
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
108
|
+
return cls(weight_bits, group_size, desc_act, lm_head_quantized, dynamic)
|
109
|
+
|
110
|
+
def get_quant_method(
|
111
|
+
self, layer: torch.nn.Module, prefix: str
|
112
|
+
) -> Optional["GPTQLinearMethod"]:
|
113
|
+
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
114
|
+
|
115
|
+
from sglang.srt.layers.quantization import get_linear_quant_method
|
116
|
+
|
117
|
+
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
118
|
+
|
119
|
+
|
120
|
+
class GPTQMarlinConfig(QuantizationConfig):
|
121
|
+
"""Config class for GPTQ Marlin"""
|
122
|
+
|
123
|
+
# (num_bits, is_sym) -> quant_type
|
124
|
+
TYPE_MAP = {
|
125
|
+
(4, True): scalar_types.uint4b8,
|
126
|
+
(8, True): scalar_types.uint8b128,
|
127
|
+
}
|
128
|
+
|
129
|
+
def __init__(
|
130
|
+
self,
|
131
|
+
weight_bits: int,
|
132
|
+
group_size: int,
|
133
|
+
desc_act: bool,
|
134
|
+
is_sym: bool,
|
135
|
+
lm_head_quantized: bool,
|
136
|
+
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
137
|
+
full_config: Dict[str, Any],
|
138
|
+
) -> None:
|
139
|
+
super().__init__()
|
140
|
+
if desc_act and group_size == -1:
|
141
|
+
# In this case, act_order == True is the same as act_order == False
|
142
|
+
# (since we have only one group per output channel)
|
143
|
+
desc_act = False
|
144
|
+
|
145
|
+
# GPTQModel use `dynamic` config property to allow per module
|
146
|
+
# quantization config so each module can be individually optimized.
|
147
|
+
# Format is Dict[str, Dict] where key is a regex string that can
|
148
|
+
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
149
|
+
# matching of a module.
|
150
|
+
# Default to positive match, override base quant config mode, if no
|
151
|
+
# prefix is used. Value is in dict format of field key and override
|
152
|
+
# value.
|
153
|
+
# Negative matching will skip quantization init for this module
|
154
|
+
# entirely:
|
155
|
+
# non-quantized inference. More details and quantization examples can be
|
156
|
+
# found at: https://github.com/ModelCloud/GPTQModel
|
157
|
+
# Example:
|
158
|
+
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
159
|
+
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
160
|
+
# dynamic = {
|
161
|
+
# #`.*\.` matches the layers_node prefix
|
162
|
+
# # positive match layer 10-15
|
163
|
+
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
164
|
+
# # positive match layer 16-21
|
165
|
+
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
166
|
+
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
167
|
+
# }
|
168
|
+
self.dynamic = dynamic
|
169
|
+
|
170
|
+
self.weight_bits = weight_bits
|
171
|
+
self.is_sym = is_sym
|
172
|
+
|
173
|
+
self.pack_factor = 32 // weight_bits # packed into int32
|
174
|
+
self.group_size = group_size
|
175
|
+
self.desc_act = desc_act
|
176
|
+
self.lm_head_quantized = lm_head_quantized
|
177
|
+
self.full_config = full_config
|
178
|
+
|
179
|
+
if (weight_bits, is_sym) not in self.TYPE_MAP:
|
180
|
+
raise ValueError(
|
181
|
+
"Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
|
182
|
+
)
|
183
|
+
|
184
|
+
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
185
|
+
|
186
|
+
def __repr__(self) -> str:
|
187
|
+
return (
|
188
|
+
f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
189
|
+
f"group_size={self.group_size}, "
|
190
|
+
f"desc_act={self.desc_act}, "
|
191
|
+
f"lm_head_quantized={self.lm_head_quantized}), "
|
192
|
+
f"dynamic={self.dynamic}"
|
193
|
+
)
|
194
|
+
|
195
|
+
def get_scaled_act_names(self) -> List[str]:
|
196
|
+
"""Returns the activation function names that should be post-scaled.
|
197
|
+
|
198
|
+
For now, this is only used by AWQ.
|
199
|
+
"""
|
200
|
+
raise NotImplementedError
|
201
|
+
|
202
|
+
@classmethod
|
203
|
+
def get_name(cls) -> str:
|
204
|
+
return "gptq_marlin"
|
205
|
+
|
206
|
+
@classmethod
|
207
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
208
|
+
return [torch.half, torch.bfloat16]
|
209
|
+
|
210
|
+
@classmethod
|
211
|
+
def get_min_capability(cls) -> int:
|
212
|
+
return 80
|
213
|
+
|
214
|
+
@classmethod
|
215
|
+
def get_config_filenames(cls) -> List[str]:
|
216
|
+
return ["quantize_config.json"]
|
217
|
+
|
218
|
+
@classmethod
|
219
|
+
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
|
220
|
+
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
221
|
+
dynamic = {} if dynamic is None else dynamic
|
222
|
+
|
223
|
+
weight_bits = cls.get_from_keys(config, ["bits"])
|
224
|
+
group_size = cls.get_from_keys(config, ["group_size"])
|
225
|
+
desc_act = cls.get_from_keys(config, ["desc_act"])
|
226
|
+
is_sym = cls.get_from_keys(config, ["sym"])
|
227
|
+
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
228
|
+
return cls(
|
229
|
+
weight_bits,
|
230
|
+
group_size,
|
231
|
+
desc_act,
|
232
|
+
is_sym,
|
233
|
+
lm_head_quantized,
|
234
|
+
dynamic,
|
235
|
+
config,
|
236
|
+
)
|
237
|
+
|
238
|
+
@classmethod
|
239
|
+
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
240
|
+
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
241
|
+
|
242
|
+
is_valid_user_quant = (
|
243
|
+
user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
|
244
|
+
)
|
245
|
+
|
246
|
+
if can_convert and is_valid_user_quant:
|
247
|
+
msg = (
|
248
|
+
"The model is convertible to {} during runtime."
|
249
|
+
" Using {} kernel.".format(cls.get_name(), cls.get_name())
|
250
|
+
)
|
251
|
+
logger.info(msg)
|
252
|
+
return cls.get_name()
|
253
|
+
|
254
|
+
if can_convert and user_quant == "gptq":
|
255
|
+
logger.info(
|
256
|
+
"Detected that the model can run with gptq_marlin"
|
257
|
+
", however you specified quantization=gptq explicitly,"
|
258
|
+
" so forcing gptq. Use quantization=gptq_marlin for"
|
259
|
+
" faster inference"
|
260
|
+
)
|
261
|
+
return None
|
262
|
+
|
263
|
+
def get_quant_method(
|
264
|
+
self, layer: torch.nn.Module, prefix: str
|
265
|
+
) -> Optional["QuantizeMethodBase"]:
|
266
|
+
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
267
|
+
GPTQMarlinLinearMethod,
|
268
|
+
GPTQMarlinMoEMethod,
|
269
|
+
)
|
270
|
+
|
271
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
272
|
+
from sglang.srt.layers.quantization import get_linear_quant_method
|
273
|
+
|
274
|
+
if isinstance(layer, FusedMoE):
|
275
|
+
return GPTQMarlinMoEMethod(self)
|
276
|
+
# TODO: re-enable after SGLang syncs with vllm >= 0.7.3
|
277
|
+
# if layer.num_experts > 32:
|
278
|
+
# # For MoEs with many experts the moe_wna16 kernel is faster
|
279
|
+
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
|
280
|
+
# layer, prefix
|
281
|
+
# )
|
282
|
+
# else:
|
283
|
+
# return GPTQMarlinMoEMethod(self)
|
284
|
+
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
|
285
|
+
|
286
|
+
@classmethod
|
287
|
+
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
288
|
+
quant_method = quant_config.get("quant_method", "").lower()
|
289
|
+
num_bits = quant_config.get("bits")
|
290
|
+
group_size = quant_config.get("group_size")
|
291
|
+
sym = quant_config.get("sym")
|
292
|
+
desc_act = quant_config.get("desc_act")
|
293
|
+
|
294
|
+
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
295
|
+
check_marlin_supported,
|
296
|
+
)
|
297
|
+
from vllm.platforms import current_platform
|
298
|
+
|
299
|
+
if not current_platform.is_cuda():
|
300
|
+
return False
|
301
|
+
|
302
|
+
if quant_method != "gptq":
|
303
|
+
return False
|
304
|
+
|
305
|
+
# Marlin conversion is only valid if required properties are found
|
306
|
+
if num_bits is None or group_size is None or sym is None or desc_act is None:
|
307
|
+
return False
|
308
|
+
|
309
|
+
if (num_bits, sym) not in cls.TYPE_MAP:
|
310
|
+
return False
|
311
|
+
|
312
|
+
return check_marlin_supported(
|
313
|
+
quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
|
314
|
+
)
|
315
|
+
|
316
|
+
|
317
|
+
class MarlinConfig(QuantizationConfig):
|
318
|
+
"""Config class for Marlin.
|
319
|
+
|
320
|
+
Reference: https://github.com/IST-DASLab/marlin/tree/master
|
321
|
+
"""
|
322
|
+
|
323
|
+
def __init__(
|
324
|
+
self,
|
325
|
+
group_size: int,
|
326
|
+
lm_head_quantized: bool,
|
327
|
+
) -> None:
|
328
|
+
# Group size for the quantization.
|
329
|
+
self.group_size = group_size
|
330
|
+
self.lm_head_quantized = lm_head_quantized
|
331
|
+
if self.group_size != 128 and self.group_size != -1:
|
332
|
+
raise ValueError(
|
333
|
+
"Currently, only group size 128 and -1 (channelwise) "
|
334
|
+
"is supported for Marlin, but got group_size of "
|
335
|
+
f"{self.group_size}"
|
336
|
+
)
|
337
|
+
|
338
|
+
# 4 Bits packed into 32 bit datatype.
|
339
|
+
self.pack_factor = 32 // 4
|
340
|
+
|
341
|
+
# Tile size used by marlin kernels.
|
342
|
+
self.tile_size = 16
|
343
|
+
|
344
|
+
# Min out_features dim
|
345
|
+
self.min_n_threads = 64
|
346
|
+
|
347
|
+
# Min in_features dim
|
348
|
+
self.min_k_threads = 128
|
349
|
+
|
350
|
+
# Max parallel problems to solve at once (improves large
|
351
|
+
# batch performance)
|
352
|
+
self.max_parallel = 16
|
353
|
+
|
354
|
+
# Permutation length used by the marlin kernels.
|
355
|
+
self.perm_len = 1024
|
356
|
+
|
357
|
+
def __repr__(self) -> str:
|
358
|
+
return (
|
359
|
+
f"MarlinConfig(group_size={self.group_size}, "
|
360
|
+
f"lm_head_quantized={self.lm_head_quantized})"
|
361
|
+
)
|
362
|
+
|
363
|
+
@classmethod
|
364
|
+
def get_name(cls) -> str:
|
365
|
+
return "marlin"
|
366
|
+
|
367
|
+
@classmethod
|
368
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
369
|
+
return [torch.half]
|
370
|
+
|
371
|
+
@classmethod
|
372
|
+
# Need to figure it out
|
373
|
+
def get_min_capability(cls) -> int:
|
374
|
+
return 80
|
375
|
+
|
376
|
+
@classmethod
|
377
|
+
def get_config_filenames(cls) -> List[str]:
|
378
|
+
return ["quantize_config.json"]
|
379
|
+
|
380
|
+
@classmethod
|
381
|
+
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
|
382
|
+
group_size = cls.get_from_keys(config, ["group_size"])
|
383
|
+
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
384
|
+
return cls(group_size, lm_head_quantized)
|
385
|
+
|
386
|
+
@classmethod
|
387
|
+
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
388
|
+
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
389
|
+
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
390
|
+
is_marlin_format = hf_quant_cfg.get(
|
391
|
+
"checkpoint_format"
|
392
|
+
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
|
393
|
+
|
394
|
+
is_valid_user_quant = (
|
395
|
+
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
|
396
|
+
)
|
397
|
+
|
398
|
+
if is_marlin_format and is_valid_user_quant:
|
399
|
+
msg = "The model is serialized in {} format. Using {} kernel.".format(
|
400
|
+
cls.get_name(), cls.get_name()
|
401
|
+
)
|
402
|
+
logger.info(msg)
|
403
|
+
return cls.get_name()
|
404
|
+
|
405
|
+
return None
|
406
|
+
|
407
|
+
def get_quant_method(
|
408
|
+
self, layer: torch.nn.Module, prefix: str
|
409
|
+
) -> Optional["MarlinLinearMethod"]:
|
410
|
+
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
411
|
+
|
412
|
+
if isinstance(layer, LinearBase) or (
|
413
|
+
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
414
|
+
):
|
415
|
+
return MarlinLinearMethod(self)
|
416
|
+
return None
|