sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +1 -0
- sglang/bench_serving.py +9 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +4 -1
- sglang/srt/layers/attention/flashinfer_backend.py +51 -5
- sglang/srt/layers/attention/triton_backend.py +16 -25
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
- sglang/srt/layers/moe/topk.py +191 -0
- sglang/srt/layers/quantization/__init__.py +5 -50
- sglang/srt/layers/quantization/fp8.py +221 -36
- sglang/srt/layers/quantization/fp8_kernel.py +278 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/radix_attention.py +8 -1
- sglang/srt/layers/sampler.py +27 -5
- sglang/srt/layers/torchao_utils.py +31 -0
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +54 -34
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +171 -136
- sglang/srt/managers/tokenizer_manager.py +184 -133
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +25 -11
- sglang/srt/model_executor/model_runner.py +28 -14
- sglang/srt/model_parallel.py +66 -5
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +73 -9
- sglang/srt/models/llama.py +22 -0
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +8 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/server.py +2 -1
- sglang/srt/server_args.py +19 -9
- sglang/srt/utils.py +40 -54
- sglang/test/test_block_fp8.py +341 -0
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
sglang/llama3_eval.py
ADDED
@@ -0,0 +1,316 @@
|
|
1
|
+
# Adapt from https://github.com/fw-ai/llm_eval_meta
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
import asyncio
|
5
|
+
import os
|
6
|
+
import pickle
|
7
|
+
import re
|
8
|
+
import shutil
|
9
|
+
from collections import defaultdict
|
10
|
+
from dataclasses import dataclass
|
11
|
+
|
12
|
+
import httpx
|
13
|
+
import numpy as np
|
14
|
+
import openai
|
15
|
+
import transformers
|
16
|
+
from datasets import load_dataset
|
17
|
+
from openai import AsyncOpenAI
|
18
|
+
from tqdm import tqdm
|
19
|
+
|
20
|
+
# Mapping providers to their clients and models
|
21
|
+
provider_to_models = {
|
22
|
+
"b10": {
|
23
|
+
"8b": "meta-llama/Llama-3.1-8B-Instruct",
|
24
|
+
"70b": "meta-llama/Llama-3.1-70B-Instruct",
|
25
|
+
"405b": "meta-llama/Llama-3.1-405B-Instruct",
|
26
|
+
},
|
27
|
+
"oai": {
|
28
|
+
"8b": "meta-llama/Llama-3.1-8B-Instruct",
|
29
|
+
"70b": "meta-llama/Llama-3.1-70B-Instruct",
|
30
|
+
"405b": "meta-llama/Llama-3.1-405B-Instruct",
|
31
|
+
},
|
32
|
+
"sgl": {
|
33
|
+
"8b": "meta-llama/Llama-3.1-8B-Instruct",
|
34
|
+
"70b": "meta-llama/Llama-3.1-70B-Instruct",
|
35
|
+
"405b": "meta-llama/Llama-3.1-405B-Instruct",
|
36
|
+
},
|
37
|
+
}
|
38
|
+
|
39
|
+
|
40
|
+
async def fetch_responses(
|
41
|
+
client, prompt, semaphore, index, provider, model_size, output_dir, max_tokens
|
42
|
+
):
|
43
|
+
output_file = os.path.join(output_dir, f"response_{index}.pkl")
|
44
|
+
if os.path.exists(output_file):
|
45
|
+
print(f"File {output_file} already exists, skipping.")
|
46
|
+
return
|
47
|
+
|
48
|
+
async with semaphore:
|
49
|
+
response = await client.completions.create(
|
50
|
+
model=provider_to_models[provider][model_size],
|
51
|
+
prompt=prompt,
|
52
|
+
temperature=0.0,
|
53
|
+
max_tokens=max_tokens,
|
54
|
+
)
|
55
|
+
if isinstance(response, openai.BadRequestError):
|
56
|
+
with open(output_file, "wb") as f:
|
57
|
+
pickle.dump("bad_response", f)
|
58
|
+
assert isinstance(response, openai.types.completion.Completion)
|
59
|
+
# Save response to a file
|
60
|
+
with open(output_file, "wb") as f:
|
61
|
+
pickle.dump(response, f)
|
62
|
+
|
63
|
+
|
64
|
+
TASK_TO_MAX_TOKENS = {
|
65
|
+
"evals__mmlu__details": 1,
|
66
|
+
"evals__mmlu__0_shot__cot__details": 1024,
|
67
|
+
# Official meta uses 1024, but a small % (.05) of questions are answered correctly after relaxing
|
68
|
+
"evals__mmlu_pro__details": 2048,
|
69
|
+
"evals__gsm8k__details": 1024,
|
70
|
+
}
|
71
|
+
|
72
|
+
TASK_TO_EVAL_SET = {
|
73
|
+
"mmlu": "evals__mmlu__details",
|
74
|
+
"mmlu_cot": "evals__mmlu__0_shot__cot__details",
|
75
|
+
"mmlu_pro": "evals__mmlu_pro__details",
|
76
|
+
"gsm8k": "evals__gsm8k__details",
|
77
|
+
}
|
78
|
+
|
79
|
+
|
80
|
+
class CustomAsyncHTTPXClient(httpx.AsyncClient):
|
81
|
+
async def send(self, request: httpx.Request, *args, **kwargs) -> httpx.Response:
|
82
|
+
request.url = httpx.URL(
|
83
|
+
f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict"
|
84
|
+
)
|
85
|
+
return await super().send(request, *args, **kwargs)
|
86
|
+
|
87
|
+
|
88
|
+
def get_client(provider):
|
89
|
+
if provider not in "b10":
|
90
|
+
if os.getenv("OPENAI_API_KEY") == None:
|
91
|
+
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
92
|
+
return {
|
93
|
+
"oai": AsyncOpenAI(base_url="http://127.0.0.1:8000/v1/"),
|
94
|
+
"b10": AsyncOpenAI(
|
95
|
+
api_key=f"Api-Key {os.getenv('OPENAI_API_KEY')}",
|
96
|
+
base_url=f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict",
|
97
|
+
http_client=CustomAsyncHTTPXClient(),
|
98
|
+
),
|
99
|
+
"sgl": AsyncOpenAI(base_url="http://127.0.0.1:30000/v1/"),
|
100
|
+
}[provider]
|
101
|
+
|
102
|
+
|
103
|
+
# Define the benchmark function
|
104
|
+
async def benchmark(args):
|
105
|
+
ds = load_dataset(
|
106
|
+
"meta-llama/Llama-3.1-405B-Instruct-evals",
|
107
|
+
f"Llama-3.1-405B-Instruct-{TASK_TO_EVAL_SET[args.task]}",
|
108
|
+
)
|
109
|
+
semaphore = asyncio.Semaphore(args.concurrency) # Limit to 16 concurrent tasks
|
110
|
+
|
111
|
+
if args.num_examples is None:
|
112
|
+
args.num_examples = len(ds["latest"]["input_final_prompts"])
|
113
|
+
prompts = ds["latest"]["input_final_prompts"][: args.num_examples]
|
114
|
+
|
115
|
+
# Create the output directory if it does not exist
|
116
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
117
|
+
|
118
|
+
tasks = []
|
119
|
+
# Create the tasks with tqdm progress bar
|
120
|
+
max_tokens = TASK_TO_MAX_TOKENS[TASK_TO_EVAL_SET[args.task]]
|
121
|
+
client = get_client(args.provider)
|
122
|
+
for idx, prompt in enumerate(tqdm(prompts, desc="Creating tasks")):
|
123
|
+
tasks.append(
|
124
|
+
asyncio.create_task(
|
125
|
+
fetch_responses(
|
126
|
+
client,
|
127
|
+
f"<|begin_of_text|>{prompt[0]}",
|
128
|
+
semaphore,
|
129
|
+
idx,
|
130
|
+
args.provider,
|
131
|
+
args.model_size,
|
132
|
+
args.output_dir,
|
133
|
+
max_tokens=max_tokens,
|
134
|
+
)
|
135
|
+
)
|
136
|
+
)
|
137
|
+
|
138
|
+
# Run the tasks with tqdm progress bar
|
139
|
+
for future in tqdm(
|
140
|
+
asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks"
|
141
|
+
):
|
142
|
+
await future
|
143
|
+
|
144
|
+
|
145
|
+
def get_mmlu_answer(response):
|
146
|
+
if response is not None:
|
147
|
+
return response.choices[0].text.lstrip().rstrip().upper().replace(".", "")
|
148
|
+
return None
|
149
|
+
|
150
|
+
|
151
|
+
def get_mmlu_cot_answer(response):
|
152
|
+
pattern = r"The best answer is (.+)\.?"
|
153
|
+
match = re.search(pattern, response.choices[0].text)
|
154
|
+
if match:
|
155
|
+
return match.group(1).replace(".", "").replace("*", "")
|
156
|
+
|
157
|
+
pattern = r"the best answer is (.+)\.?"
|
158
|
+
match = re.search(pattern, response.choices[0].text)
|
159
|
+
if match:
|
160
|
+
return match.group(1).replace(".", "")
|
161
|
+
|
162
|
+
pattern = r"The correct answer is (.+)\.?"
|
163
|
+
match = re.search(pattern, response.choices[0].text)
|
164
|
+
if match:
|
165
|
+
return match.group(1).replace(".", "")
|
166
|
+
|
167
|
+
pattern = r"the correct answer is (.+)\.?"
|
168
|
+
match = re.search(pattern, response.choices[0].text)
|
169
|
+
if match:
|
170
|
+
return match.group(1).replace(".", "")
|
171
|
+
|
172
|
+
|
173
|
+
def get_answer_gsm8k(response):
|
174
|
+
pattern = r"The final answer is (.+)\.?"
|
175
|
+
match = re.search(pattern, response.choices[0].text)
|
176
|
+
if match:
|
177
|
+
s = match.group(1)
|
178
|
+
for ok_symbol in ["%", "$"]:
|
179
|
+
s = s.replace(ok_symbol, "")
|
180
|
+
return s
|
181
|
+
|
182
|
+
|
183
|
+
TASK_TO_ANSWER_EXTRACTOR = {
|
184
|
+
"evals__mmlu__details": get_mmlu_answer,
|
185
|
+
"evals__mmlu__0_shot__cot__details": get_mmlu_cot_answer,
|
186
|
+
"evals__gsm8k__details": get_answer_gsm8k,
|
187
|
+
"evals__mmlu_pro__details": get_mmlu_cot_answer,
|
188
|
+
}
|
189
|
+
|
190
|
+
|
191
|
+
def get_dataset_from_task(task, response_path, model_size):
|
192
|
+
ds_405b = load_dataset(
|
193
|
+
f"meta-llama/Llama-3.1-405B-Instruct-evals",
|
194
|
+
f"Llama-3.1-405B-Instruct-{task}",
|
195
|
+
)
|
196
|
+
ds_405b_hash_order = [x[0] for x in ds_405b["latest"]["input_final_prompts_hash"]]
|
197
|
+
|
198
|
+
if "70b" in model_size or "8b" in model_size:
|
199
|
+
if "70" in model_size:
|
200
|
+
ref_model_ds = load_dataset(
|
201
|
+
f"meta-llama/Llama-3.1-70B-Instruct-evals",
|
202
|
+
f"Llama-3.1-70B-Instruct-{task}",
|
203
|
+
)
|
204
|
+
else:
|
205
|
+
ref_model_ds = load_dataset(
|
206
|
+
f"meta-llama/Llama-3.1-8B-Instruct-evals",
|
207
|
+
f"Llama-3.1-8B-Instruct-{task}",
|
208
|
+
)
|
209
|
+
|
210
|
+
hash_to_row = {}
|
211
|
+
for row in ref_model_ds["latest"]:
|
212
|
+
hash_to_row[row["input_final_prompts_hash"][0]] = row
|
213
|
+
reordered_rows = []
|
214
|
+
for prompt_hash in ds_405b_hash_order:
|
215
|
+
reordered_rows.append(hash_to_row[prompt_hash])
|
216
|
+
ref_model_ds["latest"] = reordered_rows
|
217
|
+
return ref_model_ds
|
218
|
+
|
219
|
+
return ds_405b
|
220
|
+
|
221
|
+
|
222
|
+
def analyze(task, response_path, model_size):
|
223
|
+
ds = get_dataset_from_task(task, response_path, model_size)
|
224
|
+
|
225
|
+
responses = []
|
226
|
+
total = len(ds["latest"])
|
227
|
+
|
228
|
+
for i in range(0, total):
|
229
|
+
response = pickle.load(
|
230
|
+
open(os.path.join(response_path, f"response_{i}.pkl"), "rb")
|
231
|
+
)
|
232
|
+
responses.append(response)
|
233
|
+
|
234
|
+
@dataclass
|
235
|
+
class Stats:
|
236
|
+
correct: int = 0
|
237
|
+
total: int = 0
|
238
|
+
meta_correct: int = 0
|
239
|
+
|
240
|
+
average: float = None
|
241
|
+
|
242
|
+
subtask_name_to_stats = defaultdict(lambda: Stats())
|
243
|
+
|
244
|
+
for response, ds_row in zip(responses, ds["latest"]):
|
245
|
+
model_answer = TASK_TO_ANSWER_EXTRACTOR[task](response)
|
246
|
+
|
247
|
+
subtask = ds_row["subtask_name"]
|
248
|
+
|
249
|
+
is_eval_correct = model_answer in ds_row["input_correct_responses"]
|
250
|
+
if is_eval_correct:
|
251
|
+
subtask_name_to_stats[subtask].correct += 1
|
252
|
+
|
253
|
+
if ds_row["is_correct"]:
|
254
|
+
subtask_name_to_stats[subtask].meta_correct += 1
|
255
|
+
|
256
|
+
subtask_name_to_stats[subtask].total += 1
|
257
|
+
|
258
|
+
micro_stats = Stats()
|
259
|
+
for subtask, stats in subtask_name_to_stats.items():
|
260
|
+
stats.average = stats.correct / stats.total
|
261
|
+
stats.meta_average = stats.meta_correct / stats.total
|
262
|
+
|
263
|
+
micro_stats.correct += stats.correct
|
264
|
+
micro_stats.total += stats.total
|
265
|
+
micro_stats.meta_correct += stats.meta_correct
|
266
|
+
|
267
|
+
micro_stats.average = micro_stats.correct / micro_stats.total
|
268
|
+
micro_stats.meta_average = micro_stats.meta_correct / micro_stats.total
|
269
|
+
|
270
|
+
print("Macro average", np.mean([x.average for x in subtask_name_to_stats.values()]))
|
271
|
+
print(
|
272
|
+
"Meta Macro average",
|
273
|
+
np.mean([x.meta_average for x in subtask_name_to_stats.values()]),
|
274
|
+
)
|
275
|
+
print("Micro average", micro_stats.average)
|
276
|
+
print("Meta Micro average", micro_stats.meta_average)
|
277
|
+
|
278
|
+
|
279
|
+
# Entry point for the script
|
280
|
+
if __name__ == "__main__":
|
281
|
+
parser = argparse.ArgumentParser(
|
282
|
+
description="Script to run model with specified parameters."
|
283
|
+
)
|
284
|
+
parser.add_argument(
|
285
|
+
"--model-size",
|
286
|
+
type=str,
|
287
|
+
default="8b",
|
288
|
+
help="Size of the model (e.g., 8b or 70b)",
|
289
|
+
)
|
290
|
+
parser.add_argument(
|
291
|
+
"--provider",
|
292
|
+
type=str,
|
293
|
+
default="sgl",
|
294
|
+
help="Provider name (e.g., sgl, oai, b10)",
|
295
|
+
)
|
296
|
+
parser.add_argument(
|
297
|
+
"--task",
|
298
|
+
type=str,
|
299
|
+
required=True,
|
300
|
+
help="Task (e.g., mmlu, mmlu_cot, mmlu_pro, gsm8k)",
|
301
|
+
)
|
302
|
+
parser.add_argument(
|
303
|
+
"--num-examples", type=int, default=None, help="Number of examples to process"
|
304
|
+
)
|
305
|
+
parser.add_argument("--concurrency", type=int, default=16)
|
306
|
+
parser.add_argument(
|
307
|
+
"--output-dir",
|
308
|
+
type=str,
|
309
|
+
default="tmp-output-dir",
|
310
|
+
help="Directory to save responses",
|
311
|
+
)
|
312
|
+
|
313
|
+
args = parser.parse_args()
|
314
|
+
asyncio.run(benchmark(args))
|
315
|
+
analyze(TASK_TO_EVAL_SET[args.task], args.output_dir, args.model_size)
|
316
|
+
shutil.rmtree("tmp-output-dir", ignore_errors=True)
|
sglang/srt/aio_rwlock.py
ADDED
@@ -0,0 +1,100 @@
|
|
1
|
+
import asyncio
|
2
|
+
|
3
|
+
|
4
|
+
class RWLock:
|
5
|
+
def __init__(self):
|
6
|
+
# Protects internal state
|
7
|
+
self._lock = asyncio.Lock()
|
8
|
+
|
9
|
+
# Condition variable used to wait for state changes
|
10
|
+
self._cond = asyncio.Condition(self._lock)
|
11
|
+
|
12
|
+
# Number of readers currently holding the lock
|
13
|
+
self._readers = 0
|
14
|
+
|
15
|
+
# Whether a writer is currently holding the lock
|
16
|
+
self._writer_active = False
|
17
|
+
|
18
|
+
# How many writers are queued waiting for a turn
|
19
|
+
self._waiting_writers = 0
|
20
|
+
|
21
|
+
@property
|
22
|
+
def reader_lock(self):
|
23
|
+
"""
|
24
|
+
A context manager for acquiring a shared (reader) lock.
|
25
|
+
|
26
|
+
Example:
|
27
|
+
async with rwlock.reader_lock:
|
28
|
+
# read-only access
|
29
|
+
"""
|
30
|
+
return _ReaderLock(self)
|
31
|
+
|
32
|
+
@property
|
33
|
+
def writer_lock(self):
|
34
|
+
"""
|
35
|
+
A context manager for acquiring an exclusive (writer) lock.
|
36
|
+
|
37
|
+
Example:
|
38
|
+
async with rwlock.writer_lock:
|
39
|
+
# exclusive access
|
40
|
+
"""
|
41
|
+
return _WriterLock(self)
|
42
|
+
|
43
|
+
async def acquire_reader(self):
|
44
|
+
async with self._lock:
|
45
|
+
# Wait until there is no active writer or waiting writer
|
46
|
+
# to ensure fairness.
|
47
|
+
while self._writer_active or self._waiting_writers > 0:
|
48
|
+
await self._cond.wait()
|
49
|
+
self._readers += 1
|
50
|
+
|
51
|
+
async def release_reader(self):
|
52
|
+
async with self._lock:
|
53
|
+
self._readers -= 1
|
54
|
+
# If this was the last reader, wake up anyone waiting
|
55
|
+
# (potentially a writer or new readers).
|
56
|
+
if self._readers == 0:
|
57
|
+
self._cond.notify_all()
|
58
|
+
|
59
|
+
async def acquire_writer(self):
|
60
|
+
async with self._lock:
|
61
|
+
# Increment the count of writers waiting
|
62
|
+
self._waiting_writers += 1
|
63
|
+
try:
|
64
|
+
# Wait while either a writer is active or readers are present
|
65
|
+
while self._writer_active or self._readers > 0:
|
66
|
+
await self._cond.wait()
|
67
|
+
self._writer_active = True
|
68
|
+
finally:
|
69
|
+
# Decrement waiting writers only after we've acquired the writer lock
|
70
|
+
self._waiting_writers -= 1
|
71
|
+
|
72
|
+
async def release_writer(self):
|
73
|
+
async with self._lock:
|
74
|
+
self._writer_active = False
|
75
|
+
# Wake up anyone waiting (readers or writers)
|
76
|
+
self._cond.notify_all()
|
77
|
+
|
78
|
+
|
79
|
+
class _ReaderLock:
|
80
|
+
def __init__(self, rwlock: RWLock):
|
81
|
+
self._rwlock = rwlock
|
82
|
+
|
83
|
+
async def __aenter__(self):
|
84
|
+
await self._rwlock.acquire_reader()
|
85
|
+
return self
|
86
|
+
|
87
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
88
|
+
await self._rwlock.release_reader()
|
89
|
+
|
90
|
+
|
91
|
+
class _WriterLock:
|
92
|
+
def __init__(self, rwlock: RWLock):
|
93
|
+
self._rwlock = rwlock
|
94
|
+
|
95
|
+
async def __aenter__(self):
|
96
|
+
await self._rwlock.acquire_writer()
|
97
|
+
return self
|
98
|
+
|
99
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
100
|
+
await self._rwlock.release_writer()
|
@@ -94,7 +94,10 @@ class ModelConfig:
|
|
94
94
|
)
|
95
95
|
|
96
96
|
# FIXME: temporary special judge for MLA architecture
|
97
|
-
if
|
97
|
+
if (
|
98
|
+
"DeepseekV2ForCausalLM" in self.hf_config.architectures
|
99
|
+
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
|
100
|
+
):
|
98
101
|
self.head_dim = 256
|
99
102
|
self.attention_arch = AttentionArch.MLA
|
100
103
|
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
@@ -124,8 +127,12 @@ class ModelConfig:
|
|
124
127
|
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
125
128
|
self.vocab_size = self.hf_text_config.vocab_size
|
126
129
|
|
130
|
+
# Veirfy quantization
|
127
131
|
self._verify_quantization()
|
128
132
|
|
133
|
+
# Multimodel attrs
|
134
|
+
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
|
135
|
+
|
129
136
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
130
137
|
def get_total_num_kv_heads(self) -> int:
|
131
138
|
"""Returns the total number of KV heads."""
|
@@ -117,7 +117,10 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
117
117
|
key_type, key_string = key
|
118
118
|
if key_type == "json":
|
119
119
|
try:
|
120
|
-
|
120
|
+
if key_string == "$$ANY$$":
|
121
|
+
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
122
|
+
else:
|
123
|
+
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
121
124
|
except RuntimeError as e:
|
122
125
|
logging.warning(
|
123
126
|
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
@@ -18,11 +18,7 @@ import triton.language as tl
|
|
18
18
|
from sglang.global_config import global_config
|
19
19
|
from sglang.srt.layers.attention import AttentionBackend
|
20
20
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
21
|
-
from sglang.srt.utils import
|
22
|
-
get_bool_env_var,
|
23
|
-
is_flashinfer_available,
|
24
|
-
should_use_tensor_core,
|
25
|
-
)
|
21
|
+
from sglang.srt.utils import is_flashinfer_available
|
26
22
|
|
27
23
|
if TYPE_CHECKING:
|
28
24
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -678,6 +674,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
678
674
|
self.num_qo_heads,
|
679
675
|
self.num_kv_heads,
|
680
676
|
self.head_dim,
|
677
|
+
q_data_type=self.q_data_type,
|
681
678
|
)
|
682
679
|
|
683
680
|
# cached part
|
@@ -691,6 +688,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
691
688
|
self.num_kv_heads,
|
692
689
|
self.head_dim,
|
693
690
|
1,
|
691
|
+
q_data_type=self.q_data_type,
|
694
692
|
)
|
695
693
|
|
696
694
|
|
@@ -729,3 +727,51 @@ def create_flashinfer_kv_indices_triton(
|
|
729
727
|
mask=mask,
|
730
728
|
)
|
731
729
|
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
730
|
+
|
731
|
+
|
732
|
+
def should_use_tensor_core(
|
733
|
+
kv_cache_dtype: torch.dtype,
|
734
|
+
num_attention_heads: int,
|
735
|
+
num_kv_heads: int,
|
736
|
+
) -> bool:
|
737
|
+
"""
|
738
|
+
Determine whether to use tensor cores for attention computation.
|
739
|
+
|
740
|
+
Args:
|
741
|
+
kv_cache_dtype: Data type of the KV cache
|
742
|
+
num_attention_heads: Number of attention heads
|
743
|
+
num_kv_heads: Number of key/value heads
|
744
|
+
|
745
|
+
Returns:
|
746
|
+
bool: Whether to use tensor cores
|
747
|
+
"""
|
748
|
+
# Try to use environment variable first
|
749
|
+
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
|
750
|
+
if env_override is not None:
|
751
|
+
return env_override.lower() == "true"
|
752
|
+
|
753
|
+
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
754
|
+
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
755
|
+
try:
|
756
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
757
|
+
|
758
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
759
|
+
num_attention_heads,
|
760
|
+
num_kv_heads,
|
761
|
+
):
|
762
|
+
return True
|
763
|
+
else:
|
764
|
+
return False
|
765
|
+
except (ImportError, AttributeError):
|
766
|
+
pass
|
767
|
+
|
768
|
+
# Calculate GQA group size
|
769
|
+
gqa_group_size = num_attention_heads // num_kv_heads
|
770
|
+
|
771
|
+
# Determine based on dtype and GQA group size
|
772
|
+
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
773
|
+
return True
|
774
|
+
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
|
775
|
+
return gqa_group_size > 4
|
776
|
+
else:
|
777
|
+
return False
|
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
|
|
5
5
|
import torch
|
6
6
|
|
7
7
|
from sglang.srt.layers.attention import AttentionBackend
|
8
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
8
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
10
9
|
|
11
10
|
if TYPE_CHECKING:
|
@@ -35,10 +34,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
35
34
|
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
36
35
|
)
|
37
36
|
|
38
|
-
|
39
|
-
|
40
|
-
else:
|
41
|
-
self.reduce_dtype = torch.float16
|
37
|
+
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
38
|
+
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
42
39
|
|
43
40
|
self.forward_metadata = None
|
44
41
|
|
@@ -50,23 +47,23 @@ class TritonAttnBackend(AttentionBackend):
|
|
50
47
|
"""Init auxiliary variables for triton attention backend."""
|
51
48
|
|
52
49
|
if forward_batch.forward_mode.is_decode():
|
53
|
-
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
|
54
|
-
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
|
55
|
-
|
56
|
-
total_num_tokens = forward_batch.seq_lens_sum
|
57
50
|
attn_logits = torch.empty(
|
58
|
-
(
|
59
|
-
|
51
|
+
(
|
52
|
+
forward_batch.batch_size,
|
53
|
+
self.num_head,
|
54
|
+
self.num_kv_splits,
|
55
|
+
self.v_head_dim + 1,
|
56
|
+
),
|
57
|
+
dtype=torch.float32,
|
60
58
|
device=self.device,
|
61
59
|
)
|
62
60
|
|
63
|
-
max_seq_len = torch.max(forward_batch.seq_lens).item()
|
64
61
|
max_extend_len = None
|
65
62
|
else:
|
66
|
-
|
63
|
+
attn_logits = None
|
67
64
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
68
65
|
|
69
|
-
self.forward_metadata =
|
66
|
+
self.forward_metadata = attn_logits, max_extend_len
|
70
67
|
|
71
68
|
def init_cuda_graph_state(self, max_bs: int):
|
72
69
|
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
@@ -75,11 +72,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
75
72
|
(max_bs,), dtype=torch.int32, device=self.device
|
76
73
|
)
|
77
74
|
self.cuda_graph_attn_logits = torch.empty(
|
78
|
-
(
|
79
|
-
|
80
|
-
self.cuda_graph_max_total_num_tokens,
|
81
|
-
),
|
82
|
-
dtype=self.reduce_dtype,
|
75
|
+
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
|
76
|
+
dtype=torch.float32,
|
83
77
|
device="cuda",
|
84
78
|
)
|
85
79
|
|
@@ -92,9 +86,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
92
86
|
):
|
93
87
|
# NOTE: encoder_lens expected to be zeros or None
|
94
88
|
self.forward_metadata = (
|
95
|
-
self.cuda_graph_start_loc,
|
96
89
|
self.cuda_graph_attn_logits,
|
97
|
-
self.cuda_graph_max_seq_len,
|
98
90
|
None,
|
99
91
|
)
|
100
92
|
|
@@ -133,7 +125,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
133
125
|
layer, forward_batch.out_cache_loc, k, v
|
134
126
|
)
|
135
127
|
|
136
|
-
|
128
|
+
_, max_extend_len = self.forward_metadata
|
137
129
|
self.extend_attention_fwd(
|
138
130
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
139
131
|
k.contiguous(),
|
@@ -171,7 +163,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
171
163
|
else:
|
172
164
|
o = torch.empty_like(q)
|
173
165
|
|
174
|
-
|
166
|
+
attn_logits, _ = self.forward_metadata
|
175
167
|
|
176
168
|
if save_kv_cache:
|
177
169
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
@@ -185,10 +177,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
185
177
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
186
178
|
forward_batch.req_to_token_pool.req_to_token,
|
187
179
|
forward_batch.req_pool_indices,
|
188
|
-
start_loc,
|
189
180
|
forward_batch.seq_lens,
|
190
181
|
attn_logits,
|
191
|
-
|
182
|
+
self.num_kv_splits,
|
192
183
|
layer.scaling,
|
193
184
|
layer.logit_cap,
|
194
185
|
)
|