sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/test/send_one.py CHANGED
@@ -27,6 +27,7 @@ class BenchArgs:
27
27
  "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
28
28
  )
29
29
  image: bool = False
30
+ many_images: bool = False
30
31
  stream: bool = False
31
32
 
32
33
  @staticmethod
@@ -48,6 +49,7 @@ class BenchArgs:
48
49
  parser.add_argument("--return-logprob", action="store_true")
49
50
  parser.add_argument("--prompt", type=str, default=BenchArgs.prompt)
50
51
  parser.add_argument("--image", action="store_true")
52
+ parser.add_argument("--many-images", action="store_true")
51
53
  parser.add_argument("--stream", action="store_true")
52
54
 
53
55
  @classmethod
@@ -62,6 +64,17 @@ def send_one_prompt(args):
62
64
  "Human: Describe this image in a very short sentence.\n\nAssistant:"
63
65
  )
64
66
  image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
67
+ elif args.many_images:
68
+ args.prompt = (
69
+ "Human: I have one reference image and many images."
70
+ "Describe their relationship in a very short sentence.\n\nAssistant:"
71
+ )
72
+ image_data = [
73
+ "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
74
+ "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
75
+ "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
76
+ "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
77
+ ]
65
78
  else:
66
79
  image_data = None
67
80
 
@@ -74,9 +87,6 @@ def send_one_prompt(args):
74
87
  "Write in a format of json.\nAssistant:"
75
88
  )
76
89
  json_schema = "$$ANY$$"
77
- json_schema = (
78
- '{"type": "object", "properties": {"population": {"type": "integer"}}}'
79
- )
80
90
  else:
81
91
  json_schema = None
82
92
 
@@ -140,7 +140,7 @@ class ChatCompletionSampler(SamplerBase):
140
140
  max_tokens=self.max_tokens,
141
141
  )
142
142
  return response.choices[0].message.content
143
- # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
143
+ # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
144
144
  except openai.BadRequestError as e:
145
145
  print("Bad Request Error", e)
146
146
  return ""
@@ -121,7 +121,7 @@ class HumanEval(Eval):
121
121
  convo=convo,
122
122
  metrics={
123
123
  f"pass@{k}": estimate_pass_at_k([total], [correct], k)
124
- # this will be aggrated so no need of .mean()
124
+ # this will be aggregated so no need of .mean()
125
125
  for k in self._ks_passes
126
126
  if total >= k
127
127
  },
@@ -0,0 +1,278 @@
1
+ import argparse
2
+ import time
3
+
4
+ import torch
5
+ import triton # Added import
6
+ import triton.testing # Added import
7
+ from transformers import AutoConfig
8
+
9
+ from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts
10
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
11
+
12
+
13
+ def get_model_config(tp_size: int):
14
+ config = AutoConfig.from_pretrained(
15
+ "deepseek-ai/deepseek-R1", trust_remote_code=True
16
+ )
17
+ E = config.n_routed_experts
18
+ topk = config.num_experts_per_tok
19
+ intermediate_size = config.moe_intermediate_size
20
+ shard_intermediate_size = 2 * intermediate_size // tp_size
21
+
22
+ return {
23
+ "num_experts": E,
24
+ "topk": topk,
25
+ "hidden_size": config.hidden_size,
26
+ "shard_intermediate_size": shard_intermediate_size,
27
+ "dtype": config.torch_dtype,
28
+ "block_shape": config.quantization_config["weight_block_size"],
29
+ }
30
+
31
+
32
+ def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
33
+ """Converts tensor to FP8 E4M3, scaling values to fit the range."""
34
+ finfo = torch.finfo(torch.float8_e4m3fn)
35
+ # Calculate max absolute value safely
36
+ max_val = torch.max(torch.abs(tensor))
37
+ # Avoid division by zero if tensor is all zeros
38
+ if max_val == 0:
39
+ scale_factor = 1.0
40
+ else:
41
+ # Scale factor to bring the max value to finfo.max
42
+ scale_factor = finfo.max / max_val
43
+
44
+ # Apply scaling
45
+ scaled_tensor = tensor * scale_factor
46
+
47
+ # Clamp and convert
48
+ fp8_tensor = scaled_tensor.clamp(min=finfo.min, max=finfo.max).to(
49
+ dtype=torch.float8_e4m3fn
50
+ )
51
+ return fp8_tensor
52
+
53
+
54
+ def run_test(tp_size, batch_size, model_config, check=False):
55
+ print(f"\n--- Batch Size: {batch_size} ---")
56
+ torch.set_default_device("cuda")
57
+ torch.cuda.manual_seed_all(42) # For reproducible random numbers
58
+
59
+ E = model_config["num_experts"]
60
+ topk = model_config["topk"]
61
+ H = model_config["hidden_size"]
62
+ I = model_config["shard_intermediate_size"]
63
+ block_shape = model_config["block_shape"] # Tuple (BLOCK_N, BLOCK_K)
64
+ dtype = model_config["dtype"] # e.g., torch.bfloat16
65
+
66
+ print(
67
+ f"Config: E={E}, topk={topk}, H={H}, I_shard={I}, dtype={dtype}, block_shape={block_shape}"
68
+ )
69
+
70
+ # --- Input Data ---
71
+ # Use bf16/fp16 for input activation based on model config
72
+ x = torch.randn((batch_size, H), device="cuda", dtype=dtype) * 0.0001
73
+ # --- Weights (Generate in higher precision, then convert to FP8) ---
74
+ # Generate weights suitable for FP8 conversion (e.g., scaled appropriately)
75
+ w1_hp = (
76
+ torch.randn((E, I, H), device="cuda", dtype=torch.float32) * 0.00001 + 0.00001
77
+ )
78
+ w2_hp = (
79
+ torch.randn((E, H, I // 2), device="cuda", dtype=torch.float32) * 0.00001
80
+ + 0.00001
81
+ )
82
+
83
+ w1 = to_fp8(w1_hp)
84
+ w2 = to_fp8(w2_hp)
85
+
86
+ # --- Scales for FP8 Weights ---
87
+ block_n, block_k = block_shape
88
+ # Calculate number of blocks needed
89
+ w1_blocks_dim1 = (I + block_n - 1) // block_n
90
+ w1_blocks_dim2 = (H + block_k - 1) // block_k
91
+ w2_blocks_dim1 = (H + block_n - 1) // block_n
92
+ w2_blocks_dim2 = (I // 2 + block_k - 1) // block_k
93
+
94
+ # Scales are typically float32 or float16/bfloat16
95
+ scale_dtype = torch.float32 # Or dtype if scales match model dtype
96
+ w1_scale = torch.full(
97
+ (E, w1_blocks_dim1, w1_blocks_dim2), 1, device="cuda", dtype=scale_dtype
98
+ ) # Avoid zero scales
99
+ w2_scale = torch.full(
100
+ (E, w2_blocks_dim1, w2_blocks_dim2), 1, device="cuda", dtype=scale_dtype
101
+ ) # Avoid zero scales
102
+
103
+ # --- Routing Information ---
104
+ topk_weights = torch.softmax(
105
+ torch.rand(batch_size, topk, device="cuda", dtype=dtype), dim=-1
106
+ )
107
+ topk_ids = torch.randint(0, E, (batch_size, topk), dtype=torch.int32, device="cuda")
108
+
109
+ a1_strides = torch.full((E,), H, dtype=torch.int64, device="cuda")
110
+ c1_strides = torch.full((E,), I, dtype=torch.int64, device="cuda")
111
+ a2_strides = torch.full((E,), I // 2, dtype=torch.int64, device="cuda")
112
+ c2_strides = torch.full((E,), H, dtype=torch.int64, device="cuda")
113
+
114
+ workspace = torch.empty(
115
+ (7182 * 1024), device="cuda", dtype=torch.uint8
116
+ ) # Allocate sufficient workspace
117
+ # Pointer arrays (often filled by the kernel or a prep step, but needed as args)
118
+ a_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda")
119
+ b_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda")
120
+ out_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda")
121
+ a_scales_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda")
122
+ b_scales_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda")
123
+ expert_offsets = torch.empty((E + 1,), dtype=torch.int32, device="cuda")
124
+ problem_sizes1 = torch.empty((E, 3), dtype=torch.int32, device="cuda")
125
+ problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device="cuda")
126
+
127
+ # --- Lambdas for Benchmarking ---
128
+ cutlass_lambda = lambda: cutlass_fused_experts(
129
+ x,
130
+ w1.transpose(1, 2), # Transposed
131
+ w2.transpose(1, 2), # Transposed
132
+ w1_scale.transpose(1, 2),
133
+ w2_scale.transpose(1, 2),
134
+ topk_weights,
135
+ topk_ids,
136
+ a1_strides,
137
+ c1_strides,
138
+ a2_strides,
139
+ c2_strides,
140
+ workspace,
141
+ a_ptrs,
142
+ b_ptrs,
143
+ out_ptrs,
144
+ a_scales_ptrs,
145
+ b_scales_ptrs,
146
+ expert_offsets,
147
+ problem_sizes1,
148
+ problem_sizes2,
149
+ )
150
+
151
+ # Note: Triton expects non-transposed weights
152
+ triton_lambda = lambda: fused_experts(
153
+ x,
154
+ w1,
155
+ w2,
156
+ topk_weights,
157
+ topk_ids,
158
+ inplace=False, # Use False for benchmarking to avoid side effects if run multiple times
159
+ activation="silu", # Assuming SiLU activation common in MoEs
160
+ use_fp8_w8a8=True,
161
+ w1_scale=w1_scale,
162
+ w2_scale=w2_scale,
163
+ block_shape=block_shape,
164
+ )
165
+
166
+ # --- Warmup ---
167
+ print("Warming up...")
168
+ for _ in range(10):
169
+ _ = cutlass_lambda()
170
+ _ = triton_lambda()
171
+ torch.cuda.synchronize()
172
+
173
+ # --- Benchmarking ---
174
+ quantiles = [0.5, 0.2, 0.8]
175
+ print(f"Benchmarking Cutlass fused_experts...")
176
+ cutlass_ms, cutlass_min, cutlass_max = triton.testing.do_bench_cudagraph(
177
+ cutlass_lambda, rep=1000, quantiles=quantiles
178
+ )
179
+
180
+ print(f"Benchmarking Triton fused_experts...")
181
+ triton_ms, triton_min, triton_max = triton.testing.do_bench_cudagraph(
182
+ triton_lambda, rep=1000, quantiles=quantiles
183
+ )
184
+ print(
185
+ f"Cutlass fused_experts time: {cutlass_ms:.3f} ms (median) [{cutlass_min:.3f} - {cutlass_max:.3f}]"
186
+ )
187
+ print(
188
+ f"Triton fused_experts time: {triton_ms:.3f} ms (median) [{triton_min:.3f} - {triton_max:.3f}]"
189
+ )
190
+
191
+ # --- Correctness Check ---
192
+ if check:
193
+ print("Running correctness check...")
194
+ with torch.no_grad():
195
+ # Run CUTLASS version (requires transposed weights)
196
+ y_cutlass = cutlass_fused_experts(
197
+ x,
198
+ w1.transpose(1, 2), # Transposed
199
+ w2.transpose(1, 2), # Transposed
200
+ w1_scale.transpose(1, 2),
201
+ w2_scale.transpose(1, 2),
202
+ topk_weights,
203
+ topk_ids,
204
+ a1_strides,
205
+ c1_strides,
206
+ a2_strides,
207
+ c2_strides,
208
+ workspace,
209
+ a_ptrs,
210
+ b_ptrs,
211
+ out_ptrs,
212
+ a_scales_ptrs,
213
+ b_scales_ptrs,
214
+ expert_offsets,
215
+ problem_sizes1,
216
+ problem_sizes2,
217
+ )
218
+
219
+ # Run Triton version (requires original shape weights, use inplace=False)
220
+ y_triton = fused_experts(
221
+ x,
222
+ w1, # Original shape
223
+ w2, # Original shape
224
+ topk_weights,
225
+ topk_ids,
226
+ inplace=False, # Important: Use False to get output tensor
227
+ activation="silu",
228
+ use_fp8_w8a8=True,
229
+ w1_scale=w1_scale,
230
+ w2_scale=w2_scale,
231
+ block_shape=block_shape,
232
+ )
233
+
234
+ # Ensure outputs are same dtype for comparison
235
+ y_cutlass = y_cutlass.to(dtype)
236
+ y_triton = y_triton.to(dtype)
237
+
238
+ abs_error = torch.abs(y_cutlass - y_triton)
239
+ rel_error = abs_error / torch.clamp(torch.abs(y_triton), min=1e-2)
240
+
241
+ max_abs_err = abs_error.max().item()
242
+ max_rel_err = rel_error.max().item()
243
+
244
+ print("y_cutlass:", y_cutlass[:, :10])
245
+ print("y_triton:", y_triton[:, :10])
246
+ print(f"Max absolute error: {max_abs_err:.6f}")
247
+ print(f"Max relative error: {max_rel_err:.6f}")
248
+
249
+ # Tolerance might need adjustment based on FP8 specifics and kernel differences
250
+ # FP8 comparisons often require higher tolerance than FP16/BF16
251
+ assert max_rel_err < 5e-1, f"Relative error too high! {max_rel_err}"
252
+ print("Correctness check passed.")
253
+
254
+
255
+ def main(tp_size=8, batch_sizes=[1, 4, 8, 16, 32, 64, 128, 256, 512], check=False):
256
+ model_config = get_model_config(tp_size)
257
+ print("Model Config:", model_config)
258
+ for batch_size in batch_sizes:
259
+ run_test(tp_size, batch_size, model_config, check)
260
+
261
+
262
+ if __name__ == "__main__":
263
+ parser = argparse.ArgumentParser()
264
+ parser.add_argument("--tp-size", type=int, default=8, help="Tensor Parallel size")
265
+ parser.add_argument(
266
+ "--batch-sizes",
267
+ type=int,
268
+ nargs="+",
269
+ default=[1, 4, 8, 16, 32, 64, 128, 256, 512], # Adjusted default
270
+ help="List of batch sizes to test",
271
+ )
272
+ parser.add_argument("--check", action="store_true", help="Enable check mode")
273
+ args = parser.parse_args()
274
+
275
+ print(f"Running benchmarks with TP size: {args.tp_size}")
276
+ print(f"Testing batch sizes: {args.batch_sizes}")
277
+
278
+ main(tp_size=args.tp_size, batch_sizes=args.batch_sizes, check=args.check)
@@ -370,7 +370,7 @@ def test_dtype_gen():
370
370
  @sgl.function
371
371
  def dtype_gen(s):
372
372
  s += "Q: What is the full name of DNS?\n"
373
- s += "A: The full nams is " + sgl.gen("str_res", dtype=str, stop="\n") + "\n"
373
+ s += "A: The full names is " + sgl.gen("str_res", dtype=str, stop="\n") + "\n"
374
374
  s += "Q: Which year was DNS invented?\n"
375
375
  s += "A: " + sgl.gen("int_res", dtype=int) + "\n"
376
376
  s += "Q: What is the value of pi?\n"
@@ -503,7 +503,7 @@ def test_hellaswag_select():
503
503
  #####################################
504
504
 
505
505
  # Run requests
506
- tic = time.time()
506
+ tic = time.perf_counter()
507
507
  rets = few_shot_hellaswag.run_batch(
508
508
  arguments,
509
509
  temperature=0,
@@ -514,13 +514,13 @@ def test_hellaswag_select():
514
514
  preds = []
515
515
  for i, ret in enumerate(rets):
516
516
  preds.append(choices[i].index(ret["answer"]))
517
- latency = time.time() - tic
517
+ latency = time.perf_counter() - tic
518
518
 
519
519
  # Compute accuracy
520
520
  accuracy = np.mean(np.array(preds) == np.array(labels))
521
521
 
522
522
  # Test generator style of run_batch
523
- tic = time.time()
523
+ tic = time.perf_counter()
524
524
  rets = few_shot_hellaswag.run_batch(
525
525
  arguments,
526
526
  temperature=0,
@@ -531,7 +531,7 @@ def test_hellaswag_select():
531
531
  preds_gen = []
532
532
  for i, ret in enumerate(rets):
533
533
  preds_gen.append(choices[i].index(ret["answer"]))
534
- latency_gen = time.time() - tic
534
+ latency_gen = time.perf_counter() - tic
535
535
 
536
536
  # Compute accuracy
537
537
  accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
sglang/test/test_utils.py CHANGED
@@ -395,12 +395,12 @@ def popen_launch_server(
395
395
  other_args: list[str] = (),
396
396
  env: Optional[dict] = None,
397
397
  return_stdout_stderr: Optional[tuple] = None,
398
- pd_seperated: bool = False,
398
+ pd_separated: bool = False,
399
399
  ):
400
400
  _, host, port = base_url.split(":")
401
401
  host = host[2:]
402
402
 
403
- if pd_seperated:
403
+ if pd_separated:
404
404
  command = "sglang.launch_pd_server"
405
405
  else:
406
406
  command = "sglang.launch_server"
@@ -414,7 +414,7 @@ def popen_launch_server(
414
414
  *[str(x) for x in other_args],
415
415
  ]
416
416
 
417
- if pd_seperated:
417
+ if pd_separated:
418
418
  command.extend(
419
419
  [
420
420
  "--lb-host",
@@ -449,9 +449,9 @@ def popen_launch_server(
449
449
  else:
450
450
  process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
451
451
 
452
- start_time = time.time()
452
+ start_time = time.perf_counter()
453
453
  with requests.Session() as session:
454
- while time.time() - start_time < timeout:
454
+ while time.perf_counter() - start_time < timeout:
455
455
  try:
456
456
  headers = {
457
457
  "Content-Type": "application/json; charset=utf-8",
@@ -478,6 +478,47 @@ def popen_launch_server(
478
478
  raise TimeoutError("Server failed to start within the timeout period.")
479
479
 
480
480
 
481
+ def popen_launch_pd_server(
482
+ model: str,
483
+ base_url: str,
484
+ timeout: float,
485
+ api_key: Optional[str] = None,
486
+ other_args: list[str] = (),
487
+ env: Optional[dict] = None,
488
+ ):
489
+ _, host, port = base_url.split(":")
490
+ host = host[2:]
491
+
492
+ command = "sglang.launch_server"
493
+
494
+ command = [
495
+ "python3",
496
+ "-m",
497
+ command,
498
+ "--model-path",
499
+ model,
500
+ *[str(x) for x in other_args],
501
+ ]
502
+
503
+ command.extend(
504
+ [
505
+ "--host",
506
+ host,
507
+ "--port",
508
+ port,
509
+ ]
510
+ )
511
+
512
+ if api_key:
513
+ command += ["--api-key", api_key]
514
+
515
+ print(f"command={' '.join(command)}")
516
+
517
+ process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
518
+
519
+ return process
520
+
521
+
481
522
  def run_with_timeout(
482
523
  func: Callable,
483
524
  args: tuple = (),
@@ -509,7 +550,7 @@ class TestFile:
509
550
 
510
551
 
511
552
  def run_unittest_files(files: List[TestFile], timeout_per_file: float):
512
- tic = time.time()
553
+ tic = time.perf_counter()
513
554
  success = True
514
555
 
515
556
  for i, file in enumerate(files):
@@ -524,13 +565,13 @@ def run_unittest_files(files: List[TestFile], timeout_per_file: float):
524
565
  f".\n.\nBegin ({i}/{len(files) - 1}):\npython3 {filename}\n.\n.\n",
525
566
  flush=True,
526
567
  )
527
- tic = time.time()
568
+ tic = time.perf_counter()
528
569
 
529
570
  process = subprocess.Popen(
530
571
  ["python3", filename], stdout=None, stderr=None, env=os.environ
531
572
  )
532
573
  process.wait()
533
- elapsed = time.time() - tic
574
+ elapsed = time.perf_counter() - tic
534
575
 
535
576
  print(
536
577
  f".\n.\nEnd ({i}/{len(files) - 1}):\n{filename=}, {elapsed=:.0f}, {estimated_time=}\n.\n.\n",
@@ -556,9 +597,9 @@ def run_unittest_files(files: List[TestFile], timeout_per_file: float):
556
597
  break
557
598
 
558
599
  if success:
559
- print(f"Success. Time elapsed: {time.time() - tic:.2f}s", flush=True)
600
+ print(f"Success. Time elapsed: {time.perf_counter() - tic:.2f}s", flush=True)
560
601
  else:
561
- print(f"Fail. Time elapsed: {time.time() - tic:.2f}s", flush=True)
602
+ print(f"Fail. Time elapsed: {time.perf_counter() - tic:.2f}s", flush=True)
562
603
 
563
604
  return 0 if success else -1
564
605
 
@@ -581,7 +622,7 @@ def get_benchmark_args(
581
622
  disable_stream=False,
582
623
  disable_ignore_eos=False,
583
624
  seed: int = 0,
584
- pd_seperated: bool = False,
625
+ pd_separated: bool = False,
585
626
  ):
586
627
  return SimpleNamespace(
587
628
  backend="sglang",
@@ -611,7 +652,7 @@ def get_benchmark_args(
611
652
  profile=None,
612
653
  lora_name=None,
613
654
  prompt_suffix="",
614
- pd_seperated=pd_seperated,
655
+ pd_separated=pd_separated,
615
656
  )
616
657
 
617
658
 
@@ -675,7 +716,7 @@ def run_bench_serving_multi(
675
716
  other_server_args,
676
717
  benchmark_args,
677
718
  need_warmup=False,
678
- pd_seperated=False,
719
+ pd_separated=False,
679
720
  ):
680
721
  # Launch the server
681
722
  process = popen_launch_server(
@@ -683,7 +724,7 @@ def run_bench_serving_multi(
683
724
  base_url,
684
725
  timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
685
726
  other_args=other_server_args,
686
- pd_seperated=pd_seperated,
727
+ pd_separated=pd_separated,
687
728
  )
688
729
 
689
730
  # run benchmark for all
sglang/utils.py CHANGED
@@ -278,7 +278,7 @@ def graceful_registry(sub_module_name: str):
278
278
  f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
279
279
  )
280
280
  if signum == signal.SIGTERM:
281
- logger.info(f"{sub_module_name} recive sigterm")
281
+ logger.info(f"{sub_module_name} receive sigterm")
282
282
 
283
283
  signal.signal(signal.SIGTERM, graceful_shutdown)
284
284
 
@@ -436,7 +436,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
436
436
  base_url: The base URL of the server
437
437
  timeout: Maximum time to wait in seconds. None means wait forever.
438
438
  """
439
- start_time = time.time()
439
+ start_time = time.perf_counter()
440
440
  while True:
441
441
  try:
442
442
  response = requests.get(
@@ -455,7 +455,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
455
455
  )
456
456
  break
457
457
 
458
- if timeout and time.time() - start_time > timeout:
458
+ if timeout and time.perf_counter() - start_time > timeout:
459
459
  raise TimeoutError("Server did not become ready within timeout period")
460
460
  except requests.exceptions.RequestException:
461
461
  time.sleep(1)
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.6.post3"
1
+ __version__ = "0.4.6.post5"