sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +7 -1
- sglang/bench_latency.py +9 -6
- sglang/bench_serving.py +46 -22
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +4 -2
- sglang/lang/ir.py +16 -7
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/layers/activation.py +32 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +9 -2
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +7 -2
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +40 -16
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +115 -97
- sglang/srt/managers/tokenizer_manager.py +194 -112
- sglang/srt/managers/tp_worker.py +290 -359
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +71 -25
- sglang/srt/model_executor/forward_batch_info.py +293 -156
- sglang/srt/model_executor/model_runner.py +77 -57
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/deepseek.py +2 -2
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +11 -6
- sglang/srt/models/grok.py +50 -396
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/mixtral.py +56 -254
- sglang/srt/models/mixtral_quant.py +1 -4
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_moe.py +2 -13
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +187 -48
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -8
- sglang/srt/server.py +91 -29
- sglang/srt/server_args.py +32 -19
- sglang/srt/utils.py +32 -15
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +81 -73
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +36 -7
- sglang/test/test_utils.py +24 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
- sglang-0.2.13.dist-info/RECORD +112 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
sglang/srt/sampling_params.py
CHANGED
@@ -23,17 +23,19 @@ _SAMPLING_EPS = 1e-6
|
|
23
23
|
class SamplingParams:
|
24
24
|
def __init__(
|
25
25
|
self,
|
26
|
-
max_new_tokens: int =
|
26
|
+
max_new_tokens: int = 128,
|
27
|
+
min_new_tokens: int = 0,
|
27
28
|
stop: Optional[Union[str, List[str]]] = None,
|
29
|
+
stop_token_ids: Optional[List[int]] = [],
|
28
30
|
temperature: float = 1.0,
|
29
31
|
top_p: float = 1.0,
|
30
32
|
top_k: int = -1,
|
31
33
|
frequency_penalty: float = 0.0,
|
32
34
|
presence_penalty: float = 0.0,
|
35
|
+
repetition_penalty: float = 1.0,
|
33
36
|
ignore_eos: bool = False,
|
34
37
|
skip_special_tokens: bool = True,
|
35
38
|
spaces_between_special_tokens: bool = True,
|
36
|
-
dtype: Optional[str] = None,
|
37
39
|
regex: Optional[str] = None,
|
38
40
|
n: int = 1,
|
39
41
|
) -> None:
|
@@ -42,12 +44,14 @@ class SamplingParams:
|
|
42
44
|
self.top_k = top_k
|
43
45
|
self.frequency_penalty = frequency_penalty
|
44
46
|
self.presence_penalty = presence_penalty
|
47
|
+
self.repetition_penalty = repetition_penalty
|
45
48
|
self.stop_strs = stop
|
49
|
+
self.stop_token_ids = {*stop_token_ids}
|
46
50
|
self.max_new_tokens = max_new_tokens
|
51
|
+
self.min_new_tokens = min_new_tokens
|
47
52
|
self.ignore_eos = ignore_eos
|
48
53
|
self.skip_special_tokens = skip_special_tokens
|
49
54
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
50
|
-
self.dtype = dtype
|
51
55
|
self.regex = regex
|
52
56
|
self.n = n
|
53
57
|
|
@@ -57,8 +61,6 @@ class SamplingParams:
|
|
57
61
|
self.top_k = 1
|
58
62
|
if self.top_k == -1:
|
59
63
|
self.top_k = 1 << 30 # whole vocabulary
|
60
|
-
if self.dtype == "int":
|
61
|
-
self.stop_strs = [" ", "\n"]
|
62
64
|
|
63
65
|
def verify(self):
|
64
66
|
if self.temperature < 0.0:
|
@@ -80,23 +82,44 @@ class SamplingParams:
|
|
80
82
|
raise ValueError(
|
81
83
|
"presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
|
82
84
|
)
|
85
|
+
if not 0.0 <= self.repetition_penalty <= 2.0:
|
86
|
+
raise ValueError(
|
87
|
+
"repetition_penalty must be in (0, 2], got "
|
88
|
+
f"{self.repetition_penalty}."
|
89
|
+
)
|
90
|
+
if not 0 <= self.min_new_tokens:
|
91
|
+
raise ValueError(
|
92
|
+
f"min_new_tokens must be in (0, max_new_tokens], got "
|
93
|
+
f"{self.min_new_tokens}."
|
94
|
+
)
|
83
95
|
if self.max_new_tokens is not None:
|
84
96
|
if self.max_new_tokens < 0:
|
85
97
|
raise ValueError(
|
86
98
|
f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
|
87
99
|
)
|
100
|
+
if not self.min_new_tokens <= self.max_new_tokens:
|
101
|
+
raise ValueError(
|
102
|
+
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
|
103
|
+
f"{self.min_new_tokens}."
|
104
|
+
)
|
88
105
|
|
89
106
|
def normalize(self, tokenizer):
|
90
107
|
# Process stop strings
|
91
108
|
if self.stop_strs is None:
|
92
109
|
self.stop_strs = []
|
93
|
-
self.
|
110
|
+
if self.stop_token_ids is None:
|
111
|
+
self.stop_str_max_len = 0
|
112
|
+
else:
|
113
|
+
self.stop_str_max_len = 1
|
94
114
|
else:
|
95
115
|
if isinstance(self.stop_strs, str):
|
96
116
|
self.stop_strs = [self.stop_strs]
|
97
117
|
|
98
118
|
stop_str_max_len = 0
|
99
119
|
for stop_str in self.stop_strs:
|
100
|
-
|
101
|
-
|
120
|
+
if tokenizer is not None:
|
121
|
+
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
|
122
|
+
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
|
123
|
+
else:
|
124
|
+
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
102
125
|
self.stop_str_max_len = stop_str_max_len
|
sglang/srt/server.py
CHANGED
@@ -52,7 +52,7 @@ from sglang.srt.managers.controller_single import (
|
|
52
52
|
start_controller_process as start_controller_process_single,
|
53
53
|
)
|
54
54
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
55
|
-
from sglang.srt.managers.io_struct import GenerateReqInput
|
55
|
+
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
56
56
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
57
57
|
from sglang.srt.openai_api.adapter import (
|
58
58
|
load_chat_template_for_openai_api,
|
@@ -60,6 +60,7 @@ from sglang.srt.openai_api.adapter import (
|
|
60
60
|
v1_chat_completions,
|
61
61
|
v1_completions,
|
62
62
|
v1_delete_file,
|
63
|
+
v1_embeddings,
|
63
64
|
v1_files_create,
|
64
65
|
v1_retrieve_batch,
|
65
66
|
v1_retrieve_file,
|
@@ -74,7 +75,8 @@ from sglang.srt.utils import (
|
|
74
75
|
enable_show_time_cost,
|
75
76
|
kill_child_process,
|
76
77
|
maybe_set_triton_cache_manager,
|
77
|
-
|
78
|
+
prepare_model,
|
79
|
+
prepare_tokenizer,
|
78
80
|
set_ulimit,
|
79
81
|
)
|
80
82
|
from sglang.utils import get_exception_traceback
|
@@ -98,6 +100,7 @@ async def health() -> Response:
|
|
98
100
|
async def get_model_info():
|
99
101
|
result = {
|
100
102
|
"model_path": tokenizer_manager.model_path,
|
103
|
+
"is_generation": tokenizer_manager.is_generation,
|
101
104
|
}
|
102
105
|
return result
|
103
106
|
|
@@ -149,6 +152,21 @@ app.post("/generate")(generate_request)
|
|
149
152
|
app.put("/generate")(generate_request)
|
150
153
|
|
151
154
|
|
155
|
+
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
156
|
+
"""Handle an embedding request."""
|
157
|
+
try:
|
158
|
+
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
159
|
+
return ret
|
160
|
+
except ValueError as e:
|
161
|
+
return JSONResponse(
|
162
|
+
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
163
|
+
)
|
164
|
+
|
165
|
+
|
166
|
+
app.post("/encode")(encode_request)
|
167
|
+
app.put("/encode")(encode_request)
|
168
|
+
|
169
|
+
|
152
170
|
@app.post("/v1/completions")
|
153
171
|
async def openai_v1_completions(raw_request: Request):
|
154
172
|
return await v1_completions(tokenizer_manager, raw_request)
|
@@ -159,6 +177,12 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|
159
177
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
160
178
|
|
161
179
|
|
180
|
+
@app.post("/v1/embeddings")
|
181
|
+
async def openai_v1_embeddings(raw_request: Request):
|
182
|
+
response = await v1_embeddings(tokenizer_manager, raw_request)
|
183
|
+
return response
|
184
|
+
|
185
|
+
|
162
186
|
@app.get("/v1/models")
|
163
187
|
def available_models():
|
164
188
|
"""Show available models."""
|
@@ -235,6 +259,10 @@ def launch_server(
|
|
235
259
|
)
|
236
260
|
logger.info(f"{server_args=}")
|
237
261
|
|
262
|
+
# Use model from www.modelscope.cn, first download the model.
|
263
|
+
server_args.model_path = prepare_model(server_args.model_path)
|
264
|
+
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
|
265
|
+
|
238
266
|
# Launch processes for multi-node tensor parallelism
|
239
267
|
if server_args.nnodes > 1:
|
240
268
|
if server_args.node_rank != 0:
|
@@ -260,6 +288,8 @@ def launch_server(
|
|
260
288
|
|
261
289
|
# Launch processes
|
262
290
|
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
291
|
+
if server_args.chat_template:
|
292
|
+
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
263
293
|
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
264
294
|
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
265
295
|
|
@@ -330,6 +360,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
330
360
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
331
361
|
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
332
362
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
363
|
+
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
333
364
|
|
334
365
|
# Set ulimit
|
335
366
|
set_ulimit()
|
@@ -347,20 +378,11 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
347
378
|
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
348
379
|
maybe_set_triton_cache_manager()
|
349
380
|
|
350
|
-
# Set torch compile config
|
351
|
-
if server_args.enable_torch_compile:
|
352
|
-
set_torch_compile_config()
|
353
|
-
|
354
|
-
# Set global chat template
|
355
|
-
if server_args.chat_template:
|
356
|
-
# TODO: replace this with huggingface transformers template
|
357
|
-
load_chat_template_for_openai_api(server_args.chat_template)
|
358
|
-
|
359
381
|
# Check flashinfer version
|
360
382
|
if not server_args.disable_flashinfer:
|
361
383
|
assert_pkg_version(
|
362
384
|
"flashinfer",
|
363
|
-
"0.1.
|
385
|
+
"0.1.5",
|
364
386
|
"Please uninstall the old version and "
|
365
387
|
"reinstall the latest version by following the instructions "
|
366
388
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -385,6 +407,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
385
407
|
except (AssertionError, requests.exceptions.RequestException) as e:
|
386
408
|
last_traceback = get_exception_traceback()
|
387
409
|
pass
|
410
|
+
model_info = res.json()
|
388
411
|
|
389
412
|
if not success:
|
390
413
|
if pipe_finish_writer is not None:
|
@@ -393,17 +416,24 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
393
416
|
sys.exit(1)
|
394
417
|
|
395
418
|
# Send a warmup request
|
419
|
+
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
420
|
+
max_new_tokens = 8 if model_info["is_generation"] else 1
|
421
|
+
json_data = {
|
422
|
+
"sampling_params": {
|
423
|
+
"temperature": 0,
|
424
|
+
"max_new_tokens": max_new_tokens,
|
425
|
+
},
|
426
|
+
}
|
427
|
+
if server_args.skip_tokenizer_init:
|
428
|
+
json_data["input_ids"] = [10, 11, 12]
|
429
|
+
else:
|
430
|
+
json_data["text"] = "The capital city of France is"
|
431
|
+
|
396
432
|
try:
|
397
433
|
for _ in range(server_args.dp_size):
|
398
434
|
res = requests.post(
|
399
|
-
url +
|
400
|
-
json=
|
401
|
-
"text": "The capital city of France is",
|
402
|
-
"sampling_params": {
|
403
|
-
"temperature": 0,
|
404
|
-
"max_new_tokens": 8,
|
405
|
-
},
|
406
|
-
},
|
435
|
+
url + request_name,
|
436
|
+
json=json_data,
|
407
437
|
headers=headers,
|
408
438
|
timeout=600,
|
409
439
|
)
|
@@ -415,6 +445,15 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
415
445
|
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
|
416
446
|
sys.exit(1)
|
417
447
|
|
448
|
+
# Print warnings here
|
449
|
+
if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None:
|
450
|
+
logger.warning(
|
451
|
+
"You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
|
452
|
+
"This combination is an experimental feature and we noticed it can lead to "
|
453
|
+
"wrong generation results. If you want to use chunked prefill, it is recommended "
|
454
|
+
"not using `--disable-radix-cache`."
|
455
|
+
)
|
456
|
+
|
418
457
|
logger.info("The server is fired up and ready to roll!")
|
419
458
|
if pipe_finish_writer is not None:
|
420
459
|
pipe_finish_writer.send("init ok")
|
@@ -492,11 +531,18 @@ class Runtime:
|
|
492
531
|
prompt: str,
|
493
532
|
sampling_params: Optional[Dict] = None,
|
494
533
|
):
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
534
|
+
if self.server_args.skip_tokenizer_init:
|
535
|
+
json_data = {
|
536
|
+
"input_ids": prompt,
|
537
|
+
"sampling_params": sampling_params,
|
538
|
+
"stream": True,
|
539
|
+
}
|
540
|
+
else:
|
541
|
+
json_data = {
|
542
|
+
"text": prompt,
|
543
|
+
"sampling_params": sampling_params,
|
544
|
+
"stream": True,
|
545
|
+
}
|
500
546
|
pos = 0
|
501
547
|
|
502
548
|
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
@@ -508,10 +554,13 @@ class Runtime:
|
|
508
554
|
if chunk == "data: [DONE]\n\n":
|
509
555
|
break
|
510
556
|
data = json.loads(chunk[5:].strip("\n"))
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
557
|
+
if hasattr(data, "text"):
|
558
|
+
cur = data["text"][pos:]
|
559
|
+
if cur:
|
560
|
+
yield cur
|
561
|
+
pos += len(cur)
|
562
|
+
else:
|
563
|
+
yield data
|
515
564
|
|
516
565
|
add_request = async_generate
|
517
566
|
|
@@ -534,5 +583,18 @@ class Runtime:
|
|
534
583
|
)
|
535
584
|
return json.dumps(response.json())
|
536
585
|
|
586
|
+
def encode(
|
587
|
+
self,
|
588
|
+
prompt: str,
|
589
|
+
):
|
590
|
+
json_data = {
|
591
|
+
"text": prompt,
|
592
|
+
}
|
593
|
+
response = requests.post(
|
594
|
+
self.url + "/encode",
|
595
|
+
json=json_data,
|
596
|
+
)
|
597
|
+
return json.dumps(response.json())
|
598
|
+
|
537
599
|
def __del__(self):
|
538
600
|
self.shutdown()
|
sglang/srt/server_args.py
CHANGED
@@ -17,9 +17,12 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import argparse
|
19
19
|
import dataclasses
|
20
|
+
import logging
|
20
21
|
import random
|
21
22
|
from typing import List, Optional, Union
|
22
23
|
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
23
26
|
|
24
27
|
@dataclasses.dataclass
|
25
28
|
class ServerArgs:
|
@@ -27,6 +30,7 @@ class ServerArgs:
|
|
27
30
|
model_path: str
|
28
31
|
tokenizer_path: Optional[str] = None
|
29
32
|
tokenizer_mode: str = "auto"
|
33
|
+
skip_tokenizer_init: bool = False
|
30
34
|
load_format: str = "auto"
|
31
35
|
dtype: str = "auto"
|
32
36
|
trust_remote_code: bool = True
|
@@ -42,10 +46,11 @@ class ServerArgs:
|
|
42
46
|
|
43
47
|
# Memory and scheduling
|
44
48
|
mem_fraction_static: Optional[float] = None
|
45
|
-
max_prefill_tokens: Optional[int] = None
|
46
49
|
max_running_requests: Optional[int] = None
|
47
50
|
max_num_reqs: Optional[int] = None
|
48
51
|
max_total_tokens: Optional[int] = None
|
52
|
+
chunked_prefill_size: int = 8192
|
53
|
+
max_prefill_tokens: int = 16384
|
49
54
|
schedule_policy: str = "lpm"
|
50
55
|
schedule_conservativeness: float = 1.0
|
51
56
|
|
@@ -62,15 +67,12 @@ class ServerArgs:
|
|
62
67
|
|
63
68
|
# Other
|
64
69
|
api_key: Optional[str] = None
|
65
|
-
file_storage_pth: str = "
|
70
|
+
file_storage_pth: str = "SGLang_storage"
|
66
71
|
|
67
72
|
# Data parallelism
|
68
73
|
dp_size: int = 1
|
69
74
|
load_balance_method: str = "round_robin"
|
70
75
|
|
71
|
-
# Chunked Prefill
|
72
|
-
chunked_prefill_size: Optional[int] = None
|
73
|
-
|
74
76
|
# Optimization/debug options
|
75
77
|
disable_flashinfer: bool = False
|
76
78
|
disable_flashinfer_sampling: bool = False
|
@@ -96,6 +98,10 @@ class ServerArgs:
|
|
96
98
|
if self.served_model_name is None:
|
97
99
|
self.served_model_name = self.model_path
|
98
100
|
|
101
|
+
if self.chunked_prefill_size <= 0:
|
102
|
+
# Disable chunked prefill
|
103
|
+
self.chunked_prefill_size = None
|
104
|
+
|
99
105
|
if self.mem_fraction_static is None:
|
100
106
|
if self.tp_size >= 16:
|
101
107
|
self.mem_fraction_static = 0.79
|
@@ -107,6 +113,7 @@ class ServerArgs:
|
|
107
113
|
self.mem_fraction_static = 0.87
|
108
114
|
else:
|
109
115
|
self.mem_fraction_static = 0.88
|
116
|
+
|
110
117
|
if isinstance(self.additional_ports, int):
|
111
118
|
self.additional_ports = [self.additional_ports]
|
112
119
|
elif self.additional_ports is None:
|
@@ -151,6 +158,11 @@ class ServerArgs:
|
|
151
158
|
"tokenizer if available, and 'slow' will "
|
152
159
|
"always use the slow tokenizer.",
|
153
160
|
)
|
161
|
+
parser.add_argument(
|
162
|
+
"--skip-tokenizer-init",
|
163
|
+
action="store_true",
|
164
|
+
help="If set, skip init tokenizer and pass input_ids in generate request",
|
165
|
+
)
|
154
166
|
parser.add_argument(
|
155
167
|
"--load-format",
|
156
168
|
type=str,
|
@@ -226,12 +238,6 @@ class ServerArgs:
|
|
226
238
|
default=ServerArgs.mem_fraction_static,
|
227
239
|
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
|
228
240
|
)
|
229
|
-
parser.add_argument(
|
230
|
-
"--max-prefill-tokens",
|
231
|
-
type=int,
|
232
|
-
default=ServerArgs.max_prefill_tokens,
|
233
|
-
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
234
|
-
)
|
235
241
|
parser.add_argument(
|
236
242
|
"--max-running-requests",
|
237
243
|
type=int,
|
@@ -250,6 +256,18 @@ class ServerArgs:
|
|
250
256
|
default=ServerArgs.max_total_tokens,
|
251
257
|
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
|
252
258
|
)
|
259
|
+
parser.add_argument(
|
260
|
+
"--chunked-prefill-size",
|
261
|
+
type=int,
|
262
|
+
default=ServerArgs.chunked_prefill_size,
|
263
|
+
help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
|
264
|
+
)
|
265
|
+
parser.add_argument(
|
266
|
+
"--max-prefill-tokens",
|
267
|
+
type=int,
|
268
|
+
default=ServerArgs.max_prefill_tokens,
|
269
|
+
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
270
|
+
)
|
253
271
|
parser.add_argument(
|
254
272
|
"--schedule-policy",
|
255
273
|
type=str,
|
@@ -347,14 +365,6 @@ class ServerArgs:
|
|
347
365
|
)
|
348
366
|
parser.add_argument("--node-rank", type=int, help="The node rank.")
|
349
367
|
|
350
|
-
# Chunked prefill
|
351
|
-
parser.add_argument(
|
352
|
-
"--chunked-prefill-size",
|
353
|
-
type=int,
|
354
|
-
default=ServerArgs.chunked_prefill_size,
|
355
|
-
help="The size of the chunked prefill.",
|
356
|
-
)
|
357
|
-
|
358
368
|
# Optimization/debug options
|
359
369
|
parser.add_argument(
|
360
370
|
"--disable-flashinfer",
|
@@ -439,6 +449,9 @@ class ServerArgs:
|
|
439
449
|
assert not (
|
440
450
|
self.dp_size > 1 and self.node_rank is not None
|
441
451
|
), "multi-node data parallel is not supported"
|
452
|
+
if "gemma-2" in self.model_path.lower():
|
453
|
+
logger.info(f"When using sliding window in gemma-2, turn on flashinfer.")
|
454
|
+
self.disable_flashinfer = False
|
442
455
|
|
443
456
|
|
444
457
|
@dataclasses.dataclass
|
sglang/srt/utils.py
CHANGED
@@ -35,7 +35,6 @@ import torch
|
|
35
35
|
import torch.distributed as dist
|
36
36
|
from fastapi.responses import JSONResponse
|
37
37
|
from packaging import version as pkg_version
|
38
|
-
from starlette.middleware.base import BaseHTTPMiddleware
|
39
38
|
from torch.nn.parameter import Parameter
|
40
39
|
from triton.runtime.cache import (
|
41
40
|
FileCacheManager,
|
@@ -197,6 +196,8 @@ def allocate_init_ports(
|
|
197
196
|
def get_int_token_logit_bias(tokenizer, vocab_size):
|
198
197
|
"""Get the logit bias for integer-only tokens."""
|
199
198
|
# a bug when model's vocab size > tokenizer.vocab_size
|
199
|
+
if tokenizer == None:
|
200
|
+
return [-1e5] * vocab_size
|
200
201
|
vocab_size = tokenizer.vocab_size
|
201
202
|
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
202
203
|
for t_id in range(vocab_size):
|
@@ -223,6 +224,15 @@ def is_multimodal_model(model):
|
|
223
224
|
raise ValueError("unrecognized type")
|
224
225
|
|
225
226
|
|
227
|
+
def is_generation_model(model_architectures):
|
228
|
+
if (
|
229
|
+
"LlamaEmbeddingModel" in model_architectures
|
230
|
+
or "MistralModel" in model_architectures
|
231
|
+
):
|
232
|
+
return False
|
233
|
+
return True
|
234
|
+
|
235
|
+
|
226
236
|
def decode_video_base64(video_base64):
|
227
237
|
from PIL import Image
|
228
238
|
|
@@ -622,19 +632,6 @@ def receive_addrs(model_port_args, server_args):
|
|
622
632
|
dist.destroy_process_group()
|
623
633
|
|
624
634
|
|
625
|
-
def set_torch_compile_config():
|
626
|
-
# The following configurations are for torch compile optimizations
|
627
|
-
import torch._dynamo.config
|
628
|
-
import torch._inductor.config
|
629
|
-
|
630
|
-
torch._inductor.config.coordinate_descent_tuning = True
|
631
|
-
torch._inductor.config.triton.unique_kernel_names = True
|
632
|
-
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
633
|
-
|
634
|
-
# FIXME: tmp workaround
|
635
|
-
torch._dynamo.config.accumulated_cache_size_limit = 256
|
636
|
-
|
637
|
-
|
638
635
|
def set_ulimit(target_soft_limit=65535):
|
639
636
|
resource_type = resource.RLIMIT_NOFILE
|
640
637
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
@@ -646,7 +643,7 @@ def set_ulimit(target_soft_limit=65535):
|
|
646
643
|
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
|
647
644
|
|
648
645
|
|
649
|
-
def
|
646
|
+
def is_llama3_405b_fp8_head_16(model_config):
|
650
647
|
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
|
651
648
|
if (
|
652
649
|
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
|
@@ -705,3 +702,23 @@ def add_api_key_middleware(app, api_key):
|
|
705
702
|
if request.headers.get("Authorization") != "Bearer " + api_key:
|
706
703
|
return JSONResponse(content={"error": "Unauthorized"}, status_code=401)
|
707
704
|
return await call_next(request)
|
705
|
+
|
706
|
+
|
707
|
+
def prepare_model(model_path):
|
708
|
+
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
709
|
+
if not os.path.exists(model_path):
|
710
|
+
from modelscope import snapshot_download
|
711
|
+
|
712
|
+
return snapshot_download(model_path)
|
713
|
+
return model_path
|
714
|
+
|
715
|
+
|
716
|
+
def prepare_tokenizer(tokenizer_path):
|
717
|
+
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
718
|
+
if not os.path.exists(tokenizer_path):
|
719
|
+
from modelscope import snapshot_download
|
720
|
+
|
721
|
+
return snapshot_download(
|
722
|
+
tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
|
723
|
+
)
|
724
|
+
return tokenizer_path
|
sglang/test/run_eval.py
CHANGED
@@ -16,6 +16,8 @@ from sglang.test.simple_eval_common import (
|
|
16
16
|
|
17
17
|
|
18
18
|
def run_eval(args):
|
19
|
+
set_ulimit()
|
20
|
+
|
19
21
|
if "OPENAI_API_KEY" not in os.environ:
|
20
22
|
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
21
23
|
|
@@ -39,6 +41,14 @@ def run_eval(args):
|
|
39
41
|
eval_obj = MathEval(
|
40
42
|
filename, equality_checker, args.num_examples, args.num_threads
|
41
43
|
)
|
44
|
+
elif args.eval_name == "mgsm":
|
45
|
+
from sglang.test.simple_eval_mgsm import MGSMEval
|
46
|
+
|
47
|
+
eval_obj = MGSMEval(args.num_examples, args.num_threads)
|
48
|
+
elif args.eval_name == "mgsm_en":
|
49
|
+
from sglang.test.simple_eval_mgsm import MGSMEval
|
50
|
+
|
51
|
+
eval_obj = MGSMEval(args.num_examples, args.num_threads, languages=["en"])
|
42
52
|
elif args.eval_name == "gpqa":
|
43
53
|
from sglang.test.simple_eval_gpqa import GPQAEval
|
44
54
|
|
@@ -109,7 +119,6 @@ if __name__ == "__main__":
|
|
109
119
|
parser.add_argument("--eval-name", type=str, default="mmlu")
|
110
120
|
parser.add_argument("--num-examples", type=int)
|
111
121
|
parser.add_argument("--num-threads", type=int, default=512)
|
112
|
-
set_ulimit()
|
113
122
|
args = parser.parse_args()
|
114
123
|
|
115
124
|
run_eval(args)
|