sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +55 -2
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +4 -3
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/launch_server.py +3 -2
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +6 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/flashinfer_backend.py +13 -15
- sglang/srt/layers/attention/torch_native_backend.py +285 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +34 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -11
- sglang/srt/managers/detokenizer_manager.py +7 -4
- sglang/srt/managers/image_processor.py +1 -1
- sglang/srt/managers/io_struct.py +48 -12
- sglang/srt/managers/schedule_batch.py +42 -36
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +111 -46
- sglang/srt/managers/session_controller.py +0 -3
- sglang/srt/managers/tokenizer_manager.py +169 -100
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
- sglang/srt/model_executor/cuda_graph_runner.py +16 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +136 -150
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +2 -3
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +3 -11
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +14 -51
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +97 -27
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +10 -12
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +12 -5
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +391 -0
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -8
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -11
- sglang/srt/models/qwen2_vl.py +12 -9
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -12
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +10 -6
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/server.py +303 -204
- sglang/srt/server_args.py +65 -31
- sglang/srt/utils.py +253 -48
- sglang/test/test_utils.py +27 -7
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
- sglang-0.4.0.dist-info/RECORD +184 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
- sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
- sglang/srt/layers/fused_moe_grok/layer.py +0 -630
- sglang-0.3.6.post2.dist-info/RECORD +0 -164
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
sglang/srt/server.py
CHANGED
@@ -23,6 +23,7 @@ import json
|
|
23
23
|
import logging
|
24
24
|
import multiprocessing as mp
|
25
25
|
import os
|
26
|
+
import signal
|
26
27
|
import threading
|
27
28
|
import time
|
28
29
|
from http import HTTPStatus
|
@@ -51,8 +52,11 @@ from sglang.srt.managers.io_struct import (
|
|
51
52
|
CloseSessionReqInput,
|
52
53
|
EmbeddingReqInput,
|
53
54
|
GenerateReqInput,
|
55
|
+
GetWeightsByNameReqInput,
|
56
|
+
InitWeightsUpdateGroupReqInput,
|
54
57
|
OpenSessionReqInput,
|
55
|
-
|
58
|
+
UpdateWeightFromDiskReqInput,
|
59
|
+
UpdateWeightsFromDistributedReqInput,
|
56
60
|
)
|
57
61
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
58
62
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
@@ -79,7 +83,7 @@ from sglang.srt.utils import (
|
|
79
83
|
configure_logger,
|
80
84
|
delete_directory,
|
81
85
|
is_port_available,
|
82
|
-
|
86
|
+
kill_process_tree,
|
83
87
|
maybe_set_triton_cache_manager,
|
84
88
|
prepare_model_and_tokenizer,
|
85
89
|
set_prometheus_multiproc_dir,
|
@@ -92,7 +96,7 @@ logger = logging.getLogger(__name__)
|
|
92
96
|
|
93
97
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
94
98
|
|
95
|
-
|
99
|
+
# Fast API
|
96
100
|
app = FastAPI()
|
97
101
|
app.add_middleware(
|
98
102
|
CORSMiddleware,
|
@@ -103,7 +107,7 @@ app.add_middleware(
|
|
103
107
|
)
|
104
108
|
|
105
109
|
tokenizer_manager: TokenizerManager = None
|
106
|
-
|
110
|
+
scheduler_info: Dict = None
|
107
111
|
|
108
112
|
##### Native API endpoints #####
|
109
113
|
|
@@ -149,13 +153,11 @@ async def get_model_info():
|
|
149
153
|
|
150
154
|
@app.get("/get_server_info")
|
151
155
|
async def get_server_info():
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
158
|
-
)
|
156
|
+
return {
|
157
|
+
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
158
|
+
**scheduler_info,
|
159
|
+
"version": __version__,
|
160
|
+
}
|
159
161
|
|
160
162
|
|
161
163
|
@app.post("/flush_cache")
|
@@ -171,7 +173,7 @@ async def flush_cache():
|
|
171
173
|
|
172
174
|
@app.get("/start_profile")
|
173
175
|
@app.post("/start_profile")
|
174
|
-
async def
|
176
|
+
async def start_profile_async():
|
175
177
|
"""Start profiling."""
|
176
178
|
tokenizer_manager.start_profile()
|
177
179
|
return Response(
|
@@ -182,7 +184,7 @@ async def start_profile():
|
|
182
184
|
|
183
185
|
@app.get("/stop_profile")
|
184
186
|
@app.post("/stop_profile")
|
185
|
-
async def
|
187
|
+
async def stop_profile_async():
|
186
188
|
"""Stop profiling."""
|
187
189
|
tokenizer_manager.stop_profile()
|
188
190
|
return Response(
|
@@ -191,11 +193,11 @@ async def stop_profile():
|
|
191
193
|
)
|
192
194
|
|
193
195
|
|
194
|
-
@app.post("/
|
196
|
+
@app.post("/update_weights_from_disk")
|
195
197
|
@time_func_latency
|
196
|
-
async def
|
197
|
-
"""Update the weights inplace without re-launching the server."""
|
198
|
-
success, message = await tokenizer_manager.
|
198
|
+
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
199
|
+
"""Update the weights from disk inplace without re-launching the server."""
|
200
|
+
success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
|
199
201
|
content = {"success": success, "message": message}
|
200
202
|
if success:
|
201
203
|
return ORJSONResponse(
|
@@ -209,6 +211,52 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
|
209
211
|
)
|
210
212
|
|
211
213
|
|
214
|
+
@app.post("/init_weights_update_group")
|
215
|
+
async def init_weights_update_group(
|
216
|
+
obj: InitWeightsUpdateGroupReqInput, request: Request
|
217
|
+
):
|
218
|
+
"""Initialize the parameter update group."""
|
219
|
+
success, message = await tokenizer_manager.init_weights_update_group(obj, request)
|
220
|
+
content = {"success": success, "message": message}
|
221
|
+
if success:
|
222
|
+
return ORJSONResponse(content, status_code=200)
|
223
|
+
else:
|
224
|
+
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
225
|
+
|
226
|
+
|
227
|
+
@app.post("/update_weights_from_distributed")
|
228
|
+
async def update_weights_from_distributed(
|
229
|
+
obj: UpdateWeightsFromDistributedReqInput, request: Request
|
230
|
+
):
|
231
|
+
"""Update model parameter from distributed online."""
|
232
|
+
success, message = await tokenizer_manager.update_weights_from_distributed(
|
233
|
+
obj, request
|
234
|
+
)
|
235
|
+
content = {"success": success, "message": message}
|
236
|
+
if success:
|
237
|
+
return ORJSONResponse(content, status_code=200)
|
238
|
+
else:
|
239
|
+
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
240
|
+
|
241
|
+
|
242
|
+
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
243
|
+
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
244
|
+
"""Get model parameter by name."""
|
245
|
+
try:
|
246
|
+
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
247
|
+
if ret is None:
|
248
|
+
return ORJSONResponse(
|
249
|
+
{"error": {"message": "Get parameter by name failed"}},
|
250
|
+
status_code=HTTPStatus.BAD_REQUEST,
|
251
|
+
)
|
252
|
+
else:
|
253
|
+
return ORJSONResponse(ret, status_code=200)
|
254
|
+
except Exception as e:
|
255
|
+
return ORJSONResponse(
|
256
|
+
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
257
|
+
)
|
258
|
+
|
259
|
+
|
212
260
|
@app.api_route("/open_session", methods=["GET", "POST"])
|
213
261
|
async def open_session(obj: OpenSessionReqInput, request: Request):
|
214
262
|
"""Open a session, and return its unique session id."""
|
@@ -233,6 +281,8 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
|
233
281
|
)
|
234
282
|
|
235
283
|
|
284
|
+
# fastapi implicitly converts json in the request to obj (dataclass)
|
285
|
+
@app.api_route("/generate", methods=["POST", "PUT"])
|
236
286
|
@time_func_latency
|
237
287
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
238
288
|
"""Handle a generate request."""
|
@@ -266,11 +316,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|
266
316
|
)
|
267
317
|
|
268
318
|
|
269
|
-
|
270
|
-
app.post("/generate")(generate_request)
|
271
|
-
app.put("/generate")(generate_request)
|
272
|
-
|
273
|
-
|
319
|
+
@app.api_route("/encode", methods=["POST", "PUT"])
|
274
320
|
@time_func_latency
|
275
321
|
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
276
322
|
"""Handle an embedding request."""
|
@@ -283,10 +329,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
|
|
283
329
|
)
|
284
330
|
|
285
331
|
|
286
|
-
app.
|
287
|
-
app.put("/encode")(encode_request)
|
288
|
-
|
289
|
-
|
332
|
+
@app.api_route("/encode", methods=["POST", "PUT"])
|
290
333
|
@time_func_latency
|
291
334
|
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
292
335
|
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
@@ -299,10 +342,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|
299
342
|
)
|
300
343
|
|
301
344
|
|
302
|
-
app.post("/classify")(classify_request)
|
303
|
-
app.put("/classify")(classify_request)
|
304
|
-
|
305
|
-
|
306
345
|
##### OpenAI-compatible API endpoints #####
|
307
346
|
|
308
347
|
|
@@ -380,11 +419,11 @@ def launch_engine(
|
|
380
419
|
server_args: ServerArgs,
|
381
420
|
):
|
382
421
|
"""
|
383
|
-
Launch the
|
422
|
+
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
384
423
|
"""
|
385
424
|
|
386
425
|
global tokenizer_manager
|
387
|
-
global
|
426
|
+
global scheduler_info
|
388
427
|
|
389
428
|
# Configure global environment
|
390
429
|
configure_logger(server_args)
|
@@ -450,8 +489,8 @@ def launch_engine(
|
|
450
489
|
if server_args.chat_template:
|
451
490
|
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
452
491
|
|
453
|
-
# Wait for model to finish loading
|
454
|
-
|
492
|
+
# Wait for model to finish loading
|
493
|
+
scheduler_infos = []
|
455
494
|
for i in range(len(scheduler_pipe_readers)):
|
456
495
|
data = scheduler_pipe_readers[i].recv()
|
457
496
|
|
@@ -459,10 +498,10 @@ def launch_engine(
|
|
459
498
|
raise RuntimeError(
|
460
499
|
"Initialization failed. Please see the error messages above."
|
461
500
|
)
|
462
|
-
|
501
|
+
scheduler_infos.append(data)
|
463
502
|
|
464
503
|
# Assume all schedulers have same max_total_num_tokens
|
465
|
-
|
504
|
+
scheduler_info = scheduler_infos[0]
|
466
505
|
|
467
506
|
|
468
507
|
def launch_server(
|
@@ -476,12 +515,12 @@ def launch_server(
|
|
476
515
|
|
477
516
|
1. HTTP server: A FastAPI server that routes requests to the engine.
|
478
517
|
2. SRT engine:
|
479
|
-
1.
|
518
|
+
1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
|
480
519
|
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
|
481
|
-
3.
|
520
|
+
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
482
521
|
|
483
522
|
Note:
|
484
|
-
1. The HTTP server and
|
523
|
+
1. The HTTP server and TokenizerManager both run in the main process.
|
485
524
|
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
|
486
525
|
"""
|
487
526
|
launch_engine(server_args=server_args)
|
@@ -490,7 +529,7 @@ def launch_server(
|
|
490
529
|
if server_args.api_key:
|
491
530
|
add_api_key_middleware(app, server_args.api_key)
|
492
531
|
|
493
|
-
#
|
532
|
+
# Add prometheus middleware
|
494
533
|
if server_args.enable_metrics:
|
495
534
|
add_prometheus_middleware(app)
|
496
535
|
enable_func_timer()
|
@@ -502,7 +541,7 @@ def launch_server(
|
|
502
541
|
t.start()
|
503
542
|
|
504
543
|
try:
|
505
|
-
#
|
544
|
+
# Update logging configs
|
506
545
|
LOGGING_CONFIG["formatters"]["default"][
|
507
546
|
"fmt"
|
508
547
|
] = "[%(asctime)s] %(levelprefix)s %(message)s"
|
@@ -511,6 +550,8 @@ def launch_server(
|
|
511
550
|
"fmt"
|
512
551
|
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
|
513
552
|
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
553
|
+
|
554
|
+
# Listen for HTTP requests
|
514
555
|
uvicorn.run(
|
515
556
|
app,
|
516
557
|
host=server_args.host,
|
@@ -523,15 +564,6 @@ def launch_server(
|
|
523
564
|
t.join()
|
524
565
|
|
525
566
|
|
526
|
-
async def _get_server_info():
|
527
|
-
return {
|
528
|
-
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
529
|
-
"memory_pool_size": await tokenizer_manager.get_memory_pool_size(), # memory pool size
|
530
|
-
"max_total_num_tokens": _max_total_num_tokens, # max total num tokens
|
531
|
-
"version": __version__,
|
532
|
-
}
|
533
|
-
|
534
|
-
|
535
567
|
def _set_envs_and_config(server_args: ServerArgs):
|
536
568
|
# Set global environments
|
537
569
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
@@ -562,6 +594,15 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
562
594
|
"at https://docs.flashinfer.ai/installation.html.",
|
563
595
|
)
|
564
596
|
|
597
|
+
# Register the signal handler.
|
598
|
+
# The child processes will send SIGQUIT to this process when any error happens
|
599
|
+
# This process then clean up the whole process tree
|
600
|
+
def sigquit_handler(signum, frame):
|
601
|
+
kill_process_tree(os.getpid())
|
602
|
+
|
603
|
+
signal.signal(signal.SIGQUIT, sigquit_handler)
|
604
|
+
|
605
|
+
# Set mp start method
|
565
606
|
mp.set_start_method("spawn", force=True)
|
566
607
|
|
567
608
|
|
@@ -588,7 +629,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
588
629
|
if pipe_finish_writer is not None:
|
589
630
|
pipe_finish_writer.send(last_traceback)
|
590
631
|
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
591
|
-
|
632
|
+
kill_process_tree(os.getpid())
|
592
633
|
return
|
593
634
|
|
594
635
|
model_info = res.json()
|
@@ -621,9 +662,10 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
621
662
|
if pipe_finish_writer is not None:
|
622
663
|
pipe_finish_writer.send(last_traceback)
|
623
664
|
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
624
|
-
|
665
|
+
kill_process_tree(os.getpid())
|
625
666
|
return
|
626
667
|
|
668
|
+
# Debug print
|
627
669
|
# logger.info(f"{res.json()=}")
|
628
670
|
|
629
671
|
logger.info("The server is fired up and ready to roll!")
|
@@ -634,11 +676,218 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
634
676
|
delete_directory(server_args.model_path)
|
635
677
|
|
636
678
|
|
679
|
+
STREAM_END_SYMBOL = b"data: [DONE]"
|
680
|
+
STREAM_CHUNK_START_SYMBOL = b"data:"
|
681
|
+
|
682
|
+
|
683
|
+
class Engine:
|
684
|
+
"""
|
685
|
+
SRT Engine without an HTTP server layer.
|
686
|
+
|
687
|
+
This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
|
688
|
+
launching the HTTP server adds unnecessary complexity or overhead,
|
689
|
+
"""
|
690
|
+
|
691
|
+
def __init__(self, log_level: str = "error", *args, **kwargs):
|
692
|
+
"""See the arguments in server_args.py::ServerArgs"""
|
693
|
+
|
694
|
+
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
695
|
+
atexit.register(self.shutdown)
|
696
|
+
|
697
|
+
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
698
|
+
launch_engine(server_args=server_args)
|
699
|
+
|
700
|
+
def generate(
|
701
|
+
self,
|
702
|
+
# The input prompt. It can be a single prompt or a batch of prompts.
|
703
|
+
prompt: Optional[Union[List[str], str]] = None,
|
704
|
+
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
705
|
+
# The token ids for text; one can either specify text or input_ids.
|
706
|
+
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
707
|
+
return_logprob: Optional[Union[List[bool], bool]] = False,
|
708
|
+
logprob_start_len: Optional[Union[List[int], int]] = None,
|
709
|
+
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
710
|
+
lora_path: Optional[List[Optional[str]]] = None,
|
711
|
+
stream: bool = False,
|
712
|
+
):
|
713
|
+
obj = GenerateReqInput(
|
714
|
+
text=prompt,
|
715
|
+
input_ids=input_ids,
|
716
|
+
sampling_params=sampling_params,
|
717
|
+
return_logprob=return_logprob,
|
718
|
+
logprob_start_len=logprob_start_len,
|
719
|
+
top_logprobs_num=top_logprobs_num,
|
720
|
+
lora_path=lora_path,
|
721
|
+
stream=stream,
|
722
|
+
)
|
723
|
+
|
724
|
+
# get the current event loop
|
725
|
+
loop = asyncio.get_event_loop()
|
726
|
+
ret = loop.run_until_complete(generate_request(obj, None))
|
727
|
+
|
728
|
+
if stream is True:
|
729
|
+
|
730
|
+
def generator_wrapper():
|
731
|
+
offset = 0
|
732
|
+
loop = asyncio.get_event_loop()
|
733
|
+
generator = ret.body_iterator
|
734
|
+
while True:
|
735
|
+
chunk = loop.run_until_complete(generator.__anext__())
|
736
|
+
|
737
|
+
if chunk.startswith(STREAM_END_SYMBOL):
|
738
|
+
break
|
739
|
+
else:
|
740
|
+
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
741
|
+
data["text"] = data["text"][offset:]
|
742
|
+
offset += len(data["text"])
|
743
|
+
yield data
|
744
|
+
|
745
|
+
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
|
746
|
+
# however, it allows to wrap the generator as a subfunction and return
|
747
|
+
return generator_wrapper()
|
748
|
+
else:
|
749
|
+
return ret
|
750
|
+
|
751
|
+
async def async_generate(
|
752
|
+
self,
|
753
|
+
# The input prompt. It can be a single prompt or a batch of prompts.
|
754
|
+
prompt: Optional[Union[List[str], str]] = None,
|
755
|
+
sampling_params: Optional[Dict] = None,
|
756
|
+
# The token ids for text; one can either specify text or input_ids.
|
757
|
+
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
758
|
+
return_logprob: Optional[Union[List[bool], bool]] = False,
|
759
|
+
logprob_start_len: Optional[Union[List[int], int]] = None,
|
760
|
+
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
761
|
+
lora_path: Optional[List[Optional[str]]] = None,
|
762
|
+
stream: bool = False,
|
763
|
+
):
|
764
|
+
obj = GenerateReqInput(
|
765
|
+
text=prompt,
|
766
|
+
input_ids=input_ids,
|
767
|
+
sampling_params=sampling_params,
|
768
|
+
return_logprob=return_logprob,
|
769
|
+
logprob_start_len=logprob_start_len,
|
770
|
+
top_logprobs_num=top_logprobs_num,
|
771
|
+
lora_path=lora_path,
|
772
|
+
stream=stream,
|
773
|
+
)
|
774
|
+
|
775
|
+
ret = await generate_request(obj, None)
|
776
|
+
|
777
|
+
if stream is True:
|
778
|
+
generator = ret.body_iterator
|
779
|
+
|
780
|
+
async def generator_wrapper():
|
781
|
+
|
782
|
+
offset = 0
|
783
|
+
|
784
|
+
while True:
|
785
|
+
chunk = await generator.__anext__()
|
786
|
+
|
787
|
+
if chunk.startswith(STREAM_END_SYMBOL):
|
788
|
+
break
|
789
|
+
else:
|
790
|
+
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
791
|
+
data["text"] = data["text"][offset:]
|
792
|
+
offset += len(data["text"])
|
793
|
+
yield data
|
794
|
+
|
795
|
+
return generator_wrapper()
|
796
|
+
else:
|
797
|
+
return ret
|
798
|
+
|
799
|
+
def shutdown(self):
|
800
|
+
kill_process_tree(os.getpid(), include_parent=False)
|
801
|
+
|
802
|
+
def get_tokenizer(self):
|
803
|
+
global tokenizer_manager
|
804
|
+
|
805
|
+
if tokenizer_manager is None:
|
806
|
+
raise ReferenceError("Tokenizer Manager is not initialized.")
|
807
|
+
else:
|
808
|
+
return tokenizer_manager.tokenizer
|
809
|
+
|
810
|
+
def encode(
|
811
|
+
self,
|
812
|
+
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
813
|
+
):
|
814
|
+
obj = EmbeddingReqInput(text=prompt)
|
815
|
+
|
816
|
+
# get the current event loop
|
817
|
+
loop = asyncio.get_event_loop()
|
818
|
+
return loop.run_until_complete(encode_request(obj, None))
|
819
|
+
|
820
|
+
def start_profile(self):
|
821
|
+
tokenizer_manager.start_profile()
|
822
|
+
|
823
|
+
def stop_profile(self):
|
824
|
+
tokenizer_manager.stop_profile()
|
825
|
+
|
826
|
+
def get_server_info(self):
|
827
|
+
return {
|
828
|
+
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
829
|
+
**scheduler_info,
|
830
|
+
"version": __version__,
|
831
|
+
}
|
832
|
+
|
833
|
+
def init_weights_update_group(
|
834
|
+
self,
|
835
|
+
master_address: str,
|
836
|
+
master_port: int,
|
837
|
+
rank_offset: int,
|
838
|
+
world_size: int,
|
839
|
+
group_name: str,
|
840
|
+
backend: str = "nccl",
|
841
|
+
):
|
842
|
+
"""Initialize parameter update group."""
|
843
|
+
obj = InitWeightsUpdateGroupReqInput(
|
844
|
+
master_address=master_address,
|
845
|
+
master_port=master_port,
|
846
|
+
rank_offset=rank_offset,
|
847
|
+
world_size=world_size,
|
848
|
+
group_name=group_name,
|
849
|
+
backend=backend,
|
850
|
+
)
|
851
|
+
|
852
|
+
async def _init_group():
|
853
|
+
return await tokenizer_manager.init_weights_update_group(obj, None)
|
854
|
+
|
855
|
+
loop = asyncio.get_event_loop()
|
856
|
+
return loop.run_until_complete(_init_group())
|
857
|
+
|
858
|
+
def update_weights_from_distributed(self, name, dtype, shape):
|
859
|
+
"""Update weights from distributed source."""
|
860
|
+
obj = UpdateWeightsFromDistributedReqInput(
|
861
|
+
name=name,
|
862
|
+
dtype=dtype,
|
863
|
+
shape=shape,
|
864
|
+
)
|
865
|
+
|
866
|
+
async def _update_weights():
|
867
|
+
return await tokenizer_manager.update_weights_from_distributed(obj, None)
|
868
|
+
|
869
|
+
loop = asyncio.get_event_loop()
|
870
|
+
return loop.run_until_complete(_update_weights())
|
871
|
+
|
872
|
+
def get_weights_by_name(self, name, truncate_size=100):
|
873
|
+
"""Get weights by parameter name."""
|
874
|
+
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
875
|
+
|
876
|
+
async def _get_weights():
|
877
|
+
return await tokenizer_manager.get_weights_by_name(obj, None)
|
878
|
+
|
879
|
+
loop = asyncio.get_event_loop()
|
880
|
+
return loop.run_until_complete(_get_weights())
|
881
|
+
|
882
|
+
|
637
883
|
class Runtime:
|
638
884
|
"""
|
639
|
-
A wrapper for the server.
|
885
|
+
A wrapper for the HTTP server.
|
640
886
|
This is used for launching the server in a python program without
|
641
887
|
using the commond line interface.
|
888
|
+
|
889
|
+
It is mainly used for the frontend language.
|
890
|
+
You should use the Engine class if you want to do normal offline processing.
|
642
891
|
"""
|
643
892
|
|
644
893
|
def __init__(
|
@@ -690,7 +939,7 @@ class Runtime:
|
|
690
939
|
|
691
940
|
def shutdown(self):
|
692
941
|
if self.pid is not None:
|
693
|
-
|
942
|
+
kill_process_tree(self.pid)
|
694
943
|
self.pid = None
|
695
944
|
|
696
945
|
def cache_prefix(self, prefix: str):
|
@@ -786,153 +1035,3 @@ class Runtime:
|
|
786
1035
|
|
787
1036
|
def __del__(self):
|
788
1037
|
self.shutdown()
|
789
|
-
|
790
|
-
|
791
|
-
STREAM_END_SYMBOL = b"data: [DONE]"
|
792
|
-
STREAM_CHUNK_START_SYMBOL = b"data:"
|
793
|
-
|
794
|
-
|
795
|
-
class Engine:
|
796
|
-
"""
|
797
|
-
SRT Engine without an HTTP server layer.
|
798
|
-
|
799
|
-
This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
|
800
|
-
launching the HTTP server adds unnecessary complexity or overhead,
|
801
|
-
"""
|
802
|
-
|
803
|
-
def __init__(self, *args, **kwargs):
|
804
|
-
|
805
|
-
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
806
|
-
atexit.register(self.shutdown)
|
807
|
-
|
808
|
-
# runtime server default log level is log
|
809
|
-
# offline engine works in scripts, so we set it to error
|
810
|
-
|
811
|
-
if "log_level" not in kwargs:
|
812
|
-
kwargs["log_level"] = "error"
|
813
|
-
|
814
|
-
server_args = ServerArgs(*args, **kwargs)
|
815
|
-
launch_engine(server_args=server_args)
|
816
|
-
|
817
|
-
def generate(
|
818
|
-
self,
|
819
|
-
# The input prompt. It can be a single prompt or a batch of prompts.
|
820
|
-
prompt: Optional[Union[List[str], str]] = None,
|
821
|
-
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
822
|
-
# The token ids for text; one can either specify text or input_ids.
|
823
|
-
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
824
|
-
return_logprob: Optional[Union[List[bool], bool]] = False,
|
825
|
-
logprob_start_len: Optional[Union[List[int], int]] = None,
|
826
|
-
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
827
|
-
lora_path: Optional[List[Optional[str]]] = None,
|
828
|
-
stream: bool = False,
|
829
|
-
):
|
830
|
-
obj = GenerateReqInput(
|
831
|
-
text=prompt,
|
832
|
-
input_ids=input_ids,
|
833
|
-
sampling_params=sampling_params,
|
834
|
-
return_logprob=return_logprob,
|
835
|
-
logprob_start_len=logprob_start_len,
|
836
|
-
top_logprobs_num=top_logprobs_num,
|
837
|
-
lora_path=lora_path,
|
838
|
-
stream=stream,
|
839
|
-
)
|
840
|
-
|
841
|
-
# get the current event loop
|
842
|
-
loop = asyncio.get_event_loop()
|
843
|
-
ret = loop.run_until_complete(generate_request(obj, None))
|
844
|
-
|
845
|
-
if stream is True:
|
846
|
-
|
847
|
-
def generator_wrapper():
|
848
|
-
offset = 0
|
849
|
-
loop = asyncio.get_event_loop()
|
850
|
-
generator = ret.body_iterator
|
851
|
-
while True:
|
852
|
-
chunk = loop.run_until_complete(generator.__anext__())
|
853
|
-
|
854
|
-
if chunk.startswith(STREAM_END_SYMBOL):
|
855
|
-
break
|
856
|
-
else:
|
857
|
-
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
858
|
-
data["text"] = data["text"][offset:]
|
859
|
-
offset += len(data["text"])
|
860
|
-
yield data
|
861
|
-
|
862
|
-
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
|
863
|
-
# however, it allows to wrap the generator as a subfunction and return
|
864
|
-
return generator_wrapper()
|
865
|
-
else:
|
866
|
-
return ret
|
867
|
-
|
868
|
-
async def async_generate(
|
869
|
-
self,
|
870
|
-
# The input prompt. It can be a single prompt or a batch of prompts.
|
871
|
-
prompt: Optional[Union[List[str], str]] = None,
|
872
|
-
sampling_params: Optional[Dict] = None,
|
873
|
-
# The token ids for text; one can either specify text or input_ids.
|
874
|
-
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
875
|
-
return_logprob: Optional[Union[List[bool], bool]] = False,
|
876
|
-
logprob_start_len: Optional[Union[List[int], int]] = None,
|
877
|
-
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
878
|
-
lora_path: Optional[List[Optional[str]]] = None,
|
879
|
-
stream: bool = False,
|
880
|
-
):
|
881
|
-
obj = GenerateReqInput(
|
882
|
-
text=prompt,
|
883
|
-
input_ids=input_ids,
|
884
|
-
sampling_params=sampling_params,
|
885
|
-
return_logprob=return_logprob,
|
886
|
-
logprob_start_len=logprob_start_len,
|
887
|
-
top_logprobs_num=top_logprobs_num,
|
888
|
-
lora_path=lora_path,
|
889
|
-
stream=stream,
|
890
|
-
)
|
891
|
-
|
892
|
-
ret = await generate_request(obj, None)
|
893
|
-
|
894
|
-
if stream is True:
|
895
|
-
generator = ret.body_iterator
|
896
|
-
|
897
|
-
async def generator_wrapper():
|
898
|
-
|
899
|
-
offset = 0
|
900
|
-
|
901
|
-
while True:
|
902
|
-
chunk = await generator.__anext__()
|
903
|
-
|
904
|
-
if chunk.startswith(STREAM_END_SYMBOL):
|
905
|
-
break
|
906
|
-
else:
|
907
|
-
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
908
|
-
data["text"] = data["text"][offset:]
|
909
|
-
offset += len(data["text"])
|
910
|
-
yield data
|
911
|
-
|
912
|
-
return generator_wrapper()
|
913
|
-
else:
|
914
|
-
return ret
|
915
|
-
|
916
|
-
def shutdown(self):
|
917
|
-
kill_child_process()
|
918
|
-
|
919
|
-
def get_tokenizer(self):
|
920
|
-
global tokenizer_manager
|
921
|
-
|
922
|
-
if tokenizer_manager is None:
|
923
|
-
raise ReferenceError("Tokenizer Manager is not initialized.")
|
924
|
-
else:
|
925
|
-
return tokenizer_manager.tokenizer
|
926
|
-
|
927
|
-
def encode(
|
928
|
-
self,
|
929
|
-
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
930
|
-
):
|
931
|
-
obj = EmbeddingReqInput(text=prompt)
|
932
|
-
|
933
|
-
# get the current event loop
|
934
|
-
loop = asyncio.get_event_loop()
|
935
|
-
return loop.run_until_complete(encode_request(obj, None))
|
936
|
-
|
937
|
-
async def get_server_info(self):
|
938
|
-
return await _get_server_info()
|