sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +13 -1
- sglang/bench_latency.py +10 -5
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +5 -2
- sglang/lang/ir.py +22 -4
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -2
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +64 -27
- sglang/srt/layers/radix_attention.py +41 -18
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +59 -179
- sglang/srt/managers/tokenizer_manager.py +193 -84
- sglang/srt/managers/tp_worker.py +131 -50
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +97 -28
- sglang/srt/model_executor/forward_batch_info.py +188 -82
- sglang/srt/model_executor/model_runner.py +269 -87
- sglang/srt/models/chatglm.py +6 -14
- sglang/srt/models/commandr.py +6 -2
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +7 -3
- sglang/srt/models/deepseek_v2.py +12 -7
- sglang/srt/models/gemma.py +6 -2
- sglang/srt/models/gemma2.py +22 -8
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +66 -398
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +7 -3
- sglang/srt/models/mixtral.py +61 -255
- sglang/srt/models/mixtral_quant.py +6 -5
- sglang/srt/models/qwen.py +7 -4
- sglang/srt/models/qwen2.py +15 -5
- sglang/srt/models/qwen2_moe.py +7 -16
- sglang/srt/models/stablelm.py +6 -2
- sglang/srt/openai_api/adapter.py +149 -58
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
- sglang/srt/server.py +107 -71
- sglang/srt/server_args.py +49 -15
- sglang/srt/utils.py +27 -18
- sglang/test/runners.py +38 -38
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_programs.py +32 -5
- sglang/test/test_utils.py +37 -50
- sglang/version.py +1 -1
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.12.dist-info/RECORD +0 -112
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,13 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""ModelRunner runs the forward passes of the models."""
|
17
17
|
|
18
|
+
import gc
|
18
19
|
import importlib
|
19
20
|
import importlib.resources
|
20
21
|
import logging
|
21
22
|
import pkgutil
|
22
|
-
import warnings
|
23
23
|
from functools import lru_cache
|
24
|
-
from typing import Optional, Type
|
24
|
+
from typing import Optional, Tuple, Type
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
@@ -37,10 +37,15 @@ from vllm.distributed import (
|
|
37
37
|
get_tp_group,
|
38
38
|
init_distributed_environment,
|
39
39
|
initialize_model_parallel,
|
40
|
+
set_custom_all_reduce,
|
40
41
|
)
|
42
|
+
from vllm.distributed.parallel_state import in_the_same_node_as
|
43
|
+
from vllm.model_executor.model_loader import get_model
|
41
44
|
from vllm.model_executor.models import ModelRegistry
|
42
45
|
|
43
46
|
from sglang.global_config import global_config
|
47
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
48
|
+
from sglang.srt.layers.sampler import SampleOutput
|
44
49
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
45
50
|
from sglang.srt.mem_cache.memory_pool import (
|
46
51
|
MHATokenToKVPool,
|
@@ -53,7 +58,7 @@ from sglang.srt.server_args import ServerArgs
|
|
53
58
|
from sglang.srt.utils import (
|
54
59
|
get_available_gpu_memory,
|
55
60
|
is_generation_model,
|
56
|
-
|
61
|
+
is_llama3_405b_fp8_head_16,
|
57
62
|
is_multimodal_model,
|
58
63
|
monkey_patch_vllm_dummy_weight_loader,
|
59
64
|
monkey_patch_vllm_p2p_access_check,
|
@@ -87,22 +92,35 @@ class ModelRunner:
|
|
87
92
|
{
|
88
93
|
"disable_flashinfer": server_args.disable_flashinfer,
|
89
94
|
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
|
90
|
-
"
|
95
|
+
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
91
96
|
"enable_mla": server_args.enable_mla,
|
92
97
|
}
|
93
98
|
)
|
94
99
|
|
100
|
+
min_per_gpu_memory = self.init_torch_distributed()
|
101
|
+
self.load_model()
|
102
|
+
self.init_memory_pool(
|
103
|
+
min_per_gpu_memory,
|
104
|
+
server_args.max_num_reqs,
|
105
|
+
server_args.max_total_tokens,
|
106
|
+
)
|
107
|
+
self.init_cublas()
|
108
|
+
self.init_flashinfer()
|
109
|
+
self.init_cuda_graphs()
|
110
|
+
|
111
|
+
def init_torch_distributed(self):
|
95
112
|
# Init torch distributed
|
96
113
|
torch.cuda.set_device(self.gpu_id)
|
97
|
-
logger.info(
|
114
|
+
logger.info("Init nccl begin.")
|
98
115
|
|
99
|
-
if not server_args.enable_p2p_check:
|
116
|
+
if not self.server_args.enable_p2p_check:
|
100
117
|
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
101
118
|
|
102
|
-
if server_args.nccl_init_addr:
|
103
|
-
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
|
119
|
+
if self.server_args.nccl_init_addr:
|
120
|
+
nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
|
104
121
|
else:
|
105
122
|
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
|
123
|
+
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
106
124
|
init_distributed_environment(
|
107
125
|
backend="nccl",
|
108
126
|
world_size=self.tp_size,
|
@@ -111,43 +129,43 @@ class ModelRunner:
|
|
111
129
|
distributed_init_method=nccl_init_method,
|
112
130
|
)
|
113
131
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
114
|
-
|
115
|
-
total_gpu_memory = get_available_gpu_memory(
|
132
|
+
min_per_gpu_memory = get_available_gpu_memory(
|
116
133
|
self.gpu_id, distributed=self.tp_size > 1
|
117
134
|
)
|
135
|
+
self.tp_group = get_tp_group()
|
118
136
|
|
137
|
+
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
|
138
|
+
# so we disable padding in cuda graph.
|
139
|
+
if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
|
140
|
+
self.server_args.disable_cuda_graph_padding = True
|
141
|
+
logger.info(
|
142
|
+
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
|
143
|
+
)
|
144
|
+
|
145
|
+
# Check memory for tensor parallelism
|
119
146
|
if self.tp_size > 1:
|
120
|
-
|
121
|
-
if
|
147
|
+
local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
148
|
+
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
122
149
|
raise ValueError(
|
123
150
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
124
151
|
)
|
125
152
|
|
126
|
-
|
127
|
-
self.load_model()
|
128
|
-
self.init_memory_pool(
|
129
|
-
total_gpu_memory,
|
130
|
-
server_args.max_num_reqs,
|
131
|
-
server_args.max_total_tokens,
|
132
|
-
)
|
133
|
-
self.init_cublas()
|
134
|
-
self.init_flashinfer()
|
135
|
-
|
136
|
-
if self.is_generation:
|
137
|
-
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
|
138
|
-
# Capture cuda graphs
|
139
|
-
self.init_cuda_graphs()
|
153
|
+
return min_per_gpu_memory
|
140
154
|
|
141
155
|
def load_model(self):
|
142
156
|
logger.info(
|
143
|
-
f"
|
144
|
-
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
157
|
+
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
145
158
|
)
|
159
|
+
if torch.cuda.get_device_capability()[0] < 8:
|
160
|
+
logger.info(
|
161
|
+
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
162
|
+
)
|
163
|
+
self.server_args.dtype = "float16"
|
146
164
|
|
147
165
|
monkey_patch_vllm_dummy_weight_loader()
|
148
|
-
device_config = DeviceConfig()
|
149
|
-
load_config = LoadConfig(load_format=self.server_args.load_format)
|
150
|
-
vllm_model_config = VllmModelConfig(
|
166
|
+
self.device_config = DeviceConfig()
|
167
|
+
self.load_config = LoadConfig(load_format=self.server_args.load_format)
|
168
|
+
self.vllm_model_config = VllmModelConfig(
|
151
169
|
model=self.server_args.model_path,
|
152
170
|
quantization=self.server_args.quantization,
|
153
171
|
tokenizer=None,
|
@@ -158,47 +176,132 @@ class ModelRunner:
|
|
158
176
|
skip_tokenizer_init=True,
|
159
177
|
)
|
160
178
|
|
161
|
-
|
162
|
-
|
179
|
+
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
180
|
+
# Drop this after Sept, 2024.
|
181
|
+
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
|
163
182
|
self.model_config.hf_config.num_key_value_heads = 8
|
164
|
-
vllm_model_config.hf_config.num_key_value_heads = 8
|
183
|
+
self.vllm_model_config.hf_config.num_key_value_heads = 8
|
165
184
|
monkey_patch_vllm_qvk_linear_loader()
|
166
185
|
|
167
|
-
self.dtype = vllm_model_config.dtype
|
186
|
+
self.dtype = self.vllm_model_config.dtype
|
168
187
|
if self.model_config.model_overide_args is not None:
|
169
|
-
vllm_model_config.hf_config.update(
|
170
|
-
|
171
|
-
|
172
|
-
self.server_args.efficient_weight_load
|
173
|
-
and "llama" in self.server_args.model_path.lower()
|
174
|
-
and self.server_args.quantization == "fp8"
|
175
|
-
):
|
176
|
-
from sglang.srt.model_loader.model_loader import get_model
|
177
|
-
else:
|
178
|
-
from vllm.model_executor.model_loader import get_model
|
188
|
+
self.vllm_model_config.hf_config.update(
|
189
|
+
self.model_config.model_overide_args
|
190
|
+
)
|
179
191
|
|
180
192
|
self.model = get_model(
|
181
|
-
model_config=vllm_model_config,
|
182
|
-
|
183
|
-
|
184
|
-
lora_config=None,
|
185
|
-
multimodal_config=None,
|
193
|
+
model_config=self.vllm_model_config,
|
194
|
+
load_config=self.load_config,
|
195
|
+
device_config=self.device_config,
|
186
196
|
parallel_config=None,
|
187
197
|
scheduler_config=None,
|
198
|
+
lora_config=None,
|
188
199
|
cache_config=None,
|
189
200
|
)
|
201
|
+
self.sliding_window_size = (
|
202
|
+
self.model.get_attention_sliding_window_size()
|
203
|
+
if hasattr(self.model, "get_attention_sliding_window_size")
|
204
|
+
else None
|
205
|
+
)
|
190
206
|
self.is_generation = is_generation_model(
|
191
|
-
self.model_config.hf_config.architectures
|
207
|
+
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
192
208
|
)
|
193
209
|
|
194
210
|
logger.info(
|
195
|
-
f"
|
211
|
+
f"Load weight end. "
|
196
212
|
f"type={type(self.model).__name__}, "
|
197
213
|
f"dtype={self.dtype}, "
|
198
214
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
199
215
|
)
|
200
216
|
|
201
|
-
def
|
217
|
+
def update_weights(self, model_path: str, load_format: str):
|
218
|
+
"""Update weights in-place."""
|
219
|
+
from vllm.model_executor.model_loader.loader import (
|
220
|
+
DefaultModelLoader,
|
221
|
+
device_loading_context,
|
222
|
+
get_model_loader,
|
223
|
+
)
|
224
|
+
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
225
|
+
|
226
|
+
logger.info(
|
227
|
+
f"Update weights begin. "
|
228
|
+
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
229
|
+
)
|
230
|
+
|
231
|
+
target_device = torch.device(self.device_config.device)
|
232
|
+
|
233
|
+
try:
|
234
|
+
# TODO: Use a better method to check this
|
235
|
+
vllm_model_config = VllmModelConfig(
|
236
|
+
model=model_path,
|
237
|
+
quantization=self.server_args.quantization,
|
238
|
+
tokenizer=None,
|
239
|
+
tokenizer_mode=None,
|
240
|
+
trust_remote_code=self.server_args.trust_remote_code,
|
241
|
+
dtype=self.server_args.dtype,
|
242
|
+
seed=42,
|
243
|
+
skip_tokenizer_init=True,
|
244
|
+
)
|
245
|
+
except Exception as e:
|
246
|
+
logger.error(f"Failed to load model config: {e}")
|
247
|
+
return False, "Failed to update model weights"
|
248
|
+
|
249
|
+
load_config = LoadConfig(load_format=load_format)
|
250
|
+
|
251
|
+
# Only support vllm DefaultModelLoader for now
|
252
|
+
loader = get_model_loader(load_config)
|
253
|
+
if not isinstance(loader, DefaultModelLoader):
|
254
|
+
logger.error("Failed to get weights iterator: Unsupported loader")
|
255
|
+
return False, "Failed to update model weights"
|
256
|
+
|
257
|
+
def get_weight_iter(config):
|
258
|
+
iter = loader._get_weights_iterator(
|
259
|
+
config.model,
|
260
|
+
config.revision,
|
261
|
+
fall_back_to_pt=getattr(
|
262
|
+
self.model, "fall_back_to_pt_during_load", True
|
263
|
+
),
|
264
|
+
)
|
265
|
+
return iter
|
266
|
+
|
267
|
+
def model_load_weights(model, iter):
|
268
|
+
model.load_weights(iter)
|
269
|
+
for _, module in self.model.named_modules():
|
270
|
+
quant_method = getattr(module, "quant_method", None)
|
271
|
+
if quant_method is not None:
|
272
|
+
with device_loading_context(module, target_device):
|
273
|
+
quant_method.process_weights_after_loading(module)
|
274
|
+
return model
|
275
|
+
|
276
|
+
with set_default_torch_dtype(vllm_model_config.dtype):
|
277
|
+
try:
|
278
|
+
iter = get_weight_iter(vllm_model_config)
|
279
|
+
except Exception as e:
|
280
|
+
message = f"Failed to get weights iterator: {e}"
|
281
|
+
logger.error(message)
|
282
|
+
return False, message
|
283
|
+
try:
|
284
|
+
model = model_load_weights(self.model, iter)
|
285
|
+
except Exception as e:
|
286
|
+
message = f"Failed to update weights: {e}. \n Rolling back to original weights"
|
287
|
+
logger.error(message)
|
288
|
+
del iter
|
289
|
+
gc.collect()
|
290
|
+
iter = get_weight_iter(self.vllm_model_config)
|
291
|
+
self.model = model_load_weights(self.model, iter)
|
292
|
+
return False, message
|
293
|
+
|
294
|
+
self.model = model
|
295
|
+
self.server_args.model_path = model_path
|
296
|
+
self.server_args.load_format = load_format
|
297
|
+
self.vllm_model_config = vllm_model_config
|
298
|
+
self.load_config = load_config
|
299
|
+
self.model_config.path = model_path
|
300
|
+
|
301
|
+
logger.info("Update weights end.")
|
302
|
+
return True, "Succeeded to update model weights"
|
303
|
+
|
304
|
+
def profile_max_num_token(self, total_gpu_memory: int):
|
202
305
|
available_gpu_memory = get_available_gpu_memory(
|
203
306
|
self.gpu_id, distributed=self.tp_size > 1
|
204
307
|
)
|
@@ -209,7 +312,7 @@ class ModelRunner:
|
|
209
312
|
cell_size = (
|
210
313
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
211
314
|
* self.model_config.num_hidden_layers
|
212
|
-
* torch._utils._element_size(self.
|
315
|
+
* torch._utils._element_size(self.kv_cache_dtype)
|
213
316
|
)
|
214
317
|
else:
|
215
318
|
cell_size = (
|
@@ -217,7 +320,7 @@ class ModelRunner:
|
|
217
320
|
* self.model_config.head_dim
|
218
321
|
* self.model_config.num_hidden_layers
|
219
322
|
* 2
|
220
|
-
* torch._utils._element_size(self.
|
323
|
+
* torch._utils._element_size(self.kv_cache_dtype)
|
221
324
|
)
|
222
325
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
223
326
|
1 - self.mem_fraction_static
|
@@ -226,12 +329,30 @@ class ModelRunner:
|
|
226
329
|
return max_num_token
|
227
330
|
|
228
331
|
def init_memory_pool(
|
229
|
-
self,
|
332
|
+
self,
|
333
|
+
total_gpu_memory: int,
|
334
|
+
max_num_reqs: int = None,
|
335
|
+
max_total_tokens: int = None,
|
230
336
|
):
|
337
|
+
if self.server_args.kv_cache_dtype == "auto":
|
338
|
+
self.kv_cache_dtype = self.dtype
|
339
|
+
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
340
|
+
if self.server_args.disable_flashinfer or self.server_args.enable_mla:
|
341
|
+
logger.warning(
|
342
|
+
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
|
343
|
+
)
|
344
|
+
self.kv_cache_dtype = self.dtype
|
345
|
+
else:
|
346
|
+
self.kv_cache_dtype = torch.float8_e5m2
|
347
|
+
else:
|
348
|
+
raise ValueError(
|
349
|
+
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
350
|
+
)
|
351
|
+
|
231
352
|
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
232
353
|
if max_total_tokens is not None:
|
233
354
|
if max_total_tokens > self.max_total_num_tokens:
|
234
|
-
|
355
|
+
logging.warning(
|
235
356
|
f"max_total_tokens={max_total_tokens} is larger than the profiled value "
|
236
357
|
f"{self.max_total_num_tokens}. "
|
237
358
|
f"Use the profiled value instead."
|
@@ -264,7 +385,7 @@ class ModelRunner:
|
|
264
385
|
):
|
265
386
|
self.token_to_kv_pool = MLATokenToKVPool(
|
266
387
|
self.max_total_num_tokens,
|
267
|
-
dtype=self.
|
388
|
+
dtype=self.kv_cache_dtype,
|
268
389
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
269
390
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
270
391
|
layer_num=self.model_config.num_hidden_layers,
|
@@ -275,13 +396,13 @@ class ModelRunner:
|
|
275
396
|
else:
|
276
397
|
self.token_to_kv_pool = MHATokenToKVPool(
|
277
398
|
self.max_total_num_tokens,
|
278
|
-
dtype=self.
|
399
|
+
dtype=self.kv_cache_dtype,
|
279
400
|
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
280
401
|
head_dim=self.model_config.head_dim,
|
281
402
|
layer_num=self.model_config.num_hidden_layers,
|
282
403
|
)
|
283
404
|
logger.info(
|
284
|
-
f"
|
405
|
+
f"Memory pool end. "
|
285
406
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
286
407
|
)
|
287
408
|
|
@@ -295,7 +416,11 @@ class ModelRunner:
|
|
295
416
|
return c
|
296
417
|
|
297
418
|
def init_flashinfer(self):
|
419
|
+
"""Init flashinfer attention kernel wrappers."""
|
298
420
|
if self.server_args.disable_flashinfer:
|
421
|
+
assert (
|
422
|
+
self.sliding_window_size is None
|
423
|
+
), "turn on flashinfer to support window attention"
|
299
424
|
self.flashinfer_prefill_wrapper_ragged = None
|
300
425
|
self.flashinfer_prefill_wrapper_paged = None
|
301
426
|
self.flashinfer_decode_wrapper = None
|
@@ -309,36 +434,72 @@ class ModelRunner:
|
|
309
434
|
else:
|
310
435
|
use_tensor_cores = False
|
311
436
|
|
312
|
-
self.
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
437
|
+
if self.sliding_window_size is None:
|
438
|
+
self.flashinfer_workspace_buffer = torch.empty(
|
439
|
+
global_config.flashinfer_workspace_size,
|
440
|
+
dtype=torch.uint8,
|
441
|
+
device="cuda",
|
442
|
+
)
|
443
|
+
self.flashinfer_prefill_wrapper_ragged = (
|
444
|
+
BatchPrefillWithRaggedKVCacheWrapper(
|
445
|
+
self.flashinfer_workspace_buffer, "NHD"
|
446
|
+
)
|
447
|
+
)
|
448
|
+
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
449
|
+
self.flashinfer_workspace_buffer, "NHD"
|
450
|
+
)
|
451
|
+
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
452
|
+
self.flashinfer_workspace_buffer,
|
453
|
+
"NHD",
|
454
|
+
use_tensor_cores=use_tensor_cores,
|
455
|
+
)
|
456
|
+
else:
|
457
|
+
self.flashinfer_workspace_buffer = torch.empty(
|
458
|
+
global_config.flashinfer_workspace_size,
|
459
|
+
dtype=torch.uint8,
|
460
|
+
device="cuda",
|
461
|
+
)
|
462
|
+
self.flashinfer_prefill_wrapper_ragged = None
|
463
|
+
self.flashinfer_prefill_wrapper_paged = []
|
464
|
+
self.flashinfer_decode_wrapper = []
|
465
|
+
for i in range(2):
|
466
|
+
self.flashinfer_prefill_wrapper_paged.append(
|
467
|
+
BatchPrefillWithPagedKVCacheWrapper(
|
468
|
+
self.flashinfer_workspace_buffer, "NHD"
|
469
|
+
)
|
470
|
+
)
|
471
|
+
self.flashinfer_decode_wrapper.append(
|
472
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
473
|
+
self.flashinfer_workspace_buffer,
|
474
|
+
"NHD",
|
475
|
+
use_tensor_cores=use_tensor_cores,
|
476
|
+
)
|
477
|
+
)
|
326
478
|
|
327
479
|
def init_cuda_graphs(self):
|
480
|
+
"""Capture cuda graphs."""
|
481
|
+
if not self.is_generation:
|
482
|
+
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
483
|
+
return
|
484
|
+
|
328
485
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
329
486
|
|
330
487
|
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
|
331
488
|
self.cuda_graph_runner = None
|
332
489
|
return
|
333
490
|
|
334
|
-
logger.info(
|
335
|
-
|
336
|
-
|
337
|
-
|
491
|
+
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
492
|
+
|
493
|
+
if self.server_args.disable_cuda_graph_padding:
|
494
|
+
batch_size_list = list(range(1, 32)) + [64, 128]
|
495
|
+
else:
|
496
|
+
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
497
|
+
|
338
498
|
self.cuda_graph_runner = CudaGraphRunner(
|
339
499
|
self,
|
340
500
|
max_batch_size_to_capture=max(batch_size_list),
|
341
501
|
use_torch_compile=self.server_args.enable_torch_compile,
|
502
|
+
disable_padding=self.server_args.disable_cuda_graph_padding,
|
342
503
|
)
|
343
504
|
try:
|
344
505
|
self.cuda_graph_runner.capture(batch_size_list)
|
@@ -354,11 +515,17 @@ class ModelRunner:
|
|
354
515
|
|
355
516
|
@torch.inference_mode()
|
356
517
|
def forward_decode(self, batch: ScheduleBatch):
|
357
|
-
if
|
518
|
+
if (
|
519
|
+
self.cuda_graph_runner
|
520
|
+
and self.cuda_graph_runner.can_run(len(batch.reqs))
|
521
|
+
and not batch.sampling_info.has_bias()
|
522
|
+
):
|
358
523
|
return self.cuda_graph_runner.replay(batch)
|
359
524
|
|
360
525
|
input_metadata = InputMetadata.from_schedule_batch(
|
361
|
-
self,
|
526
|
+
self,
|
527
|
+
batch,
|
528
|
+
ForwardMode.DECODE,
|
362
529
|
)
|
363
530
|
|
364
531
|
return self.model.forward(
|
@@ -368,16 +535,29 @@ class ModelRunner:
|
|
368
535
|
@torch.inference_mode()
|
369
536
|
def forward_extend(self, batch: ScheduleBatch):
|
370
537
|
input_metadata = InputMetadata.from_schedule_batch(
|
371
|
-
self,
|
372
|
-
|
373
|
-
|
374
|
-
batch.input_ids, input_metadata.positions, input_metadata
|
538
|
+
self,
|
539
|
+
batch,
|
540
|
+
forward_mode=ForwardMode.EXTEND,
|
375
541
|
)
|
542
|
+
if self.is_generation:
|
543
|
+
return self.model.forward(
|
544
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
545
|
+
)
|
546
|
+
else:
|
547
|
+
# Only embedding models have get_embedding parameter
|
548
|
+
return self.model.forward(
|
549
|
+
batch.input_ids,
|
550
|
+
input_metadata.positions,
|
551
|
+
input_metadata,
|
552
|
+
get_embedding=True,
|
553
|
+
)
|
376
554
|
|
377
555
|
@torch.inference_mode()
|
378
556
|
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
379
557
|
input_metadata = InputMetadata.from_schedule_batch(
|
380
|
-
self,
|
558
|
+
self,
|
559
|
+
batch,
|
560
|
+
forward_mode=ForwardMode.EXTEND,
|
381
561
|
)
|
382
562
|
return self.model.forward(
|
383
563
|
batch.input_ids,
|
@@ -388,7 +568,9 @@ class ModelRunner:
|
|
388
568
|
input_metadata.image_offsets,
|
389
569
|
)
|
390
570
|
|
391
|
-
def forward(
|
571
|
+
def forward(
|
572
|
+
self, batch: ScheduleBatch, forward_mode: ForwardMode
|
573
|
+
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
|
392
574
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
393
575
|
return self.forward_extend_multi_modal(batch)
|
394
576
|
elif forward_mode == ForwardMode.DECODE:
|
@@ -444,4 +626,4 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
|
444
626
|
|
445
627
|
|
446
628
|
# Monkey patch model loader
|
447
|
-
setattr(ModelRegistry, "
|
629
|
+
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
|
sglang/srt/models/chatglm.py
CHANGED
@@ -24,8 +24,6 @@ from torch import nn
|
|
24
24
|
from torch.nn import LayerNorm
|
25
25
|
from vllm.config import CacheConfig
|
26
26
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
27
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
28
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
29
27
|
from vllm.model_executor.layers.linear import (
|
30
28
|
MergedColumnParallelLinear,
|
31
29
|
QKVParallelLinear,
|
@@ -33,18 +31,18 @@ from vllm.model_executor.layers.linear import (
|
|
33
31
|
)
|
34
32
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
35
33
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
36
|
-
from vllm.model_executor.layers.sampler import Sampler
|
37
34
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
38
35
|
ParallelLMHead,
|
39
36
|
VocabParallelEmbedding,
|
40
37
|
)
|
41
38
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
42
|
-
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
43
|
-
from vllm.sequence import SamplerOutput
|
44
39
|
from vllm.transformers_utils.configs import ChatGLMConfig
|
45
40
|
|
41
|
+
from sglang.srt.layers.activation import SiluAndMul
|
42
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
46
43
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
44
|
from sglang.srt.layers.radix_attention import RadixAttention
|
45
|
+
from sglang.srt.layers.sampler import Sampler
|
48
46
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
47
|
|
50
48
|
LoraConfig = None
|
@@ -383,17 +381,11 @@ class ChatGLMForCausalLM(nn.Module):
|
|
383
381
|
input_metadata: InputMetadata,
|
384
382
|
) -> torch.Tensor:
|
385
383
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
386
|
-
|
384
|
+
logits_output = self.logits_processor(
|
387
385
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
388
386
|
)
|
389
|
-
|
390
|
-
|
391
|
-
self,
|
392
|
-
logits: torch.Tensor,
|
393
|
-
sampling_metadata: SamplingMetadata,
|
394
|
-
) -> Optional[SamplerOutput]:
|
395
|
-
next_tokens = self.sampler(logits, sampling_metadata)
|
396
|
-
return next_tokens
|
387
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
388
|
+
return sample_output, logits_output
|
397
389
|
|
398
390
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
399
391
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
sglang/srt/models/commandr.py
CHANGED
@@ -50,7 +50,6 @@ from vllm.distributed import (
|
|
50
50
|
get_tensor_model_parallel_rank,
|
51
51
|
get_tensor_model_parallel_world_size,
|
52
52
|
)
|
53
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
54
53
|
from vllm.model_executor.layers.linear import (
|
55
54
|
MergedColumnParallelLinear,
|
56
55
|
QKVParallelLinear,
|
@@ -62,8 +61,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
|
|
62
61
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
63
62
|
from vllm.model_executor.utils import set_weight_attrs
|
64
63
|
|
64
|
+
from sglang.srt.layers.activation import SiluAndMul
|
65
65
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
66
66
|
from sglang.srt.layers.radix_attention import RadixAttention
|
67
|
+
from sglang.srt.layers.sampler import Sampler
|
67
68
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
68
69
|
|
69
70
|
|
@@ -326,6 +327,7 @@ class CohereForCausalLM(nn.Module):
|
|
326
327
|
self.config = config
|
327
328
|
self.quant_config = quant_config
|
328
329
|
self.logits_processor = LogitsProcessor(config)
|
330
|
+
self.sampler = Sampler()
|
329
331
|
self.model = CohereModel(config, quant_config)
|
330
332
|
|
331
333
|
@torch.no_grad()
|
@@ -340,9 +342,11 @@ class CohereForCausalLM(nn.Module):
|
|
340
342
|
positions,
|
341
343
|
input_metadata,
|
342
344
|
)
|
343
|
-
|
345
|
+
logits_output = self.logits_processor(
|
344
346
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
345
347
|
)
|
348
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
349
|
+
return sample_output, logits_output
|
346
350
|
|
347
351
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
348
352
|
stacked_params_mapping = [
|
sglang/srt/models/dbrx.py
CHANGED
@@ -45,6 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
+
from sglang.srt.layers.sampler import Sampler
|
48
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
50
|
|
50
51
|
|
@@ -382,6 +383,7 @@ class DbrxForCausalLM(nn.Module):
|
|
382
383
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
383
384
|
)
|
384
385
|
self.logits_processor = LogitsProcessor(config)
|
386
|
+
self.sampler = Sampler()
|
385
387
|
|
386
388
|
@torch.no_grad()
|
387
389
|
def forward(
|
@@ -391,9 +393,11 @@ class DbrxForCausalLM(nn.Module):
|
|
391
393
|
input_metadata: InputMetadata,
|
392
394
|
) -> torch.Tensor:
|
393
395
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
394
|
-
|
396
|
+
logits_output = self.logits_processor(
|
395
397
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
396
398
|
)
|
399
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
400
|
+
return sample_output, logits_output
|
397
401
|
|
398
402
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
399
403
|
expert_params_mapping = [
|