sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +1 -0
  3. sglang/bench_serving.py +9 -1
  4. sglang/check_env.py +140 -48
  5. sglang/lang/backend/runtime_endpoint.py +1 -0
  6. sglang/lang/chat_template.py +32 -0
  7. sglang/llama3_eval.py +316 -0
  8. sglang/srt/aio_rwlock.py +100 -0
  9. sglang/srt/configs/model_config.py +8 -1
  10. sglang/srt/constrained/xgrammar_backend.py +4 -1
  11. sglang/srt/layers/attention/flashinfer_backend.py +51 -5
  12. sglang/srt/layers/attention/triton_backend.py +16 -25
  13. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  14. sglang/srt/layers/linear.py +20 -2
  15. sglang/srt/layers/logits_processor.py +133 -95
  16. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
  17. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  18. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  19. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
  20. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
  21. sglang/srt/layers/moe/topk.py +191 -0
  22. sglang/srt/layers/quantization/__init__.py +5 -50
  23. sglang/srt/layers/quantization/fp8.py +221 -36
  24. sglang/srt/layers/quantization/fp8_kernel.py +278 -0
  25. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  26. sglang/srt/layers/radix_attention.py +8 -1
  27. sglang/srt/layers/sampler.py +27 -5
  28. sglang/srt/layers/torchao_utils.py +31 -0
  29. sglang/srt/managers/detokenizer_manager.py +37 -17
  30. sglang/srt/managers/io_struct.py +39 -10
  31. sglang/srt/managers/schedule_batch.py +54 -34
  32. sglang/srt/managers/schedule_policy.py +64 -5
  33. sglang/srt/managers/scheduler.py +171 -136
  34. sglang/srt/managers/tokenizer_manager.py +184 -133
  35. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  36. sglang/srt/mem_cache/chunk_cache.py +2 -2
  37. sglang/srt/mem_cache/memory_pool.py +15 -8
  38. sglang/srt/mem_cache/radix_cache.py +12 -2
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -11
  40. sglang/srt/model_executor/model_runner.py +28 -14
  41. sglang/srt/model_parallel.py +66 -5
  42. sglang/srt/models/dbrx.py +1 -1
  43. sglang/srt/models/deepseek.py +1 -1
  44. sglang/srt/models/deepseek_v2.py +67 -18
  45. sglang/srt/models/gemma2.py +34 -0
  46. sglang/srt/models/gemma2_reward.py +0 -1
  47. sglang/srt/models/granite.py +517 -0
  48. sglang/srt/models/grok.py +73 -9
  49. sglang/srt/models/llama.py +22 -0
  50. sglang/srt/models/llama_classification.py +11 -23
  51. sglang/srt/models/llama_reward.py +0 -2
  52. sglang/srt/models/llava.py +37 -14
  53. sglang/srt/models/mixtral.py +2 -2
  54. sglang/srt/models/olmoe.py +1 -1
  55. sglang/srt/models/qwen2.py +20 -0
  56. sglang/srt/models/qwen2_moe.py +1 -1
  57. sglang/srt/models/xverse_moe.py +1 -1
  58. sglang/srt/openai_api/adapter.py +8 -0
  59. sglang/srt/openai_api/protocol.py +9 -4
  60. sglang/srt/server.py +2 -1
  61. sglang/srt/server_args.py +19 -9
  62. sglang/srt/utils.py +40 -54
  63. sglang/test/test_block_fp8.py +341 -0
  64. sglang/test/test_utils.py +3 -2
  65. sglang/utils.py +10 -3
  66. sglang/version.py +1 -1
  67. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
  68. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
  69. sglang/srt/layers/fused_moe_patch.py +0 -133
  70. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  71. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  72. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
  73. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
sglang/llama3_eval.py ADDED
@@ -0,0 +1,316 @@
1
+ # Adapt from https://github.com/fw-ai/llm_eval_meta
2
+
3
+ import argparse
4
+ import asyncio
5
+ import os
6
+ import pickle
7
+ import re
8
+ import shutil
9
+ from collections import defaultdict
10
+ from dataclasses import dataclass
11
+
12
+ import httpx
13
+ import numpy as np
14
+ import openai
15
+ import transformers
16
+ from datasets import load_dataset
17
+ from openai import AsyncOpenAI
18
+ from tqdm import tqdm
19
+
20
+ # Mapping providers to their clients and models
21
+ provider_to_models = {
22
+ "b10": {
23
+ "8b": "meta-llama/Llama-3.1-8B-Instruct",
24
+ "70b": "meta-llama/Llama-3.1-70B-Instruct",
25
+ "405b": "meta-llama/Llama-3.1-405B-Instruct",
26
+ },
27
+ "oai": {
28
+ "8b": "meta-llama/Llama-3.1-8B-Instruct",
29
+ "70b": "meta-llama/Llama-3.1-70B-Instruct",
30
+ "405b": "meta-llama/Llama-3.1-405B-Instruct",
31
+ },
32
+ "sgl": {
33
+ "8b": "meta-llama/Llama-3.1-8B-Instruct",
34
+ "70b": "meta-llama/Llama-3.1-70B-Instruct",
35
+ "405b": "meta-llama/Llama-3.1-405B-Instruct",
36
+ },
37
+ }
38
+
39
+
40
+ async def fetch_responses(
41
+ client, prompt, semaphore, index, provider, model_size, output_dir, max_tokens
42
+ ):
43
+ output_file = os.path.join(output_dir, f"response_{index}.pkl")
44
+ if os.path.exists(output_file):
45
+ print(f"File {output_file} already exists, skipping.")
46
+ return
47
+
48
+ async with semaphore:
49
+ response = await client.completions.create(
50
+ model=provider_to_models[provider][model_size],
51
+ prompt=prompt,
52
+ temperature=0.0,
53
+ max_tokens=max_tokens,
54
+ )
55
+ if isinstance(response, openai.BadRequestError):
56
+ with open(output_file, "wb") as f:
57
+ pickle.dump("bad_response", f)
58
+ assert isinstance(response, openai.types.completion.Completion)
59
+ # Save response to a file
60
+ with open(output_file, "wb") as f:
61
+ pickle.dump(response, f)
62
+
63
+
64
+ TASK_TO_MAX_TOKENS = {
65
+ "evals__mmlu__details": 1,
66
+ "evals__mmlu__0_shot__cot__details": 1024,
67
+ # Official meta uses 1024, but a small % (.05) of questions are answered correctly after relaxing
68
+ "evals__mmlu_pro__details": 2048,
69
+ "evals__gsm8k__details": 1024,
70
+ }
71
+
72
+ TASK_TO_EVAL_SET = {
73
+ "mmlu": "evals__mmlu__details",
74
+ "mmlu_cot": "evals__mmlu__0_shot__cot__details",
75
+ "mmlu_pro": "evals__mmlu_pro__details",
76
+ "gsm8k": "evals__gsm8k__details",
77
+ }
78
+
79
+
80
+ class CustomAsyncHTTPXClient(httpx.AsyncClient):
81
+ async def send(self, request: httpx.Request, *args, **kwargs) -> httpx.Response:
82
+ request.url = httpx.URL(
83
+ f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict"
84
+ )
85
+ return await super().send(request, *args, **kwargs)
86
+
87
+
88
+ def get_client(provider):
89
+ if provider not in "b10":
90
+ if os.getenv("OPENAI_API_KEY") == None:
91
+ os.environ["OPENAI_API_KEY"] = "EMPTY"
92
+ return {
93
+ "oai": AsyncOpenAI(base_url="http://127.0.0.1:8000/v1/"),
94
+ "b10": AsyncOpenAI(
95
+ api_key=f"Api-Key {os.getenv('OPENAI_API_KEY')}",
96
+ base_url=f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict",
97
+ http_client=CustomAsyncHTTPXClient(),
98
+ ),
99
+ "sgl": AsyncOpenAI(base_url="http://127.0.0.1:30000/v1/"),
100
+ }[provider]
101
+
102
+
103
+ # Define the benchmark function
104
+ async def benchmark(args):
105
+ ds = load_dataset(
106
+ "meta-llama/Llama-3.1-405B-Instruct-evals",
107
+ f"Llama-3.1-405B-Instruct-{TASK_TO_EVAL_SET[args.task]}",
108
+ )
109
+ semaphore = asyncio.Semaphore(args.concurrency) # Limit to 16 concurrent tasks
110
+
111
+ if args.num_examples is None:
112
+ args.num_examples = len(ds["latest"]["input_final_prompts"])
113
+ prompts = ds["latest"]["input_final_prompts"][: args.num_examples]
114
+
115
+ # Create the output directory if it does not exist
116
+ os.makedirs(args.output_dir, exist_ok=True)
117
+
118
+ tasks = []
119
+ # Create the tasks with tqdm progress bar
120
+ max_tokens = TASK_TO_MAX_TOKENS[TASK_TO_EVAL_SET[args.task]]
121
+ client = get_client(args.provider)
122
+ for idx, prompt in enumerate(tqdm(prompts, desc="Creating tasks")):
123
+ tasks.append(
124
+ asyncio.create_task(
125
+ fetch_responses(
126
+ client,
127
+ f"<|begin_of_text|>{prompt[0]}",
128
+ semaphore,
129
+ idx,
130
+ args.provider,
131
+ args.model_size,
132
+ args.output_dir,
133
+ max_tokens=max_tokens,
134
+ )
135
+ )
136
+ )
137
+
138
+ # Run the tasks with tqdm progress bar
139
+ for future in tqdm(
140
+ asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks"
141
+ ):
142
+ await future
143
+
144
+
145
+ def get_mmlu_answer(response):
146
+ if response is not None:
147
+ return response.choices[0].text.lstrip().rstrip().upper().replace(".", "")
148
+ return None
149
+
150
+
151
+ def get_mmlu_cot_answer(response):
152
+ pattern = r"The best answer is (.+)\.?"
153
+ match = re.search(pattern, response.choices[0].text)
154
+ if match:
155
+ return match.group(1).replace(".", "").replace("*", "")
156
+
157
+ pattern = r"the best answer is (.+)\.?"
158
+ match = re.search(pattern, response.choices[0].text)
159
+ if match:
160
+ return match.group(1).replace(".", "")
161
+
162
+ pattern = r"The correct answer is (.+)\.?"
163
+ match = re.search(pattern, response.choices[0].text)
164
+ if match:
165
+ return match.group(1).replace(".", "")
166
+
167
+ pattern = r"the correct answer is (.+)\.?"
168
+ match = re.search(pattern, response.choices[0].text)
169
+ if match:
170
+ return match.group(1).replace(".", "")
171
+
172
+
173
+ def get_answer_gsm8k(response):
174
+ pattern = r"The final answer is (.+)\.?"
175
+ match = re.search(pattern, response.choices[0].text)
176
+ if match:
177
+ s = match.group(1)
178
+ for ok_symbol in ["%", "$"]:
179
+ s = s.replace(ok_symbol, "")
180
+ return s
181
+
182
+
183
+ TASK_TO_ANSWER_EXTRACTOR = {
184
+ "evals__mmlu__details": get_mmlu_answer,
185
+ "evals__mmlu__0_shot__cot__details": get_mmlu_cot_answer,
186
+ "evals__gsm8k__details": get_answer_gsm8k,
187
+ "evals__mmlu_pro__details": get_mmlu_cot_answer,
188
+ }
189
+
190
+
191
+ def get_dataset_from_task(task, response_path, model_size):
192
+ ds_405b = load_dataset(
193
+ f"meta-llama/Llama-3.1-405B-Instruct-evals",
194
+ f"Llama-3.1-405B-Instruct-{task}",
195
+ )
196
+ ds_405b_hash_order = [x[0] for x in ds_405b["latest"]["input_final_prompts_hash"]]
197
+
198
+ if "70b" in model_size or "8b" in model_size:
199
+ if "70" in model_size:
200
+ ref_model_ds = load_dataset(
201
+ f"meta-llama/Llama-3.1-70B-Instruct-evals",
202
+ f"Llama-3.1-70B-Instruct-{task}",
203
+ )
204
+ else:
205
+ ref_model_ds = load_dataset(
206
+ f"meta-llama/Llama-3.1-8B-Instruct-evals",
207
+ f"Llama-3.1-8B-Instruct-{task}",
208
+ )
209
+
210
+ hash_to_row = {}
211
+ for row in ref_model_ds["latest"]:
212
+ hash_to_row[row["input_final_prompts_hash"][0]] = row
213
+ reordered_rows = []
214
+ for prompt_hash in ds_405b_hash_order:
215
+ reordered_rows.append(hash_to_row[prompt_hash])
216
+ ref_model_ds["latest"] = reordered_rows
217
+ return ref_model_ds
218
+
219
+ return ds_405b
220
+
221
+
222
+ def analyze(task, response_path, model_size):
223
+ ds = get_dataset_from_task(task, response_path, model_size)
224
+
225
+ responses = []
226
+ total = len(ds["latest"])
227
+
228
+ for i in range(0, total):
229
+ response = pickle.load(
230
+ open(os.path.join(response_path, f"response_{i}.pkl"), "rb")
231
+ )
232
+ responses.append(response)
233
+
234
+ @dataclass
235
+ class Stats:
236
+ correct: int = 0
237
+ total: int = 0
238
+ meta_correct: int = 0
239
+
240
+ average: float = None
241
+
242
+ subtask_name_to_stats = defaultdict(lambda: Stats())
243
+
244
+ for response, ds_row in zip(responses, ds["latest"]):
245
+ model_answer = TASK_TO_ANSWER_EXTRACTOR[task](response)
246
+
247
+ subtask = ds_row["subtask_name"]
248
+
249
+ is_eval_correct = model_answer in ds_row["input_correct_responses"]
250
+ if is_eval_correct:
251
+ subtask_name_to_stats[subtask].correct += 1
252
+
253
+ if ds_row["is_correct"]:
254
+ subtask_name_to_stats[subtask].meta_correct += 1
255
+
256
+ subtask_name_to_stats[subtask].total += 1
257
+
258
+ micro_stats = Stats()
259
+ for subtask, stats in subtask_name_to_stats.items():
260
+ stats.average = stats.correct / stats.total
261
+ stats.meta_average = stats.meta_correct / stats.total
262
+
263
+ micro_stats.correct += stats.correct
264
+ micro_stats.total += stats.total
265
+ micro_stats.meta_correct += stats.meta_correct
266
+
267
+ micro_stats.average = micro_stats.correct / micro_stats.total
268
+ micro_stats.meta_average = micro_stats.meta_correct / micro_stats.total
269
+
270
+ print("Macro average", np.mean([x.average for x in subtask_name_to_stats.values()]))
271
+ print(
272
+ "Meta Macro average",
273
+ np.mean([x.meta_average for x in subtask_name_to_stats.values()]),
274
+ )
275
+ print("Micro average", micro_stats.average)
276
+ print("Meta Micro average", micro_stats.meta_average)
277
+
278
+
279
+ # Entry point for the script
280
+ if __name__ == "__main__":
281
+ parser = argparse.ArgumentParser(
282
+ description="Script to run model with specified parameters."
283
+ )
284
+ parser.add_argument(
285
+ "--model-size",
286
+ type=str,
287
+ default="8b",
288
+ help="Size of the model (e.g., 8b or 70b)",
289
+ )
290
+ parser.add_argument(
291
+ "--provider",
292
+ type=str,
293
+ default="sgl",
294
+ help="Provider name (e.g., sgl, oai, b10)",
295
+ )
296
+ parser.add_argument(
297
+ "--task",
298
+ type=str,
299
+ required=True,
300
+ help="Task (e.g., mmlu, mmlu_cot, mmlu_pro, gsm8k)",
301
+ )
302
+ parser.add_argument(
303
+ "--num-examples", type=int, default=None, help="Number of examples to process"
304
+ )
305
+ parser.add_argument("--concurrency", type=int, default=16)
306
+ parser.add_argument(
307
+ "--output-dir",
308
+ type=str,
309
+ default="tmp-output-dir",
310
+ help="Directory to save responses",
311
+ )
312
+
313
+ args = parser.parse_args()
314
+ asyncio.run(benchmark(args))
315
+ analyze(TASK_TO_EVAL_SET[args.task], args.output_dir, args.model_size)
316
+ shutil.rmtree("tmp-output-dir", ignore_errors=True)
@@ -0,0 +1,100 @@
1
+ import asyncio
2
+
3
+
4
+ class RWLock:
5
+ def __init__(self):
6
+ # Protects internal state
7
+ self._lock = asyncio.Lock()
8
+
9
+ # Condition variable used to wait for state changes
10
+ self._cond = asyncio.Condition(self._lock)
11
+
12
+ # Number of readers currently holding the lock
13
+ self._readers = 0
14
+
15
+ # Whether a writer is currently holding the lock
16
+ self._writer_active = False
17
+
18
+ # How many writers are queued waiting for a turn
19
+ self._waiting_writers = 0
20
+
21
+ @property
22
+ def reader_lock(self):
23
+ """
24
+ A context manager for acquiring a shared (reader) lock.
25
+
26
+ Example:
27
+ async with rwlock.reader_lock:
28
+ # read-only access
29
+ """
30
+ return _ReaderLock(self)
31
+
32
+ @property
33
+ def writer_lock(self):
34
+ """
35
+ A context manager for acquiring an exclusive (writer) lock.
36
+
37
+ Example:
38
+ async with rwlock.writer_lock:
39
+ # exclusive access
40
+ """
41
+ return _WriterLock(self)
42
+
43
+ async def acquire_reader(self):
44
+ async with self._lock:
45
+ # Wait until there is no active writer or waiting writer
46
+ # to ensure fairness.
47
+ while self._writer_active or self._waiting_writers > 0:
48
+ await self._cond.wait()
49
+ self._readers += 1
50
+
51
+ async def release_reader(self):
52
+ async with self._lock:
53
+ self._readers -= 1
54
+ # If this was the last reader, wake up anyone waiting
55
+ # (potentially a writer or new readers).
56
+ if self._readers == 0:
57
+ self._cond.notify_all()
58
+
59
+ async def acquire_writer(self):
60
+ async with self._lock:
61
+ # Increment the count of writers waiting
62
+ self._waiting_writers += 1
63
+ try:
64
+ # Wait while either a writer is active or readers are present
65
+ while self._writer_active or self._readers > 0:
66
+ await self._cond.wait()
67
+ self._writer_active = True
68
+ finally:
69
+ # Decrement waiting writers only after we've acquired the writer lock
70
+ self._waiting_writers -= 1
71
+
72
+ async def release_writer(self):
73
+ async with self._lock:
74
+ self._writer_active = False
75
+ # Wake up anyone waiting (readers or writers)
76
+ self._cond.notify_all()
77
+
78
+
79
+ class _ReaderLock:
80
+ def __init__(self, rwlock: RWLock):
81
+ self._rwlock = rwlock
82
+
83
+ async def __aenter__(self):
84
+ await self._rwlock.acquire_reader()
85
+ return self
86
+
87
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
88
+ await self._rwlock.release_reader()
89
+
90
+
91
+ class _WriterLock:
92
+ def __init__(self, rwlock: RWLock):
93
+ self._rwlock = rwlock
94
+
95
+ async def __aenter__(self):
96
+ await self._rwlock.acquire_writer()
97
+ return self
98
+
99
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
100
+ await self._rwlock.release_writer()
@@ -94,7 +94,10 @@ class ModelConfig:
94
94
  )
95
95
 
96
96
  # FIXME: temporary special judge for MLA architecture
97
- if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
97
+ if (
98
+ "DeepseekV2ForCausalLM" in self.hf_config.architectures
99
+ or "DeepseekV3ForCausalLM" in self.hf_config.architectures
100
+ ):
98
101
  self.head_dim = 256
99
102
  self.attention_arch = AttentionArch.MLA
100
103
  self.kv_lora_rank = self.hf_config.kv_lora_rank
@@ -124,8 +127,12 @@ class ModelConfig:
124
127
  self.num_hidden_layers = self.hf_text_config.num_hidden_layers
125
128
  self.vocab_size = self.hf_text_config.vocab_size
126
129
 
130
+ # Veirfy quantization
127
131
  self._verify_quantization()
128
132
 
133
+ # Multimodel attrs
134
+ self.image_token_id = getattr(self.hf_config, "image_token_id", None)
135
+
129
136
  # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
130
137
  def get_total_num_kv_heads(self) -> int:
131
138
  """Returns the total number of KV heads."""
@@ -117,7 +117,10 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
117
117
  key_type, key_string = key
118
118
  if key_type == "json":
119
119
  try:
120
- ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
120
+ if key_string == "$$ANY$$":
121
+ ctx = self.grammar_compiler.compile_builtin_json_grammar()
122
+ else:
123
+ ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
121
124
  except RuntimeError as e:
122
125
  logging.warning(
123
126
  f"Skip invalid json_schema: json_schema={key_string}, {e=}"
@@ -18,11 +18,7 @@ import triton.language as tl
18
18
  from sglang.global_config import global_config
19
19
  from sglang.srt.layers.attention import AttentionBackend
20
20
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
21
- from sglang.srt.utils import (
22
- get_bool_env_var,
23
- is_flashinfer_available,
24
- should_use_tensor_core,
25
- )
21
+ from sglang.srt.utils import is_flashinfer_available
26
22
 
27
23
  if TYPE_CHECKING:
28
24
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -678,6 +674,7 @@ class FlashInferIndicesUpdaterPrefill:
678
674
  self.num_qo_heads,
679
675
  self.num_kv_heads,
680
676
  self.head_dim,
677
+ q_data_type=self.q_data_type,
681
678
  )
682
679
 
683
680
  # cached part
@@ -691,6 +688,7 @@ class FlashInferIndicesUpdaterPrefill:
691
688
  self.num_kv_heads,
692
689
  self.head_dim,
693
690
  1,
691
+ q_data_type=self.q_data_type,
694
692
  )
695
693
 
696
694
 
@@ -729,3 +727,51 @@ def create_flashinfer_kv_indices_triton(
729
727
  mask=mask,
730
728
  )
731
729
  tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
730
+
731
+
732
+ def should_use_tensor_core(
733
+ kv_cache_dtype: torch.dtype,
734
+ num_attention_heads: int,
735
+ num_kv_heads: int,
736
+ ) -> bool:
737
+ """
738
+ Determine whether to use tensor cores for attention computation.
739
+
740
+ Args:
741
+ kv_cache_dtype: Data type of the KV cache
742
+ num_attention_heads: Number of attention heads
743
+ num_kv_heads: Number of key/value heads
744
+
745
+ Returns:
746
+ bool: Whether to use tensor cores
747
+ """
748
+ # Try to use environment variable first
749
+ env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
750
+ if env_override is not None:
751
+ return env_override.lower() == "true"
752
+
753
+ # Try to use _grouped_size_compiled_for_decode_kernels if available
754
+ # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
755
+ try:
756
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
757
+
758
+ if not _grouped_size_compiled_for_decode_kernels(
759
+ num_attention_heads,
760
+ num_kv_heads,
761
+ ):
762
+ return True
763
+ else:
764
+ return False
765
+ except (ImportError, AttributeError):
766
+ pass
767
+
768
+ # Calculate GQA group size
769
+ gqa_group_size = num_attention_heads // num_kv_heads
770
+
771
+ # Determine based on dtype and GQA group size
772
+ if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
773
+ return True
774
+ elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
775
+ return gqa_group_size > 4
776
+ else:
777
+ return False
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
5
5
  import torch
6
6
 
7
7
  from sglang.srt.layers.attention import AttentionBackend
8
- from sglang.srt.managers.schedule_batch import global_server_args_dict
9
8
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
10
9
 
11
10
  if TYPE_CHECKING:
@@ -35,10 +34,8 @@ class TritonAttnBackend(AttentionBackend):
35
34
  model_runner.model_config.num_attention_heads // model_runner.tp_size
36
35
  )
37
36
 
38
- if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
39
- self.reduce_dtype = torch.float32
40
- else:
41
- self.reduce_dtype = torch.float16
37
+ self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
38
+ self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
42
39
 
43
40
  self.forward_metadata = None
44
41
 
@@ -50,23 +47,23 @@ class TritonAttnBackend(AttentionBackend):
50
47
  """Init auxiliary variables for triton attention backend."""
51
48
 
52
49
  if forward_batch.forward_mode.is_decode():
53
- start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
54
- start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
55
-
56
- total_num_tokens = forward_batch.seq_lens_sum
57
50
  attn_logits = torch.empty(
58
- (self.num_head, total_num_tokens),
59
- dtype=self.reduce_dtype,
51
+ (
52
+ forward_batch.batch_size,
53
+ self.num_head,
54
+ self.num_kv_splits,
55
+ self.v_head_dim + 1,
56
+ ),
57
+ dtype=torch.float32,
60
58
  device=self.device,
61
59
  )
62
60
 
63
- max_seq_len = torch.max(forward_batch.seq_lens).item()
64
61
  max_extend_len = None
65
62
  else:
66
- start_loc = attn_logits = max_seq_len = None
63
+ attn_logits = None
67
64
  max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
68
65
 
69
- self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
66
+ self.forward_metadata = attn_logits, max_extend_len
70
67
 
71
68
  def init_cuda_graph_state(self, max_bs: int):
72
69
  self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
@@ -75,11 +72,8 @@ class TritonAttnBackend(AttentionBackend):
75
72
  (max_bs,), dtype=torch.int32, device=self.device
76
73
  )
77
74
  self.cuda_graph_attn_logits = torch.empty(
78
- (
79
- self.num_head,
80
- self.cuda_graph_max_total_num_tokens,
81
- ),
82
- dtype=self.reduce_dtype,
75
+ (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
76
+ dtype=torch.float32,
83
77
  device="cuda",
84
78
  )
85
79
 
@@ -92,9 +86,7 @@ class TritonAttnBackend(AttentionBackend):
92
86
  ):
93
87
  # NOTE: encoder_lens expected to be zeros or None
94
88
  self.forward_metadata = (
95
- self.cuda_graph_start_loc,
96
89
  self.cuda_graph_attn_logits,
97
- self.cuda_graph_max_seq_len,
98
90
  None,
99
91
  )
100
92
 
@@ -133,7 +125,7 @@ class TritonAttnBackend(AttentionBackend):
133
125
  layer, forward_batch.out_cache_loc, k, v
134
126
  )
135
127
 
136
- start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
128
+ _, max_extend_len = self.forward_metadata
137
129
  self.extend_attention_fwd(
138
130
  q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
139
131
  k.contiguous(),
@@ -171,7 +163,7 @@ class TritonAttnBackend(AttentionBackend):
171
163
  else:
172
164
  o = torch.empty_like(q)
173
165
 
174
- start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
166
+ attn_logits, _ = self.forward_metadata
175
167
 
176
168
  if save_kv_cache:
177
169
  forward_batch.token_to_kv_pool.set_kv_buffer(
@@ -185,10 +177,9 @@ class TritonAttnBackend(AttentionBackend):
185
177
  o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
186
178
  forward_batch.req_to_token_pool.req_to_token,
187
179
  forward_batch.req_pool_indices,
188
- start_loc,
189
180
  forward_batch.seq_lens,
190
181
  attn_logits,
191
- max_seq_len,
182
+ self.num_kv_splits,
192
183
  layer.scaling,
193
184
  layer.logit_cap,
194
185
  )