sglang 0.3.6.post3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +4 -0
  2. sglang/bench_serving.py +13 -0
  3. sglang/check_env.py +1 -1
  4. sglang/srt/_custom_ops.py +118 -0
  5. sglang/srt/configs/device_config.py +17 -0
  6. sglang/srt/configs/load_config.py +84 -0
  7. sglang/srt/configs/model_config.py +161 -4
  8. sglang/srt/configs/qwen2vl.py +5 -8
  9. sglang/srt/constrained/outlines_backend.py +6 -1
  10. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  11. sglang/srt/distributed/__init__.py +3 -0
  12. sglang/srt/distributed/communication_op.py +34 -0
  13. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  14. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  15. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  16. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  17. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  21. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  22. sglang/srt/distributed/parallel_state.py +1275 -0
  23. sglang/srt/distributed/utils.py +223 -0
  24. sglang/srt/hf_transformers_utils.py +37 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  26. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  27. sglang/srt/layers/fused_moe_patch.py +20 -11
  28. sglang/srt/layers/linear.py +1 -0
  29. sglang/srt/layers/logits_processor.py +17 -3
  30. sglang/srt/layers/quantization/__init__.py +34 -0
  31. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  32. sglang/srt/lora/lora.py +1 -1
  33. sglang/srt/managers/io_struct.py +48 -2
  34. sglang/srt/managers/schedule_batch.py +18 -14
  35. sglang/srt/managers/schedule_policy.py +7 -4
  36. sglang/srt/managers/scheduler.py +76 -20
  37. sglang/srt/managers/tokenizer_manager.py +166 -68
  38. sglang/srt/managers/tp_worker.py +36 -3
  39. sglang/srt/managers/tp_worker_overlap_thread.py +21 -3
  40. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  41. sglang/srt/model_executor/forward_batch_info.py +9 -4
  42. sglang/srt/model_executor/model_runner.py +136 -150
  43. sglang/srt/model_loader/__init__.py +34 -0
  44. sglang/srt/model_loader/loader.py +1139 -0
  45. sglang/srt/model_loader/utils.py +41 -0
  46. sglang/srt/model_loader/weight_utils.py +640 -0
  47. sglang/srt/models/baichuan.py +9 -10
  48. sglang/srt/models/chatglm.py +6 -15
  49. sglang/srt/models/commandr.py +2 -3
  50. sglang/srt/models/dbrx.py +2 -3
  51. sglang/srt/models/deepseek.py +4 -11
  52. sglang/srt/models/deepseek_v2.py +3 -11
  53. sglang/srt/models/exaone.py +2 -3
  54. sglang/srt/models/gemma.py +2 -6
  55. sglang/srt/models/gemma2.py +3 -14
  56. sglang/srt/models/gemma2_reward.py +0 -1
  57. sglang/srt/models/gpt2.py +5 -12
  58. sglang/srt/models/gpt_bigcode.py +6 -22
  59. sglang/srt/models/grok.py +3 -3
  60. sglang/srt/models/internlm2.py +2 -3
  61. sglang/srt/models/internlm2_reward.py +0 -1
  62. sglang/srt/models/llama.py +97 -27
  63. sglang/srt/models/llama_classification.py +1 -2
  64. sglang/srt/models/llama_embedding.py +1 -2
  65. sglang/srt/models/llama_reward.py +2 -3
  66. sglang/srt/models/llava.py +1 -4
  67. sglang/srt/models/llavavid.py +1 -2
  68. sglang/srt/models/minicpm.py +4 -7
  69. sglang/srt/models/minicpm3.py +6 -19
  70. sglang/srt/models/mixtral.py +12 -5
  71. sglang/srt/models/mixtral_quant.py +2 -3
  72. sglang/srt/models/mllama.py +3 -7
  73. sglang/srt/models/olmo.py +2 -8
  74. sglang/srt/models/olmo2.py +0 -1
  75. sglang/srt/models/olmoe.py +3 -5
  76. sglang/srt/models/phi3_small.py +8 -8
  77. sglang/srt/models/qwen.py +2 -3
  78. sglang/srt/models/qwen2.py +10 -9
  79. sglang/srt/models/qwen2_moe.py +4 -11
  80. sglang/srt/models/qwen2_vl.py +2 -6
  81. sglang/srt/models/registry.py +99 -0
  82. sglang/srt/models/stablelm.py +2 -3
  83. sglang/srt/models/torch_native_llama.py +6 -12
  84. sglang/srt/models/xverse.py +2 -4
  85. sglang/srt/models/xverse_moe.py +4 -11
  86. sglang/srt/models/yivl.py +2 -3
  87. sglang/srt/openai_api/adapter.py +9 -5
  88. sglang/srt/openai_api/protocol.py +1 -0
  89. sglang/srt/server.py +267 -170
  90. sglang/srt/server_args.py +65 -31
  91. sglang/srt/utils.py +245 -28
  92. sglang/test/test_utils.py +7 -0
  93. sglang/version.py +1 -1
  94. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/METADATA +1 -1
  95. sglang-0.4.0.dist-info/RECORD +184 -0
  96. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  97. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  98. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  99. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
sglang/srt/server.py CHANGED
@@ -52,8 +52,11 @@ from sglang.srt.managers.io_struct import (
52
52
  CloseSessionReqInput,
53
53
  EmbeddingReqInput,
54
54
  GenerateReqInput,
55
+ GetWeightsByNameReqInput,
56
+ InitWeightsUpdateGroupReqInput,
55
57
  OpenSessionReqInput,
56
- UpdateWeightReqInput,
58
+ UpdateWeightFromDiskReqInput,
59
+ UpdateWeightsFromDistributedReqInput,
57
60
  )
58
61
  from sglang.srt.managers.scheduler import run_scheduler_process
59
62
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -150,13 +153,11 @@ async def get_model_info():
150
153
 
151
154
  @app.get("/get_server_info")
152
155
  async def get_server_info():
153
- try:
154
- return await _get_server_info()
155
-
156
- except Exception as e:
157
- return ORJSONResponse(
158
- {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
159
- )
156
+ return {
157
+ **dataclasses.asdict(tokenizer_manager.server_args), # server args
158
+ **scheduler_info,
159
+ "version": __version__,
160
+ }
160
161
 
161
162
 
162
163
  @app.post("/flush_cache")
@@ -192,11 +193,11 @@ async def stop_profile_async():
192
193
  )
193
194
 
194
195
 
195
- @app.post("/update_weights")
196
+ @app.post("/update_weights_from_disk")
196
197
  @time_func_latency
197
- async def update_weights(obj: UpdateWeightReqInput, request: Request):
198
- """Update the weights inplace without re-launching the server."""
199
- success, message = await tokenizer_manager.update_weights(obj, request)
198
+ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
199
+ """Update the weights from disk inplace without re-launching the server."""
200
+ success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
200
201
  content = {"success": success, "message": message}
201
202
  if success:
202
203
  return ORJSONResponse(
@@ -210,6 +211,52 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
210
211
  )
211
212
 
212
213
 
214
+ @app.post("/init_weights_update_group")
215
+ async def init_weights_update_group(
216
+ obj: InitWeightsUpdateGroupReqInput, request: Request
217
+ ):
218
+ """Initialize the parameter update group."""
219
+ success, message = await tokenizer_manager.init_weights_update_group(obj, request)
220
+ content = {"success": success, "message": message}
221
+ if success:
222
+ return ORJSONResponse(content, status_code=200)
223
+ else:
224
+ return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
225
+
226
+
227
+ @app.post("/update_weights_from_distributed")
228
+ async def update_weights_from_distributed(
229
+ obj: UpdateWeightsFromDistributedReqInput, request: Request
230
+ ):
231
+ """Update model parameter from distributed online."""
232
+ success, message = await tokenizer_manager.update_weights_from_distributed(
233
+ obj, request
234
+ )
235
+ content = {"success": success, "message": message}
236
+ if success:
237
+ return ORJSONResponse(content, status_code=200)
238
+ else:
239
+ return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
240
+
241
+
242
+ @app.api_route("/get_weights_by_name", methods=["GET", "POST"])
243
+ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
244
+ """Get model parameter by name."""
245
+ try:
246
+ ret = await tokenizer_manager.get_weights_by_name(obj, request)
247
+ if ret is None:
248
+ return ORJSONResponse(
249
+ {"error": {"message": "Get parameter by name failed"}},
250
+ status_code=HTTPStatus.BAD_REQUEST,
251
+ )
252
+ else:
253
+ return ORJSONResponse(ret, status_code=200)
254
+ except Exception as e:
255
+ return ORJSONResponse(
256
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
257
+ )
258
+
259
+
213
260
  @app.api_route("/open_session", methods=["GET", "POST"])
214
261
  async def open_session(obj: OpenSessionReqInput, request: Request):
215
262
  """Open a session, and return its unique session id."""
@@ -517,14 +564,6 @@ def launch_server(
517
564
  t.join()
518
565
 
519
566
 
520
- async def _get_server_info():
521
- return {
522
- **dataclasses.asdict(tokenizer_manager.server_args), # server args
523
- **scheduler_info,
524
- "version": __version__,
525
- }
526
-
527
-
528
567
  def _set_envs_and_config(server_args: ServerArgs):
529
568
  # Set global environments
530
569
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -637,11 +676,218 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
637
676
  delete_directory(server_args.model_path)
638
677
 
639
678
 
679
+ STREAM_END_SYMBOL = b"data: [DONE]"
680
+ STREAM_CHUNK_START_SYMBOL = b"data:"
681
+
682
+
683
+ class Engine:
684
+ """
685
+ SRT Engine without an HTTP server layer.
686
+
687
+ This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
688
+ launching the HTTP server adds unnecessary complexity or overhead,
689
+ """
690
+
691
+ def __init__(self, log_level: str = "error", *args, **kwargs):
692
+ """See the arguments in server_args.py::ServerArgs"""
693
+
694
+ # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
695
+ atexit.register(self.shutdown)
696
+
697
+ server_args = ServerArgs(*args, log_level=log_level, **kwargs)
698
+ launch_engine(server_args=server_args)
699
+
700
+ def generate(
701
+ self,
702
+ # The input prompt. It can be a single prompt or a batch of prompts.
703
+ prompt: Optional[Union[List[str], str]] = None,
704
+ sampling_params: Optional[Union[List[Dict], Dict]] = None,
705
+ # The token ids for text; one can either specify text or input_ids.
706
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None,
707
+ return_logprob: Optional[Union[List[bool], bool]] = False,
708
+ logprob_start_len: Optional[Union[List[int], int]] = None,
709
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
710
+ lora_path: Optional[List[Optional[str]]] = None,
711
+ stream: bool = False,
712
+ ):
713
+ obj = GenerateReqInput(
714
+ text=prompt,
715
+ input_ids=input_ids,
716
+ sampling_params=sampling_params,
717
+ return_logprob=return_logprob,
718
+ logprob_start_len=logprob_start_len,
719
+ top_logprobs_num=top_logprobs_num,
720
+ lora_path=lora_path,
721
+ stream=stream,
722
+ )
723
+
724
+ # get the current event loop
725
+ loop = asyncio.get_event_loop()
726
+ ret = loop.run_until_complete(generate_request(obj, None))
727
+
728
+ if stream is True:
729
+
730
+ def generator_wrapper():
731
+ offset = 0
732
+ loop = asyncio.get_event_loop()
733
+ generator = ret.body_iterator
734
+ while True:
735
+ chunk = loop.run_until_complete(generator.__anext__())
736
+
737
+ if chunk.startswith(STREAM_END_SYMBOL):
738
+ break
739
+ else:
740
+ data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
741
+ data["text"] = data["text"][offset:]
742
+ offset += len(data["text"])
743
+ yield data
744
+
745
+ # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
746
+ # however, it allows to wrap the generator as a subfunction and return
747
+ return generator_wrapper()
748
+ else:
749
+ return ret
750
+
751
+ async def async_generate(
752
+ self,
753
+ # The input prompt. It can be a single prompt or a batch of prompts.
754
+ prompt: Optional[Union[List[str], str]] = None,
755
+ sampling_params: Optional[Dict] = None,
756
+ # The token ids for text; one can either specify text or input_ids.
757
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None,
758
+ return_logprob: Optional[Union[List[bool], bool]] = False,
759
+ logprob_start_len: Optional[Union[List[int], int]] = None,
760
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
761
+ lora_path: Optional[List[Optional[str]]] = None,
762
+ stream: bool = False,
763
+ ):
764
+ obj = GenerateReqInput(
765
+ text=prompt,
766
+ input_ids=input_ids,
767
+ sampling_params=sampling_params,
768
+ return_logprob=return_logprob,
769
+ logprob_start_len=logprob_start_len,
770
+ top_logprobs_num=top_logprobs_num,
771
+ lora_path=lora_path,
772
+ stream=stream,
773
+ )
774
+
775
+ ret = await generate_request(obj, None)
776
+
777
+ if stream is True:
778
+ generator = ret.body_iterator
779
+
780
+ async def generator_wrapper():
781
+
782
+ offset = 0
783
+
784
+ while True:
785
+ chunk = await generator.__anext__()
786
+
787
+ if chunk.startswith(STREAM_END_SYMBOL):
788
+ break
789
+ else:
790
+ data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
791
+ data["text"] = data["text"][offset:]
792
+ offset += len(data["text"])
793
+ yield data
794
+
795
+ return generator_wrapper()
796
+ else:
797
+ return ret
798
+
799
+ def shutdown(self):
800
+ kill_process_tree(os.getpid(), include_parent=False)
801
+
802
+ def get_tokenizer(self):
803
+ global tokenizer_manager
804
+
805
+ if tokenizer_manager is None:
806
+ raise ReferenceError("Tokenizer Manager is not initialized.")
807
+ else:
808
+ return tokenizer_manager.tokenizer
809
+
810
+ def encode(
811
+ self,
812
+ prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
813
+ ):
814
+ obj = EmbeddingReqInput(text=prompt)
815
+
816
+ # get the current event loop
817
+ loop = asyncio.get_event_loop()
818
+ return loop.run_until_complete(encode_request(obj, None))
819
+
820
+ def start_profile(self):
821
+ tokenizer_manager.start_profile()
822
+
823
+ def stop_profile(self):
824
+ tokenizer_manager.stop_profile()
825
+
826
+ def get_server_info(self):
827
+ return {
828
+ **dataclasses.asdict(tokenizer_manager.server_args), # server args
829
+ **scheduler_info,
830
+ "version": __version__,
831
+ }
832
+
833
+ def init_weights_update_group(
834
+ self,
835
+ master_address: str,
836
+ master_port: int,
837
+ rank_offset: int,
838
+ world_size: int,
839
+ group_name: str,
840
+ backend: str = "nccl",
841
+ ):
842
+ """Initialize parameter update group."""
843
+ obj = InitWeightsUpdateGroupReqInput(
844
+ master_address=master_address,
845
+ master_port=master_port,
846
+ rank_offset=rank_offset,
847
+ world_size=world_size,
848
+ group_name=group_name,
849
+ backend=backend,
850
+ )
851
+
852
+ async def _init_group():
853
+ return await tokenizer_manager.init_weights_update_group(obj, None)
854
+
855
+ loop = asyncio.get_event_loop()
856
+ return loop.run_until_complete(_init_group())
857
+
858
+ def update_weights_from_distributed(self, name, dtype, shape):
859
+ """Update weights from distributed source."""
860
+ obj = UpdateWeightsFromDistributedReqInput(
861
+ name=name,
862
+ dtype=dtype,
863
+ shape=shape,
864
+ )
865
+
866
+ async def _update_weights():
867
+ return await tokenizer_manager.update_weights_from_distributed(obj, None)
868
+
869
+ loop = asyncio.get_event_loop()
870
+ return loop.run_until_complete(_update_weights())
871
+
872
+ def get_weights_by_name(self, name, truncate_size=100):
873
+ """Get weights by parameter name."""
874
+ obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
875
+
876
+ async def _get_weights():
877
+ return await tokenizer_manager.get_weights_by_name(obj, None)
878
+
879
+ loop = asyncio.get_event_loop()
880
+ return loop.run_until_complete(_get_weights())
881
+
882
+
640
883
  class Runtime:
641
884
  """
642
- A wrapper for the server.
885
+ A wrapper for the HTTP server.
643
886
  This is used for launching the server in a python program without
644
887
  using the commond line interface.
888
+
889
+ It is mainly used for the frontend language.
890
+ You should use the Engine class if you want to do normal offline processing.
645
891
  """
646
892
 
647
893
  def __init__(
@@ -789,152 +1035,3 @@ class Runtime:
789
1035
 
790
1036
  def __del__(self):
791
1037
  self.shutdown()
792
-
793
-
794
- STREAM_END_SYMBOL = b"data: [DONE]"
795
- STREAM_CHUNK_START_SYMBOL = b"data:"
796
-
797
-
798
- class Engine:
799
- """
800
- SRT Engine without an HTTP server layer.
801
-
802
- This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
803
- launching the HTTP server adds unnecessary complexity or overhead,
804
- """
805
-
806
- def __init__(self, log_level: str = "error", *args, **kwargs):
807
- # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
808
- atexit.register(self.shutdown)
809
-
810
- server_args = ServerArgs(*args, log_level=log_level, **kwargs)
811
- launch_engine(server_args=server_args)
812
-
813
- def generate(
814
- self,
815
- # The input prompt. It can be a single prompt or a batch of prompts.
816
- prompt: Optional[Union[List[str], str]] = None,
817
- sampling_params: Optional[Union[List[Dict], Dict]] = None,
818
- # The token ids for text; one can either specify text or input_ids.
819
- input_ids: Optional[Union[List[List[int]], List[int]]] = None,
820
- return_logprob: Optional[Union[List[bool], bool]] = False,
821
- logprob_start_len: Optional[Union[List[int], int]] = None,
822
- top_logprobs_num: Optional[Union[List[int], int]] = None,
823
- lora_path: Optional[List[Optional[str]]] = None,
824
- stream: bool = False,
825
- ):
826
- obj = GenerateReqInput(
827
- text=prompt,
828
- input_ids=input_ids,
829
- sampling_params=sampling_params,
830
- return_logprob=return_logprob,
831
- logprob_start_len=logprob_start_len,
832
- top_logprobs_num=top_logprobs_num,
833
- lora_path=lora_path,
834
- stream=stream,
835
- )
836
-
837
- # get the current event loop
838
- loop = asyncio.get_event_loop()
839
- ret = loop.run_until_complete(generate_request(obj, None))
840
-
841
- if stream is True:
842
-
843
- def generator_wrapper():
844
- offset = 0
845
- loop = asyncio.get_event_loop()
846
- generator = ret.body_iterator
847
- while True:
848
- chunk = loop.run_until_complete(generator.__anext__())
849
-
850
- if chunk.startswith(STREAM_END_SYMBOL):
851
- break
852
- else:
853
- data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
854
- data["text"] = data["text"][offset:]
855
- offset += len(data["text"])
856
- yield data
857
-
858
- # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
859
- # however, it allows to wrap the generator as a subfunction and return
860
- return generator_wrapper()
861
- else:
862
- return ret
863
-
864
- async def async_generate(
865
- self,
866
- # The input prompt. It can be a single prompt or a batch of prompts.
867
- prompt: Optional[Union[List[str], str]] = None,
868
- sampling_params: Optional[Dict] = None,
869
- # The token ids for text; one can either specify text or input_ids.
870
- input_ids: Optional[Union[List[List[int]], List[int]]] = None,
871
- return_logprob: Optional[Union[List[bool], bool]] = False,
872
- logprob_start_len: Optional[Union[List[int], int]] = None,
873
- top_logprobs_num: Optional[Union[List[int], int]] = None,
874
- lora_path: Optional[List[Optional[str]]] = None,
875
- stream: bool = False,
876
- ):
877
- obj = GenerateReqInput(
878
- text=prompt,
879
- input_ids=input_ids,
880
- sampling_params=sampling_params,
881
- return_logprob=return_logprob,
882
- logprob_start_len=logprob_start_len,
883
- top_logprobs_num=top_logprobs_num,
884
- lora_path=lora_path,
885
- stream=stream,
886
- )
887
-
888
- ret = await generate_request(obj, None)
889
-
890
- if stream is True:
891
- generator = ret.body_iterator
892
-
893
- async def generator_wrapper():
894
-
895
- offset = 0
896
-
897
- while True:
898
- chunk = await generator.__anext__()
899
-
900
- if chunk.startswith(STREAM_END_SYMBOL):
901
- break
902
- else:
903
- data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
904
- data["text"] = data["text"][offset:]
905
- offset += len(data["text"])
906
- yield data
907
-
908
- return generator_wrapper()
909
- else:
910
- return ret
911
-
912
- def shutdown(self):
913
- kill_process_tree(os.getpid(), include_parent=False)
914
-
915
- def get_tokenizer(self):
916
- global tokenizer_manager
917
-
918
- if tokenizer_manager is None:
919
- raise ReferenceError("Tokenizer Manager is not initialized.")
920
- else:
921
- return tokenizer_manager.tokenizer
922
-
923
- def encode(
924
- self,
925
- prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
926
- ):
927
- obj = EmbeddingReqInput(text=prompt)
928
-
929
- # get the current event loop
930
- loop = asyncio.get_event_loop()
931
- return loop.run_until_complete(encode_request(obj, None))
932
-
933
- def start_profile(self):
934
- tokenizer_manager.start_profile()
935
-
936
- def stop_profile(self):
937
- tokenizer_manager.stop_profile()
938
-
939
- async def get_server_info(self):
940
- return await _get_server_info()
sglang/srt/server_args.py CHANGED
@@ -20,6 +20,7 @@ import random
20
20
  import tempfile
21
21
  from typing import List, Optional
22
22
 
23
+ from sglang.srt.hf_transformers_utils import check_gguf_file
23
24
  from sglang.srt.utils import (
24
25
  get_amdgpu_memory_capacity,
25
26
  get_nvgpu_memory_capacity,
@@ -49,6 +50,7 @@ class ServerArgs:
49
50
  served_model_name: Optional[str] = None
50
51
  chat_template: Optional[str] = None
51
52
  is_embedding: bool = False
53
+ revision: Optional[str] = None
52
54
 
53
55
  # Port
54
56
  host: str = "127.0.0.1"
@@ -58,7 +60,7 @@ class ServerArgs:
58
60
  mem_fraction_static: Optional[float] = None
59
61
  max_running_requests: Optional[int] = None
60
62
  max_total_tokens: Optional[int] = None
61
- chunked_prefill_size: int = 8192
63
+ chunked_prefill_size: Optional[int] = None
62
64
  max_prefill_tokens: int = 16384
63
65
  schedule_policy: str = "lpm"
64
66
  schedule_conservativeness: float = 1.0
@@ -120,7 +122,7 @@ class ServerArgs:
120
122
  disable_jump_forward: bool = False
121
123
  disable_cuda_graph: bool = False
122
124
  disable_cuda_graph_padding: bool = False
123
- disable_disk_cache: bool = False
125
+ disable_outlines_disk_cache: bool = False
124
126
  disable_custom_all_reduce: bool = False
125
127
  disable_mla: bool = False
126
128
  disable_overlap_schedule: bool = False
@@ -128,7 +130,7 @@ class ServerArgs:
128
130
  enable_dp_attention: bool = False
129
131
  enable_torch_compile: bool = False
130
132
  torch_compile_max_bs: int = 32
131
- cuda_graph_max_bs: int = 160
133
+ cuda_graph_max_bs: Optional[int] = None
132
134
  torchao_config: str = ""
133
135
  enable_nan_detection: bool = False
134
136
  enable_p2p_check: bool = False
@@ -144,19 +146,20 @@ class ServerArgs:
144
146
  if self.served_model_name is None:
145
147
  self.served_model_name = self.model_path
146
148
 
147
- if self.chunked_prefill_size is not None and self.chunked_prefill_size <= 0:
148
- # Disable chunked prefill
149
- self.chunked_prefill_size = None
150
-
151
149
  if self.random_seed is None:
152
150
  self.random_seed = random.randint(0, 1 << 30)
153
151
 
154
- # Mem fraction depends on the tensor parallelism size
152
+ if is_hip():
153
+ gpu_mem = get_amdgpu_memory_capacity()
154
+ else:
155
+ gpu_mem = get_nvgpu_memory_capacity()
156
+
157
+ # Set mem fraction static, which depends on the tensor parallelism size
155
158
  if self.mem_fraction_static is None:
156
159
  if self.tp_size >= 16:
157
160
  self.mem_fraction_static = 0.79
158
161
  elif self.tp_size >= 8:
159
- self.mem_fraction_static = 0.82
162
+ self.mem_fraction_static = 0.81
160
163
  elif self.tp_size >= 4:
161
164
  self.mem_fraction_static = 0.85
162
165
  elif self.tp_size >= 2:
@@ -164,25 +167,35 @@ class ServerArgs:
164
167
  else:
165
168
  self.mem_fraction_static = 0.88
166
169
 
167
- # Adjust for GPUs with small memory capacities
168
- if is_hip():
169
- gpu_mem = get_amdgpu_memory_capacity()
170
- else:
171
- gpu_mem = get_nvgpu_memory_capacity()
172
- if gpu_mem < 25000:
173
- self.chunked_prefill_size //= 4 # make it 2048
174
- self.cuda_graph_max_bs = 4
175
- logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
170
+ # Set chunked prefill size, which depends on the gpu memory capacity
171
+ if self.chunked_prefill_size is None:
172
+ if gpu_mem < 25_000:
173
+ self.chunked_prefill_size = 2048
174
+ else:
175
+ self.chunked_prefill_size = 8192
176
176
 
177
- # Choose kernel backends
178
- if not is_flashinfer_available():
179
- self.attention_backend = "triton"
180
- self.sampling_backend = "pytorch"
177
+ # Set cuda graph max batch size
178
+ if self.cuda_graph_max_bs is None:
179
+ if gpu_mem < 25_000:
180
+ self.cuda_graph_max_bs = 8
181
+ else:
182
+ self.cuda_graph_max_bs = 160
181
183
 
184
+ # Choose kernel backends
182
185
  if self.attention_backend is None:
183
- self.attention_backend = "flashinfer"
186
+ self.attention_backend = (
187
+ "flashinfer" if is_flashinfer_available() else "triton"
188
+ )
184
189
  if self.sampling_backend is None:
185
- self.sampling_backend = "flashinfer"
190
+ self.sampling_backend = (
191
+ "flashinfer" if is_flashinfer_available() else "pytorch"
192
+ )
193
+
194
+ if self.attention_backend == "torch_native":
195
+ logger.warning(
196
+ "Cuda graph is disabled because of using torch native attention backend"
197
+ )
198
+ self.disable_cuda_graph = True
186
199
 
187
200
  # Others
188
201
  if self.enable_dp_attention:
@@ -191,14 +204,20 @@ class ServerArgs:
191
204
  self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
192
205
  self.schedule_conservativeness = self.schedule_conservativeness * 0.3
193
206
  self.disable_overlap_schedule = True
194
- logger.info(
207
+ logger.warning(
195
208
  f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
196
209
  f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
197
210
  f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
198
211
  "Data parallel size is adjusted to be the same as tensor parallel size. "
199
- "Overlap schedule is disabled."
212
+ "Overlap scheduler is disabled."
200
213
  )
201
214
 
215
+ # GGUF
216
+ if (
217
+ self.load_format == "auto" or self.load_format == "gguf"
218
+ ) and check_gguf_file(self.model_path):
219
+ self.quantization = self.load_format = "gguf"
220
+
202
221
  @staticmethod
203
222
  def add_cli_args(parser: argparse.ArgumentParser):
204
223
  # Model and port args
@@ -238,7 +257,7 @@ class ServerArgs:
238
257
  "--load-format",
239
258
  type=str,
240
259
  default=ServerArgs.load_format,
241
- choices=["auto", "pt", "safetensors", "npcache", "dummy"],
260
+ choices=["auto", "pt", "safetensors", "npcache", "dummy", "gguf"],
242
261
  help="The format of the model weights to load. "
243
262
  '"auto" will try to load the weights in the safetensors format '
244
263
  "and fall back to the pytorch bin format if safetensors format "
@@ -248,7 +267,8 @@ class ServerArgs:
248
267
  '"npcache" will load the weights in pytorch format and store '
249
268
  "a numpy cache to speed up the loading. "
250
269
  '"dummy" will initialize the weights with random values, '
251
- "which is mainly for profiling.",
270
+ "which is mainly for profiling."
271
+ '"gguf" will load the weights in the gguf format. ',
252
272
  )
253
273
  parser.add_argument(
254
274
  "--trust-remote-code",
@@ -288,6 +308,7 @@ class ServerArgs:
288
308
  "gptq_marlin",
289
309
  "awq_marlin",
290
310
  "bitsandbytes",
311
+ "gguf",
291
312
  ],
292
313
  help="The quantization method.",
293
314
  )
@@ -321,6 +342,14 @@ class ServerArgs:
321
342
  action="store_true",
322
343
  help="Whether to use a CausalLM as an embedding model.",
323
344
  )
345
+ parser.add_argument(
346
+ "--revision",
347
+ type=str,
348
+ default=None,
349
+ help="The specific model version to use. It can be a branch "
350
+ "name, a tag name, or a commit id. If unspecified, will use "
351
+ "the default version.",
352
+ )
324
353
 
325
354
  # Memory and scheduling
326
355
  parser.add_argument(
@@ -572,7 +601,7 @@ class ServerArgs:
572
601
  parser.add_argument(
573
602
  "--attention-backend",
574
603
  type=str,
575
- choices=["flashinfer", "triton"],
604
+ choices=["flashinfer", "triton", "torch_native"],
576
605
  default=ServerArgs.attention_backend,
577
606
  help="Choose the kernels for attention layers.",
578
607
  )
@@ -613,9 +642,9 @@ class ServerArgs:
613
642
  help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
614
643
  )
615
644
  parser.add_argument(
616
- "--disable-disk-cache",
645
+ "--disable-outlines-disk-cache",
617
646
  action="store_true",
618
- help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
647
+ help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
619
648
  )
620
649
  parser.add_argument(
621
650
  "--disable-custom-all-reduce",
@@ -716,6 +745,11 @@ class ServerArgs:
716
745
  action=DeprecatedAction,
717
746
  help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
718
747
  )
748
+ parser.add_argument(
749
+ "--disable-disk-cache",
750
+ action=DeprecatedAction,
751
+ help="'--disable-disk-cache' is deprecated. Please use '--disable-outlines-disk-cache' instead.",
752
+ )
719
753
 
720
754
  @classmethod
721
755
  def from_cli_args(cls, args: argparse.Namespace):