sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,6 @@ import torch
10
10
  import sglang.srt.sampling.penaltylib as penaltylib
11
11
  from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
12
12
  from sglang.srt.sampling.sampling_params import TOP_K_ALL
13
- from sglang.srt.utils import merge_bias_tensor
14
13
 
15
14
  if TYPE_CHECKING:
16
15
  from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -345,3 +344,42 @@ class SamplingBatchInfo:
345
344
  self.logit_bias = merge_bias_tensor(
346
345
  self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
347
346
  )
347
+
348
+
349
+ def merge_bias_tensor(
350
+ lhs: Optional[torch.Tensor],
351
+ rhs: Optional[torch.Tensor],
352
+ bs1: int,
353
+ bs2: int,
354
+ device: str,
355
+ default: float,
356
+ ):
357
+ """Merge two bias tensors for batch merging.
358
+
359
+ Args:
360
+ lhs: Left-hand side tensor
361
+ rhs: Right-hand side tensor
362
+ bs1: Batch size of left-hand side tensor
363
+ bs2: Batch size of right-hand side tensor
364
+ device: Device to place the merged tensor on
365
+ default: Default value for missing tensor elements
366
+
367
+ Returns:
368
+ Merged tensor or None if both inputs are None
369
+ """
370
+ if lhs is None and rhs is None:
371
+ return None
372
+
373
+ if lhs is not None and rhs is not None:
374
+ return torch.cat([lhs, rhs])
375
+ else:
376
+ if lhs is not None:
377
+ shape, dtype = lhs.shape[1:], lhs.dtype
378
+ else:
379
+ shape, dtype = rhs.shape[1:], rhs.dtype
380
+
381
+ if lhs is None:
382
+ lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
383
+ if rhs is None:
384
+ rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
385
+ return torch.cat([lhs, rhs])
sglang/srt/server_args.py CHANGED
@@ -20,7 +20,7 @@ import logging
20
20
  import os
21
21
  import random
22
22
  import tempfile
23
- from typing import List, Literal, Optional
23
+ from typing import List, Literal, Optional, Union
24
24
 
25
25
  from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
26
26
  from sglang.srt.reasoning_parser import ReasoningParser
@@ -46,7 +46,9 @@ class ServerArgs:
46
46
  tokenizer_path: Optional[str] = None
47
47
  tokenizer_mode: str = "auto"
48
48
  skip_tokenizer_init: bool = False
49
+ skip_server_warmup: bool = False
49
50
  load_format: str = "auto"
51
+ model_loader_extra_config: str = "{}"
50
52
  trust_remote_code: bool = False
51
53
  dtype: str = "auto"
52
54
  kv_cache_dtype: str = "auto"
@@ -60,11 +62,13 @@ class ServerArgs:
60
62
  is_embedding: bool = False
61
63
  enable_multimodal: Optional[bool] = None
62
64
  revision: Optional[str] = None
65
+ hybrid_kvcache_ratio: Optional[float] = None
63
66
  impl: str = "auto"
64
67
 
65
68
  # Port for the HTTP server
66
69
  host: str = "127.0.0.1"
67
70
  port: int = 30000
71
+ nccl_port: Optional[int] = None
68
72
 
69
73
  # Memory and scheduling
70
74
  mem_fraction_static: Optional[float] = None
@@ -97,6 +101,7 @@ class ServerArgs:
97
101
  log_level_http: Optional[str] = None
98
102
  log_requests: bool = False
99
103
  log_requests_level: int = 0
104
+ crash_dump_folder: Optional[str] = None
100
105
  show_time_cost: bool = False
101
106
  enable_metrics: bool = False
102
107
  bucket_time_to_first_token: Optional[List[float]] = None
@@ -128,7 +133,7 @@ class ServerArgs:
128
133
  preferred_sampling_params: Optional[str] = None
129
134
 
130
135
  # LoRA
131
- lora_paths: Optional[List[str]] = None
136
+ lora_paths: Optional[Union[dict[str, str], List[str]]] = None
132
137
  max_loras_per_batch: int = 8
133
138
  lora_backend: str = "triton"
134
139
 
@@ -153,6 +158,7 @@ class ServerArgs:
153
158
  enable_ep_moe: bool = False
154
159
  enable_deepep_moe: bool = False
155
160
  enable_flashinfer_moe: bool = False
161
+ enable_flashinfer_allreduce_fusion: bool = False
156
162
  deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
157
163
  ep_num_redundant_experts: int = 0
158
164
  ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
@@ -314,6 +320,14 @@ class ServerArgs:
314
320
  else:
315
321
  self.mem_fraction_static = 0.88
316
322
 
323
+ # Lazy init to avoid circular import
324
+ from sglang.srt.configs.model_config import ModelConfig
325
+
326
+ # Multimodal models need more memory for the image processor
327
+ model_config = ModelConfig.from_server_args(self)
328
+ if model_config.is_multimodal:
329
+ self.mem_fraction_static *= 0.90
330
+
317
331
  # Set chunked prefill size, which depends on the gpu memory capacity
318
332
  if self.chunked_prefill_size is None:
319
333
  if gpu_mem is not None:
@@ -375,6 +389,12 @@ class ServerArgs:
375
389
  )
376
390
  self.disable_cuda_graph = True
377
391
 
392
+ if self.attention_backend == "ascend":
393
+ logger.warning(
394
+ "At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
395
+ )
396
+ self.page_size = 128
397
+
378
398
  # Choose grammar backend
379
399
  if self.grammar_backend is None:
380
400
  self.grammar_backend = "xgrammar"
@@ -398,10 +418,6 @@ class ServerArgs:
398
418
 
399
419
  # DeepEP MoE
400
420
  if self.enable_deepep_moe:
401
- if self.deepep_mode == "auto":
402
- assert (
403
- not self.enable_dp_attention
404
- ), "DeepEP MoE `auto` mode is not supported with DP Attention."
405
421
  if self.deepep_mode == "normal":
406
422
  logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
407
423
  self.disable_cuda_graph = True
@@ -484,12 +500,6 @@ class ServerArgs:
484
500
  self.speculative_num_draft_tokens,
485
501
  ) = auto_choose_speculative_params(self)
486
502
 
487
- if self.page_size > 1 and self.speculative_eagle_topk > 1:
488
- self.speculative_eagle_topk = 1
489
- logger.warning(
490
- "speculative_eagle_topk is adjusted to 1 when page_size > 1"
491
- )
492
-
493
503
  if (
494
504
  self.speculative_eagle_topk == 1
495
505
  and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
@@ -563,6 +573,7 @@ class ServerArgs:
563
573
  # Model and port args
564
574
  parser.add_argument(
565
575
  "--model-path",
576
+ "--model",
566
577
  type=str,
567
578
  help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
568
579
  required=True,
@@ -585,6 +596,12 @@ class ServerArgs:
585
596
  default=ServerArgs.port,
586
597
  help="The port of the HTTP server.",
587
598
  )
599
+ parser.add_argument(
600
+ "--nccl-port",
601
+ type=int,
602
+ default=ServerArgs.nccl_port,
603
+ help="The port for NCCL distributed environment setup. Defaults to a random port.",
604
+ )
588
605
  parser.add_argument(
589
606
  "--tokenizer-mode",
590
607
  type=str,
@@ -599,6 +616,11 @@ class ServerArgs:
599
616
  action="store_true",
600
617
  help="If set, skip init tokenizer and pass input_ids in generate request.",
601
618
  )
619
+ parser.add_argument(
620
+ "--skip-server-warmup",
621
+ action="store_true",
622
+ help="If set, skip warmup.",
623
+ )
602
624
  parser.add_argument(
603
625
  "--load-format",
604
626
  type=str,
@@ -632,6 +654,13 @@ class ServerArgs:
632
654
  "layer before loading another to make the peak memory envelope "
633
655
  "smaller.",
634
656
  )
657
+ parser.add_argument(
658
+ "--model-loader-extra-config",
659
+ type=str,
660
+ help="Extra config for model loader. "
661
+ "This will be passed to the model loader corresponding to the chosen load_format.",
662
+ default=ServerArgs.model_loader_extra_config,
663
+ )
635
664
  parser.add_argument(
636
665
  "--trust-remote-code",
637
666
  action="store_true",
@@ -808,6 +837,18 @@ class ServerArgs:
808
837
  default=ServerArgs.page_size,
809
838
  help="The number of tokens in a page.",
810
839
  )
840
+ parser.add_argument(
841
+ "--hybrid-kvcache-ratio",
842
+ nargs="?",
843
+ const=0.5,
844
+ type=float,
845
+ default=ServerArgs.hybrid_kvcache_ratio,
846
+ help=(
847
+ "Mix ratio in [0,1] between uniform and hybrid kv buffers "
848
+ "(0.0 = pure uniform: swa_size / full_size = 1)"
849
+ "(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
850
+ ),
851
+ )
811
852
 
812
853
  # Other runtime options
813
854
  parser.add_argument(
@@ -911,8 +952,14 @@ class ServerArgs:
911
952
  "--log-requests-level",
912
953
  type=int,
913
954
  default=0,
914
- help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.",
915
- choices=[0, 1, 2],
955
+ help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
956
+ choices=[0, 1, 2, 3],
957
+ )
958
+ parser.add_argument(
959
+ "--crash-dump-folder",
960
+ type=str,
961
+ default=ServerArgs.crash_dump_folder,
962
+ help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
916
963
  )
917
964
  parser.add_argument(
918
965
  "--show-time-cost",
@@ -1083,6 +1130,7 @@ class ServerArgs:
1083
1130
  "flashmla",
1084
1131
  "intel_amx",
1085
1132
  "torch_native",
1133
+ "ascend",
1086
1134
  "triton",
1087
1135
  ],
1088
1136
  default=ServerArgs.attention_backend,
@@ -1177,6 +1225,11 @@ class ServerArgs:
1177
1225
  action="store_true",
1178
1226
  help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
1179
1227
  )
1228
+ parser.add_argument(
1229
+ "--enable-flashinfer-allreduce-fusion",
1230
+ action="store_true",
1231
+ help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
1232
+ )
1180
1233
  parser.add_argument(
1181
1234
  "--enable-deepep-moe",
1182
1235
  action="store_true",
@@ -1692,16 +1745,22 @@ class PortArgs:
1692
1745
  # The ipc filename for rpc call between Engine and Scheduler
1693
1746
  rpc_ipc_name: str
1694
1747
 
1748
+ # The ipc filename for Scheduler to send metrics
1749
+ metrics_ipc_name: str
1750
+
1695
1751
  @staticmethod
1696
1752
  def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
1697
- port = server_args.port + random.randint(100, 1000)
1698
- while True:
1699
- if is_port_available(port):
1700
- break
1701
- if port < 60000:
1702
- port += 42
1703
- else:
1704
- port -= 43
1753
+ if server_args.nccl_port is None:
1754
+ port = server_args.port + random.randint(100, 1000)
1755
+ while True:
1756
+ if is_port_available(port):
1757
+ break
1758
+ if port < 60000:
1759
+ port += 42
1760
+ else:
1761
+ port -= 43
1762
+ else:
1763
+ port = server_args.nccl_port
1705
1764
 
1706
1765
  if not server_args.enable_dp_attention:
1707
1766
  # Normal case, use IPC within a single node
@@ -1711,6 +1770,7 @@ class PortArgs:
1711
1770
  detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1712
1771
  nccl_port=port,
1713
1772
  rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1773
+ metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1714
1774
  )
1715
1775
  else:
1716
1776
  # DP attention. Use TCP + port to handle both single-node and multi-node.
@@ -1730,9 +1790,9 @@ class PortArgs:
1730
1790
  port_base = int(dist_init_port) + 1
1731
1791
  if dp_rank is None:
1732
1792
  # TokenizerManager to DataParallelController
1733
- scheduler_input_port = port_base + 3
1793
+ scheduler_input_port = port_base + 4
1734
1794
  else:
1735
- scheduler_input_port = port_base + 3 + 1 + dp_rank
1795
+ scheduler_input_port = port_base + 4 + 1 + dp_rank
1736
1796
 
1737
1797
  return PortArgs(
1738
1798
  tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
@@ -1740,6 +1800,7 @@ class PortArgs:
1740
1800
  detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
1741
1801
  nccl_port=port,
1742
1802
  rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
1803
+ metrics_ipc_name=f"tcp://{dist_init_host}:{port_base + 3}",
1743
1804
  )
1744
1805
 
1745
1806
 
@@ -1,10 +1,12 @@
1
1
  # NOTE: Please run this file to make sure the test cases are correct.
2
2
 
3
- from typing import List
3
+ import math
4
+ from enum import IntEnum
5
+ from typing import List, Optional
4
6
 
5
7
  import torch
6
8
 
7
- from sglang.srt.utils import is_cuda, is_hip, rank0_print
9
+ from sglang.srt.utils import is_cuda, is_hip
8
10
 
9
11
  if is_cuda() or is_hip():
10
12
  from sgl_kernel import (
@@ -40,6 +42,12 @@ def build_tree_kernel_efficient_preprocess(
40
42
  return parent_list, top_scores_index, draft_tokens
41
43
 
42
44
 
45
+ class TreeMaskMode(IntEnum):
46
+ FULL_MASK = 0
47
+ QLEN_ONLY = 1
48
+ QLEN_ONLY_BITPACKING = 2
49
+
50
+
43
51
  def build_tree_kernel_efficient(
44
52
  verified_id: torch.Tensor,
45
53
  score_list: List[torch.Tensor],
@@ -50,6 +58,9 @@ def build_tree_kernel_efficient(
50
58
  topk: int,
51
59
  spec_steps: int,
52
60
  num_verify_tokens: int,
61
+ tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
62
+ tree_mask_buf: Optional[torch.Tensor] = None,
63
+ position_buf: Optional[torch.Tensor] = None,
53
64
  ):
54
65
  parent_list, top_scores_index, draft_tokens = (
55
66
  build_tree_kernel_efficient_preprocess(
@@ -66,15 +77,37 @@ def build_tree_kernel_efficient(
66
77
  device = seq_lens.device
67
78
  # e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
68
79
  # where each row indicates the attending pattern of each draft token
80
+ # if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
81
+ if tree_mask_buf is not None:
82
+ tree_mask = tree_mask_buf
83
+ elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
84
+ tree_mask = torch.full(
85
+ (num_verify_tokens * bs * num_verify_tokens,),
86
+ True,
87
+ dtype=torch.bool,
88
+ device=device,
89
+ )
90
+ elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
91
+ packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
92
+ packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
93
+ tree_mask = torch.zeros(
94
+ (num_verify_tokens * bs,),
95
+ dtype=packed_dtypes[packed_dtype_idx],
96
+ device=device,
97
+ )
98
+ elif tree_mask_mode == TreeMaskMode.FULL_MASK:
99
+ tree_mask = torch.full(
100
+ (
101
+ seq_lens_sum * num_verify_tokens
102
+ + num_verify_tokens * num_verify_tokens * bs,
103
+ ),
104
+ True,
105
+ device=device,
106
+ )
107
+ else:
108
+ raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
109
+
69
110
  # TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
70
- tree_mask = torch.full(
71
- (
72
- seq_lens_sum * num_verify_tokens
73
- + num_verify_tokens * num_verify_tokens * bs,
74
- ),
75
- True,
76
- device=device,
77
- )
78
111
  retrive_index = torch.full(
79
112
  (bs, num_verify_tokens), -1, device=device, dtype=torch.long
80
113
  )
@@ -87,7 +120,12 @@ def build_tree_kernel_efficient(
87
120
  # position: where each token belongs to
88
121
  # e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
89
122
  # then, positions = [7, 8, 8, 9]
90
- positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
123
+ if position_buf is not None:
124
+ positions = position_buf
125
+ else:
126
+ positions = torch.empty(
127
+ (bs * num_verify_tokens,), device=device, dtype=torch.long
128
+ )
91
129
 
92
130
  sgl_build_tree_kernel_efficient(
93
131
  parent_list,
@@ -101,6 +139,7 @@ def build_tree_kernel_efficient(
101
139
  topk,
102
140
  spec_steps,
103
141
  num_verify_tokens,
142
+ tree_mask_mode,
104
143
  )
105
144
  return (
106
145
  tree_mask,
@@ -344,13 +383,13 @@ def test_build_tree_kernel_efficient():
344
383
  num_verify_tokens=num_draft_token,
345
384
  )
346
385
 
347
- rank0_print("=========== build tree kernel efficient ==========")
348
- # rank0_print(f"{tree_mask=}", flush=True)
349
- rank0_print(f"{position=}", flush=True)
350
- rank0_print(f"{retrive_index=}", flush=True)
351
- rank0_print(f"{retrive_next_token=}", flush=True)
352
- rank0_print(f"{retrive_next_sibling=}", flush=True)
353
- rank0_print(f"{draft_tokens=}", flush=True)
386
+ print("=========== build tree kernel efficient ==========")
387
+ print(f"{tree_mask=}")
388
+ print(f"{position=}")
389
+ print(f"{retrive_index=}")
390
+ print(f"{retrive_next_token=}")
391
+ print(f"{retrive_next_sibling=}")
392
+ print(f"{draft_tokens=}")
354
393
  assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
355
394
  assert retrive_index.tolist() == [
356
395
  [0, 1, 2, 3, 4, 5, 6, 7],
@@ -140,9 +140,11 @@ class EAGLEWorker(TpModelWorker):
140
140
  self.draft_model_runner.model.set_embed(embed)
141
141
 
142
142
  # grab hot token ids
143
- self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to(
144
- embed.device
145
- )
143
+ if self.draft_model_runner.model.hot_token_id is not None:
144
+ self.hot_token_id = self.draft_model_runner.model.hot_token_id.to(
145
+ embed.device
146
+ )
147
+
146
148
  else:
147
149
  if self.hot_token_id is not None:
148
150
  head = head.clone()
@@ -842,7 +844,7 @@ class EAGLEWorker(TpModelWorker):
842
844
  )
843
845
  batch.return_hidden_states = False
844
846
  model_worker_batch = batch.get_model_worker_batch()
845
- model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
847
+ model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
846
848
  assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
847
849
  forward_batch = ForwardBatch.init_new(
848
850
  model_worker_batch, self.draft_model_runner