sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_offline_throughput.py +18 -6
  3. sglang/bench_one_batch.py +13 -0
  4. sglang/bench_serving.py +8 -1
  5. sglang/check_env.py +140 -48
  6. sglang/lang/backend/runtime_endpoint.py +1 -0
  7. sglang/lang/chat_template.py +32 -0
  8. sglang/llama3_eval.py +316 -0
  9. sglang/srt/constrained/outlines_backend.py +5 -0
  10. sglang/srt/constrained/xgrammar_backend.py +9 -6
  11. sglang/srt/layers/attention/__init__.py +5 -2
  12. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  13. sglang/srt/layers/attention/flashinfer_backend.py +22 -5
  14. sglang/srt/layers/attention/torch_native_backend.py +22 -8
  15. sglang/srt/layers/attention/triton_backend.py +38 -33
  16. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  17. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  18. sglang/srt/layers/ep_moe/__init__.py +0 -0
  19. sglang/srt/layers/ep_moe/kernels.py +349 -0
  20. sglang/srt/layers/ep_moe/layer.py +665 -0
  21. sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
  22. sglang/srt/layers/fused_moe_triton/layer.py +1 -1
  23. sglang/srt/layers/logits_processor.py +133 -95
  24. sglang/srt/layers/quantization/__init__.py +2 -47
  25. sglang/srt/layers/quantization/fp8.py +607 -0
  26. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  27. sglang/srt/layers/radix_attention.py +11 -2
  28. sglang/srt/layers/sampler.py +29 -5
  29. sglang/srt/layers/torchao_utils.py +58 -45
  30. sglang/srt/managers/detokenizer_manager.py +37 -17
  31. sglang/srt/managers/io_struct.py +39 -10
  32. sglang/srt/managers/schedule_batch.py +39 -24
  33. sglang/srt/managers/schedule_policy.py +64 -5
  34. sglang/srt/managers/scheduler.py +236 -197
  35. sglang/srt/managers/tokenizer_manager.py +99 -58
  36. sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
  37. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  38. sglang/srt/mem_cache/chunk_cache.py +2 -2
  39. sglang/srt/mem_cache/memory_pool.py +5 -1
  40. sglang/srt/mem_cache/radix_cache.py +12 -2
  41. sglang/srt/model_executor/cuda_graph_runner.py +39 -11
  42. sglang/srt/model_executor/model_runner.py +24 -9
  43. sglang/srt/model_parallel.py +67 -10
  44. sglang/srt/models/commandr.py +2 -2
  45. sglang/srt/models/deepseek_v2.py +87 -7
  46. sglang/srt/models/gemma2.py +34 -0
  47. sglang/srt/models/gemma2_reward.py +0 -1
  48. sglang/srt/models/granite.py +517 -0
  49. sglang/srt/models/grok.py +72 -13
  50. sglang/srt/models/llama.py +22 -5
  51. sglang/srt/models/llama_classification.py +11 -23
  52. sglang/srt/models/llama_reward.py +0 -2
  53. sglang/srt/models/llava.py +37 -14
  54. sglang/srt/models/mixtral.py +12 -9
  55. sglang/srt/models/phi3_small.py +0 -5
  56. sglang/srt/models/qwen2.py +20 -0
  57. sglang/srt/models/qwen2_moe.py +0 -5
  58. sglang/srt/models/torch_native_llama.py +0 -5
  59. sglang/srt/openai_api/adapter.py +4 -0
  60. sglang/srt/openai_api/protocol.py +9 -4
  61. sglang/srt/sampling/sampling_batch_info.py +9 -8
  62. sglang/srt/server.py +4 -4
  63. sglang/srt/server_args.py +62 -13
  64. sglang/srt/utils.py +57 -10
  65. sglang/test/test_utils.py +3 -2
  66. sglang/utils.py +10 -3
  67. sglang/version.py +1 -1
  68. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
  69. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
  70. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
  71. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
  72. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
sglang/llama3_eval.py ADDED
@@ -0,0 +1,316 @@
1
+ # Adapt from https://github.com/fw-ai/llm_eval_meta
2
+
3
+ import argparse
4
+ import asyncio
5
+ import os
6
+ import pickle
7
+ import re
8
+ import shutil
9
+ from collections import defaultdict
10
+ from dataclasses import dataclass
11
+
12
+ import httpx
13
+ import numpy as np
14
+ import openai
15
+ import transformers
16
+ from datasets import load_dataset
17
+ from openai import AsyncOpenAI
18
+ from tqdm import tqdm
19
+
20
+ # Mapping providers to their clients and models
21
+ provider_to_models = {
22
+ "b10": {
23
+ "8b": "meta-llama/Llama-3.1-8B-Instruct",
24
+ "70b": "meta-llama/Llama-3.1-70B-Instruct",
25
+ "405b": "meta-llama/Llama-3.1-405B-Instruct",
26
+ },
27
+ "oai": {
28
+ "8b": "meta-llama/Llama-3.1-8B-Instruct",
29
+ "70b": "meta-llama/Llama-3.1-70B-Instruct",
30
+ "405b": "meta-llama/Llama-3.1-405B-Instruct",
31
+ },
32
+ "sgl": {
33
+ "8b": "meta-llama/Llama-3.1-8B-Instruct",
34
+ "70b": "meta-llama/Llama-3.1-70B-Instruct",
35
+ "405b": "meta-llama/Llama-3.1-405B-Instruct",
36
+ },
37
+ }
38
+
39
+
40
+ async def fetch_responses(
41
+ client, prompt, semaphore, index, provider, model_size, output_dir, max_tokens
42
+ ):
43
+ output_file = os.path.join(output_dir, f"response_{index}.pkl")
44
+ if os.path.exists(output_file):
45
+ print(f"File {output_file} already exists, skipping.")
46
+ return
47
+
48
+ async with semaphore:
49
+ response = await client.completions.create(
50
+ model=provider_to_models[provider][model_size],
51
+ prompt=prompt,
52
+ temperature=0.0,
53
+ max_tokens=max_tokens,
54
+ )
55
+ if isinstance(response, openai.BadRequestError):
56
+ with open(output_file, "wb") as f:
57
+ pickle.dump("bad_response", f)
58
+ assert isinstance(response, openai.types.completion.Completion)
59
+ # Save response to a file
60
+ with open(output_file, "wb") as f:
61
+ pickle.dump(response, f)
62
+
63
+
64
+ TASK_TO_MAX_TOKENS = {
65
+ "evals__mmlu__details": 1,
66
+ "evals__mmlu__0_shot__cot__details": 1024,
67
+ # Official meta uses 1024, but a small % (.05) of questions are answered correctly after relaxing
68
+ "evals__mmlu_pro__details": 2048,
69
+ "evals__gsm8k__details": 1024,
70
+ }
71
+
72
+ TASK_TO_EVAL_SET = {
73
+ "mmlu": "evals__mmlu__details",
74
+ "mmlu_cot": "evals__mmlu__0_shot__cot__details",
75
+ "mmlu_pro": "evals__mmlu_pro__details",
76
+ "gsm8k": "evals__gsm8k__details",
77
+ }
78
+
79
+
80
+ class CustomAsyncHTTPXClient(httpx.AsyncClient):
81
+ async def send(self, request: httpx.Request, *args, **kwargs) -> httpx.Response:
82
+ request.url = httpx.URL(
83
+ f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict"
84
+ )
85
+ return await super().send(request, *args, **kwargs)
86
+
87
+
88
+ def get_client(provider):
89
+ if provider not in "b10":
90
+ if os.getenv("OPENAI_API_KEY") == None:
91
+ os.environ["OPENAI_API_KEY"] = "EMPTY"
92
+ return {
93
+ "oai": AsyncOpenAI(base_url="http://127.0.0.1:8000/v1/"),
94
+ "b10": AsyncOpenAI(
95
+ api_key=f"Api-Key {os.getenv('OPENAI_API_KEY')}",
96
+ base_url=f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict",
97
+ http_client=CustomAsyncHTTPXClient(),
98
+ ),
99
+ "sgl": AsyncOpenAI(base_url="http://127.0.0.1:30000/v1/"),
100
+ }[provider]
101
+
102
+
103
+ # Define the benchmark function
104
+ async def benchmark(args):
105
+ ds = load_dataset(
106
+ "meta-llama/Llama-3.1-405B-Instruct-evals",
107
+ f"Llama-3.1-405B-Instruct-{TASK_TO_EVAL_SET[args.task]}",
108
+ )
109
+ semaphore = asyncio.Semaphore(args.concurrency) # Limit to 16 concurrent tasks
110
+
111
+ if args.num_examples is None:
112
+ args.num_examples = len(ds["latest"]["input_final_prompts"])
113
+ prompts = ds["latest"]["input_final_prompts"][: args.num_examples]
114
+
115
+ # Create the output directory if it does not exist
116
+ os.makedirs(args.output_dir, exist_ok=True)
117
+
118
+ tasks = []
119
+ # Create the tasks with tqdm progress bar
120
+ max_tokens = TASK_TO_MAX_TOKENS[TASK_TO_EVAL_SET[args.task]]
121
+ client = get_client(args.provider)
122
+ for idx, prompt in enumerate(tqdm(prompts, desc="Creating tasks")):
123
+ tasks.append(
124
+ asyncio.create_task(
125
+ fetch_responses(
126
+ client,
127
+ f"<|begin_of_text|>{prompt[0]}",
128
+ semaphore,
129
+ idx,
130
+ args.provider,
131
+ args.model_size,
132
+ args.output_dir,
133
+ max_tokens=max_tokens,
134
+ )
135
+ )
136
+ )
137
+
138
+ # Run the tasks with tqdm progress bar
139
+ for future in tqdm(
140
+ asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks"
141
+ ):
142
+ await future
143
+
144
+
145
+ def get_mmlu_answer(response):
146
+ if response is not None:
147
+ return response.choices[0].text.lstrip().rstrip().upper().replace(".", "")
148
+ return None
149
+
150
+
151
+ def get_mmlu_cot_answer(response):
152
+ pattern = r"The best answer is (.+)\.?"
153
+ match = re.search(pattern, response.choices[0].text)
154
+ if match:
155
+ return match.group(1).replace(".", "").replace("*", "")
156
+
157
+ pattern = r"the best answer is (.+)\.?"
158
+ match = re.search(pattern, response.choices[0].text)
159
+ if match:
160
+ return match.group(1).replace(".", "")
161
+
162
+ pattern = r"The correct answer is (.+)\.?"
163
+ match = re.search(pattern, response.choices[0].text)
164
+ if match:
165
+ return match.group(1).replace(".", "")
166
+
167
+ pattern = r"the correct answer is (.+)\.?"
168
+ match = re.search(pattern, response.choices[0].text)
169
+ if match:
170
+ return match.group(1).replace(".", "")
171
+
172
+
173
+ def get_answer_gsm8k(response):
174
+ pattern = r"The final answer is (.+)\.?"
175
+ match = re.search(pattern, response.choices[0].text)
176
+ if match:
177
+ s = match.group(1)
178
+ for ok_symbol in ["%", "$"]:
179
+ s = s.replace(ok_symbol, "")
180
+ return s
181
+
182
+
183
+ TASK_TO_ANSWER_EXTRACTOR = {
184
+ "evals__mmlu__details": get_mmlu_answer,
185
+ "evals__mmlu__0_shot__cot__details": get_mmlu_cot_answer,
186
+ "evals__gsm8k__details": get_answer_gsm8k,
187
+ "evals__mmlu_pro__details": get_mmlu_cot_answer,
188
+ }
189
+
190
+
191
+ def get_dataset_from_task(task, response_path, model_size):
192
+ ds_405b = load_dataset(
193
+ f"meta-llama/Llama-3.1-405B-Instruct-evals",
194
+ f"Llama-3.1-405B-Instruct-{task}",
195
+ )
196
+ ds_405b_hash_order = [x[0] for x in ds_405b["latest"]["input_final_prompts_hash"]]
197
+
198
+ if "70b" in model_size or "8b" in model_size:
199
+ if "70" in model_size:
200
+ ref_model_ds = load_dataset(
201
+ f"meta-llama/Llama-3.1-70B-Instruct-evals",
202
+ f"Llama-3.1-70B-Instruct-{task}",
203
+ )
204
+ else:
205
+ ref_model_ds = load_dataset(
206
+ f"meta-llama/Llama-3.1-8B-Instruct-evals",
207
+ f"Llama-3.1-8B-Instruct-{task}",
208
+ )
209
+
210
+ hash_to_row = {}
211
+ for row in ref_model_ds["latest"]:
212
+ hash_to_row[row["input_final_prompts_hash"][0]] = row
213
+ reordered_rows = []
214
+ for prompt_hash in ds_405b_hash_order:
215
+ reordered_rows.append(hash_to_row[prompt_hash])
216
+ ref_model_ds["latest"] = reordered_rows
217
+ return ref_model_ds
218
+
219
+ return ds_405b
220
+
221
+
222
+ def analyze(task, response_path, model_size):
223
+ ds = get_dataset_from_task(task, response_path, model_size)
224
+
225
+ responses = []
226
+ total = len(ds["latest"])
227
+
228
+ for i in range(0, total):
229
+ response = pickle.load(
230
+ open(os.path.join(response_path, f"response_{i}.pkl"), "rb")
231
+ )
232
+ responses.append(response)
233
+
234
+ @dataclass
235
+ class Stats:
236
+ correct: int = 0
237
+ total: int = 0
238
+ meta_correct: int = 0
239
+
240
+ average: float = None
241
+
242
+ subtask_name_to_stats = defaultdict(lambda: Stats())
243
+
244
+ for response, ds_row in zip(responses, ds["latest"]):
245
+ model_answer = TASK_TO_ANSWER_EXTRACTOR[task](response)
246
+
247
+ subtask = ds_row["subtask_name"]
248
+
249
+ is_eval_correct = model_answer in ds_row["input_correct_responses"]
250
+ if is_eval_correct:
251
+ subtask_name_to_stats[subtask].correct += 1
252
+
253
+ if ds_row["is_correct"]:
254
+ subtask_name_to_stats[subtask].meta_correct += 1
255
+
256
+ subtask_name_to_stats[subtask].total += 1
257
+
258
+ micro_stats = Stats()
259
+ for subtask, stats in subtask_name_to_stats.items():
260
+ stats.average = stats.correct / stats.total
261
+ stats.meta_average = stats.meta_correct / stats.total
262
+
263
+ micro_stats.correct += stats.correct
264
+ micro_stats.total += stats.total
265
+ micro_stats.meta_correct += stats.meta_correct
266
+
267
+ micro_stats.average = micro_stats.correct / micro_stats.total
268
+ micro_stats.meta_average = micro_stats.meta_correct / micro_stats.total
269
+
270
+ print("Macro average", np.mean([x.average for x in subtask_name_to_stats.values()]))
271
+ print(
272
+ "Meta Macro average",
273
+ np.mean([x.meta_average for x in subtask_name_to_stats.values()]),
274
+ )
275
+ print("Micro average", micro_stats.average)
276
+ print("Meta Micro average", micro_stats.meta_average)
277
+
278
+
279
+ # Entry point for the script
280
+ if __name__ == "__main__":
281
+ parser = argparse.ArgumentParser(
282
+ description="Script to run model with specified parameters."
283
+ )
284
+ parser.add_argument(
285
+ "--model-size",
286
+ type=str,
287
+ default="8b",
288
+ help="Size of the model (e.g., 8b or 70b)",
289
+ )
290
+ parser.add_argument(
291
+ "--provider",
292
+ type=str,
293
+ default="sgl",
294
+ help="Provider name (e.g., sgl, oai, b10)",
295
+ )
296
+ parser.add_argument(
297
+ "--task",
298
+ type=str,
299
+ required=True,
300
+ help="Task (e.g., mmlu, mmlu_cot, mmlu_pro, gsm8k)",
301
+ )
302
+ parser.add_argument(
303
+ "--num-examples", type=int, default=None, help="Number of examples to process"
304
+ )
305
+ parser.add_argument("--concurrency", type=int, default=16)
306
+ parser.add_argument(
307
+ "--output-dir",
308
+ type=str,
309
+ default="tmp-output-dir",
310
+ help="Directory to save responses",
311
+ )
312
+
313
+ args = parser.parse_args()
314
+ asyncio.run(benchmark(args))
315
+ analyze(TASK_TO_EVAL_SET[args.task], args.output_dir, args.model_size)
316
+ shutil.rmtree("tmp-output-dir", ignore_errors=True)
@@ -42,6 +42,7 @@ class OutlinesGrammar(BaseGrammarObject):
42
42
  self.guide = guide
43
43
  self.jump_forward_map = jump_forward_map
44
44
  self.state = 0
45
+ self.finished = False
45
46
 
46
47
  def accept_token(self, token: int):
47
48
  self.state = self.guide.get_next_state(self.state, token)
@@ -84,6 +85,10 @@ class OutlinesGrammar(BaseGrammarObject):
84
85
  ) -> torch.Tensor:
85
86
  return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
86
87
 
88
+ @staticmethod
89
+ def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
90
+ return vocab_mask
91
+
87
92
  def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
88
93
  tokens = torch.tensor(
89
94
  self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
@@ -45,6 +45,7 @@ class XGrammarGrammar(BaseGrammarObject):
45
45
  self.matcher = matcher
46
46
  self.vocab_size = vocab_size
47
47
  self.ctx = ctx
48
+ self.finished = False
48
49
 
49
50
  def accept_token(self, token: int):
50
51
  assert self.matcher.accept_token(token)
@@ -85,12 +86,11 @@ class XGrammarGrammar(BaseGrammarObject):
85
86
  self.matcher.fill_next_token_bitmask(vocab_mask, idx)
86
87
 
87
88
  @staticmethod
88
- def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
89
- if vocab_mask.device.type != logits.device.type:
90
- # vocab_mask must then be on the same device as logits
91
- # when applying the token bitmask, so we check and move if needed
92
- vocab_mask = vocab_mask.to(logits.device)
89
+ def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
90
+ return vocab_mask.to(device, non_blocking=True)
93
91
 
92
+ @staticmethod
93
+ def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
94
94
  apply_token_bitmask_inplace(logits, vocab_mask)
95
95
 
96
96
  def copy(self):
@@ -117,7 +117,10 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
117
117
  key_type, key_string = key
118
118
  if key_type == "json":
119
119
  try:
120
- ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
120
+ if key_string == "$$ANY$$":
121
+ ctx = self.grammar_compiler.compile_builtin_json_grammar()
122
+ else:
123
+ ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
121
124
  except RuntimeError as e:
122
125
  logging.warning(
123
126
  f"Skip invalid json_schema: json_schema={key_string}, {e=}"
@@ -52,12 +52,13 @@ class AttentionBackend(ABC):
52
52
  v: torch.Tensor,
53
53
  layer: RadixAttention,
54
54
  forward_batch: ForwardBatch,
55
+ save_kv_cache: bool = True,
55
56
  ):
56
57
  """Run forward on an attention layer."""
57
58
  if forward_batch.forward_mode.is_decode():
58
- return self.forward_decode(q, k, v, layer, forward_batch)
59
+ return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
59
60
  else:
60
- return self.forward_extend(q, k, v, layer, forward_batch)
61
+ return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache)
61
62
 
62
63
  def forward_decode(
63
64
  self,
@@ -66,6 +67,7 @@ class AttentionBackend(ABC):
66
67
  v: torch.Tensor,
67
68
  layer: RadixAttention,
68
69
  forward_batch: ForwardBatch,
70
+ save_kv_cache: bool = True,
69
71
  ):
70
72
  """Run a forward for decode."""
71
73
  raise NotImplementedError()
@@ -77,6 +79,7 @@ class AttentionBackend(ABC):
77
79
  v: torch.Tensor,
78
80
  layer: RadixAttention,
79
81
  forward_batch: ForwardBatch,
82
+ save_kv_cache: bool = True,
80
83
  ):
81
84
  """Run a forward for extend."""
82
85
  raise NotImplementedError()
@@ -165,7 +165,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
165
165
  return 1
166
166
 
167
167
  def forward_extend(
168
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
168
+ self,
169
+ q,
170
+ k,
171
+ v,
172
+ layer: RadixAttention,
173
+ forward_batch: ForwardBatch,
174
+ save_kv_cache=True,
169
175
  ):
170
176
  # TODO: reuse the buffer across layers
171
177
  if layer.qk_head_dim != layer.v_head_dim:
@@ -181,9 +187,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
181
187
  .expand(k.shape[0], -1, -1),
182
188
  )
183
189
 
184
- forward_batch.token_to_kv_pool.set_kv_buffer(
185
- layer, forward_batch.out_cache_loc, k, v, k_label
186
- )
190
+ if save_kv_cache:
191
+ forward_batch.token_to_kv_pool.set_kv_buffer(
192
+ layer, forward_batch.out_cache_loc, k, v, k_label
193
+ )
187
194
 
188
195
  (
189
196
  start_loc,
@@ -212,7 +219,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
212
219
  return o
213
220
 
214
221
  def forward_decode(
215
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
222
+ self,
223
+ q,
224
+ k,
225
+ v,
226
+ layer: RadixAttention,
227
+ forward_batch: ForwardBatch,
228
+ save_kv_cache=True,
216
229
  ):
217
230
  # During torch.compile, there is a bug in rotary_emb that causes the
218
231
  # output value to have a 3D tensor shape. This reshapes the output correctly.
@@ -242,9 +255,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
242
255
  .expand(k.shape[0], -1, -1),
243
256
  )
244
257
 
245
- forward_batch.token_to_kv_pool.set_kv_buffer(
246
- layer, forward_batch.out_cache_loc, k, v, k_label
247
- )
258
+ if save_kv_cache:
259
+ forward_batch.token_to_kv_pool.set_kv_buffer(
260
+ layer, forward_batch.out_cache_loc, k, v, k_label
261
+ )
248
262
 
249
263
  # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
250
264
  # and set a minimum value for sparse_decode
@@ -221,7 +221,13 @@ class FlashInferAttnBackend(AttentionBackend):
221
221
  return 0
222
222
 
223
223
  def forward_extend(
224
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
224
+ self,
225
+ q,
226
+ k,
227
+ v,
228
+ layer: RadixAttention,
229
+ forward_batch: ForwardBatch,
230
+ save_kv_cache=True,
225
231
  ):
226
232
  prefill_wrapper_paged = self.prefill_wrappers_paged[
227
233
  self._get_wrapper_idx(layer)
@@ -237,7 +243,8 @@ class FlashInferAttnBackend(AttentionBackend):
237
243
  if not use_ragged:
238
244
  if k is not None:
239
245
  assert v is not None
240
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
246
+ if save_kv_cache:
247
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
241
248
 
242
249
  o = prefill_wrapper_paged.forward(
243
250
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -270,12 +277,19 @@ class FlashInferAttnBackend(AttentionBackend):
270
277
 
271
278
  o, _ = merge_state(o1, s1, o2, s2)
272
279
 
273
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
280
+ if save_kv_cache:
281
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
274
282
 
275
283
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
276
284
 
277
285
  def forward_decode(
278
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
286
+ self,
287
+ q,
288
+ k,
289
+ v,
290
+ layer: RadixAttention,
291
+ forward_batch: ForwardBatch,
292
+ save_kv_cache=True,
279
293
  ):
280
294
  decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
281
295
  cache_loc = (
@@ -286,7 +300,8 @@ class FlashInferAttnBackend(AttentionBackend):
286
300
 
287
301
  if k is not None:
288
302
  assert v is not None
289
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
303
+ if save_kv_cache:
304
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
290
305
 
291
306
  o = decode_wrapper.forward(
292
307
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -663,6 +678,7 @@ class FlashInferIndicesUpdaterPrefill:
663
678
  self.num_qo_heads,
664
679
  self.num_kv_heads,
665
680
  self.head_dim,
681
+ q_data_type=self.q_data_type,
666
682
  )
667
683
 
668
684
  # cached part
@@ -676,6 +692,7 @@ class FlashInferIndicesUpdaterPrefill:
676
692
  self.num_kv_heads,
677
693
  self.head_dim,
678
694
  1,
695
+ q_data_type=self.q_data_type,
679
696
  )
680
697
 
681
698
 
@@ -216,16 +216,23 @@ class TorchNativeAttnBackend(AttentionBackend):
216
216
  return output
217
217
 
218
218
  def forward_extend(
219
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
219
+ self,
220
+ q,
221
+ k,
222
+ v,
223
+ layer: RadixAttention,
224
+ forward_batch: ForwardBatch,
225
+ save_kv_cache=True,
220
226
  ):
221
227
  if layer.qk_head_dim != layer.v_head_dim:
222
228
  o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
223
229
  else:
224
230
  o = torch.empty_like(q)
225
231
 
226
- forward_batch.token_to_kv_pool.set_kv_buffer(
227
- layer, forward_batch.out_cache_loc, k, v
228
- )
232
+ if save_kv_cache:
233
+ forward_batch.token_to_kv_pool.set_kv_buffer(
234
+ layer, forward_batch.out_cache_loc, k, v
235
+ )
229
236
 
230
237
  use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
231
238
 
@@ -249,7 +256,13 @@ class TorchNativeAttnBackend(AttentionBackend):
249
256
  return o
250
257
 
251
258
  def forward_decode(
252
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
259
+ self,
260
+ q,
261
+ k,
262
+ v,
263
+ layer: RadixAttention,
264
+ forward_batch: ForwardBatch,
265
+ save_kv_cache=True,
253
266
  ):
254
267
  # During torch.compile, there is a bug in rotary_emb that causes the
255
268
  # output value to have a 3D tensor shape. This reshapes the output correctly.
@@ -260,9 +273,10 @@ class TorchNativeAttnBackend(AttentionBackend):
260
273
  else:
261
274
  o = torch.empty_like(q)
262
275
 
263
- forward_batch.token_to_kv_pool.set_kv_buffer(
264
- layer, forward_batch.out_cache_loc, k, v
265
- )
276
+ if save_kv_cache:
277
+ forward_batch.token_to_kv_pool.set_kv_buffer(
278
+ layer, forward_batch.out_cache_loc, k, v
279
+ )
266
280
 
267
281
  use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
268
282