sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +3 -6
  4. sglang/compile_deep_gemm.py +136 -0
  5. sglang/lang/backend/anthropic.py +0 -4
  6. sglang/lang/backend/base_backend.py +1 -1
  7. sglang/lang/backend/openai.py +6 -2
  8. sglang/lang/backend/runtime_endpoint.py +5 -1
  9. sglang/lang/backend/vertexai.py +0 -1
  10. sglang/lang/compiler.py +1 -7
  11. sglang/lang/tracer.py +3 -7
  12. sglang/srt/_custom_ops.py +0 -2
  13. sglang/srt/configs/model_config.py +4 -1
  14. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  15. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  16. sglang/srt/constrained/xgrammar_backend.py +27 -4
  17. sglang/srt/custom_op.py +0 -62
  18. sglang/srt/disaggregation/decode.py +105 -6
  19. sglang/srt/disaggregation/mini_lb.py +74 -9
  20. sglang/srt/disaggregation/mooncake/conn.py +33 -63
  21. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  22. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  23. sglang/srt/disaggregation/nixl/conn.py +622 -0
  24. sglang/srt/disaggregation/prefill.py +137 -17
  25. sglang/srt/disaggregation/utils.py +32 -0
  26. sglang/srt/entrypoints/engine.py +4 -0
  27. sglang/srt/entrypoints/http_server.py +3 -7
  28. sglang/srt/entrypoints/verl_engine.py +7 -5
  29. sglang/srt/function_call_parser.py +60 -0
  30. sglang/srt/layers/activation.py +6 -8
  31. sglang/srt/layers/attention/flashattention_backend.py +883 -209
  32. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  33. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  34. sglang/srt/layers/attention/triton_backend.py +6 -0
  35. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  36. sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
  37. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  38. sglang/srt/layers/dp_attention.py +1 -1
  39. sglang/srt/layers/layernorm.py +20 -5
  40. sglang/srt/layers/linear.py +17 -3
  41. sglang/srt/layers/moe/ep_moe/layer.py +17 -29
  42. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  45. sglang/srt/layers/moe/topk.py +27 -30
  46. sglang/srt/layers/parameter.py +0 -2
  47. sglang/srt/layers/quantization/__init__.py +1 -0
  48. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  49. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  52. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  53. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  54. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  55. sglang/srt/layers/quantization/fp8.py +115 -132
  56. sglang/srt/layers/quantization/fp8_kernel.py +213 -88
  57. sglang/srt/layers/quantization/fp8_utils.py +189 -264
  58. sglang/srt/layers/quantization/gptq.py +13 -7
  59. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/utils.py +5 -11
  62. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  63. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  64. sglang/srt/layers/radix_attention.py +15 -0
  65. sglang/srt/layers/rotary_embedding.py +9 -8
  66. sglang/srt/layers/sampler.py +7 -12
  67. sglang/srt/lora/backend/base_backend.py +18 -2
  68. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  69. sglang/srt/lora/backend/triton_backend.py +1 -1
  70. sglang/srt/lora/layers.py +1 -1
  71. sglang/srt/lora/lora.py +1 -1
  72. sglang/srt/lora/lora_manager.py +1 -1
  73. sglang/srt/managers/data_parallel_controller.py +7 -1
  74. sglang/srt/managers/detokenizer_manager.py +0 -1
  75. sglang/srt/managers/io_struct.py +15 -3
  76. sglang/srt/managers/mm_utils.py +4 -3
  77. sglang/srt/managers/multimodal_processor.py +0 -2
  78. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  79. sglang/srt/managers/schedule_batch.py +15 -4
  80. sglang/srt/managers/scheduler.py +28 -77
  81. sglang/srt/managers/tokenizer_manager.py +116 -29
  82. sglang/srt/managers/tp_worker.py +1 -0
  83. sglang/srt/mem_cache/hiradix_cache.py +41 -29
  84. sglang/srt/mem_cache/memory_pool.py +38 -15
  85. sglang/srt/model_executor/cuda_graph_runner.py +15 -10
  86. sglang/srt/model_executor/model_runner.py +39 -31
  87. sglang/srt/models/bert.py +398 -0
  88. sglang/srt/models/deepseek.py +1 -1
  89. sglang/srt/models/deepseek_nextn.py +74 -70
  90. sglang/srt/models/deepseek_v2.py +292 -348
  91. sglang/srt/models/llama.py +5 -5
  92. sglang/srt/models/minicpm3.py +31 -203
  93. sglang/srt/models/minicpmo.py +17 -6
  94. sglang/srt/models/qwen2.py +4 -1
  95. sglang/srt/models/qwen2_moe.py +14 -13
  96. sglang/srt/models/qwen3.py +335 -0
  97. sglang/srt/models/qwen3_moe.py +423 -0
  98. sglang/srt/openai_api/adapter.py +71 -4
  99. sglang/srt/openai_api/protocol.py +6 -1
  100. sglang/srt/reasoning_parser.py +0 -1
  101. sglang/srt/sampling/sampling_batch_info.py +2 -3
  102. sglang/srt/server_args.py +86 -72
  103. sglang/srt/speculative/build_eagle_tree.py +2 -2
  104. sglang/srt/speculative/eagle_utils.py +2 -2
  105. sglang/srt/speculative/eagle_worker.py +6 -14
  106. sglang/srt/utils.py +62 -6
  107. sglang/test/runners.py +5 -1
  108. sglang/test/test_block_fp8.py +167 -0
  109. sglang/test/test_custom_ops.py +1 -1
  110. sglang/test/test_utils.py +3 -1
  111. sglang/version.py +1 -1
  112. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
  113. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
  114. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
  115. sglang/lang/__init__.py +0 -0
  116. sglang/srt/lora/backend/__init__.py +0 -25
  117. sglang/srt/server.py +0 -18
  118. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@
13
13
  # ==============================================================================
14
14
 
15
15
  import math
16
- import os
17
16
  from typing import Callable, Optional
18
17
 
19
18
  import torch
@@ -29,6 +28,10 @@ _is_hip = is_hip()
29
28
  if _is_cuda:
30
29
  from sgl_kernel import moe_fused_gate
31
30
 
31
+ if _is_cuda or _is_hip:
32
+ from sgl_kernel import topk_softmax
33
+
34
+
32
35
  expert_distribution_recorder = ExpertDistributionRecorder()
33
36
 
34
37
 
@@ -59,11 +62,6 @@ def fused_topk(
59
62
  topk: int,
60
63
  renormalize: bool,
61
64
  ):
62
- if _is_cuda or _is_hip:
63
- from sgl_kernel import topk_softmax
64
- else:
65
- from vllm import _custom_ops as vllm_ops
66
-
67
65
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
68
66
 
69
67
  M, _ = hidden_states.shape
@@ -76,20 +74,12 @@ def fused_topk(
76
74
  M, topk, dtype=torch.int32, device=hidden_states.device
77
75
  )
78
76
 
79
- if _is_cuda or _is_hip:
80
- topk_softmax(
81
- topk_weights,
82
- topk_ids,
83
- token_expert_indicies,
84
- gating_output.float(),
85
- )
86
- else:
87
- vllm_ops.topk_softmax(
88
- topk_weights,
89
- topk_ids,
90
- token_expert_indicies,
91
- gating_output.float(),
92
- )
77
+ topk_softmax(
78
+ topk_weights,
79
+ topk_ids,
80
+ token_expert_indicies,
81
+ gating_output.float(),
82
+ )
93
83
  del token_expert_indicies
94
84
 
95
85
  if renormalize:
@@ -108,6 +98,7 @@ def grouped_topk(
108
98
  num_expert_group: int = 0,
109
99
  topk_group: int = 0,
110
100
  n_share_experts_fusion: int = 0,
101
+ routed_scaling_factor: Optional[float] = None,
111
102
  ):
112
103
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
113
104
 
@@ -137,9 +128,7 @@ def grouped_topk(
137
128
  dtype=topk_ids.dtype,
138
129
  device=topk_ids.device,
139
130
  )
140
- topk_weights[:, -1] = (
141
- topk_weights[:, :-1].sum(dim=-1) / 2.5
142
- ) # 2.5 is the routed_scaling_factor.
131
+ topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
143
132
 
144
133
  if renormalize:
145
134
  topk_weights_sum = (
@@ -161,6 +150,7 @@ def biased_grouped_topk_impl(
161
150
  num_expert_group: int = 0,
162
151
  topk_group: int = 0,
163
152
  n_share_experts_fusion: int = 0,
153
+ routed_scaling_factor: Optional[float] = None,
164
154
  ):
165
155
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
166
156
 
@@ -197,9 +187,7 @@ def biased_grouped_topk_impl(
197
187
  dtype=topk_ids.dtype,
198
188
  device=topk_ids.device,
199
189
  )
200
- topk_weights[:, -1] = (
201
- topk_weights[:, :-1].sum(dim=-1) / 2.5
202
- ) # 2.5 is the routed_scaling_factor.
190
+ topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
203
191
 
204
192
  if renormalize:
205
193
  topk_weights_sum = (
@@ -226,11 +214,16 @@ def biased_grouped_topk(
226
214
  topk_group: int = 0,
227
215
  compiled: bool = True,
228
216
  n_share_experts_fusion: int = 0,
217
+ routed_scaling_factor: Optional[float] = None,
229
218
  ):
219
+ assert (
220
+ routed_scaling_factor is not None
221
+ ), "routed_scaling_factor is required for biased_grouped_topk"
230
222
  # TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
231
223
  if (
232
224
  _is_cuda
233
- and n_share_experts_fusion == 0
225
+ and gating_output.shape[1] // num_expert_group
226
+ <= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
234
227
  and is_power_of_two(correction_bias.shape[0])
235
228
  ):
236
229
  return moe_fused_gate(
@@ -239,6 +232,8 @@ def biased_grouped_topk(
239
232
  num_expert_group,
240
233
  topk_group,
241
234
  topk,
235
+ n_share_experts_fusion,
236
+ routed_scaling_factor,
242
237
  )
243
238
  else:
244
239
  biased_grouped_topk_fn = (
@@ -257,6 +252,7 @@ def biased_grouped_topk(
257
252
  num_expert_group,
258
253
  topk_group,
259
254
  n_share_experts_fusion=n_share_experts_fusion,
255
+ routed_scaling_factor=routed_scaling_factor,
260
256
  )
261
257
 
262
258
 
@@ -271,10 +267,9 @@ def select_experts(
271
267
  custom_routing_function: Optional[Callable] = None,
272
268
  correction_bias: Optional[torch.Tensor] = None,
273
269
  torch_native: bool = False,
270
+ routed_scaling_factor: Optional[float] = None,
274
271
  ):
275
- n_share_experts_fusion = 0
276
- if global_server_args_dict["n_share_experts_fusion"] is not None:
277
- n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
272
+ n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
278
273
  # DeekSeek V2/V3/R1 serices models uses grouped_top_k
279
274
  if use_grouped_topk:
280
275
  assert topk_group is not None
@@ -288,6 +283,7 @@ def select_experts(
288
283
  num_expert_group=num_expert_group,
289
284
  topk_group=topk_group,
290
285
  n_share_experts_fusion=n_share_experts_fusion,
286
+ routed_scaling_factor=routed_scaling_factor,
291
287
  )
292
288
  else:
293
289
  topk_weights, topk_ids = biased_grouped_topk(
@@ -299,6 +295,7 @@ def select_experts(
299
295
  num_expert_group=num_expert_group,
300
296
  topk_group=topk_group,
301
297
  n_share_experts_fusion=n_share_experts_fusion,
298
+ routed_scaling_factor=routed_scaling_factor,
302
299
  )
303
300
  elif torch_native and custom_routing_function is None:
304
301
  topk_weights, topk_ids = fused_topk_native(
@@ -7,8 +7,6 @@ from typing import Callable, Optional, Union
7
7
  import torch
8
8
  from torch.nn import Parameter
9
9
 
10
- from sglang.srt.distributed import get_tensor_model_parallel_rank
11
-
12
10
  __all__ = [
13
11
  "BasevLLMParameter",
14
12
  "PackedvLLMParameter",
@@ -290,6 +290,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
290
290
  apply_router_weight_on_input: bool = False,
291
291
  inplace: bool = True,
292
292
  no_combine: bool = False,
293
+ routed_scaling_factor: Optional[float] = None,
293
294
  ):
294
295
  assert activation == "silu"
295
296
  assert inplace and not no_combine
@@ -373,6 +373,7 @@ class BlockInt8MoEMethod:
373
373
  apply_router_weight_on_input: bool = False,
374
374
  inplace: bool = True,
375
375
  no_combine: bool = False,
376
+ routed_scaling_factor: Optional[float] = None,
376
377
  ) -> torch.Tensor:
377
378
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
378
379
  from sglang.srt.layers.moe.topk import select_experts
@@ -388,6 +389,7 @@ class BlockInt8MoEMethod:
388
389
  num_expert_group=num_expert_group,
389
390
  custom_routing_function=custom_routing_function,
390
391
  correction_bias=correction_bias,
392
+ routed_scaling_factor=routed_scaling_factor,
391
393
  )
392
394
 
393
395
  # Expert fusion with INT8 quantization
@@ -1,4 +1,4 @@
1
- # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  import logging
@@ -33,13 +33,20 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
33
33
  from sglang.srt.layers.quantization.compressed_tensors.schemes import (
34
34
  CompressedTensorsScheme,
35
35
  CompressedTensorsW8A8Fp8,
36
+ CompressedTensorsW8A16Fp8,
36
37
  )
37
38
  from sglang.srt.layers.quantization.compressed_tensors.utils import (
38
39
  find_matched_target,
39
40
  is_activation_quantization_format,
40
41
  should_ignore_layer,
41
42
  )
42
- from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
43
+
44
+ try:
45
+ import vllm
46
+
47
+ VLLM_AVAILABLE = True
48
+ except ImportError:
49
+ VLLM_AVAILABLE = False
43
50
 
44
51
  logger = logging.getLogger(__name__)
45
52
 
@@ -1,22 +1,16 @@
1
- # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  import enum
5
5
  import logging
6
6
  from enum import Enum
7
- from typing import TYPE_CHECKING, Callable, List, Optional
7
+ from typing import Callable, List, Optional
8
8
 
9
9
  import torch
10
10
  from compressed_tensors import CompressionFormat
11
11
  from compressed_tensors.quantization import QuantizationStrategy
12
12
 
13
- if TYPE_CHECKING:
14
- from sglang.srt.layers.moe.fused_moe_triton import (
15
- FusedMoE,
16
- FusedMoEMethodBase,
17
- FusedMoeWeightScaleSupported,
18
- )
19
-
13
+ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
20
14
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
21
15
  from sglang.srt.layers.quantization.utils import (
22
16
  all_close_1d,
@@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs
29
23
 
30
24
  _is_cuda = is_cuda()
31
25
 
32
- if _is_cuda:
33
- from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
34
- else:
26
+ if not _is_cuda:
35
27
  from vllm import _custom_ops as vllm_ops
28
+ from vllm._custom_ops import scaled_fp8_quant
36
29
 
37
30
  try:
38
31
  import vllm
@@ -58,8 +51,6 @@ __all__ = [
58
51
 
59
52
  class CompressedTensorsMoEMethod:
60
53
  def __new__(cls, *args, **kwargs):
61
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
62
-
63
54
  if cls is CompressedTensorsMoEMethod:
64
55
  return super().__new__(cls)
65
56
  return super().__new__(cls)
@@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod:
76
67
  if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
77
68
  if not VLLM_AVAILABLE:
78
69
  raise ImportError(
79
- "vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm"
70
+ "vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
80
71
  )
81
72
  return CompressedTensorsWNA16MoEMethod(quant_config)
82
73
  elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
@@ -92,11 +83,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
92
83
  def __init__(
93
84
  self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
94
85
  ):
95
- from sglang.srt.layers.moe.fused_moe_triton import (
96
- FusedMoEMethodBase,
97
- FusedMoeWeightScaleSupported,
98
- )
99
-
100
86
  self.quant_config = quant_config
101
87
  self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
102
88
  self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
@@ -267,19 +253,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
267
253
  layer.w13_weight[expert_id][start : start + shard_size, :],
268
254
  layer.w13_weight_scale[expert_id][shard_id],
269
255
  )
256
+ (
257
+ layer.w13_weight[expert_id][start : start + shard_size, :],
258
+ _,
259
+ ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
270
260
 
271
- if _is_cuda:
272
- (
273
- layer.w13_weight[expert_id][start : start + shard_size, :],
274
- _,
275
- ) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
276
- else:
277
- (
278
- layer.w13_weight[expert_id][start : start + shard_size, :],
279
- _,
280
- ) = vllm_ops.scaled_fp8_quant(
281
- dq_weight, max_w13_scales[expert_id]
282
- )
283
261
  start += shard_size
284
262
 
285
263
  layer.w13_weight_scale = torch.nn.Parameter(
@@ -305,6 +283,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
305
283
  inplace: bool = True,
306
284
  no_combine: bool = False,
307
285
  apply_router_weight_on_input: bool = False,
286
+ routed_scaling_factor: Optional[float] = None,
308
287
  ) -> torch.Tensor:
309
288
  from sglang.srt.layers.moe.fused_moe_triton import fused_experts
310
289
  from sglang.srt.layers.moe.topk import select_experts
@@ -319,6 +298,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
319
298
  num_expert_group=num_expert_group,
320
299
  custom_routing_function=custom_routing_function,
321
300
  correction_bias=correction_bias,
301
+ routed_scaling_factor=routed_scaling_factor,
322
302
  )
323
303
 
324
304
  return fused_experts(
@@ -345,11 +325,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
345
325
  def __init__(
346
326
  self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
347
327
  ):
348
- from sglang.srt.layers.moe.fused_moe_triton import (
349
- FusedMoEMethodBase,
350
- FusedMoeWeightScaleSupported,
351
- )
352
-
353
328
  self.quant_config = quant_config
354
329
  # TODO: @dsikka: refactor this to use schemes as other kernels
355
330
  # are supported + check if the layer is being ignored.
@@ -609,7 +584,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
609
584
  requires_grad=False,
610
585
  )
611
586
 
612
- marlin_w13_qweight = ops.gptq_marlin_moe_repack(
587
+ marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
613
588
  layer.w13_weight_packed,
614
589
  layer.w13_g_idx_sort_indices,
615
590
  layer.w13_weight_packed.shape[1] * self.packed_factor,
@@ -617,7 +592,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
617
592
  self.num_bits,
618
593
  )
619
594
  replace_tensor("w13_weight_packed", marlin_w13_qweight)
620
- marlin_w2_qweight = ops.gptq_marlin_moe_repack(
595
+ marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack(
621
596
  layer.w2_weight_packed,
622
597
  layer.w2_g_idx_sort_indices,
623
598
  layer.w2_weight_packed.shape[1] * self.packed_factor,
@@ -660,15 +635,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
660
635
  scoring_func: str = "softmax",
661
636
  correction_bias: Optional[torch.Tensor] = None,
662
637
  activation: str = "silu",
638
+ routed_scaling_factor: Optional[float] = None,
663
639
  ) -> torch.Tensor:
664
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
665
640
  from sglang.srt.layers.moe.topk import select_experts
666
641
 
667
642
  assert activation == "silu", "Only SiLU activation is supported."
668
- if not VLLM_AVAILABLE:
669
- raise ImportError(
670
- "vllm is not installed, to use fused_marlin_moe, please install vllm"
671
- )
672
643
  if expert_map is not None:
673
644
  raise NotImplementedError(
674
645
  "Expert Parallelism is not supported for " "fused Marlin MoE method."
@@ -685,6 +656,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
685
656
  custom_routing_function=custom_routing_function,
686
657
  scoring_func=scoring_func,
687
658
  correction_bias=correction_bias,
659
+ routed_scaling_factor=routed_scaling_factor,
688
660
  )
689
661
 
690
662
  return torch.ops.vllm.fused_marlin_moe(
@@ -2,8 +2,10 @@
2
2
 
3
3
  from .compressed_tensors_scheme import CompressedTensorsScheme
4
4
  from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
5
+ from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
5
6
 
6
7
  __all__ = [
7
8
  "CompressedTensorsScheme",
8
9
  "CompressedTensorsW8A8Fp8",
10
+ "CompressedTensorsW8A16Fp8",
9
11
  ]
@@ -0,0 +1,153 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Callable, List, Optional
5
+
6
+ import torch
7
+ from compressed_tensors.quantization import QuantizationStrategy
8
+
9
+ from sglang.srt.layers.parameter import (
10
+ ChannelQuantScaleParameter,
11
+ ModelWeightParameter,
12
+ PerTensorScaleParameter,
13
+ )
14
+ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
15
+ CompressedTensorsScheme,
16
+ )
17
+ from sglang.srt.layers.quantization.utils import convert_to_channelwise
18
+
19
+ try:
20
+ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
21
+ apply_fp8_marlin_linear,
22
+ prepare_fp8_layer_for_marlin,
23
+ )
24
+
25
+ MARLIN_FP8_AVAILABLE = True
26
+ except ImportError:
27
+ MARLIN_FP8_AVAILABLE = False
28
+
29
+ def apply_fp8_marlin_linear(*args, **kwargs):
30
+ raise ImportError("vllm is not installed")
31
+
32
+ def prepare_fp8_layer_for_marlin(*args, **kwargs):
33
+ raise ImportError("vllm is not installed")
34
+
35
+
36
+ __all__ = ["CompressedTensorsW8A16Fp8"]
37
+
38
+ SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR]
39
+
40
+
41
+ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
42
+ def __init__(self, strategy: str, is_static_input_scheme: bool):
43
+ self.strategy = strategy
44
+ self.is_static_input_scheme = is_static_input_scheme
45
+
46
+ if not MARLIN_FP8_AVAILABLE:
47
+ raise ImportError(
48
+ "vllm is not installed. To use CompressedTensorsW8A16Fp8, please install vllm"
49
+ )
50
+
51
+ @classmethod
52
+ def get_min_capability(cls) -> int:
53
+ # ampere and up
54
+ return 80
55
+
56
+ # W8A8-Fp8 kernels support only per-tensor and per-channel cases.
57
+ # So if we have a fused module (QKV, MLP) with per tensor scales,
58
+ # we expand each scale to its shard's channels.
59
+ def process_weights_after_loading(self, layer) -> None:
60
+ if self.strategy == QuantizationStrategy.TENSOR:
61
+ ws_channelwise = convert_to_channelwise(
62
+ layer.weight_scale, layer.logical_widths
63
+ )
64
+ layer.weight_scale = torch.nn.Parameter(ws_channelwise, requires_grad=False)
65
+ else:
66
+ # required by torch.compile to be torch.nn.Parameter
67
+ layer.weight_scale = torch.nn.Parameter(
68
+ layer.weight_scale.data, requires_grad=False
69
+ )
70
+
71
+ # Weights must be transposed for marlin
72
+ layer.weight = torch.nn.Parameter(layer.weight.t(), requires_grad=False)
73
+
74
+ if self.is_static_input_scheme:
75
+ # required by torch.compile to be torch.nn.Parameter
76
+ layer.input_scale = torch.nn.Parameter(
77
+ layer.input_scale.data, requires_grad=False
78
+ )
79
+ prepare_fp8_layer_for_marlin(layer, strategy="channel")
80
+
81
+ def create_weights(
82
+ self,
83
+ layer: torch.nn.Module,
84
+ input_size: int,
85
+ output_partition_sizes: List[int],
86
+ input_size_per_partition: int,
87
+ params_dtype: torch.dtype,
88
+ weight_loader: Callable,
89
+ **kwargs,
90
+ ):
91
+ output_size_per_partition = sum(output_partition_sizes)
92
+ layer.logical_widths = output_partition_sizes
93
+ layer.input_size_per_partition = input_size_per_partition
94
+ layer.output_size_per_partition = output_size_per_partition
95
+ layer.orig_dtype = params_dtype
96
+
97
+ # WEIGHT
98
+ weight = ModelWeightParameter(
99
+ data=torch.empty(
100
+ output_size_per_partition,
101
+ input_size_per_partition,
102
+ dtype=torch.float8_e4m3fn,
103
+ ),
104
+ input_dim=1,
105
+ output_dim=0,
106
+ weight_loader=weight_loader,
107
+ )
108
+ layer.register_parameter("weight", weight)
109
+
110
+ # WEIGHT SCALE
111
+ if self.strategy == QuantizationStrategy.CHANNEL:
112
+ weight_scale = ChannelQuantScaleParameter(
113
+ data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
114
+ output_dim=0,
115
+ weight_loader=weight_loader,
116
+ )
117
+ elif self.strategy == QuantizationStrategy.TENSOR:
118
+ weight_scale = PerTensorScaleParameter(
119
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
120
+ weight_loader=weight_loader,
121
+ )
122
+ else:
123
+ raise ValueError(
124
+ f"Unsupported weight strategy={self.strategy}, "
125
+ f"supported strategies are {SUPPORTED_STRATEGIES}"
126
+ )
127
+
128
+ weight_scale[:] = torch.finfo(torch.float32).min
129
+ layer.register_parameter("weight_scale", weight_scale)
130
+
131
+ # INPUT SCALE (to deal with converted checkpoints)
132
+ if self.is_static_input_scheme:
133
+ input_scale = PerTensorScaleParameter(
134
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
135
+ weight_loader=weight_loader,
136
+ )
137
+ layer.register_parameter("input_scale", input_scale)
138
+
139
+ def apply_weights(
140
+ self,
141
+ layer: torch.nn.Module,
142
+ x: torch.Tensor,
143
+ bias: Optional[torch.Tensor] = None,
144
+ ) -> torch.Tensor:
145
+ return apply_fp8_marlin_linear(
146
+ input=x,
147
+ weight=layer.weight,
148
+ weight_scale=layer.weight_scale,
149
+ workspace=layer.workspace,
150
+ size_n=layer.output_size_per_partition,
151
+ size_k=layer.input_size_per_partition,
152
+ bias=bias,
153
+ )
@@ -16,8 +16,7 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
16
16
  CompressedTensorsScheme,
17
17
  )
18
18
  from sglang.srt.layers.quantization.fp8_utils import (
19
- Fp8LinearOp,
20
- maybe_create_device_identity,
19
+ apply_fp8_linear,
21
20
  normalize_e4m3fn_to_e4m3fnuz,
22
21
  )
23
22
  from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
@@ -30,7 +29,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
30
29
  def __init__(self, strategy: str, is_static_input_scheme: bool):
31
30
  self.strategy = strategy
32
31
  self.is_static_input_scheme = is_static_input_scheme
33
- self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
34
32
 
35
33
  @classmethod
36
34
  def get_min_capability(cls) -> int:
@@ -99,8 +97,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
99
97
  weight_loader: Callable,
100
98
  **kwargs,
101
99
  ):
102
- maybe_create_device_identity()
103
-
104
100
  output_size_per_partition = sum(output_partition_sizes)
105
101
  layer.logical_widths = output_partition_sizes
106
102
 
@@ -152,11 +148,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
152
148
  x: torch.Tensor,
153
149
  bias: Optional[torch.Tensor] = None,
154
150
  ) -> torch.Tensor:
155
-
156
- return self.fp8_linear.apply(
151
+ return apply_fp8_linear(
157
152
  input=x,
158
153
  weight=layer.weight,
159
154
  weight_scale=layer.weight_scale,
160
155
  input_scale=layer.input_scale,
161
156
  bias=bias,
157
+ use_per_token_if_dynamic=True,
158
+ compressed_tensor_quant=True,
162
159
  )