sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -29,29 +29,25 @@ from sglang.srt.layers.quantization.marlin_utils import (
29
29
  verify_marlin_supported,
30
30
  verify_marlin_supports_shape,
31
31
  )
32
- from sglang.srt.layers.quantization.scalar_type import scalar_types
33
32
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
34
- from sglang.srt.layers.quantization.utils import replace_parameter
33
+ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
35
34
 
36
35
  if TYPE_CHECKING:
37
36
  from sglang.srt.layers.moe.topk import TopKOutput
38
37
 
39
- try:
40
- from vllm import _custom_ops as ops
41
-
42
- warnings.warn(
43
- f"Using kernels directly from vllm. This might lead to performance degradation or "
44
- f"missing functionalities as certain kernels may not be optimized. "
45
- )
46
- except ImportError:
47
- ops = None
48
-
49
38
  from sglang.srt.utils import is_cuda, is_hip
50
39
 
51
40
  _is_cuda = is_cuda()
52
41
  _is_hip = is_hip()
53
42
  if _is_cuda:
54
- from sgl_kernel import awq_dequantize, fused_marlin_moe
43
+ from sgl_kernel import (
44
+ awq_dequantize,
45
+ awq_marlin_moe_repack,
46
+ awq_marlin_repack,
47
+ fused_marlin_moe,
48
+ )
49
+
50
+
55
51
  elif _is_hip:
56
52
  from sglang.srt.layers.quantization.awq_triton import (
57
53
  awq_dequantize_triton as awq_dequantize,
@@ -64,6 +60,9 @@ else:
64
60
  logger = logging.getLogger(__name__)
65
61
 
66
62
 
63
+ ScalarType, scalar_types = get_scalar_types()
64
+
65
+
67
66
  def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
68
67
  return any(module_name in prefix for module_name in modules_to_not_convert)
69
68
 
@@ -516,7 +515,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
516
515
  layer.workspace = marlin_make_workspace(device)
517
516
 
518
517
  # Repack weights from AWQ format to marlin format.
519
- marlin_qweight = ops.awq_marlin_repack(
518
+ marlin_qweight = awq_marlin_repack(
520
519
  layer.qweight,
521
520
  size_k=layer.input_size_per_partition,
522
521
  size_n=layer.output_size_per_partition,
@@ -684,7 +683,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
684
683
  requires_grad=False,
685
684
  )
686
685
 
687
- marlin_w13_qweight = ops.awq_marlin_moe_repack(
686
+ marlin_w13_qweight = awq_marlin_moe_repack(
688
687
  layer.w13_qweight,
689
688
  layer.w13_g_idx_sort_indices,
690
689
  size_k=layer.w13_qweight.shape[1],
@@ -693,7 +692,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
693
692
  )
694
693
  replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
695
694
 
696
- marlin_w2_qweight = ops.awq_marlin_moe_repack(
695
+ marlin_w2_qweight = awq_marlin_moe_repack(
697
696
  layer.w2_qweight,
698
697
  layer.w2_g_idx_sort_indices,
699
698
  size_k=layer.w2_qweight.shape[1],
@@ -16,7 +16,6 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu
16
16
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
17
17
  from sglang.srt.layers.quantization.utils import (
18
18
  all_close_1d,
19
- cpu_has_amx_support,
20
19
  per_tensor_dequantize,
21
20
  replace_parameter,
22
21
  )
@@ -1356,3 +1356,280 @@ def per_token_group_quant_fp8_hopper_moe_mn_major(
1356
1356
  expert_tokens_alignment,
1357
1357
  )
1358
1358
  return a_q, sfa
1359
+
1360
+
1361
+ @triton.jit
1362
+ def _per_group_transpose(
1363
+ data_ptr: torch.Tensor,
1364
+ trans_data_ptr: torch.Tensor,
1365
+ expert_offsets: torch.Tensor,
1366
+ k: int,
1367
+ M_ALIGNMENT: tl.constexpr,
1368
+ BLOCK_SIZE_M: tl.constexpr,
1369
+ BLOCK_SIZE_K: tl.constexpr,
1370
+ ):
1371
+ expert_id = tl.program_id(0)
1372
+ m_id = tl.program_id(1)
1373
+ k_id = tl.program_id(2)
1374
+
1375
+ curr_expert_offset = tl.load(expert_offsets + expert_id)
1376
+ next_expert_offset = tl.load(expert_offsets + expert_id + 1)
1377
+ num_tokens_of_expert = next_expert_offset - curr_expert_offset
1378
+ tl.multiple_of(curr_expert_offset, M_ALIGNMENT)
1379
+ tl.multiple_of(next_expert_offset, M_ALIGNMENT)
1380
+
1381
+ data_start_ptr = data_ptr + curr_expert_offset * k
1382
+ trans_data_start_ptr = trans_data_ptr + curr_expert_offset * k
1383
+
1384
+ k_coord = k_id * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
1385
+ k_mask = k_coord < k
1386
+ for start_m in tl.range(0, num_tokens_of_expert, BLOCK_SIZE_M * tl.num_programs(1)):
1387
+ m_coord = start_m + m_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1388
+ m_mask = m_coord < num_tokens_of_expert
1389
+ off = m_coord[:, None] * k + k_coord[None, :]
1390
+ trans_off = m_coord[:, None] + k_coord[None, :] * num_tokens_of_expert
1391
+ mask = m_mask[:, None] & k_mask[None, :]
1392
+
1393
+ data = tl.load(data_start_ptr + off, mask=mask)
1394
+ tl.store(trans_data_start_ptr + trans_off, data, mask=mask)
1395
+
1396
+
1397
+ def per_group_transpose(
1398
+ a: torch.Tensor,
1399
+ expert_offsets: torch.Tensor,
1400
+ M_ALIGNMENT: int = 1,
1401
+ ) -> torch.Tensor:
1402
+ assert a.dim() == 2
1403
+ assert a.is_contiguous(), "`a` is not contiguous"
1404
+
1405
+ m, k = a.size()
1406
+ trans_a = torch.empty_like(a)
1407
+ num_experts = expert_offsets.size(0) - 1
1408
+
1409
+ grid = lambda META: (
1410
+ num_experts,
1411
+ triton.cdiv((m + num_experts - 1) // num_experts, META["BLOCK_SIZE_M"]),
1412
+ triton.cdiv(k, META["BLOCK_SIZE_K"]),
1413
+ )
1414
+ _per_group_transpose[grid](
1415
+ a, trans_a, expert_offsets, k, M_ALIGNMENT, BLOCK_SIZE_M=16, BLOCK_SIZE_K=8
1416
+ )
1417
+ return trans_a
1418
+
1419
+
1420
+ def is_weak_contiguous(x: torch.Tensor):
1421
+ strides = x.stride()
1422
+ sizes = x.shape
1423
+ is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0]))
1424
+ is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1]))
1425
+ return is_transpose or is_not_transpose
1426
+
1427
+
1428
+ @triton.jit
1429
+ def scaled_mm_kernel(
1430
+ a_ptr,
1431
+ b_ptr,
1432
+ scale_a_ptr,
1433
+ scale_b_ptr,
1434
+ c_ptr,
1435
+ bias_ptr,
1436
+ M,
1437
+ N,
1438
+ K,
1439
+ stride_am,
1440
+ stride_ak,
1441
+ stride_bk,
1442
+ stride_bn,
1443
+ stride_cm,
1444
+ stride_cn,
1445
+ ACCUMULATOR_DTYPE: tl.constexpr,
1446
+ BLOCK_SIZE_M: tl.constexpr,
1447
+ BLOCK_SIZE_N: tl.constexpr,
1448
+ BLOCK_SIZE_K: tl.constexpr,
1449
+ BLOCK_SIZE_SCALE_A: tl.constexpr,
1450
+ BLOCK_SIZE_SCALE_B: tl.constexpr,
1451
+ ):
1452
+ pid = tl.program_id(axis=0)
1453
+
1454
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
1455
+
1456
+ pid_m = pid // num_pid_n
1457
+ pid_n = pid % num_pid_n
1458
+
1459
+ accumulator_dtype = ACCUMULATOR_DTYPE
1460
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)
1461
+
1462
+ # NOTE: Some tensor inputs are so large, they will cause int32 overflow
1463
+ # so it is necessary to use tl.int64 for all the offsets, else SEGV will
1464
+ # eventually occur.
1465
+
1466
+ # Offsets and masks.
1467
+ offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
1468
+ masks_am = offsets_am < M
1469
+
1470
+ offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
1471
+ masks_bn = offsets_bn < N
1472
+
1473
+ offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
1474
+ offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :]
1475
+ offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :]
1476
+
1477
+ # NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
1478
+ # appropriate offsets and masks for each case. Same goes for
1479
+ # BLOCK_SIZE_SCALE_B.
1480
+ offsets_scale_am = (
1481
+ tl.arange(0, BLOCK_SIZE_SCALE_A)
1482
+ + (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M
1483
+ )
1484
+ masks_scale_am = offsets_scale_am < M
1485
+
1486
+ offsets_scale_bn = (
1487
+ tl.arange(0, BLOCK_SIZE_SCALE_B)
1488
+ + (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N
1489
+ )
1490
+ masks_scale_bn = offsets_scale_bn < N
1491
+
1492
+ a_ptrs = a_ptr + offsets_a
1493
+ b_ptrs = b_ptr + offsets_b
1494
+
1495
+ scale_a_ptrs = scale_a_ptr + offsets_scale_am
1496
+ scale_b_ptrs = scale_b_ptr + offsets_scale_bn
1497
+
1498
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
1499
+ masks_k = offsets_k < K
1500
+ masks_a = masks_am[:, None] & masks_k[None, :]
1501
+ a = tl.load(a_ptrs, mask=masks_a)
1502
+
1503
+ masks_b = masks_k[:, None] & masks_bn[None, :]
1504
+ b = tl.load(b_ptrs, mask=masks_b)
1505
+
1506
+ # Accumulate results.
1507
+ accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
1508
+
1509
+ offsets_k += BLOCK_SIZE_K
1510
+ a_ptrs += BLOCK_SIZE_K * stride_ak
1511
+ b_ptrs += BLOCK_SIZE_K * stride_bk
1512
+
1513
+ # Apply scale at end.
1514
+ masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
1515
+ scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a)
1516
+ # Need to broadcast to the appropriate size, if scale_a is already
1517
+ # (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes
1518
+ # for scale_b below.
1519
+ scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1))
1520
+ accumulator = scale_a * accumulator.to(tl.float32)
1521
+
1522
+ masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
1523
+ scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b)
1524
+ scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1))
1525
+ accumulator = scale_b.T * accumulator.to(tl.float32)
1526
+
1527
+ # Convert to output format.
1528
+ c = accumulator.to(c_ptr.type.element_ty)
1529
+
1530
+ # Add bias, it's already in output format, so add it after conversion.
1531
+ if bias_ptr:
1532
+ offsets_bias = offsets_bn
1533
+ bias_ptrs = bias_ptr + offsets_bias
1534
+ bias_mask = offsets_bias < N
1535
+ bias = tl.load(bias_ptrs, bias_mask)
1536
+ c += bias
1537
+
1538
+ # Save output
1539
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
1540
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
1541
+ offs_cm = offs_cm.to(tl.int64)
1542
+ offs_cn = offs_cn.to(tl.int64)
1543
+ c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
1544
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
1545
+
1546
+ tl.store(c_ptrs, c, mask=c_mask)
1547
+
1548
+
1549
+ # input - [M, K]
1550
+ # weight - [K, N]
1551
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py
1552
+ def triton_scaled_mm(
1553
+ input: torch.Tensor,
1554
+ weight: torch.Tensor,
1555
+ scale_a: torch.Tensor,
1556
+ scale_b: torch.Tensor,
1557
+ out_dtype: type[torch.dtype],
1558
+ bias: Optional[torch.Tensor] = None,
1559
+ block_size_m: int = 32,
1560
+ block_size_n: int = 32,
1561
+ block_size_k: int = 32,
1562
+ use_heuristic=True,
1563
+ ) -> torch.Tensor:
1564
+ M, K = input.shape
1565
+ N = weight.shape[1]
1566
+
1567
+ assert N > 0 and K > 0 and M > 0
1568
+ assert weight.shape[0] == K
1569
+ assert input.dtype == weight.dtype
1570
+
1571
+ scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
1572
+ scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
1573
+
1574
+ assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
1575
+ assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 or scale_a.shape[0] == M)
1576
+ assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 or scale_b.shape[0] == N)
1577
+ assert out_dtype.is_floating_point
1578
+ assert bias is None or bias.is_floating_point()
1579
+ assert is_weak_contiguous(input)
1580
+ assert is_weak_contiguous(weight)
1581
+
1582
+ grid = lambda META: (
1583
+ triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
1584
+ )
1585
+
1586
+ result = torch.empty((M, N), dtype=out_dtype, device=input.device)
1587
+
1588
+ has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
1589
+
1590
+ if use_heuristic:
1591
+ is_small_N = N < 8192
1592
+ next_power_of_2_M = max(32, triton.next_power_of_2(M))
1593
+ if next_power_of_2_M <= 32:
1594
+ tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
1595
+ elif next_power_of_2_M <= 64:
1596
+ tile_shape = (64, 64, 256)
1597
+ elif next_power_of_2_M <= 128:
1598
+ tile_shape = (64, 128, 128)
1599
+ else:
1600
+ tile_shape = (128, 128, 128)
1601
+
1602
+ block_size_m, block_size_n, block_size_k = tile_shape
1603
+
1604
+ block_size_sa = 1 if has_scalar(scale_a) else block_size_m
1605
+ block_size_sb = 1 if has_scalar(scale_b) else block_size_n
1606
+
1607
+ accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32
1608
+
1609
+ # A = input, B = weight, C = result
1610
+ # A = M x K, B = K x N, C = M x N
1611
+ scaled_mm_kernel[grid](
1612
+ input,
1613
+ weight,
1614
+ scale_a,
1615
+ scale_b,
1616
+ result,
1617
+ bias,
1618
+ M,
1619
+ N,
1620
+ K,
1621
+ input.stride(0),
1622
+ input.stride(1),
1623
+ weight.stride(0),
1624
+ weight.stride(1),
1625
+ result.stride(0),
1626
+ result.stride(1),
1627
+ accumulator_dtype,
1628
+ BLOCK_SIZE_M=block_size_m,
1629
+ BLOCK_SIZE_N=block_size_n,
1630
+ BLOCK_SIZE_K=block_size_k,
1631
+ BLOCK_SIZE_SCALE_A=block_size_sa,
1632
+ BLOCK_SIZE_SCALE_B=block_size_sb,
1633
+ )
1634
+
1635
+ return result.to(out_dtype)
@@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
22
22
  scaled_fp8_quant,
23
23
  sglang_per_token_quant_fp8,
24
24
  static_quant_fp8,
25
+ triton_scaled_mm,
25
26
  w8a8_block_fp8_matmul_deepgemm,
26
27
  w8a8_block_fp8_matmul_triton,
27
28
  )
@@ -161,16 +162,16 @@ def flashinfer_gemm_w8a8_block_fp8_linear(
161
162
  output_shape = [*input.shape[:-1], weight.shape[0]]
162
163
 
163
164
  q_input, x_scale = sglang_per_token_group_quant_fp8(
164
- input_2d, block_size[1], column_major_scales=False
165
+ input_2d, block_size[1], column_major_scales=True
165
166
  )
166
-
167
+ # TRTLLM requires column-major scaling factors
167
168
  output = gemm_fp8_nt_groupwise(
168
169
  q_input,
169
170
  weight,
170
171
  x_scale,
171
172
  weight_scale,
172
- scale_major_mode="K",
173
173
  out_dtype=input_2d.dtype,
174
+ backend="trtllm",
174
175
  )
175
176
 
176
177
  if bias is not None:
@@ -586,14 +587,25 @@ def apply_fp8_linear(
586
587
  assert (
587
588
  weight_scale.numel() == weight.shape[1]
588
589
  ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
589
- output = fp8_scaled_mm(
590
- qinput,
591
- weight,
592
- x_scale,
593
- weight_scale,
594
- out_dtype=input.dtype,
595
- bias=bias,
590
+
591
+ cutlass_compatible_b = (
592
+ weight.shape[0] % 16 == 0 and weight.shape[1] % 16 == 0
596
593
  )
594
+ if not cutlass_compatible_b:
595
+ # Massage the input to be 2D
596
+ qinput = qinput.view(-1, qinput.shape[-1])
597
+ output = triton_scaled_mm(
598
+ qinput, weight, x_scale, weight_scale, input.dtype, bias
599
+ )
600
+ else:
601
+ output = fp8_scaled_mm(
602
+ qinput,
603
+ weight,
604
+ x_scale,
605
+ weight_scale,
606
+ out_dtype=input.dtype,
607
+ bias=bias,
608
+ )
597
609
  return output.view(*output_shape)
598
610
 
599
611
  # torch.scaled_mm supports per tensor weights + activations only
@@ -36,9 +36,9 @@ from sglang.srt.layers.quantization.marlin_utils import (
36
36
  marlin_zero_points,
37
37
  verify_marlin_supported,
38
38
  )
39
- from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
40
39
  from sglang.srt.layers.quantization.utils import (
41
40
  get_linear_quant_method,
41
+ get_scalar_types,
42
42
  replace_parameter,
43
43
  unpack_cols,
44
44
  )
@@ -46,20 +46,16 @@ from sglang.srt.layers.quantization.utils import (
46
46
  if TYPE_CHECKING:
47
47
  from sglang.srt.layers.moe.topk import TopKOutput
48
48
 
49
- try:
50
- from vllm import _custom_ops as ops
51
- except ImportError:
52
- ops = None
53
-
54
49
  from sglang.srt.utils import is_cuda
55
50
 
56
51
  _is_cuda = is_cuda()
57
52
 
58
53
  if _is_cuda:
59
- from sgl_kernel import fused_marlin_moe
54
+ from sgl_kernel import fused_marlin_moe, gptq_gemm, gptq_marlin_repack, gptq_shuffle
60
55
 
61
56
 
62
57
  logger = logging.getLogger(__name__)
58
+ ScalarType, scalar_types = get_scalar_types()
63
59
 
64
60
 
65
61
  def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
@@ -85,9 +81,7 @@ def gptq_marlin_moe_repack(
85
81
  dtype=b_q_weight.dtype,
86
82
  )
87
83
  for e in range(num_experts):
88
- output[e] = torch.ops.sgl_kernel.gptq_marlin_repack(
89
- b_q_weight[e], perm[e], size_k, size_n, num_bits
90
- )
84
+ output[e] = gptq_marlin_repack(b_q_weight[e], perm[e], size_k, size_n, num_bits)
91
85
  return output
92
86
 
93
87
 
@@ -204,11 +198,12 @@ class GPTQConfig(QuantizationConfig):
204
198
  from sglang.srt.layers.linear import LinearBase
205
199
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
206
200
 
207
- if isinstance(layer, LinearBase):
208
- return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
209
- elif isinstance(layer, FusedMoE):
201
+ if isinstance(layer, FusedMoE):
210
202
  raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin")
211
- return None
203
+ else:
204
+ return get_linear_quant_method(
205
+ self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
206
+ )
212
207
 
213
208
 
214
209
  class GPTQMarlinConfig(QuantizationConfig):
@@ -530,7 +525,7 @@ class GPTQLinearMethod(LinearMethodBase):
530
525
  layer.g_idx.data = torch.empty(
531
526
  (0,), dtype=torch.int, device=layer.g_idx.device
532
527
  )
533
- ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
528
+ gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
534
529
 
535
530
  def apply(
536
531
  self,
@@ -541,7 +536,7 @@ class GPTQLinearMethod(LinearMethodBase):
541
536
  out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
542
537
  reshaped_x = x.reshape(-1, x.shape[-1])
543
538
 
544
- output = ops.gptq_gemm(
539
+ output = gptq_gemm(
545
540
  reshaped_x,
546
541
  layer.qweight,
547
542
  layer.qzeros,
@@ -726,7 +721,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
726
721
  def transform_w_q(x):
727
722
  assert isinstance(x, BasevLLMParameter)
728
723
  permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
729
- x.data = torch.ops.sgl_kernel.gptq_marlin_repack(
724
+ x.data = gptq_marlin_repack(
730
725
  x.data.contiguous(),
731
726
  perm=layer.g_idx_sort_indices,
732
727
  size_k=c.partition_weight_shape[0],
@@ -19,9 +19,12 @@ from sglang.srt.layers.quantization.base_config import (
19
19
  LinearMethodBase,
20
20
  QuantizationConfig,
21
21
  )
22
- from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
23
- from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols
24
- from sglang.srt.utils import get_device_capability
22
+ from sglang.srt.layers.quantization.utils import (
23
+ get_scalar_types,
24
+ pack_cols,
25
+ unpack_cols,
26
+ )
27
+ from sglang.srt.utils import get_device_capability, is_cuda
25
28
 
26
29
  if TYPE_CHECKING:
27
30
  from sglang.srt.layers.linear import LinearBase
@@ -31,8 +34,15 @@ try:
31
34
  except ImportError:
32
35
  ops = None
33
36
 
37
+ _is_cuda = is_cuda()
38
+
39
+ if _is_cuda:
40
+ from sgl_kernel import gptq_marlin_gemm
41
+
34
42
  logger = logging.getLogger(__name__)
35
43
 
44
+ ScalarType, scalar_types = get_scalar_types()
45
+
36
46
  GPTQ_MARLIN_TILE = 16
37
47
  GPTQ_MARLIN_MIN_THREAD_N = 64
38
48
  GPTQ_MARLIN_MIN_THREAD_K = 128
@@ -453,7 +463,7 @@ def apply_gptq_marlin_linear(
453
463
  dtype=input.dtype,
454
464
  )
455
465
 
456
- output = ops.gptq_marlin_gemm(
466
+ output = gptq_marlin_gemm(
457
467
  reshaped_x,
458
468
  None,
459
469
  weight,
@@ -504,7 +514,7 @@ def apply_awq_marlin_linear(
504
514
  dtype=input.dtype,
505
515
  )
506
516
 
507
- output = ops.gptq_marlin_gemm(
517
+ output = gptq_marlin_gemm(
508
518
  reshaped_x,
509
519
  None,
510
520
  weight,