sglang 0.3.4.post2__py3-none-any.whl → 0.3.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_latency.py +3 -3
  3. sglang/bench_server_latency.py +2 -3
  4. sglang/bench_serving.py +205 -3
  5. sglang/global_config.py +9 -3
  6. sglang/lang/chat_template.py +50 -25
  7. sglang/lang/interpreter.py +9 -1
  8. sglang/lang/ir.py +11 -2
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/configs/model_config.py +54 -13
  11. sglang/srt/constrained/__init__.py +2 -48
  12. sglang/srt/constrained/base_grammar_backend.py +72 -0
  13. sglang/srt/constrained/outlines_backend.py +165 -0
  14. sglang/srt/constrained/outlines_jump_forward.py +182 -0
  15. sglang/srt/constrained/xgrammar_backend.py +114 -0
  16. sglang/srt/hf_transformers_utils.py +6 -5
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +117 -30
  18. sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
  19. sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
  20. sglang/srt/layers/fused_moe/fused_moe.py +27 -10
  21. sglang/srt/layers/fused_moe/layer.py +28 -0
  22. sglang/srt/layers/quantization/base_config.py +14 -1
  23. sglang/srt/layers/vocab_parallel_embedding.py +552 -0
  24. sglang/srt/managers/data_parallel_controller.py +7 -6
  25. sglang/srt/managers/detokenizer_manager.py +9 -11
  26. sglang/srt/managers/image_processor.py +4 -3
  27. sglang/srt/managers/io_struct.py +74 -80
  28. sglang/srt/managers/schedule_batch.py +35 -57
  29. sglang/srt/managers/schedule_policy.py +24 -13
  30. sglang/srt/managers/scheduler.py +266 -150
  31. sglang/srt/managers/tokenizer_manager.py +292 -340
  32. sglang/srt/managers/tp_worker.py +5 -5
  33. sglang/srt/mem_cache/flush_cache.py +1 -1
  34. sglang/srt/metrics/collector.py +211 -0
  35. sglang/srt/metrics/func_timer.py +108 -0
  36. sglang/srt/mm_utils.py +1 -1
  37. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  38. sglang/srt/model_executor/forward_batch_info.py +7 -3
  39. sglang/srt/model_executor/model_runner.py +10 -18
  40. sglang/srt/models/baichuan.py +4 -4
  41. sglang/srt/models/chatglm.py +4 -4
  42. sglang/srt/models/commandr.py +1 -1
  43. sglang/srt/models/dbrx.py +5 -5
  44. sglang/srt/models/deepseek.py +4 -4
  45. sglang/srt/models/deepseek_v2.py +4 -4
  46. sglang/srt/models/exaone.py +4 -4
  47. sglang/srt/models/gemma.py +1 -1
  48. sglang/srt/models/gemma2.py +1 -1
  49. sglang/srt/models/gemma2_reward.py +69 -0
  50. sglang/srt/models/gpt2.py +281 -0
  51. sglang/srt/models/gpt_bigcode.py +1 -1
  52. sglang/srt/models/grok.py +4 -4
  53. sglang/srt/models/internlm2.py +4 -4
  54. sglang/srt/models/internlm2_reward.py +62 -0
  55. sglang/srt/models/llama.py +25 -12
  56. sglang/srt/models/llama_embedding.py +2 -10
  57. sglang/srt/models/llama_reward.py +10 -26
  58. sglang/srt/models/minicpm.py +4 -4
  59. sglang/srt/models/minicpm3.py +4 -4
  60. sglang/srt/models/mixtral.py +7 -5
  61. sglang/srt/models/mixtral_quant.py +4 -4
  62. sglang/srt/models/mllama.py +5 -5
  63. sglang/srt/models/olmo.py +4 -4
  64. sglang/srt/models/olmoe.py +4 -4
  65. sglang/srt/models/qwen.py +4 -4
  66. sglang/srt/models/qwen2.py +4 -4
  67. sglang/srt/models/qwen2_moe.py +4 -4
  68. sglang/srt/models/qwen2_vl.py +9 -15
  69. sglang/srt/models/stablelm.py +4 -4
  70. sglang/srt/models/torch_native_llama.py +4 -4
  71. sglang/srt/models/xverse.py +4 -4
  72. sglang/srt/models/xverse_moe.py +4 -4
  73. sglang/srt/openai_api/adapter.py +58 -68
  74. sglang/srt/sampling/sampling_batch_info.py +6 -13
  75. sglang/srt/sampling/sampling_params.py +0 -14
  76. sglang/srt/server.py +84 -46
  77. sglang/srt/server_args.py +61 -12
  78. sglang/srt/utils.py +127 -56
  79. sglang/test/runners.py +2 -1
  80. sglang/test/simple_eval_common.py +1 -1
  81. sglang/test/simple_eval_humaneval.py +2 -2
  82. sglang/test/simple_eval_mgsm.py +2 -2
  83. sglang/test/test_utils.py +89 -27
  84. sglang/utils.py +63 -1
  85. sglang/version.py +1 -1
  86. sglang-0.3.5.post1.dist-info/METADATA +348 -0
  87. sglang-0.3.5.post1.dist-info/RECORD +155 -0
  88. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.post1.dist-info}/WHEEL +1 -1
  89. sglang/srt/constrained/base_tool_cache.py +0 -65
  90. sglang/srt/constrained/fsm_cache.py +0 -95
  91. sglang/srt/constrained/jump_forward.py +0 -203
  92. sglang-0.3.4.post2.dist-info/METADATA +0 -899
  93. sglang-0.3.4.post2.dist-info/RECORD +0 -148
  94. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.post1.dist-info}/LICENSE +0 -0
  95. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.post1.dist-info}/top_level.txt +0 -0
sglang/api.py CHANGED
@@ -99,7 +99,7 @@ def gen(
99
99
  regex: Optional[str] = None,
100
100
  json_schema: Optional[str] = None,
101
101
  ):
102
- """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
102
+ """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
103
103
 
104
104
  if choices:
105
105
  return SglSelect(
sglang/bench_latency.py CHANGED
@@ -129,9 +129,9 @@ def load_model(server_args, port_args, tp_rank):
129
129
 
130
130
  model_config = ModelConfig(
131
131
  server_args.model_path,
132
- server_args.trust_remote_code,
132
+ trust_remote_code=server_args.trust_remote_code,
133
133
  context_length=server_args.context_length,
134
- model_override_args=json.loads(server_args.json_model_override_args),
134
+ model_override_args=server_args.json_model_override_args,
135
135
  )
136
136
  model_runner = ModelRunner(
137
137
  model_config=model_config,
@@ -550,4 +550,4 @@ if __name__ == "__main__":
550
550
  except Exception as e:
551
551
  raise e
552
552
  finally:
553
- kill_child_process(os.getpid(), including_parent=False)
553
+ kill_child_process()
@@ -15,7 +15,6 @@ import dataclasses
15
15
  import itertools
16
16
  import json
17
17
  import multiprocessing
18
- import os
19
18
  import time
20
19
  from typing import Tuple
21
20
 
@@ -70,7 +69,7 @@ def launch_server_internal(server_args):
70
69
  except Exception as e:
71
70
  raise e
72
71
  finally:
73
- kill_child_process(os.getpid(), including_parent=False)
72
+ kill_child_process()
74
73
 
75
74
 
76
75
  def launch_server_process(server_args: ServerArgs):
@@ -176,7 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
176
175
  )
177
176
  finally:
178
177
  if proc:
179
- kill_child_process(proc.pid)
178
+ kill_child_process(proc.pid, include_self=True)
180
179
 
181
180
  print(f"\nResults are saved to {bench_args.result_filename}")
182
181
 
sglang/bench_serving.py CHANGED
@@ -222,6 +222,85 @@ async def async_request_openai_completions(
222
222
  return output
223
223
 
224
224
 
225
+ async def async_request_truss(
226
+ request_func_input: RequestFuncInput,
227
+ pbar: Optional[tqdm] = None,
228
+ ) -> RequestFuncOutput:
229
+ api_url = request_func_input.api_url
230
+
231
+ prompt = request_func_input.prompt
232
+
233
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
234
+ payload = {
235
+ "model": request_func_input.model,
236
+ "prompt": prompt,
237
+ "temperature": 0.0,
238
+ "best_of": 1,
239
+ "max_tokens": request_func_input.output_len,
240
+ "stream": not args.disable_stream,
241
+ "ignore_eos": not args.disable_ignore_eos,
242
+ **request_func_input.extra_request_body,
243
+ }
244
+ headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
245
+
246
+ output = RequestFuncOutput()
247
+ output.prompt_len = request_func_input.prompt_len
248
+
249
+ generated_text = ""
250
+ ttft = 0.0
251
+ st = time.perf_counter()
252
+ most_recent_timestamp = st
253
+ try:
254
+ async with session.post(
255
+ url=api_url, json=payload, headers=headers
256
+ ) as response:
257
+ if response.status == 200:
258
+ async for chunk_bytes in response.content:
259
+ chunk_bytes = chunk_bytes.strip()
260
+ if not chunk_bytes:
261
+ continue
262
+
263
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
264
+ latency = time.perf_counter() - st
265
+ if chunk == "[DONE]":
266
+ pass
267
+ else:
268
+ data = json.loads(chunk)
269
+
270
+ # NOTE: Some completion API might have a last
271
+ # usage summary response without a token so we
272
+ # want to check a token was generated
273
+ if data["choices"][0]["delta"]["content"]:
274
+ timestamp = time.perf_counter()
275
+ # First token
276
+ if ttft == 0.0:
277
+ ttft = time.perf_counter() - st
278
+ output.ttft = ttft
279
+
280
+ # Decoding phase
281
+ else:
282
+ output.itl.append(timestamp - most_recent_timestamp)
283
+
284
+ most_recent_timestamp = timestamp
285
+ generated_text += data["choices"][0]["delta"]["content"]
286
+
287
+ output.generated_text = generated_text
288
+ output.success = True
289
+ output.latency = latency
290
+ output.output_len = request_func_input.output_len
291
+ else:
292
+ output.error = response.reason or ""
293
+ output.success = False
294
+ except Exception:
295
+ output.success = False
296
+ exc_info = sys.exc_info()
297
+ output.error = "".join(traceback.format_exception(*exc_info))
298
+
299
+ if pbar:
300
+ pbar.update(1)
301
+ return output
302
+
303
+
225
304
  async def async_request_sglang_generate(
226
305
  request_func_input: RequestFuncInput,
227
306
  pbar: Optional[tqdm] = None,
@@ -350,6 +429,7 @@ ASYNC_REQUEST_FUNCS = {
350
429
  "lmdeploy": async_request_openai_completions,
351
430
  "trt": async_request_trt_llm,
352
431
  "gserver": async_request_gserver,
432
+ "truss": async_request_truss,
353
433
  }
354
434
 
355
435
 
@@ -516,12 +596,20 @@ def sample_random_requests(
516
596
 
517
597
  # Filter out sequences that are too long or too short
518
598
  input_requests: List[Tuple[str, int, int]] = []
519
- for i in range(num_prompts):
599
+ for data in dataset:
600
+ i = len(input_requests)
601
+ if i == num_prompts:
602
+ break
603
+
520
604
  # Tokenize the prompts and completions.
521
- prompt = dataset[i][0]
605
+ prompt = data[0]
522
606
  prompt_token_ids = tokenizer.encode(prompt)
523
607
  prompt_len = len(prompt_token_ids)
524
608
 
609
+ # Skip empty prompt
610
+ if prompt_len == 0:
611
+ continue
612
+
525
613
  if prompt_len > input_lens[i]:
526
614
  input_ids = prompt_token_ids[: input_lens[i]]
527
615
  else:
@@ -547,6 +635,66 @@ def sample_random_requests(
547
635
  return input_requests
548
636
 
549
637
 
638
+ def gen_prompt(tokenizer, token_num):
639
+ """Generate a random prompt of specified token length using tokenizer vocabulary."""
640
+ all_available_tokens = list(tokenizer.get_vocab().values())
641
+ selected_tokens = random.choices(all_available_tokens, k=token_num)
642
+ return tokenizer.decode(selected_tokens)
643
+
644
+
645
+ def sample_generated_shared_prefix_requests(
646
+ num_groups: int,
647
+ prompts_per_group: int,
648
+ system_prompt_len: int,
649
+ question_len: int,
650
+ output_len: int,
651
+ tokenizer: PreTrainedTokenizerBase,
652
+ ) -> List[Tuple[str, int, int]]:
653
+ """Generate benchmark requests with shared system prompts using random tokens."""
654
+ # Generate system prompts for each group
655
+ system_prompts = []
656
+ for _ in range(num_groups):
657
+ system_prompt = gen_prompt(tokenizer, system_prompt_len)
658
+ system_prompts.append(system_prompt)
659
+
660
+ # Generate questions
661
+ questions = []
662
+ for _ in range(num_groups * prompts_per_group):
663
+ question = gen_prompt(tokenizer, question_len)
664
+ questions.append(question)
665
+
666
+ # Combine system prompts with questions
667
+ input_requests = []
668
+ total_input_tokens = 0
669
+ total_output_tokens = 0
670
+
671
+ for group_idx in range(num_groups):
672
+ system_prompt = system_prompts[group_idx]
673
+ for prompt_idx in range(prompts_per_group):
674
+ question = questions[group_idx * prompts_per_group + prompt_idx]
675
+ full_prompt = f"{system_prompt}\n\n{question}"
676
+ prompt_len = len(tokenizer.encode(full_prompt))
677
+
678
+ input_requests.append((full_prompt, prompt_len, output_len))
679
+ total_input_tokens += prompt_len
680
+ total_output_tokens += output_len
681
+
682
+ print(f"\nGenerated shared prefix dataset statistics:")
683
+ print(f"Number of groups: {num_groups}")
684
+ print(f"Prompts per group: {prompts_per_group}")
685
+ print(f"Total prompts: {len(input_requests)}")
686
+ print(f"Total input tokens: {total_input_tokens}")
687
+ print(f"Total output tokens: {total_output_tokens}")
688
+ print(
689
+ f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens"
690
+ )
691
+ print(
692
+ f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
693
+ )
694
+
695
+ return input_requests
696
+
697
+
550
698
  async def get_request(
551
699
  input_requests: List[Tuple[str, int, int]],
552
700
  request_rate: float,
@@ -873,6 +1021,7 @@ def run_benchmark(args_: argparse.Namespace):
873
1021
  "vllm": 8000,
874
1022
  "trt": 8000,
875
1023
  "gserver": 9988,
1024
+ "truss": 8080,
876
1025
  }.get(args.backend, 30000)
877
1026
 
878
1027
  model_url = (
@@ -905,9 +1054,20 @@ def run_benchmark(args_: argparse.Namespace):
905
1054
  elif args.backend == "gserver":
906
1055
  api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
907
1056
  args.model = args.model or "default"
1057
+ elif args.backend == "truss":
1058
+ api_url = (
1059
+ f"{args.base_url}/v1/models/model:predict"
1060
+ if args.base_url
1061
+ else f"http://{args.host}:{args.port}/v1/models/model:predict"
1062
+ )
908
1063
 
909
1064
  # Get model name
910
1065
  if args.model is None:
1066
+ if args.backend == "truss":
1067
+ print(
1068
+ "Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct"
1069
+ )
1070
+ sys.exit(1)
911
1071
  try:
912
1072
  response = requests.get(model_url)
913
1073
  model_list = response.json().get("data", [])
@@ -956,6 +1116,15 @@ def run_benchmark(args_: argparse.Namespace):
956
1116
  tokenizer=tokenizer,
957
1117
  dataset_path=args.dataset_path,
958
1118
  )
1119
+ elif args.dataset_name == "generated-shared-prefix":
1120
+ input_requests = sample_generated_shared_prefix_requests(
1121
+ num_groups=args.gen_num_groups,
1122
+ prompts_per_group=args.gen_prompts_per_group,
1123
+ system_prompt_len=args.gen_system_prompt_len,
1124
+ question_len=args.gen_question_len,
1125
+ output_len=args.gen_output_len,
1126
+ tokenizer=tokenizer,
1127
+ )
959
1128
  else:
960
1129
  raise ValueError(f"Unknown dataset: {args.dataset_name}")
961
1130
 
@@ -1029,7 +1198,7 @@ if __name__ == "__main__":
1029
1198
  "--dataset-name",
1030
1199
  type=str,
1031
1200
  default="sharegpt",
1032
- choices=["sharegpt", "random"],
1201
+ choices=["sharegpt", "random", "generated-shared-prefix"],
1033
1202
  help="Name of the dataset to benchmark on.",
1034
1203
  )
1035
1204
  parser.add_argument(
@@ -1116,5 +1285,38 @@ if __name__ == "__main__":
1116
1285
  help="Append given JSON object to the request payload. You can use this to specify"
1117
1286
  "additional generate params like sampling params.",
1118
1287
  )
1288
+
1289
+ group = parser.add_argument_group("generated-shared-prefix dataset arguments")
1290
+ group.add_argument(
1291
+ "--gen-num-groups",
1292
+ type=int,
1293
+ default=64,
1294
+ help="Number of system prompt groups for generated-shared-prefix dataset",
1295
+ )
1296
+ group.add_argument(
1297
+ "--gen-prompts-per-group",
1298
+ type=int,
1299
+ default=16,
1300
+ help="Number of prompts per system prompt group for generated-shared-prefix dataset",
1301
+ )
1302
+ group.add_argument(
1303
+ "--gen-system-prompt-len",
1304
+ type=int,
1305
+ default=2048,
1306
+ help="Target length in tokens for system prompts in generated-shared-prefix dataset",
1307
+ )
1308
+ group.add_argument(
1309
+ "--gen-question-len",
1310
+ type=int,
1311
+ default=128,
1312
+ help="Target length in tokens for questions in generated-shared-prefix dataset",
1313
+ )
1314
+ group.add_argument(
1315
+ "--gen-output-len",
1316
+ type=int,
1317
+ default=256,
1318
+ help="Target length in tokens for outputs in generated-shared-prefix dataset",
1319
+ )
1320
+
1119
1321
  args = parser.parse_args()
1120
1322
  run_benchmark(args)
sglang/global_config.py CHANGED
@@ -14,9 +14,15 @@ class GlobalConfig:
14
14
  self.default_backend = None
15
15
 
16
16
  # Runtime constants: New generation token ratio estimation
17
- self.init_new_token_ratio = 0.7
18
- self.base_min_new_token_ratio = 0.1
19
- self.new_token_ratio_decay = 0.001
17
+ self.default_init_new_token_ratio = float(
18
+ os.environ.get("SGLANG_INIT_NEW_TOKEN_RATIO", 0.7)
19
+ )
20
+ self.default_min_new_token_ratio_factor = float(
21
+ os.environ.get("SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR", 0.14)
22
+ )
23
+ self.default_new_token_ratio_decay_steps = float(
24
+ os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
25
+ )
20
26
 
21
27
  # Runtime constants: others
22
28
  self.retract_decode_steps = 20
@@ -116,12 +116,10 @@ register_chat_template(
116
116
  )
117
117
  )
118
118
 
119
- # There is default system prompt for qwen
120
- # reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
121
- # The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
119
+
122
120
  register_chat_template(
123
121
  ChatTemplate(
124
- name="qwen",
122
+ name="chatml-llava",
125
123
  default_system_prompt="You are a helpful assistant.",
126
124
  role_prefix_and_suffix={
127
125
  "system": ("<|im_start|>system\n", "<|im_end|>\n"),
@@ -130,13 +128,17 @@ register_chat_template(
130
128
  },
131
129
  style=ChatTemplateStyle.PLAIN,
132
130
  stop_str=("<|im_end|>",),
131
+ image_token="<image>\n",
133
132
  )
134
133
  )
135
134
 
136
- # Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
135
+
136
+ # There is default system prompt for qwen
137
+ # reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
138
+ # The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
137
139
  register_chat_template(
138
140
  ChatTemplate(
139
- name="qwen2-vl",
141
+ name="qwen",
140
142
  default_system_prompt="You are a helpful assistant.",
141
143
  role_prefix_and_suffix={
142
144
  "system": ("<|im_start|>system\n", "<|im_end|>\n"),
@@ -144,15 +146,14 @@ register_chat_template(
144
146
  "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
145
147
  },
146
148
  style=ChatTemplateStyle.PLAIN,
147
- stop_str=("<|im_end|>"),
148
- image_token="<|vision_start|><|image_pad|><|vision_end|>",
149
+ stop_str=("<|im_end|>",),
149
150
  )
150
151
  )
151
152
 
152
-
153
+ # Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
153
154
  register_chat_template(
154
155
  ChatTemplate(
155
- name="chatml-llava",
156
+ name="qwen2-vl",
156
157
  default_system_prompt="You are a helpful assistant.",
157
158
  role_prefix_and_suffix={
158
159
  "system": ("<|im_start|>system\n", "<|im_end|>\n"),
@@ -161,7 +162,7 @@ register_chat_template(
161
162
  },
162
163
  style=ChatTemplateStyle.PLAIN,
163
164
  stop_str=("<|im_end|>",),
164
- image_token="<image>\n",
165
+ image_token="<|vision_start|><|image_pad|><|vision_end|>",
165
166
  )
166
167
  )
167
168
 
@@ -182,37 +183,46 @@ register_chat_template(
182
183
  )
183
184
  )
184
185
 
185
- # Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
186
186
  register_chat_template(
187
187
  ChatTemplate(
188
- name="yi-1.5",
188
+ name="llama-2-chat",
189
189
  default_system_prompt=None,
190
190
  role_prefix_and_suffix={
191
- "system": ("", ""),
192
- "user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
193
- "assistant": ("", "<|im_end|>\n"),
191
+ "system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
192
+ "user": ("[INST] ", " [/INST]"),
193
+ "assistant": ("", " </s><s>"),
194
194
  },
195
- style=ChatTemplateStyle.PLAIN,
196
- stop_str=("<|im_end|>",),
195
+ style=ChatTemplateStyle.LLAMA2,
197
196
  )
198
197
  )
199
198
 
200
199
  register_chat_template(
201
200
  ChatTemplate(
202
- name="llama-2-chat",
201
+ name="llama-3-instruct",
203
202
  default_system_prompt=None,
204
203
  role_prefix_and_suffix={
205
- "system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
206
- "user": ("[INST] ", " [/INST]"),
207
- "assistant": ("", " </s><s>"),
204
+ "system": (
205
+ "<|start_header_id|>system<|end_header_id|>\n\n",
206
+ "<|eot_id|>",
207
+ ),
208
+ "user": (
209
+ "<|start_header_id|>user<|end_header_id|>\n\n",
210
+ "<|eot_id|>",
211
+ ),
212
+ "assistant": (
213
+ "<|start_header_id|>assistant<|end_header_id|>\n\n",
214
+ "<|eot_id|>",
215
+ ),
208
216
  },
209
- style=ChatTemplateStyle.LLAMA2,
217
+ stop_str=("<|eot_id|>",),
218
+ image_token="<|image|>",
210
219
  )
211
220
  )
212
221
 
222
+ # The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
213
223
  register_chat_template(
214
224
  ChatTemplate(
215
- name="llama-3-instruct",
225
+ name="llama-3-instruct-llava",
216
226
  default_system_prompt=None,
217
227
  role_prefix_and_suffix={
218
228
  "system": (
@@ -229,7 +239,22 @@ register_chat_template(
229
239
  ),
230
240
  },
231
241
  stop_str=("<|eot_id|>",),
232
- image_token="<|image|>",
242
+ image_token="<image>\n",
243
+ )
244
+ )
245
+
246
+ # Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
247
+ register_chat_template(
248
+ ChatTemplate(
249
+ name="yi-1.5",
250
+ default_system_prompt=None,
251
+ role_prefix_and_suffix={
252
+ "system": ("", ""),
253
+ "user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
254
+ "assistant": ("", "<|im_end|>\n"),
255
+ },
256
+ style=ChatTemplateStyle.PLAIN,
257
+ stop_str=("<|im_end|>",),
233
258
  )
234
259
  )
235
260
 
@@ -54,7 +54,14 @@ def run_internal(state, program, func_args, func_kwargs, sync):
54
54
 
55
55
 
56
56
  def run_program(
57
- program, backend, func_args, func_kwargs, default_sampling_para, stream, sync=False
57
+ program,
58
+ backend,
59
+ func_args,
60
+ func_kwargs,
61
+ default_sampling_para,
62
+ stream,
63
+ sync=False,
64
+ use_thread=True,
58
65
  ):
59
66
  if hasattr(backend, "endpoint"):
60
67
  backend = backend.endpoint
@@ -67,6 +74,7 @@ def run_program(
67
74
  chat_template=None,
68
75
  stream=stream,
69
76
  num_api_spec_tokens=program.num_api_spec_tokens,
77
+ use_thread=use_thread,
70
78
  )
71
79
  state = ProgramState(stream_executor)
72
80
 
sglang/lang/ir.py CHANGED
@@ -168,6 +168,7 @@ class SglFunction:
168
168
  return_text_in_logprobs: Optional[bool] = None,
169
169
  stream: bool = False,
170
170
  backend=None,
171
+ use_thread: bool = True,
171
172
  **kwargs,
172
173
  ):
173
174
  from sglang.lang.interpreter import run_program
@@ -195,7 +196,15 @@ class SglFunction:
195
196
  return_text_in_logprobs=return_text_in_logprobs,
196
197
  )
197
198
  backend = backend or global_config.default_backend
198
- return run_program(self, backend, args, kwargs, default_sampling_para, stream)
199
+ return run_program(
200
+ self,
201
+ backend,
202
+ args,
203
+ kwargs,
204
+ default_sampling_para,
205
+ stream,
206
+ use_thread=use_thread,
207
+ )
199
208
 
200
209
  def run_batch(
201
210
  self,
@@ -445,7 +454,7 @@ class SglGen(SglExpr):
445
454
  regex: Optional[str] = None,
446
455
  json_schema: Optional[str] = None,
447
456
  ):
448
- """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
457
+ """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
449
458
  super().__init__()
450
459
  self.name = name
451
460
  self.sampling_params = SglSamplingParams(
sglang/launch_server.py CHANGED
@@ -15,4 +15,4 @@ if __name__ == "__main__":
15
15
  except Exception as e:
16
16
  raise e
17
17
  finally:
18
- kill_child_process(os.getpid(), including_parent=False)
18
+ kill_child_process()
@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ import json
16
17
  import logging
17
18
  import os
18
19
  from enum import IntEnum, auto
19
- from typing import Optional
20
+ from typing import List, Optional
20
21
 
21
22
  from transformers import PretrainedConfig
22
23
 
@@ -38,18 +39,26 @@ class ModelConfig:
38
39
  revision: Optional[str] = None,
39
40
  context_length: Optional[int] = None,
40
41
  model_override_args: Optional[dict] = None,
42
+ is_embedding: Optional[bool] = None,
41
43
  ) -> None:
42
- self.path = path
43
- self.trust_remote_code = trust_remote_code
44
- self.revision = revision
45
- self.model_override_args = model_override_args
44
+ # Parse args
45
+ self.model_override_args = json.loads(model_override_args)
46
46
  self.hf_config = get_config(
47
- self.path,
48
- trust_remote_code,
49
- revision,
50
- model_override_args=model_override_args,
47
+ path,
48
+ trust_remote_code=trust_remote_code,
49
+ revision=revision,
50
+ model_override_args=self.model_override_args,
51
51
  )
52
52
  self.hf_text_config = get_hf_text_config(self.hf_config)
53
+
54
+ # Check model type
55
+ self.is_generation = is_generation_model(
56
+ self.hf_config.architectures, is_embedding
57
+ )
58
+ self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
59
+ self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
60
+
61
+ # Derive context length
53
62
  derived_context_len = get_context_length(self.hf_text_config)
54
63
  allow_long_context = os.environ.get(
55
64
  "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
@@ -81,7 +90,7 @@ class ModelConfig:
81
90
  self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
82
91
  )
83
92
 
84
- # FIXME: temporary special judge for deepseek v2 MLA architecture
93
+ # FIXME: temporary special judge for MLA architecture
85
94
  if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
86
95
  self.head_dim = 256
87
96
  self.attention_arch = AttentionArch.MLA
@@ -112,8 +121,6 @@ class ModelConfig:
112
121
  self.num_hidden_layers = self.hf_text_config.num_hidden_layers
113
122
  self.vocab_size = self.hf_text_config.vocab_size
114
123
 
115
- self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
116
-
117
124
  # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
118
125
  def get_total_num_kv_heads(self) -> int:
119
126
  """Returns the total number of KV heads."""
@@ -163,7 +170,6 @@ class ModelConfig:
163
170
  # equal to the number of attention heads.
164
171
  return self.hf_text_config.num_attention_heads
165
172
 
166
- # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
167
173
  def get_num_kv_heads(self, tensor_parallel_size) -> int:
168
174
  """Returns the number of KV heads per GPU."""
169
175
  total_num_kv_heads = self.get_total_num_kv_heads()
@@ -192,3 +198,38 @@ def get_hf_text_config(config: PretrainedConfig):
192
198
  return config.text_config
193
199
  else:
194
200
  return config
201
+
202
+
203
+ def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
204
+ # We have two ways to determine whether a model is a generative model.
205
+ # 1. Check the model architectue
206
+ # 2. check the `is_embedding` server args
207
+
208
+ if (
209
+ "LlamaEmbeddingModel" in model_architectures
210
+ or "MistralModel" in model_architectures
211
+ or "LlamaForSequenceClassification" in model_architectures
212
+ or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
213
+ or "InternLM2ForRewardModel" in model_architectures
214
+ ):
215
+ return False
216
+ else:
217
+ return not is_embedding
218
+
219
+
220
+ def is_multimodal_model(model_architectures: List[str]):
221
+ if (
222
+ "LlavaLlamaForCausalLM" in model_architectures
223
+ or "LlavaQwenForCausalLM" in model_architectures
224
+ or "LlavaMistralForCausalLM" in model_architectures
225
+ or "LlavaVidForCausalLM" in model_architectures
226
+ or "MllamaForConditionalGeneration" in model_architectures
227
+ or "Qwen2VLForConditionalGeneration" in model_architectures
228
+ ):
229
+ return True
230
+ else:
231
+ return False
232
+
233
+
234
+ def is_encoder_decoder_model(model_architectures: List[str]):
235
+ return "MllamaForConditionalGeneration" in model_architectures