sglang 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +6 -0
- sglang/bench_latency.py +7 -3
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +9 -0
- sglang/launch_server.py +8 -1
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -1
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/fused_moe/layer.py +2 -2
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +60 -23
- sglang/srt/layers/radix_attention.py +3 -4
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +52 -167
- sglang/srt/managers/tokenizer_manager.py +192 -83
- sglang/srt/managers/tp_worker.py +130 -43
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +49 -11
- sglang/srt/model_executor/forward_batch_info.py +59 -27
- sglang/srt/model_executor/model_runner.py +210 -61
- sglang/srt/models/chatglm.py +4 -12
- sglang/srt/models/commandr.py +5 -1
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +5 -1
- sglang/srt/models/deepseek_v2.py +5 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +15 -7
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +16 -2
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +5 -1
- sglang/srt/models/mixtral.py +5 -1
- sglang/srt/models/mixtral_quant.py +5 -1
- sglang/srt/models/qwen.py +5 -2
- sglang/srt/models/qwen2.py +13 -3
- sglang/srt/models/qwen2_moe.py +5 -14
- sglang/srt/models/stablelm.py +5 -1
- sglang/srt/openai_api/adapter.py +117 -37
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -0
- sglang/srt/server.py +84 -56
- sglang/srt/server_args.py +43 -15
- sglang/srt/utils.py +26 -16
- sglang/test/runners.py +23 -31
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_utils.py +36 -53
- sglang/version.py +1 -1
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/METADATA +92 -25
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang-0.2.13.dist-info/RECORD +0 -112
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,209 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from typing import TYPE_CHECKING, List
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
import sglang.srt.sampling.penaltylib as penaltylib
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
12
|
+
|
13
|
+
|
14
|
+
@dataclasses.dataclass
|
15
|
+
class SamplingBatchInfo:
|
16
|
+
# Basic Info
|
17
|
+
vocab_size: int
|
18
|
+
|
19
|
+
# Batched sampling params
|
20
|
+
temperatures: torch.Tensor = None
|
21
|
+
top_ps: torch.Tensor = None
|
22
|
+
top_ks: torch.Tensor = None
|
23
|
+
min_ps: torch.Tensor = None
|
24
|
+
|
25
|
+
# Dispatch in CUDA graph
|
26
|
+
need_min_p_sampling: bool = False
|
27
|
+
|
28
|
+
# Bias Tensors
|
29
|
+
logit_bias: torch.Tensor = None
|
30
|
+
vocab_mask: torch.Tensor = None
|
31
|
+
|
32
|
+
# Penalizer
|
33
|
+
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
34
|
+
linear_penalties: torch.Tensor = None
|
35
|
+
scaling_penalties: torch.Tensor = None
|
36
|
+
|
37
|
+
def has_bias(self):
|
38
|
+
return (
|
39
|
+
self.logit_bias is not None
|
40
|
+
or self.vocab_mask is not None
|
41
|
+
or self.linear_penalties is not None
|
42
|
+
or self.scaling_penalties is not None
|
43
|
+
)
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def dummy_one(cls, max_bs: int, vocab_size: int):
|
47
|
+
ret = cls(vocab_size=vocab_size)
|
48
|
+
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda")
|
49
|
+
ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda")
|
50
|
+
ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda")
|
51
|
+
ret.min_ps = torch.zeros((max_bs,), dtype=torch.float, device="cuda")
|
52
|
+
return ret
|
53
|
+
|
54
|
+
def __getitem__(self, key):
|
55
|
+
if isinstance(key, slice):
|
56
|
+
# NOTE: We do not use cuda graph when there is bias tensors
|
57
|
+
assert not self.has_bias()
|
58
|
+
return SamplingBatchInfo(
|
59
|
+
vocab_size=self.vocab_size,
|
60
|
+
temperatures=self.temperatures[key],
|
61
|
+
top_ps=self.top_ps[key],
|
62
|
+
top_ks=self.top_ks[key],
|
63
|
+
min_ps=self.min_ps[key],
|
64
|
+
need_min_p_sampling=self.need_min_p_sampling,
|
65
|
+
)
|
66
|
+
else:
|
67
|
+
raise NotImplementedError
|
68
|
+
|
69
|
+
def inplace_assign(self, bs: int, other: SamplingBatchInfo):
|
70
|
+
# NOTE: We do not use cuda graph when there is bias tensors
|
71
|
+
assert not self.has_bias()
|
72
|
+
|
73
|
+
self.vocab_size = other.vocab_size
|
74
|
+
self.need_min_p_sampling = other.need_min_p_sampling
|
75
|
+
|
76
|
+
self.temperatures[:bs] = other.temperatures
|
77
|
+
self.top_ps[:bs] = other.top_ps
|
78
|
+
self.top_ks[:bs] = other.top_ks
|
79
|
+
self.min_ps[:bs] = other.min_ps
|
80
|
+
|
81
|
+
@classmethod
|
82
|
+
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
83
|
+
device = "cuda"
|
84
|
+
reqs = batch.reqs
|
85
|
+
ret = cls(vocab_size=vocab_size)
|
86
|
+
|
87
|
+
ret.temperatures = torch.tensor(
|
88
|
+
[r.sampling_params.temperature for r in reqs],
|
89
|
+
dtype=torch.float,
|
90
|
+
device=device,
|
91
|
+
).view(-1, 1)
|
92
|
+
ret.top_ps = torch.tensor(
|
93
|
+
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
94
|
+
)
|
95
|
+
ret.top_ks = torch.tensor(
|
96
|
+
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
97
|
+
)
|
98
|
+
ret.min_ps = torch.tensor(
|
99
|
+
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
100
|
+
)
|
101
|
+
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
|
102
|
+
|
103
|
+
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
104
|
+
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
105
|
+
# should not add hefty computation overhead other than simple checks.
|
106
|
+
#
|
107
|
+
# While we choose not to even create the class instances if they are not required, this
|
108
|
+
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
109
|
+
# handle {filter_batch()} and {merge()} cases as well.
|
110
|
+
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
111
|
+
vocab_size=vocab_size,
|
112
|
+
batch=batch,
|
113
|
+
device=device,
|
114
|
+
Penalizers={
|
115
|
+
penaltylib.BatchedFrequencyPenalizer,
|
116
|
+
penaltylib.BatchedMinNewTokensPenalizer,
|
117
|
+
penaltylib.BatchedPresencePenalizer,
|
118
|
+
penaltylib.BatchedRepetitionPenalizer,
|
119
|
+
},
|
120
|
+
)
|
121
|
+
|
122
|
+
# Handle logit bias but only allocate when needed
|
123
|
+
ret.logit_bias = None
|
124
|
+
|
125
|
+
ret.update_regex_vocab_mask(batch)
|
126
|
+
|
127
|
+
return ret
|
128
|
+
|
129
|
+
def prepare_penalties(self):
|
130
|
+
self.scaling_penalties = None
|
131
|
+
self.linear_penalties = None
|
132
|
+
|
133
|
+
for penalizer in self.penalizer_orchestrator.penalizers.values():
|
134
|
+
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
|
135
|
+
if penalizer.is_prepared():
|
136
|
+
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
137
|
+
else:
|
138
|
+
if penalizer.is_prepared():
|
139
|
+
if self.linear_penalties is None:
|
140
|
+
bs = self.penalizer_orchestrator.batch.batch_size()
|
141
|
+
self.linear_penalties = torch.zeros(
|
142
|
+
(bs, self.vocab_size),
|
143
|
+
dtype=torch.float32,
|
144
|
+
device="cuda",
|
145
|
+
)
|
146
|
+
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
147
|
+
|
148
|
+
def update_regex_vocab_mask(self, batch: ScheduleBatch):
|
149
|
+
bs, reqs = batch.batch_size(), batch.reqs
|
150
|
+
device = "cuda"
|
151
|
+
has_regex = any(req.regex_fsm is not None for req in reqs)
|
152
|
+
|
153
|
+
# Reset the vocab mask
|
154
|
+
self.vocab_mask = None
|
155
|
+
|
156
|
+
if has_regex:
|
157
|
+
for i, req in enumerate(reqs):
|
158
|
+
if req.regex_fsm is not None:
|
159
|
+
if self.vocab_mask is None:
|
160
|
+
self.vocab_mask = torch.zeros(
|
161
|
+
bs, self.vocab_size, dtype=torch.bool, device=device
|
162
|
+
)
|
163
|
+
self.vocab_mask[i][
|
164
|
+
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
165
|
+
] = 1
|
166
|
+
|
167
|
+
def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
168
|
+
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
169
|
+
|
170
|
+
for item in [
|
171
|
+
"temperatures",
|
172
|
+
"top_ps",
|
173
|
+
"top_ks",
|
174
|
+
"min_ps",
|
175
|
+
"logit_bias",
|
176
|
+
]:
|
177
|
+
self_val = getattr(self, item, None)
|
178
|
+
if self_val is not None: # logit_bias can be None
|
179
|
+
setattr(self, item, self_val[new_indices])
|
180
|
+
|
181
|
+
def merge(self, other: "SamplingBatchInfo"):
|
182
|
+
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
183
|
+
|
184
|
+
for item in [
|
185
|
+
"temperatures",
|
186
|
+
"top_ps",
|
187
|
+
"top_ks",
|
188
|
+
"min_ps",
|
189
|
+
]:
|
190
|
+
self_val = getattr(self, item, None)
|
191
|
+
other_val = getattr(other, item, None)
|
192
|
+
setattr(self, item, torch.concat([self_val, other_val]))
|
193
|
+
|
194
|
+
# logit_bias can be None
|
195
|
+
if self.logit_bias is not None or other.logit_bias is not None:
|
196
|
+
vocab_size = (
|
197
|
+
self.logit_bias.shape[1]
|
198
|
+
if self.logit_bias is not None
|
199
|
+
else other.logit_bias.shape[1]
|
200
|
+
)
|
201
|
+
if self.logit_bias is None:
|
202
|
+
self.logit_bias = torch.zeros(
|
203
|
+
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
204
|
+
)
|
205
|
+
if other.logit_bias is None:
|
206
|
+
other.logit_bias = torch.zeros(
|
207
|
+
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
208
|
+
)
|
209
|
+
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
@@ -30,6 +30,7 @@ class SamplingParams:
|
|
30
30
|
temperature: float = 1.0,
|
31
31
|
top_p: float = 1.0,
|
32
32
|
top_k: int = -1,
|
33
|
+
min_p: float = 0.0,
|
33
34
|
frequency_penalty: float = 0.0,
|
34
35
|
presence_penalty: float = 0.0,
|
35
36
|
repetition_penalty: float = 1.0,
|
@@ -42,6 +43,7 @@ class SamplingParams:
|
|
42
43
|
self.temperature = temperature
|
43
44
|
self.top_p = top_p
|
44
45
|
self.top_k = top_k
|
46
|
+
self.min_p = min_p
|
45
47
|
self.frequency_penalty = frequency_penalty
|
46
48
|
self.presence_penalty = presence_penalty
|
47
49
|
self.repetition_penalty = repetition_penalty
|
@@ -69,6 +71,8 @@ class SamplingParams:
|
|
69
71
|
)
|
70
72
|
if not 0.0 < self.top_p <= 1.0:
|
71
73
|
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
74
|
+
if not 0.0 <= self.min_p <= 1.0:
|
75
|
+
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
|
72
76
|
if self.top_k < -1 or self.top_k == 0:
|
73
77
|
raise ValueError(
|
74
78
|
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
|
@@ -123,3 +127,17 @@ class SamplingParams:
|
|
123
127
|
else:
|
124
128
|
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
125
129
|
self.stop_str_max_len = stop_str_max_len
|
130
|
+
|
131
|
+
def to_srt_kwargs(self):
|
132
|
+
return {
|
133
|
+
"max_new_tokens": self.max_new_tokens,
|
134
|
+
"stop": self.stop_strs,
|
135
|
+
"stop_token_ids": list(self.stop_token_ids),
|
136
|
+
"temperature": self.temperature,
|
137
|
+
"top_p": self.top_p,
|
138
|
+
"top_k": self.top_k,
|
139
|
+
"frequency_penalty": self.frequency_penalty,
|
140
|
+
"presence_penalty": self.presence_penalty,
|
141
|
+
"ignore_eos": self.ignore_eos,
|
142
|
+
"regex": self.regex,
|
143
|
+
}
|
sglang/srt/server.py
CHANGED
@@ -24,7 +24,6 @@ import json
|
|
24
24
|
import logging
|
25
25
|
import multiprocessing as mp
|
26
26
|
import os
|
27
|
-
import sys
|
28
27
|
import threading
|
29
28
|
import time
|
30
29
|
from http import HTTPStatus
|
@@ -34,7 +33,6 @@ from typing import Dict, List, Optional, Union
|
|
34
33
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
35
34
|
|
36
35
|
import aiohttp
|
37
|
-
import psutil
|
38
36
|
import requests
|
39
37
|
import uvicorn
|
40
38
|
import uvloop
|
@@ -52,7 +50,11 @@ from sglang.srt.managers.controller_single import (
|
|
52
50
|
start_controller_process as start_controller_process_single,
|
53
51
|
)
|
54
52
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
55
|
-
from sglang.srt.managers.io_struct import
|
53
|
+
from sglang.srt.managers.io_struct import (
|
54
|
+
EmbeddingReqInput,
|
55
|
+
GenerateReqInput,
|
56
|
+
UpdateWeightReqInput,
|
57
|
+
)
|
56
58
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
57
59
|
from sglang.srt.openai_api.adapter import (
|
58
60
|
load_chat_template_for_openai_api,
|
@@ -72,6 +74,7 @@ from sglang.srt.utils import (
|
|
72
74
|
add_api_key_middleware,
|
73
75
|
allocate_init_ports,
|
74
76
|
assert_pkg_version,
|
77
|
+
configure_logger,
|
75
78
|
enable_show_time_cost,
|
76
79
|
kill_child_process,
|
77
80
|
maybe_set_triton_cache_manager,
|
@@ -92,10 +95,25 @@ tokenizer_manager = None
|
|
92
95
|
|
93
96
|
@app.get("/health")
|
94
97
|
async def health() -> Response:
|
95
|
-
"""
|
98
|
+
"""Check the health of the http server."""
|
96
99
|
return Response(status_code=200)
|
97
100
|
|
98
101
|
|
102
|
+
@app.get("/health_generate")
|
103
|
+
async def health_generate(request: Request) -> Response:
|
104
|
+
"""Check the health of the inference server by generating one token."""
|
105
|
+
gri = GenerateReqInput(
|
106
|
+
text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
107
|
+
)
|
108
|
+
try:
|
109
|
+
async for _ in tokenizer_manager.generate_request(gri, request):
|
110
|
+
break
|
111
|
+
return Response(status_code=200)
|
112
|
+
except Exception as e:
|
113
|
+
logger.exception(e)
|
114
|
+
return Response(status_code=503)
|
115
|
+
|
116
|
+
|
99
117
|
@app.get("/get_model_info")
|
100
118
|
async def get_model_info():
|
101
119
|
result = {
|
@@ -120,6 +138,23 @@ async def flush_cache():
|
|
120
138
|
)
|
121
139
|
|
122
140
|
|
141
|
+
@app.post("/update_weights")
|
142
|
+
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
143
|
+
|
144
|
+
success, message = await tokenizer_manager.update_weights(obj, request)
|
145
|
+
content = {"message": message, "success": str(success)}
|
146
|
+
if success:
|
147
|
+
return JSONResponse(
|
148
|
+
content,
|
149
|
+
status_code=HTTPStatus.OK,
|
150
|
+
)
|
151
|
+
else:
|
152
|
+
return JSONResponse(
|
153
|
+
content,
|
154
|
+
status_code=HTTPStatus.BAD_REQUEST,
|
155
|
+
)
|
156
|
+
|
157
|
+
|
123
158
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
124
159
|
"""Handle a generate request."""
|
125
160
|
if obj.stream:
|
@@ -236,15 +271,12 @@ def launch_server(
|
|
236
271
|
"""Launch an HTTP server."""
|
237
272
|
global tokenizer_manager
|
238
273
|
|
239
|
-
|
240
|
-
level=getattr(logging, server_args.log_level.upper()),
|
241
|
-
format="%(message)s",
|
242
|
-
)
|
274
|
+
configure_logger(server_args)
|
243
275
|
|
244
276
|
server_args.check_server_args()
|
245
277
|
_set_envs_and_config(server_args)
|
246
278
|
|
247
|
-
# Allocate ports
|
279
|
+
# Allocate ports for inter-process communications
|
248
280
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
249
281
|
server_args.port,
|
250
282
|
server_args.additional_ports,
|
@@ -264,27 +296,29 @@ def launch_server(
|
|
264
296
|
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
|
265
297
|
|
266
298
|
# Launch processes for multi-node tensor parallelism
|
267
|
-
if server_args.nnodes > 1:
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
range(
|
275
|
-
server_args.node_rank * tp_size_local,
|
276
|
-
(server_args.node_rank + 1) * tp_size_local,
|
277
|
-
)
|
278
|
-
)
|
279
|
-
procs = launch_tp_servers(
|
280
|
-
gpu_ids,
|
281
|
-
tp_rank_range,
|
282
|
-
server_args,
|
283
|
-
ports[3],
|
284
|
-
model_overide_args,
|
299
|
+
if server_args.nnodes > 1 and server_args.node_rank != 0:
|
300
|
+
tp_size_local = server_args.tp_size // server_args.nnodes
|
301
|
+
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
302
|
+
tp_rank_range = list(
|
303
|
+
range(
|
304
|
+
server_args.node_rank * tp_size_local,
|
305
|
+
(server_args.node_rank + 1) * tp_size_local,
|
285
306
|
)
|
286
|
-
|
287
|
-
|
307
|
+
)
|
308
|
+
procs = launch_tp_servers(
|
309
|
+
gpu_ids,
|
310
|
+
tp_rank_range,
|
311
|
+
server_args,
|
312
|
+
ports[3],
|
313
|
+
model_overide_args,
|
314
|
+
)
|
315
|
+
|
316
|
+
try:
|
317
|
+
for p in procs:
|
318
|
+
p.join()
|
319
|
+
finally:
|
320
|
+
kill_child_process(os.getpid(), including_parent=False)
|
321
|
+
return
|
288
322
|
|
289
323
|
# Launch processes
|
290
324
|
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
@@ -297,11 +331,13 @@ def launch_server(
|
|
297
331
|
start_process = start_controller_process_single
|
298
332
|
else:
|
299
333
|
start_process = start_controller_process_multi
|
334
|
+
|
300
335
|
proc_controller = mp.Process(
|
301
336
|
target=start_process,
|
302
337
|
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
|
303
338
|
)
|
304
339
|
proc_controller.start()
|
340
|
+
|
305
341
|
proc_detoken = mp.Process(
|
306
342
|
target=start_detokenizer_process,
|
307
343
|
args=(
|
@@ -319,15 +355,11 @@ def launch_server(
|
|
319
355
|
if controller_init_state != "init ok" or detoken_init_state != "init ok":
|
320
356
|
proc_controller.kill()
|
321
357
|
proc_detoken.kill()
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
print(
|
327
|
-
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
328
|
-
flush=True,
|
358
|
+
raise RuntimeError(
|
359
|
+
"Initialization failed. "
|
360
|
+
f"controller_init_state: {controller_init_state}, "
|
361
|
+
f"detoken_init_state: {detoken_init_state}"
|
329
362
|
)
|
330
|
-
sys.exit(1)
|
331
363
|
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
332
364
|
|
333
365
|
# Add api key authorization
|
@@ -336,12 +368,12 @@ def launch_server(
|
|
336
368
|
|
337
369
|
# Send a warmup request
|
338
370
|
t = threading.Thread(
|
339
|
-
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
371
|
+
target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
|
340
372
|
)
|
341
373
|
t.start()
|
342
374
|
|
343
|
-
# Listen for requests
|
344
375
|
try:
|
376
|
+
# Listen for requests
|
345
377
|
uvicorn.run(
|
346
378
|
app,
|
347
379
|
host=server_args.host,
|
@@ -389,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
389
421
|
)
|
390
422
|
|
391
423
|
|
392
|
-
def _wait_and_warmup(server_args, pipe_finish_writer):
|
424
|
+
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
393
425
|
headers = {}
|
394
426
|
url = server_args.url()
|
395
427
|
if server_args.api_key:
|
@@ -412,8 +444,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
412
444
|
if not success:
|
413
445
|
if pipe_finish_writer is not None:
|
414
446
|
pipe_finish_writer.send(last_traceback)
|
415
|
-
|
416
|
-
|
447
|
+
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
448
|
+
kill_child_process(pid, including_parent=False)
|
449
|
+
return
|
417
450
|
|
418
451
|
# Send a warmup request
|
419
452
|
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
@@ -438,21 +471,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
438
471
|
timeout=600,
|
439
472
|
)
|
440
473
|
assert res.status_code == 200, f"{res}"
|
441
|
-
except Exception
|
474
|
+
except Exception:
|
442
475
|
last_traceback = get_exception_traceback()
|
443
476
|
if pipe_finish_writer is not None:
|
444
477
|
pipe_finish_writer.send(last_traceback)
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
# Print warnings here
|
449
|
-
if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None:
|
450
|
-
logger.warning(
|
451
|
-
"You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
|
452
|
-
"This combination is an experimental feature and we noticed it can lead to "
|
453
|
-
"wrong generation results. If you want to use chunked prefill, it is recommended "
|
454
|
-
"not using `--disable-radix-cache`."
|
455
|
-
)
|
478
|
+
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
479
|
+
kill_child_process(pid, including_parent=False)
|
480
|
+
return
|
456
481
|
|
457
482
|
logger.info("The server is fired up and ready to roll!")
|
458
483
|
if pipe_finish_writer is not None:
|
@@ -490,6 +515,7 @@ class Runtime:
|
|
490
515
|
|
491
516
|
self.pid = None
|
492
517
|
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
518
|
+
|
493
519
|
proc = mp.Process(
|
494
520
|
target=launch_server,
|
495
521
|
args=(self.server_args, model_overide_args, pipe_writer),
|
@@ -566,15 +592,17 @@ class Runtime:
|
|
566
592
|
|
567
593
|
def generate(
|
568
594
|
self,
|
569
|
-
prompt: str,
|
595
|
+
prompt: Union[str, List[str]],
|
570
596
|
sampling_params: Optional[Dict] = None,
|
571
597
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
598
|
+
logprob_start_len: Optional[Union[List[int], int]] = None,
|
572
599
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
573
600
|
):
|
574
601
|
json_data = {
|
575
602
|
"text": prompt,
|
576
603
|
"sampling_params": sampling_params,
|
577
604
|
"return_logprob": return_logprob,
|
605
|
+
"logprob_start_len": logprob_start_len,
|
578
606
|
"top_logprobs_num": top_logprobs_num,
|
579
607
|
}
|
580
608
|
response = requests.post(
|
@@ -585,7 +613,7 @@ class Runtime:
|
|
585
613
|
|
586
614
|
def encode(
|
587
615
|
self,
|
588
|
-
prompt: str,
|
616
|
+
prompt: Union[str, List[str]],
|
589
617
|
):
|
590
618
|
json_data = {
|
591
619
|
"text": prompt,
|
sglang/srt/server_args.py
CHANGED
@@ -33,11 +33,13 @@ class ServerArgs:
|
|
33
33
|
skip_tokenizer_init: bool = False
|
34
34
|
load_format: str = "auto"
|
35
35
|
dtype: str = "auto"
|
36
|
+
kv_cache_dtype: str = "auto"
|
36
37
|
trust_remote_code: bool = True
|
37
38
|
context_length: Optional[int] = None
|
38
39
|
quantization: Optional[str] = None
|
39
40
|
served_model_name: Optional[str] = None
|
40
41
|
chat_template: Optional[str] = None
|
42
|
+
is_embedding: bool = False
|
41
43
|
|
42
44
|
# Port
|
43
45
|
host: str = "127.0.0.1"
|
@@ -79,12 +81,14 @@ class ServerArgs:
|
|
79
81
|
disable_radix_cache: bool = False
|
80
82
|
disable_regex_jump_forward: bool = False
|
81
83
|
disable_cuda_graph: bool = False
|
84
|
+
disable_cuda_graph_padding: bool = False
|
82
85
|
disable_disk_cache: bool = False
|
86
|
+
disable_custom_all_reduce: bool = False
|
87
|
+
enable_mixed_chunk: bool = False
|
83
88
|
enable_torch_compile: bool = False
|
84
89
|
enable_p2p_check: bool = False
|
85
90
|
enable_mla: bool = False
|
86
|
-
|
87
|
-
efficient_weight_load: bool = False
|
91
|
+
triton_attention_reduce_in_fp32: bool = False
|
88
92
|
|
89
93
|
# Distributed args
|
90
94
|
nccl_init_addr: Optional[str] = None
|
@@ -193,11 +197,23 @@ class ServerArgs:
|
|
193
197
|
'* "float" is shorthand for FP32 precision.\n'
|
194
198
|
'* "float32" for FP32 precision.',
|
195
199
|
)
|
200
|
+
parser.add_argument(
|
201
|
+
"--kv-cache-dtype",
|
202
|
+
type=str,
|
203
|
+
default=ServerArgs.kv_cache_dtype,
|
204
|
+
choices=["auto", "fp8_e5m2"],
|
205
|
+
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
|
206
|
+
)
|
196
207
|
parser.add_argument(
|
197
208
|
"--trust-remote-code",
|
198
209
|
action="store_true",
|
199
210
|
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
200
211
|
)
|
212
|
+
parser.add_argument(
|
213
|
+
"--is-embedding",
|
214
|
+
action="store_true",
|
215
|
+
help="Whether to use a CausalLM as an embedding model.",
|
216
|
+
)
|
201
217
|
parser.add_argument(
|
202
218
|
"--context-length",
|
203
219
|
type=int,
|
@@ -391,11 +407,27 @@ class ServerArgs:
|
|
391
407
|
action="store_true",
|
392
408
|
help="Disable cuda graph.",
|
393
409
|
)
|
410
|
+
parser.add_argument(
|
411
|
+
"--disable-cuda-graph-padding",
|
412
|
+
action="store_true",
|
413
|
+
help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
|
414
|
+
)
|
394
415
|
parser.add_argument(
|
395
416
|
"--disable-disk-cache",
|
396
417
|
action="store_true",
|
397
418
|
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
398
419
|
)
|
420
|
+
parser.add_argument(
|
421
|
+
"--disable-custom-all-reduce",
|
422
|
+
action="store_true",
|
423
|
+
default=False,
|
424
|
+
help="Disable the custom all-reduce kernel and fall back to NCCL.",
|
425
|
+
)
|
426
|
+
parser.add_argument(
|
427
|
+
"--enable-mixed-chunk",
|
428
|
+
action="store_true",
|
429
|
+
help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
|
430
|
+
)
|
399
431
|
parser.add_argument(
|
400
432
|
"--enable-torch-compile",
|
401
433
|
action="store_true",
|
@@ -409,13 +441,13 @@ class ServerArgs:
|
|
409
441
|
parser.add_argument(
|
410
442
|
"--enable-mla",
|
411
443
|
action="store_true",
|
412
|
-
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
|
444
|
+
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
413
445
|
)
|
414
446
|
parser.add_argument(
|
415
|
-
"--attention-reduce-in-fp32",
|
447
|
+
"--triton-attention-reduce-in-fp32",
|
416
448
|
action="store_true",
|
417
449
|
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
|
418
|
-
"This only affects Triton attention kernels",
|
450
|
+
"This only affects Triton attention kernels.",
|
419
451
|
)
|
420
452
|
parser.add_argument(
|
421
453
|
"--efficient-weight-load",
|
@@ -433,15 +465,6 @@ class ServerArgs:
|
|
433
465
|
def url(self):
|
434
466
|
return f"http://{self.host}:{self.port}"
|
435
467
|
|
436
|
-
def print_mode_args(self):
|
437
|
-
return (
|
438
|
-
f"disable_flashinfer={self.disable_flashinfer}, "
|
439
|
-
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
|
440
|
-
f"disable_radix_cache={self.disable_radix_cache}, "
|
441
|
-
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
442
|
-
f"disable_disk_cache={self.disable_disk_cache}, "
|
443
|
-
)
|
444
|
-
|
445
468
|
def check_server_args(self):
|
446
469
|
assert (
|
447
470
|
self.tp_size % self.nnodes == 0
|
@@ -449,8 +472,13 @@ class ServerArgs:
|
|
449
472
|
assert not (
|
450
473
|
self.dp_size > 1 and self.node_rank is not None
|
451
474
|
), "multi-node data parallel is not supported"
|
475
|
+
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
476
|
+
logger.info(
|
477
|
+
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
|
478
|
+
)
|
479
|
+
self.trust_remote_code = False
|
452
480
|
if "gemma-2" in self.model_path.lower():
|
453
|
-
logger.info(
|
481
|
+
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
|
454
482
|
self.disable_flashinfer = False
|
455
483
|
|
456
484
|
|