sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,285 @@
1
+ import unittest
2
+
3
+ import torch
4
+
5
+ from sglang.srt.configs.model_config import AttentionArch
6
+ from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
7
+ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
8
+ from sglang.srt.layers.radix_attention import RadixAttention
9
+ from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
10
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
11
+ from sglang.test.test_utils import CustomTestCase
12
+
13
+
14
+ class MockModelRunner:
15
+ def __init__(
16
+ self,
17
+ kv_lora_rank,
18
+ qk_rope_head_dim,
19
+ ):
20
+ attention_arch = AttentionArch.MLA
21
+ self.device = "cuda"
22
+ self.dtype = torch.float16
23
+ context_len = 2048
24
+ self.model_config = type(
25
+ "ModelConfig",
26
+ (),
27
+ {
28
+ "context_len": context_len,
29
+ "attention_arch": attention_arch,
30
+ },
31
+ )
32
+ self.sliding_window_size = None
33
+
34
+ batch_size = 160
35
+ # Create a proper req_to_token_pool with the req_to_token attribute
36
+ self.req_to_token_pool = type(
37
+ "TokenPool",
38
+ (),
39
+ {
40
+ # A typical max_bs * max_context_len for cuda graph decode
41
+ "size": batch_size,
42
+ # Add req_to_token attribute
43
+ "req_to_token": torch.zeros(
44
+ batch_size, context_len, dtype=torch.int32, device=self.device
45
+ ),
46
+ },
47
+ )
48
+ self.page_size = 1
49
+ max_total_num_tokens = batch_size * context_len
50
+ self.token_to_kv_pool = MLATokenToKVPool(
51
+ size=max_total_num_tokens,
52
+ page_size=self.page_size,
53
+ dtype=self.dtype,
54
+ kv_lora_rank=kv_lora_rank,
55
+ qk_rope_head_dim=qk_rope_head_dim,
56
+ layer_num=1, # only consider layer=1 for unit test
57
+ device=self.device,
58
+ enable_memory_saver=False,
59
+ )
60
+
61
+
62
+ class MockReqToTokenPool:
63
+ def __init__(self, batch_size, seq_len, device):
64
+ self.req_to_token = (
65
+ torch.arange(batch_size * seq_len, device=device)
66
+ .reshape(batch_size, seq_len)
67
+ .to(torch.int32)
68
+ )
69
+
70
+
71
+ @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
72
+ class TestFlashAttentionMLABackend(CustomTestCase):
73
+ def setUp(self):
74
+ # Test parameters
75
+ self.batch_size = 2
76
+ self.seq_len = 360
77
+ self.num_heads = 2
78
+ self.device = "cuda"
79
+ self.dtype = torch.float16
80
+ self.kv_lora_rank = 512
81
+ self.q_lora_rank = 128
82
+ self.qk_rope_head_dim = 64
83
+ self.qk_head_dim = self.qk_rope_head_dim + self.kv_lora_rank
84
+ # Assume no rope scaling
85
+ self.scaling = self.qk_head_dim**-0.5
86
+ # Initialize model runner and backend
87
+ self._init_model_runner()
88
+ self.backend = FlashAttentionBackend(self.model_runner)
89
+ self.num_local_heads = 2
90
+
91
+ def _init_model_runner(self):
92
+ self.model_runner = MockModelRunner(
93
+ kv_lora_rank=self.kv_lora_rank,
94
+ qk_rope_head_dim=self.qk_rope_head_dim,
95
+ )
96
+ self.backend = FlashAttentionBackend(self.model_runner)
97
+
98
+ def _create_attention_layer(self):
99
+ """Create attention layer for testing."""
100
+ self.attn_mqa = RadixAttention(
101
+ num_heads=self.num_local_heads,
102
+ head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
103
+ scaling=self.scaling,
104
+ num_kv_heads=1,
105
+ layer_id=0,
106
+ v_head_dim=self.kv_lora_rank,
107
+ prefix="attn_mqa",
108
+ )
109
+ return self.attn_mqa
110
+
111
+ def _run_reference_forward(
112
+ self, mode, q, k, v, layer, forward_batch, expected_shape
113
+ ):
114
+ """Run reference forward pass using native backend."""
115
+ if mode == ForwardMode.EXTEND:
116
+ output = self.ref_backend.forward_extend(q, k, v, layer, forward_batch)
117
+ else: # ForwardMode.DECODE
118
+ output = self.ref_backend.forward_decode(q, k, v, layer, forward_batch)
119
+ return output.view(expected_shape)
120
+
121
+ def _verify_output(self, output, expected_shape):
122
+ """Verify output tensor shape, dtype, and values."""
123
+ self.assertEqual(
124
+ output.shape,
125
+ expected_shape,
126
+ f"Expected shape {expected_shape}, got {output.shape}",
127
+ )
128
+ self.assertEqual(output.dtype, self.dtype)
129
+ self.assertEqual(output.device.type, "cuda")
130
+ self.assertEqual(
131
+ torch.isnan(output).sum().item(), 0, "Output contains NaN values"
132
+ )
133
+
134
+ def _create_forward_batch(self, mode, q_len=None, prefix_len=0):
135
+ """Create a forward batch for testing based on mode and lengths."""
136
+ # Default to self.seq_len if not specified
137
+ q_len = q_len or self.seq_len
138
+
139
+ if mode == ForwardMode.EXTEND:
140
+ total_len = prefix_len + q_len
141
+ out_cache_start = prefix_len * self.batch_size
142
+ out_cache_end = total_len * self.batch_size
143
+
144
+ forward_batch = ForwardBatch(
145
+ batch_size=self.batch_size,
146
+ input_ids=torch.randint(
147
+ 0, 100, (self.batch_size, q_len), device=self.device
148
+ ),
149
+ out_cache_loc=torch.arange(
150
+ out_cache_start, out_cache_end, device=self.device
151
+ ),
152
+ seq_lens_sum=self.batch_size * total_len,
153
+ forward_mode=mode,
154
+ req_pool_indices=torch.arange(self.batch_size, device=self.device),
155
+ seq_lens=torch.tensor(
156
+ [total_len] * self.batch_size, device=self.device
157
+ ),
158
+ seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
159
+ extend_prefix_lens=torch.tensor(
160
+ [prefix_len] * self.batch_size, device=self.device
161
+ ),
162
+ extend_prefix_lens_cpu=torch.tensor(
163
+ [prefix_len] * self.batch_size, device="cpu"
164
+ ),
165
+ extend_seq_lens=torch.tensor(
166
+ [q_len] * self.batch_size, device=self.device
167
+ ),
168
+ extend_seq_lens_cpu=torch.tensor(
169
+ [q_len] * self.batch_size, device="cpu"
170
+ ),
171
+ attn_backend=self.backend,
172
+ )
173
+
174
+ else: # ForwardMode.DECODE
175
+ decode_len = q_len # typically 1 for decode mode
176
+ total_len = self.seq_len + decode_len
177
+ out_cache_start = self.batch_size * self.seq_len
178
+ out_cache_end = self.batch_size * total_len
179
+
180
+ forward_batch = ForwardBatch(
181
+ batch_size=self.batch_size,
182
+ input_ids=torch.randint(
183
+ 0, 100, (self.batch_size, decode_len), device=self.device
184
+ ),
185
+ out_cache_loc=torch.arange(
186
+ out_cache_start, out_cache_end, device=self.device
187
+ ),
188
+ seq_lens_sum=self.batch_size * total_len,
189
+ forward_mode=mode,
190
+ req_pool_indices=torch.arange(self.batch_size, device=self.device),
191
+ seq_lens=torch.tensor(
192
+ [total_len] * self.batch_size, device=self.device
193
+ ),
194
+ seq_lens_cpu=torch.tensor([total_len] * self.batch_size, device="cpu"),
195
+ attn_backend=self.backend,
196
+ )
197
+
198
+ # Add token pool from model runner to forward batch
199
+ forward_batch.req_to_token_pool = self.model_runner.req_to_token_pool
200
+
201
+ # Add KV cache from model runner to forward batch
202
+ forward_batch.token_to_kv_pool = self.model_runner.token_to_kv_pool
203
+
204
+ return forward_batch
205
+
206
+ def _setup_kv_cache(self, forward_batch, layer, cache_len):
207
+ """Set up KV cache with prefix tokens."""
208
+ if cache_len <= 0:
209
+ return
210
+
211
+ # Create constant values for the prefix cache for easy debugging
212
+ latent_cache = torch.ones(
213
+ self.batch_size * cache_len,
214
+ 1, # latent cache has only one head in MQA
215
+ self.kv_lora_rank + self.qk_rope_head_dim,
216
+ dtype=self.dtype,
217
+ device=self.device,
218
+ )
219
+
220
+ # Set the prefix KV cache
221
+ forward_batch.token_to_kv_pool.set_kv_buffer(
222
+ layer,
223
+ torch.arange(self.batch_size * cache_len, device=self.device),
224
+ latent_cache,
225
+ None,
226
+ )
227
+
228
+ def _run_attention_test(self, mode, q_len, prefix_len=0):
229
+ """
230
+ Run an attention test with the specified parameters.
231
+ Args:
232
+ mode: ForwardMode.EXTEND or ForwardMode.DECODE
233
+ q_len: Length of the query sequence. For decode mode, q_len is 1.
234
+ prefix_len: Length of the prefix sequence for extend mode
235
+ """
236
+ layer = self._create_attention_layer()
237
+
238
+ # Create forward batch and set up
239
+ forward_batch = self._create_forward_batch(mode, q_len, prefix_len)
240
+
241
+ # Create q, kv_compressed for testing
242
+ q_shape = (self.batch_size * q_len, self.num_heads, self.qk_head_dim)
243
+ kv_shape = (self.batch_size * q_len, self.qk_head_dim)
244
+ q = torch.randn(q_shape, dtype=self.dtype, device=self.device)
245
+ kv_compressed = torch.randn(kv_shape, dtype=self.dtype, device=self.device)
246
+ # v is not used for mqa, all values passed in through k
247
+ k = kv_compressed.unsqueeze(1)
248
+ v = torch.randn((1), dtype=self.dtype, device=self.device)
249
+
250
+ self._setup_kv_cache(forward_batch, layer, prefix_len)
251
+
252
+ self.backend.init_forward_metadata(forward_batch)
253
+
254
+ expected_shape = (
255
+ self.batch_size * q_len,
256
+ self.num_heads * self.kv_lora_rank,
257
+ )
258
+
259
+ if mode == ForwardMode.EXTEND:
260
+ output = self.backend.forward_extend(q, k, v, layer, forward_batch)
261
+ else:
262
+ output = self.backend.forward_decode(q, k, v, layer, forward_batch)
263
+
264
+ self._verify_output(output, expected_shape)
265
+ return output
266
+
267
+ def test_forward_extend(self):
268
+ """Test the standard extend operation."""
269
+ self._run_attention_test(ForwardMode.EXTEND, q_len=self.seq_len)
270
+
271
+ def test_forward_decode(self):
272
+ """Test the decode operation with cached tokens."""
273
+ self._run_attention_test(ForwardMode.DECODE, q_len=1)
274
+
275
+ def test_forward_extend_with_prefix(self):
276
+ """Test extending from cached prefix tokens."""
277
+ prefix_len = self.seq_len // 2
278
+ extend_len = self.seq_len - prefix_len
279
+ self._run_attention_test(
280
+ ForwardMode.EXTEND, q_len=extend_len, prefix_len=prefix_len
281
+ )
282
+
283
+
284
+ if __name__ == "__main__":
285
+ unittest.main()
@@ -0,0 +1,224 @@
1
+ import unittest
2
+
3
+ import torch
4
+
5
+ from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
6
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
7
+ from sglang.test.test_utils import CustomTestCase
8
+
9
+ TEST_CASES = [
10
+ # Sequence with same prefix lens
11
+ {
12
+ "batch_size": 3,
13
+ "prefix_lens": [64, 64, 64],
14
+ "max_chunk_capacity": 48,
15
+ "prefix_chunk_len": 16,
16
+ "num_prefix_chunks": 4,
17
+ "prefix_chunk_starts": torch.tensor(
18
+ [
19
+ [0, 0, 0],
20
+ [16, 16, 16],
21
+ [32, 32, 32],
22
+ [48, 48, 48],
23
+ ],
24
+ dtype=torch.int32,
25
+ ),
26
+ "prefix_chunk_seq_lens": torch.tensor(
27
+ [
28
+ [16, 16, 16],
29
+ [16, 16, 16],
30
+ [16, 16, 16],
31
+ [16, 16, 16],
32
+ ],
33
+ dtype=torch.int32,
34
+ ),
35
+ },
36
+ # Sequence with different prefix lens
37
+ {
38
+ "batch_size": 4,
39
+ "prefix_lens": [16, 32, 48, 64],
40
+ "max_chunk_capacity": 64,
41
+ "prefix_chunk_len": 16,
42
+ "num_prefix_chunks": 4,
43
+ "prefix_chunk_starts": torch.tensor(
44
+ [
45
+ [0, 0, 0, 0],
46
+ [16, 16, 16, 16],
47
+ [32, 32, 32, 32],
48
+ [48, 48, 48, 48],
49
+ ],
50
+ dtype=torch.int32,
51
+ ),
52
+ "prefix_chunk_seq_lens": torch.tensor(
53
+ [
54
+ [16, 16, 16, 16],
55
+ [0, 16, 16, 16],
56
+ [0, 0, 16, 16],
57
+ [0, 0, 0, 16],
58
+ ],
59
+ dtype=torch.int32,
60
+ ),
61
+ },
62
+ # Sequence with irregular shapes
63
+ {
64
+ "batch_size": 2,
65
+ "prefix_lens": [1, 64],
66
+ "max_chunk_capacity": 31,
67
+ "prefix_chunk_len": 15,
68
+ "num_prefix_chunks": 5,
69
+ "prefix_chunk_starts": torch.tensor(
70
+ [
71
+ [0, 0],
72
+ [15, 15],
73
+ [30, 30],
74
+ [45, 45],
75
+ [60, 60],
76
+ ],
77
+ dtype=torch.int32,
78
+ ),
79
+ "prefix_chunk_seq_lens": torch.tensor(
80
+ [
81
+ [1, 15],
82
+ [0, 15],
83
+ [0, 15],
84
+ [0, 15],
85
+ [0, 4],
86
+ ],
87
+ dtype=torch.int32,
88
+ ),
89
+ },
90
+ ]
91
+
92
+
93
+ class MockForwardBatch(ForwardBatch):
94
+ def __init__(self, max_chunk_capacity: int, *args, **kwargs):
95
+ super().__init__(*args, **kwargs)
96
+ self.max_chunk_capacity = max_chunk_capacity
97
+
98
+ def get_max_chunk_capacity(self):
99
+ return self.max_chunk_capacity
100
+
101
+
102
+ class MockReqToTokenPool:
103
+ def __init__(self, batch_size, seq_len, device):
104
+ self.req_to_token = (
105
+ torch.arange(batch_size * seq_len, device=device)
106
+ .reshape(batch_size, seq_len)
107
+ .to(torch.int32)
108
+ )
109
+
110
+
111
+ # Test correctness of triton kernel for computing kv indices
112
+ def check_kv_indices(forward_batch):
113
+ for i in range(forward_batch.num_prefix_chunks):
114
+ computed_kv_indices = forward_batch.prefix_chunk_kv_indices[i]
115
+ req_to_token = forward_batch.req_to_token_pool.req_to_token[
116
+ : forward_batch.batch_size, :
117
+ ]
118
+ ref_kv_indices = torch.empty(
119
+ forward_batch.prefix_chunk_num_tokens[i],
120
+ dtype=torch.int32,
121
+ device=computed_kv_indices.device,
122
+ )
123
+ running_ptr = 0
124
+ for j in range(forward_batch.batch_size):
125
+ seq_start = forward_batch.prefix_chunk_starts[i, j].item()
126
+ seq_len = forward_batch.prefix_chunk_seq_lens[i, j].item()
127
+ ref_kv_indices[running_ptr : running_ptr + seq_len].copy_(
128
+ req_to_token[j, seq_start : seq_start + seq_len]
129
+ )
130
+ running_ptr += seq_len
131
+ assert torch.allclose(computed_kv_indices, ref_kv_indices)
132
+
133
+
134
+ @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
135
+ class TestPrefixChunkInfo(CustomTestCase):
136
+ def setUp(self):
137
+ # Common test parameters
138
+ self.num_local_heads = 128
139
+ self.kv_lora_rank = 512
140
+ self.qk_rope_head_dim = 64
141
+ self.device = torch.device("cuda")
142
+ self.dtype = torch.bfloat16
143
+ self.extend_len = 64
144
+ self.max_bs = 4
145
+ self.max_seq_len = 128
146
+
147
+ # req_to_token_pool
148
+ self.req_to_token_pool = MockReqToTokenPool(
149
+ self.max_bs,
150
+ self.max_seq_len,
151
+ self.device,
152
+ )
153
+
154
+ # token_to_kv_pool
155
+ self.token_to_kv_pool = MLATokenToKVPool(
156
+ size=self.max_bs * self.max_seq_len,
157
+ page_size=1, # only consider page=1 for unit test
158
+ dtype=self.dtype,
159
+ kv_lora_rank=self.kv_lora_rank,
160
+ qk_rope_head_dim=self.qk_rope_head_dim,
161
+ layer_num=1, # only consider layer=1 for unit test
162
+ device=self.device,
163
+ enable_memory_saver=False,
164
+ )
165
+
166
+ def test_prefix_chunk_info(self):
167
+ """Test the standard extend operation."""
168
+
169
+ for test_case in TEST_CASES:
170
+ print(
171
+ f"Test case with batch_size={test_case['batch_size']}, prefix_lens={test_case['prefix_lens']}, max_chunk_capacity={test_case['max_chunk_capacity']}"
172
+ )
173
+ batch_size = test_case["batch_size"]
174
+ prefix_lens_cpu = test_case["prefix_lens"]
175
+ assert len(prefix_lens_cpu) == batch_size
176
+ prefix_lens = torch.tensor(prefix_lens_cpu, device=self.device)
177
+ max_chunk_capacity = test_case["max_chunk_capacity"]
178
+ seq_lens_cpu = [
179
+ self.extend_len + prefix_lens_cpu[i] for i in range(batch_size)
180
+ ]
181
+ seq_lens = torch.tensor(seq_lens_cpu, device=self.device)
182
+
183
+ # Create forward batch
184
+ # input_ids and out_cache_loc are dummy tensors in this test
185
+ forward_batch = MockForwardBatch(
186
+ max_chunk_capacity=max_chunk_capacity,
187
+ batch_size=batch_size,
188
+ input_ids=torch.randint(
189
+ 0, 100, (batch_size, self.extend_len), device=self.device
190
+ ),
191
+ out_cache_loc=torch.arange(
192
+ self.max_bs * self.max_seq_len - batch_size * self.extend_len,
193
+ self.max_bs * self.max_seq_len,
194
+ device=self.device,
195
+ ),
196
+ seq_lens_sum=sum(seq_lens_cpu),
197
+ forward_mode=ForwardMode.EXTEND,
198
+ req_pool_indices=torch.arange(batch_size, device=self.device),
199
+ seq_lens=seq_lens,
200
+ seq_lens_cpu=seq_lens_cpu,
201
+ extend_prefix_lens=prefix_lens,
202
+ extend_prefix_lens_cpu=prefix_lens_cpu,
203
+ )
204
+ forward_batch.req_to_token_pool = self.req_to_token_pool
205
+ forward_batch.token_to_kv_pool = self.token_to_kv_pool
206
+
207
+ forward_batch.prepare_chunked_prefix_cache_info(self.device)
208
+ assert forward_batch.get_max_chunk_capacity() == max_chunk_capacity
209
+ assert forward_batch.prefix_chunk_len == test_case["prefix_chunk_len"]
210
+ assert forward_batch.num_prefix_chunks == test_case["num_prefix_chunks"]
211
+ assert torch.allclose(
212
+ forward_batch.prefix_chunk_starts,
213
+ test_case["prefix_chunk_starts"].to(self.device),
214
+ )
215
+ assert torch.allclose(
216
+ forward_batch.prefix_chunk_seq_lens,
217
+ test_case["prefix_chunk_seq_lens"].to(self.device),
218
+ )
219
+
220
+ check_kv_indices(forward_batch)
221
+
222
+
223
+ if __name__ == "__main__":
224
+ unittest.main()
sglang/test/runners.py CHANGED
@@ -26,8 +26,8 @@ from transformers import (
26
26
  AutoProcessor,
27
27
  )
28
28
 
29
+ from sglang.srt.entrypoints.engine import Engine
29
30
  from sglang.srt.hf_transformers_utils import get_tokenizer
30
- from sglang.srt.server import Engine
31
31
  from sglang.srt.utils import load_image
32
32
  from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
33
33
 
@@ -51,6 +51,8 @@ NUM_TOP_LOGPROBS = 5
51
51
  def get_dtype_str(torch_dtype):
52
52
  if torch_dtype is torch.float16:
53
53
  return "float16"
54
+ if torch_dtype is torch.float32:
55
+ return "float32"
54
56
  else:
55
57
  raise NotImplementedError()
56
58
 
@@ -447,6 +449,7 @@ class SRTRunner:
447
449
  port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
448
450
  lora_paths: List[str] = None,
449
451
  max_loras_per_batch: int = 4,
452
+ attention_backend: Optional[str] = None,
450
453
  lora_backend: str = "triton",
451
454
  disable_cuda_graph: bool = False,
452
455
  disable_radix_cache: bool = False,
@@ -487,6 +490,7 @@ class SRTRunner:
487
490
  lora_paths=lora_paths,
488
491
  max_loras_per_batch=max_loras_per_batch,
489
492
  lora_backend=lora_backend,
493
+ attention_backend=attention_backend,
490
494
  disable_cuda_graph=disable_cuda_graph,
491
495
  disable_radix_cache=disable_radix_cache,
492
496
  chunked_prefill_size=chunked_prefill_size,