sglang 0.2.9.post1__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +234 -74
- sglang/check_env.py +25 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -40
- sglang/lang/choices.py +164 -0
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +11 -2
- sglang/srt/hf_transformers_utils.py +2 -2
- sglang/srt/layers/extend_attention.py +59 -7
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/radix_attention.py +24 -14
- sglang/srt/layers/token_attention.py +28 -2
- sglang/srt/managers/io_struct.py +9 -4
- sglang/srt/managers/schedule_batch.py +98 -323
- sglang/srt/managers/tokenizer_manager.py +34 -16
- sglang/srt/managers/tp_worker.py +20 -22
- sglang/srt/mem_cache/memory_pool.py +74 -38
- sglang/srt/model_config.py +11 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -3
- sglang/srt/model_executor/forward_batch_info.py +256 -0
- sglang/srt/model_executor/model_runner.py +51 -26
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +199 -17
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +1 -1
- sglang/srt/models/llama2.py +1 -1
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +151 -29
- sglang/srt/openai_api/protocol.py +7 -1
- sglang/srt/server.py +111 -84
- sglang/srt/server_args.py +12 -2
- sglang/srt/utils.py +25 -20
- sglang/test/run_eval.py +21 -10
- sglang/test/runners.py +237 -0
- sglang/test/simple_eval_common.py +12 -12
- sglang/test/simple_eval_gpqa.py +92 -0
- sglang/test/simple_eval_humaneval.py +5 -5
- sglang/test/simple_eval_math.py +72 -0
- sglang/test/test_utils.py +95 -14
- sglang/utils.py +15 -37
- sglang/version.py +1 -1
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/METADATA +59 -48
- sglang-0.2.11.dist-info/RECORD +102 -0
- sglang-0.2.9.post1.dist-info/RECORD +0 -97
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
@@ -54,6 +54,7 @@ def _fwd_kernel_stage1(
|
|
54
54
|
att_stride_h,
|
55
55
|
kv_group_num: tl.constexpr,
|
56
56
|
BLOCK_DMODEL: tl.constexpr,
|
57
|
+
BLOCK_DPE: tl.constexpr,
|
57
58
|
BLOCK_N: tl.constexpr,
|
58
59
|
logit_cap: tl.constexpr,
|
59
60
|
):
|
@@ -73,6 +74,10 @@ def _fwd_kernel_stage1(
|
|
73
74
|
|
74
75
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
75
76
|
|
77
|
+
if BLOCK_DPE > 0:
|
78
|
+
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
79
|
+
off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe
|
80
|
+
|
76
81
|
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
77
82
|
|
78
83
|
block_stard_index = start_n * BLOCK_N
|
@@ -97,6 +102,19 @@ def _fwd_kernel_stage1(
|
|
97
102
|
other=0.0,
|
98
103
|
).to(REDUCE_TRITON_TYPE)
|
99
104
|
att_value = tl.sum(q[None, :] * k, 1)
|
105
|
+
if BLOCK_DPE > 0:
|
106
|
+
qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE)
|
107
|
+
offs_buf_kpe = (
|
108
|
+
k_loc[:, None] * stride_buf_kbs
|
109
|
+
+ cur_kv_head * stride_buf_kh
|
110
|
+
+ offs_dpe[None, :]
|
111
|
+
)
|
112
|
+
kpe = tl.load(
|
113
|
+
K_Buffer + offs_buf_kpe,
|
114
|
+
mask=offs_n_new[:, None] < cur_batch_end_index,
|
115
|
+
other=0.0,
|
116
|
+
).to(REDUCE_TRITON_TYPE)
|
117
|
+
att_value += tl.sum(qpe[None, :] * kpe, 1)
|
100
118
|
att_value *= sm_scale
|
101
119
|
|
102
120
|
if logit_cap > 0:
|
@@ -192,7 +210,14 @@ def _token_att_m_fwd(
|
|
192
210
|
# shape constraints
|
193
211
|
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
194
212
|
assert Lq == Lk
|
195
|
-
assert Lk in {16, 32, 64, 128, 256}
|
213
|
+
assert Lk in {16, 32, 64, 128, 256, 576}
|
214
|
+
|
215
|
+
if Lk == 576:
|
216
|
+
BLOCK_DMODEL = 512
|
217
|
+
BLOCK_DPE = 64
|
218
|
+
else:
|
219
|
+
BLOCK_DMODEL = Lk
|
220
|
+
BLOCK_DPE = 0
|
196
221
|
|
197
222
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
198
223
|
|
@@ -220,7 +245,8 @@ def _token_att_m_fwd(
|
|
220
245
|
k_buffer.stride(1),
|
221
246
|
att_out.stride(0),
|
222
247
|
kv_group_num=kv_group_num,
|
223
|
-
BLOCK_DMODEL=
|
248
|
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
249
|
+
BLOCK_DPE=BLOCK_DPE,
|
224
250
|
BLOCK_N=BLOCK,
|
225
251
|
logit_cap=logit_cap,
|
226
252
|
num_warps=num_warps,
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -92,7 +92,7 @@ class GenerateReqInput:
|
|
92
92
|
for element in parallel_sample_num_list
|
93
93
|
)
|
94
94
|
if parallel_sample_num > 1 and (not all_equal):
|
95
|
-
|
95
|
+
# TODO cope with the case that the parallel_sample_num is different for different samples
|
96
96
|
raise ValueError(
|
97
97
|
"The parallel_sample_num should be the same for all samples in sample params."
|
98
98
|
)
|
@@ -103,14 +103,19 @@ class GenerateReqInput:
|
|
103
103
|
if parallel_sample_num != 1:
|
104
104
|
# parallel sampling +1 represents the original prefill stage
|
105
105
|
num = parallel_sample_num + 1
|
106
|
-
if isinstance(self.text,
|
107
|
-
|
106
|
+
if isinstance(self.text, list):
|
107
|
+
# suppot batch operation
|
108
108
|
self.batch_size = len(self.text)
|
109
109
|
num = num * len(self.text)
|
110
|
+
elif isinstance(self.input_ids, list) and isinstance(
|
111
|
+
self.input_ids[0], list
|
112
|
+
):
|
113
|
+
self.batch_size = len(self.input_ids)
|
114
|
+
num = num * len(self.input_ids)
|
110
115
|
else:
|
111
116
|
self.batch_size = 1
|
112
117
|
else:
|
113
|
-
|
118
|
+
# support select operation
|
114
119
|
num = len(self.text) if self.text is not None else len(self.input_ids)
|
115
120
|
self.batch_size = num
|
116
121
|
|
@@ -18,7 +18,6 @@ limitations under the License.
|
|
18
18
|
import logging
|
19
19
|
import warnings
|
20
20
|
from dataclasses import dataclass
|
21
|
-
from enum import IntEnum, auto
|
22
21
|
from typing import List, Union
|
23
22
|
|
24
23
|
import numpy as np
|
@@ -29,7 +28,7 @@ from sglang.global_config import global_config
|
|
29
28
|
from sglang.srt.constrained import RegexGuide
|
30
29
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
31
30
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
32
|
-
from sglang.srt.mem_cache.memory_pool import
|
31
|
+
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
33
32
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
34
33
|
|
35
34
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
@@ -39,21 +38,13 @@ global_server_args_dict = {
|
|
39
38
|
"disable_flashinfer": False,
|
40
39
|
"disable_flashinfer_sampling": False,
|
41
40
|
"attention_reduce_in_fp32": False,
|
41
|
+
"enable_mla": False,
|
42
42
|
}
|
43
43
|
|
44
44
|
|
45
45
|
logger = logging.getLogger(__name__)
|
46
46
|
|
47
47
|
|
48
|
-
class ForwardMode(IntEnum):
|
49
|
-
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
50
|
-
PREFILL = auto()
|
51
|
-
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
52
|
-
EXTEND = auto()
|
53
|
-
# Decode one token.
|
54
|
-
DECODE = auto()
|
55
|
-
|
56
|
-
|
57
48
|
class BaseFinishReason:
|
58
49
|
def __init__(self, is_error: bool = False):
|
59
50
|
self.is_error = is_error
|
@@ -109,6 +100,9 @@ class Req:
|
|
109
100
|
self.output_ids = [] # Each decode stage's output ids
|
110
101
|
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
111
102
|
|
103
|
+
# Memory info
|
104
|
+
self.req_pool_idx = None
|
105
|
+
|
112
106
|
# For incremental decoding
|
113
107
|
# ----- | --------- read_ids -------|
|
114
108
|
# ----- | surr_ids |
|
@@ -283,13 +277,13 @@ class Req:
|
|
283
277
|
|
284
278
|
|
285
279
|
@dataclass
|
286
|
-
class
|
280
|
+
class ScheduleBatch:
|
287
281
|
"""Store all inforamtion of a batch."""
|
288
282
|
|
289
283
|
# Request, memory pool, and cache
|
290
284
|
reqs: List[Req]
|
291
285
|
req_to_token_pool: ReqToTokenPool
|
292
|
-
token_to_kv_pool:
|
286
|
+
token_to_kv_pool: BaseTokenToKVPool
|
293
287
|
tree_cache: RadixCache
|
294
288
|
|
295
289
|
# Batched arguments to model runner
|
@@ -330,6 +324,9 @@ class Batch:
|
|
330
324
|
return_logprob=return_logprob,
|
331
325
|
)
|
332
326
|
|
327
|
+
def batch_size(self):
|
328
|
+
return len(self.reqs) if self.reqs is not None else 0
|
329
|
+
|
333
330
|
def is_empty(self):
|
334
331
|
return len(self.reqs) == 0
|
335
332
|
|
@@ -337,116 +334,127 @@ class Batch:
|
|
337
334
|
# Return whether batch has at least 1 streaming request
|
338
335
|
return any(r.stream for r in self.reqs)
|
339
336
|
|
337
|
+
def alloc_req_slots(self, num_reqs):
|
338
|
+
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
339
|
+
if req_pool_indices is None:
|
340
|
+
raise RuntimeError(
|
341
|
+
"Out of memory. "
|
342
|
+
"Please set a smaller number for `--max-running-requests`."
|
343
|
+
)
|
344
|
+
return req_pool_indices
|
345
|
+
|
346
|
+
def alloc_token_slots(self, num_tokens: int):
|
347
|
+
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
348
|
+
|
349
|
+
if out_cache_loc is None:
|
350
|
+
if self.tree_cache is not None:
|
351
|
+
self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
|
352
|
+
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
353
|
+
|
354
|
+
if out_cache_loc is None:
|
355
|
+
logger.error("Prefill out of memory. Try to lower your batch size.")
|
356
|
+
if self.tree_cache is not None:
|
357
|
+
self.tree_cache.pretty_print()
|
358
|
+
exit(1)
|
359
|
+
|
360
|
+
return out_cache_loc
|
361
|
+
|
362
|
+
def batch_sampling_params(self, vocab_size, int_token_logit_bias):
|
363
|
+
device = "cuda"
|
364
|
+
bs, reqs = self.batch_size(), self.reqs
|
365
|
+
self.temperatures = torch.tensor(
|
366
|
+
[r.sampling_params.temperature for r in reqs],
|
367
|
+
dtype=torch.float,
|
368
|
+
device=device,
|
369
|
+
).view(-1, 1)
|
370
|
+
self.top_ps = torch.tensor(
|
371
|
+
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
372
|
+
)
|
373
|
+
self.top_ks = torch.tensor(
|
374
|
+
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
375
|
+
)
|
376
|
+
self.frequency_penalties = torch.tensor(
|
377
|
+
[r.sampling_params.frequency_penalty for r in reqs],
|
378
|
+
dtype=torch.float,
|
379
|
+
device=device,
|
380
|
+
)
|
381
|
+
self.presence_penalties = torch.tensor(
|
382
|
+
[r.sampling_params.presence_penalty for r in reqs],
|
383
|
+
dtype=torch.float,
|
384
|
+
device=device,
|
385
|
+
)
|
386
|
+
|
387
|
+
# Handle logit bias but only allocate when needed
|
388
|
+
self.logit_bias = None
|
389
|
+
for i in range(bs):
|
390
|
+
if reqs[i].sampling_params.dtype == "int":
|
391
|
+
if self.logit_bias is None:
|
392
|
+
self.logit_bias = torch.zeros(
|
393
|
+
(bs, vocab_size), dtype=torch.float32, device=device
|
394
|
+
)
|
395
|
+
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
396
|
+
|
340
397
|
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
341
398
|
device = "cuda"
|
342
|
-
bs =
|
399
|
+
bs = self.batch_size()
|
343
400
|
reqs = self.reqs
|
344
401
|
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
345
402
|
prefix_indices = [r.prefix_indices for r in reqs]
|
346
403
|
|
347
404
|
# Handle prefix
|
348
|
-
flatten_input_ids = []
|
349
405
|
extend_lens = []
|
350
406
|
prefix_lens = []
|
351
407
|
seq_lens = []
|
352
408
|
|
353
|
-
|
354
|
-
|
355
|
-
if req_pool_indices is None:
|
356
|
-
raise RuntimeError(
|
357
|
-
"Out of memory. "
|
358
|
-
"Please set a smaller number for `--max-running-requests`."
|
359
|
-
)
|
409
|
+
req_pool_indices_cpu = self.alloc_req_slots(bs)
|
360
410
|
|
361
|
-
|
362
|
-
|
363
|
-
flatten_input_ids.extend(input_ids[i])
|
411
|
+
for i, req in enumerate(reqs):
|
412
|
+
req.req_pool_idx = req_pool_indices_cpu[i]
|
364
413
|
extend_lens.append(len(input_ids[i]))
|
365
414
|
|
366
415
|
if len(prefix_indices[i]) == 0:
|
367
416
|
prefix_lens.append(0)
|
368
417
|
else:
|
369
418
|
prefix_lens.append(len(prefix_indices[i]))
|
370
|
-
self.req_to_token_pool.req_to_token[
|
419
|
+
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
371
420
|
: len(prefix_indices[i])
|
372
421
|
] = prefix_indices[i]
|
373
422
|
|
374
423
|
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
375
424
|
|
376
|
-
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
|
377
|
-
|
378
425
|
# Allocate memory
|
379
426
|
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
380
427
|
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
381
|
-
out_cache_loc = self.
|
382
|
-
if out_cache_loc is None:
|
383
|
-
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
384
|
-
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
385
|
-
|
386
|
-
if out_cache_loc is None:
|
387
|
-
logger.error("Prefill out of memory. This should never happen.")
|
388
|
-
self.tree_cache.pretty_print()
|
389
|
-
exit()
|
428
|
+
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
390
429
|
|
391
430
|
pt = 0
|
392
|
-
for i in
|
393
|
-
self.req_to_token_pool.req_to_token[
|
431
|
+
for i, req in enumerate(reqs):
|
432
|
+
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
394
433
|
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
|
395
434
|
] = out_cache_loc[pt : pt + extend_lens[i]]
|
396
435
|
pt += extend_lens[i]
|
397
436
|
|
398
|
-
# Handle logit bias but only allocate when needed
|
399
|
-
logit_bias = None
|
400
|
-
for i in range(bs):
|
401
|
-
if reqs[i].sampling_params.dtype == "int":
|
402
|
-
if logit_bias is None:
|
403
|
-
logit_bias = torch.zeros(
|
404
|
-
(bs, vocab_size), dtype=torch.float32, device=device
|
405
|
-
)
|
406
|
-
logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
407
|
-
|
408
437
|
# Set fields
|
409
|
-
|
410
|
-
|
411
|
-
|
438
|
+
with torch.device("cuda"):
|
439
|
+
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
440
|
+
self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
|
441
|
+
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
442
|
+
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32)
|
443
|
+
|
412
444
|
self.pixel_values = [r.pixel_values for r in reqs]
|
413
445
|
self.image_sizes = [r.image_size for r in reqs]
|
414
446
|
self.image_offsets = [
|
415
447
|
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
416
448
|
]
|
417
|
-
self.req_pool_indices = req_pool_indices
|
418
|
-
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
|
419
449
|
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
420
|
-
self.position_ids_offsets = position_ids_offsets
|
421
450
|
self.extend_num_tokens = extend_num_tokens
|
422
451
|
self.out_cache_loc = out_cache_loc
|
423
452
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
424
453
|
|
425
|
-
self.
|
426
|
-
[r.sampling_params.temperature for r in reqs],
|
427
|
-
dtype=torch.float,
|
428
|
-
device=device,
|
429
|
-
).view(-1, 1)
|
430
|
-
self.top_ps = torch.tensor(
|
431
|
-
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
432
|
-
)
|
433
|
-
self.top_ks = torch.tensor(
|
434
|
-
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
435
|
-
)
|
436
|
-
self.frequency_penalties = torch.tensor(
|
437
|
-
[r.sampling_params.frequency_penalty for r in reqs],
|
438
|
-
dtype=torch.float,
|
439
|
-
device=device,
|
440
|
-
)
|
441
|
-
self.presence_penalties = torch.tensor(
|
442
|
-
[r.sampling_params.presence_penalty for r in reqs],
|
443
|
-
dtype=torch.float,
|
444
|
-
device=device,
|
445
|
-
)
|
446
|
-
self.logit_bias = logit_bias
|
454
|
+
self.batch_sampling_params(vocab_size, int_token_logit_bias)
|
447
455
|
|
448
456
|
def check_decode_mem(self):
|
449
|
-
bs =
|
457
|
+
bs = self.batch_size()
|
450
458
|
if self.token_to_kv_pool.available_size() >= bs:
|
451
459
|
return True
|
452
460
|
|
@@ -471,7 +479,6 @@ class Batch:
|
|
471
479
|
|
472
480
|
retracted_reqs = []
|
473
481
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
474
|
-
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
475
482
|
while (
|
476
483
|
self.token_to_kv_pool.available_size()
|
477
484
|
< len(sorted_indices) * global_config.retract_decode_steps
|
@@ -489,20 +496,20 @@ class Batch:
|
|
489
496
|
|
490
497
|
if isinstance(self.tree_cache, ChunkCache):
|
491
498
|
# ChunkCache does not have eviction
|
492
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
493
|
-
|
494
|
-
]
|
499
|
+
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
500
|
+
: seq_lens_cpu[idx]
|
501
|
+
]
|
495
502
|
self.token_to_kv_pool.free(token_indices)
|
496
|
-
self.req_to_token_pool.free(
|
503
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
497
504
|
del self.tree_cache.entries[req.rid]
|
498
505
|
else:
|
499
506
|
# TODO: apply more fine-grained retraction
|
500
507
|
last_uncached_pos = len(req.prefix_indices)
|
501
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
502
|
-
|
503
|
-
]
|
508
|
+
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
509
|
+
last_uncached_pos : seq_lens_cpu[idx]
|
510
|
+
]
|
504
511
|
self.token_to_kv_pool.free(token_indices)
|
505
|
-
self.req_to_token_pool.free(
|
512
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
506
513
|
|
507
514
|
# release the last node
|
508
515
|
self.tree_cache.dec_lock_ref(req.last_node)
|
@@ -540,8 +547,6 @@ class Batch:
|
|
540
547
|
jump_forward_reqs = []
|
541
548
|
filter_indices = [i for i in range(len(self.reqs))]
|
542
549
|
|
543
|
-
req_pool_indices_cpu = None
|
544
|
-
|
545
550
|
for i, req in enumerate(self.reqs):
|
546
551
|
if req.jump_forward_map is not None:
|
547
552
|
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
|
@@ -591,13 +596,11 @@ class Batch:
|
|
591
596
|
req.vid += 1
|
592
597
|
|
593
598
|
# insert the old request into tree_cache
|
594
|
-
if req_pool_indices_cpu is None:
|
595
|
-
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
596
599
|
self.tree_cache.cache_req(
|
597
600
|
rid=req.rid,
|
598
601
|
token_ids=cur_all_ids,
|
599
602
|
last_uncached_pos=len(req.prefix_indices),
|
600
|
-
req_pool_idx=
|
603
|
+
req_pool_idx=req.req_pool_idx,
|
601
604
|
)
|
602
605
|
|
603
606
|
# unlock the last node
|
@@ -633,13 +636,8 @@ class Batch:
|
|
633
636
|
self.prefix_lens = None
|
634
637
|
|
635
638
|
# Alloc mem
|
636
|
-
bs =
|
637
|
-
self.out_cache_loc = self.
|
638
|
-
|
639
|
-
if self.out_cache_loc is None:
|
640
|
-
logger.error("Decode out of memory. This should never happen.")
|
641
|
-
self.tree_cache.pretty_print()
|
642
|
-
exit()
|
639
|
+
bs = self.batch_size()
|
640
|
+
self.out_cache_loc = self.alloc_token_slots(bs)
|
643
641
|
|
644
642
|
self.req_to_token_pool.req_to_token[
|
645
643
|
self.req_pool_indices, self.seq_lens - 1
|
@@ -669,7 +667,7 @@ class Batch:
|
|
669
667
|
if self_val is not None: # logit_bias can be None
|
670
668
|
setattr(self, item, self_val[new_indices])
|
671
669
|
|
672
|
-
def merge(self, other: "
|
670
|
+
def merge(self, other: "ScheduleBatch"):
|
673
671
|
self.reqs.extend(other.reqs)
|
674
672
|
|
675
673
|
self.req_pool_indices = torch.concat(
|
@@ -766,229 +764,6 @@ class Batch:
|
|
766
764
|
return batch_next_token_ids
|
767
765
|
|
768
766
|
|
769
|
-
@dataclass
|
770
|
-
class InputMetadata:
|
771
|
-
"""Store all inforamtion of a forward pass."""
|
772
|
-
|
773
|
-
forward_mode: ForwardMode
|
774
|
-
batch_size: int
|
775
|
-
total_num_tokens: int
|
776
|
-
req_pool_indices: torch.Tensor
|
777
|
-
seq_lens: torch.Tensor
|
778
|
-
positions: torch.Tensor
|
779
|
-
req_to_token_pool: ReqToTokenPool
|
780
|
-
token_to_kv_pool: TokenToKVPool
|
781
|
-
|
782
|
-
# For extend
|
783
|
-
extend_seq_lens: torch.Tensor
|
784
|
-
extend_start_loc: torch.Tensor
|
785
|
-
extend_no_prefix: bool
|
786
|
-
|
787
|
-
# Output location of the KV cache
|
788
|
-
out_cache_loc: torch.Tensor = None
|
789
|
-
|
790
|
-
# Output options
|
791
|
-
return_logprob: bool = False
|
792
|
-
top_logprobs_nums: List[int] = None
|
793
|
-
|
794
|
-
# Trition attention backend
|
795
|
-
triton_max_seq_len: int = 0
|
796
|
-
triton_max_extend_len: int = 0
|
797
|
-
triton_start_loc: torch.Tensor = None
|
798
|
-
triton_prefix_lens: torch.Tensor = None
|
799
|
-
|
800
|
-
# FlashInfer attention backend
|
801
|
-
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
802
|
-
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
803
|
-
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
804
|
-
flashinfer_use_ragged: bool = False
|
805
|
-
|
806
|
-
@classmethod
|
807
|
-
def create(
|
808
|
-
cls,
|
809
|
-
model_runner,
|
810
|
-
forward_mode,
|
811
|
-
req_pool_indices,
|
812
|
-
seq_lens,
|
813
|
-
prefix_lens,
|
814
|
-
position_ids_offsets,
|
815
|
-
out_cache_loc,
|
816
|
-
top_logprobs_nums=None,
|
817
|
-
return_logprob=False,
|
818
|
-
skip_flashinfer_init=False,
|
819
|
-
):
|
820
|
-
flashinfer_use_ragged = False
|
821
|
-
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
822
|
-
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
823
|
-
flashinfer_use_ragged = True
|
824
|
-
init_flashinfer_args(
|
825
|
-
forward_mode,
|
826
|
-
model_runner,
|
827
|
-
req_pool_indices,
|
828
|
-
seq_lens,
|
829
|
-
prefix_lens,
|
830
|
-
model_runner.flashinfer_decode_wrapper,
|
831
|
-
flashinfer_use_ragged,
|
832
|
-
)
|
833
|
-
|
834
|
-
batch_size = len(req_pool_indices)
|
835
|
-
|
836
|
-
if forward_mode == ForwardMode.DECODE:
|
837
|
-
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
838
|
-
extend_seq_lens = extend_start_loc = extend_no_prefix = None
|
839
|
-
if not model_runner.server_args.disable_flashinfer:
|
840
|
-
# This variable is not needed in this case,
|
841
|
-
# we do not compute it to make it compatbile with cuda graph.
|
842
|
-
total_num_tokens = None
|
843
|
-
else:
|
844
|
-
total_num_tokens = int(torch.sum(seq_lens))
|
845
|
-
else:
|
846
|
-
seq_lens_cpu = seq_lens.cpu().numpy()
|
847
|
-
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
848
|
-
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
849
|
-
positions = torch.tensor(
|
850
|
-
np.concatenate(
|
851
|
-
[
|
852
|
-
np.arange(
|
853
|
-
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
854
|
-
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
855
|
-
)
|
856
|
-
for i in range(batch_size)
|
857
|
-
],
|
858
|
-
axis=0,
|
859
|
-
),
|
860
|
-
device="cuda",
|
861
|
-
)
|
862
|
-
extend_seq_lens = seq_lens - prefix_lens
|
863
|
-
extend_start_loc = torch.zeros_like(seq_lens)
|
864
|
-
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
865
|
-
extend_no_prefix = torch.all(prefix_lens == 0)
|
866
|
-
total_num_tokens = int(torch.sum(seq_lens))
|
867
|
-
|
868
|
-
ret = cls(
|
869
|
-
forward_mode=forward_mode,
|
870
|
-
batch_size=batch_size,
|
871
|
-
total_num_tokens=total_num_tokens,
|
872
|
-
req_pool_indices=req_pool_indices,
|
873
|
-
seq_lens=seq_lens,
|
874
|
-
positions=positions,
|
875
|
-
req_to_token_pool=model_runner.req_to_token_pool,
|
876
|
-
token_to_kv_pool=model_runner.token_to_kv_pool,
|
877
|
-
out_cache_loc=out_cache_loc,
|
878
|
-
extend_seq_lens=extend_seq_lens,
|
879
|
-
extend_start_loc=extend_start_loc,
|
880
|
-
extend_no_prefix=extend_no_prefix,
|
881
|
-
return_logprob=return_logprob,
|
882
|
-
top_logprobs_nums=top_logprobs_nums,
|
883
|
-
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
884
|
-
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
885
|
-
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
886
|
-
flashinfer_use_ragged=flashinfer_use_ragged,
|
887
|
-
)
|
888
|
-
|
889
|
-
if model_runner.server_args.disable_flashinfer:
|
890
|
-
(
|
891
|
-
ret.triton_max_seq_len,
|
892
|
-
ret.triton_max_extend_len,
|
893
|
-
ret.triton_start_loc,
|
894
|
-
ret.triton_prefix_lens,
|
895
|
-
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
896
|
-
|
897
|
-
return ret
|
898
|
-
|
899
|
-
|
900
|
-
def init_flashinfer_args(
|
901
|
-
forward_mode,
|
902
|
-
model_runner,
|
903
|
-
req_pool_indices,
|
904
|
-
seq_lens,
|
905
|
-
prefix_lens,
|
906
|
-
flashinfer_decode_wrapper,
|
907
|
-
flashinfer_use_ragged=False,
|
908
|
-
):
|
909
|
-
"""Init auxiliary variables for FlashInfer attention backend."""
|
910
|
-
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
911
|
-
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
912
|
-
head_dim = model_runner.model_config.head_dim
|
913
|
-
batch_size = len(req_pool_indices)
|
914
|
-
total_num_tokens = int(torch.sum(seq_lens))
|
915
|
-
|
916
|
-
if flashinfer_use_ragged:
|
917
|
-
paged_kernel_lens = prefix_lens
|
918
|
-
else:
|
919
|
-
paged_kernel_lens = seq_lens
|
920
|
-
|
921
|
-
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
922
|
-
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
923
|
-
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
924
|
-
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
925
|
-
kv_indices = torch.cat(
|
926
|
-
[
|
927
|
-
model_runner.req_to_token_pool.req_to_token[
|
928
|
-
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
929
|
-
]
|
930
|
-
for i in range(batch_size)
|
931
|
-
],
|
932
|
-
dim=0,
|
933
|
-
).contiguous()
|
934
|
-
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
935
|
-
|
936
|
-
if forward_mode == ForwardMode.DECODE:
|
937
|
-
flashinfer_decode_wrapper.end_forward()
|
938
|
-
flashinfer_decode_wrapper.begin_forward(
|
939
|
-
kv_indptr,
|
940
|
-
kv_indices,
|
941
|
-
kv_last_page_len,
|
942
|
-
num_qo_heads,
|
943
|
-
num_kv_heads,
|
944
|
-
head_dim,
|
945
|
-
1,
|
946
|
-
)
|
947
|
-
else:
|
948
|
-
# extend part
|
949
|
-
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
950
|
-
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
951
|
-
|
952
|
-
if flashinfer_use_ragged:
|
953
|
-
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
954
|
-
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
955
|
-
qo_indptr,
|
956
|
-
qo_indptr,
|
957
|
-
num_qo_heads,
|
958
|
-
num_kv_heads,
|
959
|
-
head_dim,
|
960
|
-
)
|
961
|
-
|
962
|
-
# cached part
|
963
|
-
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
964
|
-
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
965
|
-
qo_indptr,
|
966
|
-
kv_indptr,
|
967
|
-
kv_indices,
|
968
|
-
kv_last_page_len,
|
969
|
-
num_qo_heads,
|
970
|
-
num_kv_heads,
|
971
|
-
head_dim,
|
972
|
-
1,
|
973
|
-
)
|
974
|
-
|
975
|
-
|
976
|
-
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
977
|
-
"""Init auxiliary variables for triton attention backend."""
|
978
|
-
batch_size = len(seq_lens)
|
979
|
-
max_seq_len = int(torch.max(seq_lens))
|
980
|
-
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
981
|
-
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
982
|
-
|
983
|
-
if forward_mode == ForwardMode.DECODE:
|
984
|
-
max_extend_len = None
|
985
|
-
else:
|
986
|
-
extend_seq_lens = seq_lens - prefix_lens
|
987
|
-
max_extend_len = int(torch.max(extend_seq_lens))
|
988
|
-
|
989
|
-
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
990
|
-
|
991
|
-
|
992
767
|
def top_k_top_p_sampling_from_probs_torch(
|
993
768
|
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
|
994
769
|
):
|