sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_offline_throughput.py +18 -6
- sglang/bench_one_batch.py +13 -0
- sglang/bench_serving.py +8 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/constrained/outlines_backend.py +5 -0
- sglang/srt/constrained/xgrammar_backend.py +9 -6
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +22 -5
- sglang/srt/layers/attention/torch_native_backend.py +22 -8
- sglang/srt/layers/attention/triton_backend.py +38 -33
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +665 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
- sglang/srt/layers/fused_moe_triton/layer.py +1 -1
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/quantization/__init__.py +2 -47
- sglang/srt/layers/quantization/fp8.py +607 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +11 -2
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/torchao_utils.py +58 -45
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +39 -24
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +236 -197
- sglang/srt/managers/tokenizer_manager.py +99 -58
- sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -11
- sglang/srt/model_executor/model_runner.py +24 -9
- sglang/srt/model_parallel.py +67 -10
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/deepseek_v2.py +87 -7
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +72 -13
- sglang/srt/models/llama.py +22 -5
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +12 -9
- sglang/srt/models/phi3_small.py +0 -5
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +0 -5
- sglang/srt/models/torch_native_llama.py +0 -5
- sglang/srt/openai_api/adapter.py +4 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +4 -4
- sglang/srt/server_args.py +62 -13
- sglang/srt/utils.py +57 -10
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
sglang/llama3_eval.py
ADDED
@@ -0,0 +1,316 @@
|
|
1
|
+
# Adapt from https://github.com/fw-ai/llm_eval_meta
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
import asyncio
|
5
|
+
import os
|
6
|
+
import pickle
|
7
|
+
import re
|
8
|
+
import shutil
|
9
|
+
from collections import defaultdict
|
10
|
+
from dataclasses import dataclass
|
11
|
+
|
12
|
+
import httpx
|
13
|
+
import numpy as np
|
14
|
+
import openai
|
15
|
+
import transformers
|
16
|
+
from datasets import load_dataset
|
17
|
+
from openai import AsyncOpenAI
|
18
|
+
from tqdm import tqdm
|
19
|
+
|
20
|
+
# Mapping providers to their clients and models
|
21
|
+
provider_to_models = {
|
22
|
+
"b10": {
|
23
|
+
"8b": "meta-llama/Llama-3.1-8B-Instruct",
|
24
|
+
"70b": "meta-llama/Llama-3.1-70B-Instruct",
|
25
|
+
"405b": "meta-llama/Llama-3.1-405B-Instruct",
|
26
|
+
},
|
27
|
+
"oai": {
|
28
|
+
"8b": "meta-llama/Llama-3.1-8B-Instruct",
|
29
|
+
"70b": "meta-llama/Llama-3.1-70B-Instruct",
|
30
|
+
"405b": "meta-llama/Llama-3.1-405B-Instruct",
|
31
|
+
},
|
32
|
+
"sgl": {
|
33
|
+
"8b": "meta-llama/Llama-3.1-8B-Instruct",
|
34
|
+
"70b": "meta-llama/Llama-3.1-70B-Instruct",
|
35
|
+
"405b": "meta-llama/Llama-3.1-405B-Instruct",
|
36
|
+
},
|
37
|
+
}
|
38
|
+
|
39
|
+
|
40
|
+
async def fetch_responses(
|
41
|
+
client, prompt, semaphore, index, provider, model_size, output_dir, max_tokens
|
42
|
+
):
|
43
|
+
output_file = os.path.join(output_dir, f"response_{index}.pkl")
|
44
|
+
if os.path.exists(output_file):
|
45
|
+
print(f"File {output_file} already exists, skipping.")
|
46
|
+
return
|
47
|
+
|
48
|
+
async with semaphore:
|
49
|
+
response = await client.completions.create(
|
50
|
+
model=provider_to_models[provider][model_size],
|
51
|
+
prompt=prompt,
|
52
|
+
temperature=0.0,
|
53
|
+
max_tokens=max_tokens,
|
54
|
+
)
|
55
|
+
if isinstance(response, openai.BadRequestError):
|
56
|
+
with open(output_file, "wb") as f:
|
57
|
+
pickle.dump("bad_response", f)
|
58
|
+
assert isinstance(response, openai.types.completion.Completion)
|
59
|
+
# Save response to a file
|
60
|
+
with open(output_file, "wb") as f:
|
61
|
+
pickle.dump(response, f)
|
62
|
+
|
63
|
+
|
64
|
+
TASK_TO_MAX_TOKENS = {
|
65
|
+
"evals__mmlu__details": 1,
|
66
|
+
"evals__mmlu__0_shot__cot__details": 1024,
|
67
|
+
# Official meta uses 1024, but a small % (.05) of questions are answered correctly after relaxing
|
68
|
+
"evals__mmlu_pro__details": 2048,
|
69
|
+
"evals__gsm8k__details": 1024,
|
70
|
+
}
|
71
|
+
|
72
|
+
TASK_TO_EVAL_SET = {
|
73
|
+
"mmlu": "evals__mmlu__details",
|
74
|
+
"mmlu_cot": "evals__mmlu__0_shot__cot__details",
|
75
|
+
"mmlu_pro": "evals__mmlu_pro__details",
|
76
|
+
"gsm8k": "evals__gsm8k__details",
|
77
|
+
}
|
78
|
+
|
79
|
+
|
80
|
+
class CustomAsyncHTTPXClient(httpx.AsyncClient):
|
81
|
+
async def send(self, request: httpx.Request, *args, **kwargs) -> httpx.Response:
|
82
|
+
request.url = httpx.URL(
|
83
|
+
f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict"
|
84
|
+
)
|
85
|
+
return await super().send(request, *args, **kwargs)
|
86
|
+
|
87
|
+
|
88
|
+
def get_client(provider):
|
89
|
+
if provider not in "b10":
|
90
|
+
if os.getenv("OPENAI_API_KEY") == None:
|
91
|
+
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
92
|
+
return {
|
93
|
+
"oai": AsyncOpenAI(base_url="http://127.0.0.1:8000/v1/"),
|
94
|
+
"b10": AsyncOpenAI(
|
95
|
+
api_key=f"Api-Key {os.getenv('OPENAI_API_KEY')}",
|
96
|
+
base_url=f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict",
|
97
|
+
http_client=CustomAsyncHTTPXClient(),
|
98
|
+
),
|
99
|
+
"sgl": AsyncOpenAI(base_url="http://127.0.0.1:30000/v1/"),
|
100
|
+
}[provider]
|
101
|
+
|
102
|
+
|
103
|
+
# Define the benchmark function
|
104
|
+
async def benchmark(args):
|
105
|
+
ds = load_dataset(
|
106
|
+
"meta-llama/Llama-3.1-405B-Instruct-evals",
|
107
|
+
f"Llama-3.1-405B-Instruct-{TASK_TO_EVAL_SET[args.task]}",
|
108
|
+
)
|
109
|
+
semaphore = asyncio.Semaphore(args.concurrency) # Limit to 16 concurrent tasks
|
110
|
+
|
111
|
+
if args.num_examples is None:
|
112
|
+
args.num_examples = len(ds["latest"]["input_final_prompts"])
|
113
|
+
prompts = ds["latest"]["input_final_prompts"][: args.num_examples]
|
114
|
+
|
115
|
+
# Create the output directory if it does not exist
|
116
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
117
|
+
|
118
|
+
tasks = []
|
119
|
+
# Create the tasks with tqdm progress bar
|
120
|
+
max_tokens = TASK_TO_MAX_TOKENS[TASK_TO_EVAL_SET[args.task]]
|
121
|
+
client = get_client(args.provider)
|
122
|
+
for idx, prompt in enumerate(tqdm(prompts, desc="Creating tasks")):
|
123
|
+
tasks.append(
|
124
|
+
asyncio.create_task(
|
125
|
+
fetch_responses(
|
126
|
+
client,
|
127
|
+
f"<|begin_of_text|>{prompt[0]}",
|
128
|
+
semaphore,
|
129
|
+
idx,
|
130
|
+
args.provider,
|
131
|
+
args.model_size,
|
132
|
+
args.output_dir,
|
133
|
+
max_tokens=max_tokens,
|
134
|
+
)
|
135
|
+
)
|
136
|
+
)
|
137
|
+
|
138
|
+
# Run the tasks with tqdm progress bar
|
139
|
+
for future in tqdm(
|
140
|
+
asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks"
|
141
|
+
):
|
142
|
+
await future
|
143
|
+
|
144
|
+
|
145
|
+
def get_mmlu_answer(response):
|
146
|
+
if response is not None:
|
147
|
+
return response.choices[0].text.lstrip().rstrip().upper().replace(".", "")
|
148
|
+
return None
|
149
|
+
|
150
|
+
|
151
|
+
def get_mmlu_cot_answer(response):
|
152
|
+
pattern = r"The best answer is (.+)\.?"
|
153
|
+
match = re.search(pattern, response.choices[0].text)
|
154
|
+
if match:
|
155
|
+
return match.group(1).replace(".", "").replace("*", "")
|
156
|
+
|
157
|
+
pattern = r"the best answer is (.+)\.?"
|
158
|
+
match = re.search(pattern, response.choices[0].text)
|
159
|
+
if match:
|
160
|
+
return match.group(1).replace(".", "")
|
161
|
+
|
162
|
+
pattern = r"The correct answer is (.+)\.?"
|
163
|
+
match = re.search(pattern, response.choices[0].text)
|
164
|
+
if match:
|
165
|
+
return match.group(1).replace(".", "")
|
166
|
+
|
167
|
+
pattern = r"the correct answer is (.+)\.?"
|
168
|
+
match = re.search(pattern, response.choices[0].text)
|
169
|
+
if match:
|
170
|
+
return match.group(1).replace(".", "")
|
171
|
+
|
172
|
+
|
173
|
+
def get_answer_gsm8k(response):
|
174
|
+
pattern = r"The final answer is (.+)\.?"
|
175
|
+
match = re.search(pattern, response.choices[0].text)
|
176
|
+
if match:
|
177
|
+
s = match.group(1)
|
178
|
+
for ok_symbol in ["%", "$"]:
|
179
|
+
s = s.replace(ok_symbol, "")
|
180
|
+
return s
|
181
|
+
|
182
|
+
|
183
|
+
TASK_TO_ANSWER_EXTRACTOR = {
|
184
|
+
"evals__mmlu__details": get_mmlu_answer,
|
185
|
+
"evals__mmlu__0_shot__cot__details": get_mmlu_cot_answer,
|
186
|
+
"evals__gsm8k__details": get_answer_gsm8k,
|
187
|
+
"evals__mmlu_pro__details": get_mmlu_cot_answer,
|
188
|
+
}
|
189
|
+
|
190
|
+
|
191
|
+
def get_dataset_from_task(task, response_path, model_size):
|
192
|
+
ds_405b = load_dataset(
|
193
|
+
f"meta-llama/Llama-3.1-405B-Instruct-evals",
|
194
|
+
f"Llama-3.1-405B-Instruct-{task}",
|
195
|
+
)
|
196
|
+
ds_405b_hash_order = [x[0] for x in ds_405b["latest"]["input_final_prompts_hash"]]
|
197
|
+
|
198
|
+
if "70b" in model_size or "8b" in model_size:
|
199
|
+
if "70" in model_size:
|
200
|
+
ref_model_ds = load_dataset(
|
201
|
+
f"meta-llama/Llama-3.1-70B-Instruct-evals",
|
202
|
+
f"Llama-3.1-70B-Instruct-{task}",
|
203
|
+
)
|
204
|
+
else:
|
205
|
+
ref_model_ds = load_dataset(
|
206
|
+
f"meta-llama/Llama-3.1-8B-Instruct-evals",
|
207
|
+
f"Llama-3.1-8B-Instruct-{task}",
|
208
|
+
)
|
209
|
+
|
210
|
+
hash_to_row = {}
|
211
|
+
for row in ref_model_ds["latest"]:
|
212
|
+
hash_to_row[row["input_final_prompts_hash"][0]] = row
|
213
|
+
reordered_rows = []
|
214
|
+
for prompt_hash in ds_405b_hash_order:
|
215
|
+
reordered_rows.append(hash_to_row[prompt_hash])
|
216
|
+
ref_model_ds["latest"] = reordered_rows
|
217
|
+
return ref_model_ds
|
218
|
+
|
219
|
+
return ds_405b
|
220
|
+
|
221
|
+
|
222
|
+
def analyze(task, response_path, model_size):
|
223
|
+
ds = get_dataset_from_task(task, response_path, model_size)
|
224
|
+
|
225
|
+
responses = []
|
226
|
+
total = len(ds["latest"])
|
227
|
+
|
228
|
+
for i in range(0, total):
|
229
|
+
response = pickle.load(
|
230
|
+
open(os.path.join(response_path, f"response_{i}.pkl"), "rb")
|
231
|
+
)
|
232
|
+
responses.append(response)
|
233
|
+
|
234
|
+
@dataclass
|
235
|
+
class Stats:
|
236
|
+
correct: int = 0
|
237
|
+
total: int = 0
|
238
|
+
meta_correct: int = 0
|
239
|
+
|
240
|
+
average: float = None
|
241
|
+
|
242
|
+
subtask_name_to_stats = defaultdict(lambda: Stats())
|
243
|
+
|
244
|
+
for response, ds_row in zip(responses, ds["latest"]):
|
245
|
+
model_answer = TASK_TO_ANSWER_EXTRACTOR[task](response)
|
246
|
+
|
247
|
+
subtask = ds_row["subtask_name"]
|
248
|
+
|
249
|
+
is_eval_correct = model_answer in ds_row["input_correct_responses"]
|
250
|
+
if is_eval_correct:
|
251
|
+
subtask_name_to_stats[subtask].correct += 1
|
252
|
+
|
253
|
+
if ds_row["is_correct"]:
|
254
|
+
subtask_name_to_stats[subtask].meta_correct += 1
|
255
|
+
|
256
|
+
subtask_name_to_stats[subtask].total += 1
|
257
|
+
|
258
|
+
micro_stats = Stats()
|
259
|
+
for subtask, stats in subtask_name_to_stats.items():
|
260
|
+
stats.average = stats.correct / stats.total
|
261
|
+
stats.meta_average = stats.meta_correct / stats.total
|
262
|
+
|
263
|
+
micro_stats.correct += stats.correct
|
264
|
+
micro_stats.total += stats.total
|
265
|
+
micro_stats.meta_correct += stats.meta_correct
|
266
|
+
|
267
|
+
micro_stats.average = micro_stats.correct / micro_stats.total
|
268
|
+
micro_stats.meta_average = micro_stats.meta_correct / micro_stats.total
|
269
|
+
|
270
|
+
print("Macro average", np.mean([x.average for x in subtask_name_to_stats.values()]))
|
271
|
+
print(
|
272
|
+
"Meta Macro average",
|
273
|
+
np.mean([x.meta_average for x in subtask_name_to_stats.values()]),
|
274
|
+
)
|
275
|
+
print("Micro average", micro_stats.average)
|
276
|
+
print("Meta Micro average", micro_stats.meta_average)
|
277
|
+
|
278
|
+
|
279
|
+
# Entry point for the script
|
280
|
+
if __name__ == "__main__":
|
281
|
+
parser = argparse.ArgumentParser(
|
282
|
+
description="Script to run model with specified parameters."
|
283
|
+
)
|
284
|
+
parser.add_argument(
|
285
|
+
"--model-size",
|
286
|
+
type=str,
|
287
|
+
default="8b",
|
288
|
+
help="Size of the model (e.g., 8b or 70b)",
|
289
|
+
)
|
290
|
+
parser.add_argument(
|
291
|
+
"--provider",
|
292
|
+
type=str,
|
293
|
+
default="sgl",
|
294
|
+
help="Provider name (e.g., sgl, oai, b10)",
|
295
|
+
)
|
296
|
+
parser.add_argument(
|
297
|
+
"--task",
|
298
|
+
type=str,
|
299
|
+
required=True,
|
300
|
+
help="Task (e.g., mmlu, mmlu_cot, mmlu_pro, gsm8k)",
|
301
|
+
)
|
302
|
+
parser.add_argument(
|
303
|
+
"--num-examples", type=int, default=None, help="Number of examples to process"
|
304
|
+
)
|
305
|
+
parser.add_argument("--concurrency", type=int, default=16)
|
306
|
+
parser.add_argument(
|
307
|
+
"--output-dir",
|
308
|
+
type=str,
|
309
|
+
default="tmp-output-dir",
|
310
|
+
help="Directory to save responses",
|
311
|
+
)
|
312
|
+
|
313
|
+
args = parser.parse_args()
|
314
|
+
asyncio.run(benchmark(args))
|
315
|
+
analyze(TASK_TO_EVAL_SET[args.task], args.output_dir, args.model_size)
|
316
|
+
shutil.rmtree("tmp-output-dir", ignore_errors=True)
|
@@ -42,6 +42,7 @@ class OutlinesGrammar(BaseGrammarObject):
|
|
42
42
|
self.guide = guide
|
43
43
|
self.jump_forward_map = jump_forward_map
|
44
44
|
self.state = 0
|
45
|
+
self.finished = False
|
45
46
|
|
46
47
|
def accept_token(self, token: int):
|
47
48
|
self.state = self.guide.get_next_state(self.state, token)
|
@@ -84,6 +85,10 @@ class OutlinesGrammar(BaseGrammarObject):
|
|
84
85
|
) -> torch.Tensor:
|
85
86
|
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
86
87
|
|
88
|
+
@staticmethod
|
89
|
+
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
90
|
+
return vocab_mask
|
91
|
+
|
87
92
|
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
88
93
|
tokens = torch.tensor(
|
89
94
|
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
@@ -45,6 +45,7 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
45
45
|
self.matcher = matcher
|
46
46
|
self.vocab_size = vocab_size
|
47
47
|
self.ctx = ctx
|
48
|
+
self.finished = False
|
48
49
|
|
49
50
|
def accept_token(self, token: int):
|
50
51
|
assert self.matcher.accept_token(token)
|
@@ -85,12 +86,11 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
85
86
|
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
|
86
87
|
|
87
88
|
@staticmethod
|
88
|
-
def
|
89
|
-
|
90
|
-
# vocab_mask must then be on the same device as logits
|
91
|
-
# when applying the token bitmask, so we check and move if needed
|
92
|
-
vocab_mask = vocab_mask.to(logits.device)
|
89
|
+
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
90
|
+
return vocab_mask.to(device, non_blocking=True)
|
93
91
|
|
92
|
+
@staticmethod
|
93
|
+
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
94
94
|
apply_token_bitmask_inplace(logits, vocab_mask)
|
95
95
|
|
96
96
|
def copy(self):
|
@@ -117,7 +117,10 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
117
117
|
key_type, key_string = key
|
118
118
|
if key_type == "json":
|
119
119
|
try:
|
120
|
-
|
120
|
+
if key_string == "$$ANY$$":
|
121
|
+
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
122
|
+
else:
|
123
|
+
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
121
124
|
except RuntimeError as e:
|
122
125
|
logging.warning(
|
123
126
|
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
@@ -52,12 +52,13 @@ class AttentionBackend(ABC):
|
|
52
52
|
v: torch.Tensor,
|
53
53
|
layer: RadixAttention,
|
54
54
|
forward_batch: ForwardBatch,
|
55
|
+
save_kv_cache: bool = True,
|
55
56
|
):
|
56
57
|
"""Run forward on an attention layer."""
|
57
58
|
if forward_batch.forward_mode.is_decode():
|
58
|
-
return self.forward_decode(q, k, v, layer, forward_batch)
|
59
|
+
return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
|
59
60
|
else:
|
60
|
-
return self.forward_extend(q, k, v, layer, forward_batch)
|
61
|
+
return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache)
|
61
62
|
|
62
63
|
def forward_decode(
|
63
64
|
self,
|
@@ -66,6 +67,7 @@ class AttentionBackend(ABC):
|
|
66
67
|
v: torch.Tensor,
|
67
68
|
layer: RadixAttention,
|
68
69
|
forward_batch: ForwardBatch,
|
70
|
+
save_kv_cache: bool = True,
|
69
71
|
):
|
70
72
|
"""Run a forward for decode."""
|
71
73
|
raise NotImplementedError()
|
@@ -77,6 +79,7 @@ class AttentionBackend(ABC):
|
|
77
79
|
v: torch.Tensor,
|
78
80
|
layer: RadixAttention,
|
79
81
|
forward_batch: ForwardBatch,
|
82
|
+
save_kv_cache: bool = True,
|
80
83
|
):
|
81
84
|
"""Run a forward for extend."""
|
82
85
|
raise NotImplementedError()
|
@@ -165,7 +165,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
165
165
|
return 1
|
166
166
|
|
167
167
|
def forward_extend(
|
168
|
-
self,
|
168
|
+
self,
|
169
|
+
q,
|
170
|
+
k,
|
171
|
+
v,
|
172
|
+
layer: RadixAttention,
|
173
|
+
forward_batch: ForwardBatch,
|
174
|
+
save_kv_cache=True,
|
169
175
|
):
|
170
176
|
# TODO: reuse the buffer across layers
|
171
177
|
if layer.qk_head_dim != layer.v_head_dim:
|
@@ -181,9 +187,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
181
187
|
.expand(k.shape[0], -1, -1),
|
182
188
|
)
|
183
189
|
|
184
|
-
|
185
|
-
|
186
|
-
|
190
|
+
if save_kv_cache:
|
191
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
192
|
+
layer, forward_batch.out_cache_loc, k, v, k_label
|
193
|
+
)
|
187
194
|
|
188
195
|
(
|
189
196
|
start_loc,
|
@@ -212,7 +219,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
212
219
|
return o
|
213
220
|
|
214
221
|
def forward_decode(
|
215
|
-
self,
|
222
|
+
self,
|
223
|
+
q,
|
224
|
+
k,
|
225
|
+
v,
|
226
|
+
layer: RadixAttention,
|
227
|
+
forward_batch: ForwardBatch,
|
228
|
+
save_kv_cache=True,
|
216
229
|
):
|
217
230
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
218
231
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
@@ -242,9 +255,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
242
255
|
.expand(k.shape[0], -1, -1),
|
243
256
|
)
|
244
257
|
|
245
|
-
|
246
|
-
|
247
|
-
|
258
|
+
if save_kv_cache:
|
259
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
260
|
+
layer, forward_batch.out_cache_loc, k, v, k_label
|
261
|
+
)
|
248
262
|
|
249
263
|
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
|
250
264
|
# and set a minimum value for sparse_decode
|
@@ -221,7 +221,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
221
221
|
return 0
|
222
222
|
|
223
223
|
def forward_extend(
|
224
|
-
self,
|
224
|
+
self,
|
225
|
+
q,
|
226
|
+
k,
|
227
|
+
v,
|
228
|
+
layer: RadixAttention,
|
229
|
+
forward_batch: ForwardBatch,
|
230
|
+
save_kv_cache=True,
|
225
231
|
):
|
226
232
|
prefill_wrapper_paged = self.prefill_wrappers_paged[
|
227
233
|
self._get_wrapper_idx(layer)
|
@@ -237,7 +243,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
237
243
|
if not use_ragged:
|
238
244
|
if k is not None:
|
239
245
|
assert v is not None
|
240
|
-
|
246
|
+
if save_kv_cache:
|
247
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
241
248
|
|
242
249
|
o = prefill_wrapper_paged.forward(
|
243
250
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
@@ -270,12 +277,19 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
270
277
|
|
271
278
|
o, _ = merge_state(o1, s1, o2, s2)
|
272
279
|
|
273
|
-
|
280
|
+
if save_kv_cache:
|
281
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
274
282
|
|
275
283
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
276
284
|
|
277
285
|
def forward_decode(
|
278
|
-
self,
|
286
|
+
self,
|
287
|
+
q,
|
288
|
+
k,
|
289
|
+
v,
|
290
|
+
layer: RadixAttention,
|
291
|
+
forward_batch: ForwardBatch,
|
292
|
+
save_kv_cache=True,
|
279
293
|
):
|
280
294
|
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
|
281
295
|
cache_loc = (
|
@@ -286,7 +300,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
286
300
|
|
287
301
|
if k is not None:
|
288
302
|
assert v is not None
|
289
|
-
|
303
|
+
if save_kv_cache:
|
304
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
290
305
|
|
291
306
|
o = decode_wrapper.forward(
|
292
307
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
@@ -663,6 +678,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
663
678
|
self.num_qo_heads,
|
664
679
|
self.num_kv_heads,
|
665
680
|
self.head_dim,
|
681
|
+
q_data_type=self.q_data_type,
|
666
682
|
)
|
667
683
|
|
668
684
|
# cached part
|
@@ -676,6 +692,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
676
692
|
self.num_kv_heads,
|
677
693
|
self.head_dim,
|
678
694
|
1,
|
695
|
+
q_data_type=self.q_data_type,
|
679
696
|
)
|
680
697
|
|
681
698
|
|
@@ -216,16 +216,23 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
216
216
|
return output
|
217
217
|
|
218
218
|
def forward_extend(
|
219
|
-
self,
|
219
|
+
self,
|
220
|
+
q,
|
221
|
+
k,
|
222
|
+
v,
|
223
|
+
layer: RadixAttention,
|
224
|
+
forward_batch: ForwardBatch,
|
225
|
+
save_kv_cache=True,
|
220
226
|
):
|
221
227
|
if layer.qk_head_dim != layer.v_head_dim:
|
222
228
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
223
229
|
else:
|
224
230
|
o = torch.empty_like(q)
|
225
231
|
|
226
|
-
|
227
|
-
|
228
|
-
|
232
|
+
if save_kv_cache:
|
233
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
234
|
+
layer, forward_batch.out_cache_loc, k, v
|
235
|
+
)
|
229
236
|
|
230
237
|
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
231
238
|
|
@@ -249,7 +256,13 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
249
256
|
return o
|
250
257
|
|
251
258
|
def forward_decode(
|
252
|
-
self,
|
259
|
+
self,
|
260
|
+
q,
|
261
|
+
k,
|
262
|
+
v,
|
263
|
+
layer: RadixAttention,
|
264
|
+
forward_batch: ForwardBatch,
|
265
|
+
save_kv_cache=True,
|
253
266
|
):
|
254
267
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
255
268
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
@@ -260,9 +273,10 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
260
273
|
else:
|
261
274
|
o = torch.empty_like(q)
|
262
275
|
|
263
|
-
|
264
|
-
|
265
|
-
|
276
|
+
if save_kv_cache:
|
277
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
278
|
+
layer, forward_batch.out_cache_loc, k, v
|
279
|
+
)
|
266
280
|
|
267
281
|
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
268
282
|
|