sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +133 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +32 -21
  49. sglang/srt/layers/layernorm.py +24 -2
  50. sglang/srt/layers/linear.py +17 -5
  51. sglang/srt/layers/logits_processor.py +25 -7
  52. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  53. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  54. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  55. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  62. sglang/srt/layers/moe/topk.py +31 -18
  63. sglang/srt/layers/parameter.py +1 -1
  64. sglang/srt/layers/quantization/__init__.py +184 -126
  65. sglang/srt/layers/quantization/base_config.py +5 -0
  66. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  67. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  69. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  70. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  71. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  73. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  74. sglang/srt/layers/quantization/fp8.py +76 -34
  75. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  76. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  77. sglang/srt/layers/quantization/gptq.py +36 -9
  78. sglang/srt/layers/quantization/kv_cache.py +98 -0
  79. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  80. sglang/srt/layers/quantization/utils.py +153 -0
  81. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  82. sglang/srt/layers/rotary_embedding.py +66 -87
  83. sglang/srt/layers/sampler.py +1 -1
  84. sglang/srt/lora/layers.py +68 -0
  85. sglang/srt/lora/lora.py +2 -22
  86. sglang/srt/lora/lora_manager.py +47 -23
  87. sglang/srt/lora/mem_pool.py +110 -51
  88. sglang/srt/lora/utils.py +12 -1
  89. sglang/srt/managers/cache_controller.py +2 -5
  90. sglang/srt/managers/data_parallel_controller.py +30 -8
  91. sglang/srt/managers/expert_distribution.py +81 -0
  92. sglang/srt/managers/io_struct.py +39 -3
  93. sglang/srt/managers/mm_utils.py +373 -0
  94. sglang/srt/managers/multimodal_processor.py +68 -0
  95. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  96. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  97. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  98. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  99. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  100. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  101. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  102. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  103. sglang/srt/managers/schedule_batch.py +133 -30
  104. sglang/srt/managers/scheduler.py +273 -20
  105. sglang/srt/managers/session_controller.py +1 -1
  106. sglang/srt/managers/tokenizer_manager.py +59 -23
  107. sglang/srt/managers/tp_worker.py +1 -1
  108. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  109. sglang/srt/managers/utils.py +6 -1
  110. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  111. sglang/srt/mem_cache/memory_pool.py +255 -98
  112. sglang/srt/mem_cache/paged_allocator.py +2 -2
  113. sglang/srt/mem_cache/radix_cache.py +4 -4
  114. sglang/srt/model_executor/cuda_graph_runner.py +27 -13
  115. sglang/srt/model_executor/forward_batch_info.py +68 -11
  116. sglang/srt/model_executor/model_runner.py +70 -6
  117. sglang/srt/model_loader/loader.py +160 -2
  118. sglang/srt/model_loader/weight_utils.py +45 -0
  119. sglang/srt/models/deepseek_janus_pro.py +29 -86
  120. sglang/srt/models/deepseek_nextn.py +22 -10
  121. sglang/srt/models/deepseek_v2.py +208 -77
  122. sglang/srt/models/deepseek_vl2.py +358 -0
  123. sglang/srt/models/gemma3_causal.py +684 -0
  124. sglang/srt/models/gemma3_mm.py +462 -0
  125. sglang/srt/models/llama.py +47 -7
  126. sglang/srt/models/llama_eagle.py +1 -0
  127. sglang/srt/models/llama_eagle3.py +196 -0
  128. sglang/srt/models/llava.py +3 -3
  129. sglang/srt/models/llavavid.py +3 -3
  130. sglang/srt/models/minicpmo.py +1995 -0
  131. sglang/srt/models/minicpmv.py +62 -137
  132. sglang/srt/models/mllama.py +4 -4
  133. sglang/srt/models/phi3_small.py +1 -1
  134. sglang/srt/models/qwen2.py +3 -0
  135. sglang/srt/models/qwen2_5_vl.py +68 -146
  136. sglang/srt/models/qwen2_classification.py +75 -0
  137. sglang/srt/models/qwen2_moe.py +9 -1
  138. sglang/srt/models/qwen2_vl.py +25 -63
  139. sglang/srt/openai_api/adapter.py +124 -28
  140. sglang/srt/openai_api/protocol.py +23 -2
  141. sglang/srt/sampling/sampling_batch_info.py +1 -1
  142. sglang/srt/sampling/sampling_params.py +6 -6
  143. sglang/srt/server_args.py +99 -9
  144. sglang/srt/speculative/build_eagle_tree.py +7 -347
  145. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  146. sglang/srt/speculative/eagle_utils.py +208 -252
  147. sglang/srt/speculative/eagle_worker.py +139 -53
  148. sglang/srt/speculative/spec_info.py +6 -1
  149. sglang/srt/torch_memory_saver_adapter.py +22 -0
  150. sglang/srt/utils.py +182 -21
  151. sglang/test/__init__.py +0 -0
  152. sglang/test/attention/__init__.py +0 -0
  153. sglang/test/attention/test_flashattn_backend.py +312 -0
  154. sglang/test/runners.py +2 -0
  155. sglang/test/test_activation.py +2 -1
  156. sglang/test/test_block_fp8.py +5 -4
  157. sglang/test/test_block_fp8_ep.py +2 -1
  158. sglang/test/test_dynamic_grad_mode.py +58 -0
  159. sglang/test/test_layernorm.py +3 -2
  160. sglang/test/test_utils.py +55 -4
  161. sglang/utils.py +31 -0
  162. sglang/version.py +1 -1
  163. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  164. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
  165. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  166. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  167. sglang/srt/managers/image_processor.py +0 -55
  168. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  169. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  170. sglang/srt/managers/multi_modality_padding.py +0 -134
  171. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  172. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,10 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
22
22
  if TYPE_CHECKING:
23
23
  from sglang.srt.speculative.eagle_worker import EAGLEWorker
24
24
 
25
+ import logging
26
+
27
+ logger = logging.getLogger(__name__)
28
+
25
29
 
26
30
  class EAGLEDraftCudaGraphRunner:
27
31
  def __init__(self, eagle_worker: EAGLEWorker):
@@ -33,13 +37,10 @@ class EAGLEDraftCudaGraphRunner:
33
37
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
34
38
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
35
39
  self.tp_size = self.model_runner.tp_size
36
- self.dp_size = model_runner.server_args.dp_size
37
40
  self.topk = model_runner.server_args.speculative_eagle_topk
38
41
  self.speculative_num_steps = model_runner.server_args.speculative_num_steps
39
42
  server_args = model_runner.server_args
40
43
 
41
- assert self.disable_padding
42
-
43
44
  # Batch sizes to capture
44
45
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
45
46
  self.num_tokens_per_bs = server_args.speculative_eagle_topk
@@ -51,6 +52,9 @@ class EAGLEDraftCudaGraphRunner:
51
52
  self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
52
53
  0
53
54
  ].get_cuda_graph_seq_len_fill_value()
55
+ self.seq_lens_cpu = torch.full(
56
+ (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
57
+ )
54
58
 
55
59
  if self.enable_torch_compile:
56
60
  set_torch_compile_config()
@@ -169,6 +173,13 @@ class EAGLEDraftCudaGraphRunner:
169
173
  set_global_graph_memory_pool(graph.pool())
170
174
  return graph, out
171
175
 
176
+ def _postprocess_output_to_raw_bs(self, out, raw_bs):
177
+ score_list, token_list, parents_list = out
178
+ score_list = [x[:raw_bs] for x in score_list]
179
+ token_list = [x[:raw_bs] for x in token_list]
180
+ parents_list = [x[:raw_bs] for x in parents_list]
181
+ return (score_list, token_list, parents_list)
182
+
172
183
  def replay(self, forward_batch: ForwardBatch):
173
184
  assert forward_batch.out_cache_loc is not None
174
185
  raw_bs = forward_batch.batch_size
@@ -180,6 +191,9 @@ class EAGLEDraftCudaGraphRunner:
180
191
  if bs != raw_bs:
181
192
  self.seq_lens.fill_(1)
182
193
  self.out_cache_loc.zero_()
194
+ self.positions.zero_()
195
+
196
+ num_tokens = bs * self.num_tokens_per_bs
183
197
 
184
198
  # Common inputs
185
199
  self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
@@ -193,11 +207,33 @@ class EAGLEDraftCudaGraphRunner:
193
207
  self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
194
208
 
195
209
  # Attention backend
210
+ if bs != raw_bs:
211
+ forward_batch.batch_size = bs
212
+ forward_batch.seq_lens = self.seq_lens[:bs]
213
+ forward_batch.req_pool_indices = self.req_pool_indices[:bs]
214
+ forward_batch.positions = self.positions[:num_tokens]
215
+
216
+ # Special handle for seq_len_cpu used when flashinfer mla is used
217
+ if (forward_batch.decode_seq_lens_cpu is not None) and (bs != raw_bs):
218
+ self.seq_lens_cpu.fill_(1)
219
+ self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
220
+ forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:bs]
221
+
196
222
  self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
197
- forward_batch, forward_batch.batch_size
223
+ forward_batch, bs
198
224
  )
199
225
 
200
226
  # Replay
201
227
  self.graphs[bs].replay()
228
+ out = self.output_buffers[bs]
202
229
 
203
- return self.output_buffers[bs]
230
+ if bs != raw_bs:
231
+ out = self._postprocess_output_to_raw_bs(out, raw_bs)
232
+ forward_batch.batch_size = raw_bs
233
+ forward_batch.positions = self.positions[:raw_num_token]
234
+ forward_batch.seq_lens = self.seq_lens[:raw_bs]
235
+ forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
236
+ if forward_batch.decode_seq_lens_cpu is not None:
237
+ forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
238
+
239
+ return out
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import TYPE_CHECKING, List
4
+ from typing import TYPE_CHECKING, List, Optional
5
5
 
6
6
  import torch
7
7
  import torch.nn.functional as F
@@ -13,18 +13,26 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
13
13
  from sglang.srt.managers.schedule_batch import global_server_args_dict
14
14
  from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
15
15
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
16
- from sglang.srt.speculative.build_eagle_tree import (
17
- build_tree_kernel,
18
- build_tree_kernel_efficient,
19
- )
20
- from sglang.srt.utils import is_cuda_available
16
+ from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
17
+ from sglang.srt.utils import is_cuda_available, is_hip
21
18
 
22
19
  if is_cuda_available():
23
- from sgl_kernel import tree_speculative_sampling_target_only
20
+ from sgl_kernel import (
21
+ top_k_renorm_prob,
22
+ top_p_renorm_prob,
23
+ tree_speculative_sampling_target_only,
24
+ verify_tree_greedy,
25
+ )
26
+ elif is_hip():
27
+ from sgl_kernel import verify_tree_greedy
24
28
 
25
29
  if TYPE_CHECKING:
26
30
  from sglang.srt.managers.schedule_batch import ScheduleBatch
27
31
 
32
+ import logging
33
+
34
+ logger = logging.getLogger(__name__)
35
+
28
36
 
29
37
  @dataclass
30
38
  class EagleDraftInput:
@@ -47,44 +55,32 @@ class EagleDraftInput:
47
55
  kv_indptr: torch.Tensor = None
48
56
  kv_indices: torch.Tensor = None
49
57
 
50
- # indices of unfinished requests during extend-after-decode
51
- # e.g. [0, 2, 3, 4] if only the 1st request is finished
52
- keep_indices: List[int] = None
58
+ all_padding_lens: Optional[torch.Tensor] = None
53
59
 
54
60
  def prepare_for_extend(self, batch: ScheduleBatch):
55
- assert batch.input_ids.numel() == batch.out_cache_loc.shape[0]
56
61
  # Prefill only generate 1 token.
57
62
  assert len(self.verified_id) == len(batch.seq_lens)
58
63
 
59
64
  pt = 0
60
65
  for i, extend_len in enumerate(batch.extend_lens):
61
66
  input_ids = batch.input_ids[pt : pt + extend_len]
62
- batch.input_ids[pt : pt + extend_len] = torch.concat(
67
+ batch.input_ids[pt : pt + extend_len] = torch.cat(
63
68
  (input_ids[1:], self.verified_id[i].reshape(1))
64
69
  )
65
70
  pt += extend_len
66
71
 
67
- def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
68
- assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
72
+ def prepare_extend_after_decode(
73
+ self,
74
+ batch: ScheduleBatch,
75
+ speculative_num_steps: int,
76
+ ):
77
+ assert len(self.verified_id) == len(batch.out_cache_loc)
69
78
  accept_length_cpu = batch.spec_info.accept_length_cpu
70
79
  batch.extend_lens = [x + 1 for x in accept_length_cpu]
71
80
  batch.extend_num_tokens = sum(batch.extend_lens)
72
81
  batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
82
+ batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
73
83
  seq_lens_cpu = batch.seq_lens.tolist()
74
- assert len(batch.req_pool_indices) == len(batch.reqs)
75
-
76
- pt = 0
77
- i = 0
78
- self.keep_indices = []
79
- for idx, req in enumerate(batch.reqs):
80
- if req.finished():
81
- continue
82
- self.keep_indices.append(idx)
83
- # assert seq_len - pre_len == req.extend_input_len
84
- input_len = batch.extend_lens[i]
85
- seq_len = seq_lens_cpu[i]
86
- pt += input_len
87
- i += 1
88
84
 
89
85
  self.positions = torch.empty_like(self.verified_id, dtype=torch.long)
90
86
  new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
@@ -112,10 +108,6 @@ class EagleDraftInput:
112
108
  req_to_token: torch.Tensor,
113
109
  ):
114
110
  bs = self.accept_length.numel()
115
- keep_indices = torch.tensor(self.keep_indices, device=req_pool_indices.device)
116
- req_pool_indices = req_pool_indices[keep_indices]
117
- assert req_pool_indices.shape[0] == bs
118
- assert req_pool_indices.shape[0] == paged_kernel_lens.shape[0]
119
111
 
120
112
  qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
121
113
  qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
@@ -172,7 +164,7 @@ class EagleVerifyOutput:
172
164
  # Accepeted token length per sequence in a batch in CPU.
173
165
  accept_length_per_req_cpu: List[int]
174
166
  # Accepeted indices from logits_output.next_token_logits
175
- accepeted_indices_cpu: List[int]
167
+ accepeted_indices: torch.Tensor
176
168
 
177
169
 
178
170
  @dataclass
@@ -200,67 +192,38 @@ class EagleVerifyInput:
200
192
  topk: int,
201
193
  spec_steps: int,
202
194
  num_verify_tokens: int,
203
- is_all_greedy: bool,
204
195
  ):
205
- if is_all_greedy:
206
- tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
207
- build_tree_kernel(
208
- verified_id,
209
- score_list, # b, n, topk; n= 1 + (num_steps-1) * self.topk
210
- token_list,
211
- parents_list,
212
- seq_lens,
213
- seq_lens_sum,
214
- topk,
215
- spec_steps,
216
- num_verify_tokens,
217
- )
218
- )
219
-
220
- return cls(
221
- draft_tokens,
222
- tree_mask,
223
- position,
224
- retrive_index,
225
- None,
226
- None,
227
- retrive_cum_len,
228
- num_verify_tokens,
229
- spec_steps,
230
- CaptureHiddenMode.FULL,
231
- )
232
- else:
233
- (
234
- tree_mask,
235
- position,
236
- retrive_index,
237
- retrive_next_token,
238
- retrive_next_sibling,
239
- draft_tokens,
240
- ) = build_tree_kernel_efficient(
241
- verified_id,
242
- score_list,
243
- token_list,
244
- parents_list,
245
- seq_lens,
246
- seq_lens_sum,
247
- topk,
248
- spec_steps,
249
- num_verify_tokens,
250
- )
196
+ (
197
+ tree_mask,
198
+ position,
199
+ retrive_index,
200
+ retrive_next_token,
201
+ retrive_next_sibling,
202
+ draft_tokens,
203
+ ) = build_tree_kernel_efficient(
204
+ verified_id,
205
+ score_list,
206
+ token_list,
207
+ parents_list,
208
+ seq_lens,
209
+ seq_lens_sum,
210
+ topk,
211
+ spec_steps,
212
+ num_verify_tokens,
213
+ )
251
214
 
252
- return cls(
253
- draft_tokens,
254
- tree_mask,
255
- position,
256
- retrive_index,
257
- retrive_next_token,
258
- retrive_next_sibling,
259
- None,
260
- num_verify_tokens,
261
- spec_steps,
262
- CaptureHiddenMode.FULL,
263
- )
215
+ return cls(
216
+ draft_tokens,
217
+ tree_mask,
218
+ position,
219
+ retrive_index,
220
+ retrive_next_token,
221
+ retrive_next_sibling,
222
+ None,
223
+ num_verify_tokens,
224
+ spec_steps,
225
+ CaptureHiddenMode.FULL,
226
+ )
264
227
 
265
228
  def prepare_for_verify(self, batch: ScheduleBatch):
266
229
  batch.input_ids = self.draft_token
@@ -291,7 +254,6 @@ class EagleVerifyInput:
291
254
  dtype=torch.int32,
292
255
  device="cuda",
293
256
  )
294
-
295
257
  cum_kv_seq_len = torch.zeros(
296
258
  (batch_size + 1,), dtype=torch.int32, device="cuda"
297
259
  )
@@ -304,7 +266,6 @@ class EagleVerifyInput:
304
266
  dtype=torch.int32,
305
267
  device="cuda",
306
268
  )
307
-
308
269
  create_flashinfer_kv_indices_triton[(batch_size,)](
309
270
  req_to_token,
310
271
  req_pool_indices,
@@ -322,65 +283,79 @@ class EagleVerifyInput:
322
283
  logits_output: torch.Tensor,
323
284
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
324
285
  ) -> torch.Tensor:
325
- """WARNING: This API in-place modifies the states of logits_output
326
-
286
+ """
327
287
  Verify and find accepted tokens based on logits output and batch
328
288
  (which contains spec decoding information).
329
289
 
290
+ WARNING: This API in-place modifies the states of logits_output
291
+
330
292
  This API updates values inside logits_output based on the accepted
331
293
  tokens. I.e., logits_output.next_token_logits only contains
332
294
  accepeted token logits.
333
295
  """
334
- draft_token = torch.cat(
335
- [self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
336
- dim=-1,
296
+ bs = self.retrive_index.shape[0]
297
+ candidates = self.draft_token.reshape(bs, self.draft_token_num)
298
+ sampling_info = batch.sampling_info
299
+
300
+ predict_shape = list(logits_output.next_token_logits.shape)[:-1]
301
+ predict_shape[-1] += 1
302
+ predict = torch.empty(predict_shape, dtype=torch.int32, device="cuda")
303
+ accept_index = torch.full(
304
+ (bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda"
337
305
  )
338
- candidates = draft_token[self.retrive_index]
339
- if batch.sampling_info.is_all_greedy:
340
- # temp == 0
341
- bs = self.retrive_cum_len.numel() - 1
342
- predict = torch.argmax(logits_output.next_token_logits, dim=-1)
343
- predict = torch.cat(
344
- [predict, torch.full([1], -1, dtype=torch.int32, device="cuda")], dim=-1
306
+ accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
307
+
308
+ if sampling_info.penalizer_orchestrator.is_required:
309
+ # This is a relaxed version of penalties for speculative decoding.
310
+ linear_penalty = torch.zeros(
311
+ (bs, logits_output.next_token_logits.shape[1]),
312
+ dtype=torch.float32,
313
+ device="cuda",
345
314
  )
346
- target_predict = predict[self.retrive_index]
347
- # logits = logits_output.next_token_logits[self.retrive_index]
348
- # target_predict = torch.argmax(logits[:, :-1], dim=-1)
349
- accept_mask = candidates[:, 1:] == target_predict[:, :-1]
350
-
351
- accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1)
352
- max_draft_len = self.retrive_index.shape[-1]
353
- accept_index = torch.full(
354
- (bs, max_draft_len), -1, dtype=torch.int32, device="cuda"
315
+ sampling_info.apply_logits_bias(linear_penalty)
316
+ logits_output.next_token_logits.add_(
317
+ torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
355
318
  )
356
- accept_length = torch.empty((bs,), dtype=torch.int, device="cuda")
357
- extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda")
358
- eagle_verify_retrive[(bs,)](
359
- self.retrive_index.contiguous(),
360
- accept_mask.contiguous(),
361
- self.retrive_cum_len,
362
- accept_index,
363
- accept_length,
364
- extract_index,
365
- max_draft_len,
366
- self.draft_token_num,
367
- triton.next_power_of_2(max_draft_len),
319
+
320
+ if batch.sampling_info.is_all_greedy:
321
+ target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
322
+ target_predict = target_predict.reshape(bs, self.draft_token_num)
323
+
324
+ verify_tree_greedy(
325
+ predicts=predict, # mutable
326
+ accept_index=accept_index, # mutable
327
+ accept_token_num=accept_length, # mutable
328
+ candidates=candidates.to(torch.int32),
329
+ retrive_index=self.retrive_index.to(torch.int32),
330
+ retrive_next_token=self.retrive_next_token.to(torch.int32),
331
+ retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
332
+ target_predict=target_predict.to(torch.int32),
368
333
  )
369
334
  else:
370
- # temp > 0
371
- bs = self.retrive_index.shape[0]
372
- predict_shape = list(logits_output.next_token_logits.shape)[:-1]
373
- predict_shape[-1] += 1
374
- target_logits = logits_output.next_token_logits[self.retrive_index]
375
- predict = torch.full(predict_shape, -1, dtype=torch.int32, device="cuda")
376
- accept_index = torch.full(
377
- (bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda"
335
+ # apply temperature and get target probs
336
+ expanded_temperature = torch.repeat_interleave(
337
+ sampling_info.temperatures, self.draft_token_num, dim=0
338
+ ) # (bs * draft_token_num, 1)
339
+
340
+ target_probs = F.softmax(
341
+ logits_output.next_token_logits / expanded_temperature, dim=-1
342
+ ) # (bs * draft_token_num, vocab_size)
343
+ target_probs = top_k_renorm_prob(
344
+ target_probs,
345
+ torch.repeat_interleave(
346
+ sampling_info.top_ks, self.draft_token_num, dim=0
347
+ ),
348
+ ) # (bs * draft_token_num, vocab_size)
349
+ target_probs = top_p_renorm_prob(
350
+ target_probs,
351
+ torch.repeat_interleave(
352
+ sampling_info.top_ps, self.draft_token_num, dim=0
353
+ ),
378
354
  )
379
- accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
380
- expanded_temperature = batch.sampling_info.temperatures.unsqueeze(1)
381
- target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
382
- draft_probs = torch.full_like(
383
- target_probs, 0, dtype=torch.float32, device="cuda"
355
+ target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
356
+
357
+ draft_probs = torch.zeros(
358
+ target_probs.shape, dtype=torch.float32, device="cuda"
384
359
  )
385
360
  coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
386
361
  tree_speculative_sampling_target_only(
@@ -394,6 +369,12 @@ class EagleVerifyInput:
394
369
  uniform_samples=coins,
395
370
  target_probs=target_probs,
396
371
  draft_probs=draft_probs,
372
+ threshold_single=global_server_args_dict[
373
+ "speculative_accept_threshold_single"
374
+ ],
375
+ threshold_acc=global_server_args_dict[
376
+ "speculative_accept_threshold_acc"
377
+ ],
397
378
  deterministic=True,
398
379
  )
399
380
 
@@ -425,119 +406,94 @@ class EagleVerifyInput:
425
406
  new_accept_index.extend(new_accept_index_)
426
407
  unfinished_index.append(i)
427
408
  req.spec_verify_ct += 1
428
- accept_length = (accept_index != -1).sum(dim=1) - 1
429
-
430
- accept_index = accept_index[accept_index != -1]
431
- accept_length_cpu = accept_length.tolist()
432
- verified_id = predict[accept_index]
433
- evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
434
- evict_mask[accept_index] = False
435
- mem_need_free_idx = batch.out_cache_loc[evict_mask]
436
- token_to_kv_pool_allocator.free(mem_need_free_idx)
437
- assign_req_to_token_pool[(bs,)](
438
- batch.req_pool_indices,
439
- batch.req_to_token_pool.req_to_token,
440
- batch.seq_lens,
441
- batch.seq_lens + accept_length + 1,
442
- batch.out_cache_loc[accept_index],
443
- batch.req_to_token_pool.req_to_token.shape[1],
444
- triton.next_power_of_2(bs),
445
- )
446
- batch.seq_lens.add_(accept_length + 1)
447
-
448
- draft_input = EagleDraftInput()
449
- if len(new_accept_index) > 0:
450
- new_accept_index = torch.tensor(new_accept_index, device="cuda")
451
- draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
452
- draft_input.verified_id = predict[new_accept_index]
453
- draft_input.accept_length = accept_length[unfinished_index]
454
- draft_input.accept_length_cpu = [
455
- accept_length_cpu[i] for i in unfinished_index
456
- ]
457
- if has_finished:
458
- draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
459
- else:
460
- draft_input.seq_lens_for_draft_extend = batch.seq_lens
461
- batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
462
-
463
- return EagleVerifyOutput(
464
- draft_input=draft_input,
465
- logits_output=logits_output,
466
- verified_id=verified_id,
467
- accept_length_per_req_cpu=accept_length_cpu,
468
- accepeted_indices_cpu=accept_index,
469
- )
470
409
 
471
-
472
- @triton.jit
473
- def eagle_verify_retrive(
474
- retrive_index,
475
- accept_mask,
476
- retrive_cum_len,
477
- accept_index,
478
- accept_length,
479
- extract_index,
480
- max_len: tl.constexpr,
481
- draft_token_num: tl.constexpr,
482
- max_len_upper: tl.constexpr,
483
- ):
484
- """
485
- Args:
486
- retrive_index: Pointer to indices of draft tokens
487
- accept_mask: Mask indicating which tokens were accepted
488
- retrive_cum_len: Cumulative lengths of token sequences in a batch
489
- accept_index (out): Accept token indices
490
- accept_length (out): Length of accepted tokens per sequence in a batch
491
- extract_index (out): Index for last accepted tokens
492
- max_len: Maximum length in a batch
493
- draft_token_num: Number of tokens speculatively generated
494
- max_len_upper An upper bound for token sequence length
495
- """
496
- pid = tl.program_id(axis=0)
497
-
498
- retrive_end = tl.load(retrive_cum_len + pid + 1)
499
- retrive_start = tl.load(retrive_cum_len + pid)
500
- retrive_len = retrive_end - retrive_start
501
- accept_ptr = accept_mask + retrive_start
502
- accept_offset = tl.arange(0, draft_token_num)
503
- accept_load_mask = accept_offset < retrive_len
504
- accept_len_list = tl.load(
505
- accept_ptr + accept_offset, mask=accept_load_mask, other=-1
506
- )
507
-
508
- accept_len = tl.max(accept_len_list)
509
- max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True)
510
- # triton is not support argmax with tie_break_right, so I need implement it by some way
511
- mask_max = accept_len_list == accept_len
512
-
513
- count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32)
514
- count = tl.sum(tl.where(mask_max, 1, count_mask))
515
- if count > 1:
516
- index = tl.arange(0, draft_token_num)
517
- mask_left = index != max_index
518
- remained_index = tl.where(mask_max and mask_left, index, 0)
519
- max_index = tl.max(remained_index)
520
-
521
- tl.store(accept_length + pid, accept_len)
522
- retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len
523
- retrive_offset = tl.arange(0, max_len_upper)
524
- retrive_load_mask = retrive_offset < accept_len + 1
525
- data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask)
526
-
527
- tl.store(
528
- accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask
529
- )
530
-
531
- extract_load_ptr = accept_index + pid * max_len + accept_len
532
- if accept_len == max_len - 1:
533
- extract_data = tl.load(extract_load_ptr - 1)
534
- tl.store(extract_index + pid * 2, extract_data)
535
- extract_data = tl.load(extract_load_ptr)
536
- tl.store(extract_index + pid * 2 + 1, extract_data)
537
-
538
- else:
539
- extract_data = tl.load(extract_load_ptr)
540
- tl.store(extract_index + pid * 2, extract_data)
410
+ if not has_finished:
411
+ accept_index = accept_index[accept_index != -1]
412
+ verified_id = predict[accept_index]
413
+ evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
414
+ evict_mask[accept_index] = False
415
+ mem_need_free_idx = batch.out_cache_loc[evict_mask]
416
+ token_to_kv_pool_allocator.free(mem_need_free_idx)
417
+ batch.out_cache_loc = batch.out_cache_loc[accept_index]
418
+ assign_req_to_token_pool[(bs,)](
419
+ batch.req_pool_indices,
420
+ batch.req_to_token_pool.req_to_token,
421
+ batch.seq_lens,
422
+ batch.seq_lens + accept_length + 1,
423
+ batch.out_cache_loc,
424
+ batch.req_to_token_pool.req_to_token.shape[1],
425
+ triton.next_power_of_2(bs),
426
+ )
427
+ batch.seq_lens.add_(accept_length + 1)
428
+ accept_length_cpu = accept_length.tolist()
429
+
430
+ draft_input = EagleDraftInput()
431
+ draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
432
+ draft_input.verified_id = verified_id
433
+ draft_input.accept_length = accept_length
434
+ draft_input.accept_length_cpu = accept_length_cpu
435
+ draft_input.seq_lens_for_draft_extend = batch.seq_lens
436
+ draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
437
+
438
+ return EagleVerifyOutput(
439
+ draft_input=draft_input,
440
+ logits_output=logits_output,
441
+ verified_id=verified_id,
442
+ accept_length_per_req_cpu=accept_length_cpu,
443
+ accepeted_indices=accept_index,
444
+ )
445
+ else:
446
+ accept_length = (accept_index != -1).sum(dim=1) - 1
447
+ accept_index = accept_index[accept_index != -1]
448
+ verified_id = predict[accept_index]
449
+ evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
450
+ evict_mask[accept_index] = False
451
+ mem_need_free_idx = batch.out_cache_loc[evict_mask]
452
+ token_to_kv_pool_allocator.free(mem_need_free_idx)
453
+ assign_req_to_token_pool[(bs,)](
454
+ batch.req_pool_indices,
455
+ batch.req_to_token_pool.req_to_token,
456
+ batch.seq_lens,
457
+ batch.seq_lens + accept_length + 1,
458
+ batch.out_cache_loc[accept_index],
459
+ batch.req_to_token_pool.req_to_token.shape[1],
460
+ triton.next_power_of_2(bs),
461
+ )
462
+ batch.seq_lens.add_(accept_length + 1)
463
+ accept_length_cpu = accept_length.tolist()
464
+
465
+ draft_input = EagleDraftInput()
466
+ if len(new_accept_index) > 0:
467
+ new_accept_index = torch.tensor(new_accept_index, device="cuda")
468
+ draft_input.hidden_states = batch.spec_info.hidden_states[
469
+ new_accept_index
470
+ ]
471
+ draft_input.verified_id = predict[new_accept_index]
472
+ draft_input.accept_length = accept_length[unfinished_index]
473
+ draft_input.accept_length_cpu = [
474
+ accept_length_cpu[i] for i in unfinished_index
475
+ ]
476
+ if has_finished:
477
+ draft_input.seq_lens_for_draft_extend = batch.seq_lens[
478
+ unfinished_index
479
+ ]
480
+ draft_input.req_pool_indices_for_draft_extend = (
481
+ batch.req_pool_indices[unfinished_index]
482
+ )
483
+ else:
484
+ draft_input.seq_lens_for_draft_extend = batch.seq_lens
485
+ draft_input.req_pool_indices_for_draft_extend = (
486
+ batch.req_pool_indices
487
+ )
488
+ batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
489
+
490
+ return EagleVerifyOutput(
491
+ draft_input=draft_input,
492
+ logits_output=logits_output,
493
+ verified_id=verified_id,
494
+ accept_length_per_req_cpu=accept_length_cpu,
495
+ accepeted_indices=accept_index,
496
+ )
541
497
 
542
498
 
543
499
  @triton.jit