sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +302 -414
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +13 -8
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +144 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +773 -334
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +225 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +68 -37
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +102 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +56 -31
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +280 -81
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +135 -60
  181. sglang/srt/speculative/build_eagle_tree.py +8 -9
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
  183. sglang/srt/speculative/eagle_utils.py +92 -57
  184. sglang/srt/speculative/eagle_worker.py +238 -111
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,19 @@
1
1
  import logging
2
+ import os
2
3
  import time
3
- from typing import List, Optional, Union
4
+ from typing import List, Optional, Tuple
4
5
 
5
6
  import torch
7
+ from huggingface_hub import snapshot_download
6
8
 
7
9
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
8
- from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
10
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
9
11
  from sglang.srt.managers.tp_worker import TpModelWorker
10
12
  from sglang.srt.model_executor.forward_batch_info import (
11
13
  CaptureHiddenMode,
12
14
  ForwardBatch,
13
15
  ForwardMode,
14
16
  )
15
- from sglang.srt.model_executor.model_runner import ModelRunner
16
17
  from sglang.srt.server_args import ServerArgs
17
18
  from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
18
19
  EAGLEDraftCudaGraphRunner,
@@ -20,11 +21,12 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
20
21
  from sglang.srt.speculative.eagle_utils import (
21
22
  EagleDraftInput,
22
23
  EagleVerifyInput,
24
+ EagleVerifyOutput,
23
25
  assign_draft_cache_locs,
24
26
  fast_topk,
25
27
  select_top_k_tokens,
26
28
  )
27
- from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
29
+ from sglang.srt.utils import get_available_gpu_memory
28
30
 
29
31
  logger = logging.getLogger(__name__)
30
32
 
@@ -40,10 +42,39 @@ class EAGLEWorker(TpModelWorker):
40
42
  nccl_port: int,
41
43
  target_worker: TpModelWorker,
42
44
  ):
45
+ # Parse arguments
46
+ self.server_args = server_args
47
+ self.topk = server_args.speculative_eagle_topk
48
+ self.speculative_num_steps = server_args.speculative_num_steps
49
+ self.padded_static_len = self.speculative_num_steps + 1
50
+ self.enable_nan_detection = server_args.enable_nan_detection
51
+ self.gpu_id = gpu_id
52
+ self.device = server_args.device
53
+ self.target_worker = target_worker
54
+
55
+ # Override context length with target model's context length
56
+ server_args.context_length = target_worker.model_runner.model_config.context_len
57
+
43
58
  # Do not capture cuda graph in `super().__init__()`
44
- # We will capture it later
59
+ # It will be captured later.
45
60
  backup_disable_cuda_graph = server_args.disable_cuda_graph
46
61
  server_args.disable_cuda_graph = True
62
+ # Share the allocator with a target worker.
63
+ # Draft and target worker own their own KV cache pools.
64
+ self.req_to_token_pool, self.token_to_kv_pool_allocator = (
65
+ target_worker.get_memory_pool()
66
+ )
67
+
68
+ # Load hot token ids
69
+ if server_args.speculative_token_map is not None:
70
+ self.hot_token_id = load_token_map(server_args.speculative_token_map)
71
+ server_args.json_model_override_args = (
72
+ f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
73
+ )
74
+ else:
75
+ self.hot_token_id = None
76
+
77
+ # Init draft worker
47
78
  super().__init__(
48
79
  gpu_id=gpu_id,
49
80
  tp_rank=tp_rank,
@@ -51,26 +82,27 @@ class EAGLEWorker(TpModelWorker):
51
82
  nccl_port=nccl_port,
52
83
  dp_rank=dp_rank,
53
84
  is_draft_worker=True,
85
+ req_to_token_pool=self.req_to_token_pool,
86
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
54
87
  )
55
- self.target_worker = target_worker
56
- self.finish_extend_len = []
57
88
 
58
- # Parse arguments
59
- self.topk = server_args.speculative_eagle_topk
60
- self.speculative_num_steps = server_args.speculative_num_steps
61
- self.speculative_algorithm = SpeculativeAlgorithm.from_string(
62
- server_args.speculative_algorithm
89
+ # Share the embedding and lm_head
90
+ embed, head = self.target_worker.model_runner.model.get_embed_and_head()
91
+ if self.hot_token_id is not None:
92
+ head = head.clone()
93
+ self.hot_token_id = self.hot_token_id.to(head.device)
94
+ head.data = head.data[self.hot_token_id]
95
+ self.draft_model_runner.model.set_embed_and_head(embed, head)
96
+ self.draft_model_runner.server_args.disable_cuda_graph = (
97
+ backup_disable_cuda_graph
63
98
  )
64
- self.server_args = server_args
65
99
 
66
- # Share the embedding and lm_head
67
- if not self.speculative_algorithm.is_nextn():
68
- embed, head = self.target_worker.model_runner.model.get_embed_and_head()
69
- self.model_runner.model.set_embed_and_head(embed, head)
70
- self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
100
+ self.init_attention_backend()
101
+ self.init_cuda_graphs()
71
102
 
103
+ def init_attention_backend(self):
72
104
  # Create multi-step attn backends and cuda graph runners
73
- if server_args.attention_backend == "flashinfer":
105
+ if self.server_args.attention_backend == "flashinfer":
74
106
  from sglang.srt.layers.attention.flashinfer_backend import (
75
107
  FlashInferMultiStepDraftBackend,
76
108
  )
@@ -80,7 +112,7 @@ class EAGLEWorker(TpModelWorker):
80
112
  self.topk,
81
113
  self.speculative_num_steps,
82
114
  )
83
- elif server_args.attention_backend == "triton":
115
+ elif self.server_args.attention_backend == "triton":
84
116
  from sglang.srt.layers.attention.triton_backend import (
85
117
  TritonMultiStepDraftBackend,
86
118
  )
@@ -92,11 +124,9 @@ class EAGLEWorker(TpModelWorker):
92
124
  )
93
125
  else:
94
126
  raise ValueError(
95
- f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
127
+ f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
96
128
  )
97
-
98
- self.model_runner.draft_attn_backend = self.draft_attn_backend
99
- self.init_cuda_graphs()
129
+ self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
100
130
 
101
131
  def init_cuda_graphs(self):
102
132
  """Capture cuda graphs."""
@@ -106,55 +136,81 @@ class EAGLEWorker(TpModelWorker):
106
136
  return
107
137
 
108
138
  tic = time.time()
109
- logger.info("Capture cuda graph begin. This can take up to several minutes.")
139
+ logger.info(
140
+ f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
141
+ )
110
142
  self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
111
- logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
143
+ logger.info(
144
+ f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
145
+ )
112
146
 
113
- def forward_batch_speculative_generation(self, batch: ScheduleBatch):
147
+ @property
148
+ def draft_model_runner(self):
149
+ return self.model_runner
150
+
151
+ def forward_batch_speculative_generation(
152
+ self, batch: ScheduleBatch
153
+ ) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
154
+ """Run speculative decoding forward.
155
+
156
+ NOTE: Many states of batch is modified as you go through. It is not guaranteed
157
+ the final output batch doesn't have the same state as the input.
158
+
159
+ Args:
160
+ batch: The batch to run forward. The state of the batch is modified as it runs.
161
+ Returns:
162
+ A tuple of the final logit output of the target model, next tokens accepeted,
163
+ the batch id (used for overlap schedule), and number of accepeted tokens.
164
+ """
165
+ assert not batch.spec_algorithm.is_none()
114
166
  if batch.forward_mode.is_decode():
115
- # Draft
116
- spec_info: EagleVerifyInput = self.draft(batch)
117
-
118
- # Verify
119
- (
120
- next_draft_input,
121
- logits_output,
122
- verified_id,
123
- self.finish_extend_len,
124
- accept_length_cpu,
125
- model_worker_batch,
126
- ) = self.verify(batch, spec_info)
127
- batch.spec_info = next_draft_input
128
- # if it is None, means all requsets are finished
167
+ spec_info, to_free_cache_loc = self.draft(batch)
168
+ logits_output, verify_output, model_worker_batch = self.verify(
169
+ batch, spec_info
170
+ )
171
+ # Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
172
+ self.token_to_kv_pool_allocator.free(to_free_cache_loc)
173
+ # if it is None, means all requests are finished
129
174
  if batch.spec_info.verified_id is not None:
130
175
  self.forward_draft_extend_after_decode(batch)
176
+
131
177
  return (
132
178
  logits_output,
133
- verified_id,
134
- model_worker_batch,
135
- sum(accept_length_cpu),
179
+ verify_output.verified_id,
180
+ model_worker_batch.bid,
181
+ sum(verify_output.accept_length_per_req_cpu),
136
182
  )
137
183
 
138
184
  else:
139
- # Forward with the target model and get hidden states.
140
- # We need the full hidden states to prefill the KV cache of the draft model.
141
- model_worker_batch = batch.get_model_worker_batch()
142
- model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
143
- logits_output, next_token_ids = self.target_worker.forward_batch_generation(
144
- model_worker_batch
185
+ logits_output, next_token_ids, bid = self.forward_target_extend(batch)
186
+ self.forward_draft_extend(
187
+ batch, logits_output.hidden_states, next_token_ids
145
188
  )
146
-
147
- # Forward with the draft model.
148
- batch.spec_info = EagleDraftInput(
149
- hidden_states=logits_output.hidden_states,
150
- verified_id=next_token_ids,
151
- )
152
- self.forward_draft_extend(batch)
153
- return logits_output, next_token_ids, model_worker_batch, 0
189
+ return logits_output, next_token_ids, bid, 0
190
+
191
+ def forward_target_extend(
192
+ self, batch: ScheduleBatch
193
+ ) -> Tuple[LogitsProcessorOutput, List[int], int]:
194
+ """Run the target extend.
195
+
196
+ Args:
197
+ batch: The batch to run. States could be modified.
198
+
199
+ Returns:
200
+ logits_output: The output of logits. It will contain the full hidden states.
201
+ next_token_ids: Next token ids generated.
202
+ bid: The model batch ID. Used for overlap schedule.
203
+ """
204
+ # Forward with the target model and get hidden states.
205
+ # We need the full hidden states to prefill the KV cache of the draft model.
206
+ model_worker_batch = batch.get_model_worker_batch()
207
+ model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
208
+ logits_output, next_token_ids = self.target_worker.forward_batch_generation(
209
+ model_worker_batch
210
+ )
211
+ return logits_output, next_token_ids, model_worker_batch.bid
154
212
 
155
213
  def draft(self, batch: ScheduleBatch):
156
- self._set_mem_pool(batch, self.model_runner)
157
-
158
214
  # Parse args
159
215
  num_seqs = batch.batch_size()
160
216
  spec_info = batch.spec_info
@@ -172,7 +228,6 @@ class EAGLEWorker(TpModelWorker):
172
228
  self.topk,
173
229
  self.speculative_num_steps,
174
230
  )
175
-
176
231
  batch.out_cache_loc = out_cache_loc
177
232
  batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
178
233
  spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
@@ -180,11 +235,12 @@ class EAGLEWorker(TpModelWorker):
180
235
  # Get forward batch
181
236
  spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
182
237
  model_worker_batch = batch.get_model_worker_batch()
183
- forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
238
+ forward_batch = ForwardBatch.init_new(
239
+ model_worker_batch, self.draft_model_runner
240
+ )
184
241
  can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
185
242
  forward_batch
186
243
  )
187
-
188
244
  if can_cuda_graph:
189
245
  score_list, token_list, parents_list = self.cuda_graph_runner.replay(
190
246
  forward_batch
@@ -192,7 +248,9 @@ class EAGLEWorker(TpModelWorker):
192
248
  else:
193
249
  # Initialize attention backend
194
250
  self.draft_attn_backend.init_forward_metadata(forward_batch)
195
-
251
+ forward_batch = ForwardBatch.init_new(
252
+ model_worker_batch, self.draft_model_runner
253
+ )
196
254
  # Run forward steps
197
255
  score_list, token_list, parents_list = self.draft_forward(forward_batch)
198
256
 
@@ -209,10 +267,7 @@ class EAGLEWorker(TpModelWorker):
209
267
  batch.sampling_info.is_all_greedy,
210
268
  )
211
269
 
212
- # Free cache locations
213
- batch.token_to_kv_pool.free(out_cache_loc)
214
- self._set_mem_pool(batch, self.target_worker.model_runner)
215
- return ret
270
+ return ret, out_cache_loc
216
271
 
217
272
  def draft_forward(self, forward_batch: ForwardBatch):
218
273
  # Parse args
@@ -223,6 +278,8 @@ class EAGLEWorker(TpModelWorker):
223
278
  spec_info.topk_index,
224
279
  spec_info.hidden_states,
225
280
  )
281
+ if self.hot_token_id is not None:
282
+ topk_index = self.hot_token_id[topk_index]
226
283
 
227
284
  # Return values
228
285
  score_list: List[torch.Tensor] = []
@@ -260,8 +317,11 @@ class EAGLEWorker(TpModelWorker):
260
317
  logits_output = self.model_runner.model.forward(
261
318
  forward_batch.input_ids, forward_batch.positions, forward_batch
262
319
  )
320
+ self._detect_nan_if_needed(logits_output)
263
321
  probs = torch.softmax(logits_output.next_token_logits, dim=-1)
264
322
  topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
323
+ if self.hot_token_id is not None:
324
+ topk_index = self.hot_token_id[topk_index]
265
325
  hidden_states = logits_output.hidden_states
266
326
 
267
327
  return score_list, token_list, parents_list
@@ -274,68 +334,135 @@ class EAGLEWorker(TpModelWorker):
274
334
  logits_output, _ = self.target_worker.forward_batch_generation(
275
335
  model_worker_batch, skip_sample=True
276
336
  )
337
+ self._detect_nan_if_needed(logits_output)
277
338
  spec_info.hidden_states = logits_output.hidden_states
278
- res = spec_info.verify(batch, logits_output)
339
+ res: EagleVerifyOutput = spec_info.verify(
340
+ batch, logits_output, self.token_to_kv_pool_allocator
341
+ )
342
+
343
+ # Post process based on verified outputs.
344
+ # Pick indices that we care (accepeted)
345
+ logits_output.next_token_logits = logits_output.next_token_logits[
346
+ res.accepeted_indices_cpu
347
+ ]
348
+ logits_output.hidden_states = logits_output.hidden_states[
349
+ res.accepeted_indices_cpu
350
+ ]
351
+ # Prepare the batch for the next draft forwards.
279
352
  batch.forward_mode = ForwardMode.DECODE
280
- return res + (model_worker_batch,)
353
+ batch.spec_info = res.draft_input
354
+
355
+ if batch.return_logprob:
356
+ # Compute output logprobs using the sampler.
357
+ num_tokens_per_req = [
358
+ accept + 1 for accept in res.accept_length_per_req_cpu
359
+ ]
360
+ self.target_worker.model_runner.update_output_logprobs(
361
+ logits_output,
362
+ batch.sampling_info,
363
+ batch.top_logprobs_nums,
364
+ batch.token_ids_logprobs,
365
+ res.verified_id,
366
+ # +1 for bonus token.
367
+ num_tokens_per_req=num_tokens_per_req,
368
+ )
281
369
 
282
- def forward_draft_extend(self, batch: ScheduleBatch):
283
- self._set_mem_pool(batch, self.model_runner)
370
+ # Add output logprobs to the request.
371
+ pt = 0
372
+ # NOTE: tolist() of these values are skipped when output is processed
373
+ next_token_logprobs = res.logits_output.next_token_logprobs.tolist()
374
+ verified_ids = res.verified_id.tolist()
375
+ for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
376
+ for _ in range(num_tokens):
377
+ if req.return_logprob:
378
+ token_id = verified_ids[pt]
379
+ req.output_token_logprobs_val.append(next_token_logprobs[pt])
380
+ req.output_token_logprobs_idx.append(token_id)
381
+ if req.top_logprobs_num > 0:
382
+ req.output_top_logprobs_val.append(
383
+ res.logits_output.next_token_top_logprobs_val[pt]
384
+ )
385
+ req.output_top_logprobs_idx.append(
386
+ res.logits_output.next_token_top_logprobs_idx[pt]
387
+ )
388
+ pt += 1
389
+
390
+ return logits_output, res, model_worker_batch
391
+
392
+ def forward_draft_extend(
393
+ self,
394
+ batch: ScheduleBatch,
395
+ hidden_states: torch.Tensor,
396
+ next_token_ids: List[int],
397
+ ):
398
+ """Run draft model extend. This API modifies the states of the batch.
399
+
400
+ Args:
401
+ batch: The batch to run.
402
+ hidden_states: Hidden states from the target model forward
403
+ next_token_ids: Next token ids generated from the target forward.
404
+ """
405
+ batch.spec_info = EagleDraftInput(
406
+ hidden_states=hidden_states,
407
+ verified_id=next_token_ids,
408
+ )
284
409
  batch.spec_info.prepare_for_extend(batch)
285
410
  batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
286
411
  model_worker_batch = batch.get_model_worker_batch()
287
- forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
288
- logits_output = self.model_runner.forward(forward_batch)
289
- self.capture_for_decode(logits_output, forward_batch)
290
- self._set_mem_pool(batch, self.target_worker.model_runner)
291
-
292
- def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
293
- batch.token_to_kv_pool = runner.token_to_kv_pool
294
- batch.req_to_token_pool = runner.req_to_token_pool
412
+ forward_batch = ForwardBatch.init_new(
413
+ model_worker_batch, self.draft_model_runner
414
+ )
415
+ forward_batch.return_logprob = False
416
+ logits_output = self.draft_model_runner.forward(forward_batch)
417
+ self._detect_nan_if_needed(logits_output)
418
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
419
+ assert forward_batch.spec_info is batch.spec_info
420
+ self.capture_for_decode(logits_output, forward_batch.spec_info)
295
421
 
296
422
  def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
297
423
  seq_lens_backup = batch.seq_lens
298
- req_pool_indices_backup = batch.req_pool_indices
299
-
300
- self._set_mem_pool(batch, self.model_runner)
301
424
  batch.forward_mode = ForwardMode.DRAFT_EXTEND
302
425
  batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
303
426
  batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
427
+ # We don't need logprob for this extend.
428
+ original_return_logprob = batch.return_logprob
429
+ batch.return_logprob = False
304
430
  model_worker_batch = batch.get_model_worker_batch()
305
- forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
306
- logits_output = self.model_runner.forward(forward_batch)
307
- self.capture_for_decode(logits_output, forward_batch)
308
- self._set_mem_pool(batch, self.target_worker.model_runner)
431
+ forward_batch = ForwardBatch.init_new(
432
+ model_worker_batch, self.draft_model_runner
433
+ )
434
+ logits_output = self.draft_model_runner.forward(forward_batch)
435
+ self._detect_nan_if_needed(logits_output)
436
+ assert forward_batch.spec_info is batch.spec_info
437
+ self.capture_for_decode(logits_output, forward_batch.spec_info)
309
438
 
310
439
  # Restore backup.
311
440
  # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
441
+ batch.return_logprob = original_return_logprob
312
442
  batch.forward_mode = ForwardMode.DECODE
313
443
  batch.seq_lens = seq_lens_backup
314
- batch.req_pool_indices = req_pool_indices_backup
315
444
 
316
445
  def capture_for_decode(
317
- self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
446
+ self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
318
447
  ):
319
448
  probs = torch.softmax(logits_output.next_token_logits, dim=-1)
320
- spec_info = forward_batch.spec_info
321
- spec_info.topk_p, spec_info.topk_index = fast_topk(probs, self.topk, dim=-1)
322
- spec_info.hidden_states = logits_output.hidden_states
323
-
324
- # Don't support prefix share now.
325
- def finish_request(self, reqs: Union[Req, List[Req]]):
326
- if not isinstance(reqs, List):
327
- reqs = [reqs]
328
- for req in reqs:
329
- if req.rid not in self.finish_extend_len:
330
- continue
331
- req_len = (
332
- len(req.origin_input_ids)
333
- + len(req.output_ids)
334
- - self.finish_extend_len[req.rid]
335
- - 1
336
- )
337
- kv_indices = self.model_runner.req_to_token_pool.req_to_token[
338
- req.req_pool_idx
339
- ][:req_len]
340
- self.model_runner.token_to_kv_pool.free(kv_indices)
341
- self.model_runner.req_to_token_pool.free(req.req_pool_idx)
449
+ draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
450
+ draft_input.hidden_states = logits_output.hidden_states
451
+
452
+ def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
453
+ if self.enable_nan_detection:
454
+ logits = logits_output.next_token_logits
455
+ if torch.any(torch.isnan(logits)):
456
+ logger.warning("Detected errors during sampling! NaN in the logits.")
457
+ raise ValueError("Detected errors during sampling! NaN in the logits.")
458
+
459
+
460
+ def load_token_map(token_map_path: str) -> List[int]:
461
+ if not os.path.exists(token_map_path):
462
+ cache_dir = snapshot_download(
463
+ os.path.dirname(token_map_path),
464
+ ignore_patterns=["*.bin", "*.safetensors"],
465
+ )
466
+ token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
467
+ hot_token_id = torch.load(token_map_path)
468
+ return torch.tensor(hot_token_id, dtype=torch.int32)
@@ -5,30 +5,18 @@ class SpeculativeAlgorithm(IntEnum):
5
5
  NONE = auto()
6
6
  EAGLE = auto()
7
7
 
8
- # NEXTN spec decoding is for DeepSeek V3/R1
9
- # currently it's implemented based on EAGLE
10
- NEXTN = auto()
11
-
12
8
  def is_none(self):
13
9
  return self == SpeculativeAlgorithm.NONE
14
10
 
15
11
  def is_eagle(self):
16
- return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.NEXTN
17
-
18
- def is_nextn(self):
19
- return self == SpeculativeAlgorithm.NEXTN
12
+ return self == SpeculativeAlgorithm.EAGLE
20
13
 
21
14
  @staticmethod
22
15
  def from_string(name: str):
23
16
  name_map = {
24
17
  "EAGLE": SpeculativeAlgorithm.EAGLE,
25
- "NEXTN": SpeculativeAlgorithm.NEXTN,
26
18
  None: SpeculativeAlgorithm.NONE,
27
19
  }
28
20
  if name is not None:
29
21
  name = name.upper()
30
22
  return name_map[name]
31
-
32
-
33
- class SpecInfo:
34
- pass