sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|
29
29
|
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
30
30
|
"""
|
31
31
|
|
32
|
+
import copy
|
32
33
|
import dataclasses
|
33
34
|
import logging
|
34
35
|
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
@@ -43,14 +44,15 @@ from sglang.srt.configs.model_config import ModelConfig
|
|
43
44
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
44
45
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
45
46
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
46
|
-
from sglang.srt.mem_cache.memory_pool import
|
47
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
47
48
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
48
49
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
49
50
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
50
51
|
from sglang.srt.server_args import ServerArgs
|
51
52
|
|
52
53
|
if TYPE_CHECKING:
|
53
|
-
from sglang.srt.speculative.
|
54
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
55
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
54
56
|
|
55
57
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
56
58
|
|
@@ -65,8 +67,11 @@ global_server_args_dict = {
|
|
65
67
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
66
68
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
67
69
|
"device": ServerArgs.device,
|
70
|
+
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
71
|
+
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
68
72
|
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
69
73
|
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
74
|
+
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
70
75
|
}
|
71
76
|
|
72
77
|
logger = logging.getLogger(__name__)
|
@@ -229,12 +234,14 @@ class Req:
|
|
229
234
|
sampling_params: SamplingParams,
|
230
235
|
return_logprob: bool = False,
|
231
236
|
top_logprobs_num: int = 0,
|
237
|
+
token_ids_logprob: List[int] = None,
|
232
238
|
stream: bool = False,
|
233
239
|
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
234
240
|
lora_path: Optional[str] = None,
|
235
241
|
input_embeds: Optional[List[List[float]]] = None,
|
236
242
|
session_id: Optional[str] = None,
|
237
243
|
custom_logit_processor: Optional[str] = None,
|
244
|
+
return_hidden_states: bool = False,
|
238
245
|
eos_token_ids: Optional[Set[int]] = None,
|
239
246
|
):
|
240
247
|
# Input and output info
|
@@ -254,16 +261,27 @@ class Req:
|
|
254
261
|
self.input_embeds = input_embeds
|
255
262
|
|
256
263
|
# Sampling info
|
264
|
+
if isinstance(sampling_params.custom_params, dict):
|
265
|
+
sampling_params = copy.copy(sampling_params)
|
266
|
+
sampling_params.custom_params = sampling_params.custom_params | {
|
267
|
+
"__req__": self
|
268
|
+
}
|
257
269
|
self.sampling_params = sampling_params
|
270
|
+
|
258
271
|
self.custom_logit_processor = custom_logit_processor
|
272
|
+
self.return_hidden_states = return_hidden_states
|
259
273
|
|
260
274
|
# Memory pool info
|
261
|
-
self.req_pool_idx = None
|
275
|
+
self.req_pool_idx: Optional[int] = None
|
262
276
|
|
263
277
|
# Check finish
|
264
278
|
self.tokenizer = None
|
265
279
|
self.finished_reason = None
|
280
|
+
# If we want to abort the request in the middle of the event loop, set this to true
|
281
|
+
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
|
266
282
|
self.to_abort = False
|
283
|
+
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
|
284
|
+
self.to_abort_message: str = "Unknown error"
|
267
285
|
self.stream = stream
|
268
286
|
self.eos_token_ids = eos_token_ids
|
269
287
|
|
@@ -276,7 +294,6 @@ class Req:
|
|
276
294
|
# 1: surr_offset
|
277
295
|
# 2: read_offset
|
278
296
|
# 3: last token
|
279
|
-
self.vid = 0 # version id to sync decode status with in detokenizer_manager
|
280
297
|
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
281
298
|
self.read_offset = None
|
282
299
|
self.decoded_text = ""
|
@@ -285,47 +302,58 @@ class Req:
|
|
285
302
|
self.image_inputs: Optional[ImageInputs] = None
|
286
303
|
|
287
304
|
# Prefix info
|
305
|
+
# The indices to kv cache for the shared prefix.
|
288
306
|
self.prefix_indices = []
|
289
|
-
#
|
290
|
-
# Updated if chunked.
|
307
|
+
# Number of tokens to run prefill.
|
291
308
|
self.extend_input_len = 0
|
309
|
+
# The relative logprob_start_len in an extend batch
|
310
|
+
self.extend_logprob_start_len = 0
|
292
311
|
self.last_node = None
|
293
312
|
|
294
|
-
#
|
295
|
-
|
313
|
+
# Whether or not if it is chunked. It increments whenever
|
314
|
+
# it is chunked, and decrement whenever chunked request is
|
315
|
+
# processed.
|
316
|
+
self.is_chunked = 0
|
296
317
|
|
297
318
|
# For retraction
|
298
319
|
self.is_retracted = False
|
299
320
|
|
300
321
|
# Logprobs (arguments)
|
301
322
|
self.return_logprob = return_logprob
|
323
|
+
# Start index to compute logprob from.
|
302
324
|
self.logprob_start_len = 0
|
303
325
|
self.top_logprobs_num = top_logprobs_num
|
326
|
+
self.token_ids_logprob = token_ids_logprob
|
304
327
|
|
305
328
|
# Logprobs (return values)
|
306
329
|
self.input_token_logprobs_val: Optional[List[float]] = None
|
307
330
|
self.input_token_logprobs_idx: Optional[List[int]] = None
|
308
331
|
self.input_top_logprobs_val: Optional[List[float]] = None
|
309
332
|
self.input_top_logprobs_idx: Optional[List[int]] = None
|
333
|
+
self.input_token_ids_logprobs_val: Optional[List[float]] = None
|
334
|
+
self.input_token_ids_logprobs_idx: Optional[List[int]] = None
|
335
|
+
# Temporary holder to store input_token_logprobs.
|
336
|
+
self.input_token_logprobs: Optional[List[Tuple[int]]] = None
|
337
|
+
self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
|
338
|
+
self.temp_input_top_logprobs_idx: Optional[List[int]] = None
|
339
|
+
self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
|
340
|
+
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
|
310
341
|
|
311
342
|
if return_logprob:
|
312
343
|
self.output_token_logprobs_val = []
|
313
344
|
self.output_token_logprobs_idx = []
|
314
345
|
self.output_top_logprobs_val = []
|
315
346
|
self.output_top_logprobs_idx = []
|
347
|
+
self.output_token_ids_logprobs_val = []
|
348
|
+
self.output_token_ids_logprobs_idx = []
|
316
349
|
else:
|
317
350
|
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
318
351
|
self.output_top_logprobs_val
|
319
|
-
) = self.output_top_logprobs_idx =
|
352
|
+
) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
|
353
|
+
self.output_token_ids_logprobs_idx
|
354
|
+
) = None
|
320
355
|
self.hidden_states = []
|
321
356
|
|
322
|
-
# Logprobs (internal values)
|
323
|
-
# The tokens is prefilled but need to be considered as decode tokens
|
324
|
-
# and should be updated for the decode logprobs
|
325
|
-
self.last_update_decode_tokens = 0
|
326
|
-
# The relative logprob_start_len in an extend batch
|
327
|
-
self.extend_logprob_start_len = 0
|
328
|
-
|
329
357
|
# Embedding (return values)
|
330
358
|
self.embedding = None
|
331
359
|
|
@@ -341,6 +369,10 @@ class Req:
|
|
341
369
|
self.spec_verify_ct = 0
|
342
370
|
self.lora_path = lora_path
|
343
371
|
|
372
|
+
@property
|
373
|
+
def seqlen(self):
|
374
|
+
return len(self.origin_input_ids) + len(self.output_ids)
|
375
|
+
|
344
376
|
def extend_image_inputs(self, image_inputs):
|
345
377
|
if self.image_inputs is None:
|
346
378
|
self.image_inputs = image_inputs
|
@@ -418,7 +450,9 @@ class Req:
|
|
418
450
|
return
|
419
451
|
|
420
452
|
if self.to_abort:
|
421
|
-
self.finished_reason = FINISH_ABORT(
|
453
|
+
self.finished_reason = FINISH_ABORT(
|
454
|
+
message=self.to_abort_message,
|
455
|
+
)
|
422
456
|
return
|
423
457
|
|
424
458
|
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
@@ -458,81 +492,22 @@ class Req:
|
|
458
492
|
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
459
493
|
return
|
460
494
|
|
461
|
-
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
462
|
-
if self.origin_input_text is None:
|
463
|
-
# Recovering text can only use unpadded ids
|
464
|
-
self.origin_input_text = self.tokenizer.decode(
|
465
|
-
self.origin_input_ids_unpadded
|
466
|
-
)
|
467
|
-
|
468
|
-
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
|
469
|
-
all_ids = self.tokenizer.encode(all_text)
|
470
|
-
if not all_ids:
|
471
|
-
logger.warning("Encoded all_text resulted in empty all_ids")
|
472
|
-
return False
|
473
|
-
|
474
|
-
prompt_tokens = len(self.origin_input_ids_unpadded)
|
475
|
-
if prompt_tokens > len(all_ids):
|
476
|
-
logger.warning("prompt_tokens is larger than encoded all_ids")
|
477
|
-
return False
|
478
|
-
|
479
|
-
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
480
|
-
# TODO(lsyin): fix token fusion
|
481
|
-
logger.warning(
|
482
|
-
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
483
|
-
)
|
484
|
-
return False
|
485
|
-
|
486
|
-
old_output_ids = self.output_ids
|
487
|
-
self.output_ids = all_ids[prompt_tokens:]
|
488
|
-
self.decoded_text = self.decoded_text + jump_forward_str
|
489
|
-
self.surr_offset = prompt_tokens
|
490
|
-
self.read_offset = len(all_ids)
|
491
|
-
|
492
|
-
# NOTE: A trick to reduce the surrouding tokens decoding overhead
|
493
|
-
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
|
494
|
-
surr_text_ = self.tokenizer.decode(
|
495
|
-
all_ids[self.read_offset - i : self.read_offset]
|
496
|
-
)
|
497
|
-
if not surr_text_.endswith("�"):
|
498
|
-
self.surr_offset = self.read_offset - i
|
499
|
-
break
|
500
|
-
|
501
|
-
# update the inner state of the grammar
|
502
|
-
self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
|
503
|
-
|
504
|
-
if self.return_logprob:
|
505
|
-
# For fast-forward part's logprobs
|
506
|
-
k = 0
|
507
|
-
for i, old_id in enumerate(old_output_ids):
|
508
|
-
if old_id == self.output_ids[i]:
|
509
|
-
k = k + 1
|
510
|
-
else:
|
511
|
-
break
|
512
|
-
self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
|
513
|
-
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
|
514
|
-
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
|
515
|
-
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
|
516
|
-
self.logprob_start_len = prompt_tokens + k
|
517
|
-
self.last_update_decode_tokens = len(self.output_ids) - k
|
518
|
-
|
519
|
-
return True
|
520
|
-
|
521
495
|
def reset_for_retract(self):
|
522
496
|
self.prefix_indices = []
|
523
497
|
self.last_node = None
|
524
498
|
self.extend_input_len = 0
|
525
499
|
self.is_retracted = True
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
self.
|
530
|
-
self.
|
500
|
+
self.input_token_logprobs = None
|
501
|
+
self.temp_input_top_logprobs_val = None
|
502
|
+
self.temp_input_top_logprobs_idx = None
|
503
|
+
self.extend_logprob_start_len = 0
|
504
|
+
self.is_chunked = 0
|
505
|
+
self.req_pool_idx = None
|
531
506
|
|
532
507
|
def __repr__(self):
|
533
508
|
return (
|
534
|
-
f"rid
|
535
|
-
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
|
509
|
+
f"Req(rid={self.rid}, "
|
510
|
+
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
|
536
511
|
)
|
537
512
|
|
538
513
|
|
@@ -546,7 +521,7 @@ class ScheduleBatch:
|
|
546
521
|
# Request, memory pool, and cache
|
547
522
|
reqs: List[Req]
|
548
523
|
req_to_token_pool: ReqToTokenPool = None
|
549
|
-
|
524
|
+
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
|
550
525
|
tree_cache: BasePrefixCache = None
|
551
526
|
|
552
527
|
# Batch configs
|
@@ -572,11 +547,13 @@ class ScheduleBatch:
|
|
572
547
|
|
573
548
|
# For DP attention
|
574
549
|
global_num_tokens: Optional[List[int]] = None
|
550
|
+
global_num_tokens_for_logprob: Optional[List[int]] = None
|
575
551
|
can_run_dp_cuda_graph: bool = False
|
576
552
|
|
577
553
|
# For processing logprobs
|
578
554
|
return_logprob: bool = False
|
579
555
|
top_logprobs_nums: Optional[List[int]] = None
|
556
|
+
token_ids_logprobs: Optional[List[List[int]]] = None
|
580
557
|
|
581
558
|
# For extend and mixed chunekd prefill
|
582
559
|
prefix_lens: List[int] = None
|
@@ -584,6 +561,8 @@ class ScheduleBatch:
|
|
584
561
|
extend_num_tokens: int = None
|
585
562
|
decoding_reqs: List[Req] = None
|
586
563
|
extend_logprob_start_lens: List[int] = None
|
564
|
+
# It comes empty list if logprob is not required.
|
565
|
+
extend_input_logprob_token_ids: Optional[torch.Tensor] = None
|
587
566
|
|
588
567
|
# For encoder-decoder
|
589
568
|
encoder_cached: Optional[List[bool]] = None
|
@@ -602,12 +581,12 @@ class ScheduleBatch:
|
|
602
581
|
|
603
582
|
# Speculative decoding
|
604
583
|
spec_algorithm: SpeculativeAlgorithm = None
|
605
|
-
spec_info: Optional[
|
584
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
|
606
585
|
|
607
586
|
# Enable custom logit processor
|
608
587
|
enable_custom_logit_processor: bool = False
|
609
588
|
|
610
|
-
#
|
589
|
+
# Whether to return hidden states
|
611
590
|
return_hidden_states: bool = False
|
612
591
|
|
613
592
|
@classmethod
|
@@ -615,18 +594,17 @@ class ScheduleBatch:
|
|
615
594
|
cls,
|
616
595
|
reqs: List[Req],
|
617
596
|
req_to_token_pool: ReqToTokenPool,
|
618
|
-
|
597
|
+
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
619
598
|
tree_cache: BasePrefixCache,
|
620
599
|
model_config: ModelConfig,
|
621
600
|
enable_overlap: bool,
|
622
601
|
spec_algorithm: SpeculativeAlgorithm,
|
623
602
|
enable_custom_logit_processor: bool,
|
624
|
-
return_hidden_states: bool = False,
|
625
603
|
):
|
626
604
|
return cls(
|
627
605
|
reqs=reqs,
|
628
606
|
req_to_token_pool=req_to_token_pool,
|
629
|
-
|
607
|
+
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
630
608
|
tree_cache=tree_cache,
|
631
609
|
model_config=model_config,
|
632
610
|
enable_overlap=enable_overlap,
|
@@ -636,7 +614,7 @@ class ScheduleBatch:
|
|
636
614
|
device=req_to_token_pool.device,
|
637
615
|
spec_algorithm=spec_algorithm,
|
638
616
|
enable_custom_logit_processor=enable_custom_logit_processor,
|
639
|
-
return_hidden_states=return_hidden_states,
|
617
|
+
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
640
618
|
)
|
641
619
|
|
642
620
|
def batch_size(self):
|
@@ -649,25 +627,27 @@ class ScheduleBatch:
|
|
649
627
|
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
650
628
|
if req_pool_indices is None:
|
651
629
|
raise RuntimeError(
|
652
|
-
"
|
653
|
-
"Please set a smaller number for `--max-running-requests`."
|
630
|
+
"alloc_req_slots runs out of memory. "
|
631
|
+
"Please set a smaller number for `--max-running-requests`. "
|
632
|
+
f"{self.req_to_token_pool.available_size()=}, "
|
633
|
+
f"{num_reqs=}, "
|
654
634
|
)
|
655
635
|
return req_pool_indices
|
656
636
|
|
657
637
|
def alloc_token_slots(self, num_tokens: int):
|
658
|
-
out_cache_loc = self.
|
638
|
+
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
659
639
|
|
660
640
|
if out_cache_loc is None:
|
661
641
|
if self.tree_cache is not None:
|
662
|
-
self.tree_cache.evict(num_tokens, self.
|
663
|
-
out_cache_loc = self.
|
642
|
+
self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free)
|
643
|
+
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
664
644
|
|
665
645
|
if out_cache_loc is None:
|
666
646
|
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
667
647
|
logger.error(
|
668
648
|
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
669
649
|
f"Try to allocate {num_tokens} tokens.\n"
|
670
|
-
f"Avaliable tokens: {self.
|
650
|
+
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
671
651
|
)
|
672
652
|
if self.tree_cache is not None:
|
673
653
|
self.tree_cache.pretty_print()
|
@@ -761,6 +741,7 @@ class ScheduleBatch:
|
|
761
741
|
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
762
742
|
|
763
743
|
input_embeds = []
|
744
|
+
extend_input_logprob_token_ids = []
|
764
745
|
|
765
746
|
pt = 0
|
766
747
|
for i, req in enumerate(reqs):
|
@@ -779,22 +760,64 @@ class ScheduleBatch:
|
|
779
760
|
# If req.input_embeds is already a list, append its content directly
|
780
761
|
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
781
762
|
|
782
|
-
if req.return_logprob:
|
783
|
-
# Compute the relative logprob_start_len in an extend batch
|
784
|
-
if req.logprob_start_len >= pre_len:
|
785
|
-
extend_logprob_start_len = min(
|
786
|
-
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
787
|
-
)
|
788
|
-
else:
|
789
|
-
raise RuntimeError(
|
790
|
-
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
|
791
|
-
)
|
792
|
-
req.extend_logprob_start_len = extend_logprob_start_len
|
793
|
-
|
794
763
|
req.cached_tokens += pre_len - req.already_computed
|
795
764
|
req.already_computed = seq_len
|
796
765
|
req.is_retracted = False
|
797
766
|
pre_lens.append(pre_len)
|
767
|
+
# Compute the relative logprob_start_len in an extend batch
|
768
|
+
if req.logprob_start_len >= pre_len:
|
769
|
+
req.extend_logprob_start_len = min(
|
770
|
+
req.logprob_start_len - pre_len,
|
771
|
+
req.extend_input_len,
|
772
|
+
req.seqlen - 1,
|
773
|
+
)
|
774
|
+
else:
|
775
|
+
req.extend_logprob_start_len = 0
|
776
|
+
|
777
|
+
if self.return_logprob:
|
778
|
+
# Find input logprob token ids.
|
779
|
+
# First, find a global index within origin_input_ids and slide it by 1
|
780
|
+
# to compute input logprobs. It is because you need the next token
|
781
|
+
# to compute input logprobs. E.g., (chunk size 2)
|
782
|
+
#
|
783
|
+
# input_logprobs = [1, 2, 3, 4]
|
784
|
+
# fill_ids = [1, 2]
|
785
|
+
# extend_input_logprob_token_id = [2, 3]
|
786
|
+
#
|
787
|
+
# Note that it can also overflow. In this case, we pad it with 0.
|
788
|
+
# input_logprobs = [1, 2, 3, 4]
|
789
|
+
# fill_ids = [3, 4]
|
790
|
+
# extend_input_logprob_token_id = [4, 0]
|
791
|
+
global_start_idx, global_end_idx = (
|
792
|
+
len(req.prefix_indices),
|
793
|
+
len(req.fill_ids),
|
794
|
+
)
|
795
|
+
# Apply logprob_start_len
|
796
|
+
if global_start_idx < req.logprob_start_len:
|
797
|
+
global_start_idx = req.logprob_start_len
|
798
|
+
|
799
|
+
logprob_token_ids = req.origin_input_ids[
|
800
|
+
global_start_idx + 1 : global_end_idx + 1
|
801
|
+
]
|
802
|
+
extend_input_logprob_token_ids.extend(logprob_token_ids)
|
803
|
+
|
804
|
+
# We will need req.extend_input_len - req.extend_logprob_start_len number of
|
805
|
+
# tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
|
806
|
+
extend_input_logprob_token_ids.extend(
|
807
|
+
[0]
|
808
|
+
* (
|
809
|
+
req.extend_input_len
|
810
|
+
- req.extend_logprob_start_len
|
811
|
+
- len(logprob_token_ids)
|
812
|
+
)
|
813
|
+
)
|
814
|
+
|
815
|
+
if self.return_logprob:
|
816
|
+
extend_input_logprob_token_ids = torch.tensor(
|
817
|
+
extend_input_logprob_token_ids
|
818
|
+
)
|
819
|
+
else:
|
820
|
+
extend_input_logprob_token_ids = None
|
798
821
|
|
799
822
|
# Set fields
|
800
823
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
@@ -817,10 +840,12 @@ class ScheduleBatch:
|
|
817
840
|
self.seq_lens_sum = sum(seq_lens)
|
818
841
|
if self.return_logprob:
|
819
842
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
843
|
+
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
|
820
844
|
self.extend_num_tokens = extend_num_tokens
|
821
845
|
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
822
846
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
823
847
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
848
|
+
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
824
849
|
|
825
850
|
# Write to req_to_token_pool
|
826
851
|
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
|
@@ -856,7 +881,6 @@ class ScheduleBatch:
|
|
856
881
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
857
882
|
self,
|
858
883
|
self.model_config.vocab_size,
|
859
|
-
enable_overlap_schedule=self.enable_overlap,
|
860
884
|
)
|
861
885
|
|
862
886
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
@@ -891,41 +915,60 @@ class ScheduleBatch:
|
|
891
915
|
|
892
916
|
def check_decode_mem(self, buf_multiplier=1):
|
893
917
|
bs = len(self.reqs) * buf_multiplier
|
894
|
-
if self.
|
918
|
+
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
895
919
|
return True
|
896
920
|
|
897
|
-
self.tree_cache.evict(bs, self.
|
921
|
+
self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free)
|
898
922
|
|
899
|
-
if self.
|
923
|
+
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
900
924
|
return True
|
901
925
|
|
902
926
|
return False
|
903
927
|
|
904
|
-
def retract_decode(self):
|
928
|
+
def retract_decode(self, server_args: ServerArgs):
|
905
929
|
"""Retract the decoding requests when there is not enough memory."""
|
906
930
|
sorted_indices = [i for i in range(len(self.reqs))]
|
907
931
|
|
908
932
|
# TODO(lsyin): improve retraction policy for radix cache
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
933
|
+
# For spec decoding, filter_batch API can only filter
|
934
|
+
# requests from the back, so we can only retract from the back.
|
935
|
+
# TODO(sang): Clean up finish path and support better retract
|
936
|
+
# policy.
|
937
|
+
if not server_args.speculative_algorithm:
|
938
|
+
sorted_indices.sort(
|
939
|
+
key=lambda i: (
|
940
|
+
len(self.reqs[i].output_ids),
|
941
|
+
-len(self.reqs[i].origin_input_ids),
|
942
|
+
),
|
943
|
+
reverse=True,
|
944
|
+
)
|
916
945
|
|
917
946
|
retracted_reqs = []
|
918
947
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
919
948
|
first_iter = True
|
949
|
+
|
950
|
+
def get_required_tokens(num_reqs: int):
|
951
|
+
headroom_for_spec_decode = 0
|
952
|
+
if server_args.speculative_algorithm:
|
953
|
+
headroom_for_spec_decode += (
|
954
|
+
num_reqs
|
955
|
+
* server_args.speculative_eagle_topk
|
956
|
+
* server_args.speculative_num_steps
|
957
|
+
+ num_reqs * server_args.speculative_num_draft_tokens
|
958
|
+
)
|
959
|
+
return (
|
960
|
+
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
961
|
+
)
|
962
|
+
|
920
963
|
while (
|
921
|
-
self.
|
922
|
-
< len(sorted_indices)
|
964
|
+
self.token_to_kv_pool_allocator.available_size()
|
965
|
+
< get_required_tokens(len(sorted_indices))
|
923
966
|
or first_iter
|
924
967
|
):
|
925
968
|
if len(sorted_indices) == 1:
|
926
969
|
# Corner case: only one request left
|
927
970
|
assert (
|
928
|
-
self.
|
971
|
+
self.token_to_kv_pool_allocator.available_size() > 0
|
929
972
|
), "No space left for only one request"
|
930
973
|
break
|
931
974
|
|
@@ -939,7 +982,7 @@ class ScheduleBatch:
|
|
939
982
|
token_indices = self.req_to_token_pool.req_to_token[
|
940
983
|
req.req_pool_idx, : seq_lens_cpu[idx]
|
941
984
|
]
|
942
|
-
self.
|
985
|
+
self.token_to_kv_pool_allocator.free(token_indices)
|
943
986
|
self.req_to_token_pool.free(req.req_pool_idx)
|
944
987
|
del self.tree_cache.entries[req.rid]
|
945
988
|
else:
|
@@ -948,7 +991,7 @@ class ScheduleBatch:
|
|
948
991
|
token_indices = self.req_to_token_pool.req_to_token[
|
949
992
|
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
950
993
|
]
|
951
|
-
self.
|
994
|
+
self.token_to_kv_pool_allocator.free(token_indices)
|
952
995
|
self.req_to_token_pool.free(req.req_pool_idx)
|
953
996
|
|
954
997
|
# release the last node
|
@@ -957,10 +1000,13 @@ class ScheduleBatch:
|
|
957
1000
|
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
958
1001
|
residual_size = (
|
959
1002
|
len(sorted_indices) * global_config.retract_decode_steps
|
960
|
-
- self.
|
1003
|
+
- self.token_to_kv_pool_allocator.available_size()
|
961
1004
|
)
|
962
1005
|
residual_size = max(0, residual_size)
|
963
|
-
self.tree_cache.evict(
|
1006
|
+
self.tree_cache.evict(
|
1007
|
+
residual_size, self.token_to_kv_pool_allocator.free
|
1008
|
+
)
|
1009
|
+
|
964
1010
|
req.reset_for_retract()
|
965
1011
|
|
966
1012
|
self.filter_batch(keep_indices=sorted_indices)
|
@@ -976,59 +1022,6 @@ class ScheduleBatch:
|
|
976
1022
|
|
977
1023
|
return retracted_reqs, new_estimate_ratio
|
978
1024
|
|
979
|
-
def check_for_jump_forward(self, pad_input_ids_func):
|
980
|
-
jump_forward_reqs = []
|
981
|
-
keep_indices = set(i for i in range(len(self.reqs)))
|
982
|
-
|
983
|
-
for i, req in enumerate(self.reqs):
|
984
|
-
if req.grammar is not None:
|
985
|
-
jump_helper = req.grammar.try_jump_forward(req.tokenizer)
|
986
|
-
if jump_helper:
|
987
|
-
suffix_ids, _ = jump_helper
|
988
|
-
|
989
|
-
# Current ids, for cache and revert
|
990
|
-
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
991
|
-
cur_output_ids = req.output_ids
|
992
|
-
|
993
|
-
req.output_ids.extend(suffix_ids)
|
994
|
-
decode_res, new_text = req.get_next_inc_detokenization()
|
995
|
-
if not decode_res:
|
996
|
-
req.output_ids = cur_output_ids
|
997
|
-
continue
|
998
|
-
|
999
|
-
(
|
1000
|
-
jump_forward_str,
|
1001
|
-
next_state,
|
1002
|
-
) = req.grammar.jump_forward_str_state(jump_helper)
|
1003
|
-
|
1004
|
-
# Make the incrementally decoded text part of jump_forward_str
|
1005
|
-
# so that the UTF-8 will not corrupt
|
1006
|
-
jump_forward_str = new_text + jump_forward_str
|
1007
|
-
if not req.jump_forward_and_retokenize(
|
1008
|
-
jump_forward_str, next_state
|
1009
|
-
):
|
1010
|
-
req.output_ids = cur_output_ids
|
1011
|
-
continue
|
1012
|
-
|
1013
|
-
# The decode status has diverged from detokenizer_manager
|
1014
|
-
req.vid += 1
|
1015
|
-
|
1016
|
-
# insert the old request into tree_cache
|
1017
|
-
self.tree_cache.cache_finished_req(req, cur_all_ids)
|
1018
|
-
|
1019
|
-
# re-applying image padding
|
1020
|
-
if req.image_inputs is not None:
|
1021
|
-
req.origin_input_ids = pad_input_ids_func(
|
1022
|
-
req.origin_input_ids_unpadded, req.image_inputs
|
1023
|
-
)
|
1024
|
-
|
1025
|
-
jump_forward_reqs.append(req)
|
1026
|
-
keep_indices.remove(i)
|
1027
|
-
|
1028
|
-
self.filter_batch(keep_indices=list(keep_indices))
|
1029
|
-
|
1030
|
-
return jump_forward_reqs
|
1031
|
-
|
1032
1025
|
def prepare_encoder_info_decode(self):
|
1033
1026
|
# Reset the encoder cached status
|
1034
1027
|
self.encoder_cached = [True] * len(self.reqs)
|
@@ -1044,17 +1037,40 @@ class ScheduleBatch:
|
|
1044
1037
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
1045
1038
|
self,
|
1046
1039
|
self.model_config.vocab_size,
|
1047
|
-
enable_overlap_schedule=self.enable_overlap,
|
1048
1040
|
)
|
1049
1041
|
|
1050
1042
|
def prepare_for_decode(self):
|
1051
1043
|
self.forward_mode = ForwardMode.DECODE
|
1052
1044
|
if self.spec_algorithm.is_eagle():
|
1045
|
+
# if spec decoding is used, the decode batch is prepared inside
|
1046
|
+
# `forward_batch_speculative_generation` after running draft models.
|
1053
1047
|
return
|
1054
1048
|
|
1049
|
+
if self.sampling_info.penalizer_orchestrator.is_required:
|
1050
|
+
if self.enable_overlap:
|
1051
|
+
# TODO: this can be slow, optimize this.
|
1052
|
+
delayed_output_ids = torch.tensor(
|
1053
|
+
[
|
1054
|
+
(
|
1055
|
+
req.output_ids[-1]
|
1056
|
+
if len(req.output_ids)
|
1057
|
+
else req.origin_input_ids[-1]
|
1058
|
+
)
|
1059
|
+
for req in self.reqs
|
1060
|
+
],
|
1061
|
+
dtype=torch.int64,
|
1062
|
+
device=self.device,
|
1063
|
+
)
|
1064
|
+
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
1065
|
+
delayed_output_ids
|
1066
|
+
)
|
1067
|
+
else:
|
1068
|
+
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
1069
|
+
self.output_ids.to(torch.int64)
|
1070
|
+
)
|
1071
|
+
|
1055
1072
|
self.input_ids = self.output_ids
|
1056
1073
|
self.output_ids = None
|
1057
|
-
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
|
1058
1074
|
|
1059
1075
|
# Alloc mem
|
1060
1076
|
bs = len(self.reqs)
|
@@ -1082,14 +1098,15 @@ class ScheduleBatch:
|
|
1082
1098
|
|
1083
1099
|
def filter_batch(
|
1084
1100
|
self,
|
1085
|
-
|
1101
|
+
chunked_req_to_exclude: Optional[Req] = None,
|
1086
1102
|
keep_indices: Optional[List[int]] = None,
|
1087
1103
|
):
|
1088
1104
|
if keep_indices is None:
|
1089
1105
|
keep_indices = [
|
1090
1106
|
i
|
1091
1107
|
for i in range(len(self.reqs))
|
1092
|
-
if not self.reqs[i].finished()
|
1108
|
+
if not self.reqs[i].finished()
|
1109
|
+
and self.reqs[i] is not chunked_req_to_exclude
|
1093
1110
|
]
|
1094
1111
|
|
1095
1112
|
if keep_indices is None or len(keep_indices) == 0:
|
@@ -1101,31 +1118,34 @@ class ScheduleBatch:
|
|
1101
1118
|
# No need to filter
|
1102
1119
|
return
|
1103
1120
|
|
1121
|
+
keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
|
1122
|
+
self.device, non_blocking=True
|
1123
|
+
)
|
1124
|
+
|
1104
1125
|
if self.model_config.is_encoder_decoder:
|
1105
|
-
self.encoder_lens = self.encoder_lens[
|
1126
|
+
self.encoder_lens = self.encoder_lens[keep_indices_device]
|
1106
1127
|
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
1107
1128
|
|
1108
1129
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
1109
|
-
|
1110
|
-
|
1111
|
-
)
|
1112
|
-
self.req_pool_indices = self.req_pool_indices[new_indices]
|
1113
|
-
self.seq_lens = self.seq_lens[new_indices]
|
1130
|
+
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
1131
|
+
self.seq_lens = self.seq_lens[keep_indices_device]
|
1114
1132
|
self.out_cache_loc = None
|
1115
1133
|
self.seq_lens_sum = self.seq_lens.sum().item()
|
1116
|
-
self.output_ids = self.output_ids[
|
1134
|
+
self.output_ids = self.output_ids[keep_indices_device]
|
1117
1135
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
1118
1136
|
if self.return_logprob:
|
1119
1137
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
|
1138
|
+
self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
|
1120
1139
|
else:
|
1121
1140
|
self.top_logprobs_nums = None
|
1141
|
+
self.token_ids_logprobs = None
|
1122
1142
|
|
1123
1143
|
self.has_stream = any(req.stream for req in self.reqs)
|
1124
1144
|
self.has_grammar = any(req.grammar for req in self.reqs)
|
1125
1145
|
|
1126
|
-
self.sampling_info.filter_batch(keep_indices,
|
1146
|
+
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
|
1127
1147
|
if self.spec_info:
|
1128
|
-
self.spec_info.filter_batch(
|
1148
|
+
self.spec_info.filter_batch(keep_indices_device)
|
1129
1149
|
|
1130
1150
|
def merge_batch(self, other: "ScheduleBatch"):
|
1131
1151
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
@@ -1148,23 +1168,32 @@ class ScheduleBatch:
|
|
1148
1168
|
self.output_ids = torch.concat([self.output_ids, other.output_ids])
|
1149
1169
|
if self.return_logprob and other.return_logprob:
|
1150
1170
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
1171
|
+
self.token_ids_logprobs.extend(other.token_ids_logprobs)
|
1151
1172
|
elif self.return_logprob:
|
1152
1173
|
self.top_logprobs_nums.extend([0] * len(other.reqs))
|
1174
|
+
self.token_ids_logprobs.extend([None] * len(other.reqs))
|
1153
1175
|
elif other.return_logprob:
|
1154
1176
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
1177
|
+
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
|
1155
1178
|
self.reqs.extend(other.reqs)
|
1156
1179
|
|
1157
1180
|
self.return_logprob |= other.return_logprob
|
1158
1181
|
self.has_stream |= other.has_stream
|
1159
1182
|
self.has_grammar |= other.has_grammar
|
1183
|
+
self.return_hidden_states |= other.return_hidden_states
|
1160
1184
|
|
1161
1185
|
if self.spec_info:
|
1162
1186
|
self.spec_info.merge_batch(other.spec_info)
|
1163
1187
|
|
1164
|
-
def get_model_worker_batch(self):
|
1188
|
+
def get_model_worker_batch(self) -> ModelWorkerBatch:
|
1165
1189
|
if self.forward_mode.is_decode_or_idle():
|
1190
|
+
if global_server_args_dict["enable_flashinfer_mla"]:
|
1191
|
+
decode_seq_lens = self.seq_lens.cpu()
|
1192
|
+
else:
|
1193
|
+
decode_seq_lens = None
|
1166
1194
|
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
1167
1195
|
else:
|
1196
|
+
decode_seq_lens = None
|
1168
1197
|
extend_seq_lens = self.extend_lens
|
1169
1198
|
extend_prefix_lens = self.prefix_lens
|
1170
1199
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
@@ -1187,8 +1216,11 @@ class ScheduleBatch:
|
|
1187
1216
|
seq_lens_sum=self.seq_lens_sum,
|
1188
1217
|
return_logprob=self.return_logprob,
|
1189
1218
|
top_logprobs_nums=self.top_logprobs_nums,
|
1219
|
+
token_ids_logprobs=self.token_ids_logprobs,
|
1190
1220
|
global_num_tokens=self.global_num_tokens,
|
1221
|
+
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
1191
1222
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
1223
|
+
decode_seq_lens=decode_seq_lens,
|
1192
1224
|
extend_num_tokens=self.extend_num_tokens,
|
1193
1225
|
extend_seq_lens=extend_seq_lens,
|
1194
1226
|
extend_prefix_lens=extend_prefix_lens,
|
@@ -1214,6 +1246,7 @@ class ScheduleBatch:
|
|
1214
1246
|
else CaptureHiddenMode.NULL
|
1215
1247
|
)
|
1216
1248
|
),
|
1249
|
+
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
1217
1250
|
)
|
1218
1251
|
|
1219
1252
|
def copy(self):
|
@@ -1248,7 +1281,7 @@ class ModelWorkerBatch:
|
|
1248
1281
|
req_pool_indices: torch.Tensor
|
1249
1282
|
# The sequence length
|
1250
1283
|
seq_lens: torch.Tensor
|
1251
|
-
# The indices of output tokens in the
|
1284
|
+
# The indices of output tokens in the token_to_kv_pool_allocator
|
1252
1285
|
out_cache_loc: torch.Tensor
|
1253
1286
|
|
1254
1287
|
# The sum of all sequence lengths
|
@@ -1257,16 +1290,22 @@ class ModelWorkerBatch:
|
|
1257
1290
|
# For logprob
|
1258
1291
|
return_logprob: bool
|
1259
1292
|
top_logprobs_nums: Optional[List[int]]
|
1293
|
+
token_ids_logprobs: Optional[List[List[int]]]
|
1260
1294
|
|
1261
1295
|
# For DP attention
|
1262
1296
|
global_num_tokens: Optional[List[int]]
|
1297
|
+
global_num_tokens_for_logprob: Optional[List[int]]
|
1263
1298
|
can_run_dp_cuda_graph: bool
|
1264
1299
|
|
1300
|
+
# For decode
|
1301
|
+
decode_seq_lens: Optional[torch.Tensor]
|
1302
|
+
|
1265
1303
|
# For extend
|
1266
1304
|
extend_num_tokens: Optional[int]
|
1267
1305
|
extend_seq_lens: Optional[List[int]]
|
1268
1306
|
extend_prefix_lens: Optional[List[int]]
|
1269
1307
|
extend_logprob_start_lens: Optional[List[int]]
|
1308
|
+
extend_input_logprob_token_ids: Optional[torch.Tensor]
|
1270
1309
|
|
1271
1310
|
# For multimodal
|
1272
1311
|
image_inputs: Optional[List[ImageInputs]]
|
@@ -1288,7 +1327,8 @@ class ModelWorkerBatch:
|
|
1288
1327
|
|
1289
1328
|
# Speculative decoding
|
1290
1329
|
spec_algorithm: SpeculativeAlgorithm = None
|
1291
|
-
spec_info: Optional[
|
1330
|
+
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
1331
|
+
# If set, the output of the batch contains the hidden states of the run.
|
1292
1332
|
capture_hidden_mode: CaptureHiddenMode = None
|
1293
1333
|
|
1294
1334
|
|