sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +48 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +34 -0
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/nixl/conn.py +6 -6
  10. sglang/srt/disaggregation/prefill.py +2 -2
  11. sglang/srt/disaggregation/utils.py +1 -1
  12. sglang/srt/distributed/parallel_state.py +44 -17
  13. sglang/srt/entrypoints/EngineBase.py +8 -0
  14. sglang/srt/entrypoints/engine.py +40 -6
  15. sglang/srt/entrypoints/http_server.py +111 -24
  16. sglang/srt/entrypoints/openai/protocol.py +4 -2
  17. sglang/srt/eplb/__init__.py +0 -0
  18. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  19. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  20. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  21. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  22. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  24. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  25. sglang/srt/hf_transformers_utils.py +2 -1
  26. sglang/srt/layers/activation.py +2 -2
  27. sglang/srt/layers/amx_utils.py +86 -0
  28. sglang/srt/layers/attention/ascend_backend.py +219 -0
  29. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  30. sglang/srt/layers/attention/tbo_backend.py +37 -9
  31. sglang/srt/layers/communicator.py +18 -2
  32. sglang/srt/layers/dp_attention.py +9 -3
  33. sglang/srt/layers/elementwise.py +76 -12
  34. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  35. sglang/srt/layers/layernorm.py +26 -0
  36. sglang/srt/layers/linear.py +84 -14
  37. sglang/srt/layers/logits_processor.py +4 -4
  38. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +36 -13
  40. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  41. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
  42. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
  43. sglang/srt/layers/moe/router.py +60 -22
  44. sglang/srt/layers/moe/topk.py +10 -28
  45. sglang/srt/layers/parameter.py +67 -7
  46. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  47. sglang/srt/layers/quantization/fp8.py +44 -0
  48. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  49. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  50. sglang/srt/layers/quantization/gptq.py +5 -1
  51. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  52. sglang/srt/layers/quantization/quant_utils.py +166 -0
  53. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  54. sglang/srt/layers/rotary_embedding.py +2 -2
  55. sglang/srt/layers/vocab_parallel_embedding.py +11 -7
  56. sglang/srt/lora/lora.py +4 -5
  57. sglang/srt/lora/lora_manager.py +73 -20
  58. sglang/srt/managers/configure_logging.py +1 -1
  59. sglang/srt/managers/io_struct.py +50 -13
  60. sglang/srt/managers/mm_utils.py +73 -59
  61. sglang/srt/managers/multimodal_processor.py +2 -6
  62. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  63. sglang/srt/managers/schedule_batch.py +77 -84
  64. sglang/srt/managers/scheduler.py +113 -59
  65. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  66. sglang/srt/managers/session_controller.py +12 -3
  67. sglang/srt/managers/tokenizer_manager.py +314 -103
  68. sglang/srt/managers/tp_worker.py +13 -1
  69. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  70. sglang/srt/mem_cache/allocator.py +290 -0
  71. sglang/srt/mem_cache/chunk_cache.py +34 -2
  72. sglang/srt/mem_cache/memory_pool.py +289 -3
  73. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  74. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  75. sglang/srt/model_executor/forward_batch_info.py +17 -4
  76. sglang/srt/model_executor/model_runner.py +297 -56
  77. sglang/srt/model_loader/loader.py +41 -0
  78. sglang/srt/model_loader/weight_utils.py +72 -4
  79. sglang/srt/models/deepseek_nextn.py +1 -3
  80. sglang/srt/models/deepseek_v2.py +181 -45
  81. sglang/srt/models/deepseek_vl2.py +3 -5
  82. sglang/srt/models/gemma3_causal.py +1 -2
  83. sglang/srt/models/gemma3n_causal.py +4 -3
  84. sglang/srt/models/gemma3n_mm.py +4 -20
  85. sglang/srt/models/hunyuan.py +1 -1
  86. sglang/srt/models/kimi_vl.py +1 -2
  87. sglang/srt/models/llama.py +10 -4
  88. sglang/srt/models/llama4.py +32 -45
  89. sglang/srt/models/llama_eagle3.py +61 -11
  90. sglang/srt/models/llava.py +5 -5
  91. sglang/srt/models/minicpmo.py +2 -2
  92. sglang/srt/models/mistral.py +1 -1
  93. sglang/srt/models/mllama4.py +43 -11
  94. sglang/srt/models/phi4mm.py +1 -3
  95. sglang/srt/models/pixtral.py +3 -7
  96. sglang/srt/models/qwen2.py +31 -3
  97. sglang/srt/models/qwen2_5_vl.py +1 -3
  98. sglang/srt/models/qwen2_audio.py +200 -0
  99. sglang/srt/models/qwen2_moe.py +32 -6
  100. sglang/srt/models/qwen2_vl.py +1 -4
  101. sglang/srt/models/qwen3.py +94 -25
  102. sglang/srt/models/qwen3_moe.py +68 -21
  103. sglang/srt/models/vila.py +3 -8
  104. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  105. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  106. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  107. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  108. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  109. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  110. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  111. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  112. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  117. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  120. sglang/srt/operations_strategy.py +6 -2
  121. sglang/srt/reasoning_parser.py +26 -0
  122. sglang/srt/sampling/sampling_batch_info.py +39 -1
  123. sglang/srt/server_args.py +69 -22
  124. sglang/srt/speculative/build_eagle_tree.py +57 -18
  125. sglang/srt/speculative/eagle_worker.py +6 -4
  126. sglang/srt/two_batch_overlap.py +200 -27
  127. sglang/srt/utils.py +306 -146
  128. sglang/srt/warmup.py +12 -3
  129. sglang/test/runners.py +10 -1
  130. sglang/test/test_utils.py +15 -3
  131. sglang/version.py +1 -1
  132. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  133. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
  134. sglang/math_utils.py +0 -8
  135. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  136. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  137. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  138. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  139. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  140. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  141. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,6 @@ import torch
10
10
  import sglang.srt.sampling.penaltylib as penaltylib
11
11
  from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
12
12
  from sglang.srt.sampling.sampling_params import TOP_K_ALL
13
- from sglang.srt.utils import merge_bias_tensor
14
13
 
15
14
  if TYPE_CHECKING:
16
15
  from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -345,3 +344,42 @@ class SamplingBatchInfo:
345
344
  self.logit_bias = merge_bias_tensor(
346
345
  self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
347
346
  )
347
+
348
+
349
+ def merge_bias_tensor(
350
+ lhs: Optional[torch.Tensor],
351
+ rhs: Optional[torch.Tensor],
352
+ bs1: int,
353
+ bs2: int,
354
+ device: str,
355
+ default: float,
356
+ ):
357
+ """Merge two bias tensors for batch merging.
358
+
359
+ Args:
360
+ lhs: Left-hand side tensor
361
+ rhs: Right-hand side tensor
362
+ bs1: Batch size of left-hand side tensor
363
+ bs2: Batch size of right-hand side tensor
364
+ device: Device to place the merged tensor on
365
+ default: Default value for missing tensor elements
366
+
367
+ Returns:
368
+ Merged tensor or None if both inputs are None
369
+ """
370
+ if lhs is None and rhs is None:
371
+ return None
372
+
373
+ if lhs is not None and rhs is not None:
374
+ return torch.cat([lhs, rhs])
375
+ else:
376
+ if lhs is not None:
377
+ shape, dtype = lhs.shape[1:], lhs.dtype
378
+ else:
379
+ shape, dtype = rhs.shape[1:], rhs.dtype
380
+
381
+ if lhs is None:
382
+ lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
383
+ if rhs is None:
384
+ rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
385
+ return torch.cat([lhs, rhs])
sglang/srt/server_args.py CHANGED
@@ -20,7 +20,7 @@ import logging
20
20
  import os
21
21
  import random
22
22
  import tempfile
23
- from typing import List, Literal, Optional
23
+ from typing import List, Literal, Optional, Union
24
24
 
25
25
  from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
26
26
  from sglang.srt.reasoning_parser import ReasoningParser
@@ -46,6 +46,7 @@ class ServerArgs:
46
46
  tokenizer_path: Optional[str] = None
47
47
  tokenizer_mode: str = "auto"
48
48
  skip_tokenizer_init: bool = False
49
+ skip_server_warmup: bool = False
49
50
  load_format: str = "auto"
50
51
  model_loader_extra_config: str = "{}"
51
52
  trust_remote_code: bool = False
@@ -61,11 +62,13 @@ class ServerArgs:
61
62
  is_embedding: bool = False
62
63
  enable_multimodal: Optional[bool] = None
63
64
  revision: Optional[str] = None
65
+ hybrid_kvcache_ratio: Optional[float] = None
64
66
  impl: str = "auto"
65
67
 
66
68
  # Port for the HTTP server
67
69
  host: str = "127.0.0.1"
68
70
  port: int = 30000
71
+ nccl_port: Optional[int] = None
69
72
 
70
73
  # Memory and scheduling
71
74
  mem_fraction_static: Optional[float] = None
@@ -98,6 +101,7 @@ class ServerArgs:
98
101
  log_level_http: Optional[str] = None
99
102
  log_requests: bool = False
100
103
  log_requests_level: int = 0
104
+ crash_dump_folder: Optional[str] = None
101
105
  show_time_cost: bool = False
102
106
  enable_metrics: bool = False
103
107
  bucket_time_to_first_token: Optional[List[float]] = None
@@ -129,7 +133,7 @@ class ServerArgs:
129
133
  preferred_sampling_params: Optional[str] = None
130
134
 
131
135
  # LoRA
132
- lora_paths: Optional[List[str]] = None
136
+ lora_paths: Optional[Union[dict[str, str], List[str]]] = None
133
137
  max_loras_per_batch: int = 8
134
138
  lora_backend: str = "triton"
135
139
 
@@ -154,6 +158,7 @@ class ServerArgs:
154
158
  enable_ep_moe: bool = False
155
159
  enable_deepep_moe: bool = False
156
160
  enable_flashinfer_moe: bool = False
161
+ enable_flashinfer_allreduce_fusion: bool = False
157
162
  deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
158
163
  ep_num_redundant_experts: int = 0
159
164
  ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
@@ -315,6 +320,14 @@ class ServerArgs:
315
320
  else:
316
321
  self.mem_fraction_static = 0.88
317
322
 
323
+ # Lazy init to avoid circular import
324
+ from sglang.srt.configs.model_config import ModelConfig
325
+
326
+ # Multimodal models need more memory for the image processor
327
+ model_config = ModelConfig.from_server_args(self)
328
+ if model_config.is_multimodal:
329
+ self.mem_fraction_static *= 0.90
330
+
318
331
  # Set chunked prefill size, which depends on the gpu memory capacity
319
332
  if self.chunked_prefill_size is None:
320
333
  if gpu_mem is not None:
@@ -376,6 +389,12 @@ class ServerArgs:
376
389
  )
377
390
  self.disable_cuda_graph = True
378
391
 
392
+ if self.attention_backend == "ascend":
393
+ logger.warning(
394
+ "At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
395
+ )
396
+ self.page_size = 128
397
+
379
398
  # Choose grammar backend
380
399
  if self.grammar_backend is None:
381
400
  self.grammar_backend = "xgrammar"
@@ -399,10 +418,6 @@ class ServerArgs:
399
418
 
400
419
  # DeepEP MoE
401
420
  if self.enable_deepep_moe:
402
- if self.deepep_mode == "auto":
403
- assert (
404
- not self.enable_dp_attention
405
- ), "DeepEP MoE `auto` mode is not supported with DP Attention."
406
421
  if self.deepep_mode == "normal":
407
422
  logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
408
423
  self.disable_cuda_graph = True
@@ -485,12 +500,6 @@ class ServerArgs:
485
500
  self.speculative_num_draft_tokens,
486
501
  ) = auto_choose_speculative_params(self)
487
502
 
488
- if self.page_size > 1 and self.speculative_eagle_topk > 1:
489
- self.speculative_eagle_topk = 1
490
- logger.warning(
491
- "speculative_eagle_topk is adjusted to 1 when page_size > 1"
492
- )
493
-
494
503
  if (
495
504
  self.speculative_eagle_topk == 1
496
505
  and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
@@ -587,6 +596,12 @@ class ServerArgs:
587
596
  default=ServerArgs.port,
588
597
  help="The port of the HTTP server.",
589
598
  )
599
+ parser.add_argument(
600
+ "--nccl-port",
601
+ type=int,
602
+ default=ServerArgs.nccl_port,
603
+ help="The port for NCCL distributed environment setup. Defaults to a random port.",
604
+ )
590
605
  parser.add_argument(
591
606
  "--tokenizer-mode",
592
607
  type=str,
@@ -601,6 +616,11 @@ class ServerArgs:
601
616
  action="store_true",
602
617
  help="If set, skip init tokenizer and pass input_ids in generate request.",
603
618
  )
619
+ parser.add_argument(
620
+ "--skip-server-warmup",
621
+ action="store_true",
622
+ help="If set, skip warmup.",
623
+ )
604
624
  parser.add_argument(
605
625
  "--load-format",
606
626
  type=str,
@@ -817,6 +837,18 @@ class ServerArgs:
817
837
  default=ServerArgs.page_size,
818
838
  help="The number of tokens in a page.",
819
839
  )
840
+ parser.add_argument(
841
+ "--hybrid-kvcache-ratio",
842
+ nargs="?",
843
+ const=0.5,
844
+ type=float,
845
+ default=ServerArgs.hybrid_kvcache_ratio,
846
+ help=(
847
+ "Mix ratio in [0,1] between uniform and hybrid kv buffers "
848
+ "(0.0 = pure uniform: swa_size / full_size = 1)"
849
+ "(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
850
+ ),
851
+ )
820
852
 
821
853
  # Other runtime options
822
854
  parser.add_argument(
@@ -920,8 +952,14 @@ class ServerArgs:
920
952
  "--log-requests-level",
921
953
  type=int,
922
954
  default=0,
923
- help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.",
924
- choices=[0, 1, 2],
955
+ help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
956
+ choices=[0, 1, 2, 3],
957
+ )
958
+ parser.add_argument(
959
+ "--crash-dump-folder",
960
+ type=str,
961
+ default=ServerArgs.crash_dump_folder,
962
+ help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
925
963
  )
926
964
  parser.add_argument(
927
965
  "--show-time-cost",
@@ -1092,6 +1130,7 @@ class ServerArgs:
1092
1130
  "flashmla",
1093
1131
  "intel_amx",
1094
1132
  "torch_native",
1133
+ "ascend",
1095
1134
  "triton",
1096
1135
  ],
1097
1136
  default=ServerArgs.attention_backend,
@@ -1186,6 +1225,11 @@ class ServerArgs:
1186
1225
  action="store_true",
1187
1226
  help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
1188
1227
  )
1228
+ parser.add_argument(
1229
+ "--enable-flashinfer-allreduce-fusion",
1230
+ action="store_true",
1231
+ help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
1232
+ )
1189
1233
  parser.add_argument(
1190
1234
  "--enable-deepep-moe",
1191
1235
  action="store_true",
@@ -1706,14 +1750,17 @@ class PortArgs:
1706
1750
 
1707
1751
  @staticmethod
1708
1752
  def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
1709
- port = server_args.port + random.randint(100, 1000)
1710
- while True:
1711
- if is_port_available(port):
1712
- break
1713
- if port < 60000:
1714
- port += 42
1715
- else:
1716
- port -= 43
1753
+ if server_args.nccl_port is None:
1754
+ port = server_args.port + random.randint(100, 1000)
1755
+ while True:
1756
+ if is_port_available(port):
1757
+ break
1758
+ if port < 60000:
1759
+ port += 42
1760
+ else:
1761
+ port -= 43
1762
+ else:
1763
+ port = server_args.nccl_port
1717
1764
 
1718
1765
  if not server_args.enable_dp_attention:
1719
1766
  # Normal case, use IPC within a single node
@@ -1,10 +1,12 @@
1
1
  # NOTE: Please run this file to make sure the test cases are correct.
2
2
 
3
- from typing import List
3
+ import math
4
+ from enum import IntEnum
5
+ from typing import List, Optional
4
6
 
5
7
  import torch
6
8
 
7
- from sglang.srt.utils import is_cuda, is_hip, rank0_print
9
+ from sglang.srt.utils import is_cuda, is_hip
8
10
 
9
11
  if is_cuda() or is_hip():
10
12
  from sgl_kernel import (
@@ -40,6 +42,12 @@ def build_tree_kernel_efficient_preprocess(
40
42
  return parent_list, top_scores_index, draft_tokens
41
43
 
42
44
 
45
+ class TreeMaskMode(IntEnum):
46
+ FULL_MASK = 0
47
+ QLEN_ONLY = 1
48
+ QLEN_ONLY_BITPACKING = 2
49
+
50
+
43
51
  def build_tree_kernel_efficient(
44
52
  verified_id: torch.Tensor,
45
53
  score_list: List[torch.Tensor],
@@ -50,6 +58,9 @@ def build_tree_kernel_efficient(
50
58
  topk: int,
51
59
  spec_steps: int,
52
60
  num_verify_tokens: int,
61
+ tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
62
+ tree_mask_buf: Optional[torch.Tensor] = None,
63
+ position_buf: Optional[torch.Tensor] = None,
53
64
  ):
54
65
  parent_list, top_scores_index, draft_tokens = (
55
66
  build_tree_kernel_efficient_preprocess(
@@ -66,15 +77,37 @@ def build_tree_kernel_efficient(
66
77
  device = seq_lens.device
67
78
  # e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
68
79
  # where each row indicates the attending pattern of each draft token
80
+ # if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
81
+ if tree_mask_buf is not None:
82
+ tree_mask = tree_mask_buf
83
+ elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
84
+ tree_mask = torch.full(
85
+ (num_verify_tokens * bs * num_verify_tokens,),
86
+ True,
87
+ dtype=torch.bool,
88
+ device=device,
89
+ )
90
+ elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
91
+ packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
92
+ packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
93
+ tree_mask = torch.zeros(
94
+ (num_verify_tokens * bs,),
95
+ dtype=packed_dtypes[packed_dtype_idx],
96
+ device=device,
97
+ )
98
+ elif tree_mask_mode == TreeMaskMode.FULL_MASK:
99
+ tree_mask = torch.full(
100
+ (
101
+ seq_lens_sum * num_verify_tokens
102
+ + num_verify_tokens * num_verify_tokens * bs,
103
+ ),
104
+ True,
105
+ device=device,
106
+ )
107
+ else:
108
+ raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
109
+
69
110
  # TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
70
- tree_mask = torch.full(
71
- (
72
- seq_lens_sum * num_verify_tokens
73
- + num_verify_tokens * num_verify_tokens * bs,
74
- ),
75
- True,
76
- device=device,
77
- )
78
111
  retrive_index = torch.full(
79
112
  (bs, num_verify_tokens), -1, device=device, dtype=torch.long
80
113
  )
@@ -87,7 +120,12 @@ def build_tree_kernel_efficient(
87
120
  # position: where each token belongs to
88
121
  # e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
89
122
  # then, positions = [7, 8, 8, 9]
90
- positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
123
+ if position_buf is not None:
124
+ positions = position_buf
125
+ else:
126
+ positions = torch.empty(
127
+ (bs * num_verify_tokens,), device=device, dtype=torch.long
128
+ )
91
129
 
92
130
  sgl_build_tree_kernel_efficient(
93
131
  parent_list,
@@ -101,6 +139,7 @@ def build_tree_kernel_efficient(
101
139
  topk,
102
140
  spec_steps,
103
141
  num_verify_tokens,
142
+ tree_mask_mode,
104
143
  )
105
144
  return (
106
145
  tree_mask,
@@ -344,13 +383,13 @@ def test_build_tree_kernel_efficient():
344
383
  num_verify_tokens=num_draft_token,
345
384
  )
346
385
 
347
- rank0_print("=========== build tree kernel efficient ==========")
348
- # rank0_print(f"{tree_mask=}", flush=True)
349
- rank0_print(f"{position=}", flush=True)
350
- rank0_print(f"{retrive_index=}", flush=True)
351
- rank0_print(f"{retrive_next_token=}", flush=True)
352
- rank0_print(f"{retrive_next_sibling=}", flush=True)
353
- rank0_print(f"{draft_tokens=}", flush=True)
386
+ print("=========== build tree kernel efficient ==========")
387
+ print(f"{tree_mask=}")
388
+ print(f"{position=}")
389
+ print(f"{retrive_index=}")
390
+ print(f"{retrive_next_token=}")
391
+ print(f"{retrive_next_sibling=}")
392
+ print(f"{draft_tokens=}")
354
393
  assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
355
394
  assert retrive_index.tolist() == [
356
395
  [0, 1, 2, 3, 4, 5, 6, 7],
@@ -140,9 +140,11 @@ class EAGLEWorker(TpModelWorker):
140
140
  self.draft_model_runner.model.set_embed(embed)
141
141
 
142
142
  # grab hot token ids
143
- self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to(
144
- embed.device
145
- )
143
+ if self.draft_model_runner.model.hot_token_id is not None:
144
+ self.hot_token_id = self.draft_model_runner.model.hot_token_id.to(
145
+ embed.device
146
+ )
147
+
146
148
  else:
147
149
  if self.hot_token_id is not None:
148
150
  head = head.clone()
@@ -842,7 +844,7 @@ class EAGLEWorker(TpModelWorker):
842
844
  )
843
845
  batch.return_hidden_states = False
844
846
  model_worker_batch = batch.get_model_worker_batch()
845
- model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
847
+ model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
846
848
  assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
847
849
  forward_batch = ForwardBatch.init_new(
848
850
  model_worker_batch, self.draft_model_runner