sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/lang/chat_template.py +24 -0
  3. sglang/srt/_custom_ops.py +59 -92
  4. sglang/srt/configs/model_config.py +5 -0
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/conversation.py +29 -4
  7. sglang/srt/custom_op.py +5 -0
  8. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  9. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/layers/attention/flashattention_backend.py +678 -83
  12. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  13. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  14. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  16. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  17. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  18. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
  30. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  31. sglang/srt/layers/moe/topk.py +49 -3
  32. sglang/srt/layers/quantization/__init__.py +5 -1
  33. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  35. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  36. sglang/srt/layers/quantization/fp8.py +3 -1
  37. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  38. sglang/srt/layers/quantization/moe_wna16.py +503 -0
  39. sglang/srt/layers/quantization/utils.py +1 -1
  40. sglang/srt/layers/quantization/w8a8_int8.py +2 -0
  41. sglang/srt/layers/radix_attention.py +2 -0
  42. sglang/srt/layers/rotary_embedding.py +63 -12
  43. sglang/srt/managers/cache_controller.py +34 -11
  44. sglang/srt/managers/mm_utils.py +202 -156
  45. sglang/srt/managers/multimodal_processor.py +0 -2
  46. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  47. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  48. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  49. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  50. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  51. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  52. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  53. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  54. sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
  55. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  56. sglang/srt/managers/schedule_batch.py +185 -128
  57. sglang/srt/managers/scheduler.py +4 -4
  58. sglang/srt/managers/tokenizer_manager.py +1 -1
  59. sglang/srt/managers/utils.py +1 -6
  60. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  61. sglang/srt/mem_cache/memory_pool.py +72 -6
  62. sglang/srt/mem_cache/paged_allocator.py +39 -0
  63. sglang/srt/metrics/collector.py +23 -53
  64. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  65. sglang/srt/model_executor/forward_batch_info.py +10 -10
  66. sglang/srt/model_executor/model_runner.py +60 -57
  67. sglang/srt/model_loader/loader.py +8 -0
  68. sglang/srt/models/clip.py +12 -7
  69. sglang/srt/models/deepseek_janus_pro.py +10 -15
  70. sglang/srt/models/deepseek_v2.py +212 -121
  71. sglang/srt/models/deepseek_vl2.py +105 -104
  72. sglang/srt/models/gemma3_mm.py +14 -80
  73. sglang/srt/models/llama.py +16 -5
  74. sglang/srt/models/llama4.py +420 -0
  75. sglang/srt/models/llava.py +31 -19
  76. sglang/srt/models/llavavid.py +16 -7
  77. sglang/srt/models/minicpmo.py +63 -147
  78. sglang/srt/models/minicpmv.py +17 -27
  79. sglang/srt/models/mllama.py +29 -14
  80. sglang/srt/models/mllama4.py +154 -0
  81. sglang/srt/models/qwen2.py +9 -6
  82. sglang/srt/models/qwen2_5_vl.py +21 -31
  83. sglang/srt/models/qwen2_vl.py +20 -21
  84. sglang/srt/openai_api/adapter.py +18 -6
  85. sglang/srt/platforms/interface.py +371 -0
  86. sglang/srt/server_args.py +99 -14
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  88. sglang/srt/speculative/eagle_utils.py +140 -28
  89. sglang/srt/speculative/eagle_worker.py +93 -24
  90. sglang/srt/utils.py +104 -51
  91. sglang/test/test_custom_ops.py +55 -0
  92. sglang/test/test_utils.py +13 -26
  93. sglang/utils.py +2 -2
  94. sglang/version.py +1 -1
  95. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
  96. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
  97. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -128,6 +128,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
128
128
  custom_routing_function: Optional[Callable] = None,
129
129
  correction_bias: Optional[torch.Tensor] = None,
130
130
  activation: str = "silu",
131
+ apply_router_weight_on_input: bool = False,
131
132
  inplace: bool = True,
132
133
  no_combine: bool = False,
133
134
  ) -> torch.Tensor:
@@ -143,6 +144,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
143
144
  custom_routing_function=custom_routing_function,
144
145
  correction_bias=correction_bias,
145
146
  activation=activation,
147
+ apply_router_weight_on_input=apply_router_weight_on_input,
146
148
  inplace=inplace,
147
149
  no_combine=no_combine,
148
150
  )
@@ -160,6 +162,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
160
162
  custom_routing_function: Optional[Callable] = None,
161
163
  correction_bias: Optional[torch.Tensor] = None,
162
164
  activation: str = "silu",
165
+ apply_router_weight_on_input: bool = False,
163
166
  inplace: bool = True,
164
167
  no_combine: bool = False,
165
168
  ) -> torch.Tensor:
@@ -200,6 +203,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
200
203
  topk_ids=topk_ids,
201
204
  inplace=inplace and not no_combine,
202
205
  activation=activation,
206
+ apply_router_weight_on_input=apply_router_weight_on_input,
203
207
  no_combine=no_combine,
204
208
  )
205
209
 
@@ -276,6 +280,7 @@ class FusedMoE(torch.nn.Module):
276
280
  custom_routing_function: Optional[Callable] = None,
277
281
  correction_bias: Optional[torch.Tensor] = None,
278
282
  activation: str = "silu",
283
+ apply_router_weight_on_input: bool = False,
279
284
  use_presharded_weights: bool = False,
280
285
  inplace: bool = True,
281
286
  no_combine: bool = False,
@@ -302,6 +307,7 @@ class FusedMoE(torch.nn.Module):
302
307
  self.custom_routing_function = custom_routing_function
303
308
  self.correction_bias = correction_bias
304
309
  self.activation = activation
310
+ self.apply_router_weight_on_input = apply_router_weight_on_input
305
311
  self.use_presharded_weights = use_presharded_weights
306
312
  self.inplace = inplace
307
313
  self.no_combine = no_combine
@@ -630,6 +636,7 @@ class FusedMoE(torch.nn.Module):
630
636
  custom_routing_function=self.custom_routing_function,
631
637
  correction_bias=self.correction_bias,
632
638
  activation=self.activation,
639
+ apply_router_weight_on_input=self.apply_router_weight_on_input,
633
640
  )
634
641
 
635
642
  if self.reduce_results and self.tp_size > 1:
@@ -12,12 +12,14 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
+ import os
15
16
  from typing import Callable, Optional
16
17
 
17
18
  import torch
18
19
  import torch.nn.functional as F
19
20
 
20
21
  from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
22
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
21
23
  from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
22
24
 
23
25
  _is_cuda = is_cuda()
@@ -102,11 +104,13 @@ def grouped_topk(
102
104
  renormalize: bool,
103
105
  num_expert_group: int = 0,
104
106
  topk_group: int = 0,
107
+ n_share_experts_fusion: int = 0,
105
108
  ):
106
109
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
107
110
 
108
111
  scores = torch.softmax(gating_output, dim=-1)
109
112
  num_token = scores.shape[0]
113
+ num_experts = scores.shape[1]
110
114
  group_scores = (
111
115
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
112
116
  ) # [n, n_group]
@@ -122,9 +126,25 @@ def grouped_topk(
122
126
  ) # [n, e]
123
127
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
124
128
  topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
129
+ if n_share_experts_fusion:
130
+ topk_ids[:, -1] = torch.randint(
131
+ low=num_experts,
132
+ high=num_experts + n_share_experts_fusion,
133
+ size=(topk_ids.size(0),),
134
+ dtype=topk_ids.dtype,
135
+ device=topk_ids.device,
136
+ )
137
+ topk_weights[:, -1] = (
138
+ topk_weights[:, :-1].sum(dim=-1) / 2.5
139
+ ) # 2.5 is the routed_scaling_factor.
125
140
 
126
141
  if renormalize:
127
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
142
+ topk_weights_sum = (
143
+ topk_weights.sum(dim=-1, keepdim=True)
144
+ if n_share_experts_fusion == 0
145
+ else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
146
+ )
147
+ topk_weights = topk_weights / topk_weights_sum
128
148
 
129
149
  return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
130
150
 
@@ -137,11 +157,13 @@ def biased_grouped_topk_impl(
137
157
  renormalize: bool,
138
158
  num_expert_group: int = 0,
139
159
  topk_group: int = 0,
160
+ n_share_experts_fusion: int = 0,
140
161
  ):
141
162
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
142
163
 
143
164
  scores = gating_output.sigmoid()
144
165
  num_token = scores.shape[0]
166
+ num_experts = scores.shape[1]
145
167
  scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
146
168
  group_scores = (
147
169
  scores_for_choice.view(num_token, num_expert_group, -1)
@@ -164,8 +186,25 @@ def biased_grouped_topk_impl(
164
186
  _, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
165
187
  topk_weights = scores.gather(1, topk_ids)
166
188
 
189
+ if n_share_experts_fusion:
190
+ topk_ids[:, -1] = torch.randint(
191
+ low=num_experts,
192
+ high=num_experts + n_share_experts_fusion,
193
+ size=(topk_ids.size(0),),
194
+ dtype=topk_ids.dtype,
195
+ device=topk_ids.device,
196
+ )
197
+ topk_weights[:, -1] = (
198
+ topk_weights[:, :-1].sum(dim=-1) / 2.5
199
+ ) # 2.5 is the routed_scaling_factor.
200
+
167
201
  if renormalize:
168
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
202
+ topk_weights_sum = (
203
+ topk_weights.sum(dim=-1, keepdim=True)
204
+ if n_share_experts_fusion == 0
205
+ else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
206
+ )
207
+ topk_weights = topk_weights / topk_weights_sum
169
208
 
170
209
  return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
171
210
 
@@ -179,6 +218,7 @@ def biased_grouped_topk(
179
218
  num_expert_group: int = 0,
180
219
  topk_group: int = 0,
181
220
  compiled: bool = True,
221
+ n_share_experts_fusion: int = 0,
182
222
  ):
183
223
  biased_grouped_topk_fn = (
184
224
  torch.compile(
@@ -195,6 +235,7 @@ def biased_grouped_topk(
195
235
  renormalize,
196
236
  num_expert_group,
197
237
  topk_group,
238
+ n_share_experts_fusion=n_share_experts_fusion,
198
239
  )
199
240
 
200
241
 
@@ -210,7 +251,10 @@ def select_experts(
210
251
  correction_bias: Optional[torch.Tensor] = None,
211
252
  torch_native: bool = False,
212
253
  ):
213
- # DeekSeekv2 uses grouped_top_k
254
+ n_share_experts_fusion = 0
255
+ if global_server_args_dict["n_share_experts_fusion"] is not None:
256
+ n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
257
+ # DeekSeek V2/V3/R1 serices models uses grouped_top_k
214
258
  if use_grouped_topk:
215
259
  assert topk_group is not None
216
260
  assert num_expert_group is not None
@@ -222,6 +266,7 @@ def select_experts(
222
266
  renormalize=renormalize,
223
267
  num_expert_group=num_expert_group,
224
268
  topk_group=topk_group,
269
+ n_share_experts_fusion=n_share_experts_fusion,
225
270
  )
226
271
  else:
227
272
  topk_weights, topk_ids = biased_grouped_topk(
@@ -232,6 +277,7 @@ def select_experts(
232
277
  renormalize=renormalize,
233
278
  num_expert_group=num_expert_group,
234
279
  topk_group=topk_group,
280
+ n_share_experts_fusion=n_share_experts_fusion,
235
281
  )
236
282
  elif torch_native and custom_routing_function is None:
237
283
  topk_weights, topk_ids = fused_topk_native(
@@ -51,7 +51,6 @@ except ImportError:
51
51
 
52
52
 
53
53
  from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
54
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
55
54
  from sglang.srt.layers.quantization.awq import AWQConfig
56
55
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
57
56
  from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
@@ -61,6 +60,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
61
60
  from sglang.srt.layers.quantization.fp8 import Fp8Config
62
61
  from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
63
62
  from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
63
+ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
64
64
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
65
65
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
66
66
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -75,6 +75,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
75
75
  "modelopt": ModelOptFp8Config,
76
76
  "w8a8_int8": W8A8Int8Config,
77
77
  "w8a8_fp8": W8A8Fp8Config,
78
+ "moe_wna16": MoeWNA16Config,
78
79
  "compressed-tensors": CompressedTensorsConfig,
79
80
  }
80
81
 
@@ -201,6 +202,8 @@ def get_linear_quant_method(
201
202
 
202
203
 
203
204
  def gptq_get_quant_method(self, layer, prefix):
205
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
206
+
204
207
  if isinstance(layer, FusedMoE):
205
208
  return GPTQMarlinMoEMethod(self)
206
209
 
@@ -277,6 +280,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
277
280
  custom_routing_function: Optional[Callable] = None,
278
281
  correction_bias: Optional[torch.Tensor] = None,
279
282
  activation: str = "silu",
283
+ apply_router_weight_on_input: bool = False,
280
284
  inplace: bool = True,
281
285
  no_combine: bool = False,
282
286
  ):
@@ -370,6 +370,7 @@ class BlockInt8MoEMethod:
370
370
  custom_routing_function: Optional[Callable] = None,
371
371
  correction_bias: Optional[torch.Tensor] = None,
372
372
  activation: str = "silu",
373
+ apply_router_weight_on_input: bool = False,
373
374
  inplace: bool = True,
374
375
  no_combine: bool = False,
375
376
  ) -> torch.Tensor:
@@ -398,6 +399,7 @@ class BlockInt8MoEMethod:
398
399
  topk_ids=topk_ids,
399
400
  inplace=inplace,
400
401
  activation=activation,
402
+ apply_router_weight_on_input=apply_router_weight_on_input,
401
403
  use_int8_w8a8=True,
402
404
  w1_scale=(layer.w13_weight_scale_inv),
403
405
  w2_scale=(layer.w2_weight_scale_inv),
@@ -23,7 +23,6 @@ from sglang.srt.layers.linear import (
23
23
  LinearMethodBase,
24
24
  UnquantizedLinearMethod,
25
25
  )
26
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
27
26
  from sglang.srt.layers.quantization.base_config import (
28
27
  QuantizationConfig,
29
28
  QuantizeMethodBase,
@@ -123,6 +122,8 @@ class CompressedTensorsConfig(QuantizationConfig):
123
122
  return UnquantizedLinearMethod()
124
123
  layer.scheme = scheme
125
124
  return CompressedTensorsLinearMethod(self)
125
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
126
+
126
127
  if isinstance(layer, FusedMoE):
127
128
  return CompressedTensorsMoEMethod.get_moe_method(self)
128
129
  return None
@@ -4,18 +4,19 @@
4
4
  import enum
5
5
  import logging
6
6
  from enum import Enum
7
- from typing import Callable, List, Optional
7
+ from typing import TYPE_CHECKING, Callable, List, Optional
8
8
 
9
9
  import torch
10
10
  from compressed_tensors import CompressionFormat
11
11
  from compressed_tensors.quantization import QuantizationStrategy
12
12
 
13
- from sglang.srt.layers.moe.fused_moe_triton import (
14
- FusedMoE,
15
- FusedMoEMethodBase,
16
- FusedMoeWeightScaleSupported,
17
- )
18
- from sglang.srt.layers.moe.topk import select_experts
13
+ if TYPE_CHECKING:
14
+ from sglang.srt.layers.moe.fused_moe_triton import (
15
+ FusedMoE,
16
+ FusedMoEMethodBase,
17
+ FusedMoeWeightScaleSupported,
18
+ )
19
+
19
20
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
20
21
  from sglang.srt.layers.quantization.utils import (
21
22
  all_close_1d,
@@ -55,7 +56,13 @@ __all__ = [
55
56
  ]
56
57
 
57
58
 
58
- class CompressedTensorsMoEMethod(FusedMoEMethodBase):
59
+ class CompressedTensorsMoEMethod:
60
+ def __new__(cls, *args, **kwargs):
61
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
62
+
63
+ if cls is CompressedTensorsMoEMethod:
64
+ return super().__new__(cls)
65
+ return super().__new__(cls)
59
66
 
60
67
  @staticmethod
61
68
  def get_moe_method(
@@ -85,6 +92,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
85
92
  def __init__(
86
93
  self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
87
94
  ):
95
+ from sglang.srt.layers.moe.fused_moe_triton import (
96
+ FusedMoEMethodBase,
97
+ FusedMoeWeightScaleSupported,
98
+ )
99
+
88
100
  self.quant_config = quant_config
89
101
  self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
90
102
  self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
@@ -112,6 +124,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
112
124
  params_dtype: torch.dtype,
113
125
  **extra_weight_attrs,
114
126
  ):
127
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
115
128
 
116
129
  params_dtype = torch.float8_e4m3fn
117
130
 
@@ -270,8 +283,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
270
283
  scoring_func: str = "softmax",
271
284
  correction_bias: Optional[torch.Tensor] = None,
272
285
  activation: str = "silu",
286
+ inplace: bool = True,
287
+ no_combine: bool = False,
273
288
  ) -> torch.Tensor:
274
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
289
+ from sglang.srt.layers.moe.fused_moe_triton import fused_experts
290
+ from sglang.srt.layers.moe.topk import select_experts
275
291
 
276
292
  topk_weights, topk_ids = select_experts(
277
293
  hidden_states=x,
@@ -291,7 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
291
307
  layer.w2_weight,
292
308
  topk_weights=topk_weights,
293
309
  topk_ids=topk_ids,
294
- inplace=True,
310
+ inplace=inplace,
295
311
  activation=activation,
296
312
  use_fp8_w8a8=True,
297
313
  w1_scale=layer.w13_weight_scale,
@@ -306,6 +322,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
306
322
  def __init__(
307
323
  self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
308
324
  ):
325
+ from sglang.srt.layers.moe.fused_moe_triton import (
326
+ FusedMoEMethodBase,
327
+ FusedMoeWeightScaleSupported,
328
+ )
329
+
309
330
  self.quant_config = quant_config
310
331
  # TODO: @dsikka: refactor this to use schemes as other kernels
311
332
  # are supported + check if the layer is being ignored.
@@ -617,6 +638,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
617
638
  correction_bias: Optional[torch.Tensor] = None,
618
639
  activation: str = "silu",
619
640
  ) -> torch.Tensor:
641
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
642
+ from sglang.srt.layers.moe.topk import select_experts
643
+
620
644
  assert activation == "silu", "Only SiLU activation is supported."
621
645
  if not VLLM_AVAILABLE:
622
646
  raise ImportError(
@@ -860,7 +860,7 @@ class Fp8MoEMethod:
860
860
  layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
861
861
  layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
862
862
 
863
- def process_weights_hip_scale_padding(self, layer: Module, padding_size: int):
863
+ def process_weights_hip_scale_padding(self, layer: Module):
864
864
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
865
865
  padding_size, # Avoid circular import
866
866
  )
@@ -905,6 +905,7 @@ class Fp8MoEMethod:
905
905
  custom_routing_function: Optional[Callable] = None,
906
906
  correction_bias: Optional[torch.Tensor] = None,
907
907
  activation: str = "silu",
908
+ apply_router_weight_on_input: bool = False,
908
909
  inplace: bool = True,
909
910
  no_combine: bool = False,
910
911
  ) -> torch.Tensor:
@@ -975,6 +976,7 @@ class Fp8MoEMethod:
975
976
  topk_ids=topk_ids,
976
977
  inplace=inplace and not no_combine,
977
978
  activation=activation,
979
+ apply_router_weight_on_input=apply_router_weight_on_input,
978
980
  use_fp8_w8a8=True,
979
981
  w1_scale=(
980
982
  layer.w13_weight_scale_inv
@@ -457,12 +457,9 @@ class Fp8LinearOp:
457
457
  qinput, x_scale = sgl_scaled_fp8_quant(
458
458
  input_2d,
459
459
  input_scale,
460
+ num_token_padding=self.output_padding,
460
461
  use_per_token_if_dynamic=use_per_token_if_dynamic,
461
462
  )
462
- if self.output_padding:
463
- pad_size = max(self.output_padding - qinput.shape[0], 0)
464
- if pad_size > 0:
465
- qinput = torch.nn.functional.pad(qinput, (0, 0, 0, pad_size))
466
463
  else:
467
464
  qinput, x_scale = ops.scaled_fp8_quant(
468
465
  input_2d,