sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh
6
6
  from torch.distributed.tensor import DTensor
7
7
 
8
8
  from sglang.srt.entrypoints.engine import Engine
9
- from sglang.srt.managers.tokenizer_manager import UpdateWeightsFromTensorReqInput
9
+ from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
10
10
  from sglang.srt.model_executor.model_runner import LocalSerializedTensor
11
11
  from sglang.srt.utils import MultiprocessingSerializer
12
12
 
@@ -33,7 +33,7 @@ async def update_weights(
33
33
  """
34
34
  infer_tp_size = device_mesh[device_mesh_key].mesh.size()[0]
35
35
  infer_tp_rank = device_mesh[device_mesh_key].get_local_rank()
36
- from sglang.srt.patch_torch import monkey_patch_torch_reductions
36
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
37
37
 
38
38
  monkey_patch_torch_reductions()
39
39
 
@@ -41,6 +41,10 @@ DEFAULT_CONFIG = {
41
41
  "v_head_dim": 512,
42
42
  "num_kv_heads": 1,
43
43
  "layer_id": 0,
44
+ "tp_q_head_num": 128,
45
+ "tp_k_head_num": 128,
46
+ "prefill_head_dim": 192,
47
+ "prefill_v_head_dim": 128,
44
48
  }
45
49
 
46
50
  ROPE_BASE = 10000
@@ -92,7 +96,7 @@ TEST_CASES = {
92
96
  "description": "Medium-scale batch",
93
97
  },
94
98
  ],
95
- "decode_output_match": [
99
+ "output_match": [
96
100
  {
97
101
  "name": "single_fp16",
98
102
  "batch_size": 1,
@@ -322,7 +326,7 @@ class TestTRTLLMMLA(CustomTestCase):
322
326
  config.update(test_case)
323
327
  return config
324
328
 
325
- def _create_model_components(self, config):
329
+ def _create_model_components(self, config, is_prefill=False):
326
330
  """Create model runners, backends, and layer for testing."""
327
331
  # Create model runners
328
332
  model_runner_trtllm = MockModelRunner(config)
@@ -332,14 +336,23 @@ class TestTRTLLMMLA(CustomTestCase):
332
336
  trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
333
337
  reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
334
338
 
339
+ head_dim = (
340
+ config["kv_lora_rank"] + config["qk_rope_head_dim"]
341
+ if not is_prefill
342
+ else config["prefill_head_dim"]
343
+ )
344
+ v_head_dim = (
345
+ config["v_head_dim"] if not is_prefill else config["prefill_v_head_dim"]
346
+ )
347
+
335
348
  # Create RadixAttention layer
336
349
  layer = RadixAttention(
337
350
  num_heads=config["num_attention_heads"],
338
- head_dim=config["kv_lora_rank"] + config["qk_rope_head_dim"],
351
+ head_dim=head_dim,
339
352
  scaling=model_runner_trtllm.model_config.scaling,
340
353
  num_kv_heads=config["num_kv_heads"],
341
354
  layer_id=config["layer_id"],
342
- v_head_dim=config["v_head_dim"],
355
+ v_head_dim=v_head_dim,
343
356
  prefix="attn_mqa",
344
357
  )
345
358
 
@@ -524,7 +537,7 @@ class TestTRTLLMMLA(CustomTestCase):
524
537
  """Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
525
538
  print(f"\nRunning decode output matching tests...")
526
539
 
527
- for test_case in TEST_CASES["decode_output_match"]:
540
+ for test_case in TEST_CASES["output_match"]:
528
541
  with self.subTest(test_case=test_case["name"]):
529
542
  print(f" Testing {test_case['name']}: {test_case['description']}")
530
543
 
@@ -1099,6 +1112,157 @@ class TestTRTLLMMLA(CustomTestCase):
1099
1112
  self.assertIsNotNone(metadata_3.block_kv_indices)
1100
1113
  self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
1101
1114
 
1115
+ def test_prefill_output_match_self_attention(self):
1116
+ """Test prefill (forward) behavior of TRTLLM MLA backend vs reference."""
1117
+ print(f"\nRunning prefill output tests...")
1118
+
1119
+ for test_case in TEST_CASES["output_match"][:2]: # Just a subset for speed
1120
+ with self.subTest(test_case=test_case["name"]):
1121
+ print(
1122
+ f"Prefill Testing {test_case['name']}: {test_case['description']}"
1123
+ )
1124
+
1125
+ config = self._merge_config(test_case)
1126
+ batch_size = config["batch_size"]
1127
+ max_seq_len = config["max_seq_len"]
1128
+
1129
+ # Create components
1130
+ (
1131
+ model_runner_trtllm,
1132
+ model_runner_reference,
1133
+ trtllm_backend,
1134
+ reference_backend,
1135
+ layer,
1136
+ ) = self._create_model_components(config, is_prefill=True)
1137
+
1138
+ # Prefill uses full sequences
1139
+ seq_lens = torch.full(
1140
+ (batch_size,), max_seq_len, device=config["device"]
1141
+ )
1142
+
1143
+ def _create_forward_batch_prefill(
1144
+ batch_size,
1145
+ seq_lens,
1146
+ extend_prefix_lens,
1147
+ backend,
1148
+ model_runner,
1149
+ config,
1150
+ ):
1151
+ """Create a forward batch for the given backend."""
1152
+
1153
+ fb = ForwardBatch(
1154
+ batch_size=batch_size,
1155
+ input_ids=torch.randint(
1156
+ 0, 100, (batch_size, 1), device=config["device"]
1157
+ ),
1158
+ out_cache_loc=torch.arange(batch_size, device=config["device"]),
1159
+ seq_lens_sum=int(seq_lens.sum().item()),
1160
+ extend_prefix_lens=extend_prefix_lens,
1161
+ extend_prefix_lens_cpu=extend_prefix_lens.cpu().int().tolist(),
1162
+ extend_seq_lens_cpu=(seq_lens - extend_prefix_lens)
1163
+ .cpu()
1164
+ .int()
1165
+ .tolist(),
1166
+ forward_mode=ForwardMode.EXTEND,
1167
+ req_pool_indices=torch.arange(
1168
+ batch_size, device=config["device"]
1169
+ ),
1170
+ seq_lens=seq_lens,
1171
+ seq_lens_cpu=seq_lens.cpu(),
1172
+ attn_attend_prefix_cache=False,
1173
+ mha_return_lse=False,
1174
+ attn_backend=backend,
1175
+ )
1176
+ fb.req_to_token_pool = model_runner.req_to_token_pool
1177
+ fb.token_to_kv_pool = model_runner.token_to_kv_pool
1178
+
1179
+ # Add position information for RoPE
1180
+ fb.positions = torch.arange(batch_size, device=config["device"])
1181
+
1182
+ return fb
1183
+
1184
+ # Create forward batches
1185
+ fb_trtllm = _create_forward_batch_prefill(
1186
+ batch_size,
1187
+ seq_lens.clone(),
1188
+ torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
1189
+ trtllm_backend,
1190
+ model_runner_trtllm,
1191
+ config,
1192
+ )
1193
+ fb_reference = _create_forward_batch_prefill(
1194
+ batch_size,
1195
+ seq_lens.clone(),
1196
+ torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
1197
+ reference_backend,
1198
+ model_runner_reference,
1199
+ config,
1200
+ )
1201
+
1202
+ # Initialize metadata for both backends
1203
+ trtllm_backend.init_forward_metadata(fb_trtllm)
1204
+ reference_backend.init_forward_metadata(fb_reference)
1205
+
1206
+ # Create Q, K, V tensors for prefill
1207
+ torch.manual_seed(config["seed_qkv"])
1208
+
1209
+ def _create_qkv_tensors_prefill(
1210
+ batch_size, seq_len, config, dtype_override=None
1211
+ ):
1212
+ """Create Q, K, V tensors for prefill, using config for head_num and head_dim."""
1213
+ device = config["device"]
1214
+ dtype = dtype_override or config["dtype"]
1215
+
1216
+ total_tokens = batch_size * seq_len
1217
+
1218
+ tp_q_head_num = config["tp_q_head_num"]
1219
+ tp_k_head_num = config["tp_k_head_num"]
1220
+ head_dim = config["prefill_head_dim"]
1221
+ v_head_dim = config["prefill_v_head_dim"]
1222
+
1223
+ q = torch.randn(
1224
+ (total_tokens, tp_q_head_num * head_dim),
1225
+ dtype=dtype,
1226
+ device=device,
1227
+ )
1228
+ k = torch.randn(
1229
+ (total_tokens, tp_k_head_num * head_dim),
1230
+ dtype=dtype,
1231
+ device=device,
1232
+ )
1233
+ v = torch.randn(
1234
+ (total_tokens, tp_k_head_num * v_head_dim),
1235
+ dtype=dtype,
1236
+ device=device,
1237
+ )
1238
+
1239
+ # Reshape as requested
1240
+ q = q.view(-1, tp_q_head_num, head_dim)
1241
+ k = k.view(-1, tp_k_head_num, head_dim)
1242
+ v = v.view(-1, tp_k_head_num, v_head_dim)
1243
+
1244
+ return q, k, v
1245
+
1246
+ q, k, v = _create_qkv_tensors_prefill(batch_size, max_seq_len, config)
1247
+ # Run prefill on both backends
1248
+ out_trtllm = trtllm_backend.forward_extend(
1249
+ q, k, v, layer, fb_trtllm, False
1250
+ ).view(-1, layer.tp_q_head_num * layer.v_head_dim)
1251
+ out_reference = reference_backend.forward_extend(
1252
+ q, k, v, layer, fb_reference, False
1253
+ )
1254
+
1255
+ tolerance = config.get("tolerance", 1e-2)
1256
+ comparison_passed = compare_outputs(
1257
+ out_trtllm, out_reference, tolerance=tolerance
1258
+ )
1259
+ self.assertTrue(
1260
+ comparison_passed,
1261
+ f"TRTLLM and Reference prefill outputs differ beyond tolerance. "
1262
+ f"Config: {test_case['name']}, "
1263
+ f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
1264
+ )
1265
+
1102
1266
 
1103
1267
  if __name__ == "__main__":
1104
1268
  unittest.main()
@@ -0,0 +1,57 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class DummyModel(nn.Module):
6
+ def __init__(self, d_in=2048, n_heads=128, softmax_scale=0.5):
7
+ super().__init__()
8
+ self.weights_proj = nn.Linear(d_in, 1024)
9
+ self.n_heads = n_heads
10
+ self.softmax_scale = softmax_scale
11
+
12
+ def _get_logits_head_gate_orig(self, x: torch.Tensor, q_scale: torch.Tensor):
13
+ weights = self.weights_proj(x)
14
+ weights = weights * self.n_heads**-0.5
15
+ q_scale = q_scale.unsqueeze(1) # (B,1,1)
16
+ weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
17
+ return weights
18
+
19
+ def _get_logits_head_gate_opt(self, x: torch.Tensor, q_scale: torch.Tensor):
20
+ weights = self.weights_proj(x)
21
+ q_scale = q_scale.unsqueeze(1) # (B,1,1)
22
+ scale_const = self.n_heads**-0.5 * q_scale * self.softmax_scale # (B,1,1)
23
+ weights = weights.unsqueeze(-1) * scale_const # (B,1024,1)
24
+ return weights
25
+
26
+
27
+ def main():
28
+ torch.manual_seed(0)
29
+ model = DummyModel(d_in=2048, n_heads=128, softmax_scale=0.5)
30
+ x = torch.randn(128, 2048) # batch=128, d_in=2048
31
+ q_scale = torch.randn(128, 1)
32
+
33
+ import time
34
+
35
+ start = time.time()
36
+ for _ in range(1000):
37
+ out_orig = model._get_logits_head_gate_orig(x, q_scale)
38
+ print("Original version time:", time.time() - start)
39
+
40
+ start = time.time()
41
+ for _ in range(1000):
42
+ out_opt = model._get_logits_head_gate_opt(x, q_scale)
43
+ print("Optimized version time:", time.time() - start)
44
+
45
+ print("Difference:", (out_orig - out_opt).abs().max().item())
46
+ assert torch.allclose(out_orig, out_opt), "Mismatch between original and optimized"
47
+
48
+
49
+ if __name__ == "__main__":
50
+ main()
51
+
52
+
53
+ """
54
+ Original version time: 0.49235057830810547
55
+ Optimized version time: 0.4087331295013428
56
+ Difference: 1.4901161193847656e-08
57
+ """
@@ -0,0 +1 @@
1
+ """LongBench-v2 auxiliary utilities and validation scripts."""
@@ -0,0 +1,238 @@
1
+ """
2
+ Test cases for LongBench-v2 evaluation utility.
3
+ """
4
+
5
+ import json
6
+ import os
7
+ import tempfile
8
+
9
+ from sglang.test.simple_eval_longbench_v2 import (
10
+ LongBenchV2Eval,
11
+ extract_longbench_v2_answer,
12
+ format_longbench_v2_question,
13
+ )
14
+
15
+
16
+ def test_format_longbench_v2_question():
17
+ """Test the official LongBench-v2 question formatting."""
18
+ sample_row = {
19
+ "context": "This is a sample context about environmental issues.",
20
+ "question": "What is the main theme?",
21
+ "A": "Technology",
22
+ "B": "Environment",
23
+ "C": "Economics",
24
+ "D": "Politics",
25
+ "answer": "B",
26
+ }
27
+
28
+ formatted = format_longbench_v2_question(sample_row)
29
+
30
+ # Verify official template structure
31
+ assert "This is a sample context about environmental issues." in formatted
32
+ assert (
33
+ "What is the correct answer to this question: What is the main theme?"
34
+ in formatted
35
+ )
36
+ assert "(A) Technology" in formatted
37
+ assert "(B) Environment" in formatted
38
+ assert "(C) Economics" in formatted
39
+ assert "(D) Politics" in formatted
40
+ assert "The correct answer is" in formatted
41
+ print("✓ Question formatting works correctly")
42
+
43
+
44
+ def test_extract_longbench_v2_answer():
45
+ """Test the official LongBench-v2 answer extraction."""
46
+
47
+ # Test official format: "The correct answer is (A)"
48
+ response1 = "After analyzing the context, The correct answer is (B)."
49
+ assert extract_longbench_v2_answer(response1) == "B"
50
+
51
+ # Test alternative format: "The correct answer is A"
52
+ response2 = "Based on the evidence, The correct answer is C."
53
+ assert extract_longbench_v2_answer(response2) == "C"
54
+
55
+ # Test with asterisks
56
+ response3 = "*The correct answer is (D)*"
57
+ assert extract_longbench_v2_answer(response3) == "D"
58
+
59
+ # Test fallback to standard pattern
60
+ response4 = "I think the answer is A."
61
+ assert extract_longbench_v2_answer(response4) == "A"
62
+
63
+ # Test no answer
64
+ response5 = "I'm not sure about this."
65
+ assert extract_longbench_v2_answer(response5) is None
66
+
67
+ print("✓ Answer extraction works correctly")
68
+
69
+
70
+ def test_longbench_v2_eval_initialization():
71
+ """Test LongBench-v2 evaluation class initialization."""
72
+
73
+ # Create a temporary JSON file with sample data
74
+ sample_data = [
75
+ {
76
+ "_id": "test_001",
77
+ "domain": "single_document_qa",
78
+ "question": "What is X?",
79
+ "choice_A": "Option A1",
80
+ "choice_B": "Option B1",
81
+ "choice_C": "Option C1",
82
+ "choice_D": "Option D1",
83
+ "answer": "A",
84
+ "context": "Context 1",
85
+ },
86
+ {
87
+ "_id": "test_002",
88
+ "domain": "multi_document_qa",
89
+ "question": "What is Y?",
90
+ "A": "Option A2",
91
+ "B": "Option B2",
92
+ "C": "Option C2",
93
+ "D": "Option D2",
94
+ "answer": "B",
95
+ "context": "Context 2",
96
+ },
97
+ ]
98
+
99
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
100
+ json.dump(sample_data, f)
101
+ temp_file = f.name
102
+
103
+ try:
104
+ # Test initialization with new data_source parameter
105
+ eval_instance = LongBenchV2Eval(data_source=temp_file, num_examples=1)
106
+ assert len(eval_instance.examples) == 1
107
+ first_example = eval_instance.examples[0]
108
+ assert first_example.get("category") in {
109
+ "single_document_qa",
110
+ "multi_document_qa",
111
+ }
112
+ assert first_example.get("A") in {"Option A1", "Option A2"}
113
+ print("✓ Evaluation class initialization works correctly")
114
+
115
+ finally:
116
+ os.unlink(temp_file)
117
+
118
+
119
+ def test_category_filtering():
120
+ """Ensure category filtering keeps only requested domains."""
121
+
122
+ sample_data = [
123
+ {
124
+ "_id": "test_001",
125
+ "domain": "single_document_qa",
126
+ "question": "What is X?",
127
+ "choice_A": "Option A1",
128
+ "choice_B": "Option B1",
129
+ "choice_C": "Option C1",
130
+ "choice_D": "Option D1",
131
+ "answer": "A",
132
+ "context": "Context 1",
133
+ },
134
+ {
135
+ "_id": "test_002",
136
+ "domain": "multi_document_qa",
137
+ "question": "What is Y?",
138
+ "choice_A": "Option A2",
139
+ "choice_B": "Option B2",
140
+ "choice_C": "Option C2",
141
+ "choice_D": "Option D2",
142
+ "answer": "B",
143
+ "context": "Context 2",
144
+ },
145
+ ]
146
+
147
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
148
+ json.dump(sample_data, f)
149
+ temp_file = f.name
150
+
151
+ try:
152
+ eval_instance = LongBenchV2Eval(
153
+ data_source=temp_file,
154
+ categories=["multi_document_qa"],
155
+ )
156
+ assert len(eval_instance.examples) == 1
157
+ assert eval_instance.examples[0]["category"] == "multi_document_qa"
158
+ print("✓ Category filtering works correctly")
159
+ finally:
160
+ os.unlink(temp_file)
161
+
162
+
163
+ def test_difficulty_metrics():
164
+ """Validate that difficulty-specific metrics are recorded."""
165
+
166
+ sample_data = [
167
+ {
168
+ "_id": "easy_001",
169
+ "domain": "single_document_qa",
170
+ "difficulty": "easy",
171
+ "question": "Easy question?",
172
+ "choice_A": "Correct",
173
+ "choice_B": "Wrong",
174
+ "choice_C": "Wrong",
175
+ "choice_D": "Wrong",
176
+ "answer": "A",
177
+ "context": "Easy context",
178
+ },
179
+ {
180
+ "_id": "hard_001",
181
+ "domain": "single_document_qa",
182
+ "difficulty": "hard",
183
+ "question": "Hard question?",
184
+ "choice_A": "Wrong",
185
+ "choice_B": "Correct",
186
+ "choice_C": "Wrong",
187
+ "choice_D": "Wrong",
188
+ "answer": "B",
189
+ "context": "Hard context",
190
+ },
191
+ ]
192
+
193
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
194
+ json.dump(sample_data, f)
195
+ temp_file = f.name
196
+
197
+ class FixedSampler: # noqa: D401 - simple helper
198
+ """Mock sampler returning the correct answer based on question text."""
199
+
200
+ def _pack_message(self, content: str, role: str):
201
+ return {"content": content, "role": role}
202
+
203
+ def __call__(self, messages):
204
+ prompt = messages[0]["content"]
205
+ if "Easy question" in prompt:
206
+ return "The correct answer is (A)"
207
+ return "The correct answer is (B)"
208
+
209
+ try:
210
+ eval_instance = LongBenchV2Eval(data_source=temp_file, num_threads=1)
211
+ result = eval_instance(FixedSampler())
212
+
213
+ assert result.metrics.get("difficulty_easy") == 1.0
214
+ assert result.metrics.get("difficulty_hard") == 1.0
215
+ print("✓ Difficulty metrics recorded correctly")
216
+ finally:
217
+ os.unlink(temp_file)
218
+
219
+
220
+ def main():
221
+ """Run all tests."""
222
+ print("Testing simplified LongBench-v2 evaluation utility...\n")
223
+
224
+ test_format_longbench_v2_question()
225
+ test_extract_longbench_v2_answer()
226
+ test_longbench_v2_eval_initialization()
227
+ test_category_filtering()
228
+ test_difficulty_metrics()
229
+
230
+ print("\n" + "=" * 50)
231
+ print("✅ ALL TESTS PASSED!")
232
+ print("The simplified implementation follows SGLang patterns")
233
+ print("while maintaining LongBench-v2 compatibility.")
234
+ print("=" * 50)
235
+
236
+
237
+ if __name__ == "__main__":
238
+ main()