sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,286 @@
|
|
1
|
+
import types
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import pytest
|
5
|
+
import torch
|
6
|
+
from sgl_kernel import fused_marlin_moe
|
7
|
+
|
8
|
+
from sglang.srt.layers.activation import SiluAndMul
|
9
|
+
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
10
|
+
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
|
11
|
+
|
12
|
+
|
13
|
+
def stack_and_dev(tensors: list[torch.Tensor]):
|
14
|
+
dev = tensors[0].device
|
15
|
+
return torch.stack(tensors, dim=0).to(dev)
|
16
|
+
|
17
|
+
|
18
|
+
def torch_experts(
|
19
|
+
a: torch.Tensor,
|
20
|
+
w1: torch.Tensor,
|
21
|
+
w2: torch.Tensor,
|
22
|
+
topk_weight: torch.Tensor,
|
23
|
+
topk_ids: torch.Tensor,
|
24
|
+
global_num_experts: int = -1,
|
25
|
+
expert_map: Optional[torch.Tensor] = None,
|
26
|
+
quant_dtype: Optional[torch.dtype] = None,
|
27
|
+
apply_router_weights_on_input: bool = False,
|
28
|
+
) -> torch.Tensor:
|
29
|
+
assert (
|
30
|
+
global_num_experts == -1
|
31
|
+
or (global_num_experts == w1.shape[0] and expert_map is None)
|
32
|
+
or (expert_map is not None and global_num_experts == expert_map.shape[0])
|
33
|
+
)
|
34
|
+
|
35
|
+
M, K = a.shape
|
36
|
+
topk = topk_ids.shape[1]
|
37
|
+
print("quant_dtype", quant_dtype)
|
38
|
+
# exit(0)
|
39
|
+
if apply_router_weights_on_input:
|
40
|
+
assert topk == 1
|
41
|
+
a = a * topk_weight.to(a.dtype)
|
42
|
+
|
43
|
+
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
44
|
+
|
45
|
+
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
46
|
+
|
47
|
+
num_experts = w1.shape[0]
|
48
|
+
|
49
|
+
topk_ids = topk_ids.view(-1)
|
50
|
+
if expert_map is not None:
|
51
|
+
topk_ids = expert_map[topk_ids]
|
52
|
+
|
53
|
+
f32 = torch.float32
|
54
|
+
|
55
|
+
for i in range(num_experts):
|
56
|
+
mask = topk_ids == i
|
57
|
+
if mask.sum():
|
58
|
+
if quant_dtype is None:
|
59
|
+
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
60
|
+
tmp2 = SiluAndMul()(tmp1)
|
61
|
+
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
62
|
+
|
63
|
+
if apply_router_weights_on_input:
|
64
|
+
return out
|
65
|
+
else:
|
66
|
+
return (
|
67
|
+
(out.view(M, -1, w2.shape[1]).to(f32) * topk_weight.view(M, -1, 1))
|
68
|
+
.sum(dim=1)
|
69
|
+
.to(out.dtype)
|
70
|
+
)
|
71
|
+
|
72
|
+
|
73
|
+
def torch_moe(
|
74
|
+
a: torch.Tensor,
|
75
|
+
w1: torch.Tensor,
|
76
|
+
w2: torch.Tensor,
|
77
|
+
score: torch.Tensor,
|
78
|
+
topk: int,
|
79
|
+
global_num_experts: int = -1,
|
80
|
+
expert_map: Optional[torch.Tensor] = None,
|
81
|
+
) -> torch.Tensor:
|
82
|
+
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
83
|
+
topk_weight, topk_ids = torch.topk(score, topk)
|
84
|
+
return torch_experts(
|
85
|
+
a, w1, w2, topk_weight, topk_ids, global_num_experts, expert_map
|
86
|
+
)
|
87
|
+
|
88
|
+
|
89
|
+
def marlin_moe_generate_valid_test_cases():
|
90
|
+
import itertools
|
91
|
+
|
92
|
+
m_list = [1, 123, 666]
|
93
|
+
n_list = [128, 1024]
|
94
|
+
k_list = [256, 2048]
|
95
|
+
e_list = [4, 12]
|
96
|
+
topk_list = [2, 3]
|
97
|
+
dtype_list = [torch.half, torch.bfloat16]
|
98
|
+
group_size_list = [128]
|
99
|
+
act_order_list = [True, False]
|
100
|
+
quant_type_list = [
|
101
|
+
scalar_types.uint4,
|
102
|
+
scalar_types.uint4b8,
|
103
|
+
]
|
104
|
+
is_k_full_list = [True, False]
|
105
|
+
|
106
|
+
all_combinations = itertools.product(
|
107
|
+
m_list,
|
108
|
+
n_list,
|
109
|
+
k_list,
|
110
|
+
e_list,
|
111
|
+
topk_list,
|
112
|
+
dtype_list,
|
113
|
+
group_size_list,
|
114
|
+
act_order_list,
|
115
|
+
quant_type_list,
|
116
|
+
is_k_full_list,
|
117
|
+
)
|
118
|
+
|
119
|
+
def is_invalid(
|
120
|
+
m, n, k, e, topk, dtype, group_size, act_order, quant_type, is_k_full
|
121
|
+
):
|
122
|
+
|
123
|
+
# Filter act_order
|
124
|
+
if act_order:
|
125
|
+
if group_size in (-1, k, n):
|
126
|
+
return False
|
127
|
+
if quant_type not in [scalar_types.uint4b8]:
|
128
|
+
return False
|
129
|
+
elif not is_k_full:
|
130
|
+
return False
|
131
|
+
|
132
|
+
return True
|
133
|
+
|
134
|
+
cases = []
|
135
|
+
for case in all_combinations:
|
136
|
+
if is_invalid(*case):
|
137
|
+
cases.append(case)
|
138
|
+
return cases
|
139
|
+
|
140
|
+
|
141
|
+
@pytest.mark.flaky(reruns=2)
|
142
|
+
@pytest.mark.parametrize(
|
143
|
+
("m, n, k, e, topk, dtype, group_size," "act_order, quant_type, is_k_full"),
|
144
|
+
marlin_moe_generate_valid_test_cases(),
|
145
|
+
)
|
146
|
+
def test_fused_marlin_moe(
|
147
|
+
m: int,
|
148
|
+
n: int,
|
149
|
+
k: int,
|
150
|
+
e: int,
|
151
|
+
topk: int,
|
152
|
+
dtype: torch.dtype,
|
153
|
+
group_size: int,
|
154
|
+
act_order: bool,
|
155
|
+
quant_type: ScalarType,
|
156
|
+
is_k_full: bool,
|
157
|
+
):
|
158
|
+
if not torch.cuda.is_available():
|
159
|
+
pytest.skip("CUDA device not available")
|
160
|
+
|
161
|
+
torch.manual_seed(0)
|
162
|
+
|
163
|
+
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
164
|
+
|
165
|
+
# Filter act_order
|
166
|
+
if act_order:
|
167
|
+
if group_size == -1:
|
168
|
+
return
|
169
|
+
if group_size in (k, n):
|
170
|
+
return
|
171
|
+
if has_zp:
|
172
|
+
return
|
173
|
+
else:
|
174
|
+
if not is_k_full:
|
175
|
+
return
|
176
|
+
|
177
|
+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
178
|
+
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
179
|
+
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
180
|
+
|
181
|
+
e_map = None
|
182
|
+
|
183
|
+
w_ref1_l = []
|
184
|
+
qweight1_l = []
|
185
|
+
scales1_l = []
|
186
|
+
zeros1_l = []
|
187
|
+
g_idx1_l = []
|
188
|
+
sort_indices1_l = []
|
189
|
+
|
190
|
+
for i in range(w1.shape[0]):
|
191
|
+
if has_zp:
|
192
|
+
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
|
193
|
+
w1[i].transpose(1, 0), quant_type, group_size
|
194
|
+
)
|
195
|
+
|
196
|
+
w_ref1_l.append(w_ref1.T)
|
197
|
+
qweight1_l.append(qweight1)
|
198
|
+
scales1_l.append(scales1)
|
199
|
+
zeros1_l.append(zeros1)
|
200
|
+
else:
|
201
|
+
test_perm = torch.randperm(k)
|
202
|
+
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
|
203
|
+
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
204
|
+
)
|
205
|
+
|
206
|
+
w_ref1_l.append(w_ref1.T)
|
207
|
+
qweight1_l.append(qweight1)
|
208
|
+
scales1_l.append(scales1)
|
209
|
+
g_idx1_l.append(g_idx1)
|
210
|
+
sort_indices1_l.append(sort_indices1)
|
211
|
+
|
212
|
+
w_ref1 = stack_and_dev(w_ref1_l)
|
213
|
+
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
214
|
+
scales1 = stack_and_dev(scales1_l)
|
215
|
+
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
|
216
|
+
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
|
217
|
+
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
|
218
|
+
|
219
|
+
w_ref2_l = []
|
220
|
+
qweight2_l = []
|
221
|
+
scales2_l = []
|
222
|
+
zeros2_l = []
|
223
|
+
g_idx2_l = []
|
224
|
+
sort_indices2_l = []
|
225
|
+
|
226
|
+
for i in range(w2.shape[0]):
|
227
|
+
if has_zp:
|
228
|
+
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
|
229
|
+
w2[i].transpose(1, 0), quant_type, group_size
|
230
|
+
)
|
231
|
+
|
232
|
+
w_ref2_l.append(w_ref2.T)
|
233
|
+
qweight2_l.append(qweight2)
|
234
|
+
scales2_l.append(scales2)
|
235
|
+
zeros2_l.append(zeros2)
|
236
|
+
else:
|
237
|
+
test_perm = torch.randperm(n)
|
238
|
+
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
|
239
|
+
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
240
|
+
)
|
241
|
+
|
242
|
+
w_ref2_l.append(w_ref2.T)
|
243
|
+
qweight2_l.append(qweight2)
|
244
|
+
scales2_l.append(scales2)
|
245
|
+
g_idx2_l.append(g_idx2)
|
246
|
+
sort_indices2_l.append(sort_indices2)
|
247
|
+
|
248
|
+
w_ref2 = stack_and_dev(w_ref2_l)
|
249
|
+
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
250
|
+
scales2 = stack_and_dev(scales2_l)
|
251
|
+
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
|
252
|
+
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
|
253
|
+
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
|
254
|
+
|
255
|
+
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
256
|
+
from sglang.srt.layers.moe.topk import fused_topk_torch_native
|
257
|
+
|
258
|
+
topk_weights, topk_ids = fused_topk_torch_native(a, score, topk, False)
|
259
|
+
|
260
|
+
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
|
261
|
+
|
262
|
+
marlin_output = fused_marlin_moe(
|
263
|
+
a,
|
264
|
+
qweight1,
|
265
|
+
qweight2,
|
266
|
+
scales1,
|
267
|
+
scales2,
|
268
|
+
score,
|
269
|
+
topk_weights,
|
270
|
+
topk_ids,
|
271
|
+
g_idx1=g_idx1,
|
272
|
+
g_idx2=g_idx2,
|
273
|
+
sort_indices1=sort_indices1,
|
274
|
+
sort_indices2=sort_indices2,
|
275
|
+
w1_zeros=zeros1,
|
276
|
+
w2_zeros=zeros2,
|
277
|
+
num_bits=4,
|
278
|
+
is_k_full=is_k_full,
|
279
|
+
)
|
280
|
+
|
281
|
+
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
282
|
+
|
283
|
+
|
284
|
+
if __name__ == "__main__":
|
285
|
+
# Run the specific test function directly
|
286
|
+
pytest.main([__file__])
|
@@ -0,0 +1,171 @@
|
|
1
|
+
"""
|
2
|
+
Adapted from
|
3
|
+
https://github.com/vllm-project/vllm/blob/020f58abcdea65302225663130d08fd8f4dd755a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
|
4
|
+
"""
|
5
|
+
|
6
|
+
# SPDX-License-Identifier: Apache-2.0
|
7
|
+
"""Utility functions used for tests and benchmarks"""
|
8
|
+
|
9
|
+
from typing import Optional
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
import torch
|
13
|
+
|
14
|
+
from sglang.srt.layers.quantization.marlin_utils import (
|
15
|
+
GPTQ_MARLIN_TILE,
|
16
|
+
marlin_permute_scales,
|
17
|
+
marlin_zero_points,
|
18
|
+
)
|
19
|
+
from sglang.srt.layers.quantization.scalar_type import ScalarType
|
20
|
+
from sglang.srt.layers.quantization.utils import (
|
21
|
+
get_pack_factor,
|
22
|
+
gptq_quantize_weights,
|
23
|
+
quantize_weights,
|
24
|
+
sort_weights,
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
class MarlinWorkspace:
|
29
|
+
|
30
|
+
def __init__(self, out_features, min_thread_n, max_parallel):
|
31
|
+
assert (
|
32
|
+
out_features % min_thread_n == 0
|
33
|
+
), "out_features = {} is undivisible by min_thread_n = {}".format(
|
34
|
+
out_features, min_thread_n
|
35
|
+
)
|
36
|
+
|
37
|
+
max_workspace_size = (out_features // min_thread_n) * max_parallel
|
38
|
+
|
39
|
+
self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
|
40
|
+
|
41
|
+
|
42
|
+
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
|
43
|
+
assert q_w.shape == (size_k, size_n)
|
44
|
+
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
45
|
+
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
46
|
+
|
47
|
+
# Permute weights to 16x64 marlin tiles
|
48
|
+
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
|
49
|
+
q_w = q_w.permute((0, 2, 1, 3))
|
50
|
+
q_w = q_w.reshape((size_k // tile, size_n * tile))
|
51
|
+
|
52
|
+
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
|
53
|
+
|
54
|
+
return q_w
|
55
|
+
|
56
|
+
|
57
|
+
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
58
|
+
# Permute
|
59
|
+
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
60
|
+
|
61
|
+
# Pack
|
62
|
+
pack_factor = get_pack_factor(num_bits)
|
63
|
+
orig_device = q_w.device
|
64
|
+
|
65
|
+
q_w = q_w.cpu().numpy().astype(np.uint32)
|
66
|
+
|
67
|
+
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
|
68
|
+
for i in range(pack_factor):
|
69
|
+
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
70
|
+
|
71
|
+
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
|
72
|
+
|
73
|
+
return q_packed
|
74
|
+
|
75
|
+
|
76
|
+
def get_weight_perm(num_bits: int):
|
77
|
+
perm_list: list[int] = []
|
78
|
+
for i in range(32):
|
79
|
+
perm1: list[int] = []
|
80
|
+
col = i // 4
|
81
|
+
for block in [0, 1]:
|
82
|
+
for row in [
|
83
|
+
2 * (i % 4),
|
84
|
+
2 * (i % 4) + 1,
|
85
|
+
2 * (i % 4 + 4),
|
86
|
+
2 * (i % 4 + 4) + 1,
|
87
|
+
]:
|
88
|
+
perm1.append(16 * row + col + 8 * block)
|
89
|
+
for j in range(4):
|
90
|
+
perm_list.extend([p + 256 * j for p in perm1])
|
91
|
+
|
92
|
+
perm = np.array(perm_list)
|
93
|
+
|
94
|
+
if num_bits == 4:
|
95
|
+
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
96
|
+
elif num_bits == 8:
|
97
|
+
interleave = np.array([0, 2, 1, 3])
|
98
|
+
else:
|
99
|
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
100
|
+
|
101
|
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
102
|
+
perm = torch.from_numpy(perm)
|
103
|
+
return perm
|
104
|
+
|
105
|
+
|
106
|
+
def marlin_quantize(
|
107
|
+
w: torch.Tensor,
|
108
|
+
quant_type: ScalarType,
|
109
|
+
group_size: int,
|
110
|
+
act_order: bool,
|
111
|
+
test_perm: Optional[torch.Tensor] = None,
|
112
|
+
):
|
113
|
+
size_k, size_n = w.shape
|
114
|
+
num_bits = quant_type.size_bits
|
115
|
+
|
116
|
+
# Normalize group_size
|
117
|
+
if group_size == -1:
|
118
|
+
group_size = size_k
|
119
|
+
assert group_size <= size_k
|
120
|
+
|
121
|
+
# Quantize (and apply act_order if provided)
|
122
|
+
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
123
|
+
w, quant_type, group_size, act_order, test_perm
|
124
|
+
)
|
125
|
+
|
126
|
+
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
127
|
+
# increasing
|
128
|
+
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
|
129
|
+
if act_order:
|
130
|
+
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
131
|
+
|
132
|
+
# Reformat to marlin
|
133
|
+
weight_perm = get_weight_perm(num_bits)
|
134
|
+
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
135
|
+
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
136
|
+
|
137
|
+
# Create result
|
138
|
+
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
|
139
|
+
for i in range(len(res_list)):
|
140
|
+
res_list[i] = res_list[i].to(w.device)
|
141
|
+
|
142
|
+
return res_list
|
143
|
+
|
144
|
+
|
145
|
+
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
|
146
|
+
size_k, size_n = w.shape
|
147
|
+
|
148
|
+
# Normalize group_size
|
149
|
+
if group_size == -1:
|
150
|
+
group_size = size_k
|
151
|
+
assert group_size <= size_k
|
152
|
+
|
153
|
+
# Detect num groups
|
154
|
+
assert size_k % group_size == 0
|
155
|
+
num_groups = size_k // group_size
|
156
|
+
|
157
|
+
# Quantize with zp
|
158
|
+
w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
|
159
|
+
|
160
|
+
# Reformat to marlin
|
161
|
+
weight_perm = get_weight_perm(quant_type.size_bits)
|
162
|
+
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
|
163
|
+
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
164
|
+
marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
|
165
|
+
|
166
|
+
# Create result
|
167
|
+
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
|
168
|
+
for i in range(len(res_list)):
|
169
|
+
res_list[i] = res_list[i].to(w.device)
|
170
|
+
|
171
|
+
return res_list
|
sglang/test/test_utils.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
import argparse
|
4
4
|
import copy
|
5
|
+
import json
|
5
6
|
import logging
|
6
7
|
import os
|
7
8
|
import random
|
@@ -102,6 +103,15 @@ def is_in_amd_ci():
|
|
102
103
|
return get_bool_env_var("SGLANG_AMD_CI")
|
103
104
|
|
104
105
|
|
106
|
+
def _use_cached_default_models(model_repo: str):
|
107
|
+
cache_dir = os.getenv("DEFAULT_MODEL_CACHE_DIR")
|
108
|
+
if cache_dir and model_repo:
|
109
|
+
model_path = os.path.join(cache_dir, model_repo)
|
110
|
+
if os.path.isdir(model_path):
|
111
|
+
return os.path.abspath(model_path)
|
112
|
+
return ""
|
113
|
+
|
114
|
+
|
105
115
|
if is_in_ci():
|
106
116
|
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
|
107
117
|
5000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
|
@@ -419,6 +429,31 @@ def get_call_select(args: argparse.Namespace):
|
|
419
429
|
return func
|
420
430
|
|
421
431
|
|
432
|
+
def _get_default_models():
|
433
|
+
import inspect
|
434
|
+
|
435
|
+
current_module = inspect.getmodule(_get_default_models)
|
436
|
+
default_models = set()
|
437
|
+
for name, value in current_module.__dict__.items():
|
438
|
+
if (
|
439
|
+
isinstance(name, str)
|
440
|
+
and "DEFAULT_" in name
|
441
|
+
and "MODEL_" in name
|
442
|
+
and isinstance(value, str)
|
443
|
+
):
|
444
|
+
if "," in value:
|
445
|
+
parts = [part.strip() for part in value.split(",")]
|
446
|
+
default_models.update(parts)
|
447
|
+
else:
|
448
|
+
default_models.add(value.strip())
|
449
|
+
return json.dumps(list(default_models))
|
450
|
+
|
451
|
+
|
452
|
+
def try_cached_model(model_repo: str):
|
453
|
+
model_dir = _use_cached_default_models(model_repo)
|
454
|
+
return model_dir if model_dir else model_repo
|
455
|
+
|
456
|
+
|
422
457
|
def popen_launch_server(
|
423
458
|
model: str,
|
424
459
|
base_url: str,
|
sglang/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.4.9.
|
1
|
+
__version__ = "0.4.9.post4"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: sglang
|
3
|
-
Version: 0.4.9.
|
3
|
+
Version: 0.4.9.post4
|
4
4
|
Summary: SGLang is yet another fast serving framework for large language models and vision language models.
|
5
5
|
License: Apache License
|
6
6
|
Version 2.0, January 2004
|
@@ -246,20 +246,20 @@ Requires-Dist: sentencepiece; extra == "runtime-common"
|
|
246
246
|
Requires-Dist: soundfile==0.13.1; extra == "runtime-common"
|
247
247
|
Requires-Dist: scipy; extra == "runtime-common"
|
248
248
|
Requires-Dist: torchao==0.9.0; extra == "runtime-common"
|
249
|
-
Requires-Dist: transformers==4.53.
|
249
|
+
Requires-Dist: transformers==4.53.2; extra == "runtime-common"
|
250
250
|
Requires-Dist: timm==1.0.16; extra == "runtime-common"
|
251
251
|
Requires-Dist: uvicorn; extra == "runtime-common"
|
252
252
|
Requires-Dist: uvloop; extra == "runtime-common"
|
253
253
|
Requires-Dist: xgrammar==0.1.21; extra == "runtime-common"
|
254
254
|
Provides-Extra: srt
|
255
255
|
Requires-Dist: sglang[runtime_common]; extra == "srt"
|
256
|
-
Requires-Dist: sgl-kernel==0.2.
|
256
|
+
Requires-Dist: sgl-kernel==0.2.7; extra == "srt"
|
257
257
|
Requires-Dist: torch==2.7.1; extra == "srt"
|
258
258
|
Requires-Dist: torchaudio==2.7.1; extra == "srt"
|
259
259
|
Requires-Dist: torchvision==0.22.1; extra == "srt"
|
260
260
|
Requires-Dist: cuda-python; extra == "srt"
|
261
261
|
Requires-Dist: einops; extra == "srt"
|
262
|
-
Requires-Dist: flashinfer_python==0.2.
|
262
|
+
Requires-Dist: flashinfer_python==0.2.9rc1; extra == "srt"
|
263
263
|
Provides-Extra: blackwell
|
264
264
|
Requires-Dist: sglang[runtime_common]; extra == "blackwell"
|
265
265
|
Requires-Dist: sgl-kernel; extra == "blackwell"
|
@@ -268,11 +268,11 @@ Requires-Dist: torchaudio==2.7.1; extra == "blackwell"
|
|
268
268
|
Requires-Dist: torchvision==0.22.1; extra == "blackwell"
|
269
269
|
Requires-Dist: cuda-python; extra == "blackwell"
|
270
270
|
Requires-Dist: einops; extra == "blackwell"
|
271
|
-
Requires-Dist: flashinfer_python==0.2.
|
271
|
+
Requires-Dist: flashinfer_python==0.2.9rc1; extra == "blackwell"
|
272
272
|
Provides-Extra: srt-hip
|
273
273
|
Requires-Dist: sglang[runtime_common]; extra == "srt-hip"
|
274
274
|
Requires-Dist: torch; extra == "srt-hip"
|
275
|
-
Requires-Dist:
|
275
|
+
Requires-Dist: petit_kernel==0.0.2; extra == "srt-hip"
|
276
276
|
Provides-Extra: srt-xpu
|
277
277
|
Requires-Dist: sglang[runtime_common]; extra == "srt-xpu"
|
278
278
|
Provides-Extra: srt-hpu
|
@@ -381,14 +381,14 @@ Dynamic: license-file
|
|
381
381
|
- [2025/05] 🔥 Deploying DeepSeek with PD Disaggregation and Large-scale Expert Parallelism on 96 H100 GPUs ([blog](https://lmsys.org/blog/2025-05-05-large-scale-ep/)).
|
382
382
|
- [2025/03] Supercharge DeepSeek-R1 Inference on AMD Instinct MI300X ([AMD blog](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1-Part2/README.html))
|
383
383
|
- [2025/03] SGLang Joins PyTorch Ecosystem: Efficient LLM Serving Engine ([PyTorch blog](https://pytorch.org/blog/sglang-joins-pytorch/))
|
384
|
-
- [
|
385
|
-
- [2024/12] 🔥 v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).
|
384
|
+
- [2024/12] v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).
|
386
385
|
- [2024/07] v0.2 Release: Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
|
387
386
|
|
388
387
|
<details>
|
389
388
|
<summary>More</summary>
|
390
389
|
|
391
390
|
- [2025/02] Unlock DeepSeek-R1 Inference Performance on AMD Instinct™ MI300X GPU ([AMD blog](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1_Perf/README.html))
|
391
|
+
- [2025/01] SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeepSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html), [10+ other companies](https://x.com/lmsysorg/status/1887262321636221412))
|
392
392
|
- [2024/10] The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)).
|
393
393
|
- [2024/09] v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)).
|
394
394
|
- [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
|
@@ -415,10 +415,10 @@ The core features include:
|
|
415
415
|
- [Contribution Guide](https://docs.sglang.ai/references/contribution_guide.html)
|
416
416
|
|
417
417
|
## Benchmark and Performance
|
418
|
-
Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/).
|
418
|
+
Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/), [Large-scale expert parallelism](https://lmsys.org/blog/2025-05-05-large-scale-ep/).
|
419
419
|
|
420
420
|
## Roadmap
|
421
|
-
[Development Roadmap (2025
|
421
|
+
[Development Roadmap (2025 H2)](https://github.com/sgl-project/sglang/issues/7736)
|
422
422
|
|
423
423
|
## Adoption and Sponsorship
|
424
424
|
SGLang has been deployed at large scale, generating trillions of tokens in production each day. It is trusted and adopted by a wide range of leading enterprises and institutions, including xAI, AMD, NVIDIA, Intel, LinkedIn, Cursor, Oracle Cloud, Google Cloud, Microsoft Azure, AWS, Atlas Cloud, Voltage Park, Nebius, DataCrunch, Novita, InnoMatrix, MIT, UCLA, the University of Washington, Stanford, UC Berkeley, Tsinghua University, Jam & Tea Studios, Baseten, and other major technology organizations across North America and Asia. As an open-source LLM inference engine, SGLang has become the de facto industry standard, with deployments running on over 1,000,000 GPUs worldwide.
|