sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/conversation.py +0 -112
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
  7. sglang/srt/disaggregation/launch_lb.py +5 -20
  8. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  9. sglang/srt/disaggregation/prefill.py +1 -0
  10. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  11. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  12. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  13. sglang/srt/distributed/parallel_state.py +11 -0
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +35 -15
  16. sglang/srt/eplb/expert_distribution.py +4 -2
  17. sglang/srt/hf_transformers_utils.py +25 -10
  18. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  19. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  20. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  21. sglang/srt/layers/attention/utils.py +6 -1
  22. sglang/srt/layers/attention/vision.py +27 -10
  23. sglang/srt/layers/communicator.py +14 -4
  24. sglang/srt/layers/linear.py +7 -1
  25. sglang/srt/layers/logits_processor.py +9 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +29 -68
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
  29. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
  30. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  31. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  32. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  33. sglang/srt/layers/moe/utils.py +43 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  35. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  36. sglang/srt/layers/quantization/fp8.py +57 -1
  37. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  38. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  39. sglang/srt/layers/vocab_parallel_embedding.py +7 -1
  40. sglang/srt/lora/lora_registry.py +7 -0
  41. sglang/srt/managers/cache_controller.py +43 -39
  42. sglang/srt/managers/data_parallel_controller.py +52 -2
  43. sglang/srt/managers/io_struct.py +6 -1
  44. sglang/srt/managers/schedule_batch.py +3 -2
  45. sglang/srt/managers/schedule_policy.py +3 -1
  46. sglang/srt/managers/scheduler.py +145 -6
  47. sglang/srt/managers/template_manager.py +25 -22
  48. sglang/srt/managers/tokenizer_manager.py +114 -62
  49. sglang/srt/managers/utils.py +45 -1
  50. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  51. sglang/srt/mem_cache/hicache_storage.py +13 -12
  52. sglang/srt/mem_cache/hiradix_cache.py +21 -4
  53. sglang/srt/mem_cache/memory_pool.py +15 -118
  54. sglang/srt/mem_cache/memory_pool_host.py +350 -33
  55. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  56. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
  57. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  58. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
  59. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
  60. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
  61. sglang/srt/model_executor/cuda_graph_runner.py +42 -4
  62. sglang/srt/model_executor/forward_batch_info.py +13 -3
  63. sglang/srt/model_executor/model_runner.py +13 -1
  64. sglang/srt/model_loader/weight_utils.py +2 -0
  65. sglang/srt/models/deepseek_v2.py +28 -23
  66. sglang/srt/models/glm4_moe.py +85 -22
  67. sglang/srt/models/grok.py +3 -3
  68. sglang/srt/models/llama4.py +13 -2
  69. sglang/srt/models/mixtral.py +3 -3
  70. sglang/srt/models/mllama4.py +428 -19
  71. sglang/srt/models/qwen2_moe.py +1 -4
  72. sglang/srt/models/qwen3_moe.py +7 -8
  73. sglang/srt/models/step3_vl.py +1 -4
  74. sglang/srt/multimodal/processors/base_processor.py +4 -3
  75. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  76. sglang/srt/operations_strategy.py +1 -1
  77. sglang/srt/server_args.py +115 -21
  78. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  79. sglang/srt/two_batch_overlap.py +6 -4
  80. sglang/srt/utils.py +4 -24
  81. sglang/srt/weight_sync/utils.py +1 -1
  82. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  83. sglang/test/runners.py +2 -2
  84. sglang/test/test_utils.py +3 -3
  85. sglang/version.py +1 -1
  86. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
  87. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
  88. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  89. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  90. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
  91. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
  92. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,945 @@
1
+ import math
2
+ import unittest
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from sglang.srt.layers import dp_attention as _dp_attn
8
+
9
+ # Patch DP-attention globals before importing backends
10
+ # TODO: change the interface of both trtllm_mla and flashinfer backends to take tp_size as an argument instead of patching
11
+ _dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test
12
+
13
+ from sglang.srt.configs.model_config import AttentionArch
14
+ from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
15
+ from sglang.srt.layers.attention.trtllm_mla_backend import (
16
+ TRTLLMMLABackend,
17
+ TRTLLMMLADecodeMetadata,
18
+ )
19
+ from sglang.srt.layers.attention.utils import TRITON_PAD_NUM_PAGE_PER_BLOCK
20
+ from sglang.srt.layers.radix_attention import RadixAttention
21
+ from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
22
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
+ from sglang.srt.utils import is_flashinfer_available
24
+ from sglang.test.test_utils import CustomTestCase
25
+
26
+ # Global configuration for all tests
27
+ DEFAULT_CONFIG = {
28
+ "device": "cuda",
29
+ "dtype": torch.bfloat16,
30
+ "kv_cache_dtype": torch.bfloat16,
31
+ "context_len": 2048,
32
+ "max_bs": 64,
33
+ "tolerance": 1e-2,
34
+ "seed_cache": 42,
35
+ "seed_qkv": 123,
36
+ # MLA model config (TRTLLM MLA has fixed constraints)
37
+ "num_attention_heads": 128,
38
+ "kv_lora_rank": 512,
39
+ "qk_nope_head_dim": 128,
40
+ "qk_rope_head_dim": 64,
41
+ "v_head_dim": 512,
42
+ "num_kv_heads": 1,
43
+ "layer_id": 0,
44
+ }
45
+
46
+ # Centralized test cases for different test scenarios
47
+ TEST_CASES = {
48
+ "basic_functionality": [
49
+ {
50
+ "name": "single",
51
+ "batch_size": 1,
52
+ "max_seq_len": 32,
53
+ "page_size": 32,
54
+ "description": "Minimal smoke test",
55
+ },
56
+ {
57
+ "name": "batch",
58
+ "batch_size": 32,
59
+ "max_seq_len": 128,
60
+ "page_size": 32,
61
+ "description": "Medium-scale batch",
62
+ },
63
+ ],
64
+ "decode_output_match": [
65
+ {
66
+ "name": "single",
67
+ "batch_size": 1,
68
+ "max_seq_len": 64,
69
+ "page_size": 32,
70
+ "description": "Single vs reference",
71
+ },
72
+ {
73
+ "name": "batch",
74
+ "batch_size": 32,
75
+ "max_seq_len": 64,
76
+ "page_size": 32,
77
+ "description": "Batch vs reference",
78
+ },
79
+ ],
80
+ "page_size_consistency": [
81
+ # Only 32 and 64 supported for now in flashinfer TRTLLM-GEN MLA kernel
82
+ {
83
+ "name": "page_32",
84
+ "batch_size": 8,
85
+ "max_seq_len": 128,
86
+ "page_size": 32,
87
+ "description": "32-token pages",
88
+ },
89
+ {
90
+ "name": "page_64",
91
+ "batch_size": 8,
92
+ "max_seq_len": 128,
93
+ "page_size": 64,
94
+ "description": "64-token pages",
95
+ },
96
+ ],
97
+ "shape_sanity_tests": [
98
+ {
99
+ "name": "basic",
100
+ "batch_size": 1,
101
+ "max_seq_len": 128,
102
+ "page_size": 32,
103
+ "description": "Single sequence",
104
+ },
105
+ {
106
+ "name": "basic_different_pagesize",
107
+ "batch_size": 1,
108
+ "max_seq_len": 128,
109
+ "page_size": 64,
110
+ "description": "Different page size",
111
+ },
112
+ {
113
+ "name": "batch",
114
+ "batch_size": 8,
115
+ "max_seq_len": 128,
116
+ "page_size": 32,
117
+ "description": "Batch shapes",
118
+ },
119
+ ],
120
+ "metadata_tests": [
121
+ {
122
+ "name": "single_sequence",
123
+ "batch_size": 1,
124
+ "max_seq_len": 64,
125
+ "page_size": 32,
126
+ "description": "Single sequence metadata",
127
+ },
128
+ {
129
+ "name": "batch_mixed_lengths",
130
+ "batch_size": 8,
131
+ "max_seq_len": 128,
132
+ "page_size": 32,
133
+ "description": "Mixed sequence lengths",
134
+ },
135
+ {
136
+ "name": "large_batch",
137
+ "batch_size": 32,
138
+ "max_seq_len": 256,
139
+ "page_size": 64,
140
+ "description": "Large batch stress test",
141
+ },
142
+ {
143
+ "name": "edge_case_short",
144
+ "batch_size": 4,
145
+ "max_seq_len": 16,
146
+ "page_size": 32,
147
+ "description": "Sub-page sequences",
148
+ },
149
+ ],
150
+ }
151
+
152
+
153
+ class MockModelRunner:
154
+ """Minimal fake ModelRunner for testing MLA backends."""
155
+
156
+ def __init__(self, config):
157
+ self.device = config["device"]
158
+ self.dtype = config["dtype"]
159
+ self.kv_cache_dtype = config["kv_cache_dtype"]
160
+ self.page_size = config["page_size"]
161
+
162
+ # Model-config stub with MLA attributes
163
+ self.model_config = type(
164
+ "ModelConfig",
165
+ (),
166
+ {
167
+ "context_len": config["context_len"],
168
+ "attention_arch": AttentionArch.MLA,
169
+ "num_attention_heads": config["num_attention_heads"],
170
+ "kv_lora_rank": config["kv_lora_rank"],
171
+ "qk_nope_head_dim": config["qk_nope_head_dim"],
172
+ "qk_rope_head_dim": config["qk_rope_head_dim"],
173
+ "v_head_dim": config["v_head_dim"],
174
+ "scaling": 1.0
175
+ / ((config["qk_nope_head_dim"] + config["qk_rope_head_dim"]) ** 0.5),
176
+ "get_num_kv_heads": staticmethod(lambda _: config["num_kv_heads"]),
177
+ },
178
+ )
179
+
180
+ # Req-to-token pool
181
+ max_bs = config["max_bs"]
182
+ max_ctx = self.model_config.context_len
183
+ req_to_token = torch.arange(
184
+ max_bs * max_ctx, dtype=torch.int32, device=self.device
185
+ ).reshape(max_bs, max_ctx)
186
+ self.req_to_token_pool = type(
187
+ "TokenPool",
188
+ (),
189
+ {
190
+ "size": max_bs,
191
+ "req_to_token": req_to_token,
192
+ },
193
+ )
194
+
195
+ # KV-token pool (MLA)
196
+ self.token_to_kv_pool = MLATokenToKVPool(
197
+ size=max_bs * max_ctx,
198
+ page_size=config["page_size"],
199
+ dtype=self.kv_cache_dtype,
200
+ kv_lora_rank=config["kv_lora_rank"],
201
+ qk_rope_head_dim=config["qk_rope_head_dim"],
202
+ layer_num=1,
203
+ device=self.device,
204
+ enable_memory_saver=False,
205
+ )
206
+
207
+
208
+ def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
209
+ """Compare outputs with detailed analysis."""
210
+
211
+ # Basic checks
212
+ assert (
213
+ trtllm_out.shape == reference_out.shape
214
+ ), f"Shape mismatch: {trtllm_out.shape} vs {reference_out.shape}"
215
+ assert (
216
+ trtllm_out.dtype == reference_out.dtype
217
+ ), f"Dtype mismatch: {trtllm_out.dtype} vs {reference_out.dtype}"
218
+
219
+ # Check for NaN/Inf
220
+ assert not torch.isnan(trtllm_out).any(), "TRTLLM output contains NaN"
221
+ assert not torch.isnan(reference_out).any(), "Reference output contains NaN"
222
+ assert not torch.isinf(trtllm_out).any(), "TRTLLM output contains Inf"
223
+ assert not torch.isinf(reference_out).any(), "Reference output contains Inf"
224
+
225
+ # Element-wise differences
226
+ diff = (trtllm_out - reference_out).abs()
227
+ max_diff = diff.max().item()
228
+ mean_diff = diff.mean().item()
229
+
230
+ # Check numerical equivalence
231
+ all_close = torch.allclose(
232
+ trtllm_out, reference_out, rtol=tolerance, atol=tolerance
233
+ )
234
+
235
+ if not all_close:
236
+ print(
237
+ f"Comparison failed: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, tolerance={tolerance}"
238
+ )
239
+ # Find top differences for debugging
240
+ flat_diff = diff.flatten()
241
+ top_diff_indices = torch.topk(flat_diff, k=min(5, flat_diff.numel())).indices
242
+ print("Top 5 differences:")
243
+ for i, idx in enumerate(top_diff_indices):
244
+ idx_tuple = np.unravel_index(idx.cpu().numpy(), trtllm_out.shape)
245
+ trt_val = trtllm_out[idx_tuple].item()
246
+ ref_val = reference_out[idx_tuple].item()
247
+ print(
248
+ f" [{idx_tuple}]: TRTLLM={trt_val:.6f}, Reference={ref_val:.6f}, diff={abs(trt_val-ref_val):.6f}"
249
+ )
250
+
251
+ return all_close
252
+
253
+
254
+ @unittest.skipIf(
255
+ not torch.cuda.is_available() or not is_flashinfer_available(),
256
+ "CUDA + flashinfer required",
257
+ )
258
+ class TestTRTLLMMLA(CustomTestCase):
259
+ """Test suite for TRTLLM MLA backend with centralized configuration."""
260
+
261
+ def _merge_config(self, test_case):
262
+ """Merge test case with default configuration."""
263
+ config = DEFAULT_CONFIG.copy()
264
+ config.update(test_case)
265
+ return config
266
+
267
+ def _create_model_components(self, config):
268
+ """Create model runners, backends, and layer for testing."""
269
+ # Create model runners
270
+ model_runner_trtllm = MockModelRunner(config)
271
+ model_runner_reference = MockModelRunner(config)
272
+
273
+ # Create backends
274
+ trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
275
+ reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
276
+
277
+ # Create RadixAttention layer
278
+ layer = RadixAttention(
279
+ num_heads=config["num_attention_heads"],
280
+ head_dim=config["kv_lora_rank"] + config["qk_rope_head_dim"],
281
+ scaling=model_runner_trtllm.model_config.scaling,
282
+ num_kv_heads=config["num_kv_heads"],
283
+ layer_id=config["layer_id"],
284
+ v_head_dim=config["v_head_dim"],
285
+ prefix="attn_mqa",
286
+ )
287
+
288
+ return (
289
+ model_runner_trtllm,
290
+ model_runner_reference,
291
+ trtllm_backend,
292
+ reference_backend,
293
+ layer,
294
+ )
295
+
296
+ def _create_qkv_tensors(self, batch_size, config):
297
+ """Create Q, K, V tensors for testing."""
298
+ head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
299
+ device = config["device"]
300
+ dtype = config["dtype"]
301
+
302
+ q = torch.randn(
303
+ (batch_size, config["num_attention_heads"], head_dim),
304
+ dtype=dtype,
305
+ device=device,
306
+ )
307
+ k = torch.randn(
308
+ (batch_size, config["num_kv_heads"], head_dim), dtype=dtype, device=device
309
+ )
310
+ v = torch.randn(
311
+ (batch_size, config["num_kv_heads"], config["v_head_dim"]),
312
+ dtype=dtype,
313
+ device=device,
314
+ )
315
+ return q, k, v
316
+
317
+ def _create_forward_batch(
318
+ self, batch_size, seq_lens, backend, model_runner, config
319
+ ):
320
+ """Create a forward batch for the given backend."""
321
+ fb = ForwardBatch(
322
+ batch_size=batch_size,
323
+ input_ids=torch.randint(0, 100, (batch_size, 1), device=config["device"]),
324
+ out_cache_loc=torch.arange(batch_size, device=config["device"]),
325
+ seq_lens_sum=int(seq_lens.sum().item()),
326
+ forward_mode=ForwardMode.DECODE,
327
+ req_pool_indices=torch.arange(batch_size, device=config["device"]),
328
+ seq_lens=seq_lens,
329
+ seq_lens_cpu=seq_lens.cpu(),
330
+ attn_backend=backend,
331
+ )
332
+ fb.req_to_token_pool = model_runner.req_to_token_pool
333
+ fb.token_to_kv_pool = model_runner.token_to_kv_pool
334
+ return fb
335
+
336
+ def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config):
337
+ """Populate KV cache with identical data for both backends."""
338
+ torch.manual_seed(config["seed_cache"]) # Fixed seed for reproducible cache
339
+
340
+ for model_runner in model_runners:
341
+ torch.manual_seed(config["seed_cache"]) # Reset seed for each backend
342
+ for i in range(batch_size):
343
+ seq_len = int(seq_lens[i].item())
344
+ for token_idx in range(seq_len - 1):
345
+ # Create random K components for MLA
346
+ cache_k_nope = torch.randn(
347
+ (1, config["qk_nope_head_dim"]),
348
+ dtype=config["dtype"],
349
+ device=config["device"],
350
+ )
351
+ cache_k_rope = torch.randn(
352
+ (1, config["qk_rope_head_dim"]),
353
+ dtype=config["dtype"],
354
+ device=config["device"],
355
+ )
356
+
357
+ # Calculate cache location
358
+ cache_loc = model_runner.req_to_token_pool.req_to_token[
359
+ i, token_idx
360
+ ]
361
+
362
+ # Save to KV cache
363
+ model_runner.token_to_kv_pool.set_mla_kv_buffer(
364
+ layer,
365
+ cache_loc.unsqueeze(0),
366
+ cache_k_nope.squeeze(0),
367
+ cache_k_rope.squeeze(0),
368
+ )
369
+
370
+ def test_basic_functionality(self):
371
+ """Test basic functionality with minimal setup."""
372
+ print(f"\nRunning basic functionality tests...")
373
+
374
+ for test_case in TEST_CASES["basic_functionality"]:
375
+ with self.subTest(test_case=test_case["name"]):
376
+ print(f" Testing {test_case['name']}: {test_case['description']}")
377
+
378
+ config = self._merge_config(test_case)
379
+ batch_size = config["batch_size"]
380
+ max_seq_len = config["max_seq_len"]
381
+
382
+ # Create components
383
+ model_runner_trtllm, _, trtllm_backend, _, layer = (
384
+ self._create_model_components(config)
385
+ )
386
+
387
+ # Create sequence lengths - properly handle different batch sizes
388
+ if batch_size == 2:
389
+ seq_lens = torch.tensor(
390
+ [max_seq_len, max_seq_len // 2], device=config["device"]
391
+ )
392
+ else:
393
+ # For larger batch sizes, create varied sequence lengths
394
+ torch.manual_seed(config["seed_cache"])
395
+ seq_lens = torch.randint(
396
+ max_seq_len // 2,
397
+ max_seq_len + 1,
398
+ (batch_size,),
399
+ device=config["device"],
400
+ )
401
+ seq_lens[0] = max_seq_len # Ensure at least one max length
402
+
403
+ # Create forward batch
404
+ fb = self._create_forward_batch(
405
+ batch_size, seq_lens, trtllm_backend, model_runner_trtllm, config
406
+ )
407
+ trtllm_backend.init_forward_metadata(fb)
408
+
409
+ # Populate KV cache
410
+ self._populate_kv_cache(
411
+ batch_size, seq_lens, [model_runner_trtllm], layer, config
412
+ )
413
+
414
+ # Create Q, K, V tensors
415
+ torch.manual_seed(config["seed_qkv"])
416
+ q, k, v = self._create_qkv_tensors(batch_size, config)
417
+
418
+ # Run forward decode
419
+ output = trtllm_backend.forward_decode(q, k, v, layer, fb)
420
+
421
+ # Basic checks
422
+ expected_shape = (
423
+ batch_size,
424
+ config["num_attention_heads"] * config["v_head_dim"],
425
+ )
426
+ self.assertEqual(output.shape, expected_shape)
427
+ self.assertEqual(output.dtype, config["dtype"])
428
+ self.assertFalse(torch.isnan(output).any())
429
+ self.assertFalse(torch.isinf(output).any())
430
+
431
+ def test_decode_output_match(self):
432
+ """Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
433
+ print(f"\nRunning decode output matching tests...")
434
+
435
+ for test_case in TEST_CASES["decode_output_match"]:
436
+ with self.subTest(test_case=test_case["name"]):
437
+ print(f" Testing {test_case['name']}: {test_case['description']}")
438
+
439
+ config = self._merge_config(test_case)
440
+ batch_size = config["batch_size"]
441
+ max_seq_len = config["max_seq_len"]
442
+
443
+ # Create components
444
+ (
445
+ model_runner_trtllm,
446
+ model_runner_reference,
447
+ trtllm_backend,
448
+ reference_backend,
449
+ layer,
450
+ ) = self._create_model_components(config)
451
+
452
+ # Create identical sequence lengths for both backends
453
+ torch.manual_seed(config["seed_cache"])
454
+ seq_lens = torch.randint(
455
+ 1, max_seq_len, (batch_size,), device=config["device"]
456
+ )
457
+ seq_lens[0] = max_seq_len # Ensure at least one max length
458
+
459
+ # Create forward batches with identical inputs
460
+ fb_trtllm = self._create_forward_batch(
461
+ batch_size,
462
+ seq_lens.clone(),
463
+ trtllm_backend,
464
+ model_runner_trtllm,
465
+ config,
466
+ )
467
+ fb_reference = self._create_forward_batch(
468
+ batch_size,
469
+ seq_lens.clone(),
470
+ reference_backend,
471
+ model_runner_reference,
472
+ config,
473
+ )
474
+
475
+ # Initialize metadata for both backends
476
+ trtllm_backend.init_forward_metadata(fb_trtllm)
477
+ reference_backend.init_forward_metadata(fb_reference)
478
+
479
+ # Populate both KV caches identically
480
+ self._populate_kv_cache(
481
+ batch_size,
482
+ seq_lens,
483
+ [model_runner_trtllm, model_runner_reference],
484
+ layer,
485
+ config,
486
+ )
487
+
488
+ # Create Q, K, V tensors for current decode step
489
+ torch.manual_seed(config["seed_qkv"])
490
+ q, k, v = self._create_qkv_tensors(batch_size, config)
491
+
492
+ # Run forward decode on both backends
493
+ out_trtllm = trtllm_backend.forward_decode(
494
+ q.clone(), k.clone(), v.clone(), layer, fb_trtllm
495
+ )
496
+ out_reference = reference_backend.forward_decode(
497
+ q.clone(), k.clone(), v.clone(), layer, fb_reference
498
+ )
499
+
500
+ # Compare outputs
501
+ comparison_passed = compare_outputs(
502
+ out_trtllm, out_reference, tolerance=config["tolerance"]
503
+ )
504
+
505
+ self.assertTrue(
506
+ comparison_passed,
507
+ f"TRTLLM and Reference outputs differ beyond tolerance. "
508
+ f"Config: {test_case['name']}, "
509
+ f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
510
+ )
511
+
512
+ def test_page_size_consistency(self):
513
+ """Test output consistency across different page sizes."""
514
+ print(f"\nRunning page size consistency tests...")
515
+
516
+ for test_case in TEST_CASES["page_size_consistency"]:
517
+ with self.subTest(test_case=test_case["name"]):
518
+ print(f" Testing {test_case['name']}: {test_case['description']}")
519
+
520
+ config = self._merge_config(test_case)
521
+ batch_size = config["batch_size"]
522
+ max_seq_len = config["max_seq_len"]
523
+
524
+ # Create components
525
+ model_runner, _, backend, _, layer = self._create_model_components(
526
+ config
527
+ )
528
+
529
+ # Create sequence lengths
530
+ torch.manual_seed(config["seed_cache"])
531
+ seq_lens = torch.randint(
532
+ 1, max_seq_len, (batch_size,), device=config["device"]
533
+ )
534
+ seq_lens[0] = max_seq_len
535
+
536
+ # Create forward batch
537
+ fb = self._create_forward_batch(
538
+ batch_size, seq_lens, backend, model_runner, config
539
+ )
540
+ backend.init_forward_metadata(fb)
541
+
542
+ # Populate KV cache
543
+ self._populate_kv_cache(
544
+ batch_size, seq_lens, [model_runner], layer, config
545
+ )
546
+
547
+ # Create Q, K, V tensors
548
+ torch.manual_seed(config["seed_qkv"])
549
+ q, k, v = self._create_qkv_tensors(batch_size, config)
550
+
551
+ # Run forward decode
552
+ output = backend.forward_decode(q, k, v, layer, fb)
553
+
554
+ expected_shape = (
555
+ batch_size,
556
+ config["num_attention_heads"] * config["v_head_dim"],
557
+ )
558
+ self.assertEqual(
559
+ output.shape,
560
+ expected_shape,
561
+ f"Output shape mismatch: {output.shape} vs {expected_shape}",
562
+ )
563
+ self.assertFalse(torch.isnan(output).any(), "Output contains NaN")
564
+ self.assertFalse(torch.isinf(output).any(), "Output contains Inf")
565
+
566
+ def test_shape_sanity(self):
567
+ """Smoke test decode across several configurations."""
568
+ print(f"\nRunning shape sanity tests...")
569
+
570
+ for test_case in TEST_CASES["shape_sanity_tests"]:
571
+ with self.subTest(test_case=test_case["name"]):
572
+ print(f" Testing {test_case['name']}: {test_case['description']}")
573
+
574
+ config = self._merge_config(test_case)
575
+ batch_size = config["batch_size"]
576
+ max_seq_len = config["max_seq_len"]
577
+
578
+ model_runner, _, backend, _, layer = self._create_model_components(
579
+ config
580
+ )
581
+
582
+ # Random seq lens (ensure one matches max)
583
+ torch.manual_seed(config["seed_cache"])
584
+ seq_lens = torch.randint(
585
+ 1, max_seq_len, (batch_size,), device=config["device"]
586
+ )
587
+ seq_lens[0] = max_seq_len
588
+
589
+ fb = self._create_forward_batch(
590
+ batch_size, seq_lens, backend, model_runner, config
591
+ )
592
+ backend.init_forward_metadata(fb)
593
+
594
+ # Create Q, K, V tensors
595
+ torch.manual_seed(config["seed_qkv"])
596
+ head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
597
+ q = torch.randn(
598
+ (batch_size, config["num_attention_heads"], head_dim),
599
+ dtype=config["dtype"],
600
+ device=config["device"],
601
+ )
602
+ k = torch.randn(
603
+ (batch_size, config["num_kv_heads"], head_dim),
604
+ dtype=config["dtype"],
605
+ device=config["device"],
606
+ )
607
+ v = None
608
+
609
+ # Run forward decode
610
+ output = backend.forward_decode(q, k, v, layer, fb)
611
+
612
+ # Shape and sanity checks
613
+ expected_shape = (
614
+ batch_size,
615
+ config["num_attention_heads"] * config["v_head_dim"],
616
+ )
617
+ self.assertEqual(
618
+ output.shape,
619
+ expected_shape,
620
+ f"Output shape mismatch for {test_case['name']}",
621
+ )
622
+ self.assertEqual(output.dtype, config["dtype"])
623
+ self.assertEqual(output.device.type, "cuda")
624
+ self.assertFalse(
625
+ torch.isnan(output).any(),
626
+ f"Output contains NaN for {test_case['name']}",
627
+ )
628
+ self.assertFalse(
629
+ torch.isinf(output).any(),
630
+ f"Output contains Inf for {test_case['name']}",
631
+ )
632
+
633
+ def test_metadata_initialization(self):
634
+ """Test TRTLLM MLA metadata initialization and structure."""
635
+ print(f"\nRunning metadata initialization tests...")
636
+
637
+ for test_case in TEST_CASES["metadata_tests"]:
638
+ with self.subTest(test_case=test_case["name"]):
639
+ print(f" Testing {test_case['name']}: {test_case['description']}")
640
+
641
+ config = self._merge_config(test_case)
642
+ batch_size = config["batch_size"]
643
+ max_seq_len = config["max_seq_len"]
644
+
645
+ # Create components
646
+ model_runner, _, backend, _, layer = self._create_model_components(
647
+ config
648
+ )
649
+
650
+ # Create varied sequence lengths
651
+ torch.manual_seed(config["seed_cache"])
652
+ if batch_size == 1:
653
+ seq_lens = torch.tensor([max_seq_len], device=config["device"])
654
+ else:
655
+ seq_lens = torch.randint(
656
+ max(1, max_seq_len // 4),
657
+ max_seq_len + 1,
658
+ (batch_size,),
659
+ device=config["device"],
660
+ )
661
+ seq_lens[0] = max_seq_len # Ensure at least one max length
662
+
663
+ # Create forward batch
664
+ fb = self._create_forward_batch(
665
+ batch_size, seq_lens, backend, model_runner, config
666
+ )
667
+
668
+ # Initialize metadata
669
+ backend.init_forward_metadata(fb)
670
+
671
+ # Verify metadata exists
672
+ self.assertIsNotNone(backend.forward_metadata)
673
+ self.assertIsInstance(backend.forward_metadata, TRTLLMMLADecodeMetadata)
674
+
675
+ # Test metadata structure
676
+ metadata = backend.forward_metadata
677
+ self.assertIsNotNone(
678
+ metadata.workspace, "Workspace should be allocated"
679
+ )
680
+ self.assertIsNotNone(
681
+ metadata.block_kv_indices, "Block KV indices should be created"
682
+ )
683
+
684
+ # Test workspace properties
685
+ self.assertEqual(metadata.workspace.device.type, "cuda")
686
+ self.assertEqual(metadata.workspace.dtype, torch.int8)
687
+ self.assertGreater(
688
+ metadata.workspace.numel(), 0, "Workspace should have non-zero size"
689
+ )
690
+
691
+ # Test block KV indices properties
692
+ self.assertEqual(metadata.block_kv_indices.device.type, "cuda")
693
+ self.assertEqual(metadata.block_kv_indices.dtype, torch.int32)
694
+ self.assertEqual(metadata.block_kv_indices.shape[0], batch_size)
695
+
696
+ # Verify block indices are valid (>= -1, since -1 is padding)
697
+ self.assertTrue(
698
+ (metadata.block_kv_indices >= -1).all(),
699
+ "All block indices should be >= -1 (with -1 as padding)",
700
+ )
701
+
702
+ def test_metadata_block_calculation(self):
703
+ """Test block count calculation logic."""
704
+ print(f"\nRunning metadata block calculation tests...")
705
+
706
+ test_scenarios = [
707
+ {"seq_len": 31, "page_size": 32, "expected_min_blocks": 1},
708
+ {"seq_len": 32, "page_size": 32, "expected_min_blocks": 1},
709
+ {"seq_len": 33, "page_size": 32, "expected_min_blocks": 2},
710
+ {"seq_len": 128, "page_size": 32, "expected_min_blocks": 4},
711
+ {"seq_len": 128, "page_size": 64, "expected_min_blocks": 2},
712
+ ]
713
+
714
+ for scenario in test_scenarios:
715
+ with self.subTest(scenario=scenario):
716
+ config = self._merge_config(
717
+ {
718
+ "batch_size": 1,
719
+ "max_seq_len": scenario["seq_len"],
720
+ "page_size": scenario["page_size"],
721
+ }
722
+ )
723
+
724
+ model_runner, _, backend, _, _ = self._create_model_components(config)
725
+
726
+ # Test internal block calculation
727
+ calculated_blocks = backend._calc_padded_blocks(scenario["seq_len"])
728
+
729
+ # Should be at least the minimum required
730
+ self.assertGreaterEqual(
731
+ calculated_blocks,
732
+ scenario["expected_min_blocks"],
733
+ f"Calculated blocks ({calculated_blocks}) should be >= minimum required ({scenario['expected_min_blocks']})",
734
+ )
735
+
736
+ # Should satisfy page_size constraint
737
+ total_tokens = calculated_blocks * scenario["page_size"]
738
+ self.assertGreaterEqual(
739
+ total_tokens,
740
+ scenario["seq_len"],
741
+ f"Total tokens ({total_tokens}) should cover sequence length ({scenario['seq_len']})",
742
+ )
743
+
744
+ # Should satisfy TRT-LLM and Triton constraints
745
+ trtllm_constraint = 128 // scenario["page_size"]
746
+ constraint_lcm = math.lcm(
747
+ trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK
748
+ )
749
+ self.assertEqual(
750
+ calculated_blocks % constraint_lcm,
751
+ 0,
752
+ f"Block count should be multiple of LCM of constraints ({constraint_lcm})",
753
+ )
754
+
755
+ def test_metadata_kv_indices_correctness(self):
756
+ """Test KV indices creation and correctness."""
757
+ print(f"\nRunning KV indices correctness tests...")
758
+
759
+ for test_case in TEST_CASES["metadata_tests"][
760
+ :2
761
+ ]: # Test subset for performance
762
+ with self.subTest(test_case=test_case["name"]):
763
+ print(f" Testing {test_case['name']}: {test_case['description']}")
764
+
765
+ config = self._merge_config(test_case)
766
+ batch_size = config["batch_size"]
767
+ max_seq_len = config["max_seq_len"]
768
+
769
+ model_runner, _, backend, _, layer = self._create_model_components(
770
+ config
771
+ )
772
+
773
+ # Create known sequence lengths
774
+ torch.manual_seed(config["seed_cache"])
775
+ if batch_size == 1:
776
+ seq_lens = torch.tensor([max_seq_len], device=config["device"])
777
+ else:
778
+ seq_lens = torch.randint(
779
+ max_seq_len // 2,
780
+ max_seq_len + 1,
781
+ (batch_size,),
782
+ device=config["device"],
783
+ )
784
+
785
+ fb = self._create_forward_batch(
786
+ batch_size, seq_lens, backend, model_runner, config
787
+ )
788
+
789
+ # Populate some KV cache to have valid indices
790
+ self._populate_kv_cache(
791
+ batch_size, seq_lens, [model_runner], layer, config
792
+ )
793
+
794
+ # Initialize metadata
795
+ backend.init_forward_metadata(fb)
796
+ metadata = backend.forward_metadata
797
+
798
+ # Verify KV indices structure
799
+ block_kv_indices = metadata.block_kv_indices
800
+
801
+ for i in range(batch_size):
802
+ seq_len = seq_lens[i].item()
803
+ expected_blocks = backend._calc_padded_blocks(seq_len)
804
+
805
+ # Count valid (non -1) indices for this sequence
806
+ valid_indices = (block_kv_indices[i] >= 0).sum().item()
807
+
808
+ # Should have at least enough blocks for the sequence
809
+ min_required_blocks = (seq_len + config["page_size"] - 1) // config[
810
+ "page_size"
811
+ ]
812
+ self.assertGreaterEqual(
813
+ valid_indices,
814
+ min_required_blocks,
815
+ f"Sequence {i} should have at least {min_required_blocks} valid blocks, got {valid_indices}",
816
+ )
817
+
818
+ # Verify indices are within valid range
819
+ valid_block_indices = block_kv_indices[i][block_kv_indices[i] >= 0]
820
+ if len(valid_block_indices) > 0:
821
+ max_possible_blocks = (
822
+ model_runner.token_to_kv_pool.size // config["page_size"]
823
+ )
824
+ self.assertTrue(
825
+ (valid_block_indices < max_possible_blocks).all(),
826
+ f"All block indices should be < {max_possible_blocks}",
827
+ )
828
+
829
+ def test_metadata_cuda_graph_compatibility(self):
830
+ """Test metadata compatibility with CUDA graph capture/replay."""
831
+ print(f"\nRunning CUDA graph compatibility tests...")
832
+
833
+ config = self._merge_config(
834
+ {"batch_size": 4, "max_seq_len": 64, "page_size": 32}
835
+ )
836
+
837
+ model_runner, _, backend, _, layer = self._create_model_components(config)
838
+ batch_size = config["batch_size"]
839
+
840
+ # Initialize CUDA graph state
841
+ backend.init_cuda_graph_state(
842
+ max_bs=batch_size, max_num_tokens=config["max_seq_len"] * batch_size
843
+ )
844
+
845
+ # Verify CUDA graph buffers are allocated
846
+ self.assertIsNotNone(backend.cuda_graph_kv_indices)
847
+ self.assertIsNotNone(backend.cuda_graph_workspace)
848
+
849
+ # Test capture metadata
850
+ seq_lens = torch.full(
851
+ (batch_size,), config["max_seq_len"], device=config["device"]
852
+ )
853
+ req_pool_indices = torch.arange(batch_size, device=config["device"])
854
+
855
+ backend.init_forward_metadata_capture_cuda_graph(
856
+ bs=batch_size,
857
+ num_tokens=batch_size,
858
+ req_pool_indices=req_pool_indices,
859
+ seq_lens=seq_lens,
860
+ encoder_lens=None,
861
+ forward_mode=ForwardMode.DECODE,
862
+ spec_info=None,
863
+ )
864
+
865
+ # Verify capture metadata
866
+ self.assertIn(batch_size, backend.decode_cuda_graph_metadata)
867
+ capture_metadata = backend.decode_cuda_graph_metadata[batch_size]
868
+
869
+ self.assertIsNotNone(capture_metadata.workspace)
870
+ self.assertIsNotNone(capture_metadata.block_kv_indices)
871
+
872
+ # Test replay with different sequence lengths
873
+ new_seq_lens = torch.randint(
874
+ config["max_seq_len"] // 2,
875
+ config["max_seq_len"] + 1,
876
+ (batch_size,),
877
+ device=config["device"],
878
+ )
879
+
880
+ backend.init_forward_metadata_replay_cuda_graph(
881
+ bs=batch_size,
882
+ req_pool_indices=req_pool_indices,
883
+ seq_lens=new_seq_lens,
884
+ seq_lens_sum=new_seq_lens.sum().item(),
885
+ encoder_lens=None,
886
+ forward_mode=ForwardMode.DECODE,
887
+ spec_info=None,
888
+ seq_lens_cpu=new_seq_lens.cpu(),
889
+ )
890
+
891
+ # Verify replay updated the metadata
892
+ replay_metadata = backend.forward_metadata
893
+ self.assertIsNotNone(replay_metadata)
894
+ self.assertEqual(
895
+ replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr()
896
+ )
897
+
898
+ def test_metadata_consistency_across_calls(self):
899
+ """Test metadata consistency across multiple forward calls."""
900
+ print(f"\nRunning metadata consistency tests...")
901
+
902
+ config = self._merge_config(
903
+ {"batch_size": 2, "max_seq_len": 64, "page_size": 32}
904
+ )
905
+
906
+ model_runner, _, backend, _, layer = self._create_model_components(config)
907
+
908
+ # First call
909
+ seq_lens_1 = torch.tensor([32, 48], device=config["device"])
910
+ fb_1 = self._create_forward_batch(
911
+ config["batch_size"], seq_lens_1, backend, model_runner, config
912
+ )
913
+ backend.init_forward_metadata(fb_1)
914
+ metadata_1 = backend.forward_metadata
915
+
916
+ # Second call with same sequence lengths
917
+ seq_lens_2 = torch.tensor([32, 48], device=config["device"])
918
+ fb_2 = self._create_forward_batch(
919
+ config["batch_size"], seq_lens_2, backend, model_runner, config
920
+ )
921
+ backend.init_forward_metadata(fb_2)
922
+ metadata_2 = backend.forward_metadata
923
+
924
+ # Metadata structure should be consistent
925
+ self.assertEqual(metadata_1.workspace.shape, metadata_2.workspace.shape)
926
+ self.assertEqual(
927
+ metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape
928
+ )
929
+
930
+ # Third call with different sequence lengths
931
+ seq_lens_3 = torch.tensor([16, 64], device=config["device"])
932
+ fb_3 = self._create_forward_batch(
933
+ config["batch_size"], seq_lens_3, backend, model_runner, config
934
+ )
935
+ backend.init_forward_metadata(fb_3)
936
+ metadata_3 = backend.forward_metadata
937
+
938
+ # Should still have valid structure
939
+ self.assertIsNotNone(metadata_3.workspace)
940
+ self.assertIsNotNone(metadata_3.block_kv_indices)
941
+ self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
942
+
943
+
944
+ if __name__ == "__main__":
945
+ unittest.main()