sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +208 -295
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +238 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +209 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -29
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
|
|
1
1
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
2
|
-
|
2
|
+
import re
|
3
|
+
from copy import deepcopy
|
4
|
+
from typing import Callable, Dict, Optional, Type, Union
|
3
5
|
|
4
6
|
import torch
|
5
7
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
@@ -16,15 +18,15 @@ from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfi
|
|
16
18
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
17
19
|
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
18
20
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
19
|
-
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
20
|
-
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
21
21
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
|
22
22
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
23
23
|
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
24
24
|
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
25
25
|
|
26
26
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
27
|
+
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
27
28
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
29
|
+
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
28
30
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
29
31
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
30
32
|
|
@@ -34,6 +36,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
34
36
|
"deepspeedfp": DeepSpeedFPConfig,
|
35
37
|
"tpu_int8": Int8TpuConfig,
|
36
38
|
"fp8": Fp8Config,
|
39
|
+
"blockwise_int8": BlockInt8Config,
|
37
40
|
"fbgemm_fp8": FBGEMMFp8Config,
|
38
41
|
"marlin": MarlinConfig,
|
39
42
|
"modelopt": ModelOptFp8Config,
|
@@ -59,19 +62,119 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
59
62
|
return QUANTIZATION_METHODS[quantization]
|
60
63
|
|
61
64
|
|
65
|
+
# Match dynamic rules with module name (prefix) and override quantize
|
66
|
+
# config if module (prefix) matches a rule
|
67
|
+
def override_config(config: QuantizationConfig, prefix: str):
|
68
|
+
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
|
69
|
+
if isinstance(weight_bits, int):
|
70
|
+
config.weight_bits = weight_bits
|
71
|
+
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
|
72
|
+
if isinstance(group_size, int):
|
73
|
+
config.group_size = group_size
|
74
|
+
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
|
75
|
+
if isinstance(desc_act, bool):
|
76
|
+
config.desc_act = desc_act
|
77
|
+
|
78
|
+
config.pack_factor = 32 // config.weight_bits # packed into int32
|
79
|
+
if config.get_name() == "gptq_marlin":
|
80
|
+
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
81
|
+
if isinstance(is_sym, bool):
|
82
|
+
config.is_sym = is_sym
|
83
|
+
|
84
|
+
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
85
|
+
raise ValueError(
|
86
|
+
"Unsupported quantization config: "
|
87
|
+
f"bits={config.weight_bits}, sym={config.is_sym}"
|
88
|
+
)
|
89
|
+
|
90
|
+
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
|
91
|
+
elif config.get_name() == "gptq":
|
92
|
+
if config.weight_bits not in [2, 3, 4, 8]:
|
93
|
+
raise ValueError(
|
94
|
+
"Currently, only 2/3/4/8-bit weight quantization is "
|
95
|
+
f"supported for GPTQ, but got {config.weight_bits} bits."
|
96
|
+
)
|
97
|
+
|
98
|
+
|
99
|
+
def get_dynamic_override(
|
100
|
+
config: QuantizationConfig,
|
101
|
+
layer_name: str,
|
102
|
+
key: Optional[str] = None,
|
103
|
+
default_value: Union[int, bool, None] = None,
|
104
|
+
) -> Union[Dict, int, bool, None]:
|
105
|
+
for pattern, pattern_dict in config.dynamic.items():
|
106
|
+
# Negative match: matched modules are excluded from quantized init
|
107
|
+
if pattern.startswith("-:"):
|
108
|
+
if re.match(pattern.removeprefix("-:"), layer_name):
|
109
|
+
return False
|
110
|
+
# Positive match: matched modules have quant properties overrides
|
111
|
+
# base quant config
|
112
|
+
elif re.match(pattern.removeprefix("+:"), layer_name):
|
113
|
+
if key is None:
|
114
|
+
return pattern_dict
|
115
|
+
else:
|
116
|
+
return pattern_dict.get(key, default_value)
|
117
|
+
return default_value
|
118
|
+
|
119
|
+
|
120
|
+
def get_linear_quant_method(
|
121
|
+
config: QuantizationConfig,
|
122
|
+
layer: torch.nn.Module,
|
123
|
+
prefix: str,
|
124
|
+
linear_method_cls: type,
|
125
|
+
):
|
126
|
+
|
127
|
+
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
128
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
129
|
+
ParallelLMHead,
|
130
|
+
UnquantizedEmbeddingMethod,
|
131
|
+
)
|
132
|
+
|
133
|
+
cloned_config = deepcopy(config)
|
134
|
+
parallel_lm_head_quantized = (
|
135
|
+
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
136
|
+
)
|
137
|
+
|
138
|
+
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
139
|
+
# False = skip module, None = no override, else = Positive match
|
140
|
+
if (
|
141
|
+
get_dynamic_override( # noqa: E712
|
142
|
+
cloned_config, layer_name=prefix # noqa: E712
|
143
|
+
)
|
144
|
+
== False
|
145
|
+
): # noqa: E712
|
146
|
+
if parallel_lm_head_quantized:
|
147
|
+
return UnquantizedEmbeddingMethod()
|
148
|
+
return UnquantizedLinearMethod()
|
149
|
+
|
150
|
+
if prefix:
|
151
|
+
# Dynamic per module/layer rules may override base config
|
152
|
+
override_config(cloned_config, prefix=prefix)
|
153
|
+
|
154
|
+
return linear_method_cls(cloned_config)
|
155
|
+
return None
|
156
|
+
|
157
|
+
|
62
158
|
def gptq_get_quant_method(self, layer, prefix):
|
159
|
+
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
63
160
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
64
161
|
GPTQMarlinLinearMethod,
|
65
162
|
GPTQMarlinMoEMethod,
|
66
163
|
)
|
67
164
|
|
68
|
-
from sglang.srt.layers.linear import LinearBase
|
69
165
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
70
166
|
|
71
|
-
if isinstance(layer,
|
72
|
-
return GPTQMarlinLinearMethod(self)
|
73
|
-
elif isinstance(layer, FusedMoE):
|
167
|
+
if isinstance(layer, FusedMoE):
|
74
168
|
return GPTQMarlinMoEMethod(self)
|
169
|
+
|
170
|
+
if isinstance(self, GPTQConfig):
|
171
|
+
return get_linear_quant_method(
|
172
|
+
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
173
|
+
)
|
174
|
+
elif isinstance(self, GPTQMarlinConfig):
|
175
|
+
return get_linear_quant_method(
|
176
|
+
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
177
|
+
)
|
75
178
|
return None
|
76
179
|
|
77
180
|
|
@@ -153,6 +256,7 @@ def apply_monkey_patches():
|
|
153
256
|
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
|
154
257
|
|
155
258
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
259
|
+
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
156
260
|
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
157
261
|
setattr(AWQMoEMethod, "apply", awq_moe_method_apply)
|
158
262
|
|
@@ -0,0 +1,409 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Any, Callable, Dict, List, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch.nn import Module
|
8
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
9
|
+
|
10
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
11
|
+
from sglang.srt.layers.linear import (
|
12
|
+
LinearBase,
|
13
|
+
LinearMethodBase,
|
14
|
+
UnquantizedLinearMethod,
|
15
|
+
)
|
16
|
+
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
17
|
+
from sglang.srt.layers.quantization.base_config import (
|
18
|
+
QuantizationConfig,
|
19
|
+
QuantizeMethodBase,
|
20
|
+
)
|
21
|
+
from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
|
22
|
+
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
23
|
+
from sglang.srt.utils import set_weight_attrs
|
24
|
+
|
25
|
+
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
26
|
+
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
class BlockInt8Config(QuantizationConfig):
|
31
|
+
"""Config class for INT8."""
|
32
|
+
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
is_checkpoint_int8_serialized: bool = False,
|
36
|
+
activation_scheme: str = "dynamic",
|
37
|
+
ignored_layers: Optional[List[str]] = None,
|
38
|
+
weight_block_size: List[int] = None,
|
39
|
+
) -> None:
|
40
|
+
self.is_checkpoint_int8_serialized = is_checkpoint_int8_serialized
|
41
|
+
if is_checkpoint_int8_serialized:
|
42
|
+
logger.warning(
|
43
|
+
"Detected int8 checkpoint. Please note that the "
|
44
|
+
"format is experimental and subject to change."
|
45
|
+
)
|
46
|
+
if activation_scheme not in ACTIVATION_SCHEMES:
|
47
|
+
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
|
48
|
+
self.activation_scheme = activation_scheme
|
49
|
+
self.ignored_layers = ignored_layers or []
|
50
|
+
if weight_block_size is not None:
|
51
|
+
if not is_checkpoint_int8_serialized:
|
52
|
+
raise ValueError(
|
53
|
+
f"The block-wise quantization only supports int8-serialized checkpoint for now."
|
54
|
+
)
|
55
|
+
if len(weight_block_size) != 2:
|
56
|
+
raise ValueError(
|
57
|
+
f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions."
|
58
|
+
)
|
59
|
+
if activation_scheme != "dynamic":
|
60
|
+
raise ValueError(
|
61
|
+
f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme."
|
62
|
+
)
|
63
|
+
self.weight_block_size = weight_block_size
|
64
|
+
|
65
|
+
@classmethod
|
66
|
+
def get_name(cls) -> str:
|
67
|
+
return "blockwise_int8"
|
68
|
+
|
69
|
+
@classmethod
|
70
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
71
|
+
return [torch.bfloat16, torch.half]
|
72
|
+
|
73
|
+
@classmethod
|
74
|
+
def get_min_capability(cls) -> int:
|
75
|
+
return 80
|
76
|
+
|
77
|
+
@classmethod
|
78
|
+
def get_config_filenames(cls) -> List[str]:
|
79
|
+
return []
|
80
|
+
|
81
|
+
@classmethod
|
82
|
+
def from_config(cls, config: Dict[str, Any]) -> "BlockInt8Config":
|
83
|
+
quant_method = cls.get_from_keys(config, ["quant_method"])
|
84
|
+
is_checkpoint_int8_serialized = "int8" in quant_method
|
85
|
+
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
86
|
+
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
87
|
+
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
|
88
|
+
return cls(
|
89
|
+
is_checkpoint_int8_serialized=is_checkpoint_int8_serialized,
|
90
|
+
activation_scheme=activation_scheme,
|
91
|
+
ignored_layers=ignored_layers,
|
92
|
+
weight_block_size=weight_block_size,
|
93
|
+
)
|
94
|
+
|
95
|
+
def get_quant_method(
|
96
|
+
self, layer: torch.nn.Module, prefix: str
|
97
|
+
) -> Optional["QuantizeMethodBase"]:
|
98
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
99
|
+
|
100
|
+
if isinstance(layer, LinearBase):
|
101
|
+
if is_layer_skipped(prefix, self.ignored_layers):
|
102
|
+
return UnquantizedLinearMethod()
|
103
|
+
return BlockInt8LinearMethod(self)
|
104
|
+
elif isinstance(layer, FusedMoE):
|
105
|
+
return BlockInt8MoEMethod(self)
|
106
|
+
return None
|
107
|
+
|
108
|
+
def get_scaled_act_names(self) -> List[str]:
|
109
|
+
return []
|
110
|
+
|
111
|
+
|
112
|
+
class BlockInt8LinearMethod(LinearMethodBase):
|
113
|
+
"""Linear method for INT8.
|
114
|
+
Supports loading INT8 checkpoints with static weight scale and
|
115
|
+
dynamic activation scale.
|
116
|
+
|
117
|
+
Limitations:
|
118
|
+
Only support block-wise int8 quantization and int8 checkpoint
|
119
|
+
|
120
|
+
Args:
|
121
|
+
quant_config: The quantization config.
|
122
|
+
"""
|
123
|
+
|
124
|
+
def __init__(self, quant_config: BlockInt8Config):
|
125
|
+
self.quant_config = quant_config
|
126
|
+
assert self.quant_config.weight_block_size is not None
|
127
|
+
assert self.quant_config.is_checkpoint_int8_serialized
|
128
|
+
|
129
|
+
def create_weights(
|
130
|
+
self,
|
131
|
+
layer: torch.nn.Module,
|
132
|
+
input_size_per_partition: int,
|
133
|
+
output_partition_sizes: List[int],
|
134
|
+
input_size: int,
|
135
|
+
output_size: int,
|
136
|
+
params_dtype: torch.dtype,
|
137
|
+
**extra_weight_attrs,
|
138
|
+
):
|
139
|
+
output_size_per_partition = sum(output_partition_sizes)
|
140
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
141
|
+
|
142
|
+
tp_size = get_tensor_model_parallel_world_size()
|
143
|
+
|
144
|
+
block_n, block_k = (
|
145
|
+
self.quant_config.weight_block_size[0],
|
146
|
+
self.quant_config.weight_block_size[1],
|
147
|
+
)
|
148
|
+
# Required by row parallel
|
149
|
+
if tp_size > 1 and input_size // input_size_per_partition == tp_size:
|
150
|
+
if input_size_per_partition % block_k != 0:
|
151
|
+
raise ValueError(
|
152
|
+
f"Weight input_size_per_partition = "
|
153
|
+
f"{input_size_per_partition} is not divisible by "
|
154
|
+
f"weight quantization block_k = {block_k}."
|
155
|
+
)
|
156
|
+
# Required by collum parallel or enabling merged weights
|
157
|
+
if (tp_size > 1 and output_size // output_size_per_partition == tp_size) or len(
|
158
|
+
output_partition_sizes
|
159
|
+
) > 1:
|
160
|
+
for output_partition_size in output_partition_sizes:
|
161
|
+
if output_partition_size % block_n != 0:
|
162
|
+
raise ValueError(
|
163
|
+
f"Weight output_partition_size = "
|
164
|
+
f"{output_partition_size} is not divisible by "
|
165
|
+
f"weight quantization block_n = {block_n}."
|
166
|
+
)
|
167
|
+
|
168
|
+
layer.logical_widths = output_partition_sizes
|
169
|
+
|
170
|
+
layer.input_size_per_partition = input_size_per_partition
|
171
|
+
layer.output_size_per_partition = output_size_per_partition
|
172
|
+
layer.orig_dtype = params_dtype
|
173
|
+
|
174
|
+
# WEIGHT
|
175
|
+
weight_dtype = (
|
176
|
+
torch.int8
|
177
|
+
if self.quant_config.is_checkpoint_int8_serialized
|
178
|
+
else params_dtype
|
179
|
+
)
|
180
|
+
|
181
|
+
weight = ModelWeightParameter(
|
182
|
+
data=torch.empty(
|
183
|
+
output_size_per_partition, input_size_per_partition, dtype=weight_dtype
|
184
|
+
),
|
185
|
+
input_dim=1,
|
186
|
+
output_dim=0,
|
187
|
+
weight_loader=weight_loader,
|
188
|
+
)
|
189
|
+
layer.register_parameter("weight", weight)
|
190
|
+
|
191
|
+
# WEIGHT SCALE
|
192
|
+
|
193
|
+
scale = BlockQuantScaleParameter(
|
194
|
+
data=torch.empty(
|
195
|
+
(output_size_per_partition + block_n - 1) // block_n,
|
196
|
+
(input_size_per_partition + block_k - 1) // block_k,
|
197
|
+
dtype=torch.float32,
|
198
|
+
),
|
199
|
+
input_dim=1,
|
200
|
+
output_dim=0,
|
201
|
+
weight_loader=weight_loader,
|
202
|
+
)
|
203
|
+
scale[:] = torch.finfo(torch.float32).min
|
204
|
+
layer.register_parameter("weight_scale_inv", scale)
|
205
|
+
|
206
|
+
# INPUT ACTIVATION SCALE
|
207
|
+
assert self.quant_config.activation_scheme == "dynamic"
|
208
|
+
layer.register_parameter("input_scale", None)
|
209
|
+
|
210
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
211
|
+
# Block quant doesn't need to process weights after loading
|
212
|
+
# Use torch Parameter to avoid cuda graph capturing issue
|
213
|
+
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
214
|
+
layer.weight_scale_inv = torch.nn.Parameter(
|
215
|
+
layer.weight_scale_inv.data, requires_grad=False
|
216
|
+
)
|
217
|
+
|
218
|
+
def apply(
|
219
|
+
self,
|
220
|
+
layer: torch.nn.Module,
|
221
|
+
x: torch.Tensor,
|
222
|
+
bias: Optional[torch.Tensor] = None,
|
223
|
+
) -> torch.Tensor:
|
224
|
+
return apply_w8a8_block_int8_linear(
|
225
|
+
input=x,
|
226
|
+
weight=layer.weight,
|
227
|
+
block_size=self.quant_config.weight_block_size,
|
228
|
+
weight_scale=layer.weight_scale_inv,
|
229
|
+
input_scale=None,
|
230
|
+
bias=bias,
|
231
|
+
)
|
232
|
+
|
233
|
+
|
234
|
+
class BlockInt8MoEMethod:
|
235
|
+
"""MoE method for INT8.
|
236
|
+
Supports loading INT8 checkpoints with static weight scale and
|
237
|
+
dynamic activation scale.
|
238
|
+
|
239
|
+
Limitations:
|
240
|
+
Only support block-wise int8 quantization and int8 checkpoint
|
241
|
+
|
242
|
+
Args:
|
243
|
+
quant_config: The quantization config.
|
244
|
+
"""
|
245
|
+
|
246
|
+
def __new__(cls, *args, **kwargs):
|
247
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
248
|
+
|
249
|
+
if not hasattr(cls, "_initialized"):
|
250
|
+
original_init = cls.__init__
|
251
|
+
new_cls = type(
|
252
|
+
cls.__name__,
|
253
|
+
(FusedMoEMethodBase,),
|
254
|
+
{
|
255
|
+
"__init__": original_init,
|
256
|
+
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
257
|
+
},
|
258
|
+
)
|
259
|
+
obj = super(new_cls, new_cls).__new__(new_cls)
|
260
|
+
obj.__init__(*args, **kwargs)
|
261
|
+
return obj
|
262
|
+
return super().__new__(cls)
|
263
|
+
|
264
|
+
def __init__(self, quant_config):
|
265
|
+
self.quant_config = quant_config
|
266
|
+
assert self.quant_config.weight_block_size is not None
|
267
|
+
assert self.quant_config.is_checkpoint_int8_serialized
|
268
|
+
|
269
|
+
def create_weights(
|
270
|
+
self,
|
271
|
+
layer: Module,
|
272
|
+
num_experts: int,
|
273
|
+
hidden_size: int,
|
274
|
+
intermediate_size: int,
|
275
|
+
params_dtype: torch.dtype,
|
276
|
+
**extra_weight_attrs,
|
277
|
+
):
|
278
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
279
|
+
|
280
|
+
if self.quant_config.is_checkpoint_int8_serialized:
|
281
|
+
params_dtype = torch.int8
|
282
|
+
tp_size = get_tensor_model_parallel_world_size()
|
283
|
+
|
284
|
+
block_n, block_k = (
|
285
|
+
self.quant_config.weight_block_size[0],
|
286
|
+
self.quant_config.weight_block_size[1],
|
287
|
+
)
|
288
|
+
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
289
|
+
# Required by collum parallel or enabling merged weights
|
290
|
+
if intermediate_size % block_n != 0:
|
291
|
+
raise ValueError(
|
292
|
+
f"The output_size of gate's and up's weight = "
|
293
|
+
f"{intermediate_size} is not divisible by "
|
294
|
+
f"weight quantization block_n = {block_n}."
|
295
|
+
)
|
296
|
+
if tp_size > 1:
|
297
|
+
# Required by row parallel
|
298
|
+
if intermediate_size % block_k != 0:
|
299
|
+
raise ValueError(
|
300
|
+
f"The input_size of down's weight = "
|
301
|
+
f"{intermediate_size} is not divisible by "
|
302
|
+
f"weight quantization block_k = {block_k}."
|
303
|
+
)
|
304
|
+
|
305
|
+
# WEIGHTS
|
306
|
+
w13_weight = torch.nn.Parameter(
|
307
|
+
torch.empty(
|
308
|
+
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
309
|
+
),
|
310
|
+
requires_grad=False,
|
311
|
+
)
|
312
|
+
layer.register_parameter("w13_weight", w13_weight)
|
313
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
314
|
+
|
315
|
+
w2_weight = torch.nn.Parameter(
|
316
|
+
torch.empty(
|
317
|
+
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
318
|
+
),
|
319
|
+
requires_grad=False,
|
320
|
+
)
|
321
|
+
layer.register_parameter("w2_weight", w2_weight)
|
322
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
323
|
+
|
324
|
+
# WEIGHT_SCALES
|
325
|
+
w13_weight_scale = torch.nn.Parameter(
|
326
|
+
torch.ones(
|
327
|
+
num_experts,
|
328
|
+
2 * ((intermediate_size + block_n - 1) // block_n),
|
329
|
+
(hidden_size + block_k - 1) // block_k,
|
330
|
+
dtype=torch.float32,
|
331
|
+
),
|
332
|
+
requires_grad=False,
|
333
|
+
)
|
334
|
+
w2_weight_scale = torch.nn.Parameter(
|
335
|
+
torch.ones(
|
336
|
+
num_experts,
|
337
|
+
(hidden_size + block_n - 1) // block_n,
|
338
|
+
(intermediate_size + block_k - 1) // block_k,
|
339
|
+
dtype=torch.float32,
|
340
|
+
),
|
341
|
+
requires_grad=False,
|
342
|
+
)
|
343
|
+
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
344
|
+
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
345
|
+
|
346
|
+
extra_weight_attrs.update(
|
347
|
+
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
348
|
+
)
|
349
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
350
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
351
|
+
|
352
|
+
# INPUT_SCALES
|
353
|
+
assert self.quant_config.activation_scheme == "dynamic"
|
354
|
+
layer.w13_input_scale = None
|
355
|
+
layer.w2_input_scale = None
|
356
|
+
|
357
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
358
|
+
# Block quant doesn't need to process weights after loading
|
359
|
+
return
|
360
|
+
|
361
|
+
def apply(
|
362
|
+
self,
|
363
|
+
layer: torch.nn.Module,
|
364
|
+
x: torch.Tensor,
|
365
|
+
router_logits: torch.Tensor,
|
366
|
+
top_k: int,
|
367
|
+
renormalize: bool,
|
368
|
+
use_grouped_topk: bool,
|
369
|
+
topk_group: Optional[int] = None,
|
370
|
+
num_expert_group: Optional[int] = None,
|
371
|
+
custom_routing_function: Optional[Callable] = None,
|
372
|
+
correction_bias: Optional[torch.Tensor] = None,
|
373
|
+
activation: str = "silu",
|
374
|
+
inplace: bool = True,
|
375
|
+
no_combine: bool = False,
|
376
|
+
) -> torch.Tensor:
|
377
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
378
|
+
from sglang.srt.layers.moe.topk import select_experts
|
379
|
+
|
380
|
+
# Expert selection
|
381
|
+
topk_weights, topk_ids = select_experts(
|
382
|
+
hidden_states=x,
|
383
|
+
router_logits=router_logits,
|
384
|
+
use_grouped_topk=use_grouped_topk,
|
385
|
+
top_k=top_k,
|
386
|
+
renormalize=renormalize,
|
387
|
+
topk_group=topk_group,
|
388
|
+
num_expert_group=num_expert_group,
|
389
|
+
custom_routing_function=custom_routing_function,
|
390
|
+
correction_bias=correction_bias,
|
391
|
+
)
|
392
|
+
|
393
|
+
# Expert fusion with INT8 quantization
|
394
|
+
return fused_experts(
|
395
|
+
x,
|
396
|
+
layer.w13_weight,
|
397
|
+
layer.w2_weight,
|
398
|
+
topk_weights=topk_weights,
|
399
|
+
topk_ids=topk_ids,
|
400
|
+
inplace=inplace,
|
401
|
+
activation=activation,
|
402
|
+
use_int8_w8a8=True,
|
403
|
+
w1_scale=(layer.w13_weight_scale_inv),
|
404
|
+
w2_scale=(layer.w2_weight_scale_inv),
|
405
|
+
a1_scale=layer.w13_input_scale,
|
406
|
+
a2_scale=layer.w2_input_scale,
|
407
|
+
block_shape=self.quant_config.weight_block_size,
|
408
|
+
no_combine=no_combine,
|
409
|
+
)
|