sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,286 @@
1
+ import types
2
+ from typing import Optional
3
+
4
+ import pytest
5
+ import torch
6
+ from sgl_kernel import fused_marlin_moe
7
+
8
+ from sglang.srt.layers.activation import SiluAndMul
9
+ from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
10
+ from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
11
+
12
+
13
+ def stack_and_dev(tensors: list[torch.Tensor]):
14
+ dev = tensors[0].device
15
+ return torch.stack(tensors, dim=0).to(dev)
16
+
17
+
18
+ def torch_experts(
19
+ a: torch.Tensor,
20
+ w1: torch.Tensor,
21
+ w2: torch.Tensor,
22
+ topk_weight: torch.Tensor,
23
+ topk_ids: torch.Tensor,
24
+ global_num_experts: int = -1,
25
+ expert_map: Optional[torch.Tensor] = None,
26
+ quant_dtype: Optional[torch.dtype] = None,
27
+ apply_router_weights_on_input: bool = False,
28
+ ) -> torch.Tensor:
29
+ assert (
30
+ global_num_experts == -1
31
+ or (global_num_experts == w1.shape[0] and expert_map is None)
32
+ or (expert_map is not None and global_num_experts == expert_map.shape[0])
33
+ )
34
+
35
+ M, K = a.shape
36
+ topk = topk_ids.shape[1]
37
+ print("quant_dtype", quant_dtype)
38
+ # exit(0)
39
+ if apply_router_weights_on_input:
40
+ assert topk == 1
41
+ a = a * topk_weight.to(a.dtype)
42
+
43
+ a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
44
+
45
+ out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
46
+
47
+ num_experts = w1.shape[0]
48
+
49
+ topk_ids = topk_ids.view(-1)
50
+ if expert_map is not None:
51
+ topk_ids = expert_map[topk_ids]
52
+
53
+ f32 = torch.float32
54
+
55
+ for i in range(num_experts):
56
+ mask = topk_ids == i
57
+ if mask.sum():
58
+ if quant_dtype is None:
59
+ tmp1 = a[mask] @ w1[i].transpose(0, 1)
60
+ tmp2 = SiluAndMul()(tmp1)
61
+ out[mask] = tmp2 @ w2[i].transpose(0, 1)
62
+
63
+ if apply_router_weights_on_input:
64
+ return out
65
+ else:
66
+ return (
67
+ (out.view(M, -1, w2.shape[1]).to(f32) * topk_weight.view(M, -1, 1))
68
+ .sum(dim=1)
69
+ .to(out.dtype)
70
+ )
71
+
72
+
73
+ def torch_moe(
74
+ a: torch.Tensor,
75
+ w1: torch.Tensor,
76
+ w2: torch.Tensor,
77
+ score: torch.Tensor,
78
+ topk: int,
79
+ global_num_experts: int = -1,
80
+ expert_map: Optional[torch.Tensor] = None,
81
+ ) -> torch.Tensor:
82
+ score = torch.softmax(score, dim=-1, dtype=torch.float32)
83
+ topk_weight, topk_ids = torch.topk(score, topk)
84
+ return torch_experts(
85
+ a, w1, w2, topk_weight, topk_ids, global_num_experts, expert_map
86
+ )
87
+
88
+
89
+ def marlin_moe_generate_valid_test_cases():
90
+ import itertools
91
+
92
+ m_list = [1, 123, 666]
93
+ n_list = [128, 1024]
94
+ k_list = [256, 2048]
95
+ e_list = [4, 12]
96
+ topk_list = [2, 3]
97
+ dtype_list = [torch.half, torch.bfloat16]
98
+ group_size_list = [128]
99
+ act_order_list = [True, False]
100
+ quant_type_list = [
101
+ scalar_types.uint4,
102
+ scalar_types.uint4b8,
103
+ ]
104
+ is_k_full_list = [True, False]
105
+
106
+ all_combinations = itertools.product(
107
+ m_list,
108
+ n_list,
109
+ k_list,
110
+ e_list,
111
+ topk_list,
112
+ dtype_list,
113
+ group_size_list,
114
+ act_order_list,
115
+ quant_type_list,
116
+ is_k_full_list,
117
+ )
118
+
119
+ def is_invalid(
120
+ m, n, k, e, topk, dtype, group_size, act_order, quant_type, is_k_full
121
+ ):
122
+
123
+ # Filter act_order
124
+ if act_order:
125
+ if group_size in (-1, k, n):
126
+ return False
127
+ if quant_type not in [scalar_types.uint4b8]:
128
+ return False
129
+ elif not is_k_full:
130
+ return False
131
+
132
+ return True
133
+
134
+ cases = []
135
+ for case in all_combinations:
136
+ if is_invalid(*case):
137
+ cases.append(case)
138
+ return cases
139
+
140
+
141
+ @pytest.mark.flaky(reruns=2)
142
+ @pytest.mark.parametrize(
143
+ ("m, n, k, e, topk, dtype, group_size," "act_order, quant_type, is_k_full"),
144
+ marlin_moe_generate_valid_test_cases(),
145
+ )
146
+ def test_fused_marlin_moe(
147
+ m: int,
148
+ n: int,
149
+ k: int,
150
+ e: int,
151
+ topk: int,
152
+ dtype: torch.dtype,
153
+ group_size: int,
154
+ act_order: bool,
155
+ quant_type: ScalarType,
156
+ is_k_full: bool,
157
+ ):
158
+ if not torch.cuda.is_available():
159
+ pytest.skip("CUDA device not available")
160
+
161
+ torch.manual_seed(0)
162
+
163
+ has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
164
+
165
+ # Filter act_order
166
+ if act_order:
167
+ if group_size == -1:
168
+ return
169
+ if group_size in (k, n):
170
+ return
171
+ if has_zp:
172
+ return
173
+ else:
174
+ if not is_k_full:
175
+ return
176
+
177
+ a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
178
+ w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
179
+ w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
180
+
181
+ e_map = None
182
+
183
+ w_ref1_l = []
184
+ qweight1_l = []
185
+ scales1_l = []
186
+ zeros1_l = []
187
+ g_idx1_l = []
188
+ sort_indices1_l = []
189
+
190
+ for i in range(w1.shape[0]):
191
+ if has_zp:
192
+ w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
193
+ w1[i].transpose(1, 0), quant_type, group_size
194
+ )
195
+
196
+ w_ref1_l.append(w_ref1.T)
197
+ qweight1_l.append(qweight1)
198
+ scales1_l.append(scales1)
199
+ zeros1_l.append(zeros1)
200
+ else:
201
+ test_perm = torch.randperm(k)
202
+ w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
203
+ w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
204
+ )
205
+
206
+ w_ref1_l.append(w_ref1.T)
207
+ qweight1_l.append(qweight1)
208
+ scales1_l.append(scales1)
209
+ g_idx1_l.append(g_idx1)
210
+ sort_indices1_l.append(sort_indices1)
211
+
212
+ w_ref1 = stack_and_dev(w_ref1_l)
213
+ qweight1 = stack_and_dev(qweight1_l).contiguous()
214
+ scales1 = stack_and_dev(scales1_l)
215
+ g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
216
+ zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
217
+ sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
218
+
219
+ w_ref2_l = []
220
+ qweight2_l = []
221
+ scales2_l = []
222
+ zeros2_l = []
223
+ g_idx2_l = []
224
+ sort_indices2_l = []
225
+
226
+ for i in range(w2.shape[0]):
227
+ if has_zp:
228
+ w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
229
+ w2[i].transpose(1, 0), quant_type, group_size
230
+ )
231
+
232
+ w_ref2_l.append(w_ref2.T)
233
+ qweight2_l.append(qweight2)
234
+ scales2_l.append(scales2)
235
+ zeros2_l.append(zeros2)
236
+ else:
237
+ test_perm = torch.randperm(n)
238
+ w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
239
+ w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
240
+ )
241
+
242
+ w_ref2_l.append(w_ref2.T)
243
+ qweight2_l.append(qweight2)
244
+ scales2_l.append(scales2)
245
+ g_idx2_l.append(g_idx2)
246
+ sort_indices2_l.append(sort_indices2)
247
+
248
+ w_ref2 = stack_and_dev(w_ref2_l)
249
+ qweight2 = stack_and_dev(qweight2_l).contiguous()
250
+ scales2 = stack_and_dev(scales2_l)
251
+ g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
252
+ zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
253
+ sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
254
+
255
+ score = torch.randn((m, e), device="cuda", dtype=dtype)
256
+ from sglang.srt.layers.moe.topk import fused_topk_torch_native
257
+
258
+ topk_weights, topk_ids = fused_topk_torch_native(a, score, topk, False)
259
+
260
+ torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
261
+
262
+ marlin_output = fused_marlin_moe(
263
+ a,
264
+ qweight1,
265
+ qweight2,
266
+ scales1,
267
+ scales2,
268
+ score,
269
+ topk_weights,
270
+ topk_ids,
271
+ g_idx1=g_idx1,
272
+ g_idx2=g_idx2,
273
+ sort_indices1=sort_indices1,
274
+ sort_indices2=sort_indices2,
275
+ w1_zeros=zeros1,
276
+ w2_zeros=zeros2,
277
+ num_bits=4,
278
+ is_k_full=is_k_full,
279
+ )
280
+
281
+ torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
282
+
283
+
284
+ if __name__ == "__main__":
285
+ # Run the specific test function directly
286
+ pytest.main([__file__])
@@ -0,0 +1,171 @@
1
+ """
2
+ Adapted from
3
+ https://github.com/vllm-project/vllm/blob/020f58abcdea65302225663130d08fd8f4dd755a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
4
+ """
5
+
6
+ # SPDX-License-Identifier: Apache-2.0
7
+ """Utility functions used for tests and benchmarks"""
8
+
9
+ from typing import Optional
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ from sglang.srt.layers.quantization.marlin_utils import (
15
+ GPTQ_MARLIN_TILE,
16
+ marlin_permute_scales,
17
+ marlin_zero_points,
18
+ )
19
+ from sglang.srt.layers.quantization.scalar_type import ScalarType
20
+ from sglang.srt.layers.quantization.utils import (
21
+ get_pack_factor,
22
+ gptq_quantize_weights,
23
+ quantize_weights,
24
+ sort_weights,
25
+ )
26
+
27
+
28
+ class MarlinWorkspace:
29
+
30
+ def __init__(self, out_features, min_thread_n, max_parallel):
31
+ assert (
32
+ out_features % min_thread_n == 0
33
+ ), "out_features = {} is undivisible by min_thread_n = {}".format(
34
+ out_features, min_thread_n
35
+ )
36
+
37
+ max_workspace_size = (out_features // min_thread_n) * max_parallel
38
+
39
+ self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
40
+
41
+
42
+ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
43
+ assert q_w.shape == (size_k, size_n)
44
+ assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
45
+ assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
46
+
47
+ # Permute weights to 16x64 marlin tiles
48
+ q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
49
+ q_w = q_w.permute((0, 2, 1, 3))
50
+ q_w = q_w.reshape((size_k // tile, size_n * tile))
51
+
52
+ q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
53
+
54
+ return q_w
55
+
56
+
57
+ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
58
+ # Permute
59
+ q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
60
+
61
+ # Pack
62
+ pack_factor = get_pack_factor(num_bits)
63
+ orig_device = q_w.device
64
+
65
+ q_w = q_w.cpu().numpy().astype(np.uint32)
66
+
67
+ q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
68
+ for i in range(pack_factor):
69
+ q_packed |= q_w[:, i::pack_factor] << num_bits * i
70
+
71
+ q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
72
+
73
+ return q_packed
74
+
75
+
76
+ def get_weight_perm(num_bits: int):
77
+ perm_list: list[int] = []
78
+ for i in range(32):
79
+ perm1: list[int] = []
80
+ col = i // 4
81
+ for block in [0, 1]:
82
+ for row in [
83
+ 2 * (i % 4),
84
+ 2 * (i % 4) + 1,
85
+ 2 * (i % 4 + 4),
86
+ 2 * (i % 4 + 4) + 1,
87
+ ]:
88
+ perm1.append(16 * row + col + 8 * block)
89
+ for j in range(4):
90
+ perm_list.extend([p + 256 * j for p in perm1])
91
+
92
+ perm = np.array(perm_list)
93
+
94
+ if num_bits == 4:
95
+ interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
96
+ elif num_bits == 8:
97
+ interleave = np.array([0, 2, 1, 3])
98
+ else:
99
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
100
+
101
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
102
+ perm = torch.from_numpy(perm)
103
+ return perm
104
+
105
+
106
+ def marlin_quantize(
107
+ w: torch.Tensor,
108
+ quant_type: ScalarType,
109
+ group_size: int,
110
+ act_order: bool,
111
+ test_perm: Optional[torch.Tensor] = None,
112
+ ):
113
+ size_k, size_n = w.shape
114
+ num_bits = quant_type.size_bits
115
+
116
+ # Normalize group_size
117
+ if group_size == -1:
118
+ group_size = size_k
119
+ assert group_size <= size_k
120
+
121
+ # Quantize (and apply act_order if provided)
122
+ w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
123
+ w, quant_type, group_size, act_order, test_perm
124
+ )
125
+
126
+ # For act_order, sort the "weights" and "g_idx" so that group ids are
127
+ # increasing
128
+ sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
129
+ if act_order:
130
+ q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
131
+
132
+ # Reformat to marlin
133
+ weight_perm = get_weight_perm(num_bits)
134
+ marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
135
+ marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
136
+
137
+ # Create result
138
+ res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
139
+ for i in range(len(res_list)):
140
+ res_list[i] = res_list[i].to(w.device)
141
+
142
+ return res_list
143
+
144
+
145
+ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
146
+ size_k, size_n = w.shape
147
+
148
+ # Normalize group_size
149
+ if group_size == -1:
150
+ group_size = size_k
151
+ assert group_size <= size_k
152
+
153
+ # Detect num groups
154
+ assert size_k % group_size == 0
155
+ num_groups = size_k // group_size
156
+
157
+ # Quantize with zp
158
+ w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
159
+
160
+ # Reformat to marlin
161
+ weight_perm = get_weight_perm(quant_type.size_bits)
162
+ marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
163
+ marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
164
+ marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
165
+
166
+ # Create result
167
+ res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
168
+ for i in range(len(res_list)):
169
+ res_list[i] = res_list[i].to(w.device)
170
+
171
+ return res_list
sglang/test/test_utils.py CHANGED
@@ -2,6 +2,7 @@
2
2
 
3
3
  import argparse
4
4
  import copy
5
+ import json
5
6
  import logging
6
7
  import os
7
8
  import random
@@ -102,6 +103,15 @@ def is_in_amd_ci():
102
103
  return get_bool_env_var("SGLANG_AMD_CI")
103
104
 
104
105
 
106
+ def _use_cached_default_models(model_repo: str):
107
+ cache_dir = os.getenv("DEFAULT_MODEL_CACHE_DIR")
108
+ if cache_dir and model_repo:
109
+ model_path = os.path.join(cache_dir, model_repo)
110
+ if os.path.isdir(model_path):
111
+ return os.path.abspath(model_path)
112
+ return ""
113
+
114
+
105
115
  if is_in_ci():
106
116
  DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
107
117
  5000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
@@ -419,6 +429,31 @@ def get_call_select(args: argparse.Namespace):
419
429
  return func
420
430
 
421
431
 
432
+ def _get_default_models():
433
+ import inspect
434
+
435
+ current_module = inspect.getmodule(_get_default_models)
436
+ default_models = set()
437
+ for name, value in current_module.__dict__.items():
438
+ if (
439
+ isinstance(name, str)
440
+ and "DEFAULT_" in name
441
+ and "MODEL_" in name
442
+ and isinstance(value, str)
443
+ ):
444
+ if "," in value:
445
+ parts = [part.strip() for part in value.split(",")]
446
+ default_models.update(parts)
447
+ else:
448
+ default_models.add(value.strip())
449
+ return json.dumps(list(default_models))
450
+
451
+
452
+ def try_cached_model(model_repo: str):
453
+ model_dir = _use_cached_default_models(model_repo)
454
+ return model_dir if model_dir else model_repo
455
+
456
+
422
457
  def popen_launch_server(
423
458
  model: str,
424
459
  base_url: str,
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.9.post2"
1
+ __version__ = "0.4.9.post4"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sglang
3
- Version: 0.4.9.post2
3
+ Version: 0.4.9.post4
4
4
  Summary: SGLang is yet another fast serving framework for large language models and vision language models.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -246,20 +246,20 @@ Requires-Dist: sentencepiece; extra == "runtime-common"
246
246
  Requires-Dist: soundfile==0.13.1; extra == "runtime-common"
247
247
  Requires-Dist: scipy; extra == "runtime-common"
248
248
  Requires-Dist: torchao==0.9.0; extra == "runtime-common"
249
- Requires-Dist: transformers==4.53.0; extra == "runtime-common"
249
+ Requires-Dist: transformers==4.53.2; extra == "runtime-common"
250
250
  Requires-Dist: timm==1.0.16; extra == "runtime-common"
251
251
  Requires-Dist: uvicorn; extra == "runtime-common"
252
252
  Requires-Dist: uvloop; extra == "runtime-common"
253
253
  Requires-Dist: xgrammar==0.1.21; extra == "runtime-common"
254
254
  Provides-Extra: srt
255
255
  Requires-Dist: sglang[runtime_common]; extra == "srt"
256
- Requires-Dist: sgl-kernel==0.2.5; extra == "srt"
256
+ Requires-Dist: sgl-kernel==0.2.7; extra == "srt"
257
257
  Requires-Dist: torch==2.7.1; extra == "srt"
258
258
  Requires-Dist: torchaudio==2.7.1; extra == "srt"
259
259
  Requires-Dist: torchvision==0.22.1; extra == "srt"
260
260
  Requires-Dist: cuda-python; extra == "srt"
261
261
  Requires-Dist: einops; extra == "srt"
262
- Requires-Dist: flashinfer_python==0.2.7.post1; extra == "srt"
262
+ Requires-Dist: flashinfer_python==0.2.9rc1; extra == "srt"
263
263
  Provides-Extra: blackwell
264
264
  Requires-Dist: sglang[runtime_common]; extra == "blackwell"
265
265
  Requires-Dist: sgl-kernel; extra == "blackwell"
@@ -268,11 +268,11 @@ Requires-Dist: torchaudio==2.7.1; extra == "blackwell"
268
268
  Requires-Dist: torchvision==0.22.1; extra == "blackwell"
269
269
  Requires-Dist: cuda-python; extra == "blackwell"
270
270
  Requires-Dist: einops; extra == "blackwell"
271
- Requires-Dist: flashinfer_python==0.2.7.post1; extra == "blackwell"
271
+ Requires-Dist: flashinfer_python==0.2.9rc1; extra == "blackwell"
272
272
  Provides-Extra: srt-hip
273
273
  Requires-Dist: sglang[runtime_common]; extra == "srt-hip"
274
274
  Requires-Dist: torch; extra == "srt-hip"
275
- Requires-Dist: vllm==0.6.7.dev2; extra == "srt-hip"
275
+ Requires-Dist: petit_kernel==0.0.2; extra == "srt-hip"
276
276
  Provides-Extra: srt-xpu
277
277
  Requires-Dist: sglang[runtime_common]; extra == "srt-xpu"
278
278
  Provides-Extra: srt-hpu
@@ -381,14 +381,14 @@ Dynamic: license-file
381
381
  - [2025/05] 🔥 Deploying DeepSeek with PD Disaggregation and Large-scale Expert Parallelism on 96 H100 GPUs ([blog](https://lmsys.org/blog/2025-05-05-large-scale-ep/)).
382
382
  - [2025/03] Supercharge DeepSeek-R1 Inference on AMD Instinct MI300X ([AMD blog](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1-Part2/README.html))
383
383
  - [2025/03] SGLang Joins PyTorch Ecosystem: Efficient LLM Serving Engine ([PyTorch blog](https://pytorch.org/blog/sglang-joins-pytorch/))
384
- - [2025/01] 🔥 SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeepSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html), [10+ other companies](https://x.com/lmsysorg/status/1887262321636221412))
385
- - [2024/12] 🔥 v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).
384
+ - [2024/12] v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).
386
385
  - [2024/07] v0.2 Release: Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
387
386
 
388
387
  <details>
389
388
  <summary>More</summary>
390
389
 
391
390
  - [2025/02] Unlock DeepSeek-R1 Inference Performance on AMD Instinct™ MI300X GPU ([AMD blog](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1_Perf/README.html))
391
+ - [2025/01] SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeepSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html), [10+ other companies](https://x.com/lmsysorg/status/1887262321636221412))
392
392
  - [2024/10] The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)).
393
393
  - [2024/09] v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)).
394
394
  - [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
@@ -415,10 +415,10 @@ The core features include:
415
415
  - [Contribution Guide](https://docs.sglang.ai/references/contribution_guide.html)
416
416
 
417
417
  ## Benchmark and Performance
418
- Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/).
418
+ Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/), [Large-scale expert parallelism](https://lmsys.org/blog/2025-05-05-large-scale-ep/).
419
419
 
420
420
  ## Roadmap
421
- [Development Roadmap (2025 H1)](https://github.com/sgl-project/sglang/issues/4042)
421
+ [Development Roadmap (2025 H2)](https://github.com/sgl-project/sglang/issues/7736)
422
422
 
423
423
  ## Adoption and Sponsorship
424
424
  SGLang has been deployed at large scale, generating trillions of tokens in production each day. It is trusted and adopted by a wide range of leading enterprises and institutions, including xAI, AMD, NVIDIA, Intel, LinkedIn, Cursor, Oracle Cloud, Google Cloud, Microsoft Azure, AWS, Atlas Cloud, Voltage Park, Nebius, DataCrunch, Novita, InnoMatrix, MIT, UCLA, the University of Washington, Stanford, UC Berkeley, Tsinghua University, Jam & Tea Studios, Baseten, and other major technology organizations across North America and Asia. As an open-source LLM inference engine, SGLang has become the de facto industry standard, with deployments running on over 1,000,000 GPUs worldwide.