sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
sglang/srt/two_batch_overlap.py
CHANGED
@@ -26,11 +26,13 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
26
26
|
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
27
27
|
from sglang.srt.operations_strategy import OperationsStrategy
|
28
28
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
29
|
-
from sglang.srt.utils import BumpAllocator, get_bool_env_var
|
29
|
+
from sglang.srt.utils import BumpAllocator, get_bool_env_var, is_hip
|
30
30
|
|
31
31
|
if TYPE_CHECKING:
|
32
32
|
from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
|
33
33
|
|
34
|
+
_is_hip = is_hip()
|
35
|
+
|
34
36
|
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
35
37
|
|
36
38
|
logger = logging.getLogger(__name__)
|
@@ -676,16 +678,12 @@ class TboForwardBatchPreparer:
|
|
676
678
|
# TODO improve, e.g. unify w/ `init_raw`
|
677
679
|
if (
|
678
680
|
global_server_args_dict["moe_dense_tp_size"] == 1
|
679
|
-
and batch.
|
681
|
+
and batch.global_dp_buffer_len is not None
|
680
682
|
):
|
681
683
|
sum_len = end_token_index - start_token_index
|
682
|
-
|
683
|
-
(sum_len, batch.gathered_buffer.shape[1]),
|
684
|
-
dtype=batch.gathered_buffer.dtype,
|
685
|
-
device=batch.gathered_buffer.device,
|
686
|
-
)
|
684
|
+
global_dp_buffer_len = sum_len
|
687
685
|
else:
|
688
|
-
|
686
|
+
global_dp_buffer_len = None
|
689
687
|
|
690
688
|
output_dict.update(
|
691
689
|
dict(
|
@@ -704,7 +702,7 @@ class TboForwardBatchPreparer:
|
|
704
702
|
global_num_tokens_gpu=None,
|
705
703
|
global_num_tokens_cpu=None,
|
706
704
|
dp_padding_mode=None,
|
707
|
-
|
705
|
+
global_dp_buffer_len=global_dp_buffer_len,
|
708
706
|
global_num_tokens_for_logprob_gpu=None,
|
709
707
|
global_num_tokens_for_logprob_cpu=None,
|
710
708
|
sampling_info=None,
|
@@ -822,9 +820,15 @@ def _model_forward_tbo(
|
|
822
820
|
)
|
823
821
|
del inputs
|
824
822
|
|
825
|
-
|
826
|
-
|
827
|
-
|
823
|
+
context = (
|
824
|
+
empty_context()
|
825
|
+
if _is_hip
|
826
|
+
else deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
827
|
+
operations_strategy.deep_gemm_num_sms
|
828
|
+
)
|
829
|
+
)
|
830
|
+
|
831
|
+
with context:
|
828
832
|
outputs_arr = execute_overlapped_operations(
|
829
833
|
inputs_arr=inputs_arr,
|
830
834
|
operations_arr=[operations_strategy.operations] * 2,
|
sglang/srt/utils.py
CHANGED
@@ -815,7 +815,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
|
815
815
|
vr = VideoReader(tmp_file.name, ctx=ctx)
|
816
816
|
elif video_file.startswith("data:"):
|
817
817
|
_, encoded = video_file.split(",", 1)
|
818
|
-
video_bytes =
|
818
|
+
video_bytes = pybase64.b64decode(encoded)
|
819
819
|
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
820
820
|
tmp_file.write(video_bytes)
|
821
821
|
tmp_file.close()
|
@@ -823,7 +823,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
|
823
823
|
elif os.path.isfile(video_file):
|
824
824
|
vr = VideoReader(video_file, ctx=ctx)
|
825
825
|
else:
|
826
|
-
video_bytes =
|
826
|
+
video_bytes = pybase64.b64decode(video_file)
|
827
827
|
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
828
828
|
tmp_file.write(video_bytes)
|
829
829
|
tmp_file.close()
|
@@ -2960,7 +2960,7 @@ class ConcurrentCounter:
|
|
2960
2960
|
This suspends the calling coroutine without blocking the thread, allowing
|
2961
2961
|
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
|
2962
2962
|
"""
|
2963
|
-
self.wait_for(lambda count: count == 0)
|
2963
|
+
await self.wait_for(lambda count: count == 0)
|
2964
2964
|
|
2965
2965
|
|
2966
2966
|
@lru_cache(maxsize=1)
|
@@ -0,0 +1,106 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import List, Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class FlattenedTensorMetadata:
|
9
|
+
"""Metadata for a tensor in a flattened bucket"""
|
10
|
+
|
11
|
+
name: str
|
12
|
+
shape: torch.Size
|
13
|
+
dtype: torch.dtype
|
14
|
+
start_idx: int
|
15
|
+
end_idx: int
|
16
|
+
numel: int
|
17
|
+
|
18
|
+
|
19
|
+
class FlattenedTensorBucket:
|
20
|
+
"""
|
21
|
+
A bucket that flattens multiple tensors into a single tensor for efficient processing
|
22
|
+
while preserving all metadata needed for reconstruction.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
named_tensors: List[Tuple[str, torch.Tensor]] = None,
|
28
|
+
flattened_tensor: torch.Tensor = None,
|
29
|
+
metadata: List[FlattenedTensorMetadata] = None,
|
30
|
+
):
|
31
|
+
"""
|
32
|
+
Initialize a tensor bucket from a list of named tensors OR from pre-flattened data.
|
33
|
+
Args:
|
34
|
+
named_tensors: List of (name, tensor) tuples (for creating new bucket)
|
35
|
+
flattened_tensor: Pre-flattened tensor (for reconstruction)
|
36
|
+
metadata: Pre-computed metadata (for reconstruction)
|
37
|
+
"""
|
38
|
+
if named_tensors is not None:
|
39
|
+
# Create bucket from named tensors
|
40
|
+
self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors)
|
41
|
+
self.flattened_tensor: torch.Tensor = None
|
42
|
+
|
43
|
+
if not named_tensors:
|
44
|
+
raise ValueError("Cannot create empty tensor bucket")
|
45
|
+
|
46
|
+
# Collect metadata and flatten tensors
|
47
|
+
current_idx = 0
|
48
|
+
flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors)
|
49
|
+
|
50
|
+
for i, (name, tensor) in enumerate(named_tensors):
|
51
|
+
flattened = tensor.flatten()
|
52
|
+
flattened_tensors[i] = flattened
|
53
|
+
|
54
|
+
# Store metadata
|
55
|
+
|
56
|
+
numel = flattened.numel()
|
57
|
+
metadata_obj = FlattenedTensorMetadata(
|
58
|
+
name=name,
|
59
|
+
shape=tensor.shape,
|
60
|
+
dtype=tensor.dtype,
|
61
|
+
start_idx=current_idx,
|
62
|
+
end_idx=current_idx + numel,
|
63
|
+
numel=numel,
|
64
|
+
)
|
65
|
+
self.metadata[i] = metadata_obj
|
66
|
+
current_idx += numel
|
67
|
+
|
68
|
+
# Concatenate all flattened tensors
|
69
|
+
self.flattened_tensor = torch.cat(flattened_tensors, dim=0)
|
70
|
+
else:
|
71
|
+
# Initialize from pre-flattened data
|
72
|
+
if flattened_tensor is None or metadata is None:
|
73
|
+
raise ValueError(
|
74
|
+
"Must provide either named_tensors or both flattened_tensor and metadata"
|
75
|
+
)
|
76
|
+
self.flattened_tensor = flattened_tensor
|
77
|
+
self.metadata = metadata
|
78
|
+
|
79
|
+
def get_flattened_tensor(self) -> torch.Tensor:
|
80
|
+
"""Get the flattened tensor containing all bucket tensors"""
|
81
|
+
return self.flattened_tensor
|
82
|
+
|
83
|
+
def get_metadata(self) -> List[FlattenedTensorMetadata]:
|
84
|
+
"""Get metadata for all tensors in the bucket"""
|
85
|
+
return self.metadata
|
86
|
+
|
87
|
+
def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]:
|
88
|
+
"""
|
89
|
+
Reconstruct original tensors from flattened tensor with optimized performance.
|
90
|
+
Uses memory-efficient operations to minimize allocations and copies.
|
91
|
+
"""
|
92
|
+
# preallocate the result list
|
93
|
+
reconstructed = [None] * len(self.metadata)
|
94
|
+
|
95
|
+
for i, meta in enumerate(self.metadata):
|
96
|
+
tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].reshape(
|
97
|
+
meta.shape
|
98
|
+
)
|
99
|
+
|
100
|
+
# batch dtype conversion (if needed)
|
101
|
+
if tensor.dtype != meta.dtype:
|
102
|
+
tensor = tensor.to(meta.dtype)
|
103
|
+
|
104
|
+
reconstructed[i] = (meta.name, tensor)
|
105
|
+
|
106
|
+
return reconstructed
|
@@ -43,6 +43,37 @@ DEFAULT_CONFIG = {
|
|
43
43
|
"layer_id": 0,
|
44
44
|
}
|
45
45
|
|
46
|
+
ROPE_BASE = 10000
|
47
|
+
ROPE_SCALING_CONFIG = {
|
48
|
+
"beta_fast": 32,
|
49
|
+
"beta_slow": 1,
|
50
|
+
"factor": 40,
|
51
|
+
"mscale": 1.0,
|
52
|
+
"mscale_all_dim": 1.0,
|
53
|
+
"original_max_position_embeddings": 4096,
|
54
|
+
"type": "yarn",
|
55
|
+
"rope_type": "deepseek_yarn",
|
56
|
+
}
|
57
|
+
|
58
|
+
|
59
|
+
def build_rotary_emb(config, device=None):
|
60
|
+
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
61
|
+
|
62
|
+
dev = device or config["device"]
|
63
|
+
rope_scaling = config.get("rope_scaling", ROPE_SCALING_CONFIG)
|
64
|
+
rotary = get_rope_wrapper(
|
65
|
+
head_size=config["qk_rope_head_dim"],
|
66
|
+
rotary_dim=config["qk_rope_head_dim"],
|
67
|
+
max_position=config["context_len"],
|
68
|
+
base=ROPE_BASE,
|
69
|
+
rope_scaling=rope_scaling,
|
70
|
+
is_neox_style=False,
|
71
|
+
device=dev,
|
72
|
+
)
|
73
|
+
rotary.cos_sin_cache = rotary.cos_sin_cache.to(dev)
|
74
|
+
return rotary
|
75
|
+
|
76
|
+
|
46
77
|
# Centralized test cases for different test scenarios
|
47
78
|
TEST_CASES = {
|
48
79
|
"basic_functionality": [
|
@@ -63,18 +94,36 @@ TEST_CASES = {
|
|
63
94
|
],
|
64
95
|
"decode_output_match": [
|
65
96
|
{
|
66
|
-
"name": "
|
97
|
+
"name": "single_fp16",
|
67
98
|
"batch_size": 1,
|
68
99
|
"max_seq_len": 64,
|
69
100
|
"page_size": 32,
|
70
|
-
"description": "Single vs reference",
|
101
|
+
"description": "Single FP16 vs reference",
|
71
102
|
},
|
72
103
|
{
|
73
|
-
"name": "
|
104
|
+
"name": "single_fp8",
|
105
|
+
"batch_size": 1,
|
106
|
+
"max_seq_len": 64,
|
107
|
+
"page_size": 64,
|
108
|
+
"tolerance": 1e-1,
|
109
|
+
"kv_cache_dtype": torch.float8_e4m3fn,
|
110
|
+
"description": "Single FP8 vs reference",
|
111
|
+
},
|
112
|
+
{
|
113
|
+
"name": "batch_fp16",
|
74
114
|
"batch_size": 32,
|
75
115
|
"max_seq_len": 64,
|
76
116
|
"page_size": 32,
|
77
|
-
"description": "Batch vs reference",
|
117
|
+
"description": "Batch FP16 vs reference",
|
118
|
+
},
|
119
|
+
{
|
120
|
+
"name": "batch_fp8",
|
121
|
+
"batch_size": 32,
|
122
|
+
"max_seq_len": 64,
|
123
|
+
"page_size": 64,
|
124
|
+
"tolerance": 1e-1,
|
125
|
+
"kv_cache_dtype": torch.float8_e4m3fn,
|
126
|
+
"description": "Batch FP8 vs reference",
|
78
127
|
},
|
79
128
|
],
|
80
129
|
"page_size_consistency": [
|
@@ -293,26 +342,52 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
293
342
|
layer,
|
294
343
|
)
|
295
344
|
|
296
|
-
def _create_qkv_tensors(self, batch_size, config):
|
297
|
-
"""Create Q, K, V tensors for
|
298
|
-
|
345
|
+
def _create_qkv_tensors(self, batch_size, config, dtype_override=None):
|
346
|
+
"""Create Q, K, V random tensors for given batch size with separate MLA components.
|
347
|
+
|
348
|
+
Args:
|
349
|
+
batch_size: Batch size.
|
350
|
+
config: Configuration dict with model dims and device.
|
351
|
+
dtype_override: Optional torch dtype to override config["dtype"].
|
352
|
+
|
353
|
+
Returns:
|
354
|
+
Tuple of (q_nope, q_rope, k_nope, k_rope, v, cos_sin_cache)
|
355
|
+
"""
|
299
356
|
device = config["device"]
|
300
|
-
|
357
|
+
target_dtype = dtype_override or config["dtype"]
|
301
358
|
|
302
|
-
|
303
|
-
|
304
|
-
|
359
|
+
# Create separate nope and rope components for Q
|
360
|
+
q_nope = torch.randn(
|
361
|
+
(batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
|
362
|
+
dtype=config["dtype"],
|
305
363
|
device=device,
|
306
364
|
)
|
307
|
-
|
308
|
-
(batch_size, config["
|
365
|
+
q_rope = torch.randn(
|
366
|
+
(batch_size, config["num_attention_heads"], config["qk_rope_head_dim"]),
|
367
|
+
dtype=config["dtype"],
|
368
|
+
device=device,
|
369
|
+
)
|
370
|
+
|
371
|
+
# Create separate nope and rope components for K
|
372
|
+
k_nope = torch.randn(
|
373
|
+
(batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
|
374
|
+
dtype=config["dtype"],
|
375
|
+
device=device,
|
376
|
+
)
|
377
|
+
k_rope = torch.randn(
|
378
|
+
(batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
|
379
|
+
dtype=config["dtype"],
|
380
|
+
device=device,
|
309
381
|
)
|
382
|
+
|
383
|
+
# V tensor (unchanged)
|
310
384
|
v = torch.randn(
|
311
385
|
(batch_size, config["num_kv_heads"], config["v_head_dim"]),
|
312
|
-
dtype=dtype,
|
386
|
+
dtype=config["dtype"],
|
313
387
|
device=device,
|
314
388
|
)
|
315
|
-
|
389
|
+
|
390
|
+
return q_nope, q_rope, k_nope, k_rope, v
|
316
391
|
|
317
392
|
def _create_forward_batch(
|
318
393
|
self, batch_size, seq_lens, backend, model_runner, config
|
@@ -331,6 +406,10 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
331
406
|
)
|
332
407
|
fb.req_to_token_pool = model_runner.req_to_token_pool
|
333
408
|
fb.token_to_kv_pool = model_runner.token_to_kv_pool
|
409
|
+
|
410
|
+
# Add position information for RoPE
|
411
|
+
fb.positions = torch.arange(batch_size, device=config["device"])
|
412
|
+
|
334
413
|
return fb
|
335
414
|
|
336
415
|
def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config):
|
@@ -344,7 +423,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
344
423
|
for token_idx in range(seq_len - 1):
|
345
424
|
# Create random K components for MLA
|
346
425
|
cache_k_nope = torch.randn(
|
347
|
-
(1, config["
|
426
|
+
(1, config["kv_lora_rank"]),
|
348
427
|
dtype=config["dtype"],
|
349
428
|
device=config["device"],
|
350
429
|
)
|
@@ -411,12 +490,16 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
411
490
|
batch_size, seq_lens, [model_runner_trtllm], layer, config
|
412
491
|
)
|
413
492
|
|
414
|
-
# Create Q, K, V tensors
|
493
|
+
# Create Q, K, V tensors with separate MLA components
|
415
494
|
torch.manual_seed(config["seed_qkv"])
|
416
|
-
|
495
|
+
q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
|
496
|
+
batch_size, config
|
497
|
+
)
|
417
498
|
|
418
|
-
# Run forward decode
|
419
|
-
output = trtllm_backend.forward_decode(
|
499
|
+
# Run forward decode with separate MLA components
|
500
|
+
output = trtllm_backend.forward_decode(
|
501
|
+
q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
|
502
|
+
)
|
420
503
|
|
421
504
|
# Basic checks
|
422
505
|
expected_shape = (
|
@@ -439,6 +522,7 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
439
522
|
config = self._merge_config(test_case)
|
440
523
|
batch_size = config["batch_size"]
|
441
524
|
max_seq_len = config["max_seq_len"]
|
525
|
+
use_fp8 = config["kv_cache_dtype"] == torch.float8_e4m3fn
|
442
526
|
|
443
527
|
# Create components
|
444
528
|
(
|
@@ -487,19 +571,66 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
487
571
|
|
488
572
|
# Create Q, K, V tensors for current decode step
|
489
573
|
torch.manual_seed(config["seed_qkv"])
|
490
|
-
|
574
|
+
|
575
|
+
q_nope_ref, q_rope_ref, k_nope_ref, k_rope_ref, v_ref = (
|
576
|
+
self._create_qkv_tensors(batch_size, config)
|
577
|
+
)
|
578
|
+
q_nope_trt, q_rope_trt, k_nope_trt, k_rope_trt, v_trt = (
|
579
|
+
q_nope_ref.clone(),
|
580
|
+
q_rope_ref.clone(),
|
581
|
+
k_nope_ref.clone(),
|
582
|
+
k_rope_ref.clone(),
|
583
|
+
v_ref.clone(),
|
584
|
+
)
|
585
|
+
tolerance = config["tolerance"]
|
586
|
+
|
587
|
+
extra_args = {}
|
588
|
+
if use_fp8:
|
589
|
+
# TRT kernel applies RoPE + FP8 quantization internally
|
590
|
+
# pre-apply RoPE on the reference (FlashInfer) path here so
|
591
|
+
# both paths share the same rope params/cache while keeping
|
592
|
+
# the TRT path unrotated.
|
593
|
+
rotary_emb = build_rotary_emb(config)
|
594
|
+
q_rope_ref, k_rope_ref = rotary_emb(
|
595
|
+
fb_reference.positions, q_rope_ref, k_rope_ref
|
596
|
+
)
|
597
|
+
extra_args = {
|
598
|
+
"cos_sin_cache": rotary_emb.cos_sin_cache,
|
599
|
+
"is_neox": rotary_emb.is_neox_style,
|
600
|
+
}
|
601
|
+
|
602
|
+
dtype = q_rope_ref.dtype
|
603
|
+
q_rope_ref = q_rope_ref.to(torch.float8_e4m3fn).to(dtype)
|
604
|
+
q_nope_ref = q_nope_ref.to(torch.float8_e4m3fn).to(dtype)
|
605
|
+
k_rope_ref = k_rope_ref.to(torch.float8_e4m3fn).to(dtype)
|
606
|
+
k_nope_ref = k_nope_ref.to(torch.float8_e4m3fn).to(dtype)
|
491
607
|
|
492
608
|
# Run forward decode on both backends
|
493
609
|
out_trtllm = trtllm_backend.forward_decode(
|
494
|
-
|
610
|
+
q_nope_trt,
|
611
|
+
k_nope_trt,
|
612
|
+
None,
|
613
|
+
layer,
|
614
|
+
fb_trtllm,
|
615
|
+
q_rope=q_rope_trt,
|
616
|
+
k_rope=k_rope_trt,
|
617
|
+
**extra_args,
|
495
618
|
)
|
619
|
+
|
620
|
+
# Reference backend should also take separate components, not concatenated
|
496
621
|
out_reference = reference_backend.forward_decode(
|
497
|
-
|
622
|
+
q_nope_ref,
|
623
|
+
k_nope_ref,
|
624
|
+
v_ref,
|
625
|
+
layer,
|
626
|
+
fb_reference,
|
627
|
+
q_rope=q_rope_ref,
|
628
|
+
k_rope=k_rope_ref,
|
498
629
|
)
|
499
630
|
|
500
631
|
# Compare outputs
|
501
632
|
comparison_passed = compare_outputs(
|
502
|
-
out_trtllm, out_reference, tolerance=
|
633
|
+
out_trtllm, out_reference, tolerance=tolerance
|
503
634
|
)
|
504
635
|
|
505
636
|
self.assertTrue(
|
@@ -544,12 +675,16 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
544
675
|
batch_size, seq_lens, [model_runner], layer, config
|
545
676
|
)
|
546
677
|
|
547
|
-
# Create Q, K, V tensors
|
678
|
+
# Create Q, K, V tensors with separate MLA components
|
548
679
|
torch.manual_seed(config["seed_qkv"])
|
549
|
-
|
680
|
+
q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
|
681
|
+
batch_size, config
|
682
|
+
)
|
550
683
|
|
551
|
-
# Run forward decode
|
552
|
-
output = backend.forward_decode(
|
684
|
+
# Run forward decode with separate MLA components
|
685
|
+
output = backend.forward_decode(
|
686
|
+
q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
|
687
|
+
)
|
553
688
|
|
554
689
|
expected_shape = (
|
555
690
|
batch_size,
|
@@ -591,23 +726,38 @@ class TestTRTLLMMLA(CustomTestCase):
|
|
591
726
|
)
|
592
727
|
backend.init_forward_metadata(fb)
|
593
728
|
|
594
|
-
# Create Q, K, V tensors
|
729
|
+
# Create Q, K, V tensors with separate MLA components
|
595
730
|
torch.manual_seed(config["seed_qkv"])
|
596
|
-
|
597
|
-
|
598
|
-
(batch_size, config["num_attention_heads"], head_dim),
|
731
|
+
q_nope = torch.randn(
|
732
|
+
(batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
|
599
733
|
dtype=config["dtype"],
|
600
734
|
device=config["device"],
|
601
735
|
)
|
602
|
-
|
603
|
-
(batch_size, config["num_kv_heads"],
|
736
|
+
k_nope = torch.randn(
|
737
|
+
(batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
|
604
738
|
dtype=config["dtype"],
|
605
739
|
device=config["device"],
|
606
740
|
)
|
607
|
-
|
741
|
+
q_rope = torch.randn(
|
742
|
+
(
|
743
|
+
batch_size,
|
744
|
+
config["num_attention_heads"],
|
745
|
+
config["qk_rope_head_dim"],
|
746
|
+
),
|
747
|
+
dtype=config["dtype"],
|
748
|
+
device=config["device"],
|
749
|
+
)
|
750
|
+
k_rope = torch.randn(
|
751
|
+
(batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
|
752
|
+
dtype=config["dtype"],
|
753
|
+
device=config["device"],
|
754
|
+
)
|
755
|
+
v = None # Test with None v
|
608
756
|
|
609
757
|
# Run forward decode
|
610
|
-
output = backend.forward_decode(
|
758
|
+
output = backend.forward_decode(
|
759
|
+
q_nope, k_nope, v, layer, fb, q_rope=q_rope, k_rope=k_rope
|
760
|
+
)
|
611
761
|
|
612
762
|
# Shape and sanity checks
|
613
763
|
expected_shape = (
|
sglang/test/doc_patch.py
ADDED
@@ -0,0 +1,59 @@
|
|
1
|
+
"""
|
2
|
+
Do some monkey patch to make the documentation compilation faster and more reliable.
|
3
|
+
|
4
|
+
- Avoid port conflicts
|
5
|
+
- Reduce the server launch time
|
6
|
+
"""
|
7
|
+
|
8
|
+
import weakref
|
9
|
+
|
10
|
+
import nest_asyncio
|
11
|
+
|
12
|
+
nest_asyncio.apply()
|
13
|
+
|
14
|
+
import sglang.srt.server_args as server_args_mod
|
15
|
+
from sglang.utils import execute_shell_command, reserve_port
|
16
|
+
|
17
|
+
DEFAULT_MAX_RUNNING_REQUESTS = 128
|
18
|
+
DEFAULT_MAX_TOTAL_TOKENS = 20480 # To allow multiple servers on the same machine
|
19
|
+
|
20
|
+
_original_post_init = server_args_mod.ServerArgs.__post_init__
|
21
|
+
|
22
|
+
|
23
|
+
def patched_post_init(self):
|
24
|
+
_original_post_init(self)
|
25
|
+
if self.max_running_requests is None:
|
26
|
+
self.max_running_requests = DEFAULT_MAX_RUNNING_REQUESTS
|
27
|
+
if self.max_total_tokens is None:
|
28
|
+
self.max_total_tokens = DEFAULT_MAX_TOTAL_TOKENS
|
29
|
+
self.cuda_graph_max_bs = 4
|
30
|
+
|
31
|
+
|
32
|
+
server_args_mod.ServerArgs.__post_init__ = patched_post_init
|
33
|
+
|
34
|
+
process_socket_map = weakref.WeakKeyDictionary()
|
35
|
+
|
36
|
+
|
37
|
+
def launch_server_cmd(command: str, host: str = "0.0.0.0", port: int = None):
|
38
|
+
"""
|
39
|
+
Launch the server using the given command.
|
40
|
+
If no port is specified, a free port is reserved.
|
41
|
+
"""
|
42
|
+
if port is None:
|
43
|
+
port, lock_socket = reserve_port(host)
|
44
|
+
else:
|
45
|
+
lock_socket = None
|
46
|
+
|
47
|
+
extra_flags = (
|
48
|
+
f"--max-running-requests {DEFAULT_MAX_RUNNING_REQUESTS} "
|
49
|
+
f"--max-total-tokens {DEFAULT_MAX_TOTAL_TOKENS} "
|
50
|
+
f"--cuda-graph-max-bs 4"
|
51
|
+
)
|
52
|
+
|
53
|
+
full_command = f"{command} --port {port} {extra_flags}"
|
54
|
+
process = execute_shell_command(full_command)
|
55
|
+
|
56
|
+
if lock_socket is not None:
|
57
|
+
process_socket_map[process] = lock_socket
|
58
|
+
|
59
|
+
return process, port
|
sglang/test/few_shot_gsm8k.py
CHANGED
@@ -12,7 +12,7 @@ import time
|
|
12
12
|
|
13
13
|
import numpy as np
|
14
14
|
|
15
|
-
from sglang.api import set_default_backend
|
15
|
+
from sglang.lang.api import set_default_backend
|
16
16
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
17
17
|
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
18
18
|
|
@@ -8,7 +8,7 @@ import time
|
|
8
8
|
import numpy as np
|
9
9
|
|
10
10
|
import sglang as sgl
|
11
|
-
from sglang.api import set_default_backend
|
11
|
+
from sglang.lang.api import set_default_backend
|
12
12
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
13
13
|
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
14
14
|
|
sglang/test/run_eval.py
CHANGED
@@ -65,9 +65,10 @@ def run_eval(args):
|
|
65
65
|
|
66
66
|
sampler = ChatCompletionSampler(
|
67
67
|
model=args.model,
|
68
|
-
max_tokens=2048,
|
68
|
+
max_tokens=getattr(args, "max_tokens", 2048),
|
69
69
|
base_url=base_url,
|
70
70
|
temperature=getattr(args, "temperature", 0.0),
|
71
|
+
reasoning_effort=getattr(args, "reasoning_effort", None),
|
71
72
|
)
|
72
73
|
|
73
74
|
# Run eval
|
@@ -120,7 +121,9 @@ if __name__ == "__main__":
|
|
120
121
|
parser.add_argument("--eval-name", type=str, default="mmlu")
|
121
122
|
parser.add_argument("--num-examples", type=int)
|
122
123
|
parser.add_argument("--num-threads", type=int, default=512)
|
124
|
+
parser.add_argument("--max-tokens", type=int, default=2048)
|
123
125
|
parser.add_argument("--temperature", type=float, default=0.0)
|
126
|
+
parser.add_argument("--reasoning-effort", type=str)
|
124
127
|
args = parser.parse_args()
|
125
128
|
|
126
129
|
run_eval(args)
|
@@ -91,6 +91,7 @@ class ChatCompletionSampler(SamplerBase):
|
|
91
91
|
model: Optional[str] = None,
|
92
92
|
system_message: Optional[str] = None,
|
93
93
|
temperature: float = 0.0,
|
94
|
+
reasoning_effort: Optional[str] = None,
|
94
95
|
max_tokens: int = 2048,
|
95
96
|
):
|
96
97
|
self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient())
|
@@ -102,7 +103,11 @@ class ChatCompletionSampler(SamplerBase):
|
|
102
103
|
self.system_message = system_message
|
103
104
|
self.temperature = temperature
|
104
105
|
self.max_tokens = max_tokens
|
106
|
+
self.reasoning_effort = reasoning_effort
|
105
107
|
self.image_format = "url"
|
108
|
+
print(
|
109
|
+
f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=}"
|
110
|
+
)
|
106
111
|
|
107
112
|
def _handle_image(
|
108
113
|
self,
|
@@ -138,6 +143,7 @@ class ChatCompletionSampler(SamplerBase):
|
|
138
143
|
messages=message_list,
|
139
144
|
temperature=self.temperature,
|
140
145
|
max_tokens=self.max_tokens,
|
146
|
+
reasoning_effort=self.reasoning_effort,
|
141
147
|
)
|
142
148
|
return response.choices[0].message.content
|
143
149
|
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
|
sglang/test/simple_eval_gpqa.py
CHANGED
@@ -71,6 +71,8 @@ class GPQAEval(Eval):
|
|
71
71
|
)
|
72
72
|
]
|
73
73
|
response_text = sampler(prompt_messages)
|
74
|
+
if response_text is None:
|
75
|
+
response_text = ""
|
74
76
|
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
|
75
77
|
extracted_answer = match.group(1) if match else None
|
76
78
|
score = 1.0 if extracted_answer == correct_answer else 0.0
|