sglang 0.3.1__py3-none-any.whl → 0.3.1.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -3
- sglang/bench_server_latency.py +187 -0
- sglang/bench_serving.py +1 -1
- sglang/global_config.py +5 -13
- sglang/lang/interpreter.py +0 -3
- sglang/srt/constrained/fsm_cache.py +5 -1
- sglang/srt/layers/activation.py +16 -1
- sglang/srt/layers/attention_backend.py +12 -12
- sglang/srt/layers/fused_moe/layer.py +27 -7
- sglang/srt/layers/layernorm.py +21 -6
- sglang/srt/layers/sampler.py +40 -98
- sglang/srt/lora/lora_manager.py +11 -8
- sglang/srt/managers/io_struct.py +3 -0
- sglang/srt/managers/policy_scheduler.py +49 -93
- sglang/srt/managers/schedule_batch.py +2 -1
- sglang/srt/managers/tp_worker.py +19 -13
- sglang/srt/model_executor/cuda_graph_runner.py +25 -13
- sglang/srt/model_executor/model_runner.py +37 -46
- sglang/srt/models/deepseek_v2.py +8 -3
- sglang/srt/models/llama.py +1 -3
- sglang/srt/models/llama_classification.py +2 -3
- sglang/srt/models/minicpm3.py +7 -3
- sglang/srt/models/olmoe.py +415 -0
- sglang/srt/models/xverse.py +1 -3
- sglang/srt/models/xverse_moe.py +1 -4
- sglang/srt/sampling/sampling_batch_info.py +3 -50
- sglang/srt/server.py +6 -1
- sglang/srt/server_args.py +39 -10
- sglang/srt/utils.py +7 -51
- sglang/test/few_shot_gsm8k.py +8 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/METADATA +4 -5
- {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/RECORD +37 -35
- {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/WHEEL +1 -1
- {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
"""
|
2
|
-
Benchmark the latency of a
|
2
|
+
Benchmark the latency of running a single static batch.
|
3
|
+
This script does not launch a server and uses the low-level APIs.
|
4
|
+
It accepts arguments similar to those of launch_server.py.
|
3
5
|
|
4
6
|
# Usage (latency test)
|
5
7
|
## with dummy weights:
|
@@ -63,7 +65,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
|
63
65
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
64
66
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
65
67
|
from sglang.srt.server_args import ServerArgs
|
66
|
-
from sglang.srt.utils import suppress_other_loggers
|
68
|
+
from sglang.srt.utils import kill_child_process, suppress_other_loggers
|
67
69
|
|
68
70
|
|
69
71
|
@dataclasses.dataclass
|
@@ -502,4 +504,9 @@ if __name__ == "__main__":
|
|
502
504
|
format="%(message)s",
|
503
505
|
)
|
504
506
|
|
505
|
-
|
507
|
+
try:
|
508
|
+
main(server_args, bench_args)
|
509
|
+
except Exception as e:
|
510
|
+
raise e
|
511
|
+
finally:
|
512
|
+
kill_child_process(os.getpid(), including_parent=False)
|
@@ -0,0 +1,187 @@
|
|
1
|
+
"""
|
2
|
+
Benchmark the latency of serving a single batch with a real server.
|
3
|
+
This script launches a server and uses the HTTP interface.
|
4
|
+
It accepts arguments similar to those of launch_server.py.
|
5
|
+
|
6
|
+
Usage:
|
7
|
+
|
8
|
+
python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
|
9
|
+
"""
|
10
|
+
|
11
|
+
import argparse
|
12
|
+
import dataclasses
|
13
|
+
import itertools
|
14
|
+
import json
|
15
|
+
import multiprocessing
|
16
|
+
import os
|
17
|
+
import time
|
18
|
+
from typing import Tuple
|
19
|
+
|
20
|
+
import numpy as np
|
21
|
+
import requests
|
22
|
+
|
23
|
+
from sglang.srt.server import launch_server
|
24
|
+
from sglang.srt.server_args import ServerArgs
|
25
|
+
from sglang.srt.utils import kill_child_process
|
26
|
+
|
27
|
+
|
28
|
+
@dataclasses.dataclass
|
29
|
+
class BenchArgs:
|
30
|
+
run_name: str = "default"
|
31
|
+
batch_size: Tuple[int] = (1,)
|
32
|
+
input_len: Tuple[int] = (1024,)
|
33
|
+
output_len: Tuple[int] = (16,)
|
34
|
+
result_filename: str = "result.jsonl"
|
35
|
+
|
36
|
+
@staticmethod
|
37
|
+
def add_cli_args(parser: argparse.ArgumentParser):
|
38
|
+
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
|
39
|
+
parser.add_argument(
|
40
|
+
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
41
|
+
)
|
42
|
+
parser.add_argument(
|
43
|
+
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
|
44
|
+
)
|
45
|
+
parser.add_argument(
|
46
|
+
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
47
|
+
)
|
48
|
+
parser.add_argument(
|
49
|
+
"--result-filename", type=str, default=BenchArgs.result_filename
|
50
|
+
)
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def from_cli_args(cls, args: argparse.Namespace):
|
54
|
+
# use the default value's type to case the args into correct types.
|
55
|
+
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
56
|
+
return cls(
|
57
|
+
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
def launch_server_internal(server_args):
|
62
|
+
try:
|
63
|
+
launch_server(server_args)
|
64
|
+
except Exception as e:
|
65
|
+
raise e
|
66
|
+
finally:
|
67
|
+
kill_child_process(os.getpid(), including_parent=False)
|
68
|
+
|
69
|
+
|
70
|
+
def launch_server_process(server_args: ServerArgs):
|
71
|
+
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
|
72
|
+
proc.start()
|
73
|
+
base_url = f"http://{server_args.host}:{server_args.port}"
|
74
|
+
timeout = 600
|
75
|
+
|
76
|
+
start_time = time.time()
|
77
|
+
while time.time() - start_time < timeout:
|
78
|
+
try:
|
79
|
+
headers = {
|
80
|
+
"Content-Type": "application/json; charset=utf-8",
|
81
|
+
}
|
82
|
+
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
83
|
+
if response.status_code == 200:
|
84
|
+
return proc, base_url
|
85
|
+
except requests.RequestException:
|
86
|
+
pass
|
87
|
+
time.sleep(10)
|
88
|
+
raise TimeoutError("Server failed to start within the timeout period.")
|
89
|
+
|
90
|
+
|
91
|
+
def run_one_case(
|
92
|
+
url: str,
|
93
|
+
batch_size: int,
|
94
|
+
input_len: int,
|
95
|
+
output_len: int,
|
96
|
+
run_name: str,
|
97
|
+
result_filename: str,
|
98
|
+
):
|
99
|
+
input_ids = [
|
100
|
+
[int(x) for x in np.random.randint(0, high=16384, size=(input_len,))]
|
101
|
+
for _ in range(batch_size)
|
102
|
+
]
|
103
|
+
|
104
|
+
tic = time.time()
|
105
|
+
response = requests.post(
|
106
|
+
url + "/generate",
|
107
|
+
json={
|
108
|
+
"input_ids": input_ids,
|
109
|
+
"sampling_params": {
|
110
|
+
"temperature": 0,
|
111
|
+
"max_new_tokens": output_len,
|
112
|
+
"ignore_eos": True,
|
113
|
+
},
|
114
|
+
},
|
115
|
+
)
|
116
|
+
latency = time.time() - tic
|
117
|
+
|
118
|
+
_ = response.json()
|
119
|
+
output_throughput = batch_size * output_len / latency
|
120
|
+
overall_throughput = batch_size * (input_len + output_len) / latency
|
121
|
+
|
122
|
+
print(f"batch size: {batch_size}")
|
123
|
+
print(f"latency: {latency:.2f} s")
|
124
|
+
print(f"output throughput: {output_throughput:.2f} token/s")
|
125
|
+
print(f"(input + output) throughput: {overall_throughput:.2f} token/s")
|
126
|
+
|
127
|
+
if result_filename:
|
128
|
+
with open(result_filename, "a") as fout:
|
129
|
+
res = {
|
130
|
+
"run_name": run_name,
|
131
|
+
"batch_size": batch_size,
|
132
|
+
"input_len": input_len,
|
133
|
+
"output_len": output_len,
|
134
|
+
"latency": round(latency, 4),
|
135
|
+
"output_throughput": round(output_throughput, 2),
|
136
|
+
"overall_throughput": round(overall_throughput, 2),
|
137
|
+
}
|
138
|
+
fout.write(json.dumps(res) + "\n")
|
139
|
+
|
140
|
+
|
141
|
+
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
142
|
+
proc, base_url = launch_server_process(server_args)
|
143
|
+
|
144
|
+
# warmup
|
145
|
+
run_one_case(
|
146
|
+
base_url,
|
147
|
+
batch_size=16,
|
148
|
+
input_len=1024,
|
149
|
+
output_len=16,
|
150
|
+
run_name="",
|
151
|
+
result_filename="",
|
152
|
+
)
|
153
|
+
|
154
|
+
# benchmark
|
155
|
+
try:
|
156
|
+
for bs, il, ol in itertools.product(
|
157
|
+
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
158
|
+
):
|
159
|
+
run_one_case(
|
160
|
+
base_url,
|
161
|
+
bs,
|
162
|
+
il,
|
163
|
+
ol,
|
164
|
+
bench_args.run_name,
|
165
|
+
bench_args.result_filename,
|
166
|
+
)
|
167
|
+
finally:
|
168
|
+
kill_child_process(proc.pid)
|
169
|
+
|
170
|
+
print(f"\nResults are saved to {bench_args.result_filename}")
|
171
|
+
|
172
|
+
|
173
|
+
if __name__ == "__main__":
|
174
|
+
parser = argparse.ArgumentParser()
|
175
|
+
ServerArgs.add_cli_args(parser)
|
176
|
+
BenchArgs.add_cli_args(parser)
|
177
|
+
# For this script, model-path is not required
|
178
|
+
assert (
|
179
|
+
parser._actions[1].option_strings[0] == "--model-path"
|
180
|
+
), "options changed, this code need to be updated"
|
181
|
+
parser._actions[1].required = False
|
182
|
+
args = parser.parse_args()
|
183
|
+
|
184
|
+
server_args = ServerArgs.from_cli_args(args)
|
185
|
+
bench_args = BenchArgs.from_cli_args(args)
|
186
|
+
|
187
|
+
run_benchmark(server_args, bench_args)
|
sglang/bench_serving.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
|
3
3
|
|
4
4
|
"""
|
5
|
-
Benchmark online serving.
|
5
|
+
Benchmark online serving with dynamic requests.
|
6
6
|
|
7
7
|
Usage:
|
8
8
|
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
|
sglang/global_config.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
"""Global configurations"""
|
2
2
|
|
3
|
+
import os
|
4
|
+
|
3
5
|
|
4
6
|
class GlobalConfig:
|
5
7
|
def __init__(self):
|
@@ -16,30 +18,20 @@ class GlobalConfig:
|
|
16
18
|
self.base_min_new_token_ratio = 0.1
|
17
19
|
self.new_token_ratio_decay = 0.001
|
18
20
|
|
19
|
-
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
|
20
|
-
# This can improve the speed for large batch sizes during prefill.
|
21
|
-
self.layer_sync_threshold = 8192
|
22
|
-
|
23
21
|
# Runtime constants: others
|
24
22
|
self.num_continue_decode_steps = 10
|
25
23
|
self.retract_decode_steps = 20
|
26
|
-
self.flashinfer_workspace_size =
|
24
|
+
self.flashinfer_workspace_size = os.environ.get(
|
25
|
+
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
|
26
|
+
)
|
27
27
|
|
28
28
|
# Output tokenization configs
|
29
29
|
self.skip_special_tokens_in_output = True
|
30
30
|
self.spaces_between_special_tokens_in_out = True
|
31
31
|
|
32
32
|
# Interpreter optimization configs
|
33
|
-
self.eager_fill_image = False
|
34
33
|
self.enable_precache_with_tracing = True
|
35
34
|
self.enable_parallel_encoding = True
|
36
|
-
self.enable_parallel_decoding = True
|
37
|
-
|
38
|
-
# Deprecated
|
39
|
-
# Choices: ["no_adjust", "adjust_cache"]
|
40
|
-
# no_adjust: Do not adjust the position embedding of KV cache.
|
41
|
-
# adjust_cache: Adjust the position embedding of KV cache.
|
42
|
-
self.concate_and_append_mode = "no_adjust"
|
43
35
|
|
44
36
|
|
45
37
|
global_config = GlobalConfig()
|
sglang/lang/interpreter.py
CHANGED
@@ -434,9 +434,6 @@ class StreamExecutor:
|
|
434
434
|
self.cur_images.append((path, base64_data))
|
435
435
|
self.text_ += self.chat_template.image_token
|
436
436
|
|
437
|
-
# if global_config.eager_fill_image:
|
438
|
-
# self.backend.fill_image(self)
|
439
|
-
|
440
437
|
def _spec_gen(self, sampling_params):
|
441
438
|
stop = sampling_params.stop
|
442
439
|
max_new_tokens = sampling_params.max_new_tokens
|
@@ -29,6 +29,7 @@ class FSMCache(BaseToolCache):
|
|
29
29
|
tokenizer_args_dict,
|
30
30
|
enable=True,
|
31
31
|
skip_tokenizer_init=False,
|
32
|
+
constrained_json_whitespace_pattern=None,
|
32
33
|
):
|
33
34
|
super().__init__(enable=enable)
|
34
35
|
|
@@ -63,11 +64,14 @@ class FSMCache(BaseToolCache):
|
|
63
64
|
self.outlines_tokenizer.vocabulary = (
|
64
65
|
self.outlines_tokenizer.tokenizer.get_vocab()
|
65
66
|
)
|
67
|
+
self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern
|
66
68
|
|
67
69
|
def init_value(self, key):
|
68
70
|
key_type, key_string = key
|
69
71
|
if key_type == "json":
|
70
|
-
regex = build_regex_from_schema(
|
72
|
+
regex = build_regex_from_schema(
|
73
|
+
key_string, whitespace_pattern=self.constrained_json_whitespace_pattern
|
74
|
+
)
|
71
75
|
elif key_type == "regex":
|
72
76
|
regex = key_string
|
73
77
|
else:
|
sglang/srt/layers/activation.py
CHANGED
@@ -13,12 +13,18 @@ limitations under the License.
|
|
13
13
|
|
14
14
|
"""Fused operators for activation layers."""
|
15
15
|
|
16
|
+
import logging
|
16
17
|
from typing import Optional
|
17
18
|
|
18
19
|
import torch
|
19
20
|
import torch.nn as nn
|
20
21
|
import torch.nn.functional as F
|
21
|
-
|
22
|
+
|
23
|
+
from sglang.srt.utils import is_hip
|
24
|
+
|
25
|
+
if not is_hip():
|
26
|
+
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
27
|
+
|
22
28
|
from vllm.distributed import (
|
23
29
|
divide,
|
24
30
|
get_tensor_model_parallel_rank,
|
@@ -28,6 +34,8 @@ from vllm.model_executor.custom_op import CustomOp
|
|
28
34
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
29
35
|
from vllm.model_executor.utils import set_weight_attrs
|
30
36
|
|
37
|
+
logger = logging.getLogger(__name__)
|
38
|
+
|
31
39
|
|
32
40
|
class SiluAndMul(CustomOp):
|
33
41
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -135,3 +143,10 @@ def get_act_fn(
|
|
135
143
|
act_fn, intermediate_size, input_is_parallel, params_dtype
|
136
144
|
)
|
137
145
|
return act_fn
|
146
|
+
|
147
|
+
|
148
|
+
if is_hip():
|
149
|
+
logger.info(
|
150
|
+
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
|
151
|
+
)
|
152
|
+
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
@@ -12,22 +12,26 @@ from typing import TYPE_CHECKING
|
|
12
12
|
|
13
13
|
import torch
|
14
14
|
import torch.nn as nn
|
15
|
-
from flashinfer import (
|
16
|
-
BatchDecodeWithPagedKVCacheWrapper,
|
17
|
-
BatchPrefillWithPagedKVCacheWrapper,
|
18
|
-
BatchPrefillWithRaggedKVCacheWrapper,
|
19
|
-
)
|
20
|
-
from flashinfer.cascade import merge_state
|
21
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
22
15
|
|
23
16
|
from sglang.global_config import global_config
|
24
17
|
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
25
18
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
26
19
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
20
|
+
from sglang.srt.utils import is_hip
|
27
21
|
|
28
22
|
if TYPE_CHECKING:
|
29
23
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
30
24
|
|
25
|
+
# ROCm: flashinfer available later
|
26
|
+
if not is_hip():
|
27
|
+
from flashinfer import (
|
28
|
+
BatchDecodeWithPagedKVCacheWrapper,
|
29
|
+
BatchPrefillWithPagedKVCacheWrapper,
|
30
|
+
BatchPrefillWithRaggedKVCacheWrapper,
|
31
|
+
)
|
32
|
+
from flashinfer.cascade import merge_state
|
33
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
34
|
+
|
31
35
|
|
32
36
|
class AttentionBackend(ABC):
|
33
37
|
"""The base class of attention backends"""
|
@@ -150,7 +154,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
150
154
|
# Some heuristics to check whether to use ragged forward
|
151
155
|
use_ragged = False
|
152
156
|
if (
|
153
|
-
|
157
|
+
torch.sum(input_metadata.seq_lens).item() >= 4096
|
154
158
|
and self.model_runner.sliding_window_size is None
|
155
159
|
):
|
156
160
|
use_ragged = True
|
@@ -301,10 +305,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
301
305
|
layer.layer_id, input_metadata.out_cache_loc, k, v
|
302
306
|
)
|
303
307
|
|
304
|
-
if total_num_tokens >= global_config.layer_sync_threshold:
|
305
|
-
# TODO: Revisit this. Why is this synchronize needed?
|
306
|
-
torch.cuda.synchronize()
|
307
|
-
|
308
308
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
309
309
|
|
310
310
|
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
@@ -18,6 +18,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|
18
18
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
19
19
|
from vllm.model_executor.utils import set_weight_attrs
|
20
20
|
|
21
|
+
from sglang.srt.utils import is_hip
|
22
|
+
|
21
23
|
logger = init_logger(__name__)
|
22
24
|
|
23
25
|
|
@@ -381,6 +383,7 @@ from torch.nn import Module
|
|
381
383
|
from vllm import _custom_ops as ops
|
382
384
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
383
385
|
all_close_1d,
|
386
|
+
normalize_e4m3fn_to_e4m3fnuz,
|
384
387
|
per_tensor_dequantize,
|
385
388
|
)
|
386
389
|
from vllm.utils import print_warning_once
|
@@ -479,14 +482,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
479
482
|
|
480
483
|
def process_weights_after_loading(self, layer: Module) -> None:
|
481
484
|
|
482
|
-
# If checkpoint is fp16, quantize in place.
|
485
|
+
# If checkpoint is fp16 or bfloat16, quantize in place.
|
483
486
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
484
|
-
|
485
|
-
|
486
|
-
)
|
487
|
-
w2_weight = torch.empty_like(
|
488
|
-
layer.w2_weight.data, dtype=torch.float8_e4m3fn
|
489
|
-
)
|
487
|
+
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
488
|
+
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
489
|
+
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
490
|
+
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
490
491
|
|
491
492
|
# Re-initialize w13_scale because we directly quantize
|
492
493
|
# merged w13 weights and generate a single scaling factor.
|
@@ -534,6 +535,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
534
535
|
layer.a2_scale.max(), requires_grad=False
|
535
536
|
)
|
536
537
|
|
538
|
+
# If ROCm, normalize the weights and scales to e4m3fnuz
|
539
|
+
if is_hip():
|
540
|
+
# Normalize the weights and scales
|
541
|
+
w13_weight, w13_scale, a13_scale = normalize_e4m3fn_to_e4m3fnuz(
|
542
|
+
layer.w13_weight, layer.w13_scale, layer.a13_scale
|
543
|
+
)
|
544
|
+
w2_weight, w2_scale, a2_scale = normalize_e4m3fn_to_e4m3fnuz(
|
545
|
+
layer.w2_weight, layer.w2_scale, layer.a2_scale
|
546
|
+
)
|
547
|
+
# Reset the parameters
|
548
|
+
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
549
|
+
layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False)
|
550
|
+
if a13_scale is not None:
|
551
|
+
layer.a13_scale = torch.nn.Parameter(a13_scale, requires_grad=False)
|
552
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
553
|
+
layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False)
|
554
|
+
if a2_scale is not None:
|
555
|
+
layer.a2_scale = torch.nn.Parameter(a2_scale, requires_grad=False)
|
556
|
+
|
537
557
|
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
538
558
|
# We take the max then dequant and requant each expert.
|
539
559
|
assert layer.w13_scale is not None
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -15,18 +15,26 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""Fused operators for normalization layers."""
|
17
17
|
|
18
|
+
import logging
|
18
19
|
from typing import Optional, Tuple, Union
|
19
20
|
|
20
21
|
import torch
|
21
22
|
import torch.nn as nn
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
23
|
+
|
24
|
+
from sglang.srt.utils import is_hip
|
25
|
+
|
26
|
+
if not is_hip():
|
27
|
+
from flashinfer.norm import (
|
28
|
+
fused_add_rmsnorm,
|
29
|
+
gemma_fused_add_rmsnorm,
|
30
|
+
gemma_rmsnorm,
|
31
|
+
rmsnorm,
|
32
|
+
)
|
33
|
+
|
28
34
|
from vllm.model_executor.custom_op import CustomOp
|
29
35
|
|
36
|
+
logger = logging.getLogger(__name__)
|
37
|
+
|
30
38
|
|
31
39
|
class RMSNorm(CustomOp):
|
32
40
|
def __init__(
|
@@ -109,3 +117,10 @@ class GemmaRMSNorm(CustomOp):
|
|
109
117
|
return x, residual
|
110
118
|
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
111
119
|
return out
|
120
|
+
|
121
|
+
|
122
|
+
if is_hip():
|
123
|
+
logger.info(
|
124
|
+
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
|
125
|
+
)
|
126
|
+
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
sglang/srt/layers/sampler.py
CHANGED
@@ -1,51 +1,28 @@
|
|
1
|
-
import dataclasses
|
2
1
|
import logging
|
3
|
-
from typing import
|
2
|
+
from typing import Union
|
4
3
|
|
5
4
|
import torch
|
6
|
-
from
|
7
|
-
min_p_sampling_from_probs,
|
8
|
-
top_k_renorm_prob,
|
9
|
-
top_k_top_p_sampling_from_probs,
|
10
|
-
top_p_renorm_prob,
|
11
|
-
)
|
12
|
-
from torch.library import custom_op as torch_custom_op
|
13
|
-
from vllm.model_executor.custom_op import CustomOp
|
5
|
+
from torch import nn
|
14
6
|
|
15
7
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
16
|
-
|
17
|
-
# TODO: move this dict to another place
|
18
8
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
19
9
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
10
|
+
from sglang.srt.utils import is_hip
|
11
|
+
|
12
|
+
# ROCm: flashinfer available later
|
13
|
+
if not is_hip():
|
14
|
+
from flashinfer.sampling import (
|
15
|
+
min_p_sampling_from_probs,
|
16
|
+
top_k_renorm_prob,
|
17
|
+
top_k_top_p_sampling_from_probs,
|
18
|
+
top_p_renorm_prob,
|
19
|
+
)
|
20
20
|
|
21
21
|
logger = logging.getLogger(__name__)
|
22
22
|
|
23
23
|
|
24
|
-
|
25
|
-
|
26
|
-
success: torch.Tensor
|
27
|
-
probs: torch.Tensor
|
28
|
-
batch_next_token_ids: torch.Tensor
|
29
|
-
|
30
|
-
|
31
|
-
class Sampler(CustomOp):
|
32
|
-
def __init__(self):
|
33
|
-
super().__init__()
|
34
|
-
# FIXME: torch.multinomial has too many bugs
|
35
|
-
self.forward_native = self.forward_cuda
|
36
|
-
self.is_torch_compile = False
|
37
|
-
|
38
|
-
def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
39
|
-
# Post process logits
|
40
|
-
logits = logits.contiguous()
|
41
|
-
logits.div_(sampling_info.temperatures)
|
42
|
-
if self.is_torch_compile:
|
43
|
-
# FIXME: Temporary workaround for unknown bugs in torch.compile
|
44
|
-
logits.add_(0)
|
45
|
-
|
46
|
-
return torch.softmax(logits, dim=-1)
|
47
|
-
|
48
|
-
def forward_cuda(
|
24
|
+
class Sampler(nn.Module):
|
25
|
+
def forward(
|
49
26
|
self,
|
50
27
|
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
51
28
|
sampling_info: SamplingBatchInfo,
|
@@ -53,7 +30,18 @@ class Sampler(CustomOp):
|
|
53
30
|
if isinstance(logits, LogitsProcessorOutput):
|
54
31
|
logits = logits.next_token_logits
|
55
32
|
|
56
|
-
|
33
|
+
# Post process logits
|
34
|
+
logits = logits.contiguous()
|
35
|
+
logits.div_(sampling_info.temperatures)
|
36
|
+
probs = torch.softmax(logits, dim=-1)
|
37
|
+
logits = None
|
38
|
+
del logits
|
39
|
+
|
40
|
+
if torch.any(torch.isnan(probs)):
|
41
|
+
logger.warning("Detected errors during sampling! NaN in the probability.")
|
42
|
+
probs = torch.where(
|
43
|
+
torch.isnan(probs), torch.full_like(probs, 1e-10), probs
|
44
|
+
)
|
57
45
|
|
58
46
|
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
59
47
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
@@ -67,12 +55,20 @@ class Sampler(CustomOp):
|
|
67
55
|
probs, uniform_samples, sampling_info.min_ps
|
68
56
|
)
|
69
57
|
else:
|
70
|
-
batch_next_token_ids, success =
|
71
|
-
probs,
|
58
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
59
|
+
probs,
|
60
|
+
uniform_samples,
|
61
|
+
sampling_info.top_ks,
|
62
|
+
sampling_info.top_ps,
|
63
|
+
filter_apply_order="joint",
|
72
64
|
)
|
65
|
+
|
66
|
+
if not torch.all(success):
|
67
|
+
logger.warning("Detected errors during sampling!")
|
68
|
+
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
73
69
|
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
74
70
|
# Here we provide a slower fallback implementation.
|
75
|
-
batch_next_token_ids
|
71
|
+
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
76
72
|
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
77
73
|
)
|
78
74
|
else:
|
@@ -80,48 +76,7 @@ class Sampler(CustomOp):
|
|
80
76
|
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
81
77
|
)
|
82
78
|
|
83
|
-
return
|
84
|
-
|
85
|
-
def forward_native(
|
86
|
-
self,
|
87
|
-
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
88
|
-
sampling_info: SamplingBatchInfo,
|
89
|
-
):
|
90
|
-
if isinstance(logits, LogitsProcessorOutput):
|
91
|
-
logits = logits.next_token_logits
|
92
|
-
|
93
|
-
probs = self._get_probs(logits, sampling_info)
|
94
|
-
|
95
|
-
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
96
|
-
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
97
|
-
)
|
98
|
-
|
99
|
-
return SampleOutput(success, probs, batch_next_token_ids)
|
100
|
-
|
101
|
-
|
102
|
-
@torch_custom_op("my_lib::flashinfer_top_k_top_p", mutates_args={})
|
103
|
-
def flashinfer_top_k_top_p(
|
104
|
-
probs: torch.Tensor,
|
105
|
-
uniform_samples: torch.Tensor,
|
106
|
-
top_ks: torch.Tensor,
|
107
|
-
top_ps: torch.Tensor,
|
108
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
109
|
-
# NOTE: we do not use min_p neither in CUDA nor in torch.compile
|
110
|
-
return top_k_top_p_sampling_from_probs(probs, uniform_samples, top_ks, top_ps)
|
111
|
-
|
112
|
-
|
113
|
-
@flashinfer_top_k_top_p.register_fake
|
114
|
-
def _(
|
115
|
-
probs: torch.Tensor,
|
116
|
-
uniform_samples: torch.Tensor,
|
117
|
-
top_ks: torch.Tensor,
|
118
|
-
top_ps: torch.Tensor,
|
119
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
120
|
-
bs = probs.shape[0]
|
121
|
-
return (
|
122
|
-
torch.ones(bs, dtype=torch.bool, device=probs.device),
|
123
|
-
torch.zeros(bs, dtype=torch.int32, device=probs.device),
|
124
|
-
)
|
79
|
+
return batch_next_token_ids
|
125
80
|
|
126
81
|
|
127
82
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
@@ -141,19 +96,6 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|
141
96
|
] = 0.0
|
142
97
|
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
143
98
|
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
144
|
-
|
145
|
-
# FIXME: torch.multiomial does not support num_samples = 1
|
146
|
-
sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
|
147
|
-
:, :1
|
148
|
-
]
|
149
|
-
except RuntimeError as e:
|
150
|
-
logger.warning(f"Sampling error: {e}")
|
151
|
-
batch_next_token_ids = torch.zeros(
|
152
|
-
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
|
153
|
-
)
|
154
|
-
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
155
|
-
return batch_next_token_ids, success
|
156
|
-
|
99
|
+
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
157
100
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
158
|
-
|
159
|
-
return batch_next_token_ids, success
|
101
|
+
return batch_next_token_ids
|