sglang 0.2.14.post1__py3-none-any.whl → 0.2.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +2 -0
- sglang/bench_latency.py +39 -28
- sglang/lang/interpreter.py +3 -0
- sglang/lang/ir.py +5 -0
- sglang/launch_server_llavavid.py +26 -0
- sglang/srt/configs/__init__.py +5 -0
- sglang/srt/configs/exaone.py +195 -0
- sglang/srt/constrained/fsm_cache.py +1 -1
- sglang/srt/conversation.py +24 -2
- sglang/srt/hf_transformers_utils.py +11 -160
- sglang/srt/layers/activation.py +10 -4
- sglang/srt/layers/extend_attention.py +13 -8
- sglang/srt/layers/layernorm.py +47 -1
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/sampler.py +69 -16
- sglang/srt/managers/controller_multi.py +5 -5
- sglang/srt/managers/controller_single.py +5 -5
- sglang/srt/managers/io_struct.py +11 -5
- sglang/srt/managers/schedule_batch.py +25 -13
- sglang/srt/managers/tokenizer_manager.py +76 -63
- sglang/srt/managers/tp_worker.py +47 -36
- sglang/srt/model_config.py +3 -3
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +78 -43
- sglang/srt/model_executor/model_runner.py +29 -18
- sglang/srt/models/chatglm.py +5 -13
- sglang/srt/models/commandr.py +5 -1
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +5 -1
- sglang/srt/models/deepseek_v2.py +57 -25
- sglang/srt/models/exaone.py +399 -0
- sglang/srt/models/gemma.py +7 -3
- sglang/srt/models/gemma2.py +6 -52
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +14 -4
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +10 -7
- sglang/srt/models/llama_classification.py +2 -6
- sglang/srt/models/llama_embedding.py +3 -4
- sglang/srt/models/llava.py +69 -91
- sglang/srt/models/llavavid.py +40 -86
- sglang/srt/models/minicpm.py +5 -1
- sglang/srt/models/mixtral.py +6 -2
- sglang/srt/models/mixtral_quant.py +5 -1
- sglang/srt/models/qwen.py +5 -2
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_moe.py +12 -33
- sglang/srt/models/stablelm.py +5 -1
- sglang/srt/models/yivl.py +2 -7
- sglang/srt/openai_api/adapter.py +16 -1
- sglang/srt/openai_api/protocol.py +5 -5
- sglang/srt/sampling/sampling_batch_info.py +79 -6
- sglang/srt/server.py +9 -9
- sglang/srt/utils.py +18 -36
- sglang/test/runners.py +2 -2
- sglang/test/test_layernorm.py +53 -1
- sglang/version.py +1 -1
- {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/METADATA +8 -8
- sglang-0.2.15.dist-info/RECORD +118 -0
- sglang-0.2.14.post1.dist-info/RECORD +0 -114
- {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/LICENSE +0 -0
- {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/WHEEL +0 -0
- {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/top_level.txt +0 -0
@@ -26,16 +26,18 @@ from vllm.distributed.parallel_state import graph_capture
|
|
26
26
|
from vllm.model_executor.custom_op import CustomOp
|
27
27
|
|
28
28
|
from sglang.srt.layers.logits_processor import (
|
29
|
-
LogitProcessorOutput,
|
30
29
|
LogitsMetadata,
|
31
30
|
LogitsProcessor,
|
31
|
+
LogitsProcessorOutput,
|
32
32
|
)
|
33
|
+
from sglang.srt.layers.sampler import SampleOutput
|
33
34
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
34
35
|
from sglang.srt.model_executor.forward_batch_info import (
|
35
36
|
ForwardMode,
|
36
37
|
InputMetadata,
|
37
38
|
update_flashinfer_indices,
|
38
39
|
)
|
40
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
39
41
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
40
42
|
|
41
43
|
|
@@ -144,6 +146,10 @@ class CudaGraphRunner:
|
|
144
146
|
self.flashinfer_kv_indices.clone(),
|
145
147
|
]
|
146
148
|
|
149
|
+
# Sampling inputs
|
150
|
+
vocab_size = model_runner.model_config.vocab_size
|
151
|
+
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
|
152
|
+
|
147
153
|
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
148
154
|
|
149
155
|
if use_torch_compile:
|
@@ -235,6 +241,7 @@ class CudaGraphRunner:
|
|
235
241
|
def run_once():
|
236
242
|
input_metadata = InputMetadata(
|
237
243
|
forward_mode=ForwardMode.DECODE,
|
244
|
+
sampling_info=self.sampling_info[:bs],
|
238
245
|
batch_size=bs,
|
239
246
|
req_pool_indices=req_pool_indices,
|
240
247
|
seq_lens=seq_lens,
|
@@ -299,27 +306,35 @@ class CudaGraphRunner:
|
|
299
306
|
self.flashinfer_handlers[bs],
|
300
307
|
)
|
301
308
|
|
309
|
+
# Sampling inputs
|
310
|
+
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
|
311
|
+
|
302
312
|
# Replay
|
303
313
|
torch.cuda.synchronize()
|
304
314
|
self.graphs[bs].replay()
|
305
315
|
torch.cuda.synchronize()
|
306
|
-
|
316
|
+
sample_output, logits_output = self.output_buffers[bs]
|
307
317
|
|
308
318
|
# Unpad
|
309
319
|
if bs != raw_bs:
|
310
|
-
|
311
|
-
next_token_logits=
|
320
|
+
logits_output = LogitsProcessorOutput(
|
321
|
+
next_token_logits=logits_output.next_token_logits[:raw_bs],
|
312
322
|
next_token_logprobs=None,
|
313
323
|
normalized_prompt_logprobs=None,
|
314
324
|
input_token_logprobs=None,
|
315
325
|
input_top_logprobs=None,
|
316
326
|
output_top_logprobs=None,
|
317
327
|
)
|
328
|
+
sample_output = SampleOutput(
|
329
|
+
sample_output.success[:raw_bs],
|
330
|
+
sample_output.probs[:raw_bs],
|
331
|
+
sample_output.batch_next_token_ids[:raw_bs],
|
332
|
+
)
|
318
333
|
|
319
334
|
# Extract logprobs
|
320
335
|
if batch.return_logprob:
|
321
|
-
|
322
|
-
|
336
|
+
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
|
337
|
+
logits_output.next_token_logits, dim=-1
|
323
338
|
)
|
324
339
|
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
|
325
340
|
if return_top_logprob:
|
@@ -327,8 +342,8 @@ class CudaGraphRunner:
|
|
327
342
|
forward_mode=ForwardMode.DECODE,
|
328
343
|
top_logprobs_nums=batch.top_logprobs_nums,
|
329
344
|
)
|
330
|
-
|
331
|
-
|
345
|
+
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
346
|
+
logits_output.next_token_logprobs, logits_metadata
|
332
347
|
)[1]
|
333
348
|
|
334
|
-
return
|
349
|
+
return sample_output, logits_output
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
"""
|
2
4
|
Copyright 2023-2024 SGLang Team
|
3
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -16,16 +18,19 @@ limitations under the License.
|
|
16
18
|
"""ModelRunner runs the forward passes of the models."""
|
17
19
|
from dataclasses import dataclass
|
18
20
|
from enum import IntEnum, auto
|
19
|
-
from typing import TYPE_CHECKING, List
|
21
|
+
from typing import TYPE_CHECKING, List
|
20
22
|
|
21
23
|
import numpy as np
|
22
24
|
import torch
|
25
|
+
import triton
|
26
|
+
import triton.language as tl
|
23
27
|
|
24
28
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
25
29
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
26
30
|
|
27
31
|
if TYPE_CHECKING:
|
28
32
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
33
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
29
34
|
|
30
35
|
|
31
36
|
class ForwardMode(IntEnum):
|
@@ -42,6 +47,7 @@ class InputMetadata:
|
|
42
47
|
"""Store all inforamtion of a forward pass."""
|
43
48
|
|
44
49
|
forward_mode: ForwardMode
|
50
|
+
sampling_info: SamplingBatchInfo
|
45
51
|
batch_size: int
|
46
52
|
req_pool_indices: torch.Tensor
|
47
53
|
seq_lens: torch.Tensor
|
@@ -58,6 +64,7 @@ class InputMetadata:
|
|
58
64
|
|
59
65
|
# For extend
|
60
66
|
extend_seq_lens: torch.Tensor = None
|
67
|
+
extend_prefix_lens: torch.Tensor = None
|
61
68
|
extend_start_loc: torch.Tensor = None
|
62
69
|
extend_no_prefix: bool = None
|
63
70
|
|
@@ -69,8 +76,8 @@ class InputMetadata:
|
|
69
76
|
|
70
77
|
# For multimodal
|
71
78
|
pixel_values: List[torch.Tensor] = None
|
72
|
-
image_sizes: List[List[int]] = None
|
73
|
-
image_offsets: List[int] = None
|
79
|
+
image_sizes: List[List[List[int]]] = None
|
80
|
+
image_offsets: List[List[int]] = None
|
74
81
|
|
75
82
|
# Trition attention backend
|
76
83
|
triton_max_seq_len: int = 0
|
@@ -87,20 +94,8 @@ class InputMetadata:
|
|
87
94
|
def init_multimuldal_info(self, batch: ScheduleBatch):
|
88
95
|
reqs = batch.reqs
|
89
96
|
self.pixel_values = [r.pixel_values for r in reqs]
|
90
|
-
self.image_sizes = [r.
|
91
|
-
self.image_offsets = []
|
92
|
-
for r in reqs:
|
93
|
-
if isinstance(r.image_offset, list):
|
94
|
-
self.image_offsets.append(
|
95
|
-
[
|
96
|
-
(image_offset - len(r.prefix_indices))
|
97
|
-
for image_offset in r.image_offset
|
98
|
-
]
|
99
|
-
)
|
100
|
-
elif isinstance(r.image_offset, int):
|
101
|
-
self.image_offsets.append(r.image_offset - len(r.prefix_indices))
|
102
|
-
elif r.image_offset is None:
|
103
|
-
self.image_offsets.append(0)
|
97
|
+
self.image_sizes = [r.image_sizes for r in reqs]
|
98
|
+
self.image_offsets = [r.image_offsets for r in reqs]
|
104
99
|
|
105
100
|
def compute_positions(self, batch: ScheduleBatch):
|
106
101
|
position_ids_offsets = batch.position_ids_offsets
|
@@ -153,6 +148,7 @@ class InputMetadata:
|
|
153
148
|
for i, r in enumerate(batch.reqs)
|
154
149
|
]
|
155
150
|
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
151
|
+
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
156
152
|
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
157
153
|
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
158
154
|
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
@@ -179,6 +175,7 @@ class InputMetadata:
|
|
179
175
|
):
|
180
176
|
ret = cls(
|
181
177
|
forward_mode=forward_mode,
|
178
|
+
sampling_info=batch.sampling_info,
|
182
179
|
batch_size=batch.batch_size(),
|
183
180
|
req_pool_indices=batch.req_pool_indices,
|
184
181
|
seq_lens=batch.seq_lens,
|
@@ -189,6 +186,8 @@ class InputMetadata:
|
|
189
186
|
top_logprobs_nums=batch.top_logprobs_nums,
|
190
187
|
)
|
191
188
|
|
189
|
+
ret.sampling_info.prepare_penalties()
|
190
|
+
|
192
191
|
ret.compute_positions(batch)
|
193
192
|
|
194
193
|
ret.compute_extend_infos(batch)
|
@@ -238,10 +237,10 @@ class InputMetadata:
|
|
238
237
|
prefix_lens_cpu,
|
239
238
|
flashinfer_use_ragged,
|
240
239
|
):
|
241
|
-
if self.forward_mode
|
242
|
-
prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
|
243
|
-
else:
|
240
|
+
if self.forward_mode == ForwardMode.DECODE:
|
244
241
|
prefix_lens = None
|
242
|
+
else:
|
243
|
+
prefix_lens = self.extend_prefix_lens
|
245
244
|
|
246
245
|
update_flashinfer_indices(
|
247
246
|
self.forward_mode,
|
@@ -265,6 +264,42 @@ class InputMetadata:
|
|
265
264
|
)
|
266
265
|
|
267
266
|
|
267
|
+
@triton.jit
|
268
|
+
def create_flashinfer_kv_indices_triton(
|
269
|
+
req_to_token_ptr, # [max_batch, max_context_len]
|
270
|
+
req_pool_indices_ptr,
|
271
|
+
page_kernel_lens_ptr,
|
272
|
+
kv_indptr,
|
273
|
+
kv_start_idx,
|
274
|
+
max_context_len,
|
275
|
+
kv_indices_ptr,
|
276
|
+
):
|
277
|
+
BLOCK_SIZE: tl.constexpr = 512
|
278
|
+
pid = tl.program_id(axis=0)
|
279
|
+
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
280
|
+
kv_indices_offset = tl.load(kv_indptr + pid)
|
281
|
+
|
282
|
+
kv_start = 0
|
283
|
+
kv_end = 0
|
284
|
+
if kv_start_idx:
|
285
|
+
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
286
|
+
kv_end = kv_start
|
287
|
+
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
288
|
+
|
289
|
+
req_to_token_ptr += req_pool_index * max_context_len
|
290
|
+
kv_indices_ptr += kv_indices_offset
|
291
|
+
|
292
|
+
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
|
293
|
+
st_offset = tl.arange(0, BLOCK_SIZE)
|
294
|
+
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
295
|
+
for _ in range(num_loop):
|
296
|
+
mask = ld_offset < kv_end
|
297
|
+
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
|
298
|
+
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
|
299
|
+
ld_offset += BLOCK_SIZE
|
300
|
+
st_offset += BLOCK_SIZE
|
301
|
+
|
302
|
+
|
268
303
|
def update_flashinfer_indices(
|
269
304
|
forward_mode,
|
270
305
|
model_runner,
|
@@ -288,17 +323,18 @@ def update_flashinfer_indices(
|
|
288
323
|
|
289
324
|
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
290
325
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
)
|
326
|
+
|
327
|
+
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
328
|
+
create_flashinfer_kv_indices_triton[(batch_size,)](
|
329
|
+
model_runner.req_to_token_pool.req_to_token,
|
330
|
+
req_pool_indices,
|
331
|
+
paged_kernel_lens,
|
332
|
+
kv_indptr,
|
333
|
+
None,
|
334
|
+
model_runner.req_to_token_pool.req_to_token.size(1),
|
335
|
+
kv_indices,
|
336
|
+
)
|
337
|
+
|
302
338
|
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
303
339
|
|
304
340
|
if forward_mode == ForwardMode.DECODE:
|
@@ -368,18 +404,17 @@ def update_flashinfer_indices(
|
|
368
404
|
|
369
405
|
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
370
406
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
).contiguous()
|
407
|
+
|
408
|
+
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
409
|
+
create_flashinfer_kv_indices_triton[(batch_size,)](
|
410
|
+
model_runner.req_to_token_pool.req_to_token,
|
411
|
+
req_pool_indices,
|
412
|
+
paged_kernel_lens,
|
413
|
+
kv_indptr,
|
414
|
+
kv_start_idx,
|
415
|
+
model_runner.req_to_token_pool.req_to_token.size(1),
|
416
|
+
kv_indices,
|
417
|
+
)
|
383
418
|
|
384
419
|
if forward_mode == ForwardMode.DECODE:
|
385
420
|
# CUDA graph uses different flashinfer_decode_wrapper
|
@@ -21,7 +21,7 @@ import importlib.resources
|
|
21
21
|
import logging
|
22
22
|
import pkgutil
|
23
23
|
from functools import lru_cache
|
24
|
-
from typing import Optional, Type
|
24
|
+
from typing import Optional, Tuple, Type
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
@@ -44,13 +44,15 @@ from vllm.model_executor.model_loader import get_model
|
|
44
44
|
from vllm.model_executor.models import ModelRegistry
|
45
45
|
|
46
46
|
from sglang.global_config import global_config
|
47
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
48
|
+
from sglang.srt.layers.sampler import SampleOutput
|
47
49
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
48
50
|
from sglang.srt.mem_cache.memory_pool import (
|
49
51
|
MHATokenToKVPool,
|
50
52
|
MLATokenToKVPool,
|
51
53
|
ReqToTokenPool,
|
52
54
|
)
|
53
|
-
from sglang.srt.model_config import AttentionArch
|
55
|
+
from sglang.srt.model_config import AttentionArch, ModelConfig
|
54
56
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
55
57
|
from sglang.srt.server_args import ServerArgs
|
56
58
|
from sglang.srt.utils import (
|
@@ -69,7 +71,7 @@ logger = logging.getLogger(__name__)
|
|
69
71
|
class ModelRunner:
|
70
72
|
def __init__(
|
71
73
|
self,
|
72
|
-
model_config,
|
74
|
+
model_config: ModelConfig,
|
73
75
|
mem_fraction_static: float,
|
74
76
|
gpu_id: int,
|
75
77
|
tp_rank: int,
|
@@ -85,7 +87,9 @@ class ModelRunner:
|
|
85
87
|
self.tp_size = tp_size
|
86
88
|
self.nccl_port = nccl_port
|
87
89
|
self.server_args = server_args
|
88
|
-
self.is_multimodal_model = is_multimodal_model(
|
90
|
+
self.is_multimodal_model = is_multimodal_model(
|
91
|
+
self.model_config.hf_config.architectures
|
92
|
+
)
|
89
93
|
global_server_args_dict.update(
|
90
94
|
{
|
91
95
|
"disable_flashinfer": server_args.disable_flashinfer,
|
@@ -95,6 +99,13 @@ class ModelRunner:
|
|
95
99
|
}
|
96
100
|
)
|
97
101
|
|
102
|
+
if self.is_multimodal_model:
|
103
|
+
logger.info(
|
104
|
+
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
105
|
+
)
|
106
|
+
server_args.chunked_prefill_size = None
|
107
|
+
server_args.mem_fraction_static *= 0.95
|
108
|
+
|
98
109
|
min_per_gpu_memory = self.init_torch_distributed()
|
99
110
|
self.load_model()
|
100
111
|
self.init_memory_pool(
|
@@ -184,9 +195,9 @@ class ModelRunner:
|
|
184
195
|
monkey_patch_vllm_qvk_linear_loader()
|
185
196
|
|
186
197
|
self.dtype = self.vllm_model_config.dtype
|
187
|
-
if self.model_config.
|
198
|
+
if self.model_config.model_override_args is not None:
|
188
199
|
self.vllm_model_config.hf_config.update(
|
189
|
-
self.model_config.
|
200
|
+
self.model_config.model_override_args
|
190
201
|
)
|
191
202
|
|
192
203
|
self.model = get_model(
|
@@ -337,13 +348,7 @@ class ModelRunner:
|
|
337
348
|
if self.server_args.kv_cache_dtype == "auto":
|
338
349
|
self.kv_cache_dtype = self.dtype
|
339
350
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
340
|
-
|
341
|
-
logger.warning(
|
342
|
-
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
|
343
|
-
)
|
344
|
-
self.kv_cache_dtype = self.dtype
|
345
|
-
else:
|
346
|
-
self.kv_cache_dtype = torch.float8_e5m2
|
351
|
+
self.kv_cache_dtype = torch.float8_e5m2
|
347
352
|
else:
|
348
353
|
raise ValueError(
|
349
354
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
@@ -507,15 +512,19 @@ class ModelRunner:
|
|
507
512
|
raise Exception(
|
508
513
|
f"Capture cuda graph failed: {e}\n"
|
509
514
|
"Possible solutions:\n"
|
510
|
-
"1. disable
|
511
|
-
"2.
|
512
|
-
"3.
|
515
|
+
"1. disable cuda graph by --disable-cuda-graph\n"
|
516
|
+
"2. set --mem-fraction-static to a smaller value\n"
|
517
|
+
"3. disable torch compile by not using --enable-torch-compile\n"
|
513
518
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
514
519
|
)
|
515
520
|
|
516
521
|
@torch.inference_mode()
|
517
522
|
def forward_decode(self, batch: ScheduleBatch):
|
518
|
-
if
|
523
|
+
if (
|
524
|
+
self.cuda_graph_runner
|
525
|
+
and self.cuda_graph_runner.can_run(len(batch.reqs))
|
526
|
+
and not batch.sampling_info.has_bias()
|
527
|
+
):
|
519
528
|
return self.cuda_graph_runner.replay(batch)
|
520
529
|
|
521
530
|
input_metadata = InputMetadata.from_schedule_batch(
|
@@ -564,7 +573,9 @@ class ModelRunner:
|
|
564
573
|
input_metadata.image_offsets,
|
565
574
|
)
|
566
575
|
|
567
|
-
def forward(
|
576
|
+
def forward(
|
577
|
+
self, batch: ScheduleBatch, forward_mode: ForwardMode
|
578
|
+
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
|
568
579
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
569
580
|
return self.forward_extend_multi_modal(batch)
|
570
581
|
elif forward_mode == ForwardMode.DECODE:
|
sglang/srt/models/chatglm.py
CHANGED
@@ -17,7 +17,7 @@ limitations under the License.
|
|
17
17
|
# Adapted from
|
18
18
|
# https://github.com/THUDM/ChatGLM2-6B
|
19
19
|
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
20
|
-
from typing import Iterable,
|
20
|
+
from typing import Iterable, Optional, Tuple
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
@@ -31,20 +31,18 @@ from vllm.model_executor.layers.linear import (
|
|
31
31
|
)
|
32
32
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
33
33
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
34
|
-
from vllm.model_executor.layers.sampler import Sampler
|
35
34
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
36
35
|
ParallelLMHead,
|
37
36
|
VocabParallelEmbedding,
|
38
37
|
)
|
39
38
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
40
|
-
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
41
|
-
from vllm.sequence import SamplerOutput
|
42
39
|
from vllm.transformers_utils.configs import ChatGLMConfig
|
43
40
|
|
44
41
|
from sglang.srt.layers.activation import SiluAndMul
|
45
42
|
from sglang.srt.layers.layernorm import RMSNorm
|
46
43
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
44
|
from sglang.srt.layers.radix_attention import RadixAttention
|
45
|
+
from sglang.srt.layers.sampler import Sampler
|
48
46
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
47
|
|
50
48
|
LoraConfig = None
|
@@ -383,17 +381,11 @@ class ChatGLMForCausalLM(nn.Module):
|
|
383
381
|
input_metadata: InputMetadata,
|
384
382
|
) -> torch.Tensor:
|
385
383
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
386
|
-
|
384
|
+
logits_output = self.logits_processor(
|
387
385
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
388
386
|
)
|
389
|
-
|
390
|
-
|
391
|
-
self,
|
392
|
-
logits: torch.Tensor,
|
393
|
-
sampling_metadata: SamplingMetadata,
|
394
|
-
) -> Optional[SamplerOutput]:
|
395
|
-
next_tokens = self.sampler(logits, sampling_metadata)
|
396
|
-
return next_tokens
|
387
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
388
|
+
return sample_output, logits_output
|
397
389
|
|
398
390
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
399
391
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
sglang/srt/models/commandr.py
CHANGED
@@ -64,6 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
|
64
64
|
from sglang.srt.layers.activation import SiluAndMul
|
65
65
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
66
66
|
from sglang.srt.layers.radix_attention import RadixAttention
|
67
|
+
from sglang.srt.layers.sampler import Sampler
|
67
68
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
68
69
|
|
69
70
|
|
@@ -326,6 +327,7 @@ class CohereForCausalLM(nn.Module):
|
|
326
327
|
self.config = config
|
327
328
|
self.quant_config = quant_config
|
328
329
|
self.logits_processor = LogitsProcessor(config)
|
330
|
+
self.sampler = Sampler()
|
329
331
|
self.model = CohereModel(config, quant_config)
|
330
332
|
|
331
333
|
@torch.no_grad()
|
@@ -340,9 +342,11 @@ class CohereForCausalLM(nn.Module):
|
|
340
342
|
positions,
|
341
343
|
input_metadata,
|
342
344
|
)
|
343
|
-
|
345
|
+
logits_output = self.logits_processor(
|
344
346
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
345
347
|
)
|
348
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
349
|
+
return sample_output, logits_output
|
346
350
|
|
347
351
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
348
352
|
stacked_params_mapping = [
|
sglang/srt/models/dbrx.py
CHANGED
@@ -45,6 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
+
from sglang.srt.layers.sampler import Sampler
|
48
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
50
|
|
50
51
|
|
@@ -382,6 +383,7 @@ class DbrxForCausalLM(nn.Module):
|
|
382
383
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
383
384
|
)
|
384
385
|
self.logits_processor = LogitsProcessor(config)
|
386
|
+
self.sampler = Sampler()
|
385
387
|
|
386
388
|
@torch.no_grad()
|
387
389
|
def forward(
|
@@ -391,9 +393,11 @@ class DbrxForCausalLM(nn.Module):
|
|
391
393
|
input_metadata: InputMetadata,
|
392
394
|
) -> torch.Tensor:
|
393
395
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
394
|
-
|
396
|
+
logits_output = self.logits_processor(
|
395
397
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
396
398
|
)
|
399
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
400
|
+
return sample_output, logits_output
|
397
401
|
|
398
402
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
399
403
|
expert_params_mapping = [
|
sglang/srt/models/deepseek.py
CHANGED
@@ -46,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
46
46
|
from sglang.srt.layers.layernorm import RMSNorm
|
47
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
+
from sglang.srt.layers.sampler import Sampler
|
49
50
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
51
|
|
51
52
|
|
@@ -385,6 +386,7 @@ class DeepseekForCausalLM(nn.Module):
|
|
385
386
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
386
387
|
)
|
387
388
|
self.logits_processor = LogitsProcessor(config)
|
389
|
+
self.sampler = Sampler()
|
388
390
|
|
389
391
|
@torch.no_grad()
|
390
392
|
def forward(
|
@@ -394,9 +396,11 @@ class DeepseekForCausalLM(nn.Module):
|
|
394
396
|
input_metadata: InputMetadata,
|
395
397
|
) -> torch.Tensor:
|
396
398
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
397
|
-
|
399
|
+
logits_output = self.logits_processor(
|
398
400
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
399
401
|
)
|
402
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
403
|
+
return sample_output, logits_output
|
400
404
|
|
401
405
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
402
406
|
stacked_params_mapping = [
|