sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/api.py +13 -1
  2. sglang/bench_latency.py +10 -5
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/global_config.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +60 -49
  7. sglang/lang/chat_template.py +10 -5
  8. sglang/lang/compiler.py +4 -0
  9. sglang/lang/interpreter.py +5 -2
  10. sglang/lang/ir.py +22 -4
  11. sglang/launch_server.py +8 -1
  12. sglang/srt/constrained/jump_forward.py +13 -2
  13. sglang/srt/conversation.py +50 -1
  14. sglang/srt/hf_transformers_utils.py +22 -23
  15. sglang/srt/layers/activation.py +24 -2
  16. sglang/srt/layers/decode_attention.py +338 -50
  17. sglang/srt/layers/extend_attention.py +3 -1
  18. sglang/srt/layers/fused_moe/__init__.py +1 -0
  19. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  20. sglang/srt/layers/fused_moe/layer.py +587 -0
  21. sglang/srt/layers/layernorm.py +3 -0
  22. sglang/srt/layers/logits_processor.py +64 -27
  23. sglang/srt/layers/radix_attention.py +41 -18
  24. sglang/srt/layers/sampler.py +154 -0
  25. sglang/srt/managers/controller_multi.py +2 -8
  26. sglang/srt/managers/controller_single.py +7 -10
  27. sglang/srt/managers/detokenizer_manager.py +20 -9
  28. sglang/srt/managers/io_struct.py +44 -11
  29. sglang/srt/managers/policy_scheduler.py +5 -2
  30. sglang/srt/managers/schedule_batch.py +59 -179
  31. sglang/srt/managers/tokenizer_manager.py +193 -84
  32. sglang/srt/managers/tp_worker.py +131 -50
  33. sglang/srt/mem_cache/memory_pool.py +82 -8
  34. sglang/srt/mm_utils.py +79 -7
  35. sglang/srt/model_executor/cuda_graph_runner.py +97 -28
  36. sglang/srt/model_executor/forward_batch_info.py +188 -82
  37. sglang/srt/model_executor/model_runner.py +269 -87
  38. sglang/srt/models/chatglm.py +6 -14
  39. sglang/srt/models/commandr.py +6 -2
  40. sglang/srt/models/dbrx.py +5 -1
  41. sglang/srt/models/deepseek.py +7 -3
  42. sglang/srt/models/deepseek_v2.py +12 -7
  43. sglang/srt/models/gemma.py +6 -2
  44. sglang/srt/models/gemma2.py +22 -8
  45. sglang/srt/models/gpt_bigcode.py +5 -1
  46. sglang/srt/models/grok.py +66 -398
  47. sglang/srt/models/internlm2.py +5 -1
  48. sglang/srt/models/llama2.py +7 -3
  49. sglang/srt/models/llama_classification.py +2 -2
  50. sglang/srt/models/llama_embedding.py +4 -0
  51. sglang/srt/models/llava.py +176 -59
  52. sglang/srt/models/minicpm.py +7 -3
  53. sglang/srt/models/mixtral.py +61 -255
  54. sglang/srt/models/mixtral_quant.py +6 -5
  55. sglang/srt/models/qwen.py +7 -4
  56. sglang/srt/models/qwen2.py +15 -5
  57. sglang/srt/models/qwen2_moe.py +7 -16
  58. sglang/srt/models/stablelm.py +6 -2
  59. sglang/srt/openai_api/adapter.py +149 -58
  60. sglang/srt/sampling/sampling_batch_info.py +209 -0
  61. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
  62. sglang/srt/server.py +107 -71
  63. sglang/srt/server_args.py +49 -15
  64. sglang/srt/utils.py +27 -18
  65. sglang/test/runners.py +38 -38
  66. sglang/test/simple_eval_common.py +9 -10
  67. sglang/test/simple_eval_gpqa.py +2 -1
  68. sglang/test/simple_eval_humaneval.py +2 -2
  69. sglang/test/simple_eval_math.py +2 -1
  70. sglang/test/simple_eval_mmlu.py +2 -1
  71. sglang/test/test_activation.py +55 -0
  72. sglang/test/test_programs.py +32 -5
  73. sglang/test/test_utils.py +37 -50
  74. sglang/version.py +1 -1
  75. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
  76. sglang-0.2.14.dist-info/RECORD +114 -0
  77. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  78. sglang/launch_server_llavavid.py +0 -29
  79. sglang/srt/model_loader/model_loader.py +0 -292
  80. sglang/srt/model_loader/utils.py +0 -275
  81. sglang-0.2.12.dist-info/RECORD +0 -112
  82. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  83. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
sglang/api.py CHANGED
@@ -62,9 +62,11 @@ def gen(
62
62
  name: Optional[str] = None,
63
63
  max_tokens: Optional[int] = None,
64
64
  stop: Optional[Union[str, List[str]]] = None,
65
+ stop_token_ids: Optional[List[int]] = None,
65
66
  temperature: Optional[float] = None,
66
67
  top_p: Optional[float] = None,
67
68
  top_k: Optional[int] = None,
69
+ min_p: Optional[float] = None,
68
70
  frequency_penalty: Optional[float] = None,
69
71
  presence_penalty: Optional[float] = None,
70
72
  ignore_eos: Optional[bool] = None,
@@ -72,7 +74,7 @@ def gen(
72
74
  logprob_start_len: Optional[int] = None,
73
75
  top_logprobs_num: Optional[int] = None,
74
76
  return_text_in_logprobs: Optional[bool] = None,
75
- dtype: Optional[type] = None,
77
+ dtype: Optional[Union[type, str]] = None,
76
78
  choices: Optional[List[str]] = None,
77
79
  choices_method: Optional[ChoicesSamplingMethod] = None,
78
80
  regex: Optional[str] = None,
@@ -98,9 +100,11 @@ def gen(
98
100
  name,
99
101
  max_tokens,
100
102
  stop,
103
+ stop_token_ids,
101
104
  temperature,
102
105
  top_p,
103
106
  top_k,
107
+ min_p,
104
108
  frequency_penalty,
105
109
  presence_penalty,
106
110
  ignore_eos,
@@ -117,9 +121,11 @@ def gen_int(
117
121
  name: Optional[str] = None,
118
122
  max_tokens: Optional[int] = None,
119
123
  stop: Optional[Union[str, List[str]]] = None,
124
+ stop_token_ids: Optional[List[int]] = None,
120
125
  temperature: Optional[float] = None,
121
126
  top_p: Optional[float] = None,
122
127
  top_k: Optional[int] = None,
128
+ min_p: Optional[float] = None,
123
129
  frequency_penalty: Optional[float] = None,
124
130
  presence_penalty: Optional[float] = None,
125
131
  ignore_eos: Optional[bool] = None,
@@ -132,9 +138,11 @@ def gen_int(
132
138
  name,
133
139
  max_tokens,
134
140
  stop,
141
+ stop_token_ids,
135
142
  temperature,
136
143
  top_p,
137
144
  top_k,
145
+ min_p,
138
146
  frequency_penalty,
139
147
  presence_penalty,
140
148
  ignore_eos,
@@ -151,9 +159,11 @@ def gen_string(
151
159
  name: Optional[str] = None,
152
160
  max_tokens: Optional[int] = None,
153
161
  stop: Optional[Union[str, List[str]]] = None,
162
+ stop_token_ids: Optional[List[int]] = None,
154
163
  temperature: Optional[float] = None,
155
164
  top_p: Optional[float] = None,
156
165
  top_k: Optional[int] = None,
166
+ min_p: Optional[float] = None,
157
167
  frequency_penalty: Optional[float] = None,
158
168
  presence_penalty: Optional[float] = None,
159
169
  ignore_eos: Optional[bool] = None,
@@ -166,9 +176,11 @@ def gen_string(
166
176
  name,
167
177
  max_tokens,
168
178
  stop,
179
+ stop_token_ids,
169
180
  temperature,
170
181
  top_p,
171
182
  top_k,
183
+ min_p,
172
184
  frequency_penalty,
173
185
  presence_penalty,
174
186
  ignore_eos,
sglang/bench_latency.py CHANGED
@@ -54,7 +54,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
54
54
  from sglang.srt.model_config import ModelConfig
55
55
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
56
56
  from sglang.srt.model_executor.model_runner import ModelRunner
57
- from sglang.srt.sampling_params import SamplingParams
57
+ from sglang.srt.sampling.sampling_params import SamplingParams
58
58
  from sglang.srt.server_args import ServerArgs
59
59
  from sglang.srt.utils import suppress_other_loggers
60
60
 
@@ -64,7 +64,7 @@ class BenchArgs:
64
64
  run_name: str = "before"
65
65
  batch_size: Tuple[int] = (1,)
66
66
  input_len: Tuple[int] = (1024,)
67
- output_len: Tuple[int] = (4,)
67
+ output_len: Tuple[int] = (16,)
68
68
  result_filename: str = ""
69
69
  correctness_test: bool = False
70
70
  # This is only used for correctness test
@@ -111,7 +111,11 @@ def load_model(server_args, tp_rank):
111
111
  suppress_other_loggers()
112
112
  rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
113
113
 
114
- model_config = ModelConfig(path=server_args.model_path)
114
+ model_config = ModelConfig(
115
+ server_args.model_path,
116
+ server_args.trust_remote_code,
117
+ context_length=server_args.context_length,
118
+ )
115
119
  model_runner = ModelRunner(
116
120
  model_config=model_config,
117
121
  mem_fraction_static=server_args.mem_fraction_static,
@@ -195,7 +199,7 @@ def extend(reqs, model_runner):
195
199
  token_to_kv_pool=model_runner.token_to_kv_pool,
196
200
  tree_cache=None,
197
201
  )
198
- batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
202
+ batch.prepare_for_extend(model_runner.model_config.vocab_size)
199
203
  output = model_runner.forward(batch, ForwardMode.EXTEND)
200
204
  next_token_ids = batch.sample(output.next_token_logits)
201
205
  return next_token_ids, output.next_token_logits, batch
@@ -221,6 +225,7 @@ def correctness_test(
221
225
 
222
226
  # Prepare inputs
223
227
  input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
228
+ rank_print(f"{input_ids=}")
224
229
 
225
230
  if bench_args.cut_len > 0:
226
231
  # Prefill
@@ -349,7 +354,7 @@ def latency_test(
349
354
  for bs, il, ol in itertools.product(
350
355
  bench_args.batch_size, bench_args.input_len, bench_args.output_len
351
356
  ):
352
- req = prepare_synthetic_inputs_for_latency_test(bs, il)
357
+ reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
353
358
  ret = latency_test_run_once(
354
359
  bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol
355
360
  )
sglang/bench_serving.py CHANGED
@@ -149,10 +149,12 @@ async def async_request_openai_completions(
149
149
  "completions"
150
150
  ), "OpenAI Completions API URL must end with 'completions'."
151
151
 
152
+ prompt = request_func_input.prompt
153
+
152
154
  async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
153
155
  payload = {
154
156
  "model": request_func_input.model,
155
- "prompt": request_func_input.prompt,
157
+ "prompt": prompt,
156
158
  "temperature": 0.0,
157
159
  "best_of": 1,
158
160
  "max_tokens": request_func_input.output_len,
@@ -220,6 +222,13 @@ async def async_request_openai_completions(
220
222
  return output
221
223
 
222
224
 
225
+ async def async_request_gserver(
226
+ request_func_input: RequestFuncInput,
227
+ pbar: Optional[tqdm] = None,
228
+ ) -> RequestFuncOutput:
229
+ raise NotImplementedError()
230
+
231
+
223
232
  def get_model(pretrained_model_name_or_path: str) -> str:
224
233
  if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
225
234
  import huggingface_hub.constants
@@ -238,6 +247,13 @@ def get_model(pretrained_model_name_or_path: str) -> str:
238
247
  def get_tokenizer(
239
248
  pretrained_model_name_or_path: str,
240
249
  ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
250
+ if pretrained_model_name_or_path.endswith(
251
+ ".json"
252
+ ) or pretrained_model_name_or_path.endswith(".model"):
253
+ from sglang.srt.hf_transformers_utils import get_tokenizer
254
+
255
+ return get_tokenizer(pretrained_model_name_or_path)
256
+
241
257
  if pretrained_model_name_or_path is not None and not os.path.exists(
242
258
  pretrained_model_name_or_path
243
259
  ):
@@ -252,6 +268,7 @@ ASYNC_REQUEST_FUNCS = {
252
268
  "vllm": async_request_openai_completions,
253
269
  "lmdeploy": async_request_openai_completions,
254
270
  "trt": async_request_trt_llm,
271
+ "gserver": async_request_gserver,
255
272
  }
256
273
 
257
274
 
@@ -351,9 +368,9 @@ def sample_sharegpt_requests(
351
368
 
352
369
  # Tokenize the prompts and completions.
353
370
  prompt = dataset[i][0]
354
- prompt_token_ids = tokenizer(prompt).input_ids
371
+ prompt_token_ids = tokenizer.encode(prompt)
355
372
  completion = dataset[i][1]
356
- completion_token_ids = tokenizer(completion).input_ids
373
+ completion_token_ids = tokenizer.encode(completion)
357
374
  prompt_len = len(prompt_token_ids)
358
375
  output_len = (
359
376
  len(completion_token_ids) if fixed_output_len is None else fixed_output_len
@@ -361,7 +378,9 @@ def sample_sharegpt_requests(
361
378
  if prompt_len < 4 or output_len < 4:
362
379
  # Prune too short sequences.
363
380
  continue
364
- if prompt_len > 1024 or prompt_len + output_len > 2048:
381
+ if prompt_len > 1024 or (
382
+ prompt_len + output_len > 2048 and fixed_output_len is None
383
+ ):
365
384
  # Prune too long sequences.
366
385
  continue
367
386
  filtered_dataset.append((prompt, prompt_len, output_len))
@@ -422,7 +441,7 @@ def sample_random_requests(
422
441
  for i in range(num_prompts):
423
442
  # Tokenize the prompts and completions.
424
443
  prompt = dataset[i][0]
425
- prompt_token_ids = tokenizer(prompt).input_ids
444
+ prompt_token_ids = tokenizer.encode(prompt)
426
445
  prompt_len = len(prompt_token_ids)
427
446
 
428
447
  if prompt_len > input_lens[i]:
@@ -488,7 +507,7 @@ def calculate_metrics(
488
507
  output_len = outputs[i].output_len
489
508
  output_lens.append(output_len)
490
509
  retokenized_output_len = len(
491
- tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
510
+ tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
492
511
  )
493
512
  retokenized_output_lens.append(retokenized_output_len)
494
513
  total_input += input_requests[i][1]
@@ -547,7 +566,6 @@ async def benchmark(
547
566
  input_requests: List[Tuple[str, int, int]],
548
567
  request_rate: float,
549
568
  disable_tqdm: bool,
550
- enable_multi: bool,
551
569
  extra_request_body: Dict[str, Any],
552
570
  ):
553
571
  if backend in ASYNC_REQUEST_FUNCS:
@@ -756,6 +774,7 @@ def run_benchmark(args_: argparse.Namespace):
756
774
  global args
757
775
  args = args_
758
776
 
777
+ # Set global environments
759
778
  set_ulimit()
760
779
  random.seed(args.seed)
761
780
  np.random.seed(args.seed)
@@ -764,12 +783,14 @@ def run_benchmark(args_: argparse.Namespace):
764
783
  if args.extra_request_body:
765
784
  extra_request_body = json.loads(args.extra_request_body)
766
785
 
786
+ # Set url
767
787
  if args.port is None:
768
788
  args.port = {
769
789
  "sglang": 30000,
770
790
  "lmdeploy": 23333,
771
791
  "vllm": 8000,
772
792
  "trt": 8000,
793
+ "gserver": 9988,
773
794
  }.get(args.backend, 30000)
774
795
 
775
796
  api_url = (
@@ -792,7 +813,11 @@ def run_benchmark(args_: argparse.Namespace):
792
813
  if args.model is None:
793
814
  print("Please provide a model using `--model` when using `trt` backend.")
794
815
  sys.exit(1)
816
+ elif args.backend == "gserver":
817
+ api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
818
+ args.model = args.model or "default"
795
819
 
820
+ # Get model name
796
821
  if args.model is None:
797
822
  try:
798
823
  response = requests.get(model_url)
@@ -817,6 +842,7 @@ def run_benchmark(args_: argparse.Namespace):
817
842
 
818
843
  print(f"{args}\n")
819
844
 
845
+ # Read dataset
820
846
  backend = args.backend
821
847
  model_id = args.model
822
848
  tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
@@ -842,7 +868,21 @@ def run_benchmark(args_: argparse.Namespace):
842
868
  else:
843
869
  raise ValueError(f"Unknown dataset: {args.dataset_name}")
844
870
 
845
- if args.multi:
871
+ if not args.multi:
872
+ return asyncio.run(
873
+ benchmark(
874
+ backend=backend,
875
+ api_url=api_url,
876
+ model_id=model_id,
877
+ tokenizer=tokenizer,
878
+ input_requests=input_requests,
879
+ request_rate=args.request_rate,
880
+ disable_tqdm=args.disable_tqdm,
881
+ extra_request_body=extra_request_body,
882
+ )
883
+ )
884
+ else:
885
+ # Benchmark multiple rps. TODO: use a fixed duration to compute num_prompts
846
886
  request_rates = parse_request_rate_range(args.request_rate_range)
847
887
 
848
888
  for rate in request_rates:
@@ -855,27 +895,11 @@ def run_benchmark(args_: argparse.Namespace):
855
895
  input_requests=input_requests,
856
896
  request_rate=rate,
857
897
  disable_tqdm=args.disable_tqdm,
858
- enable_multi=args.multi,
859
898
  extra_request_body=extra_request_body,
860
899
  )
861
900
  )
862
- else:
863
- return asyncio.run(
864
- benchmark(
865
- backend=backend,
866
- api_url=api_url,
867
- model_id=model_id,
868
- tokenizer=tokenizer,
869
- input_requests=input_requests,
870
- request_rate=args.request_rate,
871
- disable_tqdm=args.disable_tqdm,
872
- enable_multi=args.multi,
873
- extra_request_body=extra_request_body,
874
- )
875
- )
876
901
 
877
902
 
878
- # to avoid relying on SGLang's components
879
903
  def set_ulimit(target_soft_limit=65535):
880
904
  resource_type = resource.RLIMIT_NOFILE
881
905
  current_soft, current_hard = resource.getrlimit(resource_type)
@@ -966,9 +990,9 @@ if __name__ == "__main__":
966
990
  type=float,
967
991
  default=float("inf"),
968
992
  help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
969
- "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.",
993
+ "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
970
994
  )
971
- parser.add_argument("--seed", type=int, default=0, help="Default is 0.")
995
+ parser.add_argument("--seed", type=int, default=1, help="The random seed.")
972
996
  parser.add_argument(
973
997
  "--multi",
974
998
  action="store_true",
sglang/check_env.py CHANGED
@@ -170,6 +170,17 @@ def get_gpu_topology():
170
170
  return None
171
171
 
172
172
 
173
+ def get_hypervisor_vendor():
174
+ try:
175
+ output = subprocess.check_output(["lscpu"], text=True)
176
+ for line in output.split("\n"):
177
+ if "Hypervisor vendor:" in line:
178
+ return line.split(":")[1].strip()
179
+ return None
180
+ except:
181
+ return None
182
+
183
+
173
184
  def check_env():
174
185
  """
175
186
  Check and print environment information.
@@ -184,6 +195,10 @@ def check_env():
184
195
  if gpu_topo:
185
196
  env_info["NVIDIA Topology"] = gpu_topo
186
197
 
198
+ hypervisor_vendor = get_hypervisor_vendor()
199
+ if hypervisor_vendor:
200
+ env_info["Hypervisor vendor"] = hypervisor_vendor
201
+
187
202
  ulimit_soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE)
188
203
  env_info["ulimit soft"] = ulimit_soft
189
204
 
sglang/global_config.py CHANGED
@@ -27,7 +27,7 @@ class GlobalConfig:
27
27
  # Runtime constants: others
28
28
  self.num_continue_decode_steps = 10
29
29
  self.retract_decode_steps = 20
30
- self.flashinfer_workspace_size = 192 * 1024 * 1024
30
+ self.flashinfer_workspace_size = 384 * 1024 * 1024
31
31
 
32
32
  # Output tokenization configs
33
33
  self.skip_special_tokens_in_output = True
@@ -1,21 +1,23 @@
1
1
  import json
2
+ import warnings
2
3
  from typing import List, Optional
3
4
 
4
5
  from sglang.global_config import global_config
5
6
  from sglang.lang.backend.base_backend import BaseBackend
6
7
  from sglang.lang.chat_template import get_chat_template_by_model_path
7
- from sglang.lang.choices import (
8
- ChoicesDecision,
9
- ChoicesSamplingMethod,
10
- token_length_normalized,
11
- )
8
+ from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
12
9
  from sglang.lang.interpreter import StreamExecutor
13
- from sglang.lang.ir import SglSamplingParams
10
+ from sglang.lang.ir import (
11
+ REGEX_BOOL,
12
+ REGEX_FLOAT,
13
+ REGEX_INT,
14
+ REGEX_STR,
15
+ SglSamplingParams,
16
+ )
14
17
  from sglang.utils import http_request
15
18
 
16
19
 
17
20
  class RuntimeEndpoint(BaseBackend):
18
-
19
21
  def __init__(
20
22
  self,
21
23
  base_url: str,
@@ -95,32 +97,52 @@ class RuntimeEndpoint(BaseBackend):
95
97
  )
96
98
  self._assert_success(res)
97
99
 
100
+ def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
101
+ if sampling_params.dtype is None:
102
+ return
103
+
104
+ if sampling_params.stop == ():
105
+ sampling_params.stop = []
106
+
107
+ dtype_regex = None
108
+ if sampling_params.dtype in ["int", int]:
109
+
110
+ dtype_regex = REGEX_INT
111
+ sampling_params.stop.extend([" ", "\n"])
112
+ elif sampling_params.dtype in ["float", float]:
113
+
114
+ dtype_regex = REGEX_FLOAT
115
+ sampling_params.stop.extend([" ", "\n"])
116
+ elif sampling_params.dtype in ["str", str]:
117
+
118
+ dtype_regex = REGEX_STR
119
+ elif sampling_params.dtype in ["bool", bool]:
120
+
121
+ dtype_regex = REGEX_BOOL
122
+ else:
123
+ raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
124
+
125
+ if dtype_regex is not None and sampling_params.regex is not None:
126
+ warnings.warn(
127
+ f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
128
+ )
129
+
130
+ sampling_params.regex = dtype_regex
131
+
98
132
  def generate(
99
133
  self,
100
134
  s: StreamExecutor,
101
135
  sampling_params: SglSamplingParams,
102
136
  ):
103
- if sampling_params.dtype is None:
104
- data = {
105
- "text": s.text_,
106
- "sampling_params": {
107
- "skip_special_tokens": global_config.skip_special_tokens_in_output,
108
- "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
109
- **sampling_params.to_srt_kwargs(),
110
- },
111
- }
112
- elif sampling_params.dtype in [int, "int"]:
113
- data = {
114
- "text": s.text_,
115
- "sampling_params": {
116
- "skip_special_tokens": global_config.skip_special_tokens_in_output,
117
- "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
118
- "dtype": "int",
119
- **sampling_params.to_srt_kwargs(),
120
- },
121
- }
122
- else:
123
- raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
137
+ self._handle_dtype_to_regex(sampling_params)
138
+ data = {
139
+ "text": s.text_,
140
+ "sampling_params": {
141
+ "skip_special_tokens": global_config.skip_special_tokens_in_output,
142
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
143
+ **sampling_params.to_srt_kwargs(),
144
+ },
145
+ }
124
146
 
125
147
  for item in [
126
148
  "return_logprob",
@@ -151,27 +173,16 @@ class RuntimeEndpoint(BaseBackend):
151
173
  s: StreamExecutor,
152
174
  sampling_params: SglSamplingParams,
153
175
  ):
154
- if sampling_params.dtype is None:
155
- data = {
156
- "text": s.text_,
157
- "sampling_params": {
158
- "skip_special_tokens": global_config.skip_special_tokens_in_output,
159
- "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
160
- **sampling_params.to_srt_kwargs(),
161
- },
162
- }
163
- elif sampling_params.dtype in [int, "int"]:
164
- data = {
165
- "text": s.text_,
166
- "sampling_params": {
167
- "skip_special_tokens": global_config.skip_special_tokens_in_output,
168
- "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
169
- "dtype": "int",
170
- **sampling_params.to_srt_kwargs(),
171
- },
172
- }
173
- else:
174
- raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
176
+ self._handle_dtype_to_regex(sampling_params)
177
+
178
+ data = {
179
+ "text": s.text_,
180
+ "sampling_params": {
181
+ "skip_special_tokens": global_config.skip_special_tokens_in_output,
182
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
183
+ **sampling_params.to_srt_kwargs(),
184
+ },
185
+ }
175
186
 
176
187
  for item in [
177
188
  "return_logprob",
@@ -1,6 +1,6 @@
1
- from dataclasses import dataclass, field
1
+ from dataclasses import dataclass
2
2
  from enum import Enum, auto
3
- from typing import Callable, Dict, List, Optional, Tuple
3
+ from typing import Callable, Dict, List, Tuple
4
4
 
5
5
 
6
6
  class ChatTemplateStyle(Enum):
@@ -137,7 +137,7 @@ register_chat_template(
137
137
  register_chat_template(
138
138
  ChatTemplate(
139
139
  name="chatml-llava",
140
- default_system_prompt="Answer the questions.",
140
+ default_system_prompt="You are a helpful assistant.",
141
141
  role_prefix_and_suffix={
142
142
  "system": ("<|im_start|>system\n", "<|im_end|>\n"),
143
143
  "user": ("<|im_start|>user\n", "<|im_end|>\n"),
@@ -145,7 +145,7 @@ register_chat_template(
145
145
  },
146
146
  style=ChatTemplateStyle.PLAIN,
147
147
  stop_str=("<|im_end|>",),
148
- image_token=" <image>\n",
148
+ image_token="<image>\n",
149
149
  )
150
150
  )
151
151
 
@@ -322,12 +322,17 @@ def match_chat_ml(model_path: str):
322
322
  if "tinyllama" in model_path:
323
323
  return get_chat_template("chatml")
324
324
  # Now the suffix for qwen2 chat model is "instruct"
325
- if "qwen" in model_path and ("chat" in model_path or "instruct" in model_path):
325
+ if (
326
+ "qwen" in model_path
327
+ and ("chat" in model_path or "instruct" in model_path)
328
+ and ("llava" not in model_path)
329
+ ):
326
330
  return get_chat_template("qwen")
327
331
  if (
328
332
  "llava-v1.6-34b" in model_path
329
333
  or "llava-v1.6-yi-34b" in model_path
330
334
  or "llava-next-video-34b" in model_path
335
+ or "llava-onevision-qwen2" in model_path
331
336
  ):
332
337
  return get_chat_template("chatml-llava")
333
338
 
sglang/lang/compiler.py CHANGED
@@ -130,6 +130,7 @@ class CompiledFunction:
130
130
  temperature: float = 1.0,
131
131
  top_p: float = 1.0,
132
132
  top_k: int = -1,
133
+ min_p: float = 0.0,
133
134
  frequency_penalty: float = 0.0,
134
135
  presence_penalty: float = 0.0,
135
136
  backend=None,
@@ -145,6 +146,7 @@ class CompiledFunction:
145
146
  temperature=temperature,
146
147
  top_p=top_p,
147
148
  top_k=top_k,
149
+ min_p=min_p,
148
150
  frequency_penalty=frequency_penalty,
149
151
  presence_penalty=presence_penalty,
150
152
  )
@@ -160,6 +162,7 @@ class CompiledFunction:
160
162
  temperature: float = 1.0,
161
163
  top_p: float = 1.0,
162
164
  top_k: int = -1,
165
+ min_p: float = 0.0,
163
166
  frequency_penalty: float = 0.0,
164
167
  presence_penalty: float = 0.0,
165
168
  backend=None,
@@ -178,6 +181,7 @@ class CompiledFunction:
178
181
  temperature=temperature,
179
182
  top_p=top_p,
180
183
  top_k=top_k,
184
+ min_p=min_p,
181
185
  frequency_penalty=frequency_penalty,
182
186
  presence_penalty=presence_penalty,
183
187
  )
@@ -20,7 +20,6 @@ from sglang.lang.ir import (
20
20
  SglConstantText,
21
21
  SglExpr,
22
22
  SglExprList,
23
- SglFunction,
24
23
  SglGen,
25
24
  SglImage,
26
25
  SglRoleBegin,
@@ -181,8 +180,10 @@ class StreamExecutor:
181
180
  num_api_spec_tokens=None,
182
181
  use_thread=True,
183
182
  ):
183
+ from sglang.lang.backend.base_backend import BaseBackend
184
+
184
185
  self.sid = uuid.uuid4().hex
185
- self.backend = backend
186
+ self.backend: BaseBackend = backend
186
187
  self.arguments: Dict[str, Any] = arguments
187
188
  self.default_sampling_para = default_sampling_para
188
189
  self.stream = stream
@@ -658,9 +659,11 @@ class StreamExecutor:
658
659
  for item in [
659
660
  "max_new_tokens",
660
661
  "stop",
662
+ "stop_token_ids",
661
663
  "temperature",
662
664
  "top_p",
663
665
  "top_k",
666
+ "min_p",
664
667
  "frequency_penalty",
665
668
  "presence_penalty",
666
669
  "ignore_eos",