sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +11 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +5 -5
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +33 -20
- sglang/srt/layers/attention/torch_native_backend.py +299 -0
- sglang/srt/layers/attention/triton_backend.py +22 -8
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +661 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +36 -2
- sglang/srt/layers/quantization/fp8.py +559 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +4 -2
- sglang/srt/layers/sampler.py +2 -0
- sglang/srt/layers/torchao_utils.py +23 -45
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +19 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +145 -85
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/model_executor/cuda_graph_runner.py +30 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +146 -153
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/model_parallel.py +1 -5
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +4 -5
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +90 -18
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -8
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +96 -31
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +24 -14
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -13
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -16
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -17
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +270 -173
- sglang/srt/server_args.py +102 -29
- sglang/srt/utils.py +295 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
- sglang-0.4.0.post1.dist-info/RECORD +189 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,299 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch.nn.functional import scaled_dot_product_attention
|
7
|
+
|
8
|
+
from sglang.srt.layers.attention import AttentionBackend
|
9
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
13
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
14
|
+
|
15
|
+
|
16
|
+
class TorchNativeAttnBackend(AttentionBackend):
|
17
|
+
def __init__(self, model_runner: ModelRunner):
|
18
|
+
super().__init__()
|
19
|
+
self.forward_metadata = None
|
20
|
+
self.device = model_runner.device
|
21
|
+
|
22
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
23
|
+
"""Init the metadata for a forward pass."""
|
24
|
+
pass
|
25
|
+
|
26
|
+
def init_cuda_graph_state(self, max_bs: int):
|
27
|
+
# TODO: Support CUDA graph
|
28
|
+
raise ValueError(
|
29
|
+
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
30
|
+
)
|
31
|
+
|
32
|
+
def init_forward_metadata_capture_cuda_graph(
|
33
|
+
self,
|
34
|
+
bs: int,
|
35
|
+
req_pool_indices: torch.Tensor,
|
36
|
+
seq_lens: torch.Tensor,
|
37
|
+
encoder_lens: Optional[torch.Tensor] = None,
|
38
|
+
):
|
39
|
+
# TODO: Support CUDA graph
|
40
|
+
raise ValueError(
|
41
|
+
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
42
|
+
)
|
43
|
+
|
44
|
+
def init_forward_metadata_replay_cuda_graph(
|
45
|
+
self,
|
46
|
+
bs: int,
|
47
|
+
req_pool_indices: torch.Tensor,
|
48
|
+
seq_lens: torch.Tensor,
|
49
|
+
seq_lens_sum: int,
|
50
|
+
encoder_lens: Optional[torch.Tensor] = None,
|
51
|
+
):
|
52
|
+
# TODO: Support CUDA graph
|
53
|
+
raise ValueError(
|
54
|
+
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
55
|
+
)
|
56
|
+
|
57
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
58
|
+
# TODO: Support CUDA graph
|
59
|
+
raise ValueError(
|
60
|
+
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
61
|
+
)
|
62
|
+
|
63
|
+
def _run_sdpa_forward_extend(
|
64
|
+
self,
|
65
|
+
query: torch.Tensor,
|
66
|
+
output: torch.Tensor,
|
67
|
+
k_cache: torch.Tensor,
|
68
|
+
v_cache: torch.Tensor,
|
69
|
+
req_to_token: torch.Tensor,
|
70
|
+
req_pool_indices: torch.Tensor,
|
71
|
+
seq_lens: torch.Tensor,
|
72
|
+
extend_prefix_lens: torch.Tensor,
|
73
|
+
extend_seq_lens: torch.Tensor,
|
74
|
+
scaling=None,
|
75
|
+
enable_gqa=False,
|
76
|
+
causal=False,
|
77
|
+
):
|
78
|
+
"""Run the extend forward by using torch native sdpa op.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
query: [num_tokens, num_heads, head_size]
|
82
|
+
output: [num_tokens, num_heads, head_size]
|
83
|
+
k_cache: [max_total_num_tokens, num_heads, head_size]
|
84
|
+
v_cache: [max_total_num_tokens, num_heads, head_size]
|
85
|
+
req_to_token: [max_num_reqs, max_context_len]
|
86
|
+
req_pool_indices: [num_seqs]
|
87
|
+
seq_lens: [num_seqs]
|
88
|
+
extend_prefix_lens: [num_seqs]
|
89
|
+
extend_seq_lens: [num_seqs]
|
90
|
+
scaling: float or None
|
91
|
+
enable_gqa: bool
|
92
|
+
causal: bool
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
output: [num_tokens, num_heads, head_size]
|
96
|
+
"""
|
97
|
+
|
98
|
+
assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
|
99
|
+
assert seq_lens.shape[0] == extend_seq_lens.shape[0]
|
100
|
+
|
101
|
+
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
102
|
+
query = query.movedim(0, query.dim() - 2)
|
103
|
+
|
104
|
+
start_q, start_kv = 0, 0
|
105
|
+
for seq_idx in range(seq_lens.shape[0]):
|
106
|
+
# TODO: this loop process a sequence per iter, this is inefficient.
|
107
|
+
# Need optimize the performance later.
|
108
|
+
|
109
|
+
extend_seq_len_q = extend_seq_lens[seq_idx]
|
110
|
+
prefill_seq_len_q = extend_prefix_lens[seq_idx]
|
111
|
+
|
112
|
+
seq_len_kv = seq_lens[seq_idx]
|
113
|
+
end_q = start_q + extend_seq_len_q
|
114
|
+
end_kv = start_kv + seq_len_kv
|
115
|
+
|
116
|
+
per_req_query = query[:, start_q:end_q, :]
|
117
|
+
per_req_query_redudant = torch.empty(
|
118
|
+
(per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
|
119
|
+
dtype=per_req_query.dtype,
|
120
|
+
device=per_req_query.device,
|
121
|
+
)
|
122
|
+
|
123
|
+
per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query
|
124
|
+
|
125
|
+
# get key and value from cache. per_req_tokens contains the kv cache
|
126
|
+
# index for each token in the sequence.
|
127
|
+
req_pool_idx = req_pool_indices[seq_idx]
|
128
|
+
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
129
|
+
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
130
|
+
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
131
|
+
|
132
|
+
per_req_out_redudant = (
|
133
|
+
scaled_dot_product_attention(
|
134
|
+
per_req_query_redudant.unsqueeze(0),
|
135
|
+
per_req_key.unsqueeze(0),
|
136
|
+
per_req_value.unsqueeze(0),
|
137
|
+
enable_gqa=enable_gqa,
|
138
|
+
scale=scaling,
|
139
|
+
is_causal=causal,
|
140
|
+
)
|
141
|
+
.squeeze(0)
|
142
|
+
.movedim(query.dim() - 2, 0)
|
143
|
+
)
|
144
|
+
output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :]
|
145
|
+
start_q, start_kv = end_q, end_kv
|
146
|
+
return output
|
147
|
+
|
148
|
+
def _run_sdpa_forward_decode(
|
149
|
+
self,
|
150
|
+
query: torch.Tensor,
|
151
|
+
output: torch.Tensor,
|
152
|
+
k_cache: torch.Tensor,
|
153
|
+
v_cache: torch.Tensor,
|
154
|
+
req_to_token: torch.Tensor,
|
155
|
+
req_pool_indices: torch.Tensor,
|
156
|
+
seq_lens: torch.Tensor,
|
157
|
+
scaling=None,
|
158
|
+
enable_gqa=False,
|
159
|
+
causal=False,
|
160
|
+
):
|
161
|
+
"""Run the decode forward by using torch native sdpa op.
|
162
|
+
|
163
|
+
Args:
|
164
|
+
query: [num_tokens, num_heads, head_size]
|
165
|
+
output: [num_tokens, num_heads, head_size]
|
166
|
+
k_cache: [max_total_num_tokens, num_heads, head_size]
|
167
|
+
v_cache: [max_total_num_tokens, num_heads, head_size]
|
168
|
+
req_to_token: [max_num_reqs, max_context_len]
|
169
|
+
req_pool_indices: [num_seqs]
|
170
|
+
seq_lens: [num_seqs]
|
171
|
+
scaling: float or None
|
172
|
+
enable_gqa: bool
|
173
|
+
causal: bool
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
output: [num_tokens, num_heads, head_size]
|
177
|
+
"""
|
178
|
+
|
179
|
+
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
180
|
+
query = query.movedim(0, query.dim() - 2)
|
181
|
+
|
182
|
+
start_q, start_kv = 0, 0
|
183
|
+
for seq_idx in range(seq_lens.shape[0]):
|
184
|
+
# TODO: this loop process a sequence per iter, this is inefficient.
|
185
|
+
# Need optimize the performance later.
|
186
|
+
|
187
|
+
seq_len_q = 1
|
188
|
+
seq_len_kv = seq_lens[seq_idx]
|
189
|
+
end_q = start_q + seq_len_q
|
190
|
+
end_kv = start_kv + seq_len_kv
|
191
|
+
|
192
|
+
per_req_query = query[:, start_q:end_q, :]
|
193
|
+
|
194
|
+
# get key and value from cache. per_req_tokens contains the kv cache
|
195
|
+
# index for each token in the sequence.
|
196
|
+
req_pool_idx = req_pool_indices[seq_idx]
|
197
|
+
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
198
|
+
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
199
|
+
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
200
|
+
|
201
|
+
per_req_out = (
|
202
|
+
scaled_dot_product_attention(
|
203
|
+
per_req_query.unsqueeze(0),
|
204
|
+
per_req_key.unsqueeze(0),
|
205
|
+
per_req_value.unsqueeze(0),
|
206
|
+
enable_gqa=enable_gqa,
|
207
|
+
scale=scaling,
|
208
|
+
is_causal=causal,
|
209
|
+
)
|
210
|
+
.squeeze(0)
|
211
|
+
.movedim(query.dim() - 2, 0)
|
212
|
+
)
|
213
|
+
output[start_q:end_q, :, :] = per_req_out
|
214
|
+
start_q, start_kv = end_q, end_kv
|
215
|
+
|
216
|
+
return output
|
217
|
+
|
218
|
+
def forward_extend(
|
219
|
+
self,
|
220
|
+
q,
|
221
|
+
k,
|
222
|
+
v,
|
223
|
+
layer: RadixAttention,
|
224
|
+
forward_batch: ForwardBatch,
|
225
|
+
save_kv_cache=True,
|
226
|
+
):
|
227
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
228
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
229
|
+
else:
|
230
|
+
o = torch.empty_like(q)
|
231
|
+
|
232
|
+
if save_kv_cache:
|
233
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
234
|
+
layer, forward_batch.out_cache_loc, k, v
|
235
|
+
)
|
236
|
+
|
237
|
+
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
238
|
+
|
239
|
+
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
240
|
+
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
241
|
+
|
242
|
+
self._run_sdpa_forward_extend(
|
243
|
+
q_,
|
244
|
+
o_,
|
245
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
246
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
247
|
+
forward_batch.req_to_token_pool.req_to_token,
|
248
|
+
forward_batch.req_pool_indices,
|
249
|
+
forward_batch.seq_lens,
|
250
|
+
forward_batch.extend_prefix_lens,
|
251
|
+
forward_batch.extend_seq_lens,
|
252
|
+
scaling=layer.scaling,
|
253
|
+
enable_gqa=use_gqa,
|
254
|
+
causal=not layer.is_cross_attention,
|
255
|
+
)
|
256
|
+
return o
|
257
|
+
|
258
|
+
def forward_decode(
|
259
|
+
self,
|
260
|
+
q,
|
261
|
+
k,
|
262
|
+
v,
|
263
|
+
layer: RadixAttention,
|
264
|
+
forward_batch: ForwardBatch,
|
265
|
+
save_kv_cache=True,
|
266
|
+
):
|
267
|
+
# During torch.compile, there is a bug in rotary_emb that causes the
|
268
|
+
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
269
|
+
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
270
|
+
|
271
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
272
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
273
|
+
else:
|
274
|
+
o = torch.empty_like(q)
|
275
|
+
|
276
|
+
if save_kv_cache:
|
277
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
278
|
+
layer, forward_batch.out_cache_loc, k, v
|
279
|
+
)
|
280
|
+
|
281
|
+
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
282
|
+
|
283
|
+
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
284
|
+
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
285
|
+
|
286
|
+
self._run_sdpa_forward_decode(
|
287
|
+
q_,
|
288
|
+
o_,
|
289
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
290
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
291
|
+
forward_batch.req_to_token_pool.req_to_token,
|
292
|
+
forward_batch.req_pool_indices,
|
293
|
+
forward_batch.seq_lens,
|
294
|
+
scaling=layer.scaling,
|
295
|
+
enable_gqa=use_gqa,
|
296
|
+
causal=False,
|
297
|
+
)
|
298
|
+
|
299
|
+
return o
|
@@ -114,7 +114,13 @@ class TritonAttnBackend(AttentionBackend):
|
|
114
114
|
return 1
|
115
115
|
|
116
116
|
def forward_extend(
|
117
|
-
self,
|
117
|
+
self,
|
118
|
+
q,
|
119
|
+
k,
|
120
|
+
v,
|
121
|
+
layer: RadixAttention,
|
122
|
+
forward_batch: ForwardBatch,
|
123
|
+
save_kv_cache=True,
|
118
124
|
):
|
119
125
|
# TODO: reuse the buffer across layers
|
120
126
|
if layer.qk_head_dim != layer.v_head_dim:
|
@@ -122,9 +128,10 @@ class TritonAttnBackend(AttentionBackend):
|
|
122
128
|
else:
|
123
129
|
o = torch.empty_like(q)
|
124
130
|
|
125
|
-
|
126
|
-
|
127
|
-
|
131
|
+
if save_kv_cache:
|
132
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
133
|
+
layer, forward_batch.out_cache_loc, k, v
|
134
|
+
)
|
128
135
|
|
129
136
|
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
130
137
|
self.extend_attention_fwd(
|
@@ -146,7 +153,13 @@ class TritonAttnBackend(AttentionBackend):
|
|
146
153
|
return o
|
147
154
|
|
148
155
|
def forward_decode(
|
149
|
-
self,
|
156
|
+
self,
|
157
|
+
q,
|
158
|
+
k,
|
159
|
+
v,
|
160
|
+
layer: RadixAttention,
|
161
|
+
forward_batch: ForwardBatch,
|
162
|
+
save_kv_cache=True,
|
150
163
|
):
|
151
164
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
152
165
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
@@ -160,9 +173,10 @@ class TritonAttnBackend(AttentionBackend):
|
|
160
173
|
|
161
174
|
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
162
175
|
|
163
|
-
|
164
|
-
|
165
|
-
|
176
|
+
if save_kv_cache:
|
177
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
178
|
+
layer, forward_batch.out_cache_loc, k, v
|
179
|
+
)
|
166
180
|
|
167
181
|
self.decode_attention_fwd(
|
168
182
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
File without changes
|