sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -38,7 +38,7 @@ import logging
|
|
38
38
|
import threading
|
39
39
|
from enum import Enum, auto
|
40
40
|
from http import HTTPStatus
|
41
|
-
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
41
|
+
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
42
42
|
|
43
43
|
import numpy as np
|
44
44
|
import torch
|
@@ -54,9 +54,10 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
54
54
|
)
|
55
55
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
56
56
|
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
57
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
57
58
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
58
59
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
59
|
-
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
60
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
60
61
|
from sglang.srt.metrics.collector import TimeStats
|
61
62
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
62
63
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
@@ -72,32 +73,35 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
|
72
73
|
|
73
74
|
GLOBAL_SERVER_ARGS_KEYS = [
|
74
75
|
"attention_backend",
|
76
|
+
"mm_attention_backend",
|
75
77
|
"debug_tensor_dump_inject",
|
76
78
|
"debug_tensor_dump_output_folder",
|
77
79
|
"chunked_prefill_size",
|
78
|
-
"deepep_mode",
|
79
80
|
"device",
|
80
81
|
"disable_chunked_prefix_cache",
|
81
82
|
"disable_radix_cache",
|
82
|
-
"enable_deepep_moe",
|
83
83
|
"enable_dp_attention",
|
84
84
|
"enable_two_batch_overlap",
|
85
85
|
"enable_dp_lm_head",
|
86
|
+
"enable_deepep_moe",
|
87
|
+
"deepep_mode",
|
86
88
|
"enable_ep_moe",
|
89
|
+
"enable_flashinfer_moe",
|
90
|
+
"moe_dense_tp_size",
|
91
|
+
"ep_dispatch_algorithm",
|
87
92
|
"deepep_config",
|
93
|
+
"ep_num_redundant_experts",
|
88
94
|
"enable_nan_detection",
|
89
95
|
"flashinfer_mla_disable_ragged",
|
90
96
|
"max_micro_batch_size",
|
91
|
-
"moe_dense_tp_size",
|
92
|
-
"ep_dispatch_algorithm",
|
93
97
|
"disable_shared_experts_fusion",
|
94
98
|
"sampling_backend",
|
95
99
|
"speculative_accept_threshold_acc",
|
96
100
|
"speculative_accept_threshold_single",
|
97
101
|
"torchao_config",
|
98
102
|
"triton_attention_reduce_in_fp32",
|
99
|
-
"
|
100
|
-
"
|
103
|
+
"num_reserved_decode_tokens",
|
104
|
+
"weight_loader_disable_mmap",
|
101
105
|
]
|
102
106
|
|
103
107
|
# Put some global args for easy access
|
@@ -435,7 +439,7 @@ class Req:
|
|
435
439
|
self,
|
436
440
|
rid: str,
|
437
441
|
origin_input_text: str,
|
438
|
-
origin_input_ids:
|
442
|
+
origin_input_ids: List[int],
|
439
443
|
sampling_params: SamplingParams,
|
440
444
|
return_logprob: bool = False,
|
441
445
|
top_logprobs_num: int = 0,
|
@@ -444,6 +448,7 @@ class Req:
|
|
444
448
|
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
445
449
|
lora_path: Optional[str] = None,
|
446
450
|
input_embeds: Optional[List[List[float]]] = None,
|
451
|
+
token_type_ids: List[int] = None,
|
447
452
|
session_id: Optional[str] = None,
|
448
453
|
custom_logit_processor: Optional[str] = None,
|
449
454
|
return_hidden_states: bool = False,
|
@@ -465,10 +470,13 @@ class Req:
|
|
465
470
|
# Each decode stage's output ids
|
466
471
|
self.output_ids = []
|
467
472
|
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
|
468
|
-
self.fill_ids =
|
473
|
+
self.fill_ids = []
|
469
474
|
self.session_id = session_id
|
470
475
|
self.input_embeds = input_embeds
|
471
476
|
|
477
|
+
# for corss-endoder model
|
478
|
+
self.token_type_ids = token_type_ids
|
479
|
+
|
472
480
|
# Sampling info
|
473
481
|
if isinstance(sampling_params.custom_params, dict):
|
474
482
|
sampling_params = copy.copy(sampling_params)
|
@@ -514,13 +522,14 @@ class Req:
|
|
514
522
|
|
515
523
|
# Prefix info
|
516
524
|
# The indices to kv cache for the shared prefix.
|
517
|
-
self.prefix_indices = []
|
525
|
+
self.prefix_indices: torch.Tensor = []
|
518
526
|
# Number of tokens to run prefill.
|
519
527
|
self.extend_input_len = 0
|
520
528
|
# The relative logprob_start_len in an extend batch
|
521
529
|
self.extend_logprob_start_len = 0
|
522
|
-
self.last_node = None
|
523
|
-
self.
|
530
|
+
self.last_node: Any = None
|
531
|
+
self.last_host_node: Any = None
|
532
|
+
self.host_hit_length = 0
|
524
533
|
|
525
534
|
# Whether or not if it is chunked. It increments whenever
|
526
535
|
# it is chunked, and decrement whenever chunked request is
|
@@ -578,6 +587,7 @@ class Req:
|
|
578
587
|
self.output_token_ids_logprobs_idx
|
579
588
|
) = None
|
580
589
|
self.hidden_states: List[List[float]] = []
|
590
|
+
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
|
581
591
|
|
582
592
|
# Embedding (return values)
|
583
593
|
self.embedding = None
|
@@ -639,29 +649,17 @@ class Req:
|
|
639
649
|
def init_next_round_input(
|
640
650
|
self,
|
641
651
|
tree_cache: Optional[BasePrefixCache] = None,
|
642
|
-
enable_hierarchical_cache=False,
|
643
652
|
):
|
644
653
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
645
654
|
if tree_cache is not None:
|
646
|
-
|
647
|
-
|
648
|
-
self.
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
)
|
653
|
-
|
654
|
-
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
655
|
-
rid=self.rid, key=self.adjust_max_prefix_ids()
|
656
|
-
)
|
657
|
-
elif enable_hierarchical_cache:
|
658
|
-
# in case last_node is evicted during scheduling, we need to update the prefix_indices
|
659
|
-
while self.last_node.evicted:
|
660
|
-
self.prefix_indices = self.prefix_indices[
|
661
|
-
: -len(self.last_node.host_value)
|
662
|
-
]
|
663
|
-
self.last_node = self.last_node.parent
|
664
|
-
|
655
|
+
(
|
656
|
+
self.prefix_indices,
|
657
|
+
self.last_node,
|
658
|
+
self.last_host_node,
|
659
|
+
self.host_hit_length,
|
660
|
+
) = tree_cache.match_prefix(
|
661
|
+
key=self.adjust_max_prefix_ids(),
|
662
|
+
)
|
665
663
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
666
664
|
|
667
665
|
def adjust_max_prefix_ids(self):
|
@@ -791,6 +789,7 @@ class Req:
|
|
791
789
|
self.multimodal_inputs = None
|
792
790
|
self.grammar = None
|
793
791
|
self.origin_input_ids = [0] # set it to one token to skip the long prefill
|
792
|
+
self.return_logprob = False
|
794
793
|
self.finished_reason = FINISH_ABORT(
|
795
794
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
796
795
|
)
|
@@ -815,7 +814,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
815
814
|
# Request, memory pool, and cache
|
816
815
|
reqs: List[Req]
|
817
816
|
req_to_token_pool: ReqToTokenPool = None
|
818
|
-
token_to_kv_pool_allocator:
|
817
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
|
819
818
|
tree_cache: BasePrefixCache = None
|
820
819
|
|
821
820
|
# Batch configs
|
@@ -840,6 +839,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
840
839
|
# Batched arguments to model runner
|
841
840
|
input_ids: torch.Tensor = None # shape: [b], int64
|
842
841
|
input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
|
842
|
+
token_type_ids: torch.Tensor = None # shape: [b], int64
|
843
843
|
req_pool_indices: torch.Tensor = None # shape: [b], int64
|
844
844
|
seq_lens: torch.Tensor = None # shape: [b], int64
|
845
845
|
# The output locations of the KV cache
|
@@ -856,6 +856,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
856
856
|
global_num_tokens: Optional[List[int]] = None
|
857
857
|
global_num_tokens_for_logprob: Optional[List[int]] = None
|
858
858
|
can_run_dp_cuda_graph: bool = False
|
859
|
+
is_extend_in_batch: bool = False
|
859
860
|
tbo_split_seq_index: Optional[int] = None
|
860
861
|
global_forward_mode: Optional[ForwardMode] = None
|
861
862
|
|
@@ -902,12 +903,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
902
903
|
# Whether to return hidden states
|
903
904
|
return_hidden_states: bool = False
|
904
905
|
|
906
|
+
# hicache pointer for synchronizing data loading from CPU to GPU
|
907
|
+
hicache_consumer_index: int = 0
|
908
|
+
|
905
909
|
@classmethod
|
906
910
|
def init_new(
|
907
911
|
cls,
|
908
912
|
reqs: List[Req],
|
909
913
|
req_to_token_pool: ReqToTokenPool,
|
910
|
-
token_to_kv_pool_allocator:
|
914
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
911
915
|
tree_cache: BasePrefixCache,
|
912
916
|
model_config: ModelConfig,
|
913
917
|
enable_overlap: bool,
|
@@ -1141,6 +1145,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1141
1145
|
prefix_lens = [len(r.prefix_indices) for r in reqs]
|
1142
1146
|
extend_lens = [r.extend_input_len for r in reqs]
|
1143
1147
|
|
1148
|
+
token_type_ids = [
|
1149
|
+
r.token_type_ids for r in reqs if r.token_type_ids is not None
|
1150
|
+
]
|
1151
|
+
|
1144
1152
|
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
1145
1153
|
self.device, non_blocking=True
|
1146
1154
|
)
|
@@ -1153,6 +1161,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1153
1161
|
prefix_lens_tensor = torch.tensor(
|
1154
1162
|
prefix_lens, dtype=torch.int64, device=self.device
|
1155
1163
|
)
|
1164
|
+
|
1165
|
+
token_type_ids_tensor = None
|
1166
|
+
if len(token_type_ids) > 0:
|
1167
|
+
token_type_ids_tensor = torch.tensor(
|
1168
|
+
sum(token_type_ids, []), dtype=torch.int64
|
1169
|
+
).to(self.device, non_blocking=True)
|
1170
|
+
|
1156
1171
|
extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
|
1157
1172
|
|
1158
1173
|
# Copy prefix and do some basic check
|
@@ -1268,6 +1283,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1268
1283
|
self.device, non_blocking=True
|
1269
1284
|
)
|
1270
1285
|
self.multimodal_inputs = multimodal_inputs
|
1286
|
+
self.token_type_ids = token_type_ids_tensor
|
1271
1287
|
self.seq_lens_sum = sum(seq_lens)
|
1272
1288
|
|
1273
1289
|
if self.return_logprob:
|
@@ -1347,7 +1363,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1347
1363
|
return len(self.reqs)
|
1348
1364
|
# In the decoding phase, the length of a request's KV cache should be
|
1349
1365
|
# the total length of the request minus 1
|
1350
|
-
return
|
1366
|
+
return (
|
1367
|
+
sum(1 for req in self.reqs if req.seqlen % page_size == 0)
|
1368
|
+
if self.enable_overlap
|
1369
|
+
else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
|
1370
|
+
)
|
1351
1371
|
|
1352
1372
|
def check_decode_mem(self, buf_multiplier=1):
|
1353
1373
|
tokens_required = (
|
@@ -1414,6 +1434,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1414
1434
|
req = self.reqs[idx]
|
1415
1435
|
retracted_reqs.append(req)
|
1416
1436
|
|
1437
|
+
if server_args.disaggregation_mode == "decode":
|
1438
|
+
req.offload_kv_cache(
|
1439
|
+
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
1440
|
+
)
|
1441
|
+
|
1417
1442
|
if isinstance(self.tree_cache, ChunkCache):
|
1418
1443
|
# ChunkCache does not have eviction
|
1419
1444
|
token_indices = self.req_to_token_pool.req_to_token[
|
@@ -1445,6 +1470,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1445
1470
|
|
1446
1471
|
req.reset_for_retract()
|
1447
1472
|
|
1473
|
+
if len(retracted_reqs) == 0:
|
1474
|
+
# Corner case: only one request left
|
1475
|
+
raise ValueError(
|
1476
|
+
"Failed to retract any request. No space left for only one request."
|
1477
|
+
)
|
1478
|
+
|
1448
1479
|
self.filter_batch(keep_indices=sorted_indices)
|
1449
1480
|
|
1450
1481
|
# Reqs in batch are filtered
|
@@ -1702,8 +1733,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1702
1733
|
lora_paths=[req.lora_path for req in self.reqs],
|
1703
1734
|
sampling_info=self.sampling_info,
|
1704
1735
|
input_embeds=self.input_embeds,
|
1736
|
+
token_type_ids=self.token_type_ids,
|
1705
1737
|
spec_algorithm=self.spec_algorithm,
|
1706
1738
|
spec_info=self.spec_info,
|
1739
|
+
hicache_consumer_index=self.hicache_consumer_index,
|
1707
1740
|
capture_hidden_mode=(
|
1708
1741
|
CaptureHiddenMode.FULL
|
1709
1742
|
if self.return_hidden_states
|
@@ -1730,11 +1763,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1730
1763
|
decoding_reqs=self.decoding_reqs,
|
1731
1764
|
spec_algorithm=self.spec_algorithm,
|
1732
1765
|
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
1766
|
+
global_num_tokens=self.global_num_tokens,
|
1767
|
+
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
1768
|
+
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
1769
|
+
is_extend_in_batch=self.is_extend_in_batch,
|
1733
1770
|
)
|
1734
1771
|
|
1735
1772
|
def __str__(self):
|
1736
1773
|
return (
|
1737
|
-
f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
|
1774
|
+
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
|
1738
1775
|
f"#req={(len(self.reqs))})"
|
1739
1776
|
)
|
1740
1777
|
|
@@ -1795,11 +1832,16 @@ class ModelWorkerBatch:
|
|
1795
1832
|
# The input Embeds
|
1796
1833
|
input_embeds: Optional[torch.tensor] = None
|
1797
1834
|
|
1835
|
+
# For corss-encoder model
|
1836
|
+
token_type_ids: Optional[torch.Tensor] = None
|
1837
|
+
|
1798
1838
|
# Speculative decoding
|
1799
1839
|
spec_algorithm: SpeculativeAlgorithm = None
|
1800
1840
|
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
1801
1841
|
# If set, the output of the batch contains the hidden states of the run.
|
1802
1842
|
capture_hidden_mode: CaptureHiddenMode = None
|
1843
|
+
spec_num_draft_tokens: Optional[int] = None
|
1844
|
+
hicache_consumer_index: int = 0
|
1803
1845
|
|
1804
1846
|
# Overlap event
|
1805
1847
|
launch_done: Optional[threading.Event] = None
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
# Copyright 2023-2024 SGLang Team
|
2
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
5
|
# you may not use this file except in compliance with the License.
|
@@ -18,15 +20,17 @@ import random
|
|
18
20
|
from collections import defaultdict
|
19
21
|
from contextlib import contextmanager
|
20
22
|
from enum import Enum, auto
|
21
|
-
from typing import Dict, List, Optional, Set, Union
|
23
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
|
22
24
|
|
23
25
|
import torch
|
24
26
|
|
25
27
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
26
28
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
27
|
-
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
28
29
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
29
30
|
|
31
|
+
if TYPE_CHECKING:
|
32
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
33
|
+
|
30
34
|
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
31
35
|
# This can prevent the server from being too conservative.
|
32
36
|
# Note that this only clips the estimation in the scheduler but does not change the stop
|
@@ -51,6 +55,9 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
|
|
51
55
|
)
|
52
56
|
|
53
57
|
|
58
|
+
IGNORE_EOS_RESERVE_TOKENS = 1
|
59
|
+
|
60
|
+
|
54
61
|
class CacheAwarePolicy(Enum):
|
55
62
|
"""Scheduling policies that are aware of the tree cache."""
|
56
63
|
|
@@ -90,7 +97,7 @@ class SchedulePolicy:
|
|
90
97
|
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
91
98
|
if self.policy == CacheAgnosticPolicy.FCFS:
|
92
99
|
# A shortcut for FCFS
|
93
|
-
return
|
100
|
+
return False
|
94
101
|
|
95
102
|
policy = self._determine_active_policy(waiting_queue)
|
96
103
|
|
@@ -134,7 +141,7 @@ class SchedulePolicy:
|
|
134
141
|
"""
|
135
142
|
try:
|
136
143
|
policy_enum = CacheAwarePolicy(policy)
|
137
|
-
if tree_cache
|
144
|
+
if getattr(tree_cache, "disable", True):
|
138
145
|
# If tree_cache is disabled, using CacheAgnosticPolicy policy
|
139
146
|
return CacheAgnosticPolicy.FCFS
|
140
147
|
return policy_enum
|
@@ -158,14 +165,9 @@ class SchedulePolicy:
|
|
158
165
|
prefix_ids = r.adjust_max_prefix_ids()
|
159
166
|
|
160
167
|
# NOTE: the prefix_indices must always be aligned with last_node
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
)
|
165
|
-
else:
|
166
|
-
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
167
|
-
rid=r.rid, key=prefix_ids
|
168
|
-
)
|
168
|
+
r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = (
|
169
|
+
self.tree_cache.match_prefix(rid=r.rid, key=prefix_ids)
|
170
|
+
)
|
169
171
|
|
170
172
|
# NOTE(sang): This logic is for in-batch prefix caching;
|
171
173
|
# If there are more than 1 request that have small matching prefix from
|
@@ -175,7 +177,7 @@ class SchedulePolicy:
|
|
175
177
|
# threshold means we cannot use in-batch prefix caching for short prefixes.
|
176
178
|
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
|
177
179
|
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
|
178
|
-
in_batch_matching_prefixes, _ = (
|
180
|
+
in_batch_matching_prefixes, _, _, _ = (
|
179
181
|
self.waiting_queue_radix_tree.match_prefix(
|
180
182
|
rid=r.rid, key=prefix_ids
|
181
183
|
)
|
@@ -268,14 +270,16 @@ class AddReqResult(Enum):
|
|
268
270
|
class PrefillAdder:
|
269
271
|
def __init__(
|
270
272
|
self,
|
273
|
+
page_size: int,
|
271
274
|
tree_cache: BasePrefixCache,
|
272
|
-
token_to_kv_pool_allocator:
|
275
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
273
276
|
running_batch: ScheduleBatch,
|
274
277
|
new_token_ratio: float,
|
275
278
|
rem_input_tokens: int,
|
276
279
|
rem_chunk_tokens: Optional[int],
|
277
280
|
mixed_with_decode_tokens: int = 0,
|
278
281
|
):
|
282
|
+
self.page_size = page_size
|
279
283
|
self.tree_cache = tree_cache
|
280
284
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
281
285
|
self.running_batch = running_batch
|
@@ -292,6 +296,7 @@ class PrefillAdder:
|
|
292
296
|
self.can_run_list = []
|
293
297
|
self.new_chunked_req = None
|
294
298
|
self.log_hit_tokens = 0
|
299
|
+
# TODO(lsyin): report the real input tokens excluding page alignment
|
295
300
|
self.log_input_tokens = 0
|
296
301
|
|
297
302
|
if running_batch is not None:
|
@@ -322,6 +327,9 @@ class PrefillAdder:
|
|
322
327
|
- self.cur_rem_token_offset
|
323
328
|
)
|
324
329
|
|
330
|
+
def ceil_paged_tokens(self, tokens: int) -> int:
|
331
|
+
return -(-tokens // self.page_size) * self.page_size
|
332
|
+
|
325
333
|
def budget_state(self):
|
326
334
|
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
|
327
335
|
return AddReqResult.NO_TOKEN
|
@@ -333,9 +341,12 @@ class PrefillAdder:
|
|
333
341
|
|
334
342
|
return AddReqResult.CONTINUE
|
335
343
|
|
336
|
-
def
|
344
|
+
def _update_prefill_budget(
|
337
345
|
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
338
346
|
):
|
347
|
+
# TODO(lsyin): check this workaround logic, which only ensures the prefill will not out of memory, and may be too conservative
|
348
|
+
extend_input_len = self.ceil_paged_tokens(extend_input_len)
|
349
|
+
|
339
350
|
self.rem_total_token_offset += extend_input_len + max_new_tokens
|
340
351
|
self.cur_rem_token_offset += extend_input_len
|
341
352
|
self.rem_input_tokens -= extend_input_len
|
@@ -350,7 +361,7 @@ class PrefillAdder:
|
|
350
361
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
351
362
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
352
363
|
self.can_run_list.append(req)
|
353
|
-
self.
|
364
|
+
self._update_prefill_budget(
|
354
365
|
0,
|
355
366
|
req.extend_input_len,
|
356
367
|
(
|
@@ -372,6 +383,12 @@ class PrefillAdder:
|
|
372
383
|
self.tree_cache.dec_lock_ref(last_node)
|
373
384
|
|
374
385
|
def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
|
386
|
+
# Early exit if no enough tokens for the input tokens
|
387
|
+
if self.ceil_paged_tokens(req.extend_input_len) > min(
|
388
|
+
self.cur_rem_tokens, self.rem_total_tokens
|
389
|
+
):
|
390
|
+
return AddReqResult.NO_TOKEN
|
391
|
+
|
375
392
|
def add_req_state(r, insert_sort=False):
|
376
393
|
new_token_ratio = (
|
377
394
|
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
|
@@ -381,15 +398,17 @@ class PrefillAdder:
|
|
381
398
|
)
|
382
399
|
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
|
383
400
|
|
384
|
-
if tokens_left
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
self.req_states
|
401
|
+
if tokens_left <= 0:
|
402
|
+
return
|
403
|
+
|
404
|
+
if not insert_sort:
|
405
|
+
self.req_states.append((tokens_left, tokens_occupied))
|
406
|
+
else:
|
407
|
+
i = 0
|
408
|
+
for i in range(len(self.req_states)):
|
409
|
+
if tokens_left <= self.req_states[i][0]:
|
410
|
+
break
|
411
|
+
self.req_states.insert(i, (tokens_left, tokens_occupied))
|
393
412
|
|
394
413
|
if self.req_states is None:
|
395
414
|
self.req_states = []
|
@@ -406,13 +425,11 @@ class PrefillAdder:
|
|
406
425
|
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
|
407
426
|
tokens_freed = 0
|
408
427
|
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
409
|
-
|
410
|
-
self.req_states[i + 1][0]
|
411
|
-
if i + 1 < len(self.req_states)
|
412
|
-
else tokens_left
|
413
|
-
)
|
428
|
+
# tokens_left gives a reservative calculation as the last token is not stored
|
414
429
|
bs = len(self.req_states) - i
|
415
|
-
|
430
|
+
min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
|
431
|
+
# reserve tokens for corner cases
|
432
|
+
if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
|
416
433
|
return AddReqResult.NO_TOKEN
|
417
434
|
tokens_freed += tokens_occupied
|
418
435
|
|
@@ -422,7 +439,7 @@ class PrefillAdder:
|
|
422
439
|
):
|
423
440
|
# Non-chunked prefill
|
424
441
|
self.can_run_list.append(req)
|
425
|
-
self.
|
442
|
+
self._update_prefill_budget(
|
426
443
|
0,
|
427
444
|
req.extend_input_len,
|
428
445
|
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
|
@@ -438,55 +455,52 @@ class PrefillAdder:
|
|
438
455
|
req.fill_ids = req.fill_ids[:trunc_len]
|
439
456
|
self.can_run_list.append(req)
|
440
457
|
self.new_chunked_req = req
|
441
|
-
self.
|
458
|
+
self._update_prefill_budget(0, trunc_len, 0)
|
442
459
|
|
443
460
|
return self.budget_state()
|
444
461
|
|
445
|
-
def add_one_req(
|
446
|
-
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
|
447
|
-
):
|
462
|
+
def add_one_req(self, req: Req, has_chunked_req: bool):
|
448
463
|
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
|
449
464
|
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
450
465
|
|
451
466
|
total_tokens = req.extend_input_len + min(
|
452
467
|
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
453
468
|
)
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
)
|
469
|
+
|
470
|
+
# adjusting the input_tokens based on host_hit_length and page_size
|
471
|
+
real_input_tokens = req.extend_input_len - req.host_hit_length
|
472
|
+
real_input_tokens = self.ceil_paged_tokens(real_input_tokens)
|
458
473
|
prefix_len = len(req.prefix_indices)
|
459
474
|
|
460
475
|
if total_tokens >= self.rem_total_tokens:
|
461
476
|
return AddReqResult.NO_TOKEN
|
462
477
|
|
463
|
-
if
|
478
|
+
if real_input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
|
464
479
|
return AddReqResult.OTHER
|
465
480
|
|
466
481
|
with self._lock_node(req.last_node):
|
467
|
-
|
482
|
+
# self.rem_total_tokens may decrease after the lock acquisition
|
483
|
+
if total_tokens >= self.rem_total_tokens:
|
468
484
|
return AddReqResult.NO_TOKEN
|
469
485
|
|
470
|
-
if
|
471
|
-
|
472
|
-
|
473
|
-
and req.last_node_global.evicted
|
474
|
-
):
|
475
|
-
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
|
476
|
-
req.last_node_global, req.prefix_indices
|
486
|
+
if req.host_hit_length > 0:
|
487
|
+
new_indices, req.last_node = self.tree_cache.init_load_back(
|
488
|
+
req.last_host_node, req.host_hit_length
|
477
489
|
)
|
490
|
+
req.prefix_indices = torch.cat([req.prefix_indices, new_indices])
|
478
491
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
479
|
-
input_tokens = (
|
480
|
-
-(-req.extend_input_len // self.tree_cache.page_size)
|
481
|
-
* self.tree_cache.page_size
|
482
|
-
)
|
483
492
|
prefix_len = len(req.prefix_indices)
|
484
493
|
|
494
|
+
input_tokens = self.ceil_paged_tokens(req.extend_input_len)
|
495
|
+
|
496
|
+
if input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
|
497
|
+
return AddReqResult.OTHER
|
498
|
+
|
485
499
|
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
486
500
|
# Non-chunked prefill
|
487
501
|
self.can_run_list.append(req)
|
488
502
|
self.tree_cache.inc_lock_ref(req.last_node)
|
489
|
-
self.
|
503
|
+
self._update_prefill_budget(
|
490
504
|
prefix_len,
|
491
505
|
input_tokens,
|
492
506
|
min(
|
@@ -496,7 +510,7 @@ class PrefillAdder:
|
|
496
510
|
)
|
497
511
|
else:
|
498
512
|
# Make sure at least one page is available
|
499
|
-
trunc_len = self.rem_chunk_tokens - self.
|
513
|
+
trunc_len = self.rem_chunk_tokens - self.page_size + 1
|
500
514
|
if trunc_len <= 0:
|
501
515
|
return AddReqResult.OTHER
|
502
516
|
|
@@ -507,6 +521,6 @@ class PrefillAdder:
|
|
507
521
|
self.can_run_list.append(req)
|
508
522
|
self.new_chunked_req = req
|
509
523
|
self.tree_cache.inc_lock_ref(req.last_node)
|
510
|
-
self.
|
524
|
+
self._update_prefill_budget(prefix_len, trunc_len, 0)
|
511
525
|
|
512
526
|
return self.budget_state()
|