sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,6 @@ from sglang.srt.distributed import (
|
|
14
14
|
)
|
15
15
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
16
16
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
17
|
-
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
18
17
|
from sglang.srt.managers.schedule_batch import (
|
19
18
|
ScheduleBatch,
|
20
19
|
get_last_loc,
|
@@ -24,6 +23,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
|
|
24
23
|
from sglang.srt.model_executor.forward_batch_info import (
|
25
24
|
CaptureHiddenMode,
|
26
25
|
ForwardBatch,
|
26
|
+
ForwardBatchOutput,
|
27
27
|
ForwardMode,
|
28
28
|
)
|
29
29
|
from sglang.srt.server_args import ServerArgs
|
@@ -34,16 +34,18 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
|
34
34
|
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
|
35
35
|
EAGLEDraftExtendCudaGraphRunner,
|
36
36
|
)
|
37
|
-
from sglang.srt.speculative.
|
37
|
+
from sglang.srt.speculative.eagle_info import (
|
38
38
|
EagleDraftInput,
|
39
39
|
EagleVerifyInput,
|
40
40
|
EagleVerifyOutput,
|
41
|
+
)
|
42
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
43
|
+
from sglang.srt.speculative.spec_utils import (
|
41
44
|
assign_draft_cache_locs,
|
42
45
|
fast_topk,
|
43
46
|
generate_token_bitmask,
|
44
47
|
select_top_k_tokens,
|
45
48
|
)
|
46
|
-
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
47
49
|
from sglang.srt.utils import (
|
48
50
|
empty_context,
|
49
51
|
get_available_gpu_memory,
|
@@ -242,6 +244,7 @@ class EAGLEWorker(TpModelWorker):
|
|
242
244
|
if not is_blackwell()
|
243
245
|
else self._create_triton_prefill_backend
|
244
246
|
),
|
247
|
+
"flashmla": self._create_flashmla_prefill_backend,
|
245
248
|
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
246
249
|
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
247
250
|
}
|
@@ -381,6 +384,12 @@ class EAGLEWorker(TpModelWorker):
|
|
381
384
|
|
382
385
|
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
|
383
386
|
|
387
|
+
def _create_flashmla_prefill_backend(self):
|
388
|
+
logger.warning(
|
389
|
+
"flashmla prefill backend is not yet supported for draft extend."
|
390
|
+
)
|
391
|
+
return None
|
392
|
+
|
384
393
|
def init_cuda_graphs(self):
|
385
394
|
"""Capture cuda graphs."""
|
386
395
|
self.cuda_graph_runner = None
|
@@ -420,9 +429,7 @@ class EAGLEWorker(TpModelWorker):
|
|
420
429
|
def draft_model_runner(self):
|
421
430
|
return self.model_runner
|
422
431
|
|
423
|
-
def
|
424
|
-
self, batch: ScheduleBatch
|
425
|
-
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
|
432
|
+
def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
|
426
433
|
"""Run speculative decoding forward.
|
427
434
|
|
428
435
|
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
|
@@ -435,14 +442,19 @@ class EAGLEWorker(TpModelWorker):
|
|
435
442
|
the batch id (used for overlap schedule), and number of accepted tokens.
|
436
443
|
"""
|
437
444
|
if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
|
438
|
-
logits_output, next_token_ids,
|
439
|
-
|
445
|
+
logits_output, next_token_ids, seq_lens_cpu = self.forward_target_extend(
|
446
|
+
batch
|
440
447
|
)
|
441
448
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
442
449
|
self.forward_draft_extend(
|
443
450
|
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
|
444
451
|
)
|
445
|
-
return
|
452
|
+
return ForwardBatchOutput(
|
453
|
+
logits_output=logits_output,
|
454
|
+
next_token_ids=next_token_ids,
|
455
|
+
num_accepted_tokens=0,
|
456
|
+
can_run_cuda_graph=False,
|
457
|
+
)
|
446
458
|
else:
|
447
459
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
448
460
|
spec_info = self.draft(batch)
|
@@ -460,12 +472,11 @@ class EAGLEWorker(TpModelWorker):
|
|
460
472
|
# decode is not finished
|
461
473
|
self.forward_draft_extend_after_decode(batch)
|
462
474
|
|
463
|
-
return (
|
464
|
-
logits_output,
|
465
|
-
verify_output.verified_id,
|
466
|
-
|
467
|
-
|
468
|
-
can_run_cuda_graph,
|
475
|
+
return ForwardBatchOutput(
|
476
|
+
logits_output=logits_output,
|
477
|
+
next_token_ids=verify_output.verified_id,
|
478
|
+
num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
|
479
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
469
480
|
)
|
470
481
|
|
471
482
|
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
@@ -497,19 +508,21 @@ class EAGLEWorker(TpModelWorker):
|
|
497
508
|
Returns:
|
498
509
|
logits_output: The output of logits. It will contain the full hidden states.
|
499
510
|
next_token_ids: Next token ids generated.
|
500
|
-
bid: The model batch ID. Used for overlap schedule.
|
501
511
|
"""
|
502
512
|
# Forward with the target model and get hidden states.
|
503
513
|
# We need the full hidden states to prefill the KV cache of the draft model.
|
504
514
|
model_worker_batch = batch.get_model_worker_batch()
|
505
515
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
506
|
-
|
516
|
+
forward_batch_output = self.target_worker.forward_batch_generation(
|
507
517
|
model_worker_batch
|
508
518
|
)
|
519
|
+
logits_output, next_token_ids = (
|
520
|
+
forward_batch_output.logits_output,
|
521
|
+
forward_batch_output.next_token_ids,
|
522
|
+
)
|
509
523
|
return (
|
510
524
|
logits_output,
|
511
525
|
next_token_ids,
|
512
|
-
model_worker_batch.bid,
|
513
526
|
model_worker_batch.seq_lens_cpu,
|
514
527
|
)
|
515
528
|
|
@@ -541,6 +554,8 @@ class EAGLEWorker(TpModelWorker):
|
|
541
554
|
batch.seq_lens,
|
542
555
|
self.speculative_num_steps,
|
543
556
|
)
|
557
|
+
prefix_lens_cpu = batch.seq_lens_cpu
|
558
|
+
seq_lens_cpu = batch.seq_lens_cpu + self.speculative_num_steps
|
544
559
|
extend_num_tokens = num_seqs * self.speculative_num_steps
|
545
560
|
else:
|
546
561
|
# In this case, the last partial page needs to be duplicated.
|
@@ -576,14 +591,23 @@ class EAGLEWorker(TpModelWorker):
|
|
576
591
|
self.topk,
|
577
592
|
self.page_size,
|
578
593
|
)
|
579
|
-
|
580
|
-
|
581
|
-
|
594
|
+
prefix_lens_cpu = batch.seq_lens_cpu
|
595
|
+
last_page_lens = prefix_lens_cpu % self.page_size
|
596
|
+
num_new_pages_per_topk = (
|
597
|
+
last_page_lens + self.speculative_num_steps + self.page_size - 1
|
598
|
+
) // self.page_size
|
599
|
+
seq_lens_cpu = (
|
600
|
+
prefix_lens_cpu // self.page_size * self.page_size
|
601
|
+
+ num_new_pages_per_topk * (self.page_size * self.topk)
|
602
|
+
)
|
603
|
+
extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item()
|
582
604
|
|
583
605
|
out_cache_loc, token_to_kv_pool_state_backup = (
|
584
606
|
batch.alloc_paged_token_slots_extend(
|
585
607
|
prefix_lens,
|
608
|
+
prefix_lens_cpu,
|
586
609
|
seq_lens,
|
610
|
+
seq_lens_cpu,
|
587
611
|
last_loc,
|
588
612
|
extend_num_tokens,
|
589
613
|
backup_state=True,
|
@@ -771,6 +795,10 @@ class EAGLEWorker(TpModelWorker):
|
|
771
795
|
|
772
796
|
return score_list, token_list, parents_list
|
773
797
|
|
798
|
+
def clear_cache_pool(self):
|
799
|
+
self.model_runner.req_to_token_pool.clear()
|
800
|
+
self.model_runner.token_to_kv_pool_allocator.clear()
|
801
|
+
|
774
802
|
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
775
803
|
spec_info.prepare_for_verify(batch, self.page_size)
|
776
804
|
batch.return_hidden_states = False
|
@@ -794,10 +822,12 @@ class EAGLEWorker(TpModelWorker):
|
|
794
822
|
).cpu()
|
795
823
|
|
796
824
|
# Forward
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
825
|
+
forward_batch_output = self.target_worker.forward_batch_generation(
|
826
|
+
model_worker_batch, is_verify=True
|
827
|
+
)
|
828
|
+
logits_output, can_run_cuda_graph = (
|
829
|
+
forward_batch_output.logits_output,
|
830
|
+
forward_batch_output.can_run_cuda_graph,
|
801
831
|
)
|
802
832
|
|
803
833
|
vocab_mask = None
|
@@ -997,6 +1027,7 @@ class EAGLEWorker(TpModelWorker):
|
|
997
1027
|
assert isinstance(batch.spec_info, EagleDraftInput)
|
998
1028
|
# Backup fields that will be modified in-place
|
999
1029
|
seq_lens_backup = batch.seq_lens.clone()
|
1030
|
+
seq_lens_cpu_backup = batch.seq_lens_cpu.clone()
|
1000
1031
|
req_pool_indices_backup = batch.req_pool_indices
|
1001
1032
|
accept_length_backup = batch.spec_info.accept_length
|
1002
1033
|
return_logprob_backup = batch.return_logprob
|
@@ -1075,6 +1106,7 @@ class EAGLEWorker(TpModelWorker):
|
|
1075
1106
|
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
|
1076
1107
|
)
|
1077
1108
|
batch.seq_lens = seq_lens_backup
|
1109
|
+
batch.seq_lens_cpu = seq_lens_cpu_backup
|
1078
1110
|
batch.req_pool_indices = req_pool_indices_backup
|
1079
1111
|
batch.spec_info.accept_length = accept_length_backup
|
1080
1112
|
batch.return_logprob = return_logprob_backup
|
@@ -0,0 +1,428 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import copy
|
4
|
+
import logging
|
5
|
+
from typing import Optional, Tuple
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import triton
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
from dataclasses import dataclass
|
13
|
+
|
14
|
+
import torch.nn.functional as F
|
15
|
+
|
16
|
+
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
17
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
18
|
+
from sglang.srt.layers.sampler import apply_custom_logit_processor
|
19
|
+
from sglang.srt.managers.schedule_batch import (
|
20
|
+
ScheduleBatch,
|
21
|
+
get_last_loc,
|
22
|
+
global_server_args_dict,
|
23
|
+
)
|
24
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
25
|
+
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
|
26
|
+
from sglang.srt.speculative.spec_utils import (
|
27
|
+
TREE_SPEC_KERNEL_AVAILABLE,
|
28
|
+
assign_req_to_token_pool,
|
29
|
+
get_src_tgt_cache_loc,
|
30
|
+
get_target_cache_loc,
|
31
|
+
)
|
32
|
+
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
33
|
+
|
34
|
+
if is_cuda():
|
35
|
+
from sgl_kernel import (
|
36
|
+
top_k_renorm_prob,
|
37
|
+
top_p_renorm_prob,
|
38
|
+
tree_speculative_sampling_target_only,
|
39
|
+
verify_tree_greedy,
|
40
|
+
)
|
41
|
+
elif is_hip():
|
42
|
+
from sgl_kernel import verify_tree_greedy
|
43
|
+
|
44
|
+
|
45
|
+
@dataclass
|
46
|
+
class NgramVerifyInput(SpecInput):
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
draft_token: torch.Tensor,
|
50
|
+
tree_mask: torch.Tensor,
|
51
|
+
positions: torch.Tensor,
|
52
|
+
retrive_index: torch.Tensor,
|
53
|
+
retrive_next_token: torch.Tensor,
|
54
|
+
retrive_next_sibling: torch.Tensor,
|
55
|
+
draft_token_num: int,
|
56
|
+
):
|
57
|
+
super().__init__(SpecInputType.NGRAM_VERIFY)
|
58
|
+
self.draft_token = draft_token
|
59
|
+
self.custom_mask = tree_mask
|
60
|
+
self.positions = positions
|
61
|
+
self.retrive_index = retrive_index
|
62
|
+
self.retrive_next_token = retrive_next_token
|
63
|
+
self.retrive_next_sibling = retrive_next_sibling
|
64
|
+
self.draft_token_num = draft_token_num
|
65
|
+
self.device = self.custom_mask.device
|
66
|
+
|
67
|
+
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
|
68
|
+
return self.draft_token_num, self.draft_token_num
|
69
|
+
|
70
|
+
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
71
|
+
if batch.forward_mode.is_idle():
|
72
|
+
return
|
73
|
+
|
74
|
+
batch.input_ids = self.draft_token
|
75
|
+
|
76
|
+
if page_size == 1:
|
77
|
+
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
|
78
|
+
end_offset = batch.seq_lens + self.draft_token_num
|
79
|
+
else:
|
80
|
+
# TODO(lsyin): add prefix lens cpu here to support page size > 1
|
81
|
+
prefix_lens = batch.seq_lens
|
82
|
+
prefix_lens_cpu = batch.seq_lens_cpu
|
83
|
+
end_offset = prefix_lens + self.draft_token_num
|
84
|
+
end_offset_cpu = prefix_lens_cpu + self.draft_token_num
|
85
|
+
last_loc = get_last_loc(
|
86
|
+
batch.req_to_token_pool.req_to_token,
|
87
|
+
batch.req_pool_indices,
|
88
|
+
prefix_lens,
|
89
|
+
)
|
90
|
+
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
|
91
|
+
prefix_lens,
|
92
|
+
prefix_lens_cpu,
|
93
|
+
end_offset,
|
94
|
+
end_offset_cpu,
|
95
|
+
last_loc,
|
96
|
+
len(batch.input_ids),
|
97
|
+
)
|
98
|
+
self.last_loc = last_loc
|
99
|
+
|
100
|
+
bs = batch.batch_size()
|
101
|
+
assign_req_to_token_pool[(bs,)](
|
102
|
+
batch.req_pool_indices,
|
103
|
+
batch.req_to_token_pool.req_to_token,
|
104
|
+
batch.seq_lens,
|
105
|
+
end_offset,
|
106
|
+
batch.out_cache_loc,
|
107
|
+
batch.req_to_token_pool.req_to_token.shape[1],
|
108
|
+
triton.next_power_of_2(bs),
|
109
|
+
)
|
110
|
+
|
111
|
+
def generate_attn_arg_prefill(
|
112
|
+
self,
|
113
|
+
req_pool_indices: torch.Tensor,
|
114
|
+
paged_kernel_lens: torch.Tensor,
|
115
|
+
paged_kernel_lens_sum: int,
|
116
|
+
req_to_token: torch.Tensor,
|
117
|
+
):
|
118
|
+
bs = len(req_pool_indices)
|
119
|
+
|
120
|
+
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
121
|
+
|
122
|
+
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
|
123
|
+
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
124
|
+
|
125
|
+
self.qo_indptr = (
|
126
|
+
torch.arange(0, bs + 1, dtype=torch.int32, device=self.device)
|
127
|
+
* self.draft_token_num
|
128
|
+
)
|
129
|
+
|
130
|
+
kv_indices = torch.empty(
|
131
|
+
cum_kv_seq_len[-1], dtype=torch.int32, device=self.device
|
132
|
+
)
|
133
|
+
|
134
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
135
|
+
req_to_token,
|
136
|
+
req_pool_indices,
|
137
|
+
paged_kernel_lens,
|
138
|
+
cum_kv_seq_len,
|
139
|
+
None,
|
140
|
+
kv_indices,
|
141
|
+
req_to_token.size(1),
|
142
|
+
)
|
143
|
+
return kv_indices, cum_kv_seq_len, self.qo_indptr, self.custom_mask
|
144
|
+
|
145
|
+
def _fill_requests(
|
146
|
+
self,
|
147
|
+
batch: ScheduleBatch,
|
148
|
+
logits_output: torch.Tensor,
|
149
|
+
):
|
150
|
+
accept_index_cpu = self.accept_index.tolist()
|
151
|
+
predict_cpu = self.predict.tolist()
|
152
|
+
has_finished = False
|
153
|
+
|
154
|
+
# Iterate every accepted token and check if req has finished after append the token
|
155
|
+
# should be checked BEFORE free kv cache slots
|
156
|
+
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
157
|
+
for j, idx in enumerate(accept_index_row):
|
158
|
+
if idx == -1:
|
159
|
+
break
|
160
|
+
id = predict_cpu[idx]
|
161
|
+
req.output_ids.append(id)
|
162
|
+
req.check_finished()
|
163
|
+
if req.finished():
|
164
|
+
has_finished = True
|
165
|
+
# set all tokens after finished token to -1 and break
|
166
|
+
self.accept_index[i, j + 1 :] = -1
|
167
|
+
break
|
168
|
+
else:
|
169
|
+
if req.grammar is not None:
|
170
|
+
try:
|
171
|
+
req.grammar.accept_token(id)
|
172
|
+
except ValueError as e:
|
173
|
+
logger.info(
|
174
|
+
f"{i=}, {req=}\n"
|
175
|
+
f"{self.accept_index=}\n"
|
176
|
+
f"{self.predict=}\n"
|
177
|
+
)
|
178
|
+
raise e
|
179
|
+
req.spec_verify_ct += 1
|
180
|
+
if has_finished:
|
181
|
+
self.accept_length = (self.accept_index != -1).sum(dim=1) - 1
|
182
|
+
self.accept_index = self.accept_index[self.accept_index != -1]
|
183
|
+
|
184
|
+
logits_output.next_token_logits = logits_output.next_token_logits[
|
185
|
+
self.accept_index
|
186
|
+
]
|
187
|
+
if logits_output.hidden_states:
|
188
|
+
logits_output.hidden_states = logits_output.hidden_states[self.accept_index]
|
189
|
+
self.verified_id = self.predict[self.accept_index]
|
190
|
+
|
191
|
+
def _free_cache(self, batch: ScheduleBatch, page_size: int):
|
192
|
+
bs = batch.batch_size()
|
193
|
+
# Free the KV cache for unaccepted tokens
|
194
|
+
if page_size == 1:
|
195
|
+
# TODO: boolean array index leads to a device sync. Remove it.
|
196
|
+
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
197
|
+
evict_mask[self.accept_index] = False
|
198
|
+
batch.token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
199
|
+
batch.out_cache_loc = batch.out_cache_loc[self.accept_index]
|
200
|
+
else:
|
201
|
+
# Shift the accepted tokens to the beginning.
|
202
|
+
# Only evict the last part
|
203
|
+
src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
|
204
|
+
batch.seq_lens,
|
205
|
+
batch.out_cache_loc,
|
206
|
+
self.accept_index,
|
207
|
+
self.accept_length,
|
208
|
+
self.draft_token_num,
|
209
|
+
page_size,
|
210
|
+
)
|
211
|
+
to_free_slots = torch.empty(
|
212
|
+
(to_free_num_slots.sum().item(),),
|
213
|
+
dtype=torch.int64,
|
214
|
+
device=to_free_num_slots.device,
|
215
|
+
)
|
216
|
+
|
217
|
+
# out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
|
218
|
+
# accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
|
219
|
+
# tgt_cache_loc: [0 1 , 3 4 , 6 ]
|
220
|
+
# to_free_slots: [ 2, 5, 7 8]
|
221
|
+
# to_free_slots also needs to be page-aligned without the first partial page
|
222
|
+
#
|
223
|
+
# split each row of out_cache_loc into two parts.
|
224
|
+
# 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
|
225
|
+
# 2. the second part goes to to_free_slots.
|
226
|
+
get_target_cache_loc[(bs,)](
|
227
|
+
tgt_cache_loc,
|
228
|
+
to_free_slots,
|
229
|
+
self.accept_length,
|
230
|
+
to_free_num_slots,
|
231
|
+
batch.out_cache_loc,
|
232
|
+
self.draft_token_num,
|
233
|
+
next_power_of_2(self.draft_token_num),
|
234
|
+
next_power_of_2(bs),
|
235
|
+
)
|
236
|
+
|
237
|
+
# Free the kv cache
|
238
|
+
batch.token_to_kv_pool_allocator.free(to_free_slots)
|
239
|
+
|
240
|
+
# Copy the kv cache
|
241
|
+
batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
|
242
|
+
tgt_cache_loc, src_cache_loc
|
243
|
+
)
|
244
|
+
batch.out_cache_loc = tgt_cache_loc
|
245
|
+
|
246
|
+
assign_req_to_token_pool[(bs,)](
|
247
|
+
batch.req_pool_indices,
|
248
|
+
batch.req_to_token_pool.req_to_token,
|
249
|
+
batch.seq_lens,
|
250
|
+
batch.seq_lens + self.accept_length + 1,
|
251
|
+
batch.out_cache_loc,
|
252
|
+
batch.req_to_token_pool.req_to_token.shape[1],
|
253
|
+
triton.next_power_of_2(bs),
|
254
|
+
)
|
255
|
+
|
256
|
+
def _greedy_verify(
|
257
|
+
self,
|
258
|
+
batch: ScheduleBatch,
|
259
|
+
logits_output: LogitsProcessorOutput,
|
260
|
+
):
|
261
|
+
bs = batch.batch_size()
|
262
|
+
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
263
|
+
target_predict = target_predict.reshape(bs, self.draft_token_num)
|
264
|
+
|
265
|
+
candidates = self.draft_token.reshape(bs, self.draft_token_num)
|
266
|
+
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
|
267
|
+
predict_shape[-1] += 1
|
268
|
+
self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device)
|
269
|
+
self.accept_index = torch.full(
|
270
|
+
(bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device
|
271
|
+
)
|
272
|
+
self.accept_length = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
273
|
+
|
274
|
+
verify_tree_greedy(
|
275
|
+
predicts=self.predict, # mutable
|
276
|
+
accept_index=self.accept_index, # mutable
|
277
|
+
accept_token_num=self.accept_length, # mutable
|
278
|
+
candidates=candidates,
|
279
|
+
retrive_index=self.retrive_index,
|
280
|
+
retrive_next_token=self.retrive_next_token,
|
281
|
+
retrive_next_sibling=self.retrive_next_sibling,
|
282
|
+
target_predict=target_predict,
|
283
|
+
)
|
284
|
+
|
285
|
+
def _sampling_verify(
|
286
|
+
self,
|
287
|
+
batch: ScheduleBatch,
|
288
|
+
logits_output: LogitsProcessorOutput,
|
289
|
+
sampling_info: SamplingBatchInfo,
|
290
|
+
):
|
291
|
+
bs = batch.batch_size()
|
292
|
+
candidates = self.draft_token.reshape(bs, self.draft_token_num)
|
293
|
+
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
|
294
|
+
predict_shape[-1] += 1
|
295
|
+
self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device)
|
296
|
+
self.accept_index = torch.full(
|
297
|
+
(bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device
|
298
|
+
)
|
299
|
+
self.accept_length = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
300
|
+
# apply temperature and get target probs
|
301
|
+
expanded_temperature = torch.repeat_interleave(
|
302
|
+
sampling_info.temperatures, self.draft_token_num, dim=0
|
303
|
+
) # (bs * draft_token_num, 1)
|
304
|
+
|
305
|
+
target_probs = F.softmax(
|
306
|
+
logits_output.next_token_logits / expanded_temperature, dim=-1
|
307
|
+
) # (bs * draft_token_num, vocab_size)
|
308
|
+
|
309
|
+
# NOTE: The test shows that top_p_renorm_prob and top_k_renorm_prob are the key factors
|
310
|
+
# contributing to the poor performance of _sampling_verify.
|
311
|
+
target_probs = top_k_renorm_prob(
|
312
|
+
target_probs,
|
313
|
+
torch.repeat_interleave(sampling_info.top_ks, self.draft_token_num, dim=0),
|
314
|
+
) # (bs * draft_token_num, vocab_size)
|
315
|
+
|
316
|
+
if sampling_info.need_top_p_sampling:
|
317
|
+
# logger.info("Using top-p sampling in speculative decoding verification.")
|
318
|
+
target_probs = top_p_renorm_prob(
|
319
|
+
target_probs,
|
320
|
+
torch.repeat_interleave(
|
321
|
+
sampling_info.top_ps, self.draft_token_num, dim=0
|
322
|
+
),
|
323
|
+
)
|
324
|
+
|
325
|
+
target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
|
326
|
+
draft_probs = torch.zeros(
|
327
|
+
target_probs.shape, dtype=torch.float32, device=self.device
|
328
|
+
)
|
329
|
+
|
330
|
+
# coins for rejection sampling
|
331
|
+
coins = torch.rand_like(candidates, dtype=torch.float32, device=self.device)
|
332
|
+
# coins for final sampling
|
333
|
+
coins_for_final_sampling = torch.rand(
|
334
|
+
(bs,), dtype=torch.float32, device=self.device
|
335
|
+
)
|
336
|
+
tree_speculative_sampling_target_only(
|
337
|
+
predicts=self.predict, # mutable
|
338
|
+
accept_index=self.accept_index, # mutable
|
339
|
+
accept_token_num=self.accept_length, # mutable
|
340
|
+
candidates=candidates.to(torch.int64),
|
341
|
+
retrive_index=self.retrive_index.to(torch.int64),
|
342
|
+
retrive_next_token=self.retrive_next_token.to(torch.int64),
|
343
|
+
retrive_next_sibling=self.retrive_next_sibling.to(torch.int64),
|
344
|
+
uniform_samples=coins,
|
345
|
+
uniform_samples_for_final_sampling=coins_for_final_sampling,
|
346
|
+
target_probs=target_probs,
|
347
|
+
draft_probs=draft_probs,
|
348
|
+
threshold_single=global_server_args_dict[
|
349
|
+
"speculative_accept_threshold_single"
|
350
|
+
],
|
351
|
+
threshold_acc=global_server_args_dict["speculative_accept_threshold_acc"],
|
352
|
+
deterministic=True,
|
353
|
+
)
|
354
|
+
|
355
|
+
def verify(
|
356
|
+
self,
|
357
|
+
batch: ScheduleBatch,
|
358
|
+
logits_output: LogitsProcessorOutput,
|
359
|
+
page_size: int,
|
360
|
+
vocab_mask: Optional[torch.Tensor] = None, # For grammar
|
361
|
+
) -> torch.Tensor:
|
362
|
+
bs = self.retrive_index.shape[0]
|
363
|
+
sampling_info = batch.sampling_info
|
364
|
+
|
365
|
+
if bs != len(sampling_info):
|
366
|
+
sampling_info = copy.deepcopy(sampling_info)
|
367
|
+
# NOTE: retrive_index are the indices of the requests that are kept.
|
368
|
+
sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
|
369
|
+
|
370
|
+
# Apply the custom logit processors if registered in the sampling info.
|
371
|
+
if sampling_info.has_custom_logit_processor:
|
372
|
+
apply_custom_logit_processor(
|
373
|
+
logits_output.next_token_logits,
|
374
|
+
sampling_info,
|
375
|
+
num_tokens_in_batch=self.draft_token_num,
|
376
|
+
)
|
377
|
+
|
378
|
+
# Apply penalty
|
379
|
+
if sampling_info.penalizer_orchestrator.is_required:
|
380
|
+
# This is a relaxed version of penalties for speculative decoding.
|
381
|
+
linear_penalty = torch.zeros(
|
382
|
+
(bs, logits_output.next_token_logits.shape[1]),
|
383
|
+
dtype=torch.float32,
|
384
|
+
device=self.device,
|
385
|
+
)
|
386
|
+
sampling_info.apply_logits_bias(linear_penalty)
|
387
|
+
logits_output.next_token_logits.add_(
|
388
|
+
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
|
389
|
+
)
|
390
|
+
|
391
|
+
# Apply grammar mask
|
392
|
+
if vocab_mask is not None:
|
393
|
+
assert self.grammar is not None
|
394
|
+
self.grammar.apply_vocab_mask(
|
395
|
+
logits=logits_output.next_token_logits, vocab_mask=vocab_mask
|
396
|
+
)
|
397
|
+
|
398
|
+
# Sample tokens. Force greedy sampling on AMD
|
399
|
+
is_all_greedy = sampling_info.is_all_greedy
|
400
|
+
if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE):
|
401
|
+
logger.warning(
|
402
|
+
"Tree speculative sampling kernel unavailable (likely AMD/HIP build). "
|
403
|
+
"Falling back to greedy verification."
|
404
|
+
)
|
405
|
+
|
406
|
+
if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
|
407
|
+
self._greedy_verify(batch, logits_output)
|
408
|
+
else:
|
409
|
+
# NOTE: Compared with greedy_verify, the performance of _sampling_verify is relatively poor.
|
410
|
+
self._greedy_verify(batch, logits_output)
|
411
|
+
# self._sampling_verify(batch, logits_output, sampling_info)
|
412
|
+
|
413
|
+
self._fill_requests(batch, logits_output)
|
414
|
+
self._free_cache(batch, page_size)
|
415
|
+
|
416
|
+
accept_length_cpu = self.accept_length.cpu()
|
417
|
+
num_accepted_tokens = accept_length_cpu.sum().item()
|
418
|
+
|
419
|
+
batch.seq_lens.add_(self.accept_length + 1)
|
420
|
+
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
|
421
|
+
|
422
|
+
return logits_output, self.verified_id, num_accepted_tokens
|
423
|
+
|
424
|
+
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
|
425
|
+
pass
|
426
|
+
|
427
|
+
def merge_batch(self, spec_info: NgramVerifyInput):
|
428
|
+
pass
|