sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/srt/configs/model_config.py +37 -5
  4. sglang/srt/constrained/base_grammar_backend.py +26 -5
  5. sglang/srt/constrained/llguidance_backend.py +1 -0
  6. sglang/srt/constrained/outlines_backend.py +1 -0
  7. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  8. sglang/srt/constrained/xgrammar_backend.py +1 -0
  9. sglang/srt/disaggregation/base/__init__.py +8 -0
  10. sglang/srt/disaggregation/base/conn.py +113 -0
  11. sglang/srt/disaggregation/decode.py +18 -5
  12. sglang/srt/disaggregation/mini_lb.py +53 -122
  13. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  14. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  16. sglang/srt/disaggregation/prefill.py +43 -19
  17. sglang/srt/disaggregation/utils.py +31 -0
  18. sglang/srt/entrypoints/EngineBase.py +53 -0
  19. sglang/srt/entrypoints/engine.py +36 -8
  20. sglang/srt/entrypoints/http_server.py +37 -8
  21. sglang/srt/entrypoints/http_server_engine.py +142 -0
  22. sglang/srt/entrypoints/verl_engine.py +37 -10
  23. sglang/srt/hf_transformers_utils.py +4 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +330 -200
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  26. sglang/srt/layers/attention/vision.py +1 -1
  27. sglang/srt/layers/dp_attention.py +2 -4
  28. sglang/srt/layers/elementwise.py +15 -2
  29. sglang/srt/layers/linear.py +1 -0
  30. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
  38. sglang/srt/layers/moe/router.py +7 -1
  39. sglang/srt/layers/moe/topk.py +37 -16
  40. sglang/srt/layers/quantization/__init__.py +12 -5
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  42. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  43. sglang/srt/layers/quantization/fp8.py +25 -13
  44. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  45. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  46. sglang/srt/layers/quantization/kv_cache.py +43 -52
  47. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  48. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  49. sglang/srt/layers/quantization/w8a8_int8.py +1 -0
  50. sglang/srt/layers/radix_attention.py +13 -1
  51. sglang/srt/layers/rotary_embedding.py +12 -1
  52. sglang/srt/managers/io_struct.py +254 -97
  53. sglang/srt/managers/mm_utils.py +3 -2
  54. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  55. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  56. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  57. sglang/srt/managers/schedule_batch.py +62 -21
  58. sglang/srt/managers/scheduler.py +71 -14
  59. sglang/srt/managers/tokenizer_manager.py +17 -3
  60. sglang/srt/managers/tp_worker.py +1 -0
  61. sglang/srt/mem_cache/memory_pool.py +14 -1
  62. sglang/srt/metrics/collector.py +9 -0
  63. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  64. sglang/srt/model_executor/forward_batch_info.py +234 -15
  65. sglang/srt/model_executor/model_runner.py +48 -9
  66. sglang/srt/model_loader/loader.py +31 -4
  67. sglang/srt/model_loader/weight_utils.py +4 -2
  68. sglang/srt/models/baichuan.py +2 -0
  69. sglang/srt/models/chatglm.py +1 -0
  70. sglang/srt/models/commandr.py +1 -0
  71. sglang/srt/models/dbrx.py +1 -0
  72. sglang/srt/models/deepseek.py +1 -0
  73. sglang/srt/models/deepseek_v2.py +248 -61
  74. sglang/srt/models/exaone.py +1 -0
  75. sglang/srt/models/gemma.py +1 -0
  76. sglang/srt/models/gemma2.py +1 -0
  77. sglang/srt/models/gemma3_causal.py +1 -0
  78. sglang/srt/models/gpt2.py +1 -0
  79. sglang/srt/models/gpt_bigcode.py +1 -0
  80. sglang/srt/models/granite.py +1 -0
  81. sglang/srt/models/grok.py +1 -0
  82. sglang/srt/models/internlm2.py +1 -0
  83. sglang/srt/models/llama.py +1 -0
  84. sglang/srt/models/llama4.py +101 -34
  85. sglang/srt/models/minicpm.py +1 -0
  86. sglang/srt/models/minicpm3.py +2 -0
  87. sglang/srt/models/mixtral.py +1 -0
  88. sglang/srt/models/mixtral_quant.py +1 -0
  89. sglang/srt/models/mllama.py +51 -8
  90. sglang/srt/models/mllama4.py +102 -29
  91. sglang/srt/models/olmo.py +1 -0
  92. sglang/srt/models/olmo2.py +1 -0
  93. sglang/srt/models/olmoe.py +1 -0
  94. sglang/srt/models/phi3_small.py +1 -0
  95. sglang/srt/models/qwen.py +1 -0
  96. sglang/srt/models/qwen2.py +1 -0
  97. sglang/srt/models/qwen2_5_vl.py +35 -70
  98. sglang/srt/models/qwen2_moe.py +1 -0
  99. sglang/srt/models/qwen2_vl.py +27 -25
  100. sglang/srt/models/stablelm.py +1 -0
  101. sglang/srt/models/xverse.py +1 -0
  102. sglang/srt/models/xverse_moe.py +1 -0
  103. sglang/srt/openai_api/adapter.py +4 -1
  104. sglang/srt/patch_torch.py +11 -0
  105. sglang/srt/server_args.py +34 -0
  106. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  107. sglang/srt/speculative/eagle_utils.py +1 -11
  108. sglang/srt/speculative/eagle_worker.py +6 -2
  109. sglang/srt/utils.py +120 -9
  110. sglang/test/attention/test_flashattn_backend.py +259 -221
  111. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  112. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  113. sglang/test/test_block_fp8.py +57 -0
  114. sglang/test/test_utils.py +19 -8
  115. sglang/version.py +1 -1
  116. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  117. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
  118. sglang/srt/disaggregation/conn.py +0 -81
  119. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  120. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  121. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -2,60 +2,109 @@ import unittest
2
2
 
3
3
  import torch
4
4
 
5
+ from sglang.srt.configs.model_config import AttentionArch
5
6
  from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
7
+ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
6
8
  from sglang.srt.layers.radix_attention import RadixAttention
7
9
  from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
8
10
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
11
+ from sglang.srt.model_executor.model_runner import ServerArgs
9
12
  from sglang.test.test_utils import CustomTestCase
10
13
 
11
14
 
12
15
  class MockModelRunner:
13
- model_config = type(
14
- "ModelConfig", (), {"context_len": 2048, "is_multimodal": False}
15
- )
16
- sliding_window_size = None
17
-
18
- def __init__(self, device="cuda"):
19
- self.device = device
20
- # Create a proper req_to_token_pool with the req_to_token attribute
16
+ def __init__(
17
+ self,
18
+ page_size=1,
19
+ num_heads=2,
20
+ head_dim=8,
21
+ ):
22
+ self.device = "cuda"
23
+ self.dtype = torch.float16
24
+ attention_arch = AttentionArch.MHA
25
+ # Max batch size for the test.
26
+ max_batch_size = 160
27
+ # Total tokens(prefix + extend + decode) in the test should not exceed this length.
28
+ max_context_len = 2048
29
+ self.model_config = type(
30
+ "ModelConfig",
31
+ (),
32
+ {
33
+ "context_len": max_context_len,
34
+ "is_multimodal": False,
35
+ "attention_arch": attention_arch,
36
+ },
37
+ )
38
+ self.sliding_window_size = None
39
+ self.device = self.device
40
+ # Create a large enough req_to_token_pool to fit the test usage.
21
41
  self.req_to_token_pool = type(
22
42
  "TokenPool",
23
43
  (),
24
44
  {
25
- "size": 160, # a typical max_bs * max_context_len for cuda graph decode
45
+ # A typical max_bs * max_context_len for cuda graph decode
46
+ "size": max_batch_size,
47
+ # Add req_to_token attribute
26
48
  "req_to_token": torch.zeros(
27
- 160, 2048, dtype=torch.int32, device=device
28
- ), # Add req_to_token attribute
49
+ max_batch_size,
50
+ max_context_len,
51
+ dtype=torch.int32,
52
+ device=self.device,
53
+ ),
29
54
  },
30
55
  )
31
-
32
-
33
- class MockReqToTokenPool:
34
- def __init__(self, batch_size, seq_len, device):
35
- self.req_to_token = (
36
- torch.arange(batch_size * seq_len, device=device)
37
- .reshape(batch_size, seq_len)
38
- .to(torch.int32)
56
+ self.page_size = page_size
57
+ max_total_num_tokens = max_batch_size * max_context_len
58
+ self.token_to_kv_pool = MHATokenToKVPool(
59
+ size=max_total_num_tokens,
60
+ page_size=page_size,
61
+ dtype=self.dtype,
62
+ head_num=num_heads,
63
+ head_dim=head_dim,
64
+ layer_num=1, # only consider layer=1 for unit test
65
+ device=self.device,
66
+ enable_memory_saver=False,
39
67
  )
68
+ # Required by torch native backend
69
+ self.server_args = ServerArgs(model_path="fake_model_path")
40
70
 
41
71
 
42
72
  @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
43
73
  class TestFlashAttentionBackend(CustomTestCase):
44
74
  def setUp(self):
45
- """Set up test fixtures before each test method."""
46
- self.model_runner = MockModelRunner()
47
- self.backend = FlashAttentionBackend(self.model_runner)
48
-
49
- # Common test parameters
75
+ # Test parameters
50
76
  self.batch_size = 2
51
- self.seq_len = 4
77
+ self.seq_len = 256
52
78
  self.num_heads = 2
53
79
  self.head_dim = 8
54
80
  self.device = "cuda"
55
81
  self.dtype = torch.float16
56
82
 
83
+ def _init_model_runner(self, page_size=1):
84
+ self.model_runner = MockModelRunner(
85
+ page_size=page_size,
86
+ num_heads=self.num_heads,
87
+ head_dim=self.head_dim,
88
+ )
89
+ self.backend = FlashAttentionBackend(self.model_runner)
90
+ self.ref_backend = TorchNativeAttnBackend(self.model_runner)
91
+ self.model_runner.model_config.num_attention_heads = self.num_heads
92
+
93
+ def _mock_write_to_req_to_token_pool(self, batch_size, seq_len, page_size):
94
+ # if page_size > 1, the token pool stores the index to the page.
95
+ # so we need to multiply the index by page_size.
96
+ self.req_to_token = (
97
+ torch.arange(0, batch_size, dtype=torch.int32, device=self.device)[:, None]
98
+ * seq_len
99
+ + torch.arange(0, seq_len, dtype=torch.int32, device=self.device)[None, :]
100
+ + page_size
101
+ )
102
+ self.model_runner.req_to_token_pool.req_to_token[:batch_size, :seq_len] = (
103
+ self.req_to_token
104
+ )
105
+
57
106
  def _create_attention_layer(self):
58
- """Helper method to create an attention layer."""
107
+ """Create attention layer for testing."""
59
108
  return RadixAttention(
60
109
  num_heads=self.num_heads,
61
110
  head_dim=self.head_dim,
@@ -64,47 +113,27 @@ class TestFlashAttentionBackend(CustomTestCase):
64
113
  layer_id=0,
65
114
  )
66
115
 
67
- def _create_kv_pool(self, size):
68
- """Helper method to create a KV pool."""
69
- return MHATokenToKVPool(
70
- size=size,
71
- page_size=1, # only consider page=1 for unit test
72
- dtype=self.dtype,
73
- head_num=self.num_heads,
74
- head_dim=self.head_dim,
75
- layer_num=1, # only consider layer=1 for unit test
76
- device=self.device,
77
- enable_memory_saver=False,
78
- )
79
-
80
116
  def _create_qkv_tensors(self, tokens_len):
81
- """Helper method to create q, k, v tensors."""
117
+ """Create q, k, v tensors for testing."""
118
+ shape = (tokens_len, self.num_heads, self.head_dim)
82
119
  return (
83
- torch.randn(
84
- tokens_len,
85
- self.num_heads,
86
- self.head_dim,
87
- dtype=self.dtype,
88
- device=self.device,
89
- ),
90
- torch.randn(
91
- tokens_len,
92
- self.num_heads,
93
- self.head_dim,
94
- dtype=self.dtype,
95
- device=self.device,
96
- ),
97
- torch.randn(
98
- tokens_len,
99
- self.num_heads,
100
- self.head_dim,
101
- dtype=self.dtype,
102
- device=self.device,
103
- ),
120
+ torch.randn(shape, dtype=self.dtype, device=self.device),
121
+ torch.randn(shape, dtype=self.dtype, device=self.device),
122
+ torch.randn(shape, dtype=self.dtype, device=self.device),
104
123
  )
105
124
 
106
- def _verify_output(self, output, expected_shape):
107
- """Helper method to verify output."""
125
+ def _run_reference_forward(
126
+ self, mode, q, k, v, layer, forward_batch, expected_shape
127
+ ):
128
+ """Run reference forward pass using native backend."""
129
+ if mode == ForwardMode.EXTEND:
130
+ output = self.ref_backend.forward_extend(q, k, v, layer, forward_batch)
131
+ else: # ForwardMode.DECODE
132
+ output = self.ref_backend.forward_decode(q, k, v, layer, forward_batch)
133
+ return output.view(expected_shape)
134
+
135
+ def _verify_output(self, output, expected_shape, output_ref=None):
136
+ """Verify output tensor shape, dtype, and values."""
108
137
  self.assertEqual(
109
138
  output.shape,
110
139
  expected_shape,
@@ -116,161 +145,110 @@ class TestFlashAttentionBackend(CustomTestCase):
116
145
  torch.isnan(output).sum().item(), 0, "Output contains NaN values"
117
146
  )
118
147
 
119
- def test_forward_extend(self):
120
- """Test the standard extend operation."""
121
- # Create test inputs
122
- q, k, v = self._create_qkv_tensors(self.batch_size * self.seq_len)
123
-
124
- # Create attention layer
125
- layer = self._create_attention_layer()
126
-
127
- # Create forward batch
128
- forward_batch = ForwardBatch(
129
- batch_size=self.batch_size,
130
- input_ids=torch.randint(
131
- 0, 100, (self.batch_size, self.seq_len), device=self.device
132
- ),
133
- out_cache_loc=torch.arange(
134
- self.batch_size * self.seq_len, device=self.device
135
- ),
136
- seq_lens_sum=self.batch_size * self.seq_len,
137
- forward_mode=ForwardMode.EXTEND,
138
- req_pool_indices=torch.arange(self.batch_size, device=self.device),
139
- seq_lens=torch.tensor([self.seq_len] * self.batch_size, device=self.device),
140
- # 0 prefix, 4 extend
141
- extend_prefix_lens=torch.tensor([0] * self.batch_size, device=self.device),
142
- extend_seq_lens=torch.tensor([4] * self.batch_size, device=self.device),
143
- attn_backend=self.backend,
144
- )
145
-
146
- # Add token pool and KV cache
147
- forward_batch.req_to_token_pool = MockReqToTokenPool(
148
- self.batch_size, self.seq_len, self.device
149
- )
150
- forward_batch.token_to_kv_pool = self._create_kv_pool(
151
- self.batch_size * self.seq_len
152
- )
153
-
154
- # Initialize forward metadata before running the attention
155
- self.backend.init_forward_metadata(forward_batch)
156
-
157
- # Run forward_extend
158
- output = self.backend.forward_extend(q, k, v, layer, forward_batch)
159
-
160
- # Verify output
161
- expected_shape = (
162
- self.batch_size * self.seq_len,
163
- self.num_heads * self.head_dim,
164
- )
165
- self._verify_output(output, expected_shape)
166
-
167
- def test_forward_decode(self):
168
- """Test the decode operation with cached tokens."""
169
- # For decode, we only have one token per sequence
170
- decode_len = 1
171
- curr_seq_len = self.seq_len + decode_len
172
-
173
- # Create test inputs
174
- q, k, v = self._create_qkv_tensors(self.batch_size * decode_len)
175
-
176
- # Create attention layer
177
- layer = self._create_attention_layer()
178
-
179
- # Create forward batch
180
- forward_batch = ForwardBatch(
181
- batch_size=self.batch_size,
182
- input_ids=torch.randint(
183
- 0, 100, (self.batch_size, decode_len), device=self.device
184
- ),
185
- out_cache_loc=torch.arange(
186
- self.batch_size * self.seq_len,
187
- self.batch_size * curr_seq_len,
188
- device=self.device,
189
- ),
190
- seq_lens_sum=self.batch_size * curr_seq_len,
191
- forward_mode=ForwardMode.DECODE,
192
- req_pool_indices=torch.arange(self.batch_size, device=self.device),
193
- seq_lens=torch.tensor([curr_seq_len] * self.batch_size, device=self.device),
194
- attn_backend=self.backend,
195
- )
196
-
197
- # Add token pool and KV cache
198
- forward_batch.req_to_token_pool = MockReqToTokenPool(
199
- self.batch_size, curr_seq_len, self.device
200
- )
201
- forward_batch.token_to_kv_pool = self._create_kv_pool(
202
- self.batch_size * curr_seq_len
203
- )
204
-
205
- # Pre-fill KV cache
206
- cache_k, cache_v, _ = self._create_qkv_tensors(self.batch_size * self.seq_len)
207
- forward_batch.token_to_kv_pool.set_kv_buffer(
208
- layer,
209
- torch.arange(self.batch_size * self.seq_len, device=self.device),
210
- cache_k,
211
- cache_v,
212
- layer.k_scale,
213
- layer.v_scale,
214
- )
215
-
216
- # Initialize forward metadata before running the attention
217
- self.backend.init_forward_metadata(forward_batch)
218
-
219
- # Run forward_decode
220
- output = self.backend.forward_decode(q, k, v, layer, forward_batch)
221
-
222
- # Verify output
223
- expected_shape = (self.batch_size, self.num_heads * self.head_dim)
224
- self._verify_output(output, expected_shape)
225
-
226
- def test_forward_extend_with_prefix(self):
227
- """Test extending from cached prefix tokens."""
228
- # Define prefix and extend lengths
229
- prefix_len = 2
230
- extend_len = 2
231
- total_len = prefix_len + extend_len
232
-
233
- # Create test inputs for the extend portion
234
- q, k, v = self._create_qkv_tensors(self.batch_size * extend_len)
148
+ if output_ref is not None:
149
+ if not torch.allclose(output, output_ref, atol=1e-1, rtol=0.0):
150
+ # Check where the values differ beyond the given tolerances
151
+ diff_mask = ~torch.isclose(output, output_ref, atol=1e-1, rtol=0.0)
152
+
153
+ # Find the first index where the difference occurs
154
+ if diff_mask.any():
155
+ first_mismatch_idx = diff_mask.nonzero()[0]
156
+ print(
157
+ "First mismatch at index:", tuple(first_mismatch_idx.tolist())
158
+ )
159
+ print("output:", output[tuple(first_mismatch_idx.tolist())])
160
+ print("output_ref:", output_ref[tuple(first_mismatch_idx.tolist())])
161
+ raise AssertionError(
162
+ "Attention output is not close to the torch native backend output"
163
+ )
164
+
165
+ def _create_forward_batch(self, mode, q_len=None, prefix_len=0, page_size=1):
166
+ """Create a forward batch for testing based on mode and lengths."""
167
+ self._init_model_runner(page_size=page_size)
168
+
169
+ # Default to self.seq_len if not specified
170
+ q_len = q_len or self.seq_len
171
+
172
+ if mode == ForwardMode.EXTEND:
173
+ total_len = prefix_len + q_len
174
+ out_cache_start = prefix_len * self.batch_size
175
+ out_cache_end = total_len * self.batch_size
176
+
177
+ forward_batch = ForwardBatch(
178
+ batch_size=self.batch_size,
179
+ input_ids=torch.randint(
180
+ 0, 100, (self.batch_size, q_len), device=self.device
181
+ ),
182
+ out_cache_loc=torch.arange(
183
+ out_cache_start, out_cache_end, device=self.device
184
+ ),
185
+ seq_lens_sum=self.batch_size * total_len,
186
+ forward_mode=mode,
187
+ req_pool_indices=torch.arange(self.batch_size, device=self.device),
188
+ seq_lens=torch.tensor(
189
+ [total_len] * self.batch_size, device=self.device
190
+ ),
191
+ seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
192
+ extend_prefix_lens=torch.tensor(
193
+ [prefix_len] * self.batch_size, device=self.device
194
+ ),
195
+ extend_prefix_lens_cpu=torch.tensor(
196
+ [prefix_len] * self.batch_size, device="cpu"
197
+ ),
198
+ extend_seq_lens=torch.tensor(
199
+ [q_len] * self.batch_size, device=self.device
200
+ ),
201
+ extend_seq_lens_cpu=torch.tensor(
202
+ [q_len] * self.batch_size, device="cpu"
203
+ ),
204
+ attn_backend=self.backend,
205
+ )
206
+ else: # ForwardMode.DECODE
207
+ decode_len = q_len # Assuming 1 for decode testing
208
+ total_len = self.seq_len + decode_len
209
+ if mode == ForwardMode.DECODE and page_size > 1:
210
+ # Get next page_size multiple of self.seq_len
211
+ out_cache_start = (
212
+ self.batch_size * self.seq_len // page_size + 1
213
+ ) * page_size
214
+ # out_cache_end is the start of the next block
215
+ out_cache_end = out_cache_start + decode_len * page_size
216
+ else:
217
+ out_cache_start = self.batch_size * self.seq_len
218
+ out_cache_end = self.batch_size * total_len
219
+
220
+ forward_batch = ForwardBatch(
221
+ batch_size=self.batch_size,
222
+ input_ids=torch.randint(
223
+ 0, 100, (self.batch_size, decode_len), device=self.device
224
+ ),
225
+ out_cache_loc=torch.tensor(
226
+ [out_cache_start, out_cache_end], device=self.device
227
+ ),
228
+ seq_lens_sum=self.batch_size * total_len,
229
+ forward_mode=mode,
230
+ req_pool_indices=torch.arange(self.batch_size, device=self.device),
231
+ seq_lens=torch.tensor(
232
+ [total_len] * self.batch_size, device=self.device
233
+ ),
234
+ seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
235
+ attn_backend=self.backend,
236
+ )
235
237
 
236
- # Create attention layer
237
- layer = self._create_attention_layer()
238
+ # Add token pool
239
+ forward_batch.req_to_token_pool = self.model_runner.req_to_token_pool
238
240
 
239
- # Create forward batch
240
- forward_batch = ForwardBatch(
241
- batch_size=self.batch_size,
242
- input_ids=torch.randint(
243
- 0, 100, (self.batch_size, extend_len), device=self.device
244
- ),
245
- out_cache_loc=torch.arange(
246
- self.batch_size * prefix_len,
247
- self.batch_size * total_len,
248
- device=self.device,
249
- ),
250
- seq_lens_sum=self.batch_size * total_len,
251
- forward_mode=ForwardMode.EXTEND,
252
- req_pool_indices=torch.arange(self.batch_size, device=self.device),
253
- seq_lens=torch.tensor([total_len] * self.batch_size, device=self.device),
254
- extend_prefix_lens=torch.tensor(
255
- [prefix_len] * self.batch_size, device=self.device
256
- ),
257
- extend_seq_lens=torch.tensor(
258
- [extend_len] * self.batch_size, device=self.device
259
- ),
260
- attn_backend=self.backend,
261
- )
241
+ # Write current batch's req_to_token to req_to_token_pool
242
+ self._mock_write_to_req_to_token_pool(self.batch_size, total_len, page_size)
243
+ # Add kv pool for this forward batch
244
+ forward_batch.token_to_kv_pool = self.model_runner.token_to_kv_pool
262
245
 
263
- # Add token pool and KV cache
264
- forward_batch.req_to_token_pool = MockReqToTokenPool(
265
- self.batch_size, total_len, self.device
266
- )
267
- forward_batch.token_to_kv_pool = self._create_kv_pool(
268
- self.batch_size * total_len
269
- )
246
+ return forward_batch
270
247
 
271
- # Pre-fill the KV cache for prefix with known values
248
+ def _setup_kv_cache(self, forward_batch, layer, cache_len):
249
+ # Create constant values for the prefix cache for easy debugging
272
250
  cache_k = torch.ones(
273
- self.batch_size * prefix_len,
251
+ self.batch_size * cache_len,
274
252
  self.num_heads,
275
253
  self.head_dim,
276
254
  dtype=self.dtype,
@@ -278,7 +256,7 @@ class TestFlashAttentionBackend(CustomTestCase):
278
256
  )
279
257
  cache_v = (
280
258
  torch.ones(
281
- self.batch_size * prefix_len,
259
+ self.batch_size * cache_len,
282
260
  self.num_heads,
283
261
  self.head_dim,
284
262
  dtype=self.dtype,
@@ -290,22 +268,82 @@ class TestFlashAttentionBackend(CustomTestCase):
290
268
  # Set the prefix KV cache
291
269
  forward_batch.token_to_kv_pool.set_kv_buffer(
292
270
  layer,
293
- torch.arange(self.batch_size * prefix_len, device=self.device),
271
+ torch.arange(self.batch_size * cache_len, device=self.device),
294
272
  cache_k,
295
273
  cache_v,
296
274
  layer.k_scale,
297
275
  layer.v_scale,
298
276
  )
299
277
 
300
- # Initialize forward metadata before running the attention
278
+ def _run_attention_test(self, mode, q_len, prefix_len=0, page_size=1):
279
+ """
280
+ Run an attention test with the specified parameters.
281
+ Args:
282
+ mode: ForwardMode.EXTEND or ForwardMode.DECODE
283
+ q_len: Length of the query sequence. For decode mode, q_len is 1.
284
+ prefix_len: Length of the prefix sequence for extend mode
285
+ page_size: Page size for the KV cache
286
+ """
287
+ layer = self._create_attention_layer()
288
+
289
+ # Create forward batch and set up
290
+ forward_batch = self._create_forward_batch(mode, q_len, prefix_len, page_size)
291
+
292
+ # Create QKV tensors for the input
293
+ q, k, v = self._create_qkv_tensors(self.batch_size * q_len)
294
+
295
+ # KV cache for prefixed extend is prefix_len
296
+ # KV cache for decode is same as seq_len
297
+ # No KV cache for extend without prefix
298
+ if mode == ForwardMode.EXTEND:
299
+ if prefix_len > 0:
300
+ self._setup_kv_cache(forward_batch, layer, prefix_len)
301
+ else:
302
+ self._setup_kv_cache(forward_batch, layer, self.seq_len)
303
+
301
304
  self.backend.init_forward_metadata(forward_batch)
302
305
 
303
- # Run forward_extend
304
- output = self.backend.forward_extend(q, k, v, layer, forward_batch)
306
+ if mode == ForwardMode.EXTEND:
307
+ expected_shape = (
308
+ self.batch_size * q_len,
309
+ self.num_heads * self.head_dim,
310
+ )
311
+ output = self.backend.forward_extend(q, k, v, layer, forward_batch)
312
+ else:
313
+ expected_shape = (self.batch_size, self.num_heads * self.head_dim)
314
+ output = self.backend.forward_decode(q, k, v, layer, forward_batch)
315
+
316
+ output_ref = self._run_reference_forward(
317
+ mode, q, k, v, layer, forward_batch, expected_shape
318
+ )
319
+
320
+ self._verify_output(output, expected_shape, output_ref)
321
+
322
+ return output
323
+
324
+ def test_forward_extend(self):
325
+ """Test the standard extend operation."""
326
+ self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len)
327
+
328
+ def test_forward_decode(self):
329
+ """Test the decode operation with cached tokens."""
330
+ self._run_attention_test(ForwardMode.DECODE, q_len=1)
331
+
332
+ def test_forward_extend_with_prefix(self):
333
+ """Test extending from cached prefix tokens."""
334
+ prefix_len = self.seq_len // 2
335
+ extend_len = self.seq_len - prefix_len
336
+ self._run_attention_test(
337
+ ForwardMode.EXTEND, q_len=extend_len, prefix_len=prefix_len
338
+ )
339
+
340
+ def test_forward_extend_with_page_size_greater_than_1(self):
341
+ """Test extending from cached prefix tokens with page size greater than 1."""
342
+ self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len, page_size=64)
305
343
 
306
- # Verify output
307
- expected_shape = (self.batch_size * extend_len, self.num_heads * self.head_dim)
308
- self._verify_output(output, expected_shape)
344
+ def test_forward_decode_with_page_size_greater_than_1(self):
345
+ """Test decode operation with page size greater than 1."""
346
+ self._run_attention_test(ForwardMode.DECODE, q_len=1, page_size=64)
309
347
 
310
348
 
311
349
  if __name__ == "__main__":