sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +0 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  12. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  13. sglang/srt/constrained/xgrammar_backend.py +26 -4
  14. sglang/srt/custom_op.py +0 -62
  15. sglang/srt/disaggregation/decode.py +62 -6
  16. sglang/srt/disaggregation/mini_lb.py +5 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +32 -62
  18. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  19. sglang/srt/disaggregation/prefill.py +40 -4
  20. sglang/srt/disaggregation/utils.py +15 -0
  21. sglang/srt/entrypoints/verl_engine.py +7 -5
  22. sglang/srt/layers/activation.py +6 -8
  23. sglang/srt/layers/attention/flashattention_backend.py +114 -71
  24. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  25. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  26. sglang/srt/layers/attention/triton_backend.py +6 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  28. sglang/srt/layers/layernorm.py +1 -1
  29. sglang/srt/layers/linear.py +17 -3
  30. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  31. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  34. sglang/srt/layers/moe/topk.py +27 -30
  35. sglang/srt/layers/parameter.py +0 -2
  36. sglang/srt/layers/quantization/__init__.py +1 -0
  37. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  38. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
  39. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  40. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  41. sglang/srt/layers/quantization/fp8.py +115 -132
  42. sglang/srt/layers/quantization/fp8_kernel.py +213 -57
  43. sglang/srt/layers/quantization/fp8_utils.py +187 -262
  44. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  45. sglang/srt/layers/quantization/utils.py +5 -11
  46. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  47. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  48. sglang/srt/layers/radix_attention.py +15 -0
  49. sglang/srt/layers/rotary_embedding.py +3 -2
  50. sglang/srt/layers/sampler.py +5 -10
  51. sglang/srt/lora/backend/base_backend.py +18 -2
  52. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  53. sglang/srt/lora/backend/triton_backend.py +1 -1
  54. sglang/srt/lora/layers.py +1 -1
  55. sglang/srt/lora/lora.py +1 -1
  56. sglang/srt/lora/lora_manager.py +1 -1
  57. sglang/srt/managers/detokenizer_manager.py +0 -1
  58. sglang/srt/managers/io_struct.py +1 -0
  59. sglang/srt/managers/mm_utils.py +4 -3
  60. sglang/srt/managers/multimodal_processor.py +0 -2
  61. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  62. sglang/srt/managers/schedule_batch.py +2 -4
  63. sglang/srt/managers/scheduler.py +12 -71
  64. sglang/srt/managers/tokenizer_manager.py +1 -0
  65. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  66. sglang/srt/mem_cache/memory_pool.py +7 -2
  67. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  68. sglang/srt/model_executor/model_runner.py +20 -27
  69. sglang/srt/models/bert.py +398 -0
  70. sglang/srt/models/deepseek.py +1 -1
  71. sglang/srt/models/deepseek_nextn.py +74 -70
  72. sglang/srt/models/deepseek_v2.py +289 -348
  73. sglang/srt/models/llama.py +5 -5
  74. sglang/srt/models/minicpm3.py +29 -201
  75. sglang/srt/models/qwen2.py +4 -1
  76. sglang/srt/models/qwen2_moe.py +14 -13
  77. sglang/srt/models/qwen3.py +335 -0
  78. sglang/srt/models/qwen3_moe.py +423 -0
  79. sglang/srt/reasoning_parser.py +0 -1
  80. sglang/srt/sampling/sampling_batch_info.py +2 -3
  81. sglang/srt/server_args.py +34 -32
  82. sglang/srt/speculative/eagle_worker.py +4 -7
  83. sglang/srt/utils.py +16 -1
  84. sglang/test/runners.py +5 -1
  85. sglang/test/test_block_fp8.py +167 -0
  86. sglang/test/test_custom_ops.py +1 -1
  87. sglang/version.py +1 -1
  88. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
  89. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
  90. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  91. sglang/lang/__init__.py +0 -0
  92. sglang/srt/lora/backend/__init__.py +0 -25
  93. sglang/srt/server.py +0 -18
  94. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  95. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -131,6 +131,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
131
131
  apply_router_weight_on_input: bool = False,
132
132
  inplace: bool = True,
133
133
  no_combine: bool = False,
134
+ routed_scaling_factor: Optional[float] = None,
134
135
  ) -> torch.Tensor:
135
136
  return self.forward(
136
137
  x=x,
@@ -147,6 +148,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
147
148
  apply_router_weight_on_input=apply_router_weight_on_input,
148
149
  inplace=inplace,
149
150
  no_combine=no_combine,
151
+ routed_scaling_factor=routed_scaling_factor,
150
152
  )
151
153
 
152
154
  def forward_cuda(
@@ -165,6 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
165
167
  apply_router_weight_on_input: bool = False,
166
168
  inplace: bool = True,
167
169
  no_combine: bool = False,
170
+ routed_scaling_factor: Optional[float] = None,
168
171
  ) -> torch.Tensor:
169
172
  topk_weights, topk_ids = select_experts(
170
173
  hidden_states=x,
@@ -176,6 +179,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
176
179
  num_expert_group=num_expert_group,
177
180
  custom_routing_function=custom_routing_function,
178
181
  correction_bias=correction_bias,
182
+ routed_scaling_factor=routed_scaling_factor,
179
183
  )
180
184
 
181
185
  if _is_hip and get_bool_env_var("CK_MOE"):
@@ -284,6 +288,7 @@ class FusedMoE(torch.nn.Module):
284
288
  use_presharded_weights: bool = False,
285
289
  inplace: bool = True,
286
290
  no_combine: bool = False,
291
+ routed_scaling_factor: Optional[float] = None,
287
292
  ):
288
293
  super().__init__()
289
294
 
@@ -293,6 +298,7 @@ class FusedMoE(torch.nn.Module):
293
298
  self.tp_size = (
294
299
  tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
295
300
  )
301
+ self.routed_scaling_factor = routed_scaling_factor
296
302
  self.top_k = top_k
297
303
  self.num_experts = num_experts
298
304
  assert intermediate_size % self.tp_size == 0
@@ -637,6 +643,7 @@ class FusedMoE(torch.nn.Module):
637
643
  correction_bias=self.correction_bias,
638
644
  activation=self.activation,
639
645
  apply_router_weight_on_input=self.apply_router_weight_on_input,
646
+ routed_scaling_factor=self.routed_scaling_factor,
640
647
  )
641
648
 
642
649
  if self.reduce_results and self.tp_size > 1:
@@ -13,7 +13,6 @@
13
13
  # ==============================================================================
14
14
 
15
15
  import math
16
- import os
17
16
  from typing import Callable, Optional
18
17
 
19
18
  import torch
@@ -29,6 +28,10 @@ _is_hip = is_hip()
29
28
  if _is_cuda:
30
29
  from sgl_kernel import moe_fused_gate
31
30
 
31
+ if _is_cuda or _is_hip:
32
+ from sgl_kernel import topk_softmax
33
+
34
+
32
35
  expert_distribution_recorder = ExpertDistributionRecorder()
33
36
 
34
37
 
@@ -59,11 +62,6 @@ def fused_topk(
59
62
  topk: int,
60
63
  renormalize: bool,
61
64
  ):
62
- if _is_cuda or _is_hip:
63
- from sgl_kernel import topk_softmax
64
- else:
65
- from vllm import _custom_ops as vllm_ops
66
-
67
65
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
68
66
 
69
67
  M, _ = hidden_states.shape
@@ -76,20 +74,12 @@ def fused_topk(
76
74
  M, topk, dtype=torch.int32, device=hidden_states.device
77
75
  )
78
76
 
79
- if _is_cuda or _is_hip:
80
- topk_softmax(
81
- topk_weights,
82
- topk_ids,
83
- token_expert_indicies,
84
- gating_output.float(),
85
- )
86
- else:
87
- vllm_ops.topk_softmax(
88
- topk_weights,
89
- topk_ids,
90
- token_expert_indicies,
91
- gating_output.float(),
92
- )
77
+ topk_softmax(
78
+ topk_weights,
79
+ topk_ids,
80
+ token_expert_indicies,
81
+ gating_output.float(),
82
+ )
93
83
  del token_expert_indicies
94
84
 
95
85
  if renormalize:
@@ -108,6 +98,7 @@ def grouped_topk(
108
98
  num_expert_group: int = 0,
109
99
  topk_group: int = 0,
110
100
  n_share_experts_fusion: int = 0,
101
+ routed_scaling_factor: Optional[float] = None,
111
102
  ):
112
103
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
113
104
 
@@ -137,9 +128,7 @@ def grouped_topk(
137
128
  dtype=topk_ids.dtype,
138
129
  device=topk_ids.device,
139
130
  )
140
- topk_weights[:, -1] = (
141
- topk_weights[:, :-1].sum(dim=-1) / 2.5
142
- ) # 2.5 is the routed_scaling_factor.
131
+ topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
143
132
 
144
133
  if renormalize:
145
134
  topk_weights_sum = (
@@ -161,6 +150,7 @@ def biased_grouped_topk_impl(
161
150
  num_expert_group: int = 0,
162
151
  topk_group: int = 0,
163
152
  n_share_experts_fusion: int = 0,
153
+ routed_scaling_factor: Optional[float] = None,
164
154
  ):
165
155
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
166
156
 
@@ -197,9 +187,7 @@ def biased_grouped_topk_impl(
197
187
  dtype=topk_ids.dtype,
198
188
  device=topk_ids.device,
199
189
  )
200
- topk_weights[:, -1] = (
201
- topk_weights[:, :-1].sum(dim=-1) / 2.5
202
- ) # 2.5 is the routed_scaling_factor.
190
+ topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
203
191
 
204
192
  if renormalize:
205
193
  topk_weights_sum = (
@@ -226,11 +214,16 @@ def biased_grouped_topk(
226
214
  topk_group: int = 0,
227
215
  compiled: bool = True,
228
216
  n_share_experts_fusion: int = 0,
217
+ routed_scaling_factor: Optional[float] = None,
229
218
  ):
219
+ assert (
220
+ routed_scaling_factor is not None
221
+ ), "routed_scaling_factor is required for biased_grouped_topk"
230
222
  # TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
231
223
  if (
232
224
  _is_cuda
233
- and n_share_experts_fusion == 0
225
+ and gating_output.shape[1] // num_expert_group
226
+ <= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
234
227
  and is_power_of_two(correction_bias.shape[0])
235
228
  ):
236
229
  return moe_fused_gate(
@@ -239,6 +232,8 @@ def biased_grouped_topk(
239
232
  num_expert_group,
240
233
  topk_group,
241
234
  topk,
235
+ n_share_experts_fusion,
236
+ routed_scaling_factor,
242
237
  )
243
238
  else:
244
239
  biased_grouped_topk_fn = (
@@ -257,6 +252,7 @@ def biased_grouped_topk(
257
252
  num_expert_group,
258
253
  topk_group,
259
254
  n_share_experts_fusion=n_share_experts_fusion,
255
+ routed_scaling_factor=routed_scaling_factor,
260
256
  )
261
257
 
262
258
 
@@ -271,10 +267,9 @@ def select_experts(
271
267
  custom_routing_function: Optional[Callable] = None,
272
268
  correction_bias: Optional[torch.Tensor] = None,
273
269
  torch_native: bool = False,
270
+ routed_scaling_factor: Optional[float] = None,
274
271
  ):
275
- n_share_experts_fusion = 0
276
- if global_server_args_dict["n_share_experts_fusion"] is not None:
277
- n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
272
+ n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
278
273
  # DeekSeek V2/V3/R1 serices models uses grouped_top_k
279
274
  if use_grouped_topk:
280
275
  assert topk_group is not None
@@ -288,6 +283,7 @@ def select_experts(
288
283
  num_expert_group=num_expert_group,
289
284
  topk_group=topk_group,
290
285
  n_share_experts_fusion=n_share_experts_fusion,
286
+ routed_scaling_factor=routed_scaling_factor,
291
287
  )
292
288
  else:
293
289
  topk_weights, topk_ids = biased_grouped_topk(
@@ -299,6 +295,7 @@ def select_experts(
299
295
  num_expert_group=num_expert_group,
300
296
  topk_group=topk_group,
301
297
  n_share_experts_fusion=n_share_experts_fusion,
298
+ routed_scaling_factor=routed_scaling_factor,
302
299
  )
303
300
  elif torch_native and custom_routing_function is None:
304
301
  topk_weights, topk_ids = fused_topk_native(
@@ -7,8 +7,6 @@ from typing import Callable, Optional, Union
7
7
  import torch
8
8
  from torch.nn import Parameter
9
9
 
10
- from sglang.srt.distributed import get_tensor_model_parallel_rank
11
-
12
10
  __all__ = [
13
11
  "BasevLLMParameter",
14
12
  "PackedvLLMParameter",
@@ -290,6 +290,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
290
290
  apply_router_weight_on_input: bool = False,
291
291
  inplace: bool = True,
292
292
  no_combine: bool = False,
293
+ routed_scaling_factor: Optional[float] = None,
293
294
  ):
294
295
  assert activation == "silu"
295
296
  assert inplace and not no_combine
@@ -373,6 +373,7 @@ class BlockInt8MoEMethod:
373
373
  apply_router_weight_on_input: bool = False,
374
374
  inplace: bool = True,
375
375
  no_combine: bool = False,
376
+ routed_scaling_factor: Optional[float] = None,
376
377
  ) -> torch.Tensor:
377
378
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
378
379
  from sglang.srt.layers.moe.topk import select_experts
@@ -388,6 +389,7 @@ class BlockInt8MoEMethod:
388
389
  num_expert_group=num_expert_group,
389
390
  custom_routing_function=custom_routing_function,
390
391
  correction_bias=correction_bias,
392
+ routed_scaling_factor=routed_scaling_factor,
391
393
  )
392
394
 
393
395
  # Expert fusion with INT8 quantization
@@ -1,4 +1,4 @@
1
- # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  import logging
@@ -39,7 +39,13 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
39
39
  is_activation_quantization_format,
40
40
  should_ignore_layer,
41
41
  )
42
- from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
42
+
43
+ try:
44
+ import vllm
45
+
46
+ VLLM_AVAILABLE = True
47
+ except ImportError:
48
+ VLLM_AVAILABLE = False
43
49
 
44
50
  logger = logging.getLogger(__name__)
45
51
 
@@ -1,22 +1,16 @@
1
- # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  import enum
5
5
  import logging
6
6
  from enum import Enum
7
- from typing import TYPE_CHECKING, Callable, List, Optional
7
+ from typing import Callable, List, Optional
8
8
 
9
9
  import torch
10
10
  from compressed_tensors import CompressionFormat
11
11
  from compressed_tensors.quantization import QuantizationStrategy
12
12
 
13
- if TYPE_CHECKING:
14
- from sglang.srt.layers.moe.fused_moe_triton import (
15
- FusedMoE,
16
- FusedMoEMethodBase,
17
- FusedMoeWeightScaleSupported,
18
- )
19
-
13
+ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
20
14
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
21
15
  from sglang.srt.layers.quantization.utils import (
22
16
  all_close_1d,
@@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs
29
23
 
30
24
  _is_cuda = is_cuda()
31
25
 
32
- if _is_cuda:
33
- from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
34
- else:
26
+ if not _is_cuda:
35
27
  from vllm import _custom_ops as vllm_ops
28
+ from vllm._custom_ops import scaled_fp8_quant
36
29
 
37
30
  try:
38
31
  import vllm
@@ -58,8 +51,6 @@ __all__ = [
58
51
 
59
52
  class CompressedTensorsMoEMethod:
60
53
  def __new__(cls, *args, **kwargs):
61
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
62
-
63
54
  if cls is CompressedTensorsMoEMethod:
64
55
  return super().__new__(cls)
65
56
  return super().__new__(cls)
@@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod:
76
67
  if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
77
68
  if not VLLM_AVAILABLE:
78
69
  raise ImportError(
79
- "vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm"
70
+ "vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
80
71
  )
81
72
  return CompressedTensorsWNA16MoEMethod(quant_config)
82
73
  elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
@@ -92,11 +83,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
92
83
  def __init__(
93
84
  self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
94
85
  ):
95
- from sglang.srt.layers.moe.fused_moe_triton import (
96
- FusedMoEMethodBase,
97
- FusedMoeWeightScaleSupported,
98
- )
99
-
100
86
  self.quant_config = quant_config
101
87
  self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
102
88
  self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
@@ -267,19 +253,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
267
253
  layer.w13_weight[expert_id][start : start + shard_size, :],
268
254
  layer.w13_weight_scale[expert_id][shard_id],
269
255
  )
256
+ (
257
+ layer.w13_weight[expert_id][start : start + shard_size, :],
258
+ _,
259
+ ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
270
260
 
271
- if _is_cuda:
272
- (
273
- layer.w13_weight[expert_id][start : start + shard_size, :],
274
- _,
275
- ) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
276
- else:
277
- (
278
- layer.w13_weight[expert_id][start : start + shard_size, :],
279
- _,
280
- ) = vllm_ops.scaled_fp8_quant(
281
- dq_weight, max_w13_scales[expert_id]
282
- )
283
261
  start += shard_size
284
262
 
285
263
  layer.w13_weight_scale = torch.nn.Parameter(
@@ -305,6 +283,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
305
283
  inplace: bool = True,
306
284
  no_combine: bool = False,
307
285
  apply_router_weight_on_input: bool = False,
286
+ routed_scaling_factor: Optional[float] = None,
308
287
  ) -> torch.Tensor:
309
288
  from sglang.srt.layers.moe.fused_moe_triton import fused_experts
310
289
  from sglang.srt.layers.moe.topk import select_experts
@@ -319,6 +298,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
319
298
  num_expert_group=num_expert_group,
320
299
  custom_routing_function=custom_routing_function,
321
300
  correction_bias=correction_bias,
301
+ routed_scaling_factor=routed_scaling_factor,
322
302
  )
323
303
 
324
304
  return fused_experts(
@@ -345,11 +325,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
345
325
  def __init__(
346
326
  self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
347
327
  ):
348
- from sglang.srt.layers.moe.fused_moe_triton import (
349
- FusedMoEMethodBase,
350
- FusedMoeWeightScaleSupported,
351
- )
352
-
353
328
  self.quant_config = quant_config
354
329
  # TODO: @dsikka: refactor this to use schemes as other kernels
355
330
  # are supported + check if the layer is being ignored.
@@ -609,7 +584,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
609
584
  requires_grad=False,
610
585
  )
611
586
 
612
- marlin_w13_qweight = ops.gptq_marlin_moe_repack(
587
+ marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
613
588
  layer.w13_weight_packed,
614
589
  layer.w13_g_idx_sort_indices,
615
590
  layer.w13_weight_packed.shape[1] * self.packed_factor,
@@ -617,7 +592,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
617
592
  self.num_bits,
618
593
  )
619
594
  replace_tensor("w13_weight_packed", marlin_w13_qweight)
620
- marlin_w2_qweight = ops.gptq_marlin_moe_repack(
595
+ marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack(
621
596
  layer.w2_weight_packed,
622
597
  layer.w2_g_idx_sort_indices,
623
598
  layer.w2_weight_packed.shape[1] * self.packed_factor,
@@ -660,15 +635,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
660
635
  scoring_func: str = "softmax",
661
636
  correction_bias: Optional[torch.Tensor] = None,
662
637
  activation: str = "silu",
638
+ routed_scaling_factor: Optional[float] = None,
663
639
  ) -> torch.Tensor:
664
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
665
640
  from sglang.srt.layers.moe.topk import select_experts
666
641
 
667
642
  assert activation == "silu", "Only SiLU activation is supported."
668
- if not VLLM_AVAILABLE:
669
- raise ImportError(
670
- "vllm is not installed, to use fused_marlin_moe, please install vllm"
671
- )
672
643
  if expert_map is not None:
673
644
  raise NotImplementedError(
674
645
  "Expert Parallelism is not supported for " "fused Marlin MoE method."
@@ -685,6 +656,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
685
656
  custom_routing_function=custom_routing_function,
686
657
  scoring_func=scoring_func,
687
658
  correction_bias=correction_bias,
659
+ routed_scaling_factor=routed_scaling_factor,
688
660
  )
689
661
 
690
662
  return torch.ops.vllm.fused_marlin_moe(
@@ -16,8 +16,7 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
16
16
  CompressedTensorsScheme,
17
17
  )
18
18
  from sglang.srt.layers.quantization.fp8_utils import (
19
- Fp8LinearOp,
20
- maybe_create_device_identity,
19
+ apply_fp8_linear,
21
20
  normalize_e4m3fn_to_e4m3fnuz,
22
21
  )
23
22
  from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
@@ -30,7 +29,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
30
29
  def __init__(self, strategy: str, is_static_input_scheme: bool):
31
30
  self.strategy = strategy
32
31
  self.is_static_input_scheme = is_static_input_scheme
33
- self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
34
32
 
35
33
  @classmethod
36
34
  def get_min_capability(cls) -> int:
@@ -99,8 +97,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
99
97
  weight_loader: Callable,
100
98
  **kwargs,
101
99
  ):
102
- maybe_create_device_identity()
103
-
104
100
  output_size_per_partition = sum(output_partition_sizes)
105
101
  layer.logical_widths = output_partition_sizes
106
102
 
@@ -152,11 +148,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
152
148
  x: torch.Tensor,
153
149
  bias: Optional[torch.Tensor] = None,
154
150
  ) -> torch.Tensor:
155
-
156
- return self.fp8_linear.apply(
151
+ return apply_fp8_linear(
157
152
  input=x,
158
153
  weight=layer.weight,
159
154
  weight_scale=layer.weight_scale,
160
155
  input_scale=layer.input_scale,
161
156
  bias=bias,
157
+ use_per_token_if_dynamic=True,
158
+ compressed_tensor_quant=True,
162
159
  )