sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/lang/chat_template.py +24 -0
  3. sglang/srt/_custom_ops.py +59 -92
  4. sglang/srt/configs/model_config.py +5 -0
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/conversation.py +29 -4
  7. sglang/srt/custom_op.py +5 -0
  8. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  9. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/layers/attention/flashattention_backend.py +678 -83
  12. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  13. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  14. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  16. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  17. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  18. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
  30. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  31. sglang/srt/layers/moe/topk.py +49 -3
  32. sglang/srt/layers/quantization/__init__.py +5 -1
  33. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  35. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  36. sglang/srt/layers/quantization/fp8.py +3 -1
  37. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  38. sglang/srt/layers/quantization/moe_wna16.py +503 -0
  39. sglang/srt/layers/quantization/utils.py +1 -1
  40. sglang/srt/layers/quantization/w8a8_int8.py +2 -0
  41. sglang/srt/layers/radix_attention.py +2 -0
  42. sglang/srt/layers/rotary_embedding.py +63 -12
  43. sglang/srt/managers/cache_controller.py +34 -11
  44. sglang/srt/managers/mm_utils.py +202 -156
  45. sglang/srt/managers/multimodal_processor.py +0 -2
  46. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  47. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  48. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  49. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  50. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  51. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  52. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  53. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  54. sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
  55. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  56. sglang/srt/managers/schedule_batch.py +185 -128
  57. sglang/srt/managers/scheduler.py +4 -4
  58. sglang/srt/managers/tokenizer_manager.py +1 -1
  59. sglang/srt/managers/utils.py +1 -6
  60. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  61. sglang/srt/mem_cache/memory_pool.py +72 -6
  62. sglang/srt/mem_cache/paged_allocator.py +39 -0
  63. sglang/srt/metrics/collector.py +23 -53
  64. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  65. sglang/srt/model_executor/forward_batch_info.py +10 -10
  66. sglang/srt/model_executor/model_runner.py +60 -57
  67. sglang/srt/model_loader/loader.py +8 -0
  68. sglang/srt/models/clip.py +12 -7
  69. sglang/srt/models/deepseek_janus_pro.py +10 -15
  70. sglang/srt/models/deepseek_v2.py +212 -121
  71. sglang/srt/models/deepseek_vl2.py +105 -104
  72. sglang/srt/models/gemma3_mm.py +14 -80
  73. sglang/srt/models/llama.py +16 -5
  74. sglang/srt/models/llama4.py +420 -0
  75. sglang/srt/models/llava.py +31 -19
  76. sglang/srt/models/llavavid.py +16 -7
  77. sglang/srt/models/minicpmo.py +63 -147
  78. sglang/srt/models/minicpmv.py +17 -27
  79. sglang/srt/models/mllama.py +29 -14
  80. sglang/srt/models/mllama4.py +154 -0
  81. sglang/srt/models/qwen2.py +9 -6
  82. sglang/srt/models/qwen2_5_vl.py +21 -31
  83. sglang/srt/models/qwen2_vl.py +20 -21
  84. sglang/srt/openai_api/adapter.py +18 -6
  85. sglang/srt/platforms/interface.py +371 -0
  86. sglang/srt/server_args.py +99 -14
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  88. sglang/srt/speculative/eagle_utils.py +140 -28
  89. sglang/srt/speculative/eagle_worker.py +93 -24
  90. sglang/srt/utils.py +104 -51
  91. sglang/test/test_custom_ops.py +55 -0
  92. sglang/test/test_utils.py +13 -26
  93. sglang/utils.py +2 -2
  94. sglang/version.py +1 -1
  95. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
  96. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
  97. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -15,11 +15,12 @@
15
15
 
16
16
  import argparse
17
17
  import dataclasses
18
+ import json
18
19
  import logging
19
20
  import os
20
21
  import random
21
22
  import tempfile
22
- from typing import List, Optional
23
+ from typing import List, Literal, Optional
23
24
 
24
25
  from sglang.srt.hf_transformers_utils import check_gguf_file
25
26
  from sglang.srt.reasoning_parser import ReasoningParser
@@ -127,14 +128,14 @@ class ServerArgs:
127
128
  # Kernel backend
128
129
  attention_backend: Optional[str] = None
129
130
  sampling_backend: Optional[str] = None
130
- grammar_backend: Optional[str] = "xgrammar"
131
+ grammar_backend: Optional[str] = None
131
132
 
132
133
  # Speculative decoding
133
134
  speculative_algorithm: Optional[str] = None
134
135
  speculative_draft_model_path: Optional[str] = None
135
- speculative_num_steps: int = 5
136
- speculative_eagle_topk: int = 4
137
- speculative_num_draft_tokens: int = 8
136
+ speculative_num_steps: Optional[int] = None
137
+ speculative_eagle_topk: Optional[int] = None
138
+ speculative_num_draft_tokens: Optional[int] = None
138
139
  speculative_accept_threshold_single: float = 1.0
139
140
  speculative_accept_threshold_acc: float = 1.0
140
141
  speculative_token_map: Optional[str] = None
@@ -160,6 +161,7 @@ class ServerArgs:
160
161
  enable_dp_attention: bool = False
161
162
  enable_ep_moe: bool = False
162
163
  enable_deepep_moe: bool = False
164
+ deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
163
165
  enable_torch_compile: bool = False
164
166
  torch_compile_max_bs: int = 32
165
167
  cuda_graph_max_bs: Optional[int] = None
@@ -177,10 +179,12 @@ class ServerArgs:
177
179
  tool_call_parser: Optional[str] = None
178
180
  enable_hierarchical_cache: bool = False
179
181
  hicache_ratio: float = 2.0
180
- enable_flashinfer_mla: bool = False
182
+ enable_flashinfer_mla: bool = False # TODO: remove this argument
181
183
  enable_flashmla: bool = False
182
184
  flashinfer_mla_disable_ragged: bool = False
183
185
  warmups: Optional[str] = None
186
+ n_share_experts_fusion: int = 0
187
+ disable_shared_experts_fusion: bool = False
184
188
 
185
189
  # Debug tensor dumps
186
190
  debug_tensor_dump_output_folder: Optional[str] = None
@@ -192,6 +196,13 @@ class ServerArgs:
192
196
  disaggregation_bootstrap_port: int = 8998
193
197
 
194
198
  def __post_init__(self):
199
+ # Expert parallelism
200
+ if self.enable_ep_moe:
201
+ self.ep_size = self.tp_size
202
+ logger.info(
203
+ f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
204
+ )
205
+
195
206
  # Set missing default values
196
207
  if self.tokenizer_path is None:
197
208
  self.tokenizer_path = self.model_path
@@ -215,6 +226,9 @@ class ServerArgs:
215
226
  # GPU memory is not known yet or no GPU is available.
216
227
  gpu_mem = None
217
228
 
229
+ if is_hip():
230
+ self.disable_shared_experts_fusion = True
231
+
218
232
  # Set mem fraction static, which depends on the tensor parallelism size
219
233
  if self.mem_fraction_static is None:
220
234
  if self.tp_size >= 16:
@@ -253,15 +267,11 @@ class ServerArgs:
253
267
  else:
254
268
  self.cuda_graph_max_bs = 160
255
269
 
256
- # Choose kernel backends
270
+ # Set kernel backends for hpu device
257
271
  if self.device == "hpu":
258
272
  self.attention_backend = "torch_native"
259
273
  self.sampling_backend = "pytorch"
260
274
 
261
- if self.attention_backend is None:
262
- self.attention_backend = (
263
- "flashinfer" if is_flashinfer_available() else "triton"
264
- )
265
275
  if self.sampling_backend is None:
266
276
  self.sampling_backend = (
267
277
  "flashinfer" if is_flashinfer_available() else "pytorch"
@@ -273,6 +283,10 @@ class ServerArgs:
273
283
  )
274
284
  self.disable_cuda_graph = True
275
285
 
286
+ # Choose grammar backend
287
+ if self.grammar_backend is None:
288
+ self.grammar_backend = "xgrammar"
289
+
276
290
  # Expert parallelism
277
291
  if self.enable_ep_moe:
278
292
  self.ep_size = self.tp_size
@@ -295,6 +309,10 @@ class ServerArgs:
295
309
  self.enable_sp_layernorm = False
296
310
  # DeepEP MoE
297
311
  if self.enable_deepep_moe:
312
+ if self.deepep_mode == "auto":
313
+ assert (
314
+ not self.enable_dp_attention
315
+ ), "DeepEP MoE `auto` mode is not supported with DP Attention."
298
316
  self.ep_size = self.tp_size
299
317
  self.enable_sp_layernorm = (
300
318
  self.dp_size < self.tp_size if self.enable_dp_attention else True
@@ -313,12 +331,29 @@ class ServerArgs:
313
331
  or self.speculative_algorithm == "EAGLE3"
314
332
  ):
315
333
  if self.max_running_requests is None:
316
- self.max_running_requests = 32
334
+ self.max_running_requests = 48
317
335
  self.disable_overlap_schedule = True
318
336
  logger.info(
319
337
  "Overlap scheduler is disabled because of using "
320
338
  "eagle speculative decoding."
321
339
  )
340
+
341
+ # Auto choose parameters
342
+ if self.speculative_num_steps is None:
343
+ assert (
344
+ self.speculative_eagle_topk is None
345
+ and self.speculative_num_draft_tokens is None
346
+ )
347
+ (
348
+ self.speculative_num_steps,
349
+ self.speculative_eagle_topk,
350
+ self.speculative_num_draft_tokens,
351
+ ) = auto_choose_speculative_params(self)
352
+
353
+ if self.page_size > 1 and self.speculative_eagle_topk > 1:
354
+ self.speculative_eagle_topk = 1
355
+ logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
356
+
322
357
  # The token generated from the verify step is counted.
323
358
  # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
324
359
  # assert self.speculative_num_steps < self.speculative_num_draft_tokens
@@ -462,6 +497,7 @@ class ServerArgs:
462
497
  "modelopt",
463
498
  "w8a8_int8",
464
499
  "w8a8_fp8",
500
+ "moe_wna16",
465
501
  ],
466
502
  help="The quantization method.",
467
503
  )
@@ -795,14 +831,14 @@ class ServerArgs:
795
831
  parser.add_argument(
796
832
  "--grammar-backend",
797
833
  type=str,
798
- choices=["xgrammar", "outlines", "llguidance"],
834
+ choices=["xgrammar", "outlines", "llguidance", "none"],
799
835
  default=ServerArgs.grammar_backend,
800
836
  help="Choose the backend for grammar-guided decoding.",
801
837
  )
802
838
  parser.add_argument(
803
839
  "--enable-flashinfer-mla",
804
840
  action="store_true",
805
- help="Enable FlashInfer MLA optimization",
841
+ help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!",
806
842
  )
807
843
  parser.add_argument(
808
844
  "--enable-flashmla",
@@ -1060,6 +1096,25 @@ class ServerArgs:
1060
1096
  action="store_true",
1061
1097
  help="Enabling DeepEP MoE implementation for EP MoE.",
1062
1098
  )
1099
+ parser.add_argument(
1100
+ "--deepep-mode",
1101
+ type=str,
1102
+ choices=["normal", "low_latency", "auto"],
1103
+ help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
1104
+ )
1105
+
1106
+ parser.add_argument(
1107
+ "--n-share-experts-fusion",
1108
+ type=int,
1109
+ default=0,
1110
+ help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 "
1111
+ "we use tp_size by default.",
1112
+ )
1113
+ parser.add_argument(
1114
+ "--disable-shared-experts-fusion",
1115
+ action="store_true",
1116
+ help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
1117
+ )
1063
1118
 
1064
1119
  # Server warmups
1065
1120
  parser.add_argument(
@@ -1253,3 +1308,33 @@ class DeprecatedAction(argparse.Action):
1253
1308
 
1254
1309
  def __call__(self, parser, namespace, values, option_string=None):
1255
1310
  raise ValueError(self.help)
1311
+
1312
+
1313
+ def auto_choose_speculative_params(self: ServerArgs):
1314
+ """
1315
+ Automatically choose the parameters for speculative decoding.
1316
+
1317
+ You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
1318
+ """
1319
+ if self.decrypted_config_file:
1320
+ config_path = self.decrypted_config_file
1321
+ else:
1322
+ config_path = os.path.join(self.model_path, "config.json")
1323
+ if not os.path.exists(config_path):
1324
+ raise ValueError(f"{config_path} is not found.")
1325
+
1326
+ config = json.load(open(config_path))
1327
+
1328
+ arch = config.get("architectures", ["Unknown"])[0]
1329
+
1330
+ if arch in ["LlamaForCausalLM"]:
1331
+ # The default value for llama
1332
+ return (5, 4, 8)
1333
+ elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
1334
+ # The default value for deepseek
1335
+ return (5, 4, 8)
1336
+ elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
1337
+ return (5, 4, 8)
1338
+ else:
1339
+ # The default value for all other models
1340
+ return (5, 4, 8)
@@ -214,10 +214,10 @@ class EAGLEDraftCudaGraphRunner:
214
214
  forward_batch.positions = self.positions[:num_tokens]
215
215
 
216
216
  # Special handle for seq_len_cpu used when flashinfer mla is used
217
- if (forward_batch.decode_seq_lens_cpu is not None) and (bs != raw_bs):
217
+ if (forward_batch.seq_lens_cpu is not None) and (bs != raw_bs):
218
218
  self.seq_lens_cpu.fill_(1)
219
- self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
220
- forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:bs]
219
+ self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
220
+ forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
221
221
 
222
222
  self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
223
223
  forward_batch, bs
@@ -233,7 +233,7 @@ class EAGLEDraftCudaGraphRunner:
233
233
  forward_batch.positions = self.positions[:raw_num_token]
234
234
  forward_batch.seq_lens = self.seq_lens[:raw_bs]
235
235
  forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
236
- if forward_batch.decode_seq_lens_cpu is not None:
237
- forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
236
+ if forward_batch.seq_lens_cpu is not None:
237
+ forward_batch.seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
238
238
 
239
239
  return out
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import os
3
4
  from dataclasses import dataclass
4
5
  from typing import TYPE_CHECKING, List, Optional
5
6
 
@@ -10,11 +11,15 @@ import triton.language as tl
10
11
 
11
12
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
12
13
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
13
- from sglang.srt.managers.schedule_batch import global_server_args_dict
14
+ from sglang.srt.managers.schedule_batch import (
15
+ ScheduleBatch,
16
+ get_last_loc,
17
+ global_server_args_dict,
18
+ )
14
19
  from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
15
20
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
16
21
  from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
17
- from sglang.srt.utils import is_cuda_available, is_hip
22
+ from sglang.srt.utils import is_cuda_available, is_hip, next_power_of_2
18
23
 
19
24
  if is_cuda_available():
20
25
  from sgl_kernel import (
@@ -34,6 +39,9 @@ import logging
34
39
  logger = logging.getLogger(__name__)
35
40
 
36
41
 
42
+ SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
43
+
44
+
37
45
  @dataclass
38
46
  class EagleDraftInput:
39
47
  # The inputs for decode
@@ -93,7 +101,7 @@ class EagleDraftInput:
93
101
  torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
94
102
  self.positions,
95
103
  new_verified_id,
96
- triton.next_power_of_2(speculative_num_steps + 1),
104
+ next_power_of_2(speculative_num_steps + 1),
97
105
  )
98
106
 
99
107
  batch.seq_lens_sum = sum(seq_lens_cpu)
@@ -225,18 +233,34 @@ class EagleVerifyInput:
225
233
  CaptureHiddenMode.FULL,
226
234
  )
227
235
 
228
- def prepare_for_verify(self, batch: ScheduleBatch):
236
+ def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
229
237
  batch.input_ids = self.draft_token
230
- batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
238
+
239
+ if page_size == 1:
240
+ batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
241
+ end_offset = batch.seq_lens + self.draft_token_num
242
+ else:
243
+ prefix_lens = batch.seq_lens
244
+ end_offset = prefix_lens + self.draft_token_num
245
+ last_loc = get_last_loc(
246
+ batch.req_to_token_pool.req_to_token,
247
+ batch.req_pool_indices,
248
+ prefix_lens,
249
+ )
250
+ batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
251
+ prefix_lens, end_offset, last_loc, len(batch.input_ids)
252
+ )
253
+ self.last_loc = last_loc
254
+
231
255
  bs = batch.batch_size()
232
256
  assign_req_to_token_pool[(bs,)](
233
257
  batch.req_pool_indices,
234
258
  batch.req_to_token_pool.req_to_token,
235
259
  batch.seq_lens,
236
- batch.seq_lens + self.draft_token_num,
260
+ end_offset,
237
261
  batch.out_cache_loc,
238
262
  batch.req_to_token_pool.req_to_token.shape[1],
239
- triton.next_power_of_2(bs),
263
+ next_power_of_2(bs),
240
264
  )
241
265
 
242
266
  def generate_attn_arg_prefill(
@@ -282,6 +306,7 @@ class EagleVerifyInput:
282
306
  batch: ScheduleBatch,
283
307
  logits_output: torch.Tensor,
284
308
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
309
+ page_size: int,
285
310
  ) -> torch.Tensor:
286
311
  """
287
312
  Verify and find accepted tokens based on logits output and batch
@@ -305,6 +330,7 @@ class EagleVerifyInput:
305
330
  )
306
331
  accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
307
332
 
333
+ # Apply penalty
308
334
  if sampling_info.penalizer_orchestrator.is_required:
309
335
  # This is a relaxed version of penalties for speculative decoding.
310
336
  linear_penalty = torch.zeros(
@@ -317,6 +343,7 @@ class EagleVerifyInput:
317
343
  torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
318
344
  )
319
345
 
346
+ # Sample tokens
320
347
  if batch.sampling_info.is_all_greedy:
321
348
  target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
322
349
  target_predict = target_predict.reshape(bs, self.draft_token_num)
@@ -378,13 +405,24 @@ class EagleVerifyInput:
378
405
  deterministic=True,
379
406
  )
380
407
 
408
+ if SIMULATE_ACC_LEN:
409
+ # Do simulation
410
+ accept_index = _generate_simulated_accept_index(
411
+ accept_index=accept_index,
412
+ predict=predict, # mutable
413
+ accept_length=accept_length, # mutable
414
+ simulate_acc_len=SIMULATE_ACC_LEN,
415
+ bs=bs,
416
+ spec_steps=self.spec_steps,
417
+ )
418
+
381
419
  new_accept_index = []
382
420
  unfinished_index = []
383
421
  accept_index_cpu = accept_index.tolist()
384
422
  predict_cpu = predict.tolist()
385
423
  has_finished = False
386
424
 
387
- # iterate every accepted token and check if req has finished after append the token
425
+ # Iterate every accepted token and check if req has finished after append the token
388
426
  # should be checked BEFORE free kv cache slots
389
427
  for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
390
428
  new_accept_index_ = []
@@ -407,13 +445,28 @@ class EagleVerifyInput:
407
445
  unfinished_index.append(i)
408
446
  req.spec_verify_ct += 1
409
447
 
448
+ if has_finished:
449
+ accept_length = (accept_index != -1).sum(dim=1) - 1
450
+
451
+ # Free the KV cache for unaccepted tokens
452
+ accept_index = accept_index[accept_index != -1]
453
+ verified_id = predict[accept_index]
454
+ evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
455
+ evict_mask[accept_index] = False
456
+
457
+ if page_size != 1:
458
+ align_evict_mask_to_page_size[len(batch.seq_lens),](
459
+ batch.seq_lens,
460
+ evict_mask,
461
+ page_size,
462
+ self.draft_token_num,
463
+ next_power_of_2(self.draft_token_num),
464
+ )
465
+
466
+ token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
467
+
468
+ # Construct EagleVerifyOutput
410
469
  if not has_finished:
411
- accept_index = accept_index[accept_index != -1]
412
- verified_id = predict[accept_index]
413
- evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
414
- evict_mask[accept_index] = False
415
- mem_need_free_idx = batch.out_cache_loc[evict_mask]
416
- token_to_kv_pool_allocator.free(mem_need_free_idx)
417
470
  batch.out_cache_loc = batch.out_cache_loc[accept_index]
418
471
  assign_req_to_token_pool[(bs,)](
419
472
  batch.req_pool_indices,
@@ -422,7 +475,7 @@ class EagleVerifyInput:
422
475
  batch.seq_lens + accept_length + 1,
423
476
  batch.out_cache_loc,
424
477
  batch.req_to_token_pool.req_to_token.shape[1],
425
- triton.next_power_of_2(bs),
478
+ next_power_of_2(bs),
426
479
  )
427
480
  batch.seq_lens.add_(accept_length + 1)
428
481
  accept_length_cpu = accept_length.tolist()
@@ -443,13 +496,6 @@ class EagleVerifyInput:
443
496
  accepeted_indices=accept_index,
444
497
  )
445
498
  else:
446
- accept_length = (accept_index != -1).sum(dim=1) - 1
447
- accept_index = accept_index[accept_index != -1]
448
- verified_id = predict[accept_index]
449
- evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
450
- evict_mask[accept_index] = False
451
- mem_need_free_idx = batch.out_cache_loc[evict_mask]
452
- token_to_kv_pool_allocator.free(mem_need_free_idx)
453
499
  assign_req_to_token_pool[(bs,)](
454
500
  batch.req_pool_indices,
455
501
  batch.req_to_token_pool.req_to_token,
@@ -457,7 +503,7 @@ class EagleVerifyInput:
457
503
  batch.seq_lens + accept_length + 1,
458
504
  batch.out_cache_loc[accept_index],
459
505
  batch.req_to_token_pool.req_to_token.shape[1],
460
- triton.next_power_of_2(bs),
506
+ next_power_of_2(bs),
461
507
  )
462
508
  batch.seq_lens.add_(accept_length + 1)
463
509
  accept_length_cpu = accept_length.tolist()
@@ -465,20 +511,21 @@ class EagleVerifyInput:
465
511
  draft_input = EagleDraftInput()
466
512
  if len(new_accept_index) > 0:
467
513
  new_accept_index = torch.tensor(new_accept_index, device="cuda")
514
+ unfinished_index_device = torch.tensor(unfinished_index, device="cuda")
468
515
  draft_input.hidden_states = batch.spec_info.hidden_states[
469
516
  new_accept_index
470
517
  ]
471
518
  draft_input.verified_id = predict[new_accept_index]
472
- draft_input.accept_length = accept_length[unfinished_index]
473
519
  draft_input.accept_length_cpu = [
474
520
  accept_length_cpu[i] for i in unfinished_index
475
521
  ]
522
+ draft_input.accept_length = accept_length[unfinished_index_device]
476
523
  if has_finished:
477
524
  draft_input.seq_lens_for_draft_extend = batch.seq_lens[
478
- unfinished_index
525
+ unfinished_index_device
479
526
  ]
480
527
  draft_input.req_pool_indices_for_draft_extend = (
481
- batch.req_pool_indices[unfinished_index]
528
+ batch.req_pool_indices[unfinished_index_device]
482
529
  )
483
530
  else:
484
531
  draft_input.seq_lens_for_draft_extend = batch.seq_lens
@@ -564,13 +611,24 @@ def assign_draft_cache_locs(
564
611
  pool_len: tl.constexpr,
565
612
  topk: tl.constexpr,
566
613
  speculative_num_steps: tl.constexpr,
614
+ page_size: tl.constexpr,
567
615
  ):
568
616
  BLOCK_SIZE: tl.constexpr = 32
569
617
  pid = tl.program_id(axis=0)
570
618
  kv_start = tl.load(seq_lens + pid)
571
- kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
619
+
620
+ if page_size == 1 or topk == 1:
621
+ kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
622
+ out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
623
+ else:
624
+ prefix_len = tl.load(seq_lens + pid)
625
+ last_page_len = prefix_len % page_size
626
+ num_new_page = (
627
+ last_page_len + speculative_num_steps + page_size - 1
628
+ ) // page_size
629
+ kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
630
+
572
631
  token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
573
- out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
574
632
 
575
633
  num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
576
634
  for i in range(num_loop):
@@ -642,6 +700,29 @@ def generate_draft_decode_kv_indices(
642
700
  tl.store(kv_indptr + zid, base + zid * iters)
643
701
 
644
702
 
703
+ @triton.jit
704
+ def align_evict_mask_to_page_size(
705
+ seq_lens,
706
+ evict_mask,
707
+ page_size: tl.constexpr,
708
+ num_draft_tokens: tl.constexpr,
709
+ BLOCK_SIZE: tl.constexpr,
710
+ ):
711
+ t_range = tl.arange(0, BLOCK_SIZE)
712
+
713
+ bid = tl.program_id(axis=0)
714
+ seq_len = tl.load(seq_lens + bid)
715
+ io_mask = t_range < num_draft_tokens
716
+ mask_row = tl.load(evict_mask + bid * num_draft_tokens + t_range, mask=io_mask)
717
+
718
+ num_trues = tl.sum(mask_row)
719
+ num_false = num_draft_tokens - num_trues
720
+
721
+ start = (seq_len + num_false - 1) // page_size * page_size - seq_len
722
+ for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
723
+ tl.store(evict_mask + bid * num_draft_tokens + i, False)
724
+
725
+
645
726
  @torch.compile(dynamic=True)
646
727
  def select_top_k_tokens(
647
728
  i: int,
@@ -699,3 +780,34 @@ def fast_topk(values, topk, dim):
699
780
  else:
700
781
  # Use topk for efficiency with larger k values
701
782
  return torch.topk(values, topk, dim=dim)
783
+
784
+
785
+ def _generate_simulated_accept_index(
786
+ accept_index,
787
+ predict,
788
+ accept_length,
789
+ simulate_acc_len,
790
+ bs,
791
+ spec_steps,
792
+ ):
793
+ simulate_acc_len_float = float(simulate_acc_len)
794
+ simulated_values = torch.normal(
795
+ mean=simulate_acc_len_float,
796
+ std=1.0,
797
+ size=(1,),
798
+ device="cpu",
799
+ )
800
+ # clamp simulated values to be between 1 and self.spec_steps
801
+ simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps)
802
+ simulate_acc_len = int(simulated_values.round().item())
803
+
804
+ accept_indx_first_col = accept_index[:, 0].view(-1, 1)
805
+ sim_accept_index = torch.full(
806
+ (bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
807
+ )
808
+ sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
809
+ simulate_acc_len, device=accept_index.device
810
+ )
811
+ accept_length.fill_(simulate_acc_len - 1)
812
+ predict.fill_(100) # some legit token id
813
+ return sim_accept_index