sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -36,30 +36,27 @@ from sglang.srt.layers.quantization.marlin_utils import (
36
36
  marlin_zero_points,
37
37
  verify_marlin_supported,
38
38
  )
39
- from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
40
39
  from sglang.srt.layers.quantization.utils import (
41
40
  get_linear_quant_method,
41
+ get_scalar_types,
42
42
  replace_parameter,
43
43
  unpack_cols,
44
44
  )
45
45
 
46
46
  if TYPE_CHECKING:
47
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
47
48
  from sglang.srt.layers.moe.topk import TopKOutput
48
49
 
49
- try:
50
- from vllm import _custom_ops as ops
51
- except ImportError:
52
- ops = None
53
-
54
50
  from sglang.srt.utils import is_cuda
55
51
 
56
52
  _is_cuda = is_cuda()
57
53
 
58
54
  if _is_cuda:
59
- from sgl_kernel import fused_marlin_moe
55
+ from sgl_kernel import fused_marlin_moe, gptq_gemm, gptq_marlin_repack, gptq_shuffle
60
56
 
61
57
 
62
58
  logger = logging.getLogger(__name__)
59
+ ScalarType, scalar_types = get_scalar_types()
63
60
 
64
61
 
65
62
  def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
@@ -85,9 +82,7 @@ def gptq_marlin_moe_repack(
85
82
  dtype=b_q_weight.dtype,
86
83
  )
87
84
  for e in range(num_experts):
88
- output[e] = torch.ops.sgl_kernel.gptq_marlin_repack(
89
- b_q_weight[e], perm[e], size_k, size_n, num_bits
90
- )
85
+ output[e] = gptq_marlin_repack(b_q_weight[e], perm[e], size_k, size_n, num_bits)
91
86
  return output
92
87
 
93
88
 
@@ -204,11 +199,12 @@ class GPTQConfig(QuantizationConfig):
204
199
  from sglang.srt.layers.linear import LinearBase
205
200
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
206
201
 
207
- if isinstance(layer, LinearBase):
208
- return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
209
- elif isinstance(layer, FusedMoE):
202
+ if isinstance(layer, FusedMoE):
210
203
  raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin")
211
- return None
204
+ else:
205
+ return get_linear_quant_method(
206
+ self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
207
+ )
212
208
 
213
209
 
214
210
  class GPTQMarlinConfig(QuantizationConfig):
@@ -530,7 +526,7 @@ class GPTQLinearMethod(LinearMethodBase):
530
526
  layer.g_idx.data = torch.empty(
531
527
  (0,), dtype=torch.int, device=layer.g_idx.device
532
528
  )
533
- ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
529
+ gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
534
530
 
535
531
  def apply(
536
532
  self,
@@ -541,7 +537,7 @@ class GPTQLinearMethod(LinearMethodBase):
541
537
  out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
542
538
  reshaped_x = x.reshape(-1, x.shape[-1])
543
539
 
544
- output = ops.gptq_gemm(
540
+ output = gptq_gemm(
545
541
  reshaped_x,
546
542
  layer.qweight,
547
543
  layer.qzeros,
@@ -726,7 +722,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
726
722
  def transform_w_q(x):
727
723
  assert isinstance(x, BasevLLMParameter)
728
724
  permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
729
- x.data = torch.ops.sgl_kernel.gptq_marlin_repack(
725
+ x.data = gptq_marlin_repack(
730
726
  x.data.contiguous(),
731
727
  perm=layer.g_idx_sort_indices,
732
728
  size_k=c.partition_weight_shape[0],
@@ -1061,13 +1057,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1061
1057
  layer: torch.nn.Module,
1062
1058
  x: torch.Tensor,
1063
1059
  topk_output: TopKOutput,
1064
- *,
1065
- activation: str = "silu",
1066
- **kwargs,
1060
+ moe_runner_config: MoeRunnerConfig,
1067
1061
  ) -> torch.Tensor:
1068
1062
  # Delay the import to avoid circular dependency
1069
1063
 
1070
- assert activation == "silu", "Only SiLU activation is supported."
1064
+ assert (
1065
+ moe_runner_config.activation == "silu"
1066
+ ), "Only SiLU activation is supported."
1071
1067
 
1072
1068
  # The input must currently be float16
1073
1069
  orig_dtype = x.dtype
@@ -19,20 +19,31 @@ from sglang.srt.layers.quantization.base_config import (
19
19
  LinearMethodBase,
20
20
  QuantizationConfig,
21
21
  )
22
- from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
23
- from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols
24
- from sglang.srt.utils import get_device_capability
22
+ from sglang.srt.layers.quantization.utils import (
23
+ get_scalar_types,
24
+ pack_cols,
25
+ unpack_cols,
26
+ )
27
+ from sglang.srt.utils import get_device_capability, is_cuda
25
28
 
26
29
  if TYPE_CHECKING:
27
30
  from sglang.srt.layers.linear import LinearBase
31
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
28
32
 
29
33
  try:
30
34
  from vllm import _custom_ops as ops
31
35
  except ImportError:
32
36
  ops = None
33
37
 
38
+ _is_cuda = is_cuda()
39
+
40
+ if _is_cuda:
41
+ from sgl_kernel import gptq_marlin_gemm
42
+
34
43
  logger = logging.getLogger(__name__)
35
44
 
45
+ ScalarType, scalar_types = get_scalar_types()
46
+
36
47
  GPTQ_MARLIN_TILE = 16
37
48
  GPTQ_MARLIN_MIN_THREAD_N = 64
38
49
  GPTQ_MARLIN_MIN_THREAD_K = 128
@@ -206,13 +217,13 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
206
217
  )[0]
207
218
 
208
219
 
209
- def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
220
+ def check_moe_marlin_supports_layer(layer: FusedMoE, group_size: int) -> bool:
210
221
  hidden_size = layer.hidden_size
211
222
  intermediate_size_per_partition = layer.intermediate_size_per_partition
212
223
  # apply_router_weight_on_input is not supported for moe marlin
213
- supports_router_weight = not layer.apply_router_weight_on_input
224
+ supports_router_weight = not layer.moe_runner_config.apply_router_weight_on_input
214
225
  # moe marlin requires the activation to be silu
215
- supports_activation = layer.activation == "silu"
226
+ supports_activation = layer.moe_runner_config.activation == "silu"
216
227
 
217
228
  # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
218
229
  # down: (n, k) = (hidden_size, intermediate_size_per_partition)
@@ -295,6 +306,13 @@ def marlin_permute_scales(
295
306
  return s
296
307
 
297
308
 
309
+ def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
310
+ origin_shape = s.shape
311
+ _, scale_perm_single = get_scale_perms()
312
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
313
+ return s.reshape(*origin_shape).contiguous()
314
+
315
+
298
316
  def marlin_moe_permute_scales(
299
317
  s: torch.Tensor,
300
318
  size_k: int,
@@ -453,7 +471,7 @@ def apply_gptq_marlin_linear(
453
471
  dtype=input.dtype,
454
472
  )
455
473
 
456
- output = ops.gptq_marlin_gemm(
474
+ output = gptq_marlin_gemm(
457
475
  reshaped_x,
458
476
  None,
459
477
  weight,
@@ -504,7 +522,7 @@ def apply_awq_marlin_linear(
504
522
  dtype=input.dtype,
505
523
  )
506
524
 
507
- output = ops.gptq_marlin_gemm(
525
+ output = gptq_marlin_gemm(
508
526
  reshaped_x,
509
527
  None,
510
528
  weight,
@@ -0,0 +1,352 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import logging
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+ from sglang.srt.layers.quantization.marlin_utils import (
9
+ USE_FP32_REDUCE_DEFAULT,
10
+ marlin_make_workspace,
11
+ marlin_permute_bias,
12
+ marlin_permute_scales,
13
+ should_use_atomic_add_reduce,
14
+ )
15
+ from sglang.srt.layers.quantization.utils import get_scalar_types
16
+ from sglang.srt.utils import is_cuda
17
+
18
+ _is_cuda = is_cuda()
19
+ if _is_cuda:
20
+ from sgl_kernel import gptq_marlin_gemm, gptq_marlin_repack
21
+
22
+ ScalarType, scalar_types = get_scalar_types()
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def fp8_fused_exponent_bias_into_scales(scales):
28
+ fp8_exponent = 4
29
+ if scales.dtype == torch.half:
30
+ target_exponent = 5
31
+ elif scales.dtype == torch.bfloat16:
32
+ target_exponent = 8
33
+ # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
34
+ # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
35
+ exponent_bias = 2 ** (target_exponent - 1) - 2 ** (fp8_exponent - 1)
36
+ s = torch.ones_like(scales) * 2
37
+ s = s**exponent_bias
38
+ return scales * s
39
+
40
+
41
+ def apply_fp8_marlin_linear(
42
+ input: torch.Tensor,
43
+ weight: torch.Tensor,
44
+ weight_scale: torch.Tensor,
45
+ workspace: torch.Tensor,
46
+ size_n: int,
47
+ size_k: int,
48
+ bias: Optional[torch.Tensor],
49
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
50
+ ) -> torch.Tensor:
51
+ # For GPUs that lack FP8 hardware support, we can leverage the
52
+ # Marlin kernel for fast weight-only FP8 quantization
53
+
54
+ reshaped_x = input.reshape(-1, input.shape[-1])
55
+ out_shape = input.shape[:-1] + (size_n,)
56
+
57
+ use_atomic_add = should_use_atomic_add_reduce(
58
+ m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype
59
+ )
60
+
61
+ output = gptq_marlin_gemm(
62
+ a=reshaped_x,
63
+ c=None,
64
+ b_q_weight=weight,
65
+ b_bias=bias,
66
+ b_scales=weight_scale,
67
+ global_scale=None,
68
+ b_zeros=None,
69
+ g_idx=None,
70
+ perm=None,
71
+ workspace=workspace,
72
+ b_q_type=scalar_types.float8_e4m3fn,
73
+ size_m=reshaped_x.size(0),
74
+ size_n=size_n,
75
+ size_k=size_k,
76
+ use_atomic_add=use_atomic_add,
77
+ use_fp32_reduce=use_fp32_reduce,
78
+ )
79
+
80
+ return output.reshape(out_shape)
81
+
82
+
83
+ def prepare_fp8_layer_for_marlin(
84
+ layer: torch.nn.Module, size_k_first: bool = True
85
+ ) -> None:
86
+ logger.warning_once(
87
+ "Your GPU does not have native support for FP8 computation but "
88
+ "FP8 quantization is being used. Weight-only FP8 compression will "
89
+ "be used leveraging the Marlin kernel. This may degrade "
90
+ "performance for compute-heavy workloads."
91
+ )
92
+
93
+ part_size_n = layer.output_size_per_partition
94
+ part_size_k = layer.input_size_per_partition
95
+ weight_block_size = getattr(layer, "weight_block_size", None)
96
+
97
+ if size_k_first:
98
+ assert layer.weight.shape == (part_size_k, part_size_n)
99
+ else:
100
+ assert layer.weight.shape == (part_size_n, part_size_k)
101
+
102
+ device = layer.weight.device
103
+
104
+ # WORKSPACE
105
+ layer.workspace = marlin_make_workspace(device)
106
+
107
+ # WEIGHT
108
+ # Repack weights to marlin format
109
+ perm = torch.empty(0, dtype=torch.int, device=device)
110
+ qweight = pack_fp8_to_int32(layer.weight, size_k_first)
111
+ if not size_k_first:
112
+ qweight = qweight.T.contiguous()
113
+
114
+ marlin_qweight = gptq_marlin_repack(
115
+ b_q_weight=qweight,
116
+ perm=perm,
117
+ size_k=part_size_k,
118
+ size_n=part_size_n,
119
+ num_bits=8,
120
+ )
121
+ layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
122
+
123
+ # WEIGHT SCALES
124
+ # Permute scales
125
+ if "weight_scale" in dir(layer):
126
+ scales = layer.weight_scale.to(layer.orig_dtype)
127
+ elif "weight_scale_inv" in dir(layer):
128
+ scales = layer.weight_scale_inv.to(layer.orig_dtype)
129
+ del layer.weight_scale_inv
130
+
131
+ group_size = -1 if weight_block_size is None else weight_block_size[1]
132
+
133
+ # marlin kernel only support channel-wise and group-wise quantization
134
+ # we need to convert the scales
135
+ if weight_block_size is None:
136
+ if scales.nelement() == 1:
137
+ # tensor-wise quantization -> channel-wise quantization
138
+ # (1, 1) =>(repeat)=> (1, size_n)
139
+ scales = scales.view(1, 1).repeat_interleave(part_size_n, 1)
140
+ elif scales.nelement() > 1 and scales.nelement() != part_size_n:
141
+ assert part_size_n % scales.nelement() == 0
142
+ s_size = scales.nelement()
143
+ # tensor-wise quantization (for gate-up proj)
144
+ # -> channel-wise quantization
145
+ # (1, s_size) =>(repeat)=> (1, size_n)
146
+ scales = scales.view(1, s_size)
147
+ scales = scales.repeat_interleave(part_size_n // s_size, 1)
148
+ else:
149
+ # channel-wise quantization
150
+ # (1, size_n)
151
+ scales = scales.view(1, part_size_n)
152
+ else:
153
+ # block-wise quantization -> group-wise quantization
154
+ # (size_k // block_size[1], ceil(size_n / block_size[0]))
155
+ # =>(repeat)=> (size_k // block_size[1], size_n)
156
+ if not size_k_first:
157
+ scales = scales.T.contiguous()
158
+ block_n = weight_block_size[0]
159
+ scales = scales.repeat_interleave(block_n, 1)
160
+ # size_n may not divisible by block_size[0]
161
+ scales = scales[:, :part_size_n]
162
+
163
+ marlin_scales = marlin_permute_scales(
164
+ s=scales, size_k=part_size_k, size_n=part_size_n, group_size=group_size
165
+ )
166
+ marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
167
+ layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
168
+
169
+ if hasattr(layer, "bias") and layer.bias is not None:
170
+ assert layer.bias.shape == (part_size_n,)
171
+ bias = marlin_permute_bias(layer.bias)
172
+ layer.bias = torch.nn.Parameter(bias, requires_grad=False)
173
+
174
+
175
+ def prepare_moe_fp8_layer_for_marlin(
176
+ layer: torch.nn.Module, size_k_first: bool = True
177
+ ) -> None:
178
+ logger.warning_once(
179
+ "Your GPU does not have native support for FP8 computation but "
180
+ "FP8 quantization is being used. Weight-only FP8 compression will "
181
+ "be used leveraging the Marlin kernel. This may degrade "
182
+ "performance for compute-heavy workloads."
183
+ )
184
+
185
+ e = layer.num_experts
186
+ k = layer.hidden_size
187
+ n = layer.intermediate_size_per_partition
188
+ weight_block_size = getattr(layer, "weight_block_size", None)
189
+
190
+ # WORKSPACE
191
+ device = layer.w13_weight.device
192
+ layer.workspace = marlin_make_workspace(device, 4)
193
+ perm = torch.empty(0, dtype=torch.int, device=device)
194
+
195
+ # WEIGHT
196
+ # Repack weights to marlin format
197
+ for name in ["w13_weight", "w2_weight"]:
198
+ weight = getattr(layer, name)
199
+ tensor_list = []
200
+ if "w13" in name:
201
+ size_n, size_k = n * 2, k
202
+ else:
203
+ size_n, size_k = k, n
204
+
205
+ if size_k_first:
206
+ assert weight.shape == (e, size_k, size_n)
207
+ else:
208
+ assert weight.shape == (e, size_n, size_k)
209
+
210
+ for i in range(e):
211
+ qweight = pack_fp8_to_int32(weight[i], size_k_first)
212
+ if not size_k_first:
213
+ qweight = qweight.T.contiguous()
214
+
215
+ marlin_qweight = gptq_marlin_repack(
216
+ b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=8
217
+ )
218
+ tensor_list.append(marlin_qweight)
219
+
220
+ weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
221
+ weight = torch.nn.Parameter(weight, requires_grad=False)
222
+
223
+ setattr(layer, name, weight)
224
+
225
+ # WEIGHT SCALES
226
+ # Permute scales
227
+ group_size = -1 if weight_block_size is None else weight_block_size[1]
228
+
229
+ for name in ["w13", "w2"]:
230
+ if name + "_weight_scale" in dir(layer):
231
+ new_name = name + "_weight_scale"
232
+ scales = getattr(layer, new_name).to(layer.orig_dtype)
233
+ delattr(layer, new_name)
234
+ elif name + "_weight_scale_inv" in dir(layer):
235
+ new_name = name + "_weight_scale_inv"
236
+ scales = getattr(layer, new_name).to(layer.orig_dtype)
237
+ delattr(layer, new_name)
238
+
239
+ tensor_list = []
240
+ if "w13" in name:
241
+ size_n, size_k = n * 2, k
242
+ else:
243
+ size_n, size_k = k, n
244
+
245
+ # marlin kernel only support channel-wise and group-wise quantization
246
+ # we need to convert the scales
247
+ if weight_block_size is None:
248
+ if scales.nelement() == e:
249
+ # tensor-wise quantization -> channel-wise quantization
250
+ # (e, 1, 1) =>(repeat)=> (e, 1, size_n)
251
+ scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2)
252
+ elif scales.nelement() > e and scales.nelement() != e * size_n:
253
+ assert (e * size_n) % scales.nelement() == 0
254
+ s_size = scales.nelement() // e
255
+ # tensor-wise quantization (for gate-up proj)
256
+ # -> channel-wise quantization
257
+ # (e, 1, s_size) =>(repeat)=> (e, 1, size_n)
258
+ scales = scales.view(e, 1, s_size)
259
+ scales = scales.repeat_interleave(size_n // s_size, 2)
260
+ else:
261
+ # channel-wise quantization
262
+ # (e, 1, size_n)
263
+ scales = scales.view(e, 1, size_n)
264
+ else:
265
+ # block-wise quantization -> group-wise quantization
266
+ # (e, size_k // block_size[1], ceil(size_n / block_size[0]))
267
+ # =>(repeat)=> (e, size_k // block_size[1], size_n)
268
+ if not size_k_first:
269
+ scales = scales.permute(0, 2, 1)
270
+ block_n = weight_block_size[0]
271
+ scales = scales.repeat_interleave(block_n, 2)
272
+ # size_n may not divisible by block_size[0]
273
+ scales = scales[..., :size_n].contiguous()
274
+
275
+ for i in range(e):
276
+ marlin_scales = marlin_permute_scales(
277
+ s=scales[i], size_k=size_k, size_n=size_n, group_size=group_size
278
+ )
279
+ tensor_list.append(marlin_scales)
280
+
281
+ scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
282
+ scales = fp8_fused_exponent_bias_into_scales(scales)
283
+ scales = torch.nn.Parameter(scales, requires_grad=False)
284
+
285
+ setattr(layer, name + "_weight_scale", scales)
286
+
287
+ # BIAS
288
+ # Permute bias
289
+ for name in ["w13_bias", "w2_bias"]:
290
+ if not hasattr(layer, name):
291
+ continue
292
+ bias = getattr(layer, name).to(layer.orig_dtype)
293
+
294
+ tensor_list = []
295
+ for i in range(e):
296
+ expert_bias = bias[i]
297
+
298
+ tensor_list.append(marlin_permute_bias(expert_bias))
299
+
300
+ bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
301
+ bias = torch.nn.Parameter(bias, requires_grad=False)
302
+ setattr(layer, name, bias)
303
+
304
+
305
+ def pack_fp8_to_int32(
306
+ fp8_tensor: torch.Tensor, size_k_first: bool = True
307
+ ) -> torch.Tensor:
308
+ """
309
+ Repack FP8 weights to gptq format (packed int32 elements)
310
+ """
311
+ assert fp8_tensor.dtype == torch.float8_e4m3fn
312
+ assert fp8_tensor.ndim == 2
313
+
314
+ fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
315
+ fp8_tensor = fp8_tensor.contiguous()
316
+ # fp8_tensor is contiguous and have shape (N, K) now
317
+ # with `.view(torch.int32)`, it become (N, K // 4)
318
+ int32_tensor = fp8_tensor.view(torch.int32)
319
+ return int32_tensor.T.contiguous() if size_k_first else int32_tensor
320
+
321
+
322
+ def marlin_quant_fp8_torch(weight, group_size):
323
+ size_n, size_k = weight.shape
324
+ device = weight.device
325
+
326
+ if group_size != -1:
327
+ scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448
328
+ repeated_scales = scales.repeat_interleave(group_size, 1)
329
+ fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
330
+ weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
331
+ else:
332
+ scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448
333
+ repeated_scales = scales.repeat_interleave(size_k, 1)
334
+ fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
335
+ weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
336
+
337
+ packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
338
+ marlin_qweight = gptq_marlin_repack(
339
+ b_q_weight=packed_weight,
340
+ perm=torch.empty(0, dtype=torch.int, device=device),
341
+ size_k=size_k,
342
+ size_n=size_n,
343
+ num_bits=8,
344
+ )
345
+
346
+ marlin_scales = marlin_permute_scales(
347
+ s=scales.T, size_k=size_k, size_n=size_n, group_size=group_size
348
+ )
349
+
350
+ marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
351
+
352
+ return weight_ref.T, marlin_qweight, marlin_scales