sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. sglang/bench_offline_throughput.py +55 -2
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +4 -3
  4. sglang/bench_serving.py +13 -0
  5. sglang/check_env.py +1 -1
  6. sglang/launch_server.py +3 -2
  7. sglang/srt/_custom_ops.py +118 -0
  8. sglang/srt/configs/device_config.py +17 -0
  9. sglang/srt/configs/load_config.py +84 -0
  10. sglang/srt/configs/model_config.py +161 -4
  11. sglang/srt/configs/qwen2vl.py +5 -8
  12. sglang/srt/constrained/outlines_backend.py +6 -1
  13. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  14. sglang/srt/distributed/__init__.py +3 -0
  15. sglang/srt/distributed/communication_op.py +34 -0
  16. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  17. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  19. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  20. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  21. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  22. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  24. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  25. sglang/srt/distributed/parallel_state.py +1275 -0
  26. sglang/srt/distributed/utils.py +223 -0
  27. sglang/srt/hf_transformers_utils.py +37 -1
  28. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  29. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  30. sglang/srt/layers/fused_moe_patch.py +20 -11
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/logits_processor.py +17 -3
  33. sglang/srt/layers/quantization/__init__.py +34 -0
  34. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  35. sglang/srt/lora/lora.py +1 -1
  36. sglang/srt/managers/data_parallel_controller.py +7 -11
  37. sglang/srt/managers/detokenizer_manager.py +7 -4
  38. sglang/srt/managers/image_processor.py +1 -1
  39. sglang/srt/managers/io_struct.py +48 -12
  40. sglang/srt/managers/schedule_batch.py +42 -36
  41. sglang/srt/managers/schedule_policy.py +7 -4
  42. sglang/srt/managers/scheduler.py +111 -46
  43. sglang/srt/managers/session_controller.py +0 -3
  44. sglang/srt/managers/tokenizer_manager.py +169 -100
  45. sglang/srt/managers/tp_worker.py +36 -3
  46. sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
  47. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  48. sglang/srt/model_executor/forward_batch_info.py +9 -4
  49. sglang/srt/model_executor/model_runner.py +136 -150
  50. sglang/srt/model_loader/__init__.py +34 -0
  51. sglang/srt/model_loader/loader.py +1139 -0
  52. sglang/srt/model_loader/utils.py +41 -0
  53. sglang/srt/model_loader/weight_utils.py +640 -0
  54. sglang/srt/models/baichuan.py +9 -10
  55. sglang/srt/models/chatglm.py +6 -15
  56. sglang/srt/models/commandr.py +2 -3
  57. sglang/srt/models/dbrx.py +2 -3
  58. sglang/srt/models/deepseek.py +4 -11
  59. sglang/srt/models/deepseek_v2.py +3 -11
  60. sglang/srt/models/exaone.py +2 -3
  61. sglang/srt/models/gemma.py +2 -6
  62. sglang/srt/models/gemma2.py +3 -14
  63. sglang/srt/models/gemma2_reward.py +0 -1
  64. sglang/srt/models/gpt2.py +5 -12
  65. sglang/srt/models/gpt_bigcode.py +6 -22
  66. sglang/srt/models/grok.py +14 -51
  67. sglang/srt/models/internlm2.py +2 -3
  68. sglang/srt/models/internlm2_reward.py +0 -1
  69. sglang/srt/models/llama.py +97 -27
  70. sglang/srt/models/llama_classification.py +1 -2
  71. sglang/srt/models/llama_embedding.py +1 -2
  72. sglang/srt/models/llama_reward.py +2 -3
  73. sglang/srt/models/llava.py +10 -12
  74. sglang/srt/models/llavavid.py +1 -2
  75. sglang/srt/models/minicpm.py +4 -7
  76. sglang/srt/models/minicpm3.py +6 -19
  77. sglang/srt/models/mixtral.py +12 -5
  78. sglang/srt/models/mixtral_quant.py +2 -3
  79. sglang/srt/models/mllama.py +3 -7
  80. sglang/srt/models/olmo.py +2 -8
  81. sglang/srt/models/olmo2.py +391 -0
  82. sglang/srt/models/olmoe.py +3 -5
  83. sglang/srt/models/phi3_small.py +8 -8
  84. sglang/srt/models/qwen.py +2 -3
  85. sglang/srt/models/qwen2.py +10 -9
  86. sglang/srt/models/qwen2_moe.py +4 -11
  87. sglang/srt/models/qwen2_vl.py +12 -9
  88. sglang/srt/models/registry.py +99 -0
  89. sglang/srt/models/stablelm.py +2 -3
  90. sglang/srt/models/torch_native_llama.py +6 -12
  91. sglang/srt/models/xverse.py +2 -4
  92. sglang/srt/models/xverse_moe.py +4 -11
  93. sglang/srt/models/yivl.py +2 -3
  94. sglang/srt/openai_api/adapter.py +10 -6
  95. sglang/srt/openai_api/protocol.py +1 -0
  96. sglang/srt/server.py +303 -204
  97. sglang/srt/server_args.py +65 -31
  98. sglang/srt/utils.py +253 -48
  99. sglang/test/test_utils.py +27 -7
  100. sglang/utils.py +2 -2
  101. sglang/version.py +1 -1
  102. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
  103. sglang-0.4.0.dist-info/RECORD +184 -0
  104. sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
  105. sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
  106. sglang/srt/layers/fused_moe_grok/layer.py +0 -630
  107. sglang-0.3.6.post2.dist-info/RECORD +0 -164
  108. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  109. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  110. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
sglang/srt/server.py CHANGED
@@ -23,6 +23,7 @@ import json
23
23
  import logging
24
24
  import multiprocessing as mp
25
25
  import os
26
+ import signal
26
27
  import threading
27
28
  import time
28
29
  from http import HTTPStatus
@@ -51,8 +52,11 @@ from sglang.srt.managers.io_struct import (
51
52
  CloseSessionReqInput,
52
53
  EmbeddingReqInput,
53
54
  GenerateReqInput,
55
+ GetWeightsByNameReqInput,
56
+ InitWeightsUpdateGroupReqInput,
54
57
  OpenSessionReqInput,
55
- UpdateWeightReqInput,
58
+ UpdateWeightFromDiskReqInput,
59
+ UpdateWeightsFromDistributedReqInput,
56
60
  )
57
61
  from sglang.srt.managers.scheduler import run_scheduler_process
58
62
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -79,7 +83,7 @@ from sglang.srt.utils import (
79
83
  configure_logger,
80
84
  delete_directory,
81
85
  is_port_available,
82
- kill_child_process,
86
+ kill_process_tree,
83
87
  maybe_set_triton_cache_manager,
84
88
  prepare_model_and_tokenizer,
85
89
  set_prometheus_multiproc_dir,
@@ -92,7 +96,7 @@ logger = logging.getLogger(__name__)
92
96
 
93
97
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
94
98
 
95
-
99
+ # Fast API
96
100
  app = FastAPI()
97
101
  app.add_middleware(
98
102
  CORSMiddleware,
@@ -103,7 +107,7 @@ app.add_middleware(
103
107
  )
104
108
 
105
109
  tokenizer_manager: TokenizerManager = None
106
- _max_total_num_tokens = None
110
+ scheduler_info: Dict = None
107
111
 
108
112
  ##### Native API endpoints #####
109
113
 
@@ -149,13 +153,11 @@ async def get_model_info():
149
153
 
150
154
  @app.get("/get_server_info")
151
155
  async def get_server_info():
152
- try:
153
- return await _get_server_info()
154
-
155
- except Exception as e:
156
- return ORJSONResponse(
157
- {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
158
- )
156
+ return {
157
+ **dataclasses.asdict(tokenizer_manager.server_args), # server args
158
+ **scheduler_info,
159
+ "version": __version__,
160
+ }
159
161
 
160
162
 
161
163
  @app.post("/flush_cache")
@@ -171,7 +173,7 @@ async def flush_cache():
171
173
 
172
174
  @app.get("/start_profile")
173
175
  @app.post("/start_profile")
174
- async def start_profile():
176
+ async def start_profile_async():
175
177
  """Start profiling."""
176
178
  tokenizer_manager.start_profile()
177
179
  return Response(
@@ -182,7 +184,7 @@ async def start_profile():
182
184
 
183
185
  @app.get("/stop_profile")
184
186
  @app.post("/stop_profile")
185
- async def stop_profile():
187
+ async def stop_profile_async():
186
188
  """Stop profiling."""
187
189
  tokenizer_manager.stop_profile()
188
190
  return Response(
@@ -191,11 +193,11 @@ async def stop_profile():
191
193
  )
192
194
 
193
195
 
194
- @app.post("/update_weights")
196
+ @app.post("/update_weights_from_disk")
195
197
  @time_func_latency
196
- async def update_weights(obj: UpdateWeightReqInput, request: Request):
197
- """Update the weights inplace without re-launching the server."""
198
- success, message = await tokenizer_manager.update_weights(obj, request)
198
+ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
199
+ """Update the weights from disk inplace without re-launching the server."""
200
+ success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
199
201
  content = {"success": success, "message": message}
200
202
  if success:
201
203
  return ORJSONResponse(
@@ -209,6 +211,52 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
209
211
  )
210
212
 
211
213
 
214
+ @app.post("/init_weights_update_group")
215
+ async def init_weights_update_group(
216
+ obj: InitWeightsUpdateGroupReqInput, request: Request
217
+ ):
218
+ """Initialize the parameter update group."""
219
+ success, message = await tokenizer_manager.init_weights_update_group(obj, request)
220
+ content = {"success": success, "message": message}
221
+ if success:
222
+ return ORJSONResponse(content, status_code=200)
223
+ else:
224
+ return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
225
+
226
+
227
+ @app.post("/update_weights_from_distributed")
228
+ async def update_weights_from_distributed(
229
+ obj: UpdateWeightsFromDistributedReqInput, request: Request
230
+ ):
231
+ """Update model parameter from distributed online."""
232
+ success, message = await tokenizer_manager.update_weights_from_distributed(
233
+ obj, request
234
+ )
235
+ content = {"success": success, "message": message}
236
+ if success:
237
+ return ORJSONResponse(content, status_code=200)
238
+ else:
239
+ return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
240
+
241
+
242
+ @app.api_route("/get_weights_by_name", methods=["GET", "POST"])
243
+ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
244
+ """Get model parameter by name."""
245
+ try:
246
+ ret = await tokenizer_manager.get_weights_by_name(obj, request)
247
+ if ret is None:
248
+ return ORJSONResponse(
249
+ {"error": {"message": "Get parameter by name failed"}},
250
+ status_code=HTTPStatus.BAD_REQUEST,
251
+ )
252
+ else:
253
+ return ORJSONResponse(ret, status_code=200)
254
+ except Exception as e:
255
+ return ORJSONResponse(
256
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
257
+ )
258
+
259
+
212
260
  @app.api_route("/open_session", methods=["GET", "POST"])
213
261
  async def open_session(obj: OpenSessionReqInput, request: Request):
214
262
  """Open a session, and return its unique session id."""
@@ -233,6 +281,8 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
233
281
  )
234
282
 
235
283
 
284
+ # fastapi implicitly converts json in the request to obj (dataclass)
285
+ @app.api_route("/generate", methods=["POST", "PUT"])
236
286
  @time_func_latency
237
287
  async def generate_request(obj: GenerateReqInput, request: Request):
238
288
  """Handle a generate request."""
@@ -266,11 +316,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
266
316
  )
267
317
 
268
318
 
269
- # fastapi implicitly converts json in the request to obj (dataclass)
270
- app.post("/generate")(generate_request)
271
- app.put("/generate")(generate_request)
272
-
273
-
319
+ @app.api_route("/encode", methods=["POST", "PUT"])
274
320
  @time_func_latency
275
321
  async def encode_request(obj: EmbeddingReqInput, request: Request):
276
322
  """Handle an embedding request."""
@@ -283,10 +329,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
283
329
  )
284
330
 
285
331
 
286
- app.post("/encode")(encode_request)
287
- app.put("/encode")(encode_request)
288
-
289
-
332
+ @app.api_route("/encode", methods=["POST", "PUT"])
290
333
  @time_func_latency
291
334
  async def classify_request(obj: EmbeddingReqInput, request: Request):
292
335
  """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
@@ -299,10 +342,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
299
342
  )
300
343
 
301
344
 
302
- app.post("/classify")(classify_request)
303
- app.put("/classify")(classify_request)
304
-
305
-
306
345
  ##### OpenAI-compatible API endpoints #####
307
346
 
308
347
 
@@ -380,11 +419,11 @@ def launch_engine(
380
419
  server_args: ServerArgs,
381
420
  ):
382
421
  """
383
- Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess.
422
+ Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
384
423
  """
385
424
 
386
425
  global tokenizer_manager
387
- global _max_total_num_tokens
426
+ global scheduler_info
388
427
 
389
428
  # Configure global environment
390
429
  configure_logger(server_args)
@@ -450,8 +489,8 @@ def launch_engine(
450
489
  if server_args.chat_template:
451
490
  load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
452
491
 
453
- # Wait for model to finish loading & get max token nums
454
- scheduler_info = []
492
+ # Wait for model to finish loading
493
+ scheduler_infos = []
455
494
  for i in range(len(scheduler_pipe_readers)):
456
495
  data = scheduler_pipe_readers[i].recv()
457
496
 
@@ -459,10 +498,10 @@ def launch_engine(
459
498
  raise RuntimeError(
460
499
  "Initialization failed. Please see the error messages above."
461
500
  )
462
- scheduler_info.append(data)
501
+ scheduler_infos.append(data)
463
502
 
464
503
  # Assume all schedulers have same max_total_num_tokens
465
- _max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
504
+ scheduler_info = scheduler_infos[0]
466
505
 
467
506
 
468
507
  def launch_server(
@@ -476,12 +515,12 @@ def launch_server(
476
515
 
477
516
  1. HTTP server: A FastAPI server that routes requests to the engine.
478
517
  2. SRT engine:
479
- 1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler.
518
+ 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
480
519
  2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
481
- 3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
520
+ 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
482
521
 
483
522
  Note:
484
- 1. The HTTP server and Tokenizer Manager both run in the main process.
523
+ 1. The HTTP server and TokenizerManager both run in the main process.
485
524
  2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
486
525
  """
487
526
  launch_engine(server_args=server_args)
@@ -490,7 +529,7 @@ def launch_server(
490
529
  if server_args.api_key:
491
530
  add_api_key_middleware(app, server_args.api_key)
492
531
 
493
- # add prometheus middleware
532
+ # Add prometheus middleware
494
533
  if server_args.enable_metrics:
495
534
  add_prometheus_middleware(app)
496
535
  enable_func_timer()
@@ -502,7 +541,7 @@ def launch_server(
502
541
  t.start()
503
542
 
504
543
  try:
505
- # Listen for HTTP requests
544
+ # Update logging configs
506
545
  LOGGING_CONFIG["formatters"]["default"][
507
546
  "fmt"
508
547
  ] = "[%(asctime)s] %(levelprefix)s %(message)s"
@@ -511,6 +550,8 @@ def launch_server(
511
550
  "fmt"
512
551
  ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
513
552
  LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
553
+
554
+ # Listen for HTTP requests
514
555
  uvicorn.run(
515
556
  app,
516
557
  host=server_args.host,
@@ -523,15 +564,6 @@ def launch_server(
523
564
  t.join()
524
565
 
525
566
 
526
- async def _get_server_info():
527
- return {
528
- **dataclasses.asdict(tokenizer_manager.server_args), # server args
529
- "memory_pool_size": await tokenizer_manager.get_memory_pool_size(), # memory pool size
530
- "max_total_num_tokens": _max_total_num_tokens, # max total num tokens
531
- "version": __version__,
532
- }
533
-
534
-
535
567
  def _set_envs_and_config(server_args: ServerArgs):
536
568
  # Set global environments
537
569
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -562,6 +594,15 @@ def _set_envs_and_config(server_args: ServerArgs):
562
594
  "at https://docs.flashinfer.ai/installation.html.",
563
595
  )
564
596
 
597
+ # Register the signal handler.
598
+ # The child processes will send SIGQUIT to this process when any error happens
599
+ # This process then clean up the whole process tree
600
+ def sigquit_handler(signum, frame):
601
+ kill_process_tree(os.getpid())
602
+
603
+ signal.signal(signal.SIGQUIT, sigquit_handler)
604
+
605
+ # Set mp start method
565
606
  mp.set_start_method("spawn", force=True)
566
607
 
567
608
 
@@ -588,7 +629,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
588
629
  if pipe_finish_writer is not None:
589
630
  pipe_finish_writer.send(last_traceback)
590
631
  logger.error(f"Initialization failed. warmup error: {last_traceback}")
591
- kill_child_process(include_self=True)
632
+ kill_process_tree(os.getpid())
592
633
  return
593
634
 
594
635
  model_info = res.json()
@@ -621,9 +662,10 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
621
662
  if pipe_finish_writer is not None:
622
663
  pipe_finish_writer.send(last_traceback)
623
664
  logger.error(f"Initialization failed. warmup error: {last_traceback}")
624
- kill_child_process(include_self=True)
665
+ kill_process_tree(os.getpid())
625
666
  return
626
667
 
668
+ # Debug print
627
669
  # logger.info(f"{res.json()=}")
628
670
 
629
671
  logger.info("The server is fired up and ready to roll!")
@@ -634,11 +676,218 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
634
676
  delete_directory(server_args.model_path)
635
677
 
636
678
 
679
+ STREAM_END_SYMBOL = b"data: [DONE]"
680
+ STREAM_CHUNK_START_SYMBOL = b"data:"
681
+
682
+
683
+ class Engine:
684
+ """
685
+ SRT Engine without an HTTP server layer.
686
+
687
+ This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
688
+ launching the HTTP server adds unnecessary complexity or overhead,
689
+ """
690
+
691
+ def __init__(self, log_level: str = "error", *args, **kwargs):
692
+ """See the arguments in server_args.py::ServerArgs"""
693
+
694
+ # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
695
+ atexit.register(self.shutdown)
696
+
697
+ server_args = ServerArgs(*args, log_level=log_level, **kwargs)
698
+ launch_engine(server_args=server_args)
699
+
700
+ def generate(
701
+ self,
702
+ # The input prompt. It can be a single prompt or a batch of prompts.
703
+ prompt: Optional[Union[List[str], str]] = None,
704
+ sampling_params: Optional[Union[List[Dict], Dict]] = None,
705
+ # The token ids for text; one can either specify text or input_ids.
706
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None,
707
+ return_logprob: Optional[Union[List[bool], bool]] = False,
708
+ logprob_start_len: Optional[Union[List[int], int]] = None,
709
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
710
+ lora_path: Optional[List[Optional[str]]] = None,
711
+ stream: bool = False,
712
+ ):
713
+ obj = GenerateReqInput(
714
+ text=prompt,
715
+ input_ids=input_ids,
716
+ sampling_params=sampling_params,
717
+ return_logprob=return_logprob,
718
+ logprob_start_len=logprob_start_len,
719
+ top_logprobs_num=top_logprobs_num,
720
+ lora_path=lora_path,
721
+ stream=stream,
722
+ )
723
+
724
+ # get the current event loop
725
+ loop = asyncio.get_event_loop()
726
+ ret = loop.run_until_complete(generate_request(obj, None))
727
+
728
+ if stream is True:
729
+
730
+ def generator_wrapper():
731
+ offset = 0
732
+ loop = asyncio.get_event_loop()
733
+ generator = ret.body_iterator
734
+ while True:
735
+ chunk = loop.run_until_complete(generator.__anext__())
736
+
737
+ if chunk.startswith(STREAM_END_SYMBOL):
738
+ break
739
+ else:
740
+ data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
741
+ data["text"] = data["text"][offset:]
742
+ offset += len(data["text"])
743
+ yield data
744
+
745
+ # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
746
+ # however, it allows to wrap the generator as a subfunction and return
747
+ return generator_wrapper()
748
+ else:
749
+ return ret
750
+
751
+ async def async_generate(
752
+ self,
753
+ # The input prompt. It can be a single prompt or a batch of prompts.
754
+ prompt: Optional[Union[List[str], str]] = None,
755
+ sampling_params: Optional[Dict] = None,
756
+ # The token ids for text; one can either specify text or input_ids.
757
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None,
758
+ return_logprob: Optional[Union[List[bool], bool]] = False,
759
+ logprob_start_len: Optional[Union[List[int], int]] = None,
760
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
761
+ lora_path: Optional[List[Optional[str]]] = None,
762
+ stream: bool = False,
763
+ ):
764
+ obj = GenerateReqInput(
765
+ text=prompt,
766
+ input_ids=input_ids,
767
+ sampling_params=sampling_params,
768
+ return_logprob=return_logprob,
769
+ logprob_start_len=logprob_start_len,
770
+ top_logprobs_num=top_logprobs_num,
771
+ lora_path=lora_path,
772
+ stream=stream,
773
+ )
774
+
775
+ ret = await generate_request(obj, None)
776
+
777
+ if stream is True:
778
+ generator = ret.body_iterator
779
+
780
+ async def generator_wrapper():
781
+
782
+ offset = 0
783
+
784
+ while True:
785
+ chunk = await generator.__anext__()
786
+
787
+ if chunk.startswith(STREAM_END_SYMBOL):
788
+ break
789
+ else:
790
+ data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
791
+ data["text"] = data["text"][offset:]
792
+ offset += len(data["text"])
793
+ yield data
794
+
795
+ return generator_wrapper()
796
+ else:
797
+ return ret
798
+
799
+ def shutdown(self):
800
+ kill_process_tree(os.getpid(), include_parent=False)
801
+
802
+ def get_tokenizer(self):
803
+ global tokenizer_manager
804
+
805
+ if tokenizer_manager is None:
806
+ raise ReferenceError("Tokenizer Manager is not initialized.")
807
+ else:
808
+ return tokenizer_manager.tokenizer
809
+
810
+ def encode(
811
+ self,
812
+ prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
813
+ ):
814
+ obj = EmbeddingReqInput(text=prompt)
815
+
816
+ # get the current event loop
817
+ loop = asyncio.get_event_loop()
818
+ return loop.run_until_complete(encode_request(obj, None))
819
+
820
+ def start_profile(self):
821
+ tokenizer_manager.start_profile()
822
+
823
+ def stop_profile(self):
824
+ tokenizer_manager.stop_profile()
825
+
826
+ def get_server_info(self):
827
+ return {
828
+ **dataclasses.asdict(tokenizer_manager.server_args), # server args
829
+ **scheduler_info,
830
+ "version": __version__,
831
+ }
832
+
833
+ def init_weights_update_group(
834
+ self,
835
+ master_address: str,
836
+ master_port: int,
837
+ rank_offset: int,
838
+ world_size: int,
839
+ group_name: str,
840
+ backend: str = "nccl",
841
+ ):
842
+ """Initialize parameter update group."""
843
+ obj = InitWeightsUpdateGroupReqInput(
844
+ master_address=master_address,
845
+ master_port=master_port,
846
+ rank_offset=rank_offset,
847
+ world_size=world_size,
848
+ group_name=group_name,
849
+ backend=backend,
850
+ )
851
+
852
+ async def _init_group():
853
+ return await tokenizer_manager.init_weights_update_group(obj, None)
854
+
855
+ loop = asyncio.get_event_loop()
856
+ return loop.run_until_complete(_init_group())
857
+
858
+ def update_weights_from_distributed(self, name, dtype, shape):
859
+ """Update weights from distributed source."""
860
+ obj = UpdateWeightsFromDistributedReqInput(
861
+ name=name,
862
+ dtype=dtype,
863
+ shape=shape,
864
+ )
865
+
866
+ async def _update_weights():
867
+ return await tokenizer_manager.update_weights_from_distributed(obj, None)
868
+
869
+ loop = asyncio.get_event_loop()
870
+ return loop.run_until_complete(_update_weights())
871
+
872
+ def get_weights_by_name(self, name, truncate_size=100):
873
+ """Get weights by parameter name."""
874
+ obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
875
+
876
+ async def _get_weights():
877
+ return await tokenizer_manager.get_weights_by_name(obj, None)
878
+
879
+ loop = asyncio.get_event_loop()
880
+ return loop.run_until_complete(_get_weights())
881
+
882
+
637
883
  class Runtime:
638
884
  """
639
- A wrapper for the server.
885
+ A wrapper for the HTTP server.
640
886
  This is used for launching the server in a python program without
641
887
  using the commond line interface.
888
+
889
+ It is mainly used for the frontend language.
890
+ You should use the Engine class if you want to do normal offline processing.
642
891
  """
643
892
 
644
893
  def __init__(
@@ -690,7 +939,7 @@ class Runtime:
690
939
 
691
940
  def shutdown(self):
692
941
  if self.pid is not None:
693
- kill_child_process(self.pid, include_self=True)
942
+ kill_process_tree(self.pid)
694
943
  self.pid = None
695
944
 
696
945
  def cache_prefix(self, prefix: str):
@@ -786,153 +1035,3 @@ class Runtime:
786
1035
 
787
1036
  def __del__(self):
788
1037
  self.shutdown()
789
-
790
-
791
- STREAM_END_SYMBOL = b"data: [DONE]"
792
- STREAM_CHUNK_START_SYMBOL = b"data:"
793
-
794
-
795
- class Engine:
796
- """
797
- SRT Engine without an HTTP server layer.
798
-
799
- This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
800
- launching the HTTP server adds unnecessary complexity or overhead,
801
- """
802
-
803
- def __init__(self, *args, **kwargs):
804
-
805
- # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
806
- atexit.register(self.shutdown)
807
-
808
- # runtime server default log level is log
809
- # offline engine works in scripts, so we set it to error
810
-
811
- if "log_level" not in kwargs:
812
- kwargs["log_level"] = "error"
813
-
814
- server_args = ServerArgs(*args, **kwargs)
815
- launch_engine(server_args=server_args)
816
-
817
- def generate(
818
- self,
819
- # The input prompt. It can be a single prompt or a batch of prompts.
820
- prompt: Optional[Union[List[str], str]] = None,
821
- sampling_params: Optional[Union[List[Dict], Dict]] = None,
822
- # The token ids for text; one can either specify text or input_ids.
823
- input_ids: Optional[Union[List[List[int]], List[int]]] = None,
824
- return_logprob: Optional[Union[List[bool], bool]] = False,
825
- logprob_start_len: Optional[Union[List[int], int]] = None,
826
- top_logprobs_num: Optional[Union[List[int], int]] = None,
827
- lora_path: Optional[List[Optional[str]]] = None,
828
- stream: bool = False,
829
- ):
830
- obj = GenerateReqInput(
831
- text=prompt,
832
- input_ids=input_ids,
833
- sampling_params=sampling_params,
834
- return_logprob=return_logprob,
835
- logprob_start_len=logprob_start_len,
836
- top_logprobs_num=top_logprobs_num,
837
- lora_path=lora_path,
838
- stream=stream,
839
- )
840
-
841
- # get the current event loop
842
- loop = asyncio.get_event_loop()
843
- ret = loop.run_until_complete(generate_request(obj, None))
844
-
845
- if stream is True:
846
-
847
- def generator_wrapper():
848
- offset = 0
849
- loop = asyncio.get_event_loop()
850
- generator = ret.body_iterator
851
- while True:
852
- chunk = loop.run_until_complete(generator.__anext__())
853
-
854
- if chunk.startswith(STREAM_END_SYMBOL):
855
- break
856
- else:
857
- data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
858
- data["text"] = data["text"][offset:]
859
- offset += len(data["text"])
860
- yield data
861
-
862
- # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
863
- # however, it allows to wrap the generator as a subfunction and return
864
- return generator_wrapper()
865
- else:
866
- return ret
867
-
868
- async def async_generate(
869
- self,
870
- # The input prompt. It can be a single prompt or a batch of prompts.
871
- prompt: Optional[Union[List[str], str]] = None,
872
- sampling_params: Optional[Dict] = None,
873
- # The token ids for text; one can either specify text or input_ids.
874
- input_ids: Optional[Union[List[List[int]], List[int]]] = None,
875
- return_logprob: Optional[Union[List[bool], bool]] = False,
876
- logprob_start_len: Optional[Union[List[int], int]] = None,
877
- top_logprobs_num: Optional[Union[List[int], int]] = None,
878
- lora_path: Optional[List[Optional[str]]] = None,
879
- stream: bool = False,
880
- ):
881
- obj = GenerateReqInput(
882
- text=prompt,
883
- input_ids=input_ids,
884
- sampling_params=sampling_params,
885
- return_logprob=return_logprob,
886
- logprob_start_len=logprob_start_len,
887
- top_logprobs_num=top_logprobs_num,
888
- lora_path=lora_path,
889
- stream=stream,
890
- )
891
-
892
- ret = await generate_request(obj, None)
893
-
894
- if stream is True:
895
- generator = ret.body_iterator
896
-
897
- async def generator_wrapper():
898
-
899
- offset = 0
900
-
901
- while True:
902
- chunk = await generator.__anext__()
903
-
904
- if chunk.startswith(STREAM_END_SYMBOL):
905
- break
906
- else:
907
- data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
908
- data["text"] = data["text"][offset:]
909
- offset += len(data["text"])
910
- yield data
911
-
912
- return generator_wrapper()
913
- else:
914
- return ret
915
-
916
- def shutdown(self):
917
- kill_child_process()
918
-
919
- def get_tokenizer(self):
920
- global tokenizer_manager
921
-
922
- if tokenizer_manager is None:
923
- raise ReferenceError("Tokenizer Manager is not initialized.")
924
- else:
925
- return tokenizer_manager.tokenizer
926
-
927
- def encode(
928
- self,
929
- prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
930
- ):
931
- obj = EmbeddingReqInput(text=prompt)
932
-
933
- # get the current event loop
934
- loop = asyncio.get_event_loop()
935
- return loop.run_until_complete(encode_request(obj, None))
936
-
937
- async def get_server_info(self):
938
- return await _get_server_info()