sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +11 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +5 -5
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +33 -20
- sglang/srt/layers/attention/torch_native_backend.py +299 -0
- sglang/srt/layers/attention/triton_backend.py +22 -8
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +661 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +36 -2
- sglang/srt/layers/quantization/fp8.py +559 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +4 -2
- sglang/srt/layers/sampler.py +2 -0
- sglang/srt/layers/torchao_utils.py +23 -45
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +19 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +145 -85
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/model_executor/cuda_graph_runner.py +30 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +146 -153
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/model_parallel.py +1 -5
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +4 -5
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +90 -18
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -8
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +96 -31
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +24 -14
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -13
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -16
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -17
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +270 -173
- sglang/srt/server_args.py +102 -29
- sglang/srt/utils.py +295 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
- sglang-0.4.0.post1.dist-info/RECORD +189 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ from vllm.distributed import (
|
|
23
23
|
tensor_model_parallel_all_gather,
|
24
24
|
)
|
25
25
|
|
26
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
26
27
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
27
28
|
|
28
29
|
|
@@ -163,7 +164,7 @@ class LogitsProcessor(nn.Module):
|
|
163
164
|
self,
|
164
165
|
input_ids,
|
165
166
|
hidden_states,
|
166
|
-
|
167
|
+
lm_head: VocabParallelEmbedding,
|
167
168
|
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
168
169
|
):
|
169
170
|
if isinstance(logits_metadata, ForwardBatch):
|
@@ -178,7 +179,7 @@ class LogitsProcessor(nn.Module):
|
|
178
179
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
179
180
|
last_hidden = hidden_states[last_index]
|
180
181
|
|
181
|
-
last_logits =
|
182
|
+
last_logits = self._get_logits(last_hidden, lm_head)
|
182
183
|
if self.do_tensor_parallel_all_gather:
|
183
184
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
184
185
|
last_logits = last_logits[:, : self.config.vocab_size].float()
|
@@ -229,7 +230,7 @@ class LogitsProcessor(nn.Module):
|
|
229
230
|
|
230
231
|
# Compute the logits and logprobs for all required tokens
|
231
232
|
states = torch.cat(states, dim=0)
|
232
|
-
all_logits =
|
233
|
+
all_logits = self._get_logits(states, lm_head)
|
233
234
|
if self.do_tensor_parallel_all_gather:
|
234
235
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
235
236
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
@@ -276,6 +277,19 @@ class LogitsProcessor(nn.Module):
|
|
276
277
|
output_top_logprobs=output_top_logprobs,
|
277
278
|
)
|
278
279
|
|
280
|
+
def _get_logits(
|
281
|
+
self,
|
282
|
+
hidden_states: torch.Tensor,
|
283
|
+
lm_head: VocabParallelEmbedding,
|
284
|
+
embedding_bias: Optional[torch.Tensor] = None,
|
285
|
+
) -> torch.Tensor:
|
286
|
+
if hasattr(lm_head, "weight"):
|
287
|
+
logits = torch.matmul(hidden_states, lm_head.weight.T)
|
288
|
+
else:
|
289
|
+
# GGUF models
|
290
|
+
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
|
291
|
+
return logits
|
292
|
+
|
279
293
|
|
280
294
|
def test():
|
281
295
|
all_logprobs = torch.tensor(
|
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
|
13
13
|
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
14
14
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
15
15
|
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
16
|
-
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
17
16
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
18
17
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
19
18
|
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
@@ -23,6 +22,7 @@ from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
|
23
22
|
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
24
23
|
|
25
24
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
25
|
+
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
26
26
|
|
27
27
|
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
28
28
|
"aqlm": AQLMConfig,
|
@@ -100,13 +100,13 @@ def fp8_moe_apply(
|
|
100
100
|
def fp8_get_quant_method(self, layer, prefix):
|
101
101
|
"""Enhanced get_quant_method for FP8 config."""
|
102
102
|
from vllm.model_executor.layers.linear import LinearBase
|
103
|
-
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
104
103
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
105
104
|
is_layer_skipped,
|
106
105
|
)
|
107
106
|
|
108
107
|
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
109
108
|
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
109
|
+
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
110
110
|
|
111
111
|
if isinstance(layer, LinearBase):
|
112
112
|
if is_layer_skipped(prefix, self.ignored_layers):
|
@@ -117,10 +117,44 @@ def fp8_get_quant_method(self, layer, prefix):
|
|
117
117
|
return None
|
118
118
|
|
119
119
|
|
120
|
+
def gptq_get_quant_method(self, layer, prefix):
|
121
|
+
from vllm.model_executor.layers.linear import LinearBase
|
122
|
+
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
123
|
+
GPTQMarlinLinearMethod,
|
124
|
+
GPTQMarlinMoEMethod,
|
125
|
+
)
|
126
|
+
|
127
|
+
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
128
|
+
|
129
|
+
if isinstance(layer, LinearBase):
|
130
|
+
return GPTQMarlinLinearMethod(self)
|
131
|
+
elif isinstance(layer, FusedMoE):
|
132
|
+
return GPTQMarlinMoEMethod(self)
|
133
|
+
return None
|
134
|
+
|
135
|
+
|
136
|
+
def awq_get_quant_method(self, layer, prefix):
|
137
|
+
from vllm.model_executor.layers.linear import LinearBase
|
138
|
+
from vllm.model_executor.layers.quantization.awq_marlin import (
|
139
|
+
AWQMarlinLinearMethod,
|
140
|
+
AWQMoEMethod,
|
141
|
+
)
|
142
|
+
|
143
|
+
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
144
|
+
|
145
|
+
if isinstance(layer, LinearBase):
|
146
|
+
return AWQMarlinLinearMethod(self)
|
147
|
+
elif isinstance(layer, FusedMoE):
|
148
|
+
return AWQMoEMethod(self)
|
149
|
+
return None
|
150
|
+
|
151
|
+
|
120
152
|
def apply_monkey_patches():
|
121
153
|
"""Apply all monkey patches in one place."""
|
122
154
|
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
|
123
155
|
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
156
|
+
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
157
|
+
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
124
158
|
|
125
159
|
|
126
160
|
# Apply patches when module is imported
|
@@ -0,0 +1,559 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Any, Callable, Dict, List, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch.nn import Module
|
8
|
+
from torch.nn.parameter import Parameter
|
9
|
+
from vllm import _custom_ops as ops
|
10
|
+
from vllm.model_executor.layers.linear import LinearBase
|
11
|
+
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
12
|
+
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
13
|
+
apply_fp8_marlin_linear,
|
14
|
+
prepare_fp8_layer_for_marlin,
|
15
|
+
)
|
16
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
17
|
+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
18
|
+
all_close_1d,
|
19
|
+
apply_fp8_linear,
|
20
|
+
convert_to_channelwise,
|
21
|
+
cutlass_fp8_supported,
|
22
|
+
per_tensor_dequantize,
|
23
|
+
requantize_with_max_scale,
|
24
|
+
)
|
25
|
+
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
26
|
+
|
27
|
+
from sglang.srt.layers.fused_moe_triton import (
|
28
|
+
FusedMoE,
|
29
|
+
FusedMoEMethodBase,
|
30
|
+
FusedMoeWeightScaleSupported,
|
31
|
+
)
|
32
|
+
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
33
|
+
from sglang.srt.layers.quantization.base_config import (
|
34
|
+
QuantizationConfig,
|
35
|
+
QuantizeMethodBase,
|
36
|
+
)
|
37
|
+
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
38
|
+
from sglang.srt.utils import (
|
39
|
+
get_bool_env_var,
|
40
|
+
is_hip,
|
41
|
+
print_warning_once,
|
42
|
+
set_weight_attrs,
|
43
|
+
)
|
44
|
+
|
45
|
+
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
46
|
+
|
47
|
+
logger = logging.getLogger(__name__)
|
48
|
+
|
49
|
+
|
50
|
+
class Fp8Config(QuantizationConfig):
|
51
|
+
"""Config class for FP8."""
|
52
|
+
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
is_checkpoint_fp8_serialized: bool = False,
|
56
|
+
activation_scheme: str = "dynamic",
|
57
|
+
ignored_layers: Optional[List[str]] = None,
|
58
|
+
) -> None:
|
59
|
+
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
60
|
+
if is_checkpoint_fp8_serialized:
|
61
|
+
logger.warning(
|
62
|
+
"Detected fp8 checkpoint. Please note that the "
|
63
|
+
"format is experimental and subject to change."
|
64
|
+
)
|
65
|
+
if activation_scheme not in ACTIVATION_SCHEMES:
|
66
|
+
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
|
67
|
+
self.activation_scheme = activation_scheme
|
68
|
+
self.ignored_layers = ignored_layers or []
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
def get_name(cls) -> str:
|
72
|
+
return "fp8"
|
73
|
+
|
74
|
+
@classmethod
|
75
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
76
|
+
return [torch.bfloat16, torch.half]
|
77
|
+
|
78
|
+
@classmethod
|
79
|
+
def get_min_capability(cls) -> int:
|
80
|
+
return 80
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def get_config_filenames(cls) -> List[str]:
|
84
|
+
return []
|
85
|
+
|
86
|
+
@classmethod
|
87
|
+
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
|
88
|
+
quant_method = cls.get_from_keys(config, ["quant_method"])
|
89
|
+
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
90
|
+
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
91
|
+
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
92
|
+
return cls(
|
93
|
+
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
94
|
+
activation_scheme=activation_scheme,
|
95
|
+
ignored_layers=ignored_layers,
|
96
|
+
)
|
97
|
+
|
98
|
+
def get_quant_method(
|
99
|
+
self, layer: torch.nn.Module, prefix: str
|
100
|
+
) -> Optional["QuantizeMethodBase"]:
|
101
|
+
from vllm.attention.layer import Attention # Avoid circular import
|
102
|
+
|
103
|
+
if isinstance(layer, LinearBase):
|
104
|
+
if is_layer_skipped(prefix, self.ignored_layers):
|
105
|
+
return UnquantizedLinearMethod()
|
106
|
+
return Fp8LinearMethod(self)
|
107
|
+
elif isinstance(layer, FusedMoE):
|
108
|
+
return Fp8MoEMethod(self)
|
109
|
+
elif isinstance(layer, Attention):
|
110
|
+
return Fp8KVCacheMethod(self)
|
111
|
+
return None
|
112
|
+
|
113
|
+
def get_scaled_act_names(self) -> List[str]:
|
114
|
+
return []
|
115
|
+
|
116
|
+
|
117
|
+
class Fp8LinearMethod(LinearMethodBase):
|
118
|
+
"""Linear method for FP8.
|
119
|
+
Supports loading FP8 checkpoints with static weight scale and
|
120
|
+
dynamic/static activation scale.
|
121
|
+
|
122
|
+
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
123
|
+
activation scaling. The weight scaling factor will be initialized after
|
124
|
+
the model weights are loaded.
|
125
|
+
|
126
|
+
Limitations:
|
127
|
+
1. Only support per-tensor quantization due to torch._scaled_mm support.
|
128
|
+
2. Only support float8_e4m3fn data type due to the limitation of
|
129
|
+
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
|
130
|
+
|
131
|
+
Args:
|
132
|
+
quant_config: The quantization config.
|
133
|
+
"""
|
134
|
+
|
135
|
+
def __init__(self, quant_config: Fp8Config):
|
136
|
+
self.quant_config = quant_config
|
137
|
+
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
138
|
+
|
139
|
+
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
140
|
+
# kernel for fast weight-only FP8 quantization
|
141
|
+
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
|
142
|
+
# Disable marlin for ROCm
|
143
|
+
if is_hip():
|
144
|
+
self.use_marlin = False
|
145
|
+
|
146
|
+
def create_weights(
|
147
|
+
self,
|
148
|
+
layer: torch.nn.Module,
|
149
|
+
input_size_per_partition: int,
|
150
|
+
output_partition_sizes: List[int],
|
151
|
+
input_size: int,
|
152
|
+
output_size: int,
|
153
|
+
params_dtype: torch.dtype,
|
154
|
+
**extra_weight_attrs,
|
155
|
+
):
|
156
|
+
del input_size, output_size
|
157
|
+
output_size_per_partition = sum(output_partition_sizes)
|
158
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
159
|
+
|
160
|
+
layer.logical_widths = output_partition_sizes
|
161
|
+
|
162
|
+
layer.input_size_per_partition = input_size_per_partition
|
163
|
+
layer.output_size_per_partition = output_size_per_partition
|
164
|
+
layer.orig_dtype = params_dtype
|
165
|
+
|
166
|
+
# WEIGHT
|
167
|
+
weight_dtype = (
|
168
|
+
torch.float8_e4m3fn
|
169
|
+
if self.quant_config.is_checkpoint_fp8_serialized
|
170
|
+
else params_dtype
|
171
|
+
)
|
172
|
+
|
173
|
+
weight = ModelWeightParameter(
|
174
|
+
data=torch.empty(
|
175
|
+
output_size_per_partition, input_size_per_partition, dtype=weight_dtype
|
176
|
+
),
|
177
|
+
input_dim=1,
|
178
|
+
output_dim=0,
|
179
|
+
weight_loader=weight_loader,
|
180
|
+
)
|
181
|
+
layer.register_parameter("weight", weight)
|
182
|
+
|
183
|
+
# If checkpoint is serialized fp8, load them.
|
184
|
+
# Otherwise, wait until process_weights_after_loading.
|
185
|
+
if self.quant_config.is_checkpoint_fp8_serialized:
|
186
|
+
# WEIGHT SCALE
|
187
|
+
scale = PerTensorScaleParameter(
|
188
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
189
|
+
weight_loader=weight_loader,
|
190
|
+
)
|
191
|
+
|
192
|
+
scale[:] = torch.finfo(torch.float32).min
|
193
|
+
layer.register_parameter("weight_scale", scale)
|
194
|
+
|
195
|
+
# INPUT ACTIVATION SCALE
|
196
|
+
if self.quant_config.activation_scheme == "static":
|
197
|
+
scale = PerTensorScaleParameter(
|
198
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
199
|
+
weight_loader=weight_loader,
|
200
|
+
)
|
201
|
+
|
202
|
+
scale[:] = torch.finfo(torch.float32).min
|
203
|
+
layer.register_parameter("input_scale", scale)
|
204
|
+
else:
|
205
|
+
layer.register_parameter("input_scale", None)
|
206
|
+
|
207
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
208
|
+
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
209
|
+
# If checkpoint not serialized fp8, quantize the weights.
|
210
|
+
if not self.quant_config.is_checkpoint_fp8_serialized:
|
211
|
+
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
212
|
+
|
213
|
+
# If using marlin (w8a16), kernel uses channelwise weights,
|
214
|
+
# so extend the weight scales to be channelwise.
|
215
|
+
if self.use_marlin:
|
216
|
+
assert weight_scale.numel() == 1
|
217
|
+
weight_scale = convert_to_channelwise(
|
218
|
+
weight_scale.expand(len(layer.logical_widths)), layer.logical_widths
|
219
|
+
)
|
220
|
+
|
221
|
+
# Update the layer with the new values.
|
222
|
+
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
223
|
+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
224
|
+
layer.input_scale = None
|
225
|
+
|
226
|
+
# If checkpoint is fp8, handle that there are N scales for N
|
227
|
+
# shards in a fused module
|
228
|
+
else:
|
229
|
+
layer.weight_scale = torch.nn.Parameter(
|
230
|
+
layer.weight_scale.data, requires_grad=False
|
231
|
+
)
|
232
|
+
if self.quant_config.activation_scheme == "static":
|
233
|
+
layer.input_scale = torch.nn.Parameter(
|
234
|
+
layer.input_scale.data, requires_grad=False
|
235
|
+
)
|
236
|
+
# If using marlin (w8a16), kernel uses channelwise weights,
|
237
|
+
# so extend the weight scales to be channelwise.
|
238
|
+
if self.use_marlin:
|
239
|
+
weight = layer.weight
|
240
|
+
weight_scale = convert_to_channelwise(
|
241
|
+
layer.weight_scale, layer.logical_widths
|
242
|
+
)
|
243
|
+
|
244
|
+
# If using w8a8, torch._scaled_mm needs per tensor, so
|
245
|
+
# requantize the logical shards as a single weight.
|
246
|
+
else:
|
247
|
+
# Dequant -> Quant with max scale so we can run per tensor.
|
248
|
+
weight = layer.weight
|
249
|
+
weight_scale = layer.weight_scale
|
250
|
+
|
251
|
+
# If ROCm, normalize the weights and scales to e4m3fnuz
|
252
|
+
if is_hip():
|
253
|
+
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
254
|
+
weight=weight,
|
255
|
+
weight_scale=weight_scale,
|
256
|
+
input_scale=layer.input_scale,
|
257
|
+
)
|
258
|
+
if input_scale is not None:
|
259
|
+
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
260
|
+
|
261
|
+
weight_scale, weight = requantize_with_max_scale(
|
262
|
+
weight=weight,
|
263
|
+
weight_scale=weight_scale,
|
264
|
+
logical_widths=layer.logical_widths,
|
265
|
+
)
|
266
|
+
|
267
|
+
# Update layer with new values.
|
268
|
+
layer.weight = Parameter(weight.t(), requires_grad=False)
|
269
|
+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
270
|
+
if self.quant_config.activation_scheme == "static":
|
271
|
+
layer.input_scale = Parameter(
|
272
|
+
layer.input_scale.max(), requires_grad=False
|
273
|
+
)
|
274
|
+
|
275
|
+
if self.use_marlin:
|
276
|
+
prepare_fp8_layer_for_marlin(layer)
|
277
|
+
# Activations not quantized for marlin.
|
278
|
+
del layer.input_scale
|
279
|
+
|
280
|
+
def apply(
|
281
|
+
self,
|
282
|
+
layer: torch.nn.Module,
|
283
|
+
x: torch.Tensor,
|
284
|
+
bias: Optional[torch.Tensor] = None,
|
285
|
+
) -> torch.Tensor:
|
286
|
+
|
287
|
+
if self.use_marlin:
|
288
|
+
return apply_fp8_marlin_linear(
|
289
|
+
input=x,
|
290
|
+
weight=layer.weight,
|
291
|
+
weight_scale=layer.weight_scale,
|
292
|
+
workspace=layer.workspace,
|
293
|
+
size_n=layer.output_size_per_partition,
|
294
|
+
size_k=layer.input_size_per_partition,
|
295
|
+
bias=bias,
|
296
|
+
)
|
297
|
+
|
298
|
+
return apply_fp8_linear(
|
299
|
+
input=x,
|
300
|
+
weight=layer.weight,
|
301
|
+
weight_scale=layer.weight_scale,
|
302
|
+
input_scale=layer.input_scale,
|
303
|
+
bias=bias,
|
304
|
+
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
305
|
+
use_per_token_if_dynamic=False,
|
306
|
+
)
|
307
|
+
|
308
|
+
|
309
|
+
class Fp8MoEMethod(FusedMoEMethodBase):
|
310
|
+
"""MoE method for FP8.
|
311
|
+
Supports loading FP8 checkpoints with static weight scale and
|
312
|
+
dynamic/static activation scale.
|
313
|
+
|
314
|
+
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
315
|
+
activation scaling. The weight scaling factor will be initialized after
|
316
|
+
the model weights are loaded.
|
317
|
+
|
318
|
+
Args:
|
319
|
+
quant_config: The quantization config.
|
320
|
+
"""
|
321
|
+
|
322
|
+
def __init__(self, quant_config: Fp8Config):
|
323
|
+
self.quant_config = quant_config
|
324
|
+
|
325
|
+
def create_weights(
|
326
|
+
self,
|
327
|
+
layer: Module,
|
328
|
+
num_experts: int,
|
329
|
+
hidden_size: int,
|
330
|
+
intermediate_size: int,
|
331
|
+
params_dtype: torch.dtype,
|
332
|
+
**extra_weight_attrs,
|
333
|
+
):
|
334
|
+
|
335
|
+
if self.quant_config.is_checkpoint_fp8_serialized:
|
336
|
+
params_dtype = torch.float8_e4m3fn
|
337
|
+
|
338
|
+
# WEIGHTS
|
339
|
+
w13_weight = torch.nn.Parameter(
|
340
|
+
torch.empty(
|
341
|
+
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
342
|
+
),
|
343
|
+
requires_grad=False,
|
344
|
+
)
|
345
|
+
layer.register_parameter("w13_weight", w13_weight)
|
346
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
347
|
+
|
348
|
+
w2_weight = torch.nn.Parameter(
|
349
|
+
torch.empty(
|
350
|
+
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
351
|
+
),
|
352
|
+
requires_grad=False,
|
353
|
+
)
|
354
|
+
layer.register_parameter("w2_weight", w2_weight)
|
355
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
356
|
+
|
357
|
+
# WEIGHT_SCALES
|
358
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
359
|
+
# They will be combined to a single scale after weight loading.
|
360
|
+
w13_weight_scale = torch.nn.Parameter(
|
361
|
+
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
362
|
+
)
|
363
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
364
|
+
|
365
|
+
w2_weight_scale = torch.nn.Parameter(
|
366
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
367
|
+
)
|
368
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
369
|
+
# Add the quantization method used (per tensor/grouped/channel)
|
370
|
+
# to ensure the weight scales are loaded in properly
|
371
|
+
extra_weight_attrs.update(
|
372
|
+
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
373
|
+
)
|
374
|
+
# If loading fp8 checkpoint, pass the weight loaders.
|
375
|
+
# If loading an fp16 checkpoint, do not (we will quantize in
|
376
|
+
# process_weights_after_loading()
|
377
|
+
if self.quant_config.is_checkpoint_fp8_serialized:
|
378
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
379
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
380
|
+
|
381
|
+
# INPUT_SCALES
|
382
|
+
if self.quant_config.activation_scheme == "static":
|
383
|
+
if not self.quant_config.is_checkpoint_fp8_serialized:
|
384
|
+
raise ValueError(
|
385
|
+
"Found static activation scheme for checkpoint that "
|
386
|
+
"was not serialized fp8."
|
387
|
+
)
|
388
|
+
|
389
|
+
w13_input_scale = torch.nn.Parameter(
|
390
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
391
|
+
)
|
392
|
+
layer.register_parameter("w13_input_scale", w13_input_scale)
|
393
|
+
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
394
|
+
|
395
|
+
w2_input_scale = torch.nn.Parameter(
|
396
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
397
|
+
)
|
398
|
+
layer.register_parameter("w2_input_scale", w2_input_scale)
|
399
|
+
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
400
|
+
|
401
|
+
else:
|
402
|
+
layer.w13_input_scale = None
|
403
|
+
layer.w2_input_scale = None
|
404
|
+
|
405
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
406
|
+
|
407
|
+
# If checkpoint is fp16, quantize in place.
|
408
|
+
if not self.quant_config.is_checkpoint_fp8_serialized:
|
409
|
+
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
410
|
+
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
411
|
+
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
412
|
+
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
413
|
+
|
414
|
+
# Re-initialize w13_scale because we directly quantize
|
415
|
+
# merged w13 weights and generate a single scaling factor.
|
416
|
+
layer.w13_weight_scale = torch.nn.Parameter(
|
417
|
+
torch.ones(
|
418
|
+
layer.num_experts, dtype=torch.float32, device=w13_weight.device
|
419
|
+
),
|
420
|
+
requires_grad=False,
|
421
|
+
)
|
422
|
+
for expert in range(layer.num_experts):
|
423
|
+
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
424
|
+
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
425
|
+
)
|
426
|
+
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
427
|
+
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
428
|
+
)
|
429
|
+
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
430
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
431
|
+
return
|
432
|
+
|
433
|
+
# If checkpoint is fp8, we need to handle that the
|
434
|
+
# MoE kernels require single activation scale and single weight
|
435
|
+
# scale for w13 per expert.
|
436
|
+
else:
|
437
|
+
# Fp8 moe kernels require a single activation scale.
|
438
|
+
# We take the max of all the scales in case they differ.
|
439
|
+
if self.quant_config.activation_scheme == "static":
|
440
|
+
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
441
|
+
raise ValueError(
|
442
|
+
"QuantConfig has static quantization, but found "
|
443
|
+
"activation scales are None."
|
444
|
+
)
|
445
|
+
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
|
446
|
+
layer.w2_input_scale
|
447
|
+
):
|
448
|
+
print_warning_once(
|
449
|
+
"Found input_scales that are not equal for "
|
450
|
+
"fp8 MoE layer. Using the maximum across experts "
|
451
|
+
"for each layer. "
|
452
|
+
)
|
453
|
+
layer.w13_input_scale = torch.nn.Parameter(
|
454
|
+
layer.w13_input_scale.max(), requires_grad=False
|
455
|
+
)
|
456
|
+
layer.w2_input_scale = torch.nn.Parameter(
|
457
|
+
layer.w2_input_scale.max(), requires_grad=False
|
458
|
+
)
|
459
|
+
# If ROCm, normalize the weights and scales to e4m3fnuz
|
460
|
+
if is_hip():
|
461
|
+
# Normalize the weights and scales
|
462
|
+
w13_weight, w13_weight_scale, w13_input_scale = (
|
463
|
+
normalize_e4m3fn_to_e4m3fnuz(
|
464
|
+
layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
|
465
|
+
)
|
466
|
+
)
|
467
|
+
w2_weight, w2_weight_scale, w2_input_scale = (
|
468
|
+
normalize_e4m3fn_to_e4m3fnuz(
|
469
|
+
layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
|
470
|
+
)
|
471
|
+
)
|
472
|
+
# Reset the parameter
|
473
|
+
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
474
|
+
layer.w13_weight_scale = torch.nn.Parameter(
|
475
|
+
w13_weight_scale, requires_grad=False
|
476
|
+
)
|
477
|
+
if w13_input_scale is not None:
|
478
|
+
layer.w13_input_scale = torch.nn.Parameter(
|
479
|
+
w13_input_scale, requires_grad=False
|
480
|
+
)
|
481
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
482
|
+
layer.w2_weight_scale = torch.nn.Parameter(
|
483
|
+
w2_weight_scale, requires_grad=False
|
484
|
+
)
|
485
|
+
if w2_input_scale is not None:
|
486
|
+
layer.w2_input_scale = torch.nn.Parameter(
|
487
|
+
w2_input_scale, requires_grad=False
|
488
|
+
)
|
489
|
+
|
490
|
+
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
491
|
+
# We take the max then dequant and requant each expert.
|
492
|
+
assert layer.w13_weight_scale is not None
|
493
|
+
shard_size = layer.intermediate_size_per_partition
|
494
|
+
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
495
|
+
for expert_id in range(layer.num_experts):
|
496
|
+
start = 0
|
497
|
+
for shard_id in range(2):
|
498
|
+
dq_weight = per_tensor_dequantize(
|
499
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
500
|
+
layer.w13_weight_scale[expert_id][shard_id],
|
501
|
+
)
|
502
|
+
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
503
|
+
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
504
|
+
)
|
505
|
+
start += shard_size
|
506
|
+
|
507
|
+
layer.w13_weight_scale = torch.nn.Parameter(
|
508
|
+
max_w13_scales, requires_grad=False
|
509
|
+
)
|
510
|
+
return
|
511
|
+
|
512
|
+
def apply(
|
513
|
+
self,
|
514
|
+
layer: torch.nn.Module,
|
515
|
+
x: torch.Tensor,
|
516
|
+
router_logits: torch.Tensor,
|
517
|
+
top_k: int,
|
518
|
+
renormalize: bool,
|
519
|
+
use_grouped_topk: bool,
|
520
|
+
topk_group: Optional[int] = None,
|
521
|
+
num_expert_group: Optional[int] = None,
|
522
|
+
custom_routing_function: Optional[Callable] = None,
|
523
|
+
) -> torch.Tensor:
|
524
|
+
|
525
|
+
from vllm.model_executor.layers.fused_moe import fused_experts
|
526
|
+
|
527
|
+
topk_weights, topk_ids = FusedMoE.select_experts(
|
528
|
+
hidden_states=x,
|
529
|
+
router_logits=router_logits,
|
530
|
+
use_grouped_topk=use_grouped_topk,
|
531
|
+
top_k=top_k,
|
532
|
+
renormalize=renormalize,
|
533
|
+
topk_group=topk_group,
|
534
|
+
num_expert_group=num_expert_group,
|
535
|
+
custom_routing_function=custom_routing_function,
|
536
|
+
)
|
537
|
+
|
538
|
+
return fused_experts(
|
539
|
+
x,
|
540
|
+
layer.w13_weight,
|
541
|
+
layer.w2_weight,
|
542
|
+
topk_weights=topk_weights,
|
543
|
+
topk_ids=topk_ids,
|
544
|
+
inplace=True,
|
545
|
+
use_fp8_w8a8=True,
|
546
|
+
w1_scale=layer.w13_weight_scale,
|
547
|
+
w2_scale=layer.w2_weight_scale,
|
548
|
+
a1_scale=layer.w13_input_scale,
|
549
|
+
a2_scale=layer.w2_input_scale,
|
550
|
+
)
|
551
|
+
|
552
|
+
|
553
|
+
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
554
|
+
"""
|
555
|
+
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
556
|
+
"""
|
557
|
+
|
558
|
+
def __init__(self, quant_config: Fp8Config):
|
559
|
+
super().__init__(quant_config)
|