sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,9 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import logging
|
4
|
-
from typing import
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
5
7
|
|
6
8
|
import torch
|
7
9
|
import torch.nn.functional as F
|
@@ -28,17 +30,14 @@ except ImportError:
|
|
28
30
|
|
29
31
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
30
32
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
31
|
-
from sglang.srt.layers.linear import (
|
32
|
-
LinearBase,
|
33
|
-
LinearMethodBase,
|
34
|
-
UnquantizedLinearMethod,
|
35
|
-
)
|
36
33
|
from sglang.srt.layers.parameter import (
|
37
34
|
BlockQuantScaleParameter,
|
38
35
|
ModelWeightParameter,
|
39
36
|
PerTensorScaleParameter,
|
40
37
|
)
|
41
38
|
from sglang.srt.layers.quantization.base_config import (
|
39
|
+
FusedMoEMethodBase,
|
40
|
+
LinearMethodBase,
|
42
41
|
QuantizationConfig,
|
43
42
|
QuantizeMethodBase,
|
44
43
|
)
|
@@ -56,6 +55,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
56
55
|
normalize_e4m3fn_to_e4m3fnuz,
|
57
56
|
)
|
58
57
|
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
58
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
59
59
|
from sglang.srt.layers.quantization.utils import (
|
60
60
|
all_close_1d,
|
61
61
|
convert_to_channelwise,
|
@@ -77,6 +77,10 @@ from sglang.srt.utils import (
|
|
77
77
|
use_intel_amx_backend,
|
78
78
|
)
|
79
79
|
|
80
|
+
if TYPE_CHECKING:
|
81
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
82
|
+
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
83
|
+
|
80
84
|
_is_hip = is_hip()
|
81
85
|
_is_cuda = is_cuda()
|
82
86
|
_is_npu = is_npu()
|
@@ -91,10 +95,9 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
91
95
|
if _is_hip and (_use_aiter or _use_hip_int4):
|
92
96
|
from aiter import ActivationType, QuantType
|
93
97
|
from aiter.fused_moe import fused_moe
|
94
|
-
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
95
98
|
from aiter.ops.shuffle import shuffle_weight
|
96
99
|
|
97
|
-
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
100
|
+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
98
101
|
from vllm._custom_ops import scaled_fp8_quant
|
99
102
|
|
100
103
|
|
@@ -152,7 +155,7 @@ class Fp8Config(QuantizationConfig):
|
|
152
155
|
return []
|
153
156
|
|
154
157
|
@classmethod
|
155
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
158
|
+
def from_config(cls, config: Dict[str, Any]) -> Fp8Config:
|
156
159
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
157
160
|
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
158
161
|
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
@@ -167,7 +170,8 @@ class Fp8Config(QuantizationConfig):
|
|
167
170
|
|
168
171
|
def get_quant_method(
|
169
172
|
self, layer: torch.nn.Module, prefix: str
|
170
|
-
) -> Optional[
|
173
|
+
) -> Optional[QuantizeMethodBase]:
|
174
|
+
from sglang.srt.layers.linear import LinearBase
|
171
175
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
172
176
|
|
173
177
|
if isinstance(layer, LinearBase):
|
@@ -200,7 +204,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
200
204
|
quant_config: The quantization config.
|
201
205
|
"""
|
202
206
|
|
203
|
-
def __init__(self, quant_config: Union[
|
207
|
+
def __init__(self, quant_config: Union[Fp8Config, W4AFp8Config]):
|
204
208
|
self.quant_config = quant_config
|
205
209
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
206
210
|
|
@@ -486,7 +490,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
486
490
|
)
|
487
491
|
|
488
492
|
|
489
|
-
class Fp8MoEMethod:
|
493
|
+
class Fp8MoEMethod(FusedMoEMethodBase):
|
490
494
|
"""MoE method for FP8.
|
491
495
|
Supports loading FP8 checkpoints with static weight scale and
|
492
496
|
dynamic/static activation scale.
|
@@ -499,25 +503,7 @@ class Fp8MoEMethod:
|
|
499
503
|
quant_config: The quantization config.
|
500
504
|
"""
|
501
505
|
|
502
|
-
def
|
503
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
504
|
-
|
505
|
-
if not hasattr(cls, "_initialized"):
|
506
|
-
original_init = cls.__init__
|
507
|
-
new_cls = type(
|
508
|
-
cls.__name__,
|
509
|
-
(FusedMoEMethodBase,),
|
510
|
-
{
|
511
|
-
"__init__": original_init,
|
512
|
-
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
513
|
-
},
|
514
|
-
)
|
515
|
-
obj = super(new_cls, new_cls).__new__(new_cls)
|
516
|
-
obj.__init__(*args, **kwargs)
|
517
|
-
return obj
|
518
|
-
return super().__new__(cls)
|
519
|
-
|
520
|
-
def __init__(self, quant_config):
|
506
|
+
def __init__(self, quant_config: Fp8Config):
|
521
507
|
self.quant_config = quant_config
|
522
508
|
self.block_quant = self.quant_config.weight_block_size is not None
|
523
509
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
@@ -985,15 +971,8 @@ class Fp8MoEMethod:
|
|
985
971
|
self,
|
986
972
|
layer: torch.nn.Module,
|
987
973
|
x: torch.Tensor,
|
988
|
-
|
989
|
-
|
990
|
-
renormalize: bool,
|
991
|
-
use_grouped_topk: bool,
|
992
|
-
topk_group: Optional[int] = None,
|
993
|
-
num_expert_group: Optional[int] = None,
|
994
|
-
num_fused_shared_experts: int = 0,
|
995
|
-
custom_routing_function: Optional[Callable] = None,
|
996
|
-
correction_bias: Optional[torch.Tensor] = None,
|
974
|
+
topk_output: TopKOutput,
|
975
|
+
*,
|
997
976
|
activation: str = "silu",
|
998
977
|
apply_router_weight_on_input: bool = False,
|
999
978
|
inplace: bool = True,
|
@@ -1001,24 +980,15 @@ class Fp8MoEMethod:
|
|
1001
980
|
routed_scaling_factor: Optional[float] = None,
|
1002
981
|
) -> torch.Tensor:
|
1003
982
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
1004
|
-
from sglang.srt.layers.moe.topk import select_experts
|
1005
|
-
|
1006
|
-
# Expert selection
|
1007
|
-
topk_weights, topk_ids = select_experts(
|
1008
|
-
hidden_states=x,
|
1009
|
-
router_logits=router_logits,
|
1010
|
-
use_grouped_topk=use_grouped_topk,
|
1011
|
-
top_k=top_k,
|
1012
|
-
renormalize=renormalize,
|
1013
|
-
topk_group=topk_group,
|
1014
|
-
num_expert_group=num_expert_group,
|
1015
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
1016
|
-
custom_routing_function=custom_routing_function,
|
1017
|
-
correction_bias=correction_bias,
|
1018
|
-
routed_scaling_factor=routed_scaling_factor,
|
1019
|
-
)
|
1020
983
|
|
1021
984
|
if use_intel_amx_backend(layer):
|
985
|
+
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
986
|
+
|
987
|
+
topk_weights, topk_ids, _ = topk_output
|
988
|
+
x, topk_weights = apply_topk_weights_cpu(
|
989
|
+
apply_router_weight_on_input, topk_weights, x
|
990
|
+
)
|
991
|
+
|
1022
992
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
1023
993
|
x,
|
1024
994
|
layer.w13_weight,
|
@@ -1040,8 +1010,7 @@ class Fp8MoEMethod:
|
|
1040
1010
|
ret = self.maybe_apply_hip_fused_experts(
|
1041
1011
|
layer,
|
1042
1012
|
x,
|
1043
|
-
|
1044
|
-
topk_ids,
|
1013
|
+
topk_output,
|
1045
1014
|
activation,
|
1046
1015
|
no_combine,
|
1047
1016
|
)
|
@@ -1056,6 +1025,7 @@ class Fp8MoEMethod:
|
|
1056
1025
|
):
|
1057
1026
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
1058
1027
|
|
1028
|
+
topk_weights, topk_ids, _ = topk_output
|
1059
1029
|
return cutlass_fused_experts_fp8(
|
1060
1030
|
x,
|
1061
1031
|
layer.w13_weight.transpose(1, 2),
|
@@ -1084,8 +1054,7 @@ class Fp8MoEMethod:
|
|
1084
1054
|
x,
|
1085
1055
|
layer.w13_weight,
|
1086
1056
|
layer.w2_weight,
|
1087
|
-
|
1088
|
-
topk_ids=topk_ids,
|
1057
|
+
topk_output=topk_output,
|
1089
1058
|
inplace=inplace and not no_combine,
|
1090
1059
|
activation=activation,
|
1091
1060
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
@@ -1109,11 +1078,11 @@ class Fp8MoEMethod:
|
|
1109
1078
|
self,
|
1110
1079
|
layer: torch.nn.Module,
|
1111
1080
|
x: torch.Tensor,
|
1112
|
-
|
1113
|
-
topk_ids: torch.Tensor,
|
1081
|
+
topk_output: TopKOutput,
|
1114
1082
|
activation: str = "silu",
|
1115
1083
|
no_combine: bool = False,
|
1116
1084
|
) -> Optional[torch.Tensor]:
|
1085
|
+
topk_weights, topk_ids, _ = topk_output
|
1117
1086
|
if _use_hip_int4:
|
1118
1087
|
# TODO: add triton kernel and add check _use_aiter
|
1119
1088
|
assert not no_combine, f"{no_combine=} is not supported."
|
@@ -1169,6 +1138,248 @@ class Fp8MoEMethod:
|
|
1169
1138
|
return None
|
1170
1139
|
|
1171
1140
|
|
1141
|
+
class Fp8EPMoEMethod(Fp8MoEMethod):
|
1142
|
+
"""MoE method for FP8.
|
1143
|
+
Supports loading FP8 checkpoints with static weight scale and
|
1144
|
+
dynamic/static activation scale.
|
1145
|
+
|
1146
|
+
Args:
|
1147
|
+
quant_config: The quantization config.
|
1148
|
+
"""
|
1149
|
+
|
1150
|
+
def __init__(self, quant_config: Fp8Config):
|
1151
|
+
self.quant_config = quant_config
|
1152
|
+
self.block_quant = self.quant_config.weight_block_size is not None
|
1153
|
+
|
1154
|
+
def create_weights(
|
1155
|
+
self,
|
1156
|
+
layer: Module,
|
1157
|
+
num_experts_per_partition: int,
|
1158
|
+
hidden_size: int,
|
1159
|
+
intermediate_size: int,
|
1160
|
+
params_dtype: torch.dtype,
|
1161
|
+
**extra_weight_attrs,
|
1162
|
+
):
|
1163
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
1164
|
+
|
1165
|
+
if self.quant_config.is_checkpoint_fp8_serialized:
|
1166
|
+
params_dtype = torch.float8_e4m3fn
|
1167
|
+
|
1168
|
+
tp_size = get_tensor_model_parallel_world_size()
|
1169
|
+
if self.block_quant:
|
1170
|
+
block_n, block_k = (
|
1171
|
+
self.quant_config.weight_block_size[0],
|
1172
|
+
self.quant_config.weight_block_size[1],
|
1173
|
+
)
|
1174
|
+
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
1175
|
+
# Required by column parallel or enabling merged weights
|
1176
|
+
if intermediate_size % block_n != 0:
|
1177
|
+
raise ValueError(
|
1178
|
+
f"The output_size of gate's and up's weight = "
|
1179
|
+
f"{intermediate_size} is not divisible by "
|
1180
|
+
f"weight quantization block_n = {block_n}."
|
1181
|
+
)
|
1182
|
+
if tp_size > 1:
|
1183
|
+
# Required by row parallel
|
1184
|
+
if intermediate_size % block_k != 0:
|
1185
|
+
raise ValueError(
|
1186
|
+
f"The input_size of down's weight = "
|
1187
|
+
f"{intermediate_size} is not divisible by "
|
1188
|
+
f"weight quantization block_k = {block_k}."
|
1189
|
+
)
|
1190
|
+
|
1191
|
+
# WEIGHTS
|
1192
|
+
w13_weight = torch.nn.Parameter(
|
1193
|
+
torch.empty(
|
1194
|
+
num_experts_per_partition,
|
1195
|
+
2 * intermediate_size,
|
1196
|
+
hidden_size,
|
1197
|
+
dtype=params_dtype,
|
1198
|
+
),
|
1199
|
+
requires_grad=False,
|
1200
|
+
)
|
1201
|
+
layer.register_parameter("w13_weight", w13_weight)
|
1202
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
1203
|
+
|
1204
|
+
w2_weight = torch.nn.Parameter(
|
1205
|
+
torch.empty(
|
1206
|
+
num_experts_per_partition,
|
1207
|
+
hidden_size,
|
1208
|
+
intermediate_size,
|
1209
|
+
dtype=params_dtype,
|
1210
|
+
),
|
1211
|
+
requires_grad=False,
|
1212
|
+
)
|
1213
|
+
layer.register_parameter("w2_weight", w2_weight)
|
1214
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
1215
|
+
|
1216
|
+
# WEIGHT_SCALES
|
1217
|
+
if self.block_quant:
|
1218
|
+
w13_weight_scale = torch.nn.Parameter(
|
1219
|
+
torch.ones(
|
1220
|
+
num_experts_per_partition,
|
1221
|
+
2 * ((intermediate_size + block_n - 1) // block_n),
|
1222
|
+
(hidden_size + block_k - 1) // block_k,
|
1223
|
+
dtype=torch.float32,
|
1224
|
+
),
|
1225
|
+
requires_grad=False,
|
1226
|
+
)
|
1227
|
+
w2_weight_scale = torch.nn.Parameter(
|
1228
|
+
torch.ones(
|
1229
|
+
num_experts_per_partition,
|
1230
|
+
(hidden_size + block_n - 1) // block_n,
|
1231
|
+
(intermediate_size + block_k - 1) // block_k,
|
1232
|
+
dtype=torch.float32,
|
1233
|
+
),
|
1234
|
+
requires_grad=False,
|
1235
|
+
)
|
1236
|
+
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
1237
|
+
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
1238
|
+
assert self.quant_config.activation_scheme == "dynamic"
|
1239
|
+
else:
|
1240
|
+
# WEIGHT_SCALES
|
1241
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
1242
|
+
w13_weight_scale = torch.nn.Parameter(
|
1243
|
+
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
1244
|
+
requires_grad=False,
|
1245
|
+
)
|
1246
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
1247
|
+
|
1248
|
+
w2_weight_scale = torch.nn.Parameter(
|
1249
|
+
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
1250
|
+
requires_grad=False,
|
1251
|
+
)
|
1252
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
1253
|
+
# Add the quantization method used (per tensor/grouped/channel)
|
1254
|
+
# to ensure the weight scales are loaded in properly
|
1255
|
+
extra_weight_attrs.update(
|
1256
|
+
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
1257
|
+
if self.block_quant
|
1258
|
+
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
1259
|
+
)
|
1260
|
+
# If loading fp8 checkpoint, pass the weight loaders.
|
1261
|
+
# If loading an fp16 checkpoint, do not (we will quantize in
|
1262
|
+
# process_weights_after_loading()
|
1263
|
+
if self.quant_config.is_checkpoint_fp8_serialized:
|
1264
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
1265
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
1266
|
+
|
1267
|
+
# INPUT_SCALES
|
1268
|
+
if self.quant_config.activation_scheme == "static":
|
1269
|
+
if not self.quant_config.is_checkpoint_fp8_serialized:
|
1270
|
+
raise ValueError(
|
1271
|
+
"Found static activation scheme for checkpoint that "
|
1272
|
+
"was not serialized fp8."
|
1273
|
+
)
|
1274
|
+
|
1275
|
+
w13_input_scale = torch.nn.Parameter(
|
1276
|
+
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
1277
|
+
requires_grad=False,
|
1278
|
+
)
|
1279
|
+
layer.register_parameter("w13_input_scale", w13_input_scale)
|
1280
|
+
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
1281
|
+
|
1282
|
+
w2_input_scale = torch.nn.Parameter(
|
1283
|
+
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
1284
|
+
requires_grad=False,
|
1285
|
+
)
|
1286
|
+
layer.register_parameter("w2_input_scale", w2_input_scale)
|
1287
|
+
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
1288
|
+
|
1289
|
+
else:
|
1290
|
+
layer.w13_input_scale = None
|
1291
|
+
layer.w2_input_scale = None
|
1292
|
+
|
1293
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
1294
|
+
|
1295
|
+
# If checkpoint is fp16, quantize in place.
|
1296
|
+
if not self.quant_config.is_checkpoint_fp8_serialized:
|
1297
|
+
# If rocm, use float8_e4m3fnuz as dtype
|
1298
|
+
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
1299
|
+
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
1300
|
+
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
1301
|
+
|
1302
|
+
layer.w13_weight_scale = torch.nn.Parameter(
|
1303
|
+
torch.ones(
|
1304
|
+
layer.num_experts_per_partition,
|
1305
|
+
dtype=torch.float32,
|
1306
|
+
device=w13_weight.device,
|
1307
|
+
),
|
1308
|
+
requires_grad=False,
|
1309
|
+
)
|
1310
|
+
|
1311
|
+
for expert in range(layer.num_experts_per_partition):
|
1312
|
+
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
1313
|
+
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
1314
|
+
)
|
1315
|
+
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
1316
|
+
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
1317
|
+
)
|
1318
|
+
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
1319
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
1320
|
+
return
|
1321
|
+
|
1322
|
+
# If checkpoint is fp8, we need to handle that the
|
1323
|
+
# MoE kernels require single activation scale and single weight
|
1324
|
+
# scale for w13 per expert.
|
1325
|
+
else:
|
1326
|
+
if self.quant_config.activation_scheme == "static":
|
1327
|
+
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
1328
|
+
raise ValueError(
|
1329
|
+
"QuantConfig has static quantization, but found "
|
1330
|
+
"activation scales are None."
|
1331
|
+
)
|
1332
|
+
layer.w13_weight_scale = torch.nn.Parameter(
|
1333
|
+
torch.max(layer.w13_weight_scale, dim=1).values,
|
1334
|
+
requires_grad=False,
|
1335
|
+
)
|
1336
|
+
if self.block_quant:
|
1337
|
+
# If ROCm, normalize the weights and scales to e4m3fnuz
|
1338
|
+
if _is_fp8_fnuz:
|
1339
|
+
# activation_scheme: dynamic
|
1340
|
+
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1341
|
+
weight=layer.w13_weight,
|
1342
|
+
weight_scale=layer.w13_weight_scale_inv,
|
1343
|
+
input_scale=None,
|
1344
|
+
)
|
1345
|
+
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1346
|
+
weight=layer.w2_weight,
|
1347
|
+
weight_scale=layer.w2_weight_scale_inv,
|
1348
|
+
input_scale=None,
|
1349
|
+
)
|
1350
|
+
# Reset the parameter
|
1351
|
+
layer.w13_weight = torch.nn.Parameter(
|
1352
|
+
w13_weight, requires_grad=False
|
1353
|
+
)
|
1354
|
+
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
1355
|
+
w13_weight_scale, requires_grad=False
|
1356
|
+
)
|
1357
|
+
layer.w13_input_scale = None
|
1358
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
1359
|
+
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
1360
|
+
w2_weight_scale, requires_grad=False
|
1361
|
+
)
|
1362
|
+
layer.w2_input_scale = None
|
1363
|
+
if _use_aiter:
|
1364
|
+
layer.w13_weight = torch.nn.Parameter(
|
1365
|
+
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
1366
|
+
requires_grad=False,
|
1367
|
+
)
|
1368
|
+
layer.w2_weight = torch.nn.Parameter(
|
1369
|
+
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
1370
|
+
requires_grad=False,
|
1371
|
+
)
|
1372
|
+
return
|
1373
|
+
|
1374
|
+
def apply(
|
1375
|
+
self,
|
1376
|
+
layer: torch.nn.Module,
|
1377
|
+
hidden_states: torch.Tensor,
|
1378
|
+
topk_output: TopKOutput,
|
1379
|
+
) -> torch.Tensor:
|
1380
|
+
raise NotImplementedError
|
1381
|
+
|
1382
|
+
|
1172
1383
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
1173
1384
|
"""
|
1174
1385
|
Supports loading kv-cache scaling factors from FP8 checkpoints.
|