sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,128 @@ from typing import List, Optional
4
4
 
5
5
  import torch
6
6
 
7
- from sglang.srt.utils import is_cuda, is_hip
7
+ from sglang.srt.utils import is_cuda, is_hip, is_npu
8
8
 
9
- if is_cuda() or is_hip():
9
+ _is_cuda = is_cuda()
10
+ _is_hip = is_hip()
11
+ _is_npu = is_npu()
12
+
13
+ if _is_cuda or _is_hip:
10
14
  from sgl_kernel import (
11
15
  build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
12
16
  )
13
17
 
14
18
 
19
+ def build_tree_efficient_native(
20
+ parent_list: torch.Tensor,
21
+ selected_index: torch.Tensor,
22
+ verified_seq_len: torch.Tensor,
23
+ tree_mask: torch.Tensor,
24
+ retrive_index: torch.Tensor,
25
+ retrive_next_token: torch.Tensor,
26
+ retrive_next_sibling: torch.Tensor,
27
+ topk: int,
28
+ draft_token_num: int,
29
+ tree_mask_mode: int,
30
+ bs: int,
31
+ ):
32
+ # Generate batch and token index ranges
33
+ bs_range = torch.arange(bs, device=tree_mask.device).view(-1, 1)
34
+ draft_token_num_range = torch.arange(draft_token_num, device=tree_mask.device)
35
+
36
+ # Optimized common case for performance.
37
+ if draft_token_num == 2 and topk == 1 and tree_mask_mode == TreeMaskMode.FULL_MASK:
38
+ positions = verified_seq_len.repeat_interleave(draft_token_num)
39
+ positions = (positions.view(bs, -1) + draft_token_num_range).view(-1)
40
+
41
+ retrive_index[:] = bs_range * draft_token_num + draft_token_num_range
42
+ retrive_next_token[:, 0] = 1
43
+ retrive_next_token[:, 1] = -1
44
+ return (
45
+ positions,
46
+ retrive_index,
47
+ retrive_next_token,
48
+ retrive_next_sibling,
49
+ tree_mask,
50
+ )
51
+
52
+ # Precompute sequence tree indices
53
+ draft_token_num_range1 = torch.arange(draft_token_num - 1, device=tree_mask.device)
54
+ cum_seq_len = torch.cumsum(verified_seq_len * draft_token_num, dim=0)
55
+ cum_seq_len = torch.cat((torch.tensor([0], device=tree_mask.device), cum_seq_len))
56
+ cum_seq_len = cum_seq_len[:-1]
57
+ seq_tree_idx = (
58
+ draft_token_num * draft_token_num * torch.arange(bs, device=tree_mask.device)
59
+ + cum_seq_len
60
+ )
61
+
62
+ # Batch processing for tree mask
63
+ if tree_mask_mode == TreeMaskMode.FULL_MASK:
64
+ token_tree_base = (
65
+ seq_tree_idx.view(-1, 1)
66
+ + (verified_seq_len.view(-1, 1) + draft_token_num) * draft_token_num_range
67
+ )
68
+ token_tree_indices = token_tree_base + verified_seq_len.view(-1, 1) + 1
69
+ else:
70
+ token_tree_indices = (
71
+ bs_range * draft_token_num**2 + draft_token_num_range * draft_token_num + 1
72
+ )
73
+
74
+ tree_mask[token_tree_indices.flatten() - 1] = True
75
+ indices = token_tree_indices.unsqueeze(-1) + draft_token_num_range1.view(1, 1, -1)
76
+ tree_mask[indices.view(-1)] = False
77
+
78
+ positions = verified_seq_len.repeat_interleave(draft_token_num)
79
+ parent_tb_indices = selected_index // topk
80
+ retrive_index[:] = bs_range * draft_token_num + draft_token_num_range
81
+ tree_mask[token_tree_indices.view(-1, 1) + draft_token_num_range1] = True
82
+
83
+ for bid in range(bs):
84
+ for tid in range(draft_token_num):
85
+ position = 0
86
+ if tid == 0:
87
+ # Process root node
88
+ for i in range(draft_token_num - 1, 0, -1):
89
+ parent_position = 0
90
+ parent_tb_idx = parent_tb_indices[bid][i - 1]
91
+ if parent_tb_idx > 0:
92
+ parent_token_idx = parent_list[bid][parent_tb_idx]
93
+ loop_num = draft_token_num - parent_position
94
+ for _ in range(loop_num):
95
+ if selected_index[bid][parent_position] == parent_token_idx:
96
+ parent_position += 1
97
+ break
98
+ parent_position += 1
99
+ if parent_position == draft_token_num:
100
+ continue
101
+
102
+ if retrive_next_token[bid][parent_position] != -1:
103
+ retrive_next_sibling[bid][i] = retrive_next_token[bid][
104
+ parent_position
105
+ ]
106
+ retrive_next_token[bid][parent_position] = i
107
+ else:
108
+ # Process no-root nodes
109
+ cur_position = tid - 1
110
+ while True:
111
+ position += 1
112
+ if cur_position >= draft_token_num:
113
+ tree_mask[token_tree_indices + cur_position] = True
114
+ parent_tb_idx = selected_index[bid][cur_position] // topk
115
+ else:
116
+ parent_tb_idx = parent_tb_indices[bid][cur_position]
117
+ if parent_tb_idx == 0:
118
+ break
119
+ token_idx = parent_list[bid][parent_tb_idx]
120
+ cur_position = 0
121
+ for _ in range(draft_token_num):
122
+ if selected_index[bid][cur_position] == token_idx:
123
+ break
124
+ cur_position += 1
125
+ positions[bid * draft_token_num + tid] += position
126
+ return positions, retrive_index, retrive_next_token, retrive_next_sibling, tree_mask
127
+
128
+
15
129
  def organize_draft_results(
16
130
  score_list: List[torch.Tensor],
17
131
  token_list: List[torch.Tensor],
@@ -114,20 +228,41 @@ def build_tree_kernel_efficient(
114
228
  (bs * num_verify_tokens,), device=device, dtype=torch.long
115
229
  )
116
230
 
117
- sgl_build_tree_kernel_efficient(
118
- parent_list,
119
- top_scores_index,
120
- seq_lens,
121
- tree_mask,
122
- positions,
123
- retrive_index,
124
- retrive_next_token,
125
- retrive_next_sibling,
126
- topk,
127
- spec_steps,
128
- num_verify_tokens,
129
- tree_mask_mode,
130
- )
231
+ if _is_npu:
232
+ (
233
+ positions,
234
+ retrive_index,
235
+ retrive_next_token,
236
+ retrive_next_sibling,
237
+ tree_mask,
238
+ ) = build_tree_efficient_native(
239
+ parent_list,
240
+ top_scores_index,
241
+ seq_lens,
242
+ tree_mask,
243
+ retrive_index,
244
+ retrive_next_token,
245
+ retrive_next_sibling,
246
+ topk,
247
+ num_verify_tokens,
248
+ tree_mask_mode,
249
+ bs,
250
+ )
251
+ else:
252
+ sgl_build_tree_kernel_efficient(
253
+ parent_list,
254
+ top_scores_index,
255
+ seq_lens,
256
+ tree_mask,
257
+ positions,
258
+ retrive_index,
259
+ retrive_next_token,
260
+ retrive_next_sibling,
261
+ topk,
262
+ spec_steps,
263
+ num_verify_tokens,
264
+ tree_mask_mode,
265
+ )
131
266
  return (
132
267
  tree_mask,
133
268
  positions,
@@ -136,3 +271,113 @@ def build_tree_kernel_efficient(
136
271
  retrive_next_sibling,
137
272
  draft_tokens,
138
273
  )
274
+
275
+
276
+ def verify_tree_greedy_native(
277
+ predicts: torch.Tensor,
278
+ accept_index: torch.Tensor,
279
+ accept_token_num: torch.Tensor,
280
+ candidates: torch.Tensor,
281
+ retrive_index: torch.Tensor,
282
+ retrive_next_token: torch.Tensor,
283
+ retrive_next_sibling: torch.Tensor,
284
+ target_predict: torch.Tensor,
285
+ topk: int = -1,
286
+ ):
287
+ batch_size, num_draft_tokens = candidates.shape
288
+
289
+ # Optimized common case for performance.
290
+ if num_draft_tokens == 2 and accept_index.shape[1] == 2 and topk == 1:
291
+ comparison_result = candidates[:, 1] == target_predict[:, 0]
292
+
293
+ predicts = target_predict.flatten()
294
+
295
+ accept_index = torch.arange(
296
+ 0, num_draft_tokens * batch_size, device=candidates.device, dtype=torch.long
297
+ ).reshape(batch_size, num_draft_tokens)
298
+ comparison_result = comparison_result.to(torch.int64)
299
+ accept_index_mask = accept_index[:, 1] * comparison_result
300
+ accept_index[:, 1] = accept_index_mask - (1 - comparison_result)
301
+
302
+ accept_token_num = comparison_result.int()
303
+ return predicts, accept_index, accept_token_num
304
+
305
+ # BFS
306
+ for bx in range(batch_size):
307
+ cur_candidates = candidates[bx]
308
+ cur_retrive_index = retrive_index[bx]
309
+ cur_next_token = retrive_next_token[bx]
310
+ cur_next_sibling = retrive_next_sibling[bx]
311
+ cur_target = target_predict[bx]
312
+
313
+ last_accepted_idx = cur_retrive_index[0]
314
+ accept_index[bx, 0] = last_accepted_idx
315
+ num_accepted = 0
316
+ cur_node = 0
317
+
318
+ for _ in range(1, num_draft_tokens):
319
+ cur_node = cur_next_token[cur_node]
320
+ found = False
321
+ while cur_node != -1:
322
+ draft_idx = cur_retrive_index[cur_node]
323
+ draft_token = cur_candidates[cur_node]
324
+ target_token = cur_target[last_accepted_idx - num_draft_tokens * bx]
325
+
326
+ if draft_token == target_token:
327
+ predicts[last_accepted_idx] = target_token
328
+ num_accepted += 1
329
+ accept_index[bx, num_accepted] = draft_idx
330
+ last_accepted_idx = draft_idx
331
+ found = True
332
+ break
333
+ else:
334
+ cur_node = cur_next_sibling[cur_node]
335
+ if not found:
336
+ break
337
+
338
+ accept_token_num[bx] = num_accepted
339
+ predicts[last_accepted_idx] = cur_target[
340
+ last_accepted_idx - num_draft_tokens * bx
341
+ ]
342
+ return predicts, accept_index, accept_token_num
343
+
344
+
345
+ def verify_tree_greedy_func(
346
+ predicts: torch.Tensor,
347
+ accept_index: torch.Tensor,
348
+ accept_token_num: torch.Tensor,
349
+ candidates: torch.Tensor,
350
+ retrive_index: torch.Tensor,
351
+ retrive_next_token: torch.Tensor,
352
+ retrive_next_sibling: torch.Tensor,
353
+ target_predict: torch.Tensor,
354
+ topk: int = -1,
355
+ ):
356
+ if _is_cuda or _is_hip:
357
+ from sgl_kernel import verify_tree_greedy
358
+
359
+ verify_tree_greedy(
360
+ predicts=predicts, # mutable
361
+ accept_index=accept_index, # mutable
362
+ accept_token_num=accept_token_num, # mutable
363
+ candidates=candidates,
364
+ retrive_index=retrive_index,
365
+ retrive_next_token=retrive_next_token,
366
+ retrive_next_sibling=retrive_next_sibling,
367
+ target_predict=target_predict,
368
+ )
369
+
370
+ elif _is_npu:
371
+ predicts, accept_index, accept_token_num = verify_tree_greedy_native(
372
+ predicts=predicts, # mutable
373
+ accept_index=accept_index, # mutable
374
+ accept_token_num=accept_token_num, # mutable
375
+ candidates=candidates,
376
+ retrive_index=retrive_index,
377
+ retrive_next_token=retrive_next_token,
378
+ retrive_next_sibling=retrive_next_sibling,
379
+ target_predict=target_predict,
380
+ topk=topk,
381
+ )
382
+
383
+ return predicts, accept_index, accept_token_num
@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
5
5
  import torch
6
6
 
7
7
  from sglang.srt.distributed import get_tp_group
8
+ from sglang.srt.layers.dp_attention import get_attention_tp_group
8
9
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
9
10
  from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
10
11
  from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -52,9 +53,12 @@ from sglang.srt.utils import (
52
53
  get_available_gpu_memory,
53
54
  get_bool_env_var,
54
55
  is_cuda,
56
+ is_npu,
55
57
  next_power_of_2,
56
58
  )
57
59
 
60
+ _is_npu = is_npu()
61
+
58
62
  if is_cuda():
59
63
  from sgl_kernel import segment_packbits # noqa: F401
60
64
 
@@ -117,7 +121,11 @@ class EAGLEWorker(TpModelWorker):
117
121
  self.hot_token_id = None
118
122
 
119
123
  # Init draft worker
120
- with empty_context():
124
+ if server_args.enable_dp_attention and self.speculative_algorithm.is_eagle3():
125
+ ctx = draft_tp_context(get_attention_tp_group())
126
+ else:
127
+ ctx = empty_context()
128
+ with ctx:
121
129
  super().__init__(
122
130
  server_args=server_args,
123
131
  gpu_id=gpu_id,
@@ -200,7 +208,7 @@ class EAGLEWorker(TpModelWorker):
200
208
  self.cuda_graph_runner = None
201
209
  self.cuda_graph_runner_for_draft_extend = None
202
210
 
203
- if self.server_args.disable_cuda_graph:
211
+ if self.server_args.disable_cuda_graph or _is_npu:
204
212
  return
205
213
 
206
214
  # Capture draft
@@ -940,7 +948,7 @@ class EAGLEWorker(TpModelWorker):
940
948
  draft_input.hidden_states = logits_output.hidden_states
941
949
 
942
950
 
943
- @torch.compile(dynamic=True)
951
+ @torch.compile(dynamic=True, disable=_is_npu)
944
952
  def get_last_loc_large_page_size_top_k_1(
945
953
  req_to_token: torch.Tensor,
946
954
  req_pool_indices: torch.Tensor,
@@ -4,7 +4,6 @@ import time
4
4
  from typing import List, Optional, Tuple
5
5
 
6
6
  import torch
7
- from torch.cuda import Stream as CudaStream
8
7
 
9
8
  from sglang.srt.environ import envs
10
9
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
@@ -38,18 +37,21 @@ from sglang.srt.utils.common import (
38
37
  empty_context,
39
38
  fast_topk,
40
39
  get_available_gpu_memory,
40
+ is_npu,
41
41
  next_power_of_2,
42
42
  )
43
43
 
44
+ _is_npu = is_npu()
45
+
44
46
  logger = logging.getLogger(__name__)
45
47
 
46
48
 
47
49
  def _get_plan_stream(
48
50
  device: str,
49
- ) -> Tuple[Optional[CudaStream], contextlib.AbstractContextManager]:
51
+ ) -> Tuple[any, contextlib.AbstractContextManager]:
50
52
  if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
51
- plan_stream: CudaStream = torch.get_device_module(device).Stream()
52
- plan_stream_ctx = torch.cuda.stream(plan_stream)
53
+ plan_stream = torch.get_device_module(device).Stream()
54
+ plan_stream_ctx = torch.get_device_module(device).stream(plan_stream)
53
55
  return plan_stream, plan_stream_ctx
54
56
  else:
55
57
  return None, contextlib.nullcontext()
@@ -206,7 +208,7 @@ class EagleDraftWorker(BaseDraftWorker):
206
208
  self.cuda_graph_runner = None
207
209
  self.cuda_graph_runner_for_draft_extend = None
208
210
 
209
- if self.server_args.disable_cuda_graph:
211
+ if self.server_args.disable_cuda_graph or _is_npu:
210
212
  return
211
213
 
212
214
  # Capture draft
@@ -456,7 +458,9 @@ class EagleDraftWorker(BaseDraftWorker):
456
458
  )
457
459
 
458
460
  if self.plan_stream:
459
- torch.cuda.current_stream().wait_stream(self.plan_stream)
461
+ torch.get_device_module(self.device).current_stream().wait_stream(
462
+ self.plan_stream
463
+ )
460
464
 
461
465
  # Run draft extend batch in the main compute stream
462
466
  draft_logits_output = self.draft_runner.model.forward(
@@ -577,7 +581,9 @@ class EAGLEWorkerV2(BaseSpecWorker):
577
581
  # Since batch.seq_lens is allocated in another stream, we need
578
582
  # record_stream() to prevent pytorch gc and reuse the gpu memory
579
583
  # while forward_stream is still running.
580
- batch.seq_lens.record_stream(torch.cuda.current_stream())
584
+ batch.seq_lens.record_stream(
585
+ torch.get_device_module(self.device).current_stream()
586
+ )
581
587
 
582
588
  # Parse args
583
589
  verify_input: EagleVerifyInput = batch.spec_info
@@ -596,7 +602,7 @@ class EAGLEWorkerV2(BaseSpecWorker):
596
602
 
597
603
  # Correct some buffers due to the overlap plan
598
604
  if self.plan_stream:
599
- torch.cuda.current_stream().wait_stream(self.plan_stream)
605
+ torch.get_device_module().current_stream().wait_stream(self.plan_stream)
600
606
 
601
607
  # Some values such as custom_mask and position depend on the output of draft,
602
608
  # so the previous plan step used the wrong values. Here, we need to run the related
@@ -628,7 +634,7 @@ class EAGLEWorkerV2(BaseSpecWorker):
628
634
  accept_index,
629
635
  ) = verify_input.sample(batch, logits_output)
630
636
  new_seq_lens = batch.seq_lens + accept_length
631
- verify_done = torch.cuda.Event()
637
+ verify_done = torch.get_device_module(self.device).Event()
632
638
  verify_done.record()
633
639
 
634
640
  all_verified_id = predict[accept_index]