sglang 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -6
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +0 -4
- sglang/lang/backend/runtime_endpoint.py +5 -2
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +29 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +1 -3
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +6 -25
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +256 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +104 -71
- sglang/srt/managers/tokenizer_manager.py +17 -8
- sglang/srt/managers/tp_worker.py +181 -115
- sglang/srt/model_executor/cuda_graph_runner.py +58 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +117 -131
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +1 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/exaone.py +1 -5
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama.py +51 -5
- sglang/srt/models/llama_classification.py +1 -20
- sglang/srt/models/llava.py +30 -5
- sglang/srt/models/llavavid.py +2 -2
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +665 -0
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +57 -44
- sglang/srt/server.py +24 -14
- sglang/srt/server_args.py +130 -28
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +7 -5
- sglang/test/test_utils.py +85 -1
- sglang/utils.py +32 -37
- sglang/version.py +1 -1
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/METADATA +30 -18
- sglang-0.3.1.dist-info/RECORD +129 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
- sglang-0.3.0.dist-info/RECORD +0 -118
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -82,6 +82,14 @@ class StreamOptions(BaseModel):
|
|
82
82
|
include_usage: Optional[bool] = False
|
83
83
|
|
84
84
|
|
85
|
+
class JsonSchemaResponseFormat(BaseModel):
|
86
|
+
name: str
|
87
|
+
description: Optional[str] = None
|
88
|
+
# use alias to workaround pydantic conflict
|
89
|
+
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
90
|
+
strict: Optional[bool] = False
|
91
|
+
|
92
|
+
|
85
93
|
class FileRequest(BaseModel):
|
86
94
|
# https://platform.openai.com/docs/api-reference/files/create
|
87
95
|
file: bytes # The File object (not file name) to be uploaded
|
@@ -213,6 +221,7 @@ class ChatCompletionMessageContentImageURL(BaseModel):
|
|
213
221
|
class ChatCompletionMessageContentImagePart(BaseModel):
|
214
222
|
type: Literal["image_url"]
|
215
223
|
image_url: ChatCompletionMessageContentImageURL
|
224
|
+
modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
|
216
225
|
|
217
226
|
|
218
227
|
ChatCompletionMessageContentPart = Union[
|
@@ -236,8 +245,8 @@ ChatCompletionMessageParam = Union[
|
|
236
245
|
|
237
246
|
|
238
247
|
class ResponseFormat(BaseModel):
|
239
|
-
|
240
|
-
|
248
|
+
type: Literal["text", "json_object", "json_schema"]
|
249
|
+
json_schema: Optional[JsonSchemaResponseFormat] = None
|
241
250
|
|
242
251
|
|
243
252
|
class ChatCompletionRequest(BaseModel):
|
@@ -263,7 +272,6 @@ class ChatCompletionRequest(BaseModel):
|
|
263
272
|
|
264
273
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
265
274
|
regex: Optional[str] = None
|
266
|
-
json_schema: Optional[str] = None
|
267
275
|
min_tokens: Optional[int] = 0
|
268
276
|
repetition_penalty: Optional[float] = 1.0
|
269
277
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
@@ -34,11 +34,13 @@ class SamplingBatchInfo:
|
|
34
34
|
linear_penalties: torch.Tensor = None
|
35
35
|
scaling_penalties: torch.Tensor = None
|
36
36
|
|
37
|
+
def __len__(self):
|
38
|
+
return len(self.temperatures)
|
39
|
+
|
37
40
|
def can_run_in_cuda_graph(self):
|
38
41
|
# Vocab bias and min_ps are not supported in CUDA graph
|
39
42
|
return (
|
40
43
|
self.logit_bias is None
|
41
|
-
and self.vocab_mask is None
|
42
44
|
and self.linear_penalties is None
|
43
45
|
and self.scaling_penalties is None
|
44
46
|
and not self.need_min_p_sampling
|
@@ -47,9 +49,11 @@ class SamplingBatchInfo:
|
|
47
49
|
@classmethod
|
48
50
|
def dummy_one(cls, max_bs: int, vocab_size: int):
|
49
51
|
ret = cls(vocab_size=vocab_size)
|
50
|
-
|
51
|
-
|
52
|
-
|
52
|
+
with torch.device("cuda"):
|
53
|
+
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float)
|
54
|
+
ret.top_ps = torch.ones((max_bs,), dtype=torch.float)
|
55
|
+
ret.top_ks = torch.ones((max_bs,), dtype=torch.int)
|
56
|
+
ret.vocab_mask = torch.zeros((max_bs, vocab_size), dtype=torch.bool)
|
53
57
|
return ret
|
54
58
|
|
55
59
|
def __getitem__(self, key):
|
@@ -61,6 +65,7 @@ class SamplingBatchInfo:
|
|
61
65
|
temperatures=self.temperatures[key],
|
62
66
|
top_ps=self.top_ps[key],
|
63
67
|
top_ks=self.top_ks[key],
|
68
|
+
vocab_mask=self.vocab_mask[key],
|
64
69
|
)
|
65
70
|
else:
|
66
71
|
raise NotImplementedError
|
@@ -74,26 +79,31 @@ class SamplingBatchInfo:
|
|
74
79
|
self.top_ps[:bs] = other.top_ps
|
75
80
|
self.top_ks[:bs] = other.top_ks
|
76
81
|
|
82
|
+
if other.vocab_mask is None:
|
83
|
+
self.vocab_mask[:bs].fill_(False)
|
84
|
+
else:
|
85
|
+
self.vocab_mask[:bs] = other.vocab_mask
|
86
|
+
|
77
87
|
@classmethod
|
78
88
|
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
79
|
-
device = "cuda"
|
80
89
|
reqs = batch.reqs
|
81
90
|
ret = cls(vocab_size=vocab_size)
|
82
91
|
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
92
|
+
with torch.device("cuda"):
|
93
|
+
ret.temperatures = torch.tensor(
|
94
|
+
[r.sampling_params.temperature for r in reqs],
|
95
|
+
dtype=torch.float,
|
96
|
+
).view(-1, 1)
|
97
|
+
ret.top_ps = torch.tensor(
|
98
|
+
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
99
|
+
)
|
100
|
+
ret.top_ks = torch.tensor(
|
101
|
+
[r.sampling_params.top_k for r in reqs], dtype=torch.int
|
102
|
+
)
|
103
|
+
ret.min_ps = torch.tensor(
|
104
|
+
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
105
|
+
)
|
106
|
+
|
97
107
|
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
|
98
108
|
|
99
109
|
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
@@ -106,7 +116,7 @@ class SamplingBatchInfo:
|
|
106
116
|
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
107
117
|
vocab_size=vocab_size,
|
108
118
|
batch=batch,
|
109
|
-
device=
|
119
|
+
device="cuda",
|
110
120
|
Penalizers={
|
111
121
|
penaltylib.BatchedFrequencyPenalizer,
|
112
122
|
penaltylib.BatchedMinNewTokensPenalizer,
|
@@ -118,11 +128,9 @@ class SamplingBatchInfo:
|
|
118
128
|
# Handle logit bias but only allocate when needed
|
119
129
|
ret.logit_bias = None
|
120
130
|
|
121
|
-
ret.update_regex_vocab_mask(batch)
|
122
|
-
|
123
131
|
return ret
|
124
132
|
|
125
|
-
def
|
133
|
+
def update_penalties(self):
|
126
134
|
self.scaling_penalties = None
|
127
135
|
self.linear_penalties = None
|
128
136
|
|
@@ -142,18 +150,16 @@ class SamplingBatchInfo:
|
|
142
150
|
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
143
151
|
|
144
152
|
def update_regex_vocab_mask(self, batch: ScheduleBatch):
|
145
|
-
|
146
|
-
device = "cuda"
|
147
|
-
has_regex = any(req.regex_fsm is not None for req in reqs)
|
153
|
+
has_regex = any(req.regex_fsm is not None for req in batch.reqs)
|
148
154
|
|
149
155
|
# Reset the vocab mask
|
150
156
|
self.vocab_mask = None
|
151
157
|
|
152
158
|
if has_regex:
|
153
159
|
self.vocab_mask = torch.zeros(
|
154
|
-
|
160
|
+
batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda"
|
155
161
|
)
|
156
|
-
for i, req in enumerate(reqs):
|
162
|
+
for i, req in enumerate(batch.reqs):
|
157
163
|
if req.regex_fsm is not None:
|
158
164
|
self.vocab_mask[i].fill_(1)
|
159
165
|
self.vocab_mask[i][
|
@@ -174,6 +180,26 @@ class SamplingBatchInfo:
|
|
174
180
|
if self_val is not None: # logit_bias can be None
|
175
181
|
setattr(self, item, self_val[new_indices])
|
176
182
|
|
183
|
+
@staticmethod
|
184
|
+
def merge_bias_tensor(
|
185
|
+
lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0
|
186
|
+
):
|
187
|
+
# bias tensor can be None
|
188
|
+
if lhs is not None or rhs is not None:
|
189
|
+
shape, dtype = None, None
|
190
|
+
if lhs is not None:
|
191
|
+
shape, dtype = lhs.shape[1:], lhs.dtype
|
192
|
+
else:
|
193
|
+
shape, dtype = rhs.shape[1:], rhs.dtype
|
194
|
+
with torch.dtype(dtype):
|
195
|
+
if lhs is None:
|
196
|
+
lhs = torch.empty((bs1, *shape), device="cuda").fill_(default)
|
197
|
+
if rhs is None:
|
198
|
+
rhs = torch.empty((bs2, *shape), device="cuda").fill_(default)
|
199
|
+
return torch.cat([lhs, rhs])
|
200
|
+
|
201
|
+
return None
|
202
|
+
|
177
203
|
def merge(self, other: "SamplingBatchInfo"):
|
178
204
|
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
179
205
|
|
@@ -187,19 +213,6 @@ class SamplingBatchInfo:
|
|
187
213
|
other_val = getattr(other, item, None)
|
188
214
|
setattr(self, item, torch.concat([self_val, other_val]))
|
189
215
|
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
self.logit_bias.shape[1]
|
194
|
-
if self.logit_bias is not None
|
195
|
-
else other.logit_bias.shape[1]
|
196
|
-
)
|
197
|
-
if self.logit_bias is None:
|
198
|
-
self.logit_bias = torch.zeros(
|
199
|
-
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
200
|
-
)
|
201
|
-
if other.logit_bias is None:
|
202
|
-
other.logit_bias = torch.zeros(
|
203
|
-
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
204
|
-
)
|
205
|
-
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
216
|
+
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
217
|
+
self.logit_bias, other.logit_bias, len(self), len(other)
|
218
|
+
)
|
sglang/srt/server.py
CHANGED
@@ -37,6 +37,7 @@ import requests
|
|
37
37
|
import uvicorn
|
38
38
|
import uvloop
|
39
39
|
from fastapi import FastAPI, File, Form, Request, UploadFile
|
40
|
+
from fastapi.middleware.cors import CORSMiddleware
|
40
41
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
41
42
|
|
42
43
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
@@ -93,6 +94,14 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
93
94
|
app = FastAPI()
|
94
95
|
tokenizer_manager = None
|
95
96
|
|
97
|
+
app.add_middleware(
|
98
|
+
CORSMiddleware,
|
99
|
+
allow_origins=["*"],
|
100
|
+
allow_credentials=True,
|
101
|
+
allow_methods=["*"],
|
102
|
+
allow_headers=["*"],
|
103
|
+
)
|
104
|
+
|
96
105
|
|
97
106
|
@app.get("/health")
|
98
107
|
async def health() -> Response:
|
@@ -272,7 +281,6 @@ async def retrieve_file_content(file_id: str):
|
|
272
281
|
|
273
282
|
def launch_server(
|
274
283
|
server_args: ServerArgs,
|
275
|
-
model_override_args: Optional[dict] = None,
|
276
284
|
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
277
285
|
):
|
278
286
|
"""Launch an HTTP server."""
|
@@ -317,7 +325,6 @@ def launch_server(
|
|
317
325
|
tp_rank_range,
|
318
326
|
server_args,
|
319
327
|
ports[3],
|
320
|
-
model_override_args,
|
321
328
|
)
|
322
329
|
|
323
330
|
try:
|
@@ -328,23 +335,19 @@ def launch_server(
|
|
328
335
|
return
|
329
336
|
|
330
337
|
# Launch processes
|
331
|
-
tokenizer_manager = TokenizerManager(server_args, port_args, model_override_args)
|
332
|
-
if server_args.chat_template:
|
333
|
-
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
334
338
|
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
335
|
-
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
336
339
|
|
337
340
|
if server_args.dp_size == 1:
|
338
341
|
start_controller_process = start_controller_process_single
|
339
342
|
else:
|
340
343
|
start_controller_process = start_controller_process_multi
|
341
|
-
|
342
344
|
proc_controller = mp.Process(
|
343
345
|
target=start_controller_process,
|
344
|
-
args=(server_args, port_args, pipe_controller_writer
|
346
|
+
args=(server_args, port_args, pipe_controller_writer),
|
345
347
|
)
|
346
348
|
proc_controller.start()
|
347
349
|
|
350
|
+
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
348
351
|
proc_detoken = mp.Process(
|
349
352
|
target=start_detokenizer_process,
|
350
353
|
args=(
|
@@ -355,6 +358,10 @@ def launch_server(
|
|
355
358
|
)
|
356
359
|
proc_detoken.start()
|
357
360
|
|
361
|
+
tokenizer_manager = TokenizerManager(server_args, port_args)
|
362
|
+
if server_args.chat_template:
|
363
|
+
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
364
|
+
|
358
365
|
# Wait for the model to finish loading
|
359
366
|
controller_init_state = pipe_controller_reader.recv()
|
360
367
|
detoken_init_state = pipe_detoken_reader.recv()
|
@@ -418,7 +425,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
418
425
|
maybe_set_triton_cache_manager()
|
419
426
|
|
420
427
|
# Check flashinfer version
|
421
|
-
if
|
428
|
+
if server_args.attention_backend == "flashinfer":
|
422
429
|
assert_pkg_version(
|
423
430
|
"flashinfer",
|
424
431
|
"0.1.6",
|
@@ -440,13 +447,12 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
|
440
447
|
time.sleep(1)
|
441
448
|
try:
|
442
449
|
res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
443
|
-
assert res.status_code == 200, f"{res}"
|
450
|
+
assert res.status_code == 200, f"{res=}, {res.text=}"
|
444
451
|
success = True
|
445
452
|
break
|
446
|
-
except (AssertionError, requests.exceptions.RequestException)
|
453
|
+
except (AssertionError, requests.exceptions.RequestException):
|
447
454
|
last_traceback = get_exception_traceback()
|
448
455
|
pass
|
449
|
-
model_info = res.json()
|
450
456
|
|
451
457
|
if not success:
|
452
458
|
if pipe_finish_writer is not None:
|
@@ -455,6 +461,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
|
455
461
|
kill_child_process(pid, including_parent=False)
|
456
462
|
return
|
457
463
|
|
464
|
+
model_info = res.json()
|
465
|
+
|
458
466
|
# Send a warmup request
|
459
467
|
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
460
468
|
max_new_tokens = 8 if model_info["is_generation"] else 1
|
@@ -501,7 +509,6 @@ class Runtime:
|
|
501
509
|
def __init__(
|
502
510
|
self,
|
503
511
|
log_level: str = "error",
|
504
|
-
model_override_args: Optional[dict] = None,
|
505
512
|
*args,
|
506
513
|
**kwargs,
|
507
514
|
):
|
@@ -525,7 +532,7 @@ class Runtime:
|
|
525
532
|
|
526
533
|
proc = mp.Process(
|
527
534
|
target=launch_server,
|
528
|
-
args=(self.server_args,
|
535
|
+
args=(self.server_args, pipe_writer),
|
529
536
|
)
|
530
537
|
proc.start()
|
531
538
|
pipe_writer.close()
|
@@ -604,6 +611,7 @@ class Runtime:
|
|
604
611
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
605
612
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
606
613
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
614
|
+
lora_path: Optional[List[Optional[str]]] = None,
|
607
615
|
):
|
608
616
|
json_data = {
|
609
617
|
"text": prompt,
|
@@ -611,7 +619,9 @@ class Runtime:
|
|
611
619
|
"return_logprob": return_logprob,
|
612
620
|
"logprob_start_len": logprob_start_len,
|
613
621
|
"top_logprobs_num": top_logprobs_num,
|
622
|
+
"lora_path": lora_path,
|
614
623
|
}
|
624
|
+
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
|
615
625
|
response = requests.post(
|
616
626
|
self.url + "/generate",
|
617
627
|
json=json_data,
|
sglang/srt/server_args.py
CHANGED
@@ -49,7 +49,6 @@ class ServerArgs:
|
|
49
49
|
# Memory and scheduling
|
50
50
|
mem_fraction_static: Optional[float] = None
|
51
51
|
max_running_requests: Optional[int] = None
|
52
|
-
max_num_reqs: Optional[int] = None
|
53
52
|
max_total_tokens: Optional[int] = None
|
54
53
|
chunked_prefill_size: int = 8192
|
55
54
|
max_prefill_tokens: int = 16384
|
@@ -75,7 +74,18 @@ class ServerArgs:
|
|
75
74
|
dp_size: int = 1
|
76
75
|
load_balance_method: str = "round_robin"
|
77
76
|
|
77
|
+
# Distributed args
|
78
|
+
nccl_init_addr: Optional[str] = None
|
79
|
+
nnodes: int = 1
|
80
|
+
node_rank: Optional[int] = None
|
81
|
+
|
82
|
+
# Model override args in JSON
|
83
|
+
json_model_override_args: str = "{}"
|
84
|
+
|
78
85
|
# Optimization/debug options
|
86
|
+
attention_backend: Optional[str] = None
|
87
|
+
sampling_backend: Optional[str] = None
|
88
|
+
|
79
89
|
disable_flashinfer: bool = False
|
80
90
|
disable_flashinfer_sampling: bool = False
|
81
91
|
disable_radix_cache: bool = False
|
@@ -86,16 +96,17 @@ class ServerArgs:
|
|
86
96
|
disable_custom_all_reduce: bool = False
|
87
97
|
enable_mixed_chunk: bool = False
|
88
98
|
enable_torch_compile: bool = False
|
99
|
+
torchao_config: str = ""
|
89
100
|
enable_p2p_check: bool = False
|
90
101
|
enable_mla: bool = False
|
91
102
|
triton_attention_reduce_in_fp32: bool = False
|
92
103
|
|
93
|
-
#
|
94
|
-
|
95
|
-
|
96
|
-
node_rank: Optional[int] = None
|
104
|
+
# LoRA
|
105
|
+
lora_paths: Optional[List[str]] = None
|
106
|
+
max_loras_per_batch: int = 8
|
97
107
|
|
98
108
|
def __post_init__(self):
|
109
|
+
# Set missing default values
|
99
110
|
if self.tokenizer_path is None:
|
100
111
|
self.tokenizer_path = self.model_path
|
101
112
|
|
@@ -106,6 +117,7 @@ class ServerArgs:
|
|
106
117
|
# Disable chunked prefill
|
107
118
|
self.chunked_prefill_size = None
|
108
119
|
|
120
|
+
# Mem fraction depends on the tensor parallelism size
|
109
121
|
if self.mem_fraction_static is None:
|
110
122
|
if self.tp_size >= 16:
|
111
123
|
self.mem_fraction_static = 0.79
|
@@ -126,6 +138,42 @@ class ServerArgs:
|
|
126
138
|
if self.random_seed is None:
|
127
139
|
self.random_seed = random.randint(0, 1 << 30)
|
128
140
|
|
141
|
+
# Deprecation warnings
|
142
|
+
if self.disable_flashinfer:
|
143
|
+
logger.warning(
|
144
|
+
"The option '--disable-flashinfer' will be deprecated in the next release. "
|
145
|
+
"Please use '--attention-backend triton' instead."
|
146
|
+
)
|
147
|
+
self.attention_backend = "triton"
|
148
|
+
if self.disable_flashinfer_sampling:
|
149
|
+
logger.warning(
|
150
|
+
"The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
|
151
|
+
"Please use '--sampling-backend pytorch' instead. "
|
152
|
+
)
|
153
|
+
self.sampling_backend = "pytorch"
|
154
|
+
|
155
|
+
# Default kernel backends
|
156
|
+
if self.enable_mla:
|
157
|
+
logger.info("MLA optimization is tunred on. Use triton backend.")
|
158
|
+
self.attention_backend = "triton"
|
159
|
+
|
160
|
+
if self.attention_backend is None:
|
161
|
+
self.attention_backend = "flashinfer"
|
162
|
+
|
163
|
+
if self.sampling_backend is None:
|
164
|
+
self.sampling_backend = "flashinfer"
|
165
|
+
|
166
|
+
# Model-specific patches
|
167
|
+
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
168
|
+
logger.info(
|
169
|
+
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
|
170
|
+
)
|
171
|
+
self.trust_remote_code = False
|
172
|
+
|
173
|
+
if "gemma-2" in self.model_path.lower():
|
174
|
+
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
|
175
|
+
self.attention_backend = "flashinfer"
|
176
|
+
|
129
177
|
@staticmethod
|
130
178
|
def add_cli_args(parser: argparse.ArgumentParser):
|
131
179
|
parser.add_argument(
|
@@ -209,11 +257,6 @@ class ServerArgs:
|
|
209
257
|
action="store_true",
|
210
258
|
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
211
259
|
)
|
212
|
-
parser.add_argument(
|
213
|
-
"--is-embedding",
|
214
|
-
action="store_true",
|
215
|
-
help="Whether to use a CausalLM as an embedding model.",
|
216
|
-
)
|
217
260
|
parser.add_argument(
|
218
261
|
"--context-length",
|
219
262
|
type=int,
|
@@ -248,6 +291,11 @@ class ServerArgs:
|
|
248
291
|
default=ServerArgs.chat_template,
|
249
292
|
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
|
250
293
|
)
|
294
|
+
parser.add_argument(
|
295
|
+
"--is-embedding",
|
296
|
+
action="store_true",
|
297
|
+
help="Whether to use a CausalLM as an embedding model.",
|
298
|
+
)
|
251
299
|
parser.add_argument(
|
252
300
|
"--mem-fraction-static",
|
253
301
|
type=float,
|
@@ -260,17 +308,12 @@ class ServerArgs:
|
|
260
308
|
default=ServerArgs.max_running_requests,
|
261
309
|
help="The maximum number of running requests.",
|
262
310
|
)
|
263
|
-
parser.add_argument(
|
264
|
-
"--max-num-reqs",
|
265
|
-
type=int,
|
266
|
-
default=ServerArgs.max_num_reqs,
|
267
|
-
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
|
268
|
-
)
|
269
311
|
parser.add_argument(
|
270
312
|
"--max-total-tokens",
|
271
313
|
type=int,
|
272
314
|
default=ServerArgs.max_total_tokens,
|
273
|
-
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction.
|
315
|
+
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
|
316
|
+
"This option is typically used for development and debugging purposes.",
|
274
317
|
)
|
275
318
|
parser.add_argument(
|
276
319
|
"--chunked-prefill-size",
|
@@ -381,16 +424,38 @@ class ServerArgs:
|
|
381
424
|
)
|
382
425
|
parser.add_argument("--node-rank", type=int, help="The node rank.")
|
383
426
|
|
427
|
+
# Model override args
|
428
|
+
parser.add_argument(
|
429
|
+
"--json-model-override-args",
|
430
|
+
type=str,
|
431
|
+
help="A dictionary in JSON string format used to override default model configurations.",
|
432
|
+
default=ServerArgs.json_model_override_args,
|
433
|
+
)
|
434
|
+
|
384
435
|
# Optimization/debug options
|
436
|
+
parser.add_argument(
|
437
|
+
"--attention-backend",
|
438
|
+
type=str,
|
439
|
+
choices=["flashinfer", "triton"],
|
440
|
+
default=ServerArgs.attention_backend,
|
441
|
+
help="Choose the kernels for attention layers.",
|
442
|
+
)
|
443
|
+
parser.add_argument(
|
444
|
+
"--sampling-backend",
|
445
|
+
type=str,
|
446
|
+
choices=["flashinfer", "pytorch"],
|
447
|
+
default=ServerArgs.sampling_backend,
|
448
|
+
help="Choose the kernels for sampling layers.",
|
449
|
+
)
|
385
450
|
parser.add_argument(
|
386
451
|
"--disable-flashinfer",
|
387
452
|
action="store_true",
|
388
|
-
help="Disable flashinfer attention kernels.",
|
453
|
+
help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
|
389
454
|
)
|
390
455
|
parser.add_argument(
|
391
456
|
"--disable-flashinfer-sampling",
|
392
457
|
action="store_true",
|
393
|
-
help="Disable flashinfer sampling kernels.",
|
458
|
+
help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
|
394
459
|
)
|
395
460
|
parser.add_argument(
|
396
461
|
"--disable-radix-cache",
|
@@ -431,7 +496,13 @@ class ServerArgs:
|
|
431
496
|
parser.add_argument(
|
432
497
|
"--enable-torch-compile",
|
433
498
|
action="store_true",
|
434
|
-
help="Optimize the model with torch.compile
|
499
|
+
help="Optimize the model with torch.compile. Experimental feature.",
|
500
|
+
)
|
501
|
+
parser.add_argument(
|
502
|
+
"--torchao-config",
|
503
|
+
type=str,
|
504
|
+
default=ServerArgs.torchao_config,
|
505
|
+
help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
|
435
506
|
)
|
436
507
|
parser.add_argument(
|
437
508
|
"--enable-p2p-check",
|
@@ -455,6 +526,21 @@ class ServerArgs:
|
|
455
526
|
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
456
527
|
)
|
457
528
|
|
529
|
+
# LoRA options
|
530
|
+
parser.add_argument(
|
531
|
+
"--lora-paths",
|
532
|
+
type=str,
|
533
|
+
nargs="*",
|
534
|
+
default=None,
|
535
|
+
help="The list of LoRA adapters.",
|
536
|
+
)
|
537
|
+
parser.add_argument(
|
538
|
+
"--max-loras-per-batch",
|
539
|
+
type=int,
|
540
|
+
default=8,
|
541
|
+
help="Maximum number of adapters for a running batch, include base-only request",
|
542
|
+
)
|
543
|
+
|
458
544
|
@classmethod
|
459
545
|
def from_cli_args(cls, args: argparse.Namespace):
|
460
546
|
args.tp_size = args.tensor_parallel_size
|
@@ -472,14 +558,30 @@ class ServerArgs:
|
|
472
558
|
assert not (
|
473
559
|
self.dp_size > 1 and self.node_rank is not None
|
474
560
|
), "multi-node data parallel is not supported"
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
)
|
479
|
-
self.
|
480
|
-
|
481
|
-
|
482
|
-
|
561
|
+
assert (
|
562
|
+
self.max_loras_per_batch > 0
|
563
|
+
# FIXME
|
564
|
+
and (self.lora_paths is None or self.disable_cuda_graph)
|
565
|
+
and (self.lora_paths is None or self.disable_radix_cache)
|
566
|
+
), "compatibility of lora and cuda graph and radix attention is in progress"
|
567
|
+
|
568
|
+
|
569
|
+
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
570
|
+
"""
|
571
|
+
Prepare the server arguments from the command line arguments.
|
572
|
+
|
573
|
+
Args:
|
574
|
+
args: The command line arguments. Typically, it should be `sys.argv[1:]`
|
575
|
+
to ensure compatibility with `parse_args` when no arguments are passed.
|
576
|
+
|
577
|
+
Returns:
|
578
|
+
The server arguments.
|
579
|
+
"""
|
580
|
+
parser = argparse.ArgumentParser()
|
581
|
+
ServerArgs.add_cli_args(parser)
|
582
|
+
raw_args = parser.parse_args(argv)
|
583
|
+
server_args = ServerArgs.from_cli_args(raw_args)
|
584
|
+
return server_args
|
483
585
|
|
484
586
|
|
485
587
|
@dataclasses.dataclass
|
sglang/srt/utils.py
CHANGED
@@ -35,6 +35,7 @@ import torch
|
|
35
35
|
import torch.distributed as dist
|
36
36
|
from fastapi.responses import JSONResponse
|
37
37
|
from packaging import version as pkg_version
|
38
|
+
from torch import nn
|
38
39
|
from torch.nn.parameter import Parameter
|
39
40
|
from triton.runtime.cache import (
|
40
41
|
FileCacheManager,
|
@@ -714,3 +715,14 @@ def configure_logger(server_args, prefix: str = ""):
|
|
714
715
|
datefmt="%H:%M:%S",
|
715
716
|
force=True,
|
716
717
|
)
|
718
|
+
|
719
|
+
|
720
|
+
# source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9
|
721
|
+
def replace_submodule(
|
722
|
+
model: nn.Module, module_name: str, new_module: nn.Module
|
723
|
+
) -> nn.Module:
|
724
|
+
"""Replace a submodule in a model with a new module."""
|
725
|
+
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
|
726
|
+
target_name = module_name.split(".")[-1]
|
727
|
+
setattr(parent, target_name, new_module)
|
728
|
+
return new_module
|