sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,441 @@
1
+ """
2
+ MMMU evaluation for VLMs using the run_eval simple-evals interface.
3
+
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import base64
9
+ import io
10
+ from typing import List, Optional, Tuple
11
+
12
+ from datasets import concatenate_datasets, load_dataset
13
+ from PIL import Image
14
+
15
+ from sglang.test import simple_eval_common as common
16
+ from sglang.test.simple_eval_common import (
17
+ HTML_JINJA,
18
+ Eval,
19
+ EvalResult,
20
+ SamplerBase,
21
+ SingleEvalResult,
22
+ map_with_progress,
23
+ )
24
+
25
+
26
+ class MMMUVLMEval(Eval):
27
+ DOMAIN_CAT2SUB_CAT = {
28
+ "Art and Design": ["Art", "Art_Theory", "Design", "Music"],
29
+ "Business": ["Accounting", "Economics", "Finance", "Manage", "Marketing"],
30
+ "Science": ["Biology", "Chemistry", "Geography", "Math", "Physics"],
31
+ "Health and Medicine": [
32
+ "Basic_Medical_Science",
33
+ "Clinical_Medicine",
34
+ "Diagnostics_and_Laboratory_Medicine",
35
+ "Pharmacy",
36
+ "Public_Health",
37
+ ],
38
+ "Humanities and Social Science": [
39
+ "History",
40
+ "Literature",
41
+ "Sociology",
42
+ "Psychology",
43
+ ],
44
+ "Tech and Engineering": [
45
+ "Agriculture",
46
+ "Architecture_and_Engineering",
47
+ "Computer_Science",
48
+ "Electronics",
49
+ "Energy_and_Power",
50
+ "Materials",
51
+ "Mechanical_Engineering",
52
+ ],
53
+ }
54
+
55
+ def __init__(
56
+ self, num_examples: Optional[int] = 100, num_threads: int = 32, seed: int = 42
57
+ ):
58
+ """Create MMMU VLM eval (Math subset, 100 fixed samples by default)."""
59
+ self.num_examples = num_examples
60
+ self.num_threads = num_threads
61
+ self.seed = seed
62
+ # Prepare samples deterministically across all MMMU subjects (validation split)
63
+ self.samples = self._prepare_mmmu_samples(self.num_examples)
64
+
65
+ @staticmethod
66
+ def _to_data_uri(image: Image.Image) -> str:
67
+ if image.mode == "RGBA":
68
+ image = image.convert("RGB")
69
+ buf = io.BytesIO()
70
+ image.save(buf, format="PNG")
71
+ b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
72
+ return f"data:image/png;base64,{b64}"
73
+
74
+ @staticmethod
75
+ def _build_mc_mapping(options: List[str]) -> Tuple[dict, List[str]]:
76
+ index2ans = {}
77
+ all_choices = []
78
+ ch = ord("A")
79
+ for opt in options:
80
+ letter = chr(ch)
81
+ index2ans[letter] = opt
82
+ all_choices.append(letter)
83
+ ch += 1
84
+ return index2ans, all_choices
85
+
86
+ def _prepare_mmmu_samples(self, k: int) -> List[dict]:
87
+ # Subjects and domains copied from MMMU data_utils to categorize results
88
+ subjects: List[str] = []
89
+ for subs in self.DOMAIN_CAT2SUB_CAT.values():
90
+ subjects.extend(subs)
91
+
92
+ # Load validation split of each subject
93
+ datasets = []
94
+ for subj in subjects:
95
+ try:
96
+ d = load_dataset("MMMU/MMMU", subj, split="validation")
97
+ # attach subject info via transform
98
+ d = d.add_column("__subject__", [subj] * len(d))
99
+ datasets.append(d)
100
+ except Exception:
101
+ continue
102
+ if not datasets:
103
+ raise RuntimeError("Failed to load MMMU datasets")
104
+
105
+ merged = concatenate_datasets(datasets)
106
+
107
+ # Deterministic selection: sort by id (fallback to subject+index)
108
+ def _key(idx):
109
+ ex = merged[idx]
110
+ return str(ex.get("id", f"{ex['__subject__']}:{idx}"))
111
+
112
+ order = sorted(range(len(merged)), key=_key)
113
+ picked_indices = order[:k]
114
+
115
+ samples: List[dict] = []
116
+ for idx in picked_indices:
117
+ ex = merged[idx]
118
+ subject = ex["__subject__"]
119
+ image = ex.get("image_1")
120
+ if image is None or not hasattr(image, "convert"):
121
+ continue
122
+ data_uri = self._to_data_uri(image)
123
+ question = ex.get("question", "")
124
+ answer = ex.get("answer")
125
+ raw_options = ex.get("options")
126
+ question_type = "open"
127
+ index2ans = None
128
+ all_choices = None
129
+ options = None
130
+ if raw_options:
131
+ try:
132
+ options = (
133
+ raw_options
134
+ if isinstance(raw_options, list)
135
+ else list(eval(raw_options))
136
+ )
137
+ if isinstance(options, list) and len(options) > 0:
138
+ index2ans, all_choices = self._build_mc_mapping(options)
139
+ question_type = "multiple-choice"
140
+ except Exception:
141
+ options = None
142
+
143
+ # Build final textual prompt; include choices if MC
144
+ prompt_text = f"Question: {question}\n\n"
145
+ if options:
146
+ letters = [chr(ord("A") + i) for i in range(len(options))]
147
+ for letter, opt in zip(letters, options):
148
+ prompt_text += f"{letter}) {opt}\n"
149
+ prompt_text += "\nAnswer: "
150
+
151
+ samples.append(
152
+ {
153
+ "id": ex.get("id", f"{subject}:{idx}"),
154
+ "final_input_prompt": prompt_text,
155
+ "image_data": data_uri,
156
+ "answer": answer,
157
+ "question_type": question_type,
158
+ "index2ans": index2ans,
159
+ "all_choices": all_choices,
160
+ "category": subject,
161
+ }
162
+ )
163
+
164
+ return samples
165
+
166
+ @staticmethod
167
+ def _split_prompt_for_image(prompt: str) -> tuple[str, str]:
168
+ """Split a prompt containing an inline image tag into prefix and suffix.
169
+
170
+ If no tag is present, treat the whole prompt as prefix and empty suffix.
171
+ """
172
+ if "<" in prompt and ">" in prompt:
173
+ prefix = prompt.split("<")[0]
174
+ suffix = prompt.split(">", 1)[1]
175
+ return prefix, suffix
176
+ return prompt, ""
177
+
178
+ @staticmethod
179
+ def build_chat_messages_from_prompt(prompt: str, image_data) -> List:
180
+ """Split a prompt containing an inline image tag into prefix and suffix.
181
+
182
+ If no tag is present, treat the whole prompt as prefix and empty suffix.
183
+ """
184
+ # Build a vision+text message for OpenAI-compatible API
185
+ prefix, suffix = MMMUVLMEval._split_prompt_for_image(prompt)
186
+
187
+ content: List[dict] = []
188
+ if prefix:
189
+ content.append({"type": "text", "text": prefix})
190
+ content.append({"type": "image_url", "image_url": {"url": image_data}})
191
+ if suffix:
192
+ content.append({"type": "text", "text": suffix})
193
+ prompt_messages = [{"role": "user", "content": content}]
194
+
195
+ return prompt_messages
196
+
197
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
198
+ def fn(sample: dict):
199
+ prompt = sample["final_input_prompt"]
200
+ image_data = sample["image_data"]
201
+ prompt_messages = MMMUVLMEval.build_chat_messages_from_prompt(
202
+ prompt, image_data
203
+ )
204
+
205
+ # Sample
206
+ response_text = sampler(prompt_messages)
207
+
208
+ # Parse and score
209
+ gold = sample["answer"]
210
+ if (
211
+ sample["question_type"] == "multiple-choice"
212
+ and sample["all_choices"]
213
+ and sample["index2ans"]
214
+ ):
215
+ pred = _parse_multi_choice_response(
216
+ response_text, sample["all_choices"], sample["index2ans"]
217
+ )
218
+ score = 1.0 if (gold is not None and pred == gold) else 0.0
219
+ extracted_answer = pred
220
+ else:
221
+ parsed_list = _parse_open_response(response_text)
222
+ score = (
223
+ 1.0 if (gold is not None and _eval_open(gold, parsed_list)) else 0.0
224
+ )
225
+ extracted_answer = ", ".join(map(str, parsed_list))
226
+
227
+ html_rendered = common.jinja_env.from_string(HTML_JINJA).render(
228
+ prompt_messages=prompt_messages,
229
+ next_message=dict(content=response_text, role="assistant"),
230
+ score=score,
231
+ correct_answer=gold,
232
+ extracted_answer=extracted_answer,
233
+ )
234
+
235
+ convo = prompt_messages + [dict(content=response_text, role="assistant")]
236
+ return SingleEvalResult(
237
+ html=html_rendered,
238
+ score=score,
239
+ metrics={"__category__": sample["category"]},
240
+ convo=convo,
241
+ )
242
+
243
+ results = map_with_progress(fn, self.samples, self.num_threads)
244
+
245
+ # Build category table and overall accuracy
246
+ # Gather per-sample correctness and category
247
+ per_cat_total: dict[str, int] = {}
248
+ per_cat_correct: dict[str, int] = {}
249
+ htmls = []
250
+ convos = []
251
+ scores: List[float] = []
252
+ for r in results:
253
+ # __category__ stored under metrics
254
+ cat = r.metrics.get("__category__") if r.metrics else None
255
+ if cat is None:
256
+ cat = "Unknown"
257
+ per_cat_total[cat] = per_cat_total.get(cat, 0) + 1
258
+ if r.score:
259
+ per_cat_correct[cat] = per_cat_correct.get(cat, 0) + 1
260
+ htmls.append(r.html)
261
+ convos.append(r.convo)
262
+ if r.score is not None:
263
+ scores.append(r.score)
264
+
265
+ evaluation_result = {}
266
+ for cat, tot in per_cat_total.items():
267
+ corr = per_cat_correct.get(cat, 0)
268
+ acc = (corr / tot) if tot > 0 else 0.0
269
+ evaluation_result[cat] = {"acc": round(acc, 3), "num_example": tot}
270
+
271
+ printable_results = {}
272
+ # Domains first
273
+ for domain, cats in self.DOMAIN_CAT2SUB_CAT.items():
274
+ acc_sum = 0.0
275
+ num_sum = 0
276
+ for cat in cats:
277
+ if cat in evaluation_result:
278
+ acc_sum += (
279
+ evaluation_result[cat]["acc"]
280
+ * evaluation_result[cat]["num_example"]
281
+ )
282
+ num_sum += evaluation_result[cat]["num_example"]
283
+ if num_sum > 0:
284
+ printable_results[f"Overall-{domain}"] = {
285
+ "num": num_sum,
286
+ "acc": round(acc_sum / num_sum, 3),
287
+ }
288
+ # add each sub-category row if present
289
+ for cat in cats:
290
+ if cat in evaluation_result:
291
+ printable_results[cat] = {
292
+ "num": evaluation_result[cat]["num_example"],
293
+ "acc": evaluation_result[cat]["acc"],
294
+ }
295
+
296
+ # Overall
297
+ total_num = sum(v["num_example"] for v in evaluation_result.values())
298
+ overall_acc = (
299
+ sum(v["acc"] * v["num_example"] for v in evaluation_result.values())
300
+ / total_num
301
+ if total_num > 0
302
+ else 0.0
303
+ )
304
+ printable_results["Overall"] = {"num": total_num, "acc": round(overall_acc, 3)}
305
+
306
+ # Build EvalResult
307
+ return EvalResult(
308
+ score=overall_acc, metrics=printable_results, htmls=htmls, convos=convos
309
+ )
310
+
311
+
312
+ def _parse_multi_choice_response(
313
+ response: str, all_choices: List[str], index2ans: dict
314
+ ) -> str:
315
+ # loosely adapted from benchmark mmmu eval
316
+ for char in [",", ".", "!", "?", ";", ":", "'"]:
317
+ response = response.strip(char)
318
+ response = " " + response + " "
319
+
320
+ # Prefer explicit letter with bracket e.g. (A)
321
+ candidates: List[str] = []
322
+ for choice in all_choices:
323
+ if f"({choice})" in response:
324
+ candidates.append(choice)
325
+ if not candidates:
326
+ for choice in all_choices:
327
+ if f" {choice} " in response:
328
+ candidates.append(choice)
329
+ if not candidates and len(response.split()) > 5:
330
+ # try match by option text
331
+ for idx, ans in index2ans.items():
332
+ if ans and ans.lower() in response.lower():
333
+ candidates.append(idx)
334
+ if not candidates:
335
+ # fallback to first choice
336
+ return all_choices[0]
337
+ if len(candidates) == 1:
338
+ return candidates[0]
339
+ # choose the last occurrence
340
+ starts = []
341
+ for can in candidates:
342
+ pos = response.rfind(f"({can})")
343
+ if pos == -1:
344
+ pos = response.rfind(f" {can} ")
345
+ if pos == -1 and index2ans.get(can):
346
+ pos = response.lower().rfind(index2ans[can].lower())
347
+ starts.append(pos)
348
+ return candidates[int(max(range(len(starts)), key=lambda i: starts[i]))]
349
+
350
+
351
+ def _check_is_number(s: str) -> bool:
352
+ try:
353
+ float(s.replace(",", ""))
354
+ return True
355
+ except Exception:
356
+ return False
357
+
358
+
359
+ def _normalize_str(s: str):
360
+ s = s.strip()
361
+ if _check_is_number(s):
362
+ s = s.replace(",", "")
363
+ try:
364
+ v = round(float(s), 2)
365
+ return [v]
366
+ except Exception:
367
+ return [s.lower()]
368
+ return [s.lower()] if len(s) > 1 else [" " + s, s + " "]
369
+
370
+
371
+ def _extract_numbers(s: str) -> List[str]:
372
+ import re as _re
373
+
374
+ pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b"
375
+ pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
376
+ pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])"
377
+ return (
378
+ _re.findall(pattern_commas, s)
379
+ + _re.findall(pattern_scientific, s)
380
+ + _re.findall(pattern_simple, s)
381
+ )
382
+
383
+
384
+ def _parse_open_response(response: str) -> List[str]:
385
+ import re as _re
386
+
387
+ def get_key_subresponses(resp: str) -> List[str]:
388
+ resp = resp.strip().strip(".").lower()
389
+ subs = _re.split(r"\.\s(?=[A-Z])|\n", resp)
390
+ indicators = [
391
+ "could be ",
392
+ "so ",
393
+ "is ",
394
+ "thus ",
395
+ "therefore ",
396
+ "final ",
397
+ "answer ",
398
+ "result ",
399
+ ]
400
+ keys = []
401
+ for i, s in enumerate(subs):
402
+ cands = [*indicators]
403
+ if i == len(subs) - 1:
404
+ cands.append("=")
405
+ shortest = None
406
+ for ind in cands:
407
+ if ind in s:
408
+ part = s.split(ind)[-1].strip()
409
+ if not shortest or len(part) < len(shortest):
410
+ shortest = part
411
+ if shortest and shortest not in [":", ",", ".", "!", "?", ";", ":", "'"]:
412
+ keys.append(shortest)
413
+ return keys or [resp]
414
+
415
+ key_resps = get_key_subresponses(response)
416
+ pred_list = key_resps.copy()
417
+ for r in key_resps:
418
+ pred_list.extend(_extract_numbers(r))
419
+ out = []
420
+ for x in pred_list:
421
+ out.extend(_normalize_str(x))
422
+ # dedup
423
+ return list(dict.fromkeys(out))
424
+
425
+
426
+ def _eval_open(gold, preds: List[str]) -> bool:
427
+ if isinstance(gold, list):
428
+ norm_answers = []
429
+ for ans in gold:
430
+ norm_answers.extend(_normalize_str(ans))
431
+ else:
432
+ norm_answers = _normalize_str(gold)
433
+ for p in preds:
434
+ if isinstance(p, str):
435
+ for na in norm_answers:
436
+ if isinstance(na, str) and na in p:
437
+ return True
438
+ else:
439
+ if p in norm_answers:
440
+ return True
441
+ return False
@@ -1,5 +1,4 @@
1
1
  import itertools
2
- import os
3
2
  import unittest
4
3
 
5
4
  import torch
@@ -577,7 +576,7 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
577
576
  if not torch.cuda.is_available():
578
577
  raise unittest.SkipTest("CUDA is not available")
579
578
  try:
580
- import deep_gemm
579
+ import deep_gemm # noqa: F401
581
580
  except ImportError:
582
581
  raise unittest.SkipTest("DeepGEMM is not available")
583
582
  torch.set_default_device("cuda")
@@ -621,11 +620,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
621
620
  w_s,
622
621
  )
623
622
 
624
- from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
623
+ from deep_gemm import fp8_m_grouped_gemm_nt_masked
625
624
 
626
625
  with torch.inference_mode():
627
626
  ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
628
- m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs, rhs, oe, masked_m, expected_m)
627
+ fp8_m_grouped_gemm_nt_masked(lhs, rhs, oe, masked_m, expected_m)
629
628
  out = oe[:, :M, :]
630
629
 
631
630
  self.assertTrue(
@@ -1,5 +1,4 @@
1
1
  import itertools
2
- import os
3
2
  import unittest
4
3
  from typing import List, Tuple
5
4
 
@@ -1,5 +1,4 @@
1
1
  import argparse
2
- import time
3
2
 
4
3
  import torch
5
4
  import triton # Added import
@@ -34,7 +33,7 @@ def get_model_config(tp_size: int):
34
33
  "topk": topk,
35
34
  "hidden_size": config.hidden_size,
36
35
  "shard_intermediate_size": shard_intermediate_size,
37
- "dtype": config.torch_dtype,
36
+ "dtype": config.dtype,
38
37
  "block_shape": config.quantization_config["weight_block_size"],
39
38
  }
40
39
 
@@ -1,6 +1,6 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- from typing import Literal, Optional
3
+ from typing import Optional
4
4
 
5
5
  import pytest
6
6
  import torch
@@ -120,7 +120,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty
120
120
  )
121
121
  topk_weights, topk_ids, _ = topk_output
122
122
  expert_map = torch.arange(E, dtype=torch.int32, device=device)
123
- expert_map[local_e:] = E
123
+ expert_map[local_e:] = -1
124
124
 
125
125
  output = cutlass_moe(
126
126
  a,
@@ -138,9 +138,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty
138
138
  c_strides2,
139
139
  s_strides13,
140
140
  s_strides2,
141
- 0,
142
- local_e - 1,
143
- E,
141
+ local_e,
144
142
  a1_scale,
145
143
  a2_scale,
146
144
  expert_map,
@@ -178,7 +176,7 @@ def cutlass_moe(
178
176
  w1_scale: torch.Tensor,
179
177
  w2_scale: torch.Tensor,
180
178
  topk_weights: torch.Tensor,
181
- topk_ids_: torch.Tensor,
179
+ topk_ids: torch.Tensor,
182
180
  a_strides1: torch.Tensor,
183
181
  b_strides1: torch.Tensor,
184
182
  c_strides1: torch.Tensor,
@@ -187,40 +185,32 @@ def cutlass_moe(
187
185
  c_strides2: torch.Tensor,
188
186
  s_strides13: torch.Tensor,
189
187
  s_strides2: torch.Tensor,
190
- start_expert_id: int,
191
- end_expert_id: int,
192
- E: int,
188
+ num_local_experts: int,
193
189
  a1_scale: Optional[torch.Tensor] = None,
194
190
  a2_scale: Optional[torch.Tensor] = None,
195
191
  expert_map: Optional[torch.Tensor] = None,
196
192
  apply_router_weight_on_input: bool = False,
197
193
  ):
198
- local_topk_ids = topk_ids_
199
- local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E)
194
+ topk_ids = expert_map[topk_ids]
200
195
  device = a.device
201
196
 
202
- local_num_experts = end_expert_id - start_expert_id + 1
203
197
  expert_offsets = torch.empty(
204
- (local_num_experts + 1), dtype=torch.int32, device=device
198
+ (num_local_experts + 1), dtype=torch.int32, device=device
205
199
  )
206
200
  problem_sizes1 = torch.empty(
207
- (local_num_experts, 3), dtype=torch.int32, device=device
201
+ (num_local_experts, 3), dtype=torch.int32, device=device
208
202
  )
209
203
  problem_sizes2 = torch.empty(
210
- (local_num_experts, 3), dtype=torch.int32, device=device
204
+ (num_local_experts, 3), dtype=torch.int32, device=device
211
205
  )
212
206
  return cutlass_w4a8_moe(
213
- start_expert_id,
214
- end_expert_id,
215
- E,
216
207
  a,
217
208
  w1_q,
218
209
  w2_q,
219
210
  w1_scale,
220
211
  w2_scale,
221
212
  topk_weights,
222
- topk_ids_,
223
- local_topk_ids,
213
+ topk_ids,
224
214
  a_strides1,
225
215
  b_strides1,
226
216
  c_strides1,