sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/srt/configs/model_config.py +37 -5
  4. sglang/srt/constrained/base_grammar_backend.py +26 -5
  5. sglang/srt/constrained/llguidance_backend.py +1 -0
  6. sglang/srt/constrained/outlines_backend.py +1 -0
  7. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  8. sglang/srt/constrained/xgrammar_backend.py +1 -0
  9. sglang/srt/disaggregation/base/__init__.py +8 -0
  10. sglang/srt/disaggregation/base/conn.py +113 -0
  11. sglang/srt/disaggregation/decode.py +18 -5
  12. sglang/srt/disaggregation/mini_lb.py +53 -122
  13. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  14. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  16. sglang/srt/disaggregation/prefill.py +43 -19
  17. sglang/srt/disaggregation/utils.py +31 -0
  18. sglang/srt/entrypoints/EngineBase.py +53 -0
  19. sglang/srt/entrypoints/engine.py +36 -8
  20. sglang/srt/entrypoints/http_server.py +37 -8
  21. sglang/srt/entrypoints/http_server_engine.py +142 -0
  22. sglang/srt/entrypoints/verl_engine.py +37 -10
  23. sglang/srt/hf_transformers_utils.py +4 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +330 -200
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  26. sglang/srt/layers/attention/vision.py +1 -1
  27. sglang/srt/layers/dp_attention.py +2 -4
  28. sglang/srt/layers/elementwise.py +15 -2
  29. sglang/srt/layers/linear.py +1 -0
  30. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
  38. sglang/srt/layers/moe/router.py +7 -1
  39. sglang/srt/layers/moe/topk.py +37 -16
  40. sglang/srt/layers/quantization/__init__.py +12 -5
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  42. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  43. sglang/srt/layers/quantization/fp8.py +25 -13
  44. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  45. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  46. sglang/srt/layers/quantization/kv_cache.py +43 -52
  47. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  48. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  49. sglang/srt/layers/quantization/w8a8_int8.py +1 -0
  50. sglang/srt/layers/radix_attention.py +13 -1
  51. sglang/srt/layers/rotary_embedding.py +12 -1
  52. sglang/srt/managers/io_struct.py +254 -97
  53. sglang/srt/managers/mm_utils.py +3 -2
  54. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  55. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  56. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  57. sglang/srt/managers/schedule_batch.py +62 -21
  58. sglang/srt/managers/scheduler.py +71 -14
  59. sglang/srt/managers/tokenizer_manager.py +17 -3
  60. sglang/srt/managers/tp_worker.py +1 -0
  61. sglang/srt/mem_cache/memory_pool.py +14 -1
  62. sglang/srt/metrics/collector.py +9 -0
  63. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  64. sglang/srt/model_executor/forward_batch_info.py +234 -15
  65. sglang/srt/model_executor/model_runner.py +48 -9
  66. sglang/srt/model_loader/loader.py +31 -4
  67. sglang/srt/model_loader/weight_utils.py +4 -2
  68. sglang/srt/models/baichuan.py +2 -0
  69. sglang/srt/models/chatglm.py +1 -0
  70. sglang/srt/models/commandr.py +1 -0
  71. sglang/srt/models/dbrx.py +1 -0
  72. sglang/srt/models/deepseek.py +1 -0
  73. sglang/srt/models/deepseek_v2.py +248 -61
  74. sglang/srt/models/exaone.py +1 -0
  75. sglang/srt/models/gemma.py +1 -0
  76. sglang/srt/models/gemma2.py +1 -0
  77. sglang/srt/models/gemma3_causal.py +1 -0
  78. sglang/srt/models/gpt2.py +1 -0
  79. sglang/srt/models/gpt_bigcode.py +1 -0
  80. sglang/srt/models/granite.py +1 -0
  81. sglang/srt/models/grok.py +1 -0
  82. sglang/srt/models/internlm2.py +1 -0
  83. sglang/srt/models/llama.py +1 -0
  84. sglang/srt/models/llama4.py +101 -34
  85. sglang/srt/models/minicpm.py +1 -0
  86. sglang/srt/models/minicpm3.py +2 -0
  87. sglang/srt/models/mixtral.py +1 -0
  88. sglang/srt/models/mixtral_quant.py +1 -0
  89. sglang/srt/models/mllama.py +51 -8
  90. sglang/srt/models/mllama4.py +102 -29
  91. sglang/srt/models/olmo.py +1 -0
  92. sglang/srt/models/olmo2.py +1 -0
  93. sglang/srt/models/olmoe.py +1 -0
  94. sglang/srt/models/phi3_small.py +1 -0
  95. sglang/srt/models/qwen.py +1 -0
  96. sglang/srt/models/qwen2.py +1 -0
  97. sglang/srt/models/qwen2_5_vl.py +35 -70
  98. sglang/srt/models/qwen2_moe.py +1 -0
  99. sglang/srt/models/qwen2_vl.py +27 -25
  100. sglang/srt/models/stablelm.py +1 -0
  101. sglang/srt/models/xverse.py +1 -0
  102. sglang/srt/models/xverse_moe.py +1 -0
  103. sglang/srt/openai_api/adapter.py +4 -1
  104. sglang/srt/patch_torch.py +11 -0
  105. sglang/srt/server_args.py +34 -0
  106. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  107. sglang/srt/speculative/eagle_utils.py +1 -11
  108. sglang/srt/speculative/eagle_worker.py +6 -2
  109. sglang/srt/utils.py +120 -9
  110. sglang/test/attention/test_flashattn_backend.py +259 -221
  111. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  112. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  113. sglang/test/test_block_fp8.py +57 -0
  114. sglang/test/test_utils.py +19 -8
  115. sglang/version.py +1 -1
  116. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  117. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
  118. sglang/srt/disaggregation/conn.py +0 -81
  119. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  120. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  121. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,285 @@
1
+ import unittest
2
+
3
+ import torch
4
+
5
+ from sglang.srt.configs.model_config import AttentionArch
6
+ from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
7
+ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
8
+ from sglang.srt.layers.radix_attention import RadixAttention
9
+ from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
10
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
11
+ from sglang.test.test_utils import CustomTestCase
12
+
13
+
14
+ class MockModelRunner:
15
+ def __init__(
16
+ self,
17
+ kv_lora_rank,
18
+ qk_rope_head_dim,
19
+ ):
20
+ attention_arch = AttentionArch.MLA
21
+ self.device = "cuda"
22
+ self.dtype = torch.float16
23
+ context_len = 2048
24
+ self.model_config = type(
25
+ "ModelConfig",
26
+ (),
27
+ {
28
+ "context_len": context_len,
29
+ "attention_arch": attention_arch,
30
+ },
31
+ )
32
+ self.sliding_window_size = None
33
+
34
+ batch_size = 160
35
+ # Create a proper req_to_token_pool with the req_to_token attribute
36
+ self.req_to_token_pool = type(
37
+ "TokenPool",
38
+ (),
39
+ {
40
+ # A typical max_bs * max_context_len for cuda graph decode
41
+ "size": batch_size,
42
+ # Add req_to_token attribute
43
+ "req_to_token": torch.zeros(
44
+ batch_size, context_len, dtype=torch.int32, device=self.device
45
+ ),
46
+ },
47
+ )
48
+ self.page_size = 1
49
+ max_total_num_tokens = batch_size * context_len
50
+ self.token_to_kv_pool = MLATokenToKVPool(
51
+ size=max_total_num_tokens,
52
+ page_size=self.page_size,
53
+ dtype=self.dtype,
54
+ kv_lora_rank=kv_lora_rank,
55
+ qk_rope_head_dim=qk_rope_head_dim,
56
+ layer_num=1, # only consider layer=1 for unit test
57
+ device=self.device,
58
+ enable_memory_saver=False,
59
+ )
60
+
61
+
62
+ class MockReqToTokenPool:
63
+ def __init__(self, batch_size, seq_len, device):
64
+ self.req_to_token = (
65
+ torch.arange(batch_size * seq_len, device=device)
66
+ .reshape(batch_size, seq_len)
67
+ .to(torch.int32)
68
+ )
69
+
70
+
71
+ @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
72
+ class TestFlashAttentionMLABackend(CustomTestCase):
73
+ def setUp(self):
74
+ # Test parameters
75
+ self.batch_size = 2
76
+ self.seq_len = 360
77
+ self.num_heads = 2
78
+ self.device = "cuda"
79
+ self.dtype = torch.float16
80
+ self.kv_lora_rank = 512
81
+ self.q_lora_rank = 128
82
+ self.qk_rope_head_dim = 64
83
+ self.qk_head_dim = self.qk_rope_head_dim + self.kv_lora_rank
84
+ # Assume no rope scaling
85
+ self.scaling = self.qk_head_dim**-0.5
86
+ # Initialize model runner and backend
87
+ self._init_model_runner()
88
+ self.backend = FlashAttentionBackend(self.model_runner)
89
+ self.num_local_heads = 2
90
+
91
+ def _init_model_runner(self):
92
+ self.model_runner = MockModelRunner(
93
+ kv_lora_rank=self.kv_lora_rank,
94
+ qk_rope_head_dim=self.qk_rope_head_dim,
95
+ )
96
+ self.backend = FlashAttentionBackend(self.model_runner)
97
+
98
+ def _create_attention_layer(self):
99
+ """Create attention layer for testing."""
100
+ self.attn_mqa = RadixAttention(
101
+ num_heads=self.num_local_heads,
102
+ head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
103
+ scaling=self.scaling,
104
+ num_kv_heads=1,
105
+ layer_id=0,
106
+ v_head_dim=self.kv_lora_rank,
107
+ prefix="attn_mqa",
108
+ )
109
+ return self.attn_mqa
110
+
111
+ def _run_reference_forward(
112
+ self, mode, q, k, v, layer, forward_batch, expected_shape
113
+ ):
114
+ """Run reference forward pass using native backend."""
115
+ if mode == ForwardMode.EXTEND:
116
+ output = self.ref_backend.forward_extend(q, k, v, layer, forward_batch)
117
+ else: # ForwardMode.DECODE
118
+ output = self.ref_backend.forward_decode(q, k, v, layer, forward_batch)
119
+ return output.view(expected_shape)
120
+
121
+ def _verify_output(self, output, expected_shape):
122
+ """Verify output tensor shape, dtype, and values."""
123
+ self.assertEqual(
124
+ output.shape,
125
+ expected_shape,
126
+ f"Expected shape {expected_shape}, got {output.shape}",
127
+ )
128
+ self.assertEqual(output.dtype, self.dtype)
129
+ self.assertEqual(output.device.type, "cuda")
130
+ self.assertEqual(
131
+ torch.isnan(output).sum().item(), 0, "Output contains NaN values"
132
+ )
133
+
134
+ def _create_forward_batch(self, mode, q_len=None, prefix_len=0):
135
+ """Create a forward batch for testing based on mode and lengths."""
136
+ # Default to self.seq_len if not specified
137
+ q_len = q_len or self.seq_len
138
+
139
+ if mode == ForwardMode.EXTEND:
140
+ total_len = prefix_len + q_len
141
+ out_cache_start = prefix_len * self.batch_size
142
+ out_cache_end = total_len * self.batch_size
143
+
144
+ forward_batch = ForwardBatch(
145
+ batch_size=self.batch_size,
146
+ input_ids=torch.randint(
147
+ 0, 100, (self.batch_size, q_len), device=self.device
148
+ ),
149
+ out_cache_loc=torch.arange(
150
+ out_cache_start, out_cache_end, device=self.device
151
+ ),
152
+ seq_lens_sum=self.batch_size * total_len,
153
+ forward_mode=mode,
154
+ req_pool_indices=torch.arange(self.batch_size, device=self.device),
155
+ seq_lens=torch.tensor(
156
+ [total_len] * self.batch_size, device=self.device
157
+ ),
158
+ seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
159
+ extend_prefix_lens=torch.tensor(
160
+ [prefix_len] * self.batch_size, device=self.device
161
+ ),
162
+ extend_prefix_lens_cpu=torch.tensor(
163
+ [prefix_len] * self.batch_size, device="cpu"
164
+ ),
165
+ extend_seq_lens=torch.tensor(
166
+ [q_len] * self.batch_size, device=self.device
167
+ ),
168
+ extend_seq_lens_cpu=torch.tensor(
169
+ [q_len] * self.batch_size, device="cpu"
170
+ ),
171
+ attn_backend=self.backend,
172
+ )
173
+
174
+ else: # ForwardMode.DECODE
175
+ decode_len = q_len # typically 1 for decode mode
176
+ total_len = self.seq_len + decode_len
177
+ out_cache_start = self.batch_size * self.seq_len
178
+ out_cache_end = self.batch_size * total_len
179
+
180
+ forward_batch = ForwardBatch(
181
+ batch_size=self.batch_size,
182
+ input_ids=torch.randint(
183
+ 0, 100, (self.batch_size, decode_len), device=self.device
184
+ ),
185
+ out_cache_loc=torch.arange(
186
+ out_cache_start, out_cache_end, device=self.device
187
+ ),
188
+ seq_lens_sum=self.batch_size * total_len,
189
+ forward_mode=mode,
190
+ req_pool_indices=torch.arange(self.batch_size, device=self.device),
191
+ seq_lens=torch.tensor(
192
+ [total_len] * self.batch_size, device=self.device
193
+ ),
194
+ seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
195
+ attn_backend=self.backend,
196
+ )
197
+
198
+ # Add token pool from model runner to forward batch
199
+ forward_batch.req_to_token_pool = self.model_runner.req_to_token_pool
200
+
201
+ # Add KV cache from model runner to forward batch
202
+ forward_batch.token_to_kv_pool = self.model_runner.token_to_kv_pool
203
+
204
+ return forward_batch
205
+
206
+ def _setup_kv_cache(self, forward_batch, layer, cache_len):
207
+ """Set up KV cache with prefix tokens."""
208
+ if cache_len <= 0:
209
+ return
210
+
211
+ # Create constant values for the prefix cache for easy debugging
212
+ latent_cache = torch.ones(
213
+ self.batch_size * cache_len,
214
+ 1, # latent cache has only one head in MQA
215
+ self.kv_lora_rank + self.qk_rope_head_dim,
216
+ dtype=self.dtype,
217
+ device=self.device,
218
+ )
219
+
220
+ # Set the prefix KV cache
221
+ forward_batch.token_to_kv_pool.set_kv_buffer(
222
+ layer,
223
+ torch.arange(self.batch_size * cache_len, device=self.device),
224
+ latent_cache,
225
+ None,
226
+ )
227
+
228
+ def _run_attention_test(self, mode, q_len, prefix_len=0):
229
+ """
230
+ Run an attention test with the specified parameters.
231
+ Args:
232
+ mode: ForwardMode.EXTEND or ForwardMode.DECODE
233
+ q_len: Length of the query sequence. For decode mode, q_len is 1.
234
+ prefix_len: Length of the prefix sequence for extend mode
235
+ """
236
+ layer = self._create_attention_layer()
237
+
238
+ # Create forward batch and set up
239
+ forward_batch = self._create_forward_batch(mode, q_len, prefix_len)
240
+
241
+ # Create q, kv_compressed for testing
242
+ q_shape = (self.batch_size * q_len, self.num_heads, self.qk_head_dim)
243
+ kv_shape = (self.batch_size * q_len, self.qk_head_dim)
244
+ q = torch.randn(q_shape, dtype=self.dtype, device=self.device)
245
+ kv_compressed = torch.randn(kv_shape, dtype=self.dtype, device=self.device)
246
+ # v is not used for mqa, all values passed in through k
247
+ k = kv_compressed.unsqueeze(1)
248
+ v = torch.randn((1), dtype=self.dtype, device=self.device)
249
+
250
+ self._setup_kv_cache(forward_batch, layer, prefix_len)
251
+
252
+ self.backend.init_forward_metadata(forward_batch)
253
+
254
+ expected_shape = (
255
+ self.batch_size * q_len,
256
+ self.num_heads * self.kv_lora_rank,
257
+ )
258
+
259
+ if mode == ForwardMode.EXTEND:
260
+ output = self.backend.forward_extend(q, k, v, layer, forward_batch)
261
+ else:
262
+ output = self.backend.forward_decode(q, k, v, layer, forward_batch)
263
+
264
+ self._verify_output(output, expected_shape)
265
+ return output
266
+
267
+ def test_forward_extend(self):
268
+ """Test the standard extend operation."""
269
+ self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len)
270
+
271
+ def test_forward_decode(self):
272
+ """Test the decode operation with cached tokens."""
273
+ self._run_attention_test(ForwardMode.DECODE, q_len=1)
274
+
275
+ def test_forward_extend_with_prefix(self):
276
+ """Test extending from cached prefix tokens."""
277
+ prefix_len = self.seq_len // 2
278
+ extend_len = self.seq_len - prefix_len
279
+ self._run_attention_test(
280
+ ForwardMode.EXTEND, q_len=extend_len, prefix_len=prefix_len
281
+ )
282
+
283
+
284
+ if __name__ == "__main__":
285
+ unittest.main()
@@ -0,0 +1,224 @@
1
+ import unittest
2
+
3
+ import torch
4
+
5
+ from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
6
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
7
+ from sglang.test.test_utils import CustomTestCase
8
+
9
+ TEST_CASES = [
10
+ # Sequence with same prefix lens
11
+ {
12
+ "batch_size": 3,
13
+ "prefix_lens": [64, 64, 64],
14
+ "max_chunk_capacity": 48,
15
+ "prefix_chunk_len": 16,
16
+ "num_prefix_chunks": 4,
17
+ "prefix_chunk_starts": torch.tensor(
18
+ [
19
+ [0, 0, 0],
20
+ [16, 16, 16],
21
+ [32, 32, 32],
22
+ [48, 48, 48],
23
+ ],
24
+ dtype=torch.int32,
25
+ ),
26
+ "prefix_chunk_seq_lens": torch.tensor(
27
+ [
28
+ [16, 16, 16],
29
+ [16, 16, 16],
30
+ [16, 16, 16],
31
+ [16, 16, 16],
32
+ ],
33
+ dtype=torch.int32,
34
+ ),
35
+ },
36
+ # Sequence with different prefix lens
37
+ {
38
+ "batch_size": 4,
39
+ "prefix_lens": [16, 32, 48, 64],
40
+ "max_chunk_capacity": 64,
41
+ "prefix_chunk_len": 16,
42
+ "num_prefix_chunks": 4,
43
+ "prefix_chunk_starts": torch.tensor(
44
+ [
45
+ [0, 0, 0, 0],
46
+ [16, 16, 16, 16],
47
+ [32, 32, 32, 32],
48
+ [48, 48, 48, 48],
49
+ ],
50
+ dtype=torch.int32,
51
+ ),
52
+ "prefix_chunk_seq_lens": torch.tensor(
53
+ [
54
+ [16, 16, 16, 16],
55
+ [0, 16, 16, 16],
56
+ [0, 0, 16, 16],
57
+ [0, 0, 0, 16],
58
+ ],
59
+ dtype=torch.int32,
60
+ ),
61
+ },
62
+ # Sequence with irregular shapes
63
+ {
64
+ "batch_size": 2,
65
+ "prefix_lens": [1, 64],
66
+ "max_chunk_capacity": 31,
67
+ "prefix_chunk_len": 15,
68
+ "num_prefix_chunks": 5,
69
+ "prefix_chunk_starts": torch.tensor(
70
+ [
71
+ [0, 0],
72
+ [15, 15],
73
+ [30, 30],
74
+ [45, 45],
75
+ [60, 60],
76
+ ],
77
+ dtype=torch.int32,
78
+ ),
79
+ "prefix_chunk_seq_lens": torch.tensor(
80
+ [
81
+ [1, 15],
82
+ [0, 15],
83
+ [0, 15],
84
+ [0, 15],
85
+ [0, 4],
86
+ ],
87
+ dtype=torch.int32,
88
+ ),
89
+ },
90
+ ]
91
+
92
+
93
+ class MockForwardBatch(ForwardBatch):
94
+ def __init__(self, max_chunk_capacity: int, *args, **kwargs):
95
+ super().__init__(*args, **kwargs)
96
+ self.max_chunk_capacity = max_chunk_capacity
97
+
98
+ def get_max_chunk_capacity(self):
99
+ return self.max_chunk_capacity
100
+
101
+
102
+ class MockReqToTokenPool:
103
+ def __init__(self, batch_size, seq_len, device):
104
+ self.req_to_token = (
105
+ torch.arange(batch_size * seq_len, device=device)
106
+ .reshape(batch_size, seq_len)
107
+ .to(torch.int32)
108
+ )
109
+
110
+
111
+ # Test correctness of triton kernel for computing kv indices
112
+ def check_kv_indices(forward_batch):
113
+ for i in range(forward_batch.num_prefix_chunks):
114
+ computed_kv_indices = forward_batch.prefix_chunk_kv_indices[i]
115
+ req_to_token = forward_batch.req_to_token_pool.req_to_token[
116
+ : forward_batch.batch_size, :
117
+ ]
118
+ ref_kv_indices = torch.empty(
119
+ forward_batch.prefix_chunk_num_tokens[i],
120
+ dtype=torch.int32,
121
+ device=computed_kv_indices.device,
122
+ )
123
+ running_ptr = 0
124
+ for j in range(forward_batch.batch_size):
125
+ seq_start = forward_batch.prefix_chunk_starts[i, j].item()
126
+ seq_len = forward_batch.prefix_chunk_seq_lens[i, j].item()
127
+ ref_kv_indices[running_ptr : running_ptr + seq_len].copy_(
128
+ req_to_token[j, seq_start : seq_start + seq_len]
129
+ )
130
+ running_ptr += seq_len
131
+ assert torch.allclose(computed_kv_indices, ref_kv_indices)
132
+
133
+
134
+ @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
135
+ class TestPrefixChunkInfo(CustomTestCase):
136
+ def setUp(self):
137
+ # Common test parameters
138
+ self.num_local_heads = 128
139
+ self.kv_lora_rank = 512
140
+ self.qk_rope_head_dim = 64
141
+ self.device = torch.device("cuda")
142
+ self.dtype = torch.bfloat16
143
+ self.extend_len = 64
144
+ self.max_bs = 4
145
+ self.max_seq_len = 128
146
+
147
+ # req_to_token_pool
148
+ self.req_to_token_pool = MockReqToTokenPool(
149
+ self.max_bs,
150
+ self.max_seq_len,
151
+ self.device,
152
+ )
153
+
154
+ # token_to_kv_pool
155
+ self.token_to_kv_pool = MLATokenToKVPool(
156
+ size=self.max_bs * self.max_seq_len,
157
+ page_size=1, # only consider page=1 for unit test
158
+ dtype=self.dtype,
159
+ kv_lora_rank=self.kv_lora_rank,
160
+ qk_rope_head_dim=self.qk_rope_head_dim,
161
+ layer_num=1, # only consider layer=1 for unit test
162
+ device=self.device,
163
+ enable_memory_saver=False,
164
+ )
165
+
166
+ def test_prefix_chunk_info(self):
167
+ """Test the standard extend operation."""
168
+
169
+ for test_case in TEST_CASES:
170
+ print(
171
+ f"Test case with batch_size={test_case['batch_size']}, prefix_lens={test_case['prefix_lens']}, max_chunk_capacity={test_case['max_chunk_capacity']}"
172
+ )
173
+ batch_size = test_case["batch_size"]
174
+ prefix_lens_cpu = test_case["prefix_lens"]
175
+ assert len(prefix_lens_cpu) == batch_size
176
+ prefix_lens = torch.tensor(prefix_lens_cpu, device=self.device)
177
+ max_chunk_capacity = test_case["max_chunk_capacity"]
178
+ seq_lens_cpu = [
179
+ self.extend_len + prefix_lens_cpu[i] for i in range(batch_size)
180
+ ]
181
+ seq_lens = torch.tensor(seq_lens_cpu, device=self.device)
182
+
183
+ # Create forward batch
184
+ # input_ids and out_cache_loc are dummy tensors in this test
185
+ forward_batch = MockForwardBatch(
186
+ max_chunk_capacity=max_chunk_capacity,
187
+ batch_size=batch_size,
188
+ input_ids=torch.randint(
189
+ 0, 100, (batch_size, self.extend_len), device=self.device
190
+ ),
191
+ out_cache_loc=torch.arange(
192
+ self.max_bs * self.max_seq_len - batch_size * self.extend_len,
193
+ self.max_bs * self.max_seq_len,
194
+ device=self.device,
195
+ ),
196
+ seq_lens_sum=sum(seq_lens_cpu),
197
+ forward_mode=ForwardMode.EXTEND,
198
+ req_pool_indices=torch.arange(batch_size, device=self.device),
199
+ seq_lens=seq_lens,
200
+ seq_lens_cpu=seq_lens_cpu,
201
+ extend_prefix_lens=prefix_lens,
202
+ extend_prefix_lens_cpu=prefix_lens_cpu,
203
+ )
204
+ forward_batch.req_to_token_pool = self.req_to_token_pool
205
+ forward_batch.token_to_kv_pool = self.token_to_kv_pool
206
+
207
+ forward_batch.prepare_chunked_prefix_cache_info(self.device)
208
+ assert forward_batch.get_max_chunk_capacity() == max_chunk_capacity
209
+ assert forward_batch.prefix_chunk_len == test_case["prefix_chunk_len"]
210
+ assert forward_batch.num_prefix_chunks == test_case["num_prefix_chunks"]
211
+ assert torch.allclose(
212
+ forward_batch.prefix_chunk_starts,
213
+ test_case["prefix_chunk_starts"].to(self.device),
214
+ )
215
+ assert torch.allclose(
216
+ forward_batch.prefix_chunk_seq_lens,
217
+ test_case["prefix_chunk_seq_lens"].to(self.device),
218
+ )
219
+
220
+ check_kv_indices(forward_batch)
221
+
222
+
223
+ if __name__ == "__main__":
224
+ unittest.main()
@@ -7,10 +7,12 @@ import torch
7
7
  from sglang.srt.layers.activation import SiluAndMul
8
8
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
9
9
  from sglang.srt.layers.quantization.fp8_kernel import (
10
+ per_tensor_quant_mla_fp8,
10
11
  per_token_group_quant_fp8,
11
12
  static_quant_fp8,
12
13
  w8a8_block_fp8_matmul,
13
14
  )
15
+ from sglang.srt.layers.quantization.fp8_utils import input_to_float8
14
16
  from sglang.test.test_utils import CustomTestCase
15
17
 
16
18
  _is_cuda = torch.cuda.is_available() and torch.version.cuda
@@ -155,6 +157,61 @@ class TestStaticQuantFP8(CustomTestCase):
155
157
  self._static_quant_fp8(*params)
156
158
 
157
159
 
160
+ class TestPerTensorQuantMlaFP8(CustomTestCase):
161
+ DTYPES = [torch.half, torch.bfloat16, torch.float32]
162
+ NUM_TOKENS = [7, 83, 2048]
163
+ D = [512, 4096, 5120, 13824]
164
+ LAST_D_EXT = [1024, 0]
165
+ LAST_D = [512]
166
+ SEEDS = [0]
167
+
168
+ @classmethod
169
+ def setUpClass(cls):
170
+ if not torch.cuda.is_available():
171
+ raise unittest.SkipTest("CUDA is not available")
172
+ torch.set_default_device("cuda")
173
+
174
+ def _per_tensor_quant_mla_fp8(self, num_tokens, d, last_d_ext, last_d, dtype, seed):
175
+ torch.manual_seed(seed)
176
+
177
+ x = torch.rand(
178
+ (num_tokens, d // last_d, last_d + last_d_ext),
179
+ dtype=dtype,
180
+ )
181
+ x_sub, _ = x.split([last_d, last_d_ext], dim=-1)
182
+
183
+ with torch.inference_mode():
184
+ ref_out, ref_s = input_to_float8(x_sub.transpose(0, 1))
185
+ out, out_s = per_tensor_quant_mla_fp8(x_sub.transpose(0, 1))
186
+
187
+ self.assertTrue(out.is_contiguous())
188
+ self.assertTrue(
189
+ torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50)
190
+ )
191
+ self.assertTrue(
192
+ torch.allclose(out_s.to(torch.float32), ref_s.to(torch.float32))
193
+ )
194
+
195
+ def test_per_tensor_quant_mla_fp8(self):
196
+ for params in itertools.product(
197
+ self.NUM_TOKENS,
198
+ self.D,
199
+ self.LAST_D_EXT,
200
+ self.LAST_D,
201
+ self.DTYPES,
202
+ self.SEEDS,
203
+ ):
204
+ with self.subTest(
205
+ num_tokens=params[0],
206
+ d=params[1],
207
+ last_d_ext=params[2],
208
+ last_d=params[3],
209
+ dtype=params[4],
210
+ seed=params[5],
211
+ ):
212
+ self._per_tensor_quant_mla_fp8(*params)
213
+
214
+
158
215
  # For test
159
216
  def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
160
217
  """This function performs matrix multiplication with block-wise quantization using native torch.
sglang/test/test_utils.py CHANGED
@@ -25,7 +25,12 @@ from sglang.bench_serving import run_benchmark
25
25
  from sglang.global_config import global_config
26
26
  from sglang.lang.backend.openai import OpenAI
27
27
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
28
- from sglang.srt.utils import get_bool_env_var, kill_process_tree, retry
28
+ from sglang.srt.utils import (
29
+ get_bool_env_var,
30
+ is_port_available,
31
+ kill_process_tree,
32
+ retry,
33
+ )
29
34
  from sglang.test.run_eval import run_eval
30
35
  from sglang.utils import get_exception_traceback
31
36
 
@@ -37,11 +42,6 @@ DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST = (
37
42
  DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST = (
38
43
  "nvidia/Llama-3.1-8B-Instruct-FP8"
39
44
  )
40
- # TODO(yundai424): right now specifying to an older revision since the latest one
41
- # carries kv cache quantization which doesn't work yet
42
- DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION = (
43
- "13858565416dbdc0b4e7a4a677fadfbd5b9e5bb9"
44
- )
45
45
 
46
46
  DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
47
47
  DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
@@ -103,6 +103,17 @@ def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None)
103
103
  return pred
104
104
 
105
105
 
106
+ def find_available_port(base_port: int):
107
+ port = base_port + random.randint(100, 1000)
108
+ while True:
109
+ if is_port_available(port):
110
+ return port
111
+ if port < 60000:
112
+ port += 42
113
+ else:
114
+ port -= 43
115
+
116
+
106
117
  def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None):
107
118
  assert url is not None
108
119
 
@@ -674,8 +685,6 @@ def run_bench_one_batch(model, other_args):
674
685
  "python3",
675
686
  "-m",
676
687
  "sglang.bench_one_batch",
677
- "--model-path",
678
- model,
679
688
  "--batch-size",
680
689
  "1",
681
690
  "--input",
@@ -684,6 +693,8 @@ def run_bench_one_batch(model, other_args):
684
693
  "8",
685
694
  *[str(x) for x in other_args],
686
695
  ]
696
+ if model is not None:
697
+ command += ["--model-path", model]
687
698
  process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
688
699
 
689
700
  try:
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.5"
1
+ __version__ = "0.4.5.post1"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sglang
3
- Version: 0.4.5
3
+ Version: 0.4.5.post1
4
4
  Summary: SGLang is yet another fast serving framework for large language models and vision language models.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -239,20 +239,30 @@ Requires-Dist: python-multipart; extra == "runtime-common"
239
239
  Requires-Dist: pyzmq>=25.1.2; extra == "runtime-common"
240
240
  Requires-Dist: soundfile==0.13.1; extra == "runtime-common"
241
241
  Requires-Dist: torchao>=0.7.0; extra == "runtime-common"
242
- Requires-Dist: transformers==4.51.0; extra == "runtime-common"
242
+ Requires-Dist: transformers==4.51.1; extra == "runtime-common"
243
243
  Requires-Dist: uvicorn; extra == "runtime-common"
244
244
  Requires-Dist: uvloop; extra == "runtime-common"
245
245
  Requires-Dist: compressed-tensors; extra == "runtime-common"
246
246
  Requires-Dist: xgrammar==0.1.17; extra == "runtime-common"
247
247
  Provides-Extra: srt
248
248
  Requires-Dist: sglang[runtime_common]; extra == "srt"
249
- Requires-Dist: sgl-kernel==0.0.8; extra == "srt"
249
+ Requires-Dist: sgl-kernel==0.0.9.post1; extra == "srt"
250
250
  Requires-Dist: flashinfer_python==0.2.3; extra == "srt"
251
251
  Requires-Dist: torch==2.5.1; extra == "srt"
252
+ Requires-Dist: torchvision==0.20.1; extra == "srt"
252
253
  Requires-Dist: cuda-python; extra == "srt"
253
254
  Requires-Dist: outlines<=0.1.11,>=0.0.44; extra == "srt"
254
255
  Requires-Dist: partial_json_parser; extra == "srt"
255
256
  Requires-Dist: einops; extra == "srt"
257
+ Provides-Extra: blackwell
258
+ Requires-Dist: sglang[runtime_common]; extra == "blackwell"
259
+ Requires-Dist: sgl-kernel; extra == "blackwell"
260
+ Requires-Dist: torch; extra == "blackwell"
261
+ Requires-Dist: torchvision; extra == "blackwell"
262
+ Requires-Dist: cuda-python; extra == "blackwell"
263
+ Requires-Dist: outlines<=0.1.11,>=0.0.44; extra == "blackwell"
264
+ Requires-Dist: partial_json_parser; extra == "blackwell"
265
+ Requires-Dist: einops; extra == "blackwell"
256
266
  Provides-Extra: srt-hip
257
267
  Requires-Dist: sglang[runtime_common]; extra == "srt-hip"
258
268
  Requires-Dist: torch; extra == "srt-hip"
@@ -391,7 +401,7 @@ Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-s
391
401
 
392
402
  ## Adoption and Sponsorship
393
403
  The project has been deployed to large-scale production, generating trillions of tokens every day.
394
- It is supported by the following institutions: AMD, Atlas Cloud, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Iflytek, Jam & Tea Studios, LinkedIn, LMSYS, Meituan, Nebius, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, and 01.AI.
404
+ It is supported by the following institutions: AMD, Atlas Cloud, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Iflytek, Jam & Tea Studios, LinkedIn, LMSYS, Meituan, Nebius, Novita AI, NVIDIA, Oracle, RunPod, Stanford, UC Berkeley, UCLA, xAI, and 01.AI.
395
405
 
396
406
  <img src="https://raw.githubusercontent.com/sgl-project/sgl-learning-materials/main/slides/adoption.png" alt="logo" width="800" margin="10px"></img>
397
407