sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,17 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
2
|
+
from __future__ import annotations
|
2
3
|
|
3
4
|
import logging
|
4
|
-
from typing import Any, Callable, Dict, List, Optional
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
5
6
|
|
6
7
|
import torch
|
7
8
|
from torch.nn.parameter import Parameter
|
8
9
|
|
9
|
-
from sglang.srt.layers.linear import (
|
10
|
-
LinearBase,
|
11
|
-
LinearMethodBase,
|
12
|
-
UnquantizedLinearMethod,
|
13
|
-
)
|
14
10
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
15
11
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
16
12
|
from sglang.srt.layers.quantization.base_config import (
|
13
|
+
FusedMoEMethodBase,
|
14
|
+
LinearMethodBase,
|
17
15
|
QuantizationConfig,
|
18
16
|
QuantizeMethodBase,
|
19
17
|
)
|
@@ -23,6 +21,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
23
21
|
is_sm100_supported,
|
24
22
|
)
|
25
23
|
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
24
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
26
25
|
from sglang.srt.layers.quantization.utils import (
|
27
26
|
convert_to_channelwise,
|
28
27
|
is_layer_skipped,
|
@@ -32,6 +31,9 @@ from sglang.srt.layers.quantization.utils import (
|
|
32
31
|
from sglang.srt.layers.radix_attention import RadixAttention
|
33
32
|
from sglang.srt.utils import is_cuda, next_power_of_2
|
34
33
|
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
36
|
+
|
35
37
|
if is_cuda():
|
36
38
|
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
37
39
|
|
@@ -86,7 +88,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
86
88
|
return ["hf_quant_config.json"]
|
87
89
|
|
88
90
|
@classmethod
|
89
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
91
|
+
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
|
90
92
|
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
|
91
93
|
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
|
92
94
|
"kv_cache_quant_algo"
|
@@ -109,7 +111,11 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
109
111
|
|
110
112
|
def get_quant_method(
|
111
113
|
self, layer: torch.nn.Module, prefix: str
|
112
|
-
) -> Optional[
|
114
|
+
) -> Optional[QuantizeMethodBase]:
|
115
|
+
|
116
|
+
from sglang.srt.layers.linear import LinearBase
|
117
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
118
|
+
|
113
119
|
if self.exclude_modules and any(
|
114
120
|
module in prefix
|
115
121
|
or (
|
@@ -125,9 +131,6 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
125
131
|
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
126
132
|
return ModelOptFp8KVCacheMethod(self)
|
127
133
|
|
128
|
-
# Add MoE support
|
129
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
130
|
-
|
131
134
|
if isinstance(layer, FusedMoE):
|
132
135
|
return ModelOptFp8MoEMethod(self)
|
133
136
|
|
@@ -246,7 +249,7 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
|
246
249
|
super().__init__(quant_config)
|
247
250
|
|
248
251
|
|
249
|
-
class ModelOptFp8MoEMethod:
|
252
|
+
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
250
253
|
"""MoE method for ModelOpt FP8.
|
251
254
|
Supports loading FP8 checkpoints with static weight scale and activation scale.
|
252
255
|
|
@@ -254,30 +257,6 @@ class ModelOptFp8MoEMethod:
|
|
254
257
|
quant_config: The ModelOpt quantization config.
|
255
258
|
"""
|
256
259
|
|
257
|
-
def __new__(cls, *args, **kwargs):
|
258
|
-
"""
|
259
|
-
Dynamic class composition pattern.
|
260
|
-
|
261
|
-
This allows us to effectively "inject" FusedMoEMethodBase as a parent class
|
262
|
-
at runtime while avoiding circular import issues.
|
263
|
-
"""
|
264
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
265
|
-
|
266
|
-
if not hasattr(cls, "_initialized"):
|
267
|
-
original_init = cls.__init__
|
268
|
-
new_cls = type(
|
269
|
-
cls.__name__,
|
270
|
-
(FusedMoEMethodBase,),
|
271
|
-
{
|
272
|
-
"__init__": original_init,
|
273
|
-
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
274
|
-
},
|
275
|
-
)
|
276
|
-
obj = super(new_cls, new_cls).__new__(new_cls)
|
277
|
-
obj.__init__(*args, **kwargs)
|
278
|
-
return obj
|
279
|
-
return super().__new__(cls)
|
280
|
-
|
281
260
|
def __init__(self, quant_config: ModelOptFp8Config):
|
282
261
|
self.quant_config = quant_config
|
283
262
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
@@ -426,15 +405,8 @@ class ModelOptFp8MoEMethod:
|
|
426
405
|
self,
|
427
406
|
layer: torch.nn.Module,
|
428
407
|
x: torch.Tensor,
|
429
|
-
|
430
|
-
|
431
|
-
renormalize: bool,
|
432
|
-
use_grouped_topk: bool,
|
433
|
-
topk_group: Optional[int] = None,
|
434
|
-
num_expert_group: Optional[int] = None,
|
435
|
-
num_fused_shared_experts: Optional[int] = None,
|
436
|
-
custom_routing_function: Optional[Callable] = None,
|
437
|
-
correction_bias: Optional[torch.Tensor] = None,
|
408
|
+
topk_output: TopKOutput,
|
409
|
+
*,
|
438
410
|
activation: str = "silu",
|
439
411
|
apply_router_weight_on_input: bool = False,
|
440
412
|
inplace: bool = True,
|
@@ -442,29 +414,12 @@ class ModelOptFp8MoEMethod:
|
|
442
414
|
routed_scaling_factor: Optional[float] = None,
|
443
415
|
) -> torch.Tensor:
|
444
416
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
445
|
-
from sglang.srt.layers.moe.topk import select_experts
|
446
|
-
|
447
|
-
# Expert selection
|
448
|
-
topk_weights, topk_ids = select_experts(
|
449
|
-
hidden_states=x,
|
450
|
-
router_logits=router_logits,
|
451
|
-
use_grouped_topk=use_grouped_topk,
|
452
|
-
top_k=top_k,
|
453
|
-
renormalize=renormalize,
|
454
|
-
topk_group=topk_group,
|
455
|
-
num_expert_group=num_expert_group,
|
456
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
457
|
-
custom_routing_function=custom_routing_function,
|
458
|
-
correction_bias=correction_bias,
|
459
|
-
routed_scaling_factor=routed_scaling_factor,
|
460
|
-
)
|
461
417
|
|
462
418
|
return fused_experts(
|
463
419
|
x,
|
464
420
|
layer.w13_weight,
|
465
421
|
layer.w2_weight,
|
466
|
-
|
467
|
-
topk_ids=topk_ids,
|
422
|
+
topk_output=topk_output,
|
468
423
|
inplace=inplace,
|
469
424
|
activation=activation,
|
470
425
|
use_fp8_w8a8=True,
|
@@ -514,7 +469,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
514
469
|
return ["hf_quant_config.json"]
|
515
470
|
|
516
471
|
@classmethod
|
517
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
472
|
+
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
|
518
473
|
quant_config = cls.get_from_keys(config, ["quantization"])
|
519
474
|
quant_method = quant_config["quant_algo"]
|
520
475
|
if not quant_method in ["FP8", "NVFP4"]:
|
@@ -559,7 +514,8 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
559
514
|
|
560
515
|
def get_quant_method(
|
561
516
|
self, layer: torch.nn.Module, prefix: str
|
562
|
-
) -> Optional[
|
517
|
+
) -> Optional[QuantizeMethodBase]:
|
518
|
+
from sglang.srt.layers.linear import LinearBase
|
563
519
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
564
520
|
|
565
521
|
if isinstance(layer, LinearBase):
|
@@ -740,31 +696,13 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
740
696
|
return out.view(*output_shape)
|
741
697
|
|
742
698
|
|
743
|
-
class ModelOptNvFp4FusedMoEMethod:
|
699
|
+
class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
744
700
|
"""
|
745
701
|
MoE Method for FP4 Quantization with Blockscales and PerTensorScales
|
746
702
|
Args:
|
747
703
|
quant_config: NVFP4 Quant Config
|
748
704
|
"""
|
749
705
|
|
750
|
-
def __new__(cls, *args, **kwargs):
|
751
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
752
|
-
|
753
|
-
if not hasattr(cls, "_initialized"):
|
754
|
-
original_init = cls.__init__
|
755
|
-
new_cls = type(
|
756
|
-
cls.__name__,
|
757
|
-
(FusedMoEMethodBase,),
|
758
|
-
{
|
759
|
-
"__init__": original_init,
|
760
|
-
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
761
|
-
},
|
762
|
-
)
|
763
|
-
obj = super(new_cls, new_cls).__new__(new_cls)
|
764
|
-
obj.__init__(*args, **kwargs)
|
765
|
-
return obj
|
766
|
-
return super().__new__(cls)
|
767
|
-
|
768
706
|
def __init__(self, quant_config: ModelOptFp4Config):
|
769
707
|
self.quant_config = quant_config
|
770
708
|
if not is_sm100_supported():
|
@@ -1002,15 +940,8 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
1002
940
|
self,
|
1003
941
|
layer: torch.nn.Module,
|
1004
942
|
x: torch.Tensor,
|
1005
|
-
|
1006
|
-
|
1007
|
-
renormalize: bool,
|
1008
|
-
use_grouped_topk: bool,
|
1009
|
-
topk_group: Optional[int] = None,
|
1010
|
-
num_expert_group: Optional[int] = None,
|
1011
|
-
num_fused_shared_experts: Optional[int] = None,
|
1012
|
-
custom_routing_function: Optional[Callable] = None,
|
1013
|
-
correction_bias: Optional[torch.Tensor] = None,
|
943
|
+
topk_output: TopKOutput,
|
944
|
+
*,
|
1014
945
|
activation: str = "silu",
|
1015
946
|
apply_router_weight_on_input: bool = False,
|
1016
947
|
inplace: bool = True,
|
@@ -1023,21 +954,6 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
1023
954
|
) -> torch.Tensor:
|
1024
955
|
|
1025
956
|
assert activation == "silu", "Only SiLU activation is supported."
|
1026
|
-
from sglang.srt.layers.moe.topk import select_experts
|
1027
|
-
|
1028
|
-
topk_weights, topk_ids = select_experts(
|
1029
|
-
hidden_states=x,
|
1030
|
-
router_logits=router_logits,
|
1031
|
-
use_grouped_topk=use_grouped_topk,
|
1032
|
-
top_k=top_k,
|
1033
|
-
renormalize=renormalize,
|
1034
|
-
topk_group=topk_group,
|
1035
|
-
num_expert_group=num_expert_group,
|
1036
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
1037
|
-
custom_routing_function=custom_routing_function,
|
1038
|
-
correction_bias=correction_bias,
|
1039
|
-
routed_scaling_factor=routed_scaling_factor,
|
1040
|
-
)
|
1041
957
|
|
1042
958
|
if self.enable_flashinfer_moe:
|
1043
959
|
assert (
|
@@ -1045,6 +961,7 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
1045
961
|
), "apply_router_weight_on_input is not supported for Flashinfer"
|
1046
962
|
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
1047
963
|
# and fp4 quantized weights loaded from the checkpoint
|
964
|
+
topk_weights, topk_ids, _ = topk_output
|
1048
965
|
output = flashinfer_cutlass_fused_moe(
|
1049
966
|
x,
|
1050
967
|
topk_ids.to(torch.int),
|
@@ -1070,6 +987,7 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
1070
987
|
|
1071
988
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
1072
989
|
|
990
|
+
topk_weights, topk_ids, _ = topk_output
|
1073
991
|
return cutlass_moe_fp4(
|
1074
992
|
a=x,
|
1075
993
|
a1_gscale=layer.w13_input_scale_quant,
|
@@ -1,23 +1,59 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/moe_wna16.py
|
2
|
+
from __future__ import annotations
|
2
3
|
|
3
4
|
import logging
|
4
|
-
from typing import
|
5
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
5
6
|
|
7
|
+
import numpy as np
|
6
8
|
import torch
|
7
9
|
|
8
10
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
9
11
|
from sglang.srt.distributed.parallel_state import get_tp_group
|
10
|
-
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
11
12
|
from sglang.srt.layers.quantization.awq import AWQConfig
|
12
13
|
from sglang.srt.layers.quantization.base_config import (
|
14
|
+
FusedMoEMethodBase,
|
13
15
|
QuantizationConfig,
|
14
16
|
QuantizeMethodBase,
|
15
17
|
)
|
16
18
|
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
19
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
17
20
|
from sglang.srt.utils import get_device_capability, set_weight_attrs
|
18
21
|
|
19
22
|
logger = logging.getLogger(__name__)
|
20
23
|
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
26
|
+
|
27
|
+
|
28
|
+
def get_weight_perm(num_bits: int):
|
29
|
+
perm_list: List[int] = []
|
30
|
+
for i in range(32):
|
31
|
+
perm1: List[int] = []
|
32
|
+
col = i // 4
|
33
|
+
for block in [0, 1]:
|
34
|
+
for row in [
|
35
|
+
2 * (i % 4),
|
36
|
+
2 * (i % 4) + 1,
|
37
|
+
2 * (i % 4 + 4),
|
38
|
+
2 * (i % 4 + 4) + 1,
|
39
|
+
]:
|
40
|
+
perm1.append(16 * row + col + 8 * block)
|
41
|
+
for j in range(4):
|
42
|
+
perm_list.extend([p + 256 * j for p in perm1])
|
43
|
+
|
44
|
+
perm = np.array(perm_list)
|
45
|
+
|
46
|
+
if num_bits == 4:
|
47
|
+
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
48
|
+
elif num_bits == 8:
|
49
|
+
interleave = np.array([0, 2, 1, 3])
|
50
|
+
else:
|
51
|
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
52
|
+
|
53
|
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
54
|
+
perm = torch.from_numpy(perm)
|
55
|
+
return perm
|
56
|
+
|
21
57
|
|
22
58
|
class MoeWNA16Config(QuantizationConfig):
|
23
59
|
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
|
@@ -88,7 +124,7 @@ class MoeWNA16Config(QuantizationConfig):
|
|
88
124
|
raise NotImplementedError
|
89
125
|
|
90
126
|
@classmethod
|
91
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
127
|
+
def from_config(cls, config: Dict[str, Any]) -> MoeWNA16Config:
|
92
128
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
93
129
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
94
130
|
group_size = cls.get_from_keys(config, ["group_size"])
|
@@ -147,8 +183,9 @@ class MoeWNA16Config(QuantizationConfig):
|
|
147
183
|
|
148
184
|
def get_quant_method(
|
149
185
|
self, layer: torch.nn.Module, prefix: str
|
150
|
-
) -> Optional[
|
186
|
+
) -> Optional[QuantizeMethodBase]:
|
151
187
|
# avoid circular import
|
188
|
+
from sglang.srt.layers.linear import LinearBase
|
152
189
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
153
190
|
|
154
191
|
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
|
@@ -179,32 +216,13 @@ def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):
|
|
179
216
|
return any(module_name in prefix for module_name in modules_to_not_convert)
|
180
217
|
|
181
218
|
|
182
|
-
class MoeWNA16Method:
|
219
|
+
class MoeWNA16Method(FusedMoEMethodBase):
|
183
220
|
"""Linear method for MOE WNA16 (W8A16/W4A16) quantization.
|
184
221
|
|
185
222
|
Args:
|
186
223
|
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
|
187
224
|
"""
|
188
225
|
|
189
|
-
def __new__(cls, *args, **kwargs):
|
190
|
-
# avoid circular import
|
191
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
192
|
-
|
193
|
-
if not hasattr(cls, "_initialized"):
|
194
|
-
original_init = cls.__init__
|
195
|
-
new_cls = type(
|
196
|
-
cls.__name__,
|
197
|
-
(FusedMoEMethodBase,),
|
198
|
-
{
|
199
|
-
"__init__": original_init,
|
200
|
-
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
201
|
-
},
|
202
|
-
)
|
203
|
-
obj = super(new_cls, new_cls).__new__(new_cls)
|
204
|
-
obj.__init__(*args, **kwargs)
|
205
|
-
return obj
|
206
|
-
return super().__new__(cls)
|
207
|
-
|
208
226
|
def __init__(self, quant_config: MoeWNA16Config):
|
209
227
|
self.quant_config = quant_config
|
210
228
|
|
@@ -334,15 +352,8 @@ class MoeWNA16Method:
|
|
334
352
|
self,
|
335
353
|
layer: torch.nn.Module,
|
336
354
|
x: torch.Tensor,
|
337
|
-
|
338
|
-
|
339
|
-
renormalize: bool,
|
340
|
-
use_grouped_topk: bool = False,
|
341
|
-
topk_group: Optional[int] = None,
|
342
|
-
num_expert_group: Optional[int] = None,
|
343
|
-
num_fused_shared_experts: int = 0,
|
344
|
-
custom_routing_function: Optional[Callable] = None,
|
345
|
-
correction_bias: Optional[torch.Tensor] = None,
|
355
|
+
topk_output: TopKOutput,
|
356
|
+
*,
|
346
357
|
activation: str = "silu",
|
347
358
|
apply_router_weight_on_input: bool = False,
|
348
359
|
inplace: bool = True,
|
@@ -351,22 +362,8 @@ class MoeWNA16Method:
|
|
351
362
|
) -> torch.Tensor:
|
352
363
|
# avoid circular import
|
353
364
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
354
|
-
from sglang.srt.layers.moe.topk import select_experts
|
355
365
|
|
356
366
|
assert activation == "silu", "Only SiLU activation is supported."
|
357
|
-
topk_weights, topk_ids = select_experts(
|
358
|
-
hidden_states=x,
|
359
|
-
router_logits=router_logits,
|
360
|
-
top_k=top_k,
|
361
|
-
use_grouped_topk=use_grouped_topk,
|
362
|
-
renormalize=renormalize,
|
363
|
-
topk_group=topk_group,
|
364
|
-
num_expert_group=num_expert_group,
|
365
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
366
|
-
custom_routing_function=custom_routing_function,
|
367
|
-
correction_bias=correction_bias,
|
368
|
-
routed_scaling_factor=routed_scaling_factor,
|
369
|
-
)
|
370
367
|
|
371
368
|
weight_bits = self.quant_config.weight_bits
|
372
369
|
has_zp = self.quant_config.has_zp
|
@@ -375,8 +372,7 @@ class MoeWNA16Method:
|
|
375
372
|
x,
|
376
373
|
layer.w13_qweight,
|
377
374
|
layer.w2_qweight,
|
378
|
-
|
379
|
-
topk_ids=topk_ids,
|
375
|
+
topk_output=topk_output,
|
380
376
|
inplace=inplace,
|
381
377
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
382
378
|
use_int4_w4a16=weight_bits == 4,
|
@@ -0,0 +1,252 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
2
|
+
|
3
|
+
|
4
|
+
import logging
|
5
|
+
from typing import Any, Callable, Dict, List, Optional
|
6
|
+
|
7
|
+
import regex as re
|
8
|
+
import torch
|
9
|
+
from torch.nn.parameter import Parameter
|
10
|
+
|
11
|
+
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
12
|
+
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
13
|
+
from sglang.srt.layers.quantization.base_config import (
|
14
|
+
LinearMethodBase,
|
15
|
+
QuantizationConfig,
|
16
|
+
QuantizeMethodBase,
|
17
|
+
)
|
18
|
+
from sglang.srt.layers.quantization.petit_utils import (
|
19
|
+
apply_petit_nvfp4_linear,
|
20
|
+
prepare_nvfp4_layer_for_petit,
|
21
|
+
verify_petit_nvfp4_supported,
|
22
|
+
)
|
23
|
+
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
24
|
+
from sglang.srt.utils import is_hip
|
25
|
+
|
26
|
+
_is_hip = is_hip()
|
27
|
+
|
28
|
+
# Initialize logger for the module
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
# Configuration class to support the NVFP4 quantized model generated by the ModelOpt quantization tool
|
33
|
+
class PetitNvFp4Config(QuantizationConfig):
|
34
|
+
"""Config class for Petit FP4."""
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
is_checkpoint_nvfp4_serialized: bool = False,
|
39
|
+
kv_cache_quant_algo: str = None,
|
40
|
+
group_size: int = None,
|
41
|
+
exclude_modules: List[str] = None,
|
42
|
+
) -> None:
|
43
|
+
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
44
|
+
if is_checkpoint_nvfp4_serialized:
|
45
|
+
logger.warning(
|
46
|
+
"Detected nvfp4 checkpoint. Please note that the "
|
47
|
+
"format is experimental and subject to change."
|
48
|
+
)
|
49
|
+
self.group_size = group_size
|
50
|
+
self.kv_cache_quant_algo = kv_cache_quant_algo
|
51
|
+
self.exclude_modules = exclude_modules
|
52
|
+
|
53
|
+
@classmethod
|
54
|
+
def get_name(cls) -> str:
|
55
|
+
return "petit_nvfp4"
|
56
|
+
|
57
|
+
@classmethod
|
58
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
59
|
+
return [torch.bfloat16, torch.half]
|
60
|
+
|
61
|
+
@classmethod
|
62
|
+
def get_min_capability(cls) -> int:
|
63
|
+
# Petit supports the gfx90a and gfx942 GPUs
|
64
|
+
return 90
|
65
|
+
|
66
|
+
@classmethod
|
67
|
+
def get_config_filenames(cls) -> List[str]:
|
68
|
+
return ["hf_quant_config.json"]
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
def from_config(cls, config: Dict[str, Any]) -> "PetitNvFp4Config":
|
72
|
+
quant_config = cls.get_from_keys(config, ["quantization"])
|
73
|
+
quant_method = quant_config["quant_algo"]
|
74
|
+
group_size = quant_config.get("group_size", None)
|
75
|
+
verify_petit_nvfp4_supported(quant_method, group_size)
|
76
|
+
|
77
|
+
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
78
|
+
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
|
79
|
+
if not kv_cache_quant_algo:
|
80
|
+
kv_cache_quant_algo = "auto"
|
81
|
+
exclude_modules = quant_config.get("exclude_modules", None)
|
82
|
+
if not (group_size and kv_cache_quant_algo and (exclude_modules is not None)):
|
83
|
+
logger.warning(
|
84
|
+
f"group_size: {group_size},"
|
85
|
+
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
|
86
|
+
f"exclude_modules: {exclude_modules}"
|
87
|
+
)
|
88
|
+
raise ValueError(
|
89
|
+
"NVFP4 quantization requires group size and "
|
90
|
+
"kv_cache_quant_algo specified in "
|
91
|
+
"hf_quant_config.json"
|
92
|
+
)
|
93
|
+
return cls(
|
94
|
+
is_checkpoint_nvfp4_serialized,
|
95
|
+
kv_cache_quant_algo,
|
96
|
+
group_size,
|
97
|
+
exclude_modules,
|
98
|
+
)
|
99
|
+
|
100
|
+
@classmethod
|
101
|
+
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
102
|
+
can_convert = cls.is_petit_nvfp4_compatible(hf_quant_cfg)
|
103
|
+
if can_convert:
|
104
|
+
return cls.get_name()
|
105
|
+
return None
|
106
|
+
|
107
|
+
@classmethod
|
108
|
+
def is_petit_nvfp4_compatible(cls, quant_config: Dict[str, Any]) -> bool:
|
109
|
+
quant_method = quant_config.get("quant_method", "").lower()
|
110
|
+
return _is_hip and quant_method == "modelopt"
|
111
|
+
|
112
|
+
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
113
|
+
for pattern in exclude_modules:
|
114
|
+
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
115
|
+
if re.fullmatch(regex_str, prefix):
|
116
|
+
return True
|
117
|
+
return False
|
118
|
+
|
119
|
+
def get_quant_method(
|
120
|
+
self, layer: torch.nn.Module, prefix: str
|
121
|
+
) -> Optional["QuantizeMethodBase"]:
|
122
|
+
if isinstance(layer, LinearBase):
|
123
|
+
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
|
124
|
+
prefix, self.exclude_modules
|
125
|
+
):
|
126
|
+
return UnquantizedLinearMethod()
|
127
|
+
return PetitNvFp4LinearMethod(self)
|
128
|
+
return None
|
129
|
+
|
130
|
+
def get_scaled_act_names(self) -> List[str]:
|
131
|
+
return []
|
132
|
+
|
133
|
+
|
134
|
+
class PetitNvFp4LinearMethod(LinearMethodBase):
|
135
|
+
"""Linear method for NVFP4.
|
136
|
+
Supports loading NVFP4 checkpoints with the following structure:
|
137
|
+
|
138
|
+
|Tensor Name | datatype | shape |
|
139
|
+
|----------------------------------------------------|
|
140
|
+
|input_scale | torch.float32 | scalar |
|
141
|
+
|weight | NVFP4(SE2M1) | [1, X, y/2] |
|
142
|
+
|weight_scale | FP8-E4M3 | [X, Y] |
|
143
|
+
|weight_scale_2 | torch.float32 | scalar |
|
144
|
+
|
145
|
+
The weights are quantized per block of 16 elements.
|
146
|
+
Args: quant_config: The ModelOpt quantization config.
|
147
|
+
"""
|
148
|
+
|
149
|
+
def __init__(self, quant_config: PetitNvFp4Config):
|
150
|
+
self.quant_config = quant_config
|
151
|
+
|
152
|
+
def create_weights(
|
153
|
+
self,
|
154
|
+
layer: torch.nn.Module,
|
155
|
+
input_size_per_partition: int,
|
156
|
+
output_partition_sizes: List[int],
|
157
|
+
input_size: int,
|
158
|
+
output_size: int,
|
159
|
+
params_dtype: torch.dtype,
|
160
|
+
**extra_weight_attrs,
|
161
|
+
):
|
162
|
+
del input_size, output_size
|
163
|
+
if not self.quant_config.is_checkpoint_nvfp4_serialized:
|
164
|
+
raise ValueError(
|
165
|
+
"NVFP4 quantization was selected, "
|
166
|
+
" dynamic quantization is not supported."
|
167
|
+
)
|
168
|
+
|
169
|
+
output_size_per_partition = sum(output_partition_sizes)
|
170
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
171
|
+
|
172
|
+
layer.logical_widths = output_partition_sizes
|
173
|
+
|
174
|
+
layer.input_size_per_partition = input_size_per_partition
|
175
|
+
layer.output_size_per_partition = output_size_per_partition
|
176
|
+
if input_size_per_partition % 16 != 0:
|
177
|
+
raise ValueError(
|
178
|
+
"Unsupported model when in features size is " "not multiple of 16"
|
179
|
+
)
|
180
|
+
|
181
|
+
weight_dtype = (
|
182
|
+
torch.float8_e4m3fn
|
183
|
+
if self.quant_config.is_checkpoint_nvfp4_serialized
|
184
|
+
else params_dtype
|
185
|
+
)
|
186
|
+
|
187
|
+
weight = ModelWeightParameter(
|
188
|
+
data=torch.empty(
|
189
|
+
# 2 fp4 data is packed in one uint8 in the input dimension
|
190
|
+
output_size_per_partition,
|
191
|
+
input_size_per_partition // 2,
|
192
|
+
dtype=torch.uint8,
|
193
|
+
),
|
194
|
+
input_dim=1,
|
195
|
+
output_dim=0,
|
196
|
+
weight_loader=weight_loader,
|
197
|
+
)
|
198
|
+
layer.register_parameter("weight", weight)
|
199
|
+
|
200
|
+
input_scale = PerTensorScaleParameter(
|
201
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
202
|
+
weight_loader=weight_loader,
|
203
|
+
)
|
204
|
+
|
205
|
+
layer.register_parameter("input_scale", input_scale)
|
206
|
+
|
207
|
+
weight_scale_2 = PerTensorScaleParameter(
|
208
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
209
|
+
weight_loader=weight_loader,
|
210
|
+
)
|
211
|
+
layer.register_parameter("weight_scale_2", weight_scale_2)
|
212
|
+
|
213
|
+
weight_scale = ModelWeightParameter(
|
214
|
+
data=torch.empty(
|
215
|
+
output_size_per_partition,
|
216
|
+
input_size_per_partition // self.quant_config.group_size,
|
217
|
+
dtype=weight_dtype,
|
218
|
+
),
|
219
|
+
input_dim=1,
|
220
|
+
output_dim=0,
|
221
|
+
weight_loader=weight_loader,
|
222
|
+
)
|
223
|
+
|
224
|
+
layer.register_parameter("weight_scale", weight_scale)
|
225
|
+
|
226
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
227
|
+
input_scale_2 = layer.input_scale.max().to(torch.float32)
|
228
|
+
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
|
229
|
+
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
|
230
|
+
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
|
231
|
+
layer.alpha = Parameter(
|
232
|
+
layer.input_scale * layer.weight_scale_2, requires_grad=False
|
233
|
+
)
|
234
|
+
|
235
|
+
prepare_nvfp4_layer_for_petit(layer)
|
236
|
+
del layer.input_scale
|
237
|
+
|
238
|
+
def apply(
|
239
|
+
self,
|
240
|
+
layer: torch.nn.Module,
|
241
|
+
x: torch.Tensor,
|
242
|
+
bias: Optional[torch.Tensor] = None,
|
243
|
+
) -> torch.Tensor:
|
244
|
+
return apply_petit_nvfp4_linear(
|
245
|
+
input=x,
|
246
|
+
weight=layer.weight,
|
247
|
+
weight_scale=layer.weight_scale,
|
248
|
+
weight_scale_2=layer.weight_scale_2,
|
249
|
+
size_n=layer.output_size_per_partition,
|
250
|
+
size_k=layer.input_size_per_partition,
|
251
|
+
bias=bias,
|
252
|
+
)
|