sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +31 -13
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/conversation.py +11 -2
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/data_parallel_controller.py +177 -0
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +11 -2
- sglang/srt/managers/schedule_batch.py +126 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +245 -142
- sglang/srt/managers/tokenizer_manager.py +14 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +100 -36
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +97 -52
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +105 -59
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +171 -37
- sglang/srt/server_args.py +127 -48
- sglang/srt/utils.py +37 -14
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
- sglang-0.3.4.dist-info/RECORD +143 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- sglang-0.3.3.dist-info/RECORD +0 -139
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -46,6 +46,7 @@ from sglang.srt.managers.io_struct import (
|
|
46
46
|
EmbeddingReqInput,
|
47
47
|
FlushCacheReq,
|
48
48
|
GenerateReqInput,
|
49
|
+
ProfileReq,
|
49
50
|
RewardReqInput,
|
50
51
|
TokenizedEmbeddingReqInput,
|
51
52
|
TokenizedGenerateReqInput,
|
@@ -149,9 +150,13 @@ class TokenizerManager:
|
|
149
150
|
while self.model_update_lock.locked():
|
150
151
|
await asyncio.sleep(0.001)
|
151
152
|
|
153
|
+
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
154
|
+
raise ValueError(
|
155
|
+
"This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
|
156
|
+
)
|
157
|
+
|
152
158
|
obj.post_init()
|
153
159
|
is_single = obj.is_single
|
154
|
-
|
155
160
|
if is_single:
|
156
161
|
async for response in self._handle_single_request(obj, request):
|
157
162
|
yield response
|
@@ -512,6 +517,14 @@ class TokenizerManager:
|
|
512
517
|
req = AbortReq(rid)
|
513
518
|
self.send_to_scheduler.send_pyobj(req)
|
514
519
|
|
520
|
+
def start_profile(self):
|
521
|
+
req = ProfileReq.START_PROFILE
|
522
|
+
self.send_to_scheduler.send_pyobj(req)
|
523
|
+
|
524
|
+
def stop_profile(self):
|
525
|
+
req = ProfileReq.STOP_PROFILE
|
526
|
+
self.send_to_scheduler.send_pyobj(req)
|
527
|
+
|
515
528
|
async def update_weights(
|
516
529
|
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
517
530
|
):
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -17,6 +17,11 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import json
|
19
19
|
import logging
|
20
|
+
import threading
|
21
|
+
import time
|
22
|
+
from queue import Queue
|
23
|
+
|
24
|
+
import torch
|
20
25
|
|
21
26
|
from sglang.srt.configs.model_config import ModelConfig
|
22
27
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
@@ -75,6 +80,7 @@ class TpModelWorker:
|
|
75
80
|
tokenizer_mode=server_args.tokenizer_mode,
|
76
81
|
trust_remote_code=server_args.trust_remote_code,
|
77
82
|
)
|
83
|
+
self.device = self.model_runner.device
|
78
84
|
|
79
85
|
# Profile number of tokens
|
80
86
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
@@ -100,6 +106,9 @@ class TpModelWorker:
|
|
100
106
|
)[0]
|
101
107
|
set_random_seed(self.random_seed)
|
102
108
|
|
109
|
+
if server_args.enable_overlap_schedule:
|
110
|
+
self.init_overlap_status()
|
111
|
+
|
103
112
|
def get_token_and_memory_info(self):
|
104
113
|
return (
|
105
114
|
self.max_total_num_tokens,
|
@@ -109,6 +118,81 @@ class TpModelWorker:
|
|
109
118
|
self.random_seed,
|
110
119
|
)
|
111
120
|
|
121
|
+
def init_overlap_status(self):
|
122
|
+
self.future_logits_output_dict = dict()
|
123
|
+
self.future_logits_output_ct = 0
|
124
|
+
self.future_token_ids_ct = 0
|
125
|
+
self.future_token_ids_map = torch.empty(
|
126
|
+
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
|
127
|
+
)
|
128
|
+
self.future_token_ids_limit = self.max_running_requests * 3
|
129
|
+
self.future_token_ids_output = dict()
|
130
|
+
|
131
|
+
self.future_event_map = dict()
|
132
|
+
self.forward_queue = Queue()
|
133
|
+
self.forward_stream = torch.cuda.Stream()
|
134
|
+
self.forward_thread = threading.Thread(
|
135
|
+
target=self.forward_thread_func,
|
136
|
+
)
|
137
|
+
self.forward_thread.start()
|
138
|
+
|
139
|
+
def forward_thread_func(self):
|
140
|
+
with torch.cuda.stream(self.forward_stream):
|
141
|
+
self.forward_thread_func_()
|
142
|
+
|
143
|
+
@torch.inference_mode()
|
144
|
+
def forward_thread_func_(self):
|
145
|
+
while True:
|
146
|
+
tic1 = time.time()
|
147
|
+
model_worker_batch, future_logits_output, future_next_token_ids = (
|
148
|
+
self.forward_queue.get()
|
149
|
+
)
|
150
|
+
|
151
|
+
# Resolve future tokens in the input
|
152
|
+
tic2 = time.time()
|
153
|
+
resolved_input_ids = model_worker_batch.input_ids
|
154
|
+
future_mask = resolved_input_ids < 0
|
155
|
+
resolved_input_ids[future_mask] = self.future_token_ids_map[
|
156
|
+
-resolved_input_ids[future_mask]
|
157
|
+
]
|
158
|
+
|
159
|
+
# Run forward
|
160
|
+
logits_output, next_token_ids = self.forward_batch_generation(
|
161
|
+
model_worker_batch
|
162
|
+
)
|
163
|
+
|
164
|
+
# Set future values
|
165
|
+
if model_worker_batch.return_logprob:
|
166
|
+
self.future_logits_output_dict[future_logits_output] = logits_output
|
167
|
+
|
168
|
+
# logger.info(f"set output {future_next_token_ids=}, {next_token_ids=}")
|
169
|
+
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
|
170
|
+
torch.int32
|
171
|
+
)
|
172
|
+
# logger.info("Set event")
|
173
|
+
self.future_token_ids_output[model_worker_batch.bid] = (
|
174
|
+
next_token_ids.tolist()
|
175
|
+
)
|
176
|
+
self.future_event_map[model_worker_batch.bid].set()
|
177
|
+
|
178
|
+
if False:
|
179
|
+
tic3 = time.time()
|
180
|
+
self.acc_time_with_waiting += tic3 - tic1
|
181
|
+
self.acc_time_without_waiting += tic3 - tic2
|
182
|
+
if self.forward_queue.qsize() == 0:
|
183
|
+
logger.info(
|
184
|
+
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
|
185
|
+
)
|
186
|
+
|
187
|
+
def resolve_future_token_ids(self, bid: int):
|
188
|
+
self.future_event_map[bid].wait()
|
189
|
+
ret = self.future_token_ids_output[bid]
|
190
|
+
del self.future_event_map[bid]
|
191
|
+
return ret
|
192
|
+
|
193
|
+
def resolve_future_logits_output(self, future_obj):
|
194
|
+
return self.future_logits_output_dict.pop(future_obj)
|
195
|
+
|
112
196
|
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
113
197
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
114
198
|
logits_output = self.model_runner.forward(forward_batch)
|
@@ -118,9 +202,35 @@ class TpModelWorker:
|
|
118
202
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
119
203
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
120
204
|
logits_output = self.model_runner.forward(forward_batch)
|
121
|
-
embeddings = logits_output.embeddings
|
205
|
+
embeddings = logits_output.embeddings
|
122
206
|
return embeddings
|
123
207
|
|
208
|
+
def forward_batch_generation_non_blocking(
|
209
|
+
self, model_worker_batch: ModelWorkerBatch
|
210
|
+
):
|
211
|
+
# Allocate output future objects
|
212
|
+
future_logits_output = self.future_logits_output_ct
|
213
|
+
self.future_logits_output_ct += 1
|
214
|
+
|
215
|
+
bs = len(model_worker_batch.seq_lens)
|
216
|
+
with torch.cuda.stream(self.forward_stream):
|
217
|
+
future_next_token_ids = -torch.arange(
|
218
|
+
self.future_token_ids_ct + 1,
|
219
|
+
self.future_token_ids_ct + 1 + bs,
|
220
|
+
dtype=torch.int32,
|
221
|
+
device=self.device,
|
222
|
+
)
|
223
|
+
self.future_token_ids_ct = (
|
224
|
+
self.future_token_ids_ct + bs
|
225
|
+
) % self.future_token_ids_limit
|
226
|
+
ret = future_logits_output, future_next_token_ids
|
227
|
+
|
228
|
+
self.future_event_map[model_worker_batch.bid] = threading.Event()
|
229
|
+
self.forward_queue.put(
|
230
|
+
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
|
231
|
+
)
|
232
|
+
return ret
|
233
|
+
|
124
234
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
125
235
|
success, message = self.model_runner.update_weights(
|
126
236
|
recv_req.model_path, recv_req.load_format
|
@@ -40,10 +40,12 @@ class ChunkCache(BasePrefixCache):
|
|
40
40
|
|
41
41
|
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
42
42
|
if token_ids is None:
|
43
|
-
|
43
|
+
token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
44
|
+
else:
|
45
|
+
token_id_len = len(token_ids)
|
44
46
|
|
45
47
|
kv_indices = self.req_to_token_pool.req_to_token[
|
46
|
-
req.req_pool_idx, :
|
48
|
+
req.req_pool_idx, :token_id_len
|
47
49
|
]
|
48
50
|
self.req_to_token_pool.free(req.req_pool_idx)
|
49
51
|
self.token_to_kv_pool.free(kv_indices)
|
@@ -53,10 +55,12 @@ class ChunkCache(BasePrefixCache):
|
|
53
55
|
|
54
56
|
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
55
57
|
if token_ids is None:
|
56
|
-
|
58
|
+
token_id_len = len(req.fill_ids)
|
59
|
+
else:
|
60
|
+
token_id_len = len(token_ids)
|
57
61
|
|
58
62
|
kv_indices = self.req_to_token_pool.req_to_token[
|
59
|
-
req.req_pool_idx, :
|
63
|
+
req.req_pool_idx, :token_id_len
|
60
64
|
]
|
61
65
|
|
62
66
|
if req.rid not in self.entries:
|
@@ -18,7 +18,6 @@ limitations under the License.
|
|
18
18
|
import logging
|
19
19
|
from typing import List, Tuple, Union
|
20
20
|
|
21
|
-
import numpy as np
|
22
21
|
import torch
|
23
22
|
|
24
23
|
logger = logging.getLogger(__name__)
|
@@ -77,6 +76,8 @@ class BaseTokenToKVPool:
|
|
77
76
|
self.store_dtype = dtype
|
78
77
|
|
79
78
|
self.free_slots = None
|
79
|
+
self.is_not_in_free_group = True
|
80
|
+
self.free_group = []
|
80
81
|
self.clear()
|
81
82
|
|
82
83
|
def available_size(self):
|
@@ -89,14 +90,28 @@ class BaseTokenToKVPool:
|
|
89
90
|
select_index = self.free_slots[:need_size]
|
90
91
|
self.free_slots = self.free_slots[need_size:]
|
91
92
|
|
92
|
-
return
|
93
|
+
return select_index.to(self.device, non_blocking=True)
|
93
94
|
|
94
95
|
def free(self, free_index: torch.Tensor):
|
95
|
-
|
96
|
+
if self.is_not_in_free_group:
|
97
|
+
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
|
98
|
+
else:
|
99
|
+
self.free_group.append(free_index)
|
100
|
+
|
101
|
+
def free_group_begin(self):
|
102
|
+
self.is_not_in_free_group = False
|
103
|
+
self.free_group = []
|
104
|
+
|
105
|
+
def free_group_end(self):
|
106
|
+
self.is_not_in_free_group = True
|
107
|
+
if self.free_group:
|
108
|
+
self.free(torch.concat(self.free_group))
|
96
109
|
|
97
110
|
def clear(self):
|
98
111
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
99
|
-
self.free_slots =
|
112
|
+
self.free_slots = torch.arange(1, self.size + 1, dtype=torch.int32)
|
113
|
+
self.is_in_free_group = False
|
114
|
+
self.free_group = []
|
100
115
|
|
101
116
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
102
117
|
raise NotImplementedError()
|
@@ -231,3 +246,61 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
231
246
|
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
232
247
|
else:
|
233
248
|
self.kv_buffer[layer_id][loc] = cache_k
|
249
|
+
|
250
|
+
|
251
|
+
class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
252
|
+
|
253
|
+
def __init__(
|
254
|
+
self,
|
255
|
+
size: int,
|
256
|
+
dtype: torch.dtype,
|
257
|
+
head_num: int,
|
258
|
+
head_dim: int,
|
259
|
+
layer_num: int,
|
260
|
+
device: str,
|
261
|
+
heavy_channel_num: int,
|
262
|
+
):
|
263
|
+
super().__init__(size, dtype, device)
|
264
|
+
|
265
|
+
# [size, head_num, head_dim] for each layer
|
266
|
+
self.k_buffer = [
|
267
|
+
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
268
|
+
for _ in range(layer_num)
|
269
|
+
]
|
270
|
+
self.v_buffer = [
|
271
|
+
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
272
|
+
for _ in range(layer_num)
|
273
|
+
]
|
274
|
+
|
275
|
+
# [size, head_num, heavy_channel_num] for each layer
|
276
|
+
self.label_buffer = [
|
277
|
+
torch.empty(
|
278
|
+
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
279
|
+
)
|
280
|
+
for _ in range(layer_num)
|
281
|
+
]
|
282
|
+
|
283
|
+
def get_key_buffer(self, layer_id: int):
|
284
|
+
return self.k_buffer[layer_id]
|
285
|
+
|
286
|
+
def get_value_buffer(self, layer_id: int):
|
287
|
+
return self.v_buffer[layer_id]
|
288
|
+
|
289
|
+
def get_label_buffer(self, layer_id: int):
|
290
|
+
return self.label_buffer[layer_id]
|
291
|
+
|
292
|
+
def get_kv_buffer(self, layer_id: int):
|
293
|
+
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
294
|
+
|
295
|
+
def set_kv_buffer(
|
296
|
+
self,
|
297
|
+
layer_id: int,
|
298
|
+
loc: torch.Tensor,
|
299
|
+
cache_k: torch.Tensor,
|
300
|
+
cache_v: torch.Tensor,
|
301
|
+
cache_label: torch.Tensor,
|
302
|
+
):
|
303
|
+
# NOTE(Andy): ignore the dtype check
|
304
|
+
self.k_buffer[layer_id][loc] = cache_k
|
305
|
+
self.v_buffer[layer_id][loc] = cache_v
|
306
|
+
self.label_buffer[layer_id][loc] = cache_label
|
@@ -99,17 +99,25 @@ class RadixCache(BasePrefixCache):
|
|
99
99
|
|
100
100
|
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
101
101
|
"""Cache request when it finishes."""
|
102
|
+
if self.disable:
|
103
|
+
if token_ids is None:
|
104
|
+
token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
105
|
+
else:
|
106
|
+
token_ids_len = len(token_ids)
|
107
|
+
|
108
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
109
|
+
req.req_pool_idx, :token_ids_len
|
110
|
+
]
|
111
|
+
self.token_to_kv_pool.free(kv_indices)
|
112
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
113
|
+
return
|
114
|
+
|
102
115
|
if token_ids is None:
|
103
116
|
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
104
117
|
kv_indices = self.req_to_token_pool.req_to_token[
|
105
118
|
req.req_pool_idx, : len(token_ids)
|
106
119
|
]
|
107
120
|
|
108
|
-
if self.disable:
|
109
|
-
self.token_to_kv_pool.free(kv_indices)
|
110
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
111
|
-
return
|
112
|
-
|
113
121
|
# Radix Cache takes one ref in memory pool
|
114
122
|
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
115
123
|
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
@@ -229,7 +237,7 @@ class RadixCache(BasePrefixCache):
|
|
229
237
|
def _split_node(self, key, child: TreeNode, split_len: int):
|
230
238
|
# new_node -> child
|
231
239
|
new_node = TreeNode()
|
232
|
-
new_node.children = {key[split_len
|
240
|
+
new_node.children = {key[split_len]: child}
|
233
241
|
new_node.parent = child.parent
|
234
242
|
new_node.lock_ref = child.lock_ref
|
235
243
|
new_node.key = child.key[:split_len]
|
@@ -237,7 +245,7 @@ class RadixCache(BasePrefixCache):
|
|
237
245
|
child.parent = new_node
|
238
246
|
child.key = child.key[split_len:]
|
239
247
|
child.value = child.value[split_len:]
|
240
|
-
new_node.parent.children[key[
|
248
|
+
new_node.parent.children[key[0]] = new_node
|
241
249
|
return new_node
|
242
250
|
|
243
251
|
def _insert_helper(self, node: TreeNode, key: List, value):
|
@@ -245,10 +245,10 @@ class CudaGraphRunner:
|
|
245
245
|
self.out_cache_loc.zero_()
|
246
246
|
|
247
247
|
# Common inputs
|
248
|
-
self.input_ids[:raw_bs]
|
249
|
-
self.req_pool_indices[:raw_bs]
|
250
|
-
self.seq_lens[:raw_bs]
|
251
|
-
self.out_cache_loc[:raw_bs]
|
248
|
+
self.input_ids[:raw_bs].copy_(forward_batch.input_ids)
|
249
|
+
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
250
|
+
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
251
|
+
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
|
252
252
|
|
253
253
|
# Attention backend
|
254
254
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
@@ -118,7 +118,7 @@ class ForwardBatch:
|
|
118
118
|
batch: ModelWorkerBatch,
|
119
119
|
model_runner: ModelRunner,
|
120
120
|
):
|
121
|
-
device =
|
121
|
+
device = model_runner.device
|
122
122
|
|
123
123
|
ret = cls(
|
124
124
|
forward_mode=batch.forward_mode,
|
@@ -134,27 +134,23 @@ class ForwardBatch:
|
|
134
134
|
)
|
135
135
|
|
136
136
|
# Init position information
|
137
|
-
if ret.forward_mode.is_decode():
|
138
|
-
ret.positions =
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
],
|
148
|
-
axis=0,
|
149
|
-
),
|
150
|
-
device=device,
|
151
|
-
).to(torch.int64)
|
152
|
-
|
137
|
+
if not ret.forward_mode.is_decode():
|
138
|
+
ret.positions = torch.concat(
|
139
|
+
[
|
140
|
+
torch.arange(prefix_len, prefix_len + extend_len, device=device)
|
141
|
+
for prefix_len, extend_len in zip(
|
142
|
+
batch.extend_prefix_lens, batch.extend_seq_lens
|
143
|
+
)
|
144
|
+
],
|
145
|
+
axis=0,
|
146
|
+
)
|
153
147
|
ret.image_inputs = batch.image_inputs
|
154
|
-
ret.extend_seq_lens = torch.tensor(
|
148
|
+
ret.extend_seq_lens = torch.tensor(
|
149
|
+
batch.extend_seq_lens, dtype=torch.int32
|
150
|
+
).to(device, non_blocking=True)
|
155
151
|
ret.extend_prefix_lens = torch.tensor(
|
156
|
-
batch.extend_prefix_lens,
|
157
|
-
)
|
152
|
+
batch.extend_prefix_lens, dtype=torch.int32
|
153
|
+
).to(device, non_blocking=True)
|
158
154
|
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
|
159
155
|
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
|
160
156
|
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
@@ -164,7 +160,6 @@ class ForwardBatch:
|
|
164
160
|
ret.req_to_token_pool = model_runner.req_to_token_pool
|
165
161
|
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
166
162
|
ret.attn_backend = model_runner.attn_backend
|
167
|
-
model_runner.attn_backend.init_forward_metadata(ret)
|
168
163
|
|
169
164
|
# Init lora information
|
170
165
|
if model_runner.server_args.lora_paths is not None:
|