sglang 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +6 -25
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +104 -71
  31. sglang/srt/managers/tokenizer_manager.py +17 -8
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +58 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +117 -131
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +1 -5
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +1 -5
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/llama.py +51 -5
  49. sglang/srt/models/llama_classification.py +1 -20
  50. sglang/srt/models/llava.py +30 -5
  51. sglang/srt/models/llavavid.py +2 -2
  52. sglang/srt/models/minicpm.py +1 -5
  53. sglang/srt/models/minicpm3.py +665 -0
  54. sglang/srt/models/mixtral.py +6 -5
  55. sglang/srt/models/mixtral_quant.py +1 -5
  56. sglang/srt/models/qwen.py +1 -5
  57. sglang/srt/models/qwen2.py +1 -5
  58. sglang/srt/models/qwen2_moe.py +6 -5
  59. sglang/srt/models/stablelm.py +1 -5
  60. sglang/srt/models/xverse.py +375 -0
  61. sglang/srt/models/xverse_moe.py +445 -0
  62. sglang/srt/openai_api/adapter.py +65 -46
  63. sglang/srt/openai_api/protocol.py +11 -3
  64. sglang/srt/sampling/sampling_batch_info.py +57 -44
  65. sglang/srt/server.py +24 -14
  66. sglang/srt/server_args.py +130 -28
  67. sglang/srt/utils.py +12 -0
  68. sglang/test/few_shot_gsm8k.py +132 -0
  69. sglang/test/runners.py +114 -22
  70. sglang/test/test_programs.py +7 -5
  71. sglang/test/test_utils.py +85 -1
  72. sglang/utils.py +32 -37
  73. sglang/version.py +1 -1
  74. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/METADATA +30 -18
  75. sglang-0.3.1.dist-info/RECORD +129 -0
  76. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  77. sglang-0.3.0.dist-info/RECORD +0 -118
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  79. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,480 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Support different attention backends.
5
+ Now there are two backends: FlashInfer and Triton.
6
+ FlashInfer is faster and Triton is easier to customize.
7
+ Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
+ """
9
+
10
+ from abc import ABC, abstractmethod
11
+ from typing import TYPE_CHECKING
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from flashinfer import (
16
+ BatchDecodeWithPagedKVCacheWrapper,
17
+ BatchPrefillWithPagedKVCacheWrapper,
18
+ BatchPrefillWithRaggedKVCacheWrapper,
19
+ )
20
+ from flashinfer.cascade import merge_state
21
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
22
+
23
+ from sglang.global_config import global_config
24
+ from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
25
+ from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
26
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
27
+
28
+ if TYPE_CHECKING:
29
+ from sglang.srt.model_executor.model_runner import ModelRunner
30
+
31
+
32
+ class AttentionBackend(ABC):
33
+ """The base class of attention backends"""
34
+
35
+ @abstractmethod
36
+ def init_forward_metadata(
37
+ self, batch: ScheduleBatch, input_metadata: InputMetadata
38
+ ):
39
+ """Init the metadata for a forward pass."""
40
+ raise NotImplementedError()
41
+
42
+ def init_cuda_graph_state(self, max_bs: int):
43
+ """Init the global shared states for cuda graph."""
44
+ raise NotImplementedError()
45
+
46
+ def init_forward_metadata_capture_cuda_graph(
47
+ self, bs: int, req_pool_indices, seq_lens
48
+ ):
49
+ """Init the metadata for a forward pass for capturing a cuda graph."""
50
+ raise NotImplementedError()
51
+
52
+ def init_forward_metadata_replay_cuda_graph(
53
+ self, bs: int, req_pool_indices, seq_lens
54
+ ):
55
+ """Init the metadata for a forward pass for replying a cuda graph."""
56
+ raise NotImplementedError()
57
+
58
+ def get_cuda_graph_seq_len_fill_value(self):
59
+ """Get the fill value for padded seq lens. Typically, it is 0 or 1."""
60
+ raise NotImplementedError()
61
+
62
+ def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
63
+ """Run forward on an attention layer."""
64
+ if input_metadata.forward_mode.is_decode():
65
+ return self.forward_decode(q, k, v, layer, input_metadata)
66
+ else:
67
+ return self.forward_extend(q, k, v, layer, input_metadata)
68
+
69
+ def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
70
+ """Run a forward for decode."""
71
+ raise NotImplementedError()
72
+
73
+ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
74
+ """Run a forward for extend."""
75
+ raise NotImplementedError()
76
+
77
+
78
+ class FlashInferAttnBackend(AttentionBackend):
79
+ """Flashinfer attention kernels."""
80
+
81
+ def __init__(self, model_runner: ModelRunner):
82
+ super().__init__()
83
+ self.model_runner = model_runner
84
+
85
+ local_num_qo_heads = (
86
+ model_runner.model_config.num_attention_heads // model_runner.tp_size
87
+ )
88
+ local_num_kv_heads = model_runner.model_config.get_num_kv_heads(
89
+ model_runner.tp_size
90
+ )
91
+ if (
92
+ not _grouped_size_compiled_for_decode_kernels(
93
+ local_num_qo_heads, local_num_kv_heads
94
+ )
95
+ or local_num_qo_heads // local_num_kv_heads > 4
96
+ ):
97
+ self.decode_use_tensor_cores = True
98
+ else:
99
+ self.decode_use_tensor_cores = False
100
+
101
+ self.workspace_buffer = torch.empty(
102
+ global_config.flashinfer_workspace_size,
103
+ dtype=torch.uint8,
104
+ device="cuda",
105
+ )
106
+
107
+ if model_runner.sliding_window_size is None:
108
+ self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
109
+ self.workspace_buffer, "NHD"
110
+ )
111
+ self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
112
+ self.workspace_buffer, "NHD"
113
+ )
114
+ self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
115
+ self.workspace_buffer,
116
+ "NHD",
117
+ use_tensor_cores=self.decode_use_tensor_cores,
118
+ )
119
+ else:
120
+ # Two wrappers: one for sliding window attention and one for full attention.
121
+ # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
122
+ self.prefill_wrapper_ragged = None
123
+ self.prefill_wrapper_paged = []
124
+ self.decode_wrapper = []
125
+ for _ in range(2):
126
+ self.prefill_wrapper_paged.append(
127
+ BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
128
+ )
129
+ self.decode_wrapper.append(
130
+ BatchDecodeWithPagedKVCacheWrapper(
131
+ self.workspace_buffer,
132
+ "NHD",
133
+ use_tensor_cores=self.decode_use_tensor_cores,
134
+ )
135
+ )
136
+
137
+ self.forward_metadata = None
138
+ self.cuda_graph_metadata = {}
139
+
140
+ def init_forward_metadata(
141
+ self, batch: ScheduleBatch, input_metadata: InputMetadata
142
+ ):
143
+ if input_metadata.forward_mode.is_decode():
144
+ prefix_lens = None
145
+ use_ragged = False
146
+ total_num_tokens = None
147
+ else:
148
+ prefix_lens = input_metadata.extend_prefix_lens
149
+
150
+ # Some heuristics to check whether to use ragged forward
151
+ use_ragged = False
152
+ if (
153
+ int(torch.sum(input_metadata.seq_lens)) > 4096
154
+ and self.model_runner.sliding_window_size is None
155
+ ):
156
+ use_ragged = True
157
+
158
+ total_num_tokens = torch.sum(input_metadata.seq_lens).item()
159
+
160
+ update_flashinfer_indices(
161
+ input_metadata.forward_mode,
162
+ self.model_runner,
163
+ input_metadata.req_pool_indices,
164
+ input_metadata.seq_lens,
165
+ prefix_lens,
166
+ use_ragged=use_ragged,
167
+ )
168
+
169
+ self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper)
170
+
171
+ def init_cuda_graph_state(self, max_bs: int):
172
+ self.cuda_graph_kv_indptr = torch.zeros(
173
+ (max_bs + 1,), dtype=torch.int32, device="cuda"
174
+ )
175
+ self.cuda_graph_kv_indices = torch.zeros(
176
+ (max_bs * self.model_runner.model_config.context_len,),
177
+ dtype=torch.int32,
178
+ device="cuda",
179
+ )
180
+ self.cuda_graph_kv_last_page_len = torch.ones(
181
+ (max_bs,), dtype=torch.int32, device="cuda"
182
+ )
183
+
184
+ if self.model_runner.sliding_window_size is not None:
185
+ self.cuda_graph_kv_indptr = [
186
+ self.cuda_graph_kv_indptr,
187
+ self.cuda_graph_kv_indptr.clone(),
188
+ ]
189
+ self.cuda_graph_kv_indices = [
190
+ self.cuda_graph_kv_indices,
191
+ self.cuda_graph_kv_indices.clone(),
192
+ ]
193
+
194
+ def init_forward_metadata_capture_cuda_graph(
195
+ self, bs: int, req_pool_indices, seq_lens
196
+ ):
197
+ if self.model_runner.sliding_window_size is None:
198
+ decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
199
+ self.workspace_buffer,
200
+ "NHD",
201
+ use_cuda_graph=True,
202
+ use_tensor_cores=self.decode_use_tensor_cores,
203
+ paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[: bs + 1],
204
+ paged_kv_indices_buffer=self.cuda_graph_kv_indices,
205
+ paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
206
+ )
207
+ else:
208
+ decode_wrapper = []
209
+ for i in range(2):
210
+ decode_wrapper.append(
211
+ BatchDecodeWithPagedKVCacheWrapper(
212
+ self.workspace_buffer,
213
+ "NHD",
214
+ use_cuda_graph=True,
215
+ use_tensor_cores=self.decode_use_tensor_cores,
216
+ paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
217
+ paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
218
+ paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[
219
+ :bs
220
+ ],
221
+ )
222
+ )
223
+
224
+ update_flashinfer_indices(
225
+ ForwardMode.DECODE,
226
+ self.model_runner,
227
+ req_pool_indices,
228
+ seq_lens,
229
+ None,
230
+ decode_wrapper,
231
+ )
232
+
233
+ self.cuda_graph_metadata[bs] = decode_wrapper
234
+
235
+ self.forward_metadata = (False, None, decode_wrapper)
236
+
237
+ def init_forward_metadata_replay_cuda_graph(
238
+ self, bs: int, req_pool_indices, seq_lens
239
+ ):
240
+ update_flashinfer_indices(
241
+ ForwardMode.DECODE,
242
+ self.model_runner,
243
+ req_pool_indices[:bs],
244
+ seq_lens[:bs],
245
+ None,
246
+ self.cuda_graph_metadata[bs],
247
+ )
248
+
249
+ def get_cuda_graph_seq_len_fill_value(self):
250
+ return 0
251
+
252
+ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
253
+ if not isinstance(self.prefill_wrapper_paged, list):
254
+ prefill_wrapper_paged = self.prefill_wrapper_paged
255
+ else:
256
+ if layer.sliding_window_size != -1:
257
+ prefill_wrapper_paged = self.prefill_wrapper_paged[0]
258
+ else:
259
+ prefill_wrapper_paged = self.prefill_wrapper_paged[1]
260
+
261
+ use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
262
+
263
+ if not use_ragged:
264
+ if k is not None:
265
+ assert v is not None
266
+ input_metadata.token_to_kv_pool.set_kv_buffer(
267
+ layer.layer_id, input_metadata.out_cache_loc, k, v
268
+ )
269
+ o = prefill_wrapper_paged.forward(
270
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
271
+ input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
272
+ causal=True,
273
+ sm_scale=layer.scaling,
274
+ window_left=layer.sliding_window_size,
275
+ logits_soft_cap=layer.logit_cap,
276
+ )
277
+ else:
278
+ o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
279
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
280
+ k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
281
+ v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
282
+ causal=True,
283
+ sm_scale=layer.scaling,
284
+ logits_soft_cap=layer.logit_cap,
285
+ )
286
+
287
+ if input_metadata.extend_no_prefix:
288
+ o = o1
289
+ else:
290
+ o2, s2 = prefill_wrapper_paged.forward_return_lse(
291
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
292
+ input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
293
+ causal=False,
294
+ sm_scale=layer.scaling,
295
+ logits_soft_cap=layer.logit_cap,
296
+ )
297
+
298
+ o, _ = merge_state(o1, s1, o2, s2)
299
+
300
+ input_metadata.token_to_kv_pool.set_kv_buffer(
301
+ layer.layer_id, input_metadata.out_cache_loc, k, v
302
+ )
303
+
304
+ if total_num_tokens >= global_config.layer_sync_threshold:
305
+ # TODO: Revisit this. Why is this synchronize needed?
306
+ torch.cuda.synchronize()
307
+
308
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
309
+
310
+ def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
311
+ use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
312
+
313
+ if isinstance(decode_wrapper, list):
314
+ if layer.sliding_window_size != -1:
315
+ decode_wrapper = decode_wrapper[0]
316
+ else:
317
+ decode_wrapper = decode_wrapper[1]
318
+
319
+ if k is not None:
320
+ assert v is not None
321
+ input_metadata.token_to_kv_pool.set_kv_buffer(
322
+ layer.layer_id, input_metadata.out_cache_loc, k, v
323
+ )
324
+
325
+ o = decode_wrapper.forward(
326
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
327
+ input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
328
+ sm_scale=layer.scaling,
329
+ logits_soft_cap=layer.logit_cap,
330
+ )
331
+
332
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
333
+
334
+
335
+ class TritonAttnBackend(AttentionBackend):
336
+ def __init__(self, model_runner: ModelRunner):
337
+ # Lazy import to avoid the initialization of cuda context
338
+ from sglang.srt.layers.triton_attention.decode_attention import (
339
+ decode_attention_fwd,
340
+ )
341
+ from sglang.srt.layers.triton_attention.extend_attention import (
342
+ extend_attention_fwd,
343
+ )
344
+
345
+ super().__init__()
346
+
347
+ self.decode_attention_fwd = decode_attention_fwd
348
+ self.extend_attention_fwd = extend_attention_fwd
349
+ self.num_head = model_runner.model_config.num_attention_heads
350
+
351
+ if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
352
+ self.reduce_dtype = torch.float32
353
+ else:
354
+ self.reduce_dtype = torch.float16
355
+
356
+ self.forward_metadata = None
357
+
358
+ self.cuda_graph_max_seq_len = model_runner.model_config.context_len
359
+
360
+ def init_forward_metadata(
361
+ self, batch: ScheduleBatch, input_metadata: InputMetadata
362
+ ):
363
+ """Init auxiliary variables for triton attention backend."""
364
+
365
+ if input_metadata.forward_mode.is_decode():
366
+ start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
367
+ start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
368
+
369
+ total_num_tokens = torch.sum(input_metadata.seq_lens).item()
370
+ attn_logits = torch.empty(
371
+ (self.num_head, total_num_tokens),
372
+ dtype=self.reduce_dtype,
373
+ device="cuda",
374
+ )
375
+
376
+ max_seq_len = torch.max(input_metadata.seq_lens).item()
377
+ max_extend_len = None
378
+ else:
379
+ start_loc = attn_logits = max_seq_len = None
380
+ prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
381
+ max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
382
+
383
+ self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
384
+
385
+ def init_cuda_graph_state(self, max_bs: int):
386
+ self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
387
+
388
+ self.cuda_graph_start_loc = torch.zeros(
389
+ (max_bs,), dtype=torch.int32, device="cuda"
390
+ )
391
+ self.cuda_graph_attn_logits = torch.empty(
392
+ (
393
+ self.num_head,
394
+ self.cuda_graph_max_total_num_tokens,
395
+ ),
396
+ dtype=self.reduce_dtype,
397
+ device="cuda",
398
+ )
399
+
400
+ def init_forward_metadata_capture_cuda_graph(
401
+ self, bs: int, req_pool_indices, seq_lens
402
+ ):
403
+ self.forward_metadata = (
404
+ self.cuda_graph_start_loc,
405
+ self.cuda_graph_attn_logits,
406
+ self.cuda_graph_max_seq_len,
407
+ None,
408
+ )
409
+
410
+ def init_forward_metadata_replay_cuda_graph(
411
+ self, bs: int, req_pool_indices, seq_lens
412
+ ):
413
+ self.cuda_graph_start_loc.zero_()
414
+ self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
415
+
416
+ def get_cuda_graph_seq_len_fill_value(self):
417
+ return 1
418
+
419
+ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
420
+ # TODO: reuse the buffer across layers
421
+ if layer.qk_head_dim != layer.v_head_dim:
422
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
423
+ else:
424
+ o = torch.empty_like(q)
425
+
426
+ input_metadata.token_to_kv_pool.set_kv_buffer(
427
+ layer.layer_id, input_metadata.out_cache_loc, k, v
428
+ )
429
+
430
+ start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
431
+ self.extend_attention_fwd(
432
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
433
+ k.contiguous(),
434
+ v.contiguous(),
435
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
436
+ input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
437
+ input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
438
+ input_metadata.req_to_token_pool.req_to_token,
439
+ input_metadata.req_pool_indices,
440
+ input_metadata.seq_lens,
441
+ input_metadata.extend_seq_lens,
442
+ input_metadata.extend_start_loc,
443
+ max_extend_len,
444
+ layer.scaling,
445
+ layer.logit_cap,
446
+ )
447
+ return o
448
+
449
+ def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
450
+ # During torch.compile, there is a bug in rotary_emb that causes the
451
+ # output value to have a 3D tensor shape. This reshapes the output correctly.
452
+ q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
453
+
454
+ # TODO: reuse the buffer across layers
455
+ if layer.qk_head_dim != layer.v_head_dim:
456
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
457
+ else:
458
+ o = torch.empty_like(q)
459
+
460
+ start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
461
+
462
+ input_metadata.token_to_kv_pool.set_kv_buffer(
463
+ layer.layer_id, input_metadata.out_cache_loc, k, v
464
+ )
465
+
466
+ self.decode_attention_fwd(
467
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
468
+ input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
469
+ input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
470
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
471
+ input_metadata.req_to_token_pool.req_to_token,
472
+ input_metadata.req_pool_indices,
473
+ start_loc,
474
+ input_metadata.seq_lens,
475
+ attn_logits,
476
+ max_seq_len,
477
+ layer.scaling,
478
+ layer.logit_cap,
479
+ )
480
+ return o