sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -26,11 +26,13 @@ from sglang.srt.model_executor.forward_batch_info import (
26
26
  from sglang.srt.operations import execute_operations, execute_overlapped_operations
27
27
  from sglang.srt.operations_strategy import OperationsStrategy
28
28
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
29
- from sglang.srt.utils import BumpAllocator, get_bool_env_var
29
+ from sglang.srt.utils import BumpAllocator, get_bool_env_var, is_hip
30
30
 
31
31
  if TYPE_CHECKING:
32
32
  from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
33
33
 
34
+ _is_hip = is_hip()
35
+
34
36
  _tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
35
37
 
36
38
  logger = logging.getLogger(__name__)
@@ -676,16 +678,12 @@ class TboForwardBatchPreparer:
676
678
  # TODO improve, e.g. unify w/ `init_raw`
677
679
  if (
678
680
  global_server_args_dict["moe_dense_tp_size"] == 1
679
- and batch.gathered_buffer is not None
681
+ and batch.global_dp_buffer_len is not None
680
682
  ):
681
683
  sum_len = end_token_index - start_token_index
682
- gathered_buffer = torch.zeros(
683
- (sum_len, batch.gathered_buffer.shape[1]),
684
- dtype=batch.gathered_buffer.dtype,
685
- device=batch.gathered_buffer.device,
686
- )
684
+ global_dp_buffer_len = sum_len
687
685
  else:
688
- gathered_buffer = None
686
+ global_dp_buffer_len = None
689
687
 
690
688
  output_dict.update(
691
689
  dict(
@@ -704,7 +702,7 @@ class TboForwardBatchPreparer:
704
702
  global_num_tokens_gpu=None,
705
703
  global_num_tokens_cpu=None,
706
704
  dp_padding_mode=None,
707
- gathered_buffer=gathered_buffer,
705
+ global_dp_buffer_len=global_dp_buffer_len,
708
706
  global_num_tokens_for_logprob_gpu=None,
709
707
  global_num_tokens_for_logprob_cpu=None,
710
708
  sampling_info=None,
@@ -822,9 +820,15 @@ def _model_forward_tbo(
822
820
  )
823
821
  del inputs
824
822
 
825
- with deep_gemm_wrapper.configure_deep_gemm_num_sms(
826
- operations_strategy.deep_gemm_num_sms
827
- ):
823
+ context = (
824
+ empty_context()
825
+ if _is_hip
826
+ else deep_gemm_wrapper.configure_deep_gemm_num_sms(
827
+ operations_strategy.deep_gemm_num_sms
828
+ )
829
+ )
830
+
831
+ with context:
828
832
  outputs_arr = execute_overlapped_operations(
829
833
  inputs_arr=inputs_arr,
830
834
  operations_arr=[operations_strategy.operations] * 2,
sglang/srt/utils.py CHANGED
@@ -815,7 +815,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
815
815
  vr = VideoReader(tmp_file.name, ctx=ctx)
816
816
  elif video_file.startswith("data:"):
817
817
  _, encoded = video_file.split(",", 1)
818
- video_bytes = base64.b64decode(encoded)
818
+ video_bytes = pybase64.b64decode(encoded)
819
819
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
820
820
  tmp_file.write(video_bytes)
821
821
  tmp_file.close()
@@ -823,7 +823,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
823
823
  elif os.path.isfile(video_file):
824
824
  vr = VideoReader(video_file, ctx=ctx)
825
825
  else:
826
- video_bytes = base64.b64decode(video_file)
826
+ video_bytes = pybase64.b64decode(video_file)
827
827
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
828
828
  tmp_file.write(video_bytes)
829
829
  tmp_file.close()
@@ -2960,7 +2960,7 @@ class ConcurrentCounter:
2960
2960
  This suspends the calling coroutine without blocking the thread, allowing
2961
2961
  other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
2962
2962
  """
2963
- self.wait_for(lambda count: count == 0)
2963
+ await self.wait_for(lambda count: count == 0)
2964
2964
 
2965
2965
 
2966
2966
  @lru_cache(maxsize=1)
@@ -0,0 +1,106 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+
6
+
7
+ @dataclass
8
+ class FlattenedTensorMetadata:
9
+ """Metadata for a tensor in a flattened bucket"""
10
+
11
+ name: str
12
+ shape: torch.Size
13
+ dtype: torch.dtype
14
+ start_idx: int
15
+ end_idx: int
16
+ numel: int
17
+
18
+
19
+ class FlattenedTensorBucket:
20
+ """
21
+ A bucket that flattens multiple tensors into a single tensor for efficient processing
22
+ while preserving all metadata needed for reconstruction.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ named_tensors: List[Tuple[str, torch.Tensor]] = None,
28
+ flattened_tensor: torch.Tensor = None,
29
+ metadata: List[FlattenedTensorMetadata] = None,
30
+ ):
31
+ """
32
+ Initialize a tensor bucket from a list of named tensors OR from pre-flattened data.
33
+ Args:
34
+ named_tensors: List of (name, tensor) tuples (for creating new bucket)
35
+ flattened_tensor: Pre-flattened tensor (for reconstruction)
36
+ metadata: Pre-computed metadata (for reconstruction)
37
+ """
38
+ if named_tensors is not None:
39
+ # Create bucket from named tensors
40
+ self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors)
41
+ self.flattened_tensor: torch.Tensor = None
42
+
43
+ if not named_tensors:
44
+ raise ValueError("Cannot create empty tensor bucket")
45
+
46
+ # Collect metadata and flatten tensors
47
+ current_idx = 0
48
+ flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors)
49
+
50
+ for i, (name, tensor) in enumerate(named_tensors):
51
+ flattened = tensor.flatten()
52
+ flattened_tensors[i] = flattened
53
+
54
+ # Store metadata
55
+
56
+ numel = flattened.numel()
57
+ metadata_obj = FlattenedTensorMetadata(
58
+ name=name,
59
+ shape=tensor.shape,
60
+ dtype=tensor.dtype,
61
+ start_idx=current_idx,
62
+ end_idx=current_idx + numel,
63
+ numel=numel,
64
+ )
65
+ self.metadata[i] = metadata_obj
66
+ current_idx += numel
67
+
68
+ # Concatenate all flattened tensors
69
+ self.flattened_tensor = torch.cat(flattened_tensors, dim=0)
70
+ else:
71
+ # Initialize from pre-flattened data
72
+ if flattened_tensor is None or metadata is None:
73
+ raise ValueError(
74
+ "Must provide either named_tensors or both flattened_tensor and metadata"
75
+ )
76
+ self.flattened_tensor = flattened_tensor
77
+ self.metadata = metadata
78
+
79
+ def get_flattened_tensor(self) -> torch.Tensor:
80
+ """Get the flattened tensor containing all bucket tensors"""
81
+ return self.flattened_tensor
82
+
83
+ def get_metadata(self) -> List[FlattenedTensorMetadata]:
84
+ """Get metadata for all tensors in the bucket"""
85
+ return self.metadata
86
+
87
+ def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]:
88
+ """
89
+ Reconstruct original tensors from flattened tensor with optimized performance.
90
+ Uses memory-efficient operations to minimize allocations and copies.
91
+ """
92
+ # preallocate the result list
93
+ reconstructed = [None] * len(self.metadata)
94
+
95
+ for i, meta in enumerate(self.metadata):
96
+ tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].reshape(
97
+ meta.shape
98
+ )
99
+
100
+ # batch dtype conversion (if needed)
101
+ if tensor.dtype != meta.dtype:
102
+ tensor = tensor.to(meta.dtype)
103
+
104
+ reconstructed[i] = (meta.name, tensor)
105
+
106
+ return reconstructed
@@ -43,6 +43,37 @@ DEFAULT_CONFIG = {
43
43
  "layer_id": 0,
44
44
  }
45
45
 
46
+ ROPE_BASE = 10000
47
+ ROPE_SCALING_CONFIG = {
48
+ "beta_fast": 32,
49
+ "beta_slow": 1,
50
+ "factor": 40,
51
+ "mscale": 1.0,
52
+ "mscale_all_dim": 1.0,
53
+ "original_max_position_embeddings": 4096,
54
+ "type": "yarn",
55
+ "rope_type": "deepseek_yarn",
56
+ }
57
+
58
+
59
+ def build_rotary_emb(config, device=None):
60
+ from sglang.srt.layers.rotary_embedding import get_rope_wrapper
61
+
62
+ dev = device or config["device"]
63
+ rope_scaling = config.get("rope_scaling", ROPE_SCALING_CONFIG)
64
+ rotary = get_rope_wrapper(
65
+ head_size=config["qk_rope_head_dim"],
66
+ rotary_dim=config["qk_rope_head_dim"],
67
+ max_position=config["context_len"],
68
+ base=ROPE_BASE,
69
+ rope_scaling=rope_scaling,
70
+ is_neox_style=False,
71
+ device=dev,
72
+ )
73
+ rotary.cos_sin_cache = rotary.cos_sin_cache.to(dev)
74
+ return rotary
75
+
76
+
46
77
  # Centralized test cases for different test scenarios
47
78
  TEST_CASES = {
48
79
  "basic_functionality": [
@@ -63,18 +94,36 @@ TEST_CASES = {
63
94
  ],
64
95
  "decode_output_match": [
65
96
  {
66
- "name": "single",
97
+ "name": "single_fp16",
67
98
  "batch_size": 1,
68
99
  "max_seq_len": 64,
69
100
  "page_size": 32,
70
- "description": "Single vs reference",
101
+ "description": "Single FP16 vs reference",
71
102
  },
72
103
  {
73
- "name": "batch",
104
+ "name": "single_fp8",
105
+ "batch_size": 1,
106
+ "max_seq_len": 64,
107
+ "page_size": 64,
108
+ "tolerance": 1e-1,
109
+ "kv_cache_dtype": torch.float8_e4m3fn,
110
+ "description": "Single FP8 vs reference",
111
+ },
112
+ {
113
+ "name": "batch_fp16",
74
114
  "batch_size": 32,
75
115
  "max_seq_len": 64,
76
116
  "page_size": 32,
77
- "description": "Batch vs reference",
117
+ "description": "Batch FP16 vs reference",
118
+ },
119
+ {
120
+ "name": "batch_fp8",
121
+ "batch_size": 32,
122
+ "max_seq_len": 64,
123
+ "page_size": 64,
124
+ "tolerance": 1e-1,
125
+ "kv_cache_dtype": torch.float8_e4m3fn,
126
+ "description": "Batch FP8 vs reference",
78
127
  },
79
128
  ],
80
129
  "page_size_consistency": [
@@ -293,26 +342,52 @@ class TestTRTLLMMLA(CustomTestCase):
293
342
  layer,
294
343
  )
295
344
 
296
- def _create_qkv_tensors(self, batch_size, config):
297
- """Create Q, K, V tensors for testing."""
298
- head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
345
+ def _create_qkv_tensors(self, batch_size, config, dtype_override=None):
346
+ """Create Q, K, V random tensors for given batch size with separate MLA components.
347
+
348
+ Args:
349
+ batch_size: Batch size.
350
+ config: Configuration dict with model dims and device.
351
+ dtype_override: Optional torch dtype to override config["dtype"].
352
+
353
+ Returns:
354
+ Tuple of (q_nope, q_rope, k_nope, k_rope, v, cos_sin_cache)
355
+ """
299
356
  device = config["device"]
300
- dtype = config["dtype"]
357
+ target_dtype = dtype_override or config["dtype"]
301
358
 
302
- q = torch.randn(
303
- (batch_size, config["num_attention_heads"], head_dim),
304
- dtype=dtype,
359
+ # Create separate nope and rope components for Q
360
+ q_nope = torch.randn(
361
+ (batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
362
+ dtype=config["dtype"],
305
363
  device=device,
306
364
  )
307
- k = torch.randn(
308
- (batch_size, config["num_kv_heads"], head_dim), dtype=dtype, device=device
365
+ q_rope = torch.randn(
366
+ (batch_size, config["num_attention_heads"], config["qk_rope_head_dim"]),
367
+ dtype=config["dtype"],
368
+ device=device,
369
+ )
370
+
371
+ # Create separate nope and rope components for K
372
+ k_nope = torch.randn(
373
+ (batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
374
+ dtype=config["dtype"],
375
+ device=device,
376
+ )
377
+ k_rope = torch.randn(
378
+ (batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
379
+ dtype=config["dtype"],
380
+ device=device,
309
381
  )
382
+
383
+ # V tensor (unchanged)
310
384
  v = torch.randn(
311
385
  (batch_size, config["num_kv_heads"], config["v_head_dim"]),
312
- dtype=dtype,
386
+ dtype=config["dtype"],
313
387
  device=device,
314
388
  )
315
- return q, k, v
389
+
390
+ return q_nope, q_rope, k_nope, k_rope, v
316
391
 
317
392
  def _create_forward_batch(
318
393
  self, batch_size, seq_lens, backend, model_runner, config
@@ -331,6 +406,10 @@ class TestTRTLLMMLA(CustomTestCase):
331
406
  )
332
407
  fb.req_to_token_pool = model_runner.req_to_token_pool
333
408
  fb.token_to_kv_pool = model_runner.token_to_kv_pool
409
+
410
+ # Add position information for RoPE
411
+ fb.positions = torch.arange(batch_size, device=config["device"])
412
+
334
413
  return fb
335
414
 
336
415
  def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config):
@@ -344,7 +423,7 @@ class TestTRTLLMMLA(CustomTestCase):
344
423
  for token_idx in range(seq_len - 1):
345
424
  # Create random K components for MLA
346
425
  cache_k_nope = torch.randn(
347
- (1, config["qk_nope_head_dim"]),
426
+ (1, config["kv_lora_rank"]),
348
427
  dtype=config["dtype"],
349
428
  device=config["device"],
350
429
  )
@@ -411,12 +490,16 @@ class TestTRTLLMMLA(CustomTestCase):
411
490
  batch_size, seq_lens, [model_runner_trtllm], layer, config
412
491
  )
413
492
 
414
- # Create Q, K, V tensors
493
+ # Create Q, K, V tensors with separate MLA components
415
494
  torch.manual_seed(config["seed_qkv"])
416
- q, k, v = self._create_qkv_tensors(batch_size, config)
495
+ q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
496
+ batch_size, config
497
+ )
417
498
 
418
- # Run forward decode
419
- output = trtllm_backend.forward_decode(q, k, v, layer, fb)
499
+ # Run forward decode with separate MLA components
500
+ output = trtllm_backend.forward_decode(
501
+ q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
502
+ )
420
503
 
421
504
  # Basic checks
422
505
  expected_shape = (
@@ -439,6 +522,7 @@ class TestTRTLLMMLA(CustomTestCase):
439
522
  config = self._merge_config(test_case)
440
523
  batch_size = config["batch_size"]
441
524
  max_seq_len = config["max_seq_len"]
525
+ use_fp8 = config["kv_cache_dtype"] == torch.float8_e4m3fn
442
526
 
443
527
  # Create components
444
528
  (
@@ -487,19 +571,66 @@ class TestTRTLLMMLA(CustomTestCase):
487
571
 
488
572
  # Create Q, K, V tensors for current decode step
489
573
  torch.manual_seed(config["seed_qkv"])
490
- q, k, v = self._create_qkv_tensors(batch_size, config)
574
+
575
+ q_nope_ref, q_rope_ref, k_nope_ref, k_rope_ref, v_ref = (
576
+ self._create_qkv_tensors(batch_size, config)
577
+ )
578
+ q_nope_trt, q_rope_trt, k_nope_trt, k_rope_trt, v_trt = (
579
+ q_nope_ref.clone(),
580
+ q_rope_ref.clone(),
581
+ k_nope_ref.clone(),
582
+ k_rope_ref.clone(),
583
+ v_ref.clone(),
584
+ )
585
+ tolerance = config["tolerance"]
586
+
587
+ extra_args = {}
588
+ if use_fp8:
589
+ # TRT kernel applies RoPE + FP8 quantization internally
590
+ # pre-apply RoPE on the reference (FlashInfer) path here so
591
+ # both paths share the same rope params/cache while keeping
592
+ # the TRT path unrotated.
593
+ rotary_emb = build_rotary_emb(config)
594
+ q_rope_ref, k_rope_ref = rotary_emb(
595
+ fb_reference.positions, q_rope_ref, k_rope_ref
596
+ )
597
+ extra_args = {
598
+ "cos_sin_cache": rotary_emb.cos_sin_cache,
599
+ "is_neox": rotary_emb.is_neox_style,
600
+ }
601
+
602
+ dtype = q_rope_ref.dtype
603
+ q_rope_ref = q_rope_ref.to(torch.float8_e4m3fn).to(dtype)
604
+ q_nope_ref = q_nope_ref.to(torch.float8_e4m3fn).to(dtype)
605
+ k_rope_ref = k_rope_ref.to(torch.float8_e4m3fn).to(dtype)
606
+ k_nope_ref = k_nope_ref.to(torch.float8_e4m3fn).to(dtype)
491
607
 
492
608
  # Run forward decode on both backends
493
609
  out_trtllm = trtllm_backend.forward_decode(
494
- q.clone(), k.clone(), v.clone(), layer, fb_trtllm
610
+ q_nope_trt,
611
+ k_nope_trt,
612
+ None,
613
+ layer,
614
+ fb_trtllm,
615
+ q_rope=q_rope_trt,
616
+ k_rope=k_rope_trt,
617
+ **extra_args,
495
618
  )
619
+
620
+ # Reference backend should also take separate components, not concatenated
496
621
  out_reference = reference_backend.forward_decode(
497
- q.clone(), k.clone(), v.clone(), layer, fb_reference
622
+ q_nope_ref,
623
+ k_nope_ref,
624
+ v_ref,
625
+ layer,
626
+ fb_reference,
627
+ q_rope=q_rope_ref,
628
+ k_rope=k_rope_ref,
498
629
  )
499
630
 
500
631
  # Compare outputs
501
632
  comparison_passed = compare_outputs(
502
- out_trtllm, out_reference, tolerance=config["tolerance"]
633
+ out_trtllm, out_reference, tolerance=tolerance
503
634
  )
504
635
 
505
636
  self.assertTrue(
@@ -544,12 +675,16 @@ class TestTRTLLMMLA(CustomTestCase):
544
675
  batch_size, seq_lens, [model_runner], layer, config
545
676
  )
546
677
 
547
- # Create Q, K, V tensors
678
+ # Create Q, K, V tensors with separate MLA components
548
679
  torch.manual_seed(config["seed_qkv"])
549
- q, k, v = self._create_qkv_tensors(batch_size, config)
680
+ q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
681
+ batch_size, config
682
+ )
550
683
 
551
- # Run forward decode
552
- output = backend.forward_decode(q, k, v, layer, fb)
684
+ # Run forward decode with separate MLA components
685
+ output = backend.forward_decode(
686
+ q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
687
+ )
553
688
 
554
689
  expected_shape = (
555
690
  batch_size,
@@ -591,23 +726,38 @@ class TestTRTLLMMLA(CustomTestCase):
591
726
  )
592
727
  backend.init_forward_metadata(fb)
593
728
 
594
- # Create Q, K, V tensors
729
+ # Create Q, K, V tensors with separate MLA components
595
730
  torch.manual_seed(config["seed_qkv"])
596
- head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
597
- q = torch.randn(
598
- (batch_size, config["num_attention_heads"], head_dim),
731
+ q_nope = torch.randn(
732
+ (batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
599
733
  dtype=config["dtype"],
600
734
  device=config["device"],
601
735
  )
602
- k = torch.randn(
603
- (batch_size, config["num_kv_heads"], head_dim),
736
+ k_nope = torch.randn(
737
+ (batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
604
738
  dtype=config["dtype"],
605
739
  device=config["device"],
606
740
  )
607
- v = None
741
+ q_rope = torch.randn(
742
+ (
743
+ batch_size,
744
+ config["num_attention_heads"],
745
+ config["qk_rope_head_dim"],
746
+ ),
747
+ dtype=config["dtype"],
748
+ device=config["device"],
749
+ )
750
+ k_rope = torch.randn(
751
+ (batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
752
+ dtype=config["dtype"],
753
+ device=config["device"],
754
+ )
755
+ v = None # Test with None v
608
756
 
609
757
  # Run forward decode
610
- output = backend.forward_decode(q, k, v, layer, fb)
758
+ output = backend.forward_decode(
759
+ q_nope, k_nope, v, layer, fb, q_rope=q_rope, k_rope=k_rope
760
+ )
611
761
 
612
762
  # Shape and sanity checks
613
763
  expected_shape = (
@@ -0,0 +1,59 @@
1
+ """
2
+ Do some monkey patch to make the documentation compilation faster and more reliable.
3
+
4
+ - Avoid port conflicts
5
+ - Reduce the server launch time
6
+ """
7
+
8
+ import weakref
9
+
10
+ import nest_asyncio
11
+
12
+ nest_asyncio.apply()
13
+
14
+ import sglang.srt.server_args as server_args_mod
15
+ from sglang.utils import execute_shell_command, reserve_port
16
+
17
+ DEFAULT_MAX_RUNNING_REQUESTS = 128
18
+ DEFAULT_MAX_TOTAL_TOKENS = 20480 # To allow multiple servers on the same machine
19
+
20
+ _original_post_init = server_args_mod.ServerArgs.__post_init__
21
+
22
+
23
+ def patched_post_init(self):
24
+ _original_post_init(self)
25
+ if self.max_running_requests is None:
26
+ self.max_running_requests = DEFAULT_MAX_RUNNING_REQUESTS
27
+ if self.max_total_tokens is None:
28
+ self.max_total_tokens = DEFAULT_MAX_TOTAL_TOKENS
29
+ self.cuda_graph_max_bs = 4
30
+
31
+
32
+ server_args_mod.ServerArgs.__post_init__ = patched_post_init
33
+
34
+ process_socket_map = weakref.WeakKeyDictionary()
35
+
36
+
37
+ def launch_server_cmd(command: str, host: str = "0.0.0.0", port: int = None):
38
+ """
39
+ Launch the server using the given command.
40
+ If no port is specified, a free port is reserved.
41
+ """
42
+ if port is None:
43
+ port, lock_socket = reserve_port(host)
44
+ else:
45
+ lock_socket = None
46
+
47
+ extra_flags = (
48
+ f"--max-running-requests {DEFAULT_MAX_RUNNING_REQUESTS} "
49
+ f"--max-total-tokens {DEFAULT_MAX_TOTAL_TOKENS} "
50
+ f"--cuda-graph-max-bs 4"
51
+ )
52
+
53
+ full_command = f"{command} --port {port} {extra_flags}"
54
+ process = execute_shell_command(full_command)
55
+
56
+ if lock_socket is not None:
57
+ process_socket_map[process] = lock_socket
58
+
59
+ return process, port
@@ -12,7 +12,7 @@ import time
12
12
 
13
13
  import numpy as np
14
14
 
15
- from sglang.api import set_default_backend
15
+ from sglang.lang.api import set_default_backend
16
16
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
17
17
  from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
18
18
 
@@ -8,7 +8,7 @@ import time
8
8
  import numpy as np
9
9
 
10
10
  import sglang as sgl
11
- from sglang.api import set_default_backend
11
+ from sglang.lang.api import set_default_backend
12
12
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
13
13
  from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
14
14
 
sglang/test/run_eval.py CHANGED
@@ -65,9 +65,10 @@ def run_eval(args):
65
65
 
66
66
  sampler = ChatCompletionSampler(
67
67
  model=args.model,
68
- max_tokens=2048,
68
+ max_tokens=getattr(args, "max_tokens", 2048),
69
69
  base_url=base_url,
70
70
  temperature=getattr(args, "temperature", 0.0),
71
+ reasoning_effort=getattr(args, "reasoning_effort", None),
71
72
  )
72
73
 
73
74
  # Run eval
@@ -120,7 +121,9 @@ if __name__ == "__main__":
120
121
  parser.add_argument("--eval-name", type=str, default="mmlu")
121
122
  parser.add_argument("--num-examples", type=int)
122
123
  parser.add_argument("--num-threads", type=int, default=512)
124
+ parser.add_argument("--max-tokens", type=int, default=2048)
123
125
  parser.add_argument("--temperature", type=float, default=0.0)
126
+ parser.add_argument("--reasoning-effort", type=str)
124
127
  args = parser.parse_args()
125
128
 
126
129
  run_eval(args)
@@ -91,6 +91,7 @@ class ChatCompletionSampler(SamplerBase):
91
91
  model: Optional[str] = None,
92
92
  system_message: Optional[str] = None,
93
93
  temperature: float = 0.0,
94
+ reasoning_effort: Optional[str] = None,
94
95
  max_tokens: int = 2048,
95
96
  ):
96
97
  self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient())
@@ -102,7 +103,11 @@ class ChatCompletionSampler(SamplerBase):
102
103
  self.system_message = system_message
103
104
  self.temperature = temperature
104
105
  self.max_tokens = max_tokens
106
+ self.reasoning_effort = reasoning_effort
105
107
  self.image_format = "url"
108
+ print(
109
+ f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=}"
110
+ )
106
111
 
107
112
  def _handle_image(
108
113
  self,
@@ -138,6 +143,7 @@ class ChatCompletionSampler(SamplerBase):
138
143
  messages=message_list,
139
144
  temperature=self.temperature,
140
145
  max_tokens=self.max_tokens,
146
+ reasoning_effort=self.reasoning_effort,
141
147
  )
142
148
  return response.choices[0].message.content
143
149
  # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
@@ -71,6 +71,8 @@ class GPQAEval(Eval):
71
71
  )
72
72
  ]
73
73
  response_text = sampler(prompt_messages)
74
+ if response_text is None:
75
+ response_text = ""
74
76
  match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
75
77
  extracted_answer = match.group(1) if match else None
76
78
  score = 1.0 if extracted_answer == correct_answer else 0.0