sglang 0.3.3.post1__py3-none-any.whl → 0.3.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +30 -11
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/lang/chat_template.py +17 -0
- sglang/launch_server_llavavid.py +1 -1
- sglang/srt/configs/__init__.py +3 -0
- sglang/srt/configs/model_config.py +2 -0
- sglang/srt/configs/qwen2vl.py +133 -0
- sglang/srt/conversation.py +27 -0
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/attention/__init__.py +38 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +297 -0
- sglang/srt/layers/attention/flashinfer_backend.py +486 -97
- sglang/srt/layers/attention/triton_backend.py +26 -8
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +30 -6
- sglang/srt/layers/linear.py +89 -63
- sglang/srt/layers/rotary_embedding.py +145 -0
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/lora/lora.py +3 -1
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/image_processor.py +186 -13
- sglang/srt/managers/io_struct.py +4 -0
- sglang/srt/managers/schedule_batch.py +319 -82
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +233 -158
- sglang/srt/managers/tokenizer_manager.py +15 -5
- sglang/srt/managers/tp_worker.py +30 -5
- sglang/srt/managers/tp_worker_overlap_thread.py +172 -0
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +123 -11
- sglang/srt/mem_cache/radix_cache.py +19 -10
- sglang/srt/model_executor/cuda_graph_runner.py +63 -12
- sglang/srt/model_executor/forward_batch_info.py +101 -23
- sglang/srt/model_executor/model_runner.py +92 -12
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +8 -9
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +7 -8
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/mllama.py +1004 -0
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/qwen2_vl.py +724 -0
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +92 -49
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +103 -59
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +116 -17
- sglang/srt/server_args.py +131 -45
- sglang/srt/utils.py +33 -3
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/runners.py +20 -1
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.post1.dist-info}/METADATA +75 -32
- sglang-0.3.4.post1.dist-info/RECORD +148 -0
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.post1.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- sglang-0.3.3.post1.dist-info/RECORD +0 -140
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.post1.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py
CHANGED
@@ -227,22 +227,24 @@ def extend(reqs, model_runner):
|
|
227
227
|
req_to_token_pool=model_runner.req_to_token_pool,
|
228
228
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
229
229
|
tree_cache=None,
|
230
|
+
model_config=model_runner.model_config,
|
230
231
|
)
|
231
|
-
batch.prepare_for_extend(
|
232
|
+
batch.prepare_for_extend()
|
232
233
|
model_worker_batch = batch.get_model_worker_batch()
|
233
234
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
234
235
|
logits_output = model_runner.forward(forward_batch)
|
235
|
-
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
236
|
+
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
236
237
|
return next_token_ids, logits_output.next_token_logits, batch
|
237
238
|
|
238
239
|
|
239
240
|
@torch.inference_mode()
|
240
241
|
def decode(input_token_ids, batch, model_runner):
|
241
|
-
batch.
|
242
|
+
batch.output_ids = input_token_ids
|
243
|
+
batch.prepare_for_decode()
|
242
244
|
model_worker_batch = batch.get_model_worker_batch()
|
243
245
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
244
246
|
logits_output = model_runner.forward(forward_batch)
|
245
|
-
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
247
|
+
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
246
248
|
return next_token_ids, logits_output.next_token_logits
|
247
249
|
|
248
250
|
|
@@ -252,6 +254,7 @@ def correctness_test(
|
|
252
254
|
bench_args,
|
253
255
|
tp_rank,
|
254
256
|
):
|
257
|
+
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
255
258
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
256
259
|
|
257
260
|
# Load the model
|
@@ -279,8 +282,9 @@ def correctness_test(
|
|
279
282
|
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
280
283
|
for _ in range(bench_args.output_len[0] - 1):
|
281
284
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
285
|
+
next_token_ids_list = next_token_ids.tolist()
|
282
286
|
for i in range(len(reqs)):
|
283
|
-
output_ids[i].append(
|
287
|
+
output_ids[i].append(next_token_ids_list[i])
|
284
288
|
|
285
289
|
# Print
|
286
290
|
for i in range(len(reqs)):
|
@@ -288,8 +292,15 @@ def correctness_test(
|
|
288
292
|
rank_print(tokenizer.decode(output_ids[i]), "\n")
|
289
293
|
|
290
294
|
|
295
|
+
def synchronize(device):
|
296
|
+
if device == "cuda":
|
297
|
+
torch.cuda.synchronize()
|
298
|
+
elif device == "xpu":
|
299
|
+
torch.xpu.synchronize()
|
300
|
+
|
301
|
+
|
291
302
|
def latency_test_run_once(
|
292
|
-
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
|
303
|
+
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
|
293
304
|
):
|
294
305
|
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
295
306
|
if batch_size > max_batch_size:
|
@@ -312,10 +323,10 @@ def latency_test_run_once(
|
|
312
323
|
tot_latency = 0
|
313
324
|
|
314
325
|
# Prefill
|
315
|
-
|
326
|
+
synchronize(device)
|
316
327
|
tic = time.time()
|
317
328
|
next_token_ids, _, batch = extend(reqs, model_runner)
|
318
|
-
|
329
|
+
synchronize(device)
|
319
330
|
prefill_latency = time.time() - tic
|
320
331
|
tot_latency += prefill_latency
|
321
332
|
throughput = input_len * batch_size / prefill_latency
|
@@ -328,10 +339,10 @@ def latency_test_run_once(
|
|
328
339
|
# Decode
|
329
340
|
decode_latencies = []
|
330
341
|
for i in range(output_len - 1):
|
331
|
-
|
342
|
+
synchronize(device)
|
332
343
|
tic = time.time()
|
333
344
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
334
|
-
|
345
|
+
synchronize(device)
|
335
346
|
latency = time.time() - tic
|
336
347
|
tot_latency += latency
|
337
348
|
throughput = batch_size / latency
|
@@ -387,6 +398,7 @@ def latency_test(
|
|
387
398
|
bench_args.batch_size[0],
|
388
399
|
bench_args.input_len[0],
|
389
400
|
8, # shorter decoding to speed up the warmup
|
401
|
+
server_args.device,
|
390
402
|
)
|
391
403
|
rank_print("Benchmark ...")
|
392
404
|
|
@@ -397,7 +409,14 @@ def latency_test(
|
|
397
409
|
):
|
398
410
|
reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
|
399
411
|
ret = latency_test_run_once(
|
400
|
-
bench_args.run_name,
|
412
|
+
bench_args.run_name,
|
413
|
+
model_runner,
|
414
|
+
rank_print,
|
415
|
+
reqs,
|
416
|
+
bs,
|
417
|
+
il,
|
418
|
+
ol,
|
419
|
+
server_args.device,
|
401
420
|
)
|
402
421
|
if ret is not None:
|
403
422
|
result_list.append(ret)
|
sglang/bench_server_latency.py
CHANGED
@@ -6,6 +6,8 @@ It accepts arguments similar to those of launch_server.py.
|
|
6
6
|
Usage:
|
7
7
|
|
8
8
|
python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
|
9
|
+
|
10
|
+
python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
|
9
11
|
"""
|
10
12
|
|
11
13
|
import argparse
|
@@ -32,6 +34,8 @@ class BenchArgs:
|
|
32
34
|
input_len: Tuple[int] = (1024,)
|
33
35
|
output_len: Tuple[int] = (16,)
|
34
36
|
result_filename: str = "result.jsonl"
|
37
|
+
base_url: str = ""
|
38
|
+
skip_warmup: bool = False
|
35
39
|
|
36
40
|
@staticmethod
|
37
41
|
def add_cli_args(parser: argparse.ArgumentParser):
|
@@ -48,6 +52,8 @@ class BenchArgs:
|
|
48
52
|
parser.add_argument(
|
49
53
|
"--result-filename", type=str, default=BenchArgs.result_filename
|
50
54
|
)
|
55
|
+
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
|
56
|
+
parser.add_argument("--skip-warmup", action="store_true")
|
51
57
|
|
52
58
|
@classmethod
|
53
59
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -139,17 +145,21 @@ def run_one_case(
|
|
139
145
|
|
140
146
|
|
141
147
|
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
142
|
-
|
148
|
+
if bench_args.base_url:
|
149
|
+
proc, base_url = None, bench_args.base_url
|
150
|
+
else:
|
151
|
+
proc, base_url = launch_server_process(server_args)
|
143
152
|
|
144
153
|
# warmup
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
154
|
+
if not bench_args.skip_warmup:
|
155
|
+
run_one_case(
|
156
|
+
base_url,
|
157
|
+
batch_size=16,
|
158
|
+
input_len=1024,
|
159
|
+
output_len=16,
|
160
|
+
run_name="",
|
161
|
+
result_filename="",
|
162
|
+
)
|
153
163
|
|
154
164
|
# benchmark
|
155
165
|
try:
|
@@ -165,7 +175,8 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
165
175
|
bench_args.result_filename,
|
166
176
|
)
|
167
177
|
finally:
|
168
|
-
|
178
|
+
if proc:
|
179
|
+
kill_child_process(proc.pid)
|
169
180
|
|
170
181
|
print(f"\nResults are saved to {bench_args.result_filename}")
|
171
182
|
|
sglang/bench_serving.py
CHANGED
@@ -222,6 +222,85 @@ async def async_request_openai_completions(
|
|
222
222
|
return output
|
223
223
|
|
224
224
|
|
225
|
+
async def async_request_sglang_generate(
|
226
|
+
request_func_input: RequestFuncInput,
|
227
|
+
pbar: Optional[tqdm] = None,
|
228
|
+
) -> RequestFuncOutput:
|
229
|
+
api_url = request_func_input.api_url
|
230
|
+
prompt = request_func_input.prompt
|
231
|
+
|
232
|
+
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
233
|
+
payload = {
|
234
|
+
"text": prompt,
|
235
|
+
"sampling_params": {
|
236
|
+
"temperature": 0.0,
|
237
|
+
"max_new_tokens": request_func_input.output_len,
|
238
|
+
"ignore_eos": not args.disable_ignore_eos,
|
239
|
+
},
|
240
|
+
"stream": not args.disable_stream,
|
241
|
+
**request_func_input.extra_request_body,
|
242
|
+
}
|
243
|
+
headers = {}
|
244
|
+
|
245
|
+
output = RequestFuncOutput()
|
246
|
+
output.prompt_len = request_func_input.prompt_len
|
247
|
+
|
248
|
+
generated_text = ""
|
249
|
+
ttft = 0.0
|
250
|
+
st = time.perf_counter()
|
251
|
+
most_recent_timestamp = st
|
252
|
+
try:
|
253
|
+
async with session.post(
|
254
|
+
url=api_url, json=payload, headers=headers
|
255
|
+
) as response:
|
256
|
+
if response.status == 200:
|
257
|
+
async for chunk_bytes in response.content:
|
258
|
+
chunk_bytes = chunk_bytes.strip()
|
259
|
+
if not chunk_bytes:
|
260
|
+
continue
|
261
|
+
# print(chunk_bytes)
|
262
|
+
|
263
|
+
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
264
|
+
latency = time.perf_counter() - st
|
265
|
+
if chunk == "[DONE]":
|
266
|
+
pass
|
267
|
+
else:
|
268
|
+
data = json.loads(chunk)
|
269
|
+
|
270
|
+
# NOTE: Some completion API might have a last
|
271
|
+
# usage summary response without a token so we
|
272
|
+
# want to check a token was generated
|
273
|
+
if data["text"]:
|
274
|
+
timestamp = time.perf_counter()
|
275
|
+
# First token
|
276
|
+
if ttft == 0.0:
|
277
|
+
ttft = time.perf_counter() - st
|
278
|
+
output.ttft = ttft
|
279
|
+
|
280
|
+
# Decoding phase
|
281
|
+
else:
|
282
|
+
output.itl.append(timestamp - most_recent_timestamp)
|
283
|
+
|
284
|
+
most_recent_timestamp = timestamp
|
285
|
+
generated_text = data["text"]
|
286
|
+
|
287
|
+
output.generated_text = generated_text
|
288
|
+
output.success = True
|
289
|
+
output.latency = latency
|
290
|
+
output.output_len = request_func_input.output_len
|
291
|
+
else:
|
292
|
+
output.error = response.reason or ""
|
293
|
+
output.success = False
|
294
|
+
except Exception:
|
295
|
+
output.success = False
|
296
|
+
exc_info = sys.exc_info()
|
297
|
+
output.error = "".join(traceback.format_exception(*exc_info))
|
298
|
+
|
299
|
+
if pbar:
|
300
|
+
pbar.update(1)
|
301
|
+
return output
|
302
|
+
|
303
|
+
|
225
304
|
async def async_request_gserver(
|
226
305
|
request_func_input: RequestFuncInput,
|
227
306
|
pbar: Optional[tqdm] = None,
|
@@ -264,7 +343,9 @@ def get_tokenizer(
|
|
264
343
|
|
265
344
|
|
266
345
|
ASYNC_REQUEST_FUNCS = {
|
267
|
-
"sglang":
|
346
|
+
"sglang": async_request_sglang_generate,
|
347
|
+
"sglang-native": async_request_sglang_generate,
|
348
|
+
"sglang-oai": async_request_openai_completions,
|
268
349
|
"vllm": async_request_openai_completions,
|
269
350
|
"lmdeploy": async_request_openai_completions,
|
270
351
|
"trt": async_request_trt_llm,
|
@@ -387,6 +468,8 @@ def sample_sharegpt_requests(
|
|
387
468
|
continue
|
388
469
|
filtered_dataset.append((prompt, prompt_len, output_len))
|
389
470
|
|
471
|
+
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
|
472
|
+
print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}")
|
390
473
|
return filtered_dataset
|
391
474
|
|
392
475
|
|
@@ -587,6 +670,8 @@ async def benchmark(
|
|
587
670
|
else:
|
588
671
|
print("Initial test run completed. Starting main benchmark run...")
|
589
672
|
|
673
|
+
time.sleep(1.5)
|
674
|
+
|
590
675
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
591
676
|
|
592
677
|
benchmark_start_time = time.perf_counter()
|
@@ -782,24 +867,33 @@ def run_benchmark(args_: argparse.Namespace):
|
|
782
867
|
if args.port is None:
|
783
868
|
args.port = {
|
784
869
|
"sglang": 30000,
|
870
|
+
"sglang-native": 30000,
|
871
|
+
"sglang-oai": 30000,
|
785
872
|
"lmdeploy": 23333,
|
786
873
|
"vllm": 8000,
|
787
874
|
"trt": 8000,
|
788
875
|
"gserver": 9988,
|
789
876
|
}.get(args.backend, 30000)
|
790
877
|
|
791
|
-
api_url = (
|
792
|
-
f"{args.base_url}/v1/completions"
|
793
|
-
if args.base_url
|
794
|
-
else f"http://{args.host}:{args.port}/v1/completions"
|
795
|
-
)
|
796
878
|
model_url = (
|
797
879
|
f"{args.base_url}/v1/models"
|
798
880
|
if args.base_url
|
799
881
|
else f"http://{args.host}:{args.port}/v1/models"
|
800
882
|
)
|
801
883
|
|
802
|
-
if args.backend
|
884
|
+
if args.backend in ["sglang", "sglang-native"]:
|
885
|
+
api_url = (
|
886
|
+
f"{args.base_url}/generate"
|
887
|
+
if args.base_url
|
888
|
+
else f"http://{args.host}:{args.port}/generate"
|
889
|
+
)
|
890
|
+
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
|
891
|
+
api_url = (
|
892
|
+
f"{args.base_url}/v1/completions"
|
893
|
+
if args.base_url
|
894
|
+
else f"http://{args.host}:{args.port}/v1/completions"
|
895
|
+
)
|
896
|
+
elif args.backend == "trt":
|
803
897
|
api_url = (
|
804
898
|
f"{args.base_url}/v2/models/ensemble/generate_stream"
|
805
899
|
if args.base_url
|
sglang/global_config.py
CHANGED
@@ -19,7 +19,6 @@ class GlobalConfig:
|
|
19
19
|
self.new_token_ratio_decay = 0.001
|
20
20
|
|
21
21
|
# Runtime constants: others
|
22
|
-
self.num_continue_decode_steps = 10
|
23
22
|
self.retract_decode_steps = 20
|
24
23
|
self.flashinfer_workspace_size = os.environ.get(
|
25
24
|
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
|
sglang/lang/chat_template.py
CHANGED
@@ -133,6 +133,22 @@ register_chat_template(
|
|
133
133
|
)
|
134
134
|
)
|
135
135
|
|
136
|
+
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
|
137
|
+
register_chat_template(
|
138
|
+
ChatTemplate(
|
139
|
+
name="qwen2-vl",
|
140
|
+
default_system_prompt="You are a helpful assistant.",
|
141
|
+
role_prefix_and_suffix={
|
142
|
+
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
143
|
+
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
144
|
+
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
145
|
+
},
|
146
|
+
style=ChatTemplateStyle.PLAIN,
|
147
|
+
stop_str=("<|im_end|>"),
|
148
|
+
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
149
|
+
)
|
150
|
+
)
|
151
|
+
|
136
152
|
|
137
153
|
register_chat_template(
|
138
154
|
ChatTemplate(
|
@@ -213,6 +229,7 @@ register_chat_template(
|
|
213
229
|
),
|
214
230
|
},
|
215
231
|
stop_str=("<|eot_id|>",),
|
232
|
+
image_token="<|image|>",
|
216
233
|
)
|
217
234
|
)
|
218
235
|
|
sglang/launch_server_llavavid.py
CHANGED
@@ -14,7 +14,7 @@ if __name__ == "__main__":
|
|
14
14
|
model_override_args["num_frames"] = 16
|
15
15
|
model_override_args["model_type"] = "llavavid"
|
16
16
|
if model_override_args["num_frames"] == 32:
|
17
|
-
model_override_args["rope_scaling"] = {"factor": 2.0, "
|
17
|
+
model_override_args["rope_scaling"] = {"factor": 2.0, "rope_type": "linear"}
|
18
18
|
model_override_args["max_sequence_length"] = 4096 * 2
|
19
19
|
model_override_args["tokenizer_model_max_length"] = 4096 * 2
|
20
20
|
model_override_args["model_max_length"] = 4096 * 2
|
sglang/srt/configs/__init__.py
CHANGED
@@ -89,6 +89,8 @@ class ModelConfig:
|
|
89
89
|
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
90
90
|
self.vocab_size = self.hf_text_config.vocab_size
|
91
91
|
|
92
|
+
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
|
93
|
+
|
92
94
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
93
95
|
def get_total_num_kv_heads(self) -> int:
|
94
96
|
"""Returns the total number of KV heads."""
|
@@ -0,0 +1,133 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
|
3
|
+
# All rights reserved.
|
4
|
+
#
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6
|
+
# you may not use this file except in compliance with the License.
|
7
|
+
# You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14
|
+
# See the License for the specific language governing permissions and
|
15
|
+
# limitations under the License.
|
16
|
+
"""Qwen2VL model configuration"""
|
17
|
+
|
18
|
+
import os
|
19
|
+
from typing import Union
|
20
|
+
|
21
|
+
from transformers import PretrainedConfig
|
22
|
+
|
23
|
+
|
24
|
+
class Qwen2VLVisionConfig(PretrainedConfig):
|
25
|
+
model_type = "qwen2_vl"
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
depth=32,
|
30
|
+
embed_dim=1280,
|
31
|
+
hidden_size=3584,
|
32
|
+
hidden_act="quick_gelu",
|
33
|
+
mlp_ratio=4,
|
34
|
+
num_heads=16,
|
35
|
+
in_channels=3,
|
36
|
+
patch_size=14,
|
37
|
+
spatial_merge_size=2,
|
38
|
+
temporal_patch_size=2,
|
39
|
+
**kwargs,
|
40
|
+
):
|
41
|
+
super().__init__(**kwargs)
|
42
|
+
|
43
|
+
self.depth = depth
|
44
|
+
self.embed_dim = embed_dim
|
45
|
+
self.hidden_size = hidden_size
|
46
|
+
self.hidden_act = hidden_act
|
47
|
+
self.mlp_ratio = mlp_ratio
|
48
|
+
self.num_heads = num_heads
|
49
|
+
self.in_channels = in_channels
|
50
|
+
self.patch_size = patch_size
|
51
|
+
self.spatial_merge_size = spatial_merge_size
|
52
|
+
self.temporal_patch_size = temporal_patch_size
|
53
|
+
|
54
|
+
@classmethod
|
55
|
+
def from_pretrained(
|
56
|
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
57
|
+
) -> "PretrainedConfig":
|
58
|
+
cls._set_token_in_kwargs(kwargs)
|
59
|
+
|
60
|
+
config_dict, kwargs = cls.get_config_dict(
|
61
|
+
pretrained_model_name_or_path, **kwargs
|
62
|
+
)
|
63
|
+
|
64
|
+
if config_dict.get("model_type") == "qwen2_vl":
|
65
|
+
config_dict = config_dict["vision_config"]
|
66
|
+
|
67
|
+
return cls.from_dict(config_dict, **kwargs)
|
68
|
+
|
69
|
+
|
70
|
+
class Qwen2VLConfig(PretrainedConfig):
|
71
|
+
model_type = "qwen2_vl"
|
72
|
+
|
73
|
+
def __init__(
|
74
|
+
self,
|
75
|
+
vocab_size=152064,
|
76
|
+
hidden_size=8192,
|
77
|
+
intermediate_size=29568,
|
78
|
+
num_hidden_layers=80,
|
79
|
+
num_attention_heads=64,
|
80
|
+
num_key_value_heads=8,
|
81
|
+
hidden_act="silu",
|
82
|
+
max_position_embeddings=32768,
|
83
|
+
initializer_range=0.02,
|
84
|
+
rms_norm_eps=1e-05,
|
85
|
+
use_cache=True,
|
86
|
+
tie_word_embeddings=False,
|
87
|
+
rope_theta=1000000.0,
|
88
|
+
use_sliding_window=False,
|
89
|
+
sliding_window=4096,
|
90
|
+
max_window_layers=80,
|
91
|
+
attention_dropout=0.0,
|
92
|
+
vision_config=None,
|
93
|
+
rope_scaling=None,
|
94
|
+
**kwargs,
|
95
|
+
):
|
96
|
+
if isinstance(vision_config, dict):
|
97
|
+
self.vision_config = Qwen2VLVisionConfig(**vision_config)
|
98
|
+
elif vision_config is None:
|
99
|
+
self.vision_config = Qwen2VLVisionConfig()
|
100
|
+
|
101
|
+
self.vocab_size = vocab_size
|
102
|
+
self.max_position_embeddings = max_position_embeddings
|
103
|
+
self.hidden_size = hidden_size
|
104
|
+
self.intermediate_size = intermediate_size
|
105
|
+
self.num_hidden_layers = num_hidden_layers
|
106
|
+
self.num_attention_heads = num_attention_heads
|
107
|
+
self.use_sliding_window = use_sliding_window
|
108
|
+
self.sliding_window = sliding_window
|
109
|
+
self.max_window_layers = max_window_layers
|
110
|
+
|
111
|
+
# for backward compatibility
|
112
|
+
if num_key_value_heads is None:
|
113
|
+
num_key_value_heads = num_attention_heads
|
114
|
+
|
115
|
+
self.num_key_value_heads = num_key_value_heads
|
116
|
+
self.hidden_act = hidden_act
|
117
|
+
self.initializer_range = initializer_range
|
118
|
+
self.rms_norm_eps = rms_norm_eps
|
119
|
+
self.use_cache = use_cache
|
120
|
+
self.rope_theta = rope_theta
|
121
|
+
self.attention_dropout = attention_dropout
|
122
|
+
self.rope_scaling = rope_scaling
|
123
|
+
|
124
|
+
# NOTE: the following section from original transformers config
|
125
|
+
# for Qwen2-VL is commented out to address rope config loading issue
|
126
|
+
#
|
127
|
+
# if self.rope_scaling is not None and "type" in self.rope_scaling:
|
128
|
+
# if self.rope_scaling["type"] == "mrope":
|
129
|
+
# self.rope_scaling["type"] = "default"
|
130
|
+
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
131
|
+
# rope_config_validation(self)
|
132
|
+
|
133
|
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
sglang/srt/conversation.py
CHANGED
@@ -509,6 +509,19 @@ register_conv_template(
|
|
509
509
|
)
|
510
510
|
)
|
511
511
|
|
512
|
+
register_conv_template(
|
513
|
+
Conversation(
|
514
|
+
name="llama_3_vision",
|
515
|
+
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
|
516
|
+
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
|
517
|
+
roles=("user", "assistant"),
|
518
|
+
sep_style=SeparatorStyle.LLAMA3,
|
519
|
+
sep="",
|
520
|
+
stop_str=["<|end_of_text|>", "<|eot_id|>"],
|
521
|
+
image_token="<|image|>",
|
522
|
+
)
|
523
|
+
)
|
524
|
+
|
512
525
|
register_conv_template(
|
513
526
|
Conversation(
|
514
527
|
name="llava_llama_3",
|
@@ -530,3 +543,17 @@ register_conv_template(
|
|
530
543
|
stop_str=["<|im_end|>", "<|action_end|>"],
|
531
544
|
)
|
532
545
|
)
|
546
|
+
|
547
|
+
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
|
548
|
+
register_conv_template(
|
549
|
+
Conversation(
|
550
|
+
name="qwen2-vl",
|
551
|
+
system_message="You are a helpful assistant.",
|
552
|
+
system_template="<|im_start|>system\n{system_message}",
|
553
|
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
554
|
+
sep="<|im_end|>\n",
|
555
|
+
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
556
|
+
stop_str=["<|im_end|>"],
|
557
|
+
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
558
|
+
)
|
559
|
+
)
|
@@ -33,12 +33,13 @@ from transformers import (
|
|
33
33
|
try:
|
34
34
|
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
35
35
|
|
36
|
-
from sglang.srt.configs import ExaoneConfig
|
36
|
+
from sglang.srt.configs import ExaoneConfig, Qwen2VLConfig
|
37
37
|
|
38
38
|
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
39
39
|
ChatGLMConfig.model_type: ChatGLMConfig,
|
40
40
|
DbrxConfig.model_type: DbrxConfig,
|
41
41
|
ExaoneConfig.model_type: ExaoneConfig,
|
42
|
+
Qwen2VLConfig.model_type: Qwen2VLConfig,
|
42
43
|
}
|
43
44
|
except ImportError:
|
44
45
|
# We want this file to run without vllm dependency
|
@@ -1,7 +1,10 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
+
from typing import Optional
|
2
3
|
|
4
|
+
import torch
|
3
5
|
from torch import nn
|
4
6
|
|
7
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
5
8
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
6
9
|
|
7
10
|
|
@@ -18,13 +21,22 @@ class AttentionBackend(ABC):
|
|
18
21
|
raise NotImplementedError()
|
19
22
|
|
20
23
|
def init_forward_metadata_capture_cuda_graph(
|
21
|
-
self,
|
24
|
+
self,
|
25
|
+
bs: int,
|
26
|
+
req_pool_indices: torch.Tensor,
|
27
|
+
seq_lens: torch.Tensor,
|
28
|
+
encoder_lens: Optional[torch.Tensor] = None,
|
22
29
|
):
|
23
30
|
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
24
31
|
raise NotImplementedError()
|
25
32
|
|
26
33
|
def init_forward_metadata_replay_cuda_graph(
|
27
|
-
self,
|
34
|
+
self,
|
35
|
+
bs: int,
|
36
|
+
req_pool_indices: torch.Tensor,
|
37
|
+
seq_lens: torch.Tensor,
|
38
|
+
seq_lens_sum: int,
|
39
|
+
encoder_lens: Optional[torch.Tensor] = None,
|
28
40
|
):
|
29
41
|
"""Init the metadata for a forward pass for replying a cuda graph."""
|
30
42
|
raise NotImplementedError()
|
@@ -33,17 +45,38 @@ class AttentionBackend(ABC):
|
|
33
45
|
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
34
46
|
raise NotImplementedError()
|
35
47
|
|
36
|
-
def forward(
|
48
|
+
def forward(
|
49
|
+
self,
|
50
|
+
q: torch.Tensor,
|
51
|
+
k: torch.Tensor,
|
52
|
+
v: torch.Tensor,
|
53
|
+
layer: RadixAttention,
|
54
|
+
forward_batch: ForwardBatch,
|
55
|
+
):
|
37
56
|
"""Run forward on an attention layer."""
|
38
57
|
if forward_batch.forward_mode.is_decode():
|
39
58
|
return self.forward_decode(q, k, v, layer, forward_batch)
|
40
59
|
else:
|
41
60
|
return self.forward_extend(q, k, v, layer, forward_batch)
|
42
61
|
|
43
|
-
def forward_decode(
|
62
|
+
def forward_decode(
|
63
|
+
self,
|
64
|
+
q: torch.Tensor,
|
65
|
+
k: torch.Tensor,
|
66
|
+
v: torch.Tensor,
|
67
|
+
layer: RadixAttention,
|
68
|
+
forward_batch: ForwardBatch,
|
69
|
+
):
|
44
70
|
"""Run a forward for decode."""
|
45
71
|
raise NotImplementedError()
|
46
72
|
|
47
|
-
def forward_extend(
|
73
|
+
def forward_extend(
|
74
|
+
self,
|
75
|
+
q: torch.Tensor,
|
76
|
+
k: torch.Tensor,
|
77
|
+
v: torch.Tensor,
|
78
|
+
layer: RadixAttention,
|
79
|
+
forward_batch: ForwardBatch,
|
80
|
+
):
|
48
81
|
"""Run a forward for extend."""
|
49
82
|
raise NotImplementedError()
|