sglang 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +46 -25
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +184 -63
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -248
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/model_executor/cuda_graph_runner.py +15 -19
- sglang/srt/model_executor/forward_batch_info.py +94 -95
- sglang/srt/model_executor/model_runner.py +76 -75
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +14 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +71 -26
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +18 -9
- sglang/version.py +1 -1
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/METADATA +37 -19
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -474
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.2.dist-info/RECORD +0 -135
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -16,9 +16,9 @@ limitations under the License.
|
|
16
16
|
"""Memory pool."""
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from abc import ABC, abstractmethod
|
20
19
|
from typing import List, Tuple, Union
|
21
20
|
|
21
|
+
import numpy as np
|
22
22
|
import torch
|
23
23
|
|
24
24
|
logger = logging.getLogger(__name__)
|
@@ -27,12 +27,17 @@ logger = logging.getLogger(__name__)
|
|
27
27
|
class ReqToTokenPool:
|
28
28
|
"""A memory pool that maps a request to its token locations."""
|
29
29
|
|
30
|
-
def __init__(self, size: int, max_context_len: int):
|
30
|
+
def __init__(self, size: int, max_context_len: int, device: str):
|
31
31
|
self.size = size
|
32
|
-
self.
|
32
|
+
self.max_context_len = max_context_len
|
33
|
+
self.device = device
|
33
34
|
self.req_to_token = torch.empty(
|
34
|
-
(size, max_context_len), dtype=torch.int32, device=
|
35
|
+
(size, max_context_len), dtype=torch.int32, device=device
|
35
36
|
)
|
37
|
+
self.free_slots = list(range(size))
|
38
|
+
|
39
|
+
def available_size(self):
|
40
|
+
return len(self.free_slots)
|
36
41
|
|
37
42
|
def alloc(self, need_size: int) -> List[int]:
|
38
43
|
if need_size > len(self.free_slots):
|
@@ -53,86 +58,55 @@ class ReqToTokenPool:
|
|
53
58
|
self.free_slots = list(range(self.size))
|
54
59
|
|
55
60
|
|
56
|
-
class BaseTokenToKVPool
|
61
|
+
class BaseTokenToKVPool:
|
57
62
|
"""A memory pool that maps a token to its kv cache locations"""
|
58
63
|
|
59
64
|
def __init__(
|
60
65
|
self,
|
61
66
|
size: int,
|
62
67
|
dtype: torch.dtype,
|
68
|
+
device: str,
|
63
69
|
):
|
64
70
|
self.size = size
|
65
71
|
self.dtype = dtype
|
72
|
+
self.device = device
|
66
73
|
if dtype == torch.float8_e5m2:
|
67
74
|
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
68
75
|
self.store_dtype = torch.uint8
|
69
76
|
else:
|
70
77
|
self.store_dtype = dtype
|
71
78
|
|
72
|
-
|
73
|
-
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
74
|
-
|
75
|
-
# Prefetch buffer
|
76
|
-
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
77
|
-
self.prefetch_chunk_size = 512
|
78
|
-
|
79
|
-
self.can_use_mem_size = self.size
|
79
|
+
self.free_slots = None
|
80
80
|
self.clear()
|
81
81
|
|
82
82
|
def available_size(self):
|
83
|
-
return
|
83
|
+
return len(self.free_slots)
|
84
84
|
|
85
85
|
def alloc(self, need_size: int):
|
86
|
-
|
87
|
-
if need_size <= buffer_len:
|
88
|
-
select_index = self.prefetch_buffer[:need_size]
|
89
|
-
self.prefetch_buffer = self.prefetch_buffer[need_size:]
|
90
|
-
return select_index
|
91
|
-
|
92
|
-
addition_size = need_size - buffer_len
|
93
|
-
alloc_size = max(addition_size, self.prefetch_chunk_size)
|
94
|
-
select_index = (
|
95
|
-
torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
|
96
|
-
)
|
97
|
-
|
98
|
-
if select_index.shape[0] < addition_size:
|
86
|
+
if need_size > len(self.free_slots):
|
99
87
|
return None
|
100
88
|
|
101
|
-
self.
|
102
|
-
self.
|
103
|
-
|
104
|
-
self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
|
105
|
-
ret_index = self.prefetch_buffer[:need_size]
|
106
|
-
self.prefetch_buffer = self.prefetch_buffer[need_size:]
|
89
|
+
select_index = self.free_slots[:need_size]
|
90
|
+
self.free_slots = self.free_slots[need_size:]
|
107
91
|
|
108
|
-
return
|
92
|
+
return torch.tensor(select_index, dtype=torch.int32, device=self.device)
|
109
93
|
|
110
94
|
def free(self, free_index: torch.Tensor):
|
111
|
-
self.
|
112
|
-
self.can_use_mem_size += len(free_index)
|
95
|
+
self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy()))
|
113
96
|
|
114
97
|
def clear(self):
|
115
|
-
|
98
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
99
|
+
self.free_slots = np.arange(1, self.size + 1)
|
116
100
|
|
117
|
-
self.mem_state.fill_(True)
|
118
|
-
self.can_use_mem_size = self.size
|
119
|
-
|
120
|
-
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
121
|
-
self.mem_state[0] = False
|
122
|
-
|
123
|
-
@abstractmethod
|
124
101
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
125
102
|
raise NotImplementedError()
|
126
103
|
|
127
|
-
@abstractmethod
|
128
104
|
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
129
105
|
raise NotImplementedError()
|
130
106
|
|
131
|
-
@abstractmethod
|
132
107
|
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
133
108
|
raise NotImplementedError()
|
134
109
|
|
135
|
-
@abstractmethod
|
136
110
|
def set_kv_buffer(
|
137
111
|
self,
|
138
112
|
layer_id: int,
|
@@ -152,19 +126,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
152
126
|
head_num: int,
|
153
127
|
head_dim: int,
|
154
128
|
layer_num: int,
|
129
|
+
device: str,
|
155
130
|
):
|
156
|
-
super().__init__(size, dtype)
|
131
|
+
super().__init__(size, dtype, device)
|
157
132
|
|
158
133
|
# [size, head_num, head_dim] for each layer
|
134
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
159
135
|
self.k_buffer = [
|
160
136
|
torch.empty(
|
161
|
-
(size + 1, head_num, head_dim),
|
137
|
+
(size + 1, head_num, head_dim),
|
138
|
+
dtype=self.store_dtype,
|
139
|
+
device=device,
|
162
140
|
)
|
163
141
|
for _ in range(layer_num)
|
164
142
|
]
|
165
143
|
self.v_buffer = [
|
166
144
|
torch.empty(
|
167
|
-
(size + 1, head_num, head_dim),
|
145
|
+
(size + 1, head_num, head_dim),
|
146
|
+
dtype=self.store_dtype,
|
147
|
+
device=device,
|
168
148
|
)
|
169
149
|
for _ in range(layer_num)
|
170
150
|
]
|
@@ -210,15 +190,17 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
210
190
|
kv_lora_rank: int,
|
211
191
|
qk_rope_head_dim: int,
|
212
192
|
layer_num: int,
|
193
|
+
device: str,
|
213
194
|
):
|
214
|
-
super().__init__(size, dtype)
|
195
|
+
super().__init__(size, dtype, device)
|
215
196
|
|
216
197
|
self.kv_lora_rank = kv_lora_rank
|
198
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
217
199
|
self.kv_buffer = [
|
218
200
|
torch.empty(
|
219
201
|
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
220
202
|
dtype=self.store_dtype,
|
221
|
-
device=
|
203
|
+
device=device,
|
222
204
|
)
|
223
205
|
for _ in range(layer_num)
|
224
206
|
]
|
@@ -31,8 +31,7 @@ from sglang.srt.layers.logits_processor import (
|
|
31
31
|
LogitsProcessor,
|
32
32
|
LogitsProcessorOutput,
|
33
33
|
)
|
34
|
-
from sglang.srt.
|
35
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
34
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
36
35
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
37
36
|
|
38
37
|
if TYPE_CHECKING:
|
@@ -143,7 +142,6 @@ class CudaGraphRunner:
|
|
143
142
|
self.seq_lens = torch.full(
|
144
143
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
145
144
|
)
|
146
|
-
self.position_ids_offsets = torch.ones((self.max_bs,), dtype=torch.int32)
|
147
145
|
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
148
146
|
|
149
147
|
# Capture
|
@@ -189,7 +187,6 @@ class CudaGraphRunner:
|
|
189
187
|
input_ids = self.input_ids[:bs]
|
190
188
|
req_pool_indices = self.req_pool_indices[:bs]
|
191
189
|
seq_lens = self.seq_lens[:bs]
|
192
|
-
position_ids_offsets = self.position_ids_offsets[:bs]
|
193
190
|
out_cache_loc = self.out_cache_loc[:bs]
|
194
191
|
|
195
192
|
# Attention backend
|
@@ -199,9 +196,10 @@ class CudaGraphRunner:
|
|
199
196
|
|
200
197
|
# Run and capture
|
201
198
|
def run_once():
|
202
|
-
|
199
|
+
forward_batch = ForwardBatch(
|
203
200
|
forward_mode=ForwardMode.DECODE,
|
204
201
|
batch_size=bs,
|
202
|
+
input_ids=input_ids,
|
205
203
|
req_pool_indices=req_pool_indices,
|
206
204
|
seq_lens=seq_lens,
|
207
205
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
@@ -210,9 +208,9 @@ class CudaGraphRunner:
|
|
210
208
|
out_cache_loc=out_cache_loc,
|
211
209
|
return_logprob=False,
|
212
210
|
top_logprobs_nums=[0] * bs,
|
213
|
-
positions=(seq_lens - 1
|
211
|
+
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
|
214
212
|
)
|
215
|
-
return forward(input_ids,
|
213
|
+
return forward(input_ids, forward_batch.positions, forward_batch)
|
216
214
|
|
217
215
|
for _ in range(2):
|
218
216
|
torch.cuda.synchronize()
|
@@ -235,24 +233,22 @@ class CudaGraphRunner:
|
|
235
233
|
self.graph_memory_pool = graph.pool()
|
236
234
|
return graph, out
|
237
235
|
|
238
|
-
def replay(self,
|
239
|
-
assert
|
240
|
-
raw_bs =
|
236
|
+
def replay(self, forward_batch: ForwardBatch):
|
237
|
+
assert forward_batch.out_cache_loc is not None
|
238
|
+
raw_bs = forward_batch.batch_size
|
241
239
|
|
242
240
|
# Pad
|
243
241
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
244
242
|
bs = self.capture_bs[index]
|
245
243
|
if bs != raw_bs:
|
246
244
|
self.seq_lens.fill_(self.seq_len_fill_value)
|
247
|
-
self.position_ids_offsets.fill_(1)
|
248
245
|
self.out_cache_loc.zero_()
|
249
246
|
|
250
247
|
# Common inputs
|
251
|
-
self.input_ids[:raw_bs] =
|
252
|
-
self.req_pool_indices[:raw_bs] =
|
253
|
-
self.seq_lens[:raw_bs] =
|
254
|
-
self.
|
255
|
-
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
248
|
+
self.input_ids[:raw_bs] = forward_batch.input_ids
|
249
|
+
self.req_pool_indices[:raw_bs] = forward_batch.req_pool_indices
|
250
|
+
self.seq_lens[:raw_bs] = forward_batch.seq_lens
|
251
|
+
self.out_cache_loc[:raw_bs] = forward_batch.out_cache_loc
|
256
252
|
|
257
253
|
# Attention backend
|
258
254
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
@@ -275,15 +271,15 @@ class CudaGraphRunner:
|
|
275
271
|
)
|
276
272
|
|
277
273
|
# Extract logprobs
|
278
|
-
if
|
274
|
+
if forward_batch.return_logprob:
|
279
275
|
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
|
280
276
|
logits_output.next_token_logits, dim=-1
|
281
277
|
)
|
282
|
-
return_top_logprob = any(x > 0 for x in
|
278
|
+
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
283
279
|
if return_top_logprob:
|
284
280
|
logits_metadata = LogitsMetadata(
|
285
281
|
forward_mode=ForwardMode.DECODE,
|
286
|
-
top_logprobs_nums=
|
282
|
+
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
287
283
|
)
|
288
284
|
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
289
285
|
logits_output.next_token_logprobs, logits_metadata
|
@@ -15,19 +15,33 @@ See the License for the specific language governing permissions and
|
|
15
15
|
limitations under the License.
|
16
16
|
"""
|
17
17
|
|
18
|
-
"""
|
18
|
+
"""
|
19
|
+
Store information about a forward batch.
|
20
|
+
|
21
|
+
The following is the flow of data structures for a batch:
|
22
|
+
|
23
|
+
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
24
|
+
|
25
|
+
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
26
|
+
It contains high-level scheduling data. Most of the data is on the CPU.
|
27
|
+
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
28
|
+
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
29
|
+
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
30
|
+
"""
|
31
|
+
|
19
32
|
from dataclasses import dataclass
|
20
33
|
from enum import IntEnum, auto
|
21
|
-
from typing import TYPE_CHECKING, List
|
34
|
+
from typing import TYPE_CHECKING, List, Optional
|
22
35
|
|
23
36
|
import numpy as np
|
24
37
|
import torch
|
25
38
|
|
26
39
|
if TYPE_CHECKING:
|
27
|
-
from sglang.srt.layers.
|
28
|
-
from sglang.srt.managers.schedule_batch import
|
40
|
+
from sglang.srt.layers.attention import AttentionBackend
|
41
|
+
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
29
42
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
30
43
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
44
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
31
45
|
|
32
46
|
|
33
47
|
class ForwardMode(IntEnum):
|
@@ -37,7 +51,7 @@ class ForwardMode(IntEnum):
|
|
37
51
|
EXTEND = auto()
|
38
52
|
# Decode one token.
|
39
53
|
DECODE = auto()
|
40
|
-
# Contains both
|
54
|
+
# Contains both EXTEND and DECODE.
|
41
55
|
MIXED = auto()
|
42
56
|
|
43
57
|
def is_prefill(self):
|
@@ -54,121 +68,106 @@ class ForwardMode(IntEnum):
|
|
54
68
|
|
55
69
|
|
56
70
|
@dataclass
|
57
|
-
class
|
58
|
-
"""Store all
|
71
|
+
class ForwardBatch:
|
72
|
+
"""Store all inputs of a forward pass."""
|
59
73
|
|
74
|
+
# The forward mode
|
60
75
|
forward_mode: ForwardMode
|
76
|
+
# The batch size
|
61
77
|
batch_size: int
|
78
|
+
# The input ids
|
79
|
+
input_ids: torch.Tensor
|
80
|
+
# The indices of requests in the req_to_token_pool
|
62
81
|
req_pool_indices: torch.Tensor
|
82
|
+
# The sequence length
|
63
83
|
seq_lens: torch.Tensor
|
64
|
-
|
65
|
-
token_to_kv_pool: BaseTokenToKVPool
|
66
|
-
attn_backend: AttentionBackend
|
67
|
-
|
68
|
-
# Output location of the KV cache
|
84
|
+
# The indices of output tokens in the token_to_kv_pool
|
69
85
|
out_cache_loc: torch.Tensor
|
70
86
|
|
87
|
+
# For logprob
|
88
|
+
return_logprob: bool = False
|
89
|
+
top_logprobs_nums: Optional[List[int]] = None
|
90
|
+
|
71
91
|
# Position information
|
72
92
|
positions: torch.Tensor = None
|
73
93
|
|
74
94
|
# For extend
|
75
|
-
extend_seq_lens: torch.Tensor = None
|
76
|
-
extend_prefix_lens: torch.Tensor = None
|
77
|
-
extend_start_loc: torch.Tensor = None
|
78
|
-
|
79
|
-
|
80
|
-
# For logprob
|
81
|
-
return_logprob: bool = False
|
82
|
-
top_logprobs_nums: List[int] = None
|
83
|
-
extend_seq_lens_cpu: List[int] = None
|
84
|
-
extend_logprob_start_lens_cpu: List[int] = None
|
95
|
+
extend_seq_lens: Optional[torch.Tensor] = None
|
96
|
+
extend_prefix_lens: Optional[torch.Tensor] = None
|
97
|
+
extend_start_loc: Optional[torch.Tensor] = None
|
98
|
+
extend_seq_lens_cpu: Optional[List[int]] = None
|
99
|
+
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
85
100
|
|
86
101
|
# For multimodal
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
def compute_positions(self, batch: ScheduleBatch):
|
100
|
-
if self.forward_mode.is_decode():
|
101
|
-
if True:
|
102
|
-
self.positions = self.seq_lens - 1
|
103
|
-
else:
|
104
|
-
# Deprecated
|
105
|
-
self.positions = (self.seq_lens - 1) + batch.position_ids_offsets
|
106
|
-
else:
|
107
|
-
if True:
|
108
|
-
self.positions = torch.tensor(
|
109
|
-
np.concatenate(
|
110
|
-
[
|
111
|
-
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
|
112
|
-
for i, req in enumerate(batch.reqs)
|
113
|
-
],
|
114
|
-
axis=0,
|
115
|
-
),
|
116
|
-
device="cuda",
|
117
|
-
)
|
118
|
-
else:
|
119
|
-
# Deprecated
|
120
|
-
position_ids_offsets_cpu = batch.position_ids_offsets.cpu().numpy()
|
121
|
-
self.positions = torch.tensor(
|
122
|
-
np.concatenate(
|
123
|
-
[
|
124
|
-
np.arange(
|
125
|
-
batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
126
|
-
len(req.fill_ids) + position_ids_offsets_cpu[i],
|
127
|
-
)
|
128
|
-
for i, req in enumerate(batch.reqs)
|
129
|
-
],
|
130
|
-
axis=0,
|
131
|
-
),
|
132
|
-
device="cuda",
|
133
|
-
)
|
134
|
-
|
135
|
-
# Positions should be in long type
|
136
|
-
self.positions = self.positions.to(torch.int64)
|
137
|
-
|
138
|
-
def compute_extend_infos(self, batch: ScheduleBatch):
|
139
|
-
self.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
|
140
|
-
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
141
|
-
self.extend_start_loc = torch.zeros_like(self.extend_seq_lens)
|
142
|
-
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
143
|
-
self.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu)
|
144
|
-
self.extend_seq_lens_cpu = batch.extend_lens_cpu
|
145
|
-
self.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
|
102
|
+
image_inputs: Optional[List[ImageInputs]] = None
|
103
|
+
|
104
|
+
# For LoRA
|
105
|
+
lora_paths: Optional[List[str]] = None
|
106
|
+
|
107
|
+
# Sampling info
|
108
|
+
sampling_info: SamplingBatchInfo = None
|
109
|
+
|
110
|
+
# Attention backend
|
111
|
+
req_to_token_pool: ReqToTokenPool = None
|
112
|
+
token_to_kv_pool: BaseTokenToKVPool = None
|
113
|
+
attn_backend: AttentionBackend = None
|
146
114
|
|
147
115
|
@classmethod
|
148
|
-
def
|
116
|
+
def init_new(
|
149
117
|
cls,
|
150
|
-
|
151
|
-
|
118
|
+
batch: ModelWorkerBatch,
|
119
|
+
model_runner: ModelRunner,
|
152
120
|
):
|
121
|
+
device = "cuda"
|
122
|
+
|
153
123
|
ret = cls(
|
154
124
|
forward_mode=batch.forward_mode,
|
155
|
-
batch_size=batch.
|
125
|
+
batch_size=len(batch.seq_lens),
|
126
|
+
input_ids=batch.input_ids,
|
156
127
|
req_pool_indices=batch.req_pool_indices,
|
157
128
|
seq_lens=batch.seq_lens,
|
158
|
-
req_to_token_pool=model_runner.req_to_token_pool,
|
159
|
-
token_to_kv_pool=model_runner.token_to_kv_pool,
|
160
|
-
attn_backend=model_runner.attn_backend,
|
161
129
|
out_cache_loc=batch.out_cache_loc,
|
162
130
|
return_logprob=batch.return_logprob,
|
163
131
|
top_logprobs_nums=batch.top_logprobs_nums,
|
132
|
+
lora_paths=batch.lora_paths,
|
133
|
+
sampling_info=batch.sampling_info,
|
164
134
|
)
|
165
135
|
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
ret.
|
171
|
-
|
172
|
-
|
136
|
+
# Init position information
|
137
|
+
if ret.forward_mode.is_decode():
|
138
|
+
ret.positions = (ret.seq_lens - 1).to(torch.int64)
|
139
|
+
else:
|
140
|
+
ret.positions = torch.tensor(
|
141
|
+
np.concatenate(
|
142
|
+
[
|
143
|
+
np.arange(prefix_len, prefix_len + extend_len)
|
144
|
+
for prefix_len, extend_len in zip(
|
145
|
+
batch.extend_prefix_lens, batch.extend_seq_lens
|
146
|
+
)
|
147
|
+
],
|
148
|
+
axis=0,
|
149
|
+
),
|
150
|
+
device=device,
|
151
|
+
).to(torch.int64)
|
152
|
+
|
153
|
+
ret.image_inputs = batch.image_inputs
|
154
|
+
ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
|
155
|
+
ret.extend_prefix_lens = torch.tensor(
|
156
|
+
batch.extend_prefix_lens, device=device
|
157
|
+
)
|
158
|
+
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
|
159
|
+
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
|
160
|
+
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
161
|
+
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
162
|
+
|
163
|
+
# Init attention information
|
164
|
+
ret.req_to_token_pool = model_runner.req_to_token_pool
|
165
|
+
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
166
|
+
ret.attn_backend = model_runner.attn_backend
|
167
|
+
model_runner.attn_backend.init_forward_metadata(ret)
|
168
|
+
|
169
|
+
# Init lora information
|
170
|
+
if model_runner.server_args.lora_paths is not None:
|
171
|
+
model_runner.lora_manager.prepare_lora_batch(ret)
|
173
172
|
|
174
173
|
return ret
|