sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/test/send_one.py
CHANGED
@@ -27,6 +27,7 @@ class BenchArgs:
|
|
27
27
|
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
|
28
28
|
)
|
29
29
|
image: bool = False
|
30
|
+
many_images: bool = False
|
30
31
|
stream: bool = False
|
31
32
|
|
32
33
|
@staticmethod
|
@@ -48,6 +49,7 @@ class BenchArgs:
|
|
48
49
|
parser.add_argument("--return-logprob", action="store_true")
|
49
50
|
parser.add_argument("--prompt", type=str, default=BenchArgs.prompt)
|
50
51
|
parser.add_argument("--image", action="store_true")
|
52
|
+
parser.add_argument("--many-images", action="store_true")
|
51
53
|
parser.add_argument("--stream", action="store_true")
|
52
54
|
|
53
55
|
@classmethod
|
@@ -62,6 +64,17 @@ def send_one_prompt(args):
|
|
62
64
|
"Human: Describe this image in a very short sentence.\n\nAssistant:"
|
63
65
|
)
|
64
66
|
image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
|
67
|
+
elif args.many_images:
|
68
|
+
args.prompt = (
|
69
|
+
"Human: I have one reference image and many images."
|
70
|
+
"Describe their relationship in a very short sentence.\n\nAssistant:"
|
71
|
+
)
|
72
|
+
image_data = [
|
73
|
+
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
74
|
+
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
75
|
+
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
76
|
+
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
|
77
|
+
]
|
65
78
|
else:
|
66
79
|
image_data = None
|
67
80
|
|
@@ -74,9 +87,6 @@ def send_one_prompt(args):
|
|
74
87
|
"Write in a format of json.\nAssistant:"
|
75
88
|
)
|
76
89
|
json_schema = "$$ANY$$"
|
77
|
-
json_schema = (
|
78
|
-
'{"type": "object", "properties": {"population": {"type": "integer"}}}'
|
79
|
-
)
|
80
90
|
else:
|
81
91
|
json_schema = None
|
82
92
|
|
@@ -140,7 +140,7 @@ class ChatCompletionSampler(SamplerBase):
|
|
140
140
|
max_tokens=self.max_tokens,
|
141
141
|
)
|
142
142
|
return response.choices[0].message.content
|
143
|
-
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are
|
143
|
+
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
|
144
144
|
except openai.BadRequestError as e:
|
145
145
|
print("Bad Request Error", e)
|
146
146
|
return ""
|
@@ -121,7 +121,7 @@ class HumanEval(Eval):
|
|
121
121
|
convo=convo,
|
122
122
|
metrics={
|
123
123
|
f"pass@{k}": estimate_pass_at_k([total], [correct], k)
|
124
|
-
# this will be
|
124
|
+
# this will be aggregated so no need of .mean()
|
125
125
|
for k in self._ks_passes
|
126
126
|
if total >= k
|
127
127
|
},
|
@@ -0,0 +1,278 @@
|
|
1
|
+
import argparse
|
2
|
+
import time
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import triton # Added import
|
6
|
+
import triton.testing # Added import
|
7
|
+
from transformers import AutoConfig
|
8
|
+
|
9
|
+
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts
|
10
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
11
|
+
|
12
|
+
|
13
|
+
def get_model_config(tp_size: int):
|
14
|
+
config = AutoConfig.from_pretrained(
|
15
|
+
"deepseek-ai/deepseek-R1", trust_remote_code=True
|
16
|
+
)
|
17
|
+
E = config.n_routed_experts
|
18
|
+
topk = config.num_experts_per_tok
|
19
|
+
intermediate_size = config.moe_intermediate_size
|
20
|
+
shard_intermediate_size = 2 * intermediate_size // tp_size
|
21
|
+
|
22
|
+
return {
|
23
|
+
"num_experts": E,
|
24
|
+
"topk": topk,
|
25
|
+
"hidden_size": config.hidden_size,
|
26
|
+
"shard_intermediate_size": shard_intermediate_size,
|
27
|
+
"dtype": config.torch_dtype,
|
28
|
+
"block_shape": config.quantization_config["weight_block_size"],
|
29
|
+
}
|
30
|
+
|
31
|
+
|
32
|
+
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
33
|
+
"""Converts tensor to FP8 E4M3, scaling values to fit the range."""
|
34
|
+
finfo = torch.finfo(torch.float8_e4m3fn)
|
35
|
+
# Calculate max absolute value safely
|
36
|
+
max_val = torch.max(torch.abs(tensor))
|
37
|
+
# Avoid division by zero if tensor is all zeros
|
38
|
+
if max_val == 0:
|
39
|
+
scale_factor = 1.0
|
40
|
+
else:
|
41
|
+
# Scale factor to bring the max value to finfo.max
|
42
|
+
scale_factor = finfo.max / max_val
|
43
|
+
|
44
|
+
# Apply scaling
|
45
|
+
scaled_tensor = tensor * scale_factor
|
46
|
+
|
47
|
+
# Clamp and convert
|
48
|
+
fp8_tensor = scaled_tensor.clamp(min=finfo.min, max=finfo.max).to(
|
49
|
+
dtype=torch.float8_e4m3fn
|
50
|
+
)
|
51
|
+
return fp8_tensor
|
52
|
+
|
53
|
+
|
54
|
+
def run_test(tp_size, batch_size, model_config, check=False):
|
55
|
+
print(f"\n--- Batch Size: {batch_size} ---")
|
56
|
+
torch.set_default_device("cuda")
|
57
|
+
torch.cuda.manual_seed_all(42) # For reproducible random numbers
|
58
|
+
|
59
|
+
E = model_config["num_experts"]
|
60
|
+
topk = model_config["topk"]
|
61
|
+
H = model_config["hidden_size"]
|
62
|
+
I = model_config["shard_intermediate_size"]
|
63
|
+
block_shape = model_config["block_shape"] # Tuple (BLOCK_N, BLOCK_K)
|
64
|
+
dtype = model_config["dtype"] # e.g., torch.bfloat16
|
65
|
+
|
66
|
+
print(
|
67
|
+
f"Config: E={E}, topk={topk}, H={H}, I_shard={I}, dtype={dtype}, block_shape={block_shape}"
|
68
|
+
)
|
69
|
+
|
70
|
+
# --- Input Data ---
|
71
|
+
# Use bf16/fp16 for input activation based on model config
|
72
|
+
x = torch.randn((batch_size, H), device="cuda", dtype=dtype) * 0.0001
|
73
|
+
# --- Weights (Generate in higher precision, then convert to FP8) ---
|
74
|
+
# Generate weights suitable for FP8 conversion (e.g., scaled appropriately)
|
75
|
+
w1_hp = (
|
76
|
+
torch.randn((E, I, H), device="cuda", dtype=torch.float32) * 0.00001 + 0.00001
|
77
|
+
)
|
78
|
+
w2_hp = (
|
79
|
+
torch.randn((E, H, I // 2), device="cuda", dtype=torch.float32) * 0.00001
|
80
|
+
+ 0.00001
|
81
|
+
)
|
82
|
+
|
83
|
+
w1 = to_fp8(w1_hp)
|
84
|
+
w2 = to_fp8(w2_hp)
|
85
|
+
|
86
|
+
# --- Scales for FP8 Weights ---
|
87
|
+
block_n, block_k = block_shape
|
88
|
+
# Calculate number of blocks needed
|
89
|
+
w1_blocks_dim1 = (I + block_n - 1) // block_n
|
90
|
+
w1_blocks_dim2 = (H + block_k - 1) // block_k
|
91
|
+
w2_blocks_dim1 = (H + block_n - 1) // block_n
|
92
|
+
w2_blocks_dim2 = (I // 2 + block_k - 1) // block_k
|
93
|
+
|
94
|
+
# Scales are typically float32 or float16/bfloat16
|
95
|
+
scale_dtype = torch.float32 # Or dtype if scales match model dtype
|
96
|
+
w1_scale = torch.full(
|
97
|
+
(E, w1_blocks_dim1, w1_blocks_dim2), 1, device="cuda", dtype=scale_dtype
|
98
|
+
) # Avoid zero scales
|
99
|
+
w2_scale = torch.full(
|
100
|
+
(E, w2_blocks_dim1, w2_blocks_dim2), 1, device="cuda", dtype=scale_dtype
|
101
|
+
) # Avoid zero scales
|
102
|
+
|
103
|
+
# --- Routing Information ---
|
104
|
+
topk_weights = torch.softmax(
|
105
|
+
torch.rand(batch_size, topk, device="cuda", dtype=dtype), dim=-1
|
106
|
+
)
|
107
|
+
topk_ids = torch.randint(0, E, (batch_size, topk), dtype=torch.int32, device="cuda")
|
108
|
+
|
109
|
+
a1_strides = torch.full((E,), H, dtype=torch.int64, device="cuda")
|
110
|
+
c1_strides = torch.full((E,), I, dtype=torch.int64, device="cuda")
|
111
|
+
a2_strides = torch.full((E,), I // 2, dtype=torch.int64, device="cuda")
|
112
|
+
c2_strides = torch.full((E,), H, dtype=torch.int64, device="cuda")
|
113
|
+
|
114
|
+
workspace = torch.empty(
|
115
|
+
(7182 * 1024), device="cuda", dtype=torch.uint8
|
116
|
+
) # Allocate sufficient workspace
|
117
|
+
# Pointer arrays (often filled by the kernel or a prep step, but needed as args)
|
118
|
+
a_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda")
|
119
|
+
b_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda")
|
120
|
+
out_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda")
|
121
|
+
a_scales_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda")
|
122
|
+
b_scales_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda")
|
123
|
+
expert_offsets = torch.empty((E + 1,), dtype=torch.int32, device="cuda")
|
124
|
+
problem_sizes1 = torch.empty((E, 3), dtype=torch.int32, device="cuda")
|
125
|
+
problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device="cuda")
|
126
|
+
|
127
|
+
# --- Lambdas for Benchmarking ---
|
128
|
+
cutlass_lambda = lambda: cutlass_fused_experts(
|
129
|
+
x,
|
130
|
+
w1.transpose(1, 2), # Transposed
|
131
|
+
w2.transpose(1, 2), # Transposed
|
132
|
+
w1_scale.transpose(1, 2),
|
133
|
+
w2_scale.transpose(1, 2),
|
134
|
+
topk_weights,
|
135
|
+
topk_ids,
|
136
|
+
a1_strides,
|
137
|
+
c1_strides,
|
138
|
+
a2_strides,
|
139
|
+
c2_strides,
|
140
|
+
workspace,
|
141
|
+
a_ptrs,
|
142
|
+
b_ptrs,
|
143
|
+
out_ptrs,
|
144
|
+
a_scales_ptrs,
|
145
|
+
b_scales_ptrs,
|
146
|
+
expert_offsets,
|
147
|
+
problem_sizes1,
|
148
|
+
problem_sizes2,
|
149
|
+
)
|
150
|
+
|
151
|
+
# Note: Triton expects non-transposed weights
|
152
|
+
triton_lambda = lambda: fused_experts(
|
153
|
+
x,
|
154
|
+
w1,
|
155
|
+
w2,
|
156
|
+
topk_weights,
|
157
|
+
topk_ids,
|
158
|
+
inplace=False, # Use False for benchmarking to avoid side effects if run multiple times
|
159
|
+
activation="silu", # Assuming SiLU activation common in MoEs
|
160
|
+
use_fp8_w8a8=True,
|
161
|
+
w1_scale=w1_scale,
|
162
|
+
w2_scale=w2_scale,
|
163
|
+
block_shape=block_shape,
|
164
|
+
)
|
165
|
+
|
166
|
+
# --- Warmup ---
|
167
|
+
print("Warming up...")
|
168
|
+
for _ in range(10):
|
169
|
+
_ = cutlass_lambda()
|
170
|
+
_ = triton_lambda()
|
171
|
+
torch.cuda.synchronize()
|
172
|
+
|
173
|
+
# --- Benchmarking ---
|
174
|
+
quantiles = [0.5, 0.2, 0.8]
|
175
|
+
print(f"Benchmarking Cutlass fused_experts...")
|
176
|
+
cutlass_ms, cutlass_min, cutlass_max = triton.testing.do_bench_cudagraph(
|
177
|
+
cutlass_lambda, rep=1000, quantiles=quantiles
|
178
|
+
)
|
179
|
+
|
180
|
+
print(f"Benchmarking Triton fused_experts...")
|
181
|
+
triton_ms, triton_min, triton_max = triton.testing.do_bench_cudagraph(
|
182
|
+
triton_lambda, rep=1000, quantiles=quantiles
|
183
|
+
)
|
184
|
+
print(
|
185
|
+
f"Cutlass fused_experts time: {cutlass_ms:.3f} ms (median) [{cutlass_min:.3f} - {cutlass_max:.3f}]"
|
186
|
+
)
|
187
|
+
print(
|
188
|
+
f"Triton fused_experts time: {triton_ms:.3f} ms (median) [{triton_min:.3f} - {triton_max:.3f}]"
|
189
|
+
)
|
190
|
+
|
191
|
+
# --- Correctness Check ---
|
192
|
+
if check:
|
193
|
+
print("Running correctness check...")
|
194
|
+
with torch.no_grad():
|
195
|
+
# Run CUTLASS version (requires transposed weights)
|
196
|
+
y_cutlass = cutlass_fused_experts(
|
197
|
+
x,
|
198
|
+
w1.transpose(1, 2), # Transposed
|
199
|
+
w2.transpose(1, 2), # Transposed
|
200
|
+
w1_scale.transpose(1, 2),
|
201
|
+
w2_scale.transpose(1, 2),
|
202
|
+
topk_weights,
|
203
|
+
topk_ids,
|
204
|
+
a1_strides,
|
205
|
+
c1_strides,
|
206
|
+
a2_strides,
|
207
|
+
c2_strides,
|
208
|
+
workspace,
|
209
|
+
a_ptrs,
|
210
|
+
b_ptrs,
|
211
|
+
out_ptrs,
|
212
|
+
a_scales_ptrs,
|
213
|
+
b_scales_ptrs,
|
214
|
+
expert_offsets,
|
215
|
+
problem_sizes1,
|
216
|
+
problem_sizes2,
|
217
|
+
)
|
218
|
+
|
219
|
+
# Run Triton version (requires original shape weights, use inplace=False)
|
220
|
+
y_triton = fused_experts(
|
221
|
+
x,
|
222
|
+
w1, # Original shape
|
223
|
+
w2, # Original shape
|
224
|
+
topk_weights,
|
225
|
+
topk_ids,
|
226
|
+
inplace=False, # Important: Use False to get output tensor
|
227
|
+
activation="silu",
|
228
|
+
use_fp8_w8a8=True,
|
229
|
+
w1_scale=w1_scale,
|
230
|
+
w2_scale=w2_scale,
|
231
|
+
block_shape=block_shape,
|
232
|
+
)
|
233
|
+
|
234
|
+
# Ensure outputs are same dtype for comparison
|
235
|
+
y_cutlass = y_cutlass.to(dtype)
|
236
|
+
y_triton = y_triton.to(dtype)
|
237
|
+
|
238
|
+
abs_error = torch.abs(y_cutlass - y_triton)
|
239
|
+
rel_error = abs_error / torch.clamp(torch.abs(y_triton), min=1e-2)
|
240
|
+
|
241
|
+
max_abs_err = abs_error.max().item()
|
242
|
+
max_rel_err = rel_error.max().item()
|
243
|
+
|
244
|
+
print("y_cutlass:", y_cutlass[:, :10])
|
245
|
+
print("y_triton:", y_triton[:, :10])
|
246
|
+
print(f"Max absolute error: {max_abs_err:.6f}")
|
247
|
+
print(f"Max relative error: {max_rel_err:.6f}")
|
248
|
+
|
249
|
+
# Tolerance might need adjustment based on FP8 specifics and kernel differences
|
250
|
+
# FP8 comparisons often require higher tolerance than FP16/BF16
|
251
|
+
assert max_rel_err < 5e-1, f"Relative error too high! {max_rel_err}"
|
252
|
+
print("Correctness check passed.")
|
253
|
+
|
254
|
+
|
255
|
+
def main(tp_size=8, batch_sizes=[1, 4, 8, 16, 32, 64, 128, 256, 512], check=False):
|
256
|
+
model_config = get_model_config(tp_size)
|
257
|
+
print("Model Config:", model_config)
|
258
|
+
for batch_size in batch_sizes:
|
259
|
+
run_test(tp_size, batch_size, model_config, check)
|
260
|
+
|
261
|
+
|
262
|
+
if __name__ == "__main__":
|
263
|
+
parser = argparse.ArgumentParser()
|
264
|
+
parser.add_argument("--tp-size", type=int, default=8, help="Tensor Parallel size")
|
265
|
+
parser.add_argument(
|
266
|
+
"--batch-sizes",
|
267
|
+
type=int,
|
268
|
+
nargs="+",
|
269
|
+
default=[1, 4, 8, 16, 32, 64, 128, 256, 512], # Adjusted default
|
270
|
+
help="List of batch sizes to test",
|
271
|
+
)
|
272
|
+
parser.add_argument("--check", action="store_true", help="Enable check mode")
|
273
|
+
args = parser.parse_args()
|
274
|
+
|
275
|
+
print(f"Running benchmarks with TP size: {args.tp_size}")
|
276
|
+
print(f"Testing batch sizes: {args.batch_sizes}")
|
277
|
+
|
278
|
+
main(tp_size=args.tp_size, batch_sizes=args.batch_sizes, check=args.check)
|
sglang/test/test_programs.py
CHANGED
@@ -370,7 +370,7 @@ def test_dtype_gen():
|
|
370
370
|
@sgl.function
|
371
371
|
def dtype_gen(s):
|
372
372
|
s += "Q: What is the full name of DNS?\n"
|
373
|
-
s += "A: The full
|
373
|
+
s += "A: The full names is " + sgl.gen("str_res", dtype=str, stop="\n") + "\n"
|
374
374
|
s += "Q: Which year was DNS invented?\n"
|
375
375
|
s += "A: " + sgl.gen("int_res", dtype=int) + "\n"
|
376
376
|
s += "Q: What is the value of pi?\n"
|
@@ -503,7 +503,7 @@ def test_hellaswag_select():
|
|
503
503
|
#####################################
|
504
504
|
|
505
505
|
# Run requests
|
506
|
-
tic = time.
|
506
|
+
tic = time.perf_counter()
|
507
507
|
rets = few_shot_hellaswag.run_batch(
|
508
508
|
arguments,
|
509
509
|
temperature=0,
|
@@ -514,13 +514,13 @@ def test_hellaswag_select():
|
|
514
514
|
preds = []
|
515
515
|
for i, ret in enumerate(rets):
|
516
516
|
preds.append(choices[i].index(ret["answer"]))
|
517
|
-
latency = time.
|
517
|
+
latency = time.perf_counter() - tic
|
518
518
|
|
519
519
|
# Compute accuracy
|
520
520
|
accuracy = np.mean(np.array(preds) == np.array(labels))
|
521
521
|
|
522
522
|
# Test generator style of run_batch
|
523
|
-
tic = time.
|
523
|
+
tic = time.perf_counter()
|
524
524
|
rets = few_shot_hellaswag.run_batch(
|
525
525
|
arguments,
|
526
526
|
temperature=0,
|
@@ -531,7 +531,7 @@ def test_hellaswag_select():
|
|
531
531
|
preds_gen = []
|
532
532
|
for i, ret in enumerate(rets):
|
533
533
|
preds_gen.append(choices[i].index(ret["answer"]))
|
534
|
-
latency_gen = time.
|
534
|
+
latency_gen = time.perf_counter() - tic
|
535
535
|
|
536
536
|
# Compute accuracy
|
537
537
|
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
|
sglang/test/test_utils.py
CHANGED
@@ -395,12 +395,12 @@ def popen_launch_server(
|
|
395
395
|
other_args: list[str] = (),
|
396
396
|
env: Optional[dict] = None,
|
397
397
|
return_stdout_stderr: Optional[tuple] = None,
|
398
|
-
|
398
|
+
pd_separated: bool = False,
|
399
399
|
):
|
400
400
|
_, host, port = base_url.split(":")
|
401
401
|
host = host[2:]
|
402
402
|
|
403
|
-
if
|
403
|
+
if pd_separated:
|
404
404
|
command = "sglang.launch_pd_server"
|
405
405
|
else:
|
406
406
|
command = "sglang.launch_server"
|
@@ -414,7 +414,7 @@ def popen_launch_server(
|
|
414
414
|
*[str(x) for x in other_args],
|
415
415
|
]
|
416
416
|
|
417
|
-
if
|
417
|
+
if pd_separated:
|
418
418
|
command.extend(
|
419
419
|
[
|
420
420
|
"--lb-host",
|
@@ -449,9 +449,9 @@ def popen_launch_server(
|
|
449
449
|
else:
|
450
450
|
process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
|
451
451
|
|
452
|
-
start_time = time.
|
452
|
+
start_time = time.perf_counter()
|
453
453
|
with requests.Session() as session:
|
454
|
-
while time.
|
454
|
+
while time.perf_counter() - start_time < timeout:
|
455
455
|
try:
|
456
456
|
headers = {
|
457
457
|
"Content-Type": "application/json; charset=utf-8",
|
@@ -478,6 +478,47 @@ def popen_launch_server(
|
|
478
478
|
raise TimeoutError("Server failed to start within the timeout period.")
|
479
479
|
|
480
480
|
|
481
|
+
def popen_launch_pd_server(
|
482
|
+
model: str,
|
483
|
+
base_url: str,
|
484
|
+
timeout: float,
|
485
|
+
api_key: Optional[str] = None,
|
486
|
+
other_args: list[str] = (),
|
487
|
+
env: Optional[dict] = None,
|
488
|
+
):
|
489
|
+
_, host, port = base_url.split(":")
|
490
|
+
host = host[2:]
|
491
|
+
|
492
|
+
command = "sglang.launch_server"
|
493
|
+
|
494
|
+
command = [
|
495
|
+
"python3",
|
496
|
+
"-m",
|
497
|
+
command,
|
498
|
+
"--model-path",
|
499
|
+
model,
|
500
|
+
*[str(x) for x in other_args],
|
501
|
+
]
|
502
|
+
|
503
|
+
command.extend(
|
504
|
+
[
|
505
|
+
"--host",
|
506
|
+
host,
|
507
|
+
"--port",
|
508
|
+
port,
|
509
|
+
]
|
510
|
+
)
|
511
|
+
|
512
|
+
if api_key:
|
513
|
+
command += ["--api-key", api_key]
|
514
|
+
|
515
|
+
print(f"command={' '.join(command)}")
|
516
|
+
|
517
|
+
process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
|
518
|
+
|
519
|
+
return process
|
520
|
+
|
521
|
+
|
481
522
|
def run_with_timeout(
|
482
523
|
func: Callable,
|
483
524
|
args: tuple = (),
|
@@ -509,7 +550,7 @@ class TestFile:
|
|
509
550
|
|
510
551
|
|
511
552
|
def run_unittest_files(files: List[TestFile], timeout_per_file: float):
|
512
|
-
tic = time.
|
553
|
+
tic = time.perf_counter()
|
513
554
|
success = True
|
514
555
|
|
515
556
|
for i, file in enumerate(files):
|
@@ -524,13 +565,13 @@ def run_unittest_files(files: List[TestFile], timeout_per_file: float):
|
|
524
565
|
f".\n.\nBegin ({i}/{len(files) - 1}):\npython3 {filename}\n.\n.\n",
|
525
566
|
flush=True,
|
526
567
|
)
|
527
|
-
tic = time.
|
568
|
+
tic = time.perf_counter()
|
528
569
|
|
529
570
|
process = subprocess.Popen(
|
530
571
|
["python3", filename], stdout=None, stderr=None, env=os.environ
|
531
572
|
)
|
532
573
|
process.wait()
|
533
|
-
elapsed = time.
|
574
|
+
elapsed = time.perf_counter() - tic
|
534
575
|
|
535
576
|
print(
|
536
577
|
f".\n.\nEnd ({i}/{len(files) - 1}):\n{filename=}, {elapsed=:.0f}, {estimated_time=}\n.\n.\n",
|
@@ -556,9 +597,9 @@ def run_unittest_files(files: List[TestFile], timeout_per_file: float):
|
|
556
597
|
break
|
557
598
|
|
558
599
|
if success:
|
559
|
-
print(f"Success. Time elapsed: {time.
|
600
|
+
print(f"Success. Time elapsed: {time.perf_counter() - tic:.2f}s", flush=True)
|
560
601
|
else:
|
561
|
-
print(f"Fail. Time elapsed: {time.
|
602
|
+
print(f"Fail. Time elapsed: {time.perf_counter() - tic:.2f}s", flush=True)
|
562
603
|
|
563
604
|
return 0 if success else -1
|
564
605
|
|
@@ -581,7 +622,7 @@ def get_benchmark_args(
|
|
581
622
|
disable_stream=False,
|
582
623
|
disable_ignore_eos=False,
|
583
624
|
seed: int = 0,
|
584
|
-
|
625
|
+
pd_separated: bool = False,
|
585
626
|
):
|
586
627
|
return SimpleNamespace(
|
587
628
|
backend="sglang",
|
@@ -611,7 +652,7 @@ def get_benchmark_args(
|
|
611
652
|
profile=None,
|
612
653
|
lora_name=None,
|
613
654
|
prompt_suffix="",
|
614
|
-
|
655
|
+
pd_separated=pd_separated,
|
615
656
|
)
|
616
657
|
|
617
658
|
|
@@ -675,7 +716,7 @@ def run_bench_serving_multi(
|
|
675
716
|
other_server_args,
|
676
717
|
benchmark_args,
|
677
718
|
need_warmup=False,
|
678
|
-
|
719
|
+
pd_separated=False,
|
679
720
|
):
|
680
721
|
# Launch the server
|
681
722
|
process = popen_launch_server(
|
@@ -683,7 +724,7 @@ def run_bench_serving_multi(
|
|
683
724
|
base_url,
|
684
725
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
685
726
|
other_args=other_server_args,
|
686
|
-
|
727
|
+
pd_separated=pd_separated,
|
687
728
|
)
|
688
729
|
|
689
730
|
# run benchmark for all
|
sglang/utils.py
CHANGED
@@ -278,7 +278,7 @@ def graceful_registry(sub_module_name: str):
|
|
278
278
|
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
|
279
279
|
)
|
280
280
|
if signum == signal.SIGTERM:
|
281
|
-
logger.info(f"{sub_module_name}
|
281
|
+
logger.info(f"{sub_module_name} receive sigterm")
|
282
282
|
|
283
283
|
signal.signal(signal.SIGTERM, graceful_shutdown)
|
284
284
|
|
@@ -436,7 +436,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
|
|
436
436
|
base_url: The base URL of the server
|
437
437
|
timeout: Maximum time to wait in seconds. None means wait forever.
|
438
438
|
"""
|
439
|
-
start_time = time.
|
439
|
+
start_time = time.perf_counter()
|
440
440
|
while True:
|
441
441
|
try:
|
442
442
|
response = requests.get(
|
@@ -455,7 +455,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
|
|
455
455
|
)
|
456
456
|
break
|
457
457
|
|
458
|
-
if timeout and time.
|
458
|
+
if timeout and time.perf_counter() - start_time > timeout:
|
459
459
|
raise TimeoutError("Server did not become ready within timeout period")
|
460
460
|
except requests.exceptions.RequestException:
|
461
461
|
time.sleep(1)
|
sglang/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.4.6.
|
1
|
+
__version__ = "0.4.6.post5"
|