sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +151 -40
- sglang/bench_serving.py +46 -22
- sglang/check_env.py +24 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -29
- sglang/lang/choices.py +164 -0
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +14 -5
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/layers/activation.py +33 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +6 -1
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +6 -1
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +4 -7
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +174 -380
- sglang/srt/managers/tokenizer_manager.py +197 -112
- sglang/srt/managers/tp_worker.py +299 -364
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +10 -15
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +27 -12
- sglang/srt/model_executor/forward_batch_info.py +319 -0
- sglang/srt/model_executor/model_runner.py +30 -47
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +1 -1
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -2
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +3 -8
- sglang/srt/models/llama2.py +5 -5
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -12
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +189 -39
- sglang/srt/openai_api/protocol.py +43 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -4
- sglang/srt/server.py +93 -21
- sglang/srt/server_args.py +30 -19
- sglang/srt/utils.py +31 -13
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +63 -63
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +4 -2
- sglang/test/test_utils.py +21 -3
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
- sglang-0.2.12.dist-info/RECORD +112 -0
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang-0.2.10.dist-info/RECORD +0 -100
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,79 @@
|
|
1
|
+
import typing
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
6
|
+
|
7
|
+
|
8
|
+
class BatchedPresencePenalizer(_BatchedPenalizer):
|
9
|
+
"""
|
10
|
+
Presence penalizer penalizes tokens based on their presence in the output.
|
11
|
+
"""
|
12
|
+
|
13
|
+
presence_penalties: torch.Tensor = None
|
14
|
+
cumulated_presence_penalties: torch.Tensor = None
|
15
|
+
|
16
|
+
def _is_required(self) -> bool:
|
17
|
+
return any(
|
18
|
+
req.sampling_params.presence_penalty != 0.0
|
19
|
+
for req in self.orchestrator.reqs()
|
20
|
+
)
|
21
|
+
|
22
|
+
def _prepare(self):
|
23
|
+
self.cumulated_presence_penalties = (
|
24
|
+
torch.tensor(
|
25
|
+
data=[0.0 for _ in self.orchestrator.reqs()],
|
26
|
+
dtype=torch.float32,
|
27
|
+
device=self.orchestrator.device,
|
28
|
+
)
|
29
|
+
.unsqueeze_(1)
|
30
|
+
.repeat(1, self.orchestrator.vocab_size)
|
31
|
+
)
|
32
|
+
|
33
|
+
self.presence_penalties = (
|
34
|
+
torch.tensor(
|
35
|
+
data=[
|
36
|
+
req.sampling_params.presence_penalty
|
37
|
+
for req in self.orchestrator.reqs()
|
38
|
+
],
|
39
|
+
dtype=torch.float32,
|
40
|
+
device=self.orchestrator.device,
|
41
|
+
)
|
42
|
+
.unsqueeze_(1)
|
43
|
+
.expand_as(self.cumulated_presence_penalties)
|
44
|
+
)
|
45
|
+
|
46
|
+
def _teardown(self):
|
47
|
+
del self.presence_penalties
|
48
|
+
del self.cumulated_presence_penalties
|
49
|
+
|
50
|
+
self.presence_penalties = None
|
51
|
+
self.cumulated_presence_penalties = None
|
52
|
+
|
53
|
+
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
54
|
+
pass
|
55
|
+
|
56
|
+
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
57
|
+
mask = output_ids.occurrence_count() > 0
|
58
|
+
self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]
|
59
|
+
|
60
|
+
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
61
|
+
logits -= self.cumulated_presence_penalties
|
62
|
+
return logits
|
63
|
+
|
64
|
+
def _filter(
|
65
|
+
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
66
|
+
):
|
67
|
+
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
|
68
|
+
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
69
|
+
indices_tensor_to_keep
|
70
|
+
]
|
71
|
+
|
72
|
+
def _merge(self, their: "BatchedPresencePenalizer"):
|
73
|
+
self.presence_penalties = torch.cat(
|
74
|
+
[self.presence_penalties, their.presence_penalties], dim=0
|
75
|
+
)
|
76
|
+
self.cumulated_presence_penalties = torch.cat(
|
77
|
+
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
|
78
|
+
dim=0,
|
79
|
+
)
|
@@ -0,0 +1,83 @@
|
|
1
|
+
import typing
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
6
|
+
|
7
|
+
|
8
|
+
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
9
|
+
"""
|
10
|
+
Repetition penalizer penalizes tokens based on their repetition in the input and output.
|
11
|
+
"""
|
12
|
+
|
13
|
+
repetition_penalties: torch.Tensor = None
|
14
|
+
cumulated_repetition_penalties: torch.Tensor = None
|
15
|
+
|
16
|
+
def _is_required(self) -> bool:
|
17
|
+
return any(
|
18
|
+
req.sampling_params.repetition_penalty != 1.0
|
19
|
+
for req in self.orchestrator.reqs()
|
20
|
+
)
|
21
|
+
|
22
|
+
def _prepare(self):
|
23
|
+
self.cumulated_repetition_penalties = (
|
24
|
+
torch.tensor(
|
25
|
+
data=[1.0 for _ in self.orchestrator.reqs()],
|
26
|
+
dtype=torch.float32,
|
27
|
+
device=self.orchestrator.device,
|
28
|
+
)
|
29
|
+
.unsqueeze_(1)
|
30
|
+
.repeat(1, self.orchestrator.vocab_size)
|
31
|
+
)
|
32
|
+
|
33
|
+
self.repetition_penalties = (
|
34
|
+
torch.tensor(
|
35
|
+
data=[
|
36
|
+
req.sampling_params.repetition_penalty
|
37
|
+
for req in self.orchestrator.reqs()
|
38
|
+
],
|
39
|
+
dtype=torch.float32,
|
40
|
+
device=self.orchestrator.device,
|
41
|
+
)
|
42
|
+
.unsqueeze_(1)
|
43
|
+
.expand_as(self.cumulated_repetition_penalties)
|
44
|
+
)
|
45
|
+
|
46
|
+
def _teardown(self):
|
47
|
+
del self.repetition_penalties
|
48
|
+
del self.cumulated_repetition_penalties
|
49
|
+
|
50
|
+
self.repetition_penalties = None
|
51
|
+
self.cumulated_repetition_penalties = None
|
52
|
+
|
53
|
+
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
54
|
+
mask = input_ids.occurrence_count() > 0
|
55
|
+
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
56
|
+
|
57
|
+
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
58
|
+
mask = output_ids.occurrence_count() > 0
|
59
|
+
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
60
|
+
|
61
|
+
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
62
|
+
return torch.where(
|
63
|
+
logits > 0,
|
64
|
+
logits / self.cumulated_repetition_penalties,
|
65
|
+
logits * self.cumulated_repetition_penalties,
|
66
|
+
)
|
67
|
+
|
68
|
+
def _filter(
|
69
|
+
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
70
|
+
):
|
71
|
+
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
72
|
+
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
|
73
|
+
indices_tensor_to_keep
|
74
|
+
]
|
75
|
+
|
76
|
+
def _merge(self, their: "BatchedRepetitionPenalizer"):
|
77
|
+
self.repetition_penalties = torch.cat(
|
78
|
+
[self.repetition_penalties, their.repetition_penalties], dim=0
|
79
|
+
)
|
80
|
+
self.cumulated_repetition_penalties = torch.cat(
|
81
|
+
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
|
82
|
+
dim=0,
|
83
|
+
)
|
sglang/srt/sampling_params.py
CHANGED
@@ -23,13 +23,16 @@ _SAMPLING_EPS = 1e-6
|
|
23
23
|
class SamplingParams:
|
24
24
|
def __init__(
|
25
25
|
self,
|
26
|
-
max_new_tokens: int =
|
26
|
+
max_new_tokens: int = 128,
|
27
|
+
min_new_tokens: int = 0,
|
27
28
|
stop: Optional[Union[str, List[str]]] = None,
|
29
|
+
stop_token_ids: Optional[List[int]] = [],
|
28
30
|
temperature: float = 1.0,
|
29
31
|
top_p: float = 1.0,
|
30
32
|
top_k: int = -1,
|
31
33
|
frequency_penalty: float = 0.0,
|
32
34
|
presence_penalty: float = 0.0,
|
35
|
+
repetition_penalty: float = 1.0,
|
33
36
|
ignore_eos: bool = False,
|
34
37
|
skip_special_tokens: bool = True,
|
35
38
|
spaces_between_special_tokens: bool = True,
|
@@ -42,8 +45,11 @@ class SamplingParams:
|
|
42
45
|
self.top_k = top_k
|
43
46
|
self.frequency_penalty = frequency_penalty
|
44
47
|
self.presence_penalty = presence_penalty
|
48
|
+
self.repetition_penalty = repetition_penalty
|
45
49
|
self.stop_strs = stop
|
50
|
+
self.stop_token_ids = {*stop_token_ids}
|
46
51
|
self.max_new_tokens = max_new_tokens
|
52
|
+
self.min_new_tokens = min_new_tokens
|
47
53
|
self.ignore_eos = ignore_eos
|
48
54
|
self.skip_special_tokens = skip_special_tokens
|
49
55
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
@@ -80,23 +86,44 @@ class SamplingParams:
|
|
80
86
|
raise ValueError(
|
81
87
|
"presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
|
82
88
|
)
|
89
|
+
if not 0.0 <= self.repetition_penalty <= 2.0:
|
90
|
+
raise ValueError(
|
91
|
+
"repetition_penalty must be in (0, 2], got "
|
92
|
+
f"{self.repetition_penalty}."
|
93
|
+
)
|
94
|
+
if not 0 <= self.min_new_tokens:
|
95
|
+
raise ValueError(
|
96
|
+
f"min_new_tokens must be in (0, max_new_tokens], got "
|
97
|
+
f"{self.min_new_tokens}."
|
98
|
+
)
|
83
99
|
if self.max_new_tokens is not None:
|
84
100
|
if self.max_new_tokens < 0:
|
85
101
|
raise ValueError(
|
86
102
|
f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
|
87
103
|
)
|
104
|
+
if not self.min_new_tokens <= self.max_new_tokens:
|
105
|
+
raise ValueError(
|
106
|
+
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
|
107
|
+
f"{self.min_new_tokens}."
|
108
|
+
)
|
88
109
|
|
89
110
|
def normalize(self, tokenizer):
|
90
111
|
# Process stop strings
|
91
112
|
if self.stop_strs is None:
|
92
113
|
self.stop_strs = []
|
93
|
-
self.
|
114
|
+
if self.stop_token_ids is None:
|
115
|
+
self.stop_str_max_len = 0
|
116
|
+
else:
|
117
|
+
self.stop_str_max_len = 1
|
94
118
|
else:
|
95
119
|
if isinstance(self.stop_strs, str):
|
96
120
|
self.stop_strs = [self.stop_strs]
|
97
121
|
|
98
122
|
stop_str_max_len = 0
|
99
123
|
for stop_str in self.stop_strs:
|
100
|
-
|
101
|
-
|
124
|
+
if tokenizer is not None:
|
125
|
+
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
|
126
|
+
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
|
127
|
+
else:
|
128
|
+
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
102
129
|
self.stop_str_max_len = stop_str_max_len
|
sglang/srt/server.py
CHANGED
@@ -52,13 +52,15 @@ from sglang.srt.managers.controller_single import (
|
|
52
52
|
start_controller_process as start_controller_process_single,
|
53
53
|
)
|
54
54
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
55
|
-
from sglang.srt.managers.io_struct import GenerateReqInput
|
55
|
+
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
56
56
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
57
57
|
from sglang.srt.openai_api.adapter import (
|
58
58
|
load_chat_template_for_openai_api,
|
59
59
|
v1_batches,
|
60
60
|
v1_chat_completions,
|
61
61
|
v1_completions,
|
62
|
+
v1_delete_file,
|
63
|
+
v1_embeddings,
|
62
64
|
v1_files_create,
|
63
65
|
v1_retrieve_batch,
|
64
66
|
v1_retrieve_file,
|
@@ -73,7 +75,8 @@ from sglang.srt.utils import (
|
|
73
75
|
enable_show_time_cost,
|
74
76
|
kill_child_process,
|
75
77
|
maybe_set_triton_cache_manager,
|
76
|
-
|
78
|
+
prepare_model,
|
79
|
+
prepare_tokenizer,
|
77
80
|
set_ulimit,
|
78
81
|
)
|
79
82
|
from sglang.utils import get_exception_traceback
|
@@ -97,6 +100,7 @@ async def health() -> Response:
|
|
97
100
|
async def get_model_info():
|
98
101
|
result = {
|
99
102
|
"model_path": tokenizer_manager.model_path,
|
103
|
+
"is_generation": tokenizer_manager.is_generation,
|
100
104
|
}
|
101
105
|
return result
|
102
106
|
|
@@ -148,6 +152,21 @@ app.post("/generate")(generate_request)
|
|
148
152
|
app.put("/generate")(generate_request)
|
149
153
|
|
150
154
|
|
155
|
+
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
156
|
+
"""Handle an embedding request."""
|
157
|
+
try:
|
158
|
+
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
159
|
+
return ret
|
160
|
+
except ValueError as e:
|
161
|
+
return JSONResponse(
|
162
|
+
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
163
|
+
)
|
164
|
+
|
165
|
+
|
166
|
+
app.post("/encode")(encode_request)
|
167
|
+
app.put("/encode")(encode_request)
|
168
|
+
|
169
|
+
|
151
170
|
@app.post("/v1/completions")
|
152
171
|
async def openai_v1_completions(raw_request: Request):
|
153
172
|
return await v1_completions(tokenizer_manager, raw_request)
|
@@ -158,6 +177,12 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|
158
177
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
159
178
|
|
160
179
|
|
180
|
+
@app.post("/v1/embeddings")
|
181
|
+
async def openai_v1_embeddings(raw_request: Request):
|
182
|
+
response = await v1_embeddings(tokenizer_manager, raw_request)
|
183
|
+
return response
|
184
|
+
|
185
|
+
|
161
186
|
@app.get("/v1/models")
|
162
187
|
def available_models():
|
163
188
|
"""Show available models."""
|
@@ -175,6 +200,12 @@ async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("bat
|
|
175
200
|
)
|
176
201
|
|
177
202
|
|
203
|
+
@app.delete("/v1/files/{file_id}")
|
204
|
+
async def delete_file(file_id: str):
|
205
|
+
# https://platform.openai.com/docs/api-reference/files/delete
|
206
|
+
return await v1_delete_file(file_id)
|
207
|
+
|
208
|
+
|
178
209
|
@app.post("/v1/batches")
|
179
210
|
async def openai_v1_batches(raw_request: Request):
|
180
211
|
return await v1_batches(tokenizer_manager, raw_request)
|
@@ -228,6 +259,10 @@ def launch_server(
|
|
228
259
|
)
|
229
260
|
logger.info(f"{server_args=}")
|
230
261
|
|
262
|
+
# Use model from www.modelscope.cn, first download the model.
|
263
|
+
server_args.model_path = prepare_model(server_args.model_path)
|
264
|
+
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
|
265
|
+
|
231
266
|
# Launch processes for multi-node tensor parallelism
|
232
267
|
if server_args.nnodes > 1:
|
233
268
|
if server_args.node_rank != 0:
|
@@ -340,10 +375,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
340
375
|
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
341
376
|
maybe_set_triton_cache_manager()
|
342
377
|
|
343
|
-
# Set torch compile config
|
344
|
-
if server_args.enable_torch_compile:
|
345
|
-
set_torch_compile_config()
|
346
|
-
|
347
378
|
# Set global chat template
|
348
379
|
if server_args.chat_template:
|
349
380
|
# TODO: replace this with huggingface transformers template
|
@@ -353,7 +384,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
353
384
|
if not server_args.disable_flashinfer:
|
354
385
|
assert_pkg_version(
|
355
386
|
"flashinfer",
|
356
|
-
"0.1.
|
387
|
+
"0.1.4",
|
357
388
|
"Please uninstall the old version and "
|
358
389
|
"reinstall the latest version by following the instructions "
|
359
390
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -367,35 +398,63 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
367
398
|
headers["Authorization"] = f"Bearer {server_args.api_key}"
|
368
399
|
|
369
400
|
# Wait until the server is launched
|
401
|
+
success = False
|
370
402
|
for _ in range(120):
|
371
403
|
time.sleep(1)
|
372
404
|
try:
|
373
|
-
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
405
|
+
res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
406
|
+
assert res.status_code == 200, f"{res}"
|
407
|
+
success = True
|
374
408
|
break
|
375
|
-
except requests.exceptions.RequestException:
|
409
|
+
except (AssertionError, requests.exceptions.RequestException) as e:
|
410
|
+
last_traceback = get_exception_traceback()
|
376
411
|
pass
|
412
|
+
model_info = res.json()
|
413
|
+
|
414
|
+
if not success:
|
415
|
+
if pipe_finish_writer is not None:
|
416
|
+
pipe_finish_writer.send(last_traceback)
|
417
|
+
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
|
418
|
+
sys.exit(1)
|
377
419
|
|
378
420
|
# Send a warmup request
|
421
|
+
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
422
|
+
max_new_tokens = 8 if model_info["is_generation"] else 1
|
423
|
+
json_data = {
|
424
|
+
"sampling_params": {
|
425
|
+
"temperature": 0,
|
426
|
+
"max_new_tokens": max_new_tokens,
|
427
|
+
},
|
428
|
+
}
|
429
|
+
if server_args.skip_tokenizer_init:
|
430
|
+
json_data["input_ids"] = [10, 11, 12]
|
431
|
+
else:
|
432
|
+
json_data["text"] = "The capital city of France is"
|
433
|
+
|
379
434
|
try:
|
380
435
|
for _ in range(server_args.dp_size):
|
381
436
|
res = requests.post(
|
382
|
-
url +
|
383
|
-
json=
|
384
|
-
"text": "The capital city of France is",
|
385
|
-
"sampling_params": {
|
386
|
-
"temperature": 0,
|
387
|
-
"max_new_tokens": 8,
|
388
|
-
},
|
389
|
-
},
|
437
|
+
url + request_name,
|
438
|
+
json=json_data,
|
390
439
|
headers=headers,
|
391
440
|
timeout=600,
|
392
441
|
)
|
393
|
-
assert res.status_code == 200
|
442
|
+
assert res.status_code == 200, f"{res}"
|
394
443
|
except Exception as e:
|
444
|
+
last_traceback = get_exception_traceback()
|
395
445
|
if pipe_finish_writer is not None:
|
396
|
-
pipe_finish_writer.send(
|
397
|
-
print(f"Initialization failed. warmup error: {
|
398
|
-
|
446
|
+
pipe_finish_writer.send(last_traceback)
|
447
|
+
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
|
448
|
+
sys.exit(1)
|
449
|
+
|
450
|
+
# Print warnings here
|
451
|
+
if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None:
|
452
|
+
logger.warning(
|
453
|
+
"You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
|
454
|
+
"This combination is an experimental feature and we noticed it can lead to "
|
455
|
+
"wrong generation results. If you want to use chunked prefill, it is recommended "
|
456
|
+
"not using `--disable-radix-cache`."
|
457
|
+
)
|
399
458
|
|
400
459
|
logger.info("The server is fired up and ready to roll!")
|
401
460
|
if pipe_finish_writer is not None:
|
@@ -516,5 +575,18 @@ class Runtime:
|
|
516
575
|
)
|
517
576
|
return json.dumps(response.json())
|
518
577
|
|
578
|
+
def encode(
|
579
|
+
self,
|
580
|
+
prompt: str,
|
581
|
+
):
|
582
|
+
json_data = {
|
583
|
+
"text": prompt,
|
584
|
+
}
|
585
|
+
response = requests.post(
|
586
|
+
self.url + "/encode",
|
587
|
+
json=json_data,
|
588
|
+
)
|
589
|
+
return json.dumps(response.json())
|
590
|
+
|
519
591
|
def __del__(self):
|
520
592
|
self.shutdown()
|
sglang/srt/server_args.py
CHANGED
@@ -27,6 +27,7 @@ class ServerArgs:
|
|
27
27
|
model_path: str
|
28
28
|
tokenizer_path: Optional[str] = None
|
29
29
|
tokenizer_mode: str = "auto"
|
30
|
+
skip_tokenizer_init: bool = False
|
30
31
|
load_format: str = "auto"
|
31
32
|
dtype: str = "auto"
|
32
33
|
trust_remote_code: bool = True
|
@@ -42,10 +43,11 @@ class ServerArgs:
|
|
42
43
|
|
43
44
|
# Memory and scheduling
|
44
45
|
mem_fraction_static: Optional[float] = None
|
45
|
-
max_prefill_tokens: Optional[int] = None
|
46
46
|
max_running_requests: Optional[int] = None
|
47
47
|
max_num_reqs: Optional[int] = None
|
48
48
|
max_total_tokens: Optional[int] = None
|
49
|
+
chunked_prefill_size: int = -1
|
50
|
+
max_prefill_tokens: int = 16384
|
49
51
|
schedule_policy: str = "lpm"
|
50
52
|
schedule_conservativeness: float = 1.0
|
51
53
|
|
@@ -62,15 +64,12 @@ class ServerArgs:
|
|
62
64
|
|
63
65
|
# Other
|
64
66
|
api_key: Optional[str] = None
|
65
|
-
file_storage_pth: str = "
|
67
|
+
file_storage_pth: str = "SGLang_storage"
|
66
68
|
|
67
69
|
# Data parallelism
|
68
70
|
dp_size: int = 1
|
69
71
|
load_balance_method: str = "round_robin"
|
70
72
|
|
71
|
-
# Chunked Prefill
|
72
|
-
chunked_prefill_size: Optional[int] = None
|
73
|
-
|
74
73
|
# Optimization/debug options
|
75
74
|
disable_flashinfer: bool = False
|
76
75
|
disable_flashinfer_sampling: bool = False
|
@@ -96,6 +95,10 @@ class ServerArgs:
|
|
96
95
|
if self.served_model_name is None:
|
97
96
|
self.served_model_name = self.model_path
|
98
97
|
|
98
|
+
if self.chunked_prefill_size <= 0:
|
99
|
+
# Disable chunked prefill
|
100
|
+
self.chunked_prefill_size = None
|
101
|
+
|
99
102
|
if self.mem_fraction_static is None:
|
100
103
|
if self.tp_size >= 16:
|
101
104
|
self.mem_fraction_static = 0.79
|
@@ -107,6 +110,7 @@ class ServerArgs:
|
|
107
110
|
self.mem_fraction_static = 0.87
|
108
111
|
else:
|
109
112
|
self.mem_fraction_static = 0.88
|
113
|
+
|
110
114
|
if isinstance(self.additional_ports, int):
|
111
115
|
self.additional_ports = [self.additional_ports]
|
112
116
|
elif self.additional_ports is None:
|
@@ -151,6 +155,11 @@ class ServerArgs:
|
|
151
155
|
"tokenizer if available, and 'slow' will "
|
152
156
|
"always use the slow tokenizer.",
|
153
157
|
)
|
158
|
+
parser.add_argument(
|
159
|
+
"--skip-tokenizer-init",
|
160
|
+
action="store_true",
|
161
|
+
help="If set, skip init tokenizer and pass input_ids in generate request",
|
162
|
+
)
|
154
163
|
parser.add_argument(
|
155
164
|
"--load-format",
|
156
165
|
type=str,
|
@@ -226,12 +235,6 @@ class ServerArgs:
|
|
226
235
|
default=ServerArgs.mem_fraction_static,
|
227
236
|
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
|
228
237
|
)
|
229
|
-
parser.add_argument(
|
230
|
-
"--max-prefill-tokens",
|
231
|
-
type=int,
|
232
|
-
default=ServerArgs.max_prefill_tokens,
|
233
|
-
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
234
|
-
)
|
235
238
|
parser.add_argument(
|
236
239
|
"--max-running-requests",
|
237
240
|
type=int,
|
@@ -250,6 +253,18 @@ class ServerArgs:
|
|
250
253
|
default=ServerArgs.max_total_tokens,
|
251
254
|
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
|
252
255
|
)
|
256
|
+
parser.add_argument(
|
257
|
+
"--chunked-prefill-size",
|
258
|
+
type=int,
|
259
|
+
default=ServerArgs.chunked_prefill_size,
|
260
|
+
help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
|
261
|
+
)
|
262
|
+
parser.add_argument(
|
263
|
+
"--max-prefill-tokens",
|
264
|
+
type=int,
|
265
|
+
default=ServerArgs.max_prefill_tokens,
|
266
|
+
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
267
|
+
)
|
253
268
|
parser.add_argument(
|
254
269
|
"--schedule-policy",
|
255
270
|
type=str,
|
@@ -264,6 +279,7 @@ class ServerArgs:
|
|
264
279
|
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
265
280
|
)
|
266
281
|
parser.add_argument(
|
282
|
+
"--tensor-parallel-size",
|
267
283
|
"--tp-size",
|
268
284
|
type=int,
|
269
285
|
default=ServerArgs.tp_size,
|
@@ -318,6 +334,7 @@ class ServerArgs:
|
|
318
334
|
|
319
335
|
# Data parallelism
|
320
336
|
parser.add_argument(
|
337
|
+
"--data-parallel-size",
|
321
338
|
"--dp-size",
|
322
339
|
type=int,
|
323
340
|
default=ServerArgs.dp_size,
|
@@ -345,14 +362,6 @@ class ServerArgs:
|
|
345
362
|
)
|
346
363
|
parser.add_argument("--node-rank", type=int, help="The node rank.")
|
347
364
|
|
348
|
-
# Chunked prefill
|
349
|
-
parser.add_argument(
|
350
|
-
"--chunked-prefill-size",
|
351
|
-
type=int,
|
352
|
-
default=ServerArgs.chunked_prefill_size,
|
353
|
-
help="The size of the chunked prefill.",
|
354
|
-
)
|
355
|
-
|
356
365
|
# Optimization/debug options
|
357
366
|
parser.add_argument(
|
358
367
|
"--disable-flashinfer",
|
@@ -413,6 +422,8 @@ class ServerArgs:
|
|
413
422
|
|
414
423
|
@classmethod
|
415
424
|
def from_cli_args(cls, args: argparse.Namespace):
|
425
|
+
args.tp_size = args.tensor_parallel_size
|
426
|
+
args.dp_size = args.data_parallel_size
|
416
427
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
417
428
|
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
418
429
|
|
sglang/srt/utils.py
CHANGED
@@ -197,6 +197,8 @@ def allocate_init_ports(
|
|
197
197
|
def get_int_token_logit_bias(tokenizer, vocab_size):
|
198
198
|
"""Get the logit bias for integer-only tokens."""
|
199
199
|
# a bug when model's vocab size > tokenizer.vocab_size
|
200
|
+
if tokenizer == None:
|
201
|
+
return [-1e5] * vocab_size
|
200
202
|
vocab_size = tokenizer.vocab_size
|
201
203
|
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
202
204
|
for t_id in range(vocab_size):
|
@@ -223,6 +225,15 @@ def is_multimodal_model(model):
|
|
223
225
|
raise ValueError("unrecognized type")
|
224
226
|
|
225
227
|
|
228
|
+
def is_generation_model(model_architectures):
|
229
|
+
if (
|
230
|
+
"LlamaEmbeddingModel" in model_architectures
|
231
|
+
or "MistralModel" in model_architectures
|
232
|
+
):
|
233
|
+
return False
|
234
|
+
return True
|
235
|
+
|
236
|
+
|
226
237
|
def decode_video_base64(video_base64):
|
227
238
|
from PIL import Image
|
228
239
|
|
@@ -622,19 +633,6 @@ def receive_addrs(model_port_args, server_args):
|
|
622
633
|
dist.destroy_process_group()
|
623
634
|
|
624
635
|
|
625
|
-
def set_torch_compile_config():
|
626
|
-
# The following configurations are for torch compile optimizations
|
627
|
-
import torch._dynamo.config
|
628
|
-
import torch._inductor.config
|
629
|
-
|
630
|
-
torch._inductor.config.coordinate_descent_tuning = True
|
631
|
-
torch._inductor.config.triton.unique_kernel_names = True
|
632
|
-
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
633
|
-
|
634
|
-
# FIXME: tmp workaround
|
635
|
-
torch._dynamo.config.accumulated_cache_size_limit = 256
|
636
|
-
|
637
|
-
|
638
636
|
def set_ulimit(target_soft_limit=65535):
|
639
637
|
resource_type = resource.RLIMIT_NOFILE
|
640
638
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
@@ -705,3 +703,23 @@ def add_api_key_middleware(app, api_key):
|
|
705
703
|
if request.headers.get("Authorization") != "Bearer " + api_key:
|
706
704
|
return JSONResponse(content={"error": "Unauthorized"}, status_code=401)
|
707
705
|
return await call_next(request)
|
706
|
+
|
707
|
+
|
708
|
+
def prepare_model(model_path):
|
709
|
+
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
710
|
+
if not os.path.exists(model_path):
|
711
|
+
from modelscope import snapshot_download
|
712
|
+
|
713
|
+
return snapshot_download(model_path)
|
714
|
+
return model_path
|
715
|
+
|
716
|
+
|
717
|
+
def prepare_tokenizer(tokenizer_path):
|
718
|
+
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
719
|
+
if not os.path.exists(tokenizer_path):
|
720
|
+
from modelscope import snapshot_download
|
721
|
+
|
722
|
+
return snapshot_download(
|
723
|
+
tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
|
724
|
+
)
|
725
|
+
return tokenizer_path
|
sglang/test/run_eval.py
CHANGED
@@ -16,6 +16,8 @@ from sglang.test.simple_eval_common import (
|
|
16
16
|
|
17
17
|
|
18
18
|
def run_eval(args):
|
19
|
+
set_ulimit()
|
20
|
+
|
19
21
|
if "OPENAI_API_KEY" not in os.environ:
|
20
22
|
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
21
23
|
|
@@ -39,6 +41,14 @@ def run_eval(args):
|
|
39
41
|
eval_obj = MathEval(
|
40
42
|
filename, equality_checker, args.num_examples, args.num_threads
|
41
43
|
)
|
44
|
+
elif args.eval_name == "mgsm":
|
45
|
+
from sglang.test.simple_eval_mgsm import MGSMEval
|
46
|
+
|
47
|
+
eval_obj = MGSMEval(args.num_examples, args.num_threads)
|
48
|
+
elif args.eval_name == "mgsm_en":
|
49
|
+
from sglang.test.simple_eval_mgsm import MGSMEval
|
50
|
+
|
51
|
+
eval_obj = MGSMEval(args.num_examples, args.num_threads, languages=["en"])
|
42
52
|
elif args.eval_name == "gpqa":
|
43
53
|
from sglang.test.simple_eval_gpqa import GPQAEval
|
44
54
|
|
@@ -109,7 +119,6 @@ if __name__ == "__main__":
|
|
109
119
|
parser.add_argument("--eval-name", type=str, default="mmlu")
|
110
120
|
parser.add_argument("--num-examples", type=int)
|
111
121
|
parser.add_argument("--num-threads", type=int, default=512)
|
112
|
-
set_ulimit()
|
113
122
|
args = parser.parse_args()
|
114
123
|
|
115
124
|
run_eval(args)
|