sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_latency.py +17 -8
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +5 -17
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -4
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +33 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/activation.py +12 -0
  15. sglang/srt/layers/attention_backend.py +480 -0
  16. sglang/srt/layers/flashinfer_utils.py +235 -0
  17. sglang/srt/layers/fused_moe/layer.py +27 -7
  18. sglang/srt/layers/layernorm.py +12 -0
  19. sglang/srt/layers/logits_processor.py +64 -77
  20. sglang/srt/layers/radix_attention.py +11 -161
  21. sglang/srt/layers/sampler.py +38 -122
  22. sglang/srt/layers/torchao_utils.py +75 -0
  23. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  24. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  25. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  26. sglang/srt/lora/lora.py +403 -0
  27. sglang/srt/lora/lora_config.py +43 -0
  28. sglang/srt/lora/lora_manager.py +259 -0
  29. sglang/srt/managers/controller_multi.py +1 -5
  30. sglang/srt/managers/controller_single.py +0 -5
  31. sglang/srt/managers/io_struct.py +16 -1
  32. sglang/srt/managers/policy_scheduler.py +122 -5
  33. sglang/srt/managers/schedule_batch.py +105 -71
  34. sglang/srt/managers/tokenizer_manager.py +17 -8
  35. sglang/srt/managers/tp_worker.py +188 -121
  36. sglang/srt/model_executor/cuda_graph_runner.py +69 -133
  37. sglang/srt/model_executor/forward_batch_info.py +35 -312
  38. sglang/srt/model_executor/model_runner.py +123 -154
  39. sglang/srt/models/baichuan.py +416 -0
  40. sglang/srt/models/chatglm.py +1 -5
  41. sglang/srt/models/commandr.py +1 -5
  42. sglang/srt/models/dbrx.py +1 -5
  43. sglang/srt/models/deepseek.py +1 -5
  44. sglang/srt/models/deepseek_v2.py +7 -6
  45. sglang/srt/models/exaone.py +1 -5
  46. sglang/srt/models/gemma.py +1 -5
  47. sglang/srt/models/gemma2.py +1 -5
  48. sglang/srt/models/gpt_bigcode.py +1 -5
  49. sglang/srt/models/grok.py +1 -5
  50. sglang/srt/models/internlm2.py +1 -5
  51. sglang/srt/models/llama.py +51 -5
  52. sglang/srt/models/llama_classification.py +1 -20
  53. sglang/srt/models/llava.py +30 -5
  54. sglang/srt/models/llavavid.py +2 -2
  55. sglang/srt/models/minicpm.py +1 -5
  56. sglang/srt/models/minicpm3.py +669 -0
  57. sglang/srt/models/mixtral.py +6 -5
  58. sglang/srt/models/mixtral_quant.py +1 -5
  59. sglang/srt/models/olmoe.py +415 -0
  60. sglang/srt/models/qwen.py +1 -5
  61. sglang/srt/models/qwen2.py +1 -5
  62. sglang/srt/models/qwen2_moe.py +6 -5
  63. sglang/srt/models/stablelm.py +1 -5
  64. sglang/srt/models/xverse.py +375 -0
  65. sglang/srt/models/xverse_moe.py +445 -0
  66. sglang/srt/openai_api/adapter.py +65 -46
  67. sglang/srt/openai_api/protocol.py +11 -3
  68. sglang/srt/sampling/sampling_batch_info.py +46 -80
  69. sglang/srt/server.py +30 -15
  70. sglang/srt/server_args.py +163 -28
  71. sglang/srt/utils.py +19 -51
  72. sglang/test/few_shot_gsm8k.py +132 -0
  73. sglang/test/runners.py +114 -22
  74. sglang/test/test_programs.py +7 -5
  75. sglang/test/test_utils.py +85 -2
  76. sglang/utils.py +32 -37
  77. sglang/version.py +1 -1
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
  79. sglang-0.3.1.post1.dist-info/RECORD +130 -0
  80. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
  81. sglang-0.3.0.dist-info/RECORD +0 -118
  82. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
  83. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,480 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Support different attention backends.
5
+ Now there are two backends: FlashInfer and Triton.
6
+ FlashInfer is faster and Triton is easier to customize.
7
+ Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
+ """
9
+
10
+ from abc import ABC, abstractmethod
11
+ from typing import TYPE_CHECKING
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ from sglang.global_config import global_config
17
+ from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
18
+ from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
19
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
20
+ from sglang.srt.utils import is_hip
21
+
22
+ if TYPE_CHECKING:
23
+ from sglang.srt.model_executor.model_runner import ModelRunner
24
+
25
+ # ROCm: flashinfer available later
26
+ if not is_hip():
27
+ from flashinfer import (
28
+ BatchDecodeWithPagedKVCacheWrapper,
29
+ BatchPrefillWithPagedKVCacheWrapper,
30
+ BatchPrefillWithRaggedKVCacheWrapper,
31
+ )
32
+ from flashinfer.cascade import merge_state
33
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
34
+
35
+
36
+ class AttentionBackend(ABC):
37
+ """The base class of attention backends"""
38
+
39
+ @abstractmethod
40
+ def init_forward_metadata(
41
+ self, batch: ScheduleBatch, input_metadata: InputMetadata
42
+ ):
43
+ """Init the metadata for a forward pass."""
44
+ raise NotImplementedError()
45
+
46
+ def init_cuda_graph_state(self, max_bs: int):
47
+ """Init the global shared states for cuda graph."""
48
+ raise NotImplementedError()
49
+
50
+ def init_forward_metadata_capture_cuda_graph(
51
+ self, bs: int, req_pool_indices, seq_lens
52
+ ):
53
+ """Init the metadata for a forward pass for capturing a cuda graph."""
54
+ raise NotImplementedError()
55
+
56
+ def init_forward_metadata_replay_cuda_graph(
57
+ self, bs: int, req_pool_indices, seq_lens
58
+ ):
59
+ """Init the metadata for a forward pass for replying a cuda graph."""
60
+ raise NotImplementedError()
61
+
62
+ def get_cuda_graph_seq_len_fill_value(self):
63
+ """Get the fill value for padded seq lens. Typically, it is 0 or 1."""
64
+ raise NotImplementedError()
65
+
66
+ def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
67
+ """Run forward on an attention layer."""
68
+ if input_metadata.forward_mode.is_decode():
69
+ return self.forward_decode(q, k, v, layer, input_metadata)
70
+ else:
71
+ return self.forward_extend(q, k, v, layer, input_metadata)
72
+
73
+ def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
74
+ """Run a forward for decode."""
75
+ raise NotImplementedError()
76
+
77
+ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
78
+ """Run a forward for extend."""
79
+ raise NotImplementedError()
80
+
81
+
82
+ class FlashInferAttnBackend(AttentionBackend):
83
+ """Flashinfer attention kernels."""
84
+
85
+ def __init__(self, model_runner: ModelRunner):
86
+ super().__init__()
87
+ self.model_runner = model_runner
88
+
89
+ local_num_qo_heads = (
90
+ model_runner.model_config.num_attention_heads // model_runner.tp_size
91
+ )
92
+ local_num_kv_heads = model_runner.model_config.get_num_kv_heads(
93
+ model_runner.tp_size
94
+ )
95
+ if (
96
+ not _grouped_size_compiled_for_decode_kernels(
97
+ local_num_qo_heads, local_num_kv_heads
98
+ )
99
+ or local_num_qo_heads // local_num_kv_heads > 4
100
+ ):
101
+ self.decode_use_tensor_cores = True
102
+ else:
103
+ self.decode_use_tensor_cores = False
104
+
105
+ self.workspace_buffer = torch.empty(
106
+ global_config.flashinfer_workspace_size,
107
+ dtype=torch.uint8,
108
+ device="cuda",
109
+ )
110
+
111
+ if model_runner.sliding_window_size is None:
112
+ self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
113
+ self.workspace_buffer, "NHD"
114
+ )
115
+ self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
116
+ self.workspace_buffer, "NHD"
117
+ )
118
+ self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
119
+ self.workspace_buffer,
120
+ "NHD",
121
+ use_tensor_cores=self.decode_use_tensor_cores,
122
+ )
123
+ else:
124
+ # Two wrappers: one for sliding window attention and one for full attention.
125
+ # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
126
+ self.prefill_wrapper_ragged = None
127
+ self.prefill_wrapper_paged = []
128
+ self.decode_wrapper = []
129
+ for _ in range(2):
130
+ self.prefill_wrapper_paged.append(
131
+ BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
132
+ )
133
+ self.decode_wrapper.append(
134
+ BatchDecodeWithPagedKVCacheWrapper(
135
+ self.workspace_buffer,
136
+ "NHD",
137
+ use_tensor_cores=self.decode_use_tensor_cores,
138
+ )
139
+ )
140
+
141
+ self.forward_metadata = None
142
+ self.cuda_graph_metadata = {}
143
+
144
+ def init_forward_metadata(
145
+ self, batch: ScheduleBatch, input_metadata: InputMetadata
146
+ ):
147
+ if input_metadata.forward_mode.is_decode():
148
+ prefix_lens = None
149
+ use_ragged = False
150
+ total_num_tokens = None
151
+ else:
152
+ prefix_lens = input_metadata.extend_prefix_lens
153
+
154
+ # Some heuristics to check whether to use ragged forward
155
+ use_ragged = False
156
+ if (
157
+ torch.sum(input_metadata.seq_lens).item() >= 4096
158
+ and self.model_runner.sliding_window_size is None
159
+ ):
160
+ use_ragged = True
161
+
162
+ total_num_tokens = torch.sum(input_metadata.seq_lens).item()
163
+
164
+ update_flashinfer_indices(
165
+ input_metadata.forward_mode,
166
+ self.model_runner,
167
+ input_metadata.req_pool_indices,
168
+ input_metadata.seq_lens,
169
+ prefix_lens,
170
+ use_ragged=use_ragged,
171
+ )
172
+
173
+ self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper)
174
+
175
+ def init_cuda_graph_state(self, max_bs: int):
176
+ self.cuda_graph_kv_indptr = torch.zeros(
177
+ (max_bs + 1,), dtype=torch.int32, device="cuda"
178
+ )
179
+ self.cuda_graph_kv_indices = torch.zeros(
180
+ (max_bs * self.model_runner.model_config.context_len,),
181
+ dtype=torch.int32,
182
+ device="cuda",
183
+ )
184
+ self.cuda_graph_kv_last_page_len = torch.ones(
185
+ (max_bs,), dtype=torch.int32, device="cuda"
186
+ )
187
+
188
+ if self.model_runner.sliding_window_size is not None:
189
+ self.cuda_graph_kv_indptr = [
190
+ self.cuda_graph_kv_indptr,
191
+ self.cuda_graph_kv_indptr.clone(),
192
+ ]
193
+ self.cuda_graph_kv_indices = [
194
+ self.cuda_graph_kv_indices,
195
+ self.cuda_graph_kv_indices.clone(),
196
+ ]
197
+
198
+ def init_forward_metadata_capture_cuda_graph(
199
+ self, bs: int, req_pool_indices, seq_lens
200
+ ):
201
+ if self.model_runner.sliding_window_size is None:
202
+ decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
203
+ self.workspace_buffer,
204
+ "NHD",
205
+ use_cuda_graph=True,
206
+ use_tensor_cores=self.decode_use_tensor_cores,
207
+ paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[: bs + 1],
208
+ paged_kv_indices_buffer=self.cuda_graph_kv_indices,
209
+ paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
210
+ )
211
+ else:
212
+ decode_wrapper = []
213
+ for i in range(2):
214
+ decode_wrapper.append(
215
+ BatchDecodeWithPagedKVCacheWrapper(
216
+ self.workspace_buffer,
217
+ "NHD",
218
+ use_cuda_graph=True,
219
+ use_tensor_cores=self.decode_use_tensor_cores,
220
+ paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
221
+ paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
222
+ paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[
223
+ :bs
224
+ ],
225
+ )
226
+ )
227
+
228
+ update_flashinfer_indices(
229
+ ForwardMode.DECODE,
230
+ self.model_runner,
231
+ req_pool_indices,
232
+ seq_lens,
233
+ None,
234
+ decode_wrapper,
235
+ )
236
+
237
+ self.cuda_graph_metadata[bs] = decode_wrapper
238
+
239
+ self.forward_metadata = (False, None, decode_wrapper)
240
+
241
+ def init_forward_metadata_replay_cuda_graph(
242
+ self, bs: int, req_pool_indices, seq_lens
243
+ ):
244
+ update_flashinfer_indices(
245
+ ForwardMode.DECODE,
246
+ self.model_runner,
247
+ req_pool_indices[:bs],
248
+ seq_lens[:bs],
249
+ None,
250
+ self.cuda_graph_metadata[bs],
251
+ )
252
+
253
+ def get_cuda_graph_seq_len_fill_value(self):
254
+ return 0
255
+
256
+ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
257
+ if not isinstance(self.prefill_wrapper_paged, list):
258
+ prefill_wrapper_paged = self.prefill_wrapper_paged
259
+ else:
260
+ if layer.sliding_window_size != -1:
261
+ prefill_wrapper_paged = self.prefill_wrapper_paged[0]
262
+ else:
263
+ prefill_wrapper_paged = self.prefill_wrapper_paged[1]
264
+
265
+ use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
266
+
267
+ if not use_ragged:
268
+ if k is not None:
269
+ assert v is not None
270
+ input_metadata.token_to_kv_pool.set_kv_buffer(
271
+ layer.layer_id, input_metadata.out_cache_loc, k, v
272
+ )
273
+ o = prefill_wrapper_paged.forward(
274
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
275
+ input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
276
+ causal=True,
277
+ sm_scale=layer.scaling,
278
+ window_left=layer.sliding_window_size,
279
+ logits_soft_cap=layer.logit_cap,
280
+ )
281
+ else:
282
+ o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
283
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
284
+ k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
285
+ v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
286
+ causal=True,
287
+ sm_scale=layer.scaling,
288
+ logits_soft_cap=layer.logit_cap,
289
+ )
290
+
291
+ if input_metadata.extend_no_prefix:
292
+ o = o1
293
+ else:
294
+ o2, s2 = prefill_wrapper_paged.forward_return_lse(
295
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
296
+ input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
297
+ causal=False,
298
+ sm_scale=layer.scaling,
299
+ logits_soft_cap=layer.logit_cap,
300
+ )
301
+
302
+ o, _ = merge_state(o1, s1, o2, s2)
303
+
304
+ input_metadata.token_to_kv_pool.set_kv_buffer(
305
+ layer.layer_id, input_metadata.out_cache_loc, k, v
306
+ )
307
+
308
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
309
+
310
+ def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
311
+ use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
312
+
313
+ if isinstance(decode_wrapper, list):
314
+ if layer.sliding_window_size != -1:
315
+ decode_wrapper = decode_wrapper[0]
316
+ else:
317
+ decode_wrapper = decode_wrapper[1]
318
+
319
+ if k is not None:
320
+ assert v is not None
321
+ input_metadata.token_to_kv_pool.set_kv_buffer(
322
+ layer.layer_id, input_metadata.out_cache_loc, k, v
323
+ )
324
+
325
+ o = decode_wrapper.forward(
326
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
327
+ input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
328
+ sm_scale=layer.scaling,
329
+ logits_soft_cap=layer.logit_cap,
330
+ )
331
+
332
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
333
+
334
+
335
+ class TritonAttnBackend(AttentionBackend):
336
+ def __init__(self, model_runner: ModelRunner):
337
+ # Lazy import to avoid the initialization of cuda context
338
+ from sglang.srt.layers.triton_attention.decode_attention import (
339
+ decode_attention_fwd,
340
+ )
341
+ from sglang.srt.layers.triton_attention.extend_attention import (
342
+ extend_attention_fwd,
343
+ )
344
+
345
+ super().__init__()
346
+
347
+ self.decode_attention_fwd = decode_attention_fwd
348
+ self.extend_attention_fwd = extend_attention_fwd
349
+ self.num_head = model_runner.model_config.num_attention_heads
350
+
351
+ if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
352
+ self.reduce_dtype = torch.float32
353
+ else:
354
+ self.reduce_dtype = torch.float16
355
+
356
+ self.forward_metadata = None
357
+
358
+ self.cuda_graph_max_seq_len = model_runner.model_config.context_len
359
+
360
+ def init_forward_metadata(
361
+ self, batch: ScheduleBatch, input_metadata: InputMetadata
362
+ ):
363
+ """Init auxiliary variables for triton attention backend."""
364
+
365
+ if input_metadata.forward_mode.is_decode():
366
+ start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
367
+ start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
368
+
369
+ total_num_tokens = torch.sum(input_metadata.seq_lens).item()
370
+ attn_logits = torch.empty(
371
+ (self.num_head, total_num_tokens),
372
+ dtype=self.reduce_dtype,
373
+ device="cuda",
374
+ )
375
+
376
+ max_seq_len = torch.max(input_metadata.seq_lens).item()
377
+ max_extend_len = None
378
+ else:
379
+ start_loc = attn_logits = max_seq_len = None
380
+ prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
381
+ max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
382
+
383
+ self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
384
+
385
+ def init_cuda_graph_state(self, max_bs: int):
386
+ self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
387
+
388
+ self.cuda_graph_start_loc = torch.zeros(
389
+ (max_bs,), dtype=torch.int32, device="cuda"
390
+ )
391
+ self.cuda_graph_attn_logits = torch.empty(
392
+ (
393
+ self.num_head,
394
+ self.cuda_graph_max_total_num_tokens,
395
+ ),
396
+ dtype=self.reduce_dtype,
397
+ device="cuda",
398
+ )
399
+
400
+ def init_forward_metadata_capture_cuda_graph(
401
+ self, bs: int, req_pool_indices, seq_lens
402
+ ):
403
+ self.forward_metadata = (
404
+ self.cuda_graph_start_loc,
405
+ self.cuda_graph_attn_logits,
406
+ self.cuda_graph_max_seq_len,
407
+ None,
408
+ )
409
+
410
+ def init_forward_metadata_replay_cuda_graph(
411
+ self, bs: int, req_pool_indices, seq_lens
412
+ ):
413
+ self.cuda_graph_start_loc.zero_()
414
+ self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
415
+
416
+ def get_cuda_graph_seq_len_fill_value(self):
417
+ return 1
418
+
419
+ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
420
+ # TODO: reuse the buffer across layers
421
+ if layer.qk_head_dim != layer.v_head_dim:
422
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
423
+ else:
424
+ o = torch.empty_like(q)
425
+
426
+ input_metadata.token_to_kv_pool.set_kv_buffer(
427
+ layer.layer_id, input_metadata.out_cache_loc, k, v
428
+ )
429
+
430
+ start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
431
+ self.extend_attention_fwd(
432
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
433
+ k.contiguous(),
434
+ v.contiguous(),
435
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
436
+ input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
437
+ input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
438
+ input_metadata.req_to_token_pool.req_to_token,
439
+ input_metadata.req_pool_indices,
440
+ input_metadata.seq_lens,
441
+ input_metadata.extend_seq_lens,
442
+ input_metadata.extend_start_loc,
443
+ max_extend_len,
444
+ layer.scaling,
445
+ layer.logit_cap,
446
+ )
447
+ return o
448
+
449
+ def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
450
+ # During torch.compile, there is a bug in rotary_emb that causes the
451
+ # output value to have a 3D tensor shape. This reshapes the output correctly.
452
+ q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
453
+
454
+ # TODO: reuse the buffer across layers
455
+ if layer.qk_head_dim != layer.v_head_dim:
456
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
457
+ else:
458
+ o = torch.empty_like(q)
459
+
460
+ start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
461
+
462
+ input_metadata.token_to_kv_pool.set_kv_buffer(
463
+ layer.layer_id, input_metadata.out_cache_loc, k, v
464
+ )
465
+
466
+ self.decode_attention_fwd(
467
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
468
+ input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
469
+ input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
470
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
471
+ input_metadata.req_to_token_pool.req_to_token,
472
+ input_metadata.req_pool_indices,
473
+ start_loc,
474
+ input_metadata.seq_lens,
475
+ attn_logits,
476
+ max_seq_len,
477
+ layer.scaling,
478
+ layer.logit_cap,
479
+ )
480
+ return o