sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
sglang/srt/layers/moe/topk.py
CHANGED
@@ -12,7 +12,7 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
|
15
|
-
import
|
15
|
+
import math
|
16
16
|
from typing import Callable, Optional
|
17
17
|
|
18
18
|
import torch
|
@@ -25,6 +25,12 @@ from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
|
25
25
|
_is_cuda = is_cuda()
|
26
26
|
_is_hip = is_hip()
|
27
27
|
|
28
|
+
if _is_cuda:
|
29
|
+
from sgl_kernel import moe_fused_gate
|
30
|
+
|
31
|
+
if _is_cuda or _is_hip:
|
32
|
+
from sgl_kernel import topk_softmax
|
33
|
+
|
28
34
|
|
29
35
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
30
36
|
|
@@ -56,11 +62,6 @@ def fused_topk(
|
|
56
62
|
topk: int,
|
57
63
|
renormalize: bool,
|
58
64
|
):
|
59
|
-
if _is_cuda or _is_hip:
|
60
|
-
from sgl_kernel import topk_softmax
|
61
|
-
else:
|
62
|
-
from vllm import _custom_ops as vllm_ops
|
63
|
-
|
64
65
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
65
66
|
|
66
67
|
M, _ = hidden_states.shape
|
@@ -73,20 +74,12 @@ def fused_topk(
|
|
73
74
|
M, topk, dtype=torch.int32, device=hidden_states.device
|
74
75
|
)
|
75
76
|
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
)
|
83
|
-
else:
|
84
|
-
vllm_ops.topk_softmax(
|
85
|
-
topk_weights,
|
86
|
-
topk_ids,
|
87
|
-
token_expert_indicies,
|
88
|
-
gating_output.float(),
|
89
|
-
)
|
77
|
+
topk_softmax(
|
78
|
+
topk_weights,
|
79
|
+
topk_ids,
|
80
|
+
token_expert_indicies,
|
81
|
+
gating_output.float(),
|
82
|
+
)
|
90
83
|
del token_expert_indicies
|
91
84
|
|
92
85
|
if renormalize:
|
@@ -105,6 +98,7 @@ def grouped_topk(
|
|
105
98
|
num_expert_group: int = 0,
|
106
99
|
topk_group: int = 0,
|
107
100
|
n_share_experts_fusion: int = 0,
|
101
|
+
routed_scaling_factor: Optional[float] = None,
|
108
102
|
):
|
109
103
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
110
104
|
|
@@ -134,9 +128,7 @@ def grouped_topk(
|
|
134
128
|
dtype=topk_ids.dtype,
|
135
129
|
device=topk_ids.device,
|
136
130
|
)
|
137
|
-
topk_weights[:, -1] = (
|
138
|
-
topk_weights[:, :-1].sum(dim=-1) / 2.5
|
139
|
-
) # 2.5 is the routed_scaling_factor.
|
131
|
+
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
|
140
132
|
|
141
133
|
if renormalize:
|
142
134
|
topk_weights_sum = (
|
@@ -158,6 +150,7 @@ def biased_grouped_topk_impl(
|
|
158
150
|
num_expert_group: int = 0,
|
159
151
|
topk_group: int = 0,
|
160
152
|
n_share_experts_fusion: int = 0,
|
153
|
+
routed_scaling_factor: Optional[float] = None,
|
161
154
|
):
|
162
155
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
163
156
|
|
@@ -194,9 +187,7 @@ def biased_grouped_topk_impl(
|
|
194
187
|
dtype=topk_ids.dtype,
|
195
188
|
device=topk_ids.device,
|
196
189
|
)
|
197
|
-
topk_weights[:, -1] = (
|
198
|
-
topk_weights[:, :-1].sum(dim=-1) / 2.5
|
199
|
-
) # 2.5 is the routed_scaling_factor.
|
190
|
+
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
|
200
191
|
|
201
192
|
if renormalize:
|
202
193
|
topk_weights_sum = (
|
@@ -209,6 +200,10 @@ def biased_grouped_topk_impl(
|
|
209
200
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
210
201
|
|
211
202
|
|
203
|
+
def is_power_of_two(n):
|
204
|
+
return n > 0 and math.log2(n).is_integer()
|
205
|
+
|
206
|
+
|
212
207
|
def biased_grouped_topk(
|
213
208
|
hidden_states: torch.Tensor,
|
214
209
|
gating_output: torch.Tensor,
|
@@ -219,24 +214,46 @@ def biased_grouped_topk(
|
|
219
214
|
topk_group: int = 0,
|
220
215
|
compiled: bool = True,
|
221
216
|
n_share_experts_fusion: int = 0,
|
217
|
+
routed_scaling_factor: Optional[float] = None,
|
222
218
|
):
|
223
|
-
|
224
|
-
|
225
|
-
|
219
|
+
assert (
|
220
|
+
routed_scaling_factor is not None
|
221
|
+
), "routed_scaling_factor is required for biased_grouped_topk"
|
222
|
+
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
|
223
|
+
if (
|
224
|
+
_is_cuda
|
225
|
+
and gating_output.shape[1] // num_expert_group
|
226
|
+
<= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
|
227
|
+
and is_power_of_two(correction_bias.shape[0])
|
228
|
+
):
|
229
|
+
return moe_fused_gate(
|
230
|
+
gating_output,
|
231
|
+
correction_bias,
|
232
|
+
num_expert_group,
|
233
|
+
topk_group,
|
234
|
+
topk,
|
235
|
+
n_share_experts_fusion,
|
236
|
+
routed_scaling_factor,
|
237
|
+
)
|
238
|
+
else:
|
239
|
+
biased_grouped_topk_fn = (
|
240
|
+
torch.compile(
|
241
|
+
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
|
242
|
+
)
|
243
|
+
if compiled
|
244
|
+
else biased_grouped_topk_impl
|
245
|
+
)
|
246
|
+
return biased_grouped_topk_fn(
|
247
|
+
hidden_states,
|
248
|
+
gating_output,
|
249
|
+
correction_bias,
|
250
|
+
topk,
|
251
|
+
renormalize,
|
252
|
+
num_expert_group,
|
253
|
+
topk_group,
|
254
|
+
n_share_experts_fusion=n_share_experts_fusion,
|
255
|
+
routed_scaling_factor=routed_scaling_factor,
|
226
256
|
)
|
227
|
-
if compiled
|
228
|
-
else biased_grouped_topk_impl
|
229
|
-
)
|
230
|
-
return biased_grouped_topk_fn(
|
231
|
-
hidden_states,
|
232
|
-
gating_output,
|
233
|
-
correction_bias,
|
234
|
-
topk,
|
235
|
-
renormalize,
|
236
|
-
num_expert_group,
|
237
|
-
topk_group,
|
238
|
-
n_share_experts_fusion=n_share_experts_fusion,
|
239
|
-
)
|
240
257
|
|
241
258
|
|
242
259
|
def select_experts(
|
@@ -250,10 +267,9 @@ def select_experts(
|
|
250
267
|
custom_routing_function: Optional[Callable] = None,
|
251
268
|
correction_bias: Optional[torch.Tensor] = None,
|
252
269
|
torch_native: bool = False,
|
270
|
+
routed_scaling_factor: Optional[float] = None,
|
253
271
|
):
|
254
|
-
n_share_experts_fusion =
|
255
|
-
if global_server_args_dict["n_share_experts_fusion"] is not None:
|
256
|
-
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
272
|
+
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
257
273
|
# DeekSeek V2/V3/R1 serices models uses grouped_top_k
|
258
274
|
if use_grouped_topk:
|
259
275
|
assert topk_group is not None
|
@@ -267,6 +283,7 @@ def select_experts(
|
|
267
283
|
num_expert_group=num_expert_group,
|
268
284
|
topk_group=topk_group,
|
269
285
|
n_share_experts_fusion=n_share_experts_fusion,
|
286
|
+
routed_scaling_factor=routed_scaling_factor,
|
270
287
|
)
|
271
288
|
else:
|
272
289
|
topk_weights, topk_ids = biased_grouped_topk(
|
@@ -278,6 +295,7 @@ def select_experts(
|
|
278
295
|
num_expert_group=num_expert_group,
|
279
296
|
topk_group=topk_group,
|
280
297
|
n_share_experts_fusion=n_share_experts_fusion,
|
298
|
+
routed_scaling_factor=routed_scaling_factor,
|
281
299
|
)
|
282
300
|
elif torch_native and custom_routing_function is None:
|
283
301
|
topk_weights, topk_ids = fused_topk_native(
|
sglang/srt/layers/parameter.py
CHANGED
@@ -59,20 +59,20 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
|
|
59
59
|
)
|
60
60
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
61
61
|
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
62
|
-
from sglang.srt.layers.quantization.modelopt_quant import
|
62
|
+
from sglang.srt.layers.quantization.modelopt_quant import (
|
63
|
+
ModelOptFp4Config,
|
64
|
+
ModelOptFp8Config,
|
65
|
+
)
|
63
66
|
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
64
67
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
65
68
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
66
|
-
from sglang.srt.layers.vocab_parallel_embedding import (
|
67
|
-
ParallelLMHead,
|
68
|
-
UnquantizedEmbeddingMethod,
|
69
|
-
)
|
70
69
|
|
71
70
|
# Base quantization methods that don't depend on vllm
|
72
71
|
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
73
72
|
"fp8": Fp8Config,
|
74
73
|
"blockwise_int8": BlockInt8Config,
|
75
74
|
"modelopt": ModelOptFp8Config,
|
75
|
+
"modelopt_fp4": ModelOptFp4Config,
|
76
76
|
"w8a8_int8": W8A8Int8Config,
|
77
77
|
"w8a8_fp8": W8A8Fp8Config,
|
78
78
|
"moe_wna16": MoeWNA16Config,
|
@@ -176,6 +176,13 @@ def get_linear_quant_method(
|
|
176
176
|
prefix: str,
|
177
177
|
linear_method_cls: type,
|
178
178
|
):
|
179
|
+
# Move import here to avoid circular import. This is only used in monkey patching
|
180
|
+
# of vllm's QuantizationConfig.
|
181
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
182
|
+
ParallelLMHead,
|
183
|
+
UnquantizedEmbeddingMethod,
|
184
|
+
)
|
185
|
+
|
179
186
|
cloned_config = deepcopy(config)
|
180
187
|
parallel_lm_head_quantized = (
|
181
188
|
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
@@ -283,6 +290,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
283
290
|
apply_router_weight_on_input: bool = False,
|
284
291
|
inplace: bool = True,
|
285
292
|
no_combine: bool = False,
|
293
|
+
routed_scaling_factor: Optional[float] = None,
|
286
294
|
):
|
287
295
|
assert activation == "silu"
|
288
296
|
assert inplace and not no_combine
|
@@ -373,6 +373,7 @@ class BlockInt8MoEMethod:
|
|
373
373
|
apply_router_weight_on_input: bool = False,
|
374
374
|
inplace: bool = True,
|
375
375
|
no_combine: bool = False,
|
376
|
+
routed_scaling_factor: Optional[float] = None,
|
376
377
|
) -> torch.Tensor:
|
377
378
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
378
379
|
from sglang.srt.layers.moe.topk import select_experts
|
@@ -388,6 +389,7 @@ class BlockInt8MoEMethod:
|
|
388
389
|
num_expert_group=num_expert_group,
|
389
390
|
custom_routing_function=custom_routing_function,
|
390
391
|
correction_bias=correction_bias,
|
392
|
+
routed_scaling_factor=routed_scaling_factor,
|
391
393
|
)
|
392
394
|
|
393
395
|
# Expert fusion with INT8 quantization
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Adapted from https://github.com/vllm-project/vllm/tree/
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
3
3
|
|
4
4
|
import logging
|
@@ -39,7 +39,13 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
|
39
39
|
is_activation_quantization_format,
|
40
40
|
should_ignore_layer,
|
41
41
|
)
|
42
|
-
|
42
|
+
|
43
|
+
try:
|
44
|
+
import vllm
|
45
|
+
|
46
|
+
VLLM_AVAILABLE = True
|
47
|
+
except ImportError:
|
48
|
+
VLLM_AVAILABLE = False
|
43
49
|
|
44
50
|
logger = logging.getLogger(__name__)
|
45
51
|
|
@@ -77,6 +83,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
77
83
|
sparsity_ignore_list: List[str],
|
78
84
|
kv_cache_scheme: Optional[Dict[str, Any]] = None,
|
79
85
|
config: Optional[Dict[str, Any]] = None,
|
86
|
+
packed_modules_mapping: Dict[str, List[str]] = {},
|
80
87
|
):
|
81
88
|
super().__init__()
|
82
89
|
self.ignore = ignore
|
@@ -87,6 +94,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
87
94
|
self.sparsity_scheme_map = sparsity_scheme_map
|
88
95
|
self.sparsity_ignore_list = sparsity_ignore_list
|
89
96
|
self.config = config
|
97
|
+
self.packed_modules_mapping = packed_modules_mapping
|
90
98
|
|
91
99
|
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
92
100
|
return CompressedTensorsLinearMethod(self)
|
@@ -136,6 +144,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
136
144
|
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
137
145
|
config=config
|
138
146
|
)
|
147
|
+
packed_modules_mapping = config.get("packed_modules_mapping", {})
|
139
148
|
|
140
149
|
return cls(
|
141
150
|
target_scheme_map=target_scheme_map,
|
@@ -144,6 +153,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
144
153
|
sparsity_scheme_map=sparsity_scheme_map,
|
145
154
|
sparsity_ignore_list=sparsity_ignore_list,
|
146
155
|
config=config,
|
156
|
+
packed_modules_mapping=packed_modules_mapping,
|
147
157
|
)
|
148
158
|
|
149
159
|
@classmethod
|
@@ -1,22 +1,16 @@
|
|
1
|
-
# Adapted from https://github.com/vllm-project/vllm/tree/
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
3
3
|
|
4
4
|
import enum
|
5
5
|
import logging
|
6
6
|
from enum import Enum
|
7
|
-
from typing import
|
7
|
+
from typing import Callable, List, Optional
|
8
8
|
|
9
9
|
import torch
|
10
10
|
from compressed_tensors import CompressionFormat
|
11
11
|
from compressed_tensors.quantization import QuantizationStrategy
|
12
12
|
|
13
|
-
|
14
|
-
from sglang.srt.layers.moe.fused_moe_triton import (
|
15
|
-
FusedMoE,
|
16
|
-
FusedMoEMethodBase,
|
17
|
-
FusedMoeWeightScaleSupported,
|
18
|
-
)
|
19
|
-
|
13
|
+
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
20
14
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
21
15
|
from sglang.srt.layers.quantization.utils import (
|
22
16
|
all_close_1d,
|
@@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs
|
|
29
23
|
|
30
24
|
_is_cuda = is_cuda()
|
31
25
|
|
32
|
-
if _is_cuda:
|
33
|
-
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
34
|
-
else:
|
26
|
+
if not _is_cuda:
|
35
27
|
from vllm import _custom_ops as vllm_ops
|
28
|
+
from vllm._custom_ops import scaled_fp8_quant
|
36
29
|
|
37
30
|
try:
|
38
31
|
import vllm
|
@@ -58,8 +51,6 @@ __all__ = [
|
|
58
51
|
|
59
52
|
class CompressedTensorsMoEMethod:
|
60
53
|
def __new__(cls, *args, **kwargs):
|
61
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
62
|
-
|
63
54
|
if cls is CompressedTensorsMoEMethod:
|
64
55
|
return super().__new__(cls)
|
65
56
|
return super().__new__(cls)
|
@@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod:
|
|
76
67
|
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
77
68
|
if not VLLM_AVAILABLE:
|
78
69
|
raise ImportError(
|
79
|
-
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm"
|
70
|
+
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
|
80
71
|
)
|
81
72
|
return CompressedTensorsWNA16MoEMethod(quant_config)
|
82
73
|
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
@@ -92,27 +83,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
92
83
|
def __init__(
|
93
84
|
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
94
85
|
):
|
95
|
-
from sglang.srt.layers.moe.fused_moe_triton import (
|
96
|
-
FusedMoEMethodBase,
|
97
|
-
FusedMoeWeightScaleSupported,
|
98
|
-
)
|
99
|
-
|
100
86
|
self.quant_config = quant_config
|
101
87
|
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
102
88
|
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
103
89
|
"input_activations"
|
104
90
|
)
|
105
91
|
|
106
|
-
if not (
|
107
|
-
self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
108
|
-
and self.input_quant.strategy == QuantizationStrategy.TENSOR
|
109
|
-
):
|
110
|
-
raise ValueError(
|
111
|
-
"For FP8 Fused MoE layers, only per-tensor scales "
|
112
|
-
"for weights and activations are supported. Found "
|
113
|
-
f"{self.weight_quant}, {self.input_quant}"
|
114
|
-
)
|
115
|
-
|
116
92
|
self.static_input_scales = not self.input_quant.dynamic
|
117
93
|
|
118
94
|
def create_weights(
|
@@ -154,27 +130,50 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
154
130
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
155
131
|
|
156
132
|
# WEIGHT_SCALES
|
157
|
-
#
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
133
|
+
# per-tensor quantization
|
134
|
+
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
135
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
136
|
+
# They will be combined to a single scale after weight loading.
|
137
|
+
w13_weight_scale = torch.nn.Parameter(
|
138
|
+
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
139
|
+
)
|
140
|
+
w2_weight_scale = torch.nn.Parameter(
|
141
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
142
|
+
)
|
143
|
+
weight_quant_method = FusedMoeWeightScaleSupported.TENSOR.value
|
144
|
+
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
|
145
|
+
w13_weight_scale = torch.nn.Parameter(
|
146
|
+
torch.ones(
|
147
|
+
num_experts,
|
148
|
+
2 * intermediate_size_per_partition,
|
149
|
+
1,
|
150
|
+
dtype=torch.float32,
|
151
|
+
),
|
152
|
+
requires_grad=False,
|
153
|
+
)
|
154
|
+
w2_weight_scale = torch.nn.Parameter(
|
155
|
+
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
|
156
|
+
requires_grad=False,
|
157
|
+
)
|
158
|
+
weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value
|
159
|
+
else:
|
160
|
+
raise ValueError(
|
161
|
+
f"Unsupported weight quantization strategy: {self.weight_quant.strategy}"
|
162
|
+
)
|
163
163
|
|
164
|
-
|
165
|
-
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
166
|
-
)
|
164
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
167
165
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
168
166
|
# Add the quantization method used (per tensor/grouped/channel)
|
169
167
|
# to ensure the weight scales are loaded in properly
|
170
|
-
extra_weight_attrs.update(
|
171
|
-
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
172
|
-
)
|
168
|
+
extra_weight_attrs.update({"quant_method": weight_quant_method})
|
173
169
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
174
170
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
175
171
|
|
176
172
|
# INPUT_SCALES
|
177
173
|
if self.static_input_scales:
|
174
|
+
assert (
|
175
|
+
self.input_quant.strategy == QuantizationStrategy.TENSOR
|
176
|
+
), "Only per-tensor quantization is supported for static input scales"
|
178
177
|
w13_input_scale = torch.nn.Parameter(
|
179
178
|
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
180
179
|
)
|
@@ -241,31 +240,29 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
241
240
|
layer.w2_input_scale = torch.nn.Parameter(
|
242
241
|
w2_input_scale, requires_grad=False
|
243
242
|
)
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
)
|
257
|
-
|
258
|
-
if _is_cuda:
|
259
|
-
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
260
|
-
sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
261
|
-
)
|
262
|
-
else:
|
263
|
-
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
264
|
-
vllm_ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
243
|
+
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
244
|
+
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
245
|
+
# We take the max then dequant and requant each expert.
|
246
|
+
assert layer.w13_weight_scale is not None
|
247
|
+
shard_size = layer.intermediate_size_per_partition
|
248
|
+
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
249
|
+
for expert_id in range(layer.local_num_experts):
|
250
|
+
start = 0
|
251
|
+
for shard_id in range(2):
|
252
|
+
dq_weight = per_tensor_dequantize(
|
253
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
254
|
+
layer.w13_weight_scale[expert_id][shard_id],
|
265
255
|
)
|
266
|
-
|
256
|
+
(
|
257
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
258
|
+
_,
|
259
|
+
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
267
260
|
|
268
|
-
|
261
|
+
start += shard_size
|
262
|
+
|
263
|
+
layer.w13_weight_scale = torch.nn.Parameter(
|
264
|
+
max_w13_scales, requires_grad=False
|
265
|
+
)
|
269
266
|
|
270
267
|
def apply(
|
271
268
|
self,
|
@@ -285,6 +282,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
285
282
|
activation: str = "silu",
|
286
283
|
inplace: bool = True,
|
287
284
|
no_combine: bool = False,
|
285
|
+
apply_router_weight_on_input: bool = False,
|
286
|
+
routed_scaling_factor: Optional[float] = None,
|
288
287
|
) -> torch.Tensor:
|
289
288
|
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
290
289
|
from sglang.srt.layers.moe.topk import select_experts
|
@@ -299,6 +298,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
299
298
|
num_expert_group=num_expert_group,
|
300
299
|
custom_routing_function=custom_routing_function,
|
301
300
|
correction_bias=correction_bias,
|
301
|
+
routed_scaling_factor=routed_scaling_factor,
|
302
302
|
)
|
303
303
|
|
304
304
|
return fused_experts(
|
@@ -310,10 +310,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
310
310
|
inplace=inplace,
|
311
311
|
activation=activation,
|
312
312
|
use_fp8_w8a8=True,
|
313
|
+
per_channel_quant=self.weight_quant.strategy
|
314
|
+
== QuantizationStrategy.CHANNEL,
|
313
315
|
w1_scale=layer.w13_weight_scale,
|
314
316
|
w2_scale=layer.w2_weight_scale,
|
315
317
|
a1_scale=layer.w13_input_scale,
|
316
318
|
a2_scale=layer.w2_input_scale,
|
319
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
317
320
|
)
|
318
321
|
|
319
322
|
|
@@ -322,11 +325,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
322
325
|
def __init__(
|
323
326
|
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
324
327
|
):
|
325
|
-
from sglang.srt.layers.moe.fused_moe_triton import (
|
326
|
-
FusedMoEMethodBase,
|
327
|
-
FusedMoeWeightScaleSupported,
|
328
|
-
)
|
329
|
-
|
330
328
|
self.quant_config = quant_config
|
331
329
|
# TODO: @dsikka: refactor this to use schemes as other kernels
|
332
330
|
# are supported + check if the layer is being ignored.
|
@@ -586,7 +584,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
586
584
|
requires_grad=False,
|
587
585
|
)
|
588
586
|
|
589
|
-
marlin_w13_qweight =
|
587
|
+
marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
|
590
588
|
layer.w13_weight_packed,
|
591
589
|
layer.w13_g_idx_sort_indices,
|
592
590
|
layer.w13_weight_packed.shape[1] * self.packed_factor,
|
@@ -594,7 +592,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
594
592
|
self.num_bits,
|
595
593
|
)
|
596
594
|
replace_tensor("w13_weight_packed", marlin_w13_qweight)
|
597
|
-
marlin_w2_qweight =
|
595
|
+
marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack(
|
598
596
|
layer.w2_weight_packed,
|
599
597
|
layer.w2_g_idx_sort_indices,
|
600
598
|
layer.w2_weight_packed.shape[1] * self.packed_factor,
|
@@ -637,15 +635,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
637
635
|
scoring_func: str = "softmax",
|
638
636
|
correction_bias: Optional[torch.Tensor] = None,
|
639
637
|
activation: str = "silu",
|
638
|
+
routed_scaling_factor: Optional[float] = None,
|
640
639
|
) -> torch.Tensor:
|
641
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
642
640
|
from sglang.srt.layers.moe.topk import select_experts
|
643
641
|
|
644
642
|
assert activation == "silu", "Only SiLU activation is supported."
|
645
|
-
if not VLLM_AVAILABLE:
|
646
|
-
raise ImportError(
|
647
|
-
"vllm is not installed, to use fused_marlin_moe, please install vllm"
|
648
|
-
)
|
649
643
|
if expert_map is not None:
|
650
644
|
raise NotImplementedError(
|
651
645
|
"Expert Parallelism is not supported for " "fused Marlin MoE method."
|
@@ -662,6 +656,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
662
656
|
custom_routing_function=custom_routing_function,
|
663
657
|
scoring_func=scoring_func,
|
664
658
|
correction_bias=correction_bias,
|
659
|
+
routed_scaling_factor=routed_scaling_factor,
|
665
660
|
)
|
666
661
|
|
667
662
|
return torch.ops.vllm.fused_marlin_moe(
|
@@ -16,8 +16,7 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
|
16
16
|
CompressedTensorsScheme,
|
17
17
|
)
|
18
18
|
from sglang.srt.layers.quantization.fp8_utils import (
|
19
|
-
|
20
|
-
maybe_create_device_identity,
|
19
|
+
apply_fp8_linear,
|
21
20
|
normalize_e4m3fn_to_e4m3fnuz,
|
22
21
|
)
|
23
22
|
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
|
@@ -30,7 +29,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|
30
29
|
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
31
30
|
self.strategy = strategy
|
32
31
|
self.is_static_input_scheme = is_static_input_scheme
|
33
|
-
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
34
32
|
|
35
33
|
@classmethod
|
36
34
|
def get_min_capability(cls) -> int:
|
@@ -99,8 +97,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|
99
97
|
weight_loader: Callable,
|
100
98
|
**kwargs,
|
101
99
|
):
|
102
|
-
maybe_create_device_identity()
|
103
|
-
|
104
100
|
output_size_per_partition = sum(output_partition_sizes)
|
105
101
|
layer.logical_widths = output_partition_sizes
|
106
102
|
|
@@ -152,11 +148,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|
152
148
|
x: torch.Tensor,
|
153
149
|
bias: Optional[torch.Tensor] = None,
|
154
150
|
) -> torch.Tensor:
|
155
|
-
|
156
|
-
return self.fp8_linear.apply(
|
151
|
+
return apply_fp8_linear(
|
157
152
|
input=x,
|
158
153
|
weight=layer.weight,
|
159
154
|
weight_scale=layer.weight_scale,
|
160
155
|
input_scale=layer.input_scale,
|
161
156
|
bias=bias,
|
157
|
+
use_per_token_if_dynamic=True,
|
158
|
+
compressed_tensor_quant=True,
|
162
159
|
)
|