sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +40 -4
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +609 -202
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +28 -14
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/radix_attention.py +14 -0
- sglang/srt/layers/rotary_embedding.py +75 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +49 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +13 -4
- sglang/srt/models/llama4.py +487 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +227 -0
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -2,60 +2,109 @@ import unittest
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
+
from sglang.srt.configs.model_config import AttentionArch
|
5
6
|
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
7
|
+
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
6
8
|
from sglang.srt.layers.radix_attention import RadixAttention
|
7
9
|
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
8
10
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
11
|
+
from sglang.srt.model_executor.model_runner import ServerArgs
|
9
12
|
from sglang.test.test_utils import CustomTestCase
|
10
13
|
|
11
14
|
|
12
15
|
class MockModelRunner:
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
self.device =
|
20
|
-
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
page_size=1,
|
19
|
+
num_heads=2,
|
20
|
+
head_dim=8,
|
21
|
+
):
|
22
|
+
self.device = "cuda"
|
23
|
+
self.dtype = torch.float16
|
24
|
+
attention_arch = AttentionArch.MHA
|
25
|
+
# Max batch size for the test.
|
26
|
+
max_batch_size = 160
|
27
|
+
# Total tokens(prefix + extend + decode) in the test should not exceed this length.
|
28
|
+
max_context_len = 2048
|
29
|
+
self.model_config = type(
|
30
|
+
"ModelConfig",
|
31
|
+
(),
|
32
|
+
{
|
33
|
+
"context_len": max_context_len,
|
34
|
+
"is_multimodal": False,
|
35
|
+
"attention_arch": attention_arch,
|
36
|
+
},
|
37
|
+
)
|
38
|
+
self.sliding_window_size = None
|
39
|
+
self.device = self.device
|
40
|
+
# Create a large enough req_to_token_pool to fit the test usage.
|
21
41
|
self.req_to_token_pool = type(
|
22
42
|
"TokenPool",
|
23
43
|
(),
|
24
44
|
{
|
25
|
-
|
45
|
+
# A typical max_bs * max_context_len for cuda graph decode
|
46
|
+
"size": max_batch_size,
|
47
|
+
# Add req_to_token attribute
|
26
48
|
"req_to_token": torch.zeros(
|
27
|
-
|
28
|
-
|
49
|
+
max_batch_size,
|
50
|
+
max_context_len,
|
51
|
+
dtype=torch.int32,
|
52
|
+
device=self.device,
|
53
|
+
),
|
29
54
|
},
|
30
55
|
)
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
56
|
+
self.page_size = page_size
|
57
|
+
max_total_num_tokens = max_batch_size * max_context_len
|
58
|
+
self.token_to_kv_pool = MHATokenToKVPool(
|
59
|
+
size=max_total_num_tokens,
|
60
|
+
page_size=page_size,
|
61
|
+
dtype=self.dtype,
|
62
|
+
head_num=num_heads,
|
63
|
+
head_dim=head_dim,
|
64
|
+
layer_num=1, # only consider layer=1 for unit test
|
65
|
+
device=self.device,
|
66
|
+
enable_memory_saver=False,
|
39
67
|
)
|
68
|
+
# Required by torch native backend
|
69
|
+
self.server_args = ServerArgs(model_path="fake_model_path")
|
40
70
|
|
41
71
|
|
42
72
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
43
73
|
class TestFlashAttentionBackend(CustomTestCase):
|
44
74
|
def setUp(self):
|
45
|
-
|
46
|
-
self.model_runner = MockModelRunner()
|
47
|
-
self.backend = FlashAttentionBackend(self.model_runner)
|
48
|
-
|
49
|
-
# Common test parameters
|
75
|
+
# Test parameters
|
50
76
|
self.batch_size = 2
|
51
|
-
self.seq_len =
|
77
|
+
self.seq_len = 256
|
52
78
|
self.num_heads = 2
|
53
79
|
self.head_dim = 8
|
54
80
|
self.device = "cuda"
|
55
81
|
self.dtype = torch.float16
|
56
82
|
|
83
|
+
def _init_model_runner(self, page_size=1):
|
84
|
+
self.model_runner = MockModelRunner(
|
85
|
+
page_size=page_size,
|
86
|
+
num_heads=self.num_heads,
|
87
|
+
head_dim=self.head_dim,
|
88
|
+
)
|
89
|
+
self.backend = FlashAttentionBackend(self.model_runner)
|
90
|
+
self.ref_backend = TorchNativeAttnBackend(self.model_runner)
|
91
|
+
self.model_runner.model_config.num_attention_heads = self.num_heads
|
92
|
+
|
93
|
+
def _mock_write_to_req_to_token_pool(self, batch_size, seq_len, page_size):
|
94
|
+
# if page_size > 1, the token pool stores the index to the page.
|
95
|
+
# so we need to multiply the index by page_size.
|
96
|
+
self.req_to_token = (
|
97
|
+
torch.arange(0, batch_size, dtype=torch.int32, device=self.device)[:, None]
|
98
|
+
* seq_len
|
99
|
+
+ torch.arange(0, seq_len, dtype=torch.int32, device=self.device)[None, :]
|
100
|
+
+ page_size
|
101
|
+
)
|
102
|
+
self.model_runner.req_to_token_pool.req_to_token[:batch_size, :seq_len] = (
|
103
|
+
self.req_to_token
|
104
|
+
)
|
105
|
+
|
57
106
|
def _create_attention_layer(self):
|
58
|
-
"""
|
107
|
+
"""Create attention layer for testing."""
|
59
108
|
return RadixAttention(
|
60
109
|
num_heads=self.num_heads,
|
61
110
|
head_dim=self.head_dim,
|
@@ -64,47 +113,27 @@ class TestFlashAttentionBackend(CustomTestCase):
|
|
64
113
|
layer_id=0,
|
65
114
|
)
|
66
115
|
|
67
|
-
def _create_kv_pool(self, size):
|
68
|
-
"""Helper method to create a KV pool."""
|
69
|
-
return MHATokenToKVPool(
|
70
|
-
size=size,
|
71
|
-
page_size=1, # only consider page=1 for unit test
|
72
|
-
dtype=self.dtype,
|
73
|
-
head_num=self.num_heads,
|
74
|
-
head_dim=self.head_dim,
|
75
|
-
layer_num=1, # only consider layer=1 for unit test
|
76
|
-
device=self.device,
|
77
|
-
enable_memory_saver=False,
|
78
|
-
)
|
79
|
-
|
80
116
|
def _create_qkv_tensors(self, tokens_len):
|
81
|
-
"""
|
117
|
+
"""Create q, k, v tensors for testing."""
|
118
|
+
shape = (tokens_len, self.num_heads, self.head_dim)
|
82
119
|
return (
|
83
|
-
torch.randn(
|
84
|
-
|
85
|
-
|
86
|
-
self.head_dim,
|
87
|
-
dtype=self.dtype,
|
88
|
-
device=self.device,
|
89
|
-
),
|
90
|
-
torch.randn(
|
91
|
-
tokens_len,
|
92
|
-
self.num_heads,
|
93
|
-
self.head_dim,
|
94
|
-
dtype=self.dtype,
|
95
|
-
device=self.device,
|
96
|
-
),
|
97
|
-
torch.randn(
|
98
|
-
tokens_len,
|
99
|
-
self.num_heads,
|
100
|
-
self.head_dim,
|
101
|
-
dtype=self.dtype,
|
102
|
-
device=self.device,
|
103
|
-
),
|
120
|
+
torch.randn(shape, dtype=self.dtype, device=self.device),
|
121
|
+
torch.randn(shape, dtype=self.dtype, device=self.device),
|
122
|
+
torch.randn(shape, dtype=self.dtype, device=self.device),
|
104
123
|
)
|
105
124
|
|
106
|
-
def
|
107
|
-
|
125
|
+
def _run_reference_forward(
|
126
|
+
self, mode, q, k, v, layer, forward_batch, expected_shape
|
127
|
+
):
|
128
|
+
"""Run reference forward pass using native backend."""
|
129
|
+
if mode == ForwardMode.EXTEND:
|
130
|
+
output = self.ref_backend.forward_extend(q, k, v, layer, forward_batch)
|
131
|
+
else: # ForwardMode.DECODE
|
132
|
+
output = self.ref_backend.forward_decode(q, k, v, layer, forward_batch)
|
133
|
+
return output.view(expected_shape)
|
134
|
+
|
135
|
+
def _verify_output(self, output, expected_shape, output_ref=None):
|
136
|
+
"""Verify output tensor shape, dtype, and values."""
|
108
137
|
self.assertEqual(
|
109
138
|
output.shape,
|
110
139
|
expected_shape,
|
@@ -116,161 +145,110 @@ class TestFlashAttentionBackend(CustomTestCase):
|
|
116
145
|
torch.isnan(output).sum().item(), 0, "Output contains NaN values"
|
117
146
|
)
|
118
147
|
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
layer,
|
209
|
-
torch.arange(self.batch_size * self.seq_len, device=self.device),
|
210
|
-
cache_k,
|
211
|
-
cache_v,
|
212
|
-
layer.k_scale,
|
213
|
-
layer.v_scale,
|
214
|
-
)
|
215
|
-
|
216
|
-
# Initialize forward metadata before running the attention
|
217
|
-
self.backend.init_forward_metadata(forward_batch)
|
218
|
-
|
219
|
-
# Run forward_decode
|
220
|
-
output = self.backend.forward_decode(q, k, v, layer, forward_batch)
|
221
|
-
|
222
|
-
# Verify output
|
223
|
-
expected_shape = (self.batch_size, self.num_heads * self.head_dim)
|
224
|
-
self._verify_output(output, expected_shape)
|
225
|
-
|
226
|
-
def test_forward_extend_with_prefix(self):
|
227
|
-
"""Test extending from cached prefix tokens."""
|
228
|
-
# Define prefix and extend lengths
|
229
|
-
prefix_len = 2
|
230
|
-
extend_len = 2
|
231
|
-
total_len = prefix_len + extend_len
|
232
|
-
|
233
|
-
# Create test inputs for the extend portion
|
234
|
-
q, k, v = self._create_qkv_tensors(self.batch_size * extend_len)
|
148
|
+
if output_ref is not None:
|
149
|
+
if not torch.allclose(output, output_ref, atol=1e-1, rtol=0.0):
|
150
|
+
# Check where the values differ beyond the given tolerances
|
151
|
+
diff_mask = ~torch.isclose(output, output_ref, atol=1e-1, rtol=0.0)
|
152
|
+
|
153
|
+
# Find the first index where the difference occurs
|
154
|
+
if diff_mask.any():
|
155
|
+
first_mismatch_idx = diff_mask.nonzero()[0]
|
156
|
+
print(
|
157
|
+
"First mismatch at index:", tuple(first_mismatch_idx.tolist())
|
158
|
+
)
|
159
|
+
print("output:", output[tuple(first_mismatch_idx.tolist())])
|
160
|
+
print("output_ref:", output_ref[tuple(first_mismatch_idx.tolist())])
|
161
|
+
raise AssertionError(
|
162
|
+
"Attention output is not close to the torch native backend output"
|
163
|
+
)
|
164
|
+
|
165
|
+
def _create_forward_batch(self, mode, q_len=None, prefix_len=0, page_size=1):
|
166
|
+
"""Create a forward batch for testing based on mode and lengths."""
|
167
|
+
self._init_model_runner(page_size=page_size)
|
168
|
+
|
169
|
+
# Default to self.seq_len if not specified
|
170
|
+
q_len = q_len or self.seq_len
|
171
|
+
|
172
|
+
if mode == ForwardMode.EXTEND:
|
173
|
+
total_len = prefix_len + q_len
|
174
|
+
out_cache_start = prefix_len * self.batch_size
|
175
|
+
out_cache_end = total_len * self.batch_size
|
176
|
+
|
177
|
+
forward_batch = ForwardBatch(
|
178
|
+
batch_size=self.batch_size,
|
179
|
+
input_ids=torch.randint(
|
180
|
+
0, 100, (self.batch_size, q_len), device=self.device
|
181
|
+
),
|
182
|
+
out_cache_loc=torch.arange(
|
183
|
+
out_cache_start, out_cache_end, device=self.device
|
184
|
+
),
|
185
|
+
seq_lens_sum=self.batch_size * total_len,
|
186
|
+
forward_mode=mode,
|
187
|
+
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
188
|
+
seq_lens=torch.tensor(
|
189
|
+
[total_len] * self.batch_size, device=self.device
|
190
|
+
),
|
191
|
+
seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
|
192
|
+
extend_prefix_lens=torch.tensor(
|
193
|
+
[prefix_len] * self.batch_size, device=self.device
|
194
|
+
),
|
195
|
+
extend_prefix_lens_cpu=torch.tensor(
|
196
|
+
[prefix_len] * self.batch_size, device="cpu"
|
197
|
+
),
|
198
|
+
extend_seq_lens=torch.tensor(
|
199
|
+
[q_len] * self.batch_size, device=self.device
|
200
|
+
),
|
201
|
+
extend_seq_lens_cpu=torch.tensor(
|
202
|
+
[q_len] * self.batch_size, device="cpu"
|
203
|
+
),
|
204
|
+
attn_backend=self.backend,
|
205
|
+
)
|
206
|
+
else: # ForwardMode.DECODE
|
207
|
+
decode_len = q_len # Assuming 1 for decode testing
|
208
|
+
total_len = self.seq_len + decode_len
|
209
|
+
if mode == ForwardMode.DECODE and page_size > 1:
|
210
|
+
# Get next page_size multiple of self.seq_len
|
211
|
+
out_cache_start = (
|
212
|
+
self.batch_size * self.seq_len // page_size + 1
|
213
|
+
) * page_size
|
214
|
+
# out_cache_end is the start of the next block
|
215
|
+
out_cache_end = out_cache_start + decode_len * page_size
|
216
|
+
else:
|
217
|
+
out_cache_start = self.batch_size * self.seq_len
|
218
|
+
out_cache_end = self.batch_size * total_len
|
219
|
+
|
220
|
+
forward_batch = ForwardBatch(
|
221
|
+
batch_size=self.batch_size,
|
222
|
+
input_ids=torch.randint(
|
223
|
+
0, 100, (self.batch_size, decode_len), device=self.device
|
224
|
+
),
|
225
|
+
out_cache_loc=torch.tensor(
|
226
|
+
[out_cache_start, out_cache_end], device=self.device
|
227
|
+
),
|
228
|
+
seq_lens_sum=self.batch_size * total_len,
|
229
|
+
forward_mode=mode,
|
230
|
+
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
231
|
+
seq_lens=torch.tensor(
|
232
|
+
[total_len] * self.batch_size, device=self.device
|
233
|
+
),
|
234
|
+
seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
|
235
|
+
attn_backend=self.backend,
|
236
|
+
)
|
235
237
|
|
236
|
-
#
|
237
|
-
|
238
|
+
# Add token pool
|
239
|
+
forward_batch.req_to_token_pool = self.model_runner.req_to_token_pool
|
238
240
|
|
239
|
-
#
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
0, 100, (self.batch_size, extend_len), device=self.device
|
244
|
-
),
|
245
|
-
out_cache_loc=torch.arange(
|
246
|
-
self.batch_size * prefix_len,
|
247
|
-
self.batch_size * total_len,
|
248
|
-
device=self.device,
|
249
|
-
),
|
250
|
-
seq_lens_sum=self.batch_size * total_len,
|
251
|
-
forward_mode=ForwardMode.EXTEND,
|
252
|
-
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
253
|
-
seq_lens=torch.tensor([total_len] * self.batch_size, device=self.device),
|
254
|
-
extend_prefix_lens=torch.tensor(
|
255
|
-
[prefix_len] * self.batch_size, device=self.device
|
256
|
-
),
|
257
|
-
extend_seq_lens=torch.tensor(
|
258
|
-
[extend_len] * self.batch_size, device=self.device
|
259
|
-
),
|
260
|
-
attn_backend=self.backend,
|
261
|
-
)
|
241
|
+
# Write current batch's req_to_token to req_to_token_pool
|
242
|
+
self._mock_write_to_req_to_token_pool(self.batch_size, total_len, page_size)
|
243
|
+
# Add kv pool for this forward batch
|
244
|
+
forward_batch.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
262
245
|
|
263
|
-
|
264
|
-
forward_batch.req_to_token_pool = MockReqToTokenPool(
|
265
|
-
self.batch_size, total_len, self.device
|
266
|
-
)
|
267
|
-
forward_batch.token_to_kv_pool = self._create_kv_pool(
|
268
|
-
self.batch_size * total_len
|
269
|
-
)
|
246
|
+
return forward_batch
|
270
247
|
|
271
|
-
|
248
|
+
def _setup_kv_cache(self, forward_batch, layer, cache_len):
|
249
|
+
# Create constant values for the prefix cache for easy debugging
|
272
250
|
cache_k = torch.ones(
|
273
|
-
self.batch_size *
|
251
|
+
self.batch_size * cache_len,
|
274
252
|
self.num_heads,
|
275
253
|
self.head_dim,
|
276
254
|
dtype=self.dtype,
|
@@ -278,7 +256,7 @@ class TestFlashAttentionBackend(CustomTestCase):
|
|
278
256
|
)
|
279
257
|
cache_v = (
|
280
258
|
torch.ones(
|
281
|
-
self.batch_size *
|
259
|
+
self.batch_size * cache_len,
|
282
260
|
self.num_heads,
|
283
261
|
self.head_dim,
|
284
262
|
dtype=self.dtype,
|
@@ -290,22 +268,82 @@ class TestFlashAttentionBackend(CustomTestCase):
|
|
290
268
|
# Set the prefix KV cache
|
291
269
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
292
270
|
layer,
|
293
|
-
torch.arange(self.batch_size *
|
271
|
+
torch.arange(self.batch_size * cache_len, device=self.device),
|
294
272
|
cache_k,
|
295
273
|
cache_v,
|
296
274
|
layer.k_scale,
|
297
275
|
layer.v_scale,
|
298
276
|
)
|
299
277
|
|
300
|
-
|
278
|
+
def _run_attention_test(self, mode, q_len, prefix_len=0, page_size=1):
|
279
|
+
"""
|
280
|
+
Run an attention test with the specified parameters.
|
281
|
+
Args:
|
282
|
+
mode: ForwardMode.EXTEND or ForwardMode.DECODE
|
283
|
+
q_len: Length of the query sequence. For decode mode, q_len is 1.
|
284
|
+
prefix_len: Length of the prefix sequence for extend mode
|
285
|
+
page_size: Page size for the KV cache
|
286
|
+
"""
|
287
|
+
layer = self._create_attention_layer()
|
288
|
+
|
289
|
+
# Create forward batch and set up
|
290
|
+
forward_batch = self._create_forward_batch(mode, q_len, prefix_len, page_size)
|
291
|
+
|
292
|
+
# Create QKV tensors for the input
|
293
|
+
q, k, v = self._create_qkv_tensors(self.batch_size * q_len)
|
294
|
+
|
295
|
+
# KV cache for prefixed extend is prefix_len
|
296
|
+
# KV cache for decode is same as seq_len
|
297
|
+
# No KV cache for extend without prefix
|
298
|
+
if mode == ForwardMode.EXTEND:
|
299
|
+
if prefix_len > 0:
|
300
|
+
self._setup_kv_cache(forward_batch, layer, prefix_len)
|
301
|
+
else:
|
302
|
+
self._setup_kv_cache(forward_batch, layer, self.seq_len)
|
303
|
+
|
301
304
|
self.backend.init_forward_metadata(forward_batch)
|
302
305
|
|
303
|
-
|
304
|
-
|
306
|
+
if mode == ForwardMode.EXTEND:
|
307
|
+
expected_shape = (
|
308
|
+
self.batch_size * q_len,
|
309
|
+
self.num_heads * self.head_dim,
|
310
|
+
)
|
311
|
+
output = self.backend.forward_extend(q, k, v, layer, forward_batch)
|
312
|
+
else:
|
313
|
+
expected_shape = (self.batch_size, self.num_heads * self.head_dim)
|
314
|
+
output = self.backend.forward_decode(q, k, v, layer, forward_batch)
|
315
|
+
|
316
|
+
output_ref = self._run_reference_forward(
|
317
|
+
mode, q, k, v, layer, forward_batch, expected_shape
|
318
|
+
)
|
319
|
+
|
320
|
+
self._verify_output(output, expected_shape, output_ref)
|
321
|
+
|
322
|
+
return output
|
323
|
+
|
324
|
+
def test_forward_extend(self):
|
325
|
+
"""Test the standard extend operation."""
|
326
|
+
self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len)
|
327
|
+
|
328
|
+
def test_forward_decode(self):
|
329
|
+
"""Test the decode operation with cached tokens."""
|
330
|
+
self._run_attention_test(ForwardMode.DECODE, q_len=1)
|
331
|
+
|
332
|
+
def test_forward_extend_with_prefix(self):
|
333
|
+
"""Test extending from cached prefix tokens."""
|
334
|
+
prefix_len = self.seq_len // 2
|
335
|
+
extend_len = self.seq_len - prefix_len
|
336
|
+
self._run_attention_test(
|
337
|
+
ForwardMode.EXTEND, q_len=extend_len, prefix_len=prefix_len
|
338
|
+
)
|
339
|
+
|
340
|
+
def test_forward_extend_with_page_size_greater_than_1(self):
|
341
|
+
"""Test extending from cached prefix tokens with page size greater than 1."""
|
342
|
+
self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len, page_size=64)
|
305
343
|
|
306
|
-
|
307
|
-
|
308
|
-
self.
|
344
|
+
def test_forward_decode_with_page_size_greater_than_1(self):
|
345
|
+
"""Test decode operation with page size greater than 1."""
|
346
|
+
self._run_attention_test(ForwardMode.DECODE, q_len=1, page_size=64)
|
309
347
|
|
310
348
|
|
311
349
|
if __name__ == "__main__":
|