sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +13 -1
- sglang/bench_latency.py +10 -5
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +5 -2
- sglang/lang/ir.py +22 -4
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -2
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +64 -27
- sglang/srt/layers/radix_attention.py +41 -18
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +59 -179
- sglang/srt/managers/tokenizer_manager.py +193 -84
- sglang/srt/managers/tp_worker.py +131 -50
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +97 -28
- sglang/srt/model_executor/forward_batch_info.py +188 -82
- sglang/srt/model_executor/model_runner.py +269 -87
- sglang/srt/models/chatglm.py +6 -14
- sglang/srt/models/commandr.py +6 -2
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +7 -3
- sglang/srt/models/deepseek_v2.py +12 -7
- sglang/srt/models/gemma.py +6 -2
- sglang/srt/models/gemma2.py +22 -8
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +66 -398
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +7 -3
- sglang/srt/models/mixtral.py +61 -255
- sglang/srt/models/mixtral_quant.py +6 -5
- sglang/srt/models/qwen.py +7 -4
- sglang/srt/models/qwen2.py +15 -5
- sglang/srt/models/qwen2_moe.py +7 -16
- sglang/srt/models/stablelm.py +6 -2
- sglang/srt/openai_api/adapter.py +149 -58
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
- sglang/srt/server.py +107 -71
- sglang/srt/server_args.py +49 -15
- sglang/srt/utils.py +27 -18
- sglang/test/runners.py +38 -38
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_programs.py +32 -5
- sglang/test/test_utils.py +37 -50
- sglang/version.py +1 -1
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.12.dist-info/RECORD +0 -112
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -17,9 +17,12 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import argparse
|
19
19
|
import dataclasses
|
20
|
+
import logging
|
20
21
|
import random
|
21
22
|
from typing import List, Optional, Union
|
22
23
|
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
23
26
|
|
24
27
|
@dataclasses.dataclass
|
25
28
|
class ServerArgs:
|
@@ -30,11 +33,13 @@ class ServerArgs:
|
|
30
33
|
skip_tokenizer_init: bool = False
|
31
34
|
load_format: str = "auto"
|
32
35
|
dtype: str = "auto"
|
36
|
+
kv_cache_dtype: str = "auto"
|
33
37
|
trust_remote_code: bool = True
|
34
38
|
context_length: Optional[int] = None
|
35
39
|
quantization: Optional[str] = None
|
36
40
|
served_model_name: Optional[str] = None
|
37
41
|
chat_template: Optional[str] = None
|
42
|
+
is_embedding: bool = False
|
38
43
|
|
39
44
|
# Port
|
40
45
|
host: str = "127.0.0.1"
|
@@ -46,7 +51,7 @@ class ServerArgs:
|
|
46
51
|
max_running_requests: Optional[int] = None
|
47
52
|
max_num_reqs: Optional[int] = None
|
48
53
|
max_total_tokens: Optional[int] = None
|
49
|
-
chunked_prefill_size: int =
|
54
|
+
chunked_prefill_size: int = 8192
|
50
55
|
max_prefill_tokens: int = 16384
|
51
56
|
schedule_policy: str = "lpm"
|
52
57
|
schedule_conservativeness: float = 1.0
|
@@ -76,12 +81,14 @@ class ServerArgs:
|
|
76
81
|
disable_radix_cache: bool = False
|
77
82
|
disable_regex_jump_forward: bool = False
|
78
83
|
disable_cuda_graph: bool = False
|
84
|
+
disable_cuda_graph_padding: bool = False
|
79
85
|
disable_disk_cache: bool = False
|
86
|
+
disable_custom_all_reduce: bool = False
|
87
|
+
enable_mixed_chunk: bool = False
|
80
88
|
enable_torch_compile: bool = False
|
81
89
|
enable_p2p_check: bool = False
|
82
90
|
enable_mla: bool = False
|
83
|
-
|
84
|
-
efficient_weight_load: bool = False
|
91
|
+
triton_attention_reduce_in_fp32: bool = False
|
85
92
|
|
86
93
|
# Distributed args
|
87
94
|
nccl_init_addr: Optional[str] = None
|
@@ -190,11 +197,23 @@ class ServerArgs:
|
|
190
197
|
'* "float" is shorthand for FP32 precision.\n'
|
191
198
|
'* "float32" for FP32 precision.',
|
192
199
|
)
|
200
|
+
parser.add_argument(
|
201
|
+
"--kv-cache-dtype",
|
202
|
+
type=str,
|
203
|
+
default=ServerArgs.kv_cache_dtype,
|
204
|
+
choices=["auto", "fp8_e5m2"],
|
205
|
+
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
|
206
|
+
)
|
193
207
|
parser.add_argument(
|
194
208
|
"--trust-remote-code",
|
195
209
|
action="store_true",
|
196
210
|
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
197
211
|
)
|
212
|
+
parser.add_argument(
|
213
|
+
"--is-embedding",
|
214
|
+
action="store_true",
|
215
|
+
help="Whether to use a CausalLM as an embedding model.",
|
216
|
+
)
|
198
217
|
parser.add_argument(
|
199
218
|
"--context-length",
|
200
219
|
type=int,
|
@@ -388,11 +407,27 @@ class ServerArgs:
|
|
388
407
|
action="store_true",
|
389
408
|
help="Disable cuda graph.",
|
390
409
|
)
|
410
|
+
parser.add_argument(
|
411
|
+
"--disable-cuda-graph-padding",
|
412
|
+
action="store_true",
|
413
|
+
help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
|
414
|
+
)
|
391
415
|
parser.add_argument(
|
392
416
|
"--disable-disk-cache",
|
393
417
|
action="store_true",
|
394
418
|
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
|
395
419
|
)
|
420
|
+
parser.add_argument(
|
421
|
+
"--disable-custom-all-reduce",
|
422
|
+
action="store_true",
|
423
|
+
default=False,
|
424
|
+
help="Disable the custom all-reduce kernel and fall back to NCCL.",
|
425
|
+
)
|
426
|
+
parser.add_argument(
|
427
|
+
"--enable-mixed-chunk",
|
428
|
+
action="store_true",
|
429
|
+
help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
|
430
|
+
)
|
396
431
|
parser.add_argument(
|
397
432
|
"--enable-torch-compile",
|
398
433
|
action="store_true",
|
@@ -406,13 +441,13 @@ class ServerArgs:
|
|
406
441
|
parser.add_argument(
|
407
442
|
"--enable-mla",
|
408
443
|
action="store_true",
|
409
|
-
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
|
444
|
+
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
410
445
|
)
|
411
446
|
parser.add_argument(
|
412
|
-
"--attention-reduce-in-fp32",
|
447
|
+
"--triton-attention-reduce-in-fp32",
|
413
448
|
action="store_true",
|
414
449
|
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
|
415
|
-
"This only affects Triton attention kernels",
|
450
|
+
"This only affects Triton attention kernels.",
|
416
451
|
)
|
417
452
|
parser.add_argument(
|
418
453
|
"--efficient-weight-load",
|
@@ -430,15 +465,6 @@ class ServerArgs:
|
|
430
465
|
def url(self):
|
431
466
|
return f"http://{self.host}:{self.port}"
|
432
467
|
|
433
|
-
def print_mode_args(self):
|
434
|
-
return (
|
435
|
-
f"disable_flashinfer={self.disable_flashinfer}, "
|
436
|
-
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
|
437
|
-
f"disable_radix_cache={self.disable_radix_cache}, "
|
438
|
-
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
439
|
-
f"disable_disk_cache={self.disable_disk_cache}, "
|
440
|
-
)
|
441
|
-
|
442
468
|
def check_server_args(self):
|
443
469
|
assert (
|
444
470
|
self.tp_size % self.nnodes == 0
|
@@ -446,6 +472,14 @@ class ServerArgs:
|
|
446
472
|
assert not (
|
447
473
|
self.dp_size > 1 and self.node_rank is not None
|
448
474
|
), "multi-node data parallel is not supported"
|
475
|
+
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
476
|
+
logger.info(
|
477
|
+
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
|
478
|
+
)
|
479
|
+
self.trust_remote_code = False
|
480
|
+
if "gemma-2" in self.model_path.lower():
|
481
|
+
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
|
482
|
+
self.disable_flashinfer = False
|
449
483
|
|
450
484
|
|
451
485
|
@dataclasses.dataclass
|
sglang/srt/utils.py
CHANGED
@@ -35,7 +35,6 @@ import torch
|
|
35
35
|
import torch.distributed as dist
|
36
36
|
from fastapi.responses import JSONResponse
|
37
37
|
from packaging import version as pkg_version
|
38
|
-
from starlette.middleware.base import BaseHTTPMiddleware
|
39
38
|
from torch.nn.parameter import Parameter
|
40
39
|
from triton.runtime.cache import (
|
41
40
|
FileCacheManager,
|
@@ -225,13 +224,18 @@ def is_multimodal_model(model):
|
|
225
224
|
raise ValueError("unrecognized type")
|
226
225
|
|
227
226
|
|
228
|
-
def is_generation_model(model_architectures):
|
227
|
+
def is_generation_model(model_architectures, is_embedding: bool = False):
|
228
|
+
# We have two ways to determine whether a model is a generative model.
|
229
|
+
# 1. Check the model architectue
|
230
|
+
# 2. check the `is_embedding` server args
|
231
|
+
|
229
232
|
if (
|
230
233
|
"LlamaEmbeddingModel" in model_architectures
|
231
234
|
or "MistralModel" in model_architectures
|
232
235
|
):
|
233
236
|
return False
|
234
|
-
|
237
|
+
else:
|
238
|
+
return not is_embedding
|
235
239
|
|
236
240
|
|
237
241
|
def decode_video_base64(video_base64):
|
@@ -348,7 +352,7 @@ def suppress_other_loggers():
|
|
348
352
|
logging.WARN
|
349
353
|
)
|
350
354
|
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
351
|
-
logging.getLogger("vllm.utils").setLevel(logging.
|
355
|
+
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
|
352
356
|
|
353
357
|
|
354
358
|
def assert_pkg_version(pkg: str, min_version: str, message: str):
|
@@ -370,14 +374,11 @@ def kill_parent_process():
|
|
370
374
|
"""Kill the parent process and all children of the parent process."""
|
371
375
|
current_process = psutil.Process()
|
372
376
|
parent_process = current_process.parent()
|
373
|
-
|
374
|
-
for child in children:
|
375
|
-
if child.pid != current_process.pid:
|
376
|
-
os.kill(child.pid, 9)
|
377
|
-
os.kill(parent_process.pid, 9)
|
377
|
+
kill_child_process(parent_process.pid, skip_pid=current_process.pid)
|
378
378
|
|
379
379
|
|
380
|
-
def kill_child_process(pid, including_parent=True):
|
380
|
+
def kill_child_process(pid, including_parent=True, skip_pid=None):
|
381
|
+
"""Kill the process and all its children process."""
|
381
382
|
try:
|
382
383
|
parent = psutil.Process(pid)
|
383
384
|
except psutil.NoSuchProcess:
|
@@ -385,6 +386,8 @@ def kill_child_process(pid, including_parent=True):
|
|
385
386
|
|
386
387
|
children = parent.children(recursive=True)
|
387
388
|
for child in children:
|
389
|
+
if child.pid == skip_pid:
|
390
|
+
continue
|
388
391
|
try:
|
389
392
|
child.kill()
|
390
393
|
except psutil.NoSuchProcess:
|
@@ -453,10 +456,6 @@ def monkey_patch_vllm_dummy_weight_loader():
|
|
453
456
|
quant_method = getattr(module, "quant_method", None)
|
454
457
|
if quant_method is not None:
|
455
458
|
quant_method.process_weights_after_loading(module)
|
456
|
-
# FIXME: Remove this after Mixtral is updated
|
457
|
-
# to use quant_method.
|
458
|
-
if hasattr(module, "process_weights_after_loading"):
|
459
|
-
module.process_weights_after_loading()
|
460
459
|
|
461
460
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
462
461
|
# random values to the weights.
|
@@ -644,7 +643,7 @@ def set_ulimit(target_soft_limit=65535):
|
|
644
643
|
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
|
645
644
|
|
646
645
|
|
647
|
-
def
|
646
|
+
def is_llama3_405b_fp8_head_16(model_config):
|
648
647
|
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
|
649
648
|
if (
|
650
649
|
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
|
@@ -693,7 +692,7 @@ def monkey_patch_vllm_qvk_linear_loader():
|
|
693
692
|
setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
|
694
693
|
|
695
694
|
|
696
|
-
def add_api_key_middleware(app, api_key):
|
695
|
+
def add_api_key_middleware(app, api_key: str):
|
697
696
|
@app.middleware("http")
|
698
697
|
async def authentication(request, call_next):
|
699
698
|
if request.method == "OPTIONS":
|
@@ -705,7 +704,7 @@ def add_api_key_middleware(app, api_key):
|
|
705
704
|
return await call_next(request)
|
706
705
|
|
707
706
|
|
708
|
-
def prepare_model(model_path):
|
707
|
+
def prepare_model(model_path: str):
|
709
708
|
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
710
709
|
if not os.path.exists(model_path):
|
711
710
|
from modelscope import snapshot_download
|
@@ -714,7 +713,7 @@ def prepare_model(model_path):
|
|
714
713
|
return model_path
|
715
714
|
|
716
715
|
|
717
|
-
def prepare_tokenizer(tokenizer_path):
|
716
|
+
def prepare_tokenizer(tokenizer_path: str):
|
718
717
|
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
719
718
|
if not os.path.exists(tokenizer_path):
|
720
719
|
from modelscope import snapshot_download
|
@@ -723,3 +722,13 @@ def prepare_tokenizer(tokenizer_path):
|
|
723
722
|
tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
|
724
723
|
)
|
725
724
|
return tokenizer_path
|
725
|
+
|
726
|
+
|
727
|
+
def configure_logger(server_args, prefix: str = ""):
|
728
|
+
format = f"[%(asctime)s{prefix}] %(message)s"
|
729
|
+
logging.basicConfig(
|
730
|
+
level=getattr(logging, server_args.log_level.upper()),
|
731
|
+
format=format,
|
732
|
+
datefmt="%H:%M:%S",
|
733
|
+
force=True,
|
734
|
+
)
|
sglang/test/runners.py
CHANGED
@@ -14,7 +14,8 @@ limitations under the License.
|
|
14
14
|
"""
|
15
15
|
|
16
16
|
import json
|
17
|
-
import multiprocessing
|
17
|
+
import multiprocessing as mp
|
18
|
+
import os
|
18
19
|
from dataclasses import dataclass
|
19
20
|
from typing import List, Union
|
20
21
|
|
@@ -23,16 +24,22 @@ import torch.nn.functional as F
|
|
23
24
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
24
25
|
|
25
26
|
from sglang.srt.server import Runtime
|
26
|
-
from sglang.
|
27
|
+
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
|
27
28
|
|
28
29
|
DEFAULT_PROMPTS = [
|
29
30
|
# the output of gemma-2-2b from SRT is unstable on the commented prompt
|
30
31
|
# "The capital of France is",
|
32
|
+
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
|
31
33
|
"The capital of the United Kindom is",
|
32
34
|
"Today is a sunny day and I like",
|
33
35
|
"AI is a field of computer science focused on",
|
34
36
|
]
|
35
37
|
|
38
|
+
dirpath = os.path.dirname(__file__)
|
39
|
+
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
|
40
|
+
long_prompt = f.read()
|
41
|
+
DEFAULT_PROMPTS.append(long_prompt)
|
42
|
+
|
36
43
|
NUM_TOP_LOGPROBS = 5
|
37
44
|
|
38
45
|
|
@@ -56,44 +63,37 @@ class HFRunner:
|
|
56
63
|
def __init__(
|
57
64
|
self,
|
58
65
|
model_path,
|
59
|
-
torch_dtype
|
60
|
-
|
66
|
+
torch_dtype,
|
67
|
+
is_generation,
|
61
68
|
):
|
62
|
-
self.
|
63
|
-
|
69
|
+
self.is_generation = is_generation
|
70
|
+
|
71
|
+
self.in_queue = mp.Queue()
|
72
|
+
self.out_queue = mp.Queue()
|
64
73
|
|
65
|
-
self.model_proc =
|
74
|
+
self.model_proc = mp.Process(
|
66
75
|
target=self.start_model_process,
|
67
76
|
args=(
|
68
77
|
self.in_queue,
|
69
78
|
self.out_queue,
|
70
79
|
model_path,
|
71
80
|
torch_dtype,
|
72
|
-
is_generation_model,
|
73
81
|
),
|
74
82
|
)
|
75
83
|
self.model_proc.start()
|
76
84
|
|
77
|
-
def start_model_process(
|
78
|
-
self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
|
79
|
-
):
|
85
|
+
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
|
80
86
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
81
87
|
model_path,
|
82
88
|
torch_dtype=torch_dtype,
|
83
|
-
trust_remote_code=True,
|
84
89
|
)
|
85
90
|
|
86
|
-
self.
|
87
|
-
is_generation_model(model_path)
|
88
|
-
if is_generation_model is None
|
89
|
-
else is_generation_model
|
90
|
-
)
|
91
|
-
if self.is_generation_model:
|
91
|
+
if self.is_generation:
|
92
92
|
self.model = AutoModelForCausalLM.from_pretrained(
|
93
93
|
model_path,
|
94
94
|
torch_dtype=torch_dtype,
|
95
|
+
trust_remote_code=False,
|
95
96
|
low_cpu_mem_usage=True,
|
96
|
-
trust_remote_code=True,
|
97
97
|
).cuda()
|
98
98
|
else:
|
99
99
|
from sentence_transformers import SentenceTransformer
|
@@ -106,7 +106,7 @@ class HFRunner:
|
|
106
106
|
while True:
|
107
107
|
prompts, max_new_tokens = in_queue.get()
|
108
108
|
if prompts is not None:
|
109
|
-
if self.
|
109
|
+
if self.is_generation:
|
110
110
|
output_strs = []
|
111
111
|
prefill_logprobs = []
|
112
112
|
for p in prompts:
|
@@ -125,16 +125,14 @@ class HFRunner:
|
|
125
125
|
)
|
126
126
|
|
127
127
|
logits = self.model.forward(input_ids).logits[0]
|
128
|
-
logprobs = F.log_softmax(
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
# print("index",
|
133
|
-
logprobs
|
134
|
-
|
135
|
-
|
136
|
-
]
|
137
|
-
prefill_logprobs.append(logprobs)
|
128
|
+
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
129
|
+
logprobs, top_indices = torch.topk(
|
130
|
+
logprobs, k=NUM_TOP_LOGPROBS, dim=-1
|
131
|
+
)
|
132
|
+
# print("index", top_indices)
|
133
|
+
prefill_logprobs.append(logprobs.tolist())
|
134
|
+
del logits
|
135
|
+
del logprobs
|
138
136
|
|
139
137
|
out_queue.put(
|
140
138
|
ModelOutput(
|
@@ -171,19 +169,20 @@ class SRTRunner:
|
|
171
169
|
def __init__(
|
172
170
|
self,
|
173
171
|
model_path,
|
172
|
+
torch_dtype,
|
173
|
+
is_generation,
|
174
174
|
tp_size=1,
|
175
|
-
|
176
|
-
is_generation_model=None,
|
175
|
+
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
177
176
|
):
|
178
|
-
self.
|
179
|
-
is_generation_model(model_path)
|
180
|
-
if is_generation_model is None
|
181
|
-
else is_generation_model
|
182
|
-
)
|
177
|
+
self.is_generation = is_generation
|
183
178
|
self.runtime = Runtime(
|
184
179
|
model_path=model_path,
|
185
180
|
tp_size=tp_size,
|
186
181
|
dtype=get_dtype_str(torch_dtype),
|
182
|
+
port=port,
|
183
|
+
mem_fraction_static=0.69,
|
184
|
+
trust_remote_code=False,
|
185
|
+
is_embedding=not self.is_generation,
|
187
186
|
)
|
188
187
|
|
189
188
|
def forward(
|
@@ -191,7 +190,7 @@ class SRTRunner:
|
|
191
190
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
192
191
|
max_new_tokens=8,
|
193
192
|
):
|
194
|
-
if self.
|
193
|
+
if self.is_generation:
|
195
194
|
# the return value contains logprobs from prefill
|
196
195
|
output_strs = []
|
197
196
|
top_input_logprobs = []
|
@@ -201,6 +200,7 @@ class SRTRunner:
|
|
201
200
|
prompt,
|
202
201
|
sampling_params=sampling_params,
|
203
202
|
return_logprob=True,
|
203
|
+
logprob_start_len=0,
|
204
204
|
top_logprobs_num=NUM_TOP_LOGPROBS,
|
205
205
|
)
|
206
206
|
response = json.loads(response)
|
@@ -1,13 +1,12 @@
|
|
1
1
|
# Adapted from https://github.com/openai/simple-evals/
|
2
2
|
|
3
|
-
import base64
|
4
3
|
import os
|
5
4
|
import resource
|
6
5
|
import time
|
7
6
|
from collections import defaultdict
|
8
7
|
from dataclasses import dataclass, field
|
9
8
|
from multiprocessing.pool import ThreadPool
|
10
|
-
from typing import Any, Dict, List, Tuple
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple
|
11
10
|
|
12
11
|
import httpx
|
13
12
|
import jinja2
|
@@ -44,8 +43,8 @@ class EvalResult:
|
|
44
43
|
Result of running an evaluation (usually consisting of many samples)
|
45
44
|
"""
|
46
45
|
|
47
|
-
score: float
|
48
|
-
metrics: Dict[str, float]
|
46
|
+
score: Optional[float] # top-line metric
|
47
|
+
metrics: Optional[Dict[str, float]] # other metrics
|
49
48
|
htmls: List[str] # strings of valid HTML
|
50
49
|
convos: List[MessageList] # sampled conversations
|
51
50
|
|
@@ -56,10 +55,10 @@ class SingleEvalResult:
|
|
56
55
|
Result of evaluating a single sample
|
57
56
|
"""
|
58
57
|
|
59
|
-
score: float
|
58
|
+
score: Optional[float]
|
60
59
|
metrics: Dict[str, float] = field(default_factory=dict)
|
61
|
-
html: str
|
62
|
-
convo: MessageList
|
60
|
+
html: Optional[str] = None
|
61
|
+
convo: Optional[MessageList] = None # sampled conversation
|
63
62
|
|
64
63
|
|
65
64
|
class Eval:
|
@@ -89,8 +88,8 @@ class ChatCompletionSampler(SamplerBase):
|
|
89
88
|
def __init__(
|
90
89
|
self,
|
91
90
|
base_url: str = None,
|
92
|
-
model: str
|
93
|
-
system_message: str
|
91
|
+
model: Optional[str] = None,
|
92
|
+
system_message: Optional[str] = None,
|
94
93
|
temperature: float = 0.0,
|
95
94
|
max_tokens: int = 2048,
|
96
95
|
):
|
@@ -272,7 +271,7 @@ def _compute_stat(values: list, stat: str):
|
|
272
271
|
def aggregate_results(
|
273
272
|
single_eval_results: List[SingleEvalResult],
|
274
273
|
default_stats: Tuple[str] = ("mean", "std"),
|
275
|
-
name2stats: Dict[str, Tuple[str]]
|
274
|
+
name2stats: Optional[Dict[str, Tuple[str]]] = None,
|
276
275
|
) -> EvalResult:
|
277
276
|
"""
|
278
277
|
Aggregate results from multiple evaluations into a single EvalResult.
|
sglang/test/simple_eval_gpqa.py
CHANGED
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2311.12022
|
|
8
8
|
|
9
9
|
import random
|
10
10
|
import re
|
11
|
+
from typing import Optional
|
11
12
|
|
12
13
|
import pandas
|
13
14
|
|
@@ -28,7 +29,7 @@ class GPQAEval(Eval):
|
|
28
29
|
def __init__(
|
29
30
|
self,
|
30
31
|
filename: str,
|
31
|
-
num_examples: int
|
32
|
+
num_examples: Optional[int],
|
32
33
|
num_threads: int,
|
33
34
|
n_repeats: int = 1,
|
34
35
|
):
|
@@ -9,7 +9,7 @@ https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
|
|
9
9
|
import random
|
10
10
|
import re
|
11
11
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
12
|
-
from typing import Dict, List
|
12
|
+
from typing import Dict, List, Optional
|
13
13
|
|
14
14
|
import tqdm
|
15
15
|
|
@@ -61,7 +61,7 @@ def evaluate_functional_correctness(
|
|
61
61
|
class HumanEval(Eval):
|
62
62
|
def __init__(
|
63
63
|
self,
|
64
|
-
num_examples: int
|
64
|
+
num_examples: Optional[int],
|
65
65
|
num_threads: int,
|
66
66
|
num_samples_per_task: int = 5,
|
67
67
|
ks_passes: List[int] = [1, 2, 5],
|
sglang/test/simple_eval_math.py
CHANGED
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2103.03874
|
|
8
8
|
|
9
9
|
import random
|
10
10
|
import re
|
11
|
+
from typing import Optional
|
11
12
|
|
12
13
|
import pandas
|
13
14
|
|
@@ -36,7 +37,7 @@ class MathEval(Eval):
|
|
36
37
|
self,
|
37
38
|
filename: str,
|
38
39
|
equality_checker: SamplerBase,
|
39
|
-
num_examples: int
|
40
|
+
num_examples: Optional[int],
|
40
41
|
num_threads: int,
|
41
42
|
):
|
42
43
|
df = pandas.read_csv(filename)
|
sglang/test/simple_eval_mmlu.py
CHANGED
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2009.03300
|
|
8
8
|
|
9
9
|
import random
|
10
10
|
import re
|
11
|
+
from typing import Optional
|
11
12
|
|
12
13
|
import pandas
|
13
14
|
|
@@ -84,7 +85,7 @@ subject2category = {
|
|
84
85
|
|
85
86
|
|
86
87
|
class MMLUEval(Eval):
|
87
|
-
def __init__(self, filename: str, num_examples: int
|
88
|
+
def __init__(self, filename: str, num_examples: Optional[int], num_threads: int):
|
88
89
|
df = pandas.read_csv(filename)
|
89
90
|
examples = [row.to_dict() for _, row in df.iterrows()]
|
90
91
|
if num_examples:
|
@@ -0,0 +1,55 @@
|
|
1
|
+
import itertools
|
2
|
+
import unittest
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
from sglang.srt.layers.activation import GeluAndMul
|
7
|
+
|
8
|
+
|
9
|
+
class TestGeluAndMul(unittest.TestCase):
|
10
|
+
DTYPES = [torch.half, torch.bfloat16]
|
11
|
+
NUM_TOKENS = [7, 83, 2048]
|
12
|
+
D = [512, 4096, 5120, 13824]
|
13
|
+
SEEDS = [0]
|
14
|
+
|
15
|
+
@classmethod
|
16
|
+
def setUpClass(cls):
|
17
|
+
if not torch.cuda.is_available():
|
18
|
+
raise unittest.SkipTest("CUDA is not available")
|
19
|
+
torch.set_default_device("cuda")
|
20
|
+
|
21
|
+
def _run_gelu_and_mul_test(self, num_tokens, d, dtype, seed):
|
22
|
+
torch.manual_seed(seed)
|
23
|
+
|
24
|
+
layer = GeluAndMul().to(dtype=dtype)
|
25
|
+
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
26
|
+
|
27
|
+
with torch.inference_mode():
|
28
|
+
ref_out = layer.forward_native(x)
|
29
|
+
out = layer.forward_cuda(x)
|
30
|
+
|
31
|
+
if dtype == torch.bfloat16:
|
32
|
+
atol = rtol = 1e-2
|
33
|
+
else:
|
34
|
+
atol = rtol = 1e-3
|
35
|
+
|
36
|
+
self.assertTrue(torch.allclose(out, ref_out, atol=atol, rtol=rtol))
|
37
|
+
|
38
|
+
def test_gelu_and_mul(self):
|
39
|
+
for params in itertools.product(
|
40
|
+
self.NUM_TOKENS,
|
41
|
+
self.D,
|
42
|
+
self.DTYPES,
|
43
|
+
self.SEEDS,
|
44
|
+
):
|
45
|
+
with self.subTest(
|
46
|
+
num_tokens=params[0],
|
47
|
+
d=params[1],
|
48
|
+
dtype=params[2],
|
49
|
+
seed=params[3],
|
50
|
+
):
|
51
|
+
self._run_gelu_and_mul_test(*params)
|
52
|
+
|
53
|
+
|
54
|
+
if __name__ == "__main__":
|
55
|
+
unittest.main(verbosity=2)
|
sglang/test/test_programs.py
CHANGED
@@ -103,16 +103,19 @@ def test_decode_int():
|
|
103
103
|
def test_decode_json_regex():
|
104
104
|
@sgl.function
|
105
105
|
def decode_json(s):
|
106
|
-
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT,
|
106
|
+
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
|
107
107
|
|
108
108
|
s += "Generate a JSON object to describe the basic city information of Paris.\n"
|
109
|
+
s += "Here are the JSON object:\n"
|
110
|
+
|
111
|
+
# NOTE: we recommend using dtype gen or whole regex string to control the output
|
109
112
|
|
110
113
|
with s.var_scope("json_output"):
|
111
114
|
s += "{\n"
|
112
|
-
s += ' "name": ' + sgl.gen(regex=
|
113
|
-
s += ' "population": ' + sgl.gen(regex=REGEX_INT
|
114
|
-
s += ' "area": ' + sgl.gen(regex=REGEX_INT
|
115
|
-
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n"
|
115
|
+
s += ' "name": ' + sgl.gen(regex=REGEX_STR) + ",\n"
|
116
|
+
s += ' "population": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
|
117
|
+
s += ' "area": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
|
118
|
+
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT, stop=[" ", "\n"]) + "\n"
|
116
119
|
s += "}"
|
117
120
|
|
118
121
|
ret = decode_json.run(temperature=0.0)
|
@@ -359,6 +362,30 @@ def test_regex():
|
|
359
362
|
assert re.match(regex, answer)
|
360
363
|
|
361
364
|
|
365
|
+
def test_dtype_gen():
|
366
|
+
@sgl.function
|
367
|
+
def dtype_gen(s):
|
368
|
+
s += "Q: What is the full name of DNS?\n"
|
369
|
+
s += "A: The full nams is " + sgl.gen("str_res", dtype=str, stop="\n") + "\n"
|
370
|
+
s += "Q: Which year was DNS invented?\n"
|
371
|
+
s += "A: " + sgl.gen("int_res", dtype=int) + "\n"
|
372
|
+
s += "Q: What is the value of pi?\n"
|
373
|
+
s += "A: " + sgl.gen("float_res", dtype=float) + "\n"
|
374
|
+
s += "Q: Is the sky blue?\n"
|
375
|
+
s += "A: " + sgl.gen("bool_res", dtype=bool) + "\n"
|
376
|
+
|
377
|
+
state = dtype_gen.run()
|
378
|
+
|
379
|
+
try:
|
380
|
+
state["int_res"] = int(state["int_res"])
|
381
|
+
state["float_res"] = float(state["float_res"])
|
382
|
+
state["bool_res"] = bool(state["bool_res"])
|
383
|
+
# assert state["str_res"].startswith('"') and state["str_res"].endswith('"')
|
384
|
+
except ValueError:
|
385
|
+
print(state)
|
386
|
+
raise
|
387
|
+
|
388
|
+
|
362
389
|
def test_completion_speculative():
|
363
390
|
@sgl.function(num_api_spec_tokens=64)
|
364
391
|
def gen_character_spec(s):
|