sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,12 @@ from typing import List, Optional, Tuple
7
7
  import torch
8
8
  from huggingface_hub import snapshot_download
9
9
 
10
- from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
11
- from sglang.srt.layers.dp_attention import disable_dp_size
10
+ from sglang.srt.distributed import (
11
+ GroupCoordinator,
12
+ get_tensor_model_parallel_world_size,
13
+ get_tp_group,
14
+ patch_tensor_parallel_group,
15
+ )
12
16
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
13
17
  from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
14
18
  from sglang.srt.managers.schedule_batch import (
@@ -35,11 +39,17 @@ from sglang.srt.speculative.eagle_utils import (
35
39
  EagleVerifyInput,
36
40
  EagleVerifyOutput,
37
41
  assign_draft_cache_locs,
42
+ fast_topk,
38
43
  generate_token_bitmask,
39
44
  select_top_k_tokens,
40
45
  )
41
46
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
42
- from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda
47
+ from sglang.srt.utils import (
48
+ empty_context,
49
+ get_available_gpu_memory,
50
+ is_cuda,
51
+ next_power_of_2,
52
+ )
43
53
 
44
54
  if is_cuda():
45
55
  from sgl_kernel import segment_packbits
@@ -51,7 +61,7 @@ logger = logging.getLogger(__name__)
51
61
  def draft_tp_context(tp_group: GroupCoordinator):
52
62
  # Draft model doesn't use dp and has its own tp group.
53
63
  # We disable mscclpp now because it doesn't support 2 comm groups.
54
- with disable_dp_size(), patch_tensor_parallel_group(tp_group):
64
+ with patch_tensor_parallel_group(tp_group):
55
65
  yield
56
66
 
57
67
 
@@ -70,6 +80,7 @@ class EAGLEWorker(TpModelWorker):
70
80
  self.server_args = server_args
71
81
  self.topk = server_args.speculative_eagle_topk
72
82
  self.speculative_num_steps = server_args.speculative_num_steps
83
+ self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
73
84
  self.enable_nan_detection = server_args.enable_nan_detection
74
85
  self.gpu_id = gpu_id
75
86
  self.device = server_args.device
@@ -152,8 +163,18 @@ class EAGLEWorker(TpModelWorker):
152
163
  self.init_attention_backend()
153
164
  self.init_cuda_graphs()
154
165
 
166
+ # Some dummy tensors
167
+ self.num_new_pages_per_topk = torch.empty(
168
+ (), dtype=torch.int64, device=self.device
169
+ )
170
+ self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
171
+
155
172
  def init_attention_backend(self):
156
173
  # Create multi-step attn backends and cuda graph runners
174
+
175
+ self.has_prefill_wrapper_verify = False
176
+ self.draft_extend_attn_backend = None
177
+
157
178
  if self.server_args.attention_backend == "flashinfer":
158
179
  if not global_server_args_dict["use_mla_backend"]:
159
180
  from sglang.srt.layers.attention.flashinfer_backend import (
@@ -201,7 +222,6 @@ class EAGLEWorker(TpModelWorker):
201
222
  self.draft_model_runner,
202
223
  skip_prefill=False,
203
224
  )
204
- self.has_prefill_wrapper_verify = False
205
225
  elif self.server_args.attention_backend == "fa3":
206
226
  from sglang.srt.layers.attention.flashattention_backend import (
207
227
  FlashAttentionBackend,
@@ -217,7 +237,6 @@ class EAGLEWorker(TpModelWorker):
217
237
  self.draft_model_runner,
218
238
  skip_prefill=False,
219
239
  )
220
- self.has_prefill_wrapper_verify = False
221
240
  elif self.server_args.attention_backend == "flashmla":
222
241
  from sglang.srt.layers.attention.flashmla_backend import (
223
242
  FlashMLAMultiStepDraftBackend,
@@ -228,8 +247,6 @@ class EAGLEWorker(TpModelWorker):
228
247
  self.topk,
229
248
  self.speculative_num_steps,
230
249
  )
231
- self.draft_extend_attn_backend = None
232
- self.has_prefill_wrapper_verify = False
233
250
  else:
234
251
  raise ValueError(
235
252
  f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
@@ -254,7 +271,7 @@ class EAGLEWorker(TpModelWorker):
254
271
  self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
255
272
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
256
273
  logger.info(
257
- f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
274
+ f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
258
275
  )
259
276
 
260
277
  # Capture extend
@@ -269,7 +286,7 @@ class EAGLEWorker(TpModelWorker):
269
286
  )
270
287
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
271
288
  logger.info(
272
- f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
289
+ f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
273
290
  )
274
291
 
275
292
  @property
@@ -290,17 +307,27 @@ class EAGLEWorker(TpModelWorker):
290
307
  A tuple of the final logit output of the target model, next tokens accepted,
291
308
  the batch id (used for overlap schedule), and number of accepted tokens.
292
309
  """
293
- if batch.forward_mode.is_decode():
310
+ if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
311
+ logits_output, next_token_ids, bid, seq_lens_cpu = (
312
+ self.forward_target_extend(batch)
313
+ )
314
+ with self.draft_tp_context(self.draft_model_runner.tp_group):
315
+ self.forward_draft_extend(
316
+ batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
317
+ )
318
+ return logits_output, next_token_ids, bid, 0, False
319
+ else:
294
320
  with self.draft_tp_context(self.draft_model_runner.tp_group):
295
321
  spec_info = self.draft(batch)
296
322
  logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
297
323
  self.verify(batch, spec_info)
298
324
  )
299
325
 
300
- # If it is None, it means all requests are finished
301
- if batch.spec_info.verified_id is not None:
326
+ if self.check_forward_draft_extend_after_decode(batch):
302
327
  with self.draft_tp_context(self.draft_model_runner.tp_group):
303
- self.forward_draft_extend_after_decode(batch)
328
+ self.forward_draft_extend_after_decode(
329
+ batch,
330
+ )
304
331
  return (
305
332
  logits_output,
306
333
  verify_output.verified_id,
@@ -308,22 +335,27 @@ class EAGLEWorker(TpModelWorker):
308
335
  sum(verify_output.accept_length_per_req_cpu),
309
336
  can_run_cuda_graph,
310
337
  )
311
- elif batch.forward_mode.is_idle():
312
- model_worker_batch = batch.get_model_worker_batch()
313
- logits_output, next_token_ids, _ = (
314
- self.target_worker.forward_batch_generation(model_worker_batch)
315
- )
316
338
 
317
- return logits_output, next_token_ids, model_worker_batch.bid, 0, False
318
- else:
319
- logits_output, next_token_ids, bid, seq_lens_cpu = (
320
- self.forward_target_extend(batch)
321
- )
322
- with self.draft_tp_context(self.draft_model_runner.tp_group):
323
- self.forward_draft_extend(
324
- batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
325
- )
326
- return logits_output, next_token_ids, bid, 0, False
339
+ def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
340
+ local_need_forward = (
341
+ batch.spec_info.verified_id is not None
342
+ and batch.spec_info.verified_id.shape[0] > 0
343
+ )
344
+ if not self.server_args.enable_dp_attention:
345
+ return local_need_forward
346
+
347
+ global_need_forward = torch.tensor(
348
+ [
349
+ (local_need_forward),
350
+ ],
351
+ dtype=torch.int64,
352
+ )
353
+ torch.distributed.all_reduce(
354
+ global_need_forward, group=get_tp_group().cpu_group
355
+ )
356
+ global_need_forward_cnt = global_need_forward[0].item()
357
+ need_forward = global_need_forward_cnt > 0
358
+ return need_forward
327
359
 
328
360
  def forward_target_extend(
329
361
  self, batch: ScheduleBatch
@@ -342,6 +374,7 @@ class EAGLEWorker(TpModelWorker):
342
374
  # We need the full hidden states to prefill the KV cache of the draft model.
343
375
  model_worker_batch = batch.get_model_worker_batch()
344
376
  model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
377
+ model_worker_batch.spec_num_draft_tokens = 1
345
378
  logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
346
379
  model_worker_batch
347
380
  )
@@ -352,7 +385,7 @@ class EAGLEWorker(TpModelWorker):
352
385
  model_worker_batch.seq_lens_cpu,
353
386
  )
354
387
 
355
- def draft(self, batch: ScheduleBatch):
388
+ def _draft_preprocess_decode(self, batch: ScheduleBatch):
356
389
  # Parse args
357
390
  num_seqs = batch.batch_size()
358
391
  spec_info = batch.spec_info
@@ -365,14 +398,21 @@ class EAGLEWorker(TpModelWorker):
365
398
  )
366
399
 
367
400
  # Allocate cache locations
401
+ # Layout of the out_cache_loc
402
+ # [ topk 0 ] [ topk 1 ]
403
+ # [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
368
404
  if self.page_size == 1:
369
405
  out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
370
- num_seqs * self.topk * self.speculative_num_steps, backup_state=True
406
+ num_seqs * self.speculative_num_steps * self.topk, backup_state=True
371
407
  )
372
408
  else:
373
409
  if self.topk == 1:
374
- prefix_lens = batch.seq_lens
375
- seq_lens = prefix_lens + self.speculative_num_steps
410
+ prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1(
411
+ batch.req_to_token_pool.req_to_token,
412
+ batch.req_pool_indices,
413
+ batch.seq_lens,
414
+ self.speculative_num_steps,
415
+ )
376
416
  extend_num_tokens = num_seqs * self.speculative_num_steps
377
417
  else:
378
418
  # In this case, the last partial page needs to be duplicated.
@@ -385,29 +425,33 @@ class EAGLEWorker(TpModelWorker):
385
425
  # "x" means speculative draft tokens
386
426
  # "." means padded tokens
387
427
 
388
- # TODO: fuse these ops
389
- prefix_lens = batch.seq_lens
390
- last_page_lens = prefix_lens % self.page_size
391
- num_new_pages = (
392
- last_page_lens + self.speculative_num_steps + self.page_size - 1
393
- ) // self.page_size
394
- seq_lens = (
395
- prefix_lens // self.page_size * self.page_size
396
- + num_new_pages * (self.page_size * self.topk)
397
- )
398
- extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
399
- raise NotImplementedError(
400
- "page_size > 1 and top_k > 1 are not supported."
428
+ # TODO(lmzheng): The current implementation is still a fake support
429
+ # for page size > 1. In the `assign_draft_cache_locs` below,
430
+ # we directly move the indices instead of the real kv cache.
431
+ # This only works when the kernel backend runs with page size = 1.
432
+ # If the kernel backend runs with page size > 1, we need to
433
+ # duplicate the real KV cache. The overhead of duplicating KV
434
+ # cache seems okay because the draft KV cache only has one layer.
435
+ # see a related copy operation in MHATokenToKVPool::move_kv_cache.
436
+
437
+ (
438
+ prefix_lens,
439
+ seq_lens,
440
+ last_loc,
441
+ self.num_new_pages_per_topk,
442
+ self.extend_lens,
443
+ ) = get_last_loc_large_page_size_large_top_k(
444
+ batch.req_to_token_pool.req_to_token,
445
+ batch.req_pool_indices,
446
+ batch.seq_lens,
447
+ self.speculative_num_steps,
448
+ self.topk,
449
+ self.page_size,
401
450
  )
402
- # TODO: Support page_size > 1 and top_k > 1
403
- # 1. Duplicate the KV cache in the last partial page for all top-k segments
404
- # 2. Modify generate_draft_decode_kv_indices accordingly
405
-
406
- last_loc = get_last_loc(
407
- batch.req_to_token_pool.req_to_token,
408
- batch.req_pool_indices,
409
- prefix_lens,
410
- )
451
+
452
+ # TODO(lmzheng): remove this device sync
453
+ extend_num_tokens = torch.sum(self.extend_lens).item()
454
+
411
455
  out_cache_loc, token_to_kv_pool_state_backup = (
412
456
  batch.alloc_paged_token_slots_extend(
413
457
  prefix_lens,
@@ -422,19 +466,54 @@ class EAGLEWorker(TpModelWorker):
422
466
  batch.req_pool_indices,
423
467
  batch.req_to_token_pool.req_to_token,
424
468
  batch.seq_lens,
469
+ self.extend_lens,
470
+ self.num_new_pages_per_topk,
425
471
  out_cache_loc,
426
472
  batch.req_to_token_pool.req_to_token.shape[1],
427
473
  self.topk,
428
474
  self.speculative_num_steps,
429
475
  self.page_size,
476
+ next_power_of_2(num_seqs),
477
+ next_power_of_2(self.speculative_num_steps),
430
478
  )
479
+
480
+ if self.page_size > 1 and self.topk > 1:
481
+ # Remove padded slots
482
+ out_cache_loc = out_cache_loc[
483
+ : num_seqs * self.topk * self.speculative_num_steps
484
+ ]
485
+
431
486
  batch.out_cache_loc = out_cache_loc
432
487
  batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
488
+ batch.return_hidden_states = False
433
489
  spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
490
+ self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
491
+
492
+ def _draft_preprocess_idle(self, batch: ScheduleBatch):
493
+ batch.spec_info = EagleDraftInput.create_idle_input(
494
+ device=self.device,
495
+ hidden_size=self.model_config.hidden_size,
496
+ dtype=self.model_config.dtype,
497
+ topk=self.topk,
498
+ capture_hidden_mode=CaptureHiddenMode.LAST,
499
+ )
500
+
501
+ def draft(self, batch: ScheduleBatch):
502
+ # Parse args
503
+ if batch.forward_mode.is_idle():
504
+ self._draft_preprocess_idle(batch)
505
+ else:
506
+ self._draft_preprocess_decode(batch)
507
+
508
+ spec_info = batch.spec_info
434
509
 
435
- # Get forward batch
436
510
  spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
511
+ batch.return_hidden_states = False
512
+
513
+ # Get forward batch
437
514
  model_worker_batch = batch.get_model_worker_batch()
515
+ model_worker_batch.spec_num_draft_tokens = self.topk
516
+ assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
438
517
  forward_batch = ForwardBatch.init_new(
439
518
  model_worker_batch, self.draft_model_runner
440
519
  )
@@ -446,15 +525,18 @@ class EAGLEWorker(TpModelWorker):
446
525
  forward_batch
447
526
  )
448
527
  else:
449
- # Initialize attention backend
450
- self.draft_attn_backend.init_forward_metadata(forward_batch)
451
- forward_batch = ForwardBatch.init_new(
452
- model_worker_batch, self.draft_model_runner
453
- )
528
+ if not forward_batch.forward_mode.is_idle():
529
+ # Initialize attention backend
530
+ self.draft_attn_backend.init_forward_metadata(forward_batch)
454
531
  # Run forward steps
455
532
  score_list, token_list, parents_list = self.draft_forward(forward_batch)
456
533
 
457
- self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
534
+ if batch.forward_mode.is_idle():
535
+ return EagleVerifyInput.create_idle_input(
536
+ self.topk,
537
+ self.speculative_num_steps,
538
+ self.speculative_num_draft_tokens,
539
+ )
458
540
 
459
541
  (
460
542
  tree_mask,
@@ -472,7 +554,7 @@ class EAGLEWorker(TpModelWorker):
472
554
  batch.seq_lens_sum,
473
555
  self.topk,
474
556
  self.speculative_num_steps,
475
- self.server_args.speculative_num_draft_tokens,
557
+ self.speculative_num_draft_tokens,
476
558
  )
477
559
 
478
560
  return EagleVerifyInput(
@@ -503,6 +585,13 @@ class EAGLEWorker(TpModelWorker):
503
585
  if self.hot_token_id is not None:
504
586
  topk_index = self.hot_token_id[topk_index]
505
587
 
588
+ out_cache_loc = out_cache_loc.reshape(
589
+ forward_batch.batch_size, self.topk, self.speculative_num_steps
590
+ )
591
+ out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
592
+ self.speculative_num_steps, -1
593
+ )
594
+
506
595
  # Return values
507
596
  score_list: List[torch.Tensor] = []
508
597
  token_list: List[torch.Tensor] = []
@@ -524,10 +613,7 @@ class EAGLEWorker(TpModelWorker):
524
613
 
525
614
  # Set inputs
526
615
  forward_batch.input_ids = input_ids
527
- out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
528
- forward_batch.out_cache_loc = out_cache_loc[
529
- :, self.topk * i : self.topk * (i + 1)
530
- ].flatten()
616
+ forward_batch.out_cache_loc = out_cache_loc[i]
531
617
  forward_batch.positions.add_(1)
532
618
  forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
533
619
  spec_info.hidden_states = hidden_states
@@ -547,11 +633,18 @@ class EAGLEWorker(TpModelWorker):
547
633
 
548
634
  def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
549
635
  spec_info.prepare_for_verify(batch, self.page_size)
550
- batch.forward_mode = ForwardMode.TARGET_VERIFY
636
+ batch.return_hidden_states = False
637
+ batch.forward_mode = (
638
+ ForwardMode.TARGET_VERIFY
639
+ if not batch.forward_mode.is_idle()
640
+ else ForwardMode.IDLE
641
+ )
551
642
  batch.spec_info = spec_info
552
643
  model_worker_batch = batch.get_model_worker_batch(
553
644
  seq_lens_cpu_cache=spec_info.seq_lens_cpu
554
645
  )
646
+ model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
647
+ assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
555
648
 
556
649
  if batch.has_grammar:
557
650
  retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
@@ -583,7 +676,7 @@ class EAGLEWorker(TpModelWorker):
583
676
  if vocab_mask is not None:
584
677
  assert spec_info.grammar is not None
585
678
  vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
586
- # otherwise, this vocab mask will be the one from the previous extend stage
679
+ # NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage
587
680
  # and will be applied to produce wrong results
588
681
  batch.sampling_info.vocab_mask = None
589
682
 
@@ -604,13 +697,15 @@ class EAGLEWorker(TpModelWorker):
604
697
  ]
605
698
  logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
606
699
 
607
- # Prepare the batch for the next draft forwards.
608
- batch.forward_mode = ForwardMode.DECODE
609
- batch.spec_info = res.draft_input
610
-
611
700
  if batch.return_logprob:
612
701
  self.add_logprob_values(batch, res, logits_output)
613
702
 
703
+ # Prepare the batch for the next draft forwards.
704
+ batch.forward_mode = (
705
+ ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
706
+ )
707
+ batch.spec_info = res.draft_input
708
+
614
709
  return logits_output, res, model_worker_batch, can_run_cuda_graph
615
710
 
616
711
  def add_logprob_values(
@@ -623,8 +718,16 @@ class EAGLEWorker(TpModelWorker):
623
718
  logits_output = res.logits_output
624
719
  top_logprobs_nums = batch.top_logprobs_nums
625
720
  token_ids_logprobs = batch.token_ids_logprobs
721
+ accepted_indices = res.accepted_indices
722
+ assert len(accepted_indices) == len(logits_output.next_token_logits)
723
+ temperatures = batch.sampling_info.temperatures
724
+ num_draft_tokens = batch.spec_info.draft_token_num
725
+ # acceptance indices are the indices in a "flattened" batch.
726
+ # dividing it to num_draft_tokens will yield the actual batch index.
727
+ temperatures = temperatures[accepted_indices // num_draft_tokens]
728
+
626
729
  logprobs = torch.nn.functional.log_softmax(
627
- logits_output.next_token_logits, dim=-1
730
+ logits_output.next_token_logits / temperatures, dim=-1
628
731
  )
629
732
  batch_next_token_ids = res.verified_id
630
733
  num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
@@ -659,7 +762,7 @@ class EAGLEWorker(TpModelWorker):
659
762
  pt = 0
660
763
  next_token_logprobs = logits_output.next_token_logprobs.tolist()
661
764
  verified_ids = batch_next_token_ids.tolist()
662
- for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
765
+ for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True):
663
766
  for _ in range(num_tokens):
664
767
  if req.return_logprob:
665
768
  req.output_token_logprobs_val.append(next_token_logprobs[pt])
@@ -691,11 +794,13 @@ class EAGLEWorker(TpModelWorker):
691
794
  hidden_states=hidden_states,
692
795
  verified_id=next_token_ids,
693
796
  )
797
+ batch.return_hidden_states = False
694
798
  batch.spec_info.prepare_for_extend(batch)
695
799
  batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
696
800
  model_worker_batch = batch.get_model_worker_batch(
697
801
  seq_lens_cpu_cache=seq_lens_cpu
698
802
  )
803
+ model_worker_batch.spec_num_draft_tokens = 1
699
804
  forward_batch = ForwardBatch.init_new(
700
805
  model_worker_batch, self.draft_model_runner
701
806
  )
@@ -712,13 +817,33 @@ class EAGLEWorker(TpModelWorker):
712
817
  req_pool_indices_backup = batch.req_pool_indices
713
818
  accept_length_backup = batch.spec_info.accept_length
714
819
  return_logprob_backup = batch.return_logprob
715
-
716
- # Prepare metadata
717
- batch.spec_info.prepare_extend_after_decode(
718
- batch,
719
- self.speculative_num_steps,
720
- )
820
+ input_is_idle = batch.forward_mode.is_idle()
821
+ if not input_is_idle:
822
+ # Prepare metadata
823
+ if batch.spec_info.verified_id is not None:
824
+ batch.spec_info.prepare_extend_after_decode(
825
+ batch,
826
+ self.speculative_num_steps,
827
+ )
828
+ else:
829
+ batch = batch.copy()
830
+ batch.prepare_for_idle()
831
+ hidden_size = (
832
+ self.model_config.hidden_size * 3
833
+ if self.speculative_algorithm.is_eagle3()
834
+ else self.model_config.hidden_size
835
+ )
836
+ batch.spec_info = EagleDraftInput.create_idle_input(
837
+ device=self.device,
838
+ hidden_size=hidden_size,
839
+ dtype=self.model_config.dtype,
840
+ topk=self.topk,
841
+ capture_hidden_mode=CaptureHiddenMode.LAST,
842
+ )
843
+ batch.return_hidden_states = False
721
844
  model_worker_batch = batch.get_model_worker_batch()
845
+ model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
846
+ assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
722
847
  forward_batch = ForwardBatch.init_new(
723
848
  model_worker_batch, self.draft_model_runner
724
849
  )
@@ -742,7 +867,10 @@ class EAGLEWorker(TpModelWorker):
742
867
  )
743
868
  forward_batch.spec_info.hidden_states = logits_output.hidden_states
744
869
  else:
745
- self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
870
+ if not forward_batch.forward_mode.is_idle():
871
+ self.draft_model_runner.attn_backend.init_forward_metadata(
872
+ forward_batch
873
+ )
746
874
  logits_output = self.draft_model_runner.model.forward(
747
875
  forward_batch.input_ids, forward_batch.positions, forward_batch
748
876
  )
@@ -752,7 +880,9 @@ class EAGLEWorker(TpModelWorker):
752
880
 
753
881
  # Restore backup.
754
882
  # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
755
- batch.forward_mode = ForwardMode.DECODE
883
+ batch.forward_mode = (
884
+ ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
885
+ )
756
886
  batch.seq_lens = seq_lens_backup
757
887
  batch.req_pool_indices = req_pool_indices_backup
758
888
  batch.spec_info.accept_length = accept_length_backup
@@ -781,4 +911,48 @@ def load_token_map(token_map_path: str) -> List[int]:
781
911
  )
782
912
  token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
783
913
  hot_token_id = torch.load(token_map_path, weights_only=True)
784
- return torch.tensor(hot_token_id, dtype=torch.int32)
914
+ return torch.tensor(hot_token_id, dtype=torch.int64)
915
+
916
+
917
+ @torch.compile(dynamic=True)
918
+ def get_last_loc_large_page_size_top_k_1(
919
+ req_to_token: torch.Tensor,
920
+ req_pool_indices: torch.Tensor,
921
+ seq_lens,
922
+ speculative_num_steps: int,
923
+ ):
924
+ prefix_lens = seq_lens
925
+ seq_lens = prefix_lens + speculative_num_steps
926
+ last_loc = get_last_loc(
927
+ req_to_token,
928
+ req_pool_indices,
929
+ prefix_lens,
930
+ )
931
+ return prefix_lens, seq_lens, last_loc
932
+
933
+
934
+ @torch.compile(dynamic=True)
935
+ def get_last_loc_large_page_size_large_top_k(
936
+ req_to_token: torch.Tensor,
937
+ req_pool_indices: torch.Tensor,
938
+ seq_lens: torch.Tensor,
939
+ speculative_num_steps: int,
940
+ topk: int,
941
+ page_size: int,
942
+ ):
943
+ prefix_lens = seq_lens
944
+ last_page_lens = prefix_lens % page_size
945
+ num_new_pages_per_topk = (
946
+ last_page_lens + speculative_num_steps + page_size - 1
947
+ ) // page_size
948
+ seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * (
949
+ page_size * topk
950
+ )
951
+ extend_lens = seq_lens - prefix_lens
952
+ last_loc = get_last_loc(
953
+ req_to_token,
954
+ req_pool_indices,
955
+ prefix_lens,
956
+ )
957
+
958
+ return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens
@@ -1,11 +1,13 @@
1
1
  import logging
2
+ import threading
3
+ import time
2
4
  from abc import ABC
3
- from contextlib import contextmanager
5
+ from contextlib import contextmanager, nullcontext
4
6
 
5
7
  try:
6
8
  import torch_memory_saver
7
9
 
8
- _primary_memory_saver = torch_memory_saver.TorchMemorySaver()
10
+ _memory_saver = torch_memory_saver.torch_memory_saver
9
11
  import_error = None
10
12
  except ImportError as e:
11
13
  import_error = e
@@ -38,13 +40,13 @@ class TorchMemorySaverAdapter(ABC):
38
40
  def configure_subprocess(self):
39
41
  raise NotImplementedError
40
42
 
41
- def region(self):
43
+ def region(self, tag: str):
42
44
  raise NotImplementedError
43
45
 
44
- def pause(self):
46
+ def pause(self, tag: str):
45
47
  raise NotImplementedError
46
48
 
47
- def resume(self):
49
+ def resume(self, tag: str):
48
50
  raise NotImplementedError
49
51
 
50
52
  @property
@@ -53,21 +55,23 @@ class TorchMemorySaverAdapter(ABC):
53
55
 
54
56
 
55
57
  class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
58
+ """Adapter for TorchMemorySaver with tag-based control"""
59
+
56
60
  def configure_subprocess(self):
57
61
  return torch_memory_saver.configure_subprocess()
58
62
 
59
- def region(self):
60
- return _primary_memory_saver.region()
63
+ def region(self, tag: str):
64
+ return _memory_saver.region(tag=tag)
61
65
 
62
- def pause(self):
63
- return _primary_memory_saver.pause()
66
+ def pause(self, tag: str):
67
+ return _memory_saver.pause(tag=tag)
64
68
 
65
- def resume(self):
66
- return _primary_memory_saver.resume()
69
+ def resume(self, tag: str):
70
+ return _memory_saver.resume(tag=tag)
67
71
 
68
72
  @property
69
73
  def enabled(self):
70
- return _primary_memory_saver.enabled
74
+ return _memory_saver is not None and _memory_saver.enabled
71
75
 
72
76
 
73
77
  class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
@@ -76,13 +80,13 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
76
80
  yield
77
81
 
78
82
  @contextmanager
79
- def region(self):
83
+ def region(self, tag: str):
80
84
  yield
81
85
 
82
- def pause(self):
86
+ def pause(self, tag: str):
83
87
  pass
84
88
 
85
- def resume(self):
89
+ def resume(self, tag: str):
86
90
  pass
87
91
 
88
92
  @property