sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -38,7 +38,7 @@ import threading
|
|
38
38
|
from enum import Enum, auto
|
39
39
|
from http import HTTPStatus
|
40
40
|
from itertools import chain
|
41
|
-
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
41
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
|
42
42
|
|
43
43
|
import numpy as np
|
44
44
|
import torch
|
@@ -52,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
52
52
|
ScheduleBatchDisaggregationDecodeMixin,
|
53
53
|
)
|
54
54
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
55
|
-
from sglang.srt.layers.moe import is_tbo_enabled
|
56
55
|
from sglang.srt.mem_cache.allocator import (
|
57
56
|
BaseTokenToKVPoolAllocator,
|
58
57
|
SWATokenToKVPoolAllocator,
|
@@ -60,7 +59,7 @@ from sglang.srt.mem_cache.allocator import (
|
|
60
59
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
61
60
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
62
61
|
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
|
63
|
-
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
62
|
+
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
64
63
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
65
64
|
from sglang.srt.metrics.collector import TimeStats
|
66
65
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
@@ -99,13 +98,13 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
99
98
|
"sampling_backend",
|
100
99
|
"speculative_accept_threshold_single",
|
101
100
|
"speculative_accept_threshold_acc",
|
101
|
+
"speculative_attention_mode",
|
102
102
|
"torchao_config",
|
103
103
|
"triton_attention_reduce_in_fp32",
|
104
104
|
"num_reserved_decode_tokens",
|
105
105
|
"weight_loader_disable_mmap",
|
106
106
|
"enable_multimodal",
|
107
107
|
"enable_symm_mem",
|
108
|
-
"quantization",
|
109
108
|
"enable_custom_logit_processor",
|
110
109
|
"disaggregation_mode",
|
111
110
|
]
|
@@ -561,7 +560,10 @@ class Req:
|
|
561
560
|
# shape: (bs, k)
|
562
561
|
self.output_top_logprobs_val = []
|
563
562
|
self.output_top_logprobs_idx = []
|
564
|
-
|
563
|
+
# Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
|
564
|
+
self.output_token_ids_logprobs_val: List[
|
565
|
+
Union[List[float], torch.Tensor]
|
566
|
+
] = []
|
565
567
|
self.output_token_ids_logprobs_idx = []
|
566
568
|
else:
|
567
569
|
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
@@ -619,6 +621,11 @@ class Req:
|
|
619
621
|
def seqlen(self):
|
620
622
|
return len(self.origin_input_ids) + len(self.output_ids)
|
621
623
|
|
624
|
+
@property
|
625
|
+
def is_prefill_only(self) -> bool:
|
626
|
+
"""Check if this request is prefill-only (no token generation needed)."""
|
627
|
+
return self.sampling_params.max_new_tokens == 0
|
628
|
+
|
622
629
|
def extend_image_inputs(self, image_inputs):
|
623
630
|
if self.multimodal_inputs is None:
|
624
631
|
self.multimodal_inputs = image_inputs
|
@@ -684,9 +691,15 @@ class Req:
|
|
684
691
|
self.surr_offset = max(
|
685
692
|
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
|
686
693
|
)
|
694
|
+
self.surr_and_decode_ids = (
|
695
|
+
self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids
|
696
|
+
)
|
697
|
+
self.cur_decode_ids_len = len(self.output_ids)
|
698
|
+
else:
|
699
|
+
self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :])
|
700
|
+
self.cur_decode_ids_len = len(self.output_ids)
|
687
701
|
|
688
|
-
|
689
|
-
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
702
|
+
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
|
690
703
|
|
691
704
|
def check_finished(self):
|
692
705
|
if self.finished():
|
@@ -911,7 +924,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
911
924
|
is_prefill_only: bool = False
|
912
925
|
|
913
926
|
# hicache pointer for synchronizing data loading from CPU to GPU
|
914
|
-
hicache_consumer_index: int =
|
927
|
+
hicache_consumer_index: int = -1
|
915
928
|
|
916
929
|
@classmethod
|
917
930
|
def init_new(
|
@@ -950,9 +963,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
950
963
|
device=req_to_token_pool.device,
|
951
964
|
spec_algorithm=spec_algorithm,
|
952
965
|
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
953
|
-
is_prefill_only=all(
|
954
|
-
req.sampling_params.max_new_tokens == 0 for req in reqs
|
955
|
-
),
|
966
|
+
is_prefill_only=all(req.is_prefill_only for req in reqs),
|
956
967
|
chunked_req=chunked_req,
|
957
968
|
)
|
958
969
|
|
@@ -962,8 +973,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
962
973
|
def is_empty(self):
|
963
974
|
return len(self.reqs) == 0
|
964
975
|
|
965
|
-
def alloc_req_slots(self, num_reqs: int):
|
966
|
-
|
976
|
+
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
|
977
|
+
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
|
978
|
+
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
|
979
|
+
else:
|
980
|
+
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
967
981
|
if req_pool_indices is None:
|
968
982
|
raise RuntimeError(
|
969
983
|
"alloc_req_slots runs out of memory. "
|
@@ -1138,7 +1152,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1138
1152
|
|
1139
1153
|
# Allocate req slots
|
1140
1154
|
bs = len(self.reqs)
|
1141
|
-
req_pool_indices = self.alloc_req_slots(bs)
|
1155
|
+
req_pool_indices = self.alloc_req_slots(bs, self.reqs)
|
1142
1156
|
|
1143
1157
|
# Init tensors
|
1144
1158
|
reqs = self.reqs
|
@@ -1207,13 +1221,36 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1207
1221
|
req.is_retracted = False
|
1208
1222
|
|
1209
1223
|
# Compute the relative logprob_start_len in an extend batch
|
1224
|
+
#
|
1225
|
+
# Key variables:
|
1226
|
+
# - logprob_start_len: Absolute position in full sequence where logprob computation begins
|
1227
|
+
# - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
|
1228
|
+
# - extend_input_len: Number of tokens that need to be processed in this extend batch
|
1229
|
+
# (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
|
1230
|
+
# and prefix_indices are the cached/shared prefix tokens)
|
1231
|
+
#
|
1210
1232
|
if req.logprob_start_len >= pre_len:
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1233
|
+
# Optimization for prefill-only requests: When we only need logprobs at
|
1234
|
+
# positions beyond the input sequence (to score next-token likelihood), skip all
|
1235
|
+
# input logprob computation during prefill since no generation will occur.
|
1236
|
+
if self.is_prefill_only and req.logprob_start_len == len(
|
1237
|
+
req.origin_input_ids
|
1238
|
+
):
|
1239
|
+
# Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
|
1240
|
+
req.extend_logprob_start_len = req.extend_input_len
|
1241
|
+
else:
|
1242
|
+
# Convert absolute logprob_start_len to relative extend_logprob_start_len
|
1243
|
+
#
|
1244
|
+
# Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
|
1245
|
+
# Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
|
1246
|
+
# This means: "compute logprobs from position 3 onwards in extend batch"
|
1247
|
+
req.extend_logprob_start_len = min(
|
1248
|
+
req.logprob_start_len - pre_len,
|
1249
|
+
req.extend_input_len,
|
1250
|
+
req.seqlen - 1,
|
1251
|
+
)
|
1216
1252
|
else:
|
1253
|
+
# logprob_start_len is before the current extend batch, so start from beginning
|
1217
1254
|
req.extend_logprob_start_len = 0
|
1218
1255
|
|
1219
1256
|
if self.return_logprob:
|
@@ -1372,21 +1409,28 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1372
1409
|
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
1373
1410
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
1374
1411
|
|
1375
|
-
def new_page_count_next_decode(self):
|
1412
|
+
def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
|
1376
1413
|
page_size = self.token_to_kv_pool_allocator.page_size
|
1414
|
+
requests = (
|
1415
|
+
self.reqs
|
1416
|
+
if selected_indices is None
|
1417
|
+
else [self.reqs[i] for i in selected_indices]
|
1418
|
+
)
|
1377
1419
|
if page_size == 1:
|
1378
|
-
return len(
|
1420
|
+
return len(requests)
|
1379
1421
|
# In the decoding phase, the length of a request's KV cache should be
|
1380
1422
|
# the total length of the request minus 1
|
1381
1423
|
return (
|
1382
|
-
sum(1 for req in
|
1424
|
+
sum(1 for req in requests if req.seqlen % page_size == 0)
|
1383
1425
|
if self.enable_overlap
|
1384
|
-
else sum(1 for req in
|
1426
|
+
else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
|
1385
1427
|
)
|
1386
1428
|
|
1387
|
-
def check_decode_mem(
|
1429
|
+
def check_decode_mem(
|
1430
|
+
self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
|
1431
|
+
):
|
1388
1432
|
num_tokens = (
|
1389
|
-
self.new_page_count_next_decode()
|
1433
|
+
self.new_page_count_next_decode(selected_indices)
|
1390
1434
|
* buf_multiplier
|
1391
1435
|
* self.token_to_kv_pool_allocator.page_size
|
1392
1436
|
)
|
@@ -1412,34 +1456,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1412
1456
|
reverse=True,
|
1413
1457
|
)
|
1414
1458
|
|
1415
|
-
def get_required_tokens(num_reqs: int):
|
1416
|
-
headroom_for_spec_decode = 0
|
1417
|
-
if server_args.speculative_algorithm:
|
1418
|
-
headroom_for_spec_decode += (
|
1419
|
-
num_reqs
|
1420
|
-
* server_args.speculative_eagle_topk
|
1421
|
-
* server_args.speculative_num_steps
|
1422
|
-
+ num_reqs * server_args.speculative_num_draft_tokens
|
1423
|
-
)
|
1424
|
-
return (
|
1425
|
-
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
1426
|
-
)
|
1427
|
-
|
1428
|
-
def _get_available_size():
|
1429
|
-
if self.is_hybrid:
|
1430
|
-
return min(
|
1431
|
-
self.token_to_kv_pool_allocator.full_available_size(),
|
1432
|
-
self.token_to_kv_pool_allocator.swa_available_size(),
|
1433
|
-
)
|
1434
|
-
else:
|
1435
|
-
return self.token_to_kv_pool_allocator.available_size()
|
1436
|
-
|
1437
1459
|
retracted_reqs = []
|
1438
1460
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
1439
1461
|
first_iter = True
|
1440
|
-
while (
|
1441
|
-
|
1442
|
-
or first_iter
|
1462
|
+
while first_iter or (
|
1463
|
+
not self.check_decode_mem(selected_indices=sorted_indices)
|
1443
1464
|
):
|
1444
1465
|
if len(sorted_indices) == 1:
|
1445
1466
|
# Corner case: only one request left
|
@@ -1493,10 +1514,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1493
1514
|
else:
|
1494
1515
|
self.tree_cache.dec_lock_ref(req.last_node)
|
1495
1516
|
|
1496
|
-
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
1497
|
-
num_tokens = len(sorted_indices) * global_config.retract_decode_steps
|
1498
|
-
self._evict_tree_cache_if_needed(num_tokens)
|
1499
|
-
|
1500
1517
|
req.reset_for_retract()
|
1501
1518
|
|
1502
1519
|
if len(retracted_reqs) == 0:
|
@@ -1540,7 +1557,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1540
1557
|
self.forward_mode = ForwardMode.DECODE
|
1541
1558
|
bs = len(self.reqs)
|
1542
1559
|
|
1543
|
-
if self.spec_algorithm.is_eagle():
|
1560
|
+
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
|
1544
1561
|
# if spec decoding is used, the decode batch is prepared inside
|
1545
1562
|
# `forward_batch_speculative_generation` after running draft models.
|
1546
1563
|
return
|
@@ -1780,6 +1797,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1780
1797
|
),
|
1781
1798
|
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
1782
1799
|
launch_done=self.launch_done,
|
1800
|
+
is_prefill_only=self.is_prefill_only,
|
1783
1801
|
)
|
1784
1802
|
|
1785
1803
|
def copy(self):
|
@@ -1917,11 +1935,14 @@ class ModelWorkerBatch:
|
|
1917
1935
|
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
1918
1936
|
# If set, the output of the batch contains the hidden states of the run.
|
1919
1937
|
capture_hidden_mode: CaptureHiddenMode = None
|
1920
|
-
hicache_consumer_index: int =
|
1938
|
+
hicache_consumer_index: int = -1
|
1921
1939
|
|
1922
1940
|
# Overlap event
|
1923
1941
|
launch_done: Optional[threading.Event] = None
|
1924
1942
|
|
1943
|
+
# Whether this batch is prefill-only (no token generation needed)
|
1944
|
+
is_prefill_only: bool = False
|
1945
|
+
|
1925
1946
|
|
1926
1947
|
@triton.jit
|
1927
1948
|
def write_req_to_token_pool_triton(
|
@@ -550,7 +550,7 @@ class PrefillAdder:
|
|
550
550
|
)
|
551
551
|
else:
|
552
552
|
# Make sure at least one page is available
|
553
|
-
trunc_len = self.rem_chunk_tokens
|
553
|
+
trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size
|
554
554
|
if trunc_len <= 0:
|
555
555
|
return AddReqResult.OTHER
|
556
556
|
|