sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. sglang/bench_serving.py +72 -10
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/deepseekvl2.py +10 -1
  4. sglang/srt/configs/model_config.py +6 -16
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/custom_op.py +5 -0
  7. sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
  8. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  9. sglang/srt/distributed/parallel_state.py +32 -5
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/entrypoints/http_server.py +7 -1
  12. sglang/srt/entrypoints/verl_engine.py +2 -0
  13. sglang/srt/function_call_parser.py +0 -1
  14. sglang/srt/layers/attention/flashattention_backend.py +582 -125
  15. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  17. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  18. sglang/srt/layers/dp_attention.py +12 -1
  19. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  20. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  21. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  26. sglang/srt/layers/moe/topk.py +79 -6
  27. sglang/srt/layers/quantization/__init__.py +137 -165
  28. sglang/srt/layers/quantization/awq.py +200 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  30. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  31. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  32. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  33. sglang/srt/layers/quantization/gptq.py +30 -40
  34. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  35. sglang/srt/layers/quantization/utils.py +1 -1
  36. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  37. sglang/srt/lora/backend/base_backend.py +4 -4
  38. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  39. sglang/srt/lora/backend/triton_backend.py +5 -8
  40. sglang/srt/lora/layers.py +19 -33
  41. sglang/srt/lora/lora_manager.py +20 -7
  42. sglang/srt/lora/mem_pool.py +12 -6
  43. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  44. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  45. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  46. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  47. sglang/srt/lora/utils.py +6 -0
  48. sglang/srt/managers/cache_controller.py +34 -11
  49. sglang/srt/managers/io_struct.py +4 -2
  50. sglang/srt/managers/mm_utils.py +202 -156
  51. sglang/srt/managers/multimodal_processor.py +0 -2
  52. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  53. sglang/srt/managers/multimodal_processors/clip.py +44 -0
  54. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  55. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  56. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  57. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  58. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  59. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  61. sglang/srt/managers/schedule_batch.py +185 -127
  62. sglang/srt/managers/scheduler.py +29 -23
  63. sglang/srt/managers/tokenizer_manager.py +1 -2
  64. sglang/srt/managers/tp_worker.py +3 -0
  65. sglang/srt/managers/utils.py +1 -6
  66. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  67. sglang/srt/mem_cache/memory_pool.py +72 -6
  68. sglang/srt/mem_cache/paged_allocator.py +39 -0
  69. sglang/srt/metrics/collector.py +23 -53
  70. sglang/srt/model_executor/cuda_graph_runner.py +16 -13
  71. sglang/srt/model_executor/forward_batch_info.py +10 -10
  72. sglang/srt/model_executor/model_runner.py +64 -59
  73. sglang/srt/model_loader/loader.py +19 -1
  74. sglang/srt/model_loader/weight_utils.py +6 -3
  75. sglang/srt/models/clip.py +568 -0
  76. sglang/srt/models/deepseek_janus_pro.py +12 -17
  77. sglang/srt/models/deepseek_v2.py +339 -123
  78. sglang/srt/models/deepseek_vl2.py +105 -104
  79. sglang/srt/models/gemma3_causal.py +12 -2
  80. sglang/srt/models/gemma3_mm.py +20 -80
  81. sglang/srt/models/llama.py +4 -1
  82. sglang/srt/models/llava.py +31 -19
  83. sglang/srt/models/llavavid.py +16 -7
  84. sglang/srt/models/minicpmo.py +63 -147
  85. sglang/srt/models/minicpmv.py +17 -27
  86. sglang/srt/models/mllama.py +29 -14
  87. sglang/srt/models/qwen2.py +9 -6
  88. sglang/srt/models/qwen2_5_vl.py +21 -31
  89. sglang/srt/models/qwen2_vl.py +20 -21
  90. sglang/srt/openai_api/adapter.py +106 -93
  91. sglang/srt/openai_api/protocol.py +10 -5
  92. sglang/srt/patch_torch.py +71 -0
  93. sglang/srt/platforms/interface.py +371 -0
  94. sglang/srt/server_args.py +120 -25
  95. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  96. sglang/srt/speculative/eagle_utils.py +140 -28
  97. sglang/srt/speculative/eagle_worker.py +94 -25
  98. sglang/srt/utils.py +137 -51
  99. sglang/test/runners.py +27 -2
  100. sglang/test/test_custom_ops.py +55 -0
  101. sglang/test/test_utils.py +14 -27
  102. sglang/utils.py +2 -2
  103. sglang/version.py +1 -1
  104. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
  105. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
  106. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  107. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  108. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import os
3
4
  from dataclasses import dataclass
4
5
  from typing import TYPE_CHECKING, List, Optional
5
6
 
@@ -10,11 +11,15 @@ import triton.language as tl
10
11
 
11
12
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
12
13
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
13
- from sglang.srt.managers.schedule_batch import global_server_args_dict
14
+ from sglang.srt.managers.schedule_batch import (
15
+ ScheduleBatch,
16
+ get_last_loc,
17
+ global_server_args_dict,
18
+ )
14
19
  from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
15
20
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
16
21
  from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
17
- from sglang.srt.utils import is_cuda_available, is_hip
22
+ from sglang.srt.utils import is_cuda_available, is_hip, next_power_of_2
18
23
 
19
24
  if is_cuda_available():
20
25
  from sgl_kernel import (
@@ -34,6 +39,9 @@ import logging
34
39
  logger = logging.getLogger(__name__)
35
40
 
36
41
 
42
+ SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
43
+
44
+
37
45
  @dataclass
38
46
  class EagleDraftInput:
39
47
  # The inputs for decode
@@ -93,7 +101,7 @@ class EagleDraftInput:
93
101
  torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
94
102
  self.positions,
95
103
  new_verified_id,
96
- triton.next_power_of_2(speculative_num_steps + 1),
104
+ next_power_of_2(speculative_num_steps + 1),
97
105
  )
98
106
 
99
107
  batch.seq_lens_sum = sum(seq_lens_cpu)
@@ -225,18 +233,34 @@ class EagleVerifyInput:
225
233
  CaptureHiddenMode.FULL,
226
234
  )
227
235
 
228
- def prepare_for_verify(self, batch: ScheduleBatch):
236
+ def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
229
237
  batch.input_ids = self.draft_token
230
- batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
238
+
239
+ if page_size == 1:
240
+ batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
241
+ end_offset = batch.seq_lens + self.draft_token_num
242
+ else:
243
+ prefix_lens = batch.seq_lens
244
+ end_offset = prefix_lens + self.draft_token_num
245
+ last_loc = get_last_loc(
246
+ batch.req_to_token_pool.req_to_token,
247
+ batch.req_pool_indices,
248
+ prefix_lens,
249
+ )
250
+ batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
251
+ prefix_lens, end_offset, last_loc, len(batch.input_ids)
252
+ )
253
+ self.last_loc = last_loc
254
+
231
255
  bs = batch.batch_size()
232
256
  assign_req_to_token_pool[(bs,)](
233
257
  batch.req_pool_indices,
234
258
  batch.req_to_token_pool.req_to_token,
235
259
  batch.seq_lens,
236
- batch.seq_lens + self.draft_token_num,
260
+ end_offset,
237
261
  batch.out_cache_loc,
238
262
  batch.req_to_token_pool.req_to_token.shape[1],
239
- triton.next_power_of_2(bs),
263
+ next_power_of_2(bs),
240
264
  )
241
265
 
242
266
  def generate_attn_arg_prefill(
@@ -282,6 +306,7 @@ class EagleVerifyInput:
282
306
  batch: ScheduleBatch,
283
307
  logits_output: torch.Tensor,
284
308
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
309
+ page_size: int,
285
310
  ) -> torch.Tensor:
286
311
  """
287
312
  Verify and find accepted tokens based on logits output and batch
@@ -305,6 +330,7 @@ class EagleVerifyInput:
305
330
  )
306
331
  accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
307
332
 
333
+ # Apply penalty
308
334
  if sampling_info.penalizer_orchestrator.is_required:
309
335
  # This is a relaxed version of penalties for speculative decoding.
310
336
  linear_penalty = torch.zeros(
@@ -317,6 +343,7 @@ class EagleVerifyInput:
317
343
  torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
318
344
  )
319
345
 
346
+ # Sample tokens
320
347
  if batch.sampling_info.is_all_greedy:
321
348
  target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
322
349
  target_predict = target_predict.reshape(bs, self.draft_token_num)
@@ -378,13 +405,24 @@ class EagleVerifyInput:
378
405
  deterministic=True,
379
406
  )
380
407
 
408
+ if SIMULATE_ACC_LEN:
409
+ # Do simulation
410
+ accept_index = _generate_simulated_accept_index(
411
+ accept_index=accept_index,
412
+ predict=predict, # mutable
413
+ accept_length=accept_length, # mutable
414
+ simulate_acc_len=SIMULATE_ACC_LEN,
415
+ bs=bs,
416
+ spec_steps=self.spec_steps,
417
+ )
418
+
381
419
  new_accept_index = []
382
420
  unfinished_index = []
383
421
  accept_index_cpu = accept_index.tolist()
384
422
  predict_cpu = predict.tolist()
385
423
  has_finished = False
386
424
 
387
- # iterate every accepted token and check if req has finished after append the token
425
+ # Iterate every accepted token and check if req has finished after append the token
388
426
  # should be checked BEFORE free kv cache slots
389
427
  for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
390
428
  new_accept_index_ = []
@@ -407,13 +445,28 @@ class EagleVerifyInput:
407
445
  unfinished_index.append(i)
408
446
  req.spec_verify_ct += 1
409
447
 
448
+ if has_finished:
449
+ accept_length = (accept_index != -1).sum(dim=1) - 1
450
+
451
+ # Free the KV cache for unaccepted tokens
452
+ accept_index = accept_index[accept_index != -1]
453
+ verified_id = predict[accept_index]
454
+ evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
455
+ evict_mask[accept_index] = False
456
+
457
+ if page_size != 1:
458
+ align_evict_mask_to_page_size[len(batch.seq_lens),](
459
+ batch.seq_lens,
460
+ evict_mask,
461
+ page_size,
462
+ self.draft_token_num,
463
+ next_power_of_2(self.draft_token_num),
464
+ )
465
+
466
+ token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
467
+
468
+ # Construct EagleVerifyOutput
410
469
  if not has_finished:
411
- accept_index = accept_index[accept_index != -1]
412
- verified_id = predict[accept_index]
413
- evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
414
- evict_mask[accept_index] = False
415
- mem_need_free_idx = batch.out_cache_loc[evict_mask]
416
- token_to_kv_pool_allocator.free(mem_need_free_idx)
417
470
  batch.out_cache_loc = batch.out_cache_loc[accept_index]
418
471
  assign_req_to_token_pool[(bs,)](
419
472
  batch.req_pool_indices,
@@ -422,7 +475,7 @@ class EagleVerifyInput:
422
475
  batch.seq_lens + accept_length + 1,
423
476
  batch.out_cache_loc,
424
477
  batch.req_to_token_pool.req_to_token.shape[1],
425
- triton.next_power_of_2(bs),
478
+ next_power_of_2(bs),
426
479
  )
427
480
  batch.seq_lens.add_(accept_length + 1)
428
481
  accept_length_cpu = accept_length.tolist()
@@ -443,13 +496,6 @@ class EagleVerifyInput:
443
496
  accepeted_indices=accept_index,
444
497
  )
445
498
  else:
446
- accept_length = (accept_index != -1).sum(dim=1) - 1
447
- accept_index = accept_index[accept_index != -1]
448
- verified_id = predict[accept_index]
449
- evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
450
- evict_mask[accept_index] = False
451
- mem_need_free_idx = batch.out_cache_loc[evict_mask]
452
- token_to_kv_pool_allocator.free(mem_need_free_idx)
453
499
  assign_req_to_token_pool[(bs,)](
454
500
  batch.req_pool_indices,
455
501
  batch.req_to_token_pool.req_to_token,
@@ -457,7 +503,7 @@ class EagleVerifyInput:
457
503
  batch.seq_lens + accept_length + 1,
458
504
  batch.out_cache_loc[accept_index],
459
505
  batch.req_to_token_pool.req_to_token.shape[1],
460
- triton.next_power_of_2(bs),
506
+ next_power_of_2(bs),
461
507
  )
462
508
  batch.seq_lens.add_(accept_length + 1)
463
509
  accept_length_cpu = accept_length.tolist()
@@ -465,20 +511,21 @@ class EagleVerifyInput:
465
511
  draft_input = EagleDraftInput()
466
512
  if len(new_accept_index) > 0:
467
513
  new_accept_index = torch.tensor(new_accept_index, device="cuda")
514
+ unfinished_index_device = torch.tensor(unfinished_index, device="cuda")
468
515
  draft_input.hidden_states = batch.spec_info.hidden_states[
469
516
  new_accept_index
470
517
  ]
471
518
  draft_input.verified_id = predict[new_accept_index]
472
- draft_input.accept_length = accept_length[unfinished_index]
473
519
  draft_input.accept_length_cpu = [
474
520
  accept_length_cpu[i] for i in unfinished_index
475
521
  ]
522
+ draft_input.accept_length = accept_length[unfinished_index_device]
476
523
  if has_finished:
477
524
  draft_input.seq_lens_for_draft_extend = batch.seq_lens[
478
- unfinished_index
525
+ unfinished_index_device
479
526
  ]
480
527
  draft_input.req_pool_indices_for_draft_extend = (
481
- batch.req_pool_indices[unfinished_index]
528
+ batch.req_pool_indices[unfinished_index_device]
482
529
  )
483
530
  else:
484
531
  draft_input.seq_lens_for_draft_extend = batch.seq_lens
@@ -564,13 +611,24 @@ def assign_draft_cache_locs(
564
611
  pool_len: tl.constexpr,
565
612
  topk: tl.constexpr,
566
613
  speculative_num_steps: tl.constexpr,
614
+ page_size: tl.constexpr,
567
615
  ):
568
616
  BLOCK_SIZE: tl.constexpr = 32
569
617
  pid = tl.program_id(axis=0)
570
618
  kv_start = tl.load(seq_lens + pid)
571
- kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
619
+
620
+ if page_size == 1 or topk == 1:
621
+ kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
622
+ out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
623
+ else:
624
+ prefix_len = tl.load(seq_lens + pid)
625
+ last_page_len = prefix_len % page_size
626
+ num_new_page = (
627
+ last_page_len + speculative_num_steps + page_size - 1
628
+ ) // page_size
629
+ kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
630
+
572
631
  token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
573
- out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
574
632
 
575
633
  num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
576
634
  for i in range(num_loop):
@@ -642,6 +700,29 @@ def generate_draft_decode_kv_indices(
642
700
  tl.store(kv_indptr + zid, base + zid * iters)
643
701
 
644
702
 
703
+ @triton.jit
704
+ def align_evict_mask_to_page_size(
705
+ seq_lens,
706
+ evict_mask,
707
+ page_size: tl.constexpr,
708
+ num_draft_tokens: tl.constexpr,
709
+ BLOCK_SIZE: tl.constexpr,
710
+ ):
711
+ t_range = tl.arange(0, BLOCK_SIZE)
712
+
713
+ bid = tl.program_id(axis=0)
714
+ seq_len = tl.load(seq_lens + bid)
715
+ io_mask = t_range < num_draft_tokens
716
+ mask_row = tl.load(evict_mask + bid * num_draft_tokens + t_range, mask=io_mask)
717
+
718
+ num_trues = tl.sum(mask_row)
719
+ num_false = num_draft_tokens - num_trues
720
+
721
+ start = (seq_len + num_false - 1) // page_size * page_size - seq_len
722
+ for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
723
+ tl.store(evict_mask + bid * num_draft_tokens + i, False)
724
+
725
+
645
726
  @torch.compile(dynamic=True)
646
727
  def select_top_k_tokens(
647
728
  i: int,
@@ -699,3 +780,34 @@ def fast_topk(values, topk, dim):
699
780
  else:
700
781
  # Use topk for efficiency with larger k values
701
782
  return torch.topk(values, topk, dim=dim)
783
+
784
+
785
+ def _generate_simulated_accept_index(
786
+ accept_index,
787
+ predict,
788
+ accept_length,
789
+ simulate_acc_len,
790
+ bs,
791
+ spec_steps,
792
+ ):
793
+ simulate_acc_len_float = float(simulate_acc_len)
794
+ simulated_values = torch.normal(
795
+ mean=simulate_acc_len_float,
796
+ std=1.0,
797
+ size=(1,),
798
+ device="cpu",
799
+ )
800
+ # clamp simulated values to be between 1 and self.spec_steps
801
+ simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps)
802
+ simulate_acc_len = int(simulated_values.round().item())
803
+
804
+ accept_indx_first_col = accept_index[:, 0].view(-1, 1)
805
+ sim_accept_index = torch.full(
806
+ (bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
807
+ )
808
+ sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
809
+ simulate_acc_len, device=accept_index.device
810
+ )
811
+ accept_length.fill_(simulate_acc_len - 1)
812
+ predict.fill_(100) # some legit token id
813
+ return sim_accept_index
@@ -11,7 +11,11 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
11
11
  from sglang.srt.layers.dp_attention import disable_dp_size
12
12
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
13
13
  from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
14
- from sglang.srt.managers.schedule_batch import ScheduleBatch
14
+ from sglang.srt.managers.schedule_batch import (
15
+ ScheduleBatch,
16
+ get_last_loc,
17
+ global_server_args_dict,
18
+ )
15
19
  from sglang.srt.managers.tp_worker import TpModelWorker
16
20
  from sglang.srt.model_executor.forward_batch_info import (
17
21
  CaptureHiddenMode,
@@ -67,6 +71,7 @@ class EAGLEWorker(TpModelWorker):
67
71
  self.gpu_id = gpu_id
68
72
  self.device = server_args.device
69
73
  self.target_worker = target_worker
74
+ self.page_size = server_args.page_size
70
75
  self.speculative_algorithm = SpeculativeAlgorithm.from_string(
71
76
  server_args.speculative_algorithm
72
77
  )
@@ -145,15 +150,26 @@ class EAGLEWorker(TpModelWorker):
145
150
  def init_attention_backend(self):
146
151
  # Create multi-step attn backends and cuda graph runners
147
152
  if self.server_args.attention_backend == "flashinfer":
148
- from sglang.srt.layers.attention.flashinfer_backend import (
149
- FlashInferMultiStepDraftBackend,
150
- )
153
+ if not global_server_args_dict["use_mla_backend"]:
154
+ from sglang.srt.layers.attention.flashinfer_backend import (
155
+ FlashInferMultiStepDraftBackend,
156
+ )
151
157
 
152
- self.draft_attn_backend = FlashInferMultiStepDraftBackend(
153
- self.draft_model_runner,
154
- self.topk,
155
- self.speculative_num_steps,
156
- )
158
+ self.draft_attn_backend = FlashInferMultiStepDraftBackend(
159
+ self.draft_model_runner,
160
+ self.topk,
161
+ self.speculative_num_steps,
162
+ )
163
+ else:
164
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
165
+ FlashInferMLAMultiStepDraftBackend,
166
+ )
167
+
168
+ self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
169
+ self.draft_model_runner,
170
+ self.topk,
171
+ self.speculative_num_steps,
172
+ )
157
173
  self.draft_extend_attn_backend = None
158
174
  self.padded_static_len = self.speculative_num_steps + 1
159
175
  self.has_prefill_wrapper_verify = True
@@ -170,19 +186,19 @@ class EAGLEWorker(TpModelWorker):
170
186
  self.draft_extend_attn_backend = None
171
187
  self.padded_static_len = self.speculative_num_steps + 1
172
188
  self.has_prefill_wrapper_verify = False
173
- elif self.server_args.attention_backend == "flashinfer_mla":
174
- from sglang.srt.layers.attention.flashinfer_mla_backend import (
175
- FlashInferMLAMultiStepDraftBackend,
189
+ elif self.server_args.attention_backend == "fa3":
190
+ from sglang.srt.layers.attention.flashattention_backend import (
191
+ FlashAttentionMultiStepBackend,
176
192
  )
177
193
 
178
- self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
194
+ self.draft_attn_backend = FlashAttentionMultiStepBackend(
179
195
  self.draft_model_runner,
180
196
  self.topk,
181
197
  self.speculative_num_steps,
182
198
  )
183
199
  self.draft_extend_attn_backend = None
184
200
  self.padded_static_len = self.speculative_num_steps + 1
185
- self.has_prefill_wrapper_verify = True
201
+ self.has_prefill_wrapper_verify = False
186
202
  else:
187
203
  raise ValueError(
188
204
  f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
@@ -234,14 +250,11 @@ class EAGLEWorker(TpModelWorker):
234
250
  """
235
251
  if batch.forward_mode.is_decode():
236
252
  with self.draft_tp_context(self.draft_model_runner.tp_group):
237
- spec_info, to_free_cache_loc = self.draft(batch)
253
+ spec_info = self.draft(batch)
238
254
  logits_output, verify_output, model_worker_batch = self.verify(
239
255
  batch, spec_info
240
256
  )
241
257
 
242
- # Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
243
- self.token_to_kv_pool_allocator.free(to_free_cache_loc)
244
-
245
258
  # If it is None, it means all requests are finished
246
259
  if batch.spec_info.verified_id is not None:
247
260
  with self.draft_tp_context(self.draft_model_runner.tp_group):
@@ -305,9 +318,59 @@ class EAGLEWorker(TpModelWorker):
305
318
  )
306
319
 
307
320
  # Allocate cache locations
308
- out_cache_loc = batch.alloc_token_slots(
309
- num_seqs * self.topk * self.speculative_num_steps
310
- )
321
+ if self.page_size == 1:
322
+ out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
323
+ num_seqs * self.topk * self.speculative_num_steps, backup_state=True
324
+ )
325
+ else:
326
+ if self.topk == 1:
327
+ prefix_lens = batch.seq_lens
328
+ seq_lens = prefix_lens + self.speculative_num_steps
329
+ extend_num_tokens = num_seqs * self.speculative_num_steps
330
+ else:
331
+ # In this case, the last partial page needs to be duplicated.
332
+ # KV cache layout in batch.req_to_token_pool.req_to_token:
333
+ #
334
+ # | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
335
+ # prefix top-k = 0 tok-k = 1 top-k = 2
336
+ #
337
+ # "-" means prefix tokens
338
+ # "x" means speculative draft tokens
339
+ # "." means padded tokens
340
+
341
+ # TODO: fuse these ops
342
+ prefix_lens = batch.seq_lens
343
+ last_page_lens = prefix_lens % self.page_size
344
+ num_new_pages = (
345
+ last_page_lens + self.speculative_num_steps + self.page_size - 1
346
+ ) // self.page_size
347
+ seq_lens = (
348
+ prefix_lens // self.page_size * self.page_size
349
+ + num_new_pages * (self.page_size * self.topk)
350
+ )
351
+ extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
352
+ raise NotImplementedError(
353
+ "page_size > 1 and top_k > 1 are not supported."
354
+ )
355
+ # TODO: Support page_size > 1 and top_k > 1
356
+ # 1. Duplicate the KV cache in the last partial page for all top-k segments
357
+ # 2. Modify generate_draft_decode_kv_indices accordingly
358
+
359
+ last_loc = get_last_loc(
360
+ batch.req_to_token_pool.req_to_token,
361
+ batch.req_pool_indices,
362
+ prefix_lens,
363
+ )
364
+ out_cache_loc, token_to_kv_pool_state_backup = (
365
+ batch.alloc_paged_token_slots_extend(
366
+ prefix_lens,
367
+ seq_lens,
368
+ last_loc,
369
+ extend_num_tokens,
370
+ backup_state=True,
371
+ )
372
+ )
373
+
311
374
  assign_draft_cache_locs[(num_seqs,)](
312
375
  batch.req_pool_indices,
313
376
  batch.req_to_token_pool.req_to_token,
@@ -316,6 +379,7 @@ class EAGLEWorker(TpModelWorker):
316
379
  batch.req_to_token_pool.req_to_token.shape[1],
317
380
  self.topk,
318
381
  self.speculative_num_steps,
382
+ self.page_size,
319
383
  )
320
384
  batch.out_cache_loc = out_cache_loc
321
385
  batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
@@ -343,6 +407,8 @@ class EAGLEWorker(TpModelWorker):
343
407
  # Run forward steps
344
408
  score_list, token_list, parents_list = self.draft_forward(forward_batch)
345
409
 
410
+ self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
411
+
346
412
  ret = EagleVerifyInput.create(
347
413
  spec_info.verified_id,
348
414
  score_list,
@@ -354,7 +420,7 @@ class EAGLEWorker(TpModelWorker):
354
420
  self.speculative_num_steps,
355
421
  self.server_args.speculative_num_draft_tokens,
356
422
  )
357
- return ret, out_cache_loc
423
+ return ret
358
424
 
359
425
  def draft_forward(self, forward_batch: ForwardBatch):
360
426
  # Parse args
@@ -411,7 +477,7 @@ class EAGLEWorker(TpModelWorker):
411
477
  return score_list, token_list, parents_list
412
478
 
413
479
  def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
414
- spec_info.prepare_for_verify(batch)
480
+ spec_info.prepare_for_verify(batch, self.page_size)
415
481
  batch.forward_mode = ForwardMode.TARGET_VERIFY
416
482
  batch.spec_info = spec_info
417
483
  model_worker_batch = batch.get_model_worker_batch()
@@ -421,7 +487,10 @@ class EAGLEWorker(TpModelWorker):
421
487
  self._detect_nan_if_needed(logits_output)
422
488
  spec_info.hidden_states = logits_output.hidden_states
423
489
  res: EagleVerifyOutput = spec_info.verify(
424
- batch, logits_output, self.token_to_kv_pool_allocator
490
+ batch,
491
+ logits_output,
492
+ self.token_to_kv_pool_allocator,
493
+ self.page_size,
425
494
  )
426
495
 
427
496
  # Post process based on verified outputs.
@@ -586,5 +655,5 @@ def load_token_map(token_map_path: str) -> List[int]:
586
655
  ignore_patterns=["*.bin", "*.safetensors"],
587
656
  )
588
657
  token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
589
- hot_token_id = torch.load(token_map_path)
658
+ hot_token_id = torch.load(token_map_path, weights_only=True)
590
659
  return torch.tensor(hot_token_id, dtype=torch.int32)