sglang 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +6 -4
- sglang/bench_serving.py +46 -22
- sglang/lang/compiler.py +2 -2
- sglang/lang/ir.py +3 -3
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/layers/activation.py +33 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +6 -1
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +5 -0
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +110 -87
- sglang/srt/managers/tokenizer_manager.py +193 -111
- sglang/srt/managers/tp_worker.py +289 -352
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +168 -105
- sglang/srt/model_executor/model_runner.py +24 -37
- sglang/srt/models/gemma2.py +0 -1
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/qwen2_moe.py +0 -11
- sglang/srt/openai_api/adapter.py +155 -27
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -4
- sglang/srt/server.py +69 -15
- sglang/srt/server_args.py +26 -19
- sglang/srt/utils.py +31 -13
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +63 -63
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +4 -2
- sglang/test/test_utils.py +20 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/METADATA +23 -14
- sglang-0.2.12.dist-info/RECORD +112 -0
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
sglang/srt/sampling_params.py
CHANGED
@@ -23,13 +23,16 @@ _SAMPLING_EPS = 1e-6
|
|
23
23
|
class SamplingParams:
|
24
24
|
def __init__(
|
25
25
|
self,
|
26
|
-
max_new_tokens: int =
|
26
|
+
max_new_tokens: int = 128,
|
27
|
+
min_new_tokens: int = 0,
|
27
28
|
stop: Optional[Union[str, List[str]]] = None,
|
29
|
+
stop_token_ids: Optional[List[int]] = [],
|
28
30
|
temperature: float = 1.0,
|
29
31
|
top_p: float = 1.0,
|
30
32
|
top_k: int = -1,
|
31
33
|
frequency_penalty: float = 0.0,
|
32
34
|
presence_penalty: float = 0.0,
|
35
|
+
repetition_penalty: float = 1.0,
|
33
36
|
ignore_eos: bool = False,
|
34
37
|
skip_special_tokens: bool = True,
|
35
38
|
spaces_between_special_tokens: bool = True,
|
@@ -42,8 +45,11 @@ class SamplingParams:
|
|
42
45
|
self.top_k = top_k
|
43
46
|
self.frequency_penalty = frequency_penalty
|
44
47
|
self.presence_penalty = presence_penalty
|
48
|
+
self.repetition_penalty = repetition_penalty
|
45
49
|
self.stop_strs = stop
|
50
|
+
self.stop_token_ids = {*stop_token_ids}
|
46
51
|
self.max_new_tokens = max_new_tokens
|
52
|
+
self.min_new_tokens = min_new_tokens
|
47
53
|
self.ignore_eos = ignore_eos
|
48
54
|
self.skip_special_tokens = skip_special_tokens
|
49
55
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
@@ -80,23 +86,44 @@ class SamplingParams:
|
|
80
86
|
raise ValueError(
|
81
87
|
"presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
|
82
88
|
)
|
89
|
+
if not 0.0 <= self.repetition_penalty <= 2.0:
|
90
|
+
raise ValueError(
|
91
|
+
"repetition_penalty must be in (0, 2], got "
|
92
|
+
f"{self.repetition_penalty}."
|
93
|
+
)
|
94
|
+
if not 0 <= self.min_new_tokens:
|
95
|
+
raise ValueError(
|
96
|
+
f"min_new_tokens must be in (0, max_new_tokens], got "
|
97
|
+
f"{self.min_new_tokens}."
|
98
|
+
)
|
83
99
|
if self.max_new_tokens is not None:
|
84
100
|
if self.max_new_tokens < 0:
|
85
101
|
raise ValueError(
|
86
102
|
f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
|
87
103
|
)
|
104
|
+
if not self.min_new_tokens <= self.max_new_tokens:
|
105
|
+
raise ValueError(
|
106
|
+
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
|
107
|
+
f"{self.min_new_tokens}."
|
108
|
+
)
|
88
109
|
|
89
110
|
def normalize(self, tokenizer):
|
90
111
|
# Process stop strings
|
91
112
|
if self.stop_strs is None:
|
92
113
|
self.stop_strs = []
|
93
|
-
self.
|
114
|
+
if self.stop_token_ids is None:
|
115
|
+
self.stop_str_max_len = 0
|
116
|
+
else:
|
117
|
+
self.stop_str_max_len = 1
|
94
118
|
else:
|
95
119
|
if isinstance(self.stop_strs, str):
|
96
120
|
self.stop_strs = [self.stop_strs]
|
97
121
|
|
98
122
|
stop_str_max_len = 0
|
99
123
|
for stop_str in self.stop_strs:
|
100
|
-
|
101
|
-
|
124
|
+
if tokenizer is not None:
|
125
|
+
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
|
126
|
+
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
|
127
|
+
else:
|
128
|
+
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
102
129
|
self.stop_str_max_len = stop_str_max_len
|
sglang/srt/server.py
CHANGED
@@ -52,7 +52,7 @@ from sglang.srt.managers.controller_single import (
|
|
52
52
|
start_controller_process as start_controller_process_single,
|
53
53
|
)
|
54
54
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
55
|
-
from sglang.srt.managers.io_struct import GenerateReqInput
|
55
|
+
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
56
56
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
57
57
|
from sglang.srt.openai_api.adapter import (
|
58
58
|
load_chat_template_for_openai_api,
|
@@ -60,6 +60,7 @@ from sglang.srt.openai_api.adapter import (
|
|
60
60
|
v1_chat_completions,
|
61
61
|
v1_completions,
|
62
62
|
v1_delete_file,
|
63
|
+
v1_embeddings,
|
63
64
|
v1_files_create,
|
64
65
|
v1_retrieve_batch,
|
65
66
|
v1_retrieve_file,
|
@@ -74,7 +75,8 @@ from sglang.srt.utils import (
|
|
74
75
|
enable_show_time_cost,
|
75
76
|
kill_child_process,
|
76
77
|
maybe_set_triton_cache_manager,
|
77
|
-
|
78
|
+
prepare_model,
|
79
|
+
prepare_tokenizer,
|
78
80
|
set_ulimit,
|
79
81
|
)
|
80
82
|
from sglang.utils import get_exception_traceback
|
@@ -98,6 +100,7 @@ async def health() -> Response:
|
|
98
100
|
async def get_model_info():
|
99
101
|
result = {
|
100
102
|
"model_path": tokenizer_manager.model_path,
|
103
|
+
"is_generation": tokenizer_manager.is_generation,
|
101
104
|
}
|
102
105
|
return result
|
103
106
|
|
@@ -149,6 +152,21 @@ app.post("/generate")(generate_request)
|
|
149
152
|
app.put("/generate")(generate_request)
|
150
153
|
|
151
154
|
|
155
|
+
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
156
|
+
"""Handle an embedding request."""
|
157
|
+
try:
|
158
|
+
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
159
|
+
return ret
|
160
|
+
except ValueError as e:
|
161
|
+
return JSONResponse(
|
162
|
+
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
163
|
+
)
|
164
|
+
|
165
|
+
|
166
|
+
app.post("/encode")(encode_request)
|
167
|
+
app.put("/encode")(encode_request)
|
168
|
+
|
169
|
+
|
152
170
|
@app.post("/v1/completions")
|
153
171
|
async def openai_v1_completions(raw_request: Request):
|
154
172
|
return await v1_completions(tokenizer_manager, raw_request)
|
@@ -159,6 +177,12 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|
159
177
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
160
178
|
|
161
179
|
|
180
|
+
@app.post("/v1/embeddings")
|
181
|
+
async def openai_v1_embeddings(raw_request: Request):
|
182
|
+
response = await v1_embeddings(tokenizer_manager, raw_request)
|
183
|
+
return response
|
184
|
+
|
185
|
+
|
162
186
|
@app.get("/v1/models")
|
163
187
|
def available_models():
|
164
188
|
"""Show available models."""
|
@@ -235,6 +259,10 @@ def launch_server(
|
|
235
259
|
)
|
236
260
|
logger.info(f"{server_args=}")
|
237
261
|
|
262
|
+
# Use model from www.modelscope.cn, first download the model.
|
263
|
+
server_args.model_path = prepare_model(server_args.model_path)
|
264
|
+
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
|
265
|
+
|
238
266
|
# Launch processes for multi-node tensor parallelism
|
239
267
|
if server_args.nnodes > 1:
|
240
268
|
if server_args.node_rank != 0:
|
@@ -347,10 +375,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
347
375
|
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
348
376
|
maybe_set_triton_cache_manager()
|
349
377
|
|
350
|
-
# Set torch compile config
|
351
|
-
if server_args.enable_torch_compile:
|
352
|
-
set_torch_compile_config()
|
353
|
-
|
354
378
|
# Set global chat template
|
355
379
|
if server_args.chat_template:
|
356
380
|
# TODO: replace this with huggingface transformers template
|
@@ -360,7 +384,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
360
384
|
if not server_args.disable_flashinfer:
|
361
385
|
assert_pkg_version(
|
362
386
|
"flashinfer",
|
363
|
-
"0.1.
|
387
|
+
"0.1.4",
|
364
388
|
"Please uninstall the old version and "
|
365
389
|
"reinstall the latest version by following the instructions "
|
366
390
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -385,6 +409,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
385
409
|
except (AssertionError, requests.exceptions.RequestException) as e:
|
386
410
|
last_traceback = get_exception_traceback()
|
387
411
|
pass
|
412
|
+
model_info = res.json()
|
388
413
|
|
389
414
|
if not success:
|
390
415
|
if pipe_finish_writer is not None:
|
@@ -393,17 +418,24 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
393
418
|
sys.exit(1)
|
394
419
|
|
395
420
|
# Send a warmup request
|
421
|
+
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
422
|
+
max_new_tokens = 8 if model_info["is_generation"] else 1
|
423
|
+
json_data = {
|
424
|
+
"sampling_params": {
|
425
|
+
"temperature": 0,
|
426
|
+
"max_new_tokens": max_new_tokens,
|
427
|
+
},
|
428
|
+
}
|
429
|
+
if server_args.skip_tokenizer_init:
|
430
|
+
json_data["input_ids"] = [10, 11, 12]
|
431
|
+
else:
|
432
|
+
json_data["text"] = "The capital city of France is"
|
433
|
+
|
396
434
|
try:
|
397
435
|
for _ in range(server_args.dp_size):
|
398
436
|
res = requests.post(
|
399
|
-
url +
|
400
|
-
json=
|
401
|
-
"text": "The capital city of France is",
|
402
|
-
"sampling_params": {
|
403
|
-
"temperature": 0,
|
404
|
-
"max_new_tokens": 8,
|
405
|
-
},
|
406
|
-
},
|
437
|
+
url + request_name,
|
438
|
+
json=json_data,
|
407
439
|
headers=headers,
|
408
440
|
timeout=600,
|
409
441
|
)
|
@@ -415,6 +447,15 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
415
447
|
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
|
416
448
|
sys.exit(1)
|
417
449
|
|
450
|
+
# Print warnings here
|
451
|
+
if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None:
|
452
|
+
logger.warning(
|
453
|
+
"You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
|
454
|
+
"This combination is an experimental feature and we noticed it can lead to "
|
455
|
+
"wrong generation results. If you want to use chunked prefill, it is recommended "
|
456
|
+
"not using `--disable-radix-cache`."
|
457
|
+
)
|
458
|
+
|
418
459
|
logger.info("The server is fired up and ready to roll!")
|
419
460
|
if pipe_finish_writer is not None:
|
420
461
|
pipe_finish_writer.send("init ok")
|
@@ -534,5 +575,18 @@ class Runtime:
|
|
534
575
|
)
|
535
576
|
return json.dumps(response.json())
|
536
577
|
|
578
|
+
def encode(
|
579
|
+
self,
|
580
|
+
prompt: str,
|
581
|
+
):
|
582
|
+
json_data = {
|
583
|
+
"text": prompt,
|
584
|
+
}
|
585
|
+
response = requests.post(
|
586
|
+
self.url + "/encode",
|
587
|
+
json=json_data,
|
588
|
+
)
|
589
|
+
return json.dumps(response.json())
|
590
|
+
|
537
591
|
def __del__(self):
|
538
592
|
self.shutdown()
|
sglang/srt/server_args.py
CHANGED
@@ -27,6 +27,7 @@ class ServerArgs:
|
|
27
27
|
model_path: str
|
28
28
|
tokenizer_path: Optional[str] = None
|
29
29
|
tokenizer_mode: str = "auto"
|
30
|
+
skip_tokenizer_init: bool = False
|
30
31
|
load_format: str = "auto"
|
31
32
|
dtype: str = "auto"
|
32
33
|
trust_remote_code: bool = True
|
@@ -42,10 +43,11 @@ class ServerArgs:
|
|
42
43
|
|
43
44
|
# Memory and scheduling
|
44
45
|
mem_fraction_static: Optional[float] = None
|
45
|
-
max_prefill_tokens: Optional[int] = None
|
46
46
|
max_running_requests: Optional[int] = None
|
47
47
|
max_num_reqs: Optional[int] = None
|
48
48
|
max_total_tokens: Optional[int] = None
|
49
|
+
chunked_prefill_size: int = -1
|
50
|
+
max_prefill_tokens: int = 16384
|
49
51
|
schedule_policy: str = "lpm"
|
50
52
|
schedule_conservativeness: float = 1.0
|
51
53
|
|
@@ -62,15 +64,12 @@ class ServerArgs:
|
|
62
64
|
|
63
65
|
# Other
|
64
66
|
api_key: Optional[str] = None
|
65
|
-
file_storage_pth: str = "
|
67
|
+
file_storage_pth: str = "SGLang_storage"
|
66
68
|
|
67
69
|
# Data parallelism
|
68
70
|
dp_size: int = 1
|
69
71
|
load_balance_method: str = "round_robin"
|
70
72
|
|
71
|
-
# Chunked Prefill
|
72
|
-
chunked_prefill_size: Optional[int] = None
|
73
|
-
|
74
73
|
# Optimization/debug options
|
75
74
|
disable_flashinfer: bool = False
|
76
75
|
disable_flashinfer_sampling: bool = False
|
@@ -96,6 +95,10 @@ class ServerArgs:
|
|
96
95
|
if self.served_model_name is None:
|
97
96
|
self.served_model_name = self.model_path
|
98
97
|
|
98
|
+
if self.chunked_prefill_size <= 0:
|
99
|
+
# Disable chunked prefill
|
100
|
+
self.chunked_prefill_size = None
|
101
|
+
|
99
102
|
if self.mem_fraction_static is None:
|
100
103
|
if self.tp_size >= 16:
|
101
104
|
self.mem_fraction_static = 0.79
|
@@ -107,6 +110,7 @@ class ServerArgs:
|
|
107
110
|
self.mem_fraction_static = 0.87
|
108
111
|
else:
|
109
112
|
self.mem_fraction_static = 0.88
|
113
|
+
|
110
114
|
if isinstance(self.additional_ports, int):
|
111
115
|
self.additional_ports = [self.additional_ports]
|
112
116
|
elif self.additional_ports is None:
|
@@ -151,6 +155,11 @@ class ServerArgs:
|
|
151
155
|
"tokenizer if available, and 'slow' will "
|
152
156
|
"always use the slow tokenizer.",
|
153
157
|
)
|
158
|
+
parser.add_argument(
|
159
|
+
"--skip-tokenizer-init",
|
160
|
+
action="store_true",
|
161
|
+
help="If set, skip init tokenizer and pass input_ids in generate request",
|
162
|
+
)
|
154
163
|
parser.add_argument(
|
155
164
|
"--load-format",
|
156
165
|
type=str,
|
@@ -226,12 +235,6 @@ class ServerArgs:
|
|
226
235
|
default=ServerArgs.mem_fraction_static,
|
227
236
|
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
|
228
237
|
)
|
229
|
-
parser.add_argument(
|
230
|
-
"--max-prefill-tokens",
|
231
|
-
type=int,
|
232
|
-
default=ServerArgs.max_prefill_tokens,
|
233
|
-
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
234
|
-
)
|
235
238
|
parser.add_argument(
|
236
239
|
"--max-running-requests",
|
237
240
|
type=int,
|
@@ -250,6 +253,18 @@ class ServerArgs:
|
|
250
253
|
default=ServerArgs.max_total_tokens,
|
251
254
|
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
|
252
255
|
)
|
256
|
+
parser.add_argument(
|
257
|
+
"--chunked-prefill-size",
|
258
|
+
type=int,
|
259
|
+
default=ServerArgs.chunked_prefill_size,
|
260
|
+
help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
|
261
|
+
)
|
262
|
+
parser.add_argument(
|
263
|
+
"--max-prefill-tokens",
|
264
|
+
type=int,
|
265
|
+
default=ServerArgs.max_prefill_tokens,
|
266
|
+
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
267
|
+
)
|
253
268
|
parser.add_argument(
|
254
269
|
"--schedule-policy",
|
255
270
|
type=str,
|
@@ -347,14 +362,6 @@ class ServerArgs:
|
|
347
362
|
)
|
348
363
|
parser.add_argument("--node-rank", type=int, help="The node rank.")
|
349
364
|
|
350
|
-
# Chunked prefill
|
351
|
-
parser.add_argument(
|
352
|
-
"--chunked-prefill-size",
|
353
|
-
type=int,
|
354
|
-
default=ServerArgs.chunked_prefill_size,
|
355
|
-
help="The size of the chunked prefill.",
|
356
|
-
)
|
357
|
-
|
358
365
|
# Optimization/debug options
|
359
366
|
parser.add_argument(
|
360
367
|
"--disable-flashinfer",
|
sglang/srt/utils.py
CHANGED
@@ -197,6 +197,8 @@ def allocate_init_ports(
|
|
197
197
|
def get_int_token_logit_bias(tokenizer, vocab_size):
|
198
198
|
"""Get the logit bias for integer-only tokens."""
|
199
199
|
# a bug when model's vocab size > tokenizer.vocab_size
|
200
|
+
if tokenizer == None:
|
201
|
+
return [-1e5] * vocab_size
|
200
202
|
vocab_size = tokenizer.vocab_size
|
201
203
|
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
202
204
|
for t_id in range(vocab_size):
|
@@ -223,6 +225,15 @@ def is_multimodal_model(model):
|
|
223
225
|
raise ValueError("unrecognized type")
|
224
226
|
|
225
227
|
|
228
|
+
def is_generation_model(model_architectures):
|
229
|
+
if (
|
230
|
+
"LlamaEmbeddingModel" in model_architectures
|
231
|
+
or "MistralModel" in model_architectures
|
232
|
+
):
|
233
|
+
return False
|
234
|
+
return True
|
235
|
+
|
236
|
+
|
226
237
|
def decode_video_base64(video_base64):
|
227
238
|
from PIL import Image
|
228
239
|
|
@@ -622,19 +633,6 @@ def receive_addrs(model_port_args, server_args):
|
|
622
633
|
dist.destroy_process_group()
|
623
634
|
|
624
635
|
|
625
|
-
def set_torch_compile_config():
|
626
|
-
# The following configurations are for torch compile optimizations
|
627
|
-
import torch._dynamo.config
|
628
|
-
import torch._inductor.config
|
629
|
-
|
630
|
-
torch._inductor.config.coordinate_descent_tuning = True
|
631
|
-
torch._inductor.config.triton.unique_kernel_names = True
|
632
|
-
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
633
|
-
|
634
|
-
# FIXME: tmp workaround
|
635
|
-
torch._dynamo.config.accumulated_cache_size_limit = 256
|
636
|
-
|
637
|
-
|
638
636
|
def set_ulimit(target_soft_limit=65535):
|
639
637
|
resource_type = resource.RLIMIT_NOFILE
|
640
638
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
@@ -705,3 +703,23 @@ def add_api_key_middleware(app, api_key):
|
|
705
703
|
if request.headers.get("Authorization") != "Bearer " + api_key:
|
706
704
|
return JSONResponse(content={"error": "Unauthorized"}, status_code=401)
|
707
705
|
return await call_next(request)
|
706
|
+
|
707
|
+
|
708
|
+
def prepare_model(model_path):
|
709
|
+
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
710
|
+
if not os.path.exists(model_path):
|
711
|
+
from modelscope import snapshot_download
|
712
|
+
|
713
|
+
return snapshot_download(model_path)
|
714
|
+
return model_path
|
715
|
+
|
716
|
+
|
717
|
+
def prepare_tokenizer(tokenizer_path):
|
718
|
+
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
719
|
+
if not os.path.exists(tokenizer_path):
|
720
|
+
from modelscope import snapshot_download
|
721
|
+
|
722
|
+
return snapshot_download(
|
723
|
+
tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
|
724
|
+
)
|
725
|
+
return tokenizer_path
|
sglang/test/run_eval.py
CHANGED
@@ -16,6 +16,8 @@ from sglang.test.simple_eval_common import (
|
|
16
16
|
|
17
17
|
|
18
18
|
def run_eval(args):
|
19
|
+
set_ulimit()
|
20
|
+
|
19
21
|
if "OPENAI_API_KEY" not in os.environ:
|
20
22
|
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
21
23
|
|
@@ -39,6 +41,14 @@ def run_eval(args):
|
|
39
41
|
eval_obj = MathEval(
|
40
42
|
filename, equality_checker, args.num_examples, args.num_threads
|
41
43
|
)
|
44
|
+
elif args.eval_name == "mgsm":
|
45
|
+
from sglang.test.simple_eval_mgsm import MGSMEval
|
46
|
+
|
47
|
+
eval_obj = MGSMEval(args.num_examples, args.num_threads)
|
48
|
+
elif args.eval_name == "mgsm_en":
|
49
|
+
from sglang.test.simple_eval_mgsm import MGSMEval
|
50
|
+
|
51
|
+
eval_obj = MGSMEval(args.num_examples, args.num_threads, languages=["en"])
|
42
52
|
elif args.eval_name == "gpqa":
|
43
53
|
from sglang.test.simple_eval_gpqa import GPQAEval
|
44
54
|
|
@@ -109,7 +119,6 @@ if __name__ == "__main__":
|
|
109
119
|
parser.add_argument("--eval-name", type=str, default="mmlu")
|
110
120
|
parser.add_argument("--num-examples", type=int)
|
111
121
|
parser.add_argument("--num-threads", type=int, default=512)
|
112
|
-
set_ulimit()
|
113
122
|
args = parser.parse_args()
|
114
123
|
|
115
124
|
run_eval(args)
|
sglang/test/runners.py
CHANGED
@@ -23,23 +23,19 @@ import torch.nn.functional as F
|
|
23
23
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
24
24
|
|
25
25
|
from sglang.srt.server import Runtime
|
26
|
+
from sglang.srt.utils import is_generation_model
|
26
27
|
|
27
28
|
DEFAULT_PROMPTS = [
|
28
|
-
|
29
|
+
# the output of gemma-2-2b from SRT is unstable on the commented prompt
|
30
|
+
# "The capital of France is",
|
29
31
|
"The capital of the United Kindom is",
|
30
32
|
"Today is a sunny day and I like",
|
33
|
+
"AI is a field of computer science focused on",
|
31
34
|
]
|
32
35
|
|
33
36
|
NUM_TOP_LOGPROBS = 5
|
34
37
|
|
35
38
|
|
36
|
-
def is_embedding_model(model_path):
|
37
|
-
# FIXME incomplete list
|
38
|
-
if "e5-mistral-7b-instruct" in model_path.lower():
|
39
|
-
return True
|
40
|
-
return False
|
41
|
-
|
42
|
-
|
43
39
|
def get_dtype_str(torch_dtype):
|
44
40
|
if torch_dtype is torch.float16:
|
45
41
|
return "float16"
|
@@ -49,10 +45,11 @@ def get_dtype_str(torch_dtype):
|
|
49
45
|
|
50
46
|
@dataclass
|
51
47
|
class ModelOutput:
|
52
|
-
output_strs: str = None
|
53
|
-
|
54
|
-
|
55
|
-
|
48
|
+
output_strs: List[str] = None
|
49
|
+
output_ids: List[int] = None
|
50
|
+
top_input_logprobs: List[torch.Tensor] = None
|
51
|
+
top_output_logprobs: List[torch.Tensor] = None
|
52
|
+
embed_logits: List[torch.Tensor] = None
|
56
53
|
|
57
54
|
|
58
55
|
class HFRunner:
|
@@ -60,7 +57,7 @@ class HFRunner:
|
|
60
57
|
self,
|
61
58
|
model_path,
|
62
59
|
torch_dtype=torch.float16,
|
63
|
-
|
60
|
+
is_generation_model=None,
|
64
61
|
):
|
65
62
|
self.in_queue = multiprocessing.Queue()
|
66
63
|
self.out_queue = multiprocessing.Queue()
|
@@ -72,13 +69,13 @@ class HFRunner:
|
|
72
69
|
self.out_queue,
|
73
70
|
model_path,
|
74
71
|
torch_dtype,
|
75
|
-
|
72
|
+
is_generation_model,
|
76
73
|
),
|
77
74
|
)
|
78
75
|
self.model_proc.start()
|
79
76
|
|
80
77
|
def start_model_process(
|
81
|
-
self, in_queue, out_queue, model_path, torch_dtype,
|
78
|
+
self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
|
82
79
|
):
|
83
80
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
84
81
|
model_path,
|
@@ -86,12 +83,12 @@ class HFRunner:
|
|
86
83
|
trust_remote_code=True,
|
87
84
|
)
|
88
85
|
|
89
|
-
self.
|
90
|
-
|
91
|
-
if
|
92
|
-
else
|
86
|
+
self.is_generation_model = (
|
87
|
+
is_generation_model(model_path)
|
88
|
+
if is_generation_model is None
|
89
|
+
else is_generation_model
|
93
90
|
)
|
94
|
-
if
|
91
|
+
if self.is_generation_model:
|
95
92
|
self.model = AutoModelForCausalLM.from_pretrained(
|
96
93
|
model_path,
|
97
94
|
torch_dtype=torch_dtype,
|
@@ -103,13 +100,13 @@ class HFRunner:
|
|
103
100
|
|
104
101
|
self.model = SentenceTransformer(
|
105
102
|
model_path,
|
106
|
-
|
107
|
-
)
|
103
|
+
model_kwargs={"torch_dtype": torch_dtype},
|
104
|
+
)
|
108
105
|
|
109
106
|
while True:
|
110
107
|
prompts, max_new_tokens = in_queue.get()
|
111
108
|
if prompts is not None:
|
112
|
-
if
|
109
|
+
if self.is_generation_model:
|
113
110
|
output_strs = []
|
114
111
|
prefill_logprobs = []
|
115
112
|
for p in prompts:
|
@@ -123,7 +120,9 @@ class HFRunner:
|
|
123
120
|
output_ids = self.model.generate(
|
124
121
|
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
125
122
|
)
|
126
|
-
output_strs.append(
|
123
|
+
output_strs.append(
|
124
|
+
self.tokenizer.decode(output_ids[0][len(input_ids[0]) :])
|
125
|
+
)
|
127
126
|
|
128
127
|
logits = self.model.forward(input_ids).logits[0]
|
129
128
|
logprobs = F.log_softmax(
|
@@ -144,7 +143,6 @@ class HFRunner:
|
|
144
143
|
)
|
145
144
|
|
146
145
|
else:
|
147
|
-
assert isinstance(prompts, List[str])
|
148
146
|
logits = self.model.encode(prompts).tolist()
|
149
147
|
|
150
148
|
out_queue.put(ModelOutput(embed_logits=logits))
|
@@ -152,7 +150,7 @@ class HFRunner:
|
|
152
150
|
def forward(
|
153
151
|
self,
|
154
152
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
155
|
-
max_new_tokens=
|
153
|
+
max_new_tokens=8,
|
156
154
|
):
|
157
155
|
self.in_queue.put((prompts, max_new_tokens))
|
158
156
|
return self.out_queue.get()
|
@@ -175,16 +173,13 @@ class SRTRunner:
|
|
175
173
|
model_path,
|
176
174
|
tp_size=1,
|
177
175
|
torch_dtype=torch.float16,
|
178
|
-
|
176
|
+
is_generation_model=None,
|
179
177
|
):
|
180
|
-
self.
|
181
|
-
|
182
|
-
if
|
183
|
-
else
|
178
|
+
self.is_generation_model = (
|
179
|
+
is_generation_model(model_path)
|
180
|
+
if is_generation_model is None
|
181
|
+
else is_generation_model
|
184
182
|
)
|
185
|
-
if self.is_embedding_model:
|
186
|
-
raise NotImplementedError()
|
187
|
-
|
188
183
|
self.runtime = Runtime(
|
189
184
|
model_path=model_path,
|
190
185
|
tp_size=tp_size,
|
@@ -194,40 +189,45 @@ class SRTRunner:
|
|
194
189
|
def forward(
|
195
190
|
self,
|
196
191
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
197
|
-
max_new_tokens=
|
192
|
+
max_new_tokens=8,
|
198
193
|
):
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
215
|
-
for x in response["meta_info"]["input_top_logprobs"][1:]
|
216
|
-
]
|
217
|
-
+ [
|
194
|
+
if self.is_generation_model:
|
195
|
+
# the return value contains logprobs from prefill
|
196
|
+
output_strs = []
|
197
|
+
top_input_logprobs = []
|
198
|
+
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
199
|
+
for prompt in prompts:
|
200
|
+
response = self.runtime.generate(
|
201
|
+
prompt,
|
202
|
+
sampling_params=sampling_params,
|
203
|
+
return_logprob=True,
|
204
|
+
top_logprobs_num=NUM_TOP_LOGPROBS,
|
205
|
+
)
|
206
|
+
response = json.loads(response)
|
207
|
+
output_strs.append(response["text"])
|
208
|
+
top_input_logprobs.append(
|
218
209
|
[
|
219
|
-
tup[0]
|
220
|
-
for
|
221
|
-
|
210
|
+
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
211
|
+
for x in response["meta_info"]["input_top_logprobs"][1:]
|
212
|
+
]
|
213
|
+
+ [
|
214
|
+
[
|
215
|
+
tup[0]
|
216
|
+
for tup in response["meta_info"]["output_top_logprobs"][0][
|
217
|
+
:NUM_TOP_LOGPROBS
|
218
|
+
]
|
222
219
|
]
|
223
220
|
]
|
224
|
-
|
225
|
-
)
|
226
|
-
# print(response["meta_info"]["output_top_logprobs"][0])
|
221
|
+
)
|
227
222
|
|
228
|
-
|
229
|
-
|
230
|
-
|
223
|
+
return ModelOutput(
|
224
|
+
output_strs=output_strs, top_input_logprobs=top_input_logprobs
|
225
|
+
)
|
226
|
+
else:
|
227
|
+
response = self.runtime.encode(prompts)
|
228
|
+
response = json.loads(response)
|
229
|
+
logits = [x["embedding"] for x in response]
|
230
|
+
return ModelOutput(embed_logits=logits)
|
231
231
|
|
232
232
|
def __enter__(self):
|
233
233
|
return self
|