sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,9 @@ import logging
|
|
29
29
|
|
30
30
|
is_hip_ = is_hip()
|
31
31
|
|
32
|
+
if is_hip_:
|
33
|
+
from aiter import ck_moe
|
34
|
+
|
32
35
|
logger = logging.getLogger(__name__)
|
33
36
|
|
34
37
|
|
@@ -125,6 +128,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
125
128
|
custom_routing_function: Optional[Callable] = None,
|
126
129
|
correction_bias: Optional[torch.Tensor] = None,
|
127
130
|
activation: str = "silu",
|
131
|
+
inplace: bool = True,
|
132
|
+
no_combine: bool = False,
|
128
133
|
) -> torch.Tensor:
|
129
134
|
return self.forward(
|
130
135
|
x=x,
|
@@ -138,6 +143,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
138
143
|
custom_routing_function=custom_routing_function,
|
139
144
|
correction_bias=correction_bias,
|
140
145
|
activation=activation,
|
146
|
+
inplace=inplace,
|
147
|
+
no_combine=no_combine,
|
141
148
|
)
|
142
149
|
|
143
150
|
def forward_cuda(
|
@@ -153,6 +160,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
153
160
|
custom_routing_function: Optional[Callable] = None,
|
154
161
|
correction_bias: Optional[torch.Tensor] = None,
|
155
162
|
activation: str = "silu",
|
163
|
+
inplace: bool = True,
|
164
|
+
no_combine: bool = False,
|
156
165
|
) -> torch.Tensor:
|
157
166
|
topk_weights, topk_ids = select_experts(
|
158
167
|
hidden_states=x,
|
@@ -167,17 +176,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
167
176
|
)
|
168
177
|
|
169
178
|
if is_hip_ and get_bool_env_var("CK_MOE"):
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
179
|
+
assert not no_combine, "unsupported"
|
180
|
+
return ck_moe(
|
181
|
+
x,
|
182
|
+
layer.w13_weight,
|
183
|
+
layer.w2_weight,
|
184
|
+
topk_weights,
|
185
|
+
topk_ids,
|
186
|
+
None,
|
187
|
+
None,
|
188
|
+
None,
|
189
|
+
None,
|
190
|
+
32,
|
191
|
+
None,
|
192
|
+
activation,
|
181
193
|
)
|
182
194
|
else:
|
183
195
|
return fused_experts(
|
@@ -186,8 +198,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
186
198
|
w2=layer.w2_weight,
|
187
199
|
topk_weights=topk_weights,
|
188
200
|
topk_ids=topk_ids,
|
189
|
-
inplace=
|
201
|
+
inplace=inplace and not no_combine,
|
190
202
|
activation=activation,
|
203
|
+
no_combine=no_combine,
|
191
204
|
)
|
192
205
|
|
193
206
|
def forward_cpu(
|
@@ -202,6 +215,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
202
215
|
num_expert_group: Optional[int] = None,
|
203
216
|
custom_routing_function: Optional[Callable] = None,
|
204
217
|
correction_bias: Optional[torch.Tensor] = None,
|
218
|
+
inplace: bool = True,
|
205
219
|
) -> torch.Tensor:
|
206
220
|
return moe_forward_native(
|
207
221
|
layer,
|
@@ -241,6 +255,7 @@ class FusedMoE(torch.nn.Module):
|
|
241
255
|
reduce_results: Whether to all all_reduce on the output of the layer
|
242
256
|
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
243
257
|
quant_config: Quantization configure.
|
258
|
+
inplace: suggestion to compute inplace (modify input activation).
|
244
259
|
"""
|
245
260
|
|
246
261
|
def __init__(
|
@@ -262,6 +277,8 @@ class FusedMoE(torch.nn.Module):
|
|
262
277
|
correction_bias: Optional[torch.Tensor] = None,
|
263
278
|
activation: str = "silu",
|
264
279
|
use_presharded_weights: bool = False,
|
280
|
+
inplace: bool = True,
|
281
|
+
no_combine: bool = False,
|
265
282
|
):
|
266
283
|
super().__init__()
|
267
284
|
|
@@ -285,6 +302,9 @@ class FusedMoE(torch.nn.Module):
|
|
285
302
|
self.custom_routing_function = custom_routing_function
|
286
303
|
self.correction_bias = correction_bias
|
287
304
|
self.activation = activation
|
305
|
+
self.use_presharded_weights = use_presharded_weights
|
306
|
+
self.inplace = inplace
|
307
|
+
self.no_combine = no_combine
|
288
308
|
|
289
309
|
if quant_config is None:
|
290
310
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
@@ -304,7 +324,6 @@ class FusedMoE(torch.nn.Module):
|
|
304
324
|
params_dtype=params_dtype,
|
305
325
|
weight_loader=self.weight_loader,
|
306
326
|
)
|
307
|
-
self.use_presharded_weights = use_presharded_weights
|
308
327
|
|
309
328
|
def _load_per_tensor_weight_scale(
|
310
329
|
self,
|
@@ -598,6 +617,8 @@ class FusedMoE(torch.nn.Module):
|
|
598
617
|
custom_routing_function=self.custom_routing_function,
|
599
618
|
correction_bias=self.correction_bias,
|
600
619
|
activation=self.activation,
|
620
|
+
inplace=self.inplace,
|
621
|
+
no_combine=self.no_combine,
|
601
622
|
)
|
602
623
|
|
603
624
|
if self.reduce_results and self.tp_size > 1:
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -75,7 +75,6 @@ def fused_topk(
|
|
75
75
|
return topk_weights, topk_ids
|
76
76
|
|
77
77
|
|
78
|
-
# This is used by the Deepseek V2/V3/R1 series models
|
79
78
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
80
79
|
def grouped_topk(
|
81
80
|
hidden_states: torch.Tensor,
|
@@ -84,10 +83,17 @@ def grouped_topk(
|
|
84
83
|
renormalize: bool,
|
85
84
|
num_expert_group: int = 0,
|
86
85
|
topk_group: int = 0,
|
86
|
+
scoring_func: str = "softmax",
|
87
87
|
):
|
88
88
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
89
89
|
|
90
|
-
|
90
|
+
if scoring_func == "softmax":
|
91
|
+
scores = torch.softmax(gating_output, dim=-1)
|
92
|
+
elif scoring_func == "sigmoid":
|
93
|
+
scores = gating_output.sigmoid()
|
94
|
+
else:
|
95
|
+
raise ValueError(f"Scoring function '{scoring_func}' is not supported.")
|
96
|
+
|
91
97
|
num_token = scores.shape[0]
|
92
98
|
group_scores = (
|
93
99
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
@@ -111,6 +117,7 @@ def grouped_topk(
|
|
111
117
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
112
118
|
|
113
119
|
|
120
|
+
# DeepSeek V2/V3/R1 uses biased_grouped_top
|
114
121
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
115
122
|
def biased_grouped_topk(
|
116
123
|
hidden_states: torch.Tensor,
|
@@ -141,7 +148,9 @@ def biased_grouped_topk(
|
|
141
148
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
142
149
|
.reshape(num_token, -1)
|
143
150
|
) # [n, e]
|
144
|
-
tmp_scores = scores_for_choice.masked_fill(
|
151
|
+
tmp_scores = scores_for_choice.masked_fill(
|
152
|
+
~score_mask.bool(), float("-inf")
|
153
|
+
) # [n, e]
|
145
154
|
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
146
155
|
topk_weights = scores.gather(1, topk_ids)
|
147
156
|
|
@@ -163,7 +172,7 @@ def select_experts(
|
|
163
172
|
correction_bias: Optional[torch.Tensor] = None,
|
164
173
|
torch_native: bool = False,
|
165
174
|
):
|
166
|
-
#
|
175
|
+
# DeepSeek V2/V3/R1 uses biased_grouped_top
|
167
176
|
if use_grouped_topk:
|
168
177
|
assert topk_group is not None
|
169
178
|
assert num_expert_group is not None
|
@@ -1,5 +1,7 @@
|
|
1
1
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
2
|
-
|
2
|
+
import re
|
3
|
+
from copy import deepcopy
|
4
|
+
from typing import Callable, Dict, Optional, Type, Union
|
3
5
|
|
4
6
|
import torch
|
5
7
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
@@ -16,15 +18,15 @@ from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfi
|
|
16
18
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
17
19
|
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
18
20
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
19
|
-
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
20
|
-
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
21
21
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
|
22
22
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
23
23
|
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
24
24
|
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
25
25
|
|
26
26
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
27
|
+
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
27
28
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
29
|
+
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
28
30
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
29
31
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
30
32
|
|
@@ -34,6 +36,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
34
36
|
"deepspeedfp": DeepSpeedFPConfig,
|
35
37
|
"tpu_int8": Int8TpuConfig,
|
36
38
|
"fp8": Fp8Config,
|
39
|
+
"blockwise_int8": BlockInt8Config,
|
37
40
|
"fbgemm_fp8": FBGEMMFp8Config,
|
38
41
|
"marlin": MarlinConfig,
|
39
42
|
"modelopt": ModelOptFp8Config,
|
@@ -59,19 +62,119 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
59
62
|
return QUANTIZATION_METHODS[quantization]
|
60
63
|
|
61
64
|
|
65
|
+
# Match dynamic rules with module name (prefix) and override quantize
|
66
|
+
# config if module (prefix) matches a rule
|
67
|
+
def override_config(config: QuantizationConfig, prefix: str):
|
68
|
+
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
|
69
|
+
if isinstance(weight_bits, int):
|
70
|
+
config.weight_bits = weight_bits
|
71
|
+
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
|
72
|
+
if isinstance(group_size, int):
|
73
|
+
config.group_size = group_size
|
74
|
+
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
|
75
|
+
if isinstance(desc_act, bool):
|
76
|
+
config.desc_act = desc_act
|
77
|
+
|
78
|
+
config.pack_factor = 32 // config.weight_bits # packed into int32
|
79
|
+
if config.get_name() == "gptq_marlin":
|
80
|
+
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
81
|
+
if isinstance(is_sym, bool):
|
82
|
+
config.is_sym = is_sym
|
83
|
+
|
84
|
+
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
85
|
+
raise ValueError(
|
86
|
+
"Unsupported quantization config: "
|
87
|
+
f"bits={config.weight_bits}, sym={config.is_sym}"
|
88
|
+
)
|
89
|
+
|
90
|
+
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
|
91
|
+
elif config.get_name() == "gptq":
|
92
|
+
if config.weight_bits not in [2, 3, 4, 8]:
|
93
|
+
raise ValueError(
|
94
|
+
"Currently, only 2/3/4/8-bit weight quantization is "
|
95
|
+
f"supported for GPTQ, but got {config.weight_bits} bits."
|
96
|
+
)
|
97
|
+
|
98
|
+
|
99
|
+
def get_dynamic_override(
|
100
|
+
config: QuantizationConfig,
|
101
|
+
layer_name: str,
|
102
|
+
key: Optional[str] = None,
|
103
|
+
default_value: Union[int, bool, None] = None,
|
104
|
+
) -> Union[Dict, int, bool, None]:
|
105
|
+
for pattern, pattern_dict in config.dynamic.items():
|
106
|
+
# Negative match: matched modules are excluded from quantized init
|
107
|
+
if pattern.startswith("-:"):
|
108
|
+
if re.match(pattern.removeprefix("-:"), layer_name):
|
109
|
+
return False
|
110
|
+
# Positive match: matched modules have quant properties overrides
|
111
|
+
# base quant config
|
112
|
+
elif re.match(pattern.removeprefix("+:"), layer_name):
|
113
|
+
if key is None:
|
114
|
+
return pattern_dict
|
115
|
+
else:
|
116
|
+
return pattern_dict.get(key, default_value)
|
117
|
+
return default_value
|
118
|
+
|
119
|
+
|
120
|
+
def get_linear_quant_method(
|
121
|
+
config: QuantizationConfig,
|
122
|
+
layer: torch.nn.Module,
|
123
|
+
prefix: str,
|
124
|
+
linear_method_cls: type,
|
125
|
+
):
|
126
|
+
|
127
|
+
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
128
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
129
|
+
ParallelLMHead,
|
130
|
+
UnquantizedEmbeddingMethod,
|
131
|
+
)
|
132
|
+
|
133
|
+
cloned_config = deepcopy(config)
|
134
|
+
parallel_lm_head_quantized = (
|
135
|
+
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
136
|
+
)
|
137
|
+
|
138
|
+
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
139
|
+
# False = skip module, None = no override, else = Positive match
|
140
|
+
if (
|
141
|
+
get_dynamic_override( # noqa: E712
|
142
|
+
cloned_config, layer_name=prefix # noqa: E712
|
143
|
+
)
|
144
|
+
== False
|
145
|
+
): # noqa: E712
|
146
|
+
if parallel_lm_head_quantized:
|
147
|
+
return UnquantizedEmbeddingMethod()
|
148
|
+
return UnquantizedLinearMethod()
|
149
|
+
|
150
|
+
if prefix:
|
151
|
+
# Dynamic per module/layer rules may override base config
|
152
|
+
override_config(cloned_config, prefix=prefix)
|
153
|
+
|
154
|
+
return linear_method_cls(cloned_config)
|
155
|
+
return None
|
156
|
+
|
157
|
+
|
62
158
|
def gptq_get_quant_method(self, layer, prefix):
|
159
|
+
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
63
160
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
64
161
|
GPTQMarlinLinearMethod,
|
65
162
|
GPTQMarlinMoEMethod,
|
66
163
|
)
|
67
164
|
|
68
|
-
from sglang.srt.layers.linear import LinearBase
|
69
165
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
70
166
|
|
71
|
-
if isinstance(layer,
|
72
|
-
return GPTQMarlinLinearMethod(self)
|
73
|
-
elif isinstance(layer, FusedMoE):
|
167
|
+
if isinstance(layer, FusedMoE):
|
74
168
|
return GPTQMarlinMoEMethod(self)
|
169
|
+
|
170
|
+
if isinstance(self, GPTQConfig):
|
171
|
+
return get_linear_quant_method(
|
172
|
+
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
173
|
+
)
|
174
|
+
elif isinstance(self, GPTQMarlinConfig):
|
175
|
+
return get_linear_quant_method(
|
176
|
+
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
177
|
+
)
|
75
178
|
return None
|
76
179
|
|
77
180
|
|
@@ -153,6 +256,7 @@ def apply_monkey_patches():
|
|
153
256
|
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
|
154
257
|
|
155
258
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
259
|
+
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
156
260
|
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
157
261
|
setattr(AWQMoEMethod, "apply", awq_moe_method_apply)
|
158
262
|
|