sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -1356,3 +1356,280 @@ def per_token_group_quant_fp8_hopper_moe_mn_major(
|
|
1356
1356
|
expert_tokens_alignment,
|
1357
1357
|
)
|
1358
1358
|
return a_q, sfa
|
1359
|
+
|
1360
|
+
|
1361
|
+
@triton.jit
|
1362
|
+
def _per_group_transpose(
|
1363
|
+
data_ptr: torch.Tensor,
|
1364
|
+
trans_data_ptr: torch.Tensor,
|
1365
|
+
expert_offsets: torch.Tensor,
|
1366
|
+
k: int,
|
1367
|
+
M_ALIGNMENT: tl.constexpr,
|
1368
|
+
BLOCK_SIZE_M: tl.constexpr,
|
1369
|
+
BLOCK_SIZE_K: tl.constexpr,
|
1370
|
+
):
|
1371
|
+
expert_id = tl.program_id(0)
|
1372
|
+
m_id = tl.program_id(1)
|
1373
|
+
k_id = tl.program_id(2)
|
1374
|
+
|
1375
|
+
curr_expert_offset = tl.load(expert_offsets + expert_id)
|
1376
|
+
next_expert_offset = tl.load(expert_offsets + expert_id + 1)
|
1377
|
+
num_tokens_of_expert = next_expert_offset - curr_expert_offset
|
1378
|
+
tl.multiple_of(curr_expert_offset, M_ALIGNMENT)
|
1379
|
+
tl.multiple_of(next_expert_offset, M_ALIGNMENT)
|
1380
|
+
|
1381
|
+
data_start_ptr = data_ptr + curr_expert_offset * k
|
1382
|
+
trans_data_start_ptr = trans_data_ptr + curr_expert_offset * k
|
1383
|
+
|
1384
|
+
k_coord = k_id * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
1385
|
+
k_mask = k_coord < k
|
1386
|
+
for start_m in tl.range(0, num_tokens_of_expert, BLOCK_SIZE_M * tl.num_programs(1)):
|
1387
|
+
m_coord = start_m + m_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
1388
|
+
m_mask = m_coord < num_tokens_of_expert
|
1389
|
+
off = m_coord[:, None] * k + k_coord[None, :]
|
1390
|
+
trans_off = m_coord[:, None] + k_coord[None, :] * num_tokens_of_expert
|
1391
|
+
mask = m_mask[:, None] & k_mask[None, :]
|
1392
|
+
|
1393
|
+
data = tl.load(data_start_ptr + off, mask=mask)
|
1394
|
+
tl.store(trans_data_start_ptr + trans_off, data, mask=mask)
|
1395
|
+
|
1396
|
+
|
1397
|
+
def per_group_transpose(
|
1398
|
+
a: torch.Tensor,
|
1399
|
+
expert_offsets: torch.Tensor,
|
1400
|
+
M_ALIGNMENT: int = 1,
|
1401
|
+
) -> torch.Tensor:
|
1402
|
+
assert a.dim() == 2
|
1403
|
+
assert a.is_contiguous(), "`a` is not contiguous"
|
1404
|
+
|
1405
|
+
m, k = a.size()
|
1406
|
+
trans_a = torch.empty_like(a)
|
1407
|
+
num_experts = expert_offsets.size(0) - 1
|
1408
|
+
|
1409
|
+
grid = lambda META: (
|
1410
|
+
num_experts,
|
1411
|
+
triton.cdiv((m + num_experts - 1) // num_experts, META["BLOCK_SIZE_M"]),
|
1412
|
+
triton.cdiv(k, META["BLOCK_SIZE_K"]),
|
1413
|
+
)
|
1414
|
+
_per_group_transpose[grid](
|
1415
|
+
a, trans_a, expert_offsets, k, M_ALIGNMENT, BLOCK_SIZE_M=16, BLOCK_SIZE_K=8
|
1416
|
+
)
|
1417
|
+
return trans_a
|
1418
|
+
|
1419
|
+
|
1420
|
+
def is_weak_contiguous(x: torch.Tensor):
|
1421
|
+
strides = x.stride()
|
1422
|
+
sizes = x.shape
|
1423
|
+
is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0]))
|
1424
|
+
is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1]))
|
1425
|
+
return is_transpose or is_not_transpose
|
1426
|
+
|
1427
|
+
|
1428
|
+
@triton.jit
|
1429
|
+
def scaled_mm_kernel(
|
1430
|
+
a_ptr,
|
1431
|
+
b_ptr,
|
1432
|
+
scale_a_ptr,
|
1433
|
+
scale_b_ptr,
|
1434
|
+
c_ptr,
|
1435
|
+
bias_ptr,
|
1436
|
+
M,
|
1437
|
+
N,
|
1438
|
+
K,
|
1439
|
+
stride_am,
|
1440
|
+
stride_ak,
|
1441
|
+
stride_bk,
|
1442
|
+
stride_bn,
|
1443
|
+
stride_cm,
|
1444
|
+
stride_cn,
|
1445
|
+
ACCUMULATOR_DTYPE: tl.constexpr,
|
1446
|
+
BLOCK_SIZE_M: tl.constexpr,
|
1447
|
+
BLOCK_SIZE_N: tl.constexpr,
|
1448
|
+
BLOCK_SIZE_K: tl.constexpr,
|
1449
|
+
BLOCK_SIZE_SCALE_A: tl.constexpr,
|
1450
|
+
BLOCK_SIZE_SCALE_B: tl.constexpr,
|
1451
|
+
):
|
1452
|
+
pid = tl.program_id(axis=0)
|
1453
|
+
|
1454
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
1455
|
+
|
1456
|
+
pid_m = pid // num_pid_n
|
1457
|
+
pid_n = pid % num_pid_n
|
1458
|
+
|
1459
|
+
accumulator_dtype = ACCUMULATOR_DTYPE
|
1460
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)
|
1461
|
+
|
1462
|
+
# NOTE: Some tensor inputs are so large, they will cause int32 overflow
|
1463
|
+
# so it is necessary to use tl.int64 for all the offsets, else SEGV will
|
1464
|
+
# eventually occur.
|
1465
|
+
|
1466
|
+
# Offsets and masks.
|
1467
|
+
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
1468
|
+
masks_am = offsets_am < M
|
1469
|
+
|
1470
|
+
offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
|
1471
|
+
masks_bn = offsets_bn < N
|
1472
|
+
|
1473
|
+
offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
1474
|
+
offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :]
|
1475
|
+
offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :]
|
1476
|
+
|
1477
|
+
# NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
|
1478
|
+
# appropriate offsets and masks for each case. Same goes for
|
1479
|
+
# BLOCK_SIZE_SCALE_B.
|
1480
|
+
offsets_scale_am = (
|
1481
|
+
tl.arange(0, BLOCK_SIZE_SCALE_A)
|
1482
|
+
+ (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M
|
1483
|
+
)
|
1484
|
+
masks_scale_am = offsets_scale_am < M
|
1485
|
+
|
1486
|
+
offsets_scale_bn = (
|
1487
|
+
tl.arange(0, BLOCK_SIZE_SCALE_B)
|
1488
|
+
+ (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N
|
1489
|
+
)
|
1490
|
+
masks_scale_bn = offsets_scale_bn < N
|
1491
|
+
|
1492
|
+
a_ptrs = a_ptr + offsets_a
|
1493
|
+
b_ptrs = b_ptr + offsets_b
|
1494
|
+
|
1495
|
+
scale_a_ptrs = scale_a_ptr + offsets_scale_am
|
1496
|
+
scale_b_ptrs = scale_b_ptr + offsets_scale_bn
|
1497
|
+
|
1498
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
1499
|
+
masks_k = offsets_k < K
|
1500
|
+
masks_a = masks_am[:, None] & masks_k[None, :]
|
1501
|
+
a = tl.load(a_ptrs, mask=masks_a)
|
1502
|
+
|
1503
|
+
masks_b = masks_k[:, None] & masks_bn[None, :]
|
1504
|
+
b = tl.load(b_ptrs, mask=masks_b)
|
1505
|
+
|
1506
|
+
# Accumulate results.
|
1507
|
+
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
|
1508
|
+
|
1509
|
+
offsets_k += BLOCK_SIZE_K
|
1510
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
1511
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
1512
|
+
|
1513
|
+
# Apply scale at end.
|
1514
|
+
masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
|
1515
|
+
scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a)
|
1516
|
+
# Need to broadcast to the appropriate size, if scale_a is already
|
1517
|
+
# (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes
|
1518
|
+
# for scale_b below.
|
1519
|
+
scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1))
|
1520
|
+
accumulator = scale_a * accumulator.to(tl.float32)
|
1521
|
+
|
1522
|
+
masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
|
1523
|
+
scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b)
|
1524
|
+
scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1))
|
1525
|
+
accumulator = scale_b.T * accumulator.to(tl.float32)
|
1526
|
+
|
1527
|
+
# Convert to output format.
|
1528
|
+
c = accumulator.to(c_ptr.type.element_ty)
|
1529
|
+
|
1530
|
+
# Add bias, it's already in output format, so add it after conversion.
|
1531
|
+
if bias_ptr:
|
1532
|
+
offsets_bias = offsets_bn
|
1533
|
+
bias_ptrs = bias_ptr + offsets_bias
|
1534
|
+
bias_mask = offsets_bias < N
|
1535
|
+
bias = tl.load(bias_ptrs, bias_mask)
|
1536
|
+
c += bias
|
1537
|
+
|
1538
|
+
# Save output
|
1539
|
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
1540
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
|
1541
|
+
offs_cm = offs_cm.to(tl.int64)
|
1542
|
+
offs_cn = offs_cn.to(tl.int64)
|
1543
|
+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
1544
|
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
1545
|
+
|
1546
|
+
tl.store(c_ptrs, c, mask=c_mask)
|
1547
|
+
|
1548
|
+
|
1549
|
+
# input - [M, K]
|
1550
|
+
# weight - [K, N]
|
1551
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py
|
1552
|
+
def triton_scaled_mm(
|
1553
|
+
input: torch.Tensor,
|
1554
|
+
weight: torch.Tensor,
|
1555
|
+
scale_a: torch.Tensor,
|
1556
|
+
scale_b: torch.Tensor,
|
1557
|
+
out_dtype: type[torch.dtype],
|
1558
|
+
bias: Optional[torch.Tensor] = None,
|
1559
|
+
block_size_m: int = 32,
|
1560
|
+
block_size_n: int = 32,
|
1561
|
+
block_size_k: int = 32,
|
1562
|
+
use_heuristic=True,
|
1563
|
+
) -> torch.Tensor:
|
1564
|
+
M, K = input.shape
|
1565
|
+
N = weight.shape[1]
|
1566
|
+
|
1567
|
+
assert N > 0 and K > 0 and M > 0
|
1568
|
+
assert weight.shape[0] == K
|
1569
|
+
assert input.dtype == weight.dtype
|
1570
|
+
|
1571
|
+
scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
|
1572
|
+
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
|
1573
|
+
|
1574
|
+
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
|
1575
|
+
assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 or scale_a.shape[0] == M)
|
1576
|
+
assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 or scale_b.shape[0] == N)
|
1577
|
+
assert out_dtype.is_floating_point
|
1578
|
+
assert bias is None or bias.is_floating_point()
|
1579
|
+
assert is_weak_contiguous(input)
|
1580
|
+
assert is_weak_contiguous(weight)
|
1581
|
+
|
1582
|
+
grid = lambda META: (
|
1583
|
+
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
1584
|
+
)
|
1585
|
+
|
1586
|
+
result = torch.empty((M, N), dtype=out_dtype, device=input.device)
|
1587
|
+
|
1588
|
+
has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
|
1589
|
+
|
1590
|
+
if use_heuristic:
|
1591
|
+
is_small_N = N < 8192
|
1592
|
+
next_power_of_2_M = max(32, triton.next_power_of_2(M))
|
1593
|
+
if next_power_of_2_M <= 32:
|
1594
|
+
tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
|
1595
|
+
elif next_power_of_2_M <= 64:
|
1596
|
+
tile_shape = (64, 64, 256)
|
1597
|
+
elif next_power_of_2_M <= 128:
|
1598
|
+
tile_shape = (64, 128, 128)
|
1599
|
+
else:
|
1600
|
+
tile_shape = (128, 128, 128)
|
1601
|
+
|
1602
|
+
block_size_m, block_size_n, block_size_k = tile_shape
|
1603
|
+
|
1604
|
+
block_size_sa = 1 if has_scalar(scale_a) else block_size_m
|
1605
|
+
block_size_sb = 1 if has_scalar(scale_b) else block_size_n
|
1606
|
+
|
1607
|
+
accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32
|
1608
|
+
|
1609
|
+
# A = input, B = weight, C = result
|
1610
|
+
# A = M x K, B = K x N, C = M x N
|
1611
|
+
scaled_mm_kernel[grid](
|
1612
|
+
input,
|
1613
|
+
weight,
|
1614
|
+
scale_a,
|
1615
|
+
scale_b,
|
1616
|
+
result,
|
1617
|
+
bias,
|
1618
|
+
M,
|
1619
|
+
N,
|
1620
|
+
K,
|
1621
|
+
input.stride(0),
|
1622
|
+
input.stride(1),
|
1623
|
+
weight.stride(0),
|
1624
|
+
weight.stride(1),
|
1625
|
+
result.stride(0),
|
1626
|
+
result.stride(1),
|
1627
|
+
accumulator_dtype,
|
1628
|
+
BLOCK_SIZE_M=block_size_m,
|
1629
|
+
BLOCK_SIZE_N=block_size_n,
|
1630
|
+
BLOCK_SIZE_K=block_size_k,
|
1631
|
+
BLOCK_SIZE_SCALE_A=block_size_sa,
|
1632
|
+
BLOCK_SIZE_SCALE_B=block_size_sb,
|
1633
|
+
)
|
1634
|
+
|
1635
|
+
return result.to(out_dtype)
|
@@ -4,6 +4,7 @@ import torch
|
|
4
4
|
|
5
5
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
6
6
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
7
|
+
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
7
8
|
from sglang.srt.layers.utils import is_sm100_supported
|
8
9
|
|
9
10
|
try:
|
@@ -21,11 +22,13 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
21
22
|
scaled_fp8_quant,
|
22
23
|
sglang_per_token_quant_fp8,
|
23
24
|
static_quant_fp8,
|
25
|
+
triton_scaled_mm,
|
24
26
|
w8a8_block_fp8_matmul_deepgemm,
|
25
27
|
w8a8_block_fp8_matmul_triton,
|
26
28
|
)
|
27
29
|
from sglang.srt.utils import (
|
28
30
|
align,
|
31
|
+
ceil_div,
|
29
32
|
get_bool_env_var,
|
30
33
|
get_cuda_version,
|
31
34
|
get_device_capability,
|
@@ -159,16 +162,16 @@ def flashinfer_gemm_w8a8_block_fp8_linear(
|
|
159
162
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
160
163
|
|
161
164
|
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
162
|
-
input_2d, block_size[1], column_major_scales=
|
165
|
+
input_2d, block_size[1], column_major_scales=True
|
163
166
|
)
|
164
|
-
|
167
|
+
# TRTLLM requires column-major scaling factors
|
165
168
|
output = gemm_fp8_nt_groupwise(
|
166
169
|
q_input,
|
167
170
|
weight,
|
168
171
|
x_scale,
|
169
172
|
weight_scale,
|
170
|
-
scale_major_mode="K",
|
171
173
|
out_dtype=input_2d.dtype,
|
174
|
+
backend="trtllm",
|
172
175
|
)
|
173
176
|
|
174
177
|
if bias is not None:
|
@@ -307,6 +310,33 @@ def triton_w8a8_block_fp8_linear(
|
|
307
310
|
return output.to(dtype=input_2d.dtype).view(*output_shape)
|
308
311
|
|
309
312
|
|
313
|
+
def dequant_mxfp4(
|
314
|
+
w_block: torch.Tensor,
|
315
|
+
w_scale: torch.Tensor,
|
316
|
+
out_dtype,
|
317
|
+
) -> torch.Tensor:
|
318
|
+
"""
|
319
|
+
:param w_block: (batch, n, k, 16), uint8, pack two mxfp4 into one byte
|
320
|
+
:param w_scale: (batch, n, k), uint8
|
321
|
+
:return: (batch, n, k * 32), float32
|
322
|
+
"""
|
323
|
+
|
324
|
+
assert w_block.dtype == torch.uint8
|
325
|
+
assert w_scale.dtype == torch.uint8
|
326
|
+
|
327
|
+
batch, n, k, pack_dim = w_block.shape
|
328
|
+
batch_, n_, k_ = w_scale.shape
|
329
|
+
assert pack_dim == 16
|
330
|
+
assert batch == batch_
|
331
|
+
assert n == n_
|
332
|
+
assert k == k_
|
333
|
+
|
334
|
+
out_raw = MXFP4QuantizeUtil.dequantize(
|
335
|
+
quantized_data=w_block, scale=w_scale, dtype=out_dtype, block_sizes=[32]
|
336
|
+
)
|
337
|
+
return out_raw.reshape(batch, n, k * 32)
|
338
|
+
|
339
|
+
|
310
340
|
def input_to_float8(
|
311
341
|
x: torch.Tensor, dtype: torch.dtype = fp8_dtype
|
312
342
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
@@ -557,14 +587,25 @@ def apply_fp8_linear(
|
|
557
587
|
assert (
|
558
588
|
weight_scale.numel() == weight.shape[1]
|
559
589
|
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
560
|
-
|
561
|
-
|
562
|
-
weight
|
563
|
-
x_scale,
|
564
|
-
weight_scale,
|
565
|
-
out_dtype=input.dtype,
|
566
|
-
bias=bias,
|
590
|
+
|
591
|
+
cutlass_compatible_b = (
|
592
|
+
weight.shape[0] % 16 == 0 and weight.shape[1] % 16 == 0
|
567
593
|
)
|
594
|
+
if not cutlass_compatible_b:
|
595
|
+
# Massage the input to be 2D
|
596
|
+
qinput = qinput.view(-1, qinput.shape[-1])
|
597
|
+
output = triton_scaled_mm(
|
598
|
+
qinput, weight, x_scale, weight_scale, input.dtype, bias
|
599
|
+
)
|
600
|
+
else:
|
601
|
+
output = fp8_scaled_mm(
|
602
|
+
qinput,
|
603
|
+
weight,
|
604
|
+
x_scale,
|
605
|
+
weight_scale,
|
606
|
+
out_dtype=input.dtype,
|
607
|
+
bias=bias,
|
608
|
+
)
|
568
609
|
return output.view(*output_shape)
|
569
610
|
|
570
611
|
# torch.scaled_mm supports per tensor weights + activations only
|