sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ from vllm.distributed import (
23
23
  tensor_model_parallel_all_gather,
24
24
  )
25
25
 
26
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
26
27
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
27
28
 
28
29
 
@@ -163,7 +164,7 @@ class LogitsProcessor(nn.Module):
163
164
  self,
164
165
  input_ids,
165
166
  hidden_states,
166
- weight,
167
+ lm_head: VocabParallelEmbedding,
167
168
  logits_metadata: Union[LogitsMetadata, ForwardBatch],
168
169
  ):
169
170
  if isinstance(logits_metadata, ForwardBatch):
@@ -178,7 +179,7 @@ class LogitsProcessor(nn.Module):
178
179
  last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
179
180
  last_hidden = hidden_states[last_index]
180
181
 
181
- last_logits = torch.matmul(last_hidden, weight.T)
182
+ last_logits = self._get_logits(last_hidden, lm_head)
182
183
  if self.do_tensor_parallel_all_gather:
183
184
  last_logits = tensor_model_parallel_all_gather(last_logits)
184
185
  last_logits = last_logits[:, : self.config.vocab_size].float()
@@ -229,7 +230,7 @@ class LogitsProcessor(nn.Module):
229
230
 
230
231
  # Compute the logits and logprobs for all required tokens
231
232
  states = torch.cat(states, dim=0)
232
- all_logits = torch.matmul(states, weight.T)
233
+ all_logits = self._get_logits(states, lm_head)
233
234
  if self.do_tensor_parallel_all_gather:
234
235
  all_logits = tensor_model_parallel_all_gather(all_logits)
235
236
  all_logits = all_logits[:, : self.config.vocab_size].float()
@@ -276,6 +277,19 @@ class LogitsProcessor(nn.Module):
276
277
  output_top_logprobs=output_top_logprobs,
277
278
  )
278
279
 
280
+ def _get_logits(
281
+ self,
282
+ hidden_states: torch.Tensor,
283
+ lm_head: VocabParallelEmbedding,
284
+ embedding_bias: Optional[torch.Tensor] = None,
285
+ ) -> torch.Tensor:
286
+ if hasattr(lm_head, "weight"):
287
+ logits = torch.matmul(hidden_states, lm_head.weight.T)
288
+ else:
289
+ # GGUF models
290
+ logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
291
+ return logits
292
+
279
293
 
280
294
  def test():
281
295
  all_logprobs = torch.tensor(
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
13
13
  from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
14
14
  from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
15
15
  from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
16
- from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
17
16
  from vllm.model_executor.layers.quantization.gguf import GGUFConfig
18
17
  from vllm.model_executor.layers.quantization.gptq import GPTQConfig
19
18
  from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
@@ -23,6 +22,7 @@ from vllm.model_executor.layers.quantization.qqq import QQQConfig
23
22
  from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
24
23
 
25
24
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
25
+ from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
26
26
 
27
27
  QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
28
28
  "aqlm": AQLMConfig,
@@ -100,13 +100,13 @@ def fp8_moe_apply(
100
100
  def fp8_get_quant_method(self, layer, prefix):
101
101
  """Enhanced get_quant_method for FP8 config."""
102
102
  from vllm.model_executor.layers.linear import LinearBase
103
- from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
104
103
  from vllm.model_executor.layers.quantization.utils.quant_utils import (
105
104
  is_layer_skipped,
106
105
  )
107
106
 
108
107
  from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
109
108
  from sglang.srt.layers.linear import UnquantizedLinearMethod
109
+ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
110
110
 
111
111
  if isinstance(layer, LinearBase):
112
112
  if is_layer_skipped(prefix, self.ignored_layers):
@@ -117,10 +117,44 @@ def fp8_get_quant_method(self, layer, prefix):
117
117
  return None
118
118
 
119
119
 
120
+ def gptq_get_quant_method(self, layer, prefix):
121
+ from vllm.model_executor.layers.linear import LinearBase
122
+ from vllm.model_executor.layers.quantization.gptq_marlin import (
123
+ GPTQMarlinLinearMethod,
124
+ GPTQMarlinMoEMethod,
125
+ )
126
+
127
+ from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
128
+
129
+ if isinstance(layer, LinearBase):
130
+ return GPTQMarlinLinearMethod(self)
131
+ elif isinstance(layer, FusedMoE):
132
+ return GPTQMarlinMoEMethod(self)
133
+ return None
134
+
135
+
136
+ def awq_get_quant_method(self, layer, prefix):
137
+ from vllm.model_executor.layers.linear import LinearBase
138
+ from vllm.model_executor.layers.quantization.awq_marlin import (
139
+ AWQMarlinLinearMethod,
140
+ AWQMoEMethod,
141
+ )
142
+
143
+ from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
144
+
145
+ if isinstance(layer, LinearBase):
146
+ return AWQMarlinLinearMethod(self)
147
+ elif isinstance(layer, FusedMoE):
148
+ return AWQMoEMethod(self)
149
+ return None
150
+
151
+
120
152
  def apply_monkey_patches():
121
153
  """Apply all monkey patches in one place."""
122
154
  setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
123
155
  setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
156
+ setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
157
+ setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
124
158
 
125
159
 
126
160
  # Apply patches when module is imported
@@ -0,0 +1,559 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
2
+
3
+ import logging
4
+ from typing import Any, Callable, Dict, List, Optional
5
+
6
+ import torch
7
+ from torch.nn import Module
8
+ from torch.nn.parameter import Parameter
9
+ from vllm import _custom_ops as ops
10
+ from vllm.model_executor.layers.linear import LinearBase
11
+ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
12
+ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
13
+ apply_fp8_marlin_linear,
14
+ prepare_fp8_layer_for_marlin,
15
+ )
16
+ from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
17
+ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
18
+ all_close_1d,
19
+ apply_fp8_linear,
20
+ convert_to_channelwise,
21
+ cutlass_fp8_supported,
22
+ per_tensor_dequantize,
23
+ requantize_with_max_scale,
24
+ )
25
+ from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
26
+
27
+ from sglang.srt.layers.fused_moe_triton import (
28
+ FusedMoE,
29
+ FusedMoEMethodBase,
30
+ FusedMoeWeightScaleSupported,
31
+ )
32
+ from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
33
+ from sglang.srt.layers.quantization.base_config import (
34
+ QuantizationConfig,
35
+ QuantizeMethodBase,
36
+ )
37
+ from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
38
+ from sglang.srt.utils import (
39
+ get_bool_env_var,
40
+ is_hip,
41
+ print_warning_once,
42
+ set_weight_attrs,
43
+ )
44
+
45
+ ACTIVATION_SCHEMES = ["static", "dynamic"]
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ class Fp8Config(QuantizationConfig):
51
+ """Config class for FP8."""
52
+
53
+ def __init__(
54
+ self,
55
+ is_checkpoint_fp8_serialized: bool = False,
56
+ activation_scheme: str = "dynamic",
57
+ ignored_layers: Optional[List[str]] = None,
58
+ ) -> None:
59
+ self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
60
+ if is_checkpoint_fp8_serialized:
61
+ logger.warning(
62
+ "Detected fp8 checkpoint. Please note that the "
63
+ "format is experimental and subject to change."
64
+ )
65
+ if activation_scheme not in ACTIVATION_SCHEMES:
66
+ raise ValueError(f"Unsupported activation scheme {activation_scheme}")
67
+ self.activation_scheme = activation_scheme
68
+ self.ignored_layers = ignored_layers or []
69
+
70
+ @classmethod
71
+ def get_name(cls) -> str:
72
+ return "fp8"
73
+
74
+ @classmethod
75
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
76
+ return [torch.bfloat16, torch.half]
77
+
78
+ @classmethod
79
+ def get_min_capability(cls) -> int:
80
+ return 80
81
+
82
+ @classmethod
83
+ def get_config_filenames(cls) -> List[str]:
84
+ return []
85
+
86
+ @classmethod
87
+ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
88
+ quant_method = cls.get_from_keys(config, ["quant_method"])
89
+ is_checkpoint_fp8_serialized = "fp8" in quant_method
90
+ activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
91
+ ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
92
+ return cls(
93
+ is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
94
+ activation_scheme=activation_scheme,
95
+ ignored_layers=ignored_layers,
96
+ )
97
+
98
+ def get_quant_method(
99
+ self, layer: torch.nn.Module, prefix: str
100
+ ) -> Optional["QuantizeMethodBase"]:
101
+ from vllm.attention.layer import Attention # Avoid circular import
102
+
103
+ if isinstance(layer, LinearBase):
104
+ if is_layer_skipped(prefix, self.ignored_layers):
105
+ return UnquantizedLinearMethod()
106
+ return Fp8LinearMethod(self)
107
+ elif isinstance(layer, FusedMoE):
108
+ return Fp8MoEMethod(self)
109
+ elif isinstance(layer, Attention):
110
+ return Fp8KVCacheMethod(self)
111
+ return None
112
+
113
+ def get_scaled_act_names(self) -> List[str]:
114
+ return []
115
+
116
+
117
+ class Fp8LinearMethod(LinearMethodBase):
118
+ """Linear method for FP8.
119
+ Supports loading FP8 checkpoints with static weight scale and
120
+ dynamic/static activation scale.
121
+
122
+ Also supports loading quantized FP16/BF16 model checkpoints with dynamic
123
+ activation scaling. The weight scaling factor will be initialized after
124
+ the model weights are loaded.
125
+
126
+ Limitations:
127
+ 1. Only support per-tensor quantization due to torch._scaled_mm support.
128
+ 2. Only support float8_e4m3fn data type due to the limitation of
129
+ torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
130
+
131
+ Args:
132
+ quant_config: The quantization config.
133
+ """
134
+
135
+ def __init__(self, quant_config: Fp8Config):
136
+ self.quant_config = quant_config
137
+ self.cutlass_fp8_supported = cutlass_fp8_supported()
138
+
139
+ # For GPUs that lack FP8 hardware support, we can leverage the Marlin
140
+ # kernel for fast weight-only FP8 quantization
141
+ self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
142
+ # Disable marlin for ROCm
143
+ if is_hip():
144
+ self.use_marlin = False
145
+
146
+ def create_weights(
147
+ self,
148
+ layer: torch.nn.Module,
149
+ input_size_per_partition: int,
150
+ output_partition_sizes: List[int],
151
+ input_size: int,
152
+ output_size: int,
153
+ params_dtype: torch.dtype,
154
+ **extra_weight_attrs,
155
+ ):
156
+ del input_size, output_size
157
+ output_size_per_partition = sum(output_partition_sizes)
158
+ weight_loader = extra_weight_attrs.get("weight_loader")
159
+
160
+ layer.logical_widths = output_partition_sizes
161
+
162
+ layer.input_size_per_partition = input_size_per_partition
163
+ layer.output_size_per_partition = output_size_per_partition
164
+ layer.orig_dtype = params_dtype
165
+
166
+ # WEIGHT
167
+ weight_dtype = (
168
+ torch.float8_e4m3fn
169
+ if self.quant_config.is_checkpoint_fp8_serialized
170
+ else params_dtype
171
+ )
172
+
173
+ weight = ModelWeightParameter(
174
+ data=torch.empty(
175
+ output_size_per_partition, input_size_per_partition, dtype=weight_dtype
176
+ ),
177
+ input_dim=1,
178
+ output_dim=0,
179
+ weight_loader=weight_loader,
180
+ )
181
+ layer.register_parameter("weight", weight)
182
+
183
+ # If checkpoint is serialized fp8, load them.
184
+ # Otherwise, wait until process_weights_after_loading.
185
+ if self.quant_config.is_checkpoint_fp8_serialized:
186
+ # WEIGHT SCALE
187
+ scale = PerTensorScaleParameter(
188
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
189
+ weight_loader=weight_loader,
190
+ )
191
+
192
+ scale[:] = torch.finfo(torch.float32).min
193
+ layer.register_parameter("weight_scale", scale)
194
+
195
+ # INPUT ACTIVATION SCALE
196
+ if self.quant_config.activation_scheme == "static":
197
+ scale = PerTensorScaleParameter(
198
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
199
+ weight_loader=weight_loader,
200
+ )
201
+
202
+ scale[:] = torch.finfo(torch.float32).min
203
+ layer.register_parameter("input_scale", scale)
204
+ else:
205
+ layer.register_parameter("input_scale", None)
206
+
207
+ def process_weights_after_loading(self, layer: Module) -> None:
208
+ layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
209
+ # If checkpoint not serialized fp8, quantize the weights.
210
+ if not self.quant_config.is_checkpoint_fp8_serialized:
211
+ qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
212
+
213
+ # If using marlin (w8a16), kernel uses channelwise weights,
214
+ # so extend the weight scales to be channelwise.
215
+ if self.use_marlin:
216
+ assert weight_scale.numel() == 1
217
+ weight_scale = convert_to_channelwise(
218
+ weight_scale.expand(len(layer.logical_widths)), layer.logical_widths
219
+ )
220
+
221
+ # Update the layer with the new values.
222
+ layer.weight = Parameter(qweight.t(), requires_grad=False)
223
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
224
+ layer.input_scale = None
225
+
226
+ # If checkpoint is fp8, handle that there are N scales for N
227
+ # shards in a fused module
228
+ else:
229
+ layer.weight_scale = torch.nn.Parameter(
230
+ layer.weight_scale.data, requires_grad=False
231
+ )
232
+ if self.quant_config.activation_scheme == "static":
233
+ layer.input_scale = torch.nn.Parameter(
234
+ layer.input_scale.data, requires_grad=False
235
+ )
236
+ # If using marlin (w8a16), kernel uses channelwise weights,
237
+ # so extend the weight scales to be channelwise.
238
+ if self.use_marlin:
239
+ weight = layer.weight
240
+ weight_scale = convert_to_channelwise(
241
+ layer.weight_scale, layer.logical_widths
242
+ )
243
+
244
+ # If using w8a8, torch._scaled_mm needs per tensor, so
245
+ # requantize the logical shards as a single weight.
246
+ else:
247
+ # Dequant -> Quant with max scale so we can run per tensor.
248
+ weight = layer.weight
249
+ weight_scale = layer.weight_scale
250
+
251
+ # If ROCm, normalize the weights and scales to e4m3fnuz
252
+ if is_hip():
253
+ weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
254
+ weight=weight,
255
+ weight_scale=weight_scale,
256
+ input_scale=layer.input_scale,
257
+ )
258
+ if input_scale is not None:
259
+ layer.input_scale = Parameter(input_scale, requires_grad=False)
260
+
261
+ weight_scale, weight = requantize_with_max_scale(
262
+ weight=weight,
263
+ weight_scale=weight_scale,
264
+ logical_widths=layer.logical_widths,
265
+ )
266
+
267
+ # Update layer with new values.
268
+ layer.weight = Parameter(weight.t(), requires_grad=False)
269
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
270
+ if self.quant_config.activation_scheme == "static":
271
+ layer.input_scale = Parameter(
272
+ layer.input_scale.max(), requires_grad=False
273
+ )
274
+
275
+ if self.use_marlin:
276
+ prepare_fp8_layer_for_marlin(layer)
277
+ # Activations not quantized for marlin.
278
+ del layer.input_scale
279
+
280
+ def apply(
281
+ self,
282
+ layer: torch.nn.Module,
283
+ x: torch.Tensor,
284
+ bias: Optional[torch.Tensor] = None,
285
+ ) -> torch.Tensor:
286
+
287
+ if self.use_marlin:
288
+ return apply_fp8_marlin_linear(
289
+ input=x,
290
+ weight=layer.weight,
291
+ weight_scale=layer.weight_scale,
292
+ workspace=layer.workspace,
293
+ size_n=layer.output_size_per_partition,
294
+ size_k=layer.input_size_per_partition,
295
+ bias=bias,
296
+ )
297
+
298
+ return apply_fp8_linear(
299
+ input=x,
300
+ weight=layer.weight,
301
+ weight_scale=layer.weight_scale,
302
+ input_scale=layer.input_scale,
303
+ bias=bias,
304
+ cutlass_fp8_supported=self.cutlass_fp8_supported,
305
+ use_per_token_if_dynamic=False,
306
+ )
307
+
308
+
309
+ class Fp8MoEMethod(FusedMoEMethodBase):
310
+ """MoE method for FP8.
311
+ Supports loading FP8 checkpoints with static weight scale and
312
+ dynamic/static activation scale.
313
+
314
+ Also supports loading quantized FP16/BF16 model checkpoints with dynamic
315
+ activation scaling. The weight scaling factor will be initialized after
316
+ the model weights are loaded.
317
+
318
+ Args:
319
+ quant_config: The quantization config.
320
+ """
321
+
322
+ def __init__(self, quant_config: Fp8Config):
323
+ self.quant_config = quant_config
324
+
325
+ def create_weights(
326
+ self,
327
+ layer: Module,
328
+ num_experts: int,
329
+ hidden_size: int,
330
+ intermediate_size: int,
331
+ params_dtype: torch.dtype,
332
+ **extra_weight_attrs,
333
+ ):
334
+
335
+ if self.quant_config.is_checkpoint_fp8_serialized:
336
+ params_dtype = torch.float8_e4m3fn
337
+
338
+ # WEIGHTS
339
+ w13_weight = torch.nn.Parameter(
340
+ torch.empty(
341
+ num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
342
+ ),
343
+ requires_grad=False,
344
+ )
345
+ layer.register_parameter("w13_weight", w13_weight)
346
+ set_weight_attrs(w13_weight, extra_weight_attrs)
347
+
348
+ w2_weight = torch.nn.Parameter(
349
+ torch.empty(
350
+ num_experts, hidden_size, intermediate_size, dtype=params_dtype
351
+ ),
352
+ requires_grad=False,
353
+ )
354
+ layer.register_parameter("w2_weight", w2_weight)
355
+ set_weight_attrs(w2_weight, extra_weight_attrs)
356
+
357
+ # WEIGHT_SCALES
358
+ # Allocate 2 scales for w1 and w3 respectively.
359
+ # They will be combined to a single scale after weight loading.
360
+ w13_weight_scale = torch.nn.Parameter(
361
+ torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
362
+ )
363
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
364
+
365
+ w2_weight_scale = torch.nn.Parameter(
366
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
367
+ )
368
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
369
+ # Add the quantization method used (per tensor/grouped/channel)
370
+ # to ensure the weight scales are loaded in properly
371
+ extra_weight_attrs.update(
372
+ {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
373
+ )
374
+ # If loading fp8 checkpoint, pass the weight loaders.
375
+ # If loading an fp16 checkpoint, do not (we will quantize in
376
+ # process_weights_after_loading()
377
+ if self.quant_config.is_checkpoint_fp8_serialized:
378
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
379
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
380
+
381
+ # INPUT_SCALES
382
+ if self.quant_config.activation_scheme == "static":
383
+ if not self.quant_config.is_checkpoint_fp8_serialized:
384
+ raise ValueError(
385
+ "Found static activation scheme for checkpoint that "
386
+ "was not serialized fp8."
387
+ )
388
+
389
+ w13_input_scale = torch.nn.Parameter(
390
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
391
+ )
392
+ layer.register_parameter("w13_input_scale", w13_input_scale)
393
+ set_weight_attrs(w13_input_scale, extra_weight_attrs)
394
+
395
+ w2_input_scale = torch.nn.Parameter(
396
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
397
+ )
398
+ layer.register_parameter("w2_input_scale", w2_input_scale)
399
+ set_weight_attrs(w2_input_scale, extra_weight_attrs)
400
+
401
+ else:
402
+ layer.w13_input_scale = None
403
+ layer.w2_input_scale = None
404
+
405
+ def process_weights_after_loading(self, layer: Module) -> None:
406
+
407
+ # If checkpoint is fp16, quantize in place.
408
+ if not self.quant_config.is_checkpoint_fp8_serialized:
409
+ # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
410
+ fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
411
+ w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
412
+ w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
413
+
414
+ # Re-initialize w13_scale because we directly quantize
415
+ # merged w13 weights and generate a single scaling factor.
416
+ layer.w13_weight_scale = torch.nn.Parameter(
417
+ torch.ones(
418
+ layer.num_experts, dtype=torch.float32, device=w13_weight.device
419
+ ),
420
+ requires_grad=False,
421
+ )
422
+ for expert in range(layer.num_experts):
423
+ w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
424
+ ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
425
+ )
426
+ w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
427
+ ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
428
+ )
429
+ layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
430
+ layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
431
+ return
432
+
433
+ # If checkpoint is fp8, we need to handle that the
434
+ # MoE kernels require single activation scale and single weight
435
+ # scale for w13 per expert.
436
+ else:
437
+ # Fp8 moe kernels require a single activation scale.
438
+ # We take the max of all the scales in case they differ.
439
+ if self.quant_config.activation_scheme == "static":
440
+ if layer.w13_input_scale is None or layer.w2_input_scale is None:
441
+ raise ValueError(
442
+ "QuantConfig has static quantization, but found "
443
+ "activation scales are None."
444
+ )
445
+ if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
446
+ layer.w2_input_scale
447
+ ):
448
+ print_warning_once(
449
+ "Found input_scales that are not equal for "
450
+ "fp8 MoE layer. Using the maximum across experts "
451
+ "for each layer. "
452
+ )
453
+ layer.w13_input_scale = torch.nn.Parameter(
454
+ layer.w13_input_scale.max(), requires_grad=False
455
+ )
456
+ layer.w2_input_scale = torch.nn.Parameter(
457
+ layer.w2_input_scale.max(), requires_grad=False
458
+ )
459
+ # If ROCm, normalize the weights and scales to e4m3fnuz
460
+ if is_hip():
461
+ # Normalize the weights and scales
462
+ w13_weight, w13_weight_scale, w13_input_scale = (
463
+ normalize_e4m3fn_to_e4m3fnuz(
464
+ layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
465
+ )
466
+ )
467
+ w2_weight, w2_weight_scale, w2_input_scale = (
468
+ normalize_e4m3fn_to_e4m3fnuz(
469
+ layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
470
+ )
471
+ )
472
+ # Reset the parameter
473
+ layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
474
+ layer.w13_weight_scale = torch.nn.Parameter(
475
+ w13_weight_scale, requires_grad=False
476
+ )
477
+ if w13_input_scale is not None:
478
+ layer.w13_input_scale = torch.nn.Parameter(
479
+ w13_input_scale, requires_grad=False
480
+ )
481
+ layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
482
+ layer.w2_weight_scale = torch.nn.Parameter(
483
+ w2_weight_scale, requires_grad=False
484
+ )
485
+ if w2_input_scale is not None:
486
+ layer.w2_input_scale = torch.nn.Parameter(
487
+ w2_input_scale, requires_grad=False
488
+ )
489
+
490
+ # Fp8 moe kernel needs single weight scale for w13 per expert.
491
+ # We take the max then dequant and requant each expert.
492
+ assert layer.w13_weight_scale is not None
493
+ shard_size = layer.intermediate_size_per_partition
494
+ max_w13_scales = layer.w13_weight_scale.max(dim=1).values
495
+ for expert_id in range(layer.num_experts):
496
+ start = 0
497
+ for shard_id in range(2):
498
+ dq_weight = per_tensor_dequantize(
499
+ layer.w13_weight[expert_id][start : start + shard_size, :],
500
+ layer.w13_weight_scale[expert_id][shard_id],
501
+ )
502
+ layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
503
+ ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
504
+ )
505
+ start += shard_size
506
+
507
+ layer.w13_weight_scale = torch.nn.Parameter(
508
+ max_w13_scales, requires_grad=False
509
+ )
510
+ return
511
+
512
+ def apply(
513
+ self,
514
+ layer: torch.nn.Module,
515
+ x: torch.Tensor,
516
+ router_logits: torch.Tensor,
517
+ top_k: int,
518
+ renormalize: bool,
519
+ use_grouped_topk: bool,
520
+ topk_group: Optional[int] = None,
521
+ num_expert_group: Optional[int] = None,
522
+ custom_routing_function: Optional[Callable] = None,
523
+ ) -> torch.Tensor:
524
+
525
+ from vllm.model_executor.layers.fused_moe import fused_experts
526
+
527
+ topk_weights, topk_ids = FusedMoE.select_experts(
528
+ hidden_states=x,
529
+ router_logits=router_logits,
530
+ use_grouped_topk=use_grouped_topk,
531
+ top_k=top_k,
532
+ renormalize=renormalize,
533
+ topk_group=topk_group,
534
+ num_expert_group=num_expert_group,
535
+ custom_routing_function=custom_routing_function,
536
+ )
537
+
538
+ return fused_experts(
539
+ x,
540
+ layer.w13_weight,
541
+ layer.w2_weight,
542
+ topk_weights=topk_weights,
543
+ topk_ids=topk_ids,
544
+ inplace=True,
545
+ use_fp8_w8a8=True,
546
+ w1_scale=layer.w13_weight_scale,
547
+ w2_scale=layer.w2_weight_scale,
548
+ a1_scale=layer.w13_input_scale,
549
+ a2_scale=layer.w2_input_scale,
550
+ )
551
+
552
+
553
+ class Fp8KVCacheMethod(BaseKVCacheMethod):
554
+ """
555
+ Supports loading kv-cache scaling factors from FP8 checkpoints.
556
+ """
557
+
558
+ def __init__(self, quant_config: Fp8Config):
559
+ super().__init__(quant_config)