sglang 0.4.10__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/moe/ep_moe/layer.py +19 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -2
- sglang/srt/layers/quantization/fp8.py +52 -0
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/managers/cache_controller.py +35 -35
- sglang/srt/managers/scheduler.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +15 -6
- sglang/srt/mem_cache/hiradix_cache.py +21 -4
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +350 -33
- sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
- sglang/srt/model_executor/cuda_graph_runner.py +25 -1
- sglang/srt/model_executor/model_runner.py +8 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/step3_vl.py +0 -3
- sglang/srt/server_args.py +40 -6
- sglang/srt/utils.py +1 -0
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/version.py +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +35 -30
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,945 @@
|
|
1
|
+
import math
|
2
|
+
import unittest
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from sglang.srt.layers import dp_attention as _dp_attn
|
8
|
+
|
9
|
+
# Patch DP-attention globals before importing backends
|
10
|
+
# TODO: change the interface of both trtllm_mla and flashinfer backends to take tp_size as an argument instead of patching
|
11
|
+
_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test
|
12
|
+
|
13
|
+
from sglang.srt.configs.model_config import AttentionArch
|
14
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
15
|
+
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
16
|
+
TRTLLMMLABackend,
|
17
|
+
TRTLLMMLADecodeMetadata,
|
18
|
+
)
|
19
|
+
from sglang.srt.layers.attention.utils import TRITON_PAD_NUM_PAGE_PER_BLOCK
|
20
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
21
|
+
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
22
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
23
|
+
from sglang.srt.utils import is_flashinfer_available
|
24
|
+
from sglang.test.test_utils import CustomTestCase
|
25
|
+
|
26
|
+
# Global configuration for all tests
|
27
|
+
DEFAULT_CONFIG = {
|
28
|
+
"device": "cuda",
|
29
|
+
"dtype": torch.bfloat16,
|
30
|
+
"kv_cache_dtype": torch.bfloat16,
|
31
|
+
"context_len": 2048,
|
32
|
+
"max_bs": 64,
|
33
|
+
"tolerance": 1e-2,
|
34
|
+
"seed_cache": 42,
|
35
|
+
"seed_qkv": 123,
|
36
|
+
# MLA model config (TRTLLM MLA has fixed constraints)
|
37
|
+
"num_attention_heads": 128,
|
38
|
+
"kv_lora_rank": 512,
|
39
|
+
"qk_nope_head_dim": 128,
|
40
|
+
"qk_rope_head_dim": 64,
|
41
|
+
"v_head_dim": 512,
|
42
|
+
"num_kv_heads": 1,
|
43
|
+
"layer_id": 0,
|
44
|
+
}
|
45
|
+
|
46
|
+
# Centralized test cases for different test scenarios
|
47
|
+
TEST_CASES = {
|
48
|
+
"basic_functionality": [
|
49
|
+
{
|
50
|
+
"name": "single",
|
51
|
+
"batch_size": 1,
|
52
|
+
"max_seq_len": 32,
|
53
|
+
"page_size": 32,
|
54
|
+
"description": "Minimal smoke test",
|
55
|
+
},
|
56
|
+
{
|
57
|
+
"name": "batch",
|
58
|
+
"batch_size": 32,
|
59
|
+
"max_seq_len": 128,
|
60
|
+
"page_size": 32,
|
61
|
+
"description": "Medium-scale batch",
|
62
|
+
},
|
63
|
+
],
|
64
|
+
"decode_output_match": [
|
65
|
+
{
|
66
|
+
"name": "single",
|
67
|
+
"batch_size": 1,
|
68
|
+
"max_seq_len": 64,
|
69
|
+
"page_size": 32,
|
70
|
+
"description": "Single vs reference",
|
71
|
+
},
|
72
|
+
{
|
73
|
+
"name": "batch",
|
74
|
+
"batch_size": 32,
|
75
|
+
"max_seq_len": 64,
|
76
|
+
"page_size": 32,
|
77
|
+
"description": "Batch vs reference",
|
78
|
+
},
|
79
|
+
],
|
80
|
+
"page_size_consistency": [
|
81
|
+
# Only 32 and 64 supported for now in flashinfer TRTLLM-GEN MLA kernel
|
82
|
+
{
|
83
|
+
"name": "page_32",
|
84
|
+
"batch_size": 8,
|
85
|
+
"max_seq_len": 128,
|
86
|
+
"page_size": 32,
|
87
|
+
"description": "32-token pages",
|
88
|
+
},
|
89
|
+
{
|
90
|
+
"name": "page_64",
|
91
|
+
"batch_size": 8,
|
92
|
+
"max_seq_len": 128,
|
93
|
+
"page_size": 64,
|
94
|
+
"description": "64-token pages",
|
95
|
+
},
|
96
|
+
],
|
97
|
+
"shape_sanity_tests": [
|
98
|
+
{
|
99
|
+
"name": "basic",
|
100
|
+
"batch_size": 1,
|
101
|
+
"max_seq_len": 128,
|
102
|
+
"page_size": 32,
|
103
|
+
"description": "Single sequence",
|
104
|
+
},
|
105
|
+
{
|
106
|
+
"name": "basic_different_pagesize",
|
107
|
+
"batch_size": 1,
|
108
|
+
"max_seq_len": 128,
|
109
|
+
"page_size": 64,
|
110
|
+
"description": "Different page size",
|
111
|
+
},
|
112
|
+
{
|
113
|
+
"name": "batch",
|
114
|
+
"batch_size": 8,
|
115
|
+
"max_seq_len": 128,
|
116
|
+
"page_size": 32,
|
117
|
+
"description": "Batch shapes",
|
118
|
+
},
|
119
|
+
],
|
120
|
+
"metadata_tests": [
|
121
|
+
{
|
122
|
+
"name": "single_sequence",
|
123
|
+
"batch_size": 1,
|
124
|
+
"max_seq_len": 64,
|
125
|
+
"page_size": 32,
|
126
|
+
"description": "Single sequence metadata",
|
127
|
+
},
|
128
|
+
{
|
129
|
+
"name": "batch_mixed_lengths",
|
130
|
+
"batch_size": 8,
|
131
|
+
"max_seq_len": 128,
|
132
|
+
"page_size": 32,
|
133
|
+
"description": "Mixed sequence lengths",
|
134
|
+
},
|
135
|
+
{
|
136
|
+
"name": "large_batch",
|
137
|
+
"batch_size": 32,
|
138
|
+
"max_seq_len": 256,
|
139
|
+
"page_size": 64,
|
140
|
+
"description": "Large batch stress test",
|
141
|
+
},
|
142
|
+
{
|
143
|
+
"name": "edge_case_short",
|
144
|
+
"batch_size": 4,
|
145
|
+
"max_seq_len": 16,
|
146
|
+
"page_size": 32,
|
147
|
+
"description": "Sub-page sequences",
|
148
|
+
},
|
149
|
+
],
|
150
|
+
}
|
151
|
+
|
152
|
+
|
153
|
+
class MockModelRunner:
|
154
|
+
"""Minimal fake ModelRunner for testing MLA backends."""
|
155
|
+
|
156
|
+
def __init__(self, config):
|
157
|
+
self.device = config["device"]
|
158
|
+
self.dtype = config["dtype"]
|
159
|
+
self.kv_cache_dtype = config["kv_cache_dtype"]
|
160
|
+
self.page_size = config["page_size"]
|
161
|
+
|
162
|
+
# Model-config stub with MLA attributes
|
163
|
+
self.model_config = type(
|
164
|
+
"ModelConfig",
|
165
|
+
(),
|
166
|
+
{
|
167
|
+
"context_len": config["context_len"],
|
168
|
+
"attention_arch": AttentionArch.MLA,
|
169
|
+
"num_attention_heads": config["num_attention_heads"],
|
170
|
+
"kv_lora_rank": config["kv_lora_rank"],
|
171
|
+
"qk_nope_head_dim": config["qk_nope_head_dim"],
|
172
|
+
"qk_rope_head_dim": config["qk_rope_head_dim"],
|
173
|
+
"v_head_dim": config["v_head_dim"],
|
174
|
+
"scaling": 1.0
|
175
|
+
/ ((config["qk_nope_head_dim"] + config["qk_rope_head_dim"]) ** 0.5),
|
176
|
+
"get_num_kv_heads": staticmethod(lambda _: config["num_kv_heads"]),
|
177
|
+
},
|
178
|
+
)
|
179
|
+
|
180
|
+
# Req-to-token pool
|
181
|
+
max_bs = config["max_bs"]
|
182
|
+
max_ctx = self.model_config.context_len
|
183
|
+
req_to_token = torch.arange(
|
184
|
+
max_bs * max_ctx, dtype=torch.int32, device=self.device
|
185
|
+
).reshape(max_bs, max_ctx)
|
186
|
+
self.req_to_token_pool = type(
|
187
|
+
"TokenPool",
|
188
|
+
(),
|
189
|
+
{
|
190
|
+
"size": max_bs,
|
191
|
+
"req_to_token": req_to_token,
|
192
|
+
},
|
193
|
+
)
|
194
|
+
|
195
|
+
# KV-token pool (MLA)
|
196
|
+
self.token_to_kv_pool = MLATokenToKVPool(
|
197
|
+
size=max_bs * max_ctx,
|
198
|
+
page_size=config["page_size"],
|
199
|
+
dtype=self.kv_cache_dtype,
|
200
|
+
kv_lora_rank=config["kv_lora_rank"],
|
201
|
+
qk_rope_head_dim=config["qk_rope_head_dim"],
|
202
|
+
layer_num=1,
|
203
|
+
device=self.device,
|
204
|
+
enable_memory_saver=False,
|
205
|
+
)
|
206
|
+
|
207
|
+
|
208
|
+
def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
|
209
|
+
"""Compare outputs with detailed analysis."""
|
210
|
+
|
211
|
+
# Basic checks
|
212
|
+
assert (
|
213
|
+
trtllm_out.shape == reference_out.shape
|
214
|
+
), f"Shape mismatch: {trtllm_out.shape} vs {reference_out.shape}"
|
215
|
+
assert (
|
216
|
+
trtllm_out.dtype == reference_out.dtype
|
217
|
+
), f"Dtype mismatch: {trtllm_out.dtype} vs {reference_out.dtype}"
|
218
|
+
|
219
|
+
# Check for NaN/Inf
|
220
|
+
assert not torch.isnan(trtllm_out).any(), "TRTLLM output contains NaN"
|
221
|
+
assert not torch.isnan(reference_out).any(), "Reference output contains NaN"
|
222
|
+
assert not torch.isinf(trtllm_out).any(), "TRTLLM output contains Inf"
|
223
|
+
assert not torch.isinf(reference_out).any(), "Reference output contains Inf"
|
224
|
+
|
225
|
+
# Element-wise differences
|
226
|
+
diff = (trtllm_out - reference_out).abs()
|
227
|
+
max_diff = diff.max().item()
|
228
|
+
mean_diff = diff.mean().item()
|
229
|
+
|
230
|
+
# Check numerical equivalence
|
231
|
+
all_close = torch.allclose(
|
232
|
+
trtllm_out, reference_out, rtol=tolerance, atol=tolerance
|
233
|
+
)
|
234
|
+
|
235
|
+
if not all_close:
|
236
|
+
print(
|
237
|
+
f"Comparison failed: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, tolerance={tolerance}"
|
238
|
+
)
|
239
|
+
# Find top differences for debugging
|
240
|
+
flat_diff = diff.flatten()
|
241
|
+
top_diff_indices = torch.topk(flat_diff, k=min(5, flat_diff.numel())).indices
|
242
|
+
print("Top 5 differences:")
|
243
|
+
for i, idx in enumerate(top_diff_indices):
|
244
|
+
idx_tuple = np.unravel_index(idx.cpu().numpy(), trtllm_out.shape)
|
245
|
+
trt_val = trtllm_out[idx_tuple].item()
|
246
|
+
ref_val = reference_out[idx_tuple].item()
|
247
|
+
print(
|
248
|
+
f" [{idx_tuple}]: TRTLLM={trt_val:.6f}, Reference={ref_val:.6f}, diff={abs(trt_val-ref_val):.6f}"
|
249
|
+
)
|
250
|
+
|
251
|
+
return all_close
|
252
|
+
|
253
|
+
|
254
|
+
@unittest.skipIf(
|
255
|
+
not torch.cuda.is_available() or not is_flashinfer_available(),
|
256
|
+
"CUDA + flashinfer required",
|
257
|
+
)
|
258
|
+
class TestTRTLLMMLA(CustomTestCase):
|
259
|
+
"""Test suite for TRTLLM MLA backend with centralized configuration."""
|
260
|
+
|
261
|
+
def _merge_config(self, test_case):
|
262
|
+
"""Merge test case with default configuration."""
|
263
|
+
config = DEFAULT_CONFIG.copy()
|
264
|
+
config.update(test_case)
|
265
|
+
return config
|
266
|
+
|
267
|
+
def _create_model_components(self, config):
|
268
|
+
"""Create model runners, backends, and layer for testing."""
|
269
|
+
# Create model runners
|
270
|
+
model_runner_trtllm = MockModelRunner(config)
|
271
|
+
model_runner_reference = MockModelRunner(config)
|
272
|
+
|
273
|
+
# Create backends
|
274
|
+
trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
|
275
|
+
reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
|
276
|
+
|
277
|
+
# Create RadixAttention layer
|
278
|
+
layer = RadixAttention(
|
279
|
+
num_heads=config["num_attention_heads"],
|
280
|
+
head_dim=config["kv_lora_rank"] + config["qk_rope_head_dim"],
|
281
|
+
scaling=model_runner_trtllm.model_config.scaling,
|
282
|
+
num_kv_heads=config["num_kv_heads"],
|
283
|
+
layer_id=config["layer_id"],
|
284
|
+
v_head_dim=config["v_head_dim"],
|
285
|
+
prefix="attn_mqa",
|
286
|
+
)
|
287
|
+
|
288
|
+
return (
|
289
|
+
model_runner_trtllm,
|
290
|
+
model_runner_reference,
|
291
|
+
trtllm_backend,
|
292
|
+
reference_backend,
|
293
|
+
layer,
|
294
|
+
)
|
295
|
+
|
296
|
+
def _create_qkv_tensors(self, batch_size, config):
|
297
|
+
"""Create Q, K, V tensors for testing."""
|
298
|
+
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
299
|
+
device = config["device"]
|
300
|
+
dtype = config["dtype"]
|
301
|
+
|
302
|
+
q = torch.randn(
|
303
|
+
(batch_size, config["num_attention_heads"], head_dim),
|
304
|
+
dtype=dtype,
|
305
|
+
device=device,
|
306
|
+
)
|
307
|
+
k = torch.randn(
|
308
|
+
(batch_size, config["num_kv_heads"], head_dim), dtype=dtype, device=device
|
309
|
+
)
|
310
|
+
v = torch.randn(
|
311
|
+
(batch_size, config["num_kv_heads"], config["v_head_dim"]),
|
312
|
+
dtype=dtype,
|
313
|
+
device=device,
|
314
|
+
)
|
315
|
+
return q, k, v
|
316
|
+
|
317
|
+
def _create_forward_batch(
|
318
|
+
self, batch_size, seq_lens, backend, model_runner, config
|
319
|
+
):
|
320
|
+
"""Create a forward batch for the given backend."""
|
321
|
+
fb = ForwardBatch(
|
322
|
+
batch_size=batch_size,
|
323
|
+
input_ids=torch.randint(0, 100, (batch_size, 1), device=config["device"]),
|
324
|
+
out_cache_loc=torch.arange(batch_size, device=config["device"]),
|
325
|
+
seq_lens_sum=int(seq_lens.sum().item()),
|
326
|
+
forward_mode=ForwardMode.DECODE,
|
327
|
+
req_pool_indices=torch.arange(batch_size, device=config["device"]),
|
328
|
+
seq_lens=seq_lens,
|
329
|
+
seq_lens_cpu=seq_lens.cpu(),
|
330
|
+
attn_backend=backend,
|
331
|
+
)
|
332
|
+
fb.req_to_token_pool = model_runner.req_to_token_pool
|
333
|
+
fb.token_to_kv_pool = model_runner.token_to_kv_pool
|
334
|
+
return fb
|
335
|
+
|
336
|
+
def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config):
|
337
|
+
"""Populate KV cache with identical data for both backends."""
|
338
|
+
torch.manual_seed(config["seed_cache"]) # Fixed seed for reproducible cache
|
339
|
+
|
340
|
+
for model_runner in model_runners:
|
341
|
+
torch.manual_seed(config["seed_cache"]) # Reset seed for each backend
|
342
|
+
for i in range(batch_size):
|
343
|
+
seq_len = int(seq_lens[i].item())
|
344
|
+
for token_idx in range(seq_len - 1):
|
345
|
+
# Create random K components for MLA
|
346
|
+
cache_k_nope = torch.randn(
|
347
|
+
(1, config["qk_nope_head_dim"]),
|
348
|
+
dtype=config["dtype"],
|
349
|
+
device=config["device"],
|
350
|
+
)
|
351
|
+
cache_k_rope = torch.randn(
|
352
|
+
(1, config["qk_rope_head_dim"]),
|
353
|
+
dtype=config["dtype"],
|
354
|
+
device=config["device"],
|
355
|
+
)
|
356
|
+
|
357
|
+
# Calculate cache location
|
358
|
+
cache_loc = model_runner.req_to_token_pool.req_to_token[
|
359
|
+
i, token_idx
|
360
|
+
]
|
361
|
+
|
362
|
+
# Save to KV cache
|
363
|
+
model_runner.token_to_kv_pool.set_mla_kv_buffer(
|
364
|
+
layer,
|
365
|
+
cache_loc.unsqueeze(0),
|
366
|
+
cache_k_nope.squeeze(0),
|
367
|
+
cache_k_rope.squeeze(0),
|
368
|
+
)
|
369
|
+
|
370
|
+
def test_basic_functionality(self):
|
371
|
+
"""Test basic functionality with minimal setup."""
|
372
|
+
print(f"\nRunning basic functionality tests...")
|
373
|
+
|
374
|
+
for test_case in TEST_CASES["basic_functionality"]:
|
375
|
+
with self.subTest(test_case=test_case["name"]):
|
376
|
+
print(f" Testing {test_case['name']}: {test_case['description']}")
|
377
|
+
|
378
|
+
config = self._merge_config(test_case)
|
379
|
+
batch_size = config["batch_size"]
|
380
|
+
max_seq_len = config["max_seq_len"]
|
381
|
+
|
382
|
+
# Create components
|
383
|
+
model_runner_trtllm, _, trtllm_backend, _, layer = (
|
384
|
+
self._create_model_components(config)
|
385
|
+
)
|
386
|
+
|
387
|
+
# Create sequence lengths - properly handle different batch sizes
|
388
|
+
if batch_size == 2:
|
389
|
+
seq_lens = torch.tensor(
|
390
|
+
[max_seq_len, max_seq_len // 2], device=config["device"]
|
391
|
+
)
|
392
|
+
else:
|
393
|
+
# For larger batch sizes, create varied sequence lengths
|
394
|
+
torch.manual_seed(config["seed_cache"])
|
395
|
+
seq_lens = torch.randint(
|
396
|
+
max_seq_len // 2,
|
397
|
+
max_seq_len + 1,
|
398
|
+
(batch_size,),
|
399
|
+
device=config["device"],
|
400
|
+
)
|
401
|
+
seq_lens[0] = max_seq_len # Ensure at least one max length
|
402
|
+
|
403
|
+
# Create forward batch
|
404
|
+
fb = self._create_forward_batch(
|
405
|
+
batch_size, seq_lens, trtllm_backend, model_runner_trtllm, config
|
406
|
+
)
|
407
|
+
trtllm_backend.init_forward_metadata(fb)
|
408
|
+
|
409
|
+
# Populate KV cache
|
410
|
+
self._populate_kv_cache(
|
411
|
+
batch_size, seq_lens, [model_runner_trtllm], layer, config
|
412
|
+
)
|
413
|
+
|
414
|
+
# Create Q, K, V tensors
|
415
|
+
torch.manual_seed(config["seed_qkv"])
|
416
|
+
q, k, v = self._create_qkv_tensors(batch_size, config)
|
417
|
+
|
418
|
+
# Run forward decode
|
419
|
+
output = trtllm_backend.forward_decode(q, k, v, layer, fb)
|
420
|
+
|
421
|
+
# Basic checks
|
422
|
+
expected_shape = (
|
423
|
+
batch_size,
|
424
|
+
config["num_attention_heads"] * config["v_head_dim"],
|
425
|
+
)
|
426
|
+
self.assertEqual(output.shape, expected_shape)
|
427
|
+
self.assertEqual(output.dtype, config["dtype"])
|
428
|
+
self.assertFalse(torch.isnan(output).any())
|
429
|
+
self.assertFalse(torch.isinf(output).any())
|
430
|
+
|
431
|
+
def test_decode_output_match(self):
|
432
|
+
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
|
433
|
+
print(f"\nRunning decode output matching tests...")
|
434
|
+
|
435
|
+
for test_case in TEST_CASES["decode_output_match"]:
|
436
|
+
with self.subTest(test_case=test_case["name"]):
|
437
|
+
print(f" Testing {test_case['name']}: {test_case['description']}")
|
438
|
+
|
439
|
+
config = self._merge_config(test_case)
|
440
|
+
batch_size = config["batch_size"]
|
441
|
+
max_seq_len = config["max_seq_len"]
|
442
|
+
|
443
|
+
# Create components
|
444
|
+
(
|
445
|
+
model_runner_trtllm,
|
446
|
+
model_runner_reference,
|
447
|
+
trtllm_backend,
|
448
|
+
reference_backend,
|
449
|
+
layer,
|
450
|
+
) = self._create_model_components(config)
|
451
|
+
|
452
|
+
# Create identical sequence lengths for both backends
|
453
|
+
torch.manual_seed(config["seed_cache"])
|
454
|
+
seq_lens = torch.randint(
|
455
|
+
1, max_seq_len, (batch_size,), device=config["device"]
|
456
|
+
)
|
457
|
+
seq_lens[0] = max_seq_len # Ensure at least one max length
|
458
|
+
|
459
|
+
# Create forward batches with identical inputs
|
460
|
+
fb_trtllm = self._create_forward_batch(
|
461
|
+
batch_size,
|
462
|
+
seq_lens.clone(),
|
463
|
+
trtllm_backend,
|
464
|
+
model_runner_trtllm,
|
465
|
+
config,
|
466
|
+
)
|
467
|
+
fb_reference = self._create_forward_batch(
|
468
|
+
batch_size,
|
469
|
+
seq_lens.clone(),
|
470
|
+
reference_backend,
|
471
|
+
model_runner_reference,
|
472
|
+
config,
|
473
|
+
)
|
474
|
+
|
475
|
+
# Initialize metadata for both backends
|
476
|
+
trtllm_backend.init_forward_metadata(fb_trtllm)
|
477
|
+
reference_backend.init_forward_metadata(fb_reference)
|
478
|
+
|
479
|
+
# Populate both KV caches identically
|
480
|
+
self._populate_kv_cache(
|
481
|
+
batch_size,
|
482
|
+
seq_lens,
|
483
|
+
[model_runner_trtllm, model_runner_reference],
|
484
|
+
layer,
|
485
|
+
config,
|
486
|
+
)
|
487
|
+
|
488
|
+
# Create Q, K, V tensors for current decode step
|
489
|
+
torch.manual_seed(config["seed_qkv"])
|
490
|
+
q, k, v = self._create_qkv_tensors(batch_size, config)
|
491
|
+
|
492
|
+
# Run forward decode on both backends
|
493
|
+
out_trtllm = trtllm_backend.forward_decode(
|
494
|
+
q.clone(), k.clone(), v.clone(), layer, fb_trtllm
|
495
|
+
)
|
496
|
+
out_reference = reference_backend.forward_decode(
|
497
|
+
q.clone(), k.clone(), v.clone(), layer, fb_reference
|
498
|
+
)
|
499
|
+
|
500
|
+
# Compare outputs
|
501
|
+
comparison_passed = compare_outputs(
|
502
|
+
out_trtllm, out_reference, tolerance=config["tolerance"]
|
503
|
+
)
|
504
|
+
|
505
|
+
self.assertTrue(
|
506
|
+
comparison_passed,
|
507
|
+
f"TRTLLM and Reference outputs differ beyond tolerance. "
|
508
|
+
f"Config: {test_case['name']}, "
|
509
|
+
f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
|
510
|
+
)
|
511
|
+
|
512
|
+
def test_page_size_consistency(self):
|
513
|
+
"""Test output consistency across different page sizes."""
|
514
|
+
print(f"\nRunning page size consistency tests...")
|
515
|
+
|
516
|
+
for test_case in TEST_CASES["page_size_consistency"]:
|
517
|
+
with self.subTest(test_case=test_case["name"]):
|
518
|
+
print(f" Testing {test_case['name']}: {test_case['description']}")
|
519
|
+
|
520
|
+
config = self._merge_config(test_case)
|
521
|
+
batch_size = config["batch_size"]
|
522
|
+
max_seq_len = config["max_seq_len"]
|
523
|
+
|
524
|
+
# Create components
|
525
|
+
model_runner, _, backend, _, layer = self._create_model_components(
|
526
|
+
config
|
527
|
+
)
|
528
|
+
|
529
|
+
# Create sequence lengths
|
530
|
+
torch.manual_seed(config["seed_cache"])
|
531
|
+
seq_lens = torch.randint(
|
532
|
+
1, max_seq_len, (batch_size,), device=config["device"]
|
533
|
+
)
|
534
|
+
seq_lens[0] = max_seq_len
|
535
|
+
|
536
|
+
# Create forward batch
|
537
|
+
fb = self._create_forward_batch(
|
538
|
+
batch_size, seq_lens, backend, model_runner, config
|
539
|
+
)
|
540
|
+
backend.init_forward_metadata(fb)
|
541
|
+
|
542
|
+
# Populate KV cache
|
543
|
+
self._populate_kv_cache(
|
544
|
+
batch_size, seq_lens, [model_runner], layer, config
|
545
|
+
)
|
546
|
+
|
547
|
+
# Create Q, K, V tensors
|
548
|
+
torch.manual_seed(config["seed_qkv"])
|
549
|
+
q, k, v = self._create_qkv_tensors(batch_size, config)
|
550
|
+
|
551
|
+
# Run forward decode
|
552
|
+
output = backend.forward_decode(q, k, v, layer, fb)
|
553
|
+
|
554
|
+
expected_shape = (
|
555
|
+
batch_size,
|
556
|
+
config["num_attention_heads"] * config["v_head_dim"],
|
557
|
+
)
|
558
|
+
self.assertEqual(
|
559
|
+
output.shape,
|
560
|
+
expected_shape,
|
561
|
+
f"Output shape mismatch: {output.shape} vs {expected_shape}",
|
562
|
+
)
|
563
|
+
self.assertFalse(torch.isnan(output).any(), "Output contains NaN")
|
564
|
+
self.assertFalse(torch.isinf(output).any(), "Output contains Inf")
|
565
|
+
|
566
|
+
def test_shape_sanity(self):
|
567
|
+
"""Smoke test decode across several configurations."""
|
568
|
+
print(f"\nRunning shape sanity tests...")
|
569
|
+
|
570
|
+
for test_case in TEST_CASES["shape_sanity_tests"]:
|
571
|
+
with self.subTest(test_case=test_case["name"]):
|
572
|
+
print(f" Testing {test_case['name']}: {test_case['description']}")
|
573
|
+
|
574
|
+
config = self._merge_config(test_case)
|
575
|
+
batch_size = config["batch_size"]
|
576
|
+
max_seq_len = config["max_seq_len"]
|
577
|
+
|
578
|
+
model_runner, _, backend, _, layer = self._create_model_components(
|
579
|
+
config
|
580
|
+
)
|
581
|
+
|
582
|
+
# Random seq lens (ensure one matches max)
|
583
|
+
torch.manual_seed(config["seed_cache"])
|
584
|
+
seq_lens = torch.randint(
|
585
|
+
1, max_seq_len, (batch_size,), device=config["device"]
|
586
|
+
)
|
587
|
+
seq_lens[0] = max_seq_len
|
588
|
+
|
589
|
+
fb = self._create_forward_batch(
|
590
|
+
batch_size, seq_lens, backend, model_runner, config
|
591
|
+
)
|
592
|
+
backend.init_forward_metadata(fb)
|
593
|
+
|
594
|
+
# Create Q, K, V tensors
|
595
|
+
torch.manual_seed(config["seed_qkv"])
|
596
|
+
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
|
597
|
+
q = torch.randn(
|
598
|
+
(batch_size, config["num_attention_heads"], head_dim),
|
599
|
+
dtype=config["dtype"],
|
600
|
+
device=config["device"],
|
601
|
+
)
|
602
|
+
k = torch.randn(
|
603
|
+
(batch_size, config["num_kv_heads"], head_dim),
|
604
|
+
dtype=config["dtype"],
|
605
|
+
device=config["device"],
|
606
|
+
)
|
607
|
+
v = None
|
608
|
+
|
609
|
+
# Run forward decode
|
610
|
+
output = backend.forward_decode(q, k, v, layer, fb)
|
611
|
+
|
612
|
+
# Shape and sanity checks
|
613
|
+
expected_shape = (
|
614
|
+
batch_size,
|
615
|
+
config["num_attention_heads"] * config["v_head_dim"],
|
616
|
+
)
|
617
|
+
self.assertEqual(
|
618
|
+
output.shape,
|
619
|
+
expected_shape,
|
620
|
+
f"Output shape mismatch for {test_case['name']}",
|
621
|
+
)
|
622
|
+
self.assertEqual(output.dtype, config["dtype"])
|
623
|
+
self.assertEqual(output.device.type, "cuda")
|
624
|
+
self.assertFalse(
|
625
|
+
torch.isnan(output).any(),
|
626
|
+
f"Output contains NaN for {test_case['name']}",
|
627
|
+
)
|
628
|
+
self.assertFalse(
|
629
|
+
torch.isinf(output).any(),
|
630
|
+
f"Output contains Inf for {test_case['name']}",
|
631
|
+
)
|
632
|
+
|
633
|
+
def test_metadata_initialization(self):
|
634
|
+
"""Test TRTLLM MLA metadata initialization and structure."""
|
635
|
+
print(f"\nRunning metadata initialization tests...")
|
636
|
+
|
637
|
+
for test_case in TEST_CASES["metadata_tests"]:
|
638
|
+
with self.subTest(test_case=test_case["name"]):
|
639
|
+
print(f" Testing {test_case['name']}: {test_case['description']}")
|
640
|
+
|
641
|
+
config = self._merge_config(test_case)
|
642
|
+
batch_size = config["batch_size"]
|
643
|
+
max_seq_len = config["max_seq_len"]
|
644
|
+
|
645
|
+
# Create components
|
646
|
+
model_runner, _, backend, _, layer = self._create_model_components(
|
647
|
+
config
|
648
|
+
)
|
649
|
+
|
650
|
+
# Create varied sequence lengths
|
651
|
+
torch.manual_seed(config["seed_cache"])
|
652
|
+
if batch_size == 1:
|
653
|
+
seq_lens = torch.tensor([max_seq_len], device=config["device"])
|
654
|
+
else:
|
655
|
+
seq_lens = torch.randint(
|
656
|
+
max(1, max_seq_len // 4),
|
657
|
+
max_seq_len + 1,
|
658
|
+
(batch_size,),
|
659
|
+
device=config["device"],
|
660
|
+
)
|
661
|
+
seq_lens[0] = max_seq_len # Ensure at least one max length
|
662
|
+
|
663
|
+
# Create forward batch
|
664
|
+
fb = self._create_forward_batch(
|
665
|
+
batch_size, seq_lens, backend, model_runner, config
|
666
|
+
)
|
667
|
+
|
668
|
+
# Initialize metadata
|
669
|
+
backend.init_forward_metadata(fb)
|
670
|
+
|
671
|
+
# Verify metadata exists
|
672
|
+
self.assertIsNotNone(backend.forward_metadata)
|
673
|
+
self.assertIsInstance(backend.forward_metadata, TRTLLMMLADecodeMetadata)
|
674
|
+
|
675
|
+
# Test metadata structure
|
676
|
+
metadata = backend.forward_metadata
|
677
|
+
self.assertIsNotNone(
|
678
|
+
metadata.workspace, "Workspace should be allocated"
|
679
|
+
)
|
680
|
+
self.assertIsNotNone(
|
681
|
+
metadata.block_kv_indices, "Block KV indices should be created"
|
682
|
+
)
|
683
|
+
|
684
|
+
# Test workspace properties
|
685
|
+
self.assertEqual(metadata.workspace.device.type, "cuda")
|
686
|
+
self.assertEqual(metadata.workspace.dtype, torch.int8)
|
687
|
+
self.assertGreater(
|
688
|
+
metadata.workspace.numel(), 0, "Workspace should have non-zero size"
|
689
|
+
)
|
690
|
+
|
691
|
+
# Test block KV indices properties
|
692
|
+
self.assertEqual(metadata.block_kv_indices.device.type, "cuda")
|
693
|
+
self.assertEqual(metadata.block_kv_indices.dtype, torch.int32)
|
694
|
+
self.assertEqual(metadata.block_kv_indices.shape[0], batch_size)
|
695
|
+
|
696
|
+
# Verify block indices are valid (>= -1, since -1 is padding)
|
697
|
+
self.assertTrue(
|
698
|
+
(metadata.block_kv_indices >= -1).all(),
|
699
|
+
"All block indices should be >= -1 (with -1 as padding)",
|
700
|
+
)
|
701
|
+
|
702
|
+
def test_metadata_block_calculation(self):
|
703
|
+
"""Test block count calculation logic."""
|
704
|
+
print(f"\nRunning metadata block calculation tests...")
|
705
|
+
|
706
|
+
test_scenarios = [
|
707
|
+
{"seq_len": 31, "page_size": 32, "expected_min_blocks": 1},
|
708
|
+
{"seq_len": 32, "page_size": 32, "expected_min_blocks": 1},
|
709
|
+
{"seq_len": 33, "page_size": 32, "expected_min_blocks": 2},
|
710
|
+
{"seq_len": 128, "page_size": 32, "expected_min_blocks": 4},
|
711
|
+
{"seq_len": 128, "page_size": 64, "expected_min_blocks": 2},
|
712
|
+
]
|
713
|
+
|
714
|
+
for scenario in test_scenarios:
|
715
|
+
with self.subTest(scenario=scenario):
|
716
|
+
config = self._merge_config(
|
717
|
+
{
|
718
|
+
"batch_size": 1,
|
719
|
+
"max_seq_len": scenario["seq_len"],
|
720
|
+
"page_size": scenario["page_size"],
|
721
|
+
}
|
722
|
+
)
|
723
|
+
|
724
|
+
model_runner, _, backend, _, _ = self._create_model_components(config)
|
725
|
+
|
726
|
+
# Test internal block calculation
|
727
|
+
calculated_blocks = backend._calc_padded_blocks(scenario["seq_len"])
|
728
|
+
|
729
|
+
# Should be at least the minimum required
|
730
|
+
self.assertGreaterEqual(
|
731
|
+
calculated_blocks,
|
732
|
+
scenario["expected_min_blocks"],
|
733
|
+
f"Calculated blocks ({calculated_blocks}) should be >= minimum required ({scenario['expected_min_blocks']})",
|
734
|
+
)
|
735
|
+
|
736
|
+
# Should satisfy page_size constraint
|
737
|
+
total_tokens = calculated_blocks * scenario["page_size"]
|
738
|
+
self.assertGreaterEqual(
|
739
|
+
total_tokens,
|
740
|
+
scenario["seq_len"],
|
741
|
+
f"Total tokens ({total_tokens}) should cover sequence length ({scenario['seq_len']})",
|
742
|
+
)
|
743
|
+
|
744
|
+
# Should satisfy TRT-LLM and Triton constraints
|
745
|
+
trtllm_constraint = 128 // scenario["page_size"]
|
746
|
+
constraint_lcm = math.lcm(
|
747
|
+
trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK
|
748
|
+
)
|
749
|
+
self.assertEqual(
|
750
|
+
calculated_blocks % constraint_lcm,
|
751
|
+
0,
|
752
|
+
f"Block count should be multiple of LCM of constraints ({constraint_lcm})",
|
753
|
+
)
|
754
|
+
|
755
|
+
def test_metadata_kv_indices_correctness(self):
|
756
|
+
"""Test KV indices creation and correctness."""
|
757
|
+
print(f"\nRunning KV indices correctness tests...")
|
758
|
+
|
759
|
+
for test_case in TEST_CASES["metadata_tests"][
|
760
|
+
:2
|
761
|
+
]: # Test subset for performance
|
762
|
+
with self.subTest(test_case=test_case["name"]):
|
763
|
+
print(f" Testing {test_case['name']}: {test_case['description']}")
|
764
|
+
|
765
|
+
config = self._merge_config(test_case)
|
766
|
+
batch_size = config["batch_size"]
|
767
|
+
max_seq_len = config["max_seq_len"]
|
768
|
+
|
769
|
+
model_runner, _, backend, _, layer = self._create_model_components(
|
770
|
+
config
|
771
|
+
)
|
772
|
+
|
773
|
+
# Create known sequence lengths
|
774
|
+
torch.manual_seed(config["seed_cache"])
|
775
|
+
if batch_size == 1:
|
776
|
+
seq_lens = torch.tensor([max_seq_len], device=config["device"])
|
777
|
+
else:
|
778
|
+
seq_lens = torch.randint(
|
779
|
+
max_seq_len // 2,
|
780
|
+
max_seq_len + 1,
|
781
|
+
(batch_size,),
|
782
|
+
device=config["device"],
|
783
|
+
)
|
784
|
+
|
785
|
+
fb = self._create_forward_batch(
|
786
|
+
batch_size, seq_lens, backend, model_runner, config
|
787
|
+
)
|
788
|
+
|
789
|
+
# Populate some KV cache to have valid indices
|
790
|
+
self._populate_kv_cache(
|
791
|
+
batch_size, seq_lens, [model_runner], layer, config
|
792
|
+
)
|
793
|
+
|
794
|
+
# Initialize metadata
|
795
|
+
backend.init_forward_metadata(fb)
|
796
|
+
metadata = backend.forward_metadata
|
797
|
+
|
798
|
+
# Verify KV indices structure
|
799
|
+
block_kv_indices = metadata.block_kv_indices
|
800
|
+
|
801
|
+
for i in range(batch_size):
|
802
|
+
seq_len = seq_lens[i].item()
|
803
|
+
expected_blocks = backend._calc_padded_blocks(seq_len)
|
804
|
+
|
805
|
+
# Count valid (non -1) indices for this sequence
|
806
|
+
valid_indices = (block_kv_indices[i] >= 0).sum().item()
|
807
|
+
|
808
|
+
# Should have at least enough blocks for the sequence
|
809
|
+
min_required_blocks = (seq_len + config["page_size"] - 1) // config[
|
810
|
+
"page_size"
|
811
|
+
]
|
812
|
+
self.assertGreaterEqual(
|
813
|
+
valid_indices,
|
814
|
+
min_required_blocks,
|
815
|
+
f"Sequence {i} should have at least {min_required_blocks} valid blocks, got {valid_indices}",
|
816
|
+
)
|
817
|
+
|
818
|
+
# Verify indices are within valid range
|
819
|
+
valid_block_indices = block_kv_indices[i][block_kv_indices[i] >= 0]
|
820
|
+
if len(valid_block_indices) > 0:
|
821
|
+
max_possible_blocks = (
|
822
|
+
model_runner.token_to_kv_pool.size // config["page_size"]
|
823
|
+
)
|
824
|
+
self.assertTrue(
|
825
|
+
(valid_block_indices < max_possible_blocks).all(),
|
826
|
+
f"All block indices should be < {max_possible_blocks}",
|
827
|
+
)
|
828
|
+
|
829
|
+
def test_metadata_cuda_graph_compatibility(self):
|
830
|
+
"""Test metadata compatibility with CUDA graph capture/replay."""
|
831
|
+
print(f"\nRunning CUDA graph compatibility tests...")
|
832
|
+
|
833
|
+
config = self._merge_config(
|
834
|
+
{"batch_size": 4, "max_seq_len": 64, "page_size": 32}
|
835
|
+
)
|
836
|
+
|
837
|
+
model_runner, _, backend, _, layer = self._create_model_components(config)
|
838
|
+
batch_size = config["batch_size"]
|
839
|
+
|
840
|
+
# Initialize CUDA graph state
|
841
|
+
backend.init_cuda_graph_state(
|
842
|
+
max_bs=batch_size, max_num_tokens=config["max_seq_len"] * batch_size
|
843
|
+
)
|
844
|
+
|
845
|
+
# Verify CUDA graph buffers are allocated
|
846
|
+
self.assertIsNotNone(backend.cuda_graph_kv_indices)
|
847
|
+
self.assertIsNotNone(backend.cuda_graph_workspace)
|
848
|
+
|
849
|
+
# Test capture metadata
|
850
|
+
seq_lens = torch.full(
|
851
|
+
(batch_size,), config["max_seq_len"], device=config["device"]
|
852
|
+
)
|
853
|
+
req_pool_indices = torch.arange(batch_size, device=config["device"])
|
854
|
+
|
855
|
+
backend.init_forward_metadata_capture_cuda_graph(
|
856
|
+
bs=batch_size,
|
857
|
+
num_tokens=batch_size,
|
858
|
+
req_pool_indices=req_pool_indices,
|
859
|
+
seq_lens=seq_lens,
|
860
|
+
encoder_lens=None,
|
861
|
+
forward_mode=ForwardMode.DECODE,
|
862
|
+
spec_info=None,
|
863
|
+
)
|
864
|
+
|
865
|
+
# Verify capture metadata
|
866
|
+
self.assertIn(batch_size, backend.decode_cuda_graph_metadata)
|
867
|
+
capture_metadata = backend.decode_cuda_graph_metadata[batch_size]
|
868
|
+
|
869
|
+
self.assertIsNotNone(capture_metadata.workspace)
|
870
|
+
self.assertIsNotNone(capture_metadata.block_kv_indices)
|
871
|
+
|
872
|
+
# Test replay with different sequence lengths
|
873
|
+
new_seq_lens = torch.randint(
|
874
|
+
config["max_seq_len"] // 2,
|
875
|
+
config["max_seq_len"] + 1,
|
876
|
+
(batch_size,),
|
877
|
+
device=config["device"],
|
878
|
+
)
|
879
|
+
|
880
|
+
backend.init_forward_metadata_replay_cuda_graph(
|
881
|
+
bs=batch_size,
|
882
|
+
req_pool_indices=req_pool_indices,
|
883
|
+
seq_lens=new_seq_lens,
|
884
|
+
seq_lens_sum=new_seq_lens.sum().item(),
|
885
|
+
encoder_lens=None,
|
886
|
+
forward_mode=ForwardMode.DECODE,
|
887
|
+
spec_info=None,
|
888
|
+
seq_lens_cpu=new_seq_lens.cpu(),
|
889
|
+
)
|
890
|
+
|
891
|
+
# Verify replay updated the metadata
|
892
|
+
replay_metadata = backend.forward_metadata
|
893
|
+
self.assertIsNotNone(replay_metadata)
|
894
|
+
self.assertEqual(
|
895
|
+
replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr()
|
896
|
+
)
|
897
|
+
|
898
|
+
def test_metadata_consistency_across_calls(self):
|
899
|
+
"""Test metadata consistency across multiple forward calls."""
|
900
|
+
print(f"\nRunning metadata consistency tests...")
|
901
|
+
|
902
|
+
config = self._merge_config(
|
903
|
+
{"batch_size": 2, "max_seq_len": 64, "page_size": 32}
|
904
|
+
)
|
905
|
+
|
906
|
+
model_runner, _, backend, _, layer = self._create_model_components(config)
|
907
|
+
|
908
|
+
# First call
|
909
|
+
seq_lens_1 = torch.tensor([32, 48], device=config["device"])
|
910
|
+
fb_1 = self._create_forward_batch(
|
911
|
+
config["batch_size"], seq_lens_1, backend, model_runner, config
|
912
|
+
)
|
913
|
+
backend.init_forward_metadata(fb_1)
|
914
|
+
metadata_1 = backend.forward_metadata
|
915
|
+
|
916
|
+
# Second call with same sequence lengths
|
917
|
+
seq_lens_2 = torch.tensor([32, 48], device=config["device"])
|
918
|
+
fb_2 = self._create_forward_batch(
|
919
|
+
config["batch_size"], seq_lens_2, backend, model_runner, config
|
920
|
+
)
|
921
|
+
backend.init_forward_metadata(fb_2)
|
922
|
+
metadata_2 = backend.forward_metadata
|
923
|
+
|
924
|
+
# Metadata structure should be consistent
|
925
|
+
self.assertEqual(metadata_1.workspace.shape, metadata_2.workspace.shape)
|
926
|
+
self.assertEqual(
|
927
|
+
metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape
|
928
|
+
)
|
929
|
+
|
930
|
+
# Third call with different sequence lengths
|
931
|
+
seq_lens_3 = torch.tensor([16, 64], device=config["device"])
|
932
|
+
fb_3 = self._create_forward_batch(
|
933
|
+
config["batch_size"], seq_lens_3, backend, model_runner, config
|
934
|
+
)
|
935
|
+
backend.init_forward_metadata(fb_3)
|
936
|
+
metadata_3 = backend.forward_metadata
|
937
|
+
|
938
|
+
# Should still have valid structure
|
939
|
+
self.assertIsNotNone(metadata_3.workspace)
|
940
|
+
self.assertIsNotNone(metadata_3.block_kv_indices)
|
941
|
+
self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
|
942
|
+
|
943
|
+
|
944
|
+
if __name__ == "__main__":
|
945
|
+
unittest.main()
|