sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +11 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +4 -3
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +71 -0
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/vision.py +13 -5
- sglang/srt/layers/communicator.py +21 -4
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +2 -7
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +77 -73
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +3 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +55 -30
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +28 -7
- sglang/srt/managers/scheduler.py +26 -12
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +24 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +7 -6
- sglang/srt/model_executor/forward_batch_info.py +35 -14
- sglang/srt/model_executor/model_runner.py +19 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +72 -33
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +24 -12
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +142 -7
- sglang/srt/two_batch_overlap.py +157 -5
- sglang/srt/utils.py +38 -2
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,557 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import fnmatch
|
6
|
+
import logging
|
7
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast
|
8
|
+
|
9
|
+
import aiter
|
10
|
+
import torch
|
11
|
+
import torch.nn.functional as F
|
12
|
+
from aiter import ActivationType, QuantType, dtypes
|
13
|
+
from aiter.fused_moe import fused_moe
|
14
|
+
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
15
|
+
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
|
16
|
+
from aiter.ops.quant import get_torch_quant
|
17
|
+
from aiter.ops.shuffle import shuffle_weight
|
18
|
+
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
19
|
+
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
20
|
+
from aiter.utility.fp4_utils import e8m0_shuffle
|
21
|
+
from torch.nn import Module
|
22
|
+
|
23
|
+
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
24
|
+
from sglang.srt.layers.parameter import ModelWeightParameter
|
25
|
+
from sglang.srt.layers.quantization.base_config import (
|
26
|
+
FusedMoEMethodBase,
|
27
|
+
LinearMethodBase,
|
28
|
+
QuantizationConfig,
|
29
|
+
QuantizeMethodBase,
|
30
|
+
)
|
31
|
+
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
32
|
+
from sglang.srt.layers.quantization.quark.schemes import QuarkScheme, QuarkW4A4MXFP4
|
33
|
+
from sglang.srt.layers.quantization.quark.utils import deep_compare, should_ignore_layer
|
34
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
35
|
+
from sglang.srt.utils import (
|
36
|
+
get_bool_env_var,
|
37
|
+
get_device_capability,
|
38
|
+
log_info_on_rank0,
|
39
|
+
mxfp_supported,
|
40
|
+
set_weight_attrs,
|
41
|
+
)
|
42
|
+
|
43
|
+
if TYPE_CHECKING:
|
44
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
45
|
+
|
46
|
+
logger = logging.getLogger(__name__)
|
47
|
+
|
48
|
+
use_dynamic_mxfp4_linear = get_bool_env_var("SGLANG_USE_DYNAMIC_MXFP4_linear")
|
49
|
+
|
50
|
+
OCP_MX_BLOCK_SIZE = 32
|
51
|
+
|
52
|
+
|
53
|
+
class Mxfp4Config(QuantizationConfig):
|
54
|
+
|
55
|
+
def __init__(self, ignored_layers: Optional[list[str]] = None):
|
56
|
+
super().__init__()
|
57
|
+
self.ignored_layers = ignored_layers
|
58
|
+
|
59
|
+
@classmethod
|
60
|
+
def from_config(cls, config):
|
61
|
+
return cls()
|
62
|
+
|
63
|
+
@classmethod
|
64
|
+
def get_min_capability(cls) -> int:
|
65
|
+
return 80
|
66
|
+
|
67
|
+
@classmethod
|
68
|
+
def get_name(cls) -> QuantizationMethods:
|
69
|
+
return "mxfp4"
|
70
|
+
|
71
|
+
@classmethod
|
72
|
+
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
73
|
+
return [torch.bfloat16]
|
74
|
+
|
75
|
+
@classmethod
|
76
|
+
def get_config_filenames(cls) -> list[str]:
|
77
|
+
return []
|
78
|
+
|
79
|
+
def get_quant_method(
|
80
|
+
self, layer: torch.nn.Module, prefix: str
|
81
|
+
) -> Optional["QuantizeMethodBase"]:
|
82
|
+
from vllm.attention.layer import Attention # Avoid circular import
|
83
|
+
|
84
|
+
if isinstance(layer, LinearBase):
|
85
|
+
if self.ignored_layers and is_layer_skipped(
|
86
|
+
prefix=prefix,
|
87
|
+
ignored_layers=self.ignored_layers,
|
88
|
+
fused_mapping=self.packed_modules_mapping,
|
89
|
+
):
|
90
|
+
return UnquantizedLinearMethod()
|
91
|
+
raise NotImplementedError("Mxfp4 linear layer is not implemented")
|
92
|
+
elif isinstance(layer, FusedMoE):
|
93
|
+
return Mxfp4MoEMethod(layer.moe_config)
|
94
|
+
elif isinstance(layer, Attention):
|
95
|
+
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
96
|
+
return None
|
97
|
+
|
98
|
+
|
99
|
+
class MxFp4LinearMethod(LinearMethodBase):
|
100
|
+
|
101
|
+
def __init__(self, quantization_config: MxFp4Config):
|
102
|
+
self.quantization_config = quantization_config
|
103
|
+
|
104
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
105
|
+
return
|
106
|
+
# if self.quantization_config.is_checkpoint_fp4_serialized:
|
107
|
+
# layer.scheme.process_weights_after_loading(layer)
|
108
|
+
# else:
|
109
|
+
# #w, w_scales = dynamic_mxfp4_quant(layer.weight.data)
|
110
|
+
# ##log_info_on_rank0(logger, f"w.shape: {w.shape}")
|
111
|
+
|
112
|
+
# #wshuffle = w#shuffle_weight(w, layout=(16, 16))
|
113
|
+
# #w_scales_shuffle = w_scales#e8m0_shuffle(w_scales).view(dtypes.fp8_e8m0)
|
114
|
+
|
115
|
+
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
|
116
|
+
|
117
|
+
# w, w_scales_shuffle = quant_func(layer.weight.data, shuffle=True)
|
118
|
+
|
119
|
+
# wshuffle = shuffle_weight(w, layout=(16, 16))
|
120
|
+
|
121
|
+
# layer.weight = torch.nn.Parameter(wshuffle,
|
122
|
+
# requires_grad=False)
|
123
|
+
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
|
124
|
+
# requires_grad=False)
|
125
|
+
|
126
|
+
def create_weights(
|
127
|
+
self,
|
128
|
+
layer: torch.nn.Module,
|
129
|
+
input_size_per_partition: int,
|
130
|
+
output_partition_sizes: list[int],
|
131
|
+
input_size: int,
|
132
|
+
output_size: int,
|
133
|
+
params_dtype: torch.dtype,
|
134
|
+
**extra_weight_attrs,
|
135
|
+
):
|
136
|
+
"""
|
137
|
+
Use the CompressedTensorsScheme associated with each layer to create
|
138
|
+
the necessary parameters for the layer. See LinearMethodBase for param
|
139
|
+
details
|
140
|
+
"""
|
141
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
142
|
+
|
143
|
+
if self.quantization_config.is_checkpoint_fp4_serialized:
|
144
|
+
layer.scheme.create_weights(
|
145
|
+
layer=layer,
|
146
|
+
input_size=input_size,
|
147
|
+
input_size_per_partition=input_size_per_partition,
|
148
|
+
output_partition_sizes=output_partition_sizes,
|
149
|
+
output_size=output_size,
|
150
|
+
params_dtype=params_dtype,
|
151
|
+
weight_loader=weight_loader,
|
152
|
+
)
|
153
|
+
else:
|
154
|
+
output_size_per_partition = sum(output_partition_sizes)
|
155
|
+
layer.logical_widths = output_partition_sizes
|
156
|
+
layer.input_size_per_partition = input_size_per_partition
|
157
|
+
layer.output_size_per_partition = output_size_per_partition
|
158
|
+
layer.orig_dtype = params_dtype
|
159
|
+
|
160
|
+
weight_dtype = params_dtype
|
161
|
+
|
162
|
+
weight = ModelWeightParameter(
|
163
|
+
data=torch.empty(
|
164
|
+
output_size_per_partition,
|
165
|
+
input_size_per_partition,
|
166
|
+
dtype=weight_dtype,
|
167
|
+
),
|
168
|
+
input_dim=1,
|
169
|
+
output_dim=0,
|
170
|
+
weight_loader=weight_loader,
|
171
|
+
)
|
172
|
+
|
173
|
+
layer.register_parameter("weight", weight)
|
174
|
+
layer.register_parameter("weight_scale", None)
|
175
|
+
|
176
|
+
def apply(
|
177
|
+
self,
|
178
|
+
layer: torch.nn.Module,
|
179
|
+
x: torch.Tensor,
|
180
|
+
bias: Optional[torch.Tensor] = None,
|
181
|
+
):
|
182
|
+
"""
|
183
|
+
Use the output of create_weights and the CompressedTensorsScheme
|
184
|
+
associated with the layer to apply the forward pass with the
|
185
|
+
layer input. See LinearMethodBase for param details
|
186
|
+
|
187
|
+
"""
|
188
|
+
if self.quantization_config.is_checkpoint_fp4_serialized:
|
189
|
+
scheme = layer.scheme
|
190
|
+
if scheme is None:
|
191
|
+
raise ValueError("A scheme must be defined for each layer")
|
192
|
+
return scheme.apply_weights(layer, x, bias=bias)
|
193
|
+
else:
|
194
|
+
out_dtype = x.dtype
|
195
|
+
|
196
|
+
# ck or asm implement
|
197
|
+
# M = x.shape[0]
|
198
|
+
# N = layer.weight.shape[0]
|
199
|
+
|
200
|
+
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
|
201
|
+
|
202
|
+
# x, x_scales_shuffle = quant_func(x, shuffle=True)
|
203
|
+
|
204
|
+
# y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=out_dtype)
|
205
|
+
|
206
|
+
# out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
|
207
|
+
|
208
|
+
# return out[:M]
|
209
|
+
|
210
|
+
# triton implement
|
211
|
+
x_q, x_s = dynamic_mxfp4_quant(x)
|
212
|
+
y = torch.empty(
|
213
|
+
x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype
|
214
|
+
)
|
215
|
+
|
216
|
+
out = gemm_afp4wfp4(
|
217
|
+
x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y
|
218
|
+
)
|
219
|
+
|
220
|
+
return out
|
221
|
+
|
222
|
+
|
223
|
+
class MxFp4MoEMethod:
|
224
|
+
def __new__(cls, *args, **kwargs):
|
225
|
+
if not hasattr(cls, "_initialized"):
|
226
|
+
original_init = cls.__init__
|
227
|
+
new_cls = type(
|
228
|
+
cls.__name__,
|
229
|
+
(FusedMoEMethodBase,),
|
230
|
+
{
|
231
|
+
"__init__": original_init,
|
232
|
+
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
233
|
+
},
|
234
|
+
)
|
235
|
+
obj = super(new_cls, new_cls).__new__(new_cls)
|
236
|
+
obj.__init__(*args, **kwargs)
|
237
|
+
return obj
|
238
|
+
return super().__new__(cls)
|
239
|
+
|
240
|
+
@staticmethod
|
241
|
+
def get_moe_method(
|
242
|
+
quant_config: "MxFp4Config", # type: ignore # noqa E501 # noqa F821
|
243
|
+
module: torch.nn.Module,
|
244
|
+
layer_name: str,
|
245
|
+
) -> "MxFp4MoEMethod":
|
246
|
+
|
247
|
+
if quant_config.is_checkpoint_fp4_serialized:
|
248
|
+
layer_quant_config = quant_config._find_matched_config(layer_name, module)
|
249
|
+
|
250
|
+
if layer_quant_config.get("output_tensors") or layer_quant_config.get(
|
251
|
+
"bias"
|
252
|
+
):
|
253
|
+
raise NotImplementedError(
|
254
|
+
"Currently, Quark models with "
|
255
|
+
"output_tensors and bias "
|
256
|
+
"quantized are not supported"
|
257
|
+
)
|
258
|
+
weight_config = layer_quant_config.get("weight")
|
259
|
+
input_config = layer_quant_config.get("input_tensors")
|
260
|
+
|
261
|
+
if quant_config._is_mx_fp4(weight_config, input_config):
|
262
|
+
return W4A4MXFp4MoEStaticMethod(weight_config, input_config)
|
263
|
+
else:
|
264
|
+
raise RuntimeError("Unsupported FusedMoe scheme")
|
265
|
+
else:
|
266
|
+
return W4A4MXFp4MoEDynamicMethod(quant_config)
|
267
|
+
|
268
|
+
|
269
|
+
class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
|
270
|
+
def __init__(self, quant_config):
|
271
|
+
self.quant_config = quant_config
|
272
|
+
|
273
|
+
def create_weights(
|
274
|
+
self,
|
275
|
+
layer: torch.nn.Module,
|
276
|
+
num_experts: int,
|
277
|
+
hidden_size: int,
|
278
|
+
intermediate_size_per_partition: int,
|
279
|
+
params_dtype: torch.dtype,
|
280
|
+
**extra_weight_attrs,
|
281
|
+
):
|
282
|
+
|
283
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
284
|
+
|
285
|
+
w13_weight = torch.nn.Parameter(
|
286
|
+
torch.empty(
|
287
|
+
num_experts,
|
288
|
+
2 * intermediate_size_per_partition,
|
289
|
+
hidden_size,
|
290
|
+
dtype=params_dtype,
|
291
|
+
),
|
292
|
+
requires_grad=False,
|
293
|
+
)
|
294
|
+
w2_weight = torch.nn.Parameter(
|
295
|
+
torch.empty(
|
296
|
+
num_experts,
|
297
|
+
hidden_size,
|
298
|
+
intermediate_size_per_partition,
|
299
|
+
dtype=params_dtype,
|
300
|
+
),
|
301
|
+
requires_grad=False,
|
302
|
+
)
|
303
|
+
|
304
|
+
layer.register_parameter("w13_weight", w13_weight)
|
305
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
306
|
+
|
307
|
+
layer.register_parameter("w2_weight", w2_weight)
|
308
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
309
|
+
|
310
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
311
|
+
# They will be combined to a single scale after weight loading.
|
312
|
+
w13_weight_scale = torch.nn.Parameter(
|
313
|
+
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
314
|
+
)
|
315
|
+
w2_weight_scale = torch.nn.Parameter(
|
316
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
317
|
+
)
|
318
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
319
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
320
|
+
|
321
|
+
# Add the quantization method used (per tensor/grouped/channel)
|
322
|
+
# to ensure the weight scales are loaded in properly
|
323
|
+
extra_weight_attrs.update(
|
324
|
+
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
325
|
+
)
|
326
|
+
|
327
|
+
layer.w13_input_scale = None
|
328
|
+
layer.w2_input_scale = None
|
329
|
+
|
330
|
+
def mxfp4_quantize(self, w):
|
331
|
+
w_shape = w.shape
|
332
|
+
w_need_reshape = True if w.dim() != 2 else False
|
333
|
+
|
334
|
+
if w_need_reshape:
|
335
|
+
w_last_dim_size = w_shape[-1]
|
336
|
+
w = w.view(-1, w_last_dim_size)
|
337
|
+
|
338
|
+
# log_info_on_rank0(logger, f"[Pre-quant] w.shape: {w.shape}")
|
339
|
+
w, mx_scales = dynamic_mxfp4_quant(w)
|
340
|
+
# log_info_on_rank0(logger, f"[Post-quant] w.shape: {w.shape} mx_scales.shape: {mx_scales.shape}")
|
341
|
+
|
342
|
+
if w_need_reshape:
|
343
|
+
w_new_shape = w_shape[:-1] + (w.shape[-1],)
|
344
|
+
w = w.view(w_new_shape)
|
345
|
+
|
346
|
+
# log_info_on_rank0(logger, f"[re-shape] w.shape: {w.shape} mx_scales.shape: {mx_scales.shape}")
|
347
|
+
|
348
|
+
mx_scales = e8m0_shuffle(mx_scales)
|
349
|
+
|
350
|
+
return w, mx_scales
|
351
|
+
|
352
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
353
|
+
w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
|
354
|
+
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
|
355
|
+
|
356
|
+
layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
|
357
|
+
layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False)
|
358
|
+
|
359
|
+
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
|
360
|
+
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
|
361
|
+
|
362
|
+
def apply(
|
363
|
+
self,
|
364
|
+
layer: torch.nn.Module,
|
365
|
+
x: torch.Tensor,
|
366
|
+
topk_output: TopKOutput,
|
367
|
+
*,
|
368
|
+
activation: str = "silu",
|
369
|
+
apply_router_weight_on_input: bool = False,
|
370
|
+
inplace: bool = True,
|
371
|
+
no_combine: bool = False,
|
372
|
+
routed_scaling_factor: Optional[float] = None,
|
373
|
+
) -> torch.Tensor:
|
374
|
+
topk_weights, topk_ids, _ = topk_output
|
375
|
+
|
376
|
+
return fused_moe(
|
377
|
+
x,
|
378
|
+
layer.w13_weight,
|
379
|
+
layer.w2_weight,
|
380
|
+
topk_weights,
|
381
|
+
topk_ids,
|
382
|
+
quant_type=QuantType.per_1x32,
|
383
|
+
w1_scale=layer.w13_weight_scale,
|
384
|
+
w2_scale=layer.w2_weight_scale,
|
385
|
+
activation=(
|
386
|
+
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
387
|
+
),
|
388
|
+
doweight_stage1=False,
|
389
|
+
)
|
390
|
+
|
391
|
+
|
392
|
+
class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
|
393
|
+
|
394
|
+
def __init__(self, weight_config: dict[str, Any], input_config: dict[str, Any]):
|
395
|
+
self.weight_quant = weight_config
|
396
|
+
self.input_quant = input_config
|
397
|
+
|
398
|
+
weight_qscheme = self.weight_quant.get("qscheme")
|
399
|
+
input_qscheme = self.input_quant.get("qscheme")
|
400
|
+
if not (weight_qscheme == "per_group" and input_qscheme == "per_group"):
|
401
|
+
raise ValueError(
|
402
|
+
"For MX(FP4) Fused MoE layers, only per-group scales "
|
403
|
+
"for weights and activations are supported. Found "
|
404
|
+
f"{weight_qscheme=}, {input_qscheme=}"
|
405
|
+
) # noqa E501
|
406
|
+
|
407
|
+
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
408
|
+
|
409
|
+
def create_weights(
|
410
|
+
self,
|
411
|
+
layer: torch.nn.Module,
|
412
|
+
num_experts: int,
|
413
|
+
hidden_size: int,
|
414
|
+
intermediate_size_per_partition: int,
|
415
|
+
params_dtype: torch.dtype,
|
416
|
+
**extra_weight_attrs,
|
417
|
+
):
|
418
|
+
|
419
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
420
|
+
|
421
|
+
# Add the quantization method used (per tensor/grouped/channel)
|
422
|
+
# to ensure the weight scales are loaded in properly
|
423
|
+
extra_weight_attrs.update(
|
424
|
+
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
425
|
+
)
|
426
|
+
|
427
|
+
params_dtype = torch.uint8
|
428
|
+
|
429
|
+
# WEIGHTS
|
430
|
+
w13_weight = torch.nn.Parameter(
|
431
|
+
torch.empty(
|
432
|
+
num_experts,
|
433
|
+
2 * intermediate_size_per_partition,
|
434
|
+
hidden_size // 2,
|
435
|
+
dtype=params_dtype,
|
436
|
+
),
|
437
|
+
requires_grad=False,
|
438
|
+
)
|
439
|
+
layer.register_parameter("w13_weight", w13_weight)
|
440
|
+
|
441
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
442
|
+
|
443
|
+
w2_weight = torch.nn.Parameter(
|
444
|
+
torch.empty(
|
445
|
+
num_experts,
|
446
|
+
hidden_size,
|
447
|
+
intermediate_size_per_partition // 2,
|
448
|
+
dtype=params_dtype,
|
449
|
+
),
|
450
|
+
requires_grad=False,
|
451
|
+
)
|
452
|
+
layer.register_parameter("w2_weight", w2_weight)
|
453
|
+
|
454
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
455
|
+
|
456
|
+
# WEIGHT_SCALES
|
457
|
+
w13_weight_scale = torch.nn.Parameter(
|
458
|
+
torch.ones(
|
459
|
+
num_experts,
|
460
|
+
2 * intermediate_size_per_partition,
|
461
|
+
hidden_size // OCP_MX_BLOCK_SIZE,
|
462
|
+
dtype=params_dtype,
|
463
|
+
),
|
464
|
+
requires_grad=False,
|
465
|
+
)
|
466
|
+
w2_weight_scale = torch.nn.Parameter(
|
467
|
+
torch.ones(
|
468
|
+
num_experts,
|
469
|
+
hidden_size,
|
470
|
+
intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
|
471
|
+
dtype=params_dtype,
|
472
|
+
),
|
473
|
+
requires_grad=False,
|
474
|
+
)
|
475
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
476
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
477
|
+
|
478
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
479
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
480
|
+
|
481
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
482
|
+
float_dtype = torch.get_default_dtype()
|
483
|
+
|
484
|
+
# Pre-shuffle weight scales
|
485
|
+
s0, s1, _ = layer.w13_weight_scale.shape
|
486
|
+
w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
|
487
|
+
w13_weight_scale = e8m0_shuffle(w13_weight_scale)
|
488
|
+
layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
|
489
|
+
|
490
|
+
s0, s1, _ = layer.w2_weight_scale.shape
|
491
|
+
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
|
492
|
+
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
|
493
|
+
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
|
494
|
+
|
495
|
+
def apply(
|
496
|
+
self,
|
497
|
+
layer: torch.nn.Module,
|
498
|
+
x: torch.Tensor,
|
499
|
+
topk_output: TopKOutput,
|
500
|
+
*,
|
501
|
+
activation: str = "silu",
|
502
|
+
apply_router_weight_on_input: bool = False,
|
503
|
+
inplace: bool = True,
|
504
|
+
no_combine: bool = False,
|
505
|
+
routed_scaling_factor: Optional[float] = None,
|
506
|
+
) -> torch.Tensor:
|
507
|
+
topk_weights, topk_ids, _ = topk_output
|
508
|
+
|
509
|
+
return fused_moe(
|
510
|
+
x,
|
511
|
+
layer.w13_weight,
|
512
|
+
layer.w2_weight,
|
513
|
+
topk_weights,
|
514
|
+
topk_ids,
|
515
|
+
quant_type=QuantType.per_1x32,
|
516
|
+
w1_scale=layer.w13_weight_scale,
|
517
|
+
w2_scale=layer.w2_weight_scale,
|
518
|
+
activation=(
|
519
|
+
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
520
|
+
),
|
521
|
+
doweight_stage1=False,
|
522
|
+
)
|
523
|
+
|
524
|
+
|
525
|
+
class MxFp4KVCacheMethod(BaseKVCacheMethod):
|
526
|
+
"""
|
527
|
+
Supports loading kv-cache scaling factors from quark checkpoints.
|
528
|
+
"""
|
529
|
+
|
530
|
+
def __init__(self, quant_config: MxFp4Config):
|
531
|
+
self.validate_kv_cache_config(quant_config.kv_cache_config)
|
532
|
+
super().__init__(quant_config)
|
533
|
+
|
534
|
+
@staticmethod
|
535
|
+
def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]):
|
536
|
+
"""
|
537
|
+
Validator for the kv cache configuration. Useful for controlling the
|
538
|
+
kv cache quantization schemes, that are being supported in vLLM
|
539
|
+
:param kv_cache_config: the quark kv cache scheme
|
540
|
+
"""
|
541
|
+
if kv_cache_config is None:
|
542
|
+
return
|
543
|
+
|
544
|
+
dtype = kv_cache_config.get("dtype")
|
545
|
+
if dtype != "fp8_e4m3":
|
546
|
+
raise NotImplementedError(
|
547
|
+
"Currently supported kv cache quantization is "
|
548
|
+
f"dtype=fp8_e4m3, however received {dtype}"
|
549
|
+
)
|
550
|
+
|
551
|
+
qscheme = kv_cache_config.get("qscheme")
|
552
|
+
if qscheme != "per_tensor":
|
553
|
+
raise NotImplementedError(
|
554
|
+
"Only support per-tensor scaling factor "
|
555
|
+
"for quark KV cache. "
|
556
|
+
f"Expected qscheme: per_tensor, found qscheme: {qscheme}"
|
557
|
+
)
|
@@ -63,7 +63,7 @@ from sglang.srt.layers.quantization.utils import (
|
|
63
63
|
per_tensor_dequantize,
|
64
64
|
requantize_with_max_scale,
|
65
65
|
)
|
66
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
66
|
+
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
|
67
67
|
from sglang.srt.utils import (
|
68
68
|
cpu_has_amx_support,
|
69
69
|
get_bool_env_var,
|
@@ -98,9 +98,6 @@ if _is_hip and (_use_aiter or _use_hip_int4):
|
|
98
98
|
from aiter.fused_moe import fused_moe
|
99
99
|
from aiter.ops.shuffle import shuffle_weight
|
100
100
|
|
101
|
-
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
102
|
-
from vllm._custom_ops import scaled_fp8_quant
|
103
|
-
|
104
101
|
|
105
102
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
106
103
|
|
@@ -619,7 +616,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
619
616
|
if (
|
620
617
|
get_bool_env_var("SGLANG_CUTLASS_MOE")
|
621
618
|
and self.cutlass_fp8_supported
|
622
|
-
and is_sm100_supported()
|
619
|
+
and (is_sm100_supported() or is_sm90_supported())
|
623
620
|
):
|
624
621
|
self.ab_strides1 = torch.full(
|
625
622
|
(num_experts,),
|
@@ -1034,7 +1031,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1034
1031
|
get_bool_env_var("SGLANG_CUTLASS_MOE")
|
1035
1032
|
and self.cutlass_fp8_supported
|
1036
1033
|
and self.block_quant
|
1037
|
-
and is_sm100_supported()
|
1034
|
+
and (is_sm100_supported() or is_sm90_supported())
|
1038
1035
|
):
|
1039
1036
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
1040
1037
|
|
@@ -4,6 +4,7 @@ import torch
|
|
4
4
|
|
5
5
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
6
6
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
7
|
+
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
7
8
|
from sglang.srt.layers.utils import is_sm100_supported
|
8
9
|
|
9
10
|
try:
|
@@ -26,6 +27,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
26
27
|
)
|
27
28
|
from sglang.srt.utils import (
|
28
29
|
align,
|
30
|
+
ceil_div,
|
29
31
|
get_bool_env_var,
|
30
32
|
get_cuda_version,
|
31
33
|
get_device_capability,
|
@@ -307,6 +309,33 @@ def triton_w8a8_block_fp8_linear(
|
|
307
309
|
return output.to(dtype=input_2d.dtype).view(*output_shape)
|
308
310
|
|
309
311
|
|
312
|
+
def dequant_mxfp4(
|
313
|
+
w_block: torch.Tensor,
|
314
|
+
w_scale: torch.Tensor,
|
315
|
+
out_dtype,
|
316
|
+
) -> torch.Tensor:
|
317
|
+
"""
|
318
|
+
:param w_block: (batch, n, k, 16), uint8, pack two mxfp4 into one byte
|
319
|
+
:param w_scale: (batch, n, k), uint8
|
320
|
+
:return: (batch, n, k * 32), float32
|
321
|
+
"""
|
322
|
+
|
323
|
+
assert w_block.dtype == torch.uint8
|
324
|
+
assert w_scale.dtype == torch.uint8
|
325
|
+
|
326
|
+
batch, n, k, pack_dim = w_block.shape
|
327
|
+
batch_, n_, k_ = w_scale.shape
|
328
|
+
assert pack_dim == 16
|
329
|
+
assert batch == batch_
|
330
|
+
assert n == n_
|
331
|
+
assert k == k_
|
332
|
+
|
333
|
+
out_raw = MXFP4QuantizeUtil.dequantize(
|
334
|
+
quantized_data=w_block, scale=w_scale, dtype=out_dtype, block_sizes=[32]
|
335
|
+
)
|
336
|
+
return out_raw.reshape(batch, n, k * 32)
|
337
|
+
|
338
|
+
|
310
339
|
def input_to_float8(
|
311
340
|
x: torch.Tensor, dtype: torch.dtype = fp8_dtype
|
312
341
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|