sglang 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +6 -0
- sglang/bench_latency.py +7 -3
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +9 -0
- sglang/launch_server.py +8 -1
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -1
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/fused_moe/layer.py +2 -2
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +60 -23
- sglang/srt/layers/radix_attention.py +3 -4
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +52 -167
- sglang/srt/managers/tokenizer_manager.py +192 -83
- sglang/srt/managers/tp_worker.py +130 -43
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +49 -11
- sglang/srt/model_executor/forward_batch_info.py +59 -27
- sglang/srt/model_executor/model_runner.py +210 -61
- sglang/srt/models/chatglm.py +4 -12
- sglang/srt/models/commandr.py +5 -1
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +5 -1
- sglang/srt/models/deepseek_v2.py +5 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +15 -7
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +16 -2
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +5 -1
- sglang/srt/models/mixtral.py +5 -1
- sglang/srt/models/mixtral_quant.py +5 -1
- sglang/srt/models/qwen.py +5 -2
- sglang/srt/models/qwen2.py +13 -3
- sglang/srt/models/qwen2_moe.py +5 -14
- sglang/srt/models/stablelm.py +5 -1
- sglang/srt/openai_api/adapter.py +117 -37
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -0
- sglang/srt/server.py +84 -56
- sglang/srt/server_args.py +43 -15
- sglang/srt/utils.py +26 -16
- sglang/test/runners.py +23 -31
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_utils.py +36 -53
- sglang/version.py +1 -1
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/METADATA +92 -25
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang-0.2.13.dist-info/RECORD +0 -112
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,13 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""ModelRunner runs the forward passes of the models."""
|
17
17
|
|
18
|
+
import gc
|
18
19
|
import importlib
|
19
20
|
import importlib.resources
|
20
21
|
import logging
|
21
22
|
import pkgutil
|
22
|
-
import warnings
|
23
23
|
from functools import lru_cache
|
24
|
-
from typing import Optional, Type
|
24
|
+
from typing import Optional, Tuple, Type
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
@@ -37,11 +37,15 @@ from vllm.distributed import (
|
|
37
37
|
get_tp_group,
|
38
38
|
init_distributed_environment,
|
39
39
|
initialize_model_parallel,
|
40
|
+
set_custom_all_reduce,
|
40
41
|
)
|
42
|
+
from vllm.distributed.parallel_state import in_the_same_node_as
|
41
43
|
from vllm.model_executor.model_loader import get_model
|
42
44
|
from vllm.model_executor.models import ModelRegistry
|
43
45
|
|
44
46
|
from sglang.global_config import global_config
|
47
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
48
|
+
from sglang.srt.layers.sampler import SampleOutput
|
45
49
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
46
50
|
from sglang.srt.mem_cache.memory_pool import (
|
47
51
|
MHATokenToKVPool,
|
@@ -88,22 +92,35 @@ class ModelRunner:
|
|
88
92
|
{
|
89
93
|
"disable_flashinfer": server_args.disable_flashinfer,
|
90
94
|
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
|
91
|
-
"
|
95
|
+
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
92
96
|
"enable_mla": server_args.enable_mla,
|
93
97
|
}
|
94
98
|
)
|
95
99
|
|
100
|
+
min_per_gpu_memory = self.init_torch_distributed()
|
101
|
+
self.load_model()
|
102
|
+
self.init_memory_pool(
|
103
|
+
min_per_gpu_memory,
|
104
|
+
server_args.max_num_reqs,
|
105
|
+
server_args.max_total_tokens,
|
106
|
+
)
|
107
|
+
self.init_cublas()
|
108
|
+
self.init_flashinfer()
|
109
|
+
self.init_cuda_graphs()
|
110
|
+
|
111
|
+
def init_torch_distributed(self):
|
96
112
|
# Init torch distributed
|
97
113
|
torch.cuda.set_device(self.gpu_id)
|
98
|
-
logger.info(
|
114
|
+
logger.info("Init nccl begin.")
|
99
115
|
|
100
|
-
if not server_args.enable_p2p_check:
|
116
|
+
if not self.server_args.enable_p2p_check:
|
101
117
|
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
102
118
|
|
103
|
-
if server_args.nccl_init_addr:
|
104
|
-
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
|
119
|
+
if self.server_args.nccl_init_addr:
|
120
|
+
nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
|
105
121
|
else:
|
106
122
|
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
|
123
|
+
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
107
124
|
init_distributed_environment(
|
108
125
|
backend="nccl",
|
109
126
|
world_size=self.tp_size,
|
@@ -112,43 +129,43 @@ class ModelRunner:
|
|
112
129
|
distributed_init_method=nccl_init_method,
|
113
130
|
)
|
114
131
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
115
|
-
|
116
|
-
total_gpu_memory = get_available_gpu_memory(
|
132
|
+
min_per_gpu_memory = get_available_gpu_memory(
|
117
133
|
self.gpu_id, distributed=self.tp_size > 1
|
118
134
|
)
|
135
|
+
self.tp_group = get_tp_group()
|
136
|
+
|
137
|
+
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
|
138
|
+
# so we disable padding in cuda graph.
|
139
|
+
if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
|
140
|
+
self.server_args.disable_cuda_graph_padding = True
|
141
|
+
logger.info(
|
142
|
+
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
|
143
|
+
)
|
119
144
|
|
145
|
+
# Check memory for tensor parallelism
|
120
146
|
if self.tp_size > 1:
|
121
|
-
|
122
|
-
if
|
147
|
+
local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
148
|
+
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
123
149
|
raise ValueError(
|
124
150
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
125
151
|
)
|
126
152
|
|
127
|
-
|
128
|
-
self.load_model()
|
129
|
-
self.init_memory_pool(
|
130
|
-
total_gpu_memory,
|
131
|
-
server_args.max_num_reqs,
|
132
|
-
server_args.max_total_tokens,
|
133
|
-
)
|
134
|
-
self.init_cublas()
|
135
|
-
self.init_flashinfer()
|
136
|
-
|
137
|
-
if self.is_generation:
|
138
|
-
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
|
139
|
-
# Capture cuda graphs
|
140
|
-
self.init_cuda_graphs()
|
153
|
+
return min_per_gpu_memory
|
141
154
|
|
142
155
|
def load_model(self):
|
143
156
|
logger.info(
|
144
|
-
f"
|
145
|
-
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
157
|
+
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
146
158
|
)
|
159
|
+
if torch.cuda.get_device_capability()[0] < 8:
|
160
|
+
logger.info(
|
161
|
+
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
162
|
+
)
|
163
|
+
self.server_args.dtype = "float16"
|
147
164
|
|
148
165
|
monkey_patch_vllm_dummy_weight_loader()
|
149
|
-
device_config = DeviceConfig()
|
150
|
-
load_config = LoadConfig(load_format=self.server_args.load_format)
|
151
|
-
vllm_model_config = VllmModelConfig(
|
166
|
+
self.device_config = DeviceConfig()
|
167
|
+
self.load_config = LoadConfig(load_format=self.server_args.load_format)
|
168
|
+
self.vllm_model_config = VllmModelConfig(
|
152
169
|
model=self.server_args.model_path,
|
153
170
|
quantization=self.server_args.quantization,
|
154
171
|
tokenizer=None,
|
@@ -159,43 +176,132 @@ class ModelRunner:
|
|
159
176
|
skip_tokenizer_init=True,
|
160
177
|
)
|
161
178
|
|
179
|
+
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
180
|
+
# Drop this after Sept, 2024.
|
162
181
|
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
|
163
|
-
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
164
182
|
self.model_config.hf_config.num_key_value_heads = 8
|
165
|
-
vllm_model_config.hf_config.num_key_value_heads = 8
|
183
|
+
self.vllm_model_config.hf_config.num_key_value_heads = 8
|
166
184
|
monkey_patch_vllm_qvk_linear_loader()
|
167
185
|
|
168
|
-
self.dtype = vllm_model_config.dtype
|
186
|
+
self.dtype = self.vllm_model_config.dtype
|
169
187
|
if self.model_config.model_overide_args is not None:
|
170
|
-
vllm_model_config.hf_config.update(
|
188
|
+
self.vllm_model_config.hf_config.update(
|
189
|
+
self.model_config.model_overide_args
|
190
|
+
)
|
171
191
|
|
172
192
|
self.model = get_model(
|
173
|
-
model_config=vllm_model_config,
|
174
|
-
|
175
|
-
|
176
|
-
lora_config=None,
|
177
|
-
multimodal_config=None,
|
193
|
+
model_config=self.vllm_model_config,
|
194
|
+
load_config=self.load_config,
|
195
|
+
device_config=self.device_config,
|
178
196
|
parallel_config=None,
|
179
197
|
scheduler_config=None,
|
198
|
+
lora_config=None,
|
180
199
|
cache_config=None,
|
181
200
|
)
|
182
201
|
self.sliding_window_size = (
|
183
|
-
self.model.
|
184
|
-
if hasattr(self.model, "
|
202
|
+
self.model.get_attention_sliding_window_size()
|
203
|
+
if hasattr(self.model, "get_attention_sliding_window_size")
|
185
204
|
else None
|
186
205
|
)
|
187
206
|
self.is_generation = is_generation_model(
|
188
|
-
self.model_config.hf_config.architectures
|
207
|
+
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
189
208
|
)
|
190
209
|
|
191
210
|
logger.info(
|
192
|
-
f"
|
211
|
+
f"Load weight end. "
|
193
212
|
f"type={type(self.model).__name__}, "
|
194
213
|
f"dtype={self.dtype}, "
|
195
214
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
196
215
|
)
|
197
216
|
|
198
|
-
def
|
217
|
+
def update_weights(self, model_path: str, load_format: str):
|
218
|
+
"""Update weights in-place."""
|
219
|
+
from vllm.model_executor.model_loader.loader import (
|
220
|
+
DefaultModelLoader,
|
221
|
+
device_loading_context,
|
222
|
+
get_model_loader,
|
223
|
+
)
|
224
|
+
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
225
|
+
|
226
|
+
logger.info(
|
227
|
+
f"Update weights begin. "
|
228
|
+
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
229
|
+
)
|
230
|
+
|
231
|
+
target_device = torch.device(self.device_config.device)
|
232
|
+
|
233
|
+
try:
|
234
|
+
# TODO: Use a better method to check this
|
235
|
+
vllm_model_config = VllmModelConfig(
|
236
|
+
model=model_path,
|
237
|
+
quantization=self.server_args.quantization,
|
238
|
+
tokenizer=None,
|
239
|
+
tokenizer_mode=None,
|
240
|
+
trust_remote_code=self.server_args.trust_remote_code,
|
241
|
+
dtype=self.server_args.dtype,
|
242
|
+
seed=42,
|
243
|
+
skip_tokenizer_init=True,
|
244
|
+
)
|
245
|
+
except Exception as e:
|
246
|
+
logger.error(f"Failed to load model config: {e}")
|
247
|
+
return False, "Failed to update model weights"
|
248
|
+
|
249
|
+
load_config = LoadConfig(load_format=load_format)
|
250
|
+
|
251
|
+
# Only support vllm DefaultModelLoader for now
|
252
|
+
loader = get_model_loader(load_config)
|
253
|
+
if not isinstance(loader, DefaultModelLoader):
|
254
|
+
logger.error("Failed to get weights iterator: Unsupported loader")
|
255
|
+
return False, "Failed to update model weights"
|
256
|
+
|
257
|
+
def get_weight_iter(config):
|
258
|
+
iter = loader._get_weights_iterator(
|
259
|
+
config.model,
|
260
|
+
config.revision,
|
261
|
+
fall_back_to_pt=getattr(
|
262
|
+
self.model, "fall_back_to_pt_during_load", True
|
263
|
+
),
|
264
|
+
)
|
265
|
+
return iter
|
266
|
+
|
267
|
+
def model_load_weights(model, iter):
|
268
|
+
model.load_weights(iter)
|
269
|
+
for _, module in self.model.named_modules():
|
270
|
+
quant_method = getattr(module, "quant_method", None)
|
271
|
+
if quant_method is not None:
|
272
|
+
with device_loading_context(module, target_device):
|
273
|
+
quant_method.process_weights_after_loading(module)
|
274
|
+
return model
|
275
|
+
|
276
|
+
with set_default_torch_dtype(vllm_model_config.dtype):
|
277
|
+
try:
|
278
|
+
iter = get_weight_iter(vllm_model_config)
|
279
|
+
except Exception as e:
|
280
|
+
message = f"Failed to get weights iterator: {e}"
|
281
|
+
logger.error(message)
|
282
|
+
return False, message
|
283
|
+
try:
|
284
|
+
model = model_load_weights(self.model, iter)
|
285
|
+
except Exception as e:
|
286
|
+
message = f"Failed to update weights: {e}. \n Rolling back to original weights"
|
287
|
+
logger.error(message)
|
288
|
+
del iter
|
289
|
+
gc.collect()
|
290
|
+
iter = get_weight_iter(self.vllm_model_config)
|
291
|
+
self.model = model_load_weights(self.model, iter)
|
292
|
+
return False, message
|
293
|
+
|
294
|
+
self.model = model
|
295
|
+
self.server_args.model_path = model_path
|
296
|
+
self.server_args.load_format = load_format
|
297
|
+
self.vllm_model_config = vllm_model_config
|
298
|
+
self.load_config = load_config
|
299
|
+
self.model_config.path = model_path
|
300
|
+
|
301
|
+
logger.info("Update weights end.")
|
302
|
+
return True, "Succeeded to update model weights"
|
303
|
+
|
304
|
+
def profile_max_num_token(self, total_gpu_memory: int):
|
199
305
|
available_gpu_memory = get_available_gpu_memory(
|
200
306
|
self.gpu_id, distributed=self.tp_size > 1
|
201
307
|
)
|
@@ -206,7 +312,7 @@ class ModelRunner:
|
|
206
312
|
cell_size = (
|
207
313
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
208
314
|
* self.model_config.num_hidden_layers
|
209
|
-
* torch._utils._element_size(self.
|
315
|
+
* torch._utils._element_size(self.kv_cache_dtype)
|
210
316
|
)
|
211
317
|
else:
|
212
318
|
cell_size = (
|
@@ -214,7 +320,7 @@ class ModelRunner:
|
|
214
320
|
* self.model_config.head_dim
|
215
321
|
* self.model_config.num_hidden_layers
|
216
322
|
* 2
|
217
|
-
* torch._utils._element_size(self.
|
323
|
+
* torch._utils._element_size(self.kv_cache_dtype)
|
218
324
|
)
|
219
325
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
220
326
|
1 - self.mem_fraction_static
|
@@ -223,12 +329,30 @@ class ModelRunner:
|
|
223
329
|
return max_num_token
|
224
330
|
|
225
331
|
def init_memory_pool(
|
226
|
-
self,
|
332
|
+
self,
|
333
|
+
total_gpu_memory: int,
|
334
|
+
max_num_reqs: int = None,
|
335
|
+
max_total_tokens: int = None,
|
227
336
|
):
|
337
|
+
if self.server_args.kv_cache_dtype == "auto":
|
338
|
+
self.kv_cache_dtype = self.dtype
|
339
|
+
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
340
|
+
if self.server_args.disable_flashinfer or self.server_args.enable_mla:
|
341
|
+
logger.warning(
|
342
|
+
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
|
343
|
+
)
|
344
|
+
self.kv_cache_dtype = self.dtype
|
345
|
+
else:
|
346
|
+
self.kv_cache_dtype = torch.float8_e5m2
|
347
|
+
else:
|
348
|
+
raise ValueError(
|
349
|
+
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
350
|
+
)
|
351
|
+
|
228
352
|
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
229
353
|
if max_total_tokens is not None:
|
230
354
|
if max_total_tokens > self.max_total_num_tokens:
|
231
|
-
|
355
|
+
logging.warning(
|
232
356
|
f"max_total_tokens={max_total_tokens} is larger than the profiled value "
|
233
357
|
f"{self.max_total_num_tokens}. "
|
234
358
|
f"Use the profiled value instead."
|
@@ -261,7 +385,7 @@ class ModelRunner:
|
|
261
385
|
):
|
262
386
|
self.token_to_kv_pool = MLATokenToKVPool(
|
263
387
|
self.max_total_num_tokens,
|
264
|
-
dtype=self.
|
388
|
+
dtype=self.kv_cache_dtype,
|
265
389
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
266
390
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
267
391
|
layer_num=self.model_config.num_hidden_layers,
|
@@ -272,13 +396,13 @@ class ModelRunner:
|
|
272
396
|
else:
|
273
397
|
self.token_to_kv_pool = MHATokenToKVPool(
|
274
398
|
self.max_total_num_tokens,
|
275
|
-
dtype=self.
|
399
|
+
dtype=self.kv_cache_dtype,
|
276
400
|
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
277
401
|
head_dim=self.model_config.head_dim,
|
278
402
|
layer_num=self.model_config.num_hidden_layers,
|
279
403
|
)
|
280
404
|
logger.info(
|
281
|
-
f"
|
405
|
+
f"Memory pool end. "
|
282
406
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
283
407
|
)
|
284
408
|
|
@@ -292,6 +416,7 @@ class ModelRunner:
|
|
292
416
|
return c
|
293
417
|
|
294
418
|
def init_flashinfer(self):
|
419
|
+
"""Init flashinfer attention kernel wrappers."""
|
295
420
|
if self.server_args.disable_flashinfer:
|
296
421
|
assert (
|
297
422
|
self.sliding_window_size is None
|
@@ -352,20 +477,29 @@ class ModelRunner:
|
|
352
477
|
)
|
353
478
|
|
354
479
|
def init_cuda_graphs(self):
|
480
|
+
"""Capture cuda graphs."""
|
481
|
+
if not self.is_generation:
|
482
|
+
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
483
|
+
return
|
484
|
+
|
355
485
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
356
486
|
|
357
487
|
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
|
358
488
|
self.cuda_graph_runner = None
|
359
489
|
return
|
360
490
|
|
361
|
-
logger.info(
|
362
|
-
|
363
|
-
|
364
|
-
|
491
|
+
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
492
|
+
|
493
|
+
if self.server_args.disable_cuda_graph_padding:
|
494
|
+
batch_size_list = list(range(1, 32)) + [64, 128]
|
495
|
+
else:
|
496
|
+
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
497
|
+
|
365
498
|
self.cuda_graph_runner = CudaGraphRunner(
|
366
499
|
self,
|
367
500
|
max_batch_size_to_capture=max(batch_size_list),
|
368
501
|
use_torch_compile=self.server_args.enable_torch_compile,
|
502
|
+
disable_padding=self.server_args.disable_cuda_graph_padding,
|
369
503
|
)
|
370
504
|
try:
|
371
505
|
self.cuda_graph_runner.capture(batch_size_list)
|
@@ -381,7 +515,11 @@ class ModelRunner:
|
|
381
515
|
|
382
516
|
@torch.inference_mode()
|
383
517
|
def forward_decode(self, batch: ScheduleBatch):
|
384
|
-
if
|
518
|
+
if (
|
519
|
+
self.cuda_graph_runner
|
520
|
+
and self.cuda_graph_runner.can_run(len(batch.reqs))
|
521
|
+
and not batch.sampling_info.has_bias()
|
522
|
+
):
|
385
523
|
return self.cuda_graph_runner.replay(batch)
|
386
524
|
|
387
525
|
input_metadata = InputMetadata.from_schedule_batch(
|
@@ -401,9 +539,18 @@ class ModelRunner:
|
|
401
539
|
batch,
|
402
540
|
forward_mode=ForwardMode.EXTEND,
|
403
541
|
)
|
404
|
-
|
405
|
-
|
406
|
-
|
542
|
+
if self.is_generation:
|
543
|
+
return self.model.forward(
|
544
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
545
|
+
)
|
546
|
+
else:
|
547
|
+
# Only embedding models have get_embedding parameter
|
548
|
+
return self.model.forward(
|
549
|
+
batch.input_ids,
|
550
|
+
input_metadata.positions,
|
551
|
+
input_metadata,
|
552
|
+
get_embedding=True,
|
553
|
+
)
|
407
554
|
|
408
555
|
@torch.inference_mode()
|
409
556
|
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
@@ -421,7 +568,9 @@ class ModelRunner:
|
|
421
568
|
input_metadata.image_offsets,
|
422
569
|
)
|
423
570
|
|
424
|
-
def forward(
|
571
|
+
def forward(
|
572
|
+
self, batch: ScheduleBatch, forward_mode: ForwardMode
|
573
|
+
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
|
425
574
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
426
575
|
return self.forward_extend_multi_modal(batch)
|
427
576
|
elif forward_mode == ForwardMode.DECODE:
|
@@ -477,4 +626,4 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
|
477
626
|
|
478
627
|
|
479
628
|
# Monkey patch model loader
|
480
|
-
setattr(ModelRegistry, "
|
629
|
+
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
|
sglang/srt/models/chatglm.py
CHANGED
@@ -31,20 +31,18 @@ from vllm.model_executor.layers.linear import (
|
|
31
31
|
)
|
32
32
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
33
33
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
34
|
-
from vllm.model_executor.layers.sampler import Sampler
|
35
34
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
36
35
|
ParallelLMHead,
|
37
36
|
VocabParallelEmbedding,
|
38
37
|
)
|
39
38
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
40
|
-
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
41
|
-
from vllm.sequence import SamplerOutput
|
42
39
|
from vllm.transformers_utils.configs import ChatGLMConfig
|
43
40
|
|
44
41
|
from sglang.srt.layers.activation import SiluAndMul
|
45
42
|
from sglang.srt.layers.layernorm import RMSNorm
|
46
43
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
44
|
from sglang.srt.layers.radix_attention import RadixAttention
|
45
|
+
from sglang.srt.layers.sampler import Sampler
|
48
46
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
47
|
|
50
48
|
LoraConfig = None
|
@@ -383,17 +381,11 @@ class ChatGLMForCausalLM(nn.Module):
|
|
383
381
|
input_metadata: InputMetadata,
|
384
382
|
) -> torch.Tensor:
|
385
383
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
386
|
-
|
384
|
+
logits_output = self.logits_processor(
|
387
385
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
388
386
|
)
|
389
|
-
|
390
|
-
|
391
|
-
self,
|
392
|
-
logits: torch.Tensor,
|
393
|
-
sampling_metadata: SamplingMetadata,
|
394
|
-
) -> Optional[SamplerOutput]:
|
395
|
-
next_tokens = self.sampler(logits, sampling_metadata)
|
396
|
-
return next_tokens
|
387
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
388
|
+
return sample_output, logits_output
|
397
389
|
|
398
390
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
399
391
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
sglang/srt/models/commandr.py
CHANGED
@@ -64,6 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
|
64
64
|
from sglang.srt.layers.activation import SiluAndMul
|
65
65
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
66
66
|
from sglang.srt.layers.radix_attention import RadixAttention
|
67
|
+
from sglang.srt.layers.sampler import Sampler
|
67
68
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
68
69
|
|
69
70
|
|
@@ -326,6 +327,7 @@ class CohereForCausalLM(nn.Module):
|
|
326
327
|
self.config = config
|
327
328
|
self.quant_config = quant_config
|
328
329
|
self.logits_processor = LogitsProcessor(config)
|
330
|
+
self.sampler = Sampler()
|
329
331
|
self.model = CohereModel(config, quant_config)
|
330
332
|
|
331
333
|
@torch.no_grad()
|
@@ -340,9 +342,11 @@ class CohereForCausalLM(nn.Module):
|
|
340
342
|
positions,
|
341
343
|
input_metadata,
|
342
344
|
)
|
343
|
-
|
345
|
+
logits_output = self.logits_processor(
|
344
346
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
345
347
|
)
|
348
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
349
|
+
return sample_output, logits_output
|
346
350
|
|
347
351
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
348
352
|
stacked_params_mapping = [
|
sglang/srt/models/dbrx.py
CHANGED
@@ -45,6 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
+
from sglang.srt.layers.sampler import Sampler
|
48
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
50
|
|
50
51
|
|
@@ -382,6 +383,7 @@ class DbrxForCausalLM(nn.Module):
|
|
382
383
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
383
384
|
)
|
384
385
|
self.logits_processor = LogitsProcessor(config)
|
386
|
+
self.sampler = Sampler()
|
385
387
|
|
386
388
|
@torch.no_grad()
|
387
389
|
def forward(
|
@@ -391,9 +393,11 @@ class DbrxForCausalLM(nn.Module):
|
|
391
393
|
input_metadata: InputMetadata,
|
392
394
|
) -> torch.Tensor:
|
393
395
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
394
|
-
|
396
|
+
logits_output = self.logits_processor(
|
395
397
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
396
398
|
)
|
399
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
400
|
+
return sample_output, logits_output
|
397
401
|
|
398
402
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
399
403
|
expert_params_mapping = [
|
sglang/srt/models/deepseek.py
CHANGED
@@ -46,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
46
46
|
from sglang.srt.layers.layernorm import RMSNorm
|
47
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
+
from sglang.srt.layers.sampler import Sampler
|
49
50
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
51
|
|
51
52
|
|
@@ -385,6 +386,7 @@ class DeepseekForCausalLM(nn.Module):
|
|
385
386
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
386
387
|
)
|
387
388
|
self.logits_processor = LogitsProcessor(config)
|
389
|
+
self.sampler = Sampler()
|
388
390
|
|
389
391
|
@torch.no_grad()
|
390
392
|
def forward(
|
@@ -394,9 +396,11 @@ class DeepseekForCausalLM(nn.Module):
|
|
394
396
|
input_metadata: InputMetadata,
|
395
397
|
) -> torch.Tensor:
|
396
398
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
397
|
-
|
399
|
+
logits_output = self.logits_processor(
|
398
400
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
399
401
|
)
|
402
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
403
|
+
return sample_output, logits_output
|
400
404
|
|
401
405
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
402
406
|
stacked_params_mapping = [
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -45,6 +45,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
45
45
|
from sglang.srt.layers.layernorm import RMSNorm
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
+
from sglang.srt.layers.sampler import Sampler
|
48
49
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
49
50
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
51
|
|
@@ -632,6 +633,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
632
633
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
633
634
|
)
|
634
635
|
self.logits_processor = LogitsProcessor(config)
|
636
|
+
self.sampler = Sampler()
|
635
637
|
|
636
638
|
def forward(
|
637
639
|
self,
|
@@ -640,9 +642,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
640
642
|
input_metadata: InputMetadata,
|
641
643
|
) -> torch.Tensor:
|
642
644
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
643
|
-
|
645
|
+
logits_output = self.logits_processor(
|
644
646
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
645
647
|
)
|
648
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
649
|
+
return sample_output, logits_output
|
646
650
|
|
647
651
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
648
652
|
stacked_params_mapping = [
|
sglang/srt/models/gemma.py
CHANGED
@@ -37,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
37
37
|
from sglang.srt.layers.layernorm import RMSNorm
|
38
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
+
from sglang.srt.layers.sampler import Sampler
|
40
41
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
41
42
|
|
42
43
|
|
@@ -287,6 +288,7 @@ class GemmaForCausalLM(nn.Module):
|
|
287
288
|
self.quant_config = quant_config
|
288
289
|
self.model = GemmaModel(config, quant_config=quant_config)
|
289
290
|
self.logits_processor = LogitsProcessor(config)
|
291
|
+
self.sampler = Sampler()
|
290
292
|
|
291
293
|
@torch.no_grad()
|
292
294
|
def forward(
|
@@ -297,9 +299,11 @@ class GemmaForCausalLM(nn.Module):
|
|
297
299
|
input_embeds: torch.Tensor = None,
|
298
300
|
) -> torch.Tensor:
|
299
301
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
300
|
-
|
302
|
+
logits_output = self.logits_processor(
|
301
303
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
302
304
|
)
|
305
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
306
|
+
return (sample_output, logits_output)
|
303
307
|
|
304
308
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
305
309
|
stacked_params_mapping = [
|