sglang 0.2.10__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +145 -36
- sglang/check_env.py +24 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -29
- sglang/lang/choices.py +164 -0
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +11 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -5
- sglang/srt/managers/schedule_batch.py +95 -324
- sglang/srt/managers/tokenizer_manager.py +6 -3
- sglang/srt/managers/tp_worker.py +20 -22
- sglang/srt/mem_cache/memory_pool.py +9 -14
- sglang/srt/model_executor/cuda_graph_runner.py +3 -3
- sglang/srt/model_executor/forward_batch_info.py +256 -0
- sglang/srt/model_executor/model_runner.py +6 -10
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +1 -1
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +1 -1
- sglang/srt/models/llama2.py +1 -1
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +34 -12
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/server.py +24 -6
- sglang/srt/server_args.py +4 -0
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/METADATA +34 -24
- {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/RECORD +52 -50
- {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -39,13 +39,13 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler
|
|
39
39
|
from sglang.srt.managers.schedule_batch import (
|
40
40
|
FINISH_ABORT,
|
41
41
|
BaseFinishReason,
|
42
|
-
Batch,
|
43
|
-
ForwardMode,
|
44
42
|
Req,
|
43
|
+
ScheduleBatch,
|
45
44
|
)
|
46
45
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
47
46
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
48
47
|
from sglang.srt.model_config import ModelConfig
|
48
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
49
49
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
50
50
|
from sglang.srt.server_args import ServerArgs
|
51
51
|
from sglang.srt.utils import (
|
@@ -172,7 +172,7 @@ class ModelTpServer:
|
|
172
172
|
|
173
173
|
# Init running status
|
174
174
|
self.waiting_queue: List[Req] = []
|
175
|
-
self.running_batch:
|
175
|
+
self.running_batch: ScheduleBatch = None
|
176
176
|
self.out_pyobjs = []
|
177
177
|
self.decode_forward_ct = 0
|
178
178
|
self.stream_interval = server_args.stream_interval
|
@@ -200,7 +200,6 @@ class ModelTpServer:
|
|
200
200
|
)
|
201
201
|
self.new_token_ratio = self.min_new_token_ratio
|
202
202
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
203
|
-
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
204
203
|
|
205
204
|
def exposed_step(self, recv_reqs):
|
206
205
|
try:
|
@@ -290,10 +289,10 @@ class ModelTpServer:
|
|
290
289
|
"KV cache pool leak detected!"
|
291
290
|
)
|
292
291
|
|
293
|
-
if self.req_to_token_pool.
|
292
|
+
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
294
293
|
warnings.warn(
|
295
294
|
"Warning: "
|
296
|
-
f"available req slots={self.req_to_token_pool.
|
295
|
+
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
|
297
296
|
f"total slots={self.req_to_token_pool.size}\n"
|
298
297
|
"Memory pool leak detected!"
|
299
298
|
)
|
@@ -353,7 +352,7 @@ class ModelTpServer:
|
|
353
352
|
)
|
354
353
|
self.waiting_queue.append(req)
|
355
354
|
|
356
|
-
def get_new_prefill_batch(self) -> Optional[
|
355
|
+
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
|
357
356
|
# TODO(lsyin): organize this function
|
358
357
|
running_bs = (
|
359
358
|
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
@@ -364,12 +363,13 @@ class ModelTpServer:
|
|
364
363
|
# Compute matched prefix length
|
365
364
|
for req in self.waiting_queue:
|
366
365
|
req.input_ids = req.origin_input_ids + req.output_ids
|
366
|
+
try_match_ids = req.input_ids
|
367
|
+
if req.return_logprob:
|
368
|
+
try_match_ids = req.input_ids[: req.logprob_start_len]
|
369
|
+
# NOTE: the prefix_indices must always be aligned with last_node
|
367
370
|
prefix_indices, last_node = self.tree_cache.match_prefix(
|
368
|
-
rid=req.rid,
|
369
|
-
key=req.input_ids,
|
371
|
+
rid=req.rid, key=try_match_ids
|
370
372
|
)
|
371
|
-
if req.return_logprob:
|
372
|
-
prefix_indices = prefix_indices[: req.logprob_start_len]
|
373
373
|
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
374
374
|
req.prefix_indices = prefix_indices
|
375
375
|
req.last_node = last_node
|
@@ -525,7 +525,7 @@ class ModelTpServer:
|
|
525
525
|
)
|
526
526
|
|
527
527
|
# Return the new batch
|
528
|
-
new_batch =
|
528
|
+
new_batch = ScheduleBatch.init_new(
|
529
529
|
can_run_list,
|
530
530
|
self.req_to_token_pool,
|
531
531
|
self.token_to_kv_pool,
|
@@ -534,7 +534,7 @@ class ModelTpServer:
|
|
534
534
|
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
535
535
|
return new_batch
|
536
536
|
|
537
|
-
def forward_prefill_batch(self, batch:
|
537
|
+
def forward_prefill_batch(self, batch: ScheduleBatch):
|
538
538
|
# Build batch tensors
|
539
539
|
batch.prepare_for_extend(
|
540
540
|
self.model_config.vocab_size, self.int_token_logit_bias
|
@@ -623,14 +623,13 @@ class ModelTpServer:
|
|
623
623
|
)
|
624
624
|
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
625
625
|
|
626
|
-
def cache_filled_batch(self, batch:
|
627
|
-
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
626
|
+
def cache_filled_batch(self, batch: ScheduleBatch):
|
628
627
|
for i, req in enumerate(batch.reqs):
|
629
628
|
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
630
629
|
rid=req.rid,
|
631
630
|
token_ids=tuple(req.input_ids),
|
632
631
|
last_uncached_pos=len(req.prefix_indices),
|
633
|
-
req_pool_idx=
|
632
|
+
req_pool_idx=req.req_pool_idx,
|
634
633
|
del_in_memory_pool=False,
|
635
634
|
old_last_node=req.last_node,
|
636
635
|
)
|
@@ -638,9 +637,9 @@ class ModelTpServer:
|
|
638
637
|
|
639
638
|
if req is self.current_inflight_req:
|
640
639
|
# inflight request would get a new req idx
|
641
|
-
self.req_to_token_pool.free(
|
640
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
642
641
|
|
643
|
-
def forward_decode_batch(self, batch:
|
642
|
+
def forward_decode_batch(self, batch: ScheduleBatch):
|
644
643
|
# Check if decode out of memory
|
645
644
|
if not batch.check_decode_mem():
|
646
645
|
old_ratio = self.new_token_ratio
|
@@ -699,7 +698,7 @@ class ModelTpServer:
|
|
699
698
|
|
700
699
|
self.handle_finished_requests(batch)
|
701
700
|
|
702
|
-
def handle_finished_requests(self, batch:
|
701
|
+
def handle_finished_requests(self, batch: ScheduleBatch):
|
703
702
|
output_rids = []
|
704
703
|
output_vids = []
|
705
704
|
decoded_texts = []
|
@@ -781,14 +780,13 @@ class ModelTpServer:
|
|
781
780
|
# Remove finished reqs
|
782
781
|
if finished_indices:
|
783
782
|
# Update radix cache
|
784
|
-
req_pool_indices_cpu = batch.req_pool_indices.tolist()
|
785
783
|
for i in finished_indices:
|
786
784
|
req = batch.reqs[i]
|
787
785
|
self.tree_cache.cache_req(
|
788
786
|
rid=req.rid,
|
789
787
|
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
790
788
|
last_uncached_pos=len(req.prefix_indices),
|
791
|
-
req_pool_idx=
|
789
|
+
req_pool_idx=req.req_pool_idx,
|
792
790
|
)
|
793
791
|
|
794
792
|
self.tree_cache.dec_lock_ref(req.last_node)
|
@@ -799,7 +797,7 @@ class ModelTpServer:
|
|
799
797
|
else:
|
800
798
|
batch.reqs = []
|
801
799
|
|
802
|
-
def filter_out_inflight(self, batch:
|
800
|
+
def filter_out_inflight(self, batch: ScheduleBatch):
|
803
801
|
# TODO(lsyin): reduce the overhead, make a special version for this
|
804
802
|
if self.current_inflight_req is None:
|
805
803
|
return
|
@@ -16,6 +16,7 @@ limitations under the License.
|
|
16
16
|
"""Memory pool."""
|
17
17
|
|
18
18
|
import logging
|
19
|
+
from typing import List
|
19
20
|
|
20
21
|
import torch
|
21
22
|
|
@@ -27,34 +28,28 @@ class ReqToTokenPool:
|
|
27
28
|
|
28
29
|
def __init__(self, size: int, max_context_len: int):
|
29
30
|
self.size = size
|
30
|
-
self.
|
31
|
+
self.free_slots = list(range(size))
|
31
32
|
self.req_to_token = torch.empty(
|
32
33
|
(size, max_context_len), dtype=torch.int32, device="cuda"
|
33
34
|
)
|
34
|
-
self.can_use_mem_size = size
|
35
35
|
|
36
|
-
def alloc(self, need_size: int):
|
37
|
-
if need_size > self.
|
36
|
+
def alloc(self, need_size: int) -> List[int]:
|
37
|
+
if need_size > len(self.free_slots):
|
38
38
|
return None
|
39
39
|
|
40
|
-
select_index =
|
41
|
-
|
42
|
-
)
|
43
|
-
self.mem_state[select_index] = False
|
44
|
-
self.can_use_mem_size -= need_size
|
40
|
+
select_index = self.free_slots[:need_size]
|
41
|
+
self.free_slots = self.free_slots[need_size:]
|
45
42
|
|
46
43
|
return select_index
|
47
44
|
|
48
45
|
def free(self, free_index):
|
49
|
-
self.mem_state[free_index] = True
|
50
46
|
if isinstance(free_index, (int,)):
|
51
|
-
self.
|
47
|
+
self.free_slots.append(free_index)
|
52
48
|
else:
|
53
|
-
self.
|
49
|
+
self.free_slots.extend(free_index)
|
54
50
|
|
55
51
|
def clear(self):
|
56
|
-
self.
|
57
|
-
self.can_use_mem_size = len(self.mem_state)
|
52
|
+
self.free_slots = list(range(self.size))
|
58
53
|
|
59
54
|
|
60
55
|
class BaseTokenToKVPool:
|
@@ -29,8 +29,8 @@ from sglang.srt.layers.logits_processor import (
|
|
29
29
|
LogitsMetadata,
|
30
30
|
LogitsProcessor,
|
31
31
|
)
|
32
|
-
from sglang.srt.managers.schedule_batch import
|
33
|
-
|
32
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
33
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
34
34
|
ForwardMode,
|
35
35
|
InputMetadata,
|
36
36
|
init_flashinfer_args,
|
@@ -202,7 +202,7 @@ class CudaGraphRunner:
|
|
202
202
|
self.graph_memory_pool = graph.pool()
|
203
203
|
return graph, None, out, flashinfer_decode_wrapper
|
204
204
|
|
205
|
-
def replay(self, batch:
|
205
|
+
def replay(self, batch: ScheduleBatch):
|
206
206
|
assert batch.out_cache_loc is not None
|
207
207
|
raw_bs = len(batch.reqs)
|
208
208
|
|
@@ -0,0 +1,256 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""ModelRunner runs the forward passes of the models."""
|
17
|
+
from dataclasses import dataclass
|
18
|
+
from enum import IntEnum, auto
|
19
|
+
from typing import List
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
import torch
|
23
|
+
|
24
|
+
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
25
|
+
|
26
|
+
|
27
|
+
class ForwardMode(IntEnum):
|
28
|
+
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
29
|
+
PREFILL = auto()
|
30
|
+
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
31
|
+
EXTEND = auto()
|
32
|
+
# Decode one token.
|
33
|
+
DECODE = auto()
|
34
|
+
|
35
|
+
|
36
|
+
@dataclass
|
37
|
+
class InputMetadata:
|
38
|
+
"""Store all inforamtion of a forward pass."""
|
39
|
+
|
40
|
+
forward_mode: ForwardMode
|
41
|
+
batch_size: int
|
42
|
+
total_num_tokens: int
|
43
|
+
req_pool_indices: torch.Tensor
|
44
|
+
seq_lens: torch.Tensor
|
45
|
+
positions: torch.Tensor
|
46
|
+
req_to_token_pool: ReqToTokenPool
|
47
|
+
token_to_kv_pool: BaseTokenToKVPool
|
48
|
+
|
49
|
+
# For extend
|
50
|
+
extend_seq_lens: torch.Tensor
|
51
|
+
extend_start_loc: torch.Tensor
|
52
|
+
extend_no_prefix: bool
|
53
|
+
|
54
|
+
# Output location of the KV cache
|
55
|
+
out_cache_loc: torch.Tensor = None
|
56
|
+
|
57
|
+
# Output options
|
58
|
+
return_logprob: bool = False
|
59
|
+
top_logprobs_nums: List[int] = None
|
60
|
+
|
61
|
+
# Trition attention backend
|
62
|
+
triton_max_seq_len: int = 0
|
63
|
+
triton_max_extend_len: int = 0
|
64
|
+
triton_start_loc: torch.Tensor = None
|
65
|
+
triton_prefix_lens: torch.Tensor = None
|
66
|
+
|
67
|
+
# FlashInfer attention backend
|
68
|
+
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
69
|
+
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
70
|
+
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
71
|
+
flashinfer_use_ragged: bool = False
|
72
|
+
|
73
|
+
@classmethod
|
74
|
+
def create(
|
75
|
+
cls,
|
76
|
+
model_runner,
|
77
|
+
forward_mode,
|
78
|
+
req_pool_indices,
|
79
|
+
seq_lens,
|
80
|
+
prefix_lens,
|
81
|
+
position_ids_offsets,
|
82
|
+
out_cache_loc,
|
83
|
+
top_logprobs_nums=None,
|
84
|
+
return_logprob=False,
|
85
|
+
skip_flashinfer_init=False,
|
86
|
+
):
|
87
|
+
flashinfer_use_ragged = False
|
88
|
+
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
89
|
+
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
90
|
+
flashinfer_use_ragged = True
|
91
|
+
init_flashinfer_args(
|
92
|
+
forward_mode,
|
93
|
+
model_runner,
|
94
|
+
req_pool_indices,
|
95
|
+
seq_lens,
|
96
|
+
prefix_lens,
|
97
|
+
model_runner.flashinfer_decode_wrapper,
|
98
|
+
flashinfer_use_ragged,
|
99
|
+
)
|
100
|
+
|
101
|
+
batch_size = len(req_pool_indices)
|
102
|
+
|
103
|
+
if forward_mode == ForwardMode.DECODE:
|
104
|
+
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
105
|
+
extend_seq_lens = extend_start_loc = extend_no_prefix = None
|
106
|
+
if not model_runner.server_args.disable_flashinfer:
|
107
|
+
# This variable is not needed in this case,
|
108
|
+
# we do not compute it to make it compatbile with cuda graph.
|
109
|
+
total_num_tokens = None
|
110
|
+
else:
|
111
|
+
total_num_tokens = int(torch.sum(seq_lens))
|
112
|
+
else:
|
113
|
+
seq_lens_cpu = seq_lens.cpu().numpy()
|
114
|
+
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
115
|
+
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
116
|
+
positions = torch.tensor(
|
117
|
+
np.concatenate(
|
118
|
+
[
|
119
|
+
np.arange(
|
120
|
+
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
121
|
+
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
122
|
+
)
|
123
|
+
for i in range(batch_size)
|
124
|
+
],
|
125
|
+
axis=0,
|
126
|
+
),
|
127
|
+
device="cuda",
|
128
|
+
)
|
129
|
+
extend_seq_lens = seq_lens - prefix_lens
|
130
|
+
extend_start_loc = torch.zeros_like(seq_lens)
|
131
|
+
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
132
|
+
extend_no_prefix = torch.all(prefix_lens == 0)
|
133
|
+
total_num_tokens = int(torch.sum(seq_lens))
|
134
|
+
|
135
|
+
ret = cls(
|
136
|
+
forward_mode=forward_mode,
|
137
|
+
batch_size=batch_size,
|
138
|
+
total_num_tokens=total_num_tokens,
|
139
|
+
req_pool_indices=req_pool_indices,
|
140
|
+
seq_lens=seq_lens,
|
141
|
+
positions=positions,
|
142
|
+
req_to_token_pool=model_runner.req_to_token_pool,
|
143
|
+
token_to_kv_pool=model_runner.token_to_kv_pool,
|
144
|
+
out_cache_loc=out_cache_loc,
|
145
|
+
extend_seq_lens=extend_seq_lens,
|
146
|
+
extend_start_loc=extend_start_loc,
|
147
|
+
extend_no_prefix=extend_no_prefix,
|
148
|
+
return_logprob=return_logprob,
|
149
|
+
top_logprobs_nums=top_logprobs_nums,
|
150
|
+
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
151
|
+
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
152
|
+
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
153
|
+
flashinfer_use_ragged=flashinfer_use_ragged,
|
154
|
+
)
|
155
|
+
|
156
|
+
if model_runner.server_args.disable_flashinfer:
|
157
|
+
(
|
158
|
+
ret.triton_max_seq_len,
|
159
|
+
ret.triton_max_extend_len,
|
160
|
+
ret.triton_start_loc,
|
161
|
+
ret.triton_prefix_lens,
|
162
|
+
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
163
|
+
|
164
|
+
return ret
|
165
|
+
|
166
|
+
|
167
|
+
def init_flashinfer_args(
|
168
|
+
forward_mode,
|
169
|
+
model_runner,
|
170
|
+
req_pool_indices,
|
171
|
+
seq_lens,
|
172
|
+
prefix_lens,
|
173
|
+
flashinfer_decode_wrapper,
|
174
|
+
flashinfer_use_ragged=False,
|
175
|
+
):
|
176
|
+
"""Init auxiliary variables for FlashInfer attention backend."""
|
177
|
+
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
178
|
+
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
179
|
+
head_dim = model_runner.model_config.head_dim
|
180
|
+
batch_size = len(req_pool_indices)
|
181
|
+
total_num_tokens = int(torch.sum(seq_lens))
|
182
|
+
|
183
|
+
if flashinfer_use_ragged:
|
184
|
+
paged_kernel_lens = prefix_lens
|
185
|
+
else:
|
186
|
+
paged_kernel_lens = seq_lens
|
187
|
+
|
188
|
+
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
189
|
+
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
190
|
+
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
191
|
+
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
192
|
+
kv_indices = torch.cat(
|
193
|
+
[
|
194
|
+
model_runner.req_to_token_pool.req_to_token[
|
195
|
+
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
196
|
+
]
|
197
|
+
for i in range(batch_size)
|
198
|
+
],
|
199
|
+
dim=0,
|
200
|
+
).contiguous()
|
201
|
+
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
202
|
+
|
203
|
+
if forward_mode == ForwardMode.DECODE:
|
204
|
+
flashinfer_decode_wrapper.end_forward()
|
205
|
+
flashinfer_decode_wrapper.begin_forward(
|
206
|
+
kv_indptr,
|
207
|
+
kv_indices,
|
208
|
+
kv_last_page_len,
|
209
|
+
num_qo_heads,
|
210
|
+
num_kv_heads,
|
211
|
+
head_dim,
|
212
|
+
1,
|
213
|
+
)
|
214
|
+
else:
|
215
|
+
# extend part
|
216
|
+
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
217
|
+
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
218
|
+
|
219
|
+
if flashinfer_use_ragged:
|
220
|
+
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
221
|
+
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
222
|
+
qo_indptr,
|
223
|
+
qo_indptr,
|
224
|
+
num_qo_heads,
|
225
|
+
num_kv_heads,
|
226
|
+
head_dim,
|
227
|
+
)
|
228
|
+
|
229
|
+
# cached part
|
230
|
+
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
231
|
+
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
232
|
+
qo_indptr,
|
233
|
+
kv_indptr,
|
234
|
+
kv_indices,
|
235
|
+
kv_last_page_len,
|
236
|
+
num_qo_heads,
|
237
|
+
num_kv_heads,
|
238
|
+
head_dim,
|
239
|
+
1,
|
240
|
+
)
|
241
|
+
|
242
|
+
|
243
|
+
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
244
|
+
"""Init auxiliary variables for triton attention backend."""
|
245
|
+
batch_size = len(seq_lens)
|
246
|
+
max_seq_len = int(torch.max(seq_lens))
|
247
|
+
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
248
|
+
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
249
|
+
|
250
|
+
if forward_mode == ForwardMode.DECODE:
|
251
|
+
max_extend_len = None
|
252
|
+
else:
|
253
|
+
extend_seq_lens = seq_lens - prefix_lens
|
254
|
+
max_extend_len = int(torch.max(extend_seq_lens))
|
255
|
+
|
256
|
+
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
@@ -41,18 +41,14 @@ from vllm.distributed import (
|
|
41
41
|
from vllm.model_executor.models import ModelRegistry
|
42
42
|
|
43
43
|
from sglang.global_config import global_config
|
44
|
-
from sglang.srt.managers.schedule_batch import
|
45
|
-
Batch,
|
46
|
-
ForwardMode,
|
47
|
-
InputMetadata,
|
48
|
-
global_server_args_dict,
|
49
|
-
)
|
44
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
50
45
|
from sglang.srt.mem_cache.memory_pool import (
|
51
46
|
MHATokenToKVPool,
|
52
47
|
MLATokenToKVPool,
|
53
48
|
ReqToTokenPool,
|
54
49
|
)
|
55
50
|
from sglang.srt.model_config import AttentionArch
|
51
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
56
52
|
from sglang.srt.server_args import ServerArgs
|
57
53
|
from sglang.srt.utils import (
|
58
54
|
get_available_gpu_memory,
|
@@ -350,7 +346,7 @@ class ModelRunner:
|
|
350
346
|
)
|
351
347
|
|
352
348
|
@torch.inference_mode()
|
353
|
-
def forward_decode(self, batch:
|
349
|
+
def forward_decode(self, batch: ScheduleBatch):
|
354
350
|
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
355
351
|
return self.cuda_graph_runner.replay(batch)
|
356
352
|
|
@@ -370,7 +366,7 @@ class ModelRunner:
|
|
370
366
|
)
|
371
367
|
|
372
368
|
@torch.inference_mode()
|
373
|
-
def forward_extend(self, batch:
|
369
|
+
def forward_extend(self, batch: ScheduleBatch):
|
374
370
|
input_metadata = InputMetadata.create(
|
375
371
|
self,
|
376
372
|
forward_mode=ForwardMode.EXTEND,
|
@@ -387,7 +383,7 @@ class ModelRunner:
|
|
387
383
|
)
|
388
384
|
|
389
385
|
@torch.inference_mode()
|
390
|
-
def forward_extend_multi_modal(self, batch:
|
386
|
+
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
391
387
|
input_metadata = InputMetadata.create(
|
392
388
|
self,
|
393
389
|
forward_mode=ForwardMode.EXTEND,
|
@@ -408,7 +404,7 @@ class ModelRunner:
|
|
408
404
|
batch.image_offsets,
|
409
405
|
)
|
410
406
|
|
411
|
-
def forward(self, batch:
|
407
|
+
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
|
412
408
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
413
409
|
return self.forward_extend_multi_modal(batch)
|
414
410
|
elif forward_mode == ForwardMode.DECODE:
|
sglang/srt/models/chatglm.py
CHANGED
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
-
from sglang.srt.model_executor.
|
48
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
49
|
|
50
50
|
LoraConfig = None
|
51
51
|
|
sglang/srt/models/commandr.py
CHANGED
@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
|
64
64
|
|
65
65
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
66
66
|
from sglang.srt.layers.radix_attention import RadixAttention
|
67
|
-
from sglang.srt.model_executor.
|
67
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
68
68
|
|
69
69
|
|
70
70
|
@torch.compile
|
sglang/srt/models/dbrx.py
CHANGED
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
-
from sglang.srt.model_executor.
|
48
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
49
|
|
50
50
|
|
51
51
|
class DbrxRouter(nn.Module):
|
sglang/srt/models/deepseek.py
CHANGED
@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
46
46
|
|
47
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
-
from sglang.srt.
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
50
|
|
51
51
|
|
52
52
|
class DeepseekMLP(nn.Module):
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
48
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
49
|
-
from sglang.srt.model_executor.
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
50
|
|
51
51
|
|
52
52
|
class DeepseekV2MLP(nn.Module):
|
sglang/srt/models/gemma.py
CHANGED
@@ -37,7 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
37
37
|
|
38
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
-
from sglang.srt.model_executor.
|
40
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
41
41
|
|
42
42
|
|
43
43
|
class GemmaMLP(nn.Module):
|
sglang/srt/models/gemma2.py
CHANGED
@@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
42
42
|
|
43
43
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
44
44
|
from sglang.srt.layers.radix_attention import RadixAttention
|
45
|
-
from sglang.srt.model_executor.
|
45
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
46
46
|
|
47
47
|
|
48
48
|
class GemmaRMSNorm(CustomOp):
|
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
35
35
|
|
36
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
-
from sglang.srt.
|
38
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
39
39
|
|
40
40
|
|
41
41
|
class GPTBigCodeAttention(nn.Module):
|
sglang/srt/models/grok.py
CHANGED
@@ -52,7 +52,7 @@ from vllm.utils import print_warning_once
|
|
52
52
|
from sglang.srt.layers.fused_moe import fused_moe
|
53
53
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
54
54
|
from sglang.srt.layers.radix_attention import RadixAttention
|
55
|
-
from sglang.srt.model_executor.
|
55
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
56
56
|
|
57
57
|
use_fused = True
|
58
58
|
|
sglang/srt/models/internlm2.py
CHANGED
@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
40
40
|
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
-
from sglang.srt.model_executor.
|
43
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
44
44
|
|
45
45
|
|
46
46
|
class InternLM2MLP(nn.Module):
|
sglang/srt/models/llama2.py
CHANGED
@@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
41
41
|
|
42
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
-
from sglang.srt.model_executor.
|
44
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
45
45
|
|
46
46
|
|
47
47
|
class LlamaMLP(nn.Module):
|
@@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|
25
25
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
26
26
|
|
27
27
|
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
28
|
-
from sglang.srt.model_executor.
|
28
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
29
29
|
from sglang.srt.models.llama2 import LlamaModel
|
30
30
|
|
31
31
|
|
sglang/srt/models/llava.py
CHANGED
@@ -32,13 +32,12 @@ from vllm.config import CacheConfig
|
|
32
32
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
33
33
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
34
34
|
|
35
|
-
from sglang.srt.managers.schedule_batch import ForwardMode
|
36
35
|
from sglang.srt.mm_utils import (
|
37
36
|
get_anyres_image_grid_shape,
|
38
37
|
unpad_image,
|
39
38
|
unpad_image_shape,
|
40
39
|
)
|
41
|
-
from sglang.srt.model_executor.
|
40
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
42
41
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
43
42
|
from sglang.srt.models.mistral import MistralForCausalLM
|
44
43
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|