sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +16 -6
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +27 -12
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +76 -102
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +26 -17
- sglang/srt/layers/quantization/__init__.py +22 -23
- sglang/srt/layers/quantization/fp8.py +112 -55
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +2 -3
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +17 -4
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +46 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -8
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +54 -15
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +319 -181
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +303 -158
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +110 -77
- sglang/srt/metrics/collector.py +25 -11
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +80 -21
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +41 -4
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +52 -4
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +153 -9
- sglang/srt/sampling/sampling_params.py +4 -2
- sglang/srt/server.py +4 -1037
- sglang/srt/server_args.py +84 -32
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +130 -63
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
|
-
#
|
1
|
+
# SGLang public APIs
|
2
2
|
|
3
|
+
# Frontend Language APIs
|
3
4
|
from sglang.api import (
|
4
5
|
Engine,
|
5
6
|
Runtime,
|
@@ -23,16 +24,26 @@ from sglang.api import (
|
|
23
24
|
user_end,
|
24
25
|
video,
|
25
26
|
)
|
27
|
+
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
26
28
|
from sglang.lang.choices import (
|
27
29
|
greedy_token_selection,
|
28
30
|
token_length_normalized,
|
29
31
|
unconditional_likelihood_normalized,
|
30
32
|
)
|
33
|
+
from sglang.utils import LazyImport
|
34
|
+
|
35
|
+
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
|
36
|
+
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
|
37
|
+
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
|
38
|
+
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
|
39
|
+
|
40
|
+
# Other configs
|
41
|
+
from sglang.global_config import global_config
|
42
|
+
from sglang.version import __version__
|
31
43
|
|
32
|
-
# SGLang DSL APIs
|
33
44
|
__all__ = [
|
34
|
-
"Runtime",
|
35
45
|
"Engine",
|
46
|
+
"Runtime",
|
36
47
|
"assistant",
|
37
48
|
"assistant_begin",
|
38
49
|
"assistant_end",
|
@@ -52,27 +63,14 @@ __all__ = [
|
|
52
63
|
"user_begin",
|
53
64
|
"user_end",
|
54
65
|
"video",
|
66
|
+
"RuntimeEndpoint",
|
55
67
|
"greedy_token_selection",
|
56
68
|
"token_length_normalized",
|
57
69
|
"unconditional_likelihood_normalized",
|
70
|
+
"Anthropic",
|
71
|
+
"LiteLLM",
|
72
|
+
"OpenAI",
|
73
|
+
"VertexAI",
|
74
|
+
"global_config",
|
75
|
+
"__version__",
|
58
76
|
]
|
59
|
-
|
60
|
-
# Global Configurations
|
61
|
-
from sglang.global_config import global_config
|
62
|
-
|
63
|
-
__all__ += ["global_config"]
|
64
|
-
|
65
|
-
from sglang.version import __version__
|
66
|
-
|
67
|
-
__all__ += ["__version__"]
|
68
|
-
|
69
|
-
# SGLang Backends
|
70
|
-
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
71
|
-
from sglang.utils import LazyImport
|
72
|
-
|
73
|
-
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
|
74
|
-
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
|
75
|
-
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
|
76
|
-
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
|
77
|
-
|
78
|
-
__all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "RuntimeEndpoint"]
|
sglang/api.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
"""Public APIs of the language."""
|
2
2
|
|
3
|
-
import os
|
4
3
|
import re
|
5
4
|
from typing import Callable, List, Optional, Union
|
6
5
|
|
@@ -33,19 +32,15 @@ def function(
|
|
33
32
|
|
34
33
|
|
35
34
|
def Runtime(*args, **kwargs):
|
36
|
-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
37
|
-
|
38
35
|
# Avoid importing unnecessary dependency
|
39
|
-
from sglang.
|
36
|
+
from sglang.lang.backend.runtime_endpoint import Runtime
|
40
37
|
|
41
38
|
return Runtime(*args, **kwargs)
|
42
39
|
|
43
40
|
|
44
41
|
def Engine(*args, **kwargs):
|
45
|
-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
46
|
-
|
47
42
|
# Avoid importing unnecessary dependency
|
48
|
-
from sglang.srt.
|
43
|
+
from sglang.srt.entrypoints.engine import Engine
|
49
44
|
|
50
45
|
return Engine(*args, **kwargs)
|
51
46
|
|
@@ -27,7 +27,8 @@ from sglang.bench_serving import (
|
|
27
27
|
sample_random_requests,
|
28
28
|
set_ulimit,
|
29
29
|
)
|
30
|
-
from sglang.
|
30
|
+
from sglang.lang.backend.runtime_endpoint import Runtime
|
31
|
+
from sglang.srt.entrypoints.engine import Engine
|
31
32
|
from sglang.srt.server_args import ServerArgs
|
32
33
|
|
33
34
|
|
@@ -39,14 +40,15 @@ class BenchArgs:
|
|
39
40
|
dataset_path: str = ""
|
40
41
|
num_prompts: int = 1000
|
41
42
|
sharegpt_output_len: Optional[int] = None
|
43
|
+
sharegpt_context_len: Optional[int] = None
|
42
44
|
random_input_len: int = 1024
|
43
45
|
random_output_len: int = 1024
|
44
46
|
random_range_ratio: float = 0.0
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
47
|
+
gsp_num_groups: int = 64
|
48
|
+
gsp_prompts_per_group: int = 16
|
49
|
+
gsp_system_prompt_len: int = 2048
|
50
|
+
gsp_question_len: int = 128
|
51
|
+
gsp_output_len: int = 256
|
50
52
|
disable_ignore_eos: bool = False
|
51
53
|
extra_request_body: Optional[str] = None
|
52
54
|
seed: int = 1
|
@@ -82,6 +84,12 @@ class BenchArgs:
|
|
82
84
|
default=BenchArgs.sharegpt_output_len,
|
83
85
|
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
84
86
|
)
|
87
|
+
parser.add_argument(
|
88
|
+
"--sharegpt-context-len",
|
89
|
+
type=int,
|
90
|
+
default=BenchArgs.sharegpt_context_len,
|
91
|
+
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
|
92
|
+
)
|
85
93
|
parser.add_argument(
|
86
94
|
"--random-input-len",
|
87
95
|
type=int,
|
@@ -102,35 +110,35 @@ class BenchArgs:
|
|
102
110
|
"used only for random dataset.",
|
103
111
|
)
|
104
112
|
parser.add_argument(
|
105
|
-
"--
|
113
|
+
"--gsp-num-groups",
|
106
114
|
type=int,
|
107
|
-
default=BenchArgs.
|
115
|
+
default=BenchArgs.gsp_num_groups,
|
108
116
|
help="Number of groups with shared prefix, used"
|
109
117
|
"only for generate-shared-prefix",
|
110
118
|
)
|
111
119
|
parser.add_argument(
|
112
|
-
"--
|
120
|
+
"--gsp-prompts-per-group",
|
113
121
|
type=int,
|
114
|
-
default=BenchArgs.
|
122
|
+
default=BenchArgs.gsp_prompts_per_group,
|
115
123
|
help="Number of prompts per group of shared prefix, used"
|
116
124
|
"only for generate-shared-prefix",
|
117
125
|
)
|
118
126
|
parser.add_argument(
|
119
|
-
"--
|
127
|
+
"--gsp-system-prompt-len",
|
120
128
|
type=int,
|
121
|
-
default=BenchArgs.
|
129
|
+
default=BenchArgs.gsp_system_prompt_len,
|
122
130
|
help="System prompt length, used" "only for generate-shared-prefix",
|
123
131
|
)
|
124
132
|
parser.add_argument(
|
125
|
-
"--
|
133
|
+
"--gsp-question-len",
|
126
134
|
type=int,
|
127
|
-
default=BenchArgs.
|
135
|
+
default=BenchArgs.gsp_question_len,
|
128
136
|
help="Question length, used" "only for generate-shared-prefix",
|
129
137
|
)
|
130
138
|
parser.add_argument(
|
131
|
-
"--
|
139
|
+
"--gsp-output-len",
|
132
140
|
type=int,
|
133
|
-
default=BenchArgs.
|
141
|
+
default=BenchArgs.gsp_output_len,
|
134
142
|
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
135
143
|
)
|
136
144
|
parser.add_argument(
|
sglang/bench_one_batch.py
CHANGED
@@ -9,7 +9,8 @@ It accepts server arguments (the same as launch_server.py) and benchmark argumen
|
|
9
9
|
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
10
10
|
## sweep through multiple data points and store (append) the results in a jsonl file:
|
11
11
|
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
|
12
|
-
|
12
|
+
## run with profiling:
|
13
|
+
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
|
13
14
|
# Usage (correctness test):
|
14
15
|
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
15
16
|
|
@@ -56,12 +57,12 @@ import torch
|
|
56
57
|
import torch.distributed as dist
|
57
58
|
|
58
59
|
from sglang.srt.configs.model_config import ModelConfig
|
60
|
+
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
59
61
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
60
62
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
61
63
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
62
64
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
63
65
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
64
|
-
from sglang.srt.server import _set_envs_and_config
|
65
66
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
66
67
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
67
68
|
from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers
|
@@ -77,6 +78,8 @@ class BenchArgs:
|
|
77
78
|
correctness_test: bool = False
|
78
79
|
# This is only used for correctness test
|
79
80
|
cut_len: int = 4
|
81
|
+
profile: bool = False
|
82
|
+
profile_filename_prefix: str = "profile"
|
80
83
|
|
81
84
|
@staticmethod
|
82
85
|
def add_cli_args(parser: argparse.ArgumentParser):
|
@@ -95,6 +98,19 @@ class BenchArgs:
|
|
95
98
|
)
|
96
99
|
parser.add_argument("--correctness-test", action="store_true")
|
97
100
|
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
101
|
+
parser.add_argument(
|
102
|
+
"--profile",
|
103
|
+
action="store_true",
|
104
|
+
help="Use Torch Profiler. The endpoint must be launched with "
|
105
|
+
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
106
|
+
)
|
107
|
+
parser.add_argument(
|
108
|
+
"--profile-filename-prefix",
|
109
|
+
type=str,
|
110
|
+
default=BenchArgs.profile_filename_prefix,
|
111
|
+
help="Prefix of the profiling file names. The full profiling result file(s) be "
|
112
|
+
'"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"',
|
113
|
+
)
|
98
114
|
|
99
115
|
@classmethod
|
100
116
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -216,6 +232,7 @@ def extend(reqs, model_runner):
|
|
216
232
|
model_config=model_runner.model_config,
|
217
233
|
enable_overlap=False,
|
218
234
|
spec_algorithm=SpeculativeAlgorithm.NONE,
|
235
|
+
enable_custom_logit_processor=False,
|
219
236
|
)
|
220
237
|
batch.prepare_for_extend()
|
221
238
|
model_worker_batch = batch.get_model_worker_batch()
|
@@ -286,7 +303,16 @@ def synchronize(device):
|
|
286
303
|
|
287
304
|
|
288
305
|
def latency_test_run_once(
|
289
|
-
run_name,
|
306
|
+
run_name,
|
307
|
+
model_runner,
|
308
|
+
rank_print,
|
309
|
+
reqs,
|
310
|
+
batch_size,
|
311
|
+
input_len,
|
312
|
+
output_len,
|
313
|
+
device,
|
314
|
+
profile,
|
315
|
+
profile_filename_prefix,
|
290
316
|
):
|
291
317
|
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
292
318
|
if batch_size > max_batch_size:
|
@@ -308,6 +334,17 @@ def latency_test_run_once(
|
|
308
334
|
|
309
335
|
tot_latency = 0
|
310
336
|
|
337
|
+
profiler = None
|
338
|
+
if profile:
|
339
|
+
profiler = torch.profiler.profile(
|
340
|
+
activities=[
|
341
|
+
torch.profiler.ProfilerActivity.CPU,
|
342
|
+
torch.profiler.ProfilerActivity.CUDA,
|
343
|
+
],
|
344
|
+
with_stack=True,
|
345
|
+
)
|
346
|
+
profiler.start()
|
347
|
+
|
311
348
|
# Prefill
|
312
349
|
synchronize(device)
|
313
350
|
tic = time.time()
|
@@ -338,6 +375,13 @@ def latency_test_run_once(
|
|
338
375
|
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
339
376
|
)
|
340
377
|
|
378
|
+
if profile:
|
379
|
+
profiler.stop()
|
380
|
+
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz"
|
381
|
+
parent_dir = os.path.dirname(os.path.abspath(profile_filename))
|
382
|
+
os.makedirs(parent_dir, exist_ok=True)
|
383
|
+
profiler.export_chrome_trace(profile_filename)
|
384
|
+
|
341
385
|
# Record decode timing from 2nd output
|
342
386
|
if output_len > 1:
|
343
387
|
med_decode_latency = np.median(decode_latencies)
|
@@ -386,6 +430,8 @@ def latency_test(
|
|
386
430
|
bench_args.input_len[0],
|
387
431
|
8, # shorter decoding to speed up the warmup
|
388
432
|
server_args.device,
|
433
|
+
profile=False,
|
434
|
+
profile_filename_prefix="", # not used
|
389
435
|
)
|
390
436
|
|
391
437
|
rank_print("Benchmark ...")
|
@@ -405,6 +451,8 @@ def latency_test(
|
|
405
451
|
il,
|
406
452
|
ol,
|
407
453
|
server_args.device,
|
454
|
+
bench_args.profile,
|
455
|
+
bench_args.profile_filename_prefix,
|
408
456
|
)
|
409
457
|
if ret is not None:
|
410
458
|
result_list.append(ret)
|
sglang/bench_one_batch_server.py
CHANGED
@@ -22,7 +22,7 @@ from typing import Tuple
|
|
22
22
|
import numpy as np
|
23
23
|
import requests
|
24
24
|
|
25
|
-
from sglang.srt.
|
25
|
+
from sglang.srt.entrypoints.http_server import launch_server
|
26
26
|
from sglang.srt.server_args import ServerArgs
|
27
27
|
from sglang.srt.utils import kill_process_tree
|
28
28
|
|
sglang/bench_serving.py
CHANGED
@@ -452,6 +452,7 @@ def get_dataset(args, tokenizer):
|
|
452
452
|
num_requests=args.num_prompts,
|
453
453
|
tokenizer=tokenizer,
|
454
454
|
fixed_output_len=args.sharegpt_output_len,
|
455
|
+
context_len=args.sharegpt_context_len,
|
455
456
|
)
|
456
457
|
elif args.dataset_name == "random":
|
457
458
|
input_requests = sample_random_requests(
|
@@ -464,11 +465,11 @@ def get_dataset(args, tokenizer):
|
|
464
465
|
)
|
465
466
|
elif args.dataset_name == "generated-shared-prefix":
|
466
467
|
input_requests = sample_generated_shared_prefix_requests(
|
467
|
-
num_groups=args.
|
468
|
-
prompts_per_group=args.
|
469
|
-
system_prompt_len=args.
|
470
|
-
question_len=args.
|
471
|
-
output_len=args.
|
468
|
+
num_groups=args.gsp_num_groups,
|
469
|
+
prompts_per_group=args.gsp_prompts_per_group,
|
470
|
+
system_prompt_len=args.gsp_system_prompt_len,
|
471
|
+
question_len=args.gsp_question_len,
|
472
|
+
output_len=args.gsp_output_len,
|
472
473
|
tokenizer=tokenizer,
|
473
474
|
)
|
474
475
|
else:
|
@@ -560,6 +561,7 @@ def sample_sharegpt_requests(
|
|
560
561
|
num_requests: int,
|
561
562
|
tokenizer: PreTrainedTokenizerBase,
|
562
563
|
fixed_output_len: Optional[int] = None,
|
564
|
+
context_len: Optional[int] = None,
|
563
565
|
) -> List[Tuple[str, int, int]]:
|
564
566
|
if fixed_output_len is not None and fixed_output_len < 4:
|
565
567
|
raise ValueError("output_len too small")
|
@@ -597,14 +599,15 @@ def sample_sharegpt_requests(
|
|
597
599
|
output_len = (
|
598
600
|
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
|
599
601
|
)
|
600
|
-
|
602
|
+
|
603
|
+
if prompt_len < 1 or output_len < 1:
|
601
604
|
# Prune too short sequences.
|
602
605
|
continue
|
603
|
-
|
604
|
-
|
605
|
-
):
|
606
|
+
|
607
|
+
if context_len and prompt_len + output_len > context_len:
|
606
608
|
# Prune too long sequences.
|
607
609
|
continue
|
610
|
+
|
608
611
|
filtered_dataset.append((prompt, prompt_len, output_len))
|
609
612
|
|
610
613
|
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
|
@@ -706,8 +709,8 @@ def get_gen_prefix_cache_path(args, tokenizer):
|
|
706
709
|
|
707
710
|
# Create a unique cache filename based on the generation parameters
|
708
711
|
cache_key = (
|
709
|
-
f"
|
710
|
-
f"{args.
|
712
|
+
f"gen_shared_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_"
|
713
|
+
f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_"
|
711
714
|
f"{tokenizer.__class__.__name__}.pkl"
|
712
715
|
)
|
713
716
|
return cache_dir / cache_key
|
@@ -1374,6 +1377,12 @@ if __name__ == "__main__":
|
|
1374
1377
|
default=None,
|
1375
1378
|
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
1376
1379
|
)
|
1380
|
+
parser.add_argument(
|
1381
|
+
"--sharegpt-context-len",
|
1382
|
+
type=int,
|
1383
|
+
default=None,
|
1384
|
+
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
|
1385
|
+
)
|
1377
1386
|
parser.add_argument(
|
1378
1387
|
"--random-input-len",
|
1379
1388
|
type=int,
|
@@ -1453,49 +1462,49 @@ if __name__ == "__main__":
|
|
1453
1462
|
help="Append given JSON object to the request payload. You can use this to specify"
|
1454
1463
|
"additional generate params like sampling params.",
|
1455
1464
|
)
|
1465
|
+
parser.add_argument(
|
1466
|
+
"--profile",
|
1467
|
+
action="store_true",
|
1468
|
+
help="Use Torch Profiler. The endpoint must be launched with "
|
1469
|
+
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
1470
|
+
)
|
1471
|
+
parser.add_argument(
|
1472
|
+
"--lora-name",
|
1473
|
+
type=str,
|
1474
|
+
default=None,
|
1475
|
+
help="The name of LoRA adapter",
|
1476
|
+
)
|
1456
1477
|
|
1457
1478
|
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
1458
1479
|
group.add_argument(
|
1459
|
-
"--
|
1480
|
+
"--gsp-num-groups",
|
1460
1481
|
type=int,
|
1461
1482
|
default=64,
|
1462
1483
|
help="Number of system prompt groups for generated-shared-prefix dataset",
|
1463
1484
|
)
|
1464
1485
|
group.add_argument(
|
1465
|
-
"--
|
1486
|
+
"--gsp-prompts-per-group",
|
1466
1487
|
type=int,
|
1467
1488
|
default=16,
|
1468
1489
|
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
|
1469
1490
|
)
|
1470
1491
|
group.add_argument(
|
1471
|
-
"--
|
1492
|
+
"--gsp-system-prompt-len",
|
1472
1493
|
type=int,
|
1473
1494
|
default=2048,
|
1474
1495
|
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
|
1475
1496
|
)
|
1476
1497
|
group.add_argument(
|
1477
|
-
"--
|
1498
|
+
"--gsp-question-len",
|
1478
1499
|
type=int,
|
1479
1500
|
default=128,
|
1480
1501
|
help="Target length in tokens for questions in generated-shared-prefix dataset",
|
1481
1502
|
)
|
1482
1503
|
group.add_argument(
|
1483
|
-
"--
|
1504
|
+
"--gsp-output-len",
|
1484
1505
|
type=int,
|
1485
1506
|
default=256,
|
1486
1507
|
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
1487
1508
|
)
|
1488
|
-
parser.add_argument(
|
1489
|
-
"--profile",
|
1490
|
-
action="store_true",
|
1491
|
-
help="Use Torch Profiler. The endpoint must be launched with "
|
1492
|
-
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
1493
|
-
)
|
1494
|
-
parser.add_argument(
|
1495
|
-
"--lora-name",
|
1496
|
-
type=str,
|
1497
|
-
default=None,
|
1498
|
-
help="The name of LoRA adapter",
|
1499
|
-
)
|
1500
1509
|
args = parser.parse_args()
|
1501
1510
|
run_benchmark(args)
|
@@ -1,6 +1,11 @@
|
|
1
|
+
import atexit
|
1
2
|
import json
|
3
|
+
import multiprocessing
|
2
4
|
import warnings
|
3
|
-
from typing import List, Optional
|
5
|
+
from typing import Dict, List, Optional, Union
|
6
|
+
|
7
|
+
import aiohttp
|
8
|
+
import requests
|
4
9
|
|
5
10
|
from sglang.global_config import global_config
|
6
11
|
from sglang.lang.backend.base_backend import BaseBackend
|
@@ -251,11 +256,12 @@ class RuntimeEndpoint(BaseBackend):
|
|
251
256
|
}
|
252
257
|
obj = self._generate_http_request(s, data)
|
253
258
|
|
254
|
-
normalized_prompt_logprobs = [
|
255
|
-
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
256
|
-
]
|
257
259
|
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
258
260
|
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
261
|
+
normalized_prompt_logprobs = [
|
262
|
+
compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"])
|
263
|
+
for r in obj
|
264
|
+
]
|
259
265
|
|
260
266
|
# Remove extra token if no token healing occurred
|
261
267
|
for i in range(len(input_token_logprobs)):
|
@@ -319,3 +325,176 @@ class RuntimeEndpoint(BaseBackend):
|
|
319
325
|
def _assert_success(self, res):
|
320
326
|
if res.status_code != 200:
|
321
327
|
raise RuntimeError(res.json())
|
328
|
+
|
329
|
+
|
330
|
+
def compute_normalized_prompt_logprobs(input_logprobs):
|
331
|
+
values = [x[0] for x in input_logprobs if x[0]]
|
332
|
+
return sum(values) / len(values)
|
333
|
+
|
334
|
+
|
335
|
+
class Runtime:
|
336
|
+
"""
|
337
|
+
A wrapper for the HTTP server.
|
338
|
+
This is used for launching the server in a python program without
|
339
|
+
using the commond line interface.
|
340
|
+
|
341
|
+
It is mainly used for the frontend language.
|
342
|
+
You should use the Engine class if you want to do normal offline processing without the frontend language.
|
343
|
+
"""
|
344
|
+
|
345
|
+
def __init__(
|
346
|
+
self,
|
347
|
+
log_level: str = "error",
|
348
|
+
*args,
|
349
|
+
**kwargs,
|
350
|
+
):
|
351
|
+
"""See the arguments in server_args.py::ServerArgs"""
|
352
|
+
# We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run
|
353
|
+
# client code without installing SRT server and its dependency if they want.
|
354
|
+
from sglang.srt.entrypoints.http_server import launch_server
|
355
|
+
from sglang.srt.server_args import ServerArgs
|
356
|
+
from sglang.srt.utils import is_port_available
|
357
|
+
|
358
|
+
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
359
|
+
|
360
|
+
# Pre-allocate ports
|
361
|
+
for port in range(self.server_args.port, 40000):
|
362
|
+
if is_port_available(port):
|
363
|
+
break
|
364
|
+
self.server_args.port = port
|
365
|
+
|
366
|
+
self.url = self.server_args.url()
|
367
|
+
self.generate_url = self.url + "/generate"
|
368
|
+
|
369
|
+
# NOTE: We store pid instead of proc to fix some issues during __delete__
|
370
|
+
self.pid = None
|
371
|
+
pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
|
372
|
+
|
373
|
+
proc = multiprocessing.Process(
|
374
|
+
target=launch_server,
|
375
|
+
args=(self.server_args, pipe_writer),
|
376
|
+
)
|
377
|
+
proc.start()
|
378
|
+
pipe_writer.close()
|
379
|
+
self.pid = proc.pid
|
380
|
+
|
381
|
+
# Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
382
|
+
atexit.register(self.shutdown)
|
383
|
+
|
384
|
+
# TODO: remove this pipe_writer mechanism and use `/health_generate` instead.
|
385
|
+
try:
|
386
|
+
init_state = pipe_reader.recv()
|
387
|
+
except EOFError:
|
388
|
+
init_state = ""
|
389
|
+
|
390
|
+
if init_state != "ready":
|
391
|
+
self.shutdown()
|
392
|
+
raise RuntimeError(
|
393
|
+
"Initialization failed. Please see the error messages above."
|
394
|
+
)
|
395
|
+
|
396
|
+
self.endpoint = RuntimeEndpoint(self.url)
|
397
|
+
|
398
|
+
def shutdown(self):
|
399
|
+
from sglang.srt.utils import kill_process_tree
|
400
|
+
|
401
|
+
if self.pid is not None:
|
402
|
+
kill_process_tree(self.pid)
|
403
|
+
self.pid = None
|
404
|
+
|
405
|
+
def cache_prefix(self, prefix: str):
|
406
|
+
self.endpoint.cache_prefix(prefix)
|
407
|
+
|
408
|
+
def get_tokenizer(self):
|
409
|
+
from sglang.srt.hf_transformers_utils import get_tokenizer
|
410
|
+
|
411
|
+
return get_tokenizer(
|
412
|
+
self.server_args.tokenizer_path,
|
413
|
+
tokenizer_mode=self.server_args.tokenizer_mode,
|
414
|
+
trust_remote_code=self.server_args.trust_remote_code,
|
415
|
+
revision=self.server_args.revision,
|
416
|
+
)
|
417
|
+
|
418
|
+
async def async_generate(
|
419
|
+
self,
|
420
|
+
prompt: str,
|
421
|
+
sampling_params: Optional[Dict] = None,
|
422
|
+
):
|
423
|
+
if self.server_args.skip_tokenizer_init:
|
424
|
+
json_data = {
|
425
|
+
"input_ids": prompt,
|
426
|
+
"sampling_params": sampling_params,
|
427
|
+
"stream": True,
|
428
|
+
}
|
429
|
+
else:
|
430
|
+
json_data = {
|
431
|
+
"text": prompt,
|
432
|
+
"sampling_params": sampling_params,
|
433
|
+
"stream": True,
|
434
|
+
}
|
435
|
+
pos = 0
|
436
|
+
|
437
|
+
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
438
|
+
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
439
|
+
async with session.post(self.generate_url, json=json_data) as response:
|
440
|
+
async for chunk, _ in response.content.iter_chunks():
|
441
|
+
chunk = chunk.decode("utf-8")
|
442
|
+
if chunk and chunk.startswith("data:"):
|
443
|
+
if chunk == "data: [DONE]\n\n":
|
444
|
+
break
|
445
|
+
data = json.loads(chunk[5:].strip("\n"))
|
446
|
+
if "text" in data:
|
447
|
+
cur = data["text"][pos:]
|
448
|
+
if cur:
|
449
|
+
yield cur
|
450
|
+
pos += len(cur)
|
451
|
+
else:
|
452
|
+
yield data
|
453
|
+
|
454
|
+
add_request = async_generate
|
455
|
+
|
456
|
+
def generate(
|
457
|
+
self,
|
458
|
+
prompt: Union[str, List[str]],
|
459
|
+
sampling_params: Optional[Dict] = None,
|
460
|
+
return_logprob: Optional[Union[List[bool], bool]] = False,
|
461
|
+
logprob_start_len: Optional[Union[List[int], int]] = None,
|
462
|
+
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
463
|
+
lora_path: Optional[List[Optional[str]]] = None,
|
464
|
+
):
|
465
|
+
json_data = {
|
466
|
+
"text": prompt,
|
467
|
+
"sampling_params": sampling_params,
|
468
|
+
"return_logprob": return_logprob,
|
469
|
+
"logprob_start_len": logprob_start_len,
|
470
|
+
"top_logprobs_num": top_logprobs_num,
|
471
|
+
"lora_path": lora_path,
|
472
|
+
}
|
473
|
+
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
|
474
|
+
response = requests.post(
|
475
|
+
self.url + "/generate",
|
476
|
+
json=json_data,
|
477
|
+
)
|
478
|
+
return json.dumps(response.json())
|
479
|
+
|
480
|
+
def encode(
|
481
|
+
self,
|
482
|
+
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
483
|
+
):
|
484
|
+
json_data = {"text": prompt}
|
485
|
+
response = requests.post(self.url + "/encode", json=json_data)
|
486
|
+
return json.dumps(response.json())
|
487
|
+
|
488
|
+
async def get_server_info(self):
|
489
|
+
async with aiohttp.ClientSession() as session:
|
490
|
+
async with session.get(f"{self.url}/get_server_info") as response:
|
491
|
+
if response.status == 200:
|
492
|
+
return await response.json()
|
493
|
+
else:
|
494
|
+
error_data = await response.json()
|
495
|
+
raise RuntimeError(
|
496
|
+
f"Failed to get server info. {error_data['error']['message']}"
|
497
|
+
)
|
498
|
+
|
499
|
+
def __del__(self):
|
500
|
+
self.shutdown()
|