sglang 0.3.6.post3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +6 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/flashinfer_backend.py +13 -15
- sglang/srt/layers/attention/torch_native_backend.py +285 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +34 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +18 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +76 -20
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +21 -3
- sglang/srt/model_executor/cuda_graph_runner.py +16 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +136 -150
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +2 -3
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +3 -11
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +97 -27
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +12 -5
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -8
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -11
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -12
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/server.py +267 -170
- sglang/srt/server_args.py +65 -31
- sglang/srt/utils.py +245 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/METADATA +1 -1
- sglang-0.4.0.dist-info/RECORD +184 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
sglang/srt/server.py
CHANGED
@@ -52,8 +52,11 @@ from sglang.srt.managers.io_struct import (
|
|
52
52
|
CloseSessionReqInput,
|
53
53
|
EmbeddingReqInput,
|
54
54
|
GenerateReqInput,
|
55
|
+
GetWeightsByNameReqInput,
|
56
|
+
InitWeightsUpdateGroupReqInput,
|
55
57
|
OpenSessionReqInput,
|
56
|
-
|
58
|
+
UpdateWeightFromDiskReqInput,
|
59
|
+
UpdateWeightsFromDistributedReqInput,
|
57
60
|
)
|
58
61
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
59
62
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
@@ -150,13 +153,11 @@ async def get_model_info():
|
|
150
153
|
|
151
154
|
@app.get("/get_server_info")
|
152
155
|
async def get_server_info():
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
159
|
-
)
|
156
|
+
return {
|
157
|
+
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
158
|
+
**scheduler_info,
|
159
|
+
"version": __version__,
|
160
|
+
}
|
160
161
|
|
161
162
|
|
162
163
|
@app.post("/flush_cache")
|
@@ -192,11 +193,11 @@ async def stop_profile_async():
|
|
192
193
|
)
|
193
194
|
|
194
195
|
|
195
|
-
@app.post("/
|
196
|
+
@app.post("/update_weights_from_disk")
|
196
197
|
@time_func_latency
|
197
|
-
async def
|
198
|
-
"""Update the weights inplace without re-launching the server."""
|
199
|
-
success, message = await tokenizer_manager.
|
198
|
+
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
199
|
+
"""Update the weights from disk inplace without re-launching the server."""
|
200
|
+
success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
|
200
201
|
content = {"success": success, "message": message}
|
201
202
|
if success:
|
202
203
|
return ORJSONResponse(
|
@@ -210,6 +211,52 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
|
210
211
|
)
|
211
212
|
|
212
213
|
|
214
|
+
@app.post("/init_weights_update_group")
|
215
|
+
async def init_weights_update_group(
|
216
|
+
obj: InitWeightsUpdateGroupReqInput, request: Request
|
217
|
+
):
|
218
|
+
"""Initialize the parameter update group."""
|
219
|
+
success, message = await tokenizer_manager.init_weights_update_group(obj, request)
|
220
|
+
content = {"success": success, "message": message}
|
221
|
+
if success:
|
222
|
+
return ORJSONResponse(content, status_code=200)
|
223
|
+
else:
|
224
|
+
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
225
|
+
|
226
|
+
|
227
|
+
@app.post("/update_weights_from_distributed")
|
228
|
+
async def update_weights_from_distributed(
|
229
|
+
obj: UpdateWeightsFromDistributedReqInput, request: Request
|
230
|
+
):
|
231
|
+
"""Update model parameter from distributed online."""
|
232
|
+
success, message = await tokenizer_manager.update_weights_from_distributed(
|
233
|
+
obj, request
|
234
|
+
)
|
235
|
+
content = {"success": success, "message": message}
|
236
|
+
if success:
|
237
|
+
return ORJSONResponse(content, status_code=200)
|
238
|
+
else:
|
239
|
+
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
240
|
+
|
241
|
+
|
242
|
+
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
243
|
+
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
244
|
+
"""Get model parameter by name."""
|
245
|
+
try:
|
246
|
+
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
247
|
+
if ret is None:
|
248
|
+
return ORJSONResponse(
|
249
|
+
{"error": {"message": "Get parameter by name failed"}},
|
250
|
+
status_code=HTTPStatus.BAD_REQUEST,
|
251
|
+
)
|
252
|
+
else:
|
253
|
+
return ORJSONResponse(ret, status_code=200)
|
254
|
+
except Exception as e:
|
255
|
+
return ORJSONResponse(
|
256
|
+
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
257
|
+
)
|
258
|
+
|
259
|
+
|
213
260
|
@app.api_route("/open_session", methods=["GET", "POST"])
|
214
261
|
async def open_session(obj: OpenSessionReqInput, request: Request):
|
215
262
|
"""Open a session, and return its unique session id."""
|
@@ -517,14 +564,6 @@ def launch_server(
|
|
517
564
|
t.join()
|
518
565
|
|
519
566
|
|
520
|
-
async def _get_server_info():
|
521
|
-
return {
|
522
|
-
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
523
|
-
**scheduler_info,
|
524
|
-
"version": __version__,
|
525
|
-
}
|
526
|
-
|
527
|
-
|
528
567
|
def _set_envs_and_config(server_args: ServerArgs):
|
529
568
|
# Set global environments
|
530
569
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
@@ -637,11 +676,218 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
637
676
|
delete_directory(server_args.model_path)
|
638
677
|
|
639
678
|
|
679
|
+
STREAM_END_SYMBOL = b"data: [DONE]"
|
680
|
+
STREAM_CHUNK_START_SYMBOL = b"data:"
|
681
|
+
|
682
|
+
|
683
|
+
class Engine:
|
684
|
+
"""
|
685
|
+
SRT Engine without an HTTP server layer.
|
686
|
+
|
687
|
+
This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
|
688
|
+
launching the HTTP server adds unnecessary complexity or overhead,
|
689
|
+
"""
|
690
|
+
|
691
|
+
def __init__(self, log_level: str = "error", *args, **kwargs):
|
692
|
+
"""See the arguments in server_args.py::ServerArgs"""
|
693
|
+
|
694
|
+
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
695
|
+
atexit.register(self.shutdown)
|
696
|
+
|
697
|
+
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
698
|
+
launch_engine(server_args=server_args)
|
699
|
+
|
700
|
+
def generate(
|
701
|
+
self,
|
702
|
+
# The input prompt. It can be a single prompt or a batch of prompts.
|
703
|
+
prompt: Optional[Union[List[str], str]] = None,
|
704
|
+
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
705
|
+
# The token ids for text; one can either specify text or input_ids.
|
706
|
+
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
707
|
+
return_logprob: Optional[Union[List[bool], bool]] = False,
|
708
|
+
logprob_start_len: Optional[Union[List[int], int]] = None,
|
709
|
+
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
710
|
+
lora_path: Optional[List[Optional[str]]] = None,
|
711
|
+
stream: bool = False,
|
712
|
+
):
|
713
|
+
obj = GenerateReqInput(
|
714
|
+
text=prompt,
|
715
|
+
input_ids=input_ids,
|
716
|
+
sampling_params=sampling_params,
|
717
|
+
return_logprob=return_logprob,
|
718
|
+
logprob_start_len=logprob_start_len,
|
719
|
+
top_logprobs_num=top_logprobs_num,
|
720
|
+
lora_path=lora_path,
|
721
|
+
stream=stream,
|
722
|
+
)
|
723
|
+
|
724
|
+
# get the current event loop
|
725
|
+
loop = asyncio.get_event_loop()
|
726
|
+
ret = loop.run_until_complete(generate_request(obj, None))
|
727
|
+
|
728
|
+
if stream is True:
|
729
|
+
|
730
|
+
def generator_wrapper():
|
731
|
+
offset = 0
|
732
|
+
loop = asyncio.get_event_loop()
|
733
|
+
generator = ret.body_iterator
|
734
|
+
while True:
|
735
|
+
chunk = loop.run_until_complete(generator.__anext__())
|
736
|
+
|
737
|
+
if chunk.startswith(STREAM_END_SYMBOL):
|
738
|
+
break
|
739
|
+
else:
|
740
|
+
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
741
|
+
data["text"] = data["text"][offset:]
|
742
|
+
offset += len(data["text"])
|
743
|
+
yield data
|
744
|
+
|
745
|
+
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
|
746
|
+
# however, it allows to wrap the generator as a subfunction and return
|
747
|
+
return generator_wrapper()
|
748
|
+
else:
|
749
|
+
return ret
|
750
|
+
|
751
|
+
async def async_generate(
|
752
|
+
self,
|
753
|
+
# The input prompt. It can be a single prompt or a batch of prompts.
|
754
|
+
prompt: Optional[Union[List[str], str]] = None,
|
755
|
+
sampling_params: Optional[Dict] = None,
|
756
|
+
# The token ids for text; one can either specify text or input_ids.
|
757
|
+
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
758
|
+
return_logprob: Optional[Union[List[bool], bool]] = False,
|
759
|
+
logprob_start_len: Optional[Union[List[int], int]] = None,
|
760
|
+
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
761
|
+
lora_path: Optional[List[Optional[str]]] = None,
|
762
|
+
stream: bool = False,
|
763
|
+
):
|
764
|
+
obj = GenerateReqInput(
|
765
|
+
text=prompt,
|
766
|
+
input_ids=input_ids,
|
767
|
+
sampling_params=sampling_params,
|
768
|
+
return_logprob=return_logprob,
|
769
|
+
logprob_start_len=logprob_start_len,
|
770
|
+
top_logprobs_num=top_logprobs_num,
|
771
|
+
lora_path=lora_path,
|
772
|
+
stream=stream,
|
773
|
+
)
|
774
|
+
|
775
|
+
ret = await generate_request(obj, None)
|
776
|
+
|
777
|
+
if stream is True:
|
778
|
+
generator = ret.body_iterator
|
779
|
+
|
780
|
+
async def generator_wrapper():
|
781
|
+
|
782
|
+
offset = 0
|
783
|
+
|
784
|
+
while True:
|
785
|
+
chunk = await generator.__anext__()
|
786
|
+
|
787
|
+
if chunk.startswith(STREAM_END_SYMBOL):
|
788
|
+
break
|
789
|
+
else:
|
790
|
+
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
791
|
+
data["text"] = data["text"][offset:]
|
792
|
+
offset += len(data["text"])
|
793
|
+
yield data
|
794
|
+
|
795
|
+
return generator_wrapper()
|
796
|
+
else:
|
797
|
+
return ret
|
798
|
+
|
799
|
+
def shutdown(self):
|
800
|
+
kill_process_tree(os.getpid(), include_parent=False)
|
801
|
+
|
802
|
+
def get_tokenizer(self):
|
803
|
+
global tokenizer_manager
|
804
|
+
|
805
|
+
if tokenizer_manager is None:
|
806
|
+
raise ReferenceError("Tokenizer Manager is not initialized.")
|
807
|
+
else:
|
808
|
+
return tokenizer_manager.tokenizer
|
809
|
+
|
810
|
+
def encode(
|
811
|
+
self,
|
812
|
+
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
813
|
+
):
|
814
|
+
obj = EmbeddingReqInput(text=prompt)
|
815
|
+
|
816
|
+
# get the current event loop
|
817
|
+
loop = asyncio.get_event_loop()
|
818
|
+
return loop.run_until_complete(encode_request(obj, None))
|
819
|
+
|
820
|
+
def start_profile(self):
|
821
|
+
tokenizer_manager.start_profile()
|
822
|
+
|
823
|
+
def stop_profile(self):
|
824
|
+
tokenizer_manager.stop_profile()
|
825
|
+
|
826
|
+
def get_server_info(self):
|
827
|
+
return {
|
828
|
+
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
829
|
+
**scheduler_info,
|
830
|
+
"version": __version__,
|
831
|
+
}
|
832
|
+
|
833
|
+
def init_weights_update_group(
|
834
|
+
self,
|
835
|
+
master_address: str,
|
836
|
+
master_port: int,
|
837
|
+
rank_offset: int,
|
838
|
+
world_size: int,
|
839
|
+
group_name: str,
|
840
|
+
backend: str = "nccl",
|
841
|
+
):
|
842
|
+
"""Initialize parameter update group."""
|
843
|
+
obj = InitWeightsUpdateGroupReqInput(
|
844
|
+
master_address=master_address,
|
845
|
+
master_port=master_port,
|
846
|
+
rank_offset=rank_offset,
|
847
|
+
world_size=world_size,
|
848
|
+
group_name=group_name,
|
849
|
+
backend=backend,
|
850
|
+
)
|
851
|
+
|
852
|
+
async def _init_group():
|
853
|
+
return await tokenizer_manager.init_weights_update_group(obj, None)
|
854
|
+
|
855
|
+
loop = asyncio.get_event_loop()
|
856
|
+
return loop.run_until_complete(_init_group())
|
857
|
+
|
858
|
+
def update_weights_from_distributed(self, name, dtype, shape):
|
859
|
+
"""Update weights from distributed source."""
|
860
|
+
obj = UpdateWeightsFromDistributedReqInput(
|
861
|
+
name=name,
|
862
|
+
dtype=dtype,
|
863
|
+
shape=shape,
|
864
|
+
)
|
865
|
+
|
866
|
+
async def _update_weights():
|
867
|
+
return await tokenizer_manager.update_weights_from_distributed(obj, None)
|
868
|
+
|
869
|
+
loop = asyncio.get_event_loop()
|
870
|
+
return loop.run_until_complete(_update_weights())
|
871
|
+
|
872
|
+
def get_weights_by_name(self, name, truncate_size=100):
|
873
|
+
"""Get weights by parameter name."""
|
874
|
+
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
875
|
+
|
876
|
+
async def _get_weights():
|
877
|
+
return await tokenizer_manager.get_weights_by_name(obj, None)
|
878
|
+
|
879
|
+
loop = asyncio.get_event_loop()
|
880
|
+
return loop.run_until_complete(_get_weights())
|
881
|
+
|
882
|
+
|
640
883
|
class Runtime:
|
641
884
|
"""
|
642
|
-
A wrapper for the server.
|
885
|
+
A wrapper for the HTTP server.
|
643
886
|
This is used for launching the server in a python program without
|
644
887
|
using the commond line interface.
|
888
|
+
|
889
|
+
It is mainly used for the frontend language.
|
890
|
+
You should use the Engine class if you want to do normal offline processing.
|
645
891
|
"""
|
646
892
|
|
647
893
|
def __init__(
|
@@ -789,152 +1035,3 @@ class Runtime:
|
|
789
1035
|
|
790
1036
|
def __del__(self):
|
791
1037
|
self.shutdown()
|
792
|
-
|
793
|
-
|
794
|
-
STREAM_END_SYMBOL = b"data: [DONE]"
|
795
|
-
STREAM_CHUNK_START_SYMBOL = b"data:"
|
796
|
-
|
797
|
-
|
798
|
-
class Engine:
|
799
|
-
"""
|
800
|
-
SRT Engine without an HTTP server layer.
|
801
|
-
|
802
|
-
This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
|
803
|
-
launching the HTTP server adds unnecessary complexity or overhead,
|
804
|
-
"""
|
805
|
-
|
806
|
-
def __init__(self, log_level: str = "error", *args, **kwargs):
|
807
|
-
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
808
|
-
atexit.register(self.shutdown)
|
809
|
-
|
810
|
-
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
811
|
-
launch_engine(server_args=server_args)
|
812
|
-
|
813
|
-
def generate(
|
814
|
-
self,
|
815
|
-
# The input prompt. It can be a single prompt or a batch of prompts.
|
816
|
-
prompt: Optional[Union[List[str], str]] = None,
|
817
|
-
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
818
|
-
# The token ids for text; one can either specify text or input_ids.
|
819
|
-
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
820
|
-
return_logprob: Optional[Union[List[bool], bool]] = False,
|
821
|
-
logprob_start_len: Optional[Union[List[int], int]] = None,
|
822
|
-
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
823
|
-
lora_path: Optional[List[Optional[str]]] = None,
|
824
|
-
stream: bool = False,
|
825
|
-
):
|
826
|
-
obj = GenerateReqInput(
|
827
|
-
text=prompt,
|
828
|
-
input_ids=input_ids,
|
829
|
-
sampling_params=sampling_params,
|
830
|
-
return_logprob=return_logprob,
|
831
|
-
logprob_start_len=logprob_start_len,
|
832
|
-
top_logprobs_num=top_logprobs_num,
|
833
|
-
lora_path=lora_path,
|
834
|
-
stream=stream,
|
835
|
-
)
|
836
|
-
|
837
|
-
# get the current event loop
|
838
|
-
loop = asyncio.get_event_loop()
|
839
|
-
ret = loop.run_until_complete(generate_request(obj, None))
|
840
|
-
|
841
|
-
if stream is True:
|
842
|
-
|
843
|
-
def generator_wrapper():
|
844
|
-
offset = 0
|
845
|
-
loop = asyncio.get_event_loop()
|
846
|
-
generator = ret.body_iterator
|
847
|
-
while True:
|
848
|
-
chunk = loop.run_until_complete(generator.__anext__())
|
849
|
-
|
850
|
-
if chunk.startswith(STREAM_END_SYMBOL):
|
851
|
-
break
|
852
|
-
else:
|
853
|
-
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
854
|
-
data["text"] = data["text"][offset:]
|
855
|
-
offset += len(data["text"])
|
856
|
-
yield data
|
857
|
-
|
858
|
-
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
|
859
|
-
# however, it allows to wrap the generator as a subfunction and return
|
860
|
-
return generator_wrapper()
|
861
|
-
else:
|
862
|
-
return ret
|
863
|
-
|
864
|
-
async def async_generate(
|
865
|
-
self,
|
866
|
-
# The input prompt. It can be a single prompt or a batch of prompts.
|
867
|
-
prompt: Optional[Union[List[str], str]] = None,
|
868
|
-
sampling_params: Optional[Dict] = None,
|
869
|
-
# The token ids for text; one can either specify text or input_ids.
|
870
|
-
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
871
|
-
return_logprob: Optional[Union[List[bool], bool]] = False,
|
872
|
-
logprob_start_len: Optional[Union[List[int], int]] = None,
|
873
|
-
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
874
|
-
lora_path: Optional[List[Optional[str]]] = None,
|
875
|
-
stream: bool = False,
|
876
|
-
):
|
877
|
-
obj = GenerateReqInput(
|
878
|
-
text=prompt,
|
879
|
-
input_ids=input_ids,
|
880
|
-
sampling_params=sampling_params,
|
881
|
-
return_logprob=return_logprob,
|
882
|
-
logprob_start_len=logprob_start_len,
|
883
|
-
top_logprobs_num=top_logprobs_num,
|
884
|
-
lora_path=lora_path,
|
885
|
-
stream=stream,
|
886
|
-
)
|
887
|
-
|
888
|
-
ret = await generate_request(obj, None)
|
889
|
-
|
890
|
-
if stream is True:
|
891
|
-
generator = ret.body_iterator
|
892
|
-
|
893
|
-
async def generator_wrapper():
|
894
|
-
|
895
|
-
offset = 0
|
896
|
-
|
897
|
-
while True:
|
898
|
-
chunk = await generator.__anext__()
|
899
|
-
|
900
|
-
if chunk.startswith(STREAM_END_SYMBOL):
|
901
|
-
break
|
902
|
-
else:
|
903
|
-
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
904
|
-
data["text"] = data["text"][offset:]
|
905
|
-
offset += len(data["text"])
|
906
|
-
yield data
|
907
|
-
|
908
|
-
return generator_wrapper()
|
909
|
-
else:
|
910
|
-
return ret
|
911
|
-
|
912
|
-
def shutdown(self):
|
913
|
-
kill_process_tree(os.getpid(), include_parent=False)
|
914
|
-
|
915
|
-
def get_tokenizer(self):
|
916
|
-
global tokenizer_manager
|
917
|
-
|
918
|
-
if tokenizer_manager is None:
|
919
|
-
raise ReferenceError("Tokenizer Manager is not initialized.")
|
920
|
-
else:
|
921
|
-
return tokenizer_manager.tokenizer
|
922
|
-
|
923
|
-
def encode(
|
924
|
-
self,
|
925
|
-
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
926
|
-
):
|
927
|
-
obj = EmbeddingReqInput(text=prompt)
|
928
|
-
|
929
|
-
# get the current event loop
|
930
|
-
loop = asyncio.get_event_loop()
|
931
|
-
return loop.run_until_complete(encode_request(obj, None))
|
932
|
-
|
933
|
-
def start_profile(self):
|
934
|
-
tokenizer_manager.start_profile()
|
935
|
-
|
936
|
-
def stop_profile(self):
|
937
|
-
tokenizer_manager.stop_profile()
|
938
|
-
|
939
|
-
async def get_server_info(self):
|
940
|
-
return await _get_server_info()
|
sglang/srt/server_args.py
CHANGED
@@ -20,6 +20,7 @@ import random
|
|
20
20
|
import tempfile
|
21
21
|
from typing import List, Optional
|
22
22
|
|
23
|
+
from sglang.srt.hf_transformers_utils import check_gguf_file
|
23
24
|
from sglang.srt.utils import (
|
24
25
|
get_amdgpu_memory_capacity,
|
25
26
|
get_nvgpu_memory_capacity,
|
@@ -49,6 +50,7 @@ class ServerArgs:
|
|
49
50
|
served_model_name: Optional[str] = None
|
50
51
|
chat_template: Optional[str] = None
|
51
52
|
is_embedding: bool = False
|
53
|
+
revision: Optional[str] = None
|
52
54
|
|
53
55
|
# Port
|
54
56
|
host: str = "127.0.0.1"
|
@@ -58,7 +60,7 @@ class ServerArgs:
|
|
58
60
|
mem_fraction_static: Optional[float] = None
|
59
61
|
max_running_requests: Optional[int] = None
|
60
62
|
max_total_tokens: Optional[int] = None
|
61
|
-
chunked_prefill_size: int =
|
63
|
+
chunked_prefill_size: Optional[int] = None
|
62
64
|
max_prefill_tokens: int = 16384
|
63
65
|
schedule_policy: str = "lpm"
|
64
66
|
schedule_conservativeness: float = 1.0
|
@@ -120,7 +122,7 @@ class ServerArgs:
|
|
120
122
|
disable_jump_forward: bool = False
|
121
123
|
disable_cuda_graph: bool = False
|
122
124
|
disable_cuda_graph_padding: bool = False
|
123
|
-
|
125
|
+
disable_outlines_disk_cache: bool = False
|
124
126
|
disable_custom_all_reduce: bool = False
|
125
127
|
disable_mla: bool = False
|
126
128
|
disable_overlap_schedule: bool = False
|
@@ -128,7 +130,7 @@ class ServerArgs:
|
|
128
130
|
enable_dp_attention: bool = False
|
129
131
|
enable_torch_compile: bool = False
|
130
132
|
torch_compile_max_bs: int = 32
|
131
|
-
cuda_graph_max_bs: int =
|
133
|
+
cuda_graph_max_bs: Optional[int] = None
|
132
134
|
torchao_config: str = ""
|
133
135
|
enable_nan_detection: bool = False
|
134
136
|
enable_p2p_check: bool = False
|
@@ -144,19 +146,20 @@ class ServerArgs:
|
|
144
146
|
if self.served_model_name is None:
|
145
147
|
self.served_model_name = self.model_path
|
146
148
|
|
147
|
-
if self.chunked_prefill_size is not None and self.chunked_prefill_size <= 0:
|
148
|
-
# Disable chunked prefill
|
149
|
-
self.chunked_prefill_size = None
|
150
|
-
|
151
149
|
if self.random_seed is None:
|
152
150
|
self.random_seed = random.randint(0, 1 << 30)
|
153
151
|
|
154
|
-
|
152
|
+
if is_hip():
|
153
|
+
gpu_mem = get_amdgpu_memory_capacity()
|
154
|
+
else:
|
155
|
+
gpu_mem = get_nvgpu_memory_capacity()
|
156
|
+
|
157
|
+
# Set mem fraction static, which depends on the tensor parallelism size
|
155
158
|
if self.mem_fraction_static is None:
|
156
159
|
if self.tp_size >= 16:
|
157
160
|
self.mem_fraction_static = 0.79
|
158
161
|
elif self.tp_size >= 8:
|
159
|
-
self.mem_fraction_static = 0.
|
162
|
+
self.mem_fraction_static = 0.81
|
160
163
|
elif self.tp_size >= 4:
|
161
164
|
self.mem_fraction_static = 0.85
|
162
165
|
elif self.tp_size >= 2:
|
@@ -164,25 +167,35 @@ class ServerArgs:
|
|
164
167
|
else:
|
165
168
|
self.mem_fraction_static = 0.88
|
166
169
|
|
167
|
-
#
|
168
|
-
if
|
169
|
-
gpu_mem
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
self.chunked_prefill_size //= 4 # make it 2048
|
174
|
-
self.cuda_graph_max_bs = 4
|
175
|
-
logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
|
170
|
+
# Set chunked prefill size, which depends on the gpu memory capacity
|
171
|
+
if self.chunked_prefill_size is None:
|
172
|
+
if gpu_mem < 25_000:
|
173
|
+
self.chunked_prefill_size = 2048
|
174
|
+
else:
|
175
|
+
self.chunked_prefill_size = 8192
|
176
176
|
|
177
|
-
#
|
178
|
-
if
|
179
|
-
|
180
|
-
|
177
|
+
# Set cuda graph max batch size
|
178
|
+
if self.cuda_graph_max_bs is None:
|
179
|
+
if gpu_mem < 25_000:
|
180
|
+
self.cuda_graph_max_bs = 8
|
181
|
+
else:
|
182
|
+
self.cuda_graph_max_bs = 160
|
181
183
|
|
184
|
+
# Choose kernel backends
|
182
185
|
if self.attention_backend is None:
|
183
|
-
self.attention_backend =
|
186
|
+
self.attention_backend = (
|
187
|
+
"flashinfer" if is_flashinfer_available() else "triton"
|
188
|
+
)
|
184
189
|
if self.sampling_backend is None:
|
185
|
-
self.sampling_backend =
|
190
|
+
self.sampling_backend = (
|
191
|
+
"flashinfer" if is_flashinfer_available() else "pytorch"
|
192
|
+
)
|
193
|
+
|
194
|
+
if self.attention_backend == "torch_native":
|
195
|
+
logger.warning(
|
196
|
+
"Cuda graph is disabled because of using torch native attention backend"
|
197
|
+
)
|
198
|
+
self.disable_cuda_graph = True
|
186
199
|
|
187
200
|
# Others
|
188
201
|
if self.enable_dp_attention:
|
@@ -191,14 +204,20 @@ class ServerArgs:
|
|
191
204
|
self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
|
192
205
|
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
193
206
|
self.disable_overlap_schedule = True
|
194
|
-
logger.
|
207
|
+
logger.warning(
|
195
208
|
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
196
209
|
f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
|
197
210
|
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
|
198
211
|
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
199
|
-
"Overlap
|
212
|
+
"Overlap scheduler is disabled."
|
200
213
|
)
|
201
214
|
|
215
|
+
# GGUF
|
216
|
+
if (
|
217
|
+
self.load_format == "auto" or self.load_format == "gguf"
|
218
|
+
) and check_gguf_file(self.model_path):
|
219
|
+
self.quantization = self.load_format = "gguf"
|
220
|
+
|
202
221
|
@staticmethod
|
203
222
|
def add_cli_args(parser: argparse.ArgumentParser):
|
204
223
|
# Model and port args
|
@@ -238,7 +257,7 @@ class ServerArgs:
|
|
238
257
|
"--load-format",
|
239
258
|
type=str,
|
240
259
|
default=ServerArgs.load_format,
|
241
|
-
choices=["auto", "pt", "safetensors", "npcache", "dummy"],
|
260
|
+
choices=["auto", "pt", "safetensors", "npcache", "dummy", "gguf"],
|
242
261
|
help="The format of the model weights to load. "
|
243
262
|
'"auto" will try to load the weights in the safetensors format '
|
244
263
|
"and fall back to the pytorch bin format if safetensors format "
|
@@ -248,7 +267,8 @@ class ServerArgs:
|
|
248
267
|
'"npcache" will load the weights in pytorch format and store '
|
249
268
|
"a numpy cache to speed up the loading. "
|
250
269
|
'"dummy" will initialize the weights with random values, '
|
251
|
-
"which is mainly for profiling."
|
270
|
+
"which is mainly for profiling."
|
271
|
+
'"gguf" will load the weights in the gguf format. ',
|
252
272
|
)
|
253
273
|
parser.add_argument(
|
254
274
|
"--trust-remote-code",
|
@@ -288,6 +308,7 @@ class ServerArgs:
|
|
288
308
|
"gptq_marlin",
|
289
309
|
"awq_marlin",
|
290
310
|
"bitsandbytes",
|
311
|
+
"gguf",
|
291
312
|
],
|
292
313
|
help="The quantization method.",
|
293
314
|
)
|
@@ -321,6 +342,14 @@ class ServerArgs:
|
|
321
342
|
action="store_true",
|
322
343
|
help="Whether to use a CausalLM as an embedding model.",
|
323
344
|
)
|
345
|
+
parser.add_argument(
|
346
|
+
"--revision",
|
347
|
+
type=str,
|
348
|
+
default=None,
|
349
|
+
help="The specific model version to use. It can be a branch "
|
350
|
+
"name, a tag name, or a commit id. If unspecified, will use "
|
351
|
+
"the default version.",
|
352
|
+
)
|
324
353
|
|
325
354
|
# Memory and scheduling
|
326
355
|
parser.add_argument(
|
@@ -572,7 +601,7 @@ class ServerArgs:
|
|
572
601
|
parser.add_argument(
|
573
602
|
"--attention-backend",
|
574
603
|
type=str,
|
575
|
-
choices=["flashinfer", "triton"],
|
604
|
+
choices=["flashinfer", "triton", "torch_native"],
|
576
605
|
default=ServerArgs.attention_backend,
|
577
606
|
help="Choose the kernels for attention layers.",
|
578
607
|
)
|
@@ -613,9 +642,9 @@ class ServerArgs:
|
|
613
642
|
help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
|
614
643
|
)
|
615
644
|
parser.add_argument(
|
616
|
-
"--disable-disk-cache",
|
645
|
+
"--disable-outlines-disk-cache",
|
617
646
|
action="store_true",
|
618
|
-
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
647
|
+
help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
|
619
648
|
)
|
620
649
|
parser.add_argument(
|
621
650
|
"--disable-custom-all-reduce",
|
@@ -716,6 +745,11 @@ class ServerArgs:
|
|
716
745
|
action=DeprecatedAction,
|
717
746
|
help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
|
718
747
|
)
|
748
|
+
parser.add_argument(
|
749
|
+
"--disable-disk-cache",
|
750
|
+
action=DeprecatedAction,
|
751
|
+
help="'--disable-disk-cache' is deprecated. Please use '--disable-outlines-disk-cache' instead.",
|
752
|
+
)
|
719
753
|
|
720
754
|
@classmethod
|
721
755
|
def from_cli_args(cls, args: argparse.Namespace):
|