sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +28 -10
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +4 -0
- sglang/srt/managers/schedule_batch.py +120 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +202 -140
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +60 -1
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +92 -49
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +92 -58
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +116 -17
- sglang/srt/server_args.py +121 -45
- sglang/srt/utils.py +11 -3
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
sglang/srt/server.py
CHANGED
@@ -25,11 +25,12 @@ import json
|
|
25
25
|
import logging
|
26
26
|
import multiprocessing as mp
|
27
27
|
import os
|
28
|
-
import random
|
29
28
|
import threading
|
30
29
|
import time
|
31
30
|
from http import HTTPStatus
|
32
|
-
from typing import Dict, List, Optional, Union
|
31
|
+
from typing import AsyncIterator, Dict, List, Optional, Union
|
32
|
+
|
33
|
+
import orjson
|
33
34
|
|
34
35
|
# Fix a bug of Python threading
|
35
36
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
@@ -40,7 +41,8 @@ import uvicorn
|
|
40
41
|
import uvloop
|
41
42
|
from fastapi import FastAPI, File, Form, Request, UploadFile
|
42
43
|
from fastapi.middleware.cors import CORSMiddleware
|
43
|
-
from fastapi.responses import
|
44
|
+
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
45
|
+
from uvicorn.config import LOGGING_CONFIG
|
44
46
|
|
45
47
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
46
48
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
@@ -176,12 +178,12 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
|
176
178
|
success, message = await tokenizer_manager.update_weights(obj, request)
|
177
179
|
content = {"success": success, "message": message}
|
178
180
|
if success:
|
179
|
-
return
|
181
|
+
return ORJSONResponse(
|
180
182
|
content,
|
181
183
|
status_code=HTTPStatus.OK,
|
182
184
|
)
|
183
185
|
else:
|
184
|
-
return
|
186
|
+
return ORJSONResponse(
|
185
187
|
content,
|
186
188
|
status_code=HTTPStatus.BAD_REQUEST,
|
187
189
|
)
|
@@ -192,14 +194,18 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|
192
194
|
"""Handle a generate request."""
|
193
195
|
if obj.stream:
|
194
196
|
|
195
|
-
async def stream_results():
|
197
|
+
async def stream_results() -> AsyncIterator[bytes]:
|
196
198
|
try:
|
197
199
|
async for out in tokenizer_manager.generate_request(obj, request):
|
198
|
-
yield
|
200
|
+
yield b"data: " + orjson.dumps(
|
201
|
+
out, option=orjson.OPT_NON_STR_KEYS
|
202
|
+
) + b"\n\n"
|
199
203
|
except ValueError as e:
|
200
204
|
out = {"error": {"message": str(e)}}
|
201
|
-
yield
|
202
|
-
|
205
|
+
yield b"data: " + orjson.dumps(
|
206
|
+
out, option=orjson.OPT_NON_STR_KEYS
|
207
|
+
) + b"\n\n"
|
208
|
+
yield b"data: [DONE]\n\n"
|
203
209
|
|
204
210
|
return StreamingResponse(
|
205
211
|
stream_results(),
|
@@ -211,7 +217,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|
211
217
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
212
218
|
return ret
|
213
219
|
except ValueError as e:
|
214
|
-
return
|
220
|
+
return ORJSONResponse(
|
215
221
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
216
222
|
)
|
217
223
|
|
@@ -226,7 +232,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
|
|
226
232
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
227
233
|
return ret
|
228
234
|
except ValueError as e:
|
229
|
-
return
|
235
|
+
return ORJSONResponse(
|
230
236
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
231
237
|
)
|
232
238
|
|
@@ -241,7 +247,7 @@ async def judge_request(obj: RewardReqInput, request: Request):
|
|
241
247
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
242
248
|
return ret
|
243
249
|
except ValueError as e:
|
244
|
-
return
|
250
|
+
return ORJSONResponse(
|
245
251
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
246
252
|
)
|
247
253
|
|
@@ -260,13 +266,13 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|
260
266
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
261
267
|
|
262
268
|
|
263
|
-
@app.post("/v1/embeddings")
|
269
|
+
@app.post("/v1/embeddings", response_class=ORJSONResponse)
|
264
270
|
async def openai_v1_embeddings(raw_request: Request):
|
265
271
|
response = await v1_embeddings(tokenizer_manager, raw_request)
|
266
272
|
return response
|
267
273
|
|
268
274
|
|
269
|
-
@app.get("/v1/models")
|
275
|
+
@app.get("/v1/models", response_class=ORJSONResponse)
|
270
276
|
def available_models():
|
271
277
|
"""Show available models."""
|
272
278
|
served_model_names = [tokenizer_manager.served_model_name]
|
@@ -429,6 +435,14 @@ def launch_server(
|
|
429
435
|
|
430
436
|
try:
|
431
437
|
# Listen for HTTP requests
|
438
|
+
LOGGING_CONFIG["formatters"]["default"][
|
439
|
+
"fmt"
|
440
|
+
] = "[%(asctime)s] %(levelprefix)s %(message)s"
|
441
|
+
LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
442
|
+
LOGGING_CONFIG["formatters"]["access"][
|
443
|
+
"fmt"
|
444
|
+
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
|
445
|
+
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
432
446
|
uvicorn.run(
|
433
447
|
app,
|
434
448
|
host=server_args.host,
|
@@ -447,7 +461,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
447
461
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
448
462
|
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
449
463
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
450
|
-
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "
|
464
|
+
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
451
465
|
|
452
466
|
# Set ulimit
|
453
467
|
set_ulimit()
|
@@ -528,6 +542,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
|
528
542
|
kill_child_process(pid, including_parent=False)
|
529
543
|
return
|
530
544
|
|
545
|
+
# logger.info(f"{res.json()=}")
|
546
|
+
|
531
547
|
logger.info("The server is fired up and ready to roll!")
|
532
548
|
if pipe_finish_writer is not None:
|
533
549
|
pipe_finish_writer.send("ready")
|
@@ -692,6 +708,10 @@ class Runtime:
|
|
692
708
|
self.shutdown()
|
693
709
|
|
694
710
|
|
711
|
+
STREAM_END_SYMBOL = b"data: [DONE]"
|
712
|
+
STREAM_CHUNK_START_SYMBOL = b"data:"
|
713
|
+
|
714
|
+
|
695
715
|
class Engine:
|
696
716
|
"""
|
697
717
|
SRT Engine without an HTTP server layer.
|
@@ -716,7 +736,10 @@ class Engine:
|
|
716
736
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
717
737
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
718
738
|
lora_path: Optional[List[Optional[str]]] = None,
|
739
|
+
stream: bool = False,
|
719
740
|
):
|
741
|
+
# TODO (ByronHsu): refactor to reduce the duplicated code
|
742
|
+
|
720
743
|
obj = GenerateReqInput(
|
721
744
|
text=prompt,
|
722
745
|
sampling_params=sampling_params,
|
@@ -724,13 +747,89 @@ class Engine:
|
|
724
747
|
logprob_start_len=logprob_start_len,
|
725
748
|
top_logprobs_num=top_logprobs_num,
|
726
749
|
lora_path=lora_path,
|
750
|
+
stream=stream,
|
727
751
|
)
|
728
752
|
|
729
753
|
# get the current event loop
|
730
754
|
loop = asyncio.get_event_loop()
|
731
|
-
|
755
|
+
ret = loop.run_until_complete(generate_request(obj, None))
|
756
|
+
|
757
|
+
if stream is True:
|
758
|
+
|
759
|
+
def generator_wrapper():
|
760
|
+
offset = 0
|
761
|
+
loop = asyncio.get_event_loop()
|
762
|
+
generator = ret.body_iterator
|
763
|
+
while True:
|
764
|
+
chunk = loop.run_until_complete(generator.__anext__())
|
765
|
+
|
766
|
+
if chunk.startswith(STREAM_END_SYMBOL):
|
767
|
+
break
|
768
|
+
else:
|
769
|
+
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
770
|
+
data["text"] = data["text"][offset:]
|
771
|
+
offset += len(data["text"])
|
772
|
+
yield data
|
773
|
+
|
774
|
+
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
|
775
|
+
# however, it allows to wrap the generator as a subfunction and return
|
776
|
+
return generator_wrapper()
|
777
|
+
else:
|
778
|
+
return ret
|
779
|
+
|
780
|
+
async def async_generate(
|
781
|
+
self,
|
782
|
+
prompt: Union[str, List[str]],
|
783
|
+
sampling_params: Optional[Dict] = None,
|
784
|
+
return_logprob: Optional[Union[List[bool], bool]] = False,
|
785
|
+
logprob_start_len: Optional[Union[List[int], int]] = None,
|
786
|
+
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
787
|
+
lora_path: Optional[List[Optional[str]]] = None,
|
788
|
+
stream: bool = False,
|
789
|
+
):
|
790
|
+
obj = GenerateReqInput(
|
791
|
+
text=prompt,
|
792
|
+
sampling_params=sampling_params,
|
793
|
+
return_logprob=return_logprob,
|
794
|
+
logprob_start_len=logprob_start_len,
|
795
|
+
top_logprobs_num=top_logprobs_num,
|
796
|
+
lora_path=lora_path,
|
797
|
+
stream=stream,
|
798
|
+
)
|
799
|
+
|
800
|
+
ret = await generate_request(obj, None)
|
801
|
+
|
802
|
+
if stream is True:
|
803
|
+
generator = ret.body_iterator
|
804
|
+
|
805
|
+
async def generator_wrapper():
|
806
|
+
|
807
|
+
offset = 0
|
808
|
+
|
809
|
+
while True:
|
810
|
+
chunk = await generator.__anext__()
|
811
|
+
|
812
|
+
if chunk.startswith(STREAM_END_SYMBOL):
|
813
|
+
break
|
814
|
+
else:
|
815
|
+
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
816
|
+
data["text"] = data["text"][offset:]
|
817
|
+
offset += len(data["text"])
|
818
|
+
yield data
|
819
|
+
|
820
|
+
return generator_wrapper()
|
821
|
+
else:
|
822
|
+
return ret
|
732
823
|
|
733
824
|
def shutdown(self):
|
734
825
|
kill_child_process(os.getpid(), including_parent=False)
|
735
826
|
|
736
|
-
|
827
|
+
def get_tokenizer(self):
|
828
|
+
global tokenizer_manager
|
829
|
+
|
830
|
+
if tokenizer_manager is None:
|
831
|
+
raise ReferenceError("Tokenizer Manager is not initialized.")
|
832
|
+
else:
|
833
|
+
return tokenizer_manager.tokenizer
|
834
|
+
|
835
|
+
# TODO (ByronHsu): encode
|
sglang/srt/server_args.py
CHANGED
@@ -35,12 +35,12 @@ class ServerArgs:
|
|
35
35
|
tokenizer_mode: str = "auto"
|
36
36
|
skip_tokenizer_init: bool = False
|
37
37
|
load_format: str = "auto"
|
38
|
+
trust_remote_code: bool = True
|
38
39
|
dtype: str = "auto"
|
39
|
-
device: str = "cuda"
|
40
40
|
kv_cache_dtype: str = "auto"
|
41
|
-
trust_remote_code: bool = True
|
42
|
-
context_length: Optional[int] = None
|
43
41
|
quantization: Optional[str] = None
|
42
|
+
context_length: Optional[int] = None
|
43
|
+
device: str = "cuda"
|
44
44
|
served_model_name: Optional[str] = None
|
45
45
|
chat_template: Optional[str] = None
|
46
46
|
is_embedding: bool = False
|
@@ -73,6 +73,7 @@ class ServerArgs:
|
|
73
73
|
# Other
|
74
74
|
api_key: Optional[str] = None
|
75
75
|
file_storage_pth: str = "SGLang_storage"
|
76
|
+
enable_cache_report: bool = False
|
76
77
|
|
77
78
|
# Data parallelism
|
78
79
|
dp_size: int = 1
|
@@ -86,10 +87,23 @@ class ServerArgs:
|
|
86
87
|
# Model override args in JSON
|
87
88
|
json_model_override_args: str = "{}"
|
88
89
|
|
89
|
-
#
|
90
|
+
# Double Sparsity
|
91
|
+
enable_double_sparsity: bool = False
|
92
|
+
ds_channel_config_path: str = None
|
93
|
+
ds_heavy_channel_num: int = 32
|
94
|
+
ds_heavy_token_num: int = 256
|
95
|
+
ds_heavy_channel_type: str = "qk"
|
96
|
+
ds_sparse_decode_threshold: int = 4096
|
97
|
+
|
98
|
+
# LoRA
|
99
|
+
lora_paths: Optional[List[str]] = None
|
100
|
+
max_loras_per_batch: int = 8
|
101
|
+
|
102
|
+
# Kernel backend
|
90
103
|
attention_backend: Optional[str] = None
|
91
104
|
sampling_backend: Optional[str] = None
|
92
105
|
|
106
|
+
# Optimization/debug options
|
93
107
|
disable_flashinfer: bool = False
|
94
108
|
disable_flashinfer_sampling: bool = False
|
95
109
|
disable_radix_cache: bool = False
|
@@ -99,16 +113,16 @@ class ServerArgs:
|
|
99
113
|
disable_disk_cache: bool = False
|
100
114
|
disable_custom_all_reduce: bool = False
|
101
115
|
disable_mla: bool = False
|
116
|
+
disable_penalizer: bool = False
|
117
|
+
disable_nan_detection: bool = False
|
118
|
+
enable_overlap_schedule: bool = False
|
102
119
|
enable_mixed_chunk: bool = False
|
103
120
|
enable_torch_compile: bool = False
|
104
121
|
max_torch_compile_bs: int = 32
|
105
122
|
torchao_config: str = ""
|
106
123
|
enable_p2p_check: bool = False
|
107
124
|
triton_attention_reduce_in_fp32: bool = False
|
108
|
-
|
109
|
-
# LoRA
|
110
|
-
lora_paths: Optional[List[str]] = None
|
111
|
-
max_loras_per_batch: int = 8
|
125
|
+
num_continuous_decode_steps: int = 1
|
112
126
|
|
113
127
|
def __post_init__(self):
|
114
128
|
# Set missing default values
|
@@ -224,6 +238,11 @@ class ServerArgs:
|
|
224
238
|
'"dummy" will initialize the weights with random values, '
|
225
239
|
"which is mainly for profiling.",
|
226
240
|
)
|
241
|
+
parser.add_argument(
|
242
|
+
"--trust-remote-code",
|
243
|
+
action="store_true",
|
244
|
+
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
245
|
+
)
|
227
246
|
parser.add_argument(
|
228
247
|
"--dtype",
|
229
248
|
type=str,
|
@@ -238,13 +257,6 @@ class ServerArgs:
|
|
238
257
|
'* "float" is shorthand for FP32 precision.\n'
|
239
258
|
'* "float32" for FP32 precision.',
|
240
259
|
)
|
241
|
-
parser.add_argument(
|
242
|
-
"--device",
|
243
|
-
type=str,
|
244
|
-
default="cuda",
|
245
|
-
choices=["cuda"],
|
246
|
-
help="The device type.",
|
247
|
-
)
|
248
260
|
parser.add_argument(
|
249
261
|
"--kv-cache-dtype",
|
250
262
|
type=str,
|
@@ -252,17 +264,6 @@ class ServerArgs:
|
|
252
264
|
choices=["auto", "fp8_e5m2"],
|
253
265
|
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
|
254
266
|
)
|
255
|
-
parser.add_argument(
|
256
|
-
"--trust-remote-code",
|
257
|
-
action="store_true",
|
258
|
-
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
259
|
-
)
|
260
|
-
parser.add_argument(
|
261
|
-
"--context-length",
|
262
|
-
type=int,
|
263
|
-
default=ServerArgs.context_length,
|
264
|
-
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
265
|
-
)
|
266
267
|
parser.add_argument(
|
267
268
|
"--quantization",
|
268
269
|
type=str,
|
@@ -278,6 +279,19 @@ class ServerArgs:
|
|
278
279
|
],
|
279
280
|
help="The quantization method.",
|
280
281
|
)
|
282
|
+
parser.add_argument(
|
283
|
+
"--context-length",
|
284
|
+
type=int,
|
285
|
+
default=ServerArgs.context_length,
|
286
|
+
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
287
|
+
)
|
288
|
+
parser.add_argument(
|
289
|
+
"--device",
|
290
|
+
type=str,
|
291
|
+
default="cuda",
|
292
|
+
choices=["cuda", "xpu"],
|
293
|
+
help="The device type.",
|
294
|
+
)
|
281
295
|
parser.add_argument(
|
282
296
|
"--served-model-name",
|
283
297
|
type=str,
|
@@ -398,6 +412,11 @@ class ServerArgs:
|
|
398
412
|
default=ServerArgs.file_storage_pth,
|
399
413
|
help="The path of the file storage in backend.",
|
400
414
|
)
|
415
|
+
parser.add_argument(
|
416
|
+
"--enable-cache-report",
|
417
|
+
action="store_true",
|
418
|
+
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
|
419
|
+
)
|
401
420
|
|
402
421
|
# Data parallelism
|
403
422
|
parser.add_argument(
|
@@ -440,7 +459,60 @@ class ServerArgs:
|
|
440
459
|
default=ServerArgs.json_model_override_args,
|
441
460
|
)
|
442
461
|
|
443
|
-
#
|
462
|
+
# Double Sparsity
|
463
|
+
parser.add_argument(
|
464
|
+
"--enable-double-sparsity",
|
465
|
+
action="store_true",
|
466
|
+
help="Enable double sparsity attention",
|
467
|
+
)
|
468
|
+
parser.add_argument(
|
469
|
+
"--ds-channel-config-path",
|
470
|
+
type=str,
|
471
|
+
default=ServerArgs.ds_channel_config_path,
|
472
|
+
help="The path of the double sparsity channel config",
|
473
|
+
)
|
474
|
+
parser.add_argument(
|
475
|
+
"--ds-heavy-channel-num",
|
476
|
+
type=int,
|
477
|
+
default=ServerArgs.ds_heavy_channel_num,
|
478
|
+
help="The number of heavy channels in double sparsity attention",
|
479
|
+
)
|
480
|
+
parser.add_argument(
|
481
|
+
"--ds-heavy-token-num",
|
482
|
+
type=int,
|
483
|
+
default=ServerArgs.ds_heavy_token_num,
|
484
|
+
help="The number of heavy tokens in double sparsity attention",
|
485
|
+
)
|
486
|
+
parser.add_argument(
|
487
|
+
"--ds-heavy-channel-type",
|
488
|
+
type=str,
|
489
|
+
default=ServerArgs.ds_heavy_channel_type,
|
490
|
+
help="The type of heavy channels in double sparsity attention",
|
491
|
+
)
|
492
|
+
parser.add_argument(
|
493
|
+
"--ds-sparse-decode-threshold",
|
494
|
+
type=int,
|
495
|
+
default=ServerArgs.ds_sparse_decode_threshold,
|
496
|
+
help="The type of heavy channels in double sparsity attention",
|
497
|
+
)
|
498
|
+
|
499
|
+
# LoRA
|
500
|
+
parser.add_argument(
|
501
|
+
"--lora-paths",
|
502
|
+
type=str,
|
503
|
+
nargs="*",
|
504
|
+
default=None,
|
505
|
+
action=LoRAPathAction,
|
506
|
+
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
|
507
|
+
)
|
508
|
+
parser.add_argument(
|
509
|
+
"--max-loras-per-batch",
|
510
|
+
type=int,
|
511
|
+
default=8,
|
512
|
+
help="Maximum number of adapters for a running batch, include base-only request",
|
513
|
+
)
|
514
|
+
|
515
|
+
# Kernel backend
|
444
516
|
parser.add_argument(
|
445
517
|
"--attention-backend",
|
446
518
|
type=str,
|
@@ -455,6 +527,8 @@ class ServerArgs:
|
|
455
527
|
default=ServerArgs.sampling_backend,
|
456
528
|
help="Choose the kernels for sampling layers.",
|
457
529
|
)
|
530
|
+
|
531
|
+
# Optimization/debug options
|
458
532
|
parser.add_argument(
|
459
533
|
"--disable-flashinfer",
|
460
534
|
action="store_true",
|
@@ -501,6 +575,21 @@ class ServerArgs:
|
|
501
575
|
action="store_true",
|
502
576
|
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
503
577
|
)
|
578
|
+
parser.add_argument(
|
579
|
+
"--disable-penalizer",
|
580
|
+
action="store_true",
|
581
|
+
help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.",
|
582
|
+
)
|
583
|
+
parser.add_argument(
|
584
|
+
"--disable-nan-detection",
|
585
|
+
action="store_true",
|
586
|
+
help="Disable the NaN detection for better performance.",
|
587
|
+
)
|
588
|
+
parser.add_argument(
|
589
|
+
"--enable-overlap-schedule",
|
590
|
+
action="store_true",
|
591
|
+
help="Overlap the CPU scheduler with GPU model worker. Experimental feature.",
|
592
|
+
)
|
504
593
|
parser.add_argument(
|
505
594
|
"--enable-mixed-chunk",
|
506
595
|
action="store_true",
|
@@ -535,25 +624,12 @@ class ServerArgs:
|
|
535
624
|
"This only affects Triton attention kernels.",
|
536
625
|
)
|
537
626
|
parser.add_argument(
|
538
|
-
"--
|
539
|
-
action="store_true",
|
540
|
-
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
541
|
-
)
|
542
|
-
|
543
|
-
# LoRA options
|
544
|
-
parser.add_argument(
|
545
|
-
"--lora-paths",
|
546
|
-
type=str,
|
547
|
-
nargs="*",
|
548
|
-
default=None,
|
549
|
-
action=LoRAPathAction,
|
550
|
-
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
|
551
|
-
)
|
552
|
-
parser.add_argument(
|
553
|
-
"--max-loras-per-batch",
|
627
|
+
"--num-continuous-decode-steps",
|
554
628
|
type=int,
|
555
|
-
default=
|
556
|
-
help="
|
629
|
+
default=ServerArgs.num_continuous_decode_steps,
|
630
|
+
help="Run multiple continuous decoding steps to reduce scheduling overhead. "
|
631
|
+
"This can potentially increase throughput but may also increase time-to-first-token latency. "
|
632
|
+
"The default value is 1, meaning only run one decoding step at a time.",
|
557
633
|
)
|
558
634
|
|
559
635
|
@classmethod
|
sglang/srt/utils.py
CHANGED
@@ -35,7 +35,7 @@ import psutil
|
|
35
35
|
import requests
|
36
36
|
import torch
|
37
37
|
import torch.distributed as dist
|
38
|
-
from fastapi.responses import
|
38
|
+
from fastapi.responses import ORJSONResponse
|
39
39
|
from packaging import version as pkg_version
|
40
40
|
from torch import nn
|
41
41
|
from torch.profiler import ProfilerActivity, profile, record_function
|
@@ -566,7 +566,7 @@ def add_api_key_middleware(app, api_key: str):
|
|
566
566
|
if request.url.path.startswith("/health"):
|
567
567
|
return await call_next(request)
|
568
568
|
if request.headers.get("Authorization") != "Bearer " + api_key:
|
569
|
-
return
|
569
|
+
return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401)
|
570
570
|
return await call_next(request)
|
571
571
|
|
572
572
|
|
@@ -584,10 +584,11 @@ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
|
|
584
584
|
|
585
585
|
def configure_logger(server_args, prefix: str = ""):
|
586
586
|
format = f"[%(asctime)s{prefix}] %(message)s"
|
587
|
+
# format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s"
|
587
588
|
logging.basicConfig(
|
588
589
|
level=getattr(logging, server_args.log_level.upper()),
|
589
590
|
format=format,
|
590
|
-
datefmt="%H:%M:%S",
|
591
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
591
592
|
force=True,
|
592
593
|
)
|
593
594
|
|
@@ -690,3 +691,10 @@ def pytorch_profile(name, func, *args, data_size=-1):
|
|
690
691
|
prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
|
691
692
|
step_counter += 1
|
692
693
|
return result
|
694
|
+
|
695
|
+
|
696
|
+
def first_rank_print(*args, **kwargs):
|
697
|
+
if torch.cuda.current_device() == 0:
|
698
|
+
print(*args, **kwargs)
|
699
|
+
else:
|
700
|
+
pass
|
sglang/test/few_shot_gsm8k.py
CHANGED
@@ -76,7 +76,9 @@ def run_eval(args):
|
|
76
76
|
def few_shot_gsm8k(s, question):
|
77
77
|
s += few_shot_examples + question
|
78
78
|
s += sgl.gen(
|
79
|
-
"answer",
|
79
|
+
"answer",
|
80
|
+
max_tokens=args.max_new_tokens,
|
81
|
+
stop=["Question", "Assistant:", "<|separator|>"],
|
80
82
|
)
|
81
83
|
|
82
84
|
#####################################
|
@@ -131,6 +133,7 @@ if __name__ == "__main__":
|
|
131
133
|
parser.add_argument("--num-shots", type=int, default=5)
|
132
134
|
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
133
135
|
parser.add_argument("--num-questions", type=int, default=200)
|
136
|
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
134
137
|
parser.add_argument("--parallel", type=int, default=128)
|
135
138
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
136
139
|
parser.add_argument("--port", type=int, default=30000)
|