sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,908 @@
|
|
1
|
+
"""Meta data for requests and batches"""
|
2
|
+
|
3
|
+
import warnings
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from enum import IntEnum, auto
|
6
|
+
from typing import List, Union
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import torch
|
10
|
+
|
11
|
+
from sglang.srt.constrained import RegexGuide
|
12
|
+
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
13
|
+
from sglang.srt.managers.controller.radix_cache import RadixCache
|
14
|
+
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
15
|
+
|
16
|
+
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
17
|
+
|
18
|
+
# Store some global server args
|
19
|
+
global_server_args_dict = {}
|
20
|
+
|
21
|
+
|
22
|
+
class ForwardMode(IntEnum):
|
23
|
+
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
24
|
+
PREFILL = auto()
|
25
|
+
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
26
|
+
EXTEND = auto()
|
27
|
+
# Decode one token.
|
28
|
+
DECODE = auto()
|
29
|
+
|
30
|
+
|
31
|
+
class BaseFinishReason:
|
32
|
+
def __init__(self, is_error: bool = False):
|
33
|
+
self.is_error = is_error
|
34
|
+
|
35
|
+
def __str__(self):
|
36
|
+
raise NotImplementedError("Subclasses must implement this method")
|
37
|
+
|
38
|
+
|
39
|
+
class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
40
|
+
def __init__(self, matched: Union[int, List[int]]):
|
41
|
+
super().__init__()
|
42
|
+
self.matched = matched
|
43
|
+
|
44
|
+
def __str__(self) -> str:
|
45
|
+
return f"FINISH_MATCHED_TOKEN: {self.matched}"
|
46
|
+
|
47
|
+
|
48
|
+
class FINISH_LENGTH(BaseFinishReason):
|
49
|
+
def __init__(self, length: int):
|
50
|
+
super().__init__()
|
51
|
+
self.length = length
|
52
|
+
|
53
|
+
def __str__(self) -> str:
|
54
|
+
return f"FINISH_LENGTH: {self.length}"
|
55
|
+
|
56
|
+
|
57
|
+
class FINISH_MATCHED_STR(BaseFinishReason):
|
58
|
+
def __init__(self, matched: str):
|
59
|
+
super().__init__()
|
60
|
+
self.matched = matched
|
61
|
+
|
62
|
+
def __str__(self) -> str:
|
63
|
+
return f"FINISH_MATCHED_STR: {self.matched}"
|
64
|
+
|
65
|
+
|
66
|
+
class FINISH_ABORT(BaseFinishReason):
|
67
|
+
def __init__(self):
|
68
|
+
super().__init__(is_error=True)
|
69
|
+
|
70
|
+
def __str__(self) -> str:
|
71
|
+
return "FINISH_ABORT"
|
72
|
+
|
73
|
+
|
74
|
+
class Req:
|
75
|
+
"""Store all inforamtion of a request."""
|
76
|
+
|
77
|
+
def __init__(self, rid, origin_input_text, origin_input_ids):
|
78
|
+
# Input and output info
|
79
|
+
self.rid = rid
|
80
|
+
self.origin_input_text = origin_input_text
|
81
|
+
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
|
82
|
+
self.origin_input_ids = origin_input_ids
|
83
|
+
self.output_ids = [] # Each decode stage's output ids
|
84
|
+
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
85
|
+
|
86
|
+
# For incremental decoding
|
87
|
+
self.decoded_text = ""
|
88
|
+
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
89
|
+
self.read_offset = None
|
90
|
+
|
91
|
+
# The number of decoded tokens for token usage report. Note that
|
92
|
+
# this does not include the jump forward tokens.
|
93
|
+
self.completion_tokens_wo_jump_forward = 0
|
94
|
+
|
95
|
+
# For vision input
|
96
|
+
self.pixel_values = None
|
97
|
+
self.image_size = None
|
98
|
+
self.image_offset = 0
|
99
|
+
self.pad_value = None
|
100
|
+
|
101
|
+
# Prefix info
|
102
|
+
self.extend_input_len = 0
|
103
|
+
self.prefix_indices = []
|
104
|
+
self.last_node = None
|
105
|
+
|
106
|
+
# Sampling parameters
|
107
|
+
self.sampling_params = None
|
108
|
+
self.stream = False
|
109
|
+
|
110
|
+
# Check finish
|
111
|
+
self.tokenizer = None
|
112
|
+
self.finished_reason = None
|
113
|
+
|
114
|
+
# Logprobs
|
115
|
+
self.return_logprob = False
|
116
|
+
self.logprob_start_len = 0
|
117
|
+
self.top_logprobs_num = 0
|
118
|
+
self.normalized_prompt_logprob = None
|
119
|
+
self.prefill_token_logprobs = None
|
120
|
+
self.prefill_top_logprobs = None
|
121
|
+
self.decode_token_logprobs = []
|
122
|
+
self.decode_top_logprobs = []
|
123
|
+
# The tokens is prefilled but need to be considered as decode tokens
|
124
|
+
# and should be updated for the decode logprobs
|
125
|
+
self.last_update_decode_tokens = 0
|
126
|
+
|
127
|
+
# Constrained decoding
|
128
|
+
self.regex_fsm: RegexGuide = None
|
129
|
+
self.regex_fsm_state: int = 0
|
130
|
+
self.jump_forward_map: JumpForwardMap = None
|
131
|
+
|
132
|
+
# whether request reached finished condition
|
133
|
+
def finished(self) -> bool:
|
134
|
+
return self.finished_reason is not None
|
135
|
+
|
136
|
+
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
137
|
+
def init_detokenize_incrementally(self):
|
138
|
+
first_iter = self.surr_offset is None or self.read_offset is None
|
139
|
+
|
140
|
+
if first_iter:
|
141
|
+
self.read_offset = len(self.origin_input_ids_unpadded)
|
142
|
+
self.surr_offset = max(
|
143
|
+
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
|
144
|
+
)
|
145
|
+
|
146
|
+
all_ids = self.origin_input_ids_unpadded + self.output_ids
|
147
|
+
surr_ids = all_ids[self.surr_offset : self.read_offset]
|
148
|
+
read_ids = all_ids[self.surr_offset :]
|
149
|
+
|
150
|
+
return surr_ids, read_ids, len(all_ids)
|
151
|
+
|
152
|
+
def detokenize_incrementally(self, inplace: bool = True):
|
153
|
+
surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
|
154
|
+
|
155
|
+
surr_text = self.tokenizer.decode(
|
156
|
+
surr_ids,
|
157
|
+
skip_special_tokens=self.sampling_params.skip_special_tokens,
|
158
|
+
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
|
159
|
+
)
|
160
|
+
new_text = self.tokenizer.decode(
|
161
|
+
read_ids,
|
162
|
+
skip_special_tokens=self.sampling_params.skip_special_tokens,
|
163
|
+
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
|
164
|
+
)
|
165
|
+
|
166
|
+
if len(new_text) > len(surr_text) and not new_text.endswith("�"):
|
167
|
+
new_text = new_text[len(surr_text) :]
|
168
|
+
if inplace:
|
169
|
+
self.decoded_text += new_text
|
170
|
+
self.surr_offset = self.read_offset
|
171
|
+
self.read_offset = num_all_tokens
|
172
|
+
|
173
|
+
return True, new_text
|
174
|
+
|
175
|
+
return False, ""
|
176
|
+
|
177
|
+
def check_finished(self):
|
178
|
+
if self.finished():
|
179
|
+
return
|
180
|
+
|
181
|
+
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
182
|
+
self.finished_reason = FINISH_LENGTH(len(self.output_ids))
|
183
|
+
return
|
184
|
+
|
185
|
+
if (
|
186
|
+
self.output_ids[-1] == self.tokenizer.eos_token_id
|
187
|
+
and not self.sampling_params.ignore_eos
|
188
|
+
):
|
189
|
+
self.finished_reason = FINISH_MATCHED_TOKEN(
|
190
|
+
matched=self.tokenizer.eos_token_id
|
191
|
+
)
|
192
|
+
return
|
193
|
+
|
194
|
+
if len(self.sampling_params.stop_strs) > 0:
|
195
|
+
tail_str = self.tokenizer.decode(
|
196
|
+
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
197
|
+
)
|
198
|
+
|
199
|
+
for stop_str in self.sampling_params.stop_strs:
|
200
|
+
if stop_str in tail_str or stop_str in self.decoded_text:
|
201
|
+
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
202
|
+
return
|
203
|
+
|
204
|
+
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
205
|
+
if self.origin_input_text is None:
|
206
|
+
# Recovering text can only use unpadded ids
|
207
|
+
self.origin_input_text = self.tokenizer.decode(
|
208
|
+
self.origin_input_ids_unpadded
|
209
|
+
)
|
210
|
+
|
211
|
+
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
|
212
|
+
all_ids = self.tokenizer.encode(all_text)
|
213
|
+
prompt_tokens = len(self.origin_input_ids_unpadded)
|
214
|
+
|
215
|
+
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
216
|
+
# TODO(lsyin): fix token fusion
|
217
|
+
warnings.warn(
|
218
|
+
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
219
|
+
)
|
220
|
+
return False
|
221
|
+
|
222
|
+
old_output_ids = self.output_ids
|
223
|
+
self.output_ids = all_ids[prompt_tokens:]
|
224
|
+
self.decoded_text = self.decoded_text + jump_forward_str
|
225
|
+
self.surr_offset = prompt_tokens
|
226
|
+
self.read_offset = len(all_ids)
|
227
|
+
|
228
|
+
# NOTE: A trick to reduce the surrouding tokens decoding overhead
|
229
|
+
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
|
230
|
+
surr_text_ = self.tokenizer.decode(
|
231
|
+
all_ids[self.read_offset - i : self.read_offset]
|
232
|
+
)
|
233
|
+
if not surr_text_.endswith("�"):
|
234
|
+
self.surr_offset = self.read_offset - i
|
235
|
+
break
|
236
|
+
|
237
|
+
self.regex_fsm_state = next_state
|
238
|
+
|
239
|
+
if self.return_logprob:
|
240
|
+
# For fast-forward part's logprobs
|
241
|
+
k = 0
|
242
|
+
for i, old_id in enumerate(old_output_ids):
|
243
|
+
if old_id == self.output_ids[i]:
|
244
|
+
k = k + 1
|
245
|
+
else:
|
246
|
+
break
|
247
|
+
self.decode_token_logprobs = self.decode_token_logprobs[:k]
|
248
|
+
self.decode_top_logprobs = self.decode_top_logprobs[:k]
|
249
|
+
self.logprob_start_len = prompt_tokens + k
|
250
|
+
self.last_update_decode_tokens = len(self.output_ids) - k
|
251
|
+
|
252
|
+
return True
|
253
|
+
|
254
|
+
def __repr__(self):
|
255
|
+
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
256
|
+
|
257
|
+
|
258
|
+
@dataclass
|
259
|
+
class Batch:
|
260
|
+
"""Store all inforamtion of a batch."""
|
261
|
+
|
262
|
+
# Request, memory pool, and cache
|
263
|
+
reqs: List[Req]
|
264
|
+
req_to_token_pool: ReqToTokenPool
|
265
|
+
token_to_kv_pool: TokenToKVPool
|
266
|
+
tree_cache: RadixCache
|
267
|
+
|
268
|
+
# Batched arguments to model runner
|
269
|
+
input_ids: torch.Tensor = None
|
270
|
+
req_pool_indices: torch.Tensor = None
|
271
|
+
seq_lens: torch.Tensor = None
|
272
|
+
prefix_lens: torch.Tensor = None
|
273
|
+
position_ids_offsets: torch.Tensor = None
|
274
|
+
out_cache_loc: torch.Tensor = None
|
275
|
+
|
276
|
+
# For processing logprobs
|
277
|
+
return_logprob: bool = False
|
278
|
+
top_logprobs_nums: List[int] = None
|
279
|
+
|
280
|
+
# For multimodal
|
281
|
+
pixel_values: List[torch.Tensor] = None
|
282
|
+
image_sizes: List[List[int]] = None
|
283
|
+
image_offsets: List[int] = None
|
284
|
+
|
285
|
+
# Other arguments for control
|
286
|
+
output_ids: torch.Tensor = None
|
287
|
+
extend_num_tokens: int = None
|
288
|
+
|
289
|
+
# Batched sampling params
|
290
|
+
temperatures: torch.Tensor = None
|
291
|
+
top_ps: torch.Tensor = None
|
292
|
+
top_ks: torch.Tensor = None
|
293
|
+
frequency_penalties: torch.Tensor = None
|
294
|
+
presence_penalties: torch.Tensor = None
|
295
|
+
logit_bias: torch.Tensor = None
|
296
|
+
|
297
|
+
@classmethod
|
298
|
+
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
299
|
+
return_logprob = any(req.return_logprob for req in reqs)
|
300
|
+
|
301
|
+
return cls(
|
302
|
+
reqs=reqs,
|
303
|
+
req_to_token_pool=req_to_token_pool,
|
304
|
+
token_to_kv_pool=token_to_kv_pool,
|
305
|
+
tree_cache=tree_cache,
|
306
|
+
return_logprob=return_logprob,
|
307
|
+
)
|
308
|
+
|
309
|
+
def is_empty(self):
|
310
|
+
return len(self.reqs) == 0
|
311
|
+
|
312
|
+
def has_stream(self) -> bool:
|
313
|
+
# Return whether batch has at least 1 streaming request
|
314
|
+
return any(r.stream for r in self.reqs)
|
315
|
+
|
316
|
+
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
317
|
+
device = "cuda"
|
318
|
+
bs = len(self.reqs)
|
319
|
+
reqs = self.reqs
|
320
|
+
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
321
|
+
prefix_indices = [r.prefix_indices for r in reqs]
|
322
|
+
|
323
|
+
# Handle prefix
|
324
|
+
flatten_input_ids = []
|
325
|
+
extend_lens = []
|
326
|
+
prefix_lens = []
|
327
|
+
seq_lens = []
|
328
|
+
|
329
|
+
req_pool_indices = self.req_to_token_pool.alloc(bs)
|
330
|
+
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
331
|
+
for i in range(bs):
|
332
|
+
flatten_input_ids.extend(input_ids[i])
|
333
|
+
extend_lens.append(len(input_ids[i]))
|
334
|
+
|
335
|
+
if len(prefix_indices[i]) == 0:
|
336
|
+
prefix_lens.append(0)
|
337
|
+
else:
|
338
|
+
prefix_lens.append(len(prefix_indices[i]))
|
339
|
+
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
340
|
+
: len(prefix_indices[i])
|
341
|
+
] = prefix_indices[i]
|
342
|
+
|
343
|
+
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
344
|
+
|
345
|
+
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
|
346
|
+
|
347
|
+
# Allocate memory
|
348
|
+
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
349
|
+
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
350
|
+
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
351
|
+
if out_cache_loc is None:
|
352
|
+
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
353
|
+
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
354
|
+
|
355
|
+
if out_cache_loc is None:
|
356
|
+
print("Prefill out of memory. This should never happen.")
|
357
|
+
self.tree_cache.pretty_print()
|
358
|
+
exit()
|
359
|
+
|
360
|
+
pt = 0
|
361
|
+
for i in range(bs):
|
362
|
+
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
363
|
+
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
|
364
|
+
] = out_cache_loc[pt : pt + extend_lens[i]]
|
365
|
+
pt += extend_lens[i]
|
366
|
+
|
367
|
+
# Handle logit bias but only allocate when needed
|
368
|
+
logit_bias = None
|
369
|
+
for i in range(bs):
|
370
|
+
if reqs[i].sampling_params.dtype == "int":
|
371
|
+
if logit_bias is None:
|
372
|
+
logit_bias = torch.zeros(
|
373
|
+
(bs, vocab_size), dtype=torch.float32, device=device
|
374
|
+
)
|
375
|
+
logit_bias[i] = int_token_logit_bias
|
376
|
+
|
377
|
+
# Set fields
|
378
|
+
self.input_ids = torch.tensor(
|
379
|
+
flatten_input_ids, dtype=torch.int32, device=device
|
380
|
+
)
|
381
|
+
self.pixel_values = [r.pixel_values for r in reqs]
|
382
|
+
self.image_sizes = [r.image_size for r in reqs]
|
383
|
+
self.image_offsets = [
|
384
|
+
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
385
|
+
]
|
386
|
+
self.req_pool_indices = req_pool_indices
|
387
|
+
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
|
388
|
+
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
389
|
+
self.position_ids_offsets = position_ids_offsets
|
390
|
+
self.extend_num_tokens = extend_num_tokens
|
391
|
+
self.out_cache_loc = out_cache_loc
|
392
|
+
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
393
|
+
|
394
|
+
self.temperatures = torch.tensor(
|
395
|
+
[r.sampling_params.temperature for r in reqs],
|
396
|
+
dtype=torch.float,
|
397
|
+
device=device,
|
398
|
+
).view(-1, 1)
|
399
|
+
self.top_ps = torch.tensor(
|
400
|
+
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
401
|
+
).view(-1, 1)
|
402
|
+
self.top_ks = torch.tensor(
|
403
|
+
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
404
|
+
).view(-1, 1)
|
405
|
+
self.frequency_penalties = torch.tensor(
|
406
|
+
[r.sampling_params.frequency_penalty for r in reqs],
|
407
|
+
dtype=torch.float,
|
408
|
+
device=device,
|
409
|
+
)
|
410
|
+
self.presence_penalties = torch.tensor(
|
411
|
+
[r.sampling_params.presence_penalty for r in reqs],
|
412
|
+
dtype=torch.float,
|
413
|
+
device=device,
|
414
|
+
)
|
415
|
+
self.logit_bias = logit_bias
|
416
|
+
|
417
|
+
def check_decode_mem(self):
|
418
|
+
bs = len(self.reqs)
|
419
|
+
if self.token_to_kv_pool.available_size() >= bs:
|
420
|
+
return True
|
421
|
+
|
422
|
+
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
423
|
+
|
424
|
+
if self.token_to_kv_pool.available_size() >= bs:
|
425
|
+
return True
|
426
|
+
|
427
|
+
return False
|
428
|
+
|
429
|
+
def retract_decode(self):
|
430
|
+
sorted_indices = [i for i in range(len(self.reqs))]
|
431
|
+
# TODO(lsyin): improve the priority of retraction
|
432
|
+
sorted_indices.sort(
|
433
|
+
key=lambda i: (
|
434
|
+
len(self.reqs[i].output_ids),
|
435
|
+
-len(self.reqs[i].origin_input_ids),
|
436
|
+
),
|
437
|
+
reverse=True,
|
438
|
+
)
|
439
|
+
|
440
|
+
retracted_reqs = []
|
441
|
+
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
442
|
+
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
443
|
+
while self.token_to_kv_pool.available_size() < len(self.reqs):
|
444
|
+
idx = sorted_indices.pop()
|
445
|
+
req = self.reqs[idx]
|
446
|
+
retracted_reqs.append(req)
|
447
|
+
|
448
|
+
# TODO: apply more fine-grained retraction
|
449
|
+
last_uncached_pos = len(req.prefix_indices)
|
450
|
+
token_indices = self.req_to_token_pool.req_to_token[
|
451
|
+
req_pool_indices_cpu[idx]
|
452
|
+
][last_uncached_pos : seq_lens_cpu[idx]]
|
453
|
+
self.token_to_kv_pool.free(token_indices)
|
454
|
+
|
455
|
+
# release the last node
|
456
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
457
|
+
|
458
|
+
req.prefix_indices = None
|
459
|
+
req.last_node = None
|
460
|
+
req.extend_input_len = 0
|
461
|
+
|
462
|
+
# For incremental logprobs
|
463
|
+
req.last_update_decode_tokens = 0
|
464
|
+
req.logprob_start_len = 10**9
|
465
|
+
|
466
|
+
self.filter_batch(sorted_indices)
|
467
|
+
|
468
|
+
return retracted_reqs
|
469
|
+
|
470
|
+
def check_for_jump_forward(self, model_runner):
|
471
|
+
jump_forward_reqs = []
|
472
|
+
filter_indices = [i for i in range(len(self.reqs))]
|
473
|
+
|
474
|
+
req_pool_indices_cpu = None
|
475
|
+
|
476
|
+
for i, req in enumerate(self.reqs):
|
477
|
+
if req.jump_forward_map is not None:
|
478
|
+
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
|
479
|
+
req.regex_fsm_state
|
480
|
+
)
|
481
|
+
if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
|
482
|
+
suffix_bytes = []
|
483
|
+
continuation_range = range(0x80, 0xC0)
|
484
|
+
cur_state = req.regex_fsm_state
|
485
|
+
while (
|
486
|
+
len(jump_forward_bytes)
|
487
|
+
and jump_forward_bytes[0][0] in continuation_range
|
488
|
+
):
|
489
|
+
# continuation bytes
|
490
|
+
byte_edge = jump_forward_bytes.pop(0)
|
491
|
+
suffix_bytes.append(byte_edge[0])
|
492
|
+
cur_state = byte_edge[1]
|
493
|
+
|
494
|
+
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
495
|
+
suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
|
496
|
+
|
497
|
+
# Current ids, for cache and revert
|
498
|
+
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
499
|
+
cur_output_ids = req.output_ids
|
500
|
+
|
501
|
+
req.output_ids.extend(suffix_ids)
|
502
|
+
decode_res, new_text = req.detokenize_incrementally(inplace=False)
|
503
|
+
if not decode_res:
|
504
|
+
req.output_ids = cur_output_ids
|
505
|
+
continue
|
506
|
+
|
507
|
+
(
|
508
|
+
jump_forward_str,
|
509
|
+
next_state,
|
510
|
+
) = req.jump_forward_map.jump_forward_symbol(cur_state)
|
511
|
+
|
512
|
+
# Make the incrementally decoded text part of jump_forward_str
|
513
|
+
# so that the UTF-8 will not corrupt
|
514
|
+
jump_forward_str = new_text + jump_forward_str
|
515
|
+
if not req.jump_forward_and_retokenize(
|
516
|
+
jump_forward_str, next_state
|
517
|
+
):
|
518
|
+
req.output_ids = cur_output_ids
|
519
|
+
continue
|
520
|
+
|
521
|
+
# insert the old request into tree_cache
|
522
|
+
if req_pool_indices_cpu is None:
|
523
|
+
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
524
|
+
self.tree_cache.cache_req(
|
525
|
+
token_ids=cur_all_ids,
|
526
|
+
last_uncached_pos=len(req.prefix_indices),
|
527
|
+
req_pool_idx=req_pool_indices_cpu[i],
|
528
|
+
)
|
529
|
+
|
530
|
+
# unlock the last node
|
531
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
532
|
+
|
533
|
+
# re-applying image padding
|
534
|
+
if req.pixel_values is not None:
|
535
|
+
(
|
536
|
+
req.origin_input_ids,
|
537
|
+
req.image_offset,
|
538
|
+
) = model_runner.model.pad_input_ids(
|
539
|
+
req.origin_input_ids_unpadded,
|
540
|
+
req.pad_value,
|
541
|
+
req.pixel_values.shape,
|
542
|
+
req.image_size,
|
543
|
+
)
|
544
|
+
|
545
|
+
jump_forward_reqs.append(req)
|
546
|
+
filter_indices.remove(i)
|
547
|
+
|
548
|
+
if len(filter_indices) < len(self.reqs):
|
549
|
+
self.filter_batch(filter_indices)
|
550
|
+
|
551
|
+
return jump_forward_reqs
|
552
|
+
|
553
|
+
def prepare_for_decode(self, input_ids=None):
|
554
|
+
if input_ids is None:
|
555
|
+
input_ids = [
|
556
|
+
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
|
557
|
+
]
|
558
|
+
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
559
|
+
self.seq_lens.add_(1)
|
560
|
+
self.prefix_lens = None
|
561
|
+
|
562
|
+
# Alloc mem
|
563
|
+
bs = len(self.reqs)
|
564
|
+
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
565
|
+
|
566
|
+
if self.out_cache_loc is None:
|
567
|
+
print("Decode out of memory. This should never happen.")
|
568
|
+
self.tree_cache.pretty_print()
|
569
|
+
exit()
|
570
|
+
|
571
|
+
self.req_to_token_pool.req_to_token[
|
572
|
+
self.req_pool_indices, self.seq_lens - 1
|
573
|
+
] = self.out_cache_loc
|
574
|
+
|
575
|
+
def filter_batch(self, unfinished_indices: List[int]):
|
576
|
+
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
577
|
+
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
|
578
|
+
self.seq_lens = self.seq_lens[new_indices]
|
579
|
+
self.input_ids = None
|
580
|
+
self.req_pool_indices = self.req_pool_indices[new_indices]
|
581
|
+
self.prefix_lens = None
|
582
|
+
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
583
|
+
self.out_cache_loc = None
|
584
|
+
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
585
|
+
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
586
|
+
|
587
|
+
for item in [
|
588
|
+
"temperatures",
|
589
|
+
"top_ps",
|
590
|
+
"top_ks",
|
591
|
+
"frequency_penalties",
|
592
|
+
"presence_penalties",
|
593
|
+
"logit_bias",
|
594
|
+
]:
|
595
|
+
self_val = getattr(self, item, None)
|
596
|
+
if self_val is not None: # logit_bias can be None
|
597
|
+
setattr(self, item, self_val[new_indices])
|
598
|
+
|
599
|
+
def merge(self, other: "Batch"):
|
600
|
+
self.reqs.extend(other.reqs)
|
601
|
+
|
602
|
+
self.req_pool_indices = torch.concat(
|
603
|
+
[self.req_pool_indices, other.req_pool_indices]
|
604
|
+
)
|
605
|
+
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
606
|
+
self.prefix_lens = None
|
607
|
+
self.position_ids_offsets = torch.concat(
|
608
|
+
[self.position_ids_offsets, other.position_ids_offsets]
|
609
|
+
)
|
610
|
+
self.out_cache_loc = None
|
611
|
+
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
612
|
+
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
613
|
+
|
614
|
+
for item in [
|
615
|
+
"temperatures",
|
616
|
+
"top_ps",
|
617
|
+
"top_ks",
|
618
|
+
"frequency_penalties",
|
619
|
+
"presence_penalties",
|
620
|
+
]:
|
621
|
+
self_val = getattr(self, item, None)
|
622
|
+
other_val = getattr(other, item, None)
|
623
|
+
setattr(self, item, torch.concat([self_val, other_val]))
|
624
|
+
|
625
|
+
# logit_bias can be None
|
626
|
+
if self.logit_bias is not None or other.logit_bias is not None:
|
627
|
+
vocab_size = (
|
628
|
+
self.logit_bias.shape[1]
|
629
|
+
if self.logit_bias is not None
|
630
|
+
else other.logit_bias.shape[1]
|
631
|
+
)
|
632
|
+
if self.logit_bias is None:
|
633
|
+
self.logit_bias = torch.zeros(
|
634
|
+
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
635
|
+
)
|
636
|
+
if other.logit_bias is None:
|
637
|
+
other.logit_bias = torch.zeros(
|
638
|
+
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
639
|
+
)
|
640
|
+
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
641
|
+
|
642
|
+
def sample(self, logits: torch.Tensor):
|
643
|
+
# Post process logits
|
644
|
+
logits = logits.contiguous()
|
645
|
+
logits.div_(self.temperatures)
|
646
|
+
if self.logit_bias is not None:
|
647
|
+
logits.add_(self.logit_bias)
|
648
|
+
|
649
|
+
has_regex = any(req.regex_fsm is not None for req in self.reqs)
|
650
|
+
if has_regex:
|
651
|
+
allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
|
652
|
+
for i, req in enumerate(self.reqs):
|
653
|
+
if req.regex_fsm is not None:
|
654
|
+
allowed_mask.zero_()
|
655
|
+
allowed_mask[
|
656
|
+
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
657
|
+
] = 1
|
658
|
+
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
659
|
+
|
660
|
+
# TODO(lmzheng): apply penalty
|
661
|
+
probs = torch.softmax(logits, dim=-1)
|
662
|
+
probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
|
663
|
+
try:
|
664
|
+
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
665
|
+
except RuntimeError as e:
|
666
|
+
warnings.warn(f"Ignore errors in sampling: {e}")
|
667
|
+
sampled_index = torch.ones(
|
668
|
+
probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device
|
669
|
+
)
|
670
|
+
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
|
671
|
+
-1
|
672
|
+
)
|
673
|
+
batch_next_token_probs = torch.gather(
|
674
|
+
probs_sort, dim=1, index=sampled_index
|
675
|
+
).view(-1)
|
676
|
+
|
677
|
+
if has_regex:
|
678
|
+
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
679
|
+
for i, req in enumerate(self.reqs):
|
680
|
+
if req.regex_fsm is not None:
|
681
|
+
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
682
|
+
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
683
|
+
)
|
684
|
+
|
685
|
+
return batch_next_token_ids, batch_next_token_probs
|
686
|
+
|
687
|
+
|
688
|
+
def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
|
689
|
+
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
690
|
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
691
|
+
probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
|
692
|
+
probs_sort[
|
693
|
+
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
|
694
|
+
] = 0.0
|
695
|
+
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
696
|
+
return probs_sort, probs_idx
|
697
|
+
|
698
|
+
|
699
|
+
@dataclass
|
700
|
+
class InputMetadata:
|
701
|
+
"""Store all inforamtion of a forward pass."""
|
702
|
+
|
703
|
+
forward_mode: ForwardMode
|
704
|
+
batch_size: int
|
705
|
+
total_num_tokens: int
|
706
|
+
req_pool_indices: torch.Tensor
|
707
|
+
seq_lens: torch.Tensor
|
708
|
+
positions: torch.Tensor
|
709
|
+
req_to_token_pool: ReqToTokenPool
|
710
|
+
token_to_kv_pool: TokenToKVPool
|
711
|
+
|
712
|
+
# For extend
|
713
|
+
extend_seq_lens: torch.Tensor
|
714
|
+
extend_start_loc: torch.Tensor
|
715
|
+
extend_no_prefix: bool
|
716
|
+
|
717
|
+
# Output location of the KV cache
|
718
|
+
out_cache_loc: torch.Tensor = None
|
719
|
+
|
720
|
+
# Output options
|
721
|
+
return_logprob: bool = False
|
722
|
+
top_logprobs_nums: List[int] = None
|
723
|
+
|
724
|
+
# Trition attention backend
|
725
|
+
triton_max_seq_len: int = 0
|
726
|
+
triton_max_extend_len: int = 0
|
727
|
+
triton_start_loc: torch.Tensor = None
|
728
|
+
triton_prefix_lens: torch.Tensor = None
|
729
|
+
|
730
|
+
# FlashInfer attention backend
|
731
|
+
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
732
|
+
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
733
|
+
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
734
|
+
|
735
|
+
@classmethod
|
736
|
+
def create(
|
737
|
+
cls,
|
738
|
+
model_runner,
|
739
|
+
forward_mode,
|
740
|
+
req_pool_indices,
|
741
|
+
seq_lens,
|
742
|
+
prefix_lens,
|
743
|
+
position_ids_offsets,
|
744
|
+
out_cache_loc,
|
745
|
+
top_logprobs_nums=None,
|
746
|
+
return_logprob=False,
|
747
|
+
skip_flashinfer_init=False,
|
748
|
+
):
|
749
|
+
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
750
|
+
init_flashinfer_args(
|
751
|
+
forward_mode,
|
752
|
+
model_runner,
|
753
|
+
req_pool_indices,
|
754
|
+
seq_lens,
|
755
|
+
prefix_lens,
|
756
|
+
model_runner.flashinfer_decode_wrapper,
|
757
|
+
)
|
758
|
+
|
759
|
+
batch_size = len(req_pool_indices)
|
760
|
+
|
761
|
+
if forward_mode == ForwardMode.DECODE:
|
762
|
+
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
763
|
+
extend_seq_lens = extend_start_loc = extend_no_prefix = None
|
764
|
+
if not model_runner.server_args.disable_flashinfer:
|
765
|
+
# This variable is not needed in this case,
|
766
|
+
# we do not compute it to make it compatbile with cuda graph.
|
767
|
+
total_num_tokens = None
|
768
|
+
else:
|
769
|
+
total_num_tokens = int(torch.sum(seq_lens))
|
770
|
+
else:
|
771
|
+
seq_lens_cpu = seq_lens.cpu().numpy()
|
772
|
+
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
773
|
+
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
774
|
+
positions = torch.tensor(
|
775
|
+
np.concatenate(
|
776
|
+
[
|
777
|
+
np.arange(
|
778
|
+
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
779
|
+
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
780
|
+
)
|
781
|
+
for i in range(batch_size)
|
782
|
+
],
|
783
|
+
axis=0,
|
784
|
+
),
|
785
|
+
device="cuda",
|
786
|
+
)
|
787
|
+
extend_seq_lens = seq_lens - prefix_lens
|
788
|
+
extend_start_loc = torch.zeros_like(seq_lens)
|
789
|
+
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
790
|
+
extend_no_prefix = torch.all(prefix_lens == 0)
|
791
|
+
total_num_tokens = int(torch.sum(seq_lens))
|
792
|
+
|
793
|
+
ret = cls(
|
794
|
+
forward_mode=forward_mode,
|
795
|
+
batch_size=batch_size,
|
796
|
+
total_num_tokens=total_num_tokens,
|
797
|
+
req_pool_indices=req_pool_indices,
|
798
|
+
seq_lens=seq_lens,
|
799
|
+
positions=positions,
|
800
|
+
req_to_token_pool=model_runner.req_to_token_pool,
|
801
|
+
token_to_kv_pool=model_runner.token_to_kv_pool,
|
802
|
+
out_cache_loc=out_cache_loc,
|
803
|
+
extend_seq_lens=extend_seq_lens,
|
804
|
+
extend_start_loc=extend_start_loc,
|
805
|
+
extend_no_prefix=extend_no_prefix,
|
806
|
+
return_logprob=return_logprob,
|
807
|
+
top_logprobs_nums=top_logprobs_nums,
|
808
|
+
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
809
|
+
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
810
|
+
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
811
|
+
)
|
812
|
+
|
813
|
+
if model_runner.server_args.disable_flashinfer:
|
814
|
+
(
|
815
|
+
ret.triton_max_seq_len,
|
816
|
+
ret.triton_max_extend_len,
|
817
|
+
ret.triton_start_loc,
|
818
|
+
ret.triton_prefix_lens,
|
819
|
+
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
820
|
+
|
821
|
+
return ret
|
822
|
+
|
823
|
+
|
824
|
+
def init_flashinfer_args(
|
825
|
+
forward_mode,
|
826
|
+
model_runner,
|
827
|
+
req_pool_indices,
|
828
|
+
seq_lens,
|
829
|
+
prefix_lens,
|
830
|
+
flashinfer_decode_wrapper,
|
831
|
+
):
|
832
|
+
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
833
|
+
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
834
|
+
head_dim = model_runner.model_config.head_dim
|
835
|
+
batch_size = len(req_pool_indices)
|
836
|
+
|
837
|
+
if forward_mode == ForwardMode.DECODE:
|
838
|
+
paged_kernel_lens = seq_lens
|
839
|
+
else:
|
840
|
+
paged_kernel_lens = prefix_lens
|
841
|
+
|
842
|
+
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
843
|
+
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
844
|
+
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
845
|
+
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
846
|
+
kv_indices = torch.cat(
|
847
|
+
[
|
848
|
+
model_runner.req_to_token_pool.req_to_token[
|
849
|
+
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
850
|
+
]
|
851
|
+
for i in range(batch_size)
|
852
|
+
],
|
853
|
+
dim=0,
|
854
|
+
).contiguous()
|
855
|
+
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
856
|
+
|
857
|
+
if forward_mode == ForwardMode.DECODE:
|
858
|
+
flashinfer_decode_wrapper.end_forward()
|
859
|
+
flashinfer_decode_wrapper.begin_forward(
|
860
|
+
kv_indptr,
|
861
|
+
kv_indices,
|
862
|
+
kv_last_page_len,
|
863
|
+
num_qo_heads,
|
864
|
+
num_kv_heads,
|
865
|
+
head_dim,
|
866
|
+
1,
|
867
|
+
)
|
868
|
+
else:
|
869
|
+
# extend part
|
870
|
+
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
871
|
+
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
872
|
+
|
873
|
+
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
874
|
+
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
875
|
+
qo_indptr,
|
876
|
+
qo_indptr,
|
877
|
+
num_qo_heads,
|
878
|
+
num_kv_heads,
|
879
|
+
head_dim,
|
880
|
+
)
|
881
|
+
|
882
|
+
# cached part
|
883
|
+
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
884
|
+
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
885
|
+
qo_indptr,
|
886
|
+
kv_indptr,
|
887
|
+
kv_indices,
|
888
|
+
kv_last_page_len,
|
889
|
+
num_qo_heads,
|
890
|
+
num_kv_heads,
|
891
|
+
head_dim,
|
892
|
+
1,
|
893
|
+
)
|
894
|
+
|
895
|
+
|
896
|
+
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
897
|
+
batch_size = len(seq_lens)
|
898
|
+
max_seq_len = int(torch.max(seq_lens))
|
899
|
+
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
900
|
+
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
901
|
+
|
902
|
+
if forward_mode == ForwardMode.DECODE:
|
903
|
+
max_extend_len = None
|
904
|
+
else:
|
905
|
+
extend_seq_lens = seq_lens - prefix_lens
|
906
|
+
max_extend_len = int(torch.max(extend_seq_lens))
|
907
|
+
|
908
|
+
return max_seq_len, max_extend_len, start_loc, prefix_lens
|