sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -24,12 +24,13 @@ from sglang.srt.speculative.eagle_info_v2 import (
24
24
  EagleDraftInputV2Mixin,
25
25
  EagleVerifyInputV2Mixin,
26
26
  )
27
+ from sglang.srt.speculative.eagle_utils import verify_tree_greedy_func
27
28
  from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
28
29
  from sglang.srt.speculative.spec_utils import (
29
30
  SIMULATE_ACC_LEN,
30
31
  TREE_SPEC_KERNEL_AVAILABLE,
31
32
  align_evict_mask_to_page_size,
32
- assign_req_to_token_pool,
33
+ assign_req_to_token_pool_func,
33
34
  create_accept_length_filter,
34
35
  create_extend_after_decode_spec_info,
35
36
  filter_finished_cache_loc_kernel,
@@ -37,17 +38,16 @@ from sglang.srt.speculative.spec_utils import (
37
38
  get_src_tgt_cache_loc,
38
39
  get_target_cache_loc,
39
40
  )
40
- from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
41
+ from sglang.srt.utils import is_cuda, is_npu, next_power_of_2
42
+
43
+ _is_npu = is_npu()
41
44
 
42
45
  if is_cuda():
43
46
  from sgl_kernel import (
44
47
  top_k_renorm_prob,
45
48
  top_p_renorm_prob,
46
49
  tree_speculative_sampling_target_only,
47
- verify_tree_greedy,
48
50
  )
49
- elif is_hip():
50
- from sgl_kernel import verify_tree_greedy
51
51
 
52
52
  logger = logging.getLogger(__name__)
53
53
 
@@ -77,18 +77,22 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
77
77
 
78
78
  @classmethod
79
79
  def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
80
+ if not _is_npu:
81
+ device = "cuda"
82
+ else:
83
+ device = "npu"
80
84
  return cls(
81
- draft_token=torch.empty((0,), dtype=torch.long, device="cuda"),
82
- custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"),
83
- positions=torch.empty((0,), dtype=torch.int64, device="cuda"),
85
+ draft_token=torch.empty((0,), dtype=torch.long, device=device),
86
+ custom_mask=torch.full((0,), True, dtype=torch.bool, device=device),
87
+ positions=torch.empty((0,), dtype=torch.int64, device=device),
84
88
  retrive_index=torch.full(
85
- (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
89
+ (0, num_verify_tokens), -1, dtype=torch.long, device=device
86
90
  ),
87
91
  retrive_next_token=torch.full(
88
- (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
92
+ (0, num_verify_tokens), -1, dtype=torch.long, device=device
89
93
  ),
90
94
  retrive_next_sibling=torch.full(
91
- (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
95
+ (0, num_verify_tokens), -1, dtype=torch.long, device=device
92
96
  ),
93
97
  retrive_cum_len=None,
94
98
  topk=topk,
@@ -134,14 +138,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
134
138
  self.last_loc = last_loc
135
139
 
136
140
  bs = batch.batch_size()
137
- assign_req_to_token_pool[(bs,)](
141
+ assign_req_to_token_pool_func(
138
142
  batch.req_pool_indices,
139
143
  batch.req_to_token_pool.req_to_token,
140
144
  batch.seq_lens,
141
145
  end_offset,
142
146
  batch.out_cache_loc,
143
- batch.req_to_token_pool.req_to_token.shape[1],
144
- next_power_of_2(bs),
147
+ bs,
145
148
  )
146
149
 
147
150
  def generate_attn_arg_prefill(
@@ -151,16 +154,17 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
151
154
  paged_kernel_lens_sum: int,
152
155
  req_to_token: torch.Tensor,
153
156
  ):
157
+ device = req_pool_indices.device
154
158
  batch_size = len(req_pool_indices)
155
159
  qo_indptr = torch.arange(
156
160
  0,
157
161
  (1 + batch_size) * self.draft_token_num,
158
162
  step=self.draft_token_num,
159
163
  dtype=torch.int32,
160
- device="cuda",
164
+ device=device,
161
165
  )
162
166
  cum_kv_seq_len = torch.zeros(
163
- (batch_size + 1,), dtype=torch.int32, device="cuda"
167
+ (batch_size + 1,), dtype=torch.int32, device=device
164
168
  )
165
169
 
166
170
  paged_kernel_lens = paged_kernel_lens + self.draft_token_num
@@ -169,7 +173,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
169
173
  kv_indices = torch.empty(
170
174
  paged_kernel_lens_sum + self.draft_token_num * batch_size,
171
175
  dtype=torch.int32,
172
- device="cuda",
176
+ device=device,
173
177
  )
174
178
  create_flashinfer_kv_indices_triton[(batch_size,)](
175
179
  req_to_token,
@@ -226,11 +230,11 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
226
230
 
227
231
  predict_shape = list(logits_output.next_token_logits.shape)[:-1]
228
232
  predict_shape[-1] += 1
229
- predict = torch.empty(predict_shape, dtype=torch.int32, device="cuda")
233
+ predict = torch.empty(predict_shape, dtype=torch.int32, device=batch.device)
230
234
  accept_index = torch.full(
231
- (bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda"
235
+ (bs, self.spec_steps + 1), -1, dtype=torch.int32, device=batch.device
232
236
  )
233
- accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
237
+ accept_length = torch.empty((bs,), dtype=torch.int32, device=batch.device)
234
238
 
235
239
  if bs != len(sampling_info):
236
240
  sampling_info = copy.deepcopy(sampling_info)
@@ -254,7 +258,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
254
258
  linear_penalty = torch.zeros(
255
259
  (bs, logits_output.next_token_logits.shape[1]),
256
260
  dtype=torch.float32,
257
- device="cuda",
261
+ device=batch.device,
258
262
  )
259
263
  sampling_info.apply_logits_bias(linear_penalty)
260
264
  logits_output.next_token_logits.add_(
@@ -276,11 +280,10 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
276
280
  "Falling back to greedy verification."
277
281
  )
278
282
 
279
- if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
283
+ if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE or _is_npu:
280
284
  target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
281
285
  target_predict = target_predict.reshape(bs, self.draft_token_num)
282
-
283
- verify_tree_greedy(
286
+ predict, accept_index, accept_length = verify_tree_greedy_func(
284
287
  predicts=predict, # mutable
285
288
  accept_index=accept_index, # mutable
286
289
  accept_token_num=accept_length, # mutable
@@ -289,7 +292,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
289
292
  retrive_next_token=self.retrive_next_token,
290
293
  retrive_next_sibling=self.retrive_next_sibling,
291
294
  target_predict=target_predict,
295
+ topk=self.topk,
292
296
  )
297
+
293
298
  else:
294
299
  # apply temperature and get target probs
295
300
  expanded_temperature = torch.repeat_interleave(
@@ -315,14 +320,16 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
315
320
  target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
316
321
 
317
322
  draft_probs = torch.zeros(
318
- target_probs.shape, dtype=torch.float32, device="cuda"
323
+ target_probs.shape, dtype=torch.float32, device=batch.device
319
324
  )
320
325
 
321
326
  # coins for rejection sampling
322
- coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
327
+ coins = torch.rand_like(
328
+ candidates, dtype=torch.float32, device=batch.device
329
+ )
323
330
  # coins for final sampling
324
331
  coins_for_final_sampling = torch.rand(
325
- (bs,), dtype=torch.float32, device="cuda"
332
+ (bs,), dtype=torch.float32, device=batch.device
326
333
  )
327
334
  tree_speculative_sampling_target_only(
328
335
  predicts=predict, # mutable
@@ -468,14 +475,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
468
475
  if not has_finished:
469
476
  if page_size == 1 or self.topk == 1:
470
477
  batch.out_cache_loc = batch.out_cache_loc[accept_index]
471
- assign_req_to_token_pool[(bs,)](
478
+ assign_req_to_token_pool_func(
472
479
  batch.req_pool_indices,
473
480
  batch.req_to_token_pool.req_to_token,
474
481
  batch.seq_lens,
475
482
  batch.seq_lens + accept_length + 1,
476
483
  batch.out_cache_loc,
477
- batch.req_to_token_pool.req_to_token.shape[1],
478
- next_power_of_2(bs),
484
+ bs,
479
485
  )
480
486
  else:
481
487
  batch.out_cache_loc = tgt_cache_loc
@@ -501,14 +507,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
501
507
  )
502
508
  else:
503
509
  if page_size == 1 or self.topk == 1:
504
- assign_req_to_token_pool[(bs,)](
510
+ assign_req_to_token_pool_func(
505
511
  batch.req_pool_indices,
506
512
  batch.req_to_token_pool.req_to_token,
507
513
  batch.seq_lens,
508
514
  batch.seq_lens + accept_length + 1,
509
515
  batch.out_cache_loc[accept_index],
510
- batch.req_to_token_pool.req_to_token.shape[1],
511
- next_power_of_2(bs),
516
+ bs,
512
517
  )
513
518
  batch.seq_lens.add_(accept_length + 1)
514
519
  batch.seq_lens_cpu.add_(accept_length_cpu + 1)
@@ -695,17 +700,18 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
695
700
  paged_kernel_lens_sum: int,
696
701
  req_to_token: torch.Tensor,
697
702
  ):
703
+ device = req_pool_indices.device
698
704
  bs = self.accept_length.numel()
699
- qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
705
+ qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=device)
700
706
  qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
701
- cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
707
+ cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=device)
702
708
  cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
703
709
 
704
710
  if paged_kernel_lens_sum is None:
705
711
  paged_kernel_lens_sum = cum_kv_seq_len[-1]
706
712
 
707
713
  kv_indices = torch.empty(
708
- paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
714
+ paged_kernel_lens_sum, dtype=torch.int32, device=device
709
715
  )
710
716
 
711
717
  create_flashinfer_kv_indices_triton[(bs,)](
@@ -23,11 +23,16 @@ from sglang.srt.model_executor.forward_batch_info import (
23
23
  )
24
24
  from sglang.srt.model_executor.model_runner import ModelRunner
25
25
  from sglang.srt.server_args import get_global_server_args
26
+ from sglang.srt.speculative.eagle_utils import verify_tree_greedy_func
26
27
  from sglang.srt.speculative.spec_utils import (
27
28
  SIMULATE_ACC_LEN,
28
29
  generate_simulated_accept_index,
29
30
  )
30
- from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2
31
+ from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, is_npu, next_power_of_2
32
+
33
+ _is_cuda = is_cuda()
34
+ _is_hip = is_hip()
35
+ _is_npu = is_npu()
31
36
 
32
37
  if TYPE_CHECKING:
33
38
  from sglang.srt.managers.tp_worker import TpModelWorker
@@ -41,11 +46,8 @@ if is_cuda():
41
46
  top_k_renorm_prob,
42
47
  top_p_renorm_prob,
43
48
  tree_speculative_sampling_target_only,
44
- verify_tree_greedy,
45
49
  )
46
50
  from sgl_kernel.top_k import fast_topk
47
- elif is_hip():
48
- from sgl_kernel import verify_tree_greedy
49
51
 
50
52
 
51
53
  @triton.jit
@@ -78,7 +80,7 @@ def assign_draft_cache_locs_page_size_1(
78
80
  @dataclass
79
81
  class EagleDraftInputV2Mixin:
80
82
  def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):
81
- from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
83
+ from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func
82
84
 
83
85
  bs = batch.batch_size()
84
86
 
@@ -112,15 +114,15 @@ class EagleDraftInputV2Mixin:
112
114
  extend_num_tokens,
113
115
  )
114
116
 
115
- assign_req_to_token_pool[(bs,)](
117
+ assign_req_to_token_pool_func(
116
118
  batch.req_pool_indices,
117
119
  batch.req_to_token_pool.req_to_token,
118
120
  self.allocate_lens,
119
121
  new_allocate_lens,
120
122
  out_cache_loc,
121
- batch.req_to_token_pool.req_to_token.shape[1],
122
- next_power_of_2(bs),
123
+ bs,
123
124
  )
125
+
124
126
  self.allocate_lens = new_allocate_lens
125
127
 
126
128
  # FIXME(lsyin): make this sync optional
@@ -199,22 +201,16 @@ class EagleVerifyInputV2Mixin:
199
201
  bs = len(batch.req_pool_indices)
200
202
  batch.input_ids = self.draft_token
201
203
  device = batch.input_ids.device
202
- batch.out_cache_loc = torch.empty(
203
- (bs * self.draft_token_num,),
204
- dtype=torch.int64,
204
+ batch.out_cache_loc = assign_extend_cache_locs_func(
205
+ req_pool_indices=batch.req_pool_indices,
206
+ req_to_token=req_to_token_pool.req_to_token,
207
+ start_offset=batch.seq_lens,
208
+ end_offset=batch.seq_lens + self.draft_token_num,
209
+ batch_size=bs,
210
+ draft_token_num=self.draft_token_num,
205
211
  device=device,
206
212
  )
207
213
 
208
- assign_extend_cache_locs[(bs,)](
209
- batch.req_pool_indices,
210
- req_to_token_pool.req_to_token,
211
- batch.seq_lens,
212
- batch.seq_lens + self.draft_token_num,
213
- batch.out_cache_loc,
214
- req_to_token_pool.req_to_token.shape[1],
215
- next_power_of_2(bs),
216
- )
217
-
218
214
  # Get a forward batch
219
215
  batch.forward_mode = ForwardMode.TARGET_VERIFY
220
216
  batch.capture_hidden_mode = CaptureHiddenMode.FULL
@@ -258,11 +254,10 @@ class EagleVerifyInputV2Mixin:
258
254
  accept_length = torch.empty((bs,), dtype=torch.int32, device=device)
259
255
 
260
256
  # Sample tokens
261
- if sampling_info.is_all_greedy:
257
+ if sampling_info.is_all_greedy or _is_npu:
262
258
  target_predict = torch.argmax(next_token_logits, dim=-1)
263
259
  target_predict = target_predict.reshape(bs, self.draft_token_num)
264
-
265
- verify_tree_greedy(
260
+ predict, accept_index, accept_length = verify_tree_greedy_func(
266
261
  predicts=predict, # mutable
267
262
  accept_index=accept_index, # mutable
268
263
  accept_token_num=accept_length, # mutable
@@ -271,6 +266,7 @@ class EagleVerifyInputV2Mixin:
271
266
  retrive_next_token=self.retrive_next_token,
272
267
  retrive_next_sibling=self.retrive_next_sibling,
273
268
  target_predict=target_predict,
269
+ topk=self.topk,
274
270
  )
275
271
  else:
276
272
  # Apply temperature and get target probs
@@ -338,7 +334,7 @@ class EagleVerifyInputV2Mixin:
338
334
  return predict, accept_length, accept_index
339
335
 
340
336
 
341
- @torch.compile(dynamic=True)
337
+ @torch.compile(dynamic=True, disable=_is_npu)
342
338
  def select_top_k_tokens_tmp(
343
339
  i: int,
344
340
  topk_p: torch.Tensor,
@@ -456,3 +452,50 @@ def assign_extend_cache_locs(
456
452
  tl.store(out_cache_ptr + save_offset, data, mask=mask)
457
453
  load_offset += BLOCK_SIZE
458
454
  save_offset += BLOCK_SIZE
455
+
456
+
457
+ def assign_extend_cache_locs_func(
458
+ req_pool_indices: torch.Tensor,
459
+ req_to_token: torch.Tensor,
460
+ start_offset: torch.Tensor,
461
+ end_offset: torch.Tensor,
462
+ batch_size: int,
463
+ draft_token_num: int,
464
+ device,
465
+ ) -> torch.Tensor:
466
+ if _is_cuda or _is_hip:
467
+ out_cache_loc = torch.empty(
468
+ (batch_size * draft_token_num,),
469
+ dtype=torch.int64,
470
+ device=device,
471
+ )
472
+ assign_extend_cache_locs[(batch_size,)](
473
+ req_pool_indices,
474
+ req_to_token,
475
+ start_offset,
476
+ end_offset,
477
+ out_cache_loc,
478
+ req_to_token.shape[1],
479
+ next_power_of_2(batch_size),
480
+ )
481
+
482
+ return out_cache_loc
483
+
484
+ elif _is_npu:
485
+ import sgl_kernel_npu # noqa: F401
486
+
487
+ out_cache_loc = torch.empty(
488
+ (batch_size * draft_token_num,),
489
+ dtype=torch.int32,
490
+ device=device,
491
+ )
492
+ torch.ops.npu.cache_loc_update(
493
+ req_pool_indices,
494
+ req_to_token,
495
+ start_offset,
496
+ end_offset,
497
+ out_cache_loc,
498
+ )
499
+ out_cache_loc = out_cache_loc.to(dtype=torch.int64)
500
+
501
+ return out_cache_loc
@@ -4,14 +4,128 @@ from typing import List, Optional
4
4
 
5
5
  import torch
6
6
 
7
- from sglang.srt.utils import is_cuda, is_hip
7
+ from sglang.srt.utils import is_cuda, is_hip, is_npu
8
8
 
9
- if is_cuda() or is_hip():
9
+ _is_cuda = is_cuda()
10
+ _is_hip = is_hip()
11
+ _is_npu = is_npu()
12
+
13
+ if _is_cuda or _is_hip:
10
14
  from sgl_kernel import (
11
15
  build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
12
16
  )
13
17
 
14
18
 
19
+ def build_tree_efficient_native(
20
+ parent_list: torch.Tensor,
21
+ selected_index: torch.Tensor,
22
+ verified_seq_len: torch.Tensor,
23
+ tree_mask: torch.Tensor,
24
+ retrive_index: torch.Tensor,
25
+ retrive_next_token: torch.Tensor,
26
+ retrive_next_sibling: torch.Tensor,
27
+ topk: int,
28
+ draft_token_num: int,
29
+ tree_mask_mode: int,
30
+ bs: int,
31
+ ):
32
+ # Generate batch and token index ranges
33
+ bs_range = torch.arange(bs, device=tree_mask.device).view(-1, 1)
34
+ draft_token_num_range = torch.arange(draft_token_num, device=tree_mask.device)
35
+
36
+ # Optimized common case for performance.
37
+ if draft_token_num == 2 and topk == 1 and tree_mask_mode == TreeMaskMode.FULL_MASK:
38
+ positions = verified_seq_len.repeat_interleave(draft_token_num)
39
+ positions = (positions.view(bs, -1) + draft_token_num_range).view(-1)
40
+
41
+ retrive_index[:] = bs_range * draft_token_num + draft_token_num_range
42
+ retrive_next_token[:, 0] = 1
43
+ retrive_next_token[:, 1] = -1
44
+ return (
45
+ positions,
46
+ retrive_index,
47
+ retrive_next_token,
48
+ retrive_next_sibling,
49
+ tree_mask,
50
+ )
51
+
52
+ # Precompute sequence tree indices
53
+ draft_token_num_range1 = torch.arange(draft_token_num - 1, device=tree_mask.device)
54
+ cum_seq_len = torch.cumsum(verified_seq_len * draft_token_num, dim=0)
55
+ cum_seq_len = torch.cat((torch.tensor([0], device=tree_mask.device), cum_seq_len))
56
+ cum_seq_len = cum_seq_len[:-1]
57
+ seq_tree_idx = (
58
+ draft_token_num * draft_token_num * torch.arange(bs, device=tree_mask.device)
59
+ + cum_seq_len
60
+ )
61
+
62
+ # Batch processing for tree mask
63
+ if tree_mask_mode == TreeMaskMode.FULL_MASK:
64
+ token_tree_base = (
65
+ seq_tree_idx.view(-1, 1)
66
+ + (verified_seq_len.view(-1, 1) + draft_token_num) * draft_token_num_range
67
+ )
68
+ token_tree_indices = token_tree_base + verified_seq_len.view(-1, 1) + 1
69
+ else:
70
+ token_tree_indices = (
71
+ bs_range * draft_token_num**2 + draft_token_num_range * draft_token_num + 1
72
+ )
73
+
74
+ tree_mask[token_tree_indices.flatten() - 1] = True
75
+ indices = token_tree_indices.unsqueeze(-1) + draft_token_num_range1.view(1, 1, -1)
76
+ tree_mask[indices.view(-1)] = False
77
+
78
+ positions = verified_seq_len.repeat_interleave(draft_token_num)
79
+ parent_tb_indices = selected_index // topk
80
+ retrive_index[:] = bs_range * draft_token_num + draft_token_num_range
81
+ tree_mask[token_tree_indices.view(-1, 1) + draft_token_num_range1] = True
82
+
83
+ for bid in range(bs):
84
+ for tid in range(draft_token_num):
85
+ position = 0
86
+ if tid == 0:
87
+ # Process root node
88
+ for i in range(draft_token_num - 1, 0, -1):
89
+ parent_position = 0
90
+ parent_tb_idx = parent_tb_indices[bid][i - 1]
91
+ if parent_tb_idx > 0:
92
+ parent_token_idx = parent_list[bid][parent_tb_idx]
93
+ loop_num = draft_token_num - parent_position
94
+ for _ in range(loop_num):
95
+ if selected_index[bid][parent_position] == parent_token_idx:
96
+ parent_position += 1
97
+ break
98
+ parent_position += 1
99
+ if parent_position == draft_token_num:
100
+ continue
101
+
102
+ if retrive_next_token[bid][parent_position] != -1:
103
+ retrive_next_sibling[bid][i] = retrive_next_token[bid][
104
+ parent_position
105
+ ]
106
+ retrive_next_token[bid][parent_position] = i
107
+ else:
108
+ # Process no-root nodes
109
+ cur_position = tid - 1
110
+ while True:
111
+ position += 1
112
+ if cur_position >= draft_token_num:
113
+ tree_mask[token_tree_indices + cur_position] = True
114
+ parent_tb_idx = selected_index[bid][cur_position] // topk
115
+ else:
116
+ parent_tb_idx = parent_tb_indices[bid][cur_position]
117
+ if parent_tb_idx == 0:
118
+ break
119
+ token_idx = parent_list[bid][parent_tb_idx]
120
+ cur_position = 0
121
+ for _ in range(draft_token_num):
122
+ if selected_index[bid][cur_position] == token_idx:
123
+ break
124
+ cur_position += 1
125
+ positions[bid * draft_token_num + tid] += position
126
+ return positions, retrive_index, retrive_next_token, retrive_next_sibling, tree_mask
127
+
128
+
15
129
  def organize_draft_results(
16
130
  score_list: List[torch.Tensor],
17
131
  token_list: List[torch.Tensor],
@@ -114,20 +228,41 @@ def build_tree_kernel_efficient(
114
228
  (bs * num_verify_tokens,), device=device, dtype=torch.long
115
229
  )
116
230
 
117
- sgl_build_tree_kernel_efficient(
118
- parent_list,
119
- top_scores_index,
120
- seq_lens,
121
- tree_mask,
122
- positions,
123
- retrive_index,
124
- retrive_next_token,
125
- retrive_next_sibling,
126
- topk,
127
- spec_steps,
128
- num_verify_tokens,
129
- tree_mask_mode,
130
- )
231
+ if _is_npu:
232
+ (
233
+ positions,
234
+ retrive_index,
235
+ retrive_next_token,
236
+ retrive_next_sibling,
237
+ tree_mask,
238
+ ) = build_tree_efficient_native(
239
+ parent_list,
240
+ top_scores_index,
241
+ seq_lens,
242
+ tree_mask,
243
+ retrive_index,
244
+ retrive_next_token,
245
+ retrive_next_sibling,
246
+ topk,
247
+ num_verify_tokens,
248
+ tree_mask_mode,
249
+ bs,
250
+ )
251
+ else:
252
+ sgl_build_tree_kernel_efficient(
253
+ parent_list,
254
+ top_scores_index,
255
+ seq_lens,
256
+ tree_mask,
257
+ positions,
258
+ retrive_index,
259
+ retrive_next_token,
260
+ retrive_next_sibling,
261
+ topk,
262
+ spec_steps,
263
+ num_verify_tokens,
264
+ tree_mask_mode,
265
+ )
131
266
  return (
132
267
  tree_mask,
133
268
  positions,
@@ -136,3 +271,113 @@ def build_tree_kernel_efficient(
136
271
  retrive_next_sibling,
137
272
  draft_tokens,
138
273
  )
274
+
275
+
276
+ def verify_tree_greedy_native(
277
+ predicts: torch.Tensor,
278
+ accept_index: torch.Tensor,
279
+ accept_token_num: torch.Tensor,
280
+ candidates: torch.Tensor,
281
+ retrive_index: torch.Tensor,
282
+ retrive_next_token: torch.Tensor,
283
+ retrive_next_sibling: torch.Tensor,
284
+ target_predict: torch.Tensor,
285
+ topk: int = -1,
286
+ ):
287
+ batch_size, num_draft_tokens = candidates.shape
288
+
289
+ # Optimized common case for performance.
290
+ if num_draft_tokens == 2 and accept_index.shape[1] == 2 and topk == 1:
291
+ comparison_result = candidates[:, 1] == target_predict[:, 0]
292
+
293
+ predicts = target_predict.flatten()
294
+
295
+ accept_index = torch.arange(
296
+ 0, num_draft_tokens * batch_size, device=candidates.device, dtype=torch.long
297
+ ).reshape(batch_size, num_draft_tokens)
298
+ comparison_result = comparison_result.to(torch.int64)
299
+ accept_index_mask = accept_index[:, 1] * comparison_result
300
+ accept_index[:, 1] = accept_index_mask - (1 - comparison_result)
301
+
302
+ accept_token_num = comparison_result.int()
303
+ return predicts, accept_index, accept_token_num
304
+
305
+ # BFS
306
+ for bx in range(batch_size):
307
+ cur_candidates = candidates[bx]
308
+ cur_retrive_index = retrive_index[bx]
309
+ cur_next_token = retrive_next_token[bx]
310
+ cur_next_sibling = retrive_next_sibling[bx]
311
+ cur_target = target_predict[bx]
312
+
313
+ last_accepted_idx = cur_retrive_index[0]
314
+ accept_index[bx, 0] = last_accepted_idx
315
+ num_accepted = 0
316
+ cur_node = 0
317
+
318
+ for _ in range(1, num_draft_tokens):
319
+ cur_node = cur_next_token[cur_node]
320
+ found = False
321
+ while cur_node != -1:
322
+ draft_idx = cur_retrive_index[cur_node]
323
+ draft_token = cur_candidates[cur_node]
324
+ target_token = cur_target[last_accepted_idx - num_draft_tokens * bx]
325
+
326
+ if draft_token == target_token:
327
+ predicts[last_accepted_idx] = target_token
328
+ num_accepted += 1
329
+ accept_index[bx, num_accepted] = draft_idx
330
+ last_accepted_idx = draft_idx
331
+ found = True
332
+ break
333
+ else:
334
+ cur_node = cur_next_sibling[cur_node]
335
+ if not found:
336
+ break
337
+
338
+ accept_token_num[bx] = num_accepted
339
+ predicts[last_accepted_idx] = cur_target[
340
+ last_accepted_idx - num_draft_tokens * bx
341
+ ]
342
+ return predicts, accept_index, accept_token_num
343
+
344
+
345
+ def verify_tree_greedy_func(
346
+ predicts: torch.Tensor,
347
+ accept_index: torch.Tensor,
348
+ accept_token_num: torch.Tensor,
349
+ candidates: torch.Tensor,
350
+ retrive_index: torch.Tensor,
351
+ retrive_next_token: torch.Tensor,
352
+ retrive_next_sibling: torch.Tensor,
353
+ target_predict: torch.Tensor,
354
+ topk: int = -1,
355
+ ):
356
+ if _is_cuda or _is_hip:
357
+ from sgl_kernel import verify_tree_greedy
358
+
359
+ verify_tree_greedy(
360
+ predicts=predicts, # mutable
361
+ accept_index=accept_index, # mutable
362
+ accept_token_num=accept_token_num, # mutable
363
+ candidates=candidates,
364
+ retrive_index=retrive_index,
365
+ retrive_next_token=retrive_next_token,
366
+ retrive_next_sibling=retrive_next_sibling,
367
+ target_predict=target_predict,
368
+ )
369
+
370
+ elif _is_npu:
371
+ predicts, accept_index, accept_token_num = verify_tree_greedy_native(
372
+ predicts=predicts, # mutable
373
+ accept_index=accept_index, # mutable
374
+ accept_token_num=accept_token_num, # mutable
375
+ candidates=candidates,
376
+ retrive_index=retrive_index,
377
+ retrive_next_token=retrive_next_token,
378
+ retrive_next_sibling=retrive_next_sibling,
379
+ target_predict=target_predict,
380
+ topk=topk,
381
+ )
382
+
383
+ return predicts, accept_index, accept_token_num