sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -29,29 +29,25 @@ from sglang.srt.layers.quantization.marlin_utils import (
|
|
29
29
|
verify_marlin_supported,
|
30
30
|
verify_marlin_supports_shape,
|
31
31
|
)
|
32
|
-
from sglang.srt.layers.quantization.scalar_type import scalar_types
|
33
32
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
34
|
-
from sglang.srt.layers.quantization.utils import replace_parameter
|
33
|
+
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
|
35
34
|
|
36
35
|
if TYPE_CHECKING:
|
37
36
|
from sglang.srt.layers.moe.topk import TopKOutput
|
38
37
|
|
39
|
-
try:
|
40
|
-
from vllm import _custom_ops as ops
|
41
|
-
|
42
|
-
warnings.warn(
|
43
|
-
f"Using kernels directly from vllm. This might lead to performance degradation or "
|
44
|
-
f"missing functionalities as certain kernels may not be optimized. "
|
45
|
-
)
|
46
|
-
except ImportError:
|
47
|
-
ops = None
|
48
|
-
|
49
38
|
from sglang.srt.utils import is_cuda, is_hip
|
50
39
|
|
51
40
|
_is_cuda = is_cuda()
|
52
41
|
_is_hip = is_hip()
|
53
42
|
if _is_cuda:
|
54
|
-
from sgl_kernel import
|
43
|
+
from sgl_kernel import (
|
44
|
+
awq_dequantize,
|
45
|
+
awq_marlin_moe_repack,
|
46
|
+
awq_marlin_repack,
|
47
|
+
fused_marlin_moe,
|
48
|
+
)
|
49
|
+
|
50
|
+
|
55
51
|
elif _is_hip:
|
56
52
|
from sglang.srt.layers.quantization.awq_triton import (
|
57
53
|
awq_dequantize_triton as awq_dequantize,
|
@@ -64,6 +60,9 @@ else:
|
|
64
60
|
logger = logging.getLogger(__name__)
|
65
61
|
|
66
62
|
|
63
|
+
ScalarType, scalar_types = get_scalar_types()
|
64
|
+
|
65
|
+
|
67
66
|
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
|
68
67
|
return any(module_name in prefix for module_name in modules_to_not_convert)
|
69
68
|
|
@@ -516,7 +515,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
|
516
515
|
layer.workspace = marlin_make_workspace(device)
|
517
516
|
|
518
517
|
# Repack weights from AWQ format to marlin format.
|
519
|
-
marlin_qweight =
|
518
|
+
marlin_qweight = awq_marlin_repack(
|
520
519
|
layer.qweight,
|
521
520
|
size_k=layer.input_size_per_partition,
|
522
521
|
size_n=layer.output_size_per_partition,
|
@@ -684,7 +683,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|
684
683
|
requires_grad=False,
|
685
684
|
)
|
686
685
|
|
687
|
-
marlin_w13_qweight =
|
686
|
+
marlin_w13_qweight = awq_marlin_moe_repack(
|
688
687
|
layer.w13_qweight,
|
689
688
|
layer.w13_g_idx_sort_indices,
|
690
689
|
size_k=layer.w13_qweight.shape[1],
|
@@ -693,7 +692,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|
693
692
|
)
|
694
693
|
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
695
694
|
|
696
|
-
marlin_w2_qweight =
|
695
|
+
marlin_w2_qweight = awq_marlin_moe_repack(
|
697
696
|
layer.w2_qweight,
|
698
697
|
layer.w2_g_idx_sort_indices,
|
699
698
|
size_k=layer.w2_qweight.shape[1],
|
@@ -16,7 +16,6 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu
|
|
16
16
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
17
17
|
from sglang.srt.layers.quantization.utils import (
|
18
18
|
all_close_1d,
|
19
|
-
cpu_has_amx_support,
|
20
19
|
per_tensor_dequantize,
|
21
20
|
replace_parameter,
|
22
21
|
)
|
@@ -1356,3 +1356,280 @@ def per_token_group_quant_fp8_hopper_moe_mn_major(
|
|
1356
1356
|
expert_tokens_alignment,
|
1357
1357
|
)
|
1358
1358
|
return a_q, sfa
|
1359
|
+
|
1360
|
+
|
1361
|
+
@triton.jit
|
1362
|
+
def _per_group_transpose(
|
1363
|
+
data_ptr: torch.Tensor,
|
1364
|
+
trans_data_ptr: torch.Tensor,
|
1365
|
+
expert_offsets: torch.Tensor,
|
1366
|
+
k: int,
|
1367
|
+
M_ALIGNMENT: tl.constexpr,
|
1368
|
+
BLOCK_SIZE_M: tl.constexpr,
|
1369
|
+
BLOCK_SIZE_K: tl.constexpr,
|
1370
|
+
):
|
1371
|
+
expert_id = tl.program_id(0)
|
1372
|
+
m_id = tl.program_id(1)
|
1373
|
+
k_id = tl.program_id(2)
|
1374
|
+
|
1375
|
+
curr_expert_offset = tl.load(expert_offsets + expert_id)
|
1376
|
+
next_expert_offset = tl.load(expert_offsets + expert_id + 1)
|
1377
|
+
num_tokens_of_expert = next_expert_offset - curr_expert_offset
|
1378
|
+
tl.multiple_of(curr_expert_offset, M_ALIGNMENT)
|
1379
|
+
tl.multiple_of(next_expert_offset, M_ALIGNMENT)
|
1380
|
+
|
1381
|
+
data_start_ptr = data_ptr + curr_expert_offset * k
|
1382
|
+
trans_data_start_ptr = trans_data_ptr + curr_expert_offset * k
|
1383
|
+
|
1384
|
+
k_coord = k_id * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
1385
|
+
k_mask = k_coord < k
|
1386
|
+
for start_m in tl.range(0, num_tokens_of_expert, BLOCK_SIZE_M * tl.num_programs(1)):
|
1387
|
+
m_coord = start_m + m_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
1388
|
+
m_mask = m_coord < num_tokens_of_expert
|
1389
|
+
off = m_coord[:, None] * k + k_coord[None, :]
|
1390
|
+
trans_off = m_coord[:, None] + k_coord[None, :] * num_tokens_of_expert
|
1391
|
+
mask = m_mask[:, None] & k_mask[None, :]
|
1392
|
+
|
1393
|
+
data = tl.load(data_start_ptr + off, mask=mask)
|
1394
|
+
tl.store(trans_data_start_ptr + trans_off, data, mask=mask)
|
1395
|
+
|
1396
|
+
|
1397
|
+
def per_group_transpose(
|
1398
|
+
a: torch.Tensor,
|
1399
|
+
expert_offsets: torch.Tensor,
|
1400
|
+
M_ALIGNMENT: int = 1,
|
1401
|
+
) -> torch.Tensor:
|
1402
|
+
assert a.dim() == 2
|
1403
|
+
assert a.is_contiguous(), "`a` is not contiguous"
|
1404
|
+
|
1405
|
+
m, k = a.size()
|
1406
|
+
trans_a = torch.empty_like(a)
|
1407
|
+
num_experts = expert_offsets.size(0) - 1
|
1408
|
+
|
1409
|
+
grid = lambda META: (
|
1410
|
+
num_experts,
|
1411
|
+
triton.cdiv((m + num_experts - 1) // num_experts, META["BLOCK_SIZE_M"]),
|
1412
|
+
triton.cdiv(k, META["BLOCK_SIZE_K"]),
|
1413
|
+
)
|
1414
|
+
_per_group_transpose[grid](
|
1415
|
+
a, trans_a, expert_offsets, k, M_ALIGNMENT, BLOCK_SIZE_M=16, BLOCK_SIZE_K=8
|
1416
|
+
)
|
1417
|
+
return trans_a
|
1418
|
+
|
1419
|
+
|
1420
|
+
def is_weak_contiguous(x: torch.Tensor):
|
1421
|
+
strides = x.stride()
|
1422
|
+
sizes = x.shape
|
1423
|
+
is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0]))
|
1424
|
+
is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1]))
|
1425
|
+
return is_transpose or is_not_transpose
|
1426
|
+
|
1427
|
+
|
1428
|
+
@triton.jit
|
1429
|
+
def scaled_mm_kernel(
|
1430
|
+
a_ptr,
|
1431
|
+
b_ptr,
|
1432
|
+
scale_a_ptr,
|
1433
|
+
scale_b_ptr,
|
1434
|
+
c_ptr,
|
1435
|
+
bias_ptr,
|
1436
|
+
M,
|
1437
|
+
N,
|
1438
|
+
K,
|
1439
|
+
stride_am,
|
1440
|
+
stride_ak,
|
1441
|
+
stride_bk,
|
1442
|
+
stride_bn,
|
1443
|
+
stride_cm,
|
1444
|
+
stride_cn,
|
1445
|
+
ACCUMULATOR_DTYPE: tl.constexpr,
|
1446
|
+
BLOCK_SIZE_M: tl.constexpr,
|
1447
|
+
BLOCK_SIZE_N: tl.constexpr,
|
1448
|
+
BLOCK_SIZE_K: tl.constexpr,
|
1449
|
+
BLOCK_SIZE_SCALE_A: tl.constexpr,
|
1450
|
+
BLOCK_SIZE_SCALE_B: tl.constexpr,
|
1451
|
+
):
|
1452
|
+
pid = tl.program_id(axis=0)
|
1453
|
+
|
1454
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
1455
|
+
|
1456
|
+
pid_m = pid // num_pid_n
|
1457
|
+
pid_n = pid % num_pid_n
|
1458
|
+
|
1459
|
+
accumulator_dtype = ACCUMULATOR_DTYPE
|
1460
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)
|
1461
|
+
|
1462
|
+
# NOTE: Some tensor inputs are so large, they will cause int32 overflow
|
1463
|
+
# so it is necessary to use tl.int64 for all the offsets, else SEGV will
|
1464
|
+
# eventually occur.
|
1465
|
+
|
1466
|
+
# Offsets and masks.
|
1467
|
+
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
1468
|
+
masks_am = offsets_am < M
|
1469
|
+
|
1470
|
+
offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
|
1471
|
+
masks_bn = offsets_bn < N
|
1472
|
+
|
1473
|
+
offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
1474
|
+
offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :]
|
1475
|
+
offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :]
|
1476
|
+
|
1477
|
+
# NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
|
1478
|
+
# appropriate offsets and masks for each case. Same goes for
|
1479
|
+
# BLOCK_SIZE_SCALE_B.
|
1480
|
+
offsets_scale_am = (
|
1481
|
+
tl.arange(0, BLOCK_SIZE_SCALE_A)
|
1482
|
+
+ (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M
|
1483
|
+
)
|
1484
|
+
masks_scale_am = offsets_scale_am < M
|
1485
|
+
|
1486
|
+
offsets_scale_bn = (
|
1487
|
+
tl.arange(0, BLOCK_SIZE_SCALE_B)
|
1488
|
+
+ (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N
|
1489
|
+
)
|
1490
|
+
masks_scale_bn = offsets_scale_bn < N
|
1491
|
+
|
1492
|
+
a_ptrs = a_ptr + offsets_a
|
1493
|
+
b_ptrs = b_ptr + offsets_b
|
1494
|
+
|
1495
|
+
scale_a_ptrs = scale_a_ptr + offsets_scale_am
|
1496
|
+
scale_b_ptrs = scale_b_ptr + offsets_scale_bn
|
1497
|
+
|
1498
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
1499
|
+
masks_k = offsets_k < K
|
1500
|
+
masks_a = masks_am[:, None] & masks_k[None, :]
|
1501
|
+
a = tl.load(a_ptrs, mask=masks_a)
|
1502
|
+
|
1503
|
+
masks_b = masks_k[:, None] & masks_bn[None, :]
|
1504
|
+
b = tl.load(b_ptrs, mask=masks_b)
|
1505
|
+
|
1506
|
+
# Accumulate results.
|
1507
|
+
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
|
1508
|
+
|
1509
|
+
offsets_k += BLOCK_SIZE_K
|
1510
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
1511
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
1512
|
+
|
1513
|
+
# Apply scale at end.
|
1514
|
+
masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
|
1515
|
+
scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a)
|
1516
|
+
# Need to broadcast to the appropriate size, if scale_a is already
|
1517
|
+
# (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes
|
1518
|
+
# for scale_b below.
|
1519
|
+
scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1))
|
1520
|
+
accumulator = scale_a * accumulator.to(tl.float32)
|
1521
|
+
|
1522
|
+
masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
|
1523
|
+
scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b)
|
1524
|
+
scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1))
|
1525
|
+
accumulator = scale_b.T * accumulator.to(tl.float32)
|
1526
|
+
|
1527
|
+
# Convert to output format.
|
1528
|
+
c = accumulator.to(c_ptr.type.element_ty)
|
1529
|
+
|
1530
|
+
# Add bias, it's already in output format, so add it after conversion.
|
1531
|
+
if bias_ptr:
|
1532
|
+
offsets_bias = offsets_bn
|
1533
|
+
bias_ptrs = bias_ptr + offsets_bias
|
1534
|
+
bias_mask = offsets_bias < N
|
1535
|
+
bias = tl.load(bias_ptrs, bias_mask)
|
1536
|
+
c += bias
|
1537
|
+
|
1538
|
+
# Save output
|
1539
|
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
1540
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
|
1541
|
+
offs_cm = offs_cm.to(tl.int64)
|
1542
|
+
offs_cn = offs_cn.to(tl.int64)
|
1543
|
+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
1544
|
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
1545
|
+
|
1546
|
+
tl.store(c_ptrs, c, mask=c_mask)
|
1547
|
+
|
1548
|
+
|
1549
|
+
# input - [M, K]
|
1550
|
+
# weight - [K, N]
|
1551
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py
|
1552
|
+
def triton_scaled_mm(
|
1553
|
+
input: torch.Tensor,
|
1554
|
+
weight: torch.Tensor,
|
1555
|
+
scale_a: torch.Tensor,
|
1556
|
+
scale_b: torch.Tensor,
|
1557
|
+
out_dtype: type[torch.dtype],
|
1558
|
+
bias: Optional[torch.Tensor] = None,
|
1559
|
+
block_size_m: int = 32,
|
1560
|
+
block_size_n: int = 32,
|
1561
|
+
block_size_k: int = 32,
|
1562
|
+
use_heuristic=True,
|
1563
|
+
) -> torch.Tensor:
|
1564
|
+
M, K = input.shape
|
1565
|
+
N = weight.shape[1]
|
1566
|
+
|
1567
|
+
assert N > 0 and K > 0 and M > 0
|
1568
|
+
assert weight.shape[0] == K
|
1569
|
+
assert input.dtype == weight.dtype
|
1570
|
+
|
1571
|
+
scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
|
1572
|
+
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
|
1573
|
+
|
1574
|
+
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
|
1575
|
+
assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 or scale_a.shape[0] == M)
|
1576
|
+
assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 or scale_b.shape[0] == N)
|
1577
|
+
assert out_dtype.is_floating_point
|
1578
|
+
assert bias is None or bias.is_floating_point()
|
1579
|
+
assert is_weak_contiguous(input)
|
1580
|
+
assert is_weak_contiguous(weight)
|
1581
|
+
|
1582
|
+
grid = lambda META: (
|
1583
|
+
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
1584
|
+
)
|
1585
|
+
|
1586
|
+
result = torch.empty((M, N), dtype=out_dtype, device=input.device)
|
1587
|
+
|
1588
|
+
has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
|
1589
|
+
|
1590
|
+
if use_heuristic:
|
1591
|
+
is_small_N = N < 8192
|
1592
|
+
next_power_of_2_M = max(32, triton.next_power_of_2(M))
|
1593
|
+
if next_power_of_2_M <= 32:
|
1594
|
+
tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
|
1595
|
+
elif next_power_of_2_M <= 64:
|
1596
|
+
tile_shape = (64, 64, 256)
|
1597
|
+
elif next_power_of_2_M <= 128:
|
1598
|
+
tile_shape = (64, 128, 128)
|
1599
|
+
else:
|
1600
|
+
tile_shape = (128, 128, 128)
|
1601
|
+
|
1602
|
+
block_size_m, block_size_n, block_size_k = tile_shape
|
1603
|
+
|
1604
|
+
block_size_sa = 1 if has_scalar(scale_a) else block_size_m
|
1605
|
+
block_size_sb = 1 if has_scalar(scale_b) else block_size_n
|
1606
|
+
|
1607
|
+
accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32
|
1608
|
+
|
1609
|
+
# A = input, B = weight, C = result
|
1610
|
+
# A = M x K, B = K x N, C = M x N
|
1611
|
+
scaled_mm_kernel[grid](
|
1612
|
+
input,
|
1613
|
+
weight,
|
1614
|
+
scale_a,
|
1615
|
+
scale_b,
|
1616
|
+
result,
|
1617
|
+
bias,
|
1618
|
+
M,
|
1619
|
+
N,
|
1620
|
+
K,
|
1621
|
+
input.stride(0),
|
1622
|
+
input.stride(1),
|
1623
|
+
weight.stride(0),
|
1624
|
+
weight.stride(1),
|
1625
|
+
result.stride(0),
|
1626
|
+
result.stride(1),
|
1627
|
+
accumulator_dtype,
|
1628
|
+
BLOCK_SIZE_M=block_size_m,
|
1629
|
+
BLOCK_SIZE_N=block_size_n,
|
1630
|
+
BLOCK_SIZE_K=block_size_k,
|
1631
|
+
BLOCK_SIZE_SCALE_A=block_size_sa,
|
1632
|
+
BLOCK_SIZE_SCALE_B=block_size_sb,
|
1633
|
+
)
|
1634
|
+
|
1635
|
+
return result.to(out_dtype)
|
@@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
22
22
|
scaled_fp8_quant,
|
23
23
|
sglang_per_token_quant_fp8,
|
24
24
|
static_quant_fp8,
|
25
|
+
triton_scaled_mm,
|
25
26
|
w8a8_block_fp8_matmul_deepgemm,
|
26
27
|
w8a8_block_fp8_matmul_triton,
|
27
28
|
)
|
@@ -161,16 +162,16 @@ def flashinfer_gemm_w8a8_block_fp8_linear(
|
|
161
162
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
162
163
|
|
163
164
|
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
164
|
-
input_2d, block_size[1], column_major_scales=
|
165
|
+
input_2d, block_size[1], column_major_scales=True
|
165
166
|
)
|
166
|
-
|
167
|
+
# TRTLLM requires column-major scaling factors
|
167
168
|
output = gemm_fp8_nt_groupwise(
|
168
169
|
q_input,
|
169
170
|
weight,
|
170
171
|
x_scale,
|
171
172
|
weight_scale,
|
172
|
-
scale_major_mode="K",
|
173
173
|
out_dtype=input_2d.dtype,
|
174
|
+
backend="trtllm",
|
174
175
|
)
|
175
176
|
|
176
177
|
if bias is not None:
|
@@ -586,14 +587,25 @@ def apply_fp8_linear(
|
|
586
587
|
assert (
|
587
588
|
weight_scale.numel() == weight.shape[1]
|
588
589
|
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
589
|
-
|
590
|
-
|
591
|
-
weight
|
592
|
-
x_scale,
|
593
|
-
weight_scale,
|
594
|
-
out_dtype=input.dtype,
|
595
|
-
bias=bias,
|
590
|
+
|
591
|
+
cutlass_compatible_b = (
|
592
|
+
weight.shape[0] % 16 == 0 and weight.shape[1] % 16 == 0
|
596
593
|
)
|
594
|
+
if not cutlass_compatible_b:
|
595
|
+
# Massage the input to be 2D
|
596
|
+
qinput = qinput.view(-1, qinput.shape[-1])
|
597
|
+
output = triton_scaled_mm(
|
598
|
+
qinput, weight, x_scale, weight_scale, input.dtype, bias
|
599
|
+
)
|
600
|
+
else:
|
601
|
+
output = fp8_scaled_mm(
|
602
|
+
qinput,
|
603
|
+
weight,
|
604
|
+
x_scale,
|
605
|
+
weight_scale,
|
606
|
+
out_dtype=input.dtype,
|
607
|
+
bias=bias,
|
608
|
+
)
|
597
609
|
return output.view(*output_shape)
|
598
610
|
|
599
611
|
# torch.scaled_mm supports per tensor weights + activations only
|
@@ -36,9 +36,9 @@ from sglang.srt.layers.quantization.marlin_utils import (
|
|
36
36
|
marlin_zero_points,
|
37
37
|
verify_marlin_supported,
|
38
38
|
)
|
39
|
-
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
40
39
|
from sglang.srt.layers.quantization.utils import (
|
41
40
|
get_linear_quant_method,
|
41
|
+
get_scalar_types,
|
42
42
|
replace_parameter,
|
43
43
|
unpack_cols,
|
44
44
|
)
|
@@ -46,20 +46,16 @@ from sglang.srt.layers.quantization.utils import (
|
|
46
46
|
if TYPE_CHECKING:
|
47
47
|
from sglang.srt.layers.moe.topk import TopKOutput
|
48
48
|
|
49
|
-
try:
|
50
|
-
from vllm import _custom_ops as ops
|
51
|
-
except ImportError:
|
52
|
-
ops = None
|
53
|
-
|
54
49
|
from sglang.srt.utils import is_cuda
|
55
50
|
|
56
51
|
_is_cuda = is_cuda()
|
57
52
|
|
58
53
|
if _is_cuda:
|
59
|
-
from sgl_kernel import fused_marlin_moe
|
54
|
+
from sgl_kernel import fused_marlin_moe, gptq_gemm, gptq_marlin_repack, gptq_shuffle
|
60
55
|
|
61
56
|
|
62
57
|
logger = logging.getLogger(__name__)
|
58
|
+
ScalarType, scalar_types = get_scalar_types()
|
63
59
|
|
64
60
|
|
65
61
|
def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
|
@@ -85,9 +81,7 @@ def gptq_marlin_moe_repack(
|
|
85
81
|
dtype=b_q_weight.dtype,
|
86
82
|
)
|
87
83
|
for e in range(num_experts):
|
88
|
-
output[e] =
|
89
|
-
b_q_weight[e], perm[e], size_k, size_n, num_bits
|
90
|
-
)
|
84
|
+
output[e] = gptq_marlin_repack(b_q_weight[e], perm[e], size_k, size_n, num_bits)
|
91
85
|
return output
|
92
86
|
|
93
87
|
|
@@ -204,11 +198,12 @@ class GPTQConfig(QuantizationConfig):
|
|
204
198
|
from sglang.srt.layers.linear import LinearBase
|
205
199
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
206
200
|
|
207
|
-
if isinstance(layer,
|
208
|
-
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
209
|
-
elif isinstance(layer, FusedMoE):
|
201
|
+
if isinstance(layer, FusedMoE):
|
210
202
|
raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin")
|
211
|
-
|
203
|
+
else:
|
204
|
+
return get_linear_quant_method(
|
205
|
+
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
206
|
+
)
|
212
207
|
|
213
208
|
|
214
209
|
class GPTQMarlinConfig(QuantizationConfig):
|
@@ -530,7 +525,7 @@ class GPTQLinearMethod(LinearMethodBase):
|
|
530
525
|
layer.g_idx.data = torch.empty(
|
531
526
|
(0,), dtype=torch.int, device=layer.g_idx.device
|
532
527
|
)
|
533
|
-
|
528
|
+
gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
|
534
529
|
|
535
530
|
def apply(
|
536
531
|
self,
|
@@ -541,7 +536,7 @@ class GPTQLinearMethod(LinearMethodBase):
|
|
541
536
|
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
|
542
537
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
543
538
|
|
544
|
-
output =
|
539
|
+
output = gptq_gemm(
|
545
540
|
reshaped_x,
|
546
541
|
layer.qweight,
|
547
542
|
layer.qzeros,
|
@@ -726,7 +721,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|
726
721
|
def transform_w_q(x):
|
727
722
|
assert isinstance(x, BasevLLMParameter)
|
728
723
|
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
729
|
-
x.data =
|
724
|
+
x.data = gptq_marlin_repack(
|
730
725
|
x.data.contiguous(),
|
731
726
|
perm=layer.g_idx_sort_indices,
|
732
727
|
size_k=c.partition_weight_shape[0],
|
@@ -19,9 +19,12 @@ from sglang.srt.layers.quantization.base_config import (
|
|
19
19
|
LinearMethodBase,
|
20
20
|
QuantizationConfig,
|
21
21
|
)
|
22
|
-
from sglang.srt.layers.quantization.
|
23
|
-
|
24
|
-
|
22
|
+
from sglang.srt.layers.quantization.utils import (
|
23
|
+
get_scalar_types,
|
24
|
+
pack_cols,
|
25
|
+
unpack_cols,
|
26
|
+
)
|
27
|
+
from sglang.srt.utils import get_device_capability, is_cuda
|
25
28
|
|
26
29
|
if TYPE_CHECKING:
|
27
30
|
from sglang.srt.layers.linear import LinearBase
|
@@ -31,8 +34,15 @@ try:
|
|
31
34
|
except ImportError:
|
32
35
|
ops = None
|
33
36
|
|
37
|
+
_is_cuda = is_cuda()
|
38
|
+
|
39
|
+
if _is_cuda:
|
40
|
+
from sgl_kernel import gptq_marlin_gemm
|
41
|
+
|
34
42
|
logger = logging.getLogger(__name__)
|
35
43
|
|
44
|
+
ScalarType, scalar_types = get_scalar_types()
|
45
|
+
|
36
46
|
GPTQ_MARLIN_TILE = 16
|
37
47
|
GPTQ_MARLIN_MIN_THREAD_N = 64
|
38
48
|
GPTQ_MARLIN_MIN_THREAD_K = 128
|
@@ -453,7 +463,7 @@ def apply_gptq_marlin_linear(
|
|
453
463
|
dtype=input.dtype,
|
454
464
|
)
|
455
465
|
|
456
|
-
output =
|
466
|
+
output = gptq_marlin_gemm(
|
457
467
|
reshaped_x,
|
458
468
|
None,
|
459
469
|
weight,
|
@@ -504,7 +514,7 @@ def apply_awq_marlin_linear(
|
|
504
514
|
dtype=input.dtype,
|
505
515
|
)
|
506
516
|
|
507
|
-
output =
|
517
|
+
output = gptq_marlin_gemm(
|
508
518
|
reshaped_x,
|
509
519
|
None,
|
510
520
|
weight,
|