sglang 0.2.13__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +6 -0
- sglang/bench_latency.py +7 -3
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +9 -0
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/fsm_cache.py +11 -2
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +100 -1
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/fused_moe/layer.py +2 -2
- sglang/srt/layers/logits_processor.py +56 -19
- sglang/srt/layers/radix_attention.py +3 -4
- sglang/srt/layers/sampler.py +101 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +46 -166
- sglang/srt/managers/tokenizer_manager.py +192 -83
- sglang/srt/managers/tp_worker.py +118 -24
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +32 -8
- sglang/srt/model_executor/forward_batch_info.py +51 -26
- sglang/srt/model_executor/model_runner.py +201 -58
- sglang/srt/models/gemma2.py +10 -6
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +11 -1
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/qwen2.py +9 -3
- sglang/srt/openai_api/adapter.py +200 -39
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_batch_info.py +136 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +22 -0
- sglang/srt/server.py +92 -57
- sglang/srt/server_args.py +43 -15
- sglang/srt/utils.py +26 -16
- sglang/test/runners.py +22 -30
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_utils.py +36 -53
- sglang/version.py +1 -1
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +100 -27
- sglang-0.2.14.post1.dist-info/RECORD +114 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang-0.2.13.dist-info/RECORD +0 -112
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
@@ -15,11 +15,11 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""ModelRunner runs the forward passes of the models."""
|
17
17
|
|
18
|
+
import gc
|
18
19
|
import importlib
|
19
20
|
import importlib.resources
|
20
21
|
import logging
|
21
22
|
import pkgutil
|
22
|
-
import warnings
|
23
23
|
from functools import lru_cache
|
24
24
|
from typing import Optional, Type
|
25
25
|
|
@@ -37,7 +37,9 @@ from vllm.distributed import (
|
|
37
37
|
get_tp_group,
|
38
38
|
init_distributed_environment,
|
39
39
|
initialize_model_parallel,
|
40
|
+
set_custom_all_reduce,
|
40
41
|
)
|
42
|
+
from vllm.distributed.parallel_state import in_the_same_node_as
|
41
43
|
from vllm.model_executor.model_loader import get_model
|
42
44
|
from vllm.model_executor.models import ModelRegistry
|
43
45
|
|
@@ -88,22 +90,35 @@ class ModelRunner:
|
|
88
90
|
{
|
89
91
|
"disable_flashinfer": server_args.disable_flashinfer,
|
90
92
|
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
|
91
|
-
"
|
93
|
+
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
92
94
|
"enable_mla": server_args.enable_mla,
|
93
95
|
}
|
94
96
|
)
|
95
97
|
|
98
|
+
min_per_gpu_memory = self.init_torch_distributed()
|
99
|
+
self.load_model()
|
100
|
+
self.init_memory_pool(
|
101
|
+
min_per_gpu_memory,
|
102
|
+
server_args.max_num_reqs,
|
103
|
+
server_args.max_total_tokens,
|
104
|
+
)
|
105
|
+
self.init_cublas()
|
106
|
+
self.init_flashinfer()
|
107
|
+
self.init_cuda_graphs()
|
108
|
+
|
109
|
+
def init_torch_distributed(self):
|
96
110
|
# Init torch distributed
|
97
111
|
torch.cuda.set_device(self.gpu_id)
|
98
|
-
logger.info(
|
112
|
+
logger.info("Init nccl begin.")
|
99
113
|
|
100
|
-
if not server_args.enable_p2p_check:
|
114
|
+
if not self.server_args.enable_p2p_check:
|
101
115
|
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
102
116
|
|
103
|
-
if server_args.nccl_init_addr:
|
104
|
-
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
|
117
|
+
if self.server_args.nccl_init_addr:
|
118
|
+
nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
|
105
119
|
else:
|
106
120
|
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
|
121
|
+
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
107
122
|
init_distributed_environment(
|
108
123
|
backend="nccl",
|
109
124
|
world_size=self.tp_size,
|
@@ -112,43 +127,45 @@ class ModelRunner:
|
|
112
127
|
distributed_init_method=nccl_init_method,
|
113
128
|
)
|
114
129
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
115
|
-
|
116
|
-
total_gpu_memory = get_available_gpu_memory(
|
130
|
+
min_per_gpu_memory = get_available_gpu_memory(
|
117
131
|
self.gpu_id, distributed=self.tp_size > 1
|
118
132
|
)
|
133
|
+
self.tp_group = get_tp_group()
|
119
134
|
|
135
|
+
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
|
136
|
+
# so we disable padding in cuda graph.
|
137
|
+
if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
|
138
|
+
self.server_args.disable_cuda_graph_padding = True
|
139
|
+
logger.info(
|
140
|
+
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
|
141
|
+
)
|
142
|
+
|
143
|
+
# Check memory for tensor parallelism
|
120
144
|
if self.tp_size > 1:
|
121
|
-
|
122
|
-
if
|
145
|
+
local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
146
|
+
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
123
147
|
raise ValueError(
|
124
148
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
125
149
|
)
|
126
150
|
|
127
|
-
|
128
|
-
self.load_model()
|
129
|
-
self.init_memory_pool(
|
130
|
-
total_gpu_memory,
|
131
|
-
server_args.max_num_reqs,
|
132
|
-
server_args.max_total_tokens,
|
133
|
-
)
|
134
|
-
self.init_cublas()
|
135
|
-
self.init_flashinfer()
|
136
|
-
|
137
|
-
if self.is_generation:
|
138
|
-
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
|
139
|
-
# Capture cuda graphs
|
140
|
-
self.init_cuda_graphs()
|
151
|
+
return min_per_gpu_memory
|
141
152
|
|
142
153
|
def load_model(self):
|
143
154
|
logger.info(
|
144
|
-
f"
|
145
|
-
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
155
|
+
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
146
156
|
)
|
157
|
+
if torch.cuda.get_device_capability()[0] < 8:
|
158
|
+
logger.info(
|
159
|
+
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
160
|
+
)
|
161
|
+
self.server_args.dtype = "float16"
|
162
|
+
if torch.cuda.get_device_capability()[1] < 5:
|
163
|
+
raise RuntimeError("SGLang only supports sm75 and above.")
|
147
164
|
|
148
165
|
monkey_patch_vllm_dummy_weight_loader()
|
149
|
-
device_config = DeviceConfig()
|
150
|
-
load_config = LoadConfig(load_format=self.server_args.load_format)
|
151
|
-
vllm_model_config = VllmModelConfig(
|
166
|
+
self.device_config = DeviceConfig()
|
167
|
+
self.load_config = LoadConfig(load_format=self.server_args.load_format)
|
168
|
+
self.vllm_model_config = VllmModelConfig(
|
152
169
|
model=self.server_args.model_path,
|
153
170
|
quantization=self.server_args.quantization,
|
154
171
|
tokenizer=None,
|
@@ -159,43 +176,132 @@ class ModelRunner:
|
|
159
176
|
skip_tokenizer_init=True,
|
160
177
|
)
|
161
178
|
|
179
|
+
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
180
|
+
# Drop this after Sept, 2024.
|
162
181
|
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
|
163
|
-
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
164
182
|
self.model_config.hf_config.num_key_value_heads = 8
|
165
|
-
vllm_model_config.hf_config.num_key_value_heads = 8
|
183
|
+
self.vllm_model_config.hf_config.num_key_value_heads = 8
|
166
184
|
monkey_patch_vllm_qvk_linear_loader()
|
167
185
|
|
168
|
-
self.dtype = vllm_model_config.dtype
|
186
|
+
self.dtype = self.vllm_model_config.dtype
|
169
187
|
if self.model_config.model_overide_args is not None:
|
170
|
-
vllm_model_config.hf_config.update(
|
188
|
+
self.vllm_model_config.hf_config.update(
|
189
|
+
self.model_config.model_overide_args
|
190
|
+
)
|
171
191
|
|
172
192
|
self.model = get_model(
|
173
|
-
model_config=vllm_model_config,
|
174
|
-
|
175
|
-
|
176
|
-
lora_config=None,
|
177
|
-
multimodal_config=None,
|
193
|
+
model_config=self.vllm_model_config,
|
194
|
+
load_config=self.load_config,
|
195
|
+
device_config=self.device_config,
|
178
196
|
parallel_config=None,
|
179
197
|
scheduler_config=None,
|
198
|
+
lora_config=None,
|
180
199
|
cache_config=None,
|
181
200
|
)
|
182
201
|
self.sliding_window_size = (
|
183
|
-
self.model.
|
184
|
-
if hasattr(self.model, "
|
202
|
+
self.model.get_attention_sliding_window_size()
|
203
|
+
if hasattr(self.model, "get_attention_sliding_window_size")
|
185
204
|
else None
|
186
205
|
)
|
187
206
|
self.is_generation = is_generation_model(
|
188
|
-
self.model_config.hf_config.architectures
|
207
|
+
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
189
208
|
)
|
190
209
|
|
191
210
|
logger.info(
|
192
|
-
f"
|
211
|
+
f"Load weight end. "
|
193
212
|
f"type={type(self.model).__name__}, "
|
194
213
|
f"dtype={self.dtype}, "
|
195
214
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
196
215
|
)
|
197
216
|
|
198
|
-
def
|
217
|
+
def update_weights(self, model_path: str, load_format: str):
|
218
|
+
"""Update weights in-place."""
|
219
|
+
from vllm.model_executor.model_loader.loader import (
|
220
|
+
DefaultModelLoader,
|
221
|
+
device_loading_context,
|
222
|
+
get_model_loader,
|
223
|
+
)
|
224
|
+
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
225
|
+
|
226
|
+
logger.info(
|
227
|
+
f"Update weights begin. "
|
228
|
+
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
229
|
+
)
|
230
|
+
|
231
|
+
target_device = torch.device(self.device_config.device)
|
232
|
+
|
233
|
+
try:
|
234
|
+
# TODO: Use a better method to check this
|
235
|
+
vllm_model_config = VllmModelConfig(
|
236
|
+
model=model_path,
|
237
|
+
quantization=self.server_args.quantization,
|
238
|
+
tokenizer=None,
|
239
|
+
tokenizer_mode=None,
|
240
|
+
trust_remote_code=self.server_args.trust_remote_code,
|
241
|
+
dtype=self.server_args.dtype,
|
242
|
+
seed=42,
|
243
|
+
skip_tokenizer_init=True,
|
244
|
+
)
|
245
|
+
except Exception as e:
|
246
|
+
logger.error(f"Failed to load model config: {e}")
|
247
|
+
return False, "Failed to update model weights"
|
248
|
+
|
249
|
+
load_config = LoadConfig(load_format=load_format)
|
250
|
+
|
251
|
+
# Only support vllm DefaultModelLoader for now
|
252
|
+
loader = get_model_loader(load_config)
|
253
|
+
if not isinstance(loader, DefaultModelLoader):
|
254
|
+
logger.error("Failed to get weights iterator: Unsupported loader")
|
255
|
+
return False, "Failed to update model weights"
|
256
|
+
|
257
|
+
def get_weight_iter(config):
|
258
|
+
iter = loader._get_weights_iterator(
|
259
|
+
config.model,
|
260
|
+
config.revision,
|
261
|
+
fall_back_to_pt=getattr(
|
262
|
+
self.model, "fall_back_to_pt_during_load", True
|
263
|
+
),
|
264
|
+
)
|
265
|
+
return iter
|
266
|
+
|
267
|
+
def model_load_weights(model, iter):
|
268
|
+
model.load_weights(iter)
|
269
|
+
for _, module in self.model.named_modules():
|
270
|
+
quant_method = getattr(module, "quant_method", None)
|
271
|
+
if quant_method is not None:
|
272
|
+
with device_loading_context(module, target_device):
|
273
|
+
quant_method.process_weights_after_loading(module)
|
274
|
+
return model
|
275
|
+
|
276
|
+
with set_default_torch_dtype(vllm_model_config.dtype):
|
277
|
+
try:
|
278
|
+
iter = get_weight_iter(vllm_model_config)
|
279
|
+
except Exception as e:
|
280
|
+
message = f"Failed to get weights iterator: {e}"
|
281
|
+
logger.error(message)
|
282
|
+
return False, message
|
283
|
+
try:
|
284
|
+
model = model_load_weights(self.model, iter)
|
285
|
+
except Exception as e:
|
286
|
+
message = f"Failed to update weights: {e}. \n Rolling back to original weights"
|
287
|
+
logger.error(message)
|
288
|
+
del iter
|
289
|
+
gc.collect()
|
290
|
+
iter = get_weight_iter(self.vllm_model_config)
|
291
|
+
self.model = model_load_weights(self.model, iter)
|
292
|
+
return False, message
|
293
|
+
|
294
|
+
self.model = model
|
295
|
+
self.server_args.model_path = model_path
|
296
|
+
self.server_args.load_format = load_format
|
297
|
+
self.vllm_model_config = vllm_model_config
|
298
|
+
self.load_config = load_config
|
299
|
+
self.model_config.path = model_path
|
300
|
+
|
301
|
+
logger.info("Update weights end.")
|
302
|
+
return True, "Succeeded to update model weights"
|
303
|
+
|
304
|
+
def profile_max_num_token(self, total_gpu_memory: int):
|
199
305
|
available_gpu_memory = get_available_gpu_memory(
|
200
306
|
self.gpu_id, distributed=self.tp_size > 1
|
201
307
|
)
|
@@ -206,7 +312,7 @@ class ModelRunner:
|
|
206
312
|
cell_size = (
|
207
313
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
208
314
|
* self.model_config.num_hidden_layers
|
209
|
-
* torch._utils._element_size(self.
|
315
|
+
* torch._utils._element_size(self.kv_cache_dtype)
|
210
316
|
)
|
211
317
|
else:
|
212
318
|
cell_size = (
|
@@ -214,7 +320,7 @@ class ModelRunner:
|
|
214
320
|
* self.model_config.head_dim
|
215
321
|
* self.model_config.num_hidden_layers
|
216
322
|
* 2
|
217
|
-
* torch._utils._element_size(self.
|
323
|
+
* torch._utils._element_size(self.kv_cache_dtype)
|
218
324
|
)
|
219
325
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
220
326
|
1 - self.mem_fraction_static
|
@@ -223,12 +329,30 @@ class ModelRunner:
|
|
223
329
|
return max_num_token
|
224
330
|
|
225
331
|
def init_memory_pool(
|
226
|
-
self,
|
332
|
+
self,
|
333
|
+
total_gpu_memory: int,
|
334
|
+
max_num_reqs: int = None,
|
335
|
+
max_total_tokens: int = None,
|
227
336
|
):
|
337
|
+
if self.server_args.kv_cache_dtype == "auto":
|
338
|
+
self.kv_cache_dtype = self.dtype
|
339
|
+
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
340
|
+
if self.server_args.disable_flashinfer or self.server_args.enable_mla:
|
341
|
+
logger.warning(
|
342
|
+
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
|
343
|
+
)
|
344
|
+
self.kv_cache_dtype = self.dtype
|
345
|
+
else:
|
346
|
+
self.kv_cache_dtype = torch.float8_e5m2
|
347
|
+
else:
|
348
|
+
raise ValueError(
|
349
|
+
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
350
|
+
)
|
351
|
+
|
228
352
|
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
229
353
|
if max_total_tokens is not None:
|
230
354
|
if max_total_tokens > self.max_total_num_tokens:
|
231
|
-
|
355
|
+
logging.warning(
|
232
356
|
f"max_total_tokens={max_total_tokens} is larger than the profiled value "
|
233
357
|
f"{self.max_total_num_tokens}. "
|
234
358
|
f"Use the profiled value instead."
|
@@ -261,7 +385,7 @@ class ModelRunner:
|
|
261
385
|
):
|
262
386
|
self.token_to_kv_pool = MLATokenToKVPool(
|
263
387
|
self.max_total_num_tokens,
|
264
|
-
dtype=self.
|
388
|
+
dtype=self.kv_cache_dtype,
|
265
389
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
266
390
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
267
391
|
layer_num=self.model_config.num_hidden_layers,
|
@@ -272,13 +396,13 @@ class ModelRunner:
|
|
272
396
|
else:
|
273
397
|
self.token_to_kv_pool = MHATokenToKVPool(
|
274
398
|
self.max_total_num_tokens,
|
275
|
-
dtype=self.
|
399
|
+
dtype=self.kv_cache_dtype,
|
276
400
|
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
277
401
|
head_dim=self.model_config.head_dim,
|
278
402
|
layer_num=self.model_config.num_hidden_layers,
|
279
403
|
)
|
280
404
|
logger.info(
|
281
|
-
f"
|
405
|
+
f"Memory pool end. "
|
282
406
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
283
407
|
)
|
284
408
|
|
@@ -292,6 +416,7 @@ class ModelRunner:
|
|
292
416
|
return c
|
293
417
|
|
294
418
|
def init_flashinfer(self):
|
419
|
+
"""Init flashinfer attention kernel wrappers."""
|
295
420
|
if self.server_args.disable_flashinfer:
|
296
421
|
assert (
|
297
422
|
self.sliding_window_size is None
|
@@ -352,20 +477,29 @@ class ModelRunner:
|
|
352
477
|
)
|
353
478
|
|
354
479
|
def init_cuda_graphs(self):
|
480
|
+
"""Capture cuda graphs."""
|
481
|
+
if not self.is_generation:
|
482
|
+
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
483
|
+
return
|
484
|
+
|
355
485
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
356
486
|
|
357
487
|
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
|
358
488
|
self.cuda_graph_runner = None
|
359
489
|
return
|
360
490
|
|
361
|
-
logger.info(
|
362
|
-
|
363
|
-
|
364
|
-
|
491
|
+
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
492
|
+
|
493
|
+
if self.server_args.disable_cuda_graph_padding:
|
494
|
+
batch_size_list = list(range(1, 32)) + [64, 128]
|
495
|
+
else:
|
496
|
+
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
497
|
+
|
365
498
|
self.cuda_graph_runner = CudaGraphRunner(
|
366
499
|
self,
|
367
500
|
max_batch_size_to_capture=max(batch_size_list),
|
368
501
|
use_torch_compile=self.server_args.enable_torch_compile,
|
502
|
+
disable_padding=self.server_args.disable_cuda_graph_padding,
|
369
503
|
)
|
370
504
|
try:
|
371
505
|
self.cuda_graph_runner.capture(batch_size_list)
|
@@ -401,9 +535,18 @@ class ModelRunner:
|
|
401
535
|
batch,
|
402
536
|
forward_mode=ForwardMode.EXTEND,
|
403
537
|
)
|
404
|
-
|
405
|
-
|
406
|
-
|
538
|
+
if self.is_generation:
|
539
|
+
return self.model.forward(
|
540
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
541
|
+
)
|
542
|
+
else:
|
543
|
+
# Only embedding models have get_embedding parameter
|
544
|
+
return self.model.forward(
|
545
|
+
batch.input_ids,
|
546
|
+
input_metadata.positions,
|
547
|
+
input_metadata,
|
548
|
+
get_embedding=True,
|
549
|
+
)
|
407
550
|
|
408
551
|
@torch.inference_mode()
|
409
552
|
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
@@ -477,4 +620,4 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
|
477
620
|
|
478
621
|
|
479
622
|
# Monkey patch model loader
|
480
|
-
setattr(ModelRegistry, "
|
623
|
+
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
|
sglang/srt/models/gemma2.py
CHANGED
@@ -25,7 +25,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
|
25
25
|
|
26
26
|
# FIXME: temporary solution, remove after next vllm release
|
27
27
|
from vllm.model_executor.custom_op import CustomOp
|
28
|
-
from vllm.model_executor.layers.activation import GeluAndMul
|
29
28
|
|
30
29
|
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
31
30
|
from vllm.model_executor.layers.linear import (
|
@@ -39,6 +38,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|
39
38
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
40
39
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
41
40
|
|
41
|
+
from sglang.srt.layers.activation import GeluAndMul
|
42
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
44
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
@@ -46,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
|
46
46
|
|
47
47
|
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
48
48
|
# SGLang assumes exclusive
|
49
|
-
def
|
49
|
+
def get_attention_sliding_window_size(config):
|
50
50
|
return config.sliding_window - 1
|
51
51
|
|
52
52
|
|
@@ -135,7 +135,7 @@ class Gemma2MLP(nn.Module):
|
|
135
135
|
"function. Please set `hidden_act` and `hidden_activation` to "
|
136
136
|
"`gelu_pytorch_tanh`."
|
137
137
|
)
|
138
|
-
self.act_fn = GeluAndMul(
|
138
|
+
self.act_fn = GeluAndMul()
|
139
139
|
|
140
140
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
141
141
|
gate_up, _ = self.gate_up_proj(x)
|
@@ -213,7 +213,11 @@ class Gemma2Attention(nn.Module):
|
|
213
213
|
self.scaling,
|
214
214
|
num_kv_heads=self.num_kv_heads,
|
215
215
|
layer_id=layer_idx,
|
216
|
-
sliding_window_size=
|
216
|
+
sliding_window_size=(
|
217
|
+
get_attention_sliding_window_size(config)
|
218
|
+
if use_sliding_window
|
219
|
+
else None
|
220
|
+
),
|
217
221
|
logit_cap=self.config.attn_logit_softcapping,
|
218
222
|
)
|
219
223
|
|
@@ -406,8 +410,8 @@ class Gemma2ForCausalLM(nn.Module):
|
|
406
410
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
407
411
|
)
|
408
412
|
|
409
|
-
def
|
410
|
-
return
|
413
|
+
def get_attention_sliding_window_size(self):
|
414
|
+
return get_attention_sliding_window_size(self.config)
|
411
415
|
|
412
416
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
413
417
|
stacked_params_mapping = [
|
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -23,7 +23,6 @@ from torch import nn
|
|
23
23
|
from transformers import GPTBigCodeConfig
|
24
24
|
from vllm.config import CacheConfig, LoRAConfig
|
25
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
|
-
from vllm.model_executor.layers.activation import get_act_fn
|
27
26
|
from vllm.model_executor.layers.linear import (
|
28
27
|
ColumnParallelLinear,
|
29
28
|
QKVParallelLinear,
|
@@ -33,6 +32,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|
33
32
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
34
33
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
35
34
|
|
35
|
+
from sglang.srt.layers.activation import get_act_fn
|
36
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
38
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
sglang/srt/models/grok.py
CHANGED
@@ -300,6 +300,9 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
300
300
|
|
301
301
|
# Monkey patch _prepare_weights to load pre-sharded weights
|
302
302
|
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
303
|
+
|
304
|
+
self.use_presharded_weights = True
|
305
|
+
|
303
306
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
304
307
|
|
305
308
|
def forward(
|
@@ -355,6 +358,13 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
355
358
|
continue
|
356
359
|
name = name.replace(weight_name, param_name)
|
357
360
|
|
361
|
+
if self.use_presharded_weights:
|
362
|
+
extra_kwargs = {
|
363
|
+
"use_presharded_weights": self.use_presharded_weights
|
364
|
+
}
|
365
|
+
else:
|
366
|
+
extra_kwargs = {}
|
367
|
+
|
358
368
|
param = params_dict[name]
|
359
369
|
weight_loader = param.weight_loader
|
360
370
|
weight_loader(
|
@@ -363,7 +373,7 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
363
373
|
weight_name,
|
364
374
|
shard_id=shard_id,
|
365
375
|
expert_id=expert_id,
|
366
|
-
|
376
|
+
**extra_kwargs,
|
367
377
|
)
|
368
378
|
break
|
369
379
|
else:
|
@@ -29,7 +29,11 @@ class LlamaEmbeddingModel(nn.Module):
|
|
29
29
|
positions: torch.Tensor,
|
30
30
|
input_metadata: InputMetadata,
|
31
31
|
input_embeds: torch.Tensor = None,
|
32
|
+
get_embedding: bool = True,
|
32
33
|
) -> EmbeddingPoolerOutput:
|
34
|
+
assert (
|
35
|
+
get_embedding
|
36
|
+
), "LlamaEmbeddingModel / MistralModel is only used for embedding"
|
33
37
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
34
38
|
return self.pooler(hidden_states, input_metadata)
|
35
39
|
|