sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
sglang/srt/server.py CHANGED
@@ -52,8 +52,11 @@ from sglang.srt.managers.io_struct import (
52
52
  CloseSessionReqInput,
53
53
  EmbeddingReqInput,
54
54
  GenerateReqInput,
55
+ GetWeightsByNameReqInput,
56
+ InitWeightsUpdateGroupReqInput,
55
57
  OpenSessionReqInput,
56
- UpdateWeightReqInput,
58
+ UpdateWeightFromDiskReqInput,
59
+ UpdateWeightsFromDistributedReqInput,
57
60
  )
58
61
  from sglang.srt.managers.scheduler import run_scheduler_process
59
62
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -150,13 +153,11 @@ async def get_model_info():
150
153
 
151
154
  @app.get("/get_server_info")
152
155
  async def get_server_info():
153
- try:
154
- return await _get_server_info()
155
-
156
- except Exception as e:
157
- return ORJSONResponse(
158
- {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
159
- )
156
+ return {
157
+ **dataclasses.asdict(tokenizer_manager.server_args), # server args
158
+ **scheduler_info,
159
+ "version": __version__,
160
+ }
160
161
 
161
162
 
162
163
  @app.post("/flush_cache")
@@ -192,11 +193,11 @@ async def stop_profile_async():
192
193
  )
193
194
 
194
195
 
195
- @app.post("/update_weights")
196
+ @app.post("/update_weights_from_disk")
196
197
  @time_func_latency
197
- async def update_weights(obj: UpdateWeightReqInput, request: Request):
198
- """Update the weights inplace without re-launching the server."""
199
- success, message = await tokenizer_manager.update_weights(obj, request)
198
+ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
199
+ """Update the weights from disk inplace without re-launching the server."""
200
+ success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
200
201
  content = {"success": success, "message": message}
201
202
  if success:
202
203
  return ORJSONResponse(
@@ -210,6 +211,52 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
210
211
  )
211
212
 
212
213
 
214
+ @app.post("/init_weights_update_group")
215
+ async def init_weights_update_group(
216
+ obj: InitWeightsUpdateGroupReqInput, request: Request
217
+ ):
218
+ """Initialize the parameter update group."""
219
+ success, message = await tokenizer_manager.init_weights_update_group(obj, request)
220
+ content = {"success": success, "message": message}
221
+ if success:
222
+ return ORJSONResponse(content, status_code=200)
223
+ else:
224
+ return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
225
+
226
+
227
+ @app.post("/update_weights_from_distributed")
228
+ async def update_weights_from_distributed(
229
+ obj: UpdateWeightsFromDistributedReqInput, request: Request
230
+ ):
231
+ """Update model parameter from distributed online."""
232
+ success, message = await tokenizer_manager.update_weights_from_distributed(
233
+ obj, request
234
+ )
235
+ content = {"success": success, "message": message}
236
+ if success:
237
+ return ORJSONResponse(content, status_code=200)
238
+ else:
239
+ return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
240
+
241
+
242
+ @app.api_route("/get_weights_by_name", methods=["GET", "POST"])
243
+ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
244
+ """Get model parameter by name."""
245
+ try:
246
+ ret = await tokenizer_manager.get_weights_by_name(obj, request)
247
+ if ret is None:
248
+ return ORJSONResponse(
249
+ {"error": {"message": "Get parameter by name failed"}},
250
+ status_code=HTTPStatus.BAD_REQUEST,
251
+ )
252
+ else:
253
+ return ORJSONResponse(ret, status_code=200)
254
+ except Exception as e:
255
+ return ORJSONResponse(
256
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
257
+ )
258
+
259
+
213
260
  @app.api_route("/open_session", methods=["GET", "POST"])
214
261
  async def open_session(obj: OpenSessionReqInput, request: Request):
215
262
  """Open a session, and return its unique session id."""
@@ -282,7 +329,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
282
329
  )
283
330
 
284
331
 
285
- @app.api_route("/encode", methods=["POST", "PUT"])
332
+ @app.api_route("/classify", methods=["POST", "PUT"])
286
333
  @time_func_latency
287
334
  async def classify_request(obj: EmbeddingReqInput, request: Request):
288
335
  """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
@@ -415,8 +462,8 @@ def launch_engine(
415
462
  if server_args.node_rank >= 1:
416
463
  # For other nodes, they do not need to run tokenizer or detokenizer,
417
464
  # so they can just wait here.
418
- while True:
419
- pass
465
+ for proc in scheduler_procs:
466
+ proc.join()
420
467
  else:
421
468
  # Launch the data parallel controller
422
469
  reader, writer = mp.Pipe(duplex=False)
@@ -517,14 +564,6 @@ def launch_server(
517
564
  t.join()
518
565
 
519
566
 
520
- async def _get_server_info():
521
- return {
522
- **dataclasses.asdict(tokenizer_manager.server_args), # server args
523
- **scheduler_info,
524
- "version": __version__,
525
- }
526
-
527
-
528
567
  def _set_envs_and_config(server_args: ServerArgs):
529
568
  # Set global environments
530
569
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -637,11 +676,218 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
637
676
  delete_directory(server_args.model_path)
638
677
 
639
678
 
679
+ STREAM_END_SYMBOL = b"data: [DONE]"
680
+ STREAM_CHUNK_START_SYMBOL = b"data:"
681
+
682
+
683
+ class Engine:
684
+ """
685
+ SRT Engine without an HTTP server layer.
686
+
687
+ This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
688
+ launching the HTTP server adds unnecessary complexity or overhead,
689
+ """
690
+
691
+ def __init__(self, log_level: str = "error", *args, **kwargs):
692
+ """See the arguments in server_args.py::ServerArgs"""
693
+
694
+ # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
695
+ atexit.register(self.shutdown)
696
+
697
+ server_args = ServerArgs(*args, log_level=log_level, **kwargs)
698
+ launch_engine(server_args=server_args)
699
+
700
+ def generate(
701
+ self,
702
+ # The input prompt. It can be a single prompt or a batch of prompts.
703
+ prompt: Optional[Union[List[str], str]] = None,
704
+ sampling_params: Optional[Union[List[Dict], Dict]] = None,
705
+ # The token ids for text; one can either specify text or input_ids.
706
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None,
707
+ return_logprob: Optional[Union[List[bool], bool]] = False,
708
+ logprob_start_len: Optional[Union[List[int], int]] = None,
709
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
710
+ lora_path: Optional[List[Optional[str]]] = None,
711
+ stream: bool = False,
712
+ ):
713
+ obj = GenerateReqInput(
714
+ text=prompt,
715
+ input_ids=input_ids,
716
+ sampling_params=sampling_params,
717
+ return_logprob=return_logprob,
718
+ logprob_start_len=logprob_start_len,
719
+ top_logprobs_num=top_logprobs_num,
720
+ lora_path=lora_path,
721
+ stream=stream,
722
+ )
723
+
724
+ # get the current event loop
725
+ loop = asyncio.get_event_loop()
726
+ ret = loop.run_until_complete(generate_request(obj, None))
727
+
728
+ if stream is True:
729
+
730
+ def generator_wrapper():
731
+ offset = 0
732
+ loop = asyncio.get_event_loop()
733
+ generator = ret.body_iterator
734
+ while True:
735
+ chunk = loop.run_until_complete(generator.__anext__())
736
+
737
+ if chunk.startswith(STREAM_END_SYMBOL):
738
+ break
739
+ else:
740
+ data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
741
+ data["text"] = data["text"][offset:]
742
+ offset += len(data["text"])
743
+ yield data
744
+
745
+ # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
746
+ # however, it allows to wrap the generator as a subfunction and return
747
+ return generator_wrapper()
748
+ else:
749
+ return ret
750
+
751
+ async def async_generate(
752
+ self,
753
+ # The input prompt. It can be a single prompt or a batch of prompts.
754
+ prompt: Optional[Union[List[str], str]] = None,
755
+ sampling_params: Optional[Dict] = None,
756
+ # The token ids for text; one can either specify text or input_ids.
757
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None,
758
+ return_logprob: Optional[Union[List[bool], bool]] = False,
759
+ logprob_start_len: Optional[Union[List[int], int]] = None,
760
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
761
+ lora_path: Optional[List[Optional[str]]] = None,
762
+ stream: bool = False,
763
+ ):
764
+ obj = GenerateReqInput(
765
+ text=prompt,
766
+ input_ids=input_ids,
767
+ sampling_params=sampling_params,
768
+ return_logprob=return_logprob,
769
+ logprob_start_len=logprob_start_len,
770
+ top_logprobs_num=top_logprobs_num,
771
+ lora_path=lora_path,
772
+ stream=stream,
773
+ )
774
+
775
+ ret = await generate_request(obj, None)
776
+
777
+ if stream is True:
778
+ generator = ret.body_iterator
779
+
780
+ async def generator_wrapper():
781
+
782
+ offset = 0
783
+
784
+ while True:
785
+ chunk = await generator.__anext__()
786
+
787
+ if chunk.startswith(STREAM_END_SYMBOL):
788
+ break
789
+ else:
790
+ data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
791
+ data["text"] = data["text"][offset:]
792
+ offset += len(data["text"])
793
+ yield data
794
+
795
+ return generator_wrapper()
796
+ else:
797
+ return ret
798
+
799
+ def shutdown(self):
800
+ kill_process_tree(os.getpid(), include_parent=False)
801
+
802
+ def get_tokenizer(self):
803
+ global tokenizer_manager
804
+
805
+ if tokenizer_manager is None:
806
+ raise ReferenceError("Tokenizer Manager is not initialized.")
807
+ else:
808
+ return tokenizer_manager.tokenizer
809
+
810
+ def encode(
811
+ self,
812
+ prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
813
+ ):
814
+ obj = EmbeddingReqInput(text=prompt)
815
+
816
+ # get the current event loop
817
+ loop = asyncio.get_event_loop()
818
+ return loop.run_until_complete(encode_request(obj, None))
819
+
820
+ def start_profile(self):
821
+ tokenizer_manager.start_profile()
822
+
823
+ def stop_profile(self):
824
+ tokenizer_manager.stop_profile()
825
+
826
+ def get_server_info(self):
827
+ return {
828
+ **dataclasses.asdict(tokenizer_manager.server_args), # server args
829
+ **scheduler_info,
830
+ "version": __version__,
831
+ }
832
+
833
+ def init_weights_update_group(
834
+ self,
835
+ master_address: str,
836
+ master_port: int,
837
+ rank_offset: int,
838
+ world_size: int,
839
+ group_name: str,
840
+ backend: str = "nccl",
841
+ ):
842
+ """Initialize parameter update group."""
843
+ obj = InitWeightsUpdateGroupReqInput(
844
+ master_address=master_address,
845
+ master_port=master_port,
846
+ rank_offset=rank_offset,
847
+ world_size=world_size,
848
+ group_name=group_name,
849
+ backend=backend,
850
+ )
851
+
852
+ async def _init_group():
853
+ return await tokenizer_manager.init_weights_update_group(obj, None)
854
+
855
+ loop = asyncio.get_event_loop()
856
+ return loop.run_until_complete(_init_group())
857
+
858
+ def update_weights_from_distributed(self, name, dtype, shape):
859
+ """Update weights from distributed source."""
860
+ obj = UpdateWeightsFromDistributedReqInput(
861
+ name=name,
862
+ dtype=dtype,
863
+ shape=shape,
864
+ )
865
+
866
+ async def _update_weights():
867
+ return await tokenizer_manager.update_weights_from_distributed(obj, None)
868
+
869
+ loop = asyncio.get_event_loop()
870
+ return loop.run_until_complete(_update_weights())
871
+
872
+ def get_weights_by_name(self, name, truncate_size=100):
873
+ """Get weights by parameter name."""
874
+ obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
875
+
876
+ async def _get_weights():
877
+ return await tokenizer_manager.get_weights_by_name(obj, None)
878
+
879
+ loop = asyncio.get_event_loop()
880
+ return loop.run_until_complete(_get_weights())
881
+
882
+
640
883
  class Runtime:
641
884
  """
642
- A wrapper for the server.
885
+ A wrapper for the HTTP server.
643
886
  This is used for launching the server in a python program without
644
887
  using the commond line interface.
888
+
889
+ It is mainly used for the frontend language.
890
+ You should use the Engine class if you want to do normal offline processing.
645
891
  """
646
892
 
647
893
  def __init__(
@@ -789,152 +1035,3 @@ class Runtime:
789
1035
 
790
1036
  def __del__(self):
791
1037
  self.shutdown()
792
-
793
-
794
- STREAM_END_SYMBOL = b"data: [DONE]"
795
- STREAM_CHUNK_START_SYMBOL = b"data:"
796
-
797
-
798
- class Engine:
799
- """
800
- SRT Engine without an HTTP server layer.
801
-
802
- This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
803
- launching the HTTP server adds unnecessary complexity or overhead,
804
- """
805
-
806
- def __init__(self, log_level: str = "error", *args, **kwargs):
807
- # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
808
- atexit.register(self.shutdown)
809
-
810
- server_args = ServerArgs(*args, log_level=log_level, **kwargs)
811
- launch_engine(server_args=server_args)
812
-
813
- def generate(
814
- self,
815
- # The input prompt. It can be a single prompt or a batch of prompts.
816
- prompt: Optional[Union[List[str], str]] = None,
817
- sampling_params: Optional[Union[List[Dict], Dict]] = None,
818
- # The token ids for text; one can either specify text or input_ids.
819
- input_ids: Optional[Union[List[List[int]], List[int]]] = None,
820
- return_logprob: Optional[Union[List[bool], bool]] = False,
821
- logprob_start_len: Optional[Union[List[int], int]] = None,
822
- top_logprobs_num: Optional[Union[List[int], int]] = None,
823
- lora_path: Optional[List[Optional[str]]] = None,
824
- stream: bool = False,
825
- ):
826
- obj = GenerateReqInput(
827
- text=prompt,
828
- input_ids=input_ids,
829
- sampling_params=sampling_params,
830
- return_logprob=return_logprob,
831
- logprob_start_len=logprob_start_len,
832
- top_logprobs_num=top_logprobs_num,
833
- lora_path=lora_path,
834
- stream=stream,
835
- )
836
-
837
- # get the current event loop
838
- loop = asyncio.get_event_loop()
839
- ret = loop.run_until_complete(generate_request(obj, None))
840
-
841
- if stream is True:
842
-
843
- def generator_wrapper():
844
- offset = 0
845
- loop = asyncio.get_event_loop()
846
- generator = ret.body_iterator
847
- while True:
848
- chunk = loop.run_until_complete(generator.__anext__())
849
-
850
- if chunk.startswith(STREAM_END_SYMBOL):
851
- break
852
- else:
853
- data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
854
- data["text"] = data["text"][offset:]
855
- offset += len(data["text"])
856
- yield data
857
-
858
- # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
859
- # however, it allows to wrap the generator as a subfunction and return
860
- return generator_wrapper()
861
- else:
862
- return ret
863
-
864
- async def async_generate(
865
- self,
866
- # The input prompt. It can be a single prompt or a batch of prompts.
867
- prompt: Optional[Union[List[str], str]] = None,
868
- sampling_params: Optional[Dict] = None,
869
- # The token ids for text; one can either specify text or input_ids.
870
- input_ids: Optional[Union[List[List[int]], List[int]]] = None,
871
- return_logprob: Optional[Union[List[bool], bool]] = False,
872
- logprob_start_len: Optional[Union[List[int], int]] = None,
873
- top_logprobs_num: Optional[Union[List[int], int]] = None,
874
- lora_path: Optional[List[Optional[str]]] = None,
875
- stream: bool = False,
876
- ):
877
- obj = GenerateReqInput(
878
- text=prompt,
879
- input_ids=input_ids,
880
- sampling_params=sampling_params,
881
- return_logprob=return_logprob,
882
- logprob_start_len=logprob_start_len,
883
- top_logprobs_num=top_logprobs_num,
884
- lora_path=lora_path,
885
- stream=stream,
886
- )
887
-
888
- ret = await generate_request(obj, None)
889
-
890
- if stream is True:
891
- generator = ret.body_iterator
892
-
893
- async def generator_wrapper():
894
-
895
- offset = 0
896
-
897
- while True:
898
- chunk = await generator.__anext__()
899
-
900
- if chunk.startswith(STREAM_END_SYMBOL):
901
- break
902
- else:
903
- data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
904
- data["text"] = data["text"][offset:]
905
- offset += len(data["text"])
906
- yield data
907
-
908
- return generator_wrapper()
909
- else:
910
- return ret
911
-
912
- def shutdown(self):
913
- kill_process_tree(os.getpid(), include_parent=False)
914
-
915
- def get_tokenizer(self):
916
- global tokenizer_manager
917
-
918
- if tokenizer_manager is None:
919
- raise ReferenceError("Tokenizer Manager is not initialized.")
920
- else:
921
- return tokenizer_manager.tokenizer
922
-
923
- def encode(
924
- self,
925
- prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
926
- ):
927
- obj = EmbeddingReqInput(text=prompt)
928
-
929
- # get the current event loop
930
- loop = asyncio.get_event_loop()
931
- return loop.run_until_complete(encode_request(obj, None))
932
-
933
- def start_profile(self):
934
- tokenizer_manager.start_profile()
935
-
936
- def stop_profile(self):
937
- tokenizer_manager.stop_profile()
938
-
939
- async def get_server_info(self):
940
- return await _get_server_info()