sglang 0.3.4.post2__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_latency.py +3 -3
  3. sglang/bench_server_latency.py +2 -3
  4. sglang/bench_serving.py +92 -0
  5. sglang/global_config.py +9 -3
  6. sglang/lang/chat_template.py +50 -25
  7. sglang/lang/interpreter.py +9 -1
  8. sglang/lang/ir.py +11 -2
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/configs/model_config.py +51 -13
  11. sglang/srt/constrained/__init__.py +18 -0
  12. sglang/srt/constrained/bnf_cache.py +61 -0
  13. sglang/srt/constrained/grammar.py +190 -0
  14. sglang/srt/hf_transformers_utils.py +6 -5
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
  16. sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
  17. sglang/srt/layers/fused_moe/fused_moe.py +4 -3
  18. sglang/srt/layers/fused_moe/layer.py +28 -0
  19. sglang/srt/layers/quantization/base_config.py +16 -1
  20. sglang/srt/layers/vocab_parallel_embedding.py +486 -0
  21. sglang/srt/managers/data_parallel_controller.py +7 -6
  22. sglang/srt/managers/detokenizer_manager.py +9 -11
  23. sglang/srt/managers/image_processor.py +4 -3
  24. sglang/srt/managers/io_struct.py +70 -78
  25. sglang/srt/managers/schedule_batch.py +33 -49
  26. sglang/srt/managers/schedule_policy.py +24 -13
  27. sglang/srt/managers/scheduler.py +137 -80
  28. sglang/srt/managers/tokenizer_manager.py +224 -336
  29. sglang/srt/managers/tp_worker.py +5 -5
  30. sglang/srt/mem_cache/flush_cache.py +1 -1
  31. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  32. sglang/srt/model_executor/model_runner.py +8 -17
  33. sglang/srt/models/baichuan.py +4 -4
  34. sglang/srt/models/chatglm.py +4 -4
  35. sglang/srt/models/commandr.py +1 -1
  36. sglang/srt/models/dbrx.py +5 -5
  37. sglang/srt/models/deepseek.py +4 -4
  38. sglang/srt/models/deepseek_v2.py +4 -4
  39. sglang/srt/models/exaone.py +4 -4
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +1 -1
  42. sglang/srt/models/gpt2.py +287 -0
  43. sglang/srt/models/gpt_bigcode.py +1 -1
  44. sglang/srt/models/grok.py +4 -4
  45. sglang/srt/models/internlm2.py +4 -4
  46. sglang/srt/models/llama.py +15 -7
  47. sglang/srt/models/llama_embedding.py +2 -10
  48. sglang/srt/models/llama_reward.py +5 -0
  49. sglang/srt/models/minicpm.py +4 -4
  50. sglang/srt/models/minicpm3.py +4 -4
  51. sglang/srt/models/mixtral.py +7 -5
  52. sglang/srt/models/mixtral_quant.py +4 -4
  53. sglang/srt/models/mllama.py +5 -5
  54. sglang/srt/models/olmo.py +4 -4
  55. sglang/srt/models/olmoe.py +4 -4
  56. sglang/srt/models/qwen.py +4 -4
  57. sglang/srt/models/qwen2.py +4 -4
  58. sglang/srt/models/qwen2_moe.py +4 -4
  59. sglang/srt/models/qwen2_vl.py +4 -8
  60. sglang/srt/models/stablelm.py +4 -4
  61. sglang/srt/models/torch_native_llama.py +4 -4
  62. sglang/srt/models/xverse.py +4 -4
  63. sglang/srt/models/xverse_moe.py +4 -4
  64. sglang/srt/openai_api/adapter.py +52 -66
  65. sglang/srt/sampling/sampling_batch_info.py +7 -13
  66. sglang/srt/server.py +31 -35
  67. sglang/srt/server_args.py +34 -5
  68. sglang/srt/utils.py +40 -56
  69. sglang/test/runners.py +2 -1
  70. sglang/test/test_utils.py +73 -25
  71. sglang/utils.py +62 -1
  72. sglang/version.py +1 -1
  73. sglang-0.3.5.dist-info/METADATA +344 -0
  74. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/RECORD +77 -73
  75. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
  76. sglang-0.3.4.post2.dist-info/METADATA +0 -899
  77. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
  78. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
@@ -71,6 +71,7 @@ from sglang.srt.openai_api.protocol import (
71
71
  TopLogprob,
72
72
  UsageInfo,
73
73
  )
74
+ from sglang.utils import get_exception_traceback
74
75
 
75
76
  logger = logging.getLogger(__name__)
76
77
 
@@ -314,6 +315,8 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
314
315
  )
315
316
 
316
317
  except Exception as e:
318
+ logger.error(f"error: {get_exception_traceback()}")
319
+ responses = []
317
320
  error_json = {
318
321
  "id": f"batch_req_{uuid.uuid4()}",
319
322
  "custom_id": request_data.get("custom_id"),
@@ -363,7 +366,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
363
366
  }
364
367
 
365
368
  except Exception as e:
366
- logger.error("error in SGLang:", e)
369
+ logger.error(f"error: {e}")
367
370
  # Update batch status to "failed"
368
371
  retrieve_batch = batch_storage[batch_id]
369
372
  retrieve_batch.status = "failed"
@@ -469,80 +472,67 @@ async def v1_retrieve_file_content(file_id: str):
469
472
  def v1_generate_request(
470
473
  all_requests: List[CompletionRequest], request_ids: List[str] = None
471
474
  ):
475
+ if len(all_requests) > 1:
476
+ first_prompt_type = type(all_requests[0].prompt)
477
+ for request in all_requests:
478
+ assert (
479
+ type(request.prompt) is first_prompt_type
480
+ ), "All prompts must be of the same type in file input settings"
481
+ if request.n > 1:
482
+ raise ValueError(
483
+ "Parallel sampling is not supported for completions from files"
484
+ )
485
+
472
486
  prompts = []
473
487
  sampling_params_list = []
474
488
  return_logprobs = []
475
489
  logprob_start_lens = []
476
490
  top_logprobs_nums = []
477
491
 
478
- # NOTE: with openai API, the prompt's logprobs are always not computed
479
- first_prompt_type = type(all_requests[0].prompt)
480
492
  for request in all_requests:
481
- assert (
482
- type(request.prompt) is first_prompt_type
483
- ), "All prompts must be of the same type in file input settings"
484
- if len(all_requests) > 1 and request.n > 1:
485
- raise ValueError(
486
- "Parallel sampling is not supported for completions from files"
487
- )
493
+ # NOTE: with openai API, the prompt's logprobs are always not computed
488
494
  if request.echo and request.logprobs:
489
495
  logger.warning(
490
496
  "Echo is not compatible with logprobs. "
491
- "To compute logprobs of input prompt, please use SGLang /request API."
497
+ "To compute logprobs of input prompt, please use the native /generate API."
492
498
  )
493
499
 
494
- for request in all_requests:
495
500
  prompts.append(request.prompt)
501
+ sampling_params_list.append(
502
+ {
503
+ "temperature": request.temperature,
504
+ "max_new_tokens": request.max_tokens,
505
+ "min_new_tokens": request.min_tokens,
506
+ "stop": request.stop,
507
+ "stop_token_ids": request.stop_token_ids,
508
+ "top_p": request.top_p,
509
+ "presence_penalty": request.presence_penalty,
510
+ "frequency_penalty": request.frequency_penalty,
511
+ "repetition_penalty": request.repetition_penalty,
512
+ "regex": request.regex,
513
+ "json_schema": request.json_schema,
514
+ "n": request.n,
515
+ "ignore_eos": request.ignore_eos,
516
+ "no_stop_trim": request.no_stop_trim,
517
+ }
518
+ )
496
519
  return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
497
520
  logprob_start_lens.append(-1)
498
521
  top_logprobs_nums.append(
499
522
  request.logprobs if request.logprobs is not None else 0
500
523
  )
501
- sampling_params = []
502
- if isinstance(request.no_stop_trim, list):
503
- num_reqs = len(request.prompt)
504
- else:
505
- num_reqs = 1
506
- for i in range(num_reqs):
507
- sampling_params.append(
508
- {
509
- "temperature": request.temperature,
510
- "max_new_tokens": request.max_tokens,
511
- "min_new_tokens": request.min_tokens,
512
- "stop": request.stop,
513
- "stop_token_ids": request.stop_token_ids,
514
- "top_p": request.top_p,
515
- "presence_penalty": request.presence_penalty,
516
- "frequency_penalty": request.frequency_penalty,
517
- "repetition_penalty": request.repetition_penalty,
518
- "regex": request.regex,
519
- "json_schema": request.json_schema,
520
- "n": request.n,
521
- "ignore_eos": request.ignore_eos,
522
- "no_stop_trim": (
523
- request.no_stop_trim
524
- if not isinstance(request.no_stop_trim, list)
525
- else request.no_stop_trim[i]
526
- ),
527
- }
528
- )
529
- if num_reqs == 1:
530
- sampling_params_list.append(sampling_params[0])
531
- else:
532
- sampling_params_list.append(sampling_params)
533
524
 
534
525
  if len(all_requests) == 1:
535
- prompt = prompts[0]
526
+ if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
527
+ prompt_kwargs = {"text": prompts[0]}
528
+ else:
529
+ prompt_kwargs = {"input_ids": prompts[0]}
536
530
  sampling_params_list = sampling_params_list[0]
537
- logprob_start_lens = logprob_start_lens[0]
538
531
  return_logprobs = return_logprobs[0]
532
+ logprob_start_lens = logprob_start_lens[0]
539
533
  top_logprobs_nums = top_logprobs_nums[0]
540
- if isinstance(prompt, str) or isinstance(prompt[0], str):
541
- prompt_kwargs = {"text": prompt}
542
- else:
543
- prompt_kwargs = {"input_ids": prompt}
544
534
  else:
545
- if isinstance(prompts[0], str):
535
+ if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
546
536
  prompt_kwargs = {"text": prompts}
547
537
  else:
548
538
  prompt_kwargs = {"input_ids": prompts}
@@ -558,9 +548,7 @@ def v1_generate_request(
558
548
  rid=request_ids,
559
549
  )
560
550
 
561
- if len(all_requests) == 1:
562
- return adapted_request, all_requests[0]
563
- return adapted_request, all_requests
551
+ return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
564
552
 
565
553
 
566
554
  def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
@@ -595,7 +583,7 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
595
583
  if isinstance(request, list) and request[idx].echo:
596
584
  echo = True
597
585
  text = request[idx].prompt + text
598
- if (not isinstance(request, list)) and echo:
586
+ if echo and not isinstance(request, list):
599
587
  prompt_index = idx // request.n
600
588
  text = prompts[prompt_index] + text
601
589
 
@@ -709,7 +697,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
709
697
  async for content in tokenizer_manager.generate_request(
710
698
  adapted_request, raw_request
711
699
  ):
712
- index = content["index"]
700
+ index = content.get("index", 0)
713
701
 
714
702
  stream_buffer = stream_buffers.get(index, "")
715
703
  n_prev_token = n_prev_tokens.get(index, 0)
@@ -945,19 +933,18 @@ def v1_chat_generate_request(
945
933
  sampling_params_list.append(sampling_params)
946
934
 
947
935
  image_data_list.append(image_data)
948
- modalities_list.extend(modalities)
936
+ modalities_list.append(modalities)
949
937
  if len(all_requests) == 1:
950
- input_ids = input_ids[0]
951
- if isinstance(input_ids, str):
952
- prompt_kwargs = {"text": input_ids}
938
+ if isinstance(input_ids[0], str):
939
+ prompt_kwargs = {"text": input_ids[0]}
953
940
  else:
954
- prompt_kwargs = {"input_ids": input_ids}
941
+ prompt_kwargs = {"input_ids": input_ids[0]}
955
942
  sampling_params_list = sampling_params_list[0]
956
943
  image_data_list = image_data_list[0]
957
944
  return_logprobs = return_logprobs[0]
958
945
  logprob_start_lens = logprob_start_lens[0]
959
946
  top_logprobs_nums = top_logprobs_nums[0]
960
- modalities_list = modalities_list[:1]
947
+ modalities_list = modalities_list[0]
961
948
  else:
962
949
  if isinstance(input_ids[0], str):
963
950
  prompt_kwargs = {"text": input_ids}
@@ -976,9 +963,8 @@ def v1_chat_generate_request(
976
963
  rid=request_ids,
977
964
  modalities=modalities_list,
978
965
  )
979
- if len(all_requests) == 1:
980
- return adapted_request, all_requests[0]
981
- return adapted_request, all_requests
966
+
967
+ return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
982
968
 
983
969
 
984
970
  def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
@@ -1116,7 +1102,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1116
1102
  async for content in tokenizer_manager.generate_request(
1117
1103
  adapted_request, raw_request
1118
1104
  ):
1119
- index = content["index"]
1105
+ index = content.get("index", 0)
1120
1106
 
1121
1107
  is_first = is_firsts.get(index, True)
1122
1108
  stream_buffer = stream_buffers.get(index, "")
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List, Optional
6
6
  import torch
7
7
 
8
8
  import sglang.srt.sampling.penaltylib as penaltylib
9
- from sglang.srt.constrained import RegexGuide
9
+ from sglang.srt.constrained.grammar import Grammar
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -29,11 +29,9 @@ class SamplingBatchInfo:
29
29
  # Bias Tensors
30
30
  vocab_size: int
31
31
  logit_bias: torch.Tensor = None
32
- vocab_mask: torch.Tensor = None
32
+ vocab_mask: Optional[torch.Tensor] = None
33
33
 
34
- # FSM states
35
- regex_fsms: List[RegexGuide] = None
36
- regex_fsm_states: List[int] = None
34
+ grammars: Optional[List[Optional[Grammar]]] = None
37
35
 
38
36
  # Penalizer
39
37
  penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
@@ -136,8 +134,7 @@ class SamplingBatchInfo:
136
134
  self.linear_penalties = penalizer.apply(self.linear_penalties)
137
135
 
138
136
  def update_regex_vocab_mask(self):
139
- has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
140
- if not has_regex:
137
+ if not self.grammars or not any(grammar for grammar in self.grammars):
141
138
  self.vocab_mask = None
142
139
  return
143
140
 
@@ -147,12 +144,9 @@ class SamplingBatchInfo:
147
144
  dtype=torch.bool,
148
145
  device=self.device,
149
146
  )
150
- for i, regex_fsm in enumerate(self.regex_fsms):
151
- if regex_fsm is not None:
152
- self.vocab_mask[i].fill_(1)
153
- self.vocab_mask[i][
154
- regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
155
- ] = 0
147
+ for i, grammar in enumerate(self.grammars):
148
+ if grammar is not None:
149
+ grammar.fill_vocab_mask(self.vocab_mask[i], self.vocab_size)
156
150
 
157
151
  def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
158
152
  if self.penalizer_orchestrator:
sglang/srt/server.py CHANGED
@@ -53,7 +53,6 @@ from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
53
53
  from sglang.srt.managers.io_struct import (
54
54
  EmbeddingReqInput,
55
55
  GenerateReqInput,
56
- RewardReqInput,
57
56
  UpdateWeightReqInput,
58
57
  )
59
58
  from sglang.srt.managers.scheduler import run_scheduler_process
@@ -91,7 +90,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
91
90
 
92
91
 
93
92
  app = FastAPI()
94
- tokenizer_manager = None
93
+ tokenizer_manager: TokenizerManager = None
95
94
 
96
95
  app.add_middleware(
97
96
  CORSMiddleware,
@@ -139,7 +138,7 @@ async def get_server_args():
139
138
  return dataclasses.asdict(tokenizer_manager.server_args)
140
139
 
141
140
 
142
- @app.get("/flush_cache")
141
+ @app.post("/flush_cache")
143
142
  async def flush_cache():
144
143
  """Flush the radix cache."""
145
144
  tokenizer_manager.flush_cache()
@@ -177,9 +176,10 @@ async def get_memory_pool_size():
177
176
  """Get the memory pool size in number of tokens"""
178
177
  try:
179
178
  ret = await tokenizer_manager.get_memory_pool_size()
180
- return ret.size
179
+
180
+ return ret
181
181
  except Exception as e:
182
- return JSONResponse(
182
+ return ORJSONResponse(
183
183
  {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
184
184
  )
185
185
 
@@ -253,8 +253,8 @@ app.post("/encode")(encode_request)
253
253
  app.put("/encode")(encode_request)
254
254
 
255
255
 
256
- async def judge_request(obj: RewardReqInput, request: Request):
257
- """Handle a reward model request."""
256
+ async def judge_request(obj: EmbeddingReqInput, request: Request):
257
+ """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
258
258
  try:
259
259
  ret = await tokenizer_manager.generate_request(obj, request).__anext__()
260
260
  return ret
@@ -441,7 +441,7 @@ def launch_server(
441
441
 
442
442
  # Send a warmup request
443
443
  t = threading.Thread(
444
- target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
444
+ target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
445
445
  )
446
446
  t.start()
447
447
 
@@ -496,7 +496,7 @@ def _set_envs_and_config(server_args: ServerArgs):
496
496
  mp.set_start_method("spawn", force=True)
497
497
 
498
498
 
499
- def _wait_and_warmup(server_args, pipe_finish_writer, pid):
499
+ def _wait_and_warmup(server_args, pipe_finish_writer):
500
500
  headers = {}
501
501
  url = server_args.url()
502
502
  if server_args.api_key:
@@ -519,7 +519,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
519
519
  if pipe_finish_writer is not None:
520
520
  pipe_finish_writer.send(last_traceback)
521
521
  logger.error(f"Initialization failed. warmup error: {last_traceback}")
522
- kill_child_process(pid, including_parent=False)
522
+ kill_child_process(include_self=True)
523
523
  return
524
524
 
525
525
  model_info = res.json()
@@ -551,7 +551,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
551
551
  if pipe_finish_writer is not None:
552
552
  pipe_finish_writer.send(last_traceback)
553
553
  logger.error(f"Initialization failed. warmup error: {last_traceback}")
554
- kill_child_process(pid, including_parent=False)
554
+ kill_child_process(include_self=True)
555
555
  return
556
556
 
557
557
  # logger.info(f"{res.json()=}")
@@ -617,7 +617,7 @@ class Runtime:
617
617
 
618
618
  def shutdown(self):
619
619
  if self.pid is not None:
620
- kill_child_process(self.pid)
620
+ kill_child_process(self.pid, include_self=True)
621
621
  self.pid = None
622
622
 
623
623
  def cache_prefix(self, prefix: str):
@@ -696,24 +696,8 @@ class Runtime:
696
696
  self,
697
697
  prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
698
698
  ):
699
- if isinstance(prompt, str) or isinstance(prompt[0], str):
700
- # embedding
701
- json_data = {
702
- "text": prompt,
703
- }
704
- response = requests.post(
705
- self.url + "/encode",
706
- json=json_data,
707
- )
708
- else:
709
- # reward
710
- json_data = {
711
- "conv": prompt,
712
- }
713
- response = requests.post(
714
- self.url + "/judge",
715
- json=json_data,
716
- )
699
+ json_data = {"text": prompt}
700
+ response = requests.post(self.url + "/encode", json=json_data)
717
701
  return json.dumps(response.json())
718
702
 
719
703
  def __del__(self):
@@ -736,24 +720,32 @@ class Engine:
736
720
 
737
721
  # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
738
722
  atexit.register(self.shutdown)
723
+
724
+ # runtime server default log level is log
725
+ # offline engine works in scripts, so we set it to error
726
+
727
+ if 'log_level' not in kwargs:
728
+ kwargs['log_level'] = 'error'
739
729
 
740
730
  server_args = ServerArgs(*args, **kwargs)
741
731
  launch_engine(server_args=server_args)
742
732
 
743
733
  def generate(
744
734
  self,
745
- prompt: Union[str, List[str]],
735
+ # The input prompt. It can be a single prompt or a batch of prompts.
736
+ prompt: Optional[Union[List[str], str]] = None,
746
737
  sampling_params: Optional[Dict] = None,
738
+ # The token ids for text; one can either specify text or input_ids.
739
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None,
747
740
  return_logprob: Optional[Union[List[bool], bool]] = False,
748
741
  logprob_start_len: Optional[Union[List[int], int]] = None,
749
742
  top_logprobs_num: Optional[Union[List[int], int]] = None,
750
743
  lora_path: Optional[List[Optional[str]]] = None,
751
744
  stream: bool = False,
752
745
  ):
753
- # TODO (ByronHsu): refactor to reduce the duplicated code
754
-
755
746
  obj = GenerateReqInput(
756
747
  text=prompt,
748
+ input_ids=input_ids,
757
749
  sampling_params=sampling_params,
758
750
  return_logprob=return_logprob,
759
751
  logprob_start_len=logprob_start_len,
@@ -791,8 +783,11 @@ class Engine:
791
783
 
792
784
  async def async_generate(
793
785
  self,
794
- prompt: Union[str, List[str]],
786
+ # The input prompt. It can be a single prompt or a batch of prompts.
787
+ prompt: Optional[Union[List[str], str]] = None,
795
788
  sampling_params: Optional[Dict] = None,
789
+ # The token ids for text; one can either specify text or input_ids.
790
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None,
796
791
  return_logprob: Optional[Union[List[bool], bool]] = False,
797
792
  logprob_start_len: Optional[Union[List[int], int]] = None,
798
793
  top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -801,6 +796,7 @@ class Engine:
801
796
  ):
802
797
  obj = GenerateReqInput(
803
798
  text=prompt,
799
+ input_ids=input_ids,
804
800
  sampling_params=sampling_params,
805
801
  return_logprob=return_logprob,
806
802
  logprob_start_len=logprob_start_len,
@@ -834,7 +830,7 @@ class Engine:
834
830
  return ret
835
831
 
836
832
  def shutdown(self):
837
- kill_child_process(os.getpid(), including_parent=False)
833
+ kill_child_process()
838
834
 
839
835
  def get_tokenizer(self):
840
836
  global tokenizer_manager
sglang/srt/server_args.py CHANGED
@@ -63,6 +63,7 @@ class ServerArgs:
63
63
  stream_interval: int = 1
64
64
  random_seed: Optional[int] = None
65
65
  constrained_json_whitespace_pattern: Optional[str] = None
66
+ decode_log_interval: int = 40
66
67
 
67
68
  # Logging
68
69
  log_level: str = "info"
@@ -74,6 +75,7 @@ class ServerArgs:
74
75
  api_key: Optional[str] = None
75
76
  file_storage_pth: str = "SGLang_storage"
76
77
  enable_cache_report: bool = False
78
+ watchdog_timeout: float = 600
77
79
 
78
80
  # Data parallelism
79
81
  dp_size: int = 1
@@ -102,6 +104,7 @@ class ServerArgs:
102
104
  # Kernel backend
103
105
  attention_backend: Optional[str] = None
104
106
  sampling_backend: Optional[str] = None
107
+ grammar_backend: Optional[str] = "outlines"
105
108
 
106
109
  # Optimization/debug options
107
110
  disable_flashinfer: bool = False
@@ -118,7 +121,8 @@ class ServerArgs:
118
121
  enable_overlap_schedule: bool = False
119
122
  enable_mixed_chunk: bool = False
120
123
  enable_torch_compile: bool = False
121
- max_torch_compile_bs: int = 32
124
+ torch_compile_max_bs: int = 32
125
+ cuda_graph_max_bs: int = 160
122
126
  torchao_config: str = ""
123
127
  enable_p2p_check: bool = False
124
128
  triton_attention_reduce_in_fp32: bool = False
@@ -427,6 +431,18 @@ class ServerArgs:
427
431
  action="store_true",
428
432
  help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
429
433
  )
434
+ parser.add_argument(
435
+ "--watchdog-timeout",
436
+ type=float,
437
+ default=ServerArgs.watchdog_timeout,
438
+ help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
439
+ )
440
+ parser.add_argument(
441
+ "--decode-log-interval",
442
+ type=int,
443
+ default=ServerArgs.decode_log_interval,
444
+ help="The log interval of decode batch"
445
+ )
430
446
 
431
447
  # Data parallelism
432
448
  parser.add_argument(
@@ -537,6 +553,13 @@ class ServerArgs:
537
553
  default=ServerArgs.sampling_backend,
538
554
  help="Choose the kernels for sampling layers.",
539
555
  )
556
+ parser.add_argument(
557
+ "--grammar-backend",
558
+ type=str,
559
+ choices=["xgrammar", "outlines"],
560
+ default=ServerArgs.grammar_backend,
561
+ help="Choose the backend for constrained decoding.",
562
+ )
540
563
 
541
564
  # Optimization/debug options
542
565
  parser.add_argument(
@@ -611,11 +634,17 @@ class ServerArgs:
611
634
  help="Optimize the model with torch.compile. Experimental feature.",
612
635
  )
613
636
  parser.add_argument(
614
- "--max-torch-compile-bs",
637
+ "--torch-compile-max-bs",
615
638
  type=int,
616
- default=ServerArgs.max_torch_compile_bs,
639
+ default=ServerArgs.torch_compile_max_bs,
617
640
  help="Set the maximum batch size when using torch compile.",
618
641
  )
642
+ parser.add_argument(
643
+ "--cuda-graph-max-bs",
644
+ type=int,
645
+ default=ServerArgs.cuda_graph_max_bs,
646
+ help="Set the maximum batch size for cuda graph.",
647
+ )
619
648
  parser.add_argument(
620
649
  "--torchao-config",
621
650
  type=str,
@@ -712,11 +741,11 @@ class PortArgs:
712
741
 
713
742
  @staticmethod
714
743
  def init_new(server_args) -> "PortArgs":
715
- port = server_args.port + 1
744
+ port = server_args.port + 42
716
745
  while True:
717
746
  if is_port_available(port):
718
747
  break
719
- port += 1
748
+ port += 42
720
749
 
721
750
  return PortArgs(
722
751
  tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
sglang/srt/utils.py CHANGED
@@ -35,6 +35,7 @@ import psutil
35
35
  import requests
36
36
  import torch
37
37
  import torch.distributed as dist
38
+ import zmq
38
39
  from fastapi.responses import ORJSONResponse
39
40
  from packaging import version as pkg_version
40
41
  from torch import nn
@@ -203,56 +204,6 @@ def is_port_available(port):
203
204
  return False
204
205
 
205
206
 
206
- def is_multimodal_model(model_architectures):
207
- if (
208
- "LlavaLlamaForCausalLM" in model_architectures
209
- or "LlavaQwenForCausalLM" in model_architectures
210
- or "LlavaMistralForCausalLM" in model_architectures
211
- or "LlavaVidForCausalLM" in model_architectures
212
- or "MllamaForConditionalGeneration" in model_architectures
213
- or "Qwen2VLForConditionalGeneration" in model_architectures
214
- ):
215
- return True
216
- else:
217
- return False
218
-
219
-
220
- def is_attention_free_model(model_architectures):
221
- return False
222
-
223
-
224
- def model_has_inner_state(model_architectures):
225
- return False
226
-
227
-
228
- def is_embedding_model(model_architectures):
229
- if (
230
- "LlamaEmbeddingModel" in model_architectures
231
- or "MistralModel" in model_architectures
232
- or "LlamaForSequenceClassification" in model_architectures
233
- or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
234
- ):
235
- return True
236
- else:
237
- return False
238
-
239
-
240
- def is_generation_model(model_architectures, is_embedding: bool = False):
241
- # We have two ways to determine whether a model is a generative model.
242
- # 1. Check the model architectue
243
- # 2. check the `is_embedding` server args
244
-
245
- if (
246
- "LlamaEmbeddingModel" in model_architectures
247
- or "MistralModel" in model_architectures
248
- or "LlamaForSequenceClassification" in model_architectures
249
- or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
250
- ):
251
- return False
252
- else:
253
- return not is_embedding
254
-
255
-
256
207
  def decode_video_base64(video_base64):
257
208
  from PIL import Image
258
209
 
@@ -397,17 +348,26 @@ def kill_parent_process():
397
348
  """Kill the parent process and all children of the parent process."""
398
349
  current_process = psutil.Process()
399
350
  parent_process = current_process.parent()
400
- kill_child_process(parent_process.pid, skip_pid=current_process.pid)
351
+ kill_child_process(
352
+ parent_process.pid, include_self=True, skip_pid=current_process.pid
353
+ )
354
+ try:
355
+ current_process.kill()
356
+ except psutil.NoSuchProcess:
357
+ pass
401
358
 
402
359
 
403
- def kill_child_process(pid, including_parent=True, skip_pid=None):
360
+ def kill_child_process(pid=None, include_self=False, skip_pid=None):
404
361
  """Kill the process and all its children process."""
362
+ if pid is None:
363
+ pid = os.getpid()
364
+
405
365
  try:
406
- parent = psutil.Process(pid)
366
+ itself = psutil.Process(pid)
407
367
  except psutil.NoSuchProcess:
408
368
  return
409
369
 
410
- children = parent.children(recursive=True)
370
+ children = itself.children(recursive=True)
411
371
  for child in children:
412
372
  if child.pid == skip_pid:
413
373
  continue
@@ -416,9 +376,9 @@ def kill_child_process(pid, including_parent=True, skip_pid=None):
416
376
  except psutil.NoSuchProcess:
417
377
  pass
418
378
 
419
- if including_parent:
379
+ if include_self:
420
380
  try:
421
- parent.kill()
381
+ itself.kill()
422
382
  except psutil.NoSuchProcess:
423
383
  pass
424
384
 
@@ -720,3 +680,27 @@ def first_rank_print(*args, **kwargs):
720
680
  print(*args, **kwargs)
721
681
  else:
722
682
  pass
683
+
684
+
685
+ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str):
686
+ mem = psutil.virtual_memory()
687
+ total_mem = mem.total / 1024**3
688
+ available_mem = mem.available / 1024**3
689
+ if total_mem > 32 and available_mem > 16:
690
+ buf_size = int(0.5 * 1024**3)
691
+ else:
692
+ buf_size = -1
693
+
694
+ socket = context.socket(socket_type)
695
+ if socket_type == zmq.PUSH:
696
+ socket.setsockopt(zmq.SNDHWM, 0)
697
+ socket.setsockopt(zmq.SNDBUF, buf_size)
698
+ socket.connect(f"ipc://{endpoint}")
699
+ elif socket_type == zmq.PULL:
700
+ socket.setsockopt(zmq.RCVHWM, 0)
701
+ socket.setsockopt(zmq.RCVBUF, buf_size)
702
+ socket.bind(f"ipc://{endpoint}")
703
+ else:
704
+ raise ValueError(f"Unsupported socket type: {socket_type}")
705
+
706
+ return socket