sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +7 -7
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +158 -11
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +12 -2
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +28 -3
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +8 -2
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +3 -1
- sglang/srt/hf_transformers_utils.py +130 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +582 -0
- sglang/srt/layers/logits_processor.py +65 -32
- sglang/srt/layers/radix_attention.py +41 -7
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
- sglang/srt/managers/{router → controller}/model_runner.py +262 -158
- sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
- sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
- sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
- sglang/srt/managers/detokenizer_manager.py +42 -46
- sglang/srt/managers/io_struct.py +22 -12
- sglang/srt/managers/tokenizer_manager.py +151 -87
- sglang/srt/model_config.py +83 -5
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +12 -15
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +26 -15
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +86 -19
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +282 -103
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +150 -95
- sglang/srt/openai_protocol.py +11 -2
- sglang/srt/server.py +124 -48
- sglang/srt/server_args.py +128 -48
- sglang/srt/utils.py +234 -67
- sglang/test/test_programs.py +65 -3
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +23 -4
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,26 @@
|
|
1
|
+
"""A controller that manages a group of tensor parallel workers."""
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
import logging
|
5
|
+
from concurrent.futures import ThreadPoolExecutor
|
3
6
|
|
4
7
|
import uvloop
|
5
8
|
import zmq
|
6
9
|
import zmq.asyncio
|
7
10
|
|
8
11
|
from sglang.global_config import global_config
|
9
|
-
from sglang.srt.managers.
|
12
|
+
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
10
13
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
11
|
-
from sglang.srt.utils import
|
14
|
+
from sglang.srt.utils import kill_parent_process
|
15
|
+
from sglang.utils import get_exception_traceback
|
12
16
|
|
13
17
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
14
18
|
|
19
|
+
logger = logging.getLogger("srt.controller")
|
20
|
+
|
15
21
|
|
16
|
-
class
|
17
|
-
def __init__(self, model_client:
|
22
|
+
class ControllerSingle:
|
23
|
+
def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
|
18
24
|
# Init communication
|
19
25
|
context = zmq.asyncio.Context(2)
|
20
26
|
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
@@ -30,7 +36,7 @@ class RouterManager:
|
|
30
36
|
self.recv_reqs = []
|
31
37
|
|
32
38
|
# Init some configs
|
33
|
-
self.
|
39
|
+
self.request_dependency_delay = global_config.request_dependency_delay
|
34
40
|
|
35
41
|
async def loop_for_forward(self):
|
36
42
|
while True:
|
@@ -44,14 +50,16 @@ class RouterManager:
|
|
44
50
|
# async sleep for receiving the subsequent request and avoiding cache miss
|
45
51
|
slept = False
|
46
52
|
if len(out_pyobjs) != 0:
|
47
|
-
has_finished = any(
|
53
|
+
has_finished = any(
|
54
|
+
[obj.finished_reason is not None for obj in out_pyobjs]
|
55
|
+
)
|
48
56
|
if has_finished:
|
49
|
-
if self.
|
57
|
+
if self.request_dependency_delay > 0:
|
50
58
|
slept = True
|
51
|
-
await asyncio.sleep(self.
|
59
|
+
await asyncio.sleep(self.request_dependency_delay)
|
52
60
|
|
53
61
|
if not slept:
|
54
|
-
await asyncio.sleep(
|
62
|
+
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
55
63
|
|
56
64
|
async def loop_for_recv_requests(self):
|
57
65
|
while True:
|
@@ -59,7 +67,7 @@ class RouterManager:
|
|
59
67
|
self.recv_reqs.append(recv_req)
|
60
68
|
|
61
69
|
|
62
|
-
def
|
70
|
+
def start_controller_process(
|
63
71
|
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
|
64
72
|
):
|
65
73
|
logging.basicConfig(
|
@@ -68,8 +76,14 @@ def start_router_process(
|
|
68
76
|
)
|
69
77
|
|
70
78
|
try:
|
71
|
-
|
72
|
-
|
79
|
+
tp_size_local = server_args.tp_size // server_args.nnodes
|
80
|
+
model_client = ModelTpClient(
|
81
|
+
[i for _ in range(server_args.nnodes) for i in range(tp_size_local)],
|
82
|
+
server_args,
|
83
|
+
port_args.model_port_args[0],
|
84
|
+
model_overide_args,
|
85
|
+
)
|
86
|
+
controller = ControllerSingle(model_client, port_args)
|
73
87
|
except Exception:
|
74
88
|
pipe_writer.send(get_exception_traceback())
|
75
89
|
raise
|
@@ -77,6 +91,12 @@ def start_router_process(
|
|
77
91
|
pipe_writer.send("init ok")
|
78
92
|
|
79
93
|
loop = asyncio.new_event_loop()
|
94
|
+
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
|
80
95
|
asyncio.set_event_loop(loop)
|
81
|
-
loop.create_task(
|
82
|
-
|
96
|
+
loop.create_task(controller.loop_for_recv_requests())
|
97
|
+
try:
|
98
|
+
loop.run_until_complete(controller.loop_for_forward())
|
99
|
+
except Exception:
|
100
|
+
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
101
|
+
finally:
|
102
|
+
kill_parent_process()
|
@@ -1,69 +1,40 @@
|
|
1
|
+
"""ModelRunner runs the forward passes of the models."""
|
2
|
+
|
1
3
|
import importlib
|
2
4
|
import importlib.resources
|
3
|
-
import inspect
|
4
5
|
import logging
|
5
6
|
import pkgutil
|
6
7
|
from dataclasses import dataclass
|
7
8
|
from functools import lru_cache
|
8
|
-
from typing import List
|
9
|
+
from typing import List, Optional, Type
|
9
10
|
|
10
11
|
import numpy as np
|
11
12
|
import torch
|
12
|
-
|
13
|
-
from vllm.
|
14
|
-
from vllm.
|
15
|
-
from vllm.
|
16
|
-
from vllm.model_executor.model_loader
|
17
|
-
|
18
|
-
|
13
|
+
import torch.nn as nn
|
14
|
+
from vllm.config import DeviceConfig, LoadConfig
|
15
|
+
from vllm.config import ModelConfig as VllmModelConfig
|
16
|
+
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
17
|
+
from vllm.model_executor.model_loader import get_model
|
18
|
+
from vllm.model_executor.models import ModelRegistry
|
19
|
+
|
20
|
+
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
|
19
21
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
20
|
-
from sglang.srt.
|
21
|
-
|
22
|
+
from sglang.srt.server_args import ServerArgs
|
23
|
+
from sglang.srt.utils import (
|
24
|
+
get_available_gpu_memory,
|
25
|
+
is_multimodal_model,
|
26
|
+
monkey_patch_vllm_dummy_weight_loader,
|
27
|
+
monkey_patch_vllm_p2p_access_check,
|
28
|
+
)
|
22
29
|
|
23
|
-
|
24
|
-
"awq": AWQConfig,
|
25
|
-
"gptq": GPTQConfig,
|
26
|
-
"marlin": MarlinConfig,
|
27
|
-
}
|
28
|
-
|
29
|
-
logger = logging.getLogger("model_runner")
|
30
|
+
logger = logging.getLogger("srt.model_runner")
|
30
31
|
|
31
32
|
# for server args in model endpoints
|
32
33
|
global_server_args_dict = {}
|
33
34
|
|
34
35
|
|
35
|
-
@lru_cache()
|
36
|
-
def import_model_classes():
|
37
|
-
model_arch_name_to_cls = {}
|
38
|
-
package_name = "sglang.srt.models"
|
39
|
-
package = importlib.import_module(package_name)
|
40
|
-
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
41
|
-
if not ispkg:
|
42
|
-
module = importlib.import_module(name)
|
43
|
-
if hasattr(module, "EntryClass"):
|
44
|
-
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
|
45
|
-
return model_arch_name_to_cls
|
46
|
-
|
47
|
-
|
48
|
-
def get_model_cls_by_arch_name(model_arch_names):
|
49
|
-
model_arch_name_to_cls = import_model_classes()
|
50
|
-
|
51
|
-
model_class = None
|
52
|
-
for arch in model_arch_names:
|
53
|
-
if arch in model_arch_name_to_cls:
|
54
|
-
model_class = model_arch_name_to_cls[arch]
|
55
|
-
break
|
56
|
-
else:
|
57
|
-
raise ValueError(
|
58
|
-
f"Unsupported architectures: {arch}. "
|
59
|
-
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
60
|
-
)
|
61
|
-
return model_class
|
62
|
-
|
63
|
-
|
64
36
|
@dataclass
|
65
37
|
class InputMetadata:
|
66
|
-
model_runner: "ModelRunner"
|
67
38
|
forward_mode: ForwardMode
|
68
39
|
batch_size: int
|
69
40
|
total_num_tokens: int
|
@@ -94,73 +65,82 @@ class InputMetadata:
|
|
94
65
|
kv_indptr: torch.Tensor = None
|
95
66
|
kv_indices: torch.Tensor = None
|
96
67
|
kv_last_page_len: torch.Tensor = None
|
97
|
-
|
98
|
-
|
68
|
+
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
69
|
+
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
70
|
+
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
99
71
|
|
100
|
-
def init_flashinfer_args(self,
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
)
|
72
|
+
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
|
73
|
+
if (
|
74
|
+
self.forward_mode == ForwardMode.PREFILL
|
75
|
+
or self.forward_mode == ForwardMode.EXTEND
|
76
|
+
):
|
77
|
+
paged_kernel_lens = self.prefix_lens
|
78
|
+
self.no_prefix = torch.all(self.prefix_lens == 0)
|
79
|
+
else:
|
80
|
+
paged_kernel_lens = self.seq_lens
|
105
81
|
|
106
82
|
self.kv_indptr = torch.zeros(
|
107
83
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
108
84
|
)
|
109
|
-
self.kv_indptr[1:] = torch.cumsum(
|
85
|
+
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
110
86
|
self.kv_last_page_len = torch.ones(
|
111
87
|
(self.batch_size,), dtype=torch.int32, device="cuda"
|
112
88
|
)
|
113
89
|
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
114
|
-
|
90
|
+
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
115
91
|
self.kv_indices = torch.cat(
|
116
92
|
[
|
117
93
|
self.req_to_token_pool.req_to_token[
|
118
|
-
req_pool_indices_cpu[i], :
|
94
|
+
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
119
95
|
]
|
120
96
|
for i in range(self.batch_size)
|
121
97
|
],
|
122
98
|
dim=0,
|
123
99
|
).contiguous()
|
124
100
|
|
125
|
-
workspace_buffer = torch.empty(
|
126
|
-
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
127
|
-
)
|
128
101
|
if (
|
129
102
|
self.forward_mode == ForwardMode.PREFILL
|
130
103
|
or self.forward_mode == ForwardMode.EXTEND
|
131
104
|
):
|
105
|
+
# extend part
|
132
106
|
self.qo_indptr = torch.zeros(
|
133
107
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
134
108
|
)
|
135
109
|
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
136
|
-
|
137
|
-
|
110
|
+
|
111
|
+
self.flashinfer_prefill_wrapper_ragged.end_forward()
|
112
|
+
self.flashinfer_prefill_wrapper_ragged.begin_forward(
|
113
|
+
self.qo_indptr,
|
114
|
+
self.qo_indptr.clone(),
|
115
|
+
num_qo_heads,
|
116
|
+
num_kv_heads,
|
117
|
+
head_dim,
|
138
118
|
)
|
139
|
-
|
119
|
+
|
120
|
+
# cached part
|
121
|
+
self.flashinfer_prefill_wrapper_paged.end_forward()
|
122
|
+
self.flashinfer_prefill_wrapper_paged.begin_forward(
|
140
123
|
self.qo_indptr,
|
141
124
|
self.kv_indptr,
|
142
125
|
self.kv_indices,
|
143
126
|
self.kv_last_page_len,
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
self.prefill_wrapper.begin_forward(*args)
|
150
|
-
else:
|
151
|
-
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
152
|
-
workspace_buffer, "NHD"
|
127
|
+
num_qo_heads,
|
128
|
+
num_kv_heads,
|
129
|
+
head_dim,
|
130
|
+
1
|
153
131
|
)
|
154
|
-
|
132
|
+
else:
|
133
|
+
self.flashinfer_decode_wrapper.end_forward()
|
134
|
+
self.flashinfer_decode_wrapper.begin_forward(
|
155
135
|
self.kv_indptr,
|
156
136
|
self.kv_indices,
|
157
137
|
self.kv_last_page_len,
|
158
|
-
|
159
|
-
|
160
|
-
|
138
|
+
num_qo_heads,
|
139
|
+
num_kv_heads,
|
140
|
+
head_dim,
|
161
141
|
1,
|
162
|
-
"NONE",
|
163
|
-
|
142
|
+
pos_encoding_mode="NONE",
|
143
|
+
data_type=self.token_to_kv_pool.kv_data[0].dtype
|
164
144
|
)
|
165
145
|
|
166
146
|
def init_extend_args(self):
|
@@ -184,6 +164,9 @@ class InputMetadata:
|
|
184
164
|
out_cache_cont_end=None,
|
185
165
|
top_logprobs_nums=None,
|
186
166
|
return_logprob=False,
|
167
|
+
flashinfer_prefill_wrapper_ragged=None,
|
168
|
+
flashinfer_prefill_wrapper_paged=None,
|
169
|
+
flashinfer_decode_wrapper=None,
|
187
170
|
):
|
188
171
|
batch_size = len(req_pool_indices)
|
189
172
|
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
@@ -216,7 +199,6 @@ class InputMetadata:
|
|
216
199
|
other_kv_index = None
|
217
200
|
|
218
201
|
ret = cls(
|
219
|
-
model_runner=model_runner,
|
220
202
|
forward_mode=forward_mode,
|
221
203
|
batch_size=batch_size,
|
222
204
|
total_num_tokens=total_num_tokens,
|
@@ -234,13 +216,20 @@ class InputMetadata:
|
|
234
216
|
other_kv_index=other_kv_index,
|
235
217
|
return_logprob=return_logprob,
|
236
218
|
top_logprobs_nums=top_logprobs_nums,
|
219
|
+
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
|
220
|
+
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
|
221
|
+
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
237
222
|
)
|
238
223
|
|
239
224
|
if forward_mode == ForwardMode.EXTEND:
|
240
225
|
ret.init_extend_args()
|
241
226
|
|
242
|
-
if global_server_args_dict.get("
|
243
|
-
ret.init_flashinfer_args(
|
227
|
+
if not global_server_args_dict.get("disable_flashinfer", False):
|
228
|
+
ret.init_flashinfer_args(
|
229
|
+
model_runner.model_config.num_attention_heads // tp_size,
|
230
|
+
model_runner.model_config.get_num_kv_heads(tp_size),
|
231
|
+
model_runner.model_config.head_dim
|
232
|
+
)
|
244
233
|
|
245
234
|
return ret
|
246
235
|
|
@@ -249,122 +238,180 @@ class ModelRunner:
|
|
249
238
|
def __init__(
|
250
239
|
self,
|
251
240
|
model_config,
|
252
|
-
mem_fraction_static,
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
server_args_dict: dict = {},
|
241
|
+
mem_fraction_static: float,
|
242
|
+
gpu_id: int,
|
243
|
+
tp_rank: int,
|
244
|
+
tp_size: int,
|
245
|
+
nccl_port: int,
|
246
|
+
server_args: ServerArgs,
|
259
247
|
):
|
260
248
|
self.model_config = model_config
|
261
249
|
self.mem_fraction_static = mem_fraction_static
|
250
|
+
self.gpu_id = gpu_id
|
262
251
|
self.tp_rank = tp_rank
|
263
252
|
self.tp_size = tp_size
|
264
253
|
self.nccl_port = nccl_port
|
265
|
-
self.
|
266
|
-
self.
|
267
|
-
|
268
|
-
global global_server_args_dict
|
269
|
-
global_server_args_dict = server_args_dict
|
254
|
+
self.server_args = server_args
|
255
|
+
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
256
|
+
monkey_patch_vllm_dummy_weight_loader()
|
270
257
|
|
271
258
|
# Init torch distributed
|
272
|
-
|
273
|
-
torch.
|
259
|
+
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
260
|
+
torch.cuda.set_device(self.gpu_id)
|
261
|
+
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
262
|
+
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
263
|
+
if server_args.nccl_init_addr:
|
264
|
+
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
|
265
|
+
else:
|
266
|
+
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
|
267
|
+
init_distributed_environment(
|
274
268
|
backend="nccl",
|
275
269
|
world_size=self.tp_size,
|
276
270
|
rank=self.tp_rank,
|
277
|
-
|
271
|
+
local_rank=self.gpu_id,
|
272
|
+
distributed_init_method=nccl_init_method
|
278
273
|
)
|
279
|
-
|
280
274
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
281
|
-
|
282
275
|
total_gpu_memory = get_available_gpu_memory(
|
283
|
-
self.
|
284
|
-
)
|
276
|
+
self.gpu_id, distributed=self.tp_size > 1
|
277
|
+
)
|
278
|
+
|
279
|
+
if self.tp_size > 1:
|
280
|
+
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
281
|
+
if total_local_gpu_memory < total_gpu_memory * 0.9:
|
282
|
+
raise ValueError(
|
283
|
+
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
284
|
+
)
|
285
|
+
|
286
|
+
# Set some global args
|
287
|
+
global global_server_args_dict
|
288
|
+
global_server_args_dict = {
|
289
|
+
"disable_flashinfer": server_args.disable_flashinfer,
|
290
|
+
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
291
|
+
}
|
292
|
+
|
293
|
+
# Load the model and create memory pool
|
285
294
|
self.load_model()
|
286
295
|
self.init_memory_pool(total_gpu_memory)
|
287
|
-
|
288
|
-
self.
|
296
|
+
self.init_cublas()
|
297
|
+
self.init_flash_infer()
|
289
298
|
|
290
299
|
def load_model(self):
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
logger.info(f"Rank {self.tp_rank}: load weight begin.")
|
296
|
-
|
297
|
-
# Load weights
|
298
|
-
quant_config = None
|
299
|
-
|
300
|
-
quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
|
301
|
-
if quant_cfg is not None:
|
302
|
-
quant_method = quant_cfg.get("quant_method", "").lower()
|
303
|
-
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
304
|
-
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
305
|
-
is_format_marlin = quant_cfg.get(
|
306
|
-
"checkpoint_format"
|
307
|
-
) == "marlin" or quant_cfg.get("is_marlin_format", False)
|
308
|
-
|
309
|
-
# Use marlin if the GPTQ model is serialized in marlin format.
|
310
|
-
if quant_method == "gptq" and is_format_marlin:
|
311
|
-
quant_method = "marlin"
|
312
|
-
|
313
|
-
quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method)
|
314
|
-
|
315
|
-
if quant_config_class is None:
|
316
|
-
raise ValueError(f"Unsupported quantization method: {quant_method}")
|
317
|
-
|
318
|
-
quant_config = quant_config_class.from_config(quant_cfg)
|
319
|
-
logger.info(f"quant_config: {quant_config}")
|
320
|
-
|
321
|
-
with set_default_torch_dtype(torch.float16):
|
322
|
-
with torch.device("cuda"):
|
323
|
-
model = model_class(
|
324
|
-
config=self.model_config.hf_config, quant_config=quant_config
|
325
|
-
)
|
326
|
-
model.load_weights(
|
327
|
-
self.model_config.path,
|
328
|
-
cache_dir=None,
|
329
|
-
load_format=self.load_format,
|
330
|
-
revision=None,
|
331
|
-
)
|
332
|
-
self.model = model.eval()
|
300
|
+
logger.info(
|
301
|
+
f"[gpu_id={self.gpu_id}] Load weight begin. "
|
302
|
+
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
303
|
+
)
|
333
304
|
|
334
|
-
|
305
|
+
device_config = DeviceConfig()
|
306
|
+
load_config = LoadConfig(load_format=self.server_args.load_format)
|
307
|
+
vllm_model_config = VllmModelConfig(
|
308
|
+
model=self.server_args.model_path,
|
309
|
+
quantization=self.server_args.quantization,
|
310
|
+
tokenizer=None,
|
311
|
+
tokenizer_mode=None,
|
312
|
+
trust_remote_code=self.server_args.trust_remote_code,
|
313
|
+
dtype=self.server_args.dtype,
|
314
|
+
seed=42,
|
315
|
+
skip_tokenizer_init=True,
|
316
|
+
)
|
317
|
+
self.dtype = vllm_model_config.dtype
|
318
|
+
if self.model_config.model_overide_args is not None:
|
319
|
+
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
320
|
+
|
321
|
+
self.model = get_model(
|
322
|
+
model_config=vllm_model_config,
|
323
|
+
device_config=device_config,
|
324
|
+
load_config=load_config,
|
325
|
+
lora_config=None,
|
326
|
+
vision_language_config=None,
|
327
|
+
parallel_config=None,
|
328
|
+
scheduler_config=None,
|
329
|
+
cache_config=None,
|
330
|
+
)
|
331
|
+
logger.info(
|
332
|
+
f"[gpu_id={self.gpu_id}] Load weight end. "
|
333
|
+
f"type={type(self.model).__name__}, "
|
334
|
+
f"dtype={self.dtype}, "
|
335
|
+
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
336
|
+
)
|
335
337
|
|
336
338
|
def profile_max_num_token(self, total_gpu_memory):
|
337
339
|
available_gpu_memory = get_available_gpu_memory(
|
338
|
-
self.
|
339
|
-
)
|
340
|
+
self.gpu_id, distributed=self.tp_size > 1
|
341
|
+
)
|
340
342
|
head_dim = self.model_config.head_dim
|
341
|
-
head_num = self.model_config.
|
342
|
-
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 *
|
343
|
+
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
344
|
+
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * torch._utils._element_size(self.dtype)
|
343
345
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
344
346
|
1 - self.mem_fraction_static
|
345
347
|
)
|
346
|
-
max_num_token = int(rest_memory // cell_size)
|
348
|
+
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
347
349
|
return max_num_token
|
348
350
|
|
349
351
|
def init_memory_pool(self, total_gpu_memory):
|
350
|
-
self.
|
352
|
+
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
351
353
|
|
352
|
-
if self.
|
354
|
+
if self.max_total_num_tokens <= 0:
|
353
355
|
raise RuntimeError(
|
354
|
-
"Not
|
356
|
+
"Not enough memory. Please try to increase --mem-fraction-static."
|
355
357
|
)
|
356
358
|
|
357
359
|
self.req_to_token_pool = ReqToTokenPool(
|
358
|
-
int(self.
|
360
|
+
int(self.max_total_num_tokens / self.model_config.context_len * 256),
|
359
361
|
self.model_config.context_len + 8,
|
360
362
|
)
|
361
363
|
self.token_to_kv_pool = TokenToKVPool(
|
362
|
-
self.
|
363
|
-
dtype=
|
364
|
-
head_num=self.model_config.
|
364
|
+
self.max_total_num_tokens,
|
365
|
+
dtype=self.dtype,
|
366
|
+
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
365
367
|
head_dim=self.model_config.head_dim,
|
366
368
|
layer_num=self.model_config.num_hidden_layers,
|
367
369
|
)
|
370
|
+
logger.info(
|
371
|
+
f"[gpu_id={self.gpu_id}] Memory pool end. "
|
372
|
+
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
373
|
+
)
|
374
|
+
|
375
|
+
def init_cublas(self):
|
376
|
+
"""We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
|
377
|
+
dtype = torch.float16
|
378
|
+
device = "cuda"
|
379
|
+
a = torch.ones((16, 16), dtype=dtype, device=device)
|
380
|
+
b = torch.ones((16, 16), dtype=dtype, device=device)
|
381
|
+
c = a @ b
|
382
|
+
return c
|
383
|
+
|
384
|
+
def init_flash_infer(self):
|
385
|
+
if not global_server_args_dict.get("disable_flashinfer", False):
|
386
|
+
from flashinfer import (
|
387
|
+
BatchPrefillWithRaggedKVCacheWrapper,
|
388
|
+
BatchPrefillWithPagedKVCacheWrapper,
|
389
|
+
BatchDecodeWithPagedKVCacheWrapper,
|
390
|
+
)
|
391
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
392
|
+
|
393
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
394
|
+
self.model_config.num_attention_heads // self.tp_size,
|
395
|
+
self.model_config.get_num_kv_heads(self.tp_size)):
|
396
|
+
use_tensor_cores = True
|
397
|
+
else:
|
398
|
+
use_tensor_cores = False
|
399
|
+
|
400
|
+
workspace_buffers = torch.empty(
|
401
|
+
3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
|
402
|
+
)
|
403
|
+
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
404
|
+
workspace_buffers[0], "NHD"
|
405
|
+
)
|
406
|
+
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
407
|
+
workspace_buffers[1], "NHD"
|
408
|
+
)
|
409
|
+
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
410
|
+
workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
|
411
|
+
)
|
412
|
+
else:
|
413
|
+
self.flashinfer_prefill_wrapper_ragged = self.flashinfer_prefill_wrapper_paged = None
|
414
|
+
self.flashinfer_decode_wrapper = None
|
368
415
|
|
369
416
|
@torch.inference_mode()
|
370
417
|
def forward_prefill(self, batch: Batch):
|
@@ -379,6 +426,9 @@ class ModelRunner:
|
|
379
426
|
out_cache_loc=batch.out_cache_loc,
|
380
427
|
top_logprobs_nums=batch.top_logprobs_nums,
|
381
428
|
return_logprob=batch.return_logprob,
|
429
|
+
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
430
|
+
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
431
|
+
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
382
432
|
)
|
383
433
|
return self.model.forward(
|
384
434
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -397,6 +447,9 @@ class ModelRunner:
|
|
397
447
|
out_cache_loc=batch.out_cache_loc,
|
398
448
|
top_logprobs_nums=batch.top_logprobs_nums,
|
399
449
|
return_logprob=batch.return_logprob,
|
450
|
+
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
451
|
+
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
452
|
+
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
400
453
|
)
|
401
454
|
return self.model.forward(
|
402
455
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -417,6 +470,9 @@ class ModelRunner:
|
|
417
470
|
out_cache_cont_end=batch.out_cache_cont_end,
|
418
471
|
top_logprobs_nums=batch.top_logprobs_nums,
|
419
472
|
return_logprob=batch.return_logprob,
|
473
|
+
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
474
|
+
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
475
|
+
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
420
476
|
)
|
421
477
|
return self.model.forward(
|
422
478
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -435,6 +491,9 @@ class ModelRunner:
|
|
435
491
|
out_cache_loc=batch.out_cache_loc,
|
436
492
|
top_logprobs_nums=batch.top_logprobs_nums,
|
437
493
|
return_logprob=batch.return_logprob,
|
494
|
+
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
495
|
+
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
496
|
+
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
438
497
|
)
|
439
498
|
return self.model.forward(
|
440
499
|
batch.input_ids,
|
@@ -456,3 +515,48 @@ class ModelRunner:
|
|
456
515
|
return self.forward_prefill(batch)
|
457
516
|
else:
|
458
517
|
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
518
|
+
|
519
|
+
|
520
|
+
@lru_cache()
|
521
|
+
def import_model_classes():
|
522
|
+
model_arch_name_to_cls = {}
|
523
|
+
package_name = "sglang.srt.models"
|
524
|
+
package = importlib.import_module(package_name)
|
525
|
+
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
526
|
+
if not ispkg:
|
527
|
+
module = importlib.import_module(name)
|
528
|
+
if hasattr(module, "EntryClass"):
|
529
|
+
entry = module.EntryClass
|
530
|
+
if isinstance(
|
531
|
+
entry, list
|
532
|
+
): # To support multiple model classes in one module
|
533
|
+
for tmp in entry:
|
534
|
+
model_arch_name_to_cls[tmp.__name__] = tmp
|
535
|
+
else:
|
536
|
+
model_arch_name_to_cls[entry.__name__] = entry
|
537
|
+
|
538
|
+
# compat: some models such as chatglm has incorrect class set in config.json
|
539
|
+
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
|
540
|
+
if hasattr(module, "EntryClassRemapping") and isinstance(
|
541
|
+
module.EntryClassRemapping, list
|
542
|
+
):
|
543
|
+
for remap in module.EntryClassRemapping:
|
544
|
+
if isinstance(remap, tuple) and len(remap) == 2:
|
545
|
+
model_arch_name_to_cls[remap[0]] = remap[1]
|
546
|
+
|
547
|
+
return model_arch_name_to_cls
|
548
|
+
|
549
|
+
|
550
|
+
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
551
|
+
model_arch_name_to_cls = import_model_classes()
|
552
|
+
|
553
|
+
if model_arch not in model_arch_name_to_cls:
|
554
|
+
raise ValueError(
|
555
|
+
f"Unsupported architectures: {model_arch}. "
|
556
|
+
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
557
|
+
)
|
558
|
+
return model_arch_name_to_cls[model_arch]
|
559
|
+
|
560
|
+
|
561
|
+
# Monkey patch model loader
|
562
|
+
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
|