sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -21,19 +21,22 @@ from sglang.srt.managers.schedule_batch import (
21
21
  get_last_loc,
22
22
  global_server_args_dict,
23
23
  )
24
- from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
24
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
25
25
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
26
- from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
26
+ from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
27
+
28
+ logger = logging.getLogger(__name__)
27
29
 
28
30
  if is_cuda():
29
31
  from sgl_kernel import (
32
+ fast_topk,
30
33
  top_k_renorm_prob,
31
34
  top_p_renorm_prob,
32
35
  tree_speculative_sampling_target_only,
33
36
  verify_tree_greedy,
34
37
  )
35
38
  elif is_hip():
36
- from sgl_kernel import verify_tree_greedy
39
+ from sgl_kernel import fast_topk, verify_tree_greedy
37
40
 
38
41
 
39
42
  logger = logging.getLogger(__name__)
@@ -67,9 +70,9 @@ class EagleDraftInput:
67
70
  kv_indptr: torch.Tensor = None
68
71
  kv_indices: torch.Tensor = None
69
72
 
70
- all_padding_lens: Optional[torch.Tensor] = None
71
-
72
73
  def prepare_for_extend(self, batch: ScheduleBatch):
74
+ if batch.forward_mode.is_idle():
75
+ return
73
76
  # Prefill only generate 1 token.
74
77
  assert len(self.verified_id) == len(batch.seq_lens)
75
78
 
@@ -81,6 +84,25 @@ class EagleDraftInput:
81
84
  )
82
85
  pt += extend_len
83
86
 
87
+ @classmethod
88
+ def create_idle_input(
89
+ cls,
90
+ device: torch.device,
91
+ hidden_size: int,
92
+ dtype: torch.dtype,
93
+ topk: int,
94
+ capture_hidden_mode: CaptureHiddenMode,
95
+ ):
96
+ return cls(
97
+ verified_id=None,
98
+ hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
99
+ topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
100
+ topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
101
+ capture_hidden_mode=capture_hidden_mode,
102
+ accept_length=torch.empty((0,), device=device, dtype=torch.int32),
103
+ accept_length_cpu=[],
104
+ )
105
+
84
106
  def prepare_extend_after_decode(
85
107
  self,
86
108
  batch: ScheduleBatch,
@@ -93,6 +115,7 @@ class EagleDraftInput:
93
115
  batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
94
116
  batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
95
117
  batch.return_logprob = False
118
+ batch.return_hidden_states = False
96
119
 
97
120
  self.capture_hidden_mode = CaptureHiddenMode.LAST
98
121
  self.accept_length.add_(1)
@@ -116,13 +139,14 @@ class EagleDraftInput:
116
139
  req_to_token: torch.Tensor,
117
140
  ):
118
141
  bs = self.accept_length.numel()
119
-
120
142
  qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
121
143
  qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
122
-
123
144
  cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
124
145
  cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
125
146
 
147
+ if paged_kernel_lens_sum is None:
148
+ paged_kernel_lens_sum = cum_kv_seq_len[-1]
149
+
126
150
  kv_indices = torch.empty(
127
151
  paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
128
152
  )
@@ -136,7 +160,6 @@ class EagleDraftInput:
136
160
  kv_indices,
137
161
  req_to_token.size(1),
138
162
  )
139
-
140
163
  return kv_indices, cum_kv_seq_len, qo_indptr, None
141
164
 
142
165
  def filter_batch(self, new_indices: torch.Tensor):
@@ -193,7 +216,35 @@ class EagleVerifyInput:
193
216
  seq_lens_cpu: torch.Tensor
194
217
  grammar: BaseGrammarObject = None
195
218
 
219
+ @classmethod
220
+ def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
221
+ return cls(
222
+ draft_token=torch.empty((0,), dtype=torch.long, device="cuda"),
223
+ custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"),
224
+ positions=torch.empty((0,), dtype=torch.int64, device="cuda"),
225
+ retrive_index=torch.full(
226
+ (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
227
+ ),
228
+ retrive_next_token=torch.full(
229
+ (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
230
+ ),
231
+ retrive_next_sibling=torch.full(
232
+ (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
233
+ ),
234
+ retrive_cum_len=None,
235
+ topk=topk,
236
+ draft_token_num=num_verify_tokens,
237
+ spec_steps=spec_steps,
238
+ capture_hidden_mode=CaptureHiddenMode.FULL,
239
+ seq_lens_sum=0,
240
+ seq_lens_cpu=torch.empty((0,), dtype=torch.int32),
241
+ )
242
+
196
243
  def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
244
+
245
+ if batch.forward_mode.is_idle():
246
+ return
247
+
197
248
  batch.input_ids = self.draft_token
198
249
 
199
250
  if page_size == 1:
@@ -265,9 +316,9 @@ class EagleVerifyInput:
265
316
  self,
266
317
  batch: ScheduleBatch,
267
318
  logits_output: torch.Tensor,
268
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
319
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
269
320
  page_size: int,
270
- vocab_mask: Optional[torch.Tensor] = None,
321
+ vocab_mask: Optional[torch.Tensor] = None, # For grammar
271
322
  ) -> torch.Tensor:
272
323
  """
273
324
  Verify and find accepted tokens based on logits output and batch
@@ -279,6 +330,26 @@ class EagleVerifyInput:
279
330
  tokens. I.e., logits_output.next_token_logits only contains
280
331
  accepted token logits.
281
332
  """
333
+ if batch.forward_mode.is_idle():
334
+ return EagleVerifyOutput(
335
+ draft_input=EagleDraftInput.create_idle_input(
336
+ device=batch.device,
337
+ hidden_size=batch.model_config.hidden_size,
338
+ dtype=batch.model_config.dtype,
339
+ topk=self.topk,
340
+ capture_hidden_mode=CaptureHiddenMode.LAST,
341
+ ),
342
+ logits_output=logits_output,
343
+ verified_id=torch.empty(0, dtype=torch.long, device=batch.device),
344
+ accept_length_per_req_cpu=[],
345
+ accepted_indices=torch.full(
346
+ (0, self.spec_steps + 1),
347
+ -1,
348
+ dtype=torch.int32,
349
+ device=batch.device,
350
+ ),
351
+ )
352
+
282
353
  bs = self.retrive_index.shape[0]
283
354
  candidates = self.draft_token.reshape(bs, self.draft_token_num)
284
355
  sampling_info = batch.sampling_info
@@ -291,6 +362,14 @@ class EagleVerifyInput:
291
362
  )
292
363
  accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
293
364
 
365
+ # Apply the custom logit processors if registered in the sampling info.
366
+ if sampling_info.has_custom_logit_processor:
367
+ apply_custom_logit_processor(
368
+ logits_output.next_token_logits,
369
+ sampling_info,
370
+ num_tokens_in_batch=self.draft_token_num,
371
+ )
372
+
294
373
  # Apply penalty
295
374
  if sampling_info.penalizer_orchestrator.is_required:
296
375
  # This is a relaxed version of penalties for speculative decoding.
@@ -320,11 +399,11 @@ class EagleVerifyInput:
320
399
  predicts=predict, # mutable
321
400
  accept_index=accept_index, # mutable
322
401
  accept_token_num=accept_length, # mutable
323
- candidates=candidates.to(torch.int32),
324
- retrive_index=self.retrive_index.to(torch.int32),
325
- retrive_next_token=self.retrive_next_token.to(torch.int32),
326
- retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
327
- target_predict=target_predict.to(torch.int32),
402
+ candidates=candidates,
403
+ retrive_index=self.retrive_index,
404
+ retrive_next_token=self.retrive_next_token,
405
+ retrive_next_sibling=self.retrive_next_sibling,
406
+ target_predict=target_predict,
328
407
  )
329
408
  else:
330
409
  # apply temperature and get target probs
@@ -352,16 +431,23 @@ class EagleVerifyInput:
352
431
  draft_probs = torch.zeros(
353
432
  target_probs.shape, dtype=torch.float32, device="cuda"
354
433
  )
434
+
435
+ # coins for rejection sampling
355
436
  coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
437
+ # coins for final sampling
438
+ coins_for_final_sampling = torch.rand(
439
+ (bs,), dtype=torch.float32, device="cuda"
440
+ )
356
441
  tree_speculative_sampling_target_only(
357
442
  predicts=predict, # mutable
358
443
  accept_index=accept_index, # mutable
359
444
  accept_token_num=accept_length, # mutable
360
- candidates=candidates.to(torch.int32),
361
- retrive_index=self.retrive_index.to(torch.int32),
362
- retrive_next_token=self.retrive_next_token.to(torch.int32),
363
- retrive_next_sibling=self.retrive_next_sibling.to(torch.int32),
445
+ candidates=candidates,
446
+ retrive_index=self.retrive_index,
447
+ retrive_next_token=self.retrive_next_token,
448
+ retrive_next_sibling=self.retrive_next_sibling,
364
449
  uniform_samples=coins,
450
+ uniform_samples_for_final_sampling=coins_for_final_sampling,
365
451
  target_probs=target_probs,
366
452
  draft_probs=draft_probs,
367
453
  threshold_single=global_server_args_dict[
@@ -384,8 +470,8 @@ class EagleVerifyInput:
384
470
  spec_steps=self.spec_steps,
385
471
  )
386
472
 
387
- new_accept_index = []
388
473
  unfinished_index = []
474
+ unfinished_accept_index = []
389
475
  accept_index_cpu = accept_index.tolist()
390
476
  predict_cpu = predict.tolist()
391
477
  has_finished = False
@@ -393,12 +479,10 @@ class EagleVerifyInput:
393
479
  # Iterate every accepted token and check if req has finished after append the token
394
480
  # should be checked BEFORE free kv cache slots
395
481
  for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
396
- new_accept_index_ = []
397
482
  for j, idx in enumerate(accept_index_row):
398
483
  if idx == -1:
399
484
  break
400
485
  id = predict_cpu[idx]
401
- # if not found_finished:
402
486
  req.output_ids.append(id)
403
487
  req.check_finished()
404
488
  if req.finished():
@@ -407,8 +491,6 @@ class EagleVerifyInput:
407
491
  accept_index[i, j + 1 :] = -1
408
492
  break
409
493
  else:
410
- new_accept_index_.append(idx)
411
- # update grammar state
412
494
  if req.grammar is not None:
413
495
  try:
414
496
  req.grammar.accept_token(id)
@@ -418,50 +500,104 @@ class EagleVerifyInput:
418
500
  )
419
501
  raise e
420
502
  if not req.finished():
421
- new_accept_index.extend(new_accept_index_)
422
503
  unfinished_index.append(i)
504
+ if idx == -1:
505
+ unfinished_accept_index.append(accept_index[i, :j])
506
+ else:
507
+ unfinished_accept_index.append(accept_index[i])
423
508
  req.spec_verify_ct += 1
424
509
 
425
510
  if has_finished:
426
511
  accept_length = (accept_index != -1).sum(dim=1) - 1
427
512
 
428
513
  # Free the KV cache for unaccepted tokens
514
+ # TODO: fuse them
429
515
  accept_index = accept_index[accept_index != -1]
430
516
  verified_id = predict[accept_index]
431
517
  evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
432
518
  evict_mask[accept_index] = False
433
519
 
434
- if page_size != 1:
435
- align_evict_mask_to_page_size[len(batch.seq_lens),](
436
- batch.seq_lens,
437
- evict_mask,
438
- page_size,
439
- self.draft_token_num,
440
- next_power_of_2(self.draft_token_num),
441
- )
520
+ if page_size == 1:
521
+ # TODO: boolean array index leads to a device sync. Remove it.
522
+ token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
523
+ else:
524
+ if self.topk == 1:
525
+ # Only evict full empty page. Do not evict partial empty page
526
+ align_evict_mask_to_page_size[len(batch.seq_lens),](
527
+ batch.seq_lens,
528
+ evict_mask,
529
+ page_size,
530
+ self.draft_token_num,
531
+ next_power_of_2(self.draft_token_num),
532
+ )
533
+ token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
534
+ else:
535
+ # Shift the accepted tokens to the beginning.
536
+ # Only evict the last part
537
+ src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
538
+ batch.seq_lens,
539
+ batch.out_cache_loc,
540
+ accept_index,
541
+ accept_length,
542
+ self.draft_token_num,
543
+ page_size,
544
+ )
545
+ to_free_slots = torch.empty(
546
+ (to_free_num_slots.sum().item(),),
547
+ dtype=torch.int64,
548
+ device=to_free_num_slots.device,
549
+ )
442
550
 
443
- token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
551
+ # out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
552
+ # accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
553
+ # tgt_cache_loc: [0 1 , 3 4 , 6 ]
554
+ # to_free_slots: [ 2, 5, 7 8]
555
+ # to_free_slots also needs to be page-aligned without the first partial page
556
+ #
557
+ # split each row of out_cache_loc into two parts.
558
+ # 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
559
+ # 2. the second part goes to to_free_slots.
560
+ get_target_cache_loc[(bs,)](
561
+ tgt_cache_loc,
562
+ to_free_slots,
563
+ accept_length,
564
+ to_free_num_slots,
565
+ batch.out_cache_loc,
566
+ self.draft_token_num,
567
+ next_power_of_2(self.draft_token_num),
568
+ next_power_of_2(bs),
569
+ )
570
+
571
+ # Free the kv cache
572
+ token_to_kv_pool_allocator.free(to_free_slots)
573
+
574
+ # Copy the kv cache
575
+ batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
576
+ tgt_cache_loc, src_cache_loc
577
+ )
444
578
 
445
579
  # Construct EagleVerifyOutput
446
580
  if not has_finished:
447
- batch.out_cache_loc = batch.out_cache_loc[accept_index]
448
- assign_req_to_token_pool[(bs,)](
449
- batch.req_pool_indices,
450
- batch.req_to_token_pool.req_to_token,
451
- batch.seq_lens,
452
- batch.seq_lens + accept_length + 1,
453
- batch.out_cache_loc,
454
- batch.req_to_token_pool.req_to_token.shape[1],
455
- next_power_of_2(bs),
456
- )
581
+ if page_size == 1 or self.topk == 1:
582
+ batch.out_cache_loc = batch.out_cache_loc[accept_index]
583
+ assign_req_to_token_pool[(bs,)](
584
+ batch.req_pool_indices,
585
+ batch.req_to_token_pool.req_to_token,
586
+ batch.seq_lens,
587
+ batch.seq_lens + accept_length + 1,
588
+ batch.out_cache_loc,
589
+ batch.req_to_token_pool.req_to_token.shape[1],
590
+ next_power_of_2(bs),
591
+ )
592
+ else:
593
+ batch.out_cache_loc = tgt_cache_loc
457
594
  batch.seq_lens.add_(accept_length + 1)
458
- accept_length_cpu = accept_length.tolist()
459
595
 
460
596
  draft_input = EagleDraftInput()
461
597
  draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
462
598
  draft_input.verified_id = verified_id
463
599
  draft_input.accept_length = accept_length
464
- draft_input.accept_length_cpu = accept_length_cpu
600
+ draft_input.accept_length_cpu = accept_length.tolist()
465
601
  draft_input.seq_lens_for_draft_extend = batch.seq_lens
466
602
  draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
467
603
 
@@ -469,47 +605,66 @@ class EagleVerifyInput:
469
605
  draft_input=draft_input,
470
606
  logits_output=logits_output,
471
607
  verified_id=verified_id,
472
- accept_length_per_req_cpu=accept_length_cpu,
608
+ accept_length_per_req_cpu=draft_input.accept_length_cpu,
473
609
  accepted_indices=accept_index,
474
610
  )
475
611
  else:
476
- assign_req_to_token_pool[(bs,)](
477
- batch.req_pool_indices,
478
- batch.req_to_token_pool.req_to_token,
479
- batch.seq_lens,
480
- batch.seq_lens + accept_length + 1,
481
- batch.out_cache_loc[accept_index],
482
- batch.req_to_token_pool.req_to_token.shape[1],
483
- next_power_of_2(bs),
484
- )
485
- batch.seq_lens.add_(accept_length + 1)
486
- accept_length_cpu = accept_length.tolist()
612
+ if page_size == 1 or self.topk == 1:
613
+ assign_req_to_token_pool[(bs,)](
614
+ batch.req_pool_indices,
615
+ batch.req_to_token_pool.req_to_token,
616
+ batch.seq_lens,
617
+ batch.seq_lens + accept_length + 1,
618
+ batch.out_cache_loc[accept_index],
619
+ batch.req_to_token_pool.req_to_token.shape[1],
620
+ next_power_of_2(bs),
621
+ )
622
+ batch.seq_lens.add_(accept_length + 1)
487
623
 
624
+ accept_length_cpu = accept_length.tolist()
488
625
  draft_input = EagleDraftInput()
489
- if len(new_accept_index) > 0:
490
- new_accept_index = torch.tensor(new_accept_index, device="cuda")
491
- unfinished_index_device = torch.tensor(unfinished_index, device="cuda")
492
- draft_input.hidden_states = batch.spec_info.hidden_states[
493
- new_accept_index
494
- ]
495
- draft_input.verified_id = predict[new_accept_index]
496
- draft_input.accept_length_cpu = [
626
+ if len(unfinished_accept_index) > 0:
627
+ unfinished_accept_index = torch.cat(unfinished_accept_index)
628
+ unfinished_index_device = torch.tensor(
629
+ unfinished_index, dtype=torch.int64, device=predict.device
630
+ )
631
+ draft_input_accept_length_cpu = [
497
632
  accept_length_cpu[i] for i in unfinished_index
498
633
  ]
499
- draft_input.accept_length = accept_length[unfinished_index_device]
500
- if has_finished:
501
- draft_input.seq_lens_for_draft_extend = batch.seq_lens[
502
- unfinished_index_device
503
- ]
504
- draft_input.req_pool_indices_for_draft_extend = (
505
- batch.req_pool_indices[unfinished_index_device]
506
- )
634
+ if page_size == 1 or self.topk == 1:
635
+ batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
507
636
  else:
508
- draft_input.seq_lens_for_draft_extend = batch.seq_lens
509
- draft_input.req_pool_indices_for_draft_extend = (
510
- batch.req_pool_indices
637
+ batch.out_cache_loc = torch.empty(
638
+ len(unfinished_index) + sum(draft_input_accept_length_cpu),
639
+ dtype=torch.int64,
640
+ device=predict.device,
641
+ )
642
+ accept_length_filter = create_accept_length_filter(
643
+ accept_length,
644
+ unfinished_index_device,
645
+ batch.seq_lens,
646
+ )
647
+ filter_finished_cache_loc_kernel[(bs,)](
648
+ batch.out_cache_loc,
649
+ tgt_cache_loc,
650
+ accept_length,
651
+ accept_length_filter,
652
+ next_power_of_2(bs),
653
+ next_power_of_2(self.draft_token_num),
511
654
  )
512
- batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
655
+
656
+ draft_input.hidden_states = batch.spec_info.hidden_states[
657
+ unfinished_accept_index
658
+ ]
659
+ draft_input.verified_id = predict[unfinished_accept_index]
660
+ draft_input.accept_length_cpu = draft_input_accept_length_cpu
661
+ draft_input.accept_length = accept_length[unfinished_index_device]
662
+ draft_input.seq_lens_for_draft_extend = batch.seq_lens[
663
+ unfinished_index_device
664
+ ]
665
+ draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
666
+ unfinished_index_device
667
+ ]
513
668
 
514
669
  return EagleVerifyOutput(
515
670
  draft_input=draft_input,
@@ -586,36 +741,75 @@ def assign_draft_cache_locs(
586
741
  req_pool_indices,
587
742
  req_to_token,
588
743
  seq_lens,
744
+ extend_lens,
745
+ num_new_pages_per_topk,
589
746
  out_cache_loc,
590
747
  pool_len: tl.constexpr,
591
748
  topk: tl.constexpr,
592
749
  speculative_num_steps: tl.constexpr,
593
750
  page_size: tl.constexpr,
751
+ bs_upper: tl.constexpr,
752
+ iter_upper: tl.constexpr,
594
753
  ):
595
- BLOCK_SIZE: tl.constexpr = 32
754
+ BLOCK_SIZE: tl.constexpr = 128
596
755
  pid = tl.program_id(axis=0)
597
- kv_start = tl.load(seq_lens + pid)
598
756
 
599
757
  if page_size == 1 or topk == 1:
600
- kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
758
+ copy_len = topk * speculative_num_steps
601
759
  out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
602
760
  else:
603
- prefix_len = tl.load(seq_lens + pid)
604
- last_page_len = prefix_len % page_size
605
- num_new_page = (
606
- last_page_len + speculative_num_steps + page_size - 1
607
- ) // page_size
608
- kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
761
+ bs_offset = tl.arange(0, bs_upper)
762
+ copy_len = tl.load(extend_lens + pid)
763
+ cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
764
+ out_cache_ptr = out_cache_loc + cum_copy_len
609
765
 
766
+ # Part 1: Copy from out_cache_loc to req_to_token
767
+ kv_start = tl.load(seq_lens + pid)
610
768
  token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
611
-
612
- num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
769
+ num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
613
770
  for i in range(num_loop):
614
- save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start
615
- load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
616
- mask = save_offset < kv_end
617
- data = tl.load(out_cache_ptr + load_offset, mask=mask)
618
- tl.store(token_pool + save_offset, data, mask=mask)
771
+ copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
772
+ mask = copy_offset < copy_len
773
+ data = tl.load(out_cache_ptr + copy_offset, mask=mask)
774
+ tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
775
+
776
+ if page_size == 1 or topk == 1:
777
+ return
778
+
779
+ # Part 2: Copy the indices for the last partial page
780
+ prefix_len = tl.load(seq_lens + pid)
781
+ last_page_len = prefix_len % page_size
782
+ offsets = tl.arange(0, page_size)
783
+ mask = offsets < last_page_len
784
+ num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
785
+ prefix_base = token_pool + prefix_len - last_page_len
786
+
787
+ for topk_id in range(topk):
788
+ value = tl.load(prefix_base + offsets, mask=mask)
789
+ tl.store(
790
+ prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
791
+ value,
792
+ mask=mask,
793
+ )
794
+
795
+ # Part 3: Remove the padding in out_cache_loc
796
+ iter_offest = tl.arange(0, iter_upper)
797
+ for topk_id in range(topk):
798
+ indices = tl.load(
799
+ prefix_base
800
+ + topk_id * num_new_pages_per_topk_ * page_size
801
+ + last_page_len
802
+ + iter_offest,
803
+ mask=iter_offest < speculative_num_steps,
804
+ )
805
+ tl.store(
806
+ out_cache_loc
807
+ + pid * topk * speculative_num_steps
808
+ + topk_id * speculative_num_steps
809
+ + iter_offest,
810
+ indices,
811
+ mask=iter_offest < speculative_num_steps,
812
+ )
619
813
 
620
814
 
621
815
  @triton.jit
@@ -626,20 +820,23 @@ def generate_draft_decode_kv_indices(
626
820
  kv_indices,
627
821
  kv_indptr,
628
822
  positions,
629
- num_seqs: tl.constexpr,
630
- topk: tl.constexpr,
631
823
  pool_len: tl.constexpr,
632
824
  kv_indices_stride: tl.constexpr,
633
825
  kv_indptr_stride: tl.constexpr,
634
826
  bs_upper: tl.constexpr,
635
827
  iter_upper: tl.constexpr,
636
828
  num_tokens_upper: tl.constexpr,
829
+ page_size: tl.constexpr,
637
830
  ):
638
831
  BLOCK_SIZE: tl.constexpr = 128
639
832
  iters = tl.program_id(axis=0)
640
833
  bid = tl.program_id(axis=1)
641
834
  topk_id = tl.program_id(axis=2)
642
835
 
836
+ num_steps = tl.num_programs(axis=0)
837
+ num_seqs = tl.num_programs(axis=1)
838
+ topk = tl.num_programs(axis=2)
839
+
643
840
  kv_indices += kv_indices_stride * iters
644
841
  kv_indptr += kv_indptr_stride * iters
645
842
  iters += 1
@@ -649,6 +846,7 @@ def generate_draft_decode_kv_indices(
649
846
  seq_len = tl.load(paged_kernel_lens + bid)
650
847
  cum_seq_len = tl.sum(seq_lens)
651
848
 
849
+ # Update kv_indices
652
850
  kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
653
851
  kv_ptr = kv_indices + kv_offset
654
852
  token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
@@ -662,10 +860,26 @@ def generate_draft_decode_kv_indices(
662
860
  kv_offset += BLOCK_SIZE
663
861
 
664
862
  extend_offset = tl.arange(0, iter_upper)
665
- extend_data = tl.load(
666
- token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id,
667
- mask=extend_offset < iters,
668
- )
863
+ if page_size == 1 or topk == 1:
864
+ extend_data = tl.load(
865
+ token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
866
+ mask=extend_offset < iters,
867
+ )
868
+ else:
869
+ prefix_len = seq_len
870
+ last_page_len = prefix_len % page_size
871
+ num_new_pages_per_topk = (
872
+ last_page_len + num_steps + page_size - 1
873
+ ) // page_size
874
+ prefix_base = seq_len // page_size * page_size
875
+ start = (
876
+ prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
877
+ )
878
+ extend_data = tl.load(
879
+ token_pool_ptr + start + extend_offset,
880
+ mask=extend_offset < iters,
881
+ )
882
+
669
883
  tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
670
884
 
671
885
  # Update kv_indptr
@@ -704,6 +918,116 @@ def align_evict_mask_to_page_size(
704
918
  tl.store(evict_mask + bid * num_draft_tokens + i, False)
705
919
 
706
920
 
921
+ @triton.jit
922
+ def get_target_cache_loc(
923
+ tgt_cache_loc,
924
+ to_free_slots,
925
+ accept_length,
926
+ to_free_num_slots,
927
+ out_cache_loc,
928
+ num_verify_tokens: tl.constexpr,
929
+ num_verify_tokens_upper: tl.constexpr,
930
+ bs_upper: tl.constexpr,
931
+ ):
932
+ bid = tl.program_id(axis=0)
933
+ offset = tl.arange(0, num_verify_tokens_upper)
934
+ bs_offset = tl.arange(0, bs_upper)
935
+
936
+ # write the first part to tgt_cache_loc
937
+ accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
938
+ tgt_cache_loc_start = tl.sum(accept_len_all) + bid
939
+ copy_len = tl.load(accept_length + bid) + 1
940
+ out_cache_loc_row = tl.load(
941
+ out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
942
+ )
943
+ tl.store(
944
+ tgt_cache_loc + tgt_cache_loc_start + offset,
945
+ out_cache_loc_row,
946
+ mask=offset < copy_len,
947
+ )
948
+
949
+ # write the second part to to_free_num_pages
950
+ to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
951
+ to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
952
+ out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
953
+ to_free_slots_start = tl.sum(to_free_num_slots_all)
954
+
955
+ copy_len = to_free_num_slots_cur
956
+ out_cache_loc_row = tl.load(
957
+ out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
958
+ mask=offset < copy_len,
959
+ )
960
+ tl.store(
961
+ to_free_slots + to_free_slots_start + offset,
962
+ out_cache_loc_row,
963
+ mask=offset < copy_len,
964
+ )
965
+
966
+
967
+ @torch.compile(dynamic=True)
968
+ def get_src_tgt_cache_loc(
969
+ seq_lens: torch.Tensor,
970
+ out_cache_loc: torch.Tensor,
971
+ accept_index: torch.Tensor,
972
+ accept_length: torch.Tensor,
973
+ draft_token_num: int,
974
+ page_size: int,
975
+ ):
976
+ src_cache_loc = out_cache_loc[accept_index]
977
+ tgt_cache_loc = torch.empty_like(src_cache_loc)
978
+ extended_len = seq_lens + draft_token_num
979
+ keep_len = torch.minimum(
980
+ (seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
981
+ extended_len,
982
+ )
983
+ to_free_num_slots = extended_len - keep_len
984
+ return src_cache_loc, tgt_cache_loc, to_free_num_slots
985
+
986
+
987
+ @triton.jit
988
+ def filter_finished_cache_loc_kernel(
989
+ out_cache_loc,
990
+ tgt_cache_loc,
991
+ accept_length,
992
+ accept_length_filter,
993
+ bs_upper: tl.constexpr,
994
+ num_verify_tokens_upper: tl.constexpr,
995
+ ):
996
+ bid = tl.program_id(0)
997
+ bs_offset = tl.arange(0, bs_upper)
998
+
999
+ accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
1000
+ old_start = tl.sum(accept_length_all) + bid
1001
+
1002
+ accept_length_filter_all = tl.load(
1003
+ accept_length_filter + bs_offset, mask=bs_offset < bid
1004
+ )
1005
+ new_start = tl.sum(accept_length_filter_all)
1006
+
1007
+ copy_len = tl.load(accept_length_filter + bid)
1008
+ copy_offset = tl.arange(0, num_verify_tokens_upper)
1009
+ value = tl.load(
1010
+ tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
1011
+ )
1012
+ tl.store(
1013
+ out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
1014
+ )
1015
+
1016
+
1017
+ @torch.compile(dynamic=True)
1018
+ def create_accept_length_filter(
1019
+ accept_length: torch.Tensor,
1020
+ unfinished_index_device: torch.Tensor,
1021
+ seq_lens: torch.Tensor,
1022
+ ):
1023
+ accept_length_filter = torch.zeros_like(accept_length)
1024
+ accept_length_filter[unfinished_index_device] = (
1025
+ accept_length[unfinished_index_device] + 1
1026
+ )
1027
+ seq_lens.add_(accept_length + 1)
1028
+ return accept_length_filter
1029
+
1030
+
707
1031
  @torch.compile(dynamic=True)
708
1032
  def select_top_k_tokens(
709
1033
  i: int,
@@ -739,10 +1063,11 @@ def select_top_k_tokens(
739
1063
  topk_index = topk_index.reshape(-1, topk**2)
740
1064
  input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
741
1065
 
742
- selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
743
- 0, hidden_states.shape[0], step=topk, device="cuda"
744
- ).repeat_interleave(topk)
745
- hidden_states = hidden_states[selected_input_index, :]
1066
+ if hidden_states.shape[0] > 0:
1067
+ selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
1068
+ 0, hidden_states.shape[0], step=topk, device="cuda"
1069
+ ).repeat_interleave(topk)
1070
+ hidden_states = hidden_states[selected_input_index, :]
746
1071
 
747
1072
  tree_info = (
748
1073
  expand_scores, # shape: (b, topk, topk)
@@ -762,15 +1087,35 @@ def _generate_simulated_accept_index(
762
1087
  spec_steps,
763
1088
  ):
764
1089
  simulate_acc_len_float = float(simulate_acc_len)
765
- simulated_values = torch.normal(
766
- mean=simulate_acc_len_float,
767
- std=1.0,
768
- size=(1,),
769
- device="cpu",
770
- )
771
- # clamp simulated values to be between 1 and self.spec_steps
772
- simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps)
773
- simulate_acc_len = int(simulated_values.round().item())
1090
+ if SIMULATE_ACC_METHOD == "multinomial":
1091
+ simulated_values = torch.normal(
1092
+ mean=simulate_acc_len_float,
1093
+ std=1.0,
1094
+ size=(1,),
1095
+ device="cpu",
1096
+ )
1097
+ # clamp simulated values to be between 1 and self.spec_steps
1098
+ simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
1099
+ simulate_acc_len = int(simulated_values.round().item())
1100
+ elif SIMULATE_ACC_METHOD == "match-expected":
1101
+ # multinomial sampling does not match the expected length
1102
+ # we keep it for the sake of compatibility of existing tests
1103
+ # but it's better to use "match-expected" for the cases that need to
1104
+ # match the expected length, One caveat is that this will only sample
1105
+ # either round down or round up of the expected length
1106
+ simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
1107
+ lower = int(simulate_acc_len_float // 1)
1108
+ upper = lower + 1 if lower < spec_steps + 1 else lower
1109
+ if lower == upper:
1110
+ simulate_acc_len = lower
1111
+ else:
1112
+ weight_upper = simulate_acc_len_float - lower
1113
+ weight_lower = 1.0 - weight_upper
1114
+ probs = torch.tensor([weight_lower, weight_upper], device="cpu")
1115
+ sampled_index = torch.multinomial(probs, num_samples=1)
1116
+ simulate_acc_len = lower if sampled_index == 0 else upper
1117
+ else:
1118
+ raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
774
1119
 
775
1120
  accept_indx_first_col = accept_index[:, 0].view(-1, 1)
776
1121
  sim_accept_index = torch.full(
@@ -861,9 +1206,9 @@ def generate_token_bitmask(
861
1206
  """
862
1207
  Generate the logit mask for structured output.
863
1208
  Draft model's token can be either valid or invalid with respect to the grammar.
864
- We need to perform DFS to figure out:
865
- 1. which tokens are accepted by the grammar
866
- 2. what is the corresponding logit mask.
1209
+ We need to perform DFS to
1210
+ 1. figure out which tokens are accepted by the grammar.
1211
+ 2. if so, what is the corresponding logit mask.
867
1212
  """
868
1213
 
869
1214
  num_draft_tokens = draft_tokens_cpu.shape[-1]
@@ -880,6 +1225,7 @@ def generate_token_bitmask(
880
1225
  device="cpu",
881
1226
  )
882
1227
  grammar = req.grammar
1228
+ s = time.perf_counter()
883
1229
  traverse_tree(
884
1230
  retrieve_next_token_cpu[i],
885
1231
  retrieve_next_sibling_cpu[i],
@@ -889,6 +1235,12 @@ def generate_token_bitmask(
889
1235
  i * num_draft_tokens : (i + 1) * num_draft_tokens
890
1236
  ],
891
1237
  )
1238
+ tree_traverse_time = time.perf_counter() - s
1239
+ if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
1240
+ logger.warning(
1241
+ f"Bit mask generation took {tree_traverse_time} seconds with "
1242
+ f"grammar: {req.grammar}"
1243
+ )
892
1244
 
893
1245
  verify_input.grammar = grammar
894
1246
  return allocate_token_bitmask