sglang 0.4.9__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/srt/configs/model_config.py +12 -1
- sglang/srt/conversation.py +35 -1
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/layers/communicator.py +3 -1
- sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
- sglang/srt/layers/layernorm.py +2 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +58 -0
- sglang/srt/layers/moe/ep_moe/layer.py +140 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +135 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +28 -7
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/vocab_parallel_embedding.py +9 -3
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/io_struct.py +8 -1
- sglang/srt/managers/mm_utils.py +4 -2
- sglang/srt/managers/schedule_batch.py +1 -1
- sglang/srt/managers/scheduler.py +17 -5
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +113 -63
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/models/deepseek_v2.py +16 -2
- sglang/srt/models/mllama4.py +360 -79
- sglang/srt/multimodal/mm_utils.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +62 -60
- sglang/srt/server_args.py +15 -0
- sglang/srt/two_batch_overlap.py +3 -0
- sglang/srt/utils.py +37 -17
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +4 -3
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +47 -43
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,176 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from sgl_kernel import gelu_and_mul, silu_and_mul
|
6
|
+
from triton_kernels.matmul_ogs import matmul_ogs
|
7
|
+
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
8
|
+
|
9
|
+
from sglang.srt.utils import direct_register_custom_op
|
10
|
+
|
11
|
+
|
12
|
+
def triton_kernel_moe_forward(
|
13
|
+
hidden_states: torch.Tensor,
|
14
|
+
w1: torch.Tensor,
|
15
|
+
w2: torch.Tensor,
|
16
|
+
gating_output: torch.Tensor,
|
17
|
+
topk: int,
|
18
|
+
renormalize: bool,
|
19
|
+
inplace: bool = False,
|
20
|
+
activation: str = "silu",
|
21
|
+
apply_router_weight_on_input: bool = False,
|
22
|
+
use_fp8_w8a8: bool = False,
|
23
|
+
per_channel_quant: bool = False,
|
24
|
+
global_num_experts: int = -1,
|
25
|
+
expert_map: Optional[torch.Tensor] = None,
|
26
|
+
w1_scale: Optional[torch.Tensor] = None,
|
27
|
+
w2_scale: Optional[torch.Tensor] = None,
|
28
|
+
a1_scale: Optional[torch.Tensor] = None,
|
29
|
+
a2_scale: Optional[torch.Tensor] = None,
|
30
|
+
block_shape: Optional[list[int]] = None,
|
31
|
+
) -> torch.Tensor:
|
32
|
+
|
33
|
+
if not renormalize:
|
34
|
+
gating_output = torch.softmax(gating_output, dim=-1)
|
35
|
+
routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
|
36
|
+
|
37
|
+
return triton_kernel_fused_experts(
|
38
|
+
hidden_states,
|
39
|
+
w1,
|
40
|
+
w2,
|
41
|
+
routing_data,
|
42
|
+
gather_idx,
|
43
|
+
scatter_idx,
|
44
|
+
inplace=inplace,
|
45
|
+
activation=activation,
|
46
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
47
|
+
use_fp8_w8a8=use_fp8_w8a8,
|
48
|
+
per_channel_quant=per_channel_quant,
|
49
|
+
global_num_experts=global_num_experts,
|
50
|
+
expert_map=expert_map,
|
51
|
+
w1_scale=w1_scale,
|
52
|
+
w2_scale=w2_scale,
|
53
|
+
a1_scale=a1_scale,
|
54
|
+
a2_scale=a2_scale,
|
55
|
+
block_shape=block_shape,
|
56
|
+
)
|
57
|
+
|
58
|
+
|
59
|
+
# This is a triton implementation of the fused_experts function
|
60
|
+
def triton_kernel_fused_experts(
|
61
|
+
hidden_states: torch.Tensor,
|
62
|
+
w1: torch.Tensor,
|
63
|
+
w2: torch.Tensor,
|
64
|
+
routing_data: RoutingData,
|
65
|
+
gather_indx: GatherIndx,
|
66
|
+
scatter_indx: ScatterIndx,
|
67
|
+
inplace: bool = False,
|
68
|
+
activation: str = "silu",
|
69
|
+
apply_router_weight_on_input: bool = False,
|
70
|
+
use_fp8_w8a8: bool = False,
|
71
|
+
per_channel_quant: bool = False,
|
72
|
+
global_num_experts: int = -1,
|
73
|
+
expert_map: Optional[torch.Tensor] = None,
|
74
|
+
w1_scale: Optional[torch.Tensor] = None,
|
75
|
+
w2_scale: Optional[torch.Tensor] = None,
|
76
|
+
a1_scale: Optional[torch.Tensor] = None,
|
77
|
+
a2_scale: Optional[torch.Tensor] = None,
|
78
|
+
block_shape: Optional[list[int]] = None,
|
79
|
+
) -> torch.Tensor:
|
80
|
+
|
81
|
+
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
|
82
|
+
assert per_channel_quant == False, "per_channel_quant is not supported"
|
83
|
+
assert expert_map == None, "expert_map is not supported"
|
84
|
+
assert w1_scale == None, "w1_scale is not supported"
|
85
|
+
assert w2_scale == None, "w2_scale is not supported"
|
86
|
+
assert a1_scale == None, "a1_scale is not supported"
|
87
|
+
assert a2_scale == None, "a2_scale is not supported"
|
88
|
+
assert block_shape == None, "block_shape is not supported"
|
89
|
+
|
90
|
+
# type check
|
91
|
+
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
|
92
|
+
assert w1.dtype == torch.bfloat16, "w1 must be bfloat16"
|
93
|
+
assert w2.dtype == torch.bfloat16, "w2 must be bfloat16"
|
94
|
+
|
95
|
+
# Shape check
|
96
|
+
assert hidden_states.ndim == 2, "hidden_states must be 2D"
|
97
|
+
assert (
|
98
|
+
hidden_states.shape[-1] == w1.shape[-2]
|
99
|
+
), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
|
100
|
+
assert (
|
101
|
+
w2.shape[-1] == w1.shape[1]
|
102
|
+
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
|
103
|
+
|
104
|
+
# feature check
|
105
|
+
assert inplace == False, "Inplace is not supported in new triton MoE kernel"
|
106
|
+
|
107
|
+
M, K = hidden_states.shape
|
108
|
+
E, _, N = w1.shape
|
109
|
+
n_expts_act = routing_data.n_expts_act
|
110
|
+
dtype = hidden_states.dtype
|
111
|
+
|
112
|
+
if global_num_experts == -1:
|
113
|
+
global_num_experts = E
|
114
|
+
|
115
|
+
# consistent with default implementation
|
116
|
+
intermediate_cache2 = torch.empty(
|
117
|
+
(M * n_expts_act, N // 2), device="cuda", dtype=dtype
|
118
|
+
)
|
119
|
+
|
120
|
+
intermediate_cache1 = matmul_ogs(
|
121
|
+
hidden_states,
|
122
|
+
w1,
|
123
|
+
None,
|
124
|
+
routing_data,
|
125
|
+
gather_indx=gather_indx,
|
126
|
+
gammas=routing_data.gate_scal if apply_router_weight_on_input else None,
|
127
|
+
)
|
128
|
+
|
129
|
+
if activation == "silu":
|
130
|
+
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
131
|
+
elif activation == "gelu":
|
132
|
+
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
133
|
+
else:
|
134
|
+
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
135
|
+
|
136
|
+
intermediate_cache3 = matmul_ogs(
|
137
|
+
intermediate_cache2,
|
138
|
+
w2,
|
139
|
+
None,
|
140
|
+
routing_data,
|
141
|
+
scatter_indx=scatter_indx,
|
142
|
+
gammas=None if apply_router_weight_on_input else routing_data.gate_scal,
|
143
|
+
)
|
144
|
+
|
145
|
+
return intermediate_cache3
|
146
|
+
|
147
|
+
|
148
|
+
def triton_kernel_moe_forward_fake(
|
149
|
+
hidden_states: torch.Tensor,
|
150
|
+
w1: torch.Tensor,
|
151
|
+
w2: torch.Tensor,
|
152
|
+
gating_output: torch.Tensor,
|
153
|
+
topk: int,
|
154
|
+
renormalize: bool,
|
155
|
+
inplace: bool = False,
|
156
|
+
activation: str = "silu",
|
157
|
+
apply_router_weight_on_input: bool = False,
|
158
|
+
use_fp8_w8a8: bool = False,
|
159
|
+
per_channel_quant: bool = False,
|
160
|
+
global_num_experts: int = -1,
|
161
|
+
expert_map: Optional[torch.Tensor] = None,
|
162
|
+
w1_scale: Optional[torch.Tensor] = None,
|
163
|
+
w2_scale: Optional[torch.Tensor] = None,
|
164
|
+
a1_scale: Optional[torch.Tensor] = None,
|
165
|
+
a2_scale: Optional[torch.Tensor] = None,
|
166
|
+
block_shape: Optional[list[int]] = None,
|
167
|
+
) -> torch.Tensor:
|
168
|
+
return torch.empty_like(hidden_states)
|
169
|
+
|
170
|
+
|
171
|
+
direct_register_custom_op(
|
172
|
+
op_name="forward_cuda_triton",
|
173
|
+
op_func=triton_kernel_moe_forward,
|
174
|
+
mutates_args=[],
|
175
|
+
fake_impl=triton_kernel_moe_forward_fake,
|
176
|
+
)
|
@@ -68,6 +68,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
|
68
68
|
)
|
69
69
|
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
70
70
|
from sglang.srt.layers.quantization.qoq import QoQConfig
|
71
|
+
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
71
72
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
72
73
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
73
74
|
|
@@ -82,6 +83,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
82
83
|
"moe_wna16": MoeWNA16Config,
|
83
84
|
"compressed-tensors": CompressedTensorsConfig,
|
84
85
|
"qoq": QoQConfig,
|
86
|
+
"w4afp8": W4AFp8Config,
|
85
87
|
}
|
86
88
|
|
87
89
|
# VLLM-dependent quantization methods
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import Any, Callable, Dict, List, Optional
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as F
|
@@ -88,7 +88,7 @@ _is_fp8_fnuz = is_fp8_fnuz()
|
|
88
88
|
_use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
|
89
89
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
90
90
|
|
91
|
-
if _is_hip:
|
91
|
+
if _is_hip and (_use_aiter or _use_hip_int4):
|
92
92
|
from aiter import ActivationType, QuantType
|
93
93
|
from aiter.fused_moe import fused_moe
|
94
94
|
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
@@ -200,7 +200,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
200
200
|
quant_config: The quantization config.
|
201
201
|
"""
|
202
202
|
|
203
|
-
def __init__(self, quant_config: Fp8Config):
|
203
|
+
def __init__(self, quant_config: Union["Fp8Config", "W4AFp8Config"]):
|
204
204
|
self.quant_config = quant_config
|
205
205
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
206
206
|
|
@@ -286,7 +286,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
286
286
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
287
287
|
# WEIGHT SCALE
|
288
288
|
if self.block_quant:
|
289
|
-
|
289
|
+
if hasattr(self.quant_config, "activation_scheme"):
|
290
|
+
assert self.quant_config.activation_scheme == "dynamic"
|
291
|
+
elif hasattr(self.quant_config, "linear_activation_scheme"):
|
292
|
+
assert self.quant_config.linear_activation_scheme == "dynamic"
|
290
293
|
scale = BlockQuantScaleParameter(
|
291
294
|
data=torch.empty(
|
292
295
|
(output_size_per_partition + block_n - 1) // block_n,
|
@@ -308,7 +311,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
308
311
|
layer.register_parameter("weight_scale", scale)
|
309
312
|
|
310
313
|
# INPUT ACTIVATION SCALE
|
311
|
-
if
|
314
|
+
if (
|
315
|
+
hasattr(self.quant_config, "activation_scheme")
|
316
|
+
and self.quant_config.activation_scheme == "static"
|
317
|
+
) or (
|
318
|
+
hasattr(self.quant_config, "linear_activation_scheme")
|
319
|
+
and self.quant_config.linear_activation_scheme == "static"
|
320
|
+
):
|
312
321
|
scale = PerTensorScaleParameter(
|
313
322
|
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
314
323
|
weight_loader=weight_loader,
|
@@ -371,7 +380,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
371
380
|
layer.weight_scale = torch.nn.Parameter(
|
372
381
|
layer.weight_scale.data, requires_grad=False
|
373
382
|
)
|
374
|
-
if
|
383
|
+
if (
|
384
|
+
hasattr(self.quant_config, "activation_scheme")
|
385
|
+
and self.quant_config.activation_scheme == "static"
|
386
|
+
) or (
|
387
|
+
hasattr(self.quant_config, "linear_activation_scheme")
|
388
|
+
and self.quant_config.linear_activation_scheme == "static"
|
389
|
+
):
|
375
390
|
layer.input_scale = torch.nn.Parameter(
|
376
391
|
layer.input_scale.data, requires_grad=False
|
377
392
|
)
|
@@ -405,7 +420,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
405
420
|
# Update layer with new values.
|
406
421
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
407
422
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
408
|
-
if
|
423
|
+
if (
|
424
|
+
hasattr(self.quant_config, "activation_scheme")
|
425
|
+
and self.quant_config.activation_scheme == "static"
|
426
|
+
) or (
|
427
|
+
hasattr(self.quant_config, "linear_activation_scheme")
|
428
|
+
and self.quant_config.linear_activation_scheme == "static"
|
429
|
+
):
|
409
430
|
layer.input_scale = Parameter(
|
410
431
|
layer.input_scale.max(), requires_grad=False
|
411
432
|
)
|
@@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
|
26
26
|
from sglang.srt.layers.quantization.utils import (
|
27
27
|
convert_to_channelwise,
|
28
28
|
is_layer_skipped,
|
29
|
+
per_tensor_dequantize,
|
29
30
|
requantize_with_max_scale,
|
30
31
|
)
|
31
32
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
110
111
|
self, layer: torch.nn.Module, prefix: str
|
111
112
|
) -> Optional["QuantizeMethodBase"]:
|
112
113
|
if self.exclude_modules and any(
|
113
|
-
module in prefix
|
114
|
+
module in prefix
|
115
|
+
or (
|
116
|
+
prefix.startswith("language_model.")
|
117
|
+
and module in prefix.removeprefix("language_model.")
|
118
|
+
)
|
119
|
+
for module in self.exclude_modules
|
114
120
|
):
|
115
121
|
return None
|
116
122
|
|
@@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
119
125
|
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
120
126
|
return ModelOptFp8KVCacheMethod(self)
|
121
127
|
|
128
|
+
# Add MoE support
|
129
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
130
|
+
|
131
|
+
if isinstance(layer, FusedMoE):
|
132
|
+
return ModelOptFp8MoEMethod(self)
|
133
|
+
|
122
134
|
return None
|
123
135
|
|
124
136
|
def get_scaled_act_names(self) -> List[str]:
|
@@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
|
234
246
|
super().__init__(quant_config)
|
235
247
|
|
236
248
|
|
249
|
+
class ModelOptFp8MoEMethod:
|
250
|
+
"""MoE method for ModelOpt FP8.
|
251
|
+
Supports loading FP8 checkpoints with static weight scale and activation scale.
|
252
|
+
|
253
|
+
Args:
|
254
|
+
quant_config: The ModelOpt quantization config.
|
255
|
+
"""
|
256
|
+
|
257
|
+
def __new__(cls, *args, **kwargs):
|
258
|
+
"""
|
259
|
+
Dynamic class composition pattern.
|
260
|
+
|
261
|
+
This allows us to effectively "inject" FusedMoEMethodBase as a parent class
|
262
|
+
at runtime while avoiding circular import issues.
|
263
|
+
"""
|
264
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
265
|
+
|
266
|
+
if not hasattr(cls, "_initialized"):
|
267
|
+
original_init = cls.__init__
|
268
|
+
new_cls = type(
|
269
|
+
cls.__name__,
|
270
|
+
(FusedMoEMethodBase,),
|
271
|
+
{
|
272
|
+
"__init__": original_init,
|
273
|
+
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
274
|
+
},
|
275
|
+
)
|
276
|
+
obj = super(new_cls, new_cls).__new__(new_cls)
|
277
|
+
obj.__init__(*args, **kwargs)
|
278
|
+
return obj
|
279
|
+
return super().__new__(cls)
|
280
|
+
|
281
|
+
def __init__(self, quant_config: ModelOptFp8Config):
|
282
|
+
self.quant_config = quant_config
|
283
|
+
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
284
|
+
|
285
|
+
def create_weights(
|
286
|
+
self,
|
287
|
+
layer: torch.nn.Module,
|
288
|
+
num_experts: int,
|
289
|
+
hidden_size: int,
|
290
|
+
intermediate_size: int,
|
291
|
+
params_dtype: torch.dtype,
|
292
|
+
**extra_weight_attrs,
|
293
|
+
):
|
294
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
295
|
+
|
296
|
+
# Use FP8 dtype if checkpoint is serialized, otherwise use the default dtype
|
297
|
+
weight_dtype = (
|
298
|
+
torch.float8_e4m3fn
|
299
|
+
if self.quant_config.is_checkpoint_fp8_serialized
|
300
|
+
else params_dtype
|
301
|
+
)
|
302
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
303
|
+
|
304
|
+
w13_weight = ModelWeightParameter(
|
305
|
+
data=torch.empty(
|
306
|
+
num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
|
307
|
+
),
|
308
|
+
input_dim=2,
|
309
|
+
output_dim=1,
|
310
|
+
weight_loader=weight_loader,
|
311
|
+
)
|
312
|
+
layer.register_parameter("w13_weight", w13_weight)
|
313
|
+
|
314
|
+
w2_weight = ModelWeightParameter(
|
315
|
+
data=torch.empty(
|
316
|
+
num_experts, hidden_size, intermediate_size, dtype=weight_dtype
|
317
|
+
),
|
318
|
+
input_dim=2,
|
319
|
+
output_dim=1,
|
320
|
+
weight_loader=weight_loader,
|
321
|
+
)
|
322
|
+
layer.register_parameter("w2_weight", w2_weight)
|
323
|
+
|
324
|
+
if self.quant_config.is_checkpoint_fp8_serialized:
|
325
|
+
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
|
326
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
327
|
+
# They will be combined to a single scale after weight loading.
|
328
|
+
w13_weight_scale = PerTensorScaleParameter(
|
329
|
+
data=torch.full(
|
330
|
+
(num_experts, 2),
|
331
|
+
torch.finfo(torch.float32).min,
|
332
|
+
dtype=torch.float32,
|
333
|
+
),
|
334
|
+
weight_loader=weight_loader,
|
335
|
+
)
|
336
|
+
w2_weight_scale = PerTensorScaleParameter(
|
337
|
+
data=torch.full(
|
338
|
+
(num_experts,), torch.finfo(torch.float32).min, dtype=torch.float32
|
339
|
+
),
|
340
|
+
weight_loader=weight_loader,
|
341
|
+
)
|
342
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
343
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
344
|
+
|
345
|
+
# Set weight loader attributes for scales
|
346
|
+
extra_weight_attrs.update(
|
347
|
+
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
348
|
+
)
|
349
|
+
|
350
|
+
# INPUT SCALES - Per-tensor scaling for ModelOpt
|
351
|
+
w13_input_scale = PerTensorScaleParameter(
|
352
|
+
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
|
353
|
+
weight_loader=weight_loader,
|
354
|
+
)
|
355
|
+
w2_input_scale = PerTensorScaleParameter(
|
356
|
+
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
|
357
|
+
weight_loader=weight_loader,
|
358
|
+
)
|
359
|
+
layer.register_parameter("w13_input_scale", w13_input_scale)
|
360
|
+
layer.register_parameter("w2_input_scale", w2_input_scale)
|
361
|
+
|
362
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
363
|
+
"""Process FP8 MoE weights after loading from serialized checkpoint.
|
364
|
+
|
365
|
+
Only supports pre-quantized checkpoints with FP8 weights and scales.
|
366
|
+
"""
|
367
|
+
|
368
|
+
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
369
|
+
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
370
|
+
|
371
|
+
# Handle scale parameters
|
372
|
+
if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
|
373
|
+
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
374
|
+
# We take the max of the w1 and w3 scales then dequant and requant each expert.
|
375
|
+
if layer.w13_weight_scale.dim() == 2: # Shape: (num_experts, 2)
|
376
|
+
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
377
|
+
|
378
|
+
# Get the maximum scale across w1 and w3 for each expert
|
379
|
+
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
380
|
+
|
381
|
+
# Requantize each expert's weights using the combined scale
|
382
|
+
# w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
|
383
|
+
# where the first intermediate_size rows are w1, the next are w3
|
384
|
+
intermediate_size = layer.w13_weight.shape[1] // 2
|
385
|
+
for expert_id in range(layer.w13_weight.shape[0]):
|
386
|
+
start = 0
|
387
|
+
for shard_id in range(2): # w1 and w3
|
388
|
+
# Dequantize using the original scale for this shard
|
389
|
+
dq_weight = per_tensor_dequantize(
|
390
|
+
layer.w13_weight[expert_id][
|
391
|
+
start : start + intermediate_size, :
|
392
|
+
],
|
393
|
+
layer.w13_weight_scale[expert_id][shard_id],
|
394
|
+
)
|
395
|
+
# Requantize using the combined max scale
|
396
|
+
(
|
397
|
+
layer.w13_weight[expert_id][
|
398
|
+
start : start + intermediate_size, :
|
399
|
+
],
|
400
|
+
_,
|
401
|
+
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
402
|
+
|
403
|
+
start += intermediate_size
|
404
|
+
|
405
|
+
# Update the scale parameter to be per-expert instead of per-shard
|
406
|
+
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
|
407
|
+
else:
|
408
|
+
layer.w13_weight_scale = Parameter(
|
409
|
+
layer.w13_weight_scale.data, requires_grad=False
|
410
|
+
)
|
411
|
+
|
412
|
+
if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
|
413
|
+
layer.w2_weight_scale = Parameter(
|
414
|
+
layer.w2_weight_scale.data, requires_grad=False
|
415
|
+
)
|
416
|
+
if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
|
417
|
+
layer.w13_input_scale = Parameter(
|
418
|
+
layer.w13_input_scale.max(), requires_grad=False
|
419
|
+
)
|
420
|
+
if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
|
421
|
+
layer.w2_input_scale = Parameter(
|
422
|
+
layer.w2_input_scale.max(), requires_grad=False
|
423
|
+
)
|
424
|
+
|
425
|
+
def apply(
|
426
|
+
self,
|
427
|
+
layer: torch.nn.Module,
|
428
|
+
x: torch.Tensor,
|
429
|
+
router_logits: torch.Tensor,
|
430
|
+
top_k: int,
|
431
|
+
renormalize: bool,
|
432
|
+
use_grouped_topk: bool,
|
433
|
+
topk_group: Optional[int] = None,
|
434
|
+
num_expert_group: Optional[int] = None,
|
435
|
+
num_fused_shared_experts: Optional[int] = None,
|
436
|
+
custom_routing_function: Optional[Callable] = None,
|
437
|
+
correction_bias: Optional[torch.Tensor] = None,
|
438
|
+
activation: str = "silu",
|
439
|
+
apply_router_weight_on_input: bool = False,
|
440
|
+
inplace: bool = True,
|
441
|
+
no_combine: bool = False,
|
442
|
+
routed_scaling_factor: Optional[float] = None,
|
443
|
+
) -> torch.Tensor:
|
444
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
445
|
+
from sglang.srt.layers.moe.topk import select_experts
|
446
|
+
|
447
|
+
# Expert selection
|
448
|
+
topk_weights, topk_ids = select_experts(
|
449
|
+
hidden_states=x,
|
450
|
+
router_logits=router_logits,
|
451
|
+
use_grouped_topk=use_grouped_topk,
|
452
|
+
top_k=top_k,
|
453
|
+
renormalize=renormalize,
|
454
|
+
topk_group=topk_group,
|
455
|
+
num_expert_group=num_expert_group,
|
456
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
457
|
+
custom_routing_function=custom_routing_function,
|
458
|
+
correction_bias=correction_bias,
|
459
|
+
routed_scaling_factor=routed_scaling_factor,
|
460
|
+
)
|
461
|
+
|
462
|
+
return fused_experts(
|
463
|
+
x,
|
464
|
+
layer.w13_weight,
|
465
|
+
layer.w2_weight,
|
466
|
+
topk_weights=topk_weights,
|
467
|
+
topk_ids=topk_ids,
|
468
|
+
inplace=inplace,
|
469
|
+
activation=activation,
|
470
|
+
use_fp8_w8a8=True,
|
471
|
+
per_channel_quant=False, # ModelOpt uses per-tensor quantization
|
472
|
+
w1_scale=layer.w13_weight_scale,
|
473
|
+
w2_scale=layer.w2_weight_scale,
|
474
|
+
a1_scale=layer.w13_input_scale,
|
475
|
+
a2_scale=layer.w2_input_scale,
|
476
|
+
no_combine=no_combine,
|
477
|
+
)
|
478
|
+
|
479
|
+
|
237
480
|
class ModelOptFp4Config(QuantizationConfig):
|
238
481
|
"""Config class for FP4."""
|
239
482
|
|