sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +40 -4
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +609 -202
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +28 -14
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/radix_attention.py +14 -0
- sglang/srt/layers/rotary_embedding.py +75 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +49 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +13 -4
- sglang/srt/models/llama4.py +487 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +227 -0
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,285 @@
|
|
1
|
+
import unittest
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.configs.model_config import AttentionArch
|
6
|
+
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
7
|
+
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
8
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
9
|
+
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
10
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
11
|
+
from sglang.test.test_utils import CustomTestCase
|
12
|
+
|
13
|
+
|
14
|
+
class MockModelRunner:
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
kv_lora_rank,
|
18
|
+
qk_rope_head_dim,
|
19
|
+
):
|
20
|
+
attention_arch = AttentionArch.MLA
|
21
|
+
self.device = "cuda"
|
22
|
+
self.dtype = torch.float16
|
23
|
+
context_len = 2048
|
24
|
+
self.model_config = type(
|
25
|
+
"ModelConfig",
|
26
|
+
(),
|
27
|
+
{
|
28
|
+
"context_len": context_len,
|
29
|
+
"attention_arch": attention_arch,
|
30
|
+
},
|
31
|
+
)
|
32
|
+
self.sliding_window_size = None
|
33
|
+
|
34
|
+
batch_size = 160
|
35
|
+
# Create a proper req_to_token_pool with the req_to_token attribute
|
36
|
+
self.req_to_token_pool = type(
|
37
|
+
"TokenPool",
|
38
|
+
(),
|
39
|
+
{
|
40
|
+
# A typical max_bs * max_context_len for cuda graph decode
|
41
|
+
"size": batch_size,
|
42
|
+
# Add req_to_token attribute
|
43
|
+
"req_to_token": torch.zeros(
|
44
|
+
batch_size, context_len, dtype=torch.int32, device=self.device
|
45
|
+
),
|
46
|
+
},
|
47
|
+
)
|
48
|
+
self.page_size = 1
|
49
|
+
max_total_num_tokens = batch_size * context_len
|
50
|
+
self.token_to_kv_pool = MLATokenToKVPool(
|
51
|
+
size=max_total_num_tokens,
|
52
|
+
page_size=self.page_size,
|
53
|
+
dtype=self.dtype,
|
54
|
+
kv_lora_rank=kv_lora_rank,
|
55
|
+
qk_rope_head_dim=qk_rope_head_dim,
|
56
|
+
layer_num=1, # only consider layer=1 for unit test
|
57
|
+
device=self.device,
|
58
|
+
enable_memory_saver=False,
|
59
|
+
)
|
60
|
+
|
61
|
+
|
62
|
+
class MockReqToTokenPool:
|
63
|
+
def __init__(self, batch_size, seq_len, device):
|
64
|
+
self.req_to_token = (
|
65
|
+
torch.arange(batch_size * seq_len, device=device)
|
66
|
+
.reshape(batch_size, seq_len)
|
67
|
+
.to(torch.int32)
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
72
|
+
class TestFlashAttentionMLABackend(CustomTestCase):
|
73
|
+
def setUp(self):
|
74
|
+
# Test parameters
|
75
|
+
self.batch_size = 2
|
76
|
+
self.seq_len = 360
|
77
|
+
self.num_heads = 2
|
78
|
+
self.device = "cuda"
|
79
|
+
self.dtype = torch.float16
|
80
|
+
self.kv_lora_rank = 512
|
81
|
+
self.q_lora_rank = 128
|
82
|
+
self.qk_rope_head_dim = 64
|
83
|
+
self.qk_head_dim = self.qk_rope_head_dim + self.kv_lora_rank
|
84
|
+
# Assume no rope scaling
|
85
|
+
self.scaling = self.qk_head_dim**-0.5
|
86
|
+
# Initialize model runner and backend
|
87
|
+
self._init_model_runner()
|
88
|
+
self.backend = FlashAttentionBackend(self.model_runner)
|
89
|
+
self.num_local_heads = 2
|
90
|
+
|
91
|
+
def _init_model_runner(self):
|
92
|
+
self.model_runner = MockModelRunner(
|
93
|
+
kv_lora_rank=self.kv_lora_rank,
|
94
|
+
qk_rope_head_dim=self.qk_rope_head_dim,
|
95
|
+
)
|
96
|
+
self.backend = FlashAttentionBackend(self.model_runner)
|
97
|
+
|
98
|
+
def _create_attention_layer(self):
|
99
|
+
"""Create attention layer for testing."""
|
100
|
+
self.attn_mqa = RadixAttention(
|
101
|
+
num_heads=self.num_local_heads,
|
102
|
+
head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
|
103
|
+
scaling=self.scaling,
|
104
|
+
num_kv_heads=1,
|
105
|
+
layer_id=0,
|
106
|
+
v_head_dim=self.kv_lora_rank,
|
107
|
+
prefix="attn_mqa",
|
108
|
+
)
|
109
|
+
return self.attn_mqa
|
110
|
+
|
111
|
+
def _run_reference_forward(
|
112
|
+
self, mode, q, k, v, layer, forward_batch, expected_shape
|
113
|
+
):
|
114
|
+
"""Run reference forward pass using native backend."""
|
115
|
+
if mode == ForwardMode.EXTEND:
|
116
|
+
output = self.ref_backend.forward_extend(q, k, v, layer, forward_batch)
|
117
|
+
else: # ForwardMode.DECODE
|
118
|
+
output = self.ref_backend.forward_decode(q, k, v, layer, forward_batch)
|
119
|
+
return output.view(expected_shape)
|
120
|
+
|
121
|
+
def _verify_output(self, output, expected_shape):
|
122
|
+
"""Verify output tensor shape, dtype, and values."""
|
123
|
+
self.assertEqual(
|
124
|
+
output.shape,
|
125
|
+
expected_shape,
|
126
|
+
f"Expected shape {expected_shape}, got {output.shape}",
|
127
|
+
)
|
128
|
+
self.assertEqual(output.dtype, self.dtype)
|
129
|
+
self.assertEqual(output.device.type, "cuda")
|
130
|
+
self.assertEqual(
|
131
|
+
torch.isnan(output).sum().item(), 0, "Output contains NaN values"
|
132
|
+
)
|
133
|
+
|
134
|
+
def _create_forward_batch(self, mode, q_len=None, prefix_len=0):
|
135
|
+
"""Create a forward batch for testing based on mode and lengths."""
|
136
|
+
# Default to self.seq_len if not specified
|
137
|
+
q_len = q_len or self.seq_len
|
138
|
+
|
139
|
+
if mode == ForwardMode.EXTEND:
|
140
|
+
total_len = prefix_len + q_len
|
141
|
+
out_cache_start = prefix_len * self.batch_size
|
142
|
+
out_cache_end = total_len * self.batch_size
|
143
|
+
|
144
|
+
forward_batch = ForwardBatch(
|
145
|
+
batch_size=self.batch_size,
|
146
|
+
input_ids=torch.randint(
|
147
|
+
0, 100, (self.batch_size, q_len), device=self.device
|
148
|
+
),
|
149
|
+
out_cache_loc=torch.arange(
|
150
|
+
out_cache_start, out_cache_end, device=self.device
|
151
|
+
),
|
152
|
+
seq_lens_sum=self.batch_size * total_len,
|
153
|
+
forward_mode=mode,
|
154
|
+
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
155
|
+
seq_lens=torch.tensor(
|
156
|
+
[total_len] * self.batch_size, device=self.device
|
157
|
+
),
|
158
|
+
seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
|
159
|
+
extend_prefix_lens=torch.tensor(
|
160
|
+
[prefix_len] * self.batch_size, device=self.device
|
161
|
+
),
|
162
|
+
extend_prefix_lens_cpu=torch.tensor(
|
163
|
+
[prefix_len] * self.batch_size, device="cpu"
|
164
|
+
),
|
165
|
+
extend_seq_lens=torch.tensor(
|
166
|
+
[q_len] * self.batch_size, device=self.device
|
167
|
+
),
|
168
|
+
extend_seq_lens_cpu=torch.tensor(
|
169
|
+
[q_len] * self.batch_size, device="cpu"
|
170
|
+
),
|
171
|
+
attn_backend=self.backend,
|
172
|
+
)
|
173
|
+
|
174
|
+
else: # ForwardMode.DECODE
|
175
|
+
decode_len = q_len # typically 1 for decode mode
|
176
|
+
total_len = self.seq_len + decode_len
|
177
|
+
out_cache_start = self.batch_size * self.seq_len
|
178
|
+
out_cache_end = self.batch_size * total_len
|
179
|
+
|
180
|
+
forward_batch = ForwardBatch(
|
181
|
+
batch_size=self.batch_size,
|
182
|
+
input_ids=torch.randint(
|
183
|
+
0, 100, (self.batch_size, decode_len), device=self.device
|
184
|
+
),
|
185
|
+
out_cache_loc=torch.arange(
|
186
|
+
out_cache_start, out_cache_end, device=self.device
|
187
|
+
),
|
188
|
+
seq_lens_sum=self.batch_size * total_len,
|
189
|
+
forward_mode=mode,
|
190
|
+
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
191
|
+
seq_lens=torch.tensor(
|
192
|
+
[total_len] * self.batch_size, device=self.device
|
193
|
+
),
|
194
|
+
seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
|
195
|
+
attn_backend=self.backend,
|
196
|
+
)
|
197
|
+
|
198
|
+
# Add token pool from model runner to forward batch
|
199
|
+
forward_batch.req_to_token_pool = self.model_runner.req_to_token_pool
|
200
|
+
|
201
|
+
# Add KV cache from model runner to forward batch
|
202
|
+
forward_batch.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
203
|
+
|
204
|
+
return forward_batch
|
205
|
+
|
206
|
+
def _setup_kv_cache(self, forward_batch, layer, cache_len):
|
207
|
+
"""Set up KV cache with prefix tokens."""
|
208
|
+
if cache_len <= 0:
|
209
|
+
return
|
210
|
+
|
211
|
+
# Create constant values for the prefix cache for easy debugging
|
212
|
+
latent_cache = torch.ones(
|
213
|
+
self.batch_size * cache_len,
|
214
|
+
1, # latent cache has only one head in MQA
|
215
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
216
|
+
dtype=self.dtype,
|
217
|
+
device=self.device,
|
218
|
+
)
|
219
|
+
|
220
|
+
# Set the prefix KV cache
|
221
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
222
|
+
layer,
|
223
|
+
torch.arange(self.batch_size * cache_len, device=self.device),
|
224
|
+
latent_cache,
|
225
|
+
None,
|
226
|
+
)
|
227
|
+
|
228
|
+
def _run_attention_test(self, mode, q_len, prefix_len=0):
|
229
|
+
"""
|
230
|
+
Run an attention test with the specified parameters.
|
231
|
+
Args:
|
232
|
+
mode: ForwardMode.EXTEND or ForwardMode.DECODE
|
233
|
+
q_len: Length of the query sequence. For decode mode, q_len is 1.
|
234
|
+
prefix_len: Length of the prefix sequence for extend mode
|
235
|
+
"""
|
236
|
+
layer = self._create_attention_layer()
|
237
|
+
|
238
|
+
# Create forward batch and set up
|
239
|
+
forward_batch = self._create_forward_batch(mode, q_len, prefix_len)
|
240
|
+
|
241
|
+
# Create q, kv_compressed for testing
|
242
|
+
q_shape = (self.batch_size * q_len, self.num_heads, self.qk_head_dim)
|
243
|
+
kv_shape = (self.batch_size * q_len, self.qk_head_dim)
|
244
|
+
q = torch.randn(q_shape, dtype=self.dtype, device=self.device)
|
245
|
+
kv_compressed = torch.randn(kv_shape, dtype=self.dtype, device=self.device)
|
246
|
+
# v is not used for mqa, all values passed in through k
|
247
|
+
k = kv_compressed.unsqueeze(1)
|
248
|
+
v = torch.randn((1), dtype=self.dtype, device=self.device)
|
249
|
+
|
250
|
+
self._setup_kv_cache(forward_batch, layer, prefix_len)
|
251
|
+
|
252
|
+
self.backend.init_forward_metadata(forward_batch)
|
253
|
+
|
254
|
+
expected_shape = (
|
255
|
+
self.batch_size * q_len,
|
256
|
+
self.num_heads * self.kv_lora_rank,
|
257
|
+
)
|
258
|
+
|
259
|
+
if mode == ForwardMode.EXTEND:
|
260
|
+
output = self.backend.forward_extend(q, k, v, layer, forward_batch)
|
261
|
+
else:
|
262
|
+
output = self.backend.forward_decode(q, k, v, layer, forward_batch)
|
263
|
+
|
264
|
+
self._verify_output(output, expected_shape)
|
265
|
+
return output
|
266
|
+
|
267
|
+
def test_forward_extend(self):
|
268
|
+
"""Test the standard extend operation."""
|
269
|
+
self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len)
|
270
|
+
|
271
|
+
def test_forward_decode(self):
|
272
|
+
"""Test the decode operation with cached tokens."""
|
273
|
+
self._run_attention_test(ForwardMode.DECODE, q_len=1)
|
274
|
+
|
275
|
+
def test_forward_extend_with_prefix(self):
|
276
|
+
"""Test extending from cached prefix tokens."""
|
277
|
+
prefix_len = self.seq_len // 2
|
278
|
+
extend_len = self.seq_len - prefix_len
|
279
|
+
self._run_attention_test(
|
280
|
+
ForwardMode.EXTEND, q_len=extend_len, prefix_len=prefix_len
|
281
|
+
)
|
282
|
+
|
283
|
+
|
284
|
+
if __name__ == "__main__":
|
285
|
+
unittest.main()
|
@@ -0,0 +1,224 @@
|
|
1
|
+
import unittest
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
6
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
7
|
+
from sglang.test.test_utils import CustomTestCase
|
8
|
+
|
9
|
+
TEST_CASES = [
|
10
|
+
# Sequence with same prefix lens
|
11
|
+
{
|
12
|
+
"batch_size": 3,
|
13
|
+
"prefix_lens": [64, 64, 64],
|
14
|
+
"max_chunk_capacity": 48,
|
15
|
+
"prefix_chunk_len": 16,
|
16
|
+
"num_prefix_chunks": 4,
|
17
|
+
"prefix_chunk_starts": torch.tensor(
|
18
|
+
[
|
19
|
+
[0, 0, 0],
|
20
|
+
[16, 16, 16],
|
21
|
+
[32, 32, 32],
|
22
|
+
[48, 48, 48],
|
23
|
+
],
|
24
|
+
dtype=torch.int32,
|
25
|
+
),
|
26
|
+
"prefix_chunk_seq_lens": torch.tensor(
|
27
|
+
[
|
28
|
+
[16, 16, 16],
|
29
|
+
[16, 16, 16],
|
30
|
+
[16, 16, 16],
|
31
|
+
[16, 16, 16],
|
32
|
+
],
|
33
|
+
dtype=torch.int32,
|
34
|
+
),
|
35
|
+
},
|
36
|
+
# Sequence with different prefix lens
|
37
|
+
{
|
38
|
+
"batch_size": 4,
|
39
|
+
"prefix_lens": [16, 32, 48, 64],
|
40
|
+
"max_chunk_capacity": 64,
|
41
|
+
"prefix_chunk_len": 16,
|
42
|
+
"num_prefix_chunks": 4,
|
43
|
+
"prefix_chunk_starts": torch.tensor(
|
44
|
+
[
|
45
|
+
[0, 0, 0, 0],
|
46
|
+
[16, 16, 16, 16],
|
47
|
+
[32, 32, 32, 32],
|
48
|
+
[48, 48, 48, 48],
|
49
|
+
],
|
50
|
+
dtype=torch.int32,
|
51
|
+
),
|
52
|
+
"prefix_chunk_seq_lens": torch.tensor(
|
53
|
+
[
|
54
|
+
[16, 16, 16, 16],
|
55
|
+
[0, 16, 16, 16],
|
56
|
+
[0, 0, 16, 16],
|
57
|
+
[0, 0, 0, 16],
|
58
|
+
],
|
59
|
+
dtype=torch.int32,
|
60
|
+
),
|
61
|
+
},
|
62
|
+
# Sequence with irregular shapes
|
63
|
+
{
|
64
|
+
"batch_size": 2,
|
65
|
+
"prefix_lens": [1, 64],
|
66
|
+
"max_chunk_capacity": 31,
|
67
|
+
"prefix_chunk_len": 15,
|
68
|
+
"num_prefix_chunks": 5,
|
69
|
+
"prefix_chunk_starts": torch.tensor(
|
70
|
+
[
|
71
|
+
[0, 0],
|
72
|
+
[15, 15],
|
73
|
+
[30, 30],
|
74
|
+
[45, 45],
|
75
|
+
[60, 60],
|
76
|
+
],
|
77
|
+
dtype=torch.int32,
|
78
|
+
),
|
79
|
+
"prefix_chunk_seq_lens": torch.tensor(
|
80
|
+
[
|
81
|
+
[1, 15],
|
82
|
+
[0, 15],
|
83
|
+
[0, 15],
|
84
|
+
[0, 15],
|
85
|
+
[0, 4],
|
86
|
+
],
|
87
|
+
dtype=torch.int32,
|
88
|
+
),
|
89
|
+
},
|
90
|
+
]
|
91
|
+
|
92
|
+
|
93
|
+
class MockForwardBatch(ForwardBatch):
|
94
|
+
def __init__(self, max_chunk_capacity: int, *args, **kwargs):
|
95
|
+
super().__init__(*args, **kwargs)
|
96
|
+
self.max_chunk_capacity = max_chunk_capacity
|
97
|
+
|
98
|
+
def get_max_chunk_capacity(self):
|
99
|
+
return self.max_chunk_capacity
|
100
|
+
|
101
|
+
|
102
|
+
class MockReqToTokenPool:
|
103
|
+
def __init__(self, batch_size, seq_len, device):
|
104
|
+
self.req_to_token = (
|
105
|
+
torch.arange(batch_size * seq_len, device=device)
|
106
|
+
.reshape(batch_size, seq_len)
|
107
|
+
.to(torch.int32)
|
108
|
+
)
|
109
|
+
|
110
|
+
|
111
|
+
# Test correctness of triton kernel for computing kv indices
|
112
|
+
def check_kv_indices(forward_batch):
|
113
|
+
for i in range(forward_batch.num_prefix_chunks):
|
114
|
+
computed_kv_indices = forward_batch.prefix_chunk_kv_indices[i]
|
115
|
+
req_to_token = forward_batch.req_to_token_pool.req_to_token[
|
116
|
+
: forward_batch.batch_size, :
|
117
|
+
]
|
118
|
+
ref_kv_indices = torch.empty(
|
119
|
+
forward_batch.prefix_chunk_num_tokens[i],
|
120
|
+
dtype=torch.int32,
|
121
|
+
device=computed_kv_indices.device,
|
122
|
+
)
|
123
|
+
running_ptr = 0
|
124
|
+
for j in range(forward_batch.batch_size):
|
125
|
+
seq_start = forward_batch.prefix_chunk_starts[i, j].item()
|
126
|
+
seq_len = forward_batch.prefix_chunk_seq_lens[i, j].item()
|
127
|
+
ref_kv_indices[running_ptr : running_ptr + seq_len].copy_(
|
128
|
+
req_to_token[j, seq_start : seq_start + seq_len]
|
129
|
+
)
|
130
|
+
running_ptr += seq_len
|
131
|
+
assert torch.allclose(computed_kv_indices, ref_kv_indices)
|
132
|
+
|
133
|
+
|
134
|
+
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
135
|
+
class TestPrefixChunkInfo(CustomTestCase):
|
136
|
+
def setUp(self):
|
137
|
+
# Common test parameters
|
138
|
+
self.num_local_heads = 128
|
139
|
+
self.kv_lora_rank = 512
|
140
|
+
self.qk_rope_head_dim = 64
|
141
|
+
self.device = torch.device("cuda")
|
142
|
+
self.dtype = torch.bfloat16
|
143
|
+
self.extend_len = 64
|
144
|
+
self.max_bs = 4
|
145
|
+
self.max_seq_len = 128
|
146
|
+
|
147
|
+
# req_to_token_pool
|
148
|
+
self.req_to_token_pool = MockReqToTokenPool(
|
149
|
+
self.max_bs,
|
150
|
+
self.max_seq_len,
|
151
|
+
self.device,
|
152
|
+
)
|
153
|
+
|
154
|
+
# token_to_kv_pool
|
155
|
+
self.token_to_kv_pool = MLATokenToKVPool(
|
156
|
+
size=self.max_bs * self.max_seq_len,
|
157
|
+
page_size=1, # only consider page=1 for unit test
|
158
|
+
dtype=self.dtype,
|
159
|
+
kv_lora_rank=self.kv_lora_rank,
|
160
|
+
qk_rope_head_dim=self.qk_rope_head_dim,
|
161
|
+
layer_num=1, # only consider layer=1 for unit test
|
162
|
+
device=self.device,
|
163
|
+
enable_memory_saver=False,
|
164
|
+
)
|
165
|
+
|
166
|
+
def test_prefix_chunk_info(self):
|
167
|
+
"""Test the standard extend operation."""
|
168
|
+
|
169
|
+
for test_case in TEST_CASES:
|
170
|
+
print(
|
171
|
+
f"Test case with batch_size={test_case['batch_size']}, prefix_lens={test_case['prefix_lens']}, max_chunk_capacity={test_case['max_chunk_capacity']}"
|
172
|
+
)
|
173
|
+
batch_size = test_case["batch_size"]
|
174
|
+
prefix_lens_cpu = test_case["prefix_lens"]
|
175
|
+
assert len(prefix_lens_cpu) == batch_size
|
176
|
+
prefix_lens = torch.tensor(prefix_lens_cpu, device=self.device)
|
177
|
+
max_chunk_capacity = test_case["max_chunk_capacity"]
|
178
|
+
seq_lens_cpu = [
|
179
|
+
self.extend_len + prefix_lens_cpu[i] for i in range(batch_size)
|
180
|
+
]
|
181
|
+
seq_lens = torch.tensor(seq_lens_cpu, device=self.device)
|
182
|
+
|
183
|
+
# Create forward batch
|
184
|
+
# input_ids and out_cache_loc are dummy tensors in this test
|
185
|
+
forward_batch = MockForwardBatch(
|
186
|
+
max_chunk_capacity=max_chunk_capacity,
|
187
|
+
batch_size=batch_size,
|
188
|
+
input_ids=torch.randint(
|
189
|
+
0, 100, (batch_size, self.extend_len), device=self.device
|
190
|
+
),
|
191
|
+
out_cache_loc=torch.arange(
|
192
|
+
self.max_bs * self.max_seq_len - batch_size * self.extend_len,
|
193
|
+
self.max_bs * self.max_seq_len,
|
194
|
+
device=self.device,
|
195
|
+
),
|
196
|
+
seq_lens_sum=sum(seq_lens_cpu),
|
197
|
+
forward_mode=ForwardMode.EXTEND,
|
198
|
+
req_pool_indices=torch.arange(batch_size, device=self.device),
|
199
|
+
seq_lens=seq_lens,
|
200
|
+
seq_lens_cpu=seq_lens_cpu,
|
201
|
+
extend_prefix_lens=prefix_lens,
|
202
|
+
extend_prefix_lens_cpu=prefix_lens_cpu,
|
203
|
+
)
|
204
|
+
forward_batch.req_to_token_pool = self.req_to_token_pool
|
205
|
+
forward_batch.token_to_kv_pool = self.token_to_kv_pool
|
206
|
+
|
207
|
+
forward_batch.prepare_chunked_prefix_cache_info(self.device)
|
208
|
+
assert forward_batch.get_max_chunk_capacity() == max_chunk_capacity
|
209
|
+
assert forward_batch.prefix_chunk_len == test_case["prefix_chunk_len"]
|
210
|
+
assert forward_batch.num_prefix_chunks == test_case["num_prefix_chunks"]
|
211
|
+
assert torch.allclose(
|
212
|
+
forward_batch.prefix_chunk_starts,
|
213
|
+
test_case["prefix_chunk_starts"].to(self.device),
|
214
|
+
)
|
215
|
+
assert torch.allclose(
|
216
|
+
forward_batch.prefix_chunk_seq_lens,
|
217
|
+
test_case["prefix_chunk_seq_lens"].to(self.device),
|
218
|
+
)
|
219
|
+
|
220
|
+
check_kv_indices(forward_batch)
|
221
|
+
|
222
|
+
|
223
|
+
if __name__ == "__main__":
|
224
|
+
unittest.main()
|
sglang/test/test_block_fp8.py
CHANGED
@@ -7,10 +7,12 @@ import torch
|
|
7
7
|
from sglang.srt.layers.activation import SiluAndMul
|
8
8
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
9
9
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
10
|
+
per_tensor_quant_mla_fp8,
|
10
11
|
per_token_group_quant_fp8,
|
11
12
|
static_quant_fp8,
|
12
13
|
w8a8_block_fp8_matmul,
|
13
14
|
)
|
15
|
+
from sglang.srt.layers.quantization.fp8_utils import input_to_float8
|
14
16
|
from sglang.test.test_utils import CustomTestCase
|
15
17
|
|
16
18
|
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
@@ -155,6 +157,61 @@ class TestStaticQuantFP8(CustomTestCase):
|
|
155
157
|
self._static_quant_fp8(*params)
|
156
158
|
|
157
159
|
|
160
|
+
class TestPerTensorQuantMlaFP8(CustomTestCase):
|
161
|
+
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
162
|
+
NUM_TOKENS = [7, 83, 2048]
|
163
|
+
D = [512, 4096, 5120, 13824]
|
164
|
+
LAST_D_EXT = [1024, 0]
|
165
|
+
LAST_D = [512]
|
166
|
+
SEEDS = [0]
|
167
|
+
|
168
|
+
@classmethod
|
169
|
+
def setUpClass(cls):
|
170
|
+
if not torch.cuda.is_available():
|
171
|
+
raise unittest.SkipTest("CUDA is not available")
|
172
|
+
torch.set_default_device("cuda")
|
173
|
+
|
174
|
+
def _per_tensor_quant_mla_fp8(self, num_tokens, d, last_d_ext, last_d, dtype, seed):
|
175
|
+
torch.manual_seed(seed)
|
176
|
+
|
177
|
+
x = torch.rand(
|
178
|
+
(num_tokens, d // last_d, last_d + last_d_ext),
|
179
|
+
dtype=dtype,
|
180
|
+
)
|
181
|
+
x_sub, _ = x.split([last_d, last_d_ext], dim=-1)
|
182
|
+
|
183
|
+
with torch.inference_mode():
|
184
|
+
ref_out, ref_s = input_to_float8(x_sub.transpose(0, 1))
|
185
|
+
out, out_s = per_tensor_quant_mla_fp8(x_sub.transpose(0, 1))
|
186
|
+
|
187
|
+
self.assertTrue(out.is_contiguous())
|
188
|
+
self.assertTrue(
|
189
|
+
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50)
|
190
|
+
)
|
191
|
+
self.assertTrue(
|
192
|
+
torch.allclose(out_s.to(torch.float32), ref_s.to(torch.float32))
|
193
|
+
)
|
194
|
+
|
195
|
+
def test_per_tensor_quant_mla_fp8(self):
|
196
|
+
for params in itertools.product(
|
197
|
+
self.NUM_TOKENS,
|
198
|
+
self.D,
|
199
|
+
self.LAST_D_EXT,
|
200
|
+
self.LAST_D,
|
201
|
+
self.DTYPES,
|
202
|
+
self.SEEDS,
|
203
|
+
):
|
204
|
+
with self.subTest(
|
205
|
+
num_tokens=params[0],
|
206
|
+
d=params[1],
|
207
|
+
last_d_ext=params[2],
|
208
|
+
last_d=params[3],
|
209
|
+
dtype=params[4],
|
210
|
+
seed=params[5],
|
211
|
+
):
|
212
|
+
self._per_tensor_quant_mla_fp8(*params)
|
213
|
+
|
214
|
+
|
158
215
|
# For test
|
159
216
|
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
160
217
|
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
sglang/test/test_utils.py
CHANGED
@@ -25,7 +25,12 @@ from sglang.bench_serving import run_benchmark
|
|
25
25
|
from sglang.global_config import global_config
|
26
26
|
from sglang.lang.backend.openai import OpenAI
|
27
27
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
28
|
-
from sglang.srt.utils import
|
28
|
+
from sglang.srt.utils import (
|
29
|
+
get_bool_env_var,
|
30
|
+
is_port_available,
|
31
|
+
kill_process_tree,
|
32
|
+
retry,
|
33
|
+
)
|
29
34
|
from sglang.test.run_eval import run_eval
|
30
35
|
from sglang.utils import get_exception_traceback
|
31
36
|
|
@@ -37,11 +42,6 @@ DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST = (
|
|
37
42
|
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST = (
|
38
43
|
"nvidia/Llama-3.1-8B-Instruct-FP8"
|
39
44
|
)
|
40
|
-
# TODO(yundai424): right now specifying to an older revision since the latest one
|
41
|
-
# carries kv cache quantization which doesn't work yet
|
42
|
-
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION = (
|
43
|
-
"13858565416dbdc0b4e7a4a677fadfbd5b9e5bb9"
|
44
|
-
)
|
45
45
|
|
46
46
|
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
|
47
47
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
|
@@ -103,6 +103,17 @@ def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None)
|
|
103
103
|
return pred
|
104
104
|
|
105
105
|
|
106
|
+
def find_available_port(base_port: int):
|
107
|
+
port = base_port + random.randint(100, 1000)
|
108
|
+
while True:
|
109
|
+
if is_port_available(port):
|
110
|
+
return port
|
111
|
+
if port < 60000:
|
112
|
+
port += 42
|
113
|
+
else:
|
114
|
+
port -= 43
|
115
|
+
|
116
|
+
|
106
117
|
def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None):
|
107
118
|
assert url is not None
|
108
119
|
|
@@ -674,8 +685,6 @@ def run_bench_one_batch(model, other_args):
|
|
674
685
|
"python3",
|
675
686
|
"-m",
|
676
687
|
"sglang.bench_one_batch",
|
677
|
-
"--model-path",
|
678
|
-
model,
|
679
688
|
"--batch-size",
|
680
689
|
"1",
|
681
690
|
"--input",
|
@@ -684,6 +693,8 @@ def run_bench_one_batch(model, other_args):
|
|
684
693
|
"8",
|
685
694
|
*[str(x) for x in other_args],
|
686
695
|
]
|
696
|
+
if model is not None:
|
697
|
+
command += ["--model-path", model]
|
687
698
|
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
688
699
|
|
689
700
|
try:
|
sglang/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.4.
|
1
|
+
__version__ = "0.4.5.post1"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: sglang
|
3
|
-
Version: 0.4.
|
3
|
+
Version: 0.4.5.post1
|
4
4
|
Summary: SGLang is yet another fast serving framework for large language models and vision language models.
|
5
5
|
License: Apache License
|
6
6
|
Version 2.0, January 2004
|
@@ -239,20 +239,30 @@ Requires-Dist: python-multipart; extra == "runtime-common"
|
|
239
239
|
Requires-Dist: pyzmq>=25.1.2; extra == "runtime-common"
|
240
240
|
Requires-Dist: soundfile==0.13.1; extra == "runtime-common"
|
241
241
|
Requires-Dist: torchao>=0.7.0; extra == "runtime-common"
|
242
|
-
Requires-Dist: transformers==4.51.
|
242
|
+
Requires-Dist: transformers==4.51.1; extra == "runtime-common"
|
243
243
|
Requires-Dist: uvicorn; extra == "runtime-common"
|
244
244
|
Requires-Dist: uvloop; extra == "runtime-common"
|
245
245
|
Requires-Dist: compressed-tensors; extra == "runtime-common"
|
246
246
|
Requires-Dist: xgrammar==0.1.17; extra == "runtime-common"
|
247
247
|
Provides-Extra: srt
|
248
248
|
Requires-Dist: sglang[runtime_common]; extra == "srt"
|
249
|
-
Requires-Dist: sgl-kernel==0.0.
|
249
|
+
Requires-Dist: sgl-kernel==0.0.9.post1; extra == "srt"
|
250
250
|
Requires-Dist: flashinfer_python==0.2.3; extra == "srt"
|
251
251
|
Requires-Dist: torch==2.5.1; extra == "srt"
|
252
|
+
Requires-Dist: torchvision==0.20.1; extra == "srt"
|
252
253
|
Requires-Dist: cuda-python; extra == "srt"
|
253
254
|
Requires-Dist: outlines<=0.1.11,>=0.0.44; extra == "srt"
|
254
255
|
Requires-Dist: partial_json_parser; extra == "srt"
|
255
256
|
Requires-Dist: einops; extra == "srt"
|
257
|
+
Provides-Extra: blackwell
|
258
|
+
Requires-Dist: sglang[runtime_common]; extra == "blackwell"
|
259
|
+
Requires-Dist: sgl-kernel; extra == "blackwell"
|
260
|
+
Requires-Dist: torch; extra == "blackwell"
|
261
|
+
Requires-Dist: torchvision; extra == "blackwell"
|
262
|
+
Requires-Dist: cuda-python; extra == "blackwell"
|
263
|
+
Requires-Dist: outlines<=0.1.11,>=0.0.44; extra == "blackwell"
|
264
|
+
Requires-Dist: partial_json_parser; extra == "blackwell"
|
265
|
+
Requires-Dist: einops; extra == "blackwell"
|
256
266
|
Provides-Extra: srt-hip
|
257
267
|
Requires-Dist: sglang[runtime_common]; extra == "srt-hip"
|
258
268
|
Requires-Dist: torch; extra == "srt-hip"
|
@@ -391,7 +401,7 @@ Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-s
|
|
391
401
|
|
392
402
|
## Adoption and Sponsorship
|
393
403
|
The project has been deployed to large-scale production, generating trillions of tokens every day.
|
394
|
-
It is supported by the following institutions: AMD, Atlas Cloud, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Iflytek, Jam & Tea Studios, LinkedIn, LMSYS, Meituan, Nebius, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, and 01.AI.
|
404
|
+
It is supported by the following institutions: AMD, Atlas Cloud, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Iflytek, Jam & Tea Studios, LinkedIn, LMSYS, Meituan, Nebius, Novita AI, NVIDIA, Oracle, RunPod, Stanford, UC Berkeley, UCLA, xAI, and 01.AI.
|
395
405
|
|
396
406
|
<img src="https://raw.githubusercontent.com/sgl-project/sgl-learning-materials/main/slides/adoption.png" alt="logo" width="800" margin="10px"></img>
|
397
407
|
|