sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -297,7 +297,7 @@ class EAGLEWorker(TpModelWorker):
297
297
 
298
298
  def forward_batch_speculative_generation(
299
299
  self, batch: ScheduleBatch
300
- ) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
300
+ ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
301
301
  """Run speculative decoding forward.
302
302
 
303
303
  NOTE: Many states of batch is modified as you go through. It is not guaranteed that
@@ -325,11 +325,16 @@ class EAGLEWorker(TpModelWorker):
325
325
  self.verify(batch, spec_info)
326
326
  )
327
327
 
328
- if self.check_forward_draft_extend_after_decode(batch):
329
- with self.draft_tp_context(self.draft_model_runner.tp_group):
330
- self.forward_draft_extend_after_decode(
331
- batch,
332
- )
328
+ with self.draft_tp_context(self.draft_model_runner.tp_group):
329
+ # NOTE: We should use `check_forward_draft_extend_after_decode`
330
+ # when DP attention is enabled, but it is slow. Skip it for now.
331
+ if (
332
+ self.server_args.enable_dp_attention
333
+ or batch.spec_info.verified_id.shape[0] > 0
334
+ ):
335
+ # decode is not finished
336
+ self.forward_draft_extend_after_decode(batch)
337
+
333
338
  return (
334
339
  logits_output,
335
340
  verify_output.verified_id,
@@ -339,10 +344,7 @@ class EAGLEWorker(TpModelWorker):
339
344
  )
340
345
 
341
346
  def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
342
- local_need_forward = (
343
- batch.spec_info.verified_id is not None
344
- and batch.spec_info.verified_id.shape[0] > 0
345
- )
347
+ local_need_forward = batch.spec_info.verified_id.shape[0] > 0
346
348
  if not self.server_args.enable_dp_attention:
347
349
  return local_need_forward
348
350
 
@@ -361,7 +363,7 @@ class EAGLEWorker(TpModelWorker):
361
363
 
362
364
  def forward_target_extend(
363
365
  self, batch: ScheduleBatch
364
- ) -> Tuple[LogitsProcessorOutput, List[int], int]:
366
+ ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, Optional[torch.Tensor]]:
365
367
  """Run the target extend.
366
368
 
367
369
  Args:
@@ -376,7 +378,6 @@ class EAGLEWorker(TpModelWorker):
376
378
  # We need the full hidden states to prefill the KV cache of the draft model.
377
379
  model_worker_batch = batch.get_model_worker_batch()
378
380
  model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
379
- model_worker_batch.spec_num_draft_tokens = 1
380
381
  logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
381
382
  model_worker_batch
382
383
  )
@@ -508,13 +509,15 @@ class EAGLEWorker(TpModelWorker):
508
509
  self._draft_preprocess_decode(batch)
509
510
 
510
511
  spec_info = batch.spec_info
512
+ assert isinstance(spec_info, EagleDraftInput)
511
513
 
512
514
  spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
515
+ spec_info.num_tokens_per_batch = self.topk
516
+ spec_info.num_tokens_for_logprob_per_batch = self.topk
513
517
  batch.return_hidden_states = False
514
518
 
515
519
  # Get forward batch
516
520
  model_worker_batch = batch.get_model_worker_batch()
517
- model_worker_batch.spec_num_draft_tokens = self.topk
518
521
  assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
519
522
  forward_batch = ForwardBatch.init_new(
520
523
  model_worker_batch, self.draft_model_runner
@@ -527,6 +530,7 @@ class EAGLEWorker(TpModelWorker):
527
530
  forward_batch
528
531
  )
529
532
  else:
533
+ forward_batch.can_run_dp_cuda_graph = False
530
534
  if not forward_batch.forward_mode.is_idle():
531
535
  # Initialize attention backend
532
536
  self.draft_attn_backend.init_forward_metadata(forward_batch)
@@ -578,6 +582,7 @@ class EAGLEWorker(TpModelWorker):
578
582
  def draft_forward(self, forward_batch: ForwardBatch):
579
583
  # Parse args
580
584
  spec_info = forward_batch.spec_info
585
+ assert isinstance(spec_info, EagleDraftInput)
581
586
  out_cache_loc = forward_batch.out_cache_loc
582
587
  topk_p, topk_index, hidden_states = (
583
588
  spec_info.topk_p,
@@ -621,8 +626,8 @@ class EAGLEWorker(TpModelWorker):
621
626
  spec_info.hidden_states = hidden_states
622
627
 
623
628
  # Run forward
624
- logits_output = self.draft_model_runner.model.forward(
625
- forward_batch.input_ids, forward_batch.positions, forward_batch
629
+ logits_output, _ = self.draft_model_runner.forward(
630
+ forward_batch, skip_attn_backend_init=True
626
631
  )
627
632
  self._detect_nan_if_needed(logits_output)
628
633
  probs = torch.softmax(logits_output.next_token_logits, dim=-1)
@@ -642,10 +647,10 @@ class EAGLEWorker(TpModelWorker):
642
647
  else ForwardMode.IDLE
643
648
  )
644
649
  batch.spec_info = spec_info
650
+
645
651
  model_worker_batch = batch.get_model_worker_batch(
646
652
  seq_lens_cpu_cache=spec_info.seq_lens_cpu
647
653
  )
648
- model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
649
654
  assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
650
655
 
651
656
  if batch.has_grammar:
@@ -782,8 +787,8 @@ class EAGLEWorker(TpModelWorker):
782
787
  self,
783
788
  batch: ScheduleBatch,
784
789
  hidden_states: torch.Tensor,
785
- next_token_ids: List[int],
786
- seq_lens_cpu: torch.Tensor,
790
+ next_token_ids: torch.Tensor,
791
+ seq_lens_cpu: Optional[torch.Tensor],
787
792
  ):
788
793
  """Run draft model extend. This API modifies the states of the batch.
789
794
 
@@ -795,6 +800,8 @@ class EAGLEWorker(TpModelWorker):
795
800
  batch.spec_info = EagleDraftInput(
796
801
  hidden_states=hidden_states,
797
802
  verified_id=next_token_ids,
803
+ num_tokens_per_batch=1,
804
+ num_tokens_for_logprob_per_batch=1,
798
805
  )
799
806
  batch.return_hidden_states = False
800
807
  batch.spec_info.prepare_for_extend(batch)
@@ -802,7 +809,6 @@ class EAGLEWorker(TpModelWorker):
802
809
  model_worker_batch = batch.get_model_worker_batch(
803
810
  seq_lens_cpu_cache=seq_lens_cpu
804
811
  )
805
- model_worker_batch.spec_num_draft_tokens = 1
806
812
  forward_batch = ForwardBatch.init_new(
807
813
  model_worker_batch, self.draft_model_runner
808
814
  )
@@ -814,37 +820,45 @@ class EAGLEWorker(TpModelWorker):
814
820
  self.capture_for_decode(logits_output, forward_batch.spec_info)
815
821
 
816
822
  def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
823
+ assert isinstance(batch.spec_info, EagleDraftInput)
817
824
  # Backup fields that will be modified in-place
818
825
  seq_lens_backup = batch.seq_lens.clone()
819
826
  req_pool_indices_backup = batch.req_pool_indices
820
827
  accept_length_backup = batch.spec_info.accept_length
821
828
  return_logprob_backup = batch.return_logprob
829
+
822
830
  input_is_idle = batch.forward_mode.is_idle()
823
- if not input_is_idle:
824
- # Prepare metadata
825
- if batch.spec_info.verified_id is not None:
826
- batch.spec_info.prepare_extend_after_decode(
827
- batch,
828
- self.speculative_num_steps,
829
- )
830
- else:
831
- batch = batch.copy()
832
- batch.prepare_for_idle()
833
- hidden_size = (
834
- self.model_config.hidden_size * 3
835
- if self.speculative_algorithm.is_eagle3()
836
- else self.model_config.hidden_size
837
- )
838
- batch.spec_info = EagleDraftInput.create_idle_input(
839
- device=self.device,
840
- hidden_size=hidden_size,
841
- dtype=self.model_config.dtype,
842
- topk=self.topk,
843
- capture_hidden_mode=CaptureHiddenMode.LAST,
844
- )
831
+
832
+ if not input_is_idle and batch.spec_info.verified_id.numel() == 0:
833
+ batch = batch.copy()
834
+ batch.prepare_for_idle()
835
+ hidden_size = (
836
+ self.model_config.hidden_size * 3
837
+ if self.speculative_algorithm.is_eagle3()
838
+ else self.model_config.hidden_size
839
+ )
840
+ batch.spec_info = EagleDraftInput.create_idle_input(
841
+ device=self.device,
842
+ hidden_size=hidden_size,
843
+ dtype=self.model_config.dtype,
844
+ topk=self.topk,
845
+ capture_hidden_mode=CaptureHiddenMode.LAST,
846
+ )
847
+
848
+ batch.spec_info.num_tokens_per_batch = self.speculative_num_steps + 1
849
+ batch.spec_info.num_tokens_for_logprob_per_batch = 1
850
+ batch.spec_info.prepare_extend_after_decode(
851
+ batch,
852
+ self.speculative_num_steps,
853
+ )
854
+ batch.forward_mode = (
855
+ ForwardMode.DRAFT_EXTEND
856
+ if not batch.forward_mode.is_idle()
857
+ else ForwardMode.IDLE
858
+ )
859
+
845
860
  batch.return_hidden_states = False
846
861
  model_worker_batch = batch.get_model_worker_batch()
847
- model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
848
862
  assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
849
863
  forward_batch = ForwardBatch.init_new(
850
864
  model_worker_batch, self.draft_model_runner
@@ -869,12 +883,13 @@ class EAGLEWorker(TpModelWorker):
869
883
  )
870
884
  forward_batch.spec_info.hidden_states = logits_output.hidden_states
871
885
  else:
886
+ forward_batch.can_run_dp_cuda_graph = False
872
887
  if not forward_batch.forward_mode.is_idle():
873
888
  self.draft_model_runner.attn_backend.init_forward_metadata(
874
889
  forward_batch
875
890
  )
876
- logits_output = self.draft_model_runner.model.forward(
877
- forward_batch.input_ids, forward_batch.positions, forward_batch
891
+ logits_output, _ = self.draft_model_runner.forward(
892
+ forward_batch, skip_attn_backend_init=True
878
893
  )
879
894
  self.capture_for_decode(logits_output, forward_batch.spec_info)
880
895
 
@@ -341,15 +341,18 @@ class TboDPAttentionPreparer:
341
341
 
342
342
  @staticmethod
343
343
  def _compute_global_forward_mode(forward_modes):
344
- converted_forward_modes = [
345
- ForwardMode.DECODE.value if x == ForwardMode.IDLE.value else x
346
- for x in forward_modes
344
+ forward_modes_excluding_idle = [
345
+ x for x in forward_modes if x != ForwardMode.IDLE.value
347
346
  ]
347
+
348
+ if not forward_modes_excluding_idle:
349
+ return ForwardMode.IDLE, False
350
+
348
351
  forward_mode_agree = TboDPAttentionPreparer._is_all_same(
349
- converted_forward_modes
352
+ forward_modes_excluding_idle
350
353
  )
351
354
  global_forward_mode = (
352
- ForwardMode(converted_forward_modes[0]) if forward_mode_agree else None
355
+ ForwardMode(forward_modes_excluding_idle[0]) if forward_mode_agree else None
353
356
  )
354
357
  return global_forward_mode, forward_mode_agree
355
358
 
@@ -500,6 +503,7 @@ class TboForwardBatchPreparer:
500
503
  "capture_hidden_mode",
501
504
  "padded_static_len",
502
505
  "mrope_positions", # only used by qwen2-vl, thus not care
506
+ "split_index", # for split prefill
503
507
  ]:
504
508
  output_dict[key] = getattr(batch, key)
505
509
  if not batch.forward_mode.is_target_verify():
@@ -541,6 +545,7 @@ class TboForwardBatchPreparer:
541
545
  tbo_children=None,
542
546
  global_num_tokens_gpu=None,
543
547
  global_num_tokens_cpu=None,
548
+ dp_padding_mode=None,
544
549
  gathered_buffer=gathered_buffer,
545
550
  global_num_tokens_for_logprob_gpu=None,
546
551
  global_num_tokens_for_logprob_cpu=None,
sglang/srt/utils.py CHANGED
@@ -691,12 +691,17 @@ def decode_video_base64(video_base64):
691
691
  ) # Return an empty array and size tuple if no frames were found
692
692
 
693
693
 
694
- def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarray:
694
+ def load_audio(
695
+ audio_file: str, sr: Optional[int] = None, mono: bool = True
696
+ ) -> np.ndarray:
695
697
  # Use soundfile here, since librosa use it under the hood,
696
698
  # and librosa will not support audio loading in the future
697
699
  import soundfile as sf
698
700
  from scipy.signal import resample
699
701
 
702
+ if sr is None:
703
+ sr = 16000
704
+
700
705
  # Load audio data
701
706
  if isinstance(audio_file, bytes):
702
707
  audio, original_sr = sf.read(BytesIO(audio_file))
@@ -739,9 +744,13 @@ def load_image(
739
744
  image = Image.open(BytesIO(image_file))
740
745
  elif image_file.startswith("http://") or image_file.startswith("https://"):
741
746
  timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
742
- response = requests.get(image_file, stream=True, timeout=timeout).raw
743
- image = Image.open(response)
744
- response.close()
747
+ response = requests.get(image_file, stream=True, timeout=timeout)
748
+ try:
749
+ response.raise_for_status()
750
+ image = Image.open(response.raw)
751
+ image.load() # Force loading to avoid issues after closing the stream
752
+ finally:
753
+ response.close()
745
754
  elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
746
755
  image = Image.open(image_file)
747
756
  elif image_file.startswith("data:"):
@@ -928,71 +937,6 @@ def monkey_patch_vllm_gguf_config():
928
937
  setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)
929
938
 
930
939
 
931
- def maybe_set_triton_cache_manager() -> None:
932
- """Set environment variable to tell Triton to use a
933
- custom cache manager"""
934
- cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
935
- if cache_manger is None:
936
- manager = "sglang.srt.utils:CustomCacheManager"
937
- logger.debug("Setting Triton cache manager to: %s", manager)
938
- os.environ["TRITON_CACHE_MANAGER"] = manager
939
-
940
-
941
- class CustomCacheManager(FileCacheManager):
942
- # Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
943
- def __init__(self, key, override=False, dump=False):
944
- from sglang.srt.distributed.parallel_state import get_tp_group
945
-
946
- self.key = key
947
- self.lock_path = None
948
-
949
- try:
950
- module_path = "triton.runtime.cache"
951
- cache_module = importlib.import_module(module_path)
952
-
953
- default_cache_dir = getattr(cache_module, "default_cache_dir", None)
954
- default_dump_dir = getattr(cache_module, "default_dump_dir", None)
955
- default_override_dir = getattr(cache_module, "default_override_dir", None)
956
- except (ModuleNotFoundError, AttributeError) as e:
957
- default_cache_dir = None
958
- default_dump_dir = None
959
- default_override_dir = None
960
-
961
- if dump:
962
- self.cache_dir = (
963
- default_dump_dir()
964
- if default_dump_dir is not None
965
- else os.path.join(Path.home(), ".triton", "dump")
966
- )
967
- self.cache_dir = os.path.join(self.cache_dir, self.key)
968
- self.lock_path = os.path.join(self.cache_dir, "lock")
969
- os.makedirs(self.cache_dir, exist_ok=True)
970
- elif override:
971
- self.cache_dir = (
972
- default_override_dir()
973
- if default_override_dir is not None
974
- else os.path.join(Path.home(), ".triton", "override")
975
- )
976
- self.cache_dir = os.path.join(self.cache_dir, self.key)
977
- else:
978
- # create cache directory if it doesn't exist
979
- self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or (
980
- default_cache_dir()
981
- if default_cache_dir is not None
982
- else os.path.join(Path.home(), ".triton", "cache")
983
- )
984
- if self.cache_dir:
985
- try:
986
- self.cache_dir = f"{self.cache_dir}_{get_tp_group().local_rank}"
987
- except:
988
- self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
989
- self.cache_dir = os.path.join(self.cache_dir, self.key)
990
- self.lock_path = os.path.join(self.cache_dir, "lock")
991
- os.makedirs(self.cache_dir, exist_ok=True)
992
- else:
993
- raise RuntimeError("Could not create or locate cache dir")
994
-
995
-
996
940
  def set_ulimit(target_soft_limit=65535):
997
941
  # number of open files
998
942
  resource_type = resource.RLIMIT_NOFILE
@@ -1417,6 +1361,13 @@ def get_nvgpu_memory_capacity():
1417
1361
  ]
1418
1362
 
1419
1363
  if not memory_values:
1364
+ # Fallback to torch.cuda.mem_get_info() when failed to get memory capacity from nvidia-smi,
1365
+ # typically in NVIDIA MIG mode.
1366
+ if torch.cuda.is_available():
1367
+ logger.warning(
1368
+ "Failed to get GPU memory capacity from nvidia-smi, falling back to torch.cuda.mem_get_info()."
1369
+ )
1370
+ return torch.cuda.mem_get_info()[1] // 1024 // 1024 # unit: MB
1420
1371
  raise ValueError("No GPU memory values found.")
1421
1372
 
1422
1373
  # Return the minimum memory value
@@ -2049,6 +2000,16 @@ def is_valid_ipv6_address(address: str) -> bool:
2049
2000
  return False
2050
2001
 
2051
2002
 
2003
+ def maybe_wrap_ipv6_address(address: str) -> str:
2004
+ if is_valid_ipv6_address(address):
2005
+ return f"[{address}]"
2006
+ return address
2007
+
2008
+
2009
+ def format_tcp_address(ip: str, port: int) -> str:
2010
+ return f"tcp://{maybe_wrap_ipv6_address(ip)}:{port}"
2011
+
2012
+
2052
2013
  def configure_ipv6(dist_init_addr):
2053
2014
  addr = dist_init_addr
2054
2015
  end = addr.find("]")
@@ -2880,3 +2841,17 @@ def parse_module_path(module_path, function_name, create_dummy):
2880
2841
  return final_module, getattr(final_module, function_name)
2881
2842
 
2882
2843
  return final_module, None
2844
+
2845
+
2846
+ # LoRA-related constants and utilities
2847
+ SUPPORTED_LORA_TARGET_MODULES = [
2848
+ "q_proj",
2849
+ "k_proj",
2850
+ "v_proj",
2851
+ "o_proj",
2852
+ "gate_proj",
2853
+ "up_proj",
2854
+ "down_proj",
2855
+ ]
2856
+
2857
+ LORA_TARGET_ALL_MODULES = "all"
sglang/test/runners.py CHANGED
@@ -134,10 +134,12 @@ class HFRunner:
134
134
  model_type: str = "generation",
135
135
  output_str_only: bool = False,
136
136
  trust_remote_code: bool = False,
137
+ patch_model_do_sample_false: bool = False,
137
138
  ):
138
139
  self.model_type = model_type
139
140
  self.output_str_only = output_str_only
140
141
  self.trust_remote_code = trust_remote_code
142
+ self.patch_model_do_sample_false = patch_model_do_sample_false
141
143
 
142
144
  self.in_queue = mp.Queue()
143
145
  self.out_queue = mp.Queue()
@@ -292,6 +294,7 @@ class HFRunner:
292
294
  torch_dtype=torch_dtype,
293
295
  output_str_only=self.output_str_only,
294
296
  token_ids_logprob=token_ids_logprob,
297
+ patch_model_do_sample_false=self.patch_model_do_sample_false,
295
298
  )
296
299
  )
297
300
  elif self.model_type == "embedding":
@@ -380,6 +383,7 @@ class HFRunner:
380
383
  lora_paths: Optional[List[str]] = None,
381
384
  output_str_only: bool = False,
382
385
  token_ids_logprob: Optional[int] = None,
386
+ patch_model_do_sample_false: Optional[bool] = False,
383
387
  ) -> ModelOutput:
384
388
  output_strs = []
385
389
  top_input_logprobs = []
@@ -407,7 +411,8 @@ class HFRunner:
407
411
  )
408
412
  else:
409
413
  model = base_model
410
-
414
+ if patch_model_do_sample_false:
415
+ model.generation_config.do_sample = False
411
416
  outputs = model.generate(
412
417
  input_ids=input_ids,
413
418
  generation_config=GenerationConfig(
@@ -481,7 +486,7 @@ class SRTRunner:
481
486
  torch_dtype: torch.dtype,
482
487
  model_type: str,
483
488
  tp_size: int = 1,
484
- impl: str = "auto",
489
+ model_impl: str = "auto",
485
490
  port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
486
491
  lora_paths: List[str] = None,
487
492
  max_loras_per_batch: int = 4,
@@ -505,6 +510,9 @@ class SRTRunner:
505
510
  torchao_config: Optional[str] = None,
506
511
  cuda_graph_max_bs: int = 4,
507
512
  sleep_on_idle=False,
513
+ max_lora_rank: Optional[int] = None,
514
+ lora_target_modules: Optional[List[str]] = None,
515
+ enable_lora: Optional[bool] = None,
508
516
  ):
509
517
  self.model_type = model_type
510
518
  self.is_generation = model_type == "generation"
@@ -523,7 +531,7 @@ class SRTRunner:
523
531
  tp_size=tp_size,
524
532
  dtype=get_dtype_str(torch_dtype),
525
533
  port=port,
526
- impl=impl,
534
+ model_impl=model_impl,
527
535
  torchao_config=torchao_config,
528
536
  mem_fraction_static=mem_fraction_static,
529
537
  trust_remote_code=trust_remote_code,
@@ -543,6 +551,9 @@ class SRTRunner:
543
551
  cuda_graph_max_bs=cuda_graph_max_bs,
544
552
  disable_custom_all_reduce=disable_custom_all_reduce,
545
553
  sleep_on_idle=sleep_on_idle,
554
+ max_lora_rank=max_lora_rank,
555
+ lora_target_modules=lora_target_modules,
556
+ enable_lora=enable_lora,
546
557
  **spec_kwargs,
547
558
  )
548
559
 
@@ -3,9 +3,12 @@ import unittest
3
3
 
4
4
  import torch
5
5
 
6
- from sglang.srt.layers.activation import GeluAndMul
6
+ from sglang.srt.layers.activation import GeluAndMul, QuickGELU
7
+ from sglang.srt.utils import is_hip
7
8
  from sglang.test.test_utils import CustomTestCase
8
9
 
10
+ _is_hip = is_hip()
11
+
9
12
 
10
13
  class TestGeluAndMul(CustomTestCase):
11
14
  DTYPES = [torch.half, torch.bfloat16]
@@ -52,5 +55,51 @@ class TestGeluAndMul(CustomTestCase):
52
55
  self._run_gelu_and_mul_test(*params)
53
56
 
54
57
 
58
+ class TestQuickGELU(CustomTestCase):
59
+ DTYPES = [torch.half, torch.bfloat16]
60
+ NUM_TOKENS = [7, 83, 2048] # batch = sequence length
61
+ DIMS = [512, 4096, 5120, 13824] # all multiples of 16 bytes
62
+ SEEDS = [0]
63
+
64
+ @classmethod
65
+ def setUpClass(cls):
66
+ if not torch.cuda.is_available():
67
+ raise unittest.SkipTest("CUDA is not available")
68
+ torch.set_default_device("cuda")
69
+
70
+ def _run_gelu_quick_test(self, n_tok: int, dim: int, dtype: torch.dtype, seed: int):
71
+ torch.manual_seed(seed)
72
+
73
+ layer = QuickGELU().to(dtype=dtype)
74
+
75
+ x = torch.randn(n_tok, dim, dtype=dtype, device="cuda")
76
+
77
+ with torch.inference_mode():
78
+ ref = layer.forward_native(x) # x * sigmoid(1.702 * x), fp32 math
79
+ if _is_hip:
80
+ out = layer.forward_hip(x) # 128-bit vectorised kernel from sgl-kernel
81
+ else:
82
+ out = layer.forward_cuda(x)
83
+
84
+ tol = 1e-2 if dtype is torch.bfloat16 else 1e-3
85
+ self.assertTrue(
86
+ torch.allclose(out, ref, atol=tol, rtol=tol),
87
+ msg=f"Mismatch @ B={n_tok}, D={dim}, dtype={dtype}",
88
+ )
89
+ print(f"Match @ B={n_tok}, D={dim}, dtype={dtype}")
90
+
91
+ def test_quick_gelu(self):
92
+ for params in itertools.product(
93
+ self.NUM_TOKENS, self.DIMS, self.DTYPES, self.SEEDS
94
+ ):
95
+ with self.subTest(
96
+ num_tokens=params[0],
97
+ dim=params[1],
98
+ dtype=params[2],
99
+ seed=params[3],
100
+ ):
101
+ self._run_gelu_quick_test(*params)
102
+
103
+
55
104
  if __name__ == "__main__":
56
105
  unittest.main(verbosity=2)
@@ -6,6 +6,7 @@ import torch
6
6
 
7
7
  from sglang.srt.layers.activation import SiluAndMul
8
8
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
9
+ from sglang.srt.layers.moe.topk import select_experts
9
10
  from sglang.srt.layers.quantization.fp8_kernel import (
10
11
  per_tensor_quant_mla_fp8,
11
12
  per_token_group_quant_fp8,
@@ -497,13 +498,17 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
497
498
  score = torch.randn((M, E), dtype=dtype)
498
499
 
499
500
  with torch.inference_mode():
501
+ topk_output = select_experts(
502
+ hidden_states=a,
503
+ router_logits=score,
504
+ top_k=topk,
505
+ renormalize=False,
506
+ )
500
507
  out = fused_moe(
501
508
  a,
502
509
  w1,
503
510
  w2,
504
- score,
505
- topk,
506
- renormalize=False,
511
+ topk_output,
507
512
  use_fp8_w8a8=True,
508
513
  w1_scale=w1_s,
509
514
  w2_scale=w2_s,
@@ -40,7 +40,7 @@ def ep_moe(
40
40
  block_shape: Optional[List[int]] = None,
41
41
  ):
42
42
  use_blockwise_fp8 = block_shape is not None
43
- topk_weights, topk_ids = select_experts(
43
+ topk_weights, topk_ids, _ = select_experts(
44
44
  hidden_states=hidden_states,
45
45
  router_logits=router_logits,
46
46
  top_k=top_k,
@@ -3,8 +3,13 @@
3
3
  import pytest
4
4
  import torch
5
5
 
6
- from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
7
- from sglang.srt.utils import is_cuda
6
+ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
7
+ from sglang.srt.utils import is_cuda, is_hip
8
+
9
+ _is_cuda = is_cuda()
10
+ _is_hip = is_hip()
11
+ _is_fp8_fnuz = is_fp8_fnuz()
12
+ fp8_dtype = torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn
8
13
 
9
14
 
10
15
  @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@@ -13,10 +18,10 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
13
18
  def quantize_ref_per_tensor(tensor, inv_scale):
14
19
  # The reference implementation that fully aligns to
15
20
  # the kernel being tested.
16
- finfo = torch.finfo(torch.float8_e4m3fn)
21
+ finfo = torch.finfo(fp8_dtype)
17
22
  scale = inv_scale.reciprocal()
18
23
  qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
19
- qweight = qweight.to(torch.float8_e4m3fn)
24
+ qweight = qweight.to(fp8_dtype)
20
25
  return qweight
21
26
 
22
27
  def dequantize_per_tensor(tensor, inv_scale, dtype):
@@ -48,19 +53,19 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
48
53
  )
49
54
 
50
55
 
51
- if is_cuda:
56
+ if _is_cuda or _is_hip:
52
57
 
53
58
  @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
54
59
  def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
55
60
  def quantize_ref_per_token(tensor, inv_scale):
56
61
  # The reference implementation that fully aligns to
57
62
  # the kernel being tested.
58
- finfo = torch.finfo(torch.float8_e4m3fn)
63
+ finfo = torch.finfo(fp8_dtype)
59
64
  scale = inv_scale.reciprocal()
60
65
  qweight = (tensor.to(torch.float32) * scale).clamp(
61
66
  min=finfo.min, max=finfo.max
62
67
  )
63
- qweight = qweight.to(torch.float8_e4m3fn)
68
+ qweight = qweight.to(fp8_dtype)
64
69
  return qweight
65
70
 
66
71
  def dequantize_per_token(tensor, inv_scale, dtype):
@@ -100,12 +100,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
100
100
  s_strides2 = c_strides2
101
101
 
102
102
  score = torch.randn((M, E), dtype=dtype, device=device)
103
- topk_weights, topk_ids = select_experts(
103
+ topk_weights, topk_ids, _ = select_experts(
104
104
  hidden_states=a,
105
105
  router_logits=score,
106
106
  top_k=topk,
107
- use_grouped_topk=False,
108
- renormalize=False,
109
107
  )
110
108
  expert_map = torch.arange(E, dtype=torch.int32, device=device)
111
109
  expert_map[local_e:] = E
@@ -159,12 +159,10 @@ def test_cutlass_fp4_moe_no_graph(
159
159
 
160
160
  score = torch.randn((m, e), device="cuda", dtype=dtype)
161
161
 
162
- topk_weights, topk_ids = select_experts(
162
+ topk_weights, topk_ids, _ = select_experts(
163
163
  hidden_states=a,
164
164
  router_logits=score,
165
165
  top_k=topk,
166
- use_grouped_topk=False,
167
- renormalize=False,
168
166
  )
169
167
 
170
168
  a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)