sglang 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/api.py +14 -0
- sglang/backend/anthropic.py +18 -12
- sglang/backend/base_backend.py +6 -0
- sglang/backend/openai.py +41 -12
- sglang/backend/runtime_endpoint.py +57 -6
- sglang/lang/chat_template.py +47 -26
- sglang/lang/interpreter.py +15 -2
- sglang/lang/ir.py +1 -1
- sglang/srt/constrained/__init__.py +23 -1
- sglang/srt/constrained/fsm_cache.py +14 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -1
- sglang/srt/layers/extend_attention.py +7 -6
- sglang/srt/layers/radix_attention.py +2 -10
- sglang/srt/layers/token_attention.py +12 -4
- sglang/srt/managers/io_struct.py +3 -1
- sglang/srt/managers/router/infer_batch.py +6 -2
- sglang/srt/managers/router/model_rpc.py +45 -32
- sglang/srt/managers/router/model_runner.py +40 -25
- sglang/srt/managers/tokenizer_manager.py +2 -0
- sglang/srt/model_config.py +12 -5
- sglang/srt/models/gemma.py +340 -0
- sglang/srt/models/llama2.py +5 -5
- sglang/srt/models/llava.py +2 -4
- sglang/srt/models/mixtral.py +5 -5
- sglang/srt/models/qwen.py +4 -4
- sglang/srt/models/qwen2.py +5 -5
- sglang/srt/models/stablelm.py +293 -0
- sglang/srt/server.py +111 -47
- sglang/srt/server_args.py +44 -9
- sglang/srt/utils.py +1 -0
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +15 -12
- {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/METADATA +16 -6
- sglang-0.1.14.dist-info/RECORD +64 -0
- {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/WHEEL +1 -1
- sglang/srt/models/gpt_neox.py +0 -274
- sglang-0.1.12.dist-info/RECORD +0 -63
- {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/LICENSE +0 -0
- {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,293 @@
|
|
1
|
+
# This code is based on:
|
2
|
+
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/stablelm.py
|
3
|
+
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
|
4
|
+
model compatible with HuggingFace weights."""
|
5
|
+
from typing import Optional, Tuple
|
6
|
+
|
7
|
+
import torch
|
8
|
+
from torch import nn
|
9
|
+
from transformers import PretrainedConfig
|
10
|
+
|
11
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
12
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
13
|
+
from sglang.srt.managers.router.model_runner import InputMetadata
|
14
|
+
from vllm.model_executor.layers.activation import SiluAndMul
|
15
|
+
from vllm.model_executor.layers.linear import (
|
16
|
+
LinearMethodBase,
|
17
|
+
MergedColumnParallelLinear,
|
18
|
+
QKVParallelLinear,
|
19
|
+
RowParallelLinear,
|
20
|
+
)
|
21
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
22
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
23
|
+
VocabParallelEmbedding,
|
24
|
+
ParallelLMHead,
|
25
|
+
)
|
26
|
+
from vllm.model_executor.parallel_utils.parallel_state import (
|
27
|
+
get_tensor_model_parallel_world_size,
|
28
|
+
)
|
29
|
+
from vllm.model_executor.weight_utils import (
|
30
|
+
default_weight_loader,
|
31
|
+
hf_model_weights_iterator,
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
class StablelmMLP(nn.Module):
|
36
|
+
def __init__(
|
37
|
+
self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None
|
38
|
+
) -> None:
|
39
|
+
super().__init__()
|
40
|
+
self.config = config
|
41
|
+
self.hidden_size = config.hidden_size
|
42
|
+
self.intermediate_size = config.intermediate_size
|
43
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
44
|
+
config.hidden_size,
|
45
|
+
[config.intermediate_size] * 2,
|
46
|
+
bias=False,
|
47
|
+
linear_method=linear_method,
|
48
|
+
)
|
49
|
+
self.down_proj = RowParallelLinear(
|
50
|
+
config.intermediate_size, config.hidden_size, bias=False
|
51
|
+
)
|
52
|
+
self.act_fn = SiluAndMul()
|
53
|
+
|
54
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
55
|
+
gate_up, _ = self.gate_up_proj(x)
|
56
|
+
x = self.act_fn(gate_up)
|
57
|
+
x, _ = self.down_proj(x)
|
58
|
+
return x
|
59
|
+
|
60
|
+
|
61
|
+
class StablelmAttention(nn.Module):
|
62
|
+
def __init__(
|
63
|
+
self,
|
64
|
+
config: PretrainedConfig,
|
65
|
+
layer_id: int = 0,
|
66
|
+
linear_method: Optional[LinearMethodBase] = None,
|
67
|
+
) -> None:
|
68
|
+
super().__init__()
|
69
|
+
self.config = config
|
70
|
+
self.hidden_size = config.hidden_size
|
71
|
+
tp_size = get_tensor_model_parallel_world_size()
|
72
|
+
self.total_num_heads = config.num_attention_heads
|
73
|
+
self.num_heads = self.total_num_heads // tp_size
|
74
|
+
|
75
|
+
self.total_num_key_value_heads = config.num_key_value_heads
|
76
|
+
if self.total_num_key_value_heads >= tp_size:
|
77
|
+
# Number of KV heads is greater than TP size, so we partition
|
78
|
+
# the KV heads across multiple tensor parallel GPUs.
|
79
|
+
assert self.total_num_key_value_heads % tp_size == 0
|
80
|
+
else:
|
81
|
+
# Number of KV heads is less than TP size, so we replicate
|
82
|
+
# the KV heads across multiple tensor parallel GPUs.
|
83
|
+
assert tp_size % self.total_num_key_value_heads == 0
|
84
|
+
self.num_key_value_heads = max(1, self.total_num_key_value_heads // tp_size)
|
85
|
+
self.head_dim = self.hidden_size // self.total_num_heads
|
86
|
+
self.max_position_embeddings = config.max_position_embeddings
|
87
|
+
rope_pct = getattr(
|
88
|
+
config, "rope_pct", getattr(config, "partial_rotary_factor", 1)
|
89
|
+
)
|
90
|
+
self.rotary_ndims = int(self.head_dim * rope_pct)
|
91
|
+
self.scaling = self.head_dim**-0.5
|
92
|
+
self.q_size = self.num_heads * self.head_dim
|
93
|
+
self.kv_size = self.num_key_value_heads * self.head_dim
|
94
|
+
self.qkv_bias = getattr(config, "use_qkv_bias", False)
|
95
|
+
if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
|
96
|
+
raise ValueError(
|
97
|
+
f"hidden_size must be divisible by num_heads "
|
98
|
+
f"(got `hidden_size`: {self.hidden_size}"
|
99
|
+
f" and `num_heads`: {self.num_heads})."
|
100
|
+
)
|
101
|
+
|
102
|
+
self.qkv_proj = QKVParallelLinear(
|
103
|
+
self.hidden_size,
|
104
|
+
self.head_dim,
|
105
|
+
self.total_num_heads,
|
106
|
+
self.total_num_key_value_heads,
|
107
|
+
self.qkv_bias,
|
108
|
+
linear_method=linear_method,
|
109
|
+
)
|
110
|
+
self.o_proj = RowParallelLinear(
|
111
|
+
self.total_num_heads * self.head_dim,
|
112
|
+
self.hidden_size,
|
113
|
+
bias=False,
|
114
|
+
linear_method=linear_method,
|
115
|
+
)
|
116
|
+
self.rotary_emb = get_rope(
|
117
|
+
self.head_dim,
|
118
|
+
rotary_dim=self.rotary_ndims,
|
119
|
+
max_position=self.config.max_position_embeddings,
|
120
|
+
base=self.config.rope_theta,
|
121
|
+
)
|
122
|
+
self.attn = RadixAttention(
|
123
|
+
self.num_heads,
|
124
|
+
self.head_dim,
|
125
|
+
self.scaling,
|
126
|
+
num_kv_heads=self.num_key_value_heads,
|
127
|
+
layer_id=layer_id,
|
128
|
+
)
|
129
|
+
|
130
|
+
def forward(
|
131
|
+
self,
|
132
|
+
positions: torch.Tensor,
|
133
|
+
hidden_states: torch.Tensor,
|
134
|
+
input_metadata: InputMetadata,
|
135
|
+
) -> torch.Tensor:
|
136
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
137
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
138
|
+
q, k = self.rotary_emb(positions, q, k)
|
139
|
+
attn_output = self.attn(q, k, v, input_metadata)
|
140
|
+
output, _ = self.o_proj(attn_output)
|
141
|
+
return output
|
142
|
+
|
143
|
+
|
144
|
+
class StablelmDecoderLayer(nn.Module):
|
145
|
+
def __init__(
|
146
|
+
self,
|
147
|
+
config: PretrainedConfig,
|
148
|
+
layer_id: int = 0,
|
149
|
+
linear_method: Optional[LinearMethodBase] = None,
|
150
|
+
) -> None:
|
151
|
+
super().__init__()
|
152
|
+
self.self_attn = StablelmAttention(config, layer_id=layer_id)
|
153
|
+
self.mlp = StablelmMLP(config, linear_method)
|
154
|
+
norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
|
155
|
+
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
156
|
+
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
157
|
+
|
158
|
+
def forward(
|
159
|
+
self,
|
160
|
+
positions: torch.Tensor,
|
161
|
+
hidden_states: torch.Tensor,
|
162
|
+
input_metadata: InputMetadata,
|
163
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
164
|
+
# Self Attention
|
165
|
+
residual = hidden_states
|
166
|
+
hidden_states = self.input_layernorm(hidden_states)
|
167
|
+
hidden_states = self.self_attn(
|
168
|
+
positions=positions,
|
169
|
+
hidden_states=hidden_states,
|
170
|
+
input_metadata=input_metadata,
|
171
|
+
)
|
172
|
+
hidden_states = residual + hidden_states
|
173
|
+
|
174
|
+
# Fully Connected
|
175
|
+
residual = hidden_states
|
176
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
177
|
+
hidden_states = self.mlp(hidden_states)
|
178
|
+
hidden_states = residual + hidden_states
|
179
|
+
|
180
|
+
return hidden_states, residual
|
181
|
+
|
182
|
+
|
183
|
+
class StableLMEpochModel(nn.Module):
|
184
|
+
def __init__(
|
185
|
+
self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None
|
186
|
+
) -> None:
|
187
|
+
super().__init__()
|
188
|
+
self.embed_tokens = VocabParallelEmbedding(
|
189
|
+
config.vocab_size,
|
190
|
+
config.hidden_size,
|
191
|
+
)
|
192
|
+
self.layers = nn.ModuleList(
|
193
|
+
[
|
194
|
+
StablelmDecoderLayer(config, i, linear_method)
|
195
|
+
for i in range(config.num_hidden_layers)
|
196
|
+
]
|
197
|
+
)
|
198
|
+
norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
|
199
|
+
self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
200
|
+
|
201
|
+
def forward(
|
202
|
+
self,
|
203
|
+
input_ids: torch.Tensor,
|
204
|
+
positions: torch.Tensor,
|
205
|
+
input_metadata: InputMetadata,
|
206
|
+
input_embeds: torch.Tensor = None,
|
207
|
+
) -> torch.Tensor:
|
208
|
+
if input_embeds is None:
|
209
|
+
hidden_states = self.embed_tokens(input_ids)
|
210
|
+
else:
|
211
|
+
hidden_states = input_embeds
|
212
|
+
for i in range(len(self.layers)):
|
213
|
+
layer = self.layers[i]
|
214
|
+
hidden_states, residual = layer(
|
215
|
+
positions,
|
216
|
+
hidden_states,
|
217
|
+
input_metadata,
|
218
|
+
)
|
219
|
+
hidden_states = self.norm(hidden_states)
|
220
|
+
return hidden_states
|
221
|
+
|
222
|
+
|
223
|
+
class StableLmForCausalLM(nn.Module):
|
224
|
+
def __init__(
|
225
|
+
self,
|
226
|
+
config: PretrainedConfig,
|
227
|
+
linear_method: Optional[LinearMethodBase] = None,
|
228
|
+
) -> None:
|
229
|
+
super().__init__()
|
230
|
+
self.config = config
|
231
|
+
self.linear_method = linear_method
|
232
|
+
self.model = StableLMEpochModel(config, linear_method)
|
233
|
+
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
234
|
+
self.logits_processor = LogitsProcessor(config)
|
235
|
+
|
236
|
+
def forward(
|
237
|
+
self,
|
238
|
+
input_ids: torch.Tensor,
|
239
|
+
positions: torch.Tensor,
|
240
|
+
input_metadata: InputMetadata,
|
241
|
+
input_embeds: torch.Tensor = None,
|
242
|
+
) -> torch.Tensor:
|
243
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
244
|
+
return self.logits_processor(
|
245
|
+
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
246
|
+
)
|
247
|
+
|
248
|
+
def load_weights(
|
249
|
+
self,
|
250
|
+
model_name_or_path: str,
|
251
|
+
cache_dir: Optional[str] = None,
|
252
|
+
load_format: str = "auto",
|
253
|
+
revision: Optional[str] = None,
|
254
|
+
):
|
255
|
+
stacked_params_mapping = [
|
256
|
+
# (param_name, shard_name, shard_id)
|
257
|
+
("qkv_proj", "q_proj", "q"),
|
258
|
+
("qkv_proj", "k_proj", "k"),
|
259
|
+
("qkv_proj", "v_proj", "v"),
|
260
|
+
("gate_up_proj", "gate_proj", 0),
|
261
|
+
("gate_up_proj", "up_proj", 1),
|
262
|
+
]
|
263
|
+
params_dict = dict(self.named_parameters())
|
264
|
+
for name, loaded_weight in hf_model_weights_iterator(
|
265
|
+
model_name_or_path, cache_dir, load_format, revision
|
266
|
+
):
|
267
|
+
if "rotary_emb.inv_freq" in name:
|
268
|
+
continue
|
269
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
270
|
+
# Models trained using ColossalAI may include these tensors in
|
271
|
+
# the checkpoint. Skip them.
|
272
|
+
continue
|
273
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
274
|
+
if weight_name not in name:
|
275
|
+
continue
|
276
|
+
name = name.replace(weight_name, param_name)
|
277
|
+
# Skip loading extra bias for GPTQ models.
|
278
|
+
if name.endswith(".bias") and name not in params_dict:
|
279
|
+
continue
|
280
|
+
param = params_dict[name]
|
281
|
+
weight_loader = param.weight_loader
|
282
|
+
weight_loader(param, loaded_weight, shard_id)
|
283
|
+
break
|
284
|
+
else:
|
285
|
+
# Skip loading extra bias for GPTQ models.
|
286
|
+
if name.endswith(".bias") and name not in params_dict:
|
287
|
+
continue
|
288
|
+
param = params_dict[name]
|
289
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
290
|
+
weight_loader(param, loaded_weight)
|
291
|
+
|
292
|
+
|
293
|
+
EntryClass = StableLmForCausalLM
|
sglang/srt/server.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""SRT: SGLang Runtime"""
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import dataclasses
|
4
5
|
import json
|
5
6
|
import multiprocessing as mp
|
6
7
|
import os
|
@@ -52,10 +53,31 @@ from sglang.srt.managers.openai_protocol import (
|
|
52
53
|
from sglang.srt.managers.router.manager import start_router_process
|
53
54
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
54
55
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
55
|
-
from sglang.srt.utils import
|
56
|
+
from sglang.srt.utils import handle_port_init
|
57
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
58
|
+
from starlette.responses import JSONResponse
|
56
59
|
|
57
60
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
58
61
|
|
62
|
+
API_KEY_HEADER_NAME = "X-API-Key"
|
63
|
+
|
64
|
+
|
65
|
+
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
66
|
+
def __init__(self, app, api_key: str):
|
67
|
+
super().__init__(app)
|
68
|
+
self.api_key = api_key
|
69
|
+
|
70
|
+
async def dispatch(self, request: Request, call_next):
|
71
|
+
# extract API key from the request headers
|
72
|
+
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
|
73
|
+
if not api_key_header or api_key_header != self.api_key:
|
74
|
+
return JSONResponse(
|
75
|
+
status_code=403,
|
76
|
+
content={"detail": "Invalid API Key"},
|
77
|
+
)
|
78
|
+
response = await call_next(request)
|
79
|
+
return response
|
80
|
+
|
59
81
|
|
60
82
|
app = FastAPI()
|
61
83
|
tokenizer_manager = None
|
@@ -86,6 +108,11 @@ async def get_model_info():
|
|
86
108
|
return result
|
87
109
|
|
88
110
|
|
111
|
+
@app.get("/get_server_args")
|
112
|
+
async def get_server_args():
|
113
|
+
return dataclasses.asdict(tokenizer_manager.server_args)
|
114
|
+
|
115
|
+
|
89
116
|
@app.get("/flush_cache")
|
90
117
|
async def flush_cache():
|
91
118
|
await tokenizer_manager.flush_cache()
|
@@ -96,19 +123,25 @@ async def flush_cache():
|
|
96
123
|
)
|
97
124
|
|
98
125
|
|
99
|
-
async def
|
126
|
+
async def detokenize_logprob_tokens(token_logprobs):
|
127
|
+
token_ids = [tid for tid, _ in token_logprobs]
|
128
|
+
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
|
129
|
+
return [(text, logprob) for text, (_, logprob) in zip(token_texts, token_logprobs)]
|
130
|
+
|
131
|
+
|
132
|
+
async def stream_generator(obj: GenerateReqInput):
|
100
133
|
async for out in tokenizer_manager.generate_request(obj):
|
134
|
+
if obj.return_logprob and obj.return_text_in_logprobs:
|
135
|
+
out["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
|
136
|
+
out["meta_info"]["token_logprob"]
|
137
|
+
)
|
101
138
|
yield out
|
102
139
|
|
103
140
|
|
104
141
|
async def make_openai_style_logprobs(token_logprobs):
|
105
142
|
ret_logprobs = LogProbs()
|
106
143
|
|
107
|
-
|
108
|
-
token_ids = [tid for tid, _ in token_logprobs]
|
109
|
-
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
|
110
|
-
|
111
|
-
for token_text, (_, token_logprob) in zip(token_texts, token_logprobs):
|
144
|
+
for token_text, token_logprob in token_logprobs:
|
112
145
|
ret_logprobs.tokens.append(token_text)
|
113
146
|
ret_logprobs.token_logprobs.append(token_logprob)
|
114
147
|
|
@@ -132,6 +165,11 @@ async def generate_request(obj: GenerateReqInput):
|
|
132
165
|
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
133
166
|
|
134
167
|
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
168
|
+
if obj.return_logprob and obj.return_text_in_logprobs:
|
169
|
+
ret["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
|
170
|
+
ret["meta_info"]["token_logprob"]
|
171
|
+
)
|
172
|
+
|
135
173
|
return ret
|
136
174
|
|
137
175
|
|
@@ -155,6 +193,7 @@ async def v1_completions(raw_request: Request):
|
|
155
193
|
"regex": request.regex,
|
156
194
|
},
|
157
195
|
return_logprob=request.logprobs is not None,
|
196
|
+
return_text_in_logprobs=True,
|
158
197
|
stream=request.stream,
|
159
198
|
)
|
160
199
|
adapted_request.post_init()
|
@@ -211,6 +250,7 @@ async def v1_completions(raw_request: Request):
|
|
211
250
|
|
212
251
|
# Non-streaming response.
|
213
252
|
ret = await generate_request(adapted_request)
|
253
|
+
ret = ret[0] if isinstance(ret, list) else ret
|
214
254
|
|
215
255
|
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
216
256
|
completion_tokens = ret["meta_info"]["completion_tokens"]
|
@@ -463,8 +503,10 @@ def launch_server(server_args, pipe_finish_writer):
|
|
463
503
|
|
464
504
|
assert proc_router.is_alive() and proc_detoken.is_alive()
|
465
505
|
|
506
|
+
if server_args.api_key and server_args.api_key != "":
|
507
|
+
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
508
|
+
|
466
509
|
def _launch_server():
|
467
|
-
# Launch api server
|
468
510
|
uvicorn.run(
|
469
511
|
app,
|
470
512
|
host=server_args.host,
|
@@ -474,49 +516,59 @@ def launch_server(server_args, pipe_finish_writer):
|
|
474
516
|
loop="uvloop",
|
475
517
|
)
|
476
518
|
|
477
|
-
|
478
|
-
|
519
|
+
def _wait_and_warmup():
|
520
|
+
headers = {}
|
521
|
+
url = server_args.url()
|
522
|
+
if server_args.api_key and server_args.api_key != "":
|
523
|
+
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
479
524
|
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
pass
|
488
|
-
else:
|
489
|
-
if pipe_finish_writer is not None:
|
490
|
-
pipe_finish_writer.send(str(e))
|
525
|
+
for _ in range(120):
|
526
|
+
time.sleep(0.5)
|
527
|
+
try:
|
528
|
+
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
529
|
+
break
|
530
|
+
except requests.exceptions.RequestException as e:
|
531
|
+
pass
|
491
532
|
else:
|
492
|
-
|
493
|
-
|
533
|
+
if pipe_finish_writer is not None:
|
534
|
+
pipe_finish_writer.send(str(e))
|
535
|
+
else:
|
536
|
+
print(e, flush=True)
|
537
|
+
return
|
494
538
|
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
539
|
+
# Warmup
|
540
|
+
try:
|
541
|
+
# print("Warmup...", flush=True)
|
542
|
+
res = requests.post(
|
543
|
+
url + "/generate",
|
544
|
+
json={
|
545
|
+
"text": "Say this is a warmup request.",
|
546
|
+
"sampling_params": {
|
547
|
+
"temperature": 0,
|
548
|
+
"max_new_tokens": 16,
|
549
|
+
},
|
505
550
|
},
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
551
|
+
headers=headers,
|
552
|
+
timeout=60,
|
553
|
+
)
|
554
|
+
# print(f"Warmup done. model response: {res.json()['text']}")
|
555
|
+
# print("=" * 20, "Server is ready", "=" * 20, flush=True)
|
556
|
+
except requests.exceptions.RequestException as e:
|
557
|
+
if pipe_finish_writer is not None:
|
558
|
+
pipe_finish_writer.send(str(e))
|
559
|
+
else:
|
560
|
+
print(e, flush=True)
|
561
|
+
return
|
562
|
+
|
512
563
|
if pipe_finish_writer is not None:
|
513
|
-
pipe_finish_writer.send(
|
514
|
-
else:
|
515
|
-
print(e, flush=True)
|
516
|
-
return
|
564
|
+
pipe_finish_writer.send("init ok")
|
517
565
|
|
518
|
-
|
519
|
-
|
566
|
+
t = threading.Thread(target=_wait_and_warmup)
|
567
|
+
t.start()
|
568
|
+
try:
|
569
|
+
_launch_server()
|
570
|
+
finally:
|
571
|
+
t.join()
|
520
572
|
|
521
573
|
|
522
574
|
class Runtime:
|
@@ -529,11 +581,17 @@ class Runtime:
|
|
529
581
|
trust_remote_code: bool = True,
|
530
582
|
mem_fraction_static: float = ServerArgs.mem_fraction_static,
|
531
583
|
max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
|
584
|
+
context_length: int = ServerArgs.context_length,
|
532
585
|
tp_size: int = 1,
|
533
|
-
model_mode: List[str] = (),
|
534
586
|
schedule_heuristic: str = "lpm",
|
587
|
+
attention_reduce_in_fp32: bool = False,
|
535
588
|
random_seed: int = 42,
|
536
589
|
log_level: str = "error",
|
590
|
+
disable_radix_cache: bool = False,
|
591
|
+
enable_flashinfer: bool = False,
|
592
|
+
disable_regex_jump_forward: bool = False,
|
593
|
+
disable_disk_cache: bool = False,
|
594
|
+
api_key: str = "",
|
537
595
|
port: Optional[int] = None,
|
538
596
|
additional_ports: Optional[Union[List[int], int]] = None,
|
539
597
|
):
|
@@ -550,11 +608,17 @@ class Runtime:
|
|
550
608
|
trust_remote_code=trust_remote_code,
|
551
609
|
mem_fraction_static=mem_fraction_static,
|
552
610
|
max_prefill_num_token=max_prefill_num_token,
|
611
|
+
context_length=context_length,
|
553
612
|
tp_size=tp_size,
|
554
|
-
model_mode=model_mode,
|
555
613
|
schedule_heuristic=schedule_heuristic,
|
614
|
+
attention_reduce_in_fp32=attention_reduce_in_fp32,
|
556
615
|
random_seed=random_seed,
|
557
616
|
log_level=log_level,
|
617
|
+
disable_radix_cache=disable_radix_cache,
|
618
|
+
enable_flashinfer=enable_flashinfer,
|
619
|
+
disable_regex_jump_forward=disable_regex_jump_forward,
|
620
|
+
disable_disk_cache=disable_disk_cache,
|
621
|
+
api_key=api_key,
|
558
622
|
)
|
559
623
|
|
560
624
|
self.url = self.server_args.url()
|
sglang/srt/server_args.py
CHANGED
@@ -16,17 +16,23 @@ class ServerArgs:
|
|
16
16
|
trust_remote_code: bool = True
|
17
17
|
mem_fraction_static: Optional[float] = None
|
18
18
|
max_prefill_num_token: Optional[int] = None
|
19
|
+
context_length: Optional[int] = None
|
19
20
|
tp_size: int = 1
|
20
|
-
model_mode: List[str] = ()
|
21
21
|
schedule_heuristic: str = "lpm"
|
22
22
|
schedule_conservativeness: float = 1.0
|
23
|
+
attention_reduce_in_fp32: bool = False
|
23
24
|
random_seed: int = 42
|
24
25
|
stream_interval: int = 8
|
25
26
|
disable_log_stats: bool = False
|
26
27
|
log_stats_interval: int = 10
|
27
28
|
log_level: str = "info"
|
29
|
+
|
30
|
+
# optional modes
|
31
|
+
disable_radix_cache: bool = False
|
32
|
+
enable_flashinfer: bool = False
|
28
33
|
disable_regex_jump_forward: bool = False
|
29
34
|
disable_disk_cache: bool = False
|
35
|
+
api_key: str = ""
|
30
36
|
|
31
37
|
def __post_init__(self):
|
32
38
|
if self.tokenizer_path is None:
|
@@ -117,20 +123,18 @@ class ServerArgs:
|
|
117
123
|
default=ServerArgs.max_prefill_num_token,
|
118
124
|
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
119
125
|
)
|
126
|
+
parser.add_argument(
|
127
|
+
"--context-length",
|
128
|
+
type=int,
|
129
|
+
default=ServerArgs.context_length,
|
130
|
+
help="The model's maximum context length. Use this to reduce the context length to save memory. Defaults to None (will use the value from the model's config.json instead).",
|
131
|
+
)
|
120
132
|
parser.add_argument(
|
121
133
|
"--tp-size",
|
122
134
|
type=int,
|
123
135
|
default=ServerArgs.tp_size,
|
124
136
|
help="Tensor parallelism degree.",
|
125
137
|
)
|
126
|
-
parser.add_argument(
|
127
|
-
"--model-mode",
|
128
|
-
type=str,
|
129
|
-
default=[],
|
130
|
-
nargs="+",
|
131
|
-
choices=["flashinfer", "no-cache"],
|
132
|
-
help="Model mode: [flashinfer, no-cache]",
|
133
|
-
)
|
134
138
|
parser.add_argument(
|
135
139
|
"--schedule-heuristic",
|
136
140
|
type=str,
|
@@ -149,6 +153,11 @@ class ServerArgs:
|
|
149
153
|
default=ServerArgs.random_seed,
|
150
154
|
help="Random seed.",
|
151
155
|
)
|
156
|
+
parser.add_argument(
|
157
|
+
"--attention-reduce-in-fp32",
|
158
|
+
action="store_true",
|
159
|
+
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
|
160
|
+
)
|
152
161
|
parser.add_argument(
|
153
162
|
"--stream-interval",
|
154
163
|
type=int,
|
@@ -172,6 +181,17 @@ class ServerArgs:
|
|
172
181
|
default=ServerArgs.log_stats_interval,
|
173
182
|
help="Log stats interval in second.",
|
174
183
|
)
|
184
|
+
# optional modes
|
185
|
+
parser.add_argument(
|
186
|
+
"--disable-radix-cache",
|
187
|
+
action="store_true",
|
188
|
+
help="Disable RadixAttention",
|
189
|
+
)
|
190
|
+
parser.add_argument(
|
191
|
+
"--enable-flashinfer",
|
192
|
+
action="store_true",
|
193
|
+
help="Enable flashinfer inference kernels",
|
194
|
+
)
|
175
195
|
parser.add_argument(
|
176
196
|
"--disable-regex-jump-forward",
|
177
197
|
action="store_true",
|
@@ -182,6 +202,12 @@ class ServerArgs:
|
|
182
202
|
action="store_true",
|
183
203
|
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
184
204
|
)
|
205
|
+
parser.add_argument(
|
206
|
+
"--api-key",
|
207
|
+
type=str,
|
208
|
+
default=ServerArgs.api_key,
|
209
|
+
help="Set API Key",
|
210
|
+
)
|
185
211
|
|
186
212
|
@classmethod
|
187
213
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -191,6 +217,15 @@ class ServerArgs:
|
|
191
217
|
def url(self):
|
192
218
|
return f"http://{self.host}:{self.port}"
|
193
219
|
|
220
|
+
def get_optional_modes_logging(self):
|
221
|
+
return (
|
222
|
+
f"disable_radix_cache={self.disable_radix_cache}, "
|
223
|
+
f"enable_flashinfer={self.enable_flashinfer}, "
|
224
|
+
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
225
|
+
f"disable_disk_cache={self.disable_disk_cache}, "
|
226
|
+
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
|
227
|
+
)
|
228
|
+
|
194
229
|
|
195
230
|
@dataclasses.dataclass
|
196
231
|
class PortArgs:
|
sglang/srt/utils.py
CHANGED
@@ -103,6 +103,7 @@ def alloc_usable_network_port(num, used_list=()):
|
|
103
103
|
def check_port(port):
|
104
104
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
105
105
|
try:
|
106
|
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
106
107
|
s.bind(("", port))
|
107
108
|
return True
|
108
109
|
except socket.error:
|
sglang/test/test_utils.py
CHANGED
@@ -155,7 +155,7 @@ def select_sglang_backend(args):
|
|
155
155
|
global_config.enable_parallel_decoding = False
|
156
156
|
global_config.enable_parallel_encoding = False
|
157
157
|
backend = RuntimeEndpoint(f"{args.host}:{args.port}")
|
158
|
-
elif args.backend.startswith("gpt"):
|
158
|
+
elif args.backend.startswith("gpt-"):
|
159
159
|
backend = OpenAI(args.backend)
|
160
160
|
else:
|
161
161
|
raise ValueError(f"Invalid backend: {args.backend}")
|