sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,19 @@
|
|
1
1
|
import logging
|
2
|
+
import os
|
2
3
|
import time
|
3
|
-
from typing import List, Optional,
|
4
|
+
from typing import List, Optional, Tuple
|
4
5
|
|
5
6
|
import torch
|
7
|
+
from huggingface_hub import snapshot_download
|
6
8
|
|
7
9
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
8
|
-
from sglang.srt.managers.schedule_batch import
|
10
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
9
11
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
10
12
|
from sglang.srt.model_executor.forward_batch_info import (
|
11
13
|
CaptureHiddenMode,
|
12
14
|
ForwardBatch,
|
13
15
|
ForwardMode,
|
14
16
|
)
|
15
|
-
from sglang.srt.model_executor.model_runner import ModelRunner
|
16
17
|
from sglang.srt.server_args import ServerArgs
|
17
18
|
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
18
19
|
EAGLEDraftCudaGraphRunner,
|
@@ -20,11 +21,12 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
|
20
21
|
from sglang.srt.speculative.eagle_utils import (
|
21
22
|
EagleDraftInput,
|
22
23
|
EagleVerifyInput,
|
24
|
+
EagleVerifyOutput,
|
23
25
|
assign_draft_cache_locs,
|
24
26
|
fast_topk,
|
25
27
|
select_top_k_tokens,
|
26
28
|
)
|
27
|
-
from sglang.srt.
|
29
|
+
from sglang.srt.utils import get_available_gpu_memory
|
28
30
|
|
29
31
|
logger = logging.getLogger(__name__)
|
30
32
|
|
@@ -40,10 +42,39 @@ class EAGLEWorker(TpModelWorker):
|
|
40
42
|
nccl_port: int,
|
41
43
|
target_worker: TpModelWorker,
|
42
44
|
):
|
45
|
+
# Parse arguments
|
46
|
+
self.server_args = server_args
|
47
|
+
self.topk = server_args.speculative_eagle_topk
|
48
|
+
self.speculative_num_steps = server_args.speculative_num_steps
|
49
|
+
self.padded_static_len = self.speculative_num_steps + 1
|
50
|
+
self.enable_nan_detection = server_args.enable_nan_detection
|
51
|
+
self.gpu_id = gpu_id
|
52
|
+
self.device = server_args.device
|
53
|
+
self.target_worker = target_worker
|
54
|
+
|
55
|
+
# Override context length with target model's context length
|
56
|
+
server_args.context_length = target_worker.model_runner.model_config.context_len
|
57
|
+
|
43
58
|
# Do not capture cuda graph in `super().__init__()`
|
44
|
-
#
|
59
|
+
# It will be captured later.
|
45
60
|
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
46
61
|
server_args.disable_cuda_graph = True
|
62
|
+
# Share the allocator with a target worker.
|
63
|
+
# Draft and target worker own their own KV cache pools.
|
64
|
+
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
65
|
+
target_worker.get_memory_pool()
|
66
|
+
)
|
67
|
+
|
68
|
+
# Load hot token ids
|
69
|
+
if server_args.speculative_token_map is not None:
|
70
|
+
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
71
|
+
server_args.json_model_override_args = (
|
72
|
+
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
|
73
|
+
)
|
74
|
+
else:
|
75
|
+
self.hot_token_id = None
|
76
|
+
|
77
|
+
# Init draft worker
|
47
78
|
super().__init__(
|
48
79
|
gpu_id=gpu_id,
|
49
80
|
tp_rank=tp_rank,
|
@@ -51,26 +82,27 @@ class EAGLEWorker(TpModelWorker):
|
|
51
82
|
nccl_port=nccl_port,
|
52
83
|
dp_rank=dp_rank,
|
53
84
|
is_draft_worker=True,
|
85
|
+
req_to_token_pool=self.req_to_token_pool,
|
86
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
54
87
|
)
|
55
|
-
self.target_worker = target_worker
|
56
|
-
self.finish_extend_len = []
|
57
88
|
|
58
|
-
#
|
59
|
-
|
60
|
-
self.
|
61
|
-
|
62
|
-
|
89
|
+
# Share the embedding and lm_head
|
90
|
+
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
91
|
+
if self.hot_token_id is not None:
|
92
|
+
head = head.clone()
|
93
|
+
self.hot_token_id = self.hot_token_id.to(head.device)
|
94
|
+
head.data = head.data[self.hot_token_id]
|
95
|
+
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
96
|
+
self.draft_model_runner.server_args.disable_cuda_graph = (
|
97
|
+
backup_disable_cuda_graph
|
63
98
|
)
|
64
|
-
self.server_args = server_args
|
65
99
|
|
66
|
-
|
67
|
-
|
68
|
-
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
69
|
-
self.model_runner.model.set_embed_and_head(embed, head)
|
70
|
-
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
100
|
+
self.init_attention_backend()
|
101
|
+
self.init_cuda_graphs()
|
71
102
|
|
103
|
+
def init_attention_backend(self):
|
72
104
|
# Create multi-step attn backends and cuda graph runners
|
73
|
-
if server_args.attention_backend == "flashinfer":
|
105
|
+
if self.server_args.attention_backend == "flashinfer":
|
74
106
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
75
107
|
FlashInferMultiStepDraftBackend,
|
76
108
|
)
|
@@ -80,7 +112,7 @@ class EAGLEWorker(TpModelWorker):
|
|
80
112
|
self.topk,
|
81
113
|
self.speculative_num_steps,
|
82
114
|
)
|
83
|
-
elif server_args.attention_backend == "triton":
|
115
|
+
elif self.server_args.attention_backend == "triton":
|
84
116
|
from sglang.srt.layers.attention.triton_backend import (
|
85
117
|
TritonMultiStepDraftBackend,
|
86
118
|
)
|
@@ -92,11 +124,9 @@ class EAGLEWorker(TpModelWorker):
|
|
92
124
|
)
|
93
125
|
else:
|
94
126
|
raise ValueError(
|
95
|
-
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
|
127
|
+
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
|
96
128
|
)
|
97
|
-
|
98
|
-
self.model_runner.draft_attn_backend = self.draft_attn_backend
|
99
|
-
self.init_cuda_graphs()
|
129
|
+
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
100
130
|
|
101
131
|
def init_cuda_graphs(self):
|
102
132
|
"""Capture cuda graphs."""
|
@@ -106,55 +136,81 @@ class EAGLEWorker(TpModelWorker):
|
|
106
136
|
return
|
107
137
|
|
108
138
|
tic = time.time()
|
109
|
-
logger.info(
|
139
|
+
logger.info(
|
140
|
+
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
141
|
+
)
|
110
142
|
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
111
|
-
logger.info(
|
143
|
+
logger.info(
|
144
|
+
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
145
|
+
)
|
112
146
|
|
113
|
-
|
147
|
+
@property
|
148
|
+
def draft_model_runner(self):
|
149
|
+
return self.model_runner
|
150
|
+
|
151
|
+
def forward_batch_speculative_generation(
|
152
|
+
self, batch: ScheduleBatch
|
153
|
+
) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
|
154
|
+
"""Run speculative decoding forward.
|
155
|
+
|
156
|
+
NOTE: Many states of batch is modified as you go through. It is not guaranteed
|
157
|
+
the final output batch doesn't have the same state as the input.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
batch: The batch to run forward. The state of the batch is modified as it runs.
|
161
|
+
Returns:
|
162
|
+
A tuple of the final logit output of the target model, next tokens accepeted,
|
163
|
+
the batch id (used for overlap schedule), and number of accepeted tokens.
|
164
|
+
"""
|
165
|
+
assert not batch.spec_algorithm.is_none()
|
114
166
|
if batch.forward_mode.is_decode():
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
(
|
120
|
-
|
121
|
-
|
122
|
-
verified_id,
|
123
|
-
self.finish_extend_len,
|
124
|
-
accept_length_cpu,
|
125
|
-
model_worker_batch,
|
126
|
-
) = self.verify(batch, spec_info)
|
127
|
-
batch.spec_info = next_draft_input
|
128
|
-
# if it is None, means all requsets are finished
|
167
|
+
spec_info, to_free_cache_loc = self.draft(batch)
|
168
|
+
logits_output, verify_output, model_worker_batch = self.verify(
|
169
|
+
batch, spec_info
|
170
|
+
)
|
171
|
+
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
|
172
|
+
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
|
173
|
+
# if it is None, means all requests are finished
|
129
174
|
if batch.spec_info.verified_id is not None:
|
130
175
|
self.forward_draft_extend_after_decode(batch)
|
176
|
+
|
131
177
|
return (
|
132
178
|
logits_output,
|
133
|
-
verified_id,
|
134
|
-
model_worker_batch,
|
135
|
-
sum(
|
179
|
+
verify_output.verified_id,
|
180
|
+
model_worker_batch.bid,
|
181
|
+
sum(verify_output.accept_length_per_req_cpu),
|
136
182
|
)
|
137
183
|
|
138
184
|
else:
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
143
|
-
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
144
|
-
model_worker_batch
|
185
|
+
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
|
186
|
+
self.forward_draft_extend(
|
187
|
+
batch, logits_output.hidden_states, next_token_ids
|
145
188
|
)
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
189
|
+
return logits_output, next_token_ids, bid, 0
|
190
|
+
|
191
|
+
def forward_target_extend(
|
192
|
+
self, batch: ScheduleBatch
|
193
|
+
) -> Tuple[LogitsProcessorOutput, List[int], int]:
|
194
|
+
"""Run the target extend.
|
195
|
+
|
196
|
+
Args:
|
197
|
+
batch: The batch to run. States could be modified.
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
logits_output: The output of logits. It will contain the full hidden states.
|
201
|
+
next_token_ids: Next token ids generated.
|
202
|
+
bid: The model batch ID. Used for overlap schedule.
|
203
|
+
"""
|
204
|
+
# Forward with the target model and get hidden states.
|
205
|
+
# We need the full hidden states to prefill the KV cache of the draft model.
|
206
|
+
model_worker_batch = batch.get_model_worker_batch()
|
207
|
+
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
208
|
+
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
209
|
+
model_worker_batch
|
210
|
+
)
|
211
|
+
return logits_output, next_token_ids, model_worker_batch.bid
|
154
212
|
|
155
213
|
def draft(self, batch: ScheduleBatch):
|
156
|
-
self._set_mem_pool(batch, self.model_runner)
|
157
|
-
|
158
214
|
# Parse args
|
159
215
|
num_seqs = batch.batch_size()
|
160
216
|
spec_info = batch.spec_info
|
@@ -172,7 +228,6 @@ class EAGLEWorker(TpModelWorker):
|
|
172
228
|
self.topk,
|
173
229
|
self.speculative_num_steps,
|
174
230
|
)
|
175
|
-
|
176
231
|
batch.out_cache_loc = out_cache_loc
|
177
232
|
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
178
233
|
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
@@ -180,11 +235,12 @@ class EAGLEWorker(TpModelWorker):
|
|
180
235
|
# Get forward batch
|
181
236
|
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
182
237
|
model_worker_batch = batch.get_model_worker_batch()
|
183
|
-
forward_batch = ForwardBatch.init_new(
|
238
|
+
forward_batch = ForwardBatch.init_new(
|
239
|
+
model_worker_batch, self.draft_model_runner
|
240
|
+
)
|
184
241
|
can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
185
242
|
forward_batch
|
186
243
|
)
|
187
|
-
|
188
244
|
if can_cuda_graph:
|
189
245
|
score_list, token_list, parents_list = self.cuda_graph_runner.replay(
|
190
246
|
forward_batch
|
@@ -192,7 +248,9 @@ class EAGLEWorker(TpModelWorker):
|
|
192
248
|
else:
|
193
249
|
# Initialize attention backend
|
194
250
|
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
195
|
-
|
251
|
+
forward_batch = ForwardBatch.init_new(
|
252
|
+
model_worker_batch, self.draft_model_runner
|
253
|
+
)
|
196
254
|
# Run forward steps
|
197
255
|
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
198
256
|
|
@@ -209,10 +267,7 @@ class EAGLEWorker(TpModelWorker):
|
|
209
267
|
batch.sampling_info.is_all_greedy,
|
210
268
|
)
|
211
269
|
|
212
|
-
|
213
|
-
batch.token_to_kv_pool.free(out_cache_loc)
|
214
|
-
self._set_mem_pool(batch, self.target_worker.model_runner)
|
215
|
-
return ret
|
270
|
+
return ret, out_cache_loc
|
216
271
|
|
217
272
|
def draft_forward(self, forward_batch: ForwardBatch):
|
218
273
|
# Parse args
|
@@ -223,6 +278,8 @@ class EAGLEWorker(TpModelWorker):
|
|
223
278
|
spec_info.topk_index,
|
224
279
|
spec_info.hidden_states,
|
225
280
|
)
|
281
|
+
if self.hot_token_id is not None:
|
282
|
+
topk_index = self.hot_token_id[topk_index]
|
226
283
|
|
227
284
|
# Return values
|
228
285
|
score_list: List[torch.Tensor] = []
|
@@ -260,8 +317,11 @@ class EAGLEWorker(TpModelWorker):
|
|
260
317
|
logits_output = self.model_runner.model.forward(
|
261
318
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
262
319
|
)
|
320
|
+
self._detect_nan_if_needed(logits_output)
|
263
321
|
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
264
322
|
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
323
|
+
if self.hot_token_id is not None:
|
324
|
+
topk_index = self.hot_token_id[topk_index]
|
265
325
|
hidden_states = logits_output.hidden_states
|
266
326
|
|
267
327
|
return score_list, token_list, parents_list
|
@@ -274,68 +334,135 @@ class EAGLEWorker(TpModelWorker):
|
|
274
334
|
logits_output, _ = self.target_worker.forward_batch_generation(
|
275
335
|
model_worker_batch, skip_sample=True
|
276
336
|
)
|
337
|
+
self._detect_nan_if_needed(logits_output)
|
277
338
|
spec_info.hidden_states = logits_output.hidden_states
|
278
|
-
res = spec_info.verify(
|
339
|
+
res: EagleVerifyOutput = spec_info.verify(
|
340
|
+
batch, logits_output, self.token_to_kv_pool_allocator
|
341
|
+
)
|
342
|
+
|
343
|
+
# Post process based on verified outputs.
|
344
|
+
# Pick indices that we care (accepeted)
|
345
|
+
logits_output.next_token_logits = logits_output.next_token_logits[
|
346
|
+
res.accepeted_indices_cpu
|
347
|
+
]
|
348
|
+
logits_output.hidden_states = logits_output.hidden_states[
|
349
|
+
res.accepeted_indices_cpu
|
350
|
+
]
|
351
|
+
# Prepare the batch for the next draft forwards.
|
279
352
|
batch.forward_mode = ForwardMode.DECODE
|
280
|
-
|
353
|
+
batch.spec_info = res.draft_input
|
354
|
+
|
355
|
+
if batch.return_logprob:
|
356
|
+
# Compute output logprobs using the sampler.
|
357
|
+
num_tokens_per_req = [
|
358
|
+
accept + 1 for accept in res.accept_length_per_req_cpu
|
359
|
+
]
|
360
|
+
self.target_worker.model_runner.update_output_logprobs(
|
361
|
+
logits_output,
|
362
|
+
batch.sampling_info,
|
363
|
+
batch.top_logprobs_nums,
|
364
|
+
batch.token_ids_logprobs,
|
365
|
+
res.verified_id,
|
366
|
+
# +1 for bonus token.
|
367
|
+
num_tokens_per_req=num_tokens_per_req,
|
368
|
+
)
|
281
369
|
|
282
|
-
|
283
|
-
|
370
|
+
# Add output logprobs to the request.
|
371
|
+
pt = 0
|
372
|
+
# NOTE: tolist() of these values are skipped when output is processed
|
373
|
+
next_token_logprobs = res.logits_output.next_token_logprobs.tolist()
|
374
|
+
verified_ids = res.verified_id.tolist()
|
375
|
+
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
|
376
|
+
for _ in range(num_tokens):
|
377
|
+
if req.return_logprob:
|
378
|
+
token_id = verified_ids[pt]
|
379
|
+
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
380
|
+
req.output_token_logprobs_idx.append(token_id)
|
381
|
+
if req.top_logprobs_num > 0:
|
382
|
+
req.output_top_logprobs_val.append(
|
383
|
+
res.logits_output.next_token_top_logprobs_val[pt]
|
384
|
+
)
|
385
|
+
req.output_top_logprobs_idx.append(
|
386
|
+
res.logits_output.next_token_top_logprobs_idx[pt]
|
387
|
+
)
|
388
|
+
pt += 1
|
389
|
+
|
390
|
+
return logits_output, res, model_worker_batch
|
391
|
+
|
392
|
+
def forward_draft_extend(
|
393
|
+
self,
|
394
|
+
batch: ScheduleBatch,
|
395
|
+
hidden_states: torch.Tensor,
|
396
|
+
next_token_ids: List[int],
|
397
|
+
):
|
398
|
+
"""Run draft model extend. This API modifies the states of the batch.
|
399
|
+
|
400
|
+
Args:
|
401
|
+
batch: The batch to run.
|
402
|
+
hidden_states: Hidden states from the target model forward
|
403
|
+
next_token_ids: Next token ids generated from the target forward.
|
404
|
+
"""
|
405
|
+
batch.spec_info = EagleDraftInput(
|
406
|
+
hidden_states=hidden_states,
|
407
|
+
verified_id=next_token_ids,
|
408
|
+
)
|
284
409
|
batch.spec_info.prepare_for_extend(batch)
|
285
410
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
286
411
|
model_worker_batch = batch.get_model_worker_batch()
|
287
|
-
forward_batch = ForwardBatch.init_new(
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
412
|
+
forward_batch = ForwardBatch.init_new(
|
413
|
+
model_worker_batch, self.draft_model_runner
|
414
|
+
)
|
415
|
+
forward_batch.return_logprob = False
|
416
|
+
logits_output = self.draft_model_runner.forward(forward_batch)
|
417
|
+
self._detect_nan_if_needed(logits_output)
|
418
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
419
|
+
assert forward_batch.spec_info is batch.spec_info
|
420
|
+
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
295
421
|
|
296
422
|
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
297
423
|
seq_lens_backup = batch.seq_lens
|
298
|
-
req_pool_indices_backup = batch.req_pool_indices
|
299
|
-
|
300
|
-
self._set_mem_pool(batch, self.model_runner)
|
301
424
|
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
302
425
|
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
|
303
426
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
427
|
+
# We don't need logprob for this extend.
|
428
|
+
original_return_logprob = batch.return_logprob
|
429
|
+
batch.return_logprob = False
|
304
430
|
model_worker_batch = batch.get_model_worker_batch()
|
305
|
-
forward_batch = ForwardBatch.init_new(
|
306
|
-
|
307
|
-
|
308
|
-
|
431
|
+
forward_batch = ForwardBatch.init_new(
|
432
|
+
model_worker_batch, self.draft_model_runner
|
433
|
+
)
|
434
|
+
logits_output = self.draft_model_runner.forward(forward_batch)
|
435
|
+
self._detect_nan_if_needed(logits_output)
|
436
|
+
assert forward_batch.spec_info is batch.spec_info
|
437
|
+
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
309
438
|
|
310
439
|
# Restore backup.
|
311
440
|
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
441
|
+
batch.return_logprob = original_return_logprob
|
312
442
|
batch.forward_mode = ForwardMode.DECODE
|
313
443
|
batch.seq_lens = seq_lens_backup
|
314
|
-
batch.req_pool_indices = req_pool_indices_backup
|
315
444
|
|
316
445
|
def capture_for_decode(
|
317
|
-
self, logits_output: LogitsProcessorOutput,
|
446
|
+
self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
|
318
447
|
):
|
319
448
|
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
self.model_runner.token_to_kv_pool.free(kv_indices)
|
341
|
-
self.model_runner.req_to_token_pool.free(req.req_pool_idx)
|
449
|
+
draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
|
450
|
+
draft_input.hidden_states = logits_output.hidden_states
|
451
|
+
|
452
|
+
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
|
453
|
+
if self.enable_nan_detection:
|
454
|
+
logits = logits_output.next_token_logits
|
455
|
+
if torch.any(torch.isnan(logits)):
|
456
|
+
logger.warning("Detected errors during sampling! NaN in the logits.")
|
457
|
+
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
458
|
+
|
459
|
+
|
460
|
+
def load_token_map(token_map_path: str) -> List[int]:
|
461
|
+
if not os.path.exists(token_map_path):
|
462
|
+
cache_dir = snapshot_download(
|
463
|
+
os.path.dirname(token_map_path),
|
464
|
+
ignore_patterns=["*.bin", "*.safetensors"],
|
465
|
+
)
|
466
|
+
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
467
|
+
hot_token_id = torch.load(token_map_path)
|
468
|
+
return torch.tensor(hot_token_id, dtype=torch.int32)
|
@@ -5,30 +5,18 @@ class SpeculativeAlgorithm(IntEnum):
|
|
5
5
|
NONE = auto()
|
6
6
|
EAGLE = auto()
|
7
7
|
|
8
|
-
# NEXTN spec decoding is for DeepSeek V3/R1
|
9
|
-
# currently it's implemented based on EAGLE
|
10
|
-
NEXTN = auto()
|
11
|
-
|
12
8
|
def is_none(self):
|
13
9
|
return self == SpeculativeAlgorithm.NONE
|
14
10
|
|
15
11
|
def is_eagle(self):
|
16
|
-
return self == SpeculativeAlgorithm.EAGLE
|
17
|
-
|
18
|
-
def is_nextn(self):
|
19
|
-
return self == SpeculativeAlgorithm.NEXTN
|
12
|
+
return self == SpeculativeAlgorithm.EAGLE
|
20
13
|
|
21
14
|
@staticmethod
|
22
15
|
def from_string(name: str):
|
23
16
|
name_map = {
|
24
17
|
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
25
|
-
"NEXTN": SpeculativeAlgorithm.NEXTN,
|
26
18
|
None: SpeculativeAlgorithm.NONE,
|
27
19
|
}
|
28
20
|
if name is not None:
|
29
21
|
name = name.upper()
|
30
22
|
return name_map[name]
|
31
|
-
|
32
|
-
|
33
|
-
class SpecInfo:
|
34
|
-
pass
|