sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,285 @@
|
|
1
|
+
import unittest
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.configs.model_config import AttentionArch
|
6
|
+
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
7
|
+
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
8
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
9
|
+
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
10
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
11
|
+
from sglang.test.test_utils import CustomTestCase
|
12
|
+
|
13
|
+
|
14
|
+
class MockModelRunner:
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
kv_lora_rank,
|
18
|
+
qk_rope_head_dim,
|
19
|
+
):
|
20
|
+
attention_arch = AttentionArch.MLA
|
21
|
+
self.device = "cuda"
|
22
|
+
self.dtype = torch.float16
|
23
|
+
context_len = 2048
|
24
|
+
self.model_config = type(
|
25
|
+
"ModelConfig",
|
26
|
+
(),
|
27
|
+
{
|
28
|
+
"context_len": context_len,
|
29
|
+
"attention_arch": attention_arch,
|
30
|
+
},
|
31
|
+
)
|
32
|
+
self.sliding_window_size = None
|
33
|
+
|
34
|
+
batch_size = 160
|
35
|
+
# Create a proper req_to_token_pool with the req_to_token attribute
|
36
|
+
self.req_to_token_pool = type(
|
37
|
+
"TokenPool",
|
38
|
+
(),
|
39
|
+
{
|
40
|
+
# A typical max_bs * max_context_len for cuda graph decode
|
41
|
+
"size": batch_size,
|
42
|
+
# Add req_to_token attribute
|
43
|
+
"req_to_token": torch.zeros(
|
44
|
+
batch_size, context_len, dtype=torch.int32, device=self.device
|
45
|
+
),
|
46
|
+
},
|
47
|
+
)
|
48
|
+
self.page_size = 1
|
49
|
+
max_total_num_tokens = batch_size * context_len
|
50
|
+
self.token_to_kv_pool = MLATokenToKVPool(
|
51
|
+
size=max_total_num_tokens,
|
52
|
+
page_size=self.page_size,
|
53
|
+
dtype=self.dtype,
|
54
|
+
kv_lora_rank=kv_lora_rank,
|
55
|
+
qk_rope_head_dim=qk_rope_head_dim,
|
56
|
+
layer_num=1, # only consider layer=1 for unit test
|
57
|
+
device=self.device,
|
58
|
+
enable_memory_saver=False,
|
59
|
+
)
|
60
|
+
|
61
|
+
|
62
|
+
class MockReqToTokenPool:
|
63
|
+
def __init__(self, batch_size, seq_len, device):
|
64
|
+
self.req_to_token = (
|
65
|
+
torch.arange(batch_size * seq_len, device=device)
|
66
|
+
.reshape(batch_size, seq_len)
|
67
|
+
.to(torch.int32)
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
72
|
+
class TestFlashAttentionMLABackend(CustomTestCase):
|
73
|
+
def setUp(self):
|
74
|
+
# Test parameters
|
75
|
+
self.batch_size = 2
|
76
|
+
self.seq_len = 360
|
77
|
+
self.num_heads = 2
|
78
|
+
self.device = "cuda"
|
79
|
+
self.dtype = torch.float16
|
80
|
+
self.kv_lora_rank = 512
|
81
|
+
self.q_lora_rank = 128
|
82
|
+
self.qk_rope_head_dim = 64
|
83
|
+
self.qk_head_dim = self.qk_rope_head_dim + self.kv_lora_rank
|
84
|
+
# Assume no rope scaling
|
85
|
+
self.scaling = self.qk_head_dim**-0.5
|
86
|
+
# Initialize model runner and backend
|
87
|
+
self._init_model_runner()
|
88
|
+
self.backend = FlashAttentionBackend(self.model_runner)
|
89
|
+
self.num_local_heads = 2
|
90
|
+
|
91
|
+
def _init_model_runner(self):
|
92
|
+
self.model_runner = MockModelRunner(
|
93
|
+
kv_lora_rank=self.kv_lora_rank,
|
94
|
+
qk_rope_head_dim=self.qk_rope_head_dim,
|
95
|
+
)
|
96
|
+
self.backend = FlashAttentionBackend(self.model_runner)
|
97
|
+
|
98
|
+
def _create_attention_layer(self):
|
99
|
+
"""Create attention layer for testing."""
|
100
|
+
self.attn_mqa = RadixAttention(
|
101
|
+
num_heads=self.num_local_heads,
|
102
|
+
head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
|
103
|
+
scaling=self.scaling,
|
104
|
+
num_kv_heads=1,
|
105
|
+
layer_id=0,
|
106
|
+
v_head_dim=self.kv_lora_rank,
|
107
|
+
prefix="attn_mqa",
|
108
|
+
)
|
109
|
+
return self.attn_mqa
|
110
|
+
|
111
|
+
def _run_reference_forward(
|
112
|
+
self, mode, q, k, v, layer, forward_batch, expected_shape
|
113
|
+
):
|
114
|
+
"""Run reference forward pass using native backend."""
|
115
|
+
if mode == ForwardMode.EXTEND:
|
116
|
+
output = self.ref_backend.forward_extend(q, k, v, layer, forward_batch)
|
117
|
+
else: # ForwardMode.DECODE
|
118
|
+
output = self.ref_backend.forward_decode(q, k, v, layer, forward_batch)
|
119
|
+
return output.view(expected_shape)
|
120
|
+
|
121
|
+
def _verify_output(self, output, expected_shape):
|
122
|
+
"""Verify output tensor shape, dtype, and values."""
|
123
|
+
self.assertEqual(
|
124
|
+
output.shape,
|
125
|
+
expected_shape,
|
126
|
+
f"Expected shape {expected_shape}, got {output.shape}",
|
127
|
+
)
|
128
|
+
self.assertEqual(output.dtype, self.dtype)
|
129
|
+
self.assertEqual(output.device.type, "cuda")
|
130
|
+
self.assertEqual(
|
131
|
+
torch.isnan(output).sum().item(), 0, "Output contains NaN values"
|
132
|
+
)
|
133
|
+
|
134
|
+
def _create_forward_batch(self, mode, q_len=None, prefix_len=0):
|
135
|
+
"""Create a forward batch for testing based on mode and lengths."""
|
136
|
+
# Default to self.seq_len if not specified
|
137
|
+
q_len = q_len or self.seq_len
|
138
|
+
|
139
|
+
if mode == ForwardMode.EXTEND:
|
140
|
+
total_len = prefix_len + q_len
|
141
|
+
out_cache_start = prefix_len * self.batch_size
|
142
|
+
out_cache_end = total_len * self.batch_size
|
143
|
+
|
144
|
+
forward_batch = ForwardBatch(
|
145
|
+
batch_size=self.batch_size,
|
146
|
+
input_ids=torch.randint(
|
147
|
+
0, 100, (self.batch_size, q_len), device=self.device
|
148
|
+
),
|
149
|
+
out_cache_loc=torch.arange(
|
150
|
+
out_cache_start, out_cache_end, device=self.device
|
151
|
+
),
|
152
|
+
seq_lens_sum=self.batch_size * total_len,
|
153
|
+
forward_mode=mode,
|
154
|
+
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
155
|
+
seq_lens=torch.tensor(
|
156
|
+
[total_len] * self.batch_size, device=self.device
|
157
|
+
),
|
158
|
+
seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
|
159
|
+
extend_prefix_lens=torch.tensor(
|
160
|
+
[prefix_len] * self.batch_size, device=self.device
|
161
|
+
),
|
162
|
+
extend_prefix_lens_cpu=torch.tensor(
|
163
|
+
[prefix_len] * self.batch_size, device="cpu"
|
164
|
+
),
|
165
|
+
extend_seq_lens=torch.tensor(
|
166
|
+
[q_len] * self.batch_size, device=self.device
|
167
|
+
),
|
168
|
+
extend_seq_lens_cpu=torch.tensor(
|
169
|
+
[q_len] * self.batch_size, device="cpu"
|
170
|
+
),
|
171
|
+
attn_backend=self.backend,
|
172
|
+
)
|
173
|
+
|
174
|
+
else: # ForwardMode.DECODE
|
175
|
+
decode_len = q_len # typically 1 for decode mode
|
176
|
+
total_len = self.seq_len + decode_len
|
177
|
+
out_cache_start = self.batch_size * self.seq_len
|
178
|
+
out_cache_end = self.batch_size * total_len
|
179
|
+
|
180
|
+
forward_batch = ForwardBatch(
|
181
|
+
batch_size=self.batch_size,
|
182
|
+
input_ids=torch.randint(
|
183
|
+
0, 100, (self.batch_size, decode_len), device=self.device
|
184
|
+
),
|
185
|
+
out_cache_loc=torch.arange(
|
186
|
+
out_cache_start, out_cache_end, device=self.device
|
187
|
+
),
|
188
|
+
seq_lens_sum=self.batch_size * total_len,
|
189
|
+
forward_mode=mode,
|
190
|
+
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
191
|
+
seq_lens=torch.tensor(
|
192
|
+
[total_len] * self.batch_size, device=self.device
|
193
|
+
),
|
194
|
+
seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
|
195
|
+
attn_backend=self.backend,
|
196
|
+
)
|
197
|
+
|
198
|
+
# Add token pool from model runner to forward batch
|
199
|
+
forward_batch.req_to_token_pool = self.model_runner.req_to_token_pool
|
200
|
+
|
201
|
+
# Add KV cache from model runner to forward batch
|
202
|
+
forward_batch.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
203
|
+
|
204
|
+
return forward_batch
|
205
|
+
|
206
|
+
def _setup_kv_cache(self, forward_batch, layer, cache_len):
|
207
|
+
"""Set up KV cache with prefix tokens."""
|
208
|
+
if cache_len <= 0:
|
209
|
+
return
|
210
|
+
|
211
|
+
# Create constant values for the prefix cache for easy debugging
|
212
|
+
latent_cache = torch.ones(
|
213
|
+
self.batch_size * cache_len,
|
214
|
+
1, # latent cache has only one head in MQA
|
215
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
216
|
+
dtype=self.dtype,
|
217
|
+
device=self.device,
|
218
|
+
)
|
219
|
+
|
220
|
+
# Set the prefix KV cache
|
221
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
222
|
+
layer,
|
223
|
+
torch.arange(self.batch_size * cache_len, device=self.device),
|
224
|
+
latent_cache,
|
225
|
+
None,
|
226
|
+
)
|
227
|
+
|
228
|
+
def _run_attention_test(self, mode, q_len, prefix_len=0):
|
229
|
+
"""
|
230
|
+
Run an attention test with the specified parameters.
|
231
|
+
Args:
|
232
|
+
mode: ForwardMode.EXTEND or ForwardMode.DECODE
|
233
|
+
q_len: Length of the query sequence. For decode mode, q_len is 1.
|
234
|
+
prefix_len: Length of the prefix sequence for extend mode
|
235
|
+
"""
|
236
|
+
layer = self._create_attention_layer()
|
237
|
+
|
238
|
+
# Create forward batch and set up
|
239
|
+
forward_batch = self._create_forward_batch(mode, q_len, prefix_len)
|
240
|
+
|
241
|
+
# Create q, kv_compressed for testing
|
242
|
+
q_shape = (self.batch_size * q_len, self.num_heads, self.qk_head_dim)
|
243
|
+
kv_shape = (self.batch_size * q_len, self.qk_head_dim)
|
244
|
+
q = torch.randn(q_shape, dtype=self.dtype, device=self.device)
|
245
|
+
kv_compressed = torch.randn(kv_shape, dtype=self.dtype, device=self.device)
|
246
|
+
# v is not used for mqa, all values passed in through k
|
247
|
+
k = kv_compressed.unsqueeze(1)
|
248
|
+
v = torch.randn((1), dtype=self.dtype, device=self.device)
|
249
|
+
|
250
|
+
self._setup_kv_cache(forward_batch, layer, prefix_len)
|
251
|
+
|
252
|
+
self.backend.init_forward_metadata(forward_batch)
|
253
|
+
|
254
|
+
expected_shape = (
|
255
|
+
self.batch_size * q_len,
|
256
|
+
self.num_heads * self.kv_lora_rank,
|
257
|
+
)
|
258
|
+
|
259
|
+
if mode == ForwardMode.EXTEND:
|
260
|
+
output = self.backend.forward_extend(q, k, v, layer, forward_batch)
|
261
|
+
else:
|
262
|
+
output = self.backend.forward_decode(q, k, v, layer, forward_batch)
|
263
|
+
|
264
|
+
self._verify_output(output, expected_shape)
|
265
|
+
return output
|
266
|
+
|
267
|
+
def test_forward_extend(self):
|
268
|
+
"""Test the standard extend operation."""
|
269
|
+
self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len)
|
270
|
+
|
271
|
+
def test_forward_decode(self):
|
272
|
+
"""Test the decode operation with cached tokens."""
|
273
|
+
self._run_attention_test(ForwardMode.DECODE, q_len=1)
|
274
|
+
|
275
|
+
def test_forward_extend_with_prefix(self):
|
276
|
+
"""Test extending from cached prefix tokens."""
|
277
|
+
prefix_len = self.seq_len // 2
|
278
|
+
extend_len = self.seq_len - prefix_len
|
279
|
+
self._run_attention_test(
|
280
|
+
ForwardMode.EXTEND, q_len=extend_len, prefix_len=prefix_len
|
281
|
+
)
|
282
|
+
|
283
|
+
|
284
|
+
if __name__ == "__main__":
|
285
|
+
unittest.main()
|
@@ -0,0 +1,224 @@
|
|
1
|
+
import unittest
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
6
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
7
|
+
from sglang.test.test_utils import CustomTestCase
|
8
|
+
|
9
|
+
TEST_CASES = [
|
10
|
+
# Sequence with same prefix lens
|
11
|
+
{
|
12
|
+
"batch_size": 3,
|
13
|
+
"prefix_lens": [64, 64, 64],
|
14
|
+
"max_chunk_capacity": 48,
|
15
|
+
"prefix_chunk_len": 16,
|
16
|
+
"num_prefix_chunks": 4,
|
17
|
+
"prefix_chunk_starts": torch.tensor(
|
18
|
+
[
|
19
|
+
[0, 0, 0],
|
20
|
+
[16, 16, 16],
|
21
|
+
[32, 32, 32],
|
22
|
+
[48, 48, 48],
|
23
|
+
],
|
24
|
+
dtype=torch.int32,
|
25
|
+
),
|
26
|
+
"prefix_chunk_seq_lens": torch.tensor(
|
27
|
+
[
|
28
|
+
[16, 16, 16],
|
29
|
+
[16, 16, 16],
|
30
|
+
[16, 16, 16],
|
31
|
+
[16, 16, 16],
|
32
|
+
],
|
33
|
+
dtype=torch.int32,
|
34
|
+
),
|
35
|
+
},
|
36
|
+
# Sequence with different prefix lens
|
37
|
+
{
|
38
|
+
"batch_size": 4,
|
39
|
+
"prefix_lens": [16, 32, 48, 64],
|
40
|
+
"max_chunk_capacity": 64,
|
41
|
+
"prefix_chunk_len": 16,
|
42
|
+
"num_prefix_chunks": 4,
|
43
|
+
"prefix_chunk_starts": torch.tensor(
|
44
|
+
[
|
45
|
+
[0, 0, 0, 0],
|
46
|
+
[16, 16, 16, 16],
|
47
|
+
[32, 32, 32, 32],
|
48
|
+
[48, 48, 48, 48],
|
49
|
+
],
|
50
|
+
dtype=torch.int32,
|
51
|
+
),
|
52
|
+
"prefix_chunk_seq_lens": torch.tensor(
|
53
|
+
[
|
54
|
+
[16, 16, 16, 16],
|
55
|
+
[0, 16, 16, 16],
|
56
|
+
[0, 0, 16, 16],
|
57
|
+
[0, 0, 0, 16],
|
58
|
+
],
|
59
|
+
dtype=torch.int32,
|
60
|
+
),
|
61
|
+
},
|
62
|
+
# Sequence with irregular shapes
|
63
|
+
{
|
64
|
+
"batch_size": 2,
|
65
|
+
"prefix_lens": [1, 64],
|
66
|
+
"max_chunk_capacity": 31,
|
67
|
+
"prefix_chunk_len": 15,
|
68
|
+
"num_prefix_chunks": 5,
|
69
|
+
"prefix_chunk_starts": torch.tensor(
|
70
|
+
[
|
71
|
+
[0, 0],
|
72
|
+
[15, 15],
|
73
|
+
[30, 30],
|
74
|
+
[45, 45],
|
75
|
+
[60, 60],
|
76
|
+
],
|
77
|
+
dtype=torch.int32,
|
78
|
+
),
|
79
|
+
"prefix_chunk_seq_lens": torch.tensor(
|
80
|
+
[
|
81
|
+
[1, 15],
|
82
|
+
[0, 15],
|
83
|
+
[0, 15],
|
84
|
+
[0, 15],
|
85
|
+
[0, 4],
|
86
|
+
],
|
87
|
+
dtype=torch.int32,
|
88
|
+
),
|
89
|
+
},
|
90
|
+
]
|
91
|
+
|
92
|
+
|
93
|
+
class MockForwardBatch(ForwardBatch):
|
94
|
+
def __init__(self, max_chunk_capacity: int, *args, **kwargs):
|
95
|
+
super().__init__(*args, **kwargs)
|
96
|
+
self.max_chunk_capacity = max_chunk_capacity
|
97
|
+
|
98
|
+
def get_max_chunk_capacity(self):
|
99
|
+
return self.max_chunk_capacity
|
100
|
+
|
101
|
+
|
102
|
+
class MockReqToTokenPool:
|
103
|
+
def __init__(self, batch_size, seq_len, device):
|
104
|
+
self.req_to_token = (
|
105
|
+
torch.arange(batch_size * seq_len, device=device)
|
106
|
+
.reshape(batch_size, seq_len)
|
107
|
+
.to(torch.int32)
|
108
|
+
)
|
109
|
+
|
110
|
+
|
111
|
+
# Test correctness of triton kernel for computing kv indices
|
112
|
+
def check_kv_indices(forward_batch):
|
113
|
+
for i in range(forward_batch.num_prefix_chunks):
|
114
|
+
computed_kv_indices = forward_batch.prefix_chunk_kv_indices[i]
|
115
|
+
req_to_token = forward_batch.req_to_token_pool.req_to_token[
|
116
|
+
: forward_batch.batch_size, :
|
117
|
+
]
|
118
|
+
ref_kv_indices = torch.empty(
|
119
|
+
forward_batch.prefix_chunk_num_tokens[i],
|
120
|
+
dtype=torch.int32,
|
121
|
+
device=computed_kv_indices.device,
|
122
|
+
)
|
123
|
+
running_ptr = 0
|
124
|
+
for j in range(forward_batch.batch_size):
|
125
|
+
seq_start = forward_batch.prefix_chunk_starts[i, j].item()
|
126
|
+
seq_len = forward_batch.prefix_chunk_seq_lens[i, j].item()
|
127
|
+
ref_kv_indices[running_ptr : running_ptr + seq_len].copy_(
|
128
|
+
req_to_token[j, seq_start : seq_start + seq_len]
|
129
|
+
)
|
130
|
+
running_ptr += seq_len
|
131
|
+
assert torch.allclose(computed_kv_indices, ref_kv_indices)
|
132
|
+
|
133
|
+
|
134
|
+
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
135
|
+
class TestPrefixChunkInfo(CustomTestCase):
|
136
|
+
def setUp(self):
|
137
|
+
# Common test parameters
|
138
|
+
self.num_local_heads = 128
|
139
|
+
self.kv_lora_rank = 512
|
140
|
+
self.qk_rope_head_dim = 64
|
141
|
+
self.device = torch.device("cuda")
|
142
|
+
self.dtype = torch.bfloat16
|
143
|
+
self.extend_len = 64
|
144
|
+
self.max_bs = 4
|
145
|
+
self.max_seq_len = 128
|
146
|
+
|
147
|
+
# req_to_token_pool
|
148
|
+
self.req_to_token_pool = MockReqToTokenPool(
|
149
|
+
self.max_bs,
|
150
|
+
self.max_seq_len,
|
151
|
+
self.device,
|
152
|
+
)
|
153
|
+
|
154
|
+
# token_to_kv_pool
|
155
|
+
self.token_to_kv_pool = MLATokenToKVPool(
|
156
|
+
size=self.max_bs * self.max_seq_len,
|
157
|
+
page_size=1, # only consider page=1 for unit test
|
158
|
+
dtype=self.dtype,
|
159
|
+
kv_lora_rank=self.kv_lora_rank,
|
160
|
+
qk_rope_head_dim=self.qk_rope_head_dim,
|
161
|
+
layer_num=1, # only consider layer=1 for unit test
|
162
|
+
device=self.device,
|
163
|
+
enable_memory_saver=False,
|
164
|
+
)
|
165
|
+
|
166
|
+
def test_prefix_chunk_info(self):
|
167
|
+
"""Test the standard extend operation."""
|
168
|
+
|
169
|
+
for test_case in TEST_CASES:
|
170
|
+
print(
|
171
|
+
f"Test case with batch_size={test_case['batch_size']}, prefix_lens={test_case['prefix_lens']}, max_chunk_capacity={test_case['max_chunk_capacity']}"
|
172
|
+
)
|
173
|
+
batch_size = test_case["batch_size"]
|
174
|
+
prefix_lens_cpu = test_case["prefix_lens"]
|
175
|
+
assert len(prefix_lens_cpu) == batch_size
|
176
|
+
prefix_lens = torch.tensor(prefix_lens_cpu, device=self.device)
|
177
|
+
max_chunk_capacity = test_case["max_chunk_capacity"]
|
178
|
+
seq_lens_cpu = [
|
179
|
+
self.extend_len + prefix_lens_cpu[i] for i in range(batch_size)
|
180
|
+
]
|
181
|
+
seq_lens = torch.tensor(seq_lens_cpu, device=self.device)
|
182
|
+
|
183
|
+
# Create forward batch
|
184
|
+
# input_ids and out_cache_loc are dummy tensors in this test
|
185
|
+
forward_batch = MockForwardBatch(
|
186
|
+
max_chunk_capacity=max_chunk_capacity,
|
187
|
+
batch_size=batch_size,
|
188
|
+
input_ids=torch.randint(
|
189
|
+
0, 100, (batch_size, self.extend_len), device=self.device
|
190
|
+
),
|
191
|
+
out_cache_loc=torch.arange(
|
192
|
+
self.max_bs * self.max_seq_len - batch_size * self.extend_len,
|
193
|
+
self.max_bs * self.max_seq_len,
|
194
|
+
device=self.device,
|
195
|
+
),
|
196
|
+
seq_lens_sum=sum(seq_lens_cpu),
|
197
|
+
forward_mode=ForwardMode.EXTEND,
|
198
|
+
req_pool_indices=torch.arange(batch_size, device=self.device),
|
199
|
+
seq_lens=seq_lens,
|
200
|
+
seq_lens_cpu=seq_lens_cpu,
|
201
|
+
extend_prefix_lens=prefix_lens,
|
202
|
+
extend_prefix_lens_cpu=prefix_lens_cpu,
|
203
|
+
)
|
204
|
+
forward_batch.req_to_token_pool = self.req_to_token_pool
|
205
|
+
forward_batch.token_to_kv_pool = self.token_to_kv_pool
|
206
|
+
|
207
|
+
forward_batch.prepare_chunked_prefix_cache_info(self.device)
|
208
|
+
assert forward_batch.get_max_chunk_capacity() == max_chunk_capacity
|
209
|
+
assert forward_batch.prefix_chunk_len == test_case["prefix_chunk_len"]
|
210
|
+
assert forward_batch.num_prefix_chunks == test_case["num_prefix_chunks"]
|
211
|
+
assert torch.allclose(
|
212
|
+
forward_batch.prefix_chunk_starts,
|
213
|
+
test_case["prefix_chunk_starts"].to(self.device),
|
214
|
+
)
|
215
|
+
assert torch.allclose(
|
216
|
+
forward_batch.prefix_chunk_seq_lens,
|
217
|
+
test_case["prefix_chunk_seq_lens"].to(self.device),
|
218
|
+
)
|
219
|
+
|
220
|
+
check_kv_indices(forward_batch)
|
221
|
+
|
222
|
+
|
223
|
+
if __name__ == "__main__":
|
224
|
+
unittest.main()
|
sglang/test/runners.py
CHANGED
@@ -26,8 +26,8 @@ from transformers import (
|
|
26
26
|
AutoProcessor,
|
27
27
|
)
|
28
28
|
|
29
|
+
from sglang.srt.entrypoints.engine import Engine
|
29
30
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
30
|
-
from sglang.srt.server import Engine
|
31
31
|
from sglang.srt.utils import load_image
|
32
32
|
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
|
33
33
|
|
@@ -51,6 +51,8 @@ NUM_TOP_LOGPROBS = 5
|
|
51
51
|
def get_dtype_str(torch_dtype):
|
52
52
|
if torch_dtype is torch.float16:
|
53
53
|
return "float16"
|
54
|
+
if torch_dtype is torch.float32:
|
55
|
+
return "float32"
|
54
56
|
else:
|
55
57
|
raise NotImplementedError()
|
56
58
|
|
@@ -447,6 +449,7 @@ class SRTRunner:
|
|
447
449
|
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
448
450
|
lora_paths: List[str] = None,
|
449
451
|
max_loras_per_batch: int = 4,
|
452
|
+
attention_backend: Optional[str] = None,
|
450
453
|
lora_backend: str = "triton",
|
451
454
|
disable_cuda_graph: bool = False,
|
452
455
|
disable_radix_cache: bool = False,
|
@@ -487,6 +490,7 @@ class SRTRunner:
|
|
487
490
|
lora_paths=lora_paths,
|
488
491
|
max_loras_per_batch=max_loras_per_batch,
|
489
492
|
lora_backend=lora_backend,
|
493
|
+
attention_backend=attention_backend,
|
490
494
|
disable_cuda_graph=disable_cuda_graph,
|
491
495
|
disable_radix_cache=disable_radix_cache,
|
492
496
|
chunked_prefill_size=chunked_prefill_size,
|