sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +48 -33
- sglang/bench_server_latency.py +0 -6
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/hf_transformers_utils.py +1 -0
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/patch.py +117 -0
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +187 -68
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -247
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/mem_cache/radix_cache.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +25 -25
- sglang/srt/model_executor/forward_batch_info.py +94 -97
- sglang/srt/model_executor/model_runner.py +76 -78
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +22 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/openai_api/adapter.py +7 -0
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +76 -33
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +53 -9
- sglang/version.py +1 -1
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -482
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.1.post3.dist-info/RECORD +0 -134
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -16,9 +16,9 @@ limitations under the License.
|
|
16
16
|
"""Memory pool."""
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from abc import ABC, abstractmethod
|
20
19
|
from typing import List, Tuple, Union
|
21
20
|
|
21
|
+
import numpy as np
|
22
22
|
import torch
|
23
23
|
|
24
24
|
logger = logging.getLogger(__name__)
|
@@ -27,12 +27,17 @@ logger = logging.getLogger(__name__)
|
|
27
27
|
class ReqToTokenPool:
|
28
28
|
"""A memory pool that maps a request to its token locations."""
|
29
29
|
|
30
|
-
def __init__(self, size: int, max_context_len: int):
|
30
|
+
def __init__(self, size: int, max_context_len: int, device: str):
|
31
31
|
self.size = size
|
32
|
-
self.
|
32
|
+
self.max_context_len = max_context_len
|
33
|
+
self.device = device
|
33
34
|
self.req_to_token = torch.empty(
|
34
|
-
(size, max_context_len), dtype=torch.int32, device=
|
35
|
+
(size, max_context_len), dtype=torch.int32, device=device
|
35
36
|
)
|
37
|
+
self.free_slots = list(range(size))
|
38
|
+
|
39
|
+
def available_size(self):
|
40
|
+
return len(self.free_slots)
|
36
41
|
|
37
42
|
def alloc(self, need_size: int) -> List[int]:
|
38
43
|
if need_size > len(self.free_slots):
|
@@ -53,86 +58,55 @@ class ReqToTokenPool:
|
|
53
58
|
self.free_slots = list(range(self.size))
|
54
59
|
|
55
60
|
|
56
|
-
class BaseTokenToKVPool
|
61
|
+
class BaseTokenToKVPool:
|
57
62
|
"""A memory pool that maps a token to its kv cache locations"""
|
58
63
|
|
59
64
|
def __init__(
|
60
65
|
self,
|
61
66
|
size: int,
|
62
67
|
dtype: torch.dtype,
|
68
|
+
device: str,
|
63
69
|
):
|
64
70
|
self.size = size
|
65
71
|
self.dtype = dtype
|
72
|
+
self.device = device
|
66
73
|
if dtype == torch.float8_e5m2:
|
67
74
|
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
68
75
|
self.store_dtype = torch.uint8
|
69
76
|
else:
|
70
77
|
self.store_dtype = dtype
|
71
78
|
|
72
|
-
|
73
|
-
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
74
|
-
|
75
|
-
# Prefetch buffer
|
76
|
-
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
77
|
-
self.prefetch_chunk_size = 512
|
78
|
-
|
79
|
-
self.can_use_mem_size = self.size
|
79
|
+
self.free_slots = None
|
80
80
|
self.clear()
|
81
81
|
|
82
82
|
def available_size(self):
|
83
|
-
return
|
83
|
+
return len(self.free_slots)
|
84
84
|
|
85
85
|
def alloc(self, need_size: int):
|
86
|
-
|
87
|
-
if need_size <= buffer_len:
|
88
|
-
select_index = self.prefetch_buffer[:need_size]
|
89
|
-
self.prefetch_buffer = self.prefetch_buffer[need_size:]
|
90
|
-
return select_index
|
91
|
-
|
92
|
-
addition_size = need_size - buffer_len
|
93
|
-
alloc_size = max(addition_size, self.prefetch_chunk_size)
|
94
|
-
select_index = (
|
95
|
-
torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
|
96
|
-
)
|
97
|
-
|
98
|
-
if select_index.shape[0] < addition_size:
|
86
|
+
if need_size > len(self.free_slots):
|
99
87
|
return None
|
100
88
|
|
101
|
-
self.
|
102
|
-
self.
|
103
|
-
|
104
|
-
self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
|
105
|
-
ret_index = self.prefetch_buffer[:need_size]
|
106
|
-
self.prefetch_buffer = self.prefetch_buffer[need_size:]
|
89
|
+
select_index = self.free_slots[:need_size]
|
90
|
+
self.free_slots = self.free_slots[need_size:]
|
107
91
|
|
108
|
-
return
|
92
|
+
return torch.tensor(select_index, dtype=torch.int32, device=self.device)
|
109
93
|
|
110
94
|
def free(self, free_index: torch.Tensor):
|
111
|
-
self.
|
112
|
-
self.can_use_mem_size += len(free_index)
|
95
|
+
self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy()))
|
113
96
|
|
114
97
|
def clear(self):
|
115
|
-
|
98
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
99
|
+
self.free_slots = np.arange(1, self.size + 1)
|
116
100
|
|
117
|
-
self.mem_state.fill_(True)
|
118
|
-
self.can_use_mem_size = self.size
|
119
|
-
|
120
|
-
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
121
|
-
self.mem_state[0] = False
|
122
|
-
|
123
|
-
@abstractmethod
|
124
101
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
125
102
|
raise NotImplementedError()
|
126
103
|
|
127
|
-
@abstractmethod
|
128
104
|
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
129
105
|
raise NotImplementedError()
|
130
106
|
|
131
|
-
@abstractmethod
|
132
107
|
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
133
108
|
raise NotImplementedError()
|
134
109
|
|
135
|
-
@abstractmethod
|
136
110
|
def set_kv_buffer(
|
137
111
|
self,
|
138
112
|
layer_id: int,
|
@@ -152,19 +126,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
152
126
|
head_num: int,
|
153
127
|
head_dim: int,
|
154
128
|
layer_num: int,
|
129
|
+
device: str,
|
155
130
|
):
|
156
|
-
super().__init__(size, dtype)
|
131
|
+
super().__init__(size, dtype, device)
|
157
132
|
|
158
133
|
# [size, head_num, head_dim] for each layer
|
134
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
159
135
|
self.k_buffer = [
|
160
136
|
torch.empty(
|
161
|
-
(size + 1, head_num, head_dim),
|
137
|
+
(size + 1, head_num, head_dim),
|
138
|
+
dtype=self.store_dtype,
|
139
|
+
device=device,
|
162
140
|
)
|
163
141
|
for _ in range(layer_num)
|
164
142
|
]
|
165
143
|
self.v_buffer = [
|
166
144
|
torch.empty(
|
167
|
-
(size + 1, head_num, head_dim),
|
145
|
+
(size + 1, head_num, head_dim),
|
146
|
+
dtype=self.store_dtype,
|
147
|
+
device=device,
|
168
148
|
)
|
169
149
|
for _ in range(layer_num)
|
170
150
|
]
|
@@ -210,15 +190,17 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
210
190
|
kv_lora_rank: int,
|
211
191
|
qk_rope_head_dim: int,
|
212
192
|
layer_num: int,
|
193
|
+
device: str,
|
213
194
|
):
|
214
|
-
super().__init__(size, dtype)
|
195
|
+
super().__init__(size, dtype, device)
|
215
196
|
|
216
197
|
self.kv_lora_rank = kv_lora_rank
|
198
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
217
199
|
self.kv_buffer = [
|
218
200
|
torch.empty(
|
219
201
|
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
220
202
|
dtype=self.store_dtype,
|
221
|
-
device=
|
203
|
+
device=device,
|
222
204
|
)
|
223
205
|
for _ in range(layer_num)
|
224
206
|
]
|
@@ -291,15 +291,15 @@ class RadixCache(BasePrefixCache):
|
|
291
291
|
|
292
292
|
def _collect_leaves(self):
|
293
293
|
ret_list = []
|
294
|
+
stack = [self.root_node]
|
294
295
|
|
295
|
-
|
296
|
+
while stack:
|
297
|
+
cur_node = stack.pop()
|
296
298
|
if len(cur_node.children) == 0:
|
297
299
|
ret_list.append(cur_node)
|
300
|
+
else:
|
301
|
+
stack.extend(cur_node.children.values())
|
298
302
|
|
299
|
-
for x in cur_node.children.values():
|
300
|
-
dfs_(x)
|
301
|
-
|
302
|
-
dfs_(self.root_node)
|
303
303
|
return ret_list
|
304
304
|
|
305
305
|
|
@@ -25,13 +25,13 @@ import torch
|
|
25
25
|
from vllm.distributed.parallel_state import graph_capture
|
26
26
|
from vllm.model_executor.custom_op import CustomOp
|
27
27
|
|
28
|
+
from sglang.srt.layers.fused_moe.patch import fused_moe_forward_native
|
28
29
|
from sglang.srt.layers.logits_processor import (
|
29
30
|
LogitsMetadata,
|
30
31
|
LogitsProcessor,
|
31
32
|
LogitsProcessorOutput,
|
32
33
|
)
|
33
|
-
from sglang.srt.
|
34
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
34
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
35
35
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
36
36
|
|
37
37
|
if TYPE_CHECKING:
|
@@ -41,14 +41,15 @@ if TYPE_CHECKING:
|
|
41
41
|
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
42
42
|
for sub in model._modules.values():
|
43
43
|
if isinstance(sub, CustomOp):
|
44
|
-
# NOTE: FusedMoE torch native implementaiton is not efficient
|
45
|
-
if "FusedMoE" in sub.__class__.__name__:
|
46
|
-
continue
|
47
44
|
if reverse:
|
48
45
|
sub._forward_method = sub.forward_cuda
|
49
46
|
setattr(sub, "is_torch_compile", False)
|
50
47
|
else:
|
51
|
-
|
48
|
+
# NOTE: Temporarily workaround MoE
|
49
|
+
if "FusedMoE" in sub.__class__.__name__:
|
50
|
+
sub._forward_method = fused_moe_forward_native
|
51
|
+
else:
|
52
|
+
sub._forward_method = sub.forward_native
|
52
53
|
setattr(sub, "is_torch_compile", True)
|
53
54
|
if isinstance(sub, torch.nn.Module):
|
54
55
|
_to_torch(sub, reverse)
|
@@ -67,7 +68,9 @@ def patch_model(
|
|
67
68
|
monkey_patch_vllm_all_gather()
|
68
69
|
backup_ca_comm = tp_group.ca_comm
|
69
70
|
tp_group.ca_comm = None
|
70
|
-
yield torch.compile(
|
71
|
+
yield torch.compile(
|
72
|
+
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
|
73
|
+
)
|
71
74
|
else:
|
72
75
|
yield model.forward
|
73
76
|
finally:
|
@@ -139,7 +142,6 @@ class CudaGraphRunner:
|
|
139
142
|
self.seq_lens = torch.full(
|
140
143
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
141
144
|
)
|
142
|
-
self.position_ids_offsets = torch.ones((self.max_bs,), dtype=torch.int32)
|
143
145
|
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
144
146
|
|
145
147
|
# Capture
|
@@ -150,7 +152,7 @@ class CudaGraphRunner:
|
|
150
152
|
f"Capture cuda graph failed: {e}\n"
|
151
153
|
"Possible solutions:\n"
|
152
154
|
"1. disable cuda graph by --disable-cuda-graph\n"
|
153
|
-
"2. set --mem-fraction-static to a smaller value\n"
|
155
|
+
"2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
154
156
|
"3. disable torch compile by not using --enable-torch-compile\n"
|
155
157
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
156
158
|
)
|
@@ -185,7 +187,6 @@ class CudaGraphRunner:
|
|
185
187
|
input_ids = self.input_ids[:bs]
|
186
188
|
req_pool_indices = self.req_pool_indices[:bs]
|
187
189
|
seq_lens = self.seq_lens[:bs]
|
188
|
-
position_ids_offsets = self.position_ids_offsets[:bs]
|
189
190
|
out_cache_loc = self.out_cache_loc[:bs]
|
190
191
|
|
191
192
|
# Attention backend
|
@@ -195,9 +196,10 @@ class CudaGraphRunner:
|
|
195
196
|
|
196
197
|
# Run and capture
|
197
198
|
def run_once():
|
198
|
-
|
199
|
+
forward_batch = ForwardBatch(
|
199
200
|
forward_mode=ForwardMode.DECODE,
|
200
201
|
batch_size=bs,
|
202
|
+
input_ids=input_ids,
|
201
203
|
req_pool_indices=req_pool_indices,
|
202
204
|
seq_lens=seq_lens,
|
203
205
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
@@ -206,9 +208,9 @@ class CudaGraphRunner:
|
|
206
208
|
out_cache_loc=out_cache_loc,
|
207
209
|
return_logprob=False,
|
208
210
|
top_logprobs_nums=[0] * bs,
|
209
|
-
positions=(seq_lens - 1
|
211
|
+
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
|
210
212
|
)
|
211
|
-
return forward(input_ids,
|
213
|
+
return forward(input_ids, forward_batch.positions, forward_batch)
|
212
214
|
|
213
215
|
for _ in range(2):
|
214
216
|
torch.cuda.synchronize()
|
@@ -231,24 +233,22 @@ class CudaGraphRunner:
|
|
231
233
|
self.graph_memory_pool = graph.pool()
|
232
234
|
return graph, out
|
233
235
|
|
234
|
-
def replay(self,
|
235
|
-
assert
|
236
|
-
raw_bs =
|
236
|
+
def replay(self, forward_batch: ForwardBatch):
|
237
|
+
assert forward_batch.out_cache_loc is not None
|
238
|
+
raw_bs = forward_batch.batch_size
|
237
239
|
|
238
240
|
# Pad
|
239
241
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
240
242
|
bs = self.capture_bs[index]
|
241
243
|
if bs != raw_bs:
|
242
244
|
self.seq_lens.fill_(self.seq_len_fill_value)
|
243
|
-
self.position_ids_offsets.fill_(1)
|
244
245
|
self.out_cache_loc.zero_()
|
245
246
|
|
246
247
|
# Common inputs
|
247
|
-
self.input_ids[:raw_bs] =
|
248
|
-
self.req_pool_indices[:raw_bs] =
|
249
|
-
self.seq_lens[:raw_bs] =
|
250
|
-
self.
|
251
|
-
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
248
|
+
self.input_ids[:raw_bs] = forward_batch.input_ids
|
249
|
+
self.req_pool_indices[:raw_bs] = forward_batch.req_pool_indices
|
250
|
+
self.seq_lens[:raw_bs] = forward_batch.seq_lens
|
251
|
+
self.out_cache_loc[:raw_bs] = forward_batch.out_cache_loc
|
252
252
|
|
253
253
|
# Attention backend
|
254
254
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
@@ -271,15 +271,15 @@ class CudaGraphRunner:
|
|
271
271
|
)
|
272
272
|
|
273
273
|
# Extract logprobs
|
274
|
-
if
|
274
|
+
if forward_batch.return_logprob:
|
275
275
|
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
|
276
276
|
logits_output.next_token_logits, dim=-1
|
277
277
|
)
|
278
|
-
return_top_logprob = any(x > 0 for x in
|
278
|
+
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
279
279
|
if return_top_logprob:
|
280
280
|
logits_metadata = LogitsMetadata(
|
281
281
|
forward_mode=ForwardMode.DECODE,
|
282
|
-
top_logprobs_nums=
|
282
|
+
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
283
283
|
)
|
284
284
|
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
285
285
|
logits_output.next_token_logprobs, logits_metadata
|
@@ -15,19 +15,33 @@ See the License for the specific language governing permissions and
|
|
15
15
|
limitations under the License.
|
16
16
|
"""
|
17
17
|
|
18
|
-
"""
|
18
|
+
"""
|
19
|
+
Store information about a forward batch.
|
20
|
+
|
21
|
+
The following is the flow of data structures for a batch:
|
22
|
+
|
23
|
+
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
24
|
+
|
25
|
+
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
26
|
+
It contains high-level scheduling data. Most of the data is on the CPU.
|
27
|
+
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
28
|
+
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
29
|
+
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
30
|
+
"""
|
31
|
+
|
19
32
|
from dataclasses import dataclass
|
20
33
|
from enum import IntEnum, auto
|
21
|
-
from typing import TYPE_CHECKING, List
|
34
|
+
from typing import TYPE_CHECKING, List, Optional
|
22
35
|
|
23
36
|
import numpy as np
|
24
37
|
import torch
|
25
38
|
|
26
39
|
if TYPE_CHECKING:
|
27
|
-
from sglang.srt.layers.
|
28
|
-
from sglang.srt.managers.schedule_batch import
|
40
|
+
from sglang.srt.layers.attention import AttentionBackend
|
41
|
+
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
29
42
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
30
43
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
44
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
31
45
|
|
32
46
|
|
33
47
|
class ForwardMode(IntEnum):
|
@@ -37,7 +51,7 @@ class ForwardMode(IntEnum):
|
|
37
51
|
EXTEND = auto()
|
38
52
|
# Decode one token.
|
39
53
|
DECODE = auto()
|
40
|
-
# Contains both
|
54
|
+
# Contains both EXTEND and DECODE.
|
41
55
|
MIXED = auto()
|
42
56
|
|
43
57
|
def is_prefill(self):
|
@@ -54,123 +68,106 @@ class ForwardMode(IntEnum):
|
|
54
68
|
|
55
69
|
|
56
70
|
@dataclass
|
57
|
-
class
|
58
|
-
"""Store all
|
71
|
+
class ForwardBatch:
|
72
|
+
"""Store all inputs of a forward pass."""
|
59
73
|
|
74
|
+
# The forward mode
|
60
75
|
forward_mode: ForwardMode
|
76
|
+
# The batch size
|
61
77
|
batch_size: int
|
78
|
+
# The input ids
|
79
|
+
input_ids: torch.Tensor
|
80
|
+
# The indices of requests in the req_to_token_pool
|
62
81
|
req_pool_indices: torch.Tensor
|
82
|
+
# The sequence length
|
63
83
|
seq_lens: torch.Tensor
|
64
|
-
|
65
|
-
token_to_kv_pool: BaseTokenToKVPool
|
66
|
-
attn_backend: AttentionBackend
|
67
|
-
|
68
|
-
# Output location of the KV cache
|
84
|
+
# The indices of output tokens in the token_to_kv_pool
|
69
85
|
out_cache_loc: torch.Tensor
|
70
86
|
|
87
|
+
# For logprob
|
88
|
+
return_logprob: bool = False
|
89
|
+
top_logprobs_nums: Optional[List[int]] = None
|
90
|
+
|
71
91
|
# Position information
|
72
92
|
positions: torch.Tensor = None
|
73
93
|
|
74
94
|
# For extend
|
75
|
-
extend_seq_lens: torch.Tensor = None
|
76
|
-
extend_prefix_lens: torch.Tensor = None
|
77
|
-
extend_start_loc: torch.Tensor = None
|
78
|
-
|
79
|
-
|
80
|
-
# For logprob
|
81
|
-
return_logprob: bool = False
|
82
|
-
top_logprobs_nums: List[int] = None
|
83
|
-
extend_seq_lens_cpu: List[int] = None
|
84
|
-
extend_logprob_start_lens_cpu: List[int] = None
|
95
|
+
extend_seq_lens: Optional[torch.Tensor] = None
|
96
|
+
extend_prefix_lens: Optional[torch.Tensor] = None
|
97
|
+
extend_start_loc: Optional[torch.Tensor] = None
|
98
|
+
extend_seq_lens_cpu: Optional[List[int]] = None
|
99
|
+
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
85
100
|
|
86
101
|
# For multimodal
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
def compute_positions(self, batch: ScheduleBatch):
|
100
|
-
position_ids_offsets = batch.position_ids_offsets
|
101
|
-
|
102
|
-
if self.forward_mode.is_decode():
|
103
|
-
if True:
|
104
|
-
self.positions = self.seq_lens - 1
|
105
|
-
else:
|
106
|
-
# Deprecated
|
107
|
-
self.positions = (self.seq_lens - 1) + position_ids_offsets
|
108
|
-
else:
|
109
|
-
if True:
|
110
|
-
self.positions = torch.tensor(
|
111
|
-
np.concatenate(
|
112
|
-
[
|
113
|
-
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
|
114
|
-
for i, req in enumerate(batch.reqs)
|
115
|
-
],
|
116
|
-
axis=0,
|
117
|
-
),
|
118
|
-
device="cuda",
|
119
|
-
)
|
120
|
-
else:
|
121
|
-
# Deprecated
|
122
|
-
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
123
|
-
self.positions = torch.tensor(
|
124
|
-
np.concatenate(
|
125
|
-
[
|
126
|
-
np.arange(
|
127
|
-
batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
128
|
-
len(req.fill_ids) + position_ids_offsets_cpu[i],
|
129
|
-
)
|
130
|
-
for i, req in enumerate(batch.reqs)
|
131
|
-
],
|
132
|
-
axis=0,
|
133
|
-
),
|
134
|
-
device="cuda",
|
135
|
-
)
|
136
|
-
|
137
|
-
# Positions should be in long type
|
138
|
-
self.positions = self.positions.to(torch.int64)
|
139
|
-
|
140
|
-
def compute_extend_infos(self, batch: ScheduleBatch):
|
141
|
-
self.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
|
142
|
-
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
143
|
-
self.extend_start_loc = torch.zeros_like(self.extend_seq_lens)
|
144
|
-
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
145
|
-
self.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu)
|
146
|
-
self.extend_seq_lens_cpu = batch.extend_lens_cpu
|
147
|
-
self.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
|
102
|
+
image_inputs: Optional[List[ImageInputs]] = None
|
103
|
+
|
104
|
+
# For LoRA
|
105
|
+
lora_paths: Optional[List[str]] = None
|
106
|
+
|
107
|
+
# Sampling info
|
108
|
+
sampling_info: SamplingBatchInfo = None
|
109
|
+
|
110
|
+
# Attention backend
|
111
|
+
req_to_token_pool: ReqToTokenPool = None
|
112
|
+
token_to_kv_pool: BaseTokenToKVPool = None
|
113
|
+
attn_backend: AttentionBackend = None
|
148
114
|
|
149
115
|
@classmethod
|
150
|
-
def
|
116
|
+
def init_new(
|
151
117
|
cls,
|
152
|
-
|
153
|
-
|
118
|
+
batch: ModelWorkerBatch,
|
119
|
+
model_runner: ModelRunner,
|
154
120
|
):
|
121
|
+
device = "cuda"
|
122
|
+
|
155
123
|
ret = cls(
|
156
124
|
forward_mode=batch.forward_mode,
|
157
|
-
batch_size=batch.
|
125
|
+
batch_size=len(batch.seq_lens),
|
126
|
+
input_ids=batch.input_ids,
|
158
127
|
req_pool_indices=batch.req_pool_indices,
|
159
128
|
seq_lens=batch.seq_lens,
|
160
|
-
req_to_token_pool=model_runner.req_to_token_pool,
|
161
|
-
token_to_kv_pool=model_runner.token_to_kv_pool,
|
162
|
-
attn_backend=model_runner.attn_backend,
|
163
129
|
out_cache_loc=batch.out_cache_loc,
|
164
130
|
return_logprob=batch.return_logprob,
|
165
131
|
top_logprobs_nums=batch.top_logprobs_nums,
|
132
|
+
lora_paths=batch.lora_paths,
|
133
|
+
sampling_info=batch.sampling_info,
|
166
134
|
)
|
167
135
|
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
ret.
|
173
|
-
|
174
|
-
|
136
|
+
# Init position information
|
137
|
+
if ret.forward_mode.is_decode():
|
138
|
+
ret.positions = (ret.seq_lens - 1).to(torch.int64)
|
139
|
+
else:
|
140
|
+
ret.positions = torch.tensor(
|
141
|
+
np.concatenate(
|
142
|
+
[
|
143
|
+
np.arange(prefix_len, prefix_len + extend_len)
|
144
|
+
for prefix_len, extend_len in zip(
|
145
|
+
batch.extend_prefix_lens, batch.extend_seq_lens
|
146
|
+
)
|
147
|
+
],
|
148
|
+
axis=0,
|
149
|
+
),
|
150
|
+
device=device,
|
151
|
+
).to(torch.int64)
|
152
|
+
|
153
|
+
ret.image_inputs = batch.image_inputs
|
154
|
+
ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
|
155
|
+
ret.extend_prefix_lens = torch.tensor(
|
156
|
+
batch.extend_prefix_lens, device=device
|
157
|
+
)
|
158
|
+
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
|
159
|
+
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
|
160
|
+
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
161
|
+
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
162
|
+
|
163
|
+
# Init attention information
|
164
|
+
ret.req_to_token_pool = model_runner.req_to_token_pool
|
165
|
+
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
166
|
+
ret.attn_backend = model_runner.attn_backend
|
167
|
+
model_runner.attn_backend.init_forward_metadata(ret)
|
168
|
+
|
169
|
+
# Init lora information
|
170
|
+
if model_runner.server_args.lora_paths is not None:
|
171
|
+
model_runner.lora_manager.prepare_lora_batch(ret)
|
175
172
|
|
176
173
|
return ret
|