sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,57 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class DummyModel(nn.Module):
6
+ def __init__(self, d_in=2048, n_heads=128, softmax_scale=0.5):
7
+ super().__init__()
8
+ self.weights_proj = nn.Linear(d_in, 1024)
9
+ self.n_heads = n_heads
10
+ self.softmax_scale = softmax_scale
11
+
12
+ def _get_logits_head_gate_orig(self, x: torch.Tensor, q_scale: torch.Tensor):
13
+ weights = self.weights_proj(x)
14
+ weights = weights * self.n_heads**-0.5
15
+ q_scale = q_scale.unsqueeze(1) # (B,1,1)
16
+ weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
17
+ return weights
18
+
19
+ def _get_logits_head_gate_opt(self, x: torch.Tensor, q_scale: torch.Tensor):
20
+ weights = self.weights_proj(x)
21
+ q_scale = q_scale.unsqueeze(1) # (B,1,1)
22
+ scale_const = self.n_heads**-0.5 * q_scale * self.softmax_scale # (B,1,1)
23
+ weights = weights.unsqueeze(-1) * scale_const # (B,1024,1)
24
+ return weights
25
+
26
+
27
+ def main():
28
+ torch.manual_seed(0)
29
+ model = DummyModel(d_in=2048, n_heads=128, softmax_scale=0.5)
30
+ x = torch.randn(128, 2048) # batch=128, d_in=2048
31
+ q_scale = torch.randn(128, 1)
32
+
33
+ import time
34
+
35
+ start = time.time()
36
+ for _ in range(1000):
37
+ out_orig = model._get_logits_head_gate_orig(x, q_scale)
38
+ print("Original version time:", time.time() - start)
39
+
40
+ start = time.time()
41
+ for _ in range(1000):
42
+ out_opt = model._get_logits_head_gate_opt(x, q_scale)
43
+ print("Optimized version time:", time.time() - start)
44
+
45
+ print("Difference:", (out_orig - out_opt).abs().max().item())
46
+ assert torch.allclose(out_orig, out_opt), "Mismatch between original and optimized"
47
+
48
+
49
+ if __name__ == "__main__":
50
+ main()
51
+
52
+
53
+ """
54
+ Original version time: 0.49235057830810547
55
+ Optimized version time: 0.4087331295013428
56
+ Difference: 1.4901161193847656e-08
57
+ """
sglang/test/run_eval.py CHANGED
@@ -10,11 +10,46 @@ import time
10
10
 
11
11
  from sglang.test.simple_eval_common import (
12
12
  ChatCompletionSampler,
13
+ Eval,
13
14
  make_report,
14
15
  set_ulimit,
15
16
  )
16
17
 
17
18
 
19
+ def get_thinking_kwargs(args):
20
+ thinking_mode = getattr(args, "thinking_mode", None)
21
+ if thinking_mode in THINKING_MODE_CHOICES:
22
+ if thinking_mode == "deepseek-v3":
23
+ thinking_param = "thinking"
24
+ else:
25
+ thinking_param = "enable_thinking"
26
+ return {
27
+ "chat_template_kwargs": {thinking_param: True},
28
+ }
29
+ return {}
30
+
31
+
32
+ def run_eval_once(args, base_url: str, eval_obj: Eval) -> dict:
33
+ # Get thinking kwargs based on user's choice
34
+ thinking_kwargs = get_thinking_kwargs(args)
35
+
36
+ sampler = ChatCompletionSampler(
37
+ model=args.model,
38
+ max_tokens=getattr(args, "max_tokens", 2048),
39
+ base_url=base_url,
40
+ temperature=getattr(args, "temperature", 0.0),
41
+ reasoning_effort=getattr(args, "reasoning_effort", None),
42
+ extra_body=thinking_kwargs,
43
+ )
44
+
45
+ # Run eval
46
+ tic = time.perf_counter()
47
+ result = eval_obj(sampler)
48
+ latency = time.perf_counter() - tic
49
+
50
+ return result, latency, sampler
51
+
52
+
18
53
  def run_eval(args):
19
54
  set_ulimit()
20
55
 
@@ -60,21 +95,40 @@ def run_eval(args):
60
95
  from sglang.test.simple_eval_humaneval import HumanEval
61
96
 
62
97
  eval_obj = HumanEval(args.num_examples, args.num_threads)
98
+ elif args.eval_name == "mmmu":
99
+ # VLM MMMU evaluation with fixed 100 examples by default
100
+ from sglang.test.simple_eval_mmmu_vlm import MMMUVLMEval
101
+
102
+ eval_obj = MMMUVLMEval(args.num_examples, args.num_threads)
63
103
  else:
64
104
  raise ValueError(f"Invalid eval name: {args.eval_name}")
65
105
 
66
- sampler = ChatCompletionSampler(
67
- model=args.model,
68
- max_tokens=getattr(args, "max_tokens", 2048),
69
- base_url=base_url,
70
- temperature=getattr(args, "temperature", 0.0),
71
- reasoning_effort=getattr(args, "reasoning_effort", None),
72
- )
106
+ if getattr(args, "repeat", 1) == 1:
107
+ result, latency, sampler = run_eval_once(args, base_url, eval_obj)
108
+ else:
109
+ from concurrent.futures import ThreadPoolExecutor
73
110
 
74
- # Run eval
75
- tic = time.perf_counter()
76
- result = eval_obj(sampler)
77
- latency = time.perf_counter() - tic
111
+ executor = ThreadPoolExecutor(max_workers=args.repeat)
112
+
113
+ futures = [
114
+ executor.submit(run_eval_once, args, base_url, eval_obj)
115
+ for _ in range(args.repeat)
116
+ ]
117
+
118
+ scores_repeat = []
119
+
120
+ for f in futures:
121
+ result, latency, sampler = f.result()
122
+ scores_repeat.append(result.score)
123
+
124
+ mean_score = sum(scores_repeat) / len(scores_repeat)
125
+ scores_repeat = [f"{s:.3f}" for s in scores_repeat]
126
+ print("=" * 20)
127
+ print(f"Repeat: {args.repeat}, mean: {mean_score:.3f}")
128
+ print(f"Scores: {scores_repeat}")
129
+ print("=" * 20)
130
+
131
+ executor.shutdown()
78
132
 
79
133
  # Dump reports
80
134
  metrics = result.metrics | {"score": result.score}
@@ -94,9 +148,13 @@ def run_eval(args):
94
148
  print(f"Total latency: {latency:.3f} s")
95
149
  print(f"Score: {metrics['score']:.3f}")
96
150
 
151
+ if getattr(args, "return_latency", False):
152
+ return metrics, latency
97
153
  return metrics
98
154
 
99
155
 
156
+ THINKING_MODE_CHOICES = ["deepseek-r1", "deepseek-v3", "qwen3"]
157
+
100
158
  if __name__ == "__main__":
101
159
  parser = argparse.ArgumentParser()
102
160
  parser.add_argument(
@@ -118,12 +176,22 @@ if __name__ == "__main__":
118
176
  type=str,
119
177
  help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
120
178
  )
179
+ parser.add_argument(
180
+ "--repeat", type=int, default=1, help="repeat the evaluation n times"
181
+ )
121
182
  parser.add_argument("--eval-name", type=str, default="mmlu")
122
183
  parser.add_argument("--num-examples", type=int)
123
184
  parser.add_argument("--num-threads", type=int, default=512)
124
185
  parser.add_argument("--max-tokens", type=int, default=2048)
125
186
  parser.add_argument("--temperature", type=float, default=0.0)
126
187
  parser.add_argument("--reasoning-effort", type=str)
188
+ parser.add_argument(
189
+ "--thinking-mode",
190
+ default=None,
191
+ type=str,
192
+ choices=THINKING_MODE_CHOICES,
193
+ help="Enable thinking mode in Deepseek R1, V3.1/3.2, or Qwen3",
194
+ )
127
195
  args = parser.parse_args()
128
196
 
129
197
  run_eval(args)
sglang/test/runners.py CHANGED
@@ -30,8 +30,8 @@ from transformers import (
30
30
  )
31
31
 
32
32
  from sglang.srt.entrypoints.engine import Engine
33
- from sglang.srt.hf_transformers_utils import get_tokenizer
34
33
  from sglang.srt.utils import load_image
34
+ from sglang.srt.utils.hf_transformers_utils import get_tokenizer
35
35
  from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
36
36
 
37
37
  DEFAULT_PROMPTS = [
@@ -93,6 +93,7 @@ class ChatCompletionSampler(SamplerBase):
93
93
  temperature: float = 0.0,
94
94
  reasoning_effort: Optional[str] = None,
95
95
  max_tokens: int = 2048,
96
+ extra_body: Optional[Dict[str, Any]] = None,
96
97
  ):
97
98
  self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient())
98
99
 
@@ -104,9 +105,10 @@ class ChatCompletionSampler(SamplerBase):
104
105
  self.temperature = temperature
105
106
  self.max_tokens = max_tokens
106
107
  self.reasoning_effort = reasoning_effort
108
+ self.extra_body = extra_body
107
109
  self.image_format = "url"
108
110
  print(
109
- f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=}"
111
+ f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=} {self.extra_body=}"
110
112
  )
111
113
 
112
114
  def _handle_image(
@@ -136,7 +138,7 @@ class ChatCompletionSampler(SamplerBase):
136
138
  self._pack_message("system", self.system_message)
137
139
  ] + message_list
138
140
  trial = 0
139
- while True:
141
+ while trial < 6: # 126 seconds in total
140
142
  try:
141
143
  response = self.client.chat.completions.create(
142
144
  model=self.model,
@@ -144,6 +146,7 @@ class ChatCompletionSampler(SamplerBase):
144
146
  temperature=self.temperature,
145
147
  max_tokens=self.max_tokens,
146
148
  reasoning_effort=self.reasoning_effort,
149
+ extra_body=self.extra_body,
147
150
  )
148
151
  return response.choices[0].message.content
149
152
  # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
@@ -0,0 +1,441 @@
1
+ """
2
+ MMMU evaluation for VLMs using the run_eval simple-evals interface.
3
+
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import base64
9
+ import io
10
+ from typing import List, Optional, Tuple
11
+
12
+ from datasets import concatenate_datasets, load_dataset
13
+ from PIL import Image
14
+
15
+ from sglang.test import simple_eval_common as common
16
+ from sglang.test.simple_eval_common import (
17
+ HTML_JINJA,
18
+ Eval,
19
+ EvalResult,
20
+ SamplerBase,
21
+ SingleEvalResult,
22
+ map_with_progress,
23
+ )
24
+
25
+
26
+ class MMMUVLMEval(Eval):
27
+ DOMAIN_CAT2SUB_CAT = {
28
+ "Art and Design": ["Art", "Art_Theory", "Design", "Music"],
29
+ "Business": ["Accounting", "Economics", "Finance", "Manage", "Marketing"],
30
+ "Science": ["Biology", "Chemistry", "Geography", "Math", "Physics"],
31
+ "Health and Medicine": [
32
+ "Basic_Medical_Science",
33
+ "Clinical_Medicine",
34
+ "Diagnostics_and_Laboratory_Medicine",
35
+ "Pharmacy",
36
+ "Public_Health",
37
+ ],
38
+ "Humanities and Social Science": [
39
+ "History",
40
+ "Literature",
41
+ "Sociology",
42
+ "Psychology",
43
+ ],
44
+ "Tech and Engineering": [
45
+ "Agriculture",
46
+ "Architecture_and_Engineering",
47
+ "Computer_Science",
48
+ "Electronics",
49
+ "Energy_and_Power",
50
+ "Materials",
51
+ "Mechanical_Engineering",
52
+ ],
53
+ }
54
+
55
+ def __init__(
56
+ self, num_examples: Optional[int] = 100, num_threads: int = 32, seed: int = 42
57
+ ):
58
+ """Create MMMU VLM eval (Math subset, 100 fixed samples by default)."""
59
+ self.num_examples = num_examples
60
+ self.num_threads = num_threads
61
+ self.seed = seed
62
+ # Prepare samples deterministically across all MMMU subjects (validation split)
63
+ self.samples = self._prepare_mmmu_samples(self.num_examples)
64
+
65
+ @staticmethod
66
+ def _to_data_uri(image: Image.Image) -> str:
67
+ if image.mode == "RGBA":
68
+ image = image.convert("RGB")
69
+ buf = io.BytesIO()
70
+ image.save(buf, format="PNG")
71
+ b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
72
+ return f"data:image/png;base64,{b64}"
73
+
74
+ @staticmethod
75
+ def _build_mc_mapping(options: List[str]) -> Tuple[dict, List[str]]:
76
+ index2ans = {}
77
+ all_choices = []
78
+ ch = ord("A")
79
+ for opt in options:
80
+ letter = chr(ch)
81
+ index2ans[letter] = opt
82
+ all_choices.append(letter)
83
+ ch += 1
84
+ return index2ans, all_choices
85
+
86
+ def _prepare_mmmu_samples(self, k: int) -> List[dict]:
87
+ # Subjects and domains copied from MMMU data_utils to categorize results
88
+ subjects: List[str] = []
89
+ for subs in self.DOMAIN_CAT2SUB_CAT.values():
90
+ subjects.extend(subs)
91
+
92
+ # Load validation split of each subject
93
+ datasets = []
94
+ for subj in subjects:
95
+ try:
96
+ d = load_dataset("MMMU/MMMU", subj, split="validation")
97
+ # attach subject info via transform
98
+ d = d.add_column("__subject__", [subj] * len(d))
99
+ datasets.append(d)
100
+ except Exception:
101
+ continue
102
+ if not datasets:
103
+ raise RuntimeError("Failed to load MMMU datasets")
104
+
105
+ merged = concatenate_datasets(datasets)
106
+
107
+ # Deterministic selection: sort by id (fallback to subject+index)
108
+ def _key(idx):
109
+ ex = merged[idx]
110
+ return str(ex.get("id", f"{ex['__subject__']}:{idx}"))
111
+
112
+ order = sorted(range(len(merged)), key=_key)
113
+ picked_indices = order[:k]
114
+
115
+ samples: List[dict] = []
116
+ for idx in picked_indices:
117
+ ex = merged[idx]
118
+ subject = ex["__subject__"]
119
+ image = ex.get("image_1")
120
+ if image is None or not hasattr(image, "convert"):
121
+ continue
122
+ data_uri = self._to_data_uri(image)
123
+ question = ex.get("question", "")
124
+ answer = ex.get("answer")
125
+ raw_options = ex.get("options")
126
+ question_type = "open"
127
+ index2ans = None
128
+ all_choices = None
129
+ options = None
130
+ if raw_options:
131
+ try:
132
+ options = (
133
+ raw_options
134
+ if isinstance(raw_options, list)
135
+ else list(eval(raw_options))
136
+ )
137
+ if isinstance(options, list) and len(options) > 0:
138
+ index2ans, all_choices = self._build_mc_mapping(options)
139
+ question_type = "multiple-choice"
140
+ except Exception:
141
+ options = None
142
+
143
+ # Build final textual prompt; include choices if MC
144
+ prompt_text = f"Question: {question}\n\n"
145
+ if options:
146
+ letters = [chr(ord("A") + i) for i in range(len(options))]
147
+ for letter, opt in zip(letters, options):
148
+ prompt_text += f"{letter}) {opt}\n"
149
+ prompt_text += "\nAnswer: "
150
+
151
+ samples.append(
152
+ {
153
+ "id": ex.get("id", f"{subject}:{idx}"),
154
+ "final_input_prompt": prompt_text,
155
+ "image_data": data_uri,
156
+ "answer": answer,
157
+ "question_type": question_type,
158
+ "index2ans": index2ans,
159
+ "all_choices": all_choices,
160
+ "category": subject,
161
+ }
162
+ )
163
+
164
+ return samples
165
+
166
+ @staticmethod
167
+ def _split_prompt_for_image(prompt: str) -> tuple[str, str]:
168
+ """Split a prompt containing an inline image tag into prefix and suffix.
169
+
170
+ If no tag is present, treat the whole prompt as prefix and empty suffix.
171
+ """
172
+ if "<" in prompt and ">" in prompt:
173
+ prefix = prompt.split("<")[0]
174
+ suffix = prompt.split(">", 1)[1]
175
+ return prefix, suffix
176
+ return prompt, ""
177
+
178
+ @staticmethod
179
+ def build_chat_messages_from_prompt(prompt: str, image_data) -> List:
180
+ """Split a prompt containing an inline image tag into prefix and suffix.
181
+
182
+ If no tag is present, treat the whole prompt as prefix and empty suffix.
183
+ """
184
+ # Build a vision+text message for OpenAI-compatible API
185
+ prefix, suffix = MMMUVLMEval._split_prompt_for_image(prompt)
186
+
187
+ content: List[dict] = []
188
+ if prefix:
189
+ content.append({"type": "text", "text": prefix})
190
+ content.append({"type": "image_url", "image_url": {"url": image_data}})
191
+ if suffix:
192
+ content.append({"type": "text", "text": suffix})
193
+ prompt_messages = [{"role": "user", "content": content}]
194
+
195
+ return prompt_messages
196
+
197
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
198
+ def fn(sample: dict):
199
+ prompt = sample["final_input_prompt"]
200
+ image_data = sample["image_data"]
201
+ prompt_messages = MMMUVLMEval.build_chat_messages_from_prompt(
202
+ prompt, image_data
203
+ )
204
+
205
+ # Sample
206
+ response_text = sampler(prompt_messages)
207
+
208
+ # Parse and score
209
+ gold = sample["answer"]
210
+ if (
211
+ sample["question_type"] == "multiple-choice"
212
+ and sample["all_choices"]
213
+ and sample["index2ans"]
214
+ ):
215
+ pred = _parse_multi_choice_response(
216
+ response_text, sample["all_choices"], sample["index2ans"]
217
+ )
218
+ score = 1.0 if (gold is not None and pred == gold) else 0.0
219
+ extracted_answer = pred
220
+ else:
221
+ parsed_list = _parse_open_response(response_text)
222
+ score = (
223
+ 1.0 if (gold is not None and _eval_open(gold, parsed_list)) else 0.0
224
+ )
225
+ extracted_answer = ", ".join(map(str, parsed_list))
226
+
227
+ html_rendered = common.jinja_env.from_string(HTML_JINJA).render(
228
+ prompt_messages=prompt_messages,
229
+ next_message=dict(content=response_text, role="assistant"),
230
+ score=score,
231
+ correct_answer=gold,
232
+ extracted_answer=extracted_answer,
233
+ )
234
+
235
+ convo = prompt_messages + [dict(content=response_text, role="assistant")]
236
+ return SingleEvalResult(
237
+ html=html_rendered,
238
+ score=score,
239
+ metrics={"__category__": sample["category"]},
240
+ convo=convo,
241
+ )
242
+
243
+ results = map_with_progress(fn, self.samples, self.num_threads)
244
+
245
+ # Build category table and overall accuracy
246
+ # Gather per-sample correctness and category
247
+ per_cat_total: dict[str, int] = {}
248
+ per_cat_correct: dict[str, int] = {}
249
+ htmls = []
250
+ convos = []
251
+ scores: List[float] = []
252
+ for r in results:
253
+ # __category__ stored under metrics
254
+ cat = r.metrics.get("__category__") if r.metrics else None
255
+ if cat is None:
256
+ cat = "Unknown"
257
+ per_cat_total[cat] = per_cat_total.get(cat, 0) + 1
258
+ if r.score:
259
+ per_cat_correct[cat] = per_cat_correct.get(cat, 0) + 1
260
+ htmls.append(r.html)
261
+ convos.append(r.convo)
262
+ if r.score is not None:
263
+ scores.append(r.score)
264
+
265
+ evaluation_result = {}
266
+ for cat, tot in per_cat_total.items():
267
+ corr = per_cat_correct.get(cat, 0)
268
+ acc = (corr / tot) if tot > 0 else 0.0
269
+ evaluation_result[cat] = {"acc": round(acc, 3), "num_example": tot}
270
+
271
+ printable_results = {}
272
+ # Domains first
273
+ for domain, cats in self.DOMAIN_CAT2SUB_CAT.items():
274
+ acc_sum = 0.0
275
+ num_sum = 0
276
+ for cat in cats:
277
+ if cat in evaluation_result:
278
+ acc_sum += (
279
+ evaluation_result[cat]["acc"]
280
+ * evaluation_result[cat]["num_example"]
281
+ )
282
+ num_sum += evaluation_result[cat]["num_example"]
283
+ if num_sum > 0:
284
+ printable_results[f"Overall-{domain}"] = {
285
+ "num": num_sum,
286
+ "acc": round(acc_sum / num_sum, 3),
287
+ }
288
+ # add each sub-category row if present
289
+ for cat in cats:
290
+ if cat in evaluation_result:
291
+ printable_results[cat] = {
292
+ "num": evaluation_result[cat]["num_example"],
293
+ "acc": evaluation_result[cat]["acc"],
294
+ }
295
+
296
+ # Overall
297
+ total_num = sum(v["num_example"] for v in evaluation_result.values())
298
+ overall_acc = (
299
+ sum(v["acc"] * v["num_example"] for v in evaluation_result.values())
300
+ / total_num
301
+ if total_num > 0
302
+ else 0.0
303
+ )
304
+ printable_results["Overall"] = {"num": total_num, "acc": round(overall_acc, 3)}
305
+
306
+ # Build EvalResult
307
+ return EvalResult(
308
+ score=overall_acc, metrics=printable_results, htmls=htmls, convos=convos
309
+ )
310
+
311
+
312
+ def _parse_multi_choice_response(
313
+ response: str, all_choices: List[str], index2ans: dict
314
+ ) -> str:
315
+ # loosely adapted from benchmark mmmu eval
316
+ for char in [",", ".", "!", "?", ";", ":", "'"]:
317
+ response = response.strip(char)
318
+ response = " " + response + " "
319
+
320
+ # Prefer explicit letter with bracket e.g. (A)
321
+ candidates: List[str] = []
322
+ for choice in all_choices:
323
+ if f"({choice})" in response:
324
+ candidates.append(choice)
325
+ if not candidates:
326
+ for choice in all_choices:
327
+ if f" {choice} " in response:
328
+ candidates.append(choice)
329
+ if not candidates and len(response.split()) > 5:
330
+ # try match by option text
331
+ for idx, ans in index2ans.items():
332
+ if ans and ans.lower() in response.lower():
333
+ candidates.append(idx)
334
+ if not candidates:
335
+ # fallback to first choice
336
+ return all_choices[0]
337
+ if len(candidates) == 1:
338
+ return candidates[0]
339
+ # choose the last occurrence
340
+ starts = []
341
+ for can in candidates:
342
+ pos = response.rfind(f"({can})")
343
+ if pos == -1:
344
+ pos = response.rfind(f" {can} ")
345
+ if pos == -1 and index2ans.get(can):
346
+ pos = response.lower().rfind(index2ans[can].lower())
347
+ starts.append(pos)
348
+ return candidates[int(max(range(len(starts)), key=lambda i: starts[i]))]
349
+
350
+
351
+ def _check_is_number(s: str) -> bool:
352
+ try:
353
+ float(s.replace(",", ""))
354
+ return True
355
+ except Exception:
356
+ return False
357
+
358
+
359
+ def _normalize_str(s: str):
360
+ s = s.strip()
361
+ if _check_is_number(s):
362
+ s = s.replace(",", "")
363
+ try:
364
+ v = round(float(s), 2)
365
+ return [v]
366
+ except Exception:
367
+ return [s.lower()]
368
+ return [s.lower()] if len(s) > 1 else [" " + s, s + " "]
369
+
370
+
371
+ def _extract_numbers(s: str) -> List[str]:
372
+ import re as _re
373
+
374
+ pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b"
375
+ pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
376
+ pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])"
377
+ return (
378
+ _re.findall(pattern_commas, s)
379
+ + _re.findall(pattern_scientific, s)
380
+ + _re.findall(pattern_simple, s)
381
+ )
382
+
383
+
384
+ def _parse_open_response(response: str) -> List[str]:
385
+ import re as _re
386
+
387
+ def get_key_subresponses(resp: str) -> List[str]:
388
+ resp = resp.strip().strip(".").lower()
389
+ subs = _re.split(r"\.\s(?=[A-Z])|\n", resp)
390
+ indicators = [
391
+ "could be ",
392
+ "so ",
393
+ "is ",
394
+ "thus ",
395
+ "therefore ",
396
+ "final ",
397
+ "answer ",
398
+ "result ",
399
+ ]
400
+ keys = []
401
+ for i, s in enumerate(subs):
402
+ cands = [*indicators]
403
+ if i == len(subs) - 1:
404
+ cands.append("=")
405
+ shortest = None
406
+ for ind in cands:
407
+ if ind in s:
408
+ part = s.split(ind)[-1].strip()
409
+ if not shortest or len(part) < len(shortest):
410
+ shortest = part
411
+ if shortest and shortest not in [":", ",", ".", "!", "?", ";", ":", "'"]:
412
+ keys.append(shortest)
413
+ return keys or [resp]
414
+
415
+ key_resps = get_key_subresponses(response)
416
+ pred_list = key_resps.copy()
417
+ for r in key_resps:
418
+ pred_list.extend(_extract_numbers(r))
419
+ out = []
420
+ for x in pred_list:
421
+ out.extend(_normalize_str(x))
422
+ # dedup
423
+ return list(dict.fromkeys(out))
424
+
425
+
426
+ def _eval_open(gold, preds: List[str]) -> bool:
427
+ if isinstance(gold, list):
428
+ norm_answers = []
429
+ for ans in gold:
430
+ norm_answers.extend(_normalize_str(ans))
431
+ else:
432
+ norm_answers = _normalize_str(gold)
433
+ for p in preds:
434
+ if isinstance(p, str):
435
+ for na in norm_answers:
436
+ if isinstance(na, str) and na in p:
437
+ return True
438
+ else:
439
+ if p in norm_answers:
440
+ return True
441
+ return False
@@ -621,11 +621,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
621
621
  w_s,
622
622
  )
623
623
 
624
- from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
624
+ from deep_gemm import fp8_m_grouped_gemm_nt_masked
625
625
 
626
626
  with torch.inference_mode():
627
627
  ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
628
- m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs, rhs, oe, masked_m, expected_m)
628
+ fp8_m_grouped_gemm_nt_masked(lhs, rhs, oe, masked_m, expected_m)
629
629
  out = oe[:, :M, :]
630
630
 
631
631
  self.assertTrue(