sglang 0.2.10__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +145 -36
- sglang/check_env.py +24 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -29
- sglang/lang/choices.py +164 -0
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +11 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -5
- sglang/srt/managers/schedule_batch.py +95 -324
- sglang/srt/managers/tokenizer_manager.py +6 -3
- sglang/srt/managers/tp_worker.py +20 -22
- sglang/srt/mem_cache/memory_pool.py +9 -14
- sglang/srt/model_executor/cuda_graph_runner.py +3 -3
- sglang/srt/model_executor/forward_batch_info.py +256 -0
- sglang/srt/model_executor/model_runner.py +6 -10
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +1 -1
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +1 -1
- sglang/srt/models/llama2.py +1 -1
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +34 -12
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/server.py +24 -6
- sglang/srt/server_args.py +4 -0
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/METADATA +34 -24
- {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/RECORD +52 -50
- {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,6 @@ limitations under the License.
|
|
18
18
|
import logging
|
19
19
|
import warnings
|
20
20
|
from dataclasses import dataclass
|
21
|
-
from enum import IntEnum, auto
|
22
21
|
from typing import List, Union
|
23
22
|
|
24
23
|
import numpy as np
|
@@ -46,15 +45,6 @@ global_server_args_dict = {
|
|
46
45
|
logger = logging.getLogger(__name__)
|
47
46
|
|
48
47
|
|
49
|
-
class ForwardMode(IntEnum):
|
50
|
-
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
51
|
-
PREFILL = auto()
|
52
|
-
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
53
|
-
EXTEND = auto()
|
54
|
-
# Decode one token.
|
55
|
-
DECODE = auto()
|
56
|
-
|
57
|
-
|
58
48
|
class BaseFinishReason:
|
59
49
|
def __init__(self, is_error: bool = False):
|
60
50
|
self.is_error = is_error
|
@@ -110,6 +100,9 @@ class Req:
|
|
110
100
|
self.output_ids = [] # Each decode stage's output ids
|
111
101
|
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
112
102
|
|
103
|
+
# Memory info
|
104
|
+
self.req_pool_idx = None
|
105
|
+
|
113
106
|
# For incremental decoding
|
114
107
|
# ----- | --------- read_ids -------|
|
115
108
|
# ----- | surr_ids |
|
@@ -284,7 +277,7 @@ class Req:
|
|
284
277
|
|
285
278
|
|
286
279
|
@dataclass
|
287
|
-
class
|
280
|
+
class ScheduleBatch:
|
288
281
|
"""Store all inforamtion of a batch."""
|
289
282
|
|
290
283
|
# Request, memory pool, and cache
|
@@ -331,6 +324,9 @@ class Batch:
|
|
331
324
|
return_logprob=return_logprob,
|
332
325
|
)
|
333
326
|
|
327
|
+
def batch_size(self):
|
328
|
+
return len(self.reqs) if self.reqs is not None else 0
|
329
|
+
|
334
330
|
def is_empty(self):
|
335
331
|
return len(self.reqs) == 0
|
336
332
|
|
@@ -338,118 +334,127 @@ class Batch:
|
|
338
334
|
# Return whether batch has at least 1 streaming request
|
339
335
|
return any(r.stream for r in self.reqs)
|
340
336
|
|
337
|
+
def alloc_req_slots(self, num_reqs):
|
338
|
+
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
339
|
+
if req_pool_indices is None:
|
340
|
+
raise RuntimeError(
|
341
|
+
"Out of memory. "
|
342
|
+
"Please set a smaller number for `--max-running-requests`."
|
343
|
+
)
|
344
|
+
return req_pool_indices
|
345
|
+
|
346
|
+
def alloc_token_slots(self, num_tokens: int):
|
347
|
+
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
348
|
+
|
349
|
+
if out_cache_loc is None:
|
350
|
+
if self.tree_cache is not None:
|
351
|
+
self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
|
352
|
+
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
353
|
+
|
354
|
+
if out_cache_loc is None:
|
355
|
+
logger.error("Prefill out of memory. Try to lower your batch size.")
|
356
|
+
if self.tree_cache is not None:
|
357
|
+
self.tree_cache.pretty_print()
|
358
|
+
exit(1)
|
359
|
+
|
360
|
+
return out_cache_loc
|
361
|
+
|
362
|
+
def batch_sampling_params(self, vocab_size, int_token_logit_bias):
|
363
|
+
device = "cuda"
|
364
|
+
bs, reqs = self.batch_size(), self.reqs
|
365
|
+
self.temperatures = torch.tensor(
|
366
|
+
[r.sampling_params.temperature for r in reqs],
|
367
|
+
dtype=torch.float,
|
368
|
+
device=device,
|
369
|
+
).view(-1, 1)
|
370
|
+
self.top_ps = torch.tensor(
|
371
|
+
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
372
|
+
)
|
373
|
+
self.top_ks = torch.tensor(
|
374
|
+
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
375
|
+
)
|
376
|
+
self.frequency_penalties = torch.tensor(
|
377
|
+
[r.sampling_params.frequency_penalty for r in reqs],
|
378
|
+
dtype=torch.float,
|
379
|
+
device=device,
|
380
|
+
)
|
381
|
+
self.presence_penalties = torch.tensor(
|
382
|
+
[r.sampling_params.presence_penalty for r in reqs],
|
383
|
+
dtype=torch.float,
|
384
|
+
device=device,
|
385
|
+
)
|
386
|
+
|
387
|
+
# Handle logit bias but only allocate when needed
|
388
|
+
self.logit_bias = None
|
389
|
+
for i in range(bs):
|
390
|
+
if reqs[i].sampling_params.dtype == "int":
|
391
|
+
if self.logit_bias is None:
|
392
|
+
self.logit_bias = torch.zeros(
|
393
|
+
(bs, vocab_size), dtype=torch.float32, device=device
|
394
|
+
)
|
395
|
+
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
396
|
+
|
341
397
|
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
342
398
|
device = "cuda"
|
343
|
-
bs =
|
399
|
+
bs = self.batch_size()
|
344
400
|
reqs = self.reqs
|
345
401
|
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
346
402
|
prefix_indices = [r.prefix_indices for r in reqs]
|
347
403
|
|
348
404
|
# Handle prefix
|
349
|
-
flatten_input_ids = []
|
350
405
|
extend_lens = []
|
351
406
|
prefix_lens = []
|
352
407
|
seq_lens = []
|
353
408
|
|
354
|
-
|
409
|
+
req_pool_indices_cpu = self.alloc_req_slots(bs)
|
355
410
|
|
356
|
-
|
357
|
-
|
358
|
-
"Out of memory. "
|
359
|
-
"Please set a smaller number for `--max-running-requests`."
|
360
|
-
)
|
361
|
-
|
362
|
-
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
363
|
-
for i in range(bs):
|
364
|
-
flatten_input_ids.extend(input_ids[i])
|
411
|
+
for i, req in enumerate(reqs):
|
412
|
+
req.req_pool_idx = req_pool_indices_cpu[i]
|
365
413
|
extend_lens.append(len(input_ids[i]))
|
366
414
|
|
367
415
|
if len(prefix_indices[i]) == 0:
|
368
416
|
prefix_lens.append(0)
|
369
417
|
else:
|
370
418
|
prefix_lens.append(len(prefix_indices[i]))
|
371
|
-
self.req_to_token_pool.req_to_token[
|
419
|
+
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
372
420
|
: len(prefix_indices[i])
|
373
421
|
] = prefix_indices[i]
|
374
422
|
|
375
423
|
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
376
424
|
|
377
|
-
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
|
378
|
-
|
379
425
|
# Allocate memory
|
380
426
|
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
381
427
|
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
382
|
-
out_cache_loc = self.
|
383
|
-
if out_cache_loc is None:
|
384
|
-
if self.tree_cache is not None:
|
385
|
-
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
386
|
-
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
387
|
-
|
388
|
-
if out_cache_loc is None:
|
389
|
-
logger.error("Prefill out of memory. Try to lower your batch size.")
|
390
|
-
if self.tree_cache is not None:
|
391
|
-
self.tree_cache.pretty_print()
|
392
|
-
exit(1)
|
428
|
+
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
393
429
|
|
394
430
|
pt = 0
|
395
|
-
for i in
|
396
|
-
self.req_to_token_pool.req_to_token[
|
431
|
+
for i, req in enumerate(reqs):
|
432
|
+
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
397
433
|
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
|
398
434
|
] = out_cache_loc[pt : pt + extend_lens[i]]
|
399
435
|
pt += extend_lens[i]
|
400
436
|
|
401
|
-
# Handle logit bias but only allocate when needed
|
402
|
-
logit_bias = None
|
403
|
-
for i in range(bs):
|
404
|
-
if reqs[i].sampling_params.dtype == "int":
|
405
|
-
if logit_bias is None:
|
406
|
-
logit_bias = torch.zeros(
|
407
|
-
(bs, vocab_size), dtype=torch.float32, device=device
|
408
|
-
)
|
409
|
-
logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
410
|
-
|
411
437
|
# Set fields
|
412
|
-
|
413
|
-
|
414
|
-
|
438
|
+
with torch.device("cuda"):
|
439
|
+
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
440
|
+
self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
|
441
|
+
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
|
442
|
+
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32)
|
443
|
+
|
415
444
|
self.pixel_values = [r.pixel_values for r in reqs]
|
416
445
|
self.image_sizes = [r.image_size for r in reqs]
|
417
446
|
self.image_offsets = [
|
418
447
|
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
419
448
|
]
|
420
|
-
self.req_pool_indices = req_pool_indices
|
421
|
-
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
|
422
449
|
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
423
|
-
self.position_ids_offsets = position_ids_offsets
|
424
450
|
self.extend_num_tokens = extend_num_tokens
|
425
451
|
self.out_cache_loc = out_cache_loc
|
426
452
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
427
453
|
|
428
|
-
self.
|
429
|
-
[r.sampling_params.temperature for r in reqs],
|
430
|
-
dtype=torch.float,
|
431
|
-
device=device,
|
432
|
-
).view(-1, 1)
|
433
|
-
self.top_ps = torch.tensor(
|
434
|
-
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
435
|
-
)
|
436
|
-
self.top_ks = torch.tensor(
|
437
|
-
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
438
|
-
)
|
439
|
-
self.frequency_penalties = torch.tensor(
|
440
|
-
[r.sampling_params.frequency_penalty for r in reqs],
|
441
|
-
dtype=torch.float,
|
442
|
-
device=device,
|
443
|
-
)
|
444
|
-
self.presence_penalties = torch.tensor(
|
445
|
-
[r.sampling_params.presence_penalty for r in reqs],
|
446
|
-
dtype=torch.float,
|
447
|
-
device=device,
|
448
|
-
)
|
449
|
-
self.logit_bias = logit_bias
|
454
|
+
self.batch_sampling_params(vocab_size, int_token_logit_bias)
|
450
455
|
|
451
456
|
def check_decode_mem(self):
|
452
|
-
bs =
|
457
|
+
bs = self.batch_size()
|
453
458
|
if self.token_to_kv_pool.available_size() >= bs:
|
454
459
|
return True
|
455
460
|
|
@@ -474,7 +479,6 @@ class Batch:
|
|
474
479
|
|
475
480
|
retracted_reqs = []
|
476
481
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
477
|
-
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
478
482
|
while (
|
479
483
|
self.token_to_kv_pool.available_size()
|
480
484
|
< len(sorted_indices) * global_config.retract_decode_steps
|
@@ -492,20 +496,20 @@ class Batch:
|
|
492
496
|
|
493
497
|
if isinstance(self.tree_cache, ChunkCache):
|
494
498
|
# ChunkCache does not have eviction
|
495
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
496
|
-
|
497
|
-
]
|
499
|
+
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
500
|
+
: seq_lens_cpu[idx]
|
501
|
+
]
|
498
502
|
self.token_to_kv_pool.free(token_indices)
|
499
|
-
self.req_to_token_pool.free(
|
503
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
500
504
|
del self.tree_cache.entries[req.rid]
|
501
505
|
else:
|
502
506
|
# TODO: apply more fine-grained retraction
|
503
507
|
last_uncached_pos = len(req.prefix_indices)
|
504
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
505
|
-
|
506
|
-
]
|
508
|
+
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
509
|
+
last_uncached_pos : seq_lens_cpu[idx]
|
510
|
+
]
|
507
511
|
self.token_to_kv_pool.free(token_indices)
|
508
|
-
self.req_to_token_pool.free(
|
512
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
509
513
|
|
510
514
|
# release the last node
|
511
515
|
self.tree_cache.dec_lock_ref(req.last_node)
|
@@ -543,8 +547,6 @@ class Batch:
|
|
543
547
|
jump_forward_reqs = []
|
544
548
|
filter_indices = [i for i in range(len(self.reqs))]
|
545
549
|
|
546
|
-
req_pool_indices_cpu = None
|
547
|
-
|
548
550
|
for i, req in enumerate(self.reqs):
|
549
551
|
if req.jump_forward_map is not None:
|
550
552
|
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
|
@@ -594,13 +596,11 @@ class Batch:
|
|
594
596
|
req.vid += 1
|
595
597
|
|
596
598
|
# insert the old request into tree_cache
|
597
|
-
if req_pool_indices_cpu is None:
|
598
|
-
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
599
599
|
self.tree_cache.cache_req(
|
600
600
|
rid=req.rid,
|
601
601
|
token_ids=cur_all_ids,
|
602
602
|
last_uncached_pos=len(req.prefix_indices),
|
603
|
-
req_pool_idx=
|
603
|
+
req_pool_idx=req.req_pool_idx,
|
604
604
|
)
|
605
605
|
|
606
606
|
# unlock the last node
|
@@ -636,14 +636,8 @@ class Batch:
|
|
636
636
|
self.prefix_lens = None
|
637
637
|
|
638
638
|
# Alloc mem
|
639
|
-
bs =
|
640
|
-
self.out_cache_loc = self.
|
641
|
-
|
642
|
-
if self.out_cache_loc is None:
|
643
|
-
logger.error("Decode out of memory. Try to lower your batch size.")
|
644
|
-
if self.tree_cache is not None:
|
645
|
-
self.tree_cache.pretty_print()
|
646
|
-
exit(1)
|
639
|
+
bs = self.batch_size()
|
640
|
+
self.out_cache_loc = self.alloc_token_slots(bs)
|
647
641
|
|
648
642
|
self.req_to_token_pool.req_to_token[
|
649
643
|
self.req_pool_indices, self.seq_lens - 1
|
@@ -673,7 +667,7 @@ class Batch:
|
|
673
667
|
if self_val is not None: # logit_bias can be None
|
674
668
|
setattr(self, item, self_val[new_indices])
|
675
669
|
|
676
|
-
def merge(self, other: "
|
670
|
+
def merge(self, other: "ScheduleBatch"):
|
677
671
|
self.reqs.extend(other.reqs)
|
678
672
|
|
679
673
|
self.req_pool_indices = torch.concat(
|
@@ -770,229 +764,6 @@ class Batch:
|
|
770
764
|
return batch_next_token_ids
|
771
765
|
|
772
766
|
|
773
|
-
@dataclass
|
774
|
-
class InputMetadata:
|
775
|
-
"""Store all inforamtion of a forward pass."""
|
776
|
-
|
777
|
-
forward_mode: ForwardMode
|
778
|
-
batch_size: int
|
779
|
-
total_num_tokens: int
|
780
|
-
req_pool_indices: torch.Tensor
|
781
|
-
seq_lens: torch.Tensor
|
782
|
-
positions: torch.Tensor
|
783
|
-
req_to_token_pool: ReqToTokenPool
|
784
|
-
token_to_kv_pool: BaseTokenToKVPool
|
785
|
-
|
786
|
-
# For extend
|
787
|
-
extend_seq_lens: torch.Tensor
|
788
|
-
extend_start_loc: torch.Tensor
|
789
|
-
extend_no_prefix: bool
|
790
|
-
|
791
|
-
# Output location of the KV cache
|
792
|
-
out_cache_loc: torch.Tensor = None
|
793
|
-
|
794
|
-
# Output options
|
795
|
-
return_logprob: bool = False
|
796
|
-
top_logprobs_nums: List[int] = None
|
797
|
-
|
798
|
-
# Trition attention backend
|
799
|
-
triton_max_seq_len: int = 0
|
800
|
-
triton_max_extend_len: int = 0
|
801
|
-
triton_start_loc: torch.Tensor = None
|
802
|
-
triton_prefix_lens: torch.Tensor = None
|
803
|
-
|
804
|
-
# FlashInfer attention backend
|
805
|
-
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
806
|
-
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
807
|
-
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
808
|
-
flashinfer_use_ragged: bool = False
|
809
|
-
|
810
|
-
@classmethod
|
811
|
-
def create(
|
812
|
-
cls,
|
813
|
-
model_runner,
|
814
|
-
forward_mode,
|
815
|
-
req_pool_indices,
|
816
|
-
seq_lens,
|
817
|
-
prefix_lens,
|
818
|
-
position_ids_offsets,
|
819
|
-
out_cache_loc,
|
820
|
-
top_logprobs_nums=None,
|
821
|
-
return_logprob=False,
|
822
|
-
skip_flashinfer_init=False,
|
823
|
-
):
|
824
|
-
flashinfer_use_ragged = False
|
825
|
-
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
826
|
-
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
827
|
-
flashinfer_use_ragged = True
|
828
|
-
init_flashinfer_args(
|
829
|
-
forward_mode,
|
830
|
-
model_runner,
|
831
|
-
req_pool_indices,
|
832
|
-
seq_lens,
|
833
|
-
prefix_lens,
|
834
|
-
model_runner.flashinfer_decode_wrapper,
|
835
|
-
flashinfer_use_ragged,
|
836
|
-
)
|
837
|
-
|
838
|
-
batch_size = len(req_pool_indices)
|
839
|
-
|
840
|
-
if forward_mode == ForwardMode.DECODE:
|
841
|
-
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
842
|
-
extend_seq_lens = extend_start_loc = extend_no_prefix = None
|
843
|
-
if not model_runner.server_args.disable_flashinfer:
|
844
|
-
# This variable is not needed in this case,
|
845
|
-
# we do not compute it to make it compatbile with cuda graph.
|
846
|
-
total_num_tokens = None
|
847
|
-
else:
|
848
|
-
total_num_tokens = int(torch.sum(seq_lens))
|
849
|
-
else:
|
850
|
-
seq_lens_cpu = seq_lens.cpu().numpy()
|
851
|
-
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
852
|
-
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
853
|
-
positions = torch.tensor(
|
854
|
-
np.concatenate(
|
855
|
-
[
|
856
|
-
np.arange(
|
857
|
-
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
858
|
-
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
859
|
-
)
|
860
|
-
for i in range(batch_size)
|
861
|
-
],
|
862
|
-
axis=0,
|
863
|
-
),
|
864
|
-
device="cuda",
|
865
|
-
)
|
866
|
-
extend_seq_lens = seq_lens - prefix_lens
|
867
|
-
extend_start_loc = torch.zeros_like(seq_lens)
|
868
|
-
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
869
|
-
extend_no_prefix = torch.all(prefix_lens == 0)
|
870
|
-
total_num_tokens = int(torch.sum(seq_lens))
|
871
|
-
|
872
|
-
ret = cls(
|
873
|
-
forward_mode=forward_mode,
|
874
|
-
batch_size=batch_size,
|
875
|
-
total_num_tokens=total_num_tokens,
|
876
|
-
req_pool_indices=req_pool_indices,
|
877
|
-
seq_lens=seq_lens,
|
878
|
-
positions=positions,
|
879
|
-
req_to_token_pool=model_runner.req_to_token_pool,
|
880
|
-
token_to_kv_pool=model_runner.token_to_kv_pool,
|
881
|
-
out_cache_loc=out_cache_loc,
|
882
|
-
extend_seq_lens=extend_seq_lens,
|
883
|
-
extend_start_loc=extend_start_loc,
|
884
|
-
extend_no_prefix=extend_no_prefix,
|
885
|
-
return_logprob=return_logprob,
|
886
|
-
top_logprobs_nums=top_logprobs_nums,
|
887
|
-
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
888
|
-
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
889
|
-
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
890
|
-
flashinfer_use_ragged=flashinfer_use_ragged,
|
891
|
-
)
|
892
|
-
|
893
|
-
if model_runner.server_args.disable_flashinfer:
|
894
|
-
(
|
895
|
-
ret.triton_max_seq_len,
|
896
|
-
ret.triton_max_extend_len,
|
897
|
-
ret.triton_start_loc,
|
898
|
-
ret.triton_prefix_lens,
|
899
|
-
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
900
|
-
|
901
|
-
return ret
|
902
|
-
|
903
|
-
|
904
|
-
def init_flashinfer_args(
|
905
|
-
forward_mode,
|
906
|
-
model_runner,
|
907
|
-
req_pool_indices,
|
908
|
-
seq_lens,
|
909
|
-
prefix_lens,
|
910
|
-
flashinfer_decode_wrapper,
|
911
|
-
flashinfer_use_ragged=False,
|
912
|
-
):
|
913
|
-
"""Init auxiliary variables for FlashInfer attention backend."""
|
914
|
-
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
915
|
-
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
916
|
-
head_dim = model_runner.model_config.head_dim
|
917
|
-
batch_size = len(req_pool_indices)
|
918
|
-
total_num_tokens = int(torch.sum(seq_lens))
|
919
|
-
|
920
|
-
if flashinfer_use_ragged:
|
921
|
-
paged_kernel_lens = prefix_lens
|
922
|
-
else:
|
923
|
-
paged_kernel_lens = seq_lens
|
924
|
-
|
925
|
-
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
926
|
-
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
927
|
-
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
928
|
-
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
929
|
-
kv_indices = torch.cat(
|
930
|
-
[
|
931
|
-
model_runner.req_to_token_pool.req_to_token[
|
932
|
-
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
933
|
-
]
|
934
|
-
for i in range(batch_size)
|
935
|
-
],
|
936
|
-
dim=0,
|
937
|
-
).contiguous()
|
938
|
-
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
939
|
-
|
940
|
-
if forward_mode == ForwardMode.DECODE:
|
941
|
-
flashinfer_decode_wrapper.end_forward()
|
942
|
-
flashinfer_decode_wrapper.begin_forward(
|
943
|
-
kv_indptr,
|
944
|
-
kv_indices,
|
945
|
-
kv_last_page_len,
|
946
|
-
num_qo_heads,
|
947
|
-
num_kv_heads,
|
948
|
-
head_dim,
|
949
|
-
1,
|
950
|
-
)
|
951
|
-
else:
|
952
|
-
# extend part
|
953
|
-
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
954
|
-
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
955
|
-
|
956
|
-
if flashinfer_use_ragged:
|
957
|
-
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
958
|
-
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
959
|
-
qo_indptr,
|
960
|
-
qo_indptr,
|
961
|
-
num_qo_heads,
|
962
|
-
num_kv_heads,
|
963
|
-
head_dim,
|
964
|
-
)
|
965
|
-
|
966
|
-
# cached part
|
967
|
-
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
968
|
-
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
969
|
-
qo_indptr,
|
970
|
-
kv_indptr,
|
971
|
-
kv_indices,
|
972
|
-
kv_last_page_len,
|
973
|
-
num_qo_heads,
|
974
|
-
num_kv_heads,
|
975
|
-
head_dim,
|
976
|
-
1,
|
977
|
-
)
|
978
|
-
|
979
|
-
|
980
|
-
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
981
|
-
"""Init auxiliary variables for triton attention backend."""
|
982
|
-
batch_size = len(seq_lens)
|
983
|
-
max_seq_len = int(torch.max(seq_lens))
|
984
|
-
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
985
|
-
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
986
|
-
|
987
|
-
if forward_mode == ForwardMode.DECODE:
|
988
|
-
max_extend_len = None
|
989
|
-
else:
|
990
|
-
extend_seq_lens = seq_lens - prefix_lens
|
991
|
-
max_extend_len = int(torch.max(extend_seq_lens))
|
992
|
-
|
993
|
-
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
994
|
-
|
995
|
-
|
996
767
|
def top_k_top_p_sampling_from_probs_torch(
|
997
768
|
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
|
998
769
|
):
|
@@ -308,7 +308,6 @@ class TokenizerManager:
|
|
308
308
|
event = asyncio.Event()
|
309
309
|
state = ReqState([], False, event)
|
310
310
|
self.rid_to_state[rid] = state
|
311
|
-
|
312
311
|
# Then wait for all responses
|
313
312
|
output_list = []
|
314
313
|
for i in range(batch_size):
|
@@ -341,7 +340,6 @@ class TokenizerManager:
|
|
341
340
|
)
|
342
341
|
assert state.finished
|
343
342
|
del self.rid_to_state[rid]
|
344
|
-
|
345
343
|
yield output_list
|
346
344
|
|
347
345
|
def _validate_input_length(self, input_ids: List[int]):
|
@@ -390,8 +388,13 @@ class TokenizerManager:
|
|
390
388
|
obj.return_text_in_logprobs,
|
391
389
|
)
|
392
390
|
|
391
|
+
# Log requests
|
393
392
|
if self.server_args.log_requests and state.finished:
|
394
|
-
|
393
|
+
if obj.text is None:
|
394
|
+
in_obj = {"text": self.tokenizer.decode(obj.input_ids)}
|
395
|
+
else:
|
396
|
+
in_obj = {"text": obj.text}
|
397
|
+
logger.info(f"in={in_obj}, out={out}")
|
395
398
|
|
396
399
|
state.out_list = []
|
397
400
|
if state.finished:
|