sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/lang/interpreter.py +40 -1
  5. sglang/lang/ir.py +27 -0
  6. sglang/math_utils.py +8 -0
  7. sglang/srt/configs/model_config.py +6 -0
  8. sglang/srt/conversation.py +6 -0
  9. sglang/srt/disaggregation/base/__init__.py +1 -1
  10. sglang/srt/disaggregation/base/conn.py +25 -11
  11. sglang/srt/disaggregation/common/__init__.py +5 -1
  12. sglang/srt/disaggregation/common/utils.py +42 -0
  13. sglang/srt/disaggregation/decode.py +196 -51
  14. sglang/srt/disaggregation/fake/__init__.py +1 -1
  15. sglang/srt/disaggregation/fake/conn.py +15 -9
  16. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +18 -13
  18. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  19. sglang/srt/disaggregation/nixl/conn.py +17 -12
  20. sglang/srt/disaggregation/prefill.py +128 -43
  21. sglang/srt/disaggregation/utils.py +127 -123
  22. sglang/srt/entrypoints/engine.py +15 -1
  23. sglang/srt/entrypoints/http_server.py +13 -2
  24. sglang/srt/eplb_simulator/__init__.py +1 -0
  25. sglang/srt/eplb_simulator/reader.py +51 -0
  26. sglang/srt/layers/activation.py +19 -0
  27. sglang/srt/layers/attention/aiter_backend.py +15 -2
  28. sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
  29. sglang/srt/layers/attention/flashattention_backend.py +53 -64
  30. sglang/srt/layers/attention/flashinfer_backend.py +1 -2
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
  32. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  33. sglang/srt/layers/attention/triton_backend.py +119 -119
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  35. sglang/srt/layers/attention/vision.py +51 -24
  36. sglang/srt/layers/communicator.py +23 -5
  37. sglang/srt/layers/linear.py +0 -4
  38. sglang/srt/layers/logits_processor.py +0 -12
  39. sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
  40. sglang/srt/layers/moe/ep_moe/layer.py +42 -32
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
  43. sglang/srt/layers/moe/topk.py +16 -8
  44. sglang/srt/layers/pooler.py +56 -0
  45. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  46. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  47. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  49. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  50. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  51. sglang/srt/layers/radix_attention.py +2 -3
  52. sglang/srt/lora/lora_manager.py +79 -34
  53. sglang/srt/lora/mem_pool.py +4 -5
  54. sglang/srt/managers/cache_controller.py +2 -1
  55. sglang/srt/managers/io_struct.py +28 -4
  56. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  57. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  58. sglang/srt/managers/schedule_batch.py +39 -6
  59. sglang/srt/managers/scheduler.py +73 -17
  60. sglang/srt/managers/tokenizer_manager.py +29 -2
  61. sglang/srt/mem_cache/chunk_cache.py +1 -0
  62. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  63. sglang/srt/mem_cache/memory_pool.py +111 -407
  64. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  65. sglang/srt/mem_cache/radix_cache.py +36 -12
  66. sglang/srt/model_executor/cuda_graph_runner.py +122 -55
  67. sglang/srt/model_executor/forward_batch_info.py +14 -5
  68. sglang/srt/model_executor/model_runner.py +6 -6
  69. sglang/srt/model_loader/loader.py +8 -1
  70. sglang/srt/models/bert.py +113 -13
  71. sglang/srt/models/deepseek_v2.py +113 -155
  72. sglang/srt/models/internvl.py +46 -102
  73. sglang/srt/models/roberta.py +117 -9
  74. sglang/srt/models/vila.py +305 -0
  75. sglang/srt/openai_api/adapter.py +162 -4
  76. sglang/srt/openai_api/protocol.py +37 -1
  77. sglang/srt/sampling/sampling_batch_info.py +24 -0
  78. sglang/srt/sampling/sampling_params.py +2 -0
  79. sglang/srt/server_args.py +318 -233
  80. sglang/srt/speculative/build_eagle_tree.py +1 -1
  81. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
  82. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
  83. sglang/srt/speculative/eagle_utils.py +389 -109
  84. sglang/srt/speculative/eagle_worker.py +134 -43
  85. sglang/srt/two_batch_overlap.py +4 -2
  86. sglang/srt/utils.py +58 -0
  87. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  88. sglang/test/runners.py +38 -3
  89. sglang/test/test_block_fp8.py +1 -0
  90. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  91. sglang/test/test_block_fp8_ep.py +1 -0
  92. sglang/test/test_utils.py +3 -1
  93. sglang/utils.py +9 -0
  94. sglang/version.py +1 -1
  95. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
  96. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
  97. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ from sglang.srt.managers.schedule_batch import (
23
23
  )
24
24
  from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
25
25
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
26
- from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
26
+ from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
27
27
 
28
28
  if is_cuda():
29
29
  from sgl_kernel import (
@@ -32,6 +32,7 @@ if is_cuda():
32
32
  tree_speculative_sampling_target_only,
33
33
  verify_tree_greedy,
34
34
  )
35
+ from sgl_kernel.top_k import fast_topk
35
36
  elif is_hip():
36
37
  from sgl_kernel import verify_tree_greedy
37
38
 
@@ -67,8 +68,6 @@ class EagleDraftInput:
67
68
  kv_indptr: torch.Tensor = None
68
69
  kv_indices: torch.Tensor = None
69
70
 
70
- all_padding_lens: Optional[torch.Tensor] = None
71
-
72
71
  def prepare_for_extend(self, batch: ScheduleBatch):
73
72
  # Prefill only generate 1 token.
74
73
  assert len(self.verified_id) == len(batch.seq_lens)
@@ -93,6 +92,7 @@ class EagleDraftInput:
93
92
  batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
94
93
  batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
95
94
  batch.return_logprob = False
95
+ batch.return_hidden_states = False
96
96
 
97
97
  self.capture_hidden_mode = CaptureHiddenMode.LAST
98
98
  self.accept_length.add_(1)
@@ -116,13 +116,14 @@ class EagleDraftInput:
116
116
  req_to_token: torch.Tensor,
117
117
  ):
118
118
  bs = self.accept_length.numel()
119
-
120
119
  qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
121
120
  qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
122
-
123
121
  cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
124
122
  cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
125
123
 
124
+ if paged_kernel_lens_sum is None:
125
+ paged_kernel_lens_sum = cum_kv_seq_len[-1]
126
+
126
127
  kv_indices = torch.empty(
127
128
  paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
128
129
  )
@@ -136,7 +137,6 @@ class EagleDraftInput:
136
137
  kv_indices,
137
138
  req_to_token.size(1),
138
139
  )
139
-
140
140
  return kv_indices, cum_kv_seq_len, qo_indptr, None
141
141
 
142
142
  def filter_batch(self, new_indices: torch.Tensor):
@@ -267,7 +267,7 @@ class EagleVerifyInput:
267
267
  logits_output: torch.Tensor,
268
268
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
269
269
  page_size: int,
270
- vocab_mask: Optional[torch.Tensor] = None,
270
+ vocab_mask: Optional[torch.Tensor] = None, # For grammar
271
271
  ) -> torch.Tensor:
272
272
  """
273
273
  Verify and find accepted tokens based on logits output and batch
@@ -291,6 +291,14 @@ class EagleVerifyInput:
291
291
  )
292
292
  accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
293
293
 
294
+ # Apply the custom logit processors if registered in the sampling info.
295
+ if sampling_info.has_custom_logit_processor:
296
+ apply_custom_logit_processor(
297
+ logits_output.next_token_logits,
298
+ sampling_info,
299
+ num_tokens_in_batch=self.draft_token_num,
300
+ )
301
+
294
302
  # Apply penalty
295
303
  if sampling_info.penalizer_orchestrator.is_required:
296
304
  # This is a relaxed version of penalties for speculative decoding.
@@ -320,11 +328,11 @@ class EagleVerifyInput:
320
328
  predicts=predict, # mutable
321
329
  accept_index=accept_index, # mutable
322
330
  accept_token_num=accept_length, # mutable
323
- candidates=candidates.to(torch.int32),
324
- retrive_index=self.retrive_index.to(torch.int32),
325
- retrive_next_token=self.retrive_next_token.to(torch.int32),
326
- retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
327
- target_predict=target_predict.to(torch.int32),
331
+ candidates=candidates,
332
+ retrive_index=self.retrive_index,
333
+ retrive_next_token=self.retrive_next_token,
334
+ retrive_next_sibling=self.retrive_next_sibling,
335
+ target_predict=target_predict,
328
336
  )
329
337
  else:
330
338
  # apply temperature and get target probs
@@ -352,16 +360,23 @@ class EagleVerifyInput:
352
360
  draft_probs = torch.zeros(
353
361
  target_probs.shape, dtype=torch.float32, device="cuda"
354
362
  )
363
+
364
+ # coins for rejection sampling
355
365
  coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
366
+ # coins for final sampling
367
+ coins_for_final_sampling = torch.rand(
368
+ (bs,), dtype=torch.float32, device="cuda"
369
+ )
356
370
  tree_speculative_sampling_target_only(
357
371
  predicts=predict, # mutable
358
372
  accept_index=accept_index, # mutable
359
373
  accept_token_num=accept_length, # mutable
360
- candidates=candidates.to(torch.int32),
361
- retrive_index=self.retrive_index.to(torch.int32),
362
- retrive_next_token=self.retrive_next_token.to(torch.int32),
363
- retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
374
+ candidates=candidates,
375
+ retrive_index=self.retrive_index,
376
+ retrive_next_token=self.retrive_next_token,
377
+ retrive_next_sibling=self.retrive_next_sibling,
364
378
  uniform_samples=coins,
379
+ uniform_samples_for_final_sampling=coins_for_final_sampling,
365
380
  target_probs=target_probs,
366
381
  draft_probs=draft_probs,
367
382
  threshold_single=global_server_args_dict[
@@ -384,8 +399,8 @@ class EagleVerifyInput:
384
399
  spec_steps=self.spec_steps,
385
400
  )
386
401
 
387
- new_accept_index = []
388
402
  unfinished_index = []
403
+ unfinished_accept_index = []
389
404
  accept_index_cpu = accept_index.tolist()
390
405
  predict_cpu = predict.tolist()
391
406
  has_finished = False
@@ -393,12 +408,10 @@ class EagleVerifyInput:
393
408
  # Iterate every accepted token and check if req has finished after append the token
394
409
  # should be checked BEFORE free kv cache slots
395
410
  for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
396
- new_accept_index_ = []
397
411
  for j, idx in enumerate(accept_index_row):
398
412
  if idx == -1:
399
413
  break
400
414
  id = predict_cpu[idx]
401
- # if not found_finished:
402
415
  req.output_ids.append(id)
403
416
  req.check_finished()
404
417
  if req.finished():
@@ -407,8 +420,6 @@ class EagleVerifyInput:
407
420
  accept_index[i, j + 1 :] = -1
408
421
  break
409
422
  else:
410
- new_accept_index_.append(idx)
411
- # update grammar state
412
423
  if req.grammar is not None:
413
424
  try:
414
425
  req.grammar.accept_token(id)
@@ -418,50 +429,104 @@ class EagleVerifyInput:
418
429
  )
419
430
  raise e
420
431
  if not req.finished():
421
- new_accept_index.extend(new_accept_index_)
422
432
  unfinished_index.append(i)
433
+ if idx == -1:
434
+ unfinished_accept_index.append(accept_index[i, :j])
435
+ else:
436
+ unfinished_accept_index.append(accept_index[i])
423
437
  req.spec_verify_ct += 1
424
438
 
425
439
  if has_finished:
426
440
  accept_length = (accept_index != -1).sum(dim=1) - 1
427
441
 
428
442
  # Free the KV cache for unaccepted tokens
443
+ # TODO: fuse them
429
444
  accept_index = accept_index[accept_index != -1]
430
445
  verified_id = predict[accept_index]
431
446
  evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
432
447
  evict_mask[accept_index] = False
433
448
 
434
- if page_size != 1:
435
- align_evict_mask_to_page_size[len(batch.seq_lens),](
436
- batch.seq_lens,
437
- evict_mask,
438
- page_size,
439
- self.draft_token_num,
440
- next_power_of_2(self.draft_token_num),
441
- )
449
+ if page_size == 1:
450
+ # TODO: boolean array index leads to a device sync. Remove it.
451
+ token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
452
+ else:
453
+ if self.topk == 1:
454
+ # Only evict full empty page. Do not evict partial empty page
455
+ align_evict_mask_to_page_size[len(batch.seq_lens),](
456
+ batch.seq_lens,
457
+ evict_mask,
458
+ page_size,
459
+ self.draft_token_num,
460
+ next_power_of_2(self.draft_token_num),
461
+ )
462
+ token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
463
+ else:
464
+ # Shift the accepted tokens to the beginning.
465
+ # Only evict the last part
466
+ src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
467
+ batch.seq_lens,
468
+ batch.out_cache_loc,
469
+ accept_index,
470
+ accept_length,
471
+ self.draft_token_num,
472
+ page_size,
473
+ )
474
+ to_free_slots = torch.empty(
475
+ (to_free_num_slots.sum().item(),),
476
+ dtype=torch.int64,
477
+ device=to_free_num_slots.device,
478
+ )
479
+
480
+ # out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
481
+ # accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
482
+ # tgt_cache_loc: [0 1 , 3 4 , 6 ]
483
+ # to_free_slots: [ 2, 5, 7 8]
484
+ # to_free_slots also needs to be page-aligned without the first partial page
485
+ #
486
+ # split each row of out_cache_loc into two parts.
487
+ # 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
488
+ # 2. the second part goes to to_free_slots.
489
+ get_target_cache_loc[(bs,)](
490
+ tgt_cache_loc,
491
+ to_free_slots,
492
+ accept_length,
493
+ to_free_num_slots,
494
+ batch.out_cache_loc,
495
+ self.draft_token_num,
496
+ next_power_of_2(self.draft_token_num),
497
+ next_power_of_2(bs),
498
+ )
442
499
 
443
- token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
500
+ # Free the kv cache
501
+ token_to_kv_pool_allocator.free(to_free_slots)
502
+
503
+ # Copy the kv cache
504
+ batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
505
+ tgt_cache_loc, src_cache_loc
506
+ )
444
507
 
445
508
  # Construct EagleVerifyOutput
446
509
  if not has_finished:
447
- batch.out_cache_loc = batch.out_cache_loc[accept_index]
448
- assign_req_to_token_pool[(bs,)](
449
- batch.req_pool_indices,
450
- batch.req_to_token_pool.req_to_token,
451
- batch.seq_lens,
452
- batch.seq_lens + accept_length + 1,
453
- batch.out_cache_loc,
454
- batch.req_to_token_pool.req_to_token.shape[1],
455
- next_power_of_2(bs),
456
- )
510
+ if page_size == 1 or self.topk == 1:
511
+ batch.out_cache_loc = batch.out_cache_loc[accept_index]
512
+ assign_req_to_token_pool[(bs,)](
513
+ batch.req_pool_indices,
514
+ batch.req_to_token_pool.req_to_token,
515
+ batch.seq_lens,
516
+ batch.seq_lens + accept_length + 1,
517
+ batch.out_cache_loc,
518
+ batch.req_to_token_pool.req_to_token.shape[1],
519
+ next_power_of_2(bs),
520
+ )
521
+ else:
522
+ batch.out_cache_loc = tgt_cache_loc
457
523
  batch.seq_lens.add_(accept_length + 1)
458
- accept_length_cpu = accept_length.tolist()
459
524
 
460
525
  draft_input = EagleDraftInput()
461
526
  draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
462
527
  draft_input.verified_id = verified_id
463
528
  draft_input.accept_length = accept_length
464
- draft_input.accept_length_cpu = accept_length_cpu
529
+ draft_input.accept_length_cpu = accept_length.tolist()
465
530
  draft_input.seq_lens_for_draft_extend = batch.seq_lens
466
531
  draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
467
532
 
@@ -469,47 +534,66 @@ class EagleVerifyInput:
469
534
  draft_input=draft_input,
470
535
  logits_output=logits_output,
471
536
  verified_id=verified_id,
472
- accept_length_per_req_cpu=accept_length_cpu,
537
+ accept_length_per_req_cpu=draft_input.accept_length_cpu,
473
538
  accepted_indices=accept_index,
474
539
  )
475
540
  else:
476
- assign_req_to_token_pool[(bs,)](
477
- batch.req_pool_indices,
478
- batch.req_to_token_pool.req_to_token,
479
- batch.seq_lens,
480
- batch.seq_lens + accept_length + 1,
481
- batch.out_cache_loc[accept_index],
482
- batch.req_to_token_pool.req_to_token.shape[1],
483
- next_power_of_2(bs),
484
- )
485
- batch.seq_lens.add_(accept_length + 1)
486
- accept_length_cpu = accept_length.tolist()
541
+ if page_size == 1 or self.topk == 1:
542
+ assign_req_to_token_pool[(bs,)](
543
+ batch.req_pool_indices,
544
+ batch.req_to_token_pool.req_to_token,
545
+ batch.seq_lens,
546
+ batch.seq_lens + accept_length + 1,
547
+ batch.out_cache_loc[accept_index],
548
+ batch.req_to_token_pool.req_to_token.shape[1],
549
+ next_power_of_2(bs),
550
+ )
551
+ batch.seq_lens.add_(accept_length + 1)
487
552
 
553
+ accept_length_cpu = accept_length.tolist()
488
554
  draft_input = EagleDraftInput()
489
- if len(new_accept_index) > 0:
490
- new_accept_index = torch.tensor(new_accept_index, device="cuda")
491
- unfinished_index_device = torch.tensor(unfinished_index, device="cuda")
492
- draft_input.hidden_states = batch.spec_info.hidden_states[
493
- new_accept_index
494
- ]
495
- draft_input.verified_id = predict[new_accept_index]
496
- draft_input.accept_length_cpu = [
555
+ if len(unfinished_accept_index) > 0:
556
+ unfinished_accept_index = torch.cat(unfinished_accept_index)
557
+ unfinished_index_device = torch.tensor(
558
+ unfinished_index, dtype=torch.int64, device=predict.device
559
+ )
560
+ draft_input_accept_length_cpu = [
497
561
  accept_length_cpu[i] for i in unfinished_index
498
562
  ]
499
- draft_input.accept_length = accept_length[unfinished_index_device]
500
- if has_finished:
501
- draft_input.seq_lens_for_draft_extend = batch.seq_lens[
502
- unfinished_index_device
503
- ]
504
- draft_input.req_pool_indices_for_draft_extend = (
505
- batch.req_pool_indices[unfinished_index_device]
506
- )
563
+ if page_size == 1 or self.topk == 1:
564
+ batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
507
565
  else:
508
- draft_input.seq_lens_for_draft_extend = batch.seq_lens
509
- draft_input.req_pool_indices_for_draft_extend = (
510
- batch.req_pool_indices
566
+ batch.out_cache_loc = torch.empty(
567
+ len(unfinished_index) + sum(draft_input_accept_length_cpu),
568
+ dtype=torch.int64,
569
+ device=predict.device,
570
+ )
571
+ accept_length_filter = create_accept_length_filter(
572
+ accept_length,
573
+ unfinished_index_device,
574
+ batch.seq_lens,
511
575
  )
512
- batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
576
+ filter_finished_cache_loc_kernel[(bs,)](
577
+ batch.out_cache_loc,
578
+ tgt_cache_loc,
579
+ accept_length,
580
+ accept_length_filter,
581
+ next_power_of_2(bs),
582
+ next_power_of_2(self.draft_token_num),
583
+ )
584
+
585
+ draft_input.hidden_states = batch.spec_info.hidden_states[
586
+ unfinished_accept_index
587
+ ]
588
+ draft_input.verified_id = predict[unfinished_accept_index]
589
+ draft_input.accept_length_cpu = draft_input_accept_length_cpu
590
+ draft_input.accept_length = accept_length[unfinished_index_device]
591
+ draft_input.seq_lens_for_draft_extend = batch.seq_lens[
592
+ unfinished_index_device
593
+ ]
594
+ draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
595
+ unfinished_index_device
596
+ ]
513
597
 
514
598
  return EagleVerifyOutput(
515
599
  draft_input=draft_input,
@@ -586,36 +670,75 @@ def assign_draft_cache_locs(
586
670
  req_pool_indices,
587
671
  req_to_token,
588
672
  seq_lens,
673
+ extend_lens,
674
+ num_new_pages_per_topk,
589
675
  out_cache_loc,
590
676
  pool_len: tl.constexpr,
591
677
  topk: tl.constexpr,
592
678
  speculative_num_steps: tl.constexpr,
593
679
  page_size: tl.constexpr,
680
+ bs_upper: tl.constexpr,
681
+ iter_upper: tl.constexpr,
594
682
  ):
595
- BLOCK_SIZE: tl.constexpr = 32
683
+ BLOCK_SIZE: tl.constexpr = 128
596
684
  pid = tl.program_id(axis=0)
597
- kv_start = tl.load(seq_lens + pid)
598
685
 
599
686
  if page_size == 1 or topk == 1:
600
- kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
687
+ copy_len = topk * speculative_num_steps
601
688
  out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
602
689
  else:
603
- prefix_len = tl.load(seq_lens + pid)
604
- last_page_len = prefix_len % page_size
605
- num_new_page = (
606
- last_page_len + speculative_num_steps + page_size - 1
607
- ) // page_size
608
- kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
690
+ bs_offset = tl.arange(0, bs_upper)
691
+ copy_len = tl.load(extend_lens + pid)
692
+ cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
693
+ out_cache_ptr = out_cache_loc + cum_copy_len
609
694
 
695
+ # Part 1: Copy from out_cache_loc to req_to_token
696
+ kv_start = tl.load(seq_lens + pid)
610
697
  token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
611
-
612
- num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
698
+ num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
613
699
  for i in range(num_loop):
614
- save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start
615
- load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
616
- mask = save_offset < kv_end
617
- data = tl.load(out_cache_ptr + load_offset, mask=mask)
618
- tl.store(token_pool + save_offset, data, mask=mask)
700
+ copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
701
+ mask = copy_offset < copy_len
702
+ data = tl.load(out_cache_ptr + copy_offset, mask=mask)
703
+ tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
704
+
705
+ if page_size == 1 or topk == 1:
706
+ return
707
+
708
+ # Part 2: Copy the indices for the last partial page
709
+ prefix_len = tl.load(seq_lens + pid)
710
+ last_page_len = prefix_len % page_size
711
+ offsets = tl.arange(0, page_size)
712
+ mask = offsets < last_page_len
713
+ num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
714
+ prefix_base = token_pool + prefix_len - last_page_len
715
+
716
+ for topk_id in range(topk):
717
+ value = tl.load(prefix_base + offsets, mask=mask)
718
+ tl.store(
719
+ prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
720
+ value,
721
+ mask=mask,
722
+ )
723
+
724
+ # Part 3: Remove the padding in out_cache_loc
725
+ iter_offest = tl.arange(0, iter_upper)
726
+ for topk_id in range(topk):
727
+ indices = tl.load(
728
+ prefix_base
729
+ + topk_id * num_new_pages_per_topk_ * page_size
730
+ + last_page_len
731
+ + iter_offest,
732
+ mask=iter_offest < speculative_num_steps,
733
+ )
734
+ tl.store(
735
+ out_cache_loc
736
+ + pid * topk * speculative_num_steps
737
+ + topk_id * speculative_num_steps
738
+ + iter_offest,
739
+ indices,
740
+ mask=iter_offest < speculative_num_steps,
741
+ )
619
742
 
620
743
 
621
744
  @triton.jit
@@ -626,20 +749,23 @@ def generate_draft_decode_kv_indices(
626
749
  kv_indices,
627
750
  kv_indptr,
628
751
  positions,
629
- num_seqs: tl.constexpr,
630
- topk: tl.constexpr,
631
752
  pool_len: tl.constexpr,
632
753
  kv_indices_stride: tl.constexpr,
633
754
  kv_indptr_stride: tl.constexpr,
634
755
  bs_upper: tl.constexpr,
635
756
  iter_upper: tl.constexpr,
636
757
  num_tokens_upper: tl.constexpr,
758
+ page_size: tl.constexpr,
637
759
  ):
638
760
  BLOCK_SIZE: tl.constexpr = 128
639
761
  iters = tl.program_id(axis=0)
640
762
  bid = tl.program_id(axis=1)
641
763
  topk_id = tl.program_id(axis=2)
642
764
 
765
+ num_steps = tl.num_programs(axis=0)
766
+ num_seqs = tl.num_programs(axis=1)
767
+ topk = tl.num_programs(axis=2)
768
+
643
769
  kv_indices += kv_indices_stride * iters
644
770
  kv_indptr += kv_indptr_stride * iters
645
771
  iters += 1
@@ -649,6 +775,7 @@ def generate_draft_decode_kv_indices(
649
775
  seq_len = tl.load(paged_kernel_lens + bid)
650
776
  cum_seq_len = tl.sum(seq_lens)
651
777
 
778
+ # Update kv_indices
652
779
  kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
653
780
  kv_ptr = kv_indices + kv_offset
654
781
  token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
@@ -662,10 +789,26 @@ def generate_draft_decode_kv_indices(
662
789
  kv_offset += BLOCK_SIZE
663
790
 
664
791
  extend_offset = tl.arange(0, iter_upper)
665
- extend_data = tl.load(
666
- token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id,
667
- mask=extend_offset < iters,
668
- )
792
+ if page_size == 1 or topk == 1:
793
+ extend_data = tl.load(
794
+ token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
795
+ mask=extend_offset < iters,
796
+ )
797
+ else:
798
+ prefix_len = seq_len
799
+ last_page_len = prefix_len % page_size
800
+ num_new_pages_per_topk = (
801
+ last_page_len + num_steps + page_size - 1
802
+ ) // page_size
803
+ prefix_base = seq_len // page_size * page_size
804
+ start = (
805
+ prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
806
+ )
807
+ extend_data = tl.load(
808
+ token_pool_ptr + start + extend_offset,
809
+ mask=extend_offset < iters,
810
+ )
811
+
669
812
  tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
670
813
 
671
814
  # Update kv_indptr
@@ -704,6 +847,116 @@ def align_evict_mask_to_page_size(
704
847
  tl.store(evict_mask + bid * num_draft_tokens + i, False)
705
848
 
706
849
 
850
+ @triton.jit
851
+ def get_target_cache_loc(
852
+ tgt_cache_loc,
853
+ to_free_slots,
854
+ accept_length,
855
+ to_free_num_slots,
856
+ out_cache_loc,
857
+ num_verify_tokens: tl.constexpr,
858
+ num_verify_tokens_upper: tl.constexpr,
859
+ bs_upper: tl.constexpr,
860
+ ):
861
+ bid = tl.program_id(axis=0)
862
+ offset = tl.arange(0, num_verify_tokens_upper)
863
+ bs_offset = tl.arange(0, bs_upper)
864
+
865
+ # write the first part to tgt_cache_loc
866
+ accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
867
+ tgt_cache_loc_start = tl.sum(accept_len_all) + bid
868
+ copy_len = tl.load(accept_length + bid) + 1
869
+ out_cache_loc_row = tl.load(
870
+ out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
871
+ )
872
+ tl.store(
873
+ tgt_cache_loc + tgt_cache_loc_start + offset,
874
+ out_cache_loc_row,
875
+ mask=offset < copy_len,
876
+ )
877
+
878
+ # write the second part to to_free_num_pages
879
+ to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
880
+ to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
881
+ out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
882
+ to_free_slots_start = tl.sum(to_free_num_slots_all)
883
+
884
+ copy_len = to_free_num_slots_cur
885
+ out_cache_loc_row = tl.load(
886
+ out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
887
+ mask=offset < copy_len,
888
+ )
889
+ tl.store(
890
+ to_free_slots + to_free_slots_start + offset,
891
+ out_cache_loc_row,
892
+ mask=offset < copy_len,
893
+ )
894
+
895
+
896
+ @torch.compile(dynamic=True)
897
+ def get_src_tgt_cache_loc(
898
+ seq_lens: torch.Tensor,
899
+ out_cache_loc: torch.Tensor,
900
+ accept_index: torch.Tensor,
901
+ accept_length: torch.Tensor,
902
+ draft_token_num: int,
903
+ page_size: int,
904
+ ):
905
+ src_cache_loc = out_cache_loc[accept_index]
906
+ tgt_cache_loc = torch.empty_like(src_cache_loc)
907
+ extended_len = seq_lens + draft_token_num
908
+ keep_len = torch.minimum(
909
+ (seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
910
+ extended_len,
911
+ )
912
+ to_free_num_slots = extended_len - keep_len
913
+ return src_cache_loc, tgt_cache_loc, to_free_num_slots
914
+
915
+
916
+ @triton.jit
917
+ def filter_finished_cache_loc_kernel(
918
+ out_cache_loc,
919
+ tgt_cache_loc,
920
+ accept_length,
921
+ accept_length_filter,
922
+ bs_upper: tl.constexpr,
923
+ num_verify_tokens_upper: tl.constexpr,
924
+ ):
925
+ bid = tl.program_id(0)
926
+ bs_offset = tl.arange(0, bs_upper)
927
+
928
+ accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
929
+ old_start = tl.sum(accept_length_all) + bid
930
+
931
+ accept_length_filter_all = tl.load(
932
+ accept_length_filter + bs_offset, mask=bs_offset < bid
933
+ )
934
+ new_start = tl.sum(accept_length_filter_all)
935
+
936
+ copy_len = tl.load(accept_length_filter + bid)
937
+ copy_offset = tl.arange(0, num_verify_tokens_upper)
938
+ value = tl.load(
939
+ tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
940
+ )
941
+ tl.store(
942
+ out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
943
+ )
944
+
945
+
946
+ @torch.compile(dynamic=True)
947
+ def create_accept_length_filter(
948
+ accept_length: torch.Tensor,
949
+ unfinished_index_device: torch.Tensor,
950
+ seq_lens: torch.Tensor,
951
+ ):
952
+ accept_length_filter = torch.zeros_like(accept_length)
953
+ accept_length_filter[unfinished_index_device] = (
954
+ accept_length[unfinished_index_device] + 1
955
+ )
956
+ seq_lens.add_(accept_length + 1)
957
+ return accept_length_filter
958
+
959
+
707
960
  @torch.compile(dynamic=True)
708
961
  def select_top_k_tokens(
709
962
  i: int,
@@ -762,15 +1015,35 @@ def _generate_simulated_accept_index(
762
1015
  spec_steps,
763
1016
  ):
764
1017
  simulate_acc_len_float = float(simulate_acc_len)
765
- simulated_values = torch.normal(
766
- mean=simulate_acc_len_float,
767
- std=1.0,
768
- size=(1,),
769
- device="cpu",
770
- )
771
- # clamp simulated values to be between 1 and self.spec_steps
772
- simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps)
773
- simulate_acc_len = int(simulated_values.round().item())
1018
+ if SIMULATE_ACC_METHOD == "multinomial":
1019
+ simulated_values = torch.normal(
1020
+ mean=simulate_acc_len_float,
1021
+ std=1.0,
1022
+ size=(1,),
1023
+ device="cpu",
1024
+ )
1025
+ # clamp simulated values to be between 1 and self.spec_steps
1026
+ simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
1027
+ simulate_acc_len = int(simulated_values.round().item())
1028
+ elif SIMULATE_ACC_METHOD == "match-expected":
1029
+ # multinomial sampling does not match the expected length
1030
+ # we keep it for the sake of compatibility of existing tests
1031
+ # but it's better to use "match-expected" for the cases that need to
1032
+ # match the expected length, One caveat is that this will only sample
1033
+ # either round down or round up of the expected length
1034
+ simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
1035
+ lower = int(simulate_acc_len_float // 1)
1036
+ upper = lower + 1 if lower < spec_steps + 1 else lower
1037
+ if lower == upper:
1038
+ simulate_acc_len = lower
1039
+ else:
1040
+ weight_upper = simulate_acc_len_float - lower
1041
+ weight_lower = 1.0 - weight_upper
1042
+ probs = torch.tensor([weight_lower, weight_upper], device="cpu")
1043
+ sampled_index = torch.multinomial(probs, num_samples=1)
1044
+ simulate_acc_len = lower if sampled_index == 0 else upper
1045
+ else:
1046
+ raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
774
1047
 
775
1048
  accept_indx_first_col = accept_index[:, 0].view(-1, 1)
776
1049
  sim_accept_index = torch.full(
@@ -861,9 +1134,9 @@ def generate_token_bitmask(
861
1134
  """
862
1135
  Generate the logit mask for structured output.
863
1136
  Draft model's token can be either valid or invalid with respect to the grammar.
864
- We need to perform DFS to figure out:
865
- 1. which tokens are accepted by the grammar
866
- 2. what is the corresponding logit mask.
1137
+ We need to perform DFS to
1138
+ 1. figure out which tokens are accepted by the grammar.
1139
+ 2. if so, what is the corresponding logit mask.
867
1140
  """
868
1141
 
869
1142
  num_draft_tokens = draft_tokens_cpu.shape[-1]
@@ -880,6 +1153,7 @@ def generate_token_bitmask(
880
1153
  device="cpu",
881
1154
  )
882
1155
  grammar = req.grammar
1156
+ s = time.perf_counter()
883
1157
  traverse_tree(
884
1158
  retrieve_next_token_cpu[i],
885
1159
  retrieve_next_sibling_cpu[i],
@@ -889,6 +1163,12 @@ def generate_token_bitmask(
889
1163
  i * num_draft_tokens : (i + 1) * num_draft_tokens
890
1164
  ],
891
1165
  )
1166
+ tree_traverse_time = time.perf_counter() - s
1167
+ if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
1168
+ logger.warning(
1169
+ f"Bit mask generation took {tree_traverse_time} seconds with "
1170
+ f"grammar: {req.grammar}"
1171
+ )
892
1172
 
893
1173
  verify_input.grammar = grammar
894
1174
  return allocate_token_bitmask