sglang 0.3.3.post1__py3-none-any.whl → 0.3.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. sglang/bench_latency.py +30 -11
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/lang/chat_template.py +17 -0
  6. sglang/launch_server_llavavid.py +1 -1
  7. sglang/srt/configs/__init__.py +3 -0
  8. sglang/srt/configs/model_config.py +2 -0
  9. sglang/srt/configs/qwen2vl.py +133 -0
  10. sglang/srt/conversation.py +27 -0
  11. sglang/srt/hf_transformers_utils.py +2 -1
  12. sglang/srt/layers/attention/__init__.py +38 -5
  13. sglang/srt/layers/attention/double_sparsity_backend.py +297 -0
  14. sglang/srt/layers/attention/flashinfer_backend.py +486 -97
  15. sglang/srt/layers/attention/triton_backend.py +26 -8
  16. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  17. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  18. sglang/srt/layers/attention/triton_ops/prefill_attention.py +30 -6
  19. sglang/srt/layers/linear.py +89 -63
  20. sglang/srt/layers/rotary_embedding.py +145 -0
  21. sglang/srt/layers/sampler.py +6 -2
  22. sglang/srt/lora/lora.py +3 -1
  23. sglang/srt/managers/detokenizer_manager.py +31 -10
  24. sglang/srt/managers/image_processor.py +186 -13
  25. sglang/srt/managers/io_struct.py +4 -0
  26. sglang/srt/managers/schedule_batch.py +319 -82
  27. sglang/srt/managers/schedule_policy.py +2 -1
  28. sglang/srt/managers/scheduler.py +233 -158
  29. sglang/srt/managers/tokenizer_manager.py +15 -5
  30. sglang/srt/managers/tp_worker.py +30 -5
  31. sglang/srt/managers/tp_worker_overlap_thread.py +172 -0
  32. sglang/srt/mem_cache/chunk_cache.py +8 -4
  33. sglang/srt/mem_cache/memory_pool.py +123 -11
  34. sglang/srt/mem_cache/radix_cache.py +19 -10
  35. sglang/srt/model_executor/cuda_graph_runner.py +63 -12
  36. sglang/srt/model_executor/forward_batch_info.py +101 -23
  37. sglang/srt/model_executor/model_runner.py +92 -12
  38. sglang/srt/models/baichuan.py +2 -3
  39. sglang/srt/models/chatglm.py +8 -9
  40. sglang/srt/models/commandr.py +1 -2
  41. sglang/srt/models/dbrx.py +1 -2
  42. sglang/srt/models/deepseek.py +4 -5
  43. sglang/srt/models/deepseek_v2.py +7 -8
  44. sglang/srt/models/exaone.py +1 -2
  45. sglang/srt/models/gemma.py +2 -2
  46. sglang/srt/models/gemma2.py +5 -5
  47. sglang/srt/models/gpt_bigcode.py +5 -5
  48. sglang/srt/models/grok.py +1 -2
  49. sglang/srt/models/internlm2.py +1 -2
  50. sglang/srt/models/llama.py +1 -2
  51. sglang/srt/models/llama_classification.py +1 -2
  52. sglang/srt/models/llama_reward.py +2 -3
  53. sglang/srt/models/llava.py +4 -8
  54. sglang/srt/models/llavavid.py +1 -2
  55. sglang/srt/models/minicpm.py +1 -2
  56. sglang/srt/models/minicpm3.py +5 -6
  57. sglang/srt/models/mixtral.py +1 -2
  58. sglang/srt/models/mixtral_quant.py +1 -2
  59. sglang/srt/models/mllama.py +1004 -0
  60. sglang/srt/models/olmo.py +352 -0
  61. sglang/srt/models/olmoe.py +1 -2
  62. sglang/srt/models/qwen.py +1 -2
  63. sglang/srt/models/qwen2.py +1 -2
  64. sglang/srt/models/qwen2_moe.py +4 -5
  65. sglang/srt/models/qwen2_vl.py +724 -0
  66. sglang/srt/models/stablelm.py +1 -2
  67. sglang/srt/models/torch_native_llama.py +1 -2
  68. sglang/srt/models/xverse.py +1 -2
  69. sglang/srt/models/xverse_moe.py +4 -5
  70. sglang/srt/models/yivl.py +1 -2
  71. sglang/srt/openai_api/adapter.py +92 -49
  72. sglang/srt/openai_api/protocol.py +10 -2
  73. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  74. sglang/srt/sampling/sampling_batch_info.py +103 -59
  75. sglang/srt/sampling/sampling_params.py +2 -0
  76. sglang/srt/server.py +116 -17
  77. sglang/srt/server_args.py +131 -45
  78. sglang/srt/utils.py +33 -3
  79. sglang/test/few_shot_gsm8k.py +4 -1
  80. sglang/test/few_shot_gsm8k_engine.py +144 -0
  81. sglang/test/runners.py +20 -1
  82. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  83. sglang/version.py +1 -1
  84. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.post1.dist-info}/METADATA +75 -32
  85. sglang-0.3.4.post1.dist-info/RECORD +148 -0
  86. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.post1.dist-info}/WHEEL +1 -1
  87. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  88. sglang-0.3.3.post1.dist-info/RECORD +0 -140
  89. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.post1.dist-info}/LICENSE +0 -0
  90. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.post1.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py CHANGED
@@ -227,22 +227,24 @@ def extend(reqs, model_runner):
227
227
  req_to_token_pool=model_runner.req_to_token_pool,
228
228
  token_to_kv_pool=model_runner.token_to_kv_pool,
229
229
  tree_cache=None,
230
+ model_config=model_runner.model_config,
230
231
  )
231
- batch.prepare_for_extend(model_runner.model_config.vocab_size)
232
+ batch.prepare_for_extend()
232
233
  model_worker_batch = batch.get_model_worker_batch()
233
234
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
234
235
  logits_output = model_runner.forward(forward_batch)
235
- next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
236
+ next_token_ids = model_runner.sample(logits_output, forward_batch)
236
237
  return next_token_ids, logits_output.next_token_logits, batch
237
238
 
238
239
 
239
240
  @torch.inference_mode()
240
241
  def decode(input_token_ids, batch, model_runner):
241
- batch.prepare_for_decode(input_token_ids)
242
+ batch.output_ids = input_token_ids
243
+ batch.prepare_for_decode()
242
244
  model_worker_batch = batch.get_model_worker_batch()
243
245
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
244
246
  logits_output = model_runner.forward(forward_batch)
245
- next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
247
+ next_token_ids = model_runner.sample(logits_output, forward_batch)
246
248
  return next_token_ids, logits_output.next_token_logits
247
249
 
248
250
 
@@ -252,6 +254,7 @@ def correctness_test(
252
254
  bench_args,
253
255
  tp_rank,
254
256
  ):
257
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
255
258
  rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
256
259
 
257
260
  # Load the model
@@ -279,8 +282,9 @@ def correctness_test(
279
282
  output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
280
283
  for _ in range(bench_args.output_len[0] - 1):
281
284
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
285
+ next_token_ids_list = next_token_ids.tolist()
282
286
  for i in range(len(reqs)):
283
- output_ids[i].append(next_token_ids[i])
287
+ output_ids[i].append(next_token_ids_list[i])
284
288
 
285
289
  # Print
286
290
  for i in range(len(reqs)):
@@ -288,8 +292,15 @@ def correctness_test(
288
292
  rank_print(tokenizer.decode(output_ids[i]), "\n")
289
293
 
290
294
 
295
+ def synchronize(device):
296
+ if device == "cuda":
297
+ torch.cuda.synchronize()
298
+ elif device == "xpu":
299
+ torch.xpu.synchronize()
300
+
301
+
291
302
  def latency_test_run_once(
292
- run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
303
+ run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
293
304
  ):
294
305
  max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
295
306
  if batch_size > max_batch_size:
@@ -312,10 +323,10 @@ def latency_test_run_once(
312
323
  tot_latency = 0
313
324
 
314
325
  # Prefill
315
- torch.cuda.synchronize()
326
+ synchronize(device)
316
327
  tic = time.time()
317
328
  next_token_ids, _, batch = extend(reqs, model_runner)
318
- torch.cuda.synchronize()
329
+ synchronize(device)
319
330
  prefill_latency = time.time() - tic
320
331
  tot_latency += prefill_latency
321
332
  throughput = input_len * batch_size / prefill_latency
@@ -328,10 +339,10 @@ def latency_test_run_once(
328
339
  # Decode
329
340
  decode_latencies = []
330
341
  for i in range(output_len - 1):
331
- torch.cuda.synchronize()
342
+ synchronize(device)
332
343
  tic = time.time()
333
344
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
334
- torch.cuda.synchronize()
345
+ synchronize(device)
335
346
  latency = time.time() - tic
336
347
  tot_latency += latency
337
348
  throughput = batch_size / latency
@@ -387,6 +398,7 @@ def latency_test(
387
398
  bench_args.batch_size[0],
388
399
  bench_args.input_len[0],
389
400
  8, # shorter decoding to speed up the warmup
401
+ server_args.device,
390
402
  )
391
403
  rank_print("Benchmark ...")
392
404
 
@@ -397,7 +409,14 @@ def latency_test(
397
409
  ):
398
410
  reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
399
411
  ret = latency_test_run_once(
400
- bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol
412
+ bench_args.run_name,
413
+ model_runner,
414
+ rank_print,
415
+ reqs,
416
+ bs,
417
+ il,
418
+ ol,
419
+ server_args.device,
401
420
  )
402
421
  if ret is not None:
403
422
  result_list.append(ret)
@@ -6,6 +6,8 @@ It accepts arguments similar to those of launch_server.py.
6
6
  Usage:
7
7
 
8
8
  python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
9
+
10
+ python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
9
11
  """
10
12
 
11
13
  import argparse
@@ -32,6 +34,8 @@ class BenchArgs:
32
34
  input_len: Tuple[int] = (1024,)
33
35
  output_len: Tuple[int] = (16,)
34
36
  result_filename: str = "result.jsonl"
37
+ base_url: str = ""
38
+ skip_warmup: bool = False
35
39
 
36
40
  @staticmethod
37
41
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -48,6 +52,8 @@ class BenchArgs:
48
52
  parser.add_argument(
49
53
  "--result-filename", type=str, default=BenchArgs.result_filename
50
54
  )
55
+ parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
56
+ parser.add_argument("--skip-warmup", action="store_true")
51
57
 
52
58
  @classmethod
53
59
  def from_cli_args(cls, args: argparse.Namespace):
@@ -139,17 +145,21 @@ def run_one_case(
139
145
 
140
146
 
141
147
  def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
142
- proc, base_url = launch_server_process(server_args)
148
+ if bench_args.base_url:
149
+ proc, base_url = None, bench_args.base_url
150
+ else:
151
+ proc, base_url = launch_server_process(server_args)
143
152
 
144
153
  # warmup
145
- run_one_case(
146
- base_url,
147
- batch_size=16,
148
- input_len=1024,
149
- output_len=16,
150
- run_name="",
151
- result_filename="",
152
- )
154
+ if not bench_args.skip_warmup:
155
+ run_one_case(
156
+ base_url,
157
+ batch_size=16,
158
+ input_len=1024,
159
+ output_len=16,
160
+ run_name="",
161
+ result_filename="",
162
+ )
153
163
 
154
164
  # benchmark
155
165
  try:
@@ -165,7 +175,8 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
165
175
  bench_args.result_filename,
166
176
  )
167
177
  finally:
168
- kill_child_process(proc.pid)
178
+ if proc:
179
+ kill_child_process(proc.pid)
169
180
 
170
181
  print(f"\nResults are saved to {bench_args.result_filename}")
171
182
 
sglang/bench_serving.py CHANGED
@@ -222,6 +222,85 @@ async def async_request_openai_completions(
222
222
  return output
223
223
 
224
224
 
225
+ async def async_request_sglang_generate(
226
+ request_func_input: RequestFuncInput,
227
+ pbar: Optional[tqdm] = None,
228
+ ) -> RequestFuncOutput:
229
+ api_url = request_func_input.api_url
230
+ prompt = request_func_input.prompt
231
+
232
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
233
+ payload = {
234
+ "text": prompt,
235
+ "sampling_params": {
236
+ "temperature": 0.0,
237
+ "max_new_tokens": request_func_input.output_len,
238
+ "ignore_eos": not args.disable_ignore_eos,
239
+ },
240
+ "stream": not args.disable_stream,
241
+ **request_func_input.extra_request_body,
242
+ }
243
+ headers = {}
244
+
245
+ output = RequestFuncOutput()
246
+ output.prompt_len = request_func_input.prompt_len
247
+
248
+ generated_text = ""
249
+ ttft = 0.0
250
+ st = time.perf_counter()
251
+ most_recent_timestamp = st
252
+ try:
253
+ async with session.post(
254
+ url=api_url, json=payload, headers=headers
255
+ ) as response:
256
+ if response.status == 200:
257
+ async for chunk_bytes in response.content:
258
+ chunk_bytes = chunk_bytes.strip()
259
+ if not chunk_bytes:
260
+ continue
261
+ # print(chunk_bytes)
262
+
263
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
264
+ latency = time.perf_counter() - st
265
+ if chunk == "[DONE]":
266
+ pass
267
+ else:
268
+ data = json.loads(chunk)
269
+
270
+ # NOTE: Some completion API might have a last
271
+ # usage summary response without a token so we
272
+ # want to check a token was generated
273
+ if data["text"]:
274
+ timestamp = time.perf_counter()
275
+ # First token
276
+ if ttft == 0.0:
277
+ ttft = time.perf_counter() - st
278
+ output.ttft = ttft
279
+
280
+ # Decoding phase
281
+ else:
282
+ output.itl.append(timestamp - most_recent_timestamp)
283
+
284
+ most_recent_timestamp = timestamp
285
+ generated_text = data["text"]
286
+
287
+ output.generated_text = generated_text
288
+ output.success = True
289
+ output.latency = latency
290
+ output.output_len = request_func_input.output_len
291
+ else:
292
+ output.error = response.reason or ""
293
+ output.success = False
294
+ except Exception:
295
+ output.success = False
296
+ exc_info = sys.exc_info()
297
+ output.error = "".join(traceback.format_exception(*exc_info))
298
+
299
+ if pbar:
300
+ pbar.update(1)
301
+ return output
302
+
303
+
225
304
  async def async_request_gserver(
226
305
  request_func_input: RequestFuncInput,
227
306
  pbar: Optional[tqdm] = None,
@@ -264,7 +343,9 @@ def get_tokenizer(
264
343
 
265
344
 
266
345
  ASYNC_REQUEST_FUNCS = {
267
- "sglang": async_request_openai_completions,
346
+ "sglang": async_request_sglang_generate,
347
+ "sglang-native": async_request_sglang_generate,
348
+ "sglang-oai": async_request_openai_completions,
268
349
  "vllm": async_request_openai_completions,
269
350
  "lmdeploy": async_request_openai_completions,
270
351
  "trt": async_request_trt_llm,
@@ -387,6 +468,8 @@ def sample_sharegpt_requests(
387
468
  continue
388
469
  filtered_dataset.append((prompt, prompt_len, output_len))
389
470
 
471
+ print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
472
+ print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}")
390
473
  return filtered_dataset
391
474
 
392
475
 
@@ -587,6 +670,8 @@ async def benchmark(
587
670
  else:
588
671
  print("Initial test run completed. Starting main benchmark run...")
589
672
 
673
+ time.sleep(1.5)
674
+
590
675
  pbar = None if disable_tqdm else tqdm(total=len(input_requests))
591
676
 
592
677
  benchmark_start_time = time.perf_counter()
@@ -782,24 +867,33 @@ def run_benchmark(args_: argparse.Namespace):
782
867
  if args.port is None:
783
868
  args.port = {
784
869
  "sglang": 30000,
870
+ "sglang-native": 30000,
871
+ "sglang-oai": 30000,
785
872
  "lmdeploy": 23333,
786
873
  "vllm": 8000,
787
874
  "trt": 8000,
788
875
  "gserver": 9988,
789
876
  }.get(args.backend, 30000)
790
877
 
791
- api_url = (
792
- f"{args.base_url}/v1/completions"
793
- if args.base_url
794
- else f"http://{args.host}:{args.port}/v1/completions"
795
- )
796
878
  model_url = (
797
879
  f"{args.base_url}/v1/models"
798
880
  if args.base_url
799
881
  else f"http://{args.host}:{args.port}/v1/models"
800
882
  )
801
883
 
802
- if args.backend == "trt":
884
+ if args.backend in ["sglang", "sglang-native"]:
885
+ api_url = (
886
+ f"{args.base_url}/generate"
887
+ if args.base_url
888
+ else f"http://{args.host}:{args.port}/generate"
889
+ )
890
+ elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
891
+ api_url = (
892
+ f"{args.base_url}/v1/completions"
893
+ if args.base_url
894
+ else f"http://{args.host}:{args.port}/v1/completions"
895
+ )
896
+ elif args.backend == "trt":
803
897
  api_url = (
804
898
  f"{args.base_url}/v2/models/ensemble/generate_stream"
805
899
  if args.base_url
sglang/global_config.py CHANGED
@@ -19,7 +19,6 @@ class GlobalConfig:
19
19
  self.new_token_ratio_decay = 0.001
20
20
 
21
21
  # Runtime constants: others
22
- self.num_continue_decode_steps = 10
23
22
  self.retract_decode_steps = 20
24
23
  self.flashinfer_workspace_size = os.environ.get(
25
24
  "FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
@@ -133,6 +133,22 @@ register_chat_template(
133
133
  )
134
134
  )
135
135
 
136
+ # Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
137
+ register_chat_template(
138
+ ChatTemplate(
139
+ name="qwen2-vl",
140
+ default_system_prompt="You are a helpful assistant.",
141
+ role_prefix_and_suffix={
142
+ "system": ("<|im_start|>system\n", "<|im_end|>\n"),
143
+ "user": ("<|im_start|>user\n", "<|im_end|>\n"),
144
+ "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
145
+ },
146
+ style=ChatTemplateStyle.PLAIN,
147
+ stop_str=("<|im_end|>"),
148
+ image_token="<|vision_start|><|image_pad|><|vision_end|>",
149
+ )
150
+ )
151
+
136
152
 
137
153
  register_chat_template(
138
154
  ChatTemplate(
@@ -213,6 +229,7 @@ register_chat_template(
213
229
  ),
214
230
  },
215
231
  stop_str=("<|eot_id|>",),
232
+ image_token="<|image|>",
216
233
  )
217
234
  )
218
235
 
@@ -14,7 +14,7 @@ if __name__ == "__main__":
14
14
  model_override_args["num_frames"] = 16
15
15
  model_override_args["model_type"] = "llavavid"
16
16
  if model_override_args["num_frames"] == 32:
17
- model_override_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
17
+ model_override_args["rope_scaling"] = {"factor": 2.0, "rope_type": "linear"}
18
18
  model_override_args["max_sequence_length"] = 4096 * 2
19
19
  model_override_args["tokenizer_model_max_length"] = 4096 * 2
20
20
  model_override_args["model_max_length"] = 4096 * 2
@@ -1,5 +1,8 @@
1
1
  from sglang.srt.configs.exaone import ExaoneConfig
2
+ from sglang.srt.configs.qwen2vl import Qwen2VLConfig, Qwen2VLVisionConfig
2
3
 
3
4
  __all__ = [
4
5
  "ExaoneConfig",
6
+ "Qwen2VLConfig",
7
+ "Qwen2VLVisionConfig",
5
8
  ]
@@ -89,6 +89,8 @@ class ModelConfig:
89
89
  self.num_hidden_layers = self.hf_text_config.num_hidden_layers
90
90
  self.vocab_size = self.hf_text_config.vocab_size
91
91
 
92
+ self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
93
+
92
94
  # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
93
95
  def get_total_num_kv_heads(self) -> int:
94
96
  """Returns the total number of KV heads."""
@@ -0,0 +1,133 @@
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
3
+ # All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """Qwen2VL model configuration"""
17
+
18
+ import os
19
+ from typing import Union
20
+
21
+ from transformers import PretrainedConfig
22
+
23
+
24
+ class Qwen2VLVisionConfig(PretrainedConfig):
25
+ model_type = "qwen2_vl"
26
+
27
+ def __init__(
28
+ self,
29
+ depth=32,
30
+ embed_dim=1280,
31
+ hidden_size=3584,
32
+ hidden_act="quick_gelu",
33
+ mlp_ratio=4,
34
+ num_heads=16,
35
+ in_channels=3,
36
+ patch_size=14,
37
+ spatial_merge_size=2,
38
+ temporal_patch_size=2,
39
+ **kwargs,
40
+ ):
41
+ super().__init__(**kwargs)
42
+
43
+ self.depth = depth
44
+ self.embed_dim = embed_dim
45
+ self.hidden_size = hidden_size
46
+ self.hidden_act = hidden_act
47
+ self.mlp_ratio = mlp_ratio
48
+ self.num_heads = num_heads
49
+ self.in_channels = in_channels
50
+ self.patch_size = patch_size
51
+ self.spatial_merge_size = spatial_merge_size
52
+ self.temporal_patch_size = temporal_patch_size
53
+
54
+ @classmethod
55
+ def from_pretrained(
56
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
57
+ ) -> "PretrainedConfig":
58
+ cls._set_token_in_kwargs(kwargs)
59
+
60
+ config_dict, kwargs = cls.get_config_dict(
61
+ pretrained_model_name_or_path, **kwargs
62
+ )
63
+
64
+ if config_dict.get("model_type") == "qwen2_vl":
65
+ config_dict = config_dict["vision_config"]
66
+
67
+ return cls.from_dict(config_dict, **kwargs)
68
+
69
+
70
+ class Qwen2VLConfig(PretrainedConfig):
71
+ model_type = "qwen2_vl"
72
+
73
+ def __init__(
74
+ self,
75
+ vocab_size=152064,
76
+ hidden_size=8192,
77
+ intermediate_size=29568,
78
+ num_hidden_layers=80,
79
+ num_attention_heads=64,
80
+ num_key_value_heads=8,
81
+ hidden_act="silu",
82
+ max_position_embeddings=32768,
83
+ initializer_range=0.02,
84
+ rms_norm_eps=1e-05,
85
+ use_cache=True,
86
+ tie_word_embeddings=False,
87
+ rope_theta=1000000.0,
88
+ use_sliding_window=False,
89
+ sliding_window=4096,
90
+ max_window_layers=80,
91
+ attention_dropout=0.0,
92
+ vision_config=None,
93
+ rope_scaling=None,
94
+ **kwargs,
95
+ ):
96
+ if isinstance(vision_config, dict):
97
+ self.vision_config = Qwen2VLVisionConfig(**vision_config)
98
+ elif vision_config is None:
99
+ self.vision_config = Qwen2VLVisionConfig()
100
+
101
+ self.vocab_size = vocab_size
102
+ self.max_position_embeddings = max_position_embeddings
103
+ self.hidden_size = hidden_size
104
+ self.intermediate_size = intermediate_size
105
+ self.num_hidden_layers = num_hidden_layers
106
+ self.num_attention_heads = num_attention_heads
107
+ self.use_sliding_window = use_sliding_window
108
+ self.sliding_window = sliding_window
109
+ self.max_window_layers = max_window_layers
110
+
111
+ # for backward compatibility
112
+ if num_key_value_heads is None:
113
+ num_key_value_heads = num_attention_heads
114
+
115
+ self.num_key_value_heads = num_key_value_heads
116
+ self.hidden_act = hidden_act
117
+ self.initializer_range = initializer_range
118
+ self.rms_norm_eps = rms_norm_eps
119
+ self.use_cache = use_cache
120
+ self.rope_theta = rope_theta
121
+ self.attention_dropout = attention_dropout
122
+ self.rope_scaling = rope_scaling
123
+
124
+ # NOTE: the following section from original transformers config
125
+ # for Qwen2-VL is commented out to address rope config loading issue
126
+ #
127
+ # if self.rope_scaling is not None and "type" in self.rope_scaling:
128
+ # if self.rope_scaling["type"] == "mrope":
129
+ # self.rope_scaling["type"] = "default"
130
+ # self.rope_scaling["rope_type"] = self.rope_scaling["type"]
131
+ # rope_config_validation(self)
132
+
133
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
@@ -509,6 +509,19 @@ register_conv_template(
509
509
  )
510
510
  )
511
511
 
512
+ register_conv_template(
513
+ Conversation(
514
+ name="llama_3_vision",
515
+ system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
516
+ system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
517
+ roles=("user", "assistant"),
518
+ sep_style=SeparatorStyle.LLAMA3,
519
+ sep="",
520
+ stop_str=["<|end_of_text|>", "<|eot_id|>"],
521
+ image_token="<|image|>",
522
+ )
523
+ )
524
+
512
525
  register_conv_template(
513
526
  Conversation(
514
527
  name="llava_llama_3",
@@ -530,3 +543,17 @@ register_conv_template(
530
543
  stop_str=["<|im_end|>", "<|action_end|>"],
531
544
  )
532
545
  )
546
+
547
+ # Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
548
+ register_conv_template(
549
+ Conversation(
550
+ name="qwen2-vl",
551
+ system_message="You are a helpful assistant.",
552
+ system_template="<|im_start|>system\n{system_message}",
553
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
554
+ sep="<|im_end|>\n",
555
+ sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
556
+ stop_str=["<|im_end|>"],
557
+ image_token="<|vision_start|><|image_pad|><|vision_end|>",
558
+ )
559
+ )
@@ -33,12 +33,13 @@ from transformers import (
33
33
  try:
34
34
  from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
35
35
 
36
- from sglang.srt.configs import ExaoneConfig
36
+ from sglang.srt.configs import ExaoneConfig, Qwen2VLConfig
37
37
 
38
38
  _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
39
39
  ChatGLMConfig.model_type: ChatGLMConfig,
40
40
  DbrxConfig.model_type: DbrxConfig,
41
41
  ExaoneConfig.model_type: ExaoneConfig,
42
+ Qwen2VLConfig.model_type: Qwen2VLConfig,
42
43
  }
43
44
  except ImportError:
44
45
  # We want this file to run without vllm dependency
@@ -1,7 +1,10 @@
1
1
  from abc import ABC, abstractmethod
2
+ from typing import Optional
2
3
 
4
+ import torch
3
5
  from torch import nn
4
6
 
7
+ from sglang.srt.layers.radix_attention import RadixAttention
5
8
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
6
9
 
7
10
 
@@ -18,13 +21,22 @@ class AttentionBackend(ABC):
18
21
  raise NotImplementedError()
19
22
 
20
23
  def init_forward_metadata_capture_cuda_graph(
21
- self, bs: int, req_pool_indices, seq_lens
24
+ self,
25
+ bs: int,
26
+ req_pool_indices: torch.Tensor,
27
+ seq_lens: torch.Tensor,
28
+ encoder_lens: Optional[torch.Tensor] = None,
22
29
  ):
23
30
  """Init the metadata for a forward pass for capturing a cuda graph."""
24
31
  raise NotImplementedError()
25
32
 
26
33
  def init_forward_metadata_replay_cuda_graph(
27
- self, bs: int, req_pool_indices, seq_lens
34
+ self,
35
+ bs: int,
36
+ req_pool_indices: torch.Tensor,
37
+ seq_lens: torch.Tensor,
38
+ seq_lens_sum: int,
39
+ encoder_lens: Optional[torch.Tensor] = None,
28
40
  ):
29
41
  """Init the metadata for a forward pass for replying a cuda graph."""
30
42
  raise NotImplementedError()
@@ -33,17 +45,38 @@ class AttentionBackend(ABC):
33
45
  """Get the fill value for padded seq lens. Typically, it is 0 or 1."""
34
46
  raise NotImplementedError()
35
47
 
36
- def forward(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
48
+ def forward(
49
+ self,
50
+ q: torch.Tensor,
51
+ k: torch.Tensor,
52
+ v: torch.Tensor,
53
+ layer: RadixAttention,
54
+ forward_batch: ForwardBatch,
55
+ ):
37
56
  """Run forward on an attention layer."""
38
57
  if forward_batch.forward_mode.is_decode():
39
58
  return self.forward_decode(q, k, v, layer, forward_batch)
40
59
  else:
41
60
  return self.forward_extend(q, k, v, layer, forward_batch)
42
61
 
43
- def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
62
+ def forward_decode(
63
+ self,
64
+ q: torch.Tensor,
65
+ k: torch.Tensor,
66
+ v: torch.Tensor,
67
+ layer: RadixAttention,
68
+ forward_batch: ForwardBatch,
69
+ ):
44
70
  """Run a forward for decode."""
45
71
  raise NotImplementedError()
46
72
 
47
- def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
73
+ def forward_extend(
74
+ self,
75
+ q: torch.Tensor,
76
+ k: torch.Tensor,
77
+ v: torch.Tensor,
78
+ layer: RadixAttention,
79
+ forward_batch: ForwardBatch,
80
+ ):
48
81
  """Run a forward for extend."""
49
82
  raise NotImplementedError()