sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +11 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +5 -5
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +33 -20
- sglang/srt/layers/attention/torch_native_backend.py +299 -0
- sglang/srt/layers/attention/triton_backend.py +22 -8
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +661 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +36 -2
- sglang/srt/layers/quantization/fp8.py +559 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +4 -2
- sglang/srt/layers/sampler.py +2 -0
- sglang/srt/layers/torchao_utils.py +23 -45
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +19 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +145 -85
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/model_executor/cuda_graph_runner.py +30 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +146 -153
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/model_parallel.py +1 -5
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +4 -5
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +90 -18
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -8
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +96 -31
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +24 -14
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -13
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -16
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -17
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +270 -173
- sglang/srt/server_args.py +102 -29
- sglang/srt/utils.py +295 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
- sglang-0.4.0.post1.dist-info/RECORD +189 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
sglang/srt/server.py
CHANGED
@@ -52,8 +52,11 @@ from sglang.srt.managers.io_struct import (
|
|
52
52
|
CloseSessionReqInput,
|
53
53
|
EmbeddingReqInput,
|
54
54
|
GenerateReqInput,
|
55
|
+
GetWeightsByNameReqInput,
|
56
|
+
InitWeightsUpdateGroupReqInput,
|
55
57
|
OpenSessionReqInput,
|
56
|
-
|
58
|
+
UpdateWeightFromDiskReqInput,
|
59
|
+
UpdateWeightsFromDistributedReqInput,
|
57
60
|
)
|
58
61
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
59
62
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
@@ -150,13 +153,11 @@ async def get_model_info():
|
|
150
153
|
|
151
154
|
@app.get("/get_server_info")
|
152
155
|
async def get_server_info():
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
159
|
-
)
|
156
|
+
return {
|
157
|
+
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
158
|
+
**scheduler_info,
|
159
|
+
"version": __version__,
|
160
|
+
}
|
160
161
|
|
161
162
|
|
162
163
|
@app.post("/flush_cache")
|
@@ -192,11 +193,11 @@ async def stop_profile_async():
|
|
192
193
|
)
|
193
194
|
|
194
195
|
|
195
|
-
@app.post("/
|
196
|
+
@app.post("/update_weights_from_disk")
|
196
197
|
@time_func_latency
|
197
|
-
async def
|
198
|
-
"""Update the weights inplace without re-launching the server."""
|
199
|
-
success, message = await tokenizer_manager.
|
198
|
+
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
199
|
+
"""Update the weights from disk inplace without re-launching the server."""
|
200
|
+
success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
|
200
201
|
content = {"success": success, "message": message}
|
201
202
|
if success:
|
202
203
|
return ORJSONResponse(
|
@@ -210,6 +211,52 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
|
210
211
|
)
|
211
212
|
|
212
213
|
|
214
|
+
@app.post("/init_weights_update_group")
|
215
|
+
async def init_weights_update_group(
|
216
|
+
obj: InitWeightsUpdateGroupReqInput, request: Request
|
217
|
+
):
|
218
|
+
"""Initialize the parameter update group."""
|
219
|
+
success, message = await tokenizer_manager.init_weights_update_group(obj, request)
|
220
|
+
content = {"success": success, "message": message}
|
221
|
+
if success:
|
222
|
+
return ORJSONResponse(content, status_code=200)
|
223
|
+
else:
|
224
|
+
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
225
|
+
|
226
|
+
|
227
|
+
@app.post("/update_weights_from_distributed")
|
228
|
+
async def update_weights_from_distributed(
|
229
|
+
obj: UpdateWeightsFromDistributedReqInput, request: Request
|
230
|
+
):
|
231
|
+
"""Update model parameter from distributed online."""
|
232
|
+
success, message = await tokenizer_manager.update_weights_from_distributed(
|
233
|
+
obj, request
|
234
|
+
)
|
235
|
+
content = {"success": success, "message": message}
|
236
|
+
if success:
|
237
|
+
return ORJSONResponse(content, status_code=200)
|
238
|
+
else:
|
239
|
+
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
240
|
+
|
241
|
+
|
242
|
+
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
243
|
+
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
244
|
+
"""Get model parameter by name."""
|
245
|
+
try:
|
246
|
+
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
247
|
+
if ret is None:
|
248
|
+
return ORJSONResponse(
|
249
|
+
{"error": {"message": "Get parameter by name failed"}},
|
250
|
+
status_code=HTTPStatus.BAD_REQUEST,
|
251
|
+
)
|
252
|
+
else:
|
253
|
+
return ORJSONResponse(ret, status_code=200)
|
254
|
+
except Exception as e:
|
255
|
+
return ORJSONResponse(
|
256
|
+
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
257
|
+
)
|
258
|
+
|
259
|
+
|
213
260
|
@app.api_route("/open_session", methods=["GET", "POST"])
|
214
261
|
async def open_session(obj: OpenSessionReqInput, request: Request):
|
215
262
|
"""Open a session, and return its unique session id."""
|
@@ -282,7 +329,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
|
|
282
329
|
)
|
283
330
|
|
284
331
|
|
285
|
-
@app.api_route("/
|
332
|
+
@app.api_route("/classify", methods=["POST", "PUT"])
|
286
333
|
@time_func_latency
|
287
334
|
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
288
335
|
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
@@ -415,8 +462,8 @@ def launch_engine(
|
|
415
462
|
if server_args.node_rank >= 1:
|
416
463
|
# For other nodes, they do not need to run tokenizer or detokenizer,
|
417
464
|
# so they can just wait here.
|
418
|
-
|
419
|
-
|
465
|
+
for proc in scheduler_procs:
|
466
|
+
proc.join()
|
420
467
|
else:
|
421
468
|
# Launch the data parallel controller
|
422
469
|
reader, writer = mp.Pipe(duplex=False)
|
@@ -517,14 +564,6 @@ def launch_server(
|
|
517
564
|
t.join()
|
518
565
|
|
519
566
|
|
520
|
-
async def _get_server_info():
|
521
|
-
return {
|
522
|
-
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
523
|
-
**scheduler_info,
|
524
|
-
"version": __version__,
|
525
|
-
}
|
526
|
-
|
527
|
-
|
528
567
|
def _set_envs_and_config(server_args: ServerArgs):
|
529
568
|
# Set global environments
|
530
569
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
@@ -637,11 +676,218 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
637
676
|
delete_directory(server_args.model_path)
|
638
677
|
|
639
678
|
|
679
|
+
STREAM_END_SYMBOL = b"data: [DONE]"
|
680
|
+
STREAM_CHUNK_START_SYMBOL = b"data:"
|
681
|
+
|
682
|
+
|
683
|
+
class Engine:
|
684
|
+
"""
|
685
|
+
SRT Engine without an HTTP server layer.
|
686
|
+
|
687
|
+
This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
|
688
|
+
launching the HTTP server adds unnecessary complexity or overhead,
|
689
|
+
"""
|
690
|
+
|
691
|
+
def __init__(self, log_level: str = "error", *args, **kwargs):
|
692
|
+
"""See the arguments in server_args.py::ServerArgs"""
|
693
|
+
|
694
|
+
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
695
|
+
atexit.register(self.shutdown)
|
696
|
+
|
697
|
+
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
698
|
+
launch_engine(server_args=server_args)
|
699
|
+
|
700
|
+
def generate(
|
701
|
+
self,
|
702
|
+
# The input prompt. It can be a single prompt or a batch of prompts.
|
703
|
+
prompt: Optional[Union[List[str], str]] = None,
|
704
|
+
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
705
|
+
# The token ids for text; one can either specify text or input_ids.
|
706
|
+
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
707
|
+
return_logprob: Optional[Union[List[bool], bool]] = False,
|
708
|
+
logprob_start_len: Optional[Union[List[int], int]] = None,
|
709
|
+
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
710
|
+
lora_path: Optional[List[Optional[str]]] = None,
|
711
|
+
stream: bool = False,
|
712
|
+
):
|
713
|
+
obj = GenerateReqInput(
|
714
|
+
text=prompt,
|
715
|
+
input_ids=input_ids,
|
716
|
+
sampling_params=sampling_params,
|
717
|
+
return_logprob=return_logprob,
|
718
|
+
logprob_start_len=logprob_start_len,
|
719
|
+
top_logprobs_num=top_logprobs_num,
|
720
|
+
lora_path=lora_path,
|
721
|
+
stream=stream,
|
722
|
+
)
|
723
|
+
|
724
|
+
# get the current event loop
|
725
|
+
loop = asyncio.get_event_loop()
|
726
|
+
ret = loop.run_until_complete(generate_request(obj, None))
|
727
|
+
|
728
|
+
if stream is True:
|
729
|
+
|
730
|
+
def generator_wrapper():
|
731
|
+
offset = 0
|
732
|
+
loop = asyncio.get_event_loop()
|
733
|
+
generator = ret.body_iterator
|
734
|
+
while True:
|
735
|
+
chunk = loop.run_until_complete(generator.__anext__())
|
736
|
+
|
737
|
+
if chunk.startswith(STREAM_END_SYMBOL):
|
738
|
+
break
|
739
|
+
else:
|
740
|
+
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
741
|
+
data["text"] = data["text"][offset:]
|
742
|
+
offset += len(data["text"])
|
743
|
+
yield data
|
744
|
+
|
745
|
+
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
|
746
|
+
# however, it allows to wrap the generator as a subfunction and return
|
747
|
+
return generator_wrapper()
|
748
|
+
else:
|
749
|
+
return ret
|
750
|
+
|
751
|
+
async def async_generate(
|
752
|
+
self,
|
753
|
+
# The input prompt. It can be a single prompt or a batch of prompts.
|
754
|
+
prompt: Optional[Union[List[str], str]] = None,
|
755
|
+
sampling_params: Optional[Dict] = None,
|
756
|
+
# The token ids for text; one can either specify text or input_ids.
|
757
|
+
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
758
|
+
return_logprob: Optional[Union[List[bool], bool]] = False,
|
759
|
+
logprob_start_len: Optional[Union[List[int], int]] = None,
|
760
|
+
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
761
|
+
lora_path: Optional[List[Optional[str]]] = None,
|
762
|
+
stream: bool = False,
|
763
|
+
):
|
764
|
+
obj = GenerateReqInput(
|
765
|
+
text=prompt,
|
766
|
+
input_ids=input_ids,
|
767
|
+
sampling_params=sampling_params,
|
768
|
+
return_logprob=return_logprob,
|
769
|
+
logprob_start_len=logprob_start_len,
|
770
|
+
top_logprobs_num=top_logprobs_num,
|
771
|
+
lora_path=lora_path,
|
772
|
+
stream=stream,
|
773
|
+
)
|
774
|
+
|
775
|
+
ret = await generate_request(obj, None)
|
776
|
+
|
777
|
+
if stream is True:
|
778
|
+
generator = ret.body_iterator
|
779
|
+
|
780
|
+
async def generator_wrapper():
|
781
|
+
|
782
|
+
offset = 0
|
783
|
+
|
784
|
+
while True:
|
785
|
+
chunk = await generator.__anext__()
|
786
|
+
|
787
|
+
if chunk.startswith(STREAM_END_SYMBOL):
|
788
|
+
break
|
789
|
+
else:
|
790
|
+
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
791
|
+
data["text"] = data["text"][offset:]
|
792
|
+
offset += len(data["text"])
|
793
|
+
yield data
|
794
|
+
|
795
|
+
return generator_wrapper()
|
796
|
+
else:
|
797
|
+
return ret
|
798
|
+
|
799
|
+
def shutdown(self):
|
800
|
+
kill_process_tree(os.getpid(), include_parent=False)
|
801
|
+
|
802
|
+
def get_tokenizer(self):
|
803
|
+
global tokenizer_manager
|
804
|
+
|
805
|
+
if tokenizer_manager is None:
|
806
|
+
raise ReferenceError("Tokenizer Manager is not initialized.")
|
807
|
+
else:
|
808
|
+
return tokenizer_manager.tokenizer
|
809
|
+
|
810
|
+
def encode(
|
811
|
+
self,
|
812
|
+
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
813
|
+
):
|
814
|
+
obj = EmbeddingReqInput(text=prompt)
|
815
|
+
|
816
|
+
# get the current event loop
|
817
|
+
loop = asyncio.get_event_loop()
|
818
|
+
return loop.run_until_complete(encode_request(obj, None))
|
819
|
+
|
820
|
+
def start_profile(self):
|
821
|
+
tokenizer_manager.start_profile()
|
822
|
+
|
823
|
+
def stop_profile(self):
|
824
|
+
tokenizer_manager.stop_profile()
|
825
|
+
|
826
|
+
def get_server_info(self):
|
827
|
+
return {
|
828
|
+
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
829
|
+
**scheduler_info,
|
830
|
+
"version": __version__,
|
831
|
+
}
|
832
|
+
|
833
|
+
def init_weights_update_group(
|
834
|
+
self,
|
835
|
+
master_address: str,
|
836
|
+
master_port: int,
|
837
|
+
rank_offset: int,
|
838
|
+
world_size: int,
|
839
|
+
group_name: str,
|
840
|
+
backend: str = "nccl",
|
841
|
+
):
|
842
|
+
"""Initialize parameter update group."""
|
843
|
+
obj = InitWeightsUpdateGroupReqInput(
|
844
|
+
master_address=master_address,
|
845
|
+
master_port=master_port,
|
846
|
+
rank_offset=rank_offset,
|
847
|
+
world_size=world_size,
|
848
|
+
group_name=group_name,
|
849
|
+
backend=backend,
|
850
|
+
)
|
851
|
+
|
852
|
+
async def _init_group():
|
853
|
+
return await tokenizer_manager.init_weights_update_group(obj, None)
|
854
|
+
|
855
|
+
loop = asyncio.get_event_loop()
|
856
|
+
return loop.run_until_complete(_init_group())
|
857
|
+
|
858
|
+
def update_weights_from_distributed(self, name, dtype, shape):
|
859
|
+
"""Update weights from distributed source."""
|
860
|
+
obj = UpdateWeightsFromDistributedReqInput(
|
861
|
+
name=name,
|
862
|
+
dtype=dtype,
|
863
|
+
shape=shape,
|
864
|
+
)
|
865
|
+
|
866
|
+
async def _update_weights():
|
867
|
+
return await tokenizer_manager.update_weights_from_distributed(obj, None)
|
868
|
+
|
869
|
+
loop = asyncio.get_event_loop()
|
870
|
+
return loop.run_until_complete(_update_weights())
|
871
|
+
|
872
|
+
def get_weights_by_name(self, name, truncate_size=100):
|
873
|
+
"""Get weights by parameter name."""
|
874
|
+
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
875
|
+
|
876
|
+
async def _get_weights():
|
877
|
+
return await tokenizer_manager.get_weights_by_name(obj, None)
|
878
|
+
|
879
|
+
loop = asyncio.get_event_loop()
|
880
|
+
return loop.run_until_complete(_get_weights())
|
881
|
+
|
882
|
+
|
640
883
|
class Runtime:
|
641
884
|
"""
|
642
|
-
A wrapper for the server.
|
885
|
+
A wrapper for the HTTP server.
|
643
886
|
This is used for launching the server in a python program without
|
644
887
|
using the commond line interface.
|
888
|
+
|
889
|
+
It is mainly used for the frontend language.
|
890
|
+
You should use the Engine class if you want to do normal offline processing.
|
645
891
|
"""
|
646
892
|
|
647
893
|
def __init__(
|
@@ -789,152 +1035,3 @@ class Runtime:
|
|
789
1035
|
|
790
1036
|
def __del__(self):
|
791
1037
|
self.shutdown()
|
792
|
-
|
793
|
-
|
794
|
-
STREAM_END_SYMBOL = b"data: [DONE]"
|
795
|
-
STREAM_CHUNK_START_SYMBOL = b"data:"
|
796
|
-
|
797
|
-
|
798
|
-
class Engine:
|
799
|
-
"""
|
800
|
-
SRT Engine without an HTTP server layer.
|
801
|
-
|
802
|
-
This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
|
803
|
-
launching the HTTP server adds unnecessary complexity or overhead,
|
804
|
-
"""
|
805
|
-
|
806
|
-
def __init__(self, log_level: str = "error", *args, **kwargs):
|
807
|
-
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
808
|
-
atexit.register(self.shutdown)
|
809
|
-
|
810
|
-
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
811
|
-
launch_engine(server_args=server_args)
|
812
|
-
|
813
|
-
def generate(
|
814
|
-
self,
|
815
|
-
# The input prompt. It can be a single prompt or a batch of prompts.
|
816
|
-
prompt: Optional[Union[List[str], str]] = None,
|
817
|
-
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
818
|
-
# The token ids for text; one can either specify text or input_ids.
|
819
|
-
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
820
|
-
return_logprob: Optional[Union[List[bool], bool]] = False,
|
821
|
-
logprob_start_len: Optional[Union[List[int], int]] = None,
|
822
|
-
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
823
|
-
lora_path: Optional[List[Optional[str]]] = None,
|
824
|
-
stream: bool = False,
|
825
|
-
):
|
826
|
-
obj = GenerateReqInput(
|
827
|
-
text=prompt,
|
828
|
-
input_ids=input_ids,
|
829
|
-
sampling_params=sampling_params,
|
830
|
-
return_logprob=return_logprob,
|
831
|
-
logprob_start_len=logprob_start_len,
|
832
|
-
top_logprobs_num=top_logprobs_num,
|
833
|
-
lora_path=lora_path,
|
834
|
-
stream=stream,
|
835
|
-
)
|
836
|
-
|
837
|
-
# get the current event loop
|
838
|
-
loop = asyncio.get_event_loop()
|
839
|
-
ret = loop.run_until_complete(generate_request(obj, None))
|
840
|
-
|
841
|
-
if stream is True:
|
842
|
-
|
843
|
-
def generator_wrapper():
|
844
|
-
offset = 0
|
845
|
-
loop = asyncio.get_event_loop()
|
846
|
-
generator = ret.body_iterator
|
847
|
-
while True:
|
848
|
-
chunk = loop.run_until_complete(generator.__anext__())
|
849
|
-
|
850
|
-
if chunk.startswith(STREAM_END_SYMBOL):
|
851
|
-
break
|
852
|
-
else:
|
853
|
-
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
854
|
-
data["text"] = data["text"][offset:]
|
855
|
-
offset += len(data["text"])
|
856
|
-
yield data
|
857
|
-
|
858
|
-
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
|
859
|
-
# however, it allows to wrap the generator as a subfunction and return
|
860
|
-
return generator_wrapper()
|
861
|
-
else:
|
862
|
-
return ret
|
863
|
-
|
864
|
-
async def async_generate(
|
865
|
-
self,
|
866
|
-
# The input prompt. It can be a single prompt or a batch of prompts.
|
867
|
-
prompt: Optional[Union[List[str], str]] = None,
|
868
|
-
sampling_params: Optional[Dict] = None,
|
869
|
-
# The token ids for text; one can either specify text or input_ids.
|
870
|
-
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
871
|
-
return_logprob: Optional[Union[List[bool], bool]] = False,
|
872
|
-
logprob_start_len: Optional[Union[List[int], int]] = None,
|
873
|
-
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
874
|
-
lora_path: Optional[List[Optional[str]]] = None,
|
875
|
-
stream: bool = False,
|
876
|
-
):
|
877
|
-
obj = GenerateReqInput(
|
878
|
-
text=prompt,
|
879
|
-
input_ids=input_ids,
|
880
|
-
sampling_params=sampling_params,
|
881
|
-
return_logprob=return_logprob,
|
882
|
-
logprob_start_len=logprob_start_len,
|
883
|
-
top_logprobs_num=top_logprobs_num,
|
884
|
-
lora_path=lora_path,
|
885
|
-
stream=stream,
|
886
|
-
)
|
887
|
-
|
888
|
-
ret = await generate_request(obj, None)
|
889
|
-
|
890
|
-
if stream is True:
|
891
|
-
generator = ret.body_iterator
|
892
|
-
|
893
|
-
async def generator_wrapper():
|
894
|
-
|
895
|
-
offset = 0
|
896
|
-
|
897
|
-
while True:
|
898
|
-
chunk = await generator.__anext__()
|
899
|
-
|
900
|
-
if chunk.startswith(STREAM_END_SYMBOL):
|
901
|
-
break
|
902
|
-
else:
|
903
|
-
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
904
|
-
data["text"] = data["text"][offset:]
|
905
|
-
offset += len(data["text"])
|
906
|
-
yield data
|
907
|
-
|
908
|
-
return generator_wrapper()
|
909
|
-
else:
|
910
|
-
return ret
|
911
|
-
|
912
|
-
def shutdown(self):
|
913
|
-
kill_process_tree(os.getpid(), include_parent=False)
|
914
|
-
|
915
|
-
def get_tokenizer(self):
|
916
|
-
global tokenizer_manager
|
917
|
-
|
918
|
-
if tokenizer_manager is None:
|
919
|
-
raise ReferenceError("Tokenizer Manager is not initialized.")
|
920
|
-
else:
|
921
|
-
return tokenizer_manager.tokenizer
|
922
|
-
|
923
|
-
def encode(
|
924
|
-
self,
|
925
|
-
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
926
|
-
):
|
927
|
-
obj = EmbeddingReqInput(text=prompt)
|
928
|
-
|
929
|
-
# get the current event loop
|
930
|
-
loop = asyncio.get_event_loop()
|
931
|
-
return loop.run_until_complete(encode_request(obj, None))
|
932
|
-
|
933
|
-
def start_profile(self):
|
934
|
-
tokenizer_manager.start_profile()
|
935
|
-
|
936
|
-
def stop_profile(self):
|
937
|
-
tokenizer_manager.stop_profile()
|
938
|
-
|
939
|
-
async def get_server_info(self):
|
940
|
-
return await _get_server_info()
|