sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -1356,3 +1356,280 @@ def per_token_group_quant_fp8_hopper_moe_mn_major(
1356
1356
  expert_tokens_alignment,
1357
1357
  )
1358
1358
  return a_q, sfa
1359
+
1360
+
1361
+ @triton.jit
1362
+ def _per_group_transpose(
1363
+ data_ptr: torch.Tensor,
1364
+ trans_data_ptr: torch.Tensor,
1365
+ expert_offsets: torch.Tensor,
1366
+ k: int,
1367
+ M_ALIGNMENT: tl.constexpr,
1368
+ BLOCK_SIZE_M: tl.constexpr,
1369
+ BLOCK_SIZE_K: tl.constexpr,
1370
+ ):
1371
+ expert_id = tl.program_id(0)
1372
+ m_id = tl.program_id(1)
1373
+ k_id = tl.program_id(2)
1374
+
1375
+ curr_expert_offset = tl.load(expert_offsets + expert_id)
1376
+ next_expert_offset = tl.load(expert_offsets + expert_id + 1)
1377
+ num_tokens_of_expert = next_expert_offset - curr_expert_offset
1378
+ tl.multiple_of(curr_expert_offset, M_ALIGNMENT)
1379
+ tl.multiple_of(next_expert_offset, M_ALIGNMENT)
1380
+
1381
+ data_start_ptr = data_ptr + curr_expert_offset * k
1382
+ trans_data_start_ptr = trans_data_ptr + curr_expert_offset * k
1383
+
1384
+ k_coord = k_id * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
1385
+ k_mask = k_coord < k
1386
+ for start_m in tl.range(0, num_tokens_of_expert, BLOCK_SIZE_M * tl.num_programs(1)):
1387
+ m_coord = start_m + m_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1388
+ m_mask = m_coord < num_tokens_of_expert
1389
+ off = m_coord[:, None] * k + k_coord[None, :]
1390
+ trans_off = m_coord[:, None] + k_coord[None, :] * num_tokens_of_expert
1391
+ mask = m_mask[:, None] & k_mask[None, :]
1392
+
1393
+ data = tl.load(data_start_ptr + off, mask=mask)
1394
+ tl.store(trans_data_start_ptr + trans_off, data, mask=mask)
1395
+
1396
+
1397
+ def per_group_transpose(
1398
+ a: torch.Tensor,
1399
+ expert_offsets: torch.Tensor,
1400
+ M_ALIGNMENT: int = 1,
1401
+ ) -> torch.Tensor:
1402
+ assert a.dim() == 2
1403
+ assert a.is_contiguous(), "`a` is not contiguous"
1404
+
1405
+ m, k = a.size()
1406
+ trans_a = torch.empty_like(a)
1407
+ num_experts = expert_offsets.size(0) - 1
1408
+
1409
+ grid = lambda META: (
1410
+ num_experts,
1411
+ triton.cdiv((m + num_experts - 1) // num_experts, META["BLOCK_SIZE_M"]),
1412
+ triton.cdiv(k, META["BLOCK_SIZE_K"]),
1413
+ )
1414
+ _per_group_transpose[grid](
1415
+ a, trans_a, expert_offsets, k, M_ALIGNMENT, BLOCK_SIZE_M=16, BLOCK_SIZE_K=8
1416
+ )
1417
+ return trans_a
1418
+
1419
+
1420
+ def is_weak_contiguous(x: torch.Tensor):
1421
+ strides = x.stride()
1422
+ sizes = x.shape
1423
+ is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0]))
1424
+ is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1]))
1425
+ return is_transpose or is_not_transpose
1426
+
1427
+
1428
+ @triton.jit
1429
+ def scaled_mm_kernel(
1430
+ a_ptr,
1431
+ b_ptr,
1432
+ scale_a_ptr,
1433
+ scale_b_ptr,
1434
+ c_ptr,
1435
+ bias_ptr,
1436
+ M,
1437
+ N,
1438
+ K,
1439
+ stride_am,
1440
+ stride_ak,
1441
+ stride_bk,
1442
+ stride_bn,
1443
+ stride_cm,
1444
+ stride_cn,
1445
+ ACCUMULATOR_DTYPE: tl.constexpr,
1446
+ BLOCK_SIZE_M: tl.constexpr,
1447
+ BLOCK_SIZE_N: tl.constexpr,
1448
+ BLOCK_SIZE_K: tl.constexpr,
1449
+ BLOCK_SIZE_SCALE_A: tl.constexpr,
1450
+ BLOCK_SIZE_SCALE_B: tl.constexpr,
1451
+ ):
1452
+ pid = tl.program_id(axis=0)
1453
+
1454
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
1455
+
1456
+ pid_m = pid // num_pid_n
1457
+ pid_n = pid % num_pid_n
1458
+
1459
+ accumulator_dtype = ACCUMULATOR_DTYPE
1460
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)
1461
+
1462
+ # NOTE: Some tensor inputs are so large, they will cause int32 overflow
1463
+ # so it is necessary to use tl.int64 for all the offsets, else SEGV will
1464
+ # eventually occur.
1465
+
1466
+ # Offsets and masks.
1467
+ offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
1468
+ masks_am = offsets_am < M
1469
+
1470
+ offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
1471
+ masks_bn = offsets_bn < N
1472
+
1473
+ offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
1474
+ offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :]
1475
+ offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :]
1476
+
1477
+ # NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
1478
+ # appropriate offsets and masks for each case. Same goes for
1479
+ # BLOCK_SIZE_SCALE_B.
1480
+ offsets_scale_am = (
1481
+ tl.arange(0, BLOCK_SIZE_SCALE_A)
1482
+ + (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M
1483
+ )
1484
+ masks_scale_am = offsets_scale_am < M
1485
+
1486
+ offsets_scale_bn = (
1487
+ tl.arange(0, BLOCK_SIZE_SCALE_B)
1488
+ + (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N
1489
+ )
1490
+ masks_scale_bn = offsets_scale_bn < N
1491
+
1492
+ a_ptrs = a_ptr + offsets_a
1493
+ b_ptrs = b_ptr + offsets_b
1494
+
1495
+ scale_a_ptrs = scale_a_ptr + offsets_scale_am
1496
+ scale_b_ptrs = scale_b_ptr + offsets_scale_bn
1497
+
1498
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
1499
+ masks_k = offsets_k < K
1500
+ masks_a = masks_am[:, None] & masks_k[None, :]
1501
+ a = tl.load(a_ptrs, mask=masks_a)
1502
+
1503
+ masks_b = masks_k[:, None] & masks_bn[None, :]
1504
+ b = tl.load(b_ptrs, mask=masks_b)
1505
+
1506
+ # Accumulate results.
1507
+ accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
1508
+
1509
+ offsets_k += BLOCK_SIZE_K
1510
+ a_ptrs += BLOCK_SIZE_K * stride_ak
1511
+ b_ptrs += BLOCK_SIZE_K * stride_bk
1512
+
1513
+ # Apply scale at end.
1514
+ masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
1515
+ scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a)
1516
+ # Need to broadcast to the appropriate size, if scale_a is already
1517
+ # (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes
1518
+ # for scale_b below.
1519
+ scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1))
1520
+ accumulator = scale_a * accumulator.to(tl.float32)
1521
+
1522
+ masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
1523
+ scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b)
1524
+ scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1))
1525
+ accumulator = scale_b.T * accumulator.to(tl.float32)
1526
+
1527
+ # Convert to output format.
1528
+ c = accumulator.to(c_ptr.type.element_ty)
1529
+
1530
+ # Add bias, it's already in output format, so add it after conversion.
1531
+ if bias_ptr:
1532
+ offsets_bias = offsets_bn
1533
+ bias_ptrs = bias_ptr + offsets_bias
1534
+ bias_mask = offsets_bias < N
1535
+ bias = tl.load(bias_ptrs, bias_mask)
1536
+ c += bias
1537
+
1538
+ # Save output
1539
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
1540
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
1541
+ offs_cm = offs_cm.to(tl.int64)
1542
+ offs_cn = offs_cn.to(tl.int64)
1543
+ c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
1544
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
1545
+
1546
+ tl.store(c_ptrs, c, mask=c_mask)
1547
+
1548
+
1549
+ # input - [M, K]
1550
+ # weight - [K, N]
1551
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py
1552
+ def triton_scaled_mm(
1553
+ input: torch.Tensor,
1554
+ weight: torch.Tensor,
1555
+ scale_a: torch.Tensor,
1556
+ scale_b: torch.Tensor,
1557
+ out_dtype: type[torch.dtype],
1558
+ bias: Optional[torch.Tensor] = None,
1559
+ block_size_m: int = 32,
1560
+ block_size_n: int = 32,
1561
+ block_size_k: int = 32,
1562
+ use_heuristic=True,
1563
+ ) -> torch.Tensor:
1564
+ M, K = input.shape
1565
+ N = weight.shape[1]
1566
+
1567
+ assert N > 0 and K > 0 and M > 0
1568
+ assert weight.shape[0] == K
1569
+ assert input.dtype == weight.dtype
1570
+
1571
+ scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
1572
+ scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
1573
+
1574
+ assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
1575
+ assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 or scale_a.shape[0] == M)
1576
+ assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 or scale_b.shape[0] == N)
1577
+ assert out_dtype.is_floating_point
1578
+ assert bias is None or bias.is_floating_point()
1579
+ assert is_weak_contiguous(input)
1580
+ assert is_weak_contiguous(weight)
1581
+
1582
+ grid = lambda META: (
1583
+ triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
1584
+ )
1585
+
1586
+ result = torch.empty((M, N), dtype=out_dtype, device=input.device)
1587
+
1588
+ has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
1589
+
1590
+ if use_heuristic:
1591
+ is_small_N = N < 8192
1592
+ next_power_of_2_M = max(32, triton.next_power_of_2(M))
1593
+ if next_power_of_2_M <= 32:
1594
+ tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
1595
+ elif next_power_of_2_M <= 64:
1596
+ tile_shape = (64, 64, 256)
1597
+ elif next_power_of_2_M <= 128:
1598
+ tile_shape = (64, 128, 128)
1599
+ else:
1600
+ tile_shape = (128, 128, 128)
1601
+
1602
+ block_size_m, block_size_n, block_size_k = tile_shape
1603
+
1604
+ block_size_sa = 1 if has_scalar(scale_a) else block_size_m
1605
+ block_size_sb = 1 if has_scalar(scale_b) else block_size_n
1606
+
1607
+ accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32
1608
+
1609
+ # A = input, B = weight, C = result
1610
+ # A = M x K, B = K x N, C = M x N
1611
+ scaled_mm_kernel[grid](
1612
+ input,
1613
+ weight,
1614
+ scale_a,
1615
+ scale_b,
1616
+ result,
1617
+ bias,
1618
+ M,
1619
+ N,
1620
+ K,
1621
+ input.stride(0),
1622
+ input.stride(1),
1623
+ weight.stride(0),
1624
+ weight.stride(1),
1625
+ result.stride(0),
1626
+ result.stride(1),
1627
+ accumulator_dtype,
1628
+ BLOCK_SIZE_M=block_size_m,
1629
+ BLOCK_SIZE_N=block_size_n,
1630
+ BLOCK_SIZE_K=block_size_k,
1631
+ BLOCK_SIZE_SCALE_A=block_size_sa,
1632
+ BLOCK_SIZE_SCALE_B=block_size_sb,
1633
+ )
1634
+
1635
+ return result.to(out_dtype)
@@ -4,6 +4,7 @@ import torch
4
4
 
5
5
  from sglang.srt.layers.quantization import deep_gemm_wrapper
6
6
  from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
7
+ from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
7
8
  from sglang.srt.layers.utils import is_sm100_supported
8
9
 
9
10
  try:
@@ -21,11 +22,13 @@ from sglang.srt.layers.quantization.fp8_kernel import (
21
22
  scaled_fp8_quant,
22
23
  sglang_per_token_quant_fp8,
23
24
  static_quant_fp8,
25
+ triton_scaled_mm,
24
26
  w8a8_block_fp8_matmul_deepgemm,
25
27
  w8a8_block_fp8_matmul_triton,
26
28
  )
27
29
  from sglang.srt.utils import (
28
30
  align,
31
+ ceil_div,
29
32
  get_bool_env_var,
30
33
  get_cuda_version,
31
34
  get_device_capability,
@@ -159,16 +162,16 @@ def flashinfer_gemm_w8a8_block_fp8_linear(
159
162
  output_shape = [*input.shape[:-1], weight.shape[0]]
160
163
 
161
164
  q_input, x_scale = sglang_per_token_group_quant_fp8(
162
- input_2d, block_size[1], column_major_scales=False
165
+ input_2d, block_size[1], column_major_scales=True
163
166
  )
164
-
167
+ # TRTLLM requires column-major scaling factors
165
168
  output = gemm_fp8_nt_groupwise(
166
169
  q_input,
167
170
  weight,
168
171
  x_scale,
169
172
  weight_scale,
170
- scale_major_mode="K",
171
173
  out_dtype=input_2d.dtype,
174
+ backend="trtllm",
172
175
  )
173
176
 
174
177
  if bias is not None:
@@ -307,6 +310,33 @@ def triton_w8a8_block_fp8_linear(
307
310
  return output.to(dtype=input_2d.dtype).view(*output_shape)
308
311
 
309
312
 
313
+ def dequant_mxfp4(
314
+ w_block: torch.Tensor,
315
+ w_scale: torch.Tensor,
316
+ out_dtype,
317
+ ) -> torch.Tensor:
318
+ """
319
+ :param w_block: (batch, n, k, 16), uint8, pack two mxfp4 into one byte
320
+ :param w_scale: (batch, n, k), uint8
321
+ :return: (batch, n, k * 32), float32
322
+ """
323
+
324
+ assert w_block.dtype == torch.uint8
325
+ assert w_scale.dtype == torch.uint8
326
+
327
+ batch, n, k, pack_dim = w_block.shape
328
+ batch_, n_, k_ = w_scale.shape
329
+ assert pack_dim == 16
330
+ assert batch == batch_
331
+ assert n == n_
332
+ assert k == k_
333
+
334
+ out_raw = MXFP4QuantizeUtil.dequantize(
335
+ quantized_data=w_block, scale=w_scale, dtype=out_dtype, block_sizes=[32]
336
+ )
337
+ return out_raw.reshape(batch, n, k * 32)
338
+
339
+
310
340
  def input_to_float8(
311
341
  x: torch.Tensor, dtype: torch.dtype = fp8_dtype
312
342
  ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -557,14 +587,25 @@ def apply_fp8_linear(
557
587
  assert (
558
588
  weight_scale.numel() == weight.shape[1]
559
589
  ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
560
- output = fp8_scaled_mm(
561
- qinput,
562
- weight,
563
- x_scale,
564
- weight_scale,
565
- out_dtype=input.dtype,
566
- bias=bias,
590
+
591
+ cutlass_compatible_b = (
592
+ weight.shape[0] % 16 == 0 and weight.shape[1] % 16 == 0
567
593
  )
594
+ if not cutlass_compatible_b:
595
+ # Massage the input to be 2D
596
+ qinput = qinput.view(-1, qinput.shape[-1])
597
+ output = triton_scaled_mm(
598
+ qinput, weight, x_scale, weight_scale, input.dtype, bias
599
+ )
600
+ else:
601
+ output = fp8_scaled_mm(
602
+ qinput,
603
+ weight,
604
+ x_scale,
605
+ weight_scale,
606
+ out_dtype=input.dtype,
607
+ bias=bias,
608
+ )
568
609
  return output.view(*output_shape)
569
610
 
570
611
  # torch.scaled_mm supports per tensor weights + activations only