sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -155,7 +155,7 @@ class ServerArgs:
155
155
  enable_nccl_nvls: bool = False
156
156
  disable_outlines_disk_cache: bool = False
157
157
  disable_custom_all_reduce: bool = False
158
- disable_mla: bool = False
158
+ enable_llama4_multimodal: Optional[bool] = None
159
159
  disable_overlap_schedule: bool = False
160
160
  enable_mixed_chunk: bool = False
161
161
  enable_dp_attention: bool = False
@@ -179,12 +179,12 @@ class ServerArgs:
179
179
  tool_call_parser: Optional[str] = None
180
180
  enable_hierarchical_cache: bool = False
181
181
  hicache_ratio: float = 2.0
182
- enable_flashinfer_mla: bool = False # TODO: remove this argument
183
- enable_flashmla: bool = False
184
182
  flashinfer_mla_disable_ragged: bool = False
185
183
  warmups: Optional[str] = None
184
+ moe_dense_tp_size: Optional[int] = None
186
185
  n_share_experts_fusion: int = 0
187
- disable_shared_experts_fusion: bool = False
186
+ disable_chunked_prefix_cache: bool = False
187
+ disable_fast_image_processor: bool = False
188
188
 
189
189
  # Debug tensor dumps
190
190
  debug_tensor_dump_output_folder: Optional[str] = None
@@ -194,6 +194,8 @@ class ServerArgs:
194
194
  # For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
195
195
  disaggregation_mode: str = "null"
196
196
  disaggregation_bootstrap_port: int = 8998
197
+ disaggregation_transfer_backend: str = "mooncake"
198
+ disaggregation_ib_device: Optional[str] = None
197
199
 
198
200
  def __post_init__(self):
199
201
  # Expert parallelism
@@ -226,9 +228,6 @@ class ServerArgs:
226
228
  # GPU memory is not known yet or no GPU is available.
227
229
  gpu_mem = None
228
230
 
229
- if is_hip():
230
- self.disable_shared_experts_fusion = True
231
-
232
231
  # Set mem fraction static, which depends on the tensor parallelism size
233
232
  if self.mem_fraction_static is None:
234
233
  if self.tp_size >= 16:
@@ -251,7 +250,12 @@ class ServerArgs:
251
250
 
252
251
  assert self.chunked_prefill_size % self.page_size == 0
253
252
 
254
- if self.enable_flashmla is True:
253
+ assert self.moe_dense_tp_size in {
254
+ 1,
255
+ None,
256
+ }, f"moe_dense_tp_size only support 1 and None currently"
257
+
258
+ if self.attention_backend == "flashmla":
255
259
  logger.warning(
256
260
  "FlashMLA only supports a page_size of 64, change page_size to 64."
257
261
  )
@@ -294,6 +298,8 @@ class ServerArgs:
294
298
  f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
295
299
  )
296
300
 
301
+ self.enable_multimodal: Optional[bool] = self.enable_llama4_multimodal
302
+
297
303
  # Data parallelism attention
298
304
  if self.enable_dp_attention:
299
305
  self.schedule_conservativeness = self.schedule_conservativeness * 0.3
@@ -386,6 +392,10 @@ class ServerArgs:
386
392
  os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
387
393
  "1" if self.enable_torch_compile else "0"
388
394
  )
395
+ # Set env var before grammar backends init
396
+ os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = (
397
+ "1" if self.disable_outlines_disk_cache else "0"
398
+ )
389
399
 
390
400
  @staticmethod
391
401
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -495,6 +505,7 @@ class ServerArgs:
495
505
  "bitsandbytes",
496
506
  "gguf",
497
507
  "modelopt",
508
+ "modelopt_fp4",
498
509
  "w8a8_int8",
499
510
  "w8a8_fp8",
500
511
  "moe_wna16",
@@ -817,7 +828,7 @@ class ServerArgs:
817
828
  parser.add_argument(
818
829
  "--attention-backend",
819
830
  type=str,
820
- choices=["flashinfer", "triton", "torch_native", "fa3"],
831
+ choices=["flashinfer", "triton", "torch_native", "fa3", "flashmla"],
821
832
  default=ServerArgs.attention_backend,
822
833
  help="Choose the kernels for attention layers.",
823
834
  )
@@ -837,13 +848,13 @@ class ServerArgs:
837
848
  )
838
849
  parser.add_argument(
839
850
  "--enable-flashinfer-mla",
840
- action="store_true",
841
- help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!",
851
+ action=DeprecatedAction,
852
+ help="--enable-flashinfer-mla is deprecated. Please use '--attention-backend flashinfer' instead.",
842
853
  )
843
854
  parser.add_argument(
844
855
  "--enable-flashmla",
845
- action="store_true",
846
- help="Enable FlashMLA decode optimization",
856
+ action=DeprecatedAction,
857
+ help="--enable-flashmla is deprecated. Please use '--attention-backend flashmla' instead.",
847
858
  )
848
859
  parser.add_argument(
849
860
  "--flashinfer-mla-disable-ragged",
@@ -969,9 +980,10 @@ class ServerArgs:
969
980
  help="Disable the custom all-reduce kernel and fall back to NCCL.",
970
981
  )
971
982
  parser.add_argument(
972
- "--disable-mla",
983
+ "--enable-llama4-multimodal",
984
+ default=ServerArgs.enable_llama4_multimodal,
973
985
  action="store_true",
974
- help="Disable Multi-head Latent Attention (MLA) for DeepSeek V2/V3/R1 series models.",
986
+ help="Enable the multimodal functionality for Llama-4.",
975
987
  )
976
988
  parser.add_argument(
977
989
  "--disable-overlap-schedule",
@@ -1096,10 +1108,17 @@ class ServerArgs:
1096
1108
  action="store_true",
1097
1109
  help="Enabling DeepEP MoE implementation for EP MoE.",
1098
1110
  )
1111
+ parser.add_argument(
1112
+ "--moe-dense-tp-size",
1113
+ type=int,
1114
+ default=ServerArgs.moe_dense_tp_size,
1115
+ help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
1116
+ )
1099
1117
  parser.add_argument(
1100
1118
  "--deepep-mode",
1101
1119
  type=str,
1102
1120
  choices=["normal", "low_latency", "auto"],
1121
+ default="auto",
1103
1122
  help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
1104
1123
  )
1105
1124
 
@@ -1107,13 +1126,18 @@ class ServerArgs:
1107
1126
  "--n-share-experts-fusion",
1108
1127
  type=int,
1109
1128
  default=0,
1110
- help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 "
1111
- "we use tp_size by default.",
1129
+ help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, "
1130
+ "set it to tp_size can get best optimized performace.",
1112
1131
  )
1113
1132
  parser.add_argument(
1114
- "--disable-shared-experts-fusion",
1133
+ "--disable-chunked-prefix-cache",
1115
1134
  action="store_true",
1116
- help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
1135
+ help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
1136
+ )
1137
+ parser.add_argument(
1138
+ "--disable-fast-image-processor",
1139
+ action="store_true",
1140
+ help="Adopt base image processor instead of fast image processor.",
1117
1141
  )
1118
1142
 
1119
1143
  # Server warmups
@@ -1159,6 +1183,18 @@ class ServerArgs:
1159
1183
  default=ServerArgs.disaggregation_bootstrap_port,
1160
1184
  help="Bootstrap server port on the prefill server. Default is 8998.",
1161
1185
  )
1186
+ parser.add_argument(
1187
+ "--disaggregation-transfer-backend",
1188
+ type=str,
1189
+ default=ServerArgs.disaggregation_transfer_backend,
1190
+ help="The backend for disaggregation transfer. Default is mooncake.",
1191
+ )
1192
+ parser.add_argument(
1193
+ "--disaggregation-ib-device",
1194
+ type=str,
1195
+ default=ServerArgs.disaggregation_ib_device,
1196
+ help="The ib device for disaggregation transfer. Default is None, it will be detected automatically if using the mooncake backend.",
1197
+ )
1162
1198
 
1163
1199
  @classmethod
1164
1200
  def from_cli_args(cls, args: argparse.Namespace):
@@ -84,10 +84,10 @@ class EAGLEDraftCudaGraphRunner:
84
84
  raise Exception(
85
85
  f"Capture cuda graph failed: {e}\n"
86
86
  "Possible solutions:\n"
87
- "1. disable cuda graph by --disable-cuda-graph\n"
88
- "2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
89
- "3. disable torch compile by not using --enable-torch-compile\n"
90
- "4. specify --dtype to the same dtype (e.g. bfloat16)\n"
87
+ "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
88
+ "2. disable torch compile by not using --enable-torch-compile\n"
89
+ "3. specify --dtype to the same dtype (e.g. bfloat16)\n"
90
+ "4. disable cuda graph by --disable-cuda-graph\n"
91
91
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
92
92
  )
93
93
 
@@ -19,7 +19,7 @@ from sglang.srt.managers.schedule_batch import (
19
19
  from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
20
20
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
21
21
  from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
22
- from sglang.srt.utils import is_cuda_available, is_hip, next_power_of_2
22
+ from sglang.srt.utils import fast_topk, is_cuda_available, is_hip, next_power_of_2
23
23
 
24
24
  if is_cuda_available():
25
25
  from sgl_kernel import (
@@ -772,16 +772,6 @@ def select_top_k_tokens(
772
772
  return input_ids, hidden_states, scores, tree_info
773
773
 
774
774
 
775
- def fast_topk(values, topk, dim):
776
- if topk == 1:
777
- # Use max along the specified dimension to get both value and index
778
- max_value, max_index = torch.max(values, dim=dim)
779
- return max_value.unsqueeze(1), max_index.unsqueeze(1)
780
- else:
781
- # Use topk for efficiency with larger k values
782
- return torch.topk(values, topk, dim=dim)
783
-
784
-
785
775
  def _generate_simulated_accept_index(
786
776
  accept_index,
787
777
  predict,
@@ -31,11 +31,15 @@ from sglang.srt.speculative.eagle_utils import (
31
31
  EagleVerifyInput,
32
32
  EagleVerifyOutput,
33
33
  assign_draft_cache_locs,
34
- fast_topk,
35
34
  select_top_k_tokens,
36
35
  )
37
36
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
38
- from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available
37
+ from sglang.srt.utils import (
38
+ empty_context,
39
+ fast_topk,
40
+ get_available_gpu_memory,
41
+ is_cuda_available,
42
+ )
39
43
 
40
44
  if is_cuda_available():
41
45
  from sgl_kernel import segment_packbits
@@ -267,14 +271,11 @@ class EAGLEWorker(TpModelWorker):
267
271
  )
268
272
  elif batch.forward_mode.is_idle():
269
273
  model_worker_batch = batch.get_model_worker_batch()
270
- logits_output, next_token_ids, _ = (
271
- self.target_worker.forward_batch_generation(
272
- ForwardBatch.init_new(
273
- model_worker_batch, self.target_worker.model_runner
274
- )
275
- )
274
+ logits_output, next_token_ids = self.target_worker.forward_batch_generation(
275
+ model_worker_batch
276
276
  )
277
- return logits_output, next_token_ids, model_worker_batch.bid, 0, False
277
+
278
+ return logits_output, next_token_ids, model_worker_batch.bid, 0
278
279
  else:
279
280
  logits_output, next_token_ids, bid = self.forward_target_extend(batch)
280
281
  with self.draft_tp_context(self.draft_model_runner.tp_group):
sglang/srt/utils.py CHANGED
@@ -16,6 +16,7 @@ import base64
16
16
  import builtins
17
17
  import ctypes
18
18
  import dataclasses
19
+ import importlib
19
20
  import io
20
21
  import ipaddress
21
22
  import itertools
@@ -54,7 +55,6 @@ import torch.distributed
54
55
  import torch.distributed as dist
55
56
  import triton
56
57
  import zmq
57
- from decord import VideoReader, cpu
58
58
  from fastapi.responses import ORJSONResponse
59
59
  from packaging import version as pkg_version
60
60
  from PIL import Image
@@ -127,7 +127,7 @@ def is_flashinfer_available():
127
127
  """
128
128
  if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
129
129
  return False
130
- return is_cuda()
130
+ return importlib.util.find_spec("flashinfer") is not None and is_cuda()
131
131
 
132
132
 
133
133
  def is_cuda_available():
@@ -544,6 +544,9 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
544
544
 
545
545
 
546
546
  def encode_video(video_path, frame_count_limit=None):
547
+ # Lazy import because decord is not available on some arm platforms.
548
+ from decord import VideoReader, cpu
549
+
547
550
  if not os.path.exists(video_path):
548
551
  logger.error(f"Video {video_path} does not exist")
549
552
  return []
@@ -568,7 +571,7 @@ def encode_video(video_path, frame_count_limit=None):
568
571
 
569
572
 
570
573
  def load_image(
571
- image_file: Union[Image.Image, str, bytes]
574
+ image_file: Union[Image.Image, str, bytes],
572
575
  ) -> tuple[Image.Image, tuple[int, int]]:
573
576
  image = image_size = None
574
577
  if isinstance(image_file, Image.Image):
@@ -845,33 +848,38 @@ def broadcast_pyobj(
845
848
  rank: int,
846
849
  dist_group: Optional[torch.distributed.ProcessGroup] = None,
847
850
  src: int = 0,
851
+ force_cpu_device: bool = True,
848
852
  ):
849
853
  """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
854
+ device = torch.device(
855
+ "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
856
+ )
850
857
 
851
858
  if rank == 0:
852
859
  if len(data) == 0:
853
- tensor_size = torch.tensor([0], dtype=torch.long)
860
+ tensor_size = torch.tensor([0], dtype=torch.long, device=device)
854
861
  dist.broadcast(tensor_size, src=src, group=dist_group)
855
862
  else:
856
863
  serialized_data = pickle.dumps(data)
857
864
  size = len(serialized_data)
865
+
858
866
  tensor_data = torch.ByteTensor(
859
867
  np.frombuffer(serialized_data, dtype=np.uint8)
860
- )
861
- tensor_size = torch.tensor([size], dtype=torch.long)
868
+ ).to(device)
869
+ tensor_size = torch.tensor([size], dtype=torch.long, device=device)
862
870
 
863
871
  dist.broadcast(tensor_size, src=src, group=dist_group)
864
872
  dist.broadcast(tensor_data, src=src, group=dist_group)
865
873
  return data
866
874
  else:
867
- tensor_size = torch.tensor([0], dtype=torch.long)
875
+ tensor_size = torch.tensor([0], dtype=torch.long, device=device)
868
876
  dist.broadcast(tensor_size, src=src, group=dist_group)
869
877
  size = tensor_size.item()
870
878
 
871
879
  if size == 0:
872
880
  return []
873
881
 
874
- tensor_data = torch.empty(size, dtype=torch.uint8)
882
+ tensor_data = torch.empty(size, dtype=torch.uint8, device=device)
875
883
  dist.broadcast(tensor_data, src=src, group=dist_group)
876
884
 
877
885
  serialized_data = bytes(tensor_data.cpu().numpy())
@@ -1480,14 +1488,43 @@ def permute_weight(x: torch.Tensor) -> torch.Tensor:
1480
1488
 
1481
1489
  class MultiprocessingSerializer:
1482
1490
  @staticmethod
1483
- def serialize(obj):
1491
+ def serialize(obj, output_str: bool = False):
1492
+ """
1493
+ Serialize a Python object using ForkingPickler.
1494
+
1495
+ Args:
1496
+ obj: The object to serialize.
1497
+ output_str (bool): If True, return a base64-encoded string instead of raw bytes.
1498
+
1499
+ Returns:
1500
+ bytes or str: The serialized object.
1501
+ """
1484
1502
  buf = io.BytesIO()
1485
1503
  ForkingPickler(buf).dump(obj)
1486
1504
  buf.seek(0)
1487
- return buf.read()
1505
+ output = buf.read()
1506
+
1507
+ if output_str:
1508
+ # Convert bytes to base64-encoded string
1509
+ output = base64.b64encode(output).decode("utf-8")
1510
+
1511
+ return output
1488
1512
 
1489
1513
  @staticmethod
1490
1514
  def deserialize(data):
1515
+ """
1516
+ Deserialize a previously serialized object.
1517
+
1518
+ Args:
1519
+ data (bytes or str): The serialized data, optionally base64-encoded.
1520
+
1521
+ Returns:
1522
+ The deserialized Python object.
1523
+ """
1524
+ if isinstance(data, str):
1525
+ # Decode base64 string to bytes
1526
+ data = base64.b64decode(data)
1527
+
1491
1528
  return ForkingPickler.loads(data)
1492
1529
 
1493
1530
 
@@ -1819,3 +1856,92 @@ class DeepEPMode(Enum):
1819
1856
  return DeepEPMode.low_latency
1820
1857
  else:
1821
1858
  return DeepEPMode.normal
1859
+
1860
+
1861
+ def fast_topk(values, topk, dim):
1862
+ if topk == 1:
1863
+ # Use max along the specified dimension to get both value and index
1864
+ return torch.max(values, dim=dim, keepdim=True)
1865
+ else:
1866
+ # Use topk for efficiency with larger k values
1867
+ return torch.topk(values, topk, dim=dim)
1868
+
1869
+
1870
+ def is_hopper_with_cuda_12_3():
1871
+ if not is_cuda():
1872
+ return False
1873
+ is_hopper = torch.cuda.get_device_capability()[0] == 9
1874
+ cuda_version = torch.version.cuda.split(".")
1875
+ is_cuda_compatible = int(cuda_version[0]) == 12 and int(cuda_version[1]) >= 3
1876
+ return is_hopper and is_cuda_compatible
1877
+
1878
+
1879
+ def get_free_port():
1880
+ # try ipv4
1881
+ try:
1882
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
1883
+ s.bind(("", 0))
1884
+ return s.getsockname()[1]
1885
+ except OSError:
1886
+ # try ipv6
1887
+ with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
1888
+ s.bind(("", 0))
1889
+ return s.getsockname()[1]
1890
+
1891
+
1892
+ def get_local_ip_by_remote() -> str:
1893
+ # try ipv4
1894
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1895
+ try:
1896
+ s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
1897
+ return s.getsockname()[0]
1898
+ except Exception:
1899
+ pass
1900
+
1901
+ # try ipv6
1902
+ try:
1903
+ s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
1904
+ # Google's public DNS server, see
1905
+ # https://developers.google.com/speed/public-dns/docs/using#addresses
1906
+ s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
1907
+ return s.getsockname()[0]
1908
+ except Exception:
1909
+ raise ValueError(f"Can not get local ip")
1910
+
1911
+
1912
+ def is_page_size_one(server_args):
1913
+ return server_args.page_size == 1
1914
+
1915
+
1916
+ def is_no_spec_infer_or_topk_one(server_args):
1917
+ return server_args.speculative_eagle_topk is None or (
1918
+ server_args.speculative_eagle_topk is not None
1919
+ and server_args.speculative_eagle_topk == 1
1920
+ and is_page_size_one(server_args)
1921
+ )
1922
+
1923
+
1924
+ def is_fa3_default_architecture(hf_config):
1925
+ architectures = getattr(hf_config, "architectures", None)
1926
+ if not isinstance(architectures, list) or not architectures:
1927
+ return False
1928
+ default_archs = {
1929
+ "Qwen2ForCausalLM",
1930
+ "Llama4ForConditionalGeneration",
1931
+ "LlamaForCausalLM",
1932
+ "MistralForCausalLM",
1933
+ }
1934
+ return architectures[0] in default_archs
1935
+
1936
+
1937
+ # Can be more general if it is used in multiple places (keep it simple and thus not general now)
1938
+ class BumpAllocator:
1939
+ def __init__(self, buffer_size: int, dtype, device):
1940
+ self._buffer = torch.zeros((buffer_size,), dtype=dtype, device=device)
1941
+ self._pointer = 0
1942
+
1943
+ def allocate(self, size: int):
1944
+ assert self._pointer + size <= len(self._buffer)
1945
+ output = self._buffer[self._pointer : self._pointer + size]
1946
+ self._pointer += size
1947
+ return output