sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +13 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +12 -16
- sglang/srt/disaggregation/prefill.py +17 -13
- sglang/srt/disaggregation/utils.py +46 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +22 -28
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +21 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +19 -9
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +207 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +91 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/io_struct.py +9 -12
- sglang/srt/managers/schedule_batch.py +40 -31
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +147 -62
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +76 -45
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +22 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +108 -26
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +36 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/utils.py +177 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
sglang/srt/managers/io_struct.py
CHANGED
@@ -226,11 +226,11 @@ class GenerateReqInput:
|
|
226
226
|
|
227
227
|
# Expand input based on type
|
228
228
|
self._expand_inputs(num)
|
229
|
+
self._normalize_rid(num)
|
229
230
|
self._normalize_lora_paths(num)
|
230
231
|
self._normalize_image_data(num)
|
231
232
|
self._normalize_audio_data(num)
|
232
233
|
self._normalize_sampling_params(num)
|
233
|
-
self._normalize_rid(num)
|
234
234
|
self._normalize_logprob_params(num)
|
235
235
|
self._normalize_custom_logit_processor(num)
|
236
236
|
|
@@ -530,6 +530,7 @@ class EmbeddingReqInput:
|
|
530
530
|
if self.text is not None:
|
531
531
|
if isinstance(self.text, list):
|
532
532
|
self.batch_size += len(self.text)
|
533
|
+
self.is_single = False
|
533
534
|
else:
|
534
535
|
self.batch_size += 1
|
535
536
|
|
@@ -537,12 +538,10 @@ class EmbeddingReqInput:
|
|
537
538
|
if self.input_ids is not None:
|
538
539
|
if isinstance(self.input_ids[0], list):
|
539
540
|
self.batch_size += len(self.input_ids)
|
541
|
+
self.is_single = False
|
540
542
|
else:
|
541
543
|
self.batch_size += 1
|
542
544
|
|
543
|
-
if self.batch_size > 1:
|
544
|
-
self.is_single = False
|
545
|
-
|
546
545
|
# Fill in default arguments
|
547
546
|
if self.is_single:
|
548
547
|
if self.rid is None:
|
@@ -812,7 +811,9 @@ class GetWeightsByNameReqOutput:
|
|
812
811
|
|
813
812
|
@dataclass
|
814
813
|
class ReleaseMemoryOccupationReqInput:
|
815
|
-
|
814
|
+
# Optional tags to identify the memory region, which is primarily used for RL
|
815
|
+
# Currently we only support `weights` and `kv_cache`
|
816
|
+
tags: Optional[List[str]] = None
|
816
817
|
|
817
818
|
|
818
819
|
@dataclass
|
@@ -822,7 +823,9 @@ class ReleaseMemoryOccupationReqOutput:
|
|
822
823
|
|
823
824
|
@dataclass
|
824
825
|
class ResumeMemoryOccupationReqInput:
|
825
|
-
|
826
|
+
# Optional tags to identify the memory region, which is primarily used for RL
|
827
|
+
# Currently we only support `weights` and `kv_cache`
|
828
|
+
tags: Optional[List[str]] = None
|
826
829
|
|
827
830
|
|
828
831
|
@dataclass
|
@@ -861,12 +864,6 @@ class SetInternalStateReq:
|
|
861
864
|
server_args: Dict[str, Any]
|
862
865
|
|
863
866
|
|
864
|
-
@dataclass
|
865
|
-
class V1RerankReqInput:
|
866
|
-
query: str
|
867
|
-
documents: List[str]
|
868
|
-
|
869
|
-
|
870
867
|
@dataclass
|
871
868
|
class SetInternalStateReqOutput:
|
872
869
|
updated: bool
|
@@ -38,7 +38,7 @@ import logging
|
|
38
38
|
import threading
|
39
39
|
from enum import Enum, auto
|
40
40
|
from http import HTTPStatus
|
41
|
-
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
41
|
+
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
42
42
|
|
43
43
|
import numpy as np
|
44
44
|
import torch
|
@@ -54,9 +54,10 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
54
54
|
)
|
55
55
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
56
56
|
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
57
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
57
58
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
58
59
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
59
|
-
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
60
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
60
61
|
from sglang.srt.metrics.collector import TimeStats
|
61
62
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
62
63
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
@@ -85,6 +86,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
85
86
|
"enable_deepep_moe",
|
86
87
|
"deepep_mode",
|
87
88
|
"enable_ep_moe",
|
89
|
+
"enable_flashinfer_moe",
|
88
90
|
"moe_dense_tp_size",
|
89
91
|
"ep_dispatch_algorithm",
|
90
92
|
"deepep_config",
|
@@ -99,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
99
101
|
"torchao_config",
|
100
102
|
"triton_attention_reduce_in_fp32",
|
101
103
|
"num_reserved_decode_tokens",
|
104
|
+
"weight_loader_disable_mmap",
|
102
105
|
]
|
103
106
|
|
104
107
|
# Put some global args for easy access
|
@@ -436,7 +439,7 @@ class Req:
|
|
436
439
|
self,
|
437
440
|
rid: str,
|
438
441
|
origin_input_text: str,
|
439
|
-
origin_input_ids:
|
442
|
+
origin_input_ids: List[int],
|
440
443
|
sampling_params: SamplingParams,
|
441
444
|
return_logprob: bool = False,
|
442
445
|
top_logprobs_num: int = 0,
|
@@ -467,7 +470,7 @@ class Req:
|
|
467
470
|
# Each decode stage's output ids
|
468
471
|
self.output_ids = []
|
469
472
|
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
|
470
|
-
self.fill_ids =
|
473
|
+
self.fill_ids = []
|
471
474
|
self.session_id = session_id
|
472
475
|
self.input_embeds = input_embeds
|
473
476
|
|
@@ -519,13 +522,14 @@ class Req:
|
|
519
522
|
|
520
523
|
# Prefix info
|
521
524
|
# The indices to kv cache for the shared prefix.
|
522
|
-
self.prefix_indices = []
|
525
|
+
self.prefix_indices: torch.Tensor = []
|
523
526
|
# Number of tokens to run prefill.
|
524
527
|
self.extend_input_len = 0
|
525
528
|
# The relative logprob_start_len in an extend batch
|
526
529
|
self.extend_logprob_start_len = 0
|
527
|
-
self.last_node = None
|
528
|
-
self.
|
530
|
+
self.last_node: Any = None
|
531
|
+
self.last_host_node: Any = None
|
532
|
+
self.host_hit_length = 0
|
529
533
|
|
530
534
|
# Whether or not if it is chunked. It increments whenever
|
531
535
|
# it is chunked, and decrement whenever chunked request is
|
@@ -583,6 +587,7 @@ class Req:
|
|
583
587
|
self.output_token_ids_logprobs_idx
|
584
588
|
) = None
|
585
589
|
self.hidden_states: List[List[float]] = []
|
590
|
+
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
|
586
591
|
|
587
592
|
# Embedding (return values)
|
588
593
|
self.embedding = None
|
@@ -644,29 +649,17 @@ class Req:
|
|
644
649
|
def init_next_round_input(
|
645
650
|
self,
|
646
651
|
tree_cache: Optional[BasePrefixCache] = None,
|
647
|
-
enable_hierarchical_cache=False,
|
648
652
|
):
|
649
653
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
650
654
|
if tree_cache is not None:
|
651
|
-
|
652
|
-
|
653
|
-
self.
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
)
|
658
|
-
|
659
|
-
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
660
|
-
rid=self.rid, key=self.adjust_max_prefix_ids()
|
661
|
-
)
|
662
|
-
elif enable_hierarchical_cache:
|
663
|
-
# in case last_node is evicted during scheduling, we need to update the prefix_indices
|
664
|
-
while self.last_node.evicted:
|
665
|
-
self.prefix_indices = self.prefix_indices[
|
666
|
-
: -len(self.last_node.host_value)
|
667
|
-
]
|
668
|
-
self.last_node = self.last_node.parent
|
669
|
-
|
655
|
+
(
|
656
|
+
self.prefix_indices,
|
657
|
+
self.last_node,
|
658
|
+
self.last_host_node,
|
659
|
+
self.host_hit_length,
|
660
|
+
) = tree_cache.match_prefix(
|
661
|
+
key=self.adjust_max_prefix_ids(),
|
662
|
+
)
|
670
663
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
671
664
|
|
672
665
|
def adjust_max_prefix_ids(self):
|
@@ -796,6 +789,7 @@ class Req:
|
|
796
789
|
self.multimodal_inputs = None
|
797
790
|
self.grammar = None
|
798
791
|
self.origin_input_ids = [0] # set it to one token to skip the long prefill
|
792
|
+
self.return_logprob = False
|
799
793
|
self.finished_reason = FINISH_ABORT(
|
800
794
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
801
795
|
)
|
@@ -820,7 +814,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
820
814
|
# Request, memory pool, and cache
|
821
815
|
reqs: List[Req]
|
822
816
|
req_to_token_pool: ReqToTokenPool = None
|
823
|
-
token_to_kv_pool_allocator:
|
817
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
|
824
818
|
tree_cache: BasePrefixCache = None
|
825
819
|
|
826
820
|
# Batch configs
|
@@ -862,6 +856,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
862
856
|
global_num_tokens: Optional[List[int]] = None
|
863
857
|
global_num_tokens_for_logprob: Optional[List[int]] = None
|
864
858
|
can_run_dp_cuda_graph: bool = False
|
859
|
+
is_extend_in_batch: bool = False
|
865
860
|
tbo_split_seq_index: Optional[int] = None
|
866
861
|
global_forward_mode: Optional[ForwardMode] = None
|
867
862
|
|
@@ -908,12 +903,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
908
903
|
# Whether to return hidden states
|
909
904
|
return_hidden_states: bool = False
|
910
905
|
|
906
|
+
# hicache pointer for synchronizing data loading from CPU to GPU
|
907
|
+
hicache_consumer_index: int = 0
|
908
|
+
|
911
909
|
@classmethod
|
912
910
|
def init_new(
|
913
911
|
cls,
|
914
912
|
reqs: List[Req],
|
915
913
|
req_to_token_pool: ReqToTokenPool,
|
916
|
-
token_to_kv_pool_allocator:
|
914
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
917
915
|
tree_cache: BasePrefixCache,
|
918
916
|
model_config: ModelConfig,
|
919
917
|
enable_overlap: bool,
|
@@ -1365,7 +1363,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1365
1363
|
return len(self.reqs)
|
1366
1364
|
# In the decoding phase, the length of a request's KV cache should be
|
1367
1365
|
# the total length of the request minus 1
|
1368
|
-
return
|
1366
|
+
return (
|
1367
|
+
sum(1 for req in self.reqs if req.seqlen % page_size == 0)
|
1368
|
+
if self.enable_overlap
|
1369
|
+
else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
|
1370
|
+
)
|
1369
1371
|
|
1370
1372
|
def check_decode_mem(self, buf_multiplier=1):
|
1371
1373
|
tokens_required = (
|
@@ -1734,6 +1736,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1734
1736
|
token_type_ids=self.token_type_ids,
|
1735
1737
|
spec_algorithm=self.spec_algorithm,
|
1736
1738
|
spec_info=self.spec_info,
|
1739
|
+
hicache_consumer_index=self.hicache_consumer_index,
|
1737
1740
|
capture_hidden_mode=(
|
1738
1741
|
CaptureHiddenMode.FULL
|
1739
1742
|
if self.return_hidden_states
|
@@ -1760,11 +1763,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1760
1763
|
decoding_reqs=self.decoding_reqs,
|
1761
1764
|
spec_algorithm=self.spec_algorithm,
|
1762
1765
|
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
1766
|
+
global_num_tokens=self.global_num_tokens,
|
1767
|
+
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
1768
|
+
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
1769
|
+
is_extend_in_batch=self.is_extend_in_batch,
|
1763
1770
|
)
|
1764
1771
|
|
1765
1772
|
def __str__(self):
|
1766
1773
|
return (
|
1767
|
-
f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
|
1774
|
+
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
|
1768
1775
|
f"#req={(len(self.reqs))})"
|
1769
1776
|
)
|
1770
1777
|
|
@@ -1833,6 +1840,8 @@ class ModelWorkerBatch:
|
|
1833
1840
|
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
1834
1841
|
# If set, the output of the batch contains the hidden states of the run.
|
1835
1842
|
capture_hidden_mode: CaptureHiddenMode = None
|
1843
|
+
spec_num_draft_tokens: Optional[int] = None
|
1844
|
+
hicache_consumer_index: int = 0
|
1836
1845
|
|
1837
1846
|
# Overlap event
|
1838
1847
|
launch_done: Optional[threading.Event] = None
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
# Copyright 2023-2024 SGLang Team
|
2
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
5
|
# you may not use this file except in compliance with the License.
|
@@ -18,15 +20,17 @@ import random
|
|
18
20
|
from collections import defaultdict
|
19
21
|
from contextlib import contextmanager
|
20
22
|
from enum import Enum, auto
|
21
|
-
from typing import Dict, List, Optional, Set, Union
|
23
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
|
22
24
|
|
23
25
|
import torch
|
24
26
|
|
25
27
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
26
28
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
27
|
-
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
28
29
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
29
30
|
|
31
|
+
if TYPE_CHECKING:
|
32
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
33
|
+
|
30
34
|
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
31
35
|
# This can prevent the server from being too conservative.
|
32
36
|
# Note that this only clips the estimation in the scheduler but does not change the stop
|
@@ -51,6 +55,9 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
|
|
51
55
|
)
|
52
56
|
|
53
57
|
|
58
|
+
IGNORE_EOS_RESERVE_TOKENS = 1
|
59
|
+
|
60
|
+
|
54
61
|
class CacheAwarePolicy(Enum):
|
55
62
|
"""Scheduling policies that are aware of the tree cache."""
|
56
63
|
|
@@ -90,7 +97,7 @@ class SchedulePolicy:
|
|
90
97
|
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
91
98
|
if self.policy == CacheAgnosticPolicy.FCFS:
|
92
99
|
# A shortcut for FCFS
|
93
|
-
return
|
100
|
+
return False
|
94
101
|
|
95
102
|
policy = self._determine_active_policy(waiting_queue)
|
96
103
|
|
@@ -134,7 +141,7 @@ class SchedulePolicy:
|
|
134
141
|
"""
|
135
142
|
try:
|
136
143
|
policy_enum = CacheAwarePolicy(policy)
|
137
|
-
if tree_cache
|
144
|
+
if getattr(tree_cache, "disable", True):
|
138
145
|
# If tree_cache is disabled, using CacheAgnosticPolicy policy
|
139
146
|
return CacheAgnosticPolicy.FCFS
|
140
147
|
return policy_enum
|
@@ -158,14 +165,9 @@ class SchedulePolicy:
|
|
158
165
|
prefix_ids = r.adjust_max_prefix_ids()
|
159
166
|
|
160
167
|
# NOTE: the prefix_indices must always be aligned with last_node
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
)
|
165
|
-
else:
|
166
|
-
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
167
|
-
rid=r.rid, key=prefix_ids
|
168
|
-
)
|
168
|
+
r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = (
|
169
|
+
self.tree_cache.match_prefix(rid=r.rid, key=prefix_ids)
|
170
|
+
)
|
169
171
|
|
170
172
|
# NOTE(sang): This logic is for in-batch prefix caching;
|
171
173
|
# If there are more than 1 request that have small matching prefix from
|
@@ -175,7 +177,7 @@ class SchedulePolicy:
|
|
175
177
|
# threshold means we cannot use in-batch prefix caching for short prefixes.
|
176
178
|
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
|
177
179
|
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
|
178
|
-
in_batch_matching_prefixes, _ = (
|
180
|
+
in_batch_matching_prefixes, _, _, _ = (
|
179
181
|
self.waiting_queue_radix_tree.match_prefix(
|
180
182
|
rid=r.rid, key=prefix_ids
|
181
183
|
)
|
@@ -268,14 +270,16 @@ class AddReqResult(Enum):
|
|
268
270
|
class PrefillAdder:
|
269
271
|
def __init__(
|
270
272
|
self,
|
273
|
+
page_size: int,
|
271
274
|
tree_cache: BasePrefixCache,
|
272
|
-
token_to_kv_pool_allocator:
|
275
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
273
276
|
running_batch: ScheduleBatch,
|
274
277
|
new_token_ratio: float,
|
275
278
|
rem_input_tokens: int,
|
276
279
|
rem_chunk_tokens: Optional[int],
|
277
280
|
mixed_with_decode_tokens: int = 0,
|
278
281
|
):
|
282
|
+
self.page_size = page_size
|
279
283
|
self.tree_cache = tree_cache
|
280
284
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
281
285
|
self.running_batch = running_batch
|
@@ -292,6 +296,7 @@ class PrefillAdder:
|
|
292
296
|
self.can_run_list = []
|
293
297
|
self.new_chunked_req = None
|
294
298
|
self.log_hit_tokens = 0
|
299
|
+
# TODO(lsyin): report the real input tokens excluding page alignment
|
295
300
|
self.log_input_tokens = 0
|
296
301
|
|
297
302
|
if running_batch is not None:
|
@@ -322,6 +327,9 @@ class PrefillAdder:
|
|
322
327
|
- self.cur_rem_token_offset
|
323
328
|
)
|
324
329
|
|
330
|
+
def ceil_paged_tokens(self, tokens: int) -> int:
|
331
|
+
return -(-tokens // self.page_size) * self.page_size
|
332
|
+
|
325
333
|
def budget_state(self):
|
326
334
|
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
|
327
335
|
return AddReqResult.NO_TOKEN
|
@@ -333,9 +341,12 @@ class PrefillAdder:
|
|
333
341
|
|
334
342
|
return AddReqResult.CONTINUE
|
335
343
|
|
336
|
-
def
|
344
|
+
def _update_prefill_budget(
|
337
345
|
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
338
346
|
):
|
347
|
+
# TODO(lsyin): check this workaround logic, which only ensures the prefill will not out of memory, and may be too conservative
|
348
|
+
extend_input_len = self.ceil_paged_tokens(extend_input_len)
|
349
|
+
|
339
350
|
self.rem_total_token_offset += extend_input_len + max_new_tokens
|
340
351
|
self.cur_rem_token_offset += extend_input_len
|
341
352
|
self.rem_input_tokens -= extend_input_len
|
@@ -350,7 +361,7 @@ class PrefillAdder:
|
|
350
361
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
351
362
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
352
363
|
self.can_run_list.append(req)
|
353
|
-
self.
|
364
|
+
self._update_prefill_budget(
|
354
365
|
0,
|
355
366
|
req.extend_input_len,
|
356
367
|
(
|
@@ -372,6 +383,12 @@ class PrefillAdder:
|
|
372
383
|
self.tree_cache.dec_lock_ref(last_node)
|
373
384
|
|
374
385
|
def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
|
386
|
+
# Early exit if no enough tokens for the input tokens
|
387
|
+
if self.ceil_paged_tokens(req.extend_input_len) > min(
|
388
|
+
self.cur_rem_tokens, self.rem_total_tokens
|
389
|
+
):
|
390
|
+
return AddReqResult.NO_TOKEN
|
391
|
+
|
375
392
|
def add_req_state(r, insert_sort=False):
|
376
393
|
new_token_ratio = (
|
377
394
|
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
|
@@ -381,15 +398,17 @@ class PrefillAdder:
|
|
381
398
|
)
|
382
399
|
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
|
383
400
|
|
384
|
-
if tokens_left
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
self.req_states
|
401
|
+
if tokens_left <= 0:
|
402
|
+
return
|
403
|
+
|
404
|
+
if not insert_sort:
|
405
|
+
self.req_states.append((tokens_left, tokens_occupied))
|
406
|
+
else:
|
407
|
+
i = 0
|
408
|
+
for i in range(len(self.req_states)):
|
409
|
+
if tokens_left <= self.req_states[i][0]:
|
410
|
+
break
|
411
|
+
self.req_states.insert(i, (tokens_left, tokens_occupied))
|
393
412
|
|
394
413
|
if self.req_states is None:
|
395
414
|
self.req_states = []
|
@@ -406,13 +425,11 @@ class PrefillAdder:
|
|
406
425
|
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
|
407
426
|
tokens_freed = 0
|
408
427
|
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
409
|
-
|
410
|
-
self.req_states[i + 1][0]
|
411
|
-
if i + 1 < len(self.req_states)
|
412
|
-
else tokens_left
|
413
|
-
)
|
428
|
+
# tokens_left gives a reservative calculation as the last token is not stored
|
414
429
|
bs = len(self.req_states) - i
|
415
|
-
|
430
|
+
min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
|
431
|
+
# reserve tokens for corner cases
|
432
|
+
if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
|
416
433
|
return AddReqResult.NO_TOKEN
|
417
434
|
tokens_freed += tokens_occupied
|
418
435
|
|
@@ -422,7 +439,7 @@ class PrefillAdder:
|
|
422
439
|
):
|
423
440
|
# Non-chunked prefill
|
424
441
|
self.can_run_list.append(req)
|
425
|
-
self.
|
442
|
+
self._update_prefill_budget(
|
426
443
|
0,
|
427
444
|
req.extend_input_len,
|
428
445
|
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
|
@@ -438,55 +455,52 @@ class PrefillAdder:
|
|
438
455
|
req.fill_ids = req.fill_ids[:trunc_len]
|
439
456
|
self.can_run_list.append(req)
|
440
457
|
self.new_chunked_req = req
|
441
|
-
self.
|
458
|
+
self._update_prefill_budget(0, trunc_len, 0)
|
442
459
|
|
443
460
|
return self.budget_state()
|
444
461
|
|
445
|
-
def add_one_req(
|
446
|
-
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
|
447
|
-
):
|
462
|
+
def add_one_req(self, req: Req, has_chunked_req: bool):
|
448
463
|
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
|
449
464
|
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
450
465
|
|
451
466
|
total_tokens = req.extend_input_len + min(
|
452
467
|
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
453
468
|
)
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
)
|
469
|
+
|
470
|
+
# adjusting the input_tokens based on host_hit_length and page_size
|
471
|
+
real_input_tokens = req.extend_input_len - req.host_hit_length
|
472
|
+
real_input_tokens = self.ceil_paged_tokens(real_input_tokens)
|
458
473
|
prefix_len = len(req.prefix_indices)
|
459
474
|
|
460
475
|
if total_tokens >= self.rem_total_tokens:
|
461
476
|
return AddReqResult.NO_TOKEN
|
462
477
|
|
463
|
-
if
|
478
|
+
if real_input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
|
464
479
|
return AddReqResult.OTHER
|
465
480
|
|
466
481
|
with self._lock_node(req.last_node):
|
467
|
-
|
482
|
+
# self.rem_total_tokens may decrease after the lock acquisition
|
483
|
+
if total_tokens >= self.rem_total_tokens:
|
468
484
|
return AddReqResult.NO_TOKEN
|
469
485
|
|
470
|
-
if
|
471
|
-
|
472
|
-
|
473
|
-
and req.last_node_global.evicted
|
474
|
-
):
|
475
|
-
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
|
476
|
-
req.last_node_global, req.prefix_indices
|
486
|
+
if req.host_hit_length > 0:
|
487
|
+
new_indices, req.last_node = self.tree_cache.init_load_back(
|
488
|
+
req.last_host_node, req.host_hit_length
|
477
489
|
)
|
490
|
+
req.prefix_indices = torch.cat([req.prefix_indices, new_indices])
|
478
491
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
479
|
-
input_tokens = (
|
480
|
-
-(-req.extend_input_len // self.tree_cache.page_size)
|
481
|
-
* self.tree_cache.page_size
|
482
|
-
)
|
483
492
|
prefix_len = len(req.prefix_indices)
|
484
493
|
|
494
|
+
input_tokens = self.ceil_paged_tokens(req.extend_input_len)
|
495
|
+
|
496
|
+
if input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
|
497
|
+
return AddReqResult.OTHER
|
498
|
+
|
485
499
|
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
486
500
|
# Non-chunked prefill
|
487
501
|
self.can_run_list.append(req)
|
488
502
|
self.tree_cache.inc_lock_ref(req.last_node)
|
489
|
-
self.
|
503
|
+
self._update_prefill_budget(
|
490
504
|
prefix_len,
|
491
505
|
input_tokens,
|
492
506
|
min(
|
@@ -496,7 +510,7 @@ class PrefillAdder:
|
|
496
510
|
)
|
497
511
|
else:
|
498
512
|
# Make sure at least one page is available
|
499
|
-
trunc_len = self.rem_chunk_tokens - self.
|
513
|
+
trunc_len = self.rem_chunk_tokens - self.page_size + 1
|
500
514
|
if trunc_len <= 0:
|
501
515
|
return AddReqResult.OTHER
|
502
516
|
|
@@ -507,6 +521,6 @@ class PrefillAdder:
|
|
507
521
|
self.can_run_list.append(req)
|
508
522
|
self.new_chunked_req = req
|
509
523
|
self.tree_cache.inc_lock_ref(req.last_node)
|
510
|
-
self.
|
524
|
+
self._update_prefill_budget(prefix_len, trunc_len, 0)
|
511
525
|
|
512
526
|
return self.budget_state()
|