sglang 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/conversation.py +50 -1
  11. sglang/srt/hf_transformers_utils.py +22 -23
  12. sglang/srt/layers/activation.py +24 -1
  13. sglang/srt/layers/decode_attention.py +338 -50
  14. sglang/srt/layers/fused_moe/layer.py +2 -2
  15. sglang/srt/layers/layernorm.py +3 -0
  16. sglang/srt/layers/logits_processor.py +60 -23
  17. sglang/srt/layers/radix_attention.py +3 -4
  18. sglang/srt/layers/sampler.py +154 -0
  19. sglang/srt/managers/controller_multi.py +2 -8
  20. sglang/srt/managers/controller_single.py +7 -10
  21. sglang/srt/managers/detokenizer_manager.py +20 -9
  22. sglang/srt/managers/io_struct.py +44 -11
  23. sglang/srt/managers/policy_scheduler.py +5 -2
  24. sglang/srt/managers/schedule_batch.py +52 -167
  25. sglang/srt/managers/tokenizer_manager.py +192 -83
  26. sglang/srt/managers/tp_worker.py +130 -43
  27. sglang/srt/mem_cache/memory_pool.py +82 -8
  28. sglang/srt/mm_utils.py +79 -7
  29. sglang/srt/model_executor/cuda_graph_runner.py +49 -11
  30. sglang/srt/model_executor/forward_batch_info.py +59 -27
  31. sglang/srt/model_executor/model_runner.py +210 -61
  32. sglang/srt/models/chatglm.py +4 -12
  33. sglang/srt/models/commandr.py +5 -1
  34. sglang/srt/models/dbrx.py +5 -1
  35. sglang/srt/models/deepseek.py +5 -1
  36. sglang/srt/models/deepseek_v2.py +5 -1
  37. sglang/srt/models/gemma.py +5 -1
  38. sglang/srt/models/gemma2.py +15 -7
  39. sglang/srt/models/gpt_bigcode.py +5 -1
  40. sglang/srt/models/grok.py +16 -2
  41. sglang/srt/models/internlm2.py +5 -1
  42. sglang/srt/models/llama2.py +7 -3
  43. sglang/srt/models/llama_classification.py +2 -2
  44. sglang/srt/models/llama_embedding.py +4 -0
  45. sglang/srt/models/llava.py +176 -59
  46. sglang/srt/models/minicpm.py +5 -1
  47. sglang/srt/models/mixtral.py +5 -1
  48. sglang/srt/models/mixtral_quant.py +5 -1
  49. sglang/srt/models/qwen.py +5 -2
  50. sglang/srt/models/qwen2.py +13 -3
  51. sglang/srt/models/qwen2_moe.py +5 -14
  52. sglang/srt/models/stablelm.py +5 -1
  53. sglang/srt/openai_api/adapter.py +117 -37
  54. sglang/srt/sampling/sampling_batch_info.py +209 -0
  55. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -0
  56. sglang/srt/server.py +84 -56
  57. sglang/srt/server_args.py +43 -15
  58. sglang/srt/utils.py +26 -16
  59. sglang/test/runners.py +23 -31
  60. sglang/test/simple_eval_common.py +9 -10
  61. sglang/test/simple_eval_gpqa.py +2 -1
  62. sglang/test/simple_eval_humaneval.py +2 -2
  63. sglang/test/simple_eval_math.py +2 -1
  64. sglang/test/simple_eval_mmlu.py +2 -1
  65. sglang/test/test_activation.py +55 -0
  66. sglang/test/test_utils.py +36 -53
  67. sglang/version.py +1 -1
  68. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/METADATA +92 -25
  69. sglang-0.2.14.dist-info/RECORD +114 -0
  70. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  71. sglang/launch_server_llavavid.py +0 -29
  72. sglang-0.2.13.dist-info/RECORD +0 -112
  73. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  74. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
sglang/api.py CHANGED
@@ -66,6 +66,7 @@ def gen(
66
66
  temperature: Optional[float] = None,
67
67
  top_p: Optional[float] = None,
68
68
  top_k: Optional[int] = None,
69
+ min_p: Optional[float] = None,
69
70
  frequency_penalty: Optional[float] = None,
70
71
  presence_penalty: Optional[float] = None,
71
72
  ignore_eos: Optional[bool] = None,
@@ -103,6 +104,7 @@ def gen(
103
104
  temperature,
104
105
  top_p,
105
106
  top_k,
107
+ min_p,
106
108
  frequency_penalty,
107
109
  presence_penalty,
108
110
  ignore_eos,
@@ -123,6 +125,7 @@ def gen_int(
123
125
  temperature: Optional[float] = None,
124
126
  top_p: Optional[float] = None,
125
127
  top_k: Optional[int] = None,
128
+ min_p: Optional[float] = None,
126
129
  frequency_penalty: Optional[float] = None,
127
130
  presence_penalty: Optional[float] = None,
128
131
  ignore_eos: Optional[bool] = None,
@@ -139,6 +142,7 @@ def gen_int(
139
142
  temperature,
140
143
  top_p,
141
144
  top_k,
145
+ min_p,
142
146
  frequency_penalty,
143
147
  presence_penalty,
144
148
  ignore_eos,
@@ -159,6 +163,7 @@ def gen_string(
159
163
  temperature: Optional[float] = None,
160
164
  top_p: Optional[float] = None,
161
165
  top_k: Optional[int] = None,
166
+ min_p: Optional[float] = None,
162
167
  frequency_penalty: Optional[float] = None,
163
168
  presence_penalty: Optional[float] = None,
164
169
  ignore_eos: Optional[bool] = None,
@@ -175,6 +180,7 @@ def gen_string(
175
180
  temperature,
176
181
  top_p,
177
182
  top_k,
183
+ min_p,
178
184
  frequency_penalty,
179
185
  presence_penalty,
180
186
  ignore_eos,
sglang/bench_latency.py CHANGED
@@ -54,7 +54,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
54
54
  from sglang.srt.model_config import ModelConfig
55
55
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
56
56
  from sglang.srt.model_executor.model_runner import ModelRunner
57
- from sglang.srt.sampling_params import SamplingParams
57
+ from sglang.srt.sampling.sampling_params import SamplingParams
58
58
  from sglang.srt.server_args import ServerArgs
59
59
  from sglang.srt.utils import suppress_other_loggers
60
60
 
@@ -111,7 +111,11 @@ def load_model(server_args, tp_rank):
111
111
  suppress_other_loggers()
112
112
  rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
113
113
 
114
- model_config = ModelConfig(path=server_args.model_path)
114
+ model_config = ModelConfig(
115
+ server_args.model_path,
116
+ server_args.trust_remote_code,
117
+ context_length=server_args.context_length,
118
+ )
115
119
  model_runner = ModelRunner(
116
120
  model_config=model_config,
117
121
  mem_fraction_static=server_args.mem_fraction_static,
@@ -350,7 +354,7 @@ def latency_test(
350
354
  for bs, il, ol in itertools.product(
351
355
  bench_args.batch_size, bench_args.input_len, bench_args.output_len
352
356
  ):
353
- req = prepare_synthetic_inputs_for_latency_test(bs, il)
357
+ reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
354
358
  ret = latency_test_run_once(
355
359
  bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol
356
360
  )
sglang/bench_serving.py CHANGED
@@ -149,10 +149,12 @@ async def async_request_openai_completions(
149
149
  "completions"
150
150
  ), "OpenAI Completions API URL must end with 'completions'."
151
151
 
152
+ prompt = request_func_input.prompt
153
+
152
154
  async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
153
155
  payload = {
154
156
  "model": request_func_input.model,
155
- "prompt": request_func_input.prompt,
157
+ "prompt": prompt,
156
158
  "temperature": 0.0,
157
159
  "best_of": 1,
158
160
  "max_tokens": request_func_input.output_len,
@@ -220,6 +222,13 @@ async def async_request_openai_completions(
220
222
  return output
221
223
 
222
224
 
225
+ async def async_request_gserver(
226
+ request_func_input: RequestFuncInput,
227
+ pbar: Optional[tqdm] = None,
228
+ ) -> RequestFuncOutput:
229
+ raise NotImplementedError()
230
+
231
+
223
232
  def get_model(pretrained_model_name_or_path: str) -> str:
224
233
  if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
225
234
  import huggingface_hub.constants
@@ -238,6 +247,13 @@ def get_model(pretrained_model_name_or_path: str) -> str:
238
247
  def get_tokenizer(
239
248
  pretrained_model_name_or_path: str,
240
249
  ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
250
+ if pretrained_model_name_or_path.endswith(
251
+ ".json"
252
+ ) or pretrained_model_name_or_path.endswith(".model"):
253
+ from sglang.srt.hf_transformers_utils import get_tokenizer
254
+
255
+ return get_tokenizer(pretrained_model_name_or_path)
256
+
241
257
  if pretrained_model_name_or_path is not None and not os.path.exists(
242
258
  pretrained_model_name_or_path
243
259
  ):
@@ -252,6 +268,7 @@ ASYNC_REQUEST_FUNCS = {
252
268
  "vllm": async_request_openai_completions,
253
269
  "lmdeploy": async_request_openai_completions,
254
270
  "trt": async_request_trt_llm,
271
+ "gserver": async_request_gserver,
255
272
  }
256
273
 
257
274
 
@@ -351,9 +368,9 @@ def sample_sharegpt_requests(
351
368
 
352
369
  # Tokenize the prompts and completions.
353
370
  prompt = dataset[i][0]
354
- prompt_token_ids = tokenizer(prompt).input_ids
371
+ prompt_token_ids = tokenizer.encode(prompt)
355
372
  completion = dataset[i][1]
356
- completion_token_ids = tokenizer(completion).input_ids
373
+ completion_token_ids = tokenizer.encode(completion)
357
374
  prompt_len = len(prompt_token_ids)
358
375
  output_len = (
359
376
  len(completion_token_ids) if fixed_output_len is None else fixed_output_len
@@ -361,7 +378,9 @@ def sample_sharegpt_requests(
361
378
  if prompt_len < 4 or output_len < 4:
362
379
  # Prune too short sequences.
363
380
  continue
364
- if prompt_len > 1024 or prompt_len + output_len > 2048:
381
+ if prompt_len > 1024 or (
382
+ prompt_len + output_len > 2048 and fixed_output_len is None
383
+ ):
365
384
  # Prune too long sequences.
366
385
  continue
367
386
  filtered_dataset.append((prompt, prompt_len, output_len))
@@ -422,7 +441,7 @@ def sample_random_requests(
422
441
  for i in range(num_prompts):
423
442
  # Tokenize the prompts and completions.
424
443
  prompt = dataset[i][0]
425
- prompt_token_ids = tokenizer(prompt).input_ids
444
+ prompt_token_ids = tokenizer.encode(prompt)
426
445
  prompt_len = len(prompt_token_ids)
427
446
 
428
447
  if prompt_len > input_lens[i]:
@@ -488,7 +507,7 @@ def calculate_metrics(
488
507
  output_len = outputs[i].output_len
489
508
  output_lens.append(output_len)
490
509
  retokenized_output_len = len(
491
- tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
510
+ tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
492
511
  )
493
512
  retokenized_output_lens.append(retokenized_output_len)
494
513
  total_input += input_requests[i][1]
@@ -547,7 +566,6 @@ async def benchmark(
547
566
  input_requests: List[Tuple[str, int, int]],
548
567
  request_rate: float,
549
568
  disable_tqdm: bool,
550
- enable_multi: bool,
551
569
  extra_request_body: Dict[str, Any],
552
570
  ):
553
571
  if backend in ASYNC_REQUEST_FUNCS:
@@ -756,6 +774,7 @@ def run_benchmark(args_: argparse.Namespace):
756
774
  global args
757
775
  args = args_
758
776
 
777
+ # Set global environments
759
778
  set_ulimit()
760
779
  random.seed(args.seed)
761
780
  np.random.seed(args.seed)
@@ -764,12 +783,14 @@ def run_benchmark(args_: argparse.Namespace):
764
783
  if args.extra_request_body:
765
784
  extra_request_body = json.loads(args.extra_request_body)
766
785
 
786
+ # Set url
767
787
  if args.port is None:
768
788
  args.port = {
769
789
  "sglang": 30000,
770
790
  "lmdeploy": 23333,
771
791
  "vllm": 8000,
772
792
  "trt": 8000,
793
+ "gserver": 9988,
773
794
  }.get(args.backend, 30000)
774
795
 
775
796
  api_url = (
@@ -792,7 +813,11 @@ def run_benchmark(args_: argparse.Namespace):
792
813
  if args.model is None:
793
814
  print("Please provide a model using `--model` when using `trt` backend.")
794
815
  sys.exit(1)
816
+ elif args.backend == "gserver":
817
+ api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
818
+ args.model = args.model or "default"
795
819
 
820
+ # Get model name
796
821
  if args.model is None:
797
822
  try:
798
823
  response = requests.get(model_url)
@@ -817,6 +842,7 @@ def run_benchmark(args_: argparse.Namespace):
817
842
 
818
843
  print(f"{args}\n")
819
844
 
845
+ # Read dataset
820
846
  backend = args.backend
821
847
  model_id = args.model
822
848
  tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
@@ -842,7 +868,21 @@ def run_benchmark(args_: argparse.Namespace):
842
868
  else:
843
869
  raise ValueError(f"Unknown dataset: {args.dataset_name}")
844
870
 
845
- if args.multi:
871
+ if not args.multi:
872
+ return asyncio.run(
873
+ benchmark(
874
+ backend=backend,
875
+ api_url=api_url,
876
+ model_id=model_id,
877
+ tokenizer=tokenizer,
878
+ input_requests=input_requests,
879
+ request_rate=args.request_rate,
880
+ disable_tqdm=args.disable_tqdm,
881
+ extra_request_body=extra_request_body,
882
+ )
883
+ )
884
+ else:
885
+ # Benchmark multiple rps. TODO: use a fixed duration to compute num_prompts
846
886
  request_rates = parse_request_rate_range(args.request_rate_range)
847
887
 
848
888
  for rate in request_rates:
@@ -855,27 +895,11 @@ def run_benchmark(args_: argparse.Namespace):
855
895
  input_requests=input_requests,
856
896
  request_rate=rate,
857
897
  disable_tqdm=args.disable_tqdm,
858
- enable_multi=args.multi,
859
898
  extra_request_body=extra_request_body,
860
899
  )
861
900
  )
862
- else:
863
- return asyncio.run(
864
- benchmark(
865
- backend=backend,
866
- api_url=api_url,
867
- model_id=model_id,
868
- tokenizer=tokenizer,
869
- input_requests=input_requests,
870
- request_rate=args.request_rate,
871
- disable_tqdm=args.disable_tqdm,
872
- enable_multi=args.multi,
873
- extra_request_body=extra_request_body,
874
- )
875
- )
876
901
 
877
902
 
878
- # to avoid relying on SGLang's components
879
903
  def set_ulimit(target_soft_limit=65535):
880
904
  resource_type = resource.RLIMIT_NOFILE
881
905
  current_soft, current_hard = resource.getrlimit(resource_type)
@@ -966,9 +990,9 @@ if __name__ == "__main__":
966
990
  type=float,
967
991
  default=float("inf"),
968
992
  help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
969
- "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.",
993
+ "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
970
994
  )
971
- parser.add_argument("--seed", type=int, default=0, help="Default is 0.")
995
+ parser.add_argument("--seed", type=int, default=1, help="The random seed.")
972
996
  parser.add_argument(
973
997
  "--multi",
974
998
  action="store_true",
sglang/check_env.py CHANGED
@@ -170,6 +170,17 @@ def get_gpu_topology():
170
170
  return None
171
171
 
172
172
 
173
+ def get_hypervisor_vendor():
174
+ try:
175
+ output = subprocess.check_output(["lscpu"], text=True)
176
+ for line in output.split("\n"):
177
+ if "Hypervisor vendor:" in line:
178
+ return line.split(":")[1].strip()
179
+ return None
180
+ except:
181
+ return None
182
+
183
+
173
184
  def check_env():
174
185
  """
175
186
  Check and print environment information.
@@ -184,6 +195,10 @@ def check_env():
184
195
  if gpu_topo:
185
196
  env_info["NVIDIA Topology"] = gpu_topo
186
197
 
198
+ hypervisor_vendor = get_hypervisor_vendor()
199
+ if hypervisor_vendor:
200
+ env_info["Hypervisor vendor"] = hypervisor_vendor
201
+
187
202
  ulimit_soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE)
188
203
  env_info["ulimit soft"] = ulimit_soft
189
204
 
@@ -1,6 +1,6 @@
1
- from dataclasses import dataclass, field
1
+ from dataclasses import dataclass
2
2
  from enum import Enum, auto
3
- from typing import Callable, Dict, List, Optional, Tuple
3
+ from typing import Callable, Dict, List, Tuple
4
4
 
5
5
 
6
6
  class ChatTemplateStyle(Enum):
@@ -137,7 +137,7 @@ register_chat_template(
137
137
  register_chat_template(
138
138
  ChatTemplate(
139
139
  name="chatml-llava",
140
- default_system_prompt="Answer the questions.",
140
+ default_system_prompt="You are a helpful assistant.",
141
141
  role_prefix_and_suffix={
142
142
  "system": ("<|im_start|>system\n", "<|im_end|>\n"),
143
143
  "user": ("<|im_start|>user\n", "<|im_end|>\n"),
@@ -145,7 +145,7 @@ register_chat_template(
145
145
  },
146
146
  style=ChatTemplateStyle.PLAIN,
147
147
  stop_str=("<|im_end|>",),
148
- image_token=" <image>\n",
148
+ image_token="<image>\n",
149
149
  )
150
150
  )
151
151
 
@@ -322,12 +322,17 @@ def match_chat_ml(model_path: str):
322
322
  if "tinyllama" in model_path:
323
323
  return get_chat_template("chatml")
324
324
  # Now the suffix for qwen2 chat model is "instruct"
325
- if "qwen" in model_path and ("chat" in model_path or "instruct" in model_path):
325
+ if (
326
+ "qwen" in model_path
327
+ and ("chat" in model_path or "instruct" in model_path)
328
+ and ("llava" not in model_path)
329
+ ):
326
330
  return get_chat_template("qwen")
327
331
  if (
328
332
  "llava-v1.6-34b" in model_path
329
333
  or "llava-v1.6-yi-34b" in model_path
330
334
  or "llava-next-video-34b" in model_path
335
+ or "llava-onevision-qwen2" in model_path
331
336
  ):
332
337
  return get_chat_template("chatml-llava")
333
338
 
sglang/lang/compiler.py CHANGED
@@ -130,6 +130,7 @@ class CompiledFunction:
130
130
  temperature: float = 1.0,
131
131
  top_p: float = 1.0,
132
132
  top_k: int = -1,
133
+ min_p: float = 0.0,
133
134
  frequency_penalty: float = 0.0,
134
135
  presence_penalty: float = 0.0,
135
136
  backend=None,
@@ -145,6 +146,7 @@ class CompiledFunction:
145
146
  temperature=temperature,
146
147
  top_p=top_p,
147
148
  top_k=top_k,
149
+ min_p=min_p,
148
150
  frequency_penalty=frequency_penalty,
149
151
  presence_penalty=presence_penalty,
150
152
  )
@@ -160,6 +162,7 @@ class CompiledFunction:
160
162
  temperature: float = 1.0,
161
163
  top_p: float = 1.0,
162
164
  top_k: int = -1,
165
+ min_p: float = 0.0,
163
166
  frequency_penalty: float = 0.0,
164
167
  presence_penalty: float = 0.0,
165
168
  backend=None,
@@ -178,6 +181,7 @@ class CompiledFunction:
178
181
  temperature=temperature,
179
182
  top_p=top_p,
180
183
  top_k=top_k,
184
+ min_p=min_p,
181
185
  frequency_penalty=frequency_penalty,
182
186
  presence_penalty=presence_penalty,
183
187
  )
@@ -663,6 +663,7 @@ class StreamExecutor:
663
663
  "temperature",
664
664
  "top_p",
665
665
  "top_k",
666
+ "min_p",
666
667
  "frequency_penalty",
667
668
  "presence_penalty",
668
669
  "ignore_eos",
sglang/lang/ir.py CHANGED
@@ -22,6 +22,7 @@ class SglSamplingParams:
22
22
  temperature: float = 1.0
23
23
  top_p: float = 1.0
24
24
  top_k: int = -1 # -1 means disable
25
+ min_p: float = 0.0
25
26
  frequency_penalty: float = 0.0
26
27
  presence_penalty: float = 0.0
27
28
  ignore_eos: bool = False
@@ -42,6 +43,7 @@ class SglSamplingParams:
42
43
  self.temperature,
43
44
  self.top_p,
44
45
  self.top_k,
46
+ self.min_p,
45
47
  self.frequency_penalty,
46
48
  self.presence_penalty,
47
49
  self.ignore_eos,
@@ -114,6 +116,7 @@ class SglSamplingParams:
114
116
  "temperature": self.temperature,
115
117
  "top_p": self.top_p,
116
118
  "top_k": self.top_k,
119
+ "min_p": self.min_p,
117
120
  "frequency_penalty": self.frequency_penalty,
118
121
  "presence_penalty": self.presence_penalty,
119
122
  "ignore_eos": self.ignore_eos,
@@ -149,6 +152,7 @@ class SglFunction:
149
152
  temperature: float = 1.0,
150
153
  top_p: float = 1.0,
151
154
  top_k: int = -1,
155
+ min_p: float = 0.0,
152
156
  frequency_penalty: float = 0.0,
153
157
  presence_penalty: float = 0.0,
154
158
  ignore_eos: bool = False,
@@ -169,6 +173,7 @@ class SglFunction:
169
173
  temperature=temperature,
170
174
  top_p=top_p,
171
175
  top_k=top_k,
176
+ min_p=min_p,
172
177
  frequency_penalty=frequency_penalty,
173
178
  presence_penalty=presence_penalty,
174
179
  ignore_eos=ignore_eos,
@@ -190,6 +195,7 @@ class SglFunction:
190
195
  temperature: float = 1.0,
191
196
  top_p: float = 1.0,
192
197
  top_k: int = -1,
198
+ min_p: float = 0.0,
193
199
  frequency_penalty: float = 0.0,
194
200
  presence_penalty: float = 0.0,
195
201
  ignore_eos: bool = False,
@@ -228,6 +234,7 @@ class SglFunction:
228
234
  temperature=temperature,
229
235
  top_p=top_p,
230
236
  top_k=top_k,
237
+ min_p=min_p,
231
238
  frequency_penalty=frequency_penalty,
232
239
  presence_penalty=presence_penalty,
233
240
  ignore_eos=ignore_eos,
@@ -408,6 +415,7 @@ class SglGen(SglExpr):
408
415
  temperature: Optional[float] = None,
409
416
  top_p: Optional[float] = None,
410
417
  top_k: Optional[int] = None,
418
+ min_p: Optional[float] = None,
411
419
  frequency_penalty: Optional[float] = None,
412
420
  presence_penalty: Optional[float] = None,
413
421
  ignore_eos: Optional[bool] = None,
@@ -428,6 +436,7 @@ class SglGen(SglExpr):
428
436
  temperature=temperature,
429
437
  top_p=top_p,
430
438
  top_k=top_k,
439
+ min_p=min_p,
431
440
  frequency_penalty=frequency_penalty,
432
441
  presence_penalty=presence_penalty,
433
442
  ignore_eos=ignore_eos,
sglang/launch_server.py CHANGED
@@ -1,9 +1,11 @@
1
1
  """Launch the inference server."""
2
2
 
3
3
  import argparse
4
+ import os
4
5
 
5
6
  from sglang.srt.server import launch_server
6
7
  from sglang.srt.server_args import ServerArgs
8
+ from sglang.srt.utils import kill_child_process
7
9
 
8
10
  if __name__ == "__main__":
9
11
  parser = argparse.ArgumentParser()
@@ -11,4 +13,9 @@ if __name__ == "__main__":
11
13
  args = parser.parse_args()
12
14
  server_args = ServerArgs.from_cli_args(args)
13
15
 
14
- launch_server(server_args)
16
+ try:
17
+ launch_server(server_args)
18
+ except Exception as e:
19
+ raise e
20
+ finally:
21
+ kill_child_process(os.getpid(), including_parent=False)
@@ -34,6 +34,7 @@ class SeparatorStyle(IntEnum):
34
34
  NO_COLON_TWO = auto()
35
35
  ADD_NEW_LINE_SINGLE = auto()
36
36
  LLAMA2 = auto()
37
+ LLAMA3 = auto()
37
38
  CHATGLM = auto()
38
39
  CHATML = auto()
39
40
  CHATINTERN = auto()
@@ -137,6 +138,20 @@ class Conversation:
137
138
  else:
138
139
  ret += role + ":"
139
140
  return ret
141
+ elif self.sep_style == SeparatorStyle.LLAMA3:
142
+ ret = "<|begin_of_text|>"
143
+ if self.system_message:
144
+ ret += system_prompt
145
+ else:
146
+ ret += ""
147
+ for i, (role, message) in enumerate(self.messages):
148
+ if message:
149
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
150
+ ret += f"{message.strip()}<|eot_id|>"
151
+ else:
152
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
153
+ # print(ret)
154
+ return ret
140
155
  elif self.sep_style == SeparatorStyle.LLAMA2:
141
156
  seps = [self.sep, self.sep2]
142
157
  if self.system_message:
@@ -379,12 +394,23 @@ def generate_chat_conv(
379
394
  conv.append_message(conv.roles[0], message.content)
380
395
  else:
381
396
  real_content = ""
397
+ # calculate number of image_url
398
+ num_image_url = 0
399
+ for content in message.content:
400
+ if content.type == "image_url":
401
+ num_image_url += 1
402
+ if num_image_url > 1:
403
+ image_token = "<image>"
404
+ else:
405
+ image_token = "<image>\n"
382
406
  for content in message.content:
383
407
  if content.type == "text":
408
+ if num_image_url > 16:
409
+ real_content += "\n" # for video
384
410
  real_content += content.text
385
411
  elif content.type == "image_url":
386
412
  # NOTE: Only works for llava
387
- real_content += "<image>\n"
413
+ real_content += image_token
388
414
  conv.append_image(content.image_url.url)
389
415
  conv.append_message(conv.roles[0], real_content)
390
416
  elif msg_role == "assistant":
@@ -425,6 +451,18 @@ register_conv_template(
425
451
  )
426
452
  )
427
453
 
454
+ register_conv_template(
455
+ Conversation(
456
+ name="chatml-llava",
457
+ system_template="<|im_start|>system\n{system_message}",
458
+ system_message="You are a helpful assistant.",
459
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
460
+ sep_style=SeparatorStyle.CHATML,
461
+ sep="<|im_end|>",
462
+ stop_str=["<|endoftext|>", "<|im_end|>"],
463
+ )
464
+ )
465
+
428
466
  register_conv_template(
429
467
  Conversation(
430
468
  name="vicuna_v1.1",
@@ -437,6 +475,17 @@ register_conv_template(
437
475
  )
438
476
  )
439
477
 
478
+ register_conv_template(
479
+ Conversation(
480
+ name="llava_llama_3",
481
+ system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
482
+ system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
483
+ roles=("user", "assistant"),
484
+ sep_style=SeparatorStyle.LLAMA3,
485
+ sep="",
486
+ stop_str=["<|end_of_text|>", "<|eot_id|>"],
487
+ )
488
+ )
440
489
  # Reference: https://github.com/InternLM/lmdeploy/blob/387bf54b4f124e72aab30ae9755f562e435d3d01/lmdeploy/model.py#L425-L442
441
490
  register_conv_template(
442
491
  Conversation(
@@ -30,14 +30,19 @@ from transformers import (
30
30
  PreTrainedTokenizer,
31
31
  PreTrainedTokenizerFast,
32
32
  )
33
- from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
34
33
 
35
- from sglang.srt.utils import is_multimodal_model
34
+ try:
35
+ from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
36
+
37
+ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
38
+ ChatGLMConfig.model_type: ChatGLMConfig,
39
+ DbrxConfig.model_type: DbrxConfig,
40
+ }
41
+ except ImportError:
42
+ # We want this file to run without vllm dependency
43
+ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {}
36
44
 
37
- _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
38
- ChatGLMConfig.model_type: ChatGLMConfig,
39
- DbrxConfig.model_type: DbrxConfig,
40
- }
45
+ from sglang.srt.utils import is_multimodal_model
41
46
 
42
47
 
43
48
  def download_from_hf(model_path: str):
@@ -137,18 +142,6 @@ def get_tokenizer(
137
142
  raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
138
143
  kwargs["use_fast"] = False
139
144
 
140
- if (
141
- "llama" in tokenizer_name.lower()
142
- and kwargs.get("use_fast", True)
143
- and tokenizer_name != _FAST_LLAMA_TOKENIZER
144
- ):
145
- pass
146
- # warnings.warn(
147
- # "For some LLaMA V1 models, initializing the fast tokenizer may "
148
- # "take a long time. To reduce the initialization time, consider "
149
- # f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
150
- # "tokenizer."
151
- # )
152
145
  try:
153
146
  tokenizer = AutoTokenizer.from_pretrained(
154
147
  tokenizer_name,
@@ -229,6 +222,8 @@ class TiktokenTokenizer:
229
222
  }
230
223
  assert tok_dict["word_split"] == "V1"
231
224
 
225
+ default_allowed_special = None
226
+
232
227
  kwargs = {
233
228
  "name": name,
234
229
  "pat_str": tok_dict.get("pat_str", PAT_STR_B),
@@ -242,14 +237,18 @@ class TiktokenTokenizer:
242
237
  for bytes_list in tok_dict["default_allowed_special"]
243
238
  ]
244
239
  )
245
- else:
246
- default_allowed_special = None
247
240
  if "vocab_size" in tok_dict:
248
241
  kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
249
242
 
243
+ PAD = "<|pad|>"
244
+ EOS = "<|eos|>"
245
+ SEP = "<|separator|>"
246
+
247
+ DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
248
+
250
249
  tokenizer = tiktoken.Encoding(**kwargs)
251
250
  tokenizer._default_allowed_special = default_allowed_special or set()
252
- tokenizer._default_allowed_special |= {"<|separator|>"}
251
+ tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
253
252
 
254
253
  def encode_patched(
255
254
  self,
@@ -266,14 +265,14 @@ class TiktokenTokenizer:
266
265
  self,
267
266
  text,
268
267
  allowed_special=allowed_special,
269
- disallowed_special=disallowed_special,
268
+ disallowed_special=(),
270
269
  )
271
270
 
272
271
  tokenizer.encode = functools.partial(encode_patched, tokenizer)
273
272
 
274
273
  # Convert to HF interface
275
274
  self.tokenizer = tokenizer
276
- self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
275
+ self.eos_token_id = tokenizer._special_tokens[EOS]
277
276
  self.vocab_size = tokenizer.n_vocab
278
277
  self.chat_template = Template(
279
278
  "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"