sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +48 -33
- sglang/bench_server_latency.py +0 -6
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/hf_transformers_utils.py +1 -0
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/patch.py +117 -0
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +187 -68
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -247
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/mem_cache/radix_cache.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +25 -25
- sglang/srt/model_executor/forward_batch_info.py +94 -97
- sglang/srt/model_executor/model_runner.py +76 -78
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +22 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/openai_api/adapter.py +7 -0
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +76 -33
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +53 -9
- sglang/version.py +1 -1
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -482
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.1.post3.dist-info/RECORD +0 -134
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, List
|
|
6
6
|
import torch
|
7
7
|
|
8
8
|
import sglang.srt.sampling.penaltylib as penaltylib
|
9
|
+
from sglang.srt.constrained import RegexGuide
|
9
10
|
|
10
11
|
if TYPE_CHECKING:
|
11
12
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
@@ -13,22 +14,24 @@ if TYPE_CHECKING:
|
|
13
14
|
|
14
15
|
@dataclasses.dataclass
|
15
16
|
class SamplingBatchInfo:
|
16
|
-
# Basic Info
|
17
|
-
vocab_size: int
|
18
|
-
|
19
17
|
# Batched sampling params
|
20
|
-
temperatures: torch.Tensor
|
21
|
-
top_ps: torch.Tensor
|
22
|
-
top_ks: torch.Tensor
|
23
|
-
min_ps: torch.Tensor
|
18
|
+
temperatures: torch.Tensor
|
19
|
+
top_ps: torch.Tensor
|
20
|
+
top_ks: torch.Tensor
|
21
|
+
min_ps: torch.Tensor
|
24
22
|
|
25
23
|
# Dispatch in CUDA graph
|
26
|
-
need_min_p_sampling: bool
|
24
|
+
need_min_p_sampling: bool
|
27
25
|
|
28
26
|
# Bias Tensors
|
27
|
+
vocab_size: int
|
29
28
|
logit_bias: torch.Tensor = None
|
30
29
|
vocab_mask: torch.Tensor = None
|
31
30
|
|
31
|
+
# FSM states
|
32
|
+
regex_fsms: List[RegexGuide] = None
|
33
|
+
regex_fsm_states: List[int] = None
|
34
|
+
|
32
35
|
# Penalizer
|
33
36
|
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
34
37
|
linear_penalties: torch.Tensor = None
|
@@ -37,24 +40,30 @@ class SamplingBatchInfo:
|
|
37
40
|
@classmethod
|
38
41
|
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
39
42
|
reqs = batch.reqs
|
40
|
-
|
41
|
-
|
42
|
-
with torch.device("cuda"):
|
43
|
-
ret.temperatures = torch.tensor(
|
43
|
+
with batch.input_ids.device:
|
44
|
+
temperatures = torch.tensor(
|
44
45
|
[r.sampling_params.temperature for r in reqs],
|
45
46
|
dtype=torch.float,
|
46
47
|
).view(-1, 1)
|
47
|
-
|
48
|
+
top_ps = torch.tensor(
|
48
49
|
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
49
50
|
)
|
50
|
-
|
51
|
+
top_ks = torch.tensor(
|
51
52
|
[r.sampling_params.top_k for r in reqs], dtype=torch.int
|
52
53
|
)
|
53
|
-
|
54
|
+
min_ps = torch.tensor(
|
54
55
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
55
56
|
)
|
56
57
|
|
57
|
-
ret
|
58
|
+
ret = cls(
|
59
|
+
temperatures=temperatures,
|
60
|
+
top_ps=top_ps,
|
61
|
+
top_ks=top_ks,
|
62
|
+
min_ps=min_ps,
|
63
|
+
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
64
|
+
vocab_size=vocab_size,
|
65
|
+
)
|
66
|
+
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
58
67
|
|
59
68
|
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
60
69
|
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
@@ -102,24 +111,24 @@ class SamplingBatchInfo:
|
|
102
111
|
)
|
103
112
|
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
104
113
|
|
105
|
-
def update_regex_vocab_mask(self
|
106
|
-
has_regex = any(
|
114
|
+
def update_regex_vocab_mask(self):
|
115
|
+
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
|
107
116
|
|
108
117
|
# Reset the vocab mask
|
109
118
|
self.vocab_mask = None
|
110
119
|
|
111
120
|
if has_regex:
|
112
121
|
self.vocab_mask = torch.zeros(
|
113
|
-
|
122
|
+
len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda"
|
114
123
|
)
|
115
|
-
for i,
|
116
|
-
if
|
124
|
+
for i, regex_fsm in enumerate(self.regex_fsms):
|
125
|
+
if regex_fsm is not None:
|
117
126
|
self.vocab_mask[i].fill_(1)
|
118
127
|
self.vocab_mask[i][
|
119
|
-
|
128
|
+
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
|
120
129
|
] = 0
|
121
130
|
|
122
|
-
def
|
131
|
+
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
123
132
|
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
124
133
|
|
125
134
|
for item in [
|
@@ -129,9 +138,9 @@ class SamplingBatchInfo:
|
|
129
138
|
"min_ps",
|
130
139
|
"logit_bias",
|
131
140
|
]:
|
132
|
-
|
133
|
-
if
|
134
|
-
setattr(self, item,
|
141
|
+
value = getattr(self, item, None)
|
142
|
+
if value is not None: # logit_bias can be None
|
143
|
+
setattr(self, item, value[new_indices])
|
135
144
|
|
136
145
|
@staticmethod
|
137
146
|
def merge_bias_tensor(
|
@@ -153,7 +162,7 @@ class SamplingBatchInfo:
|
|
153
162
|
|
154
163
|
return None
|
155
164
|
|
156
|
-
def
|
165
|
+
def merge_batch(self, other: "SamplingBatchInfo"):
|
157
166
|
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
158
167
|
|
159
168
|
for item in [
|
@@ -26,7 +26,7 @@ class SamplingParams:
|
|
26
26
|
max_new_tokens: int = 128,
|
27
27
|
min_new_tokens: int = 0,
|
28
28
|
stop: Optional[Union[str, List[str]]] = None,
|
29
|
-
stop_token_ids: Optional[List[int]] =
|
29
|
+
stop_token_ids: Optional[List[int]] = None,
|
30
30
|
temperature: float = 1.0,
|
31
31
|
top_p: float = 1.0,
|
32
32
|
top_k: int = -1,
|
@@ -49,6 +49,8 @@ class SamplingParams:
|
|
49
49
|
self.presence_penalty = presence_penalty
|
50
50
|
self.repetition_penalty = repetition_penalty
|
51
51
|
self.stop_strs = stop
|
52
|
+
if stop_token_ids is None:
|
53
|
+
stop_token_ids = []
|
52
54
|
self.stop_token_ids = {*stop_token_ids}
|
53
55
|
self.max_new_tokens = max_new_tokens
|
54
56
|
self.min_new_tokens = min_new_tokens
|
sglang/srt/server.py
CHANGED
@@ -19,11 +19,13 @@ SRT = SGLang Runtime.
|
|
19
19
|
"""
|
20
20
|
|
21
21
|
import asyncio
|
22
|
+
import atexit
|
22
23
|
import dataclasses
|
23
24
|
import json
|
24
25
|
import logging
|
25
26
|
import multiprocessing as mp
|
26
27
|
import os
|
28
|
+
import random
|
27
29
|
import threading
|
28
30
|
import time
|
29
31
|
from http import HTTPStatus
|
@@ -41,21 +43,15 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
41
43
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
42
44
|
|
43
45
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
44
|
-
from sglang.srt.constrained import disable_cache
|
45
46
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
46
|
-
from sglang.srt.managers.
|
47
|
-
start_controller_process as start_controller_process_multi,
|
48
|
-
)
|
49
|
-
from sglang.srt.managers.controller_single import launch_tp_servers
|
50
|
-
from sglang.srt.managers.controller_single import (
|
51
|
-
start_controller_process as start_controller_process_single,
|
52
|
-
)
|
53
|
-
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
47
|
+
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
54
48
|
from sglang.srt.managers.io_struct import (
|
55
49
|
EmbeddingReqInput,
|
56
50
|
GenerateReqInput,
|
51
|
+
RewardReqInput,
|
57
52
|
UpdateWeightReqInput,
|
58
53
|
)
|
54
|
+
from sglang.srt.managers.scheduler import run_scheduler_process
|
59
55
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
60
56
|
from sglang.srt.openai_api.adapter import (
|
61
57
|
load_chat_template_for_openai_api,
|
@@ -74,15 +70,12 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
|
74
70
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
75
71
|
from sglang.srt.utils import (
|
76
72
|
add_api_key_middleware,
|
77
|
-
allocate_init_ports,
|
78
73
|
assert_pkg_version,
|
79
74
|
configure_logger,
|
80
|
-
|
81
|
-
is_hip,
|
75
|
+
is_port_available,
|
82
76
|
kill_child_process,
|
83
77
|
maybe_set_triton_cache_manager,
|
84
|
-
|
85
|
-
prepare_tokenizer,
|
78
|
+
prepare_model_and_tokenizer,
|
86
79
|
set_ulimit,
|
87
80
|
)
|
88
81
|
from sglang.utils import get_exception_traceback
|
@@ -127,6 +120,7 @@ async def health_generate(request: Request) -> Response:
|
|
127
120
|
|
128
121
|
@app.get("/get_model_info")
|
129
122
|
async def get_model_info():
|
123
|
+
"""Get the model information."""
|
130
124
|
result = {
|
131
125
|
"model_path": tokenizer_manager.model_path,
|
132
126
|
"is_generation": tokenizer_manager.is_generation,
|
@@ -136,11 +130,13 @@ async def get_model_info():
|
|
136
130
|
|
137
131
|
@app.get("/get_server_args")
|
138
132
|
async def get_server_args():
|
133
|
+
"""Get the server arguments."""
|
139
134
|
return dataclasses.asdict(tokenizer_manager.server_args)
|
140
135
|
|
141
136
|
|
142
137
|
@app.get("/flush_cache")
|
143
138
|
async def flush_cache():
|
139
|
+
"""Flush the radix cache."""
|
144
140
|
tokenizer_manager.flush_cache()
|
145
141
|
return Response(
|
146
142
|
content="Cache flushed.\nPlease check backend logs for more details. "
|
@@ -151,7 +147,7 @@ async def flush_cache():
|
|
151
147
|
|
152
148
|
@app.post("/update_weights")
|
153
149
|
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
154
|
-
|
150
|
+
"""Update the weights inplace without re-launching the server."""
|
155
151
|
success, message = await tokenizer_manager.update_weights(obj, request)
|
156
152
|
content = {"success": success, "message": message}
|
157
153
|
if success:
|
@@ -166,6 +162,7 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
|
166
162
|
)
|
167
163
|
|
168
164
|
|
165
|
+
# fastapi implicitly converts json in the request to obj (dataclass)
|
169
166
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
170
167
|
"""Handle a generate request."""
|
171
168
|
if obj.stream:
|
@@ -213,6 +210,21 @@ app.post("/encode")(encode_request)
|
|
213
210
|
app.put("/encode")(encode_request)
|
214
211
|
|
215
212
|
|
213
|
+
async def judge_request(obj: RewardReqInput, request: Request):
|
214
|
+
"""Handle a reward model request."""
|
215
|
+
try:
|
216
|
+
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
217
|
+
return ret
|
218
|
+
except ValueError as e:
|
219
|
+
return JSONResponse(
|
220
|
+
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
221
|
+
)
|
222
|
+
|
223
|
+
|
224
|
+
app.post("/judge")(judge_request)
|
225
|
+
app.put("/judge")(judge_request)
|
226
|
+
|
227
|
+
|
216
228
|
@app.post("/v1/completions")
|
217
229
|
async def openai_v1_completions(raw_request: Request):
|
218
230
|
return await v1_completions(tokenizer_manager, raw_request)
|
@@ -280,102 +292,95 @@ async def retrieve_file_content(file_id: str):
|
|
280
292
|
return await v1_retrieve_file_content(file_id)
|
281
293
|
|
282
294
|
|
283
|
-
def
|
295
|
+
def launch_engine(
|
284
296
|
server_args: ServerArgs,
|
285
|
-
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
286
297
|
):
|
287
|
-
"""
|
298
|
+
"""
|
299
|
+
Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess.
|
300
|
+
"""
|
301
|
+
|
288
302
|
global tokenizer_manager
|
289
303
|
|
304
|
+
# Configure global environment
|
290
305
|
configure_logger(server_args)
|
291
|
-
|
292
306
|
server_args.check_server_args()
|
293
307
|
_set_envs_and_config(server_args)
|
294
308
|
|
295
309
|
# Allocate ports for inter-process communications
|
296
|
-
|
297
|
-
server_args.port,
|
298
|
-
server_args.additional_ports,
|
299
|
-
server_args.dp_size,
|
300
|
-
)
|
301
|
-
ports = server_args.additional_ports
|
302
|
-
port_args = PortArgs(
|
303
|
-
tokenizer_port=ports[0],
|
304
|
-
controller_port=ports[1],
|
305
|
-
detokenizer_port=ports[2],
|
306
|
-
nccl_ports=ports[3:],
|
307
|
-
)
|
310
|
+
port_args = PortArgs.init_new(server_args)
|
308
311
|
logger.info(f"{server_args=}")
|
309
312
|
|
310
|
-
#
|
311
|
-
server_args.model_path =
|
312
|
-
|
313
|
-
|
314
|
-
# Launch processes for multi-node tensor parallelism
|
315
|
-
if server_args.nnodes > 1 and server_args.node_rank != 0:
|
316
|
-
tp_size_local = server_args.tp_size // server_args.nnodes
|
317
|
-
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
318
|
-
tp_rank_range = list(
|
319
|
-
range(
|
320
|
-
server_args.node_rank * tp_size_local,
|
321
|
-
(server_args.node_rank + 1) * tp_size_local,
|
322
|
-
)
|
323
|
-
)
|
324
|
-
procs = launch_tp_servers(
|
325
|
-
gpu_ids,
|
326
|
-
tp_rank_range,
|
327
|
-
server_args,
|
328
|
-
ports[3],
|
329
|
-
)
|
330
|
-
|
331
|
-
try:
|
332
|
-
for p in procs:
|
333
|
-
p.join()
|
334
|
-
finally:
|
335
|
-
kill_child_process(os.getpid(), including_parent=False)
|
336
|
-
return
|
337
|
-
|
338
|
-
# Launch processes
|
339
|
-
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
313
|
+
# If using model from www.modelscope.cn, first download the model.
|
314
|
+
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
315
|
+
server_args.model_path, server_args.tokenizer_path
|
316
|
+
)
|
340
317
|
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
318
|
+
# Launch tensor parallel scheduler processes
|
319
|
+
scheduler_procs = []
|
320
|
+
scheduler_pipe_readers = []
|
321
|
+
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
322
|
+
tp_rank_range = range(
|
323
|
+
tp_size_per_node * server_args.node_rank,
|
324
|
+
tp_size_per_node * (server_args.node_rank + 1),
|
348
325
|
)
|
349
|
-
|
326
|
+
for tp_rank in tp_rank_range:
|
327
|
+
reader, writer = mp.Pipe(duplex=False)
|
328
|
+
gpu_id = tp_rank % tp_size_per_node
|
329
|
+
proc = mp.Process(
|
330
|
+
target=run_scheduler_process,
|
331
|
+
args=(server_args, port_args, gpu_id, tp_rank, writer),
|
332
|
+
)
|
333
|
+
proc.start()
|
334
|
+
scheduler_procs.append(proc)
|
335
|
+
scheduler_pipe_readers.append(reader)
|
336
|
+
|
337
|
+
if server_args.node_rank >= 1:
|
338
|
+
# For other nodes, they do not need to run tokenizer or detokenizer,
|
339
|
+
# so they can just wait here.
|
340
|
+
while True:
|
341
|
+
pass
|
350
342
|
|
351
|
-
|
352
|
-
|
353
|
-
target=
|
343
|
+
# Launch detokenizer process
|
344
|
+
detoken_proc = mp.Process(
|
345
|
+
target=run_detokenizer_process,
|
354
346
|
args=(
|
355
347
|
server_args,
|
356
348
|
port_args,
|
357
|
-
pipe_detoken_writer,
|
358
349
|
),
|
359
350
|
)
|
360
|
-
|
351
|
+
detoken_proc.start()
|
361
352
|
|
353
|
+
# Launch tokenizer process
|
362
354
|
tokenizer_manager = TokenizerManager(server_args, port_args)
|
363
355
|
if server_args.chat_template:
|
364
356
|
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
365
357
|
|
366
|
-
# Wait for
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
358
|
+
# Wait for model to finish loading
|
359
|
+
for i in range(len(scheduler_pipe_readers)):
|
360
|
+
scheduler_pipe_readers[i].recv()
|
361
|
+
|
362
|
+
|
363
|
+
def launch_server(
|
364
|
+
server_args: ServerArgs,
|
365
|
+
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
366
|
+
):
|
367
|
+
"""
|
368
|
+
Launch SRT (SGLang Runtime) Server
|
369
|
+
|
370
|
+
The SRT server consists of an HTTP server and the SRT engine.
|
371
|
+
|
372
|
+
1. HTTP server: A FastAPI server that routes requests to the engine.
|
373
|
+
2. SRT engine:
|
374
|
+
1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler.
|
375
|
+
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
|
376
|
+
3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
377
|
+
|
378
|
+
Note:
|
379
|
+
1. The HTTP server and Tokenizer Manager both run in the main process.
|
380
|
+
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
|
381
|
+
"""
|
382
|
+
|
383
|
+
launch_engine(server_args=server_args)
|
379
384
|
|
380
385
|
# Add api key authorization
|
381
386
|
if server_args.api_key:
|
@@ -388,7 +393,7 @@ def launch_server(
|
|
388
393
|
t.start()
|
389
394
|
|
390
395
|
try:
|
391
|
-
# Listen for requests
|
396
|
+
# Listen for HTTP requests
|
392
397
|
uvicorn.run(
|
393
398
|
app,
|
394
399
|
host=server_args.host,
|
@@ -412,14 +417,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
412
417
|
# Set ulimit
|
413
418
|
set_ulimit()
|
414
419
|
|
415
|
-
# Enable show time cost for debugging
|
416
|
-
if server_args.show_time_cost:
|
417
|
-
enable_show_time_cost()
|
418
|
-
|
419
|
-
# Disable disk cache
|
420
|
-
if server_args.disable_disk_cache:
|
421
|
-
disable_cache()
|
422
|
-
|
423
420
|
# Fix triton bugs
|
424
421
|
if server_args.tp_size * server_args.dp_size > 1:
|
425
422
|
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
@@ -435,9 +432,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
435
432
|
"at https://docs.flashinfer.ai/installation.html.",
|
436
433
|
)
|
437
434
|
|
438
|
-
|
439
|
-
# to figure out a better method of not using fork later
|
440
|
-
mp.set_start_method("spawn", force=True)
|
435
|
+
mp.set_start_method("spawn", force=True)
|
441
436
|
|
442
437
|
|
443
438
|
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
@@ -467,7 +462,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
|
467
462
|
return
|
468
463
|
|
469
464
|
model_info = res.json()
|
470
|
-
|
471
465
|
# Send a warmup request
|
472
466
|
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
473
467
|
max_new_tokens = 8 if model_info["is_generation"] else 1
|
@@ -501,7 +495,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
|
501
495
|
|
502
496
|
logger.info("The server is fired up and ready to roll!")
|
503
497
|
if pipe_finish_writer is not None:
|
504
|
-
pipe_finish_writer.send("
|
498
|
+
pipe_finish_writer.send("ready")
|
505
499
|
|
506
500
|
|
507
501
|
class Runtime:
|
@@ -520,18 +514,20 @@ class Runtime:
|
|
520
514
|
"""See the arguments in server_args.py::ServerArgs"""
|
521
515
|
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
522
516
|
|
517
|
+
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
518
|
+
atexit.register(self.shutdown)
|
519
|
+
|
523
520
|
# Pre-allocate ports
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
521
|
+
for port in range(10000, 40000):
|
522
|
+
if is_port_available(port):
|
523
|
+
break
|
524
|
+
port += 1
|
525
|
+
self.server_args.port = port
|
529
526
|
|
530
527
|
self.url = self.server_args.url()
|
531
|
-
self.generate_url =
|
532
|
-
f"http://{self.server_args.host}:{self.server_args.port}/generate"
|
533
|
-
)
|
528
|
+
self.generate_url = self.url + "/generate"
|
534
529
|
|
530
|
+
# NOTE: We store pid instead of proc to fix some issues during __delete__
|
535
531
|
self.pid = None
|
536
532
|
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
537
533
|
|
@@ -548,7 +544,7 @@ class Runtime:
|
|
548
544
|
except EOFError:
|
549
545
|
init_state = ""
|
550
546
|
|
551
|
-
if init_state != "
|
547
|
+
if init_state != "ready":
|
552
548
|
self.shutdown()
|
553
549
|
raise RuntimeError(
|
554
550
|
"Initialization failed. Please see the error messages above."
|
@@ -599,7 +595,7 @@ class Runtime:
|
|
599
595
|
if chunk == "data: [DONE]\n\n":
|
600
596
|
break
|
601
597
|
data = json.loads(chunk[5:].strip("\n"))
|
602
|
-
if
|
598
|
+
if "text" in data:
|
603
599
|
cur = data["text"][pos:]
|
604
600
|
if cur:
|
605
601
|
yield cur
|
@@ -635,16 +631,71 @@ class Runtime:
|
|
635
631
|
|
636
632
|
def encode(
|
637
633
|
self,
|
638
|
-
prompt: Union[str, List[str]],
|
634
|
+
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
639
635
|
):
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
636
|
+
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
637
|
+
# embedding
|
638
|
+
json_data = {
|
639
|
+
"text": prompt,
|
640
|
+
}
|
641
|
+
response = requests.post(
|
642
|
+
self.url + "/encode",
|
643
|
+
json=json_data,
|
644
|
+
)
|
645
|
+
else:
|
646
|
+
# reward
|
647
|
+
json_data = {
|
648
|
+
"conv": prompt,
|
649
|
+
}
|
650
|
+
response = requests.post(
|
651
|
+
self.url + "/judge",
|
652
|
+
json=json_data,
|
653
|
+
)
|
647
654
|
return json.dumps(response.json())
|
648
655
|
|
649
656
|
def __del__(self):
|
650
657
|
self.shutdown()
|
658
|
+
|
659
|
+
|
660
|
+
class Engine:
|
661
|
+
"""
|
662
|
+
SRT Engine without an HTTP server layer.
|
663
|
+
|
664
|
+
This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
|
665
|
+
launching the HTTP server adds unnecessary complexity or overhead,
|
666
|
+
"""
|
667
|
+
|
668
|
+
def __init__(self, *args, **kwargs):
|
669
|
+
|
670
|
+
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
671
|
+
atexit.register(self.shutdown)
|
672
|
+
|
673
|
+
server_args = ServerArgs(*args, **kwargs)
|
674
|
+
launch_engine(server_args=server_args)
|
675
|
+
|
676
|
+
def generate(
|
677
|
+
self,
|
678
|
+
prompt: Union[str, List[str]],
|
679
|
+
sampling_params: Optional[Dict] = None,
|
680
|
+
return_logprob: Optional[Union[List[bool], bool]] = False,
|
681
|
+
logprob_start_len: Optional[Union[List[int], int]] = None,
|
682
|
+
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
683
|
+
lora_path: Optional[List[Optional[str]]] = None,
|
684
|
+
):
|
685
|
+
obj = GenerateReqInput(
|
686
|
+
text=prompt,
|
687
|
+
sampling_params=sampling_params,
|
688
|
+
return_logprob=return_logprob,
|
689
|
+
logprob_start_len=logprob_start_len,
|
690
|
+
top_logprobs_num=top_logprobs_num,
|
691
|
+
lora_path=lora_path,
|
692
|
+
)
|
693
|
+
|
694
|
+
# get the current event loop
|
695
|
+
loop = asyncio.get_event_loop()
|
696
|
+
return loop.run_until_complete(generate_request(obj, None))
|
697
|
+
|
698
|
+
def shutdown(self):
|
699
|
+
kill_child_process(os.getpid(), including_parent=False)
|
700
|
+
|
701
|
+
# TODO (ByronHsu): encode and async generate
|