sglang 0.4.9__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +12 -1
  3. sglang/srt/conversation.py +35 -1
  4. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  5. sglang/srt/entrypoints/http_server_engine.py +1 -1
  6. sglang/srt/layers/communicator.py +3 -1
  7. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  8. sglang/srt/layers/layernorm.py +2 -2
  9. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  10. sglang/srt/layers/moe/ep_moe/kernels.py +58 -0
  11. sglang/srt/layers/moe/ep_moe/layer.py +140 -2
  12. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  13. sglang/srt/layers/moe/fused_moe_triton/layer.py +135 -58
  14. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  15. sglang/srt/layers/quantization/__init__.py +2 -0
  16. sglang/srt/layers/quantization/fp8.py +28 -7
  17. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  18. sglang/srt/layers/quantization/w4afp8.py +264 -0
  19. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  20. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  21. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  22. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  23. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  24. sglang/srt/managers/cache_controller.py +41 -195
  25. sglang/srt/managers/io_struct.py +8 -1
  26. sglang/srt/managers/mm_utils.py +4 -2
  27. sglang/srt/managers/schedule_batch.py +1 -1
  28. sglang/srt/managers/scheduler.py +17 -5
  29. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  30. sglang/srt/mem_cache/memory_pool.py +113 -63
  31. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  32. sglang/srt/mem_cache/radix_cache.py +8 -4
  33. sglang/srt/models/deepseek_v2.py +16 -2
  34. sglang/srt/models/mllama4.py +360 -79
  35. sglang/srt/multimodal/mm_utils.py +2 -2
  36. sglang/srt/multimodal/processors/mllama4.py +62 -60
  37. sglang/srt/server_args.py +15 -0
  38. sglang/srt/two_batch_overlap.py +3 -0
  39. sglang/srt/utils.py +37 -17
  40. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  41. sglang/utils.py +5 -5
  42. sglang/version.py +1 -1
  43. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +4 -3
  44. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +47 -43
  45. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  46. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  47. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,176 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from sgl_kernel import gelu_and_mul, silu_and_mul
6
+ from triton_kernels.matmul_ogs import matmul_ogs
7
+ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
8
+
9
+ from sglang.srt.utils import direct_register_custom_op
10
+
11
+
12
+ def triton_kernel_moe_forward(
13
+ hidden_states: torch.Tensor,
14
+ w1: torch.Tensor,
15
+ w2: torch.Tensor,
16
+ gating_output: torch.Tensor,
17
+ topk: int,
18
+ renormalize: bool,
19
+ inplace: bool = False,
20
+ activation: str = "silu",
21
+ apply_router_weight_on_input: bool = False,
22
+ use_fp8_w8a8: bool = False,
23
+ per_channel_quant: bool = False,
24
+ global_num_experts: int = -1,
25
+ expert_map: Optional[torch.Tensor] = None,
26
+ w1_scale: Optional[torch.Tensor] = None,
27
+ w2_scale: Optional[torch.Tensor] = None,
28
+ a1_scale: Optional[torch.Tensor] = None,
29
+ a2_scale: Optional[torch.Tensor] = None,
30
+ block_shape: Optional[list[int]] = None,
31
+ ) -> torch.Tensor:
32
+
33
+ if not renormalize:
34
+ gating_output = torch.softmax(gating_output, dim=-1)
35
+ routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
36
+
37
+ return triton_kernel_fused_experts(
38
+ hidden_states,
39
+ w1,
40
+ w2,
41
+ routing_data,
42
+ gather_idx,
43
+ scatter_idx,
44
+ inplace=inplace,
45
+ activation=activation,
46
+ apply_router_weight_on_input=apply_router_weight_on_input,
47
+ use_fp8_w8a8=use_fp8_w8a8,
48
+ per_channel_quant=per_channel_quant,
49
+ global_num_experts=global_num_experts,
50
+ expert_map=expert_map,
51
+ w1_scale=w1_scale,
52
+ w2_scale=w2_scale,
53
+ a1_scale=a1_scale,
54
+ a2_scale=a2_scale,
55
+ block_shape=block_shape,
56
+ )
57
+
58
+
59
+ # This is a triton implementation of the fused_experts function
60
+ def triton_kernel_fused_experts(
61
+ hidden_states: torch.Tensor,
62
+ w1: torch.Tensor,
63
+ w2: torch.Tensor,
64
+ routing_data: RoutingData,
65
+ gather_indx: GatherIndx,
66
+ scatter_indx: ScatterIndx,
67
+ inplace: bool = False,
68
+ activation: str = "silu",
69
+ apply_router_weight_on_input: bool = False,
70
+ use_fp8_w8a8: bool = False,
71
+ per_channel_quant: bool = False,
72
+ global_num_experts: int = -1,
73
+ expert_map: Optional[torch.Tensor] = None,
74
+ w1_scale: Optional[torch.Tensor] = None,
75
+ w2_scale: Optional[torch.Tensor] = None,
76
+ a1_scale: Optional[torch.Tensor] = None,
77
+ a2_scale: Optional[torch.Tensor] = None,
78
+ block_shape: Optional[list[int]] = None,
79
+ ) -> torch.Tensor:
80
+
81
+ assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
82
+ assert per_channel_quant == False, "per_channel_quant is not supported"
83
+ assert expert_map == None, "expert_map is not supported"
84
+ assert w1_scale == None, "w1_scale is not supported"
85
+ assert w2_scale == None, "w2_scale is not supported"
86
+ assert a1_scale == None, "a1_scale is not supported"
87
+ assert a2_scale == None, "a2_scale is not supported"
88
+ assert block_shape == None, "block_shape is not supported"
89
+
90
+ # type check
91
+ assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
92
+ assert w1.dtype == torch.bfloat16, "w1 must be bfloat16"
93
+ assert w2.dtype == torch.bfloat16, "w2 must be bfloat16"
94
+
95
+ # Shape check
96
+ assert hidden_states.ndim == 2, "hidden_states must be 2D"
97
+ assert (
98
+ hidden_states.shape[-1] == w1.shape[-2]
99
+ ), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
100
+ assert (
101
+ w2.shape[-1] == w1.shape[1]
102
+ ), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
103
+
104
+ # feature check
105
+ assert inplace == False, "Inplace is not supported in new triton MoE kernel"
106
+
107
+ M, K = hidden_states.shape
108
+ E, _, N = w1.shape
109
+ n_expts_act = routing_data.n_expts_act
110
+ dtype = hidden_states.dtype
111
+
112
+ if global_num_experts == -1:
113
+ global_num_experts = E
114
+
115
+ # consistent with default implementation
116
+ intermediate_cache2 = torch.empty(
117
+ (M * n_expts_act, N // 2), device="cuda", dtype=dtype
118
+ )
119
+
120
+ intermediate_cache1 = matmul_ogs(
121
+ hidden_states,
122
+ w1,
123
+ None,
124
+ routing_data,
125
+ gather_indx=gather_indx,
126
+ gammas=routing_data.gate_scal if apply_router_weight_on_input else None,
127
+ )
128
+
129
+ if activation == "silu":
130
+ silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
131
+ elif activation == "gelu":
132
+ gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
133
+ else:
134
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}")
135
+
136
+ intermediate_cache3 = matmul_ogs(
137
+ intermediate_cache2,
138
+ w2,
139
+ None,
140
+ routing_data,
141
+ scatter_indx=scatter_indx,
142
+ gammas=None if apply_router_weight_on_input else routing_data.gate_scal,
143
+ )
144
+
145
+ return intermediate_cache3
146
+
147
+
148
+ def triton_kernel_moe_forward_fake(
149
+ hidden_states: torch.Tensor,
150
+ w1: torch.Tensor,
151
+ w2: torch.Tensor,
152
+ gating_output: torch.Tensor,
153
+ topk: int,
154
+ renormalize: bool,
155
+ inplace: bool = False,
156
+ activation: str = "silu",
157
+ apply_router_weight_on_input: bool = False,
158
+ use_fp8_w8a8: bool = False,
159
+ per_channel_quant: bool = False,
160
+ global_num_experts: int = -1,
161
+ expert_map: Optional[torch.Tensor] = None,
162
+ w1_scale: Optional[torch.Tensor] = None,
163
+ w2_scale: Optional[torch.Tensor] = None,
164
+ a1_scale: Optional[torch.Tensor] = None,
165
+ a2_scale: Optional[torch.Tensor] = None,
166
+ block_shape: Optional[list[int]] = None,
167
+ ) -> torch.Tensor:
168
+ return torch.empty_like(hidden_states)
169
+
170
+
171
+ direct_register_custom_op(
172
+ op_name="forward_cuda_triton",
173
+ op_func=triton_kernel_moe_forward,
174
+ mutates_args=[],
175
+ fake_impl=triton_kernel_moe_forward_fake,
176
+ )
@@ -68,6 +68,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
68
68
  )
69
69
  from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
70
70
  from sglang.srt.layers.quantization.qoq import QoQConfig
71
+ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
71
72
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
72
73
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
73
74
 
@@ -82,6 +83,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
82
83
  "moe_wna16": MoeWNA16Config,
83
84
  "compressed-tensors": CompressedTensorsConfig,
84
85
  "qoq": QoQConfig,
86
+ "w4afp8": W4AFp8Config,
85
87
  }
86
88
 
87
89
  # VLLM-dependent quantization methods
@@ -1,7 +1,7 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
2
2
 
3
3
  import logging
4
- from typing import Any, Callable, Dict, List, Optional
4
+ from typing import Any, Callable, Dict, List, Optional, Union
5
5
 
6
6
  import torch
7
7
  import torch.nn.functional as F
@@ -88,7 +88,7 @@ _is_fp8_fnuz = is_fp8_fnuz()
88
88
  _use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
89
89
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
90
90
 
91
- if _is_hip:
91
+ if _is_hip and (_use_aiter or _use_hip_int4):
92
92
  from aiter import ActivationType, QuantType
93
93
  from aiter.fused_moe import fused_moe
94
94
  from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
@@ -200,7 +200,7 @@ class Fp8LinearMethod(LinearMethodBase):
200
200
  quant_config: The quantization config.
201
201
  """
202
202
 
203
- def __init__(self, quant_config: Fp8Config):
203
+ def __init__(self, quant_config: Union["Fp8Config", "W4AFp8Config"]):
204
204
  self.quant_config = quant_config
205
205
  self.cutlass_fp8_supported = cutlass_fp8_supported()
206
206
 
@@ -286,7 +286,10 @@ class Fp8LinearMethod(LinearMethodBase):
286
286
  if self.quant_config.is_checkpoint_fp8_serialized:
287
287
  # WEIGHT SCALE
288
288
  if self.block_quant:
289
- assert self.quant_config.activation_scheme == "dynamic"
289
+ if hasattr(self.quant_config, "activation_scheme"):
290
+ assert self.quant_config.activation_scheme == "dynamic"
291
+ elif hasattr(self.quant_config, "linear_activation_scheme"):
292
+ assert self.quant_config.linear_activation_scheme == "dynamic"
290
293
  scale = BlockQuantScaleParameter(
291
294
  data=torch.empty(
292
295
  (output_size_per_partition + block_n - 1) // block_n,
@@ -308,7 +311,13 @@ class Fp8LinearMethod(LinearMethodBase):
308
311
  layer.register_parameter("weight_scale", scale)
309
312
 
310
313
  # INPUT ACTIVATION SCALE
311
- if self.quant_config.activation_scheme == "static":
314
+ if (
315
+ hasattr(self.quant_config, "activation_scheme")
316
+ and self.quant_config.activation_scheme == "static"
317
+ ) or (
318
+ hasattr(self.quant_config, "linear_activation_scheme")
319
+ and self.quant_config.linear_activation_scheme == "static"
320
+ ):
312
321
  scale = PerTensorScaleParameter(
313
322
  data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
314
323
  weight_loader=weight_loader,
@@ -371,7 +380,13 @@ class Fp8LinearMethod(LinearMethodBase):
371
380
  layer.weight_scale = torch.nn.Parameter(
372
381
  layer.weight_scale.data, requires_grad=False
373
382
  )
374
- if self.quant_config.activation_scheme == "static":
383
+ if (
384
+ hasattr(self.quant_config, "activation_scheme")
385
+ and self.quant_config.activation_scheme == "static"
386
+ ) or (
387
+ hasattr(self.quant_config, "linear_activation_scheme")
388
+ and self.quant_config.linear_activation_scheme == "static"
389
+ ):
375
390
  layer.input_scale = torch.nn.Parameter(
376
391
  layer.input_scale.data, requires_grad=False
377
392
  )
@@ -405,7 +420,13 @@ class Fp8LinearMethod(LinearMethodBase):
405
420
  # Update layer with new values.
406
421
  layer.weight = Parameter(weight.t(), requires_grad=False)
407
422
  layer.weight_scale = Parameter(weight_scale, requires_grad=False)
408
- if self.quant_config.activation_scheme == "static":
423
+ if (
424
+ hasattr(self.quant_config, "activation_scheme")
425
+ and self.quant_config.activation_scheme == "static"
426
+ ) or (
427
+ hasattr(self.quant_config, "linear_activation_scheme")
428
+ and self.quant_config.linear_activation_scheme == "static"
429
+ ):
409
430
  layer.input_scale = Parameter(
410
431
  layer.input_scale.max(), requires_grad=False
411
432
  )
@@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
26
26
  from sglang.srt.layers.quantization.utils import (
27
27
  convert_to_channelwise,
28
28
  is_layer_skipped,
29
+ per_tensor_dequantize,
29
30
  requantize_with_max_scale,
30
31
  )
31
32
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig):
110
111
  self, layer: torch.nn.Module, prefix: str
111
112
  ) -> Optional["QuantizeMethodBase"]:
112
113
  if self.exclude_modules and any(
113
- module in prefix for module in self.exclude_modules
114
+ module in prefix
115
+ or (
116
+ prefix.startswith("language_model.")
117
+ and module in prefix.removeprefix("language_model.")
118
+ )
119
+ for module in self.exclude_modules
114
120
  ):
115
121
  return None
116
122
 
@@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig):
119
125
  if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
120
126
  return ModelOptFp8KVCacheMethod(self)
121
127
 
128
+ # Add MoE support
129
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
130
+
131
+ if isinstance(layer, FusedMoE):
132
+ return ModelOptFp8MoEMethod(self)
133
+
122
134
  return None
123
135
 
124
136
  def get_scaled_act_names(self) -> List[str]:
@@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
234
246
  super().__init__(quant_config)
235
247
 
236
248
 
249
+ class ModelOptFp8MoEMethod:
250
+ """MoE method for ModelOpt FP8.
251
+ Supports loading FP8 checkpoints with static weight scale and activation scale.
252
+
253
+ Args:
254
+ quant_config: The ModelOpt quantization config.
255
+ """
256
+
257
+ def __new__(cls, *args, **kwargs):
258
+ """
259
+ Dynamic class composition pattern.
260
+
261
+ This allows us to effectively "inject" FusedMoEMethodBase as a parent class
262
+ at runtime while avoiding circular import issues.
263
+ """
264
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
265
+
266
+ if not hasattr(cls, "_initialized"):
267
+ original_init = cls.__init__
268
+ new_cls = type(
269
+ cls.__name__,
270
+ (FusedMoEMethodBase,),
271
+ {
272
+ "__init__": original_init,
273
+ **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
274
+ },
275
+ )
276
+ obj = super(new_cls, new_cls).__new__(new_cls)
277
+ obj.__init__(*args, **kwargs)
278
+ return obj
279
+ return super().__new__(cls)
280
+
281
+ def __init__(self, quant_config: ModelOptFp8Config):
282
+ self.quant_config = quant_config
283
+ self.cutlass_fp8_supported = cutlass_fp8_supported()
284
+
285
+ def create_weights(
286
+ self,
287
+ layer: torch.nn.Module,
288
+ num_experts: int,
289
+ hidden_size: int,
290
+ intermediate_size: int,
291
+ params_dtype: torch.dtype,
292
+ **extra_weight_attrs,
293
+ ):
294
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
295
+
296
+ # Use FP8 dtype if checkpoint is serialized, otherwise use the default dtype
297
+ weight_dtype = (
298
+ torch.float8_e4m3fn
299
+ if self.quant_config.is_checkpoint_fp8_serialized
300
+ else params_dtype
301
+ )
302
+ weight_loader = extra_weight_attrs.get("weight_loader")
303
+
304
+ w13_weight = ModelWeightParameter(
305
+ data=torch.empty(
306
+ num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
307
+ ),
308
+ input_dim=2,
309
+ output_dim=1,
310
+ weight_loader=weight_loader,
311
+ )
312
+ layer.register_parameter("w13_weight", w13_weight)
313
+
314
+ w2_weight = ModelWeightParameter(
315
+ data=torch.empty(
316
+ num_experts, hidden_size, intermediate_size, dtype=weight_dtype
317
+ ),
318
+ input_dim=2,
319
+ output_dim=1,
320
+ weight_loader=weight_loader,
321
+ )
322
+ layer.register_parameter("w2_weight", w2_weight)
323
+
324
+ if self.quant_config.is_checkpoint_fp8_serialized:
325
+ # WEIGHT SCALES - Per-tensor scaling for ModelOpts
326
+ # Allocate 2 scales for w1 and w3 respectively.
327
+ # They will be combined to a single scale after weight loading.
328
+ w13_weight_scale = PerTensorScaleParameter(
329
+ data=torch.full(
330
+ (num_experts, 2),
331
+ torch.finfo(torch.float32).min,
332
+ dtype=torch.float32,
333
+ ),
334
+ weight_loader=weight_loader,
335
+ )
336
+ w2_weight_scale = PerTensorScaleParameter(
337
+ data=torch.full(
338
+ (num_experts,), torch.finfo(torch.float32).min, dtype=torch.float32
339
+ ),
340
+ weight_loader=weight_loader,
341
+ )
342
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
343
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
344
+
345
+ # Set weight loader attributes for scales
346
+ extra_weight_attrs.update(
347
+ {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
348
+ )
349
+
350
+ # INPUT SCALES - Per-tensor scaling for ModelOpt
351
+ w13_input_scale = PerTensorScaleParameter(
352
+ data=torch.full((num_experts,), 1.0, dtype=torch.float32),
353
+ weight_loader=weight_loader,
354
+ )
355
+ w2_input_scale = PerTensorScaleParameter(
356
+ data=torch.full((num_experts,), 1.0, dtype=torch.float32),
357
+ weight_loader=weight_loader,
358
+ )
359
+ layer.register_parameter("w13_input_scale", w13_input_scale)
360
+ layer.register_parameter("w2_input_scale", w2_input_scale)
361
+
362
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
363
+ """Process FP8 MoE weights after loading from serialized checkpoint.
364
+
365
+ Only supports pre-quantized checkpoints with FP8 weights and scales.
366
+ """
367
+
368
+ layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
369
+ layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
370
+
371
+ # Handle scale parameters
372
+ if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
373
+ # Fp8 moe kernel needs single weight scale for w13 per expert.
374
+ # We take the max of the w1 and w3 scales then dequant and requant each expert.
375
+ if layer.w13_weight_scale.dim() == 2: # Shape: (num_experts, 2)
376
+ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
377
+
378
+ # Get the maximum scale across w1 and w3 for each expert
379
+ max_w13_scales = layer.w13_weight_scale.max(dim=1).values
380
+
381
+ # Requantize each expert's weights using the combined scale
382
+ # w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
383
+ # where the first intermediate_size rows are w1, the next are w3
384
+ intermediate_size = layer.w13_weight.shape[1] // 2
385
+ for expert_id in range(layer.w13_weight.shape[0]):
386
+ start = 0
387
+ for shard_id in range(2): # w1 and w3
388
+ # Dequantize using the original scale for this shard
389
+ dq_weight = per_tensor_dequantize(
390
+ layer.w13_weight[expert_id][
391
+ start : start + intermediate_size, :
392
+ ],
393
+ layer.w13_weight_scale[expert_id][shard_id],
394
+ )
395
+ # Requantize using the combined max scale
396
+ (
397
+ layer.w13_weight[expert_id][
398
+ start : start + intermediate_size, :
399
+ ],
400
+ _,
401
+ ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
402
+
403
+ start += intermediate_size
404
+
405
+ # Update the scale parameter to be per-expert instead of per-shard
406
+ layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
407
+ else:
408
+ layer.w13_weight_scale = Parameter(
409
+ layer.w13_weight_scale.data, requires_grad=False
410
+ )
411
+
412
+ if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
413
+ layer.w2_weight_scale = Parameter(
414
+ layer.w2_weight_scale.data, requires_grad=False
415
+ )
416
+ if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
417
+ layer.w13_input_scale = Parameter(
418
+ layer.w13_input_scale.max(), requires_grad=False
419
+ )
420
+ if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
421
+ layer.w2_input_scale = Parameter(
422
+ layer.w2_input_scale.max(), requires_grad=False
423
+ )
424
+
425
+ def apply(
426
+ self,
427
+ layer: torch.nn.Module,
428
+ x: torch.Tensor,
429
+ router_logits: torch.Tensor,
430
+ top_k: int,
431
+ renormalize: bool,
432
+ use_grouped_topk: bool,
433
+ topk_group: Optional[int] = None,
434
+ num_expert_group: Optional[int] = None,
435
+ num_fused_shared_experts: Optional[int] = None,
436
+ custom_routing_function: Optional[Callable] = None,
437
+ correction_bias: Optional[torch.Tensor] = None,
438
+ activation: str = "silu",
439
+ apply_router_weight_on_input: bool = False,
440
+ inplace: bool = True,
441
+ no_combine: bool = False,
442
+ routed_scaling_factor: Optional[float] = None,
443
+ ) -> torch.Tensor:
444
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
445
+ from sglang.srt.layers.moe.topk import select_experts
446
+
447
+ # Expert selection
448
+ topk_weights, topk_ids = select_experts(
449
+ hidden_states=x,
450
+ router_logits=router_logits,
451
+ use_grouped_topk=use_grouped_topk,
452
+ top_k=top_k,
453
+ renormalize=renormalize,
454
+ topk_group=topk_group,
455
+ num_expert_group=num_expert_group,
456
+ num_fused_shared_experts=num_fused_shared_experts,
457
+ custom_routing_function=custom_routing_function,
458
+ correction_bias=correction_bias,
459
+ routed_scaling_factor=routed_scaling_factor,
460
+ )
461
+
462
+ return fused_experts(
463
+ x,
464
+ layer.w13_weight,
465
+ layer.w2_weight,
466
+ topk_weights=topk_weights,
467
+ topk_ids=topk_ids,
468
+ inplace=inplace,
469
+ activation=activation,
470
+ use_fp8_w8a8=True,
471
+ per_channel_quant=False, # ModelOpt uses per-tensor quantization
472
+ w1_scale=layer.w13_weight_scale,
473
+ w2_scale=layer.w2_weight_scale,
474
+ a1_scale=layer.w13_input_scale,
475
+ a2_scale=layer.w2_input_scale,
476
+ no_combine=no_combine,
477
+ )
478
+
479
+
237
480
  class ModelOptFp4Config(QuantizationConfig):
238
481
  """Config class for FP4."""
239
482