sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -113,7 +113,7 @@ if supports_custom_op():
113
113
 
114
114
 
115
115
  @triton.jit
116
- def _per_token_group_quant_fp8(
116
+ def _per_token_group_quant_8bit(
117
117
  # Pointers to inputs and output
118
118
  y_ptr,
119
119
  y_q_ptr,
@@ -125,8 +125,8 @@ def _per_token_group_quant_fp8(
125
125
  # Avoid to divide zero
126
126
  eps,
127
127
  # Information for float8
128
- fp8_min,
129
- fp8_max,
128
+ bit8_min,
129
+ bit8_max,
130
130
  # Meta-parameters
131
131
  BLOCK: tl.constexpr,
132
132
  ):
@@ -147,16 +147,16 @@ def _per_token_group_quant_fp8(
147
147
  y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
148
148
  # Quant
149
149
  _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
150
- y_s = _absmax / fp8_max
150
+ y_s = _absmax / bit8_max
151
151
  y_s_inv = 1.0 / y_s
152
- y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
152
+ y_q = tl.clamp(y * y_s_inv, bit8_min, bit8_max).to(y_q_ptr.dtype.element_ty)
153
153
 
154
154
  tl.store(y_q_ptr + cols, y_q, mask=mask)
155
155
  tl.store(y_s_ptr, y_s)
156
156
 
157
157
 
158
158
  @triton.jit
159
- def _per_token_group_quant_fp8_colmajor(
159
+ def _per_token_group_quant_8bit_colmajor(
160
160
  # Pointers to inputs and output
161
161
  y_ptr,
162
162
  y_q_ptr,
@@ -169,8 +169,8 @@ def _per_token_group_quant_fp8_colmajor(
169
169
  # Avoid to divide zero
170
170
  eps,
171
171
  # Information for float8
172
- fp8_min,
173
- fp8_max,
172
+ bit8_min,
173
+ bit8_max,
174
174
  # Meta-parameters
175
175
  BLOCK: tl.constexpr,
176
176
  SCALE_UE8M0: tl.constexpr,
@@ -197,19 +197,20 @@ def _per_token_group_quant_fp8_colmajor(
197
197
  y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
198
198
  # Quant
199
199
  _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
200
- y_s = _absmax / fp8_max
200
+ y_s = _absmax / bit8_max
201
201
  if SCALE_UE8M0:
202
202
  y_s = tl.exp2(tl.ceil(tl.log2(tl.abs(y_s))))
203
- y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
203
+ y_q = tl.clamp(y / y_s, bit8_min, bit8_max).to(y_q_ptr.dtype.element_ty)
204
204
 
205
205
  tl.store(y_q_ptr + cols, y_q, mask=mask)
206
206
  tl.store(y_s_ptr, y_s)
207
207
 
208
208
 
209
- def per_token_group_quant_fp8(
209
+ def _per_token_group_quant_8bit_raw(
210
210
  x: torch.Tensor,
211
211
  group_size: int,
212
212
  eps: float = 1e-10,
213
+ dtype: torch.dtype = fp8_dtype,
213
214
  column_major_scales: bool = False,
214
215
  scale_tma_aligned: bool = False,
215
216
  scale_ue8m0: bool = False,
@@ -223,6 +224,7 @@ def per_token_group_quant_fp8(
223
224
  x: The input tenosr with ndim >= 2.
224
225
  group_size: The group size used for quantization.
225
226
  eps: The minimum to avoid dividing zero.
227
+ dtype: The dype of output tensor.
226
228
 
227
229
  Returns:
228
230
  Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
@@ -232,7 +234,21 @@ def per_token_group_quant_fp8(
232
234
  ), "the last dimension of `x` cannot be divisible by `group_size`"
233
235
  assert x.is_contiguous(), "`x` is not contiguous"
234
236
 
235
- x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
237
+ if _is_hip:
238
+ if dtype == torch.int8:
239
+ bit8_max = 127.0
240
+ else:
241
+ bit8_max = 224.0
242
+ bit8_min = -bit8_max # TODO incorrect for int8
243
+ else:
244
+ if dtype == torch.int8:
245
+ info = torch.iinfo(dtype)
246
+ else:
247
+ info = torch.finfo(dtype)
248
+ bit8_max = info.max
249
+ bit8_min = info.min
250
+
251
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
236
252
  x_s = create_per_token_group_quant_fp8_output_scale(
237
253
  x_shape=x.shape,
238
254
  device=x.device,
@@ -250,7 +266,7 @@ def per_token_group_quant_fp8(
250
266
  num_warps = min(max(BLOCK // 256, 1), 8)
251
267
  num_stages = 1
252
268
  if column_major_scales:
253
- _per_token_group_quant_fp8_colmajor[(M,)](
269
+ _per_token_group_quant_8bit_colmajor[(M,)](
254
270
  x,
255
271
  x_q,
256
272
  x_s,
@@ -258,8 +274,8 @@ def per_token_group_quant_fp8(
258
274
  x.shape[1],
259
275
  x_s.stride(1),
260
276
  eps,
261
- fp8_min=fp8_min,
262
- fp8_max=fp8_max,
277
+ bit8_min=bit8_min,
278
+ bit8_max=bit8_max,
263
279
  BLOCK=BLOCK,
264
280
  num_warps=num_warps,
265
281
  num_stages=num_stages,
@@ -267,15 +283,15 @@ def per_token_group_quant_fp8(
267
283
  )
268
284
  else:
269
285
  assert not scale_ue8m0
270
- _per_token_group_quant_fp8[(M,)](
286
+ _per_token_group_quant_8bit[(M,)](
271
287
  x,
272
288
  x_q,
273
289
  x_s,
274
290
  group_size,
275
291
  N,
276
292
  eps,
277
- fp8_min=fp8_min,
278
- fp8_max=fp8_max,
293
+ bit8_min=bit8_min,
294
+ bit8_max=bit8_max,
279
295
  BLOCK=BLOCK,
280
296
  num_warps=num_warps,
281
297
  num_stages=num_stages,
@@ -297,6 +313,117 @@ def per_token_group_quant_fp8(
297
313
  return x_q, x_s
298
314
 
299
315
 
316
+ # backward compatibility
317
+ per_token_group_quant_fp8 = _per_token_group_quant_8bit_raw
318
+
319
+
320
+ def _per_token_group_quant_8bit_fuse_silu_and_mul(
321
+ x: torch.Tensor,
322
+ group_size: int,
323
+ dst_dtype: torch.dtype,
324
+ column_major_scales: bool,
325
+ scale_tma_aligned: bool,
326
+ scale_ue8m0: bool,
327
+ masked_m: Optional[torch.Tensor],
328
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
329
+ # Another way to implement (can be used in e.g. comparison tests)
330
+ # from sgl_kernel import silu_and_mul
331
+ # x_after_silu_and_mul = silu_and_mul(x)
332
+ # return per_token_group_quant_fp8(
333
+ # x_after_silu_and_mul,
334
+ # group_size=group_size,
335
+ # eps=eps,
336
+ # column_major_scales=column_major_scales,
337
+ # scale_tma_aligned=scale_tma_aligned,
338
+ # scale_ue8m0=scale_ue8m0,
339
+ # )
340
+
341
+ from deep_gemm.utils.layout import transform_sf_into_required_layout
342
+
343
+ from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
344
+
345
+ assert column_major_scales
346
+ assert scale_tma_aligned
347
+ assert scale_ue8m0
348
+
349
+ needs_unsqueeze = x.dim() == 2
350
+ if needs_unsqueeze:
351
+ num_tokens, _ = x.shape
352
+ x = x.unsqueeze(0)
353
+ assert masked_m is None
354
+ masked_m = torch.tensor([num_tokens], device=x.device, dtype=torch.int32)
355
+
356
+ # Use `zeros` for easier testing
357
+ output = torch.zeros(
358
+ (*x.shape[:-1], x.shape[-1] // 2),
359
+ device=x.device,
360
+ dtype=dst_dtype,
361
+ )
362
+ # Use `zeros` for easier testing
363
+ output_scale_for_kernel = torch.zeros(
364
+ (*x.shape[:-1], x.shape[-1] // 2 // group_size),
365
+ device=x.device,
366
+ dtype=torch.float32,
367
+ )
368
+ silu_and_mul_masked_post_quant_fwd(
369
+ input=x,
370
+ output=output,
371
+ output_scale=output_scale_for_kernel,
372
+ quant_group_size=group_size,
373
+ masked_m=masked_m,
374
+ scale_ue8m0=scale_ue8m0,
375
+ )
376
+
377
+ assert group_size == 128
378
+ output_scale = transform_sf_into_required_layout(
379
+ output_scale_for_kernel,
380
+ num_groups=output.shape[0],
381
+ mn=output.shape[-2],
382
+ k=output.shape[-1],
383
+ recipe=(1, group_size, group_size),
384
+ is_sfa=True,
385
+ )
386
+
387
+ if needs_unsqueeze:
388
+ output = output.squeeze(0)
389
+ output_scale = output_scale.squeeze(0)
390
+
391
+ return output, output_scale
392
+
393
+
394
+ def per_token_group_quant_8bit(
395
+ x: torch.Tensor,
396
+ group_size: int,
397
+ dst_dtype: torch.dtype,
398
+ eps: float = 1e-10,
399
+ column_major_scales: bool = False,
400
+ scale_tma_aligned: bool = False,
401
+ scale_ue8m0: bool = False,
402
+ fuse_silu_and_mul: bool = False,
403
+ masked_m: Optional[torch.Tensor] = None,
404
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
405
+ if fuse_silu_and_mul:
406
+ return _per_token_group_quant_8bit_fuse_silu_and_mul(
407
+ x=x,
408
+ group_size=group_size,
409
+ dst_dtype=dst_dtype,
410
+ column_major_scales=column_major_scales,
411
+ scale_tma_aligned=scale_tma_aligned,
412
+ scale_ue8m0=scale_ue8m0,
413
+ masked_m=masked_m,
414
+ )
415
+ else:
416
+ return _per_token_group_quant_8bit_raw(
417
+ x=x,
418
+ group_size=group_size,
419
+ eps=eps,
420
+ column_major_scales=column_major_scales,
421
+ scale_tma_aligned=scale_tma_aligned,
422
+ scale_ue8m0=scale_ue8m0,
423
+ dtype=dst_dtype,
424
+ )
425
+
426
+
300
427
  def create_per_token_group_quant_fp8_output_scale(
301
428
  x_shape,
302
429
  device,
@@ -307,16 +434,16 @@ def create_per_token_group_quant_fp8_output_scale(
307
434
  ):
308
435
  if scale_ue8m0:
309
436
  assert column_major_scales and scale_tma_aligned
310
- x_q_mn, x_q_k = x_shape
437
+ *x_batch, x_q_mn, x_q_k = x_shape
311
438
  x_s_mn, x_s_k = x_q_mn, x_q_k // 128
312
439
  aligned_mn = align(x_s_mn, 4)
313
440
  aligned_k = align(x_s_k, 4)
314
441
  # TODO(FIXME): Fix cuda kernel and recover here to empty.
315
- return torch.zeros(
316
- (aligned_k // 4, aligned_mn),
442
+ return torch.empty(
443
+ (*x_batch, aligned_k // 4, aligned_mn),
317
444
  device=device,
318
445
  dtype=torch.int,
319
- ).transpose(0, 1)[:x_s_mn, :]
446
+ ).transpose(-1, -2)[..., :x_s_mn, :]
320
447
  elif column_major_scales:
321
448
  if scale_tma_aligned:
322
449
  # TODO extract "align" function
@@ -348,15 +475,19 @@ def sglang_per_token_group_quant_fp8(
348
475
  column_major_scales: bool = False,
349
476
  scale_tma_aligned: bool = False,
350
477
  scale_ue8m0: bool = False,
478
+ fuse_silu_and_mul: bool = False,
479
+ masked_m: Optional[torch.Tensor] = None,
351
480
  ):
352
481
  assert (
353
482
  x.shape[-1] % group_size == 0
354
483
  ), "the last dimension of `x` cannot be divisible by `group_size`"
355
484
  assert x.is_contiguous(), "`x` is not contiguous"
356
485
 
357
- x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
486
+ out_shape = (*x.shape[:-1], x.shape[-1] // (2 if fuse_silu_and_mul else 1))
487
+
488
+ x_q = torch.empty(out_shape, device=x.device, dtype=fp8_dtype)
358
489
  x_s = create_per_token_group_quant_fp8_output_scale(
359
- x_shape=x.shape,
490
+ x_shape=out_shape,
360
491
  device=x.device,
361
492
  group_size=group_size,
362
493
  column_major_scales=column_major_scales,
@@ -372,6 +503,46 @@ def sglang_per_token_group_quant_fp8(
372
503
  return x_q, x_s
373
504
 
374
505
 
506
+ # TODO maybe unify int8 and fp8 code later
507
+ def sglang_per_token_group_quant_8bit(
508
+ x: torch.Tensor,
509
+ group_size: int,
510
+ dst_dtype: torch.dtype,
511
+ eps: float = 1e-10,
512
+ column_major_scales: bool = False,
513
+ scale_tma_aligned: bool = False,
514
+ scale_ue8m0: bool = False,
515
+ fuse_silu_and_mul: bool = False,
516
+ masked_m: Optional[torch.Tensor] = None,
517
+ ):
518
+ from sglang.srt.layers.quantization.int8_kernel import (
519
+ sglang_per_token_group_quant_int8,
520
+ )
521
+
522
+ if dst_dtype == torch.int8:
523
+ assert not column_major_scales
524
+ assert not scale_tma_aligned
525
+ assert not fuse_silu_and_mul
526
+ assert masked_m is None
527
+ return sglang_per_token_group_quant_int8(
528
+ x=x,
529
+ group_size=group_size,
530
+ eps=eps,
531
+ dtype=dst_dtype,
532
+ )
533
+
534
+ return sglang_per_token_group_quant_fp8(
535
+ x=x,
536
+ group_size=group_size,
537
+ eps=eps,
538
+ column_major_scales=column_major_scales,
539
+ scale_tma_aligned=scale_tma_aligned,
540
+ scale_ue8m0=scale_ue8m0,
541
+ fuse_silu_and_mul=fuse_silu_and_mul,
542
+ masked_m=masked_m,
543
+ )
544
+
545
+
375
546
  def sglang_per_token_quant_fp8(
376
547
  x: torch.Tensor,
377
548
  dtype: torch.dtype = fp8_dtype,
@@ -53,6 +53,7 @@ if _is_cuda:
53
53
  from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
54
54
 
55
55
  use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
56
+ use_triton_w8a8_fp8_kernel = get_bool_env_var("USE_TRITON_W8A8_FP8_KERNEL")
56
57
 
57
58
  # Input scaling factors are no longer optional in _scaled_mm starting
58
59
  # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
@@ -113,6 +114,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
113
114
  return weight, weight_scale, input_scale
114
115
 
115
116
 
117
+ # TODO(ch-wan): define these backends in --moe-runner-backend
116
118
  def cutlass_block_fp8_supported() -> bool:
117
119
  if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"):
118
120
  return False
@@ -555,7 +557,10 @@ def apply_fp8_linear(
555
557
  # We also don't pad when using torch.compile,
556
558
  # as it breaks with dynamic shapes.
557
559
  if pad_output is None:
558
- pad_output = not get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
560
+ pad_output = (
561
+ not get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
562
+ and not cutlass_fp8_supported
563
+ )
559
564
  output_padding = 17 if pad_output else None
560
565
 
561
566
  # View input as 2D matrix for fp8 methods
@@ -591,7 +596,7 @@ def apply_fp8_linear(
591
596
  cutlass_compatible_b = (
592
597
  weight.shape[0] % 16 == 0 and weight.shape[1] % 16 == 0
593
598
  )
594
- if not cutlass_compatible_b:
599
+ if not cutlass_compatible_b or use_triton_w8a8_fp8_kernel:
595
600
  # Massage the input to be 2D
596
601
  qinput = qinput.view(-1, qinput.shape[-1])
597
602
  output = triton_scaled_mm(
@@ -734,14 +739,25 @@ def apply_fp8_linear(
734
739
  assert (
735
740
  weight_scale.numel() == weight.shape[1]
736
741
  ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
737
- output = fp8_scaled_mm(
738
- qinput,
739
- weight,
740
- x_scale,
741
- weight_scale,
742
- out_dtype=input.dtype,
743
- bias=bias,
742
+
743
+ cutlass_compatible_b = (
744
+ weight.shape[0] % 16 == 0 and weight.shape[1] % 16 == 0
744
745
  )
746
+ if not cutlass_compatible_b or use_triton_w8a8_fp8_kernel:
747
+ # Massage the input to be 2D
748
+ qinput = qinput.view(-1, qinput.shape[-1])
749
+ output = triton_scaled_mm(
750
+ qinput, weight, x_scale, weight_scale, input.dtype, bias
751
+ )
752
+ else:
753
+ output = fp8_scaled_mm(
754
+ qinput,
755
+ weight,
756
+ x_scale,
757
+ weight_scale,
758
+ out_dtype=input.dtype,
759
+ bias=bias,
760
+ )
745
761
  return output.view(*output_shape)
746
762
  except (ImportError, NameError, AttributeError):
747
763
  pass
@@ -788,3 +804,12 @@ def apply_fp8_linear(
788
804
  bias,
789
805
  input.dtype,
790
806
  )
807
+
808
+
809
+ def can_auto_enable_marlin_fp8() -> bool:
810
+ try:
811
+ major, minor = get_device_capability()
812
+ sm = major * 10 + minor
813
+ return 80 <= sm < 89
814
+ except Exception:
815
+ return False
@@ -0,0 +1,203 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ from typing import Any, Optional
6
+
7
+ import torch
8
+ from torch.nn import Module
9
+ from torch.nn.parameter import Parameter
10
+
11
+ from sglang.srt.layers.linear import LinearBase
12
+ from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
13
+ from sglang.srt.layers.quantization.base_config import (
14
+ FusedMoEMethodBase,
15
+ LinearMethodBase,
16
+ QuantizationConfig,
17
+ QuantizeMethodBase,
18
+ )
19
+ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
20
+ from sglang.srt.layers.quantization.fp8_utils import (
21
+ apply_fp8_linear,
22
+ can_auto_enable_marlin_fp8,
23
+ cutlass_fp8_supported,
24
+ normalize_e4m3fn_to_e4m3fnuz,
25
+ )
26
+ from sglang.srt.layers.quantization.marlin_utils_fp8 import (
27
+ apply_fp8_marlin_linear,
28
+ prepare_fp8_layer_for_marlin,
29
+ )
30
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
31
+ from sglang.srt.layers.quantization.utils import is_layer_skipped, replace_parameter
32
+ from sglang.srt.utils import get_bool_env_var, is_cuda
33
+
34
+ _is_cuda = is_cuda()
35
+ _is_fp8_fnuz = is_fp8_fnuz()
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class FBGEMMFp8Config(QuantizationConfig):
41
+ """Config class for FBGEMM Fp8."""
42
+
43
+ def __init__(self, ignore_list: list[str], input_scale_ub: float):
44
+ super().__init__()
45
+ self.ignore_list = ignore_list if ignore_list else []
46
+ self.input_scale_ub = input_scale_ub
47
+
48
+ # For GPUs that lack FP8 hardware suspport, we can leverage the Marlin
49
+ # kernel for fast weight-only FP8 quantization
50
+ # self.use_marlin = not marlin_fp8_supported()
51
+ self.use_marlin = False
52
+ if _is_cuda:
53
+ force_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
54
+ auto_enable = can_auto_enable_marlin_fp8()
55
+ self.use_marlin = force_marlin or auto_enable
56
+
57
+ @classmethod
58
+ def get_name(cls) -> str:
59
+ return "fbgemm_fp8"
60
+
61
+ @classmethod
62
+ def get_supported_act_dtypes(cls) -> list[torch.dtype]:
63
+ return [torch.bfloat16, torch.float16]
64
+
65
+ @classmethod
66
+ def get_min_capability(cls) -> int:
67
+ return 80
68
+
69
+ @classmethod
70
+ def get_config_filenames(cls) -> list[str]:
71
+ return []
72
+
73
+ @classmethod
74
+ def from_config(cls, config: dict[str, Any]) -> FBGEMMFp8Config:
75
+ ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
76
+ input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
77
+ return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
78
+
79
+ def get_quant_method(
80
+ self, layer: torch.nn.Module, prefix: str
81
+ ) -> Optional[QuantizeMethodBase]:
82
+ if isinstance(layer, LinearBase):
83
+ if is_layer_skipped(
84
+ prefix=prefix,
85
+ ignored_layers=self.ignore_list,
86
+ fused_mapping=self.packed_modules_mapping,
87
+ ):
88
+ return UnquantizedLinearMethod()
89
+ return FBGEMMFp8LinearMethod(self)
90
+ return None
91
+
92
+ def get_scaled_act_names(self) -> List[str]:
93
+ return []
94
+
95
+
96
+ class FBGEMMFp8LinearMethod(LinearMethodBase):
97
+
98
+ def __init__(self, quant_config: FBGEMMFp8Config):
99
+ self.quant_config = quant_config
100
+ # self.fp8_linear = Fp8LinearOp(
101
+ # act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN)
102
+ self.out_dtype = torch.get_default_dtype()
103
+ self.cutlass_fp8_supported = cutlass_fp8_supported()
104
+
105
+ def create_weights(
106
+ self,
107
+ layer: torch.nn.Module,
108
+ input_size_per_partition: int,
109
+ output_partition_sizes: list[int],
110
+ input_size: int,
111
+ output_size: int,
112
+ params_dtype: torch.dtype,
113
+ **extra_weight_attrs,
114
+ ):
115
+ # maybe_create_device_identity()
116
+ weight_loader = extra_weight_attrs.get("weight_loader")
117
+ del input_size, output_size
118
+ output_size_per_partition = sum(output_partition_sizes)
119
+
120
+ layer.logical_widths = output_partition_sizes
121
+
122
+ layer.input_size_per_partition = input_size_per_partition
123
+ layer.output_size_per_partition = output_size_per_partition
124
+ layer.orig_dtype = params_dtype
125
+
126
+ # WEIGHT
127
+ weight = ModelWeightParameter(
128
+ data=torch.empty(
129
+ output_size_per_partition,
130
+ input_size_per_partition,
131
+ dtype=torch.float8_e4m3fn,
132
+ ),
133
+ input_dim=1,
134
+ output_dim=0,
135
+ weight_loader=weight_loader,
136
+ )
137
+ layer.register_parameter("weight", weight)
138
+
139
+ # WEIGHT SCALE
140
+ weight_scale = ChannelQuantScaleParameter(
141
+ data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
142
+ output_dim=0,
143
+ weight_loader=weight_loader,
144
+ )
145
+ weight_scale[:] = torch.finfo(torch.float32).min
146
+ layer.register_parameter("weight_scale", weight_scale)
147
+
148
+ # INPUT SCALE UPPER BOUND
149
+ input_scale_ub = torch.nn.Parameter(
150
+ torch.tensor((self.quant_config.input_scale_ub), dtype=torch.float32),
151
+ requires_grad=False,
152
+ )
153
+ layer.input_scale_ub = input_scale_ub
154
+
155
+ def process_weights_after_loading(self, layer: Module) -> None:
156
+ # required by torch.compile
157
+ layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
158
+ layer.weight = Parameter(layer.weight.data, requires_grad=False)
159
+
160
+ weight = layer.weight
161
+
162
+ if _is_fp8_fnuz:
163
+ weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
164
+ weight=weight, weight_scale=layer.weight_scale, input_scale=None
165
+ )
166
+ if input_scale is not None:
167
+ layer.input_scale = Parameter(input_scale, requires_grad=False)
168
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
169
+
170
+ layer.weight = Parameter(weight.t(), requires_grad=False)
171
+ if self.quant_config.use_marlin:
172
+ prepare_fp8_layer_for_marlin(layer)
173
+ # Activations not quantized for marlin.
174
+ del layer.input_scale_ub
175
+
176
+ def apply(
177
+ self,
178
+ layer: torch.nn.Module,
179
+ x: torch.Tensor,
180
+ bias: Optional[torch.Tensor] = None,
181
+ ) -> torch.Tensor:
182
+
183
+ if self.quant_config.use_marlin:
184
+ return apply_fp8_marlin_linear(
185
+ input=x,
186
+ weight=layer.weight,
187
+ weight_scale=layer.weight_scale,
188
+ workspace=layer.workspace,
189
+ size_n=layer.output_size_per_partition,
190
+ size_k=layer.input_size_per_partition,
191
+ bias=bias,
192
+ )
193
+
194
+ return apply_fp8_linear(
195
+ input=x,
196
+ weight=layer.weight,
197
+ weight_scale=layer.weight_scale,
198
+ input_scale=None,
199
+ input_scale_ub=layer.input_scale_ub,
200
+ bias=bias,
201
+ cutlass_fp8_supported=self.cutlass_fp8_supported,
202
+ use_per_token_if_dynamic=False,
203
+ )
@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.utils import (
44
44
  )
45
45
 
46
46
  if TYPE_CHECKING:
47
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
47
48
  from sglang.srt.layers.moe.topk import TopKOutput
48
49
 
49
50
  from sglang.srt.utils import is_cuda
@@ -1056,13 +1057,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
1056
1057
  layer: torch.nn.Module,
1057
1058
  x: torch.Tensor,
1058
1059
  topk_output: TopKOutput,
1059
- *,
1060
- activation: str = "silu",
1061
- **kwargs,
1060
+ moe_runner_config: MoeRunnerConfig,
1062
1061
  ) -> torch.Tensor:
1063
1062
  # Delay the import to avoid circular dependency
1064
1063
 
1065
- assert activation == "silu", "Only SiLU activation is supported."
1064
+ assert (
1065
+ moe_runner_config.activation == "silu"
1066
+ ), "Only SiLU activation is supported."
1066
1067
 
1067
1068
  # The input must currently be float16
1068
1069
  orig_dtype = x.dtype
@@ -28,6 +28,7 @@ from sglang.srt.utils import get_device_capability, is_cuda
28
28
 
29
29
  if TYPE_CHECKING:
30
30
  from sglang.srt.layers.linear import LinearBase
31
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
31
32
 
32
33
  try:
33
34
  from vllm import _custom_ops as ops
@@ -216,13 +217,13 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
216
217
  )[0]
217
218
 
218
219
 
219
- def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
220
+ def check_moe_marlin_supports_layer(layer: FusedMoE, group_size: int) -> bool:
220
221
  hidden_size = layer.hidden_size
221
222
  intermediate_size_per_partition = layer.intermediate_size_per_partition
222
223
  # apply_router_weight_on_input is not supported for moe marlin
223
- supports_router_weight = not layer.apply_router_weight_on_input
224
+ supports_router_weight = not layer.moe_runner_config.apply_router_weight_on_input
224
225
  # moe marlin requires the activation to be silu
225
- supports_activation = layer.activation == "silu"
226
+ supports_activation = layer.moe_runner_config.activation == "silu"
226
227
 
227
228
  # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
228
229
  # down: (n, k) = (hidden_size, intermediate_size_per_partition)
@@ -305,6 +306,13 @@ def marlin_permute_scales(
305
306
  return s
306
307
 
307
308
 
309
+ def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
310
+ origin_shape = s.shape
311
+ _, scale_perm_single = get_scale_perms()
312
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
313
+ return s.reshape(*origin_shape).contiguous()
314
+
315
+
308
316
  def marlin_moe_permute_scales(
309
317
  s: torch.Tensor,
310
318
  size_k: int,